From 293bdc8e72ae73f431383c77b9f8579a78aad9b7 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Sep 2018 10:54:42 +0800 Subject: [PATCH 001/843] support url.Values in request param --- req.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/req.go b/req.go index f8a9ae73..decab412 100644 --- a/req.go +++ b/req.go @@ -214,6 +214,13 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro return nil, err } delayedFunc = append(delayedFunc, fn) + case url.Values: + p := param{vv} + if method == "GET" || method == "HEAD" { + queryParam.Copy(p) + } else { + formParam.Copy(p) + } case Param: if method == "GET" || method == "HEAD" { queryParam.Adds(vv) From bc0679ab8b32c507810efbc37122451c03eaf203 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Sep 2018 21:15:33 +0800 Subject: [PATCH 002/843] support *os.File in request param --- req.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/req.go b/req.go index decab412..ae6a11e5 100644 --- a/req.go +++ b/req.go @@ -429,9 +429,17 @@ func setContentType(req *http.Request, contentType string) { func setBodyReader(req *http.Request, resp *Resp, rd io.Reader) func() { var rc io.ReadCloser - if trc, ok := rd.(io.ReadCloser); ok { - rc = trc - } else { + switch r := rd.(type) { + case *os.File: + stat, err := r.Stat() + if err == nil { + req.ContentLength = stat.Size() + } + rc = r + + case io.ReadCloser: + rc = r + default: rc = ioutil.NopCloser(rd) } bw := &bodyWrapper{ From fc61e19f4f80caf381411d80d906bbfb1aceaba7 Mon Sep 17 00:00:00 2001 From: Salvador Guzman Date: Tue, 30 Oct 2018 03:04:08 -0700 Subject: [PATCH 003/843] fixed small typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bdc77fc9..c41d5ae0 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ There is a default `Req` object, all of its' public methods are wrapped by the r := req.New() r.Get(url) -// use req package to initiate reqeust. +// use req package to initiate request. req.Get(url) ``` You can use `req.New()` to create lots of `*Req` as client with independent configuration From 38b12ac77d21d816f2ca9ab42e771dd3a886d2cc Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 23 Nov 2018 11:02:36 +0800 Subject: [PATCH 004/843] add support for time cost --- dump.go | 14 +++++++++++++- req.go | 10 +++++++++- resp.go | 14 ++++++++++++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/dump.go b/dump.go index 6c17cdaf..ce6d3a5b 100644 --- a/dump.go +++ b/dump.go @@ -3,6 +3,7 @@ package req import ( "bufio" "bytes" + "fmt" "io" "io/ioutil" "net" @@ -84,8 +85,10 @@ type dumpBuffer struct { } func (b *dumpBuffer) Write(p []byte) { + if b.Len() > 0 { + b.Buffer.WriteString("\r\n\r\n") + } b.Buffer.Write(p) - b.Buffer.WriteString("\r\n\r\n") } func (b *dumpBuffer) WriteString(s string) { @@ -189,8 +192,17 @@ func (r *Resp) dumpResponse(dump *dumpBuffer) { } } +// Cost return the time cost of the request +func (r *Resp) Cost() time.Duration { + return r.cost +} + +// Dump dump the request func (r *Resp) Dump() string { dump := new(dumpBuffer) + if r.r.flag&Lcost != 0 { + dump.WriteString(fmt.Sprint(r.cost)) + } r.dumpRequest(dump) l := dump.Len() if l > 0 { diff --git a/req.go b/req.go index ae6a11e5..760d79d5 100644 --- a/req.go +++ b/req.go @@ -320,7 +320,15 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro resp.client = r.Client() } - response, err := resp.client.Do(req) + var response *http.Response + if r.flag&Lcost != 0 { + before := time.Now() + response, err = resp.client.Do(req) + after := time.Now() + resp.cost = after.Sub(before) + } else { + response, err = resp.client.Do(req) + } if err != nil { return nil, err } diff --git a/resp.go b/resp.go index 03ed679e..eb56b1bd 100644 --- a/resp.go +++ b/resp.go @@ -18,6 +18,7 @@ type Resp struct { req *http.Request resp *http.Response client *http.Client + cost time.Duration *multipartHelper reqBody []byte respBody []byte @@ -149,7 +150,11 @@ var regNewline = regexp.MustCompile(`\n|\r`) func (r *Resp) autoFormat(s fmt.State) { req := r.req - fmt.Fprint(s, req.Method, " ", req.URL.String()) + if r.r.flag&Lcost != 0 { + fmt.Fprint(s, req.Method, " ", req.URL.String(), " ", r.cost) + } else { + fmt.Fprint(s, req.Method, " ", req.URL.String()) + } // test if it is should be outputed pretty var pretty bool @@ -180,7 +185,11 @@ func (r *Resp) autoFormat(s fmt.State) { func (r *Resp) miniFormat(s fmt.State) { req := r.req - fmt.Fprint(s, req.Method, " ", req.URL.String()) + if r.r.flag&Lcost != 0 { + fmt.Fprint(s, req.Method, " ", req.URL.String(), " ", r.cost) + } else { + fmt.Fprint(s, req.Method, " ", req.URL.String()) + } if r.r.flag&LreqBody != 0 && len(r.reqBody) > 0 { // request body str := regNewline.ReplaceAllString(string(r.reqBody), " ") fmt.Fprint(s, " ", str) @@ -191,6 +200,7 @@ func (r *Resp) miniFormat(s fmt.State) { } } +// Format fort the response func (r *Resp) Format(s fmt.State, verb rune) { if r == nil || r.req == nil { return From b355b6f6842fbda1bee879f2846cb79424cad4de Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 25 Nov 2018 18:04:31 +0800 Subject: [PATCH 005/843] support context.Context --- req.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/req.go b/req.go index 760d79d5..1e782e1f 100644 --- a/req.go +++ b/req.go @@ -3,6 +3,7 @@ package req import ( "bytes" "compress/gzip" + "context" "encoding/json" "encoding/xml" "errors" @@ -254,6 +255,8 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro resp.downloadProgress = vv case func(int64, int64): progress = vv + case context.Context: + req = req.WithContext(vv) case error: return nil, vv } From 936e820f1d37ce004ad49bfef925079558f84fcf Mon Sep 17 00:00:00 2001 From: Karl Gustav Date: Thu, 16 May 2019 10:29:02 +0200 Subject: [PATCH 006/843] Fix one char spelling mistake smae as above --> same as above --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c41d5ae0..ede817cb 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ Output in simple way (default format) ``` go r, _ := req.Get(url, param) log.Printf("%v\n", r) // GET http://foo.bar/api?name=roc&cmd=add {"code":"0","msg":"success"} -log.Prinln(r) // smae as above +log.Prinln(r) // same as above ``` ### `%-v` or `%-s` From b713f8758bce73e0840dc34e304b8ad2911c62fc Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 20 Jun 2019 10:14:22 +0800 Subject: [PATCH 007/843] fix nil pointer when pass ctx --- req.go | 1 + 1 file changed, 1 insertion(+) diff --git a/req.go b/req.go index 1e782e1f..d1b3e712 100644 --- a/req.go +++ b/req.go @@ -257,6 +257,7 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro progress = vv case context.Context: req = req.WithContext(vv) + resp.req = req case error: return nil, vv } From 7bd39462dbc460fa5950ce089e10cb96bec92f28 Mon Sep 17 00:00:00 2001 From: dvornikov Date: Wed, 14 Aug 2019 12:11:47 +0300 Subject: [PATCH 008/843] Add example with context.Context --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ede817cb..d6560087 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ Examples [Cookie](#Cookie) [Set Timeout](#Set-Timeout) [Set Proxy](#Set-Proxy) -[Customize Client](#Customize-Client) +[Customize Client](#Customize-Client) +[Set context.Context](#Context) ## Basic ``` go @@ -281,6 +282,12 @@ Set a simple proxy (use fixed proxy url for every request) req.SetProxyUrl("http://my.proxy.com:23456") ``` +## Set context.Context +You can pass context.Context in simple way: +```go +r, _ := req.Get(url, context.Background()) +``` + ## Customize Client Use `SetClient` to change the default underlying `*http.Client` ``` go From ba3ab259576399d39cf55090d3ebf99d93ed734c Mon Sep 17 00:00:00 2001 From: Maddie Zhan Date: Mon, 26 Aug 2019 01:30:21 +0800 Subject: [PATCH 009/843] Report progress on upload and download finish --- req.go | 5 +++++ resp.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/req.go b/req.go index d1b3e712..e3da406e 100644 --- a/req.go +++ b/req.go @@ -514,6 +514,11 @@ func (m *multipartHelper) Upload(req *http.Request) { var current int64 buf := make([]byte, 1024) var lastTime time.Time + + defer func() { + m.uploadProgress(current, total) + }() + upload = func(w io.Writer, r io.Reader) error { for { n, err := r.Read(buf) diff --git a/resp.go b/resp.go index eb56b1bd..a7658253 100644 --- a/resp.go +++ b/resp.go @@ -124,6 +124,11 @@ func (r *Resp) download(file *os.File) error { total := r.resp.ContentLength var current int64 var lastTime time.Time + + defer func() { + r.downloadProgress(current, total) + }() + for { l, err := b.Read(p) if l > 0 { From d0ec57625b503541723cebaa657b5aee6182bf98 Mon Sep 17 00:00:00 2001 From: Maddie Zhan Date: Mon, 26 Aug 2019 15:18:57 +0800 Subject: [PATCH 010/843] Allow customization of progress reporting interval --- req.go | 30 +++++++++++++++++------------- resp.go | 2 +- setting.go | 12 ++++++++++++ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/req.go b/req.go index d1b3e712..645c5c33 100644 --- a/req.go +++ b/req.go @@ -121,15 +121,17 @@ func BodyXML(v interface{}) *bodyXml { // Req is a convenient client for initiating requests type Req struct { - client *http.Client - jsonEncOpts *jsonEncOpts - xmlEncOpts *xmlEncOpts - flag int + client *http.Client + jsonEncOpts *jsonEncOpts + xmlEncOpts *xmlEncOpts + flag int + progressInterval time.Duration } // New create a new *Req func New() *Req { - return &Req{flag: LstdFlags} + // default progress reporting interval is 200 milliseconds + return &Req{flag: LstdFlags, progressInterval: 200 * time.Millisecond} } type param struct { @@ -277,9 +279,10 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro up = UploadProgress(progress) } multipartHelper := &multipartHelper{ - form: formParam.Values, - uploads: uploads, - uploadProgress: up, + form: formParam.Values, + uploads: uploads, + uploadProgress: up, + progressInterval: resp.r.progressInterval, } multipartHelper.Upload(req) resp.multipartHelper = multipartHelper @@ -484,10 +487,11 @@ func (b *bodyWrapper) Read(p []byte) (n int, err error) { } type multipartHelper struct { - form url.Values - uploads []FileUpload - dump []byte - uploadProgress UploadProgress + form url.Values + uploads []FileUpload + dump []byte + uploadProgress UploadProgress + progressInterval time.Duration } func (m *multipartHelper) Upload(req *http.Request) { @@ -523,7 +527,7 @@ func (m *multipartHelper) Upload(req *http.Request) { return _err } current += int64(n) - if now := time.Now(); now.Sub(lastTime) > 200*time.Millisecond { + if now := time.Now(); now.Sub(lastTime) > m.progressInterval { lastTime = now m.uploadProgress(current, total) } diff --git a/resp.go b/resp.go index eb56b1bd..87ad51e6 100644 --- a/resp.go +++ b/resp.go @@ -132,7 +132,7 @@ func (r *Resp) download(file *os.File) error { return _err } current += int64(l) - if now := time.Now(); now.Sub(lastTime) > 200*time.Millisecond { + if now := time.Now(); now.Sub(lastTime) > r.r.progressInterval { lastTime = now r.downloadProgress(current, total) } diff --git a/setting.go b/setting.go index 74235f37..ee771e07 100644 --- a/setting.go +++ b/setting.go @@ -234,3 +234,15 @@ func (r *Req) SetXMLIndent(prefix, indent string) { func SetXMLIndent(prefix, indent string) { std.SetXMLIndent(prefix, indent) } + +// SetProgressInterval sets the progress reporting interval of both +// UploadProgress and DownloadProgress handler +func (r *Req) SetProgressInterval(interval time.Duration) { + r.progressInterval = interval +} + +// SetProgressInterval sets the progress reporting interval of both +// UploadProgress and DownloadProgress handler for the default client +func SetProgressInterval(interval time.Duration) { + std.SetProgressInterval(interval) +} From 98209d1daaa4dfc9818676d210559fa3b9a4185a Mon Sep 17 00:00:00 2001 From: Aether Date: Thu, 19 Sep 2019 12:05:28 +0800 Subject: [PATCH 011/843] init go mod --- go.mod | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 go.mod diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..301337f2 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/aetherwu/req + +go 1.12 From def36f2fd4ef2482c35c8ec4d169a532806d0ed8 Mon Sep 17 00:00:00 2001 From: Aether Date: Thu, 19 Sep 2019 12:19:51 +0800 Subject: [PATCH 012/843] fix repo path for go mod --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 301337f2..433bcc07 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/aetherwu/req +module github.com/imroc/req go 1.12 From d25f6f07d264fa4caca6d8cb85e5206253b5c348 Mon Sep 17 00:00:00 2001 From: Jirawat Harnsiriwatanakit Date: Sat, 25 Jan 2020 22:41:12 +0700 Subject: [PATCH 013/843] add function to parse for header --- req.go | 10 ++++++++++ req_test.go | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/req.go b/req.go index b1f3487a..25e314da 100644 --- a/req.go +++ b/req.go @@ -48,6 +48,16 @@ func (h Header) Clone() Header { return hh } +func ParseStruct(h Header, v interface{}) Header { + data, err := json.Marshal(v) + if err != nil { + return h + } + + err = json.Unmarshal(data, &h) + return h +} + // Param represents http request param type Param map[string]interface{} diff --git a/req_test.go b/req_test.go index a863941a..34fca5ec 100644 --- a/req_test.go +++ b/req_test.go @@ -267,6 +267,26 @@ func TestHeader(t *testing.T) { } } +func TestParseStruct(t *testing.T) { + + type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` + } + + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + var header Header + header = ParseStruct(header, h) + + if header["User-Agent"] != "V1.0.0" && header["Authorization"] != "roc" { + t.Fatal("struct parser for header is not working") + } +} + func TestUpload(t *testing.T) { str := "hello req" file := ioutil.NopCloser(strings.NewReader(str)) From ae0ae39b73f6ddbd54e6fc147b0fc6ca803f0da5 Mon Sep 17 00:00:00 2001 From: Jirawat Harnsiriwatanakit Date: Sat, 25 Jan 2020 22:46:22 +0700 Subject: [PATCH 014/843] add Header from struct function --- req.go | 9 +++++++++ req_test.go | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/req.go b/req.go index 25e314da..b0424a84 100644 --- a/req.go +++ b/req.go @@ -48,6 +48,7 @@ func (h Header) Clone() Header { return hh } +// ParseStruct parse struct into header func ParseStruct(h Header, v interface{}) Header { data, err := json.Marshal(v) if err != nil { @@ -58,6 +59,14 @@ func ParseStruct(h Header, v interface{}) Header { return h } +// HeaderFromStruct init header from struct +func HeaderFromStruct(v interface{}) Header { + + var header Header + header = ParseStruct(header, v) + return header +} + // Param represents http request param type Param map[string]interface{} diff --git a/req_test.go b/req_test.go index 34fca5ec..f37ad4fe 100644 --- a/req_test.go +++ b/req_test.go @@ -270,7 +270,7 @@ func TestHeader(t *testing.T) { func TestParseStruct(t *testing.T) { type HeaderStruct struct { - UserAgent string `json:"User-Agent"` + UserAgent string `json:"User-Agent"` Authorization string `json:"Authorization"` } @@ -282,6 +282,26 @@ func TestParseStruct(t *testing.T) { var header Header header = ParseStruct(header, h) + if header["User-Agent"] != "V1.0.0" && header["Authorization"] != "roc" { + t.Fatal("struct parser for header is not working") + } + +} + +func TestHeaderFromStruct(t *testing.T) { + + type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` + } + + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + header := HeaderFromStruct(h) + if header["User-Agent"] != "V1.0.0" && header["Authorization"] != "roc" { t.Fatal("struct parser for header is not working") } From 756091ea3676c1d6ec8c8f4c6e6212bfbae13cef Mon Sep 17 00:00:00 2001 From: Jirawat Harnsiriwatanakit Date: Sat, 25 Jan 2020 22:52:54 +0700 Subject: [PATCH 015/843] move all header type and functions to new header.go file --- header.go | 41 +++++++++++++++++++++++++++++++++++++++++ header_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ req.go | 33 --------------------------------- req_test.go | 40 ---------------------------------------- 4 files changed, 88 insertions(+), 73 deletions(-) create mode 100644 header.go create mode 100644 header_test.go diff --git a/header.go b/header.go new file mode 100644 index 00000000..ac1be32b --- /dev/null +++ b/header.go @@ -0,0 +1,41 @@ +/* + GoLang code created by Jirawat Harnsiriwatanakit https://github.com/kazekim +*/ + +package req + +import "encoding/json" + +// Header represents http request header +type Header map[string]string + +func (h Header) Clone() Header { + if h == nil { + return nil + } + hh := Header{} + for k, v := range h { + hh[k] = v + } + return hh +} + +// ParseStruct parse struct into header +func ParseStruct(h Header, v interface{}) Header { + data, err := json.Marshal(v) + if err != nil { + return h + } + + err = json.Unmarshal(data, &h) + return h +} + +// HeaderFromStruct init header from struct +func HeaderFromStruct(v interface{}) Header { + + var header Header + header = ParseStruct(header, v) + return header +} + diff --git a/header_test.go b/header_test.go new file mode 100644 index 00000000..5e56fac4 --- /dev/null +++ b/header_test.go @@ -0,0 +1,47 @@ +/* + GoLang code created by Jirawat Harnsiriwatanakit https://github.com/kazekim +*/ + +package req + +import "testing" + +func TestParseStruct(t *testing.T) { + + type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` + } + + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + var header Header + header = ParseStruct(header, h) + + if header["User-Agent"] != h.UserAgent && header["Authorization"] != h.Authorization { + t.Fatal("struct parser for header is not working") + } + +} + +func TestHeaderFromStruct(t *testing.T) { + + type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` + } + + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + header := HeaderFromStruct(h) + + if header["User-Agent"] != h.UserAgent && header["Authorization"] != h.Authorization { + t.Fatal("struct parser for header is not working") + } +} diff --git a/req.go b/req.go index b0424a84..01937f04 100644 --- a/req.go +++ b/req.go @@ -34,39 +34,6 @@ const ( LstdFlags = LreqHead | LreqBody | LrespHead | LrespBody ) -// Header represents http request header -type Header map[string]string - -func (h Header) Clone() Header { - if h == nil { - return nil - } - hh := Header{} - for k, v := range h { - hh[k] = v - } - return hh -} - -// ParseStruct parse struct into header -func ParseStruct(h Header, v interface{}) Header { - data, err := json.Marshal(v) - if err != nil { - return h - } - - err = json.Unmarshal(data, &h) - return h -} - -// HeaderFromStruct init header from struct -func HeaderFromStruct(v interface{}) Header { - - var header Header - header = ParseStruct(header, v) - return header -} - // Param represents http request param type Param map[string]interface{} diff --git a/req_test.go b/req_test.go index f37ad4fe..a863941a 100644 --- a/req_test.go +++ b/req_test.go @@ -267,46 +267,6 @@ func TestHeader(t *testing.T) { } } -func TestParseStruct(t *testing.T) { - - type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` - } - - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - var header Header - header = ParseStruct(header, h) - - if header["User-Agent"] != "V1.0.0" && header["Authorization"] != "roc" { - t.Fatal("struct parser for header is not working") - } - -} - -func TestHeaderFromStruct(t *testing.T) { - - type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` - } - - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - header := HeaderFromStruct(h) - - if header["User-Agent"] != "V1.0.0" && header["Authorization"] != "roc" { - t.Fatal("struct parser for header is not working") - } -} - func TestUpload(t *testing.T) { str := "hello req" file := ioutil.NopCloser(strings.NewReader(str)) From 472d7da5e8bdab2622f237d6cd9a93ba79812c80 Mon Sep 17 00:00:00 2001 From: Jirawat Harnsiriwatanakit Date: Sat, 25 Jan 2020 22:59:51 +0700 Subject: [PATCH 016/843] update read me file to add HeaderFromStruct function usage --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index d6560087..d7f5201c 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,25 @@ header.Set("Accept", "application/json") req.Get("https://www.baidu.com", header) ``` +#### Set Header From Struct +Use `HeaderFromStruct` func to parse your struct +``` go +type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` + } + + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + authHeader := req.HeaderFromStruct(h) + req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) +``` + +Note: Please add tag 'json' to your argument in struct to let you customize the key name of your header + ## Set Param Use `req.Param` (it is actually a `map[string]interface{}`) ``` go From d679d25f09a22a186b72b8bb286ff06a7ac6d4a0 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 26 Jan 2020 15:39:49 +0800 Subject: [PATCH 017/843] update README --- README.md | 14 +++++++------- doc/README_cn.md | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d7f5201c..7d54deb4 100644 --- a/README.md +++ b/README.md @@ -105,14 +105,14 @@ header.Set("Accept", "application/json") req.Get("https://www.baidu.com", header) ``` -#### Set Header From Struct -Use `HeaderFromStruct` func to parse your struct +You can also set header from struct, use `HeaderFromStruct` func to parse your struct ``` go type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` - } + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` +} +func main(){ h := HeaderStruct{ "V1.0.0", "roc", @@ -120,9 +120,9 @@ type HeaderStruct struct { authHeader := req.HeaderFromStruct(h) req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) +} ``` - -Note: Please add tag 'json' to your argument in struct to let you customize the key name of your header +> Note: Please add tag 'json' to your argument in struct to let you customize the key name of your header ## Set Param Use `req.Param` (it is actually a `map[string]interface{}`) diff --git a/doc/README_cn.md b/doc/README_cn.md index d5405bb2..3f724ea9 100644 --- a/doc/README_cn.md +++ b/doc/README_cn.md @@ -97,6 +97,25 @@ header.Set("Accept", "application/json") req.Get("https://www.baidu.com", header) ``` +你可以使用 `struct` 来设置请求头,用 `HeaderFromStruct` 这个函数来解析你的 `struct` +``` go +type HeaderStruct struct { + UserAgent string `json:"User-Agent"` + Authorization string `json:"Authorization"` +} + +func main(){ + h := HeaderStruct{ + "V1.0.0", + "roc", + } + + authHeader := req.HeaderFromStruct(h) + req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) +} +``` +> 注:请给你的 struct 加上 json tag. + ## 设置请求参数 Use `req.Param` (它实际上是一个 `map[string]interface{}`) ``` go From 74fbafcbf1179eb36a92d93944d56613ad7abe5e Mon Sep 17 00:00:00 2001 From: Erick Sosa Date: Tue, 22 Sep 2020 13:19:35 -0400 Subject: [PATCH 018/843] fix: short variable declaration error should be `:=` instead of `=` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d54deb4..1f30e0ba 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ param := req.Param{ "cmd": "add", } // only url is required, others are optional. -r, err = req.Post("http://foo.bar/api", header, param) +r, err := req.Post("http://foo.bar/api", header, param) if err != nil { log.Fatal(err) } From a7f7e9c8ce87c23f15a77f2f3546ae322c0ab7d1 Mon Sep 17 00:00:00 2001 From: raymonder jin Date: Thu, 16 Sep 2021 15:15:48 +0800 Subject: [PATCH 019/843] fix typo --- README.md | 4 ++-- resp.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1f30e0ba..8319dabb 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ req.Post(url, req.BodyXML(&bar)) ``` ## Debug -Set global variable `req.Debug` to true, it will print detail infomation for every request. +Set global variable `req.Debug` to true, it will print detail information for every request. ``` go req.Debug = true req.Post("http://localhost/test" "hi") @@ -169,7 +169,7 @@ req.Post("http://localhost/test" "hi") ![post](doc/post.png) ## Output Format -You can use different kind of output format to log the request and response infomation in your log file in defferent scenarios. For example, use `%+v` output format in the development phase, it allows you to observe the details. Use `%v` or `%-v` output format in production phase, just log the information necessarily. +You can use different kind of output format to log the request and response information in your log file in defferent scenarios. For example, use `%+v` output format in the development phase, it allows you to observe the details. Use `%v` or `%-v` output format in production phase, just log the information necessarily. ### `%+v` or `%+s` Output in detail diff --git a/resp.go b/resp.go index b464c2b1..e5c36fce 100644 --- a/resp.go +++ b/resp.go @@ -43,7 +43,7 @@ func (r *Resp) Bytes() []byte { } // ToBytes returns response body as []byte, -// return error if error happend when reading +// return error if error happened when reading // the response body func (r *Resp) ToBytes() ([]byte, error) { if r.err != nil { @@ -69,7 +69,7 @@ func (r *Resp) String() string { } // ToString returns response body as string, -// return error if error happend when reading +// return error if error happened when reading // the response body func (r *Resp) ToString() (string, error) { data, err := r.ToBytes() From 31fdd27ffdcd2c0fa20c8d37eb659fe816360440 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 27 Oct 2021 12:53:47 +0800 Subject: [PATCH 020/843] still print debug info when err happens --- dump.go | 5 +++-- req.go | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dump.go b/dump.go index ce6d3a5b..52e527e9 100644 --- a/dump.go +++ b/dump.go @@ -207,10 +207,11 @@ func (r *Resp) Dump() string { l := dump.Len() if l > 0 { dump.WriteString("=================================") - l = dump.Len() } - r.dumpResponse(dump) + if r.resp != nil { + r.dumpResponse(dump) + } return dump.String() } diff --git a/req.go b/req.go index 01937f04..452547dd 100644 --- a/req.go +++ b/req.go @@ -171,6 +171,13 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro } resp = &Resp{req: req, r: r} + // output detail if Debug is enabled + if Debug { + defer func(resp *Resp) { + fmt.Println(resp.Dump()) + }(resp) + } + var queryParam param var formParam param var uploads []FileUpload @@ -340,10 +347,6 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro response.Body = body } - // output detail if Debug is enabled - if Debug { - fmt.Println(resp.Dump()) - } return } From 73066f56c5972f983279b2845b9f0a6248fed97e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=81=E6=96=87=E6=B1=9F?= Date: Mon, 8 Nov 2021 17:38:21 +0800 Subject: [PATCH 021/843] feat: add ReservedHeader --- header.go | 1 + req.go | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/header.go b/header.go index ac1be32b..dbc59925 100644 --- a/header.go +++ b/header.go @@ -39,3 +39,4 @@ func HeaderFromStruct(v interface{}) Header { return header } +type ReservedHeader map[string]string diff --git a/req.go b/req.go index 452547dd..99c04b3d 100644 --- a/req.go +++ b/req.go @@ -198,6 +198,10 @@ func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err erro req.Header.Add(key, value) } } + case ReservedHeader: + for key, value := range vv { + req.Header[key] = []string{value} + } case *bodyJson: fn, err := setBodyJson(req, resp, r.jsonEncOpts, vv.v) if err != nil { From f7ae3a6a8afc0712847b98520276d7e2fcf1bbb5 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 18:29:45 +0800 Subject: [PATCH 022/843] update README: v2 news notice --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 8319dabb..02cdd259 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,10 @@ A golang http request library for humans +News +======== +The new v2 version is under development, come and try it by clicking [here](https://github.com/imroc/req/tree/v2)! Features ======== From a5e3331ea7d8b804c7257ca36d6390351e586ffb Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 18:35:01 +0800 Subject: [PATCH 023/843] add news warning --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 02cdd259..74d70efa 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ A golang http request library for humans -News +:warning:News ======== The new v2 version is under development, come and try it by clicking [here](https://github.com/imroc/req/tree/v2)! From cf27a29ca23d7d81c1fa14394a31d30dd7e83721 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 10:50:07 +0800 Subject: [PATCH 024/843] born v2 --- .gitignore | 34 + LICENSE | 222 +- README.md | 324 +- body.go | 124 + client.go | 130 + common.go | 9 + decode.go | 106 + doc/README_cn.md | 314 - doc/post.png | Bin 53255 -> 0 bytes dump.go | 324 +- dump_test.go | 62 - go.mod | 8 +- go.sum | 13 + h2_bundle.go | 10681 +++++++++++++++++++++++++++++ header.go | 134 +- header_test.go | 47 - http.go | 242 + http_request.go | 327 + http_response.go | 92 + internal/ascii/print.go | 61 + internal/ascii/print_test.go | 95 + internal/chunked.go | 261 + internal/chunked_test.go | 241 + internal/godebug/godebug.go | 34 + internal/godebug/godebug_test.go | 34 + req.go | 690 -- req_test.go | 313 - request.go | 226 + resp.go | 220 - resp_test.go | 130 - response.go | 66 + roundtrip.go | 25 + setting.go | 248 - setting_test.go | 62 - socks_bundle.go | 473 ++ textproto_reader.go | 844 +++ transfer.go | 1126 +++ transport.go | 2938 ++++++++ transport_default_js.go | 17 + transport_default_other.go | 17 + 40 files changed, 18520 insertions(+), 2794 deletions(-) create mode 100644 .gitignore create mode 100644 body.go create mode 100644 client.go create mode 100644 common.go create mode 100644 decode.go delete mode 100644 doc/README_cn.md delete mode 100644 doc/post.png delete mode 100644 dump_test.go create mode 100644 go.sum create mode 100644 h2_bundle.go delete mode 100644 header_test.go create mode 100644 http.go create mode 100644 http_request.go create mode 100644 http_response.go create mode 100644 internal/ascii/print.go create mode 100644 internal/ascii/print_test.go create mode 100644 internal/chunked.go create mode 100644 internal/chunked_test.go create mode 100644 internal/godebug/godebug.go create mode 100644 internal/godebug/godebug_test.go delete mode 100644 req.go delete mode 100644 req_test.go create mode 100644 request.go delete mode 100644 resp.go delete mode 100644 resp_test.go create mode 100644 response.go create mode 100644 roundtrip.go delete mode 100644 setting.go delete mode 100644 setting_test.go create mode 100644 socks_bundle.go create mode 100644 textproto_reader.go create mode 100644 transfer.go create mode 100644 transport.go create mode 100644 transport_default_js.go create mode 100644 transport_default_other.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..d2bf3553 --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +# IDE ignore +.idea/ +*.ipr +*.iml +*.iws +.vscode/ + +# Emacs save files +*~ +\#*\# +.\#* + +# Vim-related files +[._]*.s[a-w][a-z] +[._]s[a-w][a-z] +*.un~ +Session.vim +.netrwhist + +# make-related metadata +/.make/ + +# temp ignore +*.log +*.cache +*.diff +*.exe +*.exe~ +*.patch +*.tmp +*.swp + +# OSX trash +.DS_Store \ No newline at end of file diff --git a/LICENSE b/LICENSE index 8dada3ed..70f3d40a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2017-2022 roc + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 74d70efa..360d0d36 100644 --- a/README.md +++ b/README.md @@ -1,331 +1,19 @@ # req -[![GoDoc](https://godoc.org/github.com/imroc/req?status.svg)](https://godoc.org/github.com/imroc/req) +[![GoDoc](https://pkg.go.dev/badge/github.com/imroc/req.svg)](https://pkg.go.dev/github.com/imroc/req) A golang http request library for humans -:warning:News -======== - -The new v2 version is under development, come and try it by clicking [here](https://github.com/imroc/req/tree/v2)! - Features ======== -- Light weight -- Simple -- Easy play with JSON and XML -- Easy for debug and logging -- Easy file uploads and downloads -- Easy manage cookie -- Easy set up proxy -- Easy set timeout -- Easy customize http client - - -Document -======== -[中文](doc/README_cn.md) +* Simple and chainable methods for client and request settings +* Rich syntax sugar, greatly improving development efficiency +* Powerful debugging capabilities +* The settings can be dynamically adjusted, making it possible to debug in the production environment Install ======= ``` sh -go get github.com/imroc/req -``` - -Overview -======= -`req` implements a friendly API over Go's existing `net/http` library. - -`Req` and `Resp` are two most important struct, you can think of `Req` as a client that initiate HTTP requests, `Resp` as a information container for the request and response. They all provide simple and convenient APIs that allows you to do a lot of things. -``` go -func (r *Req) Post(url string, v ...interface{}) (*Resp, error) -``` - -In most cases, only url is required, others are optional, like headers, params, files or body etc. - -There is a default `Req` object, all of its' public methods are wrapped by the `req` package, so you can also think of `req` package as a `Req` object -``` go -// use Req object to initiate requests. -r := req.New() -r.Get(url) - -// use req package to initiate request. -req.Get(url) -``` -You can use `req.New()` to create lots of `*Req` as client with independent configuration - -Examples -======= -[Basic](#Basic) -[Set Header](#Set-Header) -[Set Param](#Set-Param) -[Set Body](#Set-Body) -[Debug](#Debug) -[Output Format](#Format) -[ToJSON & ToXML](#ToJSON-ToXML) -[Get *http.Response](#Response) -[Upload](#Upload) -[Download](#Download) -[Cookie](#Cookie) -[Set Timeout](#Set-Timeout) -[Set Proxy](#Set-Proxy) -[Customize Client](#Customize-Client) -[Set context.Context](#Context) - -## Basic -``` go -header := req.Header{ - "Accept": "application/json", - "Authorization": "Basic YWRtaW46YWRtaW4=", -} -param := req.Param{ - "name": "imroc", - "cmd": "add", -} -// only url is required, others are optional. -r, err := req.Post("http://foo.bar/api", header, param) -if err != nil { - log.Fatal(err) -} -r.ToJSON(&foo) // response => struct/map -log.Printf("%+v", r) // print info (try it, you may surprise) -``` - -## Set Header -Use `req.Header` (it is actually a `map[string]string`) -``` go -authHeader := req.Header{ - "Accept": "application/json", - "Authorization": "Basic YWRtaW46YWRtaW4=", -} -req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) -``` -use `http.Header` -``` go -header := make(http.Header) -header.Set("Accept", "application/json") -req.Get("https://www.baidu.com", header) -``` - -You can also set header from struct, use `HeaderFromStruct` func to parse your struct -``` go -type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` -} - -func main(){ - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - authHeader := req.HeaderFromStruct(h) - req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) -} -``` -> Note: Please add tag 'json' to your argument in struct to let you customize the key name of your header - -## Set Param -Use `req.Param` (it is actually a `map[string]interface{}`) -``` go -param := req.Param{ - "id": "imroc", - "pwd": "roc", -} -req.Get("http://foo.bar/api", param) // http://foo.bar/api?id=imroc&pwd=roc -req.Post(url, param) // body => id=imroc&pwd=roc -``` -use `req.QueryParam` force to append params to the url (it is also actually a `map[string]interface{}`) -``` go -req.Post("http://foo.bar/api", req.Param{"name": "roc", "age": "22"}, req.QueryParam{"access_token": "fedledGF9Hg9ehTU"}) -/* -POST /api?access_token=fedledGF9Hg9ehTU HTTP/1.1 -Host: foo.bar -User-Agent: Go-http-client/1.1 -Content-Length: 15 -Content-Type: application/x-www-form-urlencoded;charset=UTF-8 -Accept-Encoding: gzip - -age=22&name=roc -*/ -``` - -## Set Body -Put `string`, `[]byte` and `io.Reader` as body directly. -``` go -req.Post(url, "id=roc&cmd=query") -``` -Put object as xml or json body (add `Content-Type` header automatically) -``` go -req.Post(url, req.BodyJSON(&foo)) -req.Post(url, req.BodyXML(&bar)) -``` - -## Debug -Set global variable `req.Debug` to true, it will print detail information for every request. -``` go -req.Debug = true -req.Post("http://localhost/test" "hi") -``` -![post](doc/post.png) - -## Output Format -You can use different kind of output format to log the request and response information in your log file in defferent scenarios. For example, use `%+v` output format in the development phase, it allows you to observe the details. Use `%v` or `%-v` output format in production phase, just log the information necessarily. - -### `%+v` or `%+s` -Output in detail -``` go -r, _ := req.Post(url, header, param) -log.Printf("%+v", r) // output the same format as Debug is enabled -``` - -### `%v` or `%s` -Output in simple way (default format) -``` go -r, _ := req.Get(url, param) -log.Printf("%v\n", r) // GET http://foo.bar/api?name=roc&cmd=add {"code":"0","msg":"success"} -log.Prinln(r) // same as above -``` - -### `%-v` or `%-s` -Output in simple way and keep all in one line (request body or response body may have multiple lines, this format will replace `"\r"` or `"\n"` with `" "`, it's useful when doing some search in your log file) - -### Flag -You can call `SetFlags` to control the output content, decide which pieces can be output. -``` go -const ( - LreqHead = 1 << iota // output request head (request line and request header) - LreqBody // output request body - LrespHead // output response head (response line and response header) - LrespBody // output response body - Lcost // output time costed by the request - LstdFlags = LreqHead | LreqBody | LrespHead | LrespBody -) -``` -``` go -req.SetFlags(req.LreqHead | req.LreqBody | req.LrespHead) -``` - -### Monitoring time consuming -``` go -req.SetFlags(req.LstdFlags | req.Lcost) // output format add time costed by request -r,_ := req.Get(url) -log.Println(r) // http://foo.bar/api 3.260802ms {"code":0 "msg":"success"} -if r.Cost() > 3 * time.Second { // check cost - log.Println("WARN: slow request:", r) -} -``` - -## ToJSON & ToXML -``` go -r, _ := req.Get(url) -r.ToJSON(&foo) -r, _ = req.Post(url, req.BodyXML(&bar)) -r.ToXML(&baz) -``` - -## Get *http.Response -```go -// func (r *Req) Response() *http.Response -r, _ := req.Get(url) -resp := r.Response() -fmt.Println(resp.StatusCode) -``` - -## Upload -Use `req.File` to match files -``` go -req.Post(url, req.File("imroc.png"), req.File("/Users/roc/Pictures/*.png")) -``` -Use `req.FileUpload` to fully control -``` go -file, _ := os.Open("imroc.png") -req.Post(url, req.FileUpload{ - File: file, - FieldName: "file", // FieldName is form field name - FileName: "avatar.png", //Filename is the name of the file that you wish to upload. We use this to guess the mimetype as well as pass it onto the server -}) -``` -Use `req.UploadProgress` to listen upload progress -```go -progress := func(current, total int64) { - fmt.Println(float32(current)/float32(total)*100, "%") -} -req.Post(url, req.File("/Users/roc/Pictures/*.png"), req.UploadProgress(progress)) -fmt.Println("upload complete") -``` - -## Download -``` go -r, _ := req.Get(url) -r.ToFile("imroc.png") -``` -Use `req.DownloadProgress` to listen download progress -```go -progress := func(current, total int64) { - fmt.Println(float32(current)/float32(total)*100, "%") -} -r, _ := req.Get(url, req.DownloadProgress(progress)) -r.ToFile("hello.mp4") -fmt.Println("download complete") -``` - -## Cookie -By default, the underlying `*http.Client` will manage your cookie(send cookie header to server automatically if server has set a cookie for you), you can disable it by calling this function : -``` go -req.EnableCookie(false) -``` -and you can set cookie in request just using `*http.Cookie` -``` go -cookie := new(http.Cookie) -// ...... -req.Get(url, cookie) -``` - -## Set Timeout -``` go -req.SetTimeout(50 * time.Second) -``` - -## Set Proxy -By default, req use proxy from system environment if `http_proxy` or `https_proxy` is specified, you can set a custom proxy or disable it by set `nil` -``` go -req.SetProxy(func(r *http.Request) (*url.URL, error) { - if strings.Contains(r.URL.Hostname(), "google") { - return url.Parse("http://my.vpn.com:23456") - } - return nil, nil -}) -``` -Set a simple proxy (use fixed proxy url for every request) -``` go -req.SetProxyUrl("http://my.proxy.com:23456") -``` - -## Set context.Context -You can pass context.Context in simple way: -```go -r, _ := req.Get(url, context.Background()) -``` - -## Customize Client -Use `SetClient` to change the default underlying `*http.Client` -``` go -req.SetClient(client) -``` -Specify independent http client for some requests -``` go -client := &http.Client{Timeout: 30 * time.Second} -req.Get(url, client) -``` -Change some properties of default client you want -``` go -req.Client().Jar, _ = cookiejar.New(nil) -trans, _ := req.Client().Transport.(*http.Transport) -trans.MaxIdleConns = 20 -trans.TLSHandshakeTimeout = 20 * time.Second -trans.DisableKeepAlives = true -trans.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} +go get github.com/imroc/req/v2 ``` diff --git a/body.go b/body.go new file mode 100644 index 00000000..c4644261 --- /dev/null +++ b/body.go @@ -0,0 +1,124 @@ +package req + +import ( + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "net/http" + "os" + "strings" +) + +type Body struct { + io.ReadCloser + resp *http.Response +} + +func (body Body) MustSave(dst io.Writer) { + err := body.Save(dst) + if err != nil { + panic(err) + } +} + +func (body Body) Save(dst io.Writer) error { + if dst == nil { + return nil // TODO: return error + } + _, err := io.Copy(dst, body.ReadCloser) + body.Close() + return err +} + +func (body Body) MustSaveFile(filename string) { + err := body.SaveFile(filename) + if err != nil { + panic(err) + } +} + +func (body Body) SaveFile(filename string) error { + if filename == "" { + return nil // TODO: return error + } + file, err := os.Create(filename) + if err != nil { + return err + } + _, err = io.Copy(file, body.ReadCloser) + body.Close() + return err +} + +func (body Body) MustUnmarshalJson(v interface{}) { + err := body.UnmarshalJson(v) + if err != nil { + panic(err) + } +} + +func (body Body) UnmarshalJson(v interface{}) error { + b, err := body.Bytes() + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func (body Body) MustUnmarshalXml(v interface{}) { + err := body.UnmarshalXml(v) + if err != nil { + panic(err) + } +} + +func (body Body) UnmarshalXml(v interface{}) error { + b, err := body.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(b, v) +} +func (body Body) MustUnmarshal(v interface{}) { + err := body.Unmarshal(v) + if err != nil { + panic(err) + } +} + +func (body Body) Unmarshal(v interface{}) error { + contentType := body.resp.Header.Get("Content-Type") + if strings.Contains(contentType, "json") { + return body.UnmarshalJson(v) + } else if strings.Contains(contentType, "xml") { + return body.UnmarshalXml(v) + } + return body.UnmarshalJson(v) +} + +func (body Body) MustString() string { + b, err := body.Bytes() + if err != nil { + panic(err) + } + return string(b) +} + +func (body Body) String() (string, error) { + b, err := body.Bytes() + return string(b), err +} + +func (body Body) Bytes() ([]byte, error) { + defer body.Close() + return ioutil.ReadAll(body.ReadCloser) +} + +func (body Body) MustBytes() []byte { + b, err := body.Bytes() + if err != nil { + panic(err) + } + return b +} diff --git a/client.go b/client.go new file mode 100644 index 00000000..a3577155 --- /dev/null +++ b/client.go @@ -0,0 +1,130 @@ +package req + +import ( + "encoding/json" + "golang.org/x/net/publicsuffix" + "net" + "net/http" + "net/http/cookiejar" + "time" +) + +func DefaultClient() *Client { + return defaultClient +} + +func SetDefaultClient(c *Client) { + if c != nil { + defaultClient = c + } +} + +var defaultClient *Client = C() + +type Client struct { + t *Transport + t2 *http2Transport + dumpOptions *DumpOptions + httpClient *http.Client + dialer *net.Dialer + jsonDecoder *json.Decoder + commonHeader map[string]string +} + +func (c *Client) R() *Request { + req := &http.Request{ + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + return &Request{ + client: c, + httpRequest: req, + } +} + + +func (c *Client) ResponseOptions(opts ResponseOptions) *Client { + c.t.ResponseOptions = opts + return c +} + +func (c *Client) ResponseOption(opts ...ResponseOption) *Client { + for _, opt := range opts { + opt(&c.t.ResponseOptions) + } + return c +} + +func (c *Client) Timeout(d time.Duration) *Client { + c.httpClient.Timeout = d + return c +} + +// NewRequest is the alias of R() +func (c *Client) NewRequest() *Request { + return c.R() +} + +func (c *Client) DumpOptions(opt *DumpOptions) *Client { + c.dumpOptions = opt + return c +} + +func (c *Client) DisableDump() *Client { + c.t.DisableDump() + return c +} + +func (c *Client) UserAgent(userAgent string) *Client { + return c.Header("User-Agent", userAgent) +} + +func (c *Client) Header(key, value string) *Client { + if c.commonHeader == nil { + c.commonHeader = make(map[string]string) + } + c.commonHeader[key] = value + return c +} + +func (c *Client) EnableDump(opts ...DumpOption) *Client { + if len(opts) > 0 { + if c.dumpOptions == nil { + c.dumpOptions = &DumpOptions{} + } + c.dumpOptions.Set(opts...) + } + c.t.EnableDump(c.dumpOptions) + return c +} + +// NewClient is the alias of C() +func NewClient() *Client { + return C() +} + +func C() *Client { + t := &Transport{ + ForceAttemptHTTP2: true, + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + t2, _ := http2ConfigureTransports(t) + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + httpClient := &http.Client{ + Transport: t, + Jar: jar, + Timeout: 2 * time.Minute, + } + c := &Client{ + httpClient: httpClient, + t: t, + t2: t2, + } + return c +} diff --git a/common.go b/common.go new file mode 100644 index 00000000..7af8cd06 --- /dev/null +++ b/common.go @@ -0,0 +1,9 @@ +package req + +const ( + CONTENT_TYPE_APPLICATION_JSON_UTF8 = "application/json; charset=UTF-8" + CONTENT_TYPE_APPLICATION_XML_UTF8 = "application/xml; charset=UTF-8" + CONTENT_TYPE_TEXT_XML_UTF8 = "text/xml; charset=UTF-8" + CONTENT_TYPE_TEXT_HTML_UTF8 = "text/html; charset=UTF-8" + CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=UTF-8" +) diff --git a/decode.go b/decode.go new file mode 100644 index 00000000..aeef41b2 --- /dev/null +++ b/decode.go @@ -0,0 +1,106 @@ +package req + +import ( + "fmt" + htmlcharset "golang.org/x/net/html/charset" + "golang.org/x/text/encoding/charmap" + "io" + "strings" +) + +func responseBodyIsText(contentType string) bool { + for _, keyword := range []string{"text", "json", "xml", "html", "java"} { + if strings.Contains(contentType, keyword) { + return true + } + } + return false +} + +type decodeReaderCloser struct { + io.ReadCloser + decodeReader io.Reader +} + +func (d *decodeReaderCloser) Read(p []byte) (n int, err error) { + return d.decodeReader.Read(p) +} + +type autoDecodeReadCloser struct { + io.ReadCloser + decodeReader io.Reader + detected bool + peek []byte +} + +func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { + n, err = a.ReadCloser.Read(p) + if n == 0 || err != nil { + return + } + a.detected = true + enc, name, _ := htmlcharset.DetermineEncoding(p[:n], "") + if enc == charmap.Windows1252 { + return + } + if name != "" { + fmt.Println("content charset detected:", name) + } + if enc == nil { + return + } + dc := enc.NewDecoder() + a.decodeReader = dc.Reader(a.ReadCloser) + var pp []byte + pp, err = dc.Bytes(p[:n]) + if err != nil { + return + } + if len(pp) > len(p) { + a.peek = make([]byte, len(pp)-len(p)) + copy(a.peek, pp[len(p):]) + copy(p, pp[:len(p)]) + n = len(p) + return + } + copy(p, pp) + n = len(p) + return +} + +func (a *autoDecodeReadCloser) peekDrain(p []byte) (n int, err error) { + if len(a.peek) > len(p) { + copy(p, a.peek[:len(p)]) + peek := make([]byte, len(a.peek)-len(p)) + copy(peek, a.peek[len(p):]) + a.peek = peek + n = len(p) + return + } + if len(a.peek) == len(p) { + copy(p, a.peek) + n = len(p) + a.peek = nil + return + } + pp := make([]byte, len(p)-len(a.peek)) + nn, err := a.decodeReader.Read(pp) + n = len(a.peek) + nn + copy(p[:len(a.peek)], a.peek) + copy(p[len(a.peek):], pp[:nn]) + a.peek = nil + return +} + +func (a *autoDecodeReadCloser) Read(p []byte) (n int, err error) { + if !a.detected { + return a.peekRead(p) + } + if a.peek != nil { + return a.peekDrain(p) + } + if a.decodeReader != nil { + return a.decodeReader.Read(p) + } + return a.ReadCloser.Read(p) // can not determine charset, not decode +} diff --git a/doc/README_cn.md b/doc/README_cn.md deleted file mode 100644 index 3f724ea9..00000000 --- a/doc/README_cn.md +++ /dev/null @@ -1,314 +0,0 @@ -# req -[![GoDoc](https://godoc.org/github.com/imroc/req?status.svg)](https://godoc.org/github.com/imroc/req) - -Go语言人性化HTTP请求库 - - -特性 -======== - -- 轻量级 -- 简单 -- 容易操作JSON和XML -- 容易调试和日志记录 -- 容易上传和下载文件 -- 容易管理Cookie -- 容易设置代理 -- 容易设置超时 -- 容易自定义HTTP客户端 - -安装 -======= -``` sh -go get github.com/imroc/req -``` - -概要 -======= -`req` 基于标准库 `net/http` 实现了一个友好的API. - -`Req` 和 `Resp` 是两个最重要的结构体, 你可以把 `Req` 看作客户端, 把`Resp` 看作存放请求及其响应的容器,它们都提供许多简洁方便的API,让你可以很轻松做很多很多事情。 -``` go -func (r *Req) Post(url string, v ...interface{}) (*Resp, error) -``` - -大多情况下,发起请求只有url是必选参数,其它都可选,比如请求头、请求参数、文件或请求体等。 - -包中含一个默认的 `Req` 对象, 它所有的公有方法都被`req`包对应的公有方法包装了,所以大多数情况下,你直接可以把`req`包看作一个`Req`对象来使用。 -``` go -// 创建Req对象来发起请求 -r := req.New() -r.Get(url) - -// 直接使用req包发起请求 -req.Get(url) -``` -你可以使用 `req.New()` 方法来创建 `*Req` 作为一个单独的客户端 - -例子 -======= -[基础用法](#Basic) -[设置请求头](#Set-Header) -[设置请求参数](#Set-Param) -[设置请求体](#Set-Body) -[调试](#Debug) -[输出格式](#Format) -[ToJSON & ToXML](#ToJSON-ToXML) -[获取 *http.Response](#Response) -[上传](#Upload) -[下载](#Download) -[Cookie](#Cookie) -[设置超时](#Set-Timeout) -[设置代理](#Set-Proxy) -[自定义 http.Client](#Customize-Client) - -## 基础用法 -``` go -header := req.Header{ - "Accept": "application/json", - "Authorization": "Basic YWRtaW46YWRtaW4=", -} -param := req.Param{ - "name": "imroc", - "cmd": "add", -} -// 只有url必选,其它参数都是可选 -r, err = req.Post("http://foo.bar/api", header, param) -if err != nil { - log.Fatal(err) -} -r.ToJSON(&foo) // 响应体转成对象 -log.Printf("%+v", r) // 打印详细信息 -``` - -## 设置请求头 -使用 `req.Header` (它实际上是一个 `map[string]string`) -``` go -authHeader := req.Header{ - "Accept": "application/json", - "Authorization": "Basic YWRtaW46YWRtaW4=", -} -req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) -``` -使用 `http.Header` -``` go -header := make(http.Header) -header.Set("Accept", "application/json") -req.Get("https://www.baidu.com", header) -``` - -你可以使用 `struct` 来设置请求头,用 `HeaderFromStruct` 这个函数来解析你的 `struct` -``` go -type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` -} - -func main(){ - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - authHeader := req.HeaderFromStruct(h) - req.Get("https://www.baidu.com", authHeader, req.Header{"User-Agent": "V1.1"}) -} -``` -> 注:请给你的 struct 加上 json tag. - -## 设置请求参数 -Use `req.Param` (它实际上是一个 `map[string]interface{}`) -``` go -param := req.Param{ - "id": "imroc", - "pwd": "roc", -} -req.Get("http://foo.bar/api", param) // http://foo.bar/api?id=imroc&pwd=roc -req.Post(url, param) // 请求体 => id=imroc&pwd=roc -``` -使用 `req.QueryParam` 强制将请求参数拼在url后面 (它实际上也是一个 `map[string]interface{}`) -``` go -req.Post("http://foo.bar/api", req.Param{"name": "roc", "age": "22"}, req.QueryParam{"access_token": "fedledGF9Hg9ehTU"}) -/* -POST /api?access_token=fedledGF9Hg9ehTU HTTP/1.1 -Host: foo.bar -User-Agent: Go-http-client/1.1 -Content-Length: 15 -Content-Type: application/x-www-form-urlencoded;charset=UTF-8 -Accept-Encoding: gzip - -age=22&name=roc -*/ -``` - -## 设置请求体 -Put `string`, `[]byte` and `io.Reader` as body directly. -``` go -req.Post(url, "id=roc&cmd=query") -``` -将对象作为JSON或XML请求体(自动添加 `Content-Type` 请求头) -``` go -req.Post(url, req.BodyJSON(&foo)) -req.Post(url, req.BodyXML(&bar)) -``` - -## 调试 -将全局变量 `req.Debug` 设置为`true`,将会把所有请求的详细信息打印在标准输出。 -``` go -req.Debug = true -req.Post("http://localhost/test" "hi") -``` -![post](post.png) - -## 输出格式 -您可以使用指定类型的输出格式在日志文件中记录请求和响应的信息。例如,在开发阶段使用`%+v`格式,可以让你观察请求和响应的细节信息。 在生产阶段使用`%v`或`%-v`输出格式,只记录所需要的信息。 - -### `%+v` 或 `%+s` -详细输出 -``` go -r, _ := req.Post(url, header, param) -log.Printf("%+v", r) // 输出格式和Debug开启时的格式一样 -``` - -### `%v` 或 `%s` -简单输出(默认格式) -``` go -r, _ := req.Get(url, param) -log.Printf("%v\n", r) // GET http://foo.bar/api?name=roc&cmd=add {"code":"0","msg":"success"} -log.Prinln(r) // 和上面一样 -``` - -### `%-v` 或 `%-s` -简单输出并保持所有内容在一行内(请求体或响应体可能包含多行,这种格式会将所有换行、回车替换成`" "`, 这在会让你在查日志的时候非常有用) - -### Flag -你可以调用 `SetFlags` 控制输出内容,决定哪些部分能够被输出。 -``` go -const ( - LreqHead = 1 << iota // 输出请求首部(包含请求行和请求头) - LreqBody // 输出请求体 - LrespHead // 输出响应首部(包含响应行和响应头) - LrespBody // 输出响应体 - Lcost // 输出请求所消耗掉时长 - LstdFlags = LreqHead | LreqBody | LrespHead | LrespBody -) -``` -``` go -req.SetFlags(req.LreqHead | req.LreqBody | req.LrespHead) -``` - -### 监控请求耗时 -``` go -req.SetFlags(req.LstdFlags | req.Lcost) // 输出格式显示请求耗时 -r,_ := req.Get(url) -log.Println(r) // http://foo.bar/api 3.260802ms {"code":0 "msg":"success"} -if r.Cost() > 3 * time.Second { // 检查耗时 - log.Println("WARN: slow request:", r) -} -``` - -## ToJSON & ToXML -``` go -r, _ := req.Get(url) -r.ToJSON(&foo) -r, _ = req.Post(url, req.BodyXML(&bar)) -r.ToXML(&baz) -``` - -## 获取 *http.Response -```go -// func (r *Req) Response() *http.Response -r, _ := req.Get(url) -resp := r.Response() -fmt.Println(resp.StatusCode) -``` - -## 上传 -使用 `req.File` 匹配文件 -``` go -req.Post(url, req.File("imroc.png"), req.File("/Users/roc/Pictures/*.png")) -``` -使用 `req.FileUpload` 细粒度控制上传 -``` go -file, _ := os.Open("imroc.png") -req.Post(url, req.FileUpload{ - File: file, - FieldName: "file", // FieldName 是表单字段名 - FileName: "avatar.png", // Filename 是要上传的文件的名称,我们使用它来猜测mimetype,并将其上传到服务器上 -}) -``` -使用`req.UploadProgress`监听上传进度 -```go -progress := func(current, total int64) { - fmt.Println(float32(current)/float32(total)*100, "%") -} -req.Post(url, req.File("/Users/roc/Pictures/*.png"), req.UploadProgress(progress)) -fmt.Println("upload complete") -``` - -## 下载 -``` go -r, _ := req.Get(url) -r.ToFile("imroc.png") -``` -使用`req.DownloadProgress`监听下载进度 -```go -progress := func(current, total int64) { - fmt.Println(float32(current)/float32(total)*100, "%") -} -r, _ := req.Get(url, req.DownloadProgress(progress)) -r.ToFile("hello.mp4") -fmt.Println("download complete") -``` - -## Cookie -默认情况下,底层的 `*http.Client` 会自动管理你的cookie(如果服务器给你发了cookie,之后的请求它会自动带上cookie请求头给服务器), 你可以调用这个方法取消自动管理: -``` go -req.EnableCookie(false) -``` -你还可以在发送请求的时候自己传入 `*http.Cookie` -``` go -cookie := new(http.Cookie) -// ...... -req.Get(url, cookie) -``` - -## 设置超时 -``` go -req.SetTimeout(50 * time.Second) -``` - -## 设置代理 -默认情况下,如果系统环境变量有 `http_proxy` 或 `https_proxy` ,req会讲对应的地址作为对应协议的代理,你也可以自定义设置代理,或者将其置为`nil`,即取消代理。 -``` go -req.SetProxy(func(r *http.Request) (*url.URL, error) { - if strings.Contains(r.URL.Hostname(), "google") { - return url.Parse("http://my.vpn.com:23456") - } - return nil, nil -}) -``` -设置简单代理(将所有请求都转发到指定代理url地址上) -``` go -req.SetProxyUrl("http://my.proxy.com:23456") -``` - -## 自定义HTTP客户端 -使用 `SetClient` 改变底层的 `*http.Client` -``` go -req.SetClient(client) -``` -给某个请求制定特定的 `*http.Client` -``` go -client := &http.Client{Timeout: 30 * time.Second} -req.Get(url, client) -``` -改变底层 `*http.Client` 的某些属性 -``` go -req.Client().Jar, _ = cookiejar.New(nil) -trans, _ := req.Client().Transport.(*http.Transport) -trans.MaxIdleConns = 20 -trans.TLSHandshakeTimeout = 20 * time.Second -trans.DisableKeepAlives = true -trans.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} -``` diff --git a/doc/post.png b/doc/post.png deleted file mode 100644 index 934d867b7c12429900c4cc67f7d1927474e734da..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53255 zcmce8Q$S|j8*epXvdzg&c9WZIyC&P7Y`e*JO}1^@?#;G!YCiqX#koBf@6FTNdp~>a zwbrwKSOv>S3By8RK>+~)!HS9q$N>R?(7gYsK7+qUW_<`}fPg?@P5Aj`MEUs%WNfVr zP0S5|fY`#r6*Y}<6j8XF?TmQU`b$qDKTZ$G))P^ zB@T|UF*d(sNejxj(I2Lu^t9J|5s=@&FLT$>#zB!R%@B}t*WCNc?1$zbC2=q&nKV_d z8JxOZrJbd_o;Dhtf$cXIsl-R40HY;fZ$xG1Ci0yELUvy^3VDK&pffUrnI zJJqVM^LM#wEzYjKW|Yh9Om^Y+lX zG;He4ciVA$gqs*{Qvs%fz!c?8gAx!TM|$stV4mnJ=Cm>;O^=H9cd~U;99Au@%f<5Qv$dBCGq){@Kf z5zhoPN|qst+A?xWYeT+x?D?@=Q30rf###4<oFI z^xHmJ$fsZK?Hn>E54$=GJU_NLi`=@~W_&w2aNED~i>HQo&P#~(MeXqo^2`$AyWIo4 zR<4L4e8M!v+7b1niOGRer`^)D$lo6qhe2nZ!52q6di0jdd=sb+we>1j1uvEj3zuG^ zZah4CknJ5qcgt||XDRmfY=fo91qKZYy8Z+xDxC zG*?Me|lfU$}j=B%IkGhY!cWQz}a}1|#wQbT^mQ01-7n@i<=I{X~`+<3kW@U~E@@OI6FAxX~hR z&>J&%$G7c|3_m?yW)4<*>`L~iy=qxG(%RS0?B}?7-->~Z!aadAGu#wqv|p{>*9S;z z5mh@NAT*MXAF!w#@i`C>FOaAJpTbYz<5pM&1?A+km$TtOw0v-Lm@kFKisF>?xlE7; z;+SeFhnA!2J$^$Wkil0mdC&}<9EV2$_CtxSm~{V9n?$AGH_mz0}7tc|k>46fn+Ze}}PGw_B?-yD__4k6Z5b zPv^>ydi|h@K*-u45?+vW-ic)A>$*R_qeVf|0h4z8^#@0gk_Kz#zZ!f5NnwDMG)N`n zD>FYX`KuWv%?iYScY);hLQR02R-hvP`@ZjuLZI*dy`zM77lL-LK^GG$^w7T#uK;wp z|9d~=L^@CvnP67E#-v2YJseF>e*mi?OAT~VejYqX_RQIC;?tNcWx6V!5iXZ(PddKbrT z*~YG95M-9%!iw6%9$h{YThhJss-6LrfW#r1s?-caD{@DSgC=&w64ChpvU@|->%LG< z!lS2vwVouLpOv3^$?pCxp(yL3o!{ZxOy{KjRCJuDQfvCnLspLBiEcjbv@d|UQKqQ? zIg-LciN$uW^LK+W{<@ci|JRH}U@r6PYg4Ljks|pHP9YctJtvbd402v`Z%=K~T+HmB z;zwk=R%T)Jm|_bNcj6)uV87w7SAh9{J=Fg@^?15vt{bjQrsG$JrX#-eif&eGPeff4 znB@y;T2!wl1zx4^*2Ee$!=Zj7sSxku=MZQ*^Irrij~tU%=-2^)frvarmZ`QLKAqG4Rxk8|;vIYjaY2lT$ z9aL+{4BScTqNWIQt8eSGHY4967&C5w84b0iT+piUJwEenOOrF7K3FXVmc+{ zDiB6+nxip}pb+DTM--W30_9=I7py9HuPA(Ne_~>pV|j?J^GMtN?+O_z(1C? z$ zw;EL@2`S;QOFpHC8#g-xHHU~e%a-=E#r#^3+Z#|0FWRe-Yw`|VxF{jw5rsldp3p9_ z3nG3^Hm&cHIocE?!@^-x3pd9)G~sC*7PdTFh~{D4=(_2I+&t)QC`P0Vmf*K0s1IzJn#ESlT4enADFIm^N}{$( zxI&4<&zj+1h~BM#)&gKiP>WRTj14p4DVgd|v;9lCAaZEh{zB|CvHK^};)Xcxi+MF{ z#3n_HLGr6g7uRhWJe%kF>-Hl^XA`;B&;9sr`N}JV5!P(om^5VdjMiarVt1%SzIaqG zdy(4hXWGLUaYKt!6{*PIjZg4g*c+c;G~uPPo^kHaXxX0|c;(a{m&bI^ii|GW)iiU$ zJ*u$%ooa@${R>lL-&Z;cae{qBbWguz?gmgh9klQ!3I0R)86eE#vinI8`lix5q!w_K z)WzqF$hPmD$>u$ItAR*jyY!<1l`+$UQT{?Iq^Z*u{lC8H`)i+cSUJr^M&DzhB$RbL zs}72E5G;*h4(|5T=kK(Oq-BAMqmS{+fzUCr1*Nc-Z-z?q2i$9>r&$ypn&e(Q-djf}hadf7|zZt$Q%(Rh2 z{Phx7*mJOh0nSHryie^6A|n3S(lOYP^_1|dfIcSMiWgz%XXkQTKBAot(|N!Zp&D~gz~+|p00behI|^P_ykBqUo; zn0EA5ZCFDGOrTr#ISJEAiiHr4{ML>7_1m0!6XSa>p9%CfUXcjCK+u0wo_spsgfRqC zzgn=++}M|lt%mcGIns&T=xs~r95TL!hTNnv70Ful^W7Epl3`>Rgz1XQJkw|6S=Hy| zH8D_nY`_TFBLv(hap<+Td9g)+Rmy0^1OP+FUBWFUM#*3jn3{qWeo1*`LB1Ec;|%=$ z#Y~3b7vj*MrP}K-6@})nrE1GS`Yb53MkaLY`NDB6bx0l|@+L7E9A8}-(jD_`k)!hu zS5}R;?Z?Tsw;J?~#l6OT$eUInf4XB>5w7ZTl}w~K!JQ(@`+Awf$U zVBjZX^x+GH9@#kgRb`@f48?+BN^F?H-ays-p$iSdTs3T zYT2UGXw+xN{BeUSr7ET+li7eB^h&vv_V3=QJ`QZJ+T3uj{Ebi5 z-?Z0cA?-0#me?L$I$h#iCNkgQ3w(PIVD1ZbICJO6jayG?;$;{cmvDO-@T>&8 z#l_^u*j5GQqsVL;N+7g6K1zxx(CX} z@>{Kjn42#shxAcj=DJu7R^c&JImzwM4l_~WBT362uL!lBoQcaRlV(3c6|jKrUsHixD~;nj+Dx(P136)3 zUx3RNFR2V$9dI5@Y-GmnQjsna49#CwVW`%xSB+_P4LIM^A|V1RmbIWKgMeGC;&a_jojT^gXT>J_uW z(iP=4c=XJ;)J%rc2Q>QX>2Q|b*z|`MrN2W#XoworY&yW6jiF5*;c3u^qV%y zA<)x}Z_t&4-Fp5!-R+|vzW;XFf`}LgqUuvq+XfuXkdF+bw3en(e`!5<8lF7d;tE?Dt$DTV$v%^)rBOgFm z;0AY0iGX3@8J&f3TG&Ro5tUFSo0;wLTz4M;}x75 z;PDpvQAcJ}9JcvJYE3;NP_wq?gxH2w1+*+d#lFOEq%9Ygp+6IwVqxNjckp?a08lLw z;)7{Y&5#aOxV*t*>vZ9Gp+@47IqActWBR8MS%U}ZF_D^}LhD*M&NpmrF_Iz8ci>e*lgJD(ze>K95@Qt6nm`%Mp*N6DcY+p2MpB#Wk?~MPVz7pkNk;F4Yv@-!ITjaloM_* zAl}ek13$plkbNe!DoCB|cCPu*I6M4dZIJtW-yp>R#S-{dmC^YFx@Lo4j&WBl>U(%e#34v;KU!X*Ba z+48lwl6R@zSvX8Odc7%Zymcgtzo80kFEedP>$978knP#V^L&%3rMLQSKmBxZW{EZV zx;>RKtU9^=fcZN8b81IWd~%(mTxjj(N(52tOL|uZr3xFvT^*p6Rg156ImB@sBWw5| ziQA9ysr#FDh&Dh5e6!BA;=m_7Q1&(fMa~>GxAaMBI?CGx%6(+|01&-xw4S@h&6ruW zr+w0u2Nb$aD4Q-%%>4(sh5>t@$)5coQl{D=<_m}!J@7@8~zTy_?ZL} zy9E|@!}iLA$?x1?e`4nv(QU{%Owbb&H`^Z8bSmItQeAP3Rx&UhzTmISdA!NVX|}XP z#Q4B}$N}VYaW7Q#c4M~xt_)&+Clz|sLMu4` zxkjJF;KePrgDT3KMoe*4$rsMoCO99;e9=Zq(`1EhYTZO12M7*T-^KkdSti(M@ zM@Xm^6VBgU$a`)*Sl;)%AitT}3k@OZA14-!Vev5b(M~vN7Qs8Pultiw1rDFLj=`b#DuaMy_?7B>u zxX%mhUVRKqLjyVPZlCHdm8pNE#5d~F;=rifKRU`5}bZKp|gRnBNK($P% z?y5J;I_Z->dKNj7*P4Z;`5 z0SC*UFMDqsjsf{wsJI(`H&16NjpKfJ6wm%D#`QD8%hu`f#?-A6yBs%*uCft$heFVp`vgH8;!jJ8^qOCP5q0M z=Kn!TdKFN@?JuzAGy^E1BjcK0X#v#qc`bg!@+DDfk*ja_AhwM1`cZA5;%Cm!fCt?3 zIoKMKWU`wNF%jXM{|mpDOHFjou&z&vJ_ttF|oco#$DYtH|AKpV(S_{2X~m$g>|k_{N^mWHZA-`P$Mi1 z?cHo4Gdxy55s!8qk9?CERsKS9f)_-&?$73^tf3g_;EXwyyM!b>eZ&ps-kh!pr$=IjzVdaOHR%JIsEM z-I(;qzosvN!bd0A3{J$y>)L$xVO!mz(%(mm*(KwU)`aL?&{b9&!7b1}myt5v__;u? z0LG?b17|QxBbY7lmRWeLy+Qc;dQcjb<}cGMh7@7>0n(~abPUAYV&vR1y9~XZC)xmm zcw8P>=c}3PH?j@aJIrd)0^-eJ^Kaj{gAnt;uIb!nOs>n>3?mJ(D&JOA~5AaQP$uL1~hwNw5X5xfZ9OqjIf!G&(fpg9i{A|I*uExtY zot4d|%;wd2f|0(u&w(0?pfQek_4$zpEWC%pyy53Qenst=`&1ERLi5PL7eN!X5VImP zCXK%846spOqQAIYFL>__CDg@7o7WkNRBrdBtH|fE`WsgRF^&V8=zMgl6ZDOrMzX%N18+FIl|1XOO1LLI5ffmcZu!s* z8d2omN$NbjX@OmP`1wr&kyJK^qhWVicZzDy1vu^^Z@n(Gd4faL50uGEm)1F(AXu8% z>mIwZsN8?m81GC-(7QGBlisu2<7ixOnFKsu%Xnih9a?+v{AIh72SAG_-y`_REuwz2 zHIE5R_>{e@-VHF-)27sRajcBtw)W_*DQ9CFdIn2IXD~lAU#ZZg*1iA?ChWMq1hTNY zSNb>)Y|~{^;BKf*>FS>B7WX-#J)DkEE@D%CJ$t@)%jNlEyb+T|VY3a-705`zM=kUY z<|Al8OlbHyaxR}?zlRV==`&o_Gpsv@3Pz{d3#&RTgUvrb(ixND|D1JeTAM@#p82$`hgdvUL5`K{_K8kVCdZcZYrnAkS8Jk1pC&T8LU&(Qi z7O)VTRi==#4UiaI=ILDXn{r&w*d(rmS@350%Yh~!D5(he1QU$i^c>J)Ag4xpeHn~~ ze~JQxH@%F8yjy-T~Sbu>ll zQxXiy?H>ZIhv4IR62``dD8)rGT;HLccQg-%NT#ZchUfWCb zJ*h?(Fx0~^VA22yM$Pj;69I+O`>WoygZA;-Aj05>@1tb6o2aNO3m)lmrq#LOr&pZ; zPdPqUm}w!Nkj}I^0Fkhtp+sD}P^DA19^7eII{b8bPB%F~8b>YKS;)Qlj3j;C3y$xM zQ(2SuWbHO@SmJ7eZGkEmtc`aSQIqqCdt1qQGeBdW9xKj2Yc)Z7=wdUVsT@9pgh_5n zEhZM4rjbl53KN~^|K`MclA`N`#SgYtq$>VYk~q0&fj@-F=SXitXM^MKcaF`HU~_8p zMIk!p?@5d9ojC&9G6ib0j zhxv3`0J4FM8r!U1QExOK_}bpx*Xrp}yAFnr-j1dU-N>6ohGahS&U3HrEJ&RAA7t}2 zj(o)dB~(3m(#9QQDCt0kcIbOiHnQu{4!X(S6w8}qH8!y!XITFMX8XD(r!qYvmt3Lp z=0J2tkaJ(Jr0>pU^G^@6QgzJDOOz^OfrmAJwwB@`$1z1h&l4I9{38MBGhB%jK1w@7 z^62EETo#E%_8J-+&78Uok z1nR4`xHjK#u$Eq^osoho!b$ntnltm|X@)SA5EMhTA5Xo$z456-ggQq?21#s8&ZfOg4tUzfsW9s zp~8vl%rv{Ph+dvq1o-y2V1O0-JL!7wbF9Z`+Wf@cwP_bxw;7xDPcC{1KSC1N>>=IL zzNf(3tfd7w$w7`EY^B(VG0G+3mV0-G2q$f1$waL` zpcQtFqRTF#(X@P_%Ad1-i_n(ZVtop;(y$Wt7+`E-D`mpZ{yA{?{<=DbBihDU3(ih2=f!qOK^9$( zrIrQ-<*gu0BysYq&=k4ks~SHA6uxH2;vF#&UU$RndSpbo#QpgyJ5#uFv+iHA{-0P9AL1dV4&JzscsZg z^i|TCm(`p^#~d)%p&|lTqsUu#uBDyu305|?I>;Ufq_YcHb(z9!=;b+JruZyb(yH4F z*af9RC=6lJ#qGr=}sOshPi!7&9M-3YR4aLR|noPR&@|P8RVE^FCj)ooqgdYp} zWPSumXm@KCXOu&((RqFM6m`J45pGYvE1PQ&Fe!MC-rTj6OeQhvwlaQv1riJ%`bZqC zW0ViskGF&4VB#FN2wZbzPDkZg(b!l83KFGXYe^KK)pE>jB#Ckr&ns+XE0)eU1@uGL%J!rVn1vAoUMeh+?&VS3sh+d9V59}k;3t5lb zC5c|-HW{W>ePPnjjuktPAsb6l=kZyWX)}C2Qs)I3deUhZ;jqt)$r3b}Xw_p$5_Lc_ zZ#eglYUvzT)|sDhVBla9FzhLDb~H&du1L+DyZ{RFm1e96vYe5qQ5z5z+wTW6PvC^O z3W-BQk30-(%7mqNEwvxTIG`1HX<7>lmPEKTS;JN_SXQILak0BzF5HUk^6NJw>`m7i|1#WWe&|myT{?dGZVz+wAimBOfe`a2%yWRAYHTy+S?iA-B`BFHcl>mf^iO$g z_Ot&=i1bb6!p)qKLzjboWF38tITqhf#w&P#E>-3u2&%#h>tEJChobcdaQXv^{RX{N)L*B za@kk$HT9+e;j?9rL2UK`42epUHC~J{l#8)70;|=JZuAU~us+065@vE5P49>5m}oD) znl#3D>HfrFNt~uuh<_T)Fma_sZ&r=VzHu>8&>w#7xO-<;h$&2^Ot}wO>SnlFFh@iE zjm-E^%gc=cOG5P;p@u+j0_XsT&j= z@e7f;)|vi7&$Vapy|q$7FUwsKmcw=9I)3Y}qO8f!Y7KkL+p%^P@c5_YxW10|q6KqU z6=Ka%Wy;zz4ayqLA<{f7LL;vNO?VnR!NS1h%;gP$gPz9M)TyJotwEeBwMxg0dXB@i zwJAP)?r-i;u%0WQa$bI9N@5HP^VbYt713l!{RGil` zQJde6>5-?r;rJ$op{g(%541Rsy6hA*`jT0R(#2r|3$$KYp=w$YA?=!ieYpT@??{Zf z2)&76j9&pgy`6rRN*p%^#&XydX9=_fWTT3P|Kexo^$|A#s73z2G5Yk*9!+K(wmRX{ z#58R;XQl7L`5whdzn7=e2soIlQ)xfVH5vTi9))GQ_6a*jg z7QPwEl5El7$tdA+3y4b)$BY~ z6q5_bqOm};FVp6{acq>u>aV16V^f5%J=Od|>AmnA(7=%}uZNVTKQx^S=I{Lv%!?F5 zBV`R~g^Iaz74>jnEr#K&8f19b#7ZUiT$L&`P0G^i=;5Z%@e z&_*hG{7UxEE@TZf4Tf<{J_wJr86JyIX~vKH3WF4kyD2`AB_ zxcLd;@kWu(9AVinIEE4|hhbevr<#I*A*8H!jb`8xr94C{lFK>E?%T521)}xitnRFx z#=v^iDF^;BL!o}Mg!6-WF5W4sk+IZIvu>zzQ>WeyjQsnJ&2lSP9n5NYJihr3{wach z5`|piAuY2qfwG)n$Cq!!2{5+5(_4)O30RY%Lij;m4M3tv=nO-|gJ1a$vKXupJ>RaC z))A4i@M=IOm;YD#wG=9wjhH3ER68d9wD2zbmbv9`_G1r6i;0N`E{pgyT?75l;?Mbl zXuqT=o=4)VkMfvRm@~6ky~aEsEiNE5)~&!LXV@~FQ{QIcnAbbyjP^9I>{7yeNw&A& zK&;oji6p0HZp>wafL)9W6^Hf|)*O{1(SxK(F+mVcx5AR5lZD-!`MHmbg3kqoONNRL zUaTl9N9kRy*jJ5k+`J4aNJI%PI4gt{)OD|+g)WZ8$Azh5nGhC3bnHe&{EO8{ z>R)oEJ_(f|u37q6BVRoJD%znNsNHz+EQQI=$MJ{k;r6DN*N{8SF2MSha>KD6SnnG; zE}uqIH5+3~KXF+WI7LXBg>5ix$I~|?0fttZ*@&UqOM>|O`f$$q5H;L>Inz`R|a8|W}ZSdvZs&5ZMh^-ST?dEzSRE+@;cDL^=Z zrlX~EtOg2`PiDm?Aj`93j@`*S*}8E!45!Jte_c=rEZgi|GdHOjIK^Rg`Vh>&uVmy` zHo>@`kf=mk_Che<B*V{#EQ!j!R zdR)R>=QcFB$W44WbF*>t^Vssmx-0^%ryGM)62Ir*D`lWVFdcWgQsU(i2eQ9T+;$r)4%OeCq6JdG*6$gRzC6dF7DI zXr2#_OzP}L_{f*onZ)>8Twr}|FPDV`HyN&Z8>}L zis8uQTYH5v*Gxg}^H#bBXKydv?WEZnVLU~P*^l#+k`_1`_pNG-O_<`u>ORZ zNjtBe?qyloCdk}23IV+h3tpz7xiypC!i(TnvZWIQ%O{8Qq6>GM%&)0W&rO951Q(W6 zq`fVJHoyVBg73^zbQzc35d?j60u<@t&EER}bO6k&YZ6Iba;3Y!ch{GXY5y-*w2zn` z?aQ1}e6jub!)?h8oD6PuploSkjc4AnIx46IDrbbhSI7?iHfuGv`FX@y{T$!weu*UU z@Gi4ttzM^wGQGR0OJW@eK11p{1_1c!2+4OClNV*#*vfXs9iqb?G5(}HO)70OFrq3% zulXBh*f#f3wS42}XWFQo-S0%uR!_|~-;S|34>|axa2Sqw(W48_O(MF6W z+2=@c_I%MM9vAMnnIyXsT#1pR@lk@iee7sV`movn9p`*1;Jp)#2;T2*9omh=mL8-E zahRFZYksxocA*QQ*YpR1F$t6=U| zN9U8mQm-o_bQXau-ZNHge{Bu~(H6_DM% zn)}8$tZ6+4)e2FEI%z9}%$OAKWTN{3$vc_Iu<2O-ZSZz=jM|uI+^Pf#6;^3iY`W-^ z*4twFBWvw~8moYm0(-V}4UMa> z+`3>pd>rWL*!Fyauv{^a?VjzDV)4nLZ?cC4zg+aqvtd7T2;%!6A}%QkuEmvU5H1p_ zwngl(F<$0@ZDLz;_$5hM0)&(+`Btz+7gXnIMIpTGzICqced#({)0a(NJiD!*`(vMbre*iFd`z)2Z$CEHgB#&y;ji4dy>(Uy805$p7`bmYjn6)`ToWE-g4`S3kILj_H?^S~kZY zfwkAT*RjWhYI7|>_#XCn%1W|=&iwvN>BaH(;QUy7bpi-YMUe7oa<`YfpH`0!Y$O-+)3yiVDhSNLuwCp!o z0MXx@hsM%~PN+gpk+-I0A|k)WJOKRRxeV?Q2JY$0O2>03^4Rourlu7(re*x|nPCUU zGhc+Ol>iMFI@jvR(WJ-m!crB0$pU$c;W z&Qcz0oO#^k{Hf1HrPTHHT4xPr4TyO_%&mc33`*L55v)ep_tuc(Q4c~xTy$30_~A3a zLn%=t`_J4Iw7}srB|Wp!z3UBuhF2W$#%okTg}@mrlbWMas1<2fuGLEEhknUlr=l{? z2po|(mF?EOI%=s&KZN$>5n>k8i!l$#?~k|`cAQbh$mB!3Ry0>X@vcS~MOU3sqwMAe z4<|#Zk34PLQaAsxRY`rqwUge3j{7Eq|g7mu5buh8A9CEkjCg_ZeP> zp`P9segbjMeRPr4h&CZ0&F3%8m`n?m~kw-mCOA&uDZQB!tisWF4?)TUGj@ zBnusMe9X7<8oH*UZRJ9j91!Oudq$;vG4Z0(7q+KBHH9g-yd~TIo|tl&oGm~~Mqs3t zx#WV%1u0wCiS|Gw67GiI;4Bj;W=t!6$)U_Dgq zG1r@g64peukvyBnX-_5-$K#|XDS1V%nNToR)^^mBaBkL0Jy>qU+J3sI!{(2cKfE-F2mS(|E4oRECk% z?wA=-`$Ip)z}yJ)Btiy&F6M4zcyEMKQ!IuO{;q-E{pGiKh3x$p$_b)F5sM#J_A#nN z7CJJR;&T12jSM@hAPQkWk*MK|7Xs4I4PNl2AH$-gB*=P>=^Jxl(KNVx9wnZ)>Fp|j z8)F?-gYW9;tc$HHA%6BRzf9i<=Cyas zA!%0b{8foh_5N5SsIsZf56=xi!!3pu6SNbNPSX_IFt8;L`B`ny^6zOD0N6CPudPbC&V1+~Jl*-;IpG9s2))qL+`T%)=oQNr{6 zVID%G)4$(^@Esbhp4^9(IVYh%N84g`pD#4Oez*^g`R@Q6k6{B5haE)ndJu*K9$g+(V51~~P^bvs z>1=ulf=nArm|?Cjf|Tx!;LXFD>%f;~0)wY%*;9lIug~*?LTB3C|9C*?b|Axyw-|OH z1rd1Vim)2RwYo%K)a>0HFmaizJbx*Pko~&Y%0BKLh@y(0bSy!8zIpiUvsdz+NSk#6 zh8nT|kBTS|_|B$JZc@cx@^!na+{a!&d{VrX`|No-W_sEu?D-6PJ9A+i%7M-ufdtGz ziKwa35*BS3tG}W9hgc^<(8f+c1PT|eSOtX}H>{*(H;vY%Vd+Zv67^$;zcvf(B zns~n1>|hKSWzj3QDrBDcZ5XZLZoHl5oOe4JyEV{`!e?SBI6IB4_eMETGi_T*gTAPG z6_uVGW-?vyson~(6h~{l8~HTZflbrD=FpGy-UD=oAETuvuW<36xJnM0%U^Uj)9Q3v z`vHJb@^$$PlLgM`UjtLb0rhnbNx@&FK8QIkuO24pk#f}ctt2lK8(s~$^HoorCI-~s zLV}}~B_m>S6x^_wNvZ>*o-o7fWnFViq)BYYz5lpbo9&LiYOmBGbty~uoh52;UI?eM zElQl7Gt%v~_1??ALh1K^PC*?n5zhLGSU1+Jr)CqJoXX!VH>?t<6CBNg1x^2bhe9s+ zp`4)DIMzQFO`Jz9S(2{S!@G=c)^O3}r_JMv`wG;*o`u`D+vPHB7wlrRD(W6Tdh^sa~sBUqT zo$c(z4iw@En&KKZKBlyudBZ(S{6Ed}LA?=$H0&pMlE3&=II?CvfWuyW338Vtk@}D0Bnf3vi56w99nt>% z(18>cR{!K)_S}BOb{msKH|}%Z9ck0!V{S;Y$xzQMGojjp^%ZVGIm>=UGA-o6A``u3Gfw>!8*T9na`jcG$VD@jOm7Hsj0tzalhtcTG zt4ge#g7k#6rK%VK$_&EsaCjp3A2Sr&Hc`A*Mb&5NR1ahW#hWj~$mDz`vwWOM1=*sc z3B$B1!DQ*|{UUxp#M zCH!ws(?oYK`$TAYnMde$BXo3h%xA?{HHE_?*;xW&671BR=HqGAvJ#Y{@2Q-s`PHP#M79uP)9;B%_8)}U zgTQwAWMPp^37YOrI-Gpd&CL8DymD()tRHJf3mKYFYE1i0i(sf7v1b-G(MWeNxW1hz z1}#Z5H3r3H%a;5P%}KSYlz`W9WX3sP7M{QAYtf#%LozddsNbmdc!^c2os+lgbM%IT z?`Iz_GT8N?B}ohRrb1k~&-nr_x!Lq+&h}vQp`Ipo2oNI3*w`rftw~t1NegjL{4ez; zW>!?*v)m|L!9oUBl)?Fli-HEvDhvuMT6V$uVuDB8EbR&5Co+vwQuMDm4d2hk4BG@wsftl^t z1dn64oI$}te=J38JYTwR`bRvGqq#Ldc|xX z12o-uI(Gi<6LgO=11ml0C6}Fm!UpoQ(SU&G7JMn$X!$7rFvZY873R$5=}?!XylUX- z0WhHyXbY@+ndAusrPQYrIQ%?8DFLhYsHIhEn}(CUfiC@$s{`bSilDbcmdzk{p1ifI z!1HIghv>|pgkj7-juhP>-ghf7{WKCfZ5u2E*&U2T{oc}5uk)Mehd&ESzuNZ0pJka0 z9X`|2Iy)};&1bpV_$B6laRfNGP-HS zx?@=~7f~Axv&lW4ExFff|mjDfn0~jPt;b`Qw#@7 z36$?{1l~eWkh4^AGY;Bir`-`Z>qg@(dRyRn>t*ikn6`V5x5+(X{wZ}TV;=mx9gH7S z?+039@{5u72$boxb;yu{xY0#t6?18$&h(fhFTc0xvog;n7yi5hkAmgEq}J-;v5Kzm zPBxWjX#3b3S%7$>qGS2qzdTtWWyR*HXh~r!hVB3HdFfFoL{89PUY6PKAPtsvUOMRC zwONnK527ER{$XvDg(;)%IH_I3^PvUDjE6{KQ4MJ7+j3w%Knq4XkPZe;JaOTUsbW~v z?5=LU;n1e5afQG#MY5wwY5I`alK38vuRXE0=g8{GGVKli?V2MZr>V(Ol4Ksc6=NY- z6nw+~YTPF>_M6$Xz=`OjdFpRsBExCqXT(4M3!vUuC#SODn9_GC0zYCOWi%I%9oz#u^qB)9~3 z2=1=I-QC^Y-Q5Z9?(XjH?gR@?aF@BnzV}{D&9C_}@~Tcz)UDfn`s~xad#}}N@2k5e zx?)t4Afhshsh!}#l*|8Wpk{ojfrGVK&CK@8>w)5a@?#NUBlDCFWYSOY{tAT5b@yrd^#WU|t$kVJ1nlW>4qU3O$x9>qQp){`LhWgW@G54? zFhIl?Y12FC0#3bTyd(@8enf%+*dx|bXw(nX861?>1a3& zNLTBmXn#*Z>=8~ej%mH-q#xUs@j%2yjd8(r~b88{T;#~RDT!iIK=LN~yI$XGP zL2qHkwaao2av+Ki`8f*VNW2~n0y}!Zeafkiz^V3G1e+2;8Mg@Q_BUtP|$fp+3 z3Fx*;WE+roXey%~{bEnt*$@Gu}&iX=` zesCWK`PJJCoeI0CYdVhEE`2fdnscK*#0NqR^f$!n1q8w2wz|9`$T~6zB!56?%0;=^ zFA&PBN!#n0W_AVh!}e{i13()us85Y`5-)dS@f2}XT(Q9Ebb^&k2QomLE+Dwauu2eH zq(>W}5!Z%_MjD(gWj{RZ+r@qs!OlLVma7Kdchn01FrQV1#}KMl~8 z{jhMm8raa1#Eh&Lo8#=h{)JVUAu6}FaJo9SCArJ|aXo2f(Q2Z;#;@`*J)ZUD(NQyA zP>=09wVtu}1uDEWl?t zjDTl=|9sLZ^VRR2d*8B_7ZIMbHPI%*cL#_J%mO_?E-L{rt_62NfI@+fkIl)EkzbK} z?lB9y4CpL6rd*A{cv(&rTx^wc?Jl^KavZen(Xad%@}Lt-`b?bIc2hlzK8YRy`k{WE zi7KMb6Y6y(~^Czar~05OhtwWK#c=()@bPMe6YtMuB-dzHrG$Eu*xU-aCB*$>H%G6I?p1(Y<9(a~%g)j~3dl$aR23&wyRF8c_% zDMeCRkBgE*!(Q$1Og!wAD;i%}q|L?1I0#l{hzM!L_7#p1r4KY8$NPJ@Z5>w+Fd z%2N+g7ViB19mhiBm-HvXDDaa8>tD!6p!Ey#S;l=vlL2rZwKQ~M7V-V)?lcPQ)Y!tl z?srz6=2F+83cmg-5h^)A%HIbUEcb~xS(JUjW(J!AW?=p#^3(xQ+g*N6b!|Yd9>bAs z)e39#54W$Uy&obMLj?^ zw+4<7b5c}S3;6f?H6-~ zinN(zZ`K> zm9GKHsZ8%@6@Fb`HeXz?LwD_ctVIToIOVNBTB7<)io;X zv{K6NgwY}k)t=!fif4>u_Gd}S2A=aNuU;+z@3#)*Zx1eV&MSd_Q!vR}WXBo_#AFEd z-V*fgSB7qM5Tj%Fa2M5dnZv5SIpBuc_Dfh9!=@v*+KSQHP5sUMaHKh!u@Lp zHPP8;a(-5(?-m{&X?*jk{6mPwMfgQ@Qr*dBh+Pu~NSZ+LiG_C-D5UaQy5l}Zl@RWF zLS1c1?Y1&Z`mTF-WWnI!Kc7)=#R1gDR|*J4WSjMTN(&huF|prfBm*R9UZasB1rqjt zrNb7Vv|v$>VVh}GMto3MYG4)^DIMq26>}p$Kz;L!B7J?Vu+$)u#uEt>S-Y8saC09O z_owoMg+R|-^9R(1GCx|r&^0}&)*cEma#jmU#uwVaL!&=BnZm4h|4k_KzDy;|W5q1M zA1lsB=2OPktZWHGZ)Yhvgw(p=3@|q)BkOd*EgVWpG*)Pi0|RGciMS^JoB3!>;MD(? ziR3_%ta40u$_v907QZjz4R#wHWwo4zs(LIE+O&-bg*7Y;jUf`wEAQotv?rxtzDMqr z)o}A);CHE#>@WCT%>kEuD0O3&HdArs1!rpXfzAm9EWv8$3Vwe&VZn$n*Jisr~z8Yz*e>@LT>L)IC(CbzJNqXy$`n|`J4LGmi1fQ1gzMuRK6NalQqMbR`@#xE*K@xlLji}sBtvBo1;k;i z^1eY^3Vf_kNuqi?W~G#Tny!!vxvJ+p;%z_aiwfejQ~3oX6uZDM2=WN|P|dQ82Nu$A zJUKoZ^I8N(0nR7o-qQU$X9=h;X1wxI<3F@S$2>l4(@BVg6Q@*pBcc#Kt6tE_fQeaH zPFw}O9>mbols8@be!7{3pi7Gg^XO9seJgt5KHEu3sZz1}2?=GlZD}k1>@?C6S7EOR zUDH4CX~Nw@T}~~^i3vH@XMX3~q;o8cLme!dEpqIXF+ju<$C1M(sc4pxhjOCM1c&`| z&Q9*N+|;2*O3`jl`ga!A2a}g#aDL7l#LZe?ux7&SWnj_A_vpT{j`q+fl|%*af&r;t z@9a~+vXX0qH~5TsK~$_y8`#|9kndCRwm-usKnxqx&WC$<>lniGuo%N2QDIZobogCu zr(x=QtT@n9z@J_-CDXD{0D>NbSg(bdi3a#LkqlN1KqT{0Z}CecgP&lgNoyM~&3g-O z&YHhe(CrEkuQ17WaBk+!yoqr#lYWVD!si5RS#L=Fa$DTYP4+1+Jb#{J+*=1Rze!(I z0n!(V4Jp`Pq8#2K>KV`FU!oj|S-6E|d$8TCh~r6!!4=i#Q6T1HD2(?(9hb zKb>mab?1lmG(ca&gS(Op(AQ*1BBI9c3Y{3X`D|Z|eUIkvmEhf>wXg|L+7udAwmwTx zB3o{lLOlM$<|d%PhbYeJf>^=j478Z=mL~R9>uaC>9N)*@2PM|=F55L>wut8gbQ?tj zJj!j(xJl%r@imb;lU0gV?$}Vtv%0=0UcDDsrnJjV(GMY|_}ae;ZrD?R=nZY{n)S4M zXh>_dc$>mi^6xi2K&q%%E|(d!eXPpZXb0}lCGLtiZZtG+drun-v+lnHL-{e|ow-GG z$mdzHQ-}IDfNzPjdNW3QLXC- zj?~^oP7@zNEG**9eB03lrIo4$DN0k&W9f+7VD+Jk8{~0AqFDd3?_ic&bO%OqufKFO z9nNvmyBJ-)+VwpUX~OVQbSQKut(b>fpURZrQgpsly&6yKow$2OW7WaG$ z+b?Yu8%@^M!z60#k^wps^%s0h|D^H*n*{pjC}!{`y+C-Dg05jTV->^iM_iBYg#zWR zs6(Ws4vqdRH@FE`eYcMlgQuo2H{%5XnCQSM}9d?&;X@E@^aQ zg#tvU3@Hhg`)6qfBsQaCsG)T?{Ekkt;D=$r7&qPBNMVjiqHf)Ur3jq`V^UTp<;}eG z8cg~HJ8pIHx012X$8CT!0%u@|=;`f4>=g>JaAPf;6`O5>^(J)gmzx}i%6|Htv)Bwn zkw-_YzAS8ZN*H@$GEG>sgcYW~ktoeHMFWL9fXYn83%el;ZFQIQ>FY9dM6@o3=r`AU zX>b{Uke2!Y5>!e}4O{+c-md?r5#Nbm*|d@1su{VhtQ2~&hcBaaQqY@H!@2&elGyR? ztV&!QvTLh;U_wN@(4$3YqVpfV1jp!&FUbqj;bTNuXC&j3k#7sp~Mgj#H;g z9?$4nx-7`j3$vbW03sGzPJoDII0wMC#NZu6e7Ivfrg!;(N-^6hH1vmendzTendEjY z04z%Jn*p|Q^5Lp2xasav1I@}z1W}zjfpLv zi~IZ$8%*g?`w4Ynes3yH2BX?WSi8^>^^76g?I$lr{+f~OfkI-314>w=eRjbuO#WCV z5(!kFbTsY1|eAZdX^AFZ{aepW5>u6rT7^E7Ee1897`u9ABD#>&bz z`@REe-}Rm}P6hx~QxihmVOmeU(hJg-W1<=xPJlh#hv-elRBMs4_p>IOoE?W&d^c0A zxDzN78~du@fJ{qA*BcmPb;;*h*aJub8S$-ln2JPe zKmpT(1ee^mrw@iF0nz|oNZ7dQW}CdsKv`0BZpLuY-i(E%ZL^OwwO&c2i}8!F1ztB> zbcSukg-`s9q=S#6Wyab&G3F|O-XjHu*onW}pKvcqReyWyiHw401jUXm?8n1L74(Qf zK$Jm%d&W^bpqV3sl8j8dLnQ#HgYw163W5oct%r8PX`mUnqv=;U{f88KxR{*8aVf9O zNgBRT;I6I39xlh*#1H)+>V+wi4###+7SgZc4j+HfOTcp#1)s5J$~)Kf!A5c7L3)h9 zH@~cmAm`$GC%gS-qaqg9Ws;K`?XN%1-Y@x4AyePS&NRi+y6;=pZNPFAOUs-C(;N5P z?U#|E3bsOYSDf$fKy|EY$p}0=e9RY%03_9Q_7Z>hEGC(@H?_=8yXkSNgZiS*hMp>` z&?UbER)i`Zo+C@QoFXI*nG@0_%sgJzZ3HgMx4*I!7;C^jI5<)yFyA(f47AYEipX8+ zFe|SJ1`(+?xxW*Zih;k27<^=zOUs-J*r$hhI)BPd?G$1y_&|P*;64~(SX4iKkoe{} zPT1n8V;n*av>f@$y0njpQFj@PXE^I+^_FxfMg~0f**aL?Y)d3j2{}qX^gXa}kx8|> z3ncNi>D2k0zFn02ytv9xnppp8Xf+BxKA|*X&506)$$-^^AptT1zl0&j-GgoRcM!m? zB>-SlkWi}IJ^|zq?g<;w-qvA7qThW3ZY$4;Rd{VQ&X)qp%#;zg5fAH zyfPN-+Hns`!g%K*b3_lh4Y!VlMElE5a`-cVcY+Yp3i(3my=nEIGO?dj-zcvCM<-SO zCd_c&0r)fE&P|;EFk|&^%$Vv)*Ysa#)L7!{8Qp(^)%d?542^8pufA$B` zSM{%S?ha*fuiSl7mC9g$@=MPbyzA1HIPVupA&bGsfSAkPBrE&mEASp6II3TWG7An^ z_-1B*P7$D^e|AhercPr7o=e9QW^GM(wV@{zi!kQeCNh0_c%X%$@lkjY5E4gCA^ z0DN)|5?Ju(8(yHhMr zWs@$Z3S%W*)64ILQI1Q^+aLq(IYX(b3C`Kr+H24`pW2P(#dcHqUn|ixZZE6^8r07q zyK-7lKVtTASAEr- zt!qi&=7fxT_vZqY)FMzP<=LB_Zp1>nKRY-G;0IDX}9 zr`&21{rfJgoZ%tL2i_x`U6u3+g;!;oF9^uZ(k~h%RhDZrRKssYF%+O0mDU6_%w`=+ zueAHcSr~P8i&c2%jKngFivR896QlQc7qu2&rJ`nwVzYsBiPDP`-QN#apA3R~N08xSi8OFbM_bLZ9 zYTOTr?f%yvp0Ji3bb_|ag+MheY096bl=i?m&D|ocn;K>-4hh5b24+wZ?aIEpDalm! z{p8sHdXoUXqEid-og^W`3JxCsi9r})dkP+%Ed8%{N?YKK=!@5FTEot**uRrYjQ~_& zVPF@}o0s9=6v~c$TPLL`Cj5V`n#to85DJC}7X7!`S7IP^$|+^4*^s{v$nEmJ z>fag)k^rVx%dXKVMEW;GFhC=8Db%ljyZw{~u)tTtHmNWF3=R!wM4w3=^tW3aR!G3Q z9(T!7__wan0ga5fwLkssmfH?Ka}z%8nz!uVqtXMG{*+fY&fjjm_X)Ur2w9K(RsKDJ z(WlJKY#@K0f89>Ky&t>zk9Gf^U=oMBhacpF@4s&U*D$}`a*i#cR%aJ*Q>ym1MMp4@2%L-wwzvD0da5CsL5vSnurxJN%hO--xuI41NX zRKnWFfxJHckEa45UY0$5c1-Ci>{PXGEe=`Q@5J_m!`0Y>l_H1y)oLM)0xw)}QOtpt zvrx8NjI%eh)zKaT9vA_YOsYESkQBkvG|9*9>p<4DSeXoVwgLcjA|(tUS>k{GYvtX@ ze6c4%0O=;^_lh5bCiTy6zJuK7?{tj(2t}PT!#84>2_inG#VTpCvnAZ?B1XBBEzZpr zxA3va2w@VY4QC-VjMXZX8E7LQwCK}ovksOHuXL6zN-tBw0eez(C^6g-O;c1TBhZPI z2Gy^oSG;OeV+_m9)N+cz?a&Ey#Am{W#DKaFJ|ombV0#g1@5a3j)Pkc!zd-MCHZdzV zp8n5&U;MYWX1!eD2)MBmAiDF6{`w#wcStKBR(EF^^nZ>E^v$>7sD=F7HzpGRPV|4j z^1pU@n>7DB8Qwx*?*F?cL+?YZIdyV0r*yD2ytp{UvNW_n;+#-DbA;~6{&H4KJqRPq zGEFUCnhV;PL1A$-KSt+J?3$6WvExVR`_5KSd#+{>UCBBfb7cm4VdmkJS<)RhUSVqV zdo(c;T0?v~-48M`L|Vs4itqL$wtK54&elI4O%+0RD?spZ7QgJdt;Rh_Rw|8ei~U?+ zuJuGQIZAm-pk1|l$$Q}ZEaL%|ICfB68j(+Q*IX=SL(%`S#Ctp3=w@KBL4o~up#9y~ zOS5P9r63^-hXbZENPbEVtV}TC;H97{n;YF$H5KTwDiVspm(YA#mAeR!IBqN}0aB7u z)f-A$-nIrRv*4NfAWyvr*ObEJax7 z;GQA3{t5c%h|n}ZNLXZyHaOk}YpYSxe<41mVI9Ul=j31wv!y;kP=Lo_i+tMV2;-rl zmN!N-YFEDGSB+hrp8r^7#t*bb9i$E-UU*K)8lvfh6ly@a>QZ?vOrypt9c2AG)C7h< zGi`)|&Oiaufrf(7?Mx`l(mp~{O{Ft@oa7>@npo*JKjAixP{je-o3+eV!;)D56UyN?D-F&SIqgnkj4&yYg>*?XFuS%C?r?dL7SRi`LhLyxq24cL}# zawiVWPP?q58L|GLbLg?)PsDrC-w46Ps7qG~-aAq^-`i(yGEHPvb+v31=H;PXgOg2F zx}=|pb&fU}WcFcAge}LY%^6i)9Si0`l_;C{I@F3b^A#*)8DybS9b8tHJ2A!2Em5Rf z&AcBNXN!0B5X}Q0v>2(cDC>-#`CRGZjY_5A;v#C}hbmfSGev$~X}Vz`$FNAqs$&)S zE8yn_2hQ0^Wn0fOIYzaVL>v&VNSmB79-f}Dbq`G1(j`jc=;oxzSED)MQKGt%s-?2SdrAt2E0KE{L zb!%EcqZLVQoq9L8Sf6s5WVylLcz^-SyJ|gV`S@|&g>gYa%W@%;Y))QsV+r-Vdjgpk z4F^3gy=m(ZY(v^P;MY*KzJdx2S)oNoGy}5wA1eX+mlY*~H7yA8nO1mMgW!c>EpU&P z;&uENVYgaKDi2v_CM=2y8Fx>ae$WNyU0jmEo~dyvFO`6$#8j>DnFZTvIOkJy*|NvB zj4g2{K1m8j#1Z~U&-O=um!0B*5#xrn zX?re!9Hq;+5+NnDWNrNXHyOaK37inVxz&s8`=xiI|23a^XIwgveOcp`0lf{cLCgJm znHs~r4fH)4_456_Vu{-Uhk~Gzp+{wg#B4Js_OVFzT*ke;HBvBRj(AbpZKIfFFR@>dt$qv zD{*EXGS@rq6C`rRnGqJ2^WoC&9}-okz6H8!BsU(x)at1Bl`tD)`g<=6B3ZCDVDK_q z4gGvFKJUM-tMd%&SYq#Bc$|*=xS)?4Av>~v8QtqGk=QJxB-YtS2Spho+HLJbuc=?aOBednsF zQiL(nZ4G@VpKJ|iI_-5uUZV%K9DNbf<`L$F$2JXSp)p~uy$PP7-gvwXT$3-j0Hk=4 zY##FLnfb?b4nWKVH{ILS)M4H-T5C-lAH#`sH$gqq6#C(1kbi=T^USVy3OVQSz}m!s zpKyAPven~wwQ_!&9QL=129)||?&+Hj1^HR~J@u-w(VUO3QKuSiDf#2Krap{1U(fV& z2(mA)?{)k8wh!ZutlDtLSoM|H?^ItK$2}d}*Kil6gA)=K!js3PGoni|TFw1wpOxd~ z*(#ADD?Y@CYcpUC#>vG~nw1q^ljTpq4O-1O?|!G^rgY!SW$=XC9X?-YeVM=&6C05) z@Z+kqfp*|L+`zes@BH*Ds`j>m05xocy7}&Cx?oc7u0yi+Nj^X{hMb;hg2kCXQE77F zn&stcIkeFHs-Zq0l<)LrOkyM4s!&73xHocJUYSDS+~i=n1F27*GYpvqxD;iESckA5*3!NTNH&-H!%G?kx@3)HwEOjxzMF~7bUbte6BF5 z60T$~51$B-JE1|+mow6ajAtX-2l9K6(FhRZR>8s7_crKD#+_Lxl?C=x>zppg+r1yN zDjrc`80byV)zV|sF>N$Up`WNP#XOC%P;@Sn-_I{`KQ;tSZLv_S=FP6TZ{0vC|KI{s z)JXDrIxL9B&9OZU>9Js!rPg?|!&Lp_!FEt|SR?m!WqN-Uk@mpwfqbBakD@olWsLJE zt!l0o;!m?Zbq#<5Sb+N3OPXvQm9UTkT*JfTjWrz^)a26Kc_1K-bE}42a zJgHhrEeu50IH{flLo%s1o~7Lr>W}(BnHXE-H_@5T>;YVoKy0UhTsi^Ul@$r4?@i*Ky)w_@XSZPF5KYbxCAppago{G(Hw`o!&H?I)x+jAKMcQc5Ooq; zKD$csnTxFEH<)M0N4fWXSKh0u42JuWR3sv&tXDs`c?)mBDldil%6uW38o{5u4^1*C zEKAsOctBSYukxgG#B+3T2M`BD3fbWQ;t7MGyhm28tXX%k-~!bmmzr_me%5;uuFyhF zKrXd|Ch2L{rBww! zVrR{z%|@#zLKVU4zvy%OIaMM{1Ed*UMLEHsq&9n!)7c*ai<<1C69t=SSGYq*N-*+J+`;M3g2$(298iHrs{ zd~bE9sjE6fP;Qm&2sNxXuMy7*?^;Av-qW?OT8Ijc;d8i=8yLvlmgxZ3fA4l1$+CPs zNzBz!%iWQm1ZG_=&xrP*)b{zk2RszUk~jG9!ir9<);-lrkfLdC7Lze@`bIsR>&ua( z(a-RnFt&@R5;qRj`gPxHteNQ^al;YD9qfds+(5k|)UqVY)2c(3TdW-e71xxwYm-*i z@TVR5j>Ek{w_U|S;Rrqz4anl%Sl}m-T5K)$Uv#E6;L)t6R0g~0YP7b5(2wx+(}O{< z!BZ#~orbN1{2!(|c#hW)AJ8#^FWPdPoSR)?XL@mPrIXmrHgssl2V8JDi2F+D)`+yv zm|-51Mso!{9K!J)zVr}9a?<%VXHb4M(P<`EOFWHNU+Ch>Rx%Aw1r(zahTZIB=(wYw zvaM3p$#5yIa#PnoDRgn8V`L(e7_%NEbp`7S43S20EBKZI!~UCOJ0m*7gnS5eJj=t@ zA=3yJv}{A(a5x<(M(3Jw!w>29+8QR-62t0msgR4@68wy>LXu)-x?~|P+Ix1!;`TU} z;PlF6I<yv5O;rD(wmbO+0E6RK(3VU1duK!7 zsx0JnIi(OEHiOM(Qz9JX9_E>q&1wDAUGPpRRDXJ$0%M?!`6sfF0zD7FX5ofl&*tP6{K9fLn3Z<# z#PyAu-`UfmD+l!3LiP8*quKW9coNKkT6RX2t~)mz6qmGPOG`}-C_QFhZMkJ4QC<1& z5b=8NJ!-IKz#+)^OGdX)j|H8Dj<&a`nX+%`g|ibP6Mj zVm^6wd>h&AsB5YB7fjzWr>xW2-sX(e0X)QhYD&k2uwMU{cjSv5OKV0=c2$*n1Xfv) zsOP>cXvOfDmTfyawanMWFOiNkg&6(0xE-rI>|Gph6ROb<5_I312LCleuI|oPeXQ_< zEy-AP=uuk8pK(I0ar`n2eEe~1sEhF;#eP8efwIZvxh zz92WMp*JIQrXG!|wSqPjz-Uh_tGK7Rg*R-w*qha3V|SeNfHWL{V5r6fLOE9@+8kkM zZ{nwy$B~4+&0$J%pw|G=my^Wohb@)Ki~jj228JD~E;_DX?!ifo)}B% zD>hESWQp~Is}`kXc%j=^VCep1VS#MR)Y4~j%77;uxfib65QN6%u|4_0t_aM*g%QYA z59I*owB3r;usmnu;)tdt*>kZ?T1@dk8d_7yB^?@r*5;h^39r$k8VhSu*=VX349e-w zUcjM@w`j;4zGb(MIpN_9s%TlmS_8sG4K1!Ru#QMxpI!+5;rJTTk@lX*a4jHJEGamy z=FnyGNBSJn;;jv|i&kIs4DZjkMOl_UklW@u%`^HnoF7I9ER*ZECo23smE?kJ87 zw1lq$n3oHqvEIKXv^V4(gId10?EGt#dF5w8K9f(Om7@l=SB}U0StZs!DvROCcUSCY zrBCnl8BPO0sevRrRLBKF-^dsHcyRh8#G9Tat&G}TF zrt!noi#lwKmnXkbt=OU;!}E`%ZsGe%_J^HmtpHqPQ)FD8inaJw4ZfKQl@moS7bE(7pfbyr4{g| zf4YF)>ogXo5j#wsPvaw4c6F6;bzQJ32F4p;aI8m`j)L}Tu#6+Opp9Tt)fRj`2mm9G zy97C5lUC9=f>>B5?Tv+1!VYjRvusAAFnP*Hr{{7Y{*l+W-QRR^c1EuzATDWsf{((Q ztQIFTSi4d1(|6ig>wv>658zG41Wie0`<*n3@rxWw#6YDzbrICXK~z~+BxS|XBbh0eB`6Ud}^!qjml;|0+A@gZ$XeisC;)8 zRdW?L6oO{T_C0b*BQW^l=nS>A`8oLc=1-0%b2H(Uh%_1K1V2+c?KtX*z?s7RqJ6ya zVnC_ZLr1m^-v=<*C=Y}?W|n0&y{S&sZ*?Mi5u@N8>bv^2cazf;%zz&W5Jt0jUgEE% zpLJG8xUt25UaK_&q17nsF3lwvRegYOZ>R*}C9Lnr3ZJ|jd;KIJo>+cD>NV64P@)v4gjG8zayD;r zxMjCHhGiN{G`rjk=9AF(4T#o{J||3xrau)~1RGD)MOh)?;xhQsWmr+(!M+nM2M zt((8B4ocw*NqRTfMCJ+|Is5<&z7XQnYA%Y_9r?rK`x)+x+u<6(7QR3KJSS*m&Srk8 zF)xSA>?o!{H?RF!pE8zMU0?p8bSti`?f%0;1g73TV1QBR#$Qdl9(MJWs0_u3=(hCP z%lN|F1~Q|?)=(BGs`Mo`zY2L6p)3F91Sx@mJ)jeQ@SUJ4MdAJsWX~UuABq?6YJ+)1 zG6~d44J(P80$QS6)VaDj_wRCGS5zwBmQyfox_WA8)8>f3M|Yf|!{?tTu|-3#Pi_X; zz_?~PrZvuT1J_Wj5iDC?28j*Ti~uC$Vc4C6#8Df&!h z8f~QD3l+#7ZF-%egVY31jcX_ngFbnk(BO{lUQXI5`4;n`umLu<&?}xNj|n(aGz7ye zfSB~GlzYDdn z*z$>8JnyK|_F9JA!^EHbT81#T`FaD2(+^$GdjRP@nn?f$IC|yK=2dnFOq|f09jCtH zKsd%Z$UV2hq zk|JG04+~kuQn>W;0p?|w+NJk^j$tjR)XwzQL30ILTBGDAwe(27;*>*1X>LlCm(Sb^ zP22S+)^;?3Zh2Db)kN>yI|+8hpKfgCOSMp*qpJb$tMkR?eUOYApW)mdmaJ1_^X3us z>*j}IOTW}-n|O+sbsr3J=AcK&EH}>*s zv3lb`p-dIEVG-VQ6fRey=8x*-8te|ElQ&rm!?@`loi_N_GnMmi7E#VBkjMQTC(u=} zNEdwqsy%f<8^QRp)tvglVPS8Qn~sdEL8~VjPjC3vV!FmgOLaw#05-ei8sE-oN4n)$ zJ5{9#2zFxVE3XY}_3lIybaulf!c&uW106{m2j(>t;fmC@iMSKCIOSYl$Jgs0B0IT+ zI7GORhan8UZHuoDL~`aPK1Hj7iWxbXSqm-rMID$K*Wg#Fh~SlWs$BEmtuY*y-Cqa- ztR(trbakQ<)j7-fdFwPt_*= zSL7_znbzrP%e4asa7*6PC?-|;tM_fBlF%Uw>^bNZ7leFI8EfL{15U3Rz9K`s* zE-i3(b_A!lwJ@08gv*DPOO&kb(475!cJG-jS_Yof9Lv#Ei)YHnLVLXdBD0WAtI_?0 zM3&#s*iY1pkAE{pmvs#S^<%>Kp}-$F_{xvlNxskOPPgH&oy)8@_4{CHK5q=zcRWG> zM8+#?6qPaeh4gQl(H2SzvfdivuzMijWP8FwPhK3e5uRuI_=mm#Xq~|5iEw`dVE~$; z2Hj6})o!Fy{0HW#0XQ8J1U(ior+;5+sBf4X%JGNDujd1J_D^q8ENyOG_kDsWhBP?w~>|h8v#JQn)KyQzfs(%0MzXV6fXRRPTm7w%w2Z#qEh*1 zS3(iI8POStRaTqy!)c=bQF)S=64bVsWrdq&{`;S;%7H&_MuOC8QRcM#Slo4jRMNj9 zC>B9SFX^MM{kOB$ZuNymw;o3FL9x#7nSGEnL@_Z=^I=O!@^a*ATR5uG>!t3RTyj%O{f%TsSkE@U|bJUW_ysre{AHRNiPd^9ON z$aQ4CCA|sfQNR}0KthE3^_)+-P0lU%nttt-UI8&TPysw1CbJ0Y2~SGPEV{g-6CD(V z&yxdMNmd01yB&1-#Z+=pAT<7l#@ae48X^8mzfsye%c!4eMSYz<GWHmR_qOYuJys zgHl1z6A8F<(oHKTsR^>-bRG?Ksat9-oboXT zpC{VQ*saF1p(NcrQJwu<6%?5rZm`q9oC^@my=J$qhkf=1>Acr&Ku1f6(sE~py2{a% zV4=cnLIiujA#HrmgEM`N;`f1ve@azyCPqY$Sa@Z*4i+gzU|vvK^_*JmQGVDIaE`e6Y1f_bzE z3^Y4wq8JZaac*H(np{~5?l4scp`l@RJobwp*7id?@4l1_p2F)pHtsW3J*<>mPoT4X zeIH3N%4een{~_MISK8Uz;tHPwqNmRfTGUKnM2ooosS)&=lF33&GZPIBps z=OitCGg@GxYVNh!QZ0tR#+*)XzzM@(82GYdxq!2IW7-og@(p|#*1^Nf$E0=_(U}DC^XAP6yjZQrTXeHZ(M3o)3PiuSgHPYdo^5K{L9VIf9>LVe;IKJ@l2Z z>3$hI6%f5Uq}%n!rCyIVKTec-*2EBQvQ2^~QCIYnGp2ThHS2kZwy#&q2URwMM}-F7 zti7y4jfHn!lYy_0qeI(&8q}_3wlTNpPcTXF#3*M-#OP9C19DLt9qhoD7iU_W3o8br zfb;Gd<;bp5;EjbH8P}p8j*{b2QAvdhyXQ}Iy}U+cvO~u{Dd*Si zW&IFFLY5YKL@e=1ePU}69J=r}Yu|s4j{34vH>f6F`lQUU0z~to6a(ivPc8&VfT7A& zMesR|BiM1?@dkGb&zO`LXPH_T{u12|!&i;thp))mt! zb^d7b?O2ucxdrp?q}T(%ikn5(7@XcCnRfVw#0ZtcN^S(UjA;~J7K@q=}RF*;EM$ap$O=nQdt_rOmk z(A|mcMoT^;N8d8=JUbOJ=zO=t8ajSWV{gff-YGtDqpM9s6H4;_Qe#;hX`0-)lPgHK zoAXB==4dd-*8^FQxGjNB#!v8`Yj#2fAuZ=Rslmdu=$Tjh9uMBHO!&Xq8$3P>5lhjL z&#K?jq8RyJ;Mrwuuw@X;o>l5+o!`T5ML$>Mln(2X_tb(?l~Z!J(nn0}BxEV4I?Z!D zf>y`OO&1pZhr?NN3dl`X+?hTMccLa%N7J@5YjVD%e$8rAl5d}Hi*E9AVT}u&m3UZl zQg}*2+yZZ458f@`_Q7&ajyLNy=>ux)K>FrC0ysg!RzMk)@@ab!>)VhTZS$9d+Er*h|+$r4$V^bK}mDaQ4>k?Zq8S?`o8Gi(shDZWthZ42 zN|!UlQEB;VTv#hx1=N$a@3=hOzO~rawI#vitU(uHA=)2{^`+=8(7|TrFNGb3bMZkALe zeEy1Agvr3xlC8uL5E0=L?c2 z+efkSW@wfqS_D!<+aA6LbLZ!ZG zMw-~Wskm9ySJXorvhep6plIEm0vmdnSLhDdl3ZFTYv4(E&x7^(l-I+aZ5FIen>*m2 zjt+g55y0}0X9BzLqttys5t6Lv{)(4b1STH2N(wW6p2+YTNIhF&!?cXyjDG$db6#0g zaQ(h*`Cam{1%JFHkbcPS;d?3U78y@s37~~M;*skpuT=gol^1s)N=GySXf_ivbn6z|(H2j=>0Wrwh-ZW6 zNhfJ-%znqkWxA4er@wQ$Miaw^AuI=OzQxd^3kyYlLEAc_P7u}Cy_&p7asP8f-!}vz zt!ewBpak%$l#0zH&eW*g&;6?C#-cg1~*cg!%tU83hG=ioS~D~4_+YiG_*+U@|J@W z7Nr%*^(yb^0t%xJjyg#V@PAM z3i@Pws3+F6?gffnGtFzNdDRh_+Dp0D=3=&FkD@RtD#-i5PTw9w-cI=%f%Z8)ArEo~ zTI|V)DWP!@lvC>djpVCQU-n&NMxO@Bgcvas5cM3;Kx{@Qrj6{hnC>)7ISuCn-61 z^3euP_o!0}L;kqzVE&aq+&ygA&mOuYn$S{_ZY*!BP4}dVWnm){$T$cwhuA^QXIZO@ z*6XFF7;T5dqn)Iv zq2i}^Z#i{I^?kBuL>^!IKwZYt?rN(pL#gMO5H&d60c~8fh41v;%~SI*T4^1MK$~eW zB+;l;ZNGNyRsBhKaR@pSNQ)gG2r`)@??PD~jDF-i3ordP^SH^0pF>BBs8P4~gomG9 z9k@6vDc9>_TgocO1w=$Hibc|!TnX#zbjC^W2n8G05Ep$`(rl8)EWM(U zVy*AN8(u(1j?^~JlC7_txTG>`E!vEeMyr@U7P6VgE;aZxt2S7Ht8B5Fog_hu{_p4Fm`hT!R+w?hb{7 z5ZpbuySqDuySux)OIL30eXqMmfAvrQbdCBtW82whowL{8bI!H;TnU)6a`XHbBKa)T zf}Kg5gfx>6lcdW9|}#c}0H-BnElvO2iRLg9qFz+4B9$?y{g@sdL>FIni%?sl7`b)FfRS|Wrf@z-*fOJAqQ#X8EcJ7G z*5>bVSKy91>WYp!owYn|8djZA1moa!A0IO}%zOq`IOf(oI;B4go%$y-ifeBA13a0n zulaKMu1Sh-KdBYNqk4>&l-3r+R$=Uzgy@0w>Bo$4^4PJL$oQUG{M$kqEf*^k3fH4- zEOQ&#>WlDmgz>R?V?zv-z?n%=vAnB`uBOEVUqEHR?(P?i}8P95~#x{iQbj|UOvJ)=NGmC&Jpow4-d|M5KDiq?KyR;{|2JMd0UdWV%(+}*f@Cms0nzlyjiO`=672m3WTx*XPUTOhEh3;Ewr z?mhPXqca$l-u(fm*_7FCBR^wn+WX2&8muZ?w5H)-Jik*By5?%YpyjiRYgi<*|5b%B z3`JH@S?%q2dmP=z9ojor@wMrCUL}Xw5xxlw%%X9REs|+hHWHP&Y$q?LcKo9T=Zj)X z3!uBQk(jke^dPi~o0-<-wA@1PB=5Gbtwkcs{K7HK>y9Fa2_6s!A5`hc8Q%0=U&_Vd zN-@YN@evOPHyeh=klL-=#+DqPtTk;iblSyB_YZ7oGbki0W~~rQVfRAedT`f}a~N?` zt!g{+yh?%`$y;HcxuPmnP8?ECa@3i1r4oL@u86C@&;z@bGx+wWo&{=9TR)xlZ2%L| zA3lz?(<%BjC4aZuEhUv){r$}0G>wm@cgo${euVueni@&+{Vkl&F#6LgIO&g*;mafM z9MSaduL}I3o=mIiWQ8Q3)ncYW=F5szP<39n_*(-W5fO>LzLMr=I^w8&d{-n+mL=24 zC564uyCs3)Xy-7?s|!K&9?3T-Vc-U)GNuowCg`Y;xu->SBaSam6Ww^cV@(m+rf&aQrTzy_c&d+VZCba6jw^EXdo;uHBr}_^rfM zgy(m>Q)xR_Ia2hai89Vo6WJ`DT$lcBY#;Dsa!RMcR9l*@&nac%Ht$zw{$oLl;lk8xMhmC>*sXHgVV6>W$P2}`NYrwM_R3y zByJvUw{OY%5if5Zu+3ymAV>ceoItej&mcMvuj6MgpJSKOa*fGA=~bkz_?eY*65SQl zb>oFI0G}!s;b#EaXOL&?Rik2;oT5%j2fNL?vZi~|27@fakOdiS_Kt_K$ZrH^fxP|S z6j7aGg&wA}-0zh*+HlNASOS5n@=fnhV?_1xyTRnn!F68Iqm}+@r+hH?JmjRAk`<2h zccvyxM?&psFX&%ppF@YnD%zM}-5v2VKH%1pBp%KR-7Cx#xNdcTB0Z{xzDFK1(^BBl z6L9MNWXN0pM0!>K`ZL$?$9yu0g9Ur$F|?mQ|EZG<(|hU8{9cq@R|A@;{gZwP3f&Gk zW)QQIe3Jz>Rf`&Z6^(wx7giN_vO^EP=;x(}5869XB*$sgRW3iO)pr8IMfK5Mpp9)t zX11dM9)9m_t2|PEP)nD8)_2xDw zUsOST6h!C@%$+w&3xm_#NK7*g{B+Dj$vLw-ec=i(WYAC)S>RyHhuSw-8m$ce;^p_um3_hqmgQd6_BKkYMMm?K>YLN)g~^`0T8wF!bm=2K0~zI zNwA;|J+7><^qnb!Q^MOuWrEGz1-m1zi835UkCkZE{j&Ouy#g@@+rAR*TJd3u7%#!X zTe){U8?{-F)iu*$X#i~25`6JHw0q{KGm2xu zPyHB5A_H2VKFe`s)Z_djre1a8nbh20=B4&H9=~qW=xF91_lDrzKakQN?V;^%=)8Qi zN`jc9@e%2l5ZMl?!^ey&q}rBwg)*gESV;KdeBeQB1rt=QCkBgz#q){Q_ui(Vg3r5- zGx%ujHt0#3%AloHeM=eu?h~`(KtSY-(I;-}C@P?Hj)L?DbJ;1SqAFXB$=?IB>*R|B zh@wudh2}8Lvu+y0bbKBtaVB%X`)|HX4}G|Ls}@R|)WdY0jdsV>KML|Om=*WU=YHkp z(YK7RSa;C#6o+g?M@^s1c~d+-kE%vA-d+hZ7cS&8ac7&XR;^yDb3RP@6gUt&m1+fn zWlJJ*s0u$^*_2xT=rkn^;h_5^h?OUxZ~1G|sT zhKU)5nodJ?PB88kd6^EnnKp290>ziASNo{v;I$q7x-mfDMh}?LsG5ZybJ~s-j}yw} zwS3s2-4SD7xnQ+SPbg<`bFLWQr1RjYJlXJQk`!EhcFNxrk$##L=f*%!D*yxnZM-n* zI%^T!B^BtWoP~k=2c=PVpqUYETZ{`MZ`>5sjytY5##|sw)o|EY|xn1_?>$C*EBXF;qGi*PJ>JjCCKj zPrW*!DcL#jL+`fBtVBX9=Tyy5%xVYKAGJ}kL>IEA94IO$qNiLuD%tE^%fr0*sxOhO z=C9}$PNHKSFfruH-FR~ANCbO1yTt10VEhPKBd}HjcMnpe?FMrV@AP*#Y!M`c|Ig@| z5CeunY1fwcEV=wSSsbmon^nd^f$>VlfLTG2(BTt-W1k$*^M(0BNwZD88EVuDguR;l zz%wG5x6hB*`!Nak<_Ck=ojr0MJ;XCskGqdg{9gADQf6AkJKJ+v!7}S?VD8FyX52dW165cq zY918){#}8`^*}|UeRA|{TDlxGb4rgWF9C{TRG_I_rtolq7n4yWQ5#s`h`qb5&pCK& zbWW?FIfFPcAnoR-c}YmN3U`9GJ!E~?tT)H)d6vDRCGP^GEZ+8vrVDC>* zt0<2xh~9;gP^n~yMjEJSwUuPrOODbj|};29EN?CGzJztT@q>L+))U>+*B{?@qvspXg3UW1Mh_ns-xXhVV40DA%?zf z9eyg3TO{Gp=8Qni(vAt{A}(u9e*z%QgsQuYrWHBWO;Ee!XP2#bE*bOoK_M7ze?Du5 z+gSV9Efp#^#>P`gGn4h<8kehn{rpmQDLyOumTYm;i(RJ7w##2CW%(c^QcB?mH%t`Erv4$w{V^74?P?N7Hqb3yaa?yuX|q+twg4 zJg3aP9BRWpDUUmUT=n`DQ^)W{f!mp5qO6Gkl8{w+wV=#)D+@Z``sz7Gt`*?+IAstL zd0v!Z-P#FqL`LZyinu^wsl3P;U)oprU0 zy%f@K(S4rA>Bd#poJOuA=@uxvfX9pWKR=xl-hdy4@?8 z-B^?XlIBuXbJ{*1@LpZeYQInAaQz}f%)3>GjG$>}@cp;1y=i2z&qwPxwN>d!UKM;C zx}lV?7-3WQQ3)Ozb3gYFmWeBMJtnh09VUW1()eO6X3g%zYOdv1{qfni_fJvPJhPF@ zUN0vUx9%b~TMx;xC%)E?Ve4b5oa1YolJgXO`<)-zce0+et-7vNv}yGTzqhaR5k)hD z%20K`nSyQBN%8NRVMB5VUn)SiCtgQZs<){4nm72eBsw!Lp*61&@&sw;E-jjB7OJNi zQ+em4gtdb$)lKK5vAJ>+EKG~YTF^JXXs+Mv!o8+bupQ~Eq3OgHW%DRr2Ne(`P2=|QvvLXT#(iK6L-dj z{A(8&2G$u2dfw5PqpF;e2LuZ>OkPAOrgX5`vIVd9p|vjov71Fl`2w zI<6S{e_`VO&f32?u|N(1SPT|3XlTA~==aZ<{~ z(O;^3S@R)eCr*30S>K&+6Ct!-!62VeK8?WF`(E{{q=#?uyWZf#Ea`u(?S7mnN( zOj*Twc`v!0-?5fh5N9=6p(^m=4x4a^wX4@sI4$WzcuH5%76VUyt)(JbA>22KE`89l zgL^^7y$9HA9=yG#0Ig8grI$fyC$q0FoEBymiyoLRW{Q4U zH$;?H<S0C&TXs6!?XUN1KW>dWN{vs5$7fwM50BU4s4JRRaKBJ!ndEmc$ajBm zbSW!!+|vYj2#F}8_GFrUY2t#@t8}%5)_gp_+@Ndg$*QoS^t@3d!HAuTL;2!tjn6hr ztdLVJdFfs)sRhjXL-;5dkJ%!gmlfOFX3eOzWe03&l8@zvhn3#eJB78B@SN>cUXxIF zN$-+V?RD8Sgw&=o!2n`2Yc~W8KANiJ=IFygnEBy0i-8B>t%OVV9p5P1ZUfAlC2dw; zMR7*^)QhmVQE;xDoHkBt%i6-8^8&@BCBEvZ;oy|!f6L7v779HYl*8L7#f*r-~|wKT2!6Twcq( zI*dhIcX!r zmbb=v$95k{Dny%vy`Bk~!UWU6X6 z!#Vp4iV2hx+izGgmWj!KVEhHGP{nkseV(ziq}{Q(u`SAqe`L_v>k*ts@9wR4G+4(r zFqqpAUd>q58gDEGY#m#KL_i?K#3Taw-L4)RVI*5>zGoaPiA2~?hi@=dY#<`GOc7nV z`fD8#OYu^(YE`$JtjgjAj+mJ+r-HWsyLS3BgUy(1f~(82NhqXzFm}ahtGfiIa5(??|6llAP9_o%I~d zK)I*AK5?P8sc=~Jtf}*M*GIo#g(QfMnDzZDCmR#4pln0f|LQa}c+_KW3MxFKseZ^U z1jb1Z%I{J3(6nr^n=!jSj!fG@vl5M#>T&ncH3?g87x{g8n=3=cn;UIj+#|pk<2p-359R0 zF32<`qk7FMwu(XIA}x>S8@Z9}caT{q@H)aPfnwXeAD?^dTGZyMgomhbCqDx&tFox$ z#r6^)dXdGmMwmy=CPGHfAVVJZ2`9nxsVb`q{h#MCmA$Q3UOq>*=P;S*$fr!*6U~Yr z@lT_S^0h!xYEgvJMoXV4@%2a3dRJwgH|qf?=@!obwzY+5ePrVGHW94Y5%4u8wGPtA zKQsu`3Dn;I)px{mMxhsGDcJdp(|Djfo7h4becJ9yP6-U$%+Ah}-Pp^ZVr%5OR-1I8 zK||jmaouw@`5oF{yB^<45?^P$i7_OltThaG?%AcrYmY>8==b`<{9_lI<@NE-DnjO{wYrR{wWHi+R z$uAsy!?U!M#d*1j{fs>I9de>ynZ`DH!+mXC-0q*}u2BlB8&#p-6PWU#e|<6pxpkKa zZ}OD5d`F7EtH&DpBw*B8U&V%9)*XF$`qfM@`4N6`Je`YuJ6IuFOZ~3`ljP+P zvV`3*H~hXj`NFkMC-MLi-Mrg^ddo>!EWZa@(4eD=Iw1ytHK#SoX7QN{zg2-C2&cGg zMqS-DOsh(4xXS+2D1V-2z_%qd+Z~cm+Ft|7%GgE@^iVD7OVix7Ph-6FoNw4c$g7IA_5^fF{tiK>t zp^=ypCn@hpjI*oCsEl;eUtLi)#z_=RPc7A+eb_x08E?q)VwqIi>+)9*Q!Z-vb|D}^^JIzfRGJ`# zW&F^$*gMj1ml1SFj4g2TRYQXbf~=2nlfGHIV)b*eC`z)R#; zq(kI9`F_X@(0lym9tMk9Jz+4d*Jqt71j9V?MMdhOHPIXTMJ z@*7zNp5#Z}=hK_bRBYD($wnA{hN#1tY5Fm(kf-~4V8vm#4rEVz0@I{er`>k_$rEYI zNwLoQvAfa9!Z+WAMl!6MbImXZx9^vItq?=^GF)P< zvnZDwZn%-YjYitw7o<9QKtIbQ8C}xjAN4>Ia>K0GVtPAlxY9!rBjGZ}7%Acy7EwQh z_ccOcAGUmOMK$CMt||klbT!VCR{j+rnWNsEs|Qe%mh-JWsI5A%s!hGSxLJ6bOZgp- z6HXWlq$NT9+BBA%2uQYI?pcZvXy%=O89Wm7cC^U2FWCQ>4=9*bLtbsGxy~vqMQ8JM z^M~~-Iz1}?2|=-?cVxYIvP~Ef#WzhN-KKq>1Egh*#}eBAx@`8W8j|mBdg~`t$uz-XIvIp#kFbc zd#v>8_bAzY$&b=f;uQm~Sqk-=z}1SZJ2Nm~xcyp&(Az9JV^b6&o(2`dZ=Y0(*zrS>o|8H=#1{03SDg!ut_a!+wkQx2979ll za~8W*>*nF$ZHDH3sChwEQXo;cY=Ro*{XEplxSPmp{-pHm-icAU74C_qR)U z{m)vO^6A{0PRF-~)if|BJ;uv34d}Ld)nADeX5H`@C5oj7>(YhX%gHc|N3ioKHEVH> zcr4h_CPVCX=$eTOChC!iT(vzWon`c@jW{wNGb#v9ooQhLIv7!T*K1w@6OLkV z)QX(r0OU`7ZV{bo$0xS+CsaP~BbDUS`s6|03$<<8^Oxh#=eprTB2KER*GQGoQwZ&C zaQkiO4q=fr+D&+4O51nmanKwTOeD^{l^;!F1w-?T*DDL3Y!zE)3U_d$lYP=khfry# zmK>m99)r-R;sf?gZ{^0yt4EPMf?uHSpF*FLcL(#-dd3#uwN>#>q;8K(K<9n|yzVCJKYS;XKt-i( zPzTwNl=n8C;Y@_o39+#H%y&xThFFwvt=o;Oj|u@AgqAlHA{d1_OOqOPmY`ABhDS>* zgAYH?KFck@&pIsK%6^@cxk~H4EcT$uBNd%uv9kRxRZ%17nH_#kKzbfS*jeGBVOmZ+ ziXF#e{QDQ#PkClGD9jm4kc3fQe|(G@IIz5sFI`Z`$$IoFi{*`Ur|>fpZeONL+RboR zare7(AQ}wpwoj1SNL6Lyc)+Q@U;KGT=X_tJmhb#!QMYIjfnVm^!%kIb(X+-uiE3}T zcYLIJyw#+FvTaiQagiDSgi}do0AA@nHk#cQUl@~tXSO}9XkwIvm%uJ0*spySDRirk zic7w_l*yFY>@HvuFu6oDRfVL{ouyKd^=h4aP z)Ra+)v%KZpZr9ik@&1kY4~zYyCE<*iuD*I4((BA>Ys6bQ=VRW;>|KxjoZ0j+^Mvm&FgrD$ZH;g73%Imy>EK9ye64fl zFPGPFCTVYr)zu{j0nI8u>}Xs1Qnc#C(XdcUEx&70p*D`03L~rOOriK-a+vytJlIFW z@K+h7v_Zz#lUjJ1>3akKxF7}E9S%7+-b(~tKIBjXiKW6vNMp96a|$XNtbEGJdbm8_ zwz51=%zWYeap>r%+4a=J$@!-4a;15eo_JM#gDK6c6M)5_lqs@bymDhTrQP|`i+3t- zcgX2@v||H*fZ;5Hq)k0Icl;%bQ=i}$(~85oB*uM$izbGe3_LfJ8&kdkpDNxpHr;FD zH&E)CELb!WCIp7bmq#xDy7OD_X&6X`NOP|#mwjFpmnm*F2`bourg%ADx(DYh=<(ivT(aNi97?UQvU(Zz+UPsq z8q;?$W{1eoJ@zKzS*sattR&>|Tk38*!*)(=*&2!|ozzTJ@dK2hYHH53-QHw) z=eyQaHOj5#t?taG8Zytji+j9Mson@C^6iig@aebkjAYOKPF|BP^Sm#o$5$Fp3Vo)0 zojKQ*;U{0#7t_o)5rv8CL)hQb`YZY+tZE#0%H!(`KF=Yi#@>H*Z_6C>DFIH^f5gQaU;^UhsEzMok z_T$CQfDZqHAF%IBr|nu@-P$&A_lDZZvg3i&T2#SO$*wiqV_tem4e!zhrqE_u$3q+n zLAflN2gB)bsVgVI_A$9BJw;|aZb9a?LhS=-0siZ?R#{^(V-sG}AT&Ydc0%Do4Inzs z&S#KEV6$DA&k1W(!nhL}bgOr&g-hUW{hIzO@a_vu)({DIX-jX(K8oY_PThpnJPw@F z^e7mMfJ?_cv%660d<3r4U0vNA4QU0d#;8O6Q}(G}EB#e9))r2h4KVBzDUJsx0LiPl z)1)!u$N=ECXM?3{`?ySiMY(nWum51ih$^`PX0z)YcI}t()cE8Hsn!OD?*0k;Eo5ek z=v-;a2?z(f?Rp#6y0QJOsJcw|epJ34Y#h(TCeG*E6#OyIfZgI}S`Pmgvp|RUrs2US zU(oN#+?AQpGH`=DIq8F(TD(8ZXE*;k2v4raTechA9DNeKwy-4@esp#&pe_noa^zBb z<-IpqSSzR39idS}Ro*Zh#BpqAJ4^;m$(VZYKG#C~kDuoQnu?H%oQc66$GcU$TP*Fa@uVY!(5hX1WY7F9UPi#``Vd;4C?uw zhbE6acJfS6KRx}lU!V28Ah_P3eS(H-$?AK8W@)eP_C=JW!dso~b%KVX?w?DJqnp#< zJeva_Ce`e!%>kuVk=*4E)(xIh#|W*Ml3q&&rPLdj=`m#`wK&J&vo1g1aPmdWw_-MQ zzj9aWE{+n|){Af}=dFhC9^P(FoqC3sLLtVvH@&$Zuo$gGs-Ee1)eAv-@0yPxS)6sJ z_GtBWfo&8t{4hHQ)IK#ukm_Q)udvrsY0$=uvA=IgC4|N(Ks!*iov}8;Q7&N;9M3Wz ze+5&zpMN7Q9f(~AIx}~-%p*2S3lWF1qmQyep1;wE*P9}h_%)FE!}<)K2;=Rn;8Qv zNjcpXUw>G=zO%8^^eN)G(H#~7PzO<9He8I>^_36vl5Upxak?kHop~thQ%bLBq=zu!>}BF&ZQbd;mMYT$6(7RvaGm=?}-_Fl`$@7AAXdlib>> z5Dx@D%5SSdK5&S=6*Rb%7qV?3&uiCK!P|`Z;8yvgjbS%HvAYo6eY}Bkb$K;W@a|~$ zbnvd+EY>|VK`D{Mq~ePQB8&Y2V}e(%3KJ)%@B6r`f+{%;Di1;;yNxDCXpi_to6-W; zWHyNhh8Pt-}T>d`T1#3Q`O*|j2O z{w~n#v0(J8^Y*(KdU;@v7QeABOC0k^Tv6+4H}PcN=8lT?MrQ-3JYj*ycv*c#a(prk zW4w9_tI=jsEzc*92dt337`&9-6|i@}r4u+wMte+kfK~_wyd4!_cQ8n_*;s`}NTL-s z)Znc9fQ65KUsP28v~SFH>y!d9Lfq(gQ6%-h6;X|%9iXUhh)W{XZI%H%juk^PDOX4d z0zZBz>|_HhGon>l=O2= zI#-}HiG-c0x_-rZPEo;>SKfgvFmhypZ&_|_#r1$LdofsjLfa4;f|!wJTt203)k&^_ zrb3=vPEt|1!2$Nud5_{yn=vg(Ut}xd_hxTp%$*H!V%WG}Z?V$9M!bQJC&REh!|WUE zqJvk1Z5&;jB`k_)9tS+xrCg!zFb|8ib_cbBLr+-C^s|PJ27Mv5+6{1YNaOpT4ysOQ z>C+w9UYc?c&s8SA-0rqYntYZY9Py%kG^^0TO8=4Y7D;1h%18?<{SD+kp%Cyebapj_ z+@+)-H}C_}(*G8qVYaJy^9IpfTv$Nye+#ea1%wCELMr|#{Yl}6cW}bOP%xz8Kg|h2 zT1Pr(Vfm+NNC_i^Pc|$HNv`sLnue6n5rb49=cCH~+qqo36cj9zkr^xV-_GB_Ai+Z_ zK%xcz+aAJi+V*y^l;zI1|At2nso;0_`}co6vclli|1KDA{&O$?4Cy1J0vai%!hb#H zzJbPNtx3TBZ%BUc@apk4^!5JhaR^IiO)r<@GV{M7K^_3^|Ko)G|C8{4F`NI-Ey0@> zU4e&Q?$sUMLWnqDs|ubG-I?RX0Nm;Z;nS`;g+HLg8Fhrp5>yyP57sg2Fs?X2`L7H8 z_uRj0klC&MEchJ~zYT=!84Fs$co5I(-Ipg%yq;=p3qnesnouBo#WU?5Bzb@EmW zAi1|t4f}9Nhg}LA!%!KTWZ1ml!~4nkU*nJ?%Xv4~OPQmEg5&qEiF%7cUWyOGkLXw_3qgC1#BdhewXV@ur=Y40f z!0XORWYyL^D_RPh!v)gwmkMYx5tgcCnYmBDYB|ybl(kq8?5kKne(yD#eg^L*j!$iX zR!C_)M;VOCbgILNUEYQof z=@wViPr{OYKaTVdTXL%6SW6> zkqfI6-z5#eLLYm&l0o&^g-aX4^7P%-MgR0WCF}ro#B%`_e3= zno`mtfnZQMDe>{~9+uIxtDeZVzs%mIC=dUtj1V{YEO17c1 zs^z07?~2VR{*gy6!h_ncLFW87OaFd1#h|?RO6yfwBN9-mNuG=~$1-DDVQ`tvRo>^$ zUrQ6`+^y>`#GmRsvb3dtE3vmnR&$FvWHPP4T^Dqm6n>&Db$f>|(UaA}wTD}6l4Dv_ z)ERjra4oP?NWJ|a_5l|Wg5&Rfe$LcpT};NGoR)ebEZMjo61Cl6~3@%xt<1i&4>Pvz7au+TIH}aeQgv zzJ;@K7ry&D%vvt_dMRY}PfOS2J*DI@MjDG(wDA>{w*-pKtf&S-=n%JA~)c;^_|~-rs@ysSCH7T!eg+^qgNdPsP3BDKQ$@ z4hhvdsR>^bbe~?7JSJn}GD}WATOjT25|tV64-g$Y+GuyezuuWVZP=pz`$i%lK;Fpu zGKy};*t6+#SXPYmUm9C4ngy9zbx z#m+Vi2~O)*i1ja%`K{xewg@?M>>c+gmu|*D_Vj<(Kpj1pOhgSm#!b~WcxPNsd9}!^ zetdn#mAMyrrIpX5RbK=K`-FW}ZQWJ~#s#a^nVCP-N6ISv<^=8F!M*1Foto_CuTvfx zdIjQx;6b&^AVU&6>Rf0^EO1L$QmQ1A{fRn;=c}u8AtvuY4!J{$^v%;;MlgC?9!?@~ zcL}qRG;-+W*`al2l-glrN#gIqEyn1b{f^0m+gCOw$uAq;2&`b2`L?|$x7n& z`-4N3g2$X3>8>-dzKhwS_q%NK&QWc(phQL(0Nx716cO*Q8B{*Owr+9!1eq-wpV>aZ z+;{-Otgxolcp-kgu>!l=3758M$@u%M7X-#1wY4 z|FJDGNAGSJrdD-_W=E(0*_6$fe+ER-zq$t#LjSiX>nzCu>efA;JSFI*p+oytL3k& zIZVmsW-<*><1xT8m6gqQ#RQUvnfSELLKx>fOLJ>xUfv_r2kxQKclKyx3s%v?Z6lC=nAA$X^pYZ=w0y=*2h0@?Z fKQ&r(?G@^EOhxr^CW7e= 0 { - if l >= left { - n = left - err = io.EOF - } else { - n = l - } - d.off += n - for i := 0; i < n; i++ { - p[i] = '*' - } +func DumpAll() DumpOption { + return func(o *DumpOptions) { + o.RequestHead = true + o.RequestBody = true + o.ResponseHead = true + o.ResponseBody = true } +} - return +func DumpTo(output io.Writer) DumpOption { + return func(o *DumpOptions) { + o.Output = output + } } -func (d *dummyBody) Close() error { - return nil +func DumpToFile(filename string) DumpOption { + return func(o *DumpOptions) { + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + o.Output = file + } } -type dumpBuffer struct { - bytes.Buffer +func (d *dumper) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { + return &dumpReadCloser{rc, d} } -func (b *dumpBuffer) Write(p []byte) { - if b.Len() > 0 { - b.Buffer.WriteString("\r\n\r\n") +type dumpReadCloser struct { + io.ReadCloser + dump *dumper +} + +func (r *dumpReadCloser) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + r.dump.dump(p[:n]) + if err == io.EOF { + r.dump.dump([]byte("\r\n")) } - b.Buffer.Write(p) + return } -func (b *dumpBuffer) WriteString(s string) { - b.Write([]byte(s)) +func (d *dumper) WrapWriteCloser(rc io.WriteCloser) io.WriteCloser { + return &dumpWriteCloser{rc, d} } -func (r *Resp) dumpRequest(dump *dumpBuffer) { - head := r.r.flag&LreqHead != 0 - body := r.r.flag&LreqBody != 0 +type dumpWriteCloser struct { + io.WriteCloser + dump *dumper +} - if head { - r.dumpReqHead(dump) - } - if body { - if r.multipartHelper != nil { - dump.Write(r.multipartHelper.Dump()) - } else if len(r.reqBody) > 0 { - dump.Write(r.reqBody) - } - } +func (w *dumpWriteCloser) Write(p []byte) (n int, err error) { + n, err = w.WriteCloser.Write(p) + w.dump.dump(p[:n]) + return } -func (r *Resp) dumpReqHead(dump *dumpBuffer) { - reqSend := new(http.Request) - *reqSend = *r.req - if reqSend.URL.Scheme == "https" { - reqSend.URL = new(url.URL) - *reqSend.URL = *r.req.URL - reqSend.URL.Scheme = "http" +func (d *dumper) WrapReader(r io.Reader) io.Reader { + return &dumpReader{ + r: r, + dump: d, } +} - if reqSend.ContentLength > 0 { - reqSend.Body = &dummyBody{N: int(reqSend.ContentLength)} - } else { - reqSend.Body = &dummyBody{N: 1} - } +type dumpReader struct { + r io.Reader + dump *dumper +} + +func (r *dumpReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + r.dump.dump(p[:n]) + return +} - // Use the actual Transport code to record what we would send - // on the wire, but not using TCP. Use a Transport with a - // custom dialer that returns a fake net.Conn that waits - // for the full input (and recording it), and then responds - // with a dummy response. - var buf bytes.Buffer // records the output - pr, pw := io.Pipe() - defer pw.Close() - dr := &delegateReader{c: make(chan io.Reader)} +type dumpWriter struct { + w io.Writer + dump *dumper +} + +func (w *dumpWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.dump.dump(p[:n]) + return +} - t := &http.Transport{ - Dial: func(net, addr string) (net.Conn, error) { - return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil - }, +func (d *dumper) WrapWriter(w io.Writer) io.Writer { + return &dumpWriter{ + w: w, + dump: d, } - defer t.CloseIdleConnections() +} - client := new(http.Client) - *client = *r.client - client.Transport = t +type dumper struct { + *DumpOptions + ch chan []byte +} - // Wait for the request before replying with a dummy response: - go func() { - req, err := http.ReadRequest(bufio.NewReader(pr)) - if err == nil { - // Ensure all the body is read; otherwise - // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) - req.Body.Close() - } +func DefaultDumpOptions() *DumpOptions { + return defaultDumpOptions +} - dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") - pr.Close() - }() - - _, err := client.Do(reqSend) - if err != nil { - dump.WriteString(err.Error()) - } else { - reqDump := buf.Bytes() - if i := bytes.Index(reqDump, []byte("\r\n\r\n")); i >= 0 { - reqDump = reqDump[:i] - } - dump.Write(reqDump) +var defaultDumpOptions = &DumpOptions{ + Output: os.Stdout, + RequestBody: true, + ResponseBody: true, + ResponseHead: true, + RequestHead: true, +} + +func newDumper(opt *DumpOptions) *dumper { + if opt == nil { + opt = defaultDumpOptions + } + if opt.Output == nil { + opt.Output = os.Stdout } + d := &dumper{ + DumpOptions: opt, + ch: make(chan []byte, 20), + } + return d } -func (r *Resp) dumpResponse(dump *dumpBuffer) { - head := r.r.flag&LrespHead != 0 - body := r.r.flag&LrespBody != 0 - if head { - respDump, err := httputil.DumpResponse(r.resp, false) - if err != nil { - dump.WriteString(err.Error()) - } else { - if i := bytes.Index(respDump, []byte("\r\n\r\n")); i >= 0 { - respDump = respDump[:i] - } - dump.Write(respDump) - } +func (d *dumper) dump(p []byte) { + if len(p) == 0 { + return } - if body && len(r.Bytes()) > 0 { - dump.Write(r.Bytes()) + if d.Async { + b := make([]byte, len(p)) + copy(b, p) + d.ch <- b + return } + d.Output.Write(p) } -// Cost return the time cost of the request -func (r *Resp) Cost() time.Duration { - return r.cost +func (d *dumper) Stop() { + d.ch <- nil } -// Dump dump the request -func (r *Resp) Dump() string { - dump := new(dumpBuffer) - if r.r.flag&Lcost != 0 { - dump.WriteString(fmt.Sprint(r.cost)) - } - r.dumpRequest(dump) - l := dump.Len() - if l > 0 { - dump.WriteString("=================================") +func (d *dumper) Start() { + for b := range d.ch { + if b == nil { + fmt.Println("stop dump") + return + } + d.Output.Write(b) } +} - if r.resp != nil { - r.dumpResponse(dump) - } +func (t *Transport) EnableDump(opt *DumpOptions) { + dump := newDumper(opt) + t.dump = dump + go dump.Start() +} - return dump.String() +func (t *Transport) DisableDump() { + if t.dump != nil { + t.dump.Stop() + t.dump = nil + } } diff --git a/dump_test.go b/dump_test.go deleted file mode 100644 index 98d1bbe4..00000000 --- a/dump_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package req - -import ( - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestDumpText(t *testing.T) { - SetFlags(LstdFlags | Lcost) - reqBody := "request body" - respBody := "response body" - reqHeader := "Request-Header" - respHeader := "Response-Header" - handler := func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(respHeader, "req") - w.Write([]byte(respBody)) - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - header := Header{ - reqHeader: "hello", - } - resp, err := Post(ts.URL, header, reqBody) - if err != nil { - t.Fatal(err) - } - dump := resp.Dump() - for _, keyword := range []string{reqBody, respBody, reqHeader, respHeader} { - if !strings.Contains(dump, keyword) { - t.Errorf("dump missing part, want: %s", keyword) - } - } -} - -func TestDumpUpload(t *testing.T) { - SetFlags(LreqBody) - file1 := ioutil.NopCloser(strings.NewReader("file1")) - uploads := []FileUpload{ - { - FileName: "1.txt", - FieldName: "media", - File: file1, - }, - } - ts := newDefaultTestServer() - r, err := Post(ts.URL, uploads, Param{"hello": "req"}) - if err != nil { - t.Fatal(err) - } - dump := r.Dump() - contains := []string{ - `Content-Disposition: form-data; name="hello"`, - `Content-Disposition: form-data; name="media"; filename="1.txt"`, - } - for _, contain := range contains { - if !strings.Contains(dump, contain) { - t.Errorf("multipart dump should contains: %s", contain) - } - } -} diff --git a/go.mod b/go.mod index 433bcc07..da3c0110 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module github.com/imroc/req -go 1.12 +go 1.16 + +require ( + github.com/hashicorp/go-multierror v1.1.1 + golang.org/x/net v0.0.0-20220111093109-d55c255bac03 + golang.org/x/text v0.3.7 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..59f8d4e3 --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/h2_bundle.go b/h2_bundle.go new file mode 100644 index 00000000..e14f7fa0 --- /dev/null +++ b/h2_bundle.go @@ -0,0 +1,10681 @@ +//go:build !nethttpomithttp2 +// +build !nethttpomithttp2 + +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 + +// Package http2 implements the HTTP/2 protocol. +// +// This package is low-level and intended to be used directly by very +// few people. Most users will use it indirectly through the automatic +// use by the net/http package (from Go 1.6 and later). +// For use in earlier Go versions see ConfigureServer. (Transport support +// requires Go 1.6 or later) +// +// See https://http2.github.io/ for more information on HTTP/2. +// +// See https://http2.golang.org/ for a test server running this code. +// + +package req + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "crypto/rand" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math" + mathrand "math/rand" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" +) + +// The HTTP protocols are defined in terms of ASCII, not Unicode. This file +// contains helper functions which may use Unicode-aware functions which would +// otherwise be unsafe and could introduce vulnerabilities if used improperly. + +// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func http2asciiEqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if http2lower(s[i]) != http2lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func http2lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// isASCIIPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func http2isASCIIPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// asciiToLower returns the lowercase version of s if s is ASCII and printable, +// and whether or not it was. +func http2asciiToLower(s string) (lower string, ok bool) { + if !http2isASCIIPrint(s) { + return "", false + } + return strings.ToLower(s), true +} + +// A list of the possible cipher suite ids. Taken from +// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt + +const ( + http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + // Reserved uint16 = 0x001C-1D + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + // Reserved uint16 = 0x0047-4F + // Reserved uint16 = 0x0050-58 + // Reserved uint16 = 0x0059-5C + // Unassigned uint16 = 0x005D-5F + // Reserved uint16 = 0x0060-66 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + // Unassigned uint16 = 0x006E-83 + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E + http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 + http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA + http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + // Unassigned uint16 = 0x00C6-FE + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF + // Unassigned uint16 = 0x01-55,* + http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 + // Unassigned uint16 = 0x5601 - 0xC000 + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B + http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C + http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F + http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 + http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 + http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 + http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 + http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 + http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 + http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA + http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF + // Unassigned uint16 = 0xC0B0-FF + // Unassigned uint16 = 0xC1-CB,* + // Unassigned uint16 = 0xCC00-A7 + http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 + http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 + http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA + http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB + http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC + http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD + http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE +) + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +// References: +// https://tools.ietf.org/html/rfc7540#appendix-A +// Reject cipher suites from Appendix A. +// "This list includes those cipher suites that do not +// offer an ephemeral key exchange and those that are +// based on the TLS null, stream or block cipher type" +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case http2cipher_TLS_NULL_WITH_NULL_NULL, + http2cipher_TLS_RSA_WITH_NULL_MD5, + http2cipher_TLS_RSA_WITH_NULL_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_RC4_128_SHA, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, + http2cipher_TLS_KRB5_WITH_RC4_128_MD5, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_PSK_WITH_NULL_SHA, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_NULL_SHA256, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_NULL_SHA256, + http2cipher_TLS_PSK_WITH_NULL_SHA384, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_AES_128_CCM, + http2cipher_TLS_RSA_WITH_AES_256_CCM, + http2cipher_TLS_RSA_WITH_AES_128_CCM_8, + http2cipher_TLS_RSA_WITH_AES_256_CCM_8, + http2cipher_TLS_PSK_WITH_AES_128_CCM, + http2cipher_TLS_PSK_WITH_AES_256_CCM, + http2cipher_TLS_PSK_WITH_AES_128_CCM_8, + http2cipher_TLS_PSK_WITH_AES_256_CCM_8: + return true + default: + return false + } +} + +// ClientConnPool manages a pool of HTTP/2 client connections. +type http2ClientConnPool interface { + // GetClientConn returns a specific HTTP/2 connection (usually + // a TLS-TCP connection) to an HTTP/2 server. On success, the + // returned ClientConn accounts for the upcoming RoundTrip + // call, so the caller should not omit it. If the caller needs + // to, ClientConn.RoundTrip can be called with a bogus + // new(http.Request) to release the stream reservation. + GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) + MarkDead(*http2ClientConn) +} + +// clientConnPoolIdleCloser is the interface implemented by ClientConnPool +// implementations which can close their idle connections. +type http2clientConnPoolIdleCloser interface { + http2ClientConnPool + closeIdleConnections() +} + +var ( + _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) + _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} +) + +// TODO: use singleflight for dialing and addConnCalls? +type http2clientConnPool struct { + t *http2Transport + + mu sync.Mutex // TODO: maybe switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) + conns map[string][]*http2ClientConn // key is host:port + dialing map[string]*http2dialCall // currently in-flight dials + keys map[*http2ClientConn][]string + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls +} + +func (p *http2clientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2dialOnMiss) +} + +const ( + http2dialOnMiss = true + http2noDialOnMiss = false +) + +func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? + if http2isConnectionCloseRequest(req) && dialOnMiss { + // It gets its own connection. + http2traceGetConn(req, addr) + const singleUse = true + cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) + if err != nil { + return nil, err + } + return cc, nil + } + for { + p.mu.Lock() + for _, cc := range p.conns[addr] { + if cc.ReserveNewRequest() { + // When a connection is presented to us by the net/http package, + // the GetConn hook has already been called. + // Don't call it a second time here. + if !cc.getConnCalled { + http2traceGetConn(req, addr) + } + cc.getConnCalled = false + p.mu.Unlock() + return cc, nil + } + } + if !dialOnMiss { + p.mu.Unlock() + return nil, http2ErrNoCachedConn + } + http2traceGetConn(req, addr) + call := p.getStartDialLocked(req.Context(), addr) + p.mu.Unlock() + <-call.done + if http2shouldRetryDial(call, req) { + continue + } + cc, err := call.res, call.err + if err != nil { + return nil, err + } + if cc.ReserveNewRequest() { + return cc, nil + } + } +} + +// dialCall is an in-flight Transport dial call to a host. +type http2dialCall struct { + _ http2incomparable + p *http2clientConnPool + // the context associated with the request + // that created this dialCall + ctx context.Context + done chan struct{} // closed when done + res *http2ClientConn // valid after done is closed + err error // valid after done is closed +} + +// requires p.mu is held. +func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall { + if call, ok := p.dialing[addr]; ok { + // A dial is already in-flight. Don't start another. + return call + } + call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx} + if p.dialing == nil { + p.dialing = make(map[string]*http2dialCall) + } + p.dialing[addr] = call + go call.dial(call.ctx, addr) + return call +} + +// run in its own goroutine. +func (c *http2dialCall) dial(ctx context.Context, addr string) { + const singleUse = false // shared conn + c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) + close(c.done) + + c.p.mu.Lock() + delete(c.p.dialing, addr) + if c.err == nil { + c.p.addConnLocked(addr, c.res) + } + c.p.mu.Unlock() +} + +// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't +// already exist. It coalesces concurrent calls with the same key. +// This is used by the http1 Transport code when it creates a new connection. Because +// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know +// the protocol), it can get into a situation where it has multiple TLS connections. +// This code decides which ones live or die. +// The return value used is whether c was used. +// c is never closed. +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) { + p.mu.Lock() + for _, cc := range p.conns[key] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return false, nil + } + } + call, dup := p.addConnCalls[key] + if !dup { + if p.addConnCalls == nil { + p.addConnCalls = make(map[string]*http2addConnCall) + } + call = &http2addConnCall{ + p: p, + done: make(chan struct{}), + } + p.addConnCalls[key] = call + go call.run(t, key, c) + } + p.mu.Unlock() + + <-call.done + if call.err != nil { + return false, call.err + } + return !dup, nil +} + +type http2addConnCall struct { + _ http2incomparable + p *http2clientConnPool + done chan struct{} // closed when done + err error +} + +func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { + cc, err := t.NewClientConn(tc) + + p := c.p + p.mu.Lock() + if err != nil { + c.err = err + } else { + cc.getConnCalled = true // already called by the net/http package + p.addConnLocked(key, cc) + } + delete(p.addConnCalls, key) + p.mu.Unlock() + close(c.done) +} + +// p.mu must be held +func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { + for _, v := range p.conns[key] { + if v == cc { + return + } + } + if p.conns == nil { + p.conns = make(map[string][]*http2ClientConn) + } + if p.keys == nil { + p.keys = make(map[*http2ClientConn][]string) + } + p.conns[key] = append(p.conns[key], cc) + p.keys[cc] = append(p.keys[cc], key) +} + +func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + for _, key := range p.keys[cc] { + vv, ok := p.conns[key] + if !ok { + continue + } + newList := http2filterOutClientConn(vv, cc) + if len(newList) > 0 { + p.conns[key] = newList + } else { + delete(p.conns, key) + } + } + delete(p.keys, cc) +} + +func (p *http2clientConnPool) closeIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + // TODO: don't close a cc if it was just added to the pool + // milliseconds ago and has never been used. There's currently + // a small race window with the HTTP/1 Transport's integration + // where it can add an idle conn just before using it, and + // somebody else can concurrently call CloseIdleConns and + // break some caller's RoundTrip. + for _, vv := range p.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + // If we filtered it out, zero out the last item to prevent + // the GC from seeing it. + if len(in) != len(out) { + in[len(in)-1] = nil + } + return out +} + +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + +// shouldRetryDial reports whether the current request should +// retry dialing after the call finished unsuccessfully, for example +// if the dial was canceled because of a context cancellation or +// deadline expiry. +func http2shouldRetryDial(call *http2dialCall, req *http.Request) bool { + if call.err == nil { + // No error, no need to retry + return false + } + if call.ctx == req.Context() { + // If the call has the same context as the request, the dial + // should not be retried, since any cancellation will have come + // from this request. + return false + } + if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) { + // If the call error is not because of a context cancellation or a deadline expiry, + // the dial should not be retried. + return false + } + // Only retry if the error is a context cancellation error or deadline expiry + // and the context associated with the call was canceled or expired. + return call.ctx.Err() != nil +} + +// Buffer chunks are allocated from a pool to reduce pressure on GC. +// The maximum wasted space per dataBuffer is 2x the largest size class, +// which happens when the dataBuffer has multiple chunks and there is +// one unread byte in both the first and last chunks. We use a few size +// classes to minimize overheads for servers that typically receive very +// small request bodies. +// +// TODO: Benchmark to determine if the pools are necessary. The GC may have +// improved enough that we can instead allocate chunks like this: +// make([]byte, max(16<<10, expectedBytesRemaining)) +var ( + http2dataChunkSizeClasses = []int{ + 1 << 10, + 2 << 10, + 4 << 10, + 8 << 10, + 16 << 10, + } + http2dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return make([]byte, 1<<10) }}, + {New: func() interface{} { return make([]byte, 2<<10) }}, + {New: func() interface{} { return make([]byte, 4<<10) }}, + {New: func() interface{} { return make([]byte, 8<<10) }}, + {New: func() interface{} { return make([]byte, 16<<10) }}, + } +) + +func http2getDataBufferChunk(size int64) []byte { + i := 0 + for ; i < len(http2dataChunkSizeClasses)-1; i++ { + if size <= int64(http2dataChunkSizeClasses[i]) { + break + } + } + return http2dataChunkPools[i].Get().([]byte) +} + +func http2putDataBufferChunk(p []byte) { + for i, n := range http2dataChunkSizeClasses { + if len(p) == n { + http2dataChunkPools[i].Put(p) + return + } + } + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) +} + +// dataBuffer is an io.ReadWriter backed by a list of data chunks. +// Each dataBuffer is used to read DATA frames on a single stream. +// The buffer is divided into chunks so the server can limit the +// total memory used by a single connection without limiting the +// request body size on any single stream. +type http2dataBuffer struct { + chunks [][]byte + r int // next byte to read is chunks[0][r] + w int // next byte to write is chunks[len(chunks)-1][w] + size int // total buffered bytes + expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) +} + +var http2errReadEmpty = errors.New("read from empty dataBuffer") + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2dataBuffer) Read(p []byte) (int, error) { + if b.size == 0 { + return 0, http2errReadEmpty + } + var ntotal int + for len(p) > 0 && b.size > 0 { + readFrom := b.bytesFromFirstChunk() + n := copy(p, readFrom) + p = p[n:] + ntotal += n + b.r += n + b.size -= n + // If the first chunk has been consumed, advance to the next chunk. + if b.r == len(b.chunks[0]) { + http2putDataBufferChunk(b.chunks[0]) + end := len(b.chunks) - 1 + copy(b.chunks[:end], b.chunks[1:]) + b.chunks[end] = nil + b.chunks = b.chunks[:end] + b.r = 0 + } + } + return ntotal, nil +} + +func (b *http2dataBuffer) bytesFromFirstChunk() []byte { + if len(b.chunks) == 1 { + return b.chunks[0][b.r:b.w] + } + return b.chunks[0][b.r:] +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2dataBuffer) Len() int { + return b.size +} + +// Write appends p to the buffer. +func (b *http2dataBuffer) Write(p []byte) (int, error) { + ntotal := len(p) + for len(p) > 0 { + // If the last chunk is empty, allocate a new chunk. Try to allocate + // enough to fully copy p plus any additional bytes we expect to + // receive. However, this may allocate less than len(p). + want := int64(len(p)) + if b.expected > want { + want = b.expected + } + chunk := b.lastChunkOrAlloc(want) + n := copy(chunk[b.w:], p) + p = p[n:] + b.w += n + b.size += n + b.expected -= int64(n) + } + return ntotal, nil +} + +func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { + if len(b.chunks) != 0 { + last := b.chunks[len(b.chunks)-1] + if b.w < len(last) { + return last + } + } + chunk := http2getDataBufferChunk(want) + b.chunks = append(b.chunks, chunk) + b.w = 0 + return chunk +} + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type http2ErrCode uint32 + +const ( + http2ErrCodeNo http2ErrCode = 0x0 + http2ErrCodeProtocol http2ErrCode = 0x1 + http2ErrCodeInternal http2ErrCode = 0x2 + http2ErrCodeFlowControl http2ErrCode = 0x3 + http2ErrCodeSettingsTimeout http2ErrCode = 0x4 + http2ErrCodeStreamClosed http2ErrCode = 0x5 + http2ErrCodeFrameSize http2ErrCode = 0x6 + http2ErrCodeRefusedStream http2ErrCode = 0x7 + http2ErrCodeCancel http2ErrCode = 0x8 + http2ErrCodeCompression http2ErrCode = 0x9 + http2ErrCodeConnect http2ErrCode = 0xa + http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb + http2ErrCodeInadequateSecurity http2ErrCode = 0xc + http2ErrCodeHTTP11Required http2ErrCode = 0xd +) + +var http2errCodeName = map[http2ErrCode]string{ + http2ErrCodeNo: "NO_ERROR", + http2ErrCodeProtocol: "PROTOCOL_ERROR", + http2ErrCodeInternal: "INTERNAL_ERROR", + http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + http2ErrCodeStreamClosed: "STREAM_CLOSED", + http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", + http2ErrCodeRefusedStream: "REFUSED_STREAM", + http2ErrCodeCancel: "CANCEL", + http2ErrCodeCompression: "COMPRESSION_ERROR", + http2ErrCodeConnect: "CONNECT_ERROR", + http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e http2ErrCode) String() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +func (e http2ErrCode) stringToken() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type http2ConnectionError http2ErrCode + +func (e http2ConnectionError) Error() string { + return fmt.Sprintf("connection error: %s", http2ErrCode(e)) +} + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type http2StreamError struct { + StreamID uint32 + Code http2ErrCode + Cause error // optional additional detail +} + +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var http2errFromPeer = errors.New("received from peer") + +func http2streamError(id uint32, code http2ErrCode) http2StreamError { + return http2StreamError{StreamID: id, Code: code} +} + +func (e http2StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type http2goAwayFlowError struct{} + +func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } + +// connError represents an HTTP/2 ConnectionError error code, along +// with a string (for debugging) explaining why. +// +// Errors of this type are only returned by the frame parser functions +// and converted into ConnectionError(Code), after stashing away +// the Reason into the Framer's errDetail field, accessible via +// the (*Framer).ErrorDetail method. +type http2connError struct { + Code http2ErrCode // the ConnectionError error code + Reason string // additional reason +} + +func (e http2connError) Error() string { + return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) +} + +type http2pseudoHeaderError string + +func (e http2pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type http2duplicatePseudoHeaderError string + +func (e http2duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type http2headerFieldNameError string + +func (e http2headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type http2headerFieldValueError string + +func (e http2headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value %q", string(e)) +} + +var ( + http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + http2errPseudoAfterRegular = errors.New("pseudo header field after regular") +) + +// flow is the flow control window's size. +type http2flow struct { + _ http2incomparable + + // n is the number of DATA bytes we're allowed to send. + // A flow is kept both on a conn and a per-stream. + n int32 + + // conn points to the shared connection-level flow that is + // shared by all streams on that conn. It is nil for the flow + // that's on the conn directly. + conn *http2flow +} + +func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } + +func (f *http2flow) available() int32 { + n := f.n + if f.conn != nil && f.conn.n < n { + n = f.conn.n + } + return n +} + +func (f *http2flow) take(n int32) { + if n > f.available() { + panic("internal error: took too much") + } + f.n -= n + if f.conn != nil { + f.conn.n -= n + } +} + +// add adds n bytes (positive or negative) to the flow control window. +// It returns false if the sum would exceed 2^31-1. +func (f *http2flow) add(n int32) bool { + sum := f.n + n + if (sum > n) == (f.n > 0) { + f.n = sum + return true + } + return false +} + +const http2frameHeaderLen = 9 + +var http2padZeros = make([]byte, 255) // zeros for padding + +// A FrameType is a registered frame type as defined in +// http://http2.github.io/http2-spec/#rfc.section.11.2 +type http2FrameType uint8 + +const ( + http2FrameData http2FrameType = 0x0 + http2FrameHeaders http2FrameType = 0x1 + http2FramePriority http2FrameType = 0x2 + http2FrameRSTStream http2FrameType = 0x3 + http2FrameSettings http2FrameType = 0x4 + http2FramePushPromise http2FrameType = 0x5 + http2FramePing http2FrameType = 0x6 + http2FrameGoAway http2FrameType = 0x7 + http2FrameWindowUpdate http2FrameType = 0x8 + http2FrameContinuation http2FrameType = 0x9 +) + +var http2frameName = map[http2FrameType]string{ + http2FrameData: "DATA", + http2FrameHeaders: "HEADERS", + http2FramePriority: "PRIORITY", + http2FrameRSTStream: "RST_STREAM", + http2FrameSettings: "SETTINGS", + http2FramePushPromise: "PUSH_PROMISE", + http2FramePing: "PING", + http2FrameGoAway: "GOAWAY", + http2FrameWindowUpdate: "WINDOW_UPDATE", + http2FrameContinuation: "CONTINUATION", +} + +func (t http2FrameType) String() string { + if s, ok := http2frameName[t]; ok { + return s + } + return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) +} + +// Flags is a bitmask of HTTP/2 flags. +// The meaning of flags varies depending on the frame type. +type http2Flags uint8 + +// Has reports whether f contains all (0 or more) flags in v. +func (f http2Flags) Has(v http2Flags) bool { + return (f & v) == v +} + +// Frame-specific FrameHeader flag bits. +const ( + // Data Frame + http2FlagDataEndStream http2Flags = 0x1 + http2FlagDataPadded http2Flags = 0x8 + + // Headers Frame + http2FlagHeadersEndStream http2Flags = 0x1 + http2FlagHeadersEndHeaders http2Flags = 0x4 + http2FlagHeadersPadded http2Flags = 0x8 + http2FlagHeadersPriority http2Flags = 0x20 + + // Settings Frame + http2FlagSettingsAck http2Flags = 0x1 + + // Ping Frame + http2FlagPingAck http2Flags = 0x1 + + // Continuation Frame + http2FlagContinuationEndHeaders http2Flags = 0x4 + + http2FlagPushPromiseEndHeaders http2Flags = 0x4 + http2FlagPushPromisePadded http2Flags = 0x8 +) + +var http2flagName = map[http2FrameType]map[http2Flags]string{ + http2FrameData: { + http2FlagDataEndStream: "END_STREAM", + http2FlagDataPadded: "PADDED", + }, + http2FrameHeaders: { + http2FlagHeadersEndStream: "END_STREAM", + http2FlagHeadersEndHeaders: "END_HEADERS", + http2FlagHeadersPadded: "PADDED", + http2FlagHeadersPriority: "PRIORITY", + }, + http2FrameSettings: { + http2FlagSettingsAck: "ACK", + }, + http2FramePing: { + http2FlagPingAck: "ACK", + }, + http2FrameContinuation: { + http2FlagContinuationEndHeaders: "END_HEADERS", + }, + http2FramePushPromise: { + http2FlagPushPromiseEndHeaders: "END_HEADERS", + http2FlagPushPromisePadded: "PADDED", + }, +} + +// a frameParser parses a frame given its FrameHeader and payload +// bytes. The length of payload will always equal fh.Length (which +// might be 0). +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) + +var http2frameParsers = map[http2FrameType]http2frameParser{ + http2FrameData: http2parseDataFrame, + http2FrameHeaders: http2parseHeadersFrame, + http2FramePriority: http2parsePriorityFrame, + http2FrameRSTStream: http2parseRSTStreamFrame, + http2FrameSettings: http2parseSettingsFrame, + http2FramePushPromise: http2parsePushPromise, + http2FramePing: http2parsePingFrame, + http2FrameGoAway: http2parseGoAwayFrame, + http2FrameWindowUpdate: http2parseWindowUpdateFrame, + http2FrameContinuation: http2parseContinuationFrame, +} + +func http2typeFrameParser(t http2FrameType) http2frameParser { + if f := http2frameParsers[t]; f != nil { + return f + } + return http2parseUnknownFrame +} + +// A FrameHeader is the 9 byte header of all HTTP/2 frames. +// +// See http://http2.github.io/http2-spec/#FrameHeader +type http2FrameHeader struct { + valid bool // caller can access []byte fields in the Frame + + // Type is the 1 byte frame type. There are ten standard frame + // types, but extension frame types may be written by WriteRawFrame + // and will be returned by ReadFrame (as UnknownFrame). + Type http2FrameType + + // Flags are the 1 byte of 8 potential bit flags per frame. + // They are specific to the frame type. + Flags http2Flags + + // Length is the length of the frame, not including the 9 byte header. + // The maximum size is one byte less than 16MB (uint24), but only + // frames up to 16KB are allowed without peer agreement. + Length uint32 + + // StreamID is which stream this frame is for. Certain frames + // are not stream-specific, in which case this field is 0. + StreamID uint32 +} + +// Header returns h. It exists so FrameHeaders can be embedded in other +// specific frame types and implement the Frame interface. +func (h http2FrameHeader) Header() http2FrameHeader { return h } + +func (h http2FrameHeader) String() string { + var buf bytes.Buffer + buf.WriteString("[FrameHeader ") + h.writeDebug(&buf) + buf.WriteByte(']') + return buf.String() +} + +func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { + buf.WriteString(h.Type.String()) + if h.Flags != 0 { + buf.WriteString(" flags=") + set := 0 + for i := uint8(0); i < 8; i++ { + if h.Flags&(1< 1 { + buf.WriteByte('|') + } + name := http2flagName[h.Type][http2Flags(1<>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) +} + +func (f *http2Framer) endWrite() error { + // Now that we know the final size, fill in the FrameHeader in + // the space previously reserved for it. Abuse append. + length := len(f.wbuf) - http2frameHeaderLen + if length >= (1 << 24) { + return http2ErrFrameTooLarge + } + _ = append(f.wbuf[:0], + byte(length>>16), + byte(length>>8), + byte(length)) + if f.logWrites { + f.logWrite() + } + + n, err := f.w.Write(f.wbuf) + if err == nil && n != len(f.wbuf) { + err = io.ErrShortWrite + } + return err +} + +func (f *http2Framer) logWrite() { + if f.debugFramer == nil { + f.debugFramerBuf = new(bytes.Buffer) + f.debugFramer = http2NewFramer(nil, f.debugFramerBuf, f.dump) + f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below + // Let us read anything, even if we accidentally wrote it + // in the wrong order: + f.debugFramer.AllowIllegalReads = true + } + f.debugFramerBuf.Write(f.wbuf) + fr, err := f.debugFramer.ReadFrame() + if err != nil { + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) + return + } + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) +} + +func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } + +func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } + +func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } + +func (f *http2Framer) writeUint32(v uint32) { + f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +const ( + http2minMaxFrameSize = 1 << 14 + http2maxFrameSize = 1<<24 - 1 +) + +// SetReuseFrames allows the Framer to reuse Frames. +// If called on a Framer, Frames returned by calls to ReadFrame are only +// valid until the next call to ReadFrame. +func (fr *http2Framer) SetReuseFrames() { + if fr.frameCache != nil { + return + } + fr.frameCache = &http2frameCache{} +} + +type http2frameCache struct { + dataFrame http2DataFrame +} + +func (fc *http2frameCache) getDataFrame() *http2DataFrame { + if fc == nil { + return &http2DataFrame{} + } + return &fc.dataFrame +} + +// NewFramer returns a Framer that writes frames to w and reads them from r. +func http2NewFramer(w io.Writer, r io.Reader, dump *dumper) *http2Framer { + fr := &http2Framer{ + dump: dump, + w: w, + r: r, + countError: func(string) {}, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, + } + fr.getReadBuf = func(size uint32) []byte { + if cap(fr.readBuf) >= int(size) { + return fr.readBuf[:size] + } + fr.readBuf = make([]byte, size) + return fr.readBuf + } + fr.SetMaxReadFrameSize(http2maxFrameSize) + if dump != nil && dump.RequestBody { + fr.WriteData = func(streamID uint32, endStream bool, data []byte) error { + dump.dump(data) + return fr.writeData(streamID, endStream, data) + } + } else { + fr.WriteData = fr.writeData + } + return fr +} + +// SetMaxReadFrameSize sets the maximum size of a frame +// that will be read by a subsequent call to ReadFrame. +// It is the caller's responsibility to advertise this +// limit with a SETTINGS frame. +func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { + if v > http2maxFrameSize { + v = http2maxFrameSize + } + fr.maxReadSize = v +} + +// ErrorDetail returns a more detailed error of the last error +// returned by Framer.ReadFrame. For instance, if ReadFrame +// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail +// will say exactly what was invalid. ErrorDetail is not guaranteed +// to return a non-nil value and like the rest of the http2 package, +// its return value is not protected by an API compatibility promise. +// ErrorDetail is reset after the next call to ReadFrame. +func (fr *http2Framer) ErrorDetail() error { + return fr.errDetail +} + +// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// sends a frame that is larger than declared with SetMaxReadFrameSize. +var http2ErrFrameTooLarge = errors.New("http2: frame too large") + +// terminalReadFrameError reports whether err is an unrecoverable +// error from ReadFrame and no other frames should be read. +func http2terminalReadFrameError(err error) bool { + if _, ok := err.(http2StreamError); ok { + return false + } + return err != nil +} + +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from the underlying +// reader. +func (fr *http2Framer) ReadFrame() (http2Frame, error) { + fr.errDetail = nil + if fr.lastFrame != nil { + fr.lastFrame.invalidate() + } + fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) + if err != nil { + return nil, err + } + if fh.Length > fr.maxReadSize { + return nil, http2ErrFrameTooLarge + } + payload := fr.getReadBuf(fh.Length) + if _, err := io.ReadFull(fr.r, payload); err != nil { + return nil, err + } + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) + if err != nil { + if ce, ok := err.(http2connError); ok { + return nil, fr.connError(ce.Code, ce.Reason) + } + return nil, err + } + if err := fr.checkFrameOrder(f); err != nil { + return nil, err + } + if fr.logReads { + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + } + if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + hr, err := fr.readMetaFrame(f.(*http2HeadersFrame)) + if fr.dump != nil && err == nil && fr.dump.ResponseHead { + fr.dump.dump([]byte("\r\n")) + } + return hr, err + } + return f, nil +} + +// connError returns ConnectionError(code) but first +// stashes away a public reason to the caller can optionally relay it +// to the peer before hanging up on them. This might help others debug +// their implementations. +func (fr *http2Framer) connError(code http2ErrCode, reason string) error { + fr.errDetail = errors.New(reason) + return http2ConnectionError(code) +} + +// checkFrameOrder reports an error if f is an invalid frame to return +// next from ReadFrame. Mostly it checks whether HEADERS and +// CONTINUATION frames are contiguous. +func (fr *http2Framer) checkFrameOrder(f http2Frame) error { + last := fr.lastFrame + fr.lastFrame = f + if fr.AllowIllegalReads { + return nil + } + + fh := f.Header() + if fr.lastHeaderStream != 0 { + if fh.Type != http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", + fh.Type, fh.StreamID, + last.Header().Type, fr.lastHeaderStream)) + } + if fh.StreamID != fr.lastHeaderStream { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", + fh.StreamID, fr.lastHeaderStream)) + } + } else if fh.Type == http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + } + + switch fh.Type { + case http2FrameHeaders, http2FrameContinuation: + if fh.Flags.Has(http2FlagHeadersEndHeaders) { + fr.lastHeaderStream = 0 + } else { + fr.lastHeaderStream = fh.StreamID + } + } + + return nil +} + +// A DataFrame conveys arbitrary, variable-length sequences of octets +// associated with a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.1 +type http2DataFrame struct { + http2FrameHeader + data []byte +} + +func (f *http2DataFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) +} + +// Data returns the frame's data octets, not including any padding +// size byte or padding suffix bytes. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2DataFrame) Data() []byte { + f.checkValid() + return f.data +} + +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + countError("frame_data_stream_0") + return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} + } + f := fc.getDataFrame() + f.http2FrameHeader = fh + + var padSize byte + if fh.Flags.Has(http2FlagDataPadded) { + var err error + payload, padSize, err = http2readByte(payload) + if err != nil { + countError("frame_data_pad_byte_short") + return nil, err + } + } + if int(padSize) > len(payload) { + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 + countError("frame_data_pad_too_big") + return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} + } + f.data = payload[:len(payload)-int(padSize)] + return f, nil +} + +var ( + http2errStreamID = errors.New("invalid stream ID") + http2errDepStreamID = errors.New("invalid dependent stream ID") + http2errPadLength = errors.New("pad length too large") + http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") +) + +func http2validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} + +func http2validStreamID(streamID uint32) bool { + return streamID != 0 && streamID&(1<<31) == 0 +} + +// writeData writes a DATA frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) writeData(streamID uint32, endStream bool, data []byte) error { + return f.WriteDataPadded(streamID, endStream, data, nil) +} + +// WriteDataPadded writes a DATA frame with optional padding. +// +// If pad is nil, the padding bit is not sent. +// The length of pad must not exceed 255 bytes. +// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if len(pad) > 0 { + if len(pad) > 255 { + return http2errPadLength + } + if !f.AllowIllegalWrites { + for _, b := range pad { + if b != 0 { + // "Padding octets MUST be set to zero when sending." + return http2errPadBytes + } + } + } + } + var flags http2Flags + if endStream { + flags |= http2FlagDataEndStream + } + if pad != nil { + flags |= http2FlagDataPadded + } + f.startWrite(http2FrameData, flags, streamID) + if pad != nil { + f.wbuf = append(f.wbuf, byte(len(pad))) + } + f.wbuf = append(f.wbuf, data...) + f.wbuf = append(f.wbuf, pad...) + return f.endWrite() +} + +// A SettingsFrame conveys configuration parameters that affect how +// endpoints communicate, such as preferences and constraints on peer +// behavior. +// +// See http://http2.github.io/http2-spec/#SETTINGS +type http2SettingsFrame struct { + http2FrameHeader + p []byte +} + +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { + // When this (ACK 0x1) bit is set, the payload of the + // SETTINGS frame MUST be empty. Receipt of a + // SETTINGS frame with the ACK flag set and a length + // field value other than 0 MUST be treated as a + // connection error (Section 5.4.1) of type + // FRAME_SIZE_ERROR. + countError("frame_settings_ack_with_length") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + // SETTINGS frames always apply to a connection, + // never a single stream. The stream identifier for a + // SETTINGS frame MUST be zero (0x0). If an endpoint + // receives a SETTINGS frame whose stream identifier + // field is anything other than 0x0, the endpoint MUST + // respond with a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR. + countError("frame_settings_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p)%6 != 0 { + countError("frame_settings_mod_6") + // Expecting even number of 6 byte settings. + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + f := &http2SettingsFrame{http2FrameHeader: fh, p: p} + if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + countError("frame_settings_window_size_too_big") + // Values above the maximum flow control window size of 2^31 - 1 MUST + // be treated as a connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + return nil, http2ConnectionError(http2ErrCodeFlowControl) + } + return f, nil +} + +func (f *http2SettingsFrame) IsAck() bool { + return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) +} + +func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if s := f.Setting(i); s.ID == id { + return s.Val, true + } + } + return 0, false +} + +// Setting returns the setting from the frame at the given 0-based index. +// The index must be >= 0 and less than f.NumSettings(). +func (f *http2SettingsFrame) Setting(i int) http2Setting { + buf := f.p + return http2Setting{ + ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), + Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), + } +} + +func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 } + +// HasDuplicates reports whether f contains any duplicate setting IDs. +func (f *http2SettingsFrame) HasDuplicates() bool { + num := f.NumSettings() + if num == 0 { + return false + } + // If it's small enough (the common case), just do the n^2 + // thing and avoid a map allocation. + if num < 10 { + for i := 0; i < num; i++ { + idi := f.Setting(i).ID + for j := i + 1; j < num; j++ { + idj := f.Setting(j).ID + if idi == idj { + return true + } + } + } + return false + } + seen := map[http2SettingID]bool{} + for i := 0; i < num; i++ { + id := f.Setting(i).ID + if seen[id] { + return true + } + seen[id] = true + } + return false +} + +// ForeachSetting runs fn for each setting. +// It stops and returns the first error. +func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if err := fn(f.Setting(i)); err != nil { + return err + } + } + return nil +} + +// WriteSettings writes a SETTINGS frame with zero or more settings +// specified and the ACK bit not set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettings(settings ...http2Setting) error { + f.startWrite(http2FrameSettings, 0, 0) + for _, s := range settings { + f.writeUint16(uint16(s.ID)) + f.writeUint32(s.Val) + } + return f.endWrite() +} + +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettingsAck() error { + f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) + return f.endWrite() +} + +// A PingFrame is a mechanism for measuring a minimal round trip time +// from the sender, as well as determining whether an idle connection +// is still functional. +// See http://http2.github.io/http2-spec/#rfc.section.6.7 +type http2PingFrame struct { + http2FrameHeader + Data [8]byte +} + +func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } + +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if len(payload) != 8 { + countError("frame_ping_length") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + countError("frame_ping_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + f := &http2PingFrame{http2FrameHeader: fh} + copy(f.Data[:], payload) + return f, nil +} + +func (f *http2Framer) WritePing(ack bool, data [8]byte) error { + var flags http2Flags + if ack { + flags = http2FlagPingAck + } + f.startWrite(http2FramePing, flags, 0) + f.writeBytes(data[:]) + return f.endWrite() +} + +// A GoAwayFrame informs the remote peer to stop creating streams on this connection. +// See http://http2.github.io/http2-spec/#rfc.section.6.8 +type http2GoAwayFrame struct { + http2FrameHeader + LastStreamID uint32 + ErrCode http2ErrCode + debugData []byte +} + +// DebugData returns any debug data in the GOAWAY frame. Its contents +// are not defined. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2GoAwayFrame) DebugData() []byte { + f.checkValid() + return f.debugData +} + +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.StreamID != 0 { + countError("frame_goaway_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p) < 8 { + countError("frame_goaway_short") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + return &http2GoAwayFrame{ + http2FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], + }, nil +} + +func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { + f.startWrite(http2FrameGoAway, 0, 0) + f.writeUint32(maxStreamID & (1<<31 - 1)) + f.writeUint32(uint32(code)) + f.writeBytes(debugData) + return f.endWrite() +} + +// An UnknownFrame is the frame type returned when the frame type is unknown +// or no specific frame type parser exists. +type http2UnknownFrame struct { + http2FrameHeader + p []byte +} + +// Payload returns the frame's payload (after the header). It is not +// valid to call this method after a subsequent call to +// Framer.ReadFrame, nor is it valid to retain the returned slice. +// The memory is owned by the Framer and is invalidated when the next +// frame is read. +func (f *http2UnknownFrame) Payload() []byte { + f.checkValid() + return f.p +} + +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + return &http2UnknownFrame{fh, p}, nil +} + +// A WindowUpdateFrame is used to implement flow control. +// See http://http2.github.io/http2-spec/#rfc.section.6.9 +type http2WindowUpdateFrame struct { + http2FrameHeader + Increment uint32 // never read with high bit set +} + +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if len(p) != 4 { + countError("frame_windowupdate_bad_len") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit + if inc == 0 { + // A receiver MUST treat the receipt of a + // WINDOW_UPDATE frame with an flow control window + // increment of 0 as a stream error (Section 5.4.2) of + // type PROTOCOL_ERROR; errors on the connection flow + // control window MUST be treated as a connection + // error (Section 5.4.1). + if fh.StreamID == 0 { + countError("frame_windowupdate_zero_inc_conn") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + countError("frame_windowupdate_zero_inc_stream") + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + return &http2WindowUpdateFrame{ + http2FrameHeader: fh, + Increment: inc, + }, nil +} + +// WriteWindowUpdate writes a WINDOW_UPDATE frame. +// The increment value must be between 1 and 2,147,483,647, inclusive. +// If the Stream ID is zero, the window update applies to the +// connection as a whole. +func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { + // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." + if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + return errors.New("illegal window increment value") + } + f.startWrite(http2FrameWindowUpdate, 0, streamID) + f.writeUint32(incr) + return f.endWrite() +} + +// A HeadersFrame is used to open a stream and additionally carries a +// header block fragment. +type http2HeadersFrame struct { + http2FrameHeader + + // Priority is set if FlagHeadersPriority is set in the FrameHeader. + Priority http2PriorityParam + + headerFragBuf []byte // not owned +} + +func (f *http2HeadersFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2HeadersFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) +} + +func (f *http2HeadersFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) +} + +func (f *http2HeadersFrame) HasPriority() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) +} + +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { + hf := &http2HeadersFrame{ + http2FrameHeader: fh, + } + if fh.StreamID == 0 { + // HEADERS frames MUST be associated with a stream. If a HEADERS frame + // is received whose stream identifier field is 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + countError("frame_headers_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} + } + var padLength uint8 + if fh.Flags.Has(http2FlagHeadersPadded) { + if p, padLength, err = http2readByte(p); err != nil { + countError("frame_headers_pad_short") + return + } + } + if fh.Flags.Has(http2FlagHeadersPriority) { + var v uint32 + p, v, err = http2readUint32(p) + if err != nil { + countError("frame_headers_prio_short") + return nil, err + } + hf.Priority.StreamDep = v & 0x7fffffff + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set + p, hf.Priority.Weight, err = http2readByte(p) + if err != nil { + countError("frame_headers_prio_weight_short") + return nil, err + } + } + if len(p)-int(padLength) < 0 { + countError("frame_headers_pad_too_big") + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + hf.headerFragBuf = p[:len(p)-int(padLength)] + return hf, nil +} + +// HeadersFrameParam are the parameters for writing a HEADERS frame. +type http2HeadersFrameParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndStream indicates that the header block is the last that + // the endpoint will send for the identified stream. Setting + // this flag causes the stream to enter one of "half closed" + // states. + EndStream bool + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 + + // Priority, if non-zero, includes stream priority information + // in the HEADER frame. + Priority http2PriorityParam +} + +// WriteHeaders writes a single HEADERS frame. +// +// This is a low-level header writing method. Encoding headers and +// splitting them into any necessary CONTINUATION frames is handled +// elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagHeadersPadded + } + if p.EndStream { + flags |= http2FlagHeadersEndStream + } + if p.EndHeaders { + flags |= http2FlagHeadersEndHeaders + } + if !p.Priority.IsZero() { + flags |= http2FlagHeadersPriority + } + f.startWrite(http2FrameHeaders, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !p.Priority.IsZero() { + v := p.Priority.StreamDep + if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return http2errDepStreamID + } + if p.Priority.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Priority.Weight) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// A PriorityFrame specifies the sender-advised priority of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.3 +type http2PriorityFrame struct { + http2FrameHeader + http2PriorityParam +} + +// PriorityParam are the stream prioritzation parameters. +type http2PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p http2PriorityParam) IsZero() bool { + return p == http2PriorityParam{} +} + +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + countError("frame_priority_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} + } + if len(payload) != 5 { + countError("frame_priority_bad_length") + return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} + } + v := binary.BigEndian.Uint32(payload[:4]) + streamID := v & 0x7fffffff // mask off high bit + return &http2PriorityFrame{ + http2FrameHeader: fh, + http2PriorityParam: http2PriorityParam{ + Weight: payload[4], + StreamDep: streamID, + Exclusive: streamID != v, // was high bit set? + }, + }, nil +} + +// WritePriority writes a PRIORITY frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if !http2validStreamIDOrZero(p.StreamDep) { + return http2errDepStreamID + } + f.startWrite(http2FramePriority, 0, streamID) + v := p.StreamDep + if p.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Weight) + return f.endWrite() +} + +// A RSTStreamFrame allows for abnormal termination of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.4 +type http2RSTStreamFrame struct { + http2FrameHeader + ErrCode http2ErrCode +} + +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if len(p) != 4 { + countError("frame_rststream_bad_len") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID == 0 { + countError("frame_rststream_zero_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil +} + +// WriteRSTStream writes a RST_STREAM frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.startWrite(http2FrameRSTStream, 0, streamID) + f.writeUint32(uint32(code)) + return f.endWrite() +} + +// A ContinuationFrame is used to continue a sequence of header block fragments. +// See http://http2.github.io/http2-spec/#rfc.section.6.10 +type http2ContinuationFrame struct { + http2FrameHeader + headerFragBuf []byte +} + +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.StreamID == 0 { + countError("frame_continuation_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} + } + return &http2ContinuationFrame{fh, p}, nil +} + +func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2ContinuationFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) +} + +// WriteContinuation writes a CONTINUATION frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if endHeaders { + flags |= http2FlagContinuationEndHeaders + } + f.startWrite(http2FrameContinuation, flags, streamID) + f.wbuf = append(f.wbuf, headerBlockFragment...) + return f.endWrite() +} + +// A PushPromiseFrame is used to initiate a server stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.6 +type http2PushPromiseFrame struct { + http2FrameHeader + PromiseID uint32 + headerFragBuf []byte // not owned +} + +func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2PushPromiseFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) +} + +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { + pp := &http2PushPromiseFrame{ + http2FrameHeader: fh, + } + if pp.StreamID == 0 { + // PUSH_PROMISE frames MUST be associated with an existing, + // peer-initiated stream. The stream identifier of a + // PUSH_PROMISE frame indicates the stream it is associated + // with. If the stream identifier field specifies the value + // 0x0, a recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + countError("frame_pushpromise_zero_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + // The PUSH_PROMISE frame includes optional padding. + // Padding fields and flags are identical to those defined for DATA frames + var padLength uint8 + if fh.Flags.Has(http2FlagPushPromisePadded) { + if p, padLength, err = http2readByte(p); err != nil { + countError("frame_pushpromise_pad_short") + return + } + } + + p, pp.PromiseID, err = http2readUint32(p) + if err != nil { + countError("frame_pushpromise_promiseid_short") + return + } + pp.PromiseID = pp.PromiseID & (1<<31 - 1) + + if int(padLength) > len(p) { + // like the DATA frame, error out if padding is longer than the body. + countError("frame_pushpromise_pad_too_big") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + pp.headerFragBuf = p[:len(p)-int(padLength)] + return pp, nil +} + +// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. +type http2PushPromiseParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + + // PromiseID is the required Stream ID which this + // Push Promises + PromiseID uint32 + + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 +} + +// WritePushPromise writes a single PushPromise Frame. +// +// As with Header Frames, This is the low level call for writing +// individual frames. Continuation frames are handled elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagPushPromisePadded + } + if p.EndHeaders { + flags |= http2FlagPushPromiseEndHeaders + } + f.startWrite(http2FramePushPromise, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.writeUint32(p.PromiseID) + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// WriteRawFrame writes a raw frame. This can be used to write +// extension frames unknown to this package. +func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { + f.startWrite(t, flags, streamID) + f.writeBytes(payload) + return f.endWrite() +} + +func http2readByte(p []byte) (remain []byte, b byte, err error) { + if len(p) == 0 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[1:], p[0], nil +} + +func http2readUint32(p []byte) (remain []byte, v uint32, err error) { + if len(p) < 4 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[4:], binary.BigEndian.Uint32(p[:4]), nil +} + +type http2streamEnder interface { + StreamEnded() bool +} + +type http2headersEnder interface { + HeadersEnded() bool +} + +type http2headersOrContinuation interface { + http2headersEnder + HeaderBlockFragment() []byte +} + +// A MetaHeadersFrame is the representation of one HEADERS frame and +// zero or more contiguous CONTINUATION frames and the decoding of +// their HPACK-encoded contents. +// +// This type of frame does not appear on the wire and is only returned +// by the Framer when Framer.ReadMetaHeaders is set. +type http2MetaHeadersFrame struct { + *http2HeadersFrame + + // Fields are the fields contained in the HEADERS and + // CONTINUATION frames. The underlying slice is owned by the + // Framer and must not be retained after the next call to + // ReadFrame. + // + // Fields are guaranteed to be in the correct http2 order and + // not have unknown pseudo header fields or invalid header + // field names or values. Required pseudo header fields may be + // missing, however. Use the MetaHeadersFrame.Pseudo accessor + // method access pseudo headers. + Fields []hpack.HeaderField + + // Truncated is whether the max header list size limit was hit + // and Fields is incomplete. The hpack decoder state is still + // valid, however. + Truncated bool +} + +// PseudoValue returns the given pseudo header field's value. +// The provided pseudo field should not contain the leading colon. +func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { + for _, hf := range mh.Fields { + if !hf.IsPseudo() { + return "" + } + if hf.Name[1:] == pseudo { + return hf.Value + } + } + return "" +} + +// RegularFields returns the regular (non-pseudo) header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[i:] + } + } + return nil +} + +// PseudoFields returns the pseudo header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[:i] + } + } + return mh.Fields +} + +func (mh *http2MetaHeadersFrame) checkPseudos() error { + var isRequest, isResponse bool + pf := mh.PseudoFields() + for i, hf := range pf { + switch hf.Name { + case ":method", ":path", ":scheme", ":authority": + isRequest = true + case ":status": + isResponse = true + default: + return http2pseudoHeaderError(hf.Name) + } + // Check for duplicates. + // This would be a bad algorithm, but N is 4. + // And this doesn't allocate. + for _, hf2 := range pf[:i] { + if hf.Name == hf2.Name { + return http2duplicatePseudoHeaderError(hf.Name) + } + } + } + if isRequest && isResponse { + return http2errMixPseudoHeaderTypes + } + return nil +} + +func (fr *http2Framer) maxHeaderStringLen() int { + v := fr.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + // They had a crazy big number for MaxHeaderBytes anyway, + // so give them unlimited header lengths: + return 0 +} + +// readMetaFrame returns 0 or more CONTINUATION frames from fr and +// merge them into the provided hf and returns a MetaHeadersFrame +// with the decoded hpack values. +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { + if fr.AllowIllegalReads { + return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") + } + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: hf, + } + var remainSize = fr.maxHeaderListSize() + var sawRegular bool + + var invalid error // pseudo header field errors + hdec := fr.ReadMetaHeaders + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + rawEmitFunc := func(hf hpack.HeaderField) { + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) + } + if !httpguts.ValidHeaderFieldValue(hf.Value) { + invalid = http2headerFieldValueError(hf.Value) + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = http2errPseudoAfterRegular + } + } else { + sawRegular = true + if !http2validWireHeaderFieldName(hf.Name) { + invalid = http2headerFieldNameError(hf.Name) + } + } + + if invalid != nil { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + mh.Truncated = true + return + } + remainSize -= size + + mh.Fields = append(mh.Fields, hf) + } + emitFunc := rawEmitFunc + if fr.dump != nil && fr.dump.ResponseHead { + emitFunc = func(hf hpack.HeaderField) { + fr.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + rawEmitFunc(hf) + } + } + hdec.SetEmitFunc(emitFunc) + // Lose reference to MetaHeadersFrame: + defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var hc http2headersOrContinuation = hf + for { + frag := hc.HeaderBlockFragment() + if _, err := hdec.Write(frag); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + + if hc.HeadersEnded() { + break + } + if f, err := fr.ReadFrame(); err != nil { + return nil, err + } else { + hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder + } + } + + mh.http2HeadersFrame.headerFragBuf = nil + mh.http2HeadersFrame.invalidate() + + if err := hdec.Close(); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + if invalid != nil { + fr.errDetail = invalid + if http2VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} + } + if err := mh.checkPseudos(); err != nil { + fr.errDetail = err + if http2VerboseLogs { + log.Printf("http2: invalid pseudo headers: %v", err) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} + } + return mh, nil +} + +func http2summarizeFrame(f http2Frame) string { + var buf bytes.Buffer + f.Header().writeDebug(&buf) + switch f := f.(type) { + case *http2SettingsFrame: + n := 0 + f.ForeachSetting(func(s http2Setting) error { + n++ + if n == 1 { + buf.WriteString(", settings:") + } + fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val) + return nil + }) + if n > 0 { + buf.Truncate(buf.Len() - 1) // remove trailing comma + } + case *http2DataFrame: + data := f.Data() + const max = 256 + if len(data) > max { + data = data[:max] + } + fmt.Fprintf(&buf, " data=%q", data) + if len(f.Data()) > max { + fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) + } + case *http2WindowUpdateFrame: + if f.StreamID == 0 { + buf.WriteString(" (conn)") + } + fmt.Fprintf(&buf, " incr=%v", f.Increment) + case *http2PingFrame: + fmt.Fprintf(&buf, " ping=%q", f.Data[:]) + case *http2GoAwayFrame: + fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", + f.LastStreamID, f.ErrCode, f.debugData) + case *http2RSTStreamFrame: + fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) + } + return buf.String() +} + +func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +} + +func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(k, []string{v}) + } +} + +func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { + if trace != nil { + return trace.Got1xxResponse + } + return nil +} + +// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS +// connection. +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, + } + cn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed + return tlsCn, nil +} + +var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" + +type http2goroutineLock uint64 + +func http2newGoroutineLock() http2goroutineLock { + if !http2DebugGoroutines { + return 0 + } + return http2goroutineLock(http2curGoroutineID()) +} + +func (g http2goroutineLock) check() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() != uint64(g) { + panic("running on the wrong goroutine") + } +} + +func (g http2goroutineLock) checkNotOn() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() == uint64(g) { + panic("running on the wrong goroutine") + } +} + +var http2goroutineSpace = []byte("goroutine ") + +func http2curGoroutineID() uint64 { + bp := http2littleBuf.Get().(*[]byte) + defer http2littleBuf.Put(bp) + b := *bp + b = b[:runtime.Stack(b, false)] + // Parse the 4707 out of "goroutine 4707 [" + b = bytes.TrimPrefix(b, http2goroutineSpace) + i := bytes.IndexByte(b, ' ') + if i < 0 { + panic(fmt.Sprintf("No space found in %q", b)) + } + b = b[:i] + n, err := http2parseUintBytes(b, 10, 64) + if err != nil { + panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) + } + return n +} + +var http2littleBuf = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 64) + return &buf + }, +} + +// parseUintBytes is like strconv.ParseUint, but using a []byte. +func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { + var cutoff, maxVal uint64 + + if bitSize == 0 { + bitSize = int(strconv.IntSize) + } + + s0 := s + switch { + case len(s) < 1: + err = strconv.ErrSyntax + goto Error + + case 2 <= base && base <= 36: + // valid base; nothing to do + + case base == 0: + // Look for octal, hex prefix. + switch { + case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): + base = 16 + s = s[2:] + if len(s) < 1 { + err = strconv.ErrSyntax + goto Error + } + case s[0] == '0': + base = 8 + default: + base = 10 + } + + default: + err = errors.New("invalid base " + strconv.Itoa(base)) + goto Error + } + + n = 0 + cutoff = http2cutoff64(base) + maxVal = 1<= base { + n = 0 + err = strconv.ErrSyntax + goto Error + } + + if n >= cutoff { + // n*base overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n *= uint64(base) + + n1 := n + uint64(v) + if n1 < n || n1 > maxVal { + // n+v overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n = n1 + } + + return n, nil + +Error: + return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} +} + +// Return the first number n such that n*base >= 1<<64. +func http2cutoff64(base int) uint64 { + if base < 2 { + return 0 + } + return (1<<64-1)/uint64(base) + 1 +} + +var ( + http2commonBuildOnce sync.Once + http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case +) + +func http2buildCommonHeaderMapsOnce() { + http2commonBuildOnce.Do(http2buildCommonHeaderMaps) +} + +func http2buildCommonHeaderMaps() { + common := []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-origin", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "trailer", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + } + http2commonLowerHeader = make(map[string]string, len(common)) + http2commonCanonHeader = make(map[string]string, len(common)) + for _, v := range common { + chk := http.CanonicalHeaderKey(v) + http2commonLowerHeader[chk] = v + http2commonCanonHeader[v] = chk + } +} + +func http2lowerHeader(v string) (lower string, ascii bool) { + http2buildCommonHeaderMapsOnce() + if s, ok := http2commonLowerHeader[v]; ok { + return s, true + } + return http2asciiToLower(v) +} + +var ( + http2VerboseLogs bool + http2logFrameWrites bool + http2logFrameReads bool + http2inTests bool +) + +func init() { + e := os.Getenv("GODEBUG") + if strings.Contains(e, "http2debug=1") { + http2VerboseLogs = true + } + if strings.Contains(e, "http2debug=2") { + http2VerboseLogs = true + http2logFrameWrites = true + http2logFrameReads = true + } +} + +const ( + // ClientPreface is the string that must be sent by new + // connections from clients. + http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + // SETTINGS_MAX_FRAME_SIZE default + // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + http2initialMaxFrameSize = 16384 + + // NextProtoTLS is the NPN/ALPN protocol negotiated during + // HTTP/2's TLS setup. + http2NextProtoTLS = "h2" + + // http://http2.github.io/http2-spec/#SettingValues + http2initialHeaderTableSize = 4096 + + http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size + + http2defaultMaxReadFrameSize = 1 << 20 +) + +var ( + http2clientPreface = []byte(http2ClientPreface) +) + +type http2streamState int + +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. +const ( + http2stateIdle http2streamState = iota + http2stateOpen + http2stateHalfClosedLocal + http2stateHalfClosedRemote + http2stateClosed +) + +var http2stateName = [...]string{ + http2stateIdle: "Idle", + http2stateOpen: "Open", + http2stateHalfClosedLocal: "HalfClosedLocal", + http2stateHalfClosedRemote: "HalfClosedRemote", + http2stateClosed: "Closed", +} + +func (st http2streamState) String() string { + return http2stateName[st] +} + +// Setting is a setting parameter: which setting it is, and its value. +type http2Setting struct { + // ID is which setting is being set. + // See http://http2.github.io/http2-spec/#SettingValues + ID http2SettingID + + // Val is the value. + Val uint32 +} + +func (s http2Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} + +// Valid reports whether the setting is valid. +func (s http2Setting) Valid() error { + // Limits and error codes from 6.5.2 Defined SETTINGS Parameters + switch s.ID { + case http2SettingEnablePush: + if s.Val != 1 && s.Val != 0 { + return http2ConnectionError(http2ErrCodeProtocol) + } + case http2SettingInitialWindowSize: + if s.Val > 1<<31-1 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + case http2SettingMaxFrameSize: + if s.Val < 16384 || s.Val > 1<<24-1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + } + return nil +} + +// A SettingID is an HTTP/2 setting as defined in +// http://http2.github.io/http2-spec/#iana-settings +type http2SettingID uint16 + +const ( + http2SettingHeaderTableSize http2SettingID = 0x1 + http2SettingEnablePush http2SettingID = 0x2 + http2SettingMaxConcurrentStreams http2SettingID = 0x3 + http2SettingInitialWindowSize http2SettingID = 0x4 + http2SettingMaxFrameSize http2SettingID = 0x5 + http2SettingMaxHeaderListSize http2SettingID = 0x6 +) + +var http2settingName = map[http2SettingID]string{ + http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", + http2SettingEnablePush: "ENABLE_PUSH", + http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + http2SettingMaxFrameSize: "MAX_FRAME_SIZE", + http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s http2SettingID) String() string { + if v, ok := http2settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} + +// validWireHeaderFieldName reports whether v is a valid header field +// name (key). See httpguts.ValidHeaderName for the base rules. +// +// Further, http2 says: +// "Just as in HTTP/1.x, header field names are strings of ASCII +// characters that are compared in a case-insensitive +// fashion. However, header field names MUST be converted to +// lowercase prior to their encoding in HTTP/2. " +func http2validWireHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if !httpguts.IsTokenRune(r) { + return false + } + if 'A' <= r && r <= 'Z' { + return false + } + } + return true +} + +func http2httpCodeString(code int) string { + switch code { + case 200: + return "200" + case 404: + return "404" + } + return strconv.Itoa(code) +} + +// from pkg io +type http2stringWriter interface { + WriteString(s string) (n int, err error) +} + +// A gate lets two goroutines coordinate their activities. +type http2gate chan struct{} + +func (g http2gate) Done() { g <- struct{}{} } + +func (g http2gate) Wait() { <-g } + +// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). +type http2closeWaiter chan struct{} + +// Init makes a closeWaiter usable. +// It exists because so a closeWaiter value can be placed inside a +// larger struct and have the Mutex and Cond's memory in the same +// allocation. +func (cw *http2closeWaiter) Init() { + *cw = make(chan struct{}) +} + +// Close marks the closeWaiter as closed and unblocks any waiters. +func (cw http2closeWaiter) Close() { + close(cw) +} + +// Wait waits for the closeWaiter to become closed. +func (cw http2closeWaiter) Wait() { + <-cw +} + +// bufferedWriter is a buffered writer that writes to w. +// Its buffered writer is lazily allocated as needed, to minimize +// idle memory usage with many connections. +type http2bufferedWriter struct { + _ http2incomparable + w io.Writer // immutable + bw *bufio.Writer // non-nil when data is buffered +} + +func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { + return &http2bufferedWriter{w: w} +} + +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + +var http2bufWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) + }, +} + +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + +func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { + if w.bw == nil { + bw := http2bufWriterPool.Get().(*bufio.Writer) + bw.Reset(w.w) + w.bw = bw + } + return w.bw.Write(p) +} + +func (w *http2bufferedWriter) Flush() error { + bw := w.bw + if bw == nil { + return nil + } + err := bw.Flush() + bw.Reset(nil) + http2bufWriterPool.Put(bw) + w.bw = nil + return err +} + +func http2mustUint31(v int32) uint32 { + if v < 0 || v > 2147483647 { + panic("out of range") + } + return uint32(v) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func http2bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +type http2httpError struct { + _ http2incomparable + msg string + timeout bool +} + +func (e *http2httpError) Error() string { return e.msg } + +func (e *http2httpError) Timeout() bool { return e.timeout } + +func (e *http2httpError) Temporary() bool { return true } + +var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} + +type http2connectionStater interface { + ConnectionState() tls.ConnectionState +} + +var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} + +type http2sorter struct { + v []string // owned by sorter +} + +func (s *http2sorter) Len() int { return len(s.v) } + +func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } + +func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } + +// Keys returns the sorted keys of h. +// +// The returned slice is only valid until s used again or returned to +// its pool. +func (s *http2sorter) Keys(h http.Header) []string { + keys := s.v[:0] + for k := range h { + keys = append(keys, k) + } + s.v = keys + sort.Sort(s) + return keys +} + +func (s *http2sorter) SortStrings(ss []string) { + // Our sorter works on s.v, which sorter owns, so + // stash it away while we sort the user's buffer. + save := s.v + s.v = ss + sort.Sort(s) + s.v = save +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type http2incomparable [0]func() + +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// io.Pipe except there are no PipeReader/PipeWriter halves, and the +// underlying buffer is an interface. (io.Pipe is always unbuffered) +type http2pipe struct { + mu sync.Mutex + c sync.Cond // c.L lazily initialized to &p.mu + b http2pipeBuffer // nil when done reading + unread int // bytes unread when done + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error +} + +type http2pipeBuffer interface { + Len() int + io.Writer + io.Reader +} + +// setBuffer initializes the pipe buffer. +// It has no effect if the pipe is already closed. +func (p *http2pipe) setBuffer(b http2pipeBuffer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil || p.breakErr != nil { + return + } + p.b = b +} + +func (p *http2pipe) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + if p.b == nil { + return p.unread + } + return p.b.Len() +} + +// Read waits until data is available and copies bytes +// from the buffer into p. +func (p *http2pipe) Read(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + for { + if p.breakErr != nil { + return 0, p.breakErr + } + if p.b != nil && p.b.Len() > 0 { + return p.b.Read(d) + } + if p.err != nil { + if p.readFn != nil { + p.readFn() // e.g. copy trailers + p.readFn = nil // not sticky like p.err + } + p.b = nil + return 0, p.err + } + p.c.Wait() + } +} + +var http2errClosedPipeWrite = errors.New("write on closed buffer") + +// Write copies bytes from p into the buffer and wakes a reader. +// It is an error to write more data than the buffer can hold. +func (p *http2pipe) Write(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err != nil { + return 0, http2errClosedPipeWrite + } + if p.breakErr != nil { + p.unread += len(d) + return len(d), nil // discard when there is no reader + } + return p.b.Write(d) +} + +// CloseWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err after all data has been +// read. +// +// The error must be non-nil. +func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } + +// BreakWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err immediately, without +// waiting for unread data. +func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } + +// closeWithErrorAndCode is like CloseWithError but also sets some code to run +// in the caller's goroutine before returning the error. +func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } + +func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { + if err == nil { + panic("err must be non-nil") + } + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if *dst != nil { + // Already been done. + return + } + p.readFn = fn + if dst == &p.breakErr { + if p.b != nil { + p.unread += p.b.Len() + } + p.b = nil + } + *dst = err + p.closeDoneLocked() +} + +// requires p.mu be held. +func (p *http2pipe) closeDoneLocked() { + if p.donec == nil { + return + } + // Close if unclosed. This isn't racy since we always + // hold p.mu while closing. + select { + case <-p.donec: + default: + close(p.donec) + } +} + +// Err returns the error (if any) first set by BreakWithError or CloseWithError. +func (p *http2pipe) Err() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.breakErr != nil { + return p.breakErr + } + return p.err +} + +// Done returns a channel which is closed if and when this pipe is closed +// with CloseWithError. +func (p *http2pipe) Done() <-chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + if p.donec == nil { + p.donec = make(chan struct{}) + if p.err != nil || p.breakErr != nil { + // Already hit an error. + p.closeDoneLocked() + } + } + return p.donec +} + +const ( + http2prefaceTimeout = 10 * time.Second + http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + http2handlerChunkWriteSize = 4 << 10 + http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + http2maxQueuedControlFrames = 10000 +) + +var ( + http2errClientDisconnected = errors.New("client disconnected") + http2errClosedBody = errors.New("body closed by handler") + http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + http2errStreamClosed = errors.New("http2: stream closed") +) + +var http2responseWriterStatePool = sync.Pool{ + New: func() interface{} { + rws := &http2responseWriterState{} + rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) + return rws + }, +} + +// Test hooks. +var ( + http2testHookOnConn func() + http2testHookGetServerConn func(*http2serverConn) + http2testHookOnPanicMu *sync.Mutex // nil except in tests + http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) +) + +// Server is an HTTP/2 server. +type http2Server struct { + // MaxHandlers limits the number of http.Handler ServeHTTP goroutines + // which may run at a time over all connections. + // Negative or zero no limit. + // TODO: implement + MaxHandlers int + + // MaxConcurrentStreams optionally specifies the number of + // concurrent streams that each client may have open at a + // time. This is unrelated to the number of http.Handler goroutines + // which may be active globally, which is MaxHandlers. + // If zero, MaxConcurrentStreams defaults to at least 100, per + // the HTTP/2 spec's recommendations. + MaxConcurrentStreams uint32 + + // MaxReadFrameSize optionally specifies the largest frame + // this server is willing to read. A valid value is between + // 16k and 16M, inclusive. If zero or otherwise invalid, a + // default value is used. + MaxReadFrameSize uint32 + + // PermitProhibitedCipherSuites, if true, permits the use of + // cipher suites prohibited by the HTTP/2 spec. + PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // MaxUploadBufferPerConnection is the size of the initial flow + // control window for each connections. The HTTP/2 spec does not + // allow this to be smaller than 65535 or larger than 2^32-1. + // If the value is outside this range, a default value will be + // used instead. + MaxUploadBufferPerConnection int32 + + // MaxUploadBufferPerStream is the size of the initial flow control + // window for each stream. The HTTP/2 spec does not allow this to + // be larger than 2^32-1. If the value is zero or larger than the + // maximum, a default value will be used instead. + MaxUploadBufferPerStream int32 + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler + + // CountError, if non-nil, is called on HTTP/2 server errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + + // Internal state. This is a pointer (rather than embedded directly) + // so that we don't embed a Mutex in this struct, which will make the + // struct non-copyable, which might break some callers. + state *http2serverInternalState +} + +func (s *http2Server) initialConnRecvWindowSize() int32 { + if s.MaxUploadBufferPerConnection > http2initialWindowSize { + return s.MaxUploadBufferPerConnection + } + return 1 << 20 +} + +func (s *http2Server) initialStreamRecvWindowSize() int32 { + if s.MaxUploadBufferPerStream > 0 { + return s.MaxUploadBufferPerStream + } + return 1 << 20 +} + +func (s *http2Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { + return v + } + return http2defaultMaxReadFrameSize +} + +func (s *http2Server) maxConcurrentStreams() uint32 { + if v := s.MaxConcurrentStreams; v > 0 { + return v + } + return http2defaultMaxStreams +} + +// maxQueuedControlFrames is the maximum number of control frames like +// SETTINGS, PING and RST_STREAM that will be queued for writing before +// the connection is closed to prevent memory exhaustion attacks. +func (s *http2Server) maxQueuedControlFrames() int { + // TODO: if anybody asks, add a Server field, and remember to define the + // behavior of negative values. + return http2maxQueuedControlFrames +} + +type http2serverInternalState struct { + mu sync.Mutex + activeConns map[*http2serverConn]struct{} +} + +func (s *http2serverInternalState) registerConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + s.activeConns[sc] = struct{}{} + s.mu.Unlock() +} + +func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() +} + +func (s *http2serverInternalState) startGracefulShutdown() { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + for sc := range s.activeConns { + sc.startGracefulShutdown() + } + s.mu.Unlock() +} + +// ServeConnOpts are options for the Server.ServeConn method. +type http2ServeConnOpts struct { + // Context is the base context to use. + // If nil, context.Background is used. + Context context.Context + + // BaseConfig optionally sets the base configuration + // for values. If nil, defaults are used. + BaseConfig *http.Server + + // Handler specifies which handler to use for processing + // requests. If nil, BaseConfig.Handler is used. If BaseConfig + // or BaseConfig.Handler is nil, http.DefaultServeMux is used. + Handler http.Handler +} + +func (o *http2ServeConnOpts) context() context.Context { + if o != nil && o.Context != nil { + return o.Context + } + return context.Background() +} + +func (o *http2ServeConnOpts) baseConfig() *http.Server { + if o != nil && o.BaseConfig != nil { + return o.BaseConfig + } + return new(http.Server) +} + +func (o *http2ServeConnOpts) handler() http.Handler { + if o != nil { + if o.Handler != nil { + return o.Handler + } + if o.BaseConfig != nil && o.BaseConfig.Handler != nil { + return o.BaseConfig.Handler + } + } + return http.DefaultServeMux +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { + ctx, cancel = context.WithCancel(opts.context()) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, http.ServerContextKey, hs) + } + return +} + +func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { + sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) + // ignoring errors. hanging up anyway. + sc.framer.WriteGoAway(0, err, []byte(debug)) + sc.bw.Flush() + sc.conn.Close() +} + +type http2serverConn struct { + // Immutable: + srv *http2Server + hs *http.Server + conn net.Conn + bw *http2bufferedWriter // writing to conn + handler http.Handler + baseCtx context.Context + framer *http2Framer + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http + remoteAddrStr string + writeSched http2WriteScheduler + + // Everything following is owned by the serve loop; use serveG.check(): + serveG http2goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + queuedControlFrames int // control frames in the writeSched queue + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes + streams map[uint32]*http2stream + initialStreamSendWindowSize int32 + maxFrameSize int32 + headerTableSize uint32 + peerMaxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode http2ErrCode + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused + + // Owned by the writeFrameAsync goroutine: + headerWriteBuf bytes.Buffer + hpackEncoder *hpack.Encoder + + // Used by startGracefulShutdown. + shutdownOnce sync.Once +} + +func (sc *http2serverConn) maxHeaderListSize() uint32 { + n := sc.hs.MaxHeaderBytes + if n <= 0 { + n = http.DefaultMaxHeaderBytes + } + // http2's count is in a slightly different unit and includes 32 bytes per pair. + // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. + const perFieldOverhead = 32 // per http2 spec + const typicalHeaders = 10 // conservative + return uint32(n + typicalHeaders*perFieldOverhead) +} + +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + +// stream represents a stream. This is the minimal metadata needed by +// the serve goroutine. Most of the actual stream state is owned by +// the http.Handler's goroutine in the responseWriter. Because the +// responseWriter's responseWriterState is recycled at the end of a +// handler, this struct intentionally has no pointer to the +// *responseWriter{,State} itself, as the Handler ending nils out the +// responseWriter's state field. +type http2stream struct { + // immutable: + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx context.Context + cancelCtx func() + + // owned by serverConn's serve loop: + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2flow // limits writing from Handler to client + inflow http2flow // what the client is allowed to POST/etc to us + state http2streamState + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + writeDeadline *time.Timer // nil if unused + + trailer http.Header // accumulated trailers + reqTrailer http.Header // handler's Request.Trailer +} + +func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } + +func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } + +func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } + +func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { + return sc.hpackEncoder, &sc.headerWriteBuf +} + +func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { + sc.serveG.check() + // http://tools.ietf.org/html/rfc7540#section-5.1 + if st, ok := sc.streams[streamID]; ok { + return st.state, st + } + // "The first use of a new stream identifier implicitly closes all + // streams in the "idle" state that might have been initiated by + // that peer with a lower-valued stream identifier. For example, if + // a client sends a HEADERS frame on stream 7 without ever sending a + // frame on stream 5, then stream 5 transitions to the "closed" + // state when the first frame for stream 7 is sent or received." + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } + } + return http2stateIdle, nil +} + +// setConnState calls the net/http ConnState hook for this connection, if configured. +// Note that the net/http package does StateNew and StateClosed for us. +// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. +func (sc *http2serverConn) setConnState(state http.ConnState) { + if sc.hs.ConnState != nil { + sc.hs.ConnState(sc.conn, state) + } +} + +func (sc *http2serverConn) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) logf(format string, args ...interface{}) { + if lg := sc.hs.ErrorLog; lg != nil { + lg.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// errno returns v's underlying uintptr, else 0. +// +// TODO: remove this helper function once http2 can use build +// tags. See comment in isClosedConnError. +func http2errno(v error) uintptr { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { + return uintptr(rv.Uint()) + } + return 0 +} + +// isClosedConnError reports whether err is an error from use of a closed +// network connection. +func http2isClosedConnError(err error) bool { + if err == nil { + return false + } + + // TODO: remove this string search and be more like the Windows + // case below. That might involve modifying the standard library + // to return better error types. + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + + // TODO(bradfitz): x/tools/cmd/bundle doesn't really support + // build tags, so I can't make an http2_windows.go file with + // Windows-specific stuff. Fix that and move this, once we + // have a way to bundle this into std's net/http somehow. + if runtime.GOOS == "windows" { + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { + const WSAECONNABORTED = 10053 + const WSAECONNRESET = 10054 + if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + return true + } + } + } + } + return false +} + +func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { + if err == nil { + return + } + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { + // Boring, expected errors. + sc.vlogf(format, args...) + } else { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) canonicalHeader(v string) string { + sc.serveG.check() + http2buildCommonHeaderMapsOnce() + cv, ok := http2commonCanonHeader[v] + if ok { + return cv + } + cv, ok = sc.canonHeader[v] + if ok { + return cv + } + if sc.canonHeader == nil { + sc.canonHeader = make(map[string]string) + } + cv = http.CanonicalHeaderKey(v) + // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of + // entries in the canonHeader cache. This should be larger than the number + // of unique, uncommon header keys likely to be sent by the peer, while not + // so high as to permit unreasonable memory usage if the peer sends an unbounded + // number of unique header keys. + const maxCachedCanonicalHeaders = 32 + if len(sc.canonHeader) < maxCachedCanonicalHeaders { + sc.canonHeader[v] = cv + } + return cv +} + +type http2readFrameResult struct { + f http2Frame // valid until readMore is called + err error + + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() +} + +// readFrames is the loop that reads incoming frames. +// It takes care to only read one frame at a time, blocking until the +// consumer is done with the frame. +// It's run on its own goroutine. +func (sc *http2serverConn) readFrames() { + gate := make(http2gate) + gateDone := gate.Done + for { + f, err := sc.framer.ReadFrame() + select { + case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: + case <-sc.doneServing: + return + } + select { + case <-gate: + case <-sc.doneServing: + return + } + if http2terminalReadFrameError(err) { + return + } + } +} + +// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. +type http2frameWriteResult struct { + _ http2incomparable + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call +} + +// writeFrameAsync runs in its own goroutine and writes a single frame +// and then reports when it's done. +// At most one goroutine can be running writeFrameAsync at a time per +// serverConn. +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} +} + +func (sc *http2serverConn) closeAllStreamsOnConnClose() { + sc.serveG.check() + for _, st := range sc.streams { + sc.closeStream(st, http2errClientDisconnected) + } +} + +func (sc *http2serverConn) stopShutdownTimer() { + sc.serveG.check() + if t := sc.shutdownTimer; t != nil { + t.Stop() + } +} + +func (sc *http2serverConn) notePanic() { + // Note: this is for serverConn.serve panicking, not http.Handler code. + if http2testHookOnPanicMu != nil { + http2testHookOnPanicMu.Lock() + defer http2testHookOnPanicMu.Unlock() + } + if http2testHookOnPanic != nil { + if e := recover(); e != nil { + if http2testHookOnPanic(sc, e) { + panic(e) + } + } + } +} + +func (sc *http2serverConn) serve() { + sc.serveG.check() + defer sc.notePanic() + defer sc.conn.Close() + defer sc.closeAllStreamsOnConnClose() + defer sc.stopShutdownTimer() + defer close(sc.doneServing) // unblocks handlers trying to send + + if http2VerboseLogs { + sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) + } + + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeSettings{ + {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, + {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, + }, + }) + sc.unackedSettings++ + + // Each connection starts with initialWindowSize inflow tokens. + // If a higher value is configured, we add more tokens. + if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { + sc.sendWindowUpdate(nil, int(diff)) + } + + if err := sc.readPreface(); err != nil { + sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) + return + } + // Now that we've got the preface, get us out of the + // "StateNew" state. We can't go directly to idle, though. + // Active means we read some data and anticipate a request. We'll + // do another Active when we get a HEADERS frame. + sc.setConnState(http.StateActive) + sc.setConnState(http.StateIdle) + + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + defer sc.idleTimer.Stop() + } + + go sc.readFrames() // closed by defer sc.conn.Close above + + settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + defer settingsTimer.Stop() + + loopNum := 0 + for { + loopNum++ + select { + case wr := <-sc.wantWriteFrameCh: + if se, ok := wr.write.(http2StreamError); ok { + sc.resetStream(se) + break + } + sc.writeFrame(wr) + case res := <-sc.wroteFrameCh: + sc.wroteFrame(res) + case res := <-sc.readFrameCh: + // Process any written frames before reading new frames from the client since a + // written frame could have triggered a new stream to be started. + if sc.writingFrameAsync { + select { + case wroteRes := <-sc.wroteFrameCh: + sc.wroteFrame(wroteRes) + default: + } + } + if !sc.processFrameFromReader(res) { + return + } + res.readMore() + if settingsTimer != nil { + settingsTimer.Stop() + settingsTimer = nil + } + case m := <-sc.bodyReadCh: + sc.noteBodyRead(m.st, m.n) + case msg := <-sc.serveMsgCh: + switch v := msg.(type) { + case func(int): + v(loopNum) // for testing + case *http2serverMessage: + switch v { + case http2settingsTimerMsg: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case http2idleTimerMsg: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) + case http2shutdownTimerMsg: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case http2gracefulShutdownMsg: + sc.startGracefulShutdownInternal() + default: + panic("unknown timer") + } + case *http2startPushRequest: + sc.startPush(v) + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } + } + + // If the peer is causing us to generate a lot of control frames, + // but not reading them from us, assume they are trying to make us + // run out of memory. + if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { + sc.vlogf("http2: too many control frames in send queue, closing connection") + return + } + + // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY + // with no error code (graceful shutdown), don't start the timer until + // all open streams have been completed. + sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame + gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 + if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { + sc.shutDownIn(http2goAwayTimeout) + } + } +} + +func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { + select { + case <-sc.doneServing: + case <-sharedCh: + close(privateCh) + } +} + +type http2serverMessage int + +// Message values sent to serveMsgCh. +var ( + http2settingsTimerMsg = new(http2serverMessage) + http2idleTimerMsg = new(http2serverMessage) + http2shutdownTimerMsg = new(http2serverMessage) + http2gracefulShutdownMsg = new(http2serverMessage) +) + +func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } + +func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } + +func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } + +func (sc *http2serverConn) sendServeMsg(msg interface{}) { + sc.serveG.checkNotOn() // NOT + select { + case sc.serveMsgCh <- msg: + case <-sc.doneServing: + } +} + +var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") + +// readPreface reads the ClientPreface greeting from the peer or +// returns errPrefaceTimeout on timeout, or an error if the greeting +// is invalid. +func (sc *http2serverConn) readPreface() error { + errc := make(chan error, 1) + go func() { + // Read the client preface + buf := make([]byte, len(http2ClientPreface)) + if _, err := io.ReadFull(sc.conn, buf); err != nil { + errc <- err + } else if !bytes.Equal(buf, http2clientPreface) { + errc <- fmt.Errorf("bogus greeting %q", buf) + } else { + errc <- nil + } + }() + timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server? + defer timer.Stop() + select { + case <-timer.C: + return http2errPrefaceTimeout + case err := <-errc: + if err == nil { + if http2VerboseLogs { + sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) + } + } + return err + } +} + +var http2errChanPool = sync.Pool{ + New: func() interface{} { return make(chan error, 1) }, +} + +var http2writeDataPool = sync.Pool{ + New: func() interface{} { return new(http2writeData) }, +} + +// writeDataFromHandler writes DATA response frames from a handler on +// the given stream. +func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { + ch := http2errChanPool.Get().(chan error) + writeArg := http2writeDataPool.Get().(*http2writeData) + *writeArg = http2writeData{stream.id, data, endStream} + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: writeArg, + stream: stream, + done: ch, + }) + if err != nil { + return err + } + var frameWriteDone bool // the frame write is done (successfully or not) + select { + case err = <-ch: + frameWriteDone = true + case <-sc.doneServing: + return http2errClientDisconnected + case <-stream.cw: + // If both ch and stream.cw were ready (as might + // happen on the final Write after an http.Handler + // ends), prefer the write result. Otherwise this + // might just be us successfully closing the stream. + // The writeFrameAsync and serve goroutines guarantee + // that the ch send will happen before the stream.cw + // close. + select { + case err = <-ch: + frameWriteDone = true + default: + return http2errStreamClosed + } + } + http2errChanPool.Put(ch) + if frameWriteDone { + http2writeDataPool.Put(writeArg) + } + return err +} + +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts +// if the connection has gone away. +// +// This must not be run from the serve goroutine itself, else it might +// deadlock writing to sc.wantWriteFrameCh (which is only mildly +// buffered and is read by serve itself). If you're on the serve +// goroutine, call writeFrame instead. +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { + sc.serveG.checkNotOn() // NOT + select { + case sc.wantWriteFrameCh <- wr: + return nil + case <-sc.doneServing: + // Serve loop is gone. + // Client has closed their connection to the server. + return http2errClientDisconnected + } +} + +// writeFrame schedules a frame to write and sends it if there's nothing +// already being written. +// +// There is no pushback here (the serve goroutine never blocks). It's +// the http.Handlers that block, waiting for their previous frames to +// make it onto the wire +// +// If you're not on the serve goroutine, use writeFrameFromHandler instead. +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { + sc.serveG.check() + + // If true, wr will not be written and wr.done will not be signaled. + var ignoreWrite bool + + // We are not allowed to write frames on closed streams. RFC 7540 Section + // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on + // a closed stream." Our server never sends PRIORITY, so that exception + // does not apply. + // + // The serverConn might close an open stream while the stream's handler + // is still running. For example, the server might close a stream when it + // receives bad data from the client. If this happens, the handler might + // attempt to write a frame after the stream has been closed (since the + // handler hasn't yet been notified of the close). In this case, we simply + // ignore the frame. The handler will notice that the stream is closed when + // it waits for the frame to be written. + // + // As an exception to this rule, we allow sending RST_STREAM after close. + // This allows us to immediately reject new streams without tracking any + // state for those streams (except for the queued RST_STREAM frame). This + // may result in duplicate RST_STREAMs in some cases, but the client should + // ignore those. + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + // Don't send a 100-continue response if we've already sent headers. + // See golang.org/issue/14030. + switch wr.write.(type) { + case *http2writeResHeaders: + wr.stream.wroteHeaders = true + case http2write100ContinueHeadersFrame: + if wr.stream.wroteHeaders { + // We do not need to notify wr.done because this frame is + // never written with wr.done != nil. + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } + ignoreWrite = true + } + } + + if !ignoreWrite { + if wr.isControl() { + sc.queuedControlFrames++ + // For extra safety, detect wraparounds, which should not happen, + // and pull the plug. + if sc.queuedControlFrames < 0 { + sc.conn.Close() + } + } + sc.writeSched.Push(wr) + } + sc.scheduleFrameWrite() +} + +// startFrameWrite starts a goroutine to write wr (in a separate +// goroutine since that might block on the network), and updates the +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { + sc.serveG.check() + if sc.writingFrame { + panic("internal error: can only be writing one frame at a time") + } + + st := wr.stream + if st != nil { + switch st.state { + case http2stateHalfClosedLocal: + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: + // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE + // in this state. (We never send PRIORITY from the server, so that is not checked.) + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) + } + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return + } + } + + sc.writingFrame = true + sc.needsFrameFlush = true + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } +} + +// errHandlerPanicked is the error given to any callers blocked in a read from +// Request.Body when the main goroutine panics. Since most handlers read in the +// main ServeHTTP goroutine, this will show up rarely. +var http2errHandlerPanicked = errors.New("http2: handler panicked") + +// wroteFrame is called on the serve goroutine with the result of +// whatever happened on writeFrameAsync. +func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { + sc.serveG.check() + if !sc.writingFrame { + panic("internal error: expected to be already writing a frame") + } + sc.writingFrame = false + sc.writingFrameAsync = false + + wr := res.wr + + if http2writeEndsStream(wr.write) { + st := wr.stream + if st == nil { + panic("internal error: expecting non-nil stream") + } + switch st.state { + case http2stateOpen: + // Here we would go to stateHalfClosedLocal in + // theory, but since our handler is done and + // the net/http package provides no mechanism + // for closing a ResponseWriter while still + // reading data (see possible TODO at top of + // this file), we go into closed state here + // anyway, after telling the peer we're + // hanging up on them. We'll transition to + // stateClosed after the RST_STREAM frame is + // written. + st.state = http2stateHalfClosedLocal + // Section 8.1: a server MAY request that the client abort + // transmission of a request without error by sending a + // RST_STREAM with an error code of NO_ERROR after sending + // a complete response. + sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) + case http2stateHalfClosedRemote: + sc.closeStream(st, http2errHandlerComplete) + } + } else { + switch v := wr.write.(type) { + case http2StreamError: + // st may be unknown if the RST_STREAM was generated to reject bad input. + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } + } + + // Reply (if requested) to unblock the ServeHTTP goroutine. + wr.replyToWriter(res.err) + + sc.scheduleFrameWrite() +} + +// scheduleFrameWrite tickles the frame writing scheduler. +// +// If a frame is already being written, nothing happens. This will be called again +// when the frame is done being written. +// +// If a frame isn't being written and we need to send one, the best frame +// to send is selected by writeSched. +// +// If a frame isn't being written and there's nothing else to send, we +// flush the write buffer. +func (sc *http2serverConn) scheduleFrameWrite() { + sc.serveG.check() + if sc.writingFrame || sc.inFrameScheduleLoop { + return + } + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue + } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + if wr.isControl() { + sc.queuedControlFrames-- + } + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true + continue + } + break + } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown gracefully shuts down a connection. This +// sends GOAWAY with ErrCodeNo to tell the client we're gracefully +// shutting down. The connection isn't closed until all current +// streams are done. +// +// startGracefulShutdown returns immediately; it does not wait until +// the connection has shut down. +func (sc *http2serverConn) startGracefulShutdown() { + sc.serveG.checkNotOn() // NOT + sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) +} + +// After sending GOAWAY with an error code (non-graceful shutdown), the +// connection will close after goAwayTimeout. +// +// If we close the connection immediately after sending GOAWAY, there may +// be unsent data in our kernel receive buffer, which will cause the kernel +// to send a TCP RST on close() instead of a FIN. This RST will abort the +// connection immediately, whether or not the client had received the GOAWAY. +// +// Ideally we should delay for at least 1 RTT + epsilon so the client has +// a chance to read the GOAWAY and stop sending messages. Measuring RTT +// is hard, so we approximate with 1 second. See golang.org/issue/18701. +// +// This is a var so it can be shorter in tests, where all requests uses the +// loopback interface making the expected RTT very small. +// +// TODO: configurable? +var http2goAwayTimeout = 1 * time.Second + +func (sc *http2serverConn) startGracefulShutdownInternal() { + sc.goAway(http2ErrCodeNo) +} + +func (sc *http2serverConn) goAway(code http2ErrCode) { + sc.serveG.check() + if sc.inGoAway { + return + } + sc.inGoAway = true + sc.needToSendGoAway = true + sc.goAwayCode = code + sc.scheduleFrameWrite() +} + +func (sc *http2serverConn) shutDownIn(d time.Duration) { + sc.serveG.check() + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) +} + +func (sc *http2serverConn) resetStream(se http2StreamError) { + sc.serveG.check() + sc.writeFrame(http2FrameWriteRequest{write: se}) + if st, ok := sc.streams[se.StreamID]; ok { + st.resetQueued = true + } +} + +// processFrameFromReader processes the serve loop's read from readFrameCh from the +// frame-reading goroutine. +// processFrameFromReader returns whether the connection should be kept open. +func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { + sc.serveG.check() + err := res.err + if err != nil { + if err == http2ErrFrameTooLarge { + sc.goAway(http2ErrCodeFrameSize) + return true // goAway will close the loop + } + clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) + if clientGone { + // TODO: could we also get into this state if + // the peer does a half close + // (e.g. CloseWrite) because they're done + // sending frames but they're still wanting + // our open replies? Investigate. + // TODO: add CloseWrite to crypto/tls.Conn first + // so we have a way to test this? I suppose + // just for testing we could have a non-TLS mode. + return false + } + } else { + f := res.f + if http2VerboseLogs { + sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) + } + err = sc.processFrame(f) + if err == nil { + return true + } + } + + switch ev := err.(type) { + case http2StreamError: + sc.resetStream(ev) + return true + case http2goAwayFlowError: + sc.goAway(http2ErrCodeFlowControl) + return true + case http2ConnectionError: + sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) + sc.goAway(http2ErrCode(ev)) + return true // goAway will handle shutdown + default: + if res.err != nil { + sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) + } else { + sc.logf("http2: server closing client connection: %v", err) + } + return false + } +} + +func (sc *http2serverConn) processFrame(f http2Frame) error { + sc.serveG.check() + + // First frame received must be SETTINGS. + if !sc.sawFirstSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) + } + sc.sawFirstSettings = true + } + + switch f := f.(type) { + case *http2SettingsFrame: + return sc.processSettings(f) + case *http2MetaHeadersFrame: + return sc.processHeaders(f) + case *http2WindowUpdateFrame: + return sc.processWindowUpdate(f) + case *http2PingFrame: + return sc.processPing(f) + case *http2DataFrame: + return sc.processData(f) + case *http2RSTStreamFrame: + return sc.processResetStream(f) + case *http2PriorityFrame: + return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) + case *http2PushPromiseFrame: + // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE + // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) + default: + sc.vlogf("http2: server ignoring frame: %v", f.Header()) + return nil + } +} + +func (sc *http2serverConn) processPing(f *http2PingFrame) error { + sc.serveG.check() + if f.IsAck() { + // 6.7 PING: " An endpoint MUST NOT respond to PING frames + // containing this flag." + return nil + } + if f.StreamID != 0 { + // "PING frames are not associated with any individual + // stream. If a PING frame is received with a stream + // identifier field value other than 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) + } + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) + return nil +} + +func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { + sc.serveG.check() + switch { + case f.StreamID != 0: // stream-level flow control + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) + } + if st == nil { + // "WINDOW_UPDATE can be sent by a peer that has sent a + // frame bearing the END_STREAM flag. This means that a + // receiver could receive a WINDOW_UPDATE frame on a "half + // closed (remote)" or "closed" stream. A receiver MUST + // NOT treat this as an error, see Section 5.1." + return nil + } + if !st.flow.add(int32(f.Increment)) { + return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) + } + default: // connection-level flow control + if !sc.flow.add(int32(f.Increment)) { + return http2goAwayFlowError{} + } + } + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { + sc.serveG.check() + + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // 6.4 "RST_STREAM frames MUST NOT be sent for a + // stream in the "idle" state. If a RST_STREAM frame + // identifying an idle stream is received, the + // recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) + } + if st != nil { + st.cancelCtx() + sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) + } + return nil +} + +func (sc *http2serverConn) closeStream(st *http2stream, err error) { + sc.serveG.check() + if st.state == http2stateIdle || st.state == http2stateClosed { + panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) + } + st.state = http2stateClosed + if st.writeDeadline != nil { + st.writeDeadline.Stop() + } + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- + } + delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(http.StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdownInternal() + } + } + if p := st.body; p != nil { + // Return any buffered unread bytes worth of conn-level flow control. + // See golang.org/issue/16481 + sc.sendWindowUpdate(nil, p.Len()) + + p.CloseWithError(err) + } + st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc + sc.writeSched.CloseStream(st.id) +} + +func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { + sc.serveG.check() + if f.IsAck() { + sc.unackedSettings-- + if sc.unackedSettings < 0 { + // Why is the peer ACKing settings we never sent? + // The spec doesn't mention this case, but + // hang up on them anyway. + return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) + } + return nil + } + if f.NumSettings() > 100 || f.HasDuplicates() { + // This isn't actually in the spec, but hang up on + // suspiciously large settings frames or those with + // duplicate entries. + return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) + } + if err := f.ForeachSetting(sc.processSetting); err != nil { + return err + } + // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be + // acknowledged individually, even if multiple are received before the ACK. + sc.needToSendSettingsAck = true + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processSetting(s http2Setting) error { + sc.serveG.check() + if err := s.Valid(); err != nil { + return err + } + if http2VerboseLogs { + sc.vlogf("http2: server processing setting %v", s) + } + switch s.ID { + case http2SettingHeaderTableSize: + sc.headerTableSize = s.Val + sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) + case http2SettingEnablePush: + sc.pushEnabled = s.Val != 0 + case http2SettingMaxConcurrentStreams: + sc.clientMaxStreams = s.Val + case http2SettingInitialWindowSize: + return sc.processSettingInitialWindowSize(s.Val) + case http2SettingMaxFrameSize: + sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 + case http2SettingMaxHeaderListSize: + sc.peerMaxHeaderListSize = s.Val + default: + // Unknown setting: "An endpoint that receives a SETTINGS + // frame with any unknown or unsupported identifier MUST + // ignore that setting." + if http2VerboseLogs { + sc.vlogf("http2: server ignoring unknown setting %v", s) + } + } + return nil +} + +func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { + sc.serveG.check() + // Note: val already validated to be within range by + // processSetting's Valid call. + + // "A SETTINGS frame can alter the initial flow control window + // size for all current streams. When the value of + // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST + // adjust the size of all stream flow control windows that it + // maintains by the difference between the new value and the + // old value." + old := sc.initialStreamSendWindowSize + sc.initialStreamSendWindowSize = int32(val) + growth := int32(val) - old // may be negative + for _, st := range sc.streams { + if !st.flow.add(growth) { + // 6.9.2 Initial Flow Control Window Size + // "An endpoint MUST treat a change to + // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow + // control window to exceed the maximum size as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR." + return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) + } + } + return nil +} + +func (sc *http2serverConn) processData(f *http2DataFrame) error { + sc.serveG.check() + id := f.Header().StreamID + if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || id > sc.maxClientStreamID) { + // Discard all DATA frames if the GOAWAY is due to an + // error, or: + // + // Section 6.8: After sending a GOAWAY frame, the sender + // can discard frames for streams initiated by the + // receiver with identifiers higher than the identified + // last stream. + return nil + } + + data := f.Data() + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + // Section 6.1: "DATA frames MUST be associated with a + // stream. If a DATA frame is received whose stream + // identifier field is 0x0, the recipient MUST respond + // with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + // + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) + } + + // "If a DATA frame is received whose stream is not in "open" + // or "half closed (local)" state, the recipient MUST respond + // with a stream error (Section 5.4.2) of type STREAM_CLOSED." + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { + // This includes sending a RST_STREAM if the stream is + // in stateHalfClosedLocal (which currently means that + // the http.Handler returned, so it's done reading & + // done writing). Try to stop the client from sending + // more DATA. + + // But still enforce their connection-level flow control, + // and return any flow control bytes since we're not going + // to consume them. + if sc.inflow.available() < int32(f.Length) { + return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) + } + // Deduct the flow control from inflow, since we're + // going to immediately add it back in + // sendWindowUpdate, which also schedules sending the + // frames. + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + + if st != nil && st.resetQueued { + // Already have a stream error in flight. Don't send another. + return nil + } + return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) + } + if st.body == nil { + panic("internal error: should have a body in this state") + } + + // Sender sending more than they'd declared? + if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the + // value of a content-length header field does not equal the sum of the + // DATA frame payload lengths that form the body. + return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) + } + if f.Length > 0 { + // Check whether the client has flow control quota. + if st.inflow.available() < int32(f.Length) { + return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) + } + st.inflow.take(int32(f.Length)) + + if len(data) > 0 { + wrote, err := st.body.Write(data) + if err != nil { + sc.sendWindowUpdate(nil, int(f.Length)-wrote) + return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) + } + + // Return any padded flow control now, since we won't + // refund it later on body reads. + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) + } + } + if f.StreamEnded() { + st.endStream() + } + return nil +} + +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdownInternal() + // http://tools.ietf.org/html/rfc7540#section-6.8 + // We should not create any new streams, which means we should disable push. + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + +// endStream closes a Request.Body's pipe. It is called when a DATA +// frame says a request body is over (or after trailers). +func (st *http2stream) endStream() { + sc := st.sc + sc.serveG.check() + + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) + st.body.CloseWithError(io.EOF) + } + st.state = http2stateHalfClosedRemote +} + +// copyTrailersToHandlerRequest is run in the Handler's goroutine in +// its Request.Body.Read just before it gets io.EOF. +func (st *http2stream) copyTrailersToHandlerRequest() { + for k, vv := range st.trailer { + if _, ok := st.reqTrailer[k]; ok { + // Only copy it over it was pre-declared. + st.reqTrailer[k] = vv + } + } +} + +// onWriteTimeout is run on its own goroutine (from time.AfterFunc) +// when the stream's WriteTimeout has fired. +func (st *http2stream) onWriteTimeout() { + st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +} + +func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { + sc.serveG.check() + id := f.StreamID + if sc.inGoAway { + // Ignore. + return nil + } + // http://tools.ietf.org/html/rfc7540#section-5.1.1 + // Streams initiated by a client MUST use odd-numbered stream + // identifiers. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + if id%2 != 1 { + return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) + } + // A HEADERS frame can be used to create a new stream or + // send a trailer for an open one. If we already have a stream + // open, let it process its own HEADERS frame (trailers at this + // point, if it's valid). + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + // We're sending RST_STREAM to close the stream, so don't bother + // processing this frame. + return nil + } + // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than + // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in + // this state, it MUST respond with a stream error (Section 5.4.2) of + // type STREAM_CLOSED. + if st.state == http2stateHalfClosedRemote { + return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) + } + return st.processTrailerHeaders(f) + } + + // [...] The identifier of a newly established stream MUST be + // numerically greater than all streams that the initiating + // endpoint has opened or reserved. [...] An endpoint that + // receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + if id <= sc.maxClientStreamID { + return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) + } + sc.maxClientStreamID = id + + if sc.idleTimer != nil { + sc.idleTimer.Stop() + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.2 + // [...] Endpoints MUST NOT exceed the limit set by their peer. An + // endpoint that receives a HEADERS frame that causes their + // advertised concurrent stream limit to be exceeded MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR + // or REFUSED_STREAM. + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { + // They should know better. + return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) + } + // Assume it's a network race, where they just haven't + // received our last SETTINGS update. But actually + // this can't happen yet, because we don't yet provide + // a way for users to adjust server parameters at + // runtime. + return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) + } + + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) + + if f.HasPriority() { + if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { + return err + } + sc.writeSched.AdjustStream(st.id, f.Priority) + } + + rw, req, err := sc.newWriterAndRequest(st, f) + if err != nil { + return err + } + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(http.Header) + } + st.body = req.Body.(*http2requestBody).pipe // may be nil + st.declBodyBytes = req.ContentLength + + handler := sc.handler.ServeHTTP + if f.Truncated { + // Their header list was too long. Send a 431 error. + handler = http2handleHeaderListTooLong + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { + handler = http2new400Handler(err) + } + + // The net/http package sets the read deadline from the + // http.Server.ReadTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already + // set. Disarm it here after the request headers are read, + // similar to how the http1 server works. Here it's + // technically more like the http1 Server's ReadHeaderTimeout + // (in Go 1.8), though. That's a more sane option anyway. + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + + go sc.runHandler(rw, req, handler) + return nil +} + +func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { + sc := st.sc + sc.serveG.check() + if st.gotTrailerHeader { + return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) + } + st.gotTrailerHeader = true + if !f.StreamEnded() { + return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) + } + + if len(f.PseudoFields()) > 0 { + return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) + } + if st.trailer != nil { + for _, hf := range f.RegularFields() { + key := sc.canonicalHeader(hf.Name) + if !httpguts.ValidTrailerHeader(key) { + // TODO: send more details to the peer somehow. But http2 has + // no way to send debug data at a stream level. Discuss with + // HTTP folk. + return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) + } + st.trailer[key] = append(st.trailer[key], hf.Value) + } + } + st.endStream() + return nil +} + +func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." + // Section 5.3.3 says that a stream can depend on one of its dependencies, + // so it's only self-dependencies that are forbidden. + return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) + } + return nil +} + +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} + +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") + } + + ctx, cancelCtx := context.WithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, + } + st.cw.Init() + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialStreamSendWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + if sc.hs.WriteTimeout != 0 { + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + } + + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ + } + if sc.curOpenStreams() == 1 { + sc.setConnState(http.StateActive) + } + + return st +} + +func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *http.Request, error) { + sc.serveG.check() + + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } + + isConnect := rp.method == "CONNECT" + if isConnect { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { + return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { + // See 8.1.2.6 Malformed Requests and Responses: + // + // Malformed requests or responses that are detected + // MUST be treated as a stream error (Section 5.4.2) + // of type PROTOCOL_ERROR." + // + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid + // value for the :method, :scheme, and :path + // pseudo-header fields" + return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + + bodyOpen := !f.StreamEnded() + if rp.method == "HEAD" && bodyOpen { + // HEAD requests can't have bodies + return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + + rp.header = make(http.Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") + } + + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err + } + if bodyOpen { + if vv, ok := rp.header["Content-Length"]; ok { + if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { + req.ContentLength = int64(cl) + } else { + req.ContentLength = 0 + } + } else { + req.ContentLength = -1 + } + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2dataBuffer{expected: req.ContentLength}, + } + } + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header http.Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *http.Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" + if needsContinue { + rp.header.Del("Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) + } + + // Setup Trailers + var trailer http.Header + for _, v := range rp.header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = http.CanonicalHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(http.Header) + } + trailer[key] = nil + } + } + } + delete(rp.header, "Trailer") + + var url_ *url.URL + var requestURI string + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.path) + if err != nil { + return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) + } + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, + } + req := &http.Request{ + Method: rp.method, + URL: url_, + RemoteAddr: sc.remoteAddrStr, + Header: rp.header, + RequestURI: requestURI, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + TLS: tlsState, + Host: rp.authority, + Body: body, + Trailer: trailer, + } + req = req.WithContext(st.ctx) + + rws := http2responseWriterStatePool.Get().(*http2responseWriterState) + bwSave := rws.bw + *rws = http2responseWriterState{} // zero all the fields + rws.conn = sc + rws.bw = bwSave + rws.bw.Reset(http2chunkWriter{rws}) + rws.stream = st + rws.req = req + rws.body = body + + rw := &http2responseWriter{rws: rws} + return rw, req, nil +} + +// Run on its own goroutine. +func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { + didPanic := true + defer func() { + rw.rws.stream.cancelCtx() + if didPanic { + e := recover() + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2handlerPanicRST{rw.rws.stream.id}, + stream: rw.rws.stream, + }) + // Same as net/http: + if e != nil && e != http.ErrAbortHandler { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } + return + } + rw.handlerDone() + }() + handler(rw, req) + didPanic = false +} + +func http2handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { + // 10.5.1 Limits on Header Block Size: + // .. "A server that receives a larger header block than it is + // willing to handle can send an HTTP 431 (Request Header Fields Too + // Large) status code" + const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ + w.WriteHeader(statusRequestHeaderFieldsTooLarge) + io.WriteString(w, "

HTTP Error 431

Request Header Field(s) Too Large

") +} + +// called from handler goroutines. +// h may be nil. +func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { + sc.serveG.checkNotOn() // NOT on + var errc chan error + if headerData.h != nil { + // If there's a header map (which we don't own), so we have to block on + // waiting for this frame to be written, so an http.Flush mid-handler + // writes out the correct value of keys, before a handler later potentially + // mutates it. + errc = http2errChanPool.Get().(chan error) + } + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: headerData, + stream: st, + done: errc, + }); err != nil { + return err + } + if errc != nil { + select { + case err := <-errc: + http2errChanPool.Put(errc) + return err + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + } + } + return nil +} + +// called from handler goroutines. +func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2write100ContinueHeadersFrame{st.id}, + stream: st, + }) +} + +// A bodyReadMsg tells the server loop that the http.Handler read n +// bytes of the DATA from the client on the given stream. +type http2bodyReadMsg struct { + st *http2stream + n int +} + +// called from handler goroutines. +// Notes that the handler for the given stream ID read n bytes of its body +// and schedules flow control tokens to be sent. +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { + sc.serveG.checkNotOn() // NOT on + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } +} + +func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { + sc.serveG.check() + sc.sendWindowUpdate(nil, n) // conn-level + if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { + // Don't send this WINDOW_UPDATE if the stream is closed + // remotely. + sc.sendWindowUpdate(st, n) + } +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { + sc.serveG.check() + // "The legal range for the increment to the flow control + // window is 1 to 2^31-1 (2,147,483,647) octets." + // A Go Read call on 64-bit machines could in theory read + // a larger Read than this. Very unlikely, but we handle it here + // rather than elsewhere for now. + const maxUint31 = 1<<31 - 1 + for n >= maxUint31 { + sc.sendWindowUpdate32(st, maxUint31) + n -= maxUint31 + } + sc.sendWindowUpdate32(st, int32(n)) +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { + sc.serveG.check() + if n == 0 { + return + } + if n < 0 { + panic("negative update") + } + var streamID uint32 + if st != nil { + streamID = st.id + } + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + stream: st, + }) + var ok bool + if st == nil { + ok = sc.inflow.add(n) + } else { + ok = st.inflow.add(n) + } + if !ok { + panic("internal error; sent too many window updates without decrements?") + } +} + +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. +type http2requestBody struct { + _ http2incomparable + stream *http2stream + conn *http2serverConn + closed bool // for use by Close only + sawEOF bool // for use by Read only + pipe *http2pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue +} + +func (b *http2requestBody) Close() error { + if b.pipe != nil && !b.closed { + b.pipe.BreakWithError(http2errClosedBody) + } + b.closed = true + return nil +} + +func (b *http2requestBody) Read(p []byte) (n int, err error) { + if b.needsContinue { + b.needsContinue = false + b.conn.write100ContinueHeaders(b.stream) + } + if b.pipe == nil || b.sawEOF { + return 0, io.EOF + } + n, err = b.pipe.Read(p) + if err == io.EOF { + b.sawEOF = true + } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) + return +} + +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriterState pointer inside is zeroed at the end of a +// request (in handlerDone) and calls on the responseWriter thereafter +// simply crash (caller's mistake), but the much larger responseWriterState +// and buffers are reused between multiple requests. +type http2responseWriter struct { + rws *http2responseWriterState +} + +// Optional http.ResponseWriter interfaces implemented. +var ( + _ http.CloseNotifier = (*http2responseWriter)(nil) + _ http.Flusher = (*http2responseWriter)(nil) + _ http2stringWriter = (*http2responseWriter)(nil) +) + +type http2responseWriterState struct { + // immutable within a request: + stream *http2stream + req *http.Request + body *http2requestBody // to close at end of request, if DATA frames didn't + conn *http2serverConn + + // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc + bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} + + // mutated by http.Handler goroutine: + handlerHeader http.Header // nil until called + snapHeader http.Header // snapshot of handlerHeader at WriteHeader time + trailers []string // set in writeChunk + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished + dirty bool // a Write failed; don't reuse this responseWriterState + + sentContentLen int64 // non-zero if handler set a Content-Length header + wroteBytes int64 + + closeNotifierMu sync.Mutex // guards closeNotifierCh + closeNotifierCh chan bool // nil until first used +} + +type http2chunkWriter struct{ rws *http2responseWriterState } + +func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } + +func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } + +func (rws *http2responseWriterState) hasNonemptyTrailers() bool { + for _, trailer := range rws.trailers { + if _, ok := rws.handlerHeader[trailer]; ok { + return true + } + } + return false +} + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (rws *http2responseWriterState) declareTrailer(k string) { + k = http.CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 7230, section 4.1.2. + rws.conn.logf("ignoring invalid trailer %q", k) + return + } + if !http2strSliceContains(rws.trailers, k) { + rws.trailers = append(rws.trailers, k) + } +} + +// writeChunk writes chunks from the bufio.Writer. But because +// bufio.Writer may bypass its chunking, sometimes p may be +// arbitrarily large. +// +// writeChunk is also responsible (on the first chunk) for sending the +// HEADER response. +func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { + if !rws.wroteHeader { + rws.writeHeader(200) + } + + isHeadResp := rws.req.Method == "HEAD" + if !rws.sentHeader { + rws.sentHeader = true + var ctype, clen string + if clen = rws.snapHeader.Get("Content-Length"); clen != "" { + rws.snapHeader.Del("Content-Length") + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + rws.sentContentLen = int64(cl) + } else { + clen = "" + } + } + if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + clen = strconv.Itoa(len(p)) + } + _, hasContentType := rws.snapHeader["Content-Type"] + // If the Content-Encoding is non-blank, we shouldn't + // sniff the body. See Issue golang.org/issue/31753. + ce := rws.snapHeader.Get("Content-Encoding") + hasCE := len(ce) > 0 + if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { + ctype = http.DetectContentType(p) + } + var date string + if _, ok := rws.snapHeader["Date"]; !ok { + // TODO(bradfitz): be faster here, like net/http? measure. + date = time.Now().UTC().Format(http.TimeFormat) + } + + for _, v := range rws.snapHeader["Trailer"] { + http2foreachHeaderElement(v, rws.declareTrailer) + } + + // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), + // but respect "Connection" == "close" to mean sending a GOAWAY and tearing + // down the TCP connection when idle, like we do for HTTP/1. + // TODO: remove more Connection-specific header fields here, in addition + // to "Connection". + if _, ok := rws.snapHeader["Connection"]; ok { + v := rws.snapHeader.Get("Connection") + delete(rws.snapHeader, "Connection") + if v == "close" { + rws.conn.startGracefulShutdown() + } + } + + endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + httpResCode: rws.status, + h: rws.snapHeader, + endStream: endStream, + contentType: ctype, + contentLength: clen, + date: date, + }) + if err != nil { + rws.dirty = true + return 0, err + } + if endStream { + return 0, nil + } + } + if isHeadResp { + return len(p), nil + } + if len(p) == 0 && !rws.handlerDone { + return 0, nil + } + + if rws.handlerDone { + rws.promoteUndeclaredTrailers() + } + + // only send trailers if they have actually been defined by the + // server handler. + hasNonemptyTrailers := rws.hasNonemptyTrailers() + endStream := rws.handlerDone && !hasNonemptyTrailers + if len(p) > 0 || endStream { + // only send a 0 byte DATA frame if we're ending the stream. + if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { + rws.dirty = true + return 0, err + } + } + + if rws.handlerDone && hasNonemptyTrailers { + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + h: rws.handlerHeader, + trailers: rws.trailers, + endStream: true, + }) + if err != nil { + rws.dirty = true + } + return len(p), err + } + return len(p), nil +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const http2TrailerPrefix = "Trailer:" + +// promoteUndeclaredTrailers permits http.Handlers to set trailers +// after the header has already been flushed. Because the Go +// ResponseWriter interface has no way to set Trailers (only the +// Header), and because we didn't want to expand the ResponseWriter +// interface, and because nobody used trailers, and because RFC 7230 +// says you SHOULD (but not must) predeclare any trailers in the +// header, the official ResponseWriter rules said trailers in Go must +// be predeclared, and then we reuse the same ResponseWriter.Header() +// map to mean both Headers and Trailers. When it's time to write the +// Trailers, we pick out the fields of Headers that were declared as +// trailers. That worked for a while, until we found the first major +// user of Trailers in the wild: gRPC (using them only over http2), +// and gRPC libraries permit setting trailers mid-stream without +// predeclaring them. So: change of plans. We still permit the old +// way, but we also permit this hack: if a Header() key begins with +// "Trailer:", the suffix of that key is a Trailer. Because ':' is an +// invalid token byte anyway, there is no ambiguity. (And it's already +// filtered out) It's mildly hacky, but not terrible. +// +// This method runs after the Handler is done and promotes any Header +// fields to be trailers. +func (rws *http2responseWriterState) promoteUndeclaredTrailers() { + for k, vv := range rws.handlerHeader { + if !strings.HasPrefix(k, http2TrailerPrefix) { + continue + } + trailerKey := strings.TrimPrefix(k, http2TrailerPrefix) + rws.declareTrailer(trailerKey) + rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv + } + + if len(rws.trailers) > 1 { + sorter := http2sorterPool.Get().(*http2sorter) + sorter.SortStrings(rws.trailers) + http2sorterPool.Put(sorter) + } +} + +func (w *http2responseWriter) Flush() { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.bw.Buffered() > 0 { + if err := rws.bw.Flush(); err != nil { + // Ignore the error. The frame writer already knows. + return + } + } else { + // The bufio.Writer won't call chunkWriter.Write + // (writeChunk with zero bytes, so we have to do it + // ourselves to force the HTTP response header and/or + // final DATA frame (with END_STREAM) to be sent. + rws.writeChunk(nil) + } +} + +func (w *http2responseWriter) CloseNotify() <-chan bool { + rws := w.rws + if rws == nil { + panic("CloseNotify called after Handler finished") + } + rws.closeNotifierMu.Lock() + ch := rws.closeNotifierCh + if ch == nil { + ch = make(chan bool, 1) + rws.closeNotifierCh = ch + cw := rws.stream.cw + go func() { + cw.Wait() // wait for close + ch <- true + }() + } + rws.closeNotifierMu.Unlock() + return ch +} + +func (w *http2responseWriter) Header() http.Header { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.handlerHeader == nil { + rws.handlerHeader = make(http.Header) + } + return rws.handlerHeader +} + +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func http2checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +func (w *http2responseWriter) WriteHeader(code int) { + rws := w.rws + if rws == nil { + panic("WriteHeader called after Handler finished") + } + rws.writeHeader(code) +} + +func (rws *http2responseWriterState) writeHeader(code int) { + if !rws.wroteHeader { + http2checkWriteHeaderCode(code) + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = http2cloneHeader(rws.handlerHeader) + } + } +} + +func http2cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// The Life Of A Write is like this: +// +// * Handler calls w.Write or w.WriteString -> +// * -> rws.bw (*bufio.Writer) -> +// * (Handler might call Flush) +// * -> chunkWriter{rws} +// * -> responseWriterState.writeChunk(p []byte) +// * -> responseWriterState.writeChunk (most of the magic; see comment there) +func (w *http2responseWriter) Write(p []byte) (n int, err error) { + return w.write(len(p), p, "") +} + +func (w *http2responseWriter) WriteString(s string) (n int, err error) { + return w.write(len(s), nil, s) +} + +// either dataB or dataS is non-zero. +func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { + rws := w.rws + if rws == nil { + panic("Write called after Handler finished") + } + if !rws.wroteHeader { + w.WriteHeader(200) + } + if !http2bodyAllowedForStatus(rws.status) { + return 0, http.ErrBodyNotAllowed + } + rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set + if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { + // TODO: send a RST_STREAM + return 0, errors.New("http2: handler wrote more than declared Content-Length") + } + + if dataB != nil { + return rws.bw.Write(dataB) + } else { + return rws.bw.WriteString(dataS) + } +} + +func (w *http2responseWriter) handlerDone() { + rws := w.rws + dirty := rws.dirty + rws.handlerDone = true + w.Flush() + w.rws = nil + if !dirty { + // Only recycle the pool if all prior Write calls to + // the serverConn goroutine completed successfully. If + // they returned earlier due to resets from the peer + // there might still be write goroutines outstanding + // from the serverConn referencing the rws memory. See + // issue 20704. + http2responseWriterStatePool.Put(rws) + } +} + +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +var _ http.Pusher = (*http2responseWriter)(nil) + +func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." + // http://tools.ietf.org/html/rfc7540#section-6.6 + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts == nil { + opts = new(http.PushOptions) + } + + // Default options. + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = http.Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + // Validate the request. + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + // These headers are meaningful only if the request has a body, + // but PUSH_PROMISE requests cannot have a body. + // http://tools.ietf.org/html/rfc7540#section-8.2 + // Also disallow Host, since the promised URL must be absolute. + if http2asciiEqualFold(k, "content-length") || + http2asciiEqualFold(k, "content-encoding") || + http2asciiEqualFold(k, "trailer") || + http2asciiEqualFold(k, "te") || + http2asciiEqualFold(k, "expect") || + http2asciiEqualFold(k, "host") { + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + // The RFC effectively limits promised requests to GET and HEAD: + // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" + // http://tools.ietf.org/html/rfc7540#section-8.2 + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := &http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.serveMsgCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header http.Header + done chan error +} + +func (sc *http2serverConn) startPush(msg *http2startPushRequest) { + sc.serveG.check() + + // http://tools.ietf.org/html/rfc7540#section-6.6. + // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that + // is in either the "open" or "half-closed (remote)" state. + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + // responseWriter.Push checks that the stream is peer-initiated. + msg.done <- http2errStreamClosed + return + } + + // http://tools.ietf.org/html/rfc7540#section-6.6. + if !sc.pushEnabled { + msg.done <- http.ErrNotSupported + return + } + + // PUSH_PROMISE frames must be sent in increasing order by stream ID, so + // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE + // is written. Once the ID is allocated, we start the request handler. + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + // Check this again, just in case. Technically, we might have received + // an updated SETTINGS by the time we got around to writing this frame. + if !sc.pushEnabled { + return 0, http.ErrNotSupported + } + // http://tools.ietf.org/html/rfc7540#section-6.5.2. + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.1. + // Streams initiated by the server MUST use even-numbered identifiers. + // A server that is unable to establish a new stream identifier can send a GOAWAY + // frame so that the client is forced to open a new connection for new streams. + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdownInternal() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + // http://tools.ietf.org/html/rfc7540#section-8.2. + // Strictly speaking, the new stream should start in "reserved (local)", then + // transition to "half closed (remote)" after sending the initial HEADERS, but + // we start in "half closed (remote)" for simplicity. + // See further comments at the definition of stateHalfClosedRemote. + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + }) + if err != nil { + // Should not happen, since we've already validated msg.url. + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2RequestHeaders(h http.Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) + } + } + te := h["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, err.Error(), http.StatusBadRequest) + } +} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + +func (sc *http2serverConn) countError(name string, err error) error { + if sc == nil || sc.srv == nil { + return err + } + f := sc.srv.CountError + if f == nil { + return err + } + var typ string + var code http2ErrCode + switch e := err.(type) { + case http2ConnectionError: + typ = "conn" + code = http2ErrCode(e) + case http2StreamError: + typ = "stream" + code = http2ErrCode(e.Code) + default: + return err + } + codeStr := http2errCodeName[code] + if codeStr == "" { + codeStr = strconv.Itoa(int(code)) + } + f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) + return err +} + +const ( + // transportDefaultConnFlow is how many connection-level flow control + // tokens we give the server at start-up, past the default 64k. + http2transportDefaultConnFlow = 1 << 30 + + // transportDefaultStreamFlow is how many stream-level flow + // control tokens we announce to the peer, and how many bytes + // we buffer per stream. + http2transportDefaultStreamFlow = 4 << 20 + + // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send + // a stream-level WINDOW_UPDATE for at a time. + http2transportDefaultStreamMinRefresh = 4 << 10 + + http2defaultUserAgent = "Go-http-client/2.0" + + // initialMaxConcurrentStreams is a connections maxConcurrentStreams until + // it's received servers initial SETTINGS frame, which corresponds with the + // spec's minimum recommended value. + http2initialMaxConcurrentStreams = 100 + + // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams + // if the server doesn't include one in its initial SETTINGS frame. + http2defaultMaxConcurrentStreams = 1000 +) + +// Transport is an HTTP/2 Transport. +// +// A Transport internally caches connections to servers. It is safe +// for concurrent use by multiple goroutines. +type http2Transport struct { + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLS is nil, tls.Dial is used. + // + // If the returned net.Conn has a ConnectionState method like tls.Conn, + // it will be used to set http.Response.TLS. + DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // ConnPool optionally specifies an alternate connection pool to use. + // If nil, the default is used. + ConnPool http2ClientConnPool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // AllowHTTP, if true, permits HTTP/2 requests using the insecure, + // plain-text "http" scheme. Note that this does not enable h2c support. + AllowHTTP bool + + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to + // send in the initial settings frame. It is how many bytes + // of response headers are allowed. Unlike the http2 spec, zero here + // means to use a default limit (currently 10MB). If you actually + // want to advertise an unlimited value to the peer, Transport + // interprets the highest possible value here (0xffffffff or 1<<32-1) + // to mean no limit. + MaxHeaderListSize uint32 + + // StrictMaxConcurrentStreams controls whether the server's + // SETTINGS_MAX_CONCURRENT_STREAMS should be respected + // globally. If false, new TCP connections are created to the + // server as needed to keep each under the per-connection + // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the + // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as + // a global limit and callers of RoundTrip block when needed, + // waiting for their turn. + StrictMaxConcurrentStreams bool + + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration + + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration + + // WriteByteTimeout is the timeout after which the connection will be + // closed no data can be written to it. The timeout begins when data is + // available to write, and is extended whenever any bytes are written. + WriteByteTimeout time.Duration + + // CountError, if non-nil, is called on HTTP/2 transport errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + + // t1, if non-nil, is the standard library Transport using + // this transport. Its settings are used (but not its + // RoundTrip method, etc). + t1 *Transport + + connPoolOnce sync.Once + connPoolOrDef http2ClientConnPool // non-nil version of ConnPool +} + +func (t *http2Transport) maxHeaderListSize() uint32 { + if t.MaxHeaderListSize == 0 { + return 10 << 20 + } + if t.MaxHeaderListSize == 0xffffffff { + return 0 + } + return t.MaxHeaderListSize +} + +func (t *http2Transport) disableCompression() bool { + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) +} + +func (t *http2Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second + } + return t.PingTimeout + +} + +// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns an error if t1 has already been HTTP/2-enabled. +// +// Use ConfigureTransports instead to configure the HTTP/2 Transport. +func http2ConfigureTransport(t1 *Transport) error { + _, err := http2ConfigureTransports(t1) + return err +} + +// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns a new HTTP/2 Transport for further configuration. +// It returns an error if t1 has already been HTTP/2-enabled. +func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { + connPool := new(http2clientConnPool) + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, + } + connPool.t = t2 + if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + return nil, err + } + if t1.TLSClientConfig == nil { + t1.TLSClientConfig = new(tls.Config) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { + t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + } + upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { + addr := http2authorityAddr("https", authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() + return http2erringRoundTripper{err} + } else if !used { + // Turns out we don't need this c. + // For example, two goroutines made requests to the same host + // at the same time, both kicking off TCP dials. (since protocol + // was unknown) + go c.Close() + } + return t2 + } + if m := t1.TLSNextProto; len(m) == 0 { + t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ + "h2": upgradeFn, + } + } else { + m["h2"] = upgradeFn + } + return t2, nil +} + +func (t *http2Transport) connPool() http2ClientConnPool { + t.connPoolOnce.Do(t.initConnPool) + return t.connPoolOrDef +} + +func (t *http2Transport) initConnPool() { + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &http2clientConnPool{t: t} + } +} + +// ClientConn is the state of a single HTTP/2 client connection to an +// HTTP/2 server. +type http2ClientConn struct { + writeHeader func(name, value string) + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool + + // readLoop goroutine fields: + readerDone chan struct{} // closed on error + readerErr error // set before readerDone is closed + + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests + closing bool + closed bool + seenSettings bool // true if we've seen a settings frame, false otherwise + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*http2clientStream // client-initiated + streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip + nextStreamID uint32 + pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams + pings map[[8]byte]chan struct{} // in flight ping data to notification channel + br *bufio.Reader + lastActive time.Time + lastIdle time.Time // time last idle + // Settings from peer: (also guarded by wmu) + maxFrameSize uint32 + maxConcurrentStreams uint32 + peerMaxHeaderListSize uint64 + initialWindowSize uint32 + + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. + // Write to reqHeaderMu to lock it, read from it to unlock. + // Lock reqmu BEFORE mu or wmu. + reqHeaderMu chan struct{} + + // wmu is held while writing. + // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. + // Only acquire both at the same time when changing peer settings. + wmu sync.Mutex + bw *bufio.Writer + fr *http2Framer + werr error // first write error that has occurred + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder +} + +// clientStream is the state for a single HTTP/2 stream. One of these +// is created for each Transport.RoundTrip call. +type http2clientStream struct { + cc *http2ClientConn + + // Fields of Request that we may access even after the response body is closed. + ctx context.Context + reqCancel <-chan struct{} + + trace *httptrace.ClientTrace // or nil + ID uint32 + bufPipe http2pipe // buffered pipe with the flow-controlled response payload + requestedGzip bool + isHead bool + + abortOnce sync.Once + abort chan struct{} // closed to signal stream should end immediately + abortErr error // set if abort is closed + + peerClosed chan struct{} // closed when the peer sends an END_STREAM flag + donec chan struct{} // closed after the stream is in the closed state + on100 chan struct{} // buffered; written to if a 100 is received + + respHeaderRecv chan struct{} // closed when headers are received + res *http.Response // set if respHeaderRecv is closed + + flow http2flow // guarded by cc.mu + inflow http2flow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read + + reqBody io.ReadCloser + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed bool // body has been closed; guarded by cc.mu + + // owned by writeRequest: + sentEndStream bool // sent an END_STREAM flag to the peer + sentHeaders bool + + // owned by clientConnReadLoop: + firstByte bool // got the first response byte + pastHeaders bool // got first MetaHeadersFrame (actual headers) + pastTrailers bool // got optional second MetaHeadersFrame (trailers) + num1xx uint8 // number of 1xx responses seen + readClosed bool // peer sent an END_STREAM flag + readAborted bool // read loop reset the stream + + trailer http.Header // accumulated trailers + resTrailer *http.Header // client's Response.Trailer +} + +var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error + +// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, +// if any. It returns nil if not set or if the Go version is too old. +func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { + if fn := http2got1xxFuncForTests; fn != nil { + return fn + } + return http2traceGot1xxResponseFunc(cs.trace) +} + +func (cs *http2clientStream) abortStream(err error) { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) +} + +func (cs *http2clientStream) abortStreamLocked(err error) { + cs.abortOnce.Do(func() { + cs.abortErr = err + close(cs.abort) + }) + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true + } + // TODO(dneil): Clean up tests where cs.cc.cond is nil. + if cs.cc.cond != nil { + // Wake up writeRequestBody if it is waiting on flow control. + cs.cc.cond.Broadcast() + } +} + +func (cs *http2clientStream) abortRequestBodyWrite() { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true + cc.cond.Broadcast() + } +} + +type http2stickyErrWriter struct { + conn net.Conn + timeout time.Duration + err *error +} + +func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { + if *sew.err != nil { + return 0, *sew.err + } + for { + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) + } + nn, err := sew.conn.Write(p[n:]) + n += nn + if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { + // Keep extending the deadline so long as we're making progress. + continue + } + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Time{}) + } + *sew.err = err + return n, err + } +} + +// noCachedConnError is the concrete type of ErrNoCachedConn, which +// needs to be detected by net/http regardless of whether it's its +// bundled version (in h2_bundle.go with a rewritten type name) or +// from a user's x/net/http2. As such, as it has a unique method name +// (IsHTTP2NoCachedConnError) that net/http sniffs for via func +// isNoCachedConnError. +type http2noCachedConnError struct{} + +func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} + +func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } + +// isNoCachedConnError reports whether err is of type noCachedConnError +// or its equivalent renamed type in net/http2's h2_bundle.go. Both types +// may coexist in the same running program. +func http2isNoCachedConnError(err error) bool { + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok +} + +var http2ErrNoCachedConn error = http2noCachedConnError{} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type http2RoundTripOpt struct { + // OnlyCachedConn controls whether RoundTripOpt may + // create a new TCP connection. If set true and + // no cached connection is available, RoundTripOpt + // will return ErrNoCachedConn. + OnlyCachedConn bool +} + +func (t *http2Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripOpt(req, http2RoundTripOpt{}) +} + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func http2authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} + +// RoundTripOpt is like RoundTrip, but takes options. +func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) (*http.Response, error) { + if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { + return nil, errors.New("http2: unsupported scheme") + } + + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + for retry := 0; ; retry++ { + cc, err := t.connPool().GetClientConn(req, addr) + if err != nil { + t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) + return nil, err + } + reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) + http2traceGotConn(req, cc, reused) + res, err := cc.RoundTrip(req) + if err != nil && retry <= 6 { + if req, err = http2shouldRetryRequest(req, err); err == nil { + // After the first retry, do exponential backoff with 10% jitter. + if retry == 0 { + continue + } + backoff := float64(uint(1) << (uint(retry) - 1)) + backoff += backoff * (0.1 * mathrand.Float64()) + select { + case <-time.After(time.Second * time.Duration(backoff)): + continue + case <-req.Context().Done(): + err = req.Context().Err() + } + } + } + if err != nil { + t.vlogf("RoundTrip failure: %v", err) + return nil, err + } + return res, nil + } +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle. +// It does not interrupt any connections currently in use. +func (t *http2Transport) CloseIdleConnections() { + if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { + cp.closeIdleConnections() + } +} + +var ( + http2errClientConnClosed = errors.New("http2: client conn is closed") + http2errClientConnUnusable = errors.New("http2: client conn not usable") + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") +) + +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { + if !http2canRetryError(err) { + return nil, err + } + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. + if req.Body == nil || req.Body == http.NoBody { + return req, nil + } + + // If the request body can be reset back to its original + // state via the optional req.GetBody, do that. + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } + + // The Request.Body can't reset back to the beginning, but we + // don't seem to have started to read from it yet, so reuse + // the request directly. + if err == http2errClientConnUnusable { + return req, nil + } + + return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) +} + +func http2canRetryError(err error) bool { + if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { + return true + } + if se, ok := err.(http2StreamError); ok { + if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { + // See golang/go#47635, golang/go#42777 + return true + } + return se.Code == http2ErrCodeRefusedStream + } + return false +} + +func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) + if err != nil { + return nil, err + } + return t.newClientConn(tconn, singleUse) +} + +func (t *http2Transport) newTLSConfig(host string) *tls.Config { + cfg := new(tls.Config) + if t.TLSClientConfig != nil { + *cfg = *t.TLSClientConfig.Clone() + } + if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { + cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) + } + if cfg.ServerName == "" { + cfg.ServerName = host + } + return cfg +} + +func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS + } + return func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) + if err != nil { + return nil, err + } + state := tlsCn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") + } + return tlsCn, nil + } +} + +// disableKeepAlives reports whether connections should be closed as +// soon as possible after handling the first request. +func (t *http2Transport) disableKeepAlives() bool { + return t.t1 != nil && t.t1.DisableKeepAlives +} + +func (t *http2Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 + } + return t.t1.ExpectContinueTimeout +} + +func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { + return t.newClientConn(c, t.disableKeepAlives()) +} + +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { + cc := &http2ClientConn{ + t: t, + tconn: c, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + streams: make(map[uint32]*http2clientStream), + singleUse: singleUse, + wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + reqHeaderMu: make(chan struct{}, 1), + } + if t.t1.dump != nil && t.t1.dump.RequestHead { + cc.writeHeader = func(name, value string) { + t.t1.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + cc._writeHeader(name, value) + } + } else { + cc.writeHeader = cc._writeHeader + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + } + if http2VerboseLogs { + t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) + } + + cc.cond = sync.NewCond(&cc.mu) + cc.flow.add(int32(http2initialWindowSize)) + + // TODO: adjust this writer size to account for frame size + + // MTU + crypto/tls record padding. + cc.bw = bufio.NewWriter(http2stickyErrWriter{ + conn: c, + timeout: t.WriteByteTimeout, + err: &cc.werr, + }) + cc.br = bufio.NewReader(c) + cc.fr = http2NewFramer(cc.bw, cc.br, cc.t.t1.dump) + if t.CountError != nil { + cc.fr.countError = t.CountError + } + cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.MaxHeaderListSize = t.maxHeaderListSize() + + // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on + // henc in response to SETTINGS frames? + cc.henc = hpack.NewEncoder(&cc.hbuf) + + if t.AllowHTTP { + cc.nextStreamID = 3 + } + + if cs, ok := c.(http2connectionStater); ok { + state := cs.ConnectionState() + cc.tlsState = &state + } + + initialSettings := []http2Setting{ + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + } + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) + } + + cc.bw.Write(http2clientPreface) + cc.fr.WriteSettings(initialSettings...) + cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.bw.Flush() + if cc.werr != nil { + cc.Close() + return nil, cc.werr + } + + go cc.readLoop() + return cc, nil +} + +func (cc *http2ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } +} + +// SetDoNotReuse marks cc as not reusable for future HTTP requests. +func (cc *http2ClientConn) SetDoNotReuse() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.doNotReuse = true +} + +func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { + cc.mu.Lock() + defer cc.mu.Unlock() + + old := cc.goAway + cc.goAway = f + + // Merge the previous and current GoAway error frames. + if cc.goAwayDebug == "" { + cc.goAwayDebug = string(f.DebugData()) + } + if old != nil && old.ErrCode != http2ErrCodeNo { + cc.goAway.ErrCode = old.ErrCode + } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + cs.abortStreamLocked(http2errClientConnGotGoAway) + } + } +} + +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. +// +// If the caller is going to immediately make a new request on this +// connection, use ReserveNewRequest instead. +func (cc *http2ClientConn) CanTakeNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.canTakeNewRequestLocked() +} + +// ReserveNewRequest is like CanTakeNewRequest but also reserves a +// concurrent stream in cc. The reservation is decremented on the +// next call to RoundTrip. +func (cc *http2ClientConn) ReserveNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + if st := cc.idleStateLocked(); !st.canTakeNewRequest { + return false + } + cc.streamsReserved++ + return true +} + +// ClientConnState describes the state of a ClientConn. +type http2ClientConnState struct { + // Closed is whether the connection is closed. + Closed bool + + // Closing is whether the connection is in the process of + // closing. It may be closing due to shutdown, being a + // single-use connection, being marked as DoNotReuse, or + // having received a GOAWAY frame. + Closing bool + + // StreamsActive is how many streams are active. + StreamsActive int + + // StreamsReserved is how many streams have been reserved via + // ClientConn.ReserveNewRequest. + StreamsReserved int + + // StreamsPending is how many requests have been sent in excess + // of the peer's advertised MaxConcurrentStreams setting and + // are waiting for other streams to complete. + StreamsPending int + + // MaxConcurrentStreams is how many concurrent streams the + // peer advertised as acceptable. Zero means no SETTINGS + // frame has been received yet. + MaxConcurrentStreams uint32 + + // LastIdle, if non-zero, is when the connection last + // transitioned to idle state. + LastIdle time.Time +} + +// State returns a snapshot of cc's state. +func (cc *http2ClientConn) State() http2ClientConnState { + cc.wmu.Lock() + maxConcurrent := cc.maxConcurrentStreams + if !cc.seenSettings { + maxConcurrent = 0 + } + cc.wmu.Unlock() + + cc.mu.Lock() + defer cc.mu.Unlock() + return http2ClientConnState{ + Closed: cc.closed, + Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, + StreamsActive: len(cc.streams), + StreamsReserved: cc.streamsReserved, + StreamsPending: cc.pendingRequests, + LastIdle: cc.lastIdle, + MaxConcurrentStreams: maxConcurrent, + } +} + +// clientConnIdleState describes the suitability of a client +// connection to initiate a new RoundTrip request. +type http2clientConnIdleState struct { + canTakeNewRequest bool +} + +func (cc *http2ClientConn) idleState() http2clientConnIdleState { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.idleStateLocked() +} + +func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { + if cc.singleUse && cc.nextStreamID > 1 { + return + } + var maxConcurrentOkay bool + if cc.t.StrictMaxConcurrentStreams { + // We'll tell the caller we can take a new request to + // prevent the caller from dialing a new TCP + // connection, but then we'll block later before + // writing it. + maxConcurrentOkay = true + } else { + maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams) + } + + st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + !cc.doNotReuse && + int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && + !cc.tooIdleLocked() + return +} + +func (cc *http2ClientConn) canTakeNewRequestLocked() bool { + st := cc.idleStateLocked() + return st.canTakeNewRequest +} + +// tooIdleLocked reports whether this connection has been been sitting idle +// for too much wall time. +func (cc *http2ClientConn) tooIdleLocked() bool { + // The Round(0) strips the monontonic clock reading so the + // times are compared based on their wall time. We don't want + // to reuse a connection that's been sitting idle during + // VM/laptop suspend if monotonic time was also frozen. + return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout +} + +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + +func (cc *http2ClientConn) closeIfIdle() { + cc.mu.Lock() + if len(cc.streams) > 0 || cc.streamsReserved > 0 { + cc.mu.Unlock() + return + } + cc.closed = true + nextID := cc.nextStreamID + // TODO: do clients send GOAWAY too? maybe? Just Close: + cc.mu.Unlock() + + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) + } + cc.tconn.Close() +} + +func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.doNotReuse && len(cc.streams) == 0 +} + +var http2shutdownEnterWaitStateHook = func() {} + +// Shutdown gracefully closes the client connection, waiting for running streams to complete. +func (cc *http2ClientConn) Shutdown(ctx context.Context) error { + if err := cc.sendGoAway(); err != nil { + return err + } + // Wait for all in-flight streams to complete or connection to close + done := make(chan error, 1) + cancelled := false // guarded by cc.mu + go func() { + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if len(cc.streams) == 0 || cc.closed { + cc.closed = true + done <- cc.tconn.Close() + break + } + if cancelled { + break + } + cc.cond.Wait() + } + }() + http2shutdownEnterWaitStateHook() + select { + case err := <-done: + return err + case <-ctx.Done(): + cc.mu.Lock() + // Free the goroutine above + cancelled = true + cc.cond.Broadcast() + cc.mu.Unlock() + return ctx.Err() + } +} + +func (cc *http2ClientConn) sendGoAway() error { + cc.mu.Lock() + closing := cc.closing + cc.closing = true + maxStreamID := cc.nextStreamID + cc.mu.Unlock() + if closing { + // GOAWAY sent already + return nil + } + + cc.wmu.Lock() + defer cc.wmu.Unlock() + // Send a graceful shutdown frame to server + if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { + return err + } + if err := cc.bw.Flush(); err != nil { + return err + } + // Prevent new requests + return nil +} + +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *http2ClientConn) closeForError(err error) error { + cc.mu.Lock() + cc.closed = true + for _, cs := range cc.streams { + cs.abortStreamLocked(err) + } + defer cc.cond.Broadcast() + defer cc.mu.Unlock() + return cc.tconn.Close() +} + +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *http2ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} + +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *http2ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + if f := cc.t.CountError; f != nil { + f("conn_close_lost_ping") + } + return cc.closeForError(err) +} + +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") + +func http2commaSeparatedTrailers(req *http.Request) (string, error) { + keys := make([]string, 0, len(req.Trailer)) + for k := range req.Trailer { + k = http.CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", fmt.Errorf("invalid Trailer key %q", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + return strings.Join(keys, ","), nil + } + return "", nil +} + +func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { + if cc.t.t1 != nil { + return cc.t.t1.ResponseHeaderTimeout + } + // No way to do this (yet?) with just an http2.Transport. Probably + // no need. Request.Cancel this is the new way. We only need to support + // this for compatibility with the old http.Transport fields when + // we're doing transparent http2. + return 0 +} + +// checkConnHeaders checks whether req has any invalid connection-level headers. +// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields. +// Certain headers are special-cased as okay but not transmitted later. +func http2checkConnHeaders(req *http.Request) error { + if v := req.Header.Get("Upgrade"); v != "" { + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) + } + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) + } + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) + } + return nil +} + +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +func (cc *http2ClientConn) decrStreamReservations() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.decrStreamReservationsLocked() +} + +func (cc *http2ClientConn) decrStreamReservationsLocked() { + if cc.streamsReserved > 0 { + cc.streamsReserved-- + } +} + +func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + cs := &http2clientStream{ + cc: cc, + ctx: ctx, + reqCancel: req.Cancel, + isHead: req.Method == "HEAD", + reqBody: req.Body, + reqBodyContentLength: http2actualContentLength(req), + trace: httptrace.ContextClientTrace(ctx), + peerClosed: make(chan struct{}), + abort: make(chan struct{}), + respHeaderRecv: make(chan struct{}), + donec: make(chan struct{}), + } + go cs.doRequest(req) + + waitDone := func() error { + select { + case <-cs.donec: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } + } + + handleResponseHeaders := func() (*http.Response, error) { + res := cs.res + if res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite() + } + res.Request = req + res.TLS = cc.tlsState + if res.Body == http2noBody && http2actualContentLength(req) == 0 { + // If there isn't a request or response body still being + // written, then wait for the stream to be closed before + // RoundTrip returns. + if err := waitDone(); err != nil { + return nil, err + } + } + return res, nil + } + + for { + select { + case <-cs.respHeaderRecv: + return handleResponseHeaders() + case <-cs.abort: + select { + case <-cs.respHeaderRecv: + // If both cs.respHeaderRecv and cs.abort are signaling, + // pick respHeaderRecv. The server probably wrote the + // response and immediately reset the stream. + // golang.org/issue/49645 + return handleResponseHeaders() + default: + waitDone() + return nil, cs.abortErr + } + case <-ctx.Done(): + err := ctx.Err() + cs.abortStream(err) + return nil, err + case <-cs.reqCancel: + cs.abortStream(http2errRequestCanceled) + return nil, http2errRequestCanceled + } + } +} + +// doRequest runs for the duration of the request lifetime. +// +// It sends the request and performs post-request cleanup (closing Request.Body, etc.). +func (cs *http2clientStream) doRequest(req *http.Request) { + err := cs.writeRequest(req) + cs.cleanupWriteRequest(err) +} + +// writeRequest sends a request. +// +// It returns nil after the request is written, the response read, +// and the request stream is half-closed by the peer. +// +// It returns non-nil if the request ends otherwise. +// If the returned error is StreamError, the error Code may be used in resetting the stream. +func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { + cc := cs.cc + ctx := cs.ctx + + if err := http2checkConnHeaders(req); err != nil { + return err + } + + // Acquire the new-request lock by writing to reqHeaderMu. + // This lock guards the critical section covering allocating a new stream ID + // (requires mu) and creating the stream (requires wmu). + if cc.reqHeaderMu == nil { + panic("RoundTrip on uninitialized ClientConn") // for tests + } + select { + case cc.reqHeaderMu <- struct{}{}: + case <-cs.reqCancel: + return http2errRequestCanceled + case <-ctx.Done(): + return ctx.Err() + } + + cc.mu.Lock() + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + cc.decrStreamReservationsLocked() + if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { + cc.mu.Unlock() + <-cc.reqHeaderMu + return err + } + cc.addStreamLocked(cs) // assigns stream ID + if http2isConnectionCloseRequest(req) { + cc.doNotReuse = true + } + cc.mu.Unlock() + + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + if !cc.t.disableCompression() && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + !cs.isHead { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + cs.requestedGzip = true + } + + continueTimeout := cc.t.expectContinueTimeout() + if continueTimeout != 0 { + if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { + continueTimeout = 0 + } else { + cs.on100 = make(chan struct{}, 1) + } + } + + // Past this point (where we send request headers), it is possible for + // RoundTrip to return successfully. Since the RoundTrip contract permits + // the caller to "mutate or reuse" the Request after closing the Response's Body, + // we must take care when referencing the Request from here on. + err = cs.encodeAndWriteHeaders(req) + <-cc.reqHeaderMu + if err != nil { + return err + } + + hasBody := cs.reqBodyContentLength != 0 + if !hasBody { + cs.sentEndStream = true + } else { + if continueTimeout != 0 { + http2traceWait100Continue(cs.trace) + timer := time.NewTimer(continueTimeout) + select { + case <-timer.C: + err = nil + case <-cs.on100: + err = nil + case <-cs.abort: + err = cs.abortErr + case <-ctx.Done(): + err = ctx.Err() + case <-cs.reqCancel: + err = http2errRequestCanceled + } + timer.Stop() + if err != nil { + http2traceWroteRequest(cs.trace, err) + return err + } + } + + if err = cs.writeRequestBody(req); err != nil { + if err != http2errStopReqBodyWrite { + http2traceWroteRequest(cs.trace, err) + return err + } + } else { + cs.sentEndStream = true + if dump := cs.cc.t.t1.dump; dump != nil && dump.RequestBody { + dump.dump([]byte("\r\n\r\n")) + } + } + } + + http2traceWroteRequest(cs.trace, err) + + var respHeaderTimer <-chan time.Time + var respHeaderRecv chan struct{} + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + respHeaderRecv = cs.respHeaderRecv + } + // Wait until the peer half-closes its end of the stream, + // or until the request is aborted (via context, error, or otherwise), + // whichever comes first. + for { + select { + case <-cs.peerClosed: + return nil + case <-respHeaderTimer: + return http2errTimeout + case <-respHeaderRecv: + respHeaderRecv = nil + respHeaderTimer = nil // keep waiting for END_STREAM + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } + } +} + +func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request) error { + cc := cs.cc + ctx := cs.ctx + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + // If the request was canceled while waiting for cc.mu, just quit. + select { + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + default: + } + + // Encode headers. + // + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return err + } + hasTrailers := trailers != "" + contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + if err != nil { + return err + } + + // Write the request. + endStream := !hasBody && !hasTrailers + cs.sentHeaders = true + err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + http2traceWroteHeaders(cs.trace) + return err +} + +// cleanupWriteRequest performs post-request tasks. +// +// If err (the result of writeRequest) is non-nil and the stream is not closed, +// cleanupWriteRequest will send a reset to the peer. +func (cs *http2clientStream) cleanupWriteRequest(err error) { + cc := cs.cc + + if cs.ID == 0 { + // We were canceled before creating the stream, so return our reservation. + cc.decrStreamReservations() + } + + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cs.reqBodyClosed = true + cc.mu.Unlock() + if !bodyClosed && cs.reqBody != nil { + cs.reqBody.Close() + } + + if err != nil && cs.sentEndStream { + // If the connection is closed immediately after the response is read, + // we may be aborted before finishing up here. If the stream was closed + // cleanly on both sides, there is no error. + select { + case <-cs.peerClosed: + err = nil + default: + } + } + if err != nil { + cs.abortStream(err) // possibly redundant, but harmless + if cs.sentHeaders { + if se, ok := err.(http2StreamError); ok { + if se.Cause != http2errFromPeer { + cc.writeStreamReset(cs.ID, se.Code, err) + } + } else { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) + } + } + cs.bufPipe.CloseWithError(err) // no-op if already closed + } else { + if cs.sentHeaders && !cs.sentEndStream { + cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) + } + cs.bufPipe.CloseWithError(http2errRequestCanceled) + } + if cs.ID != 0 { + cc.forgetStreamID(cs.ID) + } + + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() + } + + close(cs.donec) +} + +// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. +// Must hold cc.mu. +func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { + for { + cc.lastActive = time.Now() + if cc.closed || !cc.canTakeNewRequestLocked() { + return http2errClientConnUnusable + } + cc.lastIdle = time.Time{} + if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { + return nil + } + cc.pendingRequests++ + cc.cond.Wait() + cc.pendingRequests-- + select { + case <-cs.abort: + return cs.abortErr + default: + } + } +} + +// requires cc.wmu be held +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { + first := true // first frame written (HEADERS is first, then CONTINUATION) + for len(hdrs) > 0 && cc.werr == nil { + chunk := hdrs + if len(chunk) > maxFrameSize { + chunk = chunk[:maxFrameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: streamID, + BlockFragment: chunk, + EndStream: endStream, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(streamID, endHeaders, chunk) + } + } + cc.bw.Flush() + return cc.werr +} + +// internal error values; they don't escape to callers +var ( + // abort request body write; don't send cancel + http2errStopReqBodyWrite = errors.New("http2: aborting request body write") + + // abort request body write, but send stream reset of cancel. + http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") + + http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") +) + +// frameScratchBufferLen returns the length of a buffer to use for +// outgoing request bodies to read/write to/from. +// +// It returns max(1, min(peer's advertised max frame size, +// Request.ContentLength+1, 512KB)). +func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { + const max = 512 << 10 + n := int64(maxFrameSize) + if n > max { + n = max + } + if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n { + // Add an extra byte past the declared content-length to + // give the caller's Request.Body io.textprotoReader a chance to + // give us more bytes than they declared, so we can catch it + // early. + n = cl + 1 + } + if n < 1 { + return 1 + } + return int(n) // doesn't truncate; max is 512K +} + +var http2bufPool sync.Pool // of *[]byte + +func (cs *http2clientStream) writeRequestBody(req *http.Request) (err error) { + cc := cs.cc + body := cs.reqBody + sentEnd := false // whether we sent the final DATA frame w/ END_STREAM + + hasTrailers := req.Trailer != nil + remainLen := cs.reqBodyContentLength + hasContentLen := remainLen != -1 + + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() + + // Scratch buffer for reading into & writing from. + scratchLen := cs.frameScratchBufferLen(maxFrameSize) + var buf []byte + if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer http2bufPool.Put(bp) + buf = *bp + } else { + buf = make([]byte, scratchLen) + defer http2bufPool.Put(&buf) + } + + var sawEOF bool + for !sawEOF { + n, err := body.Read(buf[:len(buf)]) + if hasContentLen { + remainLen -= int64(n) + if remainLen == 0 && err == nil { + // The request body's Content-Length was predeclared and + // we just finished reading it all, but the underlying io.textprotoReader + // returned the final chunk with a nil error (which is one of + // the two valid things a textprotoReader can do at EOF). Because we'd prefer + // to send the END_STREAM bit early, double-check that we're actually + // at EOF. Subsequent reads should return (0, EOF) at this point. + // If either value is different, we return an error in one of two ways below. + var scratch [1]byte + var n1 int + n1, err = body.Read(scratch[:]) + remainLen -= int64(n1) + } + if remainLen < 0 { + err = http2errReqBodyTooLong + return err + } + } + if err != nil { + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cc.mu.Unlock() + switch { + case bodyClosed: + return http2errStopReqBodyWrite + case err == io.EOF: + sawEOF = true + err = nil + default: + return err + } + } + + remain := buf[:n] + for len(remain) > 0 && err == nil { + var allowed int32 + allowed, err = cs.awaitFlowControl(len(remain)) + if err != nil { + return err + } + cc.wmu.Lock() + data := remain[:allowed] + remain = remain[allowed:] + sentEnd = sawEOF && len(remain) == 0 && !hasTrailers + err = cc.fr.WriteData(cs.ID, sentEnd, data) + if err == nil { + // TODO(bradfitz): this flush is for latency, not bandwidth. + // Most requests won't need this. Make this opt-in or + // opt-out? Use some heuristic on the body type? Nagel-like + // timers? Based on 'n'? Only last chunk of this for loop, + // unless flow control tokens are low? For now, always. + // If we change this, see comment below. + err = cc.bw.Flush() + } + cc.wmu.Unlock() + } + if err != nil { + return err + } + } + + if sentEnd { + // Already sent END_STREAM (which implies we have no + // trailers) and flushed, because currently all + // WriteData frames above get a flush. So we're done. + return nil + } + + // Since the RoundTrip contract permits the caller to "mutate or reuse" + // a request after the Response's Body is closed, verify that this hasn't + // happened before accessing the trailers. + cc.mu.Lock() + trailer := req.Trailer + err = cs.abortErr + cc.mu.Unlock() + if err != nil { + return err + } + + cc.wmu.Lock() + defer cc.wmu.Unlock() + var trls []byte + if len(trailer) > 0 { + trls, err = cc.encodeTrailers(trailer) + if err != nil { + return err + } + } + + // Two ways to send END_STREAM: either with trailers, or + // with an empty DATA frame. + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) + } + if ferr := cc.bw.Flush(); ferr != nil && err == nil { + err = ferr + } + return err +} + +// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow +// control tokens from the server. +// It returns either the non-zero number of tokens taken or an error +// if the stream is dead. +func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { + cc := cs.cc + ctx := cs.ctx + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if cc.closed { + return 0, http2errClientConnClosed + } + if cs.reqBodyClosed { + return 0, http2errStopReqBodyWrite + } + select { + case <-cs.abort: + return 0, cs.abortErr + case <-ctx.Done(): + return 0, ctx.Err() + case <-cs.reqCancel: + return 0, http2errRequestCanceled + default: + } + if a := cs.flow.available(); a > 0 { + take := a + if int(take) > maxBytes { + + take = int32(maxBytes) // can't truncate int; take is int32 + } + if take > int32(cc.maxFrameSize) { + take = int32(cc.maxFrameSize) + } + cs.flow.take(take) + return take, nil + } + cc.cond.Wait() + } +} + +var http2errNilRequestURL = errors.New("http2: Request.URI is nil") + +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { + cc.hbuf.Reset() + if req.URL == nil { + return nil, http2errNilRequestURL + } + + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + m := req.Method + if m == "" { + m = http.MethodGet + } + f(":method", m) + if req.Method != "CONNECT" { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if http2asciiEqualFold(k, "connection") || + http2asciiEqualFold(k, "proxy-connection") || + http2asciiEqualFold(k, "transfer-encoding") || + http2asciiEqualFold(k, "upgrade") || + http2asciiEqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if http2asciiEqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } else if http2asciiEqualFold(k, "cookie") { + // Per 8.1.2.5 To allow for better compression efficiency, the + // Cookie header field MAY be split into separate header fields, + // each with one or more cookie-pairs. + for _, v := range vv { + for { + p := strings.IndexByte(v, ';') + if p < 0 { + break + } + f("cookie", v[:p]) + p++ + // strip space after semicolon if any. + for p+1 <= len(v) && v[p] == ' ' { + p++ + } + v = v[p:] + } + if len(v) > 0 { + f("cookie", v) + } + } + continue + } + + for _, v := range vv { + f(k, v) + } + } + if http2shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", http2defaultUserAgent) + } + } + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + + trace := httptrace.ContextClientTrace(req.Context()) + traceHeaders := http2traceHasWroteHeaderField(trace) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name, ascii := http2asciiToLower(name) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + return + } + cc.writeHeader(name, value) + if traceHeaders { + http2traceWroteHeaderField(trace, name, value) + } + }) + + if cc.t.t1.dump != nil && cc.t.t1.dump.RequestHead { + cc.t.t1.dump.dump([]byte("\r\n")) + } + + return cc.hbuf.Bytes(), nil +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func http2shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) { + cc.hbuf.Reset() + + hlSize := uint64(0) + for k, vv := range trailer { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + hlSize += uint64(hf.Size()) + } + } + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + + for k, vv := range trailer { + lowKey, ascii := http2asciiToLower(k) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + continue + } + // Transfer-Encoding, etc.. have already been filtered at the + // start of RoundTrip + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + return cc.hbuf.Bytes(), nil +} + +func (cc *http2ClientConn) _writeHeader(name, value string) { + if http2VerboseLogs { + log.Printf("http2: Transport encoding header %q = %q", name, value) + } + cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) +} + +type http2resAndError struct { + _ http2incomparable + res *http.Response + err error +} + +// requires cc.mu be held. +func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { + cs.flow.add(int32(cc.initialWindowSize)) + cs.flow.setConnFlow(&cc.flow) + cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.setConnFlow(&cc.inflow) + cs.ID = cc.nextStreamID + cc.nextStreamID += 2 + cc.streams[cs.ID] = cs + if cs.ID == 0 { + panic("assigned stream ID 0") + } +} + +func (cc *http2ClientConn) forgetStreamID(id uint32) { + cc.mu.Lock() + slen := len(cc.streams) + delete(cc.streams, id) + if len(cc.streams) != slen-1 { + panic("forgetting unknown stream id") + } + cc.lastActive = time.Now() + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + // Wake up writeRequestBody via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) + } + cc.closed = true + defer cc.tconn.Close() + } + + cc.mu.Unlock() +} + +// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. +type http2clientConnReadLoop struct { + _ http2incomparable + cc *http2ClientConn +} + +// readLoop runs in its own goroutine and reads and dispatches frames. +func (cc *http2ClientConn) readLoop() { + rl := &http2clientConnReadLoop{cc: cc} + defer rl.cleanup() + cc.readerErr = rl.run() + if ce, ok := cc.readerErr.(http2ConnectionError); ok { + cc.wmu.Lock() + cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.wmu.Unlock() + } +} + +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type http2GoAwayError struct { + LastStreamID uint32 + ErrCode http2ErrCode + DebugData string +} + +func (e http2GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} + +func http2isEOFOrNetReadError(err error) bool { + if err == io.EOF { + return true + } + ne, ok := err.(*net.OpError) + return ok && ne.Op == "read" +} + +func (rl *http2clientConnReadLoop) cleanup() { + cc := rl.cc + defer cc.tconn.Close() + defer cc.t.connPool().MarkDead(cc) + defer close(cc.readerDone) + + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + + // Close any response bodies if the server closes prematurely. + // TODO: also do this if we've written the headers but not + // gotten a response yet. + err := cc.readerErr + cc.mu.Lock() + if cc.goAway != nil && http2isEOFOrNetReadError(err) { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else if err == io.EOF { + err = io.ErrUnexpectedEOF + } + cc.closed = true + for _, cs := range cc.streams { + select { + case <-cs.peerClosed: + // The server closed the stream before closing the conn, + // so no need to interrupt it. + default: + cs.abortStreamLocked(err) + } + } + cc.cond.Broadcast() + cc.mu.Unlock() +} + +// countReadFrameError calls Transport.CountError with a string +// representing err. +func (cc *http2ClientConn) countReadFrameError(err error) { + f := cc.t.CountError + if f == nil || err == nil { + return + } + if ce, ok := err.(http2ConnectionError); ok { + errCode := http2ErrCode(ce) + f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) + return + } + if errors.Is(err, io.EOF) { + f("read_frame_eof") + return + } + if errors.Is(err, io.ErrUnexpectedEOF) { + f("read_frame_unexpected_eof") + return + } + if errors.Is(err, http2ErrFrameTooLarge) { + f("read_frame_too_large") + return + } + f("read_frame_other") +} + +func (rl *http2clientConnReadLoop) run() error { + cc := rl.cc + gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } + for { + f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } + if err != nil { + cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) + } + if se, ok := err.(http2StreamError); ok { + if cs := rl.streamByID(se.StreamID); cs != nil { + if se.Cause == nil { + se.Cause = cc.fr.errDetail + } + rl.endStreamError(cs, se) + } + continue + } else if err != nil { + cc.countReadFrameError(err) + return err + } + if http2VerboseLogs { + cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) + } + if !gotSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + cc.logf("protocol error: received %T before a SETTINGS frame", f) + return http2ConnectionError(http2ErrCodeProtocol) + } + gotSettings = true + } + + switch f := f.(type) { + case *http2MetaHeadersFrame: + err = rl.processHeaders(f) + case *http2DataFrame: + err = rl.processData(f) + case *http2GoAwayFrame: + err = rl.processGoAway(f) + case *http2RSTStreamFrame: + err = rl.processResetStream(f) + case *http2SettingsFrame: + err = rl.processSettings(f) + case *http2PushPromiseFrame: + err = rl.processPushPromise(f) + case *http2WindowUpdateFrame: + err = rl.processWindowUpdate(f) + case *http2PingFrame: + err = rl.processPing(f) + default: + cc.logf("Transport: unhandled response frame type %T", f) + } + if err != nil { + if http2VerboseLogs { + cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) + } + return err + } + } +} + +func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { + cs := rl.streamByID(f.StreamID) + if cs == nil { + // We'd get here if we canceled a request while the + // server had its response still in flight. So if this + // was just something we canceled, ignore it. + return nil + } + if cs.readClosed { + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: errors.New("protocol error: headers after END_STREAM"), + }) + return nil + } + if !cs.firstByte { + if cs.trace != nil { + // TODO(bradfitz): move first response byte earlier, + // when we first read the 9 byte header, not waiting + // until all the HEADERS+CONTINUATION frames have been + // merged. This works for now. + http2traceFirstResponseByte(cs.trace) + } + cs.firstByte = true + } + if !cs.pastHeaders { + cs.pastHeaders = true + } else { + return rl.processTrailers(cs, f) + } + + res, err := rl.handleResponse(cs, f) + if err != nil { + if _, ok := err.(http2ConnectionError); ok { + return err + } + // Any other error type is a stream error. + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: err, + }) + return nil // return nil from process* funcs to keep conn alive + } + if res == nil { + // (nil, nil) special case. See handleResponse docs. + return nil + } + cs.resTrailer = &res.Trailer + cs.res = res + close(cs.respHeaderRecv) + if f.StreamEnded() { + rl.endStream(cs) + } + return nil +} + +// may return error types nil, or ConnectionError. Any other error value +// is a StreamError of type ErrCodeProtocol. The returned error in that case +// is the detail. +// +// As a special case, handleResponse may return (nil, nil) to skip the +// frame (currently only used for 1xx responses). +func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*http.Response, error) { + if f.Truncated { + return nil, http2errResponseHeaderListSize + } + + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("malformed response from server: missing status pseudo header") + } + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") + } + + regularFields := f.RegularFields() + strs := make([]string, len(regularFields)) + header := make(http.Header, len(regularFields)) + res := &http.Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + http.StatusText(statusCode), + } + for _, hf := range regularFields { + key := http.CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(http.Header) + res.Trailer = t + } + http2foreachHeaderElement(hf.Value, func(v string) { + t[http.CanonicalHeaderKey(v)] = nil + }) + } else { + vv := header[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = hf.Value + header[key] = vv + } else { + header[key] = append(vv, hf.Value) + } + } + } + + if statusCode >= 100 && statusCode <= 199 { + if f.StreamEnded() { + return nil, errors.New("1xx informational response with END_STREAM flag") + } + cs.num1xx++ + const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http + if cs.num1xx > max1xxResponses { + return nil, errors.New("http2: too many 1xx informational responses") + } + if fn := cs.get1xxTraceFunc(); fn != nil { + if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil { + return nil, err + } + } + if statusCode == 100 { + http2traceGot100Continue(cs.trace) + select { + case cs.on100 <- struct{}{}: + default: + } + } + cs.pastHeaders = false // do it all again + return nil, nil + } + + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) + } else { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } else if f.StreamEnded() && !cs.isHead { + res.ContentLength = 0 + } + + if cs.isHead { + res.Body = http2noBody + return res, nil + } + + if f.StreamEnded() { + if res.ContentLength > 0 { + res.Body = http2missingBody{} + } else { + res.Body = http2noBody + } + return res, nil + } + + cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} + + if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + res.Uncompressed = true + } + + return res, nil +} + +func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { + if cs.pastTrailers { + // Too many HEADERS frames for this stream. + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !f.StreamEnded() { + // We expect that any headers for trailers also + // has END_STREAM. + return http2ConnectionError(http2ErrCodeProtocol) + } + if len(f.PseudoFields()) > 0 { + // No pseudo header fields are defined for trailers. + // TODO: ConnectionError might be overly harsh? Check. + return http2ConnectionError(http2ErrCodeProtocol) + } + + trailer := make(http.Header) + for _, hf := range f.RegularFields() { + key := http.CanonicalHeaderKey(hf.Name) + trailer[key] = append(trailer[key], hf.Value) + } + cs.trailer = trailer + + rl.endStream(cs) + return nil +} + +// transportResponseBody is the concrete type of Transport.RoundTrip's +// Response.Body. It is an io.ReadCloser. +type http2transportResponseBody struct { + cs *http2clientStream +} + +func (b http2transportResponseBody) Read(p []byte) (n int, err error) { + cs := b.cs + cc := cs.cc + + if cs.readErr != nil { + return 0, cs.readErr + } + n, err = b.cs.bufPipe.Read(p) + if cs.bytesRemain != -1 { + if int64(n) > cs.bytesRemain { + n = int(cs.bytesRemain) + if err == nil { + err = errors.New("net/http: server replied with more than declared Content-Length; truncated") + cs.abortStream(err) + } + cs.readErr = err + return int(cs.bytesRemain), err + } + cs.bytesRemain -= int64(n) + if err == io.EOF && cs.bytesRemain > 0 { + err = io.ErrUnexpectedEOF + cs.readErr = err + return n, err + } + } + if n == 0 { + // No flow control tokens to send back. + return + } + + cc.mu.Lock() + var connAdd, streamAdd int32 + // Check the conn-level first, before the stream-level. + if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { + connAdd = http2transportDefaultConnFlow - v + cc.inflow.add(connAdd) + } + if err == nil { // No need to refresh if the stream is over or failed. + // Consider any buffered body data (read from the conn but not + // consumed by the client) when computing flow control for this + // stream. + v := int(cs.inflow.available()) + cs.bufPipe.Len() + if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = int32(http2transportDefaultStreamFlow - v) + cs.inflow.add(streamAdd) + } + } + cc.mu.Unlock() + + if connAdd != 0 || streamAdd != 0 { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if connAdd != 0 { + cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + } + if streamAdd != 0 { + cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + } + cc.bw.Flush() + } + return +} + +var http2errClosedResponseBody = errors.New("http2: response body closed") + +func (b http2transportResponseBody) Close() error { + cs := b.cs + cc := cs.cc + + unread := cs.bufPipe.Len() + if unread > 0 { + cc.mu.Lock() + // Return connection-level flow control. + if unread > 0 { + cc.inflow.add(int32(unread)) + } + cc.mu.Unlock() + + // TODO(dneil): Acquiring this mutex can block indefinitely. + // Move flow control return to a goroutine? + cc.wmu.Lock() + // Return connection-level flow control. + if unread > 0 { + cc.fr.WriteWindowUpdate(0, uint32(unread)) + } + cc.bw.Flush() + cc.wmu.Unlock() + } + + cs.bufPipe.BreakWithError(http2errClosedResponseBody) + cs.abortStream(http2errClosedResponseBody) + + select { + case <-cs.donec: + case <-cs.ctx.Done(): + // See golang/go#49366: The net/http package can cancel the + // request context after the response body is fully read. + // Don't treat this as an error. + return nil + case <-cs.reqCancel: + return http2errRequestCanceled + } + return nil +} + +func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { + cc := rl.cc + cs := rl.streamByID(f.StreamID) + data := f.Data() + if cs == nil { + cc.mu.Lock() + neverSent := cc.nextStreamID + cc.mu.Unlock() + if f.StreamID >= neverSent { + // We never asked for this. + cc.logf("http2: Transport received unsolicited DATA frame; closing connection") + return http2ConnectionError(http2ErrCodeProtocol) + } + // We probably did ask for this, but canceled. Just ignore it. + // TODO: be stricter here? only silently ignore things which + // we canceled, but not things which were closed normally + // by the peer? Tough without accumulating too much state. + + // But at least return their flow control: + if f.Length > 0 { + cc.mu.Lock() + cc.inflow.add(int32(f.Length)) + cc.mu.Unlock() + + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(f.Length)) + cc.bw.Flush() + cc.wmu.Unlock() + } + return nil + } + if cs.readClosed { + cc.logf("protocol error: received DATA after END_STREAM") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + if !cs.firstByte { + cc.logf("protocol error: received DATA before a HEADERS frame") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + if f.Length > 0 { + if cs.isHead && len(data) > 0 { + cc.logf("protocol error: received DATA on a HEAD request") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + // Check connection-level flow control. + cc.mu.Lock() + if cs.inflow.available() >= int32(f.Length) { + cs.inflow.take(int32(f.Length)) + } else { + cc.mu.Unlock() + return http2ConnectionError(http2ErrCodeFlowControl) + } + // Return any padded flow control now, since we won't + // refund it later on body reads. + var refund int + if pad := int(f.Length) - len(data); pad > 0 { + refund += pad + } + + didReset := false + var err error + if len(data) > 0 { + if _, err = cs.bufPipe.Write(data); err != nil { + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset = true + refund += len(data) + } + } + + if refund > 0 { + cc.inflow.add(int32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + } + } + cc.mu.Unlock() + + if refund > 0 { + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(refund)) + if !didReset { + cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + } + cc.bw.Flush() + cc.wmu.Unlock() + } + + if err != nil { + rl.endStreamError(cs, err) + return nil + } + } + + if f.StreamEnded() { + rl.endStream(cs) + } + return nil +} + +func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { + // TODO: check that any declared content-length matches, like + // server.go's (*stream).endStream method. + if !cs.readClosed { + cs.readClosed = true + // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a + // race condition: The caller can read io.EOF from Response.Body + // and close the body before we close cs.peerClosed, causing + // cleanupWriteRequest to send a RST_STREAM. + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) + close(cs.peerClosed) + } +} + +func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { + cs.readAborted = true + cs.abortStream(err) +} + +func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs := rl.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs + } + return nil +} + +func (cs *http2clientStream) copyTrailers() { + for k, vv := range cs.trailer { + t := cs.resTrailer + if *t == nil { + *t = make(http.Header) + } + (*t)[k] = vv + } +} + +func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { + cc := rl.cc + cc.t.connPool().MarkDead(cc) + if f.ErrCode != 0 { + // TODO: deal with GOAWAY more. particularly the error code + cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + if fn := cc.t.CountError; fn != nil { + fn("recv_goaway_" + f.ErrCode.stringToken()) + } + + } + cc.setGoAway(f) + return nil +} + +func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { + cc := rl.cc + // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. + // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. + cc.wmu.Lock() + defer cc.wmu.Unlock() + + if err := rl.processSettingsNoWrite(f); err != nil { + return err + } + if !f.IsAck() { + cc.fr.WriteSettingsAck() + cc.bw.Flush() + } + return nil +} + +func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + + if f.IsAck() { + if cc.wantSettingsAck { + cc.wantSettingsAck = false + return nil + } + return http2ConnectionError(http2ErrCodeProtocol) + } + + var seenMaxConcurrentStreams bool + err := f.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + seenMaxConcurrentStreams = true + case http2SettingMaxHeaderListSize: + cc.peerMaxHeaderListSize = uint64(s.Val) + case http2SettingInitialWindowSize: + // Values above the maximum flow-control + // window size of 2^31-1 MUST be treated as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + if s.Val > math.MaxInt32 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + + // Adjust flow control of currently-open + // frames by the difference of the old initial + // window size and this one. + delta := int32(s.Val) - int32(cc.initialWindowSize) + for _, cs := range cc.streams { + cs.flow.add(delta) + } + cc.cond.Broadcast() + + cc.initialWindowSize = s.Val + default: + // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. + cc.vlogf("Unhandled Setting: %v", s) + } + return nil + }) + if err != nil { + return err + } + + if !cc.seenSettings { + if !seenMaxConcurrentStreams { + // This was the servers initial SETTINGS frame and it + // didn't contain a MAX_CONCURRENT_STREAMS field so + // increase the number of concurrent streams this + // connection can establish to our default. + cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + } + cc.seenSettings = true + } + + return nil +} + +func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { + cc := rl.cc + cs := rl.streamByID(f.StreamID) + if f.StreamID != 0 && cs == nil { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + fl := &cc.flow + if cs != nil { + fl = &cs.flow + } + if !fl.add(int32(f.Increment)) { + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.cond.Broadcast() + return nil +} + +func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { + cs := rl.streamByID(f.StreamID) + if cs == nil { + // TODO: return error if server tries to RST_STREAM an idle stream + return nil + } + serr := http2streamError(cs.ID, f.ErrCode) + serr.Cause = http2errFromPeer + if f.ErrCode == http2ErrCodeProtocol { + rl.cc.SetDoNotReuse() + } + if fn := cs.cc.t.CountError; fn != nil { + fn("recv_rststream_" + f.ErrCode.stringToken()) + } + cs.abortStream(serr) + + cs.bufPipe.CloseWithError(serr) + return nil +} + +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + // check for dup before insert + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + errc := make(chan error, 1) + go func() { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(false, p); err != nil { + errc <- err + return + } + if err := cc.bw.Flush(); err != nil { + errc <- err + return + } + }() + select { + case <-c: + return nil + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + // connection closed + return cc.readerErr + } +} + +func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { + if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + // If ack, notify listener if any + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } + return nil + } + cc := rl.cc + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(true, f.Data); err != nil { + return err + } + return cc.bw.Flush() +} + +func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { + // We told the peer we don't want them. + // Spec says: + // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH + // setting of the peer endpoint is set to 0. An endpoint that + // has set this setting and has received acknowledgement MUST + // treat the receipt of a PUSH_PROMISE frame as a connection + // error (Section 5.4.1) of type PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) +} + +func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { + // TODO: map err to more interesting error codes, once the + // HTTP community comes up with some. But currently for + // RST_STREAM there's no equivalent to GOAWAY frame's debug + // data, and the error codes are all pretty vague ("cancel"). + cc.wmu.Lock() + cc.fr.WriteRSTStream(streamID, code) + cc.bw.Flush() + cc.wmu.Unlock() +} + +var ( + http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") +) + +func (cc *http2ClientConn) logf(format string, args ...interface{}) { + cc.t.logf(format, args...) +} + +func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { + cc.t.vlogf(format, args...) +} + +func (t *http2Transport) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + t.logf(format, args...) + } +} + +func (t *http2Transport) logf(format string, args ...interface{}) { + log.Printf(format, args...) +} + +var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) + +type http2missingBody struct{} + +func (http2missingBody) Close() error { return nil } + +func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } + +func http2strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + +type http2erringRoundTripper struct{ err error } + +func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } + +func (rt http2erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, rt.err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type http2gzipReader struct { + _ http2incomparable + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func (gz *http2gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *http2gzipReader) Close() error { + return gz.body.Close() +} + +type http2errorReader struct{ err error } + +func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } + +// isConnectionCloseRequest reports whether req should use its own +// connection for a single request and then close the connection. +func http2isConnectionCloseRequest(req *http.Request) bool { + return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") +} + +// registerHTTPSProtocol calls Transport.RegisterProtocol but +// converting panics into errors. +func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + }() + t.RegisterProtocol("https", rt) + return nil +} + +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +// (The field is exported so it can be accessed via reflect from net/http; tested +// by TestNoDialH2RoundTripperType) +type http2noDialH2RoundTripper struct{ *http2Transport } + +func (rt http2noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + res, err := rt.http2Transport.RoundTrip(req) + if http2isNoCachedConnError(err) { + return nil, http.ErrSkipAltProtocol + } + return res, err +} + +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + +func http2traceGetConn(req *http.Request, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GetConn == nil { + return + } + trace.GetConn(hostPort) +} + +func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + ci.Reused = reused + cc.mu.Lock() + ci.WasIdle = len(cc.streams) == 0 && reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *httptrace.ClientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} + +// writeFramer is implemented by any type that is used to write frames. +type http2writeFramer interface { + writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool +} + +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *serverConn. +// +// TODO: decide whether to a) use this in the client code (which didn't +// end up using this yet, because it has a simpler design, not +// currently implementing priorities), or b) delete this and +// make the server code a bit more concrete. +type http2writeContext interface { + Framer() *http2Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) +} + +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { + switch v := w.(type) { + case *http2writeData: + return v.endStream + case *http2writeResHeaders: + return v.endStream + case nil: + // This can only happen if the caller reuses w after it's + // been intentionally nil'ed out to prevent use. Keep this + // here to catch future refactoring breaking it. + panic("writeEndsStream called on nil writeFramer") + } + return false +} + +type http2flushFrameWriter struct{} + +func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { + return ctx.Flush() +} + +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + +type http2writeSettings []http2Setting + +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + +func (s http2writeSettings) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettings([]http2Setting(s)...) +} + +type http2writeGoAway struct { + maxStreamID uint32 + code http2ErrCode +} + +func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { + err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) + ctx.Flush() // ignore error: we're hanging up on them anyway + return err +} + +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes + +type http2writeData struct { + streamID uint32 + p []byte + endStream bool +} + +func (w *http2writeData) String() string { + return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) +} + +func (w *http2writeData) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) +} + +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + +// handlerPanicRST is the message sent from handler goroutines when +// the handler panics. +type http2handlerPanicRST struct { + StreamID uint32 +} + +func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) +} + +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (se http2StreamError) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) +} + +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +type http2writePingAck struct{ pf *http2PingFrame } + +func (w http2writePingAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WritePing(true, w.pf.Data) +} + +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + +type http2writeSettingsAck struct{} + +func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettingsAck() +} + +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + +// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames +// for HTTP response headers or trailers from a server handler. +type http2writeResHeaders struct { + streamID uint32 + httpResCode int // 0 means no ":status" line + h http.Header // may be nil + trailers []string // if non-nil, which keys of h to write. nil means all. + endStream bool + + date string + contentType string + contentLength string +} + +func http2encKV(enc *hpack.Encoder, k, v string) { + if http2VerboseLogs { + log.Printf("http2: server encoding header %q = %q", k, v) + } + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) +} + +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + // TODO: this is a common one. It'd be nice to return true + // here and get into the fast path if we could be clever and + // calculate the size fast enough, or at least a conservative + // upper bound that usually fires. (Maybe if w.h and + // w.trailers are nil, so we don't need to enumerate it.) + // Otherwise I'm afraid that just calculating the length to + // answer this question would be slower than the ~2µs benefit. + return false +} + +func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + if w.httpResCode != 0 { + http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) + } + + http2encodeHeaders(enc, w.h, w.trailers) + + if w.contentType != "" { + http2encKV(enc, "content-type", w.contentType) + } + if w.contentLength != "" { + http2encKV(enc, "content-length", w.contentLength) + } + if w.date != "" { + http2encKV(enc, "date", w.date) + } + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 && w.trailers == nil { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h http.Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + // TODO: see writeResHeaders.staysWithinBuffer + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +type http2write100ContinueHeadersFrame struct { + streamID uint32 +} + +func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + http2encKV(enc, ":status", "100") + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: buf.Bytes(), + EndStream: false, + EndHeaders: true, + }) +} + +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + // Sloppy but conservative: + return 9+2*(len(":status")+len("100")) <= max +} + +type http2writeWindowUpdate struct { + streamID uint32 // or 0 for conn-level + n uint32 +} + +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) +} + +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only if k is in keys. +func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { + if keys == nil { + sorter := http2sorterPool.Get().(*http2sorter) + // Using defer here, since the returned keys from the + // sorter.Keys method is only valid until the sorter + // is returned: + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) + } + for _, k := range keys { + vv := h[k] + k, ascii := http2lowerHeader(k) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + continue + } + if !http2validWireHeaderFieldName(k) { + // Skip it as backup paranoia. Per + // golang.org/issue/14048, these should + // already be rejected at a higher level. + continue + } + isTE := k == "transfer-encoding" + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // TODO: return an error? golang.org/issue/14048 + // For now just omit it. + continue + } + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" + if isTE && v != "trailers" { + continue + } + http2encKV(enc, k, v) + } + } +} + +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd, except RST_STREAM frames. No frames should be + // discarded except by CloseStream. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { + // write is the interface value that does the writing, once the + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. + write http2writeFramer + + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. + stream *http2stream + + // done, if non-nil, must be a buffered channel with space for + // 1 message and is sent the return value from write (or an + // earlier error) when the frame has been written. + done chan error +} + +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + // (*serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// isControl reports whether wr is a control frame for MaxQueuedControlFrames +// purposes. That includes non-stream frames and RST_STREAM frames. +func (wr http2FrameWriteRequest) isControl() bool { + return wr.stream == nil +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + // Non-DATA frames are always consumed whole. + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + // Might need to split after applying limits. + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + // Even if the original had endStream set, there + // are bytes remaining because len(wd.p) > allowed, + // so we know endStream is false. + endStream: false, + }, + // Our caller is blocking on the final DATA frame, not + // this intermediate frame, so no need to wait. + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 + } + + // The frame is consumed whole. + // NB: This cast cannot overflow because allowed is <= math.MaxInt32. + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { + var des string + if s, ok := wr.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wr.write) + } + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) +} + +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil // prevent use (assume it's tainted after wr.done send) +} + +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} + +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } + +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} + +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + // TODO: less copy-happy queue. + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr +} + +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false + } + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true +} + +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) + if ln == 0 { + return new(http2writeQueue) + } + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] + return q +} + +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 + +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + // For justification of these defaults, see: + // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 + } else { + ws.writeThrottleLimit = math.MaxInt32 + } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings +} + +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") + } + if n.parent == parent { + return + } + // Unlink from current parent. + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + // Link to new parent. + // If parent=nil, remove n from the tree. + // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n + } +} + +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b + } +} + +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this function returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true + } + if n.kids == nil { + return false + } + + // Don't consider the root "open" when updating openParent since + // we can't send data frames on the root stream (only control frames). + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) + } + + // Common case: only one kid or all kids have the same weight. + // Some clients don't use weights; other clients (like web browsers) + // use mostly-linear priority trees. + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break + } + } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } + + // Uncommon case: sort the child nodes. We remove the kids from the parent, + // then re-insert after sorting so we can reuse tmp for future sort calls. + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) + } + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + // Prefer the subtree that has sent fewer bytes relative to its weight. + // See sections 5.3.2 and 5.3.4. + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk + } + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} + +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode + + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // The stream may be currently idle but cannot be opened or closed. + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + // RFC 7540, Section 5.3.5: + // "All streams are initially assigned a non-exclusive dependency on stream 0x0. + // Pushed streams initially depend on their associated stream. In both cases, + // streams are assigned a default weight of 16." + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID + } +} + +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") + } + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) + } + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) + } + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) + } +} + +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } + + // If streamID does not exist, there are two cases: + // - A closed stream that has been removed (this will have ID <= maxID) + // - An idle stream that is being used for "grouping" (this will have ID > maxID) + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return + } + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, + } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } + + // Section 5.3.1: A dependency on a stream that is not currently in the tree + // results in that stream being given a default priority (Section 5.3.5). + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } + + // Ignore if the client tries to make a node its own parent. + if n == parent { + return + } + + // Section 5.3.3: + // "If a stream is made dependent on one of its own dependencies, the + // formerly dependent stream is first moved to be dependent on the + // reprioritized stream's previous parent. The moved dependency retains + // its weight." + // + // That is: if parent depends on n, move parent to depend on n.parent. + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } + + // Section 5.3.3: The exclusive flag causes the stream to become the sole + // dependency of its parent stream, causing other dependencies to become + // dependent on the exclusive stream. + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next + } + } + + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + // id is an idle or closed stream. wr should not be a HEADERS or + // DATA frame. However, wr can be a RST_STREAM. In this case, we + // push wr onto the root, rather than creating a new priorityNode, + // since RST_STREAM is tiny and the stream's priority is unknown + // anyway. See issue #17919. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } + } + n.q.push(wr) +} + +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + // If B depends on A and B continuously has data available but A + // does not, gradually increase the throttling limit to allow B to + // steal more and more bandwidth from A. + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { + return + } + if len(*list) == maxSize { + // Remove the oldest node, then shift left. + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] + } + *list = append(*list, n) +} + +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) +} + +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} + +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle, closed, or emptied, it's deleted + // from the map. + sq map[uint32]*http2writeQueue + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // no-op: idle streams are not tracked +} + +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return + } + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + // no-op: priorities are ignored +} + +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + if wr.isControl() { + ws.zero.push(wr) + return + } + id := wr.StreamID() + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.zero.empty() { + return ws.zero.shift(), true + } + // Iterate over all non-idle streams until finding one that can be consumed. + for streamID, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + if q.empty() { + delete(ws.sq, streamID) + ws.queuePool.put(q) + } + return wr, true + } + } + return http2FrameWriteRequest{}, false +} diff --git a/header.go b/header.go index dbc59925..28e2a9ae 100644 --- a/header.go +++ b/header.go @@ -1,42 +1,118 @@ -/* - GoLang code created by Jirawat Harnsiriwatanakit https://github.com/kazekim -*/ - package req -import "encoding/json" +import ( + "golang.org/x/net/http/httpguts" + "io" + "net/http" + "net/http/httptrace" + "net/textproto" + "sort" + "strings" + "sync" +) -// Header represents http request header -type Header map[string]string +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") -func (h Header) Clone() Header { - if h == nil { - return nil - } - hh := Header{} - for k, v := range h { - hh[k] = v - } - return hh +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) } -// ParseStruct parse struct into header -func ParseStruct(h Header, v interface{}) Header { - data, err := json.Marshal(v) - if err != nil { - return h +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func headerGet(h http.Header, key string) string { + if v := h[key]; len(v) > 0 { + return v[0] } + return "" +} - err = json.Unmarshal(data, &h) - return h +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func headerHas(h http.Header, key string) bool { + _, ok := h[key] + return ok } -// HeaderFromStruct init header from struct -func HeaderFromStruct(v interface{}) Header { +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} - var header Header - header = ParseStruct(header, v) - return header +func headerWrite(h http.Header, w io.Writer, trace *httptrace.ClientTrace) error { + return headerWriteSubset(h, w, nil, trace) } -type ReservedHeader map[string]string +func headerWriteSubset(h http.Header, w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error { + ws, ok := w.(io.StringWriter) + if !ok { + ws = stringWriter{w} + } + kvs, sorter := headerSortedKeyValues(h, exclude) + var formattedVals []string + for _, kv := range kvs { + if !httpguts.ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + headerSorterPool.Put(sorter) + return err + } + } + if trace != nil && trace.WroteHeaderField != nil { + formattedVals = append(formattedVals, v) + } + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(kv.key, formattedVals) + formattedVals = nil + } + } + headerSorterPool.Put(sorter) + return nil +} diff --git a/header_test.go b/header_test.go deleted file mode 100644 index 5e56fac4..00000000 --- a/header_test.go +++ /dev/null @@ -1,47 +0,0 @@ -/* - GoLang code created by Jirawat Harnsiriwatanakit https://github.com/kazekim -*/ - -package req - -import "testing" - -func TestParseStruct(t *testing.T) { - - type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` - } - - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - var header Header - header = ParseStruct(header, h) - - if header["User-Agent"] != h.UserAgent && header["Authorization"] != h.Authorization { - t.Fatal("struct parser for header is not working") - } - -} - -func TestHeaderFromStruct(t *testing.T) { - - type HeaderStruct struct { - UserAgent string `json:"User-Agent"` - Authorization string `json:"Authorization"` - } - - h := HeaderStruct{ - "V1.0.0", - "roc", - } - - header := HeaderFromStruct(h) - - if header["User-Agent"] != h.UserAgent && header["Authorization"] != h.Authorization { - t.Fatal("struct parser for header is not working") - } -} diff --git a/http.go b/http.go new file mode 100644 index 00000000..55f0a2be --- /dev/null +++ b/http.go @@ -0,0 +1,242 @@ +package req + +import ( + "encoding/base64" + "fmt" + "github.com/imroc/req/internal/ascii" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/idna" + "io" + "net" + "net/http" + "net/textproto" + "strings" +) + +// maxInt64 is the effective "infinite" value for the Server and +// Transport's byte-limiting readers. +const maxInt64 = 1<<63 - 1 + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func bodyIsWritable(r *http.Response) bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// isProtocolSwitch reports whether the response code and header +// indicate a successful protocol upgrade response. +func isProtocolSwitch(r *http.Response) bool { + return isProtocolSwitchResponse(r.StatusCode, r.Header) +} + +// isProtocolSwitchResponse reports whether the response code and +// response header indicate a successful protocol upgrade response. +func isProtocolSwitchResponse(code int, h http.Header) bool { + return code == http.StatusSwitchingProtocols && isProtocolSwitchHeader(h) +} + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h http.Header) bool { + return h.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +var ( + // verify that an io.Copy from NoBody won't require a buffer: + _ io.WriterTo = NoBody + _ io.ReadCloser = NoBody +) + +type readResult struct { + _ incomparable + n int + err error + b byte // byte read, if n == 1 +} + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if ascii.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +func idnaASCII(v string) (string, error) { + // TODO: Consider removing this check after verifying performance is okay. + // Right now punycode verification, length checks, context checks, and the + // permissible character tests are all omitted. It also prevents the ToASCII + // call from salvaging an invalid IDN, when possible. As a result it may be + // possible to have two IDNs that appear identical to the user where the + // ASCII-only version causes an error downstream whereas the non-ASCII + // version does not. + // Note that for correct ASCII IDNs ToASCII will only do considerably more + // work, but it will not cause an allocation. + if ascii.Is(v) { + return v, nil + } + return idna.Lookup.ToASCII(v) +} + +// cleanHost cleans up the host sent in request's Host header. +// +// It both strips anything after '/' or ' ', and puts the value +// into Punycode form, if necessary. +// +// Ideally we'd clean the Host header according to the spec: +// https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]") +// https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host) +// https://tools.ietf.org/html/rfc3986#section-3.2.2 (definition of host) +// But practically, what we are trying to avoid is the situation in +// issue 11206, where a malformed Host header used in the proxy context +// would create a bad request. So it is enough to just truncate at the +// first offending character. +func cleanHost(in string) string { + if i := strings.IndexAny(in, " /"); i != -1 { + in = in[:i] + } + host, port, err := net.SplitHostPort(in) + if err != nil { // input was just a host + a, err := idnaASCII(in) + if err != nil { + return in // garbage in, garbage out + } + return a + } + a, err := idnaASCII(host) + if err != nil { + return in // garbage in, garbage out + } + return net.JoinHostPort(a, port) +} + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} diff --git a/http_request.go b/http_request.go new file mode 100644 index 00000000..49439527 --- /dev/null +++ b/http_request.go @@ -0,0 +1,327 @@ +package req + +import ( + "bufio" + "errors" + "fmt" + "github.com/imroc/req/internal/ascii" + "golang.org/x/net/http/httpguts" + "io" + "net/http" + "net/http/httptrace" + "strings" +) + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +func closeBody(r *http.Request) error { + if r.Body == nil { + return nil + } + return r.Body.Close() +} + +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func outgoingLength(r *http.Request) int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// extraHeaders may be nil +// waitForContinue may be nil +// always closes body +func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders http.Header, waitForContinue func() bool, dump *dumper) (err error) { + trace := httptrace.ContextClientTrace(r.Context()) + if trace != nil && trace.WroteRequest != nil { + defer func() { + trace.WroteRequest(httptrace.WroteRequestInfo{ + Err: err, + }) + }() + } + closed := false + defer func() { + if closed { + return + } + if closeErr := closeRequestBody(r); closeErr != nil && err == nil { + err = closeErr + } + }() + + // Find the target host. Prefer the Host: header, but if that + // is not given, use the host from the request URL. + // + // Clean the host, in case it arrives with unexpected stuff in it. + host := cleanHost(r.Host) + if host == "" { + if r.URL == nil { + return errMissingHost + } + host = cleanHost(r.URL.Host) + } + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + host = removeZone(host) + + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + if r.URL.Opaque != "" { + ruri = r.URL.Opaque + } + } + if stringContainsCTLByte(ruri) { + return errors.New("net/http: can't write control character in Request.URL") + } + // TODO: validate r.Method too? At least it's less likely to + // come from an attacker (more likely to be a constant in + // code). + + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + rw := w // raw writer + if dump != nil && dump.RequestHead { + w = dump.WrapWriter(w) + } + + _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) + if err != nil { + return err + } + + // Header lines + _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + if err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Host", []string{host}) + } + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if headerHas(r.Header, "User-Agent") { + userAgent = r.Header.Get("User-Agent") + } + if userAgent != "" { + _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + if err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("User-Agent", []string{userAgent}) + } + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(r) + if err != nil { + return err + } + err = tw.writeHeader(w, trace) + if err != nil { + return err + } + + err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace) + if err != nil { + return err + } + + if extraHeaders != nil { + err = headerWrite(extraHeaders, w, trace) + if err != nil { + return err + } + } + + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } + + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } + + // Flush and wait for 100-continue if expected. + if waitForContinue != nil { + if bw, ok := w.(*bufio.Writer); ok { + err = bw.Flush() + if err != nil { + return err + } + } + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } + if !waitForContinue() { + closed = true + closeRequestBody(r) + return nil + } + } + + if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { + if err := bw.Flush(); err != nil { + return err + } + } + + // Write body and trailer + closed = true + err = tw.writeBody(rw, dump) + if err != nil { + if tw.bodyReadError == err { + err = requestBodyReadError{err} + } + return err + } + + if bw != nil { + return bw.Flush() + } + return nil +} + +func closeRequestBody(r *http.Request) error { + if r.Body == nil { + return nil + } + return r.Body.Close() +} + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// requestMethodUsuallyLacksBody reports whether the given request +// method is one that typically does not involve a request body. +// This is used by the Transport (via +// transferWriter.shouldSendChunkedRequestBody) to determine whether +// we try to test-read a byte from a non-nil Request.Body when +// Request.outgoingLength() returns -1. See the comments in +// shouldSendChunkedRequestBody. +func requestMethodUsuallyLacksBody(method string) bool { + switch method { + case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH": + return true + } + return false +} + +// requiresHTTP1 reports whether this request requires being sent on +// an HTTP/1 connection. +func requestRequiresHTTP1(r *http.Request) bool { + return hasToken(r.Header.Get("Connection"), "upgrade") && + ascii.EqualFold(r.Header.Get("Upgrade"), "websocket") +} + +func isReplayable(r *http.Request) bool { + if r.Body == nil || r.Body == NoBody || r.GetBody != nil { + switch valueOrDefault(r.Method, "GET") { + case "GET", "HEAD", "OPTIONS", "TRACE": + return true + } + // The Idempotency-Key, while non-standard, is widely used to + // mean a POST or other request is idempotent. See + // https://golang.org/issue/19943#issuecomment-421092421 + if headerHas(r.Header, "Idempotency-Key") || headerHas(r.Header, "X-Idempotency-Key") { + return true + } + } + return false +} + +func reqExpectsContinue(r *http.Request) bool { + return hasToken(headerGet(r.Header, "Expect"), "100-continue") +} + +func reqWantsHttp10KeepAlive(r *http.Request) bool { + if r.ProtoMajor != 1 || r.ProtoMinor != 0 { + return false + } + return hasToken(headerGet(r.Header, "Connection"), "keep-alive") +} + +func reqWantsClose(r *http.Request) bool { + if r.Close { + return true + } + return hasToken(headerGet(r.Header, "Connection"), "close") +} diff --git a/http_response.go b/http_response.go new file mode 100644 index 00000000..db636701 --- /dev/null +++ b/http_response.go @@ -0,0 +1,92 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Response reading and parsing. + +package req + +import ( + "io" + "net/http" + "strconv" + "strings" +) + +var respExcludeHeader = map[string]bool{ + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// ReadResponse reads and returns an HTTP response from r. +// The req parameter optionally specifies the Request that corresponds +// to this Response. If nil, a GET request is assumed. +// Clients must call resp.Body.Close when finished reading resp.Body. +// After that call, clients can inspect resp.Trailer to find key/value +// pairs included in the response trailer. +func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { + //var tp headReader + tp := newTextprotoReader(pc.br, pc.t.dump) + resp := &http.Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := strings.Cut(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = http.Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + err = readTransfer(resp, pc.br) + if err != nil { + return nil, err + } + + return resp, nil +} + +// RFC 7234, section 5.4: Should treat +// Pragma: no-cache +// like +// Cache-Control: no-cache +func fixPragmaCacheControl(header http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} diff --git a/internal/ascii/print.go b/internal/ascii/print.go new file mode 100644 index 00000000..585e5bab --- /dev/null +++ b/internal/ascii/print.go @@ -0,0 +1,61 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ascii + +import ( + "strings" + "unicode" +) + +// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// Is returns whether s is ASCII. +func Is(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} diff --git a/internal/ascii/print_test.go b/internal/ascii/print_test.go new file mode 100644 index 00000000..0b7767ca --- /dev/null +++ b/internal/ascii/print_test.go @@ -0,0 +1,95 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ascii + +import "testing" + +func TestEqualFold(t *testing.T) { + var tests = []struct { + name string + a, b string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "simple match", + a: "CHUNKED", + b: "chunked", + want: true, + }, + { + name: "same string", + a: "chunked", + b: "chunked", + want: true, + }, + { + name: "Unicode Kelvin symbol", + a: "chunKed", // This "K" is 'KELVIN SIGN' (\u212A) + b: "chunked", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := EqualFold(tt.a, tt.b); got != tt.want { + t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestIsPrint(t *testing.T) { + var tests = []struct { + name string + in string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "ASCII low", + in: "This is a space: ' '", + want: true, + }, + { + name: "ASCII high", + in: "This is a tilde: '~'", + want: true, + }, + { + name: "ASCII low non-print", + in: "This is a unit separator: \x1F", + want: false, + }, + { + name: "Ascii high non-print", + in: "This is a Delete: \x7F", + want: false, + }, + { + name: "Unicode letter", + in: "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A) + want: false, + }, + { + name: "Unicode emoji", + in: "Gophers like 🧀", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsPrint(tt.in); got != tt.want { + t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want) + } + }) + } +} diff --git a/internal/chunked.go b/internal/chunked.go new file mode 100644 index 00000000..37a72e90 --- /dev/null +++ b/internal/chunked.go @@ -0,0 +1,261 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// Package internal contains HTTP internals shared by net/http and +// net/http/httputil. +package internal + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line []byte + line, cr.err = readChunkLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = parseHexUint(line) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 + } + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } else { + if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + break + } + cr.checkEnd = false + } + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue + } + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + cr.checkEnd = true + } else if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are owned by the bufio.Reader +// so they are only valid until the next bufio read. +func readChunkLine(b *bufio.Reader) ([]byte, error) { + p, err := b.ReadSlice('\n') + if err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + p = trimTrailingWhitespace(p) + p, err = removeChunkExtension(p) + if err != nil { + return nil, err + } + return p, nil +} + +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +var semi = []byte(";") + +// removeChunkExtension removes any chunk-extension from p. +// For example, +// "0" => "0" +// "0;token" => "0" +// "0;token=val" => "0" +// `0;token="quoted string"` => "0" +func removeChunkExtension(p []byte) ([]byte, error) { + p, _, _ = bytes.Cut(p, semi) + // TODO: care about exact syntax of chunk extensions? We're + // ignoring and stripping them anyway. For now just never + // return an error. + return p, nil +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil { + return + } + if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok { + err = bw.Flush() + } + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} + +// FlushAfterChunkWriter signals from the caller of NewChunkedWriter +// that each chunk should be followed by a flush. It is used by the +// http.Transport code to keep the buffering behavior for headers and +// trailers, but flush out chunks aggressively in the middle for +// request bodies which may be generated slowly. See Issue 6574. +type FlushAfterChunkWriter struct { + *bufio.Writer +} + +func parseHexUint(v []byte) (n uint64, err error) { + for i, b := range v { + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + if i == 16 { + return 0, errors.New("http chunk length too large") + } + n <<= 4 + n |= uint64(b) + } + return +} diff --git a/internal/chunked_test.go b/internal/chunked_test.go new file mode 100644 index 00000000..5e29a786 --- /dev/null +++ b/internal/chunked_test.go @@ -0,0 +1,241 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" + "testing" + "testing/iotest" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := NewChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := NewChunkedReader(&b) + data, err := io.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} + +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := NewChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := NewChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + +func TestChunkReaderAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, io.SeekStart) + bufr.Reset(byter) + r := NewChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) + } +} + +func TestParseHexUint(t *testing.T) { + type testCase struct { + in string + want uint64 + wantErr string + } + tests := []testCase{ + {"x", 0, "invalid byte in chunk length"}, + {"0000000000000000", 0, ""}, + {"0000000000000001", 1, ""}, + {"ffffffffffffffff", 1<<64 - 1, ""}, + {"000000000000bogus", 0, "invalid byte in chunk length"}, + {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted + {"10000000000000000", 0, "http chunk length too large"}, + {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + } + for i := uint64(0); i <= 1234; i++ { + tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) + } + for _, tt := range tests { + got, err := parseHexUint([]byte(tt.in)) + if tt.wantErr != "" { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr) + } + } else { + if err != nil || got != tt.want { + t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want) + } + } + } +} + +func TestChunkReadingIgnoresExtensions(t *testing.T) { + in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string + "hello, \r\n" + + "17;someext\r\n" + // token without value + "world! 0123456789abcdef\r\n" + + "0;someextension=sometoken\r\n" // token=token + data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in))) + if err != nil { + t.Fatalf("ReadAll = %q, %v", data, err) + } + if g, e := string(data), "hello, world! 0123456789abcdef"; g != e { + t.Errorf("read %q; want %q", g, e) + } +} + +// Issue 17355: ChunkedReader shouldn't block waiting for more data +// if it can return something. +func TestChunkReadPartial(t *testing.T) { + pr, pw := io.Pipe() + go func() { + pw.Write([]byte("7\r\n1234567")) + }() + cr := NewChunkedReader(pr) + readBuf := make([]byte, 7) + n, err := cr.Read(readBuf) + if err != nil { + t.Fatal(err) + } + want := "1234567" + if n != 7 || string(readBuf) != want { + t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want) + } + go func() { + pw.Write([]byte("xx")) + }() + _, err = cr.Read(readBuf) + if got := fmt.Sprint(err); !strings.Contains(got, "malformed") { + t.Fatalf("second read = %v; want malformed error", err) + } + +} + +// Issue 48861: ChunkedReader should report incomplete chunks +func TestIncompleteChunk(t *testing.T) { + const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n" + + for i := 0; i < len(valid); i++ { + incomplete := valid[:i] + r := NewChunkedReader(strings.NewReader(incomplete)) + if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF { + t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err) + } + } + + r := NewChunkedReader(strings.NewReader(valid)) + if _, err := io.ReadAll(r); err != nil { + t.Errorf("unexpected error for %q: %v", valid, err) + } +} + +func TestChunkEndReadError(t *testing.T) { + readErr := fmt.Errorf("chunk end read error") + + r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr))) + if _, err := io.ReadAll(r); err != readErr { + t.Errorf("expected %v, got %v", readErr, err) + } +} diff --git a/internal/godebug/godebug.go b/internal/godebug/godebug.go new file mode 100644 index 00000000..ac434e5f --- /dev/null +++ b/internal/godebug/godebug.go @@ -0,0 +1,34 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package godebug parses the GODEBUG environment variable. +package godebug + +import "os" + +// Get returns the value for the provided GODEBUG key. +func Get(key string) string { + return get(os.Getenv("GODEBUG"), key) +} + +// get returns the value part of key=value in s (a GODEBUG value). +func get(s, key string) string { + for i := 0; i < len(s)-len(key)-1; i++ { + if i > 0 && s[i-1] != ',' { + continue + } + afterKey := s[i+len(key):] + if afterKey[0] != '=' || s[i:i+len(key)] != key { + continue + } + val := afterKey[1:] + for i, b := range val { + if b == ',' { + return val[:i] + } + } + return val + } + return "" +} diff --git a/internal/godebug/godebug_test.go b/internal/godebug/godebug_test.go new file mode 100644 index 00000000..41b9117b --- /dev/null +++ b/internal/godebug/godebug_test.go @@ -0,0 +1,34 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package godebug + +import "testing" + +func TestGet(t *testing.T) { + tests := []struct { + godebug string + key string + want string + }{ + {"", "", ""}, + {"", "foo", ""}, + {"foo=bar", "foo", "bar"}, + {"foo=bar,after=x", "foo", "bar"}, + {"before=x,foo=bar,after=x", "foo", "bar"}, + {"before=x,foo=bar", "foo", "bar"}, + {",,,foo=bar,,,", "foo", "bar"}, + {"foodecoy=wrong,foo=bar", "foo", "bar"}, + {"foo=", "foo", ""}, + {"foo", "foo", ""}, + {",foo", "foo", ""}, + {"foo=bar,baz", "loooooooong", ""}, + } + for _, tt := range tests { + got := get(tt.godebug, tt.key) + if got != tt.want { + t.Errorf("get(%q, %q) = %q; want %q", tt.godebug, tt.key, got, tt.want) + } + } +} diff --git a/req.go b/req.go deleted file mode 100644 index 99c04b3d..00000000 --- a/req.go +++ /dev/null @@ -1,690 +0,0 @@ -package req - -import ( - "bytes" - "compress/gzip" - "context" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "io/ioutil" - "mime/multipart" - "net/http" - "net/textproto" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "time" -) - -// default *Req -var std = New() - -// flags to decide which part can be outputed -const ( - LreqHead = 1 << iota // output request head (request line and request header) - LreqBody // output request body - LrespHead // output response head (response line and response header) - LrespBody // output response body - Lcost // output time costed by the request - LstdFlags = LreqHead | LreqBody | LrespHead | LrespBody -) - -// Param represents http request param -type Param map[string]interface{} - -// QueryParam is used to force append http request param to the uri -type QueryParam map[string]interface{} - -// Host is used for set request's Host -type Host string - -// FileUpload represents a file to upload -type FileUpload struct { - // filename in multipart form. - FileName string - // form field name - FieldName string - // file to uplaod, required - File io.ReadCloser -} - -type DownloadProgress func(current, total int64) - -type UploadProgress func(current, total int64) - -// File upload files matching the name pattern such as -// /usr/*/bin/go* (assuming the Separator is '/') -func File(patterns ...string) interface{} { - matches := []string{} - for _, pattern := range patterns { - m, err := filepath.Glob(pattern) - if err != nil { - return err - } - matches = append(matches, m...) - } - if len(matches) == 0 { - return errors.New("req: no file have been matched") - } - uploads := []FileUpload{} - for _, match := range matches { - if s, e := os.Stat(match); e != nil || s.IsDir() { - continue - } - file, _ := os.Open(match) - uploads = append(uploads, FileUpload{ - File: file, - FileName: filepath.Base(match), - FieldName: "media", - }) - } - - return uploads -} - -type bodyJson struct { - v interface{} -} - -type bodyXml struct { - v interface{} -} - -// BodyJSON make the object be encoded in json format and set it to the request body -func BodyJSON(v interface{}) *bodyJson { - return &bodyJson{v: v} -} - -// BodyXML make the object be encoded in xml format and set it to the request body -func BodyXML(v interface{}) *bodyXml { - return &bodyXml{v: v} -} - -// Req is a convenient client for initiating requests -type Req struct { - client *http.Client - jsonEncOpts *jsonEncOpts - xmlEncOpts *xmlEncOpts - flag int - progressInterval time.Duration -} - -// New create a new *Req -func New() *Req { - // default progress reporting interval is 200 milliseconds - return &Req{flag: LstdFlags, progressInterval: 200 * time.Millisecond} -} - -type param struct { - url.Values -} - -func (p *param) getValues() url.Values { - if p.Values == nil { - p.Values = make(url.Values) - } - return p.Values -} - -func (p *param) Copy(pp param) { - if pp.Values == nil { - return - } - vs := p.getValues() - for key, values := range pp.Values { - for _, value := range values { - vs.Add(key, value) - } - } -} -func (p *param) Adds(m map[string]interface{}) { - if len(m) == 0 { - return - } - vs := p.getValues() - for k, v := range m { - vs.Add(k, fmt.Sprint(v)) - } -} - -func (p *param) Empty() bool { - return p.Values == nil -} - -// Do execute a http request with sepecify method and url, -// and it can also have some optional params, depending on your needs. -func (r *Req) Do(method, rawurl string, vs ...interface{}) (resp *Resp, err error) { - if rawurl == "" { - return nil, errors.New("req: url not specified") - } - req := &http.Request{ - Method: method, - Header: make(http.Header), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - } - resp = &Resp{req: req, r: r} - - // output detail if Debug is enabled - if Debug { - defer func(resp *Resp) { - fmt.Println(resp.Dump()) - }(resp) - } - - var queryParam param - var formParam param - var uploads []FileUpload - var uploadProgress UploadProgress - var progress func(int64, int64) - var delayedFunc []func() - var lastFunc []func() - - for _, v := range vs { - switch vv := v.(type) { - case Header: - for key, value := range vv { - req.Header.Add(key, value) - } - case http.Header: - for key, values := range vv { - for _, value := range values { - req.Header.Add(key, value) - } - } - case ReservedHeader: - for key, value := range vv { - req.Header[key] = []string{value} - } - case *bodyJson: - fn, err := setBodyJson(req, resp, r.jsonEncOpts, vv.v) - if err != nil { - return nil, err - } - delayedFunc = append(delayedFunc, fn) - case *bodyXml: - fn, err := setBodyXml(req, resp, r.xmlEncOpts, vv.v) - if err != nil { - return nil, err - } - delayedFunc = append(delayedFunc, fn) - case url.Values: - p := param{vv} - if method == "GET" || method == "HEAD" { - queryParam.Copy(p) - } else { - formParam.Copy(p) - } - case Param: - if method == "GET" || method == "HEAD" { - queryParam.Adds(vv) - } else { - formParam.Adds(vv) - } - case QueryParam: - queryParam.Adds(vv) - case string: - setBodyBytes(req, resp, []byte(vv)) - case []byte: - setBodyBytes(req, resp, vv) - case bytes.Buffer: - setBodyBytes(req, resp, vv.Bytes()) - case *http.Client: - resp.client = vv - case FileUpload: - uploads = append(uploads, vv) - case []FileUpload: - uploads = append(uploads, vv...) - case *http.Cookie: - req.AddCookie(vv) - case Host: - req.Host = string(vv) - case io.Reader: - fn := setBodyReader(req, resp, vv) - lastFunc = append(lastFunc, fn) - case UploadProgress: - uploadProgress = vv - case DownloadProgress: - resp.downloadProgress = vv - case func(int64, int64): - progress = vv - case context.Context: - req = req.WithContext(vv) - resp.req = req - case error: - return nil, vv - } - } - - if length := req.Header.Get("Content-Length"); length != "" { - if l, err := strconv.ParseInt(length, 10, 64); err == nil { - req.ContentLength = l - } - } - - if len(uploads) > 0 && (req.Method == "POST" || req.Method == "PUT") { // multipart - var up UploadProgress - if uploadProgress != nil { - up = uploadProgress - } else if progress != nil { - up = UploadProgress(progress) - } - multipartHelper := &multipartHelper{ - form: formParam.Values, - uploads: uploads, - uploadProgress: up, - progressInterval: resp.r.progressInterval, - } - multipartHelper.Upload(req) - resp.multipartHelper = multipartHelper - } else { - if progress != nil { - resp.downloadProgress = DownloadProgress(progress) - } - if !formParam.Empty() { - if req.Body != nil { - queryParam.Copy(formParam) - } else { - setBodyBytes(req, resp, []byte(formParam.Encode())) - setContentType(req, "application/x-www-form-urlencoded; charset=UTF-8") - } - } - } - - if !queryParam.Empty() { - paramStr := queryParam.Encode() - if strings.IndexByte(rawurl, '?') == -1 { - rawurl = rawurl + "?" + paramStr - } else { - rawurl = rawurl + "&" + paramStr - } - } - - u, err := url.Parse(rawurl) - if err != nil { - return nil, err - } - req.URL = u - - if host := req.Header.Get("Host"); host != "" { - req.Host = host - } - - for _, fn := range delayedFunc { - fn() - } - - if resp.client == nil { - resp.client = r.Client() - } - - var response *http.Response - if r.flag&Lcost != 0 { - before := time.Now() - response, err = resp.client.Do(req) - after := time.Now() - resp.cost = after.Sub(before) - } else { - response, err = resp.client.Do(req) - } - if err != nil { - return nil, err - } - - for _, fn := range lastFunc { - fn() - } - - resp.resp = response - - if _, ok := resp.client.Transport.(*http.Transport); ok && response.Header.Get("Content-Encoding") == "gzip" && req.Header.Get("Accept-Encoding") != "" { - body, err := gzip.NewReader(response.Body) - if err != nil { - return nil, err - } - response.Body = body - } - - return -} - -func setBodyBytes(req *http.Request, resp *Resp, data []byte) { - resp.reqBody = data - req.Body = ioutil.NopCloser(bytes.NewReader(data)) - req.ContentLength = int64(len(data)) -} - -func setBodyJson(req *http.Request, resp *Resp, opts *jsonEncOpts, v interface{}) (func(), error) { - var data []byte - switch vv := v.(type) { - case string: - data = []byte(vv) - case []byte: - data = vv - case *bytes.Buffer: - data = vv.Bytes() - default: - if opts != nil { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetIndent(opts.indentPrefix, opts.indentValue) - enc.SetEscapeHTML(opts.escapeHTML) - err := enc.Encode(v) - if err != nil { - return nil, err - } - data = buf.Bytes() - } else { - var err error - data, err = json.Marshal(v) - if err != nil { - return nil, err - } - } - } - setBodyBytes(req, resp, data) - delayedFunc := func() { - setContentType(req, "application/json; charset=UTF-8") - } - return delayedFunc, nil -} - -func setBodyXml(req *http.Request, resp *Resp, opts *xmlEncOpts, v interface{}) (func(), error) { - var data []byte - switch vv := v.(type) { - case string: - data = []byte(vv) - case []byte: - data = vv - case *bytes.Buffer: - data = vv.Bytes() - default: - if opts != nil { - var buf bytes.Buffer - enc := xml.NewEncoder(&buf) - enc.Indent(opts.prefix, opts.indent) - err := enc.Encode(v) - if err != nil { - return nil, err - } - data = buf.Bytes() - } else { - var err error - data, err = xml.Marshal(v) - if err != nil { - return nil, err - } - } - } - setBodyBytes(req, resp, data) - delayedFunc := func() { - setContentType(req, "application/xml; charset=UTF-8") - } - return delayedFunc, nil -} - -func setContentType(req *http.Request, contentType string) { - if req.Header.Get("Content-Type") == "" { - req.Header.Set("Content-Type", contentType) - } -} - -func setBodyReader(req *http.Request, resp *Resp, rd io.Reader) func() { - var rc io.ReadCloser - switch r := rd.(type) { - case *os.File: - stat, err := r.Stat() - if err == nil { - req.ContentLength = stat.Size() - } - rc = r - - case io.ReadCloser: - rc = r - default: - rc = ioutil.NopCloser(rd) - } - bw := &bodyWrapper{ - ReadCloser: rc, - limit: 102400, - } - req.Body = bw - lastFunc := func() { - resp.reqBody = bw.buf.Bytes() - } - return lastFunc -} - -type bodyWrapper struct { - io.ReadCloser - buf bytes.Buffer - limit int -} - -func (b *bodyWrapper) Read(p []byte) (n int, err error) { - n, err = b.ReadCloser.Read(p) - if left := b.limit - b.buf.Len(); left > 0 && n > 0 { - if n <= left { - b.buf.Write(p[:n]) - } else { - b.buf.Write(p[:left]) - } - } - return -} - -type multipartHelper struct { - form url.Values - uploads []FileUpload - dump []byte - uploadProgress UploadProgress - progressInterval time.Duration -} - -func (m *multipartHelper) Upload(req *http.Request) { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for key, values := range m.form { - for _, value := range values { - bodyWriter.WriteField(key, value) - } - } - var upload func(io.Writer, io.Reader) error - if m.uploadProgress != nil { - var total int64 - for _, up := range m.uploads { - if file, ok := up.File.(*os.File); ok { - stat, err := file.Stat() - if err != nil { - continue - } - total += stat.Size() - } - } - var current int64 - buf := make([]byte, 1024) - var lastTime time.Time - - defer func() { - m.uploadProgress(current, total) - }() - - upload = func(w io.Writer, r io.Reader) error { - for { - n, err := r.Read(buf) - if n > 0 { - _, _err := w.Write(buf[:n]) - if _err != nil { - return _err - } - current += int64(n) - if now := time.Now(); now.Sub(lastTime) > m.progressInterval { - lastTime = now - m.uploadProgress(current, total) - } - } - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } - } - } - - i := 0 - for _, up := range m.uploads { - if up.FieldName == "" { - i++ - up.FieldName = "file" + strconv.Itoa(i) - } - fileWriter, err := bodyWriter.CreateFormFile(up.FieldName, up.FileName) - if err != nil { - continue - } - //iocopy - if upload == nil { - io.Copy(fileWriter, up.File) - } else { - if _, ok := up.File.(*os.File); ok { - upload(fileWriter, up.File) - } else { - io.Copy(fileWriter, up.File) - } - } - up.File.Close() - } - bodyWriter.Close() - pw.Close() - }() - req.Header.Set("Content-Type", bodyWriter.FormDataContentType()) - req.Body = ioutil.NopCloser(pr) -} - -func (m *multipartHelper) Dump() []byte { - if m.dump != nil { - return m.dump - } - var buf bytes.Buffer - bodyWriter := multipart.NewWriter(&buf) - for key, values := range m.form { - for _, value := range values { - m.writeField(bodyWriter, key, value) - } - } - for _, up := range m.uploads { - m.writeFile(bodyWriter, up.FieldName, up.FileName) - } - bodyWriter.Close() - m.dump = buf.Bytes() - return m.dump -} - -func (m *multipartHelper) writeField(w *multipart.Writer, fieldname, value string) error { - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", - fmt.Sprintf(`form-data; name="%s"`, fieldname)) - p, err := w.CreatePart(h) - if err != nil { - return err - } - _, err = p.Write([]byte(value)) - return err -} - -func (m *multipartHelper) writeFile(w *multipart.Writer, fieldname, filename string) error { - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", - fmt.Sprintf(`form-data; name="%s"; filename="%s"`, - fieldname, filename)) - h.Set("Content-Type", "application/octet-stream") - p, err := w.CreatePart(h) - if err != nil { - return err - } - _, err = p.Write([]byte("******")) - return err -} - -// Get execute a http GET request -func (r *Req) Get(url string, v ...interface{}) (*Resp, error) { - return r.Do("GET", url, v...) -} - -// Post execute a http POST request -func (r *Req) Post(url string, v ...interface{}) (*Resp, error) { - return r.Do("POST", url, v...) -} - -// Put execute a http PUT request -func (r *Req) Put(url string, v ...interface{}) (*Resp, error) { - return r.Do("PUT", url, v...) -} - -// Patch execute a http PATCH request -func (r *Req) Patch(url string, v ...interface{}) (*Resp, error) { - return r.Do("PATCH", url, v...) -} - -// Delete execute a http DELETE request -func (r *Req) Delete(url string, v ...interface{}) (*Resp, error) { - return r.Do("DELETE", url, v...) -} - -// Head execute a http HEAD request -func (r *Req) Head(url string, v ...interface{}) (*Resp, error) { - return r.Do("HEAD", url, v...) -} - -// Options execute a http OPTIONS request -func (r *Req) Options(url string, v ...interface{}) (*Resp, error) { - return r.Do("OPTIONS", url, v...) -} - -// Get execute a http GET request -func Get(url string, v ...interface{}) (*Resp, error) { - return std.Get(url, v...) -} - -// Post execute a http POST request -func Post(url string, v ...interface{}) (*Resp, error) { - return std.Post(url, v...) -} - -// Put execute a http PUT request -func Put(url string, v ...interface{}) (*Resp, error) { - return std.Put(url, v...) -} - -// Head execute a http HEAD request -func Head(url string, v ...interface{}) (*Resp, error) { - return std.Head(url, v...) -} - -// Options execute a http OPTIONS request -func Options(url string, v ...interface{}) (*Resp, error) { - return std.Options(url, v...) -} - -// Delete execute a http DELETE request -func Delete(url string, v ...interface{}) (*Resp, error) { - return std.Delete(url, v...) -} - -// Patch execute a http PATCH request -func Patch(url string, v ...interface{}) (*Resp, error) { - return std.Patch(url, v...) -} - -// Do execute request. -func Do(method, url string, v ...interface{}) (*Resp, error) { - return std.Do(method, url, v...) -} diff --git a/req_test.go b/req_test.go deleted file mode 100644 index a863941a..00000000 --- a/req_test.go +++ /dev/null @@ -1,313 +0,0 @@ -package req - -import ( - "bytes" - "encoding/json" - "encoding/xml" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestUrlParam(t *testing.T) { - m := map[string]interface{}{ - "access_token": "123abc", - "name": "roc", - "enc": "中文", - } - queryHandler := func(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - for key, value := range m { - if v := query.Get(key); value != v { - t.Errorf("query param %s = %s; want = %s", key, v, value) - } - } - } - ts := httptest.NewServer(http.HandlerFunc(queryHandler)) - _, err := Get(ts.URL, QueryParam(m)) - if err != nil { - t.Fatal(err) - } - _, err = Head(ts.URL, Param(m)) - if err != nil { - t.Fatal(err) - } - _, err = Put(ts.URL, QueryParam(m)) - if err != nil { - t.Fatal(err) - } -} - -func TestFormParam(t *testing.T) { - formParam := Param{ - "access_token": "123abc", - "name": "roc", - "enc": "中文", - } - formHandler := func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - for key, value := range formParam { - if v := r.FormValue(key); value != v { - t.Errorf("form param %s = %s; want = %s", key, v, value) - } - } - } - ts := httptest.NewServer(http.HandlerFunc(formHandler)) - url := ts.URL - _, err := Post(url, formParam) - if err != nil { - t.Fatal(err) - } -} - -func TestParamWithBody(t *testing.T) { - reqBody := "request body" - p := Param{ - "name": "roc", - "job": "programmer", - } - buf := bytes.NewBufferString(reqBody) - ts := newDefaultTestServer() - r, err := Post(ts.URL, p, buf) - if err != nil { - t.Fatal(err) - } - if r.Request().URL.Query().Get("name") != "roc" { - t.Error("param should in the url when set body manually") - } - if string(r.reqBody) != reqBody { - t.Error("request body not equal") - } -} - -func TestParamBoth(t *testing.T) { - urlParam := QueryParam{ - "access_token": "123abc", - "enc": "中文", - } - formParam := Param{ - "name": "roc", - "job": "软件工程师", - } - handler := func(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - for key, value := range urlParam { - if v := query.Get(key); value != v { - t.Errorf("query param %s = %s; want = %s", key, v, value) - } - } - r.ParseForm() - for key, value := range formParam { - if v := r.FormValue(key); value != v { - t.Errorf("form param %s = %s; want = %s", key, v, value) - } - } - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - url := ts.URL - _, err := Patch(url, urlParam, formParam) - if err != nil { - t.Fatal(err) - } - -} - -func TestBody(t *testing.T) { - body := "request body" - handler := func(w http.ResponseWriter, r *http.Request) { - bs, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - if string(bs) != body { - t.Errorf("body = %s; want = %s", bs, body) - } - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - - // string - _, err := Post(ts.URL, body) - if err != nil { - t.Fatal(err) - } - - // []byte - _, err = Post(ts.URL, []byte(body)) - if err != nil { - t.Fatal(err) - } - - // *bytes.Buffer - var buf bytes.Buffer - buf.WriteString(body) - _, err = Post(ts.URL, &buf) - if err != nil { - t.Fatal(err) - } - - // io.Reader - _, err = Post(ts.URL, strings.NewReader(body)) - if err != nil { - t.Fatal(err) - } -} - -func TestBodyJSON(t *testing.T) { - type content struct { - Code int `json:"code"` - Msg string `json:"msg"` - } - c := content{ - Code: 1, - Msg: "ok", - } - checkData := func(data []byte) { - var cc content - err := json.Unmarshal(data, &cc) - if err != nil { - t.Fatal(err) - } - if cc != c { - t.Errorf("request body = %+v; want = %+v", cc, c) - } - } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - checkData(data) - }) - - ts := httptest.NewServer(handler) - resp, err := Post(ts.URL, BodyJSON(&c)) - if err != nil { - t.Fatal(err) - } - checkData(resp.reqBody) - - SetJSONEscapeHTML(false) - SetJSONIndent("", "\t") - resp, err = Put(ts.URL, BodyJSON(&c)) - if err != nil { - t.Fatal(err) - } - checkData(resp.reqBody) -} - -func TestBodyXML(t *testing.T) { - type content struct { - Code int `xml:"code"` - Msg string `xml:"msg"` - } - c := content{ - Code: 1, - Msg: "ok", - } - checkData := func(data []byte) { - var cc content - err := xml.Unmarshal(data, &cc) - if err != nil { - t.Fatal(err) - } - if cc != c { - t.Errorf("request body = %+v; want = %+v", cc, c) - } - } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - checkData(data) - }) - - ts := httptest.NewServer(handler) - resp, err := Post(ts.URL, BodyXML(&c)) - if err != nil { - t.Fatal(err) - } - checkData(resp.reqBody) - - SetXMLIndent("", " ") - resp, err = Put(ts.URL, BodyXML(&c)) - if err != nil { - t.Fatal(err) - } - checkData(resp.reqBody) -} - -func TestHeader(t *testing.T) { - header := Header{ - "User-Agent": "V1.0.0", - "Authorization": "roc", - } - handler := func(w http.ResponseWriter, r *http.Request) { - for key, value := range header { - if v := r.Header.Get(key); value != v { - t.Errorf("header %q = %s; want = %s", key, v, value) - } - } - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - _, err := Head(ts.URL, header) - if err != nil { - t.Fatal(err) - } - - httpHeader := make(http.Header) - for key, value := range header { - httpHeader.Add(key, value) - } - _, err = Head(ts.URL, httpHeader) - if err != nil { - t.Fatal(err) - } -} - -func TestUpload(t *testing.T) { - str := "hello req" - file := ioutil.NopCloser(strings.NewReader(str)) - upload := FileUpload{ - File: file, - FieldName: "media", - FileName: "hello.txt", - } - handler := func(w http.ResponseWriter, r *http.Request) { - mr, err := r.MultipartReader() - if err != nil { - t.Fatal(err) - } - for { - p, err := mr.NextPart() - if err != nil { - break - } - if p.FileName() != upload.FileName { - t.Errorf("filename = %s; want = %s", p.FileName(), upload.FileName) - } - if p.FormName() != upload.FieldName { - t.Errorf("formname = %s; want = %s", p.FileName(), upload.FileName) - } - data, err := ioutil.ReadAll(p) - if err != nil { - t.Fatal(err) - } - if string(data) != str { - t.Errorf("file content = %s; want = %s", data, str) - } - } - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - _, err := Post(ts.URL, upload) - if err != nil { - t.Fatal(err) - } - ts = newDefaultTestServer() - _, err = Post(ts.URL, File("*.go")) - if err != nil { - t.Fatal(err) - } -} diff --git a/request.go b/request.go new file mode 100644 index 00000000..d885b18a --- /dev/null +++ b/request.go @@ -0,0 +1,226 @@ +package req + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/hashicorp/go-multierror" + "io" + "net/http" + urlpkg "net/url" + "strings" +) + +type Request struct { + error error + client *Client + httpRequest *http.Request +} + +func New() *Request { + return defaultClient.R() +} + +func (r *Request) appendError(err error) { + r.error = multierror.Append(r.error, err) +} + +func (r *Request) Error() error { + return r.error +} + +func (r *Request) Method(method string) *Request { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + if !validMethod(method) { + err := fmt.Errorf("net/http: invalid method %q", method) + if err != nil { + r.appendError(err) + } + } + r.httpRequest.Method = method + r.httpRequest = r.httpRequest.WithContext(context.Background()) + return r +} + +func (r *Request) URL(url string) *Request { + u, err := urlpkg.Parse(url) + if err != nil { + r.appendError(err) + return r + } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) + r.httpRequest.URL = u + r.httpRequest.Host = u.Host + return r +} + +func (r *Request) send(method, url string) (*Response, error) { + return r.Method(method).URL(url).Send() +} + +func (r *Request) MustGet(url string) *Response { + resp, err := r.Get(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Get(url string) (*Response, error) { + return r.send(http.MethodGet, url) +} + +func (r *Request) MustPost(url string) *Response { + resp, err := r.Post(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Post(url string) (*Response, error) { + return r.send(http.MethodPost, url) +} + +func (r *Request) MustPut(url string) *Response { + resp, err := r.Put(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Put(url string) (*Response, error) { + return r.send(http.MethodPut, url) +} + +func (r *Request) MustPatch(url string) *Response { + resp, err := r.Patch(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Patch(url string) (*Response, error) { + return r.send(http.MethodPatch, url) +} + +func (r *Request) MustDelete(url string) *Response { + resp, err := r.Delete(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Delete(url string) (*Response, error) { + return r.send(http.MethodDelete, url) +} + +func (r *Request) MustOptions(url string) *Response { + resp, err := r.Options(url) + if err != nil { + panic(err) + } + return resp +} + +func (r *Request) Options(url string) (*Response, error) { + return r.send(http.MethodOptions, url) +} + +func (r *Request) MustHead(url string) (*Response, error) { + return r.send(http.MethodHead, url) +} + +func (r *Request) Head(url string) (*Response, error) { + return r.send(http.MethodHead, url) +} + +func (r *Request) Body(body interface{}) *Request { + if body == nil { + return r + } + switch b := body.(type) { + case io.ReadCloser: + r.httpRequest.Body = b + case io.Reader: + r.httpRequest.Body = io.NopCloser(b) + case []byte: + r.BodyBytes(b) + case string: + r.BodyString(b) + } + return r +} + +func (r *Request) BodyBytes(body []byte) *Request { + r.httpRequest.Body = io.NopCloser(bytes.NewReader(body)) + return r +} + +func (r *Request) BodyString(body string) *Request { + r.httpRequest.Body = io.NopCloser(strings.NewReader(body)) + return r +} + +func (r *Request) BodyJsonString(body string) *Request { + r.httpRequest.Body = io.NopCloser(strings.NewReader(body)) + r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + return r +} + +func (r *Request) BodyJsonBytes(body []byte) *Request { + r.httpRequest.Body = io.NopCloser(bytes.NewReader(body)) + r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + return r +} + +func (r *Request) BodyJsonMarshal(v interface{}) *Request { + b, err := json.Marshal(v) + if err != nil { + r.appendError(err) + return r + } + return r.BodyBytes(b) +} + +func (r *Request) setContentType(contentType string) *Request { + r.httpRequest.Header.Set("Content-Type", contentType) + return r +} + +func (r *Request) execute() (resp *Response, err error) { + if r.error != nil { + return nil, r.error + } + for k, v := range r.client.commonHeader { + if r.httpRequest.Header.Get(k) == "" { + r.httpRequest.Header.Set(k, v) + } + } + httpResponse, err := r.client.httpClient.Do(r.httpRequest) + if err != nil { + return + } + resp = &Response{ + request: r, + Response: httpResponse, + } + if r.client.t.Discard { + io.Copy(io.Discard, httpResponse.Body) + } + return +} + +func (r *Request) Send() (resp *Response, err error) { + return r.execute() +} diff --git a/resp.go b/resp.go deleted file mode 100644 index e5c36fce..00000000 --- a/resp.go +++ /dev/null @@ -1,220 +0,0 @@ -package req - -import ( - "encoding/json" - "encoding/xml" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "regexp" - "time" -) - -// Resp represents a request with it's response -type Resp struct { - r *Req - req *http.Request - resp *http.Response - client *http.Client - cost time.Duration - *multipartHelper - reqBody []byte - respBody []byte - downloadProgress DownloadProgress - err error // delayed error -} - -// Request returns *http.Request -func (r *Resp) Request() *http.Request { - return r.req -} - -// Response returns *http.Response -func (r *Resp) Response() *http.Response { - return r.resp -} - -// Bytes returns response body as []byte -func (r *Resp) Bytes() []byte { - data, _ := r.ToBytes() - return data -} - -// ToBytes returns response body as []byte, -// return error if error happened when reading -// the response body -func (r *Resp) ToBytes() ([]byte, error) { - if r.err != nil { - return nil, r.err - } - if r.respBody != nil { - return r.respBody, nil - } - defer r.resp.Body.Close() - respBody, err := ioutil.ReadAll(r.resp.Body) - if err != nil { - r.err = err - return nil, err - } - r.respBody = respBody - return r.respBody, nil -} - -// String returns response body as string -func (r *Resp) String() string { - data, _ := r.ToBytes() - return string(data) -} - -// ToString returns response body as string, -// return error if error happened when reading -// the response body -func (r *Resp) ToString() (string, error) { - data, err := r.ToBytes() - return string(data), err -} - -// ToJSON convert json response body to struct or map -func (r *Resp) ToJSON(v interface{}) error { - data, err := r.ToBytes() - if err != nil { - return err - } - return json.Unmarshal(data, v) -} - -// ToXML convert xml response body to struct or map -func (r *Resp) ToXML(v interface{}) error { - data, err := r.ToBytes() - if err != nil { - return err - } - return xml.Unmarshal(data, v) -} - -// ToFile download the response body to file with optional download callback -func (r *Resp) ToFile(name string) error { - //TODO set name to the suffix of url path if name == "" - file, err := os.Create(name) - if err != nil { - return err - } - defer file.Close() - - if r.respBody != nil { - _, err = file.Write(r.respBody) - return err - } - - if r.downloadProgress != nil && r.resp.ContentLength > 0 { - return r.download(file) - } - - defer r.resp.Body.Close() - _, err = io.Copy(file, r.resp.Body) - return err -} - -func (r *Resp) download(file *os.File) error { - p := make([]byte, 1024) - b := r.resp.Body - defer b.Close() - total := r.resp.ContentLength - var current int64 - var lastTime time.Time - - defer func() { - r.downloadProgress(current, total) - }() - - for { - l, err := b.Read(p) - if l > 0 { - _, _err := file.Write(p[:l]) - if _err != nil { - return _err - } - current += int64(l) - if now := time.Now(); now.Sub(lastTime) > r.r.progressInterval { - lastTime = now - r.downloadProgress(current, total) - } - } - if err != nil { - if err == io.EOF { - return nil - } - return err - } - } -} - -var regNewline = regexp.MustCompile(`\n|\r`) - -func (r *Resp) autoFormat(s fmt.State) { - req := r.req - if r.r.flag&Lcost != 0 { - fmt.Fprint(s, req.Method, " ", req.URL.String(), " ", r.cost) - } else { - fmt.Fprint(s, req.Method, " ", req.URL.String()) - } - - // test if it is should be outputed pretty - var pretty bool - var parts []string - addPart := func(part string) { - if part == "" { - return - } - parts = append(parts, part) - if !pretty && regNewline.MatchString(part) { - pretty = true - } - } - if r.r.flag&LreqBody != 0 { // request body - addPart(string(r.reqBody)) - } - if r.r.flag&LrespBody != 0 { // response body - addPart(r.String()) - } - - for _, part := range parts { - if pretty { - fmt.Fprint(s, "\n") - } - fmt.Fprint(s, " ", part) - } -} - -func (r *Resp) miniFormat(s fmt.State) { - req := r.req - if r.r.flag&Lcost != 0 { - fmt.Fprint(s, req.Method, " ", req.URL.String(), " ", r.cost) - } else { - fmt.Fprint(s, req.Method, " ", req.URL.String()) - } - if r.r.flag&LreqBody != 0 && len(r.reqBody) > 0 { // request body - str := regNewline.ReplaceAllString(string(r.reqBody), " ") - fmt.Fprint(s, " ", str) - } - if r.r.flag&LrespBody != 0 && r.String() != "" { // response body - str := regNewline.ReplaceAllString(r.String(), " ") - fmt.Fprint(s, " ", str) - } -} - -// Format fort the response -func (r *Resp) Format(s fmt.State, verb rune) { - if r == nil || r.req == nil { - return - } - if s.Flag('+') { // include header and format pretty. - fmt.Fprint(s, r.Dump()) - } else if s.Flag('-') { // keep all informations in one line. - r.miniFormat(s) - } else { // auto - r.autoFormat(s) - } -} diff --git a/resp_test.go b/resp_test.go deleted file mode 100644 index 6e881a3b..00000000 --- a/resp_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package req - -import ( - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestToJSON(t *testing.T) { - type Result struct { - Code int `json:"code"` - Msg string `json:"msg"` - } - r1 := Result{ - Code: 1, - Msg: "ok", - } - handler := func(w http.ResponseWriter, r *http.Request) { - data, _ := json.Marshal(&r1) - w.Write(data) - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - r, err := Get(ts.URL) - if err != nil { - t.Fatal(err) - } - var r2 Result - err = r.ToJSON(&r2) - if err != nil { - t.Fatal(err) - } - if r1 != r2 { - t.Errorf("json response body = %+v; want = %+v", r2, r1) - } -} - -func TestToXML(t *testing.T) { - type Result struct { - XMLName xml.Name - Code int `xml:"code"` - Msg string `xml:"msg"` - } - r1 := Result{ - XMLName: xml.Name{Local: "result"}, - Code: 1, - Msg: "ok", - } - handler := func(w http.ResponseWriter, r *http.Request) { - data, _ := xml.Marshal(&r1) - w.Write(data) - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - r, err := Get(ts.URL) - if err != nil { - t.Fatal(err) - } - var r2 Result - err = r.ToXML(&r2) - if err != nil { - t.Fatal(err) - } - if r1 != r2 { - t.Errorf("xml response body = %+v; want = %+v", r2, r1) - } -} - -func TestFormat(t *testing.T) { - SetFlags(LstdFlags | Lcost) - reqHeader := "Request-Header" - respHeader := "Response-Header" - reqBody := "request body" - respBody1 := "response body 1" - respBody2 := "response body 2" - respBody := fmt.Sprintf("%s\n%s", respBody1, respBody2) - handler := func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(respHeader, "req") - w.Write([]byte(respBody)) - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - - // %v - r, err := Post(ts.URL, reqBody, Header{reqHeader: "hello"}) - if err != nil { - t.Fatal(err) - } - str := fmt.Sprintf("%v", r) - for _, keyword := range []string{ts.URL, reqBody, respBody} { - if !strings.Contains(str, keyword) { - t.Errorf("format %%v output lack of part, want: %s", keyword) - } - } - - // %-v - str = fmt.Sprintf("%-v", r) - for _, keyword := range []string{ts.URL, respBody1 + " " + respBody2} { - if !strings.Contains(str, keyword) { - t.Errorf("format %%-v output lack of part, want: %s", keyword) - } - } - - // %+v - str = fmt.Sprintf("%+v", r) - for _, keyword := range []string{reqBody, respBody, reqHeader, respHeader} { - if !strings.Contains(str, keyword) { - t.Errorf("format %%+v output lack of part, want: %s", keyword) - } - } -} - -func TestBytesAndString(t *testing.T) { - respBody := "response body" - handler := func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(respBody)) - } - ts := httptest.NewServer(http.HandlerFunc(handler)) - r, err := Get(ts.URL) - if err != nil { - t.Fatal(err) - } - if string(r.Bytes()) != respBody { - t.Errorf("response body = %s; want = %s", r.Bytes(), respBody) - } - if r.String() != respBody { - t.Errorf("response body = %s; want = %s", r.String(), respBody) - } -} diff --git a/response.go b/response.go new file mode 100644 index 00000000..b36718a1 --- /dev/null +++ b/response.go @@ -0,0 +1,66 @@ +package req + +import ( + "net/http" + "strings" +) + +type ResponseOptions struct { + // DisableAutoDecode, if true, prevents auto detect response + // body's charset and decode it to utf-8 + DisableAutoDecode bool + + // AutoDecodeContentType specifies an optional function for determine + // whether the response body should been auto decode to utf-8. + // Only valid when DisableAutoDecode is true. + AutoDecodeContentType func(contentType string) bool + + Discard bool +} + +type ResponseOption func(o *ResponseOptions) + +func DiscardBody() ResponseOption { + return func(o *ResponseOptions) { + o.Discard = true + } +} + +// DisableAutoDecode disable the response body auto-decode to improve performance. +func DisableAutoDecode() ResponseOption { + return func(o *ResponseOptions) { + o.DisableAutoDecode = true + } +} + +// AutoDecodeContentTypeFunc customize the function to determine whether response +// body should auto decode with specified content type. +func AutoDecodeContentTypeFunc(fn func(contentType string) bool) ResponseOption { + return func(o *ResponseOptions) { + o.AutoDecodeContentType = fn + } +} + +// AutoDecodeContentType specifies that the response body should been auto-decoded +// when content type contains keywords that here given. +func AutoDecodeContentType(contentTypes ...string) ResponseOption { + return func(o *ResponseOptions) { + o.AutoDecodeContentType = func(contentType string) bool { + for _, t := range contentTypes { + if strings.Contains(contentType, t) { + return true + } + } + return false + } + } +} + +type Response struct { + *http.Response + request *Request +} + +func (r *Response) Body() Body { + return Body{r.Response.Body, r.Response} +} diff --git a/roundtrip.go b/roundtrip.go new file mode 100644 index 00000000..a8a01996 --- /dev/null +++ b/roundtrip.go @@ -0,0 +1,25 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js || !wasm + +package req + +import "net/http" + +// RoundTrip implements the RoundTripper interface. +// +// For higher-level HTTP client support (such as handling of cookies +// and redirects), see Get, Post, and the Client type. +// +// Like the RoundTripper interface, the error types returned +// by RoundTrip are unspecified. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + resp, err = t.roundTrip(req) + if err != nil { + return + } + t.handleResponseBody(resp) + return +} diff --git a/setting.go b/setting.go deleted file mode 100644 index ee771e07..00000000 --- a/setting.go +++ /dev/null @@ -1,248 +0,0 @@ -package req - -import ( - "crypto/tls" - "errors" - "net" - "net/http" - "net/http/cookiejar" - "net/url" - "time" -) - -// create a default client -func newClient() *http.Client { - jar, _ := cookiejar.New(nil) - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - return &http.Client{ - Jar: jar, - Transport: transport, - Timeout: 2 * time.Minute, - } -} - -// Client return the default underlying http client -func (r *Req) Client() *http.Client { - if r.client == nil { - r.client = newClient() - } - return r.client -} - -// Client return the default underlying http client -func Client() *http.Client { - return std.Client() -} - -// SetClient sets the underlying http.Client. -func (r *Req) SetClient(client *http.Client) { - r.client = client // use default if client == nil -} - -// SetClient sets the default http.Client for requests. -func SetClient(client *http.Client) { - std.SetClient(client) -} - -// SetFlags control display format of *Resp -func (r *Req) SetFlags(flags int) { - r.flag = flags -} - -// SetFlags control display format of *Resp -func SetFlags(flags int) { - std.SetFlags(flags) -} - -// Flags return output format for the *Resp -func (r *Req) Flags() int { - return r.flag -} - -// Flags return output format for the *Resp -func Flags() int { - return std.Flags() -} - -func (r *Req) getTransport() *http.Transport { - trans, _ := r.Client().Transport.(*http.Transport) - return trans -} - -// EnableInsecureTLS allows insecure https -func (r *Req) EnableInsecureTLS(enable bool) { - trans := r.getTransport() - if trans == nil { - return - } - if trans.TLSClientConfig == nil { - trans.TLSClientConfig = &tls.Config{} - } - trans.TLSClientConfig.InsecureSkipVerify = enable -} - -func EnableInsecureTLS(enable bool) { - std.EnableInsecureTLS(enable) -} - -// EnableCookieenable or disable cookie manager -func (r *Req) EnableCookie(enable bool) { - if enable { - jar, _ := cookiejar.New(nil) - r.Client().Jar = jar - } else { - r.Client().Jar = nil - } -} - -// EnableCookieenable or disable cookie manager -func EnableCookie(enable bool) { - std.EnableCookie(enable) -} - -// SetTimeout sets the timeout for every request -func (r *Req) SetTimeout(d time.Duration) { - r.Client().Timeout = d -} - -// SetTimeout sets the timeout for every request -func SetTimeout(d time.Duration) { - std.SetTimeout(d) -} - -// SetProxyUrl set the simple proxy with fixed proxy url -func (r *Req) SetProxyUrl(rawurl string) error { - trans := r.getTransport() - if trans == nil { - return errors.New("req: no transport") - } - u, err := url.Parse(rawurl) - if err != nil { - return err - } - trans.Proxy = http.ProxyURL(u) - return nil -} - -// SetProxyUrl set the simple proxy with fixed proxy url -func SetProxyUrl(rawurl string) error { - return std.SetProxyUrl(rawurl) -} - -// SetProxy sets the proxy for every request -func (r *Req) SetProxy(proxy func(*http.Request) (*url.URL, error)) error { - trans := r.getTransport() - if trans == nil { - return errors.New("req: no transport") - } - trans.Proxy = proxy - return nil -} - -// SetProxy sets the proxy for every request -func SetProxy(proxy func(*http.Request) (*url.URL, error)) error { - return std.SetProxy(proxy) -} - -type jsonEncOpts struct { - indentPrefix string - indentValue string - escapeHTML bool -} - -func (r *Req) getJSONEncOpts() *jsonEncOpts { - if r.jsonEncOpts == nil { - r.jsonEncOpts = &jsonEncOpts{escapeHTML: true} - } - return r.jsonEncOpts -} - -// SetJSONEscapeHTML specifies whether problematic HTML characters -// should be escaped inside JSON quoted strings. -// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e -// to avoid certain safety problems that can arise when embedding JSON in HTML. -// -// In non-HTML settings where the escaping interferes with the readability -// of the output, SetEscapeHTML(false) disables this behavior. -func (r *Req) SetJSONEscapeHTML(escape bool) { - opts := r.getJSONEncOpts() - opts.escapeHTML = escape -} - -// SetJSONEscapeHTML specifies whether problematic HTML characters -// should be escaped inside JSON quoted strings. -// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e -// to avoid certain safety problems that can arise when embedding JSON in HTML. -// -// In non-HTML settings where the escaping interferes with the readability -// of the output, SetEscapeHTML(false) disables this behavior. -func SetJSONEscapeHTML(escape bool) { - std.SetJSONEscapeHTML(escape) -} - -// SetJSONIndent instructs the encoder to format each subsequent encoded -// value as if indented by the package-level function Indent(dst, src, prefix, indent). -// Calling SetIndent("", "") disables indentation. -func (r *Req) SetJSONIndent(prefix, indent string) { - opts := r.getJSONEncOpts() - opts.indentPrefix = prefix - opts.indentValue = indent -} - -// SetJSONIndent instructs the encoder to format each subsequent encoded -// value as if indented by the package-level function Indent(dst, src, prefix, indent). -// Calling SetIndent("", "") disables indentation. -func SetJSONIndent(prefix, indent string) { - std.SetJSONIndent(prefix, indent) -} - -type xmlEncOpts struct { - prefix string - indent string -} - -func (r *Req) getXMLEncOpts() *xmlEncOpts { - if r.xmlEncOpts == nil { - r.xmlEncOpts = &xmlEncOpts{} - } - return r.xmlEncOpts -} - -// SetXMLIndent sets the encoder to generate XML in which each element -// begins on a new indented line that starts with prefix and is followed by -// one or more copies of indent according to the nesting depth. -func (r *Req) SetXMLIndent(prefix, indent string) { - opts := r.getXMLEncOpts() - opts.prefix = prefix - opts.indent = indent -} - -// SetXMLIndent sets the encoder to generate XML in which each element -// begins on a new indented line that starts with prefix and is followed by -// one or more copies of indent according to the nesting depth. -func SetXMLIndent(prefix, indent string) { - std.SetXMLIndent(prefix, indent) -} - -// SetProgressInterval sets the progress reporting interval of both -// UploadProgress and DownloadProgress handler -func (r *Req) SetProgressInterval(interval time.Duration) { - r.progressInterval = interval -} - -// SetProgressInterval sets the progress reporting interval of both -// UploadProgress and DownloadProgress handler for the default client -func SetProgressInterval(interval time.Duration) { - std.SetProgressInterval(interval) -} diff --git a/setting_test.go b/setting_test.go deleted file mode 100644 index e71a6d7d..00000000 --- a/setting_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package req - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func newDefaultTestServer() *httptest.Server { - handler := func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hi")) - } - return httptest.NewServer(http.HandlerFunc(handler)) -} - -func TestSetClient(t *testing.T) { - - ts := newDefaultTestServer() - - client := &http.Client{} - SetClient(client) - _, err := Get(ts.URL) - if err != nil { - t.Errorf("error after set client: %v", err) - } - - SetClient(nil) - _, err = Get(ts.URL) - if err != nil { - t.Errorf("error after set client to nil: %v", err) - } - - client = Client() - if trans, ok := client.Transport.(*http.Transport); ok { - trans.MaxIdleConns = 1 - trans.DisableKeepAlives = true - _, err = Get(ts.URL) - if err != nil { - t.Errorf("error after change client's transport: %v", err) - } - } else { - t.Errorf("transport is not http.Transport: %+#v", client.Transport) - } -} - -func TestSetting(t *testing.T) { - defer func() { - if rc := recover(); rc != nil { - t.Errorf("panic happened while change setting: %v", rc) - } - }() - SetTimeout(2 * time.Second) - EnableCookie(false) - EnableCookie(true) - EnableInsecureTLS(true) - SetJSONIndent("", " ") - SetJSONEscapeHTML(false) - SetXMLIndent("", "\t") - SetProxyUrl("http://localhost:8080") - SetProxy(nil) -} diff --git a/socks_bundle.go b/socks_bundle.go new file mode 100644 index 00000000..940e0bb3 --- /dev/null +++ b/socks_bundle.go @@ -0,0 +1,473 @@ +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +//go:generate bundle -o socks_bundle.go -prefix socks golang.org/x/net/internal/socks + +// Package socks provides a SOCKS version 5 client implementation. +// +// SOCKS protocol version 5 is defined in RFC 1928. +// Username/Password authentication for SOCKS version 5 is defined in +// RFC 1929. +// + +package req + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "time" +) + +var ( + socksnoDeadline = time.Time{} + socksaLongTimeAgo = time.Unix(1, 0) +) + +func (d *socksDialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { + host, port, err := sockssplitHostPort(address) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + c.SetDeadline(deadline) + defer c.SetDeadline(socksnoDeadline) + } + if ctx != context.Background() { + errCh := make(chan error, 1) + done := make(chan struct{}) + defer func() { + close(done) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + go func() { + select { + case <-ctx.Done(): + c.SetDeadline(socksaLongTimeAgo) + errCh <- ctx.Err() + case <-done: + errCh <- nil + } + }() + } + + b := make([]byte, 0, 6+len(host)) // the size here is just an estimate + b = append(b, socksVersion5) + if len(d.AuthMethods) == 0 || d.Authenticate == nil { + b = append(b, 1, byte(socksAuthMethodNotRequired)) + } else { + ams := d.AuthMethods + if len(ams) > 255 { + return nil, errors.New("too many authentication methods") + } + b = append(b, byte(len(ams))) + for _, am := range ams { + b = append(b, byte(am)) + } + } + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { + return + } + if b[0] != socksVersion5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + am := socksAuthMethod(b[1]) + if am == socksAuthMethodNoAcceptableMethods { + return nil, errors.New("no acceptable authentication methods") + } + if d.Authenticate != nil { + if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { + return + } + } + + b = b[:0] + b = append(b, socksVersion5, byte(d.cmd), 0) + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + b = append(b, socksAddrTypeIPv4) + b = append(b, ip4...) + } else if ip6 := ip.To16(); ip6 != nil { + b = append(b, socksAddrTypeIPv6) + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + } else { + if len(host) > 255 { + return nil, errors.New("FQDN too long") + } + b = append(b, socksAddrTypeFQDN) + b = append(b, byte(len(host))) + b = append(b, host...) + } + b = append(b, byte(port>>8), byte(port)) + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { + return + } + if b[0] != socksVersion5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + if cmdErr := socksReply(b[1]); cmdErr != socksStatusSucceeded { + return nil, errors.New("unknown error " + cmdErr.String()) + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + l := 2 + var a socksAddr + switch b[3] { + case socksAddrTypeIPv4: + l += net.IPv4len + a.IP = make(net.IP, net.IPv4len) + case socksAddrTypeIPv6: + l += net.IPv6len + a.IP = make(net.IP, net.IPv6len) + case socksAddrTypeFQDN: + if _, err := io.ReadFull(c, b[:1]); err != nil { + return nil, err + } + l += int(b[0]) + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + } + if cap(b) < l { + b = make([]byte, l) + } else { + b = b[:l] + } + if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { + return + } + if a.IP != nil { + copy(a.IP, b) + } else { + a.Name = string(b[:len(b)-2]) + } + a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) + return &a, nil +} + +func sockssplitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} + +// A Command represents a SOCKS command. +type socksCommand int + +func (cmd socksCommand) String() string { + switch cmd { + case socksCmdConnect: + return "socks connect" + case sockscmdBind: + return "socks bind" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +// An AuthMethod represents a SOCKS authentication method. +type socksAuthMethod int + +// A Reply represents a SOCKS command reply code. +type socksReply int + +func (code socksReply) String() string { + switch code { + case socksStatusSucceeded: + return "succeeded" + case 0x01: + return "general SOCKS server failure" + case 0x02: + return "connection not allowed by ruleset" + case 0x03: + return "network unreachable" + case 0x04: + return "host unreachable" + case 0x05: + return "connection refused" + case 0x06: + return "TTL expired" + case 0x07: + return "command not supported" + case 0x08: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// Wire protocol constants. +const ( + socksVersion5 = 0x05 + + socksAddrTypeIPv4 = 0x01 + socksAddrTypeFQDN = 0x03 + socksAddrTypeIPv6 = 0x04 + + socksCmdConnect socksCommand = 0x01 // establishes an active-open forward proxy connection + sockscmdBind socksCommand = 0x02 // establishes a passive-open forward proxy connection + + socksAuthMethodNotRequired socksAuthMethod = 0x00 // no authentication required + socksAuthMethodUsernamePassword socksAuthMethod = 0x02 // use username/password + socksAuthMethodNoAcceptableMethods socksAuthMethod = 0xff // no acceptable authentication methods + + socksStatusSucceeded socksReply = 0x00 +) + +// An Addr represents a SOCKS-specific address. +// Either Name or IP is used exclusively. +type socksAddr struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *socksAddr) Network() string { return "socks" } + +func (a *socksAddr) String() string { + if a == nil { + return "" + } + port := strconv.Itoa(a.Port) + if a.IP == nil { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +// A Conn represents a forward proxy connection. +type socksConn struct { + net.Conn + + boundAddr net.Addr +} + +// BoundAddr returns the address assigned by the proxy server for +// connecting to the command target address from the proxy server. +func (c *socksConn) BoundAddr() net.Addr { + if c == nil { + return nil + } + return c.boundAddr +} + +// A Dialer holds SOCKS-specific options. +type socksDialer struct { + cmd socksCommand // either CmdConnect or cmdBind + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address + + // ProxyDial specifies the optional dial function for + // establishing the transport connection. + ProxyDial func(context.Context, string, string) (net.Conn, error) + + // AuthMethods specifies the list of request authentication + // methods. + // If empty, SOCKS client requests only AuthMethodNotRequired. + AuthMethods []socksAuthMethod + + // Authenticate specifies the optional authentication + // function. It must be non-nil when AuthMethods is not empty. + // It must return an error when the authentication is failed. + Authenticate func(context.Context, io.ReadWriter, socksAuthMethod) error +} + +// DialContext connects to the provided address on the provided +// network. +// +// The returned error value may be a net.OpError. When the Op field of +// net.OpError contains "socks", the Source field contains a proxy +// server address and the Addr field contains a command target +// address. +// +// See func Dial of the net package of standard library for a +// description of the network and address parameters. +func (d *socksDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) + } else { + var dd net.Dialer + c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + a, err := d.connect(ctx, c, address) + if err != nil { + c.Close() + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return &socksConn{Conn: c, boundAddr: a}, nil +} + +// DialWithConn initiates a connection from SOCKS server to the target +// network and address using the connection c that is already +// connected to the SOCKS server. +// +// It returns the connection's local address assigned by the SOCKS +// server. +func (d *socksDialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + a, err := d.connect(ctx, c, address) + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return a, nil +} + +// Dial connects to the provided address on the provided network. +// +// Unlike DialContext, it returns a raw transport connection instead +// of a forward proxy connection. +// +// Deprecated: Use DialContext or DialWithConn instead. +func (d *socksDialer) Dial(network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) + } else { + c, err = net.Dial(d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { + c.Close() + return nil, err + } + return c, nil +} + +func (d *socksDialer) validateTarget(network, address string) error { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return errors.New("network not implemented") + } + switch d.cmd { + case socksCmdConnect, sockscmdBind: + default: + return errors.New("command not implemented") + } + return nil +} + +func (d *socksDialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { + for i, s := range []string{d.proxyAddress, address} { + host, port, err := sockssplitHostPort(s) + if err != nil { + return nil, nil, err + } + a := &socksAddr{Port: port} + a.IP = net.ParseIP(host) + if a.IP == nil { + a.Name = host + } + if i == 0 { + proxy = a + } else { + dst = a + } + } + return +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func socksNewDialer(network, address string) *socksDialer { + return &socksDialer{proxyNetwork: network, proxyAddress: address, cmd: socksCmdConnect} +} + +const ( + socksauthUsernamePasswordVersion = 0x01 + socksauthStatusSucceeded = 0x00 +) + +// UsernamePassword are the credentials for the username/password +// authentication method. +type socksUsernamePassword struct { + Username string + Password string +} + +// Authenticate authenticates a pair of username and password with the +// proxy server. +func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth socksAuthMethod) error { + switch auth { + case socksAuthMethodNotRequired: + return nil + case socksAuthMethodUsernamePassword: + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 { + return errors.New("invalid username/password") + } + b := []byte{socksauthUsernamePasswordVersion} + b = append(b, byte(len(up.Username))) + b = append(b, up.Username...) + b = append(b, byte(len(up.Password))) + b = append(b, up.Password...) + // TODO(mikio): handle IO deadlines and cancelation if + // necessary + if _, err := rw.Write(b); err != nil { + return err + } + if _, err := io.ReadFull(rw, b[:2]); err != nil { + return err + } + if b[0] != socksauthUsernamePasswordVersion { + return errors.New("invalid username/password version") + } + if b[1] != socksauthStatusSucceeded { + return errors.New("username/password authentication failed") + } + return nil + } + return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) +} diff --git a/textproto_reader.go b/textproto_reader.go new file mode 100644 index 00000000..12cec208 --- /dev/null +++ b/textproto_reader.go @@ -0,0 +1,844 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/textproto" + "strconv" + "strings" + "sync" +) + +// An Error represents a numeric error response from a server. +type codeError struct { + Code int + Msg string +} + +func (e *codeError) Error() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isASCIILetter(b byte) bool { + b |= 0x20 // make lower case + return 'a' <= b && b <= 'z' +} + +// A textprotoReader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type textprotoReader struct { + R *bufio.Reader + dot *dotReader + buf []byte // a re-usable buffer for readContinuedLineSlice + readLine func() (line []byte, isPrefix bool, err error) +} + +// NewReader returns a new textprotoReader reading from r. +// +// To avoid denial of service attacks, the provided bufio.Reader +// should be reading from an io.LimitReader or similar textprotoReader to bound +// the size of responses. +func newTextprotoReader(r *bufio.Reader, dump *dumper) *textprotoReader { + commonHeaderOnce.Do(initCommonHeader) + t := &textprotoReader{R: r} + if dump != nil && dump.ResponseHead { + t.readLine = func() (line []byte, isPrefix bool, err error) { + line, err = t.R.ReadSlice('\n') + if len(line) == 0 { + if err != nil { + line = nil + } + return + } + err = nil + dump.dump(line) + if line[len(line)-1] == '\n' { + drop := 1 + if len(line) > 1 && line[len(line)-2] == '\r' { + drop = 2 + } + line = line[:len(line)-drop] + } + return + } + } else { + t.readLine = t.R.ReadLine + } + return t +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *textprotoReader) ReadLine() (string, error) { + line, err := r.readLineSlice() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *textprotoReader) ReadLineBytes() ([]byte, error) { + line, err := r.readLineSlice() + if line != nil { + buf := make([]byte, len(line)) + copy(buf, line) + line = buf + } + return line, err +} + +func (r *textprotoReader) readLineSlice() ([]byte, error) { + r.closeDot() + var line []byte + + for { + l, more, err := r.readLine() + if err != nil { + return nil, err + } + // Avoid the copy if the first call produced a full line. + if line == nil && !more { + return l, nil + } + line = append(line, l...) + if !more { + break + } + } + return line, nil +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// Empty lines are never continued. +// +func (r *textprotoReader) ReadContinuedLine() (string, error) { + line, err := r.readContinuedLineSlice(noValidation) + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *textprotoReader) ReadContinuedLineBytes() ([]byte, error) { + line, err := r.readContinuedLineSlice(noValidation) + if line != nil { + buf := make([]byte, len(line)) + copy(buf, line) + line = buf + } + return line, err +} + +// readContinuedLineSlice reads continued lines from the reader buffer, +// returning a byte slice with all lines. The validateFirstLine function +// is run on the first read line, and if it returns an error then this +// error is returned from readContinuedLineSlice. +func (r *textprotoReader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { + if validateFirstLine == nil { + return nil, fmt.Errorf("missing validateFirstLine func") + } + + // Read the first line. + line, err := r.readLineSlice() + if err != nil { + return nil, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + + if err := validateFirstLine(line); err != nil { + return nil, err + } + + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), or a blank + // line, so we can avoid copying that buffered data around in memory + // and skipping over non-existent whitespace. + if r.R.Buffered() > 1 { + peek, _ := r.R.Peek(2) + if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') || + len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' { + return trim(line), nil + } + } + + // ReadByte or the next readLineSlice will flush the read buffer; + // copy the slice into buf. + r.buf = append(r.buf[:0], trim(line)...) + + // Read continuation lines. + for r.skipSpace() > 0 { + line, err := r.readLineSlice() + if err != nil { + break + } + r.buf = append(r.buf, ' ') + r.buf = append(r.buf, trim(line)...) + } + return r.buf, nil +} + +// skipSpace skips R over all spaces and returns the number of bytes skipped. +func (r *textprotoReader) skipSpace() int { + n := 0 + for { + c, err := r.R.ReadByte() + if err != nil { + // Bufio will keep err until next read. + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + n++ + } + return n +} + +func (r *textprotoReader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { + line, err := r.ReadLine() + if err != nil { + return + } + return parseCodeLine(line, expectCode) +} + +// A protocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type protocolError string + +func (p protocolError) Error() string { + return string(p) +} + +func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = protocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = protocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &codeError{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// code message +// where code is a three-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &codeError{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *textprotoReader) ReadCodeLine(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = protocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form: +// +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// +// where code is a three-digit status code. The first line starts with the +// code and a hyphen. The response is terminated by a line that starts +// with the same code followed by a space. Each line in message is +// separated by a newline (\n). +// +// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for +// details of another form of response accepted: +// +// code-message line 1 +// message line 2 +// ... +// code message line n +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &codeError{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *textprotoReader) ReadResponse(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + multi := continued + for continued { + line, err := r.ReadLine() + if err != nil { + return 0, "", err + } + + var code2 int + var moreMessage string + code2, continued, moreMessage, err = parseCodeLine(line, 0) + if err != nil || code2 != code { + message += "\n" + strings.TrimRight(line, "\r\n") + continued = true + continue + } + message += "\n" + moreMessage + } + if err != nil && multi && message != "" { + // replace one line error message with all lines (full message) + err = &codeError{code, message} + } + return +} + +// DotReader returns a new textprotoReader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned textprotoReader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protocols such as SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the textprotoReader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error io.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *textprotoReader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *textprotoReader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = io.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *textprotoReader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *textprotoReader) ReadDotBytes() ([]byte, error) { + return io.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *textprotoReader) ReadDotLines() ([]string, error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v []string + var err error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v = append(v, line) + } + return v, err +} + +var colon = []byte(":") + +// Readtextproto.MIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps Canonicaltextproto.MIMEHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, Readtextproto.MIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": {"Value 1", "Value 2"}, +// "Long-Key": {"Even Longer Value"}, +// } +// +func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { + // Avoid lots of small slice allocations later by allocating one + // large one ahead of time which we'll cut up into smaller + // slices. If this isn't big enough later, we allocate small ones. + var strs []string + hint := r.upcomingHeaderNewlines() + if hint > 0 { + strs = make([]string, hint) + } + + m := make(textproto.MIMEHeader, hint) + + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + line, err := r.readLineSlice() + if err != nil { + return m, err + } + return m, protocolError("malformed MIME header initial line: " + string(line)) + } + + for { + kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon. + k, v, ok := bytes.Cut(kv, colon) + if !ok { + return m, protocolError("malformed MIME header line: " + string(kv)) + } + key := canonicalMIMEHeaderKey(k) + + // As per RFC 7230 field-name is a token, tokens consist of one or more chars. + // We could return a protocolError here, but better to be liberal in what we + // accept, so if we get an empty key, skip it. + if key == "" { + continue + } + + // Skip initial spaces in value. + value := strings.TrimLeft(string(v), " \t") + + vv := m[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = value + m[key] = vv + } else { + m[key] = append(vv, value) + } + + if err != nil { + return m, err + } + } +} + +// noValidation is a no-op validation func for readContinuedLineSlice +// that permits any lines. +func noValidation(_ []byte) error { return nil } + +// mustHaveFieldNameColon ensures that, per RFC 7230, the +// field-name is on a single line, so the first line must +// contain a colon. +func mustHaveFieldNameColon(line []byte) error { + if bytes.IndexByte(line, ':') < 0 { + return protocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line)) + } + return nil +} + +var nl = []byte("\n") + +// upcomingHeaderNewlines returns an approximation of the number of newlines +// that will be in this header. If it gets confused, it returns 0. +func (r *textprotoReader) upcomingHeaderNewlines() (n int) { + // Try to determine the 'hint' size. + r.R.Peek(1) // force a buffer load if empty + s := r.R.Buffered() + if s == 0 { + return + } + peek, _ := r.R.Peek(s) + return bytes.Count(peek, nl) +} + +// Canonicaltextproto.MIMEHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// MIME header keys are assumed to be ASCII only. +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalMIMEHeaderKey(s string) string { + commonHeaderOnce.Do(initCommonHeader) + + // Quick check for canonical encoding. + upper := true + for i := 0; i < len(s); i++ { + c := s[i] + if !validHeaderFieldByte(c) { + return s + } + if upper && 'a' <= c && c <= 'z' { + return canonicalMIMEHeaderKey([]byte(s)) + } + if !upper && 'A' <= c && c <= 'Z' { + return canonicalMIMEHeaderKey([]byte(s)) + } + upper = c == '-' + } + return s +} + +const toLower = 'a' - 'A' + +// validHeaderFieldByte reports whether b is a valid byte in a header +// field name. RFC 7230 says: +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +// token = 1*tchar +func validHeaderFieldByte(b byte) bool { + return int(b) < len(isTokenTable) && isTokenTable[b] +} + +// canonicalMIMEHeaderKey is like Canonicaltextproto.MIMEHeaderKey but is +// allowed to mutate the provided byte slice before returning the +// string. +// +// For invalid inputs (if a contains spaces or non-token bytes), a +// is unchanged and a string copy is returned. +func canonicalMIMEHeaderKey(a []byte) string { + // See if a looks like a header key. If not, return it unchanged. + for _, c := range a { + if validHeaderFieldByte(c) { + continue + } + // Don't canonicalize. + return string(a) + } + + upper := true + for i, c := range a { + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + if upper && 'a' <= c && c <= 'z' { + c -= toLower + } else if !upper && 'A' <= c && c <= 'Z' { + c += toLower + } + a[i] = c + upper = c == '-' // for next time + } + // The compiler recognizes m[string(byteSlice)] as a special + // case, so a copy of a's bytes into a new string does not + // happen in this map lookup: + if v := commonHeader[string(a)]; v != "" { + return v + } + return string(a) +} + +// commonHeader interns common header strings. +var commonHeader map[string]string + +var commonHeaderOnce sync.Once + +func initCommonHeader() { + commonHeader = make(map[string]string) + for _, v := range []string{ + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Accept-Ranges", + "Cache-Control", + "Cc", + "Connection", + "Content-Id", + "Content-Language", + "Content-Length", + "Content-Transfer-Encoding", + "Content-Type", + "Cookie", + "Date", + "Dkim-Signature", + "Etag", + "Expires", + "From", + "Host", + "If-Modified-Since", + "If-None-Match", + "In-Reply-To", + "Last-Modified", + "Location", + "Message-Id", + "Mime-Version", + "Pragma", + "Received", + "Return-Path", + "Server", + "Set-Cookie", + "Subject", + "To", + "User-Agent", + "Via", + "X-Forwarded-For", + "X-Imforwards", + "X-Powered-By", + } { + commonHeader[v] = v + } +} + +// isTokenTable is a copy of net/http/lex.go's isTokenTable. +// See https://httpwg.github.io/specs/rfc7230.html#rule.token.separators +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} diff --git a/transfer.go b/transfer.go new file mode 100644 index 00000000..88a23b5e --- /dev/null +++ b/transfer.go @@ -0,0 +1,1126 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "github.com/imroc/req/internal" + "github.com/imroc/req/internal/ascii" + "io" + "net/http" + "net/http/httptrace" + "net/textproto" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// ErrLineTooLong is returned when reading request or response bodies +// with malformed chunked encoding. +var ErrLineTooLong = internal.ErrLineTooLong + +type errorReader struct { + err error +} + +func (r errorReader) Read(p []byte) (n int, err error) { + return 0, r.err +} + +type byteReader struct { + b byte + done bool +} + +func (br *byteReader) Read(p []byte) (n int, err error) { + if br.done { + return 0, io.EOF + } + if len(p) == 0 { + return 0, nil + } + br.done = true + p[0] = br.b + return 1, io.EOF +} + +// transferWriter inspects the fields of a user-supplied Request or Response, +// sanitizes them without changing the user object and provides methods for +// writing the respective header, body and trailer in wire format. +type transferWriter struct { + Method string + Body io.Reader + BodyCloser io.Closer + ResponseToHEAD bool + ContentLength int64 // -1 means unknown, 0 means exactly none + Close bool + TransferEncoding []string + Header http.Header + Trailer http.Header + IsResponse bool + bodyReadError error // any non-EOF error from reading Body + + FlushHeaders bool // flush headers to network before body + ByteReadCh chan readResult // non-nil if probeRequestBody called +} + +func newTransferWriter(r any) (t *transferWriter, err error) { + t = &transferWriter{} + + // Extract relevant fields + atLeastHTTP11 := false + switch rr := r.(type) { + case *http.Request: + if rr.ContentLength != 0 && rr.Body == nil { + return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) + } + t.Method = valueOrDefault(rr.Method, "GET") + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Header = rr.Header + t.Trailer = rr.Trailer + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = outgoingLength(rr) + if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() { + t.TransferEncoding = []string{"chunked"} + } + // If there's a body, conservatively flush the headers + // to any bufio.Writer we're writing to, just in case + // the server needs the headers early, before we copy + // the body and possibly block. We make an exception + // for the common standard library in-memory types, + // though, to avoid unnecessary TCP packets on the + // wire. (Issue 22088.) + if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) { + t.FlushHeaders = true + } + + atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0 + case *http.Response: + t.IsResponse = true + if rr.Request != nil { + t.Method = rr.Request.Method + } + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.ContentLength + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Header = rr.Header + t.Trailer = rr.Trailer + atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + t.ResponseToHEAD = noResponseBodyExpected(t.Method) + } + + // Sanitize Body,ContentLength,TransferEncoding + if t.ResponseToHEAD { + t.Body = nil + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } + } else { + if !atLeastHTTP11 || t.Body == nil { + t.TransferEncoding = nil + } + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } else if t.Body == nil { // no chunking, no body + t.ContentLength = 0 + } + } + + // Sanitize Trailer + if !chunked(t.TransferEncoding) { + t.Trailer = nil + } + + return t, nil +} + +// shouldSendChunkedRequestBody reports whether we should try to send a +// chunked request body to the server. In particular, the case we really +// want to prevent is sending a GET or other typically-bodyless request to a +// server with a chunked body when the body has zero bytes, since GETs with +// bodies (while acceptable according to specs), even zero-byte chunked +// bodies, are approximately never seen in the wild and confuse most +// servers. See Issue 18257, as one example. +// +// The only reason we'd send such a request is if the user set the Body to a +// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't +// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume +// there's bytes to send. +// +// This code tries to read a byte from the Request.Body in such cases to see +// whether the body actually has content (super rare) or is actually just +// a non-nil content-less ReadCloser (the more common case). In that more +// common case, we act as if their Body were nil instead, and don't send +// a body. +func (t *transferWriter) shouldSendChunkedRequestBody() bool { + // Note that t.ContentLength is the corrected content length + // from rr.outgoingLength, so 0 actually means zero, not unknown. + if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them + return false + } + if t.Method == "CONNECT" { + return false + } + if requestMethodUsuallyLacksBody(t.Method) { + // Only probe the Request.Body for GET/HEAD/DELETE/etc + // requests, because it's only those types of requests + // that confuse servers. + t.probeRequestBody() // adjusts t.Body, t.ContentLength + return t.Body != nil + } + // For all other request types (PUT, POST, PATCH, or anything + // made-up we've never heard of), assume it's normal and the server + // can deal with a chunked request body. Maybe we'll adjust this + // later. + return true +} + +// probeRequestBody reads a byte from t.Body to see whether it's empty +// (returns io.EOF right away). +// +// But because we've had problems with this blocking users in the past +// (issue 17480) when the body is a pipe (perhaps waiting on the response +// headers before the pipe is fed data), we need to be careful and bound how +// long we wait for it. This delay will only affect users if all the following +// are true: +// * the request body blocks +// * the content length is not set (or set to -1) +// * the method doesn't usually have a body (GET, HEAD, DELETE, ...) +// * there is no transfer-encoding=chunked already set. +// In other words, this delay will not normally affect anybody, and there +// are workarounds if it does. +func (t *transferWriter) probeRequestBody() { + t.ByteReadCh = make(chan readResult, 1) + go func(body io.Reader) { + var buf [1]byte + var rres readResult + rres.n, rres.err = body.Read(buf[:]) + if rres.n == 1 { + rres.b = buf[0] + } + t.ByteReadCh <- rres + close(t.ByteReadCh) + }(t.Body) + timer := time.NewTimer(200 * time.Millisecond) + select { + case rres := <-t.ByteReadCh: + timer.Stop() + if rres.n == 0 && rres.err == io.EOF { + // It was empty. + t.Body = nil + t.ContentLength = 0 + } else if rres.n == 1 { + if rres.err != nil { + t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err}) + } else { + t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body) + } + } else if rres.err != nil { + t.Body = errorReader{rres.err} + } + case <-timer.C: + // Too slow. Don't wait. Read it later, and keep + // assuming that this is ContentLength == -1 + // (unknown), which means we'll send a + // "Transfer-Encoding: chunked" header. + t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body) + // Request that Request.Write flush the headers to the + // network before writing the body, since our body may not + // become readable until it's seen the response headers. + t.FlushHeaders = true + } +} + +func noResponseBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func (t *transferWriter) shouldSendContentLength() bool { + if chunked(t.TransferEncoding) { + return false + } + if t.ContentLength > 0 { + return true + } + if t.ContentLength < 0 { + return false + } + // Many servers expect a Content-Length for these methods + if t.Method == "POST" || t.Method == "PUT" || t.Method == "PATCH" { + return true + } + if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { + if t.Method == "GET" || t.Method == "HEAD" { + return false + } + return true + } + + return false +} + +func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error { + if t.Close && !hasToken(headerGet(t.Header, "Connection"), "close") { + if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Connection", []string{"close"}) + } + } + + // Write Content-Length and/or Transfer-Encoding whose values are a + // function of the sanitized field triple (Body, ContentLength, + // TransferEncoding) + if t.shouldSendContentLength() { + if _, err := io.WriteString(w, "Content-Length: "); err != nil { + return err + } + if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)}) + } + } else if chunked(t.TransferEncoding) { + if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"}) + } + } + + // Write Trailer header + if t.Trailer != nil { + keys := make([]string, 0, len(t.Trailer)) + for k := range t.Trailer { + k = http.CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return badStringError("invalid Trailer key", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Trailer", keys) + } + } + } + + return nil +} + +// always closes t.BodyCloser +func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { + var ncopy int64 + closed := false + defer func() { + if closed || t.BodyCloser == nil { + return + } + if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + rw := w // raw writer + if dump != nil && dump.RequestBody { + w = dump.WrapWriter(w) + } + + // Write body. We "unwrap" the body first if it was wrapped in a + // nopCloser or readTrackingBody. This is to ensure that we can take advantage of + // OS-level optimizations in the event that the body is an + // *os.File. + if t.Body != nil { + var body = t.unwrapBody() + if chunked(t.TransferEncoding) { + if bw, ok := rw.(*bufio.Writer); ok && !t.IsResponse { + rw = &internal.FlushAfterChunkWriter{Writer: bw} + } + cw := internal.NewChunkedWriter(rw) + if dump != nil && dump.RequestBody { + cw = dump.WrapWriteCloser(cw) + } + _, err = t.doBodyCopy(cw, body) + if err == nil { + err = cw.Close() + } + } else if t.ContentLength == -1 { + dst := w + if t.Method == "CONNECT" { + dst = bufioFlushWriter{dst} + } + ncopy, err = t.doBodyCopy(dst, body) + } else { + ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength)) + if err != nil { + return err + } + var nextra int64 + nextra, err = t.doBodyCopy(io.Discard, body) + ncopy += nextra + } + if err != nil { + return err + } + if dump != nil && dump.RequestBody { + dump.dump([]byte("\r\n")) + } + } + if t.BodyCloser != nil { + closed = true + if err := t.BodyCloser.Close(); err != nil { + return err + } + } + + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { + return fmt.Errorf("http: ContentLength=%d with Body length %d", + t.ContentLength, ncopy) + } + + if chunked(t.TransferEncoding) { + // Write Trailer header + if t.Trailer != nil { + if err := t.Trailer.Write(w); err != nil { + return err + } + } + // Last chunk, empty trailer + _, err = io.WriteString(w, "\r\n") + } + return err +} + +// doBodyCopy wraps a copy operation, with any resulting error also +// being saved in bodyReadError. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { + n, err = io.Copy(dst, src) + if err != nil && err != io.EOF { + t.bodyReadError = err + } + return +} + +// unwrapBodyReader unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) unwrapBody() io.Reader { + if reflect.TypeOf(t.Body) == nopCloserType { + return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) + } + if r, ok := t.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } + return t.Body +} + +type transferReader struct { + // Input + Header http.Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer http.Header +} + +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 7232 section 4.1 + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +} + +// msg is *http.Request or *http.Response. +func readTransfer(msg any, r *bufio.Reader) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *http.Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *http.Request: + t.Header = rr.Header + t.RequestMethod = rr.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.Close = rr.Close + default: + panic("unexpected type") + } + + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + + // Transfer-Encoding: chunked, and overriding Content-Length. + if err := t.parseTransferEncoding(); err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(headerGet(t.Header, "Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.Chunked) + if err != nil { + return err + } + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC 7230, section 3.3. + switch msg.(type) { + case *Response: + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + switch { + case t.Chunked: + if noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + t.Body = NoBody + } else { + t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = NoBody + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = NoBody + } + } + + // Unify output + switch rr := msg.(type) { + case *http.Request: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + case *http.Response: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } + +// Checks whether the encoding is explicitly "identity". +func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } + +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// isUnsupportedTEError checks if the error is of type +// unsupportedTEError. It is usually invoked with a non-nil err. +func isUnsupportedTEError(err error) bool { + _, ok := err.(*unsupportedTEError) + return ok +} + +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !ascii.EqualFold(textproto.TrimString(raw[0]), "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil +} + +// Determine the expected body length, using RFC 7230 Section 3.3. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (int64, error) { + isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := textproto.TrimString(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != textproto.TrimString(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + + // Logic based on response type or status + if noResponseBodyExpected(requestMethod) { + // For HTTP requests, as part of hardening against request + // smuggling (RFC 7230), don't allow a Content-Length header for + // methods which don't permit bodies. As an exception, allow + // exactly one Content-Length header if its value is "0". + if isRequest && len(contentLens) > 0 && !(len(contentLens) == 1 && contentLens[0] == "0") { + return 0, fmt.Errorf("http: method cannot contain a Content-Length; got %q", contentLens) + } + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked { + return -1, nil + } + + // Logic based on Content-Length + var cl string + if len(contentLens) == 1 { + cl = textproto.TrimString(contentLens[0]) + } + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } + header.Del("Content-Length") + + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers +func shouldClose(major, minor int, header http.Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } + + conv := header["Connection"] + hasClose := httpguts.HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose +} + +// Parse the trailer header +func fixTrailer(header http.Header, chunked bool) (http.Header, error) { + vv, ok := header["Trailer"] + if !ok { + return nil, nil + } + if !chunked { + // Trailer and no chunking: + // this is an invalid use case for trailer header. + // Nevertheless, no error will be returned and we + // let users decide if this is a valid HTTP message. + // The Trailer header will be kept in Response.Header + // but not populate Response.Trailer. + // See issue #27197. + return nil, nil + } + header.Del("Trailer") + + trailer := make(http.Header) + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = http.CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = badStringError("bad trailer key", key) + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err + } + if len(trailer) == 0 { + return nil, nil + } + return trailer, nil +} + +// body turns a textprotoReader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr any // non-nil (Response or Request) value means read trailer + r *bufio.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + doEarlyClose bool // whether Close should stop early + + mu sync.Mutex // guards following, and calls to Read and Close + sawEOF bool + closed bool + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read +} + +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + if b.sawEOF { + return 0, io.EOF + } + n, err = b.src.Read(p) + + if err == io.EOF { + b.sawEOF = true + // Chunked case. Read the trailer. + if b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + // Something went wrong in the trailer, we must not allow any + // further reads of any kind to succeed from body, nor any + // subsequent requests on the server connection. See + // golang.org/issue/12027 + b.sawEOF = false + b.closed = true + } + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.textprotoReader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + b.sawEOF = true + } + } + + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + + return n, err +} + +var ( + singleCRLF = []byte("\r\n") + doubleCRLF = []byte("\r\n\r\n") +) + +func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { + for peekSize := 4; ; peekSize++ { + // This loop stops when Peek returns an error, + // which it does when r's buffer has been filled. + buf, err := r.Peek(peekSize) + if bytes.HasSuffix(buf, doubleCRLF) { + return true + } + if err != nil { + break + } + } + return false +} + +var errTrailerEOF = errors.New("http: unexpected EOF reading trailer") + +func (b *body) readTrailer() error { + // The common case, since nobody uses trailers. + buf, err := b.r.Peek(2) + if bytes.Equal(buf, singleCRLF) { + b.r.Discard(2) + return nil + } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } + + // Make sure there's a header terminator coming up, to prevent + // a DoS with an unbounded size Trailer. It's not easy to + // slip in a LimitReader here, as textproto.NewReader requires + // a concrete *bufio.textprotoReader. Also, we can't get all the way + // back up to our conn's LimitedReader that *might* be backing + // this bufio.textprotoReader. Instead, a hack: we iteratively Peek up + // to the bufio.textprotoReader's max size, looking for a double CRLF. + // This limits the trailer to the underlying buffer size, typically 4kB. + if !seeUpcomingDoubleCRLF(b.r) { + return errors.New("http: suspiciously long trailer after chunked body") + } + + hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() + if err != nil { + if err == io.EOF { + return errTrailerEOF + } + return err + } + switch rr := b.hdr.(type) { + case *http.Request: + mergeSetHeader(&rr.Trailer, http.Header(hdr)) + case *http.Response: + mergeSetHeader(&rr.Trailer, http.Header(hdr)) + } + return nil +} + +func mergeSetHeader(dst *http.Header, src http.Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + +// unreadDataSizeLocked returns the number of bytes of unread input. +// It returns -1 if unknown. +// b.mu must be held. +func (b *body) unreadDataSizeLocked() int64 { + if lr, ok := b.src.(*io.LimitedReader); ok { + return lr.N + } + return -1 +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.sawEOF: + // Already saw EOF, so no need going to look for it. + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + case b.doEarlyClose: + // Read up to maxPostHandlerReadBytes bytes of the body, looking + // for EOF (and trailers), so we can re-use this connection. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes { + // There was a declared Content-Length, and we have more bytes remaining + // than our maxPostHandlerReadBytes tolerance. So, give up. + b.earlyClose = true + } else { + var n int64 + // Consume the body, or, which will also lead to us reading + // the trailer headers after the body, if present. + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + if err == io.EOF { + err = nil + } + if n == maxPostHandlerReadBytes { + b.earlyClose = true + } + } + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(io.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +func (b *body) didEarlyClose() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.earlyClose +} + +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + +// bodyLocked is an io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = textproto.TrimString(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, badStringError("bad Content-Length", cl) + } + return int64(n), nil + +} + +// finishAsyncByteRead finishes reading the 1-byte sniff +// from the ContentLength==0, Body!=nil case. +type finishAsyncByteRead struct { + tw *transferWriter +} + +func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + rres := <-fr.tw.ByteReadCh + n, err = rres.n, rres.err + if n == 1 { + p[0] = rres.b + } + if err == nil { + err = io.EOF + } + return +} + +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) + +// isKnownInMemoryReader reports whether r is a type known to not +// block on Read. Its caller uses this as an optional optimization to +// send fewer TCP packets. +func isKnownInMemoryReader(r io.Reader) bool { + switch r.(type) { + case *bytes.Reader, *bytes.Buffer, *strings.Reader: + return true + } + if reflect.TypeOf(r) == nopCloserType { + return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader)) + } + if r, ok := r.(*readTrackingBody); ok { + return isKnownInMemoryReader(r.ReadCloser) + } + return false +} + +// bufioFlushWriter is an io.Writer wrapper that flushes all writes +// on its wrapped writer if it's a *bufio.Writer. +type bufioFlushWriter struct{ w io.Writer } + +func (fw bufioFlushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if bw, ok := fw.w.(*bufio.Writer); n > 0 && ok { + ferr := bw.Flush() + if ferr != nil && err == nil { + err = ferr + } + } + return +} diff --git a/transport.go b/transport.go new file mode 100644 index 00000000..8c0cd99a --- /dev/null +++ b/transport.go @@ -0,0 +1,2938 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client implementation. See RFC 7230 through 7235. +// +// This is the low-level Transport implementation of http.RoundTripper. +// The high-level interface is in client.go. + +package req + +import ( + "bufio" + "compress/gzip" + "container/list" + "context" + "crypto/tls" + "errors" + "fmt" + "github.com/imroc/req/internal/ascii" + "github.com/imroc/req/internal/godebug" + htmlcharset "golang.org/x/net/html/charset" + "golang.org/x/text/encoding/ianaindex" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http/httpproxy" +) + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 + +// Transport is an implementation of http.RoundTripper that supports HTTP, +// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). +// +// By default, Transport caches connections for future re-use. +// This may leave many open connections when accessing many hosts. +// This behavior can be managed using Transport's CloseIdleConnections method +// and the MaxIdleConnsPerHost and DisableKeepAlives fields. +// +// Transports should be reused instead of created as needed. +// Transports are safe for concurrent use by multiple goroutines. +// +// A Transport is a low-level primitive for making HTTP and HTTPS requests. +// For high-level functionality, such as cookies and redirects, see Client. +// +// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2 +// for HTTPS URLs, depending on whether the server supports HTTP/2, +// and how the Transport is configured. The DefaultTransport supports HTTP/2. +// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2 +// and call ConfigureTransport. See the package docs for more about HTTP/2. +// +// Responses with status codes in the 1xx range are either handled +// automatically (100 expect-continue) or ignored. The one +// exception is HTTP status code 101 (Switching Protocols), which is +// considered a terminal status and returned by RoundTrip. To see the +// ignored 1xx responses, use the httptrace trace package's +// ClientTrace.Got1xxResponse. +// +// Transport only retries a request upon encountering a network error +// if the request is idempotent and either has no body or has its +// Request.GetBody defined. HTTP requests are considered idempotent if +// they have HTTP methods GET, HEAD, OPTIONS, or TRACE; or if their +// Header map contains an "Idempotency-Key" or "X-Idempotency-Key" +// entry. If the idempotency key value is a zero-length slice, the +// request is treated as idempotent but the header is not sent on the +// wire. +type Transport struct { + idleMu sync.Mutex + closeIdle bool // user has requested to close all idle conns + idleConn map[connectMethodKey][]*persistConn // most recently used at end + idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns + idleLRU connLRU + + reqMu sync.Mutex + reqCanceler map[cancelKey]func(error) + + altMu sync.Mutex // guards changing altProto only + altProto atomic.Value // of nil or map[string]http.RoundTripper, key is URI scheme + + connsPerHostMu sync.Mutex + connsPerHost map[connectMethodKey]int + connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // + // The proxy type is determined by the URL scheme. "http", + // "https", and "socks5" are supported. If the scheme is empty, + // "http" is assumed. + // + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // DialContext specifies the dial function for creating unencrypted TCP connections. + // If DialContext is nil (and the deprecated Dial below is also nil), + // then the transport dials using package net. + // + // DialContext runs concurrently with calls to RoundTrip. + // A RoundTrip call that initiates a dial may end up using + // a connection dialed previously when the earlier connection + // becomes idle before the later DialContext completes. + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Dial specifies the dial function for creating unencrypted TCP connections. + // + // Dial runs concurrently with calls to RoundTrip. + // A RoundTrip call that initiates a dial may end up using + // a connection dialed previously when the earlier connection + // becomes idle before the later Dial completes. + // + // Deprecated: Use DialContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialContext takes priority. + Dial func(network, addr string) (net.Conn, error) + + // DialTLSContext specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // If DialTLSContext is nil (and the deprecated DialTLS below is also nil), + // DialContext and TLSClientConfig are used. + // + // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS + // requests and the TLSClientConfig and TLSHandshakeTimeout + // are ignored. The returned net.Conn is assumed to already be + // past the TLS handshake. + DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // DialTLS specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // Deprecated: Use DialTLSContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialTLSContext takes priority. + DialTLS func(network, addr string) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. + // If nil, the default configuration is used. + // If non-nil, HTTP/2 support may not be enabled by default. + TLSClientConfig *tls.Config + + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + + // DisableKeepAlives, if true, disables HTTP keep-alives and + // will only use the connection to the server for a single + // HTTP request. + // + // This is unrelated to the similarly named TCP keep-alives. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) connections to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + MaxConnsPerHost int + + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // ExpectContinueTimeout, if non-zero, specifies the amount of + // time to wait for a server's first response headers after fully + // writing the request headers if the request has an + // "Expect: 100-continue" header. Zero means no timeout and + // causes the body to be sent immediately, without + // waiting for the server to approve. + // This time does not include the time to send the request header. + ExpectContinueTimeout time.Duration + + // TLSNextProto specifies how the Transport switches to an + // alternate protocol (such as HTTP/2) after a TLS ALPN + // protocol negotiation. If Transport dials an TLS connection + // with a non-empty protocol name and TLSNextProto contains a + // map entry for that key (such as "h2"), then the func is + // called with the request's authority (such as "example.com" + // or "example.com:1234") and the TLS connection. The function + // must return a http.RoundTripper that then handles the request. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. + TLSNextProto map[string]func(authority string, c *tls.Conn) http.RoundTripper + + // ProxyConnectHeader optionally specifies headers to send to + // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. + ProxyConnectHeader http.Header + + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) + + // MaxResponseHeaderBytes specifies a limit on how many + // response bytes are allowed in the server's response + // header. + // + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + + // WriteBufferSize specifies the size of the write buffer used + // when writing to the transport. + // If zero, a default (currently 4KB) is used. + WriteBufferSize int + + // ReadBufferSize specifies the size of the read buffer used + // when reading from the transport. + // If zero, a default (currently 4KB) is used. + ReadBufferSize int + + // nextProtoOnce guards initialization of TLSNextProto and + // h2transport (via onceSetNextProtoDefaults) + nextProtoOnce sync.Once + h2transport h2Transport // non-nil if http2 wired up + tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired + + // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero + // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. + // By default, use of any those fields conservatively disables HTTP/2. + // To use a custom dialer or TLS config and still attempt HTTP/2 + // upgrades, set this to true. + ForceAttemptHTTP2 bool + + ResponseOptions + dump *dumper +} + +func (t *Transport) handleResponseBody(res *http.Response) { + t.autoDecodeResponseBody(res) + t.dumpResponseBody(res) +} + +func (t *Transport) dumpResponseBody(res *http.Response) { + if t.dump == nil || !t.dump.ResponseBody { + return + } + res.Body = t.dump.WrapReadCloser(res.Body) +} + +func (t *Transport) autoDecodeResponseBody(res *http.Response) { + if t.ResponseOptions.DisableAutoDecode { + return + } + contentType := res.Header.Get("Content-Type") + var shouldDecode func(contentType string) bool + if t.ResponseOptions.AutoDecodeContentType != nil { + shouldDecode = t.ResponseOptions.AutoDecodeContentType + } else { + shouldDecode = responseBodyIsText + } + if !shouldDecode(contentType) { + return + } + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + panic(err) + } + if charset, ok := params["charset"]; ok { + fmt.Println("chartset", charset, "detected") + if strings.Contains(charset, "utf-8") || strings.Contains(charset, "utf8") { // do not decode utf-8 + fmt.Println("decode not needed") + return + } + enc, _ := htmlcharset.Lookup(charset) + if enc == nil { + enc, err = ianaindex.MIME.Encoding(charset) + if err != nil { + fmt.Println("chartset", charset, "not supported:", err.Error(), "; cancel decode") + return + } + } + if enc == nil { + return + } + decodeReader := enc.NewDecoder().Reader(res.Body) + res.Body = &decodeReaderCloser{res.Body, decodeReader} + return + } + res.Body = &autoDecodeReadCloser{ReadCloser: res.Body} +} + +// A cancelKey is the key of the reqCanceler map. +// We wrap the *http.Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *http.Request +} + +func (t *Transport) writeBufferSize() int { + if t.WriteBufferSize > 0 { + return t.WriteBufferSize + } + return 4 << 10 +} + +func (t *Transport) readBufferSize() int { + if t.ReadBufferSize > 0 { + return t.ReadBufferSize + } + return 4 << 10 +} + +// Clone returns a deep copy of t's exported fields. +func (t *Transport) Clone() *Transport { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t2 := &Transport{ + Proxy: t.Proxy, + DialContext: t.DialContext, + Dial: t.Dial, + DialTLS: t.DialTLS, + DialTLSContext: t.DialTLSContext, + TLSHandshakeTimeout: t.TLSHandshakeTimeout, + DisableKeepAlives: t.DisableKeepAlives, + DisableCompression: t.DisableCompression, + MaxIdleConns: t.MaxIdleConns, + MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, + MaxConnsPerHost: t.MaxConnsPerHost, + IdleConnTimeout: t.IdleConnTimeout, + ResponseHeaderTimeout: t.ResponseHeaderTimeout, + ExpectContinueTimeout: t.ExpectContinueTimeout, + ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + GetProxyConnectHeader: t.GetProxyConnectHeader, + MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, + ForceAttemptHTTP2: t.ForceAttemptHTTP2, + WriteBufferSize: t.WriteBufferSize, + ReadBufferSize: t.ReadBufferSize, + } + if t.TLSClientConfig != nil { + t2.TLSClientConfig = t.TLSClientConfig.Clone() + } + if !t.tlsNextProtoWasNil { + npm := map[string]func(authority string, c *tls.Conn) http.RoundTripper{} + for k, v := range t.TLSNextProto { + npm[k] = v + } + t2.TLSNextProto = npm + } + return t2 +} + +// h2Transport is the interface we expect to be able to call from +// net/http against an *http2.Transport that's either bundled into +// h2_bundle.go or supplied by the user via x/net/http2. +// +// We name it with the "h2" prefix to stay out of the "http2" prefix +// namespace used by x/tools/cmd/bundle for h2_bundle.go. +type h2Transport interface { + CloseIdleConnections() +} + +func (t *Transport) hasCustomTLSDialer() bool { + return t.DialTLS != nil || t.DialTLSContext != nil +} + +// onceSetNextProtoDefaults initializes TLSNextProto. +// It must be called via t.nextProtoOnce.Do. +func (t *Transport) onceSetNextProtoDefaults() { + t.tlsNextProtoWasNil = (t.TLSNextProto == nil) + if godebug.Get("http2client") == "0" { + return + } + + // If they've already configured http2 with + // golang.org/x/net/http2 instead of the bundled copy, try to + // get at its http2.Transport value (via the "https" + // altproto map) so we can call CloseIdleConnections on it if + // requested. (Issue 22891) + altProto, _ := t.altProto.Load().(map[string]http.RoundTripper) + if rv := reflect.ValueOf(altProto["https"]); rv.IsValid() && rv.Type().Kind() == reflect.Struct && rv.Type().NumField() == 1 { + if v := rv.Field(0); v.CanInterface() { + if h2i, ok := v.Interface().(h2Transport); ok { + t.h2transport = h2i + return + } + } + } + + if t.TLSNextProto != nil { + // This is the documented way to disable http2 on a + // Transport. + return + } + if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { + // Be conservative and don't automatically enable + // http2 if they've specified a custom TLS config or + // custom dialers. Let them opt-in themselves via + // http2.ConfigureTransport so we don't surprise them + // by modifying their tls.Config. Issue 14275. + // However, if ForceAttemptHTTP2 is true, it overrides the above checks. + return + } + t2, err := http2ConfigureTransports(t) + if err != nil { + log.Printf("Error enabling Transport HTTP/2 support: %v", err) + return + } + t.h2transport = t2 + + // Auto-configure the http2.Transport's MaxHeaderListSize from + // the http.Transport's MaxResponseHeaderBytes. They don't + // exactly mean the same thing, but they're close. + // + // TODO: also add this to x/net/http2.Configure Transport, behind + // a +build go1.7 build tag: + if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 { + const h2max = 1<<32 - 1 + if limit1 >= h2max { + t2.MaxHeaderListSize = h2max + } else { + t2.MaxHeaderListSize = uint32(limit1) + } + } +} + +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions +// thereof). HTTPS_PROXY takes precedence over HTTP_PROXY for https +// requests. +// +// The environment values may be either a complete URL or a +// "host[:port]", in which case the "http" scheme is assumed. +// The schemes "http", "https", and "socks5" are supported. +// An error is returned if the value is a different form. +// +// A nil URL and nil error are returned if no proxy is defined in the +// environment, or a proxy should not be used for the given request, +// as defined by NO_PROXY. +// +// As a special case, if req.URL.Host is "localhost" (with or without +// a port number), then a nil URL and nil error will be returned. +func ProxyFromEnvironment(req *http.Request) (*url.URL, error) { + return envProxyFunc()(req.URL) +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(fixedURL *url.URL) func(*http.Request) (*url.URL, error) { + return func(*http.Request) (*url.URL, error) { + return fixedURL, nil + } +} + +// transportRequest is a wrapper around a *http.Request that adds +// optional extra headers to write and stores any error to return +// from roundTrip. +type transportRequest struct { + *http.Request // original request, not to be mutated + extra http.Header // extra headers to write, or nil + trace *httptrace.ClientTrace // optional + cancelKey cancelKey + + mu sync.Mutex // guards err + err error // first setError value for mapRoundTripError to consider +} + +func (tr *transportRequest) extraHeaders() http.Header { + if tr.extra == nil { + tr.extra = make(http.Header) + } + return tr.extra +} + +func (tr *transportRequest) setError(err error) { + tr.mu.Lock() + if tr.err == nil { + tr.err = err + } + tr.mu.Unlock() +} + +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *http.Request) bool { + if req.URL.Scheme == "https" && requestRequiresHTTP1(req) { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + +// alternatehttp.RoundTripper returns the alternate http.RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *http.Request) http.RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]http.RoundTripper) + return altProto[req.URL.Scheme] +} + +// roundTrip implements a http.RoundTripper over HTTP. +func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + ctx := req.Context() + trace := httptrace.ContextClientTrace(ctx) + + if req.URL == nil { + closeBody(req) + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + closeBody(req) + return nil, errors.New("http: nil Request.Header") + } + scheme := req.URL.Scheme + isHTTP := scheme == "http" || scheme == "https" + if isHTTP { + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + closeBody(req) + return nil, fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + closeBody(req) + return nil, fmt.Errorf("net/http: invalid header field value %q for key %v", v, k) + } + } + } + } + + origReq := req + cancelKey := cancelKey{origReq} + req = setupRewindBody(req) + + if altRT := t.alternateRoundTripper(req); altRT != nil { + if resp, err := altRT.RoundTrip(req); err != http.ErrSkipAltProtocol { + return resp, err + } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } + } + if !isHTTP { + closeBody(req) + return nil, badStringError("unsupported protocol scheme", scheme) + } + if req.Method != "" && !validMethod(req.Method) { + closeBody(req) + return nil, fmt.Errorf("net/http: invalid method %q", req.Method) + } + if req.URL.Host == "" { + closeBody(req) + return nil, errors.New("http: no Host in request URL") + } + + for { + select { + case <-ctx.Done(): + closeBody(req) + return nil, ctx.Err() + default: + } + + // treq gets modified by roundTrip, so we need to recreate for each retry. + treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + closeBody(req) + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(treq, cm) + if err != nil { + t.setReqCanceler(cancelKey, nil) + closeBody(req) + return nil, err + } + + var resp *http.Response + if pconn.alt != nil { + // HTTP/2 path. + t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest + resp, err = pconn.alt.RoundTrip(req) + } else { + resp, err = pconn.roundTrip(treq) + } + if err == nil { + resp.Request = origReq + return resp, nil + } + + // Failed. Clean up and determine whether to retry. + if http2isNoCachedConnError(err) { + if t.removeIdleConn(pconn) { + t.decConnsPerHost(pconn.cacheKey) + } + } else if !pconn.shouldRetryRequest(req, err) { + // Issue 16465: return underlying net.Conn.Read error from peek, + // as we've historically done. + if e, ok := err.(transportReadFromServerError); ok { + err = e.err + } + return nil, err + } + testHookRoundTripRetried() + + // Rewind the body if we're able to. + req, err = rewindBody(req) + if err != nil { + return nil, err + } + } +} + +var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") + +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been read at all yet. +func setupRewindBody(req *http.Request) *http.Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *http.Request) (rewound *http.Request, err error) { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { + return req, nil // nothing to rewind + } + if !req.Body.(*readTrackingBody).didClose { + closeBody(req) + } + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil +} + +// shouldRetryRequest reports whether we should retry sending a failed +// HTTP request on a new connection. The non-nil input error is the +// error from roundTrip. +func (pc *persistConn) shouldRetryRequest(req *http.Request, err error) bool { + if http2isNoCachedConnError(err) { + // Issue 16582: if the user started a bunch of + // requests at once, they can all pick the same conn + // and violate the server's max concurrent streams. + // Instead, match the HTTP/1 behavior for now and dial + // again to get a new TCP connection, rather than failing + // this request. + return true + } + if err == errMissingHost { + // User error. + return false + } + if !pc.isReused() { + // This was a fresh connection. There's no reason the server + // should've hung up on us. + // + // Also, if we retried now, we could loop forever + // creating new connections and retrying if the server + // is just hanging up on us because it doesn't like + // our request (as opposed to sending an error). + return false + } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry, if there's no body or we + // can "rewind" the body with GetBody. + return outgoingLength(req) == 0 || req.GetBody != nil + } + if !isReplayable(req) { + // Don't retry non-idempotent requests. + return false + } + if _, ok := err.(transportReadFromServerError); ok { + // We got some non-EOF net.Conn.Read failure reading + // the 1st response byte from the server. + return true + } + if err == errServerClosedIdle { + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + return true + } + return false // conservatively +} + +// RegisterProtocol registers a new protocol with scheme. +// The Transport will pass requests using the given scheme to rt. +// It is rt's responsibility to simulate HTTP request semantics. +// +// RegisterProtocol can be used by other packages to provide +// implementations of protocol schemes like "ftp" or "file". +// +// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will +// handle the RoundTrip itself for that one request, as if the +// protocol were not registered. +func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { + t.altMu.Lock() + defer t.altMu.Unlock() + oldMap, _ := t.altProto.Load().(map[string]http.RoundTripper) + if _, exists := oldMap[scheme]; exists { + panic("protocol " + scheme + " already registered") + } + newMap := make(map[string]http.RoundTripper) + for k, v := range oldMap { + newMap[k] = v + } + newMap[scheme] = rt + t.altProto.Store(newMap) +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.closeIdle = true // close newly idle connections + t.idleLRU = connLRU{} + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close(errCloseIdleConns) + } + } + if t2 := t.h2transport; t2 != nil { + t2.CloseIdleConnections() + } +} + +// CancelRequest cancels an in-flight request by closing its connection. +// CancelRequest should only be called after RoundTrip has returned. +// +// Deprecated: Use Request.WithContext to create a request with a +// cancelable context instead. CancelRequest cannot cancel HTTP/2 +// requests. +func (t *Transport) CancelRequest(req *http.Request) { + t.cancelRequest(cancelKey{req}, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +// Returns whether the request was canceled. +func (t *Transport) cancelRequest(key cancelKey, err error) bool { + // This function must not return until the cancel func has completed. + // See: https://golang.org/issue/34658 + t.reqMu.Lock() + defer t.reqMu.Unlock() + cancel := t.reqCanceler[key] + delete(t.reqCanceler, key) + if cancel != nil { + cancel(err) + } + + return cancel != nil +} + +// +// Private implementation past this point. +// + +var ( + // proxyConfigOnce guards proxyConfig + envProxyOnce sync.Once + envProxyFuncValue func(*url.URL) (*url.URL, error) +) + +// defaultProxyConfig returns a ProxyConfig value looked up +// from the environment. This mitigates expensive lookups +// on some platforms (e.g. Windows). +func envProxyFunc() func(*url.URL) (*url.URL, error) { + envProxyOnce.Do(func() { + envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc() + }) + return envProxyFuncValue +} + +// resetProxyConfig is used by tests. +func resetProxyConfig() { + envProxyOnce = sync.Once{} + envProxyFuncValue = nil +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + cm.onlyH1 = requestRequiresHTTP1(treq.Request) + return cm, err +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + if u := cm.proxyURL.User; u != nil { + username := u.Username() + password, _ := u.Password() + return "Basic " + basicAuth(username, password) + } + return "" +} + +// error values for debugging and testing, not seen by users. +var ( + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") + + // errServerClosedIdle is not seen by users for idempotent requests, but may be + // seen by a user if the server shuts down an idle connection and sends its FIN + // in flight with already-written POST body bytes from the client. + // See https://github.com/golang/go/issues/19943#issuecomment-355607646 + errServerClosedIdle = errors.New("http: server closed idle connection") +) + +// transportReadFromServerError is used by Transport.readLoop when the +// 1 byte peek read fails and we're actually anticipating a response. +// Usually this is just due to the inherent keep-alive shut down race, +// where the server closed the connection at the same time the client +// wrote. The underlying err field is usually io.EOF or some +// ECONNRESET sort of thing which varies by platform. But it might be +// the user's custom net.Conn.Read error too, so we carry it along for +// them to return from Transport.RoundTrip. +type transportReadFromServerError struct { + err error +} + +func (e transportReadFromServerError) Unwrap() error { return e.err } + +func (e transportReadFromServerError) Error() string { + return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err) +} + +func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { + if err := t.tryPutIdleConn(pconn); err != nil { + pconn.close(err) + } +} + +func (t *Transport) maxIdleConnsPerHost() int { + if v := t.MaxIdleConnsPerHost; v != 0 { + return v + } + return DefaultMaxIdleConnsPerHost +} + +// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting +// a new request. +// If pconn is no longer needed or not in a good state, tryPutIdleConn returns +// an error explaining why it wasn't registered. +// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that. +func (t *Transport) tryPutIdleConn(pconn *persistConn) error { + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + return errKeepAlivesDisabled + } + if pconn.isBroken() { + return errConnBroken + } + pconn.markReused() + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // HTTP/2 (pconn.alt != nil) connections do not come out of the idle list, + // because multiple goroutines can use them simultaneously. + // If this is an HTTP/2 connection being “returned,” we're done. + if pconn.alt != nil && t.idleLRU.m[pconn] != nil { + return nil + } + + // Deliver pconn to goroutine waiting for idle connection, if any. + // (They may be actively dialing, but this conn is ready first. + // Chrome calls this socket late binding. + // See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.) + key := pconn.cacheKey + if q, ok := t.idleConnWait[key]; ok { + done := false + if pconn.alt == nil { + // HTTP/1. + // Loop over the waiting list until we find a w that isn't done already, and hand it pconn. + for q.len() > 0 { + w := q.popFront() + if w.tryDeliver(pconn, nil) { + done = true + break + } + } + } else { + // HTTP/2. + // Can hand the same pconn to everyone in the waiting list, + // and we still won't be done: we want to put it in the idle + // list unconditionally, for any future clients too. + for q.len() > 0 { + w := q.popFront() + w.tryDeliver(pconn, nil) + } + } + if q.len() == 0 { + delete(t.idleConnWait, key) + } else { + t.idleConnWait[key] = q + } + if done { + return nil + } + } + + if t.closeIdle { + return errCloseIdle + } + if t.idleConn == nil { + t.idleConn = make(map[connectMethodKey][]*persistConn) + } + idles := t.idleConn[key] + if len(idles) >= t.maxIdleConnsPerHost() { + return errTooManyIdleHost + } + for _, exist := range idles { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } + t.idleConn[key] = append(idles, pconn) + t.idleLRU.add(pconn) + if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { + oldest := t.idleLRU.removeOldest() + oldest.close(errTooManyIdle) + t.removeIdleConnLocked(oldest) + } + + // Set idle timer, but only for HTTP/1 (pconn.alt == nil). + // The HTTP/2 implementation manages the idle timer itself + // (see idleConnTimeout in h2_bundle.go). + if t.IdleConnTimeout > 0 && pconn.alt == nil { + if pconn.idleTimer != nil { + pconn.idleTimer.Reset(t.IdleConnTimeout) + } else { + pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + } + } + pconn.idleAt = time.Now() + return nil +} + +// queueForIdleConn queues w to receive the next idle connection for w.cm. +// As an optimization hint to the caller, queueForIdleConn reports whether +// it successfully delivered an already-idle connection. +func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { + if t.DisableKeepAlives { + return false + } + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // Stop closing connections that become idle - we might want one. + // (That is, undo the effect of t.CloseIdleConnections.) + t.closeIdle = false + + if w == nil { + // Happens in test hook. + return false + } + + // If IdleConnTimeout is set, calculate the oldest + // persistConn.idleAt time we're willing to use a cached idle + // conn. + var oldTime time.Time + if t.IdleConnTimeout > 0 { + oldTime = time.Now().Add(-t.IdleConnTimeout) + } + + // Look for most recently-used idle connection. + if list, ok := t.idleConn[w.key]; ok { + stop := false + delivered := false + for len(list) > 0 && !stop { + pconn := list[len(list)-1] + + // See whether this connection has been idle too long, considering + // only the wall time (the Round(0)), in case this is a laptop or VM + // coming out of suspend with previously cached idle connections. + tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + if tooOld { + // Async cleanup. Launch in its own goroutine (as if a + // time.AfterFunc called it); it acquires idleMu, which we're + // holding, and does a synchronous net.Conn.Close. + go pconn.closeConnIfStillIdle() + } + if pconn.isBroken() || tooOld { + // If either persistConn.readLoop has marked the connection + // broken, but Transport.removeIdleConn has not yet removed it + // from the idle list, or if this persistConn is too old (it was + // idle too long), then ignore it and look for another. In both + // cases it's already in the process of being closed. + list = list[:len(list)-1] + continue + } + delivered = w.tryDeliver(pconn, nil) + if delivered { + if pconn.alt != nil { + // HTTP/2: multiple clients can share pconn. + // Leave it in the list. + } else { + // HTTP/1: only one client can use pconn. + // Remove it from the list. + t.idleLRU.remove(pconn) + list = list[:len(list)-1] + } + } + stop = true + } + if len(list) > 0 { + t.idleConn[w.key] = list + } else { + delete(t.idleConn, w.key) + } + if stop { + return delivered + } + } + + // Register to receive next connection that becomes idle. + if t.idleConnWait == nil { + t.idleConnWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.idleConnWait[w.key] + q.cleanFront() + q.pushBack(w) + t.idleConnWait[w.key] = q + return false +} + +// removeIdleConn marks pconn as dead. +func (t *Transport) removeIdleConn(pconn *persistConn) bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.removeIdleConnLocked(pconn) +} + +// t.idleMu must be held. +func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { + if pconn.idleTimer != nil { + pconn.idleTimer.Stop() + } + t.idleLRU.remove(pconn) + key := pconn.cacheKey + pconns := t.idleConn[key] + var removed bool + switch len(pconns) { + case 0: + // Nothing + case 1: + if pconns[0] == pconn { + delete(t.idleConn, key) + removed = true + } + default: + for i, v := range pconns { + if v != pconn { + continue + } + // Slide down, keeping most recently-used + // conns at the end. + copy(pconns[i:], pconns[i+1:]) + t.idleConn[key] = pconns[:len(pconns)-1] + removed = true + break + } + } + return removed +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[key] + if !ok { + return false + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } + return true +} + +var zeroDialer net.Dialer + +func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { + if t.DialContext != nil { + return t.DialContext(ctx, network, addr) + } + if t.Dial != nil { + c, err := t.Dial(network, addr) + if c == nil && err == nil { + err = errors.New("net/http: Transport.Dial hook returned (nil, nil)") + } + return c, err + } + return zeroDialer.DialContext(ctx, network, addr) +} + +// A wantConn records state about a wanted connection +// (that is, an active call to getConn). +// The conn may be gotten by dialing or by finding an idle connection, +// or a cancellation may make the conn no longer wanted. +// These three options are racing against each other and use +// wantConn to coordinate and agree about the winning outcome. +type wantConn struct { + cm connectMethod + key connectMethodKey // cm.key() + ctx context.Context // context for dial + ready chan struct{} // closed when pc, err pair is delivered + + // hooks for testing to know when dials are done + // beforeDial is called in the getConn goroutine when the dial is queued. + // afterDial is called when the dial is completed or canceled. + beforeDial func() + afterDial func() + + mu sync.Mutex // protects pc, err, close(ready) + pc *persistConn + err error +} + +// waiting reports whether w is still waiting for an answer (connection or error). +func (w *wantConn) waiting() bool { + select { + case <-w.ready: + return false + default: + return true + } +} + +// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. +func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + + if w.pc != nil || w.err != nil { + return false + } + + w.pc = pc + w.err = err + if w.pc == nil && w.err == nil { + panic("net/http: internal error: misuse of tryDeliver") + } + close(w.ready) + return true +} + +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. +func (w *wantConn) cancel(t *Transport, err error) { + w.mu.Lock() + if w.pc == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + pc := w.pc + w.pc = nil + w.err = err + w.mu.Unlock() + + if pc != nil { + t.putOrCloseIdleConn(pc) + } +} + +// A wantConnQueue is a queue of wantConns. +type wantConnQueue struct { + // This is a queue, not a deque. + // It is split into two stages - head[headPos:] and tail. + // popFront is trivial (headPos++) on the first stage, and + // pushBack is trivial (append) on the second stage. + // If the first stage is empty, popFront can swap the + // first and second stages to remedy the situation. + // + // This two-stage split is analogous to the use of two lists + // in Okasaki's purely functional queue but without the + // overhead of reversing the list when swapping stages. + head []*wantConn + headPos int + tail []*wantConn +} + +// len returns the number of items in the queue. +func (q *wantConnQueue) len() int { + return len(q.head) - q.headPos + len(q.tail) +} + +// pushBack adds w to the back of the queue. +func (q *wantConnQueue) pushBack(w *wantConn) { + q.tail = append(q.tail, w) +} + +// popFront removes and returns the wantConn at the front of the queue. +func (q *wantConnQueue) popFront() *wantConn { + if q.headPos >= len(q.head) { + if len(q.tail) == 0 { + return nil + } + // Pick up tail as new head, clear tail. + q.head, q.headPos, q.tail = q.tail, 0, q.head[:0] + } + w := q.head[q.headPos] + q.head[q.headPos] = nil + q.headPos++ + return w +} + +// peekFront returns the wantConn at the front of the queue without removing it. +func (q *wantConnQueue) peekFront() *wantConn { + if q.headPos < len(q.head) { + return q.head[q.headPos] + } + if len(q.tail) > 0 { + return q.tail[0] + } + return nil +} + +// cleanFront pops any wantConns that are no longer waiting from the head of the +// queue, reporting whether any were popped. +func (q *wantConnQueue) cleanFront() (cleaned bool) { + for { + w := q.peekFront() + if w == nil || w.waiting() { + return cleaned + } + q.popFront() + cleaned = true + } +} + +func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { + if t.DialTLSContext != nil { + conn, err = t.DialTLSContext(ctx, network, addr) + } else { + conn, err = t.DialTLS(network, addr) + } + if conn == nil && err == nil { + err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)") + } + return +} + +// getConn dials and creates a new persistConn to the target as +// specified in the connectMethod. This includes doing a proxy CONNECT +// and/or setting up TLS. If this doesn't return an error, the persistConn +// is ready to write requests to. +func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { + req := treq.Request + trace := treq.trace + ctx := req.Context() + if trace != nil && trace.GetConn != nil { + trace.GetConn(cm.addr()) + } + + w := &wantConn{ + cm: cm, + key: cm.key(), + ctx: ctx, + ready: make(chan struct{}, 1), + beforeDial: testHookPrePendingDial, + afterDial: testHookPostPendingDial, + } + defer func() { + if err != nil { + w.cancel(t, err) + } + }() + + // Queue for idle connection. + if delivered := t.queueForIdleConn(w); delivered { + pc := w.pc + // Trace only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + if pc.alt == nil && trace != nil && trace.GotConn != nil { + trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) + } + // set request canceler to some non-nil function so we + // can detect whether it was cleared between now and when + // we enter roundTrip + t.setReqCanceler(treq.cancelKey, func(error) {}) + return pc, nil + } + + cancelc := make(chan error, 1) + t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) + + // Queue for permission to dial. + t.queueForDial(w) + + // Wait for completion or cancellation. + select { + case <-w.ready: + // Trace success but only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + } + if w.err != nil { + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + case <-req.Cancel: + return nil, errRequestCanceledConn + case <-req.Context().Done(): + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + default: + // return below + } + } + return w.pc, w.err + case <-req.Cancel: + return nil, errRequestCanceledConn + case <-req.Context().Done(): + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + } +} + +// queueForDial queues w to wait for permission to begin dialing. +// Once w receives permission to dial, it will do so in a separate goroutine. +func (t *Transport) queueForDial(w *wantConn) { + w.beforeDial() + if t.MaxConnsPerHost <= 0 { + go t.dialConnFor(w) + return + } + + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + + if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[w.key] = n + 1 + go t.dialConnFor(w) + return + } + + if t.connsPerHostWait == nil { + t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.connsPerHostWait[w.key] + q.cleanFront() + q.pushBack(w) + t.connsPerHostWait[w.key] = q +} + +// dialConnFor dials on behalf of w and delivers the result to w. +// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. +// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. +func (t *Transport) dialConnFor(w *wantConn) { + defer w.afterDial() + + pc, err := t.dialConn(w.ctx, w.cm) + delivered := w.tryDeliver(pc, err) + if err == nil && (!delivered || pc.alt != nil) { + // pconn was not passed to w, + // or it is HTTP/2 and can be shared. + // Add to the idle connection pool. + t.putOrCloseIdleConn(pc) + } + if err != nil { + t.decConnsPerHost(w.key) + } +} + +// decConnsPerHost decrements the per-host connection count for key, +// which may in turn give a different waiting goroutine permission to dial. +func (t *Transport) decConnsPerHost(key connectMethodKey) { + if t.MaxConnsPerHost <= 0 { + return + } + + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + n := t.connsPerHost[key] + if n == 0 { + // Shouldn't happen, but if it does, the counting is buggy and could + // easily lead to a silent deadlock, so report the problem loudly. + panic("net/http: internal error: connCount underflow") + } + + // Can we hand this count to a goroutine still waiting to dial? + // (Some goroutines on the wait list may have timed out or + // gotten a connection another way. If they're all gone, + // we don't want to kick off any spurious dial operations.) + if q := t.connsPerHostWait[key]; q.len() > 0 { + done := false + for q.len() > 0 { + w := q.popFront() + if w.waiting() { + go t.dialConnFor(w) + done = true + break + } + } + if q.len() == 0 { + delete(t.connsPerHostWait, key) + } else { + // q is a value (like a slice), so we have to store + // the updated q back into the map. + t.connsPerHostWait[key] = q + } + if done { + return + } + } + + // Otherwise, decrement the recorded count. + if n--; n == 0 { + delete(t.connsPerHost, key) + } else { + t.connsPerHost[key] = n + } +} + +// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS +// tunnel, this function establishes a nested TLS session inside the encrypted channel. +// The remote endpoint's name may be overridden by TLSClientConfig.ServerName. +func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error { + // Initiate TLS and check remote host name against certificate. + cfg := cloneTLSConfig(pconn.t.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = name + } + if pconn.cacheKey.onlyH1 { + cfg.NextProtos = nil + } + plainConn := pconn.conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := pconn.t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := tlsConn.HandshakeContext(ctx) + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + return err + } + cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } + pconn.tlsState = &cs + pconn.conn = tlsConn + return nil +} + +type erringRoundTripper interface { + RoundTripErr() error +} + +func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { + pconn = &persistConn{ + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), + writeLoopDone: make(chan struct{}), + } + trace := httptrace.ContextClientTrace(ctx) + wrapErr := func(err error) error { + if cm.proxyURL != nil { + // Return a typed error, per Issue 16997 + return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} + } + return err + } + if cm.scheme() == "https" && t.hasCustomTLSDialer() { + var err error + pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) + if err != nil { + return nil, wrapErr(err) + } + if tc, ok := pconn.conn.(*tls.Conn); ok { + // Handshake here, in case DialTLS didn't. TLSNextProto below + // depends on it for knowing the connection state. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + if err := tc.HandshakeContext(ctx); err != nil { + go pconn.conn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + return nil, err + } + cs := tc.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } + pconn.tlsState = &cs + } + } else { + conn, err := t.dial(ctx, "tcp", cm.addr()) + if err != nil { + return nil, wrapErr(err) + } + pconn.conn = conn + if cm.scheme() == "https" { + var firstTLSHost string + if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { + return nil, wrapErr(err) + } + if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { + return nil, wrapErr(err) + } + } + } + + // Proxy setup. + switch { + case cm.proxyURL == nil: + // Do nothing. Not using a proxy. + case cm.proxyURL.Scheme == "socks5": + conn := pconn.conn + d := socksNewDialer("tcp", conn.RemoteAddr().String()) + if u := cm.proxyURL.User; u != nil { + auth := &socksUsernamePassword{ + Username: u.Username(), + } + auth.Password, _ = u.Password() + d.AuthMethods = []socksAuthMethod{ + socksAuthMethodNotRequired, + socksAuthMethodUsernamePassword, + } + d.Authenticate = auth.Authenticate + } + if _, err := d.DialWithConn(ctx, conn, "tcp", cm.targetAddr); err != nil { + conn.Close() + return nil, err + } + case cm.targetScheme == "http": + pconn.isProxy = true + if pa := cm.proxyAuth(); pa != "" { + pconn.mutateHeaderFunc = func(h http.Header) { + h.Set("Proxy-Authorization", pa) + } + } + case cm.targetScheme == "https": + conn := pconn.conn + var hdr http.Header + if t.GetProxyConnectHeader != nil { + var err error + hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr) + if err != nil { + conn.Close() + return nil, err + } + } else { + hdr = t.ProxyConnectHeader + } + if hdr == nil { + hdr = make(http.Header) + } + if pa := cm.proxyAuth(); pa != "" { + hdr = hdr.Clone() + hdr.Set("Proxy-Authorization", pa) + } + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: cm.targetAddr}, + Host: cm.targetAddr, + Header: hdr, + } + + // If there's no done channel (no deadline or cancellation + // from the caller possible), at least set some (long) + // timeout here. This will make sure we don't block forever + // and leak a goroutine if the connection stops replying + // after the TCP connect. + connectCtx := ctx + if ctx.Done() == nil { + newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + connectCtx = newCtx + } + + didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails + var ( + resp *http.Response + err error // write or read error + ) + // Write the CONNECT request & read the response. + go func() { + defer close(didReadResponse) + err = connectReq.Write(conn) + if err != nil { + return + } + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err = http.ReadResponse(br, connectReq) + }() + select { + case <-connectCtx.Done(): + conn.Close() + <-didReadResponse + return nil, connectCtx.Err() + case <-didReadResponse: + // resp or err now set + } + if err != nil { + conn.Close() + return nil, err + } + if resp.StatusCode != 200 { + _, text, ok := strings.Cut(resp.Status, " ") + conn.Close() + if !ok { + return nil, errors.New("unknown status code") + } + return nil, errors.New(text) + } + } + + if cm.proxyURL != nil && cm.targetScheme == "https" { + if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { + return nil, err + } + } + + if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { + alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) + if e, ok := alt.(erringRoundTripper); ok { + // pconn.conn was closed by next (http2configureTransports.upgradeFn). + return nil, e.RoundTripErr() + } + return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil + } + } + + pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize()) + pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize()) + + go pconn.readLoop() + go pconn.writeLoop() + return pconn, nil +} + +// persistConnWriter is the io.Writer written to by pc.bw. +// It accumulates the number of bytes written to the underlying conn, +// so the retry logic can determine whether any bytes made it across +// the wire. +// This is exactly 1 pointer field wide so it can go into an interface +// without allocation. +type persistConnWriter struct { + pc *persistConn +} + +func (w persistConnWriter) Write(p []byte) (n int, err error) { + n, err = w.pc.conn.Write(p) + w.pc.nwrite += int64(n) + return +} + +// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if +// the Conn implements io.ReaderFrom, it can take advantage of optimizations +// such as sendfile. +func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) { + n, err = io.Copy(w.pc.conn, r) + w.pc.nwrite += n + return +} + +var _ io.ReaderFrom = (*persistConnWriter)(nil) + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// connectMethod.key().String() Description +// ------------------------------ ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com +// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com +// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com +// https://proxy.com|http https to proxy, http to anywhere after that +// +type connectMethod struct { + _ incomparable + proxyURL *url.URL // nil for no proxy, else full proxy URL + targetScheme string // "http" or "https" + // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), + // then targetAddr is not included in the connect method key, because the socket can + // be reused for different targetAddr values. + targetAddr string + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 +} + +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + onlyH1: cm.onlyH1, + } +} + +// scheme returns the first hop scheme: http, https, or socks5 +func (cm *connectMethod) scheme() string { + if cm.proxyURL != nil { + return cm.proxyURL.Scheme + } + return cm.targetScheme +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// tlsHost returns the host name to match against the peer's +// TLS certificate. +func (cm *connectMethod) tlsHost() string { + h := cm.targetAddr + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + return h +} + +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string + onlyH1 bool +} + +func (k connectMethodKey) String() string { + // Only used by tests. + var h1 string + if k.onlyH1 { + h1 = ",h1" + } + return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr) +} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + // alt optionally specifies the TLS NextProto http.RoundTripper. + // This is used for HTTP/2 today and future protocols later. + // If it's non-nil, the rest of the fields are unused. + alt http.RoundTripper + + t *Transport + cacheKey connectMethodKey + conn net.Conn + tlsState *tls.ConnectionState + br *bufio.Reader // from conn + bw *bufio.Writer // to conn + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // closed when conn closed + isProxy bool + sawEOF bool // whether we've seen EOF from conn; owned by readLoop + readLimit int64 // bytes allowed to be read; owned by readLoop + // writeErrCh passes the request write error (usually nil) + // from the writeLoop goroutine to the readLoop which passes + // it off to the res.Body reader, which then uses it to decide + // whether or not a connection can be reused. Issue 7569. + writeErrCh chan error + + writeLoopDone chan struct{} // closed when write loop ends + + // Both guarded by Transport.idleMu: + idleAt time.Time // time it last become idle + idleTimer *time.Timer // holding an AfterFunc to close it + + mu sync.Mutex // guards following fields + numExpectedResponses int + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled + broken bool // an error has happened on this connection; marked broken so it's not reused. + reused bool // whether conn has had successful request/response and is being reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(http.Header) +} + +func (pc *persistConn) maxHeaderResponseSize() int64 { + if v := pc.t.MaxResponseHeaderBytes; v != 0 { + return v + } + return 10 << 20 // conservative default; same as http2 +} + +func (pc *persistConn) Read(p []byte) (n int, err error) { + if pc.readLimit <= 0 { + return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize()) + } + if int64(len(p)) > pc.readLimit { + p = p[:pc.readLimit] + } + n, err = pc.conn.Read(p) + if err == io.EOF { + pc.sawEOF = true + } + pc.readLimit -= int64(n) + return +} + +// isBroken reports whether this connection is in a known broken state. +func (pc *persistConn) isBroken() bool { + pc.mu.Lock() + b := pc.closed != nil + pc.mu.Unlock() + return b +} + +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancellation. +func (pc *persistConn) canceled() error { + pc.mu.Lock() + defer pc.mu.Unlock() + return pc.canceledErr +} + +// isReused reports whether this connection has been used before. +func (pc *persistConn) isReused() bool { + pc.mu.Lock() + r := pc.reused + pc.mu.Unlock() + return r +} + +func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) { + pc.mu.Lock() + defer pc.mu.Unlock() + t.Reused = pc.reused + t.Conn = pc.conn + t.WasIdle = true + if !idleAt.IsZero() { + t.IdleTime = time.Since(idleAt) + } + return +} + +func (pc *persistConn) cancelRequest(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.canceledErr = err + pc.closeLocked(errRequestCanceled) +} + +// closeConnIfStillIdle closes the connection if it's still sitting idle. +// This is what's called by the persistConn's idleTimer, and is run in its +// own goroutine. +func (pc *persistConn) closeConnIfStillIdle() { + t := pc.t + t.idleMu.Lock() + defer t.idleMu.Unlock() + if _, ok := t.idleLRU.m[pc]; !ok { + // Not idle. + return + } + t.removeIdleConnLocked(pc) + pc.close(errIdleConnTimeout) +} + +// mapRoundTripError returns the appropriate error value for +// persistConn.roundTrip. +// +// The provided err is the first error that (*persistConn).roundTrip +// happened to receive from its select statement. +// +// The startBytesWritten value should be the value of pc.nwrite before the roundTrip +// started writing the request. +func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error { + if err == nil { + return nil + } + + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + + // If the request was canceled, that's better than network + // failures that were likely the result of tearing down the + // connection. + if cerr := pc.canceled(); cerr != nil { + return cerr + } + + // See if an error was set explicitly. + req.mu.Lock() + reqErr := req.err + req.mu.Unlock() + if reqErr != nil { + return reqErr + } + + if err == errServerClosedIdle { + // Don't decorate + return err + } + + if _, ok := err.(transportReadFromServerError); ok { + // Don't decorate + return err + } + if pc.isBroken() { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) + } + return err +} + +// errCallerOwnsConn is an internal sentinel error used when we hand +// off a writable response.Body to the caller. We use this to prevent +// closing a net.Conn that is now owned by the caller. +var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn") + +func (pc *persistConn) readLoop() { + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { + pc.close(closeErr) + pc.t.removeIdleConn(pc) + }() + + tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + closeErr = err + if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + trace.PutIdleConn(err) + } + return false + } + if trace != nil && trace.PutIdleConn != nil { + trace.PutIdleConn(nil) + } + return true + } + + // eofc is used to block caller goroutines reading from Response.Body + // at EOF until this goroutines has (potentially) added the connection + // back to the idle pool. + eofc := make(chan struct{}) + defer close(eofc) // unblock reader on errors + + // Read this once, before loop starts. (to avoid races in tests) + testHookMu.Lock() + testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + testHookMu.Unlock() + + alive := true + for alive { + pc.readLimit = pc.maxHeaderResponseSize() + _, err := pc.br.Peek(1) + + pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(err) + pc.mu.Unlock() + return + } + pc.mu.Unlock() + + rc := <-pc.reqch + trace := httptrace.ContextClientTrace(rc.req.Context()) + + var resp *http.Response + if err == nil { + resp, err = pc.readResponse(rc, trace) + } else { + err = transportReadFromServerError{err} + closeErr = err + } + + if err != nil { + if pc.readLimit <= 0 { + err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize()) + } + + select { + case rc.ch <- responseAndError{err: err}: + case <-rc.callerGone: + return + } + return + } + pc.readLimit = maxInt64 // effectively no limit for response bodies + + pc.mu.Lock() + pc.numExpectedResponses-- + pc.mu.Unlock() + + bodyWritable := bodyIsWritable(resp) + hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + + if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + alive = false + } + + if !hasBody || bodyWritable { + replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) + + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + alive = alive && + !pc.sawEOF && + pc.wroteRequest() && + replaced && tryPutIdleConn(trace) + + if bodyWritable { + closeErr = errCallerOwnsConn + } + + select { + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: + return + } + + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + continue + } + + waitForBodyRead := make(chan bool, 2) + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + waitForBodyRead <- false + <-eofc // will be closed by deferred call at the end of the function + return nil + + }, + fn: func(err error) error { + isEOF := err == io.EOF + waitForBodyRead <- isEOF + if isEOF { + <-eofc // see comment above eofc declaration + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + + resp.Body = body + if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + resp.Body = &gzipReader{body: body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true + } + + select { + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: + return + } + + // Before looping back to the top of this function and peeking on + // the bufio.textprotoReader, wait for the caller goroutine to finish + // reading the response body. (or for cancellation or death) + select { + case bodyEOF := <-waitForBodyRead: + replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + alive = alive && + bodyEOF && + !pc.sawEOF && + pc.wroteRequest() && + replaced && tryPutIdleConn(trace) + if bodyEOF { + eofc <- struct{}{} + } + case <-rc.req.Cancel: + alive = false + pc.t.CancelRequest(rc.req) + case <-rc.req.Context().Done(): + alive = false + pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + case <-pc.closech: + alive = false + } + + testHookReadLoopBeforeNextRead() + } +} + +func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { + if pc.closed != nil { + return + } + if n := pc.br.Buffered(); n > 0 { + buf, _ := pc.br.Peek(n) + if is408Message(buf) { + pc.closeLocked(errServerClosedIdle) + return + } else { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) + } + } + if peekErr == io.EOF { + // common case. + pc.closeLocked(errServerClosedIdle) + } else { + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %v", peekErr)) + } +} + +// is408Message reports whether buf has the prefix of an +// HTTP 408 Request Timeout response. +// See golang.org/issue/32310. +func is408Message(buf []byte) bool { + if len(buf) < len("HTTP/1.x 408") { + return false + } + if string(buf[:7]) != "HTTP/1." { + return false + } + return string(buf[8:12]) == " 408" +} + +// readResponse reads an HTTP response (or two, in the case of "Expect: +// 100-continue") from the server. It returns the final non-100 one. +// trace is optional. +func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTrace) (resp *http.Response, err error) { + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := pc.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + num1xx := 0 // number of informational 1xx headers received + const max1xxResponses = 5 // arbitrary bound on number of informational responses + + continueCh := rc.continueCh + for { + resp, err = pc._readResponse(rc.req) + if err != nil { + return + } + resCode := resp.StatusCode + if continueCh != nil { + if resCode == 100 { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } + continueCh <- struct{}{} + continueCh = nil + } else if resCode >= 200 { + close(continueCh) + continueCh = nil + } + } + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see issue 26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + + if is1xxNonTerminal { + num1xx++ + if num1xx > max1xxResponses { + return nil, errors.New("net/http: too many 1xx informational responses") + } + pc.readLimit = pc.maxHeaderResponseSize() // reset the limit + if trace != nil && trace.Got1xxResponse != nil { + if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil { + return nil, err + } + } + continue + } + break + } + if isProtocolSwitch(resp) { + resp.Body = newReadWriteCloserBody(pc.br, pc.conn) + } + + resp.TLS = pc.tlsState + return +} + +// waitForContinue returns the function to block until +// any response, timeout or connection close. After any of them, +// the function returns a bool which indicates if the body should be sent. +func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { + if continueCh == nil { + return nil + } + return func() bool { + timer := time.NewTimer(pc.t.ExpectContinueTimeout) + defer timer.Stop() + + select { + case _, ok := <-continueCh: + return ok + case <-timer.C: + return true + case <-pc.closech: + return false + } + } +} + +func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser { + body := &readWriteCloserBody{ReadWriteCloser: rwc} + if br.Buffered() != 0 { + body.br = br + } + return body +} + +// readWriteCloserBody is the Response.Body type used when we want to +// give users write access to the Body through the underlying +// connection (TCP, unless using custom dialers). This is then +// the concrete type for a Response.Body on the 101 Switching +// Protocols response, as used by WebSockets, h2c, etc. +type readWriteCloserBody struct { + _ incomparable + br *bufio.Reader // used until empty + io.ReadWriteCloser +} + +func (b *readWriteCloserBody) Read(p []byte) (n int, err error) { + if b.br != nil { + if n := b.br.Buffered(); len(p) > n { + p = p[:n] + } + n, err = b.br.Read(p) + if b.br.Buffered() == 0 { + b.br = nil + } + return n, err + } + return b.ReadWriteCloser.Read(p) +} + +// nothingWrittenError wraps a write errors which ended up writing zero bytes. +type nothingWrittenError struct { + error +} + +func (pc *persistConn) writeLoop() { + defer close(pc.writeLoopDone) + for { + select { + case wr := <-pc.writech: + startBytesWritten := pc.nwrite + err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), pc.t.dump) + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears down + // connections and causes other + // errors. + wr.req.setError(err) + } + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + if pc.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } + } + pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function + if err != nil { + pc.close(err) + return + } + case <-pc.closech: + return + } + } +} + +// maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip +// will wait to see the Request's Body.Write result after getting a +// response from the server. See comments in (*persistConn).wroteRequest. +const maxWriteWaitBeforeConnReuse = 50 * time.Millisecond + +// wroteRequest is a check before recycling a connection that the previous write +// (from writeLoop above) happened and was successful. +func (pc *persistConn) wroteRequest() bool { + select { + case err := <-pc.writeErrCh: + // Common case: the write happened well before the response, so + // avoid creating a timer. + return err == nil + default: + // Rare case: the request was written in writeLoop above but + // before it could send to pc.writeErrCh, the reader read it + // all, processed it, and called us here. In this case, give the + // write goroutine a bit of time to finish its send. + // + // Less rare case: We also get here in the legitimate case of + // Issue 7569, where the writer is still writing (or stalled), + // but the server has already replied. In this case, we don't + // want to wait too long, and we want to return false so this + // connection isn't re-used. + t := time.NewTimer(maxWriteWaitBeforeConnReuse) + defer t.Stop() + select { + case err := <-pc.writeErrCh: + return err == nil + case <-t.C: + return false + } + } +} + +// responseAndError is how the goroutine reading from an HTTP/1 server +// communicates with the goroutine doing the RoundTrip. +type responseAndError struct { + _ incomparable + res *http.Response // else use this response (see res method) + err error +} + +type requestAndChan struct { + _ incomparable + req *http.Request + cancelKey cancelKey + ch chan responseAndError // unbuffered; always send in select on callerGone + + // whether the Transport (as opposed to the user client code) + // added the Accept-Encoding gzip header. If the Transport + // set it, only then do we transparently decode the gzip. + addedGzip bool + + // Optional blocking chan for Expect: 100-continue (for send). + // If the request has an "Expect: 100-continue" header and + // the server responds 100 Continue, readLoop send a value + // to writeLoop via this chan. + continueCh chan<- struct{} + + callerGone <-chan struct{} // closed when roundTrip caller has returned +} + +// A writeRequest is sent by the caller's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error + + // Optional blocking chan for Expect: 100-continue (for receive). + // If not nil, writeLoop blocks sending request body until + // it receives from this chan. + continueCh <-chan struct{} +} + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} + +// errRequestCanceled is set to be identical to the one from h2 to facilitate +// testing. +var errRequestCanceled = http2errRequestCanceled +var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? + +func nop() {} + +// testHooks. Always non-nil. +var ( + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop +) + +func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, err error) { + testHookEnterRoundTrip() + if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { + pc.t.putOrCloseIdleConn(pc) + return nil, errRequestCanceled + } + pc.mu.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.mu.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) + } + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // https://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + requestedGzip = true + req.extraHeaders().Set("Accept-Encoding", "gzip") + } + + var continueCh chan struct{} + if req.ProtoAtLeast(1, 1) && req.Body != nil && reqExpectsContinue(req.Request) { + continueCh = make(chan struct{}, 1) + } + + if pc.t.DisableKeepAlives && + !reqWantsClose(req.Request) && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + + gone := make(chan struct{}) + defer close(gone) + + defer func() { + if err != nil { + pc.t.setReqCanceler(req.cancelKey, nil) + } + }() + + const debugRoundTrip = false + + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + startBytesWritten := pc.nwrite + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh, continueCh} + + resc := make(chan responseAndError) + pc.reqch <- requestAndChan{ + req: req.Request, + cancelKey: req.cancelKey, + ch: resc, + addedGzip: requestedGzip, + continueCh: continueCh, + callerGone: gone, + } + + var respHeaderTimer <-chan time.Time + cancelChan := req.Request.Cancel + ctxDoneChan := req.Context().Done() + pcClosed := pc.closech + canceled := false + for { + testHookWaitResLoop() + select { + case err := <-writeErrCh: + if debugRoundTrip { + req.logf("writeErrCh resv: %T/%#v", err, err) + } + if err != nil { + pc.close(fmt.Errorf("write error: %v", err)) + return nil, pc.mapRoundTripError(req, startBytesWritten, err) + } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + if debugRoundTrip { + req.logf("starting timer for %v", d) + } + timer := time.NewTimer(d) + defer timer.Stop() // prevent leaks + respHeaderTimer = timer.C + } + case <-pcClosed: + pcClosed = nil + if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) + } + case <-respHeaderTimer: + if debugRoundTrip { + req.logf("timeout waiting for response headers.") + } + pc.close(errTimeout) + return nil, errTimeout + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if debugRoundTrip { + req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil + case <-cancelChan: + canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) + cancelChan = nil + case <-ctxDoneChan: + canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) + cancelChan = nil + ctxDoneChan = nil + } + } +} + +// tLogKey is a context WithValue key for test debugging contexts containing +// a t.Logf func. See export_test.go's Request.WithT method. +type tLogKey struct{} + +func (tr *transportRequest) logf(format string, args ...any) { + if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok { + logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) + } +} + +// markReused marks this connection as having been successfully used for a +// request and response. +func (pc *persistConn) markReused() { + pc.mu.Lock() + pc.reused = true + pc.mu.Unlock() +} + +// close closes the underlying TCP connection and closes +// the pc.closech channel. +// +// The provided err is only for testing and debugging; in normal +// circumstances it should never be seen by users. +func (pc *persistConn) close(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.closeLocked(err) +} + +func (pc *persistConn) closeLocked(err error) { + if err == nil { + panic("nil error") + } + pc.broken = true + if pc.closed == nil { + pc.closed = err + pc.t.decConnsPerHost(pc.cacheKey) + // Close HTTP/1 (pc.alt == nil) connection. + // HTTP/2 closes its connection itself. + if pc.alt == nil { + if err != errCallerOwnsConn { + pc.conn.Close() + } + close(pc.closech) + } + } + pc.mutateHeaderFunc = nil +} + +var portMap = map[string]string{ + "http": "80", + "https": "443", + "socks5": "1080", +} + +// canonicalAddr returns url.Host but always with a ":port" suffix +func canonicalAddr(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v + } + port := url.Port() + if port == "" { + port = portMap[url.Scheme] + } + return net.JoinHostPort(addr, port) +} + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +// fakeLocker is a sync.Locker which does nothing. It's used to guard +// test-only fields when not under test, to avoid runtime atomic +// overhead. +type fakeLocker struct{} + +func (fakeLocker) Lock() {} +func (fakeLocker) Unlock() {} + +// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if +// cfg is nil. This is safe to call even if cfg is in active use by a TLS +// client or server. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +type connLRU struct { + ll *list.List // list.Element.Value type is of *persistConn + m map[*persistConn]*list.Element +} + +// add adds pc to the head of the linked list. +func (cl *connLRU) add(pc *persistConn) { + if cl.ll == nil { + cl.ll = list.New() + cl.m = make(map[*persistConn]*list.Element) + } + ele := cl.ll.PushFront(pc) + if _, ok := cl.m[pc]; ok { + panic("persistConn was already in LRU") + } + cl.m[pc] = ele +} + +func (cl *connLRU) removeOldest() *persistConn { + ele := cl.ll.Back() + pc := ele.Value.(*persistConn) + cl.ll.Remove(ele) + delete(cl.m, pc) + return pc +} + +// remove removes pc from cl. +func (cl *connLRU) remove(pc *persistConn) { + if ele, ok := cl.m[pc]; ok { + cl.ll.Remove(ele) + delete(cl.m, pc) + } +} + +// len returns the number of items in the cache. +func (cl *connLRU) len() int { + return len(cl.m) +} diff --git a/transport_default_js.go b/transport_default_js.go new file mode 100644 index 00000000..7cd8e335 --- /dev/null +++ b/transport_default_js.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js && wasm +// +build js,wasm + +package req + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return nil +} diff --git a/transport_default_other.go b/transport_default_other.go new file mode 100644 index 00000000..8191ea79 --- /dev/null +++ b/transport_default_other.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !(js && wasm) +// +build !js !wasm + +package req + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext +} From d25097cccec5ae0c43ff50a1647c9a9769b7ac9f Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:10:18 +0800 Subject: [PATCH 025/843] change go mod to v2 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index da3c0110..30e8e60f 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/imroc/req +module github.com/imroc/req/v2 go 1.16 From 580c5740971146c6b71601f7f81e29889e64ca7a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:21:07 +0800 Subject: [PATCH 026/843] fix v2 import path --- http_request.go | 2 +- transfer.go | 4 ++-- transport.go | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/http_request.go b/http_request.go index 49439527..1a50b374 100644 --- a/http_request.go +++ b/http_request.go @@ -4,7 +4,7 @@ import ( "bufio" "errors" "fmt" - "github.com/imroc/req/internal/ascii" + "github.com/imroc/req/v2/internal/ascii" "golang.org/x/net/http/httpguts" "io" "net/http" diff --git a/transfer.go b/transfer.go index 88a23b5e..2d09d0c5 100644 --- a/transfer.go +++ b/transfer.go @@ -9,8 +9,8 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/internal" - "github.com/imroc/req/internal/ascii" + "github.com/imroc/req/v2/internal" + "github.com/imroc/req/v2/internal/ascii" "io" "net/http" "net/http/httptrace" diff --git a/transport.go b/transport.go index 8c0cd99a..bca0b287 100644 --- a/transport.go +++ b/transport.go @@ -17,8 +17,8 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/imroc/req/internal/ascii" - "github.com/imroc/req/internal/godebug" + "github.com/imroc/req/v2/internal/ascii" + "github.com/imroc/req/v2/internal/godebug" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" From 2ad70115c0153e63cfafbd2d8349d67df4c7cbb9 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:27:09 +0800 Subject: [PATCH 027/843] any --> interface{} --- header.go | 2 +- http.go | 2 +- transfer.go | 6 +++--- transport.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/header.go b/header.go index 28e2a9ae..eb991c49 100644 --- a/header.go +++ b/header.go @@ -39,7 +39,7 @@ func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kv func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } var headerSorterPool = sync.Pool{ - New: func() any { return new(headerSorter) }, + New: func() interface{} { return new(headerSorter) }, } // get is like Get, but key must already be in CanonicalHeaderKey form. diff --git a/http.go b/http.go index 55f0a2be..54fe1369 100644 --- a/http.go +++ b/http.go @@ -3,7 +3,7 @@ package req import ( "encoding/base64" "fmt" - "github.com/imroc/req/internal/ascii" + "github.com/imroc/req/v2/internal/ascii" "golang.org/x/net/http/httpguts" "golang.org/x/net/idna" "io" diff --git a/transfer.go b/transfer.go index 2d09d0c5..9b9c2d1b 100644 --- a/transfer.go +++ b/transfer.go @@ -74,7 +74,7 @@ type transferWriter struct { ByteReadCh chan readResult // non-nil if probeRequestBody called } -func newTransferWriter(r any) (t *transferWriter, err error) { +func newTransferWriter(r interface{}) (t *transferWriter, err error) { t = &transferWriter{} // Extract relevant fields @@ -493,7 +493,7 @@ func suppressedHeaders(status int) []string { } // msg is *http.Request or *http.Response. -func readTransfer(msg any, r *bufio.Reader) (err error) { +func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -821,7 +821,7 @@ func fixTrailer(header http.Header, chunked bool) (http.Header, error) { // and then reads the trailer if necessary. type body struct { src io.Reader - hdr any // non-nil (Response or Request) value means read trailer + hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early diff --git a/transport.go b/transport.go index bca0b287..0faeccad 100644 --- a/transport.go +++ b/transport.go @@ -2706,8 +2706,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er // a t.Logf func. See export_test.go's Request.WithT method. type tLogKey struct{} -func (tr *transportRequest) logf(format string, args ...any) { - if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok { +func (tr *transportRequest) logf(format string, args ...interface{}) { + if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } } From 6fba77dfb4f8d7fdc6c09faf7b57d86340da26be Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:44:23 +0800 Subject: [PATCH 028/843] port go1.18's strings.Cut to cutString to support lower go version --- common.go | 16 ++++++++++++++++ http_response.go | 4 ++-- transport.go | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/common.go b/common.go index 7af8cd06..c85804eb 100644 --- a/common.go +++ b/common.go @@ -1,5 +1,7 @@ package req +import "strings" + const ( CONTENT_TYPE_APPLICATION_JSON_UTF8 = "application/json; charset=UTF-8" CONTENT_TYPE_APPLICATION_XML_UTF8 = "application/xml; charset=UTF-8" @@ -7,3 +9,17 @@ const ( CONTENT_TYPE_TEXT_HTML_UTF8 = "text/html; charset=UTF-8" CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=UTF-8" ) + + +// cutString is a string util function which is copied +// from go1.18 strings package. +// cutString slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, "", false. +func cutString(s, sep string) (before, after string, found bool) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} diff --git a/http_response.go b/http_response.go index db636701..ac9cbfa4 100644 --- a/http_response.go +++ b/http_response.go @@ -40,14 +40,14 @@ func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) } return nil, err } - proto, status, ok := strings.Cut(line, " ") + proto, status, ok := cutString(line, " ") if !ok { return nil, badStringError("malformed HTTP response", line) } resp.Proto = proto resp.Status = strings.TrimLeft(status, " ") - statusCode, _, _ := strings.Cut(resp.Status, " ") + statusCode, _, _ := cutString(resp.Status, " ") if len(statusCode) != 3 { return nil, badStringError("malformed HTTP status code", statusCode) } diff --git a/transport.go b/transport.go index 0faeccad..5d0153f8 100644 --- a/transport.go +++ b/transport.go @@ -1752,7 +1752,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } if resp.StatusCode != 200 { - _, text, ok := strings.Cut(resp.Status, " ") + _, text, ok := cutString(resp.Status, " ") conn.Close() if !ok { return nil, errors.New("unknown status code") From 6b00f479135dea68600cb4625a3144af11ff6286 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:45:39 +0800 Subject: [PATCH 029/843] fix cutString comments --- common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common.go b/common.go index c85804eb..bddf56ea 100644 --- a/common.go +++ b/common.go @@ -11,7 +11,7 @@ const ( ) -// cutString is a string util function which is copied +// cutString is a string util function which is ported // from go1.18 strings package. // cutString slices s around the first instance of sep, // returning the text before and after sep. From 1b0158803eba4bdf41d447179a9bc6120ff6a2ed Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:51:18 +0800 Subject: [PATCH 030/843] port go1.18's bytes.Cut to bytesCut to support lower go version --- internal/chunked.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/internal/chunked.go b/internal/chunked.go index 37a72e90..9163b1de 100644 --- a/internal/chunked.go +++ b/internal/chunked.go @@ -168,13 +168,26 @@ var semi = []byte(";") // "0;token=val" => "0" // `0;token="quoted string"` => "0" func removeChunkExtension(p []byte) ([]byte, error) { - p, _, _ = bytes.Cut(p, semi) + p, _, _ = bytesCut(p, semi) // TODO: care about exact syntax of chunk extensions? We're // ignoring and stripping them anyway. For now just never // return an error. return p, nil } +// bytesCut slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, nil, false. +// +// Cut returns slices of the original slice s, not copies. +func bytesCut(s, sep []byte) (before, after []byte, found bool) { + if i := bytes.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, nil, false +} + // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP // "chunked" format before writing them to w. Closing the returned chunkedWriter // sends the final 0-length chunk that marks the end of the stream but does From e341469b85adfb47a8c85896db70b3509fb87f07 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 11:53:51 +0800 Subject: [PATCH 031/843] fix another bytes.Cut --- common.go | 18 +++++++++++++++++- textproto_reader.go | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/common.go b/common.go index bddf56ea..08741aba 100644 --- a/common.go +++ b/common.go @@ -1,6 +1,9 @@ package req -import "strings" +import ( + "bytes" + "strings" +) const ( CONTENT_TYPE_APPLICATION_JSON_UTF8 = "application/json; charset=UTF-8" @@ -23,3 +26,16 @@ func cutString(s, sep string) (before, after string, found bool) { } return s, "", false } + +// bytesCut slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, nil, false. +// +// Cut returns slices of the original slice s, not copies. +func bytesCut(s, sep []byte) (before, after []byte, found bool) { + if i := bytes.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, nil, false +} \ No newline at end of file diff --git a/textproto_reader.go b/textproto_reader.go index 12cec208..0519876e 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -564,7 +564,7 @@ func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { } // Key ends at first colon. - k, v, ok := bytes.Cut(kv, colon) + k, v, ok := bytesCut(kv, colon) if !ok { return m, protocolError("malformed MIME header line: " + string(kv)) } From 7336f86c7e92ff9d2df24d9e899353a1afe049c7 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 13:02:09 +0800 Subject: [PATCH 032/843] supoort lower go version: change import to ioutil from io 1. io.Discard --> ioutil.Discard 2. io.NopCloser --> ioutil.NopCloser 3. now req is compatible with go1.13+ --- request.go | 13 +++++++------ transfer.go | 11 ++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/request.go b/request.go index d885b18a..09d32ac1 100644 --- a/request.go +++ b/request.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/hashicorp/go-multierror" "io" + "io/ioutil" "net/http" urlpkg "net/url" "strings" @@ -153,7 +154,7 @@ func (r *Request) Body(body interface{}) *Request { case io.ReadCloser: r.httpRequest.Body = b case io.Reader: - r.httpRequest.Body = io.NopCloser(b) + r.httpRequest.Body = ioutil.NopCloser(b) case []byte: r.BodyBytes(b) case string: @@ -163,23 +164,23 @@ func (r *Request) Body(body interface{}) *Request { } func (r *Request) BodyBytes(body []byte) *Request { - r.httpRequest.Body = io.NopCloser(bytes.NewReader(body)) + r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r } func (r *Request) BodyString(body string) *Request { - r.httpRequest.Body = io.NopCloser(strings.NewReader(body)) + r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r } func (r *Request) BodyJsonString(body string) *Request { - r.httpRequest.Body = io.NopCloser(strings.NewReader(body)) + r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } func (r *Request) BodyJsonBytes(body []byte) *Request { - r.httpRequest.Body = io.NopCloser(bytes.NewReader(body)) + r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } @@ -216,7 +217,7 @@ func (r *Request) execute() (resp *Response, err error) { Response: httpResponse, } if r.client.t.Discard { - io.Copy(io.Discard, httpResponse.Body) + io.Copy(ioutil.Discard, httpResponse.Body) } return } diff --git a/transfer.go b/transfer.go index 9b9c2d1b..cf35a702 100644 --- a/transfer.go +++ b/transfer.go @@ -12,6 +12,7 @@ import ( "github.com/imroc/req/v2/internal" "github.com/imroc/req/v2/internal/ascii" "io" + "io/ioutil" "net/http" "net/http/httptrace" "net/textproto" @@ -157,7 +158,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { // servers. See Issue 18257, as one example. // // The only reason we'd send such a request is if the user set the Body to a -// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't +// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't // set ContentLength, or NewRequest set it to -1 (unknown), so then we assume // there's bytes to send. // @@ -380,7 +381,7 @@ func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { return err } var nextra int64 - nextra, err = t.doBodyCopy(io.Discard, body) + nextra, err = t.doBodyCopy(ioutil.Discard, body) ncopy += nextra } if err != nil { @@ -1005,7 +1006,7 @@ func (b *body) Close() error { var n int64 // Consume the body, or, which will also lead to us reading // the trailer headers after the body, if present. - n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + n, err = io.CopyN(ioutil.Discard, bodyLocked{b}, maxPostHandlerReadBytes) if err == io.EOF { err = nil } @@ -1016,7 +1017,7 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(io.Discard, bodyLocked{b}) + _, err = io.Copy(ioutil.Discard, bodyLocked{b}) } b.closed = true return err @@ -1091,7 +1092,7 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) +var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil)) // isKnownInMemoryReader reports whether r is a type known to not // block on Read. Its caller uses this as an optional optimization to From ad9f707c11c0e6dc573f73625ba88ddfb3e7b98f Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 13:06:44 +0800 Subject: [PATCH 033/843] update go mod: compatible with go1.13+ --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 30e8e60f..8c7b6bc0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/imroc/req/v2 -go 1.16 +go 1.13 require ( github.com/hashicorp/go-multierror v1.1.1 From 9649ed93aae13ff99812eb40369a9e3f886ac36b Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 13:11:27 +0800 Subject: [PATCH 034/843] move bytes.Cut and strings.Cut to internal util --- common.go | 34 +--------------------------------- http_response.go | 5 +++-- internal/chunked.go | 16 ++-------------- internal/util/util.go | 30 ++++++++++++++++++++++++++++++ textproto_reader.go | 3 ++- transport.go | 3 ++- 6 files changed, 40 insertions(+), 51 deletions(-) create mode 100644 internal/util/util.go diff --git a/common.go b/common.go index 08741aba..279a6cf9 100644 --- a/common.go +++ b/common.go @@ -1,41 +1,9 @@ package req -import ( - "bytes" - "strings" -) - const ( CONTENT_TYPE_APPLICATION_JSON_UTF8 = "application/json; charset=UTF-8" CONTENT_TYPE_APPLICATION_XML_UTF8 = "application/xml; charset=UTF-8" CONTENT_TYPE_TEXT_XML_UTF8 = "text/xml; charset=UTF-8" CONTENT_TYPE_TEXT_HTML_UTF8 = "text/html; charset=UTF-8" CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=UTF-8" -) - - -// cutString is a string util function which is ported -// from go1.18 strings package. -// cutString slices s around the first instance of sep, -// returning the text before and after sep. -// The found result reports whether sep appears in s. -// If sep does not appear in s, cut returns s, "", false. -func cutString(s, sep string) (before, after string, found bool) { - if i := strings.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):], true - } - return s, "", false -} - -// bytesCut slices s around the first instance of sep, -// returning the text before and after sep. -// The found result reports whether sep appears in s. -// If sep does not appear in s, cut returns s, nil, false. -// -// Cut returns slices of the original slice s, not copies. -func bytesCut(s, sep []byte) (before, after []byte, found bool) { - if i := bytes.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):], true - } - return s, nil, false -} \ No newline at end of file +) \ No newline at end of file diff --git a/http_response.go b/http_response.go index ac9cbfa4..6b0d2620 100644 --- a/http_response.go +++ b/http_response.go @@ -7,6 +7,7 @@ package req import ( + "github.com/imroc/req/v2/internal/util" "io" "net/http" "strconv" @@ -40,14 +41,14 @@ func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) } return nil, err } - proto, status, ok := cutString(line, " ") + proto, status, ok := util.CutString(line, " ") if !ok { return nil, badStringError("malformed HTTP response", line) } resp.Proto = proto resp.Status = strings.TrimLeft(status, " ") - statusCode, _, _ := cutString(resp.Status, " ") + statusCode, _, _ := util.CutString(resp.Status, " ") if len(statusCode) != 3 { return nil, badStringError("malformed HTTP status code", statusCode) } diff --git a/internal/chunked.go b/internal/chunked.go index 9163b1de..64bcb4d3 100644 --- a/internal/chunked.go +++ b/internal/chunked.go @@ -13,6 +13,7 @@ import ( "bytes" "errors" "fmt" + "github.com/imroc/req/v2/internal/util" "io" ) @@ -168,26 +169,13 @@ var semi = []byte(";") // "0;token=val" => "0" // `0;token="quoted string"` => "0" func removeChunkExtension(p []byte) ([]byte, error) { - p, _, _ = bytesCut(p, semi) + p, _, _ = util.CutBytes(p, semi) // TODO: care about exact syntax of chunk extensions? We're // ignoring and stripping them anyway. For now just never // return an error. return p, nil } -// bytesCut slices s around the first instance of sep, -// returning the text before and after sep. -// The found result reports whether sep appears in s. -// If sep does not appear in s, cut returns s, nil, false. -// -// Cut returns slices of the original slice s, not copies. -func bytesCut(s, sep []byte) (before, after []byte, found bool) { - if i := bytes.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):], true - } - return s, nil, false -} - // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP // "chunked" format before writing them to w. Closing the returned chunkedWriter // sends the final 0-length chunk that marks the end of the stream but does diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 00000000..a6b933ca --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,30 @@ +package util + +import ( + "bytes" + "strings" +) + +// CutString slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, "", false. +func CutString(s, sep string) (before, after string, found bool) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} + +// CutBytes slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, nil, false. +// +// CutBytes returns slices of the original slice s, not copies. +func CutBytes(s, sep []byte) (before, after []byte, found bool) { + if i := bytes.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, nil, false +} diff --git a/textproto_reader.go b/textproto_reader.go index 0519876e..182ad4c4 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "fmt" + "github.com/imroc/req/v2/internal/util" "io" "net/textproto" "strconv" @@ -564,7 +565,7 @@ func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { } // Key ends at first colon. - k, v, ok := bytesCut(kv, colon) + k, v, ok := util.CutBytes(kv, colon) if !ok { return m, protocolError("malformed MIME header line: " + string(kv)) } diff --git a/transport.go b/transport.go index 5d0153f8..11dd57e4 100644 --- a/transport.go +++ b/transport.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/imroc/req/v2/internal/ascii" "github.com/imroc/req/v2/internal/godebug" + "github.com/imroc/req/v2/internal/util" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" @@ -1752,7 +1753,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } if resp.StatusCode != 200 { - _, text, ok := cutString(resp.Status, " ") + _, text, ok := util.CutString(resp.Status, " ") conn.Close() if !ok { return nil, errors.New("unknown status code") From 97ed56b11b158d9c3493812b8e8914b9ec77c8b8 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 13:16:56 +0800 Subject: [PATCH 035/843] small refactor --- common.go => content_type.go | 0 internal/ascii/print.go | 4 ++++ textproto_reader.go | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) rename common.go => content_type.go (100%) diff --git a/common.go b/content_type.go similarity index 100% rename from common.go rename to content_type.go diff --git a/internal/ascii/print.go b/internal/ascii/print.go index 585e5bab..69f32262 100644 --- a/internal/ascii/print.go +++ b/internal/ascii/print.go @@ -59,3 +59,7 @@ func ToLower(s string) (lower string, ok bool) { } return strings.ToLower(s), true } + +func IsSpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} diff --git a/textproto_reader.go b/textproto_reader.go index 182ad4c4..7e1a283c 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -517,7 +517,7 @@ func (r *textprotoReader) ReadDotLines() ([]string, error) { var colon = []byte(":") -// Readtextproto.MIMEHeader reads a MIME-style header from r. +// ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. // The returned map m maps Canonicaltextproto.MIMEHeaderKey(key) to a @@ -629,7 +629,7 @@ func (r *textprotoReader) upcomingHeaderNewlines() (n int) { return bytes.Count(peek, nl) } -// Canonicaltextproto.MIMEHeaderKey returns the canonical format of the +// CanonicalMIMEHeaderKey returns the canonical format of the // MIME header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; // the rest are converted to lowercase. For example, the From 37e9eb0c57188d28182ffb76d9524275426761ee Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 13:57:11 +0800 Subject: [PATCH 036/843] small refactor --- README.md | 35 +++++++++++++++++++++++++---------- client.go | 21 +++++++++------------ request.go | 2 +- response.go | 16 +++++++--------- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 360d0d36..4cf118b7 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,34 @@ # req + [![GoDoc](https://pkg.go.dev/badge/github.com/imroc/req.svg)](https://pkg.go.dev/github.com/imroc/req) -A golang http request library for humans +A golang http request library for humans. -Features -======== +## Features -* Simple and chainable methods for client and request settings -* Rich syntax sugar, greatly improving development efficiency -* Powerful debugging capabilities -* The settings can be dynamically adjusted, making it possible to debug in the production environment +* Simple and chainable methods for client and request settings. +* Rich syntax sugar, greatly improving development efficiency. +* Automatically detect charset and decode to utf-8. +* Powerful debugging capabilities (logging, tracing, and event dump the requests and responses content). +* The settings can be dynamically adjusted, making it possible to debug in the production environment. +* Easy to integrate with existing code, just replace client's Transport you can dump requests and reponses to debug. +## Install -Install -======= ``` sh -go get github.com/imroc/req/v2 +go get github.com/imroc/req/v2@v2.0.0-alpha.0 +``` + +## Usage + +Import req in your code: + +```go +import "github.com/imroc/req/v2" +``` + +Prepare client: + +```go +req.C().UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:95.0) Gecko/20100101 Firefox/95.0") ``` diff --git a/client.go b/client.go index a3577155..3607fcde 100644 --- a/client.go +++ b/client.go @@ -44,13 +44,11 @@ func (c *Client) R() *Request { } } - -func (c *Client) ResponseOptions(opts ResponseOptions) *Client { - c.t.ResponseOptions = opts - return c +func (c *Client) DebugMode() *Client { + return c.AutoDecodeTextContent().EnableDump(DumpAll()).UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } -func (c *Client) ResponseOption(opts ...ResponseOption) *Client { +func (c *Client) ResponseOptions(opts ...ResponseOption) *Client { for _, opt := range opts { opt(&c.t.ResponseOptions) } @@ -67,21 +65,20 @@ func (c *Client) NewRequest() *Request { return c.R() } -func (c *Client) DumpOptions(opt *DumpOptions) *Client { - c.dumpOptions = opt - return c -} - func (c *Client) DisableDump() *Client { c.t.DisableDump() return c } +func (c *Client) AutoDecodeTextContent() *Client { + return c.ResponseOptions(AutoDecodeTextContent()) +} + func (c *Client) UserAgent(userAgent string) *Client { - return c.Header("User-Agent", userAgent) + return c.CommonHeader("User-Agent", userAgent) } -func (c *Client) Header(key, value string) *Client { +func (c *Client) CommonHeader(key, value string) *Client { if c.commonHeader == nil { c.commonHeader = make(map[string]string) } diff --git a/request.go b/request.go index 09d32ac1..d231d82f 100644 --- a/request.go +++ b/request.go @@ -216,7 +216,7 @@ func (r *Request) execute() (resp *Response, err error) { request: r, Response: httpResponse, } - if r.client.t.Discard { + if r.client.t.AutoDiscard { io.Copy(ioutil.Discard, httpResponse.Body) } return diff --git a/response.go b/response.go index b36718a1..76bbc44b 100644 --- a/response.go +++ b/response.go @@ -15,14 +15,16 @@ type ResponseOptions struct { // Only valid when DisableAutoDecode is true. AutoDecodeContentType func(contentType string) bool - Discard bool + // AutoDiscard, if true, read all response body and discard automatically, + // useful when test + AutoDiscard bool } type ResponseOption func(o *ResponseOptions) -func DiscardBody() ResponseOption { +func DiscardResponseBody() ResponseOption { return func(o *ResponseOptions) { - o.Discard = true + o.AutoDiscard = true } } @@ -33,12 +35,8 @@ func DisableAutoDecode() ResponseOption { } } -// AutoDecodeContentTypeFunc customize the function to determine whether response -// body should auto decode with specified content type. -func AutoDecodeContentTypeFunc(fn func(contentType string) bool) ResponseOption { - return func(o *ResponseOptions) { - o.AutoDecodeContentType = fn - } +func AutoDecodeTextContent() ResponseOption { + return AutoDecodeContentType("text", "json", "xml", "html", "java") } // AutoDecodeContentType specifies that the response body should been auto-decoded From 796d16d978b498a39912f7bc32d89de5c77edc62 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 14:52:40 +0800 Subject: [PATCH 037/843] add some sugar method 1. resp.Discard() 2. client.AutoDiscardResponseBody() 3. client.TestMode() --- client.go | 10 ++++++++++ request.go | 2 +- response.go | 7 +++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 3607fcde..c70da145 100644 --- a/client.go +++ b/client.go @@ -44,6 +44,16 @@ func (c *Client) R() *Request { } } +func (c *Client) AutoDiscardResponseBody() *Client { + return c.ResponseOptions(DiscardResponseBody()) +} + +// TestMode is like DebugMode, but discard response body, so you can +// dump responses without read response body +func (c *Client) TestMode() *Client { + return c.DebugMode().AutoDiscardResponseBody() +} + func (c *Client) DebugMode() *Client { return c.AutoDecodeTextContent().EnableDump(DumpAll()).UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } diff --git a/request.go b/request.go index d231d82f..5e2a2936 100644 --- a/request.go +++ b/request.go @@ -217,7 +217,7 @@ func (r *Request) execute() (resp *Response, err error) { Response: httpResponse, } if r.client.t.AutoDiscard { - io.Copy(ioutil.Discard, httpResponse.Body) + err = resp.Discard() } return } diff --git a/response.go b/response.go index 76bbc44b..202f0196 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,8 @@ package req import ( + "io" + "io/ioutil" "net/http" "strings" ) @@ -62,3 +64,8 @@ type Response struct { func (r *Response) Body() Body { return Body{r.Response.Body, r.Response} } + +func (r *Response) Discard() error { + _, err := io.Copy(ioutil.Discard, r.Response.Body) + return err +} From e8647502632b8df2ea57d62c338f917ffdbc3339 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 15:55:51 +0800 Subject: [PATCH 038/843] add comments --- client.go | 12 +++++++++++- decode.go | 7 ++----- dump.go | 20 +++++++++++++------- transport.go | 5 ++--- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index c70da145..e29c3655 100644 --- a/client.go +++ b/client.go @@ -54,6 +54,9 @@ func (c *Client) TestMode() *Client { return c.DebugMode().AutoDiscardResponseBody() } +// DebugMode enables dump for requests and responses, and set user +// agent to pretend to be a web browser, Avoid returning abnormal +// data from some sites. func (c *Client) DebugMode() *Client { return c.AutoDecodeTextContent().EnableDump(DumpAll()).UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } @@ -96,12 +99,19 @@ func (c *Client) CommonHeader(key, value string) *Client { return c } +// EnableDump enables dump requests and responses, allowing you +// to clearly see the content of all requests and responses,which +// is very convenient for debugging APIs. +// EnableDump accepet options for custom the dump behavior: +// 1. DumpAsync: dump asynchronously, can be used for debugging in +// production environment without affecting performance. +// 2. DumpHead, DumpBody, func (c *Client) EnableDump(opts ...DumpOption) *Client { if len(opts) > 0 { if c.dumpOptions == nil { c.dumpOptions = &DumpOptions{} } - c.dumpOptions.Set(opts...) + c.dumpOptions.set(opts...) } c.t.EnableDump(c.dumpOptions) return c diff --git a/decode.go b/decode.go index aeef41b2..f2281194 100644 --- a/decode.go +++ b/decode.go @@ -1,7 +1,6 @@ package req import ( - "fmt" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/charmap" "io" @@ -39,13 +38,11 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc, name, _ := htmlcharset.DetermineEncoding(p[:n], "") + enc, _, _ := htmlcharset.DetermineEncoding(p[:n], "") + // TODO: log chartset name if enc == charmap.Windows1252 { return } - if name != "" { - fmt.Println("content charset detected:", name) - } if enc == nil { return } diff --git a/dump.go b/dump.go index 20b41b69..d3e65b95 100644 --- a/dump.go +++ b/dump.go @@ -1,11 +1,11 @@ package req import ( - "fmt" "io" "os" ) +// DumpOptions controls the dump behavior. type DumpOptions struct { Output io.Writer RequestHead bool @@ -15,20 +15,25 @@ type DumpOptions struct { Async bool } -func (do *DumpOptions) Set(opts ...DumpOption) { +func (do *DumpOptions) set(opts ...DumpOption) { for _, opt := range opts { opt(do) } } +// DumpOption configures the underlying DumpOptions type DumpOption func(*DumpOptions) +// DumpAsync indicates that the dump should be done asynchronously, +// can be used for debugging in production environment without +// affecting performance. func DumpAsync() DumpOption { return func(o *DumpOptions) { o.Async = true } } +// DumpHead indicates that should dump the head of requests and responses. func DumpHead() DumpOption { return func(o *DumpOptions) { o.RequestHead = true @@ -36,6 +41,7 @@ func DumpHead() DumpOption { } } +// DumpBody indicates that should dump the body of requests and responses. func DumpBody() DumpOption { return func(o *DumpOptions) { o.RequestBody = true @@ -43,6 +49,7 @@ func DumpBody() DumpOption { } } +// DumpRequest indicates that should dump the requests' head and response. func DumpRequest() DumpOption { return func(o *DumpOptions) { o.RequestHead = true @@ -50,6 +57,7 @@ func DumpRequest() DumpOption { } } +// DumpResponse indicates that should dump the responses' head and response. func DumpResponse() DumpOption { return func(o *DumpOptions) { o.ResponseHead = true @@ -57,6 +65,7 @@ func DumpResponse() DumpOption { } } +// DumpAll indicates that should dump both requests and responses' head and body. func DumpAll() DumpOption { return func(o *DumpOptions) { o.RequestHead = true @@ -66,12 +75,14 @@ func DumpAll() DumpOption { } } +// DumpTo indicates that the content should dump to the specified destination. func DumpTo(output io.Writer) DumpOption { return func(o *DumpOptions) { o.Output = output } } +// DumpToFile indicates that the content should dump to the specified filename. func DumpToFile(filename string) DumpOption { return func(o *DumpOptions) { file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600) @@ -156,10 +167,6 @@ type dumper struct { ch chan []byte } -func DefaultDumpOptions() *DumpOptions { - return defaultDumpOptions -} - var defaultDumpOptions = &DumpOptions{ Output: os.Stdout, RequestBody: true, @@ -202,7 +209,6 @@ func (d *dumper) Stop() { func (d *dumper) Start() { for b := range d.ch { if b == nil { - fmt.Println("stop dump") return } d.Output.Write(b) diff --git a/transport.go b/transport.go index 11dd57e4..b5b33e9c 100644 --- a/transport.go +++ b/transport.go @@ -304,16 +304,15 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { panic(err) } if charset, ok := params["charset"]; ok { - fmt.Println("chartset", charset, "detected") + // TODO: log charset if strings.Contains(charset, "utf-8") || strings.Contains(charset, "utf8") { // do not decode utf-8 - fmt.Println("decode not needed") return } enc, _ := htmlcharset.Lookup(charset) if enc == nil { enc, err = ianaindex.MIME.Encoding(charset) if err != nil { - fmt.Println("chartset", charset, "not supported:", err.Error(), "; cancel decode") + // TODO: log charset not supported return } } From 193a3012fb2c3eb4e9f58f8a7c53301cf5e936ce Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 16:19:09 +0800 Subject: [PATCH 039/843] optimize auto decode text content settings --- response.go | 26 +++++++++++++++++--------- transport.go | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/response.go b/response.go index 202f0196..77bb5968 100644 --- a/response.go +++ b/response.go @@ -37,22 +37,30 @@ func DisableAutoDecode() ResponseOption { } } +var textContentTypes = []string{"text", "json", "xml", "html", "java"} + func AutoDecodeTextContent() ResponseOption { - return AutoDecodeContentType("text", "json", "xml", "html", "java") + return AutoDecodeContentType(textContentTypes...) +} + +var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) + +func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { + return func(contentType string) bool { + for _, t := range contentTypes { + if strings.Contains(contentType, t) { + return true + } + } + return false + } } // AutoDecodeContentType specifies that the response body should been auto-decoded // when content type contains keywords that here given. func AutoDecodeContentType(contentTypes ...string) ResponseOption { return func(o *ResponseOptions) { - o.AutoDecodeContentType = func(contentType string) bool { - for _, t := range contentTypes { - if strings.Contains(contentType, t) { - return true - } - } - return false - } + o.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) } } diff --git a/transport.go b/transport.go index b5b33e9c..eb39e864 100644 --- a/transport.go +++ b/transport.go @@ -294,7 +294,7 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { if t.ResponseOptions.AutoDecodeContentType != nil { shouldDecode = t.ResponseOptions.AutoDecodeContentType } else { - shouldDecode = responseBodyIsText + shouldDecode = autoDecodeText } if !shouldDecode(contentType) { return From 780f21922b6326cd7ca2c81502e2559a1daa67fc Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 16:42:27 +0800 Subject: [PATCH 040/843] client support Clone() --- client.go | 35 +++++++++++++++++++++++++++++------ dump.go | 14 ++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index e29c3655..f88fca17 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,6 @@ package req import ( "encoding/json" "golang.org/x/net/publicsuffix" - "net" "net/http" "net/http/cookiejar" "time" @@ -26,11 +25,21 @@ type Client struct { t2 *http2Transport dumpOptions *DumpOptions httpClient *http.Client - dialer *net.Dialer jsonDecoder *json.Decoder commonHeader map[string]string } +func copyCommonHeader(h map[string]string) map[string]string { + if h == nil { + return nil + } + m := make(map[string]string) + for k, v := range h { + m[k] = v + } + return m +} + func (c *Client) R() *Request { req := &http.Request{ Header: make(http.Header), @@ -102,10 +111,9 @@ func (c *Client) CommonHeader(key, value string) *Client { // EnableDump enables dump requests and responses, allowing you // to clearly see the content of all requests and responses,which // is very convenient for debugging APIs. -// EnableDump accepet options for custom the dump behavior: -// 1. DumpAsync: dump asynchronously, can be used for debugging in -// production environment without affecting performance. -// 2. DumpHead, DumpBody, +// EnableDump accepet options for custom the dump behavior, such +// as DumpAsync, DumpHead, DumpBody, DumpRequest, DumpResponse, +// DumpAll, DumpTo, DumpToFile func (c *Client) EnableDump(opts ...DumpOption) *Client { if len(opts) > 0 { if c.dumpOptions == nil { @@ -122,6 +130,21 @@ func NewClient() *Client { return C() } +func (c *Client) Clone() *Client { + t := c.t.Clone() + t2, _ := http2ConfigureTransports(t) + cc := *c.httpClient + cc.Transport = t + return &Client{ + httpClient: &cc, + t: t, + t2: t2, + dumpOptions: c.dumpOptions.Clone(), + jsonDecoder: c.jsonDecoder, + commonHeader: copyCommonHeader(c.commonHeader), + } +} + func C() *Client { t := &Transport{ ForceAttemptHTTP2: true, diff --git a/dump.go b/dump.go index d3e65b95..38eaa693 100644 --- a/dump.go +++ b/dump.go @@ -15,6 +15,20 @@ type DumpOptions struct { Async bool } +func (do *DumpOptions) Clone() *DumpOptions { + if do == nil { + return nil + } + return &DumpOptions{ + Output: do.Output, + RequestHead: do.RequestHead, + RequestBody: do.RequestBody, + ResponseHead: do.ResponseHead, + ResponseBody: do.ResponseBody, + Async: do.Async, + } +} + func (do *DumpOptions) set(opts ...DumpOption) { for _, opt := range opts { opt(do) From 7f831581cf21c76bb4c661f89b417f04a3ac12b7 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 16:51:26 +0800 Subject: [PATCH 041/843] fix transport Clone() --- dump.go | 10 ++++++++++ transport.go | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/dump.go b/dump.go index 38eaa693..548ba5aa 100644 --- a/dump.go +++ b/dump.go @@ -203,6 +203,16 @@ func newDumper(opt *DumpOptions) *dumper { return d } +func (d *dumper) Clone() *dumper { + if d == nil { + return nil + } + return &dumper{ + DumpOptions: d.DumpOptions.Clone(), + ch: make(chan []byte, 20), + } +} + func (d *dumper) dump(p []byte) { if len(p) == 0 { return diff --git a/transport.go b/transport.go index eb39e864..441cbb2a 100644 --- a/transport.go +++ b/transport.go @@ -350,6 +350,7 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t2 := &Transport{ Proxy: t.Proxy, DialContext: t.DialContext, @@ -371,6 +372,11 @@ func (t *Transport) Clone() *Transport { ForceAttemptHTTP2: t.ForceAttemptHTTP2, WriteBufferSize: t.WriteBufferSize, ReadBufferSize: t.ReadBufferSize, + ResponseOptions: t.ResponseOptions, + dump: t.dump.Clone(), + } + if t.dump != nil { + t.dump.Start() } if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() From db967a38683bb335d386111c61358b4bc5ae7038 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 17:36:47 +0800 Subject: [PATCH 042/843] update README --- README.md | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4cf118b7..360aa231 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ A golang http request library for humans. go get github.com/imroc/req/v2@v2.0.0-alpha.0 ``` -## Usage +## Quick Start Import req in your code: @@ -30,5 +30,73 @@ import "github.com/imroc/req/v2" Prepare client: ```go -req.C().UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:95.0) Gecko/20100101 Firefox/95.0") +client := req.C().UserAgent("req v2") ``` + +Use client to create and send request: + +```go +resp, err : client.R().Header("test", "req").Get(url) +``` + +You can also use the default client when test: + +```go +// customize default client settings +req.DefaultClient().UserAgent("req v2") + +// create and send request using default client +resp, err := req.New().Header("test", "req").Get(url) +``` + +You can also simply do it with one line of code like this: + +```go +resp, err := req.DefaultClient().UserAgent("req v2").R().Header("test", "req").Get(url) +``` + +Want to debug requests? Just enable dump: +```go +client := req.C().EnableDump() // enable dump +client.R().Get("https://api.github.com/users/imroc").MustString() // send request and read response body +``` + +Now you can see the request and response content that has been dumped: + +```txt +:authority: api.github.com +:method: GET +:path: /users/imroc +:scheme: https +user-agent: req v2 +accept-encoding: gzip + +:status: 200 +server: GitHub.com +date: Fri, 21 Jan 2022 09:31:43 GMT +content-type: application/json; charset=utf-8 +cache-control: public, max-age=60, s-maxage=60 +vary: Accept, Accept-Encoding, Accept, X-Requested-With +etag: W/"fe5acddc5c01a01153ebc4068a1f067dadfa7a7dc9a025f44b37b0a0a50e2c55" +last-modified: Thu, 08 Jul 2021 12:11:23 GMT +x-github-media-type: github.v3; format=json +access-control-expose-headers: ETag, Link, Location, Retry-After, X-GitHub-OTP, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Used, X-RateLimit-Resource, X-RateLimit-Reset, X-OAuth-Scopes, X-Accepted-OAuth-Scopes, X-Poll-Interval, X-GitHub-Media-Type, X-GitHub-SSO, X-GitHub-Request-Id, Deprecation, Sunset +access-control-allow-origin: * +strict-transport-security: max-age=31536000; includeSubdomains; preload +x-frame-options: deny +x-content-type-options: nosniff +x-xss-protection: 0 +referrer-policy: origin-when-cross-origin, strict-origin-when-cross-origin +content-security-policy: default-src 'none' +content-encoding: gzip +x-ratelimit-limit: 60 +x-ratelimit-remaining: 59 +x-ratelimit-reset: 1642761103 +x-ratelimit-resource: core +x-ratelimit-used: 1 +accept-ranges: bytes +content-length: 486 +x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E + +{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} +``` \ No newline at end of file From 37995d32cb966b409ef4ede08ad9527b437ed2a5 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 17:39:01 +0800 Subject: [PATCH 043/843] change default user agent to req/v2 --- h2_bundle.go | 2 +- http.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index e14f7fa0..37d1af0e 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -6572,7 +6572,7 @@ const ( // a stream-level WINDOW_UPDATE for at a time. http2transportDefaultStreamMinRefresh = 4 << 10 - http2defaultUserAgent = "Go-http-client/2.0" + http2defaultUserAgent = "req/v2" // initialMaxConcurrentStreams is a connections maxConcurrentStreams until // it's received servers initial SETTINGS frame, which corresponds with the diff --git a/http.go b/http.go index 54fe1369..2fb58319 100644 --- a/http.go +++ b/http.go @@ -153,7 +153,7 @@ const maxPostHandlerReadBytes = 256 << 10 // It was changed at the time of Go 1.1 release because the former User-Agent // had ended up blocked by some intrusion detection systems. // See https://codereview.appspot.com/7532043. -const defaultUserAgent = "Go-http-client/1.1" +const defaultUserAgent = "req/v2" func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. From 4a5994b8c9b603fcf81c718c531625c4e8828637 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 17:45:20 +0800 Subject: [PATCH 044/843] remove Body, merge into Response --- body.go | 124 ----------------------------------------------- response.go | 4 -- response_body.go | 118 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 128 deletions(-) delete mode 100644 body.go create mode 100644 response_body.go diff --git a/body.go b/body.go deleted file mode 100644 index c4644261..00000000 --- a/body.go +++ /dev/null @@ -1,124 +0,0 @@ -package req - -import ( - "encoding/json" - "encoding/xml" - "io" - "io/ioutil" - "net/http" - "os" - "strings" -) - -type Body struct { - io.ReadCloser - resp *http.Response -} - -func (body Body) MustSave(dst io.Writer) { - err := body.Save(dst) - if err != nil { - panic(err) - } -} - -func (body Body) Save(dst io.Writer) error { - if dst == nil { - return nil // TODO: return error - } - _, err := io.Copy(dst, body.ReadCloser) - body.Close() - return err -} - -func (body Body) MustSaveFile(filename string) { - err := body.SaveFile(filename) - if err != nil { - panic(err) - } -} - -func (body Body) SaveFile(filename string) error { - if filename == "" { - return nil // TODO: return error - } - file, err := os.Create(filename) - if err != nil { - return err - } - _, err = io.Copy(file, body.ReadCloser) - body.Close() - return err -} - -func (body Body) MustUnmarshalJson(v interface{}) { - err := body.UnmarshalJson(v) - if err != nil { - panic(err) - } -} - -func (body Body) UnmarshalJson(v interface{}) error { - b, err := body.Bytes() - if err != nil { - return err - } - return json.Unmarshal(b, v) -} - -func (body Body) MustUnmarshalXml(v interface{}) { - err := body.UnmarshalXml(v) - if err != nil { - panic(err) - } -} - -func (body Body) UnmarshalXml(v interface{}) error { - b, err := body.Bytes() - if err != nil { - return err - } - return xml.Unmarshal(b, v) -} -func (body Body) MustUnmarshal(v interface{}) { - err := body.Unmarshal(v) - if err != nil { - panic(err) - } -} - -func (body Body) Unmarshal(v interface{}) error { - contentType := body.resp.Header.Get("Content-Type") - if strings.Contains(contentType, "json") { - return body.UnmarshalJson(v) - } else if strings.Contains(contentType, "xml") { - return body.UnmarshalXml(v) - } - return body.UnmarshalJson(v) -} - -func (body Body) MustString() string { - b, err := body.Bytes() - if err != nil { - panic(err) - } - return string(b) -} - -func (body Body) String() (string, error) { - b, err := body.Bytes() - return string(b), err -} - -func (body Body) Bytes() ([]byte, error) { - defer body.Close() - return ioutil.ReadAll(body.ReadCloser) -} - -func (body Body) MustBytes() []byte { - b, err := body.Bytes() - if err != nil { - panic(err) - } - return b -} diff --git a/response.go b/response.go index 77bb5968..0f3c65d2 100644 --- a/response.go +++ b/response.go @@ -69,10 +69,6 @@ type Response struct { request *Request } -func (r *Response) Body() Body { - return Body{r.Response.Body, r.Response} -} - func (r *Response) Discard() error { _, err := io.Copy(ioutil.Discard, r.Response.Body) return err diff --git a/response_body.go b/response_body.go new file mode 100644 index 00000000..56e1430f --- /dev/null +++ b/response_body.go @@ -0,0 +1,118 @@ +package req + +import ( + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "os" + "strings" +) + +func (r *Response) MustSave(dst io.Writer) { + err := r.Save(dst) + if err != nil { + panic(err) + } +} + +func (r *Response) Save(dst io.Writer) error { + if dst == nil { + return nil // TODO: return error + } + _, err := io.Copy(dst, r.Body) + r.Body.Close() + return err +} + +func (r *Response) MustSaveFile(filename string) { + err := r.SaveFile(filename) + if err != nil { + panic(err) + } +} + +func (r *Response) SaveFile(filename string) error { + if filename == "" { + return nil // TODO: return error + } + file, err := os.Create(filename) + if err != nil { + return err + } + _, err = io.Copy(file, r.Body) + r.Body.Close() + return err +} + +func (r *Response) MustUnmarshalJson(v interface{}) { + err := r.UnmarshalJson(v) + if err != nil { + panic(err) + } +} + +func (r *Response) UnmarshalJson(v interface{}) error { + b, err := r.Bytes() + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func (r *Response) MustUnmarshalXml(v interface{}) { + err := r.UnmarshalXml(v) + if err != nil { + panic(err) + } +} + +func (r *Response) UnmarshalXml(v interface{}) error { + b, err := r.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(b, v) +} +func (r *Response) MustUnmarshal(v interface{}) { + err := r.Unmarshal(v) + if err != nil { + panic(err) + } +} + +func (r *Response) Unmarshal(v interface{}) error { + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "json") { + return r.UnmarshalJson(v) + } else if strings.Contains(contentType, "xml") { + return r.UnmarshalXml(v) + } + return r.UnmarshalJson(v) +} + +func (r *Response) MustString() string { + b, err := r.Bytes() + if err != nil { + panic(err) + } + return string(b) +} + +func (r *Response) String() (string, error) { + b, err := r.Bytes() + return string(b), err +} + +func (r *Response) Bytes() ([]byte, error) { + defer r.Body.Close() + return ioutil.ReadAll(r.Body) +} + +func (r *Response) MustBytes() []byte { + b, err := r.Bytes() + if err != nil { + panic(err) + } + return b +} From 7a739cc8a87a8247275bdfa6de92add24ad2b880 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 17:49:03 +0800 Subject: [PATCH 045/843] update README --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 360aa231..763ff227 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,10 @@ A golang http request library for humans. * Simple and chainable methods for client and request settings. * Rich syntax sugar, greatly improving development efficiency. -* Automatically detect charset and decode to utf-8. -* Powerful debugging capabilities (logging, tracing, and event dump the requests and responses content). -* The settings can be dynamically adjusted, making it possible to debug in the production environment. -* Easy to integrate with existing code, just replace client's Transport you can dump requests and reponses to debug. +* Automatically detect charset and decode it to utf-8. +* Powerful debugging capabilities (logging, tracing, and even dump the requests and responses' content). +* All settings can be dynamically adjusted, making it possible to debug in the production environment. +* Easy to integrate with existing code, just replace client's Transport then you can dump content as req to debug APIs. ## Install @@ -36,7 +36,7 @@ client := req.C().UserAgent("req v2") Use client to create and send request: ```go -resp, err : client.R().Header("test", "req").Get(url) +resp, err: client.R().Header("test", "req").Get(url) ``` You can also use the default client when test: @@ -56,6 +56,7 @@ resp, err := req.DefaultClient().UserAgent("req v2").R().Header("test", "req").G ``` Want to debug requests? Just enable dump: + ```go client := req.C().EnableDump() // enable dump client.R().Get("https://api.github.com/users/imroc").MustString() // send request and read response body From bc694e8f3853ebb1ab7253b0b2212d32b229c5a3 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 18:03:06 +0800 Subject: [PATCH 046/843] update README --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 763ff227..54bd45d5 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ import "github.com/imroc/req/v2" Prepare client: ```go -client := req.C().UserAgent("req v2") +client := req.C().UserAgent("req/v2") ``` Use client to create and send request: @@ -62,14 +62,14 @@ client := req.C().EnableDump() // enable dump client.R().Get("https://api.github.com/users/imroc").MustString() // send request and read response body ``` -Now you can see the request and response content that has been dumped: +Now you can see the request and response content that has been dumped : ```txt :authority: api.github.com :method: GET :path: /users/imroc :scheme: https -user-agent: req v2 +user-agent: req/v2 accept-encoding: gzip :status: 200 @@ -100,4 +100,6 @@ content-length: 486 x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} -``` \ No newline at end of file +``` + +> Here we can see it's http2 format, because req will try http2 by default if server support. \ No newline at end of file From e9d11edd97dd303d41809343ea5e9b503cc325ef Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 18:18:29 +0800 Subject: [PATCH 047/843] update README --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 54bd45d5..39aa12b1 100644 --- a/README.md +++ b/README.md @@ -27,16 +27,17 @@ Import req in your code: import "github.com/imroc/req/v2" ``` -Prepare client: +Prepare a client: ```go -client := req.C().UserAgent("req/v2") +client := req.C().UserAgent("req/v2") // client settings is chainable ``` Use client to create and send request: ```go -resp, err: client.R().Header("test", "req").Get(url) +// use R() to create a new request, and request settings is also chainable +resp, err: client.R().Header("test", "req").Body("test").Get("https://test.example.com") ``` You can also use the default client when test: @@ -58,8 +59,10 @@ resp, err := req.DefaultClient().UserAgent("req v2").R().Header("test", "req").G Want to debug requests? Just enable dump: ```go -client := req.C().EnableDump() // enable dump -client.R().Get("https://api.github.com/users/imroc").MustString() // send request and read response body +// create client and enable dump +client := req.C().EnableDump() +// send request and read response body +client.R().Get("https://api.github.com/users/imroc").MustString() ``` Now you can see the request and response content that has been dumped : @@ -102,4 +105,4 @@ x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} ``` -> Here we can see it's http2 format, because req will try http2 by default if server support. \ No newline at end of file +> Here we can see the content is in http2 format, because req will try http2 by default if server support. \ No newline at end of file From 02115392ff924eafc1a18e8e2390a8d6c82ecf76 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 18:21:23 +0800 Subject: [PATCH 048/843] update README --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 39aa12b1..42862739 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ import "github.com/imroc/req/v2" Prepare a client: ```go -client := req.C().UserAgent("req/v2") // client settings is chainable +client := req.C().UserAgent("custom-client") // client settings is chainable ``` Use client to create and send request: @@ -44,7 +44,7 @@ You can also use the default client when test: ```go // customize default client settings -req.DefaultClient().UserAgent("req v2") +req.DefaultClient().UserAgent("custom-client") // create and send request using default client resp, err := req.New().Header("test", "req").Get(url) @@ -53,14 +53,14 @@ resp, err := req.New().Header("test", "req").Get(url) You can also simply do it with one line of code like this: ```go -resp, err := req.DefaultClient().UserAgent("req v2").R().Header("test", "req").Get(url) +resp, err := req.DefaultClient().UserAgent("custom-client").R().Header("test", "req").Get(url) ``` Want to debug requests? Just enable dump: ```go // create client and enable dump -client := req.C().EnableDump() +client := req.C().UserAgent("custom-client").EnableDump() // send request and read response body client.R().Get("https://api.github.com/users/imroc").MustString() ``` @@ -72,7 +72,7 @@ Now you can see the request and response content that has been dumped : :method: GET :path: /users/imroc :scheme: https -user-agent: req/v2 +user-agent: custom-client accept-encoding: gzip :status: 200 From 7188e92c85303d991bc933399671f649b5eb56b6 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 20:31:43 +0800 Subject: [PATCH 049/843] update README --- README.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 42862739..8e81342e 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ client := req.C().UserAgent("custom-client").EnableDump() client.R().Get("https://api.github.com/users/imroc").MustString() ``` -Now you can see the request and response content that has been dumped : +Now you can see the request and response content that has been dumped: ```txt :authority: api.github.com @@ -105,4 +105,22 @@ x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} ``` -> Here we can see the content is in http2 format, because req will try http2 by default if server support. \ No newline at end of file +> Here we can see the content is in http2 format, because req will try http2 by default if server support. + +## Debug + +Simple example: + +```go +// dump head content asynchronously and save it to file +client := req.C().EnableDump().DumpHead().DumpAsync().DumpToFile("reqdump.log") +resp, err := client.R().Body(body).Post(url) +... +``` + +All dump options: +* `DumpAsync()` indicates that the dump should be done asynchronously, + +## License + +Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file From 13811c9011259e4ed1df1dc1da2fcc05523c5ef7 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 20:32:05 +0800 Subject: [PATCH 050/843] clone defaultDumpOptions --- client.go | 2 ++ dump.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index f88fca17..00737312 100644 --- a/client.go +++ b/client.go @@ -120,6 +120,8 @@ func (c *Client) EnableDump(opts ...DumpOption) *Client { c.dumpOptions = &DumpOptions{} } c.dumpOptions.set(opts...) + } else if c.dumpOptions == nil { + c.dumpOptions = defaultDumpOptions.Clone() } c.t.EnableDump(c.dumpOptions) return c diff --git a/dump.go b/dump.go index 548ba5aa..15418bab 100644 --- a/dump.go +++ b/dump.go @@ -191,7 +191,7 @@ var defaultDumpOptions = &DumpOptions{ func newDumper(opt *DumpOptions) *dumper { if opt == nil { - opt = defaultDumpOptions + opt = defaultDumpOptions.Clone() } if opt.Output == nil { opt.Output = os.Stdout From f9f3b70119b51d00b5e9a402803247e32ce801d4 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 21:29:51 +0800 Subject: [PATCH 051/843] add logging capability --- client.go | 21 ++++++++++++++++++++- logger.go | 39 +++++++++++++++++++++++++++++++++++++++ request.go | 1 + 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 logger.go diff --git a/client.go b/client.go index 00737312..5afb2124 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "golang.org/x/net/publicsuffix" "net/http" "net/http/cookiejar" + "os" "time" ) @@ -21,6 +22,7 @@ func SetDefaultClient(c *Client) { var defaultClient *Client = C() type Client struct { + log Logger t *Transport t2 *http2Transport dumpOptions *DumpOptions @@ -63,11 +65,27 @@ func (c *Client) TestMode() *Client { return c.DebugMode().AutoDiscardResponseBody() } +const ( + userAgentFirefox = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:95.0) Gecko/20100101 Firefox/95.0" + userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" +) + // DebugMode enables dump for requests and responses, and set user // agent to pretend to be a web browser, Avoid returning abnormal // data from some sites. func (c *Client) DebugMode() *Client { - return c.AutoDecodeTextContent().EnableDump(DumpAll()).UserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") + return c.AutoDecodeTextContent(). + EnableDump(DumpAll()). + Logger(NewLogger(os.Stdout)). + UserAgent(userAgentChrome) +} + +func (c *Client) Logger(log Logger) *Client { + if log == nil { + return c + } + c.log = log + return c } func (c *Client) ResponseOptions(opts ...ResponseOption) *Client { @@ -164,6 +182,7 @@ func C() *Client { Timeout: 2 * time.Minute, } c := &Client{ + log: &emptyLogger{}, httpClient: httpClient, t: t, t2: t2, diff --git a/logger.go b/logger.go new file mode 100644 index 00000000..ce9dfbd5 --- /dev/null +++ b/logger.go @@ -0,0 +1,39 @@ +package req + +import ( + "fmt" + "io" + "os" +) + +type Logger interface { + Println(v ...interface{}) +} + +type logger struct { + w io.Writer +} + +func (l *logger) Println(v ...interface{}) { + fmt.Fprintln(l.w, v...) +} + +func NewLogger(output io.Writer) Logger { + if output == nil { + output = os.Stdout + } + return &logger{output} +} + +type emptyLogger struct{} + +func (l *emptyLogger) Println(v ...interface{}) {} + +func logp(logger Logger, s string) { + logger.Println("[req]", s) +} + +func logf(logger Logger, format string, v ...interface{}) { + s := fmt.Sprintf(format, v...) + logp(logger, s) +} diff --git a/request.go b/request.go index 5e2a2936..08f4814f 100644 --- a/request.go +++ b/request.go @@ -208,6 +208,7 @@ func (r *Request) execute() (resp *Response, err error) { r.httpRequest.Header.Set(k, v) } } + logf(r.client.log, "%s %s", r.httpRequest.Method, r.httpRequest.URL.String()) httpResponse, err := r.client.httpClient.Do(r.httpRequest) if err != nil { return From 8371d3c86f7d9bb7be1135c77a14dc097dc67697 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 21 Jan 2022 21:30:38 +0800 Subject: [PATCH 052/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8e81342e..3c598819 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.0 +go get github.com/imroc/req/v2@v2.0.0-alpha.1 ``` ## Quick Start From 6230f51c833ad00f3ae500f6df467b65dcf38b2c Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 08:59:34 +0800 Subject: [PATCH 053/843] refactor dump options --- README.md | 7 +-- client.go | 154 +++++++++++++++++++++++++++++++++++++++++++++------ dump.go | 98 ++++---------------------------- h2_bundle.go | 2 +- 4 files changed, 151 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index 3c598819..29e81acd 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ Want to debug requests? Just enable dump: ```go // create client and enable dump -client := req.C().UserAgent("custom-client").EnableDump() +client := req.C().UserAgent("custom-client").DumpAll() // send request and read response body client.R().Get("https://api.github.com/users/imroc").MustString() ``` @@ -113,14 +113,11 @@ Simple example: ```go // dump head content asynchronously and save it to file -client := req.C().EnableDump().DumpHead().DumpAsync().DumpToFile("reqdump.log") +client := req.C().DumpOnlyHead().DumpAsync().DumpToFile("reqdump.log") resp, err := client.R().Body(body).Post(url) ... ``` -All dump options: -* `DumpAsync()` indicates that the dump should be done asynchronously, - ## License Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index 5afb2124..cab8ee4a 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package req import ( "encoding/json" "golang.org/x/net/publicsuffix" + "io" "net/http" "net/http/cookiejar" "os" @@ -75,7 +76,7 @@ const ( // data from some sites. func (c *Client) DebugMode() *Client { return c.AutoDecodeTextContent(). - EnableDump(DumpAll()). + Dump(true). Logger(NewLogger(os.Stdout)). UserAgent(userAgentChrome) } @@ -100,16 +101,108 @@ func (c *Client) Timeout(d time.Duration) *Client { return c } -// NewRequest is the alias of R() -func (c *Client) NewRequest() *Request { - return c.R() +func (c *Client) getDumpOptions() *DumpOptions { + if c.dumpOptions == nil { + c.dumpOptions = newDefaultDumpOptions() + } + return c.dumpOptions +} + +func (c *Client) enableDump() { + if c.t.dump != nil { // dump already started + return + } + c.t.EnableDump(c.getDumpOptions()) +} + +// DumpToFile indicates that the content should dump to the specified filename. +func (c *Client) DumpToFile(filename string) *Client { + file, err := os.Create(filename) + if err != nil { + logf(c.log, "create dump file error: %v", err) + return c + } + c.getDumpOptions().Output = file + return c +} + +// DumpTo indicates that the content should dump to the specified destination. +func (c *Client) DumpTo(output io.Writer) *Client { + c.getDumpOptions().Output = output + c.enableDump() + return c +} + +// DumpAsync indicates that the dump should be done asynchronously, +// can be used for debugging in production environment without +// affecting performance. +func (c *Client) DumpAsync() *Client { + o := c.getDumpOptions() + o.Async = true + c.enableDump() + return c +} + +// DumpOnlyResponse indicates that should dump the responses' head and response. +func (c *Client) DumpOnlyResponse() *Client { + o := c.getDumpOptions() + o.ResponseHead = true + o.ResponseBody = true + o.RequestBody = false + o.RequestHead = false + c.enableDump() + return c +} + +// DumpOnlyRequest indicates that should dump the requests' head and response. +func (c *Client) DumpOnlyRequest() *Client { + o := c.getDumpOptions() + o.RequestHead = true + o.RequestBody = true + o.ResponseBody = false + o.ResponseHead = false + c.enableDump() + return c +} + +// DumpOnlyBody indicates that should dump the body of requests and responses. +func (c *Client) DumpOnlyBody() *Client { + o := c.getDumpOptions() + o.RequestBody = true + o.ResponseBody = true + o.RequestHead = false + o.ResponseHead = false + c.enableDump() + return c } -func (c *Client) DisableDump() *Client { - c.t.DisableDump() +// DumpOnlyHead indicates that should dump the head of requests and responses. +func (c *Client) DumpOnlyHead() *Client { + o := c.getDumpOptions() + o.RequestHead = true + o.ResponseHead = true + o.RequestBody = false + o.ResponseBody = false + c.enableDump() return c } +// DumpAll indicates that should dump both requests and responses' head and body. +func (c *Client) DumpAll() *Client { + o := c.getDumpOptions() + o.RequestHead = true + o.RequestBody = true + o.ResponseHead = true + o.ResponseBody = true + c.enableDump() + return c +} + +// NewRequest is the alias of R() +func (c *Client) NewRequest() *Request { + return c.R() +} + func (c *Client) AutoDecodeTextContent() *Client { return c.ResponseOptions(AutoDecodeTextContent()) } @@ -126,24 +219,49 @@ func (c *Client) CommonHeader(key, value string) *Client { return c } +// Dump if true, enables dump requests and responses, allowing you +// to clearly see the content of all requests and responses,which +// is very convenient for debugging APIs. +// Dump if false, disable the dump behaviour. +func (c *Client) Dump(enable bool) *Client { + if !enable { + c.t.DisableDump() + return c + } + c.enableDump() + return c +} + +// DumpOptions configures the underlying Transport's DumpOptions +func (c *Client) DumpOptions(opt *DumpOptions) *Client { + if opt == nil { + return c + } + c.dumpOptions = opt + if c.t.dump != nil { + c.t.dump.DumpOptions = opt + } + return c +} + // EnableDump enables dump requests and responses, allowing you // to clearly see the content of all requests and responses,which // is very convenient for debugging APIs. // EnableDump accepet options for custom the dump behavior, such // as DumpAsync, DumpHead, DumpBody, DumpRequest, DumpResponse, // DumpAll, DumpTo, DumpToFile -func (c *Client) EnableDump(opts ...DumpOption) *Client { - if len(opts) > 0 { - if c.dumpOptions == nil { - c.dumpOptions = &DumpOptions{} - } - c.dumpOptions.set(opts...) - } else if c.dumpOptions == nil { - c.dumpOptions = defaultDumpOptions.Clone() - } - c.t.EnableDump(c.dumpOptions) - return c -} +//func (c *Client) EnableDump(opts ...DumpOption) *Client { +// if len(opts) > 0 { +// if c.dumpOptions == nil { +// c.dumpOptions = &DumpOptions{} +// } +// c.dumpOptions.set(opts...) +// } else if c.dumpOptions == nil { +// c.dumpOptions = defaultDumpOptions.Clone() +// } +// c.t.EnableDump(c.dumpOptions) +// return c +//} // NewClient is the alias of C() func NewClient() *Client { diff --git a/dump.go b/dump.go index 15418bab..332a5510 100644 --- a/dump.go +++ b/dump.go @@ -19,92 +19,8 @@ func (do *DumpOptions) Clone() *DumpOptions { if do == nil { return nil } - return &DumpOptions{ - Output: do.Output, - RequestHead: do.RequestHead, - RequestBody: do.RequestBody, - ResponseHead: do.ResponseHead, - ResponseBody: do.ResponseBody, - Async: do.Async, - } -} - -func (do *DumpOptions) set(opts ...DumpOption) { - for _, opt := range opts { - opt(do) - } -} - -// DumpOption configures the underlying DumpOptions -type DumpOption func(*DumpOptions) - -// DumpAsync indicates that the dump should be done asynchronously, -// can be used for debugging in production environment without -// affecting performance. -func DumpAsync() DumpOption { - return func(o *DumpOptions) { - o.Async = true - } -} - -// DumpHead indicates that should dump the head of requests and responses. -func DumpHead() DumpOption { - return func(o *DumpOptions) { - o.RequestHead = true - o.ResponseHead = true - } -} - -// DumpBody indicates that should dump the body of requests and responses. -func DumpBody() DumpOption { - return func(o *DumpOptions) { - o.RequestBody = true - o.ResponseBody = true - } -} - -// DumpRequest indicates that should dump the requests' head and response. -func DumpRequest() DumpOption { - return func(o *DumpOptions) { - o.RequestHead = true - o.RequestBody = true - } -} - -// DumpResponse indicates that should dump the responses' head and response. -func DumpResponse() DumpOption { - return func(o *DumpOptions) { - o.ResponseHead = true - o.ResponseBody = true - } -} - -// DumpAll indicates that should dump both requests and responses' head and body. -func DumpAll() DumpOption { - return func(o *DumpOptions) { - o.RequestHead = true - o.RequestBody = true - o.ResponseHead = true - o.ResponseBody = true - } -} - -// DumpTo indicates that the content should dump to the specified destination. -func DumpTo(output io.Writer) DumpOption { - return func(o *DumpOptions) { - o.Output = output - } -} - -// DumpToFile indicates that the content should dump to the specified filename. -func DumpToFile(filename string) DumpOption { - return func(o *DumpOptions) { - file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600) - if err != nil { - panic(err) - } - o.Output = file - } + d := *do + return &d } func (d *dumper) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { @@ -181,6 +97,16 @@ type dumper struct { ch chan []byte } +func newDefaultDumpOptions() *DumpOptions { + return &DumpOptions{ + Output: os.Stdout, + RequestBody: true, + ResponseBody: true, + ResponseHead: true, + RequestHead: true, + } +} + var defaultDumpOptions = &DumpOptions{ Output: os.Stdout, RequestBody: true, diff --git a/h2_bundle.go b/h2_bundle.go index 37d1af0e..871357e0 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1865,7 +1865,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { hr, err := fr.readMetaFrame(f.(*http2HeadersFrame)) - if fr.dump != nil && err == nil && fr.dump.ResponseHead { + if err == nil && fr.dump != nil && fr.dump.ResponseHead { fr.dump.dump([]byte("\r\n")) } return hr, err From 6aef22768630910edfc8f39c918697211d6a4e76 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 09:10:07 +0800 Subject: [PATCH 054/843] add more comments --- client.go | 27 +++++++-------------------- logger.go | 4 ++++ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index cab8ee4a..037cc614 100644 --- a/client.go +++ b/client.go @@ -81,6 +81,7 @@ func (c *Client) DebugMode() *Client { UserAgent(userAgentChrome) } +// Logger set the logger for req. func (c *Client) Logger(log Logger) *Client { if log == nil { return c @@ -96,6 +97,7 @@ func (c *Client) ResponseOptions(opts ...ResponseOption) *Client { return c } +// Timeout set the timeout for all requests. func (c *Client) Timeout(d time.Duration) *Client { c.httpClient.Timeout = d return c @@ -207,10 +209,12 @@ func (c *Client) AutoDecodeTextContent() *Client { return c.ResponseOptions(AutoDecodeTextContent()) } +// UserAgent set the "User-Agent" header for all requests. func (c *Client) UserAgent(userAgent string) *Client { return c.CommonHeader("User-Agent", userAgent) } +// CommonHeader set the common header for all requests. func (c *Client) CommonHeader(key, value string) *Client { if c.commonHeader == nil { c.commonHeader = make(map[string]string) @@ -244,30 +248,12 @@ func (c *Client) DumpOptions(opt *DumpOptions) *Client { return c } -// EnableDump enables dump requests and responses, allowing you -// to clearly see the content of all requests and responses,which -// is very convenient for debugging APIs. -// EnableDump accepet options for custom the dump behavior, such -// as DumpAsync, DumpHead, DumpBody, DumpRequest, DumpResponse, -// DumpAll, DumpTo, DumpToFile -//func (c *Client) EnableDump(opts ...DumpOption) *Client { -// if len(opts) > 0 { -// if c.dumpOptions == nil { -// c.dumpOptions = &DumpOptions{} -// } -// c.dumpOptions.set(opts...) -// } else if c.dumpOptions == nil { -// c.dumpOptions = defaultDumpOptions.Clone() -// } -// c.t.EnableDump(c.dumpOptions) -// return c -//} - -// NewClient is the alias of C() +// NewClient is the alias of C func NewClient() *Client { return C() } +// Clone copy and returns the Client func (c *Client) Clone() *Client { t := c.t.Clone() t2, _ := http2ConfigureTransports(t) @@ -283,6 +269,7 @@ func (c *Client) Clone() *Client { } } +// C create a new client. func C() *Client { t := &Transport{ ForceAttemptHTTP2: true, diff --git a/logger.go b/logger.go index ce9dfbd5..98b16647 100644 --- a/logger.go +++ b/logger.go @@ -6,6 +6,9 @@ import ( "os" ) +// Logger is the logging interface that req used internal, +// you can set the Logger for client if you want to see req's +// internal logging information. type Logger interface { Println(v ...interface{}) } @@ -18,6 +21,7 @@ func (l *logger) Println(v ...interface{}) { fmt.Fprintln(l.w, v...) } +// NewLogger create a simple Logger. func NewLogger(output io.Writer) Logger { if output == nil { output = os.Stdout From b8031c0e32f0bbbe23192a039c3f90fab9d2ed94 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 09:29:35 +0800 Subject: [PATCH 055/843] add more comments --- client.go | 4 ++++ request.go | 38 ++++++++++++++++++++++++++++++++++++-- response.go | 2 ++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 037cc614..afdbfe3a 100644 --- a/client.go +++ b/client.go @@ -10,10 +10,12 @@ import ( "time" ) +// DefaultClient returns the global default Client. func DefaultClient() *Client { return defaultClient } +// SetDefaultClient override the global default Client. func SetDefaultClient(c *Client) { if c != nil { defaultClient = c @@ -22,6 +24,7 @@ func SetDefaultClient(c *Client) { var defaultClient *Client = C() +// Client is the req's http client. type Client struct { log Logger t *Transport @@ -43,6 +46,7 @@ func copyCommonHeader(h map[string]string) map[string]string { return m } +// R create a new request. func (c *Client) R() *Request { req := &http.Request{ Header: make(http.Header), diff --git a/request.go b/request.go index 08f4814f..d0ae90a4 100644 --- a/request.go +++ b/request.go @@ -13,12 +13,14 @@ import ( "strings" ) +// Request is the http request type Request struct { error error client *Client httpRequest *http.Request } +// New create a new request using the global default client. func New() *Request { return defaultClient.R() } @@ -27,10 +29,13 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } +// Error return the underlying error, not nil if some error +// happend when constructing the request. func (r *Request) Error() error { return r.error } +// Method set the http request method. func (r *Request) Method(method string) *Request { if method == "" { // We document that "" means "GET" for Request.Method, and people have @@ -49,6 +54,7 @@ func (r *Request) Method(method string) *Request { return r } +// URL set the http request url. func (r *Request) URL(url string) *Request { u, err := urlpkg.Parse(url) if err != nil { @@ -66,6 +72,7 @@ func (r *Request) send(method, url string) (*Response, error) { return r.Method(method).URL(url).Send() } +// MustGet like Get, panic if error happens. func (r *Request) MustGet(url string) *Response { resp, err := r.Get(url) if err != nil { @@ -74,10 +81,12 @@ func (r *Request) MustGet(url string) *Response { return resp } +// Get send the request with GET method and specified url. func (r *Request) Get(url string) (*Response, error) { return r.send(http.MethodGet, url) } +// MustPost like Post, panic if error happens. func (r *Request) MustPost(url string) *Response { resp, err := r.Post(url) if err != nil { @@ -86,10 +95,12 @@ func (r *Request) MustPost(url string) *Response { return resp } +// Post send the request with POST method and specified url. func (r *Request) Post(url string) (*Response, error) { return r.send(http.MethodPost, url) } +// MustPut like Put, panic if error happens. func (r *Request) MustPut(url string) *Response { resp, err := r.Put(url) if err != nil { @@ -98,10 +109,12 @@ func (r *Request) MustPut(url string) *Response { return resp } +// Put send the request with Put method and specified url. func (r *Request) Put(url string) (*Response, error) { return r.send(http.MethodPut, url) } +// MustPatch like Patch, panic if error happens. func (r *Request) MustPatch(url string) *Response { resp, err := r.Patch(url) if err != nil { @@ -110,10 +123,12 @@ func (r *Request) MustPatch(url string) *Response { return resp } +// Patch send the request with PATCH method and specified url. func (r *Request) Patch(url string) (*Response, error) { return r.send(http.MethodPatch, url) } +// MustDelete like Delete, panic if error happens. func (r *Request) MustDelete(url string) *Response { resp, err := r.Delete(url) if err != nil { @@ -122,10 +137,12 @@ func (r *Request) MustDelete(url string) *Response { return resp } +// Delete send the request with DELETE method and specified url. func (r *Request) Delete(url string) (*Response, error) { return r.send(http.MethodDelete, url) } +// MustOptions like Options, panic if error happens. func (r *Request) MustOptions(url string) *Response { resp, err := r.Options(url) if err != nil { @@ -134,18 +151,26 @@ func (r *Request) MustOptions(url string) *Response { return resp } +// Options send the request with OPTIONS method and specified url. func (r *Request) Options(url string) (*Response, error) { return r.send(http.MethodOptions, url) } -func (r *Request) MustHead(url string) (*Response, error) { - return r.send(http.MethodHead, url) +// MustHead like Head, panic if error happens. +func (r *Request) MustHead(url string) *Response { + resp, err := r.send(http.MethodHead, url) + if err != nil { + panic(err) + } + return resp } +// Head send the request with HEAD method and specified url. func (r *Request) Head(url string) (*Response, error) { return r.send(http.MethodHead, url) } +// Body set the request body. func (r *Request) Body(body interface{}) *Request { if body == nil { return r @@ -163,28 +188,36 @@ func (r *Request) Body(body interface{}) *Request { return r } +// BodyBytes set the request body as []byte. func (r *Request) BodyBytes(body []byte) *Request { r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r } +// BodyString set the request body as string. func (r *Request) BodyString(body string) *Request { r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r } +// BodyJsonString set the request body as string and set Content-Type header +// as "application/json; charset=UTF-8" func (r *Request) BodyJsonString(body string) *Request { r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } +// BodyJsonBytes set the request body as []byte and set Content-Type header +// as "application/json; charset=UTF-8" func (r *Request) BodyJsonBytes(body []byte) *Request { r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } +// BodyJsonMarshal set the request body that marshaled from object, and +// set Content-Type header as "application/json; charset=UTF-8" func (r *Request) BodyJsonMarshal(v interface{}) *Request { b, err := json.Marshal(v) if err != nil { @@ -223,6 +256,7 @@ func (r *Request) execute() (resp *Response, err error) { return } +// Send sends the request. func (r *Request) Send() (resp *Response, err error) { return r.execute() } diff --git a/response.go b/response.go index 0f3c65d2..ff604384 100644 --- a/response.go +++ b/response.go @@ -7,6 +7,7 @@ import ( "strings" ) +// ResponseOptions determines that how should the response been processed. type ResponseOptions struct { // DisableAutoDecode, if true, prevents auto detect response // body's charset and decode it to utf-8 @@ -64,6 +65,7 @@ func AutoDecodeContentType(contentTypes ...string) ResponseOption { } } +// Response is the http response. type Response struct { *http.Response request *Request From 9f71dd10e1275ca52bcbe52b6c050abeefa4e8b0 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 10:04:17 +0800 Subject: [PATCH 056/843] refactor auto decode --- client.go | 33 +++++-- decode.go | 19 +--- internal/charsetutil/charsetutil.go | 148 ++++++++++++++++++++++++++++ request.go | 2 +- response.go | 27 ----- transport.go | 5 +- 6 files changed, 181 insertions(+), 53 deletions(-) create mode 100644 internal/charsetutil/charsetutil.go diff --git a/client.go b/client.go index afdbfe3a..fcca90e4 100644 --- a/client.go +++ b/client.go @@ -61,7 +61,8 @@ func (c *Client) R() *Request { } func (c *Client) AutoDiscardResponseBody() *Client { - return c.ResponseOptions(DiscardResponseBody()) + c.getResponseOptions().AutoDiscard = true + return c } // TestMode is like DebugMode, but discard response body, so you can @@ -79,7 +80,7 @@ const ( // agent to pretend to be a web browser, Avoid returning abnormal // data from some sites. func (c *Client) DebugMode() *Client { - return c.AutoDecodeTextContent(). + return c.AutoDecodeTextType(). Dump(true). Logger(NewLogger(os.Stdout)). UserAgent(userAgentChrome) @@ -94,10 +95,18 @@ func (c *Client) Logger(log Logger) *Client { return c } -func (c *Client) ResponseOptions(opts ...ResponseOption) *Client { - for _, opt := range opts { - opt(&c.t.ResponseOptions) +func (c *Client) getResponseOptions() *ResponseOptions { + if c.t.ResponseOptions == nil { + c.t.ResponseOptions = &ResponseOptions{} + } + return c.t.ResponseOptions +} + +func (c *Client) ResponseOptions(opt *ResponseOptions) *Client { + if opt == nil { + return c } + c.t.ResponseOptions = opt return c } @@ -209,8 +218,18 @@ func (c *Client) NewRequest() *Request { return c.R() } -func (c *Client) AutoDecodeTextContent() *Client { - return c.ResponseOptions(AutoDecodeTextContent()) +// AutoDecodeAllType indicates that try autodetect and decode all content type. +func (c *Client) AutoDecodeAllType() *Client { + c.getResponseOptions().AutoDecodeContentType = func(contentType string) bool { + return true + } + return c +} + +// AutoDecodeTextType indicates that only try autodetect and decode the text content type. +func (c *Client) AutoDecodeTextType() *Client { + c.getResponseOptions().AutoDecodeContentType = autoDecodeText + return c } // UserAgent set the "User-Agent" header for all requests. diff --git a/decode.go b/decode.go index f2281194..473b4e62 100644 --- a/decode.go +++ b/decode.go @@ -1,21 +1,10 @@ package req import ( - htmlcharset "golang.org/x/net/html/charset" - "golang.org/x/text/encoding/charmap" + "github.com/imroc/req/v2/internal/charsetutil" "io" - "strings" ) -func responseBodyIsText(contentType string) bool { - for _, keyword := range []string{"text", "json", "xml", "html", "java"} { - if strings.Contains(contentType, keyword) { - return true - } - } - return false -} - type decodeReaderCloser struct { io.ReadCloser decodeReader io.Reader @@ -38,11 +27,7 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc, _, _ := htmlcharset.DetermineEncoding(p[:n], "") - // TODO: log chartset name - if enc == charmap.Windows1252 { - return - } + enc := charsetutil.FindEncoding(p) if enc == nil { return } diff --git a/internal/charsetutil/charsetutil.go b/internal/charsetutil/charsetutil.go new file mode 100644 index 00000000..7c3d05bf --- /dev/null +++ b/internal/charsetutil/charsetutil.go @@ -0,0 +1,148 @@ +package charsetutil + +import ( + "bytes" + "golang.org/x/net/html" + htmlcharset "golang.org/x/net/html/charset" + "golang.org/x/text/encoding" + "strings" +) + +var boms = []struct { + bom []byte + enc string +}{ + {[]byte{0xfe, 0xff}, "utf-16be"}, + {[]byte{0xff, 0xfe}, "utf-16le"}, + {[]byte{0xef, 0xbb, 0xbf}, "utf-8"}, +} + +func FindEncoding(content []byte) encoding.Encoding { + if len(content) == 0 { + return nil + } + for _, b := range boms { + if bytes.HasPrefix(content, b.bom) { + e, _ := htmlcharset.Lookup(b.enc) + if e != nil { + return e + } + } + } + e, _ := prescan(content) + if e != nil { + return e + } + return nil +} + +func prescan(content []byte) (e encoding.Encoding, name string) { + z := html.NewTokenizer(bytes.NewReader(content)) + for { + switch z.Next() { + case html.ErrorToken: + return nil, "" + + case html.StartTagToken, html.SelfClosingTagToken: + tagName, hasAttr := z.TagName() + if !bytes.Equal(tagName, []byte("meta")) { + continue + } + attrList := make(map[string]bool) + gotPragma := false + + const ( + dontKnow = iota + doNeedPragma + doNotNeedPragma + ) + needPragma := dontKnow + + name = "" + e = nil + for hasAttr { + var key, val []byte + key, val, hasAttr = z.TagAttr() + ks := string(key) + if attrList[ks] { + continue + } + attrList[ks] = true + for i, c := range val { + if 'A' <= c && c <= 'Z' { + val[i] = c + 0x20 + } + } + + switch ks { + case "http-equiv": + if bytes.Equal(val, []byte("content-type")) { + gotPragma = true + } + + case "content": + if e == nil { + name = fromMetaElement(string(val)) + if name != "" { + e, name = htmlcharset.Lookup(name) + if e != nil { + needPragma = doNeedPragma + } + } + } + + case "charset": + e, name = htmlcharset.Lookup(string(val)) + needPragma = doNotNeedPragma + } + } + + if needPragma == dontKnow || needPragma == doNeedPragma && !gotPragma { + continue + } + + if strings.HasPrefix(name, "utf-16") { + name = "utf-8" + e = encoding.Nop + } + + if e != nil { + return e, name + } + } + } +} + +func fromMetaElement(s string) string { + for s != "" { + csLoc := strings.Index(s, "charset") + if csLoc == -1 { + return "" + } + s = s[csLoc+len("charset"):] + s = strings.TrimLeft(s, " \t\n\f\r") + if !strings.HasPrefix(s, "=") { + continue + } + s = s[1:] + s = strings.TrimLeft(s, " \t\n\f\r") + if s == "" { + return "" + } + if q := s[0]; q == '"' || q == '\'' { + s = s[1:] + closeQuote := strings.IndexRune(s, rune(q)) + if closeQuote == -1 { + return "" + } + return s[:closeQuote] + } + + end := strings.IndexAny(s, "; \t\n\f\r") + if end == -1 { + end = len(s) + } + return s[:end] + } + return "" +} diff --git a/request.go b/request.go index d0ae90a4..38139d41 100644 --- a/request.go +++ b/request.go @@ -250,7 +250,7 @@ func (r *Request) execute() (resp *Response, err error) { request: r, Response: httpResponse, } - if r.client.t.AutoDiscard { + if r.client.t.ResponseOptions != nil && r.client.t.ResponseOptions.AutoDiscard { err = resp.Discard() } return diff --git a/response.go b/response.go index ff604384..0eecad14 100644 --- a/response.go +++ b/response.go @@ -23,27 +23,8 @@ type ResponseOptions struct { AutoDiscard bool } -type ResponseOption func(o *ResponseOptions) - -func DiscardResponseBody() ResponseOption { - return func(o *ResponseOptions) { - o.AutoDiscard = true - } -} - -// DisableAutoDecode disable the response body auto-decode to improve performance. -func DisableAutoDecode() ResponseOption { - return func(o *ResponseOptions) { - o.DisableAutoDecode = true - } -} - var textContentTypes = []string{"text", "json", "xml", "html", "java"} -func AutoDecodeTextContent() ResponseOption { - return AutoDecodeContentType(textContentTypes...) -} - var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { @@ -57,14 +38,6 @@ func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) } } -// AutoDecodeContentType specifies that the response body should been auto-decoded -// when content type contains keywords that here given. -func AutoDecodeContentType(contentTypes ...string) ResponseOption { - return func(o *ResponseOptions) { - o.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) - } -} - // Response is the http response. type Response struct { *http.Response diff --git a/transport.go b/transport.go index 441cbb2a..1419f0d1 100644 --- a/transport.go +++ b/transport.go @@ -269,7 +269,7 @@ type Transport struct { // upgrades, set this to true. ForceAttemptHTTP2 bool - ResponseOptions + *ResponseOptions dump *dumper } @@ -286,6 +286,9 @@ func (t *Transport) dumpResponseBody(res *http.Response) { } func (t *Transport) autoDecodeResponseBody(res *http.Response) { + if t.ResponseOptions == nil { + return + } if t.ResponseOptions.DisableAutoDecode { return } From 89569b3ee10cf86328aa94a4e83de03e4ba7625b Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 10:05:05 +0800 Subject: [PATCH 057/843] release v2.0.0-alpha.2 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 29e81acd..b47427a3 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.1 +go get github.com/imroc/req/v2@v2.0.0-alpha.2 ``` ## Quick Start From 9d7640e00d03c4af8908010c36743102e006a7f0 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 15:14:51 +0800 Subject: [PATCH 058/843] expose GetDumpOptions --- client.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index fcca90e4..4e0b195e 100644 --- a/client.go +++ b/client.go @@ -116,7 +116,7 @@ func (c *Client) Timeout(d time.Duration) *Client { return c } -func (c *Client) getDumpOptions() *DumpOptions { +func (c *Client) GetDumpOptions() *DumpOptions { if c.dumpOptions == nil { c.dumpOptions = newDefaultDumpOptions() } @@ -127,7 +127,7 @@ func (c *Client) enableDump() { if c.t.dump != nil { // dump already started return } - c.t.EnableDump(c.getDumpOptions()) + c.t.EnableDump(c.GetDumpOptions()) } // DumpToFile indicates that the content should dump to the specified filename. @@ -137,13 +137,13 @@ func (c *Client) DumpToFile(filename string) *Client { logf(c.log, "create dump file error: %v", err) return c } - c.getDumpOptions().Output = file + c.GetDumpOptions().Output = file return c } // DumpTo indicates that the content should dump to the specified destination. func (c *Client) DumpTo(output io.Writer) *Client { - c.getDumpOptions().Output = output + c.GetDumpOptions().Output = output c.enableDump() return c } @@ -152,7 +152,7 @@ func (c *Client) DumpTo(output io.Writer) *Client { // can be used for debugging in production environment without // affecting performance. func (c *Client) DumpAsync() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.Async = true c.enableDump() return c @@ -160,7 +160,7 @@ func (c *Client) DumpAsync() *Client { // DumpOnlyResponse indicates that should dump the responses' head and response. func (c *Client) DumpOnlyResponse() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.ResponseHead = true o.ResponseBody = true o.RequestBody = false @@ -171,7 +171,7 @@ func (c *Client) DumpOnlyResponse() *Client { // DumpOnlyRequest indicates that should dump the requests' head and response. func (c *Client) DumpOnlyRequest() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.RequestHead = true o.RequestBody = true o.ResponseBody = false @@ -182,7 +182,7 @@ func (c *Client) DumpOnlyRequest() *Client { // DumpOnlyBody indicates that should dump the body of requests and responses. func (c *Client) DumpOnlyBody() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.RequestBody = true o.ResponseBody = true o.RequestHead = false @@ -193,7 +193,7 @@ func (c *Client) DumpOnlyBody() *Client { // DumpOnlyHead indicates that should dump the head of requests and responses. func (c *Client) DumpOnlyHead() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.RequestHead = true o.ResponseHead = true o.RequestBody = false @@ -204,7 +204,7 @@ func (c *Client) DumpOnlyHead() *Client { // DumpAll indicates that should dump both requests and responses' head and body. func (c *Client) DumpAll() *Client { - o := c.getDumpOptions() + o := c.GetDumpOptions() o.RequestHead = true o.RequestBody = true o.ResponseHead = true From 99c1482c7bad0ff6f9994b2083da620bf24fc22f Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 15:22:56 +0800 Subject: [PATCH 059/843] some refactor --- client.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 4e0b195e..9f66408d 100644 --- a/client.go +++ b/client.go @@ -61,7 +61,7 @@ func (c *Client) R() *Request { } func (c *Client) AutoDiscardResponseBody() *Client { - c.getResponseOptions().AutoDiscard = true + c.GetResponseOptions().AutoDiscard = true return c } @@ -82,12 +82,12 @@ const ( func (c *Client) DebugMode() *Client { return c.AutoDecodeTextType(). Dump(true). - Logger(NewLogger(os.Stdout)). + SetLogger(NewLogger(os.Stdout)). UserAgent(userAgentChrome) } -// Logger set the logger for req. -func (c *Client) Logger(log Logger) *Client { +// SetLogger set the logger for req. +func (c *Client) SetLogger(log Logger) *Client { if log == nil { return c } @@ -95,14 +95,15 @@ func (c *Client) Logger(log Logger) *Client { return c } -func (c *Client) getResponseOptions() *ResponseOptions { +func (c *Client) GetResponseOptions() *ResponseOptions { if c.t.ResponseOptions == nil { c.t.ResponseOptions = &ResponseOptions{} } return c.t.ResponseOptions } -func (c *Client) ResponseOptions(opt *ResponseOptions) *Client { +// ResponseOptions set the ResponseOptions for the underlying Transport. +func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { if opt == nil { return c } @@ -220,7 +221,7 @@ func (c *Client) NewRequest() *Request { // AutoDecodeAllType indicates that try autodetect and decode all content type. func (c *Client) AutoDecodeAllType() *Client { - c.getResponseOptions().AutoDecodeContentType = func(contentType string) bool { + c.GetResponseOptions().AutoDecodeContentType = func(contentType string) bool { return true } return c @@ -228,7 +229,7 @@ func (c *Client) AutoDecodeAllType() *Client { // AutoDecodeTextType indicates that only try autodetect and decode the text content type. func (c *Client) AutoDecodeTextType() *Client { - c.getResponseOptions().AutoDecodeContentType = autoDecodeText + c.GetResponseOptions().AutoDecodeContentType = autoDecodeText return c } From 66ff5bb3d6ddce20db10ff7bf58de080d3a1fff2 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 15:49:05 +0800 Subject: [PATCH 060/843] refactor Send --- request.go | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/request.go b/request.go index 38139d41..6b2c1582 100644 --- a/request.go +++ b/request.go @@ -68,8 +68,17 @@ func (r *Request) URL(url string) *Request { return r } -func (r *Request) send(method, url string) (*Response, error) { - return r.Method(method).URL(url).Send() +func (r *Request) Send(method, url string) (*Response, error) { + r.httpRequest.Method = method + u, err := urlpkg.Parse(url) + if err != nil { + return nil, err + } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) + r.httpRequest.URL = u + r.httpRequest.Host = u.Host + return r.execute() } // MustGet like Get, panic if error happens. @@ -81,9 +90,9 @@ func (r *Request) MustGet(url string) *Response { return resp } -// Get send the request with GET method and specified url. +// Get Send the request with GET method and specified url. func (r *Request) Get(url string) (*Response, error) { - return r.send(http.MethodGet, url) + return r.Send(http.MethodGet, url) } // MustPost like Post, panic if error happens. @@ -95,9 +104,9 @@ func (r *Request) MustPost(url string) *Response { return resp } -// Post send the request with POST method and specified url. +// Post Send the request with POST method and specified url. func (r *Request) Post(url string) (*Response, error) { - return r.send(http.MethodPost, url) + return r.Send(http.MethodPost, url) } // MustPut like Put, panic if error happens. @@ -109,9 +118,9 @@ func (r *Request) MustPut(url string) *Response { return resp } -// Put send the request with Put method and specified url. +// Put Send the request with Put method and specified url. func (r *Request) Put(url string) (*Response, error) { - return r.send(http.MethodPut, url) + return r.Send(http.MethodPut, url) } // MustPatch like Patch, panic if error happens. @@ -123,9 +132,9 @@ func (r *Request) MustPatch(url string) *Response { return resp } -// Patch send the request with PATCH method and specified url. +// Patch Send the request with PATCH method and specified url. func (r *Request) Patch(url string) (*Response, error) { - return r.send(http.MethodPatch, url) + return r.Send(http.MethodPatch, url) } // MustDelete like Delete, panic if error happens. @@ -137,9 +146,9 @@ func (r *Request) MustDelete(url string) *Response { return resp } -// Delete send the request with DELETE method and specified url. +// Delete Send the request with DELETE method and specified url. func (r *Request) Delete(url string) (*Response, error) { - return r.send(http.MethodDelete, url) + return r.Send(http.MethodDelete, url) } // MustOptions like Options, panic if error happens. @@ -151,23 +160,23 @@ func (r *Request) MustOptions(url string) *Response { return resp } -// Options send the request with OPTIONS method and specified url. +// Options Send the request with OPTIONS method and specified url. func (r *Request) Options(url string) (*Response, error) { - return r.send(http.MethodOptions, url) + return r.Send(http.MethodOptions, url) } // MustHead like Head, panic if error happens. func (r *Request) MustHead(url string) *Response { - resp, err := r.send(http.MethodHead, url) + resp, err := r.Send(http.MethodHead, url) if err != nil { panic(err) } return resp } -// Head send the request with HEAD method and specified url. +// Head Send the request with HEAD method and specified url. func (r *Request) Head(url string) (*Response, error) { - return r.send(http.MethodHead, url) + return r.Send(http.MethodHead, url) } // Body set the request body. @@ -255,8 +264,3 @@ func (r *Request) execute() (resp *Response, err error) { } return } - -// Send sends the request. -func (r *Request) Send() (resp *Response, err error) { - return r.execute() -} From 11dc6ce5398676c3a1f0223bd254bf480a91328e Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 15:53:39 +0800 Subject: [PATCH 061/843] client.DumpOptions --> client.SetDumpOptions --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 9f66408d..c308cf95 100644 --- a/client.go +++ b/client.go @@ -261,7 +261,7 @@ func (c *Client) Dump(enable bool) *Client { } // DumpOptions configures the underlying Transport's DumpOptions -func (c *Client) DumpOptions(opt *DumpOptions) *Client { +func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c } From 24752de670859353039f140b16cc6e8a7ede9a6d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 16:02:29 +0800 Subject: [PATCH 062/843] add some proxy function --- client.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/client.go b/client.go index c308cf95..a64a00e9 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/cookiejar" + "net/url" "os" "time" ) @@ -272,6 +273,27 @@ func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { return c } +// Proxy set the proxy function. +func (c *Client) Proxy(proxy func(*http.Request) (*url.URL, error)) *Client { + c.t.Proxy = proxy + return c +} + +func (c *Client) ProxyFromEnv() *Client { + c.t.Proxy = http.ProxyFromEnvironment + return c +} + +func (c *Client) ProxyURL(proxyUrl string) *Client { + u, err := url.Parse(proxyUrl) + if err != nil { + logf(c.log, "failed to parse proxy url %s: %v", proxyUrl, err) + return c + } + c.t.Proxy = http.ProxyURL(u) + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() From 92a6bfb3fcb8496019784405ea3f18a2cbde5a55 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 16:20:23 +0800 Subject: [PATCH 063/843] support path params --- request.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/request.go b/request.go index 6b2c1582..c387e6af 100644 --- a/request.go +++ b/request.go @@ -15,6 +15,7 @@ import ( // Request is the http request type Request struct { + pathParams map[string]string error error client *Client httpRequest *http.Request @@ -25,6 +26,21 @@ func New() *Request { return defaultClient.R() } +func (r *Request) PathParams(params map[string]string) *Request { + for key, value := range params { + r.PathParam(key, value) + } + return r +} + +func (r *Request) PathParam(key, value string) *Request { + if r.pathParams == nil { + r.pathParams = make(map[string]string) + } + r.pathParams[key] = value + return r +} + func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } @@ -69,7 +85,19 @@ func (r *Request) URL(url string) *Request { } func (r *Request) Send(method, url string) (*Response, error) { + if r.error != nil { + return nil, r.error + } + r.httpRequest.Method = method + + // handle path params + if len(r.pathParams) > 0 { + for k, v := range r.pathParams { + url = strings.Replace(url, "{"+k+"}", urlpkg.PathEscape(v), -1) + } + } + u, err := urlpkg.Parse(url) if err != nil { return nil, err From 1250135f525c260a31c5e3c240111a855ef5d533 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 18:14:00 +0800 Subject: [PATCH 064/843] refactor api style, add Set* methods --- client.go | 87 +++++++++++++++++++++++++++++------------------------ request.go | 88 ++++++++++++++++++++++++------------------------------ 2 files changed, 87 insertions(+), 88 deletions(-) diff --git a/client.go b/client.go index a64a00e9..68bc3ae1 100644 --- a/client.go +++ b/client.go @@ -11,6 +11,11 @@ import ( "time" ) +type ( + // RequestMiddleware type is for request middleware, called before a request is sent + RequestMiddleware func(*Client, *Request) error +) + // DefaultClient returns the global default Client. func DefaultClient() *Client { return defaultClient @@ -81,10 +86,10 @@ const ( // agent to pretend to be a web browser, Avoid returning abnormal // data from some sites. func (c *Client) DebugMode() *Client { - return c.AutoDecodeTextType(). - Dump(true). + return c.EnableAutoDecodeTextType(). + EnableDumpAll(). SetLogger(NewLogger(os.Stdout)). - UserAgent(userAgentChrome) + SetUserAgent(userAgentChrome) } // SetLogger set the logger for req. @@ -132,8 +137,8 @@ func (c *Client) enableDump() { c.t.EnableDump(c.GetDumpOptions()) } -// DumpToFile indicates that the content should dump to the specified filename. -func (c *Client) DumpToFile(filename string) *Client { +// EnableDumpToFile indicates that the content should dump to the specified filename. +func (c *Client) EnableDumpToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { logf(c.log, "create dump file error: %v", err) @@ -143,25 +148,25 @@ func (c *Client) DumpToFile(filename string) *Client { return c } -// DumpTo indicates that the content should dump to the specified destination. -func (c *Client) DumpTo(output io.Writer) *Client { +// EnableDumpTo indicates that the content should dump to the specified destination. +func (c *Client) EnableDumpTo(output io.Writer) *Client { c.GetDumpOptions().Output = output c.enableDump() return c } -// DumpAsync indicates that the dump should be done asynchronously, +// EnableDumpAsync indicates that the dump should be done asynchronously, // can be used for debugging in production environment without // affecting performance. -func (c *Client) DumpAsync() *Client { +func (c *Client) EnableDumpAsync() *Client { o := c.GetDumpOptions() o.Async = true c.enableDump() return c } -// DumpOnlyResponse indicates that should dump the responses' head and response. -func (c *Client) DumpOnlyResponse() *Client { +// EnableDumpOnlyResponse indicates that should dump the responses' head and response. +func (c *Client) EnableDumpOnlyResponse() *Client { o := c.GetDumpOptions() o.ResponseHead = true o.ResponseBody = true @@ -171,8 +176,8 @@ func (c *Client) DumpOnlyResponse() *Client { return c } -// DumpOnlyRequest indicates that should dump the requests' head and response. -func (c *Client) DumpOnlyRequest() *Client { +// EnableDumpOnlyRequest indicates that should dump the requests' head and response. +func (c *Client) EnableDumpOnlyRequest() *Client { o := c.GetDumpOptions() o.RequestHead = true o.RequestBody = true @@ -182,8 +187,8 @@ func (c *Client) DumpOnlyRequest() *Client { return c } -// DumpOnlyBody indicates that should dump the body of requests and responses. -func (c *Client) DumpOnlyBody() *Client { +// EnableDumpOnlyBody indicates that should dump the body of requests and responses. +func (c *Client) EnableDumpOnlyBody() *Client { o := c.GetDumpOptions() o.RequestBody = true o.ResponseBody = true @@ -193,8 +198,8 @@ func (c *Client) DumpOnlyBody() *Client { return c } -// DumpOnlyHead indicates that should dump the head of requests and responses. -func (c *Client) DumpOnlyHead() *Client { +// EnableDumpOnlyHead indicates that should dump the head of requests and responses. +func (c *Client) EnableDumpOnlyHead() *Client { o := c.GetDumpOptions() o.RequestHead = true o.ResponseHead = true @@ -204,8 +209,8 @@ func (c *Client) DumpOnlyHead() *Client { return c } -// DumpAll indicates that should dump both requests and responses' head and body. -func (c *Client) DumpAll() *Client { +// EnableDumpAll indicates that should dump both requests and responses' head and body. +func (c *Client) EnableDumpAll() *Client { o := c.GetDumpOptions() o.RequestHead = true o.RequestBody = true @@ -220,27 +225,27 @@ func (c *Client) NewRequest() *Request { return c.R() } -// AutoDecodeAllType indicates that try autodetect and decode all content type. -func (c *Client) AutoDecodeAllType() *Client { +// EnableAutoDecodeAllType indicates that try autodetect and decode all content type. +func (c *Client) EnableAutoDecodeAllType() *Client { c.GetResponseOptions().AutoDecodeContentType = func(contentType string) bool { return true } return c } -// AutoDecodeTextType indicates that only try autodetect and decode the text content type. -func (c *Client) AutoDecodeTextType() *Client { +// EnableAutoDecodeTextType indicates that only try autodetect and decode the text content type. +func (c *Client) EnableAutoDecodeTextType() *Client { c.GetResponseOptions().AutoDecodeContentType = autoDecodeText return c } -// UserAgent set the "User-Agent" header for all requests. -func (c *Client) UserAgent(userAgent string) *Client { - return c.CommonHeader("User-Agent", userAgent) +// SetUserAgent set the "User-Agent" header for all requests. +func (c *Client) SetUserAgent(userAgent string) *Client { + return c.SetCommonHeader("User-Agent", userAgent) } -// CommonHeader set the common header for all requests. -func (c *Client) CommonHeader(key, value string) *Client { +// SetCommonHeader set the common header for all requests. +func (c *Client) SetCommonHeader(key, value string) *Client { if c.commonHeader == nil { c.commonHeader = make(map[string]string) } @@ -248,20 +253,24 @@ func (c *Client) CommonHeader(key, value string) *Client { return c } -// Dump if true, enables dump requests and responses, allowing you +// EnableDump enables dump requests and responses, allowing you // to clearly see the content of all requests and responses,which // is very convenient for debugging APIs. -// Dump if false, disable the dump behaviour. -func (c *Client) Dump(enable bool) *Client { - if !enable { - c.t.DisableDump() +func (c *Client) EnableDump() *Client { + if c.t.dump != nil { // dump already started return c } - c.enableDump() + c.t.EnableDump(c.GetDumpOptions()) + return c +} + +// DisableDump stop the dump. +func (c *Client) DisableDump() *Client { + c.t.DisableDump() return c } -// DumpOptions configures the underlying Transport's DumpOptions +// SetDumpOptions configures the underlying Transport's DumpOptions func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c @@ -273,18 +282,18 @@ func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { return c } -// Proxy set the proxy function. -func (c *Client) Proxy(proxy func(*http.Request) (*url.URL, error)) *Client { +// SetProxy set the proxy function. +func (c *Client) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Client { c.t.Proxy = proxy return c } -func (c *Client) ProxyFromEnv() *Client { +func (c *Client) SetProxyFromEnv() *Client { c.t.Proxy = http.ProxyFromEnvironment return c } -func (c *Client) ProxyURL(proxyUrl string) *Client { +func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := url.Parse(proxyUrl) if err != nil { logf(c.log, "failed to parse proxy url %s: %v", proxyUrl, err) diff --git a/request.go b/request.go index c387e6af..e0a3da69 100644 --- a/request.go +++ b/request.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "context" "encoding/json" "fmt" "github.com/hashicorp/go-multierror" @@ -15,7 +14,9 @@ import ( // Request is the http request type Request struct { + URL string pathParams map[string]string + query urlpkg.Values error error client *Client httpRequest *http.Request @@ -26,14 +27,14 @@ func New() *Request { return defaultClient.R() } -func (r *Request) PathParams(params map[string]string) *Request { +func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { - r.PathParam(key, value) + r.SetPathParam(key, value) } return r } -func (r *Request) PathParam(key, value string) *Request { +func (r *Request) SetPathParam(key, value string) *Request { if r.pathParams == nil { r.pathParams = make(map[string]string) } @@ -51,8 +52,11 @@ func (r *Request) Error() error { return r.error } -// Method set the http request method. -func (r *Request) Method(method string) *Request { +func (r *Request) Send(method, url string) (*Response, error) { + if r.error != nil { + return nil, r.error + } + if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -62,33 +66,9 @@ func (r *Request) Method(method string) *Request { if !validMethod(method) { err := fmt.Errorf("net/http: invalid method %q", method) if err != nil { - r.appendError(err) + return nil, err } } - r.httpRequest.Method = method - r.httpRequest = r.httpRequest.WithContext(context.Background()) - return r -} - -// URL set the http request url. -func (r *Request) URL(url string) *Request { - u, err := urlpkg.Parse(url) - if err != nil { - r.appendError(err) - return r - } - // The host's colon:port should be normalized. See Issue 14836. - u.Host = removeEmptyPort(u.Host) - r.httpRequest.URL = u - r.httpRequest.Host = u.Host - return r -} - -func (r *Request) Send(method, url string) (*Response, error) { - if r.error != nil { - return nil, r.error - } - r.httpRequest.Method = method // handle path params @@ -102,6 +82,16 @@ func (r *Request) Send(method, url string) (*Response, error) { if err != nil { return nil, err } + + // handle query params + if len(r.query) > 0 { + if len(strings.TrimSpace(u.RawQuery)) == 0 { // empty query + u.RawQuery = r.query.Encode() + } else { // query not empty + u.RawQuery = u.RawQuery + "&" + r.query.Encode() + } + } + // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) r.httpRequest.URL = u @@ -207,8 +197,8 @@ func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } -// Body set the request body. -func (r *Request) Body(body interface{}) *Request { +// SetBody set the request body. +func (r *Request) SetBody(body interface{}) *Request { if body == nil { return r } @@ -218,53 +208,53 @@ func (r *Request) Body(body interface{}) *Request { case io.Reader: r.httpRequest.Body = ioutil.NopCloser(b) case []byte: - r.BodyBytes(b) + r.SetBodyBytes(b) case string: - r.BodyString(b) + r.SetBodyString(b) } return r } -// BodyBytes set the request body as []byte. -func (r *Request) BodyBytes(body []byte) *Request { +// SetBodyBytes set the request body as []byte. +func (r *Request) SetBodyBytes(body []byte) *Request { r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r } -// BodyString set the request body as string. -func (r *Request) BodyString(body string) *Request { +// SetBodyString set the request body as string. +func (r *Request) SetBodyString(body string) *Request { r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r } -// BodyJsonString set the request body as string and set Content-Type header +// SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=UTF-8" -func (r *Request) BodyJsonString(body string) *Request { +func (r *Request) SetBodyJsonString(body string) *Request { r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) - r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } -// BodyJsonBytes set the request body as []byte and set Content-Type header +// SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=UTF-8" -func (r *Request) BodyJsonBytes(body []byte) *Request { +func (r *Request) SetBodyJsonBytes(body []byte) *Request { r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) - r.setContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } -// BodyJsonMarshal set the request body that marshaled from object, and +// SetBodyJsonMarshal set the request body that marshaled from object, and // set Content-Type header as "application/json; charset=UTF-8" -func (r *Request) BodyJsonMarshal(v interface{}) *Request { +func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { b, err := json.Marshal(v) if err != nil { r.appendError(err) return r } - return r.BodyBytes(b) + return r.SetBodyBytes(b) } -func (r *Request) setContentType(contentType string) *Request { +func (r *Request) SetContentType(contentType string) *Request { r.httpRequest.Header.Set("Content-Type", contentType) return r } From 8cd02ac9d2991c5e79a19d59ac4c28967504fa78 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 18:52:53 +0800 Subject: [PATCH 065/843] support middleware --- client.go | 53 +++++++++++++++++---------- internal/util/util.go | 5 +++ middleware.go | 85 +++++++++++++++++++++++++++++++++++++++++++ request.go | 74 ++++++++++++++++++++++++------------- 4 files changed, 173 insertions(+), 44 deletions(-) create mode 100644 middleware.go diff --git a/client.go b/client.go index 68bc3ae1..63d340ca 100644 --- a/client.go +++ b/client.go @@ -2,20 +2,17 @@ package req import ( "encoding/json" + "github.com/imroc/req/v2/internal/util" "golang.org/x/net/publicsuffix" "io" "net/http" "net/http/cookiejar" "net/url" "os" + "strings" "time" ) -type ( - // RequestMiddleware type is for request middleware, called before a request is sent - RequestMiddleware func(*Client, *Request) error -) - // DefaultClient returns the global default Client. func DefaultClient() *Client { return defaultClient @@ -32,13 +29,18 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - log Logger - t *Transport - t2 *http2Transport - dumpOptions *DumpOptions - httpClient *http.Client - jsonDecoder *json.Decoder - commonHeader map[string]string + HostURL string + PathParams map[string]string + QueryParams url.Values + scheme string + log Logger + t *Transport + t2 *http2Transport + dumpOptions *DumpOptions + httpClient *http.Client + jsonDecoder *json.Decoder + commonHeader map[string]string + beforeRequest []RequestMiddleware } func copyCommonHeader(h map[string]string) map[string]string { @@ -92,6 +94,15 @@ func (c *Client) DebugMode() *Client { SetUserAgent(userAgentChrome) } +// SetScheme method sets custom scheme in the Resty client. It's way to override default. +// client.SetScheme("http") +func (c *Client) SetScheme(scheme string) *Client { + if !util.IsStringEmpty(scheme) { + c.scheme = strings.TrimSpace(scheme) + } + return c +} + // SetLogger set the logger for req. func (c *Client) SetLogger(log Logger) *Client { if log == nil { @@ -108,7 +119,7 @@ func (c *Client) GetResponseOptions() *ResponseOptions { return c.t.ResponseOptions } -// ResponseOptions set the ResponseOptions for the underlying Transport. +// SetResponseOptions set the ResponseOptions for the underlying Transport. func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { if opt == nil { return c @@ -117,8 +128,8 @@ func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { return c } -// Timeout set the timeout for all requests. -func (c *Client) Timeout(d time.Duration) *Client { +// SetTimeout set the timeout for all requests. +func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d return c } @@ -341,11 +352,15 @@ func C() *Client { Jar: jar, Timeout: 2 * time.Minute, } + beforeRequest := []RequestMiddleware{ + parseRequestURL, + } c := &Client{ - log: &emptyLogger{}, - httpClient: httpClient, - t: t, - t2: t2, + beforeRequest: beforeRequest, + log: &emptyLogger{}, + httpClient: httpClient, + t: t, + t2: t2, } return c } diff --git a/internal/util/util.go b/internal/util/util.go index a6b933ca..a4c0b4d2 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -28,3 +28,8 @@ func CutBytes(s, sep []byte) (before, after []byte, found bool) { } return s, nil, false } + +// IsStringEmpty method tells whether given string is empty or not +func IsStringEmpty(str string) bool { + return len(strings.TrimSpace(str)) == 0 +} \ No newline at end of file diff --git a/middleware.go b/middleware.go new file mode 100644 index 00000000..18c48244 --- /dev/null +++ b/middleware.go @@ -0,0 +1,85 @@ +package req + +import ( + "github.com/imroc/req/v2/internal/util" + "net/url" + "strings" +) + +type ( + // RequestMiddleware type is for request middleware, called before a request is sent + RequestMiddleware func(*Client, *Request) error +) + +func parseRequestURL(c *Client, r *Request) error { + if len(r.PathParams) > 0 { + for p, v := range r.PathParams { + r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1) + } + } + if len(c.PathParams) > 0 { + for p, v := range c.PathParams { + r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1) + } + } + + // Parsing request URL + reqURL, err := url.Parse(r.URL) + if err != nil { + return err + } + + // If Request.URL is relative path then added c.HostURL into + // the request URL otherwise Request.URL will be used as-is + if !reqURL.IsAbs() { + r.URL = reqURL.String() + if len(r.URL) > 0 && r.URL[0] != '/' { + r.URL = "/" + r.URL + } + + reqURL, err = url.Parse(c.HostURL + r.URL) + if err != nil { + return err + } + } + + // GH #407 && #318 + if reqURL.Scheme == "" && len(c.scheme) > 0 { + reqURL.Scheme = c.scheme + } + + // Adding Query Param + query := make(url.Values) + for k, v := range c.QueryParams { + for _, iv := range v { + query.Add(k, iv) + } + } + + for k, v := range r.QueryParams { + // remove query param from client level by key + // since overrides happens for that key in the request + query.Del(k) + + for _, iv := range v { + query.Add(k, iv) + } + } + + // Preserve query string order partially. + // Since not feasible in `SetQuery*` resty methods, because + // standard package `url.Encode(...)` sorts the query params + // alphabetically + if len(query) > 0 { + if util.IsStringEmpty(reqURL.RawQuery) { + reqURL.RawQuery = query.Encode() + } else { + reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode() + } + } + + r.URL = reqURL.String() + + return nil +} + diff --git a/request.go b/request.go index e0a3da69..bc727ebd 100644 --- a/request.go +++ b/request.go @@ -9,17 +9,20 @@ import ( "io/ioutil" "net/http" urlpkg "net/url" + "os" "strings" ) // Request is the http request type Request struct { - URL string - pathParams map[string]string - query urlpkg.Values - error error - client *Client - httpRequest *http.Request + URL string + PathParams map[string]string + QueryParams urlpkg.Values + error error + client *Client + httpRequest *http.Request + isSaveResponse bool + output io.WriteCloser } // New create a new request using the global default client. @@ -27,6 +30,36 @@ func New() *Request { return defaultClient.R() } +func (r *Request) SetOutputFile(file string) *Request { + output, err := os.Create(file) + if err != nil { + r.appendError(err) + return r + } + return r.SetOutput(output) +} + +func (r *Request) SetOutput(output io.WriteCloser) *Request { + r.output = output + r.isSaveResponse = true + return r +} + +func (r *Request) SetQueryParams(params map[string]string) *Request { + for k, v := range params { + r.SetQueryParam(k, v) + } + return r +} + +func (r *Request) SetQueryParam(key, value string) *Request { + if r.QueryParams == nil { + r.QueryParams = make(urlpkg.Values) + } + r.QueryParams.Set(key, value) + return r +} + func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { r.SetPathParam(key, value) @@ -35,10 +68,10 @@ func (r *Request) SetPathParams(params map[string]string) *Request { } func (r *Request) SetPathParam(key, value string) *Request { - if r.pathParams == nil { - r.pathParams = make(map[string]string) + if r.PathParams == nil { + r.PathParams = make(map[string]string) } - r.pathParams[key] = value + r.PathParams[key] = value return r } @@ -57,6 +90,8 @@ func (r *Request) Send(method, url string) (*Response, error) { return nil, r.error } + r.URL = url + if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -71,28 +106,17 @@ func (r *Request) Send(method, url string) (*Response, error) { } r.httpRequest.Method = method - // handle path params - if len(r.pathParams) > 0 { - for k, v := range r.pathParams { - url = strings.Replace(url, "{"+k+"}", urlpkg.PathEscape(v), -1) + for _, f := range r.client.beforeRequest { + if err := f(r.client, r); err != nil { + return nil, err } } - u, err := urlpkg.Parse(url) + // The host's colon:port should be normalized. See Issue 14836. + u, err := urlpkg.Parse(r.URL) if err != nil { return nil, err } - - // handle query params - if len(r.query) > 0 { - if len(strings.TrimSpace(u.RawQuery)) == 0 { // empty query - u.RawQuery = r.query.Encode() - } else { // query not empty - u.RawQuery = u.RawQuery + "&" + r.query.Encode() - } - } - - // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) r.httpRequest.URL = u r.httpRequest.Host = u.Host From 0c0f5266cbf52cbecf87c186d9badcfbc02fdd54 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 20:38:10 +0800 Subject: [PATCH 066/843] some refactor --- client.go | 43 ++++++++++++++++++++++--------------------- request.go | 12 +++++++++--- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index 63d340ca..aa56ec89 100644 --- a/client.go +++ b/client.go @@ -29,18 +29,19 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - HostURL string - PathParams map[string]string - QueryParams url.Values - scheme string - log Logger - t *Transport - t2 *http2Transport - dumpOptions *DumpOptions - httpClient *http.Client - jsonDecoder *json.Decoder - commonHeader map[string]string - beforeRequest []RequestMiddleware + HostURL string + PathParams map[string]string + QueryParams url.Values + scheme string + log Logger + t *Transport + t2 *http2Transport + dumpOptions *DumpOptions + httpClient *http.Client + jsonDecoder *json.Decoder + commonHeader map[string]string + beforeRequest []RequestMiddleware + udBeforeRequest []RequestMiddleware } func copyCommonHeader(h map[string]string) map[string]string { @@ -267,17 +268,12 @@ func (c *Client) SetCommonHeader(key, value string) *Client { // EnableDump enables dump requests and responses, allowing you // to clearly see the content of all requests and responses,which // is very convenient for debugging APIs. -func (c *Client) EnableDump() *Client { - if c.t.dump != nil { // dump already started +func (c *Client) EnableDump(enable bool) *Client { + if !enable { + c.t.DisableDump() return c } - c.t.EnableDump(c.GetDumpOptions()) - return c -} - -// DisableDump stop the dump. -func (c *Client) DisableDump() *Client { - c.t.DisableDump() + c.enableDump() return c } @@ -304,6 +300,11 @@ func (c *Client) SetProxyFromEnv() *Client { return c } +func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { + c.udBeforeRequest = append(c.udBeforeRequest, m) + return c +} + func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := url.Parse(proxyUrl) if err != nil { diff --git a/request.go b/request.go index bc727ebd..db3d6bcb 100644 --- a/request.go +++ b/request.go @@ -90,8 +90,6 @@ func (r *Request) Send(method, url string) (*Response, error) { return nil, r.error } - r.URL = url - if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -99,13 +97,21 @@ func (r *Request) Send(method, url string) (*Response, error) { method = "GET" } if !validMethod(method) { - err := fmt.Errorf("net/http: invalid method %q", method) + err := fmt.Errorf("req: invalid method %q", method) if err != nil { return nil, err } } r.httpRequest.Method = method + r.URL = url + + for _, f := range r.client.udBeforeRequest { + if err := f(r.client, r); err != nil { + return nil, err + } + } + for _, f := range r.client.beforeRequest { if err := f(r.client, r); err != nil { return nil, err From f10409308c1c63724f4411b507714a3c8a000462 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 21:03:08 +0800 Subject: [PATCH 067/843] name refactor --- client.go | 72 +++++++++++++++++++++++++++++---------------- dump.go | 32 ++++++++++---------- h2_bundle.go | 8 ++--- http_request.go | 2 +- request.go | 2 +- textproto_reader.go | 2 +- 6 files changed, 69 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index aa56ec89..baf53c76 100644 --- a/client.go +++ b/client.go @@ -39,12 +39,25 @@ type Client struct { dumpOptions *DumpOptions httpClient *http.Client jsonDecoder *json.Decoder - commonHeader map[string]string + Headers map[string]string beforeRequest []RequestMiddleware udBeforeRequest []RequestMiddleware } -func copyCommonHeader(h map[string]string) map[string]string { +func cloneUrlValues(v url.Values) url.Values { + if v == nil { + return nil + } + vv := make(url.Values) + for key, values := range v { + for _, value := range values { + vv.Add(key, value) + } + } + return vv +} + +func cloneMap(h map[string]string) map[string]string { if h == nil { return nil } @@ -180,10 +193,10 @@ func (c *Client) EnableDumpAsync() *Client { // EnableDumpOnlyResponse indicates that should dump the responses' head and response. func (c *Client) EnableDumpOnlyResponse() *Client { o := c.GetDumpOptions() - o.ResponseHead = true + o.ResponseHeader = true o.ResponseBody = true o.RequestBody = false - o.RequestHead = false + o.RequestHeader = false c.enableDump() return c } @@ -191,10 +204,10 @@ func (c *Client) EnableDumpOnlyResponse() *Client { // EnableDumpOnlyRequest indicates that should dump the requests' head and response. func (c *Client) EnableDumpOnlyRequest() *Client { o := c.GetDumpOptions() - o.RequestHead = true + o.RequestHeader = true o.RequestBody = true o.ResponseBody = false - o.ResponseHead = false + o.ResponseHeader = false c.enableDump() return c } @@ -204,17 +217,17 @@ func (c *Client) EnableDumpOnlyBody() *Client { o := c.GetDumpOptions() o.RequestBody = true o.ResponseBody = true - o.RequestHead = false - o.ResponseHead = false + o.RequestHeader = false + o.ResponseHeader = false c.enableDump() return c } -// EnableDumpOnlyHead indicates that should dump the head of requests and responses. -func (c *Client) EnableDumpOnlyHead() *Client { +// EnableDumpOnlyHeader indicates that should dump the head of requests and responses. +func (c *Client) EnableDumpOnlyHeader() *Client { o := c.GetDumpOptions() - o.RequestHead = true - o.ResponseHead = true + o.RequestHeader = true + o.ResponseHeader = true o.RequestBody = false o.ResponseBody = false c.enableDump() @@ -224,9 +237,9 @@ func (c *Client) EnableDumpOnlyHead() *Client { // EnableDumpAll indicates that should dump both requests and responses' head and body. func (c *Client) EnableDumpAll() *Client { o := c.GetDumpOptions() - o.RequestHead = true + o.RequestHeader = true o.RequestBody = true - o.ResponseHead = true + o.ResponseHeader = true o.ResponseBody = true c.enableDump() return c @@ -253,15 +266,15 @@ func (c *Client) EnableAutoDecodeTextType() *Client { // SetUserAgent set the "User-Agent" header for all requests. func (c *Client) SetUserAgent(userAgent string) *Client { - return c.SetCommonHeader("User-Agent", userAgent) + return c.SetHeader("User-Agent", userAgent) } -// SetCommonHeader set the common header for all requests. -func (c *Client) SetCommonHeader(key, value string) *Client { - if c.commonHeader == nil { - c.commonHeader = make(map[string]string) +// SetHeader set the common header for all requests. +func (c *Client) SetHeader(key, value string) *Client { + if c.Headers == nil { + c.Headers = make(map[string]string) } - c.commonHeader[key] = value + c.Headers[key] = value return c } @@ -327,12 +340,19 @@ func (c *Client) Clone() *Client { cc := *c.httpClient cc.Transport = t return &Client{ - httpClient: &cc, - t: t, - t2: t2, - dumpOptions: c.dumpOptions.Clone(), - jsonDecoder: c.jsonDecoder, - commonHeader: copyCommonHeader(c.commonHeader), + httpClient: &cc, + t: t, + t2: t2, + dumpOptions: c.dumpOptions.Clone(), + jsonDecoder: c.jsonDecoder, + Headers: cloneMap(c.Headers), + PathParams: cloneMap(c.PathParams), + QueryParams: cloneUrlValues(c.QueryParams), + HostURL: c.HostURL, + scheme: c.scheme, + log: c.log, + beforeRequest: c.beforeRequest, + udBeforeRequest: c.udBeforeRequest, } } diff --git a/dump.go b/dump.go index 332a5510..49426c86 100644 --- a/dump.go +++ b/dump.go @@ -7,12 +7,12 @@ import ( // DumpOptions controls the dump behavior. type DumpOptions struct { - Output io.Writer - RequestHead bool - RequestBody bool - ResponseHead bool - ResponseBody bool - Async bool + Output io.Writer + RequestHeader bool + RequestBody bool + ResponseHeader bool + ResponseBody bool + Async bool } func (do *DumpOptions) Clone() *DumpOptions { @@ -99,20 +99,20 @@ type dumper struct { func newDefaultDumpOptions() *DumpOptions { return &DumpOptions{ - Output: os.Stdout, - RequestBody: true, - ResponseBody: true, - ResponseHead: true, - RequestHead: true, + Output: os.Stdout, + RequestBody: true, + ResponseBody: true, + ResponseHeader: true, + RequestHeader: true, } } var defaultDumpOptions = &DumpOptions{ - Output: os.Stdout, - RequestBody: true, - ResponseBody: true, - ResponseHead: true, - RequestHead: true, + Output: os.Stdout, + RequestBody: true, + ResponseBody: true, + ResponseHeader: true, + RequestHeader: true, } func newDumper(opt *DumpOptions) *dumper { diff --git a/h2_bundle.go b/h2_bundle.go index 871357e0..57c3e218 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1865,7 +1865,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { hr, err := fr.readMetaFrame(f.(*http2HeadersFrame)) - if err == nil && fr.dump != nil && fr.dump.ResponseHead { + if err == nil && fr.dump != nil && fr.dump.ResponseHeader { fr.dump.dump([]byte("\r\n")) } return hr, err @@ -2910,7 +2910,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr mh.Fields = append(mh.Fields, hf) } emitFunc := rawEmitFunc - if fr.dump != nil && fr.dump.ResponseHead { + if fr.dump != nil && fr.dump.ResponseHeader { emitFunc = func(hf hpack.HeaderField) { fr.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) rawEmitFunc(hf) @@ -7184,7 +7184,7 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } - if t.t1.dump != nil && t.t1.dump.RequestHead { + if t.t1.dump != nil && t.t1.dump.RequestHeader { cc.writeHeader = func(name, value string) { t.t1.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) cc._writeHeader(name, value) @@ -8403,7 +8403,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } }) - if cc.t.t1.dump != nil && cc.t.t1.dump.RequestHead { + if cc.t.t1.dump != nil && cc.t.t1.dump.RequestHeader { cc.t.t1.dump.dump([]byte("\r\n")) } diff --git a/http_request.go b/http_request.go index 1a50b374..5c50ef7b 100644 --- a/http_request.go +++ b/http_request.go @@ -149,7 +149,7 @@ func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders ht } rw := w // raw writer - if dump != nil && dump.RequestHead { + if dump != nil && dump.RequestHeader { w = dump.WrapWriter(w) } diff --git a/request.go b/request.go index db3d6bcb..231b329c 100644 --- a/request.go +++ b/request.go @@ -293,7 +293,7 @@ func (r *Request) execute() (resp *Response, err error) { if r.error != nil { return nil, r.error } - for k, v := range r.client.commonHeader { + for k, v := range r.client.Headers { if r.httpRequest.Header.Get(k) == "" { r.httpRequest.Header.Set(k, v) } diff --git a/textproto_reader.go b/textproto_reader.go index 7e1a283c..d31c14db 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -52,7 +52,7 @@ type textprotoReader struct { func newTextprotoReader(r *bufio.Reader, dump *dumper) *textprotoReader { commonHeaderOnce.Do(initCommonHeader) t := &textprotoReader{R: r} - if dump != nil && dump.ResponseHead { + if dump != nil && dump.ResponseHeader { t.readLine = func() (line []byte, isPrefix bool, err error) { line, err = t.R.ReadSlice('\n') if len(line) == 0 { From e23023cf986564b7daa77594ffeb4b7dd5ccfb6b Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 21:06:51 +0800 Subject: [PATCH 068/843] expose RawRequest --- client.go | 4 ++-- request.go | 30 +++++++++++++++--------------- response.go | 9 +-------- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index baf53c76..c42d4e63 100644 --- a/client.go +++ b/client.go @@ -77,8 +77,8 @@ func (c *Client) R() *Request { ProtoMinor: 1, } return &Request{ - client: c, - httpRequest: req, + client: c, + RawRequest: req, } } diff --git a/request.go b/request.go index 231b329c..00b301dc 100644 --- a/request.go +++ b/request.go @@ -20,7 +20,7 @@ type Request struct { QueryParams urlpkg.Values error error client *Client - httpRequest *http.Request + RawRequest *http.Request isSaveResponse bool output io.WriteCloser } @@ -102,7 +102,7 @@ func (r *Request) Send(method, url string) (*Response, error) { return nil, err } } - r.httpRequest.Method = method + r.RawRequest.Method = method r.URL = url @@ -124,8 +124,8 @@ func (r *Request) Send(method, url string) (*Response, error) { return nil, err } u.Host = removeEmptyPort(u.Host) - r.httpRequest.URL = u - r.httpRequest.Host = u.Host + r.RawRequest.URL = u + r.RawRequest.Host = u.Host return r.execute() } @@ -234,9 +234,9 @@ func (r *Request) SetBody(body interface{}) *Request { } switch b := body.(type) { case io.ReadCloser: - r.httpRequest.Body = b + r.RawRequest.Body = b case io.Reader: - r.httpRequest.Body = ioutil.NopCloser(b) + r.RawRequest.Body = ioutil.NopCloser(b) case []byte: r.SetBodyBytes(b) case string: @@ -247,20 +247,20 @@ func (r *Request) SetBody(body interface{}) *Request { // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { - r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r } // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { - r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r } // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=UTF-8" func (r *Request) SetBodyJsonString(body string) *Request { - r.httpRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } @@ -268,7 +268,7 @@ func (r *Request) SetBodyJsonString(body string) *Request { // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=UTF-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { - r.httpRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) return r } @@ -285,7 +285,7 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { } func (r *Request) SetContentType(contentType string) *Request { - r.httpRequest.Header.Set("Content-Type", contentType) + r.RawRequest.Header.Set("Content-Type", contentType) return r } @@ -294,12 +294,12 @@ func (r *Request) execute() (resp *Response, err error) { return nil, r.error } for k, v := range r.client.Headers { - if r.httpRequest.Header.Get(k) == "" { - r.httpRequest.Header.Set(k, v) + if r.RawRequest.Header.Get(k) == "" { + r.RawRequest.Header.Set(k, v) } } - logf(r.client.log, "%s %s", r.httpRequest.Method, r.httpRequest.URL.String()) - httpResponse, err := r.client.httpClient.Do(r.httpRequest) + logf(r.client.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) + httpResponse, err := r.client.httpClient.Do(r.RawRequest) if err != nil { return } diff --git a/response.go b/response.go index 0eecad14..d073adca 100644 --- a/response.go +++ b/response.go @@ -1,8 +1,6 @@ package req import ( - "io" - "io/ioutil" "net/http" "strings" ) @@ -42,9 +40,4 @@ func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) type Response struct { *http.Response request *Request -} - -func (r *Response) Discard() error { - _, err := io.Copy(ioutil.Discard, r.Response.Body) - return err -} +} \ No newline at end of file From 77c324f51b71e939180981e161e02b4d328ecd20 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 22:09:34 +0800 Subject: [PATCH 069/843] client and request refactor --- client.go | 130 +++++++++++++++++++++++++++++++++++++------------- middleware.go | 20 +++++++- request.go | 79 +++++++----------------------- response.go | 4 -- 4 files changed, 133 insertions(+), 100 deletions(-) diff --git a/client.go b/client.go index c42d4e63..2ff153b8 100644 --- a/client.go +++ b/client.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "net/http/cookiejar" - "net/url" + urlpkg "net/url" "os" "strings" "time" @@ -29,26 +29,40 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - HostURL string - PathParams map[string]string - QueryParams url.Values - scheme string - log Logger - t *Transport - t2 *http2Transport - dumpOptions *DumpOptions - httpClient *http.Client - jsonDecoder *json.Decoder - Headers map[string]string - beforeRequest []RequestMiddleware - udBeforeRequest []RequestMiddleware -} - -func cloneUrlValues(v url.Values) url.Values { + HostURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + disableAutoReadResponse bool + scheme string + log Logger + t *Transport + t2 *http2Transport + dumpOptions *DumpOptions + httpClient *http.Client + jsonDecoder *json.Decoder + beforeRequest []RequestMiddleware + udBeforeRequest []RequestMiddleware +} + +func cloneHeaders(hdrs http.Header) http.Header { + if hdrs == nil { + return nil + } + h := make(http.Header) + for k, vs := range hdrs { + for _, v := range vs { + h.Add(k, v) + } + } + return h +} + +func cloneUrlValues(v urlpkg.Values) urlpkg.Values { if v == nil { return nil } - vv := make(url.Values) + vv := make(urlpkg.Values) for key, values := range v { for _, value := range values { vv.Add(key, value) @@ -82,17 +96,6 @@ func (c *Client) R() *Request { } } -func (c *Client) AutoDiscardResponseBody() *Client { - c.GetResponseOptions().AutoDiscard = true - return c -} - -// TestMode is like DebugMode, but discard response body, so you can -// dump responses without read response body -func (c *Client) TestMode() *Client { - return c.DebugMode().AutoDiscardResponseBody() -} - const ( userAgentFirefox = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:95.0) Gecko/20100101 Firefox/95.0" userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" @@ -250,6 +253,11 @@ func (c *Client) NewRequest() *Request { return c.R() } +func (c *Client) DisableAutoReadResponse(disable bool) *Client { + c.disableAutoReadResponse = disable + return c +} + // EnableAutoDecodeAllType indicates that try autodetect and decode all content type. func (c *Client) EnableAutoDecodeAllType() *Client { c.GetResponseOptions().AutoDecodeContentType = func(contentType string) bool { @@ -269,12 +277,19 @@ func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetHeader("User-Agent", userAgent) } +func (c *Client) SetHeaders(hdrs map[string]string) *Client { + for k, v := range hdrs { + c.SetHeader(k, v) + } + return c +} + // SetHeader set the common header for all requests. func (c *Client) SetHeader(key, value string) *Client { if c.Headers == nil { - c.Headers = make(map[string]string) + c.Headers = make(http.Header) } - c.Headers[key] = value + c.Headers.Set(key, value) return c } @@ -303,7 +318,7 @@ func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { } // SetProxy set the proxy function. -func (c *Client) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Client { +func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { c.t.Proxy = proxy return c } @@ -319,7 +334,7 @@ func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { } func (c *Client) SetProxyURL(proxyUrl string) *Client { - u, err := url.Parse(proxyUrl) + u, err := urlpkg.Parse(proxyUrl) if err != nil { logf(c.log, "failed to parse proxy url %s: %v", proxyUrl, err) return c @@ -345,7 +360,7 @@ func (c *Client) Clone() *Client { t2: t2, dumpOptions: c.dumpOptions.Clone(), jsonDecoder: c.jsonDecoder, - Headers: cloneMap(c.Headers), + Headers: cloneHeaders(c.Headers), PathParams: cloneMap(c.PathParams), QueryParams: cloneUrlValues(c.QueryParams), HostURL: c.HostURL, @@ -375,6 +390,7 @@ func C() *Client { } beforeRequest := []RequestMiddleware{ parseRequestURL, + parseRequestHeader, } c := &Client{ beforeRequest: beforeRequest, @@ -385,3 +401,49 @@ func C() *Client { } return c } + +func (c *Client) Do(r *Request) (resp *Response, err error) { + for _, f := range r.client.udBeforeRequest { + if err := f(r.client, r); err != nil { + return nil, err + } + } + + for _, f := range r.client.beforeRequest { + if err := f(r.client, r); err != nil { + return nil, err + } + } + setRequestURL(r.RawRequest, r.URL) + setRequestHeader(r) + + logf(c.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) + httpResponse, err := c.httpClient.Do(r.RawRequest) + if err != nil { + return + } + resp = &Response{ + request: r, + Response: httpResponse, + } + return +} + +func setRequestHeader(r *Request) { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.RawRequest.Header = r.Headers +} + +func setRequestURL(r *http.Request, url string) error { + // The host's colon:port should be normalized. See Issue 14836. + u, err := urlpkg.Parse(url) + if err != nil { + return err + } + u.Host = removeEmptyPort(u.Host) + r.URL = u + r.Host = u.Host + return nil +} diff --git a/middleware.go b/middleware.go index 18c48244..de6e09fc 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,7 @@ package req import ( "github.com/imroc/req/v2/internal/util" + "net/http" "net/url" "strings" ) @@ -11,6 +12,24 @@ type ( RequestMiddleware func(*Client, *Request) error ) +func parseRequestHeader(c *Client, r *Request) error { + if c.Headers == nil { + return nil + } + hdr := make(http.Header) + for k := range c.Headers { + hdr[k] = append(hdr[k], c.Headers[k]...) + } + + for k := range r.Headers { + hdr.Del(k) + hdr[k] = append(hdr[k], r.Headers[k]...) + } + + r.Headers = hdr + return nil +} + func parseRequestURL(c *Client, r *Request) error { if len(r.PathParams) > 0 { for p, v := range r.PathParams { @@ -82,4 +101,3 @@ func parseRequestURL(c *Client, r *Request) error { return nil } - diff --git a/request.go b/request.go index 00b301dc..3a4c248b 100644 --- a/request.go +++ b/request.go @@ -3,7 +3,6 @@ package req import ( "bytes" "encoding/json" - "fmt" "github.com/hashicorp/go-multierror" "io" "io/ioutil" @@ -18,6 +17,7 @@ type Request struct { URL string PathParams map[string]string QueryParams urlpkg.Values + Headers http.Header error error client *Client RawRequest *http.Request @@ -30,6 +30,22 @@ func New() *Request { return defaultClient.R() } +func (r *Request) SetHeaders(hdrs map[string]string) *Request { + for k, v := range hdrs { + r.SetHeader(k, v) + } + return r +} + +// SetHeader set the common header for all requests. +func (r *Request) SetHeader(key, value string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers.Set(key, value) + return r +} + func (r *Request) SetOutputFile(file string) *Request { output, err := os.Create(file) if err != nil { @@ -89,44 +105,9 @@ func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { return nil, r.error } - - if method == "" { - // We document that "" means "GET" for Request.Method, and people have - // relied on that from NewRequest, so keep that working. - // We still enforce validMethod for non-empty methods. - method = "GET" - } - if !validMethod(method) { - err := fmt.Errorf("req: invalid method %q", method) - if err != nil { - return nil, err - } - } r.RawRequest.Method = method - r.URL = url - - for _, f := range r.client.udBeforeRequest { - if err := f(r.client, r); err != nil { - return nil, err - } - } - - for _, f := range r.client.beforeRequest { - if err := f(r.client, r); err != nil { - return nil, err - } - } - - // The host's colon:port should be normalized. See Issue 14836. - u, err := urlpkg.Parse(r.URL) - if err != nil { - return nil, err - } - u.Host = removeEmptyPort(u.Host) - r.RawRequest.URL = u - r.RawRequest.Host = u.Host - return r.execute() + return r.client.Do(r) } // MustGet like Get, panic if error happens. @@ -288,27 +269,3 @@ func (r *Request) SetContentType(contentType string) *Request { r.RawRequest.Header.Set("Content-Type", contentType) return r } - -func (r *Request) execute() (resp *Response, err error) { - if r.error != nil { - return nil, r.error - } - for k, v := range r.client.Headers { - if r.RawRequest.Header.Get(k) == "" { - r.RawRequest.Header.Set(k, v) - } - } - logf(r.client.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) - httpResponse, err := r.client.httpClient.Do(r.RawRequest) - if err != nil { - return - } - resp = &Response{ - request: r, - Response: httpResponse, - } - if r.client.t.ResponseOptions != nil && r.client.t.ResponseOptions.AutoDiscard { - err = resp.Discard() - } - return -} diff --git a/response.go b/response.go index d073adca..7a87da9e 100644 --- a/response.go +++ b/response.go @@ -15,10 +15,6 @@ type ResponseOptions struct { // whether the response body should been auto decode to utf-8. // Only valid when DisableAutoDecode is true. AutoDecodeContentType func(contentType string) bool - - // AutoDiscard, if true, read all response body and discard automatically, - // useful when test - AutoDiscard bool } var textContentTypes = []string{"text", "json", "xml", "html", "java"} From 1fe1dc3805b5787d1d691b6f6b106bc5c049dd52 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 22:21:52 +0800 Subject: [PATCH 070/843] auto read response body when not downloading --- client.go | 40 ++++++++++++++++++++++++---------------- response.go | 3 ++- response_body.go | 3 +++ 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index 2ff153b8..2972f233 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "github.com/imroc/req/v2/internal/util" "golang.org/x/net/publicsuffix" "io" + "io/ioutil" "net/http" "net/http/cookiejar" urlpkg "net/url" @@ -355,19 +356,20 @@ func (c *Client) Clone() *Client { cc := *c.httpClient cc.Transport = t return &Client{ - httpClient: &cc, - t: t, - t2: t2, - dumpOptions: c.dumpOptions.Clone(), - jsonDecoder: c.jsonDecoder, - Headers: cloneHeaders(c.Headers), - PathParams: cloneMap(c.PathParams), - QueryParams: cloneUrlValues(c.QueryParams), - HostURL: c.HostURL, - scheme: c.scheme, - log: c.log, - beforeRequest: c.beforeRequest, - udBeforeRequest: c.udBeforeRequest, + httpClient: &cc, + t: t, + t2: t2, + dumpOptions: c.dumpOptions.Clone(), + jsonDecoder: c.jsonDecoder, + Headers: cloneHeaders(c.Headers), + PathParams: cloneMap(c.PathParams), + QueryParams: cloneUrlValues(c.QueryParams), + HostURL: c.HostURL, + scheme: c.scheme, + log: c.log, + beforeRequest: c.beforeRequest, + udBeforeRequest: c.udBeforeRequest, + disableAutoReadResponse: c.disableAutoReadResponse, } } @@ -419,13 +421,19 @@ func (c *Client) Do(r *Request) (resp *Response, err error) { logf(c.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) httpResponse, err := c.httpClient.Do(r.RawRequest) - if err != nil { - return - } resp = &Response{ request: r, Response: httpResponse, } + + if err != nil || c.disableAutoReadResponse { + return + } + body, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + return + } + resp.body = body return } diff --git a/response.go b/response.go index 7a87da9e..68c4a486 100644 --- a/response.go +++ b/response.go @@ -36,4 +36,5 @@ func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) type Response struct { *http.Response request *Request -} \ No newline at end of file + body []byte +} diff --git a/response_body.go b/response_body.go index 56e1430f..3978597f 100644 --- a/response_body.go +++ b/response_body.go @@ -105,6 +105,9 @@ func (r *Response) String() (string, error) { } func (r *Response) Bytes() ([]byte, error) { + if r.body != nil { + return r.body, nil + } defer r.Body.Close() return ioutil.ReadAll(r.Body) } From 19e4dc77541b13f38f0263a2dcda495dbf456506 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 22:37:48 +0800 Subject: [PATCH 071/843] save file refactor --- README.md | 14 +++++++------- client.go | 17 ++++++++++++++++- response_body.go | 38 -------------------------------------- 3 files changed, 23 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index b47427a3..c30eb7b9 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,14 @@ import "github.com/imroc/req/v2" Prepare a client: ```go -client := req.C().UserAgent("custom-client") // client settings is chainable +client := req.C().SetUserAgent("custom-client") // client settings is chainable ``` Use client to create and send request: ```go // use R() to create a new request, and request settings is also chainable -resp, err: client.R().Header("test", "req").Body("test").Get("https://test.example.com") +resp, err: client.R().SetHeader("test", "req").SetBody("test").Get("https://test.example.com") ``` You can also use the default client when test: @@ -53,16 +53,16 @@ resp, err := req.New().Header("test", "req").Get(url) You can also simply do it with one line of code like this: ```go -resp, err := req.DefaultClient().UserAgent("custom-client").R().Header("test", "req").Get(url) +resp, err := req.DefaultClient().UserAgent("custom-client").R().SetHeader("test", "req").Get(url) ``` Want to debug requests? Just enable dump: ```go // create client and enable dump -client := req.C().UserAgent("custom-client").DumpAll() +client := req.C().UserAgent("custom-client").EnableDumpAll() // send request and read response body -client.R().Get("https://api.github.com/users/imroc").MustString() +client.R().Get("https://api.github.com/users/imroc") ``` Now you can see the request and response content that has been dumped: @@ -113,8 +113,8 @@ Simple example: ```go // dump head content asynchronously and save it to file -client := req.C().DumpOnlyHead().DumpAsync().DumpToFile("reqdump.log") -resp, err := client.R().Body(body).Post(url) +client := req.C().DumpOnlyHeader().DumpAsync().DumpToFile("reqdump.log") +resp, err := client.R().SetBody(body).Post(url) ... ``` diff --git a/client.go b/client.go index 2972f233..4dde1e48 100644 --- a/client.go +++ b/client.go @@ -426,9 +426,24 @@ func (c *Client) Do(r *Request) (resp *Response, err error) { Response: httpResponse, } - if err != nil || c.disableAutoReadResponse { + if err != nil { + return + } + + if r.isSaveResponse { + defer func() { + httpResponse.Body.Close() + r.output.Close() + }() + _, err = io.Copy(r.output, httpResponse.Body) return } + + if c.disableAutoReadResponse { + return + } + + // auto read response body body, err := ioutil.ReadAll(httpResponse.Body) if err != nil { return diff --git a/response_body.go b/response_body.go index 3978597f..f6a36755 100644 --- a/response_body.go +++ b/response_body.go @@ -3,48 +3,10 @@ package req import ( "encoding/json" "encoding/xml" - "io" "io/ioutil" - "os" "strings" ) -func (r *Response) MustSave(dst io.Writer) { - err := r.Save(dst) - if err != nil { - panic(err) - } -} - -func (r *Response) Save(dst io.Writer) error { - if dst == nil { - return nil // TODO: return error - } - _, err := io.Copy(dst, r.Body) - r.Body.Close() - return err -} - -func (r *Response) MustSaveFile(filename string) { - err := r.SaveFile(filename) - if err != nil { - panic(err) - } -} - -func (r *Response) SaveFile(filename string) error { - if filename == "" { - return nil // TODO: return error - } - file, err := os.Create(filename) - if err != nil { - return err - } - _, err = io.Copy(file, r.Body) - r.Body.Close() - return err -} - func (r *Response) MustUnmarshalJson(v interface{}) { err := r.UnmarshalJson(v) if err != nil { From 7b378daea112540893302300d24baec41927f6d1 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 22:38:25 +0800 Subject: [PATCH 072/843] bump to v2.0.0-alpha.3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c30eb7b9..3abb4769 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.2 +go get github.com/imroc/req/v2@v2.0.0-alpha.3 ``` ## Quick Start From 8f0b736075c5a1b14d290efbd1443243386040da Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 22 Jan 2022 22:39:35 +0800 Subject: [PATCH 073/843] fix README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3abb4769..67ac7969 100644 --- a/README.md +++ b/README.md @@ -44,10 +44,10 @@ You can also use the default client when test: ```go // customize default client settings -req.DefaultClient().UserAgent("custom-client") +req.DefaultClient().SetUserAgent("custom-client") // create and send request using default client -resp, err := req.New().Header("test", "req").Get(url) +resp, err := req.New().SetHeader("test", "req").Get(url) ``` You can also simply do it with one line of code like this: From 5d1f5731fc2120b1804f53f9a6dd4246498422fe Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 11:36:56 +0800 Subject: [PATCH 074/843] parse response body --- client.go | 71 ++++++++++++++++++++++++++----------------- internal/util/util.go | 37 +++++++++++++++++++++- middleware.go | 65 +++++++++++++++++++++++++++++++++++++++ request.go | 20 ++++++++---- response.go | 26 +++++++++++++++- response_body.go | 7 ++++- 6 files changed, 189 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index 4dde1e48..20f5439f 100644 --- a/client.go +++ b/client.go @@ -2,10 +2,10 @@ package req import ( "encoding/json" + "encoding/xml" "github.com/imroc/req/v2/internal/util" "golang.org/x/net/publicsuffix" "io" - "io/ioutil" "net/http" "net/http/cookiejar" urlpkg "net/url" @@ -30,10 +30,15 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - HostURL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header + HostURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + JSONMarshal func(v interface{}) ([]byte, error) + JSONUnmarshal func(data []byte, v interface{}) error + XMLMarshal func(v interface{}) ([]byte, error) + XMLUnmarshal func(data []byte, v interface{}) error + disableAutoReadResponse bool scheme string log Logger @@ -44,6 +49,7 @@ type Client struct { jsonDecoder *json.Decoder beforeRequest []RequestMiddleware udBeforeRequest []RequestMiddleware + afterResponse []ResponseMiddleware } func cloneHeaders(hdrs http.Header) http.Header { @@ -334,6 +340,11 @@ func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { return c } +func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { + c.afterResponse = append(c.afterResponse, m) + return c +} + func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) if err != nil { @@ -394,61 +405,65 @@ func C() *Client { parseRequestURL, parseRequestHeader, } + afterResponse := []ResponseMiddleware{ + parseResponseBody, + handleDownload, + } c := &Client{ beforeRequest: beforeRequest, + afterResponse: afterResponse, log: &emptyLogger{}, httpClient: httpClient, t: t, t2: t2, + JSONMarshal: json.Marshal, + JSONUnmarshal: json.Unmarshal, + XMLMarshal: xml.Marshal, + XMLUnmarshal: xml.Unmarshal, } return c } func (c *Client) Do(r *Request) (resp *Response, err error) { + for _, f := range r.client.udBeforeRequest { - if err := f(r.client, r); err != nil { - return nil, err + if err = f(r.client, r); err != nil { + return } } for _, f := range r.client.beforeRequest { - if err := f(r.client, r); err != nil { - return nil, err + if err = f(r.client, r); err != nil { + return } } + setRequestURL(r.RawRequest, r.URL) setRequestHeader(r) logf(c.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) httpResponse, err := c.httpClient.Do(r.RawRequest) - resp = &Response{ - request: r, - Response: httpResponse, - } - if err != nil { return } - if r.isSaveResponse { - defer func() { - httpResponse.Body.Close() - r.output.Close() - }() - _, err = io.Copy(r.output, httpResponse.Body) - return + resp = &Response{ + Request: r, + Response: httpResponse, } - if c.disableAutoReadResponse { - return + if !c.disableAutoReadResponse && !r.isSaveResponse { // auto read response body + _, err = resp.Bytes() + if err != nil { + return + } } - // auto read response body - body, err := ioutil.ReadAll(httpResponse.Body) - if err != nil { - return + for _, f := range r.client.afterResponse { + if err = f(r.client, resp); err != nil { + return + } } - resp.body = body return } diff --git a/internal/util/util.go b/internal/util/util.go index a4c0b4d2..58773aab 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -2,9 +2,35 @@ package util import ( "bytes" + "reflect" + "regexp" "strings" ) +var ( + jsonCheck = regexp.MustCompile(`(?i:(application|text)/(json|.*\+json|json\-.*)(;|$))`) + xmlCheck = regexp.MustCompile(`(?i:(application|text)/(xml|.*\+xml)(;|$))`) +) + +// IsJSONType method is to check JSON content type or not +func IsJSONType(ct string) bool { + return jsonCheck.MatchString(ct) +} + +// IsXMLType method is to check XML content type or not +func IsXMLType(ct string) bool { + return xmlCheck.MatchString(ct) +} + + +func GetPointer(v interface{}) interface{} { + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr { + return v + } + return reflect.New(vv.Type()).Interface() +} + // CutString slices s around the first instance of sep, // returning the text before and after sep. // The found result reports whether sep appears in s. @@ -32,4 +58,13 @@ func CutBytes(s, sep []byte) (before, after []byte, found bool) { // IsStringEmpty method tells whether given string is empty or not func IsStringEmpty(str string) bool { return len(strings.TrimSpace(str)) == 0 -} \ No newline at end of file +} + +func FirstNonEmpty(v ...string) string { + for _, s := range v { + if !IsStringEmpty(s) { + return s + } + } + return "" +} diff --git a/middleware.go b/middleware.go index de6e09fc..64bf950a 100644 --- a/middleware.go +++ b/middleware.go @@ -1,7 +1,10 @@ package req import ( + "bytes" "github.com/imroc/req/v2/internal/util" + "io" + "io/ioutil" "net/http" "net/url" "strings" @@ -10,8 +13,70 @@ import ( type ( // RequestMiddleware type is for request middleware, called before a request is sent RequestMiddleware func(*Client, *Request) error + + // ResponseMiddleware type is for response middleware, called after a response has been received + ResponseMiddleware func(*Client, *Response) error ) +// unmarshalc content into object from JSON or XML +func unmarshalc(c *Client, ct string, b []byte, d interface{}) (err error) { + if util.IsJSONType(ct) { + err = c.JSONUnmarshal(b, d) + } else if util.IsXMLType(ct) { + err = c.XMLUnmarshal(b, d) + } + return +} + +func parseResponseBody(c *Client, r *Response) (err error) { + if r.StatusCode == http.StatusNoContent { + return + } + body, err := r.Bytes() // in case req.SetResult with cient.DisalbeAutoReadResponse(true) + if err != nil { + return + } + // Handles only JSON or XML content type + ct := util.FirstNonEmpty(r.GetContentType()) + if r.IsSuccess() && r.Request.Result != nil { + r.Request.Error = nil + if util.IsJSONType(ct) { + return c.JSONUnmarshal(body, r.Request.Result) + } else if util.IsXMLType(ct) { + return c.XMLUnmarshal(body, r.Request.Result) + } + } + if r.IsError() && r.Request.Error != nil { + r.Request.Result = nil + if util.IsJSONType(ct) { + return c.JSONUnmarshal(body, r.Request.Error) + } else if util.IsXMLType(ct) { + return c.XMLUnmarshal(body, r.Request.Error) + } + } + return +} + +func handleDownload(c *Client, r *Response) error { + if !r.Request.isSaveResponse { + return nil + } + + var body io.ReadCloser + if r.body != nil { // already read + body = ioutil.NopCloser(bytes.NewReader(r.body)) + } else { + body = r.Body + } + + defer func() { + body.Close() + r.Request.output.Close() + }() + _, err := io.Copy(r.Request.output, body) + return err +} + func parseRequestHeader(c *Client, r *Request) error { if c.Headers == nil { return nil diff --git a/request.go b/request.go index 3a4c248b..75543ebf 100644 --- a/request.go +++ b/request.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "github.com/hashicorp/go-multierror" + "github.com/imroc/req/v2/internal/util" "io" "io/ioutil" "net/http" @@ -18,10 +19,13 @@ type Request struct { PathParams map[string]string QueryParams urlpkg.Values Headers http.Header + Result interface{} + Error interface{} error error client *Client RawRequest *http.Request isSaveResponse bool + isMultiPart bool output io.WriteCloser } @@ -30,6 +34,16 @@ func New() *Request { return defaultClient.R() } +func (r *Request) SetResult(result interface{}) *Request { + r.Result = util.GetPointer(result) + return r +} + +func (r *Request) SetError(error interface{}) *Request { + r.Error = util.GetPointer(error) + return r +} + func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { r.SetHeader(k, v) @@ -95,12 +109,6 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } -// Error return the underlying error, not nil if some error -// happend when constructing the request. -func (r *Request) Error() error { - return r.error -} - func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { return nil, r.error diff --git a/response.go b/response.go index 68c4a486..f3fba4e7 100644 --- a/response.go +++ b/response.go @@ -35,6 +35,30 @@ func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) // Response is the http response. type Response struct { *http.Response - request *Request + Request *Request body []byte } + +// IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. +func (r *Response) IsSuccess() bool { + return r.StatusCode > 199 && r.StatusCode < 300 +} + +// IsError method returns true if HTTP status `code >= 400` otherwise false. +func (r *Response) IsError() bool { + return r.StatusCode > 399 +} + +func (r *Response) GetContentType() string { + return r.Header.Get("Content-Type") +} + +// Result method returns the response value as an object if it has one +func (r *Response) Result() interface{} { + return r.Request.Result +} + +// Error method returns the error object if it has one +func (r *Response) Error() interface{} { + return r.Request.Error +} diff --git a/response_body.go b/response_body.go index f6a36755..1ebf1af3 100644 --- a/response_body.go +++ b/response_body.go @@ -71,7 +71,12 @@ func (r *Response) Bytes() ([]byte, error) { return r.body, nil } defer r.Body.Close() - return ioutil.ReadAll(r.Body) + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.body = body + return body, nil } func (r *Response) MustBytes() []byte { From 9936d63cb5c5900fe55d0c6f7ad072342c6ca26c Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 13:13:59 +0800 Subject: [PATCH 075/843] refactor logger and debug --- client.go | 26 +++++++++++++++--------- logger.go | 59 +++++++++++++++++++++++++++++++----------------------- request.go | 14 +++++++++++++ 3 files changed, 65 insertions(+), 34 deletions(-) diff --git a/client.go b/client.go index 20f5439f..af9ed818 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,7 @@ type Client struct { JSONUnmarshal func(data []byte, v interface{}) error XMLMarshal func(v interface{}) ([]byte, error) XMLUnmarshal func(data []byte, v interface{}) error + Debug bool disableAutoReadResponse bool scheme string @@ -108,13 +109,18 @@ const ( userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" ) -// DebugMode enables dump for requests and responses, and set user +func (c *Client) EnableDebug(enable bool) *Client { + c.Debug = enable + return c +} + +// DevMode enables dump for requests and responses, and set user // agent to pretend to be a web browser, Avoid returning abnormal // data from some sites. -func (c *Client) DebugMode() *Client { +func (c *Client) DevMode() *Client { return c.EnableAutoDecodeTextType(). EnableDumpAll(). - SetLogger(NewLogger(os.Stdout)). + EnableDebug(true). SetUserAgent(userAgentChrome) } @@ -127,9 +133,10 @@ func (c *Client) SetScheme(scheme string) *Client { return c } -// SetLogger set the logger for req. +// SetLogger set the logger for req, set to nil to disable logger. func (c *Client) SetLogger(log Logger) *Client { if log == nil { + c.log = &disableLogger{} return c } c.log = log @@ -176,7 +183,7 @@ func (c *Client) enableDump() { func (c *Client) EnableDumpToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { - logf(c.log, "create dump file error: %v", err) + c.log.Errorf("create dump file error: %v", err) return c } c.GetDumpOptions().Output = file @@ -348,7 +355,7 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) if err != nil { - logf(c.log, "failed to parse proxy url %s: %v", proxyUrl, err) + c.log.Errorf("failed to parse proxy url %s: %v", proxyUrl, err) return c } c.t.Proxy = http.ProxyURL(u) @@ -412,7 +419,7 @@ func C() *Client { c := &Client{ beforeRequest: beforeRequest, afterResponse: afterResponse, - log: &emptyLogger{}, + log: createLogger(), httpClient: httpClient, t: t, t2: t2, @@ -440,8 +447,9 @@ func (c *Client) Do(r *Request) (resp *Response, err error) { setRequestURL(r.RawRequest, r.URL) setRequestHeader(r) - - logf(c.log, "%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) + if c.Debug { + c.log.Debugf("%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) + } httpResponse, err := c.httpClient.Do(r.RawRequest) if err != nil { return diff --git a/logger.go b/logger.go index 98b16647..faf22d3e 100644 --- a/logger.go +++ b/logger.go @@ -1,43 +1,52 @@ package req import ( - "fmt" - "io" + "log" "os" ) -// Logger is the logging interface that req used internal, -// you can set the Logger for client if you want to see req's -// internal logging information. +// Logger interface is to abstract the logging from Resty. Gives control to +// the Resty users, choice of the logger. type Logger interface { - Println(v ...interface{}) + Errorf(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) } -type logger struct { - w io.Writer +func createLogger() *logger { + l := &logger{l: log.New(os.Stderr, "", log.Ldate|log.Lmicroseconds)} + return l } -func (l *logger) Println(v ...interface{}) { - fmt.Fprintln(l.w, v...) -} +var _ Logger = (*logger)(nil) -// NewLogger create a simple Logger. -func NewLogger(output io.Writer) Logger { - if output == nil { - output = os.Stdout - } - return &logger{output} -} +type disableLogger struct{} + +func (l *disableLogger) Errorf(format string, v ...interface{}) {} +func (l *disableLogger) Warnf(format string, v ...interface{}) {} +func (l *disableLogger) Debugf(format string, v ...interface{}) {} -type emptyLogger struct{} +type logger struct { + l *log.Logger +} -func (l *emptyLogger) Println(v ...interface{}) {} +func (l *logger) Errorf(format string, v ...interface{}) { + l.output("ERROR", format, v...) +} -func logp(logger Logger, s string) { - logger.Println("[req]", s) +func (l *logger) Warnf(format string, v ...interface{}) { + l.output("WARN", format, v...) } -func logf(logger Logger, format string, v ...interface{}) { - s := fmt.Sprintf(format, v...) - logp(logger, s) +func (l *logger) Debugf(format string, v ...interface{}) { + l.output("DEBUG", format, v...) } + +func (l *logger) output(level, format string, v ...interface{}) { + format = level + " [req] " + format + if len(v) == 0 { + l.l.Print(format) + return + } + l.l.Printf(format, v...) +} \ No newline at end of file diff --git a/request.go b/request.go index 75543ebf..f12881f5 100644 --- a/request.go +++ b/request.go @@ -34,6 +34,20 @@ func New() *Request { return defaultClient.R() } +func (r *Request) SetQueryString(query string) *Request { + params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) + if err == nil { + for p, v := range params { + for _, pv := range v { + r.QueryParams.Add(p, pv) + } + } + } else { + r.client.log.Errorf("%v", err) + } + return r +} + func (r *Request) SetResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r From 234009e64d37ed341b3a3a5c220dea8bc7131ca1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 13:18:18 +0800 Subject: [PATCH 076/843] default dump to stderr --- dump.go | 14 +++----------- logger.go | 2 +- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/dump.go b/dump.go index 49426c86..a27f039b 100644 --- a/dump.go +++ b/dump.go @@ -99,7 +99,7 @@ type dumper struct { func newDefaultDumpOptions() *DumpOptions { return &DumpOptions{ - Output: os.Stdout, + Output: os.Stderr, RequestBody: true, ResponseBody: true, ResponseHeader: true, @@ -107,20 +107,12 @@ func newDefaultDumpOptions() *DumpOptions { } } -var defaultDumpOptions = &DumpOptions{ - Output: os.Stdout, - RequestBody: true, - ResponseBody: true, - ResponseHeader: true, - RequestHeader: true, -} - func newDumper(opt *DumpOptions) *dumper { if opt == nil { - opt = defaultDumpOptions.Clone() + opt = newDefaultDumpOptions() } if opt.Output == nil { - opt.Output = os.Stdout + opt.Output = os.Stderr } d := &dumper{ DumpOptions: opt, diff --git a/logger.go b/logger.go index faf22d3e..56b6516a 100644 --- a/logger.go +++ b/logger.go @@ -49,4 +49,4 @@ func (l *logger) output(level, format string, v ...interface{}) { return } l.l.Printf(format, v...) -} \ No newline at end of file +} From 815d469ce41ef2f00b6ad26935c208aced042197 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:21:16 +0800 Subject: [PATCH 077/843] update README --- README.md | 73 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 67ac7969..2e13b635 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ A golang http request library for humans. go get github.com/imroc/req/v2@v2.0.0-alpha.3 ``` -## Quick Start +## Usage Import req in your code: @@ -27,47 +27,50 @@ Import req in your code: import "github.com/imroc/req/v2" ``` -Prepare a client: +[Quick Start](#Quick-Start) -```go -client := req.C().SetUserAgent("custom-client") // client settings is chainable -``` +### Quick Start -Use client to create and send request: +**Simple GET** ```go -// use R() to create a new request, and request settings is also chainable -resp, err: client.R().SetHeader("test", "req").SetBody("test").Get("https://test.example.com") +// Create and send a request with the global default client +resp, err := req.New().Get("https://api.github.com/users/imroc") + +// Create and send a request with the custom client +client := req.C() +resp, err := client.R().Get("https://api.github.com/users/imroc") ``` -You can also use the default client when test: +**Client Settings** ```go -// customize default client settings -req.DefaultClient().SetUserAgent("custom-client") - -// create and send request using default client -resp, err := req.New().SetHeader("test", "req").Get(url) +client.SetUserAgent("my-custom-client"). + SetTimeout(5 * time.Second). + DevMode() ``` -You can also simply do it with one line of code like this: +**Request Settings** ```go -resp, err := req.DefaultClient().UserAgent("custom-client").R().SetHeader("test", "req").Get(url) +var result Result +resp, err := client.R(). + SetResult(&result). + SetHeader("Accept", "application/json"). + SetQeuryParam("page", "1"). + SetPathParam("userId", "imroc"). + Get(url) ``` -Want to debug requests? Just enable dump: +### Debug ```go -// create client and enable dump -client := req.C().UserAgent("custom-client").EnableDumpAll() -// send request and read response body +// Set EnableDump to true, default dump all content to stderr, +// including both header and body of request and response +client := req.C().EnableDump(true) client.R().Get("https://api.github.com/users/imroc") -``` - -Now you can see the request and response content that has been dumped: -```txt +/* Output :authority: api.github.com :method: GET :path: /users/imroc @@ -103,21 +106,19 @@ content-length: 486 x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} -``` - -> Here we can see the content is in http2 format, because req will try http2 by default if server support. - -## Debug -Simple example: - -```go -// dump head content asynchronously and save it to file -client := req.C().DumpOnlyHeader().DumpAsync().DumpToFile("reqdump.log") -resp, err := client.R().SetBody(body).Post(url) -... +*/ + +// dump header content asynchronously and save it to file +client := req.C(). + EnableDumpOnlyHeader(). // only dump the header of request and response + EnableDumpAsync(). // dump asynchronously to improve performance + EnableDumpToFile("reqdump.log") // dump to file without printing it out +client.Get(url) ``` +## PathParams and QeuryParams + ## License Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file From 8eca518a68da7c5530910abca10dc9dbfc7bf50e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:21:53 +0800 Subject: [PATCH 078/843] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2e13b635..17260fb8 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ import "github.com/imroc/req/v2" ``` [Quick Start](#Quick-Start) +[Debug](#Debug) ### Quick Start From e0b90c169012668b5a49d8786c7e8be71cd9cfec Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:22:13 +0800 Subject: [PATCH 079/843] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 17260fb8..2b27e851 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ import "github.com/imroc/req/v2" ``` [Quick Start](#Quick-Start) + [Debug](#Debug) ### Quick Start From 2852275fab7682135b49b1646a3da1f8f3e8fbe0 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:23:33 +0800 Subject: [PATCH 080/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2b27e851..c424977a 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ Import req in your code: import "github.com/imroc/req/v2" ``` +### Content [Quick Start](#Quick-Start) - [Debug](#Debug) ### Quick Start From 3a2a01faedc23982f74f88f34fb8b6a4d87206e8 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:26:52 +0800 Subject: [PATCH 081/843] update README --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c424977a..569e2644 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,10 @@ Import req in your code: import "github.com/imroc/req/v2" ``` -### Content -[Quick Start](#Quick-Start) -[Debug](#Debug) +### Table of Contents + +* [Quick Start](#Quick-Start) +* [Debug](#Debug) ### Quick Start From a08a58af82517bf74572b9ce8eb21a23fa14f1fe Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 14:50:38 +0800 Subject: [PATCH 082/843] update README --- README.md | 49 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 569e2644..4bc6ac5a 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ import "github.com/imroc/req/v2" * [Quick Start](#Quick-Start) * [Debug](#Debug) +* [PathParams and QueryParams](#PathParams-QueryParams) ### Quick Start @@ -67,6 +68,8 @@ resp, err := client.R(). ### Debug +**Dump the content of request and response** + ```go // Set EnableDump to true, default dump all content to stderr, // including both header and body of request and response @@ -112,15 +115,51 @@ x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E */ -// dump header content asynchronously and save it to file +// Dump header content asynchronously and save it to file client := req.C(). - EnableDumpOnlyHeader(). // only dump the header of request and response - EnableDumpAsync(). // dump asynchronously to improve performance - EnableDumpToFile("reqdump.log") // dump to file without printing it out + EnableDumpOnlyHeader(). // Only dump the header of request and response + EnableDumpAsync(). // Dump asynchronously to improve performance + EnableDumpToFile("reqdump.log") // Dump to file without printing it out client.Get(url) ``` -## PathParams and QeuryParams +**Logging** + +```go +// Logging is enabled by default, but only output warning and error message. +// EnableDebug set to true to enable debug level logging. +client := req.C().EnableDebug(true) +client.R().Get("https://api.github.com/users/imroc") +// Output +// 2022/01/23 14:33:04.755019 DEBUG [req] GET https://api.github.com/users/imroc + +// SetLogger with nil to disable all log +client.SetLogger(nil) + +// Or customize the logger with your own implementation. +client.SetLogger(logger) +``` + +### PathParams and QueryParams + +```go +client := req.C().EnableDebug(true) +client.R(). + SetPathParam("owner", "imroc"). + SetPathParams(map[string]string{ + "repo": "req", + "path": "README.md", + }). + SetQueryParam("a", "a"). + SetQueryParams(map[string]string{ + "b": "b", + "c": "c", + }). + SetQueryString("d=d&e=e"). + Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}?x=x") +// Output +// 2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e +``` ## License From a88f9c9487cdeb252f3cc644534e3594d91956ed Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:05:06 +0800 Subject: [PATCH 083/843] some fix --- README.md | 118 +++++++++++++++++++++++++++++++++++++++++++------- client.go | 59 +++++++++++++++++++++++-- middleware.go | 31 ++++++------- request.go | 11 +++++ 4 files changed, 186 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 4bc6ac5a..0172bad9 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ import "github.com/imroc/req/v2" * [Quick Start](#Quick-Start) * [Debug](#Debug) -* [PathParams and QueryParams](#PathParams-QueryParams) +* [PathParam and QueryParam](#PathParam-QueryParam) +* [Header and Cookie](#Header-Cookie) ### Quick Start @@ -112,7 +113,6 @@ content-length: 486 x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} - */ // Dump header content asynchronously and save it to file @@ -126,8 +126,8 @@ client.Get(url) **Logging** ```go -// Logging is enabled by default, but only output warning and error message. -// EnableDebug set to true to enable debug level logging. +// Logging is enabled by default, but only output warning and error message to stderr. +// EnableDebug set to true to enable debug level message logging. client := req.C().EnableDebug(true) client.R().Get("https://api.github.com/users/imroc") // Output @@ -140,25 +140,113 @@ client.SetLogger(nil) client.SetLogger(logger) ``` -### PathParams and QueryParams +**DevMode** + +If you want to enable all debug features (dump, debug logging and tracing), just call `DevMode()`: ```go -client := req.C().EnableDebug(true) +client := req.C().DevMode() +client.R().Get("https://imroc.cc") +``` + +### PathParam and QueryParam + +```go +client := req.C().DevMode() client.R(). - SetPathParam("owner", "imroc"). - SetPathParams(map[string]string{ + SetPathParam("owner", "imroc"). // Set a path param, which will replace the vairable in url path + SetPathParams(map[string]string{ // Set multiple path params at once "repo": "req", "path": "README.md", - }). - SetQueryParam("a", "a"). - SetQueryParams(map[string]string{ + }).SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url + SetQueryParams(map[string]string{ // Set multiple query params at once "b": "b", "c": "c", - }). - SetQueryString("d=d&e=e"). + }).SetQueryString("d=d&e=e"). // Set query params as a raw query string Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}?x=x") -// Output -// 2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e +/* Output +2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e +... +*/ + +// You can also set the common PathParam and QueryParam for every request on client +client.SetPathParam(k1, v1). + SetPathParams(pathParams). + SetQueryParam(k2, v2). + SetQueryParams(queryParams). + SetQueryString(queryString). + +resp, err := client.Get(url1) +... + +resp, err := client.Get(url2) +... +``` + +### Header and Cookie + +```go +// Let's dump the header to see what's going on +client := req.C().EnableDumpOnlyHeader() + +// Send a request with multiple headers and cookies +resp, err := client.R(). + SetHeader("Accept", "application/json"). // Set one header + SetHeaders(map[string]string{ // Set multiple headers at once + "My-Custom-Header": "My Custom Value", + "User": "imroc", + }).SetCookie(&http.Cookie{ // Set one cookie + Name: "imroc/req", + Value: "This is my custome cookie value", + Path: "/", + Domain: "baidu.com", + MaxAge: 36000, + HttpOnly: false, + Secure: true, + }).SetCookies([]*http.Cookie{ // Set multiple cookies at once + &http.Cookie{ + Name: "testcookie1", + Value: "testcookie1 value", + Path: "/", + Domain: "baidu.com", + MaxAge: 36000, + HttpOnly: false, + Secure: true, + }, + &http.Cookie{ + Name: "testcookie2", + Value: "testcookie2 value", + Path: "/", + Domain: "baidu.com", + MaxAge: 36000, + HttpOnly: false, + Secure: true, + }, + }).Get("https://www.baidu.com/") + +/* Output +GET / HTTP/1.1 +Host: www.baidu.com +User-Agent: req/v2 +Accept: application/json +Cookie: imroc/req="This is my custome cookie value"; testcookie1="testcookie1 value"; testcookie2="testcookie2 value" +My-Custom-Header: My Custom Value +User: imroc +Accept-Encoding: gzip + +... +*/ + +// You can also set the common header and cookie for every request on client. +client.SetHeader(header). + SetHeaders(headers). + SetCookie(cookie). + SetCookies(cookies) + +resp, err := client.R().Get(url1) +... +resp, err := client.R().Get(url2) +... ``` ## License diff --git a/client.go b/client.go index af9ed818..c2a85a28 100644 --- a/client.go +++ b/client.go @@ -34,6 +34,7 @@ type Client struct { PathParams map[string]string QueryParams urlpkg.Values Headers http.Header + Cookies []*http.Cookie JSONMarshal func(v interface{}) ([]byte, error) JSONUnmarshal func(data []byte, v interface{}) error XMLMarshal func(v interface{}) ([]byte, error) @@ -102,6 +103,48 @@ func (c *Client) R() *Request { client: c, RawRequest: req, } + +} +func (c *Client) SetQueryParams(params map[string]string) *Client { + for k, v := range params { + c.SetQueryParam(k, v) + } + return c +} + +func (c *Client) SetQueryParam(key, value string) *Client { + if c.QueryParams == nil { + c.QueryParams = make(urlpkg.Values) + } + c.QueryParams.Set(key, value) + return c +} + +func (c *Client) SetQueryString(query string) *Client { + params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) + if err == nil { + if c.QueryParams == nil { + c.QueryParams = make(urlpkg.Values) + } + for p, v := range params { + for _, pv := range v { + c.QueryParams.Add(p, pv) + } + } + } else { + c.log.Warnf("failed to parse query string (%s): %v", query, err) + } + return c +} + +func (c *Client) SetCookie(hc *http.Cookie) *Client { + c.Cookies = append(c.Cookies, hc) + return c +} + +func (c *Client) SetCookies(cs []*http.Cookie) *Client { + c.Cookies = append(c.Cookies, cs...) + return c } const ( @@ -411,6 +454,7 @@ func C() *Client { beforeRequest := []RequestMiddleware{ parseRequestURL, parseRequestHeader, + parseRequestCookie, } afterResponse := []ResponseMiddleware{ parseResponseBody, @@ -431,6 +475,11 @@ func C() *Client { return c } +func setupRequest(r *Request) { + setRequestURL(r.RawRequest, r.URL) + setRequestHeaderAndCookie(r) +} + func (c *Client) Do(r *Request) (resp *Response, err error) { for _, f := range r.client.udBeforeRequest { @@ -445,11 +494,12 @@ func (c *Client) Do(r *Request) (resp *Response, err error) { } } - setRequestURL(r.RawRequest, r.URL) - setRequestHeader(r) + setupRequest(r) + if c.Debug { c.log.Debugf("%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) } + httpResponse, err := c.httpClient.Do(r.RawRequest) if err != nil { return @@ -475,11 +525,14 @@ func (c *Client) Do(r *Request) (resp *Response, err error) { return } -func setRequestHeader(r *Request) { +func setRequestHeaderAndCookie(r *Request) { if r.Headers == nil { r.Headers = make(http.Header) } r.RawRequest.Header = r.Headers + for _, cookie := range r.Cookies { + r.RawRequest.AddCookie(cookie) + } } func setRequestURL(r *http.Request, url string) error { diff --git a/middleware.go b/middleware.go index 64bf950a..1cb58c24 100644 --- a/middleware.go +++ b/middleware.go @@ -18,16 +18,6 @@ type ( ResponseMiddleware func(*Client, *Response) error ) -// unmarshalc content into object from JSON or XML -func unmarshalc(c *Client, ct string, b []byte, d interface{}) (err error) { - if util.IsJSONType(ct) { - err = c.JSONUnmarshal(b, d) - } else if util.IsXMLType(ct) { - err = c.XMLUnmarshal(b, d) - } - return -} - func parseResponseBody(c *Client, r *Response) (err error) { if r.StatusCode == http.StatusNoContent { return @@ -81,17 +71,28 @@ func parseRequestHeader(c *Client, r *Request) error { if c.Headers == nil { return nil } - hdr := make(http.Header) + if r.Headers == nil { + r.Headers = make(http.Header) + } for k := range c.Headers { - hdr[k] = append(hdr[k], c.Headers[k]...) + r.Headers[k] = append(r.Headers[k], c.Headers[k]...) } for k := range r.Headers { - hdr.Del(k) - hdr[k] = append(hdr[k], r.Headers[k]...) + r.Headers.Del(k) + r.Headers[k] = append(r.Headers[k], r.Headers[k]...) } - r.Headers = hdr + return nil +} + +func parseRequestCookie(c *Client, r *Request) error { + if len(c.Cookies) == 0 { + return nil + } + for _, ck := range c.Cookies { + r.Cookies = append(r.Cookies, ck) + } return nil } diff --git a/request.go b/request.go index f12881f5..9a091874 100644 --- a/request.go +++ b/request.go @@ -19,6 +19,7 @@ type Request struct { PathParams map[string]string QueryParams urlpkg.Values Headers http.Header + Cookies []*http.Cookie Result interface{} Error interface{} error error @@ -34,6 +35,16 @@ func New() *Request { return defaultClient.R() } +func (r *Request) SetCookie(hc *http.Cookie) *Request { + r.Cookies = append(r.Cookies, hc) + return r +} + +func (r *Request) SetCookies(rs []*http.Cookie) *Request { + r.Cookies = append(r.Cookies, rs...) + return r +} + func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { From 010bd6c182e78384708ca08298777493825b29df Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:06:18 +0800 Subject: [PATCH 084/843] unexpose client.Do --- client.go | 2 +- request.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index c2a85a28..eb8dc3ef 100644 --- a/client.go +++ b/client.go @@ -480,7 +480,7 @@ func setupRequest(r *Request) { setRequestHeaderAndCookie(r) } -func (c *Client) Do(r *Request) (resp *Response, err error) { +func (c *Client) do(r *Request) (resp *Response, err error) { for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { diff --git a/request.go b/request.go index 9a091874..6f83cac9 100644 --- a/request.go +++ b/request.go @@ -140,7 +140,7 @@ func (r *Request) Send(method, url string) (*Response, error) { } r.RawRequest.Method = method r.URL = url - return r.client.Do(r) + return r.client.do(r) } // MustGet like Get, panic if error happens. From a9cf2d57291824d66992e9bd6e42b5cb3c7f4b41 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:14:38 +0800 Subject: [PATCH 085/843] update README --- README.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0172bad9..a9c7079a 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,22 @@ client := req.C(). EnableDumpAsync(). // Dump asynchronously to improve performance EnableDumpToFile("reqdump.log") // Dump to file without printing it out client.Get(url) + +// Enable dump with fully customized settings +opt := &req.DumpOptions{ + Output: os.Stdout, + RequestHeader: true, + ResponseBody: true, + RequestBody: false, + ResponseHeader: false, + Async: false, + } +client := req.C().SetDumpOptions(opt).EnableDump(true) +client.R().Get("https://www.baidu.com/") + +// Change settings dynamiclly +opt.ResponseBody = false +client.R().Get("https://www.baidu.com/") ``` **Logging** @@ -191,7 +207,7 @@ client := req.C().EnableDumpOnlyHeader() // Send a request with multiple headers and cookies resp, err := client.R(). - SetHeader("Accept", "application/json"). // Set one header + SetHeader("Accept", "application/json"). // Set one header SetHeaders(map[string]string{ // Set multiple headers at once "My-Custom-Header": "My Custom Value", "User": "imroc", From ba3f59c4823abb6c892eb034f3a0e9325c7a90d7 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:17:58 +0800 Subject: [PATCH 086/843] bump to v2.0.0-alpha.4 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a9c7079a..b41b8454 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.3 +go get github.com/imroc/req/v2@v2.0.0-alpha.4 ``` ## Usage From 943a66562a439beeb6bb4bdfb36a919c7049cdeb Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:26:48 +0800 Subject: [PATCH 087/843] update README --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b41b8454..2ee81bdd 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,11 @@ A golang http request library for humans. ## Features -* Simple and chainable methods for client and request settings. -* Rich syntax sugar, greatly improving development efficiency. +* Simple and chainable methods for client and request settings, rich syntax sugar, less code and more efficiency. * Automatically detect charset and decode it to utf-8. * Powerful debugging capabilities (logging, tracing, and even dump the requests and responses' content). -* All settings can be dynamically adjusted, making it possible to debug in the production environment. -* Easy to integrate with existing code, just replace client's Transport then you can dump content as req to debug APIs. +* All settings can be changed dynamically, making it possible to debug in the production environment. +* Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. ## Install From d2bf574655fe6efd544190b9d41222ebc81c78fc Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:41:22 +0800 Subject: [PATCH 088/843] client support SetTLSClientConfig --- client.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client.go b/client.go index eb8dc3ef..456526ed 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package req import ( + "crypto/tls" "encoding/json" "encoding/xml" "github.com/imroc/req/v2/internal/util" @@ -105,6 +106,12 @@ func (c *Client) R() *Request { } } + +func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { + c.t.TLSClientConfig = conf + return c +} + func (c *Client) SetQueryParams(params map[string]string) *Client { for k, v := range params { c.SetQueryParam(k, v) From b43d85201c8c3806411cb216814ee27a5dfcf158 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 16:52:21 +0800 Subject: [PATCH 089/843] client support DisableCompression and DisableKeepAlives --- client.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/client.go b/client.go index 456526ed..88030ec3 100644 --- a/client.go +++ b/client.go @@ -107,6 +107,16 @@ func (c *Client) R() *Request { } +func (c *Client) DisableKeepAlives(disable bool) *Client { + c.t.DisableKeepAlives = disable + return c +} + +func (c *Client) DisableCompression(disable bool) *Client { + c.t.DisableCompression = disable + return c +} + func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf return c From 702b286ac5ea05ef12ef0a87b91840355b3dce1f Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 19:38:43 +0800 Subject: [PATCH 090/843] support more RedirectPolicy --- client.go | 24 +++++++++++++ middleware.go | 9 ++--- redirect.go | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 redirect.go diff --git a/client.go b/client.go index 88030ec3..ced7d2e1 100644 --- a/client.go +++ b/client.go @@ -15,6 +15,11 @@ import ( "time" ) +var ( + hdrUserAgentKey = "User-Agent" + hdrUserAgentValue = "req/v2 (https://github.com/imroc/req)" +) + // DefaultClient returns the global default Client. func DefaultClient() *Client { return defaultClient @@ -107,6 +112,25 @@ func (c *Client) R() *Request { } +func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { + if len(policies) == 0 { + return c + } + c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + for _, f := range policies { + if f == nil { + continue + } + err := f(req, via) + if err != nil { + return err + } + } + return nil + } + return c +} + func (c *Client) DisableKeepAlives(disable bool) *Client { c.t.DisableKeepAlives = disable return c diff --git a/middleware.go b/middleware.go index 1cb58c24..1f5b7d18 100644 --- a/middleware.go +++ b/middleware.go @@ -75,12 +75,9 @@ func parseRequestHeader(c *Client, r *Request) error { r.Headers = make(http.Header) } for k := range c.Headers { - r.Headers[k] = append(r.Headers[k], c.Headers[k]...) - } - - for k := range r.Headers { - r.Headers.Del(k) - r.Headers[k] = append(r.Headers[k], r.Headers[k]...) + if r.Headers.Get(k) == "" { + r.Headers.Add(k, c.Headers.Get(k)) + } } return nil diff --git a/redirect.go b/redirect.go new file mode 100644 index 00000000..ac1c2535 --- /dev/null +++ b/redirect.go @@ -0,0 +1,99 @@ +package req + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" +) + +type RedirectPolicy func(req *http.Request, via []*http.Request) error + +// MaxRedirectPolicy specifies the max number of redirect +func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { + return func(req *http.Request, via []*http.Request) error { + if len(via) >= noOfRedirect { + return fmt.Errorf("stopped after %d redirects", noOfRedirect) + } + return nil + } +} + +// NoRedirectPolicy disable redirect behaviour +func NoRedirectPolicy() RedirectPolicy { + return func(req *http.Request, via []*http.Request) error { + return errors.New("auto redirect is disabled") + } +} + +func SameDomainRedirectPolicy() RedirectPolicy { + return func(req *http.Request, via []*http.Request) error { + if getDomain(req.URL.Host) != getDomain(via[0].URL.Host) { + return errors.New("different domain name is not allowed") + } + return nil + } +} + +// SameHostRedirectPolicy allows redirect only if the redirected host +// is the same as original host, e.g. redirect to "www.imroc.cc" from +// "imroc.cc" is not the allowed. +func SameHostRedirectPolicy() RedirectPolicy { + return func(req *http.Request, via []*http.Request) error { + if getHostname(req.URL.Host) != getHostname(via[0].URL.Host) { + return errors.New("different host name is not allowed") + } + return nil + } +} + +// AllowedHostRedirectPolicy allows redirect only if the redirected host +// match one of the host that specified. +func AllowedHostRedirectPolicy(hosts ...string) RedirectPolicy { + m := make(map[string]bool) + for _, h := range hosts { + m[strings.ToLower(getHostname(h))] = true + } + + return func(req *http.Request, via []*http.Request) error { + if _, ok := m[getHostname(req.URL.Host)]; !ok { + return errors.New("redirect host is not allowed") + } + return nil + } +} + +// AllowedDomainRedirectPolicy allows redirect only if the redirected domain +// match one of the domain that specified. +func AllowedDomainRedirectPolicy(hosts ...string) RedirectPolicy { + domains := make(map[string]bool) + for _, h := range hosts { + domains[strings.ToLower(getDomain(h))] = true + } + + return func(req *http.Request, via []*http.Request) error { + if _, ok := domains[getDomain(req.URL.Host)]; !ok { + return errors.New("redirect domain is not allowed") + } + return nil + } +} + +func getHostname(host string) (hostname string) { + if strings.Index(host, ":") > 0 { + host, _, _ = net.SplitHostPort(host) + } + hostname = strings.ToLower(host) + return +} + +func getDomain(host string) string { + host = getHostname(host) + ss := strings.Split(host, ".") + if len(ss) < 3 { + return host + } + ss = ss[1:] + return strings.Join(ss, ".") +} From c3fdc9bafa4bbe1d56cb99f9420c72be1abaed8c Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 19:47:32 +0800 Subject: [PATCH 091/843] change default user agent --- client.go | 3 +-- h2_bundle.go | 4 +--- http.go | 6 ------ http_request.go | 2 +- 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index ced7d2e1..7c68be2c 100644 --- a/client.go +++ b/client.go @@ -109,7 +109,6 @@ func (c *Client) R() *Request { client: c, RawRequest: req, } - } func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { @@ -372,7 +371,7 @@ func (c *Client) EnableAutoDecodeTextType() *Client { // SetUserAgent set the "User-Agent" header for all requests. func (c *Client) SetUserAgent(userAgent string) *Client { - return c.SetHeader("User-Agent", userAgent) + return c.SetHeader(hdrUserAgentKey, userAgent) } func (c *Client) SetHeaders(hdrs map[string]string) *Client { diff --git a/h2_bundle.go b/h2_bundle.go index 57c3e218..698d13ef 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -6572,8 +6572,6 @@ const ( // a stream-level WINDOW_UPDATE for at a time. http2transportDefaultStreamMinRefresh = 4 << 10 - http2defaultUserAgent = "req/v2" - // initialMaxConcurrentStreams is a connections maxConcurrentStreams until // it's received servers initial SETTINGS frame, which corresponds with the // spec's minimum recommended value. @@ -8368,7 +8366,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, f("accept-encoding", "gzip") } if !didUA { - f("user-agent", http2defaultUserAgent) + f("user-agent", hdrUserAgentValue) } } diff --git a/http.go b/http.go index 2fb58319..361baae7 100644 --- a/http.go +++ b/http.go @@ -149,12 +149,6 @@ func foreachHeaderElement(v string, fn func(string)) { // well read them) const maxPostHandlerReadBytes = 256 << 10 -// NOTE: This is not intended to reflect the actual Go version being used. -// It was changed at the time of Go 1.1 release because the former User-Agent -// had ended up blocked by some intrusion detection systems. -// See https://codereview.appspot.com/7532043. -const defaultUserAgent = "req/v2" - func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. // Right now punycode verification, length checks, context checks, and the diff --git a/http_request.go b/http_request.go index 5c50ef7b..a7597abf 100644 --- a/http_request.go +++ b/http_request.go @@ -169,7 +169,7 @@ func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders ht // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. - userAgent := defaultUserAgent + userAgent := hdrUserAgentValue if headerHas(r.Header, "User-Agent") { userAgent = r.Header.Get("User-Agent") } From e317bd6301404525f79baea11d366fc707605db1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 20:05:27 +0800 Subject: [PATCH 092/843] support basic auth --- client.go | 12 +++++++++--- internal/util/util.go | 16 +++++++++++++++- middleware.go | 1 - request.go | 5 +++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 7c68be2c..bf1074b1 100644 --- a/client.go +++ b/client.go @@ -374,6 +374,11 @@ func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetHeader(hdrUserAgentKey, userAgent) } +func (c *Client) SetBasicAuth(username, password string) *Client { + c.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) + return c +} + func (c *Client) SetHeaders(hdrs map[string]string) *Client { for k, v := range hdrs { c.SetHeader(k, v) @@ -566,10 +571,11 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } func setRequestHeaderAndCookie(r *Request) { - if r.Headers == nil { - r.Headers = make(http.Header) + for k, vs := range r.Headers { + for _, v := range vs { + r.RawRequest.Header.Add(k, v) + } } - r.RawRequest.Header = r.Headers for _, cookie := range r.Cookies { r.RawRequest.AddCookie(cookie) } diff --git a/internal/util/util.go b/internal/util/util.go index 58773aab..4286f1b6 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -2,6 +2,7 @@ package util import ( "bytes" + "encoding/base64" "reflect" "regexp" "strings" @@ -22,7 +23,6 @@ func IsXMLType(ct string) bool { return xmlCheck.MatchString(ct) } - func GetPointer(v interface{}) interface{} { vv := reflect.ValueOf(v) if vv.Kind() == reflect.Ptr { @@ -68,3 +68,17 @@ func FirstNonEmpty(v ...string) string { } return "" } + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func BasicAuthHeaderValue(username, password string) string { + return "Basic " + basicAuth(username, password) +} diff --git a/middleware.go b/middleware.go index 1f5b7d18..74ae2165 100644 --- a/middleware.go +++ b/middleware.go @@ -79,7 +79,6 @@ func parseRequestHeader(c *Client, r *Request) error { r.Headers.Add(k, c.Headers.Get(k)) } } - return nil } diff --git a/request.go b/request.go index 6f83cac9..4a7146b2 100644 --- a/request.go +++ b/request.go @@ -69,6 +69,11 @@ func (r *Request) SetError(error interface{}) *Request { return r } +func (r *Request) SetBasicAuth(username, password string) *Request { + r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) + return r +} + func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { r.SetHeader(k, v) From 077a009c2217c97480853bc8f09c3ba20dfd33cd Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 20:39:24 +0800 Subject: [PATCH 093/843] support set client cert and root cert --- client.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/client.go b/client.go index bf1074b1..831da43e 100644 --- a/client.go +++ b/client.go @@ -2,11 +2,13 @@ package req import ( "crypto/tls" + "crypto/x509" "encoding/json" "encoding/xml" "github.com/imroc/req/v2/internal/util" "golang.org/x/net/publicsuffix" "io" + "io/ioutil" "net/http" "net/http/cookiejar" urlpkg "net/url" @@ -111,6 +113,59 @@ func (c *Client) R() *Request { } } +// SetCertFromFile helps to set client certificates from cert and key file +func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + c.log.Errorf("failed to load client cert: %v", err) + return c + } + config := c.tlsConfig() + config.Certificates = append(config.Certificates, cert) + return c +} + +// SetCerts helps to set client certificates +func (c *Client) SetCerts(certs ...tls.Certificate) *Client { + config := c.tlsConfig() + config.Certificates = append(config.Certificates, certs...) + return c +} + +func (c *Client) appendRootCertData(data []byte) { + config := c.tlsConfig() + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + config.RootCAs.AppendCertsFromPEM(data) + return +} + +// SetRootCertFromString helps to set root CA cert from string +func (c *Client) SetRootCertFromString(pemContent string) *Client { + c.appendRootCertData([]byte(pemContent)) + return c +} + +// SetRootCertFromFile helps to set root CA cert from file +func (c *Client) SetRootCertFromFile(pemFilePath string) *Client { + rootPemData, err := ioutil.ReadFile(pemFilePath) + if err != nil { + c.log.Errorf("failed to read root cert file: %v", err) + return c + } + c.appendRootCertData(rootPemData) + return c +} + +func (c *Client) tlsConfig() *tls.Config { + if c.t.TLSClientConfig == nil { + c.t.TLSClientConfig = &tls.Config{} + } + return c.t.TLSClientConfig +} + +// SetRedirectPolicy helps to set the RedirectPolicy func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c From e2663e17730d231e7da9fc99633469c795152e07 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 20:39:45 +0800 Subject: [PATCH 094/843] bump to v2.0.0-alpha.5 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2ee81bdd..0b75df10 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.4 +go get github.com/imroc/req/v2@v2.0.0-alpha.5 ``` ## Usage From 23ff0bf047c63c00d0c5447d98fb0c433631af05 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 20:56:32 +0800 Subject: [PATCH 095/843] update README --- README.md | 108 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 0b75df10..994fbd06 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,10 @@ import "github.com/imroc/req/v2" * [Quick Start](#Quick-Start) * [Debug](#Debug) -* [PathParam and QueryParam](#PathParam-QueryParam) -* [Header and Cookie](#Header-Cookie) +* [PathParam](#PathParam) +* [QueryParam](#QueryParam) +* [Header](#Header) +* [Cookie](#Cookie) ### Quick Start @@ -71,8 +73,8 @@ resp, err := client.R(). **Dump the content of request and response** ```go -// Set EnableDump to true, default dump all content to stderr, -// including both header and body of request and response +// Set EnableDump to true, dump all content to stderr by default, +// including both the header and body of all request and response client := req.C().EnableDump(true) client.R().Get("https://api.github.com/users/imroc") @@ -81,7 +83,7 @@ client.R().Get("https://api.github.com/users/imroc") :method: GET :path: /users/imroc :scheme: https -user-agent: custom-client +user-agent: req/v2 (https://github.com/imroc/req) accept-encoding: gzip :status: 200 @@ -164,7 +166,9 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -### PathParam and QueryParam +### PathParam + +Use `PathParam` to replace variable in the url path: ```go client := req.C().DevMode() @@ -173,44 +177,96 @@ client.R(). SetPathParams(map[string]string{ // Set multiple path params at once "repo": "req", "path": "README.md", - }).SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url + }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") +/* Output +2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md +... +*/ + +// You can also set the common PathParam for every request on client +client.SetPathParam(k1, v1).SetPathParams(pathParams) + +resp, err := client.Get(url1) +... + +resp, err := client.Get(url2) +... +``` + +### QueryParam + +```go +client := req.C().DevMode() +client.R(). + SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url SetQueryParams(map[string]string{ // Set multiple query params at once "b": "b", "c": "c", }).SetQueryString("d=d&e=e"). // Set query params as a raw query string - Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}?x=x") + Get("https://api.github.com/repos/imroc/req/contents/README.md?x=x") /* Output 2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e ... */ -// You can also set the common PathParam and QueryParam for every request on client -client.SetPathParam(k1, v1). - SetPathParams(pathParams). - SetQueryParam(k2, v2). +// You can also set the common QueryParam for every request on client +client.SetQueryParam(k, v). SetQueryParams(queryParams). SetQueryString(queryString). -resp, err := client.Get(url1) +resp1, err := client.Get(url1) ... - -resp, err := client.Get(url2) +resp2, err := client.Get(url2) ... ``` -### Header and Cookie +### Header ```go // Let's dump the header to see what's going on client := req.C().EnableDumpOnlyHeader() // Send a request with multiple headers and cookies -resp, err := client.R(). +client.R(). SetHeader("Accept", "application/json"). // Set one header SetHeaders(map[string]string{ // Set multiple headers at once "My-Custom-Header": "My Custom Value", "User": "imroc", - }).SetCookie(&http.Cookie{ // Set one cookie + }).Get("https://www.baidu.com/") + +/* Output +GET / HTTP/1.1 +Host: www.baidu.com +User-Agent: req/v2 (https://github.com/imroc/req) +Accept: application/json +My-Custom-Header: My Custom Value +User: imroc +Accept-Encoding: gzip + +... +*/ + +// You can also set the common header and cookie for every request on client. +client.SetHeader(header). + SetHeaders(headers). + SetCookie(cookie). + SetCookies(cookies) + +resp1, err := client.R().Get(url1) +... +resp2, err := client.R().Get(url2) +... +``` + +### Cookie + +```go +// Let's dump the header to see what's going on +client := req.C().EnableDumpOnlyHeader() + +// Send a request with multiple headers and cookies +client.R(). + SetCookie(&http.Cookie{ // Set one cookie Name: "imroc/req", Value: "This is my custome cookie value", Path: "/", @@ -242,26 +298,20 @@ resp, err := client.R(). /* Output GET / HTTP/1.1 Host: www.baidu.com -User-Agent: req/v2 +User-Agent: req/v2 (https://github.com/imroc/req) Accept: application/json Cookie: imroc/req="This is my custome cookie value"; testcookie1="testcookie1 value"; testcookie2="testcookie2 value" -My-Custom-Header: My Custom Value -User: imroc Accept-Encoding: gzip ... */ -// You can also set the common header and cookie for every request on client. -client.SetHeader(header). - SetHeaders(headers). - SetCookie(cookie). - SetCookies(cookies) +// You can also set the common cookie for every request on client. +client.SetCookie(cookie).SetCookies(cookies) -resp, err := client.R().Get(url1) -... -resp, err := client.R().Get(url2) +resp1, err := client.R().Get(url1) ... +resp2, err := client.R().Get(url2) ``` ## License From 08345a011e492df87168230b9901dd2a7f7f5f93 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 23 Jan 2022 20:58:29 +0800 Subject: [PATCH 096/843] update README --- README.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 994fbd06..99743560 100644 --- a/README.md +++ b/README.md @@ -186,10 +186,10 @@ client.R(). // You can also set the common PathParam for every request on client client.SetPathParam(k1, v1).SetPathParams(pathParams) -resp, err := client.Get(url1) +resp1, err := client.Get(url1) ... -resp, err := client.Get(url2) +resp2, err := client.Get(url2) ... ``` @@ -247,10 +247,7 @@ Accept-Encoding: gzip */ // You can also set the common header and cookie for every request on client. -client.SetHeader(header). - SetHeaders(headers). - SetCookie(cookie). - SetCookies(cookies) +client.SetHeader(header).SetHeaders(headers) resp1, err := client.R().Get(url1) ... From bb4ebe45161b96dccabc71e2ab079532f8484155 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:09:56 +0800 Subject: [PATCH 097/843] support global wrapper --- README.md | 37 +++++++-- client.go | 218 +++++++++++++++++++++++++++++++++++++++++++++++------ logger.go | 11 ++- request.go | 141 +++++++++++++++++++++++++++++++++- 4 files changed, 371 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 99743560..eeec7c54 100644 --- a/README.md +++ b/README.md @@ -41,31 +41,54 @@ import "github.com/imroc/req/v2" ```go // Create and send a request with the global default client -resp, err := req.New().Get("https://api.github.com/users/imroc") +req.DevMode() +resp, err := req.Get("https://api.github.com/users/imroc") // Create and send a request with the custom client -client := req.C() +client := req.C().DevMode() resp, err := client.R().Get("https://api.github.com/users/imroc") ``` **Client Settings** ```go -client.SetUserAgent("my-custom-client"). - SetTimeout(5 * time.Second). - DevMode() +// Create a client with custom client settings +client := req.C().SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() + +// You can also configure the global client using the same chaining method, req wraps glbal +// method for default client +req.SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() ``` **Request Settings** ```go +// Use client to create a request with custom request settings +client := req.C() var result Result resp, err := client.R(). - SetResult(&result). SetHeader("Accept", "application/json"). SetQeuryParam("page", "1"). SetPathParam("userId", "imroc"). - Get(url) + SetResult(&result). + Get(url) + +// You can also create a request using global client using the same chaining method +resp, err := req.R(). + SetHeader("Accept", "application/json"). + SetQeuryParam("page", "1"). + SetPathParam("userId", "imroc"). + SetResult(&result). + Get(url) + +// You can even also create a request without calling R(), cuz req +// wraps global method for request, and create a request using +// the default client automatically. +resp, err := req.SetHeader("Accept", "application/json"). + SetQeuryParam("page", "1"). + SetPathParam("userId", "imroc"). + SetResult(&result). + Get(url) ``` ### Debug diff --git a/client.go b/client.go index 831da43e..efc6abbd 100644 --- a/client.go +++ b/client.go @@ -99,6 +99,10 @@ func cloneMap(h map[string]string) map[string]string { return m } +func R() *Request { + return defaultClient.R() +} + // R create a new request. func (c *Client) R() *Request { req := &http.Request{ @@ -113,6 +117,10 @@ func (c *Client) R() *Request { } } +func SetCertFromFile(certFile, keyFile string) *Client { + return defaultClient.SetCertFromFile(certFile, keyFile) +} + // SetCertFromFile helps to set client certificates from cert and key file func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { cert, err := tls.LoadX509KeyPair(certFile, keyFile) @@ -125,6 +133,10 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { return c } +func SetCerts(certs ...tls.Certificate) *Client { + return defaultClient.SetCerts(certs...) +} + // SetCerts helps to set client certificates func (c *Client) SetCerts(certs ...tls.Certificate) *Client { config := c.tlsConfig() @@ -141,12 +153,20 @@ func (c *Client) appendRootCertData(data []byte) { return } +func SetRootCertFromString(pemContent string) *Client { + return defaultClient.SetRootCertFromString(pemContent) +} + // SetRootCertFromString helps to set root CA cert from string func (c *Client) SetRootCertFromString(pemContent string) *Client { c.appendRootCertData([]byte(pemContent)) return c } +func SetRootCertFromFile(pemFilePath string) *Client { + return defaultClient.SetRootCertFromFile(pemFilePath) +} + // SetRootCertFromFile helps to set root CA cert from file func (c *Client) SetRootCertFromFile(pemFilePath string) *Client { rootPemData, err := ioutil.ReadFile(pemFilePath) @@ -165,6 +185,10 @@ func (c *Client) tlsConfig() *tls.Config { return c.t.TLSClientConfig } +func SetRedirectPolicy(policies ...RedirectPolicy) *Client { + return defaultClient.SetRedirectPolicy(policies...) +} + // SetRedirectPolicy helps to set the RedirectPolicy func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { @@ -185,29 +209,49 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { return c } +func DisableKeepAlives(disable bool) *Client { + return defaultClient.DisableKeepAlives(disable) +} + func (c *Client) DisableKeepAlives(disable bool) *Client { c.t.DisableKeepAlives = disable return c } +func DisableCompression(disable bool) *Client { + return defaultClient.DisableCompression(disable) +} + func (c *Client) DisableCompression(disable bool) *Client { c.t.DisableCompression = disable return c } +func SetTLSClientConfig(conf *tls.Config) *Client { + return defaultClient.SetTLSClientConfig(conf) +} + func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf return c } -func (c *Client) SetQueryParams(params map[string]string) *Client { +func SetCommonQueryParams(params map[string]string) *Client { + return defaultClient.SetCommonQueryParams(params) +} + +func (c *Client) SetCommonQueryParams(params map[string]string) *Client { for k, v := range params { - c.SetQueryParam(k, v) + c.SetCommonQueryParam(k, v) } return c } -func (c *Client) SetQueryParam(key, value string) *Client { +func SetCommonQueryParam(key, value string) *Client { + return defaultClient.SetCommonQueryParam(key, value) +} + +func (c *Client) SetCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) } @@ -215,7 +259,11 @@ func (c *Client) SetQueryParam(key, value string) *Client { return c } -func (c *Client) SetQueryString(query string) *Client { +func SetCommonQueryString(query string) *Client { + return defaultClient.SetCommonQueryString(query) +} + +func (c *Client) SetCommonQueryString(query string) *Client { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { if c.QueryParams == nil { @@ -232,12 +280,20 @@ func (c *Client) SetQueryString(query string) *Client { return c } -func (c *Client) SetCookie(hc *http.Cookie) *Client { +func SetCommonCookie(hc *http.Cookie) *Client { + return defaultClient.SetCommonCookie(hc) +} + +func (c *Client) SetCommonCookie(hc *http.Cookie) *Client { c.Cookies = append(c.Cookies, hc) return c } -func (c *Client) SetCookies(cs []*http.Cookie) *Client { +func SetCommonCookies(cs []*http.Cookie) *Client { + return defaultClient.SetCommonCookies(cs) +} + +func (c *Client) SetCommonCookies(cs []*http.Cookie) *Client { c.Cookies = append(c.Cookies, cs...) return c } @@ -247,21 +303,32 @@ const ( userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" ) +func EnableDebug(enable bool) *Client { + return defaultClient.EnableDebug(enable) +} + func (c *Client) EnableDebug(enable bool) *Client { c.Debug = enable return c } +func DevMode() *Client { + return defaultClient.DevMode() +} + // DevMode enables dump for requests and responses, and set user // agent to pretend to be a web browser, Avoid returning abnormal // data from some sites. func (c *Client) DevMode() *Client { - return c.EnableAutoDecodeTextType(). - EnableDumpAll(). + return c.EnableDumpAll(). EnableDebug(true). SetUserAgent(userAgentChrome) } +func SetScheme(scheme string) *Client { + return defaultClient.SetScheme(scheme) +} + // SetScheme method sets custom scheme in the Resty client. It's way to override default. // client.SetScheme("http") func (c *Client) SetScheme(scheme string) *Client { @@ -271,6 +338,10 @@ func (c *Client) SetScheme(scheme string) *Client { return c } +func SetLogger(log Logger) *Client { + return defaultClient.SetLogger(log) +} + // SetLogger set the logger for req, set to nil to disable logger. func (c *Client) SetLogger(log Logger) *Client { if log == nil { @@ -281,6 +352,10 @@ func (c *Client) SetLogger(log Logger) *Client { return c } +func GetResponseOptions() *ResponseOptions { + return defaultClient.GetResponseOptions() +} + func (c *Client) GetResponseOptions() *ResponseOptions { if c.t.ResponseOptions == nil { c.t.ResponseOptions = &ResponseOptions{} @@ -288,6 +363,10 @@ func (c *Client) GetResponseOptions() *ResponseOptions { return c.t.ResponseOptions } +func SetResponseOptions(opt *ResponseOptions) *Client { + return defaultClient.SetResponseOptions(opt) +} + // SetResponseOptions set the ResponseOptions for the underlying Transport. func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { if opt == nil { @@ -297,12 +376,20 @@ func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { return c } +func SetTimeout(d time.Duration) *Client { + return defaultClient.SetTimeout(d) +} + // SetTimeout set the timeout for all requests. func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d return c } +func GetDumpOptions() *DumpOptions { + return defaultClient.GetDumpOptions() +} + func (c *Client) GetDumpOptions() *DumpOptions { if c.dumpOptions == nil { c.dumpOptions = newDefaultDumpOptions() @@ -317,6 +404,10 @@ func (c *Client) enableDump() { c.t.EnableDump(c.GetDumpOptions()) } +func EnableDumpToFile(filename string) *Client { + return defaultClient.EnableDumpToFile(filename) +} + // EnableDumpToFile indicates that the content should dump to the specified filename. func (c *Client) EnableDumpToFile(filename string) *Client { file, err := os.Create(filename) @@ -328,6 +419,10 @@ func (c *Client) EnableDumpToFile(filename string) *Client { return c } +func EnableDumpTo(output io.Writer) *Client { + return defaultClient.EnableDumpTo(output) +} + // EnableDumpTo indicates that the content should dump to the specified destination. func (c *Client) EnableDumpTo(output io.Writer) *Client { c.GetDumpOptions().Output = output @@ -335,6 +430,10 @@ func (c *Client) EnableDumpTo(output io.Writer) *Client { return c } +func EnableDumpAsync() *Client { + return defaultClient.EnableDumpAsync() +} + // EnableDumpAsync indicates that the dump should be done asynchronously, // can be used for debugging in production environment without // affecting performance. @@ -345,6 +444,10 @@ func (c *Client) EnableDumpAsync() *Client { return c } +func EnableDumpOnlyResponse() *Client { + return defaultClient.EnableDumpOnlyResponse() +} + // EnableDumpOnlyResponse indicates that should dump the responses' head and response. func (c *Client) EnableDumpOnlyResponse() *Client { o := c.GetDumpOptions() @@ -356,6 +459,10 @@ func (c *Client) EnableDumpOnlyResponse() *Client { return c } +func EnableDumpOnlyRequest() *Client { + return defaultClient.EnableDumpOnlyRequest() +} + // EnableDumpOnlyRequest indicates that should dump the requests' head and response. func (c *Client) EnableDumpOnlyRequest() *Client { o := c.GetDumpOptions() @@ -367,6 +474,10 @@ func (c *Client) EnableDumpOnlyRequest() *Client { return c } +func EnableDumpOnlyBody() *Client { + return defaultClient.EnableDumpOnlyBody() +} + // EnableDumpOnlyBody indicates that should dump the body of requests and responses. func (c *Client) EnableDumpOnlyBody() *Client { o := c.GetDumpOptions() @@ -378,6 +489,10 @@ func (c *Client) EnableDumpOnlyBody() *Client { return c } +func EnableDumpOnlyHeader() *Client { + return defaultClient.EnableDumpOnlyHeader() +} + // EnableDumpOnlyHeader indicates that should dump the head of requests and responses. func (c *Client) EnableDumpOnlyHeader() *Client { o := c.GetDumpOptions() @@ -389,6 +504,10 @@ func (c *Client) EnableDumpOnlyHeader() *Client { return c } +func EnableDumpAll() *Client { + return defaultClient.EnableDumpAll() +} + // EnableDumpAll indicates that should dump both requests and responses' head and body. func (c *Client) EnableDumpAll() *Client { o := c.GetDumpOptions() @@ -400,49 +519,83 @@ func (c *Client) EnableDumpAll() *Client { return c } +func NewRequest() *Request { + return defaultClient.R() +} + // NewRequest is the alias of R() func (c *Client) NewRequest() *Request { return c.R() } +func DisableAutoReadResponse(disable bool) *Client { + return defaultClient.DisableAutoReadResponse(disable) +} + func (c *Client) DisableAutoReadResponse(disable bool) *Client { c.disableAutoReadResponse = disable return c } +func EnableAutoDecodeAllType() *Client { + return defaultClient.EnableAutoDecodeAllType() +} + // EnableAutoDecodeAllType indicates that try autodetect and decode all content type. func (c *Client) EnableAutoDecodeAllType() *Client { - c.GetResponseOptions().AutoDecodeContentType = func(contentType string) bool { + opt := c.GetResponseOptions() + opt.AutoDecodeContentType = func(contentType string) bool { return true } + opt.DisableAutoDecode = false return c } -// EnableAutoDecodeTextType indicates that only try autodetect and decode the text content type. -func (c *Client) EnableAutoDecodeTextType() *Client { - c.GetResponseOptions().AutoDecodeContentType = autoDecodeText +func DisableAutoDecode(disable bool) *Client { + return defaultClient.DisableAutoDecode(disable) +} + +// DisableAutoDecode disable auto detect charset and decode to utf-8 +func (c *Client) DisableAutoDecode(disable bool) *Client { + c.GetResponseOptions().DisableAutoDecode = disable return c } +func SetUserAgent(userAgent string) *Client { + return defaultClient.SetUserAgent(userAgent) +} + // SetUserAgent set the "User-Agent" header for all requests. func (c *Client) SetUserAgent(userAgent string) *Client { - return c.SetHeader(hdrUserAgentKey, userAgent) + return c.SetCommonHeader(hdrUserAgentKey, userAgent) +} + +func SetCommonBasicAuth(username, password string) *Client { + return defaultClient.SetCommonBasicAuth(username, password) } -func (c *Client) SetBasicAuth(username, password string) *Client { - c.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) +func (c *Client) SetCommonBasicAuth(username, password string) *Client { + c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password)) return c } -func (c *Client) SetHeaders(hdrs map[string]string) *Client { +func SetCommonHeaders(hdrs map[string]string) *Client { + return defaultClient.SetCommonHeaders(hdrs) +} + +func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { for k, v := range hdrs { - c.SetHeader(k, v) + c.SetCommonHeader(k, v) } return c } -// SetHeader set the common header for all requests. -func (c *Client) SetHeader(key, value string) *Client { +func SetCommonHeader(key, value string) *Client { + return defaultClient.SetCommonHeader(key, value) +} + +// SetCommonHeader set the common header for all requests. +func (c *Client) SetCommonHeader(key, value string) *Client { if c.Headers == nil { c.Headers = make(http.Header) } @@ -450,6 +603,10 @@ func (c *Client) SetHeader(key, value string) *Client { return c } +func EnableDump(enable bool) *Client { + return defaultClient.EnableDump(enable) +} + // EnableDump enables dump requests and responses, allowing you // to clearly see the content of all requests and responses,which // is very convenient for debugging APIs. @@ -462,6 +619,10 @@ func (c *Client) EnableDump(enable bool) *Client { return c } +func SetDumpOptions(opt *DumpOptions) *Client { + return defaultClient.SetDumpOptions(opt) +} + // SetDumpOptions configures the underlying Transport's DumpOptions func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { if opt == nil { @@ -474,15 +635,18 @@ func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { return c } +func SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { + return defaultClient.SetProxy(proxy) +} + // SetProxy set the proxy function. func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { c.t.Proxy = proxy return c } -func (c *Client) SetProxyFromEnv() *Client { - c.t.Proxy = http.ProxyFromEnvironment - return c +func OnBeforeRequest(m RequestMiddleware) *Client { + return defaultClient.OnBeforeRequest(m) } func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { @@ -490,11 +654,19 @@ func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { return c } +func OnAfterResponse(m ResponseMiddleware) *Client { + return defaultClient.OnAfterResponse(m) +} + func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { c.afterResponse = append(c.afterResponse, m) return c } +func SetProxyURL(proxyUrl string) *Client { + return defaultClient.SetProxyURL(proxyUrl) +} + func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) if err != nil { @@ -563,7 +735,7 @@ func C() *Client { c := &Client{ beforeRequest: beforeRequest, afterResponse: afterResponse, - log: createLogger(), + log: createDefaultLogger(), httpClient: httpClient, t: t, t2: t2, diff --git a/logger.go b/logger.go index 56b6516a..03a6727b 100644 --- a/logger.go +++ b/logger.go @@ -1,6 +1,7 @@ package req import ( + "io" "log" "os" ) @@ -13,9 +14,13 @@ type Logger interface { Debugf(format string, v ...interface{}) } -func createLogger() *logger { - l := &logger{l: log.New(os.Stderr, "", log.Ldate|log.Lmicroseconds)} - return l +// NewLogger create a Logger wraps the *log.Logger +func NewLogger(output io.Writer, prefix string, flag int) Logger { + return &logger{l: log.New(output, prefix, flag)} +} + +func createDefaultLogger() Logger { + return NewLogger(os.Stderr, "", log.Ldate|log.Lmicroseconds) } var _ Logger = (*logger)(nil) diff --git a/request.go b/request.go index 4a7146b2..be4e8925 100644 --- a/request.go +++ b/request.go @@ -30,9 +30,8 @@ type Request struct { output io.WriteCloser } -// New create a new request using the global default client. -func New() *Request { - return defaultClient.R() +func SetCookie(hc *http.Cookie) *Request { + return defaultClient.R().SetCookie(hc) } func (r *Request) SetCookie(hc *http.Cookie) *Request { @@ -40,11 +39,19 @@ func (r *Request) SetCookie(hc *http.Cookie) *Request { return r } +func SetCookies(rs []*http.Cookie) *Request { + return defaultClient.R().SetCookies(rs) +} + func (r *Request) SetCookies(rs []*http.Cookie) *Request { r.Cookies = append(r.Cookies, rs...) return r } +func SetQueryString(query string) *Request { + return defaultClient.R().SetQueryString(query) +} + func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { @@ -59,21 +66,37 @@ func (r *Request) SetQueryString(query string) *Request { return r } +func SetResult(result interface{}) *Request { + return defaultClient.R().SetResult(result) +} + func (r *Request) SetResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r } +func SetError(error interface{}) *Request { + return defaultClient.R().SetError(error) +} + func (r *Request) SetError(error interface{}) *Request { r.Error = util.GetPointer(error) return r } +func SetBasicAuth(username, password string) *Request { + return defaultClient.R().SetBasicAuth(username, password) +} + func (r *Request) SetBasicAuth(username, password string) *Request { r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) return r } +func SetHeaders(hdrs map[string]string) *Request { + return defaultClient.R().SetHeaders(hdrs) +} + func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { r.SetHeader(k, v) @@ -81,6 +104,10 @@ func (r *Request) SetHeaders(hdrs map[string]string) *Request { return r } +func SetHeader(key, value string) *Request { + return defaultClient.R().SetHeader(key, value) +} + // SetHeader set the common header for all requests. func (r *Request) SetHeader(key, value string) *Request { if r.Headers == nil { @@ -90,6 +117,10 @@ func (r *Request) SetHeader(key, value string) *Request { return r } +func SetOutputFile(file string) *Request { + return defaultClient.R().SetOutputFile(file) +} + func (r *Request) SetOutputFile(file string) *Request { output, err := os.Create(file) if err != nil { @@ -99,12 +130,20 @@ func (r *Request) SetOutputFile(file string) *Request { return r.SetOutput(output) } +func SetOutput(output io.WriteCloser) *Request { + return defaultClient.R().SetOutput(output) +} + func (r *Request) SetOutput(output io.WriteCloser) *Request { r.output = output r.isSaveResponse = true return r } +func SetQueryParams(params map[string]string) *Request { + return defaultClient.R().SetQueryParams(params) +} + func (r *Request) SetQueryParams(params map[string]string) *Request { for k, v := range params { r.SetQueryParam(k, v) @@ -112,6 +151,10 @@ func (r *Request) SetQueryParams(params map[string]string) *Request { return r } +func SetQueryParam(key, value string) *Request { + return defaultClient.R().SetQueryParam(key, value) +} + func (r *Request) SetQueryParam(key, value string) *Request { if r.QueryParams == nil { r.QueryParams = make(urlpkg.Values) @@ -120,6 +163,10 @@ func (r *Request) SetQueryParam(key, value string) *Request { return r } +func SetPathParams(params map[string]string) *Request { + return defaultClient.R().SetPathParams(params) +} + func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { r.SetPathParam(key, value) @@ -127,6 +174,10 @@ func (r *Request) SetPathParams(params map[string]string) *Request { return r } +func SetPathParam(key, value string) *Request { + return defaultClient.R().SetPathParam(key, value) +} + func (r *Request) SetPathParam(key, value string) *Request { if r.PathParams == nil { r.PathParams = make(map[string]string) @@ -148,6 +199,10 @@ func (r *Request) Send(method, url string) (*Response, error) { return r.client.do(r) } +func MustGet(url string) *Response { + return defaultClient.R().MustGet(url) +} + // MustGet like Get, panic if error happens. func (r *Request) MustGet(url string) *Response { resp, err := r.Get(url) @@ -157,11 +212,19 @@ func (r *Request) MustGet(url string) *Response { return resp } +func Get(url string) (*Response, error) { + return defaultClient.R().Get(url) +} + // Get Send the request with GET method and specified url. func (r *Request) Get(url string) (*Response, error) { return r.Send(http.MethodGet, url) } +func MustPost(url string) *Response { + return defaultClient.R().MustPost(url) +} + // MustPost like Post, panic if error happens. func (r *Request) MustPost(url string) *Response { resp, err := r.Post(url) @@ -171,11 +234,19 @@ func (r *Request) MustPost(url string) *Response { return resp } +func Post(url string) (*Response, error) { + return defaultClient.R().Post(url) +} + // Post Send the request with POST method and specified url. func (r *Request) Post(url string) (*Response, error) { return r.Send(http.MethodPost, url) } +func MustPut(url string) *Response { + return defaultClient.R().MustPut(url) +} + // MustPut like Put, panic if error happens. func (r *Request) MustPut(url string) *Response { resp, err := r.Put(url) @@ -185,11 +256,19 @@ func (r *Request) MustPut(url string) *Response { return resp } +func Put(url string) (*Response, error) { + return defaultClient.R().Put(url) +} + // Put Send the request with Put method and specified url. func (r *Request) Put(url string) (*Response, error) { return r.Send(http.MethodPut, url) } +func MustPatch(url string) *Response { + return defaultClient.R().MustPatch(url) +} + // MustPatch like Patch, panic if error happens. func (r *Request) MustPatch(url string) *Response { resp, err := r.Patch(url) @@ -199,11 +278,19 @@ func (r *Request) MustPatch(url string) *Response { return resp } +func Patch(url string) (*Response, error) { + return defaultClient.R().Patch(url) +} + // Patch Send the request with PATCH method and specified url. func (r *Request) Patch(url string) (*Response, error) { return r.Send(http.MethodPatch, url) } +func MustDelete(url string) *Response { + return defaultClient.R().MustDelete(url) +} + // MustDelete like Delete, panic if error happens. func (r *Request) MustDelete(url string) *Response { resp, err := r.Delete(url) @@ -213,11 +300,19 @@ func (r *Request) MustDelete(url string) *Response { return resp } +func Delete(url string) (*Response, error) { + return defaultClient.R().Delete(url) +} + // Delete Send the request with DELETE method and specified url. func (r *Request) Delete(url string) (*Response, error) { return r.Send(http.MethodDelete, url) } +func MustOptions(url string) *Response { + return defaultClient.R().MustOptions(url) +} + // MustOptions like Options, panic if error happens. func (r *Request) MustOptions(url string) *Response { resp, err := r.Options(url) @@ -227,11 +322,19 @@ func (r *Request) MustOptions(url string) *Response { return resp } +func Options(url string) (*Response, error) { + return defaultClient.R().Options(url) +} + // Options Send the request with OPTIONS method and specified url. func (r *Request) Options(url string) (*Response, error) { return r.Send(http.MethodOptions, url) } +func MustHead(url string) *Response { + return defaultClient.R().MustHead(url) +} + // MustHead like Head, panic if error happens. func (r *Request) MustHead(url string) *Response { resp, err := r.Send(http.MethodHead, url) @@ -241,11 +344,19 @@ func (r *Request) MustHead(url string) *Response { return resp } +func Head(url string) (*Response, error) { + return defaultClient.R().Head(url) +} + // Head Send the request with HEAD method and specified url. func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } +func SetBody(body interface{}) *Request { + return defaultClient.R().SetBody(body) +} + // SetBody set the request body. func (r *Request) SetBody(body interface{}) *Request { if body == nil { @@ -264,18 +375,30 @@ func (r *Request) SetBody(body interface{}) *Request { return r } +func SetBodyBytes(body []byte) *Request { + return defaultClient.R().SetBodyBytes(body) +} + // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r } +func SetBodyString(body string) *Request { + return defaultClient.R().SetBodyString(body) +} + // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r } +func SetBodyJsonString(body string) *Request { + return defaultClient.R().SetBodyJsonString(body) +} + // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=UTF-8" func (r *Request) SetBodyJsonString(body string) *Request { @@ -284,6 +407,10 @@ func (r *Request) SetBodyJsonString(body string) *Request { return r } +func SetBodyJsonBytes(body []byte) *Request { + return defaultClient.R().SetBodyJsonBytes(body) +} + // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=UTF-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { @@ -292,6 +419,10 @@ func (r *Request) SetBodyJsonBytes(body []byte) *Request { return r } +func SetBodyJsonMarshal(v interface{}) *Request { + return defaultClient.R().SetBodyJsonMarshal(v) +} + // SetBodyJsonMarshal set the request body that marshaled from object, and // set Content-Type header as "application/json; charset=UTF-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { @@ -303,6 +434,10 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { return r.SetBodyBytes(b) } +func SetContentType(contentType string) *Request { + return defaultClient.R().SetContentType(contentType) +} + func (r *Request) SetContentType(contentType string) *Request { r.RawRequest.Header.Set("Content-Type", contentType) return r From 556ccd08487b1af4b7386c7991a2a48f73bf8c21 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:10:50 +0800 Subject: [PATCH 098/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eeec7c54..a58da75e 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ A golang http request library for humans. ## Install ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.5 +go get github.com/imroc/req/v2@v2.0.0-alpha.6 ``` ## Usage From 4e735812d348182f36bf265a0c8d7fe39a46f274 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:13:25 +0800 Subject: [PATCH 099/843] update README --- README.md | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a58da75e..fe18d15f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ A golang http request library for humans. go get github.com/imroc/req/v2@v2.0.0-alpha.6 ``` -## Usage +## Quick-Start Import req in your code: @@ -26,17 +26,6 @@ Import req in your code: import "github.com/imroc/req/v2" ``` -### Table of Contents - -* [Quick Start](#Quick-Start) -* [Debug](#Debug) -* [PathParam](#PathParam) -* [QueryParam](#QueryParam) -* [Header](#Header) -* [Cookie](#Cookie) - -### Quick Start - **Simple GET** ```go @@ -91,6 +80,14 @@ resp, err := req.SetHeader("Accept", "application/json"). Get(url) ``` +## Examples + +* [Debug](#Debug) +* [PathParam](#PathParam) +* [QueryParam](#QueryParam) +* [Header](#Header) +* [Cookie](#Cookie) + ### Debug **Dump the content of request and response** From 2a7bf14bd9956d6c04d99c2c493e84433fd5c3b9 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:14:49 +0800 Subject: [PATCH 100/843] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fe18d15f..3a0175db 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,14 @@ A golang http request library for humans. * All settings can be changed dynamically, making it possible to debug in the production environment. * Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. -## Install +## Quick-Start + +**Install** ``` sh go get github.com/imroc/req/v2@v2.0.0-alpha.6 ``` -## Quick-Start - Import req in your code: ```go From b12e22398af4ad3d3adc0dd324372ee4b50fe5c0 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:19:25 +0800 Subject: [PATCH 101/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3a0175db..c945c1e1 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ A golang http request library for humans. go get github.com/imroc/req/v2@v2.0.0-alpha.6 ``` -Import req in your code: +**Import** ```go import "github.com/imroc/req/v2" From dd203b201ac23f0fe96d552aa6999257a2b8b0ef Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 11:21:36 +0800 Subject: [PATCH 102/843] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c945c1e1..fe782e24 100644 --- a/README.md +++ b/README.md @@ -56,9 +56,9 @@ req.SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() client := req.C() var result Result resp, err := client.R(). - SetHeader("Accept", "application/json"). - SetQeuryParam("page", "1"). - SetPathParam("userId", "imroc"). + SetHeader("Accept", "application/json"). + SetQeuryParam("page", "1"). + SetPathParam("userId", "imroc"). SetResult(&result). Get(url) From a733c46b41da35908bd8c807ef4105b57e7b21a2 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 14:09:36 +0800 Subject: [PATCH 103/843] bump to v2.0.0-alpha.7 --- README.md | 8 +-- dump.go | 2 +- examples/find-popular-repo/go.mod | 12 +++++ examples/find-popular-repo/go.sum | 15 ++++++ examples/find-popular-repo/main.go | 86 ++++++++++++++++++++++++++++++ logger.go | 2 +- 6 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 examples/find-popular-repo/go.mod create mode 100644 examples/find-popular-repo/go.sum create mode 100644 examples/find-popular-repo/main.go diff --git a/README.md b/README.md index fe782e24..7a5a2405 100644 --- a/README.md +++ b/README.md @@ -53,11 +53,11 @@ req.SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() ```go // Use client to create a request with custom request settings -client := req.C() +client := req.C().DevMode() var result Result resp, err := client.R(). SetHeader("Accept", "application/json"). - SetQeuryParam("page", "1"). + SetQueryParam("page", "1"). SetPathParam("userId", "imroc"). SetResult(&result). Get(url) @@ -65,7 +65,7 @@ resp, err := client.R(). // You can also create a request using global client using the same chaining method resp, err := req.R(). SetHeader("Accept", "application/json"). - SetQeuryParam("page", "1"). + SetQueryParam("page", "1"). SetPathParam("userId", "imroc"). SetResult(&result). Get(url) @@ -74,7 +74,7 @@ resp, err := req.R(). // wraps global method for request, and create a request using // the default client automatically. resp, err := req.SetHeader("Accept", "application/json"). - SetQeuryParam("page", "1"). + SetQueryParam("page", "1"). SetPathParam("userId", "imroc"). SetResult(&result). Get(url) diff --git a/dump.go b/dump.go index a27f039b..69d60ce9 100644 --- a/dump.go +++ b/dump.go @@ -99,7 +99,7 @@ type dumper struct { func newDefaultDumpOptions() *DumpOptions { return &DumpOptions{ - Output: os.Stderr, + Output: os.Stdout, RequestBody: true, ResponseBody: true, ResponseHeader: true, diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod new file mode 100644 index 00000000..5e797a41 --- /dev/null +++ b/examples/find-popular-repo/go.mod @@ -0,0 +1,12 @@ +module find-popular-repo + +go 1.18 + +require github.com/imroc/req/v2 v2.0.0-alpha.7 + +require ( + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + golang.org/x/net v0.0.0-20220111093109-d55c255bac03 // indirect + golang.org/x/text v0.3.7 // indirect +) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum new file mode 100644 index 00000000..b2958219 --- /dev/null +++ b/examples/find-popular-repo/go.sum @@ -0,0 +1,15 @@ +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/imroc/req/v2 v2.0.0-alpha.6 h1:VTa/8ZXkNByK11dPB7SxxKoYLr72+7drmS1ecU6b5Os= +github.com/imroc/req/v2 v2.0.0-alpha.6/go.mod h1:4WwxUxbU5+ETYyQ02G8SxeJoYq2REnH7jEqj/AHJxhQ= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go new file mode 100644 index 00000000..1c619973 --- /dev/null +++ b/examples/find-popular-repo/main.go @@ -0,0 +1,86 @@ +package main + +import ( + "fmt" + "github.com/imroc/req/v2" + "errors" + "strconv" +) + +type Repo struct { + Name string `json:"name"` + Star int `json:"stargazers_count"` +} +type ErrorMessage struct { + Message string `json:"message"` +} + +func init() { + req.EnableDebug(true) + // Uncomment DevMode below if you want to see more details + // req.DevMode() +} + +var username = "imroc" + +func main() { + repo, star, err := findTheMostPopularRepo(username) + if err != nil { + fmt.Println(err) + return + } + fmt.Printf("The most popular repo of %s is %s, which have %d stars\n", username, repo, star) +} + + +func findTheMostPopularRepo(username string) (repo string, star int, err error) { + + var popularRepo Repo + var resp *req.Response + + for page := 1; ; page++ { + repos := []*Repo{} + errMsg := ErrorMessage{} + resp, err = req.SetHeader("Accept", "application/vnd.github.v3+json"). + SetQueryParams(map[string]string{ + "type": "owner", + "page": strconv.Itoa(page), + "per_page": "100", + "sort": "updated", + "direction": "desc", + }). + SetPathParam("username", username). + SetResult(&repos). + SetError(&errMsg). + Get("https://api.github.com/users/{username}/repos") + + if err != nil { + return + } + + if resp.IsSuccess() { // HTTP status `code >= 200 and <= 299` is considred as success + for _, repo := range repos { + if repo.Star >= popularRepo.Star { + popularRepo = *repo + } + } + if len(repo) == 100 { // Try Next page + continue + } + // All repos have been traversed, return the final result + repo = popularRepo.Name + star = popularRepo.Star + return + } else if resp.IsError() { // HTTP status `code >= 400` is considred as an error + // Extract the error message, wrap and return err + err = errors.New(errMsg.Message) + return + } + + // Unkown http status code, record and return error, here we can use + // MustString() to get response body, cuz body have already been read + // and no error returned. + err = fmt.Errorf("unkown error. status code %d; body: %s", resp.StatusCode, resp.MustString()) + return + } +} diff --git a/logger.go b/logger.go index 03a6727b..e87e61a4 100644 --- a/logger.go +++ b/logger.go @@ -20,7 +20,7 @@ func NewLogger(output io.Writer, prefix string, flag int) Logger { } func createDefaultLogger() Logger { - return NewLogger(os.Stderr, "", log.Ldate|log.Lmicroseconds) + return NewLogger(os.Stdout, "", log.Ldate|log.Lmicroseconds) } var _ Logger = (*logger)(nil) From 19236c12e2d738cb6eb573bb288b443727cb31f5 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 14:25:30 +0800 Subject: [PATCH 104/843] add a example --- README.md | 2 +- examples/find-popular-repo/Makefile | 2 ++ examples/find-popular-repo/README.md | 17 +++++++++++++++++ examples/find-popular-repo/go.sum | 4 ++-- examples/find-popular-repo/main.go | 4 ++-- 5 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 examples/find-popular-repo/Makefile create mode 100644 examples/find-popular-repo/README.md diff --git a/README.md b/README.md index 7a5a2405..fc4109e7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ A golang http request library for humans. **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.6 +go get github.com/imroc/req/v2@v2.0.0-alpha.7 ``` **Import** diff --git a/examples/find-popular-repo/Makefile b/examples/find-popular-repo/Makefile new file mode 100644 index 00000000..50ee7bfd --- /dev/null +++ b/examples/find-popular-repo/Makefile @@ -0,0 +1,2 @@ +run: + go run . \ No newline at end of file diff --git a/examples/find-popular-repo/README.md b/examples/find-popular-repo/README.md new file mode 100644 index 00000000..63e20762 --- /dev/null +++ b/examples/find-popular-repo/README.md @@ -0,0 +1,17 @@ +# find-popular-repo + +This is a runable example of req, using this Github API: [List repositories for a user](https://docs.github.com/cn/rest/reference/repos#list-repositories-for-a-user) to find someone's the most popular github repo. + +## How to run + +```bash +make run +``` + +## Modify it + +Change the global `username` vairable to your own github username: + +```go +var username = "imroc" +``` \ No newline at end of file diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index b2958219..f4deef7a 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -2,8 +2,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-alpha.6 h1:VTa/8ZXkNByK11dPB7SxxKoYLr72+7drmS1ecU6b5Os= -github.com/imroc/req/v2 v2.0.0-alpha.6/go.mod h1:4WwxUxbU5+ETYyQ02G8SxeJoYq2REnH7jEqj/AHJxhQ= +github.com/imroc/req/v2 v2.0.0-alpha.7 h1:FMrqFmIQfcIlbxWunocPeL2iowfmtBBDQh25/w/MNm8= +github.com/imroc/req/v2 v2.0.0-alpha.7/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 1c619973..c0f1e503 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -1,9 +1,9 @@ package main import ( + "errors" "fmt" "github.com/imroc/req/v2" - "errors" "strconv" ) @@ -21,6 +21,7 @@ func init() { // req.DevMode() } +// Change the name if you want var username = "imroc" func main() { @@ -32,7 +33,6 @@ func main() { fmt.Printf("The most popular repo of %s is %s, which have %d stars\n", username, repo, star) } - func findTheMostPopularRepo(username string) (repo string, star int, err error) { var popularRepo Repo From dfa0f019e680a4fb037204b10914cf76d7dc128c Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 14:27:02 +0800 Subject: [PATCH 105/843] update README --- examples/find-popular-repo/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/find-popular-repo/README.md b/examples/find-popular-repo/README.md index 63e20762..5eae3292 100644 --- a/examples/find-popular-repo/README.md +++ b/examples/find-popular-repo/README.md @@ -1,6 +1,6 @@ # find-popular-repo -This is a runable example of req, using this Github API: [List repositories for a user](https://docs.github.com/cn/rest/reference/repos#list-repositories-for-a-user) to find someone's the most popular github repo. +This is a runable example of req, using the Github API [List repositories for a user](https://docs.github.com/cn/rest/reference/repos#list-repositories-for-a-user) to find someone's the most popular github repo. ## How to run From 2a2aaac460ddf494d1d47c6ac012df337819f7e9 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 14:32:57 +0800 Subject: [PATCH 106/843] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index fc4109e7..91f22691 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,8 @@ resp, err := req.SetHeader("Accept", "application/json"). * [Header](#Header) * [Cookie](#Cookie) +You can find more complete and runnable examples [here](examples). + ### Debug **Dump the content of request and response** From 001ea72aa1399610f97a5f1d0247151fb689897e Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 14:34:27 +0800 Subject: [PATCH 107/843] update README --- examples/find-popular-repo/Makefile | 2 -- examples/find-popular-repo/README.md | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 examples/find-popular-repo/Makefile diff --git a/examples/find-popular-repo/Makefile b/examples/find-popular-repo/Makefile deleted file mode 100644 index 50ee7bfd..00000000 --- a/examples/find-popular-repo/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -run: - go run . \ No newline at end of file diff --git a/examples/find-popular-repo/README.md b/examples/find-popular-repo/README.md index 5eae3292..cbd79750 100644 --- a/examples/find-popular-repo/README.md +++ b/examples/find-popular-repo/README.md @@ -5,7 +5,7 @@ This is a runable example of req, using the Github API [List repositories for a ## How to run ```bash -make run +go run . ``` ## Modify it From 053edc8df0f7a7e635ba2a58d03bafc044fa2cee Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 17:05:52 +0800 Subject: [PATCH 108/843] update README --- README.md | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 91f22691..7d916383 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,12 @@ A golang http request library for humans. -## Features +* [Features](#Features) +* [Quick Start](#Quick-Start) +* [API Design](#API-Design) +* [Examples](#Examples) + +## Features * Simple and chainable methods for client and request settings, rich syntax sugar, less code and more efficiency. * Automatically detect charset and decode it to utf-8. @@ -12,7 +17,7 @@ A golang http request library for humans. * All settings can be changed dynamically, making it possible to debug in the production environment. * Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. -## Quick-Start +## Quick Start **Install** @@ -41,11 +46,11 @@ resp, err := client.R().Get("https://api.github.com/users/imroc") **Client Settings** ```go -// Create a client with custom client settings +// Create a client with custom client settings using chainable method client := req.C().SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() -// You can also configure the global client using the same chaining method, req wraps glbal -// method for default client +// You can also configure the global client using the same chainable method, +// req wraps global method for default client req.SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() ``` @@ -80,7 +85,33 @@ resp, err := req.SetHeader("Accept", "application/json"). Get(url) ``` -## Examples +## API Design + +**Global Wrapper for Testing Purposes** + +`req` wraps global methods of `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, just test API with minimal code like this: + +```go +req.DevMode().SetCommonBasicAuth("imroc", "123456") +req.SetBodyJsonString(`{"nickname":"roc", "email":"roc@imroc.cc"}`).Post("https://api.exmaple.com/profile") +``` + +**Conmmon Methods and Override** + +There are some similar methods between `Client` and `Request`, the pattern is like `Request.SetXXX` corresponding to `Client.SetCommonXXX`, client settings take effect for all requests, but will be overridden if the request sets the same setting. + +The common methods list is: + +* `Request.SetHeader` vs `Client.SetCommonHeader` +* `Request.SetHeaders` vs `Client.SetCommonHeaders` +* `Request.SetCookie` vs `Client.SetCommonCookie` +* `Request.SetCookies` vs `Client.SetCommonCookies` +* `Request.SetBasicAuth` vs `Client.SetCommonBasicAuth` +* `Request.SetQueryParam` vs `Client.SetCommonQueryParam` +* `Request.SetQueryParams` vs `Client.SetCommonQueryParams` +* `Request.SetQueryParamString` vs `Client.SetCommonQueryParamString` + +## Examples * [Debug](#Debug) * [PathParam](#PathParam) From a066935aa81a719d035172f1e8246a200b2eae80 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 17:23:30 +0800 Subject: [PATCH 109/843] update README --- README.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7d916383..88367a7b 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,10 @@ resp, err := req.SetHeader("Accept", "application/json"). ## API Design +**Settings and Chainable Methods** + +`Request` and `Client` is the most important object, and both need to support a lot of settings, req provide rich chainable methods out of the box, making it very simple and intuitive to initiate any request you want. + **Global Wrapper for Testing Purposes** `req` wraps global methods of `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, just test API with minimal code like this: @@ -109,15 +113,15 @@ The common methods list is: * `Request.SetBasicAuth` vs `Client.SetCommonBasicAuth` * `Request.SetQueryParam` vs `Client.SetCommonQueryParam` * `Request.SetQueryParams` vs `Client.SetCommonQueryParams` -* `Request.SetQueryParamString` vs `Client.SetCommonQueryParamString` +* `Request.SetQueryString` vs `Client.SetCommonQueryString` ## Examples * [Debug](#Debug) -* [PathParam](#PathParam) -* [QueryParam](#QueryParam) -* [Header](#Header) -* [Cookie](#Cookie) +* [Set Path Parameter](#PathParam) +* [Set Query Parameter](#QueryParam) +* [Set Header](#Header) +* [Set Cookie](#Cookie) You can find more complete and runnable examples [here](examples). @@ -219,7 +223,7 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -### PathParam +### Set Path Parameter Use `PathParam` to replace variable in the url path: @@ -246,7 +250,7 @@ resp2, err := client.Get(url2) ... ``` -### QueryParam +### Set Query Parameter ```go client := req.C().DevMode() @@ -273,7 +277,7 @@ resp2, err := client.Get(url2) ... ``` -### Header +### Set Header ```go // Let's dump the header to see what's going on @@ -308,7 +312,7 @@ resp2, err := client.R().Get(url2) ... ``` -### Cookie +### Set Cookie ```go // Let's dump the header to see what's going on From fe558d37410905ed58836fa355bd4c178fa87a04 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 17:29:14 +0800 Subject: [PATCH 110/843] update README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 88367a7b..dde85761 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ You can find more complete and runnable examples [here](examples). **Dump the content of request and response** ```go -// Set EnableDump to true, dump all content to stderr by default, +// Set EnableDump to true, dump all content to stdout by default, // including both the header and body of all request and response client := req.C().EnableDump(true) client.R().Get("https://api.github.com/users/imroc") @@ -178,7 +178,7 @@ client := req.C(). EnableDumpOnlyHeader(). // Only dump the header of request and response EnableDumpAsync(). // Dump asynchronously to improve performance EnableDumpToFile("reqdump.log") // Dump to file without printing it out -client.Get(url) +client.R().Get(url) // Enable dump with fully customized settings opt := &req.DumpOptions{ @@ -197,10 +197,10 @@ opt.ResponseBody = false client.R().Get("https://www.baidu.com/") ``` -**Logging** +**Debug Logging** ```go -// Logging is enabled by default, but only output warning and error message to stderr. +// Logging is enabled by default, but only output warning and error message to stdout. // EnableDebug set to true to enable debug level message logging. client := req.C().EnableDebug(true) client.R().Get("https://api.github.com/users/imroc") From 941ba01d3102745247d0f1d0091ca4eb62684c35 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 17:58:01 +0800 Subject: [PATCH 111/843] refactor set cert --- README.md | 23 +++++++++++++++++++++++ client.go | 28 +++++++++++++++------------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index dde85761..93b33b3c 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ The common methods list is: * [Set Query Parameter](#QueryParam) * [Set Header](#Header) * [Set Cookie](#Cookie) +* [Set Cert](#Cert) You can find more complete and runnable examples [here](examples). @@ -368,6 +369,28 @@ resp1, err := client.R().Get(url1) resp2, err := client.R().Get(url2) ``` +### Set Cert + +```go +// Set root cert and client cert from file path +client := req.C(). + SetRootCertFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files + SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file + +// You can also set root cert from string +client.SetRootCertFromString("-----BEGIN CERTIFICATE-----XXXXXX-----END CERTIFICATE-----") + +// And set client cert with +cert1, err := tls.LoadX509KeyPair("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") +if err != nil { + log.Fatalf("ERROR client certificate: %s", err) +} +// ... + +// you can add more certs if you want +client.SetCert(cert1, cert2, cert3) +``` + ## License Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index efc6abbd..f2c453e0 100644 --- a/client.go +++ b/client.go @@ -133,12 +133,12 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { return c } -func SetCerts(certs ...tls.Certificate) *Client { - return defaultClient.SetCerts(certs...) +func SetCert(certs ...tls.Certificate) *Client { + return defaultClient.SetCert(certs...) } -// SetCerts helps to set client certificates -func (c *Client) SetCerts(certs ...tls.Certificate) *Client { +// SetCert helps to set client certificates +func (c *Client) SetCert(certs ...tls.Certificate) *Client { config := c.tlsConfig() config.Certificates = append(config.Certificates, certs...) return c @@ -163,18 +163,20 @@ func (c *Client) SetRootCertFromString(pemContent string) *Client { return c } -func SetRootCertFromFile(pemFilePath string) *Client { - return defaultClient.SetRootCertFromFile(pemFilePath) +func SetRootCertFromFile(pemFiles ...string) *Client { + return defaultClient.SetRootCertFromFile(pemFiles...) } -// SetRootCertFromFile helps to set root CA cert from file -func (c *Client) SetRootCertFromFile(pemFilePath string) *Client { - rootPemData, err := ioutil.ReadFile(pemFilePath) - if err != nil { - c.log.Errorf("failed to read root cert file: %v", err) - return c +// SetRootCertFromFile helps to set root cert from files +func (c *Client) SetRootCertFromFile(pemFiles ...string) *Client { + for _, pemFile := range pemFiles { + rootPemData, err := ioutil.ReadFile(pemFile) + if err != nil { + c.log.Errorf("failed to read root cert file: %v", err) + return c + } + c.appendRootCertData(rootPemData) } - c.appendRootCertData(rootPemData) return c } From 6df917d7153a9fd75458880204d0fd602baeaeb6 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 19:47:38 +0800 Subject: [PATCH 112/843] support bearer token and update doc --- README.md | 35 ++++++++++++++++++++++++++++------- client.go | 8 ++++++++ request.go | 11 +++++++++-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 93b33b3c..c5dbce3a 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ The common methods list is: * [Set Header](#Header) * [Set Cookie](#Cookie) * [Set Cert](#Cert) +* [Set Basic Auth and Bearer Token](#Auth) You can find more complete and runnable examples [here](examples). @@ -175,10 +176,9 @@ x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E */ // Dump header content asynchronously and save it to file -client := req.C(). - EnableDumpOnlyHeader(). // Only dump the header of request and response - EnableDumpAsync(). // Dump asynchronously to improve performance - EnableDumpToFile("reqdump.log") // Dump to file without printing it out +client.EnableDumpOnlyHeader(). // Only dump the header of request and response + EnableDumpAsync(). // Dump asynchronously to improve performance + EnableDumpToFile("reqdump.log") // Dump to file without printing it out client.R().Get(url) // Enable dump with fully customized settings @@ -190,7 +190,7 @@ opt := &req.DumpOptions{ ResponseHeader: false, Async: false, } -client := req.C().SetDumpOptions(opt).EnableDump(true) +client.SetDumpOptions(opt).EnableDump(true) client.R().Get("https://www.baidu.com/") // Change settings dynamiclly @@ -230,6 +230,7 @@ Use `PathParam` to replace variable in the url path: ```go client := req.C().DevMode() + client.R(). SetPathParam("owner", "imroc"). // Set a path param, which will replace the vairable in url path SetPathParams(map[string]string{ // Set multiple path params at once @@ -255,6 +256,7 @@ resp2, err := client.Get(url2) ```go client := req.C().DevMode() + client.R(). SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url SetQueryParams(map[string]string{ // Set multiple query params at once @@ -372,9 +374,10 @@ resp2, err := client.R().Get(url2) ### Set Cert ```go +client := req.R() + // Set root cert and client cert from file path -client := req.C(). - SetRootCertFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files +client.SetRootCertFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file // You can also set root cert from string @@ -391,6 +394,24 @@ if err != nil { client.SetCert(cert1, cert2, cert3) ``` +### Set Basic Auth and Bearer Token + +```go +client := req.C() + +// Set basic auth for all request +client.SetCommonBasicAuth("imroc", "123456") + +// Set bearer token for all request +client.SetCommonBearerToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") + +// Set basic auth for a request, will override client's basic auth setting. +client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com/profile") + +// Set bearer token for a request, will override client's bearer token setting. +client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") +``` + ## License Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index f2c453e0..a96c7018 100644 --- a/client.go +++ b/client.go @@ -572,6 +572,14 @@ func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetCommonHeader(hdrUserAgentKey, userAgent) } +func SetCommonBearerAuthToken(token string) *Client { + return defaultClient.SetCommonBearerAuthToken(token) +} + +func (c *Client) SetCommonBearerAuthToken(token string) *Client { + return c.SetCommonHeader("Authorization", "Bearer "+token) +} + func SetCommonBasicAuth(username, password string) *Client { return defaultClient.SetCommonBasicAuth(username, password) } diff --git a/request.go b/request.go index be4e8925..08eb79c5 100644 --- a/request.go +++ b/request.go @@ -84,13 +84,20 @@ func (r *Request) SetError(error interface{}) *Request { return r } +func SetBearerAuthToken(token string) *Request { + return defaultClient.R().SetBearerAuthToken(token) +} + +func (r *Request) SetBearerAuthToken(token string) *Request { + return r.SetHeader("Authorization", "Bearer "+token) +} + func SetBasicAuth(username, password string) *Request { return defaultClient.R().SetBasicAuth(username, password) } func (r *Request) SetBasicAuth(username, password string) *Request { - r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) - return r + return r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) } func SetHeaders(hdrs map[string]string) *Request { From 35c8eeae04724e815196baec74a7516ca48828df Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 19:49:48 +0800 Subject: [PATCH 113/843] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c5dbce3a..91e045f2 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,7 @@ The common methods list is: * `Request.SetCookie` vs `Client.SetCommonCookie` * `Request.SetCookies` vs `Client.SetCommonCookies` * `Request.SetBasicAuth` vs `Client.SetCommonBasicAuth` +* `Request.SetBearerToken` vs `Client.SetCommonBearerToken` * `Request.SetQueryParam` vs `Client.SetCommonQueryParam` * `Request.SetQueryParams` vs `Client.SetCommonQueryParams` * `Request.SetQueryString` vs `Client.SetCommonQueryString` From 67b39df60e81f7eab488a53cf5b766950b30d30c Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 20:01:35 +0800 Subject: [PATCH 114/843] update README --- README.md | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 91e045f2..03c33a3f 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ A golang http request library for humans. ## Features * Simple and chainable methods for client and request settings, rich syntax sugar, less code and more efficiency. -* Automatically detect charset and decode it to utf-8. +* Automatically detect charset and decode it to utf-8 by default. * Powerful debugging capabilities (logging, tracing, and even dump the requests and responses' content). * All settings can be changed dynamically, making it possible to debug in the production environment. * Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. @@ -34,7 +34,8 @@ import "github.com/imroc/req/v2" **Simple GET** ```go -// Create and send a request with the global default client +// Create and send a request with the global default client, use +// DevMode to see all details, try and suprise :) req.DevMode() resp, err := req.Get("https://api.github.com/users/imroc") @@ -96,7 +97,7 @@ resp, err := req.SetHeader("Accept", "application/json"). `req` wraps global methods of `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, just test API with minimal code like this: ```go -req.DevMode().SetCommonBasicAuth("imroc", "123456") +req.SetCommonBasicAuth("imroc", "123456").DevMode() req.SetBodyJsonString(`{"nickname":"roc", "email":"roc@imroc.cc"}`).Post("https://api.exmaple.com/profile") ``` @@ -104,18 +105,6 @@ req.SetBodyJsonString(`{"nickname":"roc", "email":"roc@imroc.cc"}`).Post("https: There are some similar methods between `Client` and `Request`, the pattern is like `Request.SetXXX` corresponding to `Client.SetCommonXXX`, client settings take effect for all requests, but will be overridden if the request sets the same setting. -The common methods list is: - -* `Request.SetHeader` vs `Client.SetCommonHeader` -* `Request.SetHeaders` vs `Client.SetCommonHeaders` -* `Request.SetCookie` vs `Client.SetCommonCookie` -* `Request.SetCookies` vs `Client.SetCommonCookies` -* `Request.SetBasicAuth` vs `Client.SetCommonBasicAuth` -* `Request.SetBearerToken` vs `Client.SetCommonBearerToken` -* `Request.SetQueryParam` vs `Client.SetCommonQueryParam` -* `Request.SetQueryParams` vs `Client.SetCommonQueryParams` -* `Request.SetQueryString` vs `Client.SetCommonQueryString` - ## Examples * [Debug](#Debug) From 74adceb8232b5abe1d176462df94c3d1391f4ed3 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 21:03:08 +0800 Subject: [PATCH 115/843] update --- README.md | 79 ++++++++++++++++++++----------------------- client.go | 10 ++++++ internal/util/util.go | 12 +++++++ middleware.go | 32 +++++++++++++++--- request.go | 31 ++++++++--------- 5 files changed, 100 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index 03c33a3f..1a57167d 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,20 @@ A golang http request library for humans. * [Features](#Features) * [Quick Start](#Quick-Start) -* [API Design](#API-Design) -* [Examples](#Examples) +* [Debug](#Debug) +* [Set Path and Query Parameter](#Param) +* [Set Header and Cookie](#Header-Cookie) +* [Set Cert](#Cert) +* [Set Basic Auth and Bearer Token](#Auth) +* [Use Global Methods](#Global) ## Features -* Simple and chainable methods for client and request settings, rich syntax sugar, less code and more efficiency. +* Simple and chainable methods for client and request settings, less code and more efficiency. +* Global wrapper of both `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, make API testing minimal, and even replace tools like postman, curl with code (see [Use Global Methods](#Global)). +* There are some common settings between Client level and Request level, you can override Client settings at Request level if you want to (common settings pattern is `Request.SetXXX` vs `Client.SetCommonXXX`). * Automatically detect charset and decode it to utf-8 by default. -* Powerful debugging capabilities (logging, tracing, and even dump the requests and responses' content). +* Powerful debugging capabilities, including logging, tracing, and even dump the requests and responses' content (see [Debug](#Debug])). * All settings can be changed dynamically, making it possible to debug in the production environment. * Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. @@ -86,38 +92,7 @@ resp, err := req.SetHeader("Accept", "application/json"). Get(url) ``` -## API Design - -**Settings and Chainable Methods** - -`Request` and `Client` is the most important object, and both need to support a lot of settings, req provide rich chainable methods out of the box, making it very simple and intuitive to initiate any request you want. - -**Global Wrapper for Testing Purposes** - -`req` wraps global methods of `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, just test API with minimal code like this: - -```go -req.SetCommonBasicAuth("imroc", "123456").DevMode() -req.SetBodyJsonString(`{"nickname":"roc", "email":"roc@imroc.cc"}`).Post("https://api.exmaple.com/profile") -``` - -**Conmmon Methods and Override** - -There are some similar methods between `Client` and `Request`, the pattern is like `Request.SetXXX` corresponding to `Client.SetCommonXXX`, client settings take effect for all requests, but will be overridden if the request sets the same setting. - -## Examples - -* [Debug](#Debug) -* [Set Path Parameter](#PathParam) -* [Set Query Parameter](#QueryParam) -* [Set Header](#Header) -* [Set Cookie](#Cookie) -* [Set Cert](#Cert) -* [Set Basic Auth and Bearer Token](#Auth) - -You can find more complete and runnable examples [here](examples). - -### Debug +## Debug **Dump the content of request and response** @@ -214,7 +189,9 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -### Set Path Parameter +## Set Path and Query Parameter + +**Set Path Parameter** Use `PathParam` to replace variable in the url path: @@ -226,7 +203,7 @@ client.R(). SetPathParams(map[string]string{ // Set multiple path params at once "repo": "req", "path": "README.md", - }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") + }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") // path parameter will replace path variable in the url /* Output 2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md ... @@ -242,7 +219,7 @@ resp2, err := client.Get(url2) ... ``` -### Set Query Parameter +**Set Query Parameter** ```go client := req.C().DevMode() @@ -270,8 +247,9 @@ resp2, err := client.Get(url2) ... ``` -### Set Header +## Set Header and Cookie +**Set Header** ```go // Let's dump the header to see what's going on client := req.C().EnableDumpOnlyHeader() @@ -305,7 +283,7 @@ resp2, err := client.R().Get(url2) ... ``` -### Set Cookie +**Set Cookie** ```go // Let's dump the header to see what's going on @@ -361,7 +339,7 @@ resp1, err := client.R().Get(url1) resp2, err := client.R().Get(url2) ``` -### Set Cert +## Set Cert ```go client := req.R() @@ -384,7 +362,7 @@ if err != nil { client.SetCert(cert1, cert2, cert3) ``` -### Set Basic Auth and Bearer Token +## Set Basic Auth and Bearer Token ```go client := req.C() @@ -402,6 +380,21 @@ client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") ``` +## Use Global Methods + +`req` wrap methods of both `Client` and `Request` with global methods, very convenient when doing api testing, no need to explicitly create clients and requests to minimize the amount of code. + +```go +req.SetTimeout(5 * time.Second). + SetCommonBasicAuth("imroc", "123456"). + SetUserAgent("my api client"). + DevMode() + +req.SetQueryParam("page", "2"). + SetHeader("Accept", "application/json"). + Get("https://api.example.com/repos") +``` + ## License Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index a96c7018..22d95312 100644 --- a/client.go +++ b/client.go @@ -49,6 +49,7 @@ type Client struct { XMLUnmarshal func(data []byte, v interface{}) error Debug bool + outputDirectory string disableAutoReadResponse bool scheme string log Logger @@ -117,6 +118,15 @@ func (c *Client) R() *Request { } } +func SetOutputDirectory(dir string) *Client { + return defaultClient.SetOutputDirectory(dir) +} + +func (c *Client) SetOutputDirectory(dir string) *Client { + c.outputDirectory = dir + return c +} + func SetCertFromFile(certFile, keyFile string) *Client { return defaultClient.SetCertFromFile(certFile, keyFile) } diff --git a/internal/util/util.go b/internal/util/util.go index 4286f1b6..43d2582d 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -3,6 +3,7 @@ package util import ( "bytes" "encoding/base64" + "os" "reflect" "regexp" "strings" @@ -82,3 +83,14 @@ func basicAuth(username, password string) string { func BasicAuthHeaderValue(username, password string) string { return "Basic " + basicAuth(username, password) } + +func CreateDirectory(dir string) (err error) { + if _, err = os.Stat(dir); err != nil { + if os.IsNotExist(err) { + if err = os.MkdirAll(dir, 0755); err != nil { + return + } + } + } + return +} diff --git a/middleware.go b/middleware.go index 74ae2165..9ed25e1a 100644 --- a/middleware.go +++ b/middleware.go @@ -7,6 +7,8 @@ import ( "io/ioutil" "net/http" "net/url" + "os" + "path/filepath" "strings" ) @@ -47,24 +49,44 @@ func parseResponseBody(c *Client, r *Response) (err error) { return } -func handleDownload(c *Client, r *Response) error { +func handleDownload(c *Client, r *Response) (err error) { if !r.Request.isSaveResponse { return nil } - var body io.ReadCloser + if r.body != nil { // already read body = ioutil.NopCloser(bytes.NewReader(r.body)) } else { body = r.Body } + var output io.WriteCloser + if r.Request.outputFile != "" { + file := r.Request.outputFile + if c.outputDirectory != "" && !filepath.IsAbs(file) { + file = c.outputDirectory + string(filepath.Separator) + file + } + + file = filepath.Clean(file) + + if err = util.CreateDirectory(filepath.Dir(file)); err != nil { + return err + } + output, err = os.Create(file) + if err != nil { + return + } + } else { + output = r.Request.output // must not nil + } + defer func() { body.Close() - r.Request.output.Close() + output.Close() }() - _, err := io.Copy(r.Request.output, body) - return err + _, err = io.Copy(output, body) + return } func parseRequestHeader(c *Client, r *Request) error { diff --git a/request.go b/request.go index 08eb79c5..ae374b1d 100644 --- a/request.go +++ b/request.go @@ -15,16 +15,18 @@ import ( // Request is the http request type Request struct { - URL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - Result interface{} - Error interface{} - error error - client *Client - RawRequest *http.Request + URL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + Result interface{} + Error interface{} + error error + client *Client + RawRequest *http.Request + + outputFile string isSaveResponse bool isMultiPart bool output io.WriteCloser @@ -129,12 +131,9 @@ func SetOutputFile(file string) *Request { } func (r *Request) SetOutputFile(file string) *Request { - output, err := os.Create(file) - if err != nil { - r.appendError(err) - return r - } - return r.SetOutput(output) + r.isSaveResponse = true + r.outputFile = file + return r } func SetOutput(output io.WriteCloser) *Request { From 033b75e10e28529ae93f5642c44087f1fad18e49 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 21:04:54 +0800 Subject: [PATCH 116/843] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 1a57167d..c640935d 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ A golang http request library for humans. +**Table of Contents** + * [Features](#Features) * [Quick Start](#Quick-Start) * [Debug](#Debug) From 3382c35636cab57647e981d477912fa8f7388675 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 24 Jan 2022 21:31:55 +0800 Subject: [PATCH 117/843] bump to v2.0.0-alpha.8 --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c640935d..0fa5262d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ A golang http request library for humans. * [Debug](#Debug) * [Set Path and Query Parameter](#Param) * [Set Header and Cookie](#Header-Cookie) -* [Set Cert](#Cert) +* [Set Certificates](#Cert) * [Set Basic Auth and Bearer Token](#Auth) * [Use Global Methods](#Global) @@ -21,7 +21,7 @@ A golang http request library for humans. * Global wrapper of both `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, make API testing minimal, and even replace tools like postman, curl with code (see [Use Global Methods](#Global)). * There are some common settings between Client level and Request level, you can override Client settings at Request level if you want to (common settings pattern is `Request.SetXXX` vs `Client.SetCommonXXX`). * Automatically detect charset and decode it to utf-8 by default. -* Powerful debugging capabilities, including logging, tracing, and even dump the requests and responses' content (see [Debug](#Debug])). +* Powerful debugging capabilities, including logging, tracing, and even dump the requests and responses' content (see [Debug](#Debug)). * All settings can be changed dynamically, making it possible to debug in the production environment. * Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. @@ -30,7 +30,7 @@ A golang http request library for humans. **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.7 +go get github.com/imroc/req/v2@v2.0.0-alpha.8 ``` **Import** @@ -341,7 +341,7 @@ resp1, err := client.R().Get(url1) resp2, err := client.R().Get(url2) ``` -## Set Cert +## Set Certificates ```go client := req.R() From a396a9f285d181be5d4cf27724891569afa9959e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 10:20:41 +0800 Subject: [PATCH 118/843] EnableDebug-->EnableDebugLog and update README --- README.md | 18 ++++++++---------- client.go | 16 ++++++++-------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 0fa5262d..78c905a4 100644 --- a/README.md +++ b/README.md @@ -8,22 +8,20 @@ A golang http request library for humans. * [Features](#Features) * [Quick Start](#Quick-Start) -* [Debug](#Debug) +* [Debugging](#Debugging) * [Set Path and Query Parameter](#Param) * [Set Header and Cookie](#Header-Cookie) * [Set Certificates](#Cert) * [Set Basic Auth and Bearer Token](#Auth) -* [Use Global Methods](#Global) +* [Use Global Wrapper Methods](#Global) ## Features -* Simple and chainable methods for client and request settings, less code and more efficiency. -* Global wrapper of both `Client` and `Request` for testing purposes, so that you don't even need to create the client or request explicitly, make API testing minimal, and even replace tools like postman, curl with code (see [Use Global Methods](#Global)). -* There are some common settings between Client level and Request level, you can override Client settings at Request level if you want to (common settings pattern is `Request.SetXXX` vs `Client.SetCommonXXX`). -* Automatically detect charset and decode it to utf-8 by default. -* Powerful debugging capabilities, including logging, tracing, and even dump the requests and responses' content (see [Debug](#Debug)). -* All settings can be changed dynamically, making it possible to debug in the production environment. -* Easy to integrate with existing code, just replace the Transport of existing http.Client, then you can dump content as req to debug APIs. +* Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. +* Powerful [Debugging](#Debugging) capabilities, including debug logs, performance traces, and even dump complete request and response content. +* [Use Global Wrapper Methods](#Global) to test HTTP APIs with minimal code. +* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. +* Exportable `Transport`, just replace the Transport of existing http.Client with `*req.Transport`, then you can dump the content as `req` does to debug APIs with minimal code change. ## Quick Start @@ -94,7 +92,7 @@ resp, err := req.SetHeader("Accept", "application/json"). Get(url) ``` -## Debug +## Debugging **Dump the content of request and response** diff --git a/client.go b/client.go index 22d95312..cfd41930 100644 --- a/client.go +++ b/client.go @@ -46,8 +46,8 @@ type Client struct { JSONMarshal func(v interface{}) ([]byte, error) JSONUnmarshal func(data []byte, v interface{}) error XMLMarshal func(v interface{}) ([]byte, error) - XMLUnmarshal func(data []byte, v interface{}) error - Debug bool + XMLUnmarshal func(data []byte, v interface{}) error + DebugLog bool outputDirectory string disableAutoReadResponse bool @@ -315,12 +315,12 @@ const ( userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" ) -func EnableDebug(enable bool) *Client { - return defaultClient.EnableDebug(enable) +func EnableDebugLog(enable bool) *Client { + return defaultClient.EnableDebugLog(enable) } -func (c *Client) EnableDebug(enable bool) *Client { - c.Debug = enable +func (c *Client) EnableDebugLog(enable bool) *Client { + c.DebugLog = enable return c } @@ -333,7 +333,7 @@ func DevMode() *Client { // data from some sites. func (c *Client) DevMode() *Client { return c.EnableDumpAll(). - EnableDebug(true). + EnableDebugLog(true). SetUserAgent(userAgentChrome) } @@ -788,7 +788,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { setupRequest(r) - if c.Debug { + if c.DebugLog { c.log.Debugf("%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) } From 41204ccd71b59dd3bfcc230fac3c9c3ffbf9e0b6 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 10:20:58 +0800 Subject: [PATCH 119/843] update README --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 78c905a4..a72ad4a8 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ A golang http request library for humans. * [Features](#Features) * [Quick Start](#Quick-Start) * [Debugging](#Debugging) -* [Set Path and Query Parameter](#Param) +* [Set Path Parameter and Query Parameter](#Param) * [Set Header and Cookie](#Header-Cookie) * [Set Certificates](#Cert) * [Set Basic Auth and Bearer Token](#Auth) @@ -20,7 +20,7 @@ A golang http request library for humans. * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful [Debugging](#Debugging) capabilities, including debug logs, performance traces, and even dump complete request and response content. * [Use Global Wrapper Methods](#Global) to test HTTP APIs with minimal code. -* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. +* Detect the charset of response body and decode it to UTF-8 automatically to avoid garbled characters by default. * Exportable `Transport`, just replace the Transport of existing http.Client with `*req.Transport`, then you can dump the content as `req` does to debug APIs with minimal code change. ## Quick Start @@ -163,12 +163,12 @@ opt.ResponseBody = false client.R().Get("https://www.baidu.com/") ``` -**Debug Logging** +**Debug Log** ```go // Logging is enabled by default, but only output warning and error message to stdout. -// EnableDebug set to true to enable debug level message logging. -client := req.C().EnableDebug(true) +// EnableDebugLog set to true to enable debug level message logging. +client := req.C().EnableDebugLog(true) client.R().Get("https://api.github.com/users/imroc") // Output // 2022/01/23 14:33:04.755019 DEBUG [req] GET https://api.github.com/users/imroc @@ -189,7 +189,7 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -## Set Path and Query Parameter +## Set Path Parameter and Query Parameter **Set Path Parameter** From 4c53b47934ecbb97e3006a88f8280ded2d033e18 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 11:00:45 +0800 Subject: [PATCH 120/843] update README --- README.md | 68 +++++++++++++------------------------------------------ 1 file changed, 16 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index a72ad4a8..20b4a532 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![GoDoc](https://pkg.go.dev/badge/github.com/imroc/req.svg)](https://pkg.go.dev/github.com/imroc/req) -A golang http request library for humans. +Simplified golang http client library with magic, happy sending requests, less code and more efficiency. **Table of Contents** @@ -37,59 +37,23 @@ go get github.com/imroc/req/v2@v2.0.0-alpha.8 import "github.com/imroc/req/v2" ``` -**Simple GET** - ```go -// Create and send a request with the global default client, use -// DevMode to see all details, try and suprise :) +// For test, you can create and send a request with the global default +// client, use DevMode to see all details, try and suprise :) req.DevMode() -resp, err := req.Get("https://api.github.com/users/imroc") - -// Create and send a request with the custom client -client := req.C().DevMode() -resp, err := client.R().Get("https://api.github.com/users/imroc") -``` - -**Client Settings** - -```go -// Create a client with custom client settings using chainable method -client := req.C().SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() - -// You can also configure the global client using the same chainable method, -// req wraps global method for default client -req.SetUserAgent("my-custom-client").SetTimeout(5 * time.Second).DevMode() -``` - -**Request Settings** - -```go -// Use client to create a request with custom request settings -client := req.C().DevMode() -var result Result -resp, err := client.R(). - SetHeader("Accept", "application/json"). - SetQueryParam("page", "1"). - SetPathParam("userId", "imroc"). - SetResult(&result). - Get(url) - -// You can also create a request using global client using the same chaining method -resp, err := req.R(). - SetHeader("Accept", "application/json"). - SetQueryParam("page", "1"). - SetPathParam("userId", "imroc"). - SetResult(&result). - Get(url) - -// You can even also create a request without calling R(), cuz req -// wraps global method for request, and create a request using -// the default client automatically. -resp, err := req.SetHeader("Accept", "application/json"). +req.Get("https://api.github.com/users/imroc") + +// Create and send a request with the custom client and settings +client := req.C(). // Use C() to create a client + SetUserAgent("my-custom-client"). // Chainable client settings + SetTimeout(5 * time.Second). + DevMode() +resp, err := client.R(). // Use R() to create a request + SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings + SetPathParam("username", "imroc"). SetQueryParam("page", "1"). - SetPathParam("userId", "imroc"). SetResult(&result). - Get(url) + Get("https://api.github.com/users/{username}/repos") ``` ## Debugging @@ -140,11 +104,11 @@ x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E {"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} */ -// Dump header content asynchronously and save it to file +// Customize dump settings with predefined convenience settings. client.EnableDumpOnlyHeader(). // Only dump the header of request and response EnableDumpAsync(). // Dump asynchronously to improve performance EnableDumpToFile("reqdump.log") // Dump to file without printing it out -client.R().Get(url) +client.R().Get(url) // Send request to see the content that have been dumpped // Enable dump with fully customized settings opt := &req.DumpOptions{ From bc3170e8edb5065f64b0bbf13958a204dbed33f4 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 11:03:00 +0800 Subject: [PATCH 121/843] update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 20b4a532..c31c7460 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ client.R().Get("https://imroc.cc") **Set Path Parameter** -Use `PathParam` to replace variable in the url path: +Use `SetPathParam` or `SetPathParams` to replace variable in the url path: ```go client := req.C().DevMode() @@ -185,6 +185,8 @@ resp2, err := client.Get(url2) **Set Query Parameter** +Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query parameter: + ```go client := req.C().DevMode() From 877a5dabd28c99fc6729c2b8981e057109088941 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 11:16:19 +0800 Subject: [PATCH 122/843] update README --- README.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c31c7460..2d159d1b 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,13 @@ Simplified golang http client library with magic, happy sending requests, less c * [Set Header and Cookie](#Header-Cookie) * [Set Certificates](#Cert) * [Set Basic Auth and Bearer Token](#Auth) -* [Use Global Wrapper Methods](#Global) +* [Testing with Global Wrapper Methods](#Global) ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful [Debugging](#Debugging) capabilities, including debug logs, performance traces, and even dump complete request and response content. -* [Use Global Wrapper Methods](#Global) to test HTTP APIs with minimal code. +* [Testing with Global Wrapper Methods](#Global) with minimal code. * Detect the charset of response body and decode it to UTF-8 automatically to avoid garbled characters by default. * Exportable `Transport`, just replace the Transport of existing http.Client with `*req.Transport`, then you can dump the content as `req` does to debug APIs with minimal code change. @@ -346,16 +346,23 @@ client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") ``` -## Use Global Methods +## Use Global Methods to Test -`req` wrap methods of both `Client` and `Request` with global methods, very convenient when doing api testing, no need to explicitly create clients and requests to minimize the amount of code. +`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. ```go +// Call the global methods just like the Client's methods, +// so you can treat package name `req` as a Client, and +// you don't need to create any client explicitly. req.SetTimeout(5 * time.Second). SetCommonBasicAuth("imroc", "123456"). SetUserAgent("my api client"). DevMode() +// Call the global method just like the Request's method, +// which will create request automatically using the default +// client, so you can treat package name `req` as a Request, +// and you don't need to create request explicitly. req.SetQueryParam("page", "2"). SetHeader("Accept", "application/json"). Get("https://api.example.com/repos") From 02e05ea1c24e2624154728c0539769d13c589d74 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 11:17:07 +0800 Subject: [PATCH 123/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2d159d1b..b401193e 100644 --- a/README.md +++ b/README.md @@ -346,7 +346,7 @@ client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") ``` -## Use Global Methods to Test +## Testing with Use Global Methods `req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. From aca96044731cb701c0fa920bfb1340c130af0ce8 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 12:12:14 +0800 Subject: [PATCH 124/843] small refactor --- README.md | 2 +- client.go | 12 ++++++++++-- decode.go | 7 ++++++- internal/charsetutil/charsetutil.go | 10 ++++++++-- request.go | 1 - transport.go | 6 ++++-- 6 files changed, 29 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b401193e..777ad53a 100644 --- a/README.md +++ b/README.md @@ -370,4 +370,4 @@ req.SetQueryParam("page", "2"). ## License -Req released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file +`Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index cfd41930..ffb92c45 100644 --- a/client.go +++ b/client.go @@ -46,8 +46,8 @@ type Client struct { JSONMarshal func(v interface{}) ([]byte, error) JSONUnmarshal func(data []byte, v interface{}) error XMLMarshal func(v interface{}) ([]byte, error) - XMLUnmarshal func(data []byte, v interface{}) error - DebugLog bool + XMLUnmarshal func(data []byte, v interface{}) error + DebugLog bool outputDirectory string disableAutoReadResponse bool @@ -382,6 +382,7 @@ func SetResponseOptions(opt *ResponseOptions) *Client { // SetResponseOptions set the ResponseOptions for the underlying Transport. func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { if opt == nil { + c.log.Warnf("ignore nil *ResponseOptions") return c } c.t.ResponseOptions = opt @@ -729,6 +730,7 @@ func (c *Client) Clone() *Client { // C create a new client. func C() *Client { t := &Transport{ + ResponseOptions: &ResponseOptions{}, ForceAttemptHTTP2: true, Proxy: http.ProxyFromEnvironment, MaxIdleConns: 100, @@ -764,6 +766,12 @@ func C() *Client { XMLMarshal: xml.Marshal, XMLUnmarshal: xml.Unmarshal, } + + t.DebugFunc = func(format string, v ...interface{}) { + if c.DebugLog { + c.log.Debugf(format, v...) + } + } return c } diff --git a/decode.go b/decode.go index 473b4e62..4a24b1a3 100644 --- a/decode.go +++ b/decode.go @@ -14,8 +14,13 @@ func (d *decodeReaderCloser) Read(p []byte) (n int, err error) { return d.decodeReader.Read(p) } +func newAutoDecodeReadCloser(input io.ReadCloser, t *Transport) *autoDecodeReadCloser { + return &autoDecodeReadCloser{ReadCloser: input, t: t} +} + type autoDecodeReadCloser struct { io.ReadCloser + t *Transport decodeReader io.Reader detected bool peek []byte @@ -27,7 +32,7 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc := charsetutil.FindEncoding(p) + enc := charsetutil.FindEncoding(p, a.t.DebugFunc) if enc == nil { return } diff --git a/internal/charsetutil/charsetutil.go b/internal/charsetutil/charsetutil.go index 7c3d05bf..1ee956e1 100644 --- a/internal/charsetutil/charsetutil.go +++ b/internal/charsetutil/charsetutil.go @@ -17,7 +17,7 @@ var boms = []struct { {[]byte{0xef, 0xbb, 0xbf}, "utf-8"}, } -func FindEncoding(content []byte) encoding.Encoding { +func FindEncoding(content []byte, debugf func(format string, v ...interface{})) encoding.Encoding { if len(content) == 0 { return nil } @@ -29,7 +29,13 @@ func FindEncoding(content []byte) encoding.Encoding { } } } - e, _ := prescan(content) + e, name := prescan(content) + if strings.ToLower(name) == "utf-8" { + if debugf != nil { + debugf("%s charset found in the meta tag content, no need to decode", name) + } + return nil + } if e != nil { return e } diff --git a/request.go b/request.go index ae374b1d..dd46d127 100644 --- a/request.go +++ b/request.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net/http" urlpkg "net/url" - "os" "strings" ) diff --git a/transport.go b/transport.go index 1419f0d1..af131f89 100644 --- a/transport.go +++ b/transport.go @@ -270,7 +270,9 @@ type Transport struct { ForceAttemptHTTP2 bool *ResponseOptions - dump *dumper + + dump *dumper + DebugFunc func(format string, v ...interface{}) } func (t *Transport) handleResponseBody(res *http.Response) { @@ -326,7 +328,7 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { res.Body = &decodeReaderCloser{res.Body, decodeReader} return } - res.Body = &autoDecodeReadCloser{ReadCloser: res.Body} + res.Body = newAutoDecodeReadCloser(res.Body, t) } // A cancelKey is the key of the reqCanceler map. From 684289acb6c30e516cff307c593e3932e7896698 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 12:13:23 +0800 Subject: [PATCH 125/843] rename DebugFunc to Debugf --- client.go | 2 +- decode.go | 2 +- transport.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index ffb92c45..9f837699 100644 --- a/client.go +++ b/client.go @@ -767,7 +767,7 @@ func C() *Client { XMLUnmarshal: xml.Unmarshal, } - t.DebugFunc = func(format string, v ...interface{}) { + t.Debugf = func(format string, v ...interface{}) { if c.DebugLog { c.log.Debugf(format, v...) } diff --git a/decode.go b/decode.go index 4a24b1a3..82522f35 100644 --- a/decode.go +++ b/decode.go @@ -32,7 +32,7 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc := charsetutil.FindEncoding(p, a.t.DebugFunc) + enc := charsetutil.FindEncoding(p, a.t.Debugf) if enc == nil { return } diff --git a/transport.go b/transport.go index af131f89..e8e26e7d 100644 --- a/transport.go +++ b/transport.go @@ -271,8 +271,8 @@ type Transport struct { *ResponseOptions - dump *dumper - DebugFunc func(format string, v ...interface{}) + dump *dumper + Debugf func(format string, v ...interface{}) } func (t *Transport) handleResponseBody(res *http.Response) { From 9b83e2ca08054245e52e08c708251e9bed37c411 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:17:47 +0800 Subject: [PATCH 126/843] refactor response body --- response_body.go | 66 +++++++++++++++++------------------------------- 1 file changed, 23 insertions(+), 43 deletions(-) diff --git a/response_body.go b/response_body.go index 1ebf1af3..ec8c87ba 100644 --- a/response_body.go +++ b/response_body.go @@ -1,46 +1,24 @@ package req import ( - "encoding/json" - "encoding/xml" "io/ioutil" "strings" ) -func (r *Response) MustUnmarshalJson(v interface{}) { - err := r.UnmarshalJson(v) - if err != nil { - panic(err) - } -} - func (r *Response) UnmarshalJson(v interface{}) error { - b, err := r.Bytes() + b, err := r.ToBytes() if err != nil { return err } - return json.Unmarshal(b, v) -} - -func (r *Response) MustUnmarshalXml(v interface{}) { - err := r.UnmarshalXml(v) - if err != nil { - panic(err) - } + return r.Request.client.JSONUnmarshal(b, v) } func (r *Response) UnmarshalXml(v interface{}) error { - b, err := r.Bytes() + b, err := r.ToBytes() if err != nil { return err } - return xml.Unmarshal(b, v) -} -func (r *Response) MustUnmarshal(v interface{}) { - err := r.Unmarshal(v) - if err != nil { - panic(err) - } + return r.Request.client.XMLUnmarshal(b, v) } func (r *Response) Unmarshal(v interface{}) error { @@ -53,20 +31,30 @@ func (r *Response) Unmarshal(v interface{}) error { return r.UnmarshalJson(v) } -func (r *Response) MustString() string { - b, err := r.Bytes() - if err != nil { - panic(err) - } - return string(b) +// Bytes return the response body as []bytes that hava already been read, could be +// nil if not read, the following cases are already read: +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, +// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +func (r *Response) Bytes() []byte { + return r.body +} + +// String return the response body as string that hava already been read, could be +// nil if not read, the following cases are already read: +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, +// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +func (r *Response) String() string { + return string(r.body) } -func (r *Response) String() (string, error) { - b, err := r.Bytes() +func (r *Response) ToString() (string, error) { + b, err := r.ToBytes() return string(b), err } -func (r *Response) Bytes() ([]byte, error) { +func (r *Response) ToBytes() ([]byte, error) { if r.body != nil { return r.body, nil } @@ -78,11 +66,3 @@ func (r *Response) Bytes() ([]byte, error) { r.body = body return body, nil } - -func (r *Response) MustBytes() []byte { - b, err := r.Bytes() - if err != nil { - panic(err) - } - return b -} From f513a067f0e78a56cb60ce9bcda9d317d0d902c5 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:20:28 +0800 Subject: [PATCH 127/843] update example --- examples/find-popular-repo/main.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index c0f1e503..1916eb2d 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "github.com/imroc/req" "github.com/imroc/req/v2" "strconv" ) @@ -78,9 +79,9 @@ func findTheMostPopularRepo(username string) (repo string, star int, err error) } // Unkown http status code, record and return error, here we can use - // MustString() to get response body, cuz body have already been read - // and no error returned. - err = fmt.Errorf("unkown error. status code %d; body: %s", resp.StatusCode, resp.MustString()) + // String() to get response body, cuz response body have already been read + // and no error returned, do not need to use ToString(). + err = fmt.Errorf("unkown error. status code %d; body: %s", resp.StatusCode, resp.String()) return } } From 0a7bd2b92d0da0367a36b77774669f1679308905 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:20:59 +0800 Subject: [PATCH 128/843] bump to v2.0.0-alpha.9 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 777ad53a..9d2f518d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.8 +go get github.com/imroc/req/v2@v2.0.0-alpha.9 ``` **Import** From 2ab7eb58a667e946fd7f2c051ac5d19394ebc57e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:26:18 +0800 Subject: [PATCH 129/843] fix Bytes() and MustString() --- README.md | 2 +- client.go | 2 +- examples/find-popular-repo/go.mod | 11 ++--------- examples/find-popular-repo/go.sum | 4 ++-- examples/find-popular-repo/main.go | 4 ++-- middleware.go | 2 +- 6 files changed, 9 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 9d2f518d..212e25f0 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.9 +go get github.com/imroc/req/v2@v2.0.0-alpha.10 ``` **Import** diff --git a/client.go b/client.go index 9f837699..fcca1140 100644 --- a/client.go +++ b/client.go @@ -811,7 +811,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } if !c.disableAutoReadResponse && !r.isSaveResponse { // auto read response body - _, err = resp.Bytes() + _, err = resp.ToBytes() if err != nil { return } diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 5e797a41..4ef8552d 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -1,12 +1,5 @@ module find-popular-repo -go 1.18 +go 1.13 -require github.com/imroc/req/v2 v2.0.0-alpha.7 - -require ( - github.com/hashicorp/errwrap v1.0.0 // indirect - github.com/hashicorp/go-multierror v1.1.1 // indirect - golang.org/x/net v0.0.0-20220111093109-d55c255bac03 // indirect - golang.org/x/text v0.3.7 // indirect -) +require github.com/imroc/req/v2 v2.0.0-alpha.10 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index f4deef7a..5ab73cda 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -2,8 +2,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-alpha.7 h1:FMrqFmIQfcIlbxWunocPeL2iowfmtBBDQh25/w/MNm8= -github.com/imroc/req/v2 v2.0.0-alpha.7/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= +github.com/imroc/req/v2 v2.0.0-alpha.9 h1:EQKXapWrxmLgYA9VoW7d1ha36J5rwQiEc2v5V/srgAw= +github.com/imroc/req/v2 v2.0.0-alpha.9/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 1916eb2d..22b922d4 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -3,9 +3,9 @@ package main import ( "errors" "fmt" - "github.com/imroc/req" - "github.com/imroc/req/v2" "strconv" + + "github.com/imroc/req/v2" ) type Repo struct { diff --git a/middleware.go b/middleware.go index 9ed25e1a..4253c102 100644 --- a/middleware.go +++ b/middleware.go @@ -24,7 +24,7 @@ func parseResponseBody(c *Client, r *Response) (err error) { if r.StatusCode == http.StatusNoContent { return } - body, err := r.Bytes() // in case req.SetResult with cient.DisalbeAutoReadResponse(true) + body, err := r.ToBytes() // in case req.SetResult or req.SetError with cient.DisalbeAutoReadResponse(true) if err != nil { return } From e2e234e075306fefb58c6a3fc619ddf475126af5 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:31:24 +0800 Subject: [PATCH 130/843] fix example --- examples/find-popular-repo/go.sum | 4 ++-- examples/find-popular-repo/main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 5ab73cda..64e8d105 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -2,8 +2,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-alpha.9 h1:EQKXapWrxmLgYA9VoW7d1ha36J5rwQiEc2v5V/srgAw= -github.com/imroc/req/v2 v2.0.0-alpha.9/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= +github.com/imroc/req/v2 v2.0.0-alpha.10 h1:twsgj8MfXqMXg1bzzIeHM5j8lXhpBlHYj5PW7Lj7gW8= +github.com/imroc/req/v2 v2.0.0-alpha.10/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 22b922d4..9cc4a9e1 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -17,7 +17,7 @@ type ErrorMessage struct { } func init() { - req.EnableDebug(true) + req.EnableDebugLog(true) // Uncomment DevMode below if you want to see more details // req.DevMode() } From 8ab401c33c7dd046bd7abcea252cac4dcf305364 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 13:44:28 +0800 Subject: [PATCH 131/843] update README --- README.md | 78 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 212e25f0..3da618ec 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,11 @@ Simplified golang http client library with magic, happy sending requests, less c * [Features](#Features) * [Quick Start](#Quick-Start) * [Debugging](#Debugging) -* [Set Path Parameter and Query Parameter](#Param) -* [Set Header and Cookie](#Header-Cookie) -* [Set Certificates](#Cert) -* [Set Basic Auth and Bearer Token](#Auth) * [Testing with Global Wrapper Methods](#Global) +* [Path Parameter and Query Parameter](#Param) +* [Header and Cookie](#Header-Cookie) +* [Custom Client and Root Certificates](#Cert) +* [Basic Auth and Bearer Token](#Auth) ## Features @@ -146,14 +146,36 @@ client.SetLogger(logger) **DevMode** -If you want to enable all debug features (dump, debug logging and tracing), just call `DevMode()`: +If you want to enable all debug features (dump, debug log and tracing), just call `DevMode()`: ```go client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -## Set Path Parameter and Query Parameter +## Testing with Use Global Methods + +`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. + +```go +// Call the global methods just like the Client's methods, +// so you can treat package name `req` as a Client, and +// you don't need to create any client explicitly. +req.SetTimeout(5 * time.Second). + SetCommonBasicAuth("imroc", "123456"). + SetUserAgent("my api client"). + DevMode() + +// Call the global method just like the Request's method, +// which will create request automatically using the default +// client, so you can treat package name `req` as a Request, +// and you don't need to create request explicitly. +req.SetQueryParam("page", "2"). + SetHeader("Accept", "application/json"). + Get("https://api.example.com/repos") +``` + +## Path Parameter and Query Parameter **Set Path Parameter** @@ -213,7 +235,7 @@ resp2, err := client.Get(url2) ... ``` -## Set Header and Cookie +## Header and Cookie **Set Header** ```go @@ -305,7 +327,7 @@ resp1, err := client.R().Get(url1) resp2, err := client.R().Get(url2) ``` -## Set Certificates +## Custom Client and Root Certificates ```go client := req.R() @@ -328,7 +350,7 @@ if err != nil { client.SetCert(cert1, cert2, cert3) ``` -## Set Basic Auth and Bearer Token +## Basic Auth and Bearer Token ```go client := req.C() @@ -346,28 +368,32 @@ client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") ``` -## Testing with Use Global Methods - -`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. +## Download and Upload ```go -// Call the global methods just like the Client's methods, -// so you can treat package name `req` as a Client, and -// you don't need to create any client explicitly. -req.SetTimeout(5 * time.Second). - SetCommonBasicAuth("imroc", "123456"). - SetUserAgent("my api client"). - DevMode() +// Create a client with default download direcotry +client := req.C().SetOutputDirectory("/path/to/download") + +// Download to relative file path, this will be downloaded +// to /path/to/download/test.jpg +client.R().SetOutputFile("test.jpg").Get(url) + +// Download to absolute file path, ignore the output directory +// setting from Client +client.R().SetOutputFile("/tmp/test.jpg").Get(url) + +// You can also save file to any `io.WriteCloser` +file, err := os.Create("/tmp/test.jpg") +if err != nil { + fmt.Println(err) + return +} +client.R().SetOutput(file).Get(url) -// Call the global method just like the Request's method, -// which will create request automatically using the default -// client, so you can treat package name `req` as a Request, -// and you don't need to create request explicitly. -req.SetQueryParam("page", "2"). - SetHeader("Accept", "application/json"). - Get("https://api.example.com/repos") ``` + + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file From f45d6d77d80c91a8440f2fd558c783091de32088 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 17:36:38 +0800 Subject: [PATCH 132/843] support upload --- client.go | 30 ++++++++ examples/upload/uploadclient/go.mod | 3 + examples/upload/uploadserver/go.mod | 5 ++ examples/upload/uploadserver/go.sum | 54 ++++++++++++++ examples/upload/uploadserver/main.go | 35 ++++++++++ file.go | 9 +++ go.mod | 1 + go.sum | 48 +++++++++++++ middleware.go | 101 +++++++++++++++++++++++++++ request.go | 81 ++++++++++++++++++++- 10 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 examples/upload/uploadclient/go.mod create mode 100644 examples/upload/uploadserver/go.mod create mode 100644 examples/upload/uploadserver/go.sum create mode 100644 examples/upload/uploadserver/main.go create mode 100644 file.go diff --git a/client.go b/client.go index fcca1140..fa83daac 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ import ( var ( hdrUserAgentKey = "User-Agent" hdrUserAgentValue = "req/v2 (https://github.com/imroc/req)" + hdrContentTypeKey = "Content-Type" ) // DefaultClient returns the global default Client. @@ -457,6 +458,34 @@ func (c *Client) EnableDumpAsync() *Client { return c } +func EnableDumpNoRequestBody() *Client { + return defaultClient.EnableDumpNoRequestBody() +} + +func (c *Client) EnableDumpNoRequestBody() *Client { + o := c.GetDumpOptions() + o.ResponseHeader = true + o.ResponseBody = true + o.RequestBody = false + o.RequestHeader = true + c.enableDump() + return c +} + +func EnableDumpNoResponseBody() *Client { + return defaultClient.EnableDumpNoResponseBody() +} + +func (c *Client) EnableDumpNoResponseBody() *Client { + o := c.GetDumpOptions() + o.ResponseHeader = true + o.ResponseBody = false + o.RequestBody = true + o.RequestHeader = true + c.enableDump() + return c +} + func EnableDumpOnlyResponse() *Client { return defaultClient.EnableDumpOnlyResponse() } @@ -749,6 +778,7 @@ func C() *Client { parseRequestURL, parseRequestHeader, parseRequestCookie, + parseRequestBody, } afterResponse := []ResponseMiddleware{ parseResponseBody, diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod new file mode 100644 index 00000000..170e1f51 --- /dev/null +++ b/examples/upload/uploadclient/go.mod @@ -0,0 +1,3 @@ +module uploadclient + +go 1.18 diff --git a/examples/upload/uploadserver/go.mod b/examples/upload/uploadserver/go.mod new file mode 100644 index 00000000..94acf8b9 --- /dev/null +++ b/examples/upload/uploadserver/go.mod @@ -0,0 +1,5 @@ +module uploadserver + +go 1.13 + +require github.com/gin-gonic/gin v1.7.7 diff --git a/examples/upload/uploadserver/go.sum b/examples/upload/uploadserver/go.sum new file mode 100644 index 00000000..5ee9be12 --- /dev/null +++ b/examples/upload/uploadserver/go.sum @@ -0,0 +1,54 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/upload/uploadserver/main.go b/examples/upload/uploadserver/main.go new file mode 100644 index 00000000..d9c2c78a --- /dev/null +++ b/examples/upload/uploadserver/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "net/http" + "path/filepath" + + "github.com/gin-gonic/gin" +) + +func main() { + router := gin.Default() + router.POST("/upload", func(c *gin.Context) { + name := c.PostForm("name") + email := c.PostForm("email") + + // Multipart form + form, err := c.MultipartForm() + if err != nil { + c.String(http.StatusBadRequest, "get form err: %s", err.Error()) + return + } + files := form.File["files"] + + for _, file := range files { + filename := filepath.Base(file.Filename) + if err := c.SaveUploadedFile(file, filename); err != nil { + c.String(http.StatusBadRequest, "upload file err: %s", err.Error()) + return + } + } + + c.String(http.StatusOK, "Uploaded successfully %d files with fields name=%s and email=%s.", len(files), name, email) + }) + router.Run(":8888") +} diff --git a/file.go b/file.go new file mode 100644 index 00000000..52026d4d --- /dev/null +++ b/file.go @@ -0,0 +1,9 @@ +package req + +import "io" + +type uploadFile struct { + ParamName string + FilePath string + io.Reader +} diff --git a/go.mod b/go.mod index 8c7b6bc0..b27ffece 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/imroc/req/v2 go 1.13 require ( + github.com/gin-gonic/gin v1.7.7 // indirect github.com/hashicorp/go-multierror v1.1.1 golang.org/x/net v0.0.0-20220111093109-d55c255bac03 golang.org/x/text v0.3.7 diff --git a/go.sum b/go.sum index 59f8d4e3..80bdb59e 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,61 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/middleware.go b/middleware.go index 4253c102..f3240820 100644 --- a/middleware.go +++ b/middleware.go @@ -2,10 +2,13 @@ package req import ( "bytes" + "fmt" "github.com/imroc/req/v2/internal/util" "io" "io/ioutil" + "mime/multipart" "net/http" + "net/textproto" "net/url" "os" "path/filepath" @@ -20,6 +23,104 @@ type ( ResponseMiddleware func(*Client, *Response) error ) +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +func createMultipartHeader(param, fileName, contentType string) textproto.MIMEHeader { + hdr := make(textproto.MIMEHeader) + + var contentDispositionValue string + if util.IsStringEmpty(fileName) { + contentDispositionValue = fmt.Sprintf(`form-data; name="%s"`, param) + } else { + contentDispositionValue = fmt.Sprintf(`form-data; name="%s"; filename="%s"`, + param, escapeQuotes(fileName)) + } + hdr.Set("Content-Disposition", contentDispositionValue) + + if !util.IsStringEmpty(contentType) { + hdr.Set(hdrContentTypeKey, contentType) + } + return hdr +} + +func closeq(v interface{}) { + if c, ok := v.(io.Closer); ok { + c.Close() + } +} + +type multipartBody struct { + *io.PipeReader + closed bool +} + +func (r *multipartBody) Read(p []byte) (n int, err error) { + if r.closed { + return 0, io.EOF + } + n, err = r.PipeReader.Read(p) + if err != nil { + r.closed = true + err = nil + } + return +} + +func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r io.Reader) error { + defer closeq(r) + // Auto detect actual multipart content type + cbuf := make([]byte, 512) + size, err := r.Read(cbuf) + if err != nil && err != io.EOF { + return err + } + + pw, err := w.CreatePart(createMultipartHeader(fieldName, fileName, http.DetectContentType(cbuf))) + if err != nil { + return err + } + + if _, err = pw.Write(cbuf[:size]); err != nil { + return err + } + + _, err = io.Copy(pw, r) + return err +} + +func writeMultiPart(c *Client, r *Request, w *multipart.Writer, pw *io.PipeWriter) { + for k, vs := range r.FormData { + for _, v := range vs { + w.WriteField(k, v) + } + } + for _, file := range r.uploadFiles { + writeMultipartFormFile(w, file.ParamName, file.FilePath, file.Reader) + } + w.Close() // close multipart to write tailer boundary + pw.Close() // close pipe writer so that pipe reader could get EOF, and stop upload +} + +func handleMultiPart(c *Client, r *Request) (err error) { + pr, pw := io.Pipe() + r.RawRequest.Body = pr + w := multipart.NewWriter(pw) + r.RawRequest.Header.Set(hdrContentTypeKey, w.FormDataContentType()) + go writeMultiPart(c, r, w, pw) + return +} + +func parseRequestBody(c *Client, r *Request) (err error) { + if r.isMultiPart { + err = handleMultiPart(c, r) + } + return +} + func parseResponseBody(c *Client, r *Response) (err error) { if r.StatusCode == http.StatusNoContent { return diff --git a/request.go b/request.go index dd46d127..a97f5546 100644 --- a/request.go +++ b/request.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net/http" urlpkg "net/url" + "os" "strings" ) @@ -17,6 +18,7 @@ type Request struct { URL string PathParams map[string]string QueryParams urlpkg.Values + FormData urlpkg.Values Headers http.Header Cookies []*http.Cookie Result interface{} @@ -25,12 +27,44 @@ type Request struct { client *Client RawRequest *http.Request + isMultiPart bool + uploadFiles []*uploadFile + uploadReader []io.ReadCloser outputFile string isSaveResponse bool - isMultiPart bool output io.WriteCloser } +func SetFormDataFromValues(data urlpkg.Values) *Request { + return defaultClient.R().SetFormDataFromValues(data) +} + +func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { + if r.FormData == nil { + r.FormData = urlpkg.Values{} + } + for k, v := range data { + for _, kv := range v { + r.FormData.Add(k, kv) + } + } + return r +} + +func SetFormData(data map[string]string) *Request { + return defaultClient.R().SetFormData(data) +} + +func (r *Request) SetFormData(data map[string]string) *Request { + if r.FormData == nil { + r.FormData = urlpkg.Values{} + } + for k, v := range data { + r.FormData.Set(k, v) + } + return r +} + func SetCookie(hc *http.Cookie) *Request { return defaultClient.R().SetCookie(hc) } @@ -67,6 +101,51 @@ func (r *Request) SetQueryString(query string) *Request { return r } +func SetFileReader(paramName, filePath string, reader io.Reader) *Request { + return defaultClient.R().SetFileReader(paramName, filePath, reader) +} + +func (r *Request) SetFileReader(paramName, filePath string, reader io.Reader) *Request { + r.isMultiPart = true + r.uploadFiles = append(r.uploadFiles, &uploadFile{ + ParamName: paramName, + FilePath: filePath, + Reader: reader, + }) + return r +} + +func SetFiles(files map[string]string) *Request { + return defaultClient.R().SetFiles(files) +} + +func (r *Request) SetFiles(files map[string]string) *Request { + for k, v := range files { + r.SetFile(k, v) + } + return r +} + +func SetFile(paramName, filePath string) *Request { + return defaultClient.R().SetFile(paramName, filePath) +} + +func (r *Request) SetFile(paramName, filePath string) *Request { + r.isMultiPart = true + file, err := os.Open(filePath) + if err != nil { + r.client.log.Errorf("failed to open %s: %v", filePath, err) + r.appendError(err) + return r + } + r.uploadFiles = append(r.uploadFiles, &uploadFile{ + ParamName: paramName, + FilePath: filePath, + Reader: file, + }) + return r +} + func SetResult(result interface{}) *Request { return defaultClient.R().SetResult(result) } From 097afdd42940f65ff1b9257c934688270941650d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 17:46:53 +0800 Subject: [PATCH 133/843] add uploadclient example --- examples/upload/uploadclient/go.mod | 10 +++++- examples/upload/uploadclient/go.sum | 47 ++++++++++++++++++++++++++++ examples/upload/uploadclient/main.go | 29 +++++++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 examples/upload/uploadclient/go.sum create mode 100644 examples/upload/uploadclient/main.go diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index 170e1f51..bd99106d 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -1,3 +1,11 @@ module uploadclient -go 1.18 +go 1.13 + +require ( + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/imroc/req/v2 v2.0.0-alpha.11 // indirect + golang.org/x/net v0.0.0-20220111093109-d55c255bac03 // indirect + golang.org/x/text v0.3.7 // indirect +) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum new file mode 100644 index 00000000..3875b817 --- /dev/null +++ b/examples/upload/uploadclient/go.sum @@ -0,0 +1,47 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/imroc/req/v2 v2.0.0-alpha.11 h1:VJuYAcVcPVJUd9YQ6Vv5S2djHhl9mQS+XKoOkfjAMHw= +github.com/imroc/req/v2 v2.0.0-alpha.11/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/upload/uploadclient/main.go b/examples/upload/uploadclient/main.go new file mode 100644 index 00000000..69690cc5 --- /dev/null +++ b/examples/upload/uploadclient/main.go @@ -0,0 +1,29 @@ +package main + +import "github.com/imroc/req/v2" + +func main() { + req.EnableDumpNoRequestBody() + req.SetFile("files", "../../../README.md"). + SetFile("files", "../../../LICENSE"). + SetFormData(map[string]string{ + "name": "imroc", + "email": "roc@imroc.cc", + }). + Post("http://127.0.0.1:8888/upload") + /* Output + POST /upload HTTP/1.1 + Host: 127.0.0.1:8888 + User-Agent: req/v2 (https://github.com/imroc/req) + Transfer-Encoding: chunked + Content-Type: multipart/form-data; boundary=6af1b071a682709355cf5fb15b9cf9e793df7a45e5cd1eb7c413f2e72bf6 + Accept-Encoding: gzip + + HTTP/1.1 200 OK + Content-Type: text/plain; charset=utf-8 + Date: Tue, 25 Jan 2022 09:40:36 GMT + Content-Length: 76 + + Uploaded successfully 2 files with fields name=imroc and email=roc@imroc.cc. + */ +} From 382c476a4b990035e1390070bcc58d531da73ee7 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 18:01:26 +0800 Subject: [PATCH 134/843] update README --- README.md | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3da618ec..7a02c95f 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.10 +go get github.com/imroc/req/v2@v2.0.0-alpha.11 ``` **Import** @@ -370,6 +370,8 @@ client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.e ## Download and Upload +**Download** + ```go // Create a client with default download direcotry client := req.C().SetOutputDirectory("/path/to/download") @@ -389,10 +391,28 @@ if err != nil { return } client.R().SetOutput(file).Get(url) - ``` +**Multipart Upload** + +```go +client := req.().EnableDumpNoRequestBody() // Request body contains unreadable binary, do not dump + +client.R().SetFile("pic", "test.jpg"). // Set form param name and filename + SetFile("pic", "/path/to/roc.png"). // Multiple files using the same form param name + SetFiles(map[string]string{ // Set multiple files using map + "exe": "test.exe", + "src": "main.go", + }). + SetFormData(map[string]string{ // Set from param using map + "name": "imroc", + "email": "roc@imroc.cc", + }). + SetFromDataFromValues(values). // You can also set form data using `url.Values` + Post("http://127.0.0.1:8888/upload") +*/ +``` ## License From bf600664c87a4aa11ba00a292787cf383fc1ad5a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 18:03:32 +0800 Subject: [PATCH 135/843] update README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a02c95f..5ac7ef8f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Header and Cookie](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) +* [Download and Upload](#Download) ## Features @@ -374,7 +375,7 @@ client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.e ```go // Create a client with default download direcotry -client := req.C().SetOutputDirectory("/path/to/download") +client := req.C().SetOutputDirectory("/path/to/download").EnableDumpNoResponseBody() // Download to relative file path, this will be downloaded // to /path/to/download/test.jpg From 9c9b450ce6c9a8485a30822ffa59b9f3fd079f36 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 19:27:46 +0800 Subject: [PATCH 136/843] set form data without multipart --- client.go | 6 ------ common.go | 11 +++++++++++ content_type.go | 9 --------- middleware.go | 10 +++++++++- request.go | 10 +++++----- 5 files changed, 25 insertions(+), 21 deletions(-) create mode 100644 common.go delete mode 100644 content_type.go diff --git a/client.go b/client.go index fa83daac..ddade806 100644 --- a/client.go +++ b/client.go @@ -17,12 +17,6 @@ import ( "time" ) -var ( - hdrUserAgentKey = "User-Agent" - hdrUserAgentValue = "req/v2 (https://github.com/imroc/req)" - hdrContentTypeKey = "Content-Type" -) - // DefaultClient returns the global default Client. func DefaultClient() *Client { return defaultClient diff --git a/common.go b/common.go new file mode 100644 index 00000000..6bcb147a --- /dev/null +++ b/common.go @@ -0,0 +1,11 @@ +package req + +const ( + hdrUserAgentKey = "User-Agent" + hdrUserAgentValue = "req/v2 (https://github.com/imroc/req)" + hdrContentTypeKey = "Content-Type" + plainTextType = "text/plain; charset=utf-8" + jsonContentType = "application/json; charset=utf-8" + xmlContentType = "text/xml; charset=utf-8" + formContentType = "application/x-www-form-urlencoded" +) diff --git a/content_type.go b/content_type.go deleted file mode 100644 index 279a6cf9..00000000 --- a/content_type.go +++ /dev/null @@ -1,9 +0,0 @@ -package req - -const ( - CONTENT_TYPE_APPLICATION_JSON_UTF8 = "application/json; charset=UTF-8" - CONTENT_TYPE_APPLICATION_XML_UTF8 = "application/xml; charset=UTF-8" - CONTENT_TYPE_TEXT_XML_UTF8 = "text/xml; charset=UTF-8" - CONTENT_TYPE_TEXT_HTML_UTF8 = "text/html; charset=UTF-8" - CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=UTF-8" -) \ No newline at end of file diff --git a/middleware.go b/middleware.go index f3240820..03000eb8 100644 --- a/middleware.go +++ b/middleware.go @@ -114,9 +114,17 @@ func handleMultiPart(c *Client, r *Request) (err error) { return } +func handleFormData(r *Request) { + r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(r.FormData.Encode())) +} + func parseRequestBody(c *Client, r *Request) (err error) { if r.isMultiPart { - err = handleMultiPart(c, r) + return handleMultiPart(c, r) + } + if len(r.FormData) > 0 { + handleFormData(r) + return } return } diff --git a/request.go b/request.go index a97f5546..d2f3e06f 100644 --- a/request.go +++ b/request.go @@ -484,10 +484,10 @@ func SetBodyJsonString(body string) *Request { } // SetBodyJsonString set the request body as string and set Content-Type header -// as "application/json; charset=UTF-8" +// as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) - r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + r.SetContentType(jsonContentType) return r } @@ -496,10 +496,10 @@ func SetBodyJsonBytes(body []byte) *Request { } // SetBodyJsonBytes set the request body as []byte and set Content-Type header -// as "application/json; charset=UTF-8" +// as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) - r.SetContentType(CONTENT_TYPE_APPLICATION_JSON_UTF8) + r.SetContentType(jsonContentType) return r } @@ -508,7 +508,7 @@ func SetBodyJsonMarshal(v interface{}) *Request { } // SetBodyJsonMarshal set the request body that marshaled from object, and -// set Content-Type header as "application/json; charset=UTF-8" +// set Content-Type header as "application/json; charset=utf-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { b, err := json.Marshal(v) if err != nil { From 21ed0c252f20d8a2e479fbe56a6eae958ab98f43 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 25 Jan 2022 20:15:31 +0800 Subject: [PATCH 137/843] refactor set body --- middleware.go | 37 ++++++++++++++++++------------------- request.go | 18 +++++++++++++++--- response.go | 2 +- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/middleware.go b/middleware.go index 03000eb8..1841499e 100644 --- a/middleware.go +++ b/middleware.go @@ -129,31 +129,30 @@ func parseRequestBody(c *Client, r *Request) (err error) { return } -func parseResponseBody(c *Client, r *Response) (err error) { - if r.StatusCode == http.StatusNoContent { - return - } +func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { body, err := r.ToBytes() // in case req.SetResult or req.SetError with cient.DisalbeAutoReadResponse(true) if err != nil { return } - // Handles only JSON or XML content type ct := util.FirstNonEmpty(r.GetContentType()) - if r.IsSuccess() && r.Request.Result != nil { - r.Request.Error = nil - if util.IsJSONType(ct) { - return c.JSONUnmarshal(body, r.Request.Result) - } else if util.IsXMLType(ct) { - return c.XMLUnmarshal(body, r.Request.Result) - } + if util.IsJSONType(ct) { + return c.JSONUnmarshal(body, v) + } else if util.IsXMLType(ct) { + return c.XMLUnmarshal(body, v) } - if r.IsError() && r.Request.Error != nil { - r.Request.Result = nil - if util.IsJSONType(ct) { - return c.JSONUnmarshal(body, r.Request.Error) - } else if util.IsXMLType(ct) { - return c.XMLUnmarshal(body, r.Request.Error) - } + return +} + +func parseResponseBody(c *Client, r *Response) (err error) { + if r.StatusCode == http.StatusNoContent { + return + } + // Handles only JSON or XML content type + if r.Request.Result != nil && r.IsSuccess() { + unmarshalBody(c, r, r.Request.Result) + } + if r.Request.Error != nil && r.IsError() { + unmarshalBody(c, r, r.Request.Error) } return } diff --git a/request.go b/request.go index d2f3e06f..46ff2f88 100644 --- a/request.go +++ b/request.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "encoding/json" "github.com/hashicorp/go-multierror" "github.com/imroc/req/v2/internal/util" "io" @@ -510,7 +509,20 @@ func SetBodyJsonMarshal(v interface{}) *Request { // SetBodyJsonMarshal set the request body that marshaled from object, and // set Content-Type header as "application/json; charset=utf-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { - b, err := json.Marshal(v) + b, err := r.client.JSONMarshal(v) + if err != nil { + r.appendError(err) + return r + } + return r.SetBodyBytes(b) +} + +func SetBodyXmlMarshal(v interface{}) *Request { + return defaultClient.R().SetBodyXmlMarshal(v) +} + +func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { + b, err := r.client.XMLMarshal(v) if err != nil { r.appendError(err) return r @@ -523,6 +535,6 @@ func SetContentType(contentType string) *Request { } func (r *Request) SetContentType(contentType string) *Request { - r.RawRequest.Header.Set("Content-Type", contentType) + r.RawRequest.Header.Set(hdrContentTypeKey, contentType) return r } diff --git a/response.go b/response.go index f3fba4e7..c3ca6bd0 100644 --- a/response.go +++ b/response.go @@ -50,7 +50,7 @@ func (r *Response) IsError() bool { } func (r *Response) GetContentType() string { - return r.Header.Get("Content-Type") + return r.Header.Get(hdrContentTypeKey) } // Result method returns the response value as an object if it has one From 5a42efa2579ab6fd49b25f6c9e0d2a1bd8357606 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 14:44:24 +0800 Subject: [PATCH 138/843] support trace --- README.md | 107 +++++++++++++++++++++------------- client.go | 32 ++++++++-- middleware.go | 1 + request.go | 91 ++++++++++++++++++++++++++++- response.go | 28 ++++++++- response_body.go | 1 + trace.go | 148 +++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 361 insertions(+), 47 deletions(-) create mode 100644 trace.go diff --git a/README.md b/README.md index 5ac7ef8f..5d015526 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ Simplified golang http client library with magic, happy sending requests, less c * [Features](#Features) * [Quick Start](#Quick-Start) * [Debugging](#Debugging) -* [Testing with Global Wrapper Methods](#Global) * [Path Parameter and Query Parameter](#Param) * [Header and Cookie](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) @@ -29,7 +28,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-alpha.11 +go get github.com/imroc/req/v2@v2.0.0-beta.0 ``` **Import** @@ -59,57 +58,49 @@ resp, err := client.R(). // Use R() to create a request ## Debugging -**Dump the content of request and response** +**Dump the Content** ```go // Set EnableDump to true, dump all content to stdout by default, // including both the header and body of all request and response client := req.C().EnableDump(true) -client.R().Get("https://api.github.com/users/imroc") +client.R().Get("https://httpbin.org/get") /* Output -:authority: api.github.com +:authority: httpbin.org :method: GET -:path: /users/imroc +:path: /get :scheme: https -user-agent: req/v2 (https://github.com/imroc/req) +user-agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36 accept-encoding: gzip :status: 200 -server: GitHub.com -date: Fri, 21 Jan 2022 09:31:43 GMT -content-type: application/json; charset=utf-8 -cache-control: public, max-age=60, s-maxage=60 -vary: Accept, Accept-Encoding, Accept, X-Requested-With -etag: W/"fe5acddc5c01a01153ebc4068a1f067dadfa7a7dc9a025f44b37b0a0a50e2c55" -last-modified: Thu, 08 Jul 2021 12:11:23 GMT -x-github-media-type: github.v3; format=json -access-control-expose-headers: ETag, Link, Location, Retry-After, X-GitHub-OTP, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Used, X-RateLimit-Resource, X-RateLimit-Reset, X-OAuth-Scopes, X-Accepted-OAuth-Scopes, X-Poll-Interval, X-GitHub-Media-Type, X-GitHub-SSO, X-GitHub-Request-Id, Deprecation, Sunset +date: Wed, 26 Jan 2022 06:39:20 GMT +content-type: application/json +content-length: 372 +server: gunicorn/19.9.0 access-control-allow-origin: * -strict-transport-security: max-age=31536000; includeSubdomains; preload -x-frame-options: deny -x-content-type-options: nosniff -x-xss-protection: 0 -referrer-policy: origin-when-cross-origin, strict-origin-when-cross-origin -content-security-policy: default-src 'none' -content-encoding: gzip -x-ratelimit-limit: 60 -x-ratelimit-remaining: 59 -x-ratelimit-reset: 1642761103 -x-ratelimit-resource: core -x-ratelimit-used: 1 -accept-ranges: bytes -content-length: 486 -x-github-request-id: AF10:6205:BA107D:D614F2:61EA7D7E - -{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":128,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2021-07-08T12:11:23Z"} +access-control-allow-credentials: true + +{ + "args": {}, + "headers": { + "Accept-Encoding": "gzip", + "Host": "httpbin.org", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36", + "X-Amzn-Trace-Id": "Root=1-61f0ec98-5958c02662de26e458b7672b" + }, + "origin": "103.7.29.30", + "url": "https://httpbin.org/get" +} */ // Customize dump settings with predefined convenience settings. client.EnableDumpOnlyHeader(). // Only dump the header of request and response EnableDumpAsync(). // Dump asynchronously to improve performance EnableDumpToFile("reqdump.log") // Dump to file without printing it out -client.R().Get(url) // Send request to see the content that have been dumpped +// Send request to see the content that have been dumpped +client.R().Get(url) // Enable dump with fully customized settings opt := &req.DumpOptions{ @@ -128,15 +119,17 @@ opt.ResponseBody = false client.R().Get("https://www.baidu.com/") ``` -**Debug Log** +**EnableDebugLog for Deeper Insights** ```go -// Logging is enabled by default, but only output warning and error message to stdout. -// EnableDebugLog set to true to enable debug level message logging. +// Logging is enabled by default, but only output the warning and error message. +// set `EnableDebugLog` to true to enable debug level logging. client := req.C().EnableDebugLog(true) client.R().Get("https://api.github.com/users/imroc") -// Output -// 2022/01/23 14:33:04.755019 DEBUG [req] GET https://api.github.com/users/imroc +/* Output +2022/01/23 14:33:04.755019 DEBUG [req] GET https://api.github.com/users/imroc +... +*/ // SetLogger with nil to disable all log client.SetLogger(nil) @@ -145,6 +138,40 @@ client.SetLogger(nil) client.SetLogger(logger) ``` +**EnableTrace to Analyze Performance** + +```go +// Enable trace at request level +client := req.C() +resp, err := client.R().EnableTrace(true).Get("https://api.github.com/users/imroc") +if err != nil { + log.Fatal(err) +} +ti := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary copy in production +fmt.Println(ti) +fmt.Println("--------") +k, v := ti.MaxTime() +fmt.Printf("Max time is %s which tooks %v\n", k, v) + +/* Output +TotalTime : 1.342805875s +DNSLookupTime : 7.549292ms +TCPConnectTime : 567.833µs +TLSHandshakeTime : 536.604041ms +FirstResponseTime : 797.466708ms +ResponseTime : 374.875µs +IsConnReused: : false +RemoteAddr : 192.30.255.117:443 +-------- +Max time is FirstResponseTime which tooks 797.466708ms +*/ + +// Enable trace at client level +client.EnableTraceAll() +resp, err = client.R().Get(url) +// ... +``` + **DevMode** If you want to enable all debug features (dump, debug log and tracing), just call `DevMode()`: @@ -154,7 +181,7 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -## Testing with Use Global Methods +**Testing with Use Global Methods** `req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. diff --git a/client.go b/client.go index ddade806..11ec0fae 100644 --- a/client.go +++ b/client.go @@ -44,6 +44,7 @@ type Client struct { XMLUnmarshal func(data []byte, v interface{}) error DebugLog bool + trace bool outputDirectory string disableAutoReadResponse bool scheme string @@ -721,6 +722,15 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } +func EnableTraceAll(enable bool) *Client { + return defaultClient.EnableTraceAll(enable) +} + +func (c *Client) EnableTraceAll(enable bool) *Client { + c.trace = enable + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() @@ -802,10 +812,13 @@ func C() *Client { func setupRequest(r *Request) { setRequestURL(r.RawRequest, r.URL) setRequestHeaderAndCookie(r) + setTrace(r) } func (c *Client) do(r *Request) (resp *Response, err error) { + resp = &Response{} + for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { return @@ -824,15 +837,14 @@ func (c *Client) do(r *Request) (resp *Response, err error) { c.log.Debugf("%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) } + r.StartTime = time.Now() httpResponse, err := c.httpClient.Do(r.RawRequest) if err != nil { return } - resp = &Response{ - Request: r, - Response: httpResponse, - } + resp.Request = r + resp.Response = httpResponse if !c.disableAutoReadResponse && !r.isSaveResponse { // auto read response body _, err = resp.ToBytes() @@ -849,6 +861,18 @@ func (c *Client) do(r *Request) (resp *Response, err error) { return } +func setTrace(r *Request) { + if r.trace == nil { + if r.client.trace { + r.trace = &clientTrace{} + } else { + return + } + } + r.ctx = r.trace.createContext(r.Context()) + r.RawRequest = r.RawRequest.WithContext(r.ctx) +} + func setRequestHeaderAndCookie(r *Request) { for k, vs := range r.Headers { for _, v := range vs { diff --git a/middleware.go b/middleware.go index 1841499e..caf4bbf9 100644 --- a/middleware.go +++ b/middleware.go @@ -194,6 +194,7 @@ func handleDownload(c *Client, r *Response) (err error) { output.Close() }() _, err = io.Copy(output, body) + r.setReceivedAt() return } diff --git a/request.go b/request.go index 46ff2f88..40e5233b 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "context" "github.com/hashicorp/go-multierror" "github.com/imroc/req/v2/internal/util" "io" @@ -10,6 +11,7 @@ import ( urlpkg "net/url" "os" "strings" + "time" ) // Request is the http request @@ -25,13 +27,63 @@ type Request struct { error error client *Client RawRequest *http.Request + StartTime time.Time + ctx context.Context isMultiPart bool uploadFiles []*uploadFile uploadReader []io.ReadCloser outputFile string isSaveResponse bool output io.WriteCloser + trace *clientTrace +} + +func (r *Request) TraceInfo() TraceInfo { + ct := r.trace + + if ct == nil { + return TraceInfo{} + } + + ti := TraceInfo{ + DNSLookupTime: ct.dnsDone.Sub(ct.dnsStart), + TLSHandshakeTime: ct.tlsHandshakeDone.Sub(ct.tlsHandshakeStart), + FirstResponseTime: ct.gotFirstResponseByte.Sub(ct.gotConn), + IsConnReused: ct.gotConnInfo.Reused, + IsConnWasIdle: ct.gotConnInfo.WasIdle, + ConnIdleTime: ct.gotConnInfo.IdleTime, + } + + // Calculate the total time accordingly, + // when connection is reused + if ct.gotConnInfo.Reused { + ti.TotalTime = ct.endTime.Sub(ct.getConn) + } else { + ti.TotalTime = ct.endTime.Sub(ct.dnsStart) + } + + // Only calculate on successful connections + if !ct.connectDone.IsZero() { + ti.TCPConnectTime = ct.connectDone.Sub(ct.dnsDone) + } + + // Only calculate on successful connections + if !ct.gotConn.IsZero() { + ti.ConnectTime = ct.gotConn.Sub(ct.getConn) + } + + // Only calculate on successful connections + if !ct.gotFirstResponseByte.IsZero() { + ti.ResponseTime = ct.endTime.Sub(ct.gotFirstResponseByte) + } + + // Capture remote address info when connection is non-nil + if ct.gotConnInfo.Conn != nil { + ti.RemoteAddr = ct.gotConnInfo.Conn.RemoteAddr() + } + + return ti } func SetFormDataFromValues(data urlpkg.Values) *Request { @@ -275,7 +327,7 @@ func (r *Request) appendError(err error) { func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { - return nil, r.error + return &Response{}, r.error } r.RawRequest.Method = method r.URL = url @@ -538,3 +590,40 @@ func (r *Request) SetContentType(contentType string) *Request { r.RawRequest.Header.Set(hdrContentTypeKey, contentType) return r } + +// Context method returns the Context if its already set in request +// otherwise it creates new one using `context.Background()`. +func (r *Request) Context() context.Context { + if r.ctx == nil { + r.ctx = context.Background() + } + return r.ctx +} + +func SetContext(ctx context.Context) *Request { + return defaultClient.R().SetContext(ctx) +} + +// SetContext method sets the context.Context for current Request. It allows +// to interrupt the request execution if ctx.Done() channel is closed. +// See https://blog.golang.org/context article and the "context" package +// documentation. +func (r *Request) SetContext(ctx context.Context) *Request { + r.ctx = ctx + return r +} + +func EnableTrace(enable bool) *Request { + return defaultClient.R().EnableTrace(enable) +} + +func (r *Request) EnableTrace(enable bool) *Request { + if enable { + if r.trace == nil { + r.trace = &clientTrace{} + } + } else { + r.trace = nil + } + return r +} diff --git a/response.go b/response.go index c3ca6bd0..eafebfd0 100644 --- a/response.go +++ b/response.go @@ -3,6 +3,7 @@ package req import ( "net/http" "strings" + "time" ) // ResponseOptions determines that how should the response been processed. @@ -35,8 +36,9 @@ func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) // Response is the http response. type Response struct { *http.Response - Request *Request - body []byte + Request *Request + body []byte + receivedAt time.Time } // IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. @@ -62,3 +64,25 @@ func (r *Response) Result() interface{} { func (r *Response) Error() interface{} { return r.Request.Error } + +func (r *Response) TraceInfo() TraceInfo { + return r.Request.TraceInfo() +} + +func (r *Response) TotalTime() time.Duration { + if r.Request.trace != nil { + return r.Request.TraceInfo().TotalTime + } + return r.receivedAt.Sub(r.Request.StartTime) +} + +func (r *Response) ReceivedAt() time.Time { + return r.receivedAt +} + +func (r *Response) setReceivedAt() { + r.receivedAt = time.Now() + if r.Request.trace != nil { + r.Request.trace.endTime = r.receivedAt + } +} diff --git a/response_body.go b/response_body.go index ec8c87ba..3c1e069a 100644 --- a/response_body.go +++ b/response_body.go @@ -60,6 +60,7 @@ func (r *Response) ToBytes() ([]byte, error) { } defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) + r.setReceivedAt() if err != nil { return nil, err } diff --git a/trace.go b/trace.go new file mode 100644 index 00000000..6b1a8aa5 --- /dev/null +++ b/trace.go @@ -0,0 +1,148 @@ +package req + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http/httptrace" + "time" +) + +const ( + traceFmt = `TotalTime : %v +DNSLookupTime : %v +TCPConnectTime : %v +TLSHandshakeTime : %v +FirstResponseTime : %v +ResponseTime : %v +IsConnReused: : false +RemoteAddr : %v` + traceReusedFmt = `TotalTime : %v +FirstResponseTime : %v +ResponseTime : %v +IsConnReused: : true +RemoteAddr : %v` +) + +func (t TraceInfo) MaxTime() (maxName string, maxValue time.Duration) { + m := map[string]time.Duration{ + "DNSLookupTime": t.DNSLookupTime, + "TCPConnectTime": t.TCPConnectTime, + "TLSHandshakeTime": t.TLSHandshakeTime, + "FirstResponseTime": t.FirstResponseTime, + "ResponseTime": t.ResponseTime, + } + for k, v := range m { + if v > maxValue { + maxName = k + maxValue = v + } + } + return +} + +func (t TraceInfo) String() string { + if t.RemoteAddr == nil { + return "uncompleted request" + } + if t.IsConnReused { + return fmt.Sprintf(traceReusedFmt, t.TotalTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) + } else { + return fmt.Sprintf(traceFmt, t.TotalTime, t.DNSLookupTime, t.TCPConnectTime, t.TLSHandshakeTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) + } +} + +type TraceInfo struct { + // DNSLookupTime is a duration that transport took to perform + // DNS lookup. + DNSLookupTime time.Duration + + // ConnectTime is a duration that took to obtain a successful connection. + ConnectTime time.Duration + + // TCPConnectTime is a duration that took to obtain the TCP connection. + TCPConnectTime time.Duration + + // TLSHandshakeTime is a duration that TLS handshake took place. + TLSHandshakeTime time.Duration + + // FirstResponseTime is a duration that server took to respond first byte. + FirstResponseTime time.Duration + + // ResponseTime is a duration since first response byte from server to + // request completion. + ResponseTime time.Duration + + // TotalTime is a duration that total request took end-to-end. + TotalTime time.Duration + + // IsConnReused is whether this connection has been previously + // used for another HTTP request. + IsConnReused bool + + // IsConnWasIdle is whether this connection was obtained from an + // idle pool. + IsConnWasIdle bool + + // ConnIdleTime is a duration how long the connection was previously + // idle, if IsConnWasIdle is true. + ConnIdleTime time.Duration + + // RemoteAddr returns the remote network address. + RemoteAddr net.Addr +} + +type clientTrace struct { + getConn time.Time + dnsStart time.Time + dnsDone time.Time + connectDone time.Time + tlsHandshakeStart time.Time + tlsHandshakeDone time.Time + gotConn time.Time + gotFirstResponseByte time.Time + endTime time.Time + gotConnInfo httptrace.GotConnInfo +} + +func (t *clientTrace) createContext(ctx context.Context) context.Context { + return httptrace.WithClientTrace( + ctx, + &httptrace.ClientTrace{ + DNSStart: func(_ httptrace.DNSStartInfo) { + t.dnsStart = time.Now() + }, + DNSDone: func(_ httptrace.DNSDoneInfo) { + t.dnsDone = time.Now() + }, + ConnectStart: func(_, _ string) { + if t.dnsDone.IsZero() { + t.dnsDone = time.Now() + } + if t.dnsStart.IsZero() { + t.dnsStart = t.dnsDone + } + }, + ConnectDone: func(net, addr string, err error) { + t.connectDone = time.Now() + }, + GetConn: func(_ string) { + t.getConn = time.Now() + }, + GotConn: func(ci httptrace.GotConnInfo) { + t.gotConn = time.Now() + t.gotConnInfo = ci + }, + GotFirstResponseByte: func() { + t.gotFirstResponseByte = time.Now() + }, + TLSHandshakeStart: func() { + t.tlsHandshakeStart = time.Now() + }, + TLSHandshakeDone: func(_ tls.ConnectionState, _ error) { + t.tlsHandshakeDone = time.Now() + }, + }, + ) +} From 33299dcefeec0a48172f6ee74d80b471ec223023 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 14:46:54 +0800 Subject: [PATCH 139/843] EnableTraceAll in DevMode --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 11ec0fae..3baeb168 100644 --- a/client.go +++ b/client.go @@ -330,6 +330,7 @@ func DevMode() *Client { func (c *Client) DevMode() *Client { return c.EnableDumpAll(). EnableDebugLog(true). + EnableTraceAll(true). SetUserAgent(userAgentChrome) } From 358cc0df321d4be69970589a33a2d94ff6c2f358 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 16:37:43 +0800 Subject: [PATCH 140/843] some refactor --- README.md | 38 ++++++++- client.go | 125 ++++++++++++++++------------ decode.go | 5 +- internal/charsetutil/charsetutil.go | 25 +++--- middleware.go | 5 +- response.go | 12 ++- transport.go | 3 + 7 files changed, 133 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 5d015526..336551d2 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Custom Client and Root Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) * [Download and Upload](#Download) +* [Auto-Decoding](#AutoDecode) ## Features @@ -125,9 +126,11 @@ client.R().Get("https://www.baidu.com/") // Logging is enabled by default, but only output the warning and error message. // set `EnableDebugLog` to true to enable debug level logging. client := req.C().EnableDebugLog(true) -client.R().Get("https://api.github.com/users/imroc") +client.R().Get("http://baidu.com/s?wd=req") /* Output -2022/01/23 14:33:04.755019 DEBUG [req] GET https://api.github.com/users/imroc +2022/01/26 15:46:29.279368 DEBUG [req] GET http://baidu.com/s?wd=req +2022/01/26 15:46:29.469653 DEBUG [req] charset iso-8859-1 detected in Content-Type, auto-decode to utf-8 +2022/01/26 15:46:29.469713 DEBUG [req] GET http://www.baidu.com/s?wd=req ... */ @@ -442,6 +445,37 @@ client.R().SetFile("pic", "test.jpg"). // Set form param name and filename ``` +## Auto-Decoding + +`Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. + +Its principle is to detect whether `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information, if it's not included in the header, it will try to sniff the body's content to determine the charset, if found and is not utf-8, then decode it to utf-8 automatically, if the charset is not sure, it will not decode, and leave the body untouched. + +You can also disable if you don't need or care a lot about performance: + +```go +client.DisableAutoDecode(true) +``` + +Also you can make some customization: + +```go +// Try to auto-detect and decode all content types (some server may return incorrect Content-Type header) +client.SetAutoDecodeAllType() + +// Only auto-detect and decode content which `Content-Type` header contains "html" or "json" +client.SetAutoDecodeContentType("html", "json") + +// Or you can customize the function to determine whether to decode +fn := func(contentType string) bool { + if regexContentType.MatchString(contentType) { + return true + } + return false +} +client.SetAutoDecodeAllTypeFunc(fn) +``` + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index 3baeb168..a19201e4 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/json" "encoding/xml" + "errors" "github.com/imroc/req/v2/internal/util" "golang.org/x/net/publicsuffix" "io" @@ -33,7 +34,7 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - HostURL string + BaseURL string PathParams map[string]string QueryParams urlpkg.Values Headers http.Header @@ -114,6 +115,15 @@ func (c *Client) R() *Request { } } +func SetBaseURL(u string) *Client { + return defaultClient.SetBaseURL(u) +} + +func (c *Client) SetBaseURL(u string) *Client { + c.BaseURL = strings.TrimRight(u, "/") + return c +} + func SetOutputDirectory(dir string) *Client { return defaultClient.SetOutputDirectory(dir) } @@ -193,6 +203,16 @@ func (c *Client) tlsConfig() *tls.Config { return c.t.TLSClientConfig } +func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + if c.DebugLog { + c.log.Debugf(" %s %s", req.Method, req.URL.String()) + } + return nil +} + func SetRedirectPolicy(policies ...RedirectPolicy) *Client { return defaultClient.SetRedirectPolicy(policies...) } @@ -212,6 +232,9 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { return err } } + if c.DebugLog { + c.log.Debugf(" %s %s", req.Method, req.URL.String()) + } return nil } return c @@ -306,11 +329,6 @@ func (c *Client) SetCommonCookies(cs []*http.Cookie) *Client { return c } -const ( - userAgentFirefox = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:95.0) Gecko/20100101 Firefox/95.0" - userAgentChrome = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" -) - func EnableDebugLog(enable bool) *Client { return defaultClient.EnableDebugLog(enable) } @@ -320,18 +338,21 @@ func (c *Client) EnableDebugLog(enable bool) *Client { return c } +// DevMode is a global wrapper method for default client. func DevMode() *Client { return defaultClient.DevMode() } -// DevMode enables dump for requests and responses, and set user -// agent to pretend to be a web browser, Avoid returning abnormal -// data from some sites. +// DevMode enables: +// 1. Dump content of all requests and responses to see details. +// 2. Output debug log for deeper insights. +// 3. Trace all requests, so you can get trace info to analyze performance. +// 4. Set User-Agent to pretend to be a web browser, avoid returning abnormal data from some sites. func (c *Client) DevMode() *Client { return c.EnableDumpAll(). EnableDebugLog(true). EnableTraceAll(true). - SetUserAgent(userAgentChrome) + SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } func SetScheme(scheme string) *Client { @@ -361,31 +382,13 @@ func (c *Client) SetLogger(log Logger) *Client { return c } -func GetResponseOptions() *ResponseOptions { - return defaultClient.GetResponseOptions() -} - -func (c *Client) GetResponseOptions() *ResponseOptions { +func (c *Client) getResponseOptions() *ResponseOptions { if c.t.ResponseOptions == nil { c.t.ResponseOptions = &ResponseOptions{} } return c.t.ResponseOptions } -func SetResponseOptions(opt *ResponseOptions) *Client { - return defaultClient.SetResponseOptions(opt) -} - -// SetResponseOptions set the ResponseOptions for the underlying Transport. -func (c *Client) SetResponseOptions(opt *ResponseOptions) *Client { - if opt == nil { - c.log.Warnf("ignore nil *ResponseOptions") - return c - } - c.t.ResponseOptions = opt - return c -} - func SetTimeout(d time.Duration) *Client { return defaultClient.SetTimeout(d) } @@ -396,11 +399,7 @@ func (c *Client) SetTimeout(d time.Duration) *Client { return c } -func GetDumpOptions() *DumpOptions { - return defaultClient.GetDumpOptions() -} - -func (c *Client) GetDumpOptions() *DumpOptions { +func (c *Client) getDumpOptions() *DumpOptions { if c.dumpOptions == nil { c.dumpOptions = newDefaultDumpOptions() } @@ -411,7 +410,7 @@ func (c *Client) enableDump() { if c.t.dump != nil { // dump already started return } - c.t.EnableDump(c.GetDumpOptions()) + c.t.EnableDump(c.getDumpOptions()) } func EnableDumpToFile(filename string) *Client { @@ -425,7 +424,7 @@ func (c *Client) EnableDumpToFile(filename string) *Client { c.log.Errorf("create dump file error: %v", err) return c } - c.GetDumpOptions().Output = file + c.getDumpOptions().Output = file return c } @@ -435,7 +434,7 @@ func EnableDumpTo(output io.Writer) *Client { // EnableDumpTo indicates that the content should dump to the specified destination. func (c *Client) EnableDumpTo(output io.Writer) *Client { - c.GetDumpOptions().Output = output + c.getDumpOptions().Output = output c.enableDump() return c } @@ -448,7 +447,7 @@ func EnableDumpAsync() *Client { // can be used for debugging in production environment without // affecting performance. func (c *Client) EnableDumpAsync() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.Async = true c.enableDump() return c @@ -459,7 +458,7 @@ func EnableDumpNoRequestBody() *Client { } func (c *Client) EnableDumpNoRequestBody() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.ResponseHeader = true o.ResponseBody = true o.RequestBody = false @@ -473,7 +472,7 @@ func EnableDumpNoResponseBody() *Client { } func (c *Client) EnableDumpNoResponseBody() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.ResponseHeader = true o.ResponseBody = false o.RequestBody = true @@ -488,7 +487,7 @@ func EnableDumpOnlyResponse() *Client { // EnableDumpOnlyResponse indicates that should dump the responses' head and response. func (c *Client) EnableDumpOnlyResponse() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.ResponseHeader = true o.ResponseBody = true o.RequestBody = false @@ -503,7 +502,7 @@ func EnableDumpOnlyRequest() *Client { // EnableDumpOnlyRequest indicates that should dump the requests' head and response. func (c *Client) EnableDumpOnlyRequest() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.RequestHeader = true o.RequestBody = true o.ResponseBody = false @@ -518,7 +517,7 @@ func EnableDumpOnlyBody() *Client { // EnableDumpOnlyBody indicates that should dump the body of requests and responses. func (c *Client) EnableDumpOnlyBody() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.RequestBody = true o.ResponseBody = true o.RequestHeader = false @@ -533,7 +532,7 @@ func EnableDumpOnlyHeader() *Client { // EnableDumpOnlyHeader indicates that should dump the head of requests and responses. func (c *Client) EnableDumpOnlyHeader() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.RequestHeader = true o.ResponseHeader = true o.RequestBody = false @@ -548,7 +547,7 @@ func EnableDumpAll() *Client { // EnableDumpAll indicates that should dump both requests and responses' head and body. func (c *Client) EnableDumpAll() *Client { - o := c.GetDumpOptions() + o := c.getDumpOptions() o.RequestHeader = true o.RequestBody = true o.ResponseHeader = true @@ -575,17 +574,36 @@ func (c *Client) DisableAutoReadResponse(disable bool) *Client { return c } -func EnableAutoDecodeAllType() *Client { - return defaultClient.EnableAutoDecodeAllType() +func SetAutoDecodeContentType(contentTypes ...string) *Client { + return defaultClient.SetAutoDecodeContentType(contentTypes...) +} + +func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { + opt := c.getResponseOptions() + opt.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) + return c +} + +func SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { + return defaultClient.SetAutoDecodeAllTypeFunc(fn) +} + +func (c *Client) SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { + opt := c.getResponseOptions() + opt.AutoDecodeContentType = fn + return c +} + +func SetAutoDecodeAllType() *Client { + return defaultClient.SetAutoDecodeAllType() } -// EnableAutoDecodeAllType indicates that try autodetect and decode all content type. -func (c *Client) EnableAutoDecodeAllType() *Client { - opt := c.GetResponseOptions() +// SetAutoDecodeAllType indicates that try autodetect and decode all content type. +func (c *Client) SetAutoDecodeAllType() *Client { + opt := c.getResponseOptions() opt.AutoDecodeContentType = func(contentType string) bool { return true } - opt.DisableAutoDecode = false return c } @@ -595,7 +613,7 @@ func DisableAutoDecode(disable bool) *Client { // DisableAutoDecode disable auto detect charset and decode to utf-8 func (c *Client) DisableAutoDecode(disable bool) *Client { - c.GetResponseOptions().DisableAutoDecode = disable + c.getResponseOptions().DisableAutoDecode = disable return c } @@ -752,7 +770,7 @@ func (c *Client) Clone() *Client { Headers: cloneHeaders(c.Headers), PathParams: cloneMap(c.PathParams), QueryParams: cloneUrlValues(c.QueryParams), - HostURL: c.HostURL, + BaseURL: c.BaseURL, scheme: c.scheme, log: c.log, beforeRequest: c.beforeRequest, @@ -801,6 +819,7 @@ func C() *Client { XMLMarshal: xml.Marshal, XMLUnmarshal: xml.Unmarshal, } + httpClient.CheckRedirect = c.defaultCheckRedirect t.Debugf = func(format string, v ...interface{}) { if c.DebugLog { diff --git a/decode.go b/decode.go index 82522f35..33f68c2f 100644 --- a/decode.go +++ b/decode.go @@ -32,10 +32,13 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc := charsetutil.FindEncoding(p, a.t.Debugf) + enc, name := charsetutil.FindEncoding(p) if enc == nil { return } + if a.t.Debugf != nil { + a.t.Debugf("charset %s found in body's meta, auto-decode to utf-8", name) + } dc := enc.NewDecoder() a.decodeReader = dc.Reader(a.ReadCloser) var pp []byte diff --git a/internal/charsetutil/charsetutil.go b/internal/charsetutil/charsetutil.go index 1ee956e1..5acdec96 100644 --- a/internal/charsetutil/charsetutil.go +++ b/internal/charsetutil/charsetutil.go @@ -17,29 +17,26 @@ var boms = []struct { {[]byte{0xef, 0xbb, 0xbf}, "utf-8"}, } -func FindEncoding(content []byte, debugf func(format string, v ...interface{})) encoding.Encoding { +func FindEncoding(content []byte) (enc encoding.Encoding, name string) { if len(content) == 0 { - return nil + return } for _, b := range boms { if bytes.HasPrefix(content, b.bom) { - e, _ := htmlcharset.Lookup(b.enc) - if e != nil { - return e + enc, name = htmlcharset.Lookup(b.enc) + if enc != nil { + if strings.ToLower(name) == "utf-8" { + enc = nil + } + return } } } - e, name := prescan(content) + enc, name = prescan(content) if strings.ToLower(name) == "utf-8" { - if debugf != nil { - debugf("%s charset found in the meta tag content, no need to decode", name) - } - return nil - } - if e != nil { - return e + enc = nil } - return nil + return } func prescan(content []byte) (e encoding.Encoding, name string) { diff --git a/middleware.go b/middleware.go index caf4bbf9..d2b39a8c 100644 --- a/middleware.go +++ b/middleware.go @@ -241,7 +241,7 @@ func parseRequestURL(c *Client, r *Request) error { return err } - // If Request.URL is relative path then added c.HostURL into + // If Request.URL is relative path then added c.BaseURL into // the request URL otherwise Request.URL will be used as-is if !reqURL.IsAbs() { r.URL = reqURL.String() @@ -249,13 +249,12 @@ func parseRequestURL(c *Client, r *Request) error { r.URL = "/" + r.URL } - reqURL, err = url.Parse(c.HostURL + r.URL) + reqURL, err = url.Parse(c.BaseURL + r.URL) if err != nil { return err } } - // GH #407 && #318 if reqURL.Scheme == "" && len(c.scheme) > 0 { reqURL.Scheme = c.scheme } diff --git a/response.go b/response.go index eafebfd0..ae43aed1 100644 --- a/response.go +++ b/response.go @@ -2,7 +2,6 @@ package req import ( "net/http" - "strings" "time" ) @@ -23,13 +22,12 @@ var textContentTypes = []string{"text", "json", "xml", "html", "java"} var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { + m := make(map[string]bool) + for _, ct := range contentTypes { + m[ct] = true + } return func(contentType string) bool { - for _, t := range contentTypes { - if strings.Contains(contentType, t) { - return true - } - } - return false + return m[contentType] } } diff --git a/transport.go b/transport.go index e8e26e7d..04109af1 100644 --- a/transport.go +++ b/transport.go @@ -324,6 +324,9 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { if enc == nil { return } + if t.Debugf != nil { + t.Debugf("charset %s detected in Content-Type, auto-decode to utf-8", charset) + } decodeReader := enc.NewDecoder().Reader(res.Body) res.Body = &decodeReaderCloser{res.Body, decodeReader} return From 3477c4a48855b6400bee9c2200f6aa093212e3be Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 16:57:03 +0800 Subject: [PATCH 141/843] update README --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 336551d2..a0e7d7b1 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,11 @@ Simplified golang http client library with magic, happy sending requests, less c ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful [Debugging](#Debugging) capabilities, including debug logs, performance traces, and even dump complete request and response content. -* [Testing with Global Wrapper Methods](#Global) with minimal code. -* Detect the charset of response body and decode it to UTF-8 automatically to avoid garbled characters by default. -* Exportable `Transport`, just replace the Transport of existing http.Client with `*req.Transport`, then you can dump the content as `req` does to debug APIs with minimal code change. +* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, even provide global wrapper methods to test with minimal code. +* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). +* Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. +* Easy upload and download. +* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, redirect policy and so on for requests or clients. ## Quick Start From 1827837f53d214e328937c8dbb038b619d386c7f Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 17:00:18 +0800 Subject: [PATCH 142/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a0e7d7b1..b79f1a51 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Simplified golang http client library with magic, happy sending requests, less c * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy upload and download. -* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, redirect policy and so on for requests or clients. +* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, certs, redirect policy and so on for requests or clients. ## Quick Start @@ -185,7 +185,7 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` -**Testing with Use Global Methods** +**Test with Global Wrapper Methods** `req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. From af8b5494560076e9f78af697aa88ec660a729385 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 17:08:12 +0800 Subject: [PATCH 143/843] update example --- README.md | 10 ++++----- examples/find-popular-repo/go.mod | 2 +- examples/find-popular-repo/go.sum | 36 ++++++++++++++++++++++++++++-- examples/find-popular-repo/main.go | 4 +--- 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index b79f1a51..5d4303c8 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,16 @@ Simplified golang http client library with magic, happy sending requests, less c * [Header and Cookie](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) -* [Download and Upload](#Download) +* [Download and Upload](#Download-Upload) * [Auto-Decoding](#AutoDecode) ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, even provide global wrapper methods to test with minimal code. +* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. -* Easy upload and download. +* Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, certs, redirect policy and so on for requests or clients. ## Quick Start @@ -30,7 +30,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-beta.0 +go get github.com/imroc/req/v2@v2.0.0-beta.1 ``` **Import** @@ -400,7 +400,7 @@ client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") ``` -## Download and Upload +## Download and Upload **Download** diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 4ef8552d..2fca2d04 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,4 @@ module find-popular-repo go 1.13 -require github.com/imroc/req/v2 v2.0.0-alpha.10 +require github.com/imroc/req/v2 v2.0.0-beta.0 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 64e8d105..2d0551d3 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -1,15 +1,47 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-alpha.10 h1:twsgj8MfXqMXg1bzzIeHM5j8lXhpBlHYj5PW7Lj7gW8= -github.com/imroc/req/v2 v2.0.0-alpha.10/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= +github.com/imroc/req/v2 v2.0.0-beta.0 h1:ZM1nIZehQQ16E8ecaIfFWj8+TCn4TbjLAu5RqoOCgBE= +github.com/imroc/req/v2 v2.0.0-beta.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 9cc4a9e1..44a23126 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -17,9 +17,7 @@ type ErrorMessage struct { } func init() { - req.EnableDebugLog(true) - // Uncomment DevMode below if you want to see more details - // req.DevMode() + req.DevMode().EnableDumpOnlyHeader() } // Change the name if you want From 1ee63760c958bb4266a8e4980192d56b9f91ea19 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 17:14:49 +0800 Subject: [PATCH 144/843] update example --- examples/find-popular-repo/go.mod | 2 +- examples/find-popular-repo/go.sum | 4 ++-- examples/find-popular-repo/main.go | 29 +++++++++++++++++------------ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 2fca2d04..60430de2 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,4 @@ module find-popular-repo go 1.13 -require github.com/imroc/req/v2 v2.0.0-beta.0 +require github.com/imroc/req/v2 v2.0.0-beta.1 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 2d0551d3..0af9d86a 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -12,8 +12,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-beta.0 h1:ZM1nIZehQQ16E8ecaIfFWj8+TCn4TbjLAu5RqoOCgBE= -github.com/imroc/req/v2 v2.0.0-beta.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/imroc/req/v2 v2.0.0-beta.1 h1:eVlSdoboOJIz5tYeTw9Ik1Onb8KO02EJ0Ab/GmSli7U= +github.com/imroc/req/v2 v2.0.0-beta.1/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 44a23126..5cd98ee4 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -8,18 +8,6 @@ import ( "github.com/imroc/req/v2" ) -type Repo struct { - Name string `json:"name"` - Star int `json:"stargazers_count"` -} -type ErrorMessage struct { - Message string `json:"message"` -} - -func init() { - req.DevMode().EnableDumpOnlyHeader() -} - // Change the name if you want var username = "imroc" @@ -32,6 +20,18 @@ func main() { fmt.Printf("The most popular repo of %s is %s, which have %d stars\n", username, repo, star) } +func init() { + req.EnableDumpOnlyHeader().EnableDebugLog(true).EnableTraceAll(true) +} + +type Repo struct { + Name string `json:"name"` + Star int `json:"stargazers_count"` +} +type ErrorMessage struct { + Message string `json:"message"` +} + func findTheMostPopularRepo(username string) (repo string, star int, err error) { var popularRepo Repo @@ -53,6 +53,11 @@ func findTheMostPopularRepo(username string) (repo string, star int, err error) SetError(&errMsg). Get("https://api.github.com/users/{username}/repos") + fmt.Println("TraceInfo:") + fmt.Println("----------") + fmt.Println(resp.TraceInfo()) + fmt.Println() + if err != nil { return } From 6cc33db0b6392e7f90983cb4ece7e94ed87228d4 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 17:28:30 +0800 Subject: [PATCH 145/843] update example --- examples/upload/README.md | 19 +++++++++++++++++++ examples/upload/uploadclient/go.mod | 8 +------- examples/upload/uploadclient/go.sum | 4 ++-- 3 files changed, 22 insertions(+), 9 deletions(-) create mode 100644 examples/upload/README.md diff --git a/examples/upload/README.md b/examples/upload/README.md new file mode 100644 index 00000000..6c9dbd7c --- /dev/null +++ b/examples/upload/README.md @@ -0,0 +1,19 @@ +# upload + +This is a upload exmaple for `req` + +## How to Run + +Run `uploadserver`: + +```go +cd uploadserver +go run . +``` + +Run `uploadclient`: + +```go +cd uploadclient +go run . +``` \ No newline at end of file diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index bd99106d..9ab14e93 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -2,10 +2,4 @@ module uploadclient go 1.13 -require ( - github.com/hashicorp/errwrap v1.0.0 // indirect - github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/imroc/req/v2 v2.0.0-alpha.11 // indirect - golang.org/x/net v0.0.0-20220111093109-d55c255bac03 // indirect - golang.org/x/text v0.3.7 // indirect -) +require github.com/imroc/req/v2 v2.0.0-beta.1 diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 3875b817..0af9d86a 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -12,8 +12,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-alpha.11 h1:VJuYAcVcPVJUd9YQ6Vv5S2djHhl9mQS+XKoOkfjAMHw= -github.com/imroc/req/v2 v2.0.0-alpha.11/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/imroc/req/v2 v2.0.0-beta.1 h1:eVlSdoboOJIz5tYeTw9Ik1Onb8KO02EJ0Ab/GmSli7U= +github.com/imroc/req/v2 v2.0.0-beta.1/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= From 3ae70bf319c088d66f057a4e51c971a6305016fe Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 20:24:49 +0800 Subject: [PATCH 146/843] some refactor --- README.md | 180 +++++++++++++++++++++++++++++++++++++++++++++++++- client.go | 18 +++++ middleware.go | 58 ++++++++++------ request.go | 33 +++++++-- 4 files changed, 261 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 5d4303c8..63f59cfc 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,13 @@ Simplified golang http client library with magic, happy sending requests, less c * [Debugging](#Debugging) * [Path Parameter and Query Parameter](#Param) * [Header and Cookie](#Header-Cookie) +* [Body and Marshal/Unmarshal](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) * [Download and Upload](#Download-Upload) * [Auto-Decoding](#AutoDecode) - +* [Request and Response Middleware](#Middleware) + ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. @@ -24,6 +26,7 @@ Simplified golang http client library with magic, happy sending requests, less c * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, certs, redirect policy and so on for requests or clients. +* Support middleware before request sent and after got response ## Quick Start @@ -359,6 +362,159 @@ resp1, err := client.R().Get(url1) resp2, err := client.R().Get(url2) ``` +You can also customize the CookieJar: +```go +// Set your own http.CookieJar implementation +client.SetCookieJar(jar) + +// Set to nil to disable CookieJar +client.SetCookieJar(nil) +``` + +## Body and Marshal/Unmarshal + +**Request Body** + +```go +// Create a client that dump request +client := req.C().EnableDumpOnlyRequest() +// SetBody accepts string, []byte, io.Reader, use type assertion to +// determine the data type of body automatically. +client.R().SetBody("test").Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +test +*/ + +// If it cannot determine, like map and struct, then it will wait +// and marshal to json or xml automatically according to the `Content-Type` +// header that have been set before or after, default to json if not set. +type User struct { + Name string `json:"name"` + Email string `json:"email"` +} +user := &User{Name: "imroc", Email: "roc@imroc.cc"} +client.R().SetBody(user).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/json; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +{"name":"imroc","email":"roc@imroc.cc"} +*/ + + +// You can use more specific methods to avoid type assertions and improves performance, +client.R().SetBodyJsonString(`{"username": "imroc"}`).Post("https://httpbin.org/post") +/* +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/json; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +{"username": "imroc"} +*/ + +// Marshal body and set `Content-Type` automatically without any guess +cient.R().SetBodyXmlMarshal(user).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: text/xml; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +imrocroc@imroc.cc +*/ +``` + +**Response Body** + +```go +// Define success body struct +type User struct { + Name string `json:"name"` + Blog string `json:"blog"` +} +// Define error body struct +type ErrorMessage struct { + Message string `json:"message"` +} +// Create a client and dump body to see details +client := req.C().EnableDumpOnlyBody() + +// Send a request and unmarshal result automatically according to +// response `Content-Type` +user := &User{} +errMsg := &ErrorMessage{} +resp, err := client.R(). + SetResult(user). // Set success result + SetError(errMsg). // Set error result + Get("https://api.github.com/users/imroc") +if err != nil { + log.Fatal(err) +} +fmt.Println("----------") + +if resp.IsSuccess() { // status `code >= 200 and <= 299` is considered as success + // Must have been marshaled to user if no error returned before + fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) +} else if resp.IsError() { // status `code >= 400` is considered as error + // Must have been marshaled to errMsg if no error returned before + fmt.Println("got error:", errMsg.Message) +} else { + log.Fatal("unknown http status:", resp.Status) +} +/* Output +{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":129,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2022-01-24T23:32:53Z"} +---------- +roc's blog is https://imroc.cc +*/ + +// Or you can also unmarshal response later +if resp.IsSuccess() { + err = resp.Unmarshal(user) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) +} else { + fmt.Println("bad response:", resp) +} + +// Also, you can get the raw response and Unmarshal by yourself +yaml.Unmarshal(resp.Bytes()) +``` + +**Disable Auto-Read Response Body** + +Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). + +```go +client.DisableAutoReadResponse(true) + +resp, err := client.R().Get(url) +if err != nil { + log.Fatal(err) +} +io.Copy(dst, resp.Body) +``` + ## Custom Client and Root Certificates ```go @@ -477,6 +633,28 @@ fn := func(contentType string) bool { client.SetAutoDecodeAllTypeFunc(fn) ``` +## Request and Response Middleware + +```go +client := req.C() + +// Registering Request Middleware +client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { + // You can access Client and current Request object to do something + // you need + + return nil // return nil if it is success + }) + +// Registering Response Middleware +client.OnAfterResponse(func(c *req.Client, r *req.Response) error { + // You can access Client and current Response object to do something + // you need + + return nil // return nil if it is success + }) +``` + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index a19201e4..8d5eec75 100644 --- a/client.go +++ b/client.go @@ -667,6 +667,15 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } +func SetCommonContentType(ct string) *Client { + return defaultClient.SetCommonContentType(ct) +} + +func (c *Client) SetCommonContentType(ct string) *Client { + c.SetCommonHeader(hdrContentTypeKey, ct) + return c +} + func EnableDump(enable bool) *Client { return defaultClient.EnableDump(enable) } @@ -750,6 +759,15 @@ func (c *Client) EnableTraceAll(enable bool) *Client { return c } +func SetCookieJar(jar http.CookieJar) *Client { + return defaultClient.SetCookieJar(jar) +} + +func (c *Client) SetCookieJar(jar http.CookieJar) *Client { + c.httpClient.Jar = jar + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/middleware.go b/middleware.go index d2b39a8c..6ba8a4c6 100644 --- a/middleware.go +++ b/middleware.go @@ -53,23 +53,6 @@ func closeq(v interface{}) { } } -type multipartBody struct { - *io.PipeReader - closed bool -} - -func (r *multipartBody) Read(p []byte) (n int, err error) { - if r.closed { - return 0, io.EOF - } - n, err = r.PipeReader.Read(p) - if err != nil { - r.closed = true - err = nil - } - return -} - func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r io.Reader) error { defer closeq(r) // Auto detect actual multipart content type @@ -92,7 +75,7 @@ func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r i return err } -func writeMultiPart(c *Client, r *Request, w *multipart.Writer, pw *io.PipeWriter) { +func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { for k, vs := range r.FormData { for _, v := range vs { w.WriteField(k, v) @@ -110,7 +93,7 @@ func handleMultiPart(c *Client, r *Request) (err error) { r.RawRequest.Body = pr w := multipart.NewWriter(pw) r.RawRequest.Header.Set(hdrContentTypeKey, w.FormDataContentType()) - go writeMultiPart(c, r, w, pw) + go writeMultiPart(r, w, pw) return } @@ -118,6 +101,38 @@ func handleFormData(r *Request) { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(r.FormData.Encode())) } +func handleMarshalBody(c *Client, r *Request) error { + ct := "" + if r.Headers != nil { + ct = r.Headers.Get(hdrContentTypeKey) + } + if ct == "" { + ct = c.Headers.Get(hdrContentTypeKey) + } + if ct != "" { + if util.IsXMLType(ct) { + body, err := c.XMLMarshal(r.marshalBody) + if err != nil { + return err + } + r.SetBodyBytes(body) + } else { + body, err := c.JSONMarshal(r.marshalBody) + if err != nil { + return err + } + r.SetBodyBytes(body) + } + return nil + } + body, err := c.JSONMarshal(r.marshalBody) + if err != nil { + return err + } + r.SetBodyJsonBytes(body) + return nil +} + func parseRequestBody(c *Client, r *Request) (err error) { if r.isMultiPart { return handleMultiPart(c, r) @@ -126,6 +141,9 @@ func parseRequestBody(c *Client, r *Request) (err error) { handleFormData(r) return } + if r.marshalBody != nil { + handleMarshalBody(c, r) + } return } @@ -134,7 +152,7 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { if err != nil { return } - ct := util.FirstNonEmpty(r.GetContentType()) + ct := r.GetContentType() if util.IsJSONType(ct) { return c.JSONUnmarshal(body, v) } else if util.IsXMLType(ct) { diff --git a/request.go b/request.go index 40e5233b..ecfbf0d5 100644 --- a/request.go +++ b/request.go @@ -29,6 +29,7 @@ type Request struct { RawRequest *http.Request StartTime time.Time + marshalBody interface{} ctx context.Context isMultiPart bool uploadFiles []*uploadFile @@ -506,6 +507,8 @@ func (r *Request) SetBody(body interface{}) *Request { r.SetBodyBytes(b) case string: r.SetBodyString(b) + default: + r.marshalBody = body } return r } @@ -538,8 +541,7 @@ func SetBodyJsonString(body string) *Request { // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) - r.SetContentType(jsonContentType) - return r + return r.SetContentType(jsonContentType) } func SetBodyJsonBytes(body []byte) *Request { @@ -550,8 +552,7 @@ func SetBodyJsonBytes(body []byte) *Request { // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) - r.SetContentType(jsonContentType) - return r + return r.SetContentType(jsonContentType) } func SetBodyJsonMarshal(v interface{}) *Request { @@ -566,7 +567,25 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { r.appendError(err) return r } - return r.SetBodyBytes(b) + return r.SetContentType(jsonContentType).SetBodyBytes(b) +} + +func SetBodyXmlString(body string) *Request { + return defaultClient.R().SetBodyXmlString(body) +} + +func (r *Request) SetBodyXmlString(body string) *Request { + r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + return r.SetContentType(xmlContentType) +} + +func SetBodyXmlBytes(body []byte) *Request { + return defaultClient.R().SetBodyXmlBytes(body) +} + +func (r *Request) SetBodyXmlBytes(body []byte) *Request { + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + return r.SetContentType(xmlContentType) } func SetBodyXmlMarshal(v interface{}) *Request { @@ -579,7 +598,7 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { r.appendError(err) return r } - return r.SetBodyBytes(b) + return r.SetContentType(xmlContentType).SetBodyBytes(b) } func SetContentType(contentType string) *Request { @@ -587,7 +606,7 @@ func SetContentType(contentType string) *Request { } func (r *Request) SetContentType(contentType string) *Request { - r.RawRequest.Header.Set(hdrContentTypeKey, contentType) + r.SetHeader(hdrContentTypeKey, contentType) return r } From 9c992a9b69fd8470fef1e540fb6266ea8dfbb952 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 20:33:54 +0800 Subject: [PATCH 147/843] add FormData to README --- README.md | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 63f59cfc..d6605605 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,8 @@ Simplified golang http client library with magic, happy sending requests, less c * [Features](#Features) * [Quick Start](#Quick-Start) * [Debugging](#Debugging) -* [Path Parameter and Query Parameter](#Param) +* [Path and Query Parameter](#Param) +* [Form Data](#From) * [Header and Cookie](#Header-Cookie) * [Body and Marshal/Unmarshal](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) @@ -210,7 +211,7 @@ req.SetQueryParam("page", "2"). Get("https://api.example.com/repos") ``` -## Path Parameter and Query Parameter +## Path and Query Parameter **Set Path Parameter** @@ -270,6 +271,26 @@ resp2, err := client.Get(url2) ... ``` +## Form Data + +```go +client := req.C().EnableDumpOnlyRequest() +client.R().SetFormData(map[string]string{ + "username": "imroc", + "blog": "https://imroc.cc", +}).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +blog=https%3A%2F%2Fimroc.cc&username=imroc +*/ +``` + ## Header and Cookie **Set Header** From 8cf8ba3bdf1c8d99ece4a9d3627b4d3ae22b840c Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 20:34:57 +0800 Subject: [PATCH 148/843] fix README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d6605605..d4c8adce 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Quick Start](#Quick-Start) * [Debugging](#Debugging) * [Path and Query Parameter](#Param) -* [Form Data](#From) +* [Form Data](#Form) * [Header and Cookie](#Header-Cookie) * [Body and Marshal/Unmarshal](#Header-Cookie) * [Custom Client and Root Certificates](#Cert) From 9f26bbe83b43c3707f8a777fa48c8d65c2338f3c Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:00:38 +0800 Subject: [PATCH 149/843] update README: add RedirectPolicy --- README.md | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d4c8adce..0b87c1b3 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Features](#Features) * [Quick Start](#Quick-Start) * [Debugging](#Debugging) -* [Path and Query Parameter](#Param) +* [URL Path and Query Parameter](#Param) * [Form Data](#Form) * [Header and Cookie](#Header-Cookie) * [Body and Marshal/Unmarshal](#Header-Cookie) @@ -18,6 +18,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Download and Upload](#Download-Upload) * [Auto-Decoding](#AutoDecode) * [Request and Response Middleware](#Middleware) +* [Redirect Policy](#Redirect) ## Features @@ -211,7 +212,7 @@ req.SetQueryParam("page", "2"). Get("https://api.example.com/repos") ``` -## Path and Query Parameter +## URL Path and Query Parameter **Set Path Parameter** @@ -289,6 +290,22 @@ user-agent: req/v2 (https://github.com/imroc/req) blog=https%3A%2F%2Fimroc.cc&username=imroc */ + +// Multi value form data +criteria := url.Values{ + "multi": []string{"a", "b", "c"}, +} +client.R().SetFormDataFromValues(criteria).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +multi=a&multi=b&multi=c +*/ ``` ## Header and Cookie @@ -662,7 +679,7 @@ client := req.C() // Registering Request Middleware client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { // You can access Client and current Request object to do something - // you need + // as you need return nil // return nil if it is success }) @@ -670,12 +687,43 @@ client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { // Registering Response Middleware client.OnAfterResponse(func(c *req.Client, r *req.Response) error { // You can access Client and current Response object to do something - // you need + // as you need return nil // return nil if it is success }) ``` +## Redirect Policy + +```go +client := req.C().EnableDumpOnlyRequest() + +client.SetRedirectPolicy( + // Only allow up to 5 redirects + req.MaxRedirectPolicy(5), + // Only allow redirect to same domain. + // e.g. redirect "www.imroc.cc" to "imroc.cc" is allowed, but "google.com" is not + req.SameDomainRedirectPolicy(), +) + +client.SetRedirectPolicy( + // Only *.google.com/google.com and *.imroc.cc/imroc.cc is allowed to redirect + req.AllowedDomainRedirectPolicy("google.com", "imroc.cc"), + // Only allow redirect to same host. + // e.g. redirect "www.imroc.cc" to "imroc.cc" is not allowed, only "www.imroc.cc" is allowed + req.SameHostRedirectPolicy(), +) + +// All redirect is not allowd +client.SetRedirectPolicy(req.NoRedirectPolicy()) + +// Or customize the redirect with your own implementation +client.SetRedirectPolicy(func(req *http.Request, via []*http.Request) error { + // ... +}) + +``` + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file From 05a1661d7f51d1beb12da4ab8b1bacdf8ce46c90 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:10:21 +0800 Subject: [PATCH 150/843] update README: add proxy --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 0b87c1b3..2a7830a8 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Auto-Decoding](#AutoDecode) * [Request and Response Middleware](#Middleware) * [Redirect Policy](#Redirect) +* [Proxy](#Proxy) ## Features @@ -721,7 +722,23 @@ client.SetRedirectPolicy(req.NoRedirectPolicy()) client.SetRedirectPolicy(func(req *http.Request, via []*http.Request) error { // ... }) +``` + +## Proxy + +`Req` use proxy `http.ProxyFromEnvironment` by default, which will read the `HTTP_PROXY/HTTPS_PROXY/http_proxy/https_proxy` environment variable, and setup proxy if environment variable is been set. You can customize it if you need: + +```go +// Set proxy from proxy url +client.SetProxyURL("http://myproxy:8080") + +// Custmize the proxy function with your own implementation +client.SetProxy(func(request *http.Request) (*url.URL, error) { + //... +}) +// Disable proxy +client.SetProxy(nil) ``` ## License From 436d1ca6d2d13f861c0bcf917171b86902ccec88 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:20:12 +0800 Subject: [PATCH 151/843] update README: Marshal/Unmarshal --- README.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2a7830a8..4f94c6b6 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Simplified golang http client library with magic, happy sending requests, less c * [Custom Client and Root Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) * [Download and Upload](#Download-Upload) -* [Auto-Decoding](#AutoDecode) +* [Auto-Decode](#AutoDecode) * [Request and Response Middleware](#Middleware) * [Redirect Policy](#Redirect) * [Proxy](#Proxy) @@ -540,6 +540,22 @@ if resp.IsSuccess() { yaml.Unmarshal(resp.Bytes()) ``` +**Customize Body Marshal/Unmarshal** +```go +// Example of registering json-iterator +import jsoniter "github.com/json-iterator/go" + +json := jsoniter.ConfigCompatibleWithStandardLibrary + +client := req.R() +client.JSONMarshal = json.Marshal +client.JSONUnmarshal = json.Unmarshal + +// Similarly, XML functions can also be customized +client.XMLMarshal +client.XMLUnmarshal +``` + **Disable Auto-Read Response Body** Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). @@ -637,11 +653,14 @@ client.R().SetFile("pic", "test.jpg"). // Set form param name and filename }). SetFromDataFromValues(values). // You can also set form data using `url.Values` Post("http://127.0.0.1:8888/upload") -*/ +// You can also use io.Reader to upload +avatarImgFile, _ := os.Open("avatar.png") +client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) +*/ ``` -## Auto-Decoding +## Auto-Decode `Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. From 4de69cc42064a9fac53adf56a647d7338f2c83e9 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:22:35 +0800 Subject: [PATCH 152/843] update README: v2.0.0-beta.2 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4f94c6b6..89b5a9ff 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-beta.1 +go get github.com/imroc/req/v2@v2.0.0-beta.2 ``` **Import** From e7eba5cd4cceece01bd1c8fda13e1e9f86616302 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:25:16 +0800 Subject: [PATCH 153/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 89b5a9ff..5512399f 100644 --- a/README.md +++ b/README.md @@ -293,10 +293,10 @@ blog=https%3A%2F%2Fimroc.cc&username=imroc */ // Multi value form data -criteria := url.Values{ +v := url.Values{ "multi": []string{"a", "b", "c"}, } -client.R().SetFormDataFromValues(criteria).Post("https://httpbin.org/post") +client.R().SetFormDataFromValues(v).Post("https://httpbin.org/post") /* Output :authority: httpbin.org :method: POST From 62a36eabdf669886e76f42eaff3ddd1bdd60e1a2 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 21:26:47 +0800 Subject: [PATCH 154/843] update README --- README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/README.md b/README.md index 5512399f..94e9c663 100644 --- a/README.md +++ b/README.md @@ -332,8 +332,6 @@ Accept: application/json My-Custom-Header: My Custom Value User: imroc Accept-Encoding: gzip - -... */ // You can also set the common header and cookie for every request on client. @@ -389,8 +387,6 @@ User-Agent: req/v2 (https://github.com/imroc/req) Accept: application/json Cookie: imroc/req="This is my custome cookie value"; testcookie1="testcookie1 value"; testcookie2="testcookie2 value" Accept-Encoding: gzip - -... */ // You can also set the common cookie for every request on client. From 894b555929ecacf64c81b5d8f1484cc8a1475061 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:00:42 +0800 Subject: [PATCH 155/843] unexport marshal&unmarshal --- README.md | 12 ++++----- client.go | 66 +++++++++++++++++++++++++++++++++++++----------- middleware.go | 10 ++++---- request.go | 4 +-- response_body.go | 4 +-- 5 files changed, 66 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 94e9c663..b60a5f8f 100644 --- a/README.md +++ b/README.md @@ -536,20 +536,20 @@ if resp.IsSuccess() { yaml.Unmarshal(resp.Bytes()) ``` -**Customize Body Marshal/Unmarshal** +**Customize JSON&XML Marshal/Unmarshal** + ```go // Example of registering json-iterator import jsoniter "github.com/json-iterator/go" json := jsoniter.ConfigCompatibleWithStandardLibrary -client := req.R() -client.JSONMarshal = json.Marshal -client.JSONUnmarshal = json.Unmarshal +client := req.C(). + SetJsonMarshal(json.Marshal). + SetJsonUnmarshal(json.Unmarshal) // Similarly, XML functions can also be customized -client.XMLMarshal -client.XMLUnmarshal +client.SetXmlMarshal(xmlMarshalFunc).SetXmlUnmarshal(xmlUnmarshalFunc) ``` **Disable Auto-Read Response Body** diff --git a/client.go b/client.go index 8d5eec75..93095c3c 100644 --- a/client.go +++ b/client.go @@ -34,17 +34,17 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - BaseURL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - JSONMarshal func(v interface{}) ([]byte, error) - JSONUnmarshal func(data []byte, v interface{}) error - XMLMarshal func(v interface{}) ([]byte, error) - XMLUnmarshal func(data []byte, v interface{}) error - DebugLog bool - + BaseURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + DebugLog bool + + jsonMarshal func(v interface{}) ([]byte, error) + jsonUnmarshal func(data []byte, v interface{}) error + xmlMarshal func(v interface{}) ([]byte, error) + xmlUnmarshal func(data []byte, v interface{}) error trace bool outputDirectory string disableAutoReadResponse bool @@ -768,6 +768,42 @@ func (c *Client) SetCookieJar(jar http.CookieJar) *Client { return c } +func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { + return defaultClient.SetJsonMarshal(fn) +} + +func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { + c.jsonMarshal = fn + return c +} + +func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { + return defaultClient.SetJsonUnmarshal(fn) +} + +func (c *Client) SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { + c.jsonUnmarshal = fn + return c +} + +func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { + return defaultClient.SetXmlMarshal(fn) +} + +func (c *Client) SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { + c.xmlMarshal = fn + return c +} + +func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { + return defaultClient.SetXmlUnmarshal(fn) +} + +func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { + c.xmlUnmarshal = fn + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() @@ -832,10 +868,10 @@ func C() *Client { httpClient: httpClient, t: t, t2: t2, - JSONMarshal: json.Marshal, - JSONUnmarshal: json.Unmarshal, - XMLMarshal: xml.Marshal, - XMLUnmarshal: xml.Unmarshal, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, } httpClient.CheckRedirect = c.defaultCheckRedirect diff --git a/middleware.go b/middleware.go index 6ba8a4c6..94695eef 100644 --- a/middleware.go +++ b/middleware.go @@ -111,13 +111,13 @@ func handleMarshalBody(c *Client, r *Request) error { } if ct != "" { if util.IsXMLType(ct) { - body, err := c.XMLMarshal(r.marshalBody) + body, err := c.xmlMarshal(r.marshalBody) if err != nil { return err } r.SetBodyBytes(body) } else { - body, err := c.JSONMarshal(r.marshalBody) + body, err := c.jsonMarshal(r.marshalBody) if err != nil { return err } @@ -125,7 +125,7 @@ func handleMarshalBody(c *Client, r *Request) error { } return nil } - body, err := c.JSONMarshal(r.marshalBody) + body, err := c.jsonMarshal(r.marshalBody) if err != nil { return err } @@ -154,9 +154,9 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } ct := r.GetContentType() if util.IsJSONType(ct) { - return c.JSONUnmarshal(body, v) + return c.jsonUnmarshal(body, v) } else if util.IsXMLType(ct) { - return c.XMLUnmarshal(body, v) + return c.xmlUnmarshal(body, v) } return } diff --git a/request.go b/request.go index ecfbf0d5..e3e81af2 100644 --- a/request.go +++ b/request.go @@ -562,7 +562,7 @@ func SetBodyJsonMarshal(v interface{}) *Request { // SetBodyJsonMarshal set the request body that marshaled from object, and // set Content-Type header as "application/json; charset=utf-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { - b, err := r.client.JSONMarshal(v) + b, err := r.client.jsonMarshal(v) if err != nil { r.appendError(err) return r @@ -593,7 +593,7 @@ func SetBodyXmlMarshal(v interface{}) *Request { } func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { - b, err := r.client.XMLMarshal(v) + b, err := r.client.xmlMarshal(v) if err != nil { r.appendError(err) return r diff --git a/response_body.go b/response_body.go index 3c1e069a..9c60e9cd 100644 --- a/response_body.go +++ b/response_body.go @@ -10,7 +10,7 @@ func (r *Response) UnmarshalJson(v interface{}) error { if err != nil { return err } - return r.Request.client.JSONUnmarshal(b, v) + return r.Request.client.jsonUnmarshal(b, v) } func (r *Response) UnmarshalXml(v interface{}) error { @@ -18,7 +18,7 @@ func (r *Response) UnmarshalXml(v interface{}) error { if err != nil { return err } - return r.Request.client.XMLUnmarshal(b, v) + return r.Request.client.xmlUnmarshal(b, v) } func (r *Response) Unmarshal(v interface{}) error { From b6e1a18337397a5b021dcac8caeb78d232665922 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:01:26 +0800 Subject: [PATCH 156/843] update README: v2.0.0-beta.3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b60a5f8f..3c090518 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Simplified golang http client library with magic, happy sending requests, less c **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-beta.2 +go get github.com/imroc/req/v2@v2.0.0-beta.3 ``` **Import** From ee5ab26989347e7e9b0814e23a8988d34b8089c1 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:09:30 +0800 Subject: [PATCH 157/843] update README --- README.md | 2 ++ examples/README.md | 4 ++++ 2 files changed, 6 insertions(+) create mode 100644 examples/README.md diff --git a/README.md b/README.md index 3c090518..cf460f6d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,8 @@ resp, err := client.R(). // Use R() to create a request Get("https://api.github.com/users/{username}/repos") ``` +Checkout more runnable examples in the [examples](examples) direcotry. + ## Debugging **Dump the Content** diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..5ec6097e --- /dev/null +++ b/examples/README.md @@ -0,0 +1,4 @@ +# examples + +* [find-popular-repo](find-popular-repo): Invoke github api to find someone's most popular repo. +* [upload](upload): Use `req` to upload files. Contains a server written with `gin` and a client written with `req` \ No newline at end of file From 3844f9b36cf6f82ebc24fc19e08bdaf9e6ba3c43 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:20:18 +0800 Subject: [PATCH 158/843] update README --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index cf460f6d..5e782990 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# req - -[![GoDoc](https://pkg.go.dev/badge/github.com/imroc/req.svg)](https://pkg.go.dev/github.com/imroc/req) - -Simplified golang http client library with magic, happy sending requests, less code and more efficiency. +

+

Req

+

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

+

+

**Table of Contents** From b7446f78758081f6d97735a454bcdbcf7ae8c7ec Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:23:27 +0800 Subject: [PATCH 159/843] update README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5e782990..b4672d64 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -

-

Req

-

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

-

+

+

Req

+

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

+

**Table of Contents** From 91db98599c7df972d7d28b313fbc81d40d122aab Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 26 Jan 2022 23:24:37 +0800 Subject: [PATCH 160/843] update README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b4672d64..5fc6b043 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -

-

Req

-

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

-

+

+

Req

+

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

+

**Table of Contents** From 97dba0e4b038b90d7d7b56e0863ef0a277a2d86f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 11:11:35 +0800 Subject: [PATCH 161/843] client add common form data --- README.md | 4 +++- client.go | 57 +++++++++++++++++++++++++++++++++++++++++++++------ middleware.go | 13 +++++++++++- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5fc6b043..7b83b213 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,11 @@ * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). +* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support. * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy [Download and Upload](#Download-Upload). -* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token, timeout, proxy, certs, redirect policy and so on for requests or clients. +* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. +* Easy set timeout, proxy, certs, redirect policy for client. * Support middleware before request sent and after got response ## Quick Start diff --git a/client.go b/client.go index 93095c3c..ea948fff 100644 --- a/client.go +++ b/client.go @@ -34,12 +34,14 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - BaseURL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - DebugLog bool + BaseURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + FormData urlpkg.Values + DebugLog bool + AllowGetMethodPayload bool jsonMarshal func(v interface{}) ([]byte, error) jsonUnmarshal func(data []byte, v interface{}) error @@ -115,6 +117,36 @@ func (c *Client) R() *Request { } } +func SetCommonFormDataFromValues(data urlpkg.Values) *Client { + return defaultClient.SetCommonFormDataFromValues(data) +} + +func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { + if c.FormData == nil { + c.FormData = urlpkg.Values{} + } + for k, v := range data { + for _, kv := range v { + c.FormData.Add(k, kv) + } + } + return c +} + +func SetCommonFormData(data map[string]string) *Client { + return defaultClient.SetCommonFormData(data) +} + +func (c *Client) SetCommonFormData(data map[string]string) *Client { + if c.FormData == nil { + c.FormData = urlpkg.Values{} + } + for k, v := range data { + c.FormData.Set(k, v) + } + return c +} + func SetBaseURL(u string) *Client { return defaultClient.SetBaseURL(u) } @@ -804,6 +836,19 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli return c } +func EnableAllowGetMethodPayload(a bool) *Client { + return defaultClient.EnableAllowGetMethodPayload(a) +} + +func (c *Client) EnableAllowGetMethodPayload(a bool) *Client { + c.AllowGetMethodPayload = a + return c +} + +func (c *Client) isPayloadForbid(m string) bool { + return (m == http.MethodGet && !c.AllowGetMethodPayload) || m == http.MethodHead || m == http.MethodOptions +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/middleware.go b/middleware.go index 94695eef..402336bd 100644 --- a/middleware.go +++ b/middleware.go @@ -134,13 +134,24 @@ func handleMarshalBody(c *Client, r *Request) error { } func parseRequestBody(c *Client, r *Request) (err error) { - if r.isMultiPart { + if c.isPayloadForbid(r.RawRequest.Method) { + return + } + // handle multipart + if r.isMultiPart && (r.RawRequest.Method != http.MethodPatch) { return handleMultiPart(c, r) } + + // handle form data + if len(c.FormData) > 0 { + r.SetFormDataFromValues(c.FormData) + } if len(r.FormData) > 0 { handleFormData(r) return } + + // handle marshal body if r.marshalBody != nil { handleMarshalBody(c, r) } From 0d216fc1c7f5043f59d20ad924dd9ae1ad87035f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 11:31:48 +0800 Subject: [PATCH 162/843] update README --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7b83b213..9cf0ac6e 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,14 @@ ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support. * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. * Easy set timeout, proxy, certs, redirect policy for client. -* Support middleware before request sent and after got response +* Support middleware before request sent and after got response. ## Quick Start @@ -279,6 +279,8 @@ resp2, err := client.Get(url2) ## Form Data +Use `SetFormData` or `SetFormDataFromValues` to set form data (`GET`, `HEAD`, and `OPTIONS` requests ignores form data by default). + ```go client := req.C().EnableDumpOnlyRequest() client.R().SetFormData(map[string]string{ @@ -311,6 +313,10 @@ user-agent: req/v2 (https://github.com/imroc/req) multi=a&multi=b&multi=c */ + +// You can also set form data in client level +client.SetCommonFormData(m) +client.SetCommonFormDataFromValues(v) ``` ## Header and Cookie From cd032bfcb6922e1fdcadc1c0e11fce3e528365fd Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 11:32:30 +0800 Subject: [PATCH 163/843] update README: v2.0.0-beta.4 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9cf0ac6e..888fe9a2 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-beta.3 +go get github.com/imroc/req/v2@v2.0.0-beta.4 ``` **Import** From 490240b1723cbb864f14a044c755d397617533b3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 13:27:07 +0800 Subject: [PATCH 164/843] rename SetCert to SetCerts --- README.md | 27 ++++++++++++++------------- client.go | 16 ++++++++-------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 888fe9a2..f1e1a913 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ * [Form Data](#Form) * [Header and Cookie](#Header-Cookie) * [Body and Marshal/Unmarshal](#Header-Cookie) -* [Custom Client and Root Certificates](#Cert) +* [Custom Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) * [Download and Upload](#Download-Upload) * [Auto-Decode](#AutoDecode) @@ -238,7 +238,7 @@ client.R(). */ // You can also set the common PathParam for every request on client -client.SetPathParam(k1, v1).SetPathParams(pathParams) +client.SetCommonPathParam(k1, v1).SetCommonPathParams(pathParams) resp1, err := client.Get(url1) ... @@ -267,9 +267,9 @@ client.R(). */ // You can also set the common QueryParam for every request on client -client.SetQueryParam(k, v). - SetQueryParams(queryParams). - SetQueryString(queryString). +client.SetCommonQueryParam(k, v). + SetCommonQueryParams(queryParams). + SetCommonQueryString(queryString). resp1, err := client.Get(url1) ... @@ -279,8 +279,6 @@ resp2, err := client.Get(url2) ## Form Data -Use `SetFormData` or `SetFormDataFromValues` to set form data (`GET`, `HEAD`, and `OPTIONS` requests ignores form data by default). - ```go client := req.C().EnableDumpOnlyRequest() client.R().SetFormData(map[string]string{ @@ -319,9 +317,12 @@ client.SetCommonFormData(m) client.SetCommonFormDataFromValues(v) ``` +> `GET`, `HEAD`, and `OPTIONS` requests ignores form data by default + ## Header and Cookie **Set Header** + ```go // Let's dump the header to see what's going on client := req.C().EnableDumpOnlyHeader() @@ -345,7 +346,7 @@ Accept-Encoding: gzip */ // You can also set the common header and cookie for every request on client. -client.SetHeader(header).SetHeaders(headers) +client.SetCommonHeader(header).SetCommonHeaders(headers) resp1, err := client.R().Get(url1) ... @@ -400,7 +401,7 @@ Accept-Encoding: gzip */ // You can also set the common cookie for every request on client. -client.SetCookie(cookie).SetCookies(cookies) +client.SetCommonCookie(cookie).SetCommonCookies(cookies) resp1, err := client.R().Get(url1) ... @@ -438,7 +439,7 @@ test */ // If it cannot determine, like map and struct, then it will wait -// and marshal to json or xml automatically according to the `Content-Type` +// and marshal to JSON or XML automatically according to the `Content-Type` // header that have been set before or after, default to json if not set. type User struct { Name string `json:"name"` @@ -576,13 +577,13 @@ if err != nil { io.Copy(dst, resp.Body) ``` -## Custom Client and Root Certificates +## Custom Certificates ```go client := req.R() // Set root cert and client cert from file path -client.SetRootCertFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files +client.SetRootCertsFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file // You can also set root cert from string @@ -596,7 +597,7 @@ if err != nil { // ... // you can add more certs if you want -client.SetCert(cert1, cert2, cert3) +client.SetCerts(cert1, cert2, cert3) ``` ## Basic Auth and Bearer Token diff --git a/client.go b/client.go index ea948fff..788b18ad 100644 --- a/client.go +++ b/client.go @@ -181,12 +181,12 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { return c } -func SetCert(certs ...tls.Certificate) *Client { - return defaultClient.SetCert(certs...) +func SetCerts(certs ...tls.Certificate) *Client { + return defaultClient.SetCerts(certs...) } -// SetCert helps to set client certificates -func (c *Client) SetCert(certs ...tls.Certificate) *Client { +// SetCerts helps to set client certificates +func (c *Client) SetCerts(certs ...tls.Certificate) *Client { config := c.tlsConfig() config.Certificates = append(config.Certificates, certs...) return c @@ -211,12 +211,12 @@ func (c *Client) SetRootCertFromString(pemContent string) *Client { return c } -func SetRootCertFromFile(pemFiles ...string) *Client { - return defaultClient.SetRootCertFromFile(pemFiles...) +func SetRootCertsFromFile(pemFiles ...string) *Client { + return defaultClient.SetRootCertsFromFile(pemFiles...) } -// SetRootCertFromFile helps to set root cert from files -func (c *Client) SetRootCertFromFile(pemFiles ...string) *Client { +// SetRootCertsFromFile helps to set root certs from files +func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { for _, pemFile := range pemFiles { rootPemData, err := ioutil.ReadFile(pemFile) if err != nil { From 74ee944c92990dc7307f72eaf171aa6722840784 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 14:30:35 +0800 Subject: [PATCH 165/843] enhance trace --- README.md | 32 ++++++++++++++++---------------- trace.go | 25 +++++++++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index f1e1a913..7114407c 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ opt.ResponseBody = false client.R().Get("https://www.baidu.com/") ``` -**EnableDebugLog for Deeper Insights** +**Enable DebugLog for Deeper Insights** ```go // Logging is enabled by default, but only output the warning and error message. @@ -152,7 +152,7 @@ client.SetLogger(nil) client.SetLogger(logger) ``` -**EnableTrace to Analyze Performance** +**Enable Trace to Analyze Performance** ```go // Enable trace at request level @@ -161,23 +161,23 @@ resp, err := client.R().EnableTrace(true).Get("https://api.github.com/users/imro if err != nil { log.Fatal(err) } -ti := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary copy in production -fmt.Println(ti) -fmt.Println("--------") -k, v := ti.MaxTime() -fmt.Printf("Max time is %s which tooks %v\n", k, v) +trace := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary struct copy in production. +fmt.Println(trace.Blame()) // Print out exactly where the http request is slowing down. +fmt.Println("----------") +fmt.Println(trace) // Print details /* Output -TotalTime : 1.342805875s -DNSLookupTime : 7.549292ms -TCPConnectTime : 567.833µs -TLSHandshakeTime : 536.604041ms -FirstResponseTime : 797.466708ms -ResponseTime : 374.875µs +request total time is 1.962598667s, and server respond frist byte since connection ready costs 1.311601416s +---------- +TotalTime : 1.962598667s +DNSLookupTime : 3.604917ms +TCPConnectTime : 610µs +TLSHandshakeTime : 644.718542ms +FirstResponseTime : 1.311601416s +ResponseTime : 2.002209ms IsConnReused: : false -RemoteAddr : 192.30.255.117:443 --------- -Max time is FirstResponseTime which tooks 797.466708ms +RemoteAddr : 98.126.155.187:443 + */ // Enable trace at client level diff --git a/trace.go b/trace.go index 6b1a8aa5..b01fe25a 100644 --- a/trace.go +++ b/trace.go @@ -25,21 +25,26 @@ IsConnReused: : true RemoteAddr : %v` ) -func (t TraceInfo) MaxTime() (maxName string, maxValue time.Duration) { +func (t TraceInfo) Blame() string { + var mk string + var mv time.Duration m := map[string]time.Duration{ - "DNSLookupTime": t.DNSLookupTime, - "TCPConnectTime": t.TCPConnectTime, - "TLSHandshakeTime": t.TLSHandshakeTime, - "FirstResponseTime": t.FirstResponseTime, - "ResponseTime": t.ResponseTime, + "dns lookup": t.DNSLookupTime, + "tcp connect": t.TCPConnectTime, + "tls handshake": t.TLSHandshakeTime, + "server respond frist byte since connection ready": t.FirstResponseTime, + "server respond header and body": t.ResponseTime, } for k, v := range m { - if v > maxValue { - maxName = k - maxValue = v + if v > mv { + mk = k + mv = v } } - return + if mk == "" { + return "" + } + return fmt.Sprintf("request total time is %v, and %s costs %v", t.TotalTime, mk, mv) } func (t TraceInfo) String() string { From 0384044fe2b8ac764c4ff04581ac084b9a1ccb8f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 14:43:30 +0800 Subject: [PATCH 166/843] improve trace --- README.md | 17 ++++++++--------- trace.go | 22 +++++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 7114407c..57507614 100644 --- a/README.md +++ b/README.md @@ -167,17 +167,16 @@ fmt.Println("----------") fmt.Println(trace) // Print details /* Output -request total time is 1.962598667s, and server respond frist byte since connection ready costs 1.311601416s ----------- -TotalTime : 1.962598667s -DNSLookupTime : 3.604917ms -TCPConnectTime : 610µs -TLSHandshakeTime : 644.718542ms -FirstResponseTime : 1.311601416s -ResponseTime : 2.002209ms +the request total time is 2.562416041s, and costs 1.289082208s from connection ready to server respond frist byte +-------- +TotalTime : 2.562416041s +DNSLookupTime : 445.246375ms +TCPConnectTime : 428.458µs +TLSHandshakeTime : 825.888208ms +FirstResponseTime : 1.289082208s +ResponseTime : 1.712375ms IsConnReused: : false RemoteAddr : 98.126.155.187:443 - */ // Enable trace at client level diff --git a/trace.go b/trace.go index b01fe25a..c5849294 100644 --- a/trace.go +++ b/trace.go @@ -26,14 +26,17 @@ RemoteAddr : %v` ) func (t TraceInfo) Blame() string { + if t.RemoteAddr == nil { + return "trace is not enabled" + } var mk string var mv time.Duration m := map[string]time.Duration{ - "dns lookup": t.DNSLookupTime, - "tcp connect": t.TCPConnectTime, - "tls handshake": t.TLSHandshakeTime, - "server respond frist byte since connection ready": t.FirstResponseTime, - "server respond header and body": t.ResponseTime, + "on dns lookup": t.DNSLookupTime, + "on tcp connect": t.TCPConnectTime, + "on tls handshake": t.TLSHandshakeTime, + "from connection ready to server respond frist byte": t.FirstResponseTime, + "from server respond frist byte to request completion": t.ResponseTime, } for k, v := range m { if v > mv { @@ -42,14 +45,14 @@ func (t TraceInfo) Blame() string { } } if mk == "" { - return "" + return "nothing to blame" } - return fmt.Sprintf("request total time is %v, and %s costs %v", t.TotalTime, mk, mv) + return fmt.Sprintf("the request total time is %v, and costs %v %s", t.TotalTime, mv, mk) } func (t TraceInfo) String() string { if t.RemoteAddr == nil { - return "uncompleted request" + return "trace is not enabled" } if t.IsConnReused { return fmt.Sprintf(traceReusedFmt, t.TotalTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) @@ -72,7 +75,8 @@ type TraceInfo struct { // TLSHandshakeTime is a duration that TLS handshake took place. TLSHandshakeTime time.Duration - // FirstResponseTime is a duration that server took to respond first byte. + // FirstResponseTime is a duration that server took to respond first byte since + // connection ready (after tls handshake if it's tls and not a reused connection). FirstResponseTime time.Duration // ResponseTime is a duration since first response byte from server to From 05a39ecdc145b8338bc9d2f1f075b2ecaed8c23a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 16:38:54 +0800 Subject: [PATCH 167/843] small refactor and add comments for client --- README.md | 23 ++--- client.go | 288 +++++++++++++++++++++++++++++++++++++++++++---------- request.go | 17 +--- 3 files changed, 249 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 57507614..a6ab0e49 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,13 @@ * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). -* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decoding](#AutoDecode)). +* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). +* Automatic marshal and unmarshal for JSON and XML content type and fully customizable. * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support. * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. -* Easy set timeout, proxy, certs, redirect policy for client. +* Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. * Support middleware before request sent and after got response. ## Quick Start @@ -361,16 +362,8 @@ client := req.C().EnableDumpOnlyHeader() // Send a request with multiple headers and cookies client.R(). - SetCookie(&http.Cookie{ // Set one cookie - Name: "imroc/req", - Value: "This is my custome cookie value", - Path: "/", - Domain: "baidu.com", - MaxAge: 36000, - HttpOnly: false, - Secure: true, - }).SetCookies([]*http.Cookie{ // Set multiple cookies at once - &http.Cookie{ + SetCookies( + &http.Cookie{ Name: "testcookie1", Value: "testcookie1 value", Path: "/", @@ -388,19 +381,19 @@ client.R(). HttpOnly: false, Secure: true, }, - }).Get("https://www.baidu.com/") + ).Get("https://www.baidu.com/") /* Output GET / HTTP/1.1 Host: www.baidu.com User-Agent: req/v2 (https://github.com/imroc/req) Accept: application/json -Cookie: imroc/req="This is my custome cookie value"; testcookie1="testcookie1 value"; testcookie2="testcookie2 value" +Cookie: testcookie1="testcookie1 value"; testcookie2="testcookie2 value" Accept-Encoding: gzip */ // You can also set the common cookie for every request on client. -client.SetCommonCookie(cookie).SetCommonCookies(cookies) +client.SetCommonCookies(cookie1, cookie2, cookie3) resp1, err := client.R().Get(url1) ... diff --git a/client.go b/client.go index 788b18ad..08bdb58f 100644 --- a/client.go +++ b/client.go @@ -56,12 +56,20 @@ type Client struct { t2 *http2Transport dumpOptions *DumpOptions httpClient *http.Client - jsonDecoder *json.Decoder beforeRequest []RequestMiddleware udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware } +func cloneCookies(cookies []*http.Cookie) []*http.Cookie { + if len(cookies) == 0 { + return nil + } + c := make([]*http.Cookie, len(cookies)) + copy(c, cookies) + return c +} + func cloneHeaders(hdrs http.Header) http.Header { if hdrs == nil { return nil @@ -75,6 +83,25 @@ func cloneHeaders(hdrs http.Header) http.Header { return h } +// TODO: change to generics function when generics are commonly used. +func cloneRequestMiddleware(m []RequestMiddleware) []RequestMiddleware { + if len(m) == 0 { + return nil + } + mm := make([]RequestMiddleware, len(m)) + copy(mm, m) + return mm +} + +func cloneResponseMiddleware(m []ResponseMiddleware) []ResponseMiddleware { + if len(m) == 0 { + return nil + } + mm := make([]ResponseMiddleware, len(m)) + copy(mm, m) + return mm +} + func cloneUrlValues(v urlpkg.Values) urlpkg.Values { if v == nil { return nil @@ -99,6 +126,8 @@ func cloneMap(h map[string]string) map[string]string { return m } +// R is a global wrapper methods which delegated +// to the default client's R(). func R() *Request { return defaultClient.R() } @@ -117,10 +146,13 @@ func (c *Client) R() *Request { } } +// SetCommonFormDataFromValues is a global wrapper methods which delegated +// to the default client's SetCommonFormDataFromValues. func SetCommonFormDataFromValues(data urlpkg.Values) *Client { return defaultClient.SetCommonFormDataFromValues(data) } +// SetCommonFormDataFromValues set the form data from url.Values for all requests which method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { if c.FormData == nil { c.FormData = urlpkg.Values{} @@ -133,10 +165,13 @@ func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { return c } +// SetCommonFormData is a global wrapper methods which delegated +// to the default client's SetCommonFormData. func SetCommonFormData(data map[string]string) *Client { return defaultClient.SetCommonFormData(data) } +// SetCommonFormData set the form data from map for all requests which method allows payload. func (c *Client) SetCommonFormData(data map[string]string) *Client { if c.FormData == nil { c.FormData = urlpkg.Values{} @@ -147,24 +182,33 @@ func (c *Client) SetCommonFormData(data map[string]string) *Client { return c } +// SetBaseURL is a global wrapper methods which delegated +// to the default client's SetBaseURL. func SetBaseURL(u string) *Client { return defaultClient.SetBaseURL(u) } +// SetBaseURL set the default base url, will be used if request url is +// a relative url. func (c *Client) SetBaseURL(u string) *Client { c.BaseURL = strings.TrimRight(u, "/") return c } +// SetOutputDirectory is a global wrapper methods which delegated +// to the default client's SetOutputDirectory. func SetOutputDirectory(dir string) *Client { return defaultClient.SetOutputDirectory(dir) } +// SetOutputDirectory set output directory that response will be downloaded to. func (c *Client) SetOutputDirectory(dir string) *Client { c.outputDirectory = dir return c } +// SetCertFromFile is a global wrapper methods which delegated +// to the default client's SetCertFromFile. func SetCertFromFile(certFile, keyFile string) *Client { return defaultClient.SetCertFromFile(certFile, keyFile) } @@ -181,6 +225,8 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { return c } +// SetCerts is a global wrapper methods which delegated +// to the default client's SetCerts. func SetCerts(certs ...tls.Certificate) *Client { return defaultClient.SetCerts(certs...) } @@ -201,6 +247,8 @@ func (c *Client) appendRootCertData(data []byte) { return } +// SetRootCertFromString is a global wrapper methods which delegated +// to the default client's SetRootCertFromString. func SetRootCertFromString(pemContent string) *Client { return defaultClient.SetRootCertFromString(pemContent) } @@ -211,6 +259,8 @@ func (c *Client) SetRootCertFromString(pemContent string) *Client { return c } +// SetRootCertsFromFile is a global wrapper methods which delegated +// to the default client's SetRootCertsFromFile. func SetRootCertsFromFile(pemFiles ...string) *Client { return defaultClient.SetRootCertsFromFile(pemFiles...) } @@ -245,11 +295,13 @@ func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) er return nil } +// SetRedirectPolicy is a global wrapper methods which delegated +// to the default client's SetRedirectPolicy. func SetRedirectPolicy(policies ...RedirectPolicy) *Client { return defaultClient.SetRedirectPolicy(policies...) } -// SetRedirectPolicy helps to set the RedirectPolicy +// SetRedirectPolicy helps to set the RedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -272,37 +324,60 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { return c } +// DisableKeepAlives is a global wrapper methods which delegated +// to the default client's DisableKeepAlives. func DisableKeepAlives(disable bool) *Client { return defaultClient.DisableKeepAlives(disable) } +// DisableKeepAlives set to true disables HTTP keep-alives and +// will only use the connection to the server for a single +// HTTP request. +// +// This is unrelated to the similarly named TCP keep-alives. func (c *Client) DisableKeepAlives(disable bool) *Client { c.t.DisableKeepAlives = disable return c } +// DisableCompression is a global wrapper methods which delegated +// to the default client's DisableCompression. func DisableCompression(disable bool) *Client { return defaultClient.DisableCompression(disable) } +// DisableCompression set to true prevents the Transport from +// requesting compression with an "Accept-Encoding: gzip" +// request header when the Request contains no existing +// Accept-Encoding value. If the Transport requests gzip on +// its own and gets a gzipped response, it's transparently +// decoded in the Response.Body. However, if the user +// explicitly requested gzip it is not automatically +// uncompressed. func (c *Client) DisableCompression(disable bool) *Client { c.t.DisableCompression = disable return c } +// SetTLSClientConfig is a global wrapper methods which delegated +// to the default client's SetTLSClientConfig. func SetTLSClientConfig(conf *tls.Config) *Client { return defaultClient.SetTLSClientConfig(conf) } +// SetTLSClientConfig sets the client tls config. func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf return c } +// SetCommonQueryParams is a global wrapper methods which delegated +// to the default client's SetCommonQueryParams. func SetCommonQueryParams(params map[string]string) *Client { return defaultClient.SetCommonQueryParams(params) } +// SetCommonQueryParams sets the URL query parameters with a map at client level. func (c *Client) SetCommonQueryParams(params map[string]string) *Client { for k, v := range params { c.SetCommonQueryParam(k, v) @@ -310,10 +385,14 @@ func (c *Client) SetCommonQueryParams(params map[string]string) *Client { return c } +// SetCommonQueryParam is a global wrapper methods which delegated +// to the default client's SetCommonQueryParam. func SetCommonQueryParam(key, value string) *Client { return defaultClient.SetCommonQueryParam(key, value) } +// SetCommonQueryParam set an URL query parameter with a key-value +// pair at client level. func (c *Client) SetCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -322,10 +401,13 @@ func (c *Client) SetCommonQueryParam(key, value string) *Client { return c } +// SetCommonQueryString is a global wrapper methods which delegated +// to the default client's SetCommonQueryString. func SetCommonQueryString(query string) *Client { return defaultClient.SetCommonQueryString(query) } +// SetCommonQueryString set URL query parameters using the raw query string. func (c *Client) SetCommonQueryString(query string) *Client { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { @@ -343,34 +425,32 @@ func (c *Client) SetCommonQueryString(query string) *Client { return c } -func SetCommonCookie(hc *http.Cookie) *Client { - return defaultClient.SetCommonCookie(hc) +// SetCommonCookies is a global wrapper methods which delegated +// to the default client's SetCommonCookies. +func SetCommonCookies(cookies ...*http.Cookie) *Client { + return defaultClient.SetCommonCookies(cookies...) } -func (c *Client) SetCommonCookie(hc *http.Cookie) *Client { - c.Cookies = append(c.Cookies, hc) - return c -} - -func SetCommonCookies(cs []*http.Cookie) *Client { - return defaultClient.SetCommonCookies(cs) -} - -func (c *Client) SetCommonCookies(cs []*http.Cookie) *Client { - c.Cookies = append(c.Cookies, cs...) +// SetCommonCookies set cookies at client level. +func (c *Client) SetCommonCookies(cookies ...*http.Cookie) *Client { + c.Cookies = append(c.Cookies, cookies...) return c } +// EnableDebugLog is a global wrapper methods which delegated +// to the default client's EnableDebugLog. func EnableDebugLog(enable bool) *Client { return defaultClient.EnableDebugLog(enable) } +// EnableDebugLog enables debug level log if set to true. func (c *Client) EnableDebugLog(enable bool) *Client { c.DebugLog = enable return c } -// DevMode is a global wrapper method for default client. +// DevMode is a global wrapper methods which delegated +// to the default client's DevMode. func DevMode() *Client { return defaultClient.DevMode() } @@ -387,12 +467,14 @@ func (c *Client) DevMode() *Client { SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } +// SetScheme is a global wrapper methods which delegated +// to the default client's SetScheme. func SetScheme(scheme string) *Client { return defaultClient.SetScheme(scheme) } -// SetScheme method sets custom scheme in the Resty client. It's way to override default. -// client.SetScheme("http") +// SetScheme sets custom default scheme in the client, will be used when +// there is no scheme in the request url. func (c *Client) SetScheme(scheme string) *Client { if !util.IsStringEmpty(scheme) { c.scheme = strings.TrimSpace(scheme) @@ -400,6 +482,8 @@ func (c *Client) SetScheme(scheme string) *Client { return c } +// SetLogger is a global wrapper methods which delegated +// to the default client's SetLogger. func SetLogger(log Logger) *Client { return defaultClient.SetLogger(log) } @@ -421,6 +505,8 @@ func (c *Client) getResponseOptions() *ResponseOptions { return c.t.ResponseOptions } +// SetTimeout is a global wrapper methods which delegated +// to the default client's SetTimeout. func SetTimeout(d time.Duration) *Client { return defaultClient.SetTimeout(d) } @@ -445,11 +531,13 @@ func (c *Client) enableDump() { c.t.EnableDump(c.getDumpOptions()) } +// EnableDumpToFile is a global wrapper methods which delegated +// to the default client's EnableDumpToFile. func EnableDumpToFile(filename string) *Client { return defaultClient.EnableDumpToFile(filename) } -// EnableDumpToFile indicates that the content should dump to the specified filename. +// EnableDumpToFile enables dump and save to the specified filename. func (c *Client) EnableDumpToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { @@ -457,25 +545,30 @@ func (c *Client) EnableDumpToFile(filename string) *Client { return c } c.getDumpOptions().Output = file + c.enableDump() return c } +// EnableDumpTo is a global wrapper methods which delegated +// to the default client's EnableDumpTo. func EnableDumpTo(output io.Writer) *Client { return defaultClient.EnableDumpTo(output) } -// EnableDumpTo indicates that the content should dump to the specified destination. +// EnableDumpTo enables dump and save to the specified io.Writer. func (c *Client) EnableDumpTo(output io.Writer) *Client { c.getDumpOptions().Output = output c.enableDump() return c } +// EnableDumpAsync is a global wrapper methods which delegated +// to the default client's EnableDumpAsync. func EnableDumpAsync() *Client { return defaultClient.EnableDumpAsync() } -// EnableDumpAsync indicates that the dump should be done asynchronously, +// EnableDumpAsync enables dump and output asynchronously, // can be used for debugging in production environment without // affecting performance. func (c *Client) EnableDumpAsync() *Client { @@ -485,10 +578,14 @@ func (c *Client) EnableDumpAsync() *Client { return c } +// EnableDumpNoRequestBody is a global wrapper methods which delegated +// to the default client's EnableDumpNoRequestBody. func EnableDumpNoRequestBody() *Client { return defaultClient.EnableDumpNoRequestBody() } +// EnableDumpNoRequestBody enables dump with request body excluded, can be +// used in upload request to avoid dump the unreadable binary content. func (c *Client) EnableDumpNoRequestBody() *Client { o := c.getDumpOptions() o.ResponseHeader = true @@ -499,10 +596,14 @@ func (c *Client) EnableDumpNoRequestBody() *Client { return c } +// EnableDumpNoResponseBody is a global wrapper methods which delegated +// to the default client's EnableDumpNoResponseBody. func EnableDumpNoResponseBody() *Client { return defaultClient.EnableDumpNoResponseBody() } +// EnableDumpNoResponseBody enables dump with response body excluded, can be +// used in download request to avoid dump the unreadable binary content. func (c *Client) EnableDumpNoResponseBody() *Client { o := c.getDumpOptions() o.ResponseHeader = true @@ -513,11 +614,13 @@ func (c *Client) EnableDumpNoResponseBody() *Client { return c } +// EnableDumpOnlyResponse is a global wrapper methods which delegated +// to the default client's EnableDumpOnlyResponse. func EnableDumpOnlyResponse() *Client { return defaultClient.EnableDumpOnlyResponse() } -// EnableDumpOnlyResponse indicates that should dump the responses' head and response. +// EnableDumpOnlyResponse enables dump with only response included. func (c *Client) EnableDumpOnlyResponse() *Client { o := c.getDumpOptions() o.ResponseHeader = true @@ -528,11 +631,13 @@ func (c *Client) EnableDumpOnlyResponse() *Client { return c } +// EnableDumpOnlyRequest is a global wrapper methods which delegated +// to the default client's EnableDumpOnlyRequest. func EnableDumpOnlyRequest() *Client { return defaultClient.EnableDumpOnlyRequest() } -// EnableDumpOnlyRequest indicates that should dump the requests' head and response. +// EnableDumpOnlyRequest enables dump with only request included. func (c *Client) EnableDumpOnlyRequest() *Client { o := c.getDumpOptions() o.RequestHeader = true @@ -543,11 +648,13 @@ func (c *Client) EnableDumpOnlyRequest() *Client { return c } +// EnableDumpOnlyBody is a global wrapper methods which delegated +// to the default client's EnableDumpOnlyBody. func EnableDumpOnlyBody() *Client { return defaultClient.EnableDumpOnlyBody() } -// EnableDumpOnlyBody indicates that should dump the body of requests and responses. +// EnableDumpOnlyBody enables dump with only body included. func (c *Client) EnableDumpOnlyBody() *Client { o := c.getDumpOptions() o.RequestBody = true @@ -558,11 +665,13 @@ func (c *Client) EnableDumpOnlyBody() *Client { return c } +// EnableDumpOnlyHeader is a global wrapper methods which delegated +// to the default client's EnableDumpOnlyHeader. func EnableDumpOnlyHeader() *Client { return defaultClient.EnableDumpOnlyHeader() } -// EnableDumpOnlyHeader indicates that should dump the head of requests and responses. +// EnableDumpOnlyHeader enables dump with only header included. func (c *Client) EnableDumpOnlyHeader() *Client { o := c.getDumpOptions() o.RequestHeader = true @@ -573,11 +682,14 @@ func (c *Client) EnableDumpOnlyHeader() *Client { return c } +// EnableDumpAll is a global wrapper methods which delegated +// to the default client's EnableDumpAll. func EnableDumpAll() *Client { return defaultClient.EnableDumpAll() } -// EnableDumpAll indicates that should dump both requests and responses' head and body. +// EnableDumpAll enables dump with all content included, +// including both requests and responses' header and body func (c *Client) EnableDumpAll() *Client { o := c.getDumpOptions() o.RequestHeader = true @@ -588,6 +700,8 @@ func (c *Client) EnableDumpAll() *Client { return c } +// NewRequest is a global wrapper methods which delegated +// to the default client's NewRequest. func NewRequest() *Request { return defaultClient.R() } @@ -597,40 +711,53 @@ func (c *Client) NewRequest() *Request { return c.R() } +// DisableAutoReadResponse is a global wrapper methods which delegated +// to the default client's DisableAutoReadResponse. func DisableAutoReadResponse(disable bool) *Client { return defaultClient.DisableAutoReadResponse(disable) } +// DisableAutoReadResponse disable read response body automatically if set to true. func (c *Client) DisableAutoReadResponse(disable bool) *Client { c.disableAutoReadResponse = disable return c } +// SetAutoDecodeContentType is a global wrapper methods which delegated +// to the default client's SetAutoDecodeContentType. func SetAutoDecodeContentType(contentTypes ...string) *Client { return defaultClient.SetAutoDecodeContentType(contentTypes...) } +// SetAutoDecodeContentType set the content types that will be auto-detected and decode +// to utf-8 func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) return c } +// SetAutoDecodeAllTypeFunc is a global wrapper methods which delegated +// to the default client's SetAutoDecodeAllTypeFunc. func SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { return defaultClient.SetAutoDecodeAllTypeFunc(fn) } +// SetAutoDecodeAllTypeFunc set the custmize function that determins the content-type +// whether if should be auto-detected and decode to utf-8 func (c *Client) SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = fn return c } +// SetAutoDecodeAllType is a global wrapper methods which delegated +// to the default client's SetAutoDecodeAllType. func SetAutoDecodeAllType() *Client { return defaultClient.SetAutoDecodeAllType() } -// SetAutoDecodeAllType indicates that try autodetect and decode all content type. +// SetAutoDecodeAllType enables to try auto-detect and decode all content type to utf-8. func (c *Client) SetAutoDecodeAllType() *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = func(contentType string) bool { @@ -639,6 +766,8 @@ func (c *Client) SetAutoDecodeAllType() *Client { return c } +// DisableAutoDecode is a global wrapper methods which delegated +// to the default client's DisableAutoDecode. func DisableAutoDecode(disable bool) *Client { return defaultClient.DisableAutoDecode(disable) } @@ -649,6 +778,8 @@ func (c *Client) DisableAutoDecode(disable bool) *Client { return c } +// SetUserAgent is a global wrapper methods which delegated +// to the default client's SetUserAgent. func SetUserAgent(userAgent string) *Client { return defaultClient.SetUserAgent(userAgent) } @@ -658,27 +789,36 @@ func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetCommonHeader(hdrUserAgentKey, userAgent) } +// SetCommonBearerAuthToken is a global wrapper methods which delegated +// to the default client's SetCommonBearerAuthToken. func SetCommonBearerAuthToken(token string) *Client { return defaultClient.SetCommonBearerAuthToken(token) } +// SetCommonBearerAuthToken set the bearer auth token for all requests. func (c *Client) SetCommonBearerAuthToken(token string) *Client { return c.SetCommonHeader("Authorization", "Bearer "+token) } +// SetCommonBasicAuth is a global wrapper methods which delegated +// to the default client's SetCommonBasicAuth. func SetCommonBasicAuth(username, password string) *Client { return defaultClient.SetCommonBasicAuth(username, password) } +// SetCommonBasicAuth set the basic auth for all requests. func (c *Client) SetCommonBasicAuth(username, password string) *Client { c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password)) return c } +// SetCommonHeaders is a global wrapper methods which delegated +// to the default client's SetCommonHeaders. func SetCommonHeaders(hdrs map[string]string) *Client { return defaultClient.SetCommonHeaders(hdrs) } +// SetCommonHeaders set headers for all requests. func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { for k, v := range hdrs { c.SetCommonHeader(k, v) @@ -686,11 +826,13 @@ func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { return c } +// SetCommonHeader is a global wrapper methods which delegated +// to the default client's SetCommonHeader. func SetCommonHeader(key, value string) *Client { return defaultClient.SetCommonHeader(key, value) } -// SetCommonHeader set the common header for all requests. +// SetCommonHeader set a header for all requests. func (c *Client) SetCommonHeader(key, value string) *Client { if c.Headers == nil { c.Headers = make(http.Header) @@ -699,22 +841,27 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } +// SetCommonContentType is a global wrapper methods which delegated +// to the default client's SetCommonContentType. func SetCommonContentType(ct string) *Client { return defaultClient.SetCommonContentType(ct) } +// SetCommonContentType set the `Content-Type` header for all requests. func (c *Client) SetCommonContentType(ct string) *Client { c.SetCommonHeader(hdrContentTypeKey, ct) return c } +// EnableDump is a global wrapper methods which delegated +// to the default client's EnableDump. func EnableDump(enable bool) *Client { return defaultClient.EnableDump(enable) } -// EnableDump enables dump requests and responses, allowing you -// to clearly see the content of all requests and responses,which -// is very convenient for debugging APIs. +// EnableDump enables dump if set to true, will use a default options if +// not been set before, which dumps all the content of requests and +// responses to stdout. func (c *Client) EnableDump(enable bool) *Client { if !enable { c.t.DisableDump() @@ -724,6 +871,8 @@ func (c *Client) EnableDump(enable bool) *Client { return c } +// SetDumpOptions is a global wrapper methods which delegated +// to the default client's SetDumpOptions. func SetDumpOptions(opt *DumpOptions) *Client { return defaultClient.SetDumpOptions(opt) } @@ -740,6 +889,8 @@ func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { return c } +// SetProxy is a global wrapper methods which delegated +// to the default client's SetProxy. func SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { return defaultClient.SetProxy(proxy) } @@ -750,28 +901,37 @@ func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Clien return c } +// OnBeforeRequest is a global wrapper methods which delegated +// to the default client's OnBeforeRequest. func OnBeforeRequest(m RequestMiddleware) *Client { return defaultClient.OnBeforeRequest(m) } +// OnBeforeRequest add a request middleware which hooks before request sent. func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { c.udBeforeRequest = append(c.udBeforeRequest, m) return c } +// OnAfterResponse is a global wrapper methods which delegated +// to the default client's OnAfterResponse. func OnAfterResponse(m ResponseMiddleware) *Client { return defaultClient.OnAfterResponse(m) } +// OnAfterResponse add a response middleware which hooks after response received. func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { c.afterResponse = append(c.afterResponse, m) return c } +// SetProxyURL is a global wrapper methods which delegated +// to the default client's SetProxyURL. func SetProxyURL(proxyUrl string) *Client { return defaultClient.SetProxyURL(proxyUrl) } +// SetProxyURL set a proxy from the proxy url. func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) if err != nil { @@ -782,64 +942,85 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } +// EnableTraceAll is a global wrapper methods which delegated +// to the default client's EnableTraceAll. func EnableTraceAll(enable bool) *Client { return defaultClient.EnableTraceAll(enable) } +// EnableTraceAll enables the trace at client level if set to true. func (c *Client) EnableTraceAll(enable bool) *Client { c.trace = enable return c } +// SetCookieJar is a global wrapper methods which delegated +// to the default client's SetCookieJar. func SetCookieJar(jar http.CookieJar) *Client { return defaultClient.SetCookieJar(jar) } +// SetCookieJar set the `CookeJar` to the underlying `http.Client` func (c *Client) SetCookieJar(jar http.CookieJar) *Client { c.httpClient.Jar = jar return c } +// SetJsonMarshal is a global wrapper methods which delegated +// to the default client's SetJsonMarshal. func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetJsonMarshal(fn) } +// SetJsonMarshal set json marshal function which will be used to marshal request body. func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { c.jsonMarshal = fn return c } +// SetJsonUnmarshal is a global wrapper methods which delegated +// to the default client's SetJsonUnmarshal. func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetJsonUnmarshal(fn) } +// SetJsonUnmarshal set the JSON unmarshal function which will be used to unmarshal response body. func (c *Client) SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { c.jsonUnmarshal = fn return c } +// SetXmlMarshal is a global wrapper methods which delegated +// to the default client's SetXmlMarshal. func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetXmlMarshal(fn) } +// SetXmlMarshal set the XML marshal function which will be used to marshal request body. func (c *Client) SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { c.xmlMarshal = fn return c } +// SetXmlUnmarshal is a global wrapper methods which delegated +// to the default client's SetXmlUnmarshal. func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetXmlUnmarshal(fn) } +// SetXmlUnmarshal set the XML unmarshal function which will be used to unmarshal response body. func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { c.xmlUnmarshal = fn return c } +// EnableAllowGetMethodPayload is a global wrapper methods which delegated +// to the default client's EnableAllowGetMethodPayload. func EnableAllowGetMethodPayload(a bool) *Client { return defaultClient.EnableAllowGetMethodPayload(a) } +// EnableAllowGetMethodPayload allows sending GET method requests with body if set to true. func (c *Client) EnableAllowGetMethodPayload(a bool) *Client { c.AllowGetMethodPayload = a return c @@ -858,24 +1039,31 @@ func NewClient() *Client { func (c *Client) Clone() *Client { t := c.t.Clone() t2, _ := http2ConfigureTransports(t) - cc := *c.httpClient - cc.Transport = t - return &Client{ - httpClient: &cc, - t: t, - t2: t2, - dumpOptions: c.dumpOptions.Clone(), - jsonDecoder: c.jsonDecoder, - Headers: cloneHeaders(c.Headers), - PathParams: cloneMap(c.PathParams), - QueryParams: cloneUrlValues(c.QueryParams), - BaseURL: c.BaseURL, - scheme: c.scheme, - log: c.log, - beforeRequest: c.beforeRequest, - udBeforeRequest: c.udBeforeRequest, - disableAutoReadResponse: c.disableAutoReadResponse, - } + client := *c.httpClient + client.Transport = t + + cc := *c + cc.httpClient = &client + cc.t = t + cc.t2 = t2 + + cc.Headers = cloneHeaders(c.Headers) + cc.Cookies = cloneCookies(c.Cookies) + cc.PathParams = cloneMap(c.PathParams) + cc.QueryParams = cloneUrlValues(c.QueryParams) + cc.FormData = cloneUrlValues(c.FormData) + cc.beforeRequest = cloneRequestMiddleware(c.beforeRequest) + cc.udBeforeRequest = cloneRequestMiddleware(c.udBeforeRequest) + cc.afterResponse = cloneResponseMiddleware(c.afterResponse) + cc.dumpOptions = c.dumpOptions.Clone() + + cc.log = c.log + cc.jsonUnmarshal = c.jsonUnmarshal + cc.jsonMarshal = c.jsonMarshal + cc.xmlMarshal = c.xmlMarshal + cc.xmlUnmarshal = c.xmlUnmarshal + + return &cc } // C create a new client. @@ -935,7 +1123,6 @@ func setupRequest(r *Request) { } func (c *Client) do(r *Request) (resp *Response, err error) { - resp = &Response{} for _, f := range r.client.udBeforeRequest { @@ -943,7 +1130,6 @@ func (c *Client) do(r *Request) (resp *Response, err error) { return } } - for _, f := range r.client.beforeRequest { if err = f(r.client, r); err != nil { return diff --git a/request.go b/request.go index e3e81af2..dd402d6f 100644 --- a/request.go +++ b/request.go @@ -117,21 +117,12 @@ func (r *Request) SetFormData(data map[string]string) *Request { return r } -func SetCookie(hc *http.Cookie) *Request { - return defaultClient.R().SetCookie(hc) +func SetCookies(cookies ...*http.Cookie) *Request { + return defaultClient.R().SetCookies(cookies...) } -func (r *Request) SetCookie(hc *http.Cookie) *Request { - r.Cookies = append(r.Cookies, hc) - return r -} - -func SetCookies(rs []*http.Cookie) *Request { - return defaultClient.R().SetCookies(rs) -} - -func (r *Request) SetCookies(rs []*http.Cookie) *Request { - r.Cookies = append(r.Cookies, rs...) +func (r *Request) SetCookies(cookies ...*http.Cookie) *Request { + r.Cookies = append(r.Cookies, cookies...) return r } From 0f77c398a937f6fe72b18864c2989799d102ae9c Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 17:41:47 +0800 Subject: [PATCH 168/843] complete comments --- middleware.go | 4 +- request.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++---- response.go | 18 ++----- transport.go | 12 +++++ 4 files changed, 151 insertions(+), 27 deletions(-) diff --git a/middleware.go b/middleware.go index 402336bd..71c41745 100644 --- a/middleware.go +++ b/middleware.go @@ -198,7 +198,7 @@ func handleDownload(c *Client, r *Response) (err error) { body = r.Body } - var output io.WriteCloser + var output io.Writer if r.Request.outputFile != "" { file := r.Request.outputFile if c.outputDirectory != "" && !filepath.IsAbs(file) { @@ -220,7 +220,7 @@ func handleDownload(c *Client, r *Response) (err error) { defer func() { body.Close() - output.Close() + closeq(output) }() _, err = io.Copy(output, body) r.setReceivedAt() diff --git a/request.go b/request.go index dd402d6f..ff89022a 100644 --- a/request.go +++ b/request.go @@ -36,10 +36,11 @@ type Request struct { uploadReader []io.ReadCloser outputFile string isSaveResponse bool - output io.WriteCloser + output io.Writer trace *clientTrace } +// TraceInfo returns the trace information, only available when trace is enabled. func (r *Request) TraceInfo() TraceInfo { ct := r.trace @@ -87,10 +88,13 @@ func (r *Request) TraceInfo() TraceInfo { return ti } +// SetFormDataFromValues is a global wrapper methods which delegated +// to the default client, create a request and SetFormDataFromValues for request. func SetFormDataFromValues(data urlpkg.Values) *Request { return defaultClient.R().SetFormDataFromValues(data) } +// SetFormDataFromValues set the form data from url.Values, not used if method not allow payload. func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { if r.FormData == nil { r.FormData = urlpkg.Values{} @@ -103,10 +107,13 @@ func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { return r } +// SetFormData is a global wrapper methods which delegated +// to the default client, create a request and SetFormData for request. func SetFormData(data map[string]string) *Request { return defaultClient.R().SetFormData(data) } +// SetFormData set the form data from map, not used if method not allow payload. func (r *Request) SetFormData(data map[string]string) *Request { if r.FormData == nil { r.FormData = urlpkg.Values{} @@ -117,19 +124,25 @@ func (r *Request) SetFormData(data map[string]string) *Request { return r } +// SetCookies is a global wrapper methods which delegated +// to the default client, create a request and SetCookies for request. func SetCookies(cookies ...*http.Cookie) *Request { return defaultClient.R().SetCookies(cookies...) } +// SetCookies set cookies at request level. func (r *Request) SetCookies(cookies ...*http.Cookie) *Request { r.Cookies = append(r.Cookies, cookies...) return r } +// SetQueryString is a global wrapper methods which delegated +// to the default client, create a request and SetQueryString for request. func SetQueryString(query string) *Request { return defaultClient.R().SetQueryString(query) } +// SetQueryString set URL query parameters using the raw query string. func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { @@ -144,10 +157,13 @@ func (r *Request) SetQueryString(query string) *Request { return r } +// SetFileReader is a global wrapper methods which delegated +// to the default client, create a request and SetFileReader for request. func SetFileReader(paramName, filePath string, reader io.Reader) *Request { return defaultClient.R().SetFileReader(paramName, filePath, reader) } +// SetFileReader sets up a multipart form with a reader to upload file. func (r *Request) SetFileReader(paramName, filePath string, reader io.Reader) *Request { r.isMultiPart = true r.uploadFiles = append(r.uploadFiles, &uploadFile{ @@ -158,10 +174,14 @@ func (r *Request) SetFileReader(paramName, filePath string, reader io.Reader) *R return r } +// SetFiles is a global wrapper methods which delegated +// to the default client, create a request and SetFiles for request. func SetFiles(files map[string]string) *Request { return defaultClient.R().SetFiles(files) } +// SetFiles sets up a multipart form from a map, which key is the param +// name, value is the file path. func (r *Request) SetFiles(files map[string]string) *Request { for k, v := range files { r.SetFile(k, v) @@ -169,10 +189,13 @@ func (r *Request) SetFiles(files map[string]string) *Request { return r } +// SetFile is a global wrapper methods which delegated +// to the default client, create a request and SetFile for request. func SetFile(paramName, filePath string) *Request { return defaultClient.R().SetFile(paramName, filePath) } +// SetFile sets up a multipart form, read file from filePath automatically to upload. func (r *Request) SetFile(paramName, filePath string) *Request { r.isMultiPart = true file, err := os.Open(filePath) @@ -189,44 +212,61 @@ func (r *Request) SetFile(paramName, filePath string) *Request { return r } +// SetResult is a global wrapper methods which delegated +// to the default client, create a request and SetResult for request. func SetResult(result interface{}) *Request { return defaultClient.R().SetResult(result) } +// SetResult set the result that response body will be unmarshaled to if +// request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r } +// SetError is a global wrapper methods which delegated +// to the default client, create a request and SetError for request. func SetError(error interface{}) *Request { return defaultClient.R().SetError(error) } +// SetError set the result that response body will be unmarshaled to if +// request is error ( status `code >= 400`). func (r *Request) SetError(error interface{}) *Request { r.Error = util.GetPointer(error) return r } +// SetBearerAuthToken is a global wrapper methods which delegated +// to the default client, create a request and SetBearerAuthToken for request. func SetBearerAuthToken(token string) *Request { return defaultClient.R().SetBearerAuthToken(token) } +// SetBearerAuthToken set the bearer auth token at request level. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) } +// SetBasicAuth is a global wrapper methods which delegated +// to the default client, create a request and SetBasicAuth for request. func SetBasicAuth(username, password string) *Request { return defaultClient.R().SetBasicAuth(username, password) } +// SetBasicAuth set the basic auth at request level. func (r *Request) SetBasicAuth(username, password string) *Request { return r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) } +// SetHeaders is a global wrapper methods which delegated +// to the default client, create a request and SetHeaders for request. func SetHeaders(hdrs map[string]string) *Request { return defaultClient.R().SetHeaders(hdrs) } +// SetHeaders set the header at request level. func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { r.SetHeader(k, v) @@ -234,11 +274,13 @@ func (r *Request) SetHeaders(hdrs map[string]string) *Request { return r } +// SetHeader is a global wrapper methods which delegated +// to the default client, create a request and SetHeader for request. func SetHeader(key, value string) *Request { return defaultClient.R().SetHeader(key, value) } -// SetHeader set the common header for all requests. +// SetHeader set a header at request level. func (r *Request) SetHeader(key, value string) *Request { if r.Headers == nil { r.Headers = make(http.Header) @@ -247,30 +289,39 @@ func (r *Request) SetHeader(key, value string) *Request { return r } +// SetOutputFile is a global wrapper methods which delegated +// to the default client, create a request and SetOutputFile for request. func SetOutputFile(file string) *Request { return defaultClient.R().SetOutputFile(file) } +// SetOutputFile the file that response body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true r.outputFile = file return r } -func SetOutput(output io.WriteCloser) *Request { +// SetOutput is a global wrapper methods which delegated +// to the default client, create a request and SetOutput for request. +func SetOutput(output io.Writer) *Request { return defaultClient.R().SetOutput(output) } -func (r *Request) SetOutput(output io.WriteCloser) *Request { +// SetOutput the io.Writer that response body will be downloaded to. +func (r *Request) SetOutput(output io.Writer) *Request { r.output = output r.isSaveResponse = true return r } +// SetQueryParams is a global wrapper methods which delegated +// to the default client, create a request and SetQueryParams for request. func SetQueryParams(params map[string]string) *Request { return defaultClient.R().SetQueryParams(params) } +// SetQueryParams sets the URL query parameters with a map at client level. func (r *Request) SetQueryParams(params map[string]string) *Request { for k, v := range params { r.SetQueryParam(k, v) @@ -278,10 +329,14 @@ func (r *Request) SetQueryParams(params map[string]string) *Request { return r } +// SetQueryParam is a global wrapper methods which delegated +// to the default client, create a request and SetQueryParam for request. func SetQueryParam(key, value string) *Request { return defaultClient.R().SetQueryParam(key, value) } +// SetQueryParam set an URL query parameter with a key-value +// pair at request level. func (r *Request) SetQueryParam(key, value string) *Request { if r.QueryParams == nil { r.QueryParams = make(urlpkg.Values) @@ -290,10 +345,13 @@ func (r *Request) SetQueryParam(key, value string) *Request { return r } +// SetPathParams is a global wrapper methods which delegated +// to the default client, create a request and SetPathParams for request. func SetPathParams(params map[string]string) *Request { return defaultClient.R().SetPathParams(params) } +// SetPathParams sets the URL path parameters from a map at request level. func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { r.SetPathParam(key, value) @@ -301,10 +359,13 @@ func (r *Request) SetPathParams(params map[string]string) *Request { return r } +// SetPathParam is a global wrapper methods which delegated +// to the default client, create a request and SetPathParam for request. func SetPathParam(key, value string) *Request { return defaultClient.R().SetPathParam(key, value) } +// SetPathParam sets the URL path parameters from a key-value paire at request level. func (r *Request) SetPathParam(key, value string) *Request { if r.PathParams == nil { r.PathParams = make(map[string]string) @@ -317,6 +378,7 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } +// Send sends the http request. func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { return &Response{}, r.error @@ -326,11 +388,13 @@ func (r *Request) Send(method, url string) (*Response, error) { return r.client.do(r) } +// MustGet is a global wrapper methods which delegated +// to the default client, create a request and MustGet for request. func MustGet(url string) *Response { return defaultClient.R().MustGet(url) } -// MustGet like Get, panic if error happens. +// MustGet like Get, panic if error happens, should only be used to test without error handling. func (r *Request) MustGet(url string) *Response { resp, err := r.Get(url) if err != nil { @@ -339,6 +403,8 @@ func (r *Request) MustGet(url string) *Response { return resp } +// Get is a global wrapper methods which delegated +// to the default client, create a request and Get for request. func Get(url string) (*Response, error) { return defaultClient.R().Get(url) } @@ -348,6 +414,8 @@ func (r *Request) Get(url string) (*Response, error) { return r.Send(http.MethodGet, url) } +// MustPost is a global wrapper methods which delegated +// to the default client, create a request and Get for request. func MustPost(url string) *Response { return defaultClient.R().MustPost(url) } @@ -361,6 +429,8 @@ func (r *Request) MustPost(url string) *Response { return resp } +// Post is a global wrapper methods which delegated +// to the default client, create a request and Post for request. func Post(url string) (*Response, error) { return defaultClient.R().Post(url) } @@ -370,11 +440,13 @@ func (r *Request) Post(url string) (*Response, error) { return r.Send(http.MethodPost, url) } +// MustPut is a global wrapper methods which delegated +// to the default client, create a request and MustPut for request. func MustPut(url string) *Response { return defaultClient.R().MustPut(url) } -// MustPut like Put, panic if error happens. +// MustPut like Put, panic if error happens, should only be used to test without error handling. func (r *Request) MustPut(url string) *Response { resp, err := r.Put(url) if err != nil { @@ -383,6 +455,8 @@ func (r *Request) MustPut(url string) *Response { return resp } +// Put is a global wrapper methods which delegated +// to the default client, create a request and Put for request. func Put(url string) (*Response, error) { return defaultClient.R().Put(url) } @@ -392,11 +466,13 @@ func (r *Request) Put(url string) (*Response, error) { return r.Send(http.MethodPut, url) } +// MustPatch is a global wrapper methods which delegated +// to the default client, create a request and MustPatch for request. func MustPatch(url string) *Response { return defaultClient.R().MustPatch(url) } -// MustPatch like Patch, panic if error happens. +// MustPatch like Patch, panic if error happens, should only be used to test without error handling. func (r *Request) MustPatch(url string) *Response { resp, err := r.Patch(url) if err != nil { @@ -405,6 +481,8 @@ func (r *Request) MustPatch(url string) *Response { return resp } +// Patch is a global wrapper methods which delegated +// to the default client, create a request and Patch for request. func Patch(url string) (*Response, error) { return defaultClient.R().Patch(url) } @@ -414,11 +492,13 @@ func (r *Request) Patch(url string) (*Response, error) { return r.Send(http.MethodPatch, url) } +// MustDelete is a global wrapper methods which delegated +// to the default client, create a request and MustDelete for request. func MustDelete(url string) *Response { return defaultClient.R().MustDelete(url) } -// MustDelete like Delete, panic if error happens. +// MustDelete like Delete, panic if error happens, should only be used to test without error handling. func (r *Request) MustDelete(url string) *Response { resp, err := r.Delete(url) if err != nil { @@ -427,6 +507,8 @@ func (r *Request) MustDelete(url string) *Response { return resp } +// Delete is a global wrapper methods which delegated +// to the default client, create a request and Delete for request. func Delete(url string) (*Response, error) { return defaultClient.R().Delete(url) } @@ -436,11 +518,13 @@ func (r *Request) Delete(url string) (*Response, error) { return r.Send(http.MethodDelete, url) } +// MustOptions is a global wrapper methods which delegated +// to the default client, create a request and MustOptions for request. func MustOptions(url string) *Response { return defaultClient.R().MustOptions(url) } -// MustOptions like Options, panic if error happens. +// MustOptions like Options, panic if error happens, should only be used to test without error handling. func (r *Request) MustOptions(url string) *Response { resp, err := r.Options(url) if err != nil { @@ -449,6 +533,8 @@ func (r *Request) MustOptions(url string) *Response { return resp } +// Options is a global wrapper methods which delegated +// to the default client, create a request and Options for request. func Options(url string) (*Response, error) { return defaultClient.R().Options(url) } @@ -458,11 +544,13 @@ func (r *Request) Options(url string) (*Response, error) { return r.Send(http.MethodOptions, url) } +// MustHead is a global wrapper methods which delegated +// to the default client, create a request and MustHead for request. func MustHead(url string) *Response { return defaultClient.R().MustHead(url) } -// MustHead like Head, panic if error happens. +// MustHead like Head, panic if error happens, should only be used to test without error handling. func (r *Request) MustHead(url string) *Response { resp, err := r.Send(http.MethodHead, url) if err != nil { @@ -471,6 +559,8 @@ func (r *Request) MustHead(url string) *Response { return resp } +// Head is a global wrapper methods which delegated +// to the default client, create a request and Head for request. func Head(url string) (*Response, error) { return defaultClient.R().Head(url) } @@ -480,11 +570,13 @@ func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } +// SetBody is a global wrapper methods which delegated +// to the default client, create a request and SetBody for request. func SetBody(body interface{}) *Request { return defaultClient.R().SetBody(body) } -// SetBody set the request body. +// SetBody set the request body, accepts string, []byte, io.Reader, map and struct. func (r *Request) SetBody(body interface{}) *Request { if body == nil { return r @@ -504,6 +596,8 @@ func (r *Request) SetBody(body interface{}) *Request { return r } +// SetBodyBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyBytes for request. func SetBodyBytes(body []byte) *Request { return defaultClient.R().SetBodyBytes(body) } @@ -514,6 +608,8 @@ func (r *Request) SetBodyBytes(body []byte) *Request { return r } +// SetBodyString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyString for request. func SetBodyString(body string) *Request { return defaultClient.R().SetBodyString(body) } @@ -524,6 +620,8 @@ func (r *Request) SetBodyString(body string) *Request { return r } +// SetBodyJsonString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonString for request. func SetBodyJsonString(body string) *Request { return defaultClient.R().SetBodyJsonString(body) } @@ -535,6 +633,8 @@ func (r *Request) SetBodyJsonString(body string) *Request { return r.SetContentType(jsonContentType) } +// SetBodyJsonBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonBytes for request. func SetBodyJsonBytes(body []byte) *Request { return defaultClient.R().SetBodyJsonBytes(body) } @@ -546,6 +646,8 @@ func (r *Request) SetBodyJsonBytes(body []byte) *Request { return r.SetContentType(jsonContentType) } +// SetBodyJsonMarshal is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonMarshal for request. func SetBodyJsonMarshal(v interface{}) *Request { return defaultClient.R().SetBodyJsonMarshal(v) } @@ -561,28 +663,40 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { return r.SetContentType(jsonContentType).SetBodyBytes(b) } +// SetBodyXmlString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlString for request. func SetBodyXmlString(body string) *Request { return defaultClient.R().SetBodyXmlString(body) } +// SetBodyXmlString set the request body as string and set Content-Type header +// as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) return r.SetContentType(xmlContentType) } +// SetBodyXmlBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlBytes for request. func SetBodyXmlBytes(body []byte) *Request { return defaultClient.R().SetBodyXmlBytes(body) } +// SetBodyXmlBytes set the request body as []byte and set Content-Type header +// as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) return r.SetContentType(xmlContentType) } +// SetBodyXmlMarshal is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlMarshal for request. func SetBodyXmlMarshal(v interface{}) *Request { return defaultClient.R().SetBodyXmlMarshal(v) } +// SetBodyXmlMarshal set the request body that marshaled from object, and +// set Content-Type header as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { b, err := r.client.xmlMarshal(v) if err != nil { @@ -592,10 +706,13 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { return r.SetContentType(xmlContentType).SetBodyBytes(b) } +// SetContentType is a global wrapper methods which delegated +// to the default client, create a request and SetContentType for request. func SetContentType(contentType string) *Request { return defaultClient.R().SetContentType(contentType) } +// SetContentType set the `Content-Type` for the request. func (r *Request) SetContentType(contentType string) *Request { r.SetHeader(hdrContentTypeKey, contentType) return r @@ -610,6 +727,8 @@ func (r *Request) Context() context.Context { return r.ctx } +// SetContext is a global wrapper methods which delegated +// to the default client, create a request and SetContext for request. func SetContext(ctx context.Context) *Request { return defaultClient.R().SetContext(ctx) } @@ -623,10 +742,13 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } +// EnableTrace is a global wrapper methods which delegated +// to the default client, create a request and EnableTrace for request. func EnableTrace(enable bool) *Request { return defaultClient.R().EnableTrace(enable) } +// EnableTrace enables trace if set to true. func (r *Request) EnableTrace(enable bool) *Request { if enable { if r.trace == nil { diff --git a/response.go b/response.go index ae43aed1..31f7e6a6 100644 --- a/response.go +++ b/response.go @@ -5,18 +5,6 @@ import ( "time" ) -// ResponseOptions determines that how should the response been processed. -type ResponseOptions struct { - // DisableAutoDecode, if true, prevents auto detect response - // body's charset and decode it to utf-8 - DisableAutoDecode bool - - // AutoDecodeContentType specifies an optional function for determine - // whether the response body should been auto decode to utf-8. - // Only valid when DisableAutoDecode is true. - AutoDecodeContentType func(contentType string) bool -} - var textContentTypes = []string{"text", "json", "xml", "html", "java"} var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) @@ -49,20 +37,22 @@ func (r *Response) IsError() bool { return r.StatusCode > 399 } +// GetContentType return the `Content-Type` header value. func (r *Response) GetContentType() string { return r.Header.Get(hdrContentTypeKey) } -// Result method returns the response value as an object if it has one +// Result returns the response value as an object if it has one func (r *Response) Result() interface{} { return r.Request.Result } -// Error method returns the error object if it has one +// Error returns the error object if it has one. func (r *Response) Error() interface{} { return r.Request.Error } +// TraceInfo returns the TraceInfo from Request. func (r *Response) TraceInfo() TraceInfo { return r.Request.TraceInfo() } diff --git a/transport.go b/transport.go index 04109af1..a01a3ce6 100644 --- a/transport.go +++ b/transport.go @@ -44,6 +44,18 @@ import ( // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 +// ResponseOptions determines that how should the response been processed. +type ResponseOptions struct { + // DisableAutoDecode, if true, prevents auto detect response + // body's charset and decode it to utf-8 + DisableAutoDecode bool + + // AutoDecodeContentType specifies an optional function for determine + // whether the response body should been auto decode to utf-8. + // Only valid when DisableAutoDecode is true. + AutoDecodeContentType func(contentType string) bool +} + // Transport is an implementation of http.RoundTripper that supports HTTP, // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). // From 3243c1b45e1cb379261f0c35040e8d89c207add1 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 17:58:32 +0800 Subject: [PATCH 169/843] bump to v2.0.0 --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a6ab0e49..bcc30eea 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,13 @@

-**Table of Contents** +## Big News + +Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) + +If you want to use the older version, check it out in [v1](https://github.com/imroc/req/tree/v1) branch. + +## Table of Contents * [Features](#Features) * [Quick Start](#Quick-Start) @@ -39,7 +45,7 @@ **Install** ``` sh -go get github.com/imroc/req/v2@v2.0.0-beta.4 +go get github.com/imroc/req/v2 ``` **Import** From 7d206d96c61f342e4e4d3603911f9901c0be161b Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:18:02 +0800 Subject: [PATCH 170/843] update examples --- README.md | 2 +- examples/find-popular-repo/go.mod | 2 +- examples/find-popular-repo/go.sum | 4 ++-- examples/upload/uploadclient/go.mod | 2 +- examples/upload/uploadclient/go.sum | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bcc30eea..1cbb6578 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-## Big News +##:warning Big News Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 60430de2..c0d66c81 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,4 @@ module find-popular-repo go 1.13 -require github.com/imroc/req/v2 v2.0.0-beta.1 +require github.com/imroc/req/v2 v2.0.0 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 0af9d86a..aa42d120 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -12,8 +12,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-beta.1 h1:eVlSdoboOJIz5tYeTw9Ik1Onb8KO02EJ0Ab/GmSli7U= -github.com/imroc/req/v2 v2.0.0-beta.1/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/imroc/req/v2 v2.0.0 h1:KcRQ+V4e34cS/3sGn0KMtN9WOKqSXtCLSrZv6WnlU4k= +github.com/imroc/req/v2 v2.0.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index 9ab14e93..f5b7a42a 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -2,4 +2,4 @@ module uploadclient go 1.13 -require github.com/imroc/req/v2 v2.0.0-beta.1 +require github.com/imroc/req/v2 v2.0.0 diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 0af9d86a..aa42d120 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -12,8 +12,8 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0-beta.1 h1:eVlSdoboOJIz5tYeTw9Ik1Onb8KO02EJ0Ab/GmSli7U= -github.com/imroc/req/v2 v2.0.0-beta.1/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= +github.com/imroc/req/v2 v2.0.0 h1:KcRQ+V4e34cS/3sGn0KMtN9WOKqSXtCLSrZv6WnlU4k= +github.com/imroc/req/v2 v2.0.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= From b3d3edf98cd27edb72a0e2962a0e474e0025fc66 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:18:42 +0800 Subject: [PATCH 171/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1cbb6578..ac3f2611 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-##:warning Big News +## :warning Big News Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) From eab748c47b95be934334f93492744d69d6719230 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:19:37 +0800 Subject: [PATCH 172/843] update README: news --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ac3f2611..a25834af 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-## :warning Big News +##:warning: Big News Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) From a1f4372e17ff453cb125d2436fb0a1e744929737 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:20:04 +0800 Subject: [PATCH 173/843] update README: news --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a25834af..a0095774 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-##:warning: Big News +## :warning: Big News Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) From c078587bc410f0c90ddf14efd81ae0c1c930bf5b Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:20:55 +0800 Subject: [PATCH 174/843] update README: remove warning --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a0095774..bcc30eea 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

-## :warning: Big News +## Big News Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) From c35a0a44a1ab19c29c92b180197c34eea135c13a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:21:54 +0800 Subject: [PATCH 175/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bcc30eea..2ac39b16 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) -If you want to use the older version, check it out in [v1](https://github.com/imroc/req/tree/v1) branch. +If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). ## Table of Contents From fe8b8aaf8e70c2a101f54324ee6581d23dcbdafc Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 19:29:55 +0800 Subject: [PATCH 176/843] update README: add MustXXX --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 2ac39b16..5a060979 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,14 @@ req.SetQueryParam("page", "2"). Get("https://api.example.com/repos") ``` +**Test with MustXXX** + +Use `MustXXX` to ignore error handling when test, make it possible to complete a complex test with just one line of code: + +```go +fmt.Println(req.DevMode().MustGet("https://imroc.cc").TraceInfo()) +``` + ## URL Path and Query Parameter **Set Path Parameter** From 76f98fd15de424d74362948bfe03147a4114f2d0 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 20:06:56 +0800 Subject: [PATCH 177/843] fix auto decode --- response.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/response.go b/response.go index 31f7e6a6..b146d2f3 100644 --- a/response.go +++ b/response.go @@ -2,6 +2,7 @@ package req import ( "net/http" + "strings" "time" ) @@ -10,12 +11,13 @@ var textContentTypes = []string{"text", "json", "xml", "html", "java"} var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { - m := make(map[string]bool) - for _, ct := range contentTypes { - m[ct] = true - } return func(contentType string) bool { - return m[contentType] + for _, ct := range contentTypes { + if strings.Contains(contentType, ct) { + return true + } + } + return false } } From e3cf275e021233f5ad88fd071e62fb78b7ee2cc1 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 27 Jan 2022 23:04:42 +0800 Subject: [PATCH 178/843] update README: GoDoc to v2 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a060979..2a10de32 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

Req

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

-

+

## Big News From 00e85260ef40db2b91aef0d80bfd744afc705159 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jan 2022 09:18:28 +0800 Subject: [PATCH 179/843] update README: add TODO List --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2a10de32..c12f3263 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Request and Response Middleware](#Middleware) * [Redirect Policy](#Redirect) * [Proxy](#Proxy) +* [TODO List](#TODO) +* [License](#License) ## Features @@ -773,6 +775,14 @@ client.SetProxy(func(request *http.Request) (*url.URL, error) { client.SetProxy(nil) ``` -## License +## TODO List + +* [ ] Add tests. +* [ ] Wrap more transport settings into client. +* [ ] Support retry. +* [ ] Support unix socket. +* [ ] Support h2c. + +## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file From bde29778f7b0b24857c4eefdbadbb5c8331d651d Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jan 2022 19:40:02 +0800 Subject: [PATCH 180/843] add transport_test.go --- internal/testcert/testcert.go | 46 ++++++ transport_test.go | 268 ++++++++++++++++++++++++++++++++++ 2 files changed, 314 insertions(+) create mode 100644 internal/testcert/testcert.go create mode 100644 transport_test.go diff --git a/internal/testcert/testcert.go b/internal/testcert/testcert.go new file mode 100644 index 00000000..5f94704e --- /dev/null +++ b/internal/testcert/testcert.go @@ -0,0 +1,46 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package testcert contains a test-only localhost certificate. +package testcert + +import "strings" + +// LocalhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 +iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul +rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 +tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs +h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM +fblo6RBxUQ== +-----END CERTIFICATE-----`) + +// LocalhostKey is the private key for LocalhostCert. +var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 +SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB +l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB +AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet +3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb +uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H +qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp +jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY +fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U +fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU +y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX +qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo +f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +-----END RSA TESTING KEY-----`)) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 00000000..ef7e9f45 --- /dev/null +++ b/transport_test.go @@ -0,0 +1,268 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// White-box tests for transport.go (in package http instead of http_test). + +package req + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "github.com/imroc/req/v2/internal/testcert" + "io" + "net" + "net/http" + "strings" + "testing" +) + +func withT(r *http.Request, t *testing.T) *http.Request { + return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) +} + +// Issue 15446: incorrect wrapping of errors when server closes an idle connection. +func TestTransportPersistConnReadLoopEOF(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + connc := make(chan net.Conn, 1) + go func() { + defer close(connc) + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + connc <- c + }() + + tr := new(Transport) + req, _ := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) + req = withT(req, t) + treq := &transportRequest{Request: req} + cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} + pc, err := tr.getConn(treq, cm) + if err != nil { + t.Fatal(err) + } + defer pc.close(errors.New("test over")) + + conn := <-connc + if conn == nil { + // Already called t.Error in the accept goroutine. + return + } + conn.Close() // simulate the server hanging up on the client + + _, err = pc.roundTrip(treq) + if !isTransportReadFromServerError(err) && err != errServerClosedIdle { + t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) + } + + <-pc.closech + err = pc.closed + if !isTransportReadFromServerError(err) && err != errServerClosedIdle { + t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) + } +} + +func isTransportReadFromServerError(err error) bool { + _, ok := err.(transportReadFromServerError) + return ok +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +func dummyRequest(method string) *http.Request { + req, err := http.NewRequest(method, "http://fake.tld/", nil) + if err != nil { + panic(err) + } + return req +} +func dummyRequestWithBody(method string) *http.Request { + req, err := http.NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) + if err != nil { + panic(err) + } + return req +} + +func dummyRequestWithBodyNoGetBody(method string) *http.Request { + req := dummyRequestWithBody(method) + req.GetBody = nil + return req +} + +// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn. +type issue22091Error struct{} + +func (issue22091Error) IsHTTP2NoCachedConnError() {} +func (issue22091Error) Error() string { return "issue22091Error" } + +func TestTransportShouldRetryRequest(t *testing.T) { + tests := []struct { + pc *persistConn + req *http.Request + + err error + want bool + }{ + 0: { + pc: &persistConn{reused: false}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: false, + }, + 1: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: true, + }, + 2: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: http2ErrNoCachedConn, + want: true, + }, + 3: { + pc: nil, + req: nil, + err: issue22091Error{}, // like an external http2ErrNoCachedConn + want: true, + }, + 4: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: errMissingHost, + want: false, + }, + 5: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: transportReadFromServerError{}, + want: false, + }, + 6: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: transportReadFromServerError{}, + want: true, + }, + 7: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: errServerClosedIdle, + want: true, + }, + 8: { + pc: &persistConn{reused: true}, + req: dummyRequestWithBody("POST"), + err: nothingWrittenError{}, + want: true, + }, + 9: { + pc: &persistConn{reused: true}, + req: dummyRequestWithBodyNoGetBody("POST"), + err: nothingWrittenError{}, + want: false, + }, + } + for i, tt := range tests { + got := tt.pc.shouldRetryRequest(tt.req, tt.err) + if got != tt.want { + t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) + } + } +} + +type roundTripFunc func(r *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +// Issue 25009 +func TestTransportBodyAltRewind(t *testing.T) { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + for i := 0; i < 2; i++ { + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + sc.Close() + } + }() + + addr := ln.Addr().String() + req, _ := http.NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) + roundTripped := false + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{ + "foo": func(authority string, c *tls.Conn) http.RoundTripper { + return roundTripFunc(func(r *http.Request) (*http.Response, error) { + n, _ := io.Copy(io.Discard, r.Body) + if n == 0 { + t.Error("body length is zero") + } + if roundTripped { + return &http.Response{ + Body: NoBody, + StatusCode: 200, + }, nil + } + roundTripped = true + return nil, http2noCachedConnError{} + }) + }, + }, + DialTLS: func(_, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + return tc, nil + }, + } + c := &http.Client{Transport: tr} + _, err = c.Do(req) + if err != nil { + t.Error(err) + } +} From d70c95147316fd8acf67f6af7ba0f56cd7d49494 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jan 2022 20:17:17 +0800 Subject: [PATCH 181/843] add transfer_test.go --- transfer_test.go | 364 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 transfer_test.go diff --git a/transfer_test.go b/transfer_test.go new file mode 100644 index 00000000..056d920d --- /dev/null +++ b/transfer_test.go @@ -0,0 +1,364 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bufio" + "bytes" + "crypto/rand" + "fmt" + "io" + "net/http" + "os" + "reflect" + "strings" + "testing" +) + +func TestBodyReadBadTrailer(t *testing.T) { + b := &body{ + src: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), + } + buf := make([]byte, 7) + n, err := b.Read(buf[:3]) + got := string(buf[:n]) + if got != "foo" || err != nil { + t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if got != "bar" || err != nil { + t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if err == nil { + t.Errorf("final Read was successful (%q), expected error from trailer read", got) + } +} + +func TestFinalChunkedBodyReadEOF(t *testing.T) { + res, err := http.ReadResponse(bufio.NewReader(strings.NewReader( + "HTTP/1.1 200 OK\r\n"+ + "Transfer-Encoding: chunked\r\n"+ + "\r\n"+ + "0a\r\n"+ + "Body here\n\r\n"+ + "09\r\n"+ + "continued\r\n"+ + "0\r\n"+ + "\r\n")), nil) + if err != nil { + t.Fatal(err) + } + want := "Body here\ncontinued" + buf := make([]byte, len(want)) + n, err := res.Body.Read(buf) + if n != len(want) || err != io.EOF { + t.Logf("body = %#v", res.Body) + t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) + } + if string(buf) != want { + t.Errorf("buf = %q; want %q", buf, want) + } +} + +func TestDetectInMemoryReaders(t *testing.T) { + pr, _ := io.Pipe() + tests := []struct { + r io.Reader + want bool + }{ + {pr, false}, + + {bytes.NewReader(nil), true}, + {bytes.NewBuffer(nil), true}, + {strings.NewReader(""), true}, + + {io.NopCloser(pr), false}, + + {io.NopCloser(bytes.NewReader(nil)), true}, + {io.NopCloser(bytes.NewBuffer(nil)), true}, + {io.NopCloser(strings.NewReader("")), true}, + } + for i, tt := range tests { + got := isKnownInMemoryReader(tt.r) + if got != tt.want { + t.Errorf("%d: got = %v; want %v", i, got, tt.want) + } + } +} + +type mockTransferWriter struct { + CalledReader io.Reader + WriteCalled bool +} + +var _ io.ReaderFrom = (*mockTransferWriter)(nil) + +func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { + w.CalledReader = r + return io.Copy(io.Discard, r) +} + +func (w *mockTransferWriter) Write(p []byte) (int, error) { + w.WriteCalled = true + return io.Discard.Write(p) +} + +func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { + fileType := reflect.TypeOf(&os.File{}) + bufferType := reflect.TypeOf(&bytes.Buffer{}) + + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := os.CreateTemp("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + bodyFunc func() (io.Reader, func(), error) + method string + contentLength int64 + transferEncoding []string + limitedReader bool + expectedReader reflect.Type + expectedWrite bool + }{ + { + name: "file, non-chunked, size set", + bodyFunc: newFileFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newFileFunc() + return io.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, negative size", + method: "PUT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, chunked", + method: "PUT", + bodyFunc: newFileFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, size set", + bodyFunc: newBufferFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newBufferFunc() + return io.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, negative size", + method: "PUT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, chunked", + method: "PUT", + bodyFunc: newBufferFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body, cleanup, err := tc.bodyFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + mw := &mockTransferWriter{} + tw := &transferWriter{ + Body: body, + ContentLength: tc.contentLength, + TransferEncoding: tc.transferEncoding, + } + + if err := tw.writeBody(mw, nil); err != nil { + t.Fatal(err) + } + + if tc.expectedReader != nil { + if mw.CalledReader == nil { + t.Fatal("did not call ReadFrom") + } + + var actualReader reflect.Type + lr, ok := mw.CalledReader.(*io.LimitedReader) + if ok && tc.limitedReader { + actualReader = reflect.TypeOf(lr.R) + } else { + actualReader = reflect.TypeOf(mw.CalledReader) + } + + if tc.expectedReader != actualReader { + t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader) + } + } + + if tc.expectedWrite && !mw.WriteCalled { + t.Fatal("did not invoke Write") + } + }) + } +} + +func TestParseTransferEncoding(t *testing.T) { + tests := []struct { + hdr http.Header + wantErr error + }{ + { + hdr: http.Header{"Transfer-Encoding": {"fugazi"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}}, + wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {""}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {"chunked, identity"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {"chunked", "identity"}}, + wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {"\x0bchunked"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`}, + }, + { + hdr: http.Header{"Transfer-Encoding": {"chunked"}}, + wantErr: nil, + }, + } + + for i, tt := range tests { + tr := &transferReader{ + Header: tt.hdr, + ProtoMajor: 1, + ProtoMinor: 1, + } + gotErr := tr.parseTransferEncoding() + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr) + } + } +} + +// issue 39017 - disallow Content-Length values such as "+3" +func TestParseContentLength(t *testing.T) { + tests := []struct { + cl string + wantErr error + }{ + { + cl: "3", + wantErr: nil, + }, + { + cl: "+3", + wantErr: badStringError("bad Content-Length", "+3"), + }, + { + cl: "-3", + wantErr: badStringError("bad Content-Length", "-3"), + }, + { + // max int64, for safe conversion before returning + cl: "9223372036854775807", + wantErr: nil, + }, + { + cl: "9223372036854775808", + wantErr: badStringError("bad Content-Length", "9223372036854775808"), + }, + } + + for _, tt := range tests { + if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr) + } + } +} From f189f9ef44289f2737c579ac427854c6cd4b912e Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jan 2022 21:12:01 +0800 Subject: [PATCH 182/843] remove unused func in textproto_reader.go --- textproto_reader.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/textproto_reader.go b/textproto_reader.go index d31c14db..ae2b4711 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -26,10 +26,6 @@ func (e *codeError) Error() string { return fmt.Sprintf("%03d %s", e.Code, e.Msg) } -func isASCIISpace(b byte) bool { - return b == ' ' || b == '\t' || b == '\n' || b == '\r' -} - func isASCIILetter(b byte) bool { b |= 0x20 // make lower case return 'a' <= b && b <= 'z' From 2d2c603622a5cb79656bf0dba69ae82daf9c7d74 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jan 2022 21:54:14 +0800 Subject: [PATCH 183/843] ajust code position --- decode.go | 16 ++++++++++ response.go | 78 ++++++++++++++++++++++++++++++++++++++---------- response_body.go | 69 ------------------------------------------ 3 files changed, 79 insertions(+), 84 deletions(-) delete mode 100644 response_body.go diff --git a/decode.go b/decode.go index 33f68c2f..5205f83d 100644 --- a/decode.go +++ b/decode.go @@ -3,8 +3,24 @@ package req import ( "github.com/imroc/req/v2/internal/charsetutil" "io" + "strings" ) +var textContentTypes = []string{"text", "json", "xml", "html", "java"} + +var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) + +func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { + return func(contentType string) bool { + for _, ct := range contentTypes { + if strings.Contains(contentType, ct) { + return true + } + } + return false + } +} + type decodeReaderCloser struct { io.ReadCloser decodeReader io.Reader diff --git a/response.go b/response.go index b146d2f3..068f20b6 100644 --- a/response.go +++ b/response.go @@ -1,26 +1,12 @@ package req import ( + "io/ioutil" "net/http" "strings" "time" ) -var textContentTypes = []string{"text", "json", "xml", "html", "java"} - -var autoDecodeText = autoDecodeContentTypeFunc(textContentTypes...) - -func autoDecodeContentTypeFunc(contentTypes ...string) func(contentType string) bool { - return func(contentType string) bool { - for _, ct := range contentTypes { - if strings.Contains(contentType, ct) { - return true - } - } - return false - } -} - // Response is the http response. type Response struct { *http.Response @@ -76,3 +62,65 @@ func (r *Response) setReceivedAt() { r.Request.trace.endTime = r.receivedAt } } +func (r *Response) UnmarshalJson(v interface{}) error { + b, err := r.ToBytes() + if err != nil { + return err + } + return r.Request.client.jsonUnmarshal(b, v) +} + +func (r *Response) UnmarshalXml(v interface{}) error { + b, err := r.ToBytes() + if err != nil { + return err + } + return r.Request.client.xmlUnmarshal(b, v) +} + +func (r *Response) Unmarshal(v interface{}) error { + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "json") { + return r.UnmarshalJson(v) + } else if strings.Contains(contentType, "xml") { + return r.UnmarshalXml(v) + } + return r.UnmarshalJson(v) +} + +// Bytes return the response body as []bytes that hava already been read, could be +// nil if not read, the following cases are already read: +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, +// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +func (r *Response) Bytes() []byte { + return r.body +} + +// String return the response body as string that hava already been read, could be +// nil if not read, the following cases are already read: +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, +// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +func (r *Response) String() string { + return string(r.body) +} + +func (r *Response) ToString() (string, error) { + b, err := r.ToBytes() + return string(b), err +} + +func (r *Response) ToBytes() ([]byte, error) { + if r.body != nil { + return r.body, nil + } + defer r.Body.Close() + body, err := ioutil.ReadAll(r.Body) + r.setReceivedAt() + if err != nil { + return nil, err + } + r.body = body + return body, nil +} diff --git a/response_body.go b/response_body.go deleted file mode 100644 index 9c60e9cd..00000000 --- a/response_body.go +++ /dev/null @@ -1,69 +0,0 @@ -package req - -import ( - "io/ioutil" - "strings" -) - -func (r *Response) UnmarshalJson(v interface{}) error { - b, err := r.ToBytes() - if err != nil { - return err - } - return r.Request.client.jsonUnmarshal(b, v) -} - -func (r *Response) UnmarshalXml(v interface{}) error { - b, err := r.ToBytes() - if err != nil { - return err - } - return r.Request.client.xmlUnmarshal(b, v) -} - -func (r *Response) Unmarshal(v interface{}) error { - contentType := r.Header.Get("Content-Type") - if strings.Contains(contentType, "json") { - return r.UnmarshalJson(v) - } else if strings.Contains(contentType, "xml") { - return r.UnmarshalXml(v) - } - return r.UnmarshalJson(v) -} - -// Bytes return the response body as []bytes that hava already been read, could be -// nil if not read, the following cases are already read: -// 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, -// also `Request.SetOutput` and `Request.SetOutputFile` is not called. -func (r *Response) Bytes() []byte { - return r.body -} - -// String return the response body as string that hava already been read, could be -// nil if not read, the following cases are already read: -// 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, -// also `Request.SetOutput` and `Request.SetOutputFile` is not called. -func (r *Response) String() string { - return string(r.body) -} - -func (r *Response) ToString() (string, error) { - b, err := r.ToBytes() - return string(b), err -} - -func (r *Response) ToBytes() ([]byte, error) { - if r.body != nil { - return r.body, nil - } - defer r.Body.Close() - body, err := ioutil.ReadAll(r.Body) - r.setReceivedAt() - if err != nil { - return nil, err - } - r.body = body - return body, nil -} From 1d3ab3c9e568004bafd9029e4ca76250339e62aa Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Jan 2022 17:52:49 +0800 Subject: [PATCH 184/843] add comments on response --- response.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/response.go b/response.go index 068f20b6..c10eac86 100644 --- a/response.go +++ b/response.go @@ -45,6 +45,7 @@ func (r *Response) TraceInfo() TraceInfo { return r.Request.TraceInfo() } +// TotalTime returns the total time of the request, from request we sent to response we received. func (r *Response) TotalTime() time.Duration { if r.Request.trace != nil { return r.Request.TraceInfo().TotalTime @@ -52,6 +53,7 @@ func (r *Response) TotalTime() time.Duration { return r.receivedAt.Sub(r.Request.StartTime) } +// ReceivedAt returns the timestamp that response we received. func (r *Response) ReceivedAt() time.Time { return r.receivedAt } @@ -62,6 +64,8 @@ func (r *Response) setReceivedAt() { r.Request.trace.endTime = r.receivedAt } } + +// UnmarshalJson unmarshals JSON response body into the specified object. func (r *Response) UnmarshalJson(v interface{}) error { b, err := r.ToBytes() if err != nil { @@ -70,6 +74,7 @@ func (r *Response) UnmarshalJson(v interface{}) error { return r.Request.client.jsonUnmarshal(b, v) } +// UnmarshalXml unmarshals XML response body into the specified object. func (r *Response) UnmarshalXml(v interface{}) error { b, err := r.ToBytes() if err != nil { @@ -78,6 +83,8 @@ func (r *Response) UnmarshalXml(v interface{}) error { return r.Request.client.xmlUnmarshal(b, v) } +// Unmarshal unmarshals response body into the specified object according +// to response `Content-Type`. func (r *Response) Unmarshal(v interface{}) error { contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "json") { @@ -97,7 +104,7 @@ func (r *Response) Bytes() []byte { return r.body } -// String return the response body as string that hava already been read, could be +// String returns the response body as string that hava already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. // 2. `Client.DisableAutoReadResponse(false)` is not called, @@ -106,11 +113,13 @@ func (r *Response) String() string { return string(r.body) } +// ToString returns the response body as string, read body if not have been read. func (r *Response) ToString() (string, error) { b, err := r.ToBytes() return string(b), err } +// ToBytes returns the response body as []byte, read body if not have been read. func (r *Response) ToBytes() ([]byte, error) { if r.body != nil { return r.body, nil From 4c18d7e0a01dffdb9fbcac0b13de03126ffbd106 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Jan 2022 18:25:23 +0800 Subject: [PATCH 185/843] fix content-type missing when SetFormData --- README.md | 2 ++ client.go | 2 +- middleware.go | 3 ++- common.go => req.go | 0 req_test.go | 0 5 files changed, 5 insertions(+), 2 deletions(-) rename common.go => req.go (100%) create mode 100644 req_test.go diff --git a/README.md b/README.md index c12f3263..47bdc8f0 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,7 @@ client.R().SetFormData(map[string]string{ :method: POST :path: /post :scheme: https +content-type: application/x-www-form-urlencoded accept-encoding: gzip user-agent: req/v2 (https://github.com/imroc/req) @@ -322,6 +323,7 @@ client.R().SetFormDataFromValues(v).Post("https://httpbin.org/post") :method: POST :path: /post :scheme: https +content-type: application/x-www-form-urlencoded accept-encoding: gzip user-agent: req/v2 (https://github.com/imroc/req) diff --git a/client.go b/client.go index 08bdb58f..1fe0460a 100644 --- a/client.go +++ b/client.go @@ -1086,9 +1086,9 @@ func C() *Client { } beforeRequest := []RequestMiddleware{ parseRequestURL, + parseRequestBody, parseRequestHeader, parseRequestCookie, - parseRequestBody, } afterResponse := []ResponseMiddleware{ parseResponseBody, diff --git a/middleware.go b/middleware.go index 71c41745..8c8eb5e7 100644 --- a/middleware.go +++ b/middleware.go @@ -92,12 +92,13 @@ func handleMultiPart(c *Client, r *Request) (err error) { pr, pw := io.Pipe() r.RawRequest.Body = pr w := multipart.NewWriter(pw) - r.RawRequest.Header.Set(hdrContentTypeKey, w.FormDataContentType()) + r.SetContentType(w.FormDataContentType()) go writeMultiPart(r, w, pw) return } func handleFormData(r *Request) { + r.SetContentType(formContentType) r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(r.FormData.Encode())) } diff --git a/common.go b/req.go similarity index 100% rename from common.go rename to req.go diff --git a/req_test.go b/req_test.go new file mode 100644 index 00000000..e69de29b From 0bd933e89cc1bf50f4159c21391ff677fe0686e5 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Jan 2022 20:30:11 +0800 Subject: [PATCH 186/843] fix #83 --- req_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/req_test.go b/req_test.go index e69de29b..a0563ed2 100644 --- a/req_test.go +++ b/req_test.go @@ -0,0 +1,53 @@ +package req + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func getTestDataPath() string { + pwd, _ := os.Getwd() + return filepath.Join(pwd, ".testdata") +} + +func createTestServer(fn func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(fn)) +} + +func createGetServer(t *testing.T) *httptest.Server { + ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Method: %v", r.Method) + t.Logf("Path: %v", r.URL.Path) + + if r.Method == http.MethodGet { + switch r.URL.Path { + case "/": + _, _ = w.Write([]byte("TestGet: text response")) + case "/no-content": + _, _ = w.Write([]byte("")) + case "/json": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"TestGet": "JSON response"}`)) + case "/json-invalid": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("TestGet: Invalid JSON")) + case "/long-text": + _, _ = w.Write([]byte("TestGet: text response with size > 30")) + case "/long-json": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) + case "/mypage": + w.WriteHeader(http.StatusBadRequest) + case "/mypage2": + _, _ = w.Write([]byte("TestGet: text response from mypage2")) + case "/host-header": + _, _ = w.Write([]byte(r.Host)) + } + } + }) + + return ts +} From 79241cf32c2ae56a60aeda1aeb97fcf21700862f Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 30 Jan 2022 18:34:27 +0800 Subject: [PATCH 187/843] support lower go version --- textproto_reader.go | 3 ++- transport.go | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/textproto_reader.go b/textproto_reader.go index ae2b4711..293830d0 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/imroc/req/v2/internal/util" "io" + "io/ioutil" "net/textproto" "strconv" "strings" @@ -476,7 +477,7 @@ func (r *textprotoReader) closeDot() { // // See the documentation for the DotReader method for details about dot-encoding. func (r *textprotoReader) ReadDotBytes() ([]byte, error) { - return io.ReadAll(r.DotReader()) + return ioutil.ReadAll(r.DotReader()) } // ReadDotLines reads a dot-encoding and returns a slice diff --git a/transport.go b/transport.go index a01a3ce6..f706f272 100644 --- a/transport.go +++ b/transport.go @@ -1594,7 +1594,7 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := tlsConn.HandshakeContext(ctx) + err := tlsConn.Handshake() if timer != nil { timer.Stop() } @@ -1650,7 +1650,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - if err := tc.HandshakeContext(ctx); err != nil { + if err := tc.Handshake(); err != nil { go pconn.conn.Close() if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) From 91c7906e350d00b5af56db4ddb11bb38c29bd1e2 Mon Sep 17 00:00:00 2001 From: James Kokou GAGLO Date: Sun, 30 Jan 2022 17:42:59 +0100 Subject: [PATCH 188/843] Update SetCommonBearerToken to SetCommonBearerAuthToken in the wiki --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 47bdc8f0..31614417 100644 --- a/README.md +++ b/README.md @@ -619,7 +619,7 @@ client := req.C() client.SetCommonBasicAuth("imroc", "123456") // Set bearer token for all request -client.SetCommonBearerToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") +client.SetCommonBearerAuthToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") // Set basic auth for a request, will override client's basic auth setting. client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com/profile") @@ -787,4 +787,4 @@ client.SetProxy(nil) ## License -`Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file +`Req` released under MIT license, refer [LICENSE](LICENSE) file. From 7353d964509104ee9bba2fb5ba0801e86e214d28 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jan 2022 09:49:59 +0800 Subject: [PATCH 189/843] fix v2 build error when go version < 1.17 #84 --- roundtrip.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/roundtrip.go b/roundtrip.go index a8a01996..c2e9696a 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js || !wasm - package req import "net/http" From 960df605422a876ff76d79fb0326b3f70a1dc0bb Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jan 2022 09:56:19 +0800 Subject: [PATCH 190/843] fix all build tag --- h2_bundle.go | 1 - roundtrip.go | 2 ++ transport_default_js.go | 1 - transport_default_other.go | 1 - 4 files changed, 2 insertions(+), 3 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 698d13ef..58d38ade 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1,4 +1,3 @@ -//go:build !nethttpomithttp2 // +build !nethttpomithttp2 // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. diff --git a/roundtrip.go b/roundtrip.go index c2e9696a..fe9c8e09 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//+build !js !wasm + package req import "net/http" diff --git a/transport_default_js.go b/transport_default_js.go index 7cd8e335..af5df819 100644 --- a/transport_default_js.go +++ b/transport_default_js.go @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build js && wasm // +build js,wasm package req diff --git a/transport_default_other.go b/transport_default_other.go index 8191ea79..7f6b27f6 100644 --- a/transport_default_other.go +++ b/transport_default_other.go @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !(js && wasm) // +build !js !wasm package req From ca1710ebfe2382b9dec8a1c72a44096612dbdd7b Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jan 2022 18:08:53 +0800 Subject: [PATCH 191/843] go mod tidy; add request_test.go --- go.mod | 1 - go.sum | 48 --------------------------- req_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++---- request_test.go | 33 +++++++++++++++++++ 4 files changed, 115 insertions(+), 55 deletions(-) create mode 100644 request_test.go diff --git a/go.mod b/go.mod index b27ffece..8c7b6bc0 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/imroc/req/v2 go 1.13 require ( - github.com/gin-gonic/gin v1.7.7 // indirect github.com/hashicorp/go-multierror v1.1.1 golang.org/x/net v0.0.0-20220111093109-d55c255bac03 golang.org/x/text v0.3.7 diff --git a/go.sum b/go.sum index 80bdb59e..59f8d4e3 100644 --- a/go.sum +++ b/go.sum @@ -1,61 +1,13 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/req_test.go b/req_test.go index a0563ed2..a927a0d8 100644 --- a/req_test.go +++ b/req_test.go @@ -5,9 +5,14 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "testing" ) +func tc() *Client { + return C().EnableDebugLog(true) +} + func getTestDataPath() string { pwd, _ := os.Getwd() return filepath.Join(pwd, ".testdata") @@ -19,9 +24,6 @@ func createTestServer(fn func(w http.ResponseWriter, r *http.Request)) *httptest func createGetServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - t.Logf("Method: %v", r.Method) - t.Logf("Path: %v", r.URL.Path) - if r.Method == http.MethodGet { switch r.URL.Path { case "/": @@ -39,10 +41,8 @@ func createGetServer(t *testing.T) *httptest.Server { case "/long-json": w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) - case "/mypage": + case "/bad-request": w.WriteHeader(http.StatusBadRequest) - case "/mypage2": - _, _ = w.Write([]byte("TestGet: text response from mypage2")) case "/host-header": _, _ = w.Write([]byte(r.Host)) } @@ -51,3 +51,79 @@ func createGetServer(t *testing.T) *httptest.Server { return ts } + +func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { + assertError(t, err) + assertNotNil(t, resp) + assertNotNil(t, resp.Body) + assertEqual(t, statusCode, resp.StatusCode) + assertEqual(t, status, resp.Status) +} + +func assertResponse(t *testing.T, resp *Response, err error) { + assertError(t, err) + assertNotNil(t, resp) + assertNotNil(t, resp.Body) + assertEqual(t, http.StatusOK, resp.StatusCode) + assertEqual(t, "200 OK", resp.Status) + assertEqual(t, "HTTP/1.1", resp.Proto) +} + +func assertNil(t *testing.T, v interface{}) { + if !isNil(v) { + t.Errorf("[%v] was expected to be nil", v) + } +} + +func assertNotNil(t *testing.T, v interface{}) { + if isNil(v) { + t.Errorf("[%v] was expected to be non-nil", v) + } +} + +func assertType(t *testing.T, typ, v interface{}) { + if reflect.DeepEqual(reflect.TypeOf(typ), reflect.TypeOf(v)) { + t.Errorf("Expected type %t, got %t", typ, v) + } +} + +func assertError(t *testing.T, err error) { + if err != nil { + t.Errorf("Error occurred [%v]", err) + } +} + +func assertEqual(t *testing.T, e, g interface{}) (r bool) { + if !equal(e, g) { + t.Errorf("Expected [%v], got [%v]", e, g) + } + + return +} + +func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { + if equal(e, g) { + t.Errorf("Expected [%v], got [%v]", e, g) + } else { + r = true + } + + return +} + +func equal(expected, got interface{}) bool { + return reflect.DeepEqual(expected, got) +} + +func isNil(v interface{}) bool { + if v == nil { + return true + } + + rv := reflect.ValueOf(v) + kind := rv.Kind() + if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { + return true + } + return false +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 00000000..47e6c3c4 --- /dev/null +++ b/request_test.go @@ -0,0 +1,33 @@ +package req + +import ( + "net/http" + "testing" +) + +func TestGet(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + c := tc() + resp, err := c.R().Get(ts.URL) + assertResponse(t, resp, err) + assertEqual(t, "TestGet: text response", resp.String()) + + resp, err = c.R().Get(ts.URL + "/no-content") + assertResponse(t, resp, err) + assertEqual(t, "", resp.String()) + + resp, err = c.R().Get(ts.URL + "/json") + assertResponse(t, resp, err) + assertEqual(t, `{"TestGet": "JSON response"}`, resp.String()) + assertEqual(t, resp.GetContentType(), "application/json") + + resp, err = c.R().Get(ts.URL + "/json-invalid") + assertResponse(t, resp, err) + assertEqual(t, `TestGet: Invalid JSON`, resp.String()) + assertEqual(t, resp.GetContentType(), "application/json") + + resp, err = c.R().Get(ts.URL + "/bad-request") + assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") +} From e834d36a362d61e094885ac6846f25baa59a32b2 Mon Sep 17 00:00:00 2001 From: Brian Leishman Date: Mon, 31 Jan 2022 16:30:59 -0500 Subject: [PATCH 192/843] Request AddQueryParam func A very similar function to SetQueryParam, but allowing multiple query params of the same name to be set. Use case currently, before this change, using the HubSpot API ```go r := client.R() r.QueryParams.Add("properties", "budget") r.QueryParams.Add("properties", "quantity") r.QueryParams.Add("properties", "project_name") r.QueryParams.Add("properties", "smg_company") r.QueryParams.Add("properties", "date_needed") r.QueryParams.Add("properties", "file_uploads") r.QueryParams.Add("properties", "description") r.QueryParams.Add("properties", "project_size") r.QueryParams.Add("properties", "product_type") r.QueryParams.Add("associations", "contacts") r.QueryParams.Add("associations", "companies") r, err := hubspot.Get(c, r, fmt.Sprintf("/crm/v3/objects/deals/%d", e.Event.ObjectID)) ``` Which can't be chained like the `SetQueryParam`, but could look like ```go resp, err := hubspot.Get(c, c.Client.R(). AddQueryParam("properties", "budget"). AddQueryParam("properties", "quantity"). AddQueryParam("properties", "project_name"). AddQueryParam("properties", "smg_company"). AddQueryParam("properties", "date_needed"). AddQueryParam("properties", "file_uploads"). AddQueryParam("properties", "description"). AddQueryParam("properties", "project_size"). AddQueryParam("properties", "product_type"). AddQueryParam("associations", "contacts"). AddQueryParam("associations", "companies"), fmt.Sprintf("/crm/v3/objects/deals/%d", e.Event.ObjectID)) ``` Which IMO is cleaner --- request.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/request.go b/request.go index ff89022a..b6c2c9f5 100644 --- a/request.go +++ b/request.go @@ -345,6 +345,16 @@ func (r *Request) SetQueryParam(key, value string) *Request { return r } +// AddQueryParam add an URL query parameter with a key-value +// pair at request level. +func (r *Request) AddQueryParam(key, value string) *Request { + if r.QueryParams == nil { + r.QueryParams = make(urlpkg.Values) + } + r.QueryParams.Add(key, value) + return r +} + // SetPathParams is a global wrapper methods which delegated // to the default client, create a request and SetPathParams for request. func SetPathParams(params map[string]string) *Request { From 844d7845df34387643e895b17a236f9c3670d22c Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Feb 2022 09:59:02 +0800 Subject: [PATCH 193/843] AddQueryParam both for client and request --- client.go | 18 +++++++++++++++++- request.go | 8 +++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 1fe0460a..54334951 100644 --- a/client.go +++ b/client.go @@ -385,13 +385,29 @@ func (c *Client) SetCommonQueryParams(params map[string]string) *Client { return c } +// AddCommonQueryParam is a global wrapper methods which delegated +// to the default client's AddCommonQueryParam. +func AddCommonQueryParam(key, value string) *Client { + return defaultClient.AddCommonQueryParam(key, value) +} + +// AddCommonQueryParam add a URL query parameter with a key-value +// pair at client level +func (c *Client) AddCommonQueryParam(key, value string) *Client { + if c.QueryParams == nil { + c.QueryParams = make(urlpkg.Values) + } + c.QueryParams.Add(key, value) + return c +} + // SetCommonQueryParam is a global wrapper methods which delegated // to the default client's SetCommonQueryParam. func SetCommonQueryParam(key, value string) *Client { return defaultClient.SetCommonQueryParam(key, value) } -// SetCommonQueryParam set an URL query parameter with a key-value +// SetCommonQueryParam set a URL query parameter with a key-value // pair at client level. func (c *Client) SetCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { diff --git a/request.go b/request.go index b6c2c9f5..fe6880a1 100644 --- a/request.go +++ b/request.go @@ -345,7 +345,13 @@ func (r *Request) SetQueryParam(key, value string) *Request { return r } -// AddQueryParam add an URL query parameter with a key-value +// AddQueryParam is a global wrapper methods which delegated +// to the default client, create a request and AddQueryParam for request. +func AddQueryParam(key, value string) *Request { + return defaultClient.R().AddQueryParam(key, value) +} + +// AddQueryParam add a URL query parameter with a key-value // pair at request level. func (r *Request) AddQueryParam(key, value string) *Request { if r.QueryParams == nil { From 1ff65e83f5c83e7ffc8147148c9c17670a26be04 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Feb 2022 11:04:42 +0800 Subject: [PATCH 194/843] update README: AddQueryParam --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 31614417..a5dd0acd 100644 --- a/README.md +++ b/README.md @@ -235,7 +235,7 @@ fmt.Println(req.DevMode().MustGet("https://imroc.cc").TraceInfo()) ## URL Path and Query Parameter -**Set Path Parameter** +**Path Parameter** Use `SetPathParam` or `SetPathParams` to replace variable in the url path: @@ -263,7 +263,7 @@ resp2, err := client.Get(url2) ... ``` -**Set Query Parameter** +**Query Parameter** Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query parameter: @@ -291,6 +291,12 @@ resp1, err := client.Get(url1) ... resp2, err := client.Get(url2) ... + +// And you can add query parameter with multiple values +client.R().AddQueryParam("key", "value1").AddQueryParam("key", "value2").Get("https://httpbin.org/get") + +// Same as client level settings +client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") ``` ## Form Data From 62c5744237c85c563d992e7b8b6e6180d1b8c0f3 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Feb 2022 11:08:18 +0800 Subject: [PATCH 195/843] update go mod for examples --- examples/find-popular-repo/go.mod | 2 +- examples/find-popular-repo/go.sum | 36 ++--------------------------- examples/upload/uploadclient/go.mod | 2 +- examples/upload/uploadclient/go.sum | 36 ++--------------------------- 4 files changed, 6 insertions(+), 70 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index c0d66c81..8863c5d7 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,4 @@ module find-popular-repo go 1.13 -require github.com/imroc/req/v2 v2.0.0 +require github.com/imroc/req/v2 v2.1.0 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index aa42d120..21bcb4e1 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -1,47 +1,15 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0 h1:KcRQ+V4e34cS/3sGn0KMtN9WOKqSXtCLSrZv6WnlU4k= -github.com/imroc/req/v2 v2.0.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +github.com/imroc/req/v2 v2.1.0 h1:zs14o2Pv/3RwAF11HBmjzJJ5ZItOgLk9yABTypbl8nk= +github.com/imroc/req/v2 v2.1.0/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index f5b7a42a..d0bfc686 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -2,4 +2,4 @@ module uploadclient go 1.13 -require github.com/imroc/req/v2 v2.0.0 +require github.com/imroc/req/v2 v2.1.0 diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index aa42d120..21bcb4e1 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -1,47 +1,15 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.0.0 h1:KcRQ+V4e34cS/3sGn0KMtN9WOKqSXtCLSrZv6WnlU4k= -github.com/imroc/req/v2 v2.0.0/go.mod h1:Tn6STXYRyagrmRswbWYuiiSE3FBpnzUcqBHjtJ3+7gI= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +github.com/imroc/req/v2 v2.1.0 h1:zs14o2Pv/3RwAF11HBmjzJJ5ZItOgLk9yABTypbl8nk= +github.com/imroc/req/v2 v2.1.0/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From f0c1554bc2d14187ffe7689468205a961dd5157d Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Feb 2022 15:20:12 +0800 Subject: [PATCH 196/843] fix README: #Body --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a5dd0acd..28188b45 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [URL Path and Query Parameter](#Param) * [Form Data](#Form) * [Header and Cookie](#Header-Cookie) -* [Body and Marshal/Unmarshal](#Header-Cookie) +* [Body and Marshal/Unmarshal](#Body) * [Custom Certificates](#Cert) * [Basic Auth and Bearer Token](#Auth) * [Download and Upload](#Download-Upload) From 6c450248ff6da3e926c16da4c7dccde9593f221a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Feb 2022 20:01:50 +0800 Subject: [PATCH 197/843] support dump single request --- README.md | 4 ++ client.go | 20 +++++---- dump.go | 8 ++++ h2_bundle.go | 112 ++++++++++++++++++++++++++--------------------- http_response.go | 6 ++- request.go | 81 ++++++++++++++++++++++++++++++++++ response.go | 6 +++ roundtrip.go | 2 +- transport.go | 21 ++++++--- 9 files changed, 194 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 28188b45..da1f0854 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,10 @@ client.R().Get("https://www.baidu.com/") // Change settings dynamiclly opt.ResponseBody = false client.R().Get("https://www.baidu.com/") + +// Dump single request +resp, err := client.R().DumpAll().SetBody("test body").Post("https://httpbin.org/post") +fmt.Println(resp.Dump()) ``` **Enable DebugLog for Deeper Insights** diff --git a/client.go b/client.go index 54334951..27598f6f 100644 --- a/client.go +++ b/client.go @@ -1136,6 +1136,7 @@ func setupRequest(r *Request) { setRequestURL(r.RawRequest, r.URL) setRequestHeaderAndCookie(r) setTrace(r) + setContext(r) } func (c *Client) do(r *Request) (resp *Response, err error) { @@ -1182,16 +1183,19 @@ func (c *Client) do(r *Request) (resp *Response, err error) { return } +func setContext(r *Request) { + if r.ctx != nil { + r.RawRequest = r.RawRequest.WithContext(r.ctx) + } +} + func setTrace(r *Request) { - if r.trace == nil { - if r.client.trace { - r.trace = &clientTrace{} - } else { - return - } + if r.trace == nil && r.client.trace { + r.trace = &clientTrace{} + } + if r.trace != nil { + r.ctx = r.trace.createContext(r.Context()) } - r.ctx = r.trace.createContext(r.Context()) - r.RawRequest = r.RawRequest.WithContext(r.ctx) } func setRequestHeaderAndCookie(r *Request) { diff --git a/dump.go b/dump.go index 69d60ce9..192a48ee 100644 --- a/dump.go +++ b/dump.go @@ -1,6 +1,7 @@ package req import ( + "context" "io" "os" ) @@ -169,3 +170,10 @@ func (t *Transport) DisableDump() { t.dump = nil } } + +func getDumperOverride(dump *dumper, ctx context.Context) *dumper { + if d, ok := ctx.Value("dumper").(*dumper); ok { + return d + } + return dump +} diff --git a/h2_bundle.go b/h2_bundle.go index 58d38ade..182d711b 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1592,7 +1592,7 @@ type http2Frame interface { // A Framer reads and writes Frames. type http2Framer struct { - WriteData func(streamID uint32, endStream bool, data []byte) error + cc *http2ClientConn dump *dumper r io.Reader lastFrame http2Frame @@ -1709,7 +1709,7 @@ func (f *http2Framer) endWrite() error { func (f *http2Framer) logWrite() { if f.debugFramer == nil { f.debugFramerBuf = new(bytes.Buffer) - f.debugFramer = http2NewFramer(nil, f.debugFramerBuf, f.dump) + f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below // Let us read anything, even if we accidentally wrote it // in the wrong order: @@ -1761,9 +1761,8 @@ func (fc *http2frameCache) getDataFrame() *http2DataFrame { } // NewFramer returns a Framer that writes frames to w and reads them from r. -func http2NewFramer(w io.Writer, r io.Reader, dump *dumper) *http2Framer { +func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ - dump: dump, w: w, r: r, countError: func(string) {}, @@ -1780,14 +1779,6 @@ func http2NewFramer(w io.Writer, r io.Reader, dump *dumper) *http2Framer { return fr.readBuf } fr.SetMaxReadFrameSize(http2maxFrameSize) - if dump != nil && dump.RequestBody { - fr.WriteData = func(streamID uint32, endStream bool, data []byte) error { - dump.dump(data) - return fr.writeData(streamID, endStream, data) - } - } else { - fr.WriteData = fr.writeData - } return fr } @@ -1863,9 +1854,10 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { - hr, err := fr.readMetaFrame(f.(*http2HeadersFrame)) - if err == nil && fr.dump != nil && fr.dump.ResponseHeader { - fr.dump.dump([]byte("\r\n")) + dump := getDumperOverride(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dump) + if err == nil && dump != nil && dump.ResponseHeader { + dump.dump([]byte("\r\n")) } return hr, err } @@ -1995,7 +1987,7 @@ func http2validStreamID(streamID uint32) bool { // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. -func (f *http2Framer) writeData(streamID uint32, endStream bool, data []byte) error { +func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { return f.WriteDataPadded(streamID, endStream, data, nil) } @@ -2860,7 +2852,7 @@ func (fr *http2Framer) maxHeaderStringLen() int { // readMetaFrame returns 0 or more CONTINUATION frames from fr and // merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. -func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dump *dumper) (*http2MetaHeadersFrame, error) { if fr.AllowIllegalReads { return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") } @@ -2909,9 +2901,9 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr mh.Fields = append(mh.Fields, hf) } emitFunc := rawEmitFunc - if fr.dump != nil && fr.dump.ResponseHeader { + if dump != nil && dump.ResponseHeader { emitFunc = func(hf hpack.HeaderField) { - fr.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) rawEmitFunc(hf) } } @@ -6762,13 +6754,13 @@ func (t *http2Transport) initConnPool() { // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type http2ClientConn struct { - writeHeader func(name, value string) - t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls - reused uint32 // whether conn is being reused; atomic - singleUse bool // whether being used for a single http.Request - getConnCalled bool // used by clientConnPool + currentRequest *http.Request + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool // readLoop goroutine fields: readerDone chan struct{} // closed on error @@ -7181,14 +7173,6 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } - if t.t1.dump != nil && t.t1.dump.RequestHeader { - cc.writeHeader = func(name, value string) { - t.t1.dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) - cc._writeHeader(name, value) - } - } else { - cc.writeHeader = cc._writeHeader - } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) @@ -7208,7 +7192,8 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client err: &cc.werr, }) cc.br = bufio.NewReader(c) - cc.fr = http2NewFramer(cc.bw, cc.br, cc.t.t1.dump) + cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr.cc = cc // for dump single request if t.CountError != nil { cc.fr.countError = t.CountError } @@ -7621,6 +7606,7 @@ func (cc *http2ClientConn) decrStreamReservationsLocked() { } func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + cc.currentRequest = req ctx := req.Context() cs := &http2clientStream{ cc: cc, @@ -7784,11 +7770,13 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } + dump := getDumperOverride(cs.cc.t.t1.dump, req.Context()) + // Past this point (where we send request headers), it is possible for // RoundTrip to return successfully. Since the RoundTrip contract permits // the caller to "mutate or reuse" the Request after closing the Response's Body, // we must take care when referencing the Request from here on. - err = cs.encodeAndWriteHeaders(req) + err = cs.encodeAndWriteHeaders(req, dump) <-cc.reqHeaderMu if err != nil { return err @@ -7820,14 +7808,14 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } - if err = cs.writeRequestBody(req); err != nil { + if err = cs.writeRequestBody(req, dump); err != nil { if err != http2errStopReqBodyWrite { http2traceWroteRequest(cs.trace, err) return err } } else { cs.sentEndStream = true - if dump := cs.cc.t.t1.dump; dump != nil && dump.RequestBody { + if dump != nil && dump.RequestBody { dump.dump([]byte("\r\n\r\n")) } } @@ -7865,7 +7853,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } -func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request) error { +func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dump *dumper) error { cc := cs.cc ctx := cs.ctx @@ -7895,7 +7883,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request) error { hasTrailers := trailers != "" contentLen := http2actualContentLength(req) hasBody := contentLen != 0 - hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dump) if err != nil { return err } @@ -8060,7 +8048,7 @@ func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { var http2bufPool sync.Pool // of *[]byte -func (cs *http2clientStream) writeRequestBody(req *http.Request) (err error) { +func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) (err error) { cc := cs.cc body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM @@ -8084,6 +8072,14 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request) (err error) { defer http2bufPool.Put(&buf) } + writeData := cc.fr.WriteData + if dump != nil && dump.RequestBody { + writeData = func(streamID uint32, endStream bool, data []byte) error { + dump.dump(data) + return cc.fr.WriteData(streamID, endStream, data) + } + } + var sawEOF bool for !sawEOF { n, err := body.Read(buf[:len(buf)]) @@ -8133,7 +8129,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request) (err error) { data := remain[:allowed] remain = remain[allowed:] sentEnd = sawEOF && len(remain) == 0 && !hasTrailers - err = cc.fr.WriteData(cs.ID, sentEnd, data) + err = writeData(cs.ID, sentEnd, data) if err == nil { // TODO(bradfitz): this flush is for latency, not bandwidth. // Most requests won't need this. Make this opt-in or @@ -8172,7 +8168,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request) (err error) { defer cc.wmu.Unlock() var trls []byte if len(trailer) > 0 { - trls, err = cc.encodeTrailers(trailer) + trls, err = cc.encodeTrailers(trailer, dump) if err != nil { return err } @@ -8235,7 +8231,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er var http2errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. -func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { +func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dump *dumper) ([]byte, error) { cc.hbuf.Reset() if req.URL == nil { return nil, http2errNilRequestURL @@ -8386,6 +8382,14 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trace := httptrace.ContextClientTrace(req.Context()) traceHeaders := http2traceHasWroteHeaderField(trace) + writeHeader := cc.writeHeader + if dump != nil && dump.RequestHeader { + writeHeader = func(name, value string) { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + cc.writeHeader(name, value) + } + } + // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { name, ascii := http2asciiToLower(name) @@ -8394,14 +8398,14 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, // field names have to be ASCII characters (just as in HTTP/1.x). return } - cc.writeHeader(name, value) + writeHeader(name, value) if traceHeaders { http2traceWroteHeaderField(trace, name, value) } }) - if cc.t.t1.dump != nil && cc.t.t1.dump.RequestHeader { - cc.t.t1.dump.dump([]byte("\r\n")) + if dump != nil && dump.RequestHeader { + dump.dump([]byte("\r\n")) } return cc.hbuf.Bytes(), nil @@ -8430,7 +8434,7 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } // requires cc.wmu be held. -func (cc *http2ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) { +func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dump *dumper) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) @@ -8444,6 +8448,14 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) { return nil, http2errRequestHeaderListSize } + writeHeader := cc.writeHeader + if dump != nil && dump.RequestBody { + writeHeader = func(name, value string) { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + cc.writeHeader(name, value) + } + } + for k, vv := range trailer { lowKey, ascii := http2asciiToLower(k) if !ascii { @@ -8454,13 +8466,13 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) { // Transfer-Encoding, etc.. have already been filtered at the // start of RoundTrip for _, v := range vv { - cc.writeHeader(lowKey, v) + writeHeader(lowKey, v) } } return cc.hbuf.Bytes(), nil } -func (cc *http2ClientConn) _writeHeader(name, value string) { +func (cc *http2ClientConn) writeHeader(name, value string) { if http2VerboseLogs { log.Printf("http2: Transport encoding header %q = %q", name, value) } diff --git a/http_response.go b/http_response.go index 6b0d2620..b0eda117 100644 --- a/http_response.go +++ b/http_response.go @@ -28,7 +28,11 @@ var respExcludeHeader = map[string]bool{ // pairs included in the response trailer. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { //var tp headReader - tp := newTextprotoReader(pc.br, pc.t.dump) + dump := pc.t.dump + if d, ok := req.Context().Value("dumper").(*dumper); ok { + dump = d + } + tp := newTextprotoReader(pc.br, dump) resp := &http.Response{ Request: req, } diff --git a/request.go b/request.go index fe6880a1..9f87f1fb 100644 --- a/request.go +++ b/request.go @@ -38,6 +38,7 @@ type Request struct { isSaveResponse bool output io.Writer trace *clientTrace + dumpBuffer *bytes.Buffer } // TraceInfo returns the trace information, only available when trace is enabled. @@ -775,3 +776,83 @@ func (r *Request) EnableTrace(enable bool) *Request { } return r } + +func (r *Request) getDumpBuffer() *bytes.Buffer { + if r.dumpBuffer == nil { + r.dumpBuffer = new(bytes.Buffer) + } + return r.dumpBuffer +} + +// Dump is a global wrapper methods which delegated +// to the default client, create a request and Dump for request. +func Dump(reqHeader, reqBody, respHeader, respBody bool) *Request { + return defaultClient.R().Dump(reqHeader, reqBody, respHeader, respBody) +} + +// Dump enables the dump for the request and response. +func (r *Request) Dump(reqHeader, reqBody, respHeader, respBody bool) *Request { + opt := &DumpOptions{ + RequestHeader: reqHeader, + RequestBody: reqBody, + ResponseHeader: respHeader, + ResponseBody: respBody, + Output: r.getDumpBuffer(), + } + return r.SetContext(context.WithValue(r.Context(), "dumper", newDumper(opt))) +} + +// DumpAll is a global wrapper methods which delegated +// to the default client, create a request and DumpAll for request. +func DumpAll() *Request { + return defaultClient.R().DumpAll() +} + +// DumpAll enables dump all content for the request and response. +func (r *Request) DumpAll() *Request { + return r.Dump(true, true, true, true) +} + +// DumpOnlyHeader is a global wrapper methods which delegated +// to the default client, create a request and DumpOnlyHeader for request. +func DumpOnlyHeader() *Request { + return defaultClient.R().DumpOnlyHeader() +} + +// DumpOnlyHeader enables dump only header for the request and response. +func (r *Request) DumpOnlyHeader() *Request { + return r.Dump(true, false, true, false) +} + +// DumpOnlyBody is a global wrapper methods which delegated +// to the default client, create a request and DumpOnlyBody for request. +func DumpOnlyBody() *Request { + return defaultClient.R().DumpOnlyBody() +} + +// DumpOnlyBody enables dump only body for the request and response. +func (r *Request) DumpOnlyBody() *Request { + return r.Dump(false, true, false, true) +} + +// DumpOnlyRequest is a global wrapper methods which delegated +// to the default client, create a request and DumpOnlyRequest for request. +func DumpOnlyRequest() *Request { + return defaultClient.R().DumpOnlyRequest() +} + +// DumpOnlyRequest enables dump only request. +func (r *Request) DumpOnlyRequest() *Request { + return r.Dump(true, true, false, false) +} + +// DumpOnlyResponse is a global wrapper methods which delegated +// to the default client, create a request and DumpOnlyResponse for request. +func DumpOnlyResponse() *Request { + return defaultClient.R().DumpOnlyResponse() +} + +// DumpOnlyResponse enables dump only response. +func (r *Request) DumpOnlyResponse() *Request { + return r.Dump(false, false, true, true) +} diff --git a/response.go b/response.go index c10eac86..b3f31f49 100644 --- a/response.go +++ b/response.go @@ -133,3 +133,9 @@ func (r *Response) ToBytes() ([]byte, error) { r.body = body return body, nil } + +// Dump return the string content that have been dumped for the request. +// `Request.Dump` or `Request.DumpXXX` MUST have been called. +func (r *Response) Dump() string { + return r.Request.getDumpBuffer().String() +} diff --git a/roundtrip.go b/roundtrip.go index fe9c8e09..c96b3390 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -20,6 +20,6 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if err != nil { return } - t.handleResponseBody(resp) + t.handleResponseBody(resp, req) return } diff --git a/transport.go b/transport.go index f706f272..5b781a40 100644 --- a/transport.go +++ b/transport.go @@ -287,16 +287,20 @@ type Transport struct { Debugf func(format string, v ...interface{}) } -func (t *Transport) handleResponseBody(res *http.Response) { +func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { t.autoDecodeResponseBody(res) - t.dumpResponseBody(res) + t.dumpResponseBody(res, req) } -func (t *Transport) dumpResponseBody(res *http.Response) { - if t.dump == nil || !t.dump.ResponseBody { +func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { + dump := t.dump + if d, ok := req.Context().Value("dumper").(*dumper); ok { + dump = d + } + if dump == nil || !dump.ResponseBody { return } - res.Body = t.dump.WrapReadCloser(res.Body) + res.Body = dump.WrapReadCloser(res.Body) } func (t *Transport) autoDecodeResponseBody(res *http.Response) { @@ -2450,7 +2454,12 @@ func (pc *persistConn) writeLoop() { select { case wr := <-pc.writech: startBytesWritten := pc.nwrite - err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), pc.t.dump) + ctx := wr.req.Request.Context() + dump := pc.t.dump + if d, ok := ctx.Value("dumper").(*dumper); ok { + dump = d + } + err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), dump) if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's From 1b519353748cd92da1a5adade65e3215fd63c161 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Feb 2022 21:18:35 +0800 Subject: [PATCH 198/843] remove unused http2 server code --- h2_bundle.go | 9662 ++++++++++++++------------------------------------ 1 file changed, 2627 insertions(+), 7035 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 182d711b..f4676542 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -37,9 +37,7 @@ import ( "net/http" "net/http/httptrace" "net/textproto" - "net/url" "os" - "reflect" "runtime" "sort" "strconv" @@ -102,639 +100,6 @@ func http2asciiToLower(s string) (lower string, ok bool) { // A list of the possible cipher suite ids. Taken from // https://www.iana.org/assignments/tls-parameters/tls-parameters.txt -const ( - http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 - http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 - http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 - http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 - http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 - http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 - http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 - http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 - http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 - http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A - http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B - http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C - http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D - http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E - http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F - http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 - http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 - http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 - http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 - http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 - http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 - http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 - http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 - http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 - http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 - http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A - http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B - // Reserved uint16 = 0x001C-1D - http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F - http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 - http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 - http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B - http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A - http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 - // Reserved uint16 = 0x0047-4F - // Reserved uint16 = 0x0050-58 - // Reserved uint16 = 0x0059-5C - // Unassigned uint16 = 0x005D-5F - // Reserved uint16 = 0x0060-66 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D - // Unassigned uint16 = 0x006E-83 - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 - http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A - http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D - http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E - http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 - http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 - http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 - http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 - http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 - http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 - http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 - http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A - http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B - http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C - http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D - http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E - http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F - http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 - http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 - http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 - http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 - http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 - http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 - http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 - http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 - http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 - http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 - http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA - http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB - http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC - http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF - http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 - http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 - // Unassigned uint16 = 0x00C6-FE - http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF - // Unassigned uint16 = 0x01-55,* - http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 - // Unassigned uint16 = 0x5601 - 0xC000 - http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 - http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 - http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 - http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 - http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 - http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A - http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B - http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C - http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F - http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 - http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 - http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 - http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 - http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 - http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 - http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 - http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 - http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A - http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B - http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C - http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F - http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 - http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 - http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 - http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 - http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B - http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C - http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D - http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E - http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F - http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 - http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 - http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 - http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F - http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 - http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 - http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 - http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 - http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 - http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 - http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A - http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 - http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 - http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 - http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A - http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A - http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D - http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E - http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 - http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 - http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B - http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C - http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D - http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E - http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F - http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 - http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 - http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 - http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 - http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 - http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 - http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 - http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 - http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA - http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF - // Unassigned uint16 = 0xC0B0-FF - // Unassigned uint16 = 0xC1-CB,* - // Unassigned uint16 = 0xCC00-A7 - http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 - http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 - http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA - http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB - http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC - http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD - http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE -) - -// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. -// References: -// https://tools.ietf.org/html/rfc7540#appendix-A -// Reject cipher suites from Appendix A. -// "This list includes those cipher suites that do not -// offer an ephemeral key exchange and those that are -// based on the TLS null, stream or block cipher type" -func http2isBadCipher(cipher uint16) bool { - switch cipher { - case http2cipher_TLS_NULL_WITH_NULL_NULL, - http2cipher_TLS_RSA_WITH_NULL_MD5, - http2cipher_TLS_RSA_WITH_NULL_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_RSA_WITH_RC4_128_MD5, - http2cipher_TLS_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, - http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, - http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_KRB5_WITH_RC4_128_SHA, - http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, - http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, - http2cipher_TLS_KRB5_WITH_RC4_128_MD5, - http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_PSK_WITH_NULL_SHA, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_WITH_NULL_SHA256, - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, - http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_PSK_WITH_NULL_SHA256, - http2cipher_TLS_PSK_WITH_NULL_SHA384, - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, - http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, - http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, - http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_AES_128_CCM, - http2cipher_TLS_RSA_WITH_AES_256_CCM, - http2cipher_TLS_RSA_WITH_AES_128_CCM_8, - http2cipher_TLS_RSA_WITH_AES_256_CCM_8, - http2cipher_TLS_PSK_WITH_AES_128_CCM, - http2cipher_TLS_PSK_WITH_AES_256_CCM, - http2cipher_TLS_PSK_WITH_AES_128_CCM_8, - http2cipher_TLS_PSK_WITH_AES_256_CCM_8: - return true - default: - return false - } -} - // ClientConnPool manages a pool of HTTP/2 client connections. type http2ClientConnPool interface { // GetClientConn returns a specific HTTP/2 connection (usually @@ -3278,10 +2643,6 @@ const ( // connections from clients. http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - // SETTINGS_MAX_FRAME_SIZE default - // http://http2.github.io/http2-spec/#rfc.section.6.5.2 - http2initialMaxFrameSize = 16384 - // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/2's TLS setup. http2NextProtoTLS = "h2" @@ -3290,48 +2651,12 @@ const ( http2initialHeaderTableSize = 4096 http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size - - http2defaultMaxReadFrameSize = 1 << 20 ) var ( http2clientPreface = []byte(http2ClientPreface) ) -type http2streamState int - -// HTTP/2 stream states. -// -// See http://tools.ietf.org/html/rfc7540#section-5.1. -// -// For simplicity, the server code merges "reserved (local)" into -// "half-closed (remote)". This is one less state transition to track. -// The only downside is that we send PUSH_PROMISEs slightly less -// liberally than allowable. More discussion here: -// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html -// -// "reserved (remote)" is omitted since the client code does not -// support server push. -const ( - http2stateIdle http2streamState = iota - http2stateOpen - http2stateHalfClosedLocal - http2stateHalfClosedRemote - http2stateClosed -) - -var http2stateName = [...]string{ - http2stateIdle: "Idle", - http2stateOpen: "Open", - http2stateHalfClosedLocal: "HalfClosedLocal", - http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateClosed: "Closed", -} - -func (st http2streamState) String() string { - return http2stateName[st] -} - // Setting is a setting parameter: which setting it is, and its value. type http2Setting struct { // ID is which setting is being set. @@ -3428,52 +2753,6 @@ func http2httpCodeString(code int) string { return strconv.Itoa(code) } -// from pkg io -type http2stringWriter interface { - WriteString(s string) (n int, err error) -} - -// A gate lets two goroutines coordinate their activities. -type http2gate chan struct{} - -func (g http2gate) Done() { g <- struct{}{} } - -func (g http2gate) Wait() { <-g } - -// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). -type http2closeWaiter chan struct{} - -// Init makes a closeWaiter usable. -// It exists because so a closeWaiter value can be placed inside a -// larger struct and have the Mutex and Cond's memory in the same -// allocation. -func (cw *http2closeWaiter) Init() { - *cw = make(chan struct{}) -} - -// Close marks the closeWaiter as closed and unblocks any waiters. -func (cw http2closeWaiter) Close() { - close(cw) -} - -// Wait waits for the closeWaiter to become closed. -func (cw http2closeWaiter) Wait() { - <-cw -} - -// bufferedWriter is a buffered writer that writes to w. -// Its buffered writer is lazily allocated as needed, to minimize -// idle memory usage with many connections. -type http2bufferedWriter struct { - _ http2incomparable - w io.Writer // immutable - bw *bufio.Writer // non-nil when data is buffered -} - -func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { - return &http2bufferedWriter{w: w} -} - // bufWriterPoolBufferSize is the size of bufio.Writer's // buffers created using bufWriterPool. // @@ -3488,34 +2767,6 @@ var http2bufWriterPool = sync.Pool{ }, } -func (w *http2bufferedWriter) Available() int { - if w.bw == nil { - return http2bufWriterPoolBufferSize - } - return w.bw.Available() -} - -func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { - if w.bw == nil { - bw := http2bufWriterPool.Get().(*bufio.Writer) - bw.Reset(w.w) - w.bw = bw - } - return w.bw.Write(p) -} - -func (w *http2bufferedWriter) Flush() error { - bw := w.bw - if bw == nil { - return nil - } - err := bw.Flush() - bw.Reset(nil) - http2bufWriterPool.Put(bw) - w.bw = nil - return err -} - func http2mustUint31(v int32) uint32 { if v < 0 || v > 2147483647 { panic("out of range") @@ -3788,6903 +3039,3244 @@ const ( http2maxQueuedControlFrames = 10000 ) -var ( - http2errClientDisconnected = errors.New("client disconnected") - http2errClosedBody = errors.New("body closed by handler") - http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - http2errStreamClosed = errors.New("http2: stream closed") -) +type http2readFrameResult struct { + f http2Frame // valid until readMore is called + err error -var http2responseWriterStatePool = sync.Pool{ - New: func() interface{} { - rws := &http2responseWriterState{} - rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) - return rws - }, + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() } -// Test hooks. +type http2serverMessage int + +// Message values sent to serveMsgCh. var ( - http2testHookOnConn func() - http2testHookGetServerConn func(*http2serverConn) - http2testHookOnPanicMu *sync.Mutex // nil except in tests - http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) + http2settingsTimerMsg = new(http2serverMessage) + http2idleTimerMsg = new(http2serverMessage) + http2shutdownTimerMsg = new(http2serverMessage) + http2gracefulShutdownMsg = new(http2serverMessage) ) -// Server is an HTTP/2 server. -type http2Server struct { - // MaxHandlers limits the number of http.Handler ServeHTTP goroutines - // which may run at a time over all connections. - // Negative or zero no limit. - // TODO: implement - MaxHandlers int - - // MaxConcurrentStreams optionally specifies the number of - // concurrent streams that each client may have open at a - // time. This is unrelated to the number of http.Handler goroutines - // which may be active globally, which is MaxHandlers. - // If zero, MaxConcurrentStreams defaults to at least 100, per - // the HTTP/2 spec's recommendations. - MaxConcurrentStreams uint32 - - // MaxReadFrameSize optionally specifies the largest frame - // this server is willing to read. A valid value is between - // 16k and 16M, inclusive. If zero or otherwise invalid, a - // default value is used. - MaxReadFrameSize uint32 - - // PermitProhibitedCipherSuites, if true, permits the use of - // cipher suites prohibited by the HTTP/2 spec. - PermitProhibitedCipherSuites bool - - // IdleTimeout specifies how long until idle clients should be - // closed with a GOAWAY frame. PING frames are not considered - // activity for the purposes of IdleTimeout. - IdleTimeout time.Duration - - // MaxUploadBufferPerConnection is the size of the initial flow - // control window for each connections. The HTTP/2 spec does not - // allow this to be smaller than 65535 or larger than 2^32-1. - // If the value is outside this range, a default value will be - // used instead. - MaxUploadBufferPerConnection int32 - - // MaxUploadBufferPerStream is the size of the initial flow control - // window for each stream. The HTTP/2 spec does not allow this to - // be larger than 2^32-1. If the value is zero or larger than the - // maximum, a default value will be used instead. - MaxUploadBufferPerStream int32 - - // NewWriteScheduler constructs a write scheduler for a connection. - // If nil, a default scheduler is chosen. - NewWriteScheduler func() http2WriteScheduler - - // CountError, if non-nil, is called on HTTP/2 server errors. - // It's intended to increment a metric for monitoring, such - // as an expvar or Prometheus metric. - // The errType consists of only ASCII word characters. - CountError func(errType string) +var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") - // Internal state. This is a pointer (rather than embedded directly) - // so that we don't embed a Mutex in this struct, which will make the - // struct non-copyable, which might break some callers. - state *http2serverInternalState +var http2errChanPool = sync.Pool{ + New: func() interface{} { return make(chan error, 1) }, } -func (s *http2Server) initialConnRecvWindowSize() int32 { - if s.MaxUploadBufferPerConnection > http2initialWindowSize { - return s.MaxUploadBufferPerConnection - } - return 1 << 20 +type http2requestParam struct { + method string + scheme, authority, path string + header http.Header } -func (s *http2Server) initialStreamRecvWindowSize() int32 { - if s.MaxUploadBufferPerStream > 0 { - return s.MaxUploadBufferPerStream - } - return 1 << 20 -} +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const http2TrailerPrefix = "Trailer:" -func (s *http2Server) maxReadFrameSize() uint32 { - if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { - return v +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func http2checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) } - return http2defaultMaxReadFrameSize } -func (s *http2Server) maxConcurrentStreams() uint32 { - if v := s.MaxConcurrentStreams; v > 0 { - return v +func http2cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 } - return http2defaultMaxStreams -} - -// maxQueuedControlFrames is the maximum number of control frames like -// SETTINGS, PING and RST_STREAM that will be queued for writing before -// the connection is closed to prevent memory exhaustion attacks. -func (s *http2Server) maxQueuedControlFrames() int { - // TODO: if anybody asks, add a Server field, and remember to define the - // behavior of negative values. - return http2maxQueuedControlFrames -} - -type http2serverInternalState struct { - mu sync.Mutex - activeConns map[*http2serverConn]struct{} + return h2 } -func (s *http2serverInternalState) registerConn(sc *http2serverConn) { - if s == nil { - return // if the Server was used without calling ConfigureServer - } - s.mu.Lock() - s.activeConns[sc] = struct{}{} - s.mu.Unlock() -} +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) -func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { - if s == nil { - return // if the Server was used without calling ConfigureServer +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return } - s.mu.Lock() - delete(s.activeConns, sc) - s.mu.Unlock() -} - -func (s *http2serverInternalState) startGracefulShutdown() { - if s == nil { - return // if the Server was used without calling ConfigureServer + if !strings.Contains(v, ",") { + fn(v) + return } - s.mu.Lock() - for sc := range s.activeConns { - sc.startGracefulShutdown() + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } } - s.mu.Unlock() } -// ServeConnOpts are options for the Server.ServeConn method. -type http2ServeConnOpts struct { - // Context is the base context to use. - // If nil, context.Background is used. - Context context.Context - - // BaseConfig optionally sets the base configuration - // for values. If nil, defaults are used. - BaseConfig *http.Server - - // Handler specifies which handler to use for processing - // requests. If nil, BaseConfig.Handler is used. If BaseConfig - // or BaseConfig.Handler is nil, http.DefaultServeMux is used. - Handler http.Handler +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", } -func (o *http2ServeConnOpts) context() context.Context { - if o != nil && o.Context != nil { - return o.Context +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2RequestHeaders(h http.Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) + } } - return context.Background() -} - -func (o *http2ServeConnOpts) baseConfig() *http.Server { - if o != nil && o.BaseConfig != nil { - return o.BaseConfig + te := h["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) } - return new(http.Server) + return nil } -func (o *http2ServeConnOpts) handler() http.Handler { - if o != nil { - if o.Handler != nil { - return o.Handler - } - if o.BaseConfig != nil && o.BaseConfig.Handler != nil { - return o.BaseConfig.Handler - } +func http2new400Handler(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, err.Error(), http.StatusBadRequest) } - return http.DefaultServeMux } -func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { - ctx, cancel = context.WithCancel(opts.context()) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) - if hs := opts.baseConfig(); hs != nil { - ctx = context.WithValue(ctx, http.ServerContextKey, hs) +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool } - return + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false } -func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { - sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) - // ignoring errors. hanging up anyway. - sc.framer.WriteGoAway(0, err, []byte(debug)) - sc.bw.Flush() - sc.conn.Close() -} - -type http2serverConn struct { - // Immutable: - srv *http2Server - hs *http.Server - conn net.Conn - bw *http2bufferedWriter // writing to conn - handler http.Handler - baseCtx context.Context - framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http - remoteAddrStr string - writeSched http2WriteScheduler - - // Everything following is owned by the serve loop; use serveG.check(): - serveG http2goroutineLock // used to verify funcs are on serve() - pushEnabled bool - sawFirstSettings bool // got the initial SETTINGS frame after the preface - needToSendSettingsAck bool - unackedSettings int // how many SETTINGS have we sent without ACKs? - queuedControlFrames int // control frames in the writeSched queue - clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) - advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curClientStreams uint32 // number of open streams initiated by the client - curPushedStreams uint32 // number of open streams initiated by server push - maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests - maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes - streams map[uint32]*http2stream - initialStreamSendWindowSize int32 - maxFrameSize int32 - headerTableSize uint32 - peerMaxHeaderListSize uint32 // zero means unknown (default) - canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started writing a frame (on serve goroutine or separate) - writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh - needsFrameFlush bool // last frame write wasn't a flush - inGoAway bool // we've started to or sent GOAWAY - inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop - needToSendGoAway bool // we need to schedule a GOAWAY frame write - goAwayCode http2ErrCode - shutdownTimer *time.Timer // nil until used - idleTimer *time.Timer // nil if unused - - // Owned by the writeFrameAsync goroutine: - headerWriteBuf bytes.Buffer - hpackEncoder *hpack.Encoder - - // Used by startGracefulShutdown. - shutdownOnce sync.Once -} - -func (sc *http2serverConn) maxHeaderListSize() uint32 { - n := sc.hs.MaxHeaderBytes - if n <= 0 { - n = http.DefaultMaxHeaderBytes - } - // http2's count is in a slightly different unit and includes 32 bytes per pair. - // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. - const perFieldOverhead = 32 // per http2 spec - const typicalHeaders = 10 // conservative - return uint32(n + typicalHeaders*perFieldOverhead) -} - -func (sc *http2serverConn) curOpenStreams() uint32 { - sc.serveG.check() - return sc.curClientStreams + sc.curPushedStreams -} - -// stream represents a stream. This is the minimal metadata needed by -// the serve goroutine. Most of the actual stream state is owned by -// the http.Handler's goroutine in the responseWriter. Because the -// responseWriter's responseWriterState is recycled at the end of a -// handler, this struct intentionally has no pointer to the -// *responseWriter{,State} itself, as the Handler ending nils out the -// responseWriter's state field. -type http2stream struct { - // immutable: - sc *http2serverConn - id uint32 - body *http2pipe // non-nil if expecting DATA frames - cw http2closeWaiter // closed wait stream transitions to closed state - ctx context.Context - cancelCtx func() - - // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow http2flow // limits writing from Handler to client - inflow http2flow // what the client is allowed to POST/etc to us - state http2streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - writeDeadline *time.Timer // nil if unused +const ( + // transportDefaultConnFlow is how many connection-level flow control + // tokens we give the server at start-up, past the default 64k. + http2transportDefaultConnFlow = 1 << 30 - trailer http.Header // accumulated trailers - reqTrailer http.Header // handler's Request.Trailer -} + // transportDefaultStreamFlow is how many stream-level flow + // control tokens we announce to the peer, and how many bytes + // we buffer per stream. + http2transportDefaultStreamFlow = 4 << 20 -func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } + // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send + // a stream-level WINDOW_UPDATE for at a time. + http2transportDefaultStreamMinRefresh = 4 << 10 -func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } + // initialMaxConcurrentStreams is a connections maxConcurrentStreams until + // it's received servers initial SETTINGS frame, which corresponds with the + // spec's minimum recommended value. + http2initialMaxConcurrentStreams = 100 -func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } + // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams + // if the server doesn't include one in its initial SETTINGS frame. + http2defaultMaxConcurrentStreams = 1000 +) -func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { - return sc.hpackEncoder, &sc.headerWriteBuf -} +// Transport is an HTTP/2 Transport. +// +// A Transport internally caches connections to servers. It is safe +// for concurrent use by multiple goroutines. +type http2Transport struct { + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLS is nil, tls.Dial is used. + // + // If the returned net.Conn has a ConnectionState method like tls.Conn, + // it will be used to set http.Response.TLS. + DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) -func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { - sc.serveG.check() - // http://tools.ietf.org/html/rfc7540#section-5.1 - if st, ok := sc.streams[streamID]; ok { - return st.state, st - } - // "The first use of a new stream identifier implicitly closes all - // streams in the "idle" state that might have been initiated by - // that peer with a lower-valued stream identifier. For example, if - // a client sends a HEADERS frame on stream 7 without ever sending a - // frame on stream 5, then stream 5 transitions to the "closed" - // state when the first frame for stream 7 is sent or received." - if streamID%2 == 1 { - if streamID <= sc.maxClientStreamID { - return http2stateClosed, nil - } - } else { - if streamID <= sc.maxPushPromiseID { - return http2stateClosed, nil - } - } - return http2stateIdle, nil -} + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config -// setConnState calls the net/http ConnState hook for this connection, if configured. -// Note that the net/http package does StateNew and StateClosed for us. -// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. -func (sc *http2serverConn) setConnState(state http.ConnState) { - if sc.hs.ConnState != nil { - sc.hs.ConnState(sc.conn, state) - } -} + // ConnPool optionally specifies an alternate connection pool to use. + // If nil, the default is used. + ConnPool http2ClientConnPool -func (sc *http2serverConn) vlogf(format string, args ...interface{}) { - if http2VerboseLogs { - sc.logf(format, args...) - } -} + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool -func (sc *http2serverConn) logf(format string, args ...interface{}) { - if lg := sc.hs.ErrorLog; lg != nil { - lg.Printf(format, args...) - } else { - log.Printf(format, args...) - } -} + // AllowHTTP, if true, permits HTTP/2 requests using the insecure, + // plain-text "http" scheme. Note that this does not enable h2c support. + AllowHTTP bool -// errno returns v's underlying uintptr, else 0. -// -// TODO: remove this helper function once http2 can use build -// tags. See comment in isClosedConnError. -func http2errno(v error) uintptr { - if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { - return uintptr(rv.Uint()) - } - return 0 -} + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to + // send in the initial settings frame. It is how many bytes + // of response headers are allowed. Unlike the http2 spec, zero here + // means to use a default limit (currently 10MB). If you actually + // want to advertise an unlimited value to the peer, Transport + // interprets the highest possible value here (0xffffffff or 1<<32-1) + // to mean no limit. + MaxHeaderListSize uint32 -// isClosedConnError reports whether err is an error from use of a closed -// network connection. -func http2isClosedConnError(err error) bool { - if err == nil { - return false - } + // StrictMaxConcurrentStreams controls whether the server's + // SETTINGS_MAX_CONCURRENT_STREAMS should be respected + // globally. If false, new TCP connections are created to the + // server as needed to keep each under the per-connection + // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the + // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as + // a global limit and callers of RoundTrip block when needed, + // waiting for their turn. + StrictMaxConcurrentStreams bool - // TODO: remove this string search and be more like the Windows - // case below. That might involve modifying the standard library - // to return better error types. - str := err.Error() - if strings.Contains(str, "use of closed network connection") { - return true - } + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration - // TODO(bradfitz): x/tools/cmd/bundle doesn't really support - // build tags, so I can't make an http2_windows.go file with - // Windows-specific stuff. Fix that and move this, once we - // have a way to bundle this into std's net/http somehow. - if runtime.GOOS == "windows" { - if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { - if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { - const WSAECONNABORTED = 10053 - const WSAECONNRESET = 10054 - if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { - return true - } - } - } - } - return false -} + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration -func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { - if err == nil { - return - } - if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { - // Boring, expected errors. - sc.vlogf(format, args...) - } else { - sc.logf(format, args...) - } -} + // WriteByteTimeout is the timeout after which the connection will be + // closed no data can be written to it. The timeout begins when data is + // available to write, and is extended whenever any bytes are written. + WriteByteTimeout time.Duration -func (sc *http2serverConn) canonicalHeader(v string) string { - sc.serveG.check() - http2buildCommonHeaderMapsOnce() - cv, ok := http2commonCanonHeader[v] - if ok { - return cv - } - cv, ok = sc.canonHeader[v] - if ok { - return cv - } - if sc.canonHeader == nil { - sc.canonHeader = make(map[string]string) - } - cv = http.CanonicalHeaderKey(v) - // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of - // entries in the canonHeader cache. This should be larger than the number - // of unique, uncommon header keys likely to be sent by the peer, while not - // so high as to permit unreasonable memory usage if the peer sends an unbounded - // number of unique header keys. - const maxCachedCanonicalHeaders = 32 - if len(sc.canonHeader) < maxCachedCanonicalHeaders { - sc.canonHeader[v] = cv - } - return cv -} + // CountError, if non-nil, is called on HTTP/2 transport errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) -type http2readFrameResult struct { - f http2Frame // valid until readMore is called - err error + // t1, if non-nil, is the standard library Transport using + // this transport. Its settings are used (but not its + // RoundTrip method, etc). + t1 *Transport - // readMore should be called once the consumer no longer needs or - // retains f. After readMore, f is invalid and more frames can be - // read. - readMore func() + connPoolOnce sync.Once + connPoolOrDef http2ClientConnPool // non-nil version of ConnPool } -// readFrames is the loop that reads incoming frames. -// It takes care to only read one frame at a time, blocking until the -// consumer is done with the frame. -// It's run on its own goroutine. -func (sc *http2serverConn) readFrames() { - gate := make(http2gate) - gateDone := gate.Done - for { - f, err := sc.framer.ReadFrame() - select { - case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: - case <-sc.doneServing: - return - } - select { - case <-gate: - case <-sc.doneServing: - return - } - if http2terminalReadFrameError(err) { - return - } +func (t *http2Transport) maxHeaderListSize() uint32 { + if t.MaxHeaderListSize == 0 { + return 10 << 20 } + if t.MaxHeaderListSize == 0xffffffff { + return 0 + } + return t.MaxHeaderListSize } -// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. -type http2frameWriteResult struct { - _ http2incomparable - wr http2FrameWriteRequest // what was written (or attempted) - err error // result of the writeFrame call -} - -// writeFrameAsync runs in its own goroutine and writes a single frame -// and then reports when it's done. -// At most one goroutine can be running writeFrameAsync at a time per -// serverConn. -func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { - err := wr.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} +func (t *http2Transport) disableCompression() bool { + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } -func (sc *http2serverConn) closeAllStreamsOnConnClose() { - sc.serveG.check() - for _, st := range sc.streams { - sc.closeStream(st, http2errClientDisconnected) +func (t *http2Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second } -} + return t.PingTimeout -func (sc *http2serverConn) stopShutdownTimer() { - sc.serveG.check() - if t := sc.shutdownTimer; t != nil { - t.Stop() - } } -func (sc *http2serverConn) notePanic() { - // Note: this is for serverConn.serve panicking, not http.Handler code. - if http2testHookOnPanicMu != nil { - http2testHookOnPanicMu.Lock() - defer http2testHookOnPanicMu.Unlock() - } - if http2testHookOnPanic != nil { - if e := recover(); e != nil { - if http2testHookOnPanic(sc, e) { - panic(e) - } - } - } +// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns an error if t1 has already been HTTP/2-enabled. +// +// Use ConfigureTransports instead to configure the HTTP/2 Transport. +func http2ConfigureTransport(t1 *Transport) error { + _, err := http2ConfigureTransports(t1) + return err } -func (sc *http2serverConn) serve() { - sc.serveG.check() - defer sc.notePanic() - defer sc.conn.Close() - defer sc.closeAllStreamsOnConnClose() - defer sc.stopShutdownTimer() - defer close(sc.doneServing) // unblocks handlers trying to send - - if http2VerboseLogs { - sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) +// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns a new HTTP/2 Transport for further configuration. +// It returns an error if t1 has already been HTTP/2-enabled. +func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { + connPool := new(http2clientConnPool) + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, } - - sc.writeFrame(http2FrameWriteRequest{ - write: http2writeSettings{ - {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, - {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, - {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, - {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, - }, - }) - sc.unackedSettings++ - - // Each connection starts with initialWindowSize inflow tokens. - // If a higher value is configured, we add more tokens. - if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { - sc.sendWindowUpdate(nil, int(diff)) + connPool.t = t2 + if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + return nil, err } - - if err := sc.readPreface(); err != nil { - sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) - return + if t1.TLSClientConfig == nil { + t1.TLSClientConfig = new(tls.Config) } - // Now that we've got the preface, get us out of the - // "StateNew" state. We can't go directly to idle, though. - // Active means we read some data and anticipate a request. We'll - // do another Active when we get a HEADERS frame. - sc.setConnState(http.StateActive) - sc.setConnState(http.StateIdle) - - if sc.srv.IdleTimeout != 0 { - sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) - defer sc.idleTimer.Stop() + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) } - - go sc.readFrames() // closed by defer sc.conn.Close above - - settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) - defer settingsTimer.Stop() - - loopNum := 0 - for { - loopNum++ - select { - case wr := <-sc.wantWriteFrameCh: - if se, ok := wr.write.(http2StreamError); ok { - sc.resetStream(se) - break - } - sc.writeFrame(wr) - case res := <-sc.wroteFrameCh: - sc.wroteFrame(res) - case res := <-sc.readFrameCh: - // Process any written frames before reading new frames from the client since a - // written frame could have triggered a new stream to be started. - if sc.writingFrameAsync { - select { - case wroteRes := <-sc.wroteFrameCh: - sc.wroteFrame(wroteRes) - default: - } - } - if !sc.processFrameFromReader(res) { - return - } - res.readMore() - if settingsTimer != nil { - settingsTimer.Stop() - settingsTimer = nil - } - case m := <-sc.bodyReadCh: - sc.noteBodyRead(m.st, m.n) - case msg := <-sc.serveMsgCh: - switch v := msg.(type) { - case func(int): - v(loopNum) // for testing - case *http2serverMessage: - switch v { - case http2settingsTimerMsg: - sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) - return - case http2idleTimerMsg: - sc.vlogf("connection is idle") - sc.goAway(http2ErrCodeNo) - case http2shutdownTimerMsg: - sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) - return - case http2gracefulShutdownMsg: - sc.startGracefulShutdownInternal() - default: - panic("unknown timer") - } - case *http2startPushRequest: - sc.startPush(v) - default: - panic(fmt.Sprintf("unexpected type %T", v)) - } - } - - // If the peer is causing us to generate a lot of control frames, - // but not reading them from us, assume they are trying to make us - // run out of memory. - if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { - sc.vlogf("http2: too many control frames in send queue, closing connection") - return + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { + t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + } + upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { + addr := http2authorityAddr("https", authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() + return http2erringRoundTripper{err} + } else if !used { + // Turns out we don't need this c. + // For example, two goroutines made requests to the same host + // at the same time, both kicking off TCP dials. (since protocol + // was unknown) + go c.Close() } - - // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY - // with no error code (graceful shutdown), don't start the timer until - // all open streams have been completed. - sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame - gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 - if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { - sc.shutDownIn(http2goAwayTimeout) + return t2 + } + if m := t1.TLSNextProto; len(m) == 0 { + t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ + "h2": upgradeFn, } + } else { + m["h2"] = upgradeFn } + return t2, nil } -func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { - select { - case <-sc.doneServing: - case <-sharedCh: - close(privateCh) +func (t *http2Transport) connPool() http2ClientConnPool { + t.connPoolOnce.Do(t.initConnPool) + return t.connPoolOrDef +} + +func (t *http2Transport) initConnPool() { + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &http2clientConnPool{t: t} } } -type http2serverMessage int +// ClientConn is the state of a single HTTP/2 client connection to an +// HTTP/2 server. +type http2ClientConn struct { + currentRequest *http.Request + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool -// Message values sent to serveMsgCh. -var ( - http2settingsTimerMsg = new(http2serverMessage) - http2idleTimerMsg = new(http2serverMessage) - http2shutdownTimerMsg = new(http2serverMessage) - http2gracefulShutdownMsg = new(http2serverMessage) -) + // readLoop goroutine fields: + readerDone chan struct{} // closed on error + readerErr error // set before readerDone is closed -func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer -func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests + closing bool + closed bool + seenSettings bool // true if we've seen a settings frame, false otherwise + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*http2clientStream // client-initiated + streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip + nextStreamID uint32 + pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams + pings map[[8]byte]chan struct{} // in flight ping data to notification channel + br *bufio.Reader + lastActive time.Time + lastIdle time.Time // time last idle + // Settings from peer: (also guarded by wmu) + maxFrameSize uint32 + maxConcurrentStreams uint32 + peerMaxHeaderListSize uint64 + initialWindowSize uint32 -func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. + // Write to reqHeaderMu to lock it, read from it to unlock. + // Lock reqmu BEFORE mu or wmu. + reqHeaderMu chan struct{} -func (sc *http2serverConn) sendServeMsg(msg interface{}) { - sc.serveG.checkNotOn() // NOT - select { - case sc.serveMsgCh <- msg: - case <-sc.doneServing: - } + // wmu is held while writing. + // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. + // Only acquire both at the same time when changing peer settings. + wmu sync.Mutex + bw *bufio.Writer + fr *http2Framer + werr error // first write error that has occurred + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder } -var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") - -// readPreface reads the ClientPreface greeting from the peer or -// returns errPrefaceTimeout on timeout, or an error if the greeting -// is invalid. -func (sc *http2serverConn) readPreface() error { - errc := make(chan error, 1) - go func() { - // Read the client preface - buf := make([]byte, len(http2ClientPreface)) - if _, err := io.ReadFull(sc.conn, buf); err != nil { - errc <- err - } else if !bytes.Equal(buf, http2clientPreface) { - errc <- fmt.Errorf("bogus greeting %q", buf) - } else { - errc <- nil - } - }() - timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server? - defer timer.Stop() - select { - case <-timer.C: - return http2errPrefaceTimeout - case err := <-errc: - if err == nil { - if http2VerboseLogs { - sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) - } - } - return err - } -} +// clientStream is the state for a single HTTP/2 stream. One of these +// is created for each Transport.RoundTrip call. +type http2clientStream struct { + cc *http2ClientConn -var http2errChanPool = sync.Pool{ - New: func() interface{} { return make(chan error, 1) }, -} + // Fields of Request that we may access even after the response body is closed. + ctx context.Context + reqCancel <-chan struct{} -var http2writeDataPool = sync.Pool{ - New: func() interface{} { return new(http2writeData) }, -} + trace *httptrace.ClientTrace // or nil + ID uint32 + bufPipe http2pipe // buffered pipe with the flow-controlled response payload + requestedGzip bool + isHead bool -// writeDataFromHandler writes DATA response frames from a handler on -// the given stream. -func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { - ch := http2errChanPool.Get().(chan error) - writeArg := http2writeDataPool.Get().(*http2writeData) - *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: writeArg, - stream: stream, - done: ch, - }) - if err != nil { - return err - } - var frameWriteDone bool // the frame write is done (successfully or not) - select { - case err = <-ch: - frameWriteDone = true - case <-sc.doneServing: - return http2errClientDisconnected - case <-stream.cw: - // If both ch and stream.cw were ready (as might - // happen on the final Write after an http.Handler - // ends), prefer the write result. Otherwise this - // might just be us successfully closing the stream. - // The writeFrameAsync and serve goroutines guarantee - // that the ch send will happen before the stream.cw - // close. - select { - case err = <-ch: - frameWriteDone = true - default: - return http2errStreamClosed - } - } - http2errChanPool.Put(ch) - if frameWriteDone { - http2writeDataPool.Put(writeArg) - } - return err -} + abortOnce sync.Once + abort chan struct{} // closed to signal stream should end immediately + abortErr error // set if abort is closed -// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts -// if the connection has gone away. -// -// This must not be run from the serve goroutine itself, else it might -// deadlock writing to sc.wantWriteFrameCh (which is only mildly -// buffered and is read by serve itself). If you're on the serve -// goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { - sc.serveG.checkNotOn() // NOT - select { - case sc.wantWriteFrameCh <- wr: - return nil - case <-sc.doneServing: - // Serve loop is gone. - // Client has closed their connection to the server. - return http2errClientDisconnected - } -} + peerClosed chan struct{} // closed when the peer sends an END_STREAM flag + donec chan struct{} // closed after the stream is in the closed state + on100 chan struct{} // buffered; written to if a 100 is received -// writeFrame schedules a frame to write and sends it if there's nothing -// already being written. -// -// There is no pushback here (the serve goroutine never blocks). It's -// the http.Handlers that block, waiting for their previous frames to -// make it onto the wire -// -// If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { - sc.serveG.check() + respHeaderRecv chan struct{} // closed when headers are received + res *http.Response // set if respHeaderRecv is closed - // If true, wr will not be written and wr.done will not be signaled. - var ignoreWrite bool + flow http2flow // guarded by cc.mu + inflow http2flow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read - // We are not allowed to write frames on closed streams. RFC 7540 Section - // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on - // a closed stream." Our server never sends PRIORITY, so that exception - // does not apply. - // - // The serverConn might close an open stream while the stream's handler - // is still running. For example, the server might close a stream when it - // receives bad data from the client. If this happens, the handler might - // attempt to write a frame after the stream has been closed (since the - // handler hasn't yet been notified of the close). In this case, we simply - // ignore the frame. The handler will notice that the stream is closed when - // it waits for the frame to be written. - // - // As an exception to this rule, we allow sending RST_STREAM after close. - // This allows us to immediately reject new streams without tracking any - // state for those streams (except for the queued RST_STREAM frame). This - // may result in duplicate RST_STREAMs in some cases, but the client should - // ignore those. - if wr.StreamID() != 0 { - _, isReset := wr.write.(http2StreamError) - if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { - ignoreWrite = true - } - } - - // Don't send a 100-continue response if we've already sent headers. - // See golang.org/issue/14030. - switch wr.write.(type) { - case *http2writeResHeaders: - wr.stream.wroteHeaders = true - case http2write100ContinueHeadersFrame: - if wr.stream.wroteHeaders { - // We do not need to notify wr.done because this frame is - // never written with wr.done != nil. - if wr.done != nil { - panic("wr.done != nil for write100ContinueHeadersFrame") - } - ignoreWrite = true - } - } + reqBody io.ReadCloser + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed bool // body has been closed; guarded by cc.mu - if !ignoreWrite { - if wr.isControl() { - sc.queuedControlFrames++ - // For extra safety, detect wraparounds, which should not happen, - // and pull the plug. - if sc.queuedControlFrames < 0 { - sc.conn.Close() - } - } - sc.writeSched.Push(wr) - } - sc.scheduleFrameWrite() + // owned by writeRequest: + sentEndStream bool // sent an END_STREAM flag to the peer + sentHeaders bool + + // owned by clientConnReadLoop: + firstByte bool // got the first response byte + pastHeaders bool // got first MetaHeadersFrame (actual headers) + pastTrailers bool // got optional second MetaHeadersFrame (trailers) + num1xx uint8 // number of 1xx responses seen + readClosed bool // peer sent an END_STREAM flag + readAborted bool // read loop reset the stream + + trailer http.Header // accumulated trailers + resTrailer *http.Header // client's Response.Trailer } -// startFrameWrite starts a goroutine to write wr (in a separate -// goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wr. -func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { - sc.serveG.check() - if sc.writingFrame { - panic("internal error: can only be writing one frame at a time") +var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error + +// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, +// if any. It returns nil if not set or if the Go version is too old. +func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { + if fn := http2got1xxFuncForTests; fn != nil { + return fn } + return http2traceGot1xxResponseFunc(cs.trace) +} - st := wr.stream - if st != nil { - switch st.state { - case http2stateHalfClosedLocal: - switch wr.write.(type) { - case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: - // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE - // in this state. (We never send PRIORITY from the server, so that is not checked.) - default: - panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) - } - case http2stateClosed: - panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) - } +func (cs *http2clientStream) abortStream(err error) { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) +} + +func (cs *http2clientStream) abortStreamLocked(err error) { + cs.abortOnce.Do(func() { + cs.abortErr = err + close(cs.abort) + }) + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true } - if wpp, ok := wr.write.(*http2writePushPromise); ok { - var err error - wpp.promisedID, err = wpp.allocatePromisedID() - if err != nil { - sc.writingFrameAsync = false - wr.replyToWriter(err) - return - } + // TODO(dneil): Clean up tests where cs.cc.cond is nil. + if cs.cc.cond != nil { + // Wake up writeRequestBody if it is waiting on flow control. + cs.cc.cond.Broadcast() } +} - sc.writingFrame = true - sc.needsFrameFlush = true - if wr.write.staysWithinBuffer(sc.bw.Available()) { - sc.writingFrameAsync = false - err := wr.write.writeFrame(sc) - sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) - } else { - sc.writingFrameAsync = true - go sc.writeFrameAsync(wr) - } -} - -// errHandlerPanicked is the error given to any callers blocked in a read from -// Request.Body when the main goroutine panics. Since most handlers read in the -// main ServeHTTP goroutine, this will show up rarely. -var http2errHandlerPanicked = errors.New("http2: handler panicked") - -// wroteFrame is called on the serve goroutine with the result of -// whatever happened on writeFrameAsync. -func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { - sc.serveG.check() - if !sc.writingFrame { - panic("internal error: expected to be already writing a frame") - } - sc.writingFrame = false - sc.writingFrameAsync = false - - wr := res.wr - - if http2writeEndsStream(wr.write) { - st := wr.stream - if st == nil { - panic("internal error: expecting non-nil stream") - } - switch st.state { - case http2stateOpen: - // Here we would go to stateHalfClosedLocal in - // theory, but since our handler is done and - // the net/http package provides no mechanism - // for closing a ResponseWriter while still - // reading data (see possible TODO at top of - // this file), we go into closed state here - // anyway, after telling the peer we're - // hanging up on them. We'll transition to - // stateClosed after the RST_STREAM frame is - // written. - st.state = http2stateHalfClosedLocal - // Section 8.1: a server MAY request that the client abort - // transmission of a request without error by sending a - // RST_STREAM with an error code of NO_ERROR after sending - // a complete response. - sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) - case http2stateHalfClosedRemote: - sc.closeStream(st, http2errHandlerComplete) - } - } else { - switch v := wr.write.(type) { - case http2StreamError: - // st may be unknown if the RST_STREAM was generated to reject bad input. - if st, ok := sc.streams[v.StreamID]; ok { - sc.closeStream(st, v) - } - case http2handlerPanicRST: - sc.closeStream(wr.stream, http2errHandlerPanicked) - } +func (cs *http2clientStream) abortRequestBodyWrite() { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true + cc.cond.Broadcast() } +} - // Reply (if requested) to unblock the ServeHTTP goroutine. - wr.replyToWriter(res.err) - - sc.scheduleFrameWrite() +type http2stickyErrWriter struct { + conn net.Conn + timeout time.Duration + err *error } -// scheduleFrameWrite tickles the frame writing scheduler. -// -// If a frame is already being written, nothing happens. This will be called again -// when the frame is done being written. -// -// If a frame isn't being written and we need to send one, the best frame -// to send is selected by writeSched. -// -// If a frame isn't being written and there's nothing else to send, we -// flush the write buffer. -func (sc *http2serverConn) scheduleFrameWrite() { - sc.serveG.check() - if sc.writingFrame || sc.inFrameScheduleLoop { - return +func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { + if *sew.err != nil { + return 0, *sew.err } - sc.inFrameScheduleLoop = true - for !sc.writingFrameAsync { - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(http2FrameWriteRequest{ - write: &http2writeGoAway{ - maxStreamID: sc.maxClientStreamID, - code: sc.goAwayCode, - }, - }) - continue + for { + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + nn, err := sew.conn.Write(p[n:]) + n += nn + if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { + // Keep extending the deadline so long as we're making progress. continue } - if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { - if wr, ok := sc.writeSched.Pop(); ok { - if wr.isControl() { - sc.queuedControlFrames-- - } - sc.startFrameWrite(wr) - continue - } - } - if sc.needsFrameFlush { - sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false // after startFrameWrite, since it sets this true - continue + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Time{}) } - break + *sew.err = err + return n, err } - sc.inFrameScheduleLoop = false } -// startGracefulShutdown gracefully shuts down a connection. This -// sends GOAWAY with ErrCodeNo to tell the client we're gracefully -// shutting down. The connection isn't closed until all current -// streams are done. -// -// startGracefulShutdown returns immediately; it does not wait until -// the connection has shut down. -func (sc *http2serverConn) startGracefulShutdown() { - sc.serveG.checkNotOn() // NOT - sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) -} +// noCachedConnError is the concrete type of ErrNoCachedConn, which +// needs to be detected by net/http regardless of whether it's its +// bundled version (in h2_bundle.go with a rewritten type name) or +// from a user's x/net/http2. As such, as it has a unique method name +// (IsHTTP2NoCachedConnError) that net/http sniffs for via func +// isNoCachedConnError. +type http2noCachedConnError struct{} -// After sending GOAWAY with an error code (non-graceful shutdown), the -// connection will close after goAwayTimeout. -// -// If we close the connection immediately after sending GOAWAY, there may -// be unsent data in our kernel receive buffer, which will cause the kernel -// to send a TCP RST on close() instead of a FIN. This RST will abort the -// connection immediately, whether or not the client had received the GOAWAY. -// -// Ideally we should delay for at least 1 RTT + epsilon so the client has -// a chance to read the GOAWAY and stop sending messages. Measuring RTT -// is hard, so we approximate with 1 second. See golang.org/issue/18701. -// -// This is a var so it can be shorter in tests, where all requests uses the -// loopback interface making the expected RTT very small. -// -// TODO: configurable? -var http2goAwayTimeout = 1 * time.Second +func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} -func (sc *http2serverConn) startGracefulShutdownInternal() { - sc.goAway(http2ErrCodeNo) -} +func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } -func (sc *http2serverConn) goAway(code http2ErrCode) { - sc.serveG.check() - if sc.inGoAway { - return - } - sc.inGoAway = true - sc.needToSendGoAway = true - sc.goAwayCode = code - sc.scheduleFrameWrite() +// isNoCachedConnError reports whether err is of type noCachedConnError +// or its equivalent renamed type in net/http2's h2_bundle.go. Both types +// may coexist in the same running program. +func http2isNoCachedConnError(err error) bool { + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok } -func (sc *http2serverConn) shutDownIn(d time.Duration) { - sc.serveG.check() - sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) -} +var http2ErrNoCachedConn error = http2noCachedConnError{} -func (sc *http2serverConn) resetStream(se http2StreamError) { - sc.serveG.check() - sc.writeFrame(http2FrameWriteRequest{write: se}) - if st, ok := sc.streams[se.StreamID]; ok { - st.resetQueued = true - } +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type http2RoundTripOpt struct { + // OnlyCachedConn controls whether RoundTripOpt may + // create a new TCP connection. If set true and + // no cached connection is available, RoundTripOpt + // will return ErrNoCachedConn. + OnlyCachedConn bool } -// processFrameFromReader processes the serve loop's read from readFrameCh from the -// frame-reading goroutine. -// processFrameFromReader returns whether the connection should be kept open. -func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { - sc.serveG.check() - err := res.err - if err != nil { - if err == http2ErrFrameTooLarge { - sc.goAway(http2ErrCodeFrameSize) - return true // goAway will close the loop - } - clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) - if clientGone { - // TODO: could we also get into this state if - // the peer does a half close - // (e.g. CloseWrite) because they're done - // sending frames but they're still wanting - // our open replies? Investigate. - // TODO: add CloseWrite to crypto/tls.Conn first - // so we have a way to test this? I suppose - // just for testing we could have a non-TLS mode. - return false - } - } else { - f := res.f - if http2VerboseLogs { - sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) - } - err = sc.processFrame(f) - if err == nil { - return true - } - } - - switch ev := err.(type) { - case http2StreamError: - sc.resetStream(ev) - return true - case http2goAwayFlowError: - sc.goAway(http2ErrCodeFlowControl) - return true - case http2ConnectionError: - sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) - sc.goAway(http2ErrCode(ev)) - return true // goAway will handle shutdown - default: - if res.err != nil { - sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) - } else { - sc.logf("http2: server closing client connection: %v", err) - } - return false - } +func (t *http2Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripOpt(req, http2RoundTripOpt{}) } -func (sc *http2serverConn) processFrame(f http2Frame) error { - sc.serveG.check() - - // First frame received must be SETTINGS. - if !sc.sawFirstSettings { - if _, ok := f.(*http2SettingsFrame); !ok { - return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func http2authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" } - sc.sawFirstSettings = true - } - - switch f := f.(type) { - case *http2SettingsFrame: - return sc.processSettings(f) - case *http2MetaHeadersFrame: - return sc.processHeaders(f) - case *http2WindowUpdateFrame: - return sc.processWindowUpdate(f) - case *http2PingFrame: - return sc.processPing(f) - case *http2DataFrame: - return sc.processData(f) - case *http2RSTStreamFrame: - return sc.processResetStream(f) - case *http2PriorityFrame: - return sc.processPriority(f) - case *http2GoAwayFrame: - return sc.processGoAway(f) - case *http2PushPromiseFrame: - // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE - // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) - default: - sc.vlogf("http2: server ignoring frame: %v", f.Header()) - return nil - } -} - -func (sc *http2serverConn) processPing(f *http2PingFrame) error { - sc.serveG.check() - if f.IsAck() { - // 6.7 PING: " An endpoint MUST NOT respond to PING frames - // containing this flag." - return nil + host = authority } - if f.StreamID != 0 { - // "PING frames are not associated with any individual - // stream. If a PING frame is received with a stream - // identifier field value other than 0x0, the recipient MUST - // respond with a connection error (Section 5.4.1) of type - // PROTOCOL_ERROR." - return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) + if a, err := idna.ToASCII(host); err == nil { + host = a } - if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { - return nil + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port } - sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) - return nil + return net.JoinHostPort(host, port) } -func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { - sc.serveG.check() - switch { - case f.StreamID != 0: // stream-level flow control - state, st := sc.state(f.StreamID) - if state == http2stateIdle { - // Section 5.1: "Receiving any frame other than HEADERS - // or PRIORITY on a stream in this state MUST be - // treated as a connection error (Section 5.4.1) of - // type PROTOCOL_ERROR." - return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) - } - if st == nil { - // "WINDOW_UPDATE can be sent by a peer that has sent a - // frame bearing the END_STREAM flag. This means that a - // receiver could receive a WINDOW_UPDATE frame on a "half - // closed (remote)" or "closed" stream. A receiver MUST - // NOT treat this as an error, see Section 5.1." - return nil - } - if !st.flow.add(int32(f.Increment)) { - return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) +// RoundTripOpt is like RoundTrip, but takes options. +func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) (*http.Response, error) { + if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { + return nil, errors.New("http2: unsupported scheme") + } + + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + for retry := 0; ; retry++ { + cc, err := t.connPool().GetClientConn(req, addr) + if err != nil { + t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) + return nil, err + } + reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) + http2traceGotConn(req, cc, reused) + res, err := cc.RoundTrip(req) + if err != nil && retry <= 6 { + if req, err = http2shouldRetryRequest(req, err); err == nil { + // After the first retry, do exponential backoff with 10% jitter. + if retry == 0 { + continue + } + backoff := float64(uint(1) << (uint(retry) - 1)) + backoff += backoff * (0.1 * mathrand.Float64()) + select { + case <-time.After(time.Second * time.Duration(backoff)): + continue + case <-req.Context().Done(): + err = req.Context().Err() + } + } } - default: // connection-level flow control - if !sc.flow.add(int32(f.Increment)) { - return http2goAwayFlowError{} + if err != nil { + t.vlogf("RoundTrip failure: %v", err) + return nil, err } + return res, nil } - sc.scheduleFrameWrite() - return nil } -func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { - sc.serveG.check() - - state, st := sc.state(f.StreamID) - if state == http2stateIdle { - // 6.4 "RST_STREAM frames MUST NOT be sent for a - // stream in the "idle" state. If a RST_STREAM frame - // identifying an idle stream is received, the - // recipient MUST treat this as a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) - } - if st != nil { - st.cancelCtx() - sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle. +// It does not interrupt any connections currently in use. +func (t *http2Transport) CloseIdleConnections() { + if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { + cp.closeIdleConnections() } - return nil } -func (sc *http2serverConn) closeStream(st *http2stream, err error) { - sc.serveG.check() - if st.state == http2stateIdle || st.state == http2stateClosed { - panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) - } - st.state = http2stateClosed - if st.writeDeadline != nil { - st.writeDeadline.Stop() +var ( + http2errClientConnClosed = errors.New("http2: client conn is closed") + http2errClientConnUnusable = errors.New("http2: client conn not usable") + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") +) + +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { + if !http2canRetryError(err) { + return nil, err } - if st.isPushed() { - sc.curPushedStreams-- - } else { - sc.curClientStreams-- + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. + if req.Body == nil || req.Body == http.NoBody { + return req, nil } - delete(sc.streams, st.id) - if len(sc.streams) == 0 { - sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { - sc.idleTimer.Reset(sc.srv.IdleTimeout) - } - if http2h1ServerKeepAlivesDisabled(sc.hs) { - sc.startGracefulShutdownInternal() + + // If the request body can be reset back to its original + // state via the optional req.GetBody, do that. + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, err } + newReq := *req + newReq.Body = body + return &newReq, nil } - if p := st.body; p != nil { - // Return any buffered unread bytes worth of conn-level flow control. - // See golang.org/issue/16481 - sc.sendWindowUpdate(nil, p.Len()) - p.CloseWithError(err) + // The Request.Body can't reset back to the beginning, but we + // don't seem to have started to read from it yet, so reuse + // the request directly. + if err == http2errClientConnUnusable { + return req, nil } - st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc - sc.writeSched.CloseStream(st.id) -} -func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { - sc.serveG.check() - if f.IsAck() { - sc.unackedSettings-- - if sc.unackedSettings < 0 { - // Why is the peer ACKing settings we never sent? - // The spec doesn't mention this case, but - // hang up on them anyway. - return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) - } - return nil - } - if f.NumSettings() > 100 || f.HasDuplicates() { - // This isn't actually in the spec, but hang up on - // suspiciously large settings frames or those with - // duplicate entries. - return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) - } - if err := f.ForeachSetting(sc.processSetting); err != nil { - return err - } - // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be - // acknowledged individually, even if multiple are received before the ACK. - sc.needToSendSettingsAck = true - sc.scheduleFrameWrite() - return nil + return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) } -func (sc *http2serverConn) processSetting(s http2Setting) error { - sc.serveG.check() - if err := s.Valid(); err != nil { - return err - } - if http2VerboseLogs { - sc.vlogf("http2: server processing setting %v", s) +func http2canRetryError(err error) bool { + if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { + return true } - switch s.ID { - case http2SettingHeaderTableSize: - sc.headerTableSize = s.Val - sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) - case http2SettingEnablePush: - sc.pushEnabled = s.Val != 0 - case http2SettingMaxConcurrentStreams: - sc.clientMaxStreams = s.Val - case http2SettingInitialWindowSize: - return sc.processSettingInitialWindowSize(s.Val) - case http2SettingMaxFrameSize: - sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 - case http2SettingMaxHeaderListSize: - sc.peerMaxHeaderListSize = s.Val - default: - // Unknown setting: "An endpoint that receives a SETTINGS - // frame with any unknown or unsupported identifier MUST - // ignore that setting." - if http2VerboseLogs { - sc.vlogf("http2: server ignoring unknown setting %v", s) + if se, ok := err.(http2StreamError); ok { + if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { + // See golang/go#47635, golang/go#42777 + return true } + return se.Code == http2ErrCodeRefusedStream } - return nil + return false } -func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { - sc.serveG.check() - // Note: val already validated to be within range by - // processSetting's Valid call. - - // "A SETTINGS frame can alter the initial flow control window - // size for all current streams. When the value of - // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST - // adjust the size of all stream flow control windows that it - // maintains by the difference between the new value and the - // old value." - old := sc.initialStreamSendWindowSize - sc.initialStreamSendWindowSize = int32(val) - growth := int32(val) - old // may be negative - for _, st := range sc.streams { - if !st.flow.add(growth) { - // 6.9.2 Initial Flow Control Window Size - // "An endpoint MUST treat a change to - // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow - // control window to exceed the maximum size as a - // connection error (Section 5.4.1) of type - // FLOW_CONTROL_ERROR." - return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) - } +func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err } - return nil + tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) + if err != nil { + return nil, err + } + return t.newClientConn(tconn, singleUse) } -func (sc *http2serverConn) processData(f *http2DataFrame) error { - sc.serveG.check() - id := f.Header().StreamID - if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || id > sc.maxClientStreamID) { - // Discard all DATA frames if the GOAWAY is due to an - // error, or: - // - // Section 6.8: After sending a GOAWAY frame, the sender - // can discard frames for streams initiated by the - // receiver with identifiers higher than the identified - // last stream. - return nil +func (t *http2Transport) newTLSConfig(host string) *tls.Config { + cfg := new(tls.Config) + if t.TLSClientConfig != nil { + *cfg = *t.TLSClientConfig.Clone() } - - data := f.Data() - state, st := sc.state(id) - if id == 0 || state == http2stateIdle { - // Section 6.1: "DATA frames MUST be associated with a - // stream. If a DATA frame is received whose stream - // identifier field is 0x0, the recipient MUST respond - // with a connection error (Section 5.4.1) of type - // PROTOCOL_ERROR." - // - // Section 5.1: "Receiving any frame other than HEADERS - // or PRIORITY on a stream in this state MUST be - // treated as a connection error (Section 5.4.1) of - // type PROTOCOL_ERROR." - return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) - } - - // "If a DATA frame is received whose stream is not in "open" - // or "half closed (local)" state, the recipient MUST respond - // with a stream error (Section 5.4.2) of type STREAM_CLOSED." - if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { - // This includes sending a RST_STREAM if the stream is - // in stateHalfClosedLocal (which currently means that - // the http.Handler returned, so it's done reading & - // done writing). Try to stop the client from sending - // more DATA. - - // But still enforce their connection-level flow control, - // and return any flow control bytes since we're not going - // to consume them. - if sc.inflow.available() < int32(f.Length) { - return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) - } - // Deduct the flow control from inflow, since we're - // going to immediately add it back in - // sendWindowUpdate, which also schedules sending the - // frames. - sc.inflow.take(int32(f.Length)) - sc.sendWindowUpdate(nil, int(f.Length)) // conn-level - - if st != nil && st.resetQueued { - // Already have a stream error in flight. Don't send another. - return nil - } - return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) + if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { + cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) } - if st.body == nil { - panic("internal error: should have a body in this state") + if cfg.ServerName == "" { + cfg.ServerName = host } + return cfg +} - // Sender sending more than they'd declared? - if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { - st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) - // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the - // value of a content-length header field does not equal the sum of the - // DATA frame payload lengths that form the body. - return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) +func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS } - if f.Length > 0 { - // Check whether the client has flow control quota. - if st.inflow.available() < int32(f.Length) { - return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) + return func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) + if err != nil { + return nil, err } - st.inflow.take(int32(f.Length)) - - if len(data) > 0 { - wrote, err := st.body.Write(data) - if err != nil { - sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) - } - if wrote != len(data) { - panic("internal error: bad Writer") - } - st.bodyBytes += int64(len(data)) + state := tlsCn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) } - - // Return any padded flow control now, since we won't - // refund it later on body reads. - if pad := int32(f.Length) - int32(len(data)); pad > 0 { - sc.sendWindowUpdate32(nil, pad) - sc.sendWindowUpdate32(st, pad) + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") } + return tlsCn, nil } - if f.StreamEnded() { - st.endStream() - } - return nil -} - -func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { - sc.serveG.check() - if f.ErrCode != http2ErrCodeNo { - sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) - } else { - sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) - } - sc.startGracefulShutdownInternal() - // http://tools.ietf.org/html/rfc7540#section-6.8 - // We should not create any new streams, which means we should disable push. - sc.pushEnabled = false - return nil -} - -// isPushed reports whether the stream is server-initiated. -func (st *http2stream) isPushed() bool { - return st.id%2 == 0 } -// endStream closes a Request.Body's pipe. It is called when a DATA -// frame says a request body is over (or after trailers). -func (st *http2stream) endStream() { - sc := st.sc - sc.serveG.check() - - if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { - st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", - st.declBodyBytes, st.bodyBytes)) - } else { - st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) - st.body.CloseWithError(io.EOF) - } - st.state = http2stateHalfClosedRemote +// disableKeepAlives reports whether connections should be closed as +// soon as possible after handling the first request. +func (t *http2Transport) disableKeepAlives() bool { + return t.t1 != nil && t.t1.DisableKeepAlives } -// copyTrailersToHandlerRequest is run in the Handler's goroutine in -// its Request.Body.Read just before it gets io.EOF. -func (st *http2stream) copyTrailersToHandlerRequest() { - for k, vv := range st.trailer { - if _, ok := st.reqTrailer[k]; ok { - // Only copy it over it was pre-declared. - st.reqTrailer[k] = vv - } +func (t *http2Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 } + return t.t1.ExpectContinueTimeout } -// onWriteTimeout is run on its own goroutine (from time.AfterFunc) -// when the stream's WriteTimeout has fired. -func (st *http2stream) onWriteTimeout() { - st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { + return t.newClientConn(c, t.disableKeepAlives()) } -func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { - sc.serveG.check() - id := f.StreamID - if sc.inGoAway { - // Ignore. - return nil +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { + cc := &http2ClientConn{ + t: t, + tconn: c, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + streams: make(map[uint32]*http2clientStream), + singleUse: singleUse, + wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + reqHeaderMu: make(chan struct{}, 1), } - // http://tools.ietf.org/html/rfc7540#section-5.1.1 - // Streams initiated by a client MUST use odd-numbered stream - // identifiers. [...] An endpoint that receives an unexpected - // stream identifier MUST respond with a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. - if id%2 != 1 { - return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) - } - // A HEADERS frame can be used to create a new stream or - // send a trailer for an open one. If we already have a stream - // open, let it process its own HEADERS frame (trailers at this - // point, if it's valid). - if st := sc.streams[f.StreamID]; st != nil { - if st.resetQueued { - // We're sending RST_STREAM to close the stream, so don't bother - // processing this frame. - return nil - } - // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than - // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in - // this state, it MUST respond with a stream error (Section 5.4.2) of - // type STREAM_CLOSED. - if st.state == http2stateHalfClosedRemote { - return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) - } - return st.processTrailerHeaders(f) + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } - - // [...] The identifier of a newly established stream MUST be - // numerically greater than all streams that the initiating - // endpoint has opened or reserved. [...] An endpoint that - // receives an unexpected stream identifier MUST respond with - // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - if id <= sc.maxClientStreamID { - return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) + if http2VerboseLogs { + t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) } - sc.maxClientStreamID = id - if sc.idleTimer != nil { - sc.idleTimer.Stop() - } + cc.cond = sync.NewCond(&cc.mu) + cc.flow.add(int32(http2initialWindowSize)) - // http://tools.ietf.org/html/rfc7540#section-5.1.2 - // [...] Endpoints MUST NOT exceed the limit set by their peer. An - // endpoint that receives a HEADERS frame that causes their - // advertised concurrent stream limit to be exceeded MUST treat - // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR - // or REFUSED_STREAM. - if sc.curClientStreams+1 > sc.advMaxStreams { - if sc.unackedSettings == 0 { - // They should know better. - return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) - } - // Assume it's a network race, where they just haven't - // received our last SETTINGS update. But actually - // this can't happen yet, because we don't yet provide - // a way for users to adjust server parameters at - // runtime. - return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) + // TODO: adjust this writer size to account for frame size + + // MTU + crypto/tls record padding. + cc.bw = bufio.NewWriter(http2stickyErrWriter{ + conn: c, + timeout: t.WriteByteTimeout, + err: &cc.werr, + }) + cc.br = bufio.NewReader(c) + cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr.cc = cc // for dump single request + if t.CountError != nil { + cc.fr.countError = t.CountError } + cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.MaxHeaderListSize = t.maxHeaderListSize() - initialState := http2stateOpen - if f.StreamEnded() { - initialState = http2stateHalfClosedRemote + // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on + // henc in response to SETTINGS frames? + cc.henc = hpack.NewEncoder(&cc.hbuf) + + if t.AllowHTTP { + cc.nextStreamID = 3 } - st := sc.newStream(id, 0, initialState) - if f.HasPriority() { - if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { - return err - } - sc.writeSched.AdjustStream(st.id, f.Priority) + if cs, ok := c.(http2connectionStater); ok { + state := cs.ConnectionState() + cc.tlsState = &state } - rw, req, err := sc.newWriterAndRequest(st, f) - if err != nil { - return err + initialSettings := []http2Setting{ + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, } - st.reqTrailer = req.Trailer - if st.reqTrailer != nil { - st.trailer = make(http.Header) + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) } - st.body = req.Body.(*http2requestBody).pipe // may be nil - st.declBodyBytes = req.ContentLength - handler := sc.handler.ServeHTTP - if f.Truncated { - // Their header list was too long. Send a 431 error. - handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { - handler = http2new400Handler(err) + cc.bw.Write(http2clientPreface) + cc.fr.WriteSettings(initialSettings...) + cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.bw.Flush() + if cc.werr != nil { + cc.Close() + return nil, cc.werr } - // The net/http package sets the read deadline from the - // http.Server.ReadTimeout during the TLS handshake, but then - // passes the connection off to us with the deadline already - // set. Disarm it here after the request headers are read, - // similar to how the http1 server works. Here it's - // technically more like the http1 Server's ReadHeaderTimeout - // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { - sc.conn.SetReadDeadline(time.Time{}) + go cc.readLoop() + return cc, nil +} + +func (cc *http2ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return } +} - go sc.runHandler(rw, req, handler) - return nil +// SetDoNotReuse marks cc as not reusable for future HTTP requests. +func (cc *http2ClientConn) SetDoNotReuse() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.doNotReuse = true } -func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { - sc := st.sc - sc.serveG.check() - if st.gotTrailerHeader { - return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) +func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { + cc.mu.Lock() + defer cc.mu.Unlock() + + old := cc.goAway + cc.goAway = f + + // Merge the previous and current GoAway error frames. + if cc.goAwayDebug == "" { + cc.goAwayDebug = string(f.DebugData()) } - st.gotTrailerHeader = true - if !f.StreamEnded() { - return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) + if old != nil && old.ErrCode != http2ErrCodeNo { + cc.goAway.ErrCode = old.ErrCode } - - if len(f.PseudoFields()) > 0 { - return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) - } - if st.trailer != nil { - for _, hf := range f.RegularFields() { - key := sc.canonicalHeader(hf.Name) - if !httpguts.ValidTrailerHeader(key) { - // TODO: send more details to the peer somehow. But http2 has - // no way to send debug data at a stream level. Discuss with - // HTTP folk. - return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) - } - st.trailer[key] = append(st.trailer[key], hf.Value) + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + cs.abortStreamLocked(http2errClientConnGotGoAway) } } - st.endStream() - return nil } -func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { - if streamID == p.StreamDep { - // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat - // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." - // Section 5.3.3 says that a stream can depend on one of its dependencies, - // so it's only self-dependencies that are forbidden. - return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) - } - return nil +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. +// +// If the caller is going to immediately make a new request on this +// connection, use ReserveNewRequest instead. +func (cc *http2ClientConn) CanTakeNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.canTakeNewRequestLocked() } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - if sc.inGoAway { - return nil - } - if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { - return err +// ReserveNewRequest is like CanTakeNewRequest but also reserves a +// concurrent stream in cc. The reservation is decremented on the +// next call to RoundTrip. +func (cc *http2ClientConn) ReserveNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + if st := cc.idleStateLocked(); !st.canTakeNewRequest { + return false } - sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) - return nil + cc.streamsReserved++ + return true } -func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { - sc.serveG.check() - if id == 0 { - panic("internal error: cannot create stream with id 0") - } - - ctx, cancelCtx := context.WithCancel(sc.baseCtx) - st := &http2stream{ - sc: sc, - id: id, - state: state, - ctx: ctx, - cancelCtx: cancelCtx, - } - st.cw.Init() - st.flow.conn = &sc.flow // link to conn-level counter - st.flow.add(sc.initialStreamSendWindowSize) - st.inflow.conn = &sc.inflow // link to conn-level counter - st.inflow.add(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { - st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) - } +// ClientConnState describes the state of a ClientConn. +type http2ClientConnState struct { + // Closed is whether the connection is closed. + Closed bool - sc.streams[id] = st - sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) - if st.isPushed() { - sc.curPushedStreams++ - } else { - sc.curClientStreams++ - } - if sc.curOpenStreams() == 1 { - sc.setConnState(http.StateActive) - } + // Closing is whether the connection is in the process of + // closing. It may be closing due to shutdown, being a + // single-use connection, being marked as DoNotReuse, or + // having received a GOAWAY frame. + Closing bool - return st -} + // StreamsActive is how many streams are active. + StreamsActive int -func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *http.Request, error) { - sc.serveG.check() + // StreamsReserved is how many streams have been reserved via + // ClientConn.ReserveNewRequest. + StreamsReserved int - rp := http2requestParam{ - method: f.PseudoValue("method"), - scheme: f.PseudoValue("scheme"), - authority: f.PseudoValue("authority"), - path: f.PseudoValue("path"), - } + // StreamsPending is how many requests have been sent in excess + // of the peer's advertised MaxConcurrentStreams setting and + // are waiting for other streams to complete. + StreamsPending int - isConnect := rp.method == "CONNECT" - if isConnect { - if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) - } - } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { - // See 8.1.2.6 Malformed Requests and Responses: - // - // Malformed requests or responses that are detected - // MUST be treated as a stream error (Section 5.4.2) - // of type PROTOCOL_ERROR." - // - // 8.1.2.3 Request Pseudo-Header Fields - // "All HTTP/2 requests MUST include exactly one valid - // value for the :method, :scheme, and :path - // pseudo-header fields" - return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) - } + // MaxConcurrentStreams is how many concurrent streams the + // peer advertised as acceptable. Zero means no SETTINGS + // frame has been received yet. + MaxConcurrentStreams uint32 - bodyOpen := !f.StreamEnded() - if rp.method == "HEAD" && bodyOpen { - // HEAD requests can't have bodies - return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) - } + // LastIdle, if non-zero, is when the connection last + // transitioned to idle state. + LastIdle time.Time +} - rp.header = make(http.Header) - for _, hf := range f.RegularFields() { - rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) - } - if rp.authority == "" { - rp.authority = rp.header.Get("Host") +// State returns a snapshot of cc's state. +func (cc *http2ClientConn) State() http2ClientConnState { + cc.wmu.Lock() + maxConcurrent := cc.maxConcurrentStreams + if !cc.seenSettings { + maxConcurrent = 0 } + cc.wmu.Unlock() - rw, req, err := sc.newWriterAndRequestNoBody(st, rp) - if err != nil { - return nil, nil, err - } - if bodyOpen { - if vv, ok := rp.header["Content-Length"]; ok { - if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { - req.ContentLength = int64(cl) - } else { - req.ContentLength = 0 - } - } else { - req.ContentLength = -1 - } - req.Body.(*http2requestBody).pipe = &http2pipe{ - b: &http2dataBuffer{expected: req.ContentLength}, - } + cc.mu.Lock() + defer cc.mu.Unlock() + return http2ClientConnState{ + Closed: cc.closed, + Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, + StreamsActive: len(cc.streams), + StreamsReserved: cc.streamsReserved, + StreamsPending: cc.pendingRequests, + LastIdle: cc.lastIdle, + MaxConcurrentStreams: maxConcurrent, } - return rw, req, nil } -type http2requestParam struct { - method string - scheme, authority, path string - header http.Header +// clientConnIdleState describes the suitability of a client +// connection to initiate a new RoundTrip request. +type http2clientConnIdleState struct { + canTakeNewRequest bool } -func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *http.Request, error) { - sc.serveG.check() - - var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { - tlsState = sc.tlsState - } - - needsContinue := rp.header.Get("Expect") == "100-continue" - if needsContinue { - rp.header.Del("Expect") - } - // Merge Cookie headers into one "; "-delimited value. - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) - } +func (cc *http2ClientConn) idleState() http2clientConnIdleState { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.idleStateLocked() +} - // Setup Trailers - var trailer http.Header - for _, v := range rp.header["Trailer"] { - for _, key := range strings.Split(v, ",") { - key = http.CanonicalHeaderKey(textproto.TrimString(key)) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - // Bogus. (copy of http1 rules) - // Ignore. - default: - if trailer == nil { - trailer = make(http.Header) - } - trailer[key] = nil - } - } +func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { + if cc.singleUse && cc.nextStreamID > 1 { + return } - delete(rp.header, "Trailer") - - var url_ *url.URL - var requestURI string - if rp.method == "CONNECT" { - url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority // mimic HTTP/1 server behavior + var maxConcurrentOkay bool + if cc.t.StrictMaxConcurrentStreams { + // We'll tell the caller we can take a new request to + // prevent the caller from dialing a new TCP + // connection, but then we'll block later before + // writing it. + maxConcurrentOkay = true } else { - var err error - url_, err = url.ParseRequestURI(rp.path) - if err != nil { - return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) - } - requestURI = rp.path - } - - body := &http2requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } - req := &http.Request{ - Method: rp.method, - URL: url_, - RemoteAddr: sc.remoteAddrStr, - Header: rp.header, - RequestURI: requestURI, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - TLS: tlsState, - Host: rp.authority, - Body: body, - Trailer: trailer, - } - req = req.WithContext(st.ctx) - - rws := http2responseWriterStatePool.Get().(*http2responseWriterState) - bwSave := rws.bw - *rws = http2responseWriterState{} // zero all the fields - rws.conn = sc - rws.bw = bwSave - rws.bw.Reset(http2chunkWriter{rws}) - rws.stream = st - rws.req = req - rws.body = body - - rw := &http2responseWriter{rws: rws} - return rw, req, nil -} - -// Run on its own goroutine. -func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { - didPanic := true - defer func() { - rw.rws.stream.cancelCtx() - if didPanic { - e := recover() - sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: http2handlerPanicRST{rw.rws.stream.id}, - stream: rw.rws.stream, - }) - // Same as net/http: - if e != nil && e != http.ErrAbortHandler { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) - } - return - } - rw.handlerDone() - }() - handler(rw, req) - didPanic = false -} - -func http2handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { - // 10.5.1 Limits on Header Block Size: - // .. "A server that receives a larger header block than it is - // willing to handle can send an HTTP 431 (Request Header Fields Too - // Large) status code" - const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ - w.WriteHeader(statusRequestHeaderFieldsTooLarge) - io.WriteString(w, "

HTTP Error 431

Request Header Field(s) Too Large

") -} - -// called from handler goroutines. -// h may be nil. -func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { - sc.serveG.checkNotOn() // NOT on - var errc chan error - if headerData.h != nil { - // If there's a header map (which we don't own), so we have to block on - // waiting for this frame to be written, so an http.Flush mid-handler - // writes out the correct value of keys, before a handler later potentially - // mutates it. - errc = http2errChanPool.Get().(chan error) - } - if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: headerData, - stream: st, - done: errc, - }); err != nil { - return err - } - if errc != nil { - select { - case err := <-errc: - http2errChanPool.Put(errc) - return err - case <-sc.doneServing: - return http2errClientDisconnected - case <-st.cw: - return http2errStreamClosed - } + maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams) } - return nil -} - -// called from handler goroutines. -func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: http2write100ContinueHeadersFrame{st.id}, - stream: st, - }) -} -// A bodyReadMsg tells the server loop that the http.Handler read n -// bytes of the DATA from the client on the given stream. -type http2bodyReadMsg struct { - st *http2stream - n int + st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + !cc.doNotReuse && + int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && + !cc.tooIdleLocked() + return } -// called from handler goroutines. -// Notes that the handler for the given stream ID read n bytes of its body -// and schedules flow control tokens to be sent. -func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { - sc.serveG.checkNotOn() // NOT on - if n > 0 { - select { - case sc.bodyReadCh <- http2bodyReadMsg{st, n}: - case <-sc.doneServing: - } - } +func (cc *http2ClientConn) canTakeNewRequestLocked() bool { + st := cc.idleStateLocked() + return st.canTakeNewRequest } -func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { - sc.serveG.check() - sc.sendWindowUpdate(nil, n) // conn-level - if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { - // Don't send this WINDOW_UPDATE if the stream is closed - // remotely. - sc.sendWindowUpdate(st, n) - } +// tooIdleLocked reports whether this connection has been been sitting idle +// for too much wall time. +func (cc *http2ClientConn) tooIdleLocked() bool { + // The Round(0) strips the monontonic clock reading so the + // times are compared based on their wall time. We don't want + // to reuse a connection that's been sitting idle during + // VM/laptop suspend if monotonic time was also frozen. + return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout } -// st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { - sc.serveG.check() - // "The legal range for the increment to the flow control - // window is 1 to 2^31-1 (2,147,483,647) octets." - // A Go Read call on 64-bit machines could in theory read - // a larger Read than this. Very unlikely, but we handle it here - // rather than elsewhere for now. - const maxUint31 = 1<<31 - 1 - for n >= maxUint31 { - sc.sendWindowUpdate32(st, maxUint31) - n -= maxUint31 - } - sc.sendWindowUpdate32(st, int32(n)) +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() } -// st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { - sc.serveG.check() - if n == 0 { +func (cc *http2ClientConn) closeIfIdle() { + cc.mu.Lock() + if len(cc.streams) > 0 || cc.streamsReserved > 0 { + cc.mu.Unlock() return } - if n < 0 { - panic("negative update") - } - var streamID uint32 - if st != nil { - streamID = st.id - } - sc.writeFrame(http2FrameWriteRequest{ - write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, - stream: st, - }) - var ok bool - if st == nil { - ok = sc.inflow.add(n) - } else { - ok = st.inflow.add(n) - } - if !ok { - panic("internal error; sent too many window updates without decrements?") + cc.closed = true + nextID := cc.nextStreamID + // TODO: do clients send GOAWAY too? maybe? Just Close: + cc.mu.Unlock() + + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) } + cc.tconn.Close() } -// requestBody is the Handler's Request.Body type. -// Read and Close may be called concurrently. -type http2requestBody struct { - _ http2incomparable - stream *http2stream - conn *http2serverConn - closed bool // for use by Close only - sawEOF bool // for use by Read only - pipe *http2pipe // non-nil if we have a HTTP entity message body - needsContinue bool // need to send a 100-continue +func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.doNotReuse && len(cc.streams) == 0 } -func (b *http2requestBody) Close() error { - if b.pipe != nil && !b.closed { - b.pipe.BreakWithError(http2errClosedBody) - } - b.closed = true - return nil -} +var http2shutdownEnterWaitStateHook = func() {} -func (b *http2requestBody) Read(p []byte) (n int, err error) { - if b.needsContinue { - b.needsContinue = false - b.conn.write100ContinueHeaders(b.stream) - } - if b.pipe == nil || b.sawEOF { - return 0, io.EOF - } - n, err = b.pipe.Read(p) - if err == io.EOF { - b.sawEOF = true +// Shutdown gracefully closes the client connection, waiting for running streams to complete. +func (cc *http2ClientConn) Shutdown(ctx context.Context) error { + if err := cc.sendGoAway(); err != nil { + return err } - if b.conn == nil && http2inTests { - return - } - b.conn.noteBodyReadFromHandler(b.stream, n, err) - return -} - -// responseWriter is the http.ResponseWriter implementation. It's -// intentionally small (1 pointer wide) to minimize garbage. The -// responseWriterState pointer inside is zeroed at the end of a -// request (in handlerDone) and calls on the responseWriter thereafter -// simply crash (caller's mistake), but the much larger responseWriterState -// and buffers are reused between multiple requests. -type http2responseWriter struct { - rws *http2responseWriterState -} - -// Optional http.ResponseWriter interfaces implemented. -var ( - _ http.CloseNotifier = (*http2responseWriter)(nil) - _ http.Flusher = (*http2responseWriter)(nil) - _ http2stringWriter = (*http2responseWriter)(nil) -) - -type http2responseWriterState struct { - // immutable within a request: - stream *http2stream - req *http.Request - body *http2requestBody // to close at end of request, if DATA frames didn't - conn *http2serverConn - - // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc - bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} - - // mutated by http.Handler goroutine: - handlerHeader http.Header // nil until called - snapHeader http.Header // snapshot of handlerHeader at WriteHeader time - trailers []string // set in writeChunk - status int // status code passed to WriteHeader - wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. - sentHeader bool // have we sent the header frame? - handlerDone bool // handler has finished - dirty bool // a Write failed; don't reuse this responseWriterState - - sentContentLen int64 // non-zero if handler set a Content-Length header - wroteBytes int64 - - closeNotifierMu sync.Mutex // guards closeNotifierCh - closeNotifierCh chan bool // nil until first used -} - -type http2chunkWriter struct{ rws *http2responseWriterState } - -func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } - -func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } - -func (rws *http2responseWriterState) hasNonemptyTrailers() bool { - for _, trailer := range rws.trailers { - if _, ok := rws.handlerHeader[trailer]; ok { - return true + // Wait for all in-flight streams to complete or connection to close + done := make(chan error, 1) + cancelled := false // guarded by cc.mu + go func() { + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if len(cc.streams) == 0 || cc.closed { + cc.closed = true + done <- cc.tconn.Close() + break + } + if cancelled { + break + } + cc.cond.Wait() } + }() + http2shutdownEnterWaitStateHook() + select { + case err := <-done: + return err + case <-ctx.Done(): + cc.mu.Lock() + // Free the goroutine above + cancelled = true + cc.cond.Broadcast() + cc.mu.Unlock() + return ctx.Err() } - return false } -// declareTrailer is called for each Trailer header when the -// response header is written. It notes that a header will need to be -// written in the trailers at the end of the response. -func (rws *http2responseWriterState) declareTrailer(k string) { - k = http.CanonicalHeaderKey(k) - if !httpguts.ValidTrailerHeader(k) { - // Forbidden by RFC 7230, section 4.1.2. - rws.conn.logf("ignoring invalid trailer %q", k) - return - } - if !http2strSliceContains(rws.trailers, k) { - rws.trailers = append(rws.trailers, k) +func (cc *http2ClientConn) sendGoAway() error { + cc.mu.Lock() + closing := cc.closing + cc.closing = true + maxStreamID := cc.nextStreamID + cc.mu.Unlock() + if closing { + // GOAWAY sent already + return nil } -} - -// writeChunk writes chunks from the bufio.Writer. But because -// bufio.Writer may bypass its chunking, sometimes p may be -// arbitrarily large. -// -// writeChunk is also responsible (on the first chunk) for sending the -// HEADER response. -func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { - if !rws.wroteHeader { - rws.writeHeader(200) - } - - isHeadResp := rws.req.Method == "HEAD" - if !rws.sentHeader { - rws.sentHeader = true - var ctype, clen string - if clen = rws.snapHeader.Get("Content-Length"); clen != "" { - rws.snapHeader.Del("Content-Length") - if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { - rws.sentContentLen = int64(cl) - } else { - clen = "" - } - } - if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { - clen = strconv.Itoa(len(p)) - } - _, hasContentType := rws.snapHeader["Content-Type"] - // If the Content-Encoding is non-blank, we shouldn't - // sniff the body. See Issue golang.org/issue/31753. - ce := rws.snapHeader.Get("Content-Encoding") - hasCE := len(ce) > 0 - if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { - ctype = http.DetectContentType(p) - } - var date string - if _, ok := rws.snapHeader["Date"]; !ok { - // TODO(bradfitz): be faster here, like net/http? measure. - date = time.Now().UTC().Format(http.TimeFormat) - } - - for _, v := range rws.snapHeader["Trailer"] { - http2foreachHeaderElement(v, rws.declareTrailer) - } - - // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), - // but respect "Connection" == "close" to mean sending a GOAWAY and tearing - // down the TCP connection when idle, like we do for HTTP/1. - // TODO: remove more Connection-specific header fields here, in addition - // to "Connection". - if _, ok := rws.snapHeader["Connection"]; ok { - v := rws.snapHeader.Get("Connection") - delete(rws.snapHeader, "Connection") - if v == "close" { - rws.conn.startGracefulShutdown() - } - } - endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp - err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ - streamID: rws.stream.id, - httpResCode: rws.status, - h: rws.snapHeader, - endStream: endStream, - contentType: ctype, - contentLength: clen, - date: date, - }) - if err != nil { - rws.dirty = true - return 0, err - } - if endStream { - return 0, nil - } - } - if isHeadResp { - return len(p), nil + cc.wmu.Lock() + defer cc.wmu.Unlock() + // Send a graceful shutdown frame to server + if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { + return err } - if len(p) == 0 && !rws.handlerDone { - return 0, nil + if err := cc.bw.Flush(); err != nil { + return err } + // Prevent new requests + return nil +} - if rws.handlerDone { - rws.promoteUndeclaredTrailers() +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *http2ClientConn) closeForError(err error) error { + cc.mu.Lock() + cc.closed = true + for _, cs := range cc.streams { + cs.abortStreamLocked(err) } + defer cc.cond.Broadcast() + defer cc.mu.Unlock() + return cc.tconn.Close() +} - // only send trailers if they have actually been defined by the - // server handler. - hasNonemptyTrailers := rws.hasNonemptyTrailers() - endStream := rws.handlerDone && !hasNonemptyTrailers - if len(p) > 0 || endStream { - // only send a 0 byte DATA frame if we're ending the stream. - if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { - rws.dirty = true - return 0, err - } - } +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *http2ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} - if rws.handlerDone && hasNonemptyTrailers { - err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ - streamID: rws.stream.id, - h: rws.handlerHeader, - trailers: rws.trailers, - endStream: true, - }) - if err != nil { - rws.dirty = true - } - return len(p), err +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *http2ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + if f := cc.t.CountError; f != nil { + f("conn_close_lost_ping") } - return len(p), nil + return cc.closeForError(err) } -// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys -// that, if present, signals that the map entry is actually for -// the response trailers, and not the response headers. The prefix -// is stripped after the ServeHTTP call finishes and the values are -// sent in the trailers. -// -// This mechanism is intended only for trailers that are not known -// prior to the headers being written. If the set of trailers is fixed -// or known before the header is written, the normal Go trailers mechanism -// is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers -const http2TrailerPrefix = "Trailer:" +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") -// promoteUndeclaredTrailers permits http.Handlers to set trailers -// after the header has already been flushed. Because the Go -// ResponseWriter interface has no way to set Trailers (only the -// Header), and because we didn't want to expand the ResponseWriter -// interface, and because nobody used trailers, and because RFC 7230 -// says you SHOULD (but not must) predeclare any trailers in the -// header, the official ResponseWriter rules said trailers in Go must -// be predeclared, and then we reuse the same ResponseWriter.Header() -// map to mean both Headers and Trailers. When it's time to write the -// Trailers, we pick out the fields of Headers that were declared as -// trailers. That worked for a while, until we found the first major -// user of Trailers in the wild: gRPC (using them only over http2), -// and gRPC libraries permit setting trailers mid-stream without -// predeclaring them. So: change of plans. We still permit the old -// way, but we also permit this hack: if a Header() key begins with -// "Trailer:", the suffix of that key is a Trailer. Because ':' is an -// invalid token byte anyway, there is no ambiguity. (And it's already -// filtered out) It's mildly hacky, but not terrible. -// -// This method runs after the Handler is done and promotes any Header -// fields to be trailers. -func (rws *http2responseWriterState) promoteUndeclaredTrailers() { - for k, vv := range rws.handlerHeader { - if !strings.HasPrefix(k, http2TrailerPrefix) { - continue +func http2commaSeparatedTrailers(req *http.Request) (string, error) { + keys := make([]string, 0, len(req.Trailer)) + for k := range req.Trailer { + k = http.CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", fmt.Errorf("invalid Trailer key %q", k) } - trailerKey := strings.TrimPrefix(k, http2TrailerPrefix) - rws.declareTrailer(trailerKey) - rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv + keys = append(keys, k) } - - if len(rws.trailers) > 1 { - sorter := http2sorterPool.Get().(*http2sorter) - sorter.SortStrings(rws.trailers) - http2sorterPool.Put(sorter) + if len(keys) > 0 { + sort.Strings(keys) + return strings.Join(keys, ","), nil } + return "", nil } -func (w *http2responseWriter) Flush() { - rws := w.rws - if rws == nil { - panic("Header called after Handler finished") - } - if rws.bw.Buffered() > 0 { - if err := rws.bw.Flush(); err != nil { - // Ignore the error. The frame writer already knows. - return - } - } else { - // The bufio.Writer won't call chunkWriter.Write - // (writeChunk with zero bytes, so we have to do it - // ourselves to force the HTTP response header and/or - // final DATA frame (with END_STREAM) to be sent. - rws.writeChunk(nil) +func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { + if cc.t.t1 != nil { + return cc.t.t1.ResponseHeaderTimeout } + // No way to do this (yet?) with just an http2.Transport. Probably + // no need. Request.Cancel this is the new way. We only need to support + // this for compatibility with the old http.Transport fields when + // we're doing transparent http2. + return 0 } -func (w *http2responseWriter) CloseNotify() <-chan bool { - rws := w.rws - if rws == nil { - panic("CloseNotify called after Handler finished") - } - rws.closeNotifierMu.Lock() - ch := rws.closeNotifierCh - if ch == nil { - ch = make(chan bool, 1) - rws.closeNotifierCh = ch - cw := rws.stream.cw - go func() { - cw.Wait() // wait for close - ch <- true - }() +// checkConnHeaders checks whether req has any invalid connection-level headers. +// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields. +// Certain headers are special-cased as okay but not transmitted later. +func http2checkConnHeaders(req *http.Request) error { + if v := req.Header.Get("Upgrade"); v != "" { + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } - rws.closeNotifierMu.Unlock() - return ch -} - -func (w *http2responseWriter) Header() http.Header { - rws := w.rws - if rws == nil { - panic("Header called after Handler finished") + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if rws.handlerHeader == nil { - rws.handlerHeader = make(http.Header) + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) } - return rws.handlerHeader + return nil } -// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. -func http2checkWriteHeaderCode(code int) { - // Issue 22880: require valid WriteHeader status codes. - // For now we only enforce that it's three digits. - // In the future we might block things over 599 (600 and above aren't defined - // at http://httpwg.org/specs/rfc7231.html#status.codes) - // and we might block under 200 (once we have more mature 1xx support). - // But for now any three digits. - // - // We used to send "HTTP/1.1 000 0" on the wire in responses but there's - // no equivalent bogus thing we can realistically send in HTTP/2, - // so we'll consistently panic instead and help people find their bugs - // early. (We can't return an error from WriteHeader even if we wanted to.) - if code < 100 || code > 999 { - panic(fmt.Sprintf("invalid WriteHeader code %v", code)) +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 } -} - -func (w *http2responseWriter) WriteHeader(code int) { - rws := w.rws - if rws == nil { - panic("WriteHeader called after Handler finished") + if req.ContentLength != 0 { + return req.ContentLength } - rws.writeHeader(code) + return -1 } -func (rws *http2responseWriterState) writeHeader(code int) { - if !rws.wroteHeader { - http2checkWriteHeaderCode(code) - rws.wroteHeader = true - rws.status = code - if len(rws.handlerHeader) > 0 { - rws.snapHeader = http2cloneHeader(rws.handlerHeader) - } - } +func (cc *http2ClientConn) decrStreamReservations() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.decrStreamReservationsLocked() } -func http2cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 +func (cc *http2ClientConn) decrStreamReservationsLocked() { + if cc.streamsReserved > 0 { + cc.streamsReserved-- } - return h2 -} - -// The Life Of A Write is like this: -// -// * Handler calls w.Write or w.WriteString -> -// * -> rws.bw (*bufio.Writer) -> -// * (Handler might call Flush) -// * -> chunkWriter{rws} -// * -> responseWriterState.writeChunk(p []byte) -// * -> responseWriterState.writeChunk (most of the magic; see comment there) -func (w *http2responseWriter) Write(p []byte) (n int, err error) { - return w.write(len(p), p, "") } -func (w *http2responseWriter) WriteString(s string) (n int, err error) { - return w.write(len(s), nil, s) -} - -// either dataB or dataS is non-zero. -func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { - rws := w.rws - if rws == nil { - panic("Write called after Handler finished") - } - if !rws.wroteHeader { - w.WriteHeader(200) - } - if !http2bodyAllowedForStatus(rws.status) { - return 0, http.ErrBodyNotAllowed - } - rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set - if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { - // TODO: send a RST_STREAM - return 0, errors.New("http2: handler wrote more than declared Content-Length") - } - - if dataB != nil { - return rws.bw.Write(dataB) - } else { - return rws.bw.WriteString(dataS) +func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + cc.currentRequest = req + ctx := req.Context() + cs := &http2clientStream{ + cc: cc, + ctx: ctx, + reqCancel: req.Cancel, + isHead: req.Method == "HEAD", + reqBody: req.Body, + reqBodyContentLength: http2actualContentLength(req), + trace: httptrace.ContextClientTrace(ctx), + peerClosed: make(chan struct{}), + abort: make(chan struct{}), + respHeaderRecv: make(chan struct{}), + donec: make(chan struct{}), } -} + go cs.doRequest(req) -func (w *http2responseWriter) handlerDone() { - rws := w.rws - dirty := rws.dirty - rws.handlerDone = true - w.Flush() - w.rws = nil - if !dirty { - // Only recycle the pool if all prior Write calls to - // the serverConn goroutine completed successfully. If - // they returned earlier due to resets from the peer - // there might still be write goroutines outstanding - // from the serverConn referencing the rws memory. See - // issue 20704. - http2responseWriterStatePool.Put(rws) + waitDone := func() error { + select { + case <-cs.donec: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } } -} -// Push errors. -var ( - http2ErrRecursivePush = errors.New("http2: recursive push not allowed") - http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") -) + handleResponseHeaders := func() (*http.Response, error) { + res := cs.res + if res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite() + } + res.Request = req + res.TLS = cc.tlsState + if res.Body == http2noBody && http2actualContentLength(req) == 0 { + // If there isn't a request or response body still being + // written, then wait for the stream to be closed before + // RoundTrip returns. + if err := waitDone(); err != nil { + return nil, err + } + } + return res, nil + } + + for { + select { + case <-cs.respHeaderRecv: + return handleResponseHeaders() + case <-cs.abort: + select { + case <-cs.respHeaderRecv: + // If both cs.respHeaderRecv and cs.abort are signaling, + // pick respHeaderRecv. The server probably wrote the + // response and immediately reset the stream. + // golang.org/issue/49645 + return handleResponseHeaders() + default: + waitDone() + return nil, cs.abortErr + } + case <-ctx.Done(): + err := ctx.Err() + cs.abortStream(err) + return nil, err + case <-cs.reqCancel: + cs.abortStream(http2errRequestCanceled) + return nil, http2errRequestCanceled + } + } +} -var _ http.Pusher = (*http2responseWriter)(nil) +// doRequest runs for the duration of the request lifetime. +// +// It sends the request and performs post-request cleanup (closing Request.Body, etc.). +func (cs *http2clientStream) doRequest(req *http.Request) { + err := cs.writeRequest(req) + cs.cleanupWriteRequest(err) +} -func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error { - st := w.rws.stream - sc := st.sc - sc.serveG.checkNotOn() +// writeRequest sends a request. +// +// It returns nil after the request is written, the response read, +// and the request stream is half-closed by the peer. +// +// It returns non-nil if the request ends otherwise. +// If the returned error is StreamError, the error Code may be used in resetting the stream. +func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { + cc := cs.cc + ctx := cs.ctx - // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." - // http://tools.ietf.org/html/rfc7540#section-6.6 - if st.isPushed() { - return http2ErrRecursivePush + if err := http2checkConnHeaders(req); err != nil { + return err } - if opts == nil { - opts = new(http.PushOptions) + // Acquire the new-request lock by writing to reqHeaderMu. + // This lock guards the critical section covering allocating a new stream ID + // (requires mu) and creating the stream (requires wmu). + if cc.reqHeaderMu == nil { + panic("RoundTrip on uninitialized ClientConn") // for tests + } + select { + case cc.reqHeaderMu <- struct{}{}: + case <-cs.reqCancel: + return http2errRequestCanceled + case <-ctx.Done(): + return ctx.Err() } - // Default options. - if opts.Method == "" { - opts.Method = "GET" + cc.mu.Lock() + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + cc.decrStreamReservationsLocked() + if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { + cc.mu.Unlock() + <-cc.reqHeaderMu + return err + } + cc.addStreamLocked(cs) // assigns stream ID + if http2isConnectionCloseRequest(req) { + cc.doNotReuse = true } - if opts.Header == nil { - opts.Header = http.Header{} + cc.mu.Unlock() + + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + if !cc.t.disableCompression() && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + !cs.isHead { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + cs.requestedGzip = true } - wantScheme := "http" - if w.rws.req.TLS != nil { - wantScheme = "https" + + continueTimeout := cc.t.expectContinueTimeout() + if continueTimeout != 0 { + if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { + continueTimeout = 0 + } else { + cs.on100 = make(chan struct{}, 1) + } } - // Validate the request. - u, err := url.Parse(target) + dump := getDumperOverride(cs.cc.t.t1.dump, req.Context()) + + // Past this point (where we send request headers), it is possible for + // RoundTrip to return successfully. Since the RoundTrip contract permits + // the caller to "mutate or reuse" the Request after closing the Response's Body, + // we must take care when referencing the Request from here on. + err = cs.encodeAndWriteHeaders(req, dump) + <-cc.reqHeaderMu if err != nil { return err } - if u.Scheme == "" { - if !strings.HasPrefix(target, "/") { - return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) - } - u.Scheme = wantScheme - u.Host = w.rws.req.Host + + hasBody := cs.reqBodyContentLength != 0 + if !hasBody { + cs.sentEndStream = true } else { - if u.Scheme != wantScheme { - return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + if continueTimeout != 0 { + http2traceWait100Continue(cs.trace) + timer := time.NewTimer(continueTimeout) + select { + case <-timer.C: + err = nil + case <-cs.on100: + err = nil + case <-cs.abort: + err = cs.abortErr + case <-ctx.Done(): + err = ctx.Err() + case <-cs.reqCancel: + err = http2errRequestCanceled + } + timer.Stop() + if err != nil { + http2traceWroteRequest(cs.trace, err) + return err + } } - if u.Host == "" { - return errors.New("URL must have a host") + + if err = cs.writeRequestBody(req, dump); err != nil { + if err != http2errStopReqBodyWrite { + http2traceWroteRequest(cs.trace, err) + return err + } + } else { + cs.sentEndStream = true + if dump != nil && dump.RequestBody { + dump.dump([]byte("\r\n\r\n")) + } } } - for k := range opts.Header { - if strings.HasPrefix(k, ":") { - return fmt.Errorf("promised request headers cannot include pseudo header %q", k) - } - // These headers are meaningful only if the request has a body, - // but PUSH_PROMISE requests cannot have a body. - // http://tools.ietf.org/html/rfc7540#section-8.2 - // Also disallow Host, since the promised URL must be absolute. - if http2asciiEqualFold(k, "content-length") || - http2asciiEqualFold(k, "content-encoding") || - http2asciiEqualFold(k, "trailer") || - http2asciiEqualFold(k, "te") || - http2asciiEqualFold(k, "expect") || - http2asciiEqualFold(k, "host") { - return fmt.Errorf("promised request headers cannot include %q", k) - } + + http2traceWroteRequest(cs.trace, err) + + var respHeaderTimer <-chan time.Time + var respHeaderRecv chan struct{} + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + respHeaderRecv = cs.respHeaderRecv } - if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { - return err + // Wait until the peer half-closes its end of the stream, + // or until the request is aborted (via context, error, or otherwise), + // whichever comes first. + for { + select { + case <-cs.peerClosed: + return nil + case <-respHeaderTimer: + return http2errTimeout + case <-respHeaderRecv: + respHeaderRecv = nil + respHeaderTimer = nil // keep waiting for END_STREAM + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } } +} - // The RFC effectively limits promised requests to GET and HEAD: - // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" - // http://tools.ietf.org/html/rfc7540#section-8.2 - if opts.Method != "GET" && opts.Method != "HEAD" { - return fmt.Errorf("method %q must be GET or HEAD", opts.Method) - } +func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dump *dumper) error { + cc := cs.cc + ctx := cs.ctx - msg := &http2startPushRequest{ - parent: st, - method: opts.Method, - url: u, - header: http2cloneHeader(opts.Header), - done: http2errChanPool.Get().(chan error), - } + cc.wmu.Lock() + defer cc.wmu.Unlock() + // If the request was canceled while waiting for cc.mu, just quit. select { - case <-sc.doneServing: - return http2errClientDisconnected - case <-st.cw: - return http2errStreamClosed - case sc.serveMsgCh <- msg: + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + default: } - select { - case <-sc.doneServing: - return http2errClientDisconnected - case <-st.cw: - return http2errStreamClosed - case err := <-msg.done: - http2errChanPool.Put(msg.done) + // Encode headers. + // + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return err + } + hasTrailers := trailers != "" + contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dump) + if err != nil { return err } -} -type http2startPushRequest struct { - parent *http2stream - method string - url *url.URL - header http.Header - done chan error + // Write the request. + endStream := !hasBody && !hasTrailers + cs.sentHeaders = true + err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + http2traceWroteHeaders(cs.trace) + return err } -func (sc *http2serverConn) startPush(msg *http2startPushRequest) { - sc.serveG.check() - - // http://tools.ietf.org/html/rfc7540#section-6.6. - // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that - // is in either the "open" or "half-closed (remote)" state. - if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { - // responseWriter.Push checks that the stream is peer-initiated. - msg.done <- http2errStreamClosed - return - } +// cleanupWriteRequest performs post-request tasks. +// +// If err (the result of writeRequest) is non-nil and the stream is not closed, +// cleanupWriteRequest will send a reset to the peer. +func (cs *http2clientStream) cleanupWriteRequest(err error) { + cc := cs.cc - // http://tools.ietf.org/html/rfc7540#section-6.6. - if !sc.pushEnabled { - msg.done <- http.ErrNotSupported - return + if cs.ID == 0 { + // We were canceled before creating the stream, so return our reservation. + cc.decrStreamReservations() } - // PUSH_PROMISE frames must be sent in increasing order by stream ID, so - // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE - // is written. Once the ID is allocated, we start the request handler. - allocatePromisedID := func() (uint32, error) { - sc.serveG.check() - - // Check this again, just in case. Technically, we might have received - // an updated SETTINGS by the time we got around to writing this frame. - if !sc.pushEnabled { - return 0, http.ErrNotSupported - } - // http://tools.ietf.org/html/rfc7540#section-6.5.2. - if sc.curPushedStreams+1 > sc.clientMaxStreams { - return 0, http2ErrPushLimitReached - } - - // http://tools.ietf.org/html/rfc7540#section-5.1.1. - // Streams initiated by the server MUST use even-numbered identifiers. - // A server that is unable to establish a new stream identifier can send a GOAWAY - // frame so that the client is forced to open a new connection for new streams. - if sc.maxPushPromiseID+2 >= 1<<31 { - sc.startGracefulShutdownInternal() - return 0, http2ErrPushLimitReached - } - sc.maxPushPromiseID += 2 - promisedID := sc.maxPushPromiseID - - // http://tools.ietf.org/html/rfc7540#section-8.2. - // Strictly speaking, the new stream should start in "reserved (local)", then - // transition to "half closed (remote)" after sending the initial HEADERS, but - // we start in "half closed (remote)" for simplicity. - // See further comments at the definition of stateHalfClosedRemote. - promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) - rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ - method: msg.method, - scheme: msg.url.Scheme, - authority: msg.url.Host, - path: msg.url.RequestURI(), - header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE - }) - if err != nil { - // Should not happen, since we've already validated msg.url. - panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) - } - - go sc.runHandler(rw, req, sc.handler.ServeHTTP) - return promisedID, nil + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cs.reqBodyClosed = true + cc.mu.Unlock() + if !bodyClosed && cs.reqBody != nil { + cs.reqBody.Close() } - sc.writeFrame(http2FrameWriteRequest{ - write: &http2writePushPromise{ - streamID: msg.parent.id, - method: msg.method, - url: msg.url, - h: msg.header, - allocatePromisedID: allocatePromisedID, - }, - stream: msg.parent, - done: msg.done, - }) -} - -// foreachHeaderElement splits v according to the "#rule" construction -// in RFC 7230 section 7 and calls fn for each non-empty element. -func http2foreachHeaderElement(v string, fn func(string)) { - v = textproto.TrimString(v) - if v == "" { - return - } - if !strings.Contains(v, ",") { - fn(v) - return - } - for _, f := range strings.Split(v, ",") { - if f = textproto.TrimString(f); f != "" { - fn(f) + if err != nil && cs.sentEndStream { + // If the connection is closed immediately after the response is read, + // we may be aborted before finishing up here. If the stream was closed + // cleanly on both sides, there is no error. + select { + case <-cs.peerClosed: + err = nil + default: } } -} - -// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 -var http2connHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Connection", - "Transfer-Encoding", - "Upgrade", -} - -// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, -// per RFC 7540 Section 8.1.2.2. -// The returned error is reported to users. -func http2checkValidHTTP2RequestHeaders(h http.Header) error { - for _, k := range http2connHeaders { - if _, ok := h[k]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", k) + if err != nil { + cs.abortStream(err) // possibly redundant, but harmless + if cs.sentHeaders { + if se, ok := err.(http2StreamError); ok { + if se.Cause != http2errFromPeer { + cc.writeStreamReset(cs.ID, se.Code, err) + } + } else { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) + } } + cs.bufPipe.CloseWithError(err) // no-op if already closed + } else { + if cs.sentHeaders && !cs.sentEndStream { + cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) + } + cs.bufPipe.CloseWithError(http2errRequestCanceled) } - te := h["Te"] - if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { - return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + if cs.ID != 0 { + cc.forgetStreamID(cs.ID) } - return nil -} -func http2new400Handler(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - http.Error(w, err.Error(), http.StatusBadRequest) + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() } + + close(cs.donec) } -// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives -// disabled. See comments on h1ServerShutdownChan above for why -// the code is written this way. -func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { - var x interface{} = hs - type I interface { - doKeepAlives() bool - } - if hs, ok := x.(I); ok { - return !hs.doKeepAlives() +// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. +// Must hold cc.mu. +func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { + for { + cc.lastActive = time.Now() + if cc.closed || !cc.canTakeNewRequestLocked() { + return http2errClientConnUnusable + } + cc.lastIdle = time.Time{} + if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { + return nil + } + cc.pendingRequests++ + cc.cond.Wait() + cc.pendingRequests-- + select { + case <-cs.abort: + return cs.abortErr + default: + } } - return false } -func (sc *http2serverConn) countError(name string, err error) error { - if sc == nil || sc.srv == nil { - return err - } - f := sc.srv.CountError - if f == nil { - return err - } - var typ string - var code http2ErrCode - switch e := err.(type) { - case http2ConnectionError: - typ = "conn" - code = http2ErrCode(e) - case http2StreamError: - typ = "stream" - code = http2ErrCode(e.Code) - default: - return err - } - codeStr := http2errCodeName[code] - if codeStr == "" { - codeStr = strconv.Itoa(int(code)) +// requires cc.wmu be held +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { + first := true // first frame written (HEADERS is first, then CONTINUATION) + for len(hdrs) > 0 && cc.werr == nil { + chunk := hdrs + if len(chunk) > maxFrameSize { + chunk = chunk[:maxFrameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: streamID, + BlockFragment: chunk, + EndStream: endStream, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(streamID, endHeaders, chunk) + } } - f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) - return err + cc.bw.Flush() + return cc.werr } -const ( - // transportDefaultConnFlow is how many connection-level flow control - // tokens we give the server at start-up, past the default 64k. - http2transportDefaultConnFlow = 1 << 30 - - // transportDefaultStreamFlow is how many stream-level flow - // control tokens we announce to the peer, and how many bytes - // we buffer per stream. - http2transportDefaultStreamFlow = 4 << 20 - - // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send - // a stream-level WINDOW_UPDATE for at a time. - http2transportDefaultStreamMinRefresh = 4 << 10 +// internal error values; they don't escape to callers +var ( + // abort request body write; don't send cancel + http2errStopReqBodyWrite = errors.New("http2: aborting request body write") - // initialMaxConcurrentStreams is a connections maxConcurrentStreams until - // it's received servers initial SETTINGS frame, which corresponds with the - // spec's minimum recommended value. - http2initialMaxConcurrentStreams = 100 + // abort request body write, but send stream reset of cancel. + http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") - // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams - // if the server doesn't include one in its initial SETTINGS frame. - http2defaultMaxConcurrentStreams = 1000 + http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") ) -// Transport is an HTTP/2 Transport. +// frameScratchBufferLen returns the length of a buffer to use for +// outgoing request bodies to read/write to/from. // -// A Transport internally caches connections to servers. It is safe -// for concurrent use by multiple goroutines. -type http2Transport struct { - // DialTLS specifies an optional dial function for creating - // TLS connections for requests. - // - // If DialTLS is nil, tls.Dial is used. - // - // If the returned net.Conn has a ConnectionState method like tls.Conn, - // it will be used to set http.Response.TLS. - DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) - - // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. If nil, the default configuration is used. - TLSClientConfig *tls.Config +// It returns max(1, min(peer's advertised max frame size, +// Request.ContentLength+1, 512KB)). +func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { + const max = 512 << 10 + n := int64(maxFrameSize) + if n > max { + n = max + } + if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n { + // Add an extra byte past the declared content-length to + // give the caller's Request.Body io.textprotoReader a chance to + // give us more bytes than they declared, so we can catch it + // early. + n = cl + 1 + } + if n < 1 { + return 1 + } + return int(n) // doesn't truncate; max is 512K +} - // ConnPool optionally specifies an alternate connection pool to use. - // If nil, the default is used. - ConnPool http2ClientConnPool +var http2bufPool sync.Pool // of *[]byte - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool +func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) (err error) { + cc := cs.cc + body := cs.reqBody + sentEnd := false // whether we sent the final DATA frame w/ END_STREAM - // AllowHTTP, if true, permits HTTP/2 requests using the insecure, - // plain-text "http" scheme. Note that this does not enable h2c support. - AllowHTTP bool + hasTrailers := req.Trailer != nil + remainLen := cs.reqBodyContentLength + hasContentLen := remainLen != -1 - // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to - // send in the initial settings frame. It is how many bytes - // of response headers are allowed. Unlike the http2 spec, zero here - // means to use a default limit (currently 10MB). If you actually - // want to advertise an unlimited value to the peer, Transport - // interprets the highest possible value here (0xffffffff or 1<<32-1) - // to mean no limit. - MaxHeaderListSize uint32 + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() - // StrictMaxConcurrentStreams controls whether the server's - // SETTINGS_MAX_CONCURRENT_STREAMS should be respected - // globally. If false, new TCP connections are created to the - // server as needed to keep each under the per-connection - // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the - // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as - // a global limit and callers of RoundTrip block when needed, - // waiting for their turn. - StrictMaxConcurrentStreams bool - - // ReadIdleTimeout is the timeout after which a health check using ping - // frame will be carried out if no frame is received on the connection. - // Note that a ping response will is considered a received frame, so if - // there is no other traffic on the connection, the health check will - // be performed every ReadIdleTimeout interval. - // If zero, no health check is performed. - ReadIdleTimeout time.Duration - - // PingTimeout is the timeout after which the connection will be closed - // if a response to Ping is not received. - // Defaults to 15s. - PingTimeout time.Duration - - // WriteByteTimeout is the timeout after which the connection will be - // closed no data can be written to it. The timeout begins when data is - // available to write, and is extended whenever any bytes are written. - WriteByteTimeout time.Duration - - // CountError, if non-nil, is called on HTTP/2 transport errors. - // It's intended to increment a metric for monitoring, such - // as an expvar or Prometheus metric. - // The errType consists of only ASCII word characters. - CountError func(errType string) - - // t1, if non-nil, is the standard library Transport using - // this transport. Its settings are used (but not its - // RoundTrip method, etc). - t1 *Transport - - connPoolOnce sync.Once - connPoolOrDef http2ClientConnPool // non-nil version of ConnPool -} - -func (t *http2Transport) maxHeaderListSize() uint32 { - if t.MaxHeaderListSize == 0 { - return 10 << 20 - } - if t.MaxHeaderListSize == 0xffffffff { - return 0 - } - return t.MaxHeaderListSize -} - -func (t *http2Transport) disableCompression() bool { - return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) -} - -func (t *http2Transport) pingTimeout() time.Duration { - if t.PingTimeout == 0 { - return 15 * time.Second - } - return t.PingTimeout - -} - -// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns an error if t1 has already been HTTP/2-enabled. -// -// Use ConfigureTransports instead to configure the HTTP/2 Transport. -func http2ConfigureTransport(t1 *Transport) error { - _, err := http2ConfigureTransports(t1) - return err -} - -// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns a new HTTP/2 Transport for further configuration. -// It returns an error if t1 has already been HTTP/2-enabled. -func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { - connPool := new(http2clientConnPool) - t2 := &http2Transport{ - ConnPool: http2noDialClientConnPool{connPool}, - t1: t1, - } - connPool.t = t2 - if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { - return nil, err - } - if t1.TLSClientConfig == nil { - t1.TLSClientConfig = new(tls.Config) - } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { - t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) - } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { - t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") - } - upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { - addr := http2authorityAddr("https", authority) - if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { - go c.Close() - return http2erringRoundTripper{err} - } else if !used { - // Turns out we don't need this c. - // For example, two goroutines made requests to the same host - // at the same time, both kicking off TCP dials. (since protocol - // was unknown) - go c.Close() - } - return t2 - } - if m := t1.TLSNextProto; len(m) == 0 { - t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ - "h2": upgradeFn, - } - } else { - m["h2"] = upgradeFn - } - return t2, nil -} - -func (t *http2Transport) connPool() http2ClientConnPool { - t.connPoolOnce.Do(t.initConnPool) - return t.connPoolOrDef -} - -func (t *http2Transport) initConnPool() { - if t.ConnPool != nil { - t.connPoolOrDef = t.ConnPool - } else { - t.connPoolOrDef = &http2clientConnPool{t: t} - } -} - -// ClientConn is the state of a single HTTP/2 client connection to an -// HTTP/2 server. -type http2ClientConn struct { - currentRequest *http.Request - t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls - reused uint32 // whether conn is being reused; atomic - singleUse bool // whether being used for a single http.Request - getConnCalled bool // used by clientConnPool - - // readLoop goroutine fields: - readerDone chan struct{} // closed on error - readerErr error // set before readerDone is closed - - idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer - - mu sync.Mutex // guards following - cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow http2flow // our conn-level flow control quota (cs.flow is per stream) - inflow http2flow // peer's conn-level flow control - doNotReuse bool // whether conn is marked to not be reused for any future requests - closing bool - closed bool - seenSettings bool // true if we've seen a settings frame, false otherwise - wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back - goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received - goAwayDebug string // goAway frame's debug data, retained as a string - streams map[uint32]*http2clientStream // client-initiated - streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip - nextStreamID uint32 - pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams - pings map[[8]byte]chan struct{} // in flight ping data to notification channel - br *bufio.Reader - lastActive time.Time - lastIdle time.Time // time last idle - // Settings from peer: (also guarded by wmu) - maxFrameSize uint32 - maxConcurrentStreams uint32 - peerMaxHeaderListSize uint64 - initialWindowSize uint32 - - // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. - // Write to reqHeaderMu to lock it, read from it to unlock. - // Lock reqmu BEFORE mu or wmu. - reqHeaderMu chan struct{} - - // wmu is held while writing. - // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. - // Only acquire both at the same time when changing peer settings. - wmu sync.Mutex - bw *bufio.Writer - fr *http2Framer - werr error // first write error that has occurred - hbuf bytes.Buffer // HPACK encoder writes into this - henc *hpack.Encoder -} - -// clientStream is the state for a single HTTP/2 stream. One of these -// is created for each Transport.RoundTrip call. -type http2clientStream struct { - cc *http2ClientConn - - // Fields of Request that we may access even after the response body is closed. - ctx context.Context - reqCancel <-chan struct{} - - trace *httptrace.ClientTrace // or nil - ID uint32 - bufPipe http2pipe // buffered pipe with the flow-controlled response payload - requestedGzip bool - isHead bool - - abortOnce sync.Once - abort chan struct{} // closed to signal stream should end immediately - abortErr error // set if abort is closed - - peerClosed chan struct{} // closed when the peer sends an END_STREAM flag - donec chan struct{} // closed after the stream is in the closed state - on100 chan struct{} // buffered; written to if a 100 is received - - respHeaderRecv chan struct{} // closed when headers are received - res *http.Response // set if respHeaderRecv is closed - - flow http2flow // guarded by cc.mu - inflow http2flow // guarded by cc.mu - bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read - readErr error // sticky read error; owned by transportResponseBody.Read - - reqBody io.ReadCloser - reqBodyContentLength int64 // -1 means unknown - reqBodyClosed bool // body has been closed; guarded by cc.mu - - // owned by writeRequest: - sentEndStream bool // sent an END_STREAM flag to the peer - sentHeaders bool - - // owned by clientConnReadLoop: - firstByte bool // got the first response byte - pastHeaders bool // got first MetaHeadersFrame (actual headers) - pastTrailers bool // got optional second MetaHeadersFrame (trailers) - num1xx uint8 // number of 1xx responses seen - readClosed bool // peer sent an END_STREAM flag - readAborted bool // read loop reset the stream - - trailer http.Header // accumulated trailers - resTrailer *http.Header // client's Response.Trailer -} - -var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error - -// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, -// if any. It returns nil if not set or if the Go version is too old. -func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { - if fn := http2got1xxFuncForTests; fn != nil { - return fn - } - return http2traceGot1xxResponseFunc(cs.trace) -} - -func (cs *http2clientStream) abortStream(err error) { - cs.cc.mu.Lock() - defer cs.cc.mu.Unlock() - cs.abortStreamLocked(err) -} - -func (cs *http2clientStream) abortStreamLocked(err error) { - cs.abortOnce.Do(func() { - cs.abortErr = err - close(cs.abort) - }) - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true - } - // TODO(dneil): Clean up tests where cs.cc.cond is nil. - if cs.cc.cond != nil { - // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() - } -} - -func (cs *http2clientStream) abortRequestBodyWrite() { - cc := cs.cc - cc.mu.Lock() - defer cc.mu.Unlock() - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true - cc.cond.Broadcast() - } -} - -type http2stickyErrWriter struct { - conn net.Conn - timeout time.Duration - err *error -} - -func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { - if *sew.err != nil { - return 0, *sew.err - } - for { - if sew.timeout != 0 { - sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) - } - nn, err := sew.conn.Write(p[n:]) - n += nn - if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { - // Keep extending the deadline so long as we're making progress. - continue - } - if sew.timeout != 0 { - sew.conn.SetWriteDeadline(time.Time{}) - } - *sew.err = err - return n, err - } -} - -// noCachedConnError is the concrete type of ErrNoCachedConn, which -// needs to be detected by net/http regardless of whether it's its -// bundled version (in h2_bundle.go with a rewritten type name) or -// from a user's x/net/http2. As such, as it has a unique method name -// (IsHTTP2NoCachedConnError) that net/http sniffs for via func -// isNoCachedConnError. -type http2noCachedConnError struct{} - -func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} - -func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } - -// isNoCachedConnError reports whether err is of type noCachedConnError -// or its equivalent renamed type in net/http2's h2_bundle.go. Both types -// may coexist in the same running program. -func http2isNoCachedConnError(err error) bool { - _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) - return ok -} - -var http2ErrNoCachedConn error = http2noCachedConnError{} - -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type http2RoundTripOpt struct { - // OnlyCachedConn controls whether RoundTripOpt may - // create a new TCP connection. If set true and - // no cached connection is available, RoundTripOpt - // will return ErrNoCachedConn. - OnlyCachedConn bool -} - -func (t *http2Transport) RoundTrip(req *http.Request) (*http.Response, error) { - return t.RoundTripOpt(req, http2RoundTripOpt{}) -} - -// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) -// and returns a host:port. The port 443 is added if needed. -func http2authorityAddr(scheme string, authority string) (addr string) { - host, port, err := net.SplitHostPort(authority) - if err != nil { // authority didn't have a port - port = "443" - if scheme == "http" { - port = "80" - } - host = authority - } - if a, err := idna.ToASCII(host); err == nil { - host = a - } - // IPv6 address literal, without a port: - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - return host + ":" + port - } - return net.JoinHostPort(host, port) -} - -// RoundTripOpt is like RoundTrip, but takes options. -func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) (*http.Response, error) { - if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { - return nil, errors.New("http2: unsupported scheme") - } - - addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) - for retry := 0; ; retry++ { - cc, err := t.connPool().GetClientConn(req, addr) - if err != nil { - t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) - return nil, err - } - reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) - http2traceGotConn(req, cc, reused) - res, err := cc.RoundTrip(req) - if err != nil && retry <= 6 { - if req, err = http2shouldRetryRequest(req, err); err == nil { - // After the first retry, do exponential backoff with 10% jitter. - if retry == 0 { - continue - } - backoff := float64(uint(1) << (uint(retry) - 1)) - backoff += backoff * (0.1 * mathrand.Float64()) - select { - case <-time.After(time.Second * time.Duration(backoff)): - continue - case <-req.Context().Done(): - err = req.Context().Err() - } - } - } - if err != nil { - t.vlogf("RoundTrip failure: %v", err) - return nil, err - } - return res, nil - } -} - -// CloseIdleConnections closes any connections which were previously -// connected from previous requests but are now sitting idle. -// It does not interrupt any connections currently in use. -func (t *http2Transport) CloseIdleConnections() { - if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { - cp.closeIdleConnections() - } -} - -var ( - http2errClientConnClosed = errors.New("http2: client conn is closed") - http2errClientConnUnusable = errors.New("http2: client conn not usable") - http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") -) - -// shouldRetryRequest is called by RoundTrip when a request fails to get -// response headers. It is always called with a non-nil error. -// It returns either a request to retry (either the same request, or a -// modified clone), or an error if the request can't be replayed. -func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { - if !http2canRetryError(err) { - return nil, err - } - // If the Body is nil (or http.NoBody), it's safe to reuse - // this request and its Body. - if req.Body == nil || req.Body == http.NoBody { - return req, nil - } - - // If the request body can be reset back to its original - // state via the optional req.GetBody, do that. - if req.GetBody != nil { - body, err := req.GetBody() - if err != nil { - return nil, err - } - newReq := *req - newReq.Body = body - return &newReq, nil - } - - // The Request.Body can't reset back to the beginning, but we - // don't seem to have started to read from it yet, so reuse - // the request directly. - if err == http2errClientConnUnusable { - return req, nil - } - - return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) -} - -func http2canRetryError(err error) bool { - if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { - return true - } - if se, ok := err.(http2StreamError); ok { - if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { - // See golang/go#47635, golang/go#42777 - return true - } - return se.Code == http2ErrCodeRefusedStream - } - return false -} - -func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) - if err != nil { - return nil, err - } - return t.newClientConn(tconn, singleUse) -} - -func (t *http2Transport) newTLSConfig(host string) *tls.Config { - cfg := new(tls.Config) - if t.TLSClientConfig != nil { - *cfg = *t.TLSClientConfig.Clone() - } - if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { - cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) - } - if cfg.ServerName == "" { - cfg.ServerName = host - } - return cfg -} - -func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { - if t.DialTLS != nil { - return t.DialTLS - } - return func(network, addr string, cfg *tls.Config) (net.Conn, error) { - tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) - if err != nil { - return nil, err - } - state := tlsCn.ConnectionState() - if p := state.NegotiatedProtocol; p != http2NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) - } - if !state.NegotiatedProtocolIsMutual { - return nil, errors.New("http2: could not negotiate protocol mutually") - } - return tlsCn, nil - } -} - -// disableKeepAlives reports whether connections should be closed as -// soon as possible after handling the first request. -func (t *http2Transport) disableKeepAlives() bool { - return t.t1 != nil && t.t1.DisableKeepAlives -} - -func (t *http2Transport) expectContinueTimeout() time.Duration { - if t.t1 == nil { - return 0 - } - return t.t1.ExpectContinueTimeout -} - -func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) -} - -func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { - cc := &http2ClientConn{ - t: t, - tconn: c, - readerDone: make(chan struct{}), - nextStreamID: 1, - maxFrameSize: 16 << 10, // spec default - initialWindowSize: 65535, // spec default - maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. - peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. - streams: make(map[uint32]*http2clientStream), - singleUse: singleUse, - wantSettingsAck: true, - pings: make(map[[8]byte]chan struct{}), - reqHeaderMu: make(chan struct{}, 1), - } - if d := t.idleConnTimeout(); d != 0 { - cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) - } - if http2VerboseLogs { - t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) - } - - cc.cond = sync.NewCond(&cc.mu) - cc.flow.add(int32(http2initialWindowSize)) - - // TODO: adjust this writer size to account for frame size + - // MTU + crypto/tls record padding. - cc.bw = bufio.NewWriter(http2stickyErrWriter{ - conn: c, - timeout: t.WriteByteTimeout, - err: &cc.werr, - }) - cc.br = bufio.NewReader(c) - cc.fr = http2NewFramer(cc.bw, cc.br) - cc.fr.cc = cc // for dump single request - if t.CountError != nil { - cc.fr.countError = t.CountError - } - cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) - cc.fr.MaxHeaderListSize = t.maxHeaderListSize() - - // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on - // henc in response to SETTINGS frames? - cc.henc = hpack.NewEncoder(&cc.hbuf) - - if t.AllowHTTP { - cc.nextStreamID = 3 - } - - if cs, ok := c.(http2connectionStater); ok { - state := cs.ConnectionState() - cc.tlsState = &state - } - - initialSettings := []http2Setting{ - {ID: http2SettingEnablePush, Val: 0}, - {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, - } - if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) - } - - cc.bw.Write(http2clientPreface) - cc.fr.WriteSettings(initialSettings...) - cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) - cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) - cc.bw.Flush() - if cc.werr != nil { - cc.Close() - return nil, cc.werr - } - - go cc.readLoop() - return cc, nil -} - -func (cc *http2ClientConn) healthCheck() { - pingTimeout := cc.t.pingTimeout() - // We don't need to periodically ping in the health check, because the readLoop of ClientConn will - // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) - defer cancel() - err := cc.Ping(ctx) - if err != nil { - cc.closeForLostPing() - cc.t.connPool().MarkDead(cc) - return - } -} - -// SetDoNotReuse marks cc as not reusable for future HTTP requests. -func (cc *http2ClientConn) SetDoNotReuse() { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.doNotReuse = true -} - -func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { - cc.mu.Lock() - defer cc.mu.Unlock() - - old := cc.goAway - cc.goAway = f - - // Merge the previous and current GoAway error frames. - if cc.goAwayDebug == "" { - cc.goAwayDebug = string(f.DebugData()) - } - if old != nil && old.ErrCode != http2ErrCodeNo { - cc.goAway.ErrCode = old.ErrCode - } - last := f.LastStreamID - for streamID, cs := range cc.streams { - if streamID > last { - cs.abortStreamLocked(http2errClientConnGotGoAway) - } - } -} - -// CanTakeNewRequest reports whether the connection can take a new request, -// meaning it has not been closed or received or sent a GOAWAY. -// -// If the caller is going to immediately make a new request on this -// connection, use ReserveNewRequest instead. -func (cc *http2ClientConn) CanTakeNewRequest() bool { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.canTakeNewRequestLocked() -} - -// ReserveNewRequest is like CanTakeNewRequest but also reserves a -// concurrent stream in cc. The reservation is decremented on the -// next call to RoundTrip. -func (cc *http2ClientConn) ReserveNewRequest() bool { - cc.mu.Lock() - defer cc.mu.Unlock() - if st := cc.idleStateLocked(); !st.canTakeNewRequest { - return false - } - cc.streamsReserved++ - return true -} - -// ClientConnState describes the state of a ClientConn. -type http2ClientConnState struct { - // Closed is whether the connection is closed. - Closed bool - - // Closing is whether the connection is in the process of - // closing. It may be closing due to shutdown, being a - // single-use connection, being marked as DoNotReuse, or - // having received a GOAWAY frame. - Closing bool - - // StreamsActive is how many streams are active. - StreamsActive int - - // StreamsReserved is how many streams have been reserved via - // ClientConn.ReserveNewRequest. - StreamsReserved int - - // StreamsPending is how many requests have been sent in excess - // of the peer's advertised MaxConcurrentStreams setting and - // are waiting for other streams to complete. - StreamsPending int - - // MaxConcurrentStreams is how many concurrent streams the - // peer advertised as acceptable. Zero means no SETTINGS - // frame has been received yet. - MaxConcurrentStreams uint32 - - // LastIdle, if non-zero, is when the connection last - // transitioned to idle state. - LastIdle time.Time -} - -// State returns a snapshot of cc's state. -func (cc *http2ClientConn) State() http2ClientConnState { - cc.wmu.Lock() - maxConcurrent := cc.maxConcurrentStreams - if !cc.seenSettings { - maxConcurrent = 0 - } - cc.wmu.Unlock() - - cc.mu.Lock() - defer cc.mu.Unlock() - return http2ClientConnState{ - Closed: cc.closed, - Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, - StreamsActive: len(cc.streams), - StreamsReserved: cc.streamsReserved, - StreamsPending: cc.pendingRequests, - LastIdle: cc.lastIdle, - MaxConcurrentStreams: maxConcurrent, - } -} - -// clientConnIdleState describes the suitability of a client -// connection to initiate a new RoundTrip request. -type http2clientConnIdleState struct { - canTakeNewRequest bool -} - -func (cc *http2ClientConn) idleState() http2clientConnIdleState { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.idleStateLocked() -} - -func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { - if cc.singleUse && cc.nextStreamID > 1 { - return - } - var maxConcurrentOkay bool - if cc.t.StrictMaxConcurrentStreams { - // We'll tell the caller we can take a new request to - // prevent the caller from dialing a new TCP - // connection, but then we'll block later before - // writing it. - maxConcurrentOkay = true - } else { - maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams) - } - - st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && - !cc.doNotReuse && - int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && - !cc.tooIdleLocked() - return -} - -func (cc *http2ClientConn) canTakeNewRequestLocked() bool { - st := cc.idleStateLocked() - return st.canTakeNewRequest -} - -// tooIdleLocked reports whether this connection has been been sitting idle -// for too much wall time. -func (cc *http2ClientConn) tooIdleLocked() bool { - // The Round(0) strips the monontonic clock reading so the - // times are compared based on their wall time. We don't want - // to reuse a connection that's been sitting idle during - // VM/laptop suspend if monotonic time was also frozen. - return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout -} - -// onIdleTimeout is called from a time.AfterFunc goroutine. It will -// only be called when we're idle, but because we're coming from a new -// goroutine, there could be a new request coming in at the same time, -// so this simply calls the synchronized closeIfIdle to shut down this -// connection. The timer could just call closeIfIdle, but this is more -// clear. -func (cc *http2ClientConn) onIdleTimeout() { - cc.closeIfIdle() -} - -func (cc *http2ClientConn) closeIfIdle() { - cc.mu.Lock() - if len(cc.streams) > 0 || cc.streamsReserved > 0 { - cc.mu.Unlock() - return - } - cc.closed = true - nextID := cc.nextStreamID - // TODO: do clients send GOAWAY too? maybe? Just Close: - cc.mu.Unlock() - - if http2VerboseLogs { - cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) - } - cc.tconn.Close() -} - -func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.doNotReuse && len(cc.streams) == 0 -} - -var http2shutdownEnterWaitStateHook = func() {} - -// Shutdown gracefully closes the client connection, waiting for running streams to complete. -func (cc *http2ClientConn) Shutdown(ctx context.Context) error { - if err := cc.sendGoAway(); err != nil { - return err - } - // Wait for all in-flight streams to complete or connection to close - done := make(chan error, 1) - cancelled := false // guarded by cc.mu - go func() { - cc.mu.Lock() - defer cc.mu.Unlock() - for { - if len(cc.streams) == 0 || cc.closed { - cc.closed = true - done <- cc.tconn.Close() - break - } - if cancelled { - break - } - cc.cond.Wait() - } - }() - http2shutdownEnterWaitStateHook() - select { - case err := <-done: - return err - case <-ctx.Done(): - cc.mu.Lock() - // Free the goroutine above - cancelled = true - cc.cond.Broadcast() - cc.mu.Unlock() - return ctx.Err() - } -} - -func (cc *http2ClientConn) sendGoAway() error { - cc.mu.Lock() - closing := cc.closing - cc.closing = true - maxStreamID := cc.nextStreamID - cc.mu.Unlock() - if closing { - // GOAWAY sent already - return nil - } - - cc.wmu.Lock() - defer cc.wmu.Unlock() - // Send a graceful shutdown frame to server - if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { - return err - } - if err := cc.bw.Flush(); err != nil { - return err - } - // Prevent new requests - return nil -} - -// closes the client connection immediately. In-flight requests are interrupted. -// err is sent to streams. -func (cc *http2ClientConn) closeForError(err error) error { - cc.mu.Lock() - cc.closed = true - for _, cs := range cc.streams { - cs.abortStreamLocked(err) - } - defer cc.cond.Broadcast() - defer cc.mu.Unlock() - return cc.tconn.Close() -} - -// Close closes the client connection immediately. -// -// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. -func (cc *http2ClientConn) Close() error { - err := errors.New("http2: client connection force closed via ClientConn.Close") - return cc.closeForError(err) -} - -// closes the client connection immediately. In-flight requests are interrupted. -func (cc *http2ClientConn) closeForLostPing() error { - err := errors.New("http2: client connection lost") - if f := cc.t.CountError; f != nil { - f("conn_close_lost_ping") - } - return cc.closeForError(err) -} - -// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not -// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. -var http2errRequestCanceled = errors.New("net/http: request canceled") - -func http2commaSeparatedTrailers(req *http.Request) (string, error) { - keys := make([]string, 0, len(req.Trailer)) - for k := range req.Trailer { - k = http.CanonicalHeaderKey(k) - switch k { - case "Transfer-Encoding", "Trailer", "Content-Length": - return "", fmt.Errorf("invalid Trailer key %q", k) - } - keys = append(keys, k) - } - if len(keys) > 0 { - sort.Strings(keys) - return strings.Join(keys, ","), nil - } - return "", nil -} - -func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { - if cc.t.t1 != nil { - return cc.t.t1.ResponseHeaderTimeout - } - // No way to do this (yet?) with just an http2.Transport. Probably - // no need. Request.Cancel this is the new way. We only need to support - // this for compatibility with the old http.Transport fields when - // we're doing transparent http2. - return 0 -} - -// checkConnHeaders checks whether req has any invalid connection-level headers. -// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields. -// Certain headers are special-cased as okay but not transmitted later. -func http2checkConnHeaders(req *http.Request) error { - if v := req.Header.Get("Upgrade"); v != "" { - return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) - } - if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { - return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) - } - if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) { - return fmt.Errorf("http2: invalid Connection request header: %q", vv) - } - return nil -} - -// actualContentLength returns a sanitized version of -// req.ContentLength, where 0 actually means zero (not unknown) and -1 -// means unknown. -func http2actualContentLength(req *http.Request) int64 { - if req.Body == nil || req.Body == http.NoBody { - return 0 - } - if req.ContentLength != 0 { - return req.ContentLength - } - return -1 -} - -func (cc *http2ClientConn) decrStreamReservations() { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.decrStreamReservationsLocked() -} - -func (cc *http2ClientConn) decrStreamReservationsLocked() { - if cc.streamsReserved > 0 { - cc.streamsReserved-- - } -} - -func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { - cc.currentRequest = req - ctx := req.Context() - cs := &http2clientStream{ - cc: cc, - ctx: ctx, - reqCancel: req.Cancel, - isHead: req.Method == "HEAD", - reqBody: req.Body, - reqBodyContentLength: http2actualContentLength(req), - trace: httptrace.ContextClientTrace(ctx), - peerClosed: make(chan struct{}), - abort: make(chan struct{}), - respHeaderRecv: make(chan struct{}), - donec: make(chan struct{}), - } - go cs.doRequest(req) - - waitDone := func() error { - select { - case <-cs.donec: - return nil - case <-ctx.Done(): - return ctx.Err() - case <-cs.reqCancel: - return http2errRequestCanceled - } - } - - handleResponseHeaders := func() (*http.Response, error) { - res := cs.res - if res.StatusCode > 299 { - // On error or status code 3xx, 4xx, 5xx, etc abort any - // ongoing write, assuming that the server doesn't care - // about our request body. If the server replied with 1xx or - // 2xx, however, then assume the server DOES potentially - // want our body (e.g. full-duplex streaming: - // golang.org/issue/13444). If it turns out the server - // doesn't, they'll RST_STREAM us soon enough. This is a - // heuristic to avoid adding knobs to Transport. Hopefully - // we can keep it. - cs.abortRequestBodyWrite() - } - res.Request = req - res.TLS = cc.tlsState - if res.Body == http2noBody && http2actualContentLength(req) == 0 { - // If there isn't a request or response body still being - // written, then wait for the stream to be closed before - // RoundTrip returns. - if err := waitDone(); err != nil { - return nil, err - } - } - return res, nil - } - - for { - select { - case <-cs.respHeaderRecv: - return handleResponseHeaders() - case <-cs.abort: - select { - case <-cs.respHeaderRecv: - // If both cs.respHeaderRecv and cs.abort are signaling, - // pick respHeaderRecv. The server probably wrote the - // response and immediately reset the stream. - // golang.org/issue/49645 - return handleResponseHeaders() - default: - waitDone() - return nil, cs.abortErr - } - case <-ctx.Done(): - err := ctx.Err() - cs.abortStream(err) - return nil, err - case <-cs.reqCancel: - cs.abortStream(http2errRequestCanceled) - return nil, http2errRequestCanceled - } - } -} - -// doRequest runs for the duration of the request lifetime. -// -// It sends the request and performs post-request cleanup (closing Request.Body, etc.). -func (cs *http2clientStream) doRequest(req *http.Request) { - err := cs.writeRequest(req) - cs.cleanupWriteRequest(err) -} - -// writeRequest sends a request. -// -// It returns nil after the request is written, the response read, -// and the request stream is half-closed by the peer. -// -// It returns non-nil if the request ends otherwise. -// If the returned error is StreamError, the error Code may be used in resetting the stream. -func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { - cc := cs.cc - ctx := cs.ctx - - if err := http2checkConnHeaders(req); err != nil { - return err - } - - // Acquire the new-request lock by writing to reqHeaderMu. - // This lock guards the critical section covering allocating a new stream ID - // (requires mu) and creating the stream (requires wmu). - if cc.reqHeaderMu == nil { - panic("RoundTrip on uninitialized ClientConn") // for tests - } - select { - case cc.reqHeaderMu <- struct{}{}: - case <-cs.reqCancel: - return http2errRequestCanceled - case <-ctx.Done(): - return ctx.Err() - } - - cc.mu.Lock() - if cc.idleTimer != nil { - cc.idleTimer.Stop() - } - cc.decrStreamReservationsLocked() - if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { - cc.mu.Unlock() - <-cc.reqHeaderMu - return err - } - cc.addStreamLocked(cs) // assigns stream ID - if http2isConnectionCloseRequest(req) { - cc.doNotReuse = true - } - cc.mu.Unlock() - - // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - if !cc.t.disableCompression() && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - !cs.isHead { - // Request gzip only, not deflate. Deflate is ambiguous and - // not as universally supported anyway. - // See: https://zlib.net/zlib_faq.html#faq39 - // - // Note that we don't request this for HEAD requests, - // due to a bug in nginx: - // http://trac.nginx.org/nginx/ticket/358 - // https://golang.org/issue/5522 - // - // We don't request gzip if the request is for a range, since - // auto-decoding a portion of a gzipped document will just fail - // anyway. See https://golang.org/issue/8923 - cs.requestedGzip = true - } - - continueTimeout := cc.t.expectContinueTimeout() - if continueTimeout != 0 { - if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { - continueTimeout = 0 - } else { - cs.on100 = make(chan struct{}, 1) - } - } - - dump := getDumperOverride(cs.cc.t.t1.dump, req.Context()) - - // Past this point (where we send request headers), it is possible for - // RoundTrip to return successfully. Since the RoundTrip contract permits - // the caller to "mutate or reuse" the Request after closing the Response's Body, - // we must take care when referencing the Request from here on. - err = cs.encodeAndWriteHeaders(req, dump) - <-cc.reqHeaderMu - if err != nil { - return err - } - - hasBody := cs.reqBodyContentLength != 0 - if !hasBody { - cs.sentEndStream = true - } else { - if continueTimeout != 0 { - http2traceWait100Continue(cs.trace) - timer := time.NewTimer(continueTimeout) - select { - case <-timer.C: - err = nil - case <-cs.on100: - err = nil - case <-cs.abort: - err = cs.abortErr - case <-ctx.Done(): - err = ctx.Err() - case <-cs.reqCancel: - err = http2errRequestCanceled - } - timer.Stop() - if err != nil { - http2traceWroteRequest(cs.trace, err) - return err - } - } - - if err = cs.writeRequestBody(req, dump); err != nil { - if err != http2errStopReqBodyWrite { - http2traceWroteRequest(cs.trace, err) - return err - } - } else { - cs.sentEndStream = true - if dump != nil && dump.RequestBody { - dump.dump([]byte("\r\n\r\n")) - } - } - } - - http2traceWroteRequest(cs.trace, err) - - var respHeaderTimer <-chan time.Time - var respHeaderRecv chan struct{} - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C - respHeaderRecv = cs.respHeaderRecv - } - // Wait until the peer half-closes its end of the stream, - // or until the request is aborted (via context, error, or otherwise), - // whichever comes first. - for { - select { - case <-cs.peerClosed: - return nil - case <-respHeaderTimer: - return http2errTimeout - case <-respHeaderRecv: - respHeaderRecv = nil - respHeaderTimer = nil // keep waiting for END_STREAM - case <-cs.abort: - return cs.abortErr - case <-ctx.Done(): - return ctx.Err() - case <-cs.reqCancel: - return http2errRequestCanceled - } - } -} - -func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dump *dumper) error { - cc := cs.cc - ctx := cs.ctx - - cc.wmu.Lock() - defer cc.wmu.Unlock() - - // If the request was canceled while waiting for cc.mu, just quit. - select { - case <-cs.abort: - return cs.abortErr - case <-ctx.Done(): - return ctx.Err() - case <-cs.reqCancel: - return http2errRequestCanceled - default: - } - - // Encode headers. - // - // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is - // sent by writeRequestBody below, along with any Trailers, - // again in form HEADERS{1}, CONTINUATION{0,}) - trailers, err := http2commaSeparatedTrailers(req) - if err != nil { - return err - } - hasTrailers := trailers != "" - contentLen := http2actualContentLength(req) - hasBody := contentLen != 0 - hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dump) - if err != nil { - return err - } - - // Write the request. - endStream := !hasBody && !hasTrailers - cs.sentHeaders = true - err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) - http2traceWroteHeaders(cs.trace) - return err -} - -// cleanupWriteRequest performs post-request tasks. -// -// If err (the result of writeRequest) is non-nil and the stream is not closed, -// cleanupWriteRequest will send a reset to the peer. -func (cs *http2clientStream) cleanupWriteRequest(err error) { - cc := cs.cc - - if cs.ID == 0 { - // We were canceled before creating the stream, so return our reservation. - cc.decrStreamReservations() - } - - // TODO: write h12Compare test showing whether - // Request.Body is closed by the Transport, - // and in multiple cases: server replies <=299 and >299 - // while still writing request body - cc.mu.Lock() - bodyClosed := cs.reqBodyClosed - cs.reqBodyClosed = true - cc.mu.Unlock() - if !bodyClosed && cs.reqBody != nil { - cs.reqBody.Close() - } - - if err != nil && cs.sentEndStream { - // If the connection is closed immediately after the response is read, - // we may be aborted before finishing up here. If the stream was closed - // cleanly on both sides, there is no error. - select { - case <-cs.peerClosed: - err = nil - default: - } - } - if err != nil { - cs.abortStream(err) // possibly redundant, but harmless - if cs.sentHeaders { - if se, ok := err.(http2StreamError); ok { - if se.Cause != http2errFromPeer { - cc.writeStreamReset(cs.ID, se.Code, err) - } - } else { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) - } - } - cs.bufPipe.CloseWithError(err) // no-op if already closed - } else { - if cs.sentHeaders && !cs.sentEndStream { - cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) - } - cs.bufPipe.CloseWithError(http2errRequestCanceled) - } - if cs.ID != 0 { - cc.forgetStreamID(cs.ID) - } - - cc.wmu.Lock() - werr := cc.werr - cc.wmu.Unlock() - if werr != nil { - cc.Close() - } - - close(cs.donec) -} - -// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. -// Must hold cc.mu. -func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { - for { - cc.lastActive = time.Now() - if cc.closed || !cc.canTakeNewRequestLocked() { - return http2errClientConnUnusable - } - cc.lastIdle = time.Time{} - if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { - return nil - } - cc.pendingRequests++ - cc.cond.Wait() - cc.pendingRequests-- - select { - case <-cs.abort: - return cs.abortErr - default: - } - } -} - -// requires cc.wmu be held -func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { - first := true // first frame written (HEADERS is first, then CONTINUATION) - for len(hdrs) > 0 && cc.werr == nil { - chunk := hdrs - if len(chunk) > maxFrameSize { - chunk = chunk[:maxFrameSize] - } - hdrs = hdrs[len(chunk):] - endHeaders := len(hdrs) == 0 - if first { - cc.fr.WriteHeaders(http2HeadersFrameParam{ - StreamID: streamID, - BlockFragment: chunk, - EndStream: endStream, - EndHeaders: endHeaders, - }) - first = false - } else { - cc.fr.WriteContinuation(streamID, endHeaders, chunk) - } - } - cc.bw.Flush() - return cc.werr -} - -// internal error values; they don't escape to callers -var ( - // abort request body write; don't send cancel - http2errStopReqBodyWrite = errors.New("http2: aborting request body write") - - // abort request body write, but send stream reset of cancel. - http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") - - http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") -) - -// frameScratchBufferLen returns the length of a buffer to use for -// outgoing request bodies to read/write to/from. -// -// It returns max(1, min(peer's advertised max frame size, -// Request.ContentLength+1, 512KB)). -func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { - const max = 512 << 10 - n := int64(maxFrameSize) - if n > max { - n = max - } - if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n { - // Add an extra byte past the declared content-length to - // give the caller's Request.Body io.textprotoReader a chance to - // give us more bytes than they declared, so we can catch it - // early. - n = cl + 1 - } - if n < 1 { - return 1 - } - return int(n) // doesn't truncate; max is 512K -} - -var http2bufPool sync.Pool // of *[]byte - -func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) (err error) { - cc := cs.cc - body := cs.reqBody - sentEnd := false // whether we sent the final DATA frame w/ END_STREAM - - hasTrailers := req.Trailer != nil - remainLen := cs.reqBodyContentLength - hasContentLen := remainLen != -1 - - cc.mu.Lock() - maxFrameSize := int(cc.maxFrameSize) - cc.mu.Unlock() - - // Scratch buffer for reading into & writing from. - scratchLen := cs.frameScratchBufferLen(maxFrameSize) - var buf []byte - if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { - defer http2bufPool.Put(bp) - buf = *bp - } else { - buf = make([]byte, scratchLen) - defer http2bufPool.Put(&buf) - } - - writeData := cc.fr.WriteData - if dump != nil && dump.RequestBody { - writeData = func(streamID uint32, endStream bool, data []byte) error { - dump.dump(data) - return cc.fr.WriteData(streamID, endStream, data) - } - } - - var sawEOF bool - for !sawEOF { - n, err := body.Read(buf[:len(buf)]) - if hasContentLen { - remainLen -= int64(n) - if remainLen == 0 && err == nil { - // The request body's Content-Length was predeclared and - // we just finished reading it all, but the underlying io.textprotoReader - // returned the final chunk with a nil error (which is one of - // the two valid things a textprotoReader can do at EOF). Because we'd prefer - // to send the END_STREAM bit early, double-check that we're actually - // at EOF. Subsequent reads should return (0, EOF) at this point. - // If either value is different, we return an error in one of two ways below. - var scratch [1]byte - var n1 int - n1, err = body.Read(scratch[:]) - remainLen -= int64(n1) - } - if remainLen < 0 { - err = http2errReqBodyTooLong - return err - } - } - if err != nil { - cc.mu.Lock() - bodyClosed := cs.reqBodyClosed - cc.mu.Unlock() - switch { - case bodyClosed: - return http2errStopReqBodyWrite - case err == io.EOF: - sawEOF = true - err = nil - default: - return err - } - } - - remain := buf[:n] - for len(remain) > 0 && err == nil { - var allowed int32 - allowed, err = cs.awaitFlowControl(len(remain)) - if err != nil { - return err - } - cc.wmu.Lock() - data := remain[:allowed] - remain = remain[allowed:] - sentEnd = sawEOF && len(remain) == 0 && !hasTrailers - err = writeData(cs.ID, sentEnd, data) - if err == nil { - // TODO(bradfitz): this flush is for latency, not bandwidth. - // Most requests won't need this. Make this opt-in or - // opt-out? Use some heuristic on the body type? Nagel-like - // timers? Based on 'n'? Only last chunk of this for loop, - // unless flow control tokens are low? For now, always. - // If we change this, see comment below. - err = cc.bw.Flush() - } - cc.wmu.Unlock() - } - if err != nil { - return err - } - } - - if sentEnd { - // Already sent END_STREAM (which implies we have no - // trailers) and flushed, because currently all - // WriteData frames above get a flush. So we're done. - return nil - } - - // Since the RoundTrip contract permits the caller to "mutate or reuse" - // a request after the Response's Body is closed, verify that this hasn't - // happened before accessing the trailers. - cc.mu.Lock() - trailer := req.Trailer - err = cs.abortErr - cc.mu.Unlock() - if err != nil { - return err - } - - cc.wmu.Lock() - defer cc.wmu.Unlock() - var trls []byte - if len(trailer) > 0 { - trls, err = cc.encodeTrailers(trailer, dump) - if err != nil { - return err - } - } - - // Two ways to send END_STREAM: either with trailers, or - // with an empty DATA frame. - if len(trls) > 0 { - err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls) - } else { - err = cc.fr.WriteData(cs.ID, true, nil) - } - if ferr := cc.bw.Flush(); ferr != nil && err == nil { - err = ferr - } - return err -} - -// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow -// control tokens from the server. -// It returns either the non-zero number of tokens taken or an error -// if the stream is dead. -func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { - cc := cs.cc - ctx := cs.ctx - cc.mu.Lock() - defer cc.mu.Unlock() - for { - if cc.closed { - return 0, http2errClientConnClosed - } - if cs.reqBodyClosed { - return 0, http2errStopReqBodyWrite - } - select { - case <-cs.abort: - return 0, cs.abortErr - case <-ctx.Done(): - return 0, ctx.Err() - case <-cs.reqCancel: - return 0, http2errRequestCanceled - default: - } - if a := cs.flow.available(); a > 0 { - take := a - if int(take) > maxBytes { - - take = int32(maxBytes) // can't truncate int; take is int32 - } - if take > int32(cc.maxFrameSize) { - take = int32(cc.maxFrameSize) - } - cs.flow.take(take) - return take, nil - } - cc.cond.Wait() - } -} - -var http2errNilRequestURL = errors.New("http2: Request.URI is nil") - -// requires cc.wmu be held. -func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dump *dumper) ([]byte, error) { - cc.hbuf.Reset() - if req.URL == nil { - return nil, http2errNilRequestURL - } - - host := req.Host - if host == "" { - host = req.URL.Host - } - host, err := httpguts.PunycodeHostPort(host) - if err != nil { - return nil, err - } - - var path string - if req.Method != "CONNECT" { - path = req.URL.RequestURI() - if !http2validPseudoPath(path) { - orig := path - path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) - if !http2validPseudoPath(path) { - if req.URL.Opaque != "" { - return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) - } else { - return nil, fmt.Errorf("invalid request :path %q", orig) - } - } - } - } - - // Check for any invalid headers and return an error before we - // potentially pollute our hpack state. (We want to be able to - // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) - } - } - } - - enumerateHeaders := func(f func(name, value string)) { - // 8.1.2.3 Request Pseudo-Header Fields - // The :path pseudo-header field includes the path and query parts of the - // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of - // [RFC3986]). - f(":authority", host) - m := req.Method - if m == "" { - m = http.MethodGet - } - f(":method", m) - if req.Method != "CONNECT" { - f(":path", path) - f(":scheme", req.URL.Scheme) - } - if trailers != "" { - f("trailer", trailers) - } - - var didUA bool - for k, vv := range req.Header { - if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") { - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - } else if http2asciiEqualFold(k, "connection") || - http2asciiEqualFold(k, "proxy-connection") || - http2asciiEqualFold(k, "transfer-encoding") || - http2asciiEqualFold(k, "upgrade") || - http2asciiEqualFold(k, "keep-alive") { - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. - continue - } else if http2asciiEqualFold(k, "user-agent") { - // Match Go's http1 behavior: at most one - // User-Agent. If set to nil or empty string, - // then omit it. Otherwise if not mentioned, - // include the default (below). - didUA = true - if len(vv) < 1 { - continue - } - vv = vv[:1] - if vv[0] == "" { - continue - } - } else if http2asciiEqualFold(k, "cookie") { - // Per 8.1.2.5 To allow for better compression efficiency, the - // Cookie header field MAY be split into separate header fields, - // each with one or more cookie-pairs. - for _, v := range vv { - for { - p := strings.IndexByte(v, ';') - if p < 0 { - break - } - f("cookie", v[:p]) - p++ - // strip space after semicolon if any. - for p+1 <= len(v) && v[p] == ' ' { - p++ - } - v = v[p:] - } - if len(v) > 0 { - f("cookie", v) - } - } - continue - } - - for _, v := range vv { - f(k, v) - } - } - if http2shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) - } - if addGzipHeader { - f("accept-encoding", "gzip") - } - if !didUA { - f("user-agent", hdrUserAgentValue) - } - } - - // Do a first pass over the headers counting bytes to ensure - // we don't exceed cc.peerMaxHeaderListSize. This is done as a - // separate pass before encoding the headers to prevent - // modifying the hpack state. - hlSize := uint64(0) - enumerateHeaders(func(name, value string) { - hf := hpack.HeaderField{Name: name, Value: value} - hlSize += uint64(hf.Size()) - }) - - if hlSize > cc.peerMaxHeaderListSize { - return nil, http2errRequestHeaderListSize - } - - trace := httptrace.ContextClientTrace(req.Context()) - traceHeaders := http2traceHasWroteHeaderField(trace) - - writeHeader := cc.writeHeader - if dump != nil && dump.RequestHeader { - writeHeader = func(name, value string) { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) - cc.writeHeader(name, value) - } - } - - // Header list size is ok. Write the headers. - enumerateHeaders(func(name, value string) { - name, ascii := http2asciiToLower(name) - if !ascii { - // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header - // field names have to be ASCII characters (just as in HTTP/1.x). - return - } - writeHeader(name, value) - if traceHeaders { - http2traceWroteHeaderField(trace, name, value) - } - }) - - if dump != nil && dump.RequestHeader { - dump.dump([]byte("\r\n")) - } - - return cc.hbuf.Bytes(), nil -} - -// shouldSendReqContentLength reports whether the http2.Transport should send -// a "content-length" request header. This logic is basically a copy of the net/http -// transferWriter.shouldSendContentLength. -// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). -// -1 means unknown. -func http2shouldSendReqContentLength(method string, contentLength int64) bool { - if contentLength > 0 { - return true - } - if contentLength < 0 { - return false - } - // For zero bodies, whether we send a content-length depends on the method. - // It also kinda doesn't matter for http2 either way, with END_STREAM. - switch method { - case "POST", "PUT", "PATCH": - return true - default: - return false - } -} - -// requires cc.wmu be held. -func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dump *dumper) ([]byte, error) { - cc.hbuf.Reset() - - hlSize := uint64(0) - for k, vv := range trailer { - for _, v := range vv { - hf := hpack.HeaderField{Name: k, Value: v} - hlSize += uint64(hf.Size()) - } - } - if hlSize > cc.peerMaxHeaderListSize { - return nil, http2errRequestHeaderListSize - } - - writeHeader := cc.writeHeader - if dump != nil && dump.RequestBody { - writeHeader = func(name, value string) { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) - cc.writeHeader(name, value) - } - } - - for k, vv := range trailer { - lowKey, ascii := http2asciiToLower(k) - if !ascii { - // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header - // field names have to be ASCII characters (just as in HTTP/1.x). - continue - } - // Transfer-Encoding, etc.. have already been filtered at the - // start of RoundTrip - for _, v := range vv { - writeHeader(lowKey, v) - } - } - return cc.hbuf.Bytes(), nil -} - -func (cc *http2ClientConn) writeHeader(name, value string) { - if http2VerboseLogs { - log.Printf("http2: Transport encoding header %q = %q", name, value) - } - cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) -} - -type http2resAndError struct { - _ http2incomparable - res *http.Response - err error -} - -// requires cc.mu be held. -func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { - cs.flow.add(int32(cc.initialWindowSize)) - cs.flow.setConnFlow(&cc.flow) - cs.inflow.add(http2transportDefaultStreamFlow) - cs.inflow.setConnFlow(&cc.inflow) - cs.ID = cc.nextStreamID - cc.nextStreamID += 2 - cc.streams[cs.ID] = cs - if cs.ID == 0 { - panic("assigned stream ID 0") - } -} - -func (cc *http2ClientConn) forgetStreamID(id uint32) { - cc.mu.Lock() - slen := len(cc.streams) - delete(cc.streams, id) - if len(cc.streams) != slen-1 { - panic("forgetting unknown stream id") - } - cc.lastActive = time.Now() - if len(cc.streams) == 0 && cc.idleTimer != nil { - cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = time.Now() - } - // Wake up writeRequestBody via clientStream.awaitFlowControl and - // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() - - closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() - if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { - if http2VerboseLogs { - cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) - } - cc.closed = true - defer cc.tconn.Close() - } - - cc.mu.Unlock() -} - -// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. -type http2clientConnReadLoop struct { - _ http2incomparable - cc *http2ClientConn -} - -// readLoop runs in its own goroutine and reads and dispatches frames. -func (cc *http2ClientConn) readLoop() { - rl := &http2clientConnReadLoop{cc: cc} - defer rl.cleanup() - cc.readerErr = rl.run() - if ce, ok := cc.readerErr.(http2ConnectionError); ok { - cc.wmu.Lock() - cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) - cc.wmu.Unlock() - } -} - -// GoAwayError is returned by the Transport when the server closes the -// TCP connection after sending a GOAWAY frame. -type http2GoAwayError struct { - LastStreamID uint32 - ErrCode http2ErrCode - DebugData string -} - -func (e http2GoAwayError) Error() string { - return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", - e.LastStreamID, e.ErrCode, e.DebugData) -} - -func http2isEOFOrNetReadError(err error) bool { - if err == io.EOF { - return true - } - ne, ok := err.(*net.OpError) - return ok && ne.Op == "read" -} - -func (rl *http2clientConnReadLoop) cleanup() { - cc := rl.cc - defer cc.tconn.Close() - defer cc.t.connPool().MarkDead(cc) - defer close(cc.readerDone) - - if cc.idleTimer != nil { - cc.idleTimer.Stop() - } - - // Close any response bodies if the server closes prematurely. - // TODO: also do this if we've written the headers but not - // gotten a response yet. - err := cc.readerErr - cc.mu.Lock() - if cc.goAway != nil && http2isEOFOrNetReadError(err) { - err = http2GoAwayError{ - LastStreamID: cc.goAway.LastStreamID, - ErrCode: cc.goAway.ErrCode, - DebugData: cc.goAwayDebug, - } - } else if err == io.EOF { - err = io.ErrUnexpectedEOF - } - cc.closed = true - for _, cs := range cc.streams { - select { - case <-cs.peerClosed: - // The server closed the stream before closing the conn, - // so no need to interrupt it. - default: - cs.abortStreamLocked(err) - } - } - cc.cond.Broadcast() - cc.mu.Unlock() -} - -// countReadFrameError calls Transport.CountError with a string -// representing err. -func (cc *http2ClientConn) countReadFrameError(err error) { - f := cc.t.CountError - if f == nil || err == nil { - return - } - if ce, ok := err.(http2ConnectionError); ok { - errCode := http2ErrCode(ce) - f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) - return - } - if errors.Is(err, io.EOF) { - f("read_frame_eof") - return - } - if errors.Is(err, io.ErrUnexpectedEOF) { - f("read_frame_unexpected_eof") - return - } - if errors.Is(err, http2ErrFrameTooLarge) { - f("read_frame_too_large") - return - } - f("read_frame_other") -} - -func (rl *http2clientConnReadLoop) run() error { - cc := rl.cc - gotSettings := false - readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer - if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() - } - for { - f, err := cc.fr.ReadFrame() - if t != nil { - t.Reset(readIdleTimeout) - } - if err != nil { - cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) - } - if se, ok := err.(http2StreamError); ok { - if cs := rl.streamByID(se.StreamID); cs != nil { - if se.Cause == nil { - se.Cause = cc.fr.errDetail - } - rl.endStreamError(cs, se) - } - continue - } else if err != nil { - cc.countReadFrameError(err) - return err - } - if http2VerboseLogs { - cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) - } - if !gotSettings { - if _, ok := f.(*http2SettingsFrame); !ok { - cc.logf("protocol error: received %T before a SETTINGS frame", f) - return http2ConnectionError(http2ErrCodeProtocol) - } - gotSettings = true - } - - switch f := f.(type) { - case *http2MetaHeadersFrame: - err = rl.processHeaders(f) - case *http2DataFrame: - err = rl.processData(f) - case *http2GoAwayFrame: - err = rl.processGoAway(f) - case *http2RSTStreamFrame: - err = rl.processResetStream(f) - case *http2SettingsFrame: - err = rl.processSettings(f) - case *http2PushPromiseFrame: - err = rl.processPushPromise(f) - case *http2WindowUpdateFrame: - err = rl.processWindowUpdate(f) - case *http2PingFrame: - err = rl.processPing(f) - default: - cc.logf("Transport: unhandled response frame type %T", f) - } - if err != nil { - if http2VerboseLogs { - cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) - } - return err - } - } -} - -func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { - cs := rl.streamByID(f.StreamID) - if cs == nil { - // We'd get here if we canceled a request while the - // server had its response still in flight. So if this - // was just something we canceled, ignore it. - return nil - } - if cs.readClosed { - rl.endStreamError(cs, http2StreamError{ - StreamID: f.StreamID, - Code: http2ErrCodeProtocol, - Cause: errors.New("protocol error: headers after END_STREAM"), - }) - return nil - } - if !cs.firstByte { - if cs.trace != nil { - // TODO(bradfitz): move first response byte earlier, - // when we first read the 9 byte header, not waiting - // until all the HEADERS+CONTINUATION frames have been - // merged. This works for now. - http2traceFirstResponseByte(cs.trace) - } - cs.firstByte = true - } - if !cs.pastHeaders { - cs.pastHeaders = true - } else { - return rl.processTrailers(cs, f) - } - - res, err := rl.handleResponse(cs, f) - if err != nil { - if _, ok := err.(http2ConnectionError); ok { - return err - } - // Any other error type is a stream error. - rl.endStreamError(cs, http2StreamError{ - StreamID: f.StreamID, - Code: http2ErrCodeProtocol, - Cause: err, - }) - return nil // return nil from process* funcs to keep conn alive - } - if res == nil { - // (nil, nil) special case. See handleResponse docs. - return nil - } - cs.resTrailer = &res.Trailer - cs.res = res - close(cs.respHeaderRecv) - if f.StreamEnded() { - rl.endStream(cs) - } - return nil -} - -// may return error types nil, or ConnectionError. Any other error value -// is a StreamError of type ErrCodeProtocol. The returned error in that case -// is the detail. -// -// As a special case, handleResponse may return (nil, nil) to skip the -// frame (currently only used for 1xx responses). -func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*http.Response, error) { - if f.Truncated { - return nil, http2errResponseHeaderListSize - } - - status := f.PseudoValue("status") - if status == "" { - return nil, errors.New("malformed response from server: missing status pseudo header") - } - statusCode, err := strconv.Atoi(status) - if err != nil { - return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") - } - - regularFields := f.RegularFields() - strs := make([]string, len(regularFields)) - header := make(http.Header, len(regularFields)) - res := &http.Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, - Header: header, - StatusCode: statusCode, - Status: status + " " + http.StatusText(statusCode), - } - for _, hf := range regularFields { - key := http.CanonicalHeaderKey(hf.Name) - if key == "Trailer" { - t := res.Trailer - if t == nil { - t = make(http.Header) - res.Trailer = t - } - http2foreachHeaderElement(hf.Value, func(v string) { - t[http.CanonicalHeaderKey(v)] = nil - }) - } else { - vv := header[key] - if vv == nil && len(strs) > 0 { - // More than likely this will be a single-element key. - // Most headers aren't multi-valued. - // Set the capacity on strs[0] to 1, so any future append - // won't extend the slice into the other strings. - vv, strs = strs[:1:1], strs[1:] - vv[0] = hf.Value - header[key] = vv - } else { - header[key] = append(vv, hf.Value) - } - } - } - - if statusCode >= 100 && statusCode <= 199 { - if f.StreamEnded() { - return nil, errors.New("1xx informational response with END_STREAM flag") - } - cs.num1xx++ - const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http - if cs.num1xx > max1xxResponses { - return nil, errors.New("http2: too many 1xx informational responses") - } - if fn := cs.get1xxTraceFunc(); fn != nil { - if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil { - return nil, err - } - } - if statusCode == 100 { - http2traceGot100Continue(cs.trace) - select { - case cs.on100 <- struct{}{}: - default: - } - } - cs.pastHeaders = false // do it all again - return nil, nil - } - - res.ContentLength = -1 - if clens := res.Header["Content-Length"]; len(clens) == 1 { - if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { - res.ContentLength = int64(cl) - } else { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } - } else if len(clens) > 1 { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } else if f.StreamEnded() && !cs.isHead { - res.ContentLength = 0 - } - - if cs.isHead { - res.Body = http2noBody - return res, nil - } - - if f.StreamEnded() { - if res.ContentLength > 0 { - res.Body = http2missingBody{} - } else { - res.Body = http2noBody - } - return res, nil - } - - cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) - cs.bytesRemain = res.ContentLength - res.Body = http2transportResponseBody{cs} - - if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = &http2gzipReader{body: res.Body} - res.Uncompressed = true - } - - return res, nil -} - -func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { - if cs.pastTrailers { - // Too many HEADERS frames for this stream. - return http2ConnectionError(http2ErrCodeProtocol) - } - cs.pastTrailers = true - if !f.StreamEnded() { - // We expect that any headers for trailers also - // has END_STREAM. - return http2ConnectionError(http2ErrCodeProtocol) - } - if len(f.PseudoFields()) > 0 { - // No pseudo header fields are defined for trailers. - // TODO: ConnectionError might be overly harsh? Check. - return http2ConnectionError(http2ErrCodeProtocol) - } - - trailer := make(http.Header) - for _, hf := range f.RegularFields() { - key := http.CanonicalHeaderKey(hf.Name) - trailer[key] = append(trailer[key], hf.Value) - } - cs.trailer = trailer - - rl.endStream(cs) - return nil -} - -// transportResponseBody is the concrete type of Transport.RoundTrip's -// Response.Body. It is an io.ReadCloser. -type http2transportResponseBody struct { - cs *http2clientStream -} - -func (b http2transportResponseBody) Read(p []byte) (n int, err error) { - cs := b.cs - cc := cs.cc - - if cs.readErr != nil { - return 0, cs.readErr - } - n, err = b.cs.bufPipe.Read(p) - if cs.bytesRemain != -1 { - if int64(n) > cs.bytesRemain { - n = int(cs.bytesRemain) - if err == nil { - err = errors.New("net/http: server replied with more than declared Content-Length; truncated") - cs.abortStream(err) - } - cs.readErr = err - return int(cs.bytesRemain), err - } - cs.bytesRemain -= int64(n) - if err == io.EOF && cs.bytesRemain > 0 { - err = io.ErrUnexpectedEOF - cs.readErr = err - return n, err - } - } - if n == 0 { - // No flow control tokens to send back. - return - } - - cc.mu.Lock() - var connAdd, streamAdd int32 - // Check the conn-level first, before the stream-level. - if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { - connAdd = http2transportDefaultConnFlow - v - cc.inflow.add(connAdd) - } - if err == nil { // No need to refresh if the stream is over or failed. - // Consider any buffered body data (read from the conn but not - // consumed by the client) when computing flow control for this - // stream. - v := int(cs.inflow.available()) + cs.bufPipe.Len() - if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { - streamAdd = int32(http2transportDefaultStreamFlow - v) - cs.inflow.add(streamAdd) - } - } - cc.mu.Unlock() - - if connAdd != 0 || streamAdd != 0 { - cc.wmu.Lock() - defer cc.wmu.Unlock() - if connAdd != 0 { - cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) - } - if streamAdd != 0 { - cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) - } - cc.bw.Flush() + // Scratch buffer for reading into & writing from. + scratchLen := cs.frameScratchBufferLen(maxFrameSize) + var buf []byte + if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer http2bufPool.Put(bp) + buf = *bp + } else { + buf = make([]byte, scratchLen) + defer http2bufPool.Put(&buf) } - return -} - -var http2errClosedResponseBody = errors.New("http2: response body closed") - -func (b http2transportResponseBody) Close() error { - cs := b.cs - cc := cs.cc - - unread := cs.bufPipe.Len() - if unread > 0 { - cc.mu.Lock() - // Return connection-level flow control. - if unread > 0 { - cc.inflow.add(int32(unread)) - } - cc.mu.Unlock() - // TODO(dneil): Acquiring this mutex can block indefinitely. - // Move flow control return to a goroutine? - cc.wmu.Lock() - // Return connection-level flow control. - if unread > 0 { - cc.fr.WriteWindowUpdate(0, uint32(unread)) + writeData := cc.fr.WriteData + if dump != nil && dump.RequestBody { + writeData = func(streamID uint32, endStream bool, data []byte) error { + dump.dump(data) + return cc.fr.WriteData(streamID, endStream, data) } - cc.bw.Flush() - cc.wmu.Unlock() - } - - cs.bufPipe.BreakWithError(http2errClosedResponseBody) - cs.abortStream(http2errClosedResponseBody) - - select { - case <-cs.donec: - case <-cs.ctx.Done(): - // See golang/go#49366: The net/http package can cancel the - // request context after the response body is fully read. - // Don't treat this as an error. - return nil - case <-cs.reqCancel: - return http2errRequestCanceled } - return nil -} -func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { - cc := rl.cc - cs := rl.streamByID(f.StreamID) - data := f.Data() - if cs == nil { - cc.mu.Lock() - neverSent := cc.nextStreamID - cc.mu.Unlock() - if f.StreamID >= neverSent { - // We never asked for this. - cc.logf("http2: Transport received unsolicited DATA frame; closing connection") - return http2ConnectionError(http2ErrCodeProtocol) + var sawEOF bool + for !sawEOF { + n, err := body.Read(buf[:len(buf)]) + if hasContentLen { + remainLen -= int64(n) + if remainLen == 0 && err == nil { + // The request body's Content-Length was predeclared and + // we just finished reading it all, but the underlying io.textprotoReader + // returned the final chunk with a nil error (which is one of + // the two valid things a textprotoReader can do at EOF). Because we'd prefer + // to send the END_STREAM bit early, double-check that we're actually + // at EOF. Subsequent reads should return (0, EOF) at this point. + // If either value is different, we return an error in one of two ways below. + var scratch [1]byte + var n1 int + n1, err = body.Read(scratch[:]) + remainLen -= int64(n1) + } + if remainLen < 0 { + err = http2errReqBodyTooLong + return err + } } - // We probably did ask for this, but canceled. Just ignore it. - // TODO: be stricter here? only silently ignore things which - // we canceled, but not things which were closed normally - // by the peer? Tough without accumulating too much state. - - // But at least return their flow control: - if f.Length > 0 { + if err != nil { cc.mu.Lock() - cc.inflow.add(int32(f.Length)) - cc.mu.Unlock() - - cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(f.Length)) - cc.bw.Flush() - cc.wmu.Unlock() - } - return nil - } - if cs.readClosed { - cc.logf("protocol error: received DATA after END_STREAM") - rl.endStreamError(cs, http2StreamError{ - StreamID: f.StreamID, - Code: http2ErrCodeProtocol, - }) - return nil - } - if !cs.firstByte { - cc.logf("protocol error: received DATA before a HEADERS frame") - rl.endStreamError(cs, http2StreamError{ - StreamID: f.StreamID, - Code: http2ErrCodeProtocol, - }) - return nil - } - if f.Length > 0 { - if cs.isHead && len(data) > 0 { - cc.logf("protocol error: received DATA on a HEAD request") - rl.endStreamError(cs, http2StreamError{ - StreamID: f.StreamID, - Code: http2ErrCodeProtocol, - }) - return nil - } - // Check connection-level flow control. - cc.mu.Lock() - if cs.inflow.available() >= int32(f.Length) { - cs.inflow.take(int32(f.Length)) - } else { + bodyClosed := cs.reqBodyClosed cc.mu.Unlock() - return http2ConnectionError(http2ErrCodeFlowControl) - } - // Return any padded flow control now, since we won't - // refund it later on body reads. - var refund int - if pad := int(f.Length) - len(data); pad > 0 { - refund += pad - } - - didReset := false - var err error - if len(data) > 0 { - if _, err = cs.bufPipe.Write(data); err != nil { - // Return len(data) now if the stream is already closed, - // since data will never be read. - didReset = true - refund += len(data) + switch { + case bodyClosed: + return http2errStopReqBodyWrite + case err == io.EOF: + sawEOF = true + err = nil + default: + return err } } - if refund > 0 { - cc.inflow.add(int32(refund)) - if !didReset { - cs.inflow.add(int32(refund)) + remain := buf[:n] + for len(remain) > 0 && err == nil { + var allowed int32 + allowed, err = cs.awaitFlowControl(len(remain)) + if err != nil { + return err } - } - cc.mu.Unlock() - - if refund > 0 { cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(refund)) - if !didReset { - cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) - } - cc.bw.Flush() - cc.wmu.Unlock() - } - - if err != nil { - rl.endStreamError(cs, err) - return nil - } - } - - if f.StreamEnded() { - rl.endStream(cs) - } - return nil -} - -func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { - // TODO: check that any declared content-length matches, like - // server.go's (*stream).endStream method. - if !cs.readClosed { - cs.readClosed = true - // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a - // race condition: The caller can read io.EOF from Response.Body - // and close the body before we close cs.peerClosed, causing - // cleanupWriteRequest to send a RST_STREAM. - rl.cc.mu.Lock() - defer rl.cc.mu.Unlock() - cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) - close(cs.peerClosed) - } -} - -func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { - cs.readAborted = true - cs.abortStream(err) -} - -func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { - rl.cc.mu.Lock() - defer rl.cc.mu.Unlock() - cs := rl.cc.streams[id] - if cs != nil && !cs.readAborted { - return cs - } - return nil -} - -func (cs *http2clientStream) copyTrailers() { - for k, vv := range cs.trailer { - t := cs.resTrailer - if *t == nil { - *t = make(http.Header) + data := remain[:allowed] + remain = remain[allowed:] + sentEnd = sawEOF && len(remain) == 0 && !hasTrailers + err = writeData(cs.ID, sentEnd, data) + if err == nil { + // TODO(bradfitz): this flush is for latency, not bandwidth. + // Most requests won't need this. Make this opt-in or + // opt-out? Use some heuristic on the body type? Nagel-like + // timers? Based on 'n'? Only last chunk of this for loop, + // unless flow control tokens are low? For now, always. + // If we change this, see comment below. + err = cc.bw.Flush() + } + cc.wmu.Unlock() + } + if err != nil { + return err } - (*t)[k] = vv } -} -func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { - cc := rl.cc - cc.t.connPool().MarkDead(cc) - if f.ErrCode != 0 { - // TODO: deal with GOAWAY more. particularly the error code - cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) - if fn := cc.t.CountError; fn != nil { - fn("recv_goaway_" + f.ErrCode.stringToken()) - } + if sentEnd { + // Already sent END_STREAM (which implies we have no + // trailers) and flushed, because currently all + // WriteData frames above get a flush. So we're done. + return nil + } + // Since the RoundTrip contract permits the caller to "mutate or reuse" + // a request after the Response's Body is closed, verify that this hasn't + // happened before accessing the trailers. + cc.mu.Lock() + trailer := req.Trailer + err = cs.abortErr + cc.mu.Unlock() + if err != nil { + return err } - cc.setGoAway(f) - return nil -} -func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { - cc := rl.cc - // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. - // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. cc.wmu.Lock() defer cc.wmu.Unlock() + var trls []byte + if len(trailer) > 0 { + trls, err = cc.encodeTrailers(trailer, dump) + if err != nil { + return err + } + } - if err := rl.processSettingsNoWrite(f); err != nil { - return err + // Two ways to send END_STREAM: either with trailers, or + // with an empty DATA frame. + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) } - if !f.IsAck() { - cc.fr.WriteSettingsAck() - cc.bw.Flush() + if ferr := cc.bw.Flush(); ferr != nil && err == nil { + err = ferr } - return nil + return err } -func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { - cc := rl.cc +// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow +// control tokens from the server. +// It returns either the non-zero number of tokens taken or an error +// if the stream is dead. +func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { + cc := cs.cc + ctx := cs.ctx cc.mu.Lock() defer cc.mu.Unlock() + for { + if cc.closed { + return 0, http2errClientConnClosed + } + if cs.reqBodyClosed { + return 0, http2errStopReqBodyWrite + } + select { + case <-cs.abort: + return 0, cs.abortErr + case <-ctx.Done(): + return 0, ctx.Err() + case <-cs.reqCancel: + return 0, http2errRequestCanceled + default: + } + if a := cs.flow.available(); a > 0 { + take := a + if int(take) > maxBytes { - if f.IsAck() { - if cc.wantSettingsAck { - cc.wantSettingsAck = false - return nil + take = int32(maxBytes) // can't truncate int; take is int32 + } + if take > int32(cc.maxFrameSize) { + take = int32(cc.maxFrameSize) + } + cs.flow.take(take) + return take, nil } - return http2ConnectionError(http2ErrCodeProtocol) + cc.cond.Wait() } +} - var seenMaxConcurrentStreams bool - err := f.ForeachSetting(func(s http2Setting) error { - switch s.ID { - case http2SettingMaxFrameSize: - cc.maxFrameSize = s.Val - case http2SettingMaxConcurrentStreams: - cc.maxConcurrentStreams = s.Val - seenMaxConcurrentStreams = true - case http2SettingMaxHeaderListSize: - cc.peerMaxHeaderListSize = uint64(s.Val) - case http2SettingInitialWindowSize: - // Values above the maximum flow-control - // window size of 2^31-1 MUST be treated as a - // connection error (Section 5.4.1) of type - // FLOW_CONTROL_ERROR. - if s.Val > math.MaxInt32 { - return http2ConnectionError(http2ErrCodeFlowControl) - } +var http2errNilRequestURL = errors.New("http2: Request.URI is nil") - // Adjust flow control of currently-open - // frames by the difference of the old initial - // window size and this one. - delta := int32(s.Val) - int32(cc.initialWindowSize) - for _, cs := range cc.streams { - cs.flow.add(delta) +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dump *dumper) ([]byte, error) { + cc.hbuf.Reset() + if req.URL == nil { + return nil, http2errNilRequestURL + } + + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } } - cc.cond.Broadcast() + } + } - cc.initialWindowSize = s.Val - default: - // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. - cc.vlogf("Unhandled Setting: %v", s) + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } } - return nil - }) - if err != nil { - return err } - if !cc.seenSettings { - if !seenMaxConcurrentStreams { - // This was the servers initial SETTINGS frame and it - // didn't contain a MAX_CONCURRENT_STREAMS field so - // increase the number of concurrent streams this - // connection can establish to our default. - cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + m := req.Method + if m == "" { + m = http.MethodGet + } + f(":method", m) + if req.Method != "CONNECT" { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if http2asciiEqualFold(k, "connection") || + http2asciiEqualFold(k, "proxy-connection") || + http2asciiEqualFold(k, "transfer-encoding") || + http2asciiEqualFold(k, "upgrade") || + http2asciiEqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if http2asciiEqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } else if http2asciiEqualFold(k, "cookie") { + // Per 8.1.2.5 To allow for better compression efficiency, the + // Cookie header field MAY be split into separate header fields, + // each with one or more cookie-pairs. + for _, v := range vv { + for { + p := strings.IndexByte(v, ';') + if p < 0 { + break + } + f("cookie", v[:p]) + p++ + // strip space after semicolon if any. + for p+1 <= len(v) && v[p] == ' ' { + p++ + } + v = v[p:] + } + if len(v) > 0 { + f("cookie", v) + } + } + continue + } + + for _, v := range vv { + f(k, v) + } + } + if http2shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", hdrUserAgentValue) } - cc.seenSettings = true } - return nil -} + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) -func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { - cc := rl.cc - cs := rl.streamByID(f.StreamID) - if f.StreamID != 0 && cs == nil { - return nil + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize } - cc.mu.Lock() - defer cc.mu.Unlock() + trace := httptrace.ContextClientTrace(req.Context()) + traceHeaders := http2traceHasWroteHeaderField(trace) - fl := &cc.flow - if cs != nil { - fl = &cs.flow + writeHeader := cc.writeHeader + if dump != nil && dump.RequestHeader { + writeHeader = func(name, value string) { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + cc.writeHeader(name, value) + } } - if !fl.add(int32(f.Increment)) { - return http2ConnectionError(http2ErrCodeFlowControl) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name, ascii := http2asciiToLower(name) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + return + } + writeHeader(name, value) + if traceHeaders { + http2traceWroteHeaderField(trace, name, value) + } + }) + + if dump != nil && dump.RequestHeader { + dump.dump([]byte("\r\n")) } - cc.cond.Broadcast() - return nil + + return cc.hbuf.Bytes(), nil } -func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { - cs := rl.streamByID(f.StreamID) - if cs == nil { - // TODO: return error if server tries to RST_STREAM an idle stream - return nil +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func http2shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true } - serr := http2streamError(cs.ID, f.ErrCode) - serr.Cause = http2errFromPeer - if f.ErrCode == http2ErrCodeProtocol { - rl.cc.SetDoNotReuse() + if contentLength < 0 { + return false } - if fn := cs.cc.t.CountError; fn != nil { - fn("recv_rststream_" + f.ErrCode.stringToken()) + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false } - cs.abortStream(serr) - - cs.bufPipe.CloseWithError(serr) - return nil } -// Ping sends a PING frame to the server and waits for the ack. -func (cc *http2ClientConn) Ping(ctx context.Context) error { - c := make(chan struct{}) - // Generate a random payload - var p [8]byte - for { - if _, err := rand.Read(p[:]); err != nil { - return err +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dump *dumper) ([]byte, error) { + cc.hbuf.Reset() + + hlSize := uint64(0) + for k, vv := range trailer { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + hlSize += uint64(hf.Size()) } - cc.mu.Lock() - // check for dup before insert - if _, found := cc.pings[p]; !found { - cc.pings[p] = c - cc.mu.Unlock() - break + } + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + + writeHeader := cc.writeHeader + if dump != nil && dump.RequestBody { + writeHeader = func(name, value string) { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + cc.writeHeader(name, value) } - cc.mu.Unlock() } - errc := make(chan error, 1) - go func() { - cc.wmu.Lock() - defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err - return + + for k, vv := range trailer { + lowKey, ascii := http2asciiToLower(k) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + continue } - if err := cc.bw.Flush(); err != nil { - errc <- err - return + // Transfer-Encoding, etc.. have already been filtered at the + // start of RoundTrip + for _, v := range vv { + writeHeader(lowKey, v) } - }() - select { - case <-c: - return nil - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-cc.readerDone: - // connection closed - return cc.readerErr } + return cc.hbuf.Bytes(), nil } -func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { - if f.IsAck() { - cc := rl.cc - cc.mu.Lock() - defer cc.mu.Unlock() - // If ack, notify listener if any - if c, ok := cc.pings[f.Data]; ok { - close(c) - delete(cc.pings, f.Data) - } - return nil - } - cc := rl.cc - cc.wmu.Lock() - defer cc.wmu.Unlock() - if err := cc.fr.WritePing(true, f.Data); err != nil { - return err +func (cc *http2ClientConn) writeHeader(name, value string) { + if http2VerboseLogs { + log.Printf("http2: Transport encoding header %q = %q", name, value) } - return cc.bw.Flush() + cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) } -func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { - // We told the peer we don't want them. - // Spec says: - // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH - // setting of the peer endpoint is set to 0. An endpoint that - // has set this setting and has received acknowledgement MUST - // treat the receipt of a PUSH_PROMISE frame as a connection - // error (Section 5.4.1) of type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) +type http2resAndError struct { + _ http2incomparable + res *http.Response + err error } -func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { - // TODO: map err to more interesting error codes, once the - // HTTP community comes up with some. But currently for - // RST_STREAM there's no equivalent to GOAWAY frame's debug - // data, and the error codes are all pretty vague ("cancel"). - cc.wmu.Lock() - cc.fr.WriteRSTStream(streamID, code) - cc.bw.Flush() - cc.wmu.Unlock() +// requires cc.mu be held. +func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { + cs.flow.add(int32(cc.initialWindowSize)) + cs.flow.setConnFlow(&cc.flow) + cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.setConnFlow(&cc.inflow) + cs.ID = cc.nextStreamID + cc.nextStreamID += 2 + cc.streams[cs.ID] = cs + if cs.ID == 0 { + panic("assigned stream ID 0") + } } -var ( - http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") - http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") -) +func (cc *http2ClientConn) forgetStreamID(id uint32) { + cc.mu.Lock() + slen := len(cc.streams) + delete(cc.streams, id) + if len(cc.streams) != slen-1 { + panic("forgetting unknown stream id") + } + cc.lastActive = time.Now() + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + // Wake up writeRequestBody via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) + } + cc.closed = true + defer cc.tconn.Close() + } -func (cc *http2ClientConn) logf(format string, args ...interface{}) { - cc.t.logf(format, args...) + cc.mu.Unlock() } -func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { - cc.t.vlogf(format, args...) +// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. +type http2clientConnReadLoop struct { + _ http2incomparable + cc *http2ClientConn } -func (t *http2Transport) vlogf(format string, args ...interface{}) { - if http2VerboseLogs { - t.logf(format, args...) +// readLoop runs in its own goroutine and reads and dispatches frames. +func (cc *http2ClientConn) readLoop() { + rl := &http2clientConnReadLoop{cc: cc} + defer rl.cleanup() + cc.readerErr = rl.run() + if ce, ok := cc.readerErr.(http2ConnectionError); ok { + cc.wmu.Lock() + cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.wmu.Unlock() } } -func (t *http2Transport) logf(format string, args ...interface{}) { - log.Printf(format, args...) +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type http2GoAwayError struct { + LastStreamID uint32 + ErrCode http2ErrCode + DebugData string } -var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) - -type http2missingBody struct{} - -func (http2missingBody) Close() error { return nil } - -func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } +func (e http2GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} -func http2strSliceContains(ss []string, s string) bool { - for _, v := range ss { - if v == s { - return true - } +func http2isEOFOrNetReadError(err error) bool { + if err == io.EOF { + return true } - return false + ne, ok := err.(*net.OpError) + return ok && ne.Op == "read" } -type http2erringRoundTripper struct{ err error } - -func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } - -func (rt http2erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { - return nil, rt.err -} +func (rl *http2clientConnReadLoop) cleanup() { + cc := rl.cc + defer cc.tconn.Close() + defer cc.t.connPool().MarkDead(cc) + defer close(cc.readerDone) -// gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -type http2gzipReader struct { - _ http2incomparable - body io.ReadCloser // underlying Response.Body - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // sticky error -} + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } -func (gz *http2gzipReader) Read(p []byte) (n int, err error) { - if gz.zerr != nil { - return 0, gz.zerr + // Close any response bodies if the server closes prematurely. + // TODO: also do this if we've written the headers but not + // gotten a response yet. + err := cc.readerErr + cc.mu.Lock() + if cc.goAway != nil && http2isEOFOrNetReadError(err) { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else if err == io.EOF { + err = io.ErrUnexpectedEOF } - if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) - if err != nil { - gz.zerr = err - return 0, err + cc.closed = true + for _, cs := range cc.streams { + select { + case <-cs.peerClosed: + // The server closed the stream before closing the conn, + // so no need to interrupt it. + default: + cs.abortStreamLocked(err) } } - return gz.zr.Read(p) + cc.cond.Broadcast() + cc.mu.Unlock() } -func (gz *http2gzipReader) Close() error { - return gz.body.Close() +// countReadFrameError calls Transport.CountError with a string +// representing err. +func (cc *http2ClientConn) countReadFrameError(err error) { + f := cc.t.CountError + if f == nil || err == nil { + return + } + if ce, ok := err.(http2ConnectionError); ok { + errCode := http2ErrCode(ce) + f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) + return + } + if errors.Is(err, io.EOF) { + f("read_frame_eof") + return + } + if errors.Is(err, io.ErrUnexpectedEOF) { + f("read_frame_unexpected_eof") + return + } + if errors.Is(err, http2ErrFrameTooLarge) { + f("read_frame_too_large") + return + } + f("read_frame_other") } -type http2errorReader struct{ err error } - -func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } - -// isConnectionCloseRequest reports whether req should use its own -// connection for a single request and then close the connection. -func http2isConnectionCloseRequest(req *http.Request) bool { - return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") -} +func (rl *http2clientConnReadLoop) run() error { + cc := rl.cc + gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } + for { + f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } + if err != nil { + cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) + } + if se, ok := err.(http2StreamError); ok { + if cs := rl.streamByID(se.StreamID); cs != nil { + if se.Cause == nil { + se.Cause = cc.fr.errDetail + } + rl.endStreamError(cs, se) + } + continue + } else if err != nil { + cc.countReadFrameError(err) + return err + } + if http2VerboseLogs { + cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) + } + if !gotSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + cc.logf("protocol error: received %T before a SETTINGS frame", f) + return http2ConnectionError(http2ErrCodeProtocol) + } + gotSettings = true + } -// registerHTTPSProtocol calls Transport.RegisterProtocol but -// converting panics into errors. -func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) + switch f := f.(type) { + case *http2MetaHeadersFrame: + err = rl.processHeaders(f) + case *http2DataFrame: + err = rl.processData(f) + case *http2GoAwayFrame: + err = rl.processGoAway(f) + case *http2RSTStreamFrame: + err = rl.processResetStream(f) + case *http2SettingsFrame: + err = rl.processSettings(f) + case *http2PushPromiseFrame: + err = rl.processPushPromise(f) + case *http2WindowUpdateFrame: + err = rl.processWindowUpdate(f) + case *http2PingFrame: + err = rl.processPing(f) + default: + cc.logf("Transport: unhandled response frame type %T", f) } - }() - t.RegisterProtocol("https", rt) - return nil + if err != nil { + if http2VerboseLogs { + cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) + } + return err + } + } } -// noDialH2RoundTripper is a RoundTripper which only tries to complete the request -// if there's already has a cached connection to the host. -// (The field is exported so it can be accessed via reflect from net/http; tested -// by TestNoDialH2RoundTripperType) -type http2noDialH2RoundTripper struct{ *http2Transport } - -func (rt http2noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - res, err := rt.http2Transport.RoundTrip(req) - if http2isNoCachedConnError(err) { - return nil, http.ErrSkipAltProtocol +func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { + cs := rl.streamByID(f.StreamID) + if cs == nil { + // We'd get here if we canceled a request while the + // server had its response still in flight. So if this + // was just something we canceled, ignore it. + return nil + } + if cs.readClosed { + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: errors.New("protocol error: headers after END_STREAM"), + }) + return nil + } + if !cs.firstByte { + if cs.trace != nil { + // TODO(bradfitz): move first response byte earlier, + // when we first read the 9 byte header, not waiting + // until all the HEADERS+CONTINUATION frames have been + // merged. This works for now. + http2traceFirstResponseByte(cs.trace) + } + cs.firstByte = true + } + if !cs.pastHeaders { + cs.pastHeaders = true + } else { + return rl.processTrailers(cs, f) } - return res, err -} -func (t *http2Transport) idleConnTimeout() time.Duration { - if t.t1 != nil { - return t.t1.IdleConnTimeout + res, err := rl.handleResponse(cs, f) + if err != nil { + if _, ok := err.(http2ConnectionError); ok { + return err + } + // Any other error type is a stream error. + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: err, + }) + return nil // return nil from process* funcs to keep conn alive } - return 0 + if res == nil { + // (nil, nil) special case. See handleResponse docs. + return nil + } + cs.resTrailer = &res.Trailer + cs.res = res + close(cs.respHeaderRecv) + if f.StreamEnded() { + rl.endStream(cs) + } + return nil } -func http2traceGetConn(req *http.Request, hostPort string) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GetConn == nil { - return +// may return error types nil, or ConnectionError. Any other error value +// is a StreamError of type ErrCodeProtocol. The returned error in that case +// is the detail. +// +// As a special case, handleResponse may return (nil, nil) to skip the +// frame (currently only used for 1xx responses). +func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*http.Response, error) { + if f.Truncated { + return nil, http2errResponseHeaderListSize } - trace.GetConn(hostPort) -} -func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GotConn == nil { - return + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("malformed response from server: missing status pseudo header") } - ci := httptrace.GotConnInfo{Conn: cc.tconn} - ci.Reused = reused - cc.mu.Lock() - ci.WasIdle = len(cc.streams) == 0 && reused - if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = time.Now().Sub(cc.lastActive) + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") } - cc.mu.Unlock() - - trace.GotConn(ci) -} -func http2traceWroteHeaders(trace *httptrace.ClientTrace) { - if trace != nil && trace.WroteHeaders != nil { - trace.WroteHeaders() + regularFields := f.RegularFields() + strs := make([]string, len(regularFields)) + header := make(http.Header, len(regularFields)) + res := &http.Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + http.StatusText(statusCode), } -} - -func http2traceGot100Continue(trace *httptrace.ClientTrace) { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() + for _, hf := range regularFields { + key := http.CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(http.Header) + res.Trailer = t + } + http2foreachHeaderElement(hf.Value, func(v string) { + t[http.CanonicalHeaderKey(v)] = nil + }) + } else { + vv := header[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = hf.Value + header[key] = vv + } else { + header[key] = append(vv, hf.Value) + } + } } -} -func http2traceWait100Continue(trace *httptrace.ClientTrace) { - if trace != nil && trace.Wait100Continue != nil { - trace.Wait100Continue() + if statusCode >= 100 && statusCode <= 199 { + if f.StreamEnded() { + return nil, errors.New("1xx informational response with END_STREAM flag") + } + cs.num1xx++ + const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http + if cs.num1xx > max1xxResponses { + return nil, errors.New("http2: too many 1xx informational responses") + } + if fn := cs.get1xxTraceFunc(); fn != nil { + if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil { + return nil, err + } + } + if statusCode == 100 { + http2traceGot100Continue(cs.trace) + select { + case cs.on100 <- struct{}{}: + default: + } + } + cs.pastHeaders = false // do it all again + return nil, nil } -} -func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { - if trace != nil && trace.WroteRequest != nil { - trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) + } else { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } else if f.StreamEnded() && !cs.isHead { + res.ContentLength = 0 } -} -func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { - if trace != nil && trace.GotFirstResponseByte != nil { - trace.GotFirstResponseByte() + if cs.isHead { + res.Body = http2noBody + return res, nil } -} -// writeFramer is implemented by any type that is used to write frames. -type http2writeFramer interface { - writeFrame(http2writeContext) error - - // staysWithinBuffer reports whether this writer promises that - // it will only write less than or equal to size bytes, and it - // won't Flush the write context. - staysWithinBuffer(size int) bool -} - -// writeContext is the interface needed by the various frame writer -// types below. All the writeFrame methods below are scheduled via the -// frame writing scheduler (see writeScheduler in writesched.go). -// -// This interface is implemented by *serverConn. -// -// TODO: decide whether to a) use this in the client code (which didn't -// end up using this yet, because it has a simpler design, not -// currently implementing priorities), or b) delete this and -// make the server code a bit more concrete. -type http2writeContext interface { - Framer() *http2Framer - Flush() error - CloseConn() error - // HeaderEncoder returns an HPACK encoder that writes to the - // returned buffer. - HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) -} - -// writeEndsStream reports whether w writes a frame that will transition -// the stream to a half-closed local state. This returns false for RST_STREAM, -// which closes the entire stream (not just the local half). -func http2writeEndsStream(w http2writeFramer) bool { - switch v := w.(type) { - case *http2writeData: - return v.endStream - case *http2writeResHeaders: - return v.endStream - case nil: - // This can only happen if the caller reuses w after it's - // been intentionally nil'ed out to prevent use. Keep this - // here to catch future refactoring breaking it. - panic("writeEndsStream called on nil writeFramer") + if f.StreamEnded() { + if res.ContentLength > 0 { + res.Body = http2missingBody{} + } else { + res.Body = http2noBody + } + return res, nil } - return false -} - -type http2flushFrameWriter struct{} -func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { - return ctx.Flush() -} - -func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } - -type http2writeSettings []http2Setting - -func (s http2writeSettings) staysWithinBuffer(max int) bool { - const settingSize = 6 // uint16 + uint32 - return http2frameHeaderLen+settingSize*len(s) <= max - -} - -func (s http2writeSettings) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteSettings([]http2Setting(s)...) -} - -type http2writeGoAway struct { - maxStreamID uint32 - code http2ErrCode -} - -func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { - err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) - ctx.Flush() // ignore error: we're hanging up on them anyway - return err -} - -func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes - -type http2writeData struct { - streamID uint32 - p []byte - endStream bool -} - -func (w *http2writeData) String() string { - return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) -} - -func (w *http2writeData) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) -} - -func (w *http2writeData) staysWithinBuffer(max int) bool { - return http2frameHeaderLen+len(w.p) <= max -} - -// handlerPanicRST is the message sent from handler goroutines when -// the handler panics. -type http2handlerPanicRST struct { - StreamID uint32 -} - -func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) -} - -func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } - -func (se http2StreamError) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) -} - -func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } - -type http2writePingAck struct{ pf *http2PingFrame } - -func (w http2writePingAck) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WritePing(true, w.pf.Data) -} - -func (w http2writePingAck) staysWithinBuffer(max int) bool { - return http2frameHeaderLen+len(w.pf.Data) <= max -} + cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} -type http2writeSettingsAck struct{} + if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + res.Uncompressed = true + } -func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteSettingsAck() + return res, nil } -func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } - -// splitHeaderBlock splits headerBlock into fragments so that each fragment fits -// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true -// for the first/last fragment, respectively. -func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 - - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { - return err - } - first = false +func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { + if cs.pastTrailers { + // Too many HEADERS frames for this stream. + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !f.StreamEnded() { + // We expect that any headers for trailers also + // has END_STREAM. + return http2ConnectionError(http2ErrCodeProtocol) + } + if len(f.PseudoFields()) > 0 { + // No pseudo header fields are defined for trailers. + // TODO: ConnectionError might be overly harsh? Check. + return http2ConnectionError(http2ErrCodeProtocol) } - return nil -} - -// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames -// for HTTP response headers or trailers from a server handler. -type http2writeResHeaders struct { - streamID uint32 - httpResCode int // 0 means no ":status" line - h http.Header // may be nil - trailers []string // if non-nil, which keys of h to write. nil means all. - endStream bool - - date string - contentType string - contentLength string -} -func http2encKV(enc *hpack.Encoder, k, v string) { - if http2VerboseLogs { - log.Printf("http2: server encoding header %q = %q", k, v) + trailer := make(http.Header) + for _, hf := range f.RegularFields() { + key := http.CanonicalHeaderKey(hf.Name) + trailer[key] = append(trailer[key], hf.Value) } - enc.WriteField(hpack.HeaderField{Name: k, Value: v}) -} + cs.trailer = trailer -func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { - // TODO: this is a common one. It'd be nice to return true - // here and get into the fast path if we could be clever and - // calculate the size fast enough, or at least a conservative - // upper bound that usually fires. (Maybe if w.h and - // w.trailers are nil, so we don't need to enumerate it.) - // Otherwise I'm afraid that just calculating the length to - // answer this question would be slower than the ~2µs benefit. - return false + rl.endStream(cs) + return nil } -func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() - - if w.httpResCode != 0 { - http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) - } +// transportResponseBody is the concrete type of Transport.RoundTrip's +// Response.Body. It is an io.ReadCloser. +type http2transportResponseBody struct { + cs *http2clientStream +} - http2encodeHeaders(enc, w.h, w.trailers) +func (b http2transportResponseBody) Read(p []byte) (n int, err error) { + cs := b.cs + cc := cs.cc - if w.contentType != "" { - http2encKV(enc, "content-type", w.contentType) + if cs.readErr != nil { + return 0, cs.readErr } - if w.contentLength != "" { - http2encKV(enc, "content-length", w.contentLength) + n, err = b.cs.bufPipe.Read(p) + if cs.bytesRemain != -1 { + if int64(n) > cs.bytesRemain { + n = int(cs.bytesRemain) + if err == nil { + err = errors.New("net/http: server replied with more than declared Content-Length; truncated") + cs.abortStream(err) + } + cs.readErr = err + return int(cs.bytesRemain), err + } + cs.bytesRemain -= int64(n) + if err == io.EOF && cs.bytesRemain > 0 { + err = io.ErrUnexpectedEOF + cs.readErr = err + return n, err + } } - if w.date != "" { - http2encKV(enc, "date", w.date) + if n == 0 { + // No flow control tokens to send back. + return } - headerBlock := buf.Bytes() - if len(headerBlock) == 0 && w.trailers == nil { - panic("unexpected empty hpack") + cc.mu.Lock() + var connAdd, streamAdd int32 + // Check the conn-level first, before the stream-level. + if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { + connAdd = http2transportDefaultConnFlow - v + cc.inflow.add(connAdd) } - - return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) -} - -func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { - if firstFrag { - return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: lastFrag, - }) - } else { - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + if err == nil { // No need to refresh if the stream is over or failed. + // Consider any buffered body data (read from the conn but not + // consumed by the client) when computing flow control for this + // stream. + v := int(cs.inflow.available()) + cs.bufPipe.Len() + if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = int32(http2transportDefaultStreamFlow - v) + cs.inflow.add(streamAdd) + } } -} - -// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. -type http2writePushPromise struct { - streamID uint32 // pusher stream - method string // for :method - url *url.URL // for :scheme, :authority, :path - h http.Header + cc.mu.Unlock() - // Creates an ID for a pushed stream. This runs on serveG just before - // the frame is written. The returned ID is copied to promisedID. - allocatePromisedID func() (uint32, error) - promisedID uint32 + if connAdd != 0 || streamAdd != 0 { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if connAdd != 0 { + cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + } + if streamAdd != 0 { + cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + } + cc.bw.Flush() + } + return } -func (w *http2writePushPromise) staysWithinBuffer(max int) bool { - // TODO: see writeResHeaders.staysWithinBuffer - return false -} +var http2errClosedResponseBody = errors.New("http2: response body closed") -func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() +func (b http2transportResponseBody) Close() error { + cs := b.cs + cc := cs.cc - http2encKV(enc, ":method", w.method) - http2encKV(enc, ":scheme", w.url.Scheme) - http2encKV(enc, ":authority", w.url.Host) - http2encKV(enc, ":path", w.url.RequestURI()) - http2encodeHeaders(enc, w.h, nil) + unread := cs.bufPipe.Len() + if unread > 0 { + cc.mu.Lock() + // Return connection-level flow control. + if unread > 0 { + cc.inflow.add(int32(unread)) + } + cc.mu.Unlock() - headerBlock := buf.Bytes() - if len(headerBlock) == 0 { - panic("unexpected empty hpack") + // TODO(dneil): Acquiring this mutex can block indefinitely. + // Move flow control return to a goroutine? + cc.wmu.Lock() + // Return connection-level flow control. + if unread > 0 { + cc.fr.WriteWindowUpdate(0, uint32(unread)) + } + cc.bw.Flush() + cc.wmu.Unlock() } - return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) -} + cs.bufPipe.BreakWithError(http2errClosedResponseBody) + cs.abortStream(http2errClosedResponseBody) -func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { - if firstFrag { - return ctx.Framer().WritePushPromise(http2PushPromiseParam{ - StreamID: w.streamID, - PromiseID: w.promisedID, - BlockFragment: frag, - EndHeaders: lastFrag, - }) - } else { - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + select { + case <-cs.donec: + case <-cs.ctx.Done(): + // See golang/go#49366: The net/http package can cancel the + // request context after the response body is fully read. + // Don't treat this as an error. + return nil + case <-cs.reqCancel: + return http2errRequestCanceled } + return nil } -type http2write100ContinueHeadersFrame struct { - streamID uint32 -} - -func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() - http2encKV(enc, ":status", "100") - return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: buf.Bytes(), - EndStream: false, - EndHeaders: true, - }) -} - -func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { - // Sloppy but conservative: - return 9+2*(len(":status")+len("100")) <= max -} - -type http2writeWindowUpdate struct { - streamID uint32 // or 0 for conn-level - n uint32 -} - -func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } +func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { + cc := rl.cc + cs := rl.streamByID(f.StreamID) + data := f.Data() + if cs == nil { + cc.mu.Lock() + neverSent := cc.nextStreamID + cc.mu.Unlock() + if f.StreamID >= neverSent { + // We never asked for this. + cc.logf("http2: Transport received unsolicited DATA frame; closing connection") + return http2ConnectionError(http2ErrCodeProtocol) + } + // We probably did ask for this, but canceled. Just ignore it. + // TODO: be stricter here? only silently ignore things which + // we canceled, but not things which were closed normally + // by the peer? Tough without accumulating too much state. -func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) -} + // But at least return their flow control: + if f.Length > 0 { + cc.mu.Lock() + cc.inflow.add(int32(f.Length)) + cc.mu.Unlock() -// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) -// is encoded only if k is in keys. -func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { - if keys == nil { - sorter := http2sorterPool.Get().(*http2sorter) - // Using defer here, since the returned keys from the - // sorter.Keys method is only valid until the sorter - // is returned: - defer http2sorterPool.Put(sorter) - keys = sorter.Keys(h) + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(f.Length)) + cc.bw.Flush() + cc.wmu.Unlock() + } + return nil } - for _, k := range keys { - vv := h[k] - k, ascii := http2lowerHeader(k) - if !ascii { - // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header - // field names have to be ASCII characters (just as in HTTP/1.x). - continue + if cs.readClosed { + cc.logf("protocol error: received DATA after END_STREAM") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + if !cs.firstByte { + cc.logf("protocol error: received DATA before a HEADERS frame") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + if f.Length > 0 { + if cs.isHead && len(data) > 0 { + cc.logf("protocol error: received DATA on a HEAD request") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil } - if !http2validWireHeaderFieldName(k) { - // Skip it as backup paranoia. Per - // golang.org/issue/14048, these should - // already be rejected at a higher level. - continue + // Check connection-level flow control. + cc.mu.Lock() + if cs.inflow.available() >= int32(f.Length) { + cs.inflow.take(int32(f.Length)) + } else { + cc.mu.Unlock() + return http2ConnectionError(http2ErrCodeFlowControl) } - isTE := k == "transfer-encoding" - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // TODO: return an error? golang.org/issue/14048 - // For now just omit it. - continue + // Return any padded flow control now, since we won't + // refund it later on body reads. + var refund int + if pad := int(f.Length) - len(data); pad > 0 { + refund += pad + } + + didReset := false + var err error + if len(data) > 0 { + if _, err = cs.bufPipe.Write(data); err != nil { + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset = true + refund += len(data) } - // TODO: more of "8.1.2.2 Connection-Specific Header Fields" - if isTE && v != "trailers" { - continue + } + + if refund > 0 { + cc.inflow.add(int32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + } + } + cc.mu.Unlock() + + if refund > 0 { + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(refund)) + if !didReset { + cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) } - http2encKV(enc, k, v) + cc.bw.Flush() + cc.wmu.Unlock() } - } -} -// WriteScheduler is the interface implemented by HTTP/2 write schedulers. -// Methods are never called concurrently. -type http2WriteScheduler interface { - // OpenStream opens a new stream in the write scheduler. - // It is illegal to call this with streamID=0 or with a streamID that is - // already open -- the call may panic. - OpenStream(streamID uint32, options http2OpenStreamOptions) - - // CloseStream closes a stream in the write scheduler. Any frames queued on - // this stream should be discarded. It is illegal to call this on a stream - // that is not open -- the call may panic. - CloseStream(streamID uint32) - - // AdjustStream adjusts the priority of the given stream. This may be called - // on a stream that has not yet been opened or has been closed. Note that - // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: - // https://tools.ietf.org/html/rfc7540#section-5.1 - AdjustStream(streamID uint32, priority http2PriorityParam) - - // Push queues a frame in the scheduler. In most cases, this will not be - // called with wr.StreamID()!=0 unless that stream is currently open. The one - // exception is RST_STREAM frames, which may be sent on idle or closed streams. - Push(wr http2FrameWriteRequest) + if err != nil { + rl.endStreamError(cs, err) + return nil + } + } - // Pop dequeues the next frame to write. Returns false if no frames can - // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd, except RST_STREAM frames. No frames should be - // discarded except by CloseStream. - Pop() (wr http2FrameWriteRequest, ok bool) + if f.StreamEnded() { + rl.endStream(cs) + } + return nil } -// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. -type http2OpenStreamOptions struct { - // PusherID is zero if the stream was initiated by the client. Otherwise, - // PusherID names the stream that pushed the newly opened stream. - PusherID uint32 +func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { + // TODO: check that any declared content-length matches, like + // server.go's (*stream).endStream method. + if !cs.readClosed { + cs.readClosed = true + // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a + // race condition: The caller can read io.EOF from Response.Body + // and close the body before we close cs.peerClosed, causing + // cleanupWriteRequest to send a RST_STREAM. + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) + close(cs.peerClosed) + } } -// FrameWriteRequest is a request to write a frame. -type http2FrameWriteRequest struct { - // write is the interface value that does the writing, once the - // WriteScheduler has selected this frame to write. The write - // functions are all defined in write.go. - write http2writeFramer - - // stream is the stream on which this frame will be written. - // nil for non-stream frames like PING and SETTINGS. - // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. - stream *http2stream +func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { + cs.readAborted = true + cs.abortStream(err) +} - // done, if non-nil, must be a buffered channel with space for - // 1 message and is sent the return value from write (or an - // earlier error) when the frame has been written. - done chan error +func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs := rl.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs + } + return nil } -// StreamID returns the id of the stream this frame will be written to. -// 0 is used for non-stream frames such as PING and SETTINGS. -func (wr http2FrameWriteRequest) StreamID() uint32 { - if wr.stream == nil { - if se, ok := wr.write.(http2StreamError); ok { - // (*serverConn).resetStream doesn't set - // stream because it doesn't necessarily have - // one. So special case this type of write - // message. - return se.StreamID +func (cs *http2clientStream) copyTrailers() { + for k, vv := range cs.trailer { + t := cs.resTrailer + if *t == nil { + *t = make(http.Header) } - return 0 + (*t)[k] = vv } - return wr.stream.id } -// isControl reports whether wr is a control frame for MaxQueuedControlFrames -// purposes. That includes non-stream frames and RST_STREAM frames. -func (wr http2FrameWriteRequest) isControl() bool { - return wr.stream == nil -} +func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { + cc := rl.cc + cc.t.connPool().MarkDead(cc) + if f.ErrCode != 0 { + // TODO: deal with GOAWAY more. particularly the error code + cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + if fn := cc.t.CountError; fn != nil { + fn("recv_goaway_" + f.ErrCode.stringToken()) + } -// DataSize returns the number of flow control bytes that must be consumed -// to write this entire frame. This is 0 for non-DATA frames. -func (wr http2FrameWriteRequest) DataSize() int { - if wd, ok := wr.write.(*http2writeData); ok { - return len(wd.p) } - return 0 + cc.setGoAway(f) + return nil } -// Consume consumes min(n, available) bytes from this frame, where available -// is the number of flow control bytes available on the stream. Consume returns -// 0, 1, or 2 frames, where the integer return value gives the number of frames -// returned. -// -// If flow control prevents consuming any bytes, this returns (_, _, 0). If -// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this -// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and -// 'rest' contains the remaining bytes. The consumed bytes are deducted from the -// underlying stream's flow control budget. -func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { - var empty http2FrameWriteRequest - - // Non-DATA frames are always consumed whole. - wd, ok := wr.write.(*http2writeData) - if !ok || len(wd.p) == 0 { - return wr, empty, 1 - } - - // Might need to split after applying limits. - allowed := wr.stream.flow.available() - if n < allowed { - allowed = n - } - if wr.stream.sc.maxFrameSize < allowed { - allowed = wr.stream.sc.maxFrameSize - } - if allowed <= 0 { - return empty, empty, 0 - } - if len(wd.p) > int(allowed) { - wr.stream.flow.take(allowed) - consumed := http2FrameWriteRequest{ - stream: wr.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: wd.p[:allowed], - // Even if the original had endStream set, there - // are bytes remaining because len(wd.p) > allowed, - // so we know endStream is false. - endStream: false, - }, - // Our caller is blocking on the final DATA frame, not - // this intermediate frame, so no need to wait. - done: nil, - } - rest := http2FrameWriteRequest{ - stream: wr.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: wd.p[allowed:], - endStream: wd.endStream, - }, - done: wr.done, - } - return consumed, rest, 2 - } - - // The frame is consumed whole. - // NB: This cast cannot overflow because allowed is <= math.MaxInt32. - wr.stream.flow.take(int32(len(wd.p))) - return wr, empty, 1 -} - -// String is for debugging only. -func (wr http2FrameWriteRequest) String() string { - var des string - if s, ok := wr.write.(fmt.Stringer); ok { - des = s.String() - } else { - des = fmt.Sprintf("%T", wr.write) - } - return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) -} +func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { + cc := rl.cc + // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. + // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. + cc.wmu.Lock() + defer cc.wmu.Unlock() -// replyToWriter sends err to wr.done and panics if the send must block -// This does nothing if wr.done is nil. -func (wr *http2FrameWriteRequest) replyToWriter(err error) { - if wr.done == nil { - return + if err := rl.processSettingsNoWrite(f); err != nil { + return err } - select { - case wr.done <- err: - default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + if !f.IsAck() { + cc.fr.WriteSettingsAck() + cc.bw.Flush() } - wr.write = nil // prevent use (assume it's tainted after wr.done send) + return nil } -// writeQueue is used by implementations of WriteScheduler. -type http2writeQueue struct { - s []http2FrameWriteRequest -} +func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } + if f.IsAck() { + if cc.wantSettingsAck { + cc.wantSettingsAck = false + return nil + } + return http2ConnectionError(http2ErrCodeProtocol) + } -func (q *http2writeQueue) push(wr http2FrameWriteRequest) { - q.s = append(q.s, wr) -} + var seenMaxConcurrentStreams bool + err := f.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + seenMaxConcurrentStreams = true + case http2SettingMaxHeaderListSize: + cc.peerMaxHeaderListSize = uint64(s.Val) + case http2SettingInitialWindowSize: + // Values above the maximum flow-control + // window size of 2^31-1 MUST be treated as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + if s.Val > math.MaxInt32 { + return http2ConnectionError(http2ErrCodeFlowControl) + } -func (q *http2writeQueue) shift() http2FrameWriteRequest { - if len(q.s) == 0 { - panic("invalid use of queue") - } - wr := q.s[0] - // TODO: less copy-happy queue. - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2FrameWriteRequest{} - q.s = q.s[:len(q.s)-1] - return wr -} + // Adjust flow control of currently-open + // frames by the difference of the old initial + // window size and this one. + delta := int32(s.Val) - int32(cc.initialWindowSize) + for _, cs := range cc.streams { + cs.flow.add(delta) + } + cc.cond.Broadcast() -// consume consumes up to n bytes from q.s[0]. If the frame is -// entirely consumed, it is removed from the queue. If the frame -// is partially consumed, the frame is kept with the consumed -// bytes removed. Returns true iff any bytes were consumed. -func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { - if len(q.s) == 0 { - return http2FrameWriteRequest{}, false + cc.initialWindowSize = s.Val + default: + // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. + cc.vlogf("Unhandled Setting: %v", s) + } + return nil + }) + if err != nil { + return err } - consumed, rest, numresult := q.s[0].Consume(n) - switch numresult { - case 0: - return http2FrameWriteRequest{}, false - case 1: - q.shift() - case 2: - q.s[0] = rest + + if !cc.seenSettings { + if !seenMaxConcurrentStreams { + // This was the servers initial SETTINGS frame and it + // didn't contain a MAX_CONCURRENT_STREAMS field so + // increase the number of concurrent streams this + // connection can establish to our default. + cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + } + cc.seenSettings = true } - return consumed, true + + return nil } -type http2writeQueuePool []*http2writeQueue +func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { + cc := rl.cc + cs := rl.streamByID(f.StreamID) + if f.StreamID != 0 && cs == nil { + return nil + } -// put inserts an unused writeQueue into the pool. + cc.mu.Lock() + defer cc.mu.Unlock() -// put inserts an unused writeQueue into the pool. -func (p *http2writeQueuePool) put(q *http2writeQueue) { - for i := range q.s { - q.s[i] = http2FrameWriteRequest{} + fl := &cc.flow + if cs != nil { + fl = &cs.flow + } + if !fl.add(int32(f.Increment)) { + return http2ConnectionError(http2ErrCodeFlowControl) } - q.s = q.s[:0] - *p = append(*p, q) + cc.cond.Broadcast() + return nil } -// get returns an empty writeQueue. -func (p *http2writeQueuePool) get() *http2writeQueue { - ln := len(*p) - if ln == 0 { - return new(http2writeQueue) +func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { + cs := rl.streamByID(f.StreamID) + if cs == nil { + // TODO: return error if server tries to RST_STREAM an idle stream + return nil + } + serr := http2streamError(cs.ID, f.ErrCode) + serr.Cause = http2errFromPeer + if f.ErrCode == http2ErrCodeProtocol { + rl.cc.SetDoNotReuse() + } + if fn := cs.cc.t.CountError; fn != nil { + fn("recv_rststream_" + f.ErrCode.stringToken()) } - x := ln - 1 - q := (*p)[x] - (*p)[x] = nil - *p = (*p)[:x] - return q + cs.abortStream(serr) + + cs.bufPipe.CloseWithError(serr) + return nil } -// RFC 7540, Section 5.3.5: the default weight is 16. -const http2priorityDefaultWeight = 15 // 16 = 15 + 1 +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + // check for dup before insert + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + errc := make(chan error, 1) + go func() { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(false, p); err != nil { + errc <- err + return + } + if err := cc.bw.Flush(); err != nil { + errc <- err + return + } + }() + select { + case <-c: + return nil + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + // connection closed + return cc.readerErr + } +} -// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. -type http2PriorityWriteSchedulerConfig struct { - // MaxClosedNodesInTree controls the maximum number of closed streams to - // retain in the priority tree. Setting this to zero saves a small amount - // of memory at the cost of performance. - // - // See RFC 7540, Section 5.3.4: - // "It is possible for a stream to become closed while prioritization - // information ... is in transit. ... This potentially creates suboptimal - // prioritization, since the stream could be given a priority that is - // different from what is intended. To avoid these problems, an endpoint - // SHOULD retain stream prioritization state for a period after streams - // become closed. The longer state is retained, the lower the chance that - // streams are assigned incorrect or default priority values." - MaxClosedNodesInTree int - - // MaxIdleNodesInTree controls the maximum number of idle streams to - // retain in the priority tree. Setting this to zero saves a small amount - // of memory at the cost of performance. - // - // See RFC 7540, Section 5.3.4: - // Similarly, streams that are in the "idle" state can be assigned - // priority or become a parent of other streams. This allows for the - // creation of a grouping node in the dependency tree, which enables - // more flexible expressions of priority. Idle streams begin with a - // default priority (Section 5.3.5). - MaxIdleNodesInTree int - - // ThrottleOutOfOrderWrites enables write throttling to help ensure that - // data is delivered in priority order. This works around a race where - // stream B depends on stream A and both streams are about to call Write - // to queue DATA frames. If B wins the race, a naive scheduler would eagerly - // write as much data from B as possible, but this is suboptimal because A - // is a higher-priority stream. With throttling enabled, we write a small - // amount of data from B to minimize the amount of bandwidth that B can - // steal from A. - ThrottleOutOfOrderWrites bool -} - -// NewPriorityWriteScheduler constructs a WriteScheduler that schedules -// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. -// If cfg is nil, default options are used. -func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { - if cfg == nil { - // For justification of these defaults, see: - // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY - cfg = &http2PriorityWriteSchedulerConfig{ - MaxClosedNodesInTree: 10, - MaxIdleNodesInTree: 10, - ThrottleOutOfOrderWrites: false, - } - } - - ws := &http2priorityWriteScheduler{ - nodes: make(map[uint32]*http2priorityNode), - maxClosedNodesInTree: cfg.MaxClosedNodesInTree, - maxIdleNodesInTree: cfg.MaxIdleNodesInTree, - enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, - } - ws.nodes[0] = &ws.root - if cfg.ThrottleOutOfOrderWrites { - ws.writeThrottleLimit = 1024 - } else { - ws.writeThrottleLimit = math.MaxInt32 +func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { + if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + // If ack, notify listener if any + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } + return nil + } + cc := rl.cc + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(true, f.Data); err != nil { + return err } - return ws + return cc.bw.Flush() +} + +func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { + // We told the peer we don't want them. + // Spec says: + // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH + // setting of the peer endpoint is set to 0. An endpoint that + // has set this setting and has received acknowledgement MUST + // treat the receipt of a PUSH_PROMISE frame as a connection + // error (Section 5.4.1) of type PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) } -type http2priorityNodeState int +func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { + // TODO: map err to more interesting error codes, once the + // HTTP community comes up with some. But currently for + // RST_STREAM there's no equivalent to GOAWAY frame's debug + // data, and the error codes are all pretty vague ("cancel"). + cc.wmu.Lock() + cc.fr.WriteRSTStream(streamID, code) + cc.bw.Flush() + cc.wmu.Unlock() +} -const ( - http2priorityNodeOpen http2priorityNodeState = iota - http2priorityNodeClosed - http2priorityNodeIdle +var ( + http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") ) -// priorityNode is a node in an HTTP/2 priority tree. -// Each node is associated with a single stream ID. -// See RFC 7540, Section 5.3. -type http2priorityNode struct { - q http2writeQueue // queue of pending frames to write - id uint32 // id of the stream, or 0 for the root of the tree - weight uint8 // the actual weight is weight+1, so the value is in [1,256] - state http2priorityNodeState // open | closed | idle - bytes int64 // number of bytes written by this node, or 0 if closed - subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree +func (cc *http2ClientConn) logf(format string, args ...interface{}) { + cc.t.logf(format, args...) +} - // These links form the priority tree. - parent *http2priorityNode - kids *http2priorityNode // start of the kids list - prev, next *http2priorityNode // doubly-linked list of siblings +func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { + cc.t.vlogf(format, args...) } -func (n *http2priorityNode) setParent(parent *http2priorityNode) { - if n == parent { - panic("setParent to self") - } - if n.parent == parent { - return - } - // Unlink from current parent. - if parent := n.parent; parent != nil { - if n.prev == nil { - parent.kids = n.next - } else { - n.prev.next = n.next - } - if n.next != nil { - n.next.prev = n.prev - } - } - // Link to new parent. - // If parent=nil, remove n from the tree. - // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). - n.parent = parent - if parent == nil { - n.next = nil - n.prev = nil - } else { - n.next = parent.kids - n.prev = nil - if n.next != nil { - n.next.prev = n - } - parent.kids = n +func (t *http2Transport) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + t.logf(format, args...) } } -func (n *http2priorityNode) addBytes(b int64) { - n.bytes += b - for ; n != nil; n = n.parent { - n.subtreeBytes += b - } +func (t *http2Transport) logf(format string, args ...interface{}) { + log.Printf(format, args...) } -// walkReadyInOrder iterates over the tree in priority order, calling f for each node -// with a non-empty write queue. When f returns true, this function returns true and the -// walk halts. tmp is used as scratch space for sorting. -// -// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true -// if any ancestor p of n is still open (ignoring the root node). -func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { - if !n.q.empty() && f(n, openParent) { - return true - } - if n.kids == nil { - return false - } +var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) - // Don't consider the root "open" when updating openParent since - // we can't send data frames on the root stream (only control frames). - if n.id != 0 { - openParent = openParent || (n.state == http2priorityNodeOpen) - } +type http2missingBody struct{} - // Common case: only one kid or all kids have the same weight. - // Some clients don't use weights; other clients (like web browsers) - // use mostly-linear priority trees. - w := n.kids.weight - needSort := false - for k := n.kids.next; k != nil; k = k.next { - if k.weight != w { - needSort = true - break - } - } - if !needSort { - for k := n.kids; k != nil; k = k.next { - if k.walkReadyInOrder(openParent, tmp, f) { - return true - } - } - return false - } +func (http2missingBody) Close() error { return nil } - // Uncommon case: sort the child nodes. We remove the kids from the parent, - // then re-insert after sorting so we can reuse tmp for future sort calls. - *tmp = (*tmp)[:0] - for n.kids != nil { - *tmp = append(*tmp, n.kids) - n.kids.setParent(nil) - } - sort.Sort(http2sortPriorityNodeSiblings(*tmp)) - for i := len(*tmp) - 1; i >= 0; i-- { - (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids - } - for k := n.kids; k != nil; k = k.next { - if k.walkReadyInOrder(openParent, tmp, f) { +func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } + +func http2strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { return true } } return false } -type http2sortPriorityNodeSiblings []*http2priorityNode +type http2erringRoundTripper struct{ err error } + +func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } -func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } +func (rt http2erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, rt.err +} -func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type http2gzipReader struct { + _ http2incomparable + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} -func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { - // Prefer the subtree that has sent fewer bytes relative to its weight. - // See sections 5.3.2 and 5.3.4. - wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) - wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) - if bi == 0 && bk == 0 { - return wi >= wk +func (gz *http2gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr } - if bk == 0 { - return false + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } } - return bi/bk <= wi/wk + return gz.zr.Read(p) } -type http2priorityWriteScheduler struct { - // root is the root of the priority tree, where root.id = 0. - // The root queues control frames that are not associated with any stream. - root http2priorityNode - - // nodes maps stream ids to priority tree nodes. - nodes map[uint32]*http2priorityNode - - // maxID is the maximum stream id in nodes. - maxID uint32 - - // lists of nodes that have been closed or are idle, but are kept in - // the tree for improved prioritization. When the lengths exceed either - // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. - closedNodes, idleNodes []*http2priorityNode +func (gz *http2gzipReader) Close() error { + return gz.body.Close() +} - // From the config. - maxClosedNodesInTree int - maxIdleNodesInTree int - writeThrottleLimit int32 - enableWriteThrottle bool +type http2errorReader struct{ err error } - // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. - tmp []*http2priorityNode +func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } - // pool of empty queues for reuse. - queuePool http2writeQueuePool +// isConnectionCloseRequest reports whether req should use its own +// connection for a single request and then close the connection. +func http2isConnectionCloseRequest(req *http.Request) bool { + return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") } -func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { - // The stream may be currently idle but cannot be opened or closed. - if curr := ws.nodes[streamID]; curr != nil { - if curr.state != http2priorityNodeIdle { - panic(fmt.Sprintf("stream %d already opened", streamID)) +// registerHTTPSProtocol calls Transport.RegisterProtocol but +// converting panics into errors. +func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) } - curr.state = http2priorityNodeOpen - return - } - - // RFC 7540, Section 5.3.5: - // "All streams are initially assigned a non-exclusive dependency on stream 0x0. - // Pushed streams initially depend on their associated stream. In both cases, - // streams are assigned a default weight of 16." - parent := ws.nodes[options.PusherID] - if parent == nil { - parent = &ws.root - } - n := &http2priorityNode{ - q: *ws.queuePool.get(), - id: streamID, - weight: http2priorityDefaultWeight, - state: http2priorityNodeOpen, - } - n.setParent(parent) - ws.nodes[streamID] = n - if streamID > ws.maxID { - ws.maxID = streamID - } + }() + t.RegisterProtocol("https", rt) + return nil } -func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { - if streamID == 0 { - panic("violation of WriteScheduler interface: cannot close stream 0") - } - if ws.nodes[streamID] == nil { - panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) - } - if ws.nodes[streamID].state != http2priorityNodeOpen { - panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) - } - - n := ws.nodes[streamID] - n.state = http2priorityNodeClosed - n.addBytes(-n.bytes) +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +// (The field is exported so it can be accessed via reflect from net/http; tested +// by TestNoDialH2RoundTripperType) +type http2noDialH2RoundTripper struct{ *http2Transport } - q := n.q - ws.queuePool.put(&q) - n.q.s = nil - if ws.maxClosedNodesInTree > 0 { - ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) - } else { - ws.removeNode(n) +func (rt http2noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + res, err := rt.http2Transport.RoundTrip(req) + if http2isNoCachedConnError(err) { + return nil, http.ErrSkipAltProtocol } + return res, err } -func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - if streamID == 0 { - panic("adjustPriority on root") - } - - // If streamID does not exist, there are two cases: - // - A closed stream that has been removed (this will have ID <= maxID) - // - An idle stream that is being used for "grouping" (this will have ID > maxID) - n := ws.nodes[streamID] - if n == nil { - if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { - return - } - ws.maxID = streamID - n = &http2priorityNode{ - q: *ws.queuePool.get(), - id: streamID, - weight: http2priorityDefaultWeight, - state: http2priorityNodeIdle, - } - n.setParent(&ws.root) - ws.nodes[streamID] = n - ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout } + return 0 +} - // Section 5.3.1: A dependency on a stream that is not currently in the tree - // results in that stream being given a default priority (Section 5.3.5). - parent := ws.nodes[priority.StreamDep] - if parent == nil { - n.setParent(&ws.root) - n.weight = http2priorityDefaultWeight +func http2traceGetConn(req *http.Request, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GetConn == nil { return } + trace.GetConn(hostPort) +} - // Ignore if the client tries to make a node its own parent. - if n == parent { +func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { return } - - // Section 5.3.3: - // "If a stream is made dependent on one of its own dependencies, the - // formerly dependent stream is first moved to be dependent on the - // reprioritized stream's previous parent. The moved dependency retains - // its weight." - // - // That is: if parent depends on n, move parent to depend on n.parent. - for x := parent.parent; x != nil; x = x.parent { - if x == n { - parent.setParent(n.parent) - break - } - } - - // Section 5.3.3: The exclusive flag causes the stream to become the sole - // dependency of its parent stream, causing other dependencies to become - // dependent on the exclusive stream. - if priority.Exclusive { - k := parent.kids - for k != nil { - next := k.next - if k != n { - k.setParent(n) - } - k = next - } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + ci.Reused = reused + cc.mu.Lock() + ci.WasIdle = len(cc.streams) == 0 && reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) } + cc.mu.Unlock() - n.setParent(parent) - n.weight = priority.Weight + trace.GotConn(ci) } -func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { - var n *http2priorityNode - if id := wr.StreamID(); id == 0 { - n = &ws.root - } else { - n = ws.nodes[id] - if n == nil { - // id is an idle or closed stream. wr should not be a HEADERS or - // DATA frame. However, wr can be a RST_STREAM. In this case, we - // push wr onto the root, rather than creating a new priorityNode, - // since RST_STREAM is tiny and the stream's priority is unknown - // anyway. See issue #17919. - if wr.DataSize() > 0 { - panic("add DATA on non-open stream") - } - n = &ws.root - } +func http2traceWroteHeaders(trace *httptrace.ClientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() } - n.q.push(wr) -} - -func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { - ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { - limit := int32(math.MaxInt32) - if openParent { - limit = ws.writeThrottleLimit - } - wr, ok = n.q.consume(limit) - if !ok { - return false - } - n.addBytes(int64(wr.DataSize())) - // If B depends on A and B continuously has data available but A - // does not, gradually increase the throttling limit to allow B to - // steal more and more bandwidth from A. - if openParent { - ws.writeThrottleLimit += 1024 - if ws.writeThrottleLimit < 0 { - ws.writeThrottleLimit = math.MaxInt32 - } - } else if ws.enableWriteThrottle { - ws.writeThrottleLimit = 1024 - } - return true - }) - return wr, ok } -func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { - if maxSize == 0 { - return - } - if len(*list) == maxSize { - // Remove the oldest node, then shift left. - ws.removeNode((*list)[0]) - x := (*list)[1:] - copy(*list, x) - *list = (*list)[:len(x)] +func http2traceGot100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() } - *list = append(*list, n) } -func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { - for k := n.kids; k != nil; k = k.next { - k.setParent(n.parent) +func http2traceWait100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() } - n.setParent(nil) - delete(ws.nodes, n.id) } -// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 -// priorities. Control frames like SETTINGS and PING are written before DATA -// frames, but if no control frames are queued and multiple streams have queued -// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. -func http2NewRandomWriteScheduler() http2WriteScheduler { - return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } } -type http2randomWriteScheduler struct { - // zero are frames not associated with a specific stream. - zero http2writeQueue - - // sq contains the stream-specific queues, keyed by stream ID. - // When a stream is idle, closed, or emptied, it's deleted - // from the map. - sq map[uint32]*http2writeQueue - - // pool of empty queues for reuse. - queuePool http2writeQueuePool +func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } } -func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { - // no-op: idle streams are not tracked +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *serverConn. +// +// TODO: decide whether to a) use this in the client code (which didn't +// end up using this yet, because it has a simpler design, not +// currently implementing priorities), or b) delete this and +// make the server code a bit more concrete. +type http2writeContext interface { + Framer() *http2Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) } -func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { - q, ok := ws.sq[streamID] - if !ok { - return - } - delete(ws.sq, streamID) - ws.queuePool.put(q) +func (se http2StreamError) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } -func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - // no-op: priorities are ignored -} +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } -func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { - if wr.isControl() { - ws.zero.push(wr) - return - } - id := wr.StreamID() - q, ok := ws.sq[id] - if !ok { - q = ws.queuePool.get() - ws.sq[id] = q +func http2encKV(enc *hpack.Encoder, k, v string) { + if http2VerboseLogs { + log.Printf("http2: server encoding header %q = %q", k, v) } - q.push(wr) + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } -func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { - // Control and RST_STREAM frames first. - if !ws.zero.empty() { - return ws.zero.shift(), true +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only if k is in keys. +func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { + if keys == nil { + sorter := http2sorterPool.Get().(*http2sorter) + // Using defer here, since the returned keys from the + // sorter.Keys method is only valid until the sorter + // is returned: + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) } - // Iterate over all non-idle streams until finding one that can be consumed. - for streamID, q := range ws.sq { - if wr, ok := q.consume(math.MaxInt32); ok { - if q.empty() { - delete(ws.sq, streamID) - ws.queuePool.put(q) + for _, k := range keys { + vv := h[k] + k, ascii := http2lowerHeader(k) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + continue + } + if !http2validWireHeaderFieldName(k) { + // Skip it as backup paranoia. Per + // golang.org/issue/14048, these should + // already be rejected at a higher level. + continue + } + isTE := k == "transfer-encoding" + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // TODO: return an error? golang.org/issue/14048 + // For now just omit it. + continue + } + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" + if isTE && v != "trailers" { + continue } - return wr, true + http2encKV(enc, k, v) } } - return http2FrameWriteRequest{}, false } From 0c1988b8ee4b085b78b81bf69722ef46810b27c4 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Feb 2022 21:19:35 +0800 Subject: [PATCH 199/843] remove unused dump in http2Framer --- h2_bundle.go | 1 - 1 file changed, 1 deletion(-) diff --git a/h2_bundle.go b/h2_bundle.go index f4676542..404adac4 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -958,7 +958,6 @@ type http2Frame interface { // A Framer reads and writes Frames. type http2Framer struct { cc *http2ClientConn - dump *dumper r io.Reader lastFrame http2Frame errDetail error From 35ff901fc4609ef5030e59b93423ff450543a7f7 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 00:47:25 +0800 Subject: [PATCH 200/843] API refactoring and bump to v3 --- README.md | 51 ++++--- client.go | 379 ++++++++++++++++++++++++++++++---------------------- req_test.go | 2 +- request.go | 174 ++++++++++++++++-------- 4 files changed, 368 insertions(+), 238 deletions(-) diff --git a/README.md b/README.md index da1f0854..f339cb49 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,12 @@ ## Big News -Brand new v2 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) +Brand new v3 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). +> v2 is a transitional version. the latest version is v3 cuz some API-incompatible changes was involved during v2 refactoring, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. + ## Table of Contents * [Features](#Features) @@ -47,13 +49,13 @@ If you want to use the older version, check it out on [v1 branch](https://github **Install** ``` sh -go get github.com/imroc/req/v2 +go get github.com/imroc/req/v3 ``` **Import** ```go -import "github.com/imroc/req/v2" +import "github.com/imroc/req/v3" ``` ```go @@ -82,9 +84,9 @@ Checkout more runnable examples in the [examples](examples) direcotry. **Dump the Content** ```go -// Set EnableDump to true, dump all content to stdout by default, +// Enable dump for all requests, including all content to stdout by default, // including both the header and body of all request and response -client := req.C().EnableDump(true) +client := req.C().EnableDumpAll() client.R().Get("https://httpbin.org/get") /* Output @@ -117,9 +119,9 @@ access-control-allow-credentials: true */ // Customize dump settings with predefined convenience settings. -client.EnableDumpOnlyHeader(). // Only dump the header of request and response - EnableDumpAsync(). // Dump asynchronously to improve performance - EnableDumpToFile("reqdump.log") // Dump to file without printing it out +client.EnableDumpAllWithoutBody(). // Only dump the header of request and response + EnableDumpAllAsync(). // Dump asynchronously to improve performance + EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out // Send request to see the content that have been dumpped client.R().Get(url) @@ -132,24 +134,39 @@ opt := &req.DumpOptions{ ResponseHeader: false, Async: false, } -client.SetDumpOptions(opt).EnableDump(true) +client.SetCommonDumpOptions(opt).EnableDumpAll() client.R().Get("https://www.baidu.com/") // Change settings dynamiclly opt.ResponseBody = false client.R().Get("https://www.baidu.com/") -// Dump single request -resp, err := client.R().DumpAll().SetBody("test body").Post("https://httpbin.org/post") -fmt.Println(resp.Dump()) +// You can also enable dump at request level, dump to memory and do not print out +// by default, you can call `Response.Dump()` to get the dump result and print +// only if you want to. +resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") +if err != nil { + fmt.Println("err:", err) + fmt.Println("dump:", resp.Dump()) + return +} +if resp.StatusCode > 299 { + fmt.Println("bad status:", resp.Status) + fmt.Println("dump:", resp.Dump()) +} + +// And also support customize dump settings with predefined convenience settings like client level. +resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") +// ... +resp, err = client.R().SetDumpOptions(opt).SetBody("test body").Post("https://httpbin.org/post") ``` **Enable DebugLog for Deeper Insights** ```go // Logging is enabled by default, but only output the warning and error message. -// set `EnableDebugLog` to true to enable debug level logging. -client := req.C().EnableDebugLog(true) +// Use `EnableDebugLog` to enable debug level logging. +client := req.C().EnableDebugLog() client.R().Get("http://baidu.com/s?wd=req") /* Output 2022/01/26 15:46:29.279368 DEBUG [req] GET http://baidu.com/s?wd=req @@ -170,7 +187,7 @@ client.SetLogger(logger) ```go // Enable trace at request level client := req.C() -resp, err := client.R().EnableTrace(true).Get("https://api.github.com/users/imroc") +resp, err := client.R().EnableTrace().Get("https://api.github.com/users/imroc") if err != nil { log.Fatal(err) } @@ -588,7 +605,7 @@ client.SetXmlMarshal(xmlMarshalFunc).SetXmlUnmarshal(xmlUnmarshalFunc) Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). ```go -client.DisableAutoReadResponse(true) +client.DisableAutoReadResponse() resp, err := client.R().Get(url) if err != nil { @@ -696,7 +713,7 @@ Its principle is to detect whether `Content-Type` header at first, if it's not t You can also disable if you don't need or care a lot about performance: ```go -client.DisableAutoDecode(true) +client.DisableAutoDecode() ``` Also you can make some customization: diff --git a/client.go b/client.go index 27598f6f..0ed64f01 100644 --- a/client.go +++ b/client.go @@ -326,27 +326,64 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { // DisableKeepAlives is a global wrapper methods which delegated // to the default client's DisableKeepAlives. -func DisableKeepAlives(disable bool) *Client { - return defaultClient.DisableKeepAlives(disable) +func DisableKeepAlives() *Client { + return defaultClient.DisableKeepAlives() } -// DisableKeepAlives set to true disables HTTP keep-alives and -// will only use the connection to the server for a single -// HTTP request. +// DisableKeepAlives disables HTTP keep-alives (enabled by default) +// and will only use the connection to the server for a single +// HTTP request . // // This is unrelated to the similarly named TCP keep-alives. -func (c *Client) DisableKeepAlives(disable bool) *Client { - c.t.DisableKeepAlives = disable +func (c *Client) DisableKeepAlives() *Client { + c.t.DisableKeepAlives = true + return c +} + +// EnableKeepAlives is a global wrapper methods which delegated +// to the default client's EnableKeepAlives. +func EnableKeepAlives() *Client { + return defaultClient.EnableKeepAlives() +} + +// EnableKeepAlives enables HTTP keep-alives (enabled by default) +// and will only use the connection to the server for a single +// HTTP request . +// +// This is unrelated to the similarly named TCP keep-alives. +func (c *Client) EnableKeepAlives() *Client { + c.t.DisableKeepAlives = false return c } // DisableCompression is a global wrapper methods which delegated // to the default client's DisableCompression. -func DisableCompression(disable bool) *Client { - return defaultClient.DisableCompression(disable) +func DisableCompression() *Client { + return defaultClient.DisableCompression() +} + +// DisableCompression disables the compression (enabled by default), +// which prevents the Transport from +// requesting compression with an "Accept-Encoding: gzip" +// request header when the Request contains no existing +// Accept-Encoding value. If the Transport requests gzip on +// its own and gets a gzipped response, it's transparently +// decoded in the Response.Body. However, if the user +// explicitly requested gzip it is not automatically +// uncompressed. +func (c *Client) DisableCompression() *Client { + c.t.DisableCompression = true + return c +} + +// EnableCompression is a global wrapper methods which delegated +// to the default client's EnableCompression. +func EnableCompression() *Client { + return defaultClient.EnableCompression() } -// DisableCompression set to true prevents the Transport from +// EnableCompression enables the compression (enabled by default), +// which prevents the Transport from // requesting compression with an "Accept-Encoding: gzip" // request header when the Request contains no existing // Accept-Encoding value. If the Transport requests gzip on @@ -354,8 +391,8 @@ func DisableCompression(disable bool) *Client { // decoded in the Response.Body. However, if the user // explicitly requested gzip it is not automatically // uncompressed. -func (c *Client) DisableCompression(disable bool) *Client { - c.t.DisableCompression = disable +func (c *Client) EnableCompression() *Client { + c.t.DisableCompression = false return c } @@ -453,15 +490,27 @@ func (c *Client) SetCommonCookies(cookies ...*http.Cookie) *Client { return c } +// DisableDebugLog is a global wrapper methods which delegated +// to the default client's DisableDebugLog. +func DisableDebugLog() *Client { + return defaultClient.DisableDebugLog() +} + +// DisableDebugLog disables debug level log (disabled by default). +func (c *Client) DisableDebugLog() *Client { + c.DebugLog = false + return c +} + // EnableDebugLog is a global wrapper methods which delegated // to the default client's EnableDebugLog. -func EnableDebugLog(enable bool) *Client { - return defaultClient.EnableDebugLog(enable) +func EnableDebugLog() *Client { + return defaultClient.EnableDebugLog() } -// EnableDebugLog enables debug level log if set to true. -func (c *Client) EnableDebugLog(enable bool) *Client { - c.DebugLog = enable +// EnableDebugLog enables debug level log (disabled by default). +func (c *Client) EnableDebugLog() *Client { + c.DebugLog = true return c } @@ -478,8 +527,8 @@ func DevMode() *Client { // 4. Set User-Agent to pretend to be a web browser, avoid returning abnormal data from some sites. func (c *Client) DevMode() *Client { return c.EnableDumpAll(). - EnableDebugLog(true). - EnableTraceAll(true). + EnableDebugLog(). + EnableTraceAll(). SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } @@ -540,179 +589,141 @@ func (c *Client) getDumpOptions() *DumpOptions { return c.dumpOptions } -func (c *Client) enableDump() { +// EnableDumpAll is a global wrapper methods which delegated +// to the default client's EnableDumpAll. +func EnableDumpAll() *Client { + return defaultClient.EnableDumpAll() +} + +// EnableDumpAll enables dump for all requests, including +// all content for the request and response by default. +func (c *Client) EnableDumpAll() *Client { if c.t.dump != nil { // dump already started - return + return c } c.t.EnableDump(c.getDumpOptions()) + return c } -// EnableDumpToFile is a global wrapper methods which delegated -// to the default client's EnableDumpToFile. -func EnableDumpToFile(filename string) *Client { - return defaultClient.EnableDumpToFile(filename) +// EnableDumpAllToFile is a global wrapper methods which delegated +// to the default client's EnableDumpAllToFile. +func EnableDumpAllToFile(filename string) *Client { + return defaultClient.EnableDumpAllToFile(filename) } -// EnableDumpToFile enables dump and save to the specified filename. -func (c *Client) EnableDumpToFile(filename string) *Client { +// EnableDumpAllToFile enables dump and save to the specified filename. +func (c *Client) EnableDumpAllToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { c.log.Errorf("create dump file error: %v", err) return c } c.getDumpOptions().Output = file - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpTo is a global wrapper methods which delegated -// to the default client's EnableDumpTo. -func EnableDumpTo(output io.Writer) *Client { - return defaultClient.EnableDumpTo(output) +// EnableDumpAllTo is a global wrapper methods which delegated +// to the default client's EnableDumpAllTo. +func EnableDumpAllTo(output io.Writer) *Client { + return defaultClient.EnableDumpAllTo(output) } -// EnableDumpTo enables dump and save to the specified io.Writer. -func (c *Client) EnableDumpTo(output io.Writer) *Client { +// EnableDumpAllTo enables dump and save to the specified io.Writer. +func (c *Client) EnableDumpAllTo(output io.Writer) *Client { c.getDumpOptions().Output = output - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpAsync is a global wrapper methods which delegated -// to the default client's EnableDumpAsync. -func EnableDumpAsync() *Client { - return defaultClient.EnableDumpAsync() +// EnableDumpAllAsync is a global wrapper methods which delegated +// to the default client's EnableDumpAllAsync. +func EnableDumpAllAsync() *Client { + return defaultClient.EnableDumpAllAsync() } -// EnableDumpAsync enables dump and output asynchronously, +// EnableDumpAllAsync enables dump and output asynchronously, // can be used for debugging in production environment without // affecting performance. -func (c *Client) EnableDumpAsync() *Client { +func (c *Client) EnableDumpAllAsync() *Client { o := c.getDumpOptions() o.Async = true - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpNoRequestBody is a global wrapper methods which delegated -// to the default client's EnableDumpNoRequestBody. -func EnableDumpNoRequestBody() *Client { - return defaultClient.EnableDumpNoRequestBody() +// EnableDumpAllWithoutRequestBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutRequestBody. +func EnableDumpAllWithoutRequestBody() *Client { + return defaultClient.EnableDumpAllWithoutRequestBody() } -// EnableDumpNoRequestBody enables dump with request body excluded, can be +// EnableDumpAllWithoutRequestBody enables dump with request body excluded, can be // used in upload request to avoid dump the unreadable binary content. -func (c *Client) EnableDumpNoRequestBody() *Client { +func (c *Client) EnableDumpAllWithoutRequestBody() *Client { o := c.getDumpOptions() - o.ResponseHeader = true - o.ResponseBody = true o.RequestBody = false - o.RequestHeader = true - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpNoResponseBody is a global wrapper methods which delegated -// to the default client's EnableDumpNoResponseBody. -func EnableDumpNoResponseBody() *Client { - return defaultClient.EnableDumpNoResponseBody() +// EnableDumpAllWithoutResponseBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutResponseBody. +func EnableDumpAllWithoutResponseBody() *Client { + return defaultClient.EnableDumpAllWithoutResponseBody() } -// EnableDumpNoResponseBody enables dump with response body excluded, can be +// EnableDumpAllWithoutResponseBody enables dump with response body excluded, can be // used in download request to avoid dump the unreadable binary content. -func (c *Client) EnableDumpNoResponseBody() *Client { +func (c *Client) EnableDumpAllWithoutResponseBody() *Client { o := c.getDumpOptions() - o.ResponseHeader = true o.ResponseBody = false - o.RequestBody = true - o.RequestHeader = true - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpOnlyResponse is a global wrapper methods which delegated -// to the default client's EnableDumpOnlyResponse. -func EnableDumpOnlyResponse() *Client { - return defaultClient.EnableDumpOnlyResponse() +// EnableDumpAllWithoutResponse is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutResponse. +func EnableDumpAllWithoutResponse() *Client { + return defaultClient.EnableDumpAllWithoutResponse() } -// EnableDumpOnlyResponse enables dump with only response included. -func (c *Client) EnableDumpOnlyResponse() *Client { +// EnableDumpAllWithoutResponse enables dump with only request included. +func (c *Client) EnableDumpAllWithoutResponse() *Client { o := c.getDumpOptions() - o.ResponseHeader = true - o.ResponseBody = true - o.RequestBody = false - o.RequestHeader = false - c.enableDump() - return c -} - -// EnableDumpOnlyRequest is a global wrapper methods which delegated -// to the default client's EnableDumpOnlyRequest. -func EnableDumpOnlyRequest() *Client { - return defaultClient.EnableDumpOnlyRequest() -} - -// EnableDumpOnlyRequest enables dump with only request included. -func (c *Client) EnableDumpOnlyRequest() *Client { - o := c.getDumpOptions() - o.RequestHeader = true - o.RequestBody = true o.ResponseBody = false o.ResponseHeader = false - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpOnlyBody is a global wrapper methods which delegated -// to the default client's EnableDumpOnlyBody. -func EnableDumpOnlyBody() *Client { - return defaultClient.EnableDumpOnlyBody() +// EnableDumpAllWithoutHeader is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutHeader. +func EnableDumpAllWithoutHeader() *Client { + return defaultClient.EnableDumpAllWithoutHeader() } -// EnableDumpOnlyBody enables dump with only body included. -func (c *Client) EnableDumpOnlyBody() *Client { +// EnableDumpAllWithoutHeader enables dump with only body included. +func (c *Client) EnableDumpAllWithoutHeader() *Client { o := c.getDumpOptions() - o.RequestBody = true - o.ResponseBody = true o.RequestHeader = false o.ResponseHeader = false - c.enableDump() + c.EnableDumpAll() return c } -// EnableDumpOnlyHeader is a global wrapper methods which delegated -// to the default client's EnableDumpOnlyHeader. -func EnableDumpOnlyHeader() *Client { - return defaultClient.EnableDumpOnlyHeader() +// EnableDumpAllWithoutBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutBody. +func EnableDumpAllWithoutBody() *Client { + return defaultClient.EnableDumpAllWithoutBody() } -// EnableDumpOnlyHeader enables dump with only header included. -func (c *Client) EnableDumpOnlyHeader() *Client { +// EnableDumpAllWithoutBody enables dump with only header included. +func (c *Client) EnableDumpAllWithoutBody() *Client { o := c.getDumpOptions() - o.RequestHeader = true - o.ResponseHeader = true o.RequestBody = false o.ResponseBody = false - c.enableDump() - return c -} - -// EnableDumpAll is a global wrapper methods which delegated -// to the default client's EnableDumpAll. -func EnableDumpAll() *Client { - return defaultClient.EnableDumpAll() -} - -// EnableDumpAll enables dump with all content included, -// including both requests and responses' header and body -func (c *Client) EnableDumpAll() *Client { - o := c.getDumpOptions() - o.RequestHeader = true - o.RequestBody = true - o.ResponseHeader = true - o.ResponseBody = true - c.enableDump() + c.EnableDumpAll() return c } @@ -729,13 +740,25 @@ func (c *Client) NewRequest() *Request { // DisableAutoReadResponse is a global wrapper methods which delegated // to the default client's DisableAutoReadResponse. -func DisableAutoReadResponse(disable bool) *Client { - return defaultClient.DisableAutoReadResponse(disable) +func DisableAutoReadResponse() *Client { + return defaultClient.DisableAutoReadResponse() } -// DisableAutoReadResponse disable read response body automatically if set to true. -func (c *Client) DisableAutoReadResponse(disable bool) *Client { - c.disableAutoReadResponse = disable +// DisableAutoReadResponse disable read response body automatically (enabled by default). +func (c *Client) DisableAutoReadResponse() *Client { + c.disableAutoReadResponse = true + return c +} + +// EnableAutoReadResponse is a global wrapper methods which delegated +// to the default client's EnableAutoReadResponse. +func EnableAutoReadResponse() *Client { + return defaultClient.EnableAutoReadResponse() +} + +// EnableAutoReadResponse enable read response body automatically (enabled by default). +func (c *Client) EnableAutoReadResponse() *Client { + c.disableAutoReadResponse = false return c } @@ -784,13 +807,27 @@ func (c *Client) SetAutoDecodeAllType() *Client { // DisableAutoDecode is a global wrapper methods which delegated // to the default client's DisableAutoDecode. -func DisableAutoDecode(disable bool) *Client { - return defaultClient.DisableAutoDecode(disable) +func DisableAutoDecode() *Client { + return defaultClient.DisableAutoDecode() } // DisableAutoDecode disable auto detect charset and decode to utf-8 -func (c *Client) DisableAutoDecode(disable bool) *Client { - c.getResponseOptions().DisableAutoDecode = disable +// (enabled by default) +func (c *Client) DisableAutoDecode() *Client { + c.getResponseOptions().DisableAutoDecode = true + return c +} + +// EnableAutoDecode is a global wrapper methods which delegated +// to the default client's EnableAutoDecode. +func EnableAutoDecode() *Client { + return defaultClient.EnableAutoDecode() +} + +// EnableAutoDecode enables auto detect charset and decode to utf-8 +// (enabled by default) +func (c *Client) EnableAutoDecode() *Client { + c.getResponseOptions().DisableAutoDecode = true return c } @@ -869,32 +906,26 @@ func (c *Client) SetCommonContentType(ct string) *Client { return c } -// EnableDump is a global wrapper methods which delegated -// to the default client's EnableDump. -func EnableDump(enable bool) *Client { - return defaultClient.EnableDump(enable) +// DisableDumpAll is a global wrapper methods which delegated +// to the default client's DisableDumpAll. +func DisableDumpAll() *Client { + return defaultClient.DisableDumpAll() } -// EnableDump enables dump if set to true, will use a default options if -// not been set before, which dumps all the content of requests and -// responses to stdout. -func (c *Client) EnableDump(enable bool) *Client { - if !enable { - c.t.DisableDump() - return c - } - c.enableDump() +// DisableDumpAll disables the dump. +func (c *Client) DisableDumpAll() *Client { + c.t.DisableDump() return c } -// SetDumpOptions is a global wrapper methods which delegated -// to the default client's SetDumpOptions. -func SetDumpOptions(opt *DumpOptions) *Client { - return defaultClient.SetDumpOptions(opt) +// SetCommonDumpOptions is a global wrapper methods which delegated +// to the default client's SetCommonDumpOptions. +func SetCommonDumpOptions(opt *DumpOptions) *Client { + return defaultClient.SetCommonDumpOptions(opt) } -// SetDumpOptions configures the underlying Transport's DumpOptions -func (c *Client) SetDumpOptions(opt *DumpOptions) *Client { +// SetCommonDumpOptions configures the underlying Transport's DumpOptions +func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c } @@ -958,15 +989,27 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } +// DisableTraceAll is a global wrapper methods which delegated +// to the default client's DisableTraceAll. +func DisableTraceAll() *Client { + return defaultClient.DisableTraceAll() +} + +// DisableTraceAll disables the trace at client level. +func (c *Client) DisableTraceAll() *Client { + c.trace = false + return c +} + // EnableTraceAll is a global wrapper methods which delegated // to the default client's EnableTraceAll. -func EnableTraceAll(enable bool) *Client { - return defaultClient.EnableTraceAll(enable) +func EnableTraceAll() *Client { + return defaultClient.EnableTraceAll() } -// EnableTraceAll enables the trace at client level if set to true. -func (c *Client) EnableTraceAll(enable bool) *Client { - c.trace = enable +// EnableTraceAll enables the trace at client level. +func (c *Client) EnableTraceAll() *Client { + c.trace = true return c } @@ -1030,15 +1073,27 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli return c } +// DisableAllowGetMethodPayload is a global wrapper methods which delegated +// to the default client's DisableAllowGetMethodPayload. +func DisableAllowGetMethodPayload() *Client { + return defaultClient.DisableAllowGetMethodPayload() +} + +// DisableAllowGetMethodPayload disables sending GET method requests with body. +func (c *Client) DisableAllowGetMethodPayload() *Client { + c.AllowGetMethodPayload = false + return c +} + // EnableAllowGetMethodPayload is a global wrapper methods which delegated // to the default client's EnableAllowGetMethodPayload. -func EnableAllowGetMethodPayload(a bool) *Client { - return defaultClient.EnableAllowGetMethodPayload(a) +func EnableAllowGetMethodPayload() *Client { + return defaultClient.EnableAllowGetMethodPayload() } -// EnableAllowGetMethodPayload allows sending GET method requests with body if set to true. -func (c *Client) EnableAllowGetMethodPayload(a bool) *Client { - c.AllowGetMethodPayload = a +// EnableAllowGetMethodPayload allows sending GET method requests with body. +func (c *Client) EnableAllowGetMethodPayload() *Client { + c.AllowGetMethodPayload = true return c } diff --git a/req_test.go b/req_test.go index a927a0d8..135115f8 100644 --- a/req_test.go +++ b/req_test.go @@ -10,7 +10,7 @@ import ( ) func tc() *Client { - return C().EnableDebugLog(true) + return C().EnableDebugLog() } func getTestDataPath() string { diff --git a/request.go b/request.go index 9f87f1fb..84b2c7ea 100644 --- a/request.go +++ b/request.go @@ -29,6 +29,7 @@ type Request struct { RawRequest *http.Request StartTime time.Time + dumpOptions *DumpOptions marshalBody interface{} ctx context.Context isMultiPart bool @@ -759,20 +760,28 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } +// DisableTrace is a global wrapper methods which delegated +// to the default client, create a request and DisableTrace for request. +func DisableTrace() *Request { + return defaultClient.R().DisableTrace() +} + +// DisableTrace disables trace. +func (r *Request) DisableTrace() *Request { + r.trace = nil + return r +} + // EnableTrace is a global wrapper methods which delegated // to the default client, create a request and EnableTrace for request. -func EnableTrace(enable bool) *Request { - return defaultClient.R().EnableTrace(enable) +func EnableTrace() *Request { + return defaultClient.R().EnableTrace() } -// EnableTrace enables trace if set to true. -func (r *Request) EnableTrace(enable bool) *Request { - if enable { - if r.trace == nil { - r.trace = &clientTrace{} - } - } else { - r.trace = nil +// EnableTrace enables trace. +func (r *Request) EnableTrace() *Request { + if r.trace == nil { + r.trace = &clientTrace{} } return r } @@ -784,75 +793,124 @@ func (r *Request) getDumpBuffer() *bytes.Buffer { return r.dumpBuffer } -// Dump is a global wrapper methods which delegated -// to the default client, create a request and Dump for request. -func Dump(reqHeader, reqBody, respHeader, respBody bool) *Request { - return defaultClient.R().Dump(reqHeader, reqBody, respHeader, respBody) +func (r *Request) getDumpOptions() *DumpOptions { + if r.dumpOptions == nil { + r.dumpOptions = &DumpOptions{ + RequestHeader: true, + RequestBody: true, + ResponseHeader: true, + ResponseBody: true, + Output: r.getDumpBuffer(), + } + } + return r.dumpOptions +} + +// EnableDumpTo is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpTo for request. +func EnableDumpTo(output io.Writer) *Request { + return defaultClient.R().EnableDumpTo(output) +} + +// EnableDumpTo enables dump and save to the specified io.Writer. +func (r *Request) EnableDumpTo(output io.Writer) *Request { + r.getDumpOptions().Output = output + return r.EnableDump() } -// Dump enables the dump for the request and response. -func (r *Request) Dump(reqHeader, reqBody, respHeader, respBody bool) *Request { - opt := &DumpOptions{ - RequestHeader: reqHeader, - RequestBody: reqBody, - ResponseHeader: respHeader, - ResponseBody: respBody, - Output: r.getDumpBuffer(), +// EnableDumpToFile is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpToFile for request. +func EnableDumpToFile(filename string) *Request { + return defaultClient.R().EnableDumpToFile(filename) +} + +// EnableDumpToFile enables dump and save to the specified filename. +func (r *Request) EnableDumpToFile(filename string) *Request { + file, err := os.Create(filename) + if err != nil { + r.appendError(err) + return r + } + r.getDumpOptions().Output = file + return r.EnableDump() +} + +func SetDumpOptions(opt *DumpOptions) *Request { + return defaultClient.R().SetDumpOptions(opt) +} + +// SetDumpOptions sets DumpOptions at request level. +func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { + if opt == nil { + return r } - return r.SetContext(context.WithValue(r.Context(), "dumper", newDumper(opt))) + r.dumpOptions = opt + return r } -// DumpAll is a global wrapper methods which delegated -// to the default client, create a request and DumpAll for request. -func DumpAll() *Request { - return defaultClient.R().DumpAll() +// EnableDump is a global wrapper methods which delegated +// to the default client, create a request and EnableDump for request. +func EnableDump() *Request { + return defaultClient.R().EnableDump() } -// DumpAll enables dump all content for the request and response. -func (r *Request) DumpAll() *Request { - return r.Dump(true, true, true, true) +// EnableDump enables dump, including all content for the request and response by default. +func (r *Request) EnableDump() *Request { + return r.SetContext(context.WithValue(r.Context(), "dumper", newDumper(r.getDumpOptions()))) } -// DumpOnlyHeader is a global wrapper methods which delegated -// to the default client, create a request and DumpOnlyHeader for request. -func DumpOnlyHeader() *Request { - return defaultClient.R().DumpOnlyHeader() +// EnableDumpWithoutBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutBody for request. +func EnableDumpWithoutBody() *Request { + return defaultClient.R().EnableDumpWithoutBody() } -// DumpOnlyHeader enables dump only header for the request and response. -func (r *Request) DumpOnlyHeader() *Request { - return r.Dump(true, false, true, false) +// EnableDumpWithoutBody enables dump only header for the request and response. +func (r *Request) EnableDumpWithoutBody() *Request { + o := r.getDumpOptions() + o.RequestBody = false + o.ResponseBody = false + return r.EnableDump() } -// DumpOnlyBody is a global wrapper methods which delegated -// to the default client, create a request and DumpOnlyBody for request. -func DumpOnlyBody() *Request { - return defaultClient.R().DumpOnlyBody() +// EnableDumpWithoutHeader is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutHeader for request. +func EnableDumpWithoutHeader() *Request { + return defaultClient.R().EnableDumpWithoutHeader() } -// DumpOnlyBody enables dump only body for the request and response. -func (r *Request) DumpOnlyBody() *Request { - return r.Dump(false, true, false, true) +// EnableDumpWithoutHeader enables dump only body for the request and response. +func (r *Request) EnableDumpWithoutHeader() *Request { + o := r.getDumpOptions() + o.RequestHeader = false + o.ResponseHeader = false + return r.EnableDump() } -// DumpOnlyRequest is a global wrapper methods which delegated -// to the default client, create a request and DumpOnlyRequest for request. -func DumpOnlyRequest() *Request { - return defaultClient.R().DumpOnlyRequest() +// EnableDumpWithoutResponse is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutResponse for request. +func EnableDumpWithoutResponse() *Request { + return defaultClient.R().EnableDumpWithoutResponse() } -// DumpOnlyRequest enables dump only request. -func (r *Request) DumpOnlyRequest() *Request { - return r.Dump(true, true, false, false) +// EnableDumpWithoutResponse enables dump only request. +func (r *Request) EnableDumpWithoutResponse() *Request { + o := r.getDumpOptions() + o.ResponseHeader = false + o.ResponseBody = false + return r.EnableDump() } -// DumpOnlyResponse is a global wrapper methods which delegated -// to the default client, create a request and DumpOnlyResponse for request. -func DumpOnlyResponse() *Request { - return defaultClient.R().DumpOnlyResponse() +// EnableDumpWithoutRequest is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutRequest for request. +func EnableDumpWithoutRequest() *Request { + return defaultClient.R().EnableDumpWithoutRequest() } -// DumpOnlyResponse enables dump only response. -func (r *Request) DumpOnlyResponse() *Request { - return r.Dump(false, false, true, true) +// EnableDumpWithoutRequest enables dump only response. +func (r *Request) EnableDumpWithoutRequest() *Request { + o := r.getDumpOptions() + o.RequestHeader = false + o.RequestBody = false + return r.EnableDump() } From 942395798fd21639a600d538280b95d030d4540c Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 00:56:18 +0800 Subject: [PATCH 201/843] rename import path to v3 --- client.go | 2 +- decode.go | 2 +- examples/find-popular-repo/go.mod | 2 +- examples/find-popular-repo/main.go | 4 ++-- examples/upload/uploadclient/main.go | 2 +- go.mod | 2 +- http.go | 2 +- http_request.go | 2 +- http_response.go | 2 +- internal/chunked.go | 2 +- middleware.go | 2 +- request.go | 2 +- textproto_reader.go | 2 +- transfer.go | 4 ++-- transport.go | 6 +++--- transport_test.go | 2 +- 16 files changed, 20 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 0ed64f01..b313f858 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,7 @@ import ( "encoding/json" "encoding/xml" "errors" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "golang.org/x/net/publicsuffix" "io" "io/ioutil" diff --git a/decode.go b/decode.go index 5205f83d..7d4d273b 100644 --- a/decode.go +++ b/decode.go @@ -1,7 +1,7 @@ package req import ( - "github.com/imroc/req/v2/internal/charsetutil" + "github.com/imroc/req/v3/internal/charsetutil" "io" "strings" ) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 8863c5d7..7b3e9d26 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,4 @@ module find-popular-repo go 1.13 -require github.com/imroc/req/v2 v2.1.0 +require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 5cd98ee4..b5194582 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -5,7 +5,7 @@ import ( "fmt" "strconv" - "github.com/imroc/req/v2" + "github.com/imroc/req/v3" ) // Change the name if you want @@ -21,7 +21,7 @@ func main() { } func init() { - req.EnableDumpOnlyHeader().EnableDebugLog(true).EnableTraceAll(true) + req.EnableDumpWithoutBody().EnableDebugLog().EnableTraceAll() } type Repo struct { diff --git a/examples/upload/uploadclient/main.go b/examples/upload/uploadclient/main.go index 69690cc5..09916de6 100644 --- a/examples/upload/uploadclient/main.go +++ b/examples/upload/uploadclient/main.go @@ -1,6 +1,6 @@ package main -import "github.com/imroc/req/v2" +import "github.com/imroc/req/v3" func main() { req.EnableDumpNoRequestBody() diff --git a/go.mod b/go.mod index 8c7b6bc0..354cd3f8 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/imroc/req/v2 +module github.com/imroc/req/v3 go 1.13 diff --git a/http.go b/http.go index 361baae7..ff99699f 100644 --- a/http.go +++ b/http.go @@ -3,7 +3,7 @@ package req import ( "encoding/base64" "fmt" - "github.com/imroc/req/v2/internal/ascii" + "github.com/imroc/req/v3/internal/ascii" "golang.org/x/net/http/httpguts" "golang.org/x/net/idna" "io" diff --git a/http_request.go b/http_request.go index a7597abf..f9c1c170 100644 --- a/http_request.go +++ b/http_request.go @@ -4,7 +4,7 @@ import ( "bufio" "errors" "fmt" - "github.com/imroc/req/v2/internal/ascii" + "github.com/imroc/req/v3/internal/ascii" "golang.org/x/net/http/httpguts" "io" "net/http" diff --git a/http_response.go b/http_response.go index b0eda117..c72eb4fa 100644 --- a/http_response.go +++ b/http_response.go @@ -7,7 +7,7 @@ package req import ( - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "io" "net/http" "strconv" diff --git a/internal/chunked.go b/internal/chunked.go index 64bcb4d3..bd414ebb 100644 --- a/internal/chunked.go +++ b/internal/chunked.go @@ -13,7 +13,7 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "io" ) diff --git a/middleware.go b/middleware.go index 8c8eb5e7..35cfe57e 100644 --- a/middleware.go +++ b/middleware.go @@ -3,7 +3,7 @@ package req import ( "bytes" "fmt" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" "mime/multipart" diff --git a/request.go b/request.go index 84b2c7ea..27609c83 100644 --- a/request.go +++ b/request.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "github.com/hashicorp/go-multierror" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" "net/http" diff --git a/textproto_reader.go b/textproto_reader.go index 293830d0..520159cb 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -8,7 +8,7 @@ import ( "bufio" "bytes" "fmt" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" "net/textproto" diff --git a/transfer.go b/transfer.go index cf35a702..c0ff2512 100644 --- a/transfer.go +++ b/transfer.go @@ -9,8 +9,8 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/v2/internal" - "github.com/imroc/req/v2/internal/ascii" + "github.com/imroc/req/v3/internal" + "github.com/imroc/req/v3/internal/ascii" "io" "io/ioutil" "net/http" diff --git a/transport.go b/transport.go index 5b781a40..1bb65663 100644 --- a/transport.go +++ b/transport.go @@ -17,9 +17,9 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/imroc/req/v2/internal/ascii" - "github.com/imroc/req/v2/internal/godebug" - "github.com/imroc/req/v2/internal/util" + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/godebug" + "github.com/imroc/req/v3/internal/util" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" diff --git a/transport_test.go b/transport_test.go index ef7e9f45..dc708a80 100644 --- a/transport_test.go +++ b/transport_test.go @@ -11,7 +11,7 @@ import ( "context" "crypto/tls" "errors" - "github.com/imroc/req/v2/internal/testcert" + "github.com/imroc/req/v3/internal/testcert" "io" "net" "net/http" From 6d5b56d5b9842276427193c92c7876029a8a352d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:01:05 +0800 Subject: [PATCH 202/843] update examples --- examples/find-popular-repo/go.mod | 4 ++++ examples/find-popular-repo/main.go | 2 +- examples/upload/uploadclient/go.mod | 6 +++++- examples/upload/uploadclient/main.go | 6 ++++-- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 7b3e9d26..ccd9a56d 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,4 +2,8 @@ module find-popular-repo go 1.13 +replace ( + github.com/imroc/req/v3 => ../../ +) + require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index b5194582..6c17e86f 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -21,7 +21,7 @@ func main() { } func init() { - req.EnableDumpWithoutBody().EnableDebugLog().EnableTraceAll() + req.EnableDumpAllWithoutBody().EnableDebugLog().EnableTraceAll() } type Repo struct { diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index d0bfc686..9d257447 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -2,4 +2,8 @@ module uploadclient go 1.13 -require github.com/imroc/req/v2 v2.1.0 +replace ( + github.com/imroc/req/v3 => ../../../ +) + +require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/upload/uploadclient/main.go b/examples/upload/uploadclient/main.go index 09916de6..d9c8fd34 100644 --- a/examples/upload/uploadclient/main.go +++ b/examples/upload/uploadclient/main.go @@ -1,9 +1,11 @@ package main -import "github.com/imroc/req/v3" +import ( + "github.com/imroc/req/v3" +) func main() { - req.EnableDumpNoRequestBody() + req.EnableDumpAllWithoutRequestBody() req.SetFile("files", "../../../README.md"). SetFile("files", "../../../LICENSE"). SetFormData(map[string]string{ From de279c8e1d43ea40badb0fd9818df708007fab72 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:05:00 +0800 Subject: [PATCH 203/843] go mod tidy for examples --- examples/find-popular-repo/go.mod | 4 +--- examples/find-popular-repo/go.sum | 2 -- examples/upload/uploadclient/go.mod | 4 +--- examples/upload/uploadclient/go.sum | 2 -- 4 files changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index ccd9a56d..19199729 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -2,8 +2,6 @@ module find-popular-repo go 1.13 -replace ( - github.com/imroc/req/v3 => ../../ -) +replace github.com/imroc/req/v3 => ../../ require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 21bcb4e1..59f8d4e3 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -2,8 +2,6 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.1.0 h1:zs14o2Pv/3RwAF11HBmjzJJ5ZItOgLk9yABTypbl8nk= -github.com/imroc/req/v2 v2.1.0/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index 9d257447..c95a1ec4 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -2,8 +2,6 @@ module uploadclient go 1.13 -replace ( - github.com/imroc/req/v3 => ../../../ -) +replace github.com/imroc/req/v3 => ../../../ require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 21bcb4e1..59f8d4e3 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -2,8 +2,6 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/imroc/req/v2 v2.1.0 h1:zs14o2Pv/3RwAF11HBmjzJJ5ZItOgLk9yABTypbl8nk= -github.com/imroc/req/v2 v2.1.0/go.mod h1:3POMCRC7mUbCcscEp9wpihSyZLUVYWqvmHnwTdL6kJY= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= From 9784e32795eac84186c754ea187ac136f55419ab Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:06:16 +0800 Subject: [PATCH 204/843] update default UserAgent to v3 --- req.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/req.go b/req.go index 6bcb147a..095ff67a 100644 --- a/req.go +++ b/req.go @@ -2,7 +2,7 @@ package req const ( hdrUserAgentKey = "User-Agent" - hdrUserAgentValue = "req/v2 (https://github.com/imroc/req)" + hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" hdrContentTypeKey = "Content-Type" plainTextType = "text/plain; charset=utf-8" jsonContentType = "application/json; charset=utf-8" From 64e749e99caa82f4f5de57c2d1fd8d5614d81405 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:16:40 +0800 Subject: [PATCH 205/843] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f339cb49..3f4a4611 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ client.R().Get("https://www.baidu.com/") opt.ResponseBody = false client.R().Get("https://www.baidu.com/") -// You can also enable dump at request level, dump to memory and do not print out +// You can also enable dump at request level, dump to memory and do not print it out // by default, you can call `Response.Dump()` to get the dump result and print // only if you want to. resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") @@ -155,10 +155,10 @@ if resp.StatusCode > 299 { fmt.Println("dump:", resp.Dump()) } -// And also support customize dump settings with predefined convenience settings like client level. +// Similarly, also support to customize dump settings with predefined convenience settings at request level. resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") // ... -resp, err = client.R().SetDumpOptions(opt).SetBody("test body").Post("https://httpbin.org/post") +resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") ``` **Enable DebugLog for Deeper Insights** From 81bcb15fc0fc94b682d4f4da246da5ef41173410 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:20:43 +0800 Subject: [PATCH 206/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3f4a4611..3f651c14 100644 --- a/README.md +++ b/README.md @@ -147,12 +147,12 @@ client.R().Get("https://www.baidu.com/") resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") if err != nil { fmt.Println("err:", err) - fmt.Println("dump:", resp.Dump()) + fmt.Println("raw content:\n", resp.Dump()) return } if resp.StatusCode > 299 { fmt.Println("bad status:", resp.Status) - fmt.Println("dump:", resp.Dump()) + fmt.Println("raw content:\n", resp.Dump()) } // Similarly, also support to customize dump settings with predefined convenience settings at request level. From 14ed8e6a49b1d9f91a8324a5e44f80b7a0259de9 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 01:29:02 +0800 Subject: [PATCH 207/843] update README --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 3f651c14..287fdb05 100644 --- a/README.md +++ b/README.md @@ -84,8 +84,9 @@ Checkout more runnable examples in the [examples](examples) direcotry. **Dump the Content** ```go -// Enable dump for all requests, including all content to stdout by default, -// including both the header and body of all request and response +// Enable dump at client level, which will dump for all requests, +// including all content of request and response and output +// to stdout by default. client := req.C().EnableDumpAll() client.R().Get("https://httpbin.org/get") @@ -118,14 +119,14 @@ access-control-allow-credentials: true } */ -// Customize dump settings with predefined convenience settings. +// Customize client level dump settings with predefined convenience settings. client.EnableDumpAllWithoutBody(). // Only dump the header of request and response EnableDumpAllAsync(). // Dump asynchronously to improve performance EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out // Send request to see the content that have been dumpped client.R().Get(url) -// Enable dump with fully customized settings +// Enable dump with fully customized settings at client level. opt := &req.DumpOptions{ Output: os.Stdout, RequestHeader: true, @@ -141,7 +142,7 @@ client.R().Get("https://www.baidu.com/") opt.ResponseBody = false client.R().Get("https://www.baidu.com/") -// You can also enable dump at request level, dump to memory and do not print it out +// You can also enable dump at request level, dump to memory and will not print it out // by default, you can call `Response.Dump()` to get the dump result and print // only if you want to. resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") @@ -153,6 +154,7 @@ if err != nil { if resp.StatusCode > 299 { fmt.Println("bad status:", resp.Status) fmt.Println("raw content:\n", resp.Dump()) + return } // Similarly, also support to customize dump settings with predefined convenience settings at request level. From a4bd140b1e9fe0d6a96c4c3900cc9eefee5cb37f Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 08:29:06 +0800 Subject: [PATCH 208/843] update README: GoDoc to v3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 287fdb05..59ca7c95 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

Req

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

-

+

## Big News From 6dddaa7478268e8a7fb77bf3f8d7c251355993df Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 08:32:24 +0800 Subject: [PATCH 209/843] update README: v2 explanation --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 59ca7c95..24274e6f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Brand new v3 version is out, which is completely rewritten, bringing revolutiona If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). -> v2 is a transitional version. the latest version is v3 cuz some API-incompatible changes was involved during v2 refactoring, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. +> v2 is a transitional version, cuz some breaking changes were introduced during v2 refactoring, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. ## Table of Contents From 51c4c688215f0f54a56b754b7628968c04db1e7b Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 08:52:44 +0800 Subject: [PATCH 210/843] update README: add multiple query parameter --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 24274e6f..eb7b6c2f 100644 --- a/README.md +++ b/README.md @@ -293,6 +293,7 @@ Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query pa ```go client := req.C().DevMode() +// Set query parameter at request level. client.R(). SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url SetQueryParams(map[string]string{ // Set multiple query params at once @@ -305,7 +306,7 @@ client.R(). ... */ -// You can also set the common QueryParam for every request on client +// You can also set the query parameter at client level. client.SetCommonQueryParam(k, v). SetCommonQueryParams(queryParams). SetCommonQueryString(queryString). @@ -315,10 +316,15 @@ resp1, err := client.Get(url1) resp2, err := client.Get(url2) ... -// And you can add query parameter with multiple values +// Add query parameter with multiple values at request level. client.R().AddQueryParam("key", "value1").AddQueryParam("key", "value2").Get("https://httpbin.org/get") +/* Output +2022/02/05 08:49:26.260780 DEBUG [req] GET https://httpbin.org/get?key=value1&key=value2 +... + */ + -// Same as client level settings +// Multiple values also supported at client level. client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") ``` From 1d2de4ab74e9ef3355f2808f40e1944f164c2d4f Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 09:05:26 +0800 Subject: [PATCH 211/843] update comments of client --- client.go | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index b313f858..4bfa9743 100644 --- a/client.go +++ b/client.go @@ -611,7 +611,8 @@ func EnableDumpAllToFile(filename string) *Client { return defaultClient.EnableDumpAllToFile(filename) } -// EnableDumpAllToFile enables dump and save to the specified filename. +// EnableDumpAllToFile enables dump for all requests and save +// to the specified filename. func (c *Client) EnableDumpAllToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { @@ -629,7 +630,8 @@ func EnableDumpAllTo(output io.Writer) *Client { return defaultClient.EnableDumpAllTo(output) } -// EnableDumpAllTo enables dump and save to the specified io.Writer. +// EnableDumpAllTo enables dump for all requests and save +// to the specified io.Writer. func (c *Client) EnableDumpAllTo(output io.Writer) *Client { c.getDumpOptions().Output = output c.EnableDumpAll() @@ -642,9 +644,9 @@ func EnableDumpAllAsync() *Client { return defaultClient.EnableDumpAllAsync() } -// EnableDumpAllAsync enables dump and output asynchronously, -// can be used for debugging in production environment without -// affecting performance. +// EnableDumpAllAsync enables dump for all requests and output +// asynchronously, can be used for debugging in production +// environment without affecting performance. func (c *Client) EnableDumpAllAsync() *Client { o := c.getDumpOptions() o.Async = true @@ -658,8 +660,9 @@ func EnableDumpAllWithoutRequestBody() *Client { return defaultClient.EnableDumpAllWithoutRequestBody() } -// EnableDumpAllWithoutRequestBody enables dump with request body excluded, can be -// used in upload request to avoid dump the unreadable binary content. +// EnableDumpAllWithoutRequestBody enables dump for all requests, with +// request body excluded, can be used in upload request to avoid dump +// the unreadable binary content. func (c *Client) EnableDumpAllWithoutRequestBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -673,8 +676,9 @@ func EnableDumpAllWithoutResponseBody() *Client { return defaultClient.EnableDumpAllWithoutResponseBody() } -// EnableDumpAllWithoutResponseBody enables dump with response body excluded, can be -// used in download request to avoid dump the unreadable binary content. +// EnableDumpAllWithoutResponseBody enables dump for all requests, with +// response body excluded, can be used in download request to avoid dump +// the unreadable binary content. func (c *Client) EnableDumpAllWithoutResponseBody() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -688,7 +692,8 @@ func EnableDumpAllWithoutResponse() *Client { return defaultClient.EnableDumpAllWithoutResponse() } -// EnableDumpAllWithoutResponse enables dump with only request included. +// EnableDumpAllWithoutResponse enables dump for all requests with only +// request header and body included. func (c *Client) EnableDumpAllWithoutResponse() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -703,7 +708,8 @@ func EnableDumpAllWithoutHeader() *Client { return defaultClient.EnableDumpAllWithoutHeader() } -// EnableDumpAllWithoutHeader enables dump with only body included. +// EnableDumpAllWithoutHeader enables dump for all requests with only +// body of request and response included. func (c *Client) EnableDumpAllWithoutHeader() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -718,7 +724,8 @@ func EnableDumpAllWithoutBody() *Client { return defaultClient.EnableDumpAllWithoutBody() } -// EnableDumpAllWithoutBody enables dump with only header included. +// EnableDumpAllWithoutBody enables dump for all requests with only header +// of request and response included. func (c *Client) EnableDumpAllWithoutBody() *Client { o := c.getDumpOptions() o.RequestBody = false From ebb4e3526ecc05ffc79de368306b8682e932e2d2 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 11:09:05 +0800 Subject: [PATCH 212/843] use getDumperOverride in http/1.1 --- http_response.go | 6 +----- transport.go | 15 ++++----------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/http_response.go b/http_response.go index c72eb4fa..c8980680 100644 --- a/http_response.go +++ b/http_response.go @@ -27,11 +27,7 @@ var respExcludeHeader = map[string]bool{ // After that call, clients can inspect resp.Trailer to find key/value // pairs included in the response trailer. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - //var tp headReader - dump := pc.t.dump - if d, ok := req.Context().Value("dumper").(*dumper); ok { - dump = d - } + dump := getDumperOverride(pc.t.dump, req.Context()) tp := newTextprotoReader(pc.br, dump) resp := &http.Response{ Request: req, diff --git a/transport.go b/transport.go index 1bb65663..176e4479 100644 --- a/transport.go +++ b/transport.go @@ -293,14 +293,10 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { } func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { - dump := t.dump - if d, ok := req.Context().Value("dumper").(*dumper); ok { - dump = d + dump := getDumperOverride(t.dump, req.Context()) + if dump != nil && dump.ResponseBody { + res.Body = dump.WrapReadCloser(res.Body) } - if dump == nil || !dump.ResponseBody { - return - } - res.Body = dump.WrapReadCloser(res.Body) } func (t *Transport) autoDecodeResponseBody(res *http.Response) { @@ -2455,10 +2451,7 @@ func (pc *persistConn) writeLoop() { case wr := <-pc.writech: startBytesWritten := pc.nwrite ctx := wr.req.Request.Context() - dump := pc.t.dump - if d, ok := ctx.Value("dumper").(*dumper); ok { - dump = d - } + dump := getDumperOverride(pc.t.dump, ctx) err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), dump) if bre, ok := err.(requestBodyReadError); ok { err = bre.error From 8f11b79d7ddbd6dce0a4a85bffd5473dce71ff0e Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 13:04:59 +0800 Subject: [PATCH 213/843] improve test infra --- req_test.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/req_test.go b/req_test.go index 135115f8..fb771039 100644 --- a/req_test.go +++ b/req_test.go @@ -22,29 +22,38 @@ func createTestServer(fn func(w http.ResponseWriter, r *http.Request)) *httptest return httptest.NewServer(http.HandlerFunc(fn)) } +func createPostServer(t *testing.T) *httptest.Server { + ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + w.Write([]byte("TestPost: text response")) + } + }) + return ts +} + func createGetServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { switch r.URL.Path { case "/": - _, _ = w.Write([]byte("TestGet: text response")) + w.Write([]byte("TestGet: text response")) case "/no-content": - _, _ = w.Write([]byte("")) + w.Write([]byte("")) case "/json": w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"TestGet": "JSON response"}`)) + w.Write([]byte(`{"TestGet": "JSON response"}`)) case "/json-invalid": w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte("TestGet: Invalid JSON")) + w.Write([]byte("TestGet: Invalid JSON")) case "/long-text": - _, _ = w.Write([]byte("TestGet: text response with size > 30")) + w.Write([]byte("TestGet: text response with size > 30")) case "/long-json": w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) + w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) case "/bad-request": w.WriteHeader(http.StatusBadRequest) case "/host-header": - _, _ = w.Write([]byte(r.Host)) + w.Write([]byte(r.Host)) } } }) From 0bea1b5debc81f5481e0b6942845a88a7be1e68f Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 13:05:41 +0800 Subject: [PATCH 214/843] default output to buffer in Request.SetDumpOptions --- request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/request.go b/request.go index 27609c83..c55343d0 100644 --- a/request.go +++ b/request.go @@ -844,6 +844,9 @@ func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { if opt == nil { return r } + if opt.Output == nil { + opt.Output = r.getDumpBuffer() + } r.dumpOptions = opt return r } From ed97e94db660a590e572c6fb90333905e4c25517 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 13:13:51 +0800 Subject: [PATCH 215/843] case insensitive charset name in auto-decode Content-Type --- transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport.go b/transport.go index 176e4479..373af7b6 100644 --- a/transport.go +++ b/transport.go @@ -321,7 +321,7 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { panic(err) } if charset, ok := params["charset"]; ok { - // TODO: log charset + charset = strings.ToLower(charset) if strings.Contains(charset, "utf-8") || strings.Contains(charset, "utf8") { // do not decode utf-8 return } From 945f57e7e9515170b275d9049cdff627587b47bf Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Feb 2022 16:09:51 +0800 Subject: [PATCH 216/843] set tls client config to t2 --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 4bfa9743..8bdb6517 100644 --- a/client.go +++ b/client.go @@ -405,6 +405,7 @@ func SetTLSClientConfig(conf *tls.Config) *Client { // SetTLSClientConfig sets the client tls config. func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf + c.t2.TLSClientConfig = conf return c } From e8eca916c769f722ff9b1833c3d954d1834de509 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 13:10:54 +0800 Subject: [PATCH 217/843] move uploadFile to req.go --- file.go | 9 --------- req.go | 8 ++++++++ 2 files changed, 8 insertions(+), 9 deletions(-) delete mode 100644 file.go diff --git a/file.go b/file.go deleted file mode 100644 index 52026d4d..00000000 --- a/file.go +++ /dev/null @@ -1,9 +0,0 @@ -package req - -import "io" - -type uploadFile struct { - ParamName string - FilePath string - io.Reader -} diff --git a/req.go b/req.go index 095ff67a..5760308d 100644 --- a/req.go +++ b/req.go @@ -1,5 +1,7 @@ package req +import "io" + const ( hdrUserAgentKey = "User-Agent" hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" @@ -9,3 +11,9 @@ const ( xmlContentType = "text/xml; charset=utf-8" formContentType = "application/x-www-form-urlencoded" ) + +type uploadFile struct { + ParamName string + FilePath string + io.Reader +} From 7e9b31dccabd0d694ee67acc7b71d94f5318baee Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 13:13:29 +0800 Subject: [PATCH 218/843] move transport dump methods to transport.go --- dump.go | 13 ------------- transport.go | 13 +++++++++++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dump.go b/dump.go index 192a48ee..3caddc5e 100644 --- a/dump.go +++ b/dump.go @@ -158,19 +158,6 @@ func (d *dumper) Start() { } } -func (t *Transport) EnableDump(opt *DumpOptions) { - dump := newDumper(opt) - t.dump = dump - go dump.Start() -} - -func (t *Transport) DisableDump() { - if t.dump != nil { - t.dump.Stop() - t.dump = nil - } -} - func getDumperOverride(dump *dumper, ctx context.Context) *dumper { if d, ok := ctx.Value("dumper").(*dumper); ok { return d diff --git a/transport.go b/transport.go index 373af7b6..747182c2 100644 --- a/transport.go +++ b/transport.go @@ -411,6 +411,19 @@ func (t *Transport) Clone() *Transport { return t2 } +func (t *Transport) EnableDump(opt *DumpOptions) { + dump := newDumper(opt) + t.dump = dump + go dump.Start() +} + +func (t *Transport) DisableDump() { + if t.dump != nil { + t.dump.Stop() + t.dump = nil + } +} + // h2Transport is the interface we expect to be able to call from // net/http against an *http2.Transport that's either bundled into // h2_bundle.go or supplied by the user via x/net/http2. From d77e7148961ecf25e43b20167a95e0f95e0b9513 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 13:15:07 +0800 Subject: [PATCH 219/843] add comments to transport's dump methods --- transport.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transport.go b/transport.go index 747182c2..021ee158 100644 --- a/transport.go +++ b/transport.go @@ -411,12 +411,14 @@ func (t *Transport) Clone() *Transport { return t2 } +// EnableDump enables the dump for all requests with specified dump options. func (t *Transport) EnableDump(opt *DumpOptions) { dump := newDumper(opt) t.dump = dump go dump.Start() } +// DisableDump disables the dump. func (t *Transport) DisableDump() { if t.dump != nil { t.dump.Stop() From c58cea8b4af269f7a45caa12a01bfe9e45a1fea3 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 13:53:04 +0800 Subject: [PATCH 220/843] add tests for request level dump --- req_test.go | 22 ++++++++++++++++-- request_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/req_test.go b/req_test.go index fb771039..ea620511 100644 --- a/req_test.go +++ b/req_test.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "reflect" + "strings" "testing" ) @@ -96,20 +97,37 @@ func assertType(t *testing.T, typ, v interface{}) { } } +func assertNotContains(t *testing.T, s, substr string) { + if strings.Contains(s, substr) { + t.Errorf("%q is included in %s", substr, s) + } +} + +func assertContains(t *testing.T, s, substr string) { + if !strings.Contains(s, substr) { + t.Errorf("%q is not included in %s", substr, s) + } +} + func assertError(t *testing.T, err error) { if err != nil { t.Errorf("Error occurred [%v]", err) } } -func assertEqual(t *testing.T, e, g interface{}) (r bool) { +func assertEqual(t *testing.T, e, g interface{}) { if !equal(e, g) { t.Errorf("Expected [%v], got [%v]", e, g) } - return } +func removeEmptyString(s string) string { + s = strings.ReplaceAll(s, "\r", "") + s = strings.ReplaceAll(s, "\n", "") + return s +} + func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { if equal(e, g) { t.Errorf("Expected [%v], got [%v]", e, g) diff --git a/request_test.go b/request_test.go index 47e6c3c4..166410fa 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,66 @@ import ( "testing" ) +func TestRequestDump(t *testing.T) { + ts := createPostServer(t) + defer ts.Close() + + c := tc() + resp, err := c.R().EnableDump().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump := resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertContains(t, dump, "test body") + assertContains(t, dump, "HTTP/1.1 200 OK") + assertContains(t, dump, "TestPost: text response") + + resp, err = c.R().EnableDumpWithoutRequest().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertNotContains(t, dump, "POST / HTTP/1.1") + assertNotContains(t, dump, "test body") + assertContains(t, dump, "HTTP/1.1 200 OK") + assertContains(t, dump, "TestPost: text response") + + resp, err = c.R().EnableDumpWithoutResponse().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertContains(t, dump, "test body") + assertNotContains(t, dump, "HTTP/1.1 200 OK") + assertNotContains(t, dump, "TestPost: text response") + + resp, err = c.R().EnableDumpWithoutHeader().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertNotContains(t, dump, "POST / HTTP/1.1") + assertContains(t, dump, "test body") + assertNotContains(t, dump, "HTTP/1.1 200 OK") + assertContains(t, dump, "TestPost: text response") + + resp, err = c.R().EnableDumpWithoutBody().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertNotContains(t, dump, "test body") + assertContains(t, dump, "HTTP/1.1 200 OK") + assertNotContains(t, dump, "TestPost: text response") + + opt := &DumpOptions{ + RequestHeader: true, + RequestBody: false, + ResponseHeader: false, + ResponseBody: true, + } + resp, err = c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertNotContains(t, dump, "test body") + assertNotContains(t, dump, "HTTP/1.1 200 OK") + assertContains(t, dump, "TestPost: text response") +} + func TestGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() From e91b5a62fd004242db4115cdab91dac9ea99aad0 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 13:58:50 +0800 Subject: [PATCH 221/843] update README: split debug and test --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eb7b6c2f..3973b397 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Features](#Features) * [Quick Start](#Quick-Start) -* [Debugging](#Debugging) +* [Debugging - Log/Trace/Dump](#Debugging) +* [Quick HTTP Test](#Test) * [URL Path and Query Parameter](#Param) * [Form Data](#Form) * [Header and Cookie](#Header-Cookie) @@ -79,7 +80,7 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. -## Debugging +## Debugging - Log/Trace/Dump **Dump the Content** @@ -226,6 +227,8 @@ client := req.C().DevMode() client.R().Get("https://imroc.cc") ``` +## Quick HTTP Test + **Test with Global Wrapper Methods** `req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. From 3c3cc0f68141f85595c0115209a4e2edbd0b9e81 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 14:11:34 +0800 Subject: [PATCH 222/843] update README --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3973b397..93e74a3d 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,8 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging - Log/Trace/Dump](#Debugging). +* Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). * Automatic marshal and unmarshal for JSON and XML content type and fully customizable. * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support. @@ -239,6 +240,7 @@ client.R().Get("https://imroc.cc") // you don't need to create any client explicitly. req.SetTimeout(5 * time.Second). SetCommonBasicAuth("imroc", "123456"). + SetCommonHeader("Accept", "application/json"). SetUserAgent("my api client"). DevMode() @@ -247,16 +249,16 @@ req.SetTimeout(5 * time.Second). // client, so you can treat package name `req` as a Request, // and you don't need to create request explicitly. req.SetQueryParam("page", "2"). - SetHeader("Accept", "application/json"). + SetHeader("Accept", "text/xml"). // Override client level settings at request level. Get("https://api.example.com/repos") ``` **Test with MustXXX** -Use `MustXXX` to ignore error handling when test, make it possible to complete a complex test with just one line of code: +Use `MustXXX` to ignore error handling during test, make it possible to complete a complex test with just one line of code: ```go -fmt.Println(req.DevMode().MustGet("https://imroc.cc").TraceInfo()) +fmt.Println(req.DevMode().R().MustGet("https://imroc.cc").TraceInfo()) ``` ## URL Path and Query Parameter From 97b6be51c18a4ad2a6ee952b3247f5d4f83fb465 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 14:14:22 +0800 Subject: [PATCH 223/843] update README: ajust Debugging title --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 93e74a3d..ee1cd2ff 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Features](#Features) * [Quick Start](#Quick-Start) -* [Debugging - Log/Trace/Dump](#Debugging) +* [Debugging - Dump/Log/Trace](#Debugging) * [Quick HTTP Test](#Test) * [URL Path and Query Parameter](#Param) * [Form Data](#Form) @@ -81,7 +81,7 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. -## Debugging - Log/Trace/Dump +## Debugging - Dump/Log/Trace **Dump the Content** From 747e815dfdab87b9a9d2d02b6ad8464f150d60f1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 14:29:02 +0800 Subject: [PATCH 224/843] add tests for client level dump --- client_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 client_test.go diff --git a/client_test.go b/client_test.go new file mode 100644 index 00000000..b959f3ab --- /dev/null +++ b/client_test.go @@ -0,0 +1,71 @@ +package req + +import ( + "bytes" + "testing" +) + +func TestClientDump(t *testing.T) { + ts := createPostServer(t) + defer ts.Close() + + c := tc() + buf := new(bytes.Buffer) + c.EnableDumpAllTo(buf) + resp, err := c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertContains(t, buf.String(), "test body") + assertContains(t, buf.String(), "HTTP/1.1 200 OK") + assertContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutHeader().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertNotContains(t, buf.String(), "POST / HTTP/1.1") + assertContains(t, buf.String(), "test body") + assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") + assertContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutBody().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertNotContains(t, buf.String(), "test body") + assertContains(t, buf.String(), "HTTP/1.1 200 OK") + assertNotContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutRequestBody().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertNotContains(t, buf.String(), "test body") + assertContains(t, buf.String(), "HTTP/1.1 200 OK") + assertContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutResponse().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertContains(t, buf.String(), "test body") + assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") + assertNotContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutResponseBody().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertContains(t, buf.String(), "test body") + assertContains(t, buf.String(), "HTTP/1.1 200 OK") + assertNotContains(t, buf.String(), "TestPost: text response") +} From 2457ccf7899037033cd46d244b639ecda84f6ebd Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 14:45:25 +0800 Subject: [PATCH 225/843] add missing dump methods for client and requests --- client.go | 16 ++++++++++++++++ client_test.go | 27 +++++++++++++++++++++++++++ request.go | 28 ++++++++++++++++++++++++++++ request_test.go | 16 ++++++++++++++++ 4 files changed, 87 insertions(+) diff --git a/client.go b/client.go index 8bdb6517..c538c27a 100644 --- a/client.go +++ b/client.go @@ -703,6 +703,22 @@ func (c *Client) EnableDumpAllWithoutResponse() *Client { return c } +// EnableDumpAllWithoutRequest is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutRequest. +func EnableDumpAllWithoutRequest() *Client { + return defaultClient.EnableDumpAllWithoutRequest() +} + +// EnableDumpAllWithoutRequest enables dump for all requests with only +// request header and body included. +func (c *Client) EnableDumpAllWithoutRequest() *Client { + o := c.getDumpOptions() + o.RequestHeader = false + o.RequestBody = false + c.EnableDumpAll() + return c +} + // EnableDumpAllWithoutHeader is a global wrapper methods which delegated // to the default client's EnableDumpAllWithoutHeader. func EnableDumpAllWithoutHeader() *Client { diff --git a/client_test.go b/client_test.go index b959f3ab..02a35adf 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,16 @@ func TestClientDump(t *testing.T) { assertContains(t, buf.String(), "HTTP/1.1 200 OK") assertNotContains(t, buf.String(), "TestPost: text response") + c = tc() + buf = new(bytes.Buffer) + c.EnableDumpAllWithoutRequest().EnableDumpAllTo(buf) + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertNotContains(t, buf.String(), "POST / HTTP/1.1") + assertNotContains(t, buf.String(), "test body") + assertContains(t, buf.String(), "HTTP/1.1 200 OK") + assertContains(t, buf.String(), "TestPost: text response") + c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutRequestBody().EnableDumpAllTo(buf) @@ -68,4 +78,21 @@ func TestClientDump(t *testing.T) { assertContains(t, buf.String(), "test body") assertContains(t, buf.String(), "HTTP/1.1 200 OK") assertNotContains(t, buf.String(), "TestPost: text response") + + c = tc() + buf = new(bytes.Buffer) + opt := &DumpOptions{ + RequestHeader: true, + RequestBody: false, + ResponseHeader: false, + ResponseBody: true, + Output: buf, + } + c.SetCommonDumpOptions(opt).EnableDumpAll() + resp, err = c.R().SetBody("test body").Post(ts.URL) + assertResponse(t, resp, err) + assertContains(t, buf.String(), "POST / HTTP/1.1") + assertNotContains(t, buf.String(), "test body") + assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") + assertContains(t, buf.String(), "TestPost: text response") } diff --git a/request.go b/request.go index c55343d0..f9719ab8 100644 --- a/request.go +++ b/request.go @@ -917,3 +917,31 @@ func (r *Request) EnableDumpWithoutRequest() *Request { o.RequestBody = false return r.EnableDump() } + +// EnableDumpWithoutRequestBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutRequestBody for request. +func EnableDumpWithoutRequestBody() *Request { + return defaultClient.R().EnableDumpWithoutRequestBody() +} + +// EnableDumpWithoutRequestBody enables dump with request body excluded, +// can be used in upload request to avoid dump the unreadable binary content. +func (r *Request) EnableDumpWithoutRequestBody() *Request { + o := r.getDumpOptions() + o.RequestBody = false + return r.EnableDump() +} + +// EnableDumpWithoutResponseBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutResponseBody for request. +func EnableDumpWithoutResponseBody() *Request { + return defaultClient.R().EnableDumpWithoutResponseBody() +} + +// EnableDumpWithoutResponseBody enables dump with response body excluded, +// can be used in download request to avoid dump the unreadable binary content. +func (r *Request) EnableDumpWithoutResponseBody() *Request { + o := r.getDumpOptions() + o.ResponseBody = false + return r.EnableDump() +} diff --git a/request_test.go b/request_test.go index 166410fa..6ed499a9 100644 --- a/request_test.go +++ b/request_test.go @@ -26,6 +26,14 @@ func TestRequestDump(t *testing.T) { assertContains(t, dump, "HTTP/1.1 200 OK") assertContains(t, dump, "TestPost: text response") + resp, err = c.R().EnableDumpWithoutRequestBody().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertNotContains(t, dump, "test body") + assertContains(t, dump, "HTTP/1.1 200 OK") + assertContains(t, dump, "TestPost: text response") + resp, err = c.R().EnableDumpWithoutResponse().SetBody(`test body`).Post(ts.URL) assertResponse(t, resp, err) dump = resp.Dump() @@ -34,6 +42,14 @@ func TestRequestDump(t *testing.T) { assertNotContains(t, dump, "HTTP/1.1 200 OK") assertNotContains(t, dump, "TestPost: text response") + resp, err = c.R().EnableDumpWithoutResponseBody().SetBody(`test body`).Post(ts.URL) + assertResponse(t, resp, err) + dump = resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1") + assertContains(t, dump, "test body") + assertContains(t, dump, "HTTP/1.1 200 OK") + assertNotContains(t, dump, "TestPost: text response") + resp, err = c.R().EnableDumpWithoutHeader().SetBody(`test body`).Post(ts.URL) assertResponse(t, resp, err) dump = resp.Dump() From 9bd5bbc78cf28fdc6412bfd480b1d1d1842c3824 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 6 Feb 2022 20:53:30 +0800 Subject: [PATCH 226/843] extract tlsclient interface --- client.go | 7 +++++++ h2_bundle.go | 15 ++++++++------- pkg/tlsclient/conn.go | 12 ++++++++++++ transport.go | 9 +++++---- transport_test.go | 7 ++++--- 5 files changed, 36 insertions(+), 14 deletions(-) create mode 100644 pkg/tlsclient/conn.go diff --git a/client.go b/client.go index c538c27a..4f21efd5 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package req import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -10,6 +11,7 @@ import ( "golang.org/x/net/publicsuffix" "io" "io/ioutil" + "net" "net/http" "net/http/cookiejar" urlpkg "net/url" @@ -1097,6 +1099,11 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli return c } +func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + c.t.DialTLSContext = fn + return c +} + // DisableAllowGetMethodPayload is a global wrapper methods which delegated // to the default client's DisableAllowGetMethodPayload. func DisableAllowGetMethodPayload() *Client { diff --git a/h2_bundle.go b/h2_bundle.go index 404adac4..543e2d26 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -28,6 +28,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/imroc/req/v3/pkg/tlsclient" "io" "io/ioutil" "log" @@ -243,7 +244,7 @@ func (c *http2dialCall) dial(ctx context.Context, addr string) { // This code decides which ones live or die. // The return value used is whether c was used. // c is never closed. -func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) { +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c net.Conn) (used bool, err error) { p.mu.Lock() for _, cc := range p.conns[key] { if cc.CanTakeNewRequest() { @@ -279,7 +280,7 @@ type http2addConnCall struct { err error } -func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { +func (c *http2addConnCall) run(t *http2Transport, key string, tc net.Conn) { cc, err := t.NewClientConn(tc) p := c.p @@ -2377,7 +2378,7 @@ func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textpr // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) { +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsclient.Conn, error) { dialer := &tls.Dialer{ Config: cfg, } @@ -2385,7 +2386,7 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s if err != nil { return nil, err } - tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed + tlsCn := cn.(tlsclient.Conn) // DialContext comment promises this will always succeed return tlsCn, nil } @@ -3346,7 +3347,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") } - upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { + upgradeFn := func(authority string, c tlsclient.Conn) http.RoundTripper { addr := http2authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { go c.Close() @@ -3361,7 +3362,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { return t2 } if m := t1.TLSNextProto; len(m) == 0 { - t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ + t1.TLSNextProto = map[string]func(string, tlsclient.Conn) http.RoundTripper{ "h2": upgradeFn, } } else { @@ -3388,7 +3389,7 @@ func (t *http2Transport) initConnPool() { type http2ClientConn struct { currentRequest *http.Request t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls + tconn net.Conn // usually tlsclient.Conn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request diff --git a/pkg/tlsclient/conn.go b/pkg/tlsclient/conn.go new file mode 100644 index 00000000..514f10a9 --- /dev/null +++ b/pkg/tlsclient/conn.go @@ -0,0 +1,12 @@ +package tlsclient + +import ( + "crypto/tls" + "net" +) + +type Conn interface { + net.Conn + ConnectionState() tls.ConnectionState + Handshake() error +} diff --git a/transport.go b/transport.go index 021ee158..d8471fde 100644 --- a/transport.go +++ b/transport.go @@ -20,6 +20,7 @@ import ( "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/godebug" "github.com/imroc/req/v3/internal/util" + "github.com/imroc/req/v3/pkg/tlsclient" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" @@ -235,7 +236,7 @@ type Transport struct { // must return a http.RoundTripper that then handles the request. // If TLSNextProto is not nil, HTTP/2 support is not enabled // automatically. - TLSNextProto map[string]func(authority string, c *tls.Conn) http.RoundTripper + TLSNextProto map[string]func(authority string, c tlsclient.Conn) http.RoundTripper // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. @@ -402,7 +403,7 @@ func (t *Transport) Clone() *Transport { t2.TLSClientConfig = t.TLSClientConfig.Clone() } if !t.tlsNextProtoWasNil { - npm := map[string]func(authority string, c *tls.Conn) http.RoundTripper{} + npm := map[string]func(authority string, c tlsclient.Conn) http.RoundTripper{} for k, v := range t.TLSNextProto { npm[k] = v } @@ -1659,7 +1660,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if err != nil { return nil, wrapErr(err) } - if tc, ok := pconn.conn.(*tls.Conn); ok { + if tc, ok := pconn.conn.(tlsclient.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. if trace != nil && trace.TLSHandshakeStart != nil { @@ -1810,7 +1811,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) + alt := next(cm.targetAddr, pconn.conn.(tlsclient.Conn)) if e, ok := alt.(erringRoundTripper); ok { // pconn.conn was closed by next (http2configureTransports.upgradeFn). return nil, e.RoundTripErr() diff --git a/transport_test.go b/transport_test.go index dc708a80..12c8d822 100644 --- a/transport_test.go +++ b/transport_test.go @@ -12,6 +12,7 @@ import ( "crypto/tls" "errors" "github.com/imroc/req/v3/internal/testcert" + "github.com/imroc/req/v3/pkg/tlsclient" "io" "net" "net/http" @@ -215,7 +216,7 @@ func TestTransportBodyAltRewind(t *testing.T) { t.Error(err) return } - if err := sc.(*tls.Conn).Handshake(); err != nil { + if err := sc.(tlsclient.Conn).Handshake(); err != nil { t.Error(err) return } @@ -228,8 +229,8 @@ func TestTransportBodyAltRewind(t *testing.T) { roundTripped := false tr := &Transport{ DisableKeepAlives: true, - TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{ - "foo": func(authority string, c *tls.Conn) http.RoundTripper { + TLSNextProto: map[string]func(string, tlsclient.Conn) http.RoundTripper{ + "foo": func(authority string, c tlsclient.Conn) http.RoundTripper { return roundTripFunc(func(r *http.Request) (*http.Response, error) { n, _ := io.Copy(io.Discard, r.Body) if n == 0 { From db82f38afff3e8cc2f17f202641c7ec8f01ecfef Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 07:05:23 +0800 Subject: [PATCH 227/843] move tlsclient.Conn to TLSConn --- client.go | 5 +++++ h2_bundle.go | 11 +++++------ pkg/tlsclient/conn.go | 12 ------------ tls.go | 17 +++++++++++++++++ transport.go | 9 ++++----- transport_test.go | 7 +++---- 6 files changed, 34 insertions(+), 27 deletions(-) delete mode 100644 pkg/tlsclient/conn.go create mode 100644 tls.go diff --git a/client.go b/client.go index 4f21efd5..c9127e13 100644 --- a/client.go +++ b/client.go @@ -1104,6 +1104,11 @@ func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) ( return c } +func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + c.t.DialContext = fn + return c +} + // DisableAllowGetMethodPayload is a global wrapper methods which delegated // to the default client's DisableAllowGetMethodPayload. func DisableAllowGetMethodPayload() *Client { diff --git a/h2_bundle.go b/h2_bundle.go index 543e2d26..b726d157 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -28,7 +28,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/imroc/req/v3/pkg/tlsclient" "io" "io/ioutil" "log" @@ -2378,7 +2377,7 @@ func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textpr // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsclient.Conn, error) { +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { dialer := &tls.Dialer{ Config: cfg, } @@ -2386,7 +2385,7 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s if err != nil { return nil, err } - tlsCn := cn.(tlsclient.Conn) // DialContext comment promises this will always succeed + tlsCn := cn.(TLSConn) // DialContext comment promises this will always succeed return tlsCn, nil } @@ -3347,7 +3346,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") } - upgradeFn := func(authority string, c tlsclient.Conn) http.RoundTripper { + upgradeFn := func(authority string, c TLSConn) http.RoundTripper { addr := http2authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { go c.Close() @@ -3362,7 +3361,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { return t2 } if m := t1.TLSNextProto; len(m) == 0 { - t1.TLSNextProto = map[string]func(string, tlsclient.Conn) http.RoundTripper{ + t1.TLSNextProto = map[string]func(string, TLSConn) http.RoundTripper{ "h2": upgradeFn, } } else { @@ -3389,7 +3388,7 @@ func (t *http2Transport) initConnPool() { type http2ClientConn struct { currentRequest *http.Request t *http2Transport - tconn net.Conn // usually tlsclient.Conn, except specialized impls + tconn net.Conn // usually TLSConn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request diff --git a/pkg/tlsclient/conn.go b/pkg/tlsclient/conn.go deleted file mode 100644 index 514f10a9..00000000 --- a/pkg/tlsclient/conn.go +++ /dev/null @@ -1,12 +0,0 @@ -package tlsclient - -import ( - "crypto/tls" - "net" -) - -type Conn interface { - net.Conn - ConnectionState() tls.ConnectionState - Handshake() error -} diff --git a/tls.go b/tls.go new file mode 100644 index 00000000..121425d5 --- /dev/null +++ b/tls.go @@ -0,0 +1,17 @@ +package req + +import ( + "crypto/tls" + "net" +) + +// TLSConn is the recommended interface for the connection +// returned by the DailTLS function (Client.SetDialTLS, +// Transport.DialTLSContext), so that the TLS handshake negotiation +// can automatically decide whether to use HTTP2 or HTTP1 (ALPN). +// If this interface is not implemented, HTTP1 will be used by default. +type TLSConn interface { + net.Conn + ConnectionState() tls.ConnectionState + Handshake() error +} diff --git a/transport.go b/transport.go index d8471fde..e9dc83f7 100644 --- a/transport.go +++ b/transport.go @@ -20,7 +20,6 @@ import ( "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/godebug" "github.com/imroc/req/v3/internal/util" - "github.com/imroc/req/v3/pkg/tlsclient" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" @@ -236,7 +235,7 @@ type Transport struct { // must return a http.RoundTripper that then handles the request. // If TLSNextProto is not nil, HTTP/2 support is not enabled // automatically. - TLSNextProto map[string]func(authority string, c tlsclient.Conn) http.RoundTripper + TLSNextProto map[string]func(authority string, c TLSConn) http.RoundTripper // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. @@ -403,7 +402,7 @@ func (t *Transport) Clone() *Transport { t2.TLSClientConfig = t.TLSClientConfig.Clone() } if !t.tlsNextProtoWasNil { - npm := map[string]func(authority string, c tlsclient.Conn) http.RoundTripper{} + npm := map[string]func(authority string, c TLSConn) http.RoundTripper{} for k, v := range t.TLSNextProto { npm[k] = v } @@ -1660,7 +1659,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if err != nil { return nil, wrapErr(err) } - if tc, ok := pconn.conn.(tlsclient.Conn); ok { + if tc, ok := pconn.conn.(TLSConn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. if trace != nil && trace.TLSHandshakeStart != nil { @@ -1811,7 +1810,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - alt := next(cm.targetAddr, pconn.conn.(tlsclient.Conn)) + alt := next(cm.targetAddr, pconn.conn.(TLSConn)) if e, ok := alt.(erringRoundTripper); ok { // pconn.conn was closed by next (http2configureTransports.upgradeFn). return nil, e.RoundTripErr() diff --git a/transport_test.go b/transport_test.go index 12c8d822..563af04e 100644 --- a/transport_test.go +++ b/transport_test.go @@ -12,7 +12,6 @@ import ( "crypto/tls" "errors" "github.com/imroc/req/v3/internal/testcert" - "github.com/imroc/req/v3/pkg/tlsclient" "io" "net" "net/http" @@ -216,7 +215,7 @@ func TestTransportBodyAltRewind(t *testing.T) { t.Error(err) return } - if err := sc.(tlsclient.Conn).Handshake(); err != nil { + if err := sc.(TLSConn).Handshake(); err != nil { t.Error(err) return } @@ -229,8 +228,8 @@ func TestTransportBodyAltRewind(t *testing.T) { roundTripped := false tr := &Transport{ DisableKeepAlives: true, - TLSNextProto: map[string]func(string, tlsclient.Conn) http.RoundTripper{ - "foo": func(authority string, c tlsclient.Conn) http.RoundTripper { + TLSNextProto: map[string]func(string, TLSConn) http.RoundTripper{ + "foo": func(authority string, c TLSConn) http.RoundTripper { return roundTripFunc(func(r *http.Request) (*http.Response, error) { n, _ := io.Copy(io.Discard, r.Body) if n == 0 { From b094ed5c074538698af8604a43d30381f8f1d852 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 07:11:52 +0800 Subject: [PATCH 228/843] remove deperecated Dial and DialTLS in Transport --- transport.go | 40 ++++------------------------------------ transport_test.go | 2 +- 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/transport.go b/transport.go index e9dc83f7..6f7dc318 100644 --- a/transport.go +++ b/transport.go @@ -129,18 +129,6 @@ type Transport struct { // becomes idle before the later DialContext completes. DialContext func(ctx context.Context, network, addr string) (net.Conn, error) - // Dial specifies the dial function for creating unencrypted TCP connections. - // - // Dial runs concurrently with calls to RoundTrip. - // A RoundTrip call that initiates a dial may end up using - // a connection dialed previously when the earlier connection - // becomes idle before the later Dial completes. - // - // Deprecated: Use DialContext instead, which allows the transport - // to cancel dials as soon as they are no longer needed. - // If both are set, DialContext takes priority. - Dial func(network, addr string) (net.Conn, error) - // DialTLSContext specifies an optional dial function for creating // TLS connections for non-proxied HTTPS requests. // @@ -153,14 +141,6 @@ type Transport struct { // past the TLS handshake. DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) - // DialTLS specifies an optional dial function for creating - // TLS connections for non-proxied HTTPS requests. - // - // Deprecated: Use DialTLSContext instead, which allows the transport - // to cancel dials as soon as they are no longer needed. - // If both are set, DialTLSContext takes priority. - DialTLS func(network, addr string) (net.Conn, error) - // TLSClientConfig specifies the TLS configuration to use with // tls.Client. // If nil, the default configuration is used. @@ -374,8 +354,6 @@ func (t *Transport) Clone() *Transport { t2 := &Transport{ Proxy: t.Proxy, DialContext: t.DialContext, - Dial: t.Dial, - DialTLS: t.DialTLS, DialTLSContext: t.DialTLSContext, TLSHandshakeTimeout: t.TLSHandshakeTimeout, DisableKeepAlives: t.DisableKeepAlives, @@ -437,7 +415,7 @@ type h2Transport interface { } func (t *Transport) hasCustomTLSDialer() bool { - return t.DialTLS != nil || t.DialTLSContext != nil + return t.DialTLSContext != nil } // onceSetNextProtoDefaults initializes TLSNextProto. @@ -468,7 +446,7 @@ func (t *Transport) onceSetNextProtoDefaults() { // Transport. return } - if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { + if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { // Be conservative and don't automatically enable // http2 if they've specified a custom TLS config or // custom dialers. Let them opt-in themselves via @@ -1243,13 +1221,6 @@ func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, e if t.DialContext != nil { return t.DialContext(ctx, network, addr) } - if t.Dial != nil { - c, err := t.Dial(network, addr) - if c == nil && err == nil { - err = errors.New("net/http: Transport.Dial hook returned (nil, nil)") - } - return c, err - } return zeroDialer.DialContext(ctx, network, addr) } @@ -1388,11 +1359,8 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { - if t.DialTLSContext != nil { - conn, err = t.DialTLSContext(ctx, network, addr) - } else { - conn, err = t.DialTLS(network, addr) - } + conn, err = t.DialTLSContext(ctx, network, addr) + if conn == nil && err == nil { err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)") } diff --git a/transport_test.go b/transport_test.go index 563af04e..f0b9edd6 100644 --- a/transport_test.go +++ b/transport_test.go @@ -246,7 +246,7 @@ func TestTransportBodyAltRewind(t *testing.T) { }) }, }, - DialTLS: func(_, _ string) (net.Conn, error) { + DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"foo"}, From c29d6e0c3348c08b779d865c597309b96b881924 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 07:13:39 +0800 Subject: [PATCH 229/843] improve comments --- transport.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transport.go b/transport.go index 6f7dc318..244b109a 100644 --- a/transport.go +++ b/transport.go @@ -120,8 +120,7 @@ type Transport struct { Proxy func(*http.Request) (*url.URL, error) // DialContext specifies the dial function for creating unencrypted TCP connections. - // If DialContext is nil (and the deprecated Dial below is also nil), - // then the transport dials using package net. + // If DialContext is nil, then the transport dials using package net. // // DialContext runs concurrently with calls to RoundTrip. // A RoundTrip call that initiates a dial may end up using @@ -132,8 +131,7 @@ type Transport struct { // DialTLSContext specifies an optional dial function for creating // TLS connections for non-proxied HTTPS requests. // - // If DialTLSContext is nil (and the deprecated DialTLS below is also nil), - // DialContext and TLSClientConfig are used. + // If DialTLSContext is nil, DialContext and TLSClientConfig are used. // // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS // requests and the TLSClientConfig and TLSHandshakeTimeout From 9a3cc55c9fae6c5b47a68de5fd45fe32c47a8444 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 07:46:24 +0800 Subject: [PATCH 230/843] rename context name dumper to _dumper --- dump.go | 2 +- request.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dump.go b/dump.go index 3caddc5e..b0afbfbc 100644 --- a/dump.go +++ b/dump.go @@ -159,7 +159,7 @@ func (d *dumper) Start() { } func getDumperOverride(dump *dumper, ctx context.Context) *dumper { - if d, ok := ctx.Value("dumper").(*dumper); ok { + if d, ok := ctx.Value("_dumper").(*dumper); ok { return d } return dump diff --git a/request.go b/request.go index f9719ab8..427388bb 100644 --- a/request.go +++ b/request.go @@ -859,7 +859,7 @@ func EnableDump() *Request { // EnableDump enables dump, including all content for the request and response by default. func (r *Request) EnableDump() *Request { - return r.SetContext(context.WithValue(r.Context(), "dumper", newDumper(r.getDumpOptions()))) + return r.SetContext(context.WithValue(r.Context(), "_dumper", newDumper(r.getDumpOptions()))) } // EnableDumpWithoutBody is a global wrapper methods which delegated From 7f1b33962ff0265330c0cd35d4f5dae673d1676b Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 11:05:07 +0800 Subject: [PATCH 231/843] add some abilities 1. support ForceHTTP1 2. support customize Dail and DailTLS 3. unexport some fields which is unnecessary --- client.go | 38 +++++++++++++++ h2_bundle.go | 15 ++++-- transfer.go | 14 +----- transport.go | 130 ++++++--------------------------------------------- 4 files changed, 64 insertions(+), 133 deletions(-) diff --git a/client.go b/client.go index c9127e13..1b981054 100644 --- a/client.go +++ b/client.go @@ -1099,16 +1099,54 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli return c } +// SetDialTLS is a global wrapper methods which delegated +// to the default client's SetDialTLS. +func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + return defaultClient.SetDialTLS(fn) +} + +// SetDialTLS set the customized DialTLSContext func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialTLSContext = fn return c } +// SetDial is a global wrapper methods which delegated +// to the default client's SetDial. +func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + return defaultClient.SetDial(fn) +} + +// SetDial set the customized Dial function. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialContext = fn return c } +// EnableForceHTTP1 is a global wrapper methods which delegated +// to the default client's EnableForceHTTP1. +func EnableForceHTTP1() *Client { + return defaultClient.EnableForceHTTP1() +} + +// EnableForceHTTP1 enables force using HTTP1 (disabled by default) +func (c *Client) EnableForceHTTP1() *Client { + c.t.ForceHTTP1 = true + return c +} + +// DisableForceHTTP1 is a global wrapper methods which delegated +// to the default client's DisableForceHTTP1. +func DisableForceHTTP1() *Client { + return defaultClient.DisableForceHTTP1() +} + +// DisableForceHTTP1 disable force using HTTP1 (disabled by default) +func (c *Client) DisableForceHTTP1() *Client { + c.t.ForceHTTP1 = false + return c +} + // DisableAllowGetMethodPayload is a global wrapper methods which delegated // to the default client's DisableAllowGetMethodPayload. func DisableAllowGetMethodPayload() *Client { diff --git a/h2_bundle.go b/h2_bundle.go index b726d157..a0b4b825 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -3293,13 +3293,20 @@ type http2Transport struct { connPoolOrDef http2ClientConnPool // non-nil version of ConnPool } +const h2max = 1<<32 - 1 + func (t *http2Transport) maxHeaderListSize() uint32 { - if t.MaxHeaderListSize == 0 { - return 10 << 20 - } if t.MaxHeaderListSize == 0xffffffff { return 0 } + if t.MaxHeaderListSize > 0 { + return t.MaxHeaderListSize + } + if limit := t.t1.MaxResponseHeaderBytes; limit > 0 && limit < h2max { + t.MaxHeaderListSize = h2max + } else { + t.MaxHeaderListSize = 10 << 20 + } return t.MaxHeaderListSize } @@ -3340,7 +3347,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { if t1.TLSClientConfig == nil { t1.TLSClientConfig = new(tls.Config) } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + if !t1.ForceHTTP1 && !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) } if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { diff --git a/transfer.go b/transfer.go index c0ff2512..79829fe9 100644 --- a/transfer.go +++ b/transfer.go @@ -26,10 +26,6 @@ import ( "golang.org/x/net/http/httpguts" ) -// ErrLineTooLong is returned when reading request or response bodies -// with malformed chunked encoding. -var ErrLineTooLong = internal.ErrLineTooLong - type errorReader struct { err error } @@ -834,17 +830,11 @@ type body struct { onHitEOF func() // if non-nil, func to call when EOF is Read } -// ErrBodyReadAfterClose is returned when reading a Request or Response -// Body after the body has been closed. This typically happens when the body is -// read after an HTTP Handler calls WriteHeader or Write on its -// ResponseWriter. -var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") - func (b *body) Read(p []byte) (n int, err error) { b.mu.Lock() defer b.mu.Unlock() if b.closed { - return 0, ErrBodyReadAfterClose + return 0, http.ErrBodyReadAfterClose } return b.readLocked(p) } @@ -1051,7 +1041,7 @@ type bodyLocked struct { func (bl bodyLocked) Read(p []byte) (n int, err error) { if bl.b.closed { - return 0, ErrBodyReadAfterClose + return 0, http.ErrBodyReadAfterClose } return bl.b.readLocked(p) } diff --git a/transport.go b/transport.go index 244b109a..9a701a17 100644 --- a/transport.go +++ b/transport.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/godebug" "github.com/imroc/req/v3/internal/util" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" @@ -30,7 +29,6 @@ import ( "net/http/httptrace" "net/textproto" "net/url" - "reflect" "strings" "sync" "sync/atomic" @@ -40,9 +38,9 @@ import ( "golang.org/x/net/http/httpproxy" ) -// DefaultMaxIdleConnsPerHost is the default value of Transport's +// defaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. -const DefaultMaxIdleConnsPerHost = 2 +const defaultMaxIdleConnsPerHost = 2 // ResponseOptions determines that how should the response been processed. type ResponseOptions struct { @@ -108,6 +106,9 @@ type Transport struct { connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + // ForceHTTP1 force using HTTP/1.1 + ForceHTTP1 bool + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -172,7 +173,7 @@ type Transport struct { // MaxIdleConnsPerHost, if non-zero, controls the maximum idle // (keep-alive) connections to keep per-host. If zero, - // DefaultMaxIdleConnsPerHost is used. + // defaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int // MaxConnsPerHost optionally limits the total number of @@ -246,11 +247,7 @@ type Transport struct { // If zero, a default (currently 4KB) is used. ReadBufferSize int - // nextProtoOnce guards initialization of TLSNextProto and - // h2transport (via onceSetNextProtoDefaults) - nextProtoOnce sync.Once - h2transport h2Transport // non-nil if http2 wired up - tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired + t2 *http2Transport // non-nil if http2 wired up // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. @@ -347,7 +344,6 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { - t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t2 := &Transport{ Proxy: t.Proxy, @@ -369,6 +365,7 @@ func (t *Transport) Clone() *Transport { WriteBufferSize: t.WriteBufferSize, ReadBufferSize: t.ReadBufferSize, ResponseOptions: t.ResponseOptions, + ForceHTTP1: t.ForceHTTP1, dump: t.dump.Clone(), } if t.dump != nil { @@ -377,7 +374,7 @@ func (t *Transport) Clone() *Transport { if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() } - if !t.tlsNextProtoWasNil { + if t.TLSNextProto != nil { npm := map[string]func(authority string, c TLSConn) http.RoundTripper{} for k, v := range t.TLSNextProto { npm[k] = v @@ -402,109 +399,10 @@ func (t *Transport) DisableDump() { } } -// h2Transport is the interface we expect to be able to call from -// net/http against an *http2.Transport that's either bundled into -// h2_bundle.go or supplied by the user via x/net/http2. -// -// We name it with the "h2" prefix to stay out of the "http2" prefix -// namespace used by x/tools/cmd/bundle for h2_bundle.go. -type h2Transport interface { - CloseIdleConnections() -} - func (t *Transport) hasCustomTLSDialer() bool { return t.DialTLSContext != nil } -// onceSetNextProtoDefaults initializes TLSNextProto. -// It must be called via t.nextProtoOnce.Do. -func (t *Transport) onceSetNextProtoDefaults() { - t.tlsNextProtoWasNil = (t.TLSNextProto == nil) - if godebug.Get("http2client") == "0" { - return - } - - // If they've already configured http2 with - // golang.org/x/net/http2 instead of the bundled copy, try to - // get at its http2.Transport value (via the "https" - // altproto map) so we can call CloseIdleConnections on it if - // requested. (Issue 22891) - altProto, _ := t.altProto.Load().(map[string]http.RoundTripper) - if rv := reflect.ValueOf(altProto["https"]); rv.IsValid() && rv.Type().Kind() == reflect.Struct && rv.Type().NumField() == 1 { - if v := rv.Field(0); v.CanInterface() { - if h2i, ok := v.Interface().(h2Transport); ok { - t.h2transport = h2i - return - } - } - } - - if t.TLSNextProto != nil { - // This is the documented way to disable http2 on a - // Transport. - return - } - if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { - // Be conservative and don't automatically enable - // http2 if they've specified a custom TLS config or - // custom dialers. Let them opt-in themselves via - // http2.ConfigureTransport so we don't surprise them - // by modifying their tls.Config. Issue 14275. - // However, if ForceAttemptHTTP2 is true, it overrides the above checks. - return - } - t2, err := http2ConfigureTransports(t) - if err != nil { - log.Printf("Error enabling Transport HTTP/2 support: %v", err) - return - } - t.h2transport = t2 - - // Auto-configure the http2.Transport's MaxHeaderListSize from - // the http.Transport's MaxResponseHeaderBytes. They don't - // exactly mean the same thing, but they're close. - // - // TODO: also add this to x/net/http2.Configure Transport, behind - // a +build go1.7 build tag: - if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 { - const h2max = 1<<32 - 1 - if limit1 >= h2max { - t2.MaxHeaderListSize = h2max - } else { - t2.MaxHeaderListSize = uint32(limit1) - } - } -} - -// ProxyFromEnvironment returns the URL of the proxy to use for a -// given request, as indicated by the environment variables -// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions -// thereof). HTTPS_PROXY takes precedence over HTTP_PROXY for https -// requests. -// -// The environment values may be either a complete URL or a -// "host[:port]", in which case the "http" scheme is assumed. -// The schemes "http", "https", and "socks5" are supported. -// An error is returned if the value is a different form. -// -// A nil URL and nil error are returned if no proxy is defined in the -// environment, or a proxy should not be used for the given request, -// as defined by NO_PROXY. -// -// As a special case, if req.URL.Host is "localhost" (with or without -// a port number), then a nil URL and nil error will be returned. -func ProxyFromEnvironment(req *http.Request) (*url.URL, error) { - return envProxyFunc()(req.URL) -} - -// ProxyURL returns a proxy function (for use in a Transport) -// that always returns the same URL. -func ProxyURL(fixedURL *url.URL) func(*http.Request) (*url.URL, error) { - return func(*http.Request) (*url.URL, error) { - return fixedURL, nil - } -} - // transportRequest is a wrapper around a *http.Request that adds // optional extra headers to write and stores any error to return // from roundTrip. @@ -559,7 +457,6 @@ func (t *Transport) alternateRoundTripper(req *http.Request) http.RoundTripper { // roundTrip implements a http.RoundTripper over HTTP. func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { - t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) ctx := req.Context() trace := httptrace.ContextClientTrace(ctx) @@ -812,7 +709,6 @@ func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.idleMu.Lock() m := t.idleConn t.idleConn = nil @@ -824,7 +720,7 @@ func (t *Transport) CloseIdleConnections() { pconn.close(errCloseIdleConns) } } - if t2 := t.h2transport; t2 != nil { + if t2 := t.t2; t2 != nil { t2.CloseIdleConnections() } } @@ -887,7 +783,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } - cm.onlyH1 = requestRequiresHTTP1(treq.Request) + cm.onlyH1 = t.ForceHTTP1 || requestRequiresHTTP1(treq.Request) return cm, err } @@ -951,7 +847,7 @@ func (t *Transport) maxIdleConnsPerHost() int { if v := t.MaxIdleConnsPerHost; v != 0 { return v } - return DefaultMaxIdleConnsPerHost + return defaultMaxIdleConnsPerHost } // tryPutIdleConn adds pconn to the list of idle persistent connections awaiting @@ -1774,7 +1670,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } } - if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + if s := pconn.tlsState; !t.ForceHTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { alt := next(cm.targetAddr, pconn.conn.(TLSConn)) if e, ok := alt.(erringRoundTripper); ok { From 567057decb432da760fbaef4fa0f69158f054b1e Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 11:12:50 +0800 Subject: [PATCH 232/843] remove unnecessary ForceHTTP1 judge --- h2_bundle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/h2_bundle.go b/h2_bundle.go index a0b4b825..2fbf7c60 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -3347,7 +3347,7 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { if t1.TLSClientConfig == nil { t1.TLSClientConfig = new(tls.Config) } - if !t1.ForceHTTP1 && !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) } if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { From 80cd8a3719a33ed3bd4c9f49460c4ae5cfcaeaa3 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 11:16:36 +0800 Subject: [PATCH 233/843] add SetTLSHandshakeTimeout --- client.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/client.go b/client.go index 1b981054..5345970b 100644 --- a/client.go +++ b/client.go @@ -1123,6 +1123,18 @@ func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net return c } +// SetTLSHandshakeTimeout is a global wrapper methods which delegated +// to the default client's SetTLSHandshakeTimeout. +func SetTLSHandshakeTimeout(timeout time.Duration) *Client { + return defaultClient.SetTLSHandshakeTimeout(timeout) +} + +// SetTLSHandshakeTimeout set the TLS handshake timeout. +func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { + c.t.TLSHandshakeTimeout = timeout + return c +} + // EnableForceHTTP1 is a global wrapper methods which delegated // to the default client's EnableForceHTTP1. func EnableForceHTTP1() *Client { From 48e2904e9560abbe0d3549bd8f21dad8b13c4fdc Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 15:39:30 +0800 Subject: [PATCH 234/843] fix panic in SetQueryString #90 --- request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/request.go b/request.go index 427388bb..53f0bfe7 100644 --- a/request.go +++ b/request.go @@ -148,6 +148,9 @@ func SetQueryString(query string) *Request { func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err == nil { + if r.QueryParams == nil { + r.QueryParams = make(urlpkg.Values) + } for p, v := range params { for _, pv := range v { r.QueryParams.Add(p, pv) From 24df0640bb2d83f1af14c0926bfadfa92645f255 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Feb 2022 15:44:36 +0800 Subject: [PATCH 235/843] optimize set query string --- client.go | 20 ++++++++++---------- request.go | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 5345970b..8fb8a370 100644 --- a/client.go +++ b/client.go @@ -466,17 +466,17 @@ func SetCommonQueryString(query string) *Client { // SetCommonQueryString set URL query parameters using the raw query string. func (c *Client) SetCommonQueryString(query string) *Client { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) - if err == nil { - if c.QueryParams == nil { - c.QueryParams = make(urlpkg.Values) - } - for p, v := range params { - for _, pv := range v { - c.QueryParams.Add(p, pv) - } - } - } else { + if err != nil { c.log.Warnf("failed to parse query string (%s): %v", query, err) + return c + } + if c.QueryParams == nil { + c.QueryParams = make(urlpkg.Values) + } + for p, v := range params { + for _, pv := range v { + c.QueryParams.Add(p, pv) + } } return c } diff --git a/request.go b/request.go index 53f0bfe7..714cfa14 100644 --- a/request.go +++ b/request.go @@ -147,17 +147,17 @@ func SetQueryString(query string) *Request { // SetQueryString set URL query parameters using the raw query string. func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) - if err == nil { - if r.QueryParams == nil { - r.QueryParams = make(urlpkg.Values) - } - for p, v := range params { - for _, pv := range v { - r.QueryParams.Add(p, pv) - } + if err != nil { + r.client.log.Warnf("failed to parse query string (%s): %v", query, err) + return r + } + if r.QueryParams == nil { + r.QueryParams = make(urlpkg.Values) + } + for p, v := range params { + for _, pv := range v { + r.QueryParams.Add(p, pv) } - } else { - r.client.log.Errorf("%v", err) } return r } From 77f1ac30a5f6b8ade709e0ab92e47b2207abf010 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 09:54:04 +0800 Subject: [PATCH 236/843] update README: optimize Big News --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ee1cd2ff..cd10fe52 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,11 @@ ## Big News -Brand new v3 version is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) +Brand-new version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). -> v2 is a transitional version, cuz some breaking changes were introduced during v2 refactoring, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. +> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. ## Table of Contents From 6404d4ac1389f42a760bc47ff8a2f596735557e2 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:10:39 +0800 Subject: [PATCH 237/843] update README: add section "HTTP2 and HTTP1" --- README.md | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cd10fe52..7ff312cf 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Quick Start](#Quick-Start) * [Debugging - Dump/Log/Trace](#Debugging) * [Quick HTTP Test](#Test) +* [HTTP2 and HTTP1](#HTTP2-HTTP1) * [URL Path and Query Parameter](#Param) * [Form Data](#Form) * [Header and Cookie](#Header-Cookie) @@ -37,9 +38,9 @@ If you want to use the older version, check it out on [v1 branch](https://github * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging - Log/Trace/Dump](#Debugging). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). +* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). -* Automatic marshal and unmarshal for JSON and XML content type and fully customizable. -* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support. +* Automatic marshal and unmarshal for JSON and XML content type and fully customizable (see [Body and Marshal/Unmarshal](#Body)). * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. * Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. @@ -261,6 +262,16 @@ Use `MustXXX` to ignore error handling during test, make it possible to complete fmt.Println(req.DevMode().R().MustGet("https://imroc.cc").TraceInfo()) ``` +## HTTP2 and HTTP1 + +Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, which is negotiated by TLS handshake. + +You can force using `HTTP/1.1` if you want. + +```go +client.EnableForceHTTP1() +``` + ## URL Path and Query Parameter **Path Parameter** From ab0efb0343a469e3c64f3b8090592122c9e7f6e8 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:15:45 +0800 Subject: [PATCH 238/843] update README: Quick HTTP Test --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ff312cf..81eef49c 100644 --- a/README.md +++ b/README.md @@ -248,7 +248,7 @@ req.SetTimeout(5 * time.Second). // Call the global method just like the Request's method, // which will create request automatically using the default // client, so you can treat package name `req` as a Request, -// and you don't need to create request explicitly. +// and you don't need to create any request and client explicitly. req.SetQueryParam("page", "2"). SetHeader("Accept", "text/xml"). // Override client level settings at request level. Get("https://api.example.com/repos") From 96b3414300a78d896ccd5252fffa6b34f96c4f77 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:18:02 +0800 Subject: [PATCH 239/843] update README: Features --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 81eef49c..4b6957ff 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * Easy [Download and Upload](#Download-Upload). * Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. * Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. -* Support middleware before request sent and after got response. +* Support middleware before request sent and after got response (see [Request and Response Middleware](#Middleware)). ## Quick Start From 36ad2dd0c200bda39068ab70460618baea7426e5 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:21:09 +0800 Subject: [PATCH 240/843] update README: fix dump methods --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 4b6957ff..9cb0d55b 100644 --- a/README.md +++ b/README.md @@ -347,7 +347,7 @@ client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") ## Form Data ```go -client := req.C().EnableDumpOnlyRequest() +client := req.C().EnableDumpWithoutResponse() client.R().SetFormData(map[string]string{ "username": "imroc", "blog": "https://imroc.cc", @@ -394,7 +394,7 @@ client.SetCommonFormDataFromValues(v) ```go // Let's dump the header to see what's going on -client := req.C().EnableDumpOnlyHeader() +client := req.C().EnableDumpWithoutBody() // Send a request with multiple headers and cookies client.R(). @@ -427,7 +427,7 @@ resp2, err := client.R().Get(url2) ```go // Let's dump the header to see what's going on -client := req.C().EnableDumpOnlyHeader() +client := req.C().EnableDumpWithoutBody() // Send a request with multiple headers and cookies client.R(). @@ -484,7 +484,7 @@ client.SetCookieJar(nil) ```go // Create a client that dump request -client := req.C().EnableDumpOnlyRequest() +client := req.C().EnableDumpWithoutResponse() // SetBody accepts string, []byte, io.Reader, use type assertion to // determine the data type of body automatically. client.R().SetBody("test").Post("https://httpbin.org/post") @@ -563,7 +563,7 @@ type ErrorMessage struct { Message string `json:"message"` } // Create a client and dump body to see details -client := req.C().EnableDumpOnlyBody() +client := req.C().EnableDumpWithoutHeader() // Send a request and unmarshal result automatically according to // response `Content-Type` @@ -784,7 +784,7 @@ client.OnAfterResponse(func(c *req.Client, r *req.Response) error { ## Redirect Policy ```go -client := req.C().EnableDumpOnlyRequest() +client := req.C().EnableDumpWithoutResponse() client.SetRedirectPolicy( // Only allow up to 5 redirects From cfef5342a8e83244d83f97760897884cb8d3517d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:32:05 +0800 Subject: [PATCH 241/843] update README: force http1 --- README.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9cb0d55b..88375bb2 100644 --- a/README.md +++ b/README.md @@ -269,7 +269,23 @@ Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by defau You can force using `HTTP/1.1` if you want. ```go -client.EnableForceHTTP1() +client := req.C().EnableForceHTTP1().EnableDumpWithoutBody() +client.R().MustGet("https://httpbin.org/get") +/* Output +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) +Accept-Encoding: gzip + +HTTP/1.1 200 OK +Date: Tue, 08 Feb 2022 02:30:18 GMT +Content-Type: application/json +Content-Length: 289 +Connection: keep-alive +Server: gunicorn/19.9.0 +Access-Control-Allow-Origin: * +Access-Control-Allow-Credentials: true +*/ ``` ## URL Path and Query Parameter From 33413e46e2f2909fe3f91c18ab6e35d67ae73cbe Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:42:48 +0800 Subject: [PATCH 242/843] update README --- README.md | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 88375bb2..217ad824 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ if resp.StatusCode > 299 { } // Similarly, also support to customize dump settings with predefined convenience settings at request level. -resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") +resp, err = client.R().EnableDumpAllWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") // ... resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") ``` @@ -269,7 +269,7 @@ Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by defau You can force using `HTTP/1.1` if you want. ```go -client := req.C().EnableForceHTTP1().EnableDumpWithoutBody() +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutBody() client.R().MustGet("https://httpbin.org/get") /* Output GET /get HTTP/1.1 @@ -363,7 +363,7 @@ client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") ## Form Data ```go -client := req.C().EnableDumpWithoutResponse() +client := req.C().EnableDumpAllWithoutResponse() client.R().SetFormData(map[string]string{ "username": "imroc", "blog": "https://imroc.cc", @@ -410,7 +410,7 @@ client.SetCommonFormDataFromValues(v) ```go // Let's dump the header to see what's going on -client := req.C().EnableDumpWithoutBody() +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() // Send a request with multiple headers and cookies client.R(). @@ -418,12 +418,12 @@ client.R(). SetHeaders(map[string]string{ // Set multiple headers at once "My-Custom-Header": "My Custom Value", "User": "imroc", - }).Get("https://www.baidu.com/") + }).Get("https://httpbin.org/get") /* Output -GET / HTTP/1.1 -Host: www.baidu.com -User-Agent: req/v2 (https://github.com/imroc/req) +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) Accept: application/json My-Custom-Header: My Custom Value User: imroc @@ -443,7 +443,7 @@ resp2, err := client.R().Get(url2) ```go // Let's dump the header to see what's going on -client := req.C().EnableDumpWithoutBody() +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() // Send a request with multiple headers and cookies client.R(). @@ -466,13 +466,12 @@ client.R(). HttpOnly: false, Secure: true, }, - ).Get("https://www.baidu.com/") + ).Get("https://httpbin.org/get") /* Output -GET / HTTP/1.1 -Host: www.baidu.com -User-Agent: req/v2 (https://github.com/imroc/req) -Accept: application/json +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) Cookie: testcookie1="testcookie1 value"; testcookie2="testcookie2 value" Accept-Encoding: gzip */ @@ -500,7 +499,7 @@ client.SetCookieJar(nil) ```go // Create a client that dump request -client := req.C().EnableDumpWithoutResponse() +client := req.C().EnableDumpAllWithoutResponse() // SetBody accepts string, []byte, io.Reader, use type assertion to // determine the data type of body automatically. client.R().SetBody("test").Post("https://httpbin.org/post") @@ -579,7 +578,7 @@ type ErrorMessage struct { Message string `json:"message"` } // Create a client and dump body to see details -client := req.C().EnableDumpWithoutHeader() +client := req.C().EnableDumpAllWithoutHeader() // Send a request and unmarshal result automatically according to // response `Content-Type` @@ -800,7 +799,7 @@ client.OnAfterResponse(func(c *req.Client, r *req.Response) error { ## Redirect Policy ```go -client := req.C().EnableDumpWithoutResponse() +client := req.C().EnableDumpAllWithoutResponse() client.SetRedirectPolicy( // Only allow up to 5 redirects From add7ea1a66e36c01346b0163312caa0ec2963cd6 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 10:43:50 +0800 Subject: [PATCH 243/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 217ad824..6789d16c 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ if resp.StatusCode > 299 { } // Similarly, also support to customize dump settings with predefined convenience settings at request level. -resp, err = client.R().EnableDumpAllWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") +resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") // ... resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") ``` From 2dd79a901e8b0fa752f977c9ccd758d9d41bf974 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 13:21:31 +0800 Subject: [PATCH 244/843] fix AllowGetMethodPayload --- middleware.go | 1 + 1 file changed, 1 insertion(+) diff --git a/middleware.go b/middleware.go index 35cfe57e..12b0e66e 100644 --- a/middleware.go +++ b/middleware.go @@ -136,6 +136,7 @@ func handleMarshalBody(c *Client, r *Request) error { func parseRequestBody(c *Client, r *Request) (err error) { if c.isPayloadForbid(r.RawRequest.Method) { + r.RawRequest.Body = nil return } // handle multipart From 9f36cf37c8b2c0fc137b5aaadfc0b35d68898d34 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 14:25:52 +0800 Subject: [PATCH 245/843] request-level dumps no longer override client-level dumps --- dump.go | 10 +++-- h2_bundle.go | 92 ++++++++++++++++++++++++++++++++------------- http_request.go | 10 +++-- http_response.go | 4 +- textproto_reader.go | 20 ++++++++-- transfer.go | 20 ++++++---- transport.go | 12 +++--- 7 files changed, 118 insertions(+), 50 deletions(-) diff --git a/dump.go b/dump.go index b0afbfbc..816c7b4e 100644 --- a/dump.go +++ b/dump.go @@ -158,9 +158,13 @@ func (d *dumper) Start() { } } -func getDumperOverride(dump *dumper, ctx context.Context) *dumper { +func getDumpers(dump *dumper, ctx context.Context) []*dumper { + dumps := []*dumper{} + if dump != nil { + dumps = append(dumps, dump) + } if d, ok := ctx.Value("_dumper").(*dumper); ok { - return d + dumps = append(dumps, d) } - return dump + return dumps } diff --git a/h2_bundle.go b/h2_bundle.go index 2fbf7c60..73f6a60a 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1218,10 +1218,21 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { - dump := getDumperOverride(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) - hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dump) - if err == nil && dump != nil && dump.ResponseHeader { - dump.dump([]byte("\r\n")) + dumps := getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + if len(dumps) > 0 { + dd := []*dumper{} + for _, dump := range dumps { + if dump.ResponseHeader { + dd = append(dd, dump) + } + } + dumps = dd + } + hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dumps) + if err == nil && len(dumps) > 0 { + for _, dump := range dumps { + dump.dump([]byte("\r\n")) + } } return hr, err } @@ -2216,7 +2227,7 @@ func (fr *http2Framer) maxHeaderStringLen() int { // readMetaFrame returns 0 or more CONTINUATION frames from fr and // merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. -func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dump *dumper) (*http2MetaHeadersFrame, error) { +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { if fr.AllowIllegalReads { return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") } @@ -2265,12 +2276,16 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dump *dumper) (*http mh.Fields = append(mh.Fields, hf) } emitFunc := rawEmitFunc - if dump != nil && dump.ResponseHeader { + + if len(dumps) > 0 { emitFunc = func(hf hpack.HeaderField) { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + for _, dump := range dumps { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + } rawEmitFunc(hf) } } + hdec.SetEmitFunc(emitFunc) // Lose reference to MetaHeadersFrame: defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) @@ -4409,13 +4424,13 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } - dump := getDumperOverride(cs.cc.t.t1.dump, req.Context()) + dumps := getDumpers(cs.cc.t.t1.dump, req.Context()) // Past this point (where we send request headers), it is possible for // RoundTrip to return successfully. Since the RoundTrip contract permits // the caller to "mutate or reuse" the Request after closing the Response's Body, // we must take care when referencing the Request from here on. - err = cs.encodeAndWriteHeaders(req, dump) + err = cs.encodeAndWriteHeaders(req, dumps) <-cc.reqHeaderMu if err != nil { return err @@ -4447,14 +4462,24 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } - if err = cs.writeRequestBody(req, dump); err != nil { + if len(dumps) > 0 { + dd := []*dumper{} + for _, dump := range dumps { + if dump.RequestBody { + dd = append(dd, dump) + } + } + dumps = dd + } + + if err = cs.writeRequestBody(req, dumps); err != nil { if err != http2errStopReqBodyWrite { http2traceWroteRequest(cs.trace, err) return err } } else { cs.sentEndStream = true - if dump != nil && dump.RequestBody { + for _, dump := range dumps { dump.dump([]byte("\r\n\r\n")) } } @@ -4492,7 +4517,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } -func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dump *dumper) error { +func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*dumper) error { cc := cs.cc ctx := cs.ctx @@ -4522,7 +4547,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dump *dump hasTrailers := trailers != "" contentLen := http2actualContentLength(req) hasBody := contentLen != 0 - hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dump) + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dumps) if err != nil { return err } @@ -4687,7 +4712,7 @@ func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { var http2bufPool sync.Pool // of *[]byte -func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) (err error) { +func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper) (err error) { cc := cs.cc body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM @@ -4712,9 +4737,11 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) ( } writeData := cc.fr.WriteData - if dump != nil && dump.RequestBody { + if len(dumps) > 0 { writeData = func(streamID uint32, endStream bool, data []byte) error { - dump.dump(data) + for _, dump := range dumps { + dump.dump(data) + } return cc.fr.WriteData(streamID, endStream, data) } } @@ -4807,7 +4834,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dump *dumper) ( defer cc.wmu.Unlock() var trls []byte if len(trailer) > 0 { - trls, err = cc.encodeTrailers(trailer, dump) + trls, err = cc.encodeTrailers(trailer, dumps) if err != nil { return err } @@ -4870,7 +4897,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er var http2errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. -func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dump *dumper) ([]byte, error) { +func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dumps []*dumper) ([]byte, error) { cc.hbuf.Reset() if req.URL == nil { return nil, http2errNilRequestURL @@ -5022,10 +5049,21 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, traceHeaders := http2traceHasWroteHeaderField(trace) writeHeader := cc.writeHeader - if dump != nil && dump.RequestHeader { - writeHeader = func(name, value string) { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) - cc.writeHeader(name, value) + if len(dumps) > 0 { + dd := []*dumper{} + for _, dump := range dumps { + if dump.RequestHeader { + dd = append(dd, dump) + } + } + dumps = dd + if len(dumps) > 0 { + writeHeader = func(name, value string) { + for _, dump := range dumps { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + } + cc.writeHeader(name, value) + } } } @@ -5043,7 +5081,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } }) - if dump != nil && dump.RequestHeader { + for _, dump := range dumps { dump.dump([]byte("\r\n")) } @@ -5073,7 +5111,7 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } // requires cc.wmu be held. -func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dump *dumper) ([]byte, error) { +func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) @@ -5088,9 +5126,11 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dump *dumper) ([] } writeHeader := cc.writeHeader - if dump != nil && dump.RequestBody { + if len(dumps) > 0 { writeHeader = func(name, value string) { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + for _, dump := range dumps { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + } cc.writeHeader(name, value) } } diff --git a/http_request.go b/http_request.go index f9c1c170..7fad63da 100644 --- a/http_request.go +++ b/http_request.go @@ -85,7 +85,7 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or // extraHeaders may be nil // waitForContinue may be nil // always closes body -func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders http.Header, waitForContinue func() bool, dump *dumper) (err error) { +func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders http.Header, waitForContinue func() bool, dumps []*dumper) (err error) { trace := httptrace.ContextClientTrace(r.Context()) if trace != nil && trace.WroteRequest != nil { defer func() { @@ -149,8 +149,10 @@ func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders ht } rw := w // raw writer - if dump != nil && dump.RequestHeader { - w = dump.WrapWriter(w) + for _, dump := range dumps { + if dump.RequestHeader { + w = dump.WrapWriter(w) + } } _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) @@ -240,7 +242,7 @@ func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders ht // Write body and trailer closed = true - err = tw.writeBody(rw, dump) + err = tw.writeBody(rw, dumps) if err != nil { if tw.bodyReadError == err { err = requestBodyReadError{err} diff --git a/http_response.go b/http_response.go index c8980680..83ee1ab1 100644 --- a/http_response.go +++ b/http_response.go @@ -27,8 +27,8 @@ var respExcludeHeader = map[string]bool{ // After that call, clients can inspect resp.Trailer to find key/value // pairs included in the response trailer. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - dump := getDumperOverride(pc.t.dump, req.Context()) - tp := newTextprotoReader(pc.br, dump) + dumps := getDumpers(pc.t.dump, req.Context()) + tp := newTextprotoReader(pc.br, dumps) resp := &http.Response{ Request: req, } diff --git a/textproto_reader.go b/textproto_reader.go index 520159cb..f0982ead 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -46,10 +46,21 @@ type textprotoReader struct { // To avoid denial of service attacks, the provided bufio.Reader // should be reading from an io.LimitReader or similar textprotoReader to bound // the size of responses. -func newTextprotoReader(r *bufio.Reader, dump *dumper) *textprotoReader { +func newTextprotoReader(r *bufio.Reader, dumps []*dumper) *textprotoReader { commonHeaderOnce.Do(initCommonHeader) t := &textprotoReader{R: r} - if dump != nil && dump.ResponseHeader { + if len(dumps) > 0 { + dd := []*dumper{} + for _, dump := range dumps { + if dump.ResponseHeader { + dd = append(dd, dump) + } + } + dumps = dd + + } + + if len(dumps) > 0 { t.readLine = func() (line []byte, isPrefix bool, err error) { line, err = t.R.ReadSlice('\n') if len(line) == 0 { @@ -59,7 +70,9 @@ func newTextprotoReader(r *bufio.Reader, dump *dumper) *textprotoReader { return } err = nil - dump.dump(line) + for _, dump := range dumps { + dump.dump(line) + } if line[len(line)-1] == '\n' { drop := 1 if len(line) > 1 && line[len(line)-2] == '\r' { @@ -72,6 +85,7 @@ func newTextprotoReader(r *bufio.Reader, dump *dumper) *textprotoReader { } else { t.readLine = t.R.ReadLine } + return t } diff --git a/transfer.go b/transfer.go index 79829fe9..bdee2245 100644 --- a/transfer.go +++ b/transfer.go @@ -330,7 +330,7 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) } // always closes t.BodyCloser -func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { +func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { var ncopy int64 closed := false defer func() { @@ -343,8 +343,10 @@ func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { }() rw := w // raw writer - if dump != nil && dump.RequestBody { - w = dump.WrapWriter(w) + for _, dump := range dumps { + if dump.RequestBody { + w = dump.WrapWriter(w) + } } // Write body. We "unwrap" the body first if it was wrapped in a @@ -358,8 +360,10 @@ func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { rw = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(rw) - if dump != nil && dump.RequestBody { - cw = dump.WrapWriteCloser(cw) + for _, dump := range dumps { + if dump.RequestBody { + cw = dump.WrapWriteCloser(cw) + } } _, err = t.doBodyCopy(cw, body) if err == nil { @@ -383,8 +387,10 @@ func (t *transferWriter) writeBody(w io.Writer, dump *dumper) (err error) { if err != nil { return err } - if dump != nil && dump.RequestBody { - dump.dump([]byte("\r\n")) + for _, dump := range dumps { + if dump.RequestBody { + dump.dump([]byte("\r\n")) + } } } if t.BodyCloser != nil { diff --git a/transport.go b/transport.go index 9a701a17..276d9285 100644 --- a/transport.go +++ b/transport.go @@ -268,9 +268,11 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { } func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { - dump := getDumperOverride(t.dump, req.Context()) - if dump != nil && dump.ResponseBody { - res.Body = dump.WrapReadCloser(res.Body) + dumps := getDumpers(t.dump, req.Context()) + for _, dump := range dumps { + if dump.ResponseBody { + res.Body = dump.WrapReadCloser(res.Body) + } } } @@ -2328,8 +2330,8 @@ func (pc *persistConn) writeLoop() { case wr := <-pc.writech: startBytesWritten := pc.nwrite ctx := wr.req.Request.Context() - dump := getDumperOverride(pc.t.dump, ctx) - err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), dump) + dumps := getDumpers(pc.t.dump, ctx) + err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), dumps) if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's From 1286a971b6dddfb11256160feea09c54b88080b3 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 14:37:36 +0800 Subject: [PATCH 246/843] update README --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6789d16c..09c6ecaa 100644 --- a/README.md +++ b/README.md @@ -145,16 +145,17 @@ client.R().Get("https://www.baidu.com/") opt.ResponseBody = false client.R().Get("https://www.baidu.com/") -// You can also enable dump at request level, dump to memory and will not print it out -// by default, you can call `Response.Dump()` to get the dump result and print -// only if you want to. +// You can also enable dump at request level, which will not override client-level dumpping, +// dump to memory and will not print it out by default, you can call `Response.Dump()` to get +// the dump result and print only if you want to, typically used in production, only record +// the content of the request when the request is abnormal to help us troubleshoot problems. resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") if err != nil { fmt.Println("err:", err) fmt.Println("raw content:\n", resp.Dump()) return } -if resp.StatusCode > 299 { +if !resp.IsSuccess() { // Status code not beetween 200 and 299 fmt.Println("bad status:", resp.Status) fmt.Println("raw content:\n", resp.Dump()) return @@ -166,6 +167,8 @@ resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("htt resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") ``` +> Request-level dumpping will not override client-level dumpping, cuz + **Enable DebugLog for Deeper Insights** ```go From 8967b755a280b5023bcfed87bfdb6fbe9d4fd52b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 14:44:32 +0800 Subject: [PATCH 247/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 09c6ecaa..e767f9fe 100644 --- a/README.md +++ b/README.md @@ -139,11 +139,11 @@ opt := &req.DumpOptions{ Async: false, } client.SetCommonDumpOptions(opt).EnableDumpAll() -client.R().Get("https://www.baidu.com/") +client.R().Get("https://httpbin.org/get") // Change settings dynamiclly opt.ResponseBody = false -client.R().Get("https://www.baidu.com/") +client.R().Get("https://httpbin.org/get") // You can also enable dump at request level, which will not override client-level dumpping, // dump to memory and will not print it out by default, you can call `Response.Dump()` to get From 59c5ef077abc81c8234f064b70e13a8c1a5736a7 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 18:56:02 +0800 Subject: [PATCH 248/843] fix autodecode method names --- client.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 8fb8a370..949d271f 100644 --- a/client.go +++ b/client.go @@ -802,28 +802,28 @@ func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { return c } -// SetAutoDecodeAllTypeFunc is a global wrapper methods which delegated +// SetAutoDecodeContentTypeFunc is a global wrapper methods which delegated // to the default client's SetAutoDecodeAllTypeFunc. -func SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { - return defaultClient.SetAutoDecodeAllTypeFunc(fn) +func SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { + return defaultClient.SetAutoDecodeContentTypeFunc(fn) } -// SetAutoDecodeAllTypeFunc set the custmize function that determins the content-type +// SetAutoDecodeContentTypeFunc set the custmize function that determins the content-type // whether if should be auto-detected and decode to utf-8 -func (c *Client) SetAutoDecodeAllTypeFunc(fn func(contentType string) bool) *Client { +func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = fn return c } -// SetAutoDecodeAllType is a global wrapper methods which delegated -// to the default client's SetAutoDecodeAllType. -func SetAutoDecodeAllType() *Client { - return defaultClient.SetAutoDecodeAllType() +// SetAutoDecodeAllContentType is a global wrapper methods which delegated +// to the default client's SetAutoDecodeAllContentType. +func SetAutoDecodeAllContentType() *Client { + return defaultClient.SetAutoDecodeAllContentType() } -// SetAutoDecodeAllType enables to try auto-detect and decode all content type to utf-8. -func (c *Client) SetAutoDecodeAllType() *Client { +// SetAutoDecodeAllContentType enables to try auto-detect and decode all content type to utf-8. +func (c *Client) SetAutoDecodeAllContentType() *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = func(contentType string) bool { return true From b216f720989f0ba488f2f657fda232fdaf1862c4 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 20:02:33 +0800 Subject: [PATCH 249/843] improve doc --- README.md | 5 ++- client.go | 24 +++++------- docs/api.md | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 15 deletions(-) create mode 100644 docs/api.md diff --git a/README.md b/README.md index e767f9fe..06ac3d50 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Redirect Policy](#Redirect) * [Proxy](#Proxy) * [TODO List](#TODO) +* [API Reference](#API) * [License](#License) ## Features @@ -82,6 +83,8 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. +Checkout [Req API Reference](docs/api.md) for a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). + ## Debugging - Dump/Log/Trace **Dump the Content** @@ -856,4 +859,4 @@ client.SetProxy(nil) ## License -`Req` released under MIT license, refer [LICENSE](LICENSE) file. +`Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/client.go b/client.go index 949d271f..83d16536 100644 --- a/client.go +++ b/client.go @@ -663,9 +663,9 @@ func EnableDumpAllWithoutRequestBody() *Client { return defaultClient.EnableDumpAllWithoutRequestBody() } -// EnableDumpAllWithoutRequestBody enables dump for all requests, with -// request body excluded, can be used in upload request to avoid dump -// the unreadable binary content. +// EnableDumpAllWithoutRequestBody enables dump for all requests, without +// request body, can be used in upload request to avoid dump the unreadable +// binary content. func (c *Client) EnableDumpAllWithoutRequestBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -679,9 +679,9 @@ func EnableDumpAllWithoutResponseBody() *Client { return defaultClient.EnableDumpAllWithoutResponseBody() } -// EnableDumpAllWithoutResponseBody enables dump for all requests, with -// response body excluded, can be used in download request to avoid dump -// the unreadable binary content. +// EnableDumpAllWithoutResponseBody enables dump for all requests, without +// response body, can be used in download request to avoid dump the unreadable +// binary content. func (c *Client) EnableDumpAllWithoutResponseBody() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -695,8 +695,7 @@ func EnableDumpAllWithoutResponse() *Client { return defaultClient.EnableDumpAllWithoutResponse() } -// EnableDumpAllWithoutResponse enables dump for all requests with only -// request header and body included. +// EnableDumpAllWithoutResponse enables dump for all requests without response. func (c *Client) EnableDumpAllWithoutResponse() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -711,8 +710,7 @@ func EnableDumpAllWithoutRequest() *Client { return defaultClient.EnableDumpAllWithoutRequest() } -// EnableDumpAllWithoutRequest enables dump for all requests with only -// request header and body included. +// EnableDumpAllWithoutRequest enables dump for all requests without request. func (c *Client) EnableDumpAllWithoutRequest() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -727,8 +725,7 @@ func EnableDumpAllWithoutHeader() *Client { return defaultClient.EnableDumpAllWithoutHeader() } -// EnableDumpAllWithoutHeader enables dump for all requests with only -// body of request and response included. +// EnableDumpAllWithoutHeader enables dump for all requests without header. func (c *Client) EnableDumpAllWithoutHeader() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -743,8 +740,7 @@ func EnableDumpAllWithoutBody() *Client { return defaultClient.EnableDumpAllWithoutBody() } -// EnableDumpAllWithoutBody enables dump for all requests with only header -// of request and response included. +// EnableDumpAllWithoutBody enables dump for all requests without body. func (c *Client) EnableDumpAllWithoutBody() *Client { o := c.getDumpOptions() o.RequestBody = false diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..9808ce62 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,107 @@ +# Req API Reference + +Here is a brief introduction to some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). + +## Client Settings + +**Debug Features** + +* [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). + +* [EnableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDebugLog) - Enable debug level log (disabled by default). +* [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) - Disable debug level log (disabled by default). +* [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger, set to nil to disable logger. + +* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - Configures the underlying Transport's `DumpOptions` (need to call `EnableDumpAll()` if you want to enable dump). +* [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests, including all content for the request and response by default. +* [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) - Disable dump for all requests. +* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) - Enable dump for all requests without response body, can be used in download request to avoid dump the unreadable binary content. +* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) - Enable dump for all requests without response (only request header and body). +* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) - Enable dump for all requests without request body, can be used in upload request to avoid dump the unreadable binary content. +* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) - Enable dump for all requests without request (only response header and body). +* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) - Enable dump for all requests without header (only body of request and response). +* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) - Enable dump for all requests without body (only header of request and response). +* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) - Enable dump for all requests and save to the specified filename. +* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) - Enables dump for all requests and save to the specified `io.Writer`. +* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) - Enable dump for all requests and output asynchronously, can be used for debugging in production environment without affecting performance. + +* [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). +* [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) - Disable trace at client level (disabled by default). + +**Common Settings for HTTP Requests** + +* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) - Set a URL query parameters for all requests using the raw query string. +* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) - Set headers for all requests from a map. +* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) - Set a header with key-value pair for all requests. +* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) - Set the `Content-Type` header for all requests. +* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) - Set the bearer auth token for all requests. +* [SetCommonBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBasicAuth) - Set the basic auth for all requests. +* [SetCommonCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonCookies) - Set cookies for all requests. +* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) - Add a URL query parameter with key-value pair for all requests which will not override if same key exists. +* [SetCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParam) - Set a URL query parameter with key-value pair for all requests. +* [SetCommonQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParams) - Set the URL query parameters with a map for all requests. +* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) - Set the form data from `url.Values` for all requests. +* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) - Set the form data from map for all requests. +* [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) - Set the "User-Agent" header for all requests. + +**Auto-Decode** + +* [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) - Enable auto-detect charset and decode to utf-8 (enabled by default). +* [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default) +* [SetAutoDecodeContentType(contentTypes ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentType) - Set the content types that will be auto-detected and decode to utf-8. +* [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) - Set try to auto-detect and decode all content type to utf-8. +* [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) - Custmize the function that determines the content type whether it should be auto-detected and decode to utf-8. + +**Certificates** + +* [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) - Set client certificates from on one more `tls.Certificate`. +* [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) - Set client certificates from cert and key file. +* [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) - Set root certificates from pem files. +* [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) - Set root certificates from string. + +**Marshal&Unmarshal** + +* [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) - Set the JSON Unmarshal function which will be used to unmarshal response body. +* [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) - Set JSON Marshal function which will be used to marshal request body. +* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#SetXmlUnmarshal) - Set the XML Unmarshal function which will be used to unmarshal response body. +* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) - Set the XML Marshal function which will be used to marshal request body. + +**Other Settings** + +* [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) - Add a request middleware which hooks before request sent. +* [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) - Add a response middleware which hooks after response received. + +* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) - Set the client tls config. +* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) - Set the TLS handshake timeout. + +* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) - Enable force using HTTP1 (disabled by default). +* [DisableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHTTP1) - Disable force using HTTP1. + +* [EnableKeepAlives()](EnableKeepAlives()) - Enable HTTP keep-alives. +* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Disable HTTP keep-alives (enabled by default) + +* [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) - Set the request timeout. + +* [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) - Set the default scheme in the client, will be used when there is no scheme in the request url (e.g. "github.com/imroc/req"). +* [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) - Set the default base url, will be used if request url is a relative url. + +* [SetProxyURL(proxyUrl string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxyURL) - Set proxy from the proxy URL. +* [SetProxy(proxy func(*http.Request) (*urlpkg.URL, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxy) - Set proxy from proxy function (e.g. [http.ProxyFromEnvironment](https://pkg.go.dev/net/http@go1.17.6#ProxyFromEnvironment)). + +* [SetOutputDirectory(dir string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetOutputDirectory) - Set output directory that response body will be downloaded to. + +* [SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDialTLS) - Set the customized `DialTLSContext` function to Transport (make sure the returned `conn` implements [TLSConn](https://pkg.go.dev/github.com/imroc/req/v3#TLSConn) if you want your customized `conn` supports HTTP2). +* [SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDial) - Set the customized `DialContext` function to Transport. + +* [SetCookieJar(jar http.CookieJar)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCookieJar) - Set the `CookeJar` to the underlying `http.Client`. + +* [SetRedirectPolicy(policies ...RedirectPolicy)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRedirectPolicy) - Set the RedirectPolicy (see the predefined [AllowedDomainRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#AllowedDomainRedirectPolicy), [AllowedHostRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#AllowedHostRedirectPolicy), [MaxRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#MaxRedirectPolicy), [NoRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#NoRedirectPolicy), [SameDomainRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#SameDomainRedirectPolicy), [SameHostRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#SameDomainRedirectPolicy)). + +* [EnableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableCompression) - Enable the compression (enabled by default). +* [DisableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableCompression) - Disable the compression. + +* [EnableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoReadResponse) - Enable read response body automatically (enabled by default). +* [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Disable read response body automatically. + +* [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Enable allow sending GET method requests with body (disabled by default). +* [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) - Disable allow sending GET method requests with body. \ No newline at end of file From ac5c35f57b6174f1c180b8677bcfd309c7c60849 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 20:17:44 +0800 Subject: [PATCH 250/843] update api.md --- docs/api.md | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/docs/api.md b/docs/api.md index 9808ce62..477943c3 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,10 +1,21 @@ -# Req API Reference +

+

Req API Reference

+

Here is a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3).

+

-Here is a brief introduction to some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +## Table of Contents + +* [Client Settings](#Client) + * [Debug Features](#Debug) + * [Common Settings for HTTP Requests](#Common) + * [Auto-Decode](#Decode) + * [Certificates](#Certs) + * [Marshal&Unmarshal](#Marshal) + * [Other Settings](#Other) ## Client Settings -**Debug Features** +### Debug Features * [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). @@ -12,7 +23,6 @@ Here is a brief introduction to some core APIs, which is convenient to get start * [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) - Disable debug level log (disabled by default). * [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger, set to nil to disable logger. -* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - Configures the underlying Transport's `DumpOptions` (need to call `EnableDumpAll()` if you want to enable dump). * [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests, including all content for the request and response by default. * [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) - Disable dump for all requests. * [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) - Enable dump for all requests without response body, can be used in download request to avoid dump the unreadable binary content. @@ -24,11 +34,12 @@ Here is a brief introduction to some core APIs, which is convenient to get start * [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) - Enable dump for all requests and save to the specified filename. * [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) - Enables dump for all requests and save to the specified `io.Writer`. * [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) - Enable dump for all requests and output asynchronously, can be used for debugging in production environment without affecting performance. +* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - Configures the underlying Transport's `DumpOptions` (need to call `EnableDumpAll()` if you want to enable dump). * [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). * [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) - Disable trace at client level (disabled by default). -**Common Settings for HTTP Requests** +### Common Settings for HTTP Requests * [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) - Set a URL query parameters for all requests using the raw query string. * [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) - Set headers for all requests from a map. @@ -44,7 +55,7 @@ Here is a brief introduction to some core APIs, which is convenient to get start * [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) - Set the form data from map for all requests. * [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) - Set the "User-Agent" header for all requests. -**Auto-Decode** +### Auto-Decode * [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) - Enable auto-detect charset and decode to utf-8 (enabled by default). * [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default) @@ -52,21 +63,21 @@ Here is a brief introduction to some core APIs, which is convenient to get start * [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) - Set try to auto-detect and decode all content type to utf-8. * [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) - Custmize the function that determines the content type whether it should be auto-detected and decode to utf-8. -**Certificates** +### Certificates * [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) - Set client certificates from on one more `tls.Certificate`. * [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) - Set client certificates from cert and key file. * [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) - Set root certificates from pem files. * [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) - Set root certificates from string. -**Marshal&Unmarshal** +### Marshal&Unmarshal * [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) - Set the JSON Unmarshal function which will be used to unmarshal response body. * [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) - Set JSON Marshal function which will be used to marshal request body. * [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#SetXmlUnmarshal) - Set the XML Unmarshal function which will be used to unmarshal response body. * [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) - Set the XML Marshal function which will be used to marshal request body. -**Other Settings** +### Other Settings * [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) - Add a request middleware which hooks before request sent. * [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) - Add a response middleware which hooks after response received. From dc0071f8dd6c91162cdca22d1c6770681deb48c2 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 20:26:26 +0800 Subject: [PATCH 251/843] update api.md --- docs/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api.md b/docs/api.md index 477943c3..946c3025 100644 --- a/docs/api.md +++ b/docs/api.md @@ -15,6 +15,8 @@ ## Client Settings +The following are the chainable settings Client, all of which have corresponding global wrappers. + ### Debug Features * [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). From 0fe8a4a97c141cd30026bca8874f243660e37b7a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 20:27:15 +0800 Subject: [PATCH 252/843] update api.md --- docs/api.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 946c3025..d57a4849 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,8 +1,9 @@

Req API Reference

-

Here is a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3).

+Here is a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). + ## Table of Contents * [Client Settings](#Client) From 26da51a64636a3c531680fe660bb106a448ff369 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 20:45:25 +0800 Subject: [PATCH 253/843] optimize doc --- client.go | 10 ++-- docs/api.md | 136 ++++++++++++++++++++++++++++------------------------ 2 files changed, 80 insertions(+), 66 deletions(-) diff --git a/client.go b/client.go index 83d16536..9f7f91e1 100644 --- a/client.go +++ b/client.go @@ -303,7 +303,9 @@ func SetRedirectPolicy(policies ...RedirectPolicy) *Client { return defaultClient.SetRedirectPolicy(policies...) } -// SetRedirectPolicy helps to set the RedirectPolicy. +// SetRedirectPolicy set the RedirectPolicy, see the predefined AllowedDomainRedirectPolicy, +// AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, SameDomainRedirectPolicy +// and SameHostRedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -1101,7 +1103,9 @@ func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, er return defaultClient.SetDialTLS(fn) } -// SetDialTLS set the customized DialTLSContext +// SetDialTLS set the customized `DialTLSContext` function to Transport (make sure the returned +// `conn` implements [TLSConn](https://pkg.go.dev/github.com/imroc/req/v3#TLSConn) if you want +// your customized `conn` supports HTTP2). func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialTLSContext = fn return c @@ -1113,7 +1117,7 @@ func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error return defaultClient.SetDial(fn) } -// SetDial set the customized Dial function. +// SetDial set the customized DialContext function to Transport. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialContext = fn return c diff --git a/docs/api.md b/docs/api.md index d57a4849..4257cbaf 100644 --- a/docs/api.md +++ b/docs/api.md @@ -13,6 +13,7 @@ Here is a brief list of some core APIs, which is convenient to get started quick * [Certificates](#Certs) * [Marshal&Unmarshal](#Marshal) * [Other Settings](#Other) +* [Request Settings](#Request) ## Client Settings @@ -23,99 +24,108 @@ The following are the chainable settings Client, all of which have corresponding * [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). * [EnableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDebugLog) - Enable debug level log (disabled by default). -* [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) - Disable debug level log (disabled by default). +* [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) * [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger, set to nil to disable logger. -* [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests, including all content for the request and response by default. -* [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) - Disable dump for all requests. -* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) - Enable dump for all requests without response body, can be used in download request to avoid dump the unreadable binary content. -* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) - Enable dump for all requests without response (only request header and body). -* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) - Enable dump for all requests without request body, can be used in upload request to avoid dump the unreadable binary content. -* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) - Enable dump for all requests without request (only response header and body). -* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) - Enable dump for all requests without header (only body of request and response). -* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) - Enable dump for all requests without body (only header of request and response). -* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) - Enable dump for all requests and save to the specified filename. -* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) - Enables dump for all requests and save to the specified `io.Writer`. -* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) - Enable dump for all requests and output asynchronously, can be used for debugging in production environment without affecting performance. -* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - Configures the underlying Transport's `DumpOptions` (need to call `EnableDumpAll()` if you want to enable dump). +* [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests. +* [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) +* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) +* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) +* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) +* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) +* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) +* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) +* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) +* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) +* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) +* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - need to call `EnableDumpAll()` if you want to enable dump. * [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). -* [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) - Disable trace at client level (disabled by default). +* [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) ### Common Settings for HTTP Requests -* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) - Set a URL query parameters for all requests using the raw query string. -* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) - Set headers for all requests from a map. -* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) - Set a header with key-value pair for all requests. -* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) - Set the `Content-Type` header for all requests. -* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) - Set the bearer auth token for all requests. -* [SetCommonBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBasicAuth) - Set the basic auth for all requests. -* [SetCommonCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonCookies) - Set cookies for all requests. -* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) - Add a URL query parameter with key-value pair for all requests which will not override if same key exists. -* [SetCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParam) - Set a URL query parameter with key-value pair for all requests. -* [SetCommonQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParams) - Set the URL query parameters with a map for all requests. -* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) - Set the form data from `url.Values` for all requests. -* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) - Set the form data from map for all requests. -* [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) - Set the "User-Agent" header for all requests. +* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) +* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) +* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) +* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) +* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) +* [SetCommonBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBasicAuth) +* [SetCommonCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonCookies) +* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) +* [SetCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParam) +* [SetCommonQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParams) +* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) +* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) +* [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) ### Auto-Decode -* [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) - Enable auto-detect charset and decode to utf-8 (enabled by default). +* [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) * [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default) -* [SetAutoDecodeContentType(contentTypes ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentType) - Set the content types that will be auto-detected and decode to utf-8. -* [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) - Set try to auto-detect and decode all content type to utf-8. -* [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) - Custmize the function that determines the content type whether it should be auto-detected and decode to utf-8. +* [SetAutoDecodeContentType(contentTypes ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentType) +* [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) +* [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) ### Certificates -* [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) - Set client certificates from on one more `tls.Certificate`. -* [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) - Set client certificates from cert and key file. -* [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) - Set root certificates from pem files. -* [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) - Set root certificates from string. +* [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) +* [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) +* [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) +* [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) ### Marshal&Unmarshal -* [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) - Set the JSON Unmarshal function which will be used to unmarshal response body. -* [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) - Set JSON Marshal function which will be used to marshal request body. -* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#SetXmlUnmarshal) - Set the XML Unmarshal function which will be used to unmarshal response body. -* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) - Set the XML Marshal function which will be used to marshal request body. +* [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) +* [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) +* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#SetXmlUnmarshal) +* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) + +### Middleware + +* [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) +* [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) ### Other Settings -* [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) - Add a request middleware which hooks before request sent. -* [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) - Add a response middleware which hooks after response received. +* [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) +* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) +* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) + +* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) - Disabled by default. +* [DisableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHTTP1) -* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) - Set the client tls config. -* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) - Set the TLS handshake timeout. +* [EnableKeepAlives()](EnableKeepAlives()) +* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) Enabled by default. -* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) - Enable force using HTTP1 (disabled by default). -* [DisableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHTTP1) - Disable force using HTTP1. +* [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) +* [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) -* [EnableKeepAlives()](EnableKeepAlives()) - Enable HTTP keep-alives. -* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Disable HTTP keep-alives (enabled by default) +* [SetProxyURL(proxyUrl string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxyURL) +* [SetProxy(proxy func(*http.Request) (*urlpkg.URL, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxy) -* [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) - Set the request timeout. +* [SetOutputDirectory(dir string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetOutputDirectory) -* [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) - Set the default scheme in the client, will be used when there is no scheme in the request url (e.g. "github.com/imroc/req"). -* [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) - Set the default base url, will be used if request url is a relative url. +* [SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDialTLS) +* [SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDial) -* [SetProxyURL(proxyUrl string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxyURL) - Set proxy from the proxy URL. -* [SetProxy(proxy func(*http.Request) (*urlpkg.URL, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxy) - Set proxy from proxy function (e.g. [http.ProxyFromEnvironment](https://pkg.go.dev/net/http@go1.17.6#ProxyFromEnvironment)). +* [SetCookieJar(jar http.CookieJar)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCookieJar) -* [SetOutputDirectory(dir string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetOutputDirectory) - Set output directory that response body will be downloaded to. +* [SetRedirectPolicy(policies ...RedirectPolicy)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRedirectPolicy) -* [SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDialTLS) - Set the customized `DialTLSContext` function to Transport (make sure the returned `conn` implements [TLSConn](https://pkg.go.dev/github.com/imroc/req/v3#TLSConn) if you want your customized `conn` supports HTTP2). -* [SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDial) - Set the customized `DialContext` function to Transport. +* [EnableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableCompression) +* [DisableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableCompression) - Enabled by default -* [SetCookieJar(jar http.CookieJar)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCookieJar) - Set the `CookeJar` to the underlying `http.Client`. +* [EnableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoReadResponse) +* [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Enabled by default -* [SetRedirectPolicy(policies ...RedirectPolicy)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRedirectPolicy) - Set the RedirectPolicy (see the predefined [AllowedDomainRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#AllowedDomainRedirectPolicy), [AllowedHostRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#AllowedHostRedirectPolicy), [MaxRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#MaxRedirectPolicy), [NoRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#NoRedirectPolicy), [SameDomainRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#SameDomainRedirectPolicy), [SameHostRedirectPolicy](https://pkg.go.dev/github.com/imroc/req/v3#SameDomainRedirectPolicy)). +* [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. +* [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) -* [EnableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableCompression) - Enable the compression (enabled by default). -* [DisableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableCompression) - Disable the compression. +## Request Settings -* [EnableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoReadResponse) - Enable read response body automatically (enabled by default). -* [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Disable read response body automatically. +The following are the chainable settings Request, all of which have corresponding global wrappers. -* [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Enable allow sending GET method requests with body (disabled by default). -* [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) - Disable allow sending GET method requests with body. \ No newline at end of file +[AddQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddQueryParam) +[DisableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.DisableTrace) +[EnableDump()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDump) \ No newline at end of file From b76ac3e9c9ede3cd6a5c08178e8770acbca06b5a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 21:09:33 +0800 Subject: [PATCH 254/843] optimize doc --- docs/api.md | 79 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 6 deletions(-) diff --git a/docs/api.md b/docs/api.md index 4257cbaf..d4cc29bd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -14,6 +14,13 @@ Here is a brief list of some core APIs, which is convenient to get started quick * [Marshal&Unmarshal](#Marshal) * [Other Settings](#Other) * [Request Settings](#Request) + * [URL Query and Path Parameter](#Query) + * [Header and Cookie](#Header) + * [Body and Marshal&Unmarshal](#Body) + * [Request Level Debug](#Debug-Request) + * [Multipart & Form & Upload](#Multipart) + * [Download](#Download) + * [Other Settings](#Other-Request) ## Client Settings @@ -25,7 +32,7 @@ The following are the chainable settings Client, all of which have corresponding * [EnableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDebugLog) - Enable debug level log (disabled by default). * [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) -* [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger, set to nil to disable logger. +* [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger. * [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests. * [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) @@ -38,7 +45,7 @@ The following are the chainable settings Client, all of which have corresponding * [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) * [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) * [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) -* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) - need to call `EnableDumpAll()` if you want to enable dump. +* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) * [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). * [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) @@ -62,7 +69,7 @@ The following are the chainable settings Client, all of which have corresponding ### Auto-Decode * [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) -* [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default) +* [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default). * [SetAutoDecodeContentType(contentTypes ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentType) * [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) * [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) @@ -126,6 +133,66 @@ The following are the chainable settings Client, all of which have corresponding The following are the chainable settings Request, all of which have corresponding global wrappers. -[AddQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddQueryParam) -[DisableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.DisableTrace) -[EnableDump()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDump) \ No newline at end of file +### URL Query and Path Parameter + +* [AddQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddQueryParam) +* [SetQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryParam) +* [SetQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryParams) +* [SetQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryString) +* [SetPathParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetPathParam) +* [SetPathParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetPathParams) + +### Header and Cookie + +* [SetHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetHeader) +* [SetHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetHeaders) +* [SetBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBasicAuth) +* [SetBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBearerAuthToken) +* [SetContentType(contentType string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContentType) +* [SetCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetCookies) + +### Body and Marshal&Unmarshal + +* [SetBody(body interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBody) +* [SetBodyBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyBytes) +* [SetBodyJsonBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonBytes) +* [SetBodyJsonMarshal(v interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonMarshal) +* [SetBodyJsonString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonString) +* [SetBodyString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyString) +* [SetBodyXmlBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlBytes) +* [SetBodyXmlMarshal(v interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlMarshal) +* [SetBodyXmlString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlString) +* [SetResult(result interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetResult) +* [SetError(error interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetError) + +### Request Level Debug + +* [EnableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableTrace) - Disabled by default. +* [DisableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.DisableTrace) +* [EnableDump()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDump) +* [EnableDumpTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpTo) +* [EnableDumpToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpToFile) +* [EnableDumpWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutBody) +* [EnableDumpWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutHeader) +* [EnableDumpWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutRequest) +* [EnableDumpWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutRequestBody) +* [EnableDumpWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutResponse) +* [EnableDumpWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutResponseBody) +* [SetDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDumpOptions) + +### Multipart & Form & Upload + +* [SetFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormData) +* [SetFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormDataFromValues) +* [SetFile(paramName, filePath string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFile) +* [SetFileReader(paramName, filePath string, reader io.Reader)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileReader) +* [SetFiles(files map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFiles) + +### Download + +* [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) +* [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) + +### Other Settings + +* [SetContext(ctx context.Context)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContext) From d719299449dbe9463ad8736696b68ef22f3b2458 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 21:10:57 +0800 Subject: [PATCH 255/843] update api.md --- docs/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index d4cc29bd..ac0f2901 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,7 +24,7 @@ Here is a brief list of some core APIs, which is convenient to get started quick ## Client Settings -The following are the chainable settings Client, all of which have corresponding global wrappers. +The following are the chainable settings of Client, all of which have corresponding global wrappers. ### Debug Features @@ -131,7 +131,7 @@ The following are the chainable settings Client, all of which have corresponding ## Request Settings -The following are the chainable settings Request, all of which have corresponding global wrappers. +The following are the chainable settings of Request, all of which have corresponding global wrappers. ### URL Query and Path Parameter From cdcc7999ac828aaf8d1f45eb3b991218b2663872 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 21:16:29 +0800 Subject: [PATCH 256/843] update api.md --- docs/api.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/api.md b/docs/api.md index ac0f2901..33f92ad3 100644 --- a/docs/api.md +++ b/docs/api.md @@ -26,6 +26,8 @@ Here is a brief list of some core APIs, which is convenient to get started quick The following are the chainable settings of Client, all of which have corresponding global wrappers. +Basically, you can know the meaning of most settings directly from the method name. + ### Debug Features * [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). @@ -133,6 +135,8 @@ The following are the chainable settings of Client, all of which have correspond The following are the chainable settings of Request, all of which have corresponding global wrappers. +Basically, you can know the meaning of most settings directly from the method name. + ### URL Query and Path Parameter * [AddQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddQueryParam) From e580d6cbc09cbe9a948229f8fda409a0a85ccf13 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 21:18:00 +0800 Subject: [PATCH 257/843] update api.md --- docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 33f92ad3..18253a6a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -105,7 +105,7 @@ Basically, you can know the meaning of most settings directly from the method na * [DisableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHTTP1) * [EnableKeepAlives()](EnableKeepAlives()) -* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) Enabled by default. +* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Enabled by default. * [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) * [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) From 00a4255cd46c55aa5bd5ebca89a86883eebc4fe3 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Feb 2022 21:24:18 +0800 Subject: [PATCH 258/843] optimize doc --- README.md | 5 ++++- docs/api.md | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 06ac3d50..bb56fac4 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Features](#Features) * [Quick Start](#Quick-Start) +* [API Reference](#API) * [Debugging - Dump/Log/Trace](#Debugging) * [Quick HTTP Test](#Test) * [HTTP2 and HTTP1](#HTTP2-HTTP1) @@ -83,7 +84,9 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. -Checkout [Req API Reference](docs/api.md) for a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +## API Reference + +Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core API, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). ## Debugging - Dump/Log/Trace diff --git a/docs/api.md b/docs/api.md index 18253a6a..d7cca3ec 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,7 +2,7 @@

Req API Reference

-Here is a brief list of some core APIs, which is convenient to get started quickly. For a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +Here is a brief and categorized list of the core API, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). ## Table of Contents From c3779f89b4bc3632990041250e95c81f4a56a9ec Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 08:16:47 +0800 Subject: [PATCH 259/843] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bb56fac4..cb125a43 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, dump complete request and response content, and even provide global wrapper methods to test with minimal code (see [Debugging - Log/Trace/Dump](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump complete request and response content (see [Debugging - Log/Trace/Dump](#Debugging). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). @@ -862,4 +862,4 @@ client.SetProxy(nil) ## License -`Req` released under MIT license, refer [LICENSE](LICENSE) file. \ No newline at end of file +`Req` released under MIT license, refer [LICENSE](LICENSE) file. From 01fad0e77ecb211cd8f21025539555cff3dbb587 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 10:38:18 +0800 Subject: [PATCH 260/843] optimize api doc --- README.md | 4 ++-- docs/api.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index cb125a43..240a71d3 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump complete request and response content (see [Debugging - Log/Trace/Dump](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Log/Trace/Dump](#Debugging). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). @@ -242,7 +242,7 @@ client.R().Get("https://imroc.cc") **Test with Global Wrapper Methods** -`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to default client, it's very convenient when making API test. +`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to the default client behind the scenes, so you can just treat the package name `req` as a Client or Request to test quickly without create one explicitly. ```go // Call the global methods just like the Client's methods, diff --git a/docs/api.md b/docs/api.md index d7cca3ec..fa668865 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,7 +2,7 @@

Req API Reference

-Here is a brief and categorized list of the core API, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +Here is a brief and categorized list of the core APIs, for a more detailed and complete list, please refer to the [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). ## Table of Contents @@ -24,7 +24,7 @@ Here is a brief and categorized list of the core API, for a more detailed and co ## Client Settings -The following are the chainable settings of Client, all of which have corresponding global wrappers. +The following are the chainable settings of Client, all of which have corresponding global wrappers (Just treat the package name `req` as a Client to test, set up the Client without create any Client explicitly). Basically, you can know the meaning of most settings directly from the method name. From 593460235002cf6a31b921d3d42c913a27262741 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 11:01:44 +0800 Subject: [PATCH 261/843] update README --- README.md | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 240a71d3..034368df 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@

-## Big News +## News -Brand-new version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) +Brand-New version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). @@ -38,7 +38,7 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Log/Trace/Dump](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). * Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). @@ -86,7 +86,7 @@ Checkout more runnable examples in the [examples](examples) direcotry. ## API Reference -Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core API, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). ## Debugging - Dump/Log/Trace @@ -128,11 +128,11 @@ access-control-allow-credentials: true } */ -// Customize client level dump settings with predefined convenience settings. +// Customize dump settings with predefined and convenient settings at client level. client.EnableDumpAllWithoutBody(). // Only dump the header of request and response EnableDumpAllAsync(). // Dump asynchronously to improve performance EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out -// Send request to see the content that have been dumpped +// Send request to see the content that have been dumped client.R().Get(url) // Enable dump with fully customized settings at client level. @@ -151,30 +151,32 @@ client.R().Get("https://httpbin.org/get") opt.ResponseBody = false client.R().Get("https://httpbin.org/get") -// You can also enable dump at request level, which will not override client-level dumpping, -// dump to memory and will not print it out by default, you can call `Response.Dump()` to get -// the dump result and print only if you want to, typically used in production, only record +// You can also enable dump at request level, which will not override client-level dumping, +// dump to the internal buffer and will not print it out by default, you can call `Response.Dump()` +// to get the dump result and print only if you want to, typically used in production, only record // the content of the request when the request is abnormal to help us troubleshoot problems. resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") if err != nil { fmt.Println("err:", err) - fmt.Println("raw content:\n", resp.Dump()) + if resp.Dump() != "" { + fmt.Println("raw content:") + fmt.Println(resp.Dump()) + } return } if !resp.IsSuccess() { // Status code not beetween 200 and 299 fmt.Println("bad status:", resp.Status) - fmt.Println("raw content:\n", resp.Dump()) - return + fmt.Println("raw content:") + fmt.Println(resp.Dump()) + return } -// Similarly, also support to customize dump settings with predefined convenience settings at request level. +// Similarly, also support to customize dump settings with the predefined and convenient settings at request level. resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") // ... resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") ``` -> Request-level dumpping will not override client-level dumpping, cuz - **Enable DebugLog for Deeper Insights** ```go @@ -245,7 +247,7 @@ client.R().Get("https://imroc.cc") `req` wrap methods of both `Client` and `Request` with global methods, which is delegated to the default client behind the scenes, so you can just treat the package name `req` as a Client or Request to test quickly without create one explicitly. ```go -// Call the global methods just like the Client's methods, +// Call the global methods just like the Client's method, // so you can treat package name `req` as a Client, and // you don't need to create any client explicitly. req.SetTimeout(5 * time.Second). @@ -756,19 +758,19 @@ client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) `Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. -Its principle is to detect whether `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information, if it's not included in the header, it will try to sniff the body's content to determine the charset, if found and is not utf-8, then decode it to utf-8 automatically, if the charset is not sure, it will not decode, and leave the body untouched. +Its principle is to detect `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information. And `req` also will try to sniff the body's content to determine the charset if the charset information is not included in the header, if sniffed out and not utf-8, then decode it to utf-8 automatically, and `req` will not try to decode if the charset is not sure, just leave the body untouched. -You can also disable if you don't need or care a lot about performance: +You can also disable it if you don't need or care a lot about performance: ```go client.DisableAutoDecode() ``` -Also you can make some customization: +And also you can make some customization: ```go // Try to auto-detect and decode all content types (some server may return incorrect Content-Type header) -client.SetAutoDecodeAllType() +client.SetAutoDecodeAllContentType() // Only auto-detect and decode content which `Content-Type` header contains "html" or "json" client.SetAutoDecodeContentType("html", "json") @@ -780,7 +782,7 @@ fn := func(contentType string) bool { } return false } -client.SetAutoDecodeAllTypeFunc(fn) +client.SetAutoDecodeContentTypeFunc(fn) ``` ## Request and Response Middleware From 806ae0b429bde6622d48100f5f0157e1615d5816 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 13:02:48 +0800 Subject: [PATCH 262/843] some optimize --- README.md | 6 +- client.go | 182 ++++++++++++++++++++++++++------------------------- h2_bundle.go | 8 +-- 3 files changed, 97 insertions(+), 99 deletions(-) diff --git a/README.md b/README.md index 034368df..8bed0686 100644 --- a/README.md +++ b/README.md @@ -252,7 +252,7 @@ client.R().Get("https://imroc.cc") // you don't need to create any client explicitly. req.SetTimeout(5 * time.Second). SetCommonBasicAuth("imroc", "123456"). - SetCommonHeader("Accept", "application/json"). + SetCommonHeader("Accept", "text/xml"). SetUserAgent("my api client"). DevMode() @@ -261,8 +261,8 @@ req.SetTimeout(5 * time.Second). // client, so you can treat package name `req` as a Request, // and you don't need to create any request and client explicitly. req.SetQueryParam("page", "2"). - SetHeader("Accept", "text/xml"). // Override client level settings at request level. - Get("https://api.example.com/repos") + SetHeader("Accept", "application/json"). // Override client level settings at request level. + Get("https://httpbin.org/get") ``` **Test with MustXXX** diff --git a/client.go b/client.go index 9f7f91e1..4b85f818 100644 --- a/client.go +++ b/client.go @@ -154,7 +154,8 @@ func SetCommonFormDataFromValues(data urlpkg.Values) *Client { return defaultClient.SetCommonFormDataFromValues(data) } -// SetCommonFormDataFromValues set the form data from url.Values for all requests which method allows payload. +// SetCommonFormDataFromValues set the form data from url.Values for all requests +// which request method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { if c.FormData == nil { c.FormData = urlpkg.Values{} @@ -173,7 +174,8 @@ func SetCommonFormData(data map[string]string) *Client { return defaultClient.SetCommonFormData(data) } -// SetCommonFormData set the form data from map for all requests which method allows payload. +// SetCommonFormData set the form data from map for all requests +// which request method allows payload. func (c *Client) SetCommonFormData(data map[string]string) *Client { if c.FormData == nil { c.FormData = urlpkg.Values{} @@ -190,8 +192,8 @@ func SetBaseURL(u string) *Client { return defaultClient.SetBaseURL(u) } -// SetBaseURL set the default base url, will be used if request url is -// a relative url. +// SetBaseURL set the default base URL, will be used if request URL is +// a relative URL. func (c *Client) SetBaseURL(u string) *Client { c.BaseURL = strings.TrimRight(u, "/") return c @@ -203,7 +205,8 @@ func SetOutputDirectory(dir string) *Client { return defaultClient.SetOutputDirectory(dir) } -// SetOutputDirectory set output directory that response will be downloaded to. +// SetOutputDirectory set output directory that response will +// be downloaded to. func (c *Client) SetOutputDirectory(dir string) *Client { c.outputDirectory = dir return c @@ -215,7 +218,7 @@ func SetCertFromFile(certFile, keyFile string) *Client { return defaultClient.SetCertFromFile(certFile, keyFile) } -// SetCertFromFile helps to set client certificates from cert and key file +// SetCertFromFile helps to set client certificates from cert and key file. func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -233,7 +236,7 @@ func SetCerts(certs ...tls.Certificate) *Client { return defaultClient.SetCerts(certs...) } -// SetCerts helps to set client certificates +// SetCerts set client certificates. func (c *Client) SetCerts(certs ...tls.Certificate) *Client { config := c.tlsConfig() config.Certificates = append(config.Certificates, certs...) @@ -255,7 +258,7 @@ func SetRootCertFromString(pemContent string) *Client { return defaultClient.SetRootCertFromString(pemContent) } -// SetRootCertFromString helps to set root CA cert from string +// SetRootCertFromString set root certificates from string. func (c *Client) SetRootCertFromString(pemContent string) *Client { c.appendRootCertData([]byte(pemContent)) return c @@ -267,7 +270,7 @@ func SetRootCertsFromFile(pemFiles ...string) *Client { return defaultClient.SetRootCertsFromFile(pemFiles...) } -// SetRootCertsFromFile helps to set root certs from files +// SetRootCertsFromFile set root certificates from files. func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { for _, pemFile := range pemFiles { rootPemData, err := ioutil.ReadFile(pemFile) @@ -303,9 +306,10 @@ func SetRedirectPolicy(policies ...RedirectPolicy) *Client { return defaultClient.SetRedirectPolicy(policies...) } -// SetRedirectPolicy set the RedirectPolicy, see the predefined AllowedDomainRedirectPolicy, -// AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, SameDomainRedirectPolicy -// and SameHostRedirectPolicy. +// SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect +// responses (usually responses with 301 and 302 status code), see the predefined +// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, +// SameDomainRedirectPolicy and SameHostRedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -334,9 +338,9 @@ func DisableKeepAlives() *Client { return defaultClient.DisableKeepAlives() } -// DisableKeepAlives disables HTTP keep-alives (enabled by default) +// DisableKeepAlives disable the HTTP keep-alives (enabled by default) // and will only use the connection to the server for a single -// HTTP request . +// HTTP request. // // This is unrelated to the similarly named TCP keep-alives. func (c *Client) DisableKeepAlives() *Client { @@ -350,11 +354,7 @@ func EnableKeepAlives() *Client { return defaultClient.EnableKeepAlives() } -// EnableKeepAlives enables HTTP keep-alives (enabled by default) -// and will only use the connection to the server for a single -// HTTP request . -// -// This is unrelated to the similarly named TCP keep-alives. +// EnableKeepAlives enables HTTP keep-alives (enabled by default). func (c *Client) EnableKeepAlives() *Client { c.t.DisableKeepAlives = false return c @@ -367,14 +367,13 @@ func DisableCompression() *Client { } // DisableCompression disables the compression (enabled by default), -// which prevents the Transport from -// requesting compression with an "Accept-Encoding: gzip" -// request header when the Request contains no existing -// Accept-Encoding value. If the Transport requests gzip on -// its own and gets a gzipped response, it's transparently -// decoded in the Response.Body. However, if the user -// explicitly requested gzip it is not automatically -// uncompressed. +// which prevents the Transport from requesting compression +// with an "Accept-Encoding: gzip" request header when the +// Request contains no existing Accept-Encoding value. If +// the Transport requests gzip on its own and gets a gzipped +// response, it's transparently decoded in the Response.Body. +// However, if the user explicitly requested gzip it is not +// automatically uncompressed. func (c *Client) DisableCompression() *Client { c.t.DisableCompression = true return c @@ -386,15 +385,7 @@ func EnableCompression() *Client { return defaultClient.EnableCompression() } -// EnableCompression enables the compression (enabled by default), -// which prevents the Transport from -// requesting compression with an "Accept-Encoding: gzip" -// request header when the Request contains no existing -// Accept-Encoding value. If the Transport requests gzip on -// its own and gets a gzipped response, it's transparently -// decoded in the Response.Body. However, if the user -// explicitly requested gzip it is not automatically -// uncompressed. +// EnableCompression enables the compression (enabled by default). func (c *Client) EnableCompression() *Client { c.t.DisableCompression = false return c @@ -406,10 +397,9 @@ func SetTLSClientConfig(conf *tls.Config) *Client { return defaultClient.SetTLSClientConfig(conf) } -// SetTLSClientConfig sets the client tls config. +// SetTLSClientConfig set the TLS client config. func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf - c.t2.TLSClientConfig = conf return c } @@ -419,7 +409,8 @@ func SetCommonQueryParams(params map[string]string) *Client { return defaultClient.SetCommonQueryParams(params) } -// SetCommonQueryParams sets the URL query parameters with a map at client level. +// SetCommonQueryParams set URL query parameters with a map +// for all requests. func (c *Client) SetCommonQueryParams(params map[string]string) *Client { for k, v := range params { c.SetCommonQueryParam(k, v) @@ -434,7 +425,7 @@ func AddCommonQueryParam(key, value string) *Client { } // AddCommonQueryParam add a URL query parameter with a key-value -// pair at client level +// pair for all requests. func (c *Client) AddCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -450,7 +441,7 @@ func SetCommonQueryParam(key, value string) *Client { } // SetCommonQueryParam set a URL query parameter with a key-value -// pair at client level. +// pair for all requests. func (c *Client) SetCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -465,7 +456,8 @@ func SetCommonQueryString(query string) *Client { return defaultClient.SetCommonQueryString(query) } -// SetCommonQueryString set URL query parameters using the raw query string. +// SetCommonQueryString set URL query parameters with a raw query string +// for all requests. func (c *Client) SetCommonQueryString(query string) *Client { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err != nil { @@ -489,7 +481,7 @@ func SetCommonCookies(cookies ...*http.Cookie) *Client { return defaultClient.SetCommonCookies(cookies...) } -// SetCommonCookies set cookies at client level. +// SetCommonCookies set HTTP cookies for all requests. func (c *Client) SetCommonCookies(cookies ...*http.Cookie) *Client { c.Cookies = append(c.Cookies, cookies...) return c @@ -501,7 +493,7 @@ func DisableDebugLog() *Client { return defaultClient.DisableDebugLog() } -// DisableDebugLog disables debug level log (disabled by default). +// DisableDebugLog disable debug level log (disabled by default). func (c *Client) DisableDebugLog() *Client { c.DebugLog = false return c @@ -513,7 +505,7 @@ func EnableDebugLog() *Client { return defaultClient.EnableDebugLog() } -// EnableDebugLog enables debug level log (disabled by default). +// EnableDebugLog enable debug level log (disabled by default). func (c *Client) EnableDebugLog() *Client { c.DebugLog = true return c @@ -527,7 +519,7 @@ func DevMode() *Client { // DevMode enables: // 1. Dump content of all requests and responses to see details. -// 2. Output debug log for deeper insights. +// 2. Output debug level log for deeper insights. // 3. Trace all requests, so you can get trace info to analyze performance. // 4. Set User-Agent to pretend to be a web browser, avoid returning abnormal data from some sites. func (c *Client) DevMode() *Client { @@ -543,8 +535,8 @@ func SetScheme(scheme string) *Client { return defaultClient.SetScheme(scheme) } -// SetScheme sets custom default scheme in the client, will be used when -// there is no scheme in the request url. +// SetScheme set the default scheme for client, will be used when +// there is no scheme in the request URL (e.g. "github.com/imroc/req"). func (c *Client) SetScheme(scheme string) *Client { if !util.IsStringEmpty(scheme) { c.scheme = strings.TrimSpace(scheme) @@ -558,7 +550,7 @@ func SetLogger(log Logger) *Client { return defaultClient.SetLogger(log) } -// SetLogger set the logger for req, set to nil to disable logger. +// SetLogger set the customized logger for client, will disable log if set to nil. func (c *Client) SetLogger(log Logger) *Client { if log == nil { c.log = &disableLogger{} @@ -581,7 +573,7 @@ func SetTimeout(d time.Duration) *Client { return defaultClient.SetTimeout(d) } -// SetTimeout set the timeout for all requests. +// SetTimeout set timeout for all requests. func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d return c @@ -600,7 +592,7 @@ func EnableDumpAll() *Client { return defaultClient.EnableDumpAll() } -// EnableDumpAll enables dump for all requests, including +// EnableDumpAll enable dump for all requests, including // all content for the request and response by default. func (c *Client) EnableDumpAll() *Client { if c.t.dump != nil { // dump already started @@ -616,8 +608,8 @@ func EnableDumpAllToFile(filename string) *Client { return defaultClient.EnableDumpAllToFile(filename) } -// EnableDumpAllToFile enables dump for all requests and save -// to the specified filename. +// EnableDumpAllToFile enable dump for all requests and output +// to the specified file. func (c *Client) EnableDumpAllToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { @@ -635,8 +627,8 @@ func EnableDumpAllTo(output io.Writer) *Client { return defaultClient.EnableDumpAllTo(output) } -// EnableDumpAllTo enables dump for all requests and save -// to the specified io.Writer. +// EnableDumpAllTo enable dump for all requests and output to +// the specified io.Writer. func (c *Client) EnableDumpAllTo(output io.Writer) *Client { c.getDumpOptions().Output = output c.EnableDumpAll() @@ -649,7 +641,7 @@ func EnableDumpAllAsync() *Client { return defaultClient.EnableDumpAllAsync() } -// EnableDumpAllAsync enables dump for all requests and output +// EnableDumpAllAsync enable dump for all requests and output // asynchronously, can be used for debugging in production // environment without affecting performance. func (c *Client) EnableDumpAllAsync() *Client { @@ -665,9 +657,9 @@ func EnableDumpAllWithoutRequestBody() *Client { return defaultClient.EnableDumpAllWithoutRequestBody() } -// EnableDumpAllWithoutRequestBody enables dump for all requests, without -// request body, can be used in upload request to avoid dump the unreadable -// binary content. +// EnableDumpAllWithoutRequestBody enable dump for all requests without +// request body, can be used in the upload request to avoid dumping the +// unreadable binary content. func (c *Client) EnableDumpAllWithoutRequestBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -681,9 +673,9 @@ func EnableDumpAllWithoutResponseBody() *Client { return defaultClient.EnableDumpAllWithoutResponseBody() } -// EnableDumpAllWithoutResponseBody enables dump for all requests, without -// response body, can be used in download request to avoid dump the unreadable -// binary content. +// EnableDumpAllWithoutResponseBody enable dump for all requests without +// response body, can be used in the download request to avoid dumping the +// unreadable binary content. func (c *Client) EnableDumpAllWithoutResponseBody() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -697,7 +689,8 @@ func EnableDumpAllWithoutResponse() *Client { return defaultClient.EnableDumpAllWithoutResponse() } -// EnableDumpAllWithoutResponse enables dump for all requests without response. +// EnableDumpAllWithoutResponse enable dump for all requests without response, +// can be used if you only care about the request. func (c *Client) EnableDumpAllWithoutResponse() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -712,7 +705,8 @@ func EnableDumpAllWithoutRequest() *Client { return defaultClient.EnableDumpAllWithoutRequest() } -// EnableDumpAllWithoutRequest enables dump for all requests without request. +// EnableDumpAllWithoutRequest enables dump for all requests without request, +// can be used if you only care about the response. func (c *Client) EnableDumpAllWithoutRequest() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -727,7 +721,8 @@ func EnableDumpAllWithoutHeader() *Client { return defaultClient.EnableDumpAllWithoutHeader() } -// EnableDumpAllWithoutHeader enables dump for all requests without header. +// EnableDumpAllWithoutHeader enable dump for all requests without header, +// can be used if you only care about the body. func (c *Client) EnableDumpAllWithoutHeader() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -742,7 +737,8 @@ func EnableDumpAllWithoutBody() *Client { return defaultClient.EnableDumpAllWithoutBody() } -// EnableDumpAllWithoutBody enables dump for all requests without body. +// EnableDumpAllWithoutBody enable dump for all requests without body, +// can be used if you only care about the header. func (c *Client) EnableDumpAllWithoutBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -793,7 +789,7 @@ func SetAutoDecodeContentType(contentTypes ...string) *Client { } // SetAutoDecodeContentType set the content types that will be auto-detected and decode -// to utf-8 +// to utf-8 (e.g. "json", "xml", "html", "text"). func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) @@ -806,8 +802,8 @@ func SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { return defaultClient.SetAutoDecodeContentTypeFunc(fn) } -// SetAutoDecodeContentTypeFunc set the custmize function that determins the content-type -// whether if should be auto-detected and decode to utf-8 +// SetAutoDecodeContentTypeFunc set the function that determines whether the +// specified `Content-Type` should be auto-detected and decode to utf-8. func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = fn @@ -820,7 +816,8 @@ func SetAutoDecodeAllContentType() *Client { return defaultClient.SetAutoDecodeAllContentType() } -// SetAutoDecodeAllContentType enables to try auto-detect and decode all content type to utf-8. +// SetAutoDecodeAllContentType enable try auto-detect charset and decode all +// content type to utf-8. func (c *Client) SetAutoDecodeAllContentType() *Client { opt := c.getResponseOptions() opt.AutoDecodeContentType = func(contentType string) bool { @@ -835,8 +832,8 @@ func DisableAutoDecode() *Client { return defaultClient.DisableAutoDecode() } -// DisableAutoDecode disable auto detect charset and decode to utf-8 -// (enabled by default) +// DisableAutoDecode disable auto-detect charset and decode to utf-8 +// (enabled by default). func (c *Client) DisableAutoDecode() *Client { c.getResponseOptions().DisableAutoDecode = true return c @@ -848,8 +845,8 @@ func EnableAutoDecode() *Client { return defaultClient.EnableAutoDecode() } -// EnableAutoDecode enables auto detect charset and decode to utf-8 -// (enabled by default) +// EnableAutoDecode enable auto-detect charset and decode to utf-8 +// (enabled by default). func (c *Client) EnableAutoDecode() *Client { c.getResponseOptions().DisableAutoDecode = true return c @@ -936,7 +933,7 @@ func DisableDumpAll() *Client { return defaultClient.DisableDumpAll() } -// DisableDumpAll disables the dump. +// DisableDumpAll disable dump for all requests. func (c *Client) DisableDumpAll() *Client { c.t.DisableDump() return c @@ -949,6 +946,7 @@ func SetCommonDumpOptions(opt *DumpOptions) *Client { } // SetCommonDumpOptions configures the underlying Transport's DumpOptions +// for all requests. func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c @@ -1002,7 +1000,7 @@ func SetProxyURL(proxyUrl string) *Client { return defaultClient.SetProxyURL(proxyUrl) } -// SetProxyURL set a proxy from the proxy url. +// SetProxyURL set proxy from the proxy URL. func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) if err != nil { @@ -1019,7 +1017,7 @@ func DisableTraceAll() *Client { return defaultClient.DisableTraceAll() } -// DisableTraceAll disables the trace at client level. +// DisableTraceAll disable trace for all requests. func (c *Client) DisableTraceAll() *Client { c.trace = false return c @@ -1031,7 +1029,7 @@ func EnableTraceAll() *Client { return defaultClient.EnableTraceAll() } -// EnableTraceAll enables the trace at client level. +// EnableTraceAll enable trace for all requests. func (c *Client) EnableTraceAll() *Client { c.trace = true return c @@ -1043,7 +1041,7 @@ func SetCookieJar(jar http.CookieJar) *Client { return defaultClient.SetCookieJar(jar) } -// SetCookieJar set the `CookeJar` to the underlying `http.Client` +// SetCookieJar set the `CookeJar` to the underlying `http.Client`. func (c *Client) SetCookieJar(jar http.CookieJar) *Client { c.httpClient.Jar = jar return c @@ -1055,7 +1053,8 @@ func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetJsonMarshal(fn) } -// SetJsonMarshal set json marshal function which will be used to marshal request body. +// SetJsonMarshal set the JSON marshal function which will be used +// to marshal request body. func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { c.jsonMarshal = fn return c @@ -1067,7 +1066,8 @@ func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetJsonUnmarshal(fn) } -// SetJsonUnmarshal set the JSON unmarshal function which will be used to unmarshal response body. +// SetJsonUnmarshal set the JSON unmarshal function which will be used +// to unmarshal response body. func (c *Client) SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { c.jsonUnmarshal = fn return c @@ -1079,7 +1079,8 @@ func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetXmlMarshal(fn) } -// SetXmlMarshal set the XML marshal function which will be used to marshal request body. +// SetXmlMarshal set the XML marshal function which will be used +// to marshal request body. func (c *Client) SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { c.xmlMarshal = fn return c @@ -1091,7 +1092,8 @@ func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetXmlUnmarshal(fn) } -// SetXmlUnmarshal set the XML unmarshal function which will be used to unmarshal response body. +// SetXmlUnmarshal set the XML unmarshal function which will be used +// to unmarshal response body. func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { c.xmlUnmarshal = fn return c @@ -1103,9 +1105,9 @@ func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, er return defaultClient.SetDialTLS(fn) } -// SetDialTLS set the customized `DialTLSContext` function to Transport (make sure the returned -// `conn` implements [TLSConn](https://pkg.go.dev/github.com/imroc/req/v3#TLSConn) if you want -// your customized `conn` supports HTTP2). +// SetDialTLS set the customized `DialTLSContext` function to Transport. +// Make sure the returned `conn` implements TLSConn if you want your +// customized `conn` supports HTTP2. func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialTLSContext = fn return c @@ -1117,7 +1119,7 @@ func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error return defaultClient.SetDial(fn) } -// SetDial set the customized DialContext function to Transport. +// SetDial set the customized `DialContext` function to Transport. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialContext = fn return c @@ -1141,7 +1143,7 @@ func EnableForceHTTP1() *Client { return defaultClient.EnableForceHTTP1() } -// EnableForceHTTP1 enables force using HTTP1 (disabled by default) +// EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (c *Client) EnableForceHTTP1() *Client { c.t.ForceHTTP1 = true return c @@ -1153,7 +1155,7 @@ func DisableForceHTTP1() *Client { return defaultClient.DisableForceHTTP1() } -// DisableForceHTTP1 disable force using HTTP1 (disabled by default) +// DisableForceHTTP1 disable force using HTTP1 (disabled by default). func (c *Client) DisableForceHTTP1() *Client { c.t.ForceHTTP1 = false return c @@ -1165,7 +1167,7 @@ func DisableAllowGetMethodPayload() *Client { return defaultClient.DisableAllowGetMethodPayload() } -// DisableAllowGetMethodPayload disables sending GET method requests with body. +// DisableAllowGetMethodPayload disable sending GET method requests with body. func (c *Client) DisableAllowGetMethodPayload() *Client { c.AllowGetMethodPayload = false return c diff --git a/h2_bundle.go b/h2_bundle.go index 73f6a60a..4e8cae81 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -3234,10 +3234,6 @@ type http2Transport struct { // it will be used to set http.Response.TLS. DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) - // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. If nil, the default configuration is used. - TLSClientConfig *tls.Config - // ConnPool optionally specifies an alternate connection pool to use. // If nil, the default is used. ConnPool http2ClientConnPool @@ -3762,8 +3758,8 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) - if t.TLSClientConfig != nil { - *cfg = *t.TLSClientConfig.Clone() + if t.t1.TLSClientConfig != nil { + *cfg = *t.t1.TLSClientConfig.Clone() } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) From 352c1cb8d5339ec8a9958e0404a80655bef7cd57 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 14:03:17 +0800 Subject: [PATCH 263/843] support force http2 --- client.go | 28 ++++++++++++++++++++-------- transport.go | 23 ++++++++++++++++++----- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 4b85f818..d3947ed5 100644 --- a/client.go +++ b/client.go @@ -1145,19 +1145,31 @@ func EnableForceHTTP1() *Client { // EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (c *Client) EnableForceHTTP1() *Client { - c.t.ForceHTTP1 = true + c.t.ForceHttpVersion = HTTP1 return c } -// DisableForceHTTP1 is a global wrapper methods which delegated -// to the default client's DisableForceHTTP1. -func DisableForceHTTP1() *Client { - return defaultClient.DisableForceHTTP1() +// EnableForceHTTP2 is a global wrapper methods which delegated +// to the default client's EnableForceHTTP2. +func EnableForceHTTP2() *Client { + return defaultClient.EnableForceHTTP2() } -// DisableForceHTTP1 disable force using HTTP1 (disabled by default). -func (c *Client) DisableForceHTTP1() *Client { - c.t.ForceHTTP1 = false +// EnableForceHTTP2 enable force using HTTP2 (disabled by default). +func (c *Client) EnableForceHTTP2() *Client { + c.t.ForceHttpVersion = HTTP2 + return c +} + +// DisableForceHttpVersion is a global wrapper methods which delegated +// to the default client's DisableForceHttpVersion. +func DisableForceHttpVersion() *Client { + return defaultClient.DisableForceHttpVersion() +} + +// DisableForceHttpVersion disable force using HTTP1 (disabled by default). +func (c *Client) DisableForceHttpVersion() *Client { + c.t.ForceHttpVersion = "" return c } diff --git a/transport.go b/transport.go index 276d9285..28616462 100644 --- a/transport.go +++ b/transport.go @@ -38,6 +38,13 @@ import ( "golang.org/x/net/http/httpproxy" ) +type HttpVersion string + +const ( + HTTP1 HttpVersion = "1.1" + HTTP2 HttpVersion = "2" +) + // defaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const defaultMaxIdleConnsPerHost = 2 @@ -106,8 +113,8 @@ type Transport struct { connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns - // ForceHTTP1 force using HTTP/1.1 - ForceHTTP1 bool + // Force using specific http version + ForceHttpVersion HttpVersion // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -367,7 +374,7 @@ func (t *Transport) Clone() *Transport { WriteBufferSize: t.WriteBufferSize, ReadBufferSize: t.ReadBufferSize, ResponseOptions: t.ResponseOptions, - ForceHTTP1: t.ForceHTTP1, + ForceHttpVersion: t.ForceHttpVersion, dump: t.dump.Clone(), } if t.dump != nil { @@ -785,7 +792,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } - cm.onlyH1 = t.ForceHTTP1 || requestRequiresHTTP1(treq.Request) + cm.onlyH1 = t.ForceHttpVersion == HTTP1 || requestRequiresHTTP1(treq.Request) return cm, err } @@ -1492,6 +1499,9 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr } pconn.tlsState = &cs pconn.conn = tlsConn + if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + return fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", cs.NegotiatedProtocol, http2NextProtoTLS) + } return nil } @@ -1541,6 +1551,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers trace.TLSHandshakeDone(cs, nil) } pconn.tlsState = &cs + if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", cs.NegotiatedProtocol, http2NextProtoTLS) + } } } else { conn, err := t.dial(ctx, "tcp", cm.addr()) @@ -1672,7 +1685,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } } - if s := pconn.tlsState; !t.ForceHTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + if s := pconn.tlsState; t.ForceHttpVersion != HTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { alt := next(cm.targetAddr, pconn.conn.(TLSConn)) if e, ok := alt.(erringRoundTripper); ok { From 0eb1a5e1b32d4b3521c922932010b712da992693 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 14:17:03 +0800 Subject: [PATCH 264/843] improve error message --- README.md | 10 ++++++++++ transport.go | 12 ++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8bed0686..5bc311a6 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,16 @@ Access-Control-Allow-Credentials: true */ ``` +And also you can force using `HTTP/2` if you want, will return error if server does not support: + +```go +client := req.C().EnableForceHTTP2() +client.R().MustGet("https://baidu.com") +/* Output +panic: Get "https://baidu.com": server does not support http2, you can use http/1.1 which is supported +*/ +``` + ## URL Path and Query Parameter **Path Parameter** diff --git a/transport.go b/transport.go index 28616462..4e9ce4eb 100644 --- a/transport.go +++ b/transport.go @@ -1500,11 +1500,19 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr pconn.tlsState = &cs pconn.conn = tlsConn if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { - return fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", cs.NegotiatedProtocol, http2NextProtoTLS) + return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil } +func newHttp2NotSupportedError(negotiatedProtocol string) error { + errMsg := "server does not support http2" + if negotiatedProtocol != "" { + errMsg += fmt.Sprintf(", you can use %s which is supported", negotiatedProtocol) + } + return errors.New(errMsg) +} + type erringRoundTripper interface { RoundTripErr() error } @@ -1552,7 +1560,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } pconn.tlsState = &cs if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", cs.NegotiatedProtocol, http2NextProtoTLS) + return nil, newHttp2NotSupportedError(cs.NegotiatedProtocol) } } } else { From f1b0f625661e944914b9356d17f069eeaa5fb713 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 17:21:14 +0800 Subject: [PATCH 265/843] fully customized file upload --- client.go | 3 +- examples/upload/uploadserver/LICENSE | 21 + examples/upload/uploadserver/README.md | 877 +++++++++++++++++++++++++ logger.go | 4 +- middleware.go | 23 +- req.go | 9 +- request.go | 133 ++-- response.go | 3 + 8 files changed, 1004 insertions(+), 69 deletions(-) create mode 100644 examples/upload/uploadserver/LICENSE create mode 100644 examples/upload/uploadserver/README.md diff --git a/client.go b/client.go index d3947ed5..1fc52a94 100644 --- a/client.go +++ b/client.go @@ -1155,7 +1155,8 @@ func EnableForceHTTP2() *Client { return defaultClient.EnableForceHTTP2() } -// EnableForceHTTP2 enable force using HTTP2 (disabled by default). +// EnableForceHTTP2 enable force using HTTP2 for https requests +// (disabled by default). func (c *Client) EnableForceHTTP2() *Client { c.t.ForceHttpVersion = HTTP2 return c diff --git a/examples/upload/uploadserver/LICENSE b/examples/upload/uploadserver/LICENSE new file mode 100644 index 00000000..70f3d40a --- /dev/null +++ b/examples/upload/uploadserver/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017-2022 roc + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/examples/upload/uploadserver/README.md b/examples/upload/uploadserver/README.md new file mode 100644 index 00000000..5bc311a6 --- /dev/null +++ b/examples/upload/uploadserver/README.md @@ -0,0 +1,877 @@ +

+

Req

+

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

+

+

+ +## News + +Brand-New version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) + +If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). + +> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. + +## Table of Contents + +* [Features](#Features) +* [Quick Start](#Quick-Start) +* [API Reference](#API) +* [Debugging - Dump/Log/Trace](#Debugging) +* [Quick HTTP Test](#Test) +* [HTTP2 and HTTP1](#HTTP2-HTTP1) +* [URL Path and Query Parameter](#Param) +* [Form Data](#Form) +* [Header and Cookie](#Header-Cookie) +* [Body and Marshal/Unmarshal](#Body) +* [Custom Certificates](#Cert) +* [Basic Auth and Bearer Token](#Auth) +* [Download and Upload](#Download-Upload) +* [Auto-Decode](#AutoDecode) +* [Request and Response Middleware](#Middleware) +* [Redirect Policy](#Redirect) +* [Proxy](#Proxy) +* [TODO List](#TODO) +* [API Reference](#API) +* [License](#License) + +## Features + +* Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. +* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging). +* Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). +* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). +* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). +* Automatic marshal and unmarshal for JSON and XML content type and fully customizable (see [Body and Marshal/Unmarshal](#Body)). +* Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. +* Easy [Download and Upload](#Download-Upload). +* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. +* Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. +* Support middleware before request sent and after got response (see [Request and Response Middleware](#Middleware)). + +## Quick Start + +**Install** + +``` sh +go get github.com/imroc/req/v3 +``` + +**Import** + +```go +import "github.com/imroc/req/v3" +``` + +```go +// For test, you can create and send a request with the global default +// client, use DevMode to see all details, try and suprise :) +req.DevMode() +req.Get("https://api.github.com/users/imroc") + +// Create and send a request with the custom client and settings +client := req.C(). // Use C() to create a client + SetUserAgent("my-custom-client"). // Chainable client settings + SetTimeout(5 * time.Second). + DevMode() +resp, err := client.R(). // Use R() to create a request + SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings + SetPathParam("username", "imroc"). + SetQueryParam("page", "1"). + SetResult(&result). + Get("https://api.github.com/users/{username}/repos") +``` + +Checkout more runnable examples in the [examples](examples) direcotry. + +## API Reference + +Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). + +## Debugging - Dump/Log/Trace + +**Dump the Content** + +```go +// Enable dump at client level, which will dump for all requests, +// including all content of request and response and output +// to stdout by default. +client := req.C().EnableDumpAll() +client.R().Get("https://httpbin.org/get") + +/* Output +:authority: httpbin.org +:method: GET +:path: /get +:scheme: https +user-agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36 +accept-encoding: gzip + +:status: 200 +date: Wed, 26 Jan 2022 06:39:20 GMT +content-type: application/json +content-length: 372 +server: gunicorn/19.9.0 +access-control-allow-origin: * +access-control-allow-credentials: true + +{ + "args": {}, + "headers": { + "Accept-Encoding": "gzip", + "Host": "httpbin.org", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36", + "X-Amzn-Trace-Id": "Root=1-61f0ec98-5958c02662de26e458b7672b" + }, + "origin": "103.7.29.30", + "url": "https://httpbin.org/get" +} +*/ + +// Customize dump settings with predefined and convenient settings at client level. +client.EnableDumpAllWithoutBody(). // Only dump the header of request and response + EnableDumpAllAsync(). // Dump asynchronously to improve performance + EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out +// Send request to see the content that have been dumped +client.R().Get(url) + +// Enable dump with fully customized settings at client level. +opt := &req.DumpOptions{ + Output: os.Stdout, + RequestHeader: true, + ResponseBody: true, + RequestBody: false, + ResponseHeader: false, + Async: false, + } +client.SetCommonDumpOptions(opt).EnableDumpAll() +client.R().Get("https://httpbin.org/get") + +// Change settings dynamiclly +opt.ResponseBody = false +client.R().Get("https://httpbin.org/get") + +// You can also enable dump at request level, which will not override client-level dumping, +// dump to the internal buffer and will not print it out by default, you can call `Response.Dump()` +// to get the dump result and print only if you want to, typically used in production, only record +// the content of the request when the request is abnormal to help us troubleshoot problems. +resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") +if err != nil { + fmt.Println("err:", err) + if resp.Dump() != "" { + fmt.Println("raw content:") + fmt.Println(resp.Dump()) + } + return +} +if !resp.IsSuccess() { // Status code not beetween 200 and 299 + fmt.Println("bad status:", resp.Status) + fmt.Println("raw content:") + fmt.Println(resp.Dump()) + return +} + +// Similarly, also support to customize dump settings with the predefined and convenient settings at request level. +resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") +// ... +resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") +``` + +**Enable DebugLog for Deeper Insights** + +```go +// Logging is enabled by default, but only output the warning and error message. +// Use `EnableDebugLog` to enable debug level logging. +client := req.C().EnableDebugLog() +client.R().Get("http://baidu.com/s?wd=req") +/* Output +2022/01/26 15:46:29.279368 DEBUG [req] GET http://baidu.com/s?wd=req +2022/01/26 15:46:29.469653 DEBUG [req] charset iso-8859-1 detected in Content-Type, auto-decode to utf-8 +2022/01/26 15:46:29.469713 DEBUG [req] GET http://www.baidu.com/s?wd=req +... +*/ + +// SetLogger with nil to disable all log +client.SetLogger(nil) + +// Or customize the logger with your own implementation. +client.SetLogger(logger) +``` + +**Enable Trace to Analyze Performance** + +```go +// Enable trace at request level +client := req.C() +resp, err := client.R().EnableTrace().Get("https://api.github.com/users/imroc") +if err != nil { + log.Fatal(err) +} +trace := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary struct copy in production. +fmt.Println(trace.Blame()) // Print out exactly where the http request is slowing down. +fmt.Println("----------") +fmt.Println(trace) // Print details + +/* Output +the request total time is 2.562416041s, and costs 1.289082208s from connection ready to server respond frist byte +-------- +TotalTime : 2.562416041s +DNSLookupTime : 445.246375ms +TCPConnectTime : 428.458µs +TLSHandshakeTime : 825.888208ms +FirstResponseTime : 1.289082208s +ResponseTime : 1.712375ms +IsConnReused: : false +RemoteAddr : 98.126.155.187:443 +*/ + +// Enable trace at client level +client.EnableTraceAll() +resp, err = client.R().Get(url) +// ... +``` + +**DevMode** + +If you want to enable all debug features (dump, debug log and tracing), just call `DevMode()`: + +```go +client := req.C().DevMode() +client.R().Get("https://imroc.cc") +``` + +## Quick HTTP Test + +**Test with Global Wrapper Methods** + +`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to the default client behind the scenes, so you can just treat the package name `req` as a Client or Request to test quickly without create one explicitly. + +```go +// Call the global methods just like the Client's method, +// so you can treat package name `req` as a Client, and +// you don't need to create any client explicitly. +req.SetTimeout(5 * time.Second). + SetCommonBasicAuth("imroc", "123456"). + SetCommonHeader("Accept", "text/xml"). + SetUserAgent("my api client"). + DevMode() + +// Call the global method just like the Request's method, +// which will create request automatically using the default +// client, so you can treat package name `req` as a Request, +// and you don't need to create any request and client explicitly. +req.SetQueryParam("page", "2"). + SetHeader("Accept", "application/json"). // Override client level settings at request level. + Get("https://httpbin.org/get") +``` + +**Test with MustXXX** + +Use `MustXXX` to ignore error handling during test, make it possible to complete a complex test with just one line of code: + +```go +fmt.Println(req.DevMode().R().MustGet("https://imroc.cc").TraceInfo()) +``` + +## HTTP2 and HTTP1 + +Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, which is negotiated by TLS handshake. + +You can force using `HTTP/1.1` if you want. + +```go +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutBody() +client.R().MustGet("https://httpbin.org/get") +/* Output +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) +Accept-Encoding: gzip + +HTTP/1.1 200 OK +Date: Tue, 08 Feb 2022 02:30:18 GMT +Content-Type: application/json +Content-Length: 289 +Connection: keep-alive +Server: gunicorn/19.9.0 +Access-Control-Allow-Origin: * +Access-Control-Allow-Credentials: true +*/ +``` + +And also you can force using `HTTP/2` if you want, will return error if server does not support: + +```go +client := req.C().EnableForceHTTP2() +client.R().MustGet("https://baidu.com") +/* Output +panic: Get "https://baidu.com": server does not support http2, you can use http/1.1 which is supported +*/ +``` + +## URL Path and Query Parameter + +**Path Parameter** + +Use `SetPathParam` or `SetPathParams` to replace variable in the url path: + +```go +client := req.C().DevMode() + +client.R(). + SetPathParam("owner", "imroc"). // Set a path param, which will replace the vairable in url path + SetPathParams(map[string]string{ // Set multiple path params at once + "repo": "req", + "path": "README.md", + }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") // path parameter will replace path variable in the url +/* Output +2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md +... +*/ + +// You can also set the common PathParam for every request on client +client.SetCommonPathParam(k1, v1).SetCommonPathParams(pathParams) + +resp1, err := client.Get(url1) +... + +resp2, err := client.Get(url2) +... +``` + +**Query Parameter** + +Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query parameter: + +```go +client := req.C().DevMode() + +// Set query parameter at request level. +client.R(). + SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url + SetQueryParams(map[string]string{ // Set multiple query params at once + "b": "b", + "c": "c", + }).SetQueryString("d=d&e=e"). // Set query params as a raw query string + Get("https://api.github.com/repos/imroc/req/contents/README.md?x=x") +/* Output +2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e +... +*/ + +// You can also set the query parameter at client level. +client.SetCommonQueryParam(k, v). + SetCommonQueryParams(queryParams). + SetCommonQueryString(queryString). + +resp1, err := client.Get(url1) +... +resp2, err := client.Get(url2) +... + +// Add query parameter with multiple values at request level. +client.R().AddQueryParam("key", "value1").AddQueryParam("key", "value2").Get("https://httpbin.org/get") +/* Output +2022/02/05 08:49:26.260780 DEBUG [req] GET https://httpbin.org/get?key=value1&key=value2 +... + */ + + +// Multiple values also supported at client level. +client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") +``` + +## Form Data + +```go +client := req.C().EnableDumpAllWithoutResponse() +client.R().SetFormData(map[string]string{ + "username": "imroc", + "blog": "https://imroc.cc", +}).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/x-www-form-urlencoded +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +blog=https%3A%2F%2Fimroc.cc&username=imroc +*/ + +// Multi value form data +v := url.Values{ + "multi": []string{"a", "b", "c"}, +} +client.R().SetFormDataFromValues(v).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/x-www-form-urlencoded +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +multi=a&multi=b&multi=c +*/ + +// You can also set form data in client level +client.SetCommonFormData(m) +client.SetCommonFormDataFromValues(v) +``` + +> `GET`, `HEAD`, and `OPTIONS` requests ignores form data by default + +## Header and Cookie + +**Set Header** + +```go +// Let's dump the header to see what's going on +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() + +// Send a request with multiple headers and cookies +client.R(). + SetHeader("Accept", "application/json"). // Set one header + SetHeaders(map[string]string{ // Set multiple headers at once + "My-Custom-Header": "My Custom Value", + "User": "imroc", + }).Get("https://httpbin.org/get") + +/* Output +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) +Accept: application/json +My-Custom-Header: My Custom Value +User: imroc +Accept-Encoding: gzip +*/ + +// You can also set the common header and cookie for every request on client. +client.SetCommonHeader(header).SetCommonHeaders(headers) + +resp1, err := client.R().Get(url1) +... +resp2, err := client.R().Get(url2) +... +``` + +**Set Cookie** + +```go +// Let's dump the header to see what's going on +client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() + +// Send a request with multiple headers and cookies +client.R(). + SetCookies( + &http.Cookie{ + Name: "testcookie1", + Value: "testcookie1 value", + Path: "/", + Domain: "baidu.com", + MaxAge: 36000, + HttpOnly: false, + Secure: true, + }, + &http.Cookie{ + Name: "testcookie2", + Value: "testcookie2 value", + Path: "/", + Domain: "baidu.com", + MaxAge: 36000, + HttpOnly: false, + Secure: true, + }, + ).Get("https://httpbin.org/get") + +/* Output +GET /get HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req) +Cookie: testcookie1="testcookie1 value"; testcookie2="testcookie2 value" +Accept-Encoding: gzip +*/ + +// You can also set the common cookie for every request on client. +client.SetCommonCookies(cookie1, cookie2, cookie3) + +resp1, err := client.R().Get(url1) +... +resp2, err := client.R().Get(url2) +``` + +You can also customize the CookieJar: +```go +// Set your own http.CookieJar implementation +client.SetCookieJar(jar) + +// Set to nil to disable CookieJar +client.SetCookieJar(nil) +``` + +## Body and Marshal/Unmarshal + +**Request Body** + +```go +// Create a client that dump request +client := req.C().EnableDumpAllWithoutResponse() +// SetBody accepts string, []byte, io.Reader, use type assertion to +// determine the data type of body automatically. +client.R().SetBody("test").Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +test +*/ + +// If it cannot determine, like map and struct, then it will wait +// and marshal to JSON or XML automatically according to the `Content-Type` +// header that have been set before or after, default to json if not set. +type User struct { + Name string `json:"name"` + Email string `json:"email"` +} +user := &User{Name: "imroc", Email: "roc@imroc.cc"} +client.R().SetBody(user).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/json; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +{"name":"imroc","email":"roc@imroc.cc"} +*/ + + +// You can use more specific methods to avoid type assertions and improves performance, +client.R().SetBodyJsonString(`{"username": "imroc"}`).Post("https://httpbin.org/post") +/* +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: application/json; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +{"username": "imroc"} +*/ + +// Marshal body and set `Content-Type` automatically without any guess +cient.R().SetBodyXmlMarshal(user).Post("https://httpbin.org/post") +/* Output +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +content-type: text/xml; charset=utf-8 +accept-encoding: gzip +user-agent: req/v2 (https://github.com/imroc/req) + +imrocroc@imroc.cc +*/ +``` + +**Response Body** + +```go +// Define success body struct +type User struct { + Name string `json:"name"` + Blog string `json:"blog"` +} +// Define error body struct +type ErrorMessage struct { + Message string `json:"message"` +} +// Create a client and dump body to see details +client := req.C().EnableDumpAllWithoutHeader() + +// Send a request and unmarshal result automatically according to +// response `Content-Type` +user := &User{} +errMsg := &ErrorMessage{} +resp, err := client.R(). + SetResult(user). // Set success result + SetError(errMsg). // Set error result + Get("https://api.github.com/users/imroc") +if err != nil { + log.Fatal(err) +} +fmt.Println("----------") + +if resp.IsSuccess() { // status `code >= 200 and <= 299` is considered as success + // Must have been marshaled to user if no error returned before + fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) +} else if resp.IsError() { // status `code >= 400` is considered as error + // Must have been marshaled to errMsg if no error returned before + fmt.Println("got error:", errMsg.Message) +} else { + log.Fatal("unknown http status:", resp.Status) +} +/* Output +{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":129,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2022-01-24T23:32:53Z"} +---------- +roc's blog is https://imroc.cc +*/ + +// Or you can also unmarshal response later +if resp.IsSuccess() { + err = resp.Unmarshal(user) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) +} else { + fmt.Println("bad response:", resp) +} + +// Also, you can get the raw response and Unmarshal by yourself +yaml.Unmarshal(resp.Bytes()) +``` + +**Customize JSON&XML Marshal/Unmarshal** + +```go +// Example of registering json-iterator +import jsoniter "github.com/json-iterator/go" + +json := jsoniter.ConfigCompatibleWithStandardLibrary + +client := req.C(). + SetJsonMarshal(json.Marshal). + SetJsonUnmarshal(json.Unmarshal) + +// Similarly, XML functions can also be customized +client.SetXmlMarshal(xmlMarshalFunc).SetXmlUnmarshal(xmlUnmarshalFunc) +``` + +**Disable Auto-Read Response Body** + +Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). + +```go +client.DisableAutoReadResponse() + +resp, err := client.R().Get(url) +if err != nil { + log.Fatal(err) +} +io.Copy(dst, resp.Body) +``` + +## Custom Certificates + +```go +client := req.R() + +// Set root cert and client cert from file path +client.SetRootCertsFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files + SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file + +// You can also set root cert from string +client.SetRootCertFromString("-----BEGIN CERTIFICATE-----XXXXXX-----END CERTIFICATE-----") + +// And set client cert with +cert1, err := tls.LoadX509KeyPair("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") +if err != nil { + log.Fatalf("ERROR client certificate: %s", err) +} +// ... + +// you can add more certs if you want +client.SetCerts(cert1, cert2, cert3) +``` + +## Basic Auth and Bearer Token + +```go +client := req.C() + +// Set basic auth for all request +client.SetCommonBasicAuth("imroc", "123456") + +// Set bearer token for all request +client.SetCommonBearerAuthToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") + +// Set basic auth for a request, will override client's basic auth setting. +client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com/profile") + +// Set bearer token for a request, will override client's bearer token setting. +client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") +``` + +## Download and Upload + +**Download** + +```go +// Create a client with default download direcotry +client := req.C().SetOutputDirectory("/path/to/download").EnableDumpNoResponseBody() + +// Download to relative file path, this will be downloaded +// to /path/to/download/test.jpg +client.R().SetOutputFile("test.jpg").Get(url) + +// Download to absolute file path, ignore the output directory +// setting from Client +client.R().SetOutputFile("/tmp/test.jpg").Get(url) + +// You can also save file to any `io.WriteCloser` +file, err := os.Create("/tmp/test.jpg") +if err != nil { + fmt.Println(err) + return +} +client.R().SetOutput(file).Get(url) +``` + +**Multipart Upload** + +```go +client := req.().EnableDumpNoRequestBody() // Request body contains unreadable binary, do not dump + +client.R().SetFile("pic", "test.jpg"). // Set form param name and filename + SetFile("pic", "/path/to/roc.png"). // Multiple files using the same form param name + SetFiles(map[string]string{ // Set multiple files using map + "exe": "test.exe", + "src": "main.go", + }). + SetFormData(map[string]string{ // Set from param using map + "name": "imroc", + "email": "roc@imroc.cc", + }). + SetFromDataFromValues(values). // You can also set form data using `url.Values` + Post("http://127.0.0.1:8888/upload") + +// You can also use io.Reader to upload +avatarImgFile, _ := os.Open("avatar.png") +client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) +*/ +``` + +## Auto-Decode + +`Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. + +Its principle is to detect `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information. And `req` also will try to sniff the body's content to determine the charset if the charset information is not included in the header, if sniffed out and not utf-8, then decode it to utf-8 automatically, and `req` will not try to decode if the charset is not sure, just leave the body untouched. + +You can also disable it if you don't need or care a lot about performance: + +```go +client.DisableAutoDecode() +``` + +And also you can make some customization: + +```go +// Try to auto-detect and decode all content types (some server may return incorrect Content-Type header) +client.SetAutoDecodeAllContentType() + +// Only auto-detect and decode content which `Content-Type` header contains "html" or "json" +client.SetAutoDecodeContentType("html", "json") + +// Or you can customize the function to determine whether to decode +fn := func(contentType string) bool { + if regexContentType.MatchString(contentType) { + return true + } + return false +} +client.SetAutoDecodeContentTypeFunc(fn) +``` + +## Request and Response Middleware + +```go +client := req.C() + +// Registering Request Middleware +client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { + // You can access Client and current Request object to do something + // as you need + + return nil // return nil if it is success + }) + +// Registering Response Middleware +client.OnAfterResponse(func(c *req.Client, r *req.Response) error { + // You can access Client and current Response object to do something + // as you need + + return nil // return nil if it is success + }) +``` + +## Redirect Policy + +```go +client := req.C().EnableDumpAllWithoutResponse() + +client.SetRedirectPolicy( + // Only allow up to 5 redirects + req.MaxRedirectPolicy(5), + // Only allow redirect to same domain. + // e.g. redirect "www.imroc.cc" to "imroc.cc" is allowed, but "google.com" is not + req.SameDomainRedirectPolicy(), +) + +client.SetRedirectPolicy( + // Only *.google.com/google.com and *.imroc.cc/imroc.cc is allowed to redirect + req.AllowedDomainRedirectPolicy("google.com", "imroc.cc"), + // Only allow redirect to same host. + // e.g. redirect "www.imroc.cc" to "imroc.cc" is not allowed, only "www.imroc.cc" is allowed + req.SameHostRedirectPolicy(), +) + +// All redirect is not allowd +client.SetRedirectPolicy(req.NoRedirectPolicy()) + +// Or customize the redirect with your own implementation +client.SetRedirectPolicy(func(req *http.Request, via []*http.Request) error { + // ... +}) +``` + +## Proxy + +`Req` use proxy `http.ProxyFromEnvironment` by default, which will read the `HTTP_PROXY/HTTPS_PROXY/http_proxy/https_proxy` environment variable, and setup proxy if environment variable is been set. You can customize it if you need: + +```go +// Set proxy from proxy url +client.SetProxyURL("http://myproxy:8080") + +// Custmize the proxy function with your own implementation +client.SetProxy(func(request *http.Request) (*url.URL, error) { + //... +}) + +// Disable proxy +client.SetProxy(nil) +``` + +## TODO List + +* [ ] Add tests. +* [ ] Wrap more transport settings into client. +* [ ] Support retry. +* [ ] Support unix socket. +* [ ] Support h2c. + +## License + +`Req` released under MIT license, refer [LICENSE](LICENSE) file. diff --git a/logger.go b/logger.go index e87e61a4..af461064 100644 --- a/logger.go +++ b/logger.go @@ -6,8 +6,8 @@ import ( "os" ) -// Logger interface is to abstract the logging from Resty. Gives control to -// the Resty users, choice of the logger. +// Logger is the abstract logging interface, gives control to +// the Req users, choice of the logger. type Logger interface { Errorf(format string, v ...interface{}) Warnf(format string, v ...interface{}) diff --git a/middleware.go b/middleware.go index 12b0e66e..e6be4f37 100644 --- a/middleware.go +++ b/middleware.go @@ -29,15 +29,12 @@ func escapeQuotes(s string) string { return quoteEscaper.Replace(s) } -func createMultipartHeader(param, fileName, contentType string) textproto.MIMEHeader { +func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEHeader { hdr := make(textproto.MIMEHeader) - var contentDispositionValue string - if util.IsStringEmpty(fileName) { - contentDispositionValue = fmt.Sprintf(`form-data; name="%s"`, param) - } else { - contentDispositionValue = fmt.Sprintf(`form-data; name="%s"; filename="%s"`, - param, escapeQuotes(fileName)) + contentDispositionValue := "form-data" + for k, v := range file.ContentDisposition { + contentDispositionValue += fmt.Sprintf(`; %s="%v"`, k, v) } hdr.Set("Content-Disposition", contentDispositionValue) @@ -53,16 +50,16 @@ func closeq(v interface{}) { } } -func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r io.Reader) error { - defer closeq(r) +func writeMultipartFormFile(w *multipart.Writer, file *FileUpload) error { + defer closeq(file.File) // Auto detect actual multipart content type cbuf := make([]byte, 512) - size, err := r.Read(cbuf) + size, err := file.File.Read(cbuf) if err != nil && err != io.EOF { return err } - pw, err := w.CreatePart(createMultipartHeader(fieldName, fileName, http.DetectContentType(cbuf))) + pw, err := w.CreatePart(createMultipartHeader(file, http.DetectContentType(cbuf))) if err != nil { return err } @@ -71,7 +68,7 @@ func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r i return err } - _, err = io.Copy(pw, r) + _, err = io.Copy(pw, file.File) return err } @@ -82,7 +79,7 @@ func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { } } for _, file := range r.uploadFiles { - writeMultipartFormFile(w, file.ParamName, file.FilePath, file.Reader) + writeMultipartFormFile(w, file) } w.Close() // close multipart to write tailer boundary pw.Close() // close pipe writer so that pipe reader could get EOF, and stop upload diff --git a/req.go b/req.go index 5760308d..11a6467e 100644 --- a/req.go +++ b/req.go @@ -12,8 +12,9 @@ const ( formContentType = "application/x-www-form-urlencoded" ) -type uploadFile struct { - ParamName string - FilePath string - io.Reader +type FileUpload struct { + ParamName string + FileName string + ContentDisposition map[string]interface{} + File io.Reader } diff --git a/request.go b/request.go index 714cfa14..2b7d4c08 100644 --- a/request.go +++ b/request.go @@ -10,11 +10,14 @@ import ( "net/http" urlpkg "net/url" "os" + "path/filepath" "strings" "time" ) -// Request is the http request +// Request struct is used to compose and fire individual request from +// req client. Request provides lots of chainable settings which can +// override client level settings. type Request struct { URL string PathParams map[string]string @@ -33,7 +36,7 @@ type Request struct { marshalBody interface{} ctx context.Context isMultiPart bool - uploadFiles []*uploadFile + uploadFiles []*FileUpload uploadReader []io.ReadCloser outputFile string isSaveResponse bool @@ -42,7 +45,8 @@ type Request struct { dumpBuffer *bytes.Buffer } -// TraceInfo returns the trace information, only available when trace is enabled. +// TraceInfo returns the trace information, only available if trace is enabled +// (see Request.EnableTrace and Client.EnableTraceAll). func (r *Request) TraceInfo() TraceInfo { ct := r.trace @@ -96,7 +100,8 @@ func SetFormDataFromValues(data urlpkg.Values) *Request { return defaultClient.R().SetFormDataFromValues(data) } -// SetFormDataFromValues set the form data from url.Values, not used if method not allow payload. +// SetFormDataFromValues set the form data from url.Values, will not +// been used if request method does not allow payload. func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { if r.FormData == nil { r.FormData = urlpkg.Values{} @@ -115,7 +120,8 @@ func SetFormData(data map[string]string) *Request { return defaultClient.R().SetFormData(data) } -// SetFormData set the form data from map, not used if method not allow payload. +// SetFormData set the form data from a map, will not been used +// if request method does not allow payload. func (r *Request) SetFormData(data map[string]string) *Request { if r.FormData == nil { r.FormData = urlpkg.Values{} @@ -132,7 +138,7 @@ func SetCookies(cookies ...*http.Cookie) *Request { return defaultClient.R().SetCookies(cookies...) } -// SetCookies set cookies at request level. +// SetCookies set http cookies for the request. func (r *Request) SetCookies(cookies ...*http.Cookie) *Request { r.Cookies = append(r.Cookies, cookies...) return r @@ -144,7 +150,8 @@ func SetQueryString(query string) *Request { return defaultClient.R().SetQueryString(query) } -// SetQueryString set URL query parameters using the raw query string. +// SetQueryString set URL query parameters for the request using +// raw query string. func (r *Request) SetQueryString(query string) *Request { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err != nil { @@ -168,25 +175,39 @@ func SetFileReader(paramName, filePath string, reader io.Reader) *Request { return defaultClient.R().SetFileReader(paramName, filePath, reader) } -// SetFileReader sets up a multipart form with a reader to upload file. -func (r *Request) SetFileReader(paramName, filePath string, reader io.Reader) *Request { +// SetFileReader set up a multipart form with a reader to upload file. +func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *Request { r.isMultiPart = true - r.uploadFiles = append(r.uploadFiles, &uploadFile{ - ParamName: paramName, - FilePath: filePath, - Reader: reader, + contentDisposition := map[string]interface{}{ + "name": paramName, + "filename": filename, + } + r.uploadFiles = append(r.uploadFiles, &FileUpload{ + ContentDisposition: contentDisposition, + File: reader, }) return r } +// SetFileBytes is a global wrapper methods which delegated +// to the default client, create a request and SetFileBytes for request. +func SetFileBytes(paramName, filename string, content []byte) *Request { + return defaultClient.R().SetFileBytes(paramName, filename, content) +} + +// SetFileBytes set up a multipart form with given []byte to upload. +func (r *Request) SetFileBytes(paramName, filename string, content []byte) *Request { + return r.SetFileReader(paramName, filename, bytes.NewReader(content)) +} + // SetFiles is a global wrapper methods which delegated // to the default client, create a request and SetFiles for request. func SetFiles(files map[string]string) *Request { return defaultClient.R().SetFiles(files) } -// SetFiles sets up a multipart form from a map, which key is the param -// name, value is the file path. +// SetFiles set up a multipart form from a map to upload, which +// key is the parameter name, and value is the file path. func (r *Request) SetFiles(files map[string]string) *Request { for k, v := range files { r.SetFile(k, v) @@ -200,7 +221,8 @@ func SetFile(paramName, filePath string) *Request { return defaultClient.R().SetFile(paramName, filePath) } -// SetFile sets up a multipart form, read file from filePath automatically to upload. +// SetFile set up a multipart form from file path to upload, +// which read file from filePath automatically to upload. func (r *Request) SetFile(paramName, filePath string) *Request { r.isMultiPart = true file, err := os.Open(filePath) @@ -209,11 +231,18 @@ func (r *Request) SetFile(paramName, filePath string) *Request { r.appendError(err) return r } - r.uploadFiles = append(r.uploadFiles, &uploadFile{ - ParamName: paramName, - FilePath: filePath, - Reader: file, - }) + return r.SetFileReader(paramName, filepath.Base(filePath), file) +} + +// SetFileUpload is a global wrapper methods which delegated +// to the default client, create a request and SetFileUpload for request. +func SetFileUpload(f FileUpload) *Request { + return defaultClient.R().SetFileUpload(f) +} + +// SetFileUpload set the fully custimized multipart file upload options. +func (r *Request) SetFileUpload(f FileUpload) *Request { + r.uploadFiles = append(r.uploadFiles, &f) return r } @@ -249,7 +278,7 @@ func SetBearerAuthToken(token string) *Request { return defaultClient.R().SetBearerAuthToken(token) } -// SetBearerAuthToken set the bearer auth token at request level. +// SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) } @@ -260,7 +289,7 @@ func SetBasicAuth(username, password string) *Request { return defaultClient.R().SetBasicAuth(username, password) } -// SetBasicAuth set the basic auth at request level. +// SetBasicAuth set basic auth for the request. func (r *Request) SetBasicAuth(username, password string) *Request { return r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) } @@ -271,7 +300,7 @@ func SetHeaders(hdrs map[string]string) *Request { return defaultClient.R().SetHeaders(hdrs) } -// SetHeaders set the header at request level. +// SetHeaders set headers from a map for the request. func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { r.SetHeader(k, v) @@ -285,7 +314,7 @@ func SetHeader(key, value string) *Request { return defaultClient.R().SetHeader(key, value) } -// SetHeader set a header at request level. +// SetHeader set a header for the request. func (r *Request) SetHeader(key, value string) *Request { if r.Headers == nil { r.Headers = make(http.Header) @@ -300,7 +329,7 @@ func SetOutputFile(file string) *Request { return defaultClient.R().SetOutputFile(file) } -// SetOutputFile the file that response body will be downloaded to. +// SetOutputFile set the file that response body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true r.outputFile = file @@ -313,7 +342,7 @@ func SetOutput(output io.Writer) *Request { return defaultClient.R().SetOutput(output) } -// SetOutput the io.Writer that response body will be downloaded to. +// SetOutput set the io.Writer that response body will be downloaded to. func (r *Request) SetOutput(output io.Writer) *Request { r.output = output r.isSaveResponse = true @@ -326,7 +355,7 @@ func SetQueryParams(params map[string]string) *Request { return defaultClient.R().SetQueryParams(params) } -// SetQueryParams sets the URL query parameters with a map at client level. +// SetQueryParams set URL query parameters from a map for the request. func (r *Request) SetQueryParams(params map[string]string) *Request { for k, v := range params { r.SetQueryParam(k, v) @@ -340,8 +369,7 @@ func SetQueryParam(key, value string) *Request { return defaultClient.R().SetQueryParam(key, value) } -// SetQueryParam set an URL query parameter with a key-value -// pair at request level. +// SetQueryParam set an URL query parameter for the request. func (r *Request) SetQueryParam(key, value string) *Request { if r.QueryParams == nil { r.QueryParams = make(urlpkg.Values) @@ -356,8 +384,7 @@ func AddQueryParam(key, value string) *Request { return defaultClient.R().AddQueryParam(key, value) } -// AddQueryParam add a URL query parameter with a key-value -// pair at request level. +// AddQueryParam add a URL query parameter for the request. func (r *Request) AddQueryParam(key, value string) *Request { if r.QueryParams == nil { r.QueryParams = make(urlpkg.Values) @@ -372,7 +399,7 @@ func SetPathParams(params map[string]string) *Request { return defaultClient.R().SetPathParams(params) } -// SetPathParams sets the URL path parameters from a map at request level. +// SetPathParams set URL path parameters from a map for the request. func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { r.SetPathParam(key, value) @@ -386,7 +413,7 @@ func SetPathParam(key, value string) *Request { return defaultClient.R().SetPathParam(key, value) } -// SetPathParam sets the URL path parameters from a key-value paire at request level. +// SetPathParam set a URL path parameter for the request. func (r *Request) SetPathParam(key, value string) *Request { if r.PathParams == nil { r.PathParams = make(map[string]string) @@ -399,7 +426,8 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } -// Send sends the http request. +// Send fires http request and return the *Response which is always +// not nil, and the error is not nil if some error happens. func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { return &Response{}, r.error @@ -415,7 +443,8 @@ func MustGet(url string) *Response { return defaultClient.R().MustGet(url) } -// MustGet like Get, panic if error happens, should only be used to test without error handling. +// MustGet like Get, panic if error happens, should only be used to +// test without error handling. func (r *Request) MustGet(url string) *Response { resp, err := r.Get(url) if err != nil { @@ -430,7 +459,7 @@ func Get(url string) (*Response, error) { return defaultClient.R().Get(url) } -// Get Send the request with GET method and specified url. +// Get fires http request with GET method and the specified URL. func (r *Request) Get(url string) (*Response, error) { return r.Send(http.MethodGet, url) } @@ -441,7 +470,8 @@ func MustPost(url string) *Response { return defaultClient.R().MustPost(url) } -// MustPost like Post, panic if error happens. +// MustPost like Post, panic if error happens. should only be used to +// test without error handling. func (r *Request) MustPost(url string) *Response { resp, err := r.Post(url) if err != nil { @@ -456,7 +486,7 @@ func Post(url string) (*Response, error) { return defaultClient.R().Post(url) } -// Post Send the request with POST method and specified url. +// Post fires http request with POST method and the specified URL. func (r *Request) Post(url string) (*Response, error) { return r.Send(http.MethodPost, url) } @@ -467,7 +497,8 @@ func MustPut(url string) *Response { return defaultClient.R().MustPut(url) } -// MustPut like Put, panic if error happens, should only be used to test without error handling. +// MustPut like Put, panic if error happens, should only be used to +// test without error handling. func (r *Request) MustPut(url string) *Response { resp, err := r.Put(url) if err != nil { @@ -482,7 +513,7 @@ func Put(url string) (*Response, error) { return defaultClient.R().Put(url) } -// Put Send the request with Put method and specified url. +// Put fires http request with PUT method and the specified URL. func (r *Request) Put(url string) (*Response, error) { return r.Send(http.MethodPut, url) } @@ -493,7 +524,8 @@ func MustPatch(url string) *Response { return defaultClient.R().MustPatch(url) } -// MustPatch like Patch, panic if error happens, should only be used to test without error handling. +// MustPatch like Patch, panic if error happens, should only be used +// to test without error handling. func (r *Request) MustPatch(url string) *Response { resp, err := r.Patch(url) if err != nil { @@ -508,7 +540,7 @@ func Patch(url string) (*Response, error) { return defaultClient.R().Patch(url) } -// Patch Send the request with PATCH method and specified url. +// Patch fires http request with PATCH method and the specified URL. func (r *Request) Patch(url string) (*Response, error) { return r.Send(http.MethodPatch, url) } @@ -519,7 +551,8 @@ func MustDelete(url string) *Response { return defaultClient.R().MustDelete(url) } -// MustDelete like Delete, panic if error happens, should only be used to test without error handling. +// MustDelete like Delete, panic if error happens, should only be used +// to test without error handling. func (r *Request) MustDelete(url string) *Response { resp, err := r.Delete(url) if err != nil { @@ -534,7 +567,7 @@ func Delete(url string) (*Response, error) { return defaultClient.R().Delete(url) } -// Delete Send the request with DELETE method and specified url. +// Delete fires http request with DELETE method and the specified URL. func (r *Request) Delete(url string) (*Response, error) { return r.Send(http.MethodDelete, url) } @@ -545,7 +578,8 @@ func MustOptions(url string) *Response { return defaultClient.R().MustOptions(url) } -// MustOptions like Options, panic if error happens, should only be used to test without error handling. +// MustOptions like Options, panic if error happens, should only be +// used to test without error handling. func (r *Request) MustOptions(url string) *Response { resp, err := r.Options(url) if err != nil { @@ -560,7 +594,7 @@ func Options(url string) (*Response, error) { return defaultClient.R().Options(url) } -// Options Send the request with OPTIONS method and specified url. +// Options fires http request with OPTIONS method and the specified URL. func (r *Request) Options(url string) (*Response, error) { return r.Send(http.MethodOptions, url) } @@ -571,7 +605,8 @@ func MustHead(url string) *Response { return defaultClient.R().MustHead(url) } -// MustHead like Head, panic if error happens, should only be used to test without error handling. +// MustHead like Head, panic if error happens, should only be used +// to test without error handling. func (r *Request) MustHead(url string) *Response { resp, err := r.Send(http.MethodHead, url) if err != nil { @@ -586,7 +621,7 @@ func Head(url string) (*Response, error) { return defaultClient.R().Head(url) } -// Head Send the request with HEAD method and specified url. +// Head fires http request with HEAD method and the specified URL. func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } diff --git a/response.go b/response.go index b3f31f49..e0d27d44 100644 --- a/response.go +++ b/response.go @@ -124,6 +124,9 @@ func (r *Response) ToBytes() ([]byte, error) { if r.body != nil { return r.body, nil } + if r.Response == nil || r.Response.Body == nil { + return []byte{}, nil + } defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) r.setReceivedAt() From 38337e090f72a5c22429274eaa259fab4d3f2246 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 17:21:45 +0800 Subject: [PATCH 266/843] remove trash --- examples/upload/uploadserver/LICENSE | 21 - examples/upload/uploadserver/README.md | 877 ------------------------- 2 files changed, 898 deletions(-) delete mode 100644 examples/upload/uploadserver/LICENSE delete mode 100644 examples/upload/uploadserver/README.md diff --git a/examples/upload/uploadserver/LICENSE b/examples/upload/uploadserver/LICENSE deleted file mode 100644 index 70f3d40a..00000000 --- a/examples/upload/uploadserver/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2017-2022 roc - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/examples/upload/uploadserver/README.md b/examples/upload/uploadserver/README.md deleted file mode 100644 index 5bc311a6..00000000 --- a/examples/upload/uploadserver/README.md +++ /dev/null @@ -1,877 +0,0 @@ -

-

Req

-

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

-

-

- -## News - -Brand-New version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) - -If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). - -> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. - -## Table of Contents - -* [Features](#Features) -* [Quick Start](#Quick-Start) -* [API Reference](#API) -* [Debugging - Dump/Log/Trace](#Debugging) -* [Quick HTTP Test](#Test) -* [HTTP2 and HTTP1](#HTTP2-HTTP1) -* [URL Path and Query Parameter](#Param) -* [Form Data](#Form) -* [Header and Cookie](#Header-Cookie) -* [Body and Marshal/Unmarshal](#Body) -* [Custom Certificates](#Cert) -* [Basic Auth and Bearer Token](#Auth) -* [Download and Upload](#Download-Upload) -* [Auto-Decode](#AutoDecode) -* [Request and Response Middleware](#Middleware) -* [Redirect Policy](#Redirect) -* [Proxy](#Proxy) -* [TODO List](#TODO) -* [API Reference](#API) -* [License](#License) - -## Features - -* Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging). -* Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). -* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). -* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). -* Automatic marshal and unmarshal for JSON and XML content type and fully customizable (see [Body and Marshal/Unmarshal](#Body)). -* Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. -* Easy [Download and Upload](#Download-Upload). -* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. -* Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. -* Support middleware before request sent and after got response (see [Request and Response Middleware](#Middleware)). - -## Quick Start - -**Install** - -``` sh -go get github.com/imroc/req/v3 -``` - -**Import** - -```go -import "github.com/imroc/req/v3" -``` - -```go -// For test, you can create and send a request with the global default -// client, use DevMode to see all details, try and suprise :) -req.DevMode() -req.Get("https://api.github.com/users/imroc") - -// Create and send a request with the custom client and settings -client := req.C(). // Use C() to create a client - SetUserAgent("my-custom-client"). // Chainable client settings - SetTimeout(5 * time.Second). - DevMode() -resp, err := client.R(). // Use R() to create a request - SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings - SetPathParam("username", "imroc"). - SetQueryParam("page", "1"). - SetResult(&result). - Get("https://api.github.com/users/{username}/repos") -``` - -Checkout more runnable examples in the [examples](examples) direcotry. - -## API Reference - -Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). - -## Debugging - Dump/Log/Trace - -**Dump the Content** - -```go -// Enable dump at client level, which will dump for all requests, -// including all content of request and response and output -// to stdout by default. -client := req.C().EnableDumpAll() -client.R().Get("https://httpbin.org/get") - -/* Output -:authority: httpbin.org -:method: GET -:path: /get -:scheme: https -user-agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36 -accept-encoding: gzip - -:status: 200 -date: Wed, 26 Jan 2022 06:39:20 GMT -content-type: application/json -content-length: 372 -server: gunicorn/19.9.0 -access-control-allow-origin: * -access-control-allow-credentials: true - -{ - "args": {}, - "headers": { - "Accept-Encoding": "gzip", - "Host": "httpbin.org", - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36", - "X-Amzn-Trace-Id": "Root=1-61f0ec98-5958c02662de26e458b7672b" - }, - "origin": "103.7.29.30", - "url": "https://httpbin.org/get" -} -*/ - -// Customize dump settings with predefined and convenient settings at client level. -client.EnableDumpAllWithoutBody(). // Only dump the header of request and response - EnableDumpAllAsync(). // Dump asynchronously to improve performance - EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out -// Send request to see the content that have been dumped -client.R().Get(url) - -// Enable dump with fully customized settings at client level. -opt := &req.DumpOptions{ - Output: os.Stdout, - RequestHeader: true, - ResponseBody: true, - RequestBody: false, - ResponseHeader: false, - Async: false, - } -client.SetCommonDumpOptions(opt).EnableDumpAll() -client.R().Get("https://httpbin.org/get") - -// Change settings dynamiclly -opt.ResponseBody = false -client.R().Get("https://httpbin.org/get") - -// You can also enable dump at request level, which will not override client-level dumping, -// dump to the internal buffer and will not print it out by default, you can call `Response.Dump()` -// to get the dump result and print only if you want to, typically used in production, only record -// the content of the request when the request is abnormal to help us troubleshoot problems. -resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") -if err != nil { - fmt.Println("err:", err) - if resp.Dump() != "" { - fmt.Println("raw content:") - fmt.Println(resp.Dump()) - } - return -} -if !resp.IsSuccess() { // Status code not beetween 200 and 299 - fmt.Println("bad status:", resp.Status) - fmt.Println("raw content:") - fmt.Println(resp.Dump()) - return -} - -// Similarly, also support to customize dump settings with the predefined and convenient settings at request level. -resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") -// ... -resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") -``` - -**Enable DebugLog for Deeper Insights** - -```go -// Logging is enabled by default, but only output the warning and error message. -// Use `EnableDebugLog` to enable debug level logging. -client := req.C().EnableDebugLog() -client.R().Get("http://baidu.com/s?wd=req") -/* Output -2022/01/26 15:46:29.279368 DEBUG [req] GET http://baidu.com/s?wd=req -2022/01/26 15:46:29.469653 DEBUG [req] charset iso-8859-1 detected in Content-Type, auto-decode to utf-8 -2022/01/26 15:46:29.469713 DEBUG [req] GET http://www.baidu.com/s?wd=req -... -*/ - -// SetLogger with nil to disable all log -client.SetLogger(nil) - -// Or customize the logger with your own implementation. -client.SetLogger(logger) -``` - -**Enable Trace to Analyze Performance** - -```go -// Enable trace at request level -client := req.C() -resp, err := client.R().EnableTrace().Get("https://api.github.com/users/imroc") -if err != nil { - log.Fatal(err) -} -trace := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary struct copy in production. -fmt.Println(trace.Blame()) // Print out exactly where the http request is slowing down. -fmt.Println("----------") -fmt.Println(trace) // Print details - -/* Output -the request total time is 2.562416041s, and costs 1.289082208s from connection ready to server respond frist byte --------- -TotalTime : 2.562416041s -DNSLookupTime : 445.246375ms -TCPConnectTime : 428.458µs -TLSHandshakeTime : 825.888208ms -FirstResponseTime : 1.289082208s -ResponseTime : 1.712375ms -IsConnReused: : false -RemoteAddr : 98.126.155.187:443 -*/ - -// Enable trace at client level -client.EnableTraceAll() -resp, err = client.R().Get(url) -// ... -``` - -**DevMode** - -If you want to enable all debug features (dump, debug log and tracing), just call `DevMode()`: - -```go -client := req.C().DevMode() -client.R().Get("https://imroc.cc") -``` - -## Quick HTTP Test - -**Test with Global Wrapper Methods** - -`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to the default client behind the scenes, so you can just treat the package name `req` as a Client or Request to test quickly without create one explicitly. - -```go -// Call the global methods just like the Client's method, -// so you can treat package name `req` as a Client, and -// you don't need to create any client explicitly. -req.SetTimeout(5 * time.Second). - SetCommonBasicAuth("imroc", "123456"). - SetCommonHeader("Accept", "text/xml"). - SetUserAgent("my api client"). - DevMode() - -// Call the global method just like the Request's method, -// which will create request automatically using the default -// client, so you can treat package name `req` as a Request, -// and you don't need to create any request and client explicitly. -req.SetQueryParam("page", "2"). - SetHeader("Accept", "application/json"). // Override client level settings at request level. - Get("https://httpbin.org/get") -``` - -**Test with MustXXX** - -Use `MustXXX` to ignore error handling during test, make it possible to complete a complex test with just one line of code: - -```go -fmt.Println(req.DevMode().R().MustGet("https://imroc.cc").TraceInfo()) -``` - -## HTTP2 and HTTP1 - -Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, which is negotiated by TLS handshake. - -You can force using `HTTP/1.1` if you want. - -```go -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutBody() -client.R().MustGet("https://httpbin.org/get") -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Accept-Encoding: gzip - -HTTP/1.1 200 OK -Date: Tue, 08 Feb 2022 02:30:18 GMT -Content-Type: application/json -Content-Length: 289 -Connection: keep-alive -Server: gunicorn/19.9.0 -Access-Control-Allow-Origin: * -Access-Control-Allow-Credentials: true -*/ -``` - -And also you can force using `HTTP/2` if you want, will return error if server does not support: - -```go -client := req.C().EnableForceHTTP2() -client.R().MustGet("https://baidu.com") -/* Output -panic: Get "https://baidu.com": server does not support http2, you can use http/1.1 which is supported -*/ -``` - -## URL Path and Query Parameter - -**Path Parameter** - -Use `SetPathParam` or `SetPathParams` to replace variable in the url path: - -```go -client := req.C().DevMode() - -client.R(). - SetPathParam("owner", "imroc"). // Set a path param, which will replace the vairable in url path - SetPathParams(map[string]string{ // Set multiple path params at once - "repo": "req", - "path": "README.md", - }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") // path parameter will replace path variable in the url -/* Output -2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md -... -*/ - -// You can also set the common PathParam for every request on client -client.SetCommonPathParam(k1, v1).SetCommonPathParams(pathParams) - -resp1, err := client.Get(url1) -... - -resp2, err := client.Get(url2) -... -``` - -**Query Parameter** - -Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query parameter: - -```go -client := req.C().DevMode() - -// Set query parameter at request level. -client.R(). - SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url - SetQueryParams(map[string]string{ // Set multiple query params at once - "b": "b", - "c": "c", - }).SetQueryString("d=d&e=e"). // Set query params as a raw query string - Get("https://api.github.com/repos/imroc/req/contents/README.md?x=x") -/* Output -2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e -... -*/ - -// You can also set the query parameter at client level. -client.SetCommonQueryParam(k, v). - SetCommonQueryParams(queryParams). - SetCommonQueryString(queryString). - -resp1, err := client.Get(url1) -... -resp2, err := client.Get(url2) -... - -// Add query parameter with multiple values at request level. -client.R().AddQueryParam("key", "value1").AddQueryParam("key", "value2").Get("https://httpbin.org/get") -/* Output -2022/02/05 08:49:26.260780 DEBUG [req] GET https://httpbin.org/get?key=value1&key=value2 -... - */ - - -// Multiple values also supported at client level. -client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") -``` - -## Form Data - -```go -client := req.C().EnableDumpAllWithoutResponse() -client.R().SetFormData(map[string]string{ - "username": "imroc", - "blog": "https://imroc.cc", -}).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/x-www-form-urlencoded -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -blog=https%3A%2F%2Fimroc.cc&username=imroc -*/ - -// Multi value form data -v := url.Values{ - "multi": []string{"a", "b", "c"}, -} -client.R().SetFormDataFromValues(v).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/x-www-form-urlencoded -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -multi=a&multi=b&multi=c -*/ - -// You can also set form data in client level -client.SetCommonFormData(m) -client.SetCommonFormDataFromValues(v) -``` - -> `GET`, `HEAD`, and `OPTIONS` requests ignores form data by default - -## Header and Cookie - -**Set Header** - -```go -// Let's dump the header to see what's going on -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() - -// Send a request with multiple headers and cookies -client.R(). - SetHeader("Accept", "application/json"). // Set one header - SetHeaders(map[string]string{ // Set multiple headers at once - "My-Custom-Header": "My Custom Value", - "User": "imroc", - }).Get("https://httpbin.org/get") - -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Accept: application/json -My-Custom-Header: My Custom Value -User: imroc -Accept-Encoding: gzip -*/ - -// You can also set the common header and cookie for every request on client. -client.SetCommonHeader(header).SetCommonHeaders(headers) - -resp1, err := client.R().Get(url1) -... -resp2, err := client.R().Get(url2) -... -``` - -**Set Cookie** - -```go -// Let's dump the header to see what's going on -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() - -// Send a request with multiple headers and cookies -client.R(). - SetCookies( - &http.Cookie{ - Name: "testcookie1", - Value: "testcookie1 value", - Path: "/", - Domain: "baidu.com", - MaxAge: 36000, - HttpOnly: false, - Secure: true, - }, - &http.Cookie{ - Name: "testcookie2", - Value: "testcookie2 value", - Path: "/", - Domain: "baidu.com", - MaxAge: 36000, - HttpOnly: false, - Secure: true, - }, - ).Get("https://httpbin.org/get") - -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Cookie: testcookie1="testcookie1 value"; testcookie2="testcookie2 value" -Accept-Encoding: gzip -*/ - -// You can also set the common cookie for every request on client. -client.SetCommonCookies(cookie1, cookie2, cookie3) - -resp1, err := client.R().Get(url1) -... -resp2, err := client.R().Get(url2) -``` - -You can also customize the CookieJar: -```go -// Set your own http.CookieJar implementation -client.SetCookieJar(jar) - -// Set to nil to disable CookieJar -client.SetCookieJar(nil) -``` - -## Body and Marshal/Unmarshal - -**Request Body** - -```go -// Create a client that dump request -client := req.C().EnableDumpAllWithoutResponse() -// SetBody accepts string, []byte, io.Reader, use type assertion to -// determine the data type of body automatically. -client.R().SetBody("test").Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -test -*/ - -// If it cannot determine, like map and struct, then it will wait -// and marshal to JSON or XML automatically according to the `Content-Type` -// header that have been set before or after, default to json if not set. -type User struct { - Name string `json:"name"` - Email string `json:"email"` -} -user := &User{Name: "imroc", Email: "roc@imroc.cc"} -client.R().SetBody(user).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/json; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -{"name":"imroc","email":"roc@imroc.cc"} -*/ - - -// You can use more specific methods to avoid type assertions and improves performance, -client.R().SetBodyJsonString(`{"username": "imroc"}`).Post("https://httpbin.org/post") -/* -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/json; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -{"username": "imroc"} -*/ - -// Marshal body and set `Content-Type` automatically without any guess -cient.R().SetBodyXmlMarshal(user).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: text/xml; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -imrocroc@imroc.cc -*/ -``` - -**Response Body** - -```go -// Define success body struct -type User struct { - Name string `json:"name"` - Blog string `json:"blog"` -} -// Define error body struct -type ErrorMessage struct { - Message string `json:"message"` -} -// Create a client and dump body to see details -client := req.C().EnableDumpAllWithoutHeader() - -// Send a request and unmarshal result automatically according to -// response `Content-Type` -user := &User{} -errMsg := &ErrorMessage{} -resp, err := client.R(). - SetResult(user). // Set success result - SetError(errMsg). // Set error result - Get("https://api.github.com/users/imroc") -if err != nil { - log.Fatal(err) -} -fmt.Println("----------") - -if resp.IsSuccess() { // status `code >= 200 and <= 299` is considered as success - // Must have been marshaled to user if no error returned before - fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) -} else if resp.IsError() { // status `code >= 400` is considered as error - // Must have been marshaled to errMsg if no error returned before - fmt.Println("got error:", errMsg.Message) -} else { - log.Fatal("unknown http status:", resp.Status) -} -/* Output -{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":129,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2022-01-24T23:32:53Z"} ----------- -roc's blog is https://imroc.cc -*/ - -// Or you can also unmarshal response later -if resp.IsSuccess() { - err = resp.Unmarshal(user) - if err != nil { - log.Fatal(err) - } - fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) -} else { - fmt.Println("bad response:", resp) -} - -// Also, you can get the raw response and Unmarshal by yourself -yaml.Unmarshal(resp.Bytes()) -``` - -**Customize JSON&XML Marshal/Unmarshal** - -```go -// Example of registering json-iterator -import jsoniter "github.com/json-iterator/go" - -json := jsoniter.ConfigCompatibleWithStandardLibrary - -client := req.C(). - SetJsonMarshal(json.Marshal). - SetJsonUnmarshal(json.Unmarshal) - -// Similarly, XML functions can also be customized -client.SetXmlMarshal(xmlMarshalFunc).SetXmlUnmarshal(xmlUnmarshalFunc) -``` - -**Disable Auto-Read Response Body** - -Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). - -```go -client.DisableAutoReadResponse() - -resp, err := client.R().Get(url) -if err != nil { - log.Fatal(err) -} -io.Copy(dst, resp.Body) -``` - -## Custom Certificates - -```go -client := req.R() - -// Set root cert and client cert from file path -client.SetRootCertsFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files - SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file - -// You can also set root cert from string -client.SetRootCertFromString("-----BEGIN CERTIFICATE-----XXXXXX-----END CERTIFICATE-----") - -// And set client cert with -cert1, err := tls.LoadX509KeyPair("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") -if err != nil { - log.Fatalf("ERROR client certificate: %s", err) -} -// ... - -// you can add more certs if you want -client.SetCerts(cert1, cert2, cert3) -``` - -## Basic Auth and Bearer Token - -```go -client := req.C() - -// Set basic auth for all request -client.SetCommonBasicAuth("imroc", "123456") - -// Set bearer token for all request -client.SetCommonBearerAuthToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") - -// Set basic auth for a request, will override client's basic auth setting. -client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com/profile") - -// Set bearer token for a request, will override client's bearer token setting. -client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") -``` - -## Download and Upload - -**Download** - -```go -// Create a client with default download direcotry -client := req.C().SetOutputDirectory("/path/to/download").EnableDumpNoResponseBody() - -// Download to relative file path, this will be downloaded -// to /path/to/download/test.jpg -client.R().SetOutputFile("test.jpg").Get(url) - -// Download to absolute file path, ignore the output directory -// setting from Client -client.R().SetOutputFile("/tmp/test.jpg").Get(url) - -// You can also save file to any `io.WriteCloser` -file, err := os.Create("/tmp/test.jpg") -if err != nil { - fmt.Println(err) - return -} -client.R().SetOutput(file).Get(url) -``` - -**Multipart Upload** - -```go -client := req.().EnableDumpNoRequestBody() // Request body contains unreadable binary, do not dump - -client.R().SetFile("pic", "test.jpg"). // Set form param name and filename - SetFile("pic", "/path/to/roc.png"). // Multiple files using the same form param name - SetFiles(map[string]string{ // Set multiple files using map - "exe": "test.exe", - "src": "main.go", - }). - SetFormData(map[string]string{ // Set from param using map - "name": "imroc", - "email": "roc@imroc.cc", - }). - SetFromDataFromValues(values). // You can also set form data using `url.Values` - Post("http://127.0.0.1:8888/upload") - -// You can also use io.Reader to upload -avatarImgFile, _ := os.Open("avatar.png") -client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) -*/ -``` - -## Auto-Decode - -`Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. - -Its principle is to detect `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information. And `req` also will try to sniff the body's content to determine the charset if the charset information is not included in the header, if sniffed out and not utf-8, then decode it to utf-8 automatically, and `req` will not try to decode if the charset is not sure, just leave the body untouched. - -You can also disable it if you don't need or care a lot about performance: - -```go -client.DisableAutoDecode() -``` - -And also you can make some customization: - -```go -// Try to auto-detect and decode all content types (some server may return incorrect Content-Type header) -client.SetAutoDecodeAllContentType() - -// Only auto-detect and decode content which `Content-Type` header contains "html" or "json" -client.SetAutoDecodeContentType("html", "json") - -// Or you can customize the function to determine whether to decode -fn := func(contentType string) bool { - if regexContentType.MatchString(contentType) { - return true - } - return false -} -client.SetAutoDecodeContentTypeFunc(fn) -``` - -## Request and Response Middleware - -```go -client := req.C() - -// Registering Request Middleware -client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { - // You can access Client and current Request object to do something - // as you need - - return nil // return nil if it is success - }) - -// Registering Response Middleware -client.OnAfterResponse(func(c *req.Client, r *req.Response) error { - // You can access Client and current Response object to do something - // as you need - - return nil // return nil if it is success - }) -``` - -## Redirect Policy - -```go -client := req.C().EnableDumpAllWithoutResponse() - -client.SetRedirectPolicy( - // Only allow up to 5 redirects - req.MaxRedirectPolicy(5), - // Only allow redirect to same domain. - // e.g. redirect "www.imroc.cc" to "imroc.cc" is allowed, but "google.com" is not - req.SameDomainRedirectPolicy(), -) - -client.SetRedirectPolicy( - // Only *.google.com/google.com and *.imroc.cc/imroc.cc is allowed to redirect - req.AllowedDomainRedirectPolicy("google.com", "imroc.cc"), - // Only allow redirect to same host. - // e.g. redirect "www.imroc.cc" to "imroc.cc" is not allowed, only "www.imroc.cc" is allowed - req.SameHostRedirectPolicy(), -) - -// All redirect is not allowd -client.SetRedirectPolicy(req.NoRedirectPolicy()) - -// Or customize the redirect with your own implementation -client.SetRedirectPolicy(func(req *http.Request, via []*http.Request) error { - // ... -}) -``` - -## Proxy - -`Req` use proxy `http.ProxyFromEnvironment` by default, which will read the `HTTP_PROXY/HTTPS_PROXY/http_proxy/https_proxy` environment variable, and setup proxy if environment variable is been set. You can customize it if you need: - -```go -// Set proxy from proxy url -client.SetProxyURL("http://myproxy:8080") - -// Custmize the proxy function with your own implementation -client.SetProxy(func(request *http.Request) (*url.URL, error) { - //... -}) - -// Disable proxy -client.SetProxy(nil) -``` - -## TODO List - -* [ ] Add tests. -* [ ] Wrap more transport settings into client. -* [ ] Support retry. -* [ ] Support unix socket. -* [ ] Support h2c. - -## License - -`Req` released under MIT license, refer [LICENSE](LICENSE) file. From 10c1189f457899dba7aa1f184b6f8e7ae5b05a6b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Feb 2022 20:57:20 +0800 Subject: [PATCH 267/843] fix SetFileUpload --- middleware.go | 17 ++++++++++++++--- req.go | 49 ++++++++++++++++++++++++++++++++++++++++++++----- request.go | 10 ++++------ 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/middleware.go b/middleware.go index e6be4f37..10819d3c 100644 --- a/middleware.go +++ b/middleware.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "fmt" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -33,8 +32,20 @@ func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEH hdr := make(textproto.MIMEHeader) contentDispositionValue := "form-data" - for k, v := range file.ContentDisposition { - contentDispositionValue += fmt.Sprintf(`; %s="%v"`, k, v) + cd := new(ContentDisposition) + if file.ParamName != "" { + cd.Add("name", file.ParamName) + } + if file.FileName != "" { + cd.Add("filename", file.FileName) + } + if file.ExtraContentDisposition != nil { + for _, kv := range file.ExtraContentDisposition.kv { + cd.Add(kv.Key, kv.Value) + } + } + if c := cd.String(); c != "" { + contentDispositionValue += c } hdr.Set("Content-Disposition", contentDispositionValue) diff --git a/req.go b/req.go index 11a6467e..be42f474 100644 --- a/req.go +++ b/req.go @@ -1,6 +1,9 @@ package req -import "io" +import ( + "fmt" + "io" +) const ( hdrUserAgentKey = "User-Agent" @@ -12,9 +15,45 @@ const ( formContentType = "application/x-www-form-urlencoded" ) +type kv struct { + Key string + Value string +} + +// ContentDisposition represents parameters in `Content-Disposition` +// MIME header of multipart request. +type ContentDisposition struct { + kv []kv +} + +func (c *ContentDisposition) Add(key, value string) *ContentDisposition { + c.kv = append(c.kv, kv{Key: key, Value: value}) + return c +} + +func (c *ContentDisposition) String() string { + if c == nil { + return "" + } + s := "" + for _, kv := range c.kv { + s += fmt.Sprintf("; %s=%q", kv.Key, kv.Value) + } + return s +} + +// FileUpload represents a "form-data" multipart type FileUpload struct { - ParamName string - FileName string - ContentDisposition map[string]interface{} - File io.Reader + // "name" parameter in `Content-Disposition` + ParamName string + // "filename" parameter in `Content-Disposition` + FileName string + // The file to be uploaded. + File io.Reader + + // According to the HTTP specification, this should be nil, + // but some servers may not follow the specification and + // requires `Content-Disposition` parameters more than just + // "name" and "filename". + ExtraContentDisposition *ContentDisposition // Usually } diff --git a/request.go b/request.go index 2b7d4c08..7ced8b1b 100644 --- a/request.go +++ b/request.go @@ -178,13 +178,10 @@ func SetFileReader(paramName, filePath string, reader io.Reader) *Request { // SetFileReader set up a multipart form with a reader to upload file. func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *Request { r.isMultiPart = true - contentDisposition := map[string]interface{}{ - "name": paramName, - "filename": filename, - } r.uploadFiles = append(r.uploadFiles, &FileUpload{ - ContentDisposition: contentDisposition, - File: reader, + ParamName: paramName, + FileName: filename, + File: reader, }) return r } @@ -242,6 +239,7 @@ func SetFileUpload(f FileUpload) *Request { // SetFileUpload set the fully custimized multipart file upload options. func (r *Request) SetFileUpload(f FileUpload) *Request { + r.isMultiPart = true r.uploadFiles = append(r.uploadFiles, &f) return r } From 2d870a302da1b164be5c03831daafb7fe8bec299 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 10:31:15 +0800 Subject: [PATCH 268/843] not force http version of proxy --- transport.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transport.go b/transport.go index 4e9ce4eb..f8fd40b3 100644 --- a/transport.go +++ b/transport.go @@ -1458,7 +1458,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. -func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error { +func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pconn.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1499,7 +1499,7 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr } pconn.tlsState = &cs pconn.conn = tlsConn - if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + if !forProxy && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil @@ -1559,7 +1559,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers trace.TLSHandshakeDone(cs, nil) } pconn.tlsState = &cs - if pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + if cm.proxyURL == nil && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { return nil, newHttp2NotSupportedError(cs.NegotiatedProtocol) } } @@ -1574,7 +1574,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { return nil, wrapErr(err) } - if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { + if err = pconn.addTLS(ctx, firstTLSHost, trace, cm.proxyURL != nil); err != nil { return nil, wrapErr(err) } } @@ -1688,7 +1688,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.proxyURL != nil && cm.targetScheme == "https" { - if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { + if err := pconn.addTLS(ctx, cm.tlsHost(), trace, false); err != nil { return nil, err } } From f96c44655bac6d6c2448063ad452d8dcfbd89ea2 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 11:21:15 +0800 Subject: [PATCH 269/843] remove useless comments --- req.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/req.go b/req.go index be42f474..9710cae4 100644 --- a/req.go +++ b/req.go @@ -55,5 +55,5 @@ type FileUpload struct { // but some servers may not follow the specification and // requires `Content-Disposition` parameters more than just // "name" and "filename". - ExtraContentDisposition *ContentDisposition // Usually + ExtraContentDisposition *ContentDisposition } From e77722cbdc3f2b291f1bfa9c9c7536ade1c5d6ff Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 13:50:04 +0800 Subject: [PATCH 270/843] refactor test infra --- .testdata/sample-client-key.pem | 27 ++++++++++ .testdata/sample-client.pem | 23 ++++++++ .testdata/sample-file.txt | 3 ++ .testdata/sample-image.png | Bin 0 -> 199703 bytes .testdata/sample-root.pem | 21 ++++++++ .testdata/sample-server-key.pem | 27 ++++++++++ .testdata/sample-server.pem | 23 ++++++++ client_test.go | 19 +++---- req_test.go | 92 ++++++++++++++++++-------------- request_test.go | 31 +++++------ 10 files changed, 197 insertions(+), 69 deletions(-) create mode 100644 .testdata/sample-client-key.pem create mode 100644 .testdata/sample-client.pem create mode 100644 .testdata/sample-file.txt create mode 100644 .testdata/sample-image.png create mode 100644 .testdata/sample-root.pem create mode 100644 .testdata/sample-server-key.pem create mode 100644 .testdata/sample-server.pem diff --git a/.testdata/sample-client-key.pem b/.testdata/sample-client-key.pem new file mode 100644 index 00000000..47c6e050 --- /dev/null +++ b/.testdata/sample-client-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA779bgekwF6sVauTdq12ZB+5Say3pq0Aq43fQ8MniYPGDdhPZ +PnI8pMp08tuiSAZIatyX/NozobMs/U/U6aN8Vmbf8ERfXmN5kYLN1585RObyQkwR ++PS9jLUw98fuYnumPExb8hMDPB2QZIBu7oZRrzr3hdyw4/aqv9msyYi5rw/BdwA0 +mt9HqXZAyD23x5oKnEly5Uf/CQXF3BM8Zus05b4Nlig8oKmDPNccew3ly6JLpXHg +THMyKzNmzm2Sko4JsutMEL/TV5GmOjpM0SCG62F85K+X5l0bARNwc9qtoKpzeMcz +CM3g4BCdyDylup7nXyi4VAv4dJ0aPHw8Nc1rpQIDAQABAoIBAC8NQSpH15ZtjzCB +ZjfBkM0LqsU6J4fieghWdX0sQe+Atqovzw0AYoJ88WLQVBMKmJ/QV0vajxOHFKdK +SaDo4vgaDI0c/hKKN0ulfjx5FUY+hQEZ6NURQzogPVIDvPc7CS8AVXM25AWiT7pJ +snvBhLp9OiLdYyH6QRyR3eVXngmLDtAJVz86HW0GIjWLNjApigdMjAQDCE1AlW/H +HuLZbfhUZ2EtcFI98OZfWHPwLchGimIeo/WxciFrzkkkbCyvVyv8QygK0FfhMsh0 +SIWFZxbxx21WsDeDaX75qr+sLH2IzCjazkgDbZTdY9GsLqt8CD9YFzBxazmvH8R0 +de/dFGkCgYEA+er0+wF266kgdA0Js2u4szh0Nz6ljaBqzCLTcKj/mdNnc99oRJJO +1HZET4/hoIN0OTZvIH7/jx7SJSIDVOgSGpAF1wnoHhT/kjnXC6t5MDCRuceZIV76 +EbTAGCMwDz1yAFmTMr6qMXUcOCsBBAtlNKDtiqXshnUfctCKXlw8NP8CgYEA9ZUJ +hw1/g6Bomp7VCtwk+FBQ40SRjvrMrlJWJ39/Ezbp5I9I4DfJzTjQryg5ff+Ggq0/ +sUQrJsr6XF2AyQx1kV0vb5tdIlFjVNCErE/ZzVklss9r8BiSMBAIyqeMFVZWT6kd +pY1/uj1Q+Y5stP9iEWq/r5mmW8N+lJnq5H3Aa1sCgYEAzySfySx9lPapn4bu83fl +ryarrN6P+cNswaZb+pUYxjcjGDekBLIABLnCBPAM4y4RtxoXIaghyk6Rf5WhjU6N +MtcNAB+F9OkSq/Ck/VczK24WWxXFJpPCUcqvLVJ9EySqyP91sim2hye6LBP405Fe +YTDBspm0Yf3SAyg2h9+LR6ECgYBENAz+VfBZBP6oGn5+Up9t2xhr1co7FEouC63j +sFQBaRnSIT0TEEtaVHIYgypcZM/dkPIEcDMvxeV8K3et3mj0YxXegB6AfmwAzRxb +op2RmzWOEG8gsiI/eOSIK7oK3vx/iS8zoDWd6pOHi1eDeP2qaqQrx5ddGtEXwhtr +M8VxywKBgQDUVxNDU5U+Fplr5XyxsmemnwbvzkJW0Iz04yxSyH+WRUtGnJoL6d5v +fYhwL60gFFh3FFWTOiQxlvEOnhpfdqufCcO4PtHYxRMG37faBCb1ewQlcQSZ0n7i +jQlLzfStlPRP8QEVBW/oc4aMDO7CVP77j5g0Wzt7Kuyh0mFYx20alg== +-----END RSA PRIVATE KEY----- diff --git a/.testdata/sample-client.pem b/.testdata/sample-client.pem new file mode 100644 index 00000000..fb029bba --- /dev/null +++ b/.testdata/sample-client.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDxzCCAq+gAwIBAgIUGDq1keCeZFfm+Pp7nYCpfUyTltQwDQYJKoZIhvcNAQEL +BQAwWjELMAkGA1UEBhMCQ04xEDAOBgNVBAgTB1NpQ2h1YW4xEDAOBgNVBAcTB0No +ZW5nRHUxDDAKBgNVBAoTA3JlcTELMAkGA1UECxMCQ0ExDDAKBgNVBAMTA3JlcTAg +Fw0yMjAyMTAwNTEwMDBaGA8yMTIyMDExNzA1MTAwMFowXjELMAkGA1UEBhMCQ04x +EDAOBgNVBAgTB1NpQ2h1YW4xEDAOBgNVBAcTB0NoZW5nZHUxDDAKBgNVBAoTA3Jl +cTEMMAoGA1UECxMDcmVxMQ8wDQYDVQQDEwZjbGllbnQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDvv1uB6TAXqxVq5N2rXZkH7lJrLemrQCrjd9DwyeJg +8YN2E9k+cjykynTy26JIBkhq3Jf82jOhsyz9T9Tpo3xWZt/wRF9eY3mRgs3XnzlE +5vJCTBH49L2MtTD3x+5ie6Y8TFvyEwM8HZBkgG7uhlGvOveF3LDj9qq/2azJiLmv +D8F3ADSa30epdkDIPbfHmgqcSXLlR/8JBcXcEzxm6zTlvg2WKDygqYM81xx7DeXL +okulceBMczIrM2bObZKSjgmy60wQv9NXkaY6OkzRIIbrYXzkr5fmXRsBE3Bz2q2g +qnN4xzMIzeDgEJ3IPKW6nudfKLhUC/h0nRo8fDw1zWulAgMBAAGjfzB9MA4GA1Ud +DwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0T +AQH/BAIwADAdBgNVHQ4EFgQUoa8OEGUk7FwiwKrw0NpjiSv9WCwwHwYDVR0jBBgw +FoAUpNtNypX4RfRXBAXxI2MtrodFNMswDQYJKoZIhvcNAQELBQADggEBAAO6nZtS +MPb1PSUAL7pe9ZiCL4VH+ED/RkW+JSi/nJXJisvQJqEqUCw3nkTQik6yT7+Dr8sp +0rUNwxzBNUi7ceHRsThtUcjDXN7vgZfMDU+hC3DssgfmrqtdHefkD7m4MdAJjg9F ++kKOkKmCCzJ2sGMnFhVW3gQDf/PHl4VoZeBEATfqGqxmQROTJuLCpus3/yFxmsu1 +nT1x4K7HLzdWztxCQ/Nq/DpjD0nIMXicBFPnQamv+PHePS5NT+UEBvvbz/747aYI +LG6Jczhl3oK3zjuLAwW5QSsID7CERKclZCy6BuMekAEkQeqL79T0joUUVb5ywhQz +qHRB9DxZiGxazDU= +-----END CERTIFICATE----- diff --git a/.testdata/sample-file.txt b/.testdata/sample-file.txt new file mode 100644 index 00000000..032120aa --- /dev/null +++ b/.testdata/sample-file.txt @@ -0,0 +1,3 @@ +THIS IS A SAMPLE FILE FOR TEST + +https://github.com/imroc/req diff --git a/.testdata/sample-image.png b/.testdata/sample-image.png new file mode 100644 index 0000000000000000000000000000000000000000..5eee2c93b26e37a1ad18e0697a72f0c88866e7ee GIT binary patch literal 199703 zcmeGDcUM!}7d8xY?1duIy9$DIL3&jIL8(eF0i|~cy(34dp@>N6K{^NoLvINR1dtYb zuZa+P=mbJ}Hs`*_`#%3X-{7~#fIT3|-fPV@=QXc+&6Q{!Emev;%y-Dh$S7W`DZe8l zyG{%IxZk<~{73%Na6TE?N3z$-FZBE}H_mQn+W3NaPD3%s8be>h@v)l2ufY#xZ)bhE z_Vvs2*8BG#C{y3KetoF|MayPK3%>Ez?%OdW{%SnwX+qnx*cfa_eQ%!J11TBjnCGp~ zFLOumH8^rj4)6Pe&%yG(R*~D$=2r83vsQmeQ|2DB#dH1O$N9FY<3GSe$aE5)Q(s<- zHNM}vyy|i5T)VvT(LA|wc}>1ed-=YP-Pd9+|9oz9MdkAPT{P^`Hg<7 zm$C4_kog~60B-OsVokuyQsHZDZEe9dCY$IQ6CYC; zp2e}IYgg`Y$^Q_6#VbxCtUCL#-go-X%bfGqcC$)z>0h%zp zC;AWqfgCsaY);h(%X`K2TqBd(yku`u>jy_i2|hlxXIjuvnCa4iKYKRF%q#%cpCzkE z_fm!IqnnPK55V%=MdM;)$*XH>PCr0L;T0z^0m!1|<^B4#HMH1ggGJY`l1UBUy1b>+ z!QtUybzR-52XxfjtC`SO>6}k=3sqO2&eqn}#=v4v$jJJY4g7DAEtHMS&$rYzHSJ{? z8#G<4!Z>>>o0=M@EC#w7f%kY_^2vUg;Nal(#>T007pTp=ucckLOw6V~`h}^|9pH&~ z|9fI-{!>I#)7d<2=o+f#MNzYdy}fwAF|9{u_EuwB{C zP9mYFa>GK!n#X1En*gK@*o!-kL%>TV3b7#jp*njRU$=D!JX~zoxr`Z73sceUD z1{n1FaNyU+XWbJM17mU(7tZBS`3qXm?(_D~(tx!MMlKidYx2Fdwe|Oj8i}U(s#zj& z<;d3f>}t_djaWI1uiG?z05Hq40VQg(g;fU*78d_tU}vB<$51RjerIPeJjz)WCLA2( zz#HO%@6WPu^GURW&`#7=Tep0}C(^$+4&bGCPKP$%ehkx>aAo!PhuO8HQbGPOD= zsC8}aCGUU8wCp0}u9E4c#3dx4WP2Q2GHoW&+GDzKzh)$nU=PD9nl+zmYk33HL7`CI zDx@oG>py;zOzP~?jbJ7cqM{y}c6QSow#G9_ zHk07K7ocX8c{5=P*g#2G*%i#ANg}Zkg+L6>7mrsd3QoDIJGLxCzlAygF=y!M1Vk*- z=-YL&g+Dc}Gg7-)IZTNBBbF}M^2Tt9suir0iawISUmO04-mjLCmDOx`Zf>_SwHPhlB32eAriSu=@Y2)O z3Gwmf6N7^%O@PmY4vrORrlvALp^qqR+T$Di+Gih zP>~*N_bZqbCJ$C<<@{F~qX(eUCjgaB0_5jF)qni^;CSAk5Lef&MO%h{$lM#L0e7Q4 zIXM~g_NkrcNJL8ZX!Xyl`J<}+MEe5^895UAqZWG}IV!?zCTIKLMgLZeIEvq`BqNiH z2VC8+gN2O^S!|AJQOrGrCV9sRK+xEhChf6fZC_4KPK=Pk1%H(QY0!53oeFA$!I>M5!wftG}L25s6)pm|59Jj;7kn-n(+PdqJy7JbTPVd5epS>T3rlS;o+db%7oU zu_hSqrzEBdBiVF|?Bnodt?Tz{ZFSYS#T9hKvMU#82UQv8t^YgH+k0xK_1jzPF0ks1 z^@ELYV<@u0ettw2G5MG+C3jg^NJuE8VbU9zLi~SI&|(fTt*}$Pmxz>*wUIVF+&(`e z`DXqo6;}uF&&}8OsK=?sPZMh63rwMiAA=ii@+E}i|G#d~y|S`WH}1GICeS0Ks6XDj z7lfMVKV0lDx3c0qV%QAOMj$IE(V0{rsI-7&Y6l$7w*wPT^* zrGnI#x$;w2AviURoS_1p;u}nU$l=krWX_x z1pnP$OwCA(_6`ny%kR27G&FQOTL!bI?J`|U`h7$1HkrG@B~;p#Um_i~Q%J7VX{|ud zjr6xl=3SMX(IuDblKLkXfdcrQE%RuLCH@|PveH1D&o$Xd^;1P zql+0ZSy5s$(mlJ{w2__^y!2!X0$i7KeuKktItaDYy|R{+Z1@%~m(cR$kb*5~n%v5& zt)jbfiY2Daj=`lVhe^D;wzgwrFc0B$G`C_pPcwfstVp+4Rkee<*oIpwG6>Gj^2nSX z{|*m7bz{~S2Jlr4kU%o^hxpo@#^J|lQKzsd=laP5uDf^DIb3}xcr0-pgKoS&^DsPE zA8Mo2SN!lmu6M6sq9Itdx4$1vQ75N}I+{muc3>1GL$Go4Of)yCRJAP#w=o5&HtN|E zHBRFx%z6@Jzl>jBK1lqPs;UL36g~ZDx2I}!R3`n2@NoIY9*}QkI2nkHEG!`)=;&M~ zeiE(XbSPP%!jZ)fess8{c+U#Q37+?DdFe(FC&+pc6j>XBz*bFQK4t z0<}B>Q6Dc0^o&WCei`8hgFig14c^T0DBhc$KP=z*AsJ#v?l#2;AnZ=Kb}--h-@17V zRAEs?#Xm?I9HeF{vBh5J7^7E%K;*1fmca^{>nn+rakPT6rmTuz@#-A&sNnINF}O+Y zabWom0QC7Usru^z%VbYXcx7Eb8j5fvTMFGU0HIqAFXPi|-Ir7BjG zWVdFPUi*o3)cN=CVEyit6Yjh&NL!8sHMR5!kfG#|VmuhDC9W%;?&^1jv)|^7ITOdb$(p411& zmc+&##Mv2lCeTJ!@YSMn*`_n3jpUZJJ%@)8ZWt6JJ4@X7*@7u9yduQg%f}o2TRzaz z$Cf)n(k)&Ub-F~{n%~NE*G(D(T$?vwolCTt)bwZnghrK5 zFwAFJDV=qq)X=!qc{Zn%x($$(Ri5w!D!7v$Cty@|iwqWhiBxG^>%3*4>9Yz5j^jyf zt(1IH`#V@CfwqZ8tC3lt#?ZytdAm};>1+o}?6+L+tt^`x$*mNir>DpMf`bMJ2dR&m z&%JjlLwDE^$E|mvgAr~YoQVM$%EG#TS`JF%PViFE&Lb+q@ zWJWqFlfrT^@MHB3!{=CMHL9h{%6b?FR# zsCVt-v&;M-rF)tB=H8at0CYiQZo?9l2H;@S9^(d#76qi8V@hJCuKJ^P3FF8kuC02T{ynU;^M?_PTidOAV7WN^n?>p zU2O!P$;==48EXH8a&wxYh2(T0O}RPG(1x?3xJ3iu+a<0XZZ=)OsW~>uN5kvB87Qi~W-;EA|#|JE7R^K4ba{;o4g6(OsjrD%AcvEd^+FYIauK zpXfh~lTz#z0ehvm@S|CVt8JVJI})eJvuCF8Z8KAIOdczCrp8vYy1E7(c|p2h%RbJ~ zZa)b-UkW$Uda5JfY--Xy@7v0;)%$jxKp+TB)i%}C9ApagxEZg8`|WkhMS}V4sMYc& zOikB?g(*W1zJX}Tq&@)E!a{XbXn^2)28HIBpM!jF$>e()b?X&gBM44OkehbC)z z=*Ti81XWa3HP=%84U}Yaz(7kiUn}SRd{`A(!Mj%f5C}o{SC=`vC=LdL&51k7z<-Rc z5u&EJyuFS73P891A!FsZ zggF^|HVzIHAI51ORBGd9?<2y)eX%<-vXM=DlShvJD8rtEon4`>;BgZ-rpnO9*a!0( zWp2}XBYJrFnMmNinpc@A2PSlK7K}eT>a-bm5sXwI1QBzLj8cX18bOa7iamX`wWXJ^ zH|I!m#QphD(N__Y0!MykGKj&yCBBvjQdO3$Z^-UvEwA(yfUV8X&dPtYf#jjHiNM1` zbiq6@3o-rBDyP!#Rg>4r$T~0mS!{HBd%NZ!;BY8vL)EFKY<2H6t5^$;yf`2L`zZ_ZuuFKuTeO zR?<}1IksslTSpla6Vq%l4KouHK0|k{Qkjd@P?Fou&5MJUR~DGYm6bJ_#hiZr(GVLj zcy*RS>TQ)5IVEL$P2G5B494RH0%RfQqXeyyb2?rf{;R&DN!udqzf&_gw;=UY#g|E4^CD#w&KqK&lqV}W|Ap< zy7MRBpJsR)U;SgR=wjIoTNiSU&#OVoj~yM*kEfcVLV4c%Z?$g(%ldAeVsB@y(5ip% zfcZLiSDwxYl!A1FzOKh~)N_|ZXEvr_aI8YFfS{oK!2fR@S|tG3WKNHbdDy`^S*ps& zm9(SiHAhJqhF*GyJ#x z?53rwtxRKJAg)1jo=!Mml)sxj+QEDI6X_6*%98ETPrI0ui%v*tZTB*Kmr6XS2zorX z7w_AigPqB!b&fd}bF;MCnk-_VhlUhT^T?#d0I6tfOsdGhY4V*|Pcm8wV8*c8WU}+# zm2je73`U#|v7h+=_x?Be;YxsUFYzN3YEuVWTQ5rf#6Zu0TR4k%Q{wYTzgDry62BG?3f@mQjFZD5(Umwh7e$+E3ouJ3!L z8L#myaL=r7X8{H!G`3#`iT-rcYTgj_4t1V!J^J+>n&IZ?Gfqws{}r8P{H9w%h1~Xg z=NTyj$}rc@Ks_Qgs=YPS;Ax?xHSruMm9DCof^=p`dojVxjX@5BQz^8hyILt9HhF%Z zudW_fS+c8>`Eg(3508Fm=*3ig{Mfj)@#dC5k9p9KitN@hC={A1<@Mx{O^SBh87b7! zFU_NHFqojw=%;4&x24(+1FKNX@+}-KABS%?i|0aH+nm-9>Pw`y>~%g+RLyWA*gdPN z!0z+sMw!_OVS|bF%Y_-Zc*EvLLe0BB>utfsL#smUK#o}E_lh}%rBQ!t3>_HrB#HcK3O<~JQ&F^gv za8)Kd4GfTqE*>gYsi%V*CjU2fw4lz;&O}XPgHQ+MbFA$(E?zZ_yM7yf)`7hYq!iQ% z%5OAdjDhiyU^r8B3OdlN`dkj^=08fP4;W&vKdPusi3f%h2Gi{D4jLmkk$59`x=2O zY&fIboaKn*PG~RmtZff=3-Jt^J_-#zP01!T5T{#CAuFrvP78D?AwGmTBvo*gVO)YL zerRZA(-^AQV$dTb%pPFS4Pr4;Q%Y7!UOz~c2@s>Cbuos>Nr3!5n1zzqz{BcWF>&}T z2EiJr2Q^eo+maa`)cnHo>yj8epgPC5HC`#VZSE zB?;12ZVVMKB+Y{z9h~^D!e^TfmF-LeXU@v!?+g3SI^q0%kzD5|R2NUrdz$9iHnXyp zGDG~Y8UzUS@AwW6dzPe(Vrk^FBZ7Q=YLO56B8^V{Y>=?lf~l3;7DO%+3hM+`SdXnnz+x=w2_Bi2YPS>EpERkxk?xfGNVO_AjN2Q$B3(G^X*7lK4UhdXhAhB~?!kIl2 zH)Y{dL}_Wx2kg+hbJBDjAt1C(%P+QiKnH2ASYa+4DOj|je;Y&Iq%=>TPu3o*$$;MXh zu$U$rb68*IH|mE_+zc#8c1g7n@!x*VDka9`)F^{n%W#dwBmbhon|p3?urMHUGaw~@ zqEpjgLXi?h$r;&N8qjp1>k*tRlIYV@v(wgu(;lKl%SFz_Q!%3(D2H2KxW5*ED$@uM zp4-8tZ;on^$Xe-KpBK>4{%DW>wuyJnNR_s7es~qC9^;f^ktJ^bVy$ZI{2?C zlfNW!4lEoTXLFWEY-k@>n8; z9z3dd?x*Y)OtjefvZkfSozKDQm(ac)dln~^^+D@nD{EWv$CC|G0YvwP2R)8ohvllB0FmbXf{yDX1Si#af`iPX*!gD5*v-66JfYSqxGw}ZYL5h+mi`a<28-j!d)ZuB!8B(Q+oHpUW*4U-* zcJ?W)5y;-=gR79w<_w$X-Mb>SU;)?ArE>5d$QK=MXl@)4Q8hd);!kv!7V|#{)~4rz zAIfNupg;DV{2k$5N&cj2+D;n$|jf zw`PtJVnrXyYk!7-5hxUKYM<^gCH-{7h-dW#6)h>8z4!&2d8_}VnB!)d&GD05@8693wFXajV73|8!h59Wr`>4dkM0kKaDMtV@j=xj{ z#%9Y*pplcyh)@hFH60dRS`n(%eA3jy14peEreQoyjmawl6TRKp&@o=p~!C!@*~4b)7`0XEz-} zmsD^M?7gys3uIz<5?9lxa9H0E2o!^_u11^e?7^~Rt!nEHTx`&rHE7%?E32JEs^BLf z$61O@HxDBNIxUP)`=CO6)!$EKE0s>xlttBk7`;+&L&#eD&hJcIUZ}xv`^6 zsf~!@GnPmyb^#6+M)R%F>0**@CSC`IF+&BD!dxh%ACkFuAGh5m^Lqmnsw^fZju({% zJ?2UH!E77jz~PwJ$Ti0rsQ|x6l&o^J8JJ1dBh*eDcwPPoy0^tP)uKp`$M>*<+gKz! z2BV{ie+7GehW4qi=eT*eC3vx*(Uux0>HLH+{2#r~ zvN)?zTlXmApG(?O&ueQ>KIu-nb2od+kyLcgKbIKNDW%&?hb;rof(k*3Q%Y zR7%i~md7tUW<6i1xTV2mI-cODaX|L`8RITFnOqLw2oO^ftm>o6Kqx-G$0VJpC;tl{ zh1*=6om&|J#8KhjbJVVGL7EcXRzuFb(}BLMqIEwg-2N$bKAL@ALToKk1l=%RJ>%i5 zmYFTUWL+6+yGaR^r;^MHbZuTacZ+6mZQ&GZT>DX6Gh(^BkmDiLhI7b9CrjU#ZEDIZ zs*{u%FXR1$9?co7LW>XUtOh;K++bnmo8N=ioj*Sv? zo%;-~p{dpJcX0Jdh^%E>myR5M%=A+5gsi(FAH#zVy?5>^?%;qLI)37*j7RGSY>e2C zGi00;HQ9h10GRLf?^u|0c1V~d*$BE}D(eL?dGbB!EIs+^!a{=RF!?}ECxcT|K5PyC`IDMq4146(cLKBL-I^6fZo2HWDC~EN{QU5}nFrJUv5xnheN&}& z&S_%wpHxt{m8M;wvkalRp19_Ott>aieRYEK;yKuW>g&wYa|eFZ%rke99r`#TLgJg! z5X+9N-nVTOo2*aJ;b-k@WMavHMr6MU48G>4Jy;jC&(a!==H`5|!4;ueF=zjf4<)QH z(yrS@eElKRw0qkT-)13fkrc2X*b?Ob@XPN(1hML3Cye+j^nBFg?U#6XznyZbQG_S& zhN)R>$sbOY48iu5TD|kcZyjJ$3W@S)f@&VCNyA=m;A6nQO}F3O9}L3-x6X zClf9-PscFGQolY8a&i|br%$DwQ;rDokMG=mjo9*vHDe=Kx^qem=l!f=vjWX0^PcwE z`2C!4JeUGw9ybJJ)yBuKu(1kMAQ}VqBuA?6ZxucjDQC}ue_e*p>P3DOkH)Yty|aLhHRC9Nd{7D+VNNNzYsAl&$Si3SeHrGiiV5x`c0T&`>$=C zjoF50hlDBp!;k!J?R3Z8?)?;bc{>B2s5%1A`?T9pV)*F}gesmOYg$&TdnW-*BhMmh0;rR18lt=VBCmNX{<@FS(6asst z_ja2^L?zg>FE;G3k%S8wGe-XtM}+FhFn`tx|nH;2DeWlbr-59{tcjdGD@HBZBN zc>nn5RWTgQ%)}HYe2POXMlrdQVKJ{B7kw|$U zSQ)CTM_l_V{dARmLaLP7t-Yy}MZc$~MK_d`7)?vw6K)}(>T$^xF?Zqh&EyO%=S#vV zSv;mDFrqfpBYCmE-y`e`)p>_o`^9-_f^ul6MRjdWa!ya0V@>dGPzm|pWjuqW16Y9` z5{#en1`p~Rc2v0qn46)zotzuD|6N&Ii;q?U9IsZjd!dBfBB4ss^6OMp=F!g@@hMpu zKvpS<5YszR8!Y5C+F03&U(HiZVF;04?)e5!`**PHe?1@td$ex!%t<6hYKO|q1kj!=l zRoN$Z_;Y+80nM^Y9hLqP+?3LVql=NvMEvP`Oz42GcvGQugRdUFXH`kjv=@Rc+WRUC zwzkQFy>8ai%`9NEs^@|Gv|)?M4d!_PTJ@Z8B9zT8c;sbp*1a6jBMEG>t z*ckb()j*fkoPi?Muv_cJ^2NDDDANVeH`Gw*?^jsUb~is(scq?Z^-i2pCWWt1`Jo&r>neh^)~xz?<}}EEz6WX=W52l;AGWwCDy4;_ z5^luqN)d+a&;!ewgH^RWeTMgj(Y`=GRtfz62&Rgs(Zq)i5nh_SFY*KV8XA&QkawtF zoK#?4Dnf%dCV-w~XD0{So7r1ve{On)y}IzHfPeA<2}3Hdnu3bTVsPl)vnxL*idpg{*O@1$Mkgu*3 z)Iu!LrX$R<&pM4~zU7xuq<^$*<5r!vx@wea5w)-8q-qzNUH4dM{v@;Ba{U0jRe*{4 zJ-C!J19V6`;v?Dl7#5`l=b}m&J~)|@b_Z4h1i7iS+wBz3cZjDK$0X5W%^btEazFcv zb3fwmn-C|CnDL7NsuWMo!bMD}lD%$Ik$FMsqV-gfJHlym%1Cj-!puzsO3wf#OgZ{zyv>mTrn|J{Bu&gxu zndj>#1#cnQfY$-e@z~IUX47R8;dTTE1R1Wi1E`yk`%+2%0Ovt^Mn)DdgVk%Lm$EWV z-QPF74k_fRd#>BZceKs~S=T&K&^C0Ib5l1%95mKDua8w4Vi#8IkEQJrSa}%=UlU%t zq+iL$B~YidmF=G&!r>mpeg?Fnvho16a)tQIc`)&X~4* za)Tb}mqk`SKB$ebj3sp&zv-bdYUEWG%lCdtlj~xApZv}rr?dkFmgReAfw9BGS5Dx< z5l3BJRV%lhz^_wc77dY;;D|pqsYcF9EE38FLHN17XNq)$;7-{Xf+amJqZ2a4ERVN8 z&vBXdIVBCq0xm6cX=hkLyz-&$bp&!3aCcTyj+tlk?+$`L%g1%+n+eRx;#*RuC{S51upY%jkRY zaOo1msgb|dOpnG)*9MV3AaV=8x6+x!-E|lZUVeRuk&!+6u)lw|rbH6!ymmV^;Cq2j zIb*%ToMS~5T-+(Q9El_^9Fhy9QrvzXbZ|F2L4LZfGZAZ^F9A|~s&gY6!q0!3XSKy) zba@P@yIp%mbxLdR$M>_!o1hLSj)jJCyq`6G?+g(AfEHof;GnljEK2c1XWLK?!-lroKn% zy9s+}th_E-`j9XOy3th%SttuhKp6|S^CYdYurtR`VihiCrsvr?h{+alHCb#+d~0{( zEPmMsIPj7Z;uLtv+ygF&0+ThMpw^d{hp3iAXTRXVrcfIO0Gh3m!vG1!&eoCumq$F6 zWKd%1#6>VPKX7fr!$hCIRx;OB^8r9?iQPc^*@M95VLmmDP50il-eJHC9Cy{2LTY`g zJEveU#%T{Nv!H`f@T4~;8UW*a>j9)(WL2KoM_kup^h90Y(Jq#oXX4c=EdS?|(ZQ;a z;9lY0&a2sH8i188qBbo^CjP=hmVCU{8hfIJq&hE+FMb<#Z}4vtipJsaHCFz%tYLX! zysAca;NxbE#lOFU1Fbe@JaLDHGb}A9giwj8qN2ROzVry86L6(mqgr>j`w}fn;d_bT z#u3}4FH6p}DJoucgjwepG$Jinm%q6joLuAzS|7ol(PN`NJXb9s1!%Pk5Cz_NdsWTHO5ocGiOYv|Ca|2aPFi=z@3! z!jNxW}JPZsgAWZ^W^|-Jp3>~W6`AH zUO!C`@37}8tmW^tIm^ZVDyH4yVl@&Hur@e5YYg;9#)4ux;l-yD;M(S|8VEr_S{K6+ zn88aJAO%gXrtPAH1mV5%lkfN=V|~j(G;PSMHkh`>#KZ}|xrQYqo8nO>pe)7%W%5!` z(yBsTKH4#6uQ1&3T|$2w`pg^@7gSSZB^Vd%rV&9*&5hG7Uk94PM`uIA9UQ02ZYpk= zDI>#~qokUTgS#(vY8c8`Q%Yp^f3L)f03Na4RiaG4O{XmxmlOn=M9uE&HMb9ZKhZ_z!Gz86wS zFWar~?SUuP1tX()Y0Yj&58Yc5YnLj&pBWHwj-fg|B5v>~P|F9OExqYqa_lYPpX|OO zm|%DquYUIgTMlPVL<7!Y;j+?EBDAYX>S@-QBh(=sDSZ{b)CG=X)4hfkvO=* z4_kj?K0G!_EqDOM?$&rbObs$twp$?t*GcF@t*4Hce8;bPTFsfnC7La ze-@Zt0v(N_lri8CP2l>$X}=r>{v-sIOH?I-wS@e3%Wp(GFcQL$q`aEPkh?;G1~Wn9 z5L|~WK_C!UFk!vcWzEFRU1Ng#loXCW_6YqWKHk1f?C@LpeKQw0aG0@X`{Pn1%J%R@8`-h#*(e%jQZ|Tz62C zc2)7xo{$jK)N#)XPWz~Ep|WbCj&AxSHzM?e)RG_^vOJ8ioozJAh-eTJkXwzJuo*Ct zlSG7Cu*%EOAFwa>Bb=LBSH2s^kW-l7()3uZ04e}Cw~y<2}DKYzIp?7u4Sb>(w{;=P6N0VU>#59PHo=QiqPw`QRR z$Srt2XmjhVKMiU4Hg`l|xubw^ZRQ8{Z! zjay&v?(X#GJXr5%0wj%++HlB`ZWFy}>)WVfWD@%*+u{qNAVj8En{?&Hi}uM93eEKC zx)&|WdT+jbo6Q0-p=~X8bxNz|B5%-bNVxrFjKA);7mj~h<)#SA{&a&&G)%)~>RrD!L+ZMN9oq*-+l9RvuH~emLPc*f6&HFw!2q9sciAQx{~+3!jY9QHOxo7cT@3J}J}2$Z}H;Go^gs zo)R$N@6~uD!}@cmbZ)EcBd*rov6#WNnZER_-nB2FmC9$`#l<;){8~(~La#|9`JQ3f zcPBY@-LQ@?K*vR2nD1B9SL;SuVIHOrl7w~FtQSTD6B`|=iF~}sXhS&AwF7{E-p@*ShI*N%JMo{9`(gqEU9$H@VZ`ryv5rQ%0@H=uywFG!s{Z6_2TmJh_*$| z26*Y`%h9eD!A1L+D}p)(0s&R~frnZ2!*$8zIw#2nvZ0Q-h%helRro-{&btA|55MXB{WApWuSx`I=wUd zUjEmtwXU>p+6A$wb}AYr4KwtrjEu2nsz%=Pum}06@2%>VVh~p0{Lg5j$~vki*_bE4 z-H@e%ShF#|;8v47V?BG2`uR~pw6dT<&I{c~qN@chm1|mmtlPRyU1bE#VqNwpae}l< zKS_d}XEfJ%ssr!ds;oxVgm*uA{$7u+us18x*iJa-C=ZQMk=X1{?J_IdZ>W0-DlR!l zHG4Kh!l6}0lCw4xTo99QrCkO1IumHlBK}C{Y~(pXe7$4#&?=Ql%EHlo2l2k0fwOtX zf!sHZkNIvHhzoX$)1@lQruNb2e7DXOZ@bR?`O4SOjx(iu32(TeMT5PQrF}0EUGT+I zwv@CXwccx^j|)nkYmbEPhRMmX6J-Rvw29{uV-YG%OyZ(`Pd-kKY6`nLn#X7VRDL^h zH27F7{AC#TNyxL=qw#=%p^7fX*!%yDPj_EGd80-y^7NkFDaDr& zq=)*N>>At@q_y>RtA-}a+tD9x(!}BT9X8BWaL(Lu^ByW9Ql6j^ny~1b1{O6EVh@HU z5~r#i$1<1np+bE1PKIj(K4J zBJ-}IAES|f?ZN5C<@=@&dq|d9rOcfUERG$e_xgSpK$(Y zX>X!(|32rnf`WoF3A67{oS*U!J7_+Oix6ZQ`jH&s%xyK*yDqgfF$z-rZfjuE%xt67 z9NC}v-ED2v+gGL>1%I-NFD+e{e9pXf|9(-pPH|Af4`z6g_nEXhhKP`l|A z6T~Wc%NxdRzxr98C2I* zH@6|13`+u?x|6&n*&$FJ6;7r!M2cAXk+4cW$9-xgTbY+$|zGG{15rxA9VzrWrU zZpG@CE^vsy`4W}cXQX}Qt2S?ux5wvX*G#|s!)@xa(4FfWlZW&&Ch01{XeJq>~w z%go!crl3XJx^u5V=+wNvc#SLfl?05QDZk-`EJz3p3aoWe(gn%OjnHR4Y~s3i&wGJX z$}dd%8c_5Or8R-*PN*{M#O;v(Q)=^3;cb_KiXo?Cs9RV5MIV<9^_^C$D5b5UQ(5LP z`vEm|dG3N2x{ueoTdNPJ>X7k|??=kk&vv1HffUbfhaB{|{noN-U#`}*U`4rnucVUu z#m@M)akQyKT@I#E$Wc*RI-R)Q_*(69kl*}^;d}nl@>*VXwO1MX-LL$%(JpJ5eEmWl z_T@9iKE{nVb-BgP-{+Im@hWfH8UGrl`5_30?Y5A8^KXG8*o&Sc@US>1m= z`nmn}mm}-&&+Sjcf8+Ni@7y6z8JsqyQ^`XA;oTEwYFqtC?gJIrEb^YUXlkPI1-=0` zc2!BiI_YzJ^aH7T%osFx(jsz%;HE=+HO_8SPOsw`;kNSh)ZiqYs`0N#{~io_k^vO; zrwV_|17bzvF%%nC-Zs6an5EWj@R9*ES=h;>m!RZK@y$eYlV9PQe6zB!H9 zSxOoFbz8L9H3wmZSo={bvRqa*=flAMo;BmM?XUi}3x}g!dn!4Tk^w1K?|~Ey6X8B&M)@m!KbXNr7n7-`xUu^@J`^%Ke7+3gx`<{ zKb5sSJPvZLnymkKY?UJq(?h2sf^W1ex~BOD9sTT)7}q3|*Xg%Ez=65ss6;8p)@zA_ zwa5nJ;`AG$zhk!E^x~fkn0?$7N{O29jDG2L-@SN&t^G-j2tS$hYE%y{tf=^wcqAou zvJ^Vz`Dx4@g!_kGx0w@&)E89?8oQ! zg5*I4&-&t2db2d`#RM%f|i< z7Ym+yJ`a+zR#X0-()+${$*7u!z2fFG2H}p}yTZR>!y!6)CK2|kO*&jZ(%(JJ7`S~` ze(_e4OTbrZ11$ve*W1e6bh3GGN7xiH=Z(EdT>C41e7(2zQW@#SZ6ZHSrT%HZ5wL)4 zJaH%sUA$f`A#ukeu|k)JJA$M`vN6dYOL{_-`zb>P2682+e*evY>Q z5Ur-KuP@PT{Pjz$@)JHjkBq+Mmr1YFG&$zAl}_$6cgCBIuj_rcf6Z%@-2R!4hNW8S z6*~28YUox&YUrGQ>Qk4-2Quj_EGn_Tog10&+UI9%9hP>AS4*=#0>?zZrq^Uk-b&I; zRg@eXd61~~H`Y&@eMEsN*<120ceFsJ6x#MvGQ6+#kI>4Av+dnA$2KFvAo|Bj<7viS<@;>hO9|{! zAKLm3|g>gnw;qVsckIaKkn7Az8 zPHd!+V{ag2y0|*tR$r4MsKw^l&G*wvt&BgN_^w`N@NNIB3=z}N@Ut+C@ zir2UR@4NM1oZfr+(3g6l{v-EN_S1^g^^?Dxdg5pHGUca_=pio7ZfMQKfg%u;3FbG zf8uvS2Fu##_42MS87dx1n;9dY*SO5m(v`Z&L<~E?%5%H8chux0@ zDz2yf|5ZIvnQf_v-C3Qe-wTg ze)D3R+^zZHNIslOD9nSz_n({u^uvWZf!uQI<0TKeK?d-k+CTY0VG$D~Pf$t@ndm!q z^@5>`1Ud-trA@X5ctU&ERW~xOq(cwpIvV>Ka}5iTXd~EqK-T^Cm+b!zsM${j7$HWjK#pXcI)YG$K8Z7a!OLX=urQ5cE-OkuH#nxbV2lFod{*LuR zK+S+m7Vb%rlFPZ3&OS{#4@CW+U^(B-PLwC#rpp`r9N|teCV#;y(Y$|u zjglj*&mpp?*yx+T0JwYLHKWg`xZ1i8$2KXr7#26>n5~EWZ>m;9e$c#cyrOp@06rP$ zjE@R?eUbXyjI_bLG1;hpuWbW`;Zisv^>SEaKi7(l=!WwC;%V__;qFh5E}WXm{MsQu zf?*X?=y~(zx5LM^l8aX>#{+{yd>t01R`Sg2%-4A5^bGkegKNda9*#{$(;i^FeyHBt zxWUPG#}41$j2HPZ(h;6(^S58D%^k-uGO8iO_r?*Vs_N_g|FCop>~(fe+i#P`w#~-b zv9;qJHD;s6w$s?Q8?&+17;S8`u^aon?*H@V1LQc8HP@_}bAIP6ZvhmbVO#4JZH{g! z9p2~8m>~`7<%}d%PqHwKEj5yBMj2A}XC?}-mFH|W;hZ}udkMO66!2|OP@lE`Se~l( zdrm`BbC|j)I-iAt9hrhz79oPE@8W)S)UP`*bb(~0I&Z-8dvu&grPZ-J7Z=J&f@5KP zJU*L;V6J%StX?^NNyFhT17Sn+!$=A}Hzn&S@D;i6yrU9aj4`cu;p6!=vnd6Ruq486`h(L6W7S zBT99b{&E&Oh<{y+<9w;s zt34ev1l5Yz#oO-An(KW?;kaos4XEyANz z!S!~l)B36!zsJ`A1CM!bB(kq!0}@D?Z!srh>-c6Q3Yy`%lsbHkED%nW`z!#|Z+%uphEC+*0NF2%D)dGX9|h6)P(3woMT#vT;t@YqDi<#iQS zhBP~TraoUGQY|8jxBTjdh#a&Zcwu$Xup*AI5yv#tq6uQO<s0$sL(DarkY5N9T|Ak-%gK)>#P zj+pJeyyt+2liyUX@?!AoBBiAIFMmT~?0h8(*L-+$OTVfbAa=JYZkU-OHUanqeo` z!fD45v-n>?NN7PA`%p>vUn$KBrE>3SZo?SLo^&nfg+yC!Zvm&>843!@ZCiEHz@@mQ zr6r%x;_!^F1}RIyV)4yOZZ?qn5G{Nzexz~ToVVrX-fRGi403Mxlt=@C!qptgQbjr_52ag{BH~D$7DFMHoo|ToAD}#+MhZ~X~va#%v zEp0NN2K;W+8yKD>3<$V5@$1b$Pn!mi5-ya0yfUw@P0SNJ8~IVr)!@SdbQi z*oYse=iIIM#z&XPZppF=fjZmhM9BK;3G?iJ4leqM!d;4W-q+_&gQ;3xAFXRc@(Z4bP)%`jK=M3yU8$0tly1WZlb-(+aZnP7Og5w;PbDL}kH1wTO zm~a`5b;L;{l7bL;S{CK+J8^zJJuRv@Uy}Re$2&Z$D>-bsN-}OHr9OG_>ancd+}azy zG#*)0*ZlF~z76PEbp+2NZG`j}9F1Rf+J!DG~-%3!Ay zEotz;oC)Ui9&_KJEEc=@GU&fvHDJ0>l9-d0XjNASToAUg^i8`GLW@d>6+&6bJS<)0 zTpv}WP4yw24Z~DNBpgfv*>J2H-f!m@yUY1pJ|$0n%)q9wJ$gkvo}HsDkJ7yuALU$!A*Zw42`u4ohd*C#^>*nwK|S_dmSe-wfuzxmRJ;S-@%N zKl|x%N`Y`ZZRBgI!B>H*`2|-3lx3(LBNw?Y516g^1ANMGf$t3h~yONQash zP5wl=m2nx2hm`zafK6auoscuiQi#rXw}C}wSI%*tzM}8a&4N_0E-o(e{B2jnlz`5U zw-+`nk`T(uWx|`e?n^r1|3*ak*c}=hcBoQMON@Fk;hOVb#Y!Tv7?}`10&O}1?$>W5 z#Rn2@HnuMwlf*2(`SL`Kx7U|n38Ae65j{q5$gqEi)zwQ>OTUT=nfEG3MvN$P?xA7f zL3mw;_$5=-vt~oK4yV)GJTT1iy0F!;=ny4|--Q|fYN1fS?OZr$^#@&zpvNn)2F<3*i4pC)5 zVlD#lO@QHR8TQcua;*IFjB=Q-eOCy{3D4H$hEQPbw_0yTx@-))0y+ony=Eivhzk@f zIf$@h_CUx%Ufk7eDN;f3g}yoUObyej*X!fOSWjOn)%W0{SK8p%e4Zx3HRr$_GuGihs59nTklZt=r?P-FXt zW$Dr#X6H4ATHbT`n4= zMwm74F>c9c#Knq~Tz9>3xkOE~rIkB+B1WatH~HA!!LQ=HMZQq&iK}biuaEu%0;)f8 zZdXb~w_~g(qC;qVG2t=3sz3@BPrSGnU7Os7vgt}Pic=vGn^$IaAE@uK znx#8%s;#FH%C1^sI1`-#A45P@@9X?-gc+kh7RH$2L3wl5vL+0<n%!)i|-7? zYZL|#5cFK!PqGuHbiidn=GMzR9y7koQsUxvOx4M&QzrPVo|R2*;(Nx}j*uAl-fMWrsVL3R_Nn}VNCzS8c#E`c0e&k;QNMxN{d;~nL zpNAk4rPYscFh3q6TE-3f)yE8(qCt^7rFs1J|rhPp5G-X2 z3|2S<9dRdKR)ZO};4Ug-6T@-ECU9));^A@hFKLGc*>VfsbIXP3Z2)vv^{(D`1?LWI zt05E6{auKPsx_i`YX?zvA0;GLi0FD!2(FMW52j<50^X!MtitbI5o-+=k_eI={Y1AI zA&W0e$B5a|eZ}u)zsw}V{KkjX)BfcjiEFW&u_Ps3NZWbCdyKYu9`l;YJQJfr7}Y~} z9sk1--1#t-|C{rL&?x#C!*6k%^bQ-o5>~D7`pEc0>op~BFHFim{B(auHDE6&P-3*^ zhOt@@=-le3jm&P@NQBC%cg{$ReE{x_|QjAKpsjshZjIo4sfg#0t zpwiSqZ0tB1aIe21p1_{GUb%D@gBQ~o#0sir$QZO>!o)dufQbJS7n)g!McoE1=c$>Dzf}7|#(|%Hzt)kdwMxn@|KwT8tQ-mlmuxjYyB^ z8-WP0xM{H>jW@U;d(?T~@Qv@Fz-op%-vax0rG;8!-ZpZwrLk(*+k(Cxld(v>6UWR( zXy`E^@k#^Sv%8HnmA(We2HxUt>RN58wpbdCfyW~OT5`Iv^w_NKH)gST$C8bF;z|43 zp8KhN7v*Q2tAAK=Ca*}7gRe)9{bEnO zjGwqm9!W_<4OU~pfrqqN#ogS|zN;;}KW|TJeRs`>i@Vo!zJ}Qmmf!|KxhLScl~&$o zaAWHgE|X*$EKFZnF%)lT>gsl||5p_{;EPkdGjk?&hHbExzFAGUhWx`=S@H-Cy;|(s zzTE%6e3BEN3gnZcUQ9$7(n4APJ<+`IQBk+;2uXg6t0h>2SY9Tef=VmwbixiS9RA_7 zhW|+_MoOFfd*95LB@)F1aF0!MI82?E1JQF?`U6b{Zbg4?E@^)BajZKM9+O5RAQb24 zmq<~Ql?i*myGoViiW3|;n$O8my6MECxT3l!zWW!qx8H4NtBV`=Y~4Sb z&pvtMxNZ*4rAb#-fI>G`0ZnaI@}lHJmz(34z*O9+K#X&7G=^cNo<`4WHs{f0d@jep zC49qm)Q(&?S2sk;NOpX1upbL8ruOnsb$#T1w?G+q$MHu}Sl?^rN}sMCWRe#tQH@o| zG^n8B#zUThAPxyCY~I~CDR`wbXr;5R8aHN}jnKN!!mex*J={wMo*^s6)Qp*hdP6eh zXGccV^YRKq;=u!u4?ynS;YoPJ96Isj5KcGi4 z&$Dnh6J?dfy0d@0L8_gy99{;t%2Xbv{NLk7xQT5KJp3Oo9G4fr38ZZ`*DsTZ^NsCh z>{2*Yfzj5f>`XK;uBCc17*uFLhpTZMp9*m_2=7A!A&03hpb!~VtXN{<@H6}F2%i<0SHx9CT7b(l{iR7CKg&R8^P`@0mYR$CdyOzbKS(@+4zQmLLBPZ940dz8|}H zGo5qJw5Ya+`Ac+t$D@u?30K;r5%wb43HOB(Ee(kmNfPHA&V-DFFa zxfce%0@JmSM8I4=9oV!#f^oXqtUdr*xnFQ2UhYErnNUVwxpaa%08N!b)hFhR_ay7; z`as85{vE8*OMs$2?|^VSE;&ae{W(%ctY!hwy&>_kjlFhFRR-2l?gem)9XR=QyWrH# z40&9mzQzm(FQE*N;L#!((H~eCU5K3l!K8Ev>)e)g=bsrcBh4}39BOeQE}(??qOy*) z+;8Af&zEG+2J2Axrs13PHO-uhIB>B{i5SoMkF*Cbv}4ANC?A&TlcOR3ri^CRic-4k zaFo7xBRKQxtPYr=CDo+t4jDc5Nv79XBrIKy($9kK6IN?}MSfukO6n;}{QPou;BL6O zvIvvG0206Ea;PV8+g{MCkAq8~Ni)zT1f;mpqJwWdK5Friz${s1(}%WOE5~wbM<4Wg zph857RKSR{PKW0v$b9NWlhK3{e;Nf^@8YTrxmvlOJ4;emHhDALdE*r^^6cQmH8tnI@;UX zs*kKR?vYg_-UALOg+0G-%e5!Tm(ztFGLpq`$XB1Z(xUH%N5BxmRp~{#y7gOeM;0Mt znUD2VHKc_v_SMg9)Mcq4(RuA+q9S7S)}cQVh##17HO!viWI?|W?_A7!V)LlOwM4Dz zhP$!~hDpD_m~V^Q%Z^+XEQ?gYxN$_`xpcksoi=79d^S}(&l)-YL^qD!f74va zL2Hcjuc3MAV))_nY_RuJu%&89ee!DJe4x)^k+g>Bx0K{GlktAE-GI$ESd9*$Lelx) zBFkL(ja8eFie@gJ!p;IrK;J_|()@|ZGL50}$Lf6d5xxrzagS_Vnm9M;xpIA5v`7f6 zq@PVRV_c~&kifNS0q&zLH7YI0c7j2<5dO>|^5R#wm}z^a3ISKZ4Z_4`}Ze zJ27WnS}MP~VHERxxa`z(kNl7SkgOPATSRg+$c|4>-(0P&t-Y(QsU26?;DX*RX0Ny9 z?pjA&zlHK!bdoBo@9lDNaZT>Evj8{<$;@#=JTHs`CT@~!TUzK?t4*3MXJ$NvR#{h9 z3F(p3jMX3d9t9f4CMI-v392zgr73f0+eSe+g5mG6$3@f&TPKDK9C#{GGGA0tf!uE z8swjr63|}n&L#^l);-;1A#3f*8!Y#PCi2ove%88iu_QAM zS~!_x*559#9NtI~JcWL>=i8^m(yZcc>{u%!p(#G zgKL1HQ00-OD@W3|ihS{wfaQnd6ma);19vZ4TXht`@JadUBmis(jJ%|zAs2y@d( ziU73r&QcCIYQ9%t{?ue4cI@U`1#i_P8G}wG@(5r<*&EUlu!6Eb(ypgO9iPdtwh+J9X2Am{s%nAQ_%5ph-D zu93fgC_ZV8`%idpMlssIw8EeC-peOOV*2I}OK}l`?jpthT-tGGTG)VV-47HyAN#Jg z-glywnv9?`byihGe`%_z(GL+*Eg~E$eo}-;r!9cfT0=MdT1rYR?E&@RDi09;+>#JP zien%A_z|7;^S0rcWp_z{>3q%K|9z}DkCw7>GJi3|^Luy>_fD48lpHT#J~4+|oe*uK zJvjE`3CA%CKTz-@MNKWNQ0kF;Q=4}8H&dJwOCFwww0TAq80oiXv9U$7~MroH+{B{NJ#yR!(dd~ zQ^A>X{G7JxA58+pHYfVnmn;U#URW%(>`!oz<}a&k5xiDA)X{m|C@H(cNLBsc~DOUn(0!D2Mdq4S0!q84;MJ=1>#Z)-%4tB*! z753$j5Kbr77|wq-$GykJF}WtH*=2Qqfs~JJaAs8BI=*XubC#X+4i|EOi?}ukaVt>W*LH;v&XH`6&EN zyc!b?u0F+4SMcd7DRBfYE9wL>a&#^6^)p2;fe@SAS*0^p$KP9yqO&XZR#ts2`~N#Y zjG2Fl@w^CuVw-KIa4!kb0V>N3Nh#dKe=3A92Vpez0^T35;@C^9urv8rB+W+h5e?RK zPPvm?yDZ`)uWm%hdb)nS(71IsL6eq}uz<$klGL_MnJY5i)D#rfe;w^^!MPE)Da*;r zli)@jfxNt-eGMRCo@E`6fbW+qmzg+RC)wN<8!t5o2OVS`*iPO z?|ouTerqE#4i>H9&dm|NlC+~XLHiLQ--KcJ%^;STQfvQWFCrxN@!_|s(ssG`MgH?V zRdVX^1ldo8m~{6EM7`TlCB!{JM$omo<7Yr}dybqo4Ne+L8#l$xr4VSb>0L3;8(iVC zt~J?}av8#M=pc4z)29@+7G`K$76_^?&D!Xa?f(mK{SU5&GN(8w(&@nWlveKTYb!&Y ztUA9DPa-&+J)$dyc1CIZbJl09FWkp##cb&N(*Ypvh@_sh%g97?{+(?Q6{pCclX>I(U+tzy<_)`7M1|G)Y=DmZggdi(f&#neWaV8$S#N@@ zaEQnmn?IZp$S1*u#R*hGJ&hqz!SWm(>rSDX8cOUL3HC_i$t6;|RpaM7##*-ASNhfS ziRMrS1GNPfe@<$4<3!#PJYritlL}Z>N|9{z9=q@kg|P3af0 zN(KaRO^qLAjgaI&?AETBEMLz*Z}Mo!I^kYwlKK+^V>a|noYdq3Cj6pqUfka4LS;1r zmimn$etz)g_`RoOE4S*ao}?rtlc=vN7wY-Se}GwJ&}vJ+lyBO_;?>A*a|!L3b~mx6 z|Bh-Pc4BbBK!$iUme;q$jkBx3SGBCw)!Mz7#jE4vd99XxZQ?cW@9ySX3^$QByU{+U z_Eif7R<=bH*Ygg%FCgDRmi)&na-t1oAK|a+fLMWd^ETsC>wohoAIvqKSUV`tVL>XZ z?7x_pMqsJvvNTAW0*~B*&)B#@z59CpM73T!gnR>G2w6!-7Dr!De*3Zl^ z$>H)K7qUe4lHbnn!%-9uK!H>7{v7$iU_ViW5U7@^6SVr2`HZB@x%v5C38Yf{*E=4M zN$(m_eQPAjgm7}gFr3D8n{2-3fwNxVJ%Y=FytUZj3-9l^k3!%7pZI=VPk4Og+Izv&wvS zomvp50#r9k#RH>Iy#yDU4D2Xq9wB=z-@;1tP>>7Hpk|&IV%*y;>tTUpOn^G34iqK^ zV@wh|_Eg}I+6PAF$ZXO$2ctV&+=#Ii$#382S+4VI*WmDkvRY%lG%6dr@-pIHhlmQ2 zrmO8rlF@yeUO##is(ktqLF(e_8rd56=z}sGh+u%XUdRlfg9q=gY=nJAV1X4wZr!a* z5AWS%(r2D3!ad2eJ+fE)iMCUa*ar^BYxs-lj5J+n29WD*QrBL^L}9KbKxWwLR1u;_*=YgmhW`RiATgSIs^N8=%*CJTYWqK zi>FN1S8-!ya%NV9h~4i=(ppUWaF(Pw)0_1-qpkVXVguiwrt{u{Qx0nSwuhuJmfl-? zZ^IC7kgJOQM3@?TzFv>~x_Tz_A=Blt3mx7=|NEh3LX`D;hFa14dkavivBtwn{o1;I zBraF`ZQGEMR$1}S@pyyLNxzgMK+(e77-=Tql%hYYVZ%&jV6?%^`1l1TozodNuuuwV zCKOA=rLk%1qQpzK$|ta5P!^kks74e0=~U|Xrz2&`48>2|xbllbpXL&TyzA_rxf8fn zEkN*^;Hbe+s7h_7P-4{55ybKPm^==RO6H_n^*@<;KSDGp6VTbI6T$#oatceq0V1nm z8|l%(H-epua=By@=zd*UvxBzzp8GKIbFgaHTEIXYx!WI-$}=nOHj~}HuIl=fd5S+EzS8z(|7MKL=gD5Qh?U1MJT zL_2>#bktZcOpMBN(9YeDkdJQj{hdEE2v%se)@~Sn68gAOh_j`RB#9fh=fpAJX^l5J ziyy~?PDyHhK=x8y4PFTe)!%prh1WFvJ1Ik9m<{=c%U`}Lk4T5<7SH}zRE}{G7-E0? z4D!8o&b;QC%#!YQ*f6)Vf$E>SwcfL|r<-+HBqSJG(TGf(eh^v{Gg4ARv5zU^^X)%@ zK#=%fAoQ>|Rw*Gw^afC1Bl^VAcD|V-V`=jQgg!}O#GTyzdLcOY>Og_YJ%>ayTB}N% z`@sto9$!N|PGZ~W@gjWw*iVKL*BDjAJS4`}Q+^s=!SJ=Jyu1x*riWJ_UtUh`_n6sr zy08`4yqeLKe`*w0UG?~AIhBtu(TXjfNjVmnnHU~-wx#nrb2AJ@2h7_^B+CvgEI^3G zQ4jONiekZz@qyf9aT z1Lp4&=BF}3`*=w=UCM-wxhc`mGVm`z8I+Hf+Piki`^BksFlURNyP0@8lpS7J$%q;> z7^ytPTK_cp!=!Iay9}R4(c@Dcv~zasIhd!tpe{O5fAwT1+?gfAJ4 z6zo6Z7t(?U*x_O2&FSTOg*a6d&+WCVgl%lC#C6m!e{QZtoK(NFp-vxEYpKzPNfCOe zd=-z9su_uH^zNLpn%ZCPtZ^|= z`Xo1ga+byK^sE#99Mlal%DbBz%m|1aehdsf)ozyTz!4Ij@mizhrKU#YAbE$9TVrFm zhbzSOpz>ctjr;&xfhPCLY-rv%g9z{VHBw_Eo_4Z_b z_gxo#BU$!zONO)aOjVsgd6@0_ERpw29H?HX@>Mk_qCdUVM9ugi;0*sfsVbWJ&Lc$p zWb4h>!{pYg6ZpE9`YHzRxV4yI*cz->jB)>tM&xR@1w4WDhm>}%9zyr`_Z_OOIj~hU zG`LHC^xO#5#|X3GAIu)GW5B7XQWo??i%P;@r!n^jI%^uvthK>MlcIDOqj&6+B%opt zf?MEjVj2U|v2NeS!=}~9jW@k4qc2sv_*MeW?5hT&;)Emd@_Z1tsiY0L34*hlmi5T3 zxY135Y~R%6x`0Oy?oz7oSSTo_ku1N-%XDIu7*4%I>ty@n|AP5!D(W^1-k6yQjRWTV z565x-lz~1iP0f)^l~g#Ll~&C$iE;&HMZvf!u=zR!^zi{I zH5I>q0NGsjK@*q*!JC&aFigU`wgZ>vungE3w|;mavh z7EhE0aE}B%v)(V407>Pp**;%j;q%(;W`1hE1U2Rl_GOwrjIwm4sC`De7gn)cp5V&C zW@X_x^@wdG{iczu{wkfDbi^s}tIG+Hh>PjSh8dM=q{g?Ylr7HL1x{`Z zREs(IesdN9wrWsg%Hn~moLIkS4B z?N@6BsF(~R;TRXu1!(A))IzDm>81m8lyARkFtUTUviPL;`7N++edyTxqPOt=9ZUND zJY>)~Fv!n9beH}oP5&P!A=LYuqW%BK2h%1)KzcTi5myl}Km%rxOqu@LIlOfudk%s% zrVIel$Ujy(`nLM(y(^q*X=W*1MPM(2J%jPk&h||d9Y%`kyzP3i_4;mAMQO<0vxn7< zDGslnM%NI~lfT$k{QRjeWJpI~JTma{p8|CB_2KRC#g`TzdTqkS%LefaR)mjA`h#Hl#>9ka`RdAf+P@@TBa5VK| zJrA-O_v7}Tyb@|@zRve!QX#c>^-b;?x4^v|O}+U7ourlE^9YrP^(-zplMsBTM>H

<~TpNg2NdI3h}RToT+$oWbJa+F2iJW9IG`#9ZGSjLxgRs4@W`Wg9<7 zSiRCeACt;OqeFa@nZMw?Ar_q<9jSax{rkrs;9FU_2vbC68_dTLtt+)S zJ^`bwWkyZo(GWc2^M8Bv`Q55YuaQEz*v;OJV7fbhz7n3r?YLhL zFdFYCU)dZASO|HfcCV70Zg)GQFXrav{NOSKK>#(P4So{=Aaz)Ty}dnk+N3Lk%p|K3F0$p3OTFy}N|Ca1RPWSsmMEBeK;~mtHx=N}63L-^(qEmo22c;aN0c_?lH1 zdzJ0)gez*NWAnmj#* zNAJh+Vi|E`ddkA~&l0+-SpSh6^Khu)gJupl7HY91aZ~Bj4tTivbQpNQ`dJ*QXfiSe zrb4=ntT+i&4w(a2w?%n;{T=?`cHZ11dE8A8*Sh~POCawbn{=&r2PYOZMnVEtvKMmb zQr-*OUpRU2 zb}2Mbz;y&7DwDX?F8h|!ypmJ)lDTX26gP9x?#s`N6L7F>RO; zwH&f!t#AMG9tn%$&QAWDEtr{r*uMLcSVQE+-p1>Ut8rnW z*MhxYLd4wt&_|hR<}vH=0ZV-^At85;68}cOfPP41I$M`t`Cv;coA^&#TRDMZDHdZ* zE7ane4~LXeW^nVqSpxlp?iVPmu#hR6pnTL=`W%qG}=anxm6XaZjDYZ9^!kpm2Rz*Q0`CY8Oo=X5!-tNzLmPWzzcEx_ak*ZtbD zY4mTc4FL4Q$VGTVDH)@S#IGLEVL7{CfBUczpgfGn*$>xr_b-cl{OXx;se)gIJY{P7 z@)ej#;(7@%*~e~Pw5V=&Codc^mPZ*tNVw90QVlr}%t;T3!jh*ysx3;j;vHdkLH?rD zsu|4=ABPk&B4S^Z46lr47>E*-F($#F;S4kyL zqD5@O`@LODeV9yn$xT8$+!iy<21jgvz-Xq-!OXYsBxPm71(XB$=3i~7O0Oqkop)mm zWg8uM6~^CN;&Y!_H0^iN^a76^zV9!3&m84G1L4@b-w5LSkQ4zMbX0TGY(f5eF-eY# zo3kLxCLQMBXU}cRMq~(pzM-U0J>OF;Ob1M*vpaZoWu#9^j8^l()V74gEJ&xAzAqB< zMPK+_A$k1khEvv2gAZG!>4P_=y6em|IxCucbn}f6-F&*Vg?o_a& zFgV5*c=Wr@m^%sM4A+#h!yhp*!Tso(d?+<|;CsM))Ox-G46VmHsa6gP zwbu@Bt1Hoh;2VCf_o?IQW+%c`E;b?(M?}kvj>bFjqBm8{AAeWZ!S|)PIkl|W-^iFt zyTDr%@Yq%h7WBpX8yx5sCU z2={||ko-Vx3rol@?)dww=UXYN{iMtHe3>$x?W@t6)abNr$RI@<(*T#X%Ksk7>WQ zI`UqlWlFZu%xu~}QsNX=WOg@~l$=LY;R*lupC&Tl&H~mT9Zc0FKe-t^Ax3(H3JM?Q zasY|ly3~8J+Q{vbbPNW9dI@uS+;kG%?86lT3>Fne9Ma;C23W@<8wiorT!iv(! zj#BcutxWH>z{or~o<^_6*iJ|{LhkbkkKw@X8Px%79f2W|5BGpD&m(+=;Nl8a zCcTMTe+saaXiBlUl?1TR`w!T{TR*x5k<7Ia=2swnxIoDX= zpbDi-`j9m!Sq27mdiq{58|A6I=7 z>K+DWu6iLM86HzqAFOm>q)AX7d}JD%wy(+FSq1M8P2fZ8@yf#%&_46q0UzL3$Poj@IIOFyjUZV^vBH_kPOSx@VI;iTdSP^AqE&Xzg28`f& z{#H?Hf3hD-EdOU-&r<>fvIcSezjzynSp51+&Z@&*2wp_hIIz2_Z-L!17yl(!-&Z>) zozRFVn$oQ-Q;Hs4+}RetMekAAX1x|8KN7J>b(vy@PfT7@(?{0X*_oOVMouQ$5IJZ* zBN+S+j0=7lv*G(q<$=;aSXf*2;*fQGc%;jJ&?el zwI)I)s@sUg^3*+`S@9XyS&oSmh@GR@*om5^mX? zIF0sSJX`mgSJ~YDnd%5|Mxe70lghTZyXgL4!p4l8&&9xJkfmE>WMHWCUUc)>c|n?5 zSR%*HS`Md^^GzjsZ0KKi1%zI<_%z2RpwLF5p`r0h8g;=^*hvB+7K{9}!ehPE9w$j( zS6<%!6b!2r-N`Z95%cmYvuycSHqYem@=-eBY)k@w!5p2nknaEG-LR1_Xd za7PugHL1V&=@*^yakvp*pU6LZHc{XLYq)~Tx4;HeFTvceNoS9;Znh=BJDvQuFhsx; znPnZ7ND56&hZZN_JEo%e2O;6jZ>hrWv9C_YbGWe|i}vszAsjTV7AGEuUJd<2`s$ zPw)8Wlg?~Lyxu^>we6Mf<<6UvfznAlMx(YU&VObTHRwdMtT^P;l%p8Bvtyh#X#6Cg z|L31g+PJg7t}_Rr?urAjop+~U^PSWupFER@P!eKcLI%RfROsP`^5))KXcBTqKNGW^ z#)vW8Nq3|t;~-SpbkBPrVDxxqq*We3NOrVq`3Jm<*BV_+Pce-K6nVzr!m<(9R6u=M z%{y%&o@8Cvh}8!E{=}0SX=&>de~1f&9ML!iS$G$o9NV?!{EPnwNmFdeLKX0ej-=Vf zKM)xNF!)Tzs?x?FQ%Ze9YsVDQs`o&Mb|W(I^5Q}q@!-xtsyBF#&JVy*EVWjk#^lSq z#mc4tfeY?gVS~1J)dtI5fL%4E2H{NO8G{gkn1G>BTKuDq2D8atNiq+S>aIszy_bBn zNSkqD)9e@vTJ%TP)`W0mJ2)XJN;<9WkdKVXIiztOuti*w1^++TW8r)i|9S@6lB0f1 zEV}PiDRkI)tgJnpk+&3P?&&32l`JtVnsMOe^9cnVIwQ>|DKr^-bF+r$yGv8n zNd)Oc`2A>Ide&@@{FBT?NTG+F6rtdX5)L+L-EPS19qxYlhJjlCvs`=WL#WA)*tM_8 zwwRG;+mge$tJ{U1I2{gEPENXfvuXCqj%59-(7B&v*)enzJR8~3VR&aoF%%886m{tE zZw2ukWEB+*50PlQXPV)?5u z?g^UItt0b35(l=0@0!5I&{Dv@LFoJ?LKVz*`2KmB>U}n^si$sAi0in|^~MVD=w6}d zJ-nfjFH9>+OApB(@Y14JDy$C<(%(G>--m*SVAf>ak_XL(nbQmcn^?PpMd{<27Y|Sp zWKLi5Pb*j{60DJb&1VCZ2qku)g38M#G;-dF5+5YlGcxk>R=_`K#!$JHi&#DxJ}{u0 z(KsY=l{s0yAIGLMG>}5@`H<@2oK7SnbcmMM)(k(IR^eMZSX~{a`_d56c z%!FO4`X|vcj9cCUpmqg*w4sus?CKdwN}x_x8g$X4l-6F|IXPPjTBnf0R_6ua(6}3U zT}fPg@=nY0P`zLjKLQHW?9D%pddsn8235^e?z#EiZI^8hAityBPVnCu&lJET7#ZIO`z-xyq#X*P<+J z#^8C+zdtPf&5H)T)jwCauUgqf_y6&>tj^C(V`%fR4Zg!c03&9=;>nXh*E=FXF~wab zL6LJFc6-8D$NqSXLO8H8j5oZpx(yb9Rd72Go=(_}hL9b}coC42&a7f_K|vok=+R+C zZe?lFJWT>P+ME~643K47%>EbRA7RE>PZ3E@P6Jud2O^_bhA``6syN}!Jgo(ekB^Pm z2?n!+#qJQ5L579*g7t;fuDsc+1=AKNY|8s9mH|ArH6=&Su{NP`)9WroRwuyTz);q= zQ&c7c*N>T2{X4=rV+-HB1N#eWQP{IX;LM=_JFc4|F%PM%)y_e#2O_EFyFF3m7j<<< zDawSf_=5<1V671j(8l1`Iq-Mw@=DnGW9hgI^Vc08SF*cq2sU!_kS%H%LQ71Z|30SS zJ_;oo?OE&})kI{{csbPk-&QUFm-K-Q)4>j3ks^^ zY1!>sP(#(UpkHfdhiT^e5I2Rz^=vzu=X4hEUpYGw5cDY}QRrpO_iZ9r)2neWpWVGo z8!Gz&gSMH89Pe7FxL8H2>xEQqvD|UpZe-2%8w9=Q(X6_MKqT;oD+^ts0@F*Bd;QyF zT%u_(5@&3Lmev<2dq?)Fx%XSqCy+_Qf7@{e3*!wsK4<{3imO+PD8&^c zB;X$Li5v!o7GFWCeh@ZCi+Wb+bMeuYSK1%0UbBWS!l=Qe6Do78Z){r4@iwCZK9I>r2LBxl?|Dw1TE4)0K&10b>(A&fABDmuI7wqu9LuUDtPMCRvgU^4-q~%kEbp zct6U~*~37Mwo6k6gG^N$v)B zug=AOehRRumaA_$-)T3kga{!+OnjjOO@`n>jw^v7WdcrGRXE6WwPCK8QZ)9;_56;l z_P=vKgMfx385^hn{B~zKj5Ks!@3Pv~8*xvk{P0+Y`pGJ_f`Nj)fE+WV8Hr!T*n29iesTw2)gkzA0R{8!44^e|P#~}S+Jgg*D+g3>R z&n5LYN0JKBjfdXX0;l%`h5ozfDY`6EhUK91Y8Eg*<(%gKBk3Br<7~HR<21I>*tQ#+ zlg74fHg0U&w%yp6*lc6lxzq3df?4aG^BnB6_dZjkLC+jqo8+tb?=-`YXz4L%f;PY~n>rcth`R zq(2>Bd`cenQfr%@v;d4n$=(Vjb|ZyeX-XzDV^1Jnhs<6((e!NWm2g9fi_<5k=SsLq z+N$>jzRG2>K=5tpgxO|-sz@6s)(=?m-4P%T(p<;S2L8Y#fX#|}>A(%9g5zPM!zMem#`}@&IomO?ahX57W z$TA)2{_!5NE^_KRLxOwOv7KjACq8^CsM~9TV5;`RQb;8;{%snsx{y=_*svS903C*z z{z&?J`z+zvy@d?GBrY*ZH7}1ahtviBmSF1q<1)WjSjiTHmWxcs)Js78ayhei6YGnR zXg-70c^2K=+<&i$5;qVfrP@45eEH0Q`%$bA z_y8tYy!!X%ldE~+)1+FA!OTcjluQN!?CmnvKUmWwR_$_|=`0waJmc6UwiRs7N0uBi zIaCV2nS#2sfl_zuLAHMG6v?T5J3GLvy!|&$4JR? z*ZE0l-wM=VZV@sy*(0rhjV2-+Wy^_El%QsP?jMGNF?E|IFEjTiL_$Gt&ls*$6b!eZ z4;85`bsu>U`h7B(L%w&-O*xeX<h!o0eVb817&&=SPU5d2y02oW)ZgvNGJ8L>s;X=stbhzT@5^dyE`G2fP?g~UYah{j1T&|gGY3HBVoK|qH&JC) zl*GSJ4_apA|K#c)z2J2?%BvGnpCb_R|NE?Q``fqe^09Yh2@&5girmn-V=$DN{c*Do ze5z9)G_{T!)V;}8Sb2arExFtSyXm4N_-9T*!p~|;aDUtimJf!_+GJ&}+KE3^F7EMyy%408~!YwR?1iF^*c zxBEbMOqI8l{KrF=uYcRFZ>~x|E;4L5>wDd8msQpTl+Kb#ZJ&wI6n@%(yr;fq8GUz* z9RX-)X*FeiJRo&k=APP<=6!qzV!&6Sn5Oztww*Wq0!c7&Yr6lM>D-Pf8ca9*t<$Ko zKnMdaZO}`J_Xv5~8st#AV3IOrJXLU)Bwfr+UMuk|(daiy)rm-)IN((Nm-fjp*!^#_ zhBL$Ds~G~tG9Eu0-HEPjf0*O@lTW1BGHAEhjJdgiuO=c$fRpFP)p6P8pS7UCGKAEr zlYC#TOk#y=ifWM2exmL)t^E=Z9G~r>xVom7&~w|9u<^a#SX(^?v-k5ul_tjZqFWN~ zuzkjJZFYp&&n!qe|B*acz7w0*!u^r7Z?CAa%oHzoe0!Ix|g~ zT~Xs~z|596@{Hiq7yM%Kn2w;q#F^8f07=4dK*}ea>qUtMVGnu87sAxJ5=jK&@CFv+ z>os^LhXItt7aJ9nuN|3-RyCljPq~*nk1nRB(@>h4!iYO>yS%6drw)$xY^9qR2?!%*uTM;H(MdBFYY=G&pQz7P%2LFrem{; z&d+o4hsH}{uG5rQh0tc|o%S5=LR1K;BpF29h1?6(d>WLuAu7MCtH(Wp&*F3afF8K3vyPl!K{9!8eMT6x2@)>eaiqR0c5*jcfiwJ!_}(Gr{v z%~i66J$x?lsY%H^y;W6J?|;JlN}HN2XyY8KUe5JT?|Bo=U@pmC5ME-$+lQ&rx0ni;Dby?TY2#Hp}-vmzDyght#ow*V9-z%7PElXaJ5kwbKK; zRN0ffp!G^CFa_f4lx|ptzxt}fq+S~e{zW&F!w5MMcI7b2${U%#(XJwUXBV2c=x^I z?d`~lX#9BN@pDf3z(ItsHR1E-J4M@X(83(ps-H@gu-q3!n?PZxAZF{@?s~ubLyg84 z_?4;@3>-1uPOm>^COBIwJ#<+Ue`zNt8gnXo^y!k7_)|J?t$cR$giF4S0ctzbJ6YyI zg!tom$3K?lJ5P->!&%~yapp`vOUtD!hi%dqSY_Dpo$jN2RT z3T~i5NC9eRL_%adHEiHF1XMKFc!7Q9kRGPD+LA4 zvWMI6rpsmb9$k91e1k?TnKms|hy>Co9J~v}4SBuL^fXFv+zDXG> zbLwo3c2oQ<&%Zw~xCvH_KYXUmj-_mxka%52$ESqbx(V&yZ;7jJ6QviIhZq|+PVc+R zH+kUtYhwT9w?V7sN!8Qg>#?ZaM-LXiVT3c+9bp^B2PibuVGpW5E}R%+Q8d)_V)rT4 z9tW|Hju|$J_vpu35FcN#i2S%8`0kiJGBey(CQF2re7|QdhJFQ4sqWI}wm=oD7s9`_KY~sLkY{t=Qf;Bh&2vRROoH5Px zvj~0i^ir}-ZO;DIC9+g-D}MQwPfBXP^$+%5wOiwihHP*@;T1vf2NhayF*?VzO4&msGK!864PcXuyQ-Bc~* zn?L-arWLwK+!Y(CzymdPk4TNL24`GDQ}Rizk}x^tgP(zyS$8qP#Nz(>uEgoZE=hwV z5k2~deBz%}@wt0in_tous85ypAh*VHx~fD?uKNU9(;)QVmh z@DuMNwOr9}C}KUEF#=d;k?*nG1c?u6Q_7w?QQak>z5c2$zRV@{cj+En<6`bY9h4P8 zQu05Jwz0*{DDl39w95L%bh}=CH39t(!YjB+cZ8`0$}jXN2;YQ+h+#hel%{oW`*~a+ zmtHT3A{*_m)f$d-7SxTh_)v#h&1d78WUa8(_xqo!^pe6vJ;VA;B7&Ok1T z124L0{}m6B4puZpcgIcz z8ZSq~ds}f+ZDEhtUk4EsF&Ce0O3}5isjfK&bjBzjM}?I_KdXK1zT zAZ>2ET#z)~j-AU=pDdn|*AS&bvVf7)H+NL3XgPermKC1|Q4jR(QGy8*W?}8dU9P;8 zCU0<{@HBsc8JpMp4l9N+7$mB~;+F54 zi;3X1NfnZtQ?WPPg895Ci-#~u&%acpBvvh~?5o0OYeId()~08mAM%A^hn#d&tT2@P zp`f5CN=0tL_*ki+FohSKauR6wYcQH{`B(fXlgsbzxqknW>=`l0^pv=<1y#VnF$z_d zGviXNZltSz&Cr!HnF{mr`z6+v{#_`-(Lz!?KHly<%#H(w(Z?KxhRYw+tccv?1#X5=?vbkVcb5l=# zjy&jYSnLQCh^V%kK=Cxu{kXSio+IAa2}MDnFpY3@w!jEcM76-$t5lvRZE8vZAK149 z8sJ;qUwfJ8oRy^u6PtO7EDW%&mVedwNzGBPh4)NK*;)*v_@@(!jpnPJ_wG!?sL}l`)%i=@53u@flyjI_q zU$l2n(I_+Z0w3g{43|OmsD}yGSF{!@>a0j2uUS9V51+kBDKQ-D&Z}KkQ>#Re7R*Ey zJmCbBb7|=+wrxz#yo|#3aZWb6Z|Buv{1602BZoV^==lmYYyKDCJ)$p3OUs55DP2O4WbMC8yh$pgq`9yrsYzXG#rQ&6O#AdxKn|7{Xm1x^44X`yVaEI3*tyV(vS!%NA#qmtT66;)2-?Z zV$5_?ry(p*4wq9XJ{Q)he^OaKqNXuKkJ?f+DlMkj!UCpba?@Zi=hIv z05f+sf&oc`TnJ<%UBv*1pk`0PQaieL^O*@rL(8vg?6U#!QZ5Nztd#z4fz0eu+suT* zQsB*BjuSYpj^68~TK(VxJjR)e1rCP+ob$Bui7Ks)4WQ-pFRc4-sB-{v0$t2vy+7Dk zZXC0D+of6wcYfZE^=SOpy!_HWoH_vUDdXR#f7g;sHWJz`CdY;hOLHPz^{a-d?0X|w zT^R*9;&9oOTz-;3kobQ=ypLwibaL2HmZWv`249vU2c)>gN?*k*5sbnHB%yr1@h%4S zKy4U^zpv*h{Q3$J56FNiGfVsiY93)(?yg-lYq!Ictm)_#s8aO3Xs{RsaFDnC`1HJP zq}yEOsd3fTew2_emJKA)XgWyS**(lG+)`I}F|B#j(AEy#wKqDxP^31b6v^M2aHtg^ z@o^2y20$0{aIM-fo8dYhVx}6OhDJLiC!bw4=AByES%nH)R&rV>_hz zoZks=n|_{1wLnAIz@4}ig(q2zv*UqK>dp@EIXib%fv%~ktYI0Nu_W*V*)qN;f9Gyd`4OK?`xhFkY~WVh4S%0 zke4+r*|C-E4hy}3y)x=XgxPfwZTsCLo_VUR+}y$xwG{9YM^NKeFM8mnLocW=_v=}V ztg7pVWNqlMFD+RWHcAZ)0>%LOW>T`rtMFHE^NPF4ho3QlMlhWw0+v)AjJey$y>rut zI<6E7Z>uyGH%$)ihkbgSVKwTh#@|^?zmufA7p?c*KLmSe*8?ridZ@d;!ETduUkXCy z6RWBP5)>J*8pq*F1B08iL8>*fbQ?`gekAc@Ab49NAQxSvo+}pNKF`_S+BRAo^TFZ5 z1Xfx3Ql~NmacI!KAy37BSyeqOZe~O|_B|UShFjJ6dA}K=)9uH8=bt1bevsu1f$y;8 zFBcU6!DXJDZ_L*v>XXU}poV(v>g#zV-KU2TW5+&{+?p}6g5^WUrbVe|tc3fEJEi8) zYuwSD?byVAM(SPOGP%j|wwAhBL0UK{Py_4cr7ikmrWM+f)vn83-)7MjenZsu#;`GU z#7#_tSSE-?=Fpuq%>php!m-G^F>Y^r5p*nGd*p4WryT1|6qo~hU?9;AKX5*~FQ$13E-g~i(=lZ|x9s>M$ctiT5SyEA$dwPx zBPJt<#YOB%oa3xipo4@N-QG847agsb1(0|@AblFGFt@y)7d29kb27#ds4+~-_U!_D zzJEb4ksr#ZdPndkbV1?zoiC+!;*gb=>cWY3jE6P_psrpuX?2bk^CFBcWlxN!hA#LQ zun2l+cn#IvpdaSQvUTd?ST_0yfZ6>xtr_aH|A6ttLgxa&Bj3;y*qIfuTZty(3kkf^ zPYOL5u+!LPps_D;Rh2e07;u{GA^W#spY}dOLm^4V{`Ey$KCQ7I+=_f0vA65M&6Y{* zth6>{HnH;Zqah)Z`tEEwHgh;9lYM^9g5Lc3$HecGQUoQ~4pmecCqgO3c1~Qk00=Aj z!&12{%?)PE!FhSHzFjKb$7HT}oOutL78jTP%tAe(Ca812LXs4XzM(t+3E}`VT>wK~ z1EAS=pdZWEnp7c!8bOJUFA-0SlH{_zgST|Zfk0>saiyQ)O;pnZz=>)03dn`uJzei0Gck|@E z3s_!sl|Qy4AD+)`aCUANWh>?cpt`?YAc|Qf{g+m0z8|Bntsqj%b6fZIw7m{_dOQTY zUYBmaBVl6xaKPfzxI`4Q9s2mL)uNJpxD|=L2Lb>NtapF4sva>$j}&3t85Wyr+ylcr zE_NjiS?Fw=gAkoxGE)xB+J>Y0BDOaXI)@?{F#jGKQV#DramQ23ooqBZHSl;-rQ;o+ zRBN8R^0bMvK^dZufYK+%*UAvm@46YU@Rvs+WU%C!H(i9PR5gv&h|wGzIvQz^>fBq! z8i2BkXuMQ0G|beJkkrh+*=^fg!PobFKYPQ2B>P%y5)K#}l!v5?D2?buxuAeT)K2sS z8(P|3HT54*odbO|wLU$3D?myxspF}s=4PEB;OzW#5nk-cPEh_UGEoO`2xrFePa`GupA1Ma`JhBunNh zY)P*BaG}n>@mg3)AI(fb*A;V_h~IU!edYG&P}^n2_Fl688^kaI;hDI2{nUYQQpk7w zOxNUO&^#Q$#@L<2mH#e7S4xjBHegNlZq>XtTQcl(meUDTQ6J#-k_3k58^H34Voi5g zK$%UZg-4&TO-QwRZQKc{d`?VHDL>+-!@5Ez5Hga{(PjKAn|#PD7$GF#T4LSONI#{x zBjObKL|K3w=?H?X3lXxi#w7)zsM_Hf3u!X0cx!5SY(pH?u9+~=OkqSB^g-KU2!0P++_%`;bRa#5tO&%XnMdd>R?ZBXHTD$|nUGVG zVm7mn{}*vMJ`Q69_d5E3fxkX6LeChJnoT(VT?e#XR|{+#CNTw==$WN_0~&u26^co; zyZo}ce0UCM1SI4&NXxCfB=OH=h(kDJic1`C;NXD-aC_-6#3GrkP|qhb0>ZZhN-t|1 zUi5x`K=Ik~+x!IKV=dme9@BUS@*KhS zX0ZU?cVu!$!~k5^n=NPijj2d6>1cNx+S?w!Wk%o?{h8x7xzJd6${CgrTrvorRPgPZ zMk>xyx+9rtWYi;$C#Mo;rJS=s-eQ@IhLeblH1Hzd7k3sCz^=+f~j9PfaV#l)mQh$in#-fX6k9ozaKtQnc;inU< zNN>0VqOGcoVmJNCB~DU~oUZ$rQqlAz;km7p8|pd{D}6aY)7O;ZCNg3Ll3X-u-Y-!q zQOad`ii!F=pQ}JgM$kJZs&1rvEfT?c`GLjs7BOl3!}=m7H%Z*6gR0PWJ_(*;NZ$4Y z;C5wh_9sk@emFLfKe)s_+hHAR0tpQpx@KI(K1jyew|pkivKk?(@Wz4!imIp$U;d26 zXH}bxQD8EzC0=rx6j%5#{=EhoVP<3F^BPUHL{j>HC0kbm!t6*%9T1(Br9Cl1t-+>TTO|TH3luNl99+KDgn##yc4M zo_P{eV6*^YKCgo0?b%>LHbX-KcuDNF7`VLdx!sO;mZgRH9D~rMg3*ZLJzCQpGJ?i( zzMLIWcB~NQ_>c>WXDE_VCFcBnTE>p1P=)9t)n~ zP>~*#fq_J218YAjuXdujb|!x0=Ksz?*3Mz*5v}UBkk7K0^p><~OT|73XWCx7`=~`h&t)1J+Cf<&J31Xuv9c>?tg5w#X_;WDGZ2r(*W--_R8uS;rDs zNgG;S#gUS!Wk1}=B+tTwxs1f1)Lbt#YoT{`tn3xhThW5*JIjN#Yzy1sQ`|r{GHkG$N&M}AT4%j_A4^I*bqG`(%}`ovl0bT zT+87Not4N)ppxz;6PUPoUU?p8F$zF_*lyRNf5k{KJ_`ahTyi-vK0mBF?-v=3zQ-90 zzt?)Tgb5hT+4sx%3$~Qemw%jy(6n#S>4x(r@d`qH_i=0>aTjoE?7G~_D`AnM&cMd zk9SpASWSqVF0HQ^)`zlG6O;VyY_I6iUl=k!cMx&t#bh{7w@>NYsM$sgFFchv|Bti| zl(3|$Jkxm`E_AjhRC}?4<)Mr!ayV`GP~hrnc*_H)OH)h3m6X^N1(rUoi2U&y2FHqy zKJohbSrxMdF<0TTPt;@&miMSr9Wy%~NZ zv(q1kn6gtjWdO>F*@nzeH$J{0yQV`PTB&NH`9>0rOkPL5?ErMnr%SX(pbtrkI!;VlvqXi)*0gX{w)N-ucDnW7t>B*<=;P95^#fa9) zEwba08MhJSh^uHY6-G=lViX-cKfn*X|D$#5a5Pq$*WHboZ$O=LlGQIYtC~n%ilO>w za(pRTx9K}Pqa`s$T5w@v3=n4)GPe}i}A-OFtW;k7@N^Ag?(eQ@xP!Z~;aAMYbDjJkks?r$msfOV8&*h<*!O#uJ#(H#qi}I4+vo~$ zFPxu~DgxiTK2k@sRV4-{l=kSix7R1-4{jq4*X;h+Mj>cvpwsHA{rK@Uat>P-NeF!Y zaB2X8>m6x^|Ax9^oKE>qfyK~S%^7U<&JDel+;Sb7$HtV$NBu~>=cZSd=QW&Mh5(FS z@nKn$GA*mM6tc=Kc0?$>OS?wcGHMSZR?UFxN(eJr|I zNDvIL_@-vN76o$nMUs%x=3`OL)fMnLX)N%L>I&>92jl=N?tGkv7@!Pc<3rtB0`tC} z|Lmb%6)qb#wDt6?nr}pdE|2gkT#TN?mud9-Q1tJIj};0(IFX?UG@^hiA7D9GRS9c3 z(?PF#0AER(bwH#W7`^T^-YAl$#%DnbF&$We9cMLhaw%B8+%JTzu^vZdN5x_M25`H$~I(T_culBR*V4X5%_X#MlBNie5s_tc9YCMC?P#ouW*-t z32oZTP!#MNcb*vtKV%pI8zGp$KfHK(NwOfE3n@aC)>pn2U-Q`7Z5}#R%S>8&f4ub{ zSyL30s9D6T%wCaCBFpC5to(zGX?(2cHy!|0G(r;;Uhti)C>Ef@#fmvA=lIg)CH%;g zO^C`sDkBOlDj8vS1ZlWT_I_apB~0ehakXTjx}A@;>UT(&c$HT|c@YYYfq~sfEfrEP zVyxnI-1xX13uD&<)y}gz3NVU{;yx4?fTe|1%Ds1+j!tc7ENQMhx5i9Ei|UD%>EhP z8$cRr*wzapXbjJ+iEERUB#So);>U{=p~#ny&ov>RbB}gg?5v!j`Pc9T&VsgDjSVHx zY7uJuds1>4f`pITe)u_^m$7BkwLd5R{sAF{eB}PbbK$D`nJ&gbbHGO(N2nG*kVxC< z!ihT!SyJ=|rhSxkI=H-78Yy`>5g!zc($yEsmqD|0ADk{!7*n!0DpjH~i zOWJGeFI*sW@IZ%Gvma+!D#bft&?Gw=Ski+IuN6kAL{Yy137Ppjw*F4K&AE zaob+YI*aud|J&~Hw&7-~wUWV1P%`NY`!s?~l0A(Z=D@yP-!`!yBqr$YLzFC2pWS>( zVNXQ_n>Pa9lETa{b8g@Ylme9`f&>k`nC01}{ettk7#zmv)KKr|GqTUiCD-#~IjVG9 z`mcx6E171op>MaFW2 z%lf^cvo9>T6x%0qQjx_su|rW6iE)+2iOrAy%!u{>c%CV%EDQc|d}Tt4v8cUmX+8hW zhJz4MP7T}rty{z#Qdktk)7(f;50yXp+cOp=nt#8(R9krM7Du`v)W+vTJ=l!bc*iUD zLao;^ImAtl=fo|}Q!%IkOA-r{+w)HB^YP*BVYFaz_F4#0cb8_&Vms5Vq@@$pY~3s3 zEGp1TrOcfuDq5vs{uXbRU42(gLq(MRbdXu6F)5RY-@yTHd{B!0&<{@Fa0i5idfN{7 z`Xh%wSF^advC1o}i=EsgEoOz}c7$*&SYDnFCg03q<;R+6HHv0#EGjDM6>1IFukal~ z+;DX7+M`RY^0?Auct--WooD53iVK7?tEOVbyRQ$x3&DTYRX0C28Ies3Bu|p65+gNK zh^*Oez*XG3^?o~D?KS+^3*mmbn`d#Y^_|_b>mP2m2Ce~7Wou)4bFIgy!Au(0HP6n< zItkK$kteNbtv~siyxc!~E!2;Vl9R>D)sEQp5-M)`3-M}e@8;RLUV~0_k_byGhhLZ> zn@wg1q0jIk@0(ANV<*?H*ACiQGvln^%VV>}yOl1sCwj0`$6HOzgo*oxq34=vkfQwu z$gGUvBr)bk&#bNCX(T=HU7tCF#Gx1)eq%#VckjsY_H54X`qihV&V)9_$8Zq@SDsxS zXVZL2?Vjk_o&r-d!=h^z~d!H+(E zT`*9eRutERt(I&QLGUdU+y78h`EOxR#1Ai;UbbpOqjdTw)>XPH#e^t+m(%lsAc{yx z_$yL0eq&iZ_1|gm9GK5}o zDtwJke|y!lT0v7hr06qj`FqQFUV7Oj&k(>PU&vPm&qio`(H~M@GjV;l){{ji4%;CS zwIlTTQ9*b0mP>SSv9$EoRda{PAkD5@UR@?)ZVnT}vyYJF`WFP6vW0?hY;aK;T5_Xi zIe%#$3_}^^aG75jx3{Z#9Gdj?JS8t#XvzP$XR!U>0=R7Ya3x@yLz%ti$j{t&0#5AI zViiFJpQsku)3c=pjZsUqBN?H?Ug(oR-?syl_T zKv%tl1vg|BN*DdH#-_L*B1ycf1XjRqUyAhWXUw&awWo+e=6@7^Vbc$Q4I*ng=An&KMq+^tv`->RbB?mlGs z4Xah?l?wr%BLR@dH?}cfXDI_dr_e-`bLn94o~pl>(U`KET$&?tJQ|->KA-RCVp8x- zoSs(1T<$0%D+;0KdBEfPPFSZ=E#TI1pt@tmbO5@nT%Sng{QbMT)(`Ah)gtRVb`Hk~ zcV}JQR4{4O1d}Zam9EakuCKR_eCVv#uSw}*#z*X|K9-QvGAi34w)mU{fXyHm_1kXb z-~|G93SXx{=HNJ3aYHF%c05;HTwLdcJx-Ez$9}?WW)mD{>uq08$9+{jL_|tfuLvDR zEEeLf7dWj90p@fkW7ii8<&x~lC%Y1}W}T#;`9=$??L4s0K2KNG?pG|Xi4$YPgR710 zE$|#p`3lbD+=9^ijAX;-p$yx8k@%^(kxYV5 z-=`!|qgIEL#O56%8qxF-qwVPV?+o5fN-~)ZUtB!!@KIHkWk=*He+_c)dcc@BqLrAC zC-4xUk}3a30$~qeQz5MnB_+%@l-n@IdeHTPJ*lLvdr4G8J-0ZVy0$>Z-+T{|<6y;- z9`JiQ=xzTqefIu%2rJ>Qvust4D4mO&yi`J;pq}FgR|{IK*}puhwr1az*S7zg+__zc z#NW2ITpSg$_y`l51jO}k2wGd+O>2+#IDp6TsQwpTWzOTjgpS_0G5Kq5%fpe5&3!kz z%Uq}HBxvgqA(zbks(8Y2c#4!DxrkJ8iAIVc+Id-wwx0f&D2wu zmpEN@POXQOAp0TY0gwab%s#M-21Y#TT+o;Il6A>uGwNlW_UT0pmxY^=Xq#-BIR{IAp+#}v-f3Qaeh4F&_#?i zYIJq}H_EBCAnWa?ZoV2W>E*=q^r86h+_yvdH@j!2B>RnDfd!-{CYCBb&wS*;D)QnK z@HuRZS^)q*XRwB3B^ygCBjMtqE+lS@e#2hdF?@xiu*VYJA ze#R;`4)hbci6r2@iVHHNsoO8)+7fe?g||_&mQl$&UNNB!<>R{=1GS4@^AeL8y+5r9 z5`bl{=fc){g8yiIqFfuJk2nwg{W;cDw7%+q<)6hPgkcadJfU2pog8b_1Bc$rSn#yW zP7e`+Iabp`rQYB&VbkM(%$RM666!~1po{5Tb(?|Gw~50=a>Q$sB^l=%9+YUnKSbf| z^tjcTpj#Rv>^RltrahM#C%L zFTtGg*D3Q{!?hgml0ln-p|m-ahZT4IfA;F2%0N@PsnRVbS zcF7eW8EQs>7L(V~pdtXkFrNtwg!(lGd01!Yace)bn^RdU{CmBBsNQn~uwTaeLQsyv z8D8gz?lz*St}~WP#gQOAvTy)-C$Li{mKj~BBTepPHEmE8PWJ!VP!{dJQ5pt-{fEuZ zRgm(M1%ElWmEd&ACuGk^a^r2jWd@osk@QnazxU(Ozo4fAZ^FI}tZ?`;Cc_r9er*kR z4A?;k7>o>p7Gj7+YDdZj z>^Oo#NC>NZWvi!AQTD&rB@8;VQe9?SHnQRSI1t3t8@r6n78G{}KSrnEDR?CYQ z%StC=w@1`J>?=A9S(X~Eax5Ec6^L6@72^d5_+^>_J?7(eBB4N!UdUtrxwp@1`mgF< zPEFAC7&h$gY;JZ`{QgK`dr!_KfP0UbFKhx=8OEFy@i-^nYXXb&+c9u%0ano%y^CS0 z>elz}PC8qv@`GxM`FhBVr`dA(g>Kk;!RsCUB`xuKEO@JBDPwD207aDS}aZla&BAd=9z zs$J6nR27GUE-~{cH@UOAFEU(YB+p{-9 zAXi0XMC8XSAKqLs#VP*E3h!+}q{FB*RsHa}JuW$~q_rLq2ya!(0r6IAUN8nEkuF(y z$Ru}OV2B4dLL!xPHG!8q@Wu)?9h}5| zz6+>N4~5z-(cjPj%K&!~bg6Q^w3dUO`%j$XDN^8FK!(lT7Ut3>Ss>X6(>ytYZc`!9R zyl|GeBOhFi`Nkr4uR%GJ6I2fSrec|UCezBSS0k4_Gx!&P$E?ab z4OH_UvAyVbFM=SIY((f4IQZ*}iO>E4uFawgz@&}>&7A1D170e-TaOoUvgFIUDDv+4 zr#C9-ZZTAWs%OtpmUIs!R#rBclGSg?^5e`74^hkeLs2}=%wBK7dK?yS zf_JN03(ueLK*{Tm-cJb0_`;F)W5SNqg$*G>SV-Ok0x6VG3{Y1W*l)ubcIUN8WF?%r zvdv-ya?)@G5h(#}WNZvaL49dP`6yiH;Cf74arsF{;ofI`G%yR*?aj@Y$=WR>7+20> z`00^Z7nP1$u*LMPP@*{XuXbrwgJMeAihHmj;^G$ewgUp7ize8)e$YtXg3CS#RH^+} zglHW4M~#;1Yx^tn3VrR%)vDmL5gj)-D~u=Ko+e-5vbl;Yl*FtIrN82@nUFX9Mq~Sb z+>=1tSA`8Ycndi2W^+HK8MmWfpTw$Phq@`)#T@WvyI`lg@@S?%?aj~B8Qve_)Lwkm zA(L=`G&Jw}FmOM*$IU8lblv3LXxSr0QlNI>NiJD;wCs)Y@Exj{%6v|r6xH@4l2fBw zR$2f^lOmF_y~63bn_I(zpmez{Zzd3320mXYE_$WM;s>iR$OQ$sZnyva3vGDP8Hae` z?kk9gASh*pB8(IGE1wz`u6LWbNsEDOVIt&R-a95dhFt8kmG&U6p4*W0jR={R)BXqeFmdHrnshR16QBQ<@ZV6n=d^YM!5L-l zjjv&qNmG-2Or03D@3&AGJa7grTI>vY^2_pkO^A!lG6t*H6BA4Wu`@iWvFl0z^h#8= zUxJle{mx|vG$S?=%9iD5qwNWM^#6_GabaAGi;W}m{Zk0Y8gFD#I6_F)Li*7W4B7^h zlYgc(7)W#_j5DSvt3J&Ldpte`qVR-rh>Mn`3g2V5qRdZ}f=L!Ry8iaV*P9MNgY+v) zbZm~^)P!jWgDZ7>GvD6?U8XR8=RtBkXnQfc<#G~&w58o4q5%TrytaH?>8A2 zHlx(FR$0c0kRWq^^$DmRLm-6}MPXHgus;DhU@Czr^nt8D05{J_vX{`D;NFzmpE1I`urX+V@mmcD6Z|`kTBwa zjtUvO9;08SK1KLVw?r|^vMxzUSFSEC&C;p`|Ls)4KPo*T{DcB!w5QWq;GMD0V{JiM zFmqAAVK*&1k#JpTdDJ`7DZ(E5#zsvwMSX9dYwPF0@!$oU+0}DXbBo*J54j_2+8?3; zU)^;PVubOe>Qo@Ho|zukiI>gHSW?gh|9`uUnE*w+}z$=zmfU(uh(`X zNh-%@H`s;cRf&L$ivmj?Q5(FB2H1_?Yr@rgDeL11Rqfu+aIt_IJ{wj zAl|*UWZ9ZQJsT_%HPsV_UX;y;%rc8nAJz1E3>`*kL1(ct2O< zK&e$}1^F$8=){+lRW@N=a9OlJKl9bIacsRT>Zxk2gpL3R{*~|wq;Mn02#YK8Xj1r( z=U(Fd-B?vN0OO~A8(D;vgYK%_}+VrL?pYrPueBM?8 zysQ=-sCQ++jvGeru=mcqu%hVx`VETztl)Fqt_^^A(M5dhtaq-K$9 z#E~;IJRJ#a#?e0iE5%A5?6+3xe0q3m7}aC(1l8-<4m64c34zNxa_riN?=y419(HlrhN=VtWas;V~2(3?@Vm?$Y!o@7c ze2_uvx3y_j1bY*u;o&{H_==`OTzCcCpR9rTS^U*|n(>vMl>+I-t?g}d4;4}tB*Hkk zpKU@!P+)N)2388O5?|(f4Yyiqbi0j_QQ^SZP5(*XA6I=*dgzv^crm*_^* z+}xbX--uK3O!w^an)Y9b+dNhqD~HGH13qaD9XqTjm@78nDISFOZYhAdETchN0@yVXn7Cs z^|D;lyt5tguZz^XN55@ksVVa8D=CT;x4$^VQF_=uOhSQP&42?G`3WYN7Y|$>tT-1} zlPe;mK80|pAs7ZYVS$?IN*{^5y>@u*W3t8fl3j|>Wny@-4gfv#sCe`@Yiq;}+iUg! zIwmIMm^J63v_53EqqE4r=Jh)^Sxg$qCqn+r7_n{Nh3^?P0TqjD$-pW03lV8&sK*~U z99H))jBN)*lq6Q>W4>J2?;m{2RWR=`m;gX&OxxcepE7$0RZgqV|ux~_aG77LtU1` zBmA`LT?ly#-;8os7u+6TpHbXDFdd5;MULK<`*}wzmh9sQd_GU~pMIuwj3AE9k^mNd z!)X84`NkxK`73h5H#tW!(_FRfFePG|#MD6XWr*Va{AglI zhdCb{4PfSURUS6 z)_-~s!Q#*LG;+L`wjTx=?i9-EsW#}lybeI#3}^L?o-rtpnp9+Si5sJ%D?69+U(ZJu%s4Iu1-6u(G1H0lr1=EkO` zTTusDL*&4y>adV;~4d6bQ?7Bm)QP)nfdOISp|Mjg(qqzXWM}o? zh;30sXa}hxC?75XiDaT$(e+ODMu^|!B$%vZgvEm9L<-?2i-!ewS-ttAtDo0(mL zR41^GeU)i>!K72r$Q&n}akkI1L z4-;XPB4gsEhC`yd~@a8v;uO0+sAKwEkt@uvYW&dB_+SjvJEgGU#$H08LXmq6bAxCS@iI{d= zx+!dCNci@#_hcEBQ_is^m=_)f4XZ}F*$ir4uj_3Wp8U+()-W!N6vdJ#^9$^oztE$H zTDm%5H3zK#>S4{4wE7i?FdMGUFVng74pl>&ax2Ct$<+Wwy0@Fh)4Ppdw9 zLC%Q@TuxcVWiP+EblUwv;`ex@bY@KWi>? zmj{i|?@o*x6O_V5F3#mI9%GdSHEe)2&1(u){C%>JpUKcLmRil?FRjo?Ov1T zYupqb+wbLX4(y|xHvY>NQp+_USC-Ao?3@x@LH;|d44^Dt_b!m1&??Y4kwr5#tQ_b) zq`daXb_w|3s~*e~W;0MxQ88s^V%U>AU>A)S-K6IRO%#Lu@bEx8Z~Bq{*v4Emex|Lm zR*9bP#4e-S3OAbX)Jb;3(l$CtHlMj^D82nT`{&YIypTrJcin314Z$EGuBmX0bHyzW!yz;!hrIS}~XvYZuyOBUP0x3FU5$7?h+y2DccYr_Kq z5k;Oz7KbBz{OU~!xUKcKCqY>SkEZtAPq7md2?;FWqU7d?!3ru$;YI-2H8PzA7l8RM zXb3YTP2`<^$fwHZU(Pou+66MMV|tKm#t9<9S>PRAZX5Xu+qfiAsrt@&iOo5;pSiSk zjKomE`vdaFZc|N%r<{9>YL)Wu+~k+$+VtBqKE<} z0jb76ieGvdg;2dexPRaDzFpuS)qv|cVRouKkZ`_sS#|%{j~emmFflPc-@}He*}{%^Rra%@5Iy`ma(Ta9_Z;% z+mnnv*^x`#Sk*57$UmEa!O7Inf^=W``e69_{;e7ws{=WAtj=~p z?nP2Sn81n{V3Fy_$dDZ})AFNTDKBXG0KC#EEt7h~j$AIO!#pe3{E=z7MG2IxeNNyDuqn&|*KGI{S3h#7`j}b=3AP^BU5iB9HMOhiK`6 z;^@nxw;l$MF4yKSk?A>&7*&8@*YW7?A|LHin2r~6RBdKzY6w2i&wgID{pz{#V&IBp z?;^kLrSuUu9K##bxZE5$EyQL_I@_2%^>fga(*_uq>Xwf0uGi44#IDPvuBs}S-J{@7 z0#JX1A($NLC8$$OY7xH}N`B5kadR&(HG){Lk8aAnZ?{_MRNg2u<-W3@6H-#r+xWPi zIas~!TUMUuVgiKzG)GHIULW`60^l1dPwVK%-0Dur?mC%KF3?a4ELxp?DQEa!wsGQx zw!(@6PqJvyvt0m3C$|#(echWlv!X(}E0v8GD>dGDch406@-(fg+PaTNcEi(lDIx>@ z(Yq#LR|RCi^1aUhf(WZ&#j&shgPMb*IK+^*5{4`g>dSY7QO-ff86frB+|kVJ+M#Qs zODdq}y|5op%TLj@>%9RUe0=g1IdySiO8H`ftYe{e61*Hf9a!XmQ4{2Hr=^gr2qQm- zTRa%MV1Fic!)5D(#?39S730wkZ?ja+;3k_{4BV>%KA^n<_)X)iTQ0NJPX@4=tH@9qK#x`zeBlvgcwt#m8&i6mT+KjOZFU!S{squXxl?SgU>Xa za$y&X^Ql&bnEE^hF9#lJy*7U z_2>oKwMg}B1^Tox@(qyH6BpxyeFt#AtG+vC5F<6qHC{>>RFQ_8HXnT8iruLD!S?X1 zd`Md*l63;r?^uB+QezzD)1e>TF$yHbL`TcSo0)srX#R5rRTT}1#DteN-fT_RSJ90w zH)|1Mv`Q3~wq#f%p=@6+#Xu|oxbTNE&9y19xghHxMfN$3=z3Yx%AEjJiuOpY&10hJ zYUAU73`Y1h5mW8D8Mis}%L8Ei%cq1w{gU67Wu;s6_Ra;8)CdL8W~L@N^>ab?Uw z6L0K{?R#om5|N@ta?Ydg&|z!KSP4UNNv$=-pC5AIYR-X;7SPHZKR7$Q8XE%y8kU#U-|S?Ba;Wl^Xc5la+N9JS!B&zWwBX#p@dlOpA(Z zJyyI1=?NR0v{d^Ye^7jH0oY-u$Nq0$LlIa0@Q^fW@%^?W|6Gw{MS?CsGoy3nkZ z>KYS+Hti(fzLZ%oyAOWWaXYS{A>`*V{SwS>b6q@`D`!)>1)*Uu5s8nE{Q}YG6OZg} zWYD%-hzpp~&;{eKYw5(Mz0M~lsNI4Diw8WsYK?F|r(STh zZC$z9zk)f9E~h2>0LBrSLAC2OCF|F*lqj?Y$lJz_@8DXp?njYAXbw*zGE_t|^hAxn zPhJ=G_y5wW?Q0pvtx$oOw5mvG4s=@S>dw+ni0tgE13|4YFUX*EZZ?WI=X$m{yzU{c$j5>*J1Dj zf2~~0u$u40F59L~ij`Ul{dM)p99W*EKavVvoHzWU2ab3%yl%4hk*QNi(f(`XAblYF z7o)n45-yVkSslAEFbyI5rzik(Pqwi961QIjFB~#9;D|3Dn zCNBIl6Fk0~hn{zObjX>QIDcpE*yv9}f%B&_h~)FPfAiF453i9+2Y6UpVhwLF%IHde zscVE<5;qf*55}0yS}=?7#wQ49vp8S~cevPy{?Sc|!j^OaWu$YL$-{D43?H^2=@NR(%YooycuL=KAKihpPXcflDlYLyJ%$VTm17btEhku zVdn|IR4_mO^}uWmR6;!Ev1X!9!NK?7{5+|$5@hi@Jy;Ya-MVyIE+c8H+$i>2m!ANM zaW$`@GJl2*HV@S(gR%U@0+CVB5?8?cfDi*4vo|F@ iHkZhin0w%sTlWH>KCxUGY zu1Vftc21-CtfbgFmwJ}6kNsDA3wd%p^A}?O$o;QJ>xMT6Sz9L}NaE>}aE^Mvsf^05 zbc^;ur5?+bBYyrQPJo)7Y{|WU)`89V@Y2DI!?FpoF;zCL z?zz#NV=Yj#dOo7lxf|2Zwix5!3P}bT)=ujLej}Y*WZ#sKAo{6a(N6Rz3xii$=FX{x zU@!WvrLdg

mlc{#*(F6fHn88z-iuLjy91MskL*8liL9uqeF@8W;`QtxZj%QyZM@ zjttY$EG-xPCwmZy17LO)+tqNMf*?cu()ptnJh>ZZt4L=M*$rfLc$9E7Cc)Mz4}~1erSDOBe{q znnsu^`T% zlU9S2UC-Ql_g6+VG%h5{0aVOKFKo-$>S}5mx!~nb{?+{dS_^_Sz&xVgu{oY19re__ zSYhcC(Jm^gi(4tmzQJ{s&1tSaB0ZbgM~@jO@e>Apn;?tcQNVxTfY_&hAV+G7Vl^sd zKSwt#cO{MIq`{n2?T-t71Ru-uBYrB*#e+hPJ^7a=EAhQTIV%b!JfuuI!GtBW9lIYn zlW~#aXqVSkjmt|v2n3CSJ;p_sA=A8182?^egJl26zHmAGM5>j*qY7319Mb1kMy%* z-ky&2fhlAGVjyBUxdFHQ?S9Z~wLcl=P_&6-ioj(4Ic0IQp*DW492JLN32koD2I}iX z$xt%@i;c=KZF3G~`F&=$O?u{wiX4LshKE~xQDwjGT=Hjz`b^aH4P_>>wl?l~p3FN8 z%SiE44Zz(4gNZVH`);`I`{TJh7|WgQ_b9U0)K)qQI!1XJJKt9cyt#5kMEn>Zk=Meg zDZx7vm7|>ZPYseeYmNV)4Jgu|?>bUpN+*>&?P2gkFpY^y?_i&OMCaURC#dP! z{EQSUZIOIKDi%O4ko^9~2w66*1Fq}8X-PS*D%Q*X0K+k$;&>s?!Es#3%GTPXLL8YT zKyvnD2X(A{?9%6gkbj^n4((O0)iq9~fg~=U%FtJ6vrZaB-ao4T*+T)x0OQV#$P50K z4*{5l3KoFlVJbC|D%=hP&DxdBOasN4!!9iygyNj?-ba*QS-F#-PILAgxqt^78G;tp2@8x(;~k|dm% zlIDfWoBq~+JY(BT%CzGcKC&pf*F6PH(w#7}P*%9FhYbR)s36 z8C1$YP`s}eC#S0~(Th*!q$s~-msS4?jf_t6Hh;a|A)eLuy&)JF98eN97EK0=M!$D_ zUt3-rsKm+&Ioxil39u3#>Yq$N(QA}h(GQi7A6Q-+p*pLO3YkXQVVjq|J4Dy2JBu=_kyL?5;4iwRJ_N2Yio8 zOfR(U1j=p{OkTNuvoASvsm$0!?%avhC}cLoFy6|zNo$!uk67>p;{`B*%Y_WvEzc^J zl(_jQj7`P5U*9wN2-}CLmC`dPDm{)@{JhsBSvPI_1zu6rcuNmjkg21Zj+lH2_qm4{ zVKDkIZEd?2Z#&XO*HkwHuE)q-V-e6$0_Zs5P)^Y4V2*XGg(VdgmYJjhAx#kOGJ(tE z#1+$UF?I5xhi7dKFFjuKovxt+3xn?=<>Nf-MuAXWm?3&AqyP0f2PS2Kk_!LjnEuQ~ z7<$S5{pG|ftIeAY2TIK8#`htFPx)cPvoPFvW7{Pam^?YWecv*vx&K}Zo^8zf-ouX) zvIaW@BeFhD{oAW5MbE-Cq|+73CUVcakJc)r>?rB{9uF+cL`cDoxqh8B(TxQK)7=Jr zC6nFI^|vpW*U9<%~EOM3@S=W@*?7&i>Td{YS|TU5jmarG{2cU zcgQK^_K}9HRl|1eDrAlY3f6&2Kq_wAINsK;H%@;yg{A8$X)VPlCPF*`2lv-(1ur;&sq--qCH5BLKbS4{wG(<*c)S-yK>wtbD!5}PBK%-*Ee&Zp^*Q# z1AvDb1_8l+bRkGYrF_Ugz_!+8pn%P1-^pV;c~aUa2H(PSmU%iaKrJLd`qTFr*YlS2 zPsqfSo?78?VH3~iP78BNA|fzrPv9cZ){H0Ut}v?Jb4Q$~-IG%_SslLT4}};j3rbJH z5Bu>!LvvBCg1bWzcjdpJ;iVcD{C|;g!6{F~{mnBFi>KNfKQW6(OfD*_7LM$gbgE-h z$(Uc#rvyt-Jz!kLxUaT8<1f|seCn(h%nw08z0!u??J@H*i8Yf=zpVkIXewCDc&{**s+;Qij6xIAg(zzEE+ z$%WzlQ0$2v8?F76R^vwM{W2s-nmkKeMCaM3a>8Yc}JgWrSNMca&ye*F$yn^jc?0v^WPv%cXub431#1!@Wt4LY| zuwFl=I@c?zmDgzrC*w&yhlPdV^-wuPL~&xQ3nSygT(=||CZIp%k(6GSKJ13~E9)GBPZOt4!6xV?w7rjx^4?!rB<~&AQI5(S0|XY5LK1J@NFntZ97O zdL&_eS#z^7jk8BK6cn- zW;6Ldo+H2l5UIsvB`U9_vqL@9ZTpFjk7t=31oyxMa zl3~rn_L{@Bl{wKK0|S;P#k2S<)Xr{=T{FMxVfgSy(doCBeQuAv7$yak`k-&U-!Kz- z99;~DzFk_mNRJpo`AvQZCWEqoz#}5kF{d|J*;AjEnGVR4q5xx_pBioQ#FOI(Cx0_I z-ak;S7wnH_wXGrnjM3bseo>DrCZ=h9-z)8KOwLDUQk;Zq*TYTk`Y|F}E27%4(6DhK z_Gj+>$)vi%;P7ao0{(Q)=_F+J5qkaK3d5<9|2{b9avcc0y+VzXlBMN$RBXASpN|;M|sJ z|HA6;pP}^Mo6S264yBIHI&E(XuXMUF5dLY9$|uLbFD!#9f#sku!JA!j@7rc&KlmW{a{40r|-F*JO_v$b zT?3)XWP6g!jIq|B1b^SdcVH?MK0G`;S6%O({AKT;i^K*S6A!C!M&r##O6sudV+T6) z1s*d7v_j;prY{Q%ORn%AzM@hcPG)AwwD2}yMXW`1=GSu|3^JSE%2poA_$`gi;0S#F z^bkSc?TtX)yR%(0* zGkF{W4Yzk_wXA@3b|-ePqJ*L&DKae7^6^8J@G zUl)bOd?F9Rgp(87x~lr3Mv;Xu5S;!q?*po-Jg;%`t|^z_xZ~o+%?%78MWKp4fP2Tk@QYb%+8KE6%1XAjKt|`3j)PjI?1o@RB7+srKhLw z@+h6OSku$E(>zdDy>;BM-EUaQ$rsiOcms|}hip#IZh6QbfpZ(9tf?kV%;1}TYrQP1 zYGR_I9AzQ1xoRiW&~iiIeagF0s0ooG6#(9wzda% znNK z^{?miMmo5)$2{rDX`|Kg)+l3MIy&<6=-6b@Wu2*~6+9XcSZ~I_J6LbO_vs6to22|y zBL6Pb1^SYT)%g^pChH|(a6Cndk~&)E>brB6Bln4-`AYGUhR03;_Tk;e6LENSL}K6N z9HQajy%518VnIQH-4*l&%i0|X1&X+POOhZT1YHG6Taejn^{O+AFV4-6eEf_jiTrjE zlE7|!CtW@-eW0bSN>sbi`%RJ)foVNO zgY?(-_BO46fPClRz<{ExY~)hzh+$Y^Xy^sPJ_&vhpD-^8Np)Sa97$K22^N{}nr66R z5)~1?ORdwc*Y0f>Omh~d-|2I@{&Ud*`;x{k{8wL_KTT?vt~S=zvAEG-@S|ls8iDHCv8vTeyur#QQ{6`n(p_V2F+efvk62cRZctwWi_K zu2}!oQTlmL9|o6nTQe}Qr=3U-57kU&SMZ^`4=BA#zOA-@v6)c(vS^|lOCDx2K$-9> zK~n$e>}@I~m9^w-9JjQp>^P(N_JhCQnKCF(1VfRpX5it0E_ob6(U& z-SptTzLystCFL^Lv5HMt*vG$~)6+Gz7~B*TD(x;dfNzz95qc~q$C3c*0-+7;R>4MF zSLwGmdYvQIo_(tH%=SzBtHo!9DG)PD37L3T&t zosKYU#aNDGec>{L52g-0 zkK!Xa)naXtn9Io2J>xdHY~3-czkIM((%&H2s1x6MV|Z&K$bTtcVC2tP*>?EVWpw#8 zJVN*_&rG)H`b5A%GBbPwp^6X{Q|v%%6%v9hAf&xeRaAei!n0Ud`A*sksjCcRB1q3z zC~gw*4vI1TxA%t+-7;JRjRo6-&&oAor8-THwRqEq+e$2{Fp#7W;@k(#7L}|GLQu`4;_NxZW9W!;L+SIbZRmqRG{SH z#uO8)KU;Gx2Ff1p=7=%THuOx)&(;TYOe+)WtQqmCL4U$2IU4I8{k5C=`}@NeX*7d+ zs$yfYN`jmSdlGRx3f|rMEYN+`_g~F)y3f*D|xi$u2iF$kBVy7kxgVYJ?{)pEI4I-P+n( zWofMUP&!{<@7Un{vf|>l-L(E_XU(vEQf{f#TenvW&E+xU-_vJ}Idi-Is#)mX*eY#G ztf%wX!|8tZ(M(1v$M^Y>`>_w%M^vg&^b`E&zaL>_De1K|e)z<>FNtq$#U`Ysk-bvU zCL|bBhxEyZg{d_Ogv2V*NQjI3jIy2lxrY97D)7|Sx->q=3hgvq6=AH)=9w$nQdt>2 zk-<&=^tJW}69Yqz(NOg4be;rcDo4cW{=(qF)t;06r9>qysz0)mo6Sq9sAl-(CvY^J zkliE)R+`rH0f1vx%D#)IsF-h7@=qD_wSi5BR7!#|rZ$>pg;c_?Uo?$wAc|V7?ez&O z`dX|Rn36}!13Rv(54uj=s@P|(hY@UZMFG((K`xA0*pVk}UwH$PpSyN;EUfo@UCY>z zZsg9%!J1L!%JNy2`J(-~HX+UpySxaM1)} znlNyBmw3-lVdz<9na2y$cV}(MvT0$46(8poN-O6yXL+zNF%vs-jDeIxb=tcs5bB;t zcx2@EJZuqeKu)dYshji7qbOcpT}}C1z#9@RL7dZ0tDJAX%VO>`kBHCFcXD`GzmwN+ zomn75%S98~OOK`b?sO7jGUgGO_6E+-O-W0mYE%J{IKYx2m~tSuOacU-=ihb6n)rG; zGdpF4ITHiRzn(c?DYQu-{_qQ65%RoKD;70X=kS}qe?Ng3u`meoVc>41;$vCi?|a0U zT5jQLCJH9K9UMV(Kmua==#+;H-~vL}G`#L{3yk{maE;b^=x_*;0vTn6vwXXe9p6`) z`-iB9p}wg9a8ff`qG4nd#tY#ihz=+&Pi?+)G&gUimvceRe{~fLUWc2$(vd!fX#RTE zoAhz=?a%@DEl1tlFqVk8J>?(eiXSD`)+)7NVq=eCF`1Z{C|R~1(|aG=P%K&1vgB=4 z8?ces?(TtDFFC^_qtGXb=W7L@!ulj=*clSF(zO#>khdK#+GjQ2G#EG^VWDt>&jl-> zr=lb!B}3sqe!RDr2ytm@CZ$s+G$l1&39_T7qfLBwu)=$ueca{5JD_>Gi`F|#`5rA| z^nHj(TIBbLELE)JUzAiU2T^sC-bmk|Mbe741cV1bLK=|Llo6jdHV9)L`O&0MKZUn@~6&>>2$3vB$`-|>WwPn zV~&~dJEOo70oJXby)5@w~T*yTZ@ZR!XsL;oMH(yhK6#Fzv{+2!vXwuYW6bD8{koB zF&b3JzSs7`dVczZh>!==4Uc>1W01ACc1D4S5p18r;Ao}G*i8CY*bqzx;CYt*LGy0Y zmszAuZJEYxP+eUVdte^$U|I72}@YK(rG{o`?guxU27Oz&aKI!L@R{RNIL1(sR8^RhI#Tm?p z5nUnUuP>nceknGCGo9K{IBsAAJ7JL3AKV6>>aB0^zAUu_tY~RR3FJ?Tox_2(LcyX) zX}4#Ya?8t_N(lCX=0x<1%Wz+Xk6`b5C{Is24TxCIV97SR)}lDp_?$RVUl564`%=Rb z6ERiv^yD!yFZ6oe!GXGj#8X%3QJ0dAj;r``E|g!{;N9gRn#rnxiIbO%L@m&d-V(2!KN#9UlOvbB>{rbHqM4_U`90 z_4sZG=Sj>$K(pp2zXc9~VV$npmK0P0D`it=nq``FYrbgVpEu#gA{~d9vG}(OT?y@Mb#;ve2H|OOuB{LGF1X^GVK_<@4 zlZcUqQF_a~7fHuk>3M*q_OFUmSi7+qj=Ob=nnzrYtJ-ve0r>N6pIXZEU>d z-rL)QOJMi}1l-zHDWYE%fWH2NUoT+fSiMK$pQgujyfllQKbie*U!?E1&p-it_b;d? zu@n^)L{!uflOcTf-uxu7(J_zmEPYRZr7-!$347vfIhddFSOHT93rsovba!54z1ldt zP@xUX{S}qZpBu~d-uwB5g(odII!h?Nszc3`5y0av1OO|`f`1rXV}s>C3~oB)N*`ja z+Z>sy-*V@;(aBv)7BH65LQRI8y5si+|9eKwaC+4TWu^(Cxw;F1vJJ=Y!Jd-Xn3?&? zMd3ik{S8Dbjx0ZDQcX_Yb~9+xJ(;b63dZi=!7}^xyO4I&D)1o0WR0ByrwMX9o*P ziAV3CYXNT>TWJBg)?2Oc$Z(cNJ1D4Z2?bR>`Qr`ukj6?)xuR%-ua9=G2n!tL0bBtw zp2WAui&?dGbp^0T1T4u^+O1AVdWUSN-f{>n$0V*hQyux75Ht}!A?l*<>ci_1I%wd1 zGIe7N42)qG(8A~|OMc~xSCwepeOnm|`d;fPVO5LHjjLfPD^%x)&tgzse#OPbkm2Er zA$B>>8Hmv(8+~XC914OV*09d#cxov9Z&p^;f)X5D5cdXMv6#HkEhQsSQt#mA0Oc4q ze%LFD<{9J3Qcb4I*&nIZMr-s{HBBNy3=A4i%!hDzVgV4ET&iFe?0yv6qXMwd)!xCu z;Jb|?IpXq!J#c4WWzZ7McHG|U3!Fli&r3e+==gYUWTaDgXy`9NGyeQqE0qsf^KjE0 z#_GsG<)y>V_a{hBY(mS}=(79g%iqM1Zo72sSN^*M`tO(Ayz2`K0!{7h?cER~Xeci{ zojquvhpT4kCggWGMiqaEElwO^p<#DKMHWF4-FdLPE4h`Cxh9bo{_OqQvJNw>ZpVp| z8c~PLG$AUwpJUs-_YE)u#ahK$X>~4R!#O_I#Xv)&e+tEW`PC;k1|XyiMZ!ax6~ROZ zK#Iy4eYPrLncO9&E$|G0#Ef~!61`gf&hp;g{9+OkA75tG(?ZLCYGB(du=DZhn9|i^ z{cXr`IHSt_Z4c>u5V_6h+D6OTdU?8&5{cFBfb5eYdrz}sx0SWf2-qNso+4vmC-d80 z*FZtN;sd?3dn{;|Ht(M!qQ3KN;PXUii8*|H{DQ5)7hkS-lj3AuwSJ8n^w?BfPvLez z!Y4q)k;%w`xc2lW#PtOC$I||4UyjA-awkSGNzwOpwFaJZMPrr?GOB6;Tn&ldJ z9c&)Y@t*F6`_M$>OPL(7>o5Nh0)x_o`{xa@whVdUlJ?oN73AsjZtscOQs zn2`alrPFuQzhQmFUC!4T5C=Ov?`D3&z~%Q8BfD&MMyyF?H8#0E>&{Ad>FCHti4!A0 zZFcn#LO~Jt14xPq%ZK(a@`@qTLkSH;qG??<9o^f|0QC{2d=i|Suo92}uU`;LI zveG_2?>Ydtuex|oG-3VG&7%42Xregr=z~>1)^aN)4^Ntd!(%2VD%=cEj*M8yJ2E7h0h_5gkOqH`^hYGkQ^GYZ;%VTRsI;P@?Sk{l?DUI3$On0|mnC3EUEP1{srv`0 z#TeEtD}hZ`cU0Ckhvd!Cnr;fl&659=cW|SS-t)Kc)m585w%LJ zy=Vshlp`EvKS)K|rxh9{x*#GFf`nf65>&}NmgRU@$FyZbrdJ;1B!YS` z(B^WQP@$)>c(Pm zo~;jw6iAZE`A)sNzs09v((U(?&9EOHL;>JbV^*SAtm5A(oUq3pNfIN*T%P5uepdG7 zK6Tu!=K+nck-I;u;Zb6FFQ19LxPFm zfJi#8PJ*m9#gvcYaL|9Z5}~D~J^UtBH?0LRyhgL+uP((89TvT|O-sWA%@)elyI)~} zSpKSMX?=Msd~UkA@oVAwzS3Ve8515Z-AUPbT_FJcm+#O3L{Ci3o%$`zYB6H=Moj;( zte6w)fZ^*`MS0K!m{2xPJPM19M;$&1s^Nsebn>8cNpL<0qgv(hu>Zw9&2$UVnJnKm<#ld@cn@I4WkEtoUI& zgHegNVdb1$PiyL5Se-P`Yhs66B_{=p<@&vY49Hvr-R#Cj6X}fi$&w|m=FB;|xGksE z;aWl`AxGQWGhRNH{|>zX@xMb4N7c60sxUj=#hO}=lZW)}JZAVC8#lKY2Rbo_6;~9o zVD!u^E2A+T-kORIC+hNNhe;vh9s>VzrVDNY0k6>4mvi4S2XhWA?!5kDbDyj~I%$qq z3~a_3`cC}7!Kkup&-PR6y5y;F-$Cq)7A6K*7ueJdQ=zEtZ0>(Cj#fD(p37BV$ITgb zqdn8H@4COFOP9w8P5L!})7UM|j~6POE;gY{i@6x0jP<{Ou(M`!_xBgtM@L5;@?t(f zT@3sSsG&o8Poyj+`9+C9l<}|II2wVqHMt-vX*s#X7r()1LW!NBm}J9N$Kg+RFf5F} z+-fuy#&|SnE`bJSH;lHP_e~-KKi2B%BrXe>1!!khcI?QfR^Slf9?p9MXM)MTvLebQ z-14s_Q@c=6G#+b@CX;n^%=c^o@5d>~bROvOI(4p0ohh|sIGstA{di?q7yEJ$_dff; zVOx|G2v^?g%OU*T1W0c$dhKEhlkDjk#Qd$gL-)xP$PUCskl#buY5sRE?77h3Qf2o& zd3fpz#1nXr#dDXf;&8)ST13pZ`$uXsLOiJ)G>fMT_*ymgYY|;*#uZ(Gu4HAc;{l_` zUScv3xiKc1Y<;m_wg_j>$5llj1ngLH^1iaC4?r+w*On(NbUiV4JZwkInRW}Idy_eO zd9@R;4MET2vYIMNONX^&zwhgG(^zIQ5+(@BVliG4G7q$+on z34Dyf5q7lc)NJ4Hp@)OI#z8t=7dwA_A+CSPYtY-^OlEK_OD?8CxH-#YfGKs{AEyi< z*F2)oXGO#H?hYr?5i7&C<1W{23tsnmvRO*=5bpikqnCkqM=j-Ays zGyk3)QN&g^wN*=)a0X?2xnlVii~s#QAJ855GT#Z|w3R!H{4Sk)eAXT>zW==dDbs<6?JuLUmF#D%(e0@8d?NXUHZIO zE@uDWV2{VEtT$QPj+??0{{ENx_T{2k0u-Suo3^MzMhzjw8hgD}CMYPkh_oXMt(Qm3 z>#T#UyCWY>Te}~LFk-;tT)5R2GW-0nh@d-n=hbDt!GflVeOXhx)#ri8pZD`iGk8o8 z7`VGE^SD%#7mOFNU2g>e&V#1qaSQhMECKI7$H3LsOkti6?PvR=aoK7%komcz5H#D- zV$_i_WLPsS6%>KRP*~-)=Q9c>q!(jHfFFcsonDii8(=~HDGXv|Mq_lI7#qt|%o87V zlnB)>Eh&i=yb7K=LpeC<$>bfOy-g(}a_bvt0Wz$$pj@LNNW(!^5hhdD$^1*>?{lzu z-7;vqU&`rjP3ABd-zLtauRJKy<|a4KjKV=7@do9XY{8De?t>6!SXdaF$tVUMH%p8g zgzQevl_zsX!wm5ZeDLmf6U{B>SCZ~IPCmZe?mdS@ck6?H4SL*gD4fVOn;3M?5iu&| zp~(T-!v=9qtA$M7MXL9p(2PIheQHVcn)KAz4jUNcAr;N|VLD2qlwkTK#ZDNnSl4y; zkl((qG4*DXHbAzJxZl(i$Z@CI{_qp7paux8*U#VW1)43B+0a{TbDO5d(lyOK6J^LO z>NqmUWeMy&iypvM5e{MJAZr=yC@hM>SH83D!8;hh)=%ijmyr(NbMM1jdvZz$RlbTH z92~rx_|30v*$n_ke}n*L0jy!uYF!YASHJPhYsj#tat3eS2_b;l%A-@s=7$hH@_;kA zZjoP_U^gQG3Uuyr-NEoA(5nn%!@T*7_FY+1OFATQ z*frlwA!9!L3W(igwuQynebq9EZPz_FCL!-p!;3=_H@=h3db{xm;&nsS+V=oh#JE(A z3Rv0rSNcEH9F7IB$7CEBsKy`G`z_>hftyU}?G>iDvNHJS7{A~0G-$HuZi;@p1!VB! zhmeGZ#=TpNc?0R5K`3I1HLy-mhl9rS-JePviy-_!M>yGv5y zM~8HGHzFV*C6dyObhmU!cXxw?NO!kL3DR9k`E360`_)Um&)Ius&CHrLb4G*ya2@YY zI%F*^uCZn?eov8(h%uF-=Y$0v=$^&e>;&hp#7oZHX%C4dF10;{DCBl2QswUf4_D;5TpHgkU=m zuMo-e;r?BUi${JSCtSMRolO1h?dWLHO#N$r(pf_a?s%m(fQ6$Ds2&QSs~K464~&Is zPO+cZKRM6aQ@;|oS5`O^is)!)1fC3GT^GqU)8y`}jGtw?3>)Y;^!!>=d3)vIyHK$l z^gBEvj+lLZj}iYB{;s9E#!cAo>lrl*Ci-w4S4|wVDl2Gy-SNOoNj;xG5pbc1$k^CN z-1vR>EyIlM&prp{@AdVC$G>o7bv<3nnQ>>+M9QXIwxXGz!1V>Dk>-%HZo}*`fmX0= zEQQLaPc(af&ha?k)wQ)P|Dw*Ayf%1gpSv;jT5))+kWyI!3Hrkhy4qJ;YQB;gqB{G@ zkB+%vO-S0}V_9ApB2RyQC1nUnzqPET;MDWmGcR7FiFd`pG6;xA_D*}>8#>+b>C*I@ zI>qb30bzO$k4c?5>VDr87Z;uK{=Yr6Q(d!v@ck*C!(K-)==Wn`VtOmDUio%=PS4&S zl6`ejO)s@v;lF(APbrz;8*DmQj#CL0+!Q`nj(WAmd2>?2 zow}#%Hu!fw0Tn!dUM$ORM`ASAKu@qaN_h0D`BYM&ueG$2SwieAEw4WM$vC2pW52VM zm*2~~ySr;JTxNtI;MQxD930am_@;O9#c3zE15bipR3Gb!HUkl&_TuFzidI*@&rBJs z+nVBMBXFj2J6my~nZJn073$qx2}Z#y)vi%yVqw9TY=0wrBJ$WK*6b&yrLFn;6b6og zTX?nL9v1Q+4Z#Tgg@ifPpeCwN6hhxoFBZx(m?_}AzhCl=s(P@~bh#n^*l>089)^@N zQ((Ag@yjauLvrDNqy}EcuK0L(Du7U`$Jd*Dx*ljiYDvJIBu?Ah#0fKC(=Pgx;yAtu zvq$wD@p^Tt7NT==yVoN!R!!`WCOfsMQ!iQc8s$~(?JP_A=H$11E*im3aUD$Dti%b>)-%O^+D24_bgdZ`fAGL)A)7B{`kqxIN8HarL^c<=qnr?5@voHmqo(pPQFHo zs=|2hFy@*(HFk+s^OEd4md(L~1AUpTcU?FFP;SPvnrZUAqC(vbU)q=-PQGJ(iF>i! zW?S^1yF;R>U$L(eF$UYTsn`2Hv*TTVF!?^V9In4MrtSfGQa>!`2so$MNI5+G;aIi~ zKi_WCWefIR9@cmk|OPwi2>hy1fqTHY@KDc;qQhT0qF4+V0!Mb=on2gS;1ij%^Z zBE)m=HG20h%*m@4j5~|5^uIR|yYK;R6{DD6Iro)XHcL{`)LQA3ao=B{!t?uRYU-gY zExnDCI%?*P594o#3d~(; zW|J`Yb`R}cU0{V3Ln4k;E$0fm}pHkzFzMeTu{^~#Jk zzk~QndPzEm+lmu|gggA^bo~`GcgOVX;{*0sSwKKQ9t=zXpj}8aM7~Nvo|Mm`DJ5g4 zdfS-dw5ORgyHkNHd3SGbSXKj1te;+|J{O_KXzCslvfC&vd0*Jkx@k8rO$Va)>)|~| zQ&=hN^(WFIwRU!rQj%5TxAF*!siE$~=McNHV?S-=djop4Vb^gK`aApMqbVX(Dni*C zM`&kFDbGm`oAW>u(Um5@v&UM8bzz;oL zUEx=bZw^)Hj%&7V2=<4`>l`rA(er&D@20sXg{3|WM0dWHG{*OC`8@_z%rnuPeOW{A zChFCV02t0lQ$PfWmIxA3(h0wHF4*?`(r_jvCAK6nP=3)HbDanIeN+7BeIfuJ*A>fx z4)syy_UEnE*UO3azQtW{lQUF6P=rm_eUG>>hvd@_K?PqhHfGrPVbCqli==br(11mZ zpPk)7TwHwP$B-txIC*OWOs@a7DeC9&+I3178Vfo^&yQlB_@dO*3B_tTs_Ty@x2C#v z;JTce37`!iX)nS0OeUZhen?*LDDaCzqptN!IWtq~Z8r=)#s2gX5|oS7BZ8;halzSn z(Z2r4MXWMR=N~#Tr!;1MhM;SCCetAr?S$vs{BiTdv_!xhBxd~#t8V*EvV2n{X^GpZ z*qe9;zt`REKfNXJu?@^)9l0<-p2*RQ8xm};YC|6FH44scS=;VF8*9u@DkidBAv(8vd11HK~oLJ z1{WvmPqH;nPq!35v8?kK7I>}Xg)fE*I2kmIj83P03{M;#n-Hz3O=)4z{yLyR6#oq` zVz34L=0z}M2v8-8WljtVl3K&P=X%PA@y`irCxZ;SKQ@JCjIWH-CH_^*UR5h`uq%Rr zLqy$tKpqEat;6OKg^%1@AWW^*$v3@ml#yxUCKO4t($Le9w_swj+!#{F4~Kz){Mnp9 znZ0&5s&sWK_Vm5re!7CE)tx()e{f(x-Pt)&?)Fog!z*jM#~xV}Xyf9Ah-V|&Qa#9n zb&EEwG}QG)k6>Va6X)fQUx-`Q%F_lW?zc5z5J=X?Mt(2 z*#TDehVqgu7MyyjQG%DvP1)GM9+}Kxnq)nB)fexj7IBz>M{ioZ1tu@#d)v@KC$d#< zE9=|iJabwE(3f2u+xunwp2(svp}jOlqS^egrTm=lF5iXl-~xk@S5PPtaQ;05ga~D| zwNnK?1e4lHw2_x?svrYZOdZuajwgqig7eQq?w-fN#(DouLHNvzn=9VFI5XGfa=c{< z4og_PuI_wpW@c;Hz?Y3B{kpP`JxKv}gX7Pp%Ur2?y^Y;y%J|&s>XEueUGdj+r@@&1 zD|M)wpm$~Cr{6=Cq1PtDa(aIvXi&{y`JPD$j70{q$=)h$KRPGNg0i0*9|qFK5__lb z{?l7GMC-Lf9H72LQd%8+@0>xabsPf?&1lo9tz5XC1_)VzA+|bkhtcd40yE5~ckjv1 zlsO2(5|wZH>DIof?$LF45TreYhYU^j?yhcpZ%fL~Rs}3KRmto=&iiq!KC_UJIh}QJ zxQmOV8Z2mU?TmYu&(K?4)$In~TjXe-Innp3AP4LKO?gA+V%sHvSkPwDPZ{J?>J;_&pqGyL@fzqd9Cqe2@43;m>)!MZU+VC( ztFrgW0JOFGy9^6293KiBQztP>UOW+E0(i)qgQfD=EsP~#?VI1rEa6WxG`nWb8h8Xp zOR8{}vTrnPDle{Kpr3Fjop<-AKB*(BSyw}Ml@k290@e#Q0V-5%umzzP>$?a^TP z@4i3o9e8YY{LufmWcsV?ONHZKC^X-kUn_-ab)d(Dhp%F4;P{Wi0Vhrk=zX7M%9-ZJ zuy!*qj34t`VFFN~1@t-I>2EYNFqj?9J=V#5cAY4$l;FGhHk4xLKNf-X7m)5@;41<& zPSlMb(@99JPE*E#02MVPG>qBkCA6_Vw>UrjDqBZZW-#WHsn3FxK7r%oHQGKHzUA!D zrtbYZWKu^tgzBFPo)gj6&@kJk^J+8f<%k6yR(d4n`M=qTjL0ufAdEpJfF=|d?!LuB zX8)(zK~8nTH9}8a`}#MCYF1Z!1+hvrlMLr$_d9u5EMEkKrZp1hLC*lm+Mr9;0r3)m z#Qxc45I#w{s7CSko&19PTlKXhM*Q$F3#j|k%4AkSj>ruIxtK#Yja~>dkB;A}L+@vo zb)Q>e!D!X6e_1xBXQ7nL&_K8B)5iB8Mj@f7HLp{VPllconuUhx3-3Qiqpxhc@GdZ} zqd-5AR|#PLkR>Ft)|or++4ucm%zn|+gJaKo+*%k`aK@7V8E^-$NK-G)<|MBWV7vLI zxd@z{?q(wJLR`*_1LQ_dC$dFVCej<`?3dV4&%`HXA^(5pKPf-*H{S67v% zr#W0kSzBfd+%NOxMVMw>+RkOn zocddjGJJQ-^h+nN+lfLz^aAl$nAtaudHkp*)-4#ZB6oYv4+)Hr2;qdX_v55rK97Be zgl$#T$fCsuf0gzhfY!46_UHyu8v{$cUwYs6dbO{7Bc^1aKZu#MV?hTO7qZhUO%=mC z<+qx})c5~A1_r{iy1GudC!a$8cCaA*Ws^{hJYL)~cM4icTD*6h%NoTeV$l&-N?l`~ z{R56um!;qpHuQubw`0pIbbXMl3aayhtt}Jc!q+Gk>MvJ1qda%EUngLnfm|KQ2&vhp zG3Dv}s9|npCg0nK&DosX={fPOYFT4*Rq$$el-j@AWaiDl0H^7V%}YRziKOy{J$|9u zKna}P#eMmt7B4}6LQp|IOCESkEkz=v+zc>F;6K^DSu0XxRa~_77PrjPmT&?GC%)0+ z+~n*1iI*w-sytWxXyvb8zs%aWpzd$R8u2A!V7JH2bC12eT=8$)SkY%B_v6CMJSj#! zeP##}iNZrbxLW1QE-ns5+_84{PYy~0uPnv4?hVw4Y973(B10p@8_!78Ui<0A%N%iy zV&yy_J$voydL44gyr$vb%5v4BLoEc#-C@kk(~6&1js*ucQM6kw23giMGyO1U&KuOwhAd!H#2R$H)H_rb@Yg>tDQ{C) zpb@bpZKbI|G{Fs`!ypfnwX@{J^#9kp>!_!fepJ@#{nzQQ$b8_}ulbTkgUX#q=-t~E zkSD}5KYz7o^pJ^px>n_FyFiVB@yu_Dh({5Qu%b-=0Cp@B94Q zzkfBX?r&^0HC-Y-iX)Nq2&7TWlJSU%zxM7|zIxlq4?Urte1S6MR-^NQD#eJ1FPCiI zoI3T&q0wuzN9Gwa)QyefsDQ>qdHYBjtle;V>7v0QA`-Obd5GaSuKwY-(ulgEsb~Yly2AmL99BpubTrnMwCqrw%48z%&DO# z1!-w`l}!GKn^Q5$u=(@0$_vMnhlcl=#c7CP!NFZxuU8{YBr8PiTXZxdQk6oRII?3tU`*i_;$$jFvZ`2t9)a8_W}|_bXSGK0}~UEW24^ny~+Mv zHih+F>uu~CiSbALEFOz8 z?Qg#LjMemz{|SPco8S4MSpL@ri{zSPmxW>9h?Z3+G}{0Ed$A&y9@JSGWlRS<C#a^Y_ODXJlC{ZD%JWsl}7 z-?!oCq;{Q}ta2lfdb3F#SP1wn@mBovyWjTPJkZ8_8yBp4(-ip%K` z^s9`4DPQ~F8Ir)VE>C3uooc4^?V8uiWt(QSodMb0j0pWEnCDkSXmK?G8{ItuZ;qB} z!dT0K@m>cPZVgX1f%B?*eyYiCa{MdHH*YH8xW}*f96!{S>!-e4oy;sNKCCekcVGK! zmQ7}>V@@5E9RXbldklRK$DuJzF#%OT2Sq6-!E!DO0T`(jwYVBrucw2$btamcrhkAB zd)w;QNm~8RwIegPu$ZkkQ+w^?bP;k&3X`hS5I?knJevu#m+kP|QB@U_fH1pI#LErz zAo~}|>@>41>Voa7E1a`5+t{HecyyVm4>)`zUZC=`?}xj2oqTu!CuBjVnwFrZl*(3Q z*vVPsaoz0T#)ro-P7JsbKE{`#-Y)>EM3z@nB%5zJi1jxlCnu$6$<#$eJR91JUZ5Gy zZ~1rU=R_IB4MHQd@={GTum~@=^&}@uJ10qE#{)ZG0fWLZcjayVwGD^j|syz*t z7WrPfxDHe(qlbrw)=6xb{y6}bujGT{8iPX-#VGCITPi4eW*CA!0D--I)sva&%Q<2n z?C=|KqSPawpuIMb%xlybc3Sp^0;toX9nI6)1ab<*lBiIXFXD9i7SP!{%##(Ff1jRW zH4O}_T?}~7SP0R44&SZkNdzI1u5iWQj!O{QNx*@t%?|i!SF*j0FX&(__AHD&z(@`+ z8RAbIPO_@AJYAgUsfZeYHt5A4gD4b~mB*U4w*HZ;%qE4F^W(*b8~|RbAWupN@8xIg zNtTVaKH^RwsYaLE3knQ0TUG-j;vO#xO${*hHB(Y9ghq(VxtB?y4O*aJc^-=_2@-N= zc$fI0X?%PfpMpXH3rrlVWDCoD#4FgaCLl!tobX7s9|Wz3F*gFdHCkaVoGzXUw~EaXWxJgh(yJmvG&ri^KE{ytIU@3zKslKlyeYeab) zS@;%H!k!X>dKZh2jznGVuB@0C_$j@T{zm@(sW11qN zu64#k$HJ0H$fhsPx8{-fo^}HLG=mHGQ+eaT(SIPi^m@_*RNDH5w+AXAO3c7=8-#h? z<6;bSz?|5^n-%ouz>8m*U+pUHA(Kjl)@5`o-LV7>X}F8ijNH% zB^FS6W2Q?3b^qI}%zWldK;l}Kn=9^HRsKs#!>5xanM}k-@U97VL+(hYB_`$H4()Q6 zI02a{dj_0kPcRCnQ2cdxe9?rr%GOw4Z7q9m3=CwwKg@kaS?U34c#M`xQ(h)!MD)JxZ z3vhGKJ?lffNeO3wQmsDqIk8USVc0cB#Im%qKh5>+4Q15DgEn-Rq#q7n`Krpu#WESB zDrL19ZGDJkrKy-#WMv)Df}Ey;gDSm`;h-y``xVaIv&wIyq0yn;}G2%=(_Je-M>sK8jwx7b`ZdN z2?s&?sJ`91Z~8z35|G|S^)}{R6e(cTdN~N`!vPoM+dca25##vJ9*vpAYu?j?g98R3 zp$rP|GYKmuUT$s%-FwCNtOIw^NIec}R{fW^5af$$oegecl#snLXvtLblMlAFS>#Un z;X!6bxOD~0-fuuUiv)tX4a0zd-=~pL9ra}}Umb@;1hPF3Nb8qt)G-GCzV|N?km81e z*uokcxdq~Ve+{Tyr010@w<%QWs%uLBG9MpkLWW+NFAkFKjT0Fa4%~Ut9-jZ5`2GG( z!sqAiQ@Vh2bL+qh3%ZMQhbBY|`g{yHsGGcc;sw70J5VEM03F2kwPc<-3chwfN`S_caonc5x-|Dp9_^ZfUHmyvxSnb5iakvJn*yC`GwTa+q7!vqKd^D&u&QO` zlvrRst;TT`Hq=WGJXXI0uIxXc=N(Uyz7}3NZ)#0w{B_SlNo8NKLV${WlUV;hb zy_^5&e^le8-jwvIP|JcXPYpn1kPo<%yP4FJ+)tOkhm~K_Owkibg2$`I=fyd8BPHj5 zEe>SFNjJDm!Y<*VF{EnolC{7Zqf-6u-+__csu@C7{e#RLvE-93ud>X-k*`DHEN0b9 z8@RG|FhRaEpz}(@$QVI}i3EUM(XfG^AwLd3QNg;V!nqpv^mB$Hg8#E0eC6TIV4b&F z@MV#%s@Eir>2dz|8PNEso8Ga}-K6t;@m#n9TsVnnfH{FfqFK^0i_)Op6(F%G_h+(%!VA|zC_oEg!Lx@W6%8DL|um}4yXN&QIJ^bL%AoNRw=eqQS z@K!EprF-M7UM7b%f(E}SfYhaaODDwY8gpI zDrV_cI+Zyat|5Ew<;fLV@|bw$v8C&>%hA9gLk`JjuK&@}hNNe_vz#y28Yvd#Dhdfn zn&u0u93HkDg#59lSSV8bZnI7qSY6FkOzWhlQ@?rh?dWR56u|bs|8AC3_RAW_M{c?*~Gn{Fhf{^g3Zgi(yvz`GGMFb(%w(iIA-*nk^P2Gbr z5neDVv|nNL$b>hO$3$YPyzq5@h)g>zRCn8o!*c@dfW4j1T+h=KYT>6}+i!$oeR zE8_gz@tIG-s$Z@FO&|Ij`87(kg012JvbSY&JdJ83^v_qrD_lU&i&kisEEGegd! z@Sm}O&BTocH&37X@6z2paJ}Z5BDdhF7LhH0j{#$64(YrRA?m7lvZog{QfKfkT+2L zZTjk4NzP&3z*t{BQPzA44%&E(^eEZn zD+fQyn4Vr`c*r9#I9S^ApsXa>oPjgrZFOPbu~1Sq3iQXMTrH=V@9BZQ$hw|w>&@=| zti=f9Ias*zf&cY)vxv&5DOf}(7q+YTr|aTRZr=xZCqaXe1%gSt?_1}#>Yx!5S`XkA zCJw*_!4|boR;#I>lGb|P#G75Wuh$DdEY~_d_00=Eh}ieuZ=*wOu5TZ6u^>p$dw<)| z(qb|@uj*aP$WTY*C?i&|!~DGZ&4xZ|Y_z6BAs~?Qi`vq7J8mheymoEPteqUVn3&R2 zSiJ(lq)L$}hf}$;zJ8_+OprJ;Bu?OM0J5awpY87IpFXwKTfA~CLnmA;?+OTz_-A8d z!%c?^=WhZ+fSr_yN&dVI@4uYZ*Yu@2vHezA9#A$Xtr1D2f&yBmyT$5hRgc{VdG(d-X3jPAQsX*+y@USOVd z=eAS2RQ+X(xhvJHZEH&du*|Ak|AmutW_ahqcU?Nv-h|92R{AX#)UD*;@>DmNGF;j* zQfV`h1crl^mpuWR7q9IXf~kH*8Y6_vkfFuE)zv%~;5&+3)1_Sz-Yoh}-_7&O4Qs$_ z#VnJFU+%uj1*$4F)es5da~3Pyf4AY|=D#f7N^1wK!%5+|AT@4p_jLg&H;xusy^dL) z_cM3Z9S~I3lq}P5tK_NE(O|jDAbbZ019c38#Q95f(>18_lfeJbX!&z4%>hg+wd zxm@X^)`9W1o)SU?ZL~iklA3ygg0K|0_8+vWhA{>Qqlu+J*>9Fs2NDY6HL)mnFxbFr z^zW`f^0iQsIiuav8Jyr(`<(h~2+A4I)u5T=(zdNixL)}&IqdWe2ySNLsM;>_Rv4t> zZ$}A8ZS3-0fhc7b^rnKl?&fP4?9VByjnEPk6Y22ureMU7UK($le)(WxGLSJ|{7s_65OG~-idVS_{@|Fxc z=%lj0{UHQ76v)BiS0q8>)=uhvW}ve1f;3&CEMPR8U^}KggSm>U-d1=6X`)*1cYMss zw^~{=ppH($9Gv5>QhWXSECVni2LlRl8)m@S9@4kBm;M9Hi~LUds*vMmPk&P4tP!dO zP5xPj8sTVW7V88ErY%PVV@6U`hT!&Nd{ zAchTfYwOe+!+=iZ#e+W~k8UIws!Sw>#4(c1iW;3C1wU%wqwj1g2q$gC2%TuI$X$Qj zh{Un`4jI5h?6AJP8GNTC>D*Qc^&3ny)mCNcgbq@Yl9IvCct_2SAMmZGY*T}KoLjcv z#AfT9*L;j0Xd&~Xhe+ogGW$7GuX5*zp$(-~ptMZN%3}3mEcYLey*}-XikO5LwhDIa zXX_sClPmjoX|+E_Q;39yVvw3!x=|leSFdzyT0s8seSOlP(jPV52_4y$c8qA}<6jb1 zUJisja9%`qzXBP?qcYuBgLzMux1atD16i$*~>=2N5;euv+5oj2_ zad3#Mx8APvx!>Q2AqN(aOW$h?e`B*!TKO3wja**#*&86Wrv@5^Qo8pv@fb1?#E#pA z2|D}1Cmok;{mC|Z(H>77y29%NjQaRc4Jgz9&;-!o?Ll8Z0f+*)Ky~e;5FHr{Q+EikG2W(J5r%6inCdur^p8P1ORL1kqMpkoM z`{o*za3sAT?rpdG8#C|k<*tT`N>!dY#Vsou8XA|)KhW*s$b6`?Sxy8qT9rp_>iQNz zLW0VdRd22H6A)m|aG!4;$CB6DS<8(6qOcll+t6!uh-dl4i_5nEostmpui>&^YpQ2# zD`$D8O0J$p|GlXvghb+>MQ=MYGKvdZh7rTVjGPYGFLR7_Vl{l_cgryfnutNkhAvhy zp2ig>`nJ-_pjMe@2sFD@|xA1*~tU;+ggHR$e~FILG=WlPl<_Plm=y;>9dmuJcf zdE%ZvUgkX0Fs4<-#j>DKD+ZSNS~IDC>PVUwmm?ej+89L#3XqG-l9640SKY{!%#98I z0qfB{vydMqQ0$;fx7&RpDCIfH>!il80dGXjAbQBb= zgrcRHRb`=tDjEyhs=xxcT#Ihqh#w1;c>{NMY&HFK6fFV7e0D5A%%2S^U-VJcZAyDn z(?S+ncp%*Vofk$?oRoxb7MkCBA{#!sp?|(gS_9r+7#-jq-KZybLMid<~iAr5QOzyOAbgF6jd9Sz~k9f3ii{BUE7|D8r<{K z5HZQ*fsm#E?5`mR^XPZDiYzz-Wk%=`I7!X8G8SY~>eYF`-E?=JH3_yumH3mCd3;C#v8K4iP;GuT4)cnz3RQa6Fr>I2e;KkT@;S z!^Ug8&mc%*YEXAFoum+Sy$Z*bz!Z1io{Oh(aakArsLg?0RutlSJyHk=q~|I3JfEW8 zo}O)hoKIo0&bx+EAU#mLiqlswpc)=-{;LSQK?-<7mulP+;m36I;{(Zu+!4ez@EIV` zSLFx=jL0TG7%Kzn1tujQ&E`0r$GgKhlZ`7M$eV!o;ceGlBraFHm7N`s9$Tmb)t)2~ zabnic9r|(v1o`dW9oD2uyrr1PXHeoWc=#8x#aiqRn*&ISGb=L8!KQRsU8aStwdVnh zZPPbe`+M4vc63f|K$hu(1eOh%#f5s!mXCl*Oefm5gM*|>1abF#uqS5a;X6{jh`9FV z<-B>rG-y9FQc3fx3-0Q5$s|%txWmRDv6uR1>bl>71r-JjH3)cfJ- zUeSwwBAq9dEXMTXN6u;tbS#tS;E*fC?tUUn0^~Bvrs^>C3tyD4Xanuxyz#x+111v^ zOpfW<3P$LJu*mqsot(oGhD~`{ggDsQvGa+cspAdG;Y!^GQ)VQs#YI>9Z&KGoZ{Rmz z!6wJXldP8H$|5X@G%yZ9OQ&ZxHILl7nx>Q$50GgO^Q^(p_X%I1Q8l@Y;YD_4=91C8 z2~_&*)VeD_R9d1^bqK5-R|MpuT5Ceomm^rx1yEy;q#<$lU?n;tHvQ7EFZ_YF zT#4ZStycsO>@F>w>5-&8Y-((r8B1$09p*a3D}b$f(FY6d*(v9UCKf9XYcJIpDN&)e zDovX}49}OS709jfjVhPf#YEvBk6)hj_3^_JZ|x8O&G-{Rm2y)$WqaH7FWHJ# zUjdDrfqD3jWV^2mRQmjz3S=>bqSH-SR746vNZWxLw&QRcjO9N)2_DT7TN z1%TXS-bImB724{yu1){@QGFhs=Y3MN)^&8lV0s}bNs}|CvSeI#fs?91lA+uNY5&}ho`bx&^y7enM5S?EN|8_lfH1Et+o7JW<*xS zndUG81DX+J{B7q08K!j2sEPSNF{DZ%TDdU@2Q8`UPez2j?n$-tX)`e_u9J2_%G8x){lwcCZ9qtba7#+)IHVli3DsHQYa+Edk zebO8mJ5m`RH&7G!Msh_VVZ#85Lif;dn3$o#zCH%)(^}hOl5!?y4G3;i4D{2Nw#Lt+G5{o|qlvjRZ%$8tIs6y%280GCpsex!ES?x~ zf?s1`SK1qx9C87UnSqBVMDJZ^PS#MZj=5BPKR`4W_dlNQKfxj-9yOVitqHJ`LD{S* zkOr<>tc;1dsNnvFQ8$Wb_-NUgL0Gsa?u|w0?Pw=n9@t;`lh_W9>%)NZa>?82>xFzP z8254=WJn9Zkf@A}4GU?6Qt-xc{m3?5Dwh8^ zo~NC3T*HOhil{P0wb!|guoH|FLwJ`18;lR4|ZmFd!47mgZk+t;rvTlxkY(9x+(_2AyMxa`0A zyC)}BUKI0`ED=afeQfi^us8x;hW;_2{^ z_Osx)G+44dAllWfH4eZfat)8nSo}t%;_6EBT=h+_t+HeoepaS6 z{Gn{%i=2}a=OeY`P_*&W1A^0c1R{BR`gDHxZ-r8h;fQLfGg|PH>XHchMOd&l$h z^Alu*i9ES&GH9sJlQcx>_VCw()Rk$A+uW+TQ z=Z^lpx3{0e-9cMK)pfF7sMJE!*SN-0hYiij0zVm%pD!tJyLxGAY}|dLLos)bs~ew@ zA!Jm#gH6;O71tX0;2nX2cA5E}Sb^1$&dpy{=-6e>bgrJ1TX&i^mh~%{&y!697ink@nd?X z&&{$RpL0)To@jD%azFKBA3r!$<`*O*@)sT*se3QbVz@rS11&x2kB~tx>Y_mxJshgw zFR_>z{ zkdYixhUKlPt?jP=Uh)yeWjS@+T~|j(r}1$a#{>&CvX4L-NW{bcsLpW^5>}Vw_>wR! zkSMt$%DabNDkD>I5zJXj!6A|dZtV?)A61o3u`^4qVx`v|tvqBOJ5o}5cD_%@%Ax`D zE7R-^Wt->jxW>oQzuQ9lnTFnGg^#8NbxbU>I~f{iHj@iE)nEV1w2Il;4h(B zC*G#u%lEhUe|~9OD>Y=W1emiF&qg9s^>K$s>UDZHDLT(lF}*!qrBdPH0S6o2F5&s_ zMkT{07EX+Tipo%xe-r*Z6)ZV{m*F1%;6hDhN{5Ss!#GIc6`honWT~4&t;oSLc6t8r z(Dxt4ERE^N>(~DcpkQICwPdbdxSjHHp%hI3Q4LWm3txi)rY zdwZvczar9@GCluR9hq{WP)(x{X0R|>Nw(+LNK9@m`tWfbWUB2PN=cyQFmPROQBt!g znL9EE2D|3c)S@tdbKLy+;z2k+*kU}rD&NFgluMU1z{Nrm{0`_Tk0KE1m)bfH)>om# z$G@+rdX+{f7cC(X|ALmDo_^slSxfM%Q15#c1%+e+=C@4hlbT$FbE@cQ9|X`Okzug8 z311m9J=b@<=IasFKMl&@i6o1e!&NI+>VDqv>ogV@77l??-e}wU9zHbg4^#j~q8Fon zvxL+;?QpEXs#Q8pj8CeKPjSAW((n(Y=7?MrO0ek6eD6-iKeAWa;dZQ7Kn@T0GfUSc z6I=3W7=^wvIWyCSRXvg?x!vrv+JMAdCR`}Ma;-G)EyD>P@?ZKamjrYyEbA}J+sM-N zCI!j5q`}x*HP?Pmf1ebM1bICxsU|M}_RQK~Ne|)Y9Y{oPod@sgss2WB8${u+vdI2a zyz@GQ*4^s63Ssj*=tyWR=vT_7k zp%4*1zPzt*KuWQ^1`IK&vO8Tcwpi`xZnjrfb9(*49p6NKqvhu06~75Q@R0xFMRP3L zu0JwkA*ymw!#X=Vn>{^*$hN=W6pY7~6h$ZDOiCV4Y;A@A(LY4h$bGc@!RA*uHoe4b zUb?}}VKB8I2P>;e3Y)=4?w`cVH2la^!GDfT+}xVWUsY8|r#4+Zd@(z1;8u=}Q;^yT zdO<@&Q_#&>LBC8+UyJ4k$UFtpLfsOA_A})}a2;6AV;qB}mW zJR4{&`Z`OjoNAL0X1)O>L$*Ut1OQb^YU-b9|1;95Oph%Eut(>=prD}U!H?x-@gugU z)?WyyPV$?bTF`&6U|=)AB_;-qeg3QWF;&$JQ-_=DXBs#Y|Ewwcpk9^*G zw3a9jt|F(g0F+bso%^#h3kFt|D>5cS`4C3l-5j`XFj+bXL15BfOX$LFHrq>3NLMai zSx8V%uaWzMgu8__y`(Za+AJOn7cr~}HLsp7Vkr0>Wl6~nGWS&`0{ntr*>_4)O`zsG zx={|q%(>`8Zfk3^N{MZoJK<;VE*)>OS$?zAnhq;T4KE{!OciDU8dV9YsU`4OLftn7 z0;5yi{>5oHCfVi3KZ?ql7jF3F9VaFyQ$bL&j6lvtUE=rmlLdY|mFYOJKcae|;B}Kt zLda{F|6VUx85LLwGR5Q`+2M{(`!C9V0Z~y2Ao8>RuZSL+1)*kpj|J$jBBLy<%Jvtg zSIS9rusID%SG!d7NJ(vD;l7K$dSwJANk2CFQUkz777#AUbc)oII$&Jh$v8d_9-m>y zrb-M1{uXF>9+^rGiSY{<$L7qXs;}*A(`8y) zFMl?UpimJnEh1;3ZX!hmg?LcTl!Y?5jiM>1_!|XtixvYD8QKJ@k%a!JRC%y()b`zR z9pU1_KO(Rl#eqHyoy8_ONW*$q1&Z%XP5&Hf_>mmxc zb7E;m6Y>fg_5rRKCzx<_3cPDciYgxo~%rJ-PQ{Pth)ZkyWG`tJqC&1bTx=$weM zu_fsyR@A8i7V1meU)frO|M@+38(@}_x+y5}Y;Q00out5MpzdQy{?V+U?bTkzv5~qV{;Qy;XDjz!*2CVW zQC3#AzM7jVSlJ(RPQk}dXr`xc+C|>o@D1ywPfT<3NYcS0$l$uL0O_UYa0{0#Wf2|7 zoxnQgAVV64|JjPElQ{SeQx**z+~Y{-N?uL-yK;A%MQ$$1vAhD`t-ZZ3rlzKPw>o*! zva;}<_{!aUQi(9Yp7B0_{&7M=LNT+VqXZD*M21@7a8iMzG{&4s9f!_`F@t^nM(Xy(<=7zI$FD)&}kQB=d1?T3<2PzsGdd6SBv2iDR zv3hlduG`^Jx5JIKlCy*_E{J6fk0Uv>EdfuEkAJ-2AtkMCYgSaOi(?Wx0-Ua;{l^Xe zSy*ILRAFIZpUz`ncsQ~v0QIyNrJpjgvx{keP5e3^yY%!y<9@zA5cA*(V(Crtu1(r? zSjcujcTuWvEsa3bB5h{2mF!)CBw{!Uhqlnjiqn_SqM>lKXr&k|szME`}mx){gg z#0%_8U@5d>-^Ax%&rO^AF?Bnd>tTT_&jxG_0I>Ol6oIAnSM4o!^B{K^t1a4x_^KZHzzoKLA4(UxK zJ_7u{2I-9HL8V@k9_a8l+BE8Q&9l_Fd)EQ8aMFDCFtAl1-nNPy?I;Eghl7SjNv7pX zIZp7Nv>V5N!Et4Y85t$i1ji{={*=7~EdOaPzf-3Q@_B3?{O+TvqhreQlNg*CX>~C0 zFF8;jGa(hw_pEAfU51R$a&kz0Nw1T-78$l77n>ubHpG(21zH+wH5#Jf9?-}#m1b{nA%1lz-Z3*VmrR_gkD(#Rho z`@y^V>-;~DPEPIFMq~kMe0&A$bLy+FQxnKK{3b|bRGT`{>E$`dgJ}ln&+ku}P+nQ7 ztW?;hNNW3W4qQi|(oVB2=ajwL>)BcFcAuJYaN@sKke`o!8Mnk|5w)0PiwC+QR1c$!X=A4SY7;S>qYiQ)^aEzJA2|TA{ZMtzdLmI_wSD=U+O{`aL?JM*;yrH zkaHgXW7yu_9;-Ke1MWwTm{Fv^Uq;q#6Xp=OXgEbeVxrkX<#LK`i^XJESZHCfQw#7@ zT(WZalR0YW4xxWJAtmK~AwK#8wJbfZWVlbQKvF_N`LE@;lXH(3dGcgoz=Tl^YwY7E z>8;!bC4!CZ>|&bla5Vw~L8#xOVq^q6Rn336WyW*rwP%38%leUt$+XXWM2uRuQl=pC zMnzQ>^=s~n(j58!y|0Fzpj>3+5z>UtBtop1qPnoBx$ID09>ANx!CwZ(ou*5~aH>uw zCV{4AW=epm1SsIJvd@hNUWL+sEZf%THN1iCak&5*RbJ0pBr1V{HIhwZKC>m*qryak zZL=Q!zsLG;vMTT~NM6qBRgLyc<*ro9v$N$Jg4_0I`*D{te^q zv5)UbOCnDPRU^x=i;B{Djg0b)MaFHsZfKX}!LNrq$?EB$j)+xbeXeeoJj(dmy$3Cu* zC@lT zAzU@M)RJk37&A8u;nj17i%Uw%Y-V=3{dn85f`>miSk%_6`fCi?F{Yecs1ezJ$(V;A z5BVFiE-9D9>%P;Kg#|hTgRGh5Wkzyx@|n3g>hW=mh5T>BckOTYkB_ysOpJ_buG^w@ zn^_Tg`e>=C^@|D%`lR&q%C~oRlze@&3yX`(X+%VN1!vlDB$3s@pE9b~H8&5WDOGQf zv#BmnZ&N@(;EE3=?K?a*zR{;0GxFY|?Opofg$1{k!xzzUL7TIY@IknvX^|}rLqWd~ zenpb&x7%KoNaj+ZkqZOkMWefy+8 z+@=wqKcjiNTVRL3&QuMg?&}3x$;zs$6Bdp948S;o%5ul2hDg!imoNAMa@g7%Zd~>* zk@9xXN=fx5q@_`Ts!aXwkenzJvv#$n8xBr1s4Ec^ziXxmsG(qS)$-$JzkeovSGhBVYf%4YJ#&Lh-@o-;@1-IOv0P067#=uqX4 zj7QTuqLsUIk5RqJrY?|VkOdqdJEl8~BAWjohYvI{H*W^XOZiuBI%6q9aJX1+crP!} z8d0b%#}ZP~XxFSLUE=nKjs3nw7SB&_605UtaMYEf$g`(mV(5NtZ!a3CC#vbVh!j5W zd#P1OIbj7vL?59SHV$m+u&`HGE~Bq}&z;*&yNpI?Q}wBH{*IY3iIBv{2gmKJy?xt! z@mWwv=%`u@Pg>ad+0}voK9co24WasoswMqQcZGFdjHod?y^$JB3?J; zmfdm`pLB}&2K)Of26*ktzDEv$*Z%LiP4r@j=3-7mZJ4ObA;kMYBX#KE?`0Z?)DSn) zXL*Zprj(YokIl>s1A5sDHU^i}+s1#msbg@7_hrIe8-F_1UdFg278XkHpJMq+5sGsF zD9fZ`V$!>~a4M>(_&=V`f+5Q9i`FVAg9^-$lETmkBHi8HASFn5cM2liQqmpLjevmC z-Q6wSUH9<&-+RA+4m0n2&faT3Ywg){-=^}KiNHF+F5x#sL~$FOJo8<)DVu8sg)!5* zlk%@)d~HLXu))ChXfT+HcXfC0Hdc)2s`zJo;C}DFid|LWsTZ6s52tyxz4n85rE?CW z{dO~^c}yBnk*2SonSvy21A{dd%~h9}sB^4ksU3({xi;e-@^4EKbj3@LpGHDO;OZw# z#h~q^ZnyhH8p_^p(FF?c0UyXa=J+yAv0-B#deDuGy(6(QFwy0oNlIb~+o;4&Q5QuN zd9p}_i6}BAPD@W853m=_0lo)B*WrD z*iXdXV~)B(NUYj^ho~wMJ2LV$4>=LCk8gp{mGQq%eoUu(>1pa`?U`fpr6nYg0lu<_ zj{z&=pV1=3zNolZ1r~&(;}JFIIGMu8&6o+l|6qL?PdbsNAlbQd_A2Xg;orZ&W43uC zu@7&V%LR>Zx%MxAc--x<^#bZN^s2)Ty2;O@QsnK39BgZA`^w?9aMzy?Yr*WnomqlK zS+@!kN(kZT^|K&h`OVH@9|VY#VB_@J(o%Y`{~tJujcz$7Gczp@56{F-I(mQO!iQ=q z;raK0re*YHOo3ys9I#U)46zTnaImpWVCtNjn%b0z&}TNHZbrK z-R=>Mh93A{LKus)0^b@pD?4%^4(oBj;rUmYW31g>7lH6(qrW!Xb^8Ds z<(p#0MrY!P5P%z|?f@Z8v9$bo^qj;B;r=yZ`@jS?%F0#UaTpv9|GKb@Tc;xJr~Pfu zbGMwYOU}zKiF31{pkUi0LQ~96u>w<*^*m;DFy=d-TOzHN@_B~fA`4y0@y1SR&#S>Z zl?9Cl|Luz4!G0MpVz0N{Y!jtA^ozF}Y-<60w^2pFz?TUp8HL0cJ}z&;Cv?KE{J0U4 zk{Uz`IwaX=r4-})#P^VPX$kQ_HaHa!CjMo^1fj*xV6cJtjsB$dvq>GL&VdxJ@kNrq zAIhG>*wHaCBG_{!M=h=6Cni)~Z}!4^HkSF<`hv0EtpQcI8`n6Qq|&HgwZJXVBH5}% zK~x`+FzTCrRu<#XB&Dqw$AKSW((v-g9(D6IyW$u1Q60+XC3Et@|g6Sf(cm}~En z{^bVZa`}DxR%n8+{TGsnIcjfL6)F;ly3`yR9!kaYQw>?^yydsA-G?&VT)c|9?)w(e zhb=kp=(qQWeA&512S~dT$_`b@F2|i~H??=tj3elVF4Eb?Ylr$IEgaA_?godr+ew~_ zuB>Knt^1RK*_?7B!g|ZRw`K^BsA*|qfCFd- zud_leDz3O|l9^7B_X#+w$iBrDVtjU{#3|sWrZzPz4}ac>3}maWzGorKwDMeBTV!}}DQ5Q#a+K)%P?8F}@#>RfS={8@s53cW44o1BzZIr~O(=eIt zV4tM5%6o|i2Z)UMEVFNVpx{Q4rz?1HQib@upkte}DdxUCB$W&#_VZ&kY5aDXT+VP4rcP_0LU%915h9IITxVB#g*mNZ{(@{@xob-QIl!AM0s_2d2XuN>)DNXk&$g9k-r|JBwEO{UAlN=g!@1^ zpg_81pl!}=(YT(`_}2Obj`E&@VkS_(&7am!+S(=p4F#njF);1nc5^*5= zn6IRp=ly>C)PW8A^-8?=Mb z=+{Y4CA30Fd&L^R-*of70OViggFO~sJERgB*Z;Qu^$_8F0bk<>HCyN5755b<98-a` z3!E}asmMQ4b%%f1Z=(L%NGyE(;E#umM9zHTCn`DYDatYA*$Tsu+q8!h^d$1drH6;)>2+>ba8$^M zEEfOZ^O<;kkNy5Gt?`r*cWjsQZ`FHR>NxUV)qHVKT3k%ya%0G=9mi=lb&eSe3KGQJ zALsGRRS1!Gd7^`&ataYll0Oi7}uKQJ6vx0RiqT7f&Pr zYl_ktOHv;WqNAfbrhho52e|5nkZ4q^ny}ht^vEy8=y-D3-FGec+V+3D6`9TBUZeez zD{JNWb9N3$qwZJqjpDuMPc3>1OPk=EDb6K;tK9=HNY-wz!_%6V!$Br9cH=5);_GR>SE6mUJsL4w z8DnK8Yx@?0Qz&vDC?zGmG`HmHJZo9|bCZCCj9EYeX)JyM`Vr&vkLMHeWRI_F^0Nlw z8X*~7ZVYtv@%L2)oF8nGsj1M>uhSRCFPbiRs)^^DJl(~|#FJn8d|T_KZ>H6`9>K{e zF77{#p+MZF<;2_8O|OlwZVfb-rykVTr%|*`Z`SEeph9B$4#)i5tdP0_&|a1EJ)0W* zdHH3wfkbV%tR;1*2EyjLv6P}+#SUI20i>iMIQXUOQG0Vk#R4f9g@BH(uW#(LdE{|& zfxfIC0Yn9T1+s-r6ft;y_S_f(d#9if)uEW0m$k&-C9OYMV)tvw_hc<2z^iz~*H|(u zN9~iV`s1-Fp{ZH8RX`B?B|TpVBJ2lm;Kr!^&_utE8-G6Zg-8IfcEeLf#$}|KB>oxQ z5IiH2XE)*pC^s-%teiI9JlZGHQlXbhuz4$En=tOfmOB!xVp%XN1Jw&H3JWIB=ZTqQ zU-&R8xVa$1!wU|*d6j#Xe~}|1BH{xfm9SvH$RR}5wXM7~)oOHDtV?!R+!kba64Z>; z&G%>4%A0a*U%y#vX*I@`$E0^bL{`u}R>fWY57)kv$W+S=^S`X=YjY%$ZiWcb;u+no z;P9#GH1v2in7$8q>g3q!UuyuulqXIiSfe845e!|2sy~V8hMm=kPktc(-)#!f*Z(}7 zxqf|IpzWeWJ(#-46F z!}Gek4x_Tj=LKC{{D8smRsy)>q~*Rr870bI$VvigxK-nvh7pM0B0$TUMtl6ymX{wD zY)8*OHXWUmAN`qo#D|Z*jP?OZ(vK(~KTx(4c-Tt|&9$DtnyT=AxH#wmjJb^5dfBB? zDn@k}TgO0d+w**X+M6~@VVdn=`OvU1CJLkH?*w6lDZ|`_)s8FV_>B{DL+igwaz&F9 z4eMBJ?46rW&rg4Luv>w=A;DD7bJHD<`sUN;Bx_%n?q#H;#w$oHcp@4>V}koQuJFo_ zfr^#AR`vHro1A!KX1hgyZ%I)Rs0Sy>5Ce*mmiyl1T0vearczC-WlK2lopzOaou%*2 zH%y*7pdz=yv}M$vjHDtn$fUPc^6|K8^*U}X+Jvd_1_yQRv^svAKsF<90%zCp+IPA9 z=V!t`Q(!hpGP}6Aosb?l9dj^6Vrkirj=k?C087JaYopU-Ej*EgAk#&ES8|A@{jv;z zfZ>kF9|HAb^%=0@~7!i}L*8Z^kf@MFRg0Ge`^-ShRXnBrI5ETG?4vKq|u-n94tjf&C#7EI5|q{Zhqd?fX~+B ze$YrK)Y{tf>lHb^&k2l(h$4dfc(KCe2+nbTsgj#_N&<&xO3V2`i(A}r)iPJ=UYmg# zz%BKl(Qd@3Q)y+T(*0?kaZW{rous2ml>gk6Jf(MeM;)lG4CVDXR1DaiQNwzZtQ#rl z`uoxGaq_HfkOb^oLBh34)8RMl?t$D?esLJm&(z-q`z(>x4YUb zi5wZx>s=34Ov1f~5JvX&Rh^Tget#U_EN*Y#+k`jsozmu$lf@#1^WCXa z&?F`=42z^9$D@|}>M{$+#2*S~!mXaLOvz&-lJh`FG7v0e>?*=#2YGwB-2{zslKl|D zMNj&%iQwhuX>3fG%-1oj@6>~sn8>zl7jyOzp!m0rxnxrS4=<3W#NI#L#vYG09K^Cv z7RBXX3OhDu&EZN9WP~4i;;ASr7sd1^z2mFDqqN%#Ta@R+p=sZb`uW>(mNoW%gVkY1 zbH-}rCGA%RR@UtZJ0c=R9{gE$Labkgz1XfBKk~>ajLe3Eob> z|3TahX)_3&&QE(U?)PRZ>}PM#iBK6Bsw_hJByIB?92^o-Uy>N0>)Jn;+1V`#h~hE% z{qLQsOdx2~!Ml0-`D5L-4h-peny=e1}BT_N4Vl@3+aJ)0$Uk|oVo znu^?oE!VE4=;)8?6ag@Y9FLoqL=~1cNcaZ?2xfi`X^zvYYrNWUa|JeYX}L|If1CkA zU%xUE0U-a!0wODE%fG-TaU%npghPd7+V4Y?la!^+_ldjnnprv1$QIg~=7n?jS#P^q z(e7%6H6sEKC0<|&`uV+?o*JT^^f)@HI4Wj~(MA?|fI4{(rkZVSZ!_~?33icvng7;5 zx%5`uXjn`j%7pwh)t^lGkUXS+=M^gk*O)XIZ4(6Jd~*3j+HyVH(i(Xe82c zws7#H9&BUh6+jVru=*ye)1iE)M7cM+(SM-MYMLS}wzHEQdR0@y+)G&dAZ}HE|H9PN z?7JT%sw+#FcIC7~{Q=bJ7;@a?YF_7x(}&kPjI6BK(R%O}77*k}3G}YqJTT`*3za8L8QuY*2X~|9hwEw_;Q`PwUh>l~g$bQZD6(5Nd zcAo+l*oPHxU3>Fs_F8Q1M>F@kLGCNR-LhQ+-Abw$Pr>H&k8%Q0kx+z-_r4Ymkl=}^ z{LhkpyI>e-3rX8EeLgDvs6I3(S$G0bgQr_MVnAj7%95;=x z$hnwQM6kE@9HhzKa|*_lvWspEBvJOu$du@OS(yXr5Sk0%CZB`^f4DFqNZ0?v#f^E5 z@Lsakylm*GLG*>uH~cJ)(4 z_q|!l>%aFcUVh&K3TSa_XODxPdE-TGfrAT~PWLObh?S-s>RV7M7yB8(py)5B-hWw@ za}8%$i23GAL9+k!tZ={ifp@yW`5-&wJZp5@w_S>ityx#cQuVjJ5UrM|QEt&E_vqK= z!NPXv-p1RX3rZGN6HC%Ar8_7_2{$494uIWkLd)jf2~J?ax!YB~Yf0g97I(cGNEg;w zaR%&}vH~9zwSki>=XbFT=gh2%{u>X~s0nTmVw{(P7Py#~mp6IJ4A6}y({7-FtXE9c zXORwyyaT6Ji4iBKJ&@Z@hJ^By>a^+lD9!PSmFcvb1Xos8+k=FsY7{AT93W3vOyQC9Puc^ViWdYnM#wbC-va+@|5&ZT~pCtVSz+@ z5WgMFiFVEtav^jox0JAt-+ehRDmLG|ECTw1GO4l~C(9-=lKuC`GJ>dXGkydr=;}VV&*jb z@dL%r=iE_PEw$vg9oh2ucP`Gg8?WIxukc*#7N{z0DwXdT91pRgqidvV2`R_a>|YFw z#}YH4KR#-wm~K;>r_%92TkaYIQx0!RM@PpVB56y3weJaa?8rihp+Xeo+;gNy$U#4! z3Q>KbIt0?{bh&=gRsq%FG+Z($a;wE>Yk;T}I4aH%{iSP-gj7_t4o}x|GtPM{{dQj=nD_mr8}>7a23t4_)cw($%t%=1VAM)^@->L zDt_3rE26WsRlmsjFCRjb!GCMyXf^+2%6m z1qRDPAbZZc@X*Wb%0i*W)FIbKJ=^K|`7E=9AEL&_4`ntJOXT~1^QhNxGuC|yNO=~;1VJy+R@m&P8bo>5+s&n;8v8|Jd*zh=ffQ` zY^tq|mx=LxvTdCi{(*$d0rKiZ(!wGXTVdD@+w@(BIArO25*X8)`viXc^yxiVnO3QN zf;>!*zogmfcV&wJ=@9kPWvi3lx)deVhapbPWwf{(JcrB_t4v_Wb1v zdQAG85;6H?tM!jBkOFe&57tC&!iT3wSmT~b`uU-o^!QS^`PIR|!wMDD>Kz(NzdbAU zR*f=u^0BGO`zh&UwMc=`s-LN8?&}9UiL~^knmCH#Z~ah(RVT*B@luem8^HfM(-9IJ zZlv*+A=p|6MTGK$qvg)Mr+ll5NSM;CGl4}o(L96uzoz)pAjw`|%A_?*L`oIDzOfT_1N#gije9v-0_S#&Ug z;{@E9)~n>+s3-+}RHlWu4ZlYJe5NE*HZ?bo1*-?m@##4!MDLIO+8@tIJ;*Brb3Q{J z_9h?EcxCN%!E<%6$W3K?baak{0xxsS?*@OksRL12Ehj0a(4IZ~Gt2AY-2A3X*5x8uk zWs(*YO6wsOm~$s5Tl?>!N$S z|NYx@rsLHl)dG|eOL@|+x1}Y5pUft^!{|+*{AqJQ4GE>6l>h@HBTcs1fIlM?6NIXe zQ&Uf$Dg(G!8UwT?d2fao0tbz2r1Gpfe?S>+hg@Ai(Dn>YQn{^dh5pTr-A#lSv9-kv z?a^oPZ>b*FOxVXkFgTXG*;-!hJ^v?-3tut{a?=#^vM)=Y>+Oqle{?zt^-O4a80}PJ5 z498z|mAk)txYxPrO-ogc3d*3-`}KK-GpVAjt-};I07^$JIJ_Eb6E@sCDk6^4 z&Yd5t308sG+Ezt?SHNdOC1YDZdJDz=b)eE~je*a;ow@5mZK2&%BAj)82?!x+V7M$+ z@BNp#9qcdmXVZz-Y&t=iK3>SNo8Z{&8gVt;J`gKHaK!KM^zo+mWqlWkpgZ;xFLC|q z;V#MS@OVErOj5cHR?pg6X^qgHTnB!ivE>h9o$VFov2Ex;2W8tt=w^OoiY0h zoDfFDhZwC8c|0rcdl#jkz{qUk%`-VUX-p>mUEKLojRgl_ zn~~K$G*~_8)M9~H_7dnmm>OnOsHmvwsJXu&BBRuketiaRZf?pk6{B0~W+dixHJ{Wk zoL%>F92w1>f-vJ_BsJXX&lc{j5};3rsrHLY;yH5bo3D*jPF+=o}du;ujbuIi%&; zNYEz(fX(Au@g0%X+2vzA>c8O;_SncL|MhmV_HQdT+tLWx4FRQD}MOL zGs1)ga~G|{-L+W6LQpUBx<1#aRC&MTvgIzqr1pb#f~2{qp8^=3t)9GO{n!-{khCMu z1NbjXOFE4XNpZ;$3+~*ArPEK!TAwzrk>tNlcmseA0=sz<~epJNaAYh5X+o429&%E^1 zqZ@D2t+Tw~AYSADXNeJTm|oFjX^kP$DZ^rjoXxEyPeyjh|2jCBag@Ly_bGlrp!8^n zjJKa&2)27#LUhr{6f#2?C9y?Tl`ot0N_2Ep=MTk`yzGoRoC*AGbV{QoOhiUT_hX=$xO~Jvc*MP+-A_LC>fMDQw`u zHb*P%6}o=}vAzkN`fT{8U619FF(^bt%0Q5LAthV;`@`i#MkpeggpXGhZQifZ&66o? z%$pk$s&G6E%%i3cE1}TDJ5hOS_sSo-L zwP-ykfBI6Ghk6hQr5&tVB#K+-k2|^zhciWJg2SD3-j?m*=-k={jgP$!gra*ak)EBK z!7%Aox3`8EC?iQW=^qigJ88dSUNB}zBMgWqPf8J0Y z@EQ9M!@#z~_V&uqeiCA;ZJRifDL*k`6X149oO0GC6=aZA+NQqEtf^teqI-B?C!NCrK8&a1&B$nEUe@rJOzfYuWTRiQlOrXfA7-KrH07vq`dI& zKmgS<6i}8iydp8;-&~%qlyZ6C2-T4fee_D{+aJCPKC5AiIGpS1>(#GkvPB^}+HD}Z z0I+$?aVSGlQc@@+SoTVUn!2C@_^rCT#Ep#tQHUur{da$=6dLp=R-+RUS-DC2LGfFN zwxO7Gqq@cUzrlg`-^2C>|AaZj|NWrNM|bzTlvEj7C;|)S))YFrhy0D{LL=kLU{v34 z#=^pW)n5h~XP#pE>UnymgwHzz#tw!ydJ`zsLQa^E61>AiL!kj@o->D1CMEVkf?Eox z+5wW-%dweEIX~~l*)%AuI|%}+*L5DwG(;~k7*W2Y9PI9bt=*d^uD?U&$=otJWeEO& z4IOk;zGWXu8Pi)^g*xD}038m*)zi~6_~R?m_&|V|$!}&B4BS;LoM@ITn8lmIZceB` zcr9#Me*<;0vCK6vFnKeouO;swSk%+o``vzkLoN1m`d%70yD=4x-EsHZ*wt{Z)SAdh z(YoSZX`F;Db+nEp4_adh0Hn+Vz4GvD3L3WsASQIKM8je#!$ngLRv8AnGGCUjfL z0Cokq6=Me?Fc%V~g9UGKaQsvyB(Wl~(H8phRR=*pp|jdbQoikGIn|No+Oj$qFa zOEKp%s@rD}F>@E747#8eg36qpo%ObSiJPgIh*mX(p=Y8zwsQVOS(E$3Q$fg(U<$6W zAN$NQUDD*h3&dhXGVp`}{^$=>Wsxv8&O1GKGKYzd*jKdo<`b-7-*SK%Oco9-_(x5t zqj!&e)_Ph(ki6fXjaelY>qA8^?>kdd)6-^(r@Y@mop7~A$(ON>h9Q`Xz`OPDAJ{K) zQ!TF`)3-xtI`V;9@LWhBiajuI*x3GNX0g3Y$lt#&hZ|u+fvoK{47O}HB0>2hwWA7g z16K?%+?d$dxJgLZe_~@}X=Dc9-{S@U0>5SnA`$F1aV+!{EihE+k3$dBEhFX53CNu0 z2eDP=%0J(7a6+AJEeXAEkB+NLwawi)!8g*%OlFm8WAo>QFRH#gb! z^%^1R36aro1qOykA0$&cf~%cdeQNo+2z6rD$``q{Djg8UM80^lfJO!qFsq9qo>;g7 zbA^^&ca9w}0{%!vB@rh5h)`Qo0|sIScf$W2F#;}apz(7{{B3QerZGbd8H$cZS1;H5HSBtNel8)>ZZ+!*A4#?k32_aB;;br# z2ttB6J@BU5z<4jVJwqrm?X}M?3grP?V<%Wp`^cz7cZuJS!*WRvhElssmRZ}%bam7T{_do?Sc~t$NI>qx4ky1A4nb}$8I_I*K^(Qtab3k02e0z!C~T??jba~adqQ~;~7%&YfxUy>w|7!fE$ zziFk3-AUgRbWLg9gQI|=P_0xU6F!=Dmn0a-$hRJ&6q67bV=UNpieUH|gyFLYQ}yIK zj+y(Sd;3g9ZWuM6Nn2Ah53iEn;T0>i%GWSr z_~@bOWx7Q+$=LTC{DU1LDwT-C~J4{yBSEn?9NOI z%C7^19+Ubtt$ww0;uOj}ek|gt4HGa}F8>B(&fQH&rOn&w@+)v<*49pT zp>jjm2S5SEAolrl6i7o+1$W7t920~wO|JdARm1fjeQ7xnPEJmP$r5dvm9F>YM@P|- zG1VsO7JzN>P<(*}V041t zz3+IP-dnw*vXYug7Dg)`mFw~snX(J8DbjLc$ieIDT;ORHOz5UIaDdyst=sHCv-`TF zWU#Sr+Fj3&EhZ#1&-<48G9^`kmt-mv4lnY)>b+U3=zXuI5CX-@a}P6jBvfzS1pnOn zg8M@nf%h%Vs1n?d{q?T7r{~u=88u-l4OTG9_zz=gy`-Uww+GfKo5-XjSjLn-^E&PT zLGs4deiizPQ4c6!Ulpv5!=p1-){w_!+c+E$_V)HFz|h%}gbdE8$YMZ3rYzpu6J2J5 zQgA=DINIDgux3Eu{neJx-P6Q&=kmA_hFx>n=eX)e0; zAorhr`l|o8dm#v+pA`pp6@tv{a_l=bFc{;5g9yn<6;9>^)W2|E&8c0TyY0#2erqqO zaWo3QK7z=U$+0nm*&5Aro33l`>Qr^uvA|>Fi2L?JS(Cm7QKT7TUnyiTHI)dwv&2-1 zPRYn5a5bGf7IWCIk@~mPhjKAZO}xLrTWqj3l9QX4X<}&^)g#H{s|R{q%$PVcPy@U9 z*tk&HzFHI_Z2yoHMR_kXiM``xrUYn*4%$Ea=N?qVHQLVeD90c`p$AYCxlP85h3)UB zPJf+_$N}D-haOn{E4H?vg0F|h$ARJRjFUCwMNzCpEPw^r7byihvt+VF7Q3N=H#9_xMt!9~`@=tPkFjj~LK)4CVYADnAvp!%*n-xI%=+N~E zgs}keA1Q~{A)8yQvqN$@nIjRD0-rk2^pl`f32cQQlhOk_XrY*nV>5}bb2)k5hOwAW zbOB^Hfg8l++kva;=gErEXPCXv*d16qMis$#Usv?;bJGAFNgM6q_ekwbk{LG}k7NLs z;R53t@lxU8l}C;0jZX(Hh0`J$88KA_Xgl_owhK$6DhGMRVugqUUa*2&E_T9|)7jgT zVU!phDU1xNZ8oWpn=U@rclnug%_BNG^}TJ-Iq48_&%v}3ou#E^2OM+0ZaCQ97Stu( z2xzyGpS-THegWzAv2qfN1^aVSR#w(}vi(35WvYv0j=A?5>01^aI>EqKUm^n? zjAfF95k7wwihV>Fb$uJcB#daVCWg%kC6ow>p&3&P3sze4!o|e2zVkz4Um}0V+Syfj z?AxbRVvo8%(v~19+SpC9MBl_j-m3A4;>TX*`z}CS!~!`FOWQ2|4I~)YOtk{#99?*C zy=Z#!*RZi9z8W@*T8b9P@)Q3&}-IcSynA2|GwHtnkfIoYB@uOh1D>Y zYrAv-xpc)ix_Fj$_NNqFbd1&u+-FTt^r(SkV2I_-leZ!eg4m)8Cj;rv6djWX6C+3; zG_V3KYWpXy^>jM4*4z3nG5Y>h#=9+IqN2EklL&d}Tz%s19q*aILs&;iU;j^U9Ql3` zUI}xLrM}5$T{VCa0En z_FEPgk0A1X;S)LQ*~MuWQuX_^wzl?yL@$^2sF*|wYTT{puu$p3^Mh0+;x?6_^6Tf#kdt<%XGn=E z+AhoA?C!5N3kMOJg@Dfk9UFEwJv3BSx+a_Fpnd`K0O#eWv*1CT`5X+4Rh6u=>}&%p z%Y+|4a^Ub+FJCGfu;YDx-X;_r=oc9@G%^COu6C^1k=8W&x#%{XW31}u0iJ+1+v}SE zXW)P|H|rb8e>sMYGsvj%gy0~Ia97?Y8TkjY?&r(tK`&JntLr~#P4qy#%*`_(fE3m) zTi6MSMcYcmnsNrhhu zTD1KN39UGg(jDHy!8BJ%Y0(9SY&NT+EEi@}OdJ7_z}8o5@v>ed;vnP$&*2j!(G;RL ze_Rd)jzI8{N@7w{OsQZ0OZQ7RS|w3%Z@>7>yTT5N1Rh)4UxZGbN3VkWx2fn@)V}a7 zu?xwQ`PCt@YpZG5=746n2Ur@IB;@3rMV9Y$zb@j2tkQdK^**-j<-Z$>@v{%X%6s*b ziUVI_Gplot3x^V?;M4)xg5si>-wjR@?Gtn(WzMDOFr=HCubgPY!^Oob!`6DoZTEjI zC{g7mRzLc*D}*JB2Wj%k*u>yBWE@|!=T8;?@u8rj zKn8gnn}PGG$hK&8Y?YOj>$yo*F+iW013T(|Vw_47Ls3)|6=|I3&(iQl43a-@4%qEX z6__p0kn?aE&jc!gi@jns#o6HGn$2J`u6Sf(%*0KyMgY;Ue}8`41staeQNik0SBf!ubYemBh|q_Zl9I}tu6g~;hM-{_zbh;*R&c!a{ozXhRCI-?`SV9Z ze|ke@_;par>X9QGSaV{QfK3wM_h>a=U9D){sivg*t1hBUY(`a6TU$s)5a>n$fdnHU zAb^g*7S~Gw)-P_QS^x(*b9v~KbE;zUls@@eAeUqv^h0W4Pv>4H5&~{&s>B}G+_Of! z5XCUE=62t_w7a3SnF33*_65&3j;Zu>eNV?~&Iz8A&^Ahnp5%5ZgRMlCip$GYa{EM< zDO@!f^>706Kzl{`dO3Lu%jz?)j~;^Z!t=$AjUJ>dC`YM%3&a3o_$qGCJT-!mLFHZ) zzx1u}J1!E41;M06ve%TBHb|Pcy)_*S1Cavn$jK|iKr{HtWw&kBk@FRnpC66MbY?S7y@v*plGgm-CkT)E?)g)!!$v~ zW@~Rdy5G7a{%nwsewNKL@f^Mq^>1%K&Ka0C2>;RiK1|SGuLI#~QmVzO&osNTGBF|u zG9UsB))N*%QqmQ7PEgF{w|n!Nq?Bm{w_Wos9rPZgwW8v{|2$hz{EcuiE2Fitw6~&{ zi$w-Fy%~AbRNEZECQVi>%(a?XLZlFRVJa|R?P_#26xh>qp?YRID1AWEmO?y_JD&zj z$d3X;WKE6gdb&)W%ok9us@FRN&DnO&S6EB~L5>3s)|AHm{7Ew|G5%h{)Ksy;Cc|I- zIntQ02UzeLTl-0ijc(+JBNK3j;ph_tIQH)GyU9x;dTDV&w3;74r(5dO@ZL6XKWmhos%<5_|wwE!y@3ZNM(Nv zQ~B&#)ifX1U}_F!@0F23QBxfR8TN>Le*Ocd+cd#$Qio{KJJkMkfa4cgbGS; z>4MQ+bGMySo|YI4&MzgEG&V8OAP50(jvv(ds~s|E#G`fYkG-zOOT7$#-}8YU4|N(Z zzl`S}h9HYFH_t&CbacuND_vdrZ5=cAF$#R;Ci?_3O!ofHpFCCW&jS{F zYr+8IO>Tfogcp5A-$-nIe|J!E+lH~eE9WiRqeb+M3u`EW+}kOLEH@j>Cd#LSb7{eB zig8e0;z>4-P?v*#ysIF@mg(?2$l%xmX%HoPno3Gazy5zt15kei^y4_ohZD&8!5-q5 zNMf9;fy;qtw1~ts@Rz~R6|145^+fX~3d&>~-#Mq*i1l$b9hHmsiZ=6Q61L1xPegZw zhLhM*50S@(O5HsbmWp5h0?Lm%gOq#D%d)l)B)n)>eQYf{tz7RvJKhu{in?I4_`nlu zim=}=0}K=gzt0nkI8LO09iHeUsG?|f#H6!Jt+0$Kv~`$+Eq}WUzjyKrCS;%BV;zQo z77e7~0ODTWT#Jw}m9A{x$PCdoX2~7Tpo5W_xoX;uWoIuO1GZe66)2mL7`hkZvl8yuVG z+U4Wd*jEA@y%rxva2~JA+&m^ykoFT306czPZmlRIAC>goI5qeQ+kKRCc*#l#%<1Cb^Zzl54;)necG!~X zeX%b|_ezNnam2k2UryY_M>ZQQrTbp$eYbObv&cuGb2C#-bB7BH$;D|cqYOO$qxGQ^ zwCyBU%V^$YAd;{s?tS+k0{~&feb>tJLv6_3ToV^4;M%5)vtFR(tG_NhUPD0Zg6?d+ zAk2<`=!|bBv;KXVxm7S_WMLruaoti{|}q5OS9*VQ*U@ElIK^@&e>JQf@c9W z(UgUG%v@vf+N`saN0SA}=`h;>JRzgkyl*uo0}iS^`l+3v5cZ!kGBS!F5Cie4Fy@b3 zKq;O1LXZ`pd$n$7W?*T=caV-C{>g3Y?J?l4dmgRRZCrD+f3|7S-2G8rjsQ9NjXr zj!*O61TZ4Kr~3QLxlL&~`eoyE2Hr)s~!R>s=skdgMStdUuW}x+`1oc zG(m*Y#yjNKe z`)~$tZJjH{LP$2mWn`)?dT7RB*K_uRYSdrpn{O+MG;3vn1A+RJbprYu+kb&I2xV`f zXfRmNBna#RUbD(I@TW(KgkbmXPFHGUgz>1bTFs>+q7o*(|MiHrL}h$dRYKURX}<5a zl(u%l^%mZA>ZQNjMPW-!?ZG?>>fDLl1)(oNA@TPM3)B&QVQORXMYR*rT+aJ3?IAe$ z1#GxbtY6cTfQ_OAQ^vz9+Qo_T!O_tX823JNB<4J{PrGT{1&ujATNB+uQ-cIp*`pmB zynEzA&d4Y@PW`nd5aEIkj*Q^+#P{=3OD0jxGqLoJH*MV`Bk1Z)9yv3_>NOU>K=Mg} zK4+O>uhBX&LPHG$C#&kX95Ysdx{fy8g~vrhg|XVNsEf(&4ETiL-SpYqYDA&?u8vhb zppMR&e%J@QCD%0#|EYJ|l+`@IKm5MnYW$s6qpEA-qteq0Dhq;2Hjs9u2Ua0cCk1Wz zIF+RygiU01E(rNWB0t_kNesPxdo>db`f=ZR)(?qjXJ8qZV1cNnP6Orhvb;Q?jRK`r z2uwpbL6YLndZ8?TppU;|70(7~DUQn2e_lV+7ZXdWKmc?eV zi1j>Zzwi0;=Zvqm8Y+XwFL7~yY~NU!u;zRRiI1`(l-rlTL8p*o6aM2m2-JB!^!IyQ zpma4@yQ7rkr#e-C=oUy^_F0+m$#nk0U-jSx<36;%@nStZn5X4#I&8yOsrK+2{N(+0 ztKf35+;K@ub}c!BLF#^YI{HU^sffHVDW!$i;}D1D+xw|$0;n^>eCoVMKL7wthw#&X z{Suk!o+K1X6yV<7FNF_z6WE+!rd1BS5?co}Bv&&}4H9^6*4K z?My(>$~c2@TK>0oJ2kcV8HN3W10`?&$iT66?-{ZPFu^UP5*6NF;@?x-<=3axE*Q01 zub%uA8y=2I=VUvO*-qp8Zc5o=9_~EnQ7P%{gzz&5GWz0R&hd`_LOh_|h#U8M%b0RS9Rj(>Bq{<`GB6#A%5(R)x*7FLi`r z0Ka`)T^&2=fgG5hx6E2wM}qYQGD$s6!PnRRdw4)3%9lSQ7BDt2@DMY|c1~K~kyV#R z{744a*ItmQ@h&gLss0_JNZ-?Ndl21bj;#tD$q2TdeN&(|7@U;pL%iSBGR-gK~B5qRvTx@VgV zXsJ8Q*>-Gv(OJabh%Fxunxt>z=zCAAJm04Y5|DrM@t*AOTbXCnD*&xMAFu3>#Q=K# zdSJfgW5$i03#)~T6doQPGjee1{KP}wBmybZGjROEqE3Jqtf^h~R|#MP`9*w!FD;c1K}K?EVz|(|>KMy}up?tvxjz(OLie zc?QSE8&()b$IvGnW}pdUAx0&-W%5b;`UO(X01a}f^LA!FBy!;kbxfCsJhd>;IYANL zBGss!bo(|GNF_EMzmI3(0~4J1T6-_v^suuiE@)>rv~a;6xgTmWG%S&Qi;9LS@n?5h zk=Zz$RRHFQ`537q?%;5Ehwg#;e>l1dhN`+Ps30LBARUq)NQbm^cXxB?4(XQe2I;yq zNOwzzbV!4AcS-uT?;qTA&)I9wnl)?Ye8M4Xk_r4qG+4s$B=ZkbE5BGf`aiX`vB^cp z<%5EdYS{P5nWsa4J#5EN!GjP!G6R^Y(=e@X07&?{I!qV~;R6_=$-=whf2bv6J}7xx z51JR%&^<#Io3X!LosW8!)>bG>Uhs&{+x<>&VDxouGF|$lRh^T9LWBKEHYWT(c*EWt zDo${I2116O3;Ii0Xz}>u#6()7@LQJwo-XBCh*4BAPNqp9zXbE9ZgZSfrc3I_5LFcgPu! zFDRwxF2XE-zmLjNCfwZIm20P;{PLId4U1g!`7o7bvdTI!B(<0Qv2GX6Kvt!TgkVrk z8ta;d?DB7&8}Yd#w#KqRd1bPz#@yl9Hxp#emFk+;#?>rYidXMR3tqpaED8q93OcFa zE7-Rkf!%_n*EkW5-FU_=#ya1&i`JIQtP7ua)8<^R@o}ZxDsL;DU!HI()ON?$J~{2C zsca@B0=MFChkdLC43QJNZ8Us_A(QShv4moff7|M}6FWCF z4R^B;6m=fFD$=N*-YnY{Gt2Mb_1(J+NyRTcSab2uGzD2udZSdoG5V-^+LWLFlAjp= z2J=(#;@NF|xQAH8|630FMqdp}O;Y?fc}tp7EX)sbWF0VNw>+Vj>b-X?B%hDp)h^p! zcIX3(E~5i%1k<(tHGa1{9Ea|^6AnVtg5sH}M5?Kz;Y1ofcD0C@r87KWt{5rceoxuE zBtYcb;X#Ms`JwgmXKH@R-~rjxk}kB(m0q4}fVYw}verhdX`JjM<`vfxQF@s>e=Idv zgAEft9-wd0!FG;>JisAGL}`?K|Nb3^V@JH>G8D zMxWC}&UwzIqrw~%9ULa(ki%a=Sssu#kUTIQ#~Mbav`HGEjP9{LrfK;SGLVZB$1Rfv zC%La|0jP8u%1Zuj%F_2}M*a(PJJ8fW?$TqUvkb3pD{iAK-q~YU_qq(4(cykK#tXS8 zF_M!7GGQa%<5@ob2*<`ABS)tfQ8P)1Q2YgsQNtD^I;=r{wm9>-kw1)x2goanXf1+w zhe^%5Wb0)fRfK{3%kV`v&P_-Cg7`@b^P{qZWLUy&Okb_rdS1e68;Z5TE)N!KZm(TR zc9^>u7m;#MVs;oTI_yRH#s*}X4#V#(T8^HmsT~@?VYehoc)rve-)OrouQoRqisF{C z2~8#0h>k}mQAq5+$zt|x=S559MELUV0d`6EW=cYO`Iza?(+ljava`(DfRdEdc38yp zDuOOW*dv{~_s$%Q>%uFuwlMHn`jPUXaOhy|to@ZAb*HEPB!UE9+`L0eN!U=WzCBNR zUc1>CXVI`$NiZH6{%7!Hk}w&(e_!qhA4T+TU|29VazuojR4{#>4$DL8D?ARmDOecT zla&*cO|tv^-6@C{vd3_!dP6!kk`b$41*A4pFvVte7-8%hhlQhKbJ{PGuW#{lEp&Y4 zw+_}eGB&cfERbU3pp(JzP-5WSsnnD|v9$#cgEkb(m6}A)=t!)TN$GVt_3%hK&Q$pm zEMuE&`ki=*pbo$_$PMq$nDZcfQ)%XU3oDV1g8x=m6sag4|_ ztYAJq0_ojN!}L$9NFk`UneMck^YQ!_K^Me^cBlnYD3A|~3z3_=L-U-*Wq_ezxVEv}Wg zrTYX{tf?dF9Jk=BJZLtZat}F?BftqGt~#+@?o2t->>#W*(mio)V<AA(Qg?BX&F=9}fCk2EW%xhz{B z6%7^R5^jO{VG1^hZ*INZm=p7-Ps%tt;SQq5capO+R6(NUxc#DEtyg}nF>;PSy>skf zhtxpJq&?|`Wk2g`rd(?}w``k>AUaEbsZ~~Te0z{OCrOnu zpcUtR7vRM_#@gXx##}LanNgy(nm>Sv4`PgJJRdr{utVBuGh!}B+r(mCS) zhGje)fL>( z7Kt!E#g^pcq+7>CkAqH9d5LPRIo3fIi}Xjw5Aby>TiV-$eC)1U_j@{YEHBR$Yf!PR zv7DnJ@NfQAY+!m+hHD+7M!VEwM8RZ_pr)ab(c=Vb*d5gGl7aian-#nYN_O_S>xUy5 z8ZItPDKBct_-|PZ<*V=DJP0|cEfKEJ`b2|{LV`nRuHIbL6H0bNZROB)wML|KCeQwc7Cn(J`+n?8YdHoVNlD}Vidn_?^agmge zR$9Qsl~b->ky}! zFJo=s@(>6fg^Sg$q-=Aq5*J`=OOF^Y1i7q<6%sfAxB@{Zme_TRzer zDVUm4g5)4+G<@uA7Ff8uSM6Bbo`I0mc*NZyp!ZN1e@kmaHPx45*W{NLP{}g-_1~TI z*+?x24$VqR8O24aUOYdvd_;+*UT+C8v{2+9gN zD3tWn#Tn$}T0y~ssM^qjeakD^*{_JOXkv99E@UhCIc~bH9CCaQ79)}h0!E#Z+y39n zD!=voiWkm#ZRQph;VC5P)nvWTd8e3aS{)qhKrHfIePjr~>e%BeF3T6Ux%qjIG4+7v z$T9W0)d^6j#Al|bWq+pl720&r4oorq92rUC+G8FSWC`8uiL&SmfXXlY9FexRTWwA5 z%caF}vU2sg4&eXcg#ac0NM?>_*6e@Y57_3(d4sA1Vk_EI<}JJk8F}?-b-$AGGG%4v zeyax0FI2QRn0)=3PjYoiuB2RU=!#^jWW4wLQi(%2#PXHBnFyoNVE=Y@F*XcFQvP*> z@OlRX=XWbycbjUcyO|RgwfTsu>GNU>YwYEvt2AVS@D96@JEE>Sg}j%U;_YZMA%_k6 z(K-?(AuTi-=i=|!rfu=d8q4SpmSWS9C^lUqmfE?*zd{0<3x*Kn*Aw)SOk*-irZq6& zlVW|vR0l%O^bX+1X#_eYJ!??0$D$v44QrUImuhxi(7>p~R-o<*gaSqDb?uAL9U*;V^BUwxrI zg!4f|WnmOuroYnin{3B+8i>x1i;ELM68GKAR8v)##<+9{VDp8N+lJ-DOd78ao{r^n z;T43I?Q=$8_4f_Z9iUZCq_dY>PGs|4dV0`7Vf7<44M7QsdUJSmv}5*8k>&LET666! zVu>O6!U;N!vp)5vC%pf!>d-?&Ob0I$FrZRmVvn+FnPSV)=1gKCkPze&tjCK%lHsxU z*v1T^eIRFqc_^ZJO#%dJH6_0x#}+jU5+qSp3`#L_s>4ctN13|WDZI}X;lAH77$7W6 zJ!&>HH;{|rgEp&_M|e*Xuj4#!Sxk2r6AVO-mi-YswbRsP7HI70dO^4XoZwu!ReM~X zG_7B&FKS?(9`Efvv=#($DG6+`gIjkr;OEX4dqwP9lwSA6TGQFd$=}I}*@1Z;G)5PF z|38%ozUAT-on-5rWPH0zo;HCfE!MEbob_PJ|3^T(9{WyI-BJ$degZHrTtxChLo$< z#yhtd`p_~HgV2W9LWB?gBi6AodC#K?PMZHr9M%pI$U`Qpf)~5j!lN`mAz?x7845XC zQ^vLxQq=AA=AN5X&7hT)U5<96MXdUF!%>~vJpkcNNqf8S@BnIGLg^rh!iFpR=C+`s zqb#CW1-$x|L>`O58z%=N?JJ#cF>h75Ozu#?Yie3L@N80or zv%md8-O!Mdi59=gh`2Bq|3Vd5$0F-oqOx!uHc~@!za?SKG}ayT z3#dXp!5K98eN@#|6wCh*28EI*yHAi|-E2Ees7%)R!%j$c7-7GxEgth4@&D-FJaoM; zNFf&`cc%^s2_jNup5|Eko#$#e_5Q@1dIh?f`I(Slf3fYXNI>sJ#q+@b(NaM_`;-F# zMpO_xdD48$K(L$ap~K2iAn)tDcC4#QS9h@8*W6egshArK!879uki#rUwXjr5d_eL_ z&B|K+$+ePp!Lh5Tg*PDLrlqwAf5ARcVK^7GBA4U$(z4>C0r^o5kZ8|l9BTvwZvfXn zroRHA^nO?3)TDKw$rYQWLU{4mRg|wDKVRM-J ztmdWybQCDB&hs~KQcrbx}(2d5*5w*3y7;T+k*<#zw2b#KXfKW}|q*dv33;0qCIWQ_ELk3TAl*M=~XdOdT4 zcE$Ca#!pMYsTtLjQ#H*SenmhN1@$9!;|&L_+v#X=gY`V^&y)2T5hk#=Q=`MVLqIp-hVmNH4mVzUQ+O((_otv}a7QD5`RMy{?*~wWpl9|0T#u5K z8P$$#Mp3R?AZ&ozK(k}(^3-lt*`bvPN6qIY9&gw2YT|BIw0^CJY?f_xdK!jg^0Fzy z0Ip&}5FzhX*%FfdwERJ}(X9$Hh=how{99|kD4f+?X|_a3dK@)->*l#v1ef(#jvonENxmo zDGuBCq=q-r`rQt%e{PPQ&2rjgWN0W(pTg`F@+*2nQre!xX<4#GaCL`n0prtrU|c5F zsX@m3gXet3;3eVx%(P2WVd!lm$B15yt{V^Dy3VpT4K=f72uJDrS5u5VLeJ07vJFDN z3`p^WEcjFxT(z~+u^JsZ9SqpnIFvUk^+dFN@CF$f=?mC%o;7f~rATQ7XKds-J=I2{#&B_c6QDBEDGQh(7!0BdOz zT|Ywi-=Q?ft*P+`epgPkw6w~bSz(UMrlWZ2u06VN9Ww~Dd)mEdAHMORpSRGK)0*mY z&D~?^=qwj3?8q2_U~cp<<%GnY_E6l;(5i)f8Cglv&{f9272(otcgGjn<3l*biEEyR zQmv1iT}8G>zi8~&6RT9#_YRNu3(`L3(05>|eckie8|T$-bB(@!e;v|MuqcM;vkSuj z-X%gccOX#%PzI`CgP~u|H9Z*)CVrwJu}d)IVRCj>ci&;bm+BW`JWf{m+b3=um$Wpk zY1kec59AudfY;Fu%m&(|pHNsY-GNeU9|Xmp+$7D|3v)rC`}L7bMYG@O;W`L;E5_s> z9ERW1F0_gb!^cuZG`Q6#GOQHIs?NT_pkss)~ASO2!_k4J{ zM?@iPGs8TS4E9YizPPNvnX%wLpVWMNF zXm=NffSopmN;Z|WYDRIU=@mxAQH@?7a}-k`(nBoYE!SMcy=YbLf=<1&f|S%fLtg;7 z@%w)VFyg0OlWAC0Vh~Fy9Wh=YXS?M)1A)Te0|dV-yKv?T%7TZ z>2Jh8A5byE|GsgsPzjtkb1-R434|!U4&`MnDr{Rubvw1{II(D3ZD~285%4Rv?D!m4 zJs)NC5p>C-!62F{9b9p=QJW5wl?;(-gy(rLS!5`jHR{=@?p?^QOXQ;v|`@DyX71uWCO}5(BXkYAA z)e+L^{b);h5**QAxK?fMUQSh1Jd>I~kk}QV_&L+b#s;L&N)>S+qVjo@eaN5WbyOl9 zAkc$cKeLLyjY}ec>aU^xb4_Y-&86Amr9;)mt@*jVOQyz*wnn0|_b(>5o_tS?>;5 z7OVcF-#f^$7Ikk&!NW-y-Cyop13Nr)t0-FP0dky(Z)$~w$+5JD#J$wrO$>@=3lzu4 zg_|`uh4pgq(H-BO^#E>S`f1QE|fV7R)HUsc=|h?&sYImS<~xpXSSv9LYrM zovt&vO(Me-Cwu+@yN*Mz7?&iqR+q2LAtV0c-PvZtH);cLN!#SyDM4(BfR9)bt_|WHlqt|QXE`! zWG@~qdmZMLjeB%lT!};~xlguus8poxuV~8N3^1eQM}nBX=}0F2kC+kVBUhid`aGN; zC{q*_&nG=UcK)4?yzUWr5jtUD5b@L{<-=ZHD$yI|+kim6M9tHUdn$Xn(^5V{i=kMU zVmWRnYjTvPru^(@bE>M7MRVmePMdEch)ndww05-p8OX~mBE8C-C+sL&u*eR7&2tVt zy4)TwPE%c|Gm5?PLiPF0JA=~u2Z>Idz8Lss`Gd;_8{4=B>swONdYE6?0a}Z#76?dG4a(0&V5J zs=U}WCd~fnYFlYZv*vIMo)xW(^gc->juga^IY_%=x?HhJbDk!9#!;ncaz8&tdAAVA z8mTMC7`0oSOKF1I4-c&xBLV1s%dGBmHO5(QIRlUMRw;oB{%9jMoQ{W;s?R$Dsh8+$ zFQnqVOzc+wCWVlg7ce};A|#Xmks6;kV;6srBkM;jj0x|U8)y-MTj0rkCt1T}fr2+# z)wYMoZRw!?92!Y&5r$;K4e;#L@{8zE=Yy4GjN5#sUuVJ*|G%4 z_-kDjS=ZLqE&FK`Vwz7QrHO7OAxDG$%=Ni zP8#72xCyP*0*o)Oo~uDgMNWS;`javLy2d!U-tE6bK!aSI*%f7FaG(xfJVa8xB1c6i zBDJ=lzRyDp9_L`m=`P2z?dctjazf}M^}c<2@}^~A(qN?W+;t%`-{smBxI(xusVuZ&+M7F$0djYljEglh zypV3Lw=5R|e?4T1tR{KX!)w|lQ1*U~{G`iPofQjF`p2y+7&e8r^nK;kYd5-X+Ntl^10thiu+bn(9x@;rY()O;!Kf@i;eF~ z-(mPYbWbg@PAc0PG4fBDhGvl9&V{P>TMR{y3@o+%l>CB zLUHoo8~L6|tirFbkdn>sg+{+gtRC?l{p=X=5}?m^viSy8dPp(cO%W4z?rXac(Pl4m z)D#;V+a&wEE<+?~$#!PWbCely*ZywNmsdi=RE%KXDYct$h5ZmB*yL6(b4J%ethwa< znx#Z!)ZeHB(U!3$xD`1K3N^JY$Ce~j~Q zg({om$WSgT(<6;o{{XG^mEo+a>QH7lWt$tJClSSDPcplG|EV2r@Ee4q+n(t+WTO=5 zkX#Ma;OekP1Ys>(AHJQF6{%!3wd^x|Qt>>C$&0$8sfMn0&4>_2jix68YJT(ERJL@t z6_zCS&(USpN2Tqx(&NSA$Vxh$&)+sxlxaM!_bFZ&d$N~xbP7|TYez6l{rgJ!CZ<;( zj+~^#(k;fe8mg<~Ni-W%UfT5s#Ih8Nqk(T{-r!V^Wc2&(dX}!j1~EHt zmpf+_hqzAN)^yE-w}AsIdl<&`arDVKa3}YX8cx=9Crcd}u4juPCq|v9%+%Y>dfwTy zNum!|s2WPmbF?Bk?At4g8Jzgp5m&pTxy=g}OFFvZ>@KKKg@L!_)ijOXZQ>#PgH%+v zM|XN18KIt!CK!7p*f~0Vta-G@B+mh>K`;zJIhZ1x%^A65nK#iKI)xka+YoOcV|IRi z3M>eU^0*!=Zf5Pjbt|I8#!ewiErL%B8!|b3>u;9@`Mr#24y{9Wpx~k@99RrdTeo?7-k(i&3n;&g+hG zWSBbwq=|G8+ePtr^Qd8Bg9Fsw4-C1541DkVeG$;%zPq-y;qYWtB-0WYemwGNwzmIMLHBDU|2Y5lbyEHvZT8L@ak4$7nWkT@ONJnph*?FM^$@9nH zkTSW7^Ku>g6F0_|?#rHzyJF2T8;sh?F7KB_6o%oKL5h<~YiD^&=S%QV-V&zpdm<6nao>?ng&Q*Ef|&dj9uo zZvO0Vi_#u~5#58|mYP6e%bWHKaw1@Apq^w9kM1pW? z{#_8U{es}&Xodt?sFY5CQ*fMl*jCvC1`z@6cznj~&Y9|THay1%vx#gG3=9nYrVIz~ zF_m}v>lM(2>1D}Uu2l3 z2W3(Se)(!t4KCEMQq6D|XUT>;)+D88l6Ey&E5!e?@{UIvGsb|?*73@?ehU1t|Hb3W zDJ0O6c{7WO9%hsQ{afKYG!>bBqNlSf-$%jIAk3?~N|uQ0{L z+tZnt7{-Twh4Dh%dzxSw{UwHv}4?6NbGE%Y5eZ=X=VvnssI^T;ygi| z9u5@^wKN#NDEagKhs+cMKSg^0#>3O|l6`5piXXORMXZZU!=}@pnS~nTLWa$;Bo?2U z9)~vwz07OBfe$>is#BY4K8bc0ou$YYi*?RePHwDIMK1=wraE`%50SjNb;B4AzYJcy z#F4Fa*PetEYhkguqLER;!PU71lpM!*V`HP}jH#D-7rYui=fBT?bH_3`>2A*7q9W(h z*|kMOCHiTsuLQrOFcss9t|zNzb0MP7L&=G*^Jf%_xLISqA;Y;h1B#FBGhc8(0F1@_ znvsXQfWX$Ek=#lm@ogoem8E4edA7KjJW|*Z>N!=-_ox>eEKNfQo)yx z=vG&hqJ2L7{qh zNs)TboB5G;%>JPc znNitTU|)_)Z4m`Q2{qRl_*k=S`!P)D@eQ`ErAI-v#lL81CIc2x$mABB5)f6H<$B*P z`t6K=p5`mLJnT+Q=Z3&Mud~k+bC_omlaljaKkAf}6^ME)XFDtqr?WaCWjP)4)`o;g zLrxXDBG31frsC_Z4c<1mIlG&_>>ss+@w>RXj5wIMWA+4MkQ*mH9eQwQjeRm95RmA` zZ6_eiC=q%hF8lK<^1iw=Prs8OOa8)3o`%$_|0jHcxe|7Is{5~=miqmOpC&tIYL?Qu zO`8Pf1yob@!xB@sPkO3KS{2jFCIUmnT(+mWU?B4D!3pua={B`iX(gKVpPaRfj_Qqx~xA|=kw&W^O_dlyMbi9Hd< zxjWw5&)o4R#HxoG;kN^qFYU4DD%k|V61mNMR2yw zj?`7$%q0~uo-kihE)$|{=YFQjL94;jn+D?oi4rRZU!#1AgN1`cOE;yrCk%}Fr2SCx zzH<>?4qj+fx%P0emeqT4D;7CPZb7bR(aKHcX;pFpQpCyc@u9WAT!G8-Ec5@yw_sH8Q7qvSOSS?@$i?;j@%~LA$N-O`l_#$ zFvM9Sfu@2e0?6D{9_G7VXs&WfzaJ8r)lD(IoJ&Z9)Pn#_Pxr&|uBo*wD_dLl0?>*~ z2b(hi-1Ol?pc~2$(3DYo$rAL7ODH*NakDqod$PY-(h%zG`_ru1dhrETG3%F z$$9U4kS zBa4#AF5U5IUSE(>5+7Aodghp_mKufMz*>X^d<*gN($YaH4d%fuVHNbkcu>-DiV( zNpl~&KFSIR{9N3uXr+bm`LX?Vj7$~}(48D=f8_ms7~w$4y}*+@k-DuoBKSOANK7N^ z6gDR(#S3+{Y*k^sw~|#e5Be!%gzIw0zspG*B0;x^LZp7hhn5B6KRs0{2o=RAiQu0FQ4L_D@R1ZPR#u)lcf1 z89(?HvNAIx2Q(%Y(*;r>9K2OPAde~crcRGu|AP2m^`l4u+ld+3o`*C*VU>h^thTXt zrr?B=efZWzy-lYpr^0%E`Y4vKtX>Z5k?Y`3n)AwNsgaF_%EhfiuCk+hs+y#FwJpf#U6y~6wAe%(7a7CYaS>FJ*2Ec~1x&@IhZa}nMC44V2p zybs$=hry%tP9o^LZdlNo8EX_;=x!;Z5{3bU_c4MLz>d$2_UrQCuL!+v9cV84LWIX@ zilpazHHJliQ}lhTv$?P5Z_Tn@{y?j?sV#ZE$4CLu=acQNLdor-Jd8@xYC#H_u*kfN zty*ec`Xs9(f}zEg>7!Nl+JU`F#+2_%yrw-|53B55xplbN)cLTVKS$%bBOvoq&8Q?6 zaJ8GQQWb6-^>5zuZj9DZ|9*7_8I!5$Xiz(-Tb@}JaU33=RaBAL5qnGu;csecQUVRo zwPIh<_r>V}3vO(T?mRA=ct|Q_qJtCZRDSUOQ}e0L zsvNd!jfu3h)&+Z2NChdI_~y#zuaVW=si9vLt)1$oy?Pi1f|)a-dyd~NwXrJfTB-W~ zlEAVQ+SxmpDgJ3hZH8nlLBg(tu@PBTNe1L^u|OxJxSOYreX?_l5(4r)^|}Beu$*2O z3ArE!V39v~%#4r%rwTW`o7b)+<}HO2z|&X^I6@^4QT=skQJXKh;StQ4+-~T-Q7EVF zSMm=NerpL6?T33lznzOh>;)l=1dFr|Ej}qJ60>1OcV)0~;hZ8l&=Y-Zi zm;H;Hd#oRcpH>K@^JkPT`qOVNKbN`>DE9M`GJlOp=YK{JsJmVw{?O49ujc>AM8=++ z4!Rc$jl9}Pa|Omj;^Psj!q(z-$Myy)%R zChUP$VjdG!X$?Ds1@Rw%A}8QIn2BKG1O~#GO5XN4E+!>p3ch~v{e(r}^iH^tB_?Clx~7Vgyq%Hk_jm@Y z%IMlBFeS=1*Wri>gk8#h7YW~fS26t{ zAAiFvaHA&|Qr2$P*ShvEa3m{^Ndjpf6Yw4l!@uK~G ziy2D%GqfeqA)ZSVR`2`G$ccjqzn40;AhpPFh$~N)TYoa!MQU0I60AMn4c5s@y_A#= zx~~zG;vk?-9^9vCSJ#$3(EbVGd}u!PM=z8Q!6TC_N+uUgionrxlsyco_uP^>@CPD- zNE3hx^xZ83-f5ZW3amx4x%Yz;&@D6hy!?+*g;BHr)dq(SuXA3p;~}DZ44Ot4wfFF; zsw&I-Bkih$?W1El70Ew$Pi$Y&SGq~)WGuljbmUFkmQd5KwuPo29wj$ac)UJcZb~5s z0!Gko6-CDFAC-nhtK$In@^y~y0}CHWCj6qW-lSh!AE|U#n~<>U6G32tSM`qvQQ;zL zY%wC`AKVPa#zOfGjPcU0tj06Br_Zm&>Q#zQhKw(8OQ=hE0YPO79J8FCZz2mM;xHy| zv{LlMohEp%5Rxs_l)=8 z(@KlR)ZCCHz{-LBiaBbwEet$9@~o*DigB7ZLiE-3(d^~yNz)0KmP-W}Y;LzG|40Ej ze%g!ylLAF6Wf(Co8>qn=O@Q%#T)2Q@PKr+u{i#ufd8g5oV6_Sq32b-Ab=1uJpj>03aGxVG`A6k+PE3`zIJ*0C`f4wod7$$N%8SQZ#TJ!m8N*_V6zD`VHaH8t7z*1 zWwCBD81vy8Oq^$aRp^+!)5H$r&hIsH?9x|Oz6b&cW#G*}tcJyraY7FX8owE+fjp6W zqe~4K8yQ&wh%&phbm(#x*PZLp0O&Zv36!lzun|3l%1NIf&U{%=W(EYN&f^n)j2&3f z+gW=#D!W;THMW4eu{b?G_+*g71odDLW)iLJFZjbnkiwTvwQZsV0El#0WOBdR)6-Mo ztG7aLU?}wYR{64tf;)`1@@BA&h@JlW4bEQ5_aLKf{M?7-U3}`-MN|+W!X@APevoMD zn4G3UHT|c?n3|6=Uitl2--PiH*olEzcEEi97;qKv6vtqDQtx}kL8+Fb6h=+O?=QZ$ z{q);BwQ56ber>7fkqg0^ho4&u$9yp*r?_6PUGcsJ%vT0tNfc^orj$)MROGGXRD<+_ zeDPUgKA{-UKx0Sz$iqYA(F`=@v*FVfCB-g`1NT(si9(A$tTNxfEk==Q3BUd}bSOWqB3hoR% ztrD=F(QCtLr%>hO%gXI@mmb8e-k)XYESD5EsEw-WWp@p>Op1>1pNJ?2Deho=f7*VW z{B4MYT&LXZF1BEcg?$jm6zYD4h(=MaX(6;7G-%?kS%SsUP`#eLW*qhAH;v~D{*~oU z`i-ixI#~zKP3`fj+kPMLlD4SaLN8Ai7S<{0sK`?Fyl?5_!!kn6Uqn_YJiPq!@Y2>W zvILKaPftBmJV^YAyZYtl84VwwahJO-PW{*~uqnlDN7|1HkZvRzzqAc@hIPN74+{5rJ<(G$3wqXVt-F&E5$4@|85g8> zuF}|&gat(F1n9-sTF-JOXe{eWa=%^%Rj~LDLT_)p2?r*W&$Q9@yyLY8)X`BF`O97k;THR zSdTRMhX@^eBYYU3_tc!}Y2{Y13~~TO6Y;{a7!xOykvLPMiF=7%-Ef(T7Mm?zXL;(y zDsDRGzhtDOb7yu*&M;QR5&3+2llnC)-v(z3WBN}AqOS#?bBN{C4Tydm6ul*DWZ9u4cLi!qE5k`{x8e|0yXs znUeFezSCpc+u%v;lwRxEO&g+1`H9YTr5OGDf7j2j3uRjSuss305L@b}2IuUWY zv`NGW@}M)@lRrH|op7cb+l48~*Nf)XhuB_Wu&^LJw*Z;mu1=6fp0pB%zDMfckK>tA z4cOS%iy42_rz_!c4!8_R{GKotPu{hH#b^^dE1O^G@!!6Ed(Mu?)QE?KFxf}coIc;J zCyN!JU*s9(%hpd(rr%e8|9(sXr)PpU)2(A0H39hDFF_UFdCaL42v>J@!Pv~~7pQLs zz#_NBT6FJB}ea2Urd`5v`! zo{SfH%jnh%Oe1PG9}$0w^_#Xa$a>O|OOOxT|Ik-Jk$s&_z<5g1o=O2bMMFsz(|Agq zn{X+;dS#xH8k+cNn#SjNA@GIBiKC2PL@@|Ak5&ldMDRQ?bhD)3KdZ8;;!7wwrLFaE z-@q&-O-<6XZWiND{V@ExnIczwbzCIXk?8#!7q>fLWdPrfwCjlP?kNkbJYp}aeycD- z>gYaj{zYw*1d)XWgXIYzwvoHJs?+@sepkpA&G-T?a9JT97cB}Bv`@N?p2w_GEYGet zG;TP_t5VNQLLau^ZiaU?2|~vD`~JCZoI-Uf)OP|Dl*Qro?u(bHsmO#Z+d_?Un18PW zsqbnItRGrcJ1g#J;r>twEumr3Hi(g@rF2+7n5`HO>=ii%ELA!TyN9*nEJ}Z6aGvI@3 zgId_UZ9>bqA{Gq-)Ry3fiWk6!RWQ`n{e|oa-vA=x!D?{ZJ$Wf-O+3*&V0J{&%N&T% zQquMEaP)kCsYW`m4G;456af7h!uqeF3G9VrP{_6-(+Af7=9SuCre4{|wG-bLZ9jg?VTAI7__6j>60MwNy_%vwM@$vRFjl9u))A^-pR2q zj>;q|JRy!z_|_?XuhY11hd><8y=c7u$GG^u-3@vzdfeUMA`_3qji&N`i<*)~-}8gG z_=2g6X%I6-Us?y!*lW?_pk~dH)%L71>qF;aa*OZfh1|v2RA7_<7sRchQJh&9om$*$ z!x4dTEBEqEZkkY;5`7n}>zK%ksbZ1vaOn6FE6j$r*u49P5#D4>m$Wwk6%oTfYt*Pe zf}XHZ0*BS!q`$)NI5;_*X#af~|8~>Cm$9sXJFqew_-KuMjFyLhUetBdu;N)fRO9ci31708`W|AV7S5^VQbufRyC#Qmsyf~^Qv@fvyn3-F@ z+B-!vkdlHcP|}7MImi>aKbshHWqBqV8@E8wv8>p@&A^?loPbHqL96UZOst6*>>}di zZ~dOwEE&6De%tO_3+>HU)XDl|8z8j%ynUwm^i<%mxHv`Tw+?mB4q-t5fGe>PgSb=! zd46&5QWI1aB?#s7a{cLOd%WPd$&bctDpUv?x!}R;@4?;A+&S`ab94~H^Fphl30YHX zprHr}e>y$QUIy))E45`DL>6h{>;M0TaooVCQK4s+pc=*$@3F4m-bkno%3@MA3BHz? z3R+rey5}TLqb$=u*z(DE@g4{!&IkHTetKZzn^_2@u8N@XiGobsH|(5*TWMjsHha&P z8f<@`kdqp{{_8Yois^JeVpu&fbQJi^5GiFHrD`wwQ$oY{M=hlD-! zr9d#g{e-cv=v;((KpM&~1H9C2?C7p~Kxp0t|CP-n}AWI55!%G0&u=q^avkkR++^ zMVHb1v0C20Oo>QS0Qh*XIyQ;nDE#*tEsve)^`kuK6~@}q@_RlXom0DyC7?h`$`(ig6f5%W z?Sgn5yIOxLA_tkpQ~{m^8Wzd$qK_S8k|c*`-UKnn3;_oBWzQP1Ry(=L@AJw4h3_GV zG~Yct6YC-;ju2b^T9JF91Tn@^KMFM>wobA{dQuFb;C{m~olk-;=);FPgth4T<6GPz ze1>1+;~W4(8XXfWaB@C6W#Oc|#ODc|hH#c3lN0W82HIk4SheLRiVAh`?eX!k)nb)M z-SP>C`-!H?s&aGJubYJDfPnNFUh+%N>vWnvMh}2^} zLc-alQW0Z?nwwLf>Ftj{Brmd4%W?0Vj}F$hR0GB$WJEuAm9(0lKW5lVJr|j*D}Sl2 z2tI-*10y4|`K|$6JTXkWVJ+YMuTVRJNu!nkJ$^X(N`n>Uy6^S4^Xx((N}!L7?4R$$ zgU$5)ta703<+#g_C+<2_37dn)Q>2_rUv;yy*~T6!bM+N&M`&stU6gd45$zssb=Egn z{epU$Wbu|Sb~4Fg-bO~G@LxHyzCUyf^xT*Ol7%bDIyqsi19_}um zzTROTJM$VI8ags%@fl}wbTq*L;}d&UpJ4;KCElcz9O|yY&A=E7CB|p??&B)H!(R=Ez}) zPv)Bw*)zq;e#kL?Bk~A7bTl++b#--ej5Aj57`7BkSuK0ExU4Z=j=b5B9_YJSpyOBq zHA2*C)%e#kjrqNShPwJb?fbJp{ok{w=P7=CCrusT&7>)WdEAeN<#Xh&nOlVLme-vl zh==r@hb+r!=!$1LeUfe#O;kVq`*&KDmUi$Wap73@yp|gB4{ZnCG8~`ic(J+!I9L-m z{GaTifvPlHObT+wSf=(%s^95sgWqqTL=STS6&4cn!7X>N5SPcTtJc)O$qo@eW`AXV zQlfx%bo4zJdkjYN*XU!0H*?~~#>0wR@PvXb9@^e-QVgG&*NkS4>9Y6mFk+?9t~ zYU>c)l!QOH-L8C~iw~L%GFC-U3iV4!Pp9Rw8$T*T0Dh}qfy5^uK!-onZkT-qcwc9L ztawrTh4SIf{Gy^Aqs!viz#;shhKrM-dl($qcs;y$@{kuFgy^NR_{i9w-)z#xN=gCg zn#`NWdLvfukMg2wf+cweTid);eeWnBWQ2C`&$jiq9hY8D3D|>-41(4`K)Yh-do#}l z#1z{9yxR=O1u8(0`A~=2z|~KTvAKflH71aV=kZK8hq|rv(!ZQ0O+%T)lDFj1fWodek4TjiyM(*l_XSqSd+K=^^0SFZno&o zG;Q6!hgC`Ud|C5`w7rw6)iq|fW}U6tp4}u<41&~2dt;BxS8LjKw1q_R$(gXpi%R};9y_x z%R$_KXWo{UrFFOC;W$ULOG{E}YHD+9Yqr*CP~ogd2=PAq+0BPd0A2Ln9AT^=A;5=s z><>H+4zmd&(G^kzFJrkmB#|toxdBXGTm2x^Ma%*zK(ot=R%5Xv)vMBi+{!ka{CA|1 zhKb-}v!|+-CkEs3zpplWRZP5R4YLRLvyPf3hz&-V4nF)NU&f02VIFj_l${*3TYsYU5i+S#6f~FvPfU}ocMI}iyhw}^93msQq67T=@jkG^4-S6pEh81JlSThmd%p@wGh3)fz z@mL&Ue;OWu>Ca*UMFowT&#qtEd;V^t3A0jVpmN1#y#v%vdvXF1yfp6ZvFg>Y+5|oG z^J-WGfZoyp#}LwQD3OL1@yIgPTUJT~pi4eUfMbCPc$}Mjer`9w zw=T#PjbRqkQ4w&%H$Nvsy>l`Kv4gxO{ftn6mYB^Vt`rOHF+hD z?1n)-xdq0?|KtNorzr7!QyJQFez37YNY3C-fmW>PR0DtENx`SeN^Lv_>V ze0F_@^4a?ASAsRr6N%gje^L#mH^s-)ZVARlM}OZa^F9xvBF3J6d4Bn)x#jTvBD@y% z0m~@{cHcl~R)Jf1MPP{Nk78UpsUHMG;Q69IDf20t2!sJ;w>PfM2%{~F0{s*@3>(f0 z$8P|9T;$HzM;kWST$s4Jcb3icg90>UZ)jNnWu<{=X_sQt?7P}mX;^X4T?8mV%wPTb~I%Zl1hOuw)1h&yx0soRY&Lu%?MScQU=PN(kQ9MCks0R{g=f$o(i^=Df?U8mV z7$~#;3kZmgsX5HKBU?Cx6<6s)hAarT1Bq8y7{u}x2Cps|9LJiXafLyu7&^`Vif955 zGC)Je9#OXvh3A71<4(DRyaNXztxu;dGi1%7y#Q;DAi ziqylPSMi-K~IVBC^ zU92%d#)+g$4q;%b?M7iu0b&)t*~($H%C+zWy@u)hNzDpg_fXQ;&w&_|(B*Fg>mr*f5pE2z2nM#Yqd1|-f~Hn@e(U! ztFpR8;_~_UL#7x;?X;;VUfD-jaox?Ldrd@S#XyaTR ze!dZ(NPm3tE;l##QN&pg;Lv+Gw>$1)|LZ0GRx2Ii+aUEtNnM?Z%i%|&%1=A9A=U!s zngH~akBZ`|A>jyYYuC2GOoWB?_Z^B4atE}07#MNncaOlCz7qinL^gq15(Qp<0Rj66 zyaM=hk&WYgJPblv;a*{A?X|%#)U)`h`Z6EO%0dL6^8+o6<>fioz4kW`p3a*Ij((X+ zG^_!oU;mVK|H@2_ITlt_Bz8UQ+0@o_X6SpFMIk$&el7i74u>f7a?hoasV)bD;?_8K z#pb>K^&szxOQZz%-qSp;4GZvKtz+@fsWTnu^$8J1 zUdDVQj*MP1(}==33jQ|I{c?)B(#o#rKY2pVOE*37p)VKX3|ce;#nPPnKgC2~Seh{d z$jHpEc1il95)K4KW)`#Y)vYPODq?Zq1epOed198}xEL}JW?bsUV{}w>UQghEvr*YG zsg6F*)<^adm#3kuAHy<(kv9%{ikU(O-lsfb0UuOY;E3yG92B)(d)MRozddD|)#5DS ze1Q9yUf#z; z;~G31hcl^@v-6Nf(6=TavHO1uY};{F@*NvavxML*o@gBv5Fw7Er>L?w#tJOor4>~R z>u?k1Sbgauy1OvmA;zp2arx3vlWC;}U7ujl1{D%_0yGY1@sPw>*6!gjE;POySVPHkL1(G)wI9thV>_ zu1T^E5AYr&o?7FAkxkx>gmUuEbqu|rgXgIKBy0omwkZn=|5CWt>|F&TF%<#d)@)>g zm#Vv?gBY?Eau=@&`1mw9p$QBmW%Q-ZJp<7j0aP%ns`HjWOAkE^3W%QEzHym@971L@ z?TfR@R8Ry`=m-WW44BbRPns@z(GDcY_Xs_y2#8lE?s1kZ1M^Bf^d&v(FkqtGl}f+# z^ps!stY-%A?buWonkZ_O`X=5AK|zQ7T?8tle~@@?kviAI2ZvM_r5zss8p-^Hd@YK; z?RwsMH-dQX*(AReMI?=7_{w&TT!_Azrj(8W{(~K_iiBPW5~rM&8^N@ zzfl-!^Daq)Ot@Se^T^#>byYVH<^)P2LJ6pQ!a~D;5NDF@ymGSLQRgI7O(?WoEOw`4 zHY{8?_Q$~*s%VD2G+EEn;X`%$K4{FXoj{y5P2Yq`r!{Qr`u==M&ReE3#*0t8KPY@F zz%}qSdUG*`Di|#752j!*Fa)Qv@~N{*9?0lNcCHy3(eP1FN>_KUmJWzLPBRCHm!)g* z>hN+K85u!t7U>UbhhN2kr1G`%My{DDL9TLQky7EMPlzoS6Z|G$9pnA)*0tMSA2(+J z8L*Z=8d#(7TrI|GIM(SdnY{sD zlKA_-sxD56#mXyav528Ujl3Fz-TIROz^*R|=A8$J9Tdk<}N& zwEO`gMbdtv;2xo7dE|*UUBB;Sq=~VNY|0D}NW@gJnc1`=;Q6Ft&es6{L zd*e1pI3dlN{j-QJ=NXCLWOYgjix7^0(k!)`W@%vscx8o8PKzZR-y3S6Q(I3*GFE=P z4Yw0xx{a*|0-ma20jTiNE&@inChp8Z&09>~-I+f$#t0p4tmb)-$0qRtG68Q-dJX=5 zRb!jT(~BA>47#{Qi-FA){0&n9O;KqHFdeSItcIOVI&WU@)y@YGRT$$!!$yg?xs+le z!SGdQ$r3zx&SrEVj$Ho1$keT=S;f}5uB24PlE$s!Xwn1=lT;B%CTcJnC%|tjB=F?+ zdwanFV9Fw7r>v!=B@8$NJM>YqptT_+S7UMyc4=3z&>-82>0rA&F1fN7^<4HE@=`t*A2j+sUyU^uL z3SyUCjeWsed1TsM-o$?GMo6%SW4-(@p{+^mb{PdZ#2G@<7NLS34UOl_#l~3K#KcJg zP82)^g=MzqB-UyXtrE3toV7bR&IcZo&>c#R$M~{ z2)0s^5gZaC7Ngcx9fZRj!Dl4qt^-eG!_wZK>B=31XI^?SK99g|T(eAMGN^m0!fJ#} zhsiV6ce|ke8;>sqM;{W9qbre}Ekl*36eW#?jVq}G*qPViu+VC?`z?{vdYRNK;BHlHsd)dmz-C9kBp9EM3|3^B(bam(Ryq5*&E_k2ct{6j3#TpFtOglu-nion>cXat|YmcX!H zsmhpfWa5jM^(H$)Kw*s1U^)t(q5nLP7hM6r1IU7DP1vlaJMWZli=;)J(g$hsvbs~V`-dBzDLeCq8L?weRquvw_YF!82| z#b+NX>3HE-h0nQ0*S*NGqcCszKzIUx0>ykl&j_zB9y!&Hm$3a};Kshicf0$KiGLYO zX49)%e2`7F->=r#ej`2wh8iVGQaJKgz0J5Gv)x=;m~}sa1fHt zu%)6nKV~FZR7}**u}jQ>OYO3PmZk6GE6}}r=mRaBHtK&^1^M_g`}*yQA%TKe+AUv| z`?p|yMp8qS!g;XJG2ioD+n&za|D{=rsYV)9CrVxBbef_d&EBj<4qtQ|>(yx)-Qf+* z0pvj%LHl}>5b_Zc+|wOP)puyih2RGcmksH^u3yXMFu$ZNQH`Yu|7^ee29Rmw1r+Y% zfYNs0}Fe+CGr+-r3ftqx{u;~f}O7C`k(99S>V+_D6go>1L7D}{Xs$&BRfJu zgY!W0gH-q9?k_BUv5oJdy(yfVQybZQ@SFdI3tn0(j*S*vuvB0jk0D5qdKS64x0ZJ0 zNBvJ?J=4}58H|^2>V(f=`s_`=>Z70XvYR^hWwg#9uL1An>Q4Rfu=Zg78^L#C=uzWZvP=nbP|hYkuK<8Om*??GwLl92=1~} z3x}dLBcnb2me4IafpkI!wux*$u5BFJ*4d6|QS$E;iO`o*uJ&itpTDx8_#W$Q76FYn z20WDH+F#65e(A|xr&i}>J9J2=r|-dt-<0V1czbWgY;AuG{)r*1RKw;TPkk(4{jHmA zjkqR{n_rI5zB5+6QU@1X7=!G~CD^Kq%?uY@(jpFMU1=-g3i--xRF?f1uV(78|9;uP>x>9d>c zH^sXj*K=F&!OM5&4)$TWxa6UV}x#ai(4tvkeaO!N<{7I)eY_^+}!xyO;|d6DV?e3nKA3Y2O{kD-=Gg zlz9XUVkM8WC047l%s=QU`-?=UyMHw|JAfZ&R(hkz#3Glp7-5W}0ZqMmX_-9Cis=e1 z7xO0fGHXz@G9r?_d+zfJXm$~Fv*>DicUCIr%+98(s-q(cGL zXNf|FGKZN2`kLm}*X3JVTi3WnL&dj^w>qqBYzj#MvJcXE?k5Rb->*B&k+Suit*0js zpk+!KO8OYH`*GMg{4V{xRmQ&R*#ey$)In8NI7dVl@gAnGN#ROJ?!J0bxFKjQSfizx z7SCh%UPVR`rdtAYP!#?{dM5MX+R2V#7b({K#RsTQ{-AbIGo$16txh65q5=cK>yAg& zaS!RCO=kVsdxL`i{tNt1T1;5AkxoPsSI|*@pDek%Q{YHM@HXD1>v2X6f@|;y|9IHk zCH)Hpg;h&vs&6nyz~_LH^nnGS%{n0)cU|%PIR_lUkvNYXq6{!btkl#Ie7n&#G@98# zLBfwG79e@xa2;sDBSE=}j&2cj)0d|8n)4-RXlZHj&=t4qpp1e0*CQs6O?prF0H%2P8uy!rqK9e-O+<$8ueL{P-VPtGMlLSN zeI6p(xy?;UVA?K7!s27ps^g%WOgw-od9fU(2>$f3AS{Z^Gc<9xx!Vsoi=yIZ>nC5C z$}6__p;7bP7WE|qZkXTtg6!&M<`;&T{$uZ{O-;>|zpkLu(v=f~P{8y3cWR-M(~3iY zzyEW?S1@2;J*q44$q8s8=8DYh@25QZF<@F%9T4dK_*hoao|>IRtY@tfFVOkIQQ(g@ zTskP>+Ry~6+<4l1I6Ldw;bzftu-F}t;D6U=^H;_rW5d8n3xEA(8%?M03;e!e#_XuP zsxOAMvj^*1j;*%W1NNM*fn@ftX@`XZR$jp;O-)~5bhaFd!MCO_o|{`*sO6A#w89$@ z^VIu<_E8;{^@$qv@OHq}{V?un2eJH~ohq68y2rJgpO&00dc#S5$T<+8$-&OGaHerL zl^PxqQMfrSgU%rIOo=JXjh3FCE{+;;+DW&Sj(H~7VN#e>9Q=?;>ia#Kj;^%3g!z7o7Xw@i7z?4YnM97@?)f^8sC9w4Iv-BtZ7}rndg;yj|VQSc%x&w-{dgapZ-; zOJ~!2HbAo=GY3UQr9fMTT4*FEL)Xq8JPC583issW<}RZ0f@Wa*+Uzz})+SM2_N!#5 zSlylvFlXCvfh_?2&JdHMOq zclm~E>NkCV)$1px8qesqs>1$$hJ#^|G#RCR^bKgssMr$gxd$A+YcxU80d(e$3n(tT zbWR%_0BHKMrumDoS*T{Z&)6Df<`R}%8R%!MrN@Kn(l-Uek+{A1UsV_s^2~$peVT_w z$5Cz|Gt~j(#s$*CwL|9hpv`zBdGz8BKdu-waojMS`o7wmlWwpwo0;!l9DTJ`9{kOi zkO-(h!SQ~VVArdy-co*e`?^vz0Qgi?($V328Q*+H@%WMislSH4EG;Yy0g|e?-98cq zTNynD9sfO;rqn$kMA5kD{DZv;F-OT@OTEih-C>P~vN(BHQBDkNi#!j|;nAg>PO0mHsH%IWok> z#MWY>FiF#7<4KYyeelWCiU}cBl<0kSwIFDnnH@cKcbX+qEY5VUtEb_j&P2R+dAPnZ zEz9+)yT2pjrDz1q!Vc)4F(HC*K{#a-o3FiZ-z+q^y1S$tmr)iaO+g7T3gUC#r`8T7 z+xhuvk;>e6_THO5WgXV20x$h<1#u}Ur63CyP1}{GQ^u7pKf-mO@ecBO_lBRGl!SL_ z1{E}EEvL80c=UqAghGjM`4@mmA=B~xX}9N_xbqFZW7C57Bm0)Gb!J0oQ^+#wm>8)= zWvqm;`D1zZ2Czr>i6x0SWV$dISg?Nz>_0PgZ>HGo`3Fm*+oNmo@Nj&^(%I)Z%JsXy zHpqK>EUBKle0eIoB3kjg`%TZR;^Kno_j>&X0vIOwrBThw9v^_zCk(bk03RICR6u4! zhKArCZ=1WvX(!U;2J5zeiviJbf}~64MGp!GIxTm0$`OHUtCClH)SR zf65qGAzN|NsD)*SSVD8wVi{~1y$wwsrN0ta_;`7;TLTQghl%1``UZJNe&t;a4O2Nd9PlWjv<=sYSOUJU6y8;VP1hZC}<(SJnJ3iOp{D&Fc`koQ667 zM}f%iThkxx&EMdGov?QFZhodZ6WptS?gbYs_4&%h|4eN+`cJkr zJ{``nwcS74dPiXcfi&YWk_xiVi0^-&6hQn<$4LXKgvziyA8?p3q5YkV#C3M;ky{}J)>MK&S zU3vAeO-Q+nrCBp-xuIe;8W8ZT5gP?s%4B8o%D}u#k&*ZM1`g2H9pVB*ZA0u0cUwx{ z*VzJ)qlMO+C%Bdf?eIT?gfq3Vk3YgKt-cTi-Mu}5{C@bqtmU@Y z8yB7RdhkE9w5R<-)GzsSYCetYhHZl3Q&Xrbupj-0X@CD0$SWB4yU9mEm$}btb!JGZ z<3jz3<>jrrR)bM+Lr|uc6pcEInpmNhAvXU8aSco@D2VE(V@6r1+pk4Q3s9OjZmXOS zce(Z=t6$EV=gQ(J$8~kTN^#s)Y;3$9C_KaTQbgv_Xr_#fjfGkNxC*c}YHAYW<|+-Q z)Lf}#cmaF3T^bvX0f}K!5|{hRo;~v!#=hlTSKi2710O96(6t zutccEU>2rN=%@#9idx^Ey5HUo9o}wL_%HiEv*ZLn{nW6_0~~YWBaM=><^V$8xOYdpY1C%WYIm9T4m95P?=IKZ=0uij1O+9+Sg_S7s-o<;prdFK6?0$geq4Yc-VWi0ey|s7mQqibW>-kaL+eSby zlhtxpUl8F*BFV~f4-tI3yU+}|ZUN|v^=M$C*&1Su?EW;*8M_-^x`gY=oINoTwHu6% zjN=vnDI}fV7rfIF+%NFTAxHYV08q%sXPpi+m&+I-WZ{znjnH>GCtg7pb}T)zeL12CHT&nWN9hOA?>Rx z-vrZTaFql{BO#5PvaNka7nM-Dp2qbH!x#n=M2-JCg=CWo2s1lu;3!#8EVS{9i0rhJr(CnD0aOVoqTR-;3M~+D~6bZ)xs2G>eeL1nP9E`?> z1flWE0YTmS_w>^^d>7h|YmPUC9CGUFIKv|92W(8A;Waznert5ef!-e8B-Zb@R9T-H z(FDb~y6*fEqoZu;gnql2Q~k6sO{>Z;y-l-f26iirw67!C!Zi$ldR1f9_Eu-?q_Ei| za6#V%MbEE4mI!OmCK-DY%v6<(g!4_sjO!`*elQM)4`%-E=P{H2~X2 z%;qcQvi2uW1k=3=8k)r5l_n!>PaR*+{?Tcg2E7gMHu$_5Ik!^iMBpx?yxpC@32w_! z6!B&!W$YCJx}lu3VLDg>6T~5{BF-FcPnJE}DFOyvP7aP5lgm8bI3P0vkO~ukCd8mk zkO=;xK<%?(m!X#;`W&jLr$_wq&%i%8$M4qT*^grwf#L7Q>qS2#=jnXEniGD%+dD?Sh zlRnRFPEEeOY+7wLH5WYUriH#9Dl(T6dNBH(m;?xG82F*UheqK|4Pi1e6d3Ykfu+KR ztUFgvbD9pz)&EW9i$a!6Hhfa$4zMZ_ie9?2?MM*jwjazlnL~)JO zYcpr$%;LAQ2Cb$2KK}`vPKlX5WCG07;^LxW>@!|ZkMN~urkKP6PSMuZ$66esg`XtJ z?3NF+^r;YHS|^0;xYmAugTLz8-qnq`^kD9gNyo=mRZ^F^t}vfoTmHvkq0WPi`&Y@m zDMEQ#o_8zy2Chy0>BXXXbPIAK6zz^^&0U`Bb_b~~eQL+mFC_Lsy_#gxO~8Gz_qq_a zr%8?Eo_|V6RhAOS*5g>(96p;lFfQs89qBj;CoHp_AbNaxdyY{k_YvXmc?k8ix!zkm z%=XwQ2DDVuAxGC@u{Krv*@ojfj$1DdR|x`YBYWIqGL_St&o`(8}~ zHCJxsr?%~(D51OF?pGV{9lkeVJ}!s2BbV(xPsoAN;|^}O6hOy+7rY2UbCg*g8F8G; z;nJG^z`16x`Ad=xomaGwl44-c)CuMb1t2aie0!c8K8WSo*qW43T-bR@XJ}KErvTm1 zo=I32n;6#BUHzs%j8}YKgZO_7sIDzb&I!0dX}d0#XTbZi(ic_bVA;H^^_kD9JV$&W z5g6t=>%qe$+&2&^JZ!lbI2Ma)MsB=#qUwC;hI#4G2eUfzn1q;~1uDZdGO*dJ3l>&4k|@(Twg6U-tcXEtA5RJ@3*eDz>;K+&C`c}x-rN6O&tg)V(6cs-hZQ+)y&he-*xG9M4@@`VPgW; zJAtN>6(a&IK`CPy;&%a8)taVDw==RzeoRyA>*DdCSJ9R30Md4BrA2u9@u;tdO^CdH zI3GA+BqG6@Ws=`c5wVJK3y>z(EpoiV{hxP-4==X+89v9swY3aCy$mDX`MmAyHkj`Y z2|3cq!SOCm0Lri872eq=5ffxr)`>r(#TCp3;(DV>Unen2vuSjR3i*CzPr~~?T57r!|BOAFbb)s2Y5G-TO}V!+(tL|%SDqT3-l2(<5lHC_K6hFt= zp7ZDz-CghTj7A^`ySz z)8}g(ld}{oQ1B-8VI%HS=zbS@20mQ6xO-!o0i0Gbg-qo2p=iug55ZkFY}H(C+G;dz2s4v1k@a5Wxsz zfGc*iv3g!#o(?MvUZ}m*^a5%lXJ^=TT^HY8SW;Dm_2$3fnmdy0#vM;gM19XYh@gez zBB}U1`cY)=yQNyqX8`{SLydTOIoeafKNI}+oo_H*qFE5Kud1mm0^Qk}g6UD?{MF%5 zSwCN}&cZ`A7pYor@rK9q;y=69Tb$(5JF|%_IvmqV?n?!ZhL*@J0!(O_CG-x@f9zzR z5VkK@Y^rA#yKdFqS5R@FRcRY_qL+UCva88rB1fLo`!7b@?U@fva~_$*pI@*6e;y0w!nH8ArSE-hB(MLrhy!neK3Pcffa4Vm|3=d%sNm<+fPD7zhM{--xNA z{G+@q8k?1>vB_zlI#i^HpRjClsCYKdjj~G9k4pR}`(eMqy@)bzrjp{qLEh27xbj-Q zVY{R==S#d!({XKU2kZA8A5Y*el=9TCE3C&RpqCrHYtG5hFR!dD>gy9BLBo{S zglQA}Oq3eQxotRbjJ>!6fk{_-yFhyr@`1X}v>fiSSH$Uj+YOKgGj*!Q=4GAizY2Bw`{{zA6=s%rUn^|1svUZmYHgq!`qSX8W?3ACddq8CWQWUHH_4Q8R8W%Z3*3hNem(UXvtt246 zahADGIpsSXT-^L?PQqwM%#Tx4XK5CDtrbZc2`hMUs5qM?4*qegyH(74Iyay^GfLKb*^?~O^jwOObALBTTAIjKI5`@phUgcRZk~CbUDj|5IxV4` z&=dEjr&$$D1oc&9CF64ohEmjXRO{)P)cTmWb?pQ0LpTDKo$YBPaR~M)P=*|fD2j`V zl{_pYk=LV#+|F-3Te|G#cLpI>v)yslE*;;dP|JIJKk)iJVsUbDO`E&Psj0beGl}xL zYG+($g_}eLx<^pnP7@b6frghNomr%!aQ2O_E&&Q5-v1Izeeib_)3wHym#S{=%KAF-tst-pN=?}1iK#diaqhqVL& zUN8$c->-weV)9HapK$e^Kk}WxxIoX3W2)anTL36;~1xIaG z?`^+dm7_kU#xk%H!?Z-cF`RQ^oS*adOo{ctV)IRq;M$u96%DtseIE`!LU3(kV*~FN z*u;gukEqvtHZx=YTB8YhyJMc7ilY@6omSu%En?w~w+>=WSBGp5 zdEPEtv}$_Dj|0PM)p{milVMKJv26qR1_1T0E(red>CeaPsNYjP|7G)fCw1L;lG+jDdr9h2sT z51I2i?{H-E>q!!4uM-(LaT$WMhl3k6>kqZPH-|wJ3Ia(M!sQ#TP_@fK0iYWva`vK#}H!rQ;ij}p>_1!>^ zKhymhICJuK-k!-H`vUxPd9`(7ISOD3ZsaIm{vHF)FbCV%!|NkBMlCIp41soK*s8@D zpA8rG4_ zxSy?N0KIVyz?FdudL9DFs;Z%DoZk>o6Dc{0;f0=o7>wMSO*&-+CMkFEN-E_Qd*TCl`(H(VYm&u zUYu~tfPiZUE~sypa8)!q6DR42t@Y z?jCPV<<5nwkm>H|*EDx2dFqVpE-;h;Q&I~u3@AVNh()9_=NE8 zSF88_8~eSid#oDCFC`y5;_(p6HS>D9dU*fpr{qO)_LV>?sIo5(D+fYK4QL=(dQ6?%wt z^E0Z#Da#B@ytMjZm%Z77o)7W*bwd2bBKrJvlqYX=RABWu#1#i-OL+)Nv@%OUci+J_ zMmbv{qtlTbJ0-@}}j zmWT&EvyPwD=CL0kDE4;w&|T8|`dTVTN));x9_dooUXse&?me>4HQ}%t3WFM8=-WSe zdRtE+I51)gSIx==${l`$G^iFiUD`~3%f)?l$O{OBMebmyr0nPa7=j*%atC+nxwsqd zw7%+<@4$-`2dv4=o%^p(vEkT35`=YzMv<_14%<)>Iv7>Lk}-RK+Za%D0!A|$Ie>KP zIH0vAGvJM=@vC+raz?Es0zAj%jhcZ$^;xyW!AJrnWEk`s^!qm*_g9_fGdAoBM!aqu z@|iC(L`3{lmKIb(Dj7jRL6@3eTFQux(df;&G%*N-?G&`Vo*hxjo{+Gzhu7)#bvUrWc0$^|RvUJ$&~TTD35fA6|* zuY1%Nn}qlG3nk<0Oq?cj^mHwpfZ8WlpbW@qD|(L+!4b0z$*91kYl~L%qHAtiIuKQ_tqPQnSk%W^;CqL`e%?$UyGHC60|5WXl9mVpelxar0-}7ypb=h z7G)quP(7442HKV&*ruYmI2@H|NfOZavjhF1?=eJ0Y84Bmcy{CwJ>DOd&K0vh3H9ds zIQo{sN(k5=+(GL&+On;u(cr-a{KEBwP$*uQIChMHC5{L}Jv(g{QN7Pv_qk__l*U1h zRfF16Fh7&AUH^oT=6T7cI!8e<2tkb{eDnGG=?=APpJ$eqa&~<8zljbNV;$1~k?=sdj=n@Q7L6PW`@S6h&@ zDyMr1>|21saRvqw0q;w70vBTj}QXaWEz5aN8Q+Gu&%a8zggg@yLO z(!TglRd)`7V)jmwD7kM#yJSg4>$*K1GqZ*cEHRIBnk<^GZC|V(*57|h`U)XRb^v!i z!I^xFRj{7hf7P@t9bl%10vmR(+oUV_1ceTE4vp0$?7w^A32i5MNK-SCpV!{Jhj%b>;>cHQ?4T2)=mtYbk)+qZZala&?2a8rQ;_P1lA z!&ufq)yB1JM1N`;bq1Ln2G2C?pM3jA1x}1nu6i_>V5u<2ip+7#;36+wWSAzTd3kKk zaP!iS+t=Nm6imZGd6b;rB9NId=lVqH1^}Jp;V5-(QF%EJP=!mU-Tlf&N1w@arYT*N zlzAxQSW`tDxr8oO@O0H10T)rt&E(Ep3Rs z=+f~5o=Jmh!)C1}$H8$zRE5)*iVIKcn$kG>oc_#0@X2b6~#fs^blj{IO&JIAA~Q3W~v_}p`C`!abK3e@80LWfS0}a7kKn>_j3~mQ3 zN)}1NTiN`);@k9-IRxBFd6=tn^F<*l5+)9s_(4I>d@OZ*$vt4v(5OlV$-~ruDhOys z%#5<|(f93f@qF%7?Mge|nEh_yIz)|u?A~t$lz)H&xWj9ToFhoDFC!zPti3%Be(vb+ zYH>;SJ^pAKPxvdk|2KeAY;xT>-`|;-pb|0sdgQ8T;kILw#em?6J%xisCaA!uMrPjC z`1ATz?Y+(m1sj{D=D7u|V3|Ewdv(xMdb32Y3|6QIm4?oEY@;V4G{1020CmdYl)Zm> z!Dr&Ezw;R<3??|3m`u)`C-eINBP%!$1}VwjhH_b{m=5iAGifuyf$xmn{cs`QlxF!N ze0z&k=g!C1!fYKJ@_>>+@GInRCyVJrRuPbY+&x!Slh6ojHN?5|Fu9i^BeZ+7yiBQv z^OI>veqFkIx{B@T7faPO7F3eX*k)9jrEjEJ)_YyqEw)`@9HYd3<0g_6L1(QDJQeOn=`?WR_6~TOnCtp07;nY5qChr_gI}Of2 z-k_%wp3bUY|0&Y=o^*db9!cwJ!J(?M$0jC&m1$9X#ywjxLH}*_U%!Rd%B_BOZ#wLo zX*l;_ZBLHRacOjxOJ8crvy#54niu?0QjJ7tOfzzQIN~j6%>*94RW=VS90ONg3`mu# zs@9G0>w)kOdp~}y{L971{p)?P{XN$m^_?Ea#*B?2$bc|Zt-g~F8&wT<%2qN}lq&YA zHxk!QM2%5B+ai$YYCOj3t>MywNaGlB&?MLbph``@_TCp}XKCmRKq$z>Y42xB#>9B- zyl;7Oie8@YJ6@48n|?OQ)cWAwkgCGM5JWKR%1D+7mZi+C@T_BB9x4q%rJ;S^jDDZE zMmLq#Ray0VDl`Tp;p*a|f7&~~B1UWke)sNTx8wN$SyYZbDmUV$+-IxEZ2?1IBLCg@ z+s1GPTke0GVaT?uTBwBKCnD^GE%@PZjI)7P=DUYm-LGLbcxQ6^%I0Opv z)34nVAWakPt6BZ#0n&6@bwx{^=)ZTc8fGdo#cEf>Nts9ii{c&LkI?;-qu3kj*eW}- z>9IaQF>$eFPVVo|wj#Jc#F;L^ycr91JEf?b5RaMMpM=7n^CxPvS)-%5KDq7V2FV}M z^ddoF-G2k3?eD?heHUeATtK@VfrPQX4h=v$vYvX%$^VH#a3l(tN#)DnpQKMQmZ~FU zFNC;1qy9nscGVT=PeM$>@QX28{Bm#yD0|#9kb6BGc~KfbCeb(>+1Gc23trpkqG8(z z!U+efO{r+mw=^;D!}Ha4f+Hfg*#)mFK(f3MX_3+%Migif-}xa0{b`7YrNQ4EYuWQ| zv4z8YXIAN=H0SWxL+0y7znD_)D(zlLn_Qlo(`ldLWUcu~Sc@>1)aFi1 zO^x)`WAjSS%f3b1XGZiinmz6@L6)(zBPl6i6sE7UntXi?tw$+0YBM&&zWIZ^LnZ#b zG7ghOdxvX3<>_wKKRx?yGT!EZ@&$C#6V04jVLlaifO||Wh^_a-V3cn{lw$K!|BJdC zwNXqAx~D=%cn!g(MvKd9WEydHb~Rg zZ4N^@?|t!p+pV6S%~*H___8s+F54exwl}r3q zwUNqFxfUj(zbp3BljCUlI=QG>Iww97N-{i^Dg3Ty;SFfCS|8RF8&j?^%6g1dZ_brI zrM@el(?4z={Unrt2-na89g!u%sS9DQEvyGrvm6MI-2c%SDE^n~S8MhK$_3-I(d=6p zoFTop1frDGuzB6?xN;2B_k^h>h&)ghAtG|pe-)@M=NF0L6!D*KOMg^s*AbFJk$zp` z+fQ*i?k9s7Hz!<|JxQs1TtV(t9Cz@f0kdacqu#I+JwN2UWTqeNZ}=&k&55KiWfw;O}0A?Wqpk4w^+p=|=zyj>!eqx>|X6@y};FiC8FuI>}Sa%J?g-34> z=P*5pVBB=WT5^B<$jT^`1cwC7|M@cd~^8JJfR1(8x6>}wv+j? ztTgOKt*MMBRAks(sTx{~oX@wPR?d9Nv=PnS0M6GtJ-v7lrk;iHN-Wzj`j3A!9HprF z1ZK3W(rvIVZzi4B9^cSkk$y9Dgv&KTPS0OA`e4MTmla>M84J5sxPClJ2wOH5p=uG> zBMKtn)GaK(AXfcw@B1()!apn#PY=macW^E{Z<9Ct?)*Jl&^h3`{ffc8zP|U;YlO3g z8Wq9M%}k!nGDg*zP)5HuImftj%l_@mRRF^$5Gi)m`b&%!|4OG>HFg!fU32{`3Of%1 zEOrjzhS)yoU1?{f>n&G2XL8KT6q25b)pKKxvu|p7Ou$L%ylq06Vnp>%sCULG9 zWO$xIpX_sGj)F&~?((%MrwfG@nEkHrtkqatoo-$pc9;;4MSLf=bm7aB^`GH?u z1*4`$AF<9G5{>qPl|sw#_g2F2fu#|2-5WHI8*M%iH1y{4K@M^JtxY?dJZd!*K)NiKw5rho%@OM5DP+k`g@bp2+|;sfTT85lt0wwB8u;K8^Oys2fR zB9O9};WLu8Y31};zp-EbNV`Smj-~ni@p8`~E{&QZKLLM3)}pX8G@@fRPL}0Qn}@aK z)RsV7xGuy)Xu#f_Rho3nF(Oi*>9jOeTr7;<^O3*zTUuiNib!UX)v@?~_J-!g*O}|< z>%pZ)#IF3;vIr(`Id4n6raR%FjzGDaweXdPhulUr&Koys)~fRS(3FZow$Ha>M@y|j zE8%!JxO)BA7Bo-A?d|G_RqfCOkUK0juUQcpuDLl+d;fT~wbtqn3Ub#ZjTjgrTi!X`ge7-nh=R0craXJhtCWV zrqu7M$~-CeV?8z_Fy+et$26L{xMD_*?{!dFQ6yYufIeB1ifYB095{-TAu#U`IP(or zy!4MqyF$^MaQmr=O$JD$7lu{j%z@k=xE90Nhe#cMH;hq030%*xZmqKV*5|Q3m#_#< zRj6AtbpFun7ZhVoUu3r!{*ml8YCCC?-SF@^`W3}?4SIK8&HP8!+wFfAt|Lf!HI0p&RGelk z8>=#`RuR$S&kt)cxNWx?BM8<7nagY9D(gfw4GpnbVIYi}-8k07>uY`jtab6;us#Mu zj=xG}BRlYwXAK#ZMR&hbE~GmLIs5M=_3W<}dO&boTx^)qOiQ~zQCQ_lr~hoa#1_k~ zeL?7&QUKHHgIqD*^QOJ^f+8>==ze#{L0hZP_Z#KYX>x{LZCw4djtKad7`I}5R8l!q zDeba;qlvje0;kcjN~i$nw|skpAJDiA!dfOPB{8-}(6!BiP=a8&+6Tg_up8HRyuh3l z9hl+-kB}0QgLfWJpLeCpr+>tbm4|ti_QBY9ZFHWVf?hC#BsUMJSHjPlk-TumzpbLO z_pdoLM}*-*eMLX1aX@N21VZ@irT!P~-%?xW~b;$it!R|Fk<3%uKVG-aXgVyng zNWRj*GA3__*_bd_)e$3pnfFpq`6T727Ft7P5VEh+-#GpEFR2zY>LaEw1FvRoMvPc0`s>+U9kVAL`GNg{d&AHnDids`C9Xh>6&Q0vye ze^5Ep4I;zi{y}~&ZnaY$iy{*#qK$NF_I8mAuyv;XJFiSukBml`m>ZM@#!_gkJW9wd zN&KLk)E`um9VxKAt?d>UF_Lw`SvEBdgOaE|L|5%&)iPLlu3}4gu2D`tm>}*{NP6P8 z){b<=?^)-e3NFDJLs2zao^A@dY<2I?{GYey^~2IT^*fTwsZynYTY;V6qyLAWX{lJS zvQhUBtoSi>>U?dd%&g2s2ggk*+Z@JHE&NW!M^&vV?gQ}y57jDANB9KK7YVqpU0u`! z+d=SSb?o4_JHgao4U{=ryI*UQ`owBm19juJh2?u{Uw!xtUnoPQ#rF8M2<74ULi+n( zoVa&3{@G511=|R}Stm4GnJUT@NBSi{yFm{@F;;L~W0&4{_0^sh+KK_~VwE|lx$`}( zcD#rleSUE8x%xG(F`6fHCT4=Wwm8rYM?jO0+Asb~gr~5wktldAtUWcil<{17mM=g_khDqe*Z~@fg3J{7Bsk4f#yL?vl8T0TBkhWJUl1~twXW3$nhGx zld#BHd#QG|XC-OrpZ4(+r)@_Y6P;WQUK8X@MMM2+Jx)w3vQ^P?cX-ovmPvqtjYhI% zaD`e10NvV!awFYS)NGzAvVLe#V)Ke)ZnS zr7x-eeE9oqd+Y&--7MDsB*8K$v{yfcWkRxE&Aqse2S>VuI1fB(EoN?Bkl%o!7@zqs zXjb7}k9S-e$pghq5mmYf4#lcHnrE5@%?0$?i_*Ybdd=TbZ0b8ll75Px;v@P%O@)b$4!u^c zcE!%{cZftcn%kKDXxr2*UjX0CD8NBQm+hC+l$dZ2p`x?ut-c#>L76dD*WLm3qg%*V z8<1^_DLmOGPwz#?bH7F5dVPs$<$C=XwG#Msi6juvOx=Rp`%L+9?Y#Tv#`$mCaZ;y= z6$huuvYzlwKYggB895tud+s+5NJQQD@3v|U)O5oZpL~KNk$PniZm&`2-tPa!C_n}1 zzg)KOYdzT1Cu*EsoCh%NDt{RbJPbw55`_$rl1ZfM_lUXX?ezU4h|P|Py!>b4V)*Wz zY?7u;NC>JI9YWtFx}JjV--IA?&3#IS0q6OCX(j@86(31D=Vqk$`n~!%nL?nCYsrTP zd+V<7acqgl0k{wU@?_iq{)yv1Ezd`yXuF}f8W=R>m47esTrl}r9s(r9jO|dy{4*PS z&^l0ThOG268Gjo#Cl8S=_zm6n^0erYvTg0?hjV_5h-N4yKR7BKgJ`|sA^;-7keVEL zDao)Rmfp~iG$GqC7t7eTN8?=;X_1;A8Er7?!K+*Qn%?3j#P+@|9k`9jhZYqbP92>b z36R8x3+8w1sp>!#Pir>6Pn4Jt zApjLcYw7Py2q|A2L&nQTQ|uR1YnHJBfcyYB%u|-ows=kt1tagOI2osh#y4Dd zb;YyP_iX)dXmpdg>qH+LY@!IQ#`0xbe=O8TUOTPI0XLz1zkROv+KvbZ8X8*N$33+| z?{tJdi;o)DWO9pXl}o0Q@qP#>onKjUSvL;Ju=Gk5C{!-+#+u4oSK7C+&Fj}Li8;Eo z-Pc*Y@$l6A?8JfgH804(OHglJTv^%EW#>=Xqo1hI!G_{aVhdWJFIO+PaJOrg=9w5*mV3)fA|a(8 zqyConv~Jp=!fD24L(5S0BAI=&n5`t9qF~WnzU;Uc)vi=1ufOJxAepgQZl65$Pa^3J zoUs7uI^JG)6B`hk*G^m;M1Y6Q0isfmHJ+1`}+h$Mp^5PwgfJU8`*08VG3uC9yJ zeVXgG##T@$uQ6gy>L)a5CaQ^>f7@D~)&nhCIgIddT?!^zFnKVWWMsJYrts}h5TxGh>9e~#O}b}0uSl;|h#?{WE6IPFq7BNXk|G1k&jJuJ zv}bWLf#z#H=S6k_31!9gxG`HzR(Tk%=?YAn&HfL}*o1fu5fRPV%w^1<9M@VwT(rGAAm=p7AySiZ*NGke3;fQBr>O~#yiUZwT!E;ToQOBf!deKF z!p8P~K39g8q8duY7Q^-v9951aNP0HQ`?vt6iNFOap{BI3X4w*r;GZz1()r$7`aTlgn~lhJ|zclc5f(< zW&FxS++fhkukm(qX;K+2#~>{&$_qST>uv(5Q|bWo*GtNxxmXsh2GP>)wf$30uPlfL zv;2F~9z`{t;k)o)tlo@!avq_PRt1X`Ri%P_)0+#zZ-3WMqhvG{d@6kT(5Am8;yFc) zZgakKvw>*nWX!2Vr`WtB2Z)bKB*HIhRpclqlei2DSe)H7_~pel@eONVAa-f>v`vUW<3F z7W8cT?U}VMF1H31(^T*No1iQrCXxvk!SgMiro<*Q2&&CJ$1d> zn3+9VMiaNXM1+Su?qez{>eTSM@D2@h0$F~BrIQFZdvXW}+6cNwjb#d}`iW^dL4|{p z?A%T_-KGHK%|n?>7u@gJe(bNvDjI97hc*s;jBgo)7V&E#s*hAuNZMLjIA0}${M_8! z&QCTz&KOnb&Uvx+CPkM0Szl9to?3X62hoIp1+mY5?Ny?yjR$iQJ83FLhzsKzh!6-p zd3mMwxv`)`Z(b4jVnDHP)`++8En2cmuSST40Sul%K~UrTh^C2)eH-F z-n?}<(i@QUj!%djI3O2t|$jA^$W-BNfqv(idV z7V#aMk@PH!o|Zq*ESaY53G4J1cl!Zud~|-&Urr)IFK5{B@51aGJmiKNwNiydz8+!K z3#~%@8BKBesi%!g@4rUj!GBqda_IT-OY?6U4=XFc^a@C~`{_-7_Tm657BZGqeNC3Crm_tzHLUwr%;)*nbY;AB|w z5h83#88h`ngRP(j7xZKYQYOLLgnYsrA(cEam}9L}n`hNbIU1dLHg&Gc%P6+Tm4k&rIlOwtLKN(Jqq$qCvFFk90UHR?Xaofu8MVFK2p8uPmt;qCh@; zkTt32L?O}47%Six-6q7Z^U~Ml3#ZhqQv_8YF`Nq;uhVU^j{-|sZWnp?gcl(xKLn-f z-zvKSy7l|59q@Dc%Bt!08UH@5y$A~n53W~FcZZ_Udc+XI?-hZ#RnemM`OLI@YieF;P3{TxG_h=mP({-kD|{iEO8jc~x*xt)Qf%Z3_| zH+JUlPD)MA9dqP$3P#v>vjl@RbbfJEsF-Si)T_VLznGj1j@}8_epHQHnSw#pyBtx&no#+2Qmb)ZU{?BIjC1sQxS~v+EFEs zrT4*^b^l;Rn@+!!{Iwj|E>58*2hZnLkoG1PRfNH!QMcRS5RP*}9{#IBWE5!qOZq2j zE_M+Sd549W%Kja+CyKZqEiKVjWl(Cduy0^N8f2t{9)?B0wHv5yv`JA__+0NuJ5L~)iosn^hkG4 z_VO0Z2TI$w>8fIW2?*5GDIixq5i*^#s$ih75(koKc)3|y)xcw8=qOpeEB*1+N>?|^ zi-`^q6ud+cuX%wUq+?8}9w*4grBo@~$W61Z6qQ(qnvXP<=IY9^w2F=?7tfes@ImwD zx_do$?x3Ka*BTthq0>AgW~!iSH`>ktNQb$3zSHHcVmUkOOr;NZ}f2?Un1lnjV$R|^5qip z7_v<@D2lN__$4X03)*{cd#vwhl73U{P%J3`iGh}MjD47$NE(SWd0WqPFeGSHeX6T+ z?jL++%_-5~QT~K__FUP`Sir!7bR(vhi67grMU#<*ImiDMGdt70k**C!t*P`bW>i*C z4=F9oI?(;N0e8lV6KzsTDCm2A82wDK6n1y#?RL2v>PO^kfKhr=+0c*#5Ls0yjZ57; zJ>|g~BZH!!RV_2onBu>(nB{|DFH>@iFhx<03&@OnQ zblRi`pOE^c2m!&5$2PF>1mpYJ?}}WMb+`uXgkiEb^z=o9s&)Mq*UlrpLYmi$9F60#Q;+_wnlseEcBA(}7y_^9DD&lG!O zc|nz|k4T2;x$Z!psjPtT;SG&49cm!XlPxBWrDJJvY&t1&sylaHZmlMIkZaIH^%M+ay zITRDW2Jt-i9g5auI6hHG)|hLl&`13*bS+PP+K6g;e!VLr&0gKH{Dp!RCGPYOQ>~h8 zaZ>O>R8<~ZgF36=1ERJwCiP$+|HNNq_$R;3RT%|NxFu8kHPSN0`T7FHF1+?%hE$>1 zeIb%q0vXHwWv;9w(nKNJQmPTtnR~2Xime-CstnmZRH*q7Sq2rQc~PDSp~a)2!yP`=_q9vlLQ`t0wJbEw@}^C}}kgBpWd5);j8cJCi^G{4!lvJyMcaLHOjw|#`<)1+$ zMUMF(tj6t{#WgKnVKCO@{QPd?M{GJfl`b{8D+8OuC3H|Lg&Y1KcHtIsVp2@=L+wom zYh^U|pAFD!MJH=Q5t14v^zh}=T>?5*$i68VkJ4%e}II4G-SO3Z<;9&9jWP<0m6KC*@K-Scg77LMJ z_%OIna!VHd^|?o`R+e`TQ%r^mRzp)w zt2E*xuUP_tv;>%~rDyQEb3@J86YI*&qq=iY^?v3LrOlu46E5A&kAf0yD_rf{e&QWx z6!(4*b8<(GPw_9)=ByO%`$Fy6`cAWg1b=T&7U7AR?ssm{@1T#V7*F{Q4X%(8rb4__bPBifWKW5Hd|fu#a;Gbe~f!7}6mBwyDuGb=6KosZUe zf}_bXCnVvtQs+er3OGkD!N)L&U1o!<@2k;mZCCI*B=`nYmO}7FEAIEr=7XqEf;iuC z)=2x%ATSGBapq;!LzQ!totyU(twk^}Fp$!}NM>d(K1GMBHrAHR_qCOlrp8m6^@oDd z8Y6-0t{?^%AJR7!l|lgCc*mCp`_0Q!cU$NTW+^hYmoZp@mJUmBt`2VQB!ay!pMU0F zg&zu~ff#HDvCEK(Y=2QQ)Hkq8OsZMvmlex6vM(bJxQ}V}=jWf9p&%jRRa!!Lj_;G{ zh|h{&>AVfk!9dpf#pjH*IB`7ZL}qQ5t~uO~UFsg&M`!pz(t$?(<|cS1;QL8MfFdi- zy8-{m}Lf^+ZxaZQ1qb}e4SN#K$-n6sCMe1Vs zn>XCtOP{m}r8GeZQk?L|^oV_ZGn0$MxYI*mSry?m8(#S9JjA}%R?+nGE~7bhH&xAf zuz*d(7XYkEm3FbkbNxso?49vKq1NH-v`!@`jYYrtPCV0;%*KZoY9X)Y+{Ilg z$xKh*N8y1X3x;K7vLeQfh_prf-}SDZN&+b2qX=^2n>yefcOH` zfsK;|g=+daxD`|W+txIITq5_f36%t&)qw=FNSYz&esRt{l5X~D>N_%^ZC$HkWy@{x3?@Ri}HpB9)c8z1{U166i^y*uZz2L z-0b_Y(Ig)^5QfG;aL+20wYkSRweY(lJz3OBv_1{=({B44P3ZPWnzny-nalm9t;ff7 z)ham|c$BCDe;E@=1G6weXb8fVz4g7lktF3ty-h)0bR}_wn*kMEL9Q&=@BrJDwsCJX zc~`+F7p~k_e4kdnGxKY~is$5qkV3iFi3MP1_8Dw_IxIDse5TG#u^E<5@pzs1uLFf= zIN$AFm^?q;rn5qn4Upv&yy(V!7iPNMTWc23NVO5ugJ8Einr-7u&`9i z%&#o|;g1uIn}U^Sv~r8;>b{^$o=#vGQjuXmf9*b#KL4qdy3jQ4N%E zC!`M)B1%N+M7XFrQKW0T+#&GGA%SYK+hKJ3&Ga3b!4oP~23_|yj5Lsi=AL>Pt@M(( z<`UNe68We-dx4D6oU&_b+i__v+JYDp;zN&)rf!wDCf$o40POvzZJ)>2P1#*YwG6_* zsg}|SF-gkG_}%w1X~+wiP$Z6305la#Sk=0{H{8OkPs4M)YDgL420(b`)3KL>2|dLh zeNAbyjPXEtvl5=k@?^r@^^zrW@F+A(NBXg<^Q}^b;q)46ZZX?D&W3=>k7hGV)^Y@( zYXCo|(%hVrPtoByqHI;1=-)9~cDw~_x346`P=xPmrw>QPZ%9q62O|7_t~Dbl0Yj9w z=z5|T;fP;MPjBpDWS|82z!1ltm|PTju&XN?MwwrygCHOsJAmid*45pe|D~>}xuvDs zC$4N_c2*XkawL8=8=DE_b!${is13X6&$6t<{Y6Db1Dagdop!(AG}^12G^eG748ufp zrV{O5x;O4meN>0hWJ=wT-dXSO2?b0&hoYe!%SW=zjM(c;u8=8QX5$_G&Clf6p4aB_2R0B(!F z`KMm!-6V_SU&yLg) zcBgD#J6ZQy%Z*_G;Q-zV#2#V_7;z0QUDCrN*Wlej1k?L&vy{c4tahS{wfwNo!3LtU zQvd5?YUPcywWio*NAn%t=rFrXWld5NAH^~Tv@Y2@heOP*s7HpA9?WoK2s;qf01X0l zhOYbg=CA1biaL~5jw*PDVTR!BzMsZ70N9d2oUlv2wY7C+Ed{{AHlP1a*J2=!eCf8&{;3y0qYjTG zv+_tq@EM3^pjQ!)z{iU^BJ!J~qfvxZl(wljljM8+fX~mPuH5itt5@19LXh|KFa1Uq zsXl2{m**dV5(NzMbt<{Jxiw@a6LTYA=znKP?GZa-s#(ks0H=#tpa@I!bUwfHZ@0|W?`X6GWwV1q8DZ=XrF%kf=*73uRhGqiFd@anfNTB{D7+Dx^ml%BZWx# zu6>N@@Cf&VfWXrv1$#-spB*D2*4<|aaZoZ9u!p7fc8W7O|8EnlY3K;-u>#V7GLvaz zS^J;>$c!WP>jY!Nu*(+Sj2cu2tau;2(*|bu1b|>S{dJXye<0pZj8sVOM+#CtW0L1+CU70%n%tUcUccP_lyF0kQ(0na`5H(XAqQn+9{Me*qP*g zf5|)@1lL>$be4d>Mj_>^0{8F{;EkTBUJ_83CE!A;5!6cbMmz1()}O^7dqyWjE#QgN zK1!?ol-F}K1YYaz%Q-|x*U`rqg!2nv5|NXWQ==v-ypZxvqW8af+!j!;W3u~sP1Ni_ zn4B!PTM7&*WDjXODPh0$=;pyR6?;@x!QkIN0@lFQ@lB7gG{4-HH^TP$EP8ZyoT$K@ zb?XdruopKeC#SApA@ue4%lRBN=nd_;g`ZbyJdW@RZKSVPNeNwu%a;Fb;0-B&UXiFL z+5_l%DOVR2rTIS{l(45?qf_Ims)mJysYp0gNTpw>`=763W8?DF)!81KU0`ZTfe=iE zHyM4E-g;hN`;I0}YF+%zTwj-lGQClJt@pokPFsczdvN=ep~cCJ9i9*TKOkhF?`G3y znIdapytl{hvz3%$d$B#5m;ik)r*dYJPCYm}$~t$=eKsqte70$P0p$Nv>V0pHBML=i zMn;E+B0Du-*I*i-_r}fI61X%31KQ12hEa|U5|faSLRmG(E}Z$JcKf=AfZW{?=Y$@s z6hMsugh1^G@#w+8F54`L?gkv8C5izy@Ymn&&kX`gHxs><<+%lygM+2QD`@Je)t^Nr zYGG((KYU>EyU2JUsPU^@ban?{(*kR#;`HcGc{lWFxNbBo(K_4>4gV_va1^kZ+yI@8 zeMUP-CA!iU#S6e^0S0+y`aD&{tLKmaQ#NjBXvo?Mc&26Q>4XF<{Q1)#*Hae2yT>KS zxoZ|gQ*%nFG`FX5x3ojCK8J@Y#WlT|2Y}4MYyWLd+t}uHAua78Q_t#Z8#%4ZJNpH5 zb8|osmW~bhA>B^Ly1TlO!%oW5^xcs`s?Q;d|6Saq1?LYv7QF_tu>VpDiwW@YO?Eu} zo!i@;EcoZqGO!#?0<^Hbj{yhrsdkL;jlq^{0c@0{g@xj8sQq_~bL~nU_p?>E4d7RW ziEoSo0yGaeK%f+J0Kc4=wzckEr03^=GyUD7hw%Q9f=> zhOA6c693&T7{6Dxrvu56yyhvdp9l%rsA6{<_v}2cU&Fk$qa&9>w5Hb?duJ7styWhj# z)qDV*?z&0vNY+g@l>j`H-`Ar720pqZ%QYqY{z7l9~>w_&OENK zi6i^pWboLGoj-DXR>a_Ar2VNrqjUK{UT^u^3h43+82aBO)k8NK%1cX42x*tXoSAiX zb;IgA!ou#qK!O^_0kaCVJ1b_OMo4;A~#$EMC%pje@aNU1TkKU}WUvbcDP0^xKQM{l0vn}$Nqt!`jU-J{e zE8Ur#@{!7K2Q4Rl*B`nn<4O?W3YaOKys!RVSy?>$eQpBiI^MlViDFPT#t z7)6YYVFKX!i*?olfCY)XgRjs#Qgv7X`l(FXp_p?QgU`!X*vloaR}R%P(xf|w z$MMFOv>pa>V7}xKWT}sDKZ$279}n`rlT#GITW0srj3{IR#-ty}M=CJ2z4K>!<=0Bg zZTtFguvbLlmgkCh_^-@NGS->NEr6euV$>W(zz&fF2#b&x=UT^MzETq=s44}J z!Bz+8f29o#?w^9uu+`+<(j1+fcD6Zh4GVc`8-%_#;4q$NKTzh@UtX3TzHXOPt2Lt{ zej=S*`t%OZyXi;&4mb4w!aTH*ho9dH_!^g<%;XLagYGWD-bKRRn`U6{fYpxOIbS27`&kgUM{mEp=xpgsX~K8iCl?XTR$C9JG>oP90{ukQm1hD)T4C8^5M)R9)StYfRPGH@z3{S9eiT5HTR^~tgT;4^4w~0pHF5B zCHwAd(_{K#Iv{*l{prZZfxHFsjT;&<5UOA2=jE{nmkYNVNj${AJe|6J_M4@Lfmxo~@2#NF zv*O<7sn^qCq8UZOOK3cWk+=4yIJoZw`!ey9XiYhNJr~mcfdbRXdDlDsT+#YzKT-7Q zN5{&~ftm5y#%2wy7@ygy_h}F7!C0bHo}W*H6`hb^s2oOWimczBsk)n~v!FdXlh3qJ zHH2A@gGBz^dgTr<6*-<%aTV$s+F%vgv#Vpt!gNkdDm5G5#;$%LjU!X_(Ux1O^n3VM z*KvLD-{(Rv=chaxrX;N6SbSMISa)1j8c6ZsX`>{uBQKt~SWudCw%#KO{}U3`|NXCa zff-$;mwWkfz-BY)}iq5 zVsvjgJ8UKz=%`$S*ps-A}xVc_msyCr{{1jk`r6@TJnRF$Nj2d z@{<;cy-^sL!nCxbChx`I(=Rgz-ZQ+My|2@l&1gAU_SDv?!(`>UKw-5G2V+ugKakKO zaVm9J^m6-ag!OBqzr=-;1k>k(loz9^3|JTiAvTS&ITAl6HWjaHZj(;nGlFwJ4)9`| zfY~~kigg~^j{eG#e9uiv2m=$Gi?Py>|2+0^l4AU<{;W-$ld0nKO|E;H^6|^-P_6#& zi)v7a-g%Lk7i8gMo)RYof`>ZoS%9frZbYxQk?w#WeskOB2?+N>%zWIavo%bohAWx8 zOK8xN&q>=k*2~aI+Dg>?lCQ))*Bx@lg~-=Y_oroA7?`=y(e1+A(6?=ry3N@;R`=D4 zofYu4^%QtbFWM?0*;6lEH9Z@{+OOMZ0XO3Ji7JXeO59wWFBm`YrCHz*z;t?`bdTqBp6HSOXuL*LD9V;5%cblVIN;hVl99a_U2(6| zOjQV_fiXe&7cDzoEMuMUxa9QmQxrI4hXvYi5zF4aRdh1g@HqO*m~*$Znj*6L;?Eox zo8gQN6A*}o#rSmU#eM3VvEDTF%OS%bqQC}(nCiFQ%5ahIg5s>P@sZ{t2vINr3E1i* z+@H6iK5d*l8#e7-|J{1In|o0*5^{gL&Za!(M|l?^dU}z=muJSGlOi2?tql`!lqdgn zF}-)x!k3+`o|Od`!M}|1FCscE@B?h0avIK`^b`2U5n)VbY)VU!IrP|#5_=fc1|{wp zI~Z4AzIS|;Z-2D(-Wk~|6ma@W*0HI^^X+2DcZJQUz%(p0UR{N*?LABYIXGO{n)^L5 zOT4m#9o~a`1sXan4iVjG?NRQ`!58ak8R(}Gm0@W=V18zH>DUYs1bp|t62m<`$A5-6 zcASWNZ)B_{o`^kGXZl_+DmraQTGYN5^4~Gsz7g{t7u&Gn*7~knB+5r$baqxKfCW=6 zv&TiBbr<-Z7FgI%tHVNSCkKi3{kSDWSbn3SSv8)>Q-0D|dEk+t zf*cG8C5-f&7bYjr+FxWN|B(Unb*++bEx_ zGyRYFyw4t1Qz-pkn7t1eCAegQ!?P^kbGLcs9E!Lp#4Wi*leHZ23*A6r5`9+8(ku8nuN|KHt>rGwe8Q^6T*jRy5D zH8S0-`0{N>y=+uKEUq>z$9yi15&v2(&t^fq39IA@wHC0iVR$`(8pk1QTh4|gZ=5TmC_qaUGnMar)2_JubLhzvfs>XUF zeX?*PozvRle$Y}kIXyoAioXbIXNZw7(>LA@8Tb5_$A+JQp984SDU2Uh3JE(hJog~V zOyy{>cGbnf-l?}u2otGylBm7=i2BaMLfpmAz}5SD8WDg`V^Lz#FaAY33)=lMe&LADlPw~=M^A}na%)MT-O4MnRXcJf60r1dY=bSo+Abt80w z{FOko`CAzbTJF3&7&~B+oTAJct`TGp=a5W{{pTESBAgj;4b=@_c!@aQ13^=hu}X*n zhQsey?WQ$e?NpkTVV-u4O5FWoZXHL8v{P|}ti>$xiA{ye@~)*8OP5Jcv8i@kA_6A3 z+kKWDw3bIgKCV=4vn>*@)Kt!h%PcPhdyZvvT!8l0*YHpY1DG99=p+oF-8o!MhN9zoysEiu zEiyZ8)53!o5jbG;?WouN>pRY( zQ=K9k-Jf}EZ=S2cWjh)yt1B0)Wy&#&Fd-Lt<%uHK^Lz^MApqE?#IhSCqdWzI>?c`A zhF=tu>mP>_RZ{iln4suf8U5S$#lWFW?Cqr)3s4174glP+>flSIL@A|=b!FW>B*{qGP)qBym%Wz0%Y=^sX01rS zKMtZFs;>P}S@@>*luWM#$1#mH@bP)u5wX!XBn|}3S=^7WUM7rKxrtNa7*s7D?2Y)fn6Zp_2RJcqctl{IKjedQ{*S`3Aokix2(JUc?mU2ZGghGZV$$V1|K2i)T z`0$|}SN^pNnZ958X-yb|)KInSTMDT>_%}2wpqwj$p2{^sjAX~d3taFwvgi_D2mFa} z{+-IvV>Hu?Es%!~j>KyhC+NzbI=k!Yc}~DAJBo+)^3>VqKe&Cwvp!FOIGW?ybU%f6 z%z^=*hL9sNR|2jWsYXFK=XM3S6oxlWR&{vdtKm-p$XhQzxy|6I)t)HHx%%|oJ|EYf zoORMd=so`9bD;4VxZTCZu-Qvm85$5D!QW|tF*OXlwq-z6Edt~dXfWflA|tIoys+)Q zzP{I$WUr*)qr_sEs;+SuHgBq4nh?IAH_Q9W`)+E51P3#4HwlVHW8AhSrVXK-=LQWQwzloMt=sAz; zl1+3(UTZ7uJn=tyMb4=FslqCPCmjAtP`S>C6eCW^q0VHE&C|!ToRNsv^LH~oc}5XS zj{qM7pu@Kx|66m1y$L?Ef;$IMjwJl=&M_2(h~cM2JrN|b6ZoIa4PkTD-4{Ytj`F=j zmN+ag86Mw;Aek3@)dTsE-i1Vll%6oNJ1vV_lc@R)B7BZ9Mog8*r*U`;7u5*q9`e?( z!A;kK%vT_bdi;fe0xQo$cwN`lGH1V8%?{HWp8@v~vARBj1_WHYVL%D?VZDG-jea~x zF?w7Ufc^qO3XThLQzx<3qbH!wLK8odx-vRN_rE2<@dOm4%R}jH(h}8I!MJ^Widp@! z){ox|`U%I7;LhT+_txs7P7F96{WZTY9aL2@A0M;HfXmTQ*|{#C7p)><2zO+M6Dt6i z*nA2rv|y28Hd>U4v-&vMw%BA;A>Xd^ZNf_z0Tu%heq3^_5M+!BPJXr06^xaR@(vJz z$Nzos$n~$=DKb#6$JHXj3E&MqWZi(_ca-k{inxN;UQ}9OIiP+q_(; zZBxn*Y>RP(2P|shhPM7@qdboOGEI6+$~AiNDnkNS2Gr{!x%@+*l;h4^3Z1jNBqH!J zz?~`kBGIE7F+7~?Y~Ng=^3CQ`O}@u^xkQUTd_X7L3O(NA0Lc;CvLjHfx3%axd0iol z+t*9wV#XTt6$nJhQA5`@5#UigisIv^lu%Y$V$jj8eZj2+k8AV zHX57b7&!69g|Db>d%R9D&>{^pYJk5Mv|_3&uI%;<=!yh6>XCmSlhB*9=8RPx*+uk& z$^S;YKmx75Hu!f%A~iOJiQ)gU88AisHjrTjv8dN$H|nr^*F)H-kr;MRga=vx@LxF& zh$$BK58dUy^u5Ms7RTN`|B(d*126m${dzArEr+VBLFR`t+~Q{2N<##lj%d|?JXDZ% zM9OVEhw}jH5srx0A1&|Yc^^r?Qw0S}y{7%ZEje5RT#^~n&`wiejuy>z2I~{y{~>r! zwPV^xcbV59bVXavb_$_6mfVG-s}U4%z&gh$6(r$Ii#>5e@znJps|jxF2{?KEOrrdA z_qaO1Bo(!AH-lo#s z*=^1W$At=6I`-!|F#9_@gvqxSehvy~J!ZSjWo-I|jxsY2g>^qYeCDc|b(aXnV53e( z*8gyq`fg!%<-9()C!8Z^x@DP@nN>M=!H3psvJ1m&+oN0ulk@l``_>EPyR+?z8C_9< z-skT?f%>%m;KomkG69ZceP0dR%BfSoEo7Xgx0ycBxHg4&rk6l}Sfc|^|;VlkvIo59TRa?as z^~9CO9E7S6qP$-yL1}ZpGz{|?f$cX#vmfT@OZPl;S7V)O*gP0Zn&+<^|TAsB+BcG~{C0A{}}sqm6n~dU+o8OQp~DIc3_` zKHn#Ifr=Q3!OQDS|FrzH5-JPdP)pLt@-}kdz2Y^wzOF8Q#W*%y=7EO?_d>nzB`pr2 z5gfw6WgYzl+C3&jnItXj+3$V~n76^_CnGD3E@IYB&p!ky{($OWT>H5hy0;bmJ4sCl zG}1>QyjSqOY%tX{-Fjwp=SjI=ju+sJPrBPtRZ{k0KlTWt+&&cI(zz}9(cJoQwv|PQ zVu_r&U}8eg@LdUmTka508M5u0sYS;!!YV_($U3+n4DBq%rhioir-Q?hX4ZF^{tObx zLG2L;KkH2!Mob$g*l8MwPoX(k52i%Y?Eq*0zk<9QBoq~ zrndRL;a(6Gh;HoFtGr&q=nZ@_&)6vU+b#uu{X92qbx9^bDoF^~*Zdgeb)>f?$bhsE zMly6w?x_JF%vRegTG40|%`c@MT7_V9{9`537WDk|RS8$5*~wjhVVzH_h9Lj1=o~sZ zD&-S&uoT&?NntR!7S2rkLbeAk=5Z+#u9ryo;u=H|%%|5I)^MFt>M}-xJxVT96{fA$Dcyo(jXNxikw<9nyoX{dWl zE+~yo5EE8RwA_G?1#*?`dT+M5cgC;kceNNX?yO+HHlKLZDBHiArUbGkX{pwy5oGN5 zX*g*8YUh34mZOr@9jhiqcaH0}Dc88Z1eFnTkRRh9CBpUzMw;D-M|VjtPaB_3m^s#E zr3iIj2n0)Xw7HHU1s`4cFj6|4F&GGb97t6M#NP63@A`}!y6k^q_jFSu=Iov~#2{UO z7&aislyF+9@1vNsoLVn~iaN+q#RYK0M{BIgAP9Se2oi5FpDu(UH~e~AF}}k;OODpN zwKjP!tC&}2yfSav3k2cnorrYI3EmfQYT6%eTXlObyzPr@A(j+y%+AGZiVl%EXMO^} zRV%<`NWkYsZ&&2nNm-CeQtD-0Gyj=gv*gy9CugY+*7G~XYs{A~5rBFV$20G(LK#gV zVOuQX1y){(cNj$b3SM@Jhkjn^jY@9Gbv1f;bESSrfJq_zxPPTLNXIa;-9%){3+Uufj0!6o?jU^6fJd!NEqg&Q_5Mpct7oHm(#V!}2-^8v5Jp-0l9$J&iyNOS zC$57AH(_bGb{RLT_DIn(w3JRyw^hxG^%;;iN~?J+ZrEL+p>18u^I%)zKBdB*7WKBz zp+NU-;JB`Sc$)5;;4=GiUf_6=cpQm+w zz12IJvX`u((k^k-gp7zlnnSPvAQre=ozORWF`l1pS$AgVfSLaLQ(^hD$P^oCJ6=8D z^yytwte&45CA&B_S1CEMT-r@4Dx(mwdM;~iH5(ebz13SBLoS&}T7Fwh#qv@JWI|P$Nb$E=@?CUCD7Q6c03BQBR+1j{9fV?{bN8)hJodGZ5DW zv*Q>iv45mC{)d=`dYrK}t`8J}-Z&b(mUipL$(#gpr$&i*R@@#s70I;D3&Aidd#o;L zE{&#fbeuN5X@#JiuH>^VP+J<4z4#jZUUaa+U5hIEBkyc8bgJeX)zKM>mYf9`RN!ztOcZ1WVnKi1W&SUhKriq`!ZsG#c#oz#|= zp4)IuM+(fz^VI15Om@kgT-;HyD<-PydJhDH3NI4dlmb?loRFYvJ`w3mF44C%iEnEw zg9Q_@cAB8>>uBBHI_+QOW!7x<-X^%WzF0~SlXT4}s)@9X=!pr^3wi`Xw1PrYPOiwX zH)&&R8Qi+gymy!+v|JiUa8H#4!?N*-P|an+pj|&x6ZI6r^37VpX=SfEmLfX z0SM{A1W>QsXx=mrt*o9Ot>owOt~)T*uQNRf04E@HaWqClo*xo0n3s@t`93|^qhM91 zY|iq|SUQp+wT>DzcAlJSb0VDzT$5-C(#aNS=|Nu*f-3!<$v(BRyIGr$VQYT=Hcy#F z?f$gb_;)LTYYsb@^fiZ>m+pV&6wo}mr@762L&%dV&lT@FwL6AxS}qAcjv(goJ4Nt@ z`BsO-Qj{I08Cwxf? z0$?K8U8B&zKqMg`q=zW2-ecp#W;JZvw>R-AayFoKI~KHC2^9rs-hM#g+CH|GA6LF+ ze&6*eG+}0M(CEaqt@>3|!eMg)9<3w@N$1$uz^(h&o}L*GS``N^4!sns#$@}n)6k$7 zA${UFk0yX(8t=j$io2qZ|$=ksOC$uea^qV<1cQ`ZKLm@m_{ZyC(R)^%)Mnm{m5X!%RypZAC)q07l zvJ1_&3=oLWG#iTodZ3gBYf^Ijf&mWgc_yf`c~uEF&$yil1k z9XiW@UBsPqz9WHw$!!34mArSfr}G;(ozM~rai=x^B~_?{+oD2gqEK>3sQ&I3{|c-f zt1mTFntJH@)1vO@B!pDg7`l!w>;_yy23+Dfs2C;OIHWJh;RfPe?hQ;l|5ZTakLL6D z&>;4BcJbUIBPM*sewof>@{bf}Dz!fUN{INxvDrMyi@F8L*}C0lT|EMf?4C!^%0x0^ z)&=HKWK5jE_ia{M;BB#EvU^vx!bT1ec^Np93DXz#S0<8NKa{qIL%*$(VG_P1#?>Og z&*SO9&iI}7MJ=bes~Hs=aa7*t(eqWyUh`y^o@&Y2Sns9%Zys?oQVszD_j(QVq?b>q zgEUsiFl1yroRo@*$?cw+HJ)VGu%!)lZN33N9yD$Sk>N#@-Uz+9R_!DVjTGRTCnMo} zLR}+sfP_E&1#WhlkZdEpAtu+_+)@bI{+suPmap`SvGb|vIU6_u&}C)V*;q+`CZCi$ zmsf|prZeyL{NfC!P(`hg0CAO6NU}Y9yDh^&ll5CWFHk zjhkK*)+L^Y_wm*HbOo<%@71tEZVarnJFMM=gqig>B5hc3E7iRSk4=l=m0nwN3inD* zepjx%luMmYl?vM(0kaPqrdH;!1xgf|k8tCtlNMo=0A`7V_SCR9xp_m@2A-U|X` zCB-$uSKL}^8|kkyX7F{yV_&OCGJs4~QSaln{Zq5ck7YlzQ~bS6Rf!YAxoZ1l8`-?k zz;6W6B)?E!VgF(r1$n1mQ)6b$xe|xdqJv$i!_k$;+{gaB8RzmN-#nDV57Srp{yBoF z8a~|T<5_;VB#en|A0a=EKJ0cf0#unx#@(_zGv*HO|FX^&nb|2n<4MyAAbBL^+k`K#c@HFMPiDQNz|>NA9)mkZTh zg$1IcQI7hjBX7Z7d*#6;In~=UJ;$%x)rTWhSpYX&dj`0-m~X8RvqWZfdSE!HZQ7TwY_${#A%z^rT`}jHYSiGqFK>yHhJqd zyUS|*>d~N@*`XD)cm*l~pw41rQ;V-_Q`IvO$F;fD^SjZ^vYOVwosL#vtWiBoDs7fc zC}7xv!|LeHrq#8MACYGMId$RmnI`? zLQeSH24P+{wWS&gojNV=$2AxQHmWUX=@ei+>9{OA&UT@FsZodg<5Dq#zUc*=C2C-7mJ5vXMMa?5_CE|8EXXi_JmLeCTsH zxPz-oO5wRHX7gcEymaQ{=E_6WHyrcKLW!3Bq-r1=m!R`GiEe)i#D$N z&4z}CT#r_zyumEf^~EMP*S{05Q_|BDdU~Xav`Pwsr8Z8yQ!+9VM+44ls;x&eJ~a>n z&r11e_)uWT9y0F2ewnquMwQF-+zV-YFh})$n0SA!B*5+F`i3bpA(#YaPz#Uea?;>QYAwGwxhISZ7FN< z^t=m(ztF{izqtf8y|(M?-HM@;QBhW~qf%Te4mmcm;ki}7 z_20sckI3X~IK59gqbJb`+BWcxv>4O`FQR8dgHztjzfqQ4F{1_Tc_RYYMV=06Jm@_o z7Uw!~V`UmjMN@Wv7c3{0hxYLmWZ1^v_bEUqLX4)xkWhY5A))K>xr9Z}qhLC`0gOn0 z7xGpEQ|YJsn5EQyX9(LPbb`0@r(jo6)XJ$E`+^Bs*2Xi)05u~fK~mzd!}v_ohwe9c zJ?Y;$3I*Iw*}@`@P@a}}JIwofXJ-;%x_^BO5)H$t5)yE_6DE3VY78y>7|%k?LeB5n zjU)}`ZsR~LdF{U0z`C|Fe1D^d1)fc7k6>>F-2B7}5X(XV_ET}1xi$&9^$j>AveK*A zUFhh#e{T$%VXUy5PzGG`k8jJP+30F1UVoKGkbnrAq9+Vv2dab13($LA5O0GXj2Yt( zUi~(l6Ysp}f3f7c;MkBNYsA)CiU9w_jLjd+?%LTRY4$__K2B$RJ?66mApFuX!+#B|roR^y!xnMq5js64sK%op zUMs5=8r`m}UksU22MpZR2m~{ESC;EJG+BP`&=?A9x9Oxts|kcmQMLV4aoK+N6xfc_ z%7)YR7vb{{?I&cM<9#hKuvPM??=tmf{o6 zSb$3IGOF>FRc4md<#4^aA6+@zI<*aM_;-}4d1b66Xt&bFT1eX-15Xn@gHJw z^DsdYK!a@}0?i&^janXWwv_K*F#Q+y4Mg(+WRb%soF^3tok>;#@h%c!aiN6*w;wDm zd%oK}%`~b=aBtrBHvIP6XrC{-Xn4%AzKc5BYkUFD>RM-{MhVcK;-vN#9yy)|h5h*J zY5)21unTKwadAUkQT3Tbg`}gv) zac{nEh-uzftpd98r_FCu^k%Q-m>}?o!1a7@PUIf{F8AZ=-32W3)Cv$Q-j20FO`6Xi zQtRNbXTalZDHxGtA!a(Su~QXgXoEnLJan}>#aSLjSt#fU8N`rV&o zCJFv)>_Uo&kB`@H?$E-u)A0hiev|7)IFVM#*4h}`K*mS8l4JMrAaz6q3^iUH#Y(-60LzhLHellmcHP0%Yr3nU)8`Q)*vx z3R_+?ry51v73_3`kA}KOD$csUX(VO^Bq}*(Q>AVD%Fb$mH`<>=eB!G?E(*CGman)8UB1+!* zdlHv9Ol@V|1<4~e-Q8TIy>}uEt96!s=+zS!0IN<;3xLk(oEnLWZS?W^6W&`t54yo| z8W#QS>53?E_v$gfc5F@h^H~O(R^KM@lHMc-1PRbzg{o(qxPtCl&jN_qVNB+cLgl-< zMDN!`d9vp!wY6!bLb{>Gfx>@BW;w`$D>Q3>APaE%X}7ImGl+z;Qn)%(>iZ*~Z9+@O zweg4YE-|0jWlC{z%l*A8q9omnCxCt-&ezCS7#S!yleV6oJbx7d;L}c0MY2?wxOHk* zHn%}5amiU``}PIbRHq6hJcP!Wyw!#WjA4Cph3mdxaIZE0bxrbIc41*R5d}em#=Om+ z?|-xTcvSxl#crvLU-(;5NbP{p0rqz1r}nTGP9mc?@Q5c-<%3Lf{K0@&S6_uz{`EP3 z6^p^Q;q|=6jcUF%^MV7j;cBJ%XCquv&2xNiRPK(zhE7l4ij z&(Byo>NE@>cNp#A!!ejAcKHRzwJ-g{AeOP8YKH$metx&j(;29xuHVq7 zDN|)+iPvi!>qOwYwZom2E8M8S-z;NnGGSK@OqiFBQbwz$k=M~v?|W|y{80nwT}Y=H z;Zry}=n8!SXYlK;O?kXewpEz$DF^a!7pW$lHSL%1`Gw}_BMb^p#H7lT6*kfPY zgmT<|FzJL@Bmt^TZ^2B;ERs*Wggp6bZ{R3|9&~;u>baM!>(q3G&wsYZd3<{^8llu& zEnArkM$w5u@7<3vXB1#8s;a6A{z*1|{*+T%0C=`HqwmZ;gcWe;@BQy;Yh$X&I!~{x z{Ej_>JY0r)f`o?FPu+6a4-hl{HnahM5w3j*_6c7sGJaTeDJ;_SyZ%ejUsBP|kg?Un zdn1lLw)nhAVvZRMX53HA@7kDP0^844Fd$#f=TcKs<$OR(U_mCvQuOpE*1+i9;^F9Jd5Dty?Wry3FU`(TKSTsKCnW`xd{5Y-N%`a8n*#}3V1lWW zttMd^XKbw+=w<~s!k|viF};Cl=;&@(CpZle9m#h|W}+EN0rIdceF4;7u;oPDGWyjeT5l4xmp+)(MJu=iaX^2fC-5OtM7a+ul|m7VqWx;ID4Y{p8rf^VriMbR&A-P zj_QcU#zw;@qt>@OT5Zb+S_sQ~d!>FW&y8dA^Yim~pR%=|=CclfR>M~~Wc4XqAEMva zJ7%IuJ*p1Uv|YYTK$al!E3i+o@i(xw9ZlOCv9Qs=`;+4ddDho$M4(64LO;ESqum+$ zKk(SOX3kRp;E#{Zr~sLf&;9Kc_!*tP6QQeI(Q9h{5C06{Ynr7MAM;KDKl1_=^QF0N z;zB@5hBkN*es{aZN&m58fTympkv{M;EU>Y>5FemSuPt<=v)6lw@2&?zKr;jB>GZjL z`7Ui86%h0y|L(6+xr2d#`^52R1r^?TzzFp+eX3kYkNl+b8*x10RNu;H2z>Os| zzYGa{X?tv3^x;I5_wH|yYR;X=-$E3?7CI-%j^c3bOXUF`PckKnKzlu?#e#caYq^>) zp9CU{$q*MhlxQHNj_l?MP^4N>L3T{4fF_Lll>;;xh|#^D0R}PJ#{QZa-kTt@B5tR) zLpsjNoF8}iNfC6HEUC+Ui5(>RK~M(Bk~#lqa^NZ22OT4y?dlk?gilq#?DEq+)`w$O zOL*`IL;?pyFdL33eZn1X2gOZYLF7R8@`t$nX6axsfl(WTDa-o#^XIf0NI+iZ!Of>H zGGy?~k`n=IZ%+rKXdr0BcC?!=!wbY;;Cd6REz0;V?8^h6`7~v#B8@8=WF=%TNojhk zOQS!`p-X^LQ2PjM7%N|n1YoWv0=uVkH{j5)(M#|gCa;AHy^6XUGaqbn@vYhlIi z|D*cpWU3v}P)^vZ36S(;cf8W_!F@NOo)uXpgQd+J7XH!SST%A2SxY3&t)^gdkYLZ=9ut)SkQW4Uu=1*2u>G6ZeXz#L z)2ME*3|<1HYEolQQr>})Cjfd1)X0|^luGrQ*#8LGQC|R&BEE`MkxB%A_?(eZ9WmLk z;gsUIVAQy~Sghu52nI34BpGo=zw^g~+qm%W3~J2#6FeLDvN9EP*v#Gg1_f0;|KqD~k55DgUpA{+JEmRlmWe9ei*DW?N~(H&m`C^*PK=-`s~O?Lr$6<6`gmk5nkxojdM4C(^(m z*w1;R#zd(H$viA`1$g$Cpun_M&0HSOQ-DpPed?$@N4hg0w(lUZMq5FdaaVYf(wtAIPC)hSAM5kl_?s8E&LP9c7as7*Cc2^F_9L1S$4bMIVoGeP|>J;_&95#8cEk4yn z%ZzC{ys8{_BQ3Mnu*oX_DFEsSkw4Ci|BOq#D3S;c>Ha3CZ6~f`5Dv5FToHc;s=c`M zcEAtQ42Zk>nz`%gAM&QP*+|m4j(3;`FyU!I{EWiVWD^W(jxg+*CLbPpJpPAp$0Rzk<+E&!&U}3d;-x+ z=*YP+f5T}r{WG)M#b+*FbDk1h^Zdnj4jp@o-oL(U?Re;72Eyb|o@ql?>*Z;fL&yh0EFZ{PRL_-@^jAztxmdJ|%cus*vIjifepFc2 zi9B7S8WDGsaU;!py8lyOAF(%jtg4B;?2vJ@dgPG(n#f+rhpB}~(le47)57p`!`Jin zXsxZlVaUx$zQ16I*WE`*ItFTHboAMWdzKkH6`L{+yicI!BK-OYPo6D<4tjSln3!M1l4E0`}i;-cjAe_2;NexYEQ&JrM0+C0uCk>X)f zn-LShVx)BdEpKCRy~DP)4;j}@501tw)y*uWShCbr?U3;osjbK@(~O#jvHAToP~RX- z-`O93xwL%k?aOVj=$DoQOrnQ6apjmUt-!K~Sr}J#Wt4Z%AQoL7Ft=Ojrh`Ok*Yk-| zq1&m5<1zKofu1?3H=hDC#m^G=dgzgMfuN{ql3iL$P#uArq_$;q|wYQT#5#ZdY?YlZjP+!B$7Z<2;+C7JLd;zq0#bqZFOKik!yb&k{`8}4fc+d~P$ z>{NR)=-LT1|JiPDbGL45u49YvFzp`7E?d zWef?jQnbrMmsOP}o?T-aYgeZ!|IAU5xIwQk8z!r#)n3=;GWJ85nma4QSm#^ldtIZV z{(nbtw{6HG^b{QjJXe$NGQWj&JOq0OaM6&{c9EAN0_Db+-hDq=)T|rNzgObawKo&> z?ttO)bA3V1Rkkro_9TxD0?~W*^QIVw6@+a8aciW zibPOI-+NtCZZKw-dG0-;vUA=ycYCD8NGKTQTv5*1{Ll=`@>r6Nfhv=fsWVEsHI*tT zB`R`nO-xi9wEIbcHMEvKrIa1_fBM$j=wb#%A8Ou#hWY*j9~0yV_5@TyONHw(M02Y% zwq7ti3ND(ZlMYg2lYUM>?2P*JE*%D@hv%lT@8R0{)*e(55D0jr`R>CZKrYjT})p_Y8&(3+L5Zc+c$3Bkt12IKLvel%D^e~>78 zdE9l}jb+Rkzm1MqmWl*K7WkC8_19?8C(6Cnr&~mrFyj+cHSeorI?@_d2Q#z@%f(c? zx^TiU`wt!lLtMAMlH_fAgW(4Z*VRkal<)Luhe{!tswhw%oKX~ifB6ahg*Gp*2A0sb zN|+QGU)TC#K5y`GaBIar>?Pel9CzOQi_1TB7C{eG1{3N{+^2WyqK67{QDQkj!L7PN zV$bvY#S+!7Arde<$EDD^due)=ZhV=>kvJ+xn;CB&<`nyuO7xg&Rf%9U#}MpikJA@3 zpzoX+FHgWaR3Un)`X)4RmsR*YkLdpUdQ!UII2w>U^m4iDYXGS5>Fw@*0Ckz5>I&b> z{Xb0FcCH7^?=~7lZ{EDIFi83S{aZ>3qSXRncq+MoYSeF9f1AxL++%)z-}CRcgM|Tu zr@6PQ)XFjaDAaW(nJvK)e$5~W> zy5p=SPy$U?+w)+F8Qu{WHSwC4R4~?x^Os1OGo4)^W=kOXTJ{*|Ellh>5U_aLYkYl) zyv|9{PUpuwb<@xUW?z~6FNAku;Ni>>nuvj#4O>gx!6yHT{j_H zv?eF{-<%I(AYaI1U$$Yn%y@WV)Fe+lT^Kj6_x3mJzCa#_D6R<6) zyIy>@6xy!mEx^rB^)7nR#;W3e5s@M;X3@8#NK5hU#7_Yr%eQX6c@&c_X(;M-6?116 zTQBaH>QYJ)XvmG#>jsvjJ}?j_alEPgyI=JF^xgWe zgs0vsXOjhmtDa_=D1bB}49g4|^w8z}T0>-YQq-t`zD^o?-69SIbz9afkSx+0hqXZp zAGRkczNR_NQpc;SyP7;?EO$v(P-#^$%^~*nq>Ioe#?So;qwP+arAtBN{nEijclN<3 zzPD#`B`gnG+P<(->&of3T#XzpYjyCV$~8YK(8({U(bwr)mJvo~2XB7BzdCvy_xk#l zUGMg9Lu(vv+gacFPZDqELnsnpH?VdZw+Q)FC-YfFXIp^Kl*EtDMwi=6IuFg*|CQ`D z#OF5EW$+PV&1G`LC)Py}XLU@eoD+wb1J;nrBJvDO=m$2ndRRO;{h*XF2C zXG;a<2aao2l-;ye`iXW`B0g##ct~3I&P`f{Z(jhFZ69-_6$E@wI!?Y^3}p;dFTU>z znCz}sMuNE~j@6?RNZUsj$CpBzx(k((O9TzfXY}|rAf{692bVkrI zx4i~45dcUSanTmebp)oa5F03xE|y&-J@@A<$y<_4kWPvbY0`} z`tom!Gke#XebmoD)ImuPfGt&F0atV*6B@)e#2E)=XvhLBb1}S#9!Kx*a%C(O-WRe> zciU8Rb-vopI8HXN^=Vm~x*CQc>VoyQ8jQ#1lga~vqkhJTEZ?!P1pEdP=qxS<7S)30*TAzlNaY|9Nt240r>!*e+RRq=ujUxSKlpi|o^gW1uI2yvja%a_J z)CV&GP25$PLeeRFeDesoNw~_N0()ZMO)&*ArmWM!(3}<18IO(Ji$#W(4~l8U{TanL z{-fw6Z8J>W&i)Um-v1tMsZ;##xNE@EiNL#OMe(~DXdq9Z2>1)@KT^KacO5ExsBbDW z-0VcPH4tGG_mJvVf>Gl5W2PT1kzf)Sw|j|a+}u(O{yhkgCr*$%ZPiWm#)SFo_#Pia z+nJDICn8NJ1R`{6$WF7T>|DtnA4X-?Zk2yMhjx&Z9_6mha_|dVLiz*G^p%IM7~NI{ zLjG1%R0KkV5IeVzc}rDlz0BVg=Q%s$=F*^GJ=1$*mMab##Gh#`iQE&++ExatNX`e- z`>^D_s~6Q_l4J$;TYU;K?Il+$+U=CE`-+F6;r65EMzDGkP0Gd1Tiw!{C9K2E!&E%u zYAui|=Aq#sflEtE;GeT)b=;hZG8TaK@JDLAA;uZ5bf zF1?4+h7#BM|Moq|k>B0H5>8l!K1}!Z?}<7U<28uv)_^zky!M7-+`#oAxiXQ)ju55( zYQuzPu-e3Q0W5x`gq!d5Lvce#S7>mT$jhhUdii`Z85I5gycAzk?tF*W&Nvoz8Fv?# z!iJF;(!8%mJD3XZFO5^thXx#sh^$-6Eh416PZk1S!)`(XX#?6$hIq=BYutK%KuuCF z+xhZkU1%>n41@#EBmHhVu2?Nr_G{waB>cLAM13%<(_>d>8W>EFKTZ!k2X#TH(cEeF zo|*3VBLB^Udd+J-fzvC0@a*5Rm|wrg4AvTU+H%+eg6`+Nl-TWWtkMzdpMb^@9(iY- zQAnL#UBkU5So2)zs1GK{AlC=n zd}Q$woI1G|AZ;_OSWoL^(2?%F(wtQnXbZxuDJ-~^cpHmDPk&Nh2Ao~-y!x*l&+ z4Onolo(U{AaXA(*^Q=x6AD=!Pb;y+qy}7vCaqd4oO6zgvT0d=|h~MhCtWadDuh%wX zif8SYf4E2*{zR?*F1oJ#Ljk7T+OM=F=aWiXNHut^Wpqj4=~+bz+sri8Var8wd)ykb z)%q$U>mXy`Jsu~`;(S(EHfBis>3qqG+eKhXMmUMfpl|o@1ezbr6(nZ<%-0vGa~;}B z&pjYZZCES(4{f+sXBE_JNrIVJ?-B87swdU1LwBvi& ztt0>Wk@CRu!ofFvid%2*i>-iZp^LPEb+_YdYU8B*lLH5@mJO?4TQ0>Va>`{k>B;m% zEB-CR*xTZS7NIO)p&ZY(8vX4p;c(cqSZq)XzgdjA(EGS5ZMgylU2HG@Eg|YRoVT^y z%~BCjP%P}UlJ4r%Q~$o~T`7cURtTRGaMXiWH;B0GHZNc;`p3-9D<3hCWXIRH-kru* zAI{88_YJQr07;2FRK)1>=IUe;v7>$$`qf_d%x(dzyP(^VsU-ZxyXJ z>q)h>$_)~E@2|fqQb1Wg5YX?Q{DTBG2u4#R8qjjac63oPx<3Iz8V0NpYaFTmqv&eR zqpNw%Vz7i-Fya}*m0tkm?u>}n%?IVs-O)3?7B|Yocs$0ABP5O*(JPn6R+cC`&Ek`q zT_*^)FQxmwmT$G z*1&4qOj(!Tb9g?IDY)Nqd47?}CNa#ric80vzP@31N~y}Ie14F@6WA<%i&s!sh+L@! zUXU`kJv&n!9Cecmz%Rd%t@eZ2?Qiftdh{jT6Suy6`SPS}DNap<_4l*8i+Y;Fo#acSj112BYLFG1U)^`yW)8dg)`_WNDGB!4Bi`7~&t_`)G_pf)S^5Zms2~B|wFnNQ z8N$cE1xYSCcN4XeEIG>+l+Q&hbPg)IVl3kaI9WQ*i;LLYUHSQMy4R}TpVxu{mB&^? zdp+Gclf4FeA!V4GqR-O+C)}1T*y!(4NVS{w^0)b0gWV<< z@UD})OLN!zEA==slK8_$Sm`7;ZIP^Xt_NZDwknq3Xk&jT*{~mL#WmQg2Tb~wazlJ7t za?4Njck@!oSP-;gkDKnn6!JlXvazQ?o|2rEI^rwPJ^pQLPPgHTp%e=DbnAHh>3pJI zHm#o7h2gbw-70pizS1CTnW|kT=kMKQF5|@Y>Eilrw$QatVKS!H7EL=X!(<`}nqq69wJy-(G84$b1 z2x@R(_QZ2f$zNVV*=ZhPW66qL{Zf}M<)PQU=kS(}*q!~I>+)dE7N zl6!)f3tP>|6I9SYl=JIuL#s)ut`{ssJn~gE{~GEI^o z*jol5k^4(^*_h-DEiDJunP(1W=?3TIAGb*Ib<<`FjeR!137ln-KiU4sibR1z$i2{T z^8u(ATQZF>O#@V);F5VuI#X&Cy@mG@$~!wI-#}M`IB_2FY`zwdHjs27o>v;hyyg?noKK zCltv7`M4~U*oy-jlNT-v;T}EokItyWM>e;M0r@sFFYQhX^GCwuMj!r6ss5cF{=^0K zUw{%8Fd(FQ49ZqqY_;Wt&@*X%*h_tCA8#1fJ@O<=_#xoPQBt|wd4}F_o&Sg4)P>P; zhg|6f`zjI3P4}YmqsN7ss&{5^GG0DC3xUq6W`J0hz{z#3H2xB)e3=Sqky;DF7nZno zW;VKTUbSp|KCjmQwmJTb2l~e{yVx~|6N`56a)cqrh0C{h9Edppy#H5(dVg!?)|f>I)1WK-$O z@{;&7nCi7*_7yBaG_YpmWCy_yE@#wGljLpk0g9q%Dy4HJ$VO629G_2E+thl@eaa-`>$^>AwBTRePL8Er{- z?eN)w81J`agJQ?@4u?L(x_?wn%9GjKBUfn3n0>=y=z3>ur?QLo#{_n-O(wl#)@_AZ zYJUf3r?&0R*X}G`+&5UHL5pMS!f$(X=j-#SU@hi~d@l~$Eq{$IZ@K1!j;dR4-HVRD z<}WQFZ#;FGM;-YtQ^(iqp3l4d9y;sj0Ga%~2>AC7uC1MCG|W%dlaDf9l3Z(crmBkJ z>J{Mjm|^l%G?+mYgNYz2l`Q@VF?=lJonV9`1tvE1e8tNwRxSBuazkaj1pyHZ>6f2R zW`ADsUX0NKdJW38nhXU0G^WV+K=CA5SMcHz5hZH3I3+{y`~A89qYLHB&VBMUI_!B4 z4C7S?n6B%Ij&nyyu4iI=(J1To=GN5J_~_DRDEuw-?Pq)11J(&7d@l^HyL7M8)%nv2 zs;2?oJ;fvgI|~E6bGDw_m&;Qo#&USP!rc=d9jNW9ChO6ouaxq$-?kcFPA7)*6Ll8c zuuHqt?oU>3oq4HcSz9(Ela6|U@y^+0w{luK#lv{k>1;jwPrI?1&w2PIgu7bz5buc@MdE1rIvU zAQu+u>9iMDX_LB?qWEBUW)J?>|5rEAHPrp{lb6>*@1g0wx%(E?!OflzKX{mlnTtoFO&Kg z{%TeM5G{^->pU`{kf5*1 z+S`isQ(h@=h%IUJxF8*_x65ch%4$$O)YFvxQEzR{sZn(e;YWj!7a}n|^{}Q&U-9a>Ccesd*cJ6^RDYR&9fzv~2cZyC;-t)QFPb7DZ}@6j59ExLr<}7oiNzkR z7JYm9lVgT!$~o_Id#0np%ZH9hqy8sfYJiAv8?N{+>M-gqUl~HtH~cb#EESqy*VBBy zT>FSryhKa0nvQIeGPM;MIMJo3v{VVKJ8D6in0;8Uvh!%#;rlkPnc zwgneljp=1pj-dustLx#wE6EX?72yRFyg$d3kFB?y&I}T}XL^apN>LGBvs({%Ewt;F z6@U4G9a5w(7FLmrpRB8wOsySC^XzHRzrjA|rLrZpZD9C>CGO9b54y?Z4BN&_OIq@u z`Mt4J&t!}=cvm{fo1bxry%&)AXby1@xwVW@dWKR`pHfm&r)FpU0dB9TF}tK_ZtMDx zL=3<)*eWR{NB)eC`Vbsz)8D4kEIRrwJt{t)xcSd{NCYr)v(~;KQ(k}o4H{nfS!EEw znUqjQz|okHwdFxR@pd^v2ecS)ReRCQOis-%`S1kEZmgkFjI-&Ul zY$ukf(&`2uzaDvS>{dK8Eupl-C^3r4*bDzlr{*CsRA3o39`U|kSnwf+|D<|6<%`qL zm|_(Cb1X<-nTG$ja!4Gn?rrc}n4Sg)^4le9qb#H-Ju`7-(Id25CeaRIskIkf{o30F zyi|I?oZakG2YZyhysG4)4%gSkzf4L_w8`FTI)}UY0Oa& zxpDF%^8n?SDN9_~m1MjLf?19p5R-j^kG)Df=-Y^tXsDZhM9+HViNM)kT{OiNG3nov zutT%Ynbju)xa4>l&h~%ojq@whcX+Zy=_u6PMIJ9fMs>$rx)a40@ zoj4Kpma0J@E{X6?iLenNOR#Q1J+^xKDv8A!(d6CwK$+In&M_pb{A7IkoOfoSljPu4 zX!Ui+F*M-xua@n98%((?r#E}vtKmeAq(nf zp^s*xf@bfg-GAJ$u)JtA_VA{S+%SFng1A+@`eal87ridtSVP8LRU@QIpjYlx%yf-C2R3v*9Ua? z5!F;9MM*l!2QMpAP^K5UsK$@ljlt>jz`uzlfA2=kuG4to&9m0sd8FjqjlNT1()Gjz znaM&{DY>g87#36*(tqxJH59YmabSYs@({RB11~JTX7aA6~xJPm?e@ z8K87yX)&CAO}G@1noXdES0;B`t{FaOmL}-%H0FDRC8qo_-VW1MZ9HQ0_HtM1O;p_3 zVsQzUVtr^JNeq%xlZ}UNsIM>a_U>VFQA&O0l7b-stgCK5QTDZVe0v~crvzkV@mHhl zYim;CODJbwL=lvAi{h;RG^aTwol^%?wGNFMO_#@rU~J=Sfkqkk0ijxgUdW49AZ#}8 zPyv_H z;&$P!AWERVGKlp{7*=&(aV$2aVzR z2jRAq6S9Ec4}OI*P_5AU>+8>`|C4pdU(y(P93vT@I<{QbbvM3@km3u%)*RHp`TeJg z&GFRo?oK`CmVhA~#(9D311o=3tui;NUy_XIh66eQt8?7PUBW$2cnz5r-NiiCU5I2O zX#vPk@XP(#`??}8!})6yf@?FWFW=&+mX!hv`Gsr_3v8J%3Z>jM`(;n>2)jFH9jDMg z!9guF9Yfp_aRf{v#&K2rtdo?X&4fogkCl1VmL-cv`o0)mj1GGVZZ9pBYC{1C$AU1V zN{ca==oS!T3QP2#-K6Bp$#ejk%xo^jw)A7mayZ9H)c=l>ONemx5SGe z;wyKXG?qNIY<(G@_o^mf({92Mm=Pvk{*x@NubLdtvLXEV*`oF9siH9sAsW{OC)(OP z2Tpg6`GMuJVpk78JM)$aqZ!B2e55!SV)(cRjV>Lpx!Bi75t2a53H91W;-|CoU-qrS zyBlk%)XyfK~p?EtXz$g)1Q!_XCPC@Cmfc>uObJ#4%%Qc;l~v%rTrCSfE0INA>mvI@9dtV2TpJN zC9C5_U7?wCMN_sXB#wAHK4^}8oAGuG7il;!;@D(#{Kq{;<%=5 z-k|Rv5@j)^T0^S+n)$&`hF-a|fH^m5m|x+Sobgpj2(-=cQZR<-Z3LAqnj4~S-}b3Z zh?9~^Qtlro2*s?uZ+p~t*0YT$_NX~7z9x?VjTAPG%H*}sIC2@>!N9Qzh5-Co67Fln zIbxp}i*pH9l`P!d5p7Q&{7Rf6xl(P{2A#vv^{zK}0-gs4y<4ets8jxV!Wzi`aC^nlSplE0#!H9(0;KLbw;#y(>`@nn(9-DpoGll8Z#O~rd+T0`NTw3f zD^TBg_M!Yr!kQ}3Q{tL%Kd@X2S7U48Sde^R_^Oe@6{HdTopDt|GQFYw4)$dIp%sCO zlQDU@#p>Z7lDhk;u?9-nP)Mul+Inxem|h*e5>DC?%WP~#{IOfbru_03Pf@?BzRj;* zHSR=&?t2*ErxT`=v&L;?6*Mi-j;9!_D_!$R^~kJa48acyPSVDtrM$ahBTmVGwMP7a zcrdb(<=S$3ehmgaOWq2hXL_U_GkfPO_5mx9a+h&@R9y3hCU5j6Oq@ds|5Y+eI*?TS zvw~alkU*o+gZSh*(AulkgPLdI%4_m4=(P=~YK5EU&dsp0gaWo)Yb)H_bDP-3A>T6u zEQx}Vf&@>3RqN_DiGCyvh;fvrKVTGbKsKwk87*=Bf$+M{xhte!URft3<_G~pD=k!? zMrqo?wEv;n)|#iBi6++J{+LThT8NxW?EKC1*aYMb2NKCXDIkzK*9?RoqhOpRo72%B zN%(QZ=}Ht!BRW*V`yNIc08QD2ZbY%m+BRJ5 z;8M9JoWKp$C(~)L2n%mVBN75Zw}tSx#h{oj`&GyJVhCa;m;RUxOQQQuxbt+dhZs`K zpt1s`KONsSwkmD;fHt_@>DS};h=D?i2~rzMbLxr;1)rbr#N=i8aU603G#!DL=hIhU zZ@&AFI&0q8`F$c9JQ+r(5>CWqBL(?88uiNLxz8=}qsKi#DBBBFhAyuU`*Sy!8O}h5 zum0|Ee3KR&x0@CT2GHBsb7!H7Q^w8!)vDsH2vZeFKUGRgS< zql^qAoA;fwHx}Q07m6}ngEV$C?7-CdyzVP)$;gP!jEd$jx}qdQDm```gaYSuBWhX4 z^&Oc{4Q{24jUOZV-f-I=b5ALUt7kPHlh=U$ReVWC`0r4OmaENAuY05V5ut`y$yjyfG)JZo>Sccs%N#5o@N8t6z)Gu5d_B0Efol#1mX zf0j9(5nzFUfD#YzZQa_=tMMvkF6`L+a&Hj47k9dSxpHk~z9DWt&=ZI=olV*%Sp>$mYG@UpVFJHZF^6~**BL*g1$-)F`=R=u#y(J04=~-&CQ&-Pg38k zDL;4elZL!qKU8d*^F7uz<1=4%Xogu~A6WcWiA{@af97=i$)yKfaMwDj2T{1j{%6l` z!1iq6yYM4yj79{EAR~GzSMzyr#}mRWK4ZBy6$8 zH{k#7h* z#aj$g-krX!NOH=-ic>W_ls3=rQjKZQWTjVhHBo}yukX>UP~_w3rrqX|FR*?*o<~lp z)~xNOe34{0^pWA=b2#sFxP@A?8N~7!5a_Z` z6K+irD!+efYI*r?w@C+O%0eGVOJu02OVx9ZmpqH9X&d#Bj*gFUNT;f0t0tZtpJF8qV7)w7FzcX`Gpv302&+ z=J{)6h<(NPwMseSSgvAsTjsRA@0-8)r$k3Qy^O+y-MiQM&+=ce%k@%--G}SZ63MQtAtHzQqU%9S;Jcw(VaY=LMw-Dhh!`4R4a9 zmn52UkU@&|tq-fqpJ~6V8jte)Jzc1IY516!kX2OnrP-`Aq#yP0P=uT6=@*i5UHpPe z4~u%E6U>RrYl6cNDLCt{Va5F%n)eON*R{R2t_XWEm|$jlA0tIO|sKee;(eph$)WOBpDSs-5a*_OE<@qTsocrlUN7{@E`eX3|dEXAZ#ZhJ7eV7DML z`61b4-Cd>7QifV5g*UlA7usGgtFlDHgE?mNIX@nQQd2(M#;mHDcAXJ@_t|LjDK)hq z7!%y|;;{ee{?<@#bMWw2yzNiLIM?8sXEuic0_URQEn#6{gRN{DT=eFm|H~Q6&zf!O zuD%B|$hwjf=WtZITo|kXX3XW~X8vNuej>Y}fh~6c9tP^;w?Eq^KQB<$pBdUO&XKiR zZUk5C#`+pBbLxsJ%CD_<-JaMqWV?q-`EzL2;YSFYW1epFgPj23JJ z46+aFYxe;S`#wD=(dj?rxM|gy{V4CSAgM3c3<#Ku|7~C=$oO*N4~7gwV|VaKQ{cdP z(1b-y7Pl9nSg0{U z$(b@0jWot?J{qF1?N%DIWVb89S%15kF)b_YFHjk=c_?6MG+LF!uqZ30R7!Rzxn|$& zlq79^;y|Zhiy|h%agb5tK{XC($RoW)5w3$mi7I`o(0Kb1n)4r#?nT zGzXe`w*v{DR5BBIlAK_1_bB+von!;0y<8RtckMr7 zN2b(i2B?9zd60z+U9Pw60UG}zpx}^;Hnx(POsd_ahYDh4#P!Y{%<&>v>oyW6&Qx)1 z;sH)UMS17eI1?Y*(_uv3#*Q zvP0{crSv=Un~BgLFPyUEGX$yhijbCLoG0~`Sz&V-gmY{iZfcS{u5%>bYF^Q^?rPXKB=pU|5BMs2F zV-Z{YC?>WpuyoYs1B5>K`w7G)e#x#c(SpG=ry7x0)5vZw`J01i+}B5m0H-pkU{wLz zH;73e>%&P3e~YK|LAS~ zfg9;Q!4=^n{sl0y2M1>&sSS0`HvsJv&Qg1Ax17+yl`qugE>?z6eK-wCTMdZs=rZ9i zlYq|QUcj>ihU2m2)YdwP5+;H`dpP5eS+2f-CUb!K#gJJ}W_`jOy{1V~Q89SoM{lmL z^=xcDM{E~K^T2NnebGKNH%)ncD5`ciRID5N8A-$!0W|VS!Qe0vY%!)l4;tagRwHgF z1lXtBGB-FEPn)==(-YUAe8et5pp=Boz5Wz;*k7ELqlFwaZo&pD4V{O$ES+Fb2x%V} z_a(JoKFXF64Zqe>%rkzIG5E?y&}2bvq`uA3J=G&z6Vw~iCG=e-VV|U{H|dTAw?;Wi zo)(4fFrnh|S@AeF#>BNepXm#?bo4-+*?KCgWw!{Hwts(rKaQsdpikaf$o2)U-F%tT zW^0$oqiDv)WZG*DKp8WQhXTmCGmIW3aBS`EYf{emQm)H^nzfLpCr<*MZaL!saIBhL z<#bg@XlRvw4{V(T|MkvLES34i_2CT8oo3pgJU|NJt(93B(e|~OQr9x(C>iZUHJ-p8 zMw4E9eXVGTvG>Tg+3|RGUEYo4 zC>OE`a0baC11{WBS5WV2W^7!}%1y_O89JUCcm;^)loVZ?KLMAMv>P5r)9ln#c0JoV zFiBuVLf-Z&;fH#g+YSC-hu>(|ahK|gMpU)X;&N&1*^F2d>0E3k6hnc&_V$;{m!Yk+ zb8~Z4^liV!aLazyAviiZHtApjyI$nav>9iTV)^xY0o>0`PfBWOrbc%h2ywH$?sto!N>fnb|;{<4dQCcEOO6bka z&Ghv2QQIn*yhB!hf$Cgx5SnnD45PMC_6a%fd$=1^Amic|sllWy|E_m4j@m!T91HBe z}UrFqx5wFfH2Em&c<@7>KHv8`lPPdtl1 z@(%-SO*hA;jt-*q^mJ28OO@oX%Qg$9IX_$8m<@dy86=>Dmi0&^`B}j?2sopg#t?9@ zD|&SVW9aXWf+K=tPMic+NFq3id6xpf_({kLfhkp=>aEv#cO^)1_!WPu_!+0Nk*K%g z7}q)84->5v;ARkUyT}`J>3F}~{6?r*ADY!$G^&&eF`^lkPr1Fl&4w_fidBFNujg@+ zPi7su9`rE9iAiJRx#f_qbvksV%!J>mi{;Obqn7&2fWaQO$zz8{OiauN>d)EM0iSj> z)%KY0@x^I_QrpE{Vs2?NTMftF4`^`I^;Vr;D z4@=G~e+*l6<$Wl$XH;$tJZvC3T&@*VQ~{UHwy&)x;j3RAom}xt$&qWoF^G1=xqQ0_%^1FRcpV5j?KBwW$a<6> zm52EWaRP;eg?F6MGie-6eho!%0)6<|%bGv`ZQFD6m!h^g2q-)sE>Bd6I=x-CR{PNG zK><8#tm4{chW92+orHJ)X0q2d44N1;?!>xNYi239u&I11yJxDaH+!U5Z?ke^)8)=K zmZTUpbl{@>t_=l1f^L}Afc$Gm&Z0@+e|}8y@`OM>sRzXZ8^{w#yU^=Ri3SMZ2)u&Q*Skn?UjM<$=yaAimpsC-RjN%qSBK0x2RG%}XX`8Rxe8Ai zr0V3bg5yag-3jYzqB*7fsa)D^kG-uVuo(A2(MZoo{Die)H%%L@Qfi#rP(pKBgzn3c znyI;YPv+ZVCScD25AR{vw~hWRY*-Zi9^ik*{_59h?rqfjt=pB;GnqlM| zGYA>LMm2T7XHgbqE#{7@{`J7QmX*Be1YQfD@I14>FP-p!HAbXDLsiJALf;1O=TTWl zBnmQ@-V-rD_&SenB3rZQ_i0>83oD{u%3{)A9 zV5+p!eCVHo%%(TG7R#%^YL0#mv(mIxLHmK8Qyt*hA8+XXyR zO_vy#Sq@GofNl9~6Br2x>3-t@@SCDPFg+x`pr4LC{%pWp1%|VxeZtv)5lEcBQ_9Kt zMog~#F$TE2z$Evv)_j3?FMV(Fep-a$-4T4o`waU09!yfr08G}fmGC=6rwhg@b|pswbXrSAeIaxRkuD%wz?ff(IbPAolY zl_Tr-XG=tEYdLE+Uy&C1_Mt@H_B*ZTThL_Xo?z?MtK3Ec-;4S6-QD)ZK*PlKk%@|$ z8diqKrNoVAvyjZC3{@j#{t6nsj=C;iL9a%ECJ~1lvlxJn5KB%dxA_Jm0Dh4m%o-W$4RpMskK-?d=3FZFlK53_Gu+ew zqZ@Ma^NnFK`C1|fc|!c)Q1;w#I7oa%zPp$IEmqJhf&N+G`F49)Xy3~Deltk`J9*|0 ztPU$WjIr?d0Hcct3!_JUz<*p22ueXTNe;X@p5F%MTHxEaZLW5l_M$#7Q+N}tKV5rw z-@P40PR-b12I9y+#fuTT_Rc4JSFgL@E}Gc4g{4UuQ%fe~#0eDohVXeDeA9h<0k9-! zsSLCB|6cC`nMfT_K4#s8HFpp1{iM>>*K#wy-g?XB$%O^>M^rsnApmh3I1?l6F7~Vt zl#>yWC7rapMf3uuygk)B-1Y{eoBsYie8cD=O`|TB?hCo%!SHVLY@vH|0(3(73h7-e%AL#D zasnkJXFz~DAl-I(p0e3E^L=ohuh=AX(U|@?;|3B5#oPpZLgqCq{IcL~x9G}SpY>*B z#{>P@Lr}$rvN;#Jp};ZXyKXq`q-2wfQ55(Ho0Hi>B@;kk9<7`I)A2I~ zC-6#Gt!%5Wn~8=DzW`PS2KW4jyitzn%nKnc`!N=(1!mMC*|^bO5iBdwtRn2x^%t)d z;KWW)12TSZr1v!mA~N!?=N6M`U?i40F^kV^R>DyVE{ooNexmO94S=tK`pKTkgng5* znfcz;jbV#9;#z)39uo+sB`LpjhqJA*Z%&Nye>PcjgBlB36aM-|G_L}Mnyw8ulx`)g z;TGKE-0`HH2Asf+tj5jzi$YeG7hOEiKtuyaswWsUIKR*y`rp3L!3OG9CglU*!c48x z>Cd`wF*Za*C+Dp~@4H0ywl_!U(2%JaH%!md{VZ2L0C(&GisD6qo;$y(a%&&Xh2-rU zb8?8^Q&c=QE9Doe7k=N6X}PMpIu*7$z4sTt%GhRS(4HrdnNg17!VTSXlK{RE zvufOgWg&BIC2ajg8CJk{w~^{?hk~h7N!Vuj_kh#>r!gJEcNVY9Ap&-n4T$o#GtM%d zHaFRS_~D|qologH4hS@K&(~abDUA=rMUeIEJU4)d2`M zU)cB>aq9Jq*k(edq@=1W7OS&I(9}Jyr^MzqlC)yyiy9lT_mhAUt1xez zxP2$S07hQH>6mMJP;ID5S)mEK zdwp2XvU>RKkj5Ll?{$uZ@A`lfe@ZDZ%wqmeZ>~he6!0%ExHDu=)#c^1%xrAi_ZQoM zmUcHK{PGWxlCQHtMH~A;@9^fBg+s7%Ewxw^g*fcvsd5MIYF2dq(_D*O;Q!)C zJ+HTGthVj`cC*cX{cJtCxEPED%$Tzz>+4gm%=)){G+VnF5`r*iA}{N4*jx^2Q@tKj zJmu->d2n{7dIvr@kWcn3ehsU?<5JUhW3hUeHIZ*If_Ytkvg%p%tQcJ|ue4g_oWoM| z0!v3Dj$4$%`gUbeLE%-9^2FW@no9%0C0oRCm0U8P~3R2*^`mq5XY=b z0Vq&8P?#+73IPz_%**;;^#0owt$$*zTYbpzXa%{^1T}qu83loNl(~e zVc}3~wHz`m1n)xWx^IxN!~Pi2v=f-x6Tqp`lvJaT;VDX%DRofX*ne+pzv1)7q^4c- zE#G6kc6RJ!hGLrD+wT5E*Gu5M;zRUyKfr`5p9UXu_`Pclj?e`H73?^!DknFASQY|E zD2NfLEItU*{2~?nCO%{@bkwXRvkc(p&rSdqwRN&sv(VxU&TC+jM(5d2xmfLZ{#gEU zK2YUxybD6Q-k#HaKJbVXDbh2z$;oegnaLFDNeQ<>t53mvIi~YOKBb66T9x$;xO}PW zc(b)=eN}}H37Pabkm0@=qszWo$lf_+MuH#do=VK@+qy}ez2iCkONsDneNBjwoJ^16 zx$!LyaU2-twN<}t)dNh{^D85gljW4eq|(g&j2fNUh`bLm`3%&CAAFCSQ=IdR5j$wD z`c(R1DU1POZnkRL#Z+=9E#9Vl#(l&KZpU!v?H%{#y^n@ef9g$NIOA03Hr9DvD-hU! zPg76PpWmy!qQYf(LW(wG`|n%(>%*!{PDOB-^0_Kce-5-NJci|*$Pkn>=L*x4hoRbaGC@9(k@evA zwv$dLy{g}?X!rNz^3InmyrBuIj*E{kw%hmkI-x5$T|?GmWkdpG9I;bEV;}?^CcwMI zd=PQgb3qZwSp7zvV-c5-uvOl2)(wme&4Z1XmYnGy1I^2z1Ge)1%=@7TcSBKQ9d!1O ziT!mrbP~z_fcF<=*&~D33%axZmIPc#A!%3|J!F;-_zl{HVu=l z3autax@vT#e-*lsOG`g%8yZZWvF$qlTO(A3vv>Eyn6v4BEo$&ys52IF`?Lm(R=yvS z^U8Uo`XvbD@*O=;6DQbI+Q|MrLuh-b4WDv$&o$YvRI8m0schIEotP{Et5oxN&m=?6 zDw5Moj0pdz&(=Cy4p+8rx=94n8qGEefIeH7zWOyX1qxT{QsnSS#_=} zz~Zk$+NR_Z=V!n?=00Fj};9{L?++5Piln&t7S-@v0q{m^a4zgu-C0z$07 zt*x!Jze129i-gXD{`(#~r*|sy0NPnpNuMKr1s7BuMdt#iQtqt9LhSoL;QWW8)l&qJ zhvKnXbOOkJC4hRAg9uRSk>d(m)}zCvO@dktn%C~Ktg0M^7x<)jB{(;3Zb3^sJ-uH zdLTGFI63iWk%EImLoECIcdP_1Ck}a^8Yd3;2W9yM;pjb7AdYJIbqo;!cphT3EX{h0 z>NSxIN|_3yaL9l(Fks8>R;u(15*tcgSO z25Jdi8yg!Id@Soj`+C40Ou3MQBSuY0FfiW(wUm?%nxFotW2FW(@AVw~SVX}^K-w#4 zDbV1Ba~0|$W3qBQ>MS+2={|@KvBrJPxi&=%>Y@6i(4J`73XuX$xk#4;Yzw_~#ut(w z?H~j=AmGG`D9E8<6~s_vr5ZLxo728~%>Rk#^XyjMW9{ zqDI#20CN5Qx%9KFks%%<0yfmz6gO>ZULw~0fB!|?QKN1w_Wv)~yqa4PHy@byy6!b|9gOS z8*LKS$7Xfzekk1k-R83G`%kBf&V>IeCV1t?PT}q2A(~C<2>-nm>Qt=+@#q%}r?gAy z|J`Q5Hf(BMgB9vssesRpcUgJv!@Fw6`b>F}889Sp1G)Bpp9gqgE$|AGqa1n<&X2bk zkd43n@0M@&_~bqR^DJW^XT6c?|AMFYtF>b_Bz+dTdfyMcgnSBSBu|oxZ9K;RyIv;K zVVubiR+CTPQlQ?~>!3NqMmx+UHoX5nLkf89M0|=5JudFDDTkh!mUW$taO!3JUx7p~ z*DMyv&WR$F6>XP*Y5vn!$qQsvX4g-IUaMaW?|>|LkIp(V8Ji+@U19-;FAo|y-QSPh zp|Oj&aqX=vDcs+Z0_smJA*IZDk}PCbP7cHTadWxX`VO2f@W2m-A?P;J2mc|84}oU= zZMpF_iFc=$vyRj6#b;1|H8H9(E6uG_)`zdk&|`>#pP&_*Z5Yb@;7~sX%y>x;)PJTd z75126b#S|Lc0`)WAjB5Q7r%DX-F41y8COU2ISQ!6RVCj1t0ARM&w^z*8jE_1c*l?Hpz4+nNTzwBA}I(Mns95UgPy!l}Bn>_u9@R zjPwu(dq*L?@(ag$k}nzg-*@8@=I`xBzvHdOtnP{2xKR>gHq6qgCN^|l?7)W zMeb%>1XL(nr@A~G8>Wv{pI$$lelH2VAr!q)&DLbv=##{|B8;S8Lw`$!Rnp(2P`kv? zCQ4=*Aq1`-m*}gX)yR#*O;d5vP6rht*(#^g)nJYOv6IstcPVcXk(qGbF$T)4)sSgO zK5h&-F`_YCUu5o==##$aV~kI~3p6geFktNxdJWOlSj0olLu6*VfR3*gBXx0 zVP!IZ>rZbaOrb;p`excu5rH{p_D_MCjkfqfWPlN4cFrR*5cyHN)b|IVD@fY65f7uP zv7yoxE2h?W3VfaRlh@_)eMzT^1+>hx*56X@QYnF3N&nrI_0(JPeEtX{iZT;9u>CP$ zVIc!p$Y^L}ESkW2yH5U72!oHh-KZLZ6Vi^x?CO9y(n705Y4+oev1={zw zpPGd>nIbOKP3*txv7!HjiaJIKF!5OpwyEYw7}Sisv54^bN0)0Ap4z`Rl{ZJ?Ad(Pq z`H6GI@>9lntX)D+f25K8%*e|QN+5x&beQQI+%+2zuEd7=lwert#?b9)NqNLsS|AEA z37(l<8N)`raFpezpRXmEOIr=0Ek_4OO__&z-&KKU6#`dHuw{qb^q6+Sj9x*U{%`If z^wv~YEz^xrg();7WqFbR@s4h!^SrIRvO<1_5V^Zc)5Mhagli zc8s>re?7>FyYJ8d^T1A$W3>^swb}$>h_XCj?7a-!t$S8;Q4s$#{2Ha22$qMOS+$Jw z)HncQWC(=SVTy=tPf1;$sqip+BzAQ975-^8RVwhniB)q3gCO+h&W_$1R#KI z(NsMWR8t7mW>9}r2-RBnQ2EiC@%0@}0`btGQ84{lWYVMvOu(B!Y^J!{6$29Jy)cmz z7C+Cw+={dmp!P_WM49Zm`oqL_C%;fXsjC)D_o1nJ$xV6JmcOb*AAT5X)}+?!uw!A> z4%Yg_1=>g61yePNOh%|ShcKPGz9h!fdHl&92+e52tkzi&Qco2H-lee5QQPp^vC1Xv zU7F(vaqg>rGu8fOgHX9?(iy|{u|d^&nvxuo}4AttPOiHcTI5_A%@~UFX;omy)n;@3JzNuV!T^k(08w9B&fo-Sw=C z74^RP$s6H1+)al@eZ<^#`wmDHH97B4q1i!ZPr&D~cj(TfmI2(UxMB?K&cTMy9ul3r_5JgWNv`N z)t8UMqiT_bqPb=d$J%l1{5f)G;#!aCGwz+6D>ByVGuIKvko}0iehNs8j)hG*nvAAeq$$1kji4q~C9{e>Z|my2|5z-OFA}Gh_X84v9Ans5Y0O!)s#3?6)Kc&@P;bTR)sYKdM{;d#jlOa z%ZDw)1zOuAb=JsaaoSrPYs$xqX4Qe=45-@(n2kZsdE2PEbV#uw#;bhy;gzvhk+7`} zW50qZz({WtGw5EP7_pDpzkQi_W{^Gg#Mq3PuJuc4HWRc*I&s^1V!!WJTDZ(uuG(E~ z$hBQ1XOZ|hT=F6cgMpZVt^Z9i{hi`@F3Dm+vxVNhDlvg@r`bYmMwGPht5g?+)p~5D2~An7n;s-JQN^ z{mi)7(H~lh9U7PmP1tIFJ0ED=if~uav#OTn+ayIRxn`S*F^2!DI?GnFPi#E=+a|<~ zFdrv#Wd@PJk4GDHeV1KShHqu?42|S!m9J9QxDO`kERm|{x9+&1D{;^mP9QU6B#8No z%_z0Gz*^6D4g=lY*Lb@YDf{raD_Q38!Jw_~<4xWEyAH>}m|R-8{2G$~->NIC4Y2}f zP1=9HUs5H#VQ$|IgJh6C%2V!D5C%OINN9`CYhAhc;5Rg;1g!D+jb8*;?x)_uB0XOy z8u!llE|KNYN-UfiFP}(1nu)+>n_$bwvCLNdgxh{r-ZKg`_{buqTBK)8jtE~{UuTe2 zYB)?I;~H_PRnh*0GH0JSpOw=t&_;f_RHPGT2F*%VNqOvj=KhA|c}hT~G*tm>p5XEZ zo}iHg(q=*&O3vWtwY|S;37aj02O!2YB@~{7dChLb*^f6hOkWhDo zyC!3eb)CDQd$H=x|R@@M(I#MTDk>c1*D`w=}tvLm+lq;0TGsz7LeK{q!#Jlf$#hB z*Yde{&z&=8&fYW6JTv3Ha|COaLN1JB@rC@*oSIY<{hbi|Oli}H z!6h=Mf_cS=+819%woD2T37LF0qmm8Gsfg-t{X!C?;)+NGz&U*euL5!gzcQTl`WR~{ ztm!F2%BDm4NqhjjZuVmj_REakq4i@d0DiZ(E#^D7SalX5?+AZVhH3zU9Kr97jpZBfu5p3EnHypyWYv7I8c z4xKD$dt&%C%Bpj3-01bjYq=qNG@tArsh3Ee=*QNal&gWPD$(7ghZy@RDMPGI6m#4POzjWm9 zwx*wJ>J9s|L$b{hSMF;n#i>+vWV@Iz%Mb#WTie{0X-2LIlp{fG4*1VnroR(>Py0R zYMj~59QFHMDtLT%C{^=n#uQ)H`UL7`3Kw(xQ-uiLq&z*wS>-PX*j6OBlqq>2;5 zn8Si20j3G~(`)u*0u&&V0ZSmYWyeR^nu`bZTct~c3z4(z@U}1#t`4Uz8+y=ekb3omqLGeLuHexuN;6OJ|YoJ4s; zV>5gEjv24CrtDjtii>m`ms=4L-(`6nt1d42)pv^X+_;I4qii+^cjU~!y&tDetOCyY zYgiT4XCsn08hP#Eufe6b{pwEaOM-43Oa8GXd1*XD*-ChAxZ~`)<+UPjbC?I3r3KS? zI-&H^Rr?Ijzcnt0>x-Ia5F5~k%xqs$(dcQ!cnYwi&FWS>OkPf&M=0Zsz|Z#Ci~s!y zxs0l#Ql=n@{~?gPMja_`=MraW-l3iPDTW@7PMOp)Q}@*b88LYEp*ZeiE8PKSiB~5IaN-+tR*_X zF_;D-o}yr9%Cy8z_V~*I*`;n4f8x%9r5Y9~$>DIm`q`iX*GuR)n>%Yl@`pdvum1T;H7`F8;_kR1+D)9h8+!q zDNV@3kVp!{sn_{SmRYl>L~Vy_N6USer$ePrmaPYW_fl=`bfX5!yLZl3F7-UEJ}|<+ zuDz?EAh@+061d-9;k0iv?>|b$qt>B&A{~Gp+V0(*e6Cdd0e58k?`$I-W#ndXBrz?n zwy}Nbz%i|g#zVG^b^uuZA4uHPfEHm6^Ue-R$E#lgNfhg&IR8oa^?W!M* zePW&^_Xms4V67g$JG$vn&~v;T|8@D-%d+z!Aw%1W0@Manv7gk{s@MnI=A>LYpt*2J z_X^E&iex+|&J*AyEO%Z0D>w2PB#j`$kY;|ChEnQD1AdF^mXFn%SbGe@rTE7;r)f(y zTT5f(J1mrjT&w4r!8oRkW;Xx`Z&zw%(B-yYnEwDNeB}`?iCdVC&OT)7lek**x1P`9 z)@*KP;I>3Xlj<4Z^)Mj>nMNM$ZU%`}oL`>FIwg1-O=c& zLtcU9nT5!)VsNA_Z@ZvtWW>6&E>|beT-X)Lm~6z$sbJhAOumQ6t?A?go~n?)X-A`O zqZv$z{VxYrx7)Ap+?+fLMse+b)xY7x^`3vVzcicK+a8UNdb`rFsC!8~`TCD5a;}VR zgk_i6^8`y(bCNv%TgxeS>cXjE11RK3j52V2OT(nqrEQ&XSd$ZO|5ejSyb>K*Q*~b# z5c~FaOG|~rhyNu+Ct%Lbb^rpF!09Np#55ooXXQ1t=n-L?osPaiPWmFFXseYbzT~KK z>H0AKes&q2uRpHd^bC=jAR3vQoSG@_!=IyF;`)4sUL?kbCjrVm6*0n4*%ClMVG_tS zG3RzGstje`i)X>XSmTpsXGfrU>N9|)YiP)L?(e?i`>pa7T6Cf$g!TR5(ZROC)7x?a#cTfFmfCB^xUA4dGd!^5R#E4TMcw=G z7ck)~vif}l+F{-{PNtb{V>ux|&86BG0LSJ?`G#<_pt?m=jamSR#arl{Fp@aXACi4G zZNv)C4WP4iv))l8ydxP$+FLjNfG|0!Fkq1Oa|4Q&xA{ORX-xPIMSQ8#qkQyjJbdQ^ zjM~)SF`HjikJG=;&N!Bw63v`St=p2LUEA=>5WWw7Zu(>O<^+=P;d@(ElxcoZ6~p0a zD4yUMEjPuFG@=eFpxlw!ISO!PZM_8yV`68_8DlV&oTBT%<=Ih!2kxso*wFeHsMdB_!`YrpC+s{HD#PJ9^I51k+86Z zpDOkVZc6&o^_tDzbaa!0op>o2Vtp%!x~<3mG&T&)FW(iGn(wjeke1J?ALj%+>6Oko z+un8|=HS31u|g4!nVP1cZb-DK84FT=`TmB85z(*&Y>PNY?Z%;Wg))v zC#?v960?knPE-c5=+f7t09;ZAe%a?dP=CsQy!@j2-j*`GRI#scS65(FifDGDyN~@ zSl`loU?CsulQ8A8)qBZo7`RSGu>-~FeSBv-1uI`IX+rdl-czhOg5tJFkh$y1N;#0k z+jnshirNxKUKqHDiKl7NM^V*E13JQe$zm6ir!x^z=+xz4XP@HvSe>Hm2=7AhuB|G) z9a>6TSorKHu405?-6KE*RhwvT*A^i2&fIck!(V_mN0@_^2RB1z_rfH>_NxUrc3(qf zu*sK(qb6`K9HM9}7&z+uD8X@aCZ4|WY=z1nVJPGr^ac^^xv{;J6-rw2*sQ^ps??ER zck=MS$j+YmfSEU6@Pg9nMV6*YaM<~I9SSYJ>nwb5IOXs~ms8JAOznnH7zII|2a;!O zj5fcBvpBe5E}KGxNixylW#bd>r1q7cK`OcZ8p6Cc&%m36disMVwTRez?p@5kR?j-- zf+-NO1n!u42HIhDJ%m`@PaXhfnyZ`I4M;5DS>O2D4Um)?m3NOt?=5OdRjf9Zo=Gsf z>zNOZwe-62u~Tlyrp(GP`;bRV5QO)pP;gb#F)0@Q44sNhL{$h?nQZ1uxx)6nSP~~PMoZIiz9pf-sp=Z)cj)6Zfky3ev0ZV9kH1 z6WTB{6GV+6ELp6SB|Hi1!c(dzO~0_Rk5HO(2ereqt+{N~wjXfx3m(|W5Mg}sP=6c& z;+Kj3NDgv`(l_1oDfew48?#^D+F>gHq`$V=wUo%AfA;O$Rql@3v;Bt+*?hw(U+`}m zJTBaQ9-xl5b?`Z-C<&_GS0hte*}}-)H#Cy21kIX#G9kKz(a7kNwTB98 z%UA+xP0xhzm%xK$BsbS4V93(?N$JxI^PuzijEVIg@P3bh5pt>-SK>m1eEleQ5<4tB z?D}UqW&57~uygz2`}Y0juF*Ql{kLvDLQzdak8(dq$bvI2?bHjkOoVC$@Ol`*Lb$dZ zS@J!!?XMQu$OI4o^1v%M1ntaS9yHOwlU`|u=h0e9e|J~*p8n$-)Ix9)!y9DbiN6a> zxZ-o|-|0*wk?7ns&2&GmVNtcO#-;26-W59~#h@-i%KdFY{Qv`C>RoF~tf{DRb^YyL zip^lA;61SYa9Ww%T1jCa<>qwk4hP;WRhg9YI8qj>-`Ez6omAy4z&(4bjz=v-hTVVO zxc*4{N%ZZe0*Q38FMwF^b#+Vdvs_-X6w$Rav%ucMCwS8TD&{JXB23JB z-q9zvD9VIeipn47!zo3#K2hnEI6kEOppJ*K-F{f@&r8T1Clj7ZB4~?vlR%*_YFS)P znm;ePa|-B0;o!(@8HG!C|GnK`%-ccFTE+t7Ua&`Fm&Ml;&D6VKR@zswe&nVq`x{I( zn8~7LUK7@IaCfQj?Vi_{4Z@-w;tO=U5&cRehqY`1zi^xrszkxGb1f_TPA(kJ2M1p? z?&k#yF38?ZB%{acp@wn%k&BOrSWqgOgg}T>R2Aj*?D!S&gQ}Gm{W5tX5v3spa|C*0 zkRUCDq1caq_JLD)m#gr<0x&@e97;vTbqzcUrT|)?m0i`nXdzWJv8Lz8PX42c|6*mk z^D&3Je!?K)1VVsWNs91HIEId#NA0F(Nhv;$L(tOR*V2`2m{2QmYH#L7?-->hHc|!` zzlnlwGnjrWyl03RsFc*>tkj-Lj2Pq1S^TN_fflXB4CYuhO}Y}e^fGK%J#GB8 zjf1_3=~0~p*5;L;iJq#49r?oqqmd*r4d@?+&g^qlWiU(RP&MztS%ufhcnd~cJs72U)^R8w_ z-e4N;!#1*xU6+;h)nSDnFWnv-pH^Ld+!HVKS)_vAGJPW2aFio~{Wj!Xid-cvY&e$f zw-JF|1nIU*yA;wC>k#A7?1t%ApDo19W6wX;Uk7(e4 zUME7?`vDGjqO1zXV%Ul~@)`*bYgMAzHvP)OoIAaoLtA!gRc22cyCMgGukKiBNo|S_t}zIzK3D1rEn_HlCl$9f zvXNF>wR&`fNHvi))U6+nftc)I5UZx2(s}GQ9(}}R&EFvSb^9QQK$u>5*>s`>mR48q zdo(fhg1!{~YI0!vsUehsi4O5chcm%AWd5k>NI}8yIGLcwIVmen!N*BIiuc*a$YgK^z#xY-n zZfZg}^w6pSi@9t-uu0h#P3y|1H#l-@$BDtH0wu>hMw=&l-{f ziR(2%%+#Tq&I`-youqCJ{2$98BMI8~PM{YF|9=to${w0h^F=EdU-2#?BGxZo?6?@P z_w~!X>5cr6;V?f)EGaJZgSQpGT;7eZiQ1%*)}(3gW(!0fS6<8Ykp(rbJhl6B!CTai z?}(CkMCnx50eJ5PDMVfJDTa6 zTTV^A)jd)50^A`;0innBINi!p{mpqUv zqsW*BCcfa~B9mLc*33)^24Jjt`sHSJQ(F@eXJt&B+nsrL4jX=o^rx=L z(Ty?o6eIzeK}%&Fz^e5uSL^!eHn?=ZldZfQ&^jMJmlPab<=lw z-kXEk1Jwr#VXTF0TI^CbjZPy&EBX85U$3XUM-ip)1!Bp-T^eC=rB0NaRZ6!d=(z$U zW;hIzaEO65|1r+HdOCep-9vTLGCF(+Brk9)#RPqL0j16?GNQ(YDXHXr|NX=E!PeDW z$23B-9NUkqTKOVXbYFR^u>&Y?Q z_ZQE)qi+FC0U;60XD&Z)M1f0x(3=MbUNc8`eoI`fuC>P;HSE>jBB}EK1`W&{^jk9` zu*i|o>l9D}!dE|K`XuEp9k|}q{Rou&hR9uFL`Y*F=FozH!A3zaZ`OHudKQjewN|Dl zkW#jEvW1f1O5WA79{xG>@@fI{o?D1L;mzPQ@s_1re+w+Mj0dO_G{e-MHQi&h`aDq8 zxNfibcM}-l^wRi!P`_BZQf&Ov;(6l@$X}W*&;9NG#sh-6Z{ky1W7uq015fwxf*wJ@ z94W0TV;;_9pf_LfW5e*AW$Hl-3D6m1vwWSbeAGlA*dGVD9Z_EXpCDqIqnnJxcXK9R zwr*O84-*1C`vl})i+*lYeniWf!^FZ;-o#Ao1p~&V?|I6N4RjgrPZPB_s&XbT=v z{#38qyLk&v2cF*&g4xP9O4(i6*eD@_DXCw5DY=IiTsc^l3tLh>5zw}&0T%uVfyTfk zd^ZoH9j?th@7C#&Rt#4*8LlKdFvrBc5B{zD>!E+WeI5IRAktqj5}%jpA<|#lh`JBx z{Kt>w@)W>I=<1I<$GBgJUaoHLSL2AeLw@tsS}bvf{&r>p_s&77M_*$k{drI5Nxc;E zRN_VyIWI)!LRn1dwe=;TI32nwi3OPBX4QBLR}|!T;C?;1%jNi$8Gs+H5pRaeoFd^| z8O(_xF_?9S?!Ub-vr{b^@DXRach$oezGn2VyRUC$4N6J~7MESUdTNt>&??~K7rHEr z2<#zNP2-fUgtb*ZT4G3tRGIz7X^Y0t&vqx#%mo?@t1T$f`NyGoO1K>qb9pYax>Pq> z4HWg4-O?#TaxfUx41a_Lc|I<4e?j`ej2YzNglq;El#DoMu4V_T<1TpC2}qv!w~N&F zWqfozcZzV9f|eVUHOCrT8mE9VzO78yZAHZ8qo z{A|mi@!+_MJ7dQmGqpJscFeT{xl~|>v-lq%FEu$A@i!)OKdKI1va|{EkG7t%Y$#py z7H|hPo0VVc>5rp-;t_%0OS{)+yOT16i``E~r4Pc7OjB3$+ozoBWYY(2FZn|wTiAdB zcmFv*TDa`l-9`G0Ok!q0Pxu6v>CuBz@aVvz{dTi2b%CJJgY(c1-c2x~uK3B{F0yus z^w6DDn|}rw;Q_7xEKeJv4~|9)eqEscSkM1dojkuu)pVE(Z0hpz@b!C^`#B{i;2(%8 M?6D%^k$Ld{05+Z)+yDRo literal 0 HcmV?d00001 diff --git a/.testdata/sample-root.pem b/.testdata/sample-root.pem new file mode 100644 index 00000000..9a3a0648 --- /dev/null +++ b/.testdata/sample-root.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDhDCCAmygAwIBAgIUfoFqa4MCUjMpXdMg7GQphA4aeDowDQYJKoZIhvcNAQEL +BQAwWjELMAkGA1UEBhMCQ04xEDAOBgNVBAgTB1NpQ2h1YW4xEDAOBgNVBAcTB0No +ZW5nRHUxDDAKBgNVBAoTA3JlcTELMAkGA1UECxMCQ0ExDDAKBgNVBAMTA3JlcTAe +Fw0yMjAyMTAwNDUwMDBaFw0yNzAyMDkwNDUwMDBaMFoxCzAJBgNVBAYTAkNOMRAw +DgYDVQQIEwdTaUNodWFuMRAwDgYDVQQHEwdDaGVuZ0R1MQwwCgYDVQQKEwNyZXEx +CzAJBgNVBAsTAkNBMQwwCgYDVQQDEwNyZXEwggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDrHGzHhcx+3ZyxS0BwAo+jse9587uBpAo+DseSVFPShgDNBjkc +/VpdYIzXJJ5VJGv4+6zeidfh0XGElwi6J+7xJPrZu5Dx4UTD3buNIUDz7BVIhRFJ +fJr9IrsFn4oYPduRK07Ij4ccOWIszdnc6Tk/2r2iEKwtqA/SEOWV/34YE+72K4vD +FR/qepG6lraeCg1FvYlNEg2QDGVXc9Npc735vgh7IpXJAEOuE2hDALKOJg9233Bn +qE0iSk8tXJ5NMB1r4NRnEHGMlpcZf/2ZBC1Lb9clUS3qpDzNRn0RxANoANAQ8iVG +p8ysizgk9k4CnUrwPcNHkoTUvVHZbPFGbzzPAgMBAAGjQjBAMA4GA1UdDwEB/wQE +AwIBBjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSk203KlfhF9FcEBfEjYy2u +h0U0yzANBgkqhkiG9w0BAQsFAAOCAQEAeISMnE6KNJ+g5hORssVqeY2Y8rZmA8tm +NmJQX5OpmPj4M59BcmfEaurbCm4yprDHJdVJPotEUCnOmENHFkO+xhCSJHbXfT2J +Vb4yuIRBT5UrkKjO/RF3Ca0tYbPgDZZtdb/7VcNfR/qGar9AYEgA8GM8D6y6m1/p +L/523TV06kI9tiS6flmxJtu2ydvpimAkWwtt4sDEp7g+gB+AHt4NUnkzzH7gXB25 +G2cy9mtJ45ah+bX9niCOZOdSFSPKXSGCh6DqDrtKcHUK7noUu7SCH0WdX9A7KKGo +6PfnwBRh2eXJ35BgCcHE4IM/isT+v/QVt1W0hQPE1PVhSNnNIlO0og== +-----END CERTIFICATE----- diff --git a/.testdata/sample-server-key.pem b/.testdata/sample-server-key.pem new file mode 100644 index 00000000..6d76bb3a --- /dev/null +++ b/.testdata/sample-server-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAtAz2shdfO212Pktb//RFHXRi70utH7Xnygz9WN6yQZUg6mbs +W1tFRxUo2JOKm/feTGJ6VohFEIdv7kMTVFCpwSDe31sz9yM1l2D85adJXQuOpiMV +ux264RnyTJIVLhkFVvHyGohWYAwrMGoeV+8lBzTnbhnk0OHujJCdw1iLohc3dodE +JNRXxOB6LZ/3kj1qo1IjQwlRpDXuGRRBNO8GBW8G72QaYTamRP0bvByJ09nuWlxZ +V2ouCLYXRAhbx3BXCj1zEiqEQBAEOhQrooPNHqgrsYapvyVHmWJFDOZaIkGNfHKn +qPtM8RhsD0yqKtmY38voQaowfHKKS6F52JYP+QIDAQABAoIBAHtGUPXYaLygko/O +OvxA+71R/ZcHgk4u1reRMzjQqM2cVEAJHhTipckoZKH8Sq/FAu/bkRWEEX1irbE9 +PZPB8qgnYFEe+bJg6gVuQ1jds65ABngbl3pYvaX3hN0GO/gm62//EZs286SpUDzC +u2nLc9e+UiIhGngl6JVXQp0IF/pusWsnKN34QK81qRjCEInlPLI04o4Imy/3a/zc +a4DwCwMqXbYHyv97nufApQB/1qYjiKcrhhk9AeRr3vZyOFrYec2avQKuWhLrGpRN +koNRIPKrgTKRq3gsSmCiPiSRUD40pmtJujhAJqvcZsqSqyv+J2yPdjjTAyqyV/29 +M98GwoECgYEAwf75kTv+k4INdnJVvMXdWFcxkgzq8/fLmBOGoTNzKj1fHTDZFryD +X41ymNBee26YaUMTMg7NahaaT8s+bk7LKjYFODhD6VVdvfagVchbkB5wK27Cwa0z +XzExLFtnUlE1o/ciZwqUlzLxRXQKlfTHCfiALqBmMIEDK0CRCtxY2rECgYEA7Zj1 +ynolm5jc84yUiHhCKUUVaL0mP9IFV5Dydo+UymzJmjMUx1H7baFbHo2rZ6plkRop +AjH5LZ8cJ06EgGu9rOlHwwKYIA2FKnzwCiFE9nH8BE+ki2mBNAwsj+NtCBTvEBsN +b3byhHMMdrJj6CPgaZZOCAHHp3kBKjNmL4Hoy8kCgYEArrqM5jb3MLzui0Sn3IMK +vkqqpzVjWaJSigLsO70veVgVlyEsJsJcQXARS3pB30LZm9WCMJAMjAUXr88LyCbH +7pkBUoW7BSqSaEr+VsVDUydXOIdmezMZFiAkfiNFiGsEuU4aelyZQSXtEfVWo4H4 +1A4yxcxKvl01EXvyJ6oXjcECgYA8JTJjNRR8FPApvvaCrV6iL9jBkNAz66hqiEi4 +dpRFwdAu9qtV4YzyLZxxWY+ASIQ5fRPQeHIJeHOaB6hHEf8L3GnMFcYIpyOEo+fn +yJA6ipQvSzHuEKEiWcqWCg45s4Lo4tA93TB7Etye132u8BYI5IGQSVMPM/R1iFlf +wVT68QKBgAxY1euChqV+Wio74IaNiHsnk6KzLiUaBOV5i1xA1Kz81mownYI1n9jk +LXTzgTrmJ8jtL0HmaJcl7plIre8h3WAQhHFXWjPhmE+YhwVfPU77JgA1o5Pn7KKm +NDoSb3GDRYHuBaAK5SvBxQ9re6rkueK+N5cdcB7ozt1h1FLFz0/D +-----END RSA PRIVATE KEY----- diff --git a/.testdata/sample-server.pem b/.testdata/sample-server.pem new file mode 100644 index 00000000..81c9effd --- /dev/null +++ b/.testdata/sample-server.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID6DCCAtCgAwIBAgIUPpYORWb7lrYrrfmFRDrugeHzncUwDQYJKoZIhvcNAQEL +BQAwWjELMAkGA1UEBhMCQ04xEDAOBgNVBAgTB1NpQ2h1YW4xEDAOBgNVBAcTB0No +ZW5nRHUxDDAKBgNVBAoTA3JlcTELMAkGA1UECxMCQ0ExDDAKBgNVBAMTA3JlcTAg +Fw0yMjAyMTAwNTA0MDBaGA8yMTIyMDExNzA1MDQwMFowXjELMAkGA1UEBhMCQ04x +EDAOBgNVBAgTB1NpQ2h1YW4xEDAOBgNVBAcTB0NoZW5nZHUxDDAKBgNVBAoTA3Jl +cTEMMAoGA1UECxMDcmVxMQ8wDQYDVQQDEwZzZXJ2ZXIwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQC0DPayF187bXY+S1v/9EUddGLvS60ftefKDP1Y3rJB +lSDqZuxbW0VHFSjYk4qb995MYnpWiEUQh2/uQxNUUKnBIN7fWzP3IzWXYPzlp0ld +C46mIxW7HbrhGfJMkhUuGQVW8fIaiFZgDCswah5X7yUHNOduGeTQ4e6MkJ3DWIui +Fzd2h0Qk1FfE4Hotn/eSPWqjUiNDCVGkNe4ZFEE07wYFbwbvZBphNqZE/Ru8HInT +2e5aXFlXai4IthdECFvHcFcKPXMSKoRAEAQ6FCuig80eqCuxhqm/JUeZYkUM5loi +QY18cqeo+0zxGGwPTKoq2Zjfy+hBqjB8copLoXnYlg/5AgMBAAGjgZ8wgZwwDgYD +VR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAMBgNV +HRMBAf8EAjAAMB0GA1UdDgQWBBSjaTM5PsU7WKjGqHjLG1ShTukoljAfBgNVHSME +GDAWgBSk203KlfhF9FcEBfEjYy2uh0U0yzAdBgNVHREEFjAUgglsb2NhbGhvc3SH +BH8AAAGGASowDQYJKoZIhvcNAQELBQADggEBAMr4rw7MAwCIJNuFxHukWIxIyq6B +g2P2kFIU+oWVhKd0VlHZ/lrjR3eB1GJ4lC8n+yslEYA0nipEBZm1zABUKNhRhmam +Oi8Gf09yOg5+NZM7BihgK+AF1Kc2a282XpQOHqqqr20QAeO9RLBwBjxc4koxgPog +HSVgbNceemxFrfT8kzjyjv9SRpeRjeAYLILHxPABRVEuO5rOMoRhZOGMxb65IAuT +aFbOPIdBsW1d2/cx5hQT1yfXXOjFXvKVL3pYkEDq+61E40G8Sfr1ZqLnQs6+Fhiy +vUsHXX7yq6hnyrhVy1wTL0mLqadK+umEybkalnCMHVlusNnLXuf43k9xlFU= +-----END CERTIFICATE----- diff --git a/client_test.go b/client_test.go index 02a35adf..ff722f4f 100644 --- a/client_test.go +++ b/client_test.go @@ -6,13 +6,10 @@ import ( ) func TestClientDump(t *testing.T) { - ts := createPostServer(t) - defer ts.Close() - c := tc() buf := new(bytes.Buffer) c.EnableDumpAllTo(buf) - resp, err := c.R().SetBody("test body").Post(ts.URL) + resp, err := c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertContains(t, buf.String(), "test body") @@ -22,7 +19,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutHeader().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertNotContains(t, buf.String(), "POST / HTTP/1.1") assertContains(t, buf.String(), "test body") @@ -32,7 +29,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertNotContains(t, buf.String(), "test body") @@ -42,7 +39,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutRequest().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertNotContains(t, buf.String(), "POST / HTTP/1.1") assertNotContains(t, buf.String(), "test body") @@ -52,7 +49,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutRequestBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertNotContains(t, buf.String(), "test body") @@ -62,7 +59,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutResponse().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertContains(t, buf.String(), "test body") @@ -72,7 +69,7 @@ func TestClientDump(t *testing.T) { c = tc() buf = new(bytes.Buffer) c.EnableDumpAllWithoutResponseBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertContains(t, buf.String(), "test body") @@ -89,7 +86,7 @@ func TestClientDump(t *testing.T) { Output: buf, } c.SetCommonDumpOptions(opt).EnableDumpAll() - resp, err = c.R().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1") assertNotContains(t, buf.String(), "test body") diff --git a/req_test.go b/req_test.go index ea620511..d04e9802 100644 --- a/req_test.go +++ b/req_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "reflect" "strings" + "sync" "testing" ) @@ -14,52 +15,63 @@ func tc() *Client { return C().EnableDebugLog() } -func getTestDataPath() string { +var testDataPath string + +func init() { pwd, _ := os.Getwd() - return filepath.Join(pwd, ".testdata") + testDataPath = filepath.Join(pwd, ".testdata") } -func createTestServer(fn func(w http.ResponseWriter, r *http.Request)) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(fn)) +var testServerMu sync.Mutex +var testServer *httptest.Server + +func getTestServerURL() string { + if testServer != nil { + return testServer.URL + } + testServerMu.Lock() + defer testServerMu.Unlock() + testServer = createTestServer() + return testServer.URL +} + +func handlePost(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("TestPost: text response")) +} + +func handleGet(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/": + w.Write([]byte("TestGet: text response")) + case "/no-content": + w.Write([]byte("")) + case "/json": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"TestGet": "JSON response"}`)) + case "/json-invalid": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("TestGet: Invalid JSON")) + case "/long-text": + w.Write([]byte("TestGet: text response with size > 30")) + case "/long-json": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) + case "/bad-request": + w.WriteHeader(http.StatusBadRequest) + case "/host-header": + w.Write([]byte(r.Host)) + } } -func createPostServer(t *testing.T) *httptest.Server { - ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - w.Write([]byte("TestPost: text response")) - } - }) - return ts -} - -func createGetServer(t *testing.T) *httptest.Server { - ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet { - switch r.URL.Path { - case "/": - w.Write([]byte("TestGet: text response")) - case "/no-content": - w.Write([]byte("")) - case "/json": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"TestGet": "JSON response"}`)) - case "/json-invalid": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte("TestGet: Invalid JSON")) - case "/long-text": - w.Write([]byte("TestGet: text response with size > 30")) - case "/long-json": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) - case "/bad-request": - w.WriteHeader(http.StatusBadRequest) - case "/host-header": - w.Write([]byte(r.Host)) - } +func createTestServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + handleGet(w, r) + case http.MethodPost: + handlePost(w, r) } - }) - - return ts + })) } func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { diff --git a/request_test.go b/request_test.go index 6ed499a9..6514d027 100644 --- a/request_test.go +++ b/request_test.go @@ -6,11 +6,8 @@ import ( ) func TestRequestDump(t *testing.T) { - ts := createPostServer(t) - defer ts.Close() - c := tc() - resp, err := c.R().EnableDump().SetBody(`test body`).Post(ts.URL) + resp, err := c.R().EnableDump().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump := resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -18,7 +15,7 @@ func TestRequestDump(t *testing.T) { assertContains(t, dump, "HTTP/1.1 200 OK") assertContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutRequest().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutRequest().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertNotContains(t, dump, "POST / HTTP/1.1") @@ -26,7 +23,7 @@ func TestRequestDump(t *testing.T) { assertContains(t, dump, "HTTP/1.1 200 OK") assertContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutRequestBody().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutRequestBody().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -34,7 +31,7 @@ func TestRequestDump(t *testing.T) { assertContains(t, dump, "HTTP/1.1 200 OK") assertContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutResponse().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutResponse().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -42,7 +39,7 @@ func TestRequestDump(t *testing.T) { assertNotContains(t, dump, "HTTP/1.1 200 OK") assertNotContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutResponseBody().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutResponseBody().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -50,7 +47,7 @@ func TestRequestDump(t *testing.T) { assertContains(t, dump, "HTTP/1.1 200 OK") assertNotContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutHeader().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutHeader().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertNotContains(t, dump, "POST / HTTP/1.1") @@ -58,7 +55,7 @@ func TestRequestDump(t *testing.T) { assertNotContains(t, dump, "HTTP/1.1 200 OK") assertContains(t, dump, "TestPost: text response") - resp, err = c.R().EnableDumpWithoutBody().SetBody(`test body`).Post(ts.URL) + resp, err = c.R().EnableDumpWithoutBody().SetBody(`test body`).Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -72,7 +69,7 @@ func TestRequestDump(t *testing.T) { ResponseHeader: false, ResponseBody: true, } - resp, err = c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(ts.URL) + resp, err = c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) dump = resp.Dump() assertContains(t, dump, "POST / HTTP/1.1") @@ -82,28 +79,26 @@ func TestRequestDump(t *testing.T) { } func TestGet(t *testing.T) { - ts := createGetServer(t) - defer ts.Close() c := tc() - resp, err := c.R().Get(ts.URL) + resp, err := c.R().Get(getTestServerURL()) assertResponse(t, resp, err) assertEqual(t, "TestGet: text response", resp.String()) - resp, err = c.R().Get(ts.URL + "/no-content") + resp, err = c.R().Get(getTestServerURL() + "/no-content") assertResponse(t, resp, err) assertEqual(t, "", resp.String()) - resp, err = c.R().Get(ts.URL + "/json") + resp, err = c.R().Get(getTestServerURL() + "/json") assertResponse(t, resp, err) assertEqual(t, `{"TestGet": "JSON response"}`, resp.String()) assertEqual(t, resp.GetContentType(), "application/json") - resp, err = c.R().Get(ts.URL + "/json-invalid") + resp, err = c.R().Get(getTestServerURL() + "/json-invalid") assertResponse(t, resp, err) assertEqual(t, `TestGet: Invalid JSON`, resp.String()) assertEqual(t, resp.GetContentType(), "application/json") - resp, err = c.R().Get(ts.URL + "/bad-request") + resp, err = c.R().Get(getTestServerURL() + "/bad-request") assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } From 4ac8dd4f4f7c9804d0b98a367ef0f1ba8dde83c8 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 14:48:00 +0800 Subject: [PATCH 271/843] refactor dump test --- client_test.go | 147 +++++++++++++++++++++++----------------------- req_test.go | 38 ++++++------ request_test.go | 152 +++++++++++++++++++++++------------------------- 3 files changed, 164 insertions(+), 173 deletions(-) diff --git a/client_test.go b/client_test.go index ff722f4f..168c8151 100644 --- a/client_test.go +++ b/client_test.go @@ -6,78 +6,75 @@ import ( ) func TestClientDump(t *testing.T) { - c := tc() - buf := new(bytes.Buffer) - c.EnableDumpAllTo(buf) - resp, err := c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertContains(t, buf.String(), "test body") - assertContains(t, buf.String(), "HTTP/1.1 200 OK") - assertContains(t, buf.String(), "TestPost: text response") - - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutHeader().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertNotContains(t, buf.String(), "POST / HTTP/1.1") - assertContains(t, buf.String(), "test body") - assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") - assertContains(t, buf.String(), "TestPost: text response") - - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertNotContains(t, buf.String(), "test body") - assertContains(t, buf.String(), "HTTP/1.1 200 OK") - assertNotContains(t, buf.String(), "TestPost: text response") - - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutRequest().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertNotContains(t, buf.String(), "POST / HTTP/1.1") - assertNotContains(t, buf.String(), "test body") - assertContains(t, buf.String(), "HTTP/1.1 200 OK") - assertContains(t, buf.String(), "TestPost: text response") - - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutRequestBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertNotContains(t, buf.String(), "test body") - assertContains(t, buf.String(), "HTTP/1.1 200 OK") - assertContains(t, buf.String(), "TestPost: text response") - - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutResponse().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertContains(t, buf.String(), "test body") - assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") - assertNotContains(t, buf.String(), "TestPost: text response") + testCases := []func(r *Client, reqHeader, reqBody, respHeader, respBody *bool){ + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAll() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = true + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutRequest() + *reqHeader = false + *reqBody = false + *respHeader = true + *respBody = true + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutRequestBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = true + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutResponse() + *reqHeader = true + *reqBody = true + *respHeader = false + *respBody = false + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutResponseBody() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = false + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutHeader() + *reqHeader = false + *reqBody = true + *respHeader = false + *respBody = true + }, + func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpAllWithoutBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = false + }, + } - c = tc() - buf = new(bytes.Buffer) - c.EnableDumpAllWithoutResponseBody().EnableDumpAllTo(buf) - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertContains(t, buf.String(), "test body") - assertContains(t, buf.String(), "HTTP/1.1 200 OK") - assertNotContains(t, buf.String(), "TestPost: text response") + for _, fn := range testCases { + c := tc() + buf := new(bytes.Buffer) + c.EnableDumpAllTo(buf) + var reqHeader, reqBody, respHeader, respBody bool + fn(c, &reqHeader, &reqBody, &respHeader, &respBody) + resp, err := c.R().SetBody(`test body`).Post("/") + assertResponse(t, resp, err) + dump := buf.String() + assertContains(t, dump, "POST / HTTP/1.1", reqHeader) + assertContains(t, dump, "test body", reqBody) + assertContains(t, dump, "HTTP/1.1 200 OK", respHeader) + assertContains(t, dump, "TestPost: text response", respBody) + } - c = tc() - buf = new(bytes.Buffer) + c := tc() + buf := new(bytes.Buffer) opt := &DumpOptions{ RequestHeader: true, RequestBody: false, @@ -86,10 +83,10 @@ func TestClientDump(t *testing.T) { Output: buf, } c.SetCommonDumpOptions(opt).EnableDumpAll() - resp, err = c.R().SetBody("test body").Post(getTestServerURL()) + resp, err := c.R().SetBody("test body").Post("/") assertResponse(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1") - assertNotContains(t, buf.String(), "test body") - assertNotContains(t, buf.String(), "HTTP/1.1 200 OK") - assertContains(t, buf.String(), "TestPost: text response") + assertContains(t, buf.String(), "POST / HTTP/1.1", true) + assertContains(t, buf.String(), "test body", false) + assertContains(t, buf.String(), "HTTP/1.1 200 OK", false) + assertContains(t, buf.String(), "TestPost: text response", true) } diff --git a/req_test.go b/req_test.go index d04e9802..cbc2efd3 100644 --- a/req_test.go +++ b/req_test.go @@ -11,8 +11,12 @@ import ( "testing" ) +func tr() *Request { + return tc().R() +} + func tc() *Client { - return C().EnableDebugLog() + return C().SetBaseURL(getTestServerURL()) } var testDataPath string @@ -36,26 +40,16 @@ func getTestServerURL() string { } func handlePost(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("TestPost: text response")) + switch r.URL.Path { + case "/": + w.Write([]byte("TestPost: text response")) + } } func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestGet: text response")) - case "/no-content": - w.Write([]byte("")) - case "/json": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"TestGet": "JSON response"}`)) - case "/json-invalid": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte("TestGet: Invalid JSON")) - case "/long-text": - w.Write([]byte("TestGet: text response with size > 30")) - case "/long-json": - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) case "/bad-request": w.WriteHeader(http.StatusBadRequest) case "/host-header": @@ -115,9 +109,16 @@ func assertNotContains(t *testing.T, s, substr string) { } } -func assertContains(t *testing.T, s, substr string) { - if !strings.Contains(s, substr) { - t.Errorf("%q is not included in %s", substr, s) +func assertContains(t *testing.T, s, substr string, shouldContain bool) { + isContain := strings.Contains(s, substr) + if shouldContain { + if !isContain { + t.Errorf("%q is not included in %s", substr, s) + } + } else { + if isContain { + t.Errorf("%q is included in %s", substr, s) + } } } @@ -146,7 +147,6 @@ func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { } else { r = true } - return } diff --git a/request_test.go b/request_test.go index 6514d027..c4a73350 100644 --- a/request_test.go +++ b/request_test.go @@ -6,62 +6,70 @@ import ( ) func TestRequestDump(t *testing.T) { - c := tc() - resp, err := c.R().EnableDump().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump := resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertContains(t, dump, "test body") - assertContains(t, dump, "HTTP/1.1 200 OK") - assertContains(t, dump, "TestPost: text response") - - resp, err = c.R().EnableDumpWithoutRequest().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertNotContains(t, dump, "POST / HTTP/1.1") - assertNotContains(t, dump, "test body") - assertContains(t, dump, "HTTP/1.1 200 OK") - assertContains(t, dump, "TestPost: text response") - - resp, err = c.R().EnableDumpWithoutRequestBody().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertNotContains(t, dump, "test body") - assertContains(t, dump, "HTTP/1.1 200 OK") - assertContains(t, dump, "TestPost: text response") - - resp, err = c.R().EnableDumpWithoutResponse().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertContains(t, dump, "test body") - assertNotContains(t, dump, "HTTP/1.1 200 OK") - assertNotContains(t, dump, "TestPost: text response") - - resp, err = c.R().EnableDumpWithoutResponseBody().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertContains(t, dump, "test body") - assertContains(t, dump, "HTTP/1.1 200 OK") - assertNotContains(t, dump, "TestPost: text response") - - resp, err = c.R().EnableDumpWithoutHeader().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertNotContains(t, dump, "POST / HTTP/1.1") - assertContains(t, dump, "test body") - assertNotContains(t, dump, "HTTP/1.1 200 OK") - assertContains(t, dump, "TestPost: text response") + testCases := []func(r *Request, reqHeader, reqBody, respHeader, respBody *bool){ + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDump() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = true + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutRequest() + *reqHeader = false + *reqBody = false + *respHeader = true + *respBody = true + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutRequestBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = true + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutResponse() + *reqHeader = true + *reqBody = true + *respHeader = false + *respBody = false + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutResponseBody() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = false + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutHeader() + *reqHeader = false + *reqBody = true + *respHeader = false + *respBody = true + }, + func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = false + }, + } - resp, err = c.R().EnableDumpWithoutBody().SetBody(`test body`).Post(getTestServerURL()) - assertResponse(t, resp, err) - dump = resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertNotContains(t, dump, "test body") - assertContains(t, dump, "HTTP/1.1 200 OK") - assertNotContains(t, dump, "TestPost: text response") + for _, fn := range testCases { + r := tr() + var reqHeader, reqBody, respHeader, respBody bool + fn(r, &reqHeader, &reqBody, &respHeader, &respBody) + resp, err := r.SetBody(`test body`).Post("/") + assertResponse(t, resp, err) + dump := resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1", reqHeader) + assertContains(t, dump, "test body", reqBody) + assertContains(t, dump, "HTTP/1.1 200 OK", respHeader) + assertContains(t, dump, "TestPost: text response", respBody) + } opt := &DumpOptions{ RequestHeader: true, @@ -69,36 +77,22 @@ func TestRequestDump(t *testing.T) { ResponseHeader: false, ResponseBody: true, } - resp, err = c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) + resp, err := tr().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertResponse(t, resp, err) - dump = resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1") - assertNotContains(t, dump, "test body") - assertNotContains(t, dump, "HTTP/1.1 200 OK") - assertContains(t, dump, "TestPost: text response") + dump := resp.Dump() + assertContains(t, dump, "POST / HTTP/1.1", true) + assertContains(t, dump, "test body", false) + assertContains(t, dump, "HTTP/1.1 200 OK", false) + assertContains(t, dump, "TestPost: text response", true) } func TestGet(t *testing.T) { - - c := tc() - resp, err := c.R().Get(getTestServerURL()) + resp, err := tr().Get("/") assertResponse(t, resp, err) assertEqual(t, "TestGet: text response", resp.String()) +} - resp, err = c.R().Get(getTestServerURL() + "/no-content") - assertResponse(t, resp, err) - assertEqual(t, "", resp.String()) - - resp, err = c.R().Get(getTestServerURL() + "/json") - assertResponse(t, resp, err) - assertEqual(t, `{"TestGet": "JSON response"}`, resp.String()) - assertEqual(t, resp.GetContentType(), "application/json") - - resp, err = c.R().Get(getTestServerURL() + "/json-invalid") - assertResponse(t, resp, err) - assertEqual(t, `TestGet: Invalid JSON`, resp.String()) - assertEqual(t, resp.GetContentType(), "application/json") - - resp, err = c.R().Get(getTestServerURL() + "/bad-request") +func TestBadRequest(t *testing.T) { + resp, err := tr().Get("/bad-request") assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } From 8cfb118fe41d3ba87a967ce4b6e0f0d2092f9ba7 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 15:33:29 +0800 Subject: [PATCH 272/843] Add TestQueryParam and TestCustomUserAgent --- client_test.go | 4 +-- req_test.go | 15 ++++----- request_test.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/client_test.go b/client_test.go index 168c8151..cc0d3607 100644 --- a/client_test.go +++ b/client_test.go @@ -65,7 +65,7 @@ func TestClientDump(t *testing.T) { var reqHeader, reqBody, respHeader, respBody bool fn(c, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := c.R().SetBody(`test body`).Post("/") - assertResponse(t, resp, err) + assertSucess(t, resp, err) dump := buf.String() assertContains(t, dump, "POST / HTTP/1.1", reqHeader) assertContains(t, dump, "test body", reqBody) @@ -84,7 +84,7 @@ func TestClientDump(t *testing.T) { } c.SetCommonDumpOptions(opt).EnableDumpAll() resp, err := c.R().SetBody("test body").Post("/") - assertResponse(t, resp, err) + assertSucess(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1", true) assertContains(t, buf.String(), "test body", false) assertContains(t, buf.String(), "HTTP/1.1 200 OK", false) diff --git a/req_test.go b/req_test.go index cbc2efd3..b1c12bcc 100644 --- a/req_test.go +++ b/req_test.go @@ -54,6 +54,12 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/host-header": w.Write([]byte(r.Host)) + case "/user-agent": + w.Write([]byte(r.Header.Get(hdrUserAgentKey))) + case "/content-type": + w.Write([]byte(r.Header.Get(hdrContentTypeKey))) + case "/query-parameter": + w.Write([]byte(r.URL.RawQuery)) } } @@ -76,7 +82,7 @@ func assertStatus(t *testing.T, resp *Response, err error, statusCode int, statu assertEqual(t, status, resp.Status) } -func assertResponse(t *testing.T, resp *Response, err error) { +func assertSucess(t *testing.T, resp *Response, err error) { assertError(t, err) assertNotNil(t, resp) assertNotNil(t, resp.Body) @@ -103,12 +109,6 @@ func assertType(t *testing.T, typ, v interface{}) { } } -func assertNotContains(t *testing.T, s, substr string) { - if strings.Contains(s, substr) { - t.Errorf("%q is included in %s", substr, s) - } -} - func assertContains(t *testing.T, s, substr string, shouldContain bool) { isContain := strings.Contains(s, substr) if shouldContain { @@ -158,7 +158,6 @@ func isNil(v interface{}) bool { if v == nil { return true } - rv := reflect.ValueOf(v) kind := rv.Kind() if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { diff --git a/request_test.go b/request_test.go index c4a73350..bbe177f3 100644 --- a/request_test.go +++ b/request_test.go @@ -63,7 +63,7 @@ func TestRequestDump(t *testing.T) { var reqHeader, reqBody, respHeader, respBody bool fn(r, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := r.SetBody(`test body`).Post("/") - assertResponse(t, resp, err) + assertSucess(t, resp, err) dump := resp.Dump() assertContains(t, dump, "POST / HTTP/1.1", reqHeader) assertContains(t, dump, "test body", reqBody) @@ -78,7 +78,7 @@ func TestRequestDump(t *testing.T) { ResponseBody: true, } resp, err := tr().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) - assertResponse(t, resp, err) + assertSucess(t, resp, err) dump := resp.Dump() assertContains(t, dump, "POST / HTTP/1.1", true) assertContains(t, dump, "test body", false) @@ -88,7 +88,7 @@ func TestRequestDump(t *testing.T) { func TestGet(t *testing.T) { resp, err := tr().Get("/") - assertResponse(t, resp, err) + assertSucess(t, resp, err) assertEqual(t, "TestGet: text response", resp.String()) } @@ -96,3 +96,83 @@ func TestBadRequest(t *testing.T) { resp, err := tr().Get("/bad-request") assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } + +func TestCustomUserAgent(t *testing.T) { + customUserAgent := "My Custom User Agent" + resp, err := tr().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") + assertSucess(t, resp, err) + assertEqual(t, customUserAgent, resp.String()) +} + +func TestQueryParam(t *testing.T) { + c := tc() + + // SetQueryParam + resp, err := c.R(). + SetQueryParam("key1", "value1"). + SetQueryParam("key2", "value2"). + SetQueryParam("key3", "value3"). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + + // SetQueryString + resp, err = c.R(). + SetQueryString("key1=value1&key2=value2&key3=value3"). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + + // SetQueryParams + resp, err = c.R(). + SetQueryParams(map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + }). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + + // SetQueryParam & SetQueryParams & SetQueryString + resp, err = c.R(). + SetQueryParam("key1", "value1"). + SetQueryParams(map[string]string{ + "key2": "value2", + "key3": "value3", + }). + SetQueryString("key4=value4&key5=value5"). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) + + // Set same param to override + resp, err = c.R(). + SetQueryParam("key1", "value1"). + SetQueryParams(map[string]string{ + "key2": "value2", + "key3": "value3", + }). + SetQueryString("key4=value4&key5=value5"). + SetQueryParam("key1", "value11"). + SetQueryParam("key2", "value22"). + SetQueryParam("key4", "value44"). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) + + // Add same param without override + resp, err = c.R(). + SetQueryParam("key1", "value1"). + SetQueryParams(map[string]string{ + "key2": "value2", + "key3": "value3", + }). + SetQueryString("key4=value4&key5=value5"). + AddQueryParam("key1", "value11"). + AddQueryParam("key2", "value22"). + AddQueryParam("key4", "value44"). + Get("/query-parameter") + assertSucess(t, resp, err) + assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) +} From 0fbfd819ec880e1ceb260f79c17df09a15948841 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 15:47:26 +0800 Subject: [PATCH 273/843] Add param at client level in TestQueryParam --- request_test.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/request_test.go b/request_test.go index bbe177f3..97537a95 100644 --- a/request_test.go +++ b/request_test.go @@ -107,6 +107,15 @@ func TestCustomUserAgent(t *testing.T) { func TestQueryParam(t *testing.T) { c := tc() + // Set query param at client level, should be overwritten at request level + c.SetCommonQueryParam("key1", "client"). + SetCommonQueryParams(map[string]string{ + "key2": "client", + "key3": "client", + }). + SetCommonQueryString("key4=client&key5=client"). + AddCommonQueryParam("key5", "extra") + // SetQueryParam resp, err := c.R(). SetQueryParam("key1", "value1"). @@ -114,14 +123,14 @@ func TestQueryParam(t *testing.T) { SetQueryParam("key3", "value3"). Get("/query-parameter") assertSucess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryString resp, err = c.R(). SetQueryString("key1=value1&key2=value2&key3=value3"). Get("/query-parameter") assertSucess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParams resp, err = c.R(). @@ -132,7 +141,7 @@ func TestQueryParam(t *testing.T) { }). Get("/query-parameter") assertSucess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParam & SetQueryParams & SetQueryString resp, err = c.R(). From 523f6a8c95230e540ca94c08bf8811ab1b13aa15 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 16:01:30 +0800 Subject: [PATCH 274/843] Add TestPathParam --- req_test.go | 11 +++++++++++ request_test.go | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/req_test.go b/req_test.go index b1c12bcc..47a7d7e1 100644 --- a/req_test.go +++ b/req_test.go @@ -1,6 +1,7 @@ package req import ( + "fmt" "net/http" "net/http/httptest" "os" @@ -46,6 +47,12 @@ func handlePost(w http.ResponseWriter, r *http.Request) { } } +func handleGetUserProfile(w http.ResponseWriter, r *http.Request) { + user := strings.TrimLeft(r.URL.Path, "/user") + user = strings.TrimSuffix(user, "/profile") + w.Write([]byte(fmt.Sprintf("%s's profile", user))) +} + func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -60,6 +67,10 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.Header.Get(hdrContentTypeKey))) case "/query-parameter": w.Write([]byte(r.URL.RawQuery)) + default: + if strings.HasPrefix(r.URL.Path, "/user") { + handleGetUserProfile(w, r) + } } } diff --git a/request_test.go b/request_test.go index 97537a95..0a41418c 100644 --- a/request_test.go +++ b/request_test.go @@ -1,6 +1,7 @@ package req import ( + "fmt" "net/http" "testing" ) @@ -185,3 +186,12 @@ func TestQueryParam(t *testing.T) { assertSucess(t, resp, err) assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) } + +func TestPathParam(t *testing.T) { + username := "imroc" + resp, err := tr(). + SetPathParam("username", username). + Get("/user/{username}/profile") + assertSucess(t, resp, err) + assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) +} From a460b32e8510cb26e16309f18de3bdf26f57cf64 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 17:12:03 +0800 Subject: [PATCH 275/843] Add TestHeader --- req_test.go | 6 ++++++ request_test.go | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 47a7d7e1..c479a5d6 100644 --- a/req_test.go +++ b/req_test.go @@ -1,6 +1,7 @@ package req import ( + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -61,6 +62,10 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/host-header": w.Write([]byte(r.Host)) + case "/header": + b, _ := json.Marshal(r.Header) + w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Write(b) case "/user-agent": w.Write([]byte(r.Header.Get(hdrUserAgentKey))) case "/content-type": @@ -69,6 +74,7 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.RawQuery)) default: if strings.HasPrefix(r.URL.Path, "/user") { + r.Cookies() handleGetUserProfile(w, r) } } diff --git a/request_test.go b/request_test.go index 0a41418c..d359066b 100644 --- a/request_test.go +++ b/request_test.go @@ -98,11 +98,26 @@ func TestBadRequest(t *testing.T) { assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } -func TestCustomUserAgent(t *testing.T) { +func TestHeader(t *testing.T) { + // Set User-Agent customUserAgent := "My Custom User Agent" resp, err := tr().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") assertSucess(t, resp, err) assertEqual(t, customUserAgent, resp.String()) + + // Set custom header + headers := make(http.Header) + resp, err = tr(). + SetHeader("header1", "value1"). + SetHeaders(map[string]string{ + "header2": "value2", + "header3": "value3", + }).SetResult(&headers). + Get("/header") + assertSucess(t, resp, err) + assertEqual(t, "value1", headers.Get("header1")) + assertEqual(t, "value2", headers.Get("header2")) + assertEqual(t, "value3", headers.Get("header3")) } func TestQueryParam(t *testing.T) { From 57079e0bec0fa059b02a25cb69d41b82881ce766 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 17:19:00 +0800 Subject: [PATCH 276/843] Add TestCookie --- req_test.go | 1 - request_test.go | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index c479a5d6..cd6d897d 100644 --- a/req_test.go +++ b/req_test.go @@ -74,7 +74,6 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.RawQuery)) default: if strings.HasPrefix(r.URL.Path, "/user") { - r.Cookies() handleGetUserProfile(w, r) } } diff --git a/request_test.go b/request_test.go index d359066b..a4bc85cd 100644 --- a/request_test.go +++ b/request_test.go @@ -98,6 +98,22 @@ func TestBadRequest(t *testing.T) { assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } +func TestCookie(t *testing.T) { + headers := make(http.Header) + resp, err := tr().SetCookies( + &http.Cookie{ + Name: "cookie1", + Value: "value1", + }, + &http.Cookie{ + Name: "cookie2", + Value: "value2", + }, + ).SetResult(&headers).Get("/header") + assertSucess(t, resp, err) + assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) +} + func TestHeader(t *testing.T) { // Set User-Agent customUserAgent := "My Custom User Agent" From 024cd2773a41217dd6cef7738c567d40a9cc2419 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 17:36:19 +0800 Subject: [PATCH 277/843] Add TestAuth --- request_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/request_test.go b/request_test.go index a4bc85cd..0bfce4b4 100644 --- a/request_test.go +++ b/request_test.go @@ -114,6 +114,25 @@ func TestCookie(t *testing.T) { assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } +func TestAuth(t *testing.T) { + headers := make(http.Header) + resp, err := tr(). + SetBasicAuth("imroc", "123456"). + SetResult(&headers). + Get("/header") + assertSucess(t, resp, err) + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) + + token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" + headers = make(http.Header) + resp, err = tr(). + SetBearerAuthToken(token). + SetResult(&headers). + Get("/header") + assertSucess(t, resp, err) + assertEqual(t, "Bearer "+token, headers.Get("Authorization")) +} + func TestHeader(t *testing.T) { // Set User-Agent customUserAgent := "My Custom User Agent" From 378a0e9b342002e0059d379779bcc5244456af45 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 17:43:56 +0800 Subject: [PATCH 278/843] SetFileReader depend on SetFileUpload --- request.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/request.go b/request.go index 7ced8b1b..28bef443 100644 --- a/request.go +++ b/request.go @@ -177,8 +177,7 @@ func SetFileReader(paramName, filePath string, reader io.Reader) *Request { // SetFileReader set up a multipart form with a reader to upload file. func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *Request { - r.isMultiPart = true - r.uploadFiles = append(r.uploadFiles, &FileUpload{ + r.SetFileUpload(FileUpload{ ParamName: paramName, FileName: filename, File: reader, From 9163a8d02a356e5b8e24c59e790406de1639ad00 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 17:53:29 +0800 Subject: [PATCH 279/843] Set plain text Content-Type if SetBodyString or SetBodyBytes --- req.go | 14 +++++++------- request.go | 13 +++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/req.go b/req.go index 9710cae4..84e7a908 100644 --- a/req.go +++ b/req.go @@ -6,13 +6,13 @@ import ( ) const ( - hdrUserAgentKey = "User-Agent" - hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" - hdrContentTypeKey = "Content-Type" - plainTextType = "text/plain; charset=utf-8" - jsonContentType = "application/json; charset=utf-8" - xmlContentType = "text/xml; charset=utf-8" - formContentType = "application/x-www-form-urlencoded" + hdrUserAgentKey = "User-Agent" + hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" + hdrContentTypeKey = "Content-Type" + plainTextContentType = "text/plain; charset=utf-8" + jsonContentType = "application/json; charset=utf-8" + xmlContentType = "text/xml; charset=utf-8" + formContentType = "application/x-www-form-urlencoded" ) type kv struct { diff --git a/request.go b/request.go index 28bef443..0a499b91 100644 --- a/request.go +++ b/request.go @@ -658,7 +658,7 @@ func SetBodyBytes(body []byte) *Request { // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) - return r + return r.SetContentType(plainTextContentType) } // SetBodyString is a global wrapper methods which delegated @@ -670,7 +670,7 @@ func SetBodyString(body string) *Request { // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) - return r + return r.SetContentType(plainTextContentType) } // SetBodyJsonString is a global wrapper methods which delegated @@ -713,7 +713,8 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { r.appendError(err) return r } - return r.SetContentType(jsonContentType).SetBodyBytes(b) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(b)) + return r.SetContentType(jsonContentType) } // SetBodyXmlString is a global wrapper methods which delegated @@ -756,7 +757,8 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { r.appendError(err) return r } - return r.SetContentType(xmlContentType).SetBodyBytes(b) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(b)) + return r.SetContentType(xmlContentType) } // SetContentType is a global wrapper methods which delegated @@ -767,8 +769,7 @@ func SetContentType(contentType string) *Request { // SetContentType set the `Content-Type` for the request. func (r *Request) SetContentType(contentType string) *Request { - r.SetHeader(hdrContentTypeKey, contentType) - return r + return r.SetHeader(hdrContentTypeKey, contentType) } // Context method returns the Context if its already set in request From 8a92e80adf6f287ad70cb6cb5e30d94dec15ee7c Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 20:12:52 +0800 Subject: [PATCH 280/843] Add TestSetBodyContent and TestSetBodyJson --- req_test.go | 15 +++++++ request_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/req_test.go b/req_test.go index cd6d897d..8715fe6b 100644 --- a/req_test.go +++ b/req_test.go @@ -3,6 +3,7 @@ package req import ( "encoding/json" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "os" @@ -41,10 +42,24 @@ func getTestServerURL() string { return testServer.URL } +type echo struct { + Header http.Header `json:"header"` + Body string `json:"body"` +} + func handlePost(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestPost: text response")) + case "/echo": + b, _ := ioutil.ReadAll(r.Body) + e := echo{ + Header: r.Header, + Body: string(b), + } + result, _ := json.Marshal(&e) + w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Write(result) } } diff --git a/request_test.go b/request_test.go index 0bfce4b4..41a8c9a4 100644 --- a/request_test.go +++ b/request_test.go @@ -1,8 +1,10 @@ package req import ( + "encoding/json" "fmt" "net/http" + "strings" "testing" ) @@ -98,6 +100,107 @@ func TestBadRequest(t *testing.T) { assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } +func TestSetBodyJson(t *testing.T) { + type User struct { + Username string `json:"username"` + } + + assertUsername := func(username string) func(e *echo) { + return func(e *echo) { + var user User + err := json.Unmarshal([]byte(e.Body), &user) + assertError(t, err) + assertEqual(t, username, user.Username) + } + } + testCases := []struct { + Set func(r *Request) + Assert func(e *echo) + }{ + { // SetBody with map + Set: func(r *Request) { + m := map[string]interface{}{ + "username": "imroc", + } + r.SetBody(&m) + }, + Assert: assertUsername("imroc"), + }, + { // SetBodyJsonMarshal with map + Set: func(r *Request) { + m := map[string]interface{}{ + "username": "imroc", + } + r.SetBodyJsonMarshal(&m) + }, + Assert: assertUsername("imroc"), + }, + { // SetBody with struct + Set: func(r *Request) { + var user User + user.Username = "imroc" + r.SetBody(&user) + }, + Assert: assertUsername("imroc"), + }, + { // SetBodyJsonMarshal with struct + Set: func(r *Request) { + var user User + user.Username = "imroc" + r.SetBodyJsonMarshal(&user) + }, + Assert: assertUsername("imroc"), + }, + } + + for _, c := range testCases { + r := tr() + c.Set(r) + var e echo + resp, err := r.SetResult(&e).Post("/echo") + assertSucess(t, resp, err) + c.Assert(&e) + } +} + +func TestSetBodyContent(t *testing.T) { + var e echo + testBody := "test body" + + testCases := []func(r *Request){ + func(r *Request) { // SetBody with string + r.SetBody(testBody) + }, + func(r *Request) { // SetBody with []byte + r.SetBody([]byte(testBody)) + }, + func(r *Request) { // SetBodyString + r.SetBodyString(testBody) + }, + func(r *Request) { // SetBodyBytes + r.SetBodyBytes([]byte(testBody)) + }, + } + + for _, fn := range testCases { + r := tr() + fn(r) + var e echo + resp, err := r.SetResult(&e).Post("/echo") + assertSucess(t, resp, err) + assertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) + assertEqual(t, testBody, e.Body) + } + + // Set Reader + testBodyReader := strings.NewReader(testBody) + e = echo{} + resp, err := tr().SetBody(testBodyReader).SetResult(&e).Post("/echo") + assertSucess(t, resp, err) + assertEqual(t, testBody, e.Body) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) +} + func TestCookie(t *testing.T) { headers := make(http.Header) resp, err := tr().SetCookies( From 18181167309238ea4d04f2b00ece51aa8bd84453 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 21:05:40 +0800 Subject: [PATCH 281/843] Add TestSetBodyMarshal --- client_test.go | 4 +-- internal/util/util.go | 10 ++----- middleware.go | 4 +-- req_test.go | 8 +++--- request_test.go | 65 ++++++++++++++++++++++++++++++------------- 5 files changed, 55 insertions(+), 36 deletions(-) diff --git a/client_test.go b/client_test.go index cc0d3607..90dfe27e 100644 --- a/client_test.go +++ b/client_test.go @@ -65,7 +65,7 @@ func TestClientDump(t *testing.T) { var reqHeader, reqBody, respHeader, respBody bool fn(c, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := c.R().SetBody(`test body`).Post("/") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) dump := buf.String() assertContains(t, dump, "POST / HTTP/1.1", reqHeader) assertContains(t, dump, "test body", reqBody) @@ -84,7 +84,7 @@ func TestClientDump(t *testing.T) { } c.SetCommonDumpOptions(opt).EnableDumpAll() resp, err := c.R().SetBody("test body").Post("/") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertContains(t, buf.String(), "POST / HTTP/1.1", true) assertContains(t, buf.String(), "test body", false) assertContains(t, buf.String(), "HTTP/1.1 200 OK", false) diff --git a/internal/util/util.go b/internal/util/util.go index 43d2582d..82d488f9 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -5,23 +5,17 @@ import ( "encoding/base64" "os" "reflect" - "regexp" "strings" ) -var ( - jsonCheck = regexp.MustCompile(`(?i:(application|text)/(json|.*\+json|json\-.*)(;|$))`) - xmlCheck = regexp.MustCompile(`(?i:(application|text)/(xml|.*\+xml)(;|$))`) -) - // IsJSONType method is to check JSON content type or not func IsJSONType(ct string) bool { - return jsonCheck.MatchString(ct) + return strings.Contains(ct, "json") } // IsXMLType method is to check XML content type or not func IsXMLType(ct string) bool { - return xmlCheck.MatchString(ct) + return strings.Contains(ct, "xml") } func GetPointer(v interface{}) interface{} { diff --git a/middleware.go b/middleware.go index 10819d3c..bdc0d86d 100644 --- a/middleware.go +++ b/middleware.go @@ -124,13 +124,13 @@ func handleMarshalBody(c *Client, r *Request) error { if err != nil { return err } - r.SetBodyBytes(body) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) } else { body, err := c.jsonMarshal(r.marshalBody) if err != nil { return err } - r.SetBodyBytes(body) + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) } return nil } diff --git a/req_test.go b/req_test.go index 8715fe6b..fcd99fc0 100644 --- a/req_test.go +++ b/req_test.go @@ -43,8 +43,8 @@ func getTestServerURL() string { } type echo struct { - Header http.Header `json:"header"` - Body string `json:"body"` + Header http.Header `json:"header" xml:"header"` + Body string `json:"body" xml:"body"` } func handlePost(w http.ResponseWriter, r *http.Request) { @@ -57,8 +57,8 @@ func handlePost(w http.ResponseWriter, r *http.Request) { Header: r.Header, Body: string(b), } - result, _ := json.Marshal(&e) w.Header().Set(hdrContentTypeKey, jsonContentType) + result, _ := json.Marshal(&e) w.Write(result) } } @@ -113,7 +113,7 @@ func assertStatus(t *testing.T, resp *Response, err error, statusCode int, statu assertEqual(t, status, resp.Status) } -func assertSucess(t *testing.T, resp *Response, err error) { +func assertSuccess(t *testing.T, resp *Response, err error) { assertError(t, err) assertNotNil(t, resp) assertNotNil(t, resp.Body) diff --git a/request_test.go b/request_test.go index 41a8c9a4..7719769e 100644 --- a/request_test.go +++ b/request_test.go @@ -2,6 +2,7 @@ package req import ( "encoding/json" + "encoding/xml" "fmt" "net/http" "strings" @@ -66,7 +67,7 @@ func TestRequestDump(t *testing.T) { var reqHeader, reqBody, respHeader, respBody bool fn(r, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := r.SetBody(`test body`).Post("/") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) dump := resp.Dump() assertContains(t, dump, "POST / HTTP/1.1", reqHeader) assertContains(t, dump, "test body", reqBody) @@ -81,7 +82,7 @@ func TestRequestDump(t *testing.T) { ResponseBody: true, } resp, err := tr().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) - assertSucess(t, resp, err) + assertSuccess(t, resp, err) dump := resp.Dump() assertContains(t, dump, "POST / HTTP/1.1", true) assertContains(t, dump, "test body", false) @@ -91,7 +92,7 @@ func TestRequestDump(t *testing.T) { func TestGet(t *testing.T) { resp, err := tr().Get("/") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "TestGet: text response", resp.String()) } @@ -100,9 +101,9 @@ func TestBadRequest(t *testing.T) { assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } -func TestSetBodyJson(t *testing.T) { +func TestSetBodyMarshal(t *testing.T) { type User struct { - Username string `json:"username"` + Username string `json:"username" xml:"username"` } assertUsername := func(username string) func(e *echo) { @@ -113,6 +114,14 @@ func TestSetBodyJson(t *testing.T) { assertEqual(t, username, user.Username) } } + assertUsernameXml := func(username string) func(e *echo) { + return func(e *echo) { + var user User + err := xml.Unmarshal([]byte(e.Body), &user) + assertError(t, err) + assertEqual(t, username, user.Username) + } + } testCases := []struct { Set func(r *Request) Assert func(e *echo) @@ -143,6 +152,14 @@ func TestSetBodyJson(t *testing.T) { }, Assert: assertUsername("imroc"), }, + { // SetBody with struct use xml + Set: func(r *Request) { + var user User + user.Username = "imroc" + r.SetBody(&user).SetContentType(xmlContentType) + }, + Assert: assertUsernameXml("imroc"), + }, { // SetBodyJsonMarshal with struct Set: func(r *Request) { var user User @@ -151,6 +168,14 @@ func TestSetBodyJson(t *testing.T) { }, Assert: assertUsername("imroc"), }, + { // SetBodyXmlMarshal with struct + Set: func(r *Request) { + var user User + user.Username = "imroc" + r.SetBodyXmlMarshal(&user) + }, + Assert: assertUsernameXml("imroc"), + }, } for _, c := range testCases { @@ -158,7 +183,7 @@ func TestSetBodyJson(t *testing.T) { c.Set(r) var e echo resp, err := r.SetResult(&e).Post("/echo") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) c.Assert(&e) } } @@ -187,7 +212,7 @@ func TestSetBodyContent(t *testing.T) { fn(r) var e echo resp, err := r.SetResult(&e).Post("/echo") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) assertEqual(t, testBody, e.Body) } @@ -196,7 +221,7 @@ func TestSetBodyContent(t *testing.T) { testBodyReader := strings.NewReader(testBody) e = echo{} resp, err := tr().SetBody(testBodyReader).SetResult(&e).Post("/echo") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, testBody, e.Body) assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) } @@ -213,7 +238,7 @@ func TestCookie(t *testing.T) { Value: "value2", }, ).SetResult(&headers).Get("/header") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } @@ -223,7 +248,7 @@ func TestAuth(t *testing.T) { SetBasicAuth("imroc", "123456"). SetResult(&headers). Get("/header") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" @@ -232,7 +257,7 @@ func TestAuth(t *testing.T) { SetBearerAuthToken(token). SetResult(&headers). Get("/header") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "Bearer "+token, headers.Get("Authorization")) } @@ -240,7 +265,7 @@ func TestHeader(t *testing.T) { // Set User-Agent customUserAgent := "My Custom User Agent" resp, err := tr().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, customUserAgent, resp.String()) // Set custom header @@ -252,7 +277,7 @@ func TestHeader(t *testing.T) { "header3": "value3", }).SetResult(&headers). Get("/header") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "value1", headers.Get("header1")) assertEqual(t, "value2", headers.Get("header2")) assertEqual(t, "value3", headers.Get("header3")) @@ -276,14 +301,14 @@ func TestQueryParam(t *testing.T) { SetQueryParam("key2", "value2"). SetQueryParam("key3", "value3"). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryString resp, err = c.R(). SetQueryString("key1=value1&key2=value2&key3=value3"). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParams @@ -294,7 +319,7 @@ func TestQueryParam(t *testing.T) { "key3": "value3", }). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParam & SetQueryParams & SetQueryString @@ -306,7 +331,7 @@ func TestQueryParam(t *testing.T) { }). SetQueryString("key4=value4&key5=value5"). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) // Set same param to override @@ -321,7 +346,7 @@ func TestQueryParam(t *testing.T) { SetQueryParam("key2", "value22"). SetQueryParam("key4", "value44"). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) // Add same param without override @@ -336,7 +361,7 @@ func TestQueryParam(t *testing.T) { AddQueryParam("key2", "value22"). AddQueryParam("key4", "value44"). Get("/query-parameter") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) } @@ -345,6 +370,6 @@ func TestPathParam(t *testing.T) { resp, err := tr(). SetPathParam("username", username). Get("/user/{username}/profile") - assertSucess(t, resp, err) + assertSuccess(t, resp, err) assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) } From faac5cd274ad734bf9ed3e9a362e400ae2282327 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 21:14:02 +0800 Subject: [PATCH 282/843] extract setBodyBytes --- middleware.go | 6 +++--- request.go | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/middleware.go b/middleware.go index bdc0d86d..218dfc61 100644 --- a/middleware.go +++ b/middleware.go @@ -107,7 +107,7 @@ func handleMultiPart(c *Client, r *Request) (err error) { func handleFormData(r *Request) { r.SetContentType(formContentType) - r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(r.FormData.Encode())) + r.setBodyBytes([]byte(r.FormData.Encode())) } func handleMarshalBody(c *Client, r *Request) error { @@ -124,13 +124,13 @@ func handleMarshalBody(c *Client, r *Request) error { if err != nil { return err } - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.setBodyBytes(body) } else { body, err := c.jsonMarshal(r.marshalBody) if err != nil { return err } - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.setBodyBytes(body) } return nil } diff --git a/request.go b/request.go index 0a499b91..a4a681eb 100644 --- a/request.go +++ b/request.go @@ -657,10 +657,15 @@ func SetBodyBytes(body []byte) *Request { // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.setBodyBytes(body) return r.SetContentType(plainTextContentType) } +func (r *Request) setBodyBytes(body []byte) *Request { + r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + return r +} + // SetBodyString is a global wrapper methods which delegated // to the default client, create a request and SetBodyString for request. func SetBodyString(body string) *Request { @@ -669,7 +674,7 @@ func SetBodyString(body string) *Request { // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { - r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + r.setBodyBytes([]byte(body)) return r.SetContentType(plainTextContentType) } @@ -682,7 +687,7 @@ func SetBodyJsonString(body string) *Request { // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { - r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + r.setBodyBytes([]byte(body)) return r.SetContentType(jsonContentType) } @@ -695,7 +700,7 @@ func SetBodyJsonBytes(body []byte) *Request { // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.setBodyBytes(body) return r.SetContentType(jsonContentType) } @@ -713,7 +718,7 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { r.appendError(err) return r } - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(b)) + r.setBodyBytes(b) return r.SetContentType(jsonContentType) } @@ -726,7 +731,7 @@ func SetBodyXmlString(body string) *Request { // SetBodyXmlString set the request body as string and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { - r.RawRequest.Body = ioutil.NopCloser(strings.NewReader(body)) + r.setBodyBytes([]byte(body)) return r.SetContentType(xmlContentType) } @@ -739,7 +744,7 @@ func SetBodyXmlBytes(body []byte) *Request { // SetBodyXmlBytes set the request body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.setBodyBytes(body) return r.SetContentType(xmlContentType) } @@ -757,7 +762,7 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { r.appendError(err) return r } - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(b)) + r.setBodyBytes(b) return r.SetContentType(xmlContentType) } From 97310bf12b06698d41f9bffe20685632ab2180cb Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 21:22:10 +0800 Subject: [PATCH 283/843] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5bc311a6..248828f2 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@

Req

-

Simplified Golang HTTP client library with Black Magic, Less Code and More Efficiency.

+

Simplified Golang HTTP client library with Black Magic, Less code and More efficiency.

## News -Brand-New version v3 is out, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) +Brand-New version v3 is released, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). -> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout [v2 branch](https://github.com/imroc/req/tree/v2) if you want. +> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout v2 branch if you want. ## Table of Contents From 508dde4b2f1d4134e748cef547d47d234b14d0a0 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Feb 2022 21:25:20 +0800 Subject: [PATCH 284/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 248828f2..857184e1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Brand-New version v3 is released, which is completely rewritten, bringing revolu If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). -> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience, checkout v2 branch if you want. +> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience ## Table of Contents From 1d7681413309f232830ae931ba1b881e8621ed0e Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 09:36:24 +0800 Subject: [PATCH 285/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 857184e1..303378d9 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. * Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). -* Works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). +* Works fine with both `HTTP/2` and `HTTP/1.1`, which `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). * Automatic marshal and unmarshal for JSON and XML content type and fully customizable (see [Body and Marshal/Unmarshal](#Body)). * Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. From 38200d496f1c12cbca3a29892dd7d47047fa1615 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 11:48:45 +0800 Subject: [PATCH 286/843] support GetTLSClientConfig and EnableInsecureSkipVerify --- client.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 1fc52a94..d3540807 100644 --- a/client.go +++ b/client.go @@ -225,7 +225,7 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { c.log.Errorf("failed to load client cert: %v", err) return c } - config := c.tlsConfig() + config := c.GetTLSClientConfig() config.Certificates = append(config.Certificates, cert) return c } @@ -238,13 +238,13 @@ func SetCerts(certs ...tls.Certificate) *Client { // SetCerts set client certificates. func (c *Client) SetCerts(certs ...tls.Certificate) *Client { - config := c.tlsConfig() + config := c.GetTLSClientConfig() config.Certificates = append(config.Certificates, certs...) return c } func (c *Client) appendRootCertData(data []byte) { - config := c.tlsConfig() + config := c.GetTLSClientConfig() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } @@ -283,9 +283,18 @@ func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { return c } -func (c *Client) tlsConfig() *tls.Config { +// GetTLSClientConfig is a global wrapper methods which delegated +// to the default client's GetTLSClientConfig. +func GetTLSClientConfig() *tls.Config { + return defaultClient.GetTLSClientConfig() +} + +// GetTLSClientConfig return the underlying tls.Config. +func (c *Client) GetTLSClientConfig() *tls.Config { if c.t.TLSClientConfig == nil { - c.t.TLSClientConfig = &tls.Config{} + c.t.TLSClientConfig = &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + } } return c.t.TLSClientConfig } @@ -397,12 +406,43 @@ func SetTLSClientConfig(conf *tls.Config) *Client { return defaultClient.SetTLSClientConfig(conf) } -// SetTLSClientConfig set the TLS client config. +// SetTLSClientConfig set the TLS client config. Be careful! Usually +// you don't need this, you can directly set the tls configuration with +// methods like EnableInsecureSkipVerify, SetCerts etc. Or you can call +// GetTLSClientConfig to get the current tls configuration to avoid +// overwriting some important configurations, such as not setting NextProtos +// will not use http2 by default. func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { c.t.TLSClientConfig = conf return c } +// EnableInsecureSkipVerify is a global wrapper methods which delegated +// to the default client's EnableInsecureSkipVerify. +func EnableInsecureSkipVerify() *Client { + return defaultClient.EnableInsecureSkipVerify() +} + +// EnableInsecureSkipVerify enable send https without verifing +// the server's certificates (disabled by default). +func (c *Client) EnableInsecureSkipVerify() *Client { + c.GetTLSClientConfig().InsecureSkipVerify = true + return c +} + +// DisableInsecureSkipVerify is a global wrapper methods which delegated +// to the default client's DisableInsecureSkipVerify. +func DisableInsecureSkipVerify() *Client { + return defaultClient.DisableInsecureSkipVerify() +} + +// DisableInsecureSkipVerify disable send https without verifing +// the server's certificates (disabled by default). +func (c *Client) DisableInsecureSkipVerify() *Client { + c.GetTLSClientConfig().InsecureSkipVerify = false + return c +} + // SetCommonQueryParams is a global wrapper methods which delegated // to the default client's SetCommonQueryParams. func SetCommonQueryParams(params map[string]string) *Client { From c03d9bba1dd613df6f8e5825f2ec2c007bb97edf Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 11:49:01 +0800 Subject: [PATCH 287/843] test support http2 --- client_test.go | 12 ++++++------ req_test.go | 36 ++++++++++++++++++++++++++---------- request_test.go | 12 ++++++------ 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/client_test.go b/client_test.go index 90dfe27e..5023e3ec 100644 --- a/client_test.go +++ b/client_test.go @@ -67,10 +67,10 @@ func TestClientDump(t *testing.T) { resp, err := c.R().SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := buf.String() - assertContains(t, dump, "POST / HTTP/1.1", reqHeader) + assertContains(t, dump, "user-agent", reqHeader) assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "HTTP/1.1 200 OK", respHeader) - assertContains(t, dump, "TestPost: text response", respBody) + assertContains(t, dump, "date", respHeader) + assertContains(t, dump, "testpost: text response", respBody) } c := tc() @@ -85,8 +85,8 @@ func TestClientDump(t *testing.T) { c.SetCommonDumpOptions(opt).EnableDumpAll() resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) - assertContains(t, buf.String(), "POST / HTTP/1.1", true) + assertContains(t, buf.String(), "user-agent", true) assertContains(t, buf.String(), "test body", false) - assertContains(t, buf.String(), "HTTP/1.1 200 OK", false) - assertContains(t, buf.String(), "TestPost: text response", true) + assertContains(t, buf.String(), "date", false) + assertContains(t, buf.String(), "testpost: text response", true) } diff --git a/req_test.go b/req_test.go index fcd99fc0..45a7e2ac 100644 --- a/req_test.go +++ b/req_test.go @@ -19,7 +19,9 @@ func tr() *Request { } func tc() *Client { - return C().SetBaseURL(getTestServerURL()) + return C(). + SetBaseURL(getTestServerURL()). + EnableInsecureSkipVerify() } var testDataPath string @@ -94,15 +96,29 @@ func handleGet(w http.ResponseWriter, r *http.Request) { } } +func handleHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + handleGet(w, r) + case http.MethodPost: + handlePost(w, r) + } +} + func createTestServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - handleGet(w, r) - case http.MethodPost: - handlePost(w, r) - } - })) + server := httptest.NewUnstartedServer(http.HandlerFunc(handleHTTP)) + // certFile := filepath.Join(testDataPath, "sample-server.pem") + // keyFile := filepath.Join(testDataPath, "sample-server-key.pem") + // cert, err := tls.LoadX509KeyPair(certFile, keyFile) + // if err != nil { + // panic(fmt.Sprintf("failed to load client cert: %v", err)) + // } + // config := &tls.Config{} + // config.Certificates = append(config.Certificates, cert) + // server.TLS = config + server.EnableHTTP2 = true + server.StartTLS() + return server } func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { @@ -119,7 +135,6 @@ func assertSuccess(t *testing.T, resp *Response, err error) { assertNotNil(t, resp.Body) assertEqual(t, http.StatusOK, resp.StatusCode) assertEqual(t, "200 OK", resp.Status) - assertEqual(t, "HTTP/1.1", resp.Proto) } func assertNil(t *testing.T, v interface{}) { @@ -141,6 +156,7 @@ func assertType(t *testing.T, typ, v interface{}) { } func assertContains(t *testing.T, s, substr string, shouldContain bool) { + s = strings.ToLower(s) isContain := strings.Contains(s, substr) if shouldContain { if !isContain { diff --git a/request_test.go b/request_test.go index 7719769e..b1c92c90 100644 --- a/request_test.go +++ b/request_test.go @@ -69,10 +69,10 @@ func TestRequestDump(t *testing.T) { resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1", reqHeader) + assertContains(t, dump, "user-agent", reqHeader) assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "HTTP/1.1 200 OK", respHeader) - assertContains(t, dump, "TestPost: text response", respBody) + assertContains(t, dump, "date", respHeader) + assertContains(t, dump, "testpost: text response", respBody) } opt := &DumpOptions{ @@ -84,10 +84,10 @@ func TestRequestDump(t *testing.T) { resp, err := tr().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "POST / HTTP/1.1", true) + assertContains(t, dump, "user-agent", true) assertContains(t, dump, "test body", false) - assertContains(t, dump, "HTTP/1.1 200 OK", false) - assertContains(t, dump, "TestPost: text response", true) + assertContains(t, dump, "date", false) + assertContains(t, dump, "testpost: text response", true) } func TestGet(t *testing.T) { From db24e861fdfddaef205dbf99cce1b891ef0cb2c7 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 12:52:31 +0800 Subject: [PATCH 288/843] tests both http2 and http1 for all --- client_test.go | 13 ++++++-- request_test.go | 82 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 19 deletions(-) diff --git a/client_test.go b/client_test.go index 5023e3ec..2978f942 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,15 @@ import ( ) func TestClientDump(t *testing.T) { + testClientDump(t, func() *Client { + return tc() + }) + testClientDump(t, func() *Client { + return tc().EnableForceHTTP1() + }) +} + +func testClientDump(t *testing.T, newClient func() *Client) { testCases := []func(r *Client, reqHeader, reqBody, respHeader, respBody *bool){ func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { r.EnableDumpAll() @@ -59,7 +68,7 @@ func TestClientDump(t *testing.T) { } for _, fn := range testCases { - c := tc() + c := newClient() buf := new(bytes.Buffer) c.EnableDumpAllTo(buf) var reqHeader, reqBody, respHeader, respBody bool @@ -73,7 +82,7 @@ func TestClientDump(t *testing.T) { assertContains(t, dump, "testpost: text response", respBody) } - c := tc() + c := newClient() buf := new(bytes.Buffer) opt := &DumpOptions{ RequestHeader: true, diff --git a/request_test.go b/request_test.go index b1c92c90..460d9537 100644 --- a/request_test.go +++ b/request_test.go @@ -10,6 +10,11 @@ import ( ) func TestRequestDump(t *testing.T) { + testRequestDump(t, tc()) + testRequestDump(t, tc().EnableForceHTTP1()) +} + +func testRequestDump(t *testing.T, c *Client) { testCases := []func(r *Request, reqHeader, reqBody, respHeader, respBody *bool){ func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { r.EnableDump() @@ -63,7 +68,7 @@ func TestRequestDump(t *testing.T) { } for _, fn := range testCases { - r := tr() + r := c.R() var reqHeader, reqBody, respHeader, respBody bool fn(r, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := r.SetBody(`test body`).Post("/") @@ -81,7 +86,7 @@ func TestRequestDump(t *testing.T) { ResponseHeader: false, ResponseBody: true, } - resp, err := tr().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) + resp, err := c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertSuccess(t, resp, err) dump := resp.Dump() assertContains(t, dump, "user-agent", true) @@ -91,17 +96,32 @@ func TestRequestDump(t *testing.T) { } func TestGet(t *testing.T) { - resp, err := tr().Get("/") + testGet(t, tc()) + testGet(t, tc().EnableForceHTTP1()) +} + +func testGet(t *testing.T, c *Client) { + resp, err := c.R().Get("/") assertSuccess(t, resp, err) assertEqual(t, "TestGet: text response", resp.String()) } func TestBadRequest(t *testing.T) { - resp, err := tr().Get("/bad-request") + testBadRequest(t, tc()) + testBadRequest(t, tc().EnableForceHTTP1()) +} + +func testBadRequest(t *testing.T, c *Client) { + resp, err := c.R().Get("/bad-request") assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } func TestSetBodyMarshal(t *testing.T) { + testSetBodyMarshal(t, tc()) + testSetBodyMarshal(t, tc().EnableForceHTTP1()) +} + +func testSetBodyMarshal(t *testing.T, c *Client) { type User struct { Username string `json:"username" xml:"username"` } @@ -178,17 +198,22 @@ func TestSetBodyMarshal(t *testing.T) { }, } - for _, c := range testCases { - r := tr() - c.Set(r) + for _, cs := range testCases { + r := c.R() + cs.Set(r) var e echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - c.Assert(&e) + cs.Assert(&e) } } func TestSetBodyContent(t *testing.T) { + testSetBodyContent(t, tc()) + testSetBodyContent(t, tc().EnableForceHTTP1()) +} + +func testSetBodyContent(t *testing.T, c *Client) { var e echo testBody := "test body" @@ -208,7 +233,7 @@ func TestSetBodyContent(t *testing.T) { } for _, fn := range testCases { - r := tr() + r := c.R() fn(r) var e echo resp, err := r.SetResult(&e).Post("/echo") @@ -220,15 +245,20 @@ func TestSetBodyContent(t *testing.T) { // Set Reader testBodyReader := strings.NewReader(testBody) e = echo{} - resp, err := tr().SetBody(testBodyReader).SetResult(&e).Post("/echo") + resp, err := c.R().SetBody(testBodyReader).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) assertEqual(t, testBody, e.Body) assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) } func TestCookie(t *testing.T) { + testCookie(t, tc()) + testCookie(t, tc().EnableForceHTTP1()) +} + +func testCookie(t *testing.T, c *Client) { headers := make(http.Header) - resp, err := tr().SetCookies( + resp, err := c.R().SetCookies( &http.Cookie{ Name: "cookie1", Value: "value1", @@ -243,8 +273,13 @@ func TestCookie(t *testing.T) { } func TestAuth(t *testing.T) { + testAuth(t, tc()) + testAuth(t, tc().EnableForceHTTP1()) +} + +func testAuth(t *testing.T, c *Client) { headers := make(http.Header) - resp, err := tr(). + resp, err := c.R(). SetBasicAuth("imroc", "123456"). SetResult(&headers). Get("/header") @@ -253,7 +288,7 @@ func TestAuth(t *testing.T) { token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" headers = make(http.Header) - resp, err = tr(). + resp, err = c.R(). SetBearerAuthToken(token). SetResult(&headers). Get("/header") @@ -262,15 +297,20 @@ func TestAuth(t *testing.T) { } func TestHeader(t *testing.T) { + testHeader(t, tc()) + testHeader(t, tc().EnableForceHTTP1()) +} + +func testHeader(t *testing.T, c *Client) { // Set User-Agent customUserAgent := "My Custom User Agent" - resp, err := tr().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") + resp, err := c.R().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") assertSuccess(t, resp, err) assertEqual(t, customUserAgent, resp.String()) // Set custom header headers := make(http.Header) - resp, err = tr(). + resp, err = c.R(). SetHeader("header1", "value1"). SetHeaders(map[string]string{ "header2": "value2", @@ -284,8 +324,11 @@ func TestHeader(t *testing.T) { } func TestQueryParam(t *testing.T) { - c := tc() + testQueryParam(t, tc()) + testQueryParam(t, tc().EnableForceHTTP1()) +} +func testQueryParam(t *testing.T, c *Client) { // Set query param at client level, should be overwritten at request level c.SetCommonQueryParam("key1", "client"). SetCommonQueryParams(map[string]string{ @@ -366,8 +409,13 @@ func TestQueryParam(t *testing.T) { } func TestPathParam(t *testing.T) { + testPathParam(t, tc()) + testPathParam(t, tc().EnableForceHTTP1()) +} + +func testPathParam(t *testing.T, c *Client) { username := "imroc" - resp, err := tr(). + resp, err := c.R(). SetPathParam("username", username). Get("/user/{username}/profile") assertSuccess(t, resp, err) From 6def6f18281f22ba49fcdc0075ab0060884fe124 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 14:08:51 +0800 Subject: [PATCH 289/843] ajust tests --- req_test.go | 45 ++++++++++++++++----------------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/req_test.go b/req_test.go index 45a7e2ac..d00cea06 100644 --- a/req_test.go +++ b/req_test.go @@ -14,10 +14,6 @@ import ( "testing" ) -func tr() *Request { - return tc().R() -} - func tc() *Client { return C(). SetBaseURL(getTestServerURL()). @@ -31,6 +27,22 @@ func init() { testDataPath = filepath.Join(pwd, ".testdata") } +func createTestServer() *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(handleHTTP)) + server.EnableHTTP2 = true + server.StartTLS() + return server +} + +func handleHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + handleGet(w, r) + case http.MethodPost: + handlePost(w, r) + } +} + var testServerMu sync.Mutex var testServer *httptest.Server @@ -96,31 +108,6 @@ func handleGet(w http.ResponseWriter, r *http.Request) { } } -func handleHTTP(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - handleGet(w, r) - case http.MethodPost: - handlePost(w, r) - } -} - -func createTestServer() *httptest.Server { - server := httptest.NewUnstartedServer(http.HandlerFunc(handleHTTP)) - // certFile := filepath.Join(testDataPath, "sample-server.pem") - // keyFile := filepath.Join(testDataPath, "sample-server-key.pem") - // cert, err := tls.LoadX509KeyPair(certFile, keyFile) - // if err != nil { - // panic(fmt.Sprintf("failed to load client cert: %v", err)) - // } - // config := &tls.Config{} - // config.Certificates = append(config.Certificates, cert) - // server.TLS = config - server.EnableHTTP2 = true - server.StartTLS() - return server -} - func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { assertError(t, err) assertNotNil(t, resp) From aada1f4c9416e9864111c131decdca209fd35302 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 14:47:13 +0800 Subject: [PATCH 290/843] Add TestSuccess and TestError --- req_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++ request_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/req_test.go b/req_test.go index d00cea06..aa1b7809 100644 --- a/req_test.go +++ b/req_test.go @@ -2,6 +2,7 @@ package req import ( "encoding/json" + "encoding/xml" "fmt" "io/ioutil" "net/http" @@ -65,6 +66,8 @@ func handlePost(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestPost: text response")) + case "/search": + handleSearch(w, r) case "/echo": b, _ := ioutil.ReadAll(r.Body) e := echo{ @@ -83,6 +86,53 @@ func handleGetUserProfile(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf("%s's profile", user))) } +type UserInfo struct { + Username string `json:"username" xml:"username"` + Email string `json:"email" xml:"email"` +} + +type ErrorMessage struct { + ErrorCode int `json:"error_code" xml:"ErrorCode"` + ErrorMessage string `json:"error_message" xml:"ErrorMessage"` +} + +func handleSearch(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + username := r.FormValue("username") + tp := r.FormValue("type") + var marshalFunc func(v interface{}) ([]byte, error) + if tp == "xml" { + w.Header().Set(hdrContentTypeKey, xmlContentType) + marshalFunc = xml.Marshal + } else { + w.Header().Set(hdrContentTypeKey, jsonContentType) + marshalFunc = json.Marshal + } + var result interface{} + switch username { + case "": + w.WriteHeader(http.StatusBadRequest) + result = &ErrorMessage{ + ErrorCode: 10000, + ErrorMessage: "need username", + } + case "imroc": + w.WriteHeader(http.StatusOK) + result = &UserInfo{ + Username: "imroc", + Email: "roc@imroc.cc", + } + default: + w.WriteHeader(http.StatusNotFound) + result = &ErrorMessage{ + ErrorCode: 10001, + ErrorMessage: "username not exists", + } + } + data, _ := marshalFunc(result) + w.Write(data) +} + func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -101,6 +151,8 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.Header.Get(hdrContentTypeKey))) case "/query-parameter": w.Write([]byte(r.URL.RawQuery)) + case "/search": + handleSearch(w, r) default: if strings.HasPrefix(r.URL.Path, "/user") { handleGetUserProfile(w, r) @@ -122,6 +174,18 @@ func assertSuccess(t *testing.T, resp *Response, err error) { assertNotNil(t, resp.Body) assertEqual(t, http.StatusOK, resp.StatusCode) assertEqual(t, "200 OK", resp.Status) + if !resp.IsSuccess() { + t.Error("Response.IsSuccess should return true") + } +} + +func assertIsError(t *testing.T, resp *Response, err error) { + assertError(t, err) + assertNotNil(t, resp) + assertNotNil(t, resp.Body) + if !resp.IsError() { + t.Error("Response.IsError should return true") + } } func assertNil(t *testing.T, v interface{}) { diff --git a/request_test.go b/request_test.go index 460d9537..fa68606c 100644 --- a/request_test.go +++ b/request_test.go @@ -421,3 +421,59 @@ func testPathParam(t *testing.T, c *Client) { assertSuccess(t, resp, err) assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) } + +func TestSuccess(t *testing.T) { + testSuccess(t, tc()) + testSuccess(t, tc().EnableForceHTTP1()) +} + +func testSuccess(t *testing.T, c *Client) { + var userInfo UserInfo + resp, err := c.R(). + SetQueryParam("username", "imroc"). + SetResult(&userInfo). + Get("/search") + assertSuccess(t, resp, err) + assertEqual(t, "roc@imroc.cc", userInfo.Email) + + userInfo = UserInfo{} + resp, err = c.R(). + SetQueryParam("username", "imroc"). + SetQueryParam("type", "xml"). // auto unmarshal to xml + SetResult(&userInfo).EnableDump(). + Get("/search") + assertSuccess(t, resp, err) + assertEqual(t, "roc@imroc.cc", userInfo.Email) +} + +func TestError(t *testing.T) { + testError(t, tc()) + testError(t, tc().EnableForceHTTP1()) +} + +func testError(t *testing.T, c *Client) { + var errMsg ErrorMessage + resp, err := c.R(). + SetQueryParam("username", ""). + SetError(&errMsg). + Get("/search") + assertIsError(t, resp, err) + assertEqual(t, 10000, errMsg.ErrorCode) + + errMsg = ErrorMessage{} + resp, err = c.R(). + SetQueryParam("username", "test"). + SetError(&errMsg). + Get("/search") + assertIsError(t, resp, err) + assertEqual(t, 10001, errMsg.ErrorCode) + + errMsg = ErrorMessage{} + resp, err = c.R(). + SetQueryParam("username", "test"). + SetQueryParam("type", "xml"). // auto unmarshal to xml + SetError(&errMsg). + Get("/search") + assertIsError(t, resp, err) + assertEqual(t, 10001, errMsg.ErrorCode) +} From 303efc5cdf46f32b8250bd95676fcf8737a2da96 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 14:51:31 +0800 Subject: [PATCH 291/843] Update TODO List --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 303378d9..317c62a8 100644 --- a/README.md +++ b/README.md @@ -866,11 +866,12 @@ client.SetProxy(nil) ## TODO List -* [ ] Add tests. +* [ ] Add more tests. * [ ] Wrap more transport settings into client. * [ ] Support retry. * [ ] Support unix socket. * [ ] Support h2c. +* [ ] Support HTTP3. ## License From 3c82e32813ec54988fb2972105fc90b2e256a147 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 14:52:16 +0800 Subject: [PATCH 292/843] Update ToC --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 317c62a8..137897c1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,6 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Redirect Policy](#Redirect) * [Proxy](#Proxy) * [TODO List](#TODO) -* [API Reference](#API) * [License](#License) ## Features From af83dddecd2f1f1c3e26cb1f3f6bf7af1d780a64 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 15:08:48 +0800 Subject: [PATCH 293/843] Add TestForm --- README.md | 2 +- request_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 137897c1..ac88085e 100644 --- a/README.md +++ b/README.md @@ -856,7 +856,7 @@ client.SetProxyURL("http://myproxy:8080") // Custmize the proxy function with your own implementation client.SetProxy(func(request *http.Request) (*url.URL, error) { - //... + // ... }) // Disable proxy diff --git a/request_test.go b/request_test.go index fa68606c..8483ffda 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "fmt" "net/http" + "net/url" "strings" "testing" ) @@ -477,3 +478,31 @@ func testError(t *testing.T, c *Client) { assertIsError(t, resp, err) assertEqual(t, 10001, errMsg.ErrorCode) } + +func TestForm(t *testing.T) { + testForm(t, tc()) + testForm(t, tc().EnableForceHTTP1()) +} + +func testForm(t *testing.T, c *Client) { + var userInfo UserInfo + resp, err := c.R(). + SetFormData(map[string]string{ + "username": "imroc", + "type": "xml", + }). + SetResult(&userInfo). + Post("/search") + assertSuccess(t, resp, err) + assertEqual(t, "roc@imroc.cc", userInfo.Email) + + v := make(url.Values) + v.Add("username", "imroc") + v.Add("type", "xml") + resp, err = c.R(). + SetFormDataFromValues(v). + SetResult(&userInfo). + Post("/search") + assertSuccess(t, resp, err) + assertEqual(t, "roc@imroc.cc", userInfo.Email) +} From 4215f03bff4132ceaa179da8ff8dc0b839317eb1 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 15:57:45 +0800 Subject: [PATCH 294/843] Fix Host header not override Host --- client.go | 14 +++++++++----- request.go | 7 +++++++ request_test.go | 11 +++++++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index d3540807..2c0b98a3 100644 --- a/client.go +++ b/client.go @@ -1297,9 +1297,9 @@ func C() *Client { Timeout: 2 * time.Minute, } beforeRequest := []RequestMiddleware{ + parseRequestHeader, parseRequestURL, parseRequestBody, - parseRequestHeader, parseRequestCookie, } afterResponse := []ResponseMiddleware{ @@ -1329,7 +1329,7 @@ func C() *Client { } func setupRequest(r *Request) { - setRequestURL(r.RawRequest, r.URL) + setRequestURL(r, r.URL) setRequestHeaderAndCookie(r) setTrace(r) setContext(r) @@ -1405,14 +1405,18 @@ func setRequestHeaderAndCookie(r *Request) { } } -func setRequestURL(r *http.Request, url string) error { +func setRequestURL(r *Request, url string) error { // The host's colon:port should be normalized. See Issue 14836. u, err := urlpkg.Parse(url) if err != nil { return err } u.Host = removeEmptyPort(u.Host) - r.URL = u - r.Host = u.Host + if host := r.getHeader("Host"); host != "" { + r.RawRequest.Host = host // Host header override + } else { + r.RawRequest.Host = u.Host + } + r.RawRequest.URL = u return nil } diff --git a/request.go b/request.go index a4a681eb..76006f3a 100644 --- a/request.go +++ b/request.go @@ -45,6 +45,13 @@ type Request struct { dumpBuffer *bytes.Buffer } +func (r *Request) getHeader(key string) string { + if r.Headers == nil { + return "" + } + return r.Headers.Get(key) +} + // TraceInfo returns the trace information, only available if trace is enabled // (see Request.EnableTrace and Client.EnableTraceAll). func (r *Request) TraceInfo() TraceInfo { diff --git a/request_test.go b/request_test.go index 8483ffda..4313a370 100644 --- a/request_test.go +++ b/request_test.go @@ -506,3 +506,14 @@ func testForm(t *testing.T, c *Client) { assertSuccess(t, resp, err) assertEqual(t, "roc@imroc.cc", userInfo.Email) } + +func TestHostHeaderOverride(t *testing.T) { + testHostHeaderOverride(t, tc()) + testHostHeaderOverride(t, tc().EnableForceHTTP1()) +} + +func testHostHeaderOverride(t *testing.T, c *Client) { + resp, err := c.R().SetHeader("Host", "testhostname").Get("/host-header") + assertSuccess(t, resp, err) + assertEqual(t, "testhostname", resp.String()) +} From 069267345d4d667a5f03e907857e028e1ae763cb Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 16:42:49 +0800 Subject: [PATCH 295/843] Add TestTraceInfo --- request_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/request_test.go b/request_test.go index 4313a370..9068e447 100644 --- a/request_test.go +++ b/request_test.go @@ -517,3 +517,37 @@ func testHostHeaderOverride(t *testing.T, c *Client) { assertSuccess(t, resp, err) assertEqual(t, "testhostname", resp.String()) } + +func TestTraceInfo(t *testing.T) { + testTraceInfo(t, tc()) + testTraceInfo(t, tc().EnableForceHTTP1()) +} + +func testTraceInfo(t *testing.T, c *Client) { + // enable trace at client level + c.EnableTraceAll() + resp, err := c.R().Get("/") + assertSuccess(t, resp, err) + ti := resp.TraceInfo() + assertEqual(t, true, ti.TotalTime > 0) + assertEqual(t, true, ti.TCPConnectTime > 0) + assertEqual(t, true, ti.ConnectTime > 0) + assertEqual(t, true, ti.FirstResponseTime > 0) + assertEqual(t, true, ti.ResponseTime > 0) + assertNotNil(t, ti.RemoteAddr) + + // disable trace at client level + c.DisableTraceAll() + resp, err = c.R().Get("/") + assertSuccess(t, resp, err) + ti = resp.TraceInfo() + assertEqual(t, false, ti.TotalTime > 0) + assertNil(t, ti.RemoteAddr) + + // enable trace at request level + resp, err = c.R().EnableTrace().Get("/") + assertSuccess(t, resp, err) + ti = resp.TraceInfo() + assertEqual(t, true, ti.TotalTime > 0) + assertNotNil(t, ti.RemoteAddr) +} From 0aff7d0ccfaef3c5dc16c20444dfc66396c9f7dc Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 19:36:03 +0800 Subject: [PATCH 296/843] optimize tracing --- client.go | 4 +++- request.go | 63 +++++++++++++++++++++++++++++++------------------ request_test.go | 25 ++++++++++++++++++++ 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index 2c0b98a3..80fa2ff0 100644 --- a/client.go +++ b/client.go @@ -1336,7 +1336,9 @@ func setupRequest(r *Request) { } func (c *Client) do(r *Request) (resp *Response, err error) { - resp = &Response{} + resp = &Response{ + Request: r, + } for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { diff --git a/request.go b/request.go index 76006f3a..bef9a9b2 100644 --- a/request.go +++ b/request.go @@ -32,17 +32,18 @@ type Request struct { RawRequest *http.Request StartTime time.Time - dumpOptions *DumpOptions - marshalBody interface{} - ctx context.Context - isMultiPart bool - uploadFiles []*FileUpload - uploadReader []io.ReadCloser - outputFile string - isSaveResponse bool - output io.Writer - trace *clientTrace - dumpBuffer *bytes.Buffer + dumpOptions *DumpOptions + marshalBody interface{} + ctx context.Context + isMultiPart bool + uploadFiles []*FileUpload + uploadReader []io.ReadCloser + outputFile string + isSaveResponse bool + output io.Writer + trace *clientTrace + dumpBuffer *bytes.Buffer + responseReturnTime time.Time } func (r *Request) getHeader(key string) string { @@ -62,20 +63,32 @@ func (r *Request) TraceInfo() TraceInfo { } ti := TraceInfo{ - DNSLookupTime: ct.dnsDone.Sub(ct.dnsStart), - TLSHandshakeTime: ct.tlsHandshakeDone.Sub(ct.tlsHandshakeStart), - FirstResponseTime: ct.gotFirstResponseByte.Sub(ct.gotConn), - IsConnReused: ct.gotConnInfo.Reused, - IsConnWasIdle: ct.gotConnInfo.WasIdle, - ConnIdleTime: ct.gotConnInfo.IdleTime, + IsConnReused: ct.gotConnInfo.Reused, + IsConnWasIdle: ct.gotConnInfo.WasIdle, + ConnIdleTime: ct.gotConnInfo.IdleTime, + } + + endTime := ct.endTime + if endTime.IsZero() { // in case timeout + endTime = r.responseReturnTime + } + + if !ct.tlsHandshakeStart.IsZero() { + if !ct.tlsHandshakeDone.IsZero() { + ti.TLSHandshakeTime = ct.tlsHandshakeDone.Sub(ct.tlsHandshakeStart) + } else { + ti.TLSHandshakeTime = endTime.Sub(ct.tlsHandshakeStart) + } } - // Calculate the total time accordingly, - // when connection is reused if ct.gotConnInfo.Reused { - ti.TotalTime = ct.endTime.Sub(ct.getConn) + ti.TotalTime = endTime.Sub(ct.getConn) } else { - ti.TotalTime = ct.endTime.Sub(ct.dnsStart) + if ct.dnsStart.IsZero() { + ti.TotalTime = endTime.Sub(r.StartTime) + } else { + ti.TotalTime = endTime.Sub(ct.dnsStart) + } } // Only calculate on successful connections @@ -90,7 +103,8 @@ func (r *Request) TraceInfo() TraceInfo { // Only calculate on successful connections if !ct.gotFirstResponseByte.IsZero() { - ti.ResponseTime = ct.endTime.Sub(ct.gotFirstResponseByte) + ti.FirstResponseTime = ct.gotFirstResponseByte.Sub(ct.gotConn) + ti.ResponseTime = endTime.Sub(ct.gotFirstResponseByte) } // Capture remote address info when connection is non-nil @@ -433,8 +447,11 @@ func (r *Request) appendError(err error) { // Send fires http request and return the *Response which is always // not nil, and the error is not nil if some error happens. func (r *Request) Send(method, url string) (*Response, error) { + defer func() { + r.responseReturnTime = time.Now() + }() if r.error != nil { - return &Response{}, r.error + return &Response{Request: r}, r.error } r.RawRequest.Method = method r.URL = url diff --git a/request_test.go b/request_test.go index 9068e447..6dd3312a 100644 --- a/request_test.go +++ b/request_test.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" "testing" + "time" ) func TestRequestDump(t *testing.T) { @@ -531,6 +532,7 @@ func testTraceInfo(t *testing.T, c *Client) { ti := resp.TraceInfo() assertEqual(t, true, ti.TotalTime > 0) assertEqual(t, true, ti.TCPConnectTime > 0) + assertEqual(t, true, ti.TLSHandshakeTime > 0) assertEqual(t, true, ti.ConnectTime > 0) assertEqual(t, true, ti.FirstResponseTime > 0) assertEqual(t, true, ti.ResponseTime > 0) @@ -551,3 +553,26 @@ func testTraceInfo(t *testing.T, c *Client) { assertEqual(t, true, ti.TotalTime > 0) assertNotNil(t, ti.RemoteAddr) } + +func TestTraceOnTimeout(t *testing.T) { + testTraceOnTimeout(t, C()) + testTraceOnTimeout(t, C().EnableForceHTTP1()) +} + +func testTraceOnTimeout(t *testing.T, c *Client) { + c.EnableTraceAll().SetTimeout(100 * time.Millisecond) + + resp, err := c.R().Get("http://req-nowhere.local") + assertNotNil(t, err) + assertNotNil(t, resp) + + tr := resp.TraceInfo() + assertEqual(t, true, tr.DNSLookupTime >= 0) + assertEqual(t, true, tr.ConnectTime == 0) + assertEqual(t, true, tr.TLSHandshakeTime == 0) + assertEqual(t, true, tr.TCPConnectTime == 0) + assertEqual(t, true, tr.FirstResponseTime == 0) + assertEqual(t, true, tr.ResponseTime == 0) + assertEqual(t, true, tr.TotalTime > 0) + assertEqual(t, true, tr.TotalTime == resp.TotalTime()) +} From b760a827f6729f50999c4819a53c72f0da449d93 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 20:13:25 +0800 Subject: [PATCH 297/843] Add TestRedirect --- req.go | 1 + req_test.go | 3 +++ request_test.go | 15 +++++++++++++++ 3 files changed, 19 insertions(+) diff --git a/req.go b/req.go index 84e7a908..269d6698 100644 --- a/req.go +++ b/req.go @@ -8,6 +8,7 @@ import ( const ( hdrUserAgentKey = "User-Agent" hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" + hdrLocationKey = "Location" hdrContentTypeKey = "Content-Type" plainTextContentType = "text/plain; charset=utf-8" jsonContentType = "application/json; charset=utf-8" diff --git a/req_test.go b/req_test.go index aa1b7809..8400edd5 100644 --- a/req_test.go +++ b/req_test.go @@ -68,6 +68,9 @@ func handlePost(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestPost: text response")) case "/search": handleSearch(w, r) + case "/redirect": + w.Header().Set(hdrLocationKey, "/") + w.WriteHeader(http.StatusMovedPermanently) case "/echo": b, _ := ioutil.ReadAll(r.Body) e := echo{ diff --git a/request_test.go b/request_test.go index 6dd3312a..291e7aa3 100644 --- a/request_test.go +++ b/request_test.go @@ -576,3 +576,18 @@ func testTraceOnTimeout(t *testing.T, c *Client) { assertEqual(t, true, tr.TotalTime > 0) assertEqual(t, true, tr.TotalTime == resp.TotalTime()) } + +func TestRedirect(t *testing.T) { + testRedirect(t, tc()) + testRedirect(t, tc().EnableForceHTTP1()) +} + +func testRedirect(t *testing.T, c *Client) { + resp, err := c.R().SetBody("test").Post("/redirect") + assertSuccess(t, resp, err) + + c.SetRedirectPolicy(NoRedirectPolicy()) + resp, err = c.R().SetBody("test").Post("/redirect") + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect is disabled", true) +} From 7be69654647352413e216e77ac968d587bd7e1b8 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 20:31:52 +0800 Subject: [PATCH 298/843] Add TestDisableAutoReadResponse --- client_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/client_test.go b/client_test.go index 2978f942..5b6fae92 100644 --- a/client_test.go +++ b/client_test.go @@ -2,9 +2,30 @@ package req import ( "bytes" + "io/ioutil" "testing" ) +func TestDisableAutoReadResponse(t *testing.T) { + testDisableAutoReadResponse(t, tc()) + testDisableAutoReadResponse(t, tc().EnableForceHTTP1()) +} + +func testDisableAutoReadResponse(t *testing.T, c *Client) { + c.DisableAutoReadResponse() + resp, err := c.R().Get("/") + assertSuccess(t, resp, err) + assertEqual(t, "", resp.String()) + result, err := resp.ToString() + assertError(t, err) + assertEqual(t, "TestGet: text response", result) + + resp, err = c.R().Get("/") + assertSuccess(t, resp, err) + _, err = ioutil.ReadAll(resp.Body) + assertError(t, err) +} + func TestClientDump(t *testing.T) { testClientDump(t, func() *Client { return tc() From ba6575baf45e9961fa26ffb6991a492101c0f81a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 21:12:23 +0800 Subject: [PATCH 299/843] Optimize upload raw bytes 1. Auto detect the content type of request body if []byte is set and no Content-Type specified. 2. Add TestAutoDetectRequestContentType --- middleware.go | 13 ++++++++++--- req_test.go | 9 +++++++++ request.go | 21 +++++++++------------ request_test.go | 6 ++++++ 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/middleware.go b/middleware.go index 218dfc61..e09adb5c 100644 --- a/middleware.go +++ b/middleware.go @@ -107,7 +107,7 @@ func handleMultiPart(c *Client, r *Request) (err error) { func handleFormData(r *Request) { r.SetContentType(formContentType) - r.setBodyBytes([]byte(r.FormData.Encode())) + r.SetBodyBytes([]byte(r.FormData.Encode())) } func handleMarshalBody(c *Client, r *Request) error { @@ -124,13 +124,13 @@ func handleMarshalBody(c *Client, r *Request) error { if err != nil { return err } - r.setBodyBytes(body) + r.SetBodyBytes(body) } else { body, err := c.jsonMarshal(r.marshalBody) if err != nil { return err } - r.setBodyBytes(body) + r.SetBodyBytes(body) } return nil } @@ -165,6 +165,13 @@ func parseRequestBody(c *Client, r *Request) (err error) { if r.marshalBody != nil { handleMarshalBody(c, r) } + + if r.getHeader(hdrContentTypeKey) == "" && r.body != nil { + ct := http.DetectContentType(r.body) + if ct != "application/octet-stream" { + r.SetContentType(ct) + } + } return } diff --git a/req_test.go b/req_test.go index 8400edd5..36c45969 100644 --- a/req_test.go +++ b/req_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" @@ -57,6 +58,12 @@ func getTestServerURL() string { return testServer.URL } +func getTestFileContent(t *testing.T, filename string) []byte { + b, err := ioutil.ReadFile(filepath.Join(testDataPath, filename)) + assertError(t, err) + return b +} + type echo struct { Header http.Header `json:"header" xml:"header"` Body string `json:"body" xml:"body"` @@ -66,6 +73,8 @@ func handlePost(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestPost: text response")) + case "/raw-upload": + io.Copy(ioutil.Discard, r.Body) case "/search": handleSearch(w, r) case "/redirect": diff --git a/request.go b/request.go index bef9a9b2..5630c4c6 100644 --- a/request.go +++ b/request.go @@ -32,6 +32,7 @@ type Request struct { RawRequest *http.Request StartTime time.Time + body []byte dumpOptions *DumpOptions marshalBody interface{} ctx context.Context @@ -681,12 +682,8 @@ func SetBodyBytes(body []byte) *Request { // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { - r.setBodyBytes(body) - return r.SetContentType(plainTextContentType) -} - -func (r *Request) setBodyBytes(body []byte) *Request { r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.body = body return r } @@ -698,7 +695,7 @@ func SetBodyString(body string) *Request { // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { - r.setBodyBytes([]byte(body)) + r.SetBodyBytes([]byte(body)) return r.SetContentType(plainTextContentType) } @@ -711,7 +708,7 @@ func SetBodyJsonString(body string) *Request { // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { - r.setBodyBytes([]byte(body)) + r.SetBodyBytes([]byte(body)) return r.SetContentType(jsonContentType) } @@ -724,7 +721,7 @@ func SetBodyJsonBytes(body []byte) *Request { // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { - r.setBodyBytes(body) + r.SetBodyBytes(body) return r.SetContentType(jsonContentType) } @@ -742,7 +739,7 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { r.appendError(err) return r } - r.setBodyBytes(b) + r.SetBodyBytes(b) return r.SetContentType(jsonContentType) } @@ -755,7 +752,7 @@ func SetBodyXmlString(body string) *Request { // SetBodyXmlString set the request body as string and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { - r.setBodyBytes([]byte(body)) + r.SetBodyBytes([]byte(body)) return r.SetContentType(xmlContentType) } @@ -768,7 +765,7 @@ func SetBodyXmlBytes(body []byte) *Request { // SetBodyXmlBytes set the request body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { - r.setBodyBytes(body) + r.SetBodyBytes(body) return r.SetContentType(xmlContentType) } @@ -786,7 +783,7 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { r.appendError(err) return r } - r.setBodyBytes(b) + r.SetBodyBytes(b) return r.SetContentType(xmlContentType) } diff --git a/request_test.go b/request_test.go index 291e7aa3..65fe200d 100644 --- a/request_test.go +++ b/request_test.go @@ -591,3 +591,9 @@ func testRedirect(t *testing.T, c *Client) { assertNotNil(t, err) assertContains(t, err.Error(), "redirect is disabled", true) } + +func TestAutoDetectRequestContentType(t *testing.T) { + resp, err := tc().R().SetBody(getTestFileContent(t, "sample-image.png")).Post("/raw-upload") + assertSuccess(t, resp, err) + assertEqual(t, "image/png", resp.Request.Headers.Get(hdrContentTypeKey)) +} From a99ddfce138c6f0095a024fee103396d4eda053f Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 21:17:48 +0800 Subject: [PATCH 300/843] Update TODO List --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ac88085e..8e8c694d 100644 --- a/README.md +++ b/README.md @@ -870,6 +870,8 @@ client.SetProxy(nil) * [ ] Support retry. * [ ] Support unix socket. * [ ] Support h2c. +* [ ] Make videos. +* [ ] Design a logo. * [ ] Support HTTP3. ## License From d5ad3c28f8b6f5c58350d2ebe225cdb043ba4b2a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Feb 2022 22:06:38 +0800 Subject: [PATCH 301/843] Add TestUploadMultipart --- req_test.go | 14 +++++++++++++- request_test.go | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 36c45969..c2145986 100644 --- a/req_test.go +++ b/req_test.go @@ -59,11 +59,15 @@ func getTestServerURL() string { } func getTestFileContent(t *testing.T, filename string) []byte { - b, err := ioutil.ReadFile(filepath.Join(testDataPath, filename)) + b, err := ioutil.ReadFile(getTestFilePath(filename)) assertError(t, err) return b } +func getTestFilePath(filename string) string { + return filepath.Join(testDataPath, filename) +} + type echo struct { Header http.Header `json:"header" xml:"header"` Body string `json:"body" xml:"body"` @@ -75,6 +79,14 @@ func handlePost(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestPost: text response")) case "/raw-upload": io.Copy(ioutil.Discard, r.Body) + case "/multipart": + r.ParseMultipartForm(10e6) + m := make(map[string]interface{}) + m["values"] = r.MultipartForm.Value + m["files"] = r.MultipartForm.File + ret, _ := json.Marshal(&m) + w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Write(ret) case "/search": handleSearch(w, r) case "/redirect": diff --git a/request_test.go b/request_test.go index 65fe200d..e63b54e9 100644 --- a/request_test.go +++ b/request_test.go @@ -597,3 +597,21 @@ func TestAutoDetectRequestContentType(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, "image/png", resp.Request.Headers.Get(hdrContentTypeKey)) } + +func TestUploadMultipart(t *testing.T) { + m := make(map[string]interface{}) + resp, err := tc().R(). + SetFile("file", getTestFilePath("sample-image.png")). + SetFile("file", getTestFilePath("sample-file.txt")). + SetFormData(map[string]string{ + "param1": "value1", + "param2": "value2", + }). + SetResult(&m). + Post("/multipart") + assertSuccess(t, resp, err) + assertContains(t, resp.String(), "sample-image.png", true) + assertContains(t, resp.String(), "sample-file.txt", true) + assertContains(t, resp.String(), "value1", true) + assertContains(t, resp.String(), "value2", true) +} From 831992bda74e0d327889558da5531710326a0079 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 12 Feb 2022 11:52:54 +0800 Subject: [PATCH 302/843] fix DNSLookupTime in TraceInfo --- request.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/request.go b/request.go index 5630c4c6..3123b723 100644 --- a/request.go +++ b/request.go @@ -92,9 +92,18 @@ func (r *Request) TraceInfo() TraceInfo { } } + dnsDone := ct.dnsDone + if dnsDone.IsZero() { + dnsDone = endTime + } + + if !ct.dnsStart.IsZero() { + ti.DNSLookupTime = dnsDone.Sub(ct.dnsStart) + } + // Only calculate on successful connections if !ct.connectDone.IsZero() { - ti.TCPConnectTime = ct.connectDone.Sub(ct.dnsDone) + ti.TCPConnectTime = ct.connectDone.Sub(dnsDone) } // Only calculate on successful connections From 5f14a0203fdbf8ec18a1aa333c0e6febf83704dc Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 12 Feb 2022 19:34:43 +0800 Subject: [PATCH 303/843] update README --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8e8c694d..6c6a954d 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Table of Contents * [Features](#Features) -* [Quick Start](#Quick-Start) +* [Get Started](#Get-Started) +* [Videos](#Videos) * [API Reference](#API) * [Debugging - Dump/Log/Trace](#Debugging) * [Quick HTTP Test](#Test) @@ -48,7 +49,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. * Support middleware before request sent and after got response (see [Request and Response Middleware](#Middleware)). -## Quick Start +## Get Started **Install** @@ -87,6 +88,12 @@ Checkout more runnable examples in the [examples](examples) direcotry. Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +## Videos + +Only the Chinese version is available for now, there will be more in the future. + +* [快速上手 req](https://www.bilibili.com/video/BV1Xq4y1b7UR) (Chinese, BiliBili) + ## Debugging - Dump/Log/Trace **Dump the Content** From 199393b1b0051fd89839165c7a225202f8cb952d Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 14:29:03 +0800 Subject: [PATCH 304/843] update README --- README.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 6c6a954d..a56cd37b 100644 --- a/README.md +++ b/README.md @@ -84,16 +84,15 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. -## API Reference - -Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). - ## Videos -Only the Chinese version is available for now, there will be more in the future. - +* [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) * [快速上手 req](https://www.bilibili.com/video/BV1Xq4y1b7UR) (Chinese, BiliBili) +## API Reference + +Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). + ## Debugging - Dump/Log/Trace **Dump the Content** From 6dfa9833a0b84d166d4a9d951f01e4ee92637aa1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 14:49:51 +0800 Subject: [PATCH 305/843] update README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a56cd37b..d3b58de9 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,6 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Features](#Features) * [Get Started](#Get-Started) -* [Videos](#Videos) -* [API Reference](#API) * [Debugging - Dump/Log/Trace](#Debugging) * [Quick HTTP Test](#Test) * [HTTP2 and HTTP1](#HTTP2-HTTP1) @@ -63,6 +61,8 @@ go get github.com/imroc/req/v3 import "github.com/imroc/req/v3" ``` +**Basic Usage** + ```go // For test, you can create and send a request with the global default // client, use DevMode to see all details, try and suprise :) @@ -84,12 +84,12 @@ resp, err := client.R(). // Use R() to create a request Checkout more runnable examples in the [examples](examples) direcotry. -## Videos +**Videos** * [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) * [快速上手 req](https://www.bilibili.com/video/BV1Xq4y1b7UR) (Chinese, BiliBili) -## API Reference +**API Reference** Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). From 425bd8440069a233253f4d410963b035a32e9916 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 14:52:07 +0800 Subject: [PATCH 306/843] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d3b58de9..3c1e76e5 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,6 @@ resp, err := client.R(). // Use R() to create a request Get("https://api.github.com/users/{username}/repos") ``` -Checkout more runnable examples in the [examples](examples) direcotry. - **Videos** * [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) @@ -93,6 +91,10 @@ Checkout more runnable examples in the [examples](examples) direcotry. Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +**Examples** + +Checkout more examples below or runnable examples in the [examples](examples) direcotry. + ## Debugging - Dump/Log/Trace **Dump the Content** From 96573e915edaa89bcbeeecec3a5f00c84592b615 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 17:58:45 +0800 Subject: [PATCH 307/843] extract h2_pipe.go --- h2_bundle.go | 169 +------------------------------------------------ h2_pipe.go | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 168 deletions(-) create mode 100644 h2_pipe.go diff --git a/h2_bundle.go b/h2_bundle.go index 4e8cae81..74f21234 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1,3 +1,4 @@ +//go:build !nethttpomithttp2 // +build !nethttpomithttp2 // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. @@ -2877,174 +2878,6 @@ func http2validPseudoPath(v string) bool { // any size (as long as it's first). type http2incomparable [0]func() -// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like -// io.Pipe except there are no PipeReader/PipeWriter halves, and the -// underlying buffer is an interface. (io.Pipe is always unbuffered) -type http2pipe struct { - mu sync.Mutex - c sync.Cond // c.L lazily initialized to &p.mu - b http2pipeBuffer // nil when done reading - unread int // bytes unread when done - err error // read error once empty. non-nil means closed. - breakErr error // immediate read error (caller doesn't see rest of b) - donec chan struct{} // closed on error - readFn func() // optional code to run in Read before error -} - -type http2pipeBuffer interface { - Len() int - io.Writer - io.Reader -} - -// setBuffer initializes the pipe buffer. -// It has no effect if the pipe is already closed. -func (p *http2pipe) setBuffer(b http2pipeBuffer) { - p.mu.Lock() - defer p.mu.Unlock() - if p.err != nil || p.breakErr != nil { - return - } - p.b = b -} - -func (p *http2pipe) Len() int { - p.mu.Lock() - defer p.mu.Unlock() - if p.b == nil { - return p.unread - } - return p.b.Len() -} - -// Read waits until data is available and copies bytes -// from the buffer into p. -func (p *http2pipe) Read(d []byte) (n int, err error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.c.L == nil { - p.c.L = &p.mu - } - for { - if p.breakErr != nil { - return 0, p.breakErr - } - if p.b != nil && p.b.Len() > 0 { - return p.b.Read(d) - } - if p.err != nil { - if p.readFn != nil { - p.readFn() // e.g. copy trailers - p.readFn = nil // not sticky like p.err - } - p.b = nil - return 0, p.err - } - p.c.Wait() - } -} - -var http2errClosedPipeWrite = errors.New("write on closed buffer") - -// Write copies bytes from p into the buffer and wakes a reader. -// It is an error to write more data than the buffer can hold. -func (p *http2pipe) Write(d []byte) (n int, err error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.c.L == nil { - p.c.L = &p.mu - } - defer p.c.Signal() - if p.err != nil { - return 0, http2errClosedPipeWrite - } - if p.breakErr != nil { - p.unread += len(d) - return len(d), nil // discard when there is no reader - } - return p.b.Write(d) -} - -// CloseWithError causes the next Read (waking up a current blocked -// Read if needed) to return the provided err after all data has been -// read. -// -// The error must be non-nil. -func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } - -// BreakWithError causes the next Read (waking up a current blocked -// Read if needed) to return the provided err immediately, without -// waiting for unread data. -func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } - -// closeWithErrorAndCode is like CloseWithError but also sets some code to run -// in the caller's goroutine before returning the error. -func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } - -func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { - if err == nil { - panic("err must be non-nil") - } - p.mu.Lock() - defer p.mu.Unlock() - if p.c.L == nil { - p.c.L = &p.mu - } - defer p.c.Signal() - if *dst != nil { - // Already been done. - return - } - p.readFn = fn - if dst == &p.breakErr { - if p.b != nil { - p.unread += p.b.Len() - } - p.b = nil - } - *dst = err - p.closeDoneLocked() -} - -// requires p.mu be held. -func (p *http2pipe) closeDoneLocked() { - if p.donec == nil { - return - } - // Close if unclosed. This isn't racy since we always - // hold p.mu while closing. - select { - case <-p.donec: - default: - close(p.donec) - } -} - -// Err returns the error (if any) first set by BreakWithError or CloseWithError. -func (p *http2pipe) Err() error { - p.mu.Lock() - defer p.mu.Unlock() - if p.breakErr != nil { - return p.breakErr - } - return p.err -} - -// Done returns a channel which is closed if and when this pipe is closed -// with CloseWithError. -func (p *http2pipe) Done() <-chan struct{} { - p.mu.Lock() - defer p.mu.Unlock() - if p.donec == nil { - p.donec = make(chan struct{}) - if p.err != nil || p.breakErr != nil { - // Already hit an error. - p.closeDoneLocked() - } - } - return p.donec -} - const ( http2prefaceTimeout = 10 * time.Second http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway diff --git a/h2_pipe.go b/h2_pipe.go new file mode 100644 index 00000000..95c93769 --- /dev/null +++ b/h2_pipe.go @@ -0,0 +1,175 @@ +package req + +import ( + "errors" + "io" + "sync" +) + +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// io.Pipe except there are no PipeReader/PipeWriter halves, and the +// underlying buffer is an interface. (io.Pipe is always unbuffered) +type http2pipe struct { + mu sync.Mutex + c sync.Cond // c.L lazily initialized to &p.mu + b http2pipeBuffer // nil when done reading + unread int // bytes unread when done + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error +} + +type http2pipeBuffer interface { + Len() int + io.Writer + io.Reader +} + +// setBuffer initializes the pipe buffer. +// It has no effect if the pipe is already closed. +func (p *http2pipe) setBuffer(b http2pipeBuffer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil || p.breakErr != nil { + return + } + p.b = b +} + +func (p *http2pipe) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + if p.b == nil { + return p.unread + } + return p.b.Len() +} + +// Read waits until data is available and copies bytes +// from the buffer into p. +func (p *http2pipe) Read(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + for { + if p.breakErr != nil { + return 0, p.breakErr + } + if p.b != nil && p.b.Len() > 0 { + return p.b.Read(d) + } + if p.err != nil { + if p.readFn != nil { + p.readFn() // e.g. copy trailers + p.readFn = nil // not sticky like p.err + } + p.b = nil + return 0, p.err + } + p.c.Wait() + } +} + +var http2errClosedPipeWrite = errors.New("write on closed buffer") + +// Write copies bytes from p into the buffer and wakes a reader. +// It is an error to write more data than the buffer can hold. +func (p *http2pipe) Write(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err != nil { + return 0, http2errClosedPipeWrite + } + if p.breakErr != nil { + p.unread += len(d) + return len(d), nil // discard when there is no reader + } + return p.b.Write(d) +} + +// CloseWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err after all data has been +// read. +// +// The error must be non-nil. +func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } + +// BreakWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err immediately, without +// waiting for unread data. +func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } + +// closeWithErrorAndCode is like CloseWithError but also sets some code to run +// in the caller's goroutine before returning the error. +func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } + +func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { + if err == nil { + panic("err must be non-nil") + } + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if *dst != nil { + // Already been done. + return + } + p.readFn = fn + if dst == &p.breakErr { + if p.b != nil { + p.unread += p.b.Len() + } + p.b = nil + } + *dst = err + p.closeDoneLocked() +} + +// requires p.mu be held. +func (p *http2pipe) closeDoneLocked() { + if p.donec == nil { + return + } + // Close if unclosed. This isn't racy since we always + // hold p.mu while closing. + select { + case <-p.donec: + default: + close(p.donec) + } +} + +// Err returns the error (if any) first set by BreakWithError or CloseWithError. +func (p *http2pipe) Err() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.breakErr != nil { + return p.breakErr + } + return p.err +} + +// Done returns a channel which is closed if and when this pipe is closed +// with CloseWithError. +func (p *http2pipe) Done() <-chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + if p.donec == nil { + p.donec = make(chan struct{}) + if p.err != nil || p.breakErr != nil { + // Already hit an error. + p.closeDoneLocked() + } + } + return p.donec +} From 8d43bb66861b4fd8dd3ca129caff869fc3ca7463 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 18:03:24 +0800 Subject: [PATCH 308/843] remove some unused code in h2_bundle.go --- h2_bundle.go | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 74f21234..975a7dce 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -2878,46 +2878,6 @@ func http2validPseudoPath(v string) bool { // any size (as long as it's first). type http2incomparable [0]func() -const ( - http2prefaceTimeout = 10 * time.Second - http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway - http2handlerChunkWriteSize = 4 << 10 - http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? - http2maxQueuedControlFrames = 10000 -) - -type http2readFrameResult struct { - f http2Frame // valid until readMore is called - err error - - // readMore should be called once the consumer no longer needs or - // retains f. After readMore, f is invalid and more frames can be - // read. - readMore func() -} - -type http2serverMessage int - -// Message values sent to serveMsgCh. -var ( - http2settingsTimerMsg = new(http2serverMessage) - http2idleTimerMsg = new(http2serverMessage) - http2shutdownTimerMsg = new(http2serverMessage) - http2gracefulShutdownMsg = new(http2serverMessage) -) - -var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") - -var http2errChanPool = sync.Pool{ - New: func() interface{} { return make(chan error, 1) }, -} - -type http2requestParam struct { - method string - scheme, authority, path string - header http.Header -} - // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys // that, if present, signals that the map entry is actually for // the response trailers, and not the response headers. The prefix From f982cb57931f6b68e2e0244a4d5359ba5753a7a6 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 18:06:12 +0800 Subject: [PATCH 309/843] extract h2_ascii.go --- h2_ascii.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ h2_bundle.go | 46 ---------------------------------------------- 2 files changed, 49 insertions(+), 46 deletions(-) create mode 100644 h2_ascii.go diff --git a/h2_ascii.go b/h2_ascii.go new file mode 100644 index 00000000..be3f69e9 --- /dev/null +++ b/h2_ascii.go @@ -0,0 +1,49 @@ +package req + +import "strings" + +// The HTTP protocols are defined in terms of ASCII, not Unicode. This file +// contains helper functions which may use Unicode-aware functions which would +// otherwise be unsafe and could introduce vulnerabilities if used improperly. + +// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func http2asciiEqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if http2lower(s[i]) != http2lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func http2lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// isASCIIPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func http2isASCIIPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// asciiToLower returns the lowercase version of s if s is ASCII and printable, +// and whether or not it was. +func http2asciiToLower(s string) (lower string, ok bool) { + if !http2isASCIIPrint(s) { + return "", false + } + return strings.ToLower(s), true +} diff --git a/h2_bundle.go b/h2_bundle.go index 975a7dce..b6b24931 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -52,52 +52,6 @@ import ( "golang.org/x/net/idna" ) -// The HTTP protocols are defined in terms of ASCII, not Unicode. This file -// contains helper functions which may use Unicode-aware functions which would -// otherwise be unsafe and could introduce vulnerabilities if used improperly. - -// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t -// are equal, ASCII-case-insensitively. -func http2asciiEqualFold(s, t string) bool { - if len(s) != len(t) { - return false - } - for i := 0; i < len(s); i++ { - if http2lower(s[i]) != http2lower(t[i]) { - return false - } - } - return true -} - -// lower returns the ASCII lowercase version of b. -func http2lower(b byte) byte { - if 'A' <= b && b <= 'Z' { - return b + ('a' - 'A') - } - return b -} - -// isASCIIPrint returns whether s is ASCII and printable according to -// https://tools.ietf.org/html/rfc20#section-4.2. -func http2isASCIIPrint(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] < ' ' || s[i] > '~' { - return false - } - } - return true -} - -// asciiToLower returns the lowercase version of s if s is ASCII and printable, -// and whether or not it was. -func http2asciiToLower(s string) (lower string, ok bool) { - if !http2isASCIIPrint(s) { - return "", false - } - return strings.ToLower(s), true -} - // A list of the possible cipher suite ids. Taken from // https://www.iana.org/assignments/tls-parameters/tls-parameters.txt From fd17f6cebf6e18ffe821099a7f885e199b1a476b Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:24:33 +0800 Subject: [PATCH 310/843] replace h2_ascii.go with internal util --- h2_ascii.go | 49 ------------------------------------------------- h2_bundle.go | 33 +++++++++++++++++---------------- 2 files changed, 17 insertions(+), 65 deletions(-) delete mode 100644 h2_ascii.go diff --git a/h2_ascii.go b/h2_ascii.go deleted file mode 100644 index be3f69e9..00000000 --- a/h2_ascii.go +++ /dev/null @@ -1,49 +0,0 @@ -package req - -import "strings" - -// The HTTP protocols are defined in terms of ASCII, not Unicode. This file -// contains helper functions which may use Unicode-aware functions which would -// otherwise be unsafe and could introduce vulnerabilities if used improperly. - -// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t -// are equal, ASCII-case-insensitively. -func http2asciiEqualFold(s, t string) bool { - if len(s) != len(t) { - return false - } - for i := 0; i < len(s); i++ { - if http2lower(s[i]) != http2lower(t[i]) { - return false - } - } - return true -} - -// lower returns the ASCII lowercase version of b. -func http2lower(b byte) byte { - if 'A' <= b && b <= 'Z' { - return b + ('a' - 'A') - } - return b -} - -// isASCIIPrint returns whether s is ASCII and printable according to -// https://tools.ietf.org/html/rfc20#section-4.2. -func http2isASCIIPrint(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] < ' ' || s[i] > '~' { - return false - } - } - return true -} - -// asciiToLower returns the lowercase version of s if s is ASCII and printable, -// and whether or not it was. -func http2asciiToLower(s string) (lower string, ok bool) { - if !http2isASCIIPrint(s) { - return "", false - } - return strings.ToLower(s), true -} diff --git a/h2_bundle.go b/h2_bundle.go index b6b24931..a1520003 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -29,6 +29,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/imroc/req/v3/internal/ascii" "io" "io/ioutil" "log" @@ -2580,12 +2581,12 @@ func http2buildCommonHeaderMaps() { } } -func http2lowerHeader(v string) (lower string, ascii bool) { +func http2lowerHeader(v string) (lower string, isAscii bool) { http2buildCommonHeaderMapsOnce() if s, ok := http2commonLowerHeader[v]; ok { return s, true } - return http2asciiToLower(v) + return ascii.ToLower(v) } var ( @@ -3971,7 +3972,7 @@ func http2checkConnHeaders(req *http.Request) error { if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) { + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !ascii.EqualFold(vv[0], "close") && !ascii.EqualFold(vv[0], "keep-alive")) { return fmt.Errorf("http2: invalid Connection request header: %q", vv) } return nil @@ -4707,21 +4708,21 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, var didUA bool for k, vv := range req.Header { - if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") { + if ascii.EqualFold(k, "host") || ascii.EqualFold(k, "content-length") { // Host is :authority, already sent. // Content-Length is automatic, set below. continue - } else if http2asciiEqualFold(k, "connection") || - http2asciiEqualFold(k, "proxy-connection") || - http2asciiEqualFold(k, "transfer-encoding") || - http2asciiEqualFold(k, "upgrade") || - http2asciiEqualFold(k, "keep-alive") { + } else if ascii.EqualFold(k, "connection") || + ascii.EqualFold(k, "proxy-connection") || + ascii.EqualFold(k, "transfer-encoding") || + ascii.EqualFold(k, "upgrade") || + ascii.EqualFold(k, "keep-alive") { // Per 8.1.2.2 Connection-Specific Header // Fields, don't send connection-specific // fields. We have already checked if any // are error-worthy so just ignore the rest. continue - } else if http2asciiEqualFold(k, "user-agent") { + } else if ascii.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one // User-Agent. If set to nil or empty string, // then omit it. Otherwise if not mentioned, @@ -4734,7 +4735,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, if vv[0] == "" { continue } - } else if http2asciiEqualFold(k, "cookie") { + } else if ascii.EqualFold(k, "cookie") { // Per 8.1.2.5 To allow for better compression efficiency, the // Cookie header field MAY be split into separate header fields, // each with one or more cookie-pairs. @@ -4812,7 +4813,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { - name, ascii := http2asciiToLower(name) + name, ascii := ascii.ToLower(name) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -4879,7 +4880,7 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) } for k, vv := range trailer { - lowKey, ascii := http2asciiToLower(k) + lowKey, ascii := ascii.ToLower(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -5295,7 +5296,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http cs.bytesRemain = res.ContentLength res.Body = http2transportResponseBody{cs} - if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { + if cs.requestedGzip && ascii.EqualFold(res.Header.Get("Content-Encoding"), "gzip") { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 @@ -6042,8 +6043,8 @@ func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { } for _, k := range keys { vv := h[k] - k, ascii := http2lowerHeader(k) - if !ascii { + k, isAscii := http2lowerHeader(k) + if !isAscii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). continue From fa3426d83c1afb3c477bd90ac4f419d203e8fa0d Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:25:29 +0800 Subject: [PATCH 311/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3c1e76e5..beee6696 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ If you want to use the older version, check it out on [v1 branch](https://github ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging). +* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging)). * Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). * Works fine with both `HTTP/2` and `HTTP/1.1`, which `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). * Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). From 505bcbcc7f61b5e00ef67573db98af0868460f89 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:31:33 +0800 Subject: [PATCH 312/843] rename charsetutil to charsets --- decode.go | 4 ++-- internal/{charsetutil/charsetutil.go => charsets/charsets.go} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename internal/{charsetutil/charsetutil.go => charsets/charsets.go} (99%) diff --git a/decode.go b/decode.go index 7d4d273b..9d453c57 100644 --- a/decode.go +++ b/decode.go @@ -1,7 +1,7 @@ package req import ( - "github.com/imroc/req/v3/internal/charsetutil" + "github.com/imroc/req/v3/internal/charsets" "io" "strings" ) @@ -48,7 +48,7 @@ func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { return } a.detected = true - enc, name := charsetutil.FindEncoding(p) + enc, name := charsets.FindEncoding(p) if enc == nil { return } diff --git a/internal/charsetutil/charsetutil.go b/internal/charsets/charsets.go similarity index 99% rename from internal/charsetutil/charsetutil.go rename to internal/charsets/charsets.go index 5acdec96..e612fb1e 100644 --- a/internal/charsetutil/charsetutil.go +++ b/internal/charsets/charsets.go @@ -1,4 +1,4 @@ -package charsetutil +package charsets import ( "bytes" From 831cfdd91895cf51447a505e4fe7ceb71665a947 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:38:36 +0800 Subject: [PATCH 313/843] add h2_pipe_test.go --- h2_pipe_test.go | 142 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 h2_pipe_test.go diff --git a/h2_pipe_test.go b/h2_pipe_test.go new file mode 100644 index 00000000..7e9b4657 --- /dev/null +++ b/h2_pipe_test.go @@ -0,0 +1,142 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "testing" +) + +func TestPipeClose(t *testing.T) { + var p http2pipe + p.b = new(bytes.Buffer) + a := errors.New("a") + b := errors.New("b") + p.CloseWithError(a) + p.CloseWithError(b) + _, err := p.Read(make([]byte, 1)) + if err != a { + t.Errorf("err = %v want %v", err, a) + } +} + +func TestPipeDoneChan(t *testing.T) { + var p http2pipe + done := p.Done() + select { + case <-done: + t.Fatal("done too soon") + default: + } + p.CloseWithError(io.EOF) + select { + case <-done: + default: + t.Fatal("should be done") + } +} + +func TestPipeDoneChan_ErrFirst(t *testing.T) { + var p http2pipe + p.CloseWithError(io.EOF) + done := p.Done() + select { + case <-done: + default: + t.Fatal("should be done") + } +} + +func TestPipeDoneChan_Break(t *testing.T) { + var p http2pipe + done := p.Done() + select { + case <-done: + t.Fatal("done too soon") + default: + } + p.BreakWithError(io.EOF) + select { + case <-done: + default: + t.Fatal("should be done") + } +} + +func TestPipeDoneChan_Break_ErrFirst(t *testing.T) { + var p http2pipe + p.BreakWithError(io.EOF) + done := p.Done() + select { + case <-done: + default: + t.Fatal("should be done") + } +} + +func TestPipeCloseWithError(t *testing.T) { + p := &http2pipe{b: new(bytes.Buffer)} + const body = "foo" + io.WriteString(p, body) + a := errors.New("test error") + p.CloseWithError(a) + all, err := ioutil.ReadAll(p) + if string(all) != body { + t.Errorf("read bytes = %q; want %q", all, body) + } + if err != a { + t.Logf("read error = %v, %v", err, a) + } + if p.Len() != 0 { + t.Errorf("pipe should have 0 unread bytes") + } + // Read and Write should fail. + if n, err := p.Write([]byte("abc")); err != http2errClosedPipeWrite || n != 0 { + t.Errorf("Write(abc) after close\ngot %v, %v\nwant 0, %v", n, err, http2errClosedPipeWrite) + } + if n, err := p.Read(make([]byte, 1)); err == nil || n != 0 { + t.Errorf("Read() after close\ngot %v, nil\nwant 0, %v", n, http2errClosedPipeWrite) + } + if p.Len() != 0 { + t.Errorf("pipe should have 0 unread bytes") + } +} + +func TestPipeBreakWithError(t *testing.T) { + p := &http2pipe{b: new(bytes.Buffer)} + io.WriteString(p, "foo") + a := errors.New("test err") + p.BreakWithError(a) + all, err := ioutil.ReadAll(p) + if string(all) != "" { + t.Errorf("read bytes = %q; want empty string", all) + } + if err != a { + t.Logf("read error = %v, %v", err, a) + } + if p.b != nil { + t.Errorf("buffer should be nil after BreakWithError") + } + if p.Len() != 3 { + t.Errorf("pipe should have 3 unread bytes") + } + // Write should succeed silently. + if n, err := p.Write([]byte("abc")); err != nil || n != 3 { + t.Errorf("Write(abc) after break\ngot %v, %v\nwant 0, nil", n, err) + } + if p.b != nil { + t.Errorf("buffer should be nil after Write") + } + if p.Len() != 6 { + t.Errorf("pipe should have 6 unread bytes") + } + // Read should fail. + if n, err := p.Read(make([]byte, 1)); err == nil || n != 0 { + t.Errorf("Read() after close\ngot %v, nil\nwant 0, not nil", n) + } +} From e3ecc49e6f333df0065f8196899335bbdff455ff Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:43:11 +0800 Subject: [PATCH 314/843] extract h2_client_conn_pool.go --- h2_bundle.go | 298 ---------------------------------------- h2_client_conn_pool.go | 304 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+), 298 deletions(-) create mode 100644 h2_client_conn_pool.go diff --git a/h2_bundle.go b/h2_bundle.go index a1520003..8ea2c881 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -53,304 +53,6 @@ import ( "golang.org/x/net/idna" ) -// A list of the possible cipher suite ids. Taken from -// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt - -// ClientConnPool manages a pool of HTTP/2 client connections. -type http2ClientConnPool interface { - // GetClientConn returns a specific HTTP/2 connection (usually - // a TLS-TCP connection) to an HTTP/2 server. On success, the - // returned ClientConn accounts for the upcoming RoundTrip - // call, so the caller should not omit it. If the caller needs - // to, ClientConn.RoundTrip can be called with a bogus - // new(http.Request) to release the stream reservation. - GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) - MarkDead(*http2ClientConn) -} - -// clientConnPoolIdleCloser is the interface implemented by ClientConnPool -// implementations which can close their idle connections. -type http2clientConnPoolIdleCloser interface { - http2ClientConnPool - closeIdleConnections() -} - -var ( - _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) - _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} -) - -// TODO: use singleflight for dialing and addConnCalls? -type http2clientConnPool struct { - t *http2Transport - - mu sync.Mutex // TODO: maybe switch to RWMutex - // TODO: add support for sharing conns based on cert names - // (e.g. share conn for googleapis.com and appspot.com) - conns map[string][]*http2ClientConn // key is host:port - dialing map[string]*http2dialCall // currently in-flight dials - keys map[*http2ClientConn][]string - addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls -} - -func (p *http2clientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2dialOnMiss) -} - -const ( - http2dialOnMiss = true - http2noDialOnMiss = false -) - -func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { - // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? - if http2isConnectionCloseRequest(req) && dialOnMiss { - // It gets its own connection. - http2traceGetConn(req, addr) - const singleUse = true - cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) - if err != nil { - return nil, err - } - return cc, nil - } - for { - p.mu.Lock() - for _, cc := range p.conns[addr] { - if cc.ReserveNewRequest() { - // When a connection is presented to us by the net/http package, - // the GetConn hook has already been called. - // Don't call it a second time here. - if !cc.getConnCalled { - http2traceGetConn(req, addr) - } - cc.getConnCalled = false - p.mu.Unlock() - return cc, nil - } - } - if !dialOnMiss { - p.mu.Unlock() - return nil, http2ErrNoCachedConn - } - http2traceGetConn(req, addr) - call := p.getStartDialLocked(req.Context(), addr) - p.mu.Unlock() - <-call.done - if http2shouldRetryDial(call, req) { - continue - } - cc, err := call.res, call.err - if err != nil { - return nil, err - } - if cc.ReserveNewRequest() { - return cc, nil - } - } -} - -// dialCall is an in-flight Transport dial call to a host. -type http2dialCall struct { - _ http2incomparable - p *http2clientConnPool - // the context associated with the request - // that created this dialCall - ctx context.Context - done chan struct{} // closed when done - res *http2ClientConn // valid after done is closed - err error // valid after done is closed -} - -// requires p.mu is held. -func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall { - if call, ok := p.dialing[addr]; ok { - // A dial is already in-flight. Don't start another. - return call - } - call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx} - if p.dialing == nil { - p.dialing = make(map[string]*http2dialCall) - } - p.dialing[addr] = call - go call.dial(call.ctx, addr) - return call -} - -// run in its own goroutine. -func (c *http2dialCall) dial(ctx context.Context, addr string) { - const singleUse = false // shared conn - c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) - close(c.done) - - c.p.mu.Lock() - delete(c.p.dialing, addr) - if c.err == nil { - c.p.addConnLocked(addr, c.res) - } - c.p.mu.Unlock() -} - -// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't -// already exist. It coalesces concurrent calls with the same key. -// This is used by the http1 Transport code when it creates a new connection. Because -// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know -// the protocol), it can get into a situation where it has multiple TLS connections. -// This code decides which ones live or die. -// The return value used is whether c was used. -// c is never closed. -func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c net.Conn) (used bool, err error) { - p.mu.Lock() - for _, cc := range p.conns[key] { - if cc.CanTakeNewRequest() { - p.mu.Unlock() - return false, nil - } - } - call, dup := p.addConnCalls[key] - if !dup { - if p.addConnCalls == nil { - p.addConnCalls = make(map[string]*http2addConnCall) - } - call = &http2addConnCall{ - p: p, - done: make(chan struct{}), - } - p.addConnCalls[key] = call - go call.run(t, key, c) - } - p.mu.Unlock() - - <-call.done - if call.err != nil { - return false, call.err - } - return !dup, nil -} - -type http2addConnCall struct { - _ http2incomparable - p *http2clientConnPool - done chan struct{} // closed when done - err error -} - -func (c *http2addConnCall) run(t *http2Transport, key string, tc net.Conn) { - cc, err := t.NewClientConn(tc) - - p := c.p - p.mu.Lock() - if err != nil { - c.err = err - } else { - cc.getConnCalled = true // already called by the net/http package - p.addConnLocked(key, cc) - } - delete(p.addConnCalls, key) - p.mu.Unlock() - close(c.done) -} - -// p.mu must be held -func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { - for _, v := range p.conns[key] { - if v == cc { - return - } - } - if p.conns == nil { - p.conns = make(map[string][]*http2ClientConn) - } - if p.keys == nil { - p.keys = make(map[*http2ClientConn][]string) - } - p.conns[key] = append(p.conns[key], cc) - p.keys[cc] = append(p.keys[cc], key) -} - -func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { - p.mu.Lock() - defer p.mu.Unlock() - for _, key := range p.keys[cc] { - vv, ok := p.conns[key] - if !ok { - continue - } - newList := http2filterOutClientConn(vv, cc) - if len(newList) > 0 { - p.conns[key] = newList - } else { - delete(p.conns, key) - } - } - delete(p.keys, cc) -} - -func (p *http2clientConnPool) closeIdleConnections() { - p.mu.Lock() - defer p.mu.Unlock() - // TODO: don't close a cc if it was just added to the pool - // milliseconds ago and has never been used. There's currently - // a small race window with the HTTP/1 Transport's integration - // where it can add an idle conn just before using it, and - // somebody else can concurrently call CloseIdleConns and - // break some caller's RoundTrip. - for _, vv := range p.conns { - for _, cc := range vv { - cc.closeIfIdle() - } - } -} - -func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { - out := in[:0] - for _, v := range in { - if v != exclude { - out = append(out, v) - } - } - // If we filtered it out, zero out the last item to prevent - // the GC from seeing it. - if len(in) != len(out) { - in[len(in)-1] = nil - } - return out -} - -// noDialClientConnPool is an implementation of http2.ClientConnPool -// which never dials. We let the HTTP/1.1 client dial and use its TLS -// connection instead. -type http2noDialClientConnPool struct{ *http2clientConnPool } - -func (p http2noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2noDialOnMiss) -} - -// shouldRetryDial reports whether the current request should -// retry dialing after the call finished unsuccessfully, for example -// if the dial was canceled because of a context cancellation or -// deadline expiry. -func http2shouldRetryDial(call *http2dialCall, req *http.Request) bool { - if call.err == nil { - // No error, no need to retry - return false - } - if call.ctx == req.Context() { - // If the call has the same context as the request, the dial - // should not be retried, since any cancellation will have come - // from this request. - return false - } - if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) { - // If the call error is not because of a context cancellation or a deadline expiry, - // the dial should not be retried. - return false - } - // Only retry if the error is a context cancellation error or deadline expiry - // and the context associated with the call was canceled or expired. - return call.ctx.Err() != nil -} - // Buffer chunks are allocated from a pool to reduce pressure on GC. // The maximum wasted space per dataBuffer is 2x the largest size class, // which happens when the dataBuffer has multiple chunks and there is diff --git a/h2_client_conn_pool.go b/h2_client_conn_pool.go new file mode 100644 index 00000000..0b649748 --- /dev/null +++ b/h2_client_conn_pool.go @@ -0,0 +1,304 @@ +package req + +import ( + "context" + "errors" + "net" + "net/http" + "sync" +) + +// ClientConnPool manages a pool of HTTP/2 client connections. +type http2ClientConnPool interface { + // GetClientConn returns a specific HTTP/2 connection (usually + // a TLS-TCP connection) to an HTTP/2 server. On success, the + // returned ClientConn accounts for the upcoming RoundTrip + // call, so the caller should not omit it. If the caller needs + // to, ClientConn.RoundTrip can be called with a bogus + // new(http.Request) to release the stream reservation. + GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) + MarkDead(*http2ClientConn) +} + +// clientConnPoolIdleCloser is the interface implemented by ClientConnPool +// implementations which can close their idle connections. +type http2clientConnPoolIdleCloser interface { + http2ClientConnPool + closeIdleConnections() +} + +var ( + _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) + _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} +) + +// TODO: use singleflight for dialing and addConnCalls? +type http2clientConnPool struct { + t *http2Transport + + mu sync.Mutex // TODO: maybe switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) + conns map[string][]*http2ClientConn // key is host:port + dialing map[string]*http2dialCall // currently in-flight dials + keys map[*http2ClientConn][]string + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls +} + +func (p *http2clientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2dialOnMiss) +} + +const ( + http2dialOnMiss = true + http2noDialOnMiss = false +) + +func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? + if http2isConnectionCloseRequest(req) && dialOnMiss { + // It gets its own connection. + http2traceGetConn(req, addr) + const singleUse = true + cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) + if err != nil { + return nil, err + } + return cc, nil + } + for { + p.mu.Lock() + for _, cc := range p.conns[addr] { + if cc.ReserveNewRequest() { + // When a connection is presented to us by the net/http package, + // the GetConn hook has already been called. + // Don't call it a second time here. + if !cc.getConnCalled { + http2traceGetConn(req, addr) + } + cc.getConnCalled = false + p.mu.Unlock() + return cc, nil + } + } + if !dialOnMiss { + p.mu.Unlock() + return nil, http2ErrNoCachedConn + } + http2traceGetConn(req, addr) + call := p.getStartDialLocked(req.Context(), addr) + p.mu.Unlock() + <-call.done + if http2shouldRetryDial(call, req) { + continue + } + cc, err := call.res, call.err + if err != nil { + return nil, err + } + if cc.ReserveNewRequest() { + return cc, nil + } + } +} + +// dialCall is an in-flight Transport dial call to a host. +type http2dialCall struct { + _ http2incomparable + p *http2clientConnPool + // the context associated with the request + // that created this dialCall + ctx context.Context + done chan struct{} // closed when done + res *http2ClientConn // valid after done is closed + err error // valid after done is closed +} + +// requires p.mu is held. +func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall { + if call, ok := p.dialing[addr]; ok { + // A dial is already in-flight. Don't start another. + return call + } + call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx} + if p.dialing == nil { + p.dialing = make(map[string]*http2dialCall) + } + p.dialing[addr] = call + go call.dial(call.ctx, addr) + return call +} + +// run in its own goroutine. +func (c *http2dialCall) dial(ctx context.Context, addr string) { + const singleUse = false // shared conn + c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) + close(c.done) + + c.p.mu.Lock() + delete(c.p.dialing, addr) + if c.err == nil { + c.p.addConnLocked(addr, c.res) + } + c.p.mu.Unlock() +} + +// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't +// already exist. It coalesces concurrent calls with the same key. +// This is used by the http1 Transport code when it creates a new connection. Because +// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know +// the protocol), it can get into a situation where it has multiple TLS connections. +// This code decides which ones live or die. +// The return value used is whether c was used. +// c is never closed. +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c net.Conn) (used bool, err error) { + p.mu.Lock() + for _, cc := range p.conns[key] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return false, nil + } + } + call, dup := p.addConnCalls[key] + if !dup { + if p.addConnCalls == nil { + p.addConnCalls = make(map[string]*http2addConnCall) + } + call = &http2addConnCall{ + p: p, + done: make(chan struct{}), + } + p.addConnCalls[key] = call + go call.run(t, key, c) + } + p.mu.Unlock() + + <-call.done + if call.err != nil { + return false, call.err + } + return !dup, nil +} + +type http2addConnCall struct { + _ http2incomparable + p *http2clientConnPool + done chan struct{} // closed when done + err error +} + +func (c *http2addConnCall) run(t *http2Transport, key string, tc net.Conn) { + cc, err := t.NewClientConn(tc) + + p := c.p + p.mu.Lock() + if err != nil { + c.err = err + } else { + cc.getConnCalled = true // already called by the net/http package + p.addConnLocked(key, cc) + } + delete(p.addConnCalls, key) + p.mu.Unlock() + close(c.done) +} + +// p.mu must be held +func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { + for _, v := range p.conns[key] { + if v == cc { + return + } + } + if p.conns == nil { + p.conns = make(map[string][]*http2ClientConn) + } + if p.keys == nil { + p.keys = make(map[*http2ClientConn][]string) + } + p.conns[key] = append(p.conns[key], cc) + p.keys[cc] = append(p.keys[cc], key) +} + +func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + for _, key := range p.keys[cc] { + vv, ok := p.conns[key] + if !ok { + continue + } + newList := http2filterOutClientConn(vv, cc) + if len(newList) > 0 { + p.conns[key] = newList + } else { + delete(p.conns, key) + } + } + delete(p.keys, cc) +} + +func (p *http2clientConnPool) closeIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + // TODO: don't close a cc if it was just added to the pool + // milliseconds ago and has never been used. There's currently + // a small race window with the HTTP/1 Transport's integration + // where it can add an idle conn just before using it, and + // somebody else can concurrently call CloseIdleConns and + // break some caller's RoundTrip. + for _, vv := range p.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + // If we filtered it out, zero out the last item to prevent + // the GC from seeing it. + if len(in) != len(out) { + in[len(in)-1] = nil + } + return out +} + +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + +// shouldRetryDial reports whether the current request should +// retry dialing after the call finished unsuccessfully, for example +// if the dial was canceled because of a context cancellation or +// deadline expiry. +func http2shouldRetryDial(call *http2dialCall, req *http.Request) bool { + if call.err == nil { + // No error, no need to retry + return false + } + if call.ctx == req.Context() { + // If the call has the same context as the request, the dial + // should not be retried, since any cancellation will have come + // from this request. + return false + } + if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) { + // If the call error is not because of a context cancellation or a deadline expiry, + // the dial should not be retried. + return false + } + // Only retry if the error is a context cancellation error or deadline expiry + // and the context associated with the call was canceled or expired. + return call.ctx.Err() != nil +} From 8a90b918c4cc23dde1833e1396da40189ad053aa Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:45:53 +0800 Subject: [PATCH 315/843] extract databuffer.go --- databuffer.go | 142 ++++++++++++++++++++++++++++++++++++++++++++++++++ h2_bundle.go | 135 ----------------------------------------------- 2 files changed, 142 insertions(+), 135 deletions(-) create mode 100644 databuffer.go diff --git a/databuffer.go b/databuffer.go new file mode 100644 index 00000000..d97452f4 --- /dev/null +++ b/databuffer.go @@ -0,0 +1,142 @@ +package req + +import ( + "errors" + "fmt" + "sync" +) + +// Buffer chunks are allocated from a pool to reduce pressure on GC. +// The maximum wasted space per dataBuffer is 2x the largest size class, +// which happens when the dataBuffer has multiple chunks and there is +// one unread byte in both the first and last chunks. We use a few size +// classes to minimize overheads for servers that typically receive very +// small request bodies. +// +// TODO: Benchmark to determine if the pools are necessary. The GC may have +// improved enough that we can instead allocate chunks like this: +// make([]byte, max(16<<10, expectedBytesRemaining)) +var ( + http2dataChunkSizeClasses = []int{ + 1 << 10, + 2 << 10, + 4 << 10, + 8 << 10, + 16 << 10, + } + http2dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return make([]byte, 1<<10) }}, + {New: func() interface{} { return make([]byte, 2<<10) }}, + {New: func() interface{} { return make([]byte, 4<<10) }}, + {New: func() interface{} { return make([]byte, 8<<10) }}, + {New: func() interface{} { return make([]byte, 16<<10) }}, + } +) + +func http2getDataBufferChunk(size int64) []byte { + i := 0 + for ; i < len(http2dataChunkSizeClasses)-1; i++ { + if size <= int64(http2dataChunkSizeClasses[i]) { + break + } + } + return http2dataChunkPools[i].Get().([]byte) +} + +func http2putDataBufferChunk(p []byte) { + for i, n := range http2dataChunkSizeClasses { + if len(p) == n { + http2dataChunkPools[i].Put(p) + return + } + } + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) +} + +// dataBuffer is an io.ReadWriter backed by a list of data chunks. +// Each dataBuffer is used to read DATA frames on a single stream. +// The buffer is divided into chunks so the server can limit the +// total memory used by a single connection without limiting the +// request body size on any single stream. +type http2dataBuffer struct { + chunks [][]byte + r int // next byte to read is chunks[0][r] + w int // next byte to write is chunks[len(chunks)-1][w] + size int // total buffered bytes + expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) +} + +var http2errReadEmpty = errors.New("read from empty dataBuffer") + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2dataBuffer) Read(p []byte) (int, error) { + if b.size == 0 { + return 0, http2errReadEmpty + } + var ntotal int + for len(p) > 0 && b.size > 0 { + readFrom := b.bytesFromFirstChunk() + n := copy(p, readFrom) + p = p[n:] + ntotal += n + b.r += n + b.size -= n + // If the first chunk has been consumed, advance to the next chunk. + if b.r == len(b.chunks[0]) { + http2putDataBufferChunk(b.chunks[0]) + end := len(b.chunks) - 1 + copy(b.chunks[:end], b.chunks[1:]) + b.chunks[end] = nil + b.chunks = b.chunks[:end] + b.r = 0 + } + } + return ntotal, nil +} + +func (b *http2dataBuffer) bytesFromFirstChunk() []byte { + if len(b.chunks) == 1 { + return b.chunks[0][b.r:b.w] + } + return b.chunks[0][b.r:] +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2dataBuffer) Len() int { + return b.size +} + +// Write appends p to the buffer. +func (b *http2dataBuffer) Write(p []byte) (int, error) { + ntotal := len(p) + for len(p) > 0 { + // If the last chunk is empty, allocate a new chunk. Try to allocate + // enough to fully copy p plus any additional bytes we expect to + // receive. However, this may allocate less than len(p). + want := int64(len(p)) + if b.expected > want { + want = b.expected + } + chunk := b.lastChunkOrAlloc(want) + n := copy(chunk[b.w:], p) + p = p[n:] + b.w += n + b.size += n + b.expected -= int64(n) + } + return ntotal, nil +} + +func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { + if len(b.chunks) != 0 { + last := b.chunks[len(b.chunks)-1] + if b.w < len(last) { + return last + } + } + chunk := http2getDataBufferChunk(want) + b.chunks = append(b.chunks, chunk) + b.w = 0 + return chunk +} diff --git a/h2_bundle.go b/h2_bundle.go index 8ea2c881..c50907a8 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -53,141 +53,6 @@ import ( "golang.org/x/net/idna" ) -// Buffer chunks are allocated from a pool to reduce pressure on GC. -// The maximum wasted space per dataBuffer is 2x the largest size class, -// which happens when the dataBuffer has multiple chunks and there is -// one unread byte in both the first and last chunks. We use a few size -// classes to minimize overheads for servers that typically receive very -// small request bodies. -// -// TODO: Benchmark to determine if the pools are necessary. The GC may have -// improved enough that we can instead allocate chunks like this: -// make([]byte, max(16<<10, expectedBytesRemaining)) -var ( - http2dataChunkSizeClasses = []int{ - 1 << 10, - 2 << 10, - 4 << 10, - 8 << 10, - 16 << 10, - } - http2dataChunkPools = [...]sync.Pool{ - {New: func() interface{} { return make([]byte, 1<<10) }}, - {New: func() interface{} { return make([]byte, 2<<10) }}, - {New: func() interface{} { return make([]byte, 4<<10) }}, - {New: func() interface{} { return make([]byte, 8<<10) }}, - {New: func() interface{} { return make([]byte, 16<<10) }}, - } -) - -func http2getDataBufferChunk(size int64) []byte { - i := 0 - for ; i < len(http2dataChunkSizeClasses)-1; i++ { - if size <= int64(http2dataChunkSizeClasses[i]) { - break - } - } - return http2dataChunkPools[i].Get().([]byte) -} - -func http2putDataBufferChunk(p []byte) { - for i, n := range http2dataChunkSizeClasses { - if len(p) == n { - http2dataChunkPools[i].Put(p) - return - } - } - panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) -} - -// dataBuffer is an io.ReadWriter backed by a list of data chunks. -// Each dataBuffer is used to read DATA frames on a single stream. -// The buffer is divided into chunks so the server can limit the -// total memory used by a single connection without limiting the -// request body size on any single stream. -type http2dataBuffer struct { - chunks [][]byte - r int // next byte to read is chunks[0][r] - w int // next byte to write is chunks[len(chunks)-1][w] - size int // total buffered bytes - expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) -} - -var http2errReadEmpty = errors.New("read from empty dataBuffer") - -// Read copies bytes from the buffer into p. -// It is an error to read when no data is available. -func (b *http2dataBuffer) Read(p []byte) (int, error) { - if b.size == 0 { - return 0, http2errReadEmpty - } - var ntotal int - for len(p) > 0 && b.size > 0 { - readFrom := b.bytesFromFirstChunk() - n := copy(p, readFrom) - p = p[n:] - ntotal += n - b.r += n - b.size -= n - // If the first chunk has been consumed, advance to the next chunk. - if b.r == len(b.chunks[0]) { - http2putDataBufferChunk(b.chunks[0]) - end := len(b.chunks) - 1 - copy(b.chunks[:end], b.chunks[1:]) - b.chunks[end] = nil - b.chunks = b.chunks[:end] - b.r = 0 - } - } - return ntotal, nil -} - -func (b *http2dataBuffer) bytesFromFirstChunk() []byte { - if len(b.chunks) == 1 { - return b.chunks[0][b.r:b.w] - } - return b.chunks[0][b.r:] -} - -// Len returns the number of bytes of the unread portion of the buffer. -func (b *http2dataBuffer) Len() int { - return b.size -} - -// Write appends p to the buffer. -func (b *http2dataBuffer) Write(p []byte) (int, error) { - ntotal := len(p) - for len(p) > 0 { - // If the last chunk is empty, allocate a new chunk. Try to allocate - // enough to fully copy p plus any additional bytes we expect to - // receive. However, this may allocate less than len(p). - want := int64(len(p)) - if b.expected > want { - want = b.expected - } - chunk := b.lastChunkOrAlloc(want) - n := copy(chunk[b.w:], p) - p = p[n:] - b.w += n - b.size += n - b.expected -= int64(n) - } - return ntotal, nil -} - -func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { - if len(b.chunks) != 0 { - last := b.chunks[len(b.chunks)-1] - if b.w < len(last) { - return last - } - } - chunk := http2getDataBufferChunk(want) - b.chunks = append(b.chunks, chunk) - b.w = 0 - return chunk -} - // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. type http2ErrCode uint32 From a802385bf70385daf5acc278401ea95666a60270 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:47:37 +0800 Subject: [PATCH 316/843] databuffer.go --> h2_databuffer.go --- databuffer.go => h2_databuffer.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename databuffer.go => h2_databuffer.go (100%) diff --git a/databuffer.go b/h2_databuffer.go similarity index 100% rename from databuffer.go rename to h2_databuffer.go From 325e979564e209cf0118d22f43ff71f1e7a38635 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:49:57 +0800 Subject: [PATCH 317/843] h2_databuffer_test.go --- h2_databuffer_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 h2_databuffer_test.go diff --git a/h2_databuffer_test.go b/h2_databuffer_test.go new file mode 100644 index 00000000..a9d4f09b --- /dev/null +++ b/h2_databuffer_test.go @@ -0,0 +1,155 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bytes" + "fmt" + "reflect" + "testing" +) + +func fmtDataChunk(chunk []byte) string { + out := "" + var last byte + var count int + for _, c := range chunk { + if c != last { + if count > 0 { + out += fmt.Sprintf(" x %d ", count) + count = 0 + } + out += string([]byte{c}) + last = c + } + count++ + } + if count > 0 { + out += fmt.Sprintf(" x %d", count) + } + return out +} + +func fmtDataChunks(chunks [][]byte) string { + var out string + for _, chunk := range chunks { + out += fmt.Sprintf("{%q}", fmtDataChunk(chunk)) + } + return out +} + +func testDataBuffer(t *testing.T, wantBytes []byte, setup func(t *testing.T) *http2dataBuffer) { + // Run setup, then read the remaining bytes from the dataBuffer and check + // that they match wantBytes. We use different read sizes to check corner + // cases in Read. + for _, readSize := range []int{1, 2, 1 * 1024, 32 * 1024} { + t.Run(fmt.Sprintf("ReadSize=%d", readSize), func(t *testing.T) { + b := setup(t) + buf := make([]byte, readSize) + var gotRead bytes.Buffer + for { + n, err := b.Read(buf) + gotRead.Write(buf[:n]) + if err == http2errReadEmpty { + break + } + if err != nil { + t.Fatalf("error after %v bytes: %v", gotRead.Len(), err) + } + } + if got, want := gotRead.Bytes(), wantBytes; !bytes.Equal(got, want) { + t.Errorf("FinalRead=%q, want %q", fmtDataChunk(got), fmtDataChunk(want)) + } + }) + } +} + +func TestDataBufferAllocation(t *testing.T) { + writes := [][]byte{ + bytes.Repeat([]byte("a"), 1*1024-1), + []byte("a"), + bytes.Repeat([]byte("b"), 4*1024-1), + []byte("b"), + bytes.Repeat([]byte("c"), 8*1024-1), + []byte("c"), + bytes.Repeat([]byte("d"), 16*1024-1), + []byte("d"), + bytes.Repeat([]byte("e"), 32*1024), + } + var wantRead bytes.Buffer + for _, p := range writes { + wantRead.Write(p) + } + + testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *http2dataBuffer { + b := &http2dataBuffer{} + for _, p := range writes { + if n, err := b.Write(p); n != len(p) || err != nil { + t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) + } + } + want := [][]byte{ + bytes.Repeat([]byte("a"), 1*1024), + bytes.Repeat([]byte("b"), 4*1024), + bytes.Repeat([]byte("c"), 8*1024), + bytes.Repeat([]byte("d"), 16*1024), + bytes.Repeat([]byte("e"), 16*1024), + bytes.Repeat([]byte("e"), 16*1024), + } + if !reflect.DeepEqual(b.chunks, want) { + t.Errorf("dataBuffer.chunks\ngot: %s\nwant: %s", fmtDataChunks(b.chunks), fmtDataChunks(want)) + } + return b + }) +} + +func TestDataBufferAllocationWithExpected(t *testing.T) { + writes := [][]byte{ + bytes.Repeat([]byte("a"), 1*1024), // allocates 16KB + bytes.Repeat([]byte("b"), 14*1024), + bytes.Repeat([]byte("c"), 15*1024), // allocates 16KB more + bytes.Repeat([]byte("d"), 2*1024), + bytes.Repeat([]byte("e"), 1*1024), // overflows 32KB expectation, allocates just 1KB + } + var wantRead bytes.Buffer + for _, p := range writes { + wantRead.Write(p) + } + + testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *http2dataBuffer { + b := &http2dataBuffer{expected: 32 * 1024} + for _, p := range writes { + if n, err := b.Write(p); n != len(p) || err != nil { + t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) + } + } + want := [][]byte{ + append(bytes.Repeat([]byte("a"), 1*1024), append(bytes.Repeat([]byte("b"), 14*1024), bytes.Repeat([]byte("c"), 1*1024)...)...), + append(bytes.Repeat([]byte("c"), 14*1024), bytes.Repeat([]byte("d"), 2*1024)...), + bytes.Repeat([]byte("e"), 1*1024), + } + if !reflect.DeepEqual(b.chunks, want) { + t.Errorf("dataBuffer.chunks\ngot: %s\nwant: %s", fmtDataChunks(b.chunks), fmtDataChunks(want)) + } + return b + }) +} + +func TestDataBufferWriteAfterPartialRead(t *testing.T) { + testDataBuffer(t, []byte("cdxyz"), func(t *testing.T) *http2dataBuffer { + b := &http2dataBuffer{} + if n, err := b.Write([]byte("abcd")); n != 4 || err != nil { + t.Fatalf("Write(\"abcd\")=%v,%v want 4,nil", n, err) + } + p := make([]byte, 2) + if n, err := b.Read(p); n != 2 || err != nil || !bytes.Equal(p, []byte("ab")) { + t.Fatalf("Read()=%q,%v,%v want \"ab\",2,nil", p, n, err) + } + if n, err := b.Write([]byte("xyz")); n != 3 || err != nil { + t.Fatalf("Write(\"xyz\")=%v,%v want 3,nil", n, err) + } + return b + }) +} From 9837c4377cafb0f0449afc6056a48306d01ddf78 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:52:52 +0800 Subject: [PATCH 318/843] extract h2_errors.go --- h2_bundle.go | 137 ------------------------------------------------ h2_errors.go | 143 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 137 deletions(-) create mode 100644 h2_errors.go diff --git a/h2_bundle.go b/h2_bundle.go index c50907a8..053f757b 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -53,143 +53,6 @@ import ( "golang.org/x/net/idna" ) -// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. -type http2ErrCode uint32 - -const ( - http2ErrCodeNo http2ErrCode = 0x0 - http2ErrCodeProtocol http2ErrCode = 0x1 - http2ErrCodeInternal http2ErrCode = 0x2 - http2ErrCodeFlowControl http2ErrCode = 0x3 - http2ErrCodeSettingsTimeout http2ErrCode = 0x4 - http2ErrCodeStreamClosed http2ErrCode = 0x5 - http2ErrCodeFrameSize http2ErrCode = 0x6 - http2ErrCodeRefusedStream http2ErrCode = 0x7 - http2ErrCodeCancel http2ErrCode = 0x8 - http2ErrCodeCompression http2ErrCode = 0x9 - http2ErrCodeConnect http2ErrCode = 0xa - http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb - http2ErrCodeInadequateSecurity http2ErrCode = 0xc - http2ErrCodeHTTP11Required http2ErrCode = 0xd -) - -var http2errCodeName = map[http2ErrCode]string{ - http2ErrCodeNo: "NO_ERROR", - http2ErrCodeProtocol: "PROTOCOL_ERROR", - http2ErrCodeInternal: "INTERNAL_ERROR", - http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", - http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", - http2ErrCodeStreamClosed: "STREAM_CLOSED", - http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", - http2ErrCodeRefusedStream: "REFUSED_STREAM", - http2ErrCodeCancel: "CANCEL", - http2ErrCodeCompression: "COMPRESSION_ERROR", - http2ErrCodeConnect: "CONNECT_ERROR", - http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", - http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", - http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", -} - -func (e http2ErrCode) String() string { - if s, ok := http2errCodeName[e]; ok { - return s - } - return fmt.Sprintf("unknown error code 0x%x", uint32(e)) -} - -func (e http2ErrCode) stringToken() string { - if s, ok := http2errCodeName[e]; ok { - return s - } - return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) -} - -// ConnectionError is an error that results in the termination of the -// entire connection. -type http2ConnectionError http2ErrCode - -func (e http2ConnectionError) Error() string { - return fmt.Sprintf("connection error: %s", http2ErrCode(e)) -} - -// StreamError is an error that only affects one stream within an -// HTTP/2 connection. -type http2StreamError struct { - StreamID uint32 - Code http2ErrCode - Cause error // optional additional detail -} - -// errFromPeer is a sentinel error value for StreamError.Cause to -// indicate that the StreamError was sent from the peer over the wire -// and wasn't locally generated in the Transport. -var http2errFromPeer = errors.New("received from peer") - -func http2streamError(id uint32, code http2ErrCode) http2StreamError { - return http2StreamError{StreamID: id, Code: code} -} - -func (e http2StreamError) Error() string { - if e.Cause != nil { - return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) - } - return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) -} - -// 6.9.1 The Flow Control Window -// "If a sender receives a WINDOW_UPDATE that causes a flow control -// window to exceed this maximum it MUST terminate either the stream -// or the connection, as appropriate. For streams, [...]; for the -// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." -type http2goAwayFlowError struct{} - -func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } - -// connError represents an HTTP/2 ConnectionError error code, along -// with a string (for debugging) explaining why. -// -// Errors of this type are only returned by the frame parser functions -// and converted into ConnectionError(Code), after stashing away -// the Reason into the Framer's errDetail field, accessible via -// the (*Framer).ErrorDetail method. -type http2connError struct { - Code http2ErrCode // the ConnectionError error code - Reason string // additional reason -} - -func (e http2connError) Error() string { - return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) -} - -type http2pseudoHeaderError string - -func (e http2pseudoHeaderError) Error() string { - return fmt.Sprintf("invalid pseudo-header %q", string(e)) -} - -type http2duplicatePseudoHeaderError string - -func (e http2duplicatePseudoHeaderError) Error() string { - return fmt.Sprintf("duplicate pseudo-header %q", string(e)) -} - -type http2headerFieldNameError string - -func (e http2headerFieldNameError) Error() string { - return fmt.Sprintf("invalid header field name %q", string(e)) -} - -type http2headerFieldValueError string - -func (e http2headerFieldValueError) Error() string { - return fmt.Sprintf("invalid header field value %q", string(e)) -} - -var ( - http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") - http2errPseudoAfterRegular = errors.New("pseudo header field after regular") -) - // flow is the flow control window's size. type http2flow struct { _ http2incomparable diff --git a/h2_errors.go b/h2_errors.go new file mode 100644 index 00000000..f226cfba --- /dev/null +++ b/h2_errors.go @@ -0,0 +1,143 @@ +package req + +import ( + "errors" + "fmt" +) + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type http2ErrCode uint32 + +const ( + http2ErrCodeNo http2ErrCode = 0x0 + http2ErrCodeProtocol http2ErrCode = 0x1 + http2ErrCodeInternal http2ErrCode = 0x2 + http2ErrCodeFlowControl http2ErrCode = 0x3 + http2ErrCodeSettingsTimeout http2ErrCode = 0x4 + http2ErrCodeStreamClosed http2ErrCode = 0x5 + http2ErrCodeFrameSize http2ErrCode = 0x6 + http2ErrCodeRefusedStream http2ErrCode = 0x7 + http2ErrCodeCancel http2ErrCode = 0x8 + http2ErrCodeCompression http2ErrCode = 0x9 + http2ErrCodeConnect http2ErrCode = 0xa + http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb + http2ErrCodeInadequateSecurity http2ErrCode = 0xc + http2ErrCodeHTTP11Required http2ErrCode = 0xd +) + +var http2errCodeName = map[http2ErrCode]string{ + http2ErrCodeNo: "NO_ERROR", + http2ErrCodeProtocol: "PROTOCOL_ERROR", + http2ErrCodeInternal: "INTERNAL_ERROR", + http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + http2ErrCodeStreamClosed: "STREAM_CLOSED", + http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", + http2ErrCodeRefusedStream: "REFUSED_STREAM", + http2ErrCodeCancel: "CANCEL", + http2ErrCodeCompression: "COMPRESSION_ERROR", + http2ErrCodeConnect: "CONNECT_ERROR", + http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e http2ErrCode) String() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +func (e http2ErrCode) stringToken() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type http2ConnectionError http2ErrCode + +func (e http2ConnectionError) Error() string { + return fmt.Sprintf("connection error: %s", http2ErrCode(e)) +} + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type http2StreamError struct { + StreamID uint32 + Code http2ErrCode + Cause error // optional additional detail +} + +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var http2errFromPeer = errors.New("received from peer") + +func http2streamError(id uint32, code http2ErrCode) http2StreamError { + return http2StreamError{StreamID: id, Code: code} +} + +func (e http2StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type http2goAwayFlowError struct{} + +func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } + +// connError represents an HTTP/2 ConnectionError error code, along +// with a string (for debugging) explaining why. +// +// Errors of this type are only returned by the frame parser functions +// and converted into ConnectionError(Code), after stashing away +// the Reason into the Framer's errDetail field, accessible via +// the (*Framer).ErrorDetail method. +type http2connError struct { + Code http2ErrCode // the ConnectionError error code + Reason string // additional reason +} + +func (e http2connError) Error() string { + return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) +} + +type http2pseudoHeaderError string + +func (e http2pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type http2duplicatePseudoHeaderError string + +func (e http2duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type http2headerFieldNameError string + +func (e http2headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type http2headerFieldValueError string + +func (e http2headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value %q", string(e)) +} + +var ( + http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + http2errPseudoAfterRegular = errors.New("pseudo header field after regular") +) From 4db0a30c189ef6c6ebaae0203c59ee56bf557784 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:55:17 +0800 Subject: [PATCH 319/843] add headers to source file --- h2_client_conn_pool.go | 4 ++++ h2_databuffer.go | 4 ++++ h2_errors.go | 4 ++++ h2_pipe.go | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/h2_client_conn_pool.go b/h2_client_conn_pool.go index 0b649748..8c2bcaa9 100644 --- a/h2_client_conn_pool.go +++ b/h2_client_conn_pool.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package req import ( diff --git a/h2_databuffer.go b/h2_databuffer.go index d97452f4..26f2f877 100644 --- a/h2_databuffer.go +++ b/h2_databuffer.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package req import ( diff --git a/h2_errors.go b/h2_errors.go index f226cfba..a6fbe07e 100644 --- a/h2_errors.go +++ b/h2_errors.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package req import ( diff --git a/h2_pipe.go b/h2_pipe.go index 95c93769..0777847b 100644 --- a/h2_pipe.go +++ b/h2_pipe.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package req import ( From 21511b8944ebc8c489c733b56ad6a5d2ee9436e1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:57:17 +0800 Subject: [PATCH 320/843] add h2_errors_test.go --- h2_errors_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 h2_errors_test.go diff --git a/h2_errors_test.go b/h2_errors_test.go new file mode 100644 index 00000000..cc3970b2 --- /dev/null +++ b/h2_errors_test.go @@ -0,0 +1,24 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import "testing" + +func TestErrCodeString(t *testing.T) { + tests := []struct { + err http2ErrCode + want string + }{ + {http2ErrCodeProtocol, "PROTOCOL_ERROR"}, + {0xd, "HTTP_1_1_REQUIRED"}, + {0xf, "unknown error code 0xf"}, + } + for i, tt := range tests { + got := tt.err.String() + if got != tt.want { + t.Errorf("%d. Error = %q; want %q", i, got, tt.want) + } + } +} From 47843e8f1361f99da655c0a5376f40fc656df87e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 19:58:57 +0800 Subject: [PATCH 321/843] extract h2_flow.go --- h2_bundle.go | 45 --------------------------------------------- h2_flow.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 45 deletions(-) create mode 100644 h2_flow.go diff --git a/h2_bundle.go b/h2_bundle.go index 053f757b..31bc0ed2 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -53,51 +53,6 @@ import ( "golang.org/x/net/idna" ) -// flow is the flow control window's size. -type http2flow struct { - _ http2incomparable - - // n is the number of DATA bytes we're allowed to send. - // A flow is kept both on a conn and a per-stream. - n int32 - - // conn points to the shared connection-level flow that is - // shared by all streams on that conn. It is nil for the flow - // that's on the conn directly. - conn *http2flow -} - -func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } - -func (f *http2flow) available() int32 { - n := f.n - if f.conn != nil && f.conn.n < n { - n = f.conn.n - } - return n -} - -func (f *http2flow) take(n int32) { - if n > f.available() { - panic("internal error: took too much") - } - f.n -= n - if f.conn != nil { - f.conn.n -= n - } -} - -// add adds n bytes (positive or negative) to the flow control window. -// It returns false if the sum would exceed 2^31-1. -func (f *http2flow) add(n int32) bool { - sum := f.n + n - if (sum > n) == (f.n > 0) { - f.n = sum - return true - } - return false -} - const http2frameHeaderLen = 9 var http2padZeros = make([]byte, 255) // zeros for padding diff --git a/h2_flow.go b/h2_flow.go new file mode 100644 index 00000000..a64d3e12 --- /dev/null +++ b/h2_flow.go @@ -0,0 +1,46 @@ +package req + +// flow is the flow control window's size. +type http2flow struct { + _ http2incomparable + + // n is the number of DATA bytes we're allowed to send. + // A flow is kept both on a conn and a per-stream. + n int32 + + // conn points to the shared connection-level flow that is + // shared by all streams on that conn. It is nil for the flow + // that's on the conn directly. + conn *http2flow +} + +func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } + +func (f *http2flow) available() int32 { + n := f.n + if f.conn != nil && f.conn.n < n { + n = f.conn.n + } + return n +} + +func (f *http2flow) take(n int32) { + if n > f.available() { + panic("internal error: took too much") + } + f.n -= n + if f.conn != nil { + f.conn.n -= n + } +} + +// add adds n bytes (positive or negative) to the flow control window. +// It returns false if the sum would exceed 2^31-1. +func (f *http2flow) add(n int32) bool { + sum := f.n + n + if (sum > n) == (f.n > 0) { + f.n = sum + return true + } + return false +} From 2511c97d6d44ded16ce524bf320ad0cc116d9304 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:00:22 +0800 Subject: [PATCH 322/843] add h2_flow_test.go --- h2_flow_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 h2_flow_test.go diff --git a/h2_flow_test.go b/h2_flow_test.go new file mode 100644 index 00000000..2a229be6 --- /dev/null +++ b/h2_flow_test.go @@ -0,0 +1,87 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import "testing" + +func TestFlow(t *testing.T) { + var st http2flow + var conn http2flow + st.add(3) + conn.add(2) + + if got, want := st.available(), int32(3); got != want { + t.Errorf("available = %d; want %d", got, want) + } + st.setConnFlow(&conn) + if got, want := st.available(), int32(2); got != want { + t.Errorf("after parent setup, available = %d; want %d", got, want) + } + + st.take(2) + if got, want := conn.available(), int32(0); got != want { + t.Errorf("after taking 2, conn = %d; want %d", got, want) + } + if got, want := st.available(), int32(0); got != want { + t.Errorf("after taking 2, stream = %d; want %d", got, want) + } +} + +func TestFlowAdd(t *testing.T) { + var f http2flow + if !f.add(1) { + t.Fatal("failed to add 1") + } + if !f.add(-1) { + t.Fatal("failed to add -1") + } + if got, want := f.available(), int32(0); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + if !f.add(1<<31 - 1) { + t.Fatal("failed to add 2^31-1") + } + if got, want := f.available(), int32(1<<31-1); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + if f.add(1) { + t.Fatal("adding 1 to max shouldn't be allowed") + } +} + +func TestFlowAddOverflow(t *testing.T) { + var f http2flow + if !f.add(0) { + t.Fatal("failed to add 0") + } + if !f.add(-1) { + t.Fatal("failed to add -1") + } + if !f.add(0) { + t.Fatal("failed to add 0") + } + if !f.add(1) { + t.Fatal("failed to add 1") + } + if !f.add(1) { + t.Fatal("failed to add 1") + } + if !f.add(0) { + t.Fatal("failed to add 0") + } + if !f.add(-3) { + t.Fatal("failed to add -3") + } + if got, want := f.available(), int32(-2); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + if !f.add(1<<31 - 1) { + t.Fatal("failed to add 2^31-1") + } + if got, want := f.available(), int32(1+-3+(1<<31-1)); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + +} From d009191915b2ce1f8e77ab5cd43b06c4f7741bec Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:00:59 +0800 Subject: [PATCH 323/843] add header to h2_flow.go --- h2_flow.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/h2_flow.go b/h2_flow.go index a64d3e12..3259a197 100644 --- a/h2_flow.go +++ b/h2_flow.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package req // flow is the flow control window's size. From 71974b4b77dda392f52a4b3da01475d73a8dbc20 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:05:28 +0800 Subject: [PATCH 324/843] extract h2_frame.go --- h2_bundle.go | 1662 ------------------------------------------------- h2_frame.go | 1679 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1679 insertions(+), 1662 deletions(-) create mode 100644 h2_frame.go diff --git a/h2_bundle.go b/h2_bundle.go index 31bc0ed2..584ed1c7 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -26,7 +26,6 @@ import ( "context" "crypto/rand" "crypto/tls" - "encoding/binary" "errors" "fmt" "github.com/imroc/req/v3/internal/ascii" @@ -53,1667 +52,6 @@ import ( "golang.org/x/net/idna" ) -const http2frameHeaderLen = 9 - -var http2padZeros = make([]byte, 255) // zeros for padding - -// A FrameType is a registered frame type as defined in -// http://http2.github.io/http2-spec/#rfc.section.11.2 -type http2FrameType uint8 - -const ( - http2FrameData http2FrameType = 0x0 - http2FrameHeaders http2FrameType = 0x1 - http2FramePriority http2FrameType = 0x2 - http2FrameRSTStream http2FrameType = 0x3 - http2FrameSettings http2FrameType = 0x4 - http2FramePushPromise http2FrameType = 0x5 - http2FramePing http2FrameType = 0x6 - http2FrameGoAway http2FrameType = 0x7 - http2FrameWindowUpdate http2FrameType = 0x8 - http2FrameContinuation http2FrameType = 0x9 -) - -var http2frameName = map[http2FrameType]string{ - http2FrameData: "DATA", - http2FrameHeaders: "HEADERS", - http2FramePriority: "PRIORITY", - http2FrameRSTStream: "RST_STREAM", - http2FrameSettings: "SETTINGS", - http2FramePushPromise: "PUSH_PROMISE", - http2FramePing: "PING", - http2FrameGoAway: "GOAWAY", - http2FrameWindowUpdate: "WINDOW_UPDATE", - http2FrameContinuation: "CONTINUATION", -} - -func (t http2FrameType) String() string { - if s, ok := http2frameName[t]; ok { - return s - } - return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) -} - -// Flags is a bitmask of HTTP/2 flags. -// The meaning of flags varies depending on the frame type. -type http2Flags uint8 - -// Has reports whether f contains all (0 or more) flags in v. -func (f http2Flags) Has(v http2Flags) bool { - return (f & v) == v -} - -// Frame-specific FrameHeader flag bits. -const ( - // Data Frame - http2FlagDataEndStream http2Flags = 0x1 - http2FlagDataPadded http2Flags = 0x8 - - // Headers Frame - http2FlagHeadersEndStream http2Flags = 0x1 - http2FlagHeadersEndHeaders http2Flags = 0x4 - http2FlagHeadersPadded http2Flags = 0x8 - http2FlagHeadersPriority http2Flags = 0x20 - - // Settings Frame - http2FlagSettingsAck http2Flags = 0x1 - - // Ping Frame - http2FlagPingAck http2Flags = 0x1 - - // Continuation Frame - http2FlagContinuationEndHeaders http2Flags = 0x4 - - http2FlagPushPromiseEndHeaders http2Flags = 0x4 - http2FlagPushPromisePadded http2Flags = 0x8 -) - -var http2flagName = map[http2FrameType]map[http2Flags]string{ - http2FrameData: { - http2FlagDataEndStream: "END_STREAM", - http2FlagDataPadded: "PADDED", - }, - http2FrameHeaders: { - http2FlagHeadersEndStream: "END_STREAM", - http2FlagHeadersEndHeaders: "END_HEADERS", - http2FlagHeadersPadded: "PADDED", - http2FlagHeadersPriority: "PRIORITY", - }, - http2FrameSettings: { - http2FlagSettingsAck: "ACK", - }, - http2FramePing: { - http2FlagPingAck: "ACK", - }, - http2FrameContinuation: { - http2FlagContinuationEndHeaders: "END_HEADERS", - }, - http2FramePushPromise: { - http2FlagPushPromiseEndHeaders: "END_HEADERS", - http2FlagPushPromisePadded: "PADDED", - }, -} - -// a frameParser parses a frame given its FrameHeader and payload -// bytes. The length of payload will always equal fh.Length (which -// might be 0). -type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) - -var http2frameParsers = map[http2FrameType]http2frameParser{ - http2FrameData: http2parseDataFrame, - http2FrameHeaders: http2parseHeadersFrame, - http2FramePriority: http2parsePriorityFrame, - http2FrameRSTStream: http2parseRSTStreamFrame, - http2FrameSettings: http2parseSettingsFrame, - http2FramePushPromise: http2parsePushPromise, - http2FramePing: http2parsePingFrame, - http2FrameGoAway: http2parseGoAwayFrame, - http2FrameWindowUpdate: http2parseWindowUpdateFrame, - http2FrameContinuation: http2parseContinuationFrame, -} - -func http2typeFrameParser(t http2FrameType) http2frameParser { - if f := http2frameParsers[t]; f != nil { - return f - } - return http2parseUnknownFrame -} - -// A FrameHeader is the 9 byte header of all HTTP/2 frames. -// -// See http://http2.github.io/http2-spec/#FrameHeader -type http2FrameHeader struct { - valid bool // caller can access []byte fields in the Frame - - // Type is the 1 byte frame type. There are ten standard frame - // types, but extension frame types may be written by WriteRawFrame - // and will be returned by ReadFrame (as UnknownFrame). - Type http2FrameType - - // Flags are the 1 byte of 8 potential bit flags per frame. - // They are specific to the frame type. - Flags http2Flags - - // Length is the length of the frame, not including the 9 byte header. - // The maximum size is one byte less than 16MB (uint24), but only - // frames up to 16KB are allowed without peer agreement. - Length uint32 - - // StreamID is which stream this frame is for. Certain frames - // are not stream-specific, in which case this field is 0. - StreamID uint32 -} - -// Header returns h. It exists so FrameHeaders can be embedded in other -// specific frame types and implement the Frame interface. -func (h http2FrameHeader) Header() http2FrameHeader { return h } - -func (h http2FrameHeader) String() string { - var buf bytes.Buffer - buf.WriteString("[FrameHeader ") - h.writeDebug(&buf) - buf.WriteByte(']') - return buf.String() -} - -func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { - buf.WriteString(h.Type.String()) - if h.Flags != 0 { - buf.WriteString(" flags=") - set := 0 - for i := uint8(0); i < 8; i++ { - if h.Flags&(1< 1 { - buf.WriteByte('|') - } - name := http2flagName[h.Type][http2Flags(1<>24), - byte(streamID>>16), - byte(streamID>>8), - byte(streamID)) -} - -func (f *http2Framer) endWrite() error { - // Now that we know the final size, fill in the FrameHeader in - // the space previously reserved for it. Abuse append. - length := len(f.wbuf) - http2frameHeaderLen - if length >= (1 << 24) { - return http2ErrFrameTooLarge - } - _ = append(f.wbuf[:0], - byte(length>>16), - byte(length>>8), - byte(length)) - if f.logWrites { - f.logWrite() - } - - n, err := f.w.Write(f.wbuf) - if err == nil && n != len(f.wbuf) { - err = io.ErrShortWrite - } - return err -} - -func (f *http2Framer) logWrite() { - if f.debugFramer == nil { - f.debugFramerBuf = new(bytes.Buffer) - f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) - f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below - // Let us read anything, even if we accidentally wrote it - // in the wrong order: - f.debugFramer.AllowIllegalReads = true - } - f.debugFramerBuf.Write(f.wbuf) - fr, err := f.debugFramer.ReadFrame() - if err != nil { - f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) - return - } - f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) -} - -func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } - -func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } - -func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } - -func (f *http2Framer) writeUint32(v uint32) { - f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -const ( - http2minMaxFrameSize = 1 << 14 - http2maxFrameSize = 1<<24 - 1 -) - -// SetReuseFrames allows the Framer to reuse Frames. -// If called on a Framer, Frames returned by calls to ReadFrame are only -// valid until the next call to ReadFrame. -func (fr *http2Framer) SetReuseFrames() { - if fr.frameCache != nil { - return - } - fr.frameCache = &http2frameCache{} -} - -type http2frameCache struct { - dataFrame http2DataFrame -} - -func (fc *http2frameCache) getDataFrame() *http2DataFrame { - if fc == nil { - return &http2DataFrame{} - } - return &fc.dataFrame -} - -// NewFramer returns a Framer that writes frames to w and reads them from r. -func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { - fr := &http2Framer{ - w: w, - r: r, - countError: func(string) {}, - logReads: http2logFrameReads, - logWrites: http2logFrameWrites, - debugReadLoggerf: log.Printf, - debugWriteLoggerf: log.Printf, - } - fr.getReadBuf = func(size uint32) []byte { - if cap(fr.readBuf) >= int(size) { - return fr.readBuf[:size] - } - fr.readBuf = make([]byte, size) - return fr.readBuf - } - fr.SetMaxReadFrameSize(http2maxFrameSize) - return fr -} - -// SetMaxReadFrameSize sets the maximum size of a frame -// that will be read by a subsequent call to ReadFrame. -// It is the caller's responsibility to advertise this -// limit with a SETTINGS frame. -func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { - if v > http2maxFrameSize { - v = http2maxFrameSize - } - fr.maxReadSize = v -} - -// ErrorDetail returns a more detailed error of the last error -// returned by Framer.ReadFrame. For instance, if ReadFrame -// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail -// will say exactly what was invalid. ErrorDetail is not guaranteed -// to return a non-nil value and like the rest of the http2 package, -// its return value is not protected by an API compatibility promise. -// ErrorDetail is reset after the next call to ReadFrame. -func (fr *http2Framer) ErrorDetail() error { - return fr.errDetail -} - -// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer -// sends a frame that is larger than declared with SetMaxReadFrameSize. -var http2ErrFrameTooLarge = errors.New("http2: frame too large") - -// terminalReadFrameError reports whether err is an unrecoverable -// error from ReadFrame and no other frames should be read. -func http2terminalReadFrameError(err error) bool { - if _, ok := err.(http2StreamError); ok { - return false - } - return err != nil -} - -// ReadFrame reads a single frame. The returned Frame is only valid -// until the next call to ReadFrame. -// -// If the frame is larger than previously set with SetMaxReadFrameSize, the -// returned error is ErrFrameTooLarge. Other errors may be of type -// ConnectionError, StreamError, or anything else from the underlying -// reader. -func (fr *http2Framer) ReadFrame() (http2Frame, error) { - fr.errDetail = nil - if fr.lastFrame != nil { - fr.lastFrame.invalidate() - } - fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) - if err != nil { - return nil, err - } - if fh.Length > fr.maxReadSize { - return nil, http2ErrFrameTooLarge - } - payload := fr.getReadBuf(fh.Length) - if _, err := io.ReadFull(fr.r, payload); err != nil { - return nil, err - } - f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) - if err != nil { - if ce, ok := err.(http2connError); ok { - return nil, fr.connError(ce.Code, ce.Reason) - } - return nil, err - } - if err := fr.checkFrameOrder(f); err != nil { - return nil, err - } - if fr.logReads { - fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) - } - if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { - dumps := getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) - if len(dumps) > 0 { - dd := []*dumper{} - for _, dump := range dumps { - if dump.ResponseHeader { - dd = append(dd, dump) - } - } - dumps = dd - } - hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dumps) - if err == nil && len(dumps) > 0 { - for _, dump := range dumps { - dump.dump([]byte("\r\n")) - } - } - return hr, err - } - return f, nil -} - -// connError returns ConnectionError(code) but first -// stashes away a public reason to the caller can optionally relay it -// to the peer before hanging up on them. This might help others debug -// their implementations. -func (fr *http2Framer) connError(code http2ErrCode, reason string) error { - fr.errDetail = errors.New(reason) - return http2ConnectionError(code) -} - -// checkFrameOrder reports an error if f is an invalid frame to return -// next from ReadFrame. Mostly it checks whether HEADERS and -// CONTINUATION frames are contiguous. -func (fr *http2Framer) checkFrameOrder(f http2Frame) error { - last := fr.lastFrame - fr.lastFrame = f - if fr.AllowIllegalReads { - return nil - } - - fh := f.Header() - if fr.lastHeaderStream != 0 { - if fh.Type != http2FrameContinuation { - return fr.connError(http2ErrCodeProtocol, - fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", - fh.Type, fh.StreamID, - last.Header().Type, fr.lastHeaderStream)) - } - if fh.StreamID != fr.lastHeaderStream { - return fr.connError(http2ErrCodeProtocol, - fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", - fh.StreamID, fr.lastHeaderStream)) - } - } else if fh.Type == http2FrameContinuation { - return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) - } - - switch fh.Type { - case http2FrameHeaders, http2FrameContinuation: - if fh.Flags.Has(http2FlagHeadersEndHeaders) { - fr.lastHeaderStream = 0 - } else { - fr.lastHeaderStream = fh.StreamID - } - } - - return nil -} - -// A DataFrame conveys arbitrary, variable-length sequences of octets -// associated with a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.1 -type http2DataFrame struct { - http2FrameHeader - data []byte -} - -func (f *http2DataFrame) StreamEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) -} - -// Data returns the frame's data octets, not including any padding -// size byte or padding suffix bytes. -// The caller must not retain the returned memory past the next -// call to ReadFrame. -func (f *http2DataFrame) Data() []byte { - f.checkValid() - return f.data -} - -func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { - if fh.StreamID == 0 { - // DATA frames MUST be associated with a stream. If a - // DATA frame is received whose stream identifier - // field is 0x0, the recipient MUST respond with a - // connection error (Section 5.4.1) of type - // PROTOCOL_ERROR. - countError("frame_data_stream_0") - return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} - } - f := fc.getDataFrame() - f.http2FrameHeader = fh - - var padSize byte - if fh.Flags.Has(http2FlagDataPadded) { - var err error - payload, padSize, err = http2readByte(payload) - if err != nil { - countError("frame_data_pad_byte_short") - return nil, err - } - } - if int(padSize) > len(payload) { - // If the length of the padding is greater than the - // length of the frame payload, the recipient MUST - // treat this as a connection error. - // Filed: https://github.com/http2/http2-spec/issues/610 - countError("frame_data_pad_too_big") - return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} - } - f.data = payload[:len(payload)-int(padSize)] - return f, nil -} - -var ( - http2errStreamID = errors.New("invalid stream ID") - http2errDepStreamID = errors.New("invalid dependent stream ID") - http2errPadLength = errors.New("pad length too large") - http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") -) - -func http2validStreamIDOrZero(streamID uint32) bool { - return streamID&(1<<31) == 0 -} - -func http2validStreamID(streamID uint32) bool { - return streamID != 0 && streamID&(1<<31) == 0 -} - -// writeData writes a DATA frame. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility not to violate the maximum frame size -// and to not call other Write methods concurrently. -func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { - return f.WriteDataPadded(streamID, endStream, data, nil) -} - -// WriteDataPadded writes a DATA frame with optional padding. -// -// If pad is nil, the padding bit is not sent. -// The length of pad must not exceed 255 bytes. -// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility not to violate the maximum frame size -// and to not call other Write methods concurrently. -func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - if len(pad) > 0 { - if len(pad) > 255 { - return http2errPadLength - } - if !f.AllowIllegalWrites { - for _, b := range pad { - if b != 0 { - // "Padding octets MUST be set to zero when sending." - return http2errPadBytes - } - } - } - } - var flags http2Flags - if endStream { - flags |= http2FlagDataEndStream - } - if pad != nil { - flags |= http2FlagDataPadded - } - f.startWrite(http2FrameData, flags, streamID) - if pad != nil { - f.wbuf = append(f.wbuf, byte(len(pad))) - } - f.wbuf = append(f.wbuf, data...) - f.wbuf = append(f.wbuf, pad...) - return f.endWrite() -} - -// A SettingsFrame conveys configuration parameters that affect how -// endpoints communicate, such as preferences and constraints on peer -// behavior. -// -// See http://http2.github.io/http2-spec/#SETTINGS -type http2SettingsFrame struct { - http2FrameHeader - p []byte -} - -func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { - // When this (ACK 0x1) bit is set, the payload of the - // SETTINGS frame MUST be empty. Receipt of a - // SETTINGS frame with the ACK flag set and a length - // field value other than 0 MUST be treated as a - // connection error (Section 5.4.1) of type - // FRAME_SIZE_ERROR. - countError("frame_settings_ack_with_length") - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - if fh.StreamID != 0 { - // SETTINGS frames always apply to a connection, - // never a single stream. The stream identifier for a - // SETTINGS frame MUST be zero (0x0). If an endpoint - // receives a SETTINGS frame whose stream identifier - // field is anything other than 0x0, the endpoint MUST - // respond with a connection error (Section 5.4.1) of - // type PROTOCOL_ERROR. - countError("frame_settings_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - if len(p)%6 != 0 { - countError("frame_settings_mod_6") - // Expecting even number of 6 byte settings. - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - f := &http2SettingsFrame{http2FrameHeader: fh, p: p} - if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { - countError("frame_settings_window_size_too_big") - // Values above the maximum flow control window size of 2^31 - 1 MUST - // be treated as a connection error (Section 5.4.1) of type - // FLOW_CONTROL_ERROR. - return nil, http2ConnectionError(http2ErrCodeFlowControl) - } - return f, nil -} - -func (f *http2SettingsFrame) IsAck() bool { - return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) -} - -func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { - f.checkValid() - for i := 0; i < f.NumSettings(); i++ { - if s := f.Setting(i); s.ID == id { - return s.Val, true - } - } - return 0, false -} - -// Setting returns the setting from the frame at the given 0-based index. -// The index must be >= 0 and less than f.NumSettings(). -func (f *http2SettingsFrame) Setting(i int) http2Setting { - buf := f.p - return http2Setting{ - ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), - Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), - } -} - -func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 } - -// HasDuplicates reports whether f contains any duplicate setting IDs. -func (f *http2SettingsFrame) HasDuplicates() bool { - num := f.NumSettings() - if num == 0 { - return false - } - // If it's small enough (the common case), just do the n^2 - // thing and avoid a map allocation. - if num < 10 { - for i := 0; i < num; i++ { - idi := f.Setting(i).ID - for j := i + 1; j < num; j++ { - idj := f.Setting(j).ID - if idi == idj { - return true - } - } - } - return false - } - seen := map[http2SettingID]bool{} - for i := 0; i < num; i++ { - id := f.Setting(i).ID - if seen[id] { - return true - } - seen[id] = true - } - return false -} - -// ForeachSetting runs fn for each setting. -// It stops and returns the first error. -func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { - f.checkValid() - for i := 0; i < f.NumSettings(); i++ { - if err := fn(f.Setting(i)); err != nil { - return err - } - } - return nil -} - -// WriteSettings writes a SETTINGS frame with zero or more settings -// specified and the ACK bit not set. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteSettings(settings ...http2Setting) error { - f.startWrite(http2FrameSettings, 0, 0) - for _, s := range settings { - f.writeUint16(uint16(s.ID)) - f.writeUint32(s.Val) - } - return f.endWrite() -} - -// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteSettingsAck() error { - f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) - return f.endWrite() -} - -// A PingFrame is a mechanism for measuring a minimal round trip time -// from the sender, as well as determining whether an idle connection -// is still functional. -// See http://http2.github.io/http2-spec/#rfc.section.6.7 -type http2PingFrame struct { - http2FrameHeader - Data [8]byte -} - -func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } - -func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { - if len(payload) != 8 { - countError("frame_ping_length") - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - if fh.StreamID != 0 { - countError("frame_ping_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - f := &http2PingFrame{http2FrameHeader: fh} - copy(f.Data[:], payload) - return f, nil -} - -func (f *http2Framer) WritePing(ack bool, data [8]byte) error { - var flags http2Flags - if ack { - flags = http2FlagPingAck - } - f.startWrite(http2FramePing, flags, 0) - f.writeBytes(data[:]) - return f.endWrite() -} - -// A GoAwayFrame informs the remote peer to stop creating streams on this connection. -// See http://http2.github.io/http2-spec/#rfc.section.6.8 -type http2GoAwayFrame struct { - http2FrameHeader - LastStreamID uint32 - ErrCode http2ErrCode - debugData []byte -} - -// DebugData returns any debug data in the GOAWAY frame. Its contents -// are not defined. -// The caller must not retain the returned memory past the next -// call to ReadFrame. -func (f *http2GoAwayFrame) DebugData() []byte { - f.checkValid() - return f.debugData -} - -func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if fh.StreamID != 0 { - countError("frame_goaway_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - if len(p) < 8 { - countError("frame_goaway_short") - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - return &http2GoAwayFrame{ - http2FrameHeader: fh, - LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), - ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), - debugData: p[8:], - }, nil -} - -func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { - f.startWrite(http2FrameGoAway, 0, 0) - f.writeUint32(maxStreamID & (1<<31 - 1)) - f.writeUint32(uint32(code)) - f.writeBytes(debugData) - return f.endWrite() -} - -// An UnknownFrame is the frame type returned when the frame type is unknown -// or no specific frame type parser exists. -type http2UnknownFrame struct { - http2FrameHeader - p []byte -} - -// Payload returns the frame's payload (after the header). It is not -// valid to call this method after a subsequent call to -// Framer.ReadFrame, nor is it valid to retain the returned slice. -// The memory is owned by the Framer and is invalidated when the next -// frame is read. -func (f *http2UnknownFrame) Payload() []byte { - f.checkValid() - return f.p -} - -func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - return &http2UnknownFrame{fh, p}, nil -} - -// A WindowUpdateFrame is used to implement flow control. -// See http://http2.github.io/http2-spec/#rfc.section.6.9 -type http2WindowUpdateFrame struct { - http2FrameHeader - Increment uint32 // never read with high bit set -} - -func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if len(p) != 4 { - countError("frame_windowupdate_bad_len") - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit - if inc == 0 { - // A receiver MUST treat the receipt of a - // WINDOW_UPDATE frame with an flow control window - // increment of 0 as a stream error (Section 5.4.2) of - // type PROTOCOL_ERROR; errors on the connection flow - // control window MUST be treated as a connection - // error (Section 5.4.1). - if fh.StreamID == 0 { - countError("frame_windowupdate_zero_inc_conn") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - countError("frame_windowupdate_zero_inc_stream") - return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) - } - return &http2WindowUpdateFrame{ - http2FrameHeader: fh, - Increment: inc, - }, nil -} - -// WriteWindowUpdate writes a WINDOW_UPDATE frame. -// The increment value must be between 1 and 2,147,483,647, inclusive. -// If the Stream ID is zero, the window update applies to the -// connection as a whole. -func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { - // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." - if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { - return errors.New("illegal window increment value") - } - f.startWrite(http2FrameWindowUpdate, 0, streamID) - f.writeUint32(incr) - return f.endWrite() -} - -// A HeadersFrame is used to open a stream and additionally carries a -// header block fragment. -type http2HeadersFrame struct { - http2FrameHeader - - // Priority is set if FlagHeadersPriority is set in the FrameHeader. - Priority http2PriorityParam - - headerFragBuf []byte // not owned -} - -func (f *http2HeadersFrame) HeaderBlockFragment() []byte { - f.checkValid() - return f.headerFragBuf -} - -func (f *http2HeadersFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) -} - -func (f *http2HeadersFrame) StreamEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) -} - -func (f *http2HeadersFrame) HasPriority() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) -} - -func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { - hf := &http2HeadersFrame{ - http2FrameHeader: fh, - } - if fh.StreamID == 0 { - // HEADERS frames MUST be associated with a stream. If a HEADERS frame - // is received whose stream identifier field is 0x0, the recipient MUST - // respond with a connection error (Section 5.4.1) of type - // PROTOCOL_ERROR. - countError("frame_headers_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} - } - var padLength uint8 - if fh.Flags.Has(http2FlagHeadersPadded) { - if p, padLength, err = http2readByte(p); err != nil { - countError("frame_headers_pad_short") - return - } - } - if fh.Flags.Has(http2FlagHeadersPriority) { - var v uint32 - p, v, err = http2readUint32(p) - if err != nil { - countError("frame_headers_prio_short") - return nil, err - } - hf.Priority.StreamDep = v & 0x7fffffff - hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set - p, hf.Priority.Weight, err = http2readByte(p) - if err != nil { - countError("frame_headers_prio_weight_short") - return nil, err - } - } - if len(p)-int(padLength) < 0 { - countError("frame_headers_pad_too_big") - return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) - } - hf.headerFragBuf = p[:len(p)-int(padLength)] - return hf, nil -} - -// HeadersFrameParam are the parameters for writing a HEADERS frame. -type http2HeadersFrameParam struct { - // StreamID is the required Stream ID to initiate. - StreamID uint32 - // BlockFragment is part (or all) of a Header Block. - BlockFragment []byte - - // EndStream indicates that the header block is the last that - // the endpoint will send for the identified stream. Setting - // this flag causes the stream to enter one of "half closed" - // states. - EndStream bool - - // EndHeaders indicates that this frame contains an entire - // header block and is not followed by any - // CONTINUATION frames. - EndHeaders bool - - // PadLength is the optional number of bytes of zeros to add - // to this frame. - PadLength uint8 - - // Priority, if non-zero, includes stream priority information - // in the HEADER frame. - Priority http2PriorityParam -} - -// WriteHeaders writes a single HEADERS frame. -// -// This is a low-level header writing method. Encoding headers and -// splitting them into any necessary CONTINUATION frames is handled -// elsewhere. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { - if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - var flags http2Flags - if p.PadLength != 0 { - flags |= http2FlagHeadersPadded - } - if p.EndStream { - flags |= http2FlagHeadersEndStream - } - if p.EndHeaders { - flags |= http2FlagHeadersEndHeaders - } - if !p.Priority.IsZero() { - flags |= http2FlagHeadersPriority - } - f.startWrite(http2FrameHeaders, flags, p.StreamID) - if p.PadLength != 0 { - f.writeByte(p.PadLength) - } - if !p.Priority.IsZero() { - v := p.Priority.StreamDep - if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { - return http2errDepStreamID - } - if p.Priority.Exclusive { - v |= 1 << 31 - } - f.writeUint32(v) - f.writeByte(p.Priority.Weight) - } - f.wbuf = append(f.wbuf, p.BlockFragment...) - f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) - return f.endWrite() -} - -// A PriorityFrame specifies the sender-advised priority of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.3 -type http2PriorityFrame struct { - http2FrameHeader - http2PriorityParam -} - -// PriorityParam are the stream prioritzation parameters. -type http2PriorityParam struct { - // StreamDep is a 31-bit stream identifier for the - // stream that this stream depends on. Zero means no - // dependency. - StreamDep uint32 - - // Exclusive is whether the dependency is exclusive. - Exclusive bool - - // Weight is the stream's zero-indexed weight. It should be - // set together with StreamDep, or neither should be set. Per - // the spec, "Add one to the value to obtain a weight between - // 1 and 256." - Weight uint8 -} - -func (p http2PriorityParam) IsZero() bool { - return p == http2PriorityParam{} -} - -func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { - if fh.StreamID == 0 { - countError("frame_priority_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} - } - if len(payload) != 5 { - countError("frame_priority_bad_length") - return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} - } - v := binary.BigEndian.Uint32(payload[:4]) - streamID := v & 0x7fffffff // mask off high bit - return &http2PriorityFrame{ - http2FrameHeader: fh, - http2PriorityParam: http2PriorityParam{ - Weight: payload[4], - StreamDep: streamID, - Exclusive: streamID != v, // was high bit set? - }, - }, nil -} - -// WritePriority writes a PRIORITY frame. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - if !http2validStreamIDOrZero(p.StreamDep) { - return http2errDepStreamID - } - f.startWrite(http2FramePriority, 0, streamID) - v := p.StreamDep - if p.Exclusive { - v |= 1 << 31 - } - f.writeUint32(v) - f.writeByte(p.Weight) - return f.endWrite() -} - -// A RSTStreamFrame allows for abnormal termination of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.4 -type http2RSTStreamFrame struct { - http2FrameHeader - ErrCode http2ErrCode -} - -func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if len(p) != 4 { - countError("frame_rststream_bad_len") - return nil, http2ConnectionError(http2ErrCodeFrameSize) - } - if fh.StreamID == 0 { - countError("frame_rststream_zero_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil -} - -// WriteRSTStream writes a RST_STREAM frame. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - f.startWrite(http2FrameRSTStream, 0, streamID) - f.writeUint32(uint32(code)) - return f.endWrite() -} - -// A ContinuationFrame is used to continue a sequence of header block fragments. -// See http://http2.github.io/http2-spec/#rfc.section.6.10 -type http2ContinuationFrame struct { - http2FrameHeader - headerFragBuf []byte -} - -func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if fh.StreamID == 0 { - countError("frame_continuation_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} - } - return &http2ContinuationFrame{fh, p}, nil -} - -func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { - f.checkValid() - return f.headerFragBuf -} - -func (f *http2ContinuationFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) -} - -// WriteContinuation writes a CONTINUATION frame. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - var flags http2Flags - if endHeaders { - flags |= http2FlagContinuationEndHeaders - } - f.startWrite(http2FrameContinuation, flags, streamID) - f.wbuf = append(f.wbuf, headerBlockFragment...) - return f.endWrite() -} - -// A PushPromiseFrame is used to initiate a server stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.6 -type http2PushPromiseFrame struct { - http2FrameHeader - PromiseID uint32 - headerFragBuf []byte // not owned -} - -func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { - f.checkValid() - return f.headerFragBuf -} - -func (f *http2PushPromiseFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) -} - -func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { - pp := &http2PushPromiseFrame{ - http2FrameHeader: fh, - } - if pp.StreamID == 0 { - // PUSH_PROMISE frames MUST be associated with an existing, - // peer-initiated stream. The stream identifier of a - // PUSH_PROMISE frame indicates the stream it is associated - // with. If the stream identifier field specifies the value - // 0x0, a recipient MUST respond with a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. - countError("frame_pushpromise_zero_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - // The PUSH_PROMISE frame includes optional padding. - // Padding fields and flags are identical to those defined for DATA frames - var padLength uint8 - if fh.Flags.Has(http2FlagPushPromisePadded) { - if p, padLength, err = http2readByte(p); err != nil { - countError("frame_pushpromise_pad_short") - return - } - } - - p, pp.PromiseID, err = http2readUint32(p) - if err != nil { - countError("frame_pushpromise_promiseid_short") - return - } - pp.PromiseID = pp.PromiseID & (1<<31 - 1) - - if int(padLength) > len(p) { - // like the DATA frame, error out if padding is longer than the body. - countError("frame_pushpromise_pad_too_big") - return nil, http2ConnectionError(http2ErrCodeProtocol) - } - pp.headerFragBuf = p[:len(p)-int(padLength)] - return pp, nil -} - -// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. -type http2PushPromiseParam struct { - // StreamID is the required Stream ID to initiate. - StreamID uint32 - - // PromiseID is the required Stream ID which this - // Push Promises - PromiseID uint32 - - // BlockFragment is part (or all) of a Header Block. - BlockFragment []byte - - // EndHeaders indicates that this frame contains an entire - // header block and is not followed by any - // CONTINUATION frames. - EndHeaders bool - - // PadLength is the optional number of bytes of zeros to add - // to this frame. - PadLength uint8 -} - -// WritePushPromise writes a single PushPromise Frame. -// -// As with Header Frames, This is the low level call for writing -// individual frames. Continuation frames are handled elsewhere. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { - if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { - return http2errStreamID - } - var flags http2Flags - if p.PadLength != 0 { - flags |= http2FlagPushPromisePadded - } - if p.EndHeaders { - flags |= http2FlagPushPromiseEndHeaders - } - f.startWrite(http2FramePushPromise, flags, p.StreamID) - if p.PadLength != 0 { - f.writeByte(p.PadLength) - } - if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { - return http2errStreamID - } - f.writeUint32(p.PromiseID) - f.wbuf = append(f.wbuf, p.BlockFragment...) - f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) - return f.endWrite() -} - -// WriteRawFrame writes a raw frame. This can be used to write -// extension frames unknown to this package. -func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { - f.startWrite(t, flags, streamID) - f.writeBytes(payload) - return f.endWrite() -} - -func http2readByte(p []byte) (remain []byte, b byte, err error) { - if len(p) == 0 { - return nil, 0, io.ErrUnexpectedEOF - } - return p[1:], p[0], nil -} - -func http2readUint32(p []byte) (remain []byte, v uint32, err error) { - if len(p) < 4 { - return nil, 0, io.ErrUnexpectedEOF - } - return p[4:], binary.BigEndian.Uint32(p[:4]), nil -} - -type http2streamEnder interface { - StreamEnded() bool -} - -type http2headersEnder interface { - HeadersEnded() bool -} - -type http2headersOrContinuation interface { - http2headersEnder - HeaderBlockFragment() []byte -} - -// A MetaHeadersFrame is the representation of one HEADERS frame and -// zero or more contiguous CONTINUATION frames and the decoding of -// their HPACK-encoded contents. -// -// This type of frame does not appear on the wire and is only returned -// by the Framer when Framer.ReadMetaHeaders is set. -type http2MetaHeadersFrame struct { - *http2HeadersFrame - - // Fields are the fields contained in the HEADERS and - // CONTINUATION frames. The underlying slice is owned by the - // Framer and must not be retained after the next call to - // ReadFrame. - // - // Fields are guaranteed to be in the correct http2 order and - // not have unknown pseudo header fields or invalid header - // field names or values. Required pseudo header fields may be - // missing, however. Use the MetaHeadersFrame.Pseudo accessor - // method access pseudo headers. - Fields []hpack.HeaderField - - // Truncated is whether the max header list size limit was hit - // and Fields is incomplete. The hpack decoder state is still - // valid, however. - Truncated bool -} - -// PseudoValue returns the given pseudo header field's value. -// The provided pseudo field should not contain the leading colon. -func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { - for _, hf := range mh.Fields { - if !hf.IsPseudo() { - return "" - } - if hf.Name[1:] == pseudo { - return hf.Value - } - } - return "" -} - -// RegularFields returns the regular (non-pseudo) header fields of mh. -// The caller does not own the returned slice. -func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { - for i, hf := range mh.Fields { - if !hf.IsPseudo() { - return mh.Fields[i:] - } - } - return nil -} - -// PseudoFields returns the pseudo header fields of mh. -// The caller does not own the returned slice. -func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { - for i, hf := range mh.Fields { - if !hf.IsPseudo() { - return mh.Fields[:i] - } - } - return mh.Fields -} - -func (mh *http2MetaHeadersFrame) checkPseudos() error { - var isRequest, isResponse bool - pf := mh.PseudoFields() - for i, hf := range pf { - switch hf.Name { - case ":method", ":path", ":scheme", ":authority": - isRequest = true - case ":status": - isResponse = true - default: - return http2pseudoHeaderError(hf.Name) - } - // Check for duplicates. - // This would be a bad algorithm, but N is 4. - // And this doesn't allocate. - for _, hf2 := range pf[:i] { - if hf.Name == hf2.Name { - return http2duplicatePseudoHeaderError(hf.Name) - } - } - } - if isRequest && isResponse { - return http2errMixPseudoHeaderTypes - } - return nil -} - -func (fr *http2Framer) maxHeaderStringLen() int { - v := fr.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) - } - // They had a crazy big number for MaxHeaderBytes anyway, - // so give them unlimited header lengths: - return 0 -} - -// readMetaFrame returns 0 or more CONTINUATION frames from fr and -// merge them into the provided hf and returns a MetaHeadersFrame -// with the decoded hpack values. -func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { - if fr.AllowIllegalReads { - return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") - } - mh := &http2MetaHeadersFrame{ - http2HeadersFrame: hf, - } - var remainSize = fr.maxHeaderListSize() - var sawRegular bool - - var invalid error // pseudo header field errors - hdec := fr.ReadMetaHeaders - hdec.SetEmitEnabled(true) - hdec.SetMaxStringLength(fr.maxHeaderStringLen()) - rawEmitFunc := func(hf hpack.HeaderField) { - if http2VerboseLogs && fr.logReads { - fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) - } - if !httpguts.ValidHeaderFieldValue(hf.Value) { - invalid = http2headerFieldValueError(hf.Value) - } - isPseudo := strings.HasPrefix(hf.Name, ":") - if isPseudo { - if sawRegular { - invalid = http2errPseudoAfterRegular - } - } else { - sawRegular = true - if !http2validWireHeaderFieldName(hf.Name) { - invalid = http2headerFieldNameError(hf.Name) - } - } - - if invalid != nil { - hdec.SetEmitEnabled(false) - return - } - - size := hf.Size() - if size > remainSize { - hdec.SetEmitEnabled(false) - mh.Truncated = true - return - } - remainSize -= size - - mh.Fields = append(mh.Fields, hf) - } - emitFunc := rawEmitFunc - - if len(dumps) > 0 { - emitFunc = func(hf hpack.HeaderField) { - for _, dump := range dumps { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) - } - rawEmitFunc(hf) - } - } - - hdec.SetEmitFunc(emitFunc) - // Lose reference to MetaHeadersFrame: - defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) - - var hc http2headersOrContinuation = hf - for { - frag := hc.HeaderBlockFragment() - if _, err := hdec.Write(frag); err != nil { - return nil, http2ConnectionError(http2ErrCodeCompression) - } - - if hc.HeadersEnded() { - break - } - if f, err := fr.ReadFrame(); err != nil { - return nil, err - } else { - hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder - } - } - - mh.http2HeadersFrame.headerFragBuf = nil - mh.http2HeadersFrame.invalidate() - - if err := hdec.Close(); err != nil { - return nil, http2ConnectionError(http2ErrCodeCompression) - } - if invalid != nil { - fr.errDetail = invalid - if http2VerboseLogs { - log.Printf("http2: invalid header: %v", invalid) - } - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} - } - if err := mh.checkPseudos(); err != nil { - fr.errDetail = err - if http2VerboseLogs { - log.Printf("http2: invalid pseudo headers: %v", err) - } - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} - } - return mh, nil -} - -func http2summarizeFrame(f http2Frame) string { - var buf bytes.Buffer - f.Header().writeDebug(&buf) - switch f := f.(type) { - case *http2SettingsFrame: - n := 0 - f.ForeachSetting(func(s http2Setting) error { - n++ - if n == 1 { - buf.WriteString(", settings:") - } - fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val) - return nil - }) - if n > 0 { - buf.Truncate(buf.Len() - 1) // remove trailing comma - } - case *http2DataFrame: - data := f.Data() - const max = 256 - if len(data) > max { - data = data[:max] - } - fmt.Fprintf(&buf, " data=%q", data) - if len(f.Data()) > max { - fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) - } - case *http2WindowUpdateFrame: - if f.StreamID == 0 { - buf.WriteString(" (conn)") - } - fmt.Fprintf(&buf, " incr=%v", f.Increment) - case *http2PingFrame: - fmt.Fprintf(&buf, " ping=%q", f.Data[:]) - case *http2GoAwayFrame: - fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", - f.LastStreamID, f.ErrCode, f.debugData) - case *http2RSTStreamFrame: - fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) - } - return buf.String() -} - func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return trace != nil && trace.WroteHeaderField != nil } diff --git a/h2_frame.go b/h2_frame.go new file mode 100644 index 00000000..ed7c1f4c --- /dev/null +++ b/h2_frame.go @@ -0,0 +1,1679 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "io" + "log" + "strings" + "sync" +) + +const http2frameHeaderLen = 9 + +var http2padZeros = make([]byte, 255) // zeros for padding + +// A FrameType is a registered frame type as defined in +// http://http2.github.io/http2-spec/#rfc.section.11.2 +type http2FrameType uint8 + +const ( + http2FrameData http2FrameType = 0x0 + http2FrameHeaders http2FrameType = 0x1 + http2FramePriority http2FrameType = 0x2 + http2FrameRSTStream http2FrameType = 0x3 + http2FrameSettings http2FrameType = 0x4 + http2FramePushPromise http2FrameType = 0x5 + http2FramePing http2FrameType = 0x6 + http2FrameGoAway http2FrameType = 0x7 + http2FrameWindowUpdate http2FrameType = 0x8 + http2FrameContinuation http2FrameType = 0x9 +) + +var http2frameName = map[http2FrameType]string{ + http2FrameData: "DATA", + http2FrameHeaders: "HEADERS", + http2FramePriority: "PRIORITY", + http2FrameRSTStream: "RST_STREAM", + http2FrameSettings: "SETTINGS", + http2FramePushPromise: "PUSH_PROMISE", + http2FramePing: "PING", + http2FrameGoAway: "GOAWAY", + http2FrameWindowUpdate: "WINDOW_UPDATE", + http2FrameContinuation: "CONTINUATION", +} + +func (t http2FrameType) String() string { + if s, ok := http2frameName[t]; ok { + return s + } + return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) +} + +// Flags is a bitmask of HTTP/2 flags. +// The meaning of flags varies depending on the frame type. +type http2Flags uint8 + +// Has reports whether f contains all (0 or more) flags in v. +func (f http2Flags) Has(v http2Flags) bool { + return (f & v) == v +} + +// Frame-specific FrameHeader flag bits. +const ( + // Data Frame + http2FlagDataEndStream http2Flags = 0x1 + http2FlagDataPadded http2Flags = 0x8 + + // Headers Frame + http2FlagHeadersEndStream http2Flags = 0x1 + http2FlagHeadersEndHeaders http2Flags = 0x4 + http2FlagHeadersPadded http2Flags = 0x8 + http2FlagHeadersPriority http2Flags = 0x20 + + // Settings Frame + http2FlagSettingsAck http2Flags = 0x1 + + // Ping Frame + http2FlagPingAck http2Flags = 0x1 + + // Continuation Frame + http2FlagContinuationEndHeaders http2Flags = 0x4 + + http2FlagPushPromiseEndHeaders http2Flags = 0x4 + http2FlagPushPromisePadded http2Flags = 0x8 +) + +var http2flagName = map[http2FrameType]map[http2Flags]string{ + http2FrameData: { + http2FlagDataEndStream: "END_STREAM", + http2FlagDataPadded: "PADDED", + }, + http2FrameHeaders: { + http2FlagHeadersEndStream: "END_STREAM", + http2FlagHeadersEndHeaders: "END_HEADERS", + http2FlagHeadersPadded: "PADDED", + http2FlagHeadersPriority: "PRIORITY", + }, + http2FrameSettings: { + http2FlagSettingsAck: "ACK", + }, + http2FramePing: { + http2FlagPingAck: "ACK", + }, + http2FrameContinuation: { + http2FlagContinuationEndHeaders: "END_HEADERS", + }, + http2FramePushPromise: { + http2FlagPushPromiseEndHeaders: "END_HEADERS", + http2FlagPushPromisePadded: "PADDED", + }, +} + +// a frameParser parses a frame given its FrameHeader and payload +// bytes. The length of payload will always equal fh.Length (which +// might be 0). +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) + +var http2frameParsers = map[http2FrameType]http2frameParser{ + http2FrameData: http2parseDataFrame, + http2FrameHeaders: http2parseHeadersFrame, + http2FramePriority: http2parsePriorityFrame, + http2FrameRSTStream: http2parseRSTStreamFrame, + http2FrameSettings: http2parseSettingsFrame, + http2FramePushPromise: http2parsePushPromise, + http2FramePing: http2parsePingFrame, + http2FrameGoAway: http2parseGoAwayFrame, + http2FrameWindowUpdate: http2parseWindowUpdateFrame, + http2FrameContinuation: http2parseContinuationFrame, +} + +func http2typeFrameParser(t http2FrameType) http2frameParser { + if f := http2frameParsers[t]; f != nil { + return f + } + return http2parseUnknownFrame +} + +// A FrameHeader is the 9 byte header of all HTTP/2 frames. +// +// See http://http2.github.io/http2-spec/#FrameHeader +type http2FrameHeader struct { + valid bool // caller can access []byte fields in the Frame + + // Type is the 1 byte frame type. There are ten standard frame + // types, but extension frame types may be written by WriteRawFrame + // and will be returned by ReadFrame (as UnknownFrame). + Type http2FrameType + + // Flags are the 1 byte of 8 potential bit flags per frame. + // They are specific to the frame type. + Flags http2Flags + + // Length is the length of the frame, not including the 9 byte header. + // The maximum size is one byte less than 16MB (uint24), but only + // frames up to 16KB are allowed without peer agreement. + Length uint32 + + // StreamID is which stream this frame is for. Certain frames + // are not stream-specific, in which case this field is 0. + StreamID uint32 +} + +// Header returns h. It exists so FrameHeaders can be embedded in other +// specific frame types and implement the Frame interface. +func (h http2FrameHeader) Header() http2FrameHeader { return h } + +func (h http2FrameHeader) String() string { + var buf bytes.Buffer + buf.WriteString("[FrameHeader ") + h.writeDebug(&buf) + buf.WriteByte(']') + return buf.String() +} + +func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { + buf.WriteString(h.Type.String()) + if h.Flags != 0 { + buf.WriteString(" flags=") + set := 0 + for i := uint8(0); i < 8; i++ { + if h.Flags&(1< 1 { + buf.WriteByte('|') + } + name := http2flagName[h.Type][http2Flags(1<>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) +} + +func (f *http2Framer) endWrite() error { + // Now that we know the final size, fill in the FrameHeader in + // the space previously reserved for it. Abuse append. + length := len(f.wbuf) - http2frameHeaderLen + if length >= (1 << 24) { + return http2ErrFrameTooLarge + } + _ = append(f.wbuf[:0], + byte(length>>16), + byte(length>>8), + byte(length)) + if f.logWrites { + f.logWrite() + } + + n, err := f.w.Write(f.wbuf) + if err == nil && n != len(f.wbuf) { + err = io.ErrShortWrite + } + return err +} + +func (f *http2Framer) logWrite() { + if f.debugFramer == nil { + f.debugFramerBuf = new(bytes.Buffer) + f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) + f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below + // Let us read anything, even if we accidentally wrote it + // in the wrong order: + f.debugFramer.AllowIllegalReads = true + } + f.debugFramerBuf.Write(f.wbuf) + fr, err := f.debugFramer.ReadFrame() + if err != nil { + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) + return + } + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) +} + +func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } + +func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } + +func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } + +func (f *http2Framer) writeUint32(v uint32) { + f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +const ( + http2minMaxFrameSize = 1 << 14 + http2maxFrameSize = 1<<24 - 1 +) + +// SetReuseFrames allows the Framer to reuse Frames. +// If called on a Framer, Frames returned by calls to ReadFrame are only +// valid until the next call to ReadFrame. +func (fr *http2Framer) SetReuseFrames() { + if fr.frameCache != nil { + return + } + fr.frameCache = &http2frameCache{} +} + +type http2frameCache struct { + dataFrame http2DataFrame +} + +func (fc *http2frameCache) getDataFrame() *http2DataFrame { + if fc == nil { + return &http2DataFrame{} + } + return &fc.dataFrame +} + +// NewFramer returns a Framer that writes frames to w and reads them from r. +func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { + fr := &http2Framer{ + w: w, + r: r, + countError: func(string) {}, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, + } + fr.getReadBuf = func(size uint32) []byte { + if cap(fr.readBuf) >= int(size) { + return fr.readBuf[:size] + } + fr.readBuf = make([]byte, size) + return fr.readBuf + } + fr.SetMaxReadFrameSize(http2maxFrameSize) + return fr +} + +// SetMaxReadFrameSize sets the maximum size of a frame +// that will be read by a subsequent call to ReadFrame. +// It is the caller's responsibility to advertise this +// limit with a SETTINGS frame. +func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { + if v > http2maxFrameSize { + v = http2maxFrameSize + } + fr.maxReadSize = v +} + +// ErrorDetail returns a more detailed error of the last error +// returned by Framer.ReadFrame. For instance, if ReadFrame +// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail +// will say exactly what was invalid. ErrorDetail is not guaranteed +// to return a non-nil value and like the rest of the http2 package, +// its return value is not protected by an API compatibility promise. +// ErrorDetail is reset after the next call to ReadFrame. +func (fr *http2Framer) ErrorDetail() error { + return fr.errDetail +} + +// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// sends a frame that is larger than declared with SetMaxReadFrameSize. +var http2ErrFrameTooLarge = errors.New("http2: frame too large") + +// terminalReadFrameError reports whether err is an unrecoverable +// error from ReadFrame and no other frames should be read. +func http2terminalReadFrameError(err error) bool { + if _, ok := err.(http2StreamError); ok { + return false + } + return err != nil +} + +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from the underlying +// reader. +func (fr *http2Framer) ReadFrame() (http2Frame, error) { + fr.errDetail = nil + if fr.lastFrame != nil { + fr.lastFrame.invalidate() + } + fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) + if err != nil { + return nil, err + } + if fh.Length > fr.maxReadSize { + return nil, http2ErrFrameTooLarge + } + payload := fr.getReadBuf(fh.Length) + if _, err := io.ReadFull(fr.r, payload); err != nil { + return nil, err + } + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) + if err != nil { + if ce, ok := err.(http2connError); ok { + return nil, fr.connError(ce.Code, ce.Reason) + } + return nil, err + } + if err := fr.checkFrameOrder(f); err != nil { + return nil, err + } + if fr.logReads { + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + } + if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + dumps := getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + if len(dumps) > 0 { + dd := []*dumper{} + for _, dump := range dumps { + if dump.ResponseHeader { + dd = append(dd, dump) + } + } + dumps = dd + } + hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dumps) + if err == nil && len(dumps) > 0 { + for _, dump := range dumps { + dump.dump([]byte("\r\n")) + } + } + return hr, err + } + return f, nil +} + +// connError returns ConnectionError(code) but first +// stashes away a public reason to the caller can optionally relay it +// to the peer before hanging up on them. This might help others debug +// their implementations. +func (fr *http2Framer) connError(code http2ErrCode, reason string) error { + fr.errDetail = errors.New(reason) + return http2ConnectionError(code) +} + +// checkFrameOrder reports an error if f is an invalid frame to return +// next from ReadFrame. Mostly it checks whether HEADERS and +// CONTINUATION frames are contiguous. +func (fr *http2Framer) checkFrameOrder(f http2Frame) error { + last := fr.lastFrame + fr.lastFrame = f + if fr.AllowIllegalReads { + return nil + } + + fh := f.Header() + if fr.lastHeaderStream != 0 { + if fh.Type != http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", + fh.Type, fh.StreamID, + last.Header().Type, fr.lastHeaderStream)) + } + if fh.StreamID != fr.lastHeaderStream { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", + fh.StreamID, fr.lastHeaderStream)) + } + } else if fh.Type == http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + } + + switch fh.Type { + case http2FrameHeaders, http2FrameContinuation: + if fh.Flags.Has(http2FlagHeadersEndHeaders) { + fr.lastHeaderStream = 0 + } else { + fr.lastHeaderStream = fh.StreamID + } + } + + return nil +} + +// A DataFrame conveys arbitrary, variable-length sequences of octets +// associated with a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.1 +type http2DataFrame struct { + http2FrameHeader + data []byte +} + +func (f *http2DataFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) +} + +// Data returns the frame's data octets, not including any padding +// size byte or padding suffix bytes. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2DataFrame) Data() []byte { + f.checkValid() + return f.data +} + +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + countError("frame_data_stream_0") + return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} + } + f := fc.getDataFrame() + f.http2FrameHeader = fh + + var padSize byte + if fh.Flags.Has(http2FlagDataPadded) { + var err error + payload, padSize, err = http2readByte(payload) + if err != nil { + countError("frame_data_pad_byte_short") + return nil, err + } + } + if int(padSize) > len(payload) { + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 + countError("frame_data_pad_too_big") + return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} + } + f.data = payload[:len(payload)-int(padSize)] + return f, nil +} + +var ( + http2errStreamID = errors.New("invalid stream ID") + http2errDepStreamID = errors.New("invalid dependent stream ID") + http2errPadLength = errors.New("pad length too large") + http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") +) + +func http2validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} + +func http2validStreamID(streamID uint32) bool { + return streamID != 0 && streamID&(1<<31) == 0 +} + +// writeData writes a DATA frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + return f.WriteDataPadded(streamID, endStream, data, nil) +} + +// WriteDataPadded writes a DATA frame with optional padding. +// +// If pad is nil, the padding bit is not sent. +// The length of pad must not exceed 255 bytes. +// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if len(pad) > 0 { + if len(pad) > 255 { + return http2errPadLength + } + if !f.AllowIllegalWrites { + for _, b := range pad { + if b != 0 { + // "Padding octets MUST be set to zero when sending." + return http2errPadBytes + } + } + } + } + var flags http2Flags + if endStream { + flags |= http2FlagDataEndStream + } + if pad != nil { + flags |= http2FlagDataPadded + } + f.startWrite(http2FrameData, flags, streamID) + if pad != nil { + f.wbuf = append(f.wbuf, byte(len(pad))) + } + f.wbuf = append(f.wbuf, data...) + f.wbuf = append(f.wbuf, pad...) + return f.endWrite() +} + +// A SettingsFrame conveys configuration parameters that affect how +// endpoints communicate, such as preferences and constraints on peer +// behavior. +// +// See http://http2.github.io/http2-spec/#SETTINGS +type http2SettingsFrame struct { + http2FrameHeader + p []byte +} + +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { + // When this (ACK 0x1) bit is set, the payload of the + // SETTINGS frame MUST be empty. Receipt of a + // SETTINGS frame with the ACK flag set and a length + // field value other than 0 MUST be treated as a + // connection error (Section 5.4.1) of type + // FRAME_SIZE_ERROR. + countError("frame_settings_ack_with_length") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + // SETTINGS frames always apply to a connection, + // never a single stream. The stream identifier for a + // SETTINGS frame MUST be zero (0x0). If an endpoint + // receives a SETTINGS frame whose stream identifier + // field is anything other than 0x0, the endpoint MUST + // respond with a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR. + countError("frame_settings_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p)%6 != 0 { + countError("frame_settings_mod_6") + // Expecting even number of 6 byte settings. + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + f := &http2SettingsFrame{http2FrameHeader: fh, p: p} + if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + countError("frame_settings_window_size_too_big") + // Values above the maximum flow control window size of 2^31 - 1 MUST + // be treated as a connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + return nil, http2ConnectionError(http2ErrCodeFlowControl) + } + return f, nil +} + +func (f *http2SettingsFrame) IsAck() bool { + return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) +} + +func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if s := f.Setting(i); s.ID == id { + return s.Val, true + } + } + return 0, false +} + +// Setting returns the setting from the frame at the given 0-based index. +// The index must be >= 0 and less than f.NumSettings(). +func (f *http2SettingsFrame) Setting(i int) http2Setting { + buf := f.p + return http2Setting{ + ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), + Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), + } +} + +func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 } + +// HasDuplicates reports whether f contains any duplicate setting IDs. +func (f *http2SettingsFrame) HasDuplicates() bool { + num := f.NumSettings() + if num == 0 { + return false + } + // If it's small enough (the common case), just do the n^2 + // thing and avoid a map allocation. + if num < 10 { + for i := 0; i < num; i++ { + idi := f.Setting(i).ID + for j := i + 1; j < num; j++ { + idj := f.Setting(j).ID + if idi == idj { + return true + } + } + } + return false + } + seen := map[http2SettingID]bool{} + for i := 0; i < num; i++ { + id := f.Setting(i).ID + if seen[id] { + return true + } + seen[id] = true + } + return false +} + +// ForeachSetting runs fn for each setting. +// It stops and returns the first error. +func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if err := fn(f.Setting(i)); err != nil { + return err + } + } + return nil +} + +// WriteSettings writes a SETTINGS frame with zero or more settings +// specified and the ACK bit not set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettings(settings ...http2Setting) error { + f.startWrite(http2FrameSettings, 0, 0) + for _, s := range settings { + f.writeUint16(uint16(s.ID)) + f.writeUint32(s.Val) + } + return f.endWrite() +} + +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettingsAck() error { + f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) + return f.endWrite() +} + +// A PingFrame is a mechanism for measuring a minimal round trip time +// from the sender, as well as determining whether an idle connection +// is still functional. +// See http://http2.github.io/http2-spec/#rfc.section.6.7 +type http2PingFrame struct { + http2FrameHeader + Data [8]byte +} + +func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } + +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if len(payload) != 8 { + countError("frame_ping_length") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + countError("frame_ping_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + f := &http2PingFrame{http2FrameHeader: fh} + copy(f.Data[:], payload) + return f, nil +} + +func (f *http2Framer) WritePing(ack bool, data [8]byte) error { + var flags http2Flags + if ack { + flags = http2FlagPingAck + } + f.startWrite(http2FramePing, flags, 0) + f.writeBytes(data[:]) + return f.endWrite() +} + +// A GoAwayFrame informs the remote peer to stop creating streams on this connection. +// See http://http2.github.io/http2-spec/#rfc.section.6.8 +type http2GoAwayFrame struct { + http2FrameHeader + LastStreamID uint32 + ErrCode http2ErrCode + debugData []byte +} + +// DebugData returns any debug data in the GOAWAY frame. Its contents +// are not defined. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2GoAwayFrame) DebugData() []byte { + f.checkValid() + return f.debugData +} + +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.StreamID != 0 { + countError("frame_goaway_has_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p) < 8 { + countError("frame_goaway_short") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + return &http2GoAwayFrame{ + http2FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], + }, nil +} + +func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { + f.startWrite(http2FrameGoAway, 0, 0) + f.writeUint32(maxStreamID & (1<<31 - 1)) + f.writeUint32(uint32(code)) + f.writeBytes(debugData) + return f.endWrite() +} + +// An UnknownFrame is the frame type returned when the frame type is unknown +// or no specific frame type parser exists. +type http2UnknownFrame struct { + http2FrameHeader + p []byte +} + +// Payload returns the frame's payload (after the header). It is not +// valid to call this method after a subsequent call to +// Framer.ReadFrame, nor is it valid to retain the returned slice. +// The memory is owned by the Framer and is invalidated when the next +// frame is read. +func (f *http2UnknownFrame) Payload() []byte { + f.checkValid() + return f.p +} + +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + return &http2UnknownFrame{fh, p}, nil +} + +// A WindowUpdateFrame is used to implement flow control. +// See http://http2.github.io/http2-spec/#rfc.section.6.9 +type http2WindowUpdateFrame struct { + http2FrameHeader + Increment uint32 // never read with high bit set +} + +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if len(p) != 4 { + countError("frame_windowupdate_bad_len") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit + if inc == 0 { + // A receiver MUST treat the receipt of a + // WINDOW_UPDATE frame with an flow control window + // increment of 0 as a stream error (Section 5.4.2) of + // type PROTOCOL_ERROR; errors on the connection flow + // control window MUST be treated as a connection + // error (Section 5.4.1). + if fh.StreamID == 0 { + countError("frame_windowupdate_zero_inc_conn") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + countError("frame_windowupdate_zero_inc_stream") + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + return &http2WindowUpdateFrame{ + http2FrameHeader: fh, + Increment: inc, + }, nil +} + +// WriteWindowUpdate writes a WINDOW_UPDATE frame. +// The increment value must be between 1 and 2,147,483,647, inclusive. +// If the Stream ID is zero, the window update applies to the +// connection as a whole. +func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { + // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." + if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + return errors.New("illegal window increment value") + } + f.startWrite(http2FrameWindowUpdate, 0, streamID) + f.writeUint32(incr) + return f.endWrite() +} + +// A HeadersFrame is used to open a stream and additionally carries a +// header block fragment. +type http2HeadersFrame struct { + http2FrameHeader + + // Priority is set if FlagHeadersPriority is set in the FrameHeader. + Priority http2PriorityParam + + headerFragBuf []byte // not owned +} + +func (f *http2HeadersFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2HeadersFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) +} + +func (f *http2HeadersFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) +} + +func (f *http2HeadersFrame) HasPriority() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) +} + +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { + hf := &http2HeadersFrame{ + http2FrameHeader: fh, + } + if fh.StreamID == 0 { + // HEADERS frames MUST be associated with a stream. If a HEADERS frame + // is received whose stream identifier field is 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + countError("frame_headers_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} + } + var padLength uint8 + if fh.Flags.Has(http2FlagHeadersPadded) { + if p, padLength, err = http2readByte(p); err != nil { + countError("frame_headers_pad_short") + return + } + } + if fh.Flags.Has(http2FlagHeadersPriority) { + var v uint32 + p, v, err = http2readUint32(p) + if err != nil { + countError("frame_headers_prio_short") + return nil, err + } + hf.Priority.StreamDep = v & 0x7fffffff + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set + p, hf.Priority.Weight, err = http2readByte(p) + if err != nil { + countError("frame_headers_prio_weight_short") + return nil, err + } + } + if len(p)-int(padLength) < 0 { + countError("frame_headers_pad_too_big") + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + hf.headerFragBuf = p[:len(p)-int(padLength)] + return hf, nil +} + +// HeadersFrameParam are the parameters for writing a HEADERS frame. +type http2HeadersFrameParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndStream indicates that the header block is the last that + // the endpoint will send for the identified stream. Setting + // this flag causes the stream to enter one of "half closed" + // states. + EndStream bool + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 + + // Priority, if non-zero, includes stream priority information + // in the HEADER frame. + Priority http2PriorityParam +} + +// WriteHeaders writes a single HEADERS frame. +// +// This is a low-level header writing method. Encoding headers and +// splitting them into any necessary CONTINUATION frames is handled +// elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagHeadersPadded + } + if p.EndStream { + flags |= http2FlagHeadersEndStream + } + if p.EndHeaders { + flags |= http2FlagHeadersEndHeaders + } + if !p.Priority.IsZero() { + flags |= http2FlagHeadersPriority + } + f.startWrite(http2FrameHeaders, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !p.Priority.IsZero() { + v := p.Priority.StreamDep + if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return http2errDepStreamID + } + if p.Priority.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Priority.Weight) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// A PriorityFrame specifies the sender-advised priority of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.3 +type http2PriorityFrame struct { + http2FrameHeader + http2PriorityParam +} + +// PriorityParam are the stream prioritzation parameters. +type http2PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p http2PriorityParam) IsZero() bool { + return p == http2PriorityParam{} +} + +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + countError("frame_priority_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} + } + if len(payload) != 5 { + countError("frame_priority_bad_length") + return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} + } + v := binary.BigEndian.Uint32(payload[:4]) + streamID := v & 0x7fffffff // mask off high bit + return &http2PriorityFrame{ + http2FrameHeader: fh, + http2PriorityParam: http2PriorityParam{ + Weight: payload[4], + StreamDep: streamID, + Exclusive: streamID != v, // was high bit set? + }, + }, nil +} + +// WritePriority writes a PRIORITY frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if !http2validStreamIDOrZero(p.StreamDep) { + return http2errDepStreamID + } + f.startWrite(http2FramePriority, 0, streamID) + v := p.StreamDep + if p.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Weight) + return f.endWrite() +} + +// A RSTStreamFrame allows for abnormal termination of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.4 +type http2RSTStreamFrame struct { + http2FrameHeader + ErrCode http2ErrCode +} + +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if len(p) != 4 { + countError("frame_rststream_bad_len") + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID == 0 { + countError("frame_rststream_zero_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil +} + +// WriteRSTStream writes a RST_STREAM frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.startWrite(http2FrameRSTStream, 0, streamID) + f.writeUint32(uint32(code)) + return f.endWrite() +} + +// A ContinuationFrame is used to continue a sequence of header block fragments. +// See http://http2.github.io/http2-spec/#rfc.section.6.10 +type http2ContinuationFrame struct { + http2FrameHeader + headerFragBuf []byte +} + +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { + if fh.StreamID == 0 { + countError("frame_continuation_zero_stream") + return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} + } + return &http2ContinuationFrame{fh, p}, nil +} + +func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2ContinuationFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) +} + +// WriteContinuation writes a CONTINUATION frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if endHeaders { + flags |= http2FlagContinuationEndHeaders + } + f.startWrite(http2FrameContinuation, flags, streamID) + f.wbuf = append(f.wbuf, headerBlockFragment...) + return f.endWrite() +} + +// A PushPromiseFrame is used to initiate a server stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.6 +type http2PushPromiseFrame struct { + http2FrameHeader + PromiseID uint32 + headerFragBuf []byte // not owned +} + +func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2PushPromiseFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) +} + +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { + pp := &http2PushPromiseFrame{ + http2FrameHeader: fh, + } + if pp.StreamID == 0 { + // PUSH_PROMISE frames MUST be associated with an existing, + // peer-initiated stream. The stream identifier of a + // PUSH_PROMISE frame indicates the stream it is associated + // with. If the stream identifier field specifies the value + // 0x0, a recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + countError("frame_pushpromise_zero_stream") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + // The PUSH_PROMISE frame includes optional padding. + // Padding fields and flags are identical to those defined for DATA frames + var padLength uint8 + if fh.Flags.Has(http2FlagPushPromisePadded) { + if p, padLength, err = http2readByte(p); err != nil { + countError("frame_pushpromise_pad_short") + return + } + } + + p, pp.PromiseID, err = http2readUint32(p) + if err != nil { + countError("frame_pushpromise_promiseid_short") + return + } + pp.PromiseID = pp.PromiseID & (1<<31 - 1) + + if int(padLength) > len(p) { + // like the DATA frame, error out if padding is longer than the body. + countError("frame_pushpromise_pad_too_big") + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + pp.headerFragBuf = p[:len(p)-int(padLength)] + return pp, nil +} + +// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. +type http2PushPromiseParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + + // PromiseID is the required Stream ID which this + // Push Promises + PromiseID uint32 + + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 +} + +// WritePushPromise writes a single PushPromise Frame. +// +// As with Header Frames, This is the low level call for writing +// individual frames. Continuation frames are handled elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagPushPromisePadded + } + if p.EndHeaders { + flags |= http2FlagPushPromiseEndHeaders + } + f.startWrite(http2FramePushPromise, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.writeUint32(p.PromiseID) + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// WriteRawFrame writes a raw frame. This can be used to write +// extension frames unknown to this package. +func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { + f.startWrite(t, flags, streamID) + f.writeBytes(payload) + return f.endWrite() +} + +func http2readByte(p []byte) (remain []byte, b byte, err error) { + if len(p) == 0 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[1:], p[0], nil +} + +func http2readUint32(p []byte) (remain []byte, v uint32, err error) { + if len(p) < 4 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[4:], binary.BigEndian.Uint32(p[:4]), nil +} + +type http2streamEnder interface { + StreamEnded() bool +} + +type http2headersEnder interface { + HeadersEnded() bool +} + +type http2headersOrContinuation interface { + http2headersEnder + HeaderBlockFragment() []byte +} + +// A MetaHeadersFrame is the representation of one HEADERS frame and +// zero or more contiguous CONTINUATION frames and the decoding of +// their HPACK-encoded contents. +// +// This type of frame does not appear on the wire and is only returned +// by the Framer when Framer.ReadMetaHeaders is set. +type http2MetaHeadersFrame struct { + *http2HeadersFrame + + // Fields are the fields contained in the HEADERS and + // CONTINUATION frames. The underlying slice is owned by the + // Framer and must not be retained after the next call to + // ReadFrame. + // + // Fields are guaranteed to be in the correct http2 order and + // not have unknown pseudo header fields or invalid header + // field names or values. Required pseudo header fields may be + // missing, however. Use the MetaHeadersFrame.Pseudo accessor + // method access pseudo headers. + Fields []hpack.HeaderField + + // Truncated is whether the max header list size limit was hit + // and Fields is incomplete. The hpack decoder state is still + // valid, however. + Truncated bool +} + +// PseudoValue returns the given pseudo header field's value. +// The provided pseudo field should not contain the leading colon. +func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { + for _, hf := range mh.Fields { + if !hf.IsPseudo() { + return "" + } + if hf.Name[1:] == pseudo { + return hf.Value + } + } + return "" +} + +// RegularFields returns the regular (non-pseudo) header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[i:] + } + } + return nil +} + +// PseudoFields returns the pseudo header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[:i] + } + } + return mh.Fields +} + +func (mh *http2MetaHeadersFrame) checkPseudos() error { + var isRequest, isResponse bool + pf := mh.PseudoFields() + for i, hf := range pf { + switch hf.Name { + case ":method", ":path", ":scheme", ":authority": + isRequest = true + case ":status": + isResponse = true + default: + return http2pseudoHeaderError(hf.Name) + } + // Check for duplicates. + // This would be a bad algorithm, but N is 4. + // And this doesn't allocate. + for _, hf2 := range pf[:i] { + if hf.Name == hf2.Name { + return http2duplicatePseudoHeaderError(hf.Name) + } + } + } + if isRequest && isResponse { + return http2errMixPseudoHeaderTypes + } + return nil +} + +func (fr *http2Framer) maxHeaderStringLen() int { + v := fr.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + // They had a crazy big number for MaxHeaderBytes anyway, + // so give them unlimited header lengths: + return 0 +} + +// readMetaFrame returns 0 or more CONTINUATION frames from fr and +// merge them into the provided hf and returns a MetaHeadersFrame +// with the decoded hpack values. +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { + if fr.AllowIllegalReads { + return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") + } + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: hf, + } + var remainSize = fr.maxHeaderListSize() + var sawRegular bool + + var invalid error // pseudo header field errors + hdec := fr.ReadMetaHeaders + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + rawEmitFunc := func(hf hpack.HeaderField) { + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) + } + if !httpguts.ValidHeaderFieldValue(hf.Value) { + invalid = http2headerFieldValueError(hf.Value) + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = http2errPseudoAfterRegular + } + } else { + sawRegular = true + if !http2validWireHeaderFieldName(hf.Name) { + invalid = http2headerFieldNameError(hf.Name) + } + } + + if invalid != nil { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + mh.Truncated = true + return + } + remainSize -= size + + mh.Fields = append(mh.Fields, hf) + } + emitFunc := rawEmitFunc + + if len(dumps) > 0 { + emitFunc = func(hf hpack.HeaderField) { + for _, dump := range dumps { + dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + } + rawEmitFunc(hf) + } + } + + hdec.SetEmitFunc(emitFunc) + // Lose reference to MetaHeadersFrame: + defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var hc http2headersOrContinuation = hf + for { + frag := hc.HeaderBlockFragment() + if _, err := hdec.Write(frag); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + + if hc.HeadersEnded() { + break + } + if f, err := fr.ReadFrame(); err != nil { + return nil, err + } else { + hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder + } + } + + mh.http2HeadersFrame.headerFragBuf = nil + mh.http2HeadersFrame.invalidate() + + if err := hdec.Close(); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + if invalid != nil { + fr.errDetail = invalid + if http2VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} + } + if err := mh.checkPseudos(); err != nil { + fr.errDetail = err + if http2VerboseLogs { + log.Printf("http2: invalid pseudo headers: %v", err) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} + } + return mh, nil +} + +func http2summarizeFrame(f http2Frame) string { + var buf bytes.Buffer + f.Header().writeDebug(&buf) + switch f := f.(type) { + case *http2SettingsFrame: + n := 0 + f.ForeachSetting(func(s http2Setting) error { + n++ + if n == 1 { + buf.WriteString(", settings:") + } + fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val) + return nil + }) + if n > 0 { + buf.Truncate(buf.Len() - 1) // remove trailing comma + } + case *http2DataFrame: + data := f.Data() + const max = 256 + if len(data) > max { + data = data[:max] + } + fmt.Fprintf(&buf, " data=%q", data) + if len(f.Data()) > max { + fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) + } + case *http2WindowUpdateFrame: + if f.StreamID == 0 { + buf.WriteString(" (conn)") + } + fmt.Fprintf(&buf, " incr=%v", f.Increment) + case *http2PingFrame: + fmt.Fprintf(&buf, " ping=%q", f.Data[:]) + case *http2GoAwayFrame: + fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", + f.LastStreamID, f.ErrCode, f.debugData) + case *http2RSTStreamFrame: + fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) + } + return buf.String() +} From 9a94b2f4638aa32b2d6d3b86d57c8f566c7d2494 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:34:32 +0800 Subject: [PATCH 325/843] add h2_frame_test.go --- h2_frame.go | 5 +- h2_frame_test.go | 1284 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1288 insertions(+), 1 deletion(-) create mode 100644 h2_frame_test.go diff --git a/h2_frame.go b/h2_frame.go index ed7c1f4c..4a9ebf85 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -523,7 +523,10 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { - dumps := getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + var dumps []*dumper + if fr.cc != nil { + dumps = getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + } if len(dumps) > 0 { dd := []*dumper{} for _, dump := range dumps { diff --git a/h2_frame_test.go b/h2_frame_test.go new file mode 100644 index 00000000..8a8d2a98 --- /dev/null +++ b/h2_frame_test.go @@ -0,0 +1,1284 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "testing" + "unsafe" + + "golang.org/x/net/http2/hpack" +) + +func testFramer() (*http2Framer, *bytes.Buffer) { + buf := new(bytes.Buffer) + return http2NewFramer(buf, buf), buf +} + +func TestFrameSizes(t *testing.T) { + // Catch people rearranging the FrameHeader fields. + if got, want := int(unsafe.Sizeof(http2FrameHeader{})), 12; got != want { + t.Errorf("FrameHeader size = %d; want %d", got, want) + } +} + +func TestFrameTypeString(t *testing.T) { + tests := []struct { + ft http2FrameType + want string + }{ + {http2FrameData, "DATA"}, + {http2FramePing, "PING"}, + {http2FrameGoAway, "GOAWAY"}, + {0xf, "UNKNOWN_FRAME_TYPE_15"}, + } + + for i, tt := range tests { + got := tt.ft.String() + if got != tt.want { + t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want) + } + } +} + +func TestWriteRST(t *testing.T) { + fr, buf := testFramer() + var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 + var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4 + fr.WriteRSTStream(streamID, http2ErrCode(errCode)) + const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &http2RSTStreamFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + Type: 0x3, + Flags: 0x0, + Length: 0x4, + StreamID: 0x1020304, + }, + ErrCode: 0x7060504, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestWriteData(t *testing.T) { + fr, buf := testFramer() + var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 + data := []byte("ABC") + fr.WriteData(streamID, true, data) + const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + df, ok := f.(*http2DataFrame) + if !ok { + t.Fatalf("got %T; want *http2DataFrame", f) + } + if !bytes.Equal(df.Data(), data) { + t.Errorf("got %q; want %q", df.Data(), data) + } + if f.Header().Flags&1 == 0 { + t.Errorf("didn't see END_STREAM flag") + } +} + +func TestWriteDataPadded(t *testing.T) { + tests := [...]struct { + streamID uint32 + endStream bool + data []byte + pad []byte + wantHeader http2FrameHeader + }{ + // Unpadded: + 0: { + streamID: 1, + endStream: true, + data: []byte("foo"), + pad: nil, + wantHeader: http2FrameHeader{ + Type: http2FrameData, + Flags: http2FlagDataEndStream, + Length: 3, + StreamID: 1, + }, + }, + + // Padded bit set, but no padding: + 1: { + streamID: 1, + endStream: true, + data: []byte("foo"), + pad: []byte{}, + wantHeader: http2FrameHeader{ + Type: http2FrameData, + Flags: http2FlagDataEndStream | http2FlagDataPadded, + Length: 4, + StreamID: 1, + }, + }, + + // Padded bit set, with padding: + 2: { + streamID: 1, + endStream: false, + data: []byte("foo"), + pad: []byte{0, 0, 0}, + wantHeader: http2FrameHeader{ + Type: http2FrameData, + Flags: http2FlagDataPadded, + Length: 7, + StreamID: 1, + }, + }, + } + for i, tt := range tests { + fr, _ := testFramer() + fr.WriteDataPadded(tt.streamID, tt.endStream, tt.data, tt.pad) + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("%d. ReadFrame: %v", i, err) + continue + } + got := f.Header() + tt.wantHeader.valid = true + if !got.Equal(tt.wantHeader) { + t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader) + continue + } + df := f.(*http2DataFrame) + if !bytes.Equal(df.Data(), tt.data) { + t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data) + } + } +} + +func (fh http2FrameHeader) Equal(b http2FrameHeader) bool { + return fh.valid == b.valid && + fh.Type == b.Type && + fh.Flags == b.Flags && + fh.Length == b.Length && + fh.StreamID == b.StreamID +} + +func TestWriteHeaders(t *testing.T) { + tests := []struct { + name string + p http2HeadersFrameParam + wantEnc string + wantFrame *http2HeadersFrame + }{ + { + "basic", + http2HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + Priority: http2PriorityParam{}, + }, + "\x00\x00\x03\x01\x00\x00\x00\x00*abc", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Length: uint32(len("abc")), + }, + Priority: http2PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "basic + end flags", + http2HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + Priority: http2PriorityParam{}, + }, + "\x00\x00\x03\x01\x05\x00\x00\x00*abc", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders, + Length: uint32(len("abc")), + }, + Priority: http2PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "with padding", + http2HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 5, + Priority: http2PriorityParam{}, + }, + "\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded, + Length: uint32(1 + len("abc") + 5), // pad length + contents + padding + }, + Priority: http2PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "with priority", + http2HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 2, + Priority: http2PriorityParam{ + StreamDep: 15, + Exclusive: true, + Weight: 127, + }, + }, + "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded | http2FlagHeadersPriority, + Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding + }, + Priority: http2PriorityParam{ + StreamDep: 15, + Exclusive: true, + Weight: 127, + }, + headerFragBuf: []byte("abc"), + }, + }, + { + "with priority stream dep zero", // golang.org/issue/15444 + http2HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 2, + Priority: http2PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 127, + }, + }, + "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded | http2FlagHeadersPriority, + Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding + }, + Priority: http2PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 127, + }, + headerFragBuf: []byte("abc"), + }, + }, + { + "zero length", + http2HeadersFrameParam{ + StreamID: 42, + Priority: http2PriorityParam{}, + }, + "\x00\x00\x00\x01\x00\x00\x00\x00*", + &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: 42, + Type: http2FrameHeaders, + Length: 0, + }, + Priority: http2PriorityParam{}, + }, + }, + } + for _, tt := range tests { + fr, buf := testFramer() + if err := fr.WriteHeaders(tt.p); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + if buf.String() != tt.wantEnc { + t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWriteInvalidStreamDep(t *testing.T) { + fr, _ := testFramer() + err := fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: 42, + Priority: http2PriorityParam{ + StreamDep: 1 << 31, + }, + }) + if err != http2errDepStreamID { + t.Errorf("header error = %v; want %q", err, http2errDepStreamID) + } + + err = fr.WritePriority(2, http2PriorityParam{StreamDep: 1 << 31}) + if err != http2errDepStreamID { + t.Errorf("priority error = %v; want %q", err, http2errDepStreamID) + } +} + +func TestWriteContinuation(t *testing.T) { + const streamID = 42 + tests := []struct { + name string + end bool + frag []byte + + wantFrame *http2ContinuationFrame + }{ + { + "not end", + false, + []byte("abc"), + &http2ContinuationFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: streamID, + Type: http2FrameContinuation, + Length: uint32(len("abc")), + }, + headerFragBuf: []byte("abc"), + }, + }, + { + "end", + true, + []byte("def"), + &http2ContinuationFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + StreamID: streamID, + Type: http2FrameContinuation, + Flags: http2FlagContinuationEndHeaders, + Length: uint32(len("def")), + }, + headerFragBuf: []byte("def"), + }, + }, + } + for _, tt := range tests { + fr, _ := testFramer() + if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + fr.AllowIllegalReads = true + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWritePriority(t *testing.T) { + const streamID = 42 + tests := []struct { + name string + priority http2PriorityParam + wantFrame *http2PriorityFrame + }{ + { + "not exclusive", + http2PriorityParam{ + StreamDep: 2, + Exclusive: false, + Weight: 127, + }, + &http2PriorityFrame{ + http2FrameHeader{ + valid: true, + StreamID: streamID, + Type: http2FramePriority, + Length: 5, + }, + http2PriorityParam{ + StreamDep: 2, + Exclusive: false, + Weight: 127, + }, + }, + }, + + { + "exclusive", + http2PriorityParam{ + StreamDep: 3, + Exclusive: true, + Weight: 77, + }, + &http2PriorityFrame{ + http2FrameHeader{ + valid: true, + StreamID: streamID, + Type: http2FramePriority, + Length: 5, + }, + http2PriorityParam{ + StreamDep: 3, + Exclusive: true, + Weight: 77, + }, + }, + }, + } + for _, tt := range tests { + fr, _ := testFramer() + if err := fr.WritePriority(streamID, tt.priority); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWriteSettings(t *testing.T) { + fr, buf := testFramer() + settings := []http2Setting{{1, 2}, {3, 4}} + fr.WriteSettings(settings...) + const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + sf, ok := f.(*http2SettingsFrame) + if !ok { + t.Fatalf("Got a %T; want a SettingsFrame", f) + } + var got []http2Setting + sf.ForeachSetting(func(s http2Setting) error { + got = append(got, s) + valBack, ok := sf.Value(s.ID) + if !ok || valBack != s.Val { + t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val) + } + return nil + }) + if !reflect.DeepEqual(settings, got) { + t.Errorf("Read settings %+v != written settings %+v", got, settings) + } +} + +func TestWriteSettingsAck(t *testing.T) { + fr, buf := testFramer() + fr.WriteSettingsAck() + const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } +} + +func TestWriteWindowUpdate(t *testing.T) { + fr, buf := testFramer() + const streamID = 1<<24 + 2<<16 + 3<<8 + 4 + const incr = 7<<24 + 6<<16 + 5<<8 + 4 + if err := fr.WriteWindowUpdate(streamID, incr); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &http2WindowUpdateFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + Type: 0x8, + Flags: 0x0, + Length: 0x4, + StreamID: 0x1020304, + }, + Increment: 0x7060504, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestWritePing(t *testing.T) { testWritePing(t, false) } +func TestWritePingAck(t *testing.T) { testWritePing(t, true) } + +func testWritePing(t *testing.T, ack bool) { + fr, buf := testFramer() + if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { + t.Fatal(err) + } + var wantFlags http2Flags + if ack { + wantFlags = http2FlagPingAck + } + var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &http2PingFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + Type: 0x6, + Flags: wantFlags, + Length: 0x8, + StreamID: 0, + }, + Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestReadFrameHeader(t *testing.T) { + tests := []struct { + in string + want http2FrameHeader + }{ + {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: http2FrameHeader{}}, + {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: http2FrameHeader{ + Length: 66051, Type: 4, Flags: 5, StreamID: 101124105, + }}, + // Ignore high bit: + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: http2FrameHeader{ + Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: http2FrameHeader{ + Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, + } + for i, tt := range tests { + got, err := http2readFrameHeader(make([]byte, 9), strings.NewReader(tt.in)) + if err != nil { + t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err) + continue + } + tt.want.valid = true + if !got.Equal(tt.want) { + t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want) + } + } +} + +func TestReadWriteFrameHeader(t *testing.T) { + tests := []struct { + len uint32 + typ http2FrameType + flags http2Flags + streamID uint32 + }{ + {len: 0, typ: 255, flags: 1, streamID: 0}, + {len: 0, typ: 255, flags: 1, streamID: 1}, + {len: 0, typ: 255, flags: 1, streamID: 255}, + {len: 0, typ: 255, flags: 1, streamID: 256}, + {len: 0, typ: 255, flags: 1, streamID: 65535}, + {len: 0, typ: 255, flags: 1, streamID: 65536}, + + {len: 0, typ: 1, flags: 255, streamID: 1}, + {len: 255, typ: 1, flags: 255, streamID: 1}, + {len: 256, typ: 1, flags: 255, streamID: 1}, + {len: 65535, typ: 1, flags: 255, streamID: 1}, + {len: 65536, typ: 1, flags: 255, streamID: 1}, + {len: 16777215, typ: 1, flags: 255, streamID: 1}, + } + for _, tt := range tests { + fr, buf := testFramer() + fr.startWrite(tt.typ, tt.flags, tt.streamID) + fr.writeBytes(make([]byte, tt.len)) + fr.endWrite() + fh, err := http2ReadFrameHeader(buf) + if err != nil { + t.Errorf("ReadFrameHeader(%+v) = %v", tt, err) + continue + } + if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID { + t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh) + } + } + +} + +func TestWriteTooLargeFrame(t *testing.T) { + fr, _ := testFramer() + fr.startWrite(0, 1, 1) + fr.writeBytes(make([]byte, 1<<24)) + err := fr.endWrite() + if err != http2ErrFrameTooLarge { + t.Errorf("endWrite = %v; want errFrameTooLarge", err) + } +} + +func TestWriteGoAway(t *testing.T) { + const debug = "foo" + fr, buf := testFramer() + if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &http2GoAwayFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + Type: 0x7, + Flags: 0, + Length: uint32(4 + 4 + len(debug)), + StreamID: 0, + }, + LastStreamID: 0x01020304, + ErrCode: 0x05060708, + debugData: []byte(debug), + } + if !reflect.DeepEqual(f, want) { + t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) + } + if got := string(f.(*http2GoAwayFrame).DebugData()); got != debug { + t.Errorf("debug data = %q; want %q", got, debug) + } +} + +func TestWritePushPromise(t *testing.T) { + pp := http2PushPromiseParam{ + StreamID: 42, + PromiseID: 42, + BlockFragment: []byte("abc"), + } + fr, buf := testFramer() + if err := fr.WritePushPromise(pp); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + _, ok := f.(*http2PushPromiseFrame) + if !ok { + t.Fatalf("got %T; want *PushPromiseFrame", f) + } + want := &http2PushPromiseFrame{ + http2FrameHeader: http2FrameHeader{ + valid: true, + Type: 0x5, + Flags: 0x0, + Length: 0x7, + StreamID: 42, + }, + PromiseID: 42, + headerFragBuf: []byte("abc"), + } + if !reflect.DeepEqual(f, want) { + t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) + } +} + +// test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled. +func TestReadFrameOrder(t *testing.T) { + head := func(f *http2Framer, id uint32, end bool) { + f.WriteHeaders(http2HeadersFrameParam{ + StreamID: id, + BlockFragment: []byte("foo"), // unused, but non-empty + EndHeaders: end, + }) + } + cont := func(f *http2Framer, id uint32, end bool) { + f.WriteContinuation(id, end, []byte("foo")) + } + + tests := [...]struct { + name string + w func(*http2Framer) + atLeast int + wantErr string + }{ + 0: { + w: func(f *http2Framer) { + head(f, 1, true) + }, + }, + 1: { + w: func(f *http2Framer) { + head(f, 1, true) + head(f, 2, true) + }, + }, + 2: { + wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1", + w: func(f *http2Framer) { + head(f, 1, false) + head(f, 2, true) + }, + }, + 3: { + wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1", + w: func(f *http2Framer) { + head(f, 1, false) + }, + }, + 4: { + w: func(f *http2Framer) { + head(f, 1, false) + cont(f, 1, true) + head(f, 2, true) + }, + }, + 5: { + wantErr: "got CONTINUATION for stream 2; expected stream 1", + w: func(f *http2Framer) { + head(f, 1, false) + cont(f, 2, true) + head(f, 2, true) + }, + }, + 6: { + wantErr: "unexpected CONTINUATION for stream 1", + w: func(f *http2Framer) { + cont(f, 1, true) + }, + }, + 7: { + wantErr: "unexpected CONTINUATION for stream 1", + w: func(f *http2Framer) { + cont(f, 1, false) + }, + }, + 8: { + wantErr: "HEADERS frame with stream ID 0", + w: func(f *http2Framer) { + head(f, 0, true) + }, + }, + 9: { + wantErr: "CONTINUATION frame with stream ID 0", + w: func(f *http2Framer) { + cont(f, 0, true) + }, + }, + 10: { + wantErr: "unexpected CONTINUATION for stream 1", + atLeast: 5, + w: func(f *http2Framer) { + head(f, 1, false) + cont(f, 1, false) + cont(f, 1, false) + cont(f, 1, false) + cont(f, 1, true) + cont(f, 1, false) + }, + }, + } + for i, tt := range tests { + buf := new(bytes.Buffer) + f := http2NewFramer(buf, buf) + f.AllowIllegalWrites = true + tt.w(f) + f.WriteData(1, true, nil) // to test transition away from last step + + var err error + n := 0 + var log bytes.Buffer + for { + var got http2Frame + got, err = f.ReadFrame() + fmt.Fprintf(&log, " read %v, %v\n", got, err) + if err != nil { + break + } + n++ + } + if err == io.EOF { + err = nil + } + ok := tt.wantErr == "" + if ok && err != nil { + t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes()) + continue + } + if !ok && err != http2ConnectionError(http2ErrCodeProtocol) { + t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes()) + continue + } + if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) { + t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes()) + } + if n < tt.atLeast { + t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes()) + } + } +} + +type hpackEncoder struct { + enc *hpack.Encoder + buf bytes.Buffer +} + +func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + he.buf.Reset() + if he.enc == nil { + he.enc = hpack.NewEncoder(&he.buf) + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } + headers = headers[2:] + } + return he.buf.Bytes() +} + +func TestMetaFrameHeader(t *testing.T) { + write := func(f *http2Framer, frags ...[]byte) { + for i, frag := range frags { + end := (i == len(frags)-1) + if i == 0 { + f.WriteHeaders(http2HeadersFrameParam{ + StreamID: 1, + BlockFragment: frag, + EndHeaders: end, + }) + } else { + f.WriteContinuation(1, end, frag) + } + } + } + + want := func(flags http2Flags, length uint32, pairs ...string) *http2MetaHeadersFrame { + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{ + Type: http2FrameHeaders, + Flags: flags, + Length: length, + StreamID: 1, + }, + }, + Fields: []hpack.HeaderField(nil), + } + for len(pairs) > 0 { + mh.Fields = append(mh.Fields, hpack.HeaderField{ + Name: pairs[0], + Value: pairs[1], + }) + pairs = pairs[2:] + } + return mh + } + truncated := func(mh *http2MetaHeadersFrame) *http2MetaHeadersFrame { + mh.Truncated = true + return mh + } + + const noFlags http2Flags = 0 + + oneKBString := strings.Repeat("a", 1<<10) + + tests := [...]struct { + name string + w func(*http2Framer) + want interface{} // *MetaHeaderFrame or error + wantErrReason string + maxHeaderListSize uint32 + }{ + 0: { + name: "single_headers", + w: func(f *http2Framer) { + var he hpackEncoder + all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/") + write(f, all) + }, + want: want(http2FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"), + }, + 1: { + name: "with_continuation", + w: func(f *http2Framer) { + var he hpackEncoder + all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") + write(f, all[:1], all[1:]) + }, + want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"), + }, + 2: { + name: "with_two_continuation", + w: func(f *http2Framer) { + var he hpackEncoder + all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") + write(f, all[:2], all[2:4], all[4:]) + }, + want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"), + }, + 3: { + name: "big_string_okay", + w: func(f *http2Framer) { + var he hpackEncoder + all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) + write(f, all[:2], all[2:]) + }, + want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString), + }, + 4: { + name: "big_string_error", + w: func(f *http2Framer) { + var he hpackEncoder + all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) + write(f, all[:2], all[2:]) + }, + maxHeaderListSize: (1 << 10) / 2, + want: http2ConnectionError(http2ErrCodeCompression), + }, + 5: { + name: "max_header_list_truncated", + w: func(f *http2Framer) { + var he hpackEncoder + var pairs = []string{":method", "GET", ":path", "/"} + for i := 0; i < 100; i++ { + pairs = append(pairs, "foo", "bar") + } + all := he.encodeHeaderRaw(t, pairs...) + write(f, all[:2], all[2:]) + }, + maxHeaderListSize: (1 << 10) / 2, + want: truncated(want(noFlags, 2, + ":method", "GET", + ":path", "/", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", + "foo", "bar", // 11 + )), + }, + 6: { + name: "pseudo_order", + w: func(f *http2Framer) { + write(f, encodeHeaderRaw(t, + ":method", "GET", + "foo", "bar", + ":path", "/", // bogus + )) + }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "pseudo header field after regular", + }, + 7: { + name: "pseudo_unknown", + w: func(f *http2Framer) { + write(f, encodeHeaderRaw(t, + ":unknown", "foo", // bogus + "foo", "bar", + )) + }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "invalid pseudo-header \":unknown\"", + }, + 8: { + name: "pseudo_mix_request_response", + w: func(f *http2Framer) { + write(f, encodeHeaderRaw(t, + ":method", "GET", + ":status", "100", + )) + }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "mix of request and response pseudo headers", + }, + 9: { + name: "pseudo_dup", + w: func(f *http2Framer) { + write(f, encodeHeaderRaw(t, + ":method", "GET", + ":method", "POST", + )) + }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "duplicate pseudo-header \":method\"", + }, + 10: { + name: "trailer_okay_no_pseudo", + w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) }, + want: want(http2FlagHeadersEndHeaders, 8, "foo", "bar"), + }, + 11: { + name: "invalid_field_name", + w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "invalid header field name \"CapitalBad\"", + }, + 12: { + name: "invalid_field_value", + w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, + want: http2streamError(1, http2ErrCodeProtocol), + wantErrReason: "invalid header field value \"bad_null\\x00\"", + }, + } + for i, tt := range tests { + buf := new(bytes.Buffer) + f := http2NewFramer(buf, buf) + f.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + f.MaxHeaderListSize = tt.maxHeaderListSize + tt.w(f) + + name := tt.name + if name == "" { + name = fmt.Sprintf("test index %d", i) + } + + var got interface{} + var err error + got, err = f.ReadFrame() + if err != nil { + got = err + + // Ignore the StreamError.Cause field, if it matches the wantErrReason. + // The test table above predates the Cause field. + if se, ok := err.(http2StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason { + se.Cause = nil + got = se + } + } + if !reflect.DeepEqual(got, tt.want) { + if mhg, ok := got.(*http2MetaHeadersFrame); ok { + if mhw, ok := tt.want.(*http2MetaHeadersFrame); ok { + hg := mhg.http2HeadersFrame + hw := mhw.http2HeadersFrame + if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) { + t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw) + } + } + } + str := func(v interface{}) string { + if _, ok := v.(error); ok { + return fmt.Sprintf("error %v", v) + } else { + return fmt.Sprintf("value %#v", v) + } + } + t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want)) + } + if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) { + t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason) + } + } +} + +func TestSetReuseFrames(t *testing.T) { + fr, buf := testFramer() + fr.SetReuseFrames() + + // Check that DataFrames are reused. Note that + // SetReuseFrames only currently implements reuse of DataFrames. + firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t) + + for i := 0; i < 10; i++ { + df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) + if df != firstDf { + t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) + } + } + + for i := 0; i < 10; i++ { + df := readAndVerifyDataFrame("", 0, fr, buf, t) + if df != firstDf { + t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) + } + } + + for i := 0; i < 10; i++ { + df := readAndVerifyDataFrame("HHH", 3, fr, buf, t) + if df != firstDf { + t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) + } + } +} + +func TestSetReuseFramesMoreThanOnce(t *testing.T) { + fr, buf := testFramer() + fr.SetReuseFrames() + + firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t) + fr.SetReuseFrames() + + for i := 0; i < 10; i++ { + df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) + // SetReuseFrames should be idempotent + fr.SetReuseFrames() + if df != firstDf { + t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) + } + } +} + +func TestNoSetReuseFrames(t *testing.T) { + fr, buf := testFramer() + const numNewDataFrames = 10 + dfSoFar := make([]interface{}, numNewDataFrames) + + // Check that DataFrames are not reused if SetReuseFrames wasn't called. + // SetReuseFrames only currently implements reuse of DataFrames. + for i := 0; i < numNewDataFrames; i++ { + df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) + for _, item := range dfSoFar { + if df == item { + t.Errorf("Expected Framer to return new DataFrames since SetNoReuseFrames not set.") + } + } + dfSoFar[i] = df + } +} + +func readAndVerifyDataFrame(data string, length byte, fr *http2Framer, buf *bytes.Buffer, t *testing.T) *http2DataFrame { + var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 + fr.WriteData(streamID, true, []byte(data)) + wantEnc := "\x00\x00" + string(length) + "\x00\x01\x01\x02\x03\x04" + data + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + df, ok := f.(*http2DataFrame) + if !ok { + t.Fatalf("got %T; want *http2DataFrame", f) + } + if !bytes.Equal(df.Data(), []byte(data)) { + t.Errorf("got %q; want %q", df.Data(), []byte(data)) + } + if f.Header().Flags&1 == 0 { + t.Errorf("didn't see END_STREAM flag") + } + return df +} + +func encodeHeaderRaw(t *testing.T, pairs ...string) []byte { + var he hpackEncoder + return he.encodeHeaderRaw(t, pairs...) +} + +func TestSettingsDuplicates(t *testing.T) { + tests := []struct { + settings []http2Setting + want bool + }{ + {nil, false}, + {[]http2Setting{{ID: 1}}, false}, + {[]http2Setting{{ID: 1}, {ID: 2}}, false}, + {[]http2Setting{{ID: 1}, {ID: 2}}, false}, + {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, + {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, + {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}, false}, + + {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 2}}, true}, + {[]http2Setting{{ID: 4}, {ID: 2}, {ID: 3}, {ID: 4}}, true}, + + {[]http2Setting{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, + {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, + {ID: 9}, {ID: 10}, {ID: 11}, {ID: 12}, + }, false}, + + {[]http2Setting{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, + {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, + {ID: 9}, {ID: 10}, {ID: 11}, {ID: 11}, + }, true}, + } + for i, tt := range tests { + fr, _ := testFramer() + fr.WriteSettings(tt.settings...) + f, err := fr.ReadFrame() + if err != nil { + t.Fatalf("%d. ReadFrame: %v", i, err) + } + sf := f.(*http2SettingsFrame) + got := sf.HasDuplicates() + if got != tt.want { + t.Errorf("%d. HasDuplicates = %v; want %v", i, got, tt.want) + } + } + +} From 633099ae14f413da105aedbef087e5fb294e379f Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:45:55 +0800 Subject: [PATCH 326/843] extract h2_trace.go --- h2_bundle.go | 17 ----------------- h2_trace.go | 27 +++++++++++++++++++++++++++ textproto_reader.go | 6 +++--- 3 files changed, 30 insertions(+), 20 deletions(-) create mode 100644 h2_trace.go diff --git a/h2_bundle.go b/h2_bundle.go index 584ed1c7..4ba1dd9e 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -52,23 +52,6 @@ import ( "golang.org/x/net/idna" ) -func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { - return trace != nil && trace.WroteHeaderField != nil -} - -func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField(k, []string{v}) - } -} - -func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { - if trace != nil { - return trace.Got1xxResponse - } - return nil -} - // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { diff --git a/h2_trace.go b/h2_trace.go new file mode 100644 index 00000000..3f32f194 --- /dev/null +++ b/h2_trace.go @@ -0,0 +1,27 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "net/http/httptrace" + "net/textproto" +) + +func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +} + +func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(k, []string{v}) + } +} + +func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { + if trace != nil { + return trace.Got1xxResponse + } + return nil +} diff --git a/textproto_reader.go b/textproto_reader.go index f0982ead..99e40019 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -531,7 +531,7 @@ var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. -// The returned map m maps Canonicaltextproto.MIMEHeaderKey(key) to a +// The returned map m maps CanonicalMIMEHeaderKey(key) to a // sequence of values in the same order encountered in the input. // // For example, consider this input: @@ -541,7 +541,7 @@ var colon = []byte(":") // Longer Value // My-Key: Value 2 // -// Given that input, Readtextproto.MIMEHeader returns the map: +// Given that input, ReadMIMEHeader returns the map: // // map[string][]string{ // "My-Key": {"Value 1", "Value 2"}, @@ -682,7 +682,7 @@ func validHeaderFieldByte(b byte) bool { return int(b) < len(isTokenTable) && isTokenTable[b] } -// canonicalMIMEHeaderKey is like Canonicaltextproto.MIMEHeaderKey but is +// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is // allowed to mutate the provided byte slice before returning the // string. // From a4702ec7b06ad312aff424f9fb6dd82a50a69083 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:47:44 +0800 Subject: [PATCH 327/843] update h2_trace.go --- h2_bundle.go | 55 -------------------------------------------------- h2_trace.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index 4ba1dd9e..c2a66db0 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -3649,61 +3649,6 @@ func (t *http2Transport) idleConnTimeout() time.Duration { return 0 } -func http2traceGetConn(req *http.Request, hostPort string) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GetConn == nil { - return - } - trace.GetConn(hostPort) -} - -func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GotConn == nil { - return - } - ci := httptrace.GotConnInfo{Conn: cc.tconn} - ci.Reused = reused - cc.mu.Lock() - ci.WasIdle = len(cc.streams) == 0 && reused - if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = time.Now().Sub(cc.lastActive) - } - cc.mu.Unlock() - - trace.GotConn(ci) -} - -func http2traceWroteHeaders(trace *httptrace.ClientTrace) { - if trace != nil && trace.WroteHeaders != nil { - trace.WroteHeaders() - } -} - -func http2traceGot100Continue(trace *httptrace.ClientTrace) { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() - } -} - -func http2traceWait100Continue(trace *httptrace.ClientTrace) { - if trace != nil && trace.Wait100Continue != nil { - trace.Wait100Continue() - } -} - -func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { - if trace != nil && trace.WroteRequest != nil { - trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) - } -} - -func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { - if trace != nil && trace.GotFirstResponseByte != nil { - trace.GotFirstResponseByte() - } -} - // writeContext is the interface needed by the various frame writer // types below. All the writeFrame methods below are scheduled via the // frame writing scheduler (see writeScheduler in writesched.go). diff --git a/h2_trace.go b/h2_trace.go index 3f32f194..cf40a785 100644 --- a/h2_trace.go +++ b/h2_trace.go @@ -5,8 +5,10 @@ package req import ( + "net/http" "net/http/httptrace" "net/textproto" + "time" ) func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { @@ -25,3 +27,58 @@ func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textpr } return nil } + +func http2traceGetConn(req *http.Request, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GetConn == nil { + return + } + trace.GetConn(hostPort) +} + +func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + ci.Reused = reused + cc.mu.Lock() + ci.WasIdle = len(cc.streams) == 0 && reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *httptrace.ClientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} From 9c24974853f0ed586d26a5a77e660ec32ef4e2d1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:54:57 +0800 Subject: [PATCH 328/843] extract h2_gotrack.go --- h2_bundle.go | 153 --------------------------------------------- h2_gotrack.go | 170 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 153 deletions(-) create mode 100644 h2_gotrack.go diff --git a/h2_bundle.go b/h2_bundle.go index c2a66db0..0c9eec69 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -39,7 +39,6 @@ import ( "net/http/httptrace" "net/textproto" "os" - "runtime" "sort" "strconv" "strings" @@ -66,158 +65,6 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s return tlsCn, nil } -var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" - -type http2goroutineLock uint64 - -func http2newGoroutineLock() http2goroutineLock { - if !http2DebugGoroutines { - return 0 - } - return http2goroutineLock(http2curGoroutineID()) -} - -func (g http2goroutineLock) check() { - if !http2DebugGoroutines { - return - } - if http2curGoroutineID() != uint64(g) { - panic("running on the wrong goroutine") - } -} - -func (g http2goroutineLock) checkNotOn() { - if !http2DebugGoroutines { - return - } - if http2curGoroutineID() == uint64(g) { - panic("running on the wrong goroutine") - } -} - -var http2goroutineSpace = []byte("goroutine ") - -func http2curGoroutineID() uint64 { - bp := http2littleBuf.Get().(*[]byte) - defer http2littleBuf.Put(bp) - b := *bp - b = b[:runtime.Stack(b, false)] - // Parse the 4707 out of "goroutine 4707 [" - b = bytes.TrimPrefix(b, http2goroutineSpace) - i := bytes.IndexByte(b, ' ') - if i < 0 { - panic(fmt.Sprintf("No space found in %q", b)) - } - b = b[:i] - n, err := http2parseUintBytes(b, 10, 64) - if err != nil { - panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) - } - return n -} - -var http2littleBuf = sync.Pool{ - New: func() interface{} { - buf := make([]byte, 64) - return &buf - }, -} - -// parseUintBytes is like strconv.ParseUint, but using a []byte. -func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { - var cutoff, maxVal uint64 - - if bitSize == 0 { - bitSize = int(strconv.IntSize) - } - - s0 := s - switch { - case len(s) < 1: - err = strconv.ErrSyntax - goto Error - - case 2 <= base && base <= 36: - // valid base; nothing to do - - case base == 0: - // Look for octal, hex prefix. - switch { - case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): - base = 16 - s = s[2:] - if len(s) < 1 { - err = strconv.ErrSyntax - goto Error - } - case s[0] == '0': - base = 8 - default: - base = 10 - } - - default: - err = errors.New("invalid base " + strconv.Itoa(base)) - goto Error - } - - n = 0 - cutoff = http2cutoff64(base) - maxVal = 1<= base { - n = 0 - err = strconv.ErrSyntax - goto Error - } - - if n >= cutoff { - // n*base overflows - n = 1<<64 - 1 - err = strconv.ErrRange - goto Error - } - n *= uint64(base) - - n1 := n + uint64(v) - if n1 < n || n1 > maxVal { - // n+v overflows - n = 1<<64 - 1 - err = strconv.ErrRange - goto Error - } - n = n1 - } - - return n, nil - -Error: - return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} -} - -// Return the first number n such that n*base >= 1<<64. -func http2cutoff64(base int) uint64 { - if base < 2 { - return 0 - } - return (1<<64-1)/uint64(base) + 1 -} - var ( http2commonBuildOnce sync.Once http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case diff --git a/h2_gotrack.go b/h2_gotrack.go new file mode 100644 index 00000000..1a4656b7 --- /dev/null +++ b/h2_gotrack.go @@ -0,0 +1,170 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Defensive debug-only utility to track that functions run on the +// goroutine that they're supposed to. + +package req + +import ( + "bytes" + "errors" + "fmt" + "os" + "runtime" + "strconv" + "sync" +) + +var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" + +type http2goroutineLock uint64 + +func http2newGoroutineLock() http2goroutineLock { + if !http2DebugGoroutines { + return 0 + } + return http2goroutineLock(http2curGoroutineID()) +} + +func (g http2goroutineLock) check() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() != uint64(g) { + panic("running on the wrong goroutine") + } +} + +func (g http2goroutineLock) checkNotOn() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() == uint64(g) { + panic("running on the wrong goroutine") + } +} + +var http2goroutineSpace = []byte("goroutine ") + +func http2curGoroutineID() uint64 { + bp := http2littleBuf.Get().(*[]byte) + defer http2littleBuf.Put(bp) + b := *bp + b = b[:runtime.Stack(b, false)] + // Parse the 4707 out of "goroutine 4707 [" + b = bytes.TrimPrefix(b, http2goroutineSpace) + i := bytes.IndexByte(b, ' ') + if i < 0 { + panic(fmt.Sprintf("No space found in %q", b)) + } + b = b[:i] + n, err := http2parseUintBytes(b, 10, 64) + if err != nil { + panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) + } + return n +} + +var http2littleBuf = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 64) + return &buf + }, +} + +// parseUintBytes is like strconv.ParseUint, but using a []byte. +func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { + var cutoff, maxVal uint64 + + if bitSize == 0 { + bitSize = int(strconv.IntSize) + } + + s0 := s + switch { + case len(s) < 1: + err = strconv.ErrSyntax + goto Error + + case 2 <= base && base <= 36: + // valid base; nothing to do + + case base == 0: + // Look for octal, hex prefix. + switch { + case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): + base = 16 + s = s[2:] + if len(s) < 1 { + err = strconv.ErrSyntax + goto Error + } + case s[0] == '0': + base = 8 + default: + base = 10 + } + + default: + err = errors.New("invalid base " + strconv.Itoa(base)) + goto Error + } + + n = 0 + cutoff = http2cutoff64(base) + maxVal = 1<= base { + n = 0 + err = strconv.ErrSyntax + goto Error + } + + if n >= cutoff { + // n*base overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n *= uint64(base) + + n1 := n + uint64(v) + if n1 < n || n1 > maxVal { + // n+v overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n = n1 + } + + return n, nil + +Error: + return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} +} + +// Return the first number n such that n*base >= 1<<64. +func http2cutoff64(base int) uint64 { + if base < 2 { + return 0 + } + return (1<<64-1)/uint64(base) + 1 +} From 709b317a03696f62c3f666a9d34fd2d576040efd Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 20:57:11 +0800 Subject: [PATCH 329/843] add h2_gotrack_test.go --- h2_gotrack_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 h2_gotrack_test.go diff --git a/h2_gotrack_test.go b/h2_gotrack_test.go new file mode 100644 index 00000000..ae6eb485 --- /dev/null +++ b/h2_gotrack_test.go @@ -0,0 +1,33 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "fmt" + "strings" + "testing" +) + +func TestGoroutineLock(t *testing.T) { + oldDebug := http2DebugGoroutines + http2DebugGoroutines = true + defer func() { http2DebugGoroutines = oldDebug }() + + g := http2newGoroutineLock() + g.check() + + sawPanic := make(chan interface{}) + go func() { + defer func() { sawPanic <- recover() }() + g.check() // should panic + }() + e := <-sawPanic + if e == nil { + t.Fatal("did not see panic from check in other goroutine") + } + if !strings.Contains(fmt.Sprint(e), "wrong goroutine") { + t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e) + } +} From 04fe10654cb8524384c51cb845eb3aff56caba4a Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:00:27 +0800 Subject: [PATCH 330/843] extract h2_headermap.go --- h2_bundle.go | 77 ------------------------------------------- h2_headermap.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 77 deletions(-) create mode 100644 h2_headermap.go diff --git a/h2_bundle.go b/h2_bundle.go index 0c9eec69..2c618cfe 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -65,83 +65,6 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s return tlsCn, nil } -var ( - http2commonBuildOnce sync.Once - http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case - http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case -) - -func http2buildCommonHeaderMapsOnce() { - http2commonBuildOnce.Do(http2buildCommonHeaderMaps) -} - -func http2buildCommonHeaderMaps() { - common := []string{ - "accept", - "accept-charset", - "accept-encoding", - "accept-language", - "accept-ranges", - "age", - "access-control-allow-origin", - "allow", - "authorization", - "cache-control", - "content-disposition", - "content-encoding", - "content-language", - "content-length", - "content-location", - "content-range", - "content-type", - "cookie", - "date", - "etag", - "expect", - "expires", - "from", - "host", - "if-match", - "if-modified-since", - "if-none-match", - "if-unmodified-since", - "last-modified", - "link", - "location", - "max-forwards", - "proxy-authenticate", - "proxy-authorization", - "range", - "referer", - "refresh", - "retry-after", - "server", - "set-cookie", - "strict-transport-security", - "trailer", - "transfer-encoding", - "user-agent", - "vary", - "via", - "www-authenticate", - } - http2commonLowerHeader = make(map[string]string, len(common)) - http2commonCanonHeader = make(map[string]string, len(common)) - for _, v := range common { - chk := http.CanonicalHeaderKey(v) - http2commonLowerHeader[chk] = v - http2commonCanonHeader[v] = chk - } -} - -func http2lowerHeader(v string) (lower string, isAscii bool) { - http2buildCommonHeaderMapsOnce() - if s, ok := http2commonLowerHeader[v]; ok { - return s, true - } - return ascii.ToLower(v) -} - var ( http2VerboseLogs bool http2logFrameWrites bool diff --git a/h2_headermap.go b/h2_headermap.go new file mode 100644 index 00000000..a78e405b --- /dev/null +++ b/h2_headermap.go @@ -0,0 +1,88 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "github.com/imroc/req/v3/internal/ascii" + "net/http" + "sync" +) + +var ( + http2commonBuildOnce sync.Once + http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case +) + +func http2buildCommonHeaderMapsOnce() { + http2commonBuildOnce.Do(http2buildCommonHeaderMaps) +} + +func http2buildCommonHeaderMaps() { + common := []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-origin", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "trailer", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + } + http2commonLowerHeader = make(map[string]string, len(common)) + http2commonCanonHeader = make(map[string]string, len(common)) + for _, v := range common { + chk := http.CanonicalHeaderKey(v) + http2commonLowerHeader[chk] = v + http2commonCanonHeader[v] = chk + } +} + +func http2lowerHeader(v string) (lower string, isAscii bool) { + http2buildCommonHeaderMapsOnce() + if s, ok := http2commonLowerHeader[v]; ok { + return s, true + } + return ascii.ToLower(v) +} From 60099187f6119d4dea09755b8fea2e5cb1367555 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:02:59 +0800 Subject: [PATCH 331/843] extract h2.go --- h2.go | 262 +++++++++++++++++++++++++++++++++++++++++++++++++++ h2_bundle.go | 244 ----------------------------------------------- 2 files changed, 262 insertions(+), 244 deletions(-) create mode 100644 h2.go diff --git a/h2.go b/h2.go new file mode 100644 index 00000000..de9d83d5 --- /dev/null +++ b/h2.go @@ -0,0 +1,262 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bufio" + "crypto/tls" + "fmt" + "golang.org/x/net/http/httpguts" + "net/http" + "os" + "sort" + "strconv" + "strings" + "sync" +) + +var ( + http2VerboseLogs bool + http2logFrameWrites bool + http2logFrameReads bool + http2inTests bool +) + +func init() { + e := os.Getenv("GODEBUG") + if strings.Contains(e, "http2debug=1") { + http2VerboseLogs = true + } + if strings.Contains(e, "http2debug=2") { + http2VerboseLogs = true + http2logFrameWrites = true + http2logFrameReads = true + } +} + +const ( + // ClientPreface is the string that must be sent by new + // connections from clients. + http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + // NextProtoTLS is the NPN/ALPN protocol negotiated during + // HTTP/2's TLS setup. + http2NextProtoTLS = "h2" + + // http://http2.github.io/http2-spec/#SettingValues + http2initialHeaderTableSize = 4096 + + http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size +) + +var ( + http2clientPreface = []byte(http2ClientPreface) +) + +// Setting is a setting parameter: which setting it is, and its value. +type http2Setting struct { + // ID is which setting is being set. + // See http://http2.github.io/http2-spec/#SettingValues + ID http2SettingID + + // Val is the value. + Val uint32 +} + +func (s http2Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} + +// Valid reports whether the setting is valid. +func (s http2Setting) Valid() error { + // Limits and error codes from 6.5.2 Defined SETTINGS Parameters + switch s.ID { + case http2SettingEnablePush: + if s.Val != 1 && s.Val != 0 { + return http2ConnectionError(http2ErrCodeProtocol) + } + case http2SettingInitialWindowSize: + if s.Val > 1<<31-1 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + case http2SettingMaxFrameSize: + if s.Val < 16384 || s.Val > 1<<24-1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + } + return nil +} + +// A SettingID is an HTTP/2 setting as defined in +// http://http2.github.io/http2-spec/#iana-settings +type http2SettingID uint16 + +const ( + http2SettingHeaderTableSize http2SettingID = 0x1 + http2SettingEnablePush http2SettingID = 0x2 + http2SettingMaxConcurrentStreams http2SettingID = 0x3 + http2SettingInitialWindowSize http2SettingID = 0x4 + http2SettingMaxFrameSize http2SettingID = 0x5 + http2SettingMaxHeaderListSize http2SettingID = 0x6 +) + +var http2settingName = map[http2SettingID]string{ + http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", + http2SettingEnablePush: "ENABLE_PUSH", + http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + http2SettingMaxFrameSize: "MAX_FRAME_SIZE", + http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s http2SettingID) String() string { + if v, ok := http2settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} + +// validWireHeaderFieldName reports whether v is a valid header field +// name (key). See httpguts.ValidHeaderName for the base rules. +// +// Further, http2 says: +// "Just as in HTTP/1.x, header field names are strings of ASCII +// characters that are compared in a case-insensitive +// fashion. However, header field names MUST be converted to +// lowercase prior to their encoding in HTTP/2. " +func http2validWireHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if !httpguts.IsTokenRune(r) { + return false + } + if 'A' <= r && r <= 'Z' { + return false + } + } + return true +} + +func http2httpCodeString(code int) string { + switch code { + case 200: + return "200" + case 404: + return "404" + } + return strconv.Itoa(code) +} + +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + +var http2bufWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) + }, +} + +func http2mustUint31(v int32) uint32 { + if v < 0 || v > 2147483647 { + panic("out of range") + } + return uint32(v) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func http2bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +type http2httpError struct { + _ http2incomparable + msg string + timeout bool +} + +func (e *http2httpError) Error() string { return e.msg } + +func (e *http2httpError) Timeout() bool { return e.timeout } + +func (e *http2httpError) Temporary() bool { return true } + +var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} + +type http2connectionStater interface { + ConnectionState() tls.ConnectionState +} + +var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} + +type http2sorter struct { + v []string // owned by sorter +} + +func (s *http2sorter) Len() int { return len(s.v) } + +func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } + +func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } + +// Keys returns the sorted keys of h. +// +// The returned slice is only valid until s used again or returned to +// its pool. +func (s *http2sorter) Keys(h http.Header) []string { + keys := s.v[:0] + for k := range h { + keys = append(keys, k) + } + s.v = keys + sort.Sort(s) + return keys +} + +func (s *http2sorter) SortStrings(ss []string) { + // Our sorter works on s.v, which sorter owns, so + // stash it away while we sort the user's buffer. + save := s.v + s.v = ss + sort.Sort(s) + s.v = save +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type http2incomparable [0]func() diff --git a/h2_bundle.go b/h2_bundle.go index 2c618cfe..d66850a9 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -65,250 +65,6 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s return tlsCn, nil } -var ( - http2VerboseLogs bool - http2logFrameWrites bool - http2logFrameReads bool - http2inTests bool -) - -func init() { - e := os.Getenv("GODEBUG") - if strings.Contains(e, "http2debug=1") { - http2VerboseLogs = true - } - if strings.Contains(e, "http2debug=2") { - http2VerboseLogs = true - http2logFrameWrites = true - http2logFrameReads = true - } -} - -const ( - // ClientPreface is the string that must be sent by new - // connections from clients. - http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - - // NextProtoTLS is the NPN/ALPN protocol negotiated during - // HTTP/2's TLS setup. - http2NextProtoTLS = "h2" - - // http://http2.github.io/http2-spec/#SettingValues - http2initialHeaderTableSize = 4096 - - http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size -) - -var ( - http2clientPreface = []byte(http2ClientPreface) -) - -// Setting is a setting parameter: which setting it is, and its value. -type http2Setting struct { - // ID is which setting is being set. - // See http://http2.github.io/http2-spec/#SettingValues - ID http2SettingID - - // Val is the value. - Val uint32 -} - -func (s http2Setting) String() string { - return fmt.Sprintf("[%v = %d]", s.ID, s.Val) -} - -// Valid reports whether the setting is valid. -func (s http2Setting) Valid() error { - // Limits and error codes from 6.5.2 Defined SETTINGS Parameters - switch s.ID { - case http2SettingEnablePush: - if s.Val != 1 && s.Val != 0 { - return http2ConnectionError(http2ErrCodeProtocol) - } - case http2SettingInitialWindowSize: - if s.Val > 1<<31-1 { - return http2ConnectionError(http2ErrCodeFlowControl) - } - case http2SettingMaxFrameSize: - if s.Val < 16384 || s.Val > 1<<24-1 { - return http2ConnectionError(http2ErrCodeProtocol) - } - } - return nil -} - -// A SettingID is an HTTP/2 setting as defined in -// http://http2.github.io/http2-spec/#iana-settings -type http2SettingID uint16 - -const ( - http2SettingHeaderTableSize http2SettingID = 0x1 - http2SettingEnablePush http2SettingID = 0x2 - http2SettingMaxConcurrentStreams http2SettingID = 0x3 - http2SettingInitialWindowSize http2SettingID = 0x4 - http2SettingMaxFrameSize http2SettingID = 0x5 - http2SettingMaxHeaderListSize http2SettingID = 0x6 -) - -var http2settingName = map[http2SettingID]string{ - http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", - http2SettingEnablePush: "ENABLE_PUSH", - http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", - http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", - http2SettingMaxFrameSize: "MAX_FRAME_SIZE", - http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", -} - -func (s http2SettingID) String() string { - if v, ok := http2settingName[s]; ok { - return v - } - return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) -} - -// validWireHeaderFieldName reports whether v is a valid header field -// name (key). See httpguts.ValidHeaderName for the base rules. -// -// Further, http2 says: -// "Just as in HTTP/1.x, header field names are strings of ASCII -// characters that are compared in a case-insensitive -// fashion. However, header field names MUST be converted to -// lowercase prior to their encoding in HTTP/2. " -func http2validWireHeaderFieldName(v string) bool { - if len(v) == 0 { - return false - } - for _, r := range v { - if !httpguts.IsTokenRune(r) { - return false - } - if 'A' <= r && r <= 'Z' { - return false - } - } - return true -} - -func http2httpCodeString(code int) string { - switch code { - case 200: - return "200" - case 404: - return "404" - } - return strconv.Itoa(code) -} - -// bufWriterPoolBufferSize is the size of bufio.Writer's -// buffers created using bufWriterPool. -// -// TODO: pick a less arbitrary value? this is a bit under -// (3 x typical 1500 byte MTU) at least. Other than that, -// not much thought went into it. -const http2bufWriterPoolBufferSize = 4 << 10 - -var http2bufWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) - }, -} - -func http2mustUint31(v int32) uint32 { - if v < 0 || v > 2147483647 { - panic("out of range") - } - return uint32(v) -} - -// bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC 7230, section 3.3. -func http2bodyAllowedForStatus(status int) bool { - switch { - case status >= 100 && status <= 199: - return false - case status == 204: - return false - case status == 304: - return false - } - return true -} - -type http2httpError struct { - _ http2incomparable - msg string - timeout bool -} - -func (e *http2httpError) Error() string { return e.msg } - -func (e *http2httpError) Timeout() bool { return e.timeout } - -func (e *http2httpError) Temporary() bool { return true } - -var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} - -type http2connectionStater interface { - ConnectionState() tls.ConnectionState -} - -var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} - -type http2sorter struct { - v []string // owned by sorter -} - -func (s *http2sorter) Len() int { return len(s.v) } - -func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } - -func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } - -// Keys returns the sorted keys of h. -// -// The returned slice is only valid until s used again or returned to -// its pool. -func (s *http2sorter) Keys(h http.Header) []string { - keys := s.v[:0] - for k := range h { - keys = append(keys, k) - } - s.v = keys - sort.Sort(s) - return keys -} - -func (s *http2sorter) SortStrings(ss []string) { - // Our sorter works on s.v, which sorter owns, so - // stash it away while we sort the user's buffer. - save := s.v - s.v = ss - sort.Sort(s) - s.v = save -} - -// validPseudoPath reports whether v is a valid :path pseudo-header -// value. It must be either: -// -// *) a non-empty string starting with '/' -// *) the string '*', for OPTIONS requests. -// -// For now this is only used a quick check for deciding when to clean -// up Opaque URLs before sending requests from the Transport. -// See golang.org/issue/16847 -// -// We used to enforce that the path also didn't start with "//", but -// Google's GFE accepts such paths and Chrome sends them, so ignore -// that part of the spec. See golang.org/issue/19103. -func http2validPseudoPath(v string) bool { - return (len(v) > 0 && v[0] == '/') || v == "*" -} - -// incomparable is a zero-width, non-comparable type. Adding it to a struct -// makes that struct also non-comparable, and generally doesn't add -// any size (as long as it's first). -type http2incomparable [0]func() - // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys // that, if present, signals that the map entry is actually for // the response trailers, and not the response headers. The prefix From cc01a0517bda1255d003b604ed6cdf9bf1bf1442 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:22:16 +0800 Subject: [PATCH 332/843] add h2_test.go --- h2_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 h2_test.go diff --git a/h2_test.go b/h2_test.go new file mode 100644 index 00000000..0d2ca500 --- /dev/null +++ b/h2_test.go @@ -0,0 +1,92 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "flag" + "fmt" + "net/http" + "testing" + "time" +) + +func init() { + http2inTests = true + http2DebugGoroutines = true + flag.BoolVar(&http2VerboseLogs, "verboseh2", http2VerboseLogs, "Verbose HTTP/2 debug logging") +} + +func TestSettingString(t *testing.T) { + tests := []struct { + s http2Setting + want string + }{ + {http2Setting{http2SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"}, + {http2Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"}, + } + for i, tt := range tests { + got := fmt.Sprint(tt.s) + if got != tt.want { + t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want) + } + } +} + +func cleanDate(res *http.Response) { + if d := res.Header["Date"]; len(d) == 1 { + d[0] = "XXX" + } +} + +func TestSorterPoolAllocs(t *testing.T) { + ss := []string{"a", "b", "c"} + h := http.Header{ + "a": nil, + "b": nil, + "c": nil, + } + sorter := new(http2sorter) + + if allocs := testing.AllocsPerRun(100, func() { + sorter.SortStrings(ss) + }); allocs >= 1 { + t.Logf("SortStrings allocs = %v; want <1", allocs) + } + + if allocs := testing.AllocsPerRun(5, func() { + if len(sorter.Keys(h)) != 3 { + t.Fatal("wrong result") + } + }); allocs > 0 { + t.Logf("Keys allocs = %v; want <1", allocs) + } +} + +// waitCondition reports whether fn eventually returned true, +// checking immediately and then every checkEvery amount, +// until waitFor has elapsed, at which point it returns false. +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} + +// waitErrCondition is like waitCondition but with errors instead of bools. +func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { + deadline := time.Now().Add(waitFor) + var err error + for time.Now().Before(deadline) { + if err = fn(); err == nil { + return nil + } + time.Sleep(checkEvery) + } + return err +} From be60f560bd5e25b1b4ff238cb3cb214836a8304f Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:31:02 +0800 Subject: [PATCH 333/843] remove more unused code --- h2_bundle.go | 166 --------------------------------------------------- h2_errors.go | 9 --- 2 files changed, 175 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index d66850a9..fcfa2026 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -65,54 +65,6 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s return tlsCn, nil } -// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys -// that, if present, signals that the map entry is actually for -// the response trailers, and not the response headers. The prefix -// is stripped after the ServeHTTP call finishes and the values are -// sent in the trailers. -// -// This mechanism is intended only for trailers that are not known -// prior to the headers being written. If the set of trailers is fixed -// or known before the header is written, the normal Go trailers mechanism -// is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers -const http2TrailerPrefix = "Trailer:" - -// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. -func http2checkWriteHeaderCode(code int) { - // Issue 22880: require valid WriteHeader status codes. - // For now we only enforce that it's three digits. - // In the future we might block things over 599 (600 and above aren't defined - // at http://httpwg.org/specs/rfc7231.html#status.codes) - // and we might block under 200 (once we have more mature 1xx support). - // But for now any three digits. - // - // We used to send "HTTP/1.1 000 0" on the wire in responses but there's - // no equivalent bogus thing we can realistically send in HTTP/2, - // so we'll consistently panic instead and help people find their bugs - // early. (We can't return an error from WriteHeader even if we wanted to.) - if code < 100 || code > 999 { - panic(fmt.Sprintf("invalid WriteHeader code %v", code)) - } -} - -func http2cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } - return h2 -} - -// Push errors. -var ( - http2ErrRecursivePush = errors.New("http2: recursive push not allowed") - http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") -) - // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func http2foreachHeaderElement(v string, fn func(string)) { @@ -131,51 +83,6 @@ func http2foreachHeaderElement(v string, fn func(string)) { } } -// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 -var http2connHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Connection", - "Transfer-Encoding", - "Upgrade", -} - -// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, -// per RFC 7540 Section 8.1.2.2. -// The returned error is reported to users. -func http2checkValidHTTP2RequestHeaders(h http.Header) error { - for _, k := range http2connHeaders { - if _, ok := h[k]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", k) - } - } - te := h["Te"] - if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { - return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) - } - return nil -} - -func http2new400Handler(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - http.Error(w, err.Error(), http.StatusBadRequest) - } -} - -// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives -// disabled. See comments on h1ServerShutdownChan above for why -// the code is written this way. -func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { - var x interface{} = hs - type I interface { - doKeepAlives() bool - } - if hs, ok := x.(I); ok { - return !hs.doKeepAlives() - } - return false -} - const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -3174,76 +3081,3 @@ func (t *http2Transport) idleConnTimeout() time.Duration { } return 0 } - -// writeContext is the interface needed by the various frame writer -// types below. All the writeFrame methods below are scheduled via the -// frame writing scheduler (see writeScheduler in writesched.go). -// -// This interface is implemented by *serverConn. -// -// TODO: decide whether to a) use this in the client code (which didn't -// end up using this yet, because it has a simpler design, not -// currently implementing priorities), or b) delete this and -// make the server code a bit more concrete. -type http2writeContext interface { - Framer() *http2Framer - Flush() error - CloseConn() error - // HeaderEncoder returns an HPACK encoder that writes to the - // returned buffer. - HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) -} - -func (se http2StreamError) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) -} - -func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } - -func http2encKV(enc *hpack.Encoder, k, v string) { - if http2VerboseLogs { - log.Printf("http2: server encoding header %q = %q", k, v) - } - enc.WriteField(hpack.HeaderField{Name: k, Value: v}) -} - -// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) -// is encoded only if k is in keys. -func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { - if keys == nil { - sorter := http2sorterPool.Get().(*http2sorter) - // Using defer here, since the returned keys from the - // sorter.Keys method is only valid until the sorter - // is returned: - defer http2sorterPool.Put(sorter) - keys = sorter.Keys(h) - } - for _, k := range keys { - vv := h[k] - k, isAscii := http2lowerHeader(k) - if !isAscii { - // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header - // field names have to be ASCII characters (just as in HTTP/1.x). - continue - } - if !http2validWireHeaderFieldName(k) { - // Skip it as backup paranoia. Per - // golang.org/issue/14048, these should - // already be rejected at a higher level. - continue - } - isTE := k == "transfer-encoding" - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // TODO: return an error? golang.org/issue/14048 - // For now just omit it. - continue - } - // TODO: more of "8.1.2.2 Connection-Specific Header Fields" - if isTE && v != "trailers" { - continue - } - http2encKV(enc, k, v) - } - } -} diff --git a/h2_errors.go b/h2_errors.go index a6fbe07e..e81daf2e 100644 --- a/h2_errors.go +++ b/h2_errors.go @@ -92,15 +92,6 @@ func (e http2StreamError) Error() string { return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) } -// 6.9.1 The Flow Control Window -// "If a sender receives a WINDOW_UPDATE that causes a flow control -// window to exceed this maximum it MUST terminate either the stream -// or the connection, as appropriate. For streams, [...]; for the -// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." -type http2goAwayFlowError struct{} - -func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } - // connError represents an HTTP/2 ConnectionError error code, along // with a string (for debugging) explaining why. // From b2816fccec0d88caffe7c07180eae4207ee7b075 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:39:00 +0800 Subject: [PATCH 334/843] add h2_go115.go and h2_not_go115.go --- h2_bundle.go | 14 -------------- h2_go115.go | 27 +++++++++++++++++++++++++++ h2_not_go115.go | 26 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 14 deletions(-) create mode 100644 h2_go115.go create mode 100644 h2_not_go115.go diff --git a/h2_bundle.go b/h2_bundle.go index fcfa2026..dde6f73e 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -51,20 +51,6 @@ import ( "golang.org/x/net/idna" ) -// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS -// connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { - dialer := &tls.Dialer{ - Config: cfg, - } - cn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - tlsCn := cn.(TLSConn) // DialContext comment promises this will always succeed - return tlsCn, nil -} - // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func http2foreachHeaderElement(v string, fn func(string)) { diff --git a/h2_go115.go b/h2_go115.go new file mode 100644 index 00000000..66d1cfbb --- /dev/null +++ b/h2_go115.go @@ -0,0 +1,27 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.15 +// +build go1.15 + +package req + +import ( + "context" + "crypto/tls" +) + +// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS +// connection. +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { + dialer := &tls.Dialer{ + Config: cfg, + } + cn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := cn.(TLSConn) // DialContext comment promises this will always succeed + return tlsCn, nil +} diff --git a/h2_not_go115.go b/h2_not_go115.go new file mode 100644 index 00000000..cebaf62b --- /dev/null +++ b/h2_not_go115.go @@ -0,0 +1,26 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.15 +// +build !go1.15 + +package req + +// dialTLSWithContext opens a TLS connection. +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { + cn, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + if err := cn.Handshake(); err != nil { + return nil, err + } + if cfg.InsecureSkipVerify { + return cn, nil + } + if err := cn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + return cn, nil +} From 105395d82d6967bec6bc9ffeb6175aa3febc0168 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:42:06 +0800 Subject: [PATCH 335/843] ajust function's position --- h2_bundle.go | 57 ++++++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/h2_bundle.go b/h2_bundle.go index dde6f73e..316283fe 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1,21 +1,8 @@ -//go:build !nethttpomithttp2 -// +build !nethttpomithttp2 +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. -// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 - -// Package http2 implements the HTTP/2 protocol. -// -// This package is low-level and intended to be used directly by very -// few people. Most users will use it indirectly through the automatic -// use by the net/http package (from Go 1.6 and later). -// For use in earlier Go versions see ConfigureServer. (Transport support -// requires Go 1.6 or later) -// -// See https://http2.github.io/ for more information on HTTP/2. -// -// See https://http2.golang.org/ for a test server running this code. -// +// Transport code. package req @@ -51,24 +38,6 @@ import ( "golang.org/x/net/idna" ) -// foreachHeaderElement splits v according to the "#rule" construction -// in RFC 7230 section 7 and calls fn for each non-empty element. -func http2foreachHeaderElement(v string, fn func(string)) { - v = textproto.TrimString(v) - if v == "" { - return - } - if !strings.Contains(v, ",") { - fn(v) - return - } - for _, f := range strings.Split(v, ",") { - if f = textproto.TrimString(f); f != "" { - fn(f) - } - } -} - const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -2306,6 +2275,24 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return nil } +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + // may return error types nil, or ConnectionError. Any other error value // is a StreamError of type ErrCodeProtocol. The returned error in that case // is the detail. From af9658dc2f85745db734081546f501a6809c7a82 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 13 Feb 2022 21:42:46 +0800 Subject: [PATCH 336/843] h2_bundle.go -> h2_transport.go --- h2_bundle.go => h2_transport.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename h2_bundle.go => h2_transport.go (100%) diff --git a/h2_bundle.go b/h2_transport.go similarity index 100% rename from h2_bundle.go rename to h2_transport.go From 545564377ad55d5ff66a747806bbbf332d460379 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 09:40:56 +0800 Subject: [PATCH 337/843] add h2_transport_test.go --- h2_frame.go | 2 +- h2_server_test.go | 5747 +++++++++++++++++++++++++++++++++++++++++ h2_transport.go | 20 +- h2_transport_test.go | 5855 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 11611 insertions(+), 13 deletions(-) create mode 100644 h2_server_test.go create mode 100644 h2_transport_test.go diff --git a/h2_frame.go b/h2_frame.go index 4a9ebf85..2c153cb0 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -524,7 +524,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { var dumps []*dumper - if fr.cc != nil { + if fr.cc != nil && fr.cc.t.t1 != nil { dumps = getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) } if len(dumps) > 0 { diff --git a/h2_server_test.go b/h2_server_test.go new file mode 100644 index 00000000..1f9f8465 --- /dev/null +++ b/h2_server_test.go @@ -0,0 +1,5747 @@ +package req + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/xml" + "errors" + "flag" + "fmt" + "github.com/imroc/req/v3/internal/ascii" + "io" + "io/ioutil" + "log" + "math" + "net" + "net/http" + "net/http/httptest" + "net/textproto" + "net/url" + "os" + "reflect" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" +) + +// A list of the possible cipher suite ids. Taken from +// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt + +const ( + http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + // Reserved uint16 = 0x001C-1D + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + // Reserved uint16 = 0x0047-4F + // Reserved uint16 = 0x0050-58 + // Reserved uint16 = 0x0059-5C + // Unassigned uint16 = 0x005D-5F + // Reserved uint16 = 0x0060-66 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + // Unassigned uint16 = 0x006E-83 + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E + http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 + http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA + http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + // Unassigned uint16 = 0x00C6-FE + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF + // Unassigned uint16 = 0x01-55,* + http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 + // Unassigned uint16 = 0x5601 - 0xC000 + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B + http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C + http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F + http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 + http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 + http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 + http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 + http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 + http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 + http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA + http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF + // Unassigned uint16 = 0xC0B0-FF + // Unassigned uint16 = 0xC1-CB,* + // Unassigned uint16 = 0xCC00-A7 + http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 + http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 + http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA + http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB + http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC + http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD + http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE +) + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +// References: +// https://tools.ietf.org/html/rfc7540#appendix-A +// Reject cipher suites from Appendix A. +// "This list includes those cipher suites that do not +// offer an ephemeral key exchange and those that are +// based on the TLS null, stream or block cipher type" +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case http2cipher_TLS_NULL_WITH_NULL_NULL, + http2cipher_TLS_RSA_WITH_NULL_MD5, + http2cipher_TLS_RSA_WITH_NULL_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_RC4_128_SHA, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, + http2cipher_TLS_KRB5_WITH_RC4_128_MD5, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_PSK_WITH_NULL_SHA, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_NULL_SHA256, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_NULL_SHA256, + http2cipher_TLS_PSK_WITH_NULL_SHA384, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_AES_128_CCM, + http2cipher_TLS_RSA_WITH_AES_256_CCM, + http2cipher_TLS_RSA_WITH_AES_128_CCM_8, + http2cipher_TLS_RSA_WITH_AES_256_CCM_8, + http2cipher_TLS_PSK_WITH_AES_128_CCM, + http2cipher_TLS_PSK_WITH_AES_256_CCM, + http2cipher_TLS_PSK_WITH_AES_128_CCM_8, + http2cipher_TLS_PSK_WITH_AES_256_CCM_8: + return true + default: + return false + } +} + +const ( + http2prefaceTimeout = 10 * time.Second + http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + http2handlerChunkWriteSize = 4 << 10 + http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + http2maxQueuedControlFrames = 10000 +) + +var ( + http2errClientDisconnected = errors.New("client disconnected") + http2errClosedBody = errors.New("body closed by handler") + http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + http2errStreamClosed = errors.New("http2: stream closed") +) + +var http2responseWriterStatePool = sync.Pool{ + New: func() interface{} { + rws := &http2responseWriterState{} + rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) + return rws + }, +} + +// Test hooks. +var ( + http2testHookOnConn func() + http2testHookGetServerConn func(*http2serverConn) + http2testHookOnPanicMu *sync.Mutex // nil except in tests + http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) +) + +// Server is an HTTP/2 server. +type http2Server struct { + // MaxHandlers limits the number of http.Handler ServeHTTP goroutines + // which may run at a time over all connections. + // Negative or zero no limit. + // TODO: implement + MaxHandlers int + + // MaxConcurrentStreams optionally specifies the number of + // concurrent streams that each client may have open at a + // time. This is unrelated to the number of http.Handler goroutines + // which may be active globally, which is MaxHandlers. + // If zero, MaxConcurrentStreams defaults to at least 100, per + // the HTTP/2 spec's recommendations. + MaxConcurrentStreams uint32 + + // MaxReadFrameSize optionally specifies the largest frame + // this server is willing to read. A valid value is between + // 16k and 16M, inclusive. If zero or otherwise invalid, a + // default value is used. + MaxReadFrameSize uint32 + + // PermitProhibitedCipherSuites, if true, permits the use of + // cipher suites prohibited by the HTTP/2 spec. + PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // MaxUploadBufferPerConnection is the size of the initial flow + // control window for each connections. The HTTP/2 spec does not + // allow this to be smaller than 65535 or larger than 2^32-1. + // If the value is outside this range, a default value will be + // used instead. + MaxUploadBufferPerConnection int32 + + // MaxUploadBufferPerStream is the size of the initial flow control + // window for each stream. The HTTP/2 spec does not allow this to + // be larger than 2^32-1. If the value is zero or larger than the + // maximum, a default value will be used instead. + MaxUploadBufferPerStream int32 + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler + + // CountError, if non-nil, is called on HTTP/2 server errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + + // Internal state. This is a pointer (rather than embedded directly) + // so that we don't embed a Mutex in this struct, which will make the + // struct non-copyable, which might break some callers. + state *http2serverInternalState +} + +func (s *http2Server) initialConnRecvWindowSize() int32 { + if s.MaxUploadBufferPerConnection > http2initialWindowSize { + return s.MaxUploadBufferPerConnection + } + return 1 << 20 +} + +func (s *http2Server) initialStreamRecvWindowSize() int32 { + if s.MaxUploadBufferPerStream > 0 { + return s.MaxUploadBufferPerStream + } + return 1 << 20 +} + +func (s *http2Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { + return v + } + return http2defaultMaxReadFrameSize +} + +func (s *http2Server) maxConcurrentStreams() uint32 { + if v := s.MaxConcurrentStreams; v > 0 { + return v + } + return http2defaultMaxStreams +} + +// maxQueuedControlFrames is the maximum number of control frames like +// SETTINGS, PING and RST_STREAM that will be queued for writing before +// the connection is closed to prevent memory exhaustion attacks. +func (s *http2Server) maxQueuedControlFrames() int { + // TODO: if anybody asks, add a Server field, and remember to define the + // behavior of negative values. + return http2maxQueuedControlFrames +} + +// ServeConn serves HTTP/2 requests on the provided connection and +// blocks until the connection is no longer readable. +// +// ServeConn starts speaking HTTP/2 assuming that c has not had any +// reads or writes. It writes its initial settings frame and expects +// to be able to read the preface and settings frame from the +// client. If c has a ConnectionState method like a *tls.Conn, the +// ConnectionState is used to verify the TLS ciphersuite and to set +// the Request.TLS field in Handlers. +// +// ServeConn does not support h2c by itself. Any h2c support must be +// implemented in terms of providing a suitably-behaving net.Conn. +// +// The opts parameter is optional. If nil, default values are used. +func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + baseCtx, cancel := http2serverConnBaseContext(c, opts) + defer cancel() + + sc := &http2serverConn{ + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + serveMsgCh: make(chan interface{}, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync + bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" + advMaxStreams: s.maxConcurrentStreams(), + initialStreamSendWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, + headerTableSize: http2initialHeaderTableSize, + serveG: http2newGoroutineLock(), + pushEnabled: true, + } + + s.state.registerConn(sc) + defer s.state.unregisterConn(sc) + + // The net/http package sets the write deadline from the + // http.Server.WriteTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already set. + // Write deadlines are set per stream in serverConn.newStream. + // Disarm the net.Conn write deadline here. + if sc.hs.WriteTimeout != 0 { + sc.conn.SetWriteDeadline(time.Time{}) + } + + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + + // These start at the RFC-specified defaults. If there is a higher + // configured value for inflow, that will be updated when we send a + // WINDOW_UPDATE shortly after sending SETTINGS. + sc.flow.add(http2initialWindowSize) + sc.inflow.add(http2initialWindowSize) + sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) + + fr := http2NewFramer(sc.bw, c) + fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr.MaxHeaderListSize = sc.maxHeaderListSize() + fr.SetMaxReadFrameSize(s.maxReadFrameSize()) + sc.framer = fr + + if tc, ok := c.(http2connectionStater); ok { + sc.tlsState = new(tls.ConnectionState) + *sc.tlsState = tc.ConnectionState() + // 9.2 Use of TLS Features + // An implementation of HTTP/2 over TLS MUST use TLS + // 1.2 or higher with the restrictions on feature set + // and cipher suite described in this section. Due to + // implementation limitations, it might not be + // possible to fail TLS negotiation. An endpoint MUST + // immediately terminate an HTTP/2 connection that + // does not meet the TLS requirements described in + // this section with a connection error (Section + // 5.4.1) of type INADEQUATE_SECURITY. + if sc.tlsState.Version < tls.VersionTLS12 { + sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low") + return + } + + if sc.tlsState.ServerName == "" { + // Client must use SNI, but we don't enforce that anymore, + // since it was causing problems when connecting to bare IP + // addresses during development. + // + // TODO: optionally enforce? Or enforce at the time we receive + // a new request, and verify the ServerName matches the :authority? + // But that precludes proxy situations, perhaps. + // + // So for now, do nothing here again. + } + + if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) { + // "Endpoints MAY choose to generate a connection error + // (Section 5.4.1) of type INADEQUATE_SECURITY if one of + // the prohibited cipher suites are negotiated." + // + // We choose that. In my opinion, the spec is weak + // here. It also says both parties must support at least + // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no + // excuses here. If we really must, we could allow an + // "AllowInsecureWeakCiphers" option on the server later. + // Let's see how it plays out first. + sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) + return + } + } + + if hook := http2testHookGetServerConn; hook != nil { + hook(sc) + } + sc.serve() +} + +type http2serverInternalState struct { + mu sync.Mutex + activeConns map[*http2serverConn]struct{} +} + +func (s *http2serverInternalState) registerConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + s.activeConns[sc] = struct{}{} + s.mu.Unlock() +} + +func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() +} + +func (s *http2serverInternalState) startGracefulShutdown() { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + for sc := range s.activeConns { + sc.startGracefulShutdown() + } + s.mu.Unlock() +} + +// ServeConnOpts are options for the Server.ServeConn method. +type http2ServeConnOpts struct { + // Context is the base context to use. + // If nil, context.Background is used. + Context context.Context + + // BaseConfig optionally sets the base configuration + // for values. If nil, defaults are used. + BaseConfig *http.Server + + // Handler specifies which handler to use for processing + // requests. If nil, BaseConfig.Handler is used. If BaseConfig + // or BaseConfig.Handler is nil, http.DefaultServeMux is used. + Handler http.Handler +} + +func (o *http2ServeConnOpts) context() context.Context { + if o != nil && o.Context != nil { + return o.Context + } + return context.Background() +} + +func (o *http2ServeConnOpts) baseConfig() *http.Server { + if o != nil && o.BaseConfig != nil { + return o.BaseConfig + } + return new(http.Server) +} + +func (o *http2ServeConnOpts) handler() http.Handler { + if o != nil { + if o.Handler != nil { + return o.Handler + } + if o.BaseConfig != nil && o.BaseConfig.Handler != nil { + return o.BaseConfig.Handler + } + } + return http.DefaultServeMux +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { + ctx, cancel = context.WithCancel(opts.context()) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, http.ServerContextKey, hs) + } + return +} + +// bufferedWriter is a buffered writer that writes to w. +// Its buffered writer is lazily allocated as needed, to minimize +// idle memory usage with many connections. +type http2bufferedWriter struct { + _ http2incomparable + w io.Writer // immutable + bw *bufio.Writer // non-nil when data is buffered +} + +func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { + return &http2bufferedWriter{w: w} +} + +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + +func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { + if w.bw == nil { + bw := http2bufWriterPool.Get().(*bufio.Writer) + bw.Reset(w.w) + w.bw = bw + } + return w.bw.Write(p) +} + +func (w *http2bufferedWriter) Flush() error { + bw := w.bw + if bw == nil { + return nil + } + err := bw.Flush() + bw.Reset(nil) + http2bufWriterPool.Put(bw) + w.bw = nil + return err +} + +func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { + sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) + // ignoring errors. hanging up anyway. + sc.framer.WriteGoAway(0, err, []byte(debug)) + sc.bw.Flush() + sc.conn.Close() +} + +type http2serverConn struct { + // Immutable: + srv *http2Server + hs *http.Server + conn net.Conn + bw *http2bufferedWriter // writing to conn + handler http.Handler + baseCtx context.Context + framer *http2Framer + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http + remoteAddrStr string + writeSched http2WriteScheduler + + // Everything following is owned by the serve loop; use serveG.check(): + serveG http2goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + queuedControlFrames int // control frames in the writeSched queue + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes + streams map[uint32]*http2stream + initialStreamSendWindowSize int32 + maxFrameSize int32 + headerTableSize uint32 + peerMaxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode http2ErrCode + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused + + // Owned by the writeFrameAsync goroutine: + headerWriteBuf bytes.Buffer + hpackEncoder *hpack.Encoder + + // Used by startGracefulShutdown. + shutdownOnce sync.Once +} + +func (sc *http2serverConn) maxHeaderListSize() uint32 { + n := sc.hs.MaxHeaderBytes + if n <= 0 { + n = http.DefaultMaxHeaderBytes + } + // http2's count is in a slightly different unit and includes 32 bytes per pair. + // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. + const perFieldOverhead = 32 // per http2 spec + const typicalHeaders = 10 // conservative + return uint32(n + typicalHeaders*perFieldOverhead) +} + +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + +// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). +type http2closeWaiter chan struct{} + +// Init makes a closeWaiter usable. +// It exists because so a closeWaiter value can be placed inside a +// larger struct and have the Mutex and Cond's memory in the same +// allocation. +func (cw *http2closeWaiter) Init() { + *cw = make(chan struct{}) +} + +// Close marks the closeWaiter as closed and unblocks any waiters. +func (cw http2closeWaiter) Close() { + close(cw) +} + +// Wait waits for the closeWaiter to become closed. +func (cw http2closeWaiter) Wait() { + <-cw +} + +// stream represents a stream. This is the minimal metadata needed by +// the serve goroutine. Most of the actual stream state is owned by +// the http.Handler's goroutine in the responseWriter. Because the +// responseWriter's responseWriterState is recycled at the end of a +// handler, this struct intentionally has no pointer to the +// *responseWriter{,State} itself, as the Handler ending nils out the +// responseWriter's state field. +type http2stream struct { + // immutable: + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx context.Context + cancelCtx func() + + // owned by serverConn's serve loop: + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2flow // limits writing from Handler to client + inflow http2flow // what the client is allowed to POST/etc to us + state http2streamState + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + writeDeadline *time.Timer // nil if unused + + trailer http.Header // accumulated trailers + reqTrailer http.Header // handler's Request.Trailer +} + +func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } + +func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } + +func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } + +func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { + return sc.hpackEncoder, &sc.headerWriteBuf +} + +const ( + // SETTINGS_MAX_FRAME_SIZE default + // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + http2initialMaxFrameSize = 16384 + + http2defaultMaxReadFrameSize = 1 << 20 +) + +type http2streamState int + +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. +const ( + http2stateIdle http2streamState = iota + http2stateOpen + http2stateHalfClosedLocal + http2stateHalfClosedRemote + http2stateClosed +) + +var http2stateName = [...]string{ + http2stateIdle: "Idle", + http2stateOpen: "Open", + http2stateHalfClosedLocal: "HalfClosedLocal", + http2stateHalfClosedRemote: "HalfClosedRemote", + http2stateClosed: "Closed", +} + +func (st http2streamState) String() string { + return http2stateName[st] +} + +func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { + sc.serveG.check() + // http://tools.ietf.org/html/rfc7540#section-5.1 + if st, ok := sc.streams[streamID]; ok { + return st.state, st + } + // "The first use of a new stream identifier implicitly closes all + // streams in the "idle" state that might have been initiated by + // that peer with a lower-valued stream identifier. For example, if + // a client sends a HEADERS frame on stream 7 without ever sending a + // frame on stream 5, then stream 5 transitions to the "closed" + // state when the first frame for stream 7 is sent or received." + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } + } + return http2stateIdle, nil +} + +// setConnState calls the net/http ConnState hook for this connection, if configured. +// Note that the net/http package does StateNew and StateClosed for us. +// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. +func (sc *http2serverConn) setConnState(state http.ConnState) { + if sc.hs.ConnState != nil { + sc.hs.ConnState(sc.conn, state) + } +} + +func (sc *http2serverConn) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) logf(format string, args ...interface{}) { + if lg := sc.hs.ErrorLog; lg != nil { + lg.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// errno returns v's underlying uintptr, else 0. +// +// TODO: remove this helper function once http2 can use build +// tags. See comment in isClosedConnError. +func http2errno(v error) uintptr { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { + return uintptr(rv.Uint()) + } + return 0 +} + +// isClosedConnError reports whether err is an error from use of a closed +// network connection. +func http2isClosedConnError(err error) bool { + if err == nil { + return false + } + + // TODO: remove this string search and be more like the Windows + // case below. That might involve modifying the standard library + // to return better error types. + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + + // TODO(bradfitz): x/tools/cmd/bundle doesn't really support + // build tags, so I can't make an http2_windows.go file with + // Windows-specific stuff. Fix that and move this, once we + // have a way to bundle this into std's net/http somehow. + if runtime.GOOS == "windows" { + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { + const WSAECONNABORTED = 10053 + const WSAECONNRESET = 10054 + if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + return true + } + } + } + } + return false +} + +func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { + if err == nil { + return + } + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { + // Boring, expected errors. + sc.vlogf(format, args...) + } else { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) canonicalHeader(v string) string { + sc.serveG.check() + http2buildCommonHeaderMapsOnce() + cv, ok := http2commonCanonHeader[v] + if ok { + return cv + } + cv, ok = sc.canonHeader[v] + if ok { + return cv + } + if sc.canonHeader == nil { + sc.canonHeader = make(map[string]string) + } + cv = http.CanonicalHeaderKey(v) + // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of + // entries in the canonHeader cache. This should be larger than the number + // of unique, uncommon header keys likely to be sent by the peer, while not + // so high as to permit unreasonable memory usage if the peer sends an unbounded + // number of unique header keys. + const maxCachedCanonicalHeaders = 32 + if len(sc.canonHeader) < maxCachedCanonicalHeaders { + sc.canonHeader[v] = cv + } + return cv +} + +type http2readFrameResult struct { + f http2Frame // valid until readMore is called + err error + + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() +} + +// A gate lets two goroutines coordinate their activities. +type http2gate chan struct{} + +func (g http2gate) Done() { g <- struct{}{} } + +func (g http2gate) Wait() { <-g } + +// readFrames is the loop that reads incoming frames. +// It takes care to only read one frame at a time, blocking until the +// consumer is done with the frame. +// It's run on its own goroutine. +func (sc *http2serverConn) readFrames() { + gate := make(http2gate) + gateDone := gate.Done + for { + f, err := sc.framer.ReadFrame() + select { + case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: + case <-sc.doneServing: + return + } + select { + case <-gate: + case <-sc.doneServing: + return + } + if http2terminalReadFrameError(err) { + return + } + } +} + +// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. +type http2frameWriteResult struct { + _ http2incomparable + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call +} + +// writeFrameAsync runs in its own goroutine and writes a single frame +// and then reports when it's done. +// At most one goroutine can be running writeFrameAsync at a time per +// serverConn. +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} +} + +func (sc *http2serverConn) closeAllStreamsOnConnClose() { + sc.serveG.check() + for _, st := range sc.streams { + sc.closeStream(st, http2errClientDisconnected) + } +} + +func (sc *http2serverConn) stopShutdownTimer() { + sc.serveG.check() + if t := sc.shutdownTimer; t != nil { + t.Stop() + } +} + +func (sc *http2serverConn) notePanic() { + // Note: this is for serverConn.serve panicking, not http.Handler code. + if http2testHookOnPanicMu != nil { + http2testHookOnPanicMu.Lock() + defer http2testHookOnPanicMu.Unlock() + } + if http2testHookOnPanic != nil { + if e := recover(); e != nil { + if http2testHookOnPanic(sc, e) { + panic(e) + } + } + } +} + +func (sc *http2serverConn) serve() { + sc.serveG.check() + defer sc.notePanic() + defer sc.conn.Close() + defer sc.closeAllStreamsOnConnClose() + defer sc.stopShutdownTimer() + defer close(sc.doneServing) // unblocks handlers trying to send + + if http2VerboseLogs { + sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) + } + + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeSettings{ + {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, + {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, + }, + }) + sc.unackedSettings++ + + // Each connection starts with initialWindowSize inflow tokens. + // If a higher value is configured, we add more tokens. + if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { + sc.sendWindowUpdate(nil, int(diff)) + } + + if err := sc.readPreface(); err != nil { + sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) + return + } + // Now that we've got the preface, get us out of the + // "StateNew" state. We can't go directly to idle, though. + // Active means we read some data and anticipate a request. We'll + // do another Active when we get a HEADERS frame. + sc.setConnState(http.StateActive) + sc.setConnState(http.StateIdle) + + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + defer sc.idleTimer.Stop() + } + + go sc.readFrames() // closed by defer sc.conn.Close above + + settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + defer settingsTimer.Stop() + + loopNum := 0 + for { + loopNum++ + select { + case wr := <-sc.wantWriteFrameCh: + if se, ok := wr.write.(http2StreamError); ok { + sc.resetStream(se) + break + } + sc.writeFrame(wr) + case res := <-sc.wroteFrameCh: + sc.wroteFrame(res) + case res := <-sc.readFrameCh: + // Process any written frames before reading new frames from the client since a + // written frame could have triggered a new stream to be started. + if sc.writingFrameAsync { + select { + case wroteRes := <-sc.wroteFrameCh: + sc.wroteFrame(wroteRes) + default: + } + } + if !sc.processFrameFromReader(res) { + return + } + res.readMore() + if settingsTimer != nil { + settingsTimer.Stop() + settingsTimer = nil + } + case m := <-sc.bodyReadCh: + sc.noteBodyRead(m.st, m.n) + case msg := <-sc.serveMsgCh: + switch v := msg.(type) { + case func(int): + v(loopNum) // for testing + case *http2serverMessage: + switch v { + case http2settingsTimerMsg: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case http2idleTimerMsg: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) + case http2shutdownTimerMsg: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case http2gracefulShutdownMsg: + sc.startGracefulShutdownInternal() + default: + panic("unknown timer") + } + case *http2startPushRequest: + sc.startPush(v) + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } + } + + // If the peer is causing us to generate a lot of control frames, + // but not reading them from us, assume they are trying to make us + // run out of memory. + if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { + sc.vlogf("http2: too many control frames in send queue, closing connection") + return + } + + // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY + // with no error code (graceful shutdown), don't start the timer until + // all open streams have been completed. + sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame + gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 + if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { + sc.shutDownIn(http2goAwayTimeout) + } + } +} + +func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { + select { + case <-sc.doneServing: + case <-sharedCh: + close(privateCh) + } +} + +type http2serverMessage int + +// Message values sent to serveMsgCh. +var ( + http2settingsTimerMsg = new(http2serverMessage) + http2idleTimerMsg = new(http2serverMessage) + http2shutdownTimerMsg = new(http2serverMessage) + http2gracefulShutdownMsg = new(http2serverMessage) +) + +func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } + +func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } + +func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } + +func (sc *http2serverConn) sendServeMsg(msg interface{}) { + sc.serveG.checkNotOn() // NOT + select { + case sc.serveMsgCh <- msg: + case <-sc.doneServing: + } +} + +var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") + +// readPreface reads the ClientPreface greeting from the peer or +// returns errPrefaceTimeout on timeout, or an error if the greeting +// is invalid. +func (sc *http2serverConn) readPreface() error { + errc := make(chan error, 1) + go func() { + // Read the client preface + buf := make([]byte, len(http2ClientPreface)) + if _, err := io.ReadFull(sc.conn, buf); err != nil { + errc <- err + } else if !bytes.Equal(buf, http2clientPreface) { + errc <- fmt.Errorf("bogus greeting %q", buf) + } else { + errc <- nil + } + }() + timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *http2Server? + defer timer.Stop() + select { + case <-timer.C: + return http2errPrefaceTimeout + case err := <-errc: + if err == nil { + if http2VerboseLogs { + sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) + } + } + return err + } +} + +var http2errChanPool = sync.Pool{ + New: func() interface{} { return make(chan error, 1) }, +} + +var http2writeDataPool = sync.Pool{ + New: func() interface{} { return new(http2writeData) }, +} + +// writeDataFromHandler writes DATA response frames from a handler on +// the given stream. +func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { + ch := http2errChanPool.Get().(chan error) + writeArg := http2writeDataPool.Get().(*http2writeData) + *writeArg = http2writeData{stream.id, data, endStream} + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: writeArg, + stream: stream, + done: ch, + }) + if err != nil { + return err + } + var frameWriteDone bool // the frame write is done (successfully or not) + select { + case err = <-ch: + frameWriteDone = true + case <-sc.doneServing: + return http2errClientDisconnected + case <-stream.cw: + // If both ch and stream.cw were ready (as might + // happen on the final Write after an http.Handler + // ends), prefer the write result. Otherwise this + // might just be us successfully closing the stream. + // The writeFrameAsync and serve goroutines guarantee + // that the ch send will happen before the stream.cw + // close. + select { + case err = <-ch: + frameWriteDone = true + default: + return http2errStreamClosed + } + } + http2errChanPool.Put(ch) + if frameWriteDone { + http2writeDataPool.Put(writeArg) + } + return err +} + +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts +// if the connection has gone away. +// +// This must not be run from the serve goroutine itself, else it might +// deadlock writing to sc.wantWriteFrameCh (which is only mildly +// buffered and is read by serve itself). If you're on the serve +// goroutine, call writeFrame instead. +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { + sc.serveG.checkNotOn() // NOT + select { + case sc.wantWriteFrameCh <- wr: + return nil + case <-sc.doneServing: + // Serve loop is gone. + // Client has closed their connection to the server. + return http2errClientDisconnected + } +} + +// writeFrame schedules a frame to write and sends it if there's nothing +// already being written. +// +// There is no pushback here (the serve goroutine never blocks). It's +// the http.Handlers that block, waiting for their previous frames to +// make it onto the wire +// +// If you're not on the serve goroutine, use writeFrameFromHandler instead. +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { + sc.serveG.check() + + // If true, wr will not be written and wr.done will not be signaled. + var ignoreWrite bool + + // We are not allowed to write frames on closed streams. RFC 7540 Section + // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on + // a closed stream." Our server never sends PRIORITY, so that exception + // does not apply. + // + // The serverConn might close an open stream while the stream's handler + // is still running. For example, the server might close a stream when it + // receives bad data from the client. If this happens, the handler might + // attempt to write a frame after the stream has been closed (since the + // handler hasn't yet been notified of the close). In this case, we simply + // ignore the frame. The handler will notice that the stream is closed when + // it waits for the frame to be written. + // + // As an exception to this rule, we allow sending RST_STREAM after close. + // This allows us to immediately reject new streams without tracking any + // state for those streams (except for the queued RST_STREAM frame). This + // may result in duplicate RST_STREAMs in some cases, but the client should + // ignore those. + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + // Don't send a 100-continue response if we've already sent headers. + // See golang.org/issue/14030. + switch wr.write.(type) { + case *http2writeResHeaders: + wr.stream.wroteHeaders = true + case http2write100ContinueHeadersFrame: + if wr.stream.wroteHeaders { + // We do not need to notify wr.done because this frame is + // never written with wr.done != nil. + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } + ignoreWrite = true + } + } + + if !ignoreWrite { + if wr.isControl() { + sc.queuedControlFrames++ + // For extra safety, detect wraparounds, which should not happen, + // and pull the plug. + if sc.queuedControlFrames < 0 { + sc.conn.Close() + } + } + sc.writeSched.Push(wr) + } + sc.scheduleFrameWrite() +} + +// startFrameWrite starts a goroutine to write wr (in a separate +// goroutine since that might block on the network), and updates the +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { + sc.serveG.check() + if sc.writingFrame { + panic("internal error: can only be writing one frame at a time") + } + + st := wr.stream + if st != nil { + switch st.state { + case http2stateHalfClosedLocal: + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: + // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE + // in this state. (We never send PRIORITY from the server, so that is not checked.) + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) + } + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return + } + } + + sc.writingFrame = true + sc.needsFrameFlush = true + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } +} + +// errHandlerPanicked is the error given to any callers blocked in a read from +// Request.Body when the main goroutine panics. Since most handlers read in the +// main ServeHTTP goroutine, this will show up rarely. +var http2errHandlerPanicked = errors.New("http2: handler panicked") + +// wroteFrame is called on the serve goroutine with the result of +// whatever happened on writeFrameAsync. +func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { + sc.serveG.check() + if !sc.writingFrame { + panic("internal error: expected to be already writing a frame") + } + sc.writingFrame = false + sc.writingFrameAsync = false + + wr := res.wr + + if http2writeEndsStream(wr.write) { + st := wr.stream + if st == nil { + panic("internal error: expecting non-nil stream") + } + switch st.state { + case http2stateOpen: + // Here we would go to stateHalfClosedLocal in + // theory, but since our handler is done and + // the net/http package provides no mechanism + // for closing a ResponseWriter while still + // reading data (see possible TODO at top of + // this file), we go into closed state here + // anyway, after telling the peer we're + // hanging up on them. We'll transition to + // stateClosed after the RST_STREAM frame is + // written. + st.state = http2stateHalfClosedLocal + // Section 8.1: a server MAY request that the client abort + // transmission of a request without error by sending a + // RST_STREAM with an error code of NO_ERROR after sending + // a complete response. + sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) + case http2stateHalfClosedRemote: + sc.closeStream(st, http2errHandlerComplete) + } + } else { + switch v := wr.write.(type) { + case http2StreamError: + // st may be unknown if the RST_STREAM was generated to reject bad input. + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } + } + + // Reply (if requested) to unblock the ServeHTTP goroutine. + wr.replyToWriter(res.err) + + sc.scheduleFrameWrite() +} + +// scheduleFrameWrite tickles the frame writing scheduler. +// +// If a frame is already being written, nothing happens. This will be called again +// when the frame is done being written. +// +// If a frame isn't being written and we need to send one, the best frame +// to send is selected by writeSched. +// +// If a frame isn't being written and there's nothing else to send, we +// flush the write buffer. +func (sc *http2serverConn) scheduleFrameWrite() { + sc.serveG.check() + if sc.writingFrame || sc.inFrameScheduleLoop { + return + } + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue + } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + if wr.isControl() { + sc.queuedControlFrames-- + } + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true + continue + } + break + } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown gracefully shuts down a connection. This +// sends GOAWAY with http2ErrCodeNo to tell the client we're gracefully +// shutting down. The connection isn't closed until all current +// streams are done. +// +// startGracefulShutdown returns immediately; it does not wait until +// the connection has shut down. +func (sc *http2serverConn) startGracefulShutdown() { + sc.serveG.checkNotOn() // NOT + sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) +} + +// After sending GOAWAY with an error code (non-graceful shutdown), the +// connection will close after goAwayTimeout. +// +// If we close the connection immediately after sending GOAWAY, there may +// be unsent data in our kernel receive buffer, which will cause the kernel +// to send a TCP RST on close() instead of a FIN. This RST will abort the +// connection immediately, whether or not the client had received the GOAWAY. +// +// Ideally we should delay for at least 1 RTT + epsilon so the client has +// a chance to read the GOAWAY and stop sending messages. Measuring RTT +// is hard, so we approximate with 1 second. See golang.org/issue/18701. +// +// This is a var so it can be shorter in tests, where all requests uses the +// loopback interface making the expected RTT very small. +// +// TODO: configurable? +var http2goAwayTimeout = 1 * time.Second + +func (sc *http2serverConn) startGracefulShutdownInternal() { + sc.goAway(http2ErrCodeNo) +} + +func (sc *http2serverConn) goAway(code http2ErrCode) { + sc.serveG.check() + if sc.inGoAway { + return + } + sc.inGoAway = true + sc.needToSendGoAway = true + sc.goAwayCode = code + sc.scheduleFrameWrite() +} + +func (sc *http2serverConn) shutDownIn(d time.Duration) { + sc.serveG.check() + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) +} + +func (sc *http2serverConn) resetStream(se http2StreamError) { + sc.serveG.check() + sc.writeFrame(http2FrameWriteRequest{write: se}) + if st, ok := sc.streams[se.StreamID]; ok { + st.resetQueued = true + } +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type http2goAwayFlowError struct{} + +func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } + +// processFrameFromReader processes the serve loop's read from readFrameCh from the +// frame-reading goroutine. +// processFrameFromReader returns whether the connection should be kept open. +func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { + sc.serveG.check() + err := res.err + if err != nil { + if err == http2ErrFrameTooLarge { + sc.goAway(http2ErrCodeFrameSize) + return true // goAway will close the loop + } + clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) + if clientGone { + // TODO: could we also get into this state if + // the peer does a half close + // (e.g. CloseWrite) because they're done + // sending frames but they're still wanting + // our open replies? Investigate. + // TODO: add CloseWrite to crypto/tls.Conn first + // so we have a way to test this? I suppose + // just for testing we could have a non-TLS mode. + return false + } + } else { + f := res.f + if http2VerboseLogs { + sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) + } + err = sc.processFrame(f) + if err == nil { + return true + } + } + + switch ev := err.(type) { + case http2StreamError: + sc.resetStream(ev) + return true + case http2goAwayFlowError: + sc.goAway(http2ErrCodeFlowControl) + return true + case http2ConnectionError: + sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) + sc.goAway(http2ErrCode(ev)) + return true // goAway will handle shutdown + default: + if res.err != nil { + sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) + } else { + sc.logf("http2: server closing client connection: %v", err) + } + return false + } +} + +func (sc *http2serverConn) processFrame(f http2Frame) error { + sc.serveG.check() + + // First frame received must be SETTINGS. + if !sc.sawFirstSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) + } + sc.sawFirstSettings = true + } + + switch f := f.(type) { + case *http2SettingsFrame: + return sc.processSettings(f) + case *http2MetaHeadersFrame: + return sc.processHeaders(f) + case *http2WindowUpdateFrame: + return sc.processWindowUpdate(f) + case *http2PingFrame: + return sc.processPing(f) + case *http2DataFrame: + return sc.processData(f) + case *http2RSTStreamFrame: + return sc.processResetStream(f) + case *http2PriorityFrame: + return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) + case *http2PushPromiseFrame: + // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE + // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) + default: + sc.vlogf("http2: server ignoring frame: %v", f.Header()) + return nil + } +} + +func (sc *http2serverConn) processPing(f *http2PingFrame) error { + sc.serveG.check() + if f.IsAck() { + // 6.7 PING: " An endpoint MUST NOT respond to PING frames + // containing this flag." + return nil + } + if f.StreamID != 0 { + // "PING frames are not associated with any individual + // stream. If a PING frame is received with a stream + // identifier field value other than 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) + } + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) + return nil +} + +func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { + sc.serveG.check() + switch { + case f.StreamID != 0: // stream-level flow control + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) + } + if st == nil { + // "WINDOW_UPDATE can be sent by a peer that has sent a + // frame bearing the END_STREAM flag. This means that a + // receiver could receive a WINDOW_UPDATE frame on a "half + // closed (remote)" or "closed" stream. A receiver MUST + // NOT treat this as an error, see Section 5.1." + return nil + } + if !st.flow.add(int32(f.Increment)) { + return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) + } + default: // connection-level flow control + if !sc.flow.add(int32(f.Increment)) { + return http2goAwayFlowError{} + } + } + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { + sc.serveG.check() + + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // 6.4 "RST_STREAM frames MUST NOT be sent for a + // stream in the "idle" state. If a RST_STREAM frame + // identifying an idle stream is received, the + // recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) + } + if st != nil { + st.cancelCtx() + sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) + } + return nil +} + +func (sc *http2serverConn) closeStream(st *http2stream, err error) { + sc.serveG.check() + if st.state == http2stateIdle || st.state == http2stateClosed { + panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) + } + st.state = http2stateClosed + if st.writeDeadline != nil { + st.writeDeadline.Stop() + } + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- + } + delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(http.StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdownInternal() + } + } + if p := st.body; p != nil { + // Return any buffered unread bytes worth of conn-level flow control. + // See golang.org/issue/16481 + sc.sendWindowUpdate(nil, p.Len()) + + p.CloseWithError(err) + } + st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc + sc.writeSched.CloseStream(st.id) +} + +func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { + sc.serveG.check() + if f.IsAck() { + sc.unackedSettings-- + if sc.unackedSettings < 0 { + // Why is the peer ACKing settings we never sent? + // The spec doesn't mention this case, but + // hang up on them anyway. + return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) + } + return nil + } + if f.NumSettings() > 100 || f.HasDuplicates() { + // This isn't actually in the spec, but hang up on + // suspiciously large settings frames or those with + // duplicate entries. + return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) + } + if err := f.ForeachSetting(sc.processSetting); err != nil { + return err + } + // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be + // acknowledged individually, even if multiple are received before the ACK. + sc.needToSendSettingsAck = true + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processSetting(s http2Setting) error { + sc.serveG.check() + if err := s.Valid(); err != nil { + return err + } + if http2VerboseLogs { + sc.vlogf("http2: server processing setting %v", s) + } + switch s.ID { + case http2SettingHeaderTableSize: + sc.headerTableSize = s.Val + sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) + case http2SettingEnablePush: + sc.pushEnabled = s.Val != 0 + case http2SettingMaxConcurrentStreams: + sc.clientMaxStreams = s.Val + case http2SettingInitialWindowSize: + return sc.processSettingInitialWindowSize(s.Val) + case http2SettingMaxFrameSize: + sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 + case http2SettingMaxHeaderListSize: + sc.peerMaxHeaderListSize = s.Val + default: + // Unknown setting: "An endpoint that receives a SETTINGS + // frame with any unknown or unsupported identifier MUST + // ignore that setting." + if http2VerboseLogs { + sc.vlogf("http2: server ignoring unknown setting %v", s) + } + } + return nil +} + +func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { + sc.serveG.check() + // Note: val already validated to be within range by + // processSetting's Valid call. + + // "A SETTINGS frame can alter the initial flow control window + // size for all current streams. When the value of + // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST + // adjust the size of all stream flow control windows that it + // maintains by the difference between the new value and the + // old value." + old := sc.initialStreamSendWindowSize + sc.initialStreamSendWindowSize = int32(val) + growth := int32(val) - old // may be negative + for _, st := range sc.streams { + if !st.flow.add(growth) { + // 6.9.2 Initial Flow Control Window Size + // "An endpoint MUST treat a change to + // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow + // control window to exceed the maximum size as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR." + return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) + } + } + return nil +} + +func (sc *http2serverConn) processData(f *http2DataFrame) error { + sc.serveG.check() + id := f.Header().StreamID + if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || id > sc.maxClientStreamID) { + // Discard all DATA frames if the GOAWAY is due to an + // error, or: + // + // Section 6.8: After sending a GOAWAY frame, the sender + // can discard frames for streams initiated by the + // receiver with identifiers higher than the identified + // last stream. + return nil + } + + data := f.Data() + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + // Section 6.1: "DATA frames MUST be associated with a + // stream. If a DATA frame is received whose stream + // identifier field is 0x0, the recipient MUST respond + // with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + // + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) + } + + // "If a DATA frame is received whose stream is not in "open" + // or "half closed (local)" state, the recipient MUST respond + // with a stream error (Section 5.4.2) of type STREAM_CLOSED." + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { + // This includes sending a RST_STREAM if the stream is + // in stateHalfClosedLocal (which currently means that + // the http.Handler returned, so it's done reading & + // done writing). Try to stop the client from sending + // more DATA. + + // But still enforce their connection-level flow control, + // and return any flow control bytes since we're not going + // to consume them. + if sc.inflow.available() < int32(f.Length) { + return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) + } + // Deduct the flow control from inflow, since we're + // going to immediately add it back in + // sendWindowUpdate, which also schedules sending the + // frames. + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + + if st != nil && st.resetQueued { + // Already have a stream error in flight. Don't send another. + return nil + } + return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) + } + if st.body == nil { + panic("internal error: should have a body in this state") + } + + // Sender sending more than they'd declared? + if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the + // value of a content-length header field does not equal the sum of the + // DATA frame payload lengths that form the body. + return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) + } + if f.Length > 0 { + // Check whether the client has flow control quota. + if st.inflow.available() < int32(f.Length) { + return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) + } + st.inflow.take(int32(f.Length)) + + if len(data) > 0 { + wrote, err := st.body.Write(data) + if err != nil { + sc.sendWindowUpdate(nil, int(f.Length)-wrote) + return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) + } + + // Return any padded flow control now, since we won't + // refund it later on body reads. + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) + } + } + if f.StreamEnded() { + st.endStream() + } + return nil +} + +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdownInternal() + // http://tools.ietf.org/html/rfc7540#section-6.8 + // We should not create any new streams, which means we should disable push. + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + +// endStream closes a Request.Body's pipe. It is called when a DATA +// frame says a request body is over (or after trailers). +func (st *http2stream) endStream() { + sc := st.sc + sc.serveG.check() + + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) + st.body.CloseWithError(io.EOF) + } + st.state = http2stateHalfClosedRemote +} + +// copyTrailersToHandlerRequest is run in the Handler's goroutine in +// its Request.Body.Read just before it gets io.EOF. +func (st *http2stream) copyTrailersToHandlerRequest() { + for k, vv := range st.trailer { + if _, ok := st.reqTrailer[k]; ok { + // Only copy it over it was pre-declared. + st.reqTrailer[k] = vv + } + } +} + +// onWriteTimeout is run on its own goroutine (from time.AfterFunc) +// when the stream's WriteTimeout has fired. +func (st *http2stream) onWriteTimeout() { + st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +} + +func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { + sc.serveG.check() + id := f.StreamID + if sc.inGoAway { + // Ignore. + return nil + } + // http://tools.ietf.org/html/rfc7540#section-5.1.1 + // Streams initiated by a client MUST use odd-numbered stream + // identifiers. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + if id%2 != 1 { + return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) + } + // A HEADERS frame can be used to create a new stream or + // send a trailer for an open one. If we already have a stream + // open, let it process its own HEADERS frame (trailers at this + // point, if it's valid). + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + // We're sending RST_STREAM to close the stream, so don't bother + // processing this frame. + return nil + } + // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than + // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in + // this state, it MUST respond with a stream error (Section 5.4.2) of + // type STREAM_CLOSED. + if st.state == http2stateHalfClosedRemote { + return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) + } + return st.processTrailerHeaders(f) + } + + // [...] The identifier of a newly established stream MUST be + // numerically greater than all streams that the initiating + // endpoint has opened or reserved. [...] An endpoint that + // receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + if id <= sc.maxClientStreamID { + return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) + } + sc.maxClientStreamID = id + + if sc.idleTimer != nil { + sc.idleTimer.Stop() + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.2 + // [...] Endpoints MUST NOT exceed the limit set by their peer. An + // endpoint that receives a HEADERS frame that causes their + // advertised concurrent stream limit to be exceeded MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR + // or REFUSED_STREAM. + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { + // They should know better. + return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) + } + // Assume it's a network race, where they just haven't + // received our last SETTINGS update. But actually + // this can't happen yet, because we don't yet provide + // a way for users to adjust server parameters at + // runtime. + return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) + } + + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) + + if f.HasPriority() { + if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { + return err + } + sc.writeSched.AdjustStream(st.id, f.Priority) + } + + rw, req, err := sc.newWriterAndRequest(st, f) + if err != nil { + return err + } + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(http.Header) + } + st.body = req.Body.(*http2requestBody).pipe // may be nil + st.declBodyBytes = req.ContentLength + + handler := sc.handler.ServeHTTP + if f.Truncated { + // Their header list was too long. Send a 431 error. + handler = http2handleHeaderListTooLong + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { + handler = http2new400Handler(err) + } + + // The net/http package sets the read deadline from the + // http.Server.ReadTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already + // set. Disarm it here after the request headers are read, + // similar to how the http1 server works. Here it's + // technically more like the http1 Server's ReadHeaderTimeout + // (in Go 1.8), though. That's a more sane option anyway. + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + + go sc.runHandler(rw, req, handler) + return nil +} + +func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { + sc := st.sc + sc.serveG.check() + if st.gotTrailerHeader { + return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) + } + st.gotTrailerHeader = true + if !f.StreamEnded() { + return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) + } + + if len(f.PseudoFields()) > 0 { + return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) + } + if st.trailer != nil { + for _, hf := range f.RegularFields() { + key := sc.canonicalHeader(hf.Name) + if !httpguts.ValidTrailerHeader(key) { + // TODO: send more details to the peer somehow. But http2 has + // no way to send debug data at a stream level. Discuss with + // HTTP folk. + return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) + } + st.trailer[key] = append(st.trailer[key], hf.Value) + } + } + st.endStream() + return nil +} + +func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." + // Section 5.3.3 says that a stream can depend on one of its dependencies, + // so it's only self-dependencies that are forbidden. + return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) + } + return nil +} + +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} + +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") + } + + ctx, cancelCtx := context.WithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, + } + st.cw.Init() + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialStreamSendWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + if sc.hs.WriteTimeout != 0 { + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + } + + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ + } + if sc.curOpenStreams() == 1 { + sc.setConnState(http.StateActive) + } + + return st +} + +func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *http.Request, error) { + sc.serveG.check() + + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } + + isConnect := rp.method == "CONNECT" + if isConnect { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { + return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { + // See 8.1.2.6 Malformed Requests and Responses: + // + // Malformed requests or responses that are detected + // MUST be treated as a stream error (Section 5.4.2) + // of type PROTOCOL_ERROR." + // + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid + // value for the :method, :scheme, and :path + // pseudo-header fields" + return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + + bodyOpen := !f.StreamEnded() + if rp.method == "HEAD" && bodyOpen { + // HEAD requests can't have bodies + return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) + } + + rp.header = make(http.Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") + } + + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err + } + if bodyOpen { + if vv, ok := rp.header["Content-Length"]; ok { + if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { + req.ContentLength = int64(cl) + } else { + req.ContentLength = 0 + } + } else { + req.ContentLength = -1 + } + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2dataBuffer{expected: req.ContentLength}, + } + } + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header http.Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *http.Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" + if needsContinue { + rp.header.Del("Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) + } + + // Setup Trailers + var trailer http.Header + for _, v := range rp.header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = http.CanonicalHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(http.Header) + } + trailer[key] = nil + } + } + } + delete(rp.header, "Trailer") + + var url_ *url.URL + var requestURI string + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.path) + if err != nil { + return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) + } + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, + } + req := &http.Request{ + Method: rp.method, + URL: url_, + RemoteAddr: sc.remoteAddrStr, + Header: rp.header, + RequestURI: requestURI, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + TLS: tlsState, + Host: rp.authority, + Body: body, + Trailer: trailer, + } + req = req.WithContext(st.ctx) + + rws := http2responseWriterStatePool.Get().(*http2responseWriterState) + bwSave := rws.bw + *rws = http2responseWriterState{} // zero all the fields + rws.conn = sc + rws.bw = bwSave + rws.bw.Reset(http2chunkWriter{rws}) + rws.stream = st + rws.req = req + rws.body = body + + rw := &http2responseWriter{rws: rws} + return rw, req, nil +} + +// Run on its own goroutine. +func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { + didPanic := true + defer func() { + rw.rws.stream.cancelCtx() + if didPanic { + e := recover() + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2handlerPanicRST{rw.rws.stream.id}, + stream: rw.rws.stream, + }) + // Same as net/http: + if e != nil && e != http.ErrAbortHandler { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } + return + } + rw.handlerDone() + }() + handler(rw, req) + didPanic = false +} + +func http2handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { + // 10.5.1 Limits on Header Block Size: + // .. "A server that receives a larger header block than it is + // willing to handle can send an HTTP 431 (Request Header Fields Too + // Large) status code" + const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ + w.WriteHeader(statusRequestHeaderFieldsTooLarge) + io.WriteString(w, "

HTTP Error 431

Request Header Field(s) Too Large

") +} + +// called from handler goroutines. +// h may be nil. +func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { + sc.serveG.checkNotOn() // NOT on + var errc chan error + if headerData.h != nil { + // If there's a header map (which we don't own), so we have to block on + // waiting for this frame to be written, so an http.Flush mid-handler + // writes out the correct value of keys, before a handler later potentially + // mutates it. + errc = http2errChanPool.Get().(chan error) + } + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: headerData, + stream: st, + done: errc, + }); err != nil { + return err + } + if errc != nil { + select { + case err := <-errc: + http2errChanPool.Put(errc) + return err + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + } + } + return nil +} + +// called from handler goroutines. +func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2write100ContinueHeadersFrame{st.id}, + stream: st, + }) +} + +// A bodyReadMsg tells the server loop that the http.Handler read n +// bytes of the DATA from the client on the given stream. +type http2bodyReadMsg struct { + st *http2stream + n int +} + +// called from handler goroutines. +// Notes that the handler for the given stream ID read n bytes of its body +// and schedules flow control tokens to be sent. +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { + sc.serveG.checkNotOn() // NOT on + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } +} + +func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { + sc.serveG.check() + sc.sendWindowUpdate(nil, n) // conn-level + if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { + // Don't send this WINDOW_UPDATE if the stream is closed + // remotely. + sc.sendWindowUpdate(st, n) + } +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { + sc.serveG.check() + // "The legal range for the increment to the flow control + // window is 1 to 2^31-1 (2,147,483,647) octets." + // A Go Read call on 64-bit machines could in theory read + // a larger Read than this. Very unlikely, but we handle it here + // rather than elsewhere for now. + const maxUint31 = 1<<31 - 1 + for n >= maxUint31 { + sc.sendWindowUpdate32(st, maxUint31) + n -= maxUint31 + } + sc.sendWindowUpdate32(st, int32(n)) +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { + sc.serveG.check() + if n == 0 { + return + } + if n < 0 { + panic("negative update") + } + var streamID uint32 + if st != nil { + streamID = st.id + } + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + stream: st, + }) + var ok bool + if st == nil { + ok = sc.inflow.add(n) + } else { + ok = st.inflow.add(n) + } + if !ok { + panic("internal error; sent too many window updates without decrements?") + } +} + +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. +type http2requestBody struct { + _ http2incomparable + stream *http2stream + conn *http2serverConn + closed bool // for use by Close only + sawEOF bool // for use by Read only + pipe *http2pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue +} + +func (b *http2requestBody) Close() error { + if b.pipe != nil && !b.closed { + b.pipe.BreakWithError(http2errClosedBody) + } + b.closed = true + return nil +} + +func (b *http2requestBody) Read(p []byte) (n int, err error) { + if b.needsContinue { + b.needsContinue = false + b.conn.write100ContinueHeaders(b.stream) + } + if b.pipe == nil || b.sawEOF { + return 0, io.EOF + } + n, err = b.pipe.Read(p) + if err == io.EOF { + b.sawEOF = true + } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) + return +} + +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriterState pointer inside is zeroed at the end of a +// request (in handlerDone) and calls on the responseWriter thereafter +// simply crash (caller's mistake), but the much larger responseWriterState +// and buffers are reused between multiple requests. +type http2responseWriter struct { + rws *http2responseWriterState +} + +// from pkg io +type http2stringWriter interface { + WriteString(s string) (n int, err error) +} + +// Optional http.ResponseWriter interfaces implemented. +var ( + _ http.CloseNotifier = (*http2responseWriter)(nil) + _ http.Flusher = (*http2responseWriter)(nil) + _ http2stringWriter = (*http2responseWriter)(nil) +) + +type http2responseWriterState struct { + // immutable within a request: + stream *http2stream + req *http.Request + body *http2requestBody // to close at end of request, if DATA frames didn't + conn *http2serverConn + + // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc + bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} + + // mutated by http.Handler goroutine: + handlerHeader http.Header // nil until called + snapHeader http.Header // snapshot of handlerHeader at WriteHeader time + trailers []string // set in writeChunk + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished + dirty bool // a Write failed; don't reuse this responseWriterState + + sentContentLen int64 // non-zero if handler set a Content-Length header + wroteBytes int64 + + closeNotifierMu sync.Mutex // guards closeNotifierCh + closeNotifierCh chan bool // nil until first used +} + +type http2chunkWriter struct{ rws *http2responseWriterState } + +func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } + +func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } + +func (rws *http2responseWriterState) hasNonemptyTrailers() bool { + for _, trailer := range rws.trailers { + if _, ok := rws.handlerHeader[trailer]; ok { + return true + } + } + return false +} + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (rws *http2responseWriterState) declareTrailer(k string) { + k = http.CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 7230, section 4.1.2. + rws.conn.logf("ignoring invalid trailer %q", k) + return + } + if !http2strSliceContains(rws.trailers, k) { + rws.trailers = append(rws.trailers, k) + } +} + +// writeChunk writes chunks from the bufio.Writer. But because +// bufio.Writer may bypass its chunking, sometimes p may be +// arbitrarily large. +// +// writeChunk is also responsible (on the first chunk) for sending the +// HEADER response. +func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { + if !rws.wroteHeader { + rws.writeHeader(200) + } + + isHeadResp := rws.req.Method == "HEAD" + if !rws.sentHeader { + rws.sentHeader = true + var ctype, clen string + if clen = rws.snapHeader.Get("Content-Length"); clen != "" { + rws.snapHeader.Del("Content-Length") + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + rws.sentContentLen = int64(cl) + } else { + clen = "" + } + } + if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + clen = strconv.Itoa(len(p)) + } + _, hasContentType := rws.snapHeader["Content-Type"] + // If the Content-Encoding is non-blank, we shouldn't + // sniff the body. See Issue golang.org/issue/31753. + ce := rws.snapHeader.Get("Content-Encoding") + hasCE := len(ce) > 0 + if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { + ctype = http.DetectContentType(p) + } + var date string + if _, ok := rws.snapHeader["Date"]; !ok { + // TODO(bradfitz): be faster here, like net/http? measure. + date = time.Now().UTC().Format(http.TimeFormat) + } + + for _, v := range rws.snapHeader["Trailer"] { + http2foreachHeaderElement(v, rws.declareTrailer) + } + + // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), + // but respect "Connection" == "close" to mean sending a GOAWAY and tearing + // down the TCP connection when idle, like we do for HTTP/1. + // TODO: remove more Connection-specific header fields here, in addition + // to "Connection". + if _, ok := rws.snapHeader["Connection"]; ok { + v := rws.snapHeader.Get("Connection") + delete(rws.snapHeader, "Connection") + if v == "close" { + rws.conn.startGracefulShutdown() + } + } + + endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + httpResCode: rws.status, + h: rws.snapHeader, + endStream: endStream, + contentType: ctype, + contentLength: clen, + date: date, + }) + if err != nil { + rws.dirty = true + return 0, err + } + if endStream { + return 0, nil + } + } + if isHeadResp { + return len(p), nil + } + if len(p) == 0 && !rws.handlerDone { + return 0, nil + } + + if rws.handlerDone { + rws.promoteUndeclaredTrailers() + } + + // only send trailers if they have actually been defined by the + // server handler. + hasNonemptyTrailers := rws.hasNonemptyTrailers() + endStream := rws.handlerDone && !hasNonemptyTrailers + if len(p) > 0 || endStream { + // only send a 0 byte DATA frame if we're ending the stream. + if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { + rws.dirty = true + return 0, err + } + } + + if rws.handlerDone && hasNonemptyTrailers { + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + h: rws.handlerHeader, + trailers: rws.trailers, + endStream: true, + }) + if err != nil { + rws.dirty = true + } + return len(p), err + } + return len(p), nil +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const http2TrailerPrefix = "Trailer:" + +// promoteUndeclaredTrailers permits http.Handlers to set trailers +// after the header has already been flushed. Because the Go +// ResponseWriter interface has no way to set Trailers (only the +// Header), and because we didn't want to expand the ResponseWriter +// interface, and because nobody used trailers, and because RFC 7230 +// says you SHOULD (but not must) predeclare any trailers in the +// header, the official ResponseWriter rules said trailers in Go must +// be predeclared, and then we reuse the same ResponseWriter.Header() +// map to mean both Headers and Trailers. When it's time to write the +// Trailers, we pick out the fields of Headers that were declared as +// trailers. That worked for a while, until we found the first major +// user of Trailers in the wild: gRPC (using them only over http2), +// and gRPC libraries permit setting trailers mid-stream without +// predeclaring them. So: change of plans. We still permit the old +// way, but we also permit this hack: if a Header() key begins with +// "Trailer:", the suffix of that key is a Trailer. Because ':' is an +// invalid token byte anyway, there is no ambiguity. (And it's already +// filtered out) It's mildly hacky, but not terrible. +// +// This method runs after the Handler is done and promotes any Header +// fields to be trailers. +func (rws *http2responseWriterState) promoteUndeclaredTrailers() { + for k, vv := range rws.handlerHeader { + if !strings.HasPrefix(k, http2TrailerPrefix) { + continue + } + trailerKey := strings.TrimPrefix(k, http2TrailerPrefix) + rws.declareTrailer(trailerKey) + rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv + } + + if len(rws.trailers) > 1 { + sorter := http2sorterPool.Get().(*http2sorter) + sorter.SortStrings(rws.trailers) + http2sorterPool.Put(sorter) + } +} + +func (w *http2responseWriter) Flush() { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.bw.Buffered() > 0 { + if err := rws.bw.Flush(); err != nil { + // Ignore the error. The frame writer already knows. + return + } + } else { + // The bufio.Writer won't call chunkWriter.Write + // (writeChunk with zero bytes, so we have to do it + // ourselves to force the HTTP response header and/or + // final DATA frame (with END_STREAM) to be sent. + rws.writeChunk(nil) + } +} + +func (w *http2responseWriter) CloseNotify() <-chan bool { + rws := w.rws + if rws == nil { + panic("CloseNotify called after Handler finished") + } + rws.closeNotifierMu.Lock() + ch := rws.closeNotifierCh + if ch == nil { + ch = make(chan bool, 1) + rws.closeNotifierCh = ch + cw := rws.stream.cw + go func() { + cw.Wait() // wait for close + ch <- true + }() + } + rws.closeNotifierMu.Unlock() + return ch +} + +func (w *http2responseWriter) Header() http.Header { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.handlerHeader == nil { + rws.handlerHeader = make(http.Header) + } + return rws.handlerHeader +} + +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func http2checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +func (w *http2responseWriter) WriteHeader(code int) { + rws := w.rws + if rws == nil { + panic("WriteHeader called after Handler finished") + } + rws.writeHeader(code) +} + +func (rws *http2responseWriterState) writeHeader(code int) { + if !rws.wroteHeader { + http2checkWriteHeaderCode(code) + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = http2cloneHeader(rws.handlerHeader) + } + } +} + +func http2cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// The Life Of A Write is like this: +// +// * Handler calls w.Write or w.WriteString -> +// * -> rws.bw (*bufio.Writer) -> +// * (Handler might call Flush) +// * -> chunkWriter{rws} +// * -> responseWriterState.writeChunk(p []byte) +// * -> responseWriterState.writeChunk (most of the magic; see comment there) +func (w *http2responseWriter) Write(p []byte) (n int, err error) { + return w.write(len(p), p, "") +} + +func (w *http2responseWriter) WriteString(s string) (n int, err error) { + return w.write(len(s), nil, s) +} + +// either dataB or dataS is non-zero. +func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { + rws := w.rws + if rws == nil { + panic("Write called after Handler finished") + } + if !rws.wroteHeader { + w.WriteHeader(200) + } + if !http2bodyAllowedForStatus(rws.status) { + return 0, http.ErrBodyNotAllowed + } + rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set + if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { + // TODO: send a RST_STREAM + return 0, errors.New("http2: handler wrote more than declared Content-Length") + } + + if dataB != nil { + return rws.bw.Write(dataB) + } else { + return rws.bw.WriteString(dataS) + } +} + +func (w *http2responseWriter) handlerDone() { + rws := w.rws + dirty := rws.dirty + rws.handlerDone = true + w.Flush() + w.rws = nil + if !dirty { + // Only recycle the pool if all prior Write calls to + // the serverConn goroutine completed successfully. If + // they returned earlier due to resets from the peer + // there might still be write goroutines outstanding + // from the serverConn referencing the rws memory. See + // issue 20704. + http2responseWriterStatePool.Put(rws) + } +} + +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +var _ http.Pusher = (*http2responseWriter)(nil) + +func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." + // http://tools.ietf.org/html/rfc7540#section-6.6 + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts == nil { + opts = new(http.PushOptions) + } + + // Default options. + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = http.Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + // Validate the request. + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + // These headers are meaningful only if the request has a body, + // but PUSH_PROMISE requests cannot have a body. + // http://tools.ietf.org/html/rfc7540#section-8.2 + // Also disallow Host, since the promised URL must be absolute. + if ascii.EqualFold(k, "content-length") || + ascii.EqualFold(k, "content-encoding") || + ascii.EqualFold(k, "trailer") || + ascii.EqualFold(k, "te") || + ascii.EqualFold(k, "expect") || + ascii.EqualFold(k, "host") { + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + // The RFC effectively limits promised requests to GET and HEAD: + // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" + // http://tools.ietf.org/html/rfc7540#section-8.2 + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := &http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.serveMsgCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header http.Header + done chan error +} + +func (sc *http2serverConn) startPush(msg *http2startPushRequest) { + sc.serveG.check() + + // http://tools.ietf.org/html/rfc7540#section-6.6. + // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that + // is in either the "open" or "half-closed (remote)" state. + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + // responseWriter.Push checks that the stream is peer-initiated. + msg.done <- http2errStreamClosed + return + } + + // http://tools.ietf.org/html/rfc7540#section-6.6. + if !sc.pushEnabled { + msg.done <- http.ErrNotSupported + return + } + + // PUSH_PROMISE frames must be sent in increasing order by stream ID, so + // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE + // is written. Once the ID is allocated, we start the request handler. + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + // Check this again, just in case. Technically, we might have received + // an updated SETTINGS by the time we got around to writing this frame. + if !sc.pushEnabled { + return 0, http.ErrNotSupported + } + // http://tools.ietf.org/html/rfc7540#section-6.5.2. + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.1. + // Streams initiated by the server MUST use even-numbered identifiers. + // A server that is unable to establish a new stream identifier can send a GOAWAY + // frame so that the client is forced to open a new connection for new streams. + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdownInternal() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + // http://tools.ietf.org/html/rfc7540#section-8.2. + // Strictly speaking, the new stream should start in "reserved (local)", then + // transition to "half closed (remote)" after sending the initial HEADERS, but + // we start in "half closed (remote)" for simplicity. + // See further comments at the definition of stateHalfClosedRemote. + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + }) + if err != nil { + // Should not happen, since we've already validated msg.url. + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2RequestHeaders(h http.Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) + } + } + te := h["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, err.Error(), http.StatusBadRequest) + } +} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + +func (sc *http2serverConn) countError(name string, err error) error { + if sc == nil || sc.srv == nil { + return err + } + f := sc.srv.CountError + if f == nil { + return err + } + var typ string + var code http2ErrCode + switch e := err.(type) { + case http2ConnectionError: + typ = "conn" + code = http2ErrCode(e) + case http2StreamError: + typ = "stream" + code = http2ErrCode(e.Code) + default: + return err + } + codeStr := http2errCodeName[code] + if codeStr == "" { + codeStr = strconv.Itoa(int(code)) + } + f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) + return err +} + +// writeFramer is implemented by any type that is used to write frames. +type http2writeFramer interface { + writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool +} + +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *http2serverConn. +// +// TODO: decide whether to a) use this in the client code (which didn't +// end up using this yet, because it has a simpler design, not +// currently implementing priorities), or b) delete this and +// make the server code a bit more concrete. +type http2writeContext interface { + Framer() *http2Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) +} + +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { + switch v := w.(type) { + case *http2writeData: + return v.endStream + case *http2writeResHeaders: + return v.endStream + case nil: + // This can only happen if the caller reuses w after it's + // been intentionally nil'ed out to prevent use. Keep this + // here to catch future refactoring breaking it. + panic("writeEndsStream called on nil writeFramer") + } + return false +} + +type http2flushFrameWriter struct{} + +func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { + return ctx.Flush() +} + +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + +type http2writeSettings []http2Setting + +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + +func (s http2writeSettings) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettings([]http2Setting(s)...) +} + +type http2writeGoAway struct { + maxStreamID uint32 + code http2ErrCode +} + +func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { + err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) + ctx.Flush() // ignore error: we're hanging up on them anyway + return err +} + +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes + +type http2writeData struct { + streamID uint32 + p []byte + endStream bool +} + +func (w *http2writeData) String() string { + return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) +} + +func (w *http2writeData) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) +} + +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + +// handlerPanicRST is the message sent from handler goroutines when +// the handler panics. +type http2handlerPanicRST struct { + StreamID uint32 +} + +func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) +} + +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (se http2StreamError) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) +} + +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +type http2writePingAck struct{ pf *http2PingFrame } + +func (w http2writePingAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WritePing(true, w.pf.Data) +} + +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + +type http2writeSettingsAck struct{} + +func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettingsAck() +} + +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const http2maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > http2maxFrameSize { + frag = frag[:http2maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + +// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames +// for HTTP response headers or trailers from a server handler. +type http2writeResHeaders struct { + streamID uint32 + httpResCode int // 0 means no ":status" line + h http.Header // may be nil + trailers []string // if non-nil, which keys of h to write. nil means all. + endStream bool + + date string + contentType string + contentLength string +} + +func http2encKV(enc *hpack.Encoder, k, v string) { + if http2VerboseLogs { + log.Printf("http2: server encoding header %q = %q", k, v) + } + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) +} + +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + // TODO: this is a common one. It'd be nice to return true + // here and get into the fast path if we could be clever and + // calculate the size fast enough, or at least a conservative + // upper bound that usually fires. (Maybe if w.h and + // w.trailers are nil, so we don't need to enumerate it.) + // Otherwise I'm afraid that just calculating the length to + // answer this question would be slower than the ~2µs benefit. + return false +} + +func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + if w.httpResCode != 0 { + http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) + } + + http2encodeHeaders(enc, w.h, w.trailers) + + if w.contentType != "" { + http2encKV(enc, "content-type", w.contentType) + } + if w.contentLength != "" { + http2encKV(enc, "content-length", w.contentLength) + } + if w.date != "" { + http2encKV(enc, "date", w.date) + } + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 && w.trailers == nil { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h http.Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + // TODO: see writeResHeaders.staysWithinBuffer + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +type http2write100ContinueHeadersFrame struct { + streamID uint32 +} + +func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + http2encKV(enc, ":status", "100") + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: buf.Bytes(), + EndStream: false, + EndHeaders: true, + }) +} + +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + // Sloppy but conservative: + return 9+2*(len(":status")+len("100")) <= max +} + +type http2writeWindowUpdate struct { + streamID uint32 // or 0 for conn-level + n uint32 +} + +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) +} + +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only if k is in keys. +func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { + if keys == nil { + sorter := http2sorterPool.Get().(*http2sorter) + // Using defer here, since the returned keys from the + // sorter.Keys method is only valid until the sorter + // is returned: + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) + } + for _, k := range keys { + vv := h[k] + k, ascii := http2lowerHeader(k) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + continue + } + if !http2validWireHeaderFieldName(k) { + // Skip it as backup paranoia. Per + // golang.org/issue/14048, these should + // already be rejected at a higher level. + continue + } + isTE := k == "transfer-encoding" + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // TODO: return an error? golang.org/issue/14048 + // For now just omit it. + continue + } + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" + if isTE && v != "trailers" { + continue + } + http2encKV(enc, k, v) + } + } +} + +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd, except RST_STREAM frames. No frames should be + // discarded except by CloseStream. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { + // write is the interface value that does the writing, once the + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. + write http2writeFramer + + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. + stream *http2stream + + // done, if non-nil, must be a buffered channel with space for + // 1 message and is sent the return value from write (or an + // earlier error) when the frame has been written. + done chan error +} + +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + // (*http2serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// isControl reports whether wr is a control frame for MaxQueuedControlFrames +// purposes. That includes non-stream frames and RST_STREAM frames. +func (wr http2FrameWriteRequest) isControl() bool { + return wr.stream == nil +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + // Non-DATA frames are always consumed whole. + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + // Might need to split after applying limits. + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + // Even if the original had endStream set, there + // are bytes remaining because len(wd.p) > allowed, + // so we know endStream is false. + endStream: false, + }, + // Our caller is blocking on the final DATA frame, not + // this intermediate frame, so no need to wait. + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 + } + + // The frame is consumed whole. + // NB: This cast cannot overflow because allowed is <= math.MaxInt32. + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { + var des string + if s, ok := wr.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wr.write) + } + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) +} + +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil // prevent use (assume it's tainted after wr.done send) +} + +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} + +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } + +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} + +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + // TODO: less copy-happy queue. + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr +} + +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false + } + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true +} + +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) + if ln == 0 { + return new(http2writeQueue) + } + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] + return q +} + +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 + +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. http2Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. http2Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + // For justification of these defaults, see: + // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 + } else { + ws.writeThrottleLimit = math.MaxInt32 + } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings +} + +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") + } + if n.parent == parent { + return + } + // Unlink from current parent. + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + // Link to new parent. + // If parent=nil, remove n from the tree. + // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n + } +} + +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b + } +} + +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this function returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true + } + if n.kids == nil { + return false + } + + // Don't consider the root "open" when updating openParent since + // we can't send data frames on the root stream (only control frames). + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) + } + + // Common case: only one kid or all kids have the same weight. + // Some clients don't use weights; other clients (like web browsers) + // use mostly-linear priority trees. + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break + } + } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } + + // Uncommon case: sort the child nodes. We remove the kids from the parent, + // then re-insert after sorting so we can reuse tmp for future sort calls. + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) + } + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + // Prefer the subtree that has sent fewer bytes relative to its weight. + // See sections 5.3.2 and 5.3.4. + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk + } + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} + +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode + + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // The stream may be currently idle but cannot be opened or closed. + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + // RFC 7540, Section 5.3.5: + // "All streams are initially assigned a non-exclusive dependency on stream 0x0. + // Pushed streams initially depend on their associated stream. In both cases, + // streams are assigned a default weight of 16." + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID + } +} + +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") + } + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) + } + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) + } + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) + } +} + +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } + + // If streamID does not exist, there are two cases: + // - A closed stream that has been removed (this will have ID <= maxID) + // - An idle stream that is being used for "grouping" (this will have ID > maxID) + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return + } + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, + } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } + + // Section 5.3.1: A dependency on a stream that is not currently in the tree + // results in that stream being given a default priority (Section 5.3.5). + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } + + // Ignore if the client tries to make a node its own parent. + if n == parent { + return + } + + // Section 5.3.3: + // "If a stream is made dependent on one of its own dependencies, the + // formerly dependent stream is first moved to be dependent on the + // reprioritized stream's previous parent. The moved dependency retains + // its weight." + // + // That is: if parent depends on n, move parent to depend on n.parent. + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } + + // Section 5.3.3: The exclusive flag causes the stream to become the sole + // dependency of its parent stream, causing other dependencies to become + // dependent on the exclusive stream. + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next + } + } + + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + // id is an idle or closed stream. wr should not be a HEADERS or + // DATA frame. However, wr can be a RST_STREAM. In this case, we + // push wr onto the root, rather than creating a new priorityNode, + // since RST_STREAM is tiny and the stream's priority is unknown + // anyway. See issue #17919. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } + } + n.q.push(wr) +} + +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + // If B depends on A and B continuously has data available but A + // does not, gradually increase the throttling limit to allow B to + // steal more and more bandwidth from A. + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { + return + } + if len(*list) == maxSize { + // Remove the oldest node, then shift left. + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] + } + *list = append(*list, n) +} + +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) +} + +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} + +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle, closed, or emptied, it's deleted + // from the map. + sq map[uint32]*http2writeQueue + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // no-op: idle streams are not tracked +} + +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return + } + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + // no-op: priorities are ignored +} + +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + if wr.isControl() { + ws.zero.push(wr) + return + } + id := wr.StreamID() + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.zero.empty() { + return ws.zero.shift(), true + } + // Iterate over all non-idle streams until finding one that can be consumed. + for streamID, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + if q.empty() { + delete(ws.sq, streamID) + ws.queuePool.put(q) + } + return wr, true + } + } + return http2FrameWriteRequest{}, false +} + +var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") + +func stderrv() io.Writer { + if *stderrVerbose { + return os.Stderr + } + + return ioutil.Discard +} + +type safeBuffer struct { + b bytes.Buffer + m sync.Mutex +} + +func (sb *safeBuffer) Write(d []byte) (int, error) { + sb.m.Lock() + defer sb.m.Unlock() + return sb.b.Write(d) +} + +func (sb *safeBuffer) Bytes() []byte { + sb.m.Lock() + defer sb.m.Unlock() + return sb.b.Bytes() +} + +func (sb *safeBuffer) Len() int { + sb.m.Lock() + defer sb.m.Unlock() + return sb.b.Len() +} + +type serverTester struct { + cc net.Conn // client conn + t testing.TB + ts *httptest.Server + fr *http2Framer + serverLogBuf safeBuffer // logger for httptest.Server + logFilter []string // substrings to filter out + scMu sync.Mutex // guards sc + sc *http2serverConn + hpackDec *hpack.Decoder + decodedHeaders [][2]string + + // If http2debug!=2, then we capture Frame debug logs that will be written + // to t.Log after a test fails. The read and write logs use separate locks + // and buffers so we don't accidentally introduce synchronization between + // the read and write goroutines, which may hide data races. + frameReadLogMu sync.Mutex + frameReadLogBuf bytes.Buffer + frameWriteLogMu sync.Mutex + frameWriteLogBuf bytes.Buffer + + // writing headers: + headerBuf bytes.Buffer + hpackEnc *hpack.Encoder +} + +func (st *serverTester) onHeaderField(f hpack.HeaderField) { + if f.Name == "date" { + return + } + st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value}) +} + +func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) { + st.decodedHeaders = nil + if _, err := st.hpackDec.Write(headerBlock); err != nil { + st.t.Fatalf("hpack decoding error: %v", err) + } + if err := st.hpackDec.Close(); err != nil { + st.t.Fatalf("hpack decoding error: %v", err) + } + return st.decodedHeaders +} + +func init() { + http2testHookOnPanicMu = new(sync.Mutex) + http2goAwayTimeout = 25 * time.Millisecond +} + +func resetHooks() { + http2testHookOnPanicMu.Lock() + http2testHookOnPanic = nil + http2testHookOnPanicMu.Unlock() +} + +// ConfigureServer adds HTTP/2 support to a net/http Server. +// +// The configuration conf may be nil. +// +// ConfigureServer must be called before s begins serving. +func http2ConfigureServer(s *http.Server, conf *http2Server) error { + if s == nil { + panic("nil *http.Server") + } + if conf == nil { + conf = new(http2Server) + } + conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})} + if h1, h2 := s, conf; h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + s.RegisterOnShutdown(conf.state.startGracefulShutdown) + + if s.TLSConfig == nil { + s.TLSConfig = new(tls.Config) + } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 { + // If they already provided a TLS 1.0–1.2 CipherSuite list, return an + // error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or + // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. + haveRequired := false + for _, cs := range s.TLSConfig.CipherSuites { + switch cs { + case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + // Alternative MTI cipher to not discourage ECDSA-only servers. + // See http://golang.org/cl/30721 for further information. + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + haveRequired = true + } + } + if !haveRequired { + return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)") + } + } + + // Note: not setting MinVersion to tls.VersionTLS12, + // as we don't want to interfere with HTTP/1.1 traffic + // on the user's server. We enforce TLS 1.2 later once + // we accept a connection. Ideally this should be done + // during next-proto selection, but using TLS <1.2 with + // HTTP/2 is still the client's bug. + + s.TLSConfig.PreferServerCipherSuites = true + + if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) + } + if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") + } + + if s.TLSNextProto == nil { + s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} + } + protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { + if http2testHookOnConn != nil { + http2testHookOnConn() + } + // The TLSNextProto interface predates contexts, so + // the net/http package passes down its per-connection + // base context via an exported but unadvertised + // method on the Handler. This is for internal + // net/http<=>http2 use only. + var ctx context.Context + type baseContexter interface { + BaseContext() context.Context + } + if bc, ok := h.(baseContexter); ok { + ctx = bc.BaseContext() + } + conf.ServeConn(c, &http2ServeConnOpts{ + Context: ctx, + Handler: h, + BaseConfig: hs, + }) + } + s.TLSNextProto[http2NextProtoTLS] = protoHandler + return nil +} + +type twriter struct { + t testing.TB + st *serverTester // optional +} + +func (w twriter) Write(p []byte) (n int, err error) { + if w.st != nil { + ps := string(p) + for _, phrase := range w.st.logFilter { + if strings.Contains(ps, phrase) { + return len(p), nil // no logging + } + } + } + w.t.Logf("%s", p) + return len(p), nil +} + +type serverTesterOpt string + +var optOnlyServer = serverTesterOpt("only_server") +var optQuiet = serverTesterOpt("quiet_logging") +var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames") + +func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { + resetHooks() + + ts := httptest.NewUnstartedServer(handler) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{http2NextProtoTLS}, + } + + var onlyServer, quiet, framerReuseFrames bool + h2server := new(http2Server) + for _, opt := range opts { + switch v := opt.(type) { + case func(*tls.Config): + v(tlsConfig) + case func(*httptest.Server): + v(ts) + case func(*http2Server): + v(h2server) + case serverTesterOpt: + switch v { + case optOnlyServer: + onlyServer = true + case optQuiet: + quiet = true + case optFramerReuseFrames: + framerReuseFrames = true + } + case func(net.Conn, http.ConnState): + ts.Config.ConnState = v + default: + t.Fatalf("unknown newServerTester option type %T", v) + } + } + + http2ConfigureServer(ts.Config, h2server) + + st := &serverTester{ + t: t, + ts: ts, + } + st.hpackEnc = hpack.NewEncoder(&st.headerBuf) + st.hpackDec = hpack.NewDecoder(http2initialHeaderTableSize, st.onHeaderField) + + ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config + if quiet { + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + } else { + ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) + } + ts.StartTLS() + + if http2VerboseLogs { + t.Logf("Running test server at: %s", ts.URL) + } + http2testHookGetServerConn = func(v *http2serverConn) { + st.scMu.Lock() + defer st.scMu.Unlock() + st.sc = v + } + log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st})) + if !onlyServer { + cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + st.cc = cc + st.fr = http2NewFramer(cc, cc) + if framerReuseFrames { + st.fr.SetReuseFrames() + } + if !http2logFrameReads && !http2logFrameWrites { + st.fr.debugReadLoggerf = func(m string, v ...interface{}) { + m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" + st.frameReadLogMu.Lock() + fmt.Fprintf(&st.frameReadLogBuf, m, v...) + st.frameReadLogMu.Unlock() + } + st.fr.debugWriteLoggerf = func(m string, v ...interface{}) { + m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" + st.frameWriteLogMu.Lock() + fmt.Fprintf(&st.frameWriteLogBuf, m, v...) + st.frameWriteLogMu.Unlock() + } + st.fr.logReads = true + st.fr.logWrites = true + } + } + return st +} + +func (st *serverTester) closeConn() { + st.scMu.Lock() + defer st.scMu.Unlock() + st.sc.conn.Close() +} + +func (st *serverTester) addLogFilter(phrase string) { + st.logFilter = append(st.logFilter, phrase) +} + +func (st *serverTester) stream(id uint32) *http2stream { + ch := make(chan *http2stream, 1) + st.sc.serveMsgCh <- func(int) { + ch <- st.sc.streams[id] + } + return <-ch +} + +func (st *serverTester) http2streamState(id uint32) http2streamState { + ch := make(chan http2streamState, 1) + st.sc.serveMsgCh <- func(int) { + state, _ := st.sc.state(id) + ch <- state + } + return <-ch +} + +// loopNum reports how many times this conn's select loop has gone around. +func (st *serverTester) loopNum() int { + lastc := make(chan int, 1) + st.sc.serveMsgCh <- func(loopNum int) { + lastc <- loopNum + } + return <-lastc +} + +// awaitIdle heuristically awaits for the server conn's select loop to be idle. +// The heuristic is that the server connection's serve loop must schedule +// 50 times in a row without any channel sends or receives occurring. +func (st *serverTester) awaitIdle() { + remain := 50 + last := st.loopNum() + for remain > 0 { + n := st.loopNum() + if n == last+1 { + remain-- + } else { + remain = 50 + } + last = n + } +} + +func (st *serverTester) Close() { + if st.t.Failed() { + st.frameReadLogMu.Lock() + if st.frameReadLogBuf.Len() > 0 { + st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String()) + } + st.frameReadLogMu.Unlock() + + st.frameWriteLogMu.Lock() + if st.frameWriteLogBuf.Len() > 0 { + st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String()) + } + st.frameWriteLogMu.Unlock() + + // If we failed already (and are likely in a Fatal, + // unwindowing), force close the connection, so the + // httptest.Server doesn't wait forever for the conn + // to close. + if st.cc != nil { + st.cc.Close() + } + } + st.ts.Close() + if st.cc != nil { + st.cc.Close() + } + log.SetOutput(os.Stderr) +} + +// greet initiates the client's HTTP/2 connection into a state where +// frames may be sent. +func (st *serverTester) greet() { + st.greetAndCheckSettings(func(http2Setting) error { return nil }) +} + +func (st *serverTester) greetAndCheckSettings(checkSetting func(s http2Setting) error) { + st.writePreface() + st.writeInitialSettings() + st.wantSettings().ForeachSetting(checkSetting) + st.writeSettingsAck() + + // The initial WINDOW_UPDATE and SETTINGS ACK can come in any order. + var gotSettingsAck bool + var gotWindowUpdate bool + + for i := 0; i < 2; i++ { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + switch f := f.(type) { + case *http2SettingsFrame: + if !f.Header().Flags.Has(http2FlagSettingsAck) { + st.t.Fatal("Settings Frame didn't have ACK set") + } + gotSettingsAck = true + + case *http2WindowUpdateFrame: + if f.http2FrameHeader.StreamID != 0 { + st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.http2FrameHeader.StreamID) + } + incr := uint32((&http2Server{}).initialConnRecvWindowSize() - http2initialWindowSize) + if f.Increment != incr { + st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr) + } + gotWindowUpdate = true + + default: + st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f) + } + } + + if !gotSettingsAck { + st.t.Fatalf("Didn't get a settings ACK") + } + if !gotWindowUpdate { + st.t.Fatalf("Didn't get a window update") + } +} + +func (st *serverTester) writePreface() { + n, err := st.cc.Write(http2clientPreface) + if err != nil { + st.t.Fatalf("Error writing client preface: %v", err) + } + if n != len(http2clientPreface) { + st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2clientPreface)) + } +} + +func (st *serverTester) writeInitialSettings() { + if err := st.fr.WriteSettings(); err != nil { + st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) + } +} + +func (st *serverTester) writeSettingsAck() { + if err := st.fr.WriteSettingsAck(); err != nil { + st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) + } +} + +func (st *serverTester) writeHeaders(p http2HeadersFrameParam) { + if err := st.fr.WriteHeaders(p); err != nil { + st.t.Fatalf("Error writing HEADERS: %v", err) + } +} + +func (st *serverTester) writePriority(id uint32, p http2PriorityParam) { + if err := st.fr.WritePriority(id, p); err != nil { + st.t.Fatalf("Error writing PRIORITY: %v", err) + } +} + +func (st *serverTester) encodeHeaderField(k, v string) { + err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } +} + +// encodeHeaderRaw is the magic-free version of encodeHeader. +// It takes 0 or more (k, v) pairs and encodes them. +func (st *serverTester) encodeHeaderRaw(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + st.headerBuf.Reset() + for len(headers) > 0 { + k, v := headers[0], headers[1] + st.encodeHeaderField(k, v) + headers = headers[2:] + } + return st.headerBuf.Bytes() +} + +// encodeHeader encodes headers and returns their HPACK bytes. headers +// must contain an even number of key/value pairs. There may be +// multiple pairs for keys (e.g. "cookie"). The :method, :path, and +// :scheme headers default to GET, / and https. The :authority header +// defaults to st.ts.Listener.Addr(). +func (st *serverTester) encodeHeader(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + + st.headerBuf.Reset() + defaultAuthority := st.ts.Listener.Addr().String() + + if len(headers) == 0 { + // Fast path, mostly for benchmarks, so test code doesn't pollute + // profiles when we're looking to improve server allocations. + st.encodeHeaderField(":method", "GET") + st.encodeHeaderField(":scheme", "https") + st.encodeHeaderField(":authority", defaultAuthority) + st.encodeHeaderField(":path", "/") + return st.headerBuf.Bytes() + } + + if len(headers) == 2 && headers[0] == ":method" { + // Another fast path for benchmarks. + st.encodeHeaderField(":method", headers[1]) + st.encodeHeaderField(":scheme", "https") + st.encodeHeaderField(":authority", defaultAuthority) + st.encodeHeaderField(":path", "/") + return st.headerBuf.Bytes() + } + + pseudoCount := map[string]int{} + keys := []string{":method", ":scheme", ":authority", ":path"} + vals := map[string][]string{ + ":method": {"GET"}, + ":scheme": {"https"}, + ":authority": {defaultAuthority}, + ":path": {"/"}, + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if _, ok := vals[k]; !ok { + keys = append(keys, k) + } + if strings.HasPrefix(k, ":") { + pseudoCount[k]++ + if pseudoCount[k] == 1 { + vals[k] = []string{v} + } else { + // Allows testing of invalid headers w/ dup pseudo fields. + vals[k] = append(vals[k], v) + } + } else { + vals[k] = append(vals[k], v) + } + } + for _, k := range keys { + for _, v := range vals[k] { + st.encodeHeaderField(k, v) + } + } + return st.headerBuf.Bytes() +} + +// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set. +func (st *serverTester) bodylessReq1(headers ...string) { + st.writeHeaders(http2HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(headers...), + EndStream: true, + EndHeaders: true, + }) +} + +func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { + if err := st.fr.WriteData(streamID, endStream, data); err != nil { + st.t.Fatalf("Error writing DATA: %v", err) + } +} + +func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + st.t.Fatalf("Error writing DATA: %v", err) + } +} + +func (st *serverTester) readFrame() (http2Frame, error) { + return st.fr.ReadFrame() +} + +func (st *serverTester) wantHeaders() *http2HeadersFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a HEADERS frame: %v", err) + } + hf, ok := f.(*http2HeadersFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2HeadersFrame", f) + } + return hf +} + +func (st *serverTester) wantContinuation() *http2ContinuationFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err) + } + cf, ok := f.(*http2ContinuationFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2ContinuationFrame", f) + } + return cf +} + +func (st *serverTester) wantData() *http2DataFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a DATA frame: %v", err) + } + df, ok := f.(*http2DataFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2DataFrame", f) + } + return df +} + +func (st *serverTester) wantSettings() *http2SettingsFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) + } + sf, ok := f.(*http2SettingsFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2SettingsFrame", f) + } + return sf +} + +func (st *serverTester) wantPing() *http2PingFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a PING frame: %v", err) + } + pf, ok := f.(*http2PingFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2PingFrame", f) + } + return pf +} + +func (st *serverTester) wantGoAway() *http2GoAwayFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err) + } + gf, ok := f.(*http2GoAwayFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2GoAwayFrame", f) + } + return gf +} + +func (st *serverTester) wantRSTStream(streamID uint32, errCode http2ErrCode) { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting an RSTStream frame: %v", err) + } + rs, ok := f.(*http2RSTStreamFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2RSTStreamFrame", f) + } + if rs.http2FrameHeader.StreamID != streamID { + st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.http2FrameHeader.StreamID, streamID) + } + if rs.ErrCode != errCode { + st.t.Fatalf("RSTStream http2ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) + } +} + +func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err) + } + wu, ok := f.(*http2WindowUpdateFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2WindowUpdateFrame", f) + } + if wu.http2FrameHeader.StreamID != streamID { + st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.http2FrameHeader.StreamID, streamID) + } + if wu.Increment != incr { + st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) + } +} + +func (st *serverTester) wantSettingsAck() { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + sf, ok := f.(*http2SettingsFrame) + if !ok { + st.t.Fatalf("Wanting a settings ACK, received a %T", f) + } + if !sf.Header().Flags.Has(http2FlagSettingsAck) { + st.t.Fatal("Settings Frame didn't have ACK set") + } +} + +func (st *serverTester) wantPushPromise() *http2PushPromiseFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + ppf, ok := f.(*http2PushPromiseFrame) + if !ok { + st.t.Fatalf("Wanted PushPromise, received %T", ppf) + } + return ppf +} + +type specCoverage struct { + coverage map[specPart]bool + d *xml.Decoder +} + +func joinSection(sec []int) string { + s := fmt.Sprintf("%d", sec[0]) + for _, n := range sec[1:] { + s = fmt.Sprintf("%s.%d", s, n) + } + return s +} + +func (sc specCoverage) readSection(sec []int) { + var ( + buf = new(bytes.Buffer) + sub = 0 + ) + for { + tk, err := sc.d.Token() + if err != nil { + if err == io.EOF { + return + } + panic(err) + } + switch v := tk.(type) { + case xml.StartElement: + if skipElement(v) { + if err := sc.d.Skip(); err != nil { + panic(err) + } + if v.Name.Local == "section" { + sub++ + } + break + } + switch v.Name.Local { + case "section": + sub++ + sc.readSection(append(sec, sub)) + case "xref": + buf.Write(sc.readXRef(v)) + } + case xml.CharData: + if len(sec) == 0 { + break + } + buf.Write(v) + case xml.EndElement: + if v.Name.Local == "section" { + sc.addSentences(joinSection(sec), buf.String()) + return + } + } + } +} + +func attrSig(se xml.StartElement) string { + var names []string + for _, attr := range se.Attr { + if attr.Name.Local == "fmt" { + names = append(names, "fmt-"+attr.Value) + } else { + names = append(names, attr.Name.Local) + } + } + sort.Strings(names) + return strings.Join(names, ",") +} + +func attrValue(se xml.StartElement, attr string) string { + for _, a := range se.Attr { + if a.Name.Local == attr { + return a.Value + } + } + panic("unknown attribute " + attr) +} + +func (sc specCoverage) readXRef(se xml.StartElement) []byte { + var b []byte + for { + tk, err := sc.d.Token() + if err != nil { + panic(err) + } + switch v := tk.(type) { + case xml.CharData: + if b != nil { + panic("unexpected CharData") + } + b = []byte(string(v)) + case xml.EndElement: + if v.Name.Local != "xref" { + panic("expected ") + } + if b != nil { + return b + } + sig := attrSig(se) + switch sig { + case "target": + return []byte(fmt.Sprintf("[%s]", attrValue(se, "target"))) + case "fmt-of,rel,target", "fmt-,,rel,target": + return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel"))) + case "fmt-of,sec,target", "fmt-,,sec,target": + return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target"))) + case "fmt-of,rel,sec,target": + return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel"))) + default: + panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se))) + } + default: + panic(fmt.Sprintf("unexpected tag %q", v)) + } + } +} + +var skipAnchor = map[string]bool{ + "intro": true, + "Overview": true, +} + +var skipTitle = map[string]bool{ + "Acknowledgements": true, + "Change Log": true, + "Document Organization": true, + "Conventions and Terminology": true, +} + +func skipElement(s xml.StartElement) bool { + switch s.Name.Local { + case "artwork": + return true + case "section": + for _, attr := range s.Attr { + switch attr.Name.Local { + case "anchor": + if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") { + return true + } + case "title": + if skipTitle[attr.Value] { + return true + } + } + } + } + return false +} + +type specPart struct { + section string + sentence string +} + +func (ss specPart) Less(oo specPart) bool { + atoi := func(s string) int { + n, err := strconv.Atoi(s) + if err != nil { + panic(err) + } + return n + } + a := strings.Split(ss.section, ".") + b := strings.Split(oo.section, ".") + for len(a) > 0 { + if len(b) == 0 { + return false + } + x, y := atoi(a[0]), atoi(b[0]) + if x == y { + a, b = a[1:], b[1:] + continue + } + return x < y + } + if len(b) > 0 { + return true + } + return false +} + +type bySpecSection []specPart + +func (a bySpecSection) Len() int { return len(a) } +func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) } +func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +func readSpecCov(r io.Reader) specCoverage { + sc := specCoverage{ + coverage: map[specPart]bool{}, + d: xml.NewDecoder(r)} + sc.readSection(nil) + return sc +} + +var whitespaceRx = regexp.MustCompile(`\s+`) + +func parseSentences(sens string) []string { + sens = strings.TrimSpace(sens) + if sens == "" { + return nil + } + ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ") + for i, s := range ss { + s = strings.TrimSpace(s) + if !strings.HasSuffix(s, ".") { + s += "." + } + ss[i] = s + } + return ss +} + +func (sc specCoverage) addSentences(sec string, sentence string) { + for _, s := range parseSentences(sentence) { + sc.coverage[specPart{sec, s}] = false + } +} + +func (sc specCoverage) cover(sec string, sentence string) { + for _, s := range parseSentences(sentence) { + p := specPart{sec, s} + if _, ok := sc.coverage[p]; !ok { + panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s)) + } + sc.coverage[specPart{sec, s}] = true + } + +} + +var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests") + +// The global map of sentence coverage for the http2 spec. +var defaultSpecCoverage specCoverage + +var loadSpecOnce sync.Once + +func loadSpec() { + if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil { + panic(err) + } else { + defaultSpecCoverage = readSpecCov(f) + f.Close() + } +} + +// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not +// "covered" will be included in report outputted by TestSpecCoverage. +func covers(sec, sentences string) { + loadSpecOnce.Do(loadSpec) + defaultSpecCoverage.cover(sec, sentences) +} diff --git a/h2_transport.go b/h2_transport.go index 316283fe..9aa936a2 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -146,20 +146,13 @@ type http2Transport struct { connPoolOrDef http2ClientConnPool // non-nil version of ConnPool } -const h2max = 1<<32 - 1 - func (t *http2Transport) maxHeaderListSize() uint32 { + if t.MaxHeaderListSize == 0 { + return 10 << 20 + } if t.MaxHeaderListSize == 0xffffffff { return 0 } - if t.MaxHeaderListSize > 0 { - return t.MaxHeaderListSize - } - if limit := t.t1.MaxResponseHeaderBytes; limit > 0 && limit < h2max { - t.MaxHeaderListSize = h2max - } else { - t.MaxHeaderListSize = 10 << 20 - } return t.MaxHeaderListSize } @@ -600,7 +593,7 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) - if t.t1.TLSClientConfig != nil { + if t.t1 != nil && t.t1.TLSClientConfig != nil { *cfg = *t.t1.TLSClientConfig.Clone() } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { @@ -1262,7 +1255,10 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } - dumps := getDumpers(cs.cc.t.t1.dump, req.Context()) + var dumps []*dumper + if t1 := cs.cc.t.t1; t1 != nil { + dumps = getDumpers(t1.dump, req.Context()) + } // Past this point (where we send request headers), it is possible for // RoundTrip to return successfully. Since the RoundTrip contract permits diff --git a/h2_transport_test.go b/h2_transport_test.go new file mode 100644 index 00000000..846a535f --- /dev/null +++ b/h2_transport_test.go @@ -0,0 +1,5855 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package req + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/http2/hpack" +) + +var ( + extNet = flag.Bool("extnet", false, "do external network tests") + transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") + insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove? +) + +var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} + +var canceledCtx context.Context + +func init() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + canceledCtx = ctx +} + +func TestTransportExternal(t *testing.T) { + if !*extNet { + t.Skip("skipping external network test") + } + req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) + rt := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + res, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("%v", err) + } + res.Write(os.Stdout) +} + +type fakeTLSConn struct { + net.Conn +} + +func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { + return tls.ConnectionState{ + Version: tls.VersionTLS12, + CipherSuite: http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +func startH2cServer(t *testing.T) net.Listener { + h2Server := &http2Server{} + l := newLocalListener(t) + go func() { + conn, err := l.Accept() + if err != nil { + t.Error(err) + return + } + h2Server.ServeConn(&fakeTLSConn{conn}, &http2ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) + })}) + }() + return l +} + +func TestTransportH2c(t *testing.T) { + l := startH2cServer(t) + defer l.Close() + req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil) + if err != nil { + t.Fatal(err) + } + var gotConnCnt int32 + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + tr := &http2Transport{ + AllowHTTP: true, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + }, + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if res.ProtoMajor != 2 { + t.Fatal("proto not h2c") + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(body), "Hello, /foobar, http: true"; got != want { + t.Fatalf("response got %v, want %v", got, want) + } + if got, want := gotConnCnt, int32(1); got != want { + t.Errorf("Too many got connections: %d", gotConnCnt) + } +} + +func TestTransport(t *testing.T) { + const body = "sup" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, body) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + u, err := url.Parse(st.ts.URL) + if err != nil { + t.Fatal(err) + } + for i, m := range []string{"GET", ""} { + req := &http.Request{ + Method: m, + URL: u, + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("%d: %s", i, err) + } + + t.Logf("%d: Got res: %+v", i, res) + if g, w := res.StatusCode, 200; g != w { + t.Errorf("%d: StatusCode = %v; want %v", i, g, w) + } + if g, w := res.Status, "200 OK"; g != w { + t.Errorf("%d: Status = %q; want %q", i, g, w) + } + wantHeader := http.Header{ + "Content-Length": []string{"3"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Date": []string{"XXX"}, // see cleanDate + } + cleanDate(res) + if !reflect.DeepEqual(res.Header, wantHeader) { + t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader) + } + if res.Request != req { + t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req) + } + if res.TLS == nil { + t.Errorf("%d: Response.TLS = nil; want non-nil", i) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("%d: Body read: %v", i, err) + } else if string(slurp) != body { + t.Errorf("%d: Body = %q; want %q", i, slurp, body) + } + res.Body.Close() + } +} + +func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Request)) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer, func(c net.Conn, st http.ConnState) { + t.Logf("conn %v is now state %v", c.RemoteAddr(), st) + }) + defer st.Close() + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + } + defer tr.CloseIdleConnections() + get := func() string { + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + modReq(req) + var res *http.Response + + res, err = tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body read: %v", err) + } + addr := strings.TrimSpace(string(slurp)) + if addr == "" { + t.Fatalf("didn't get an addr in response") + } + return addr + } + first := get() + second := get() + if got := first == second; got != wantSame { + t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame) + } +} + +func TestTransportReusesConns(t *testing.T) { + for _, test := range []struct { + name string + modReq func(*http.Request) + wantSame bool + }{{ + name: "ReuseConn", + modReq: func(*http.Request) {}, + wantSame: true, + }, { + name: "RequestClose", + modReq: func(r *http.Request) { r.Close = true }, + wantSame: false, + }, { + name: "ConnClose", + modReq: func(r *http.Request) { r.Header.Set("Connection", "close") }, + wantSame: false, + }} { + t.Run(test.name, func(t *testing.T) { + t.Run("Transport", func(t *testing.T) { + const useClient = false + testTransportReusesConns(t, test.wantSame, test.modReq) + }) + t.Run("Client", func(t *testing.T) { + const useClient = true + testTransportReusesConns(t, test.wantSame, test.modReq) + }) + }) + } +} + +func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) { + testTransportGetGotConnHooks(t) +} + +func testTransportGetGotConnHooks(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, func(s *httptest.Server) { + s.EnableHTTP2 = true + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + } + + var ( + getConns int32 + gotConns int32 + ) + for i := 0; i < 2; i++ { + trace := &httptrace.ClientTrace{ + GetConn: func(hostport string) { + atomic.AddInt32(&getConns, 1) + }, + GotConn: func(connInfo httptrace.GotConnInfo) { + got := atomic.AddInt32(&gotConns, 1) + wantReused, wantWasIdle := false, false + if got > 1 { + wantReused, wantWasIdle = true, true + } + if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle { + t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle) + } + }, + } + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + var res *http.Response + res, err = tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if get := atomic.LoadInt32(&getConns); get != int32(i+1) { + t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1) + } + if got := atomic.LoadInt32(&gotConns); got != int32(i+1) { + t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1) + } + } +} + +type testNetConn struct { + net.Conn + closed bool + onClose func() +} + +func (c *testNetConn) Close() error { + if !c.closed { + // We can call Close multiple times on the same net.Conn. + c.onClose() + } + c.closed = true + return c.Conn.Close() +} + +// Tests that the Transport only keeps one pending dial open per destination address. +// https://golang.org/issue/13397 +func TestTransportGroupsPendingDials(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, optOnlyServer) + defer st.Close() + var ( + mu sync.Mutex + dialCount int + closeCount int + ) + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + mu.Lock() + dialCount++ + mu.Unlock() + c, err := tls.Dial(network, addr, cfg) + return &testNetConn{ + Conn: c, + onClose: func() { + mu.Lock() + closeCount++ + mu.Unlock() + }, + }, err + }, + } + defer tr.CloseIdleConnections() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Error(err) + return + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() + }() + } + wg.Wait() + tr.CloseIdleConnections() + if dialCount != 1 { + t.Errorf("saw %d dials; want 1", dialCount) + } + if closeCount != 1 { + t.Errorf("saw %d closes; want 1", closeCount) + } +} + +func retry(tries int, delay time.Duration, fn func() error) error { + var err error + for i := 0; i < tries; i++ { + err = fn() + if err == nil { + return nil + } + time.Sleep(delay) + } + return err +} + +func TestTransportAbortClosesPipes(t *testing.T) { + shutdown := make(chan struct{}) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + <-shutdown + }, + optOnlyServer, + ) + defer st.Close() + defer close(shutdown) // we must shutdown before st.Close() to avoid hanging + + errCh := make(chan error) + go func() { + defer close(errCh) + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + errCh <- err + return + } + res, err := tr.RoundTrip(req) + if err != nil { + errCh <- err + return + } + defer res.Body.Close() + st.closeConn() + _, err = ioutil.ReadAll(res.Body) + if err == nil { + errCh <- errors.New("expected error from res.Body.Read") + return + } + }() + + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + // deadlock? that's a bug. + case <-time.After(3 * time.Second): + t.Fatal("timeout") + } +} + +// TODO: merge this with TestTransportBody to make TestTransportRequest? This +// could be a table-driven test with extra goodies. +func TestTransportPath(t *testing.T) { + gotc := make(chan *url.URL, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + gotc <- r.URL + }, + optOnlyServer, + ) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + const ( + path = "/testpath" + query = "q=1" + ) + surl := st.ts.URL + path + "?" + query + req, err := http.NewRequest("POST", surl, nil) + if err != nil { + t.Fatal(err) + } + c := &http.Client{Transport: tr} + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + got := <-gotc + if got.Path != path { + t.Errorf("Read Path = %q; want %q", got.Path, path) + } + if got.RawQuery != query { + t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query) + } +} + +func randString(n int) string { + rnd := rand.New(rand.NewSource(int64(n))) + b := make([]byte, n) + for i := range b { + b[i] = byte(rnd.Intn(256)) + } + return string(b) +} + +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + // http.NoBody means 0, not -1. + 3: { + req: &http.Request{Body: http.NoBody}, + want: 0, + }, + } + for i, tt := range tests { + got := http2actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + +func TestTransportBody(t *testing.T) { + bodyTests := []struct { + body string + noContentLen bool + }{ + {body: "some message"}, + {body: "some message", noContentLen: true}, + {body: strings.Repeat("a", 1<<20), noContentLen: true}, + {body: strings.Repeat("a", 1<<20)}, + {body: randString(16<<10 - 1)}, + {body: randString(16 << 10)}, + {body: randString(16<<10 + 1)}, + {body: randString(512<<10 - 1)}, + {body: randString(512 << 10)}, + {body: randString(512<<10 + 1)}, + {body: randString(1<<20 - 1)}, + {body: randString(1 << 20)}, + {body: randString(1<<20 + 2)}, + } + + type reqInfo struct { + req *http.Request + slurp []byte + err error + } + gotc := make(chan reqInfo, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + gotc <- reqInfo{err: err} + } else { + gotc <- reqInfo{req: r, slurp: slurp} + } + }, + optOnlyServer, + ) + defer st.Close() + + for i, tt := range bodyTests { + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + var body io.Reader = strings.NewReader(tt.body) + if tt.noContentLen { + body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods + } + req, err := http.NewRequest("POST", st.ts.URL, body) + if err != nil { + t.Fatalf("#%d: %v", i, err) + } + c := &http.Client{Transport: tr} + res, err := c.Do(req) + if err != nil { + t.Fatalf("#%d: %v", i, err) + } + defer res.Body.Close() + ri := <-gotc + if ri.err != nil { + t.Errorf("#%d: read error: %v", i, ri.err) + continue + } + if got := string(ri.slurp); got != tt.body { + t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body)) + } + wantLen := int64(len(tt.body)) + if tt.noContentLen && tt.body != "" { + wantLen = -1 + } + if ri.req.ContentLength != wantLen { + t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen) + } + } +} + +func shortString(v string) string { + const maxLen = 100 + if len(v) <= maxLen { + return v + } + return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) +} + +func TestTransportDialTLS(t *testing.T) { + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + gotReq = true + mu.Unlock() + }, + optOnlyServer, + ) + defer ts.Close() + tr := &http2Transport{ + DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + cfg.InsecureSkipVerify = true + c, err := tls.Dial(netw, addr, cfg) + if err != nil { + return nil, err + } + return c, c.Handshake() + }, + } + defer tr.CloseIdleConnections() + client := &http.Client{Transport: tr} + res, err := client.Get(ts.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + +func TestConfigureTransport(t *testing.T) { + t1 := &Transport{} + err := http2ConfigureTransport(t1) + if err != nil { + t.Fatal(err) + } + if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) { + // Laziness, to avoid buildtags. + t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got) + } + wantNextProtos := []string{"h2", "http/1.1"} + if t1.TLSClientConfig == nil { + t.Errorf("nil t1.TLSClientConfig") + } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) { + t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos) + } + if err := http2ConfigureTransport(t1); err == nil { + t.Error("unexpected success on second call to http2ConfigureTransport") + } + + // And does it work? + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.Proto) + }, optOnlyServer) + defer st.Close() + + t1.TLSClientConfig.InsecureSkipVerify = true + c := &http.Client{Transport: t1} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(slurp), "HTTP/2.0"; got != want { + t.Errorf("body = %q; want %q", got, want) + } +} + +type capitalizeReader struct { + r io.Reader +} + +func (cr capitalizeReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + for i, b := range p[:n] { + if b >= 'a' && b <= 'z' { + p[i] = b - ('a' - 'A') + } + } + return +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} + +type clientTester struct { + t *testing.T + tr *http2Transport + sc, cc net.Conn // server and client conn + fr *http2Framer // server's framer + client func() error + server func() error +} + +func newClientTester(t *testing.T) *clientTester { + var dialOnce struct { + sync.Mutex + dialed bool + } + ct := &clientTester{ + t: t, + } + ct.tr = &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialOnce.Lock() + defer dialOnce.Unlock() + if dialOnce.dialed { + return nil, errors.New("only one dial allowed in test mode") + } + dialOnce.dialed = true + return ct.cc, nil + }, + } + + ln := newLocalListener(t) + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + + } + sc, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + ln.Close() + ct.cc = cc + ct.sc = sc + ct.fr = http2NewFramer(sc, sc) + return ct +} + +func (ct *clientTester) greet(settings ...http2Setting) { + buf := make([]byte, len(http2ClientPreface)) + _, err := io.ReadFull(ct.sc, buf) + if err != nil { + ct.t.Fatalf("reading client preface: %v", err) + } + f, err := ct.fr.ReadFrame() + if err != nil { + ct.t.Fatalf("Reading client settings frame: %v", err) + } + if sf, ok := f.(*http2SettingsFrame); !ok { + ct.t.Fatalf("Wanted client settings frame; got %v", f) + _ = sf // stash it away? + } + if err := ct.fr.WriteSettings(settings...); err != nil { + ct.t.Fatal(err) + } + if err := ct.fr.WriteSettingsAck(); err != nil { + ct.t.Fatal(err) + } +} + +func (ct *clientTester) readNonSettingsFrame() (http2Frame, error) { + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return nil, err + } + if _, ok := f.(*http2SettingsFrame); ok { + continue + } + return f, nil + } +} + +func (ct *clientTester) cleanup() { + ct.tr.CloseIdleConnections() + + // close both connections, ignore the error if its already closed + ct.sc.Close() + ct.cc.Close() +} + +func (ct *clientTester) run() { + var errOnce sync.Once + var wg sync.WaitGroup + + run := func(which string, fn func() error) { + defer wg.Done() + if err := fn(); err != nil { + errOnce.Do(func() { + ct.t.Errorf("%s: %v", which, err) + ct.cleanup() + }) + } + } + + wg.Add(2) + go run("client", ct.client) + go run("server", ct.server) + wg.Wait() + + errOnce.Do(ct.cleanup) // clean up if no error +} + +func (ct *clientTester) readFrame() (http2Frame, error) { + return ct.fr.ReadFrame() +} + +func (ct *clientTester) firstHeaders() (*http2HeadersFrame, error) { + for { + f, err := ct.readFrame() + if err != nil { + return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + continue + } + hf, ok := f.(*http2HeadersFrame) + if !ok { + return nil, fmt.Errorf("Got %T; want HeadersFrame", f) + } + return hf, nil + } +} + +type countingReader struct { + n *int64 +} + +func (r countingReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(i) + } + atomic.AddInt64(r.n, int64(len(p))) + return len(p), err +} + +func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } +func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } + +func testTransportReqBodyAfterResponse(t *testing.T, status int) { + const bodySize = 10 << 20 + clientDone := make(chan struct{}) + ct := newClientTester(t) + recvLen := make(chan int64, 1) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + defer close(clientDone) + + body := &http2pipe{b: new(bytes.Buffer)} + io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + return err + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + if res.StatusCode != status { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) + } + io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) + body.CloseWithError(io.EOF) + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Slurp: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("unexpected body: %q", slurp) + } + res.Body.Close() + if status == 200 { + if got := <-recvLen; got != bodySize { + return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) + } + } else { + if got := <-recvLen; got == 0 || got >= bodySize { + return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) + } + } + return nil + } + ct.server = func() error { + ct.greet() + defer close(recvLen) + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var dataRecv int64 + var closed bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + // println(fmt.Sprintf("server got frame: %v", f)) + ended := false + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + if f.StreamEnded() { + return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) + } + case *http2DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + if dataRecv == 0 { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + } + dataRecv += int64(dataLen) + + if !closed && ((status != 200 && dataRecv > 0) || + (status == 200 && f.StreamEnded())) { + closed = true + if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { + return err + } + } + + if f.StreamEnded() { + ended = true + } + case *http2RSTStreamFrame: + if status == 200 { + return fmt.Errorf("Unexpected client frame %v", f) + } + ended = true + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + if ended { + select { + case recvLen <- dataRecv: + default: + } + } + } + } + ct.run() +} + +// See golang.org/issue/13444 +func TestTransportFullDuplex(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) // redundant but for clarity + w.(http.Flusher).Flush() + io.Copy(flushWriter{w}, capitalizeReader{r.Body}) + fmt.Fprintf(w, "bye.\n") + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + pr, pw := io.Pipe() + req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = -1 + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) + } + bs := bufio.NewScanner(res.Body) + want := func(v string) { + if !bs.Scan() { + t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) + } + } + write := func(v string) { + _, err := io.WriteString(pw, v) + if err != nil { + t.Fatalf("pipe write: %v", err) + } + } + write("foo\n") + want("FOO") + write("bar\n") + want("BAR") + pw.Close() + want("bye.") + if err := bs.Err(); err != nil { + t.Fatal(err) + } +} + +func TestTransportConnectRequest(t *testing.T) { + gotc := make(chan *http.Request, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + gotc <- r + }, optOnlyServer) + defer st.Close() + + u, err := url.Parse(st.ts.URL) + if err != nil { + t.Fatal(err) + } + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + tests := []struct { + req *http.Request + want string + }{ + { + req: &http.Request{ + Method: "CONNECT", + Header: http.Header{}, + URL: u, + }, + want: u.Host, + }, + { + req: &http.Request{ + Method: "CONNECT", + Header: http.Header{}, + URL: u, + Host: "example.com:123", + }, + want: "example.com:123", + }, + } + + for i, tt := range tests { + res, err := c.Do(tt.req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + res.Body.Close() + req := <-gotc + if req.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", req.Method) + } + if req.Host != tt.want { + t.Errorf("Host = %q; want %q", req.Host, tt.want) + } + if req.URL.Host != tt.want { + t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) + } + } +} + +type headerType int + +const ( + noHeader headerType = iota // omitted + oneHeader + splitHeader // broken into continuation on purpose +) + +const ( + f0 = noHeader + f1 = oneHeader + f2 = splitHeader + d0 = false + d1 = true +) + +// Test all 36 combinations of response frame orders: +// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } +// Generated by http://play.golang.org/p/SScqYKJYXd +func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) } +func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) } +func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) } +func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) } +func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) } +func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) } +func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) } +func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) } +func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) } +func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) } +func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) } +func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) } +func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) } +func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) } +func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) } +func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) } +func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) } +func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) } +func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) } +func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) } +func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) } +func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) } +func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) } +func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) } +func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) } +func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) } +func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) } +func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) } +func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) } +func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) } +func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) } +func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) } +func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) } +func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) } +func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) } +func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } + +func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { + const reqBody = "some request body" + const resBody = "some response body" + + if resHeader == noHeader { + // TODO: test 100-continue followed by immediate + // server stream reset, without headers in the middle? + panic("invalid combination") + } + + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Slurp: %v", err) + } + wantBody := resBody + if !withData { + wantBody = "" + } + if string(slurp) != wantBody { + return fmt.Errorf("body = %q; want %q", slurp, wantBody) + } + if trailers == noHeader { + if len(res.Trailer) > 0 { + t.Errorf("Trailer = %v; want none", res.Trailer) + } + } else { + want := http.Header{"Some-Trailer": {"some-value"}} + if !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Trailer = %v; want %v", res.Trailer, want) + } + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + endStream := false + send := func(mode headerType) { + hbf := buf.Bytes() + switch mode { + case oneHeader: + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.Header().StreamID, + EndHeaders: true, + EndStream: endStream, + BlockFragment: hbf, + }) + case splitHeader: + if len(hbf) < 2 { + panic("too small") + } + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.Header().StreamID, + EndHeaders: false, + EndStream: endStream, + BlockFragment: hbf[:1], + }) + ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) + default: + panic("bogus mode") + } + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2DataFrame: + if !f.StreamEnded() { + // No need to send flow control tokens. The test request body is tiny. + continue + } + // Response headers (1+ frames; 1 or 2 in this test, but never 0) + { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) + enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) + if trailers != noHeader { + enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) + } + endStream = withData == false && trailers == noHeader + send(resHeader) + } + if withData { + endStream = trailers == noHeader + ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) + } + if trailers != noHeader { + endStream = true + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) + send(trailers) + } + if endStream { + return nil + } + case *http2HeadersFrame: + if expect100Continue != noHeader { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) + send(expect100Continue) + } + } + } + } + ct.run() +} + +// Issue 26189, Issue 17739: ignore unknown 1xx responses +func TestTransportUnknown1xx(t *testing.T) { + var buf bytes.Buffer + defer func() { http2got1xxFuncForTests = nil }() + http2got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { + fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) + return nil + } + + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 204 { + return fmt.Errorf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] +code=111 header=map[Foo-Bar:[111]] +code=112 header=map[Foo-Bar:[112]] +code=113 header=map[Foo-Bar:[113]] +code=114 header=map[Foo-Bar:[114]] +` + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + for i := 110; i <= 114; i++ { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) + enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + return nil + } + } + } + ct.run() + +} + +func TestTransportReceiveUndeclaredTrailer(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) + } + if len(slurp) > 0 { + return fmt.Errorf("body = %q; want nothing", slurp) + } + if _, ok := res.Trailer["Some-Trailer"]; !ok { + return fmt.Errorf("expected Some-Trailer") + } + return nil + } + ct.server = func() error { + ct.greet() + + var n int + var hf *http2HeadersFrame + for hf == nil && n < 10 { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + hf, _ = f.(*http2HeadersFrame) + n++ + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + // send headers without Trailer header + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + // send trailers + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + return nil + } + ct.run() +} + +func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { + testTransportInvalidTrailer_Pseudo(t, oneHeader) +} +func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { + testTransportInvalidTrailer_Pseudo(t, splitHeader) +} +func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { + testInvalidTrailer(t, trailers, http2pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + }) +} + +func TestTransportInvalidTrailer_Capital1(t *testing.T) { + testTransportInvalidTrailer_Capital(t, oneHeader) +} +func TestTransportInvalidTrailer_Capital2(t *testing.T) { + testTransportInvalidTrailer_Capital(t, splitHeader) +} +func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { + testInvalidTrailer(t, trailers, http2headerFieldNameError("Capital"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) + }) +} +func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { + testInvalidTrailer(t, oneHeader, http2headerFieldNameError(""), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) + }) +} +func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { + testInvalidTrailer(t, oneHeader, http2headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) + }) +} + +func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + se, ok := err.(http2StreamError) + if !ok || se.Cause != wantErr { + return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) + } + if len(slurp) > 0 { + return fmt.Errorf("body = %q; want nothing", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2HeadersFrame: + var endStream bool + send := func(mode headerType) { + hbf := buf.Bytes() + switch mode { + case oneHeader: + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: endStream, + BlockFragment: hbf, + }) + case splitHeader: + if len(hbf) < 2 { + panic("too small") + } + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: false, + EndStream: endStream, + BlockFragment: hbf[:1], + }) + ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) + default: + panic("bogus mode") + } + } + // Response headers (1+ frames; 1 or 2 in this test, but never 0) + { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) + endStream = false + send(oneHeader) + } + // Trailers: + { + endStream = true + buf.Reset() + writeTrailer(enc) + send(trailers) + } + return nil + } + } + } + ct.run() +} + +// headerListSize returns the HTTP2 header list size of h. +// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE +// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock +func headerListSize(h http.Header) (size uint32) { + for k, vv := range h { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + size += hf.Size() + } + } + return size +} + +// padHeaders adds data to an http.Header until headerListSize(h) == +// limit. Due to the way header list sizes are calculated, padHeaders +// cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will +// call t.Fatal if asked to do so. PadHeaders first reserves enough +// space for an empty "Pad-Headers" key, then adds as many copies of +// filler as possible. Any remaining bytes necessary to push the +// header list size up to limit are added to h["Pad-Headers"]. +func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { + if limit > 0xffffffff { + t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) + } + hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} + minPadding := uint64(hf.Size()) + size := uint64(headerListSize(h)) + + minlimit := size + minPadding + if limit < minlimit { + t.Fatalf("padHeaders: limit %v < %v", limit, minlimit) + } + + // Use a fixed-width format for name so that fieldSize + // remains constant. + nameFmt := "Pad-Headers-%06d" + hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler} + fieldSize := uint64(hf.Size()) + + // Add as many complete filler values as possible, leaving + // room for at least one empty "Pad-Headers" key. + limit = limit - minPadding + for i := 0; size+fieldSize < limit; i++ { + name := fmt.Sprintf(nameFmt, i) + h.Add(name, filler) + size += fieldSize + } + + // Add enough bytes to reach limit. + remain := limit - size + lastValue := strings.Repeat("*", int(remain)) + h.Add("Pad-Headers", lastValue) +} + +func TestPadHeaders(t *testing.T) { + check := func(h http.Header, limit uint32, fillerLen int) { + if h == nil { + h = make(http.Header) + } + filler := strings.Repeat("f", fillerLen) + padHeaders(t, h, uint64(limit), filler) + gotSize := headerListSize(h) + if gotSize != limit { + t.Errorf("Got size = %v; want %v", gotSize, limit) + } + } + // Try all possible combinations for small fillerLen and limit. + hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} + minLimit := hf.Size() + for limit := minLimit; limit <= 128; limit++ { + for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ { + check(nil, limit, fillerLen) + } + } + + // Try a few tests with larger limits, plus cumulative + // tests. Since these tests are cumulative, tests[i+1].limit + // must be >= tests[i].limit + minLimit. See the comment on + // padHeaders for more info on why the limit arg has this + // restriction. + tests := []struct { + fillerLen int + limit uint32 + }{ + { + fillerLen: 64, + limit: 1024, + }, + { + fillerLen: 1024, + limit: 1286, + }, + { + fillerLen: 256, + limit: 2048, + }, + { + fillerLen: 1024, + limit: 10 * 1024, + }, + { + fillerLen: 1023, + limit: 11 * 1024, + }, + } + h := make(http.Header) + for _, tc := range tests { + check(nil, tc.limit, tc.fillerLen) + check(h, tc.limit, tc.fillerLen) + } +} + +func TestTransportChecksRequestHeaderListSize(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + // Consume body & force client to send + // trailers before writing response. + // ioutil.ReadAll returns non-nil err for + // requests that attempt to send greater than + // maxHeaderListSize bytes of trailers, since + // those requests generate a stream reset. + ioutil.ReadAll(r.Body) + r.Body.Close() + }, + func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 16 << 10 + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + checkRoundTrip := func(req *http.Request, wantErr error, desc string) { + res, err := tr.RoundTrip(req) + if err != wantErr { + if res != nil { + res.Body.Close() + } + t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) + return + } + if err == nil { + if res == nil { + t.Errorf("%v: response nil; want non-nil.", desc) + return + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) + } + return + } + if res != nil { + t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) + } + } + headerListSizeForRequest := func(req *http.Request) (size uint64) { + contentLen := http2actualContentLength(req) + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + cc := &http2ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc.henc = hpack.NewEncoder(&cc.hbuf) + cc.mu.Lock() + hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen, nil) + cc.mu.Unlock() + if err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + hpackDec := hpack.NewDecoder(http2initialHeaderTableSize, func(hf hpack.HeaderField) { + size += uint64(hf.Size()) + }) + if len(hdrs) > 0 { + if _, err := hpackDec.Write(hdrs); err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + } + return size + } + // Create a new Request for each test, rather than reusing the + // same Request, to avoid a race when modifying req.Headers. + // See https://github.com/golang/go/issues/21316 + newRequest := func() *http.Request { + // Body must be non-nil to enable writing trailers. + body := strings.NewReader("hello") + req, err := http.NewRequest("POST", st.ts.URL, body) + if err != nil { + t.Fatalf("newRequest: NewRequest: %v", err) + } + return req + } + + // Make an arbitrary request to ensure we get the server's + // settings frame and initialize peerMaxHeaderListSize. + req := newRequest() + checkRoundTrip(req, nil, "Initial request") + + // Get the ClientConn associated with the request and validate + // peerMaxHeaderListSize. + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + cc, err := tr.connPool().GetClientConn(req, addr) + if err != nil { + t.Fatalf("GetClientConn: %v", err) + } + cc.mu.Lock() + peerSize := cc.peerMaxHeaderListSize + cc.mu.Unlock() + st.scMu.Lock() + wantSize := uint64(st.sc.maxHeaderListSize()) + st.scMu.Unlock() + if peerSize != wantSize { + t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) + } + + // Sanity check peerSize. (*serverConn) maxHeaderListSize adds + // 320 bytes of padding. + wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320 + if peerSize != wantHeaderBytes { + t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) + } + + // Pad headers & trailers, but stay under peerSize. + req = newRequest() + req.Header = make(http.Header) + req.Trailer = make(http.Header) + filler := strings.Repeat("*", 1024) + padHeaders(t, req.Trailer, peerSize, filler) + // cc.encodeHeaders adds some default headers to the request, + // so we need to leave room for those. + defaultBytes := headerListSizeForRequest(req) + padHeaders(t, req.Header, peerSize-defaultBytes, filler) + checkRoundTrip(req, nil, "Headers & Trailers under limit") + + // Add enough header bytes to push us over peerSize. + req = newRequest() + req.Header = make(http.Header) + padHeaders(t, req.Header, peerSize, filler) + checkRoundTrip(req, http2errRequestHeaderListSize, "Headers over limit") + + // Push trailers over the limit. + req = newRequest() + req.Trailer = make(http.Header) + padHeaders(t, req.Trailer, peerSize+1, filler) + checkRoundTrip(req, http2errRequestHeaderListSize, "Trailers over limit") + + // Send headers with a single large value. + req = newRequest() + filler = strings.Repeat("*", int(peerSize)) + req.Header = make(http.Header) + req.Header.Set("Big", filler) + checkRoundTrip(req, http2errRequestHeaderListSize, "Single large header") + + // Send trailers with a single large value. + req = newRequest() + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + checkRoundTrip(req, http2errRequestHeaderListSize, "Single large trailer") +} + +func TestTransportChecksResponseHeaderListSize(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if e, ok := err.(http2StreamError); ok { + err = e.Cause + } + if err != http2errResponseHeaderListSize { + size := int64(0) + if res != nil { + res.Body.Close() + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 + } + } + } + return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want http2errResponseHeaderListSize", err, size) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2HeadersFrame: + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + enc.WriteField(hpack.HeaderField{Name: large, Value: large}) + } + if size, want := buf.Len(), 6329; size != want { + // Note: this number might change if + // our hpack implementation + // changes. That's fine. This is + // just a sanity check that our + // response can fit in a single + // header block fragment frame. + return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + return nil + } + } + } + ct.run() +} + +func TestTransportCookieHeaderSplit(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + req.Header.Add("Cookie", "a=b;c=d; e=f;") + req.Header.Add("Cookie", "e=f;g=h; ") + req.Header.Add("Cookie", "i=j") + _, err := ct.tr.RoundTrip(req) + return err + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2HeadersFrame: + dec := hpack.NewDecoder(http2initialHeaderTableSize, nil) + hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) + if err != nil { + return err + } + got := []string{} + want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} + for _, hf := range hfs { + if hf.Name == "cookie" { + got = append(got, hf.Value) + } + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Cookies = %#v, want %#v", got, want) + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + return nil + } + } + } + ct.run() +} + +// Test that the Transport returns a typed error from Response.Body.Read calls +// when the server sends an error. (here we use a panic, since that should generate +// a stream error, but others like cancel should be similar) +func TestTransportBodyReadErrorType(t *testing.T) { + doPanic := make(chan bool, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // force headers out + <-doPanic + panic("boom") + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + doPanic <- true + buf := make([]byte, 100) + n, err := res.Body.Read(buf) + got, ok := err.(http2StreamError) + want := http2StreamError{StreamID: 0x1, Code: 0x2} + if !ok || got.StreamID != want.StreamID || got.Code != want.Code { + t.Errorf("Read = %v, %#v; want error %#v", n, err, want) + } +} + +// golang.org/issue/13924 +// This used to fail after many iterations, especially with -race: +// go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race +func TestTransportDoubleCloseOnWriteError(t *testing.T) { + var ( + mu sync.Mutex + conn net.Conn // to close if set + ) + + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + if conn != nil { + conn.Close() + } + }, + optOnlyServer, + ) + defer st.Close() + + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + mu.Lock() + defer mu.Unlock() + conn = tc + return tc, nil + }, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + c.Get(st.ts.URL) +} + +// Test that the http1 Transport.DisableKeepAlives option is respected +// and connections are closed as soon as idle. +// See golang.org/issue/14008 +func TestTransportDisableKeepAlives(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hi") + }, + optOnlyServer, + ) + defer st.Close() + + connClosed := make(chan struct{}) // closed on tls.Conn.Close + tr := &http2Transport{ + t1: &Transport{ + DisableKeepAlives: true, + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil + }, + } + c := &http.Client{Transport: tr} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, err := ioutil.ReadAll(res.Body); err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + select { + case <-connClosed: + case <-time.After(1 * time.Second): + t.Errorf("timeout") + } + +} + +// Test concurrent requests with Transport.DisableKeepAlives. We can share connections, +// but when things are totally idle, it still needs to close. +func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { + const D = 25 * time.Millisecond + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + time.Sleep(D) + io.WriteString(w, "hi") + }, + optOnlyServer, + ) + defer st.Close() + + var dials int32 + var conns sync.WaitGroup + tr := &http2Transport{ + t1: &Transport{ + DisableKeepAlives: true, + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + atomic.AddInt32(&dials, 1) + conns.Add(1) + return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil + }, + } + c := &http.Client{Transport: tr} + var reqs sync.WaitGroup + const N = 20 + for i := 0; i < N; i++ { + reqs.Add(1) + if i == N-1 { + // For the final request, try to make all the + // others close. This isn't verified in the + // count, other than the Log statement, since + // it's so timing dependent. This test is + // really to make sure we don't interrupt a + // valid request. + time.Sleep(D * 2) + } + go func() { + defer reqs.Done() + res, err := c.Get(st.ts.URL) + if err != nil { + t.Error(err) + return + } + if _, err := ioutil.ReadAll(res.Body); err != nil { + t.Error(err) + return + } + res.Body.Close() + }() + } + reqs.Wait() + conns.Wait() + t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N) +} + +type noteCloseConn struct { + net.Conn + onceClose sync.Once + closefn func() +} + +func (c *noteCloseConn) Close() error { + c.onceClose.Do(c.closefn) + return c.Conn.Close() +} + +func isTimeout(err error) bool { + switch err := err.(type) { + case nil: + return false + case *url.Error: + return isTimeout(err.Err) + case net.Error: + return err.Timeout() + } + return false +} + +// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. +func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { + testTransportResponseHeaderTimeout(t, false) +} +func TestTransportResponseHeaderTimeout_Body(t *testing.T) { + testTransportResponseHeaderTimeout(t, true) +} + +func testTransportResponseHeaderTimeout(t *testing.T, body bool) { + ct := newClientTester(t) + ct.tr.t1 = &Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, + } + ct.client = func() error { + c := &http.Client{Transport: ct.tr} + var err error + var n int64 + const bodySize = 4 << 20 + if body { + _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) + } else { + _, err = c.Get("https://dummy.tld/") + } + if !isTimeout(err) { + t.Errorf("client expected timeout error; got %#v", err) + } + if body && n != bodySize { + t.Errorf("only read %d bytes of body; want %d", n, bodySize) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + switch f := f.(type) { + case *http2DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + } + case *http2RSTStreamFrame: + if f.StreamID == 1 && f.ErrCode == http2ErrCodeCancel { + return nil + } + } + } + } + ct.run() +} + +func TestTransportDisableCompression(t *testing.T) { + const body = "sup" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + want := http.Header{ + "User-Agent": []string{hdrUserAgentValue}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("request headers = %v; want %v", r.Header, want) + } + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{ + t1: &Transport{ + DisableCompression: true, + TLSClientConfig: tlsConfigInsecure, + }, + } + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() +} + +// RFC 7540 section 8.1.2.2 +func TestTransportRejectsConnHeaders(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + var got []string + for k := range r.Header { + got = append(got, k) + } + sort.Strings(got) + w.Header().Set("Got-Header", strings.Join(got, ",")) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + tests := []struct { + key string + value []string + want string + }{ + { + key: "Upgrade", + value: []string{"anything"}, + want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]", + }, + { + key: "Connection", + value: []string{"foo"}, + want: "ERROR: http2: invalid Connection request header: [\"foo\"]", + }, + { + key: "Connection", + value: []string{"close"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"CLoSe"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"close", "something-else"}, + want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]", + }, + { + key: "Connection", + value: []string{"keep-alive"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"Keep-ALIVE"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Proxy-Connection", // just deleted and ignored + value: []string{"keep-alive"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{""}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{"foo"}, + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]", + }, + { + key: "Transfer-Encoding", + value: []string{"chunked"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{"chunKed"}, // Kelvin sign + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]", + }, + { + key: "Transfer-Encoding", + value: []string{"chunked", "other"}, + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]", + }, + { + key: "Content-Length", + value: []string{"123"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Keep-Alive", + value: []string{"doop"}, + want: "Accept-Encoding,User-Agent", + }, + } + + for _, tt := range tests { + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Header[tt.key] = tt.value + res, err := tr.RoundTrip(req) + var got string + if err != nil { + got = fmt.Sprintf("ERROR: %v", err) + } else { + got = res.Header.Get("Got-Header") + res.Body.Close() + } + if got != tt.want { + t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want) + } + } +} + +// Reject content-length headers containing a sign. +// See https://golang.org/issue/39017 +func TestTransportRejectsContentLengthWithSign(t *testing.T) { + tests := []struct { + name string + cl []string + wantCL string + }{ + { + name: "proper content-length", + cl: []string{"3"}, + wantCL: "3", + }, + { + name: "ignore cl with plus sign", + cl: []string{"+3"}, + wantCL: "", + }, + { + name: "ignore cl with minus sign", + cl: []string{"-3"}, + wantCL: "", + }, + { + name: "max int64, for safe uint64->int64 conversion", + cl: []string{"9223372036854775807"}, + wantCL: "9223372036854775807", + }, + { + name: "overflows int64, so ignored", + cl: []string{"9223372036854775808"}, + wantCL: "", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", tt.cl[0]) + }, optOnlyServer) + defer st.Close() + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("HEAD", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + + var got string + if err != nil { + got = fmt.Sprintf("ERROR: %v", err) + } else { + got = res.Header.Get("Content-Length") + res.Body.Close() + } + + if got != tt.wantCL { + t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL) + } + }) + } +} + +// golang.org/issue/14048 +func TestTransportFailsOnInvalidHeaders(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + var got []string + for k := range r.Header { + got = append(got, k) + } + sort.Strings(got) + w.Header().Set("Got-Header", strings.Join(got, ",")) + }, optOnlyServer) + defer st.Close() + + tests := [...]struct { + h http.Header + wantErr string + }{ + 0: { + h: http.Header{"with space": {"foo"}}, + wantErr: `invalid HTTP header name "with space"`, + }, + 1: { + h: http.Header{"name": {"Брэд"}}, + wantErr: "", // okay + }, + 2: { + h: http.Header{"имя": {"Brad"}}, + wantErr: `invalid HTTP header name "имя"`, + }, + 3: { + h: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`, + }, + } + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + for i, tt := range tests { + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Header = tt.h + res, err := tr.RoundTrip(req) + var bad bool + if tt.wantErr == "" { + if err != nil { + bad = true + t.Errorf("case %d: error = %v; want no error", i, err) + } + } else { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + bad = true + t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr) + } + } + if err == nil { + if bad { + t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header")) + } + res.Body.Close() + } + } +} + +// Tests that gzipReader doesn't crash on a second Read call following +// the first Read call's gzip.NewReader returning an error. +func TestGzipReader_DoubleReadCrash(t *testing.T) { + gz := &http2gzipReader{ + body: ioutil.NopCloser(strings.NewReader("0123456789")), + } + var buf [1]byte + n, err1 := gz.Read(buf[:]) + if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { + t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) + } + n, err2 := gz.Read(buf[:]) + if n != 0 || err2 != err1 { + t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) + } +} + +func TestTransportNewTLSConfig(t *testing.T) { + tests := [...]struct { + conf *tls.Config + host string + want *tls.Config + }{ + // Normal case. + 0: { + conf: nil, + host: "foo.com", + want: &tls.Config{ + ServerName: "foo.com", + NextProtos: []string{http2NextProtoTLS}, + }, + }, + + // User-provided name (bar.com) takes precedence: + 1: { + conf: &tls.Config{ + ServerName: "bar.com", + }, + host: "foo.com", + want: &tls.Config{ + ServerName: "bar.com", + NextProtos: []string{http2NextProtoTLS}, + }, + }, + + // NextProto is prepended: + 2: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar"}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{http2NextProtoTLS, "foo", "bar"}, + }, + }, + + // NextProto is not duplicated: + 3: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar", http2NextProtoTLS}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{"foo", "bar", http2NextProtoTLS}, + }, + }, + } + for i, tt := range tests { + // Ignore the session ticket keys part, which ends up populating + // unexported fields in the Config: + if tt.conf != nil { + tt.conf.SessionTicketsDisabled = true + } + + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tt.conf, + }, + } + got := tr.newTLSConfig(tt.host) + + got.SessionTicketsDisabled = false + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("%d. got %#v; want %#v", i, got, tt.want) + } + } +} + +// The Google GFE responds to HEAD requests with a HEADERS frame +// without END_STREAM, followed by a 0-length DATA frame with +// END_STREAM. Make sure we don't get confused by that. (We did.) +func TestTransportReadHeadResponse(t *testing.T) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + if res.ContentLength != 123 { + return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("ReadAll: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*http2HeadersFrame) + if !ok { + continue + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(hf.StreamID, true, nil) + + <-clientDone + return nil + } + } + ct.run() +} + +func TestTransportReadHeadResponseWithBody(t *testing.T) { + // This test use not valid response format. + // Discarding logger output to not spam tests output. + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + response := "redirecting to /elsewhere" + ct := newClientTester(t) + clientDone := make(chan struct{}) + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + if res.ContentLength != int64(len(response)) { + return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("ReadAll: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*http2HeadersFrame) + if !ok { + continue + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(hf.StreamID, true, []byte(response)) + + <-clientDone + return nil + } + } + ct.run() +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (int, error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// golang.org/issue/15425: test that a handler closing the request +// body doesn't terminate the stream to the peer. (It just stops +// readability from the handler's side, and eventually the client +// runs out of flow control tokens) +func TestTransportHandlerBodyClose(t *testing.T) { + const bodySize = 10 << 20 + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + g0 := runtime.NumGoroutine() + + const numReq = 10 + for i := 0; i < numReq; i++ { + req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if n != bodySize || err != nil { + t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) + } + } + tr.CloseIdleConnections() + + if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { + gd := runtime.NumGoroutine() - g0 + return gd < numReq/2 + }) { + t.Errorf("appeared to leak goroutines") + } +} + +// https://golang.org/issue/15930 +func TestTransportFlowControl(t *testing.T) { + const bufLen = 64 << 10 + var total int64 = 100 << 20 // 100MB + if testing.Short() { + total = 10 << 20 + } + + var wrote int64 // updated atomically + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + b := make([]byte, bufLen) + for wrote < total { + n, err := w.Write(b) + atomic.AddInt64(&wrote, int64(n)) + if err != nil { + t.Errorf("ResponseWriter.Write error: %v", err) + break + } + w.(http.Flusher).Flush() + } + }, optOnlyServer) + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal("NewRequest error:", err) + } + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatal("RoundTrip error:", err) + } + defer resp.Body.Close() + + var read int64 + b := make([]byte, bufLen) + for { + n, err := resp.Body.Read(b) + if err == io.EOF { + break + } + if err != nil { + t.Fatal("Read error:", err) + } + read += int64(n) + + const max = http2transportDefaultStreamFlow + if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { + t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) + } + + // Let the server get ahead of the client. + time.Sleep(1 * time.Millisecond) + } +} + +// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make +// the Transport remember it and return it back to users (via +// RoundTrip or request body reads) if needed (e.g. if the server +// proceeds to close the TCP connection before the client gets its +// response) +func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { + testTransportUsesGoAwayDebugError(t, false) +} + +func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { + testTransportUsesGoAwayDebugError(t, true) +} + +func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + + const goAwayErrCode = http2ErrCodeHTTP11Required // arbitrary + const goAwayDebugData = "some debug data" + + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if failMidBody { + if err != nil { + return fmt.Errorf("unexpected client RoundTrip error: %v", err) + } + _, err = io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + } + want := http2GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*http2HeadersFrame) + if !ok { + continue + } + if failMidBody { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + ct.fr.WriteGoAway(5, http2ErrCodeNo, []byte(goAwayDebugData)) + ct.fr.WriteGoAway(5, goAwayErrCode, nil) + ct.sc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + ct.sc.(*net.TCPConn).Close() + } + <-clientDone + return nil + } + } + ct.run() +} + +func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { + ct := newClientTester(t) + + clientClosed := make(chan struct{}) + serverWroteFirstByte := make(chan struct{}) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + <-serverWroteFirstByte + + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + close(clientClosed) + + return nil + } + ct.server = func() error { + ct.greet() + + var hf *http2HeadersFrame + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + continue + } + var ok bool + hf, ok = f.(*http2HeadersFrame) + if !ok { + return fmt.Errorf("Got %T; want HeadersFrame", f) + } + break + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client connection to + // close before seconding the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + if oneDataFrame { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) + close(serverWroteFirstByte) + <-clientClosed + } else { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) + close(serverWroteFirstByte) + <-clientClosed + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) + } + + waitingFor := "RSTStreamFrame" + sawRST := false + sawWUF := false + for !sawRST && !sawWUF { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) + } + switch f := f.(type) { + case *http2SettingsFrame: + case *http2RSTStreamFrame: + if sawRST { + return fmt.Errorf("saw second RSTStreamFrame: %v", http2summarizeFrame(f)) + } + if f.ErrCode != http2ErrCodeCancel { + return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", http2summarizeFrame(f)) + } + sawRST = true + case *http2WindowUpdateFrame: + if sawWUF { + return fmt.Errorf("saw second WindowUpdateFrame: %v", http2summarizeFrame(f)) + } + if f.Increment != 4999 { + return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", http2summarizeFrame(f)) + } + sawWUF = true + default: + return fmt.Errorf("Unexpected frame: %v", http2summarizeFrame(f)) + } + } + return nil + } + ct.run() +} + +// See golang.org/issue/16481 +func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { + testTransportReturnsUnusedFlowControl(t, true) +} + +// See golang.org/issue/20469 +func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { + testTransportReturnsUnusedFlowControl(t, false) +} + +// Issue 16612: adjust flow control on open streams when transport +// receives SETTINGS with INITIAL_WINDOW_SIZE from server. +func TestTransportAdjustsFlowControl(t *testing.T) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + + const bodySize = 1 << 20 + + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + defer close(clientDone) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + res.Body.Close() + return nil + } + ct.server = func() error { + _, err := io.ReadFull(ct.sc, make([]byte, len(http2ClientPreface))) + if err != nil { + return fmt.Errorf("reading client preface: %v", err) + } + + var gotBytes int64 + var sentSettings bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + return nil + default: + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + } + switch f := f.(type) { + case *http2DataFrame: + gotBytes += int64(len(f.Data())) + // After we've got half the client's + // initial flow control window's worth + // of request body data, give it just + // enough flow control to finish. + if gotBytes >= http2initialWindowSize/2 && !sentSettings { + sentSettings = true + + ct.fr.WriteSettings(http2Setting{ID: http2SettingInitialWindowSize, Val: bodySize}) + ct.fr.WriteWindowUpdate(0, bodySize) + ct.fr.WriteSettingsAck() + } + + if f.StreamEnded() { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + } + } + } + ct.run() +} + +// See golang.org/issue/16556 +func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { + ct := newClientTester(t) + + unblockClient := make(chan bool, 1) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + defer res.Body.Close() + <-unblockClient + return nil + } + ct.server = func() error { + ct.greet() + + var hf *http2HeadersFrame + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + continue + } + var ok bool + hf, ok = f.(*http2HeadersFrame) + if !ok { + return fmt.Errorf("Got %T; want HeadersFrame", f) + } + break + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + pad := make([]byte, 5) + ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream + + f, err := ct.readNonSettingsFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) + } + wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding + if wuf, ok := f.(*http2WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { + return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, http2summarizeFrame(f)) + } + + f, err = ct.readNonSettingsFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) + } + if wuf, ok := f.(*http2WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { + return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, http2summarizeFrame(f)) + } + unblockClient <- true + return nil + } + ct.run() +} + +// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a +// StreamError as a result of the response HEADERS +func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { + ct := newClientTester(t) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err == nil { + res.Body.Close() + return errors.New("unexpected successful GET") + } + want := http2StreamError{1, http2ErrCodeProtocol, http2headerFieldNameError(" content-type")} + if !reflect.DeepEqual(want, err) { + t.Errorf("RoundTrip error = %#v; want %#v", err, want) + } + return nil + } + ct.server = func() error { + ct.greet() + + hf, err := ct.firstHeaders() + if err != nil { + return err + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + for { + fr, err := ct.readFrame() + if err != nil { + return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) + } + if _, ok := fr.(*http2SettingsFrame); ok { + continue + } + if rst, ok := fr.(*http2RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != http2ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with http2ErrCodeProtocol", http2summarizeFrame(fr)) + } + break + } + + return nil + } + ct.run() +} + +// byteAndEOFReader returns is in an io.Reader which reads one byte +// (the underlying byte) and io.EOF at once in its Read call. +type byteAndEOFReader byte + +func (b byteAndEOFReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + panic("unexpected useless call") + } + p[0] = byte(b) + return 1, io.EOF +} + +// Issue 16788: the Transport had a regression where it started +// sending a spurious DATA frame with a duplicate END_STREAM bit after +// the request body writer goroutine had already read an EOF from the +// Request.Body and included the END_STREAM on a data-carrying DATA +// frame. +// +// Notably, to trigger this, the requests need to use a Request.Body +// which returns (non-0, io.EOF) and also needs to set the ContentLength +// explicitly. +func TestTransportBodyDoubleEndStream(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + // Nothing. + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a')) + req.ContentLength = 1 + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("failure on req %d: %v", i+1, err) + } + defer res.Body.Close() + } +} + +// golang.org/issue/16847, golang.org/issue/19103 +func TestTransportRequestPathPseudo(t *testing.T) { + type result struct { + path string + err string + } + tests := []struct { + req *http.Request + want result + }{ + 0: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Host: "foo.com", + Path: "/foo", + }, + }, + want: result{path: "/foo"}, + }, + // In Go 1.7, we accepted paths of "//foo". + // In Go 1.8, we rejected it (issue 16847). + // In Go 1.9, we accepted it again (issue 19103). + 1: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Host: "foo.com", + Path: "//foo", + }, + }, + want: result{path: "//foo"}, + }, + + // Opaque with //$Matching_Hostname/path + 2: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "https", + Opaque: "//foo.com/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque with some other Request.Host instead: + 3: { + req: &http.Request{ + Method: "GET", + Host: "bar.com", + URL: &url.URL{ + Scheme: "https", + Opaque: "//bar.com/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque without the leading "//": + 4: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Opaque: "/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque we can't handle: + 5: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "https", + Opaque: "//unknown_host/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`}, + }, + + // A CONNECT request: + 6: { + req: &http.Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "foo.com", + }, + }, + want: result{}, + }, + } + for i, tt := range tests { + cc := &http2ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc.henc = hpack.NewEncoder(&cc.hbuf) + cc.mu.Lock() + hdrs, err := cc.encodeHeaders(tt.req, false, "", -1, nil) + cc.mu.Unlock() + var got result + hpackDec := hpack.NewDecoder(http2initialHeaderTableSize, func(f hpack.HeaderField) { + if f.Name == ":path" { + got.path = f.Value + } + }) + if err != nil { + got.err = err.Error() + } else if len(hdrs) > 0 { + if _, err := hpackDec.Write(hdrs); err != nil { + t.Errorf("%d. bogus hpack: %v", i, err) + continue + } + } + if got != tt.want { + t.Errorf("%d. got %+v; want %+v", i, got, tt.want) + } + + } + +} + +// golang.org/issue/17071 -- don't sniff the first byte of the request body +// before we've determined that the ClientConn is usable. +func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { + const body = "foo" + req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) + cc := &http2ClientConn{ + closed: true, + reqHeaderMu: make(chan struct{}, 1), + } + _, err := cc.RoundTrip(req) + if err != http2errClientConnUnusable { + t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("ReadAll = %v", err) + } + if string(slurp) != body { + t.Errorf("Body = %q; want %q", slurp, body) + } +} + +func TestClientConnPing(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) + defer st.Close() + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + if err = cc.Ping(context.Background()); err != nil { + t.Fatal(err) + } +} + +// Issue 16974: if the server sent a DATA frame after the user +// canceled the Transport's Request, the Transport previously wrote to a +// closed pipe, got an error, and ended up closing the whole TCP +// connection. +func TestTransportCancelDataResponseRace(t *testing.T) { + cancel := make(chan struct{}) + clientGotError := make(chan bool, 1) + + const msg = "Hello." + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/hello") { + time.Sleep(50 * time.Millisecond) + io.WriteString(w, msg) + return + } + for i := 0; i < 50; i++ { + io.WriteString(w, "Some data.") + w.(http.Flusher).Flush() + if i == 2 { + close(cancel) + <-clientGotError + } + time.Sleep(10 * time.Millisecond) + } + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Cancel = cancel + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { + t.Fatal("unexpected success") + } + clientGotError <- true + + res, err = c.Get(st.ts.URL + "/hello") + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != msg { + t.Errorf("Got = %q; want %q", slurp, msg) + } +} + +// Issue 21316: It should be safe to reuse an http.Request after the +// request has completed. +func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + io.WriteString(w, "body") + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { + t.Fatalf("error reading response body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("error closing response body: %v", err) + } + + // This access of req.Header should not race with code in the transport. + req.Header = http.Header{} +} + +func TestTransportCloseAfterLostPing(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.tr.PingTimeout = 1 * time.Second + ct.tr.ReadIdleTimeout = 1 * time.Second + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + _, err := ct.tr.RoundTrip(req) + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + <-clientDone + return nil + } + ct.run() +} + +func TestTransportPingWriteBlocks(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + ) + defer st.Close() + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + s, c := net.Pipe() // unbuffered, unlike a TCP conn + go func() { + // Read initial handshake frames. + // Without this, we block indefinitely in newClientConn, + // and never get to the point of sending a PING. + var buf [1024]byte + s.Read(buf[:]) + }() + return c, nil + }, + PingTimeout: 1 * time.Millisecond, + ReadIdleTimeout: 1 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + _, err := c.Get(st.ts.URL) + if err == nil { + t.Fatalf("Get = nil, want error") + } +} + +func TestTransportPingWhenReading(t *testing.T) { + testCases := []struct { + name string + readIdleTimeout time.Duration + deadline time.Duration + expectedPingCount int + }{ + { + name: "two pings", + readIdleTimeout: 100 * time.Millisecond, + deadline: time.Second, + expectedPingCount: 2, + }, + { + name: "zero ping", + readIdleTimeout: time.Second, + deadline: 200 * time.Millisecond, + expectedPingCount: 0, + }, + { + name: "0 readIdleTimeout means no ping", + readIdleTimeout: 0 * time.Millisecond, + deadline: 500 * time.Millisecond, + expectedPingCount: 0, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) + }) + } +} + +func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { + var pingCount int + ct := newClientTester(t) + ct.tr.ReadIdleTimeout = readIdleTimeout + + ctx, cancel := context.WithTimeout(context.Background(), deadline) + defer cancel() + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) + } + _, err = ioutil.ReadAll(res.Body) + if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil + } + + cancel() + return err + } + + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var streamID uint32 + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-ctx.Done(): + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + streamID = f.StreamID + case *http2PingFrame: + pingCount++ + if pingCount == expectedPingCount { + if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { + return err + } + } + if err := ct.fr.WritePing(true, f.Data); err != nil { + return err + } + case *http2RSTStreamFrame: + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportRetryAfterGOAWAY(t *testing.T) { + var dialer struct { + sync.Mutex + count int + } + ct1 := make(chan *clientTester) + ct2 := make(chan *clientTester) + + ln := newLocalListener(t) + defer ln.Close() + + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + } + tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialer.Lock() + defer dialer.Unlock() + dialer.count++ + if dialer.count == 3 { + return nil, errors.New("unexpected number of dials") + } + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return nil, fmt.Errorf("dial error: %v", err) + } + sc, err := ln.Accept() + if err != nil { + return nil, fmt.Errorf("accept error: %v", err) + } + ct := &clientTester{ + t: t, + tr: tr, + cc: cc, + sc: sc, + fr: http2NewFramer(sc, sc), + } + switch dialer.count { + case 1: + ct1 <- ct + case 2: + ct2 <- ct + } + return cc, nil + } + + errs := make(chan error, 3) + + // Client. + go func() { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := tr.RoundTrip(req) + if res != nil { + res.Body.Close() + if got := res.Header.Get("Foo"); got != "bar" { + err = fmt.Errorf("foo header = %q; want bar", got) + } + } + if err != nil { + err = fmt.Errorf("RoundTrip: %v", err) + } + errs <- err + }() + + connToClose := make(chan io.Closer, 2) + + // Server for the first request. + go func() { + ct := <-ct1 + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err) + return + } + t.Logf("server1 got %v", hf) + if err := ct.fr.WriteGoAway(0 /*max id*/, http2ErrCodeNo, nil); err != nil { + errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) + return + } + errs <- nil + }() + + // Server for the second request. + go func() { + ct := <-ct2 + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) + return + } + t.Logf("server2 got %v", hf) + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + err = ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + if err != nil { + errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) + } else { + errs <- nil + } + }() + + for k := 0; k < 3; k++ { + err := <-errs + if err != nil { + t.Error(err) + } + } + + close(connToClose) + for c := range connToClose { + c.Close() + } +} + +func TestTransportRetryAfterRefusedStream(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + resp, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + resp.Body.Close() + if resp.StatusCode != 204 { + return fmt.Errorf("Status = %v; want 204", resp.StatusCode) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + nreq := 0 + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + nreq++ + if nreq == 1 { + ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeRefusedStream) + } else { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportResponseDataBeforeHeaders(t *testing.T) { + // This test use not valid response format. + // Discarding logger output to not spam tests output. + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + req := httptest.NewRequest("GET", "https://dummy.tld/", nil) + // First request is normal to ensure the check is per stream and not per connection. + _, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip expected no error, got: %v", err) + } + // Second request returns a DATA frame with no HEADERS. + resp, err := ct.tr.RoundTrip(req) + if err == nil { + return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + } + if err, ok := err.(http2StreamError); !ok || err.Code != http2ErrCodeProtocol { + return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame, *http2RSTStreamFrame: + case *http2HeadersFrame: + switch f.StreamID { + case 1: + // Send a valid response to first request. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + case 3: + ct.fr.WriteData(f.StreamID, true, []byte("payload")) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportRequestsLowServerLimit(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, optOnlyServer, func(s *http2Server) { + s.MaxConcurrentStreams = 1 + }) + defer st.Close() + + var ( + connCountMu sync.Mutex + connCount int + ) + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + connCountMu.Lock() + defer connCountMu.Unlock() + connCount++ + return tls.Dial(network, addr, cfg) + }, + } + defer tr.CloseIdleConnections() + + const reqCount = 3 + for i := 0; i < reqCount; i++ { + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + } + + if connCount != 1 { + t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) + } +} + +// tests Transport.StrictMaxConcurrentStreams +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 + + greet := make(chan struct{}) // server sends initial SETTINGS frame + gotRequest := make(chan struct{}) // server received a request + clientDone := make(chan struct{}) + + // Collect errors from goroutines. + var wg sync.WaitGroup + errs := make(chan error, 100) + defer func() { + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } + }() + + // We will send maxConcurrent+2 requests. This checker goroutine waits for the + // following stages: + // 1. The first maxConcurrent requests are received by the server. + // 2. The client will cancel the next request + // 3. The server is unblocked so it can service the first maxConcurrent requests + // 4. The client will send the final request + wg.Add(1) + unblockClient := make(chan struct{}) + clientRequestCancelled := make(chan struct{}) + unblockServer := make(chan struct{}) + go func() { + defer wg.Done() + // Stage 1. + for k := 0; k < maxConcurrent; k++ { + <-gotRequest + } + // Stage 2. + close(unblockClient) + <-clientRequestCancelled + // Stage 3: give some time for the final RoundTrip call to be scheduled and + // verify that the final request is not sent. + time.Sleep(50 * time.Millisecond) + select { + case <-gotRequest: + errs <- errors.New("last request did not stall") + close(unblockServer) + return + default: + } + close(unblockServer) + // Stage 4. + <-gotRequest + }() + + ct := newClientTester(t) + ct.tr.StrictMaxConcurrentStreams = true + ct.client = func() error { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(clientDone) + ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + ct.cc.(*net.TCPConn).Close() + } + }() + for k := 0; k < maxConcurrent+2; k++ { + wg.Add(1) + go func(k int) { + defer wg.Done() + // Don't send the second request until after receiving SETTINGS from the server + // to avoid a race where we use the default SettingMaxConcurrentStreams, which + // is much larger than maxConcurrent. We have to send the first request before + // waiting because the first request triggers the dial and greet. + if k > 0 { + <-greet + } + // Block until maxConcurrent requests are sent before sending any more. + if k >= maxConcurrent { + <-unblockClient + } + body := newStaticCloseChecker("") + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) + if k == maxConcurrent { + // This request will be canceled. + cancel := make(chan struct{}) + req.Cancel = cancel + close(cancel) + _, err := ct.tr.RoundTrip(req) + close(clientRequestCancelled) + if err == nil { + errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) + return + } + } else { + resp, err := ct.tr.RoundTrip(req) + if err != nil { + errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) + return + } + ioutil.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != 204 { + errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) + return + } + } + if err := body.isClosed(); err != nil { + errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) + } + }(k) + } + return nil + } + + ct.server = func() error { + var wg sync.WaitGroup + defer wg.Wait() + + ct.greet(http2Setting{http2SettingMaxConcurrentStreams, maxConcurrent}) + + // Server write loop. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + writeResp := make(chan uint32, maxConcurrent+1) + + wg.Add(1) + go func() { + defer wg.Done() + <-unblockServer + for id := range writeResp { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: id, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + }() + + // Server read loop. + var nreq int + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it will have reported any errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *http2WindowUpdateFrame: + case *http2SettingsFrame: + // Wait for the client SETTINGS ack until ending the greet. + close(greet) + case *http2HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + gotRequest <- struct{}{} + nreq++ + writeResp <- f.StreamID + if nreq == maxConcurrent+1 { + close(writeResp) + } + case *http2DataFrame: + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + + ct.run() +} + +func TestAuthorityAddr(t *testing.T) { + tests := []struct { + scheme, authority string + want string + }{ + {"http", "foo.com", "foo.com:80"}, + {"https", "foo.com", "foo.com:443"}, + {"https", "foo.com:1234", "foo.com:1234"}, + {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, + {"https", "1.2.3.4", "1.2.3.4:443"}, + {"https", "[::1]:1234", "[::1]:1234"}, + {"https", "[::1]", "[::1]:443"}, + } + for _, tt := range tests { + got := http2authorityAddr(tt.scheme, tt.authority) + if got != tt.want { + t.Errorf("http2authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) + } + } +} + +// Issue 20448: stop allocating for DATA frames' payload after +// Response.Body.Close is called. +func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { + megabyteZero := make([]byte, 1<<20) + + writeErr := make(chan error, 1) + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + var sum int64 + for i := 0; i < 100; i++ { + n, err := w.Write(megabyteZero) + sum += int64(n) + if err != nil { + writeErr <- err + return + } + } + t.Logf("wrote all %d bytes", sum) + writeErr <- nil + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [1]byte + if _, err := res.Body.Read(buf[:]); err != nil { + t.Error(err) + } + if err := res.Body.Close(); err != nil { + t.Error(err) + } + + trb, ok := res.Body.(http2transportResponseBody) + if !ok { + t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) + } + if trb.cs.bufPipe.b != nil { + t.Errorf("response body pipe is still open") + } + + gotErr := <-writeErr + if gotErr == nil { + t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") + } else if gotErr != http2errStreamClosed { + t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) + } +} + +// Issue 18891: make sure Request.Body == NoBody means no DATA frame +// is ever sent, even if empty. +func TestTransportNoBodyMeansNoDATA(t *testing.T) { + ct := newClientTester(t) + + unblockClient := make(chan bool) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) + ct.tr.RoundTrip(req) + <-unblockClient + return nil + } + ct.server = func() error { + defer close(unblockClient) + defer ct.cc.(*net.TCPConn).Close() + ct.greet() + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f := f.(type) { + default: + return fmt.Errorf("Got %T; want HeadersFrame", f) + case *http2WindowUpdateFrame, *http2SettingsFrame: + continue + case *http2HeadersFrame: + if !f.StreamEnded() { + return fmt.Errorf("got headers frame without END_STREAM") + } + return nil + } + } + } + ct.run() +} + +func disableGoroutineTracking() (restore func()) { + old := http2DebugGoroutines + http2DebugGoroutines = false + return func() { http2DebugGoroutines = old } +} + +func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { + defer disableGoroutineTracking()() + b.ReportAllocs() + st := newServerTester(b, + func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < nResHeader; i++ { + name := fmt.Sprint("A-", i) + w.Header().Set(name, "*") + } + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < nReqHeaders; i++ { + name := fmt.Sprint("A-", i) + req.Header.Set(name, "*") + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := tr.RoundTrip(req) + if err != nil { + if res != nil { + res.Body.Close() + } + b.Fatalf("RoundTrip err = %v; want nil", err) + } + res.Body.Close() + if res.StatusCode != http.StatusOK { + b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) + } + } +} + +type infiniteReader struct{} + +func (r infiniteReader) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 20521: it is not an error to receive a response and end stream +// from the server without the body being consumed. +func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + // The request body needs to be big enough to trigger flow control. + req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{}) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) + } +} + +// Verify transport doesn't crash when receiving bogus response lacking a :status header. +// Issue 22880. +func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + _, err := ct.tr.RoundTrip(req) + const substr = "malformed response from server: missing status pseudo header" + if !strings.Contains(fmt.Sprint(err), substr) { + return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2HeadersFrame: + enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(f.StreamID, true, []byte("payload")) + return nil + } + } + } + ct.run() +} + +func BenchmarkClientRequestHeaders(b *testing.B) { + b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) + b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) }) + b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) }) + b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) }) +} + +func BenchmarkClientResponseHeaders(b *testing.B) { + b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) + b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) }) + b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) }) + b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) }) +} + +func activeStreams(cc *http2ClientConn) int { + count := 0 + cc.mu.Lock() + defer cc.mu.Unlock() + for _, cs := range cc.streams { + select { + case <-cs.abort: + default: + count++ + } + } + return count +} + +type closeMode int + +const ( + closeAtHeaders closeMode = iota + closeAtBody + shutdown + shutdownCancel +) + +// See golang.org/issue/17292 +func testClientConnClose(t *testing.T, closeMode closeMode) { + clientDone := make(chan struct{}) + defer close(clientDone) + handlerDone := make(chan struct{}) + closeDone := make(chan struct{}) + beforeHeader := func() {} + bodyWrite := func(w http.ResponseWriter) {} + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + beforeHeader() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + bodyWrite(w) + select { + case <-w.(http.CloseNotifier).CloseNotify(): + // client closed connection before completion + if closeMode == shutdown || closeMode == shutdownCancel { + t.Error("expected request to complete") + } + case <-clientDone: + if closeMode == closeAtHeaders || closeMode == closeAtBody { + t.Error("expected connection closed by client") + } + } + }, optOnlyServer) + defer st.Close() + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + if closeMode == closeAtHeaders { + beforeHeader = func() { + if err := cc.Close(); err != nil { + t.Error(err) + } + close(closeDone) + } + } + var sendBody chan struct{} + if closeMode == closeAtBody { + sendBody = make(chan struct{}) + bodyWrite = func(w http.ResponseWriter) { + <-sendBody + b := make([]byte, 32) + w.Write(b) + w.(http.Flusher).Flush() + if err := cc.Close(); err != nil { + t.Errorf("unexpected ClientConn close error: %v", err) + } + close(closeDone) + w.Write(b) + w.(http.Flusher).Flush() + } + } + res, err := cc.RoundTrip(req) + if res != nil { + defer res.Body.Close() + } + if closeMode == closeAtHeaders { + got := fmt.Sprint(err) + want := "http2: client connection force closed via ClientConn.Close" + if got != want { + t.Fatalf("RoundTrip error = %v, want %v", got, want) + } + } else { + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + if got, want := activeStreams(cc), 1; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + } + switch closeMode { + case shutdownCancel: + if err = cc.Shutdown(canceledCtx); err != context.Canceled { + t.Errorf("got %v, want %v", err, context.Canceled) + } + if cc.closing == false { + t.Error("expected closing to be true") + } + if cc.CanTakeNewRequest() == true { + t.Error("CanTakeNewRequest to return false") + } + if v, want := len(cc.streams), 1; v != want { + t.Errorf("expected %d active streams, got %d", want, v) + } + clientDone <- struct{}{} + <-handlerDone + case shutdown: + wait := make(chan struct{}) + http2shutdownEnterWaitStateHook = func() { + close(wait) + http2shutdownEnterWaitStateHook = func() {} + } + defer func() { http2shutdownEnterWaitStateHook = func() {} }() + shutdown := make(chan struct{}, 1) + go func() { + if err = cc.Shutdown(context.Background()); err != nil { + t.Error(err) + } + close(shutdown) + }() + // Let the shutdown to enter wait state + <-wait + cc.mu.Lock() + if cc.closing == false { + t.Error("expected closing to be true") + } + cc.mu.Unlock() + if cc.CanTakeNewRequest() == true { + t.Error("CanTakeNewRequest to return false") + } + if got, want := activeStreams(cc), 1; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + // Let the active request finish + clientDone <- struct{}{} + // Wait for the shutdown to end + select { + case <-shutdown: + case <-time.After(2 * time.Second): + t.Fatal("expected server connection to close") + } + case closeAtHeaders, closeAtBody: + if closeMode == closeAtBody { + go close(sendBody) + if _, err := io.Copy(ioutil.Discard, res.Body); err == nil { + t.Error("expected a Copy error, got nil") + } + } + <-closeDone + if got, want := activeStreams(cc), 0; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + // wait for server to get the connection close notice + select { + case <-handlerDone: + case <-time.After(2 * time.Second): + t.Fatal("expected server connection to close") + } + } +} + +// The client closes the connection just after the server got the client's HEADERS +// frame, but before the server sends its HEADERS response back. The expected +// result is an error on RoundTrip explaining the client closed the connection. +func TestClientConnCloseAtHeaders(t *testing.T) { + testClientConnClose(t, closeAtHeaders) +} + +// The client closes the connection between two server's response DATA frames. +// The expected behavior is a response body io read error on the client. +func TestClientConnCloseAtBody(t *testing.T) { + testClientConnClose(t, closeAtBody) +} + +// The client sends a GOAWAY frame before the server finished processing a request. +// We expect the connection not to close until the request is completed. +func TestClientConnShutdown(t *testing.T) { + testClientConnClose(t, shutdown) +} + +// The client sends a GOAWAY frame before the server finishes processing a request, +// but cancels the passed context before the request is completed. The expected +// behavior is the client closing the connection after the context is canceled. +func TestClientConnShutdownCancel(t *testing.T) { + testClientConnClose(t, shutdownCancel) +} + +// Issue 25009: use Request.GetBody if present, even if it seems like +// we might not need it. Apparently something else can still read from +// the original request body. Data race? In any case, rewinding +// unconditionally on retry is a nicer model anyway and should +// simplify code in the future (after the Go 1.11 freeze) +func TestTransportUsesGetBodyWhenPresent(t *testing.T) { + calls := 0 + someBody := func() io.ReadCloser { + return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))} + } + req := &http.Request{ + Body: someBody(), + GetBody: func() (io.ReadCloser, error) { + calls++ + return someBody(), nil + }, + } + + req2, err := http2shouldRetryRequest(req, http2errClientConnUnusable) + if err != nil { + t.Fatal(err) + } + if calls != 1 { + t.Errorf("Calls = %d; want 1", calls) + } + if req2 == req { + t.Error("req2 changed") + } + if req2 == nil { + t.Fatal("req2 is nil") + } + if req2.Body == nil { + t.Fatal("req2.Body is nil") + } + if req2.GetBody == nil { + t.Fatal("req2.GetBody is nil") + } + if req2.Body == req.Body { + t.Error("req2.Body unchanged") + } +} + +type errReader struct { + body []byte + err error +} + +func (r *errReader) Read(p []byte) (int, error) { + if len(r.body) > 0 { + n := copy(p, r.body) + r.body = r.body[n:] + return n, nil + } + return 0, r.err +} + +func testTransportBodyReadError(t *testing.T, body []byte) { + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + // So far we've only seen this be flaky on Windows and Plan 9, + // perhaps due to TCP behavior on shutdowns while + // unread data is in flight. This test should be + // fixed, but a skip is better than annoying people + // for now. + t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) + } + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + defer close(clientDone) + + checkNoStreams := func() error { + cp, ok := ct.tr.connPool().(*http2clientConnPool) + if !ok { + return fmt.Errorf("conn pool is %T; want *http2clientConnPool", ct.tr.connPool()) + } + cp.mu.Lock() + defer cp.mu.Unlock() + conns, ok := cp.conns["dummy.tld:443"] + if !ok { + return fmt.Errorf("missing connection") + } + if len(conns) != 1 { + return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) + } + if activeStreams(conns[0]) != 0 { + return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) + } + return nil + } + bodyReadError := errors.New("body read error") + body := &errReader{body, bodyReadError} + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + return err + } + _, err = ct.tr.RoundTrip(req) + if err != bodyReadError { + return fmt.Errorf("err = %v; want %v", err, bodyReadError) + } + if err = checkNoStreams(); err != nil { + return err + } + return nil + } + ct.server = func() error { + ct.greet() + var receivedBody []byte + var resetCount int + for { + f, err := ct.fr.ReadFrame() + t.Logf("server: ReadFrame = %v, %v", f, err) + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + if bytes.Compare(receivedBody, body) != 0 { + return fmt.Errorf("body: %q; expected %q", receivedBody, body) + } + if resetCount != 1 { + return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) + } + return nil + default: + return err + } + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + case *http2DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *http2RSTStreamFrame: + resetCount++ + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } +func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) } + +// Issue 32254: verify that the client sends END_STREAM flag eagerly with the last +// (or in this test-case the only one) request body data frame, and does not send +// extra zero-len data frames. +func TestTransportBodyEagerEndStream(t *testing.T) { + const reqBody = "some request body" + const resBody = "some response body" + + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + body := strings.NewReader(reqBody) + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + return err + } + _, err = ct.tr.RoundTrip(req) + if err != nil { + return err + } + return nil + } + ct.server = func() error { + ct.greet() + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + case *http2DataFrame: + if !f.StreamEnded() { + ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeRefusedStream) + return fmt.Errorf("data frame without END_STREAM %v", f) + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.Header().StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(f.StreamID, true, []byte(resBody)) + return nil + case *http2RSTStreamFrame: + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +type chunkReader struct { + chunks [][]byte +} + +func (r *chunkReader) Read(p []byte) (int, error) { + if len(r.chunks) > 0 { + n := copy(p, r.chunks[0]) + r.chunks = r.chunks[1:] + return n, nil + } + panic("shouldn't read this many times") +} + +// Issue 32254: if the request body is larger than the specified +// content length, the client should refuse to send the extra part +// and abort the stream. +// +// In _len3 case, the first Read() matches the expected content length +// but the second read returns more data. +// +// In _len2 case, the first Read() exceeds the expected content length. +func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) { + body := &chunkReader{[][]byte{ + []byte("123"), + []byte("456"), + }} + testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) +} + +func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) { + body := &chunkReader{[][]byte{ + []byte("123"), + }} + testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) +} + +func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + r.Body.Read(make([]byte, 6)) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("POST", st.ts.URL, body) + req.ContentLength = contentLen + _, err := tr.RoundTrip(req) + if err != http2errReqBodyTooLong { + t.Fatalf("expected %v, got %v", http2errReqBodyTooLong, err) + } +} + +func TestClientConnTooIdle(t *testing.T) { + tests := []struct { + cc func() *http2ClientConn + want bool + }{ + { + func() *http2ClientConn { + return &http2ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + }, + true, + }, + { + func() *http2ClientConn { + return &http2ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} + }, + false, + }, + { + func() *http2ClientConn { + return &http2ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + }, + false, + }, + { + func() *http2ClientConn { + return &http2ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} + }, + false, + }, + } + for i, tt := range tests { + got := tt.cc().tooIdleLocked() + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +} + +type fakeConnErr struct { + net.Conn + writeErr error + closed bool +} + +func (fce *fakeConnErr) Write(b []byte) (n int, err error) { + return 0, fce.writeErr +} + +func (fce *fakeConnErr) Close() error { + fce.closed = true + return nil +} + +// issue 39337: close the connection on a failed write +func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { + tr := &http2Transport{} + writeErr := errors.New("write error") + fakeConn := &fakeConnErr{writeErr: writeErr} + _, err := tr.NewClientConn(fakeConn) + if err != writeErr { + t.Fatalf("expected %v, got %v", writeErr, err) + } + if !fakeConn.closed { + t.Error("expected closed conn") + } +} + +func TestTransportRoundtripCloseOnWriteError(t *testing.T) { + req, err := http.NewRequest("GET", "https://dummy.tld/", nil) + if err != nil { + t.Fatal(err) + } + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + + writeErr := errors.New("write error") + cc.wmu.Lock() + cc.werr = writeErr + cc.wmu.Unlock() + + _, err = cc.RoundTrip(req) + if err != writeErr { + t.Fatalf("expected %v, got %v", writeErr, err) + } + + cc.mu.Lock() + closed := cc.closed + cc.mu.Unlock() + if !closed { + t.Fatal("expected closed") + } +} + +// Issue 31192: A failed request may be retried if the body has not been read +// already. If the request body has started to be sent, one must wait until it +// is completed. +func TestTransportBodyRewindRace(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") + w.WriteHeader(http.StatusOK) + return + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxConnsPerHost: 1, + } + err := http2ConfigureTransport(tr) + if err != nil { + t.Fatal(err) + } + client := &http.Client{ + Transport: tr, + } + + const clients = 50 + + var wg sync.WaitGroup + wg.Add(clients) + for i := 0; i < clients; i++ { + req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef")) + if err != nil { + t.Fatalf("unexpect new request error: %v", err) + } + + go func() { + defer wg.Done() + res, err := client.Do(req) + if err == nil { + res.Body.Close() + } + }() + } + + wg.Wait() +} + +// Issue 42498: A request with a body will never be sent if the stream is +// reset prior to sending any data. +func TestTransportServerResetStreamAtHeaders(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + return + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxConnsPerHost: 1, + ExpectContinueTimeout: 10 * time.Second, + } + + err := http2ConfigureTransport(tr) + if err != nil { + t.Fatal(err) + } + client := &http.Client{ + Transport: tr, + } + + req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF}) + if err != nil { + t.Fatalf("unexpect new request error: %v", err) + } + req.ContentLength = 0 // so transport is tempted to sniff it + req.Header.Set("Expect", "100-continue") + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +type trackingReader struct { + rdr io.Reader + wasRead uint32 +} + +func (tr *trackingReader) Read(p []byte) (int, error) { + atomic.StoreUint32(&tr.wasRead, 1) + return tr.rdr.Read(p) +} + +func (tr *trackingReader) WasRead() bool { + return atomic.LoadUint32(&tr.wasRead) != 0 +} + +func TestTransportExpectContinue(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/reject": + w.WriteHeader(403) + default: + io.Copy(io.Discard, r.Body) + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxConnsPerHost: 1, + ExpectContinueTimeout: 10 * time.Second, + } + + err := http2ConfigureTransport(tr) + if err != nil { + t.Fatal(err) + } + client := &http.Client{ + Transport: tr, + } + + testCases := []struct { + Name string + Path string + Body *trackingReader + ExpectedCode int + ShouldRead bool + }{ + { + Name: "read-all", + Path: "/", + Body: &trackingReader{rdr: strings.NewReader("hello")}, + ExpectedCode: 200, + ShouldRead: true, + }, + { + Name: "reject", + Path: "/reject", + Body: &trackingReader{rdr: strings.NewReader("hello")}, + ExpectedCode: 403, + ShouldRead: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + startTime := time.Now() + + req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Expect", "100-continue") + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout { + t.Error("Request didn't finish before expect continue timeout") + } + if res.StatusCode != tc.ExpectedCode { + t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode) + } + if tc.Body.WasRead() != tc.ShouldRead { + t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead) + } + }) + } +} + +type closeChecker struct { + io.ReadCloser + closed chan struct{} +} + +func newCloseChecker(r io.ReadCloser) *closeChecker { + return &closeChecker{r, make(chan struct{})} +} + +func newStaticCloseChecker(body string) *closeChecker { + return newCloseChecker(io.NopCloser(strings.NewReader("body"))) +} + +func (rc *closeChecker) Read(b []byte) (n int, err error) { + select { + default: + case <-rc.closed: + // TODO(dneil): Consider restructuring the request write to avoid reading + // from the request body after closing it, and check for read-after-close here. + // Currently, abortRequestBodyWrite races with writeRequestBody. + return 0, errors.New("read after Body.Close") + } + return rc.ReadCloser.Read(b) +} + +func (rc *closeChecker) Close() error { + close(rc.closed) + return rc.ReadCloser.Close() +} + +func (rc *closeChecker) isClosed() error { + // The RoundTrip contract says that it will close the request body, + // but that it may do so in a separate goroutine. Wait a reasonable + // amount of time before concluding that the body isn't being closed. + timeout := time.Duration(10 * time.Second) + select { + case <-rc.closed: + case <-time.After(timeout): + return fmt.Errorf("body not closed after %v", timeout) + } + return nil +} + +// A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written. +type blockingWriteConn struct { + net.Conn + writeOnce sync.Once + writec chan struct{} // closed after the write limit is reached + unblockc chan struct{} // closed to unblock writes + count, limit int +} + +func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn { + return &blockingWriteConn{ + Conn: conn, + limit: limit, + writec: make(chan struct{}), + unblockc: make(chan struct{}), + } +} + +// wait waits until the conn blocks writing the limit+1st byte. +func (c *blockingWriteConn) wait() { + <-c.writec +} + +// unblock unblocks writes to the conn. +func (c *blockingWriteConn) unblock() { + close(c.unblockc) +} + +func (c *blockingWriteConn) Write(b []byte) (n int, err error) { + if c.count+len(b) > c.limit { + c.writeOnce.Do(func() { + close(c.writec) + }) + <-c.unblockc + } + n, err = c.Conn.Write(b) + c.count += n + return n, err +} + +// Write several requests to a ClientConn at the same time, looking for race conditions. +// See golang.org/issue/48340 +func TestTransportFrameBufferReuse(t *testing.T) { + filler := hex.EncodeToString([]byte(randString(2048))) + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("Big"), filler; got != want { + t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want) + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("error reading request body: %v", err) + } + if got, want := string(b), filler; got != want { + t.Errorf("request body = %q, want %q", got, want) + } + if got, want := r.Trailer.Get("Big"), filler; got != want { + t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want) + } + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + var wg sync.WaitGroup + defer wg.Wait() + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Big", filler) + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + }() + } + +} + +// Ensure that a request blocking while being written to the underlying net.Conn doesn't +// block access to the ClientConn pool. Test requests blocking while writing headers, the body, +// and trailers. +// See golang.org/issue/32388 +func TestTransportBlockingRequestWrite(t *testing.T) { + filler := hex.EncodeToString([]byte(randString(2048))) + for _, test := range []struct { + name string + req func(url string) (*http.Request, error) + }{{ + name: "headers", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Big", filler) + return req, err + }, + }, { + name: "body", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, strings.NewReader(filler)) + if err != nil { + return nil, err + } + return req, err + }, + }, { + name: "trailer", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, strings.NewReader("body")) + if err != nil { + return nil, err + } + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + return req, err + }, + }} { + test := test + t.Run(test.name, func(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("Big"); v != "" && v != filler { + t.Errorf("request header mismatch") + } + if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler { + t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler) + } + if v := r.Trailer.Get("Big"); v != "" && v != filler { + t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler) + } + }, optOnlyServer, func(s *http2Server) { + s.MaxConcurrentStreams = 1 + }) + defer st.Close() + + // This Transport creates connections that block on writes after 1024 bytes. + connc := make(chan *blockingWriteConn, 1) + connCount := 0 + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + connCount++ + c, err := tls.Dial(network, addr, cfg) + wc := newBlockingWriteConn(c, 1024) + select { + case connc <- wc: + default: + } + return wc, err + }, + } + defer tr.CloseIdleConnections() + + // Request 1: A small request to ensure we read the server MaxConcurrentStreams. + { + req, err := http.NewRequest("POST", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + } + + // Request 2: A large request that blocks while being written. + reqc := make(chan struct{}) + go func() { + defer close(reqc) + req, err := test.req(st.ts.URL) + if err != nil { + t.Error(err) + return + } + res, _ := tr.RoundTrip(req) + if res != nil && res.Body != nil { + res.Body.Close() + } + }() + conn := <-connc + conn.wait() // wait for the request to block + + // Request 3: A small request that is sent on a new connection, since request 2 + // is hogging the only available stream on the previous connection. + { + req, err := http.NewRequest("POST", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + } + + // Request 2 should still be blocking at this point. + select { + case <-reqc: + t.Errorf("request 2 unexpectedly completed") + default: + } + + conn.unblock() + <-reqc + + if connCount != 2 { + t.Errorf("created %v connections, want 1", connCount) + } + }) + } +} + +func TestTransportCloseRequestBody(t *testing.T) { + var statusCode int + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + + for _, status := range []int{200, 401} { + t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) { + statusCode = status + pr, pw := io.Pipe() + body := newCloseChecker(pr) + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + t.Fatal(err) + } + res, err := cc.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + pw.Close() + if err := body.isClosed(); err != nil { + t.Fatal(err) + } + }) + } +} + +// collectClientsConnPool is a http2ClientConnPool that wraps lower and +// collects what calls were made on it. +type collectClientsConnPool struct { + lower http2ClientConnPool + + mu sync.Mutex + getErrs int + got []*http2ClientConn +} + +func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { + cc, err := p.lower.GetClientConn(req, addr) + p.mu.Lock() + defer p.mu.Unlock() + if err != nil { + p.getErrs++ + return nil, err + } + p.got = append(p.got, cc) + return cc, nil +} + +func (p *collectClientsConnPool) MarkDead(cc *http2ClientConn) { + p.lower.MarkDead(cc) +} + +func TestTransportRetriesOnStreamProtocolError(t *testing.T) { + ct := newClientTester(t) + pool := &collectClientsConnPool{ + lower: &http2clientConnPool{t: ct.tr}, + } + ct.tr.ConnPool = pool + + gotProtoError := make(chan bool, 1) + ct.tr.CountError = func(errType string) { + if errType == "recv_rststream_PROTOCOL_ERROR" { + select { + case gotProtoError <- true: + default: + } + } + } + ct.client = func() error { + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. We check that + // after the first one closes, the connection then + // shuts down. + + // The long, outer request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) + res1, err := ct.tr.RoundTrip(req1) + if err != nil { + return err + } + if got, want := res1.Header.Get("Is-Long"), "1"; got != want { + return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) + } + + req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) + res, err := ct.tr.RoundTrip(req) + const want = "only one dial allowed in test mode" + if got := fmt.Sprint(err); got != want { + t.Errorf("didn't dial again: got %#q; want %#q", got, want) + } + if res != nil { + res.Body.Close() + } + select { + case <-gotProtoError: + default: + t.Errorf("didn't get stream protocol error") + } + + if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { + t.Errorf("unexpected body read %v, %v", n, err) + } + + pool.mu.Lock() + defer pool.mu.Unlock() + if pool.getErrs != 1 { + t.Errorf("pool get errors = %v; want 1", pool.getErrs) + } + if len(pool.got) == 2 { + if pool.got[0] != pool.got[1] { + t.Errorf("requests went on different connections") + } + cc := pool.got[0] + cc.mu.Lock() + if !cc.doNotReuse { + t.Error("ClientConn not marked doNotReuse") + } + cc.mu.Unlock() + + select { + case <-cc.readerDone: + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for reader to be done") + } + } else { + t.Errorf("pool get success = %v; want 2", len(pool.got)) + } + return nil + } + ct.server = func() error { + ct.greet() + var sentErr bool + var numHeaders int + var firstStreamID uint32 + + var hbuf bytes.Buffer + enc := hpack.NewEncoder(&hbuf) + + for { + f, err := ct.fr.ReadFrame() + if err == io.EOF { + // Client hung up on us, as it should at the end. + return nil + } + if err != nil { + return nil + } + switch f := f.(type) { + case *http2WindowUpdateFrame, *http2SettingsFrame: + case *http2HeadersFrame: + numHeaders++ + if numHeaders == 1 { + firstStreamID = f.StreamID + hbuf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) + ct.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: hbuf.Bytes(), + }) + continue + } + if !sentErr { + sentErr = true + ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeProtocol) + ct.fr.WriteData(firstStreamID, true, nil) + continue + } + } + } + return nil + } + ct.run() +} + +func TestClientConnReservations(t *testing.T) { + cc := &http2ClientConn{ + reqHeaderMu: make(chan struct{}, 1), + streams: make(map[uint32]*http2clientStream), + maxConcurrentStreams: http2initialMaxConcurrentStreams, + nextStreamID: 1, + t: &http2Transport{}, + } + cc.cond = sync.NewCond(&cc.mu) + n := 0 + for n <= http2initialMaxConcurrentStreams && cc.ReserveNewRequest() { + n++ + } + if n != http2initialMaxConcurrentStreams { + t.Errorf("did %v reservations; want %v", n, http2initialMaxConcurrentStreams) + } + if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, http2errNilRequestURL) { + t.Fatalf("RoundTrip error = %v; want http2errNilRequestURL", err) + } + n2 := 0 + for n2 <= 5 && cc.ReserveNewRequest() { + n2++ + } + if n2 != 1 { + t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2) + } + + // Use up all the reservations + for i := 0; i < n; i++ { + cc.RoundTrip(new(http.Request)) + } + + n2 = 0 + for n2 <= http2initialMaxConcurrentStreams && cc.ReserveNewRequest() { + n2++ + } + if n2 != n { + t.Errorf("after reset, reservations = %v; want %v", n2, n) + } +} + +func TestTransportTimeoutServerHangs(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req = req.WithContext(ctx) + req.Header.Add("Big", strings.Repeat("a", 1<<20)) + _, err = ct.tr.RoundTrip(req) + if err == nil { + return errors.New("error should not be nil") + } + if ne, ok := err.(net.Error); !ok || !ne.Timeout() { + return fmt.Errorf("error should be a net error timeout: %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + select { + case <-time.After(5 * time.Second): + case <-clientDone: + } + return nil + } + ct.run() +} + +func TestTransportContentLengthWithoutBody(t *testing.T) { + contentLength := "" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", contentLength) + }, optOnlyServer) + defer st.Close() + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + for _, test := range []struct { + name string + contentLength string + wantBody string + wantErr error + wantContentLength int64 + }{ + { + name: "non-zero content length", + contentLength: "42", + wantErr: io.ErrUnexpectedEOF, + wantContentLength: 42, + }, + { + name: "zero content length", + contentLength: "0", + wantErr: nil, + wantContentLength: 0, + }, + } { + t.Run(test.name, func(t *testing.T) { + contentLength = test.contentLength + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + + if err != test.wantErr { + t.Errorf("Expected error %v, got: %v", test.wantErr, err) + } + if len(body) > 0 { + t.Errorf("Expected empty body, got: %v", body) + } + if res.ContentLength != test.wantContentLength { + t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength) + } + }) + } +} + +func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.(http.Flusher).Flush() + io.Copy(io.Discard, r.Body) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + pr, pw := net.Pipe() + req, err := http.NewRequest("GET", st.ts.URL, pr) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + // Closing the Response's Body interrupts the blocked body read. + res.Body.Close() + pw.Close() +} + +func TestTransport300ResponseBody(t *testing.T) { + reqc := make(chan struct{}) + body := []byte("response body") + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(300) + w.(http.Flusher).Flush() + <-reqc + w.Write(body) + }, optOnlyServer) + defer st.Close() + + tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + pr, pw := net.Pipe() + req, err := http.NewRequest("GET", st.ts.URL, pr) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + close(reqc) + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if !bytes.Equal(got, body) { + t.Errorf("got response body %q, want %q", string(got), string(body)) + } + res.Body.Close() + pw.Close() +} + +func TestTransportWriteByteTimeout(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + ) + defer st.Close() + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + _, c := net.Pipe() + return c, nil + }, + WriteByteTimeout: 1 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + _, err := c.Get(st.ts.URL) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err) + } +} + +type slowWriteConn struct { + net.Conn + hasWriteDeadline bool +} + +func (c *slowWriteConn) SetWriteDeadline(t time.Time) error { + c.hasWriteDeadline = !t.IsZero() + return nil +} + +func (c *slowWriteConn) Write(b []byte) (n int, err error) { + if c.hasWriteDeadline && len(b) > 1 { + n, err = c.Conn.Write(b[:1]) + if err != nil { + return n, err + } + return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded) + } + return c.Conn.Write(b) +} + +func TestTransportSlowWrites(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + ) + defer st.Close() + tr := &http2Transport{ + t1: &Transport{ + TLSClientConfig: tlsConfigInsecure, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + cfg.InsecureSkipVerify = true + c, err := tls.Dial(network, addr, cfg) + return &slowWriteConn{Conn: c}, err + }, + WriteByteTimeout: 1 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + const bodySize = 1 << 20 + resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize)) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} From 088a707391da632a579da517e73b25090af27e99 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 09:46:19 +0800 Subject: [PATCH 338/843] add h2_transport_go117_test.go --- h2_transport_go117_test.go | 169 +++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 h2_transport_go117_test.go diff --git a/h2_transport_go117_test.go b/h2_transport_go117_test.go new file mode 100644 index 00000000..2b462f4d --- /dev/null +++ b/h2_transport_go117_test.go @@ -0,0 +1,169 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.17 +// +build go1.17 + +package req + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + "net/http/httptest" + + "testing" +) + +func TestTransportDialTLSContext(t *testing.T) { + blockCh := make(chan struct{}) + serverTLSConfigFunc := func(ts *httptest.Server) { + ts.Config.TLSConfig = &tls.Config{ + // Triggers the server to request the clients certificate + // during TLS handshake. + ClientAuth: tls.RequestClientCert, + } + } + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + serverTLSConfigFunc, + ) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Tests that the context provided to `req` is + // passed into this function. + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + req = req.WithContext(ctx) + errCh := make(chan error) + go func() { + defer close(errCh) + res, err := tr.RoundTrip(req) + if err != nil { + errCh <- err + return + } + res.Body.Close() + }() + // Wait for GetClientCertificate handler to be called + <-blockCh + // Cancel the context + cancel() + // Expect the cancellation error here + err = <-errCh + if err == nil { + t.Fatal("cancelling context during client certificate fetch did not error as expected") + return + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected error returned after cancellation: %v", err) + } +} + +// TestDialRaceResumesDial tests that, given two concurrent requests +// to the same address, when the first Dial is interrupted because +// the first request's context is cancelled, the second request +// resumes the dial automatically. +func TestDialRaceResumesDial(t *testing.T) { + blockCh := make(chan struct{}) + serverTLSConfigFunc := func(ts *httptest.Server) { + ts.Config.TLSConfig = &tls.Config{ + // Triggers the server to request the clients certificate + // during TLS handshake. + ClientAuth: tls.RequestClientCert, + } + } + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + serverTLSConfigFunc, + ) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + select { + case <-blockCh: + // If we already errored, return without error. + return &tls.Certificate{}, nil + default: + } + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + // Create two requests with independent cancellation. + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + req1 := req.WithContext(ctx1) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + req2 := req.WithContext(ctx2) + errCh := make(chan error) + go func() { + res, err := tr.RoundTrip(req1) + if err != nil { + errCh <- err + return + } + res.Body.Close() + }() + successCh := make(chan struct{}) + go func() { + // Don't start request until first request + // has initiated the handshake. + <-blockCh + res, err := tr.RoundTrip(req2) + if err != nil { + errCh <- err + return + } + res.Body.Close() + // Close successCh to indicate that the second request + // made it to the server successfully. + close(successCh) + }() + // Wait for GetClientCertificate handler to be called + <-blockCh + // Cancel the context first + cancel1() + // Expect the cancellation error here + err = <-errCh + if err == nil { + t.Fatal("cancelling context during client certificate fetch did not error as expected") + return + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected error returned after cancellation: %v", err) + } + select { + case err := <-errCh: + t.Fatalf("unexpected second error: %v", err) + case <-successCh: + } +} From 4103836c175fbee946f5d4646269144e3bf1e03e Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 09:50:41 +0800 Subject: [PATCH 339/843] transport_test.go -> transport_internal_test.go --- transport_test.go => transport_internal_test.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename transport_test.go => transport_internal_test.go (100%) diff --git a/transport_test.go b/transport_internal_test.go similarity index 100% rename from transport_test.go rename to transport_internal_test.go From 7127a79f0bcee85ab19b7fcfee7b6368c0966a8f Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 12:00:20 +0800 Subject: [PATCH 340/843] add transport_test.go --- h2_transport_go117_test.go | 2 +- h2_transport_test.go | 2 +- transport.go | 1 + transport_test.go | 6224 ++++++++++++++++++++++++++++++++++++ 4 files changed, 6227 insertions(+), 2 deletions(-) create mode 100644 transport_test.go diff --git a/h2_transport_go117_test.go b/h2_transport_go117_test.go index 2b462f4d..b27237b8 100644 --- a/h2_transport_go117_test.go +++ b/h2_transport_go117_test.go @@ -17,7 +17,7 @@ import ( "testing" ) -func TestTransportDialTLSContext(t *testing.T) { +func TestTransportDialTLSContexth2(t *testing.T) { blockCh := make(chan struct{}) serverTLSConfigFunc := func(ts *httptest.Server) { ts.Config.TLSConfig = &tls.Config{ diff --git a/h2_transport_test.go b/h2_transport_test.go index 846a535f..78bca258 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -624,7 +624,7 @@ func shortString(v string) string { return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) } -func TestTransportDialTLS(t *testing.T) { +func TestTransportDialTLSh2(t *testing.T) { var mu sync.Mutex // guards following var gotReq, didDial bool diff --git a/transport.go b/transport.go index f8fd40b3..d7716e37 100644 --- a/transport.go +++ b/transport.go @@ -375,6 +375,7 @@ func (t *Transport) Clone() *Transport { ReadBufferSize: t.ReadBufferSize, ResponseOptions: t.ResponseOptions, ForceHttpVersion: t.ForceHttpVersion, + Debugf: t.Debugf, dump: t.dump.Clone(), } if t.dump != nil { diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 00000000..551b134d --- /dev/null +++ b/transport_test.go @@ -0,0 +1,6224 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for transport.go. +// +// More tests are in clientserver_test.go (for things testing both client & server for both +// HTTP/1 and HTTP/2). This + +package req + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/binary" + "errors" + "fmt" + "github.com/imroc/req/v3/internal/testcert" + "go/token" + "golang.org/x/net/http/httpproxy" + "io" + "log" + mrand "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/http/httputil" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/iotest" + "time" + + "golang.org/x/net/http/httpguts" +) + +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqCanceler) +} + +func (t *Transport) IdleConnKeysForTesting() (keys []string) { + keys = make([]string, 0) + t.idleMu.Lock() + defer t.idleMu.Unlock() + for key := range t.idleConn { + keys = append(keys, key.String()) + } + sort.Strings(keys) + return +} + +func (t *Transport) IdleConnKeyCountForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConn) +} + +func (t *Transport) IdleConnStrsForTesting() []string { + var ret []string + t.idleMu.Lock() + defer t.idleMu.Unlock() + for _, conns := range t.idleConn { + for _, pc := range conns { + ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String()) + } + } + sort.Strings(ret) + return ret +} + +func (t *Transport) IdleConnStrsForTesting_h2() []string { + var ret []string + noDialPool := t.t2.ConnPool.(http2noDialClientConnPool) + pool := noDialPool.http2clientConnPool + + pool.mu.Lock() + defer pool.mu.Unlock() + + for k, cc := range pool.conns { + for range cc { + ret = append(ret, k) + } + } + + sort.Strings(ret) + return ret +} + +func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + key := connectMethodKey{"", scheme, addr, false} + cacheKey := key.String() + for k, conns := range t.idleConn { + if k.String() == cacheKey { + return len(conns) + } + } + return 0 +} + +func (t *Transport) IdleConnWaitMapSizeForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConnWait) +} + +func (t *Transport) IsIdleForTesting() bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.closeIdle +} + +func (t *Transport) QueueForIdleConnForTesting() { + t.queueForIdleConn(nil) +} + +// PutIdleTestConn reports whether it was able to insert a fresh +// persistConn for scheme, addr into the idle connection pool. +func (t *Transport) PutIdleTestConn(scheme, addr string) bool { + c, _ := net.Pipe() + key := connectMethodKey{"", scheme, addr, false} + + if t.MaxConnsPerHost > 0 { + // Transport is tracking conns-per-host. + // Increment connection count to account + // for new persistConn created below. + t.connsPerHostMu.Lock() + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[key]++ + t.connsPerHostMu.Unlock() + } + + return t.tryPutIdleConn(&persistConn{ + t: t, + conn: c, // dummy + closech: make(chan struct{}), // so it can be closed + cacheKey: key, + }) == nil +} + +// PutIdleTestConnH2 reports whether it was able to insert a fresh +// HTTP/2 persistConn for scheme, addr into the idle connection pool. +func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt http.RoundTripper) bool { + key := connectMethodKey{"", scheme, addr, false} + + if t.MaxConnsPerHost > 0 { + // Transport is tracking conns-per-host. + // Increment connection count to account + // for new persistConn created below. + t.connsPerHostMu.Lock() + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[key]++ + t.connsPerHostMu.Unlock() + } + + return t.tryPutIdleConn(&persistConn{ + t: t, + alt: alt, + cacheKey: key, + }) == nil +} + +// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close +// and then verify that the final 2 responses get errors back. + +// hostPortHandler writes back the client's "host:port". +var hostPortHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.FormValue("close") == "true" { + w.Header().Set("Connection", "close") + } + w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) + w.Write([]byte(r.RemoteAddr)) +}) + +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + mu sync.Mutex // guards closed and list + closed map[net.Conn]bool + list []net.Conn // in order created +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(ctx context.Context, n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(_ context.Context, n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + for i := 4; i >= 0; i-- { + for i, c := range tcs.list { + if tcs.closed[c] { + continue + } + if i != 0 { + tcs.mu.Unlock() + time.Sleep(50 * time.Millisecond) + tcs.mu.Lock() + continue + } + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + +func TestReuseRequest(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + })) + defer ts.Close() + + c := tc().httpClient + req, _ := http.NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } + + res, err = c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } +} + +// Two subsequent requests and verify their response is the same. +// The response from the server is our own IP:port +func TestTransportKeepAlives(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + c := tc().httpClient + for _, disableKeepAlive := range []bool{false, true} { + c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive + fetch := func(n int) string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + bodiesDiffer := body1 != body2 + if bodiesDiffer != disableKeepAlive { + t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + disableKeepAlive, bodiesDiffer, body1, body2) + } + } +} + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "testing.(*M).before.func1") || + strings.Contains(stack, "os/signal.signal_recv") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "net/http_test.interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +func afterTest(t testing.TB) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + // var bad string + // badSubstring := map[string]string{ + // ").readLoop(": "a Transport", + // ").writeLoop(": "a Transport", + // "created by net/http/httptest.(*Server).Start": "an httptest.Server", + // "timeoutHandler": "a TimeoutHandler", + // "net.(*netFD).connect(": "a timing out dial", + // ").noteClientGone(": "a closenotifier sender", + // } + // var stacks string + // for i := 0; i < 10; i++ { + // bad = "" + // stacks = strings.Join(interestingGoroutines(), "\n\n") + // for substr, what := range badSubstring { + // if strings.Contains(stacks, substr) { + // bad = what + // } + // } + // if bad == "" { + // return + // } + // // Bad stuff found, but goroutines might just still be + // // shutting down, so give it some time. + // time.Sleep(250 * time.Millisecond) + // } + // t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} + +func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + connSet, testDial := makeTestDial(t) + + c := tc().httpClient + tr := c.Transport.(*Transport) + tr.DialContext = testDial + + for _, connectionClose := range []bool{false, true} { + fetch := func(n int) string { + req := new(http.Request) + var err error + req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + connSet, testDial := makeTestDial(t) + + c := tc().httpClient + tr := c.Transport.(*Transport) + tr.DialContext = testDial + for _, connectionClose := range []bool{false, true} { + fetch := func(n int) string { + req := new(http.Request) + var err error + req.URL, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = connectionClose + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want { + t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", + connectionClose, got, !connectionClose) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +// if the Transport's DisableKeepAlives is set, all requests should +// send Connection: close. +// HTTP/1-only (Connection: close doesn't exist in h2) +func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + c := tc().httpClient + c.Transport.(*Transport).DisableKeepAlives = true + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.Header.Get("X-Saw-Close") != "true" { + t.Errorf("handler didn't see Connection: close ") + } +} + +// Test that Transport only sends one "Connection: close", regardless of +// how "close" was indicated. +func TestTransportRespectRequestWantsClose(t *testing.T) { + tests := []struct { + disableKeepAlives bool + close bool + }{ + {disableKeepAlives: false, close: false}, + {disableKeepAlives: false, close: true}, + {disableKeepAlives: true, close: false}, + {disableKeepAlives: true, close: true}, + } + + for _, testCase := range tests { + t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", testCase.disableKeepAlives, testCase.close), + func(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + c := tc().httpClient + c.Transport.(*Transport).DisableKeepAlives = testCase.disableKeepAlives + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + count := 0 + trace := &httptrace.ClientTrace{ + WroteHeaderField: func(key string, field []string) { + if key != "Connection" { + return + } + if httpguts.HeaderValuesContainsToken(field, "close") { + count += 1 + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + req.Close = testCase.close + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if want := testCase.disableKeepAlives || testCase.close; count > 1 || (count == 1) != want { + t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) + } + }) + } + +} + +func TestTransportIdleCacheKeys(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + c := tc().httpClient + tr := c.Transport.(*Transport) + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } + + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + io.ReadAll(resp.Body) + + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) + } + + if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { + t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) + } + + tr.CloseIdleConnections() + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } +} + +// Tests that the HTTP transport re-uses connections when a client +// reads to the end of a response Body without closing it. +func TestTransportReadToEndReusesConn(t *testing.T) { + defer afterTest(t) + const msg = "foobar" + + var addrSeen map[string]int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + addrSeen[r.RemoteAddr]++ + if r.URL.Path == "/chunked/" { + w.WriteHeader(200) + w.(http.Flusher).Flush() + } else { + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) + w.WriteHeader(200) + } + w.Write([]byte(msg)) + })) + defer ts.Close() + + buf := make([]byte, len(msg)) + + for pi, path := range []string{"/content-length/", "/chunked/"} { + wantLen := []int{len(msg), -1}[pi] + addrSeen = make(map[string]int) + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL + path) + if err != nil { + t.Errorf("Get %s: %v", path, err) + continue + } + // We want to close this body eventually (before the + // defer afterTest at top runs), but not before the + // len(addrSeen) check at the bottom of this test, + // since Closing this early in the loop would risk + // making connections be re-used for the wrong reason. + defer res.Body.Close() + + if res.ContentLength != int64(wantLen) { + t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) + } + n, err := res.Body.Read(buf) + if n != len(msg) || err != io.EOF { + t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + } + } + if len(addrSeen) != 1 { + t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) + } + } +} + +func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer afterTest(t) + stop := make(chan struct{}) // stop marks the exit of main Test goroutine + defer close(stop) + + resch := make(chan string) + gotReq := make(chan bool) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReq <- true + var msg string + select { + case <-stop: + return + case msg = <-resch: + } + _, err := w.Write([]byte(msg)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + })) + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + maxIdleConnsPerHost := 2 + tr.MaxIdleConnsPerHost = maxIdleConnsPerHost + + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we write to resch, though. + donech := make(chan bool) + doReq := func() { + defer func() { + select { + case <-stop: + return + case donech <- t.Failed(): + } + }() + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if _, err := io.ReadAll(resp.Body); err != nil { + t.Errorf("ReadAll: %v", err) + return + } + } + go doReq() + <-gotReq + go doReq() + <-gotReq + go doReq() + <-gotReq + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) + } + + resch <- "res1" + <-donech + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) + } + addr := ts.Listener.Addr().String() + cacheKey := "|http|" + addr + if keys[0] != cacheKey { + t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) + } + if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { + t.Errorf("after first response, expected %d idle conns; got %d", e, g) + } + + resch <- "res2" + <-donech + if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { + t.Errorf("after second response, idle conns = %d; want %d", g, w) + } + + resch <- "res3" + <-donech + if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { + t.Errorf("after third response, idle conns = %d; want %d", g, w) + } +} + +func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + })) + defer ts.Close() + c := tc().httpClient + tr := c.Transport.(*Transport) + dialStarted := make(chan struct{}) + stallDial := make(chan struct{}) + tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + dialStarted <- struct{}{} + <-stallDial + return net.Dial(network, addr) + } + + tr.DisableKeepAlives = true + tr.MaxConnsPerHost = 1 + + preDial := make(chan struct{}) + reqComplete := make(chan struct{}) + doReq := func(reqId string) { + req, _ := http.NewRequest("GET", ts.URL, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + preDial <- struct{}{} + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + reqComplete <- struct{}{} + } + // get req1 to dial-in-progress + go doReq("req1") + <-preDial + <-dialStarted + + // get req2 to waiting on conns per host to go down below max + go doReq("req2") + <-preDial + select { + case <-dialStarted: + t.Error("req2 dial started while req1 dial in progress") + return + default: + } + + // let req1 complete + stallDial <- struct{}{} + <-reqComplete + + // let req2 complete + <-dialStarted + stallDial <- struct{}{} + <-reqComplete +} + +func TestTransportMaxConnsPerHost(t *testing.T) { + defer afterTest(t) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + testMaxConns := func(scheme string, ts *httptest.Server) { + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + mu := sync.Mutex{} + var conns []net.Conn + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + mu.Lock() + defer mu.Unlock() + conns = append(conns, c) + return c, err + } + + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, + } + req, _ := http.NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } + } + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) + } + if gotConnCnt != expected { + t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) + } + + if t.Failed() { + t.FailNow() + } + + mu.Lock() + for _, c := range conns { + c.Close() + } + conns = nil + mu.Unlock() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) + } + } + + testMaxConns("http", httptest.NewServer(h)) + testMaxConns("https", httptest.NewTLSServer(h)) + + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + testMaxConns("http2", ts) +} + +func TestTransportRemovesDeadIdleConnections(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + })) + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + + doReq := func(name string) string { + // Do a POST instead of a GET to prevent the Transport's + // idempotent request retry logic from kicking in... + res, err := c.Post(ts.URL, "", nil) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + if res.StatusCode != 200 { + t.Fatalf("%s: %v", name, res.Status) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + return string(slurp) + } + + first := doReq("first") + keys1 := tr.IdleConnKeysForTesting() + + ts.CloseClientConnections() + + var keys2 []string + if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { + keys2 = tr.IdleConnKeysForTesting() + return len(keys2) == 0 + }) { + t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) + } + + second := doReq("second") + if first == second { + t.Errorf("expected a different connection between requests. got %q both times", first) + } +} + +// ExportCloseTransportConnsAbruptly closes all idle connections from +// tr in an abrupt way, just reaching into the underlying Conns and +// closing them, without telling the Transport or its persistConns +// that it's doing so. This is to simulate the server closing connections +// on the Transport. +func ExportCloseTransportConnsAbruptly(tr *Transport) { + tr.idleMu.Lock() + for _, pcs := range tr.idleConn { + for _, pc := range pcs { + pc.conn.Close() + } + } + tr.idleMu.Unlock() +} + +// Test that the Transport notices when a server hangs up on its +// unexpectedly (a keep-alive connection is closed). +func TestTransportServerClosingUnexpectedly(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + c := tc().httpClient + + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(time.Second / time.Duration(retries)) + } + for retries >= 0 { + retries-- + res, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := io.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) + } + panic("unreachable") + } + + body1 := fetch(1, 0) + body2 := fetch(2, 0) + + // Close all the idle connections in a way that's similar to + // the server hanging up on us. We don't use + // httptest.Server.CloseClientConnections because it's + // best-effort and stops blocking after 5 seconds. On a loaded + // machine running many tests concurrently it's possible for + // that method to be async and cause the body3 fetch below to + // run on an old connection. This function is synchronous. + ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) + + body3 := fetch(3, 5) + + if body1 != body2 { + t.Errorf("expected body1 and body2 to be equal") + } + if body2 == body3 { + t.Errorf("expected body2 and body3 to be different") + } +} + +// Test for https://golang.org/issue/2616 (appropriate issue number) +// This fails pretty reliably with GOMAXPROCS=100 or something high. +func TestStressSurpriseServerCloses(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in short mode") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "5") + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Hello")) + w.(http.Flusher).Flush() + conn, buf, _ := w.(http.Hijacker).Hijack() + buf.Flush() + conn.Close() + })) + defer ts.Close() + c := tc().httpClient + + // Do a bunch of traffic from different goroutines. Send to activityc + // after each request completes, regardless of whether it failed. + // If these are too high, OS X exhausts its ephemeral ports + // and hangs waiting for them to transition TCP states. That's + // not what we want to test. TODO(bradfitz): use an io.Pipe + // dialer for this test instead? + const ( + numClients = 20 + reqsPerClient = 25 + ) + activityc := make(chan bool) + for i := 0; i < numClients; i++ { + go func() { + for i := 0; i < reqsPerClient; i++ { + res, err := c.Get(ts.URL) + if err == nil { + // We expect errors since the server is + // hanging up on us after telling us to + // send more requests, so we don't + // actually care what the error is. + // But we want to close the body in cases + // where we won the race. + res.Body.Close() + } + if !<-activityc { // Receives false when close(activityc) is executed + return + } + } + }() + } + + // Make sure all the request come back, one way or another. + for i := 0; i < numClients*reqsPerClient; i++ { + select { + case activityc <- true: + case <-time.After(5 * time.Second): + close(activityc) + t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") + } + } +} + +// TestTransportHeadResponses verifies that we deal with Content-Lengths +// with no bodies properly +func TestTransportHeadResponses(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Content-Length", "123") + w.WriteHeader(200) + })) + defer ts.Close() + c := tc().httpClient + + for i := 0; i < 2; i++ { + res, err := c.Head(ts.URL) + if err != nil { + t.Errorf("error on loop %d: %v", i, err) + continue + } + if e, g := "123", res.Header.Get("Content-Length"); e != g { + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) + } + if e, g := int64(123), res.ContentLength; e != g { + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) + } + if all, err := io.ReadAll(res.Body); err != nil { + t.Errorf("loop %d: Body ReadAll: %v", i, err) + } else if len(all) != 0 { + t.Errorf("Bogus body %q", all) + } + } +} + +// All test hooks must be non-nil so they can be called directly, +// but the tests use nil to mean hook disabled. +func unnilTestHook(f *func()) { + if *f == nil { + *f = nop + } +} + +func SetReadLoopBeforeNextReadHook(f func()) { + testHookMu.Lock() + defer testHookMu.Unlock() + unnilTestHook(&f) + testHookReadLoopBeforeNextRead = f +} + +// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding +// on responses to HEAD requests. +func TestTransportHeadChunkedResponse(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Transfer-Encoding", "chunked") // client should ignore + w.Header().Set("x-client-ipport", r.RemoteAddr) + w.WriteHeader(200) + })) + defer ts.Close() + c := tc().httpClient + + // Ensure that we wait for the readLoop to complete before + // calling Head again + didRead := make(chan bool) + SetReadLoopBeforeNextReadHook(func() { didRead <- true }) + defer SetReadLoopBeforeNextReadHook(nil) + + res1, err := c.Head(ts.URL) + <-didRead + + if err != nil { + t.Fatalf("request 1 error: %v", err) + } + + res2, err := c.Head(ts.URL) + <-didRead + + if err != nil { + t.Fatalf("request 2 error: %v", err) + } + if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { + t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) + } +} + +var roundTripTests = []struct { + accept string + expectAccept string + compressed bool +}{ + // Requests with no accept-encoding header use transparent compression + {"", "gzip", false}, + // Requests with other accept-encoding should pass through unmodified + {"foo", "foo", false}, + // Requests with accept-encoding == gzip should be passed through + {"gzip", "gzip", true}, +} + +// Test that the modification made to the Request by the http.RoundTripper is cleaned up +func TestRoundTripGzip(t *testing.T) { + setParallel(t) + defer afterTest(t) + const responseBody = "test response body" + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + accept := req.Header.Get("Accept-Encoding") + if expect := req.FormValue("expect_accept"); accept != expect { + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) + } + if accept == "gzip" { + rw.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(rw) + gz.Write([]byte(responseBody)) + gz.Close() + } else { + rw.Header().Set("Content-Encoding", accept) + rw.Write([]byte(responseBody)) + } + })) + defer ts.Close() + tr := tc().t + + for i, test := range roundTripTests { + // Test basic request (no accept-encoding) + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("%d. RoundTrip: %v", i, err) + continue + } + var body []byte + if test.compressed { + var r *gzip.Reader + r, err = gzip.NewReader(res.Body) + if err != nil { + t.Errorf("%d. gzip NewReader: %v", i, err) + continue + } + body, err = io.ReadAll(r) + res.Body.Close() + } else { + body, err = io.ReadAll(res.Body) + } + if err != nil { + t.Errorf("%d. Error: %q", i, err) + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) + } + } + +} + +func TestTransportGzip(t *testing.T) { + setParallel(t) + defer afterTest(t) + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method == "HEAD" { + if g := req.Header.Get("Accept-Encoding"); g != "" { + t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) + } + return + } + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.CopyN(gz, rand.Reader, nRandBytes) + } + gz.Close() + })) + defer ts.Close() + c := tc().httpClient + + for _, chunked := range []string{"1", "0"} { + // First fetch something large, but only read some of it. + res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, err = c.Get(ts.URL + "/?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } + + // And a HEAD request too, because they're always weird. + res, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("Head: %v", err) + } + if res.StatusCode != 200 { + t.Errorf("Head status=%d; want=200", res.StatusCode) + } +} + +// setParallel marks t as a parallel test if we're in short mode +// (all.bash), but as a serial test otherwise. Using t.Parallel isn't +// compatible with the afterTest func in non-short mode. +func setParallel(t *testing.T) { + if testing.Short() { + t.Parallel() + } +} + +// If a request has Expect:100-continue header, the request blocks sending body until the first response. +// Premature consumption of the request body should not be occurred. +func TestTransportExpect100Continue(t *testing.T) { + setParallel(t) + defer afterTest(t) + + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + switch req.URL.Path { + case "/100": + // This endpoint implicitly responds 100 Continue and reads body. + if _, err := io.Copy(io.Discard, req.Body); err != nil { + t.Error("Failed to read Body", err) + } + rw.WriteHeader(http.StatusOK) + case "/200": + // Go 1.5 adds Connection: close header if the client expect + // continue but not entire request body is consumed. + rw.WriteHeader(http.StatusOK) + case "/500": + rw.WriteHeader(http.StatusInternalServerError) + case "/keepalive": + // This hijacked endpoint responds error without Connection:close. + _, bufrw, err := rw.(http.Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") + bufrw.WriteString("Content-Length: 0\r\n\r\n") + bufrw.Flush() + case "/timeout": + // This endpoint tries to read body without 100 (Continue) response. + // After ExpectContinueTimeout, the reading will be started. + conn, bufrw, err := rw.(http.Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { + t.Error("Failed to read Body", err) + } + bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") + bufrw.Flush() + conn.Close() + } + + })) + defer ts.Close() + + tests := []struct { + path string + body []byte + sent int + status int + }{ + {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. + {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. + {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. + {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. + {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. + } + + c := tc().httpClient + for i, v := range tests { + tr := &Transport{ + ExpectContinueTimeout: 2 * time.Second, + } + defer tr.CloseIdleConnections() + c.Transport = tr + body := bytes.NewReader(v.body) + req, err := http.NewRequest("PUT", ts.URL+v.path, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Expect", "100-continue") + req.ContentLength = int64(len(v.body)) + + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + sent := len(v.body) - body.Len() + if v.status != resp.StatusCode { + t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) + } + if v.sent != sent { + t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) + } + } +} + +func TestSOCKS5Proxy(t *testing.T) { + defer afterTest(t) + ch := make(chan string, 1) + l := newLocalListener(t) + defer l.Close() + defer close(ch) + proxy := func(t *testing.T) { + s, err := l.Accept() + if err != nil { + t.Errorf("socks5 proxy Accept(): %v", err) + return + } + defer s.Close() + var buf [22]byte + if _, err := io.ReadFull(s, buf[:3]); err != nil { + t.Errorf("socks5 proxy initial read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) + return + } + if _, err := s.Write([]byte{5, 0}); err != nil { + t.Errorf("socks5 proxy initial write: %v", err) + return + } + if _, err := io.ReadFull(s, buf[:4]); err != nil { + t.Errorf("socks5 proxy second read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) + return + } + var ipLen int + switch buf[3] { + case 1: + ipLen = net.IPv4len + case 4: + ipLen = net.IPv6len + default: + t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) + return + } + if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { + t.Errorf("socks5 proxy address read: %v", err) + return + } + ip := net.IP(buf[4 : ipLen+4]) + port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) + copy(buf[:3], []byte{5, 0, 0}) + if _, err := s.Write(buf[:ipLen+6]); err != nil { + t.Errorf("socks5 proxy connect write: %v", err) + return + } + ch <- fmt.Sprintf("proxy for %s:%d", ip, port) + + // Implement proxying. + targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) + targetConn, err := net.Dial("tcp", targetHost) + if err != nil { + t.Errorf("net.Dial failed") + return + } + go io.Copy(targetConn, s) + io.Copy(s, targetConn) // Wait for the client to close the socket. + targetConn.Close() + } + + pu, err := url.Parse("socks5://" + l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + sentinelHeader := "X-Sentinel" + sentinelValue := "12345" + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(sentinelHeader, sentinelValue) + }) + for _, useTLS := range []bool{false, true} { + t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { + var ts *httptest.Server + if useTLS { + ts = httptest.NewTLSServer(h) + } else { + ts = httptest.NewServer(h) + } + go proxy(t) + c := tc().httpClient + c.Transport.(*Transport).Proxy = http.ProxyURL(pu) + r, err := c.Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if r.Header.Get(sentinelHeader) != sentinelValue { + t.Errorf("Failed to retrieve sentinel value") + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to socks5 proxy") + } + ts.Close() + tsu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "proxy for " + tsu.Host + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + } +} + +func TestTransportProxy(t *testing.T) { + defer afterTest(t) + testCases := []struct{ httpsSite, httpsProxy bool }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + for _, testCase := range testCases { + httpsSite := testCase.httpsSite + httpsProxy := testCase.httpsProxy + t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteCh := make(chan *http.Request, 1) + h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siteCh <- r + }) + proxyCh := make(chan *http.Request, 1) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyCh <- r + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + hijacker, ok := w.(http.Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &http.Response{ + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + var ts *httptest.Server + if httpsSite { + ts = httptest.NewTLSServer(h1) + } else { + ts = httptest.NewServer(h1) + } + var proxy *httptest.Server + if httpsProxy { + proxy = httptest.NewTLSServer(h2) + } else { + proxy = httptest.NewServer(h2) + } + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + // If neither server is HTTPS or both are, then c may be derived from either. + // If only one server is HTTPS, c must be derived from that server in order + // to ensure that it is configured to use the fake root CA from testcert.go. + c := tc().httpClient + + c.Transport.(*Transport).Proxy = http.ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got *http.Request + select { + case got = <-proxyCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to http proxy") + } + c.Transport.(*Transport).CloseIdleConnections() + ts.Close() + proxy.Close() + if httpsSite { + // First message should be a CONNECT, asking for a socket to the real server, + if got.Method != "CONNECT" { + t.Errorf("Wrong method for secure proxying: %q", got.Method) + } + gotHost := got.URL.Host + pu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal("Invalid site URL") + } + if wantHost := pu.Host; gotHost != wantHost { + t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) + } + + // The next message on the channel should be from the site's server. + next := <-siteCh + if next.Method != "HEAD" { + t.Errorf("Wrong method at destination: %s", next.Method) + } + if nextURL := next.URL.String(); nextURL != "/" { + t.Errorf("Wrong URL at destination: %s", nextURL) + } + } else { + if got.Method != "HEAD" { + t.Errorf("Wrong method for destination: %q", got.Method) + } + gotURL := got.URL.String() + wantURL := ts.URL + "/" + if gotURL != wantURL { + t.Errorf("Got URL %q, want %q", gotURL, wantURL) + } + } + }) + } +} + +// Issue 28012: verify that the Transport closes its TCP connection to http proxies +// when they're slow to reply to HTTPS CONNECT responses. +func TestTransportProxyHTTPSConnectLeak(t *testing.T) { + setParallel(t) + defer afterTest(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := newLocalListener(t) + defer ln.Close() + listenerDone := make(chan struct{}) + go func() { + defer close(listenerDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + return + } + defer c.Close() + // Read the CONNECT request + br := bufio.NewReader(c) + cr, err := http.ReadRequest(br) + if err != nil { + t.Errorf("proxy server failed to read CONNECT request") + return + } + if cr.Method != "CONNECT" { + t.Errorf("unexpected method %q", cr.Method) + return + } + + // Now hang and never write a response; instead, cancel the request and wait + // for the client to close. + // (Prior to Issue 28012 being fixed, we never closed.) + cancel() + var buf [1]byte + _, err = br.Read(buf[:]) + if err != io.EOF { + t.Errorf("proxy server Read err = %v; want EOF", err) + } + return + }() + + c := &http.Client{ + Transport: &Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return url.Parse("http://" + ln.Addr().String()) + }, + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Errorf("unexpected Get success") + } + + // Wait unconditionally for the listener goroutine to exit: this should never + // hang, so if it does we want a full goroutine dump — and that's exactly what + // the testing package will give us when the test run times out. + <-listenerDone +} + +// Issue 16997: test transport dial preserves typed errors +func TestTransportDialPreservesNetOpProxyError(t *testing.T) { + defer afterTest(t) + + var errDial = errors.New("some dial error") + + tr := &Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return url.Parse("http://proxy.fake.tld/") + }, + DialContext: func(context.Context, string, string) (net.Conn, error) { + return nil, errDial + }, + } + defer tr.CloseIdleConnections() + + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", "http://fake.tld", nil) + res, err := c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("wanted a non-nil error") + } + + uerr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %T, want *url.Error", err) + } + oe, ok := uerr.Err.(*net.OpError) + if !ok { + t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) + } + want := &net.OpError{ + Op: "proxyconnect", + Net: "tcp", + Err: errDial, // original error, unwrapped. + } + if !reflect.DeepEqual(oe, want) { + t.Errorf("Got error %#v; want %#v", oe, want) + } +} + +// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. +// +// (A bug caused dialConn to instead write the per-request Proxy-Authorization +// header through to the shared Header instance, introducing a data race.) +func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { + setParallel(t) + defer afterTest(t) + + proxy := httptest.NewTLSServer(http.NotFoundHandler()) + defer proxy.Close() + c := tc().httpClient + + tr := c.Transport.(*Transport) + tr.Proxy = func(*http.Request) (*url.URL, error) { + u, _ := url.Parse(proxy.URL) + u.User = url.UserPassword("aladdin", "opensesame") + return u, nil + } + h := tr.ProxyConnectHeader + if h == nil { + h = make(http.Header) + } + tr.ProxyConnectHeader = h.Clone() + + req, err := http.NewRequest("GET", "https://golang.fake.tld/", nil) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Errorf("unexpected Get success") + } + + if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { + t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := tc().httpClient + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// golang.org/issue/7750: request fails when server replies with +// a short gzip body +func TestTransportGzipShort(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write([]byte{0x1f, 0x8b}) + })) + defer ts.Close() + + c := tc().httpClient + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + _, err = io.ReadAll(res.Body) + if err == nil { + t.Fatal("Expect an error from reading a body.") + } + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) + } +} + +// Wait until number of goroutines is no greater than nmax, or time out. +func waitNumGoroutine(nmax int) int { + nfinal := runtime.NumGoroutine() + for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { + time.Sleep(50 * time.Millisecond) + runtime.GC() + nfinal = runtime.NumGoroutine() + } + return nfinal +} + +// tests that persistent goroutine connections shut down when no longer desired. +func TestTransportPersistConnLeak(t *testing.T) { + // Not parallel: counts goroutines + defer afterTest(t) + + const numReq = 25 + gotReqCh := make(chan bool, numReq) + unblockCh := make(chan bool, numReq) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReqCh <- true + <-unblockCh + w.Header().Set("Content-Length", "0") + w.WriteHeader(204) + })) + defer ts.Close() + c := tc().httpClient + tr := c.Transport.(*Transport) + + n0 := runtime.NumGoroutine() + + didReqCh := make(chan bool, numReq) + failed := make(chan bool, numReq) + for i := 0; i < numReq; i++ { + go func() { + res, err := c.Get(ts.URL) + didReqCh <- true + if err != nil { + t.Logf("client fetch error: %v", err) + failed <- true + return + } + res.Body.Close() + }() + } + + // Wait for all goroutines to be stuck in the Handler. + for i := 0; i < numReq; i++ { + select { + case <-gotReqCh: + // ok + case <-failed: + // Not great but not what we are testing: + // sometimes an overloaded system will fail to make all the connections. + } + } + + nhigh := runtime.NumGoroutine() + + // Tell all handlers to unblock and reply. + close(unblockCh) + + // Wait for all HTTP clients to be done. + for i := 0; i < numReq; i++ { + <-didReqCh + } + + tr.CloseIdleConnections() + nfinal := waitNumGoroutine(n0 + 5) + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + if int(growth) > 5 { + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + t.Error("too many new goroutines") + } +} + +// golang.org/issue/4531: Transport leaks goroutines when +// request.ContentLength is explicitly short +func TestTransportPersistConnLeakShortBody(t *testing.T) { + // Not parallel: measures goroutines. + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + c := tc().httpClient + tr := c.Transport.(*Transport) + + n0 := runtime.NumGoroutine() + body := []byte("Hello") + for i := 0; i < 20; i++ { + req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = int64(len(body) - 2) // explicitly short + _, err = c.Do(req) + if err == nil { + t.Fatal("Expect an error from writing too long of a body.") + } + } + nhigh := runtime.NumGoroutine() + tr.CloseIdleConnections() + nfinal := waitNumGoroutine(n0 + 5) + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + if int(growth) > 5 { + t.Error("too many new goroutines") + } +} + +// A countedConn is a net.Conn that decrements an atomic counter when finalized. +type countedConn struct { + net.Conn +} + +// A countingDialer dials connections and counts the number that remain reachable. +type countingDialer struct { + dialer net.Dialer + mu sync.Mutex + total, live int64 +} + +func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + + counted := new(countedConn) + counted.Conn = conn + + d.mu.Lock() + defer d.mu.Unlock() + d.total++ + d.live++ + + runtime.SetFinalizer(counted, d.decrement) + return counted, nil +} + +func (d *countingDialer) decrement(*countedConn) { + d.mu.Lock() + defer d.mu.Unlock() + d.live-- +} + +func (d *countingDialer) Read() (total, live int64) { + d.mu.Lock() + defer d.mu.Unlock() + return d.total, d.live +} + +func TestTransportPersistConnLeakNeverIdle(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Close every connection so that it cannot be kept alive. + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack failed unexpectedly: %v", err) + return + } + conn.Close() + })) + defer ts.Close() + + var d countingDialer + c := tc().httpClient + c.Transport.(*Transport).DialContext = d.DialContext + + body := []byte("Hello") + for i := 0; ; i++ { + total, live := d.Read() + if live < total { + break + } + if i >= 1<<12 { + t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) + } + + req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Fatal("expected broken connection") + } + + runtime.GC() + } +} + +type countedContext struct { + context.Context +} + +type contextCounter struct { + mu sync.Mutex + live int64 +} + +func (cc *contextCounter) Track(ctx context.Context) context.Context { + counted := new(countedContext) + counted.Context = ctx + cc.mu.Lock() + defer cc.mu.Unlock() + cc.live++ + runtime.SetFinalizer(counted, cc.decrement) + return counted +} + +func (cc *contextCounter) decrement(*countedContext) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.live-- +} + +func (cc *contextCounter) Read() (live int64) { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.live +} + +// This used to crash; https://golang.org/issue/3266 +func TestTransportIdleConnCrash(t *testing.T) { + defer afterTest(t) + var tr *Transport + + unblockCh := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-unblockCh + tr.CloseIdleConnections() + })) + defer ts.Close() + c := tc().httpClient + tr = c.Transport.(*Transport) + + didreq := make(chan bool) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } else { + res.Body.Close() // returns idle conn + } + didreq <- true + }() + unblockCh <- true + <-didreq +} + +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { + defer afterTest(t) + const numFoos = 5000 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })) + defer ts.Close() + c := tc().httpClient + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + setParallel(t) + defer afterTest(t) + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, deniedMsg, http.StatusUnauthorized) + })) + defer ts.Close() + c := tc().httpClient + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +// From https://golang.org/issue/4454 , +// "client fails to handle requests with no body and chunked encoding" +func TestChunkedNoContent(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + c := tc().httpClient + for _, closeBody := range []bool{true, false} { + const n = 4 + for i := 1; i <= n; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) + } else { + if closeBody { + res.Body.Close() + } + } + } + } +} + +// SetPendingDialHooks sets the hooks that run before and after handling +// pending dials. +func SetPendingDialHooks(before, after func()) { + unnilTestHook(&before) + unnilTestHook(&after) + testHookPrePendingDial, testHookPostPendingDial = before, after +} + +func TestTransportConcurrency(t *testing.T) { + // Not parallel: uses global test hooks. + defer afterTest(t) + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })) + defer ts.Close() + + var wg sync.WaitGroup + wg.Add(numReqs) + + // Due to the Transport's "socket late binding" (see + // idleConnCh in transport.go), the numReqs HTTP requests + // below can finish with a dial still outstanding. To keep + // the leak checker happy, keep track of pending dials and + // wait for them to finish (and be closed or returned to the + // idle pool) before we close idle connections. + SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) + defer SetPendingDialHooks(nil, nil) + + c := tc().httpClient + reqs := make(chan string) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + t.Errorf("error on req %s: %v", req, err) + wg.Done() + continue + } + all, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + wg.Done() + continue + } + if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + res.Body.Close() + wg.Done() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} + +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + setParallel(t) + defer afterTest(t) + const debug = false + mux := http.NewServeMux() + mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, neverEnding('a')) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + timeout := 100 * time.Millisecond + + c := tc().httpClient + c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = newLoggingConn("client", conn) + } + return conn, nil + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := c.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + _, err = io.Copy(io.Discard, sres.Body) + if err == nil { + t.Errorf("Unexpected successful copy") + break + } + } + if debug { + println("tests complete; waiting for handlers to finish") + } +} + +func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + setParallel(t) + defer afterTest(t) + const debug = false + mux := http.NewServeMux() + mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, neverEnding('a')) + }) + mux.HandleFunc("/put", func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + io.Copy(io.Discard, r.Body) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + c := tc().httpClient + c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = newLoggingConn("client", conn) + } + return conn, nil + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := c.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + req, _ := http.NewRequest("PUT", ts.URL+"/put", sres.Body) + _, err = c.Do(req) + if err == nil { + sres.Body.Close() + t.Errorf("Unexpected successful PUT") + break + } + sres.Body.Close() + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func reqWithT(r *http.Request, t *testing.T) *http.Request { + return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) +} + +func TestTransportResponseHeaderTimeout(t *testing.T) { + setParallel(t) + defer afterTest(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + inHandler := make(chan bool, 1) + mux := http.NewServeMux() + mux.HandleFunc("/fast", func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + }) + mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + c := tc().httpClient + c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + req, _ := http.NewRequest("GET", ts.URL+tt.path, nil) + req = reqWithT(req, t) + res, err := c.Do(req) + select { + case <-inHandler: + case <-time.After(5 * time.Second): + t.Errorf("never entered handler for test index %d, %s", i, tt.path) + continue + } + if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not an url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { + setParallel(t) + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + c := tc().httpClient + tr := c.Transport.(*Transport) + + donec := make(chan bool) + req, _ := http.NewRequest("GET", ts.URL, body) + go func() { + defer close(donec) + c.Do(req) + }() + start := time.Now() + timeout := 10 * time.Second + for time.Since(start) < timeout { + time.Sleep(100 * time.Millisecond) + tr.CancelRequest(req) + select { + case <-donec: + return + default: + } + } + t.Errorf("Do of canceled request has not returned after %v", timeout) +} + +func TestCancelRequestWithChannel(t *testing.T) { + setParallel(t) + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello") + w.(http.Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + c := tc().httpClient + tr := c.Transport.(*Transport) + + req, _ := http.NewRequest("GET", ts.URL, nil) + ch := make(chan struct{}) + req.Cancel = ch + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + close(ch) + }() + t0 := time.Now() + body, err := io.ReadAll(res.Body) + d := time.Since(t0) + + if err != http2errRequestCanceled { + t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 5; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { + testCancelRequestWithChannelBeforeDo(t, false) +} +func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { + testCancelRequestWithChannelBeforeDo(t, true) +} +func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { + setParallel(t) + defer afterTest(t) + unblockc := make(chan bool) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + c := tc().httpClient + + req, _ := http.NewRequest("GET", ts.URL, nil) + if withCtx { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req = req.WithContext(ctx) + } else { + ch := make(chan struct{}) + req.Cancel = ch + close(ch) + } + + _, err := c.Do(req) + if ue, ok := err.(*url.Error); ok { + err = ue.Err + } + if withCtx { + if err != context.Canceled { + t.Errorf("Do error = %v; want %v", err, context.Canceled) + } + } else { + if err == nil || !strings.Contains(err.Error(), "canceled") { + t.Errorf("Do error = %v; want cancellation", err) + } + } +} + +// Issue 11020. The returned error message should be errRequestCanceled +func TestTransportCancelBeforeResponseHeaders(t *testing.T) { + defer afterTest(t) + + serverConnCh := make(chan net.Conn, 1) + tr := &Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + cc, sc := net.Pipe() + serverConnCh <- sc + return cc, nil + }, + } + defer tr.CloseIdleConnections() + errc := make(chan error, 1) + req, _ := http.NewRequest("GET", "http://example.com/", nil) + go func() { + _, err := tr.RoundTrip(req) + errc <- err + }() + + sc := <-serverConnCh + verb := make([]byte, 3) + if _, err := io.ReadFull(sc, verb); err != nil { + t.Errorf("Error reading HTTP verb from server: %v", err) + } + if string(verb) != "GET" { + t.Errorf("server received %q; want GET", verb) + } + defer sc.Close() + + tr.CancelRequest(req) + + err := <-errc + if err == nil { + t.Fatalf("unexpected success from RoundTrip") + } + if err != http2errRequestCanceled { + t.Errorf("RoundTrip error = %v; want http2errRequestCanceled", err) + } +} + +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer afterTest(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(http.Flusher).Flush() + } + })) + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + + req, _ := http.NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Fatalf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + +type fooProto struct{} + +func (fooProto) RoundTrip(req *http.Request) (*http.Response, error) { + res := &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + } + return res, nil +} + +func TestTransportAltProto(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &http.Client{Transport: tr} + tr.RegisterProtocol("foo", fooProto{}) + res, err := c.Get("foo://bar.com/path") + if err != nil { + t.Fatal(err) + } + bodyb, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + body := string(bodyb) + if e := "You wanted foo://bar.com/path"; body != e { + t.Errorf("got response %q, want %q", body, e) + } +} + +func TestTransportNoHost(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + _, err := tr.RoundTrip(&http.Request{ + Header: make(http.Header), + URL: &url.URL{ + Scheme: "http", + }, + }) + want := "http: no Host in request URL" + if got := fmt.Sprint(err); got != want { + t.Errorf("error = %v; want %q", err, want) + } +} + +// Issue 13311 +func TestTransportEmptyMethod(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.com/", nil) + req.Method = "" // docs say "For client requests an empty string means GET" + got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(got), "GET ") { + t.Fatalf("expected substring 'GET '; got: %s", got) + } +} + +func TestTransportSocketLateBinding(t *testing.T) { + setParallel(t) + defer afterTest(t) + + mux := http.NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(http.Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + dialGate := make(chan bool, 1) + c := tc().httpClient + c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(io.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- false +} + +type dummyAddr string +type oneConnListener struct { + conn net.Conn +} + +func (l *oneConnListener) Accept() (c net.Conn, err error) { + c = l.conn + if c == nil { + err = io.EOF + return + } + err = nil + l.conn = nil + return +} + +func (l *oneConnListener) Close() error { + return nil +} + +func (l *oneConnListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func (a dummyAddr) Network() string { + return string(a) +} + +func (a dummyAddr) String() string { + return string(a) +} + +type noopConn struct{} + +func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } +func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } +func (noopConn) SetDeadline(t time.Time) error { return nil } +func (noopConn) SetReadDeadline(t time.Time) error { return nil } +func (noopConn) SetWriteDeadline(t time.Time) error { return nil } + +type rwTestConn struct { + io.Reader + io.Writer + noopConn + + closeFunc func() error // called if non-nil + closec chan bool // else, if non-nil, send value to it on close +} + +func (c *rwTestConn) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + select { + case c.closec <- true: + default: + } + return nil +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := http.ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := io.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + DialContext: func(_ context.Context, n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + testResponse := func(req *http.Request, name string, wantCode int) { + t.Helper() + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = io.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := http.NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } +} + +var ( + ExportHttp2ConfigureServer = http2ConfigureServer +) + +type clientServerTest struct { + t *testing.T + h2 bool + h http.Handler + ts *httptest.Server + tr *Transport + c *http.Client +} + +func (t *clientServerTest) close() { + t.tr.CloseIdleConnections() + t.ts.Close() +} + +func (t *clientServerTest) getURL(u string) string { + res, err := t.c.Get(u) + if err != nil { + t.t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.t.Fatal(err) + } + return string(slurp) +} + +func (t *clientServerTest) scheme() string { + if t.h2 { + return "https" + } + return "http" +} + +const ( + h1Mode = false + h2Mode = true +) + +var quietLog = log.New(io.Discard, "", 0) + +var optQuietLog = func(ts *httptest.Server) { + ts.Config.ErrorLog = quietLog +} + +func optWithServerLog(lg *log.Logger) func(*httptest.Server) { + return func(ts *httptest.Server) { + ts.Config.ErrorLog = lg + } +} + +func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interface{}) *clientServerTest { + cst := &clientServerTest{ + t: t, + h2: h2, + h: h, + tr: &Transport{}, + } + cst.c = &http.Client{Transport: cst.tr} + cst.ts = httptest.NewUnstartedServer(h) + + for _, opt := range opts { + switch opt := opt.(type) { + case func(*Transport): + opt(cst.tr) + case func(*httptest.Server): + opt(cst.ts) + default: + t.Fatalf("unhandled option type %T", opt) + } + } + + if !h2 { + cst.ts.Start() + return cst + } + http2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + + cst.tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + if _, err := http2ConfigureTransports(cst.tr); err != nil { + t.Fatal(err) + } + return cst +} + +// Issue 17739: the HTTP client must ignore any unknown 1xx +// informational responses before the actual response. +func TestTransportIgnore1xxResponses(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, buf, _ := w.(http.Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) + buf.Flush() + conn.Close() + })) + defer cst.close() + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + var got bytes.Buffer + + req, _ := http.NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) + return nil + }, + })) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + res.Write(&got) + want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" + if got.String() != want { + t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) + } +} + +func TestTransportLimits1xxResponses(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, buf, _ := w.(http.Hijacker).Hijack() + for i := 0; i < 10; i++ { + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) + } + buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + res, err := cst.c.Get(cst.ts.URL) + if res != nil { + defer res.Body.Close() + } + got := fmt.Sprint(err) + wantSub := "too many 1xx informational responses" + if !strings.Contains(got, wantSub) { + t.Errorf("Get error = %v; want substring %q", err, wantSub) + } +} + +// Issue 26161: the HTTP client must treat 101 responses +// as the final response. +func TestTransportTreat101Terminal(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, buf, _ := w.(http.Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) + buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusSwitchingProtocols { + t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) + } +} + +type proxyFromEnvTest struct { + req string // URL to fetch; blank means "http://example.com" + + env string // HTTP_PROXY + httpsenv string // HTTPS_PROXY + noenv string // NO_PROXY + reqmeth string // REQUEST_METHOD + + want string + wanterr error +} + +func (t proxyFromEnvTest) String() string { + var buf bytes.Buffer + space := func() { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + } + if t.env != "" { + fmt.Fprintf(&buf, "http_proxy=%q", t.env) + } + if t.httpsenv != "" { + space() + fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) + } + if t.noenv != "" { + space() + fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) + } + if t.reqmeth != "" { + space() + fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) + } + req := "http://example.com" + if t.req != "" { + req = t.req + } + space() + fmt.Fprintf(&buf, "req=%q", req) + return strings.TrimSpace(buf.String()) +} + +var proxyFromEnvTests = []proxyFromEnvTest{ + {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, + {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, + {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, + {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, + + // Don't use secure for http + {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, + // Use secure for https. + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, + + // Issue 16405: don't use HTTP_PROXY in a CGI environment, + // where HTTP_PROXY can be attacker-controlled. + {env: "http://10.1.2.3:8080", reqmeth: "POST", + want: "", + wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, + + {want: ""}, + + {noenv: "example.com", req: "http://example.com/", env: "proxy", want: ""}, + {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: ""}, + {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, +} + +func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *http.Request) (*url.URL, error)) { + t.Helper() + reqURL := tt.req + if reqURL == "" { + reqURL = "http://example.com" + } + req, _ := http.NewRequest("GET", reqURL, nil) + url, err := proxyForRequest(req) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { + t.Errorf("%v: got error = %q, want %q", tt, g, e) + return + } + if got := fmt.Sprintf("%s", url); got != tt.want { + t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) + } +} + +func ResetCachedEnvironment() { + resetProxyConfig() +} + +func ResetProxyEnv() { + for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} { + os.Unsetenv(v) + } + ResetCachedEnvironment() +} + +func TestProxyFromEnvironment(t *testing.T) { + ResetProxyEnv() + defer ResetProxyEnv() + for _, tt := range proxyFromEnvTests { + testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { + os.Setenv("HTTP_PROXY", tt.env) + os.Setenv("HTTPS_PROXY", tt.httpsenv) + os.Setenv("NO_PROXY", tt.noenv) + os.Setenv("REQUEST_METHOD", tt.reqmeth) + ResetCachedEnvironment() + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + }) + } +} + +func TestProxyFromEnvironmentLowerCase(t *testing.T) { + ResetProxyEnv() + defer ResetProxyEnv() + for _, tt := range proxyFromEnvTests { + testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { + os.Setenv("http_proxy", tt.env) + os.Setenv("https_proxy", tt.httpsenv) + os.Setenv("no_proxy", tt.noenv) + os.Setenv("REQUEST_METHOD", tt.reqmeth) + ResetCachedEnvironment() + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + }) + } +} + +func TestIdleConnChannelLeak(t *testing.T) { + // Not parallel: uses global test hooks. + var mu sync.Mutex + var n int + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + n++ + mu.Unlock() + })) + defer ts.Close() + + const nReqs = 5 + didRead := make(chan bool, nReqs) + SetReadLoopBeforeNextReadHook(func() { didRead <- true }) + defer SetReadLoopBeforeNextReadHook(nil) + + c := tc().httpClient + tr := c.Transport.(*Transport) + tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + + // First, without keep-alives. + for _, disableKeep := range []bool{true, false} { + tr.DisableKeepAlives = disableKeep + for i := 0; i < nReqs; i++ { + _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) + if err != nil { + t.Fatal(err) + } + // Note: no res.Body.Close is needed here, since the + // response Content-Length is zero. Perhaps the test + // should be more explicit and use a HEAD, but tests + // elsewhere guarantee that zero byte responses generate + // a "Content-Length: 0" instead of chunking. + } + + // At this point, each of the 5 Transport.readLoop goroutines + // are scheduling noting that there are no response bodies (see + // earlier comment), and are then calling putIdleConn, which + // decrements this count. Usually that happens quickly, which is + // why this test has seemed to work for ages. But it's still + // racey: we have wait for them to finish first. See Issue 10427 + for i := 0; i < nReqs; i++ { + <-didRead + } + + if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { + t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) + } + } +} + +// Verify the status quo: that the Client.Post function coerces its +// body into a ReadCloser if it's a Closer, and that the Transport +// then closes it. +func TestTransportClosesRequestBody(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(io.Discard, r.Body) + })) + defer ts.Close() + + c := tc().httpClient + + closes := 0 + + res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} + +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &http.Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +// Trying to repro golang.org/issue/3514 +func TestTLSServerClosesConnection(t *testing.T) { + defer afterTest(t) + + closedc := make(chan bool, 1) + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/keep-alive-then-die") { + conn, _, _ := w.(http.Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) + conn.Close() + closedc <- true + return + } + fmt.Fprintf(w, "hello") + })) + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + + var nSuccess = 0 + var errs []error + const trials = 20 + for i := 0; i < trials; i++ { + tr.CloseIdleConnections() + res, err := c.Get(ts.URL + "/keep-alive-then-die") + if err != nil { + t.Fatal(err) + } + <-closedc + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "foo" { + t.Errorf("Got %q, want foo", slurp) + } + + // Now try again and see if we successfully + // pick a new connection. + res, err = c.Get(ts.URL + "/") + if err != nil { + errs = append(errs, err) + continue + } + slurp, err = io.ReadAll(res.Body) + if err != nil { + errs = append(errs, err) + continue + } + nSuccess++ + } + if nSuccess > 0 { + t.Logf("successes = %d of %d", nSuccess, trials) + } else { + t.Errorf("All runs failed:") + } + for _, err := range errs { + t.Logf(" err: %v", err) + } +} + +// byteFromChanReader is an io.Reader that reads a single byte at a +// time from the channel. When the channel is closed, the reader +// returns io.EOF. +type byteFromChanReader chan byte + +func (c byteFromChanReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + b, ok := <-c + if !ok { + return 0, io.EOF + } + p[0] = b + return 1, nil +} + +// Verifies that the Transport doesn't reuse a connection in the case +// where the server replies before the request has been fully +// written. We still honor that reply (see TestIssue3595), but don't +// send future requests on the connection because it's then in a +// questionable state. +// golang.org/issue/7569 +func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + setParallel(t) + defer afterTest(t) + var sconn struct { + sync.Mutex + c net.Conn + } + var getOkay bool + closeConn := func() { + sconn.Lock() + defer sconn.Unlock() + if sconn.c != nil { + sconn.c.Close() + sconn.c = nil + if !getOkay { + t.Logf("Closed server connection") + } + } + } + defer closeConn() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + io.WriteString(w, "bar") + return + } + conn, _, _ := w.(http.Hijacker).Hijack() + sconn.Lock() + sconn.c = conn + sconn.Unlock() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive + go io.Copy(io.Discard, conn) + })) + defer ts.Close() + c := tc().httpClient + + const bodySize = 256 << 10 + finalBit := make(byteFromChanReader, 1) + req, _ := http.NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req.ContentLength = bodySize + res, err := c.Do(req) + if err := wantBody(res, err, "foo"); err != nil { + t.Errorf("POST response: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + res, err = c.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return + } + getOkay = true // suppress test noise + }() + time.AfterFunc(5*time.Second, closeConn) + select { + case <-donec: + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) + case <-time.After(7 * time.Second): + t.Fatal("timeout waiting for GET request to finish") + } +} + +// Tests that we don't leak Transport persistConn.readLoop goroutines +// when a server hangs up immediately after saying it would keep-alive. +func TestTransportIssue10457(t *testing.T) { + defer afterTest(t) // used to fail in goroutine leak check + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Send a response with no body, keep-alive + // (implicit), and then lie and immediately close the + // connection. This forces the Transport's readLoop to + // immediately Peek an io.EOF and get to the point + // that used to hang. + conn, _, _ := w.(http.Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive + conn.Close() + })) + defer ts.Close() + c := tc().httpClient + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + // Just a sanity check that we at least get the response. The real + // test here is that the "defer afterTest" above doesn't find any + // leaked goroutines. + if got, want := res.Header.Get("Foo"), "Bar"; got != want { + t.Errorf("Foo header = %q; want %q", got, want) + } +} + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +type writerFuncConn struct { + net.Conn + write func(p []byte) (n int, err error) +} + +func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } + +func hookSetter(dst *func()) func(func()) { + return func(fn func()) { + unnilTestHook(&fn) + *dst = fn + } +} + +var ( + SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) + SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) +) + +// Issue 6981 +func TestTransportClosesBodyOnError(t *testing.T) { + setParallel(t) + defer afterTest(t) + readBody := make(chan error, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + readBody <- err + })) + defer ts.Close() + c := tc().httpClient + fakeErr := errors.New("fake error") + didClose := make(chan bool, 1) + req, _ := http.NewRequest("POST", ts.URL, struct { + io.Reader + io.Closer + }{ + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), + closerFunc(func() error { + select { + case didClose <- true: + default: + } + return nil + }), + }) + res, err := c.Do(req) + if res != nil { + defer res.Body.Close() + } + if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { + t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) + } + select { + case err := <-readBody: + if err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") + } + case <-time.After(5 * time.Second): + t.Error("timeout waiting for server handler to complete") + } + select { + case <-didClose: + default: + t.Errorf("didn't see Body.Close") + } +} + +func TestTransportDialTLS(t *testing.T) { + setParallel(t) + defer afterTest(t) + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })) + defer ts.Close() + c := tc().httpClient + c.Transport.(*Transport).DialTLSContext = func(_ context.Context, netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.Handshake() + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + +// Test for issue 8755 +// Ensure that if a proxy returns an error, it is exposed by RoundTrip +func TestRoundTripReturnsProxyError(t *testing.T) { + badProxy := func(*http.Request) (*url.URL, error) { + return nil, errors.New("errorMessage") + } + + tr := &Transport{Proxy: badProxy} + + req, _ := http.NewRequest("GET", "http://example.com", nil) + + _, err := tr.RoundTrip(req) + + if err == nil { + t.Error("Expected proxy error to be returned by RoundTrip") + } +} + +// tests that putting an idle conn after a call to CloseIdleConns does return it +func TestTransportCloseIdleConnsThenReturn(t *testing.T) { + tr := &Transport{} + wantIdle := func(when string, n int) bool { + got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn + if got == n { + return true + } + t.Errorf("%s: idle conns = %d; want %d", when, got, n) + return false + } + wantIdle("start", 0) + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("put failed") + } + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("second put failed") + } + wantIdle("after put", 2) + tr.CloseIdleConnections() + if !tr.IsIdleForTesting() { + t.Error("should be idle after CloseIdleConnections") + } + wantIdle("after close idle", 0) + if tr.PutIdleTestConn("http", "example.com") { + t.Fatal("put didn't fail") + } + wantIdle("after second put", 0) + + tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode + if tr.IsIdleForTesting() { + t.Error("shouldn't be idle after QueueForIdleConnForTesting") + } + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("after re-activation") + } + wantIdle("after final put", 1) +} + +// Test for issue 34282 +// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn +func TestTransportTraceGotConnH2IdleConns(t *testing.T) { + tr := &Transport{} + wantIdle := func(when string, n int) bool { + got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 + if got == n { + return true + } + t.Errorf("%s: idle conns = %d; want %d", when, got, n) + return false + } + wantIdle("start", 0) + alt := funcRoundTripper(func() {}) + if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { + t.Fatal("put failed") + } + wantIdle("after put", 1) + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(httptrace.GotConnInfo) { + // tr.getConn should leave it for the HTTP/2 alt to call GotConn. + t.Error("GotConn called") + }, + }) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) + _, err := tr.RoundTrip(req) + if err != errFakeRoundTrip { + t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) + } + wantIdle("after round trip", 1) +} + +func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + trFunc := func(tr *Transport) { + tr.MaxConnsPerHost = 1 + tr.MaxIdleConnsPerHost = 1 + tr.IdleConnTimeout = 10 * time.Millisecond + } + cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), trFunc) + defer cst.close() + + if _, err := cst.c.Get(cst.ts.URL); err != nil { + t.Fatalf("got error: %s", err) + } + + time.Sleep(100 * time.Millisecond) + got := make(chan error) + go func() { + if _, err := cst.c.Get(cst.ts.URL); err != nil { + got <- err + } + close(got) + }() + + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + select { + case err := <-got: + if err != nil { + t.Fatalf("got error: %s", err) + } + case <-timeout.C: + t.Fatal("request never completed") + } +} + +// This tests that a client requesting a content range won't also +// implicitly ask for gzip support. If they want that, they need to do it +// on their own. +// golang.org/issue/8923 +func TestTransportRangeAndGzip(t *testing.T) { + defer afterTest(t) + reqc := make(chan *http.Request, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqc <- r + })) + defer ts.Close() + c := tc().httpClient + + req, _ := http.NewRequest("GET", ts.URL, nil) + req.Header.Set("Range", "bytes=7-11") + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + select { + case r := <-reqc: + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + t.Error("Transport advertised gzip support in the Accept header") + } + if r.Header.Get("Range") == "" { + t.Error("no Range in request") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + res.Body.Close() +} + +// Test for issue 10474 +func TestTransportResponseCancelRace(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // important that this response has a body. + var b [1024]byte + w.Write(b[:]) + })) + defer ts.Close() + tr := tc().t + + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + // If we do an early close, Transport just throws the connection away and + // doesn't reuse it. In order to trigger the bug, it has to reuse the connection + // so read the body + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Fatal(err) + } + + req2, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + tr.CancelRequest(req) + res, err = tr.RoundTrip(req2) + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +// Test for issue 19248: Content-Encoding's value is case insensitive. +func TestTransportContentEncodingCaseInsensitive(t *testing.T) { + setParallel(t) + defer afterTest(t) + for _, ce := range []string{"gzip", "GZIP"} { + ce := ce + t.Run(ce, func(t *testing.T) { + const encodedString = "Hello Gopher" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", ce) + gz := gzip.NewWriter(w) + gz.Write([]byte(encodedString)) + gz.Close() + })) + defer ts.Close() + + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + + if string(body) != encodedString { + t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) + } + }) + } +} + +func TestTransportDialCancelRace(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + tr := tc().t + + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + SetEnterRoundTripHook(func() { + tr.CancelRequest(req) + }) + defer SetEnterRoundTripHook(nil) + res, err := tr.RoundTrip(req) + if err != http2errRequestCanceled { + t.Errorf("expected canceled request error; got %v", err) + if err == nil { + res.Body.Close() + } + } +} + +// logWritesConn is a net.Conn that logs each Write call to writes +// and then proxies to w. +// It proxies Read calls to a reader it receives from rch. +type logWritesConn struct { + net.Conn // nil. crash on use. + + w io.Writer + + rch <-chan io.Reader + r io.Reader // nil until received by rch + + mu sync.Mutex + writes []string +} + +func (c *logWritesConn) Write(p []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, string(p)) + return c.w.Write(p) +} + +func (c *logWritesConn) Read(p []byte) (n int, err error) { + if c.r == nil { + c.r = <-c.rch + } + return c.r.Read(p) +} + +func (c *logWritesConn) Close() error { return nil } + +// Issue 6574 +func TestTransportFlushesBodyChunks(t *testing.T) { + defer afterTest(t) + resBody := make(chan io.Reader, 1) + connr, connw := io.Pipe() // connection pipe pair + lw := &logWritesConn{ + rch: resBody, + w: connw, + } + tr := &Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return lw, nil + }, + } + bodyr, bodyw := io.Pipe() // body pipe pair + go func() { + defer bodyw.Close() + for i := 0; i < 3; i++ { + fmt.Fprintf(bodyw, "num%d\n", i) + } + }() + resc := make(chan *http.Response) + go func() { + req, _ := http.NewRequest("POST", "http://localhost:8080", bodyr) + req.Header.Set("User-Agent", "x") // known value for test + res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("RoundTrip: %v", err) + close(resc) + return + } + resc <- res + + }() + // Fully consume the request before checking the Write log vs. want. + req, err := http.ReadRequest(bufio.NewReader(connr)) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, req.Body) + + // Unblock the transport's roundTrip goroutine. + resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") + res, ok := <-resc + if !ok { + return + } + defer res.Body.Close() + + want := []string{ + "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", + "5\r\nnum0\n\r\n", + "5\r\nnum1\n\r\n", + "5\r\nnum2\n\r\n", + "0\r\n\r\n", + } + if !reflect.DeepEqual(lw.writes, want) { + t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) + } +} + +// Issue 22088: flush Transport request headers if we're not sure the body won't block on read. +func TestTransportFlushesRequestHeader(t *testing.T) { + defer afterTest(t) + gotReq := make(chan struct{}) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(gotReq) + })) + defer cst.close() + + pr, pw := io.Pipe() + req, err := http.NewRequest("POST", cst.ts.URL, pr) + if err != nil { + t.Fatal(err) + } + gotRes := make(chan struct{}) + go func() { + defer close(gotRes) + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() + }() + + select { + case <-gotReq: + pw.Close() + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for handler to get request") + } + <-gotRes +} + +// Issue 11745. +func TestTransportPrefersResponseOverWriteError(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + const contentLengthLimit = 1024 * 1024 // 1MB + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength >= contentLengthLimit { + w.WriteHeader(http.StatusBadRequest) + r.Body.Close() + return + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + c := tc().httpClient + + fail := 0 + count := 100 + bigBody := strings.Repeat("a", contentLengthLimit*2) + for i := 0; i < count; i++ { + req, err := http.NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) + if err != nil { + t.Fatal(err) + } + resp, err := c.Do(req) + if err != nil { + fail++ + t.Logf("%d = %#v", i, err) + if ue, ok := err.(*url.Error); ok { + t.Logf("urlErr = %#v", ue.Err) + if ne, ok := ue.Err.(*net.OpError); ok { + t.Logf("netOpError = %#v", ne.Err) + } + } + } else { + resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("Expected status code 400, got %v", resp.Status) + } + } + } + if fail > 0 { + t.Errorf("Failed %v out of %v\n", fail, count) + } +} + +func TestTransportAutomaticHTTP2(t *testing.T) { + tr := tc().t + testTransportAutoHTTP(t, tr, true) +} + +func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { + tr := tc().t + tr.TLSNextProto = make(map[string]func(string, TLSConn) http.RoundTripper) + testTransportAutoHTTP(t, tr, false) +} + +func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { + _, err := tr.RoundTrip(new(http.Request)) + if err == nil { + t.Error("expected error from RoundTrip") + } + if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { + t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) + } +} + +// Issue 13633: there was a race where we returned bodyless responses +// to callers before recycling the persistent connection, which meant +// a client doing two subsequent requests could end up on different +// connections. It's somewhat harmless but enough tests assume it's +// not true in order to test other things that it's worth fixing. +// Plus it's nice to be consistent and not have timing-dependent +// behavior. +func TestTransportReuseConnEmptyResponseBody(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + // Empty response body. + })) + defer cst.close() + n := 100 + if testing.Short() { + n = 10 + } + var firstAddr string + for i := 0; i < n; i++ { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + log.Fatal(err) + } + addr := res.Header.Get("X-Addr") + if i == 0 { + firstAddr = addr + } else if addr != firstAddr { + t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) + } + res.Body.Close() + } +} + +// Issue 13839 +func TestNoCrashReturningTransportAltConn(t *testing.T) { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + var wg sync.WaitGroup + SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) + defer SetPendingDialHooks(nil, nil) + + testDone := make(chan struct{}) + defer close(testDone) + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + <-testDone + sc.Close() + }() + + addr := ln.Addr().String() + + req, _ := http.NewRequest("GET", "https://fake.tld/", nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + doReturned := make(chan bool, 1) + madeRoundTripper := make(chan bool, 1) + + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, TLSConn) http.RoundTripper{ + "foo": func(authority string, c TLSConn) http.RoundTripper { + madeRoundTripper <- true + return funcRoundTripper(func() { + t.Error("foo http.RoundTripper should not be called") + }) + }, + }, + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + panic("shouldn't be called") + }, + DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + close(cancel) + <-doReturned + return tc, nil + }, + } + c := &http.Client{Transport: tr} + + _, err = c.Do(req) + if ue, ok := err.(*url.Error); !ok || ue.Err != errRequestCanceledConn { + t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) + } + + doReturned <- true + <-madeRoundTripper + wg.Wait() +} + +func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { + testTransportReuseConnection_Gzip(t, true) +} + +func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { + testTransportReuseConnection_Gzip(t, false) +} + +// Make sure we re-use underlying TCP connection for gzipped responses too. +func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { + setParallel(t) + defer afterTest(t) + addr := make(chan string, 2) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + addr <- r.RemoteAddr + w.Header().Set("Content-Encoding", "gzip") + if chunked { + w.(http.Flusher).Flush() + } + w.Write(rgz) // arbitrary gzip response + })) + defer ts.Close() + c := tc().httpClient + + for i := 0; i < 2; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, len(rgz)) + if n, err := io.ReadFull(res.Body, buf); err != nil { + t.Errorf("%d. ReadFull = %v, %v", i, n, err) + } + // Note: no res.Body.Close call. It should work without it, + // since the flate.Reader's internal buffering will hit EOF + // and that should be sufficient. + } + a1, a2 := <-addr, <-addr + if a1 != a2 { + t.Fatalf("didn't reuse connection") + } +} + +func TestTransportResponseHeaderLength(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/long" { + w.Header().Set("Long", strings.Repeat("a", 1<<20)) + } + })) + defer ts.Close() + c := tc().httpClient + c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 + + if res, err := c.Get(ts.URL); err != nil { + t.Fatal(err) + } else { + res.Body.Close() + } + + res, err := c.Get(ts.URL + "/long") + if err == nil { + defer res.Body.Close() + var n int64 + for k, vv := range res.Header { + for _, v := range vv { + n += int64(len(k)) + int64(len(v)) + } + } + t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) + } + if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { + t.Errorf("got error: %v; want %q", err, want) + } +} + +type lookupIPAltResolverKey struct{} + +func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { + defer afterTest(t) + const resBody = "some body" + gotWroteReqEvent := make(chan struct{}, 500) + cst := newClientServerTest(t, h2, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + // Do nothing for the second request. + return + } + if _, err := io.ReadAll(r.Body); err != nil { + t.Error(err) + } + if !noHooks { + select { + case <-gotWroteReqEvent: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for WroteRequest event") + } + } + io.WriteString(w, resBody) + })) + defer cst.close() + + cst.tr.ExpectContinueTimeout = 1 * time.Second + + var mu sync.Mutex // guards buf + var buf bytes.Buffer + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + addrStr := cst.ts.Listener.Addr().String() + ip, port, err := net.SplitHostPort(addrStr) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), lookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { + if host != "dns-is-faked.golang" { + t.Errorf("unexpected DNS host lookup for %q/%q", network, host) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + body := "some body" + req, _ := http.NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) + req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, + GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, + GotFirstResponseByte: func() { logf("first response byte") }, + PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, + ConnectDone: func(network, addr string, err error) { + if err != nil { + t.Errorf("ConnectDone: %v", err) + } + logf("ConnectDone: connected to %s %s = %v", network, addr, err) + }, + WroteHeaderField: func(key string, value []string) { + logf("WroteHeaderField: %s: %v", key, value) + }, + WroteHeaders: func() { + logf("WroteHeaders") + }, + Wait100Continue: func() { logf("Wait100Continue") }, + Got100Continue: func() { logf("Got100Continue") }, + WroteRequest: func(e httptrace.WroteRequestInfo) { + logf("WroteRequest: %+v", e) + gotWroteReqEvent <- struct{}{} + }, + } + if h2 { + trace.TLSHandshakeStart = func() { logf("tls handshake start") } + trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { + logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) + } + } + if noHooks { + // zero out all func pointers, trying to get some path to crash + *trace = httptrace.ClientTrace{} + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + logf("got roundtrip.response") + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + logf("consumed body") + if string(slurp) != resBody || res.StatusCode != 200 { + t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) + } + res.Body.Close() + + if noHooks { + // Done at this point. Just testing a full HTTP + // requests can happen with a trace pointing to a zero + // ClientTrace, full of nil func pointers. + return + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantOnce := func(sub string) { + if strings.Count(got, sub) != 1 { + t.Errorf("expected substring %q exactly once in output.", sub) + } + } + wantOnceOrMore := func(sub string) { + if strings.Count(got, sub) == 0 { + t.Errorf("expected substring %q at least once in output.", sub) + } + } + wantOnce("Getting conn for dns-is-faked.golang:" + port) + wantOnce("DNS start: {Host:dns-is-faked.golang}") + wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err: Coalesced:false}") + wantOnce("got conn: {") + wantOnceOrMore("Connecting to tcp " + addrStr) + wantOnceOrMore("connected to tcp " + addrStr + " = ") + wantOnce("Reused:false WasIdle:false IdleTime:0s") + wantOnce("first response byte") + if h2 { + wantOnce("tls handshake start") + wantOnce("tls handshake done") + } else { + wantOnce("PutIdleConn = ") + wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") + // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the + // WroteHeaderField hook is not yet implemented in h2.) + wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) + wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) + wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") + wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") + } + wantOnce("WroteHeaders") + wantOnce("Wait100Continue") + wantOnce("Got100Continue") + wantOnce("WroteRequest: {Err:}") + if strings.Contains(got, " to udp ") { + t.Errorf("should not see UDP (DNS) connections") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } + + // And do a second request: + req, _ = http.NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + res, err = cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Fatal(res.Status) + } + res.Body.Close() + + mu.Lock() + got = buf.String() + mu.Unlock() + + sub := "Getting conn for dns-is-faked.golang:" + if gotn, want := strings.Count(got, sub), 2; gotn != want { + t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) + } + +} + +func TestTransportEventTraceTLSVerify(t *testing.T) { + var mu sync.Mutex + var buf bytes.Buffer + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Unexpected request") + })) + defer ts.Close() + ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { + logf("%s", p) + return len(p), nil + }), "", 0) + + certpool := x509.NewCertPool() + certpool.AddCert(ts.Certificate()) + + c := &http.Client{Transport: &Transport{ + TLSClientConfig: &tls.Config{ + ServerName: "dns-is-faked.golang", + RootCAs: certpool, + }, + }} + + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) + }, + } + + req, _ := http.NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) + _, err := c.Do(req) + if err == nil { + t.Error("Expected request to fail TLS verification") + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantOnce := func(sub string) { + if strings.Count(got, sub) != 1 { + t.Errorf("expected substring %q exactly once in output.", sub) + } + } + + wantOnce("TLSHandshakeStart") + wantOnce("TLSHandshakeDone") + wantOnce("err = x509: certificate is valid for example.com") + + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +var ( + isDNSHijackedOnce sync.Once + isDNSHijacked bool +) + +func skipIfDNSHijacked(t *testing.T) { + // Skip this test if the user is using a shady/ISP + // DNS server hijacking queries. + // See issues 16732, 16716. + isDNSHijackedOnce.Do(func() { + addrs, _ := net.LookupHost("dns-should-not-resolve.golang") + isDNSHijacked = len(addrs) != 0 + }) + if isDNSHijacked { + t.Skip("skipping; test requires non-hijacking DNS server") + } +} + +func TestTransportEventTraceRealDNS(t *testing.T) { + skipIfDNSHijacked(t) + defer afterTest(t) + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + var mu sync.Mutex // guards buf + var buf bytes.Buffer + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + req, _ := http.NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) + trace := &httptrace.ClientTrace{ + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, + ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, + } + req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) + + resp, err := c.Do(req) + if err == nil { + resp.Body.Close() + t.Fatal("expected error during DNS lookup") + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantSub := func(sub string) { + if !strings.Contains(got, sub) { + t.Errorf("expected substring %q in output.", sub) + } + } + wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") + wantSub("DNSDone: {Addrs:[] Err:") + if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { + t.Errorf("should not see Connect events") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +// Issue 14353: port can only contain digits. +func TestTransportRejectsAlphaPort(t *testing.T) { + res, err := http.Get("http://dummy.tld:123foo/bar") + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %#v; want *url.Error", err) + } + got := ue.Err.Error() + want := `invalid port ":123foo" after host` + if got != want { + t.Errorf("got error %q; want %q", got, want) + } +} + +// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 +// connections. The http2 test is done in TestTransportEventTrace_h2 +func TestTLSHandshakeTrace(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + + var mu sync.Mutex + var start, done bool + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { + mu.Lock() + defer mu.Unlock() + start = true + }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + mu.Lock() + defer mu.Unlock() + done = true + if err != nil { + t.Fatal("Expected error to be nil but was:", err) + } + }, + } + + c := tc().httpClient + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal("Unable to construct test request:", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + r, err := c.Do(req) + if err != nil { + t.Fatal("Unexpected error making request:", err) + } + r.Body.Close() + mu.Lock() + defer mu.Unlock() + if !start { + t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") + } + if !done { + t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") + } +} + +// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an +// HTTP/2 connection was established but its caller no longer +// wanted it. (Assuming the connection cache was enabled, which it is +// by default) +// +// This test reproduced the crash by setting the IdleConnTimeout low +// (to make the test reasonable) and then making a request which is +// canceled by the DialTLS hook, which then also waits to return the +// real connection until after the RoundTrip saw the error. Then we +// know the successful tls.Dial from DialTLS will need to go into the +// idle pool. Then we give it a of time to explode. +func TestIdleConnH2Crash(t *testing.T) { + setParallel(t) + cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // nothing + })) + defer cst.close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sawDoErr := make(chan bool, 1) + testDone := make(chan struct{}) + defer close(testDone) + + cst.tr.IdleConnTimeout = 5 * time.Millisecond + cst.tr.DialTLSContext = func(_ context.Context, network, addr string) (net.Conn, error) { + c, err := tls.Dial(network, addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err != nil { + t.Error(err) + return nil, err + } + if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { + t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") + c.Close() + return nil, errors.New("bogus") + } + + cancel() + + failTimer := time.NewTimer(5 * time.Second) + defer failTimer.Stop() + select { + case <-sawDoErr: + case <-testDone: + case <-failTimer.C: + t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") + } + return c, nil + } + + req, _ := http.NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(ctx) + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + sawDoErr <- true + + // Wait for the explosion. + time.Sleep(cst.tr.IdleConnTimeout * 10) +} + +type funcConn struct { + net.Conn + read func([]byte) (int, error) + write func([]byte) (int, error) +} + +func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } +func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } +func (c funcConn) Close() error { return nil } + +// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek +// back to the caller. +func TestTransportReturnsPeekError(t *testing.T) { + errValue := errors.New("specific error value") + + wrote := make(chan struct{}) + var wroteOnce sync.Once + + tr := &Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + c := funcConn{ + read: func([]byte) (int, error) { + <-wrote + return 0, errValue + }, + write: func(p []byte) (int, error) { + wroteOnce.Do(func() { close(wrote) }) + return len(p), nil + }, + } + return c, nil + }, + } + _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) + if err != errValue { + t.Errorf("error = %#v; want %v", err, errValue) + } +} + +// Issue 13290: send User-Agent in proxy CONNECT +func TestTransportProxyConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *http.Request, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + + c := tc().httpClient + c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + c.Transport.(*Transport).ProxyConnectHeader = http.Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + +func TestTransportProxyGetConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *http.Request, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + + c := tc().httpClient + c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + // These should be ignored: + c.Transport.(*Transport).ProxyConnectHeader = http.Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { + return http.Header{ + "User-Agent": {"foo2"}, + "Other": {"bar2"}, + }, nil + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + +var errFakeRoundTrip = errors.New("fake roundtrip") + +type funcRoundTripper func() + +func (fn funcRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + fn() + return nil, errFakeRoundTrip +} + +func wantBody(res *http.Response, err error, want string) error { + if err != nil { + return err + } + slurp, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading body: %v", err) + } + if string(slurp) != want { + return fmt.Errorf("body = %q; want %q", slurp, want) + } + if err := res.Body.Close(); err != nil { + return fmt.Errorf("body Close = %v", err) + } + return nil +} + +type countCloseReader struct { + n *int + io.Reader +} + +func (cr countCloseReader) Close() error { + (*cr.n)++ + return nil +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} + +// Ensure that a missing status doesn't make the server panic +// See Issue https://golang.org/issues/21701 +func TestMissingStatusNoPanic(t *testing.T) { + t.Parallel() + + const want = "unknown status code" + + ln := newLocalListener(t) + addr := ln.Addr().String() + done := make(chan bool) + fullAddrURL := fmt.Sprintf("http://%s", addr) + raw := "HTTP/1.1 400\r\n" + + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + + "Content-Type: text/html; charset=utf-8\r\n" + + "Content-Length: 10\r\n" + + "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + + "Vary: Accept-Encoding\r\n\r\n" + + "Aloha Olaa" + + go func() { + defer close(done) + + conn, _ := ln.Accept() + if conn != nil { + io.WriteString(conn, raw) + io.ReadAll(conn) + conn.Close() + } + }() + + proxyURL, err := url.Parse(fullAddrURL) + if err != nil { + t.Fatalf("proxyURL: %v", err) + } + + tr := &Transport{Proxy: http.ProxyURL(proxyURL)} + + req, _ := http.NewRequest("GET", "https://golang.org/", nil) + res, err, panicked := doFetchCheckPanic(tr, req) + if panicked { + t.Error("panicked, expecting an error") + } + if res != nil && res.Body != nil { + io.Copy(io.Discard, res.Body) + res.Body.Close() + } + + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("got=%v want=%q", err, want) + } + + ln.Close() + <-done +} + +func doFetchCheckPanic(tr *Transport, req *http.Request) (res *http.Response, err error, panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + res, err = tr.RoundTrip(req) + return +} + +// Issue 22330: do not allow the response body to be read when the status code +// forbids a response body. +func TestNoBodyOnChunked304Response(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, buf, _ := w.(http.Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + + // Our test server above is sending back bogus data after the + // response (the "0\r\n\r\n" part), which causes the Transport + // code to log spam. Disable keep-alives so we never even try + // to reuse the connection. + cst.tr.DisableKeepAlives = true + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + if res.Body != NoBody { + t.Errorf("Unexpected body on 304 response") + } +} + +type funcWriter func([]byte) (int, error) + +func (f funcWriter) Write(p []byte) (int, error) { return f(p) } + +type doneContext struct { + context.Context + err error +} + +func (doneContext) Done() <-chan struct{} { + c := make(chan struct{}) + close(c) + return c +} + +func (d doneContext) Err() error { return d.err } + +// Issue 25852: Transport should check whether Context is done early. +func TestTransportCheckContextDoneEarly(t *testing.T) { + tr := &Transport{} + req, _ := http.NewRequest("GET", "http://fake.example/", nil) + wantErr := errors.New("some error") + req = req.WithContext(doneContext{context.Background(), wantErr}) + _, err := tr.RoundTrip(req) + if err != wantErr { + t.Errorf("error = %v; want %v", err, wantErr) + } +} + +// Issue 23399: verify that if a client request times out, the Transport's +// conn is closed so that it's not reused. +// +// This is the test variant that times out before the server replies with +// any response headers. +func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { + setParallel(t) + defer afterTest(t) + inHandler := make(chan net.Conn, 1) + handlerReadReturned := make(chan bool, 1) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + inHandler <- conn + n, err := conn.Read([]byte{0}) + if n != 0 || err != io.EOF { + t.Errorf("unexpected Read result: %v, %v", n, err) + } + handlerReadReturned <- true + })) + defer cst.close() + + const timeout = 50 * time.Millisecond + cst.c.Timeout = timeout + + _, err := cst.c.Get(cst.ts.URL) + if err == nil { + t.Fatal("unexpected Get succeess") + } + + select { + case c := <-inHandler: + select { + case <-handlerReadReturned: + // Success. + return + case <-time.After(5 * time.Second): + t.Error("Handler's conn.Read seems to be stuck in Read") + c.Close() // close it to unblock Handler + } + case <-time.After(timeout * 10): + // If we didn't get into the Handler in 50ms, that probably means + // the builder was just slow and the Get failed in that time + // but never made it to the server. That's fine. We'll usually + // test the part above on faster machines. + t.Skip("skipping test on slow builder") + } +} + +// Issue 23399: verify that if a client request times out, the Transport's +// conn is closed so that it's not reused. +// +// This is the test variant that has the server send response headers +// first, and time out during the write of the response body. +func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { + setParallel(t) + defer afterTest(t) + inHandler := make(chan net.Conn, 1) + handlerResult := make(chan error, 1) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "100") + w.(http.Flusher).Flush() + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + conn.Write([]byte("foo")) + inHandler <- conn + n, err := conn.Read([]byte{0}) + // The error should be io.EOF or "read tcp + // 127.0.0.1:35827->127.0.0.1:40290: read: connection + // reset by peer" depending on timing. Really we just + // care that it returns at all. But if it returns with + // data, that's weird. + if n != 0 || err == nil { + handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) + return + } + handlerResult <- nil + })) + defer cst.close() + + // Set Timeout to something very long but non-zero to exercise + // the codepaths that check for it. But rather than wait for it to fire + // (which would make the test slow), we send on the req.Cancel channel instead, + // which happens to exercise the same code paths. + cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. + req, _ := http.NewRequest("GET", cst.ts.URL, nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := cst.c.Do(req) + if err != nil { + select { + case <-inHandler: + t.Fatalf("Get error: %v", err) + default: + // Failed before entering handler. Ignore result. + t.Skip("skipping test on slow builder") + } + } + + close(cancel) + got, err := io.ReadAll(res.Body) + if err == nil { + t.Fatalf("unexpected success; read %q, nil", got) + } + + select { + case c := <-inHandler: + select { + case err := <-handlerResult: + if err != nil { + t.Errorf("handler: %v", err) + } + return + case <-time.After(5 * time.Second): + t.Error("Handler's conn.Read seems to be stuck in Read") + c.Close() // close it to unblock Handler + } + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} + +func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { + setParallel(t) + defer afterTest(t) + done := make(chan struct{}) + defer close(done) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") + bs := bufio.NewScanner(conn) + bs.Scan() + fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) + <-done + })) + defer cst.close() + + req, _ := http.NewRequest("GET", cst.ts.URL, nil) + req.Header.Set("Upgrade", "foo") + req.Header.Set("Connection", "upgrade") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) + } + defer rwc.Close() + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("expected readable input") + } + if got, want := bs.Text(), "Some buffered data"; got != want { + t.Errorf("read %q; want %q", got, want) + } + io.WriteString(rwc, "echo\n") + if !bs.Scan() { + t.Fatalf("expected another line") + } + if got, want := bs.Text(), "ECHO"; got != want { + t.Errorf("read %q; want %q", got, want) + } +} + +func TestTransportRequestReplayable(t *testing.T) { + someBody := io.NopCloser(strings.NewReader("")) + tests := []struct { + name string + req *http.Request + want bool + }{ + { + name: "GET", + req: &http.Request{Method: "GET"}, + want: true, + }, + { + name: "GET_http.NoBody", + req: &http.Request{Method: "GET", Body: NoBody}, + want: true, + }, + { + name: "GET_body", + req: &http.Request{Method: "GET", Body: someBody}, + want: false, + }, + { + name: "POST", + req: &http.Request{Method: "POST"}, + want: false, + }, + { + name: "POST_idempotency-key", + req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}}, + want: true, + }, + { + name: "POST_x-idempotency-key", + req: &http.Request{Method: "POST", Header: http.Header{"X-Idempotency-Key": {"x"}}}, + want: true, + }, + { + name: "POST_body", + req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}, Body: someBody}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isReplayable(tt.req) + if got != tt.want { + t.Errorf("replyable = %v; want %v", got, tt.want) + } + }) + } +} + +// testMockTCPConn is a mock TCP connection used to test that +// ReadFrom is called when sending the request body. +type testMockTCPConn struct { + *net.TCPConn + + ReadFromCalled bool +} + +func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { + c.ReadFromCalled = true + return c.TCPConn.ReadFrom(r) +} + +func TestTransportRequestWriteRoundTrip(t *testing.T) { + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := os.CreateTemp("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + readerFunc func() (io.Reader, func(), error) + contentLength int64 + expectedReadFrom bool + }{ + { + name: "file, length", + readerFunc: newFileFunc, + contentLength: nBytes, + expectedReadFrom: true, + }, + { + name: "file, no length", + readerFunc: newFileFunc, + }, + { + name: "file, negative length", + readerFunc: newFileFunc, + contentLength: -1, + }, + { + name: "buffer", + contentLength: nBytes, + readerFunc: newBufferFunc, + }, + { + name: "buffer, no length", + readerFunc: newBufferFunc, + }, + { + name: "buffer, length -1", + contentLength: -1, + readerFunc: newBufferFunc, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, cleanup, err := tc.readerFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + tConn := &testMockTCPConn{} + trFunc := func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) + } + + tConn.TCPConn = tcpConn + return tConn, nil + } + } + + cst := newClientServerTest( + t, + h1Mode, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(io.Discard, r.Body) + r.Body.Close() + w.WriteHeader(200) + }), + trFunc, + ) + defer cst.close() + + req, err := http.NewRequest("PUT", cst.ts.URL, r) + if err != nil { + t.Fatal(err) + } + req.ContentLength = tc.contentLength + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", resp.StatusCode) + } + + if !tConn.ReadFromCalled && tc.expectedReadFrom { + t.Fatalf("did not call ReadFrom") + } + + if tConn.ReadFromCalled && !tc.expectedReadFrom { + t.Fatalf("ReadFrom was unexpectedly invoked") + } + }) + } +} + +func TestTransportClone(t *testing.T) { + tr := &Transport{ + Proxy: func(*http.Request) (*url.URL, error) { panic("") }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: http.Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, + MaxResponseHeaderBytes: 1, + ForceAttemptHTTP2: true, + TLSNextProto: map[string]func(authority string, c TLSConn) http.RoundTripper{ + "foo": func(authority string, c TLSConn) http.RoundTripper { panic("") }, + }, + ReadBufferSize: 1, + WriteBufferSize: 1, + ForceHttpVersion: HTTP1, + ResponseOptions: &ResponseOptions{}, + Debugf: func(format string, v ...interface{}) {}, + } + tr2 := tr.Clone() + rv := reflect.ValueOf(tr2).Elem() + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !token.IsExported(sf.Name) { + continue + } + if rv.Field(i).IsZero() { + t.Errorf("cloned field t2.%s is zero", sf.Name) + } + } + + if _, ok := tr2.TLSNextProto["foo"]; !ok { + t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") + } + + // But test that a nil TLSNextProto is kept nil: + tr = new(Transport) + tr2 = tr.Clone() + if tr2.TLSNextProto != nil { + t.Errorf("Transport.TLSNextProto unexpected non-nil") + } +} + +func TestIs408(t *testing.T) { + tests := []struct { + in string + want bool + }{ + {"HTTP/1.0 408", true}, + {"HTTP/1.1 408", true}, + {"HTTP/1.8 408", true}, + {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. + {"HTTP/1.1 408 ", true}, + {"HTTP/1.1 40", false}, + {"http/1.0 408", false}, + {"HTTP/1-1 408", false}, + } + for _, tt := range tests { + if got := is408Message([]byte(tt.in)); got != tt.want { + t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) + } + } +} + +func TestTransportIgnores408(t *testing.T) { + // Not parallel. Relies on mutating the log package's global Output. + defer log.SetOutput(log.Writer()) + + var logout bytes.Buffer + log.SetOutput(&logout) + + defer afterTest(t) + const target = "backend:443" + + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nc, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer nc.Close() + nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) + nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail + })) + defer cst.close() + req, err := http.NewRequest("GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("got %q; want ok", slurp) + } + + t0 := time.Now() + for i := 0; i < 50; i++ { + time.Sleep(time.Duration(i) * 5 * time.Millisecond) + if cst.tr.IdleConnKeyCountForTesting() == 0 { + if got := logout.String(); got != "" { + t.Fatalf("expected no log output; got: %s", got) + } + return + } + } + t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) +} + +func TestInvalidHeaderResponse(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, buf, _ := w.(http.Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 200 OK\r\n" + + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + + "Content-Type: text/html; charset=utf-8\r\n" + + "Content-Length: 0\r\n" + + "Foo : bar\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if v := res.Header.Get("Foo"); v != "" { + t.Errorf(`unexpected "Foo" header: %q`, v) + } + if v := res.Header.Get("Foo "); v != "bar" { + t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") + } +} + +type bodyCloser bool + +func (bc *bodyCloser) Close() error { + *bc = true + return nil +} +func (bc *bodyCloser) Read(b []byte) (n int, err error) { + return 0, io.EOF +} + +// Issue 35015: ensure that Transport closes the body on any error +// with an invalid request, as promised by Client.Do docs. +func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Errorf("Should not have been invoked") + })) + defer cst.Close() + + u, _ := url.Parse(cst.URL) + + tests := []struct { + name string + req *http.Request + wantErr string + }{ + { + name: "invalid method", + req: &http.Request{ + Method: " ", + URL: u, + }, + wantErr: "invalid method", + }, + { + name: "nil URL", + req: &http.Request{ + Method: "GET", + }, + wantErr: "nil Request.URL", + }, + { + name: "invalid header key", + req: &http.Request{ + Method: "GET", + Header: http.Header{"💡": {"emoji"}}, + URL: u, + }, + wantErr: "invalid header field name", + }, + { + name: "invalid header value", + req: &http.Request{ + Method: "POST", + Header: http.Header{"key": {"\x19"}}, + URL: u, + }, + wantErr: "invalid header field value", + }, + { + name: "non HTTP(s) scheme", + req: &http.Request{ + Method: "POST", + URL: &url.URL{Scheme: "faux"}, + }, + wantErr: "unsupported protocol scheme", + }, + { + name: "no Host in URL", + req: &http.Request{ + Method: "POST", + URL: &url.URL{Scheme: "http"}, + }, + wantErr: "no Host", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var bc bodyCloser + req := tt.req + req.Body = &bc + _, err := DefaultClient().httpClient.Do(tt.req) + if err == nil { + t.Fatal("Expected an error") + } + if !bc { + t.Fatal("Expected body to have been closed") + } + if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { + t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w) + } + }) + } +} + +// breakableConn is a net.Conn wrapper with a Write method +// that will fail when its brokenState is true. +type breakableConn struct { + net.Conn + *brokenState +} + +type brokenState struct { + sync.Mutex + broken bool +} + +func (w *breakableConn) Write(b []byte) (n int, err error) { + w.Lock() + defer w.Unlock() + if w.broken { + return 0, errors.New("some write error") + } + return w.Conn.Write(b) +} + +// Issue 34978: don't cache a broken HTTP/2 connection +func TestDontCacheBrokenHTTP2Conn(t *testing.T) { + cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), optQuietLog) + defer cst.close() + + var brokenState brokenState + + const numReqs = 5 + var numDials, gotConns uint32 // atomic + + cst.tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { + atomic.AddUint32(&numDials, 1) + c, err := net.Dial(netw, addr) + if err != nil { + t.Errorf("unexpected Dial error: %v", err) + return nil, err + } + return &breakableConn{c, &brokenState}, err + } + + for i := 1; i <= numReqs; i++ { + brokenState.Lock() + brokenState.broken = false + brokenState.Unlock() + + // doBreak controls whether we break the TCP connection after the TLS + // handshake (before the HTTP/2 handshake). We test a few failures + // in a row followed by a final success. + doBreak := i != numReqs + + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) + atomic.AddUint32(&gotConns, 1) + }, + TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { + brokenState.Lock() + defer brokenState.Unlock() + if doBreak { + brokenState.broken = true + } + }, + }) + req, err := http.NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + _, err = cst.c.Do(req) + if doBreak != (err != nil) { + t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) + } + } + if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { + t.Errorf("GotConn calls = %v; want %v", got, want) + } + if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { + t.Errorf("Dials = %v; want %v", got, want) + } +} + +// Issue 34941 +// When the client has too many concurrent requests on a single connection, +// http.http2noCachedConnError is reported on multiple requests. There should +// only be one decrement regardless of the number of failures. +func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { + defer afterTest(t) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + ts := httptest.NewUnstartedServer(h) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + c := tc().httpClient + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + errCh := make(chan error, 300) + doReq := func() { + resp, err := c.Get(ts.URL) + if err != nil { + errCh <- fmt.Errorf("request failed: %v", err) + return + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + errCh <- fmt.Errorf("read body failed: %v", err) + } + } + + var wg sync.WaitGroup + for i := 0; i < 300; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + close(errCh) + + for err := range errCh { + t.Errorf("error occurred: %v", err) + } +} + +// Issue 36820 +// Test that we use the older backward compatible cancellation protocol +// when a http.RoundTripper is registered via RegisterProtocol. +func TestAltProtoCancellation(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &http.Client{ + Transport: tr, + Timeout: time.Millisecond, + } + tr.RegisterProtocol("timeout", timeoutProto{}) + _, err := c.Get("timeout://bar.com/path") + if err == nil { + t.Error("request unexpectedly succeeded") + } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { + t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) + } +} + +var timeoutProtoErr = errors.New("canceled as expected") + +type timeoutProto struct{} + +func (timeoutProto) RoundTrip(req *http.Request) (*http.Response, error) { + select { + case <-req.Cancel: + return nil, timeoutProtoErr + case <-time.After(5 * time.Second): + return nil, errors.New("request was not canceled") + } +} + +// Issue 32441: body is not reset after ErrSkipAltProtocol +func TestIssue32441(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { + t.Error("body length is zero") + } + })) + defer ts.Close() + c := tc().httpClient + c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *http.Request) (*http.Response, error) { + // Draining body to trigger failure condition on actual request to server. + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { + t.Error("body length is zero during round trip") + } + return nil, http.ErrSkipAltProtocol + })) + if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { + t.Error(err) + } +} + +// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers +// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. +func TestTransportRejectsSignInContentLength(t *testing.T) { + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "+3") + w.Write([]byte("abc")) + })) + defer cst.Close() + + c := cst.Client() + res, err := c.Get(cst.URL) + if err == nil || res != nil { + t.Fatal("Expected a non-nil error and a nil http.Response") + } + if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { + t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) + } +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, errors.New("delegate closed") + } + } + return r.r.Read(p) +} + +func testTransportRace(req *http.Request) { + save := req.Body + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + DialContext: func(_ context.Context, net, addr string) (net.Conn, error) { + return &dumpConn{pw, dr}, nil + }, + } + defer t.CloseIdleConnections() + + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + defer close(quitReadCh) + + req, err := http.ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(io.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case quitReadCh <- struct{}{}: + // Ensure delegate is closed so Read doesn't block forever. + close(dr.c) + } + }() + + t.RoundTrip(req) + + // Ensure the reader returns before we reset req.Body to prevent + // a data race on req.Body. + pw.Close() + <-quitReadCh + + req.Body = save +} + +// Issue 37669 +// Test that a cancellation doesn't result in a data race due to the writeLoop +// goroutine being left running, if the caller mutates the processed Request +// upon completion. +func TestErrorWriteLoopRace(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + for i := 0; i < 1000; i++ { + delay := time.Duration(mrand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + testTransportRace(req) + } +} + +// Issue 41600 +// Test that a new request which uses the connection of an active request +// cannot cause it to be canceled as well. +func TestCancelRequestWhenSharingConnection(t *testing.T) { + reqc := make(chan chan struct{}, 2) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ch := make(chan struct{}, 1) + reqc <- ch + <-ch + w.Header().Add("Content-Length", "0") + })) + defer ts.Close() + + client := tc().httpClient + transport := client.Transport.(*Transport) + transport.MaxIdleConns = 1 + transport.MaxConnsPerHost = 1 + + var wg sync.WaitGroup + + wg.Add(1) + putidlec := make(chan chan struct{}) + go func() { + defer wg.Done() + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + PutIdleConn: func(error) { + // Signal that the idle conn has been returned to the pool, + // and wait for the order to proceed. + ch := make(chan struct{}) + putidlec <- ch + <-ch + }, + }) + req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL, nil) + res, err := client.Do(req) + if err == nil { + res.Body.Close() + } + if err != nil { + t.Errorf("request 1: got err %v, want nil", err) + } + }() + + // Wait for the first request to receive a response and return the + // connection to the idle pool. + r1c := <-reqc + close(r1c) + idlec := <-putidlec + + wg.Add(1) + cancelctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + req, _ := http.NewRequestWithContext(cancelctx, "GET", ts.URL, nil) + res, err := client.Do(req) + if err == nil { + res.Body.Close() + } + if !errors.Is(err, context.Canceled) { + t.Errorf("request 2: got err %v, want Canceled", err) + } + }() + + // Wait for the second request to arrive at the server, and then cancel + // the request context. + r2c := <-reqc + cancel() + + // Give the cancelation a moment to take effect, and then unblock the first request. + time.Sleep(1 * time.Millisecond) + close(idlec) + + close(r2c) + wg.Wait() +} From 514b188e9d42823c14f8f633e3ffbc47b7e16dd1 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 17:37:13 +0800 Subject: [PATCH 341/843] extract socks internal package --- internal/socks/client.go | 168 +++++++++++++++ internal/socks/dial_test.go | 395 ++++++++++++++++++++++++++++++++++++ internal/socks/socks.go | 288 ++++++++++++++++++++++++++ transport.go | 11 +- 4 files changed, 857 insertions(+), 5 deletions(-) create mode 100644 internal/socks/client.go create mode 100644 internal/socks/dial_test.go create mode 100644 internal/socks/socks.go diff --git a/internal/socks/client.go b/internal/socks/client.go new file mode 100644 index 00000000..3d6f516a --- /dev/null +++ b/internal/socks/client.go @@ -0,0 +1,168 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "time" +) + +var ( + noDeadline = time.Time{} + aLongTimeAgo = time.Unix(1, 0) +) + +func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { + host, port, err := splitHostPort(address) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + c.SetDeadline(deadline) + defer c.SetDeadline(noDeadline) + } + if ctx != context.Background() { + errCh := make(chan error, 1) + done := make(chan struct{}) + defer func() { + close(done) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + go func() { + select { + case <-ctx.Done(): + c.SetDeadline(aLongTimeAgo) + errCh <- ctx.Err() + case <-done: + errCh <- nil + } + }() + } + + b := make([]byte, 0, 6+len(host)) // the size here is just an estimate + b = append(b, Version5) + if len(d.AuthMethods) == 0 || d.Authenticate == nil { + b = append(b, 1, byte(AuthMethodNotRequired)) + } else { + ams := d.AuthMethods + if len(ams) > 255 { + return nil, errors.New("too many authentication methods") + } + b = append(b, byte(len(ams))) + for _, am := range ams { + b = append(b, byte(am)) + } + } + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + am := AuthMethod(b[1]) + if am == AuthMethodNoAcceptableMethods { + return nil, errors.New("no acceptable authentication methods") + } + if d.Authenticate != nil { + if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { + return + } + } + + b = b[:0] + b = append(b, Version5, byte(d.cmd), 0) + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + b = append(b, AddrTypeIPv4) + b = append(b, ip4...) + } else if ip6 := ip.To16(); ip6 != nil { + b = append(b, AddrTypeIPv6) + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + } else { + if len(host) > 255 { + return nil, errors.New("FQDN too long") + } + b = append(b, AddrTypeFQDN) + b = append(b, byte(len(host))) + b = append(b, host...) + } + b = append(b, byte(port>>8), byte(port)) + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { + return nil, errors.New("unknown error " + cmdErr.String()) + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + l := 2 + var a Addr + switch b[3] { + case AddrTypeIPv4: + l += net.IPv4len + a.IP = make(net.IP, net.IPv4len) + case AddrTypeIPv6: + l += net.IPv6len + a.IP = make(net.IP, net.IPv6len) + case AddrTypeFQDN: + if _, err := io.ReadFull(c, b[:1]); err != nil { + return nil, err + } + l += int(b[0]) + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + } + if cap(b) < l { + b = make([]byte, l) + } else { + b = b[:l] + } + if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { + return + } + if a.IP != nil { + copy(a.IP, b) + } else { + a.Name = string(b[:len(b)-2]) + } + a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) + return &a, nil +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} diff --git a/internal/socks/dial_test.go b/internal/socks/dial_test.go new file mode 100644 index 00000000..7a10a57d --- /dev/null +++ b/internal/socks/dial_test.go @@ -0,0 +1,395 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socks + +import ( + "context" + "errors" + "golang.org/x/net/nettest" + "io" + "math/rand" + "net" + "os" + "testing" + "time" +) + +// An AuthRequest represents an authentication request. +type AuthRequest struct { + Version int + Methods []AuthMethod +} + +// ParseAuthRequest parses an authentication request. +func ParseAuthRequest(b []byte) (*AuthRequest, error) { + if len(b) < 2 { + return nil, errors.New("short auth request") + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version") + } + if len(b)-2 < int(b[1]) { + return nil, errors.New("short auth request") + } + req := &AuthRequest{Version: int(b[0])} + if b[1] > 0 { + req.Methods = make([]AuthMethod, b[1]) + for i, m := range b[2 : 2+b[1]] { + req.Methods[i] = AuthMethod(m) + } + } + return req, nil +} + +// MarshalAuthReply returns an authentication reply in wire format. +func MarshalAuthReply(ver int, m AuthMethod) ([]byte, error) { + return []byte{byte(ver), byte(m)}, nil +} + +// A CmdRequest repesents a command request. +type CmdRequest struct { + Version int + Cmd Command + Addr Addr +} + +// ParseCmdRequest parses a command request. +func ParseCmdRequest(b []byte) (*CmdRequest, error) { + if len(b) < 7 { + return nil, errors.New("short cmd request") + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version") + } + if Command(b[1]) != CmdConnect { + return nil, errors.New("unexpected command") + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + req := &CmdRequest{Version: int(b[0]), Cmd: Command(b[1])} + l := 2 + off := 4 + switch b[3] { + case AddrTypeIPv4: + l += net.IPv4len + req.Addr.IP = make(net.IP, net.IPv4len) + case AddrTypeIPv6: + l += net.IPv6len + req.Addr.IP = make(net.IP, net.IPv6len) + case AddrTypeFQDN: + l += int(b[4]) + off = 5 + default: + return nil, errors.New("unknown address type") + } + if len(b[off:]) < l { + return nil, errors.New("short cmd request") + } + if req.Addr.IP != nil { + copy(req.Addr.IP, b[off:]) + } else { + req.Addr.Name = string(b[off : off+l-2]) + } + req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1]) + return req, nil +} + +// MarshalCmdReply returns a command reply in wire format. +func MarshalCmdReply(ver int, reply Reply, a *Addr) ([]byte, error) { + b := make([]byte, 4) + b[0] = byte(ver) + b[1] = byte(reply) + if a.Name != "" { + if len(a.Name) > 255 { + return nil, errors.New("fqdn too long") + } + b[3] = AddrTypeFQDN + b = append(b, byte(len(a.Name))) + b = append(b, a.Name...) + } else if ip4 := a.IP.To4(); ip4 != nil { + b[3] = AddrTypeIPv4 + b = append(b, ip4...) + } else if ip6 := a.IP.To16(); ip6 != nil { + b[3] = AddrTypeIPv6 + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + b = append(b, byte(a.Port>>8), byte(a.Port)) + return b, nil +} + +// A Server repesents a server for handshake testing. +type Server struct { + ln net.Listener +} + +// Addr rerurns a server address. +func (s *Server) Addr() net.Addr { + return s.ln.Addr() +} + +// TargetAddr returns a fake final destination address. +// +// The returned address is only valid for testing with Server. +func (s *Server) TargetAddr() net.Addr { + a := s.ln.Addr() + switch a := a.(type) { + case *net.TCPAddr: + if a.IP.To4() != nil { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963} + } + if a.IP.To16() != nil && a.IP.To4() == nil { + return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963} + } + } + return nil +} + +// Close closes the server. +func (s *Server) Close() error { + return s.ln.Close() +} + +func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) { + c, err := s.ln.Accept() + if err != nil { + return + } + defer c.Close() + go s.serve(authFunc, cmdFunc) + b := make([]byte, 512) + n, err := c.Read(b) + if err != nil { + return + } + if err := authFunc(c, b[:n]); err != nil { + return + } + n, err = c.Read(b) + if err != nil { + return + } + if err := cmdFunc(c, b[:n]); err != nil { + return + } +} + +// NewServer returns a new server. +// +// The provided authFunc and cmdFunc must parse requests and return +// appropriate replies to clients. +func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) { + var err error + s := new(Server) + s.ln, err = nettest.NewLocalListener("tcp") + if err != nil { + return nil, err + } + go s.serve(authFunc, cmdFunc) + return s, nil +} + +// NoAuthRequired handles a no-authentication-required signaling. +func NoAuthRequired(rw io.ReadWriter, b []byte) error { + req, err := ParseAuthRequest(b) + if err != nil { + return err + } + b, err = MarshalAuthReply(req.Version, AuthMethodNotRequired) + if err != nil { + return err + } + n, err := rw.Write(b) + if err != nil { + return err + } + if n != len(b) { + return errors.New("short write") + } + return nil +} + +// NoProxyRequired handles a command signaling without constructing a +// proxy connection to the final destination. +func NoProxyRequired(rw io.ReadWriter, b []byte) error { + req, err := ParseCmdRequest(b) + if err != nil { + return err + } + req.Addr.Port += 1 + if req.Addr.Name != "" { + req.Addr.Name = "boundaddr.doesnotexist" + } else if req.Addr.IP.To4() != nil { + req.Addr.IP = net.IPv4(127, 0, 0, 1) + } else { + req.Addr.IP = net.IPv6loopback + } + b, err = MarshalCmdReply(Version5, StatusSucceeded, &req.Addr) + if err != nil { + return err + } + n, err := rw.Write(b) + if err != nil { + return err + } + if n != len(b) { + return errors.New("short write") + } + return nil +} + +func TestDial(t *testing.T) { + t.Run("Connect", func(t *testing.T) { + ss, err := NewServer(NoAuthRequired, NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := NewDialer(ss.Addr().Network(), ss.Addr().String()) + d.AuthMethods = []AuthMethod{ + AuthMethodNotRequired, + AuthMethodUsernamePassword, + } + d.Authenticate = (&UsernamePassword{ + Username: "username", + Password: "password", + }).Authenticate + c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String()) + if err != nil { + t.Fatal(err) + } + c.(*Conn).BoundAddr() + c.Close() + }) + t.Run("ConnectWithConn", func(t *testing.T) { + ss, err := NewServer(NoAuthRequired, NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + c, err := net.Dial(ss.Addr().Network(), ss.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + d := NewDialer(ss.Addr().Network(), ss.Addr().String()) + d.AuthMethods = []AuthMethod{ + AuthMethodNotRequired, + AuthMethodUsernamePassword, + } + d.Authenticate = (&UsernamePassword{ + Username: "username", + Password: "password", + }).Authenticate + a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String()) + if err != nil { + t.Fatal(err) + } + if _, ok := a.(*Addr); !ok { + t.Fatalf("got %+v; want Addr", a) + } + }) + t.Run("Cancel", func(t *testing.T) { + ss, err := NewServer(NoAuthRequired, blackholeCmdFunc) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := NewDialer(ss.Addr().Network(), ss.Addr().String()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dialErr := make(chan error) + go func() { + c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + if err == nil { + c.Close() + } + dialErr <- err + }() + time.Sleep(100 * time.Millisecond) + cancel() + err = <-dialErr + if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil { + t.Fatalf("got %v; want context.Canceled or equivalent", err) + } + }) + t.Run("Deadline", func(t *testing.T) { + ss, err := NewServer(NoAuthRequired, blackholeCmdFunc) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := NewDialer(ss.Addr().Network(), ss.Addr().String()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) + defer cancel() + c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + if err == nil { + c.Close() + } + if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil { + t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err) + } + }) + t.Run("WithRogueServer", func(t *testing.T) { + ss, err := NewServer(NoAuthRequired, rogueCmdFunc) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := NewDialer(ss.Addr().Network(), ss.Addr().String()) + for i := 0; i < 2*len(rogueCmdList); i++ { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) + defer cancel() + c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + if err == nil { + t.Log(c.(*Conn).BoundAddr()) + c.Close() + t.Error("should fail") + } + } + }) +} + +func blackholeCmdFunc(rw io.ReadWriter, b []byte) error { + if _, err := ParseCmdRequest(b); err != nil { + return err + } + var bb [1]byte + for { + if _, err := rw.Read(bb[:]); err != nil { + return err + } + } +} + +func rogueCmdFunc(rw io.ReadWriter, b []byte) error { + if _, err := ParseCmdRequest(b); err != nil { + return err + } + rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))]) + return nil +} + +var rogueCmdList = [][]byte{ + {0x05}, + {0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b}, + {0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b}, + {0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3}, + {0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'}, +} + +func parseDialError(err error) (perr, nerr error) { + if e, ok := err.(*net.OpError); ok { + err = e.Err + nerr = e + } + if e, ok := err.(*os.SyscallError); ok { + err = e.Err + } + perr = err + return +} diff --git a/internal/socks/socks.go b/internal/socks/socks.go new file mode 100644 index 00000000..55afd6be --- /dev/null +++ b/internal/socks/socks.go @@ -0,0 +1,288 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package socks provides a SOCKS version 5 client implementation. +// +// SOCKS protocol version 5 is defined in RFC 1928. +// Username/Password authentication for SOCKS version 5 is defined in +// RFC 1929. +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" +) + +// A Command represents a SOCKS command. +type Command int + +func (cmd Command) String() string { + switch cmd { + case CmdConnect: + return "socks connect" + case cmdBind: + return "socks bind" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +// An AuthMethod represents a SOCKS authentication method. +type AuthMethod int + +// A Reply represents a SOCKS command reply code. +type Reply int + +func (code Reply) String() string { + switch code { + case StatusSucceeded: + return "succeeded" + case 0x01: + return "general SOCKS server failure" + case 0x02: + return "connection not allowed by ruleset" + case 0x03: + return "network unreachable" + case 0x04: + return "host unreachable" + case 0x05: + return "connection refused" + case 0x06: + return "TTL expired" + case 0x07: + return "command not supported" + case 0x08: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// Wire protocol constants. +const ( + Version5 = 0x05 + + AddrTypeIPv4 = 0x01 + AddrTypeFQDN = 0x03 + AddrTypeIPv6 = 0x04 + + CmdConnect Command = 0x01 // establishes an active-open forward proxy connection + cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + + AuthMethodNotRequired AuthMethod = 0x00 // no authentication required + AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password + AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods + + StatusSucceeded Reply = 0x00 +) + +// An Addr represents a SOCKS-specific address. +// Either Name or IP is used exclusively. +type Addr struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *Addr) Network() string { return "socks" } + +func (a *Addr) String() string { + if a == nil { + return "" + } + port := strconv.Itoa(a.Port) + if a.IP == nil { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +// A Conn represents a forward proxy connection. +type Conn struct { + net.Conn + + boundAddr net.Addr +} + +// BoundAddr returns the address assigned by the proxy server for +// connecting to the command target address from the proxy server. +func (c *Conn) BoundAddr() net.Addr { + if c == nil { + return nil + } + return c.boundAddr +} + +// A Dialer holds SOCKS-specific options. +type Dialer struct { + cmd Command // either CmdConnect or cmdBind + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address + + // ProxyDial specifies the optional dial function for + // establishing the transport connection. + ProxyDial func(context.Context, string, string) (net.Conn, error) + + // AuthMethods specifies the list of request authentication + // methods. + // If empty, SOCKS client requests only AuthMethodNotRequired. + AuthMethods []AuthMethod + + // Authenticate specifies the optional authentication + // function. It must be non-nil when AuthMethods is not empty. + // It must return an error when the authentication is failed. + Authenticate func(context.Context, io.ReadWriter, AuthMethod) error +} + +// DialContext connects to the provided address on the provided +// network. +// +// The returned error value may be a net.OpError. When the Op field of +// net.OpError contains "socks", the Source field contains a proxy +// server address and the Addr field contains a command target +// address. +// +// See func Dial of the net package of standard library for a +// description of the network and address parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) + } else { + var dd net.Dialer + c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + a, err := d.connect(ctx, c, address) + if err != nil { + c.Close() + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return &Conn{Conn: c, boundAddr: a}, nil +} + +// DialWithConn initiates a connection from SOCKS server to the target +// network and address using the connection c that is already +// connected to the SOCKS server. +// +// It returns the connection's local address assigned by the SOCKS +// server. +func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + a, err := d.connect(ctx, c, address) + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return a, nil +} + +func (d *Dialer) validateTarget(network, address string) error { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return errors.New("network not implemented") + } + switch d.cmd { + case CmdConnect, cmdBind: + default: + return errors.New("command not implemented") + } + return nil +} + +func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { + for i, s := range []string{d.proxyAddress, address} { + host, port, err := splitHostPort(s) + if err != nil { + return nil, nil, err + } + a := &Addr{Port: port} + a.IP = net.ParseIP(host) + if a.IP == nil { + a.Name = host + } + if i == 0 { + proxy = a + } else { + dst = a + } + } + return +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func NewDialer(network, address string) *Dialer { + return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect} +} + +const ( + authUsernamePasswordVersion = 0x01 + authStatusSucceeded = 0x00 +) + +// UsernamePassword are the credentials for the username/password +// authentication method. +type UsernamePassword struct { + Username string + Password string +} + +// Authenticate authenticates a pair of username and password with the +// proxy server. +func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error { + switch auth { + case AuthMethodNotRequired: + return nil + case AuthMethodUsernamePassword: + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 { + return errors.New("invalid username/password") + } + b := []byte{authUsernamePasswordVersion} + b = append(b, byte(len(up.Username))) + b = append(b, up.Username...) + b = append(b, byte(len(up.Password))) + b = append(b, up.Password...) + // TODO(mikio): handle IO deadlines and cancelation if + // necessary + if _, err := rw.Write(b); err != nil { + return err + } + if _, err := io.ReadFull(rw, b[:2]); err != nil { + return err + } + if b[0] != authUsernamePasswordVersion { + return errors.New("invalid username/password version") + } + if b[1] != authStatusSucceeded { + return errors.New("username/password authentication failed") + } + return nil + } + return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) +} diff --git a/transport.go b/transport.go index d7716e37..f6c830f0 100644 --- a/transport.go +++ b/transport.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/socks" "github.com/imroc/req/v3/internal/util" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" @@ -1587,15 +1588,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // Do nothing. Not using a proxy. case cm.proxyURL.Scheme == "socks5": conn := pconn.conn - d := socksNewDialer("tcp", conn.RemoteAddr().String()) + d := socks.NewDialer("tcp", conn.RemoteAddr().String()) if u := cm.proxyURL.User; u != nil { - auth := &socksUsernamePassword{ + auth := &socks.UsernamePassword{ Username: u.Username(), } auth.Password, _ = u.Password() - d.AuthMethods = []socksAuthMethod{ - socksAuthMethodNotRequired, - socksAuthMethodUsernamePassword, + d.AuthMethods = []socks.AuthMethod{ + socks.AuthMethodNotRequired, + socks.AuthMethodUsernamePassword, } d.Authenticate = auth.Authenticate } From a739fbe84fd436df758d904fff8bb45fb510c0dd Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 17:44:40 +0800 Subject: [PATCH 342/843] remove socks_bundle.go --- socks_bundle.go | 473 ------------------------------------------------ 1 file changed, 473 deletions(-) delete mode 100644 socks_bundle.go diff --git a/socks_bundle.go b/socks_bundle.go deleted file mode 100644 index 940e0bb3..00000000 --- a/socks_bundle.go +++ /dev/null @@ -1,473 +0,0 @@ -// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. -//go:generate bundle -o socks_bundle.go -prefix socks golang.org/x/net/internal/socks - -// Package socks provides a SOCKS version 5 client implementation. -// -// SOCKS protocol version 5 is defined in RFC 1928. -// Username/Password authentication for SOCKS version 5 is defined in -// RFC 1929. -// - -package req - -import ( - "context" - "errors" - "io" - "net" - "strconv" - "time" -) - -var ( - socksnoDeadline = time.Time{} - socksaLongTimeAgo = time.Unix(1, 0) -) - -func (d *socksDialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { - host, port, err := sockssplitHostPort(address) - if err != nil { - return nil, err - } - if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { - c.SetDeadline(deadline) - defer c.SetDeadline(socksnoDeadline) - } - if ctx != context.Background() { - errCh := make(chan error, 1) - done := make(chan struct{}) - defer func() { - close(done) - if ctxErr == nil { - ctxErr = <-errCh - } - }() - go func() { - select { - case <-ctx.Done(): - c.SetDeadline(socksaLongTimeAgo) - errCh <- ctx.Err() - case <-done: - errCh <- nil - } - }() - } - - b := make([]byte, 0, 6+len(host)) // the size here is just an estimate - b = append(b, socksVersion5) - if len(d.AuthMethods) == 0 || d.Authenticate == nil { - b = append(b, 1, byte(socksAuthMethodNotRequired)) - } else { - ams := d.AuthMethods - if len(ams) > 255 { - return nil, errors.New("too many authentication methods") - } - b = append(b, byte(len(ams))) - for _, am := range ams { - b = append(b, byte(am)) - } - } - if _, ctxErr = c.Write(b); ctxErr != nil { - return - } - - if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { - return - } - if b[0] != socksVersion5 { - return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) - } - am := socksAuthMethod(b[1]) - if am == socksAuthMethodNoAcceptableMethods { - return nil, errors.New("no acceptable authentication methods") - } - if d.Authenticate != nil { - if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { - return - } - } - - b = b[:0] - b = append(b, socksVersion5, byte(d.cmd), 0) - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - b = append(b, socksAddrTypeIPv4) - b = append(b, ip4...) - } else if ip6 := ip.To16(); ip6 != nil { - b = append(b, socksAddrTypeIPv6) - b = append(b, ip6...) - } else { - return nil, errors.New("unknown address type") - } - } else { - if len(host) > 255 { - return nil, errors.New("FQDN too long") - } - b = append(b, socksAddrTypeFQDN) - b = append(b, byte(len(host))) - b = append(b, host...) - } - b = append(b, byte(port>>8), byte(port)) - if _, ctxErr = c.Write(b); ctxErr != nil { - return - } - - if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { - return - } - if b[0] != socksVersion5 { - return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) - } - if cmdErr := socksReply(b[1]); cmdErr != socksStatusSucceeded { - return nil, errors.New("unknown error " + cmdErr.String()) - } - if b[2] != 0 { - return nil, errors.New("non-zero reserved field") - } - l := 2 - var a socksAddr - switch b[3] { - case socksAddrTypeIPv4: - l += net.IPv4len - a.IP = make(net.IP, net.IPv4len) - case socksAddrTypeIPv6: - l += net.IPv6len - a.IP = make(net.IP, net.IPv6len) - case socksAddrTypeFQDN: - if _, err := io.ReadFull(c, b[:1]); err != nil { - return nil, err - } - l += int(b[0]) - default: - return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) - } - if cap(b) < l { - b = make([]byte, l) - } else { - b = b[:l] - } - if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { - return - } - if a.IP != nil { - copy(a.IP, b) - } else { - a.Name = string(b[:len(b)-2]) - } - a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) - return &a, nil -} - -func sockssplitHostPort(address string) (string, int, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return "", 0, err - } - portnum, err := strconv.Atoi(port) - if err != nil { - return "", 0, err - } - if 1 > portnum || portnum > 0xffff { - return "", 0, errors.New("port number out of range " + port) - } - return host, portnum, nil -} - -// A Command represents a SOCKS command. -type socksCommand int - -func (cmd socksCommand) String() string { - switch cmd { - case socksCmdConnect: - return "socks connect" - case sockscmdBind: - return "socks bind" - default: - return "socks " + strconv.Itoa(int(cmd)) - } -} - -// An AuthMethod represents a SOCKS authentication method. -type socksAuthMethod int - -// A Reply represents a SOCKS command reply code. -type socksReply int - -func (code socksReply) String() string { - switch code { - case socksStatusSucceeded: - return "succeeded" - case 0x01: - return "general SOCKS server failure" - case 0x02: - return "connection not allowed by ruleset" - case 0x03: - return "network unreachable" - case 0x04: - return "host unreachable" - case 0x05: - return "connection refused" - case 0x06: - return "TTL expired" - case 0x07: - return "command not supported" - case 0x08: - return "address type not supported" - default: - return "unknown code: " + strconv.Itoa(int(code)) - } -} - -// Wire protocol constants. -const ( - socksVersion5 = 0x05 - - socksAddrTypeIPv4 = 0x01 - socksAddrTypeFQDN = 0x03 - socksAddrTypeIPv6 = 0x04 - - socksCmdConnect socksCommand = 0x01 // establishes an active-open forward proxy connection - sockscmdBind socksCommand = 0x02 // establishes a passive-open forward proxy connection - - socksAuthMethodNotRequired socksAuthMethod = 0x00 // no authentication required - socksAuthMethodUsernamePassword socksAuthMethod = 0x02 // use username/password - socksAuthMethodNoAcceptableMethods socksAuthMethod = 0xff // no acceptable authentication methods - - socksStatusSucceeded socksReply = 0x00 -) - -// An Addr represents a SOCKS-specific address. -// Either Name or IP is used exclusively. -type socksAddr struct { - Name string // fully-qualified domain name - IP net.IP - Port int -} - -func (a *socksAddr) Network() string { return "socks" } - -func (a *socksAddr) String() string { - if a == nil { - return "" - } - port := strconv.Itoa(a.Port) - if a.IP == nil { - return net.JoinHostPort(a.Name, port) - } - return net.JoinHostPort(a.IP.String(), port) -} - -// A Conn represents a forward proxy connection. -type socksConn struct { - net.Conn - - boundAddr net.Addr -} - -// BoundAddr returns the address assigned by the proxy server for -// connecting to the command target address from the proxy server. -func (c *socksConn) BoundAddr() net.Addr { - if c == nil { - return nil - } - return c.boundAddr -} - -// A Dialer holds SOCKS-specific options. -type socksDialer struct { - cmd socksCommand // either CmdConnect or cmdBind - proxyNetwork string // network between a proxy server and a client - proxyAddress string // proxy server address - - // ProxyDial specifies the optional dial function for - // establishing the transport connection. - ProxyDial func(context.Context, string, string) (net.Conn, error) - - // AuthMethods specifies the list of request authentication - // methods. - // If empty, SOCKS client requests only AuthMethodNotRequired. - AuthMethods []socksAuthMethod - - // Authenticate specifies the optional authentication - // function. It must be non-nil when AuthMethods is not empty. - // It must return an error when the authentication is failed. - Authenticate func(context.Context, io.ReadWriter, socksAuthMethod) error -} - -// DialContext connects to the provided address on the provided -// network. -// -// The returned error value may be a net.OpError. When the Op field of -// net.OpError contains "socks", the Source field contains a proxy -// server address and the Addr field contains a command target -// address. -// -// See func Dial of the net package of standard library for a -// description of the network and address parameters. -func (d *socksDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if err := d.validateTarget(network, address); err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - if ctx == nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} - } - var err error - var c net.Conn - if d.ProxyDial != nil { - c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) - } else { - var dd net.Dialer - c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) - } - if err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - a, err := d.connect(ctx, c, address) - if err != nil { - c.Close() - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - return &socksConn{Conn: c, boundAddr: a}, nil -} - -// DialWithConn initiates a connection from SOCKS server to the target -// network and address using the connection c that is already -// connected to the SOCKS server. -// -// It returns the connection's local address assigned by the SOCKS -// server. -func (d *socksDialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { - if err := d.validateTarget(network, address); err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - if ctx == nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} - } - a, err := d.connect(ctx, c, address) - if err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - return a, nil -} - -// Dial connects to the provided address on the provided network. -// -// Unlike DialContext, it returns a raw transport connection instead -// of a forward proxy connection. -// -// Deprecated: Use DialContext or DialWithConn instead. -func (d *socksDialer) Dial(network, address string) (net.Conn, error) { - if err := d.validateTarget(network, address); err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - var err error - var c net.Conn - if d.ProxyDial != nil { - c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) - } else { - c, err = net.Dial(d.proxyNetwork, d.proxyAddress) - } - if err != nil { - proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { - c.Close() - return nil, err - } - return c, nil -} - -func (d *socksDialer) validateTarget(network, address string) error { - switch network { - case "tcp", "tcp6", "tcp4": - default: - return errors.New("network not implemented") - } - switch d.cmd { - case socksCmdConnect, sockscmdBind: - default: - return errors.New("command not implemented") - } - return nil -} - -func (d *socksDialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { - for i, s := range []string{d.proxyAddress, address} { - host, port, err := sockssplitHostPort(s) - if err != nil { - return nil, nil, err - } - a := &socksAddr{Port: port} - a.IP = net.ParseIP(host) - if a.IP == nil { - a.Name = host - } - if i == 0 { - proxy = a - } else { - dst = a - } - } - return -} - -// NewDialer returns a new Dialer that dials through the provided -// proxy server's network and address. -func socksNewDialer(network, address string) *socksDialer { - return &socksDialer{proxyNetwork: network, proxyAddress: address, cmd: socksCmdConnect} -} - -const ( - socksauthUsernamePasswordVersion = 0x01 - socksauthStatusSucceeded = 0x00 -) - -// UsernamePassword are the credentials for the username/password -// authentication method. -type socksUsernamePassword struct { - Username string - Password string -} - -// Authenticate authenticates a pair of username and password with the -// proxy server. -func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth socksAuthMethod) error { - switch auth { - case socksAuthMethodNotRequired: - return nil - case socksAuthMethodUsernamePassword: - if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 { - return errors.New("invalid username/password") - } - b := []byte{socksauthUsernamePasswordVersion} - b = append(b, byte(len(up.Username))) - b = append(b, up.Username...) - b = append(b, byte(len(up.Password))) - b = append(b, up.Password...) - // TODO(mikio): handle IO deadlines and cancelation if - // necessary - if _, err := rw.Write(b); err != nil { - return err - } - if _, err := io.ReadFull(rw, b[:2]); err != nil { - return err - } - if b[0] != socksauthUsernamePasswordVersion { - return errors.New("invalid username/password version") - } - if b[1] != socksauthStatusSucceeded { - return errors.New("username/password authentication failed") - } - return nil - } - return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) -} From cd718021e47db5c26fe18c67120e8afa25bf70c7 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 18:01:50 +0800 Subject: [PATCH 343/843] add http_test.go --- http_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 http_test.go diff --git a/http_test.go b/http_test.go new file mode 100644 index 00000000..f432e859 --- /dev/null +++ b/http_test.go @@ -0,0 +1,32 @@ +package req + +import "testing" + +func TestCleanHost(t *testing.T) { + tests := []struct { + in, want string + }{ + {"www.google.com", "www.google.com"}, + {"www.google.com foo", "www.google.com"}, + {"www.google.com/foo", "www.google.com"}, + {" first character is a space", ""}, + {"[1::6]:8080", "[1::6]:8080"}, + + // Punycode: + {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, + {"bücher.de", "xn--bcher-kva.de"}, + {"bücher.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we convert to lowercase before punycode: + {"BÜCHER.de", "xn--bcher-kva.de"}, + {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we normalize to NFC before punycode: + {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed + {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input + } + for _, tt := range tests { + got := cleanHost(tt.in) + if tt.want != got { + t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} From 7b6780f67c250cba550335261b58c296afd84a40 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 18:26:33 +0800 Subject: [PATCH 344/843] add TestFixPragmaCache --- req_test.go | 2 ++ request_test.go | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/req_test.go b/req_test.go index c2145986..c4a9d8a4 100644 --- a/req_test.go +++ b/req_test.go @@ -165,6 +165,8 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/host-header": w.Write([]byte(r.Host)) + case "/pragma": + w.Header().Add("Pragma", "no-cache") case "/header": b, _ := json.Marshal(r.Header) w.Header().Set(hdrContentTypeKey, jsonContentType) diff --git a/request_test.go b/request_test.go index e63b54e9..32fce396 100644 --- a/request_test.go +++ b/request_test.go @@ -615,3 +615,9 @@ func TestUploadMultipart(t *testing.T) { assertContains(t, resp.String(), "value1", true) assertContains(t, resp.String(), "value2", true) } + +func TestFixPragmaCache(t *testing.T) { + resp, err := tc().EnableForceHTTP1().R().Get("/pragma") + assertSuccess(t, resp, err) + assertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) +} From 7ecdae9c0ab25048ef073a99012739eb17129f70 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 20:56:53 +0800 Subject: [PATCH 345/843] merge http_response.go into transport.go --- http_response.go | 93 ------------------------------------------------ transport.go | 70 ++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 93 deletions(-) delete mode 100644 http_response.go diff --git a/http_response.go b/http_response.go deleted file mode 100644 index 83ee1ab1..00000000 --- a/http_response.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// HTTP Response reading and parsing. - -package req - -import ( - "github.com/imroc/req/v3/internal/util" - "io" - "net/http" - "strconv" - "strings" -) - -var respExcludeHeader = map[string]bool{ - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, -} - -// ReadResponse reads and returns an HTTP response from r. -// The req parameter optionally specifies the Request that corresponds -// to this Response. If nil, a GET request is assumed. -// Clients must call resp.Body.Close when finished reading resp.Body. -// After that call, clients can inspect resp.Trailer to find key/value -// pairs included in the response trailer. -func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - dumps := getDumpers(pc.t.dump, req.Context()) - tp := newTextprotoReader(pc.br, dumps) - resp := &http.Response{ - Request: req, - } - - // Parse the first line of the response. - line, err := tp.ReadLine() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - proto, status, ok := util.CutString(line, " ") - if !ok { - return nil, badStringError("malformed HTTP response", line) - } - resp.Proto = proto - resp.Status = strings.TrimLeft(status, " ") - - statusCode, _, _ := util.CutString(resp.Status, " ") - if len(statusCode) != 3 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - resp.StatusCode, err = strconv.Atoi(statusCode) - if err != nil || resp.StatusCode < 0 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { - return nil, badStringError("malformed HTTP version", resp.Proto) - } - - // Parse the response headers. - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - resp.Header = http.Header(mimeHeader) - - fixPragmaCacheControl(resp.Header) - - err = readTransfer(resp, pc.br) - if err != nil { - return nil, err - } - - return resp, nil -} - -// RFC 7234, section 5.4: Should treat -// Pragma: no-cache -// like -// Cache-Control: no-cache -func fixPragmaCacheControl(header http.Header) { - if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { - if _, presentcc := header["Cache-Control"]; !presentcc { - header["Cache-Control"] = []string{"no-cache"} - } - } -} diff --git a/transport.go b/transport.go index f6c830f0..d247487d 100644 --- a/transport.go +++ b/transport.go @@ -30,6 +30,7 @@ import ( "net/http/httptrace" "net/textproto" "net/url" + "strconv" "strings" "sync" "sync/atomic" @@ -1874,6 +1875,75 @@ type persistConn struct { mutateHeaderFunc func(http.Header) } +// RFC 7234, section 5.4: Should treat +// Pragma: no-cache +// like +// Cache-Control: no-cache +func fixPragmaCacheControl(header http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +// readResponse reads an HTTP response (or two, in the case of "Expect: +// 100-continue") from the server. It returns the final non-100 one. +// trace is optional. +func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { + dumps := getDumpers(pc.t.dump, req.Context()) + tp := newTextprotoReader(pc.br, dumps) + resp := &http.Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := util.CutString(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := util.CutString(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = http.Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + err = readTransfer(resp, pc.br) + if err != nil { + return nil, err + } + + return resp, nil +} + func (pc *persistConn) maxHeaderResponseSize() int64 { if v := pc.t.MaxResponseHeaderBytes; v != 0 { return v From 3584f9126b1ed8542f3aa4061417fe99f06ff153 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 14 Feb 2022 21:33:09 +0800 Subject: [PATCH 346/843] merge requestWrite into persistConn.writeRequest --- http_request.go | 178 ----------------------------------------------- transport.go | 179 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 176 insertions(+), 181 deletions(-) diff --git a/http_request.go b/http_request.go index 7fad63da..c94add39 100644 --- a/http_request.go +++ b/http_request.go @@ -1,14 +1,10 @@ package req import ( - "bufio" "errors" - "fmt" "github.com/imroc/req/v3/internal/ascii" "golang.org/x/net/http/httpguts" - "io" "net/http" - "net/http/httptrace" "strings" ) @@ -82,180 +78,6 @@ func outgoingLength(r *http.Request) int64 { // the Request. var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") -// extraHeaders may be nil -// waitForContinue may be nil -// always closes body -func requestWrite(r *http.Request, w io.Writer, usingProxy bool, extraHeaders http.Header, waitForContinue func() bool, dumps []*dumper) (err error) { - trace := httptrace.ContextClientTrace(r.Context()) - if trace != nil && trace.WroteRequest != nil { - defer func() { - trace.WroteRequest(httptrace.WroteRequestInfo{ - Err: err, - }) - }() - } - closed := false - defer func() { - if closed { - return - } - if closeErr := closeRequestBody(r); closeErr != nil && err == nil { - err = closeErr - } - }() - - // Find the target host. Prefer the Host: header, but if that - // is not given, use the host from the request URL. - // - // Clean the host, in case it arrives with unexpected stuff in it. - host := cleanHost(r.Host) - if host == "" { - if r.URL == nil { - return errMissingHost - } - host = cleanHost(r.URL.Host) - } - - // According to RFC 6874, an HTTP client, proxy, or other - // intermediary must remove any IPv6 zone identifier attached - // to an outgoing URI. - host = removeZone(host) - - ruri := r.URL.RequestURI() - if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { - ruri = r.URL.Scheme + "://" + host + ruri - } else if r.Method == "CONNECT" && r.URL.Path == "" { - // CONNECT requests normally give just the host and port, not a full URL. - ruri = host - if r.URL.Opaque != "" { - ruri = r.URL.Opaque - } - } - if stringContainsCTLByte(ruri) { - return errors.New("net/http: can't write control character in Request.URL") - } - // TODO: validate r.Method too? At least it's less likely to - // come from an attacker (more likely to be a constant in - // code). - - // Wrap the writer in a bufio Writer if it's not already buffered. - // Don't always call NewWriter, as that forces a bytes.Buffer - // and other small bufio Writers to have a minimum 4k buffer - // size. - var bw *bufio.Writer - if _, ok := w.(io.ByteWriter); !ok { - bw = bufio.NewWriter(w) - w = bw - } - - rw := w // raw writer - for _, dump := range dumps { - if dump.RequestHeader { - w = dump.WrapWriter(w) - } - } - - _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) - if err != nil { - return err - } - - // Header lines - _, err = fmt.Fprintf(w, "Host: %s\r\n", host) - if err != nil { - return err - } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Host", []string{host}) - } - - // Use the defaultUserAgent unless the Header contains one, which - // may be blank to not send the header. - userAgent := hdrUserAgentValue - if headerHas(r.Header, "User-Agent") { - userAgent = r.Header.Get("User-Agent") - } - if userAgent != "" { - _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) - if err != nil { - return err - } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("User-Agent", []string{userAgent}) - } - } - - // Process Body,ContentLength,Close,Trailer - tw, err := newTransferWriter(r) - if err != nil { - return err - } - err = tw.writeHeader(w, trace) - if err != nil { - return err - } - - err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace) - if err != nil { - return err - } - - if extraHeaders != nil { - err = headerWrite(extraHeaders, w, trace) - if err != nil { - return err - } - } - - _, err = io.WriteString(w, "\r\n") - if err != nil { - return err - } - - if trace != nil && trace.WroteHeaders != nil { - trace.WroteHeaders() - } - - // Flush and wait for 100-continue if expected. - if waitForContinue != nil { - if bw, ok := w.(*bufio.Writer); ok { - err = bw.Flush() - if err != nil { - return err - } - } - if trace != nil && trace.Wait100Continue != nil { - trace.Wait100Continue() - } - if !waitForContinue() { - closed = true - closeRequestBody(r) - return nil - } - } - - if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { - if err := bw.Flush(); err != nil { - return err - } - } - - // Write body and trailer - closed = true - err = tw.writeBody(rw, dumps) - if err != nil { - if tw.bodyReadError == err { - err = requestBodyReadError{err} - } - return err - } - - if bw != nil { - return bw.Flush() - } - return nil -} - func closeRequestBody(r *http.Request) error { if r.Body == nil { return nil diff --git a/transport.go b/transport.go index d247487d..8839600a 100644 --- a/transport.go +++ b/transport.go @@ -2422,9 +2422,7 @@ func (pc *persistConn) writeLoop() { select { case wr := <-pc.writech: startBytesWritten := pc.nwrite - ctx := wr.req.Request.Context() - dumps := getDumpers(pc.t.dump, ctx) - err := requestWrite(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh), dumps) + err := pc.writeRequest(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -2456,6 +2454,181 @@ func (pc *persistConn) writeLoop() { } } +// extraHeaders may be nil +// waitForContinue may be nil +// always closes body +func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy bool, extraHeaders http.Header, waitForContinue func() bool) (err error) { + trace := httptrace.ContextClientTrace(r.Context()) + if trace != nil && trace.WroteRequest != nil { + defer func() { + trace.WroteRequest(httptrace.WroteRequestInfo{ + Err: err, + }) + }() + } + closed := false + defer func() { + if closed { + return + } + if closeErr := closeRequestBody(r); closeErr != nil && err == nil { + err = closeErr + } + }() + + // Find the target host. Prefer the Host: header, but if that + // is not given, use the host from the request URL. + // + // Clean the host, in case it arrives with unexpected stuff in it. + host := cleanHost(r.Host) + if host == "" { + if r.URL == nil { + return errMissingHost + } + host = cleanHost(r.URL.Host) + } + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + host = removeZone(host) + + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + if r.URL.Opaque != "" { + ruri = r.URL.Opaque + } + } + if stringContainsCTLByte(ruri) { + return errors.New("net/http: can't write control character in Request.URL") + } + // TODO: validate r.Method too? At least it's less likely to + // come from an attacker (more likely to be a constant in + // code). + + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + rw := w // raw writer + dumps := getDumpers(pc.t.dump, r.Context()) + for _, dump := range dumps { + if dump.RequestHeader { + w = dump.WrapWriter(w) + } + } + + _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) + if err != nil { + return err + } + + // Header lines + _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + if err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("Host", []string{host}) + } + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := hdrUserAgentValue + if headerHas(r.Header, "User-Agent") { + userAgent = r.Header.Get("User-Agent") + } + if userAgent != "" { + _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + if err != nil { + return err + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField("User-Agent", []string{userAgent}) + } + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(r) + if err != nil { + return err + } + err = tw.writeHeader(w, trace) + if err != nil { + return err + } + + err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace) + if err != nil { + return err + } + + if extraHeaders != nil { + err = headerWrite(extraHeaders, w, trace) + if err != nil { + return err + } + } + + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } + + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } + + // Flush and wait for 100-continue if expected. + if waitForContinue != nil { + if bw, ok := w.(*bufio.Writer); ok { + err = bw.Flush() + if err != nil { + return err + } + } + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } + if !waitForContinue() { + closed = true + closeRequestBody(r) + return nil + } + } + + if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { + if err := bw.Flush(); err != nil { + return err + } + } + + // Write body and trailer + closed = true + err = tw.writeBody(rw, dumps) + if err != nil { + if tw.bodyReadError == err { + err = requestBodyReadError{err} + } + return err + } + + if bw != nil { + return bw.Flush() + } + return nil +} + // maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip // will wait to see the Request's Body.Write result after getting a // response from the server. See comments in (*persistConn).wroteRequest. From 2f438504d72187b8218dcd4e5be82ec4c23bb8b8 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 10:56:27 +0800 Subject: [PATCH 347/843] move clone functions into req.go --- client.go | 65 ----------------------------------------------------- req.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 65 deletions(-) diff --git a/client.go b/client.go index 80fa2ff0..879e7a97 100644 --- a/client.go +++ b/client.go @@ -63,71 +63,6 @@ type Client struct { afterResponse []ResponseMiddleware } -func cloneCookies(cookies []*http.Cookie) []*http.Cookie { - if len(cookies) == 0 { - return nil - } - c := make([]*http.Cookie, len(cookies)) - copy(c, cookies) - return c -} - -func cloneHeaders(hdrs http.Header) http.Header { - if hdrs == nil { - return nil - } - h := make(http.Header) - for k, vs := range hdrs { - for _, v := range vs { - h.Add(k, v) - } - } - return h -} - -// TODO: change to generics function when generics are commonly used. -func cloneRequestMiddleware(m []RequestMiddleware) []RequestMiddleware { - if len(m) == 0 { - return nil - } - mm := make([]RequestMiddleware, len(m)) - copy(mm, m) - return mm -} - -func cloneResponseMiddleware(m []ResponseMiddleware) []ResponseMiddleware { - if len(m) == 0 { - return nil - } - mm := make([]ResponseMiddleware, len(m)) - copy(mm, m) - return mm -} - -func cloneUrlValues(v urlpkg.Values) urlpkg.Values { - if v == nil { - return nil - } - vv := make(urlpkg.Values) - for key, values := range v { - for _, value := range values { - vv.Add(key, value) - } - } - return vv -} - -func cloneMap(h map[string]string) map[string]string { - if h == nil { - return nil - } - m := make(map[string]string) - for k, v := range h { - m[k] = v - } - return m -} - // R is a global wrapper methods which delegated // to the default client's R(). func R() *Request { diff --git a/req.go b/req.go index 269d6698..ef6f7a70 100644 --- a/req.go +++ b/req.go @@ -3,6 +3,8 @@ package req import ( "fmt" "io" + "net/http" + "net/url" ) const ( @@ -58,3 +60,68 @@ type FileUpload struct { // "name" and "filename". ExtraContentDisposition *ContentDisposition } + +func cloneCookies(cookies []*http.Cookie) []*http.Cookie { + if len(cookies) == 0 { + return nil + } + c := make([]*http.Cookie, len(cookies)) + copy(c, cookies) + return c +} + +func cloneHeaders(hdrs http.Header) http.Header { + if hdrs == nil { + return nil + } + h := make(http.Header) + for k, vs := range hdrs { + for _, v := range vs { + h.Add(k, v) + } + } + return h +} + +// TODO: change to generics function when generics are commonly used. +func cloneRequestMiddleware(m []RequestMiddleware) []RequestMiddleware { + if len(m) == 0 { + return nil + } + mm := make([]RequestMiddleware, len(m)) + copy(mm, m) + return mm +} + +func cloneResponseMiddleware(m []ResponseMiddleware) []ResponseMiddleware { + if len(m) == 0 { + return nil + } + mm := make([]ResponseMiddleware, len(m)) + copy(mm, m) + return mm +} + +func cloneUrlValues(v url.Values) url.Values { + if v == nil { + return nil + } + vv := make(url.Values) + for key, values := range v { + for _, value := range values { + vv.Add(key, value) + } + } + return vv +} + +func cloneMap(h map[string]string) map[string]string { + if h == nil { + return nil + } + m := make(map[string]string) + for k, v := range h { + m[k] = v + } + return m +} From 3e2cca9fe0f0908a1951e9edcb1b7cb59ffc2f7e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 11:41:42 +0800 Subject: [PATCH 348/843] add TestClientClone and Client.SetCommonPathParam --- client.go | 34 ++++++++++++++++++++++++++++++++++ client_test.go | 19 +++++++++++++++++++ req_test.go | 2 +- transport.go | 2 +- 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 879e7a97..e6298f13 100644 --- a/client.go +++ b/client.go @@ -409,6 +409,40 @@ func (c *Client) AddCommonQueryParam(key, value string) *Client { return c } +func (c *Client) pathParams() map[string]string { + if c.PathParams == nil { + c.PathParams = make(map[string]string) + } + return c.PathParams +} + +// SetCommonPathParam is a global wrapper methods which delegated +// to the default client's SetCommonPathParam. +func SetCommonPathParam(key, value string) *Client { + return defaultClient.SetCommonPathParam(key, value) +} + +// SetCommonPathParam set a path parameter for all requests. +func (c *Client) SetCommonPathParam(key, value string) *Client { + c.pathParams()[key] = value + return c +} + +// SetCommonPathParams is a global wrapper methods which delegated +// to the default client's SetCommonPathParams. +func SetCommonPathParams(pathParams map[string]string) *Client { + return defaultClient.SetCommonPathParams(pathParams) +} + +// SetCommonPathParams set path parameters for all requests. +func (c *Client) SetCommonPathParams(pathParams map[string]string) *Client { + m := c.pathParams() + for k, v := range pathParams { + m[k] = v + } + return c +} + // SetCommonQueryParam is a global wrapper methods which delegated // to the default client's SetCommonQueryParam. func SetCommonQueryParam(key, value string) *Client { diff --git a/client_test.go b/client_test.go index 5b6fae92..aea4286e 100644 --- a/client_test.go +++ b/client_test.go @@ -3,9 +3,28 @@ package req import ( "bytes" "io/ioutil" + "net/http" "testing" ) +func TestClientClone(t *testing.T) { + c1 := tc().DevMode(). + SetCommonHeader("test", "test"). + SetCommonCookies(&http.Cookie{ + Name: "test", + Value: "test", + }).SetCommonQueryParam("test", "test"). + SetCommonPathParam("test", "test") + + c2 := c1.Clone() + assertEqual(t, c1.Headers, c2.Headers) + assertEqual(t, c1.Cookies, c2.Cookies) + assertEqual(t, c1.BaseURL, c2.BaseURL) + assertEqual(t, c1.DebugLog, c2.DebugLog) + assertEqual(t, c1.PathParams, c2.PathParams) + assertEqual(t, c1.QueryParams, c2.QueryParams) +} + func TestDisableAutoReadResponse(t *testing.T) { testDisableAutoReadResponse(t, tc()) testDisableAutoReadResponse(t, tc().EnableForceHTTP1()) diff --git a/req_test.go b/req_test.go index c4a9d8a4..ae027a08 100644 --- a/req_test.go +++ b/req_test.go @@ -254,7 +254,7 @@ func assertError(t *testing.T, err error) { func assertEqual(t *testing.T, e, g interface{}) { if !equal(e, g) { - t.Errorf("Expected [%v], got [%v]", e, g) + t.Errorf("Expected [%+v], got [%+v]", e, g) } return } diff --git a/transport.go b/transport.go index 8839600a..f761db61 100644 --- a/transport.go +++ b/transport.go @@ -381,7 +381,7 @@ func (t *Transport) Clone() *Transport { dump: t.dump.Clone(), } if t.dump != nil { - t.dump.Start() + go t.dump.Start() } if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() From c7c75688e70af29fe559e3ba806e012692e88d97 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 14:37:03 +0800 Subject: [PATCH 349/843] add assertEqualStruct --- client_test.go | 7 +---- req_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/client_test.go b/client_test.go index aea4286e..468379bf 100644 --- a/client_test.go +++ b/client_test.go @@ -17,12 +17,7 @@ func TestClientClone(t *testing.T) { SetCommonPathParam("test", "test") c2 := c1.Clone() - assertEqual(t, c1.Headers, c2.Headers) - assertEqual(t, c1.Cookies, c2.Cookies) - assertEqual(t, c1.BaseURL, c2.BaseURL) - assertEqual(t, c1.DebugLog, c2.DebugLog) - assertEqual(t, c1.PathParams, c2.PathParams) - assertEqual(t, c1.QueryParams, c2.QueryParams) + assertEqualStruct(t, c1, c2, false, "t", "t2", "httpClient") } func TestDisableAutoReadResponse(t *testing.T) { diff --git a/req_test.go b/req_test.go index ae027a08..04520b9d 100644 --- a/req_test.go +++ b/req_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "go/token" "io" "io/ioutil" "net/http" @@ -14,6 +15,7 @@ import ( "strings" "sync" "testing" + "unsafe" ) func tc() *Client { @@ -252,6 +254,49 @@ func assertError(t *testing.T, err error) { } } +func assertEqualStruct(t *testing.T, e, g interface{}, onlyExported bool, excludes ...string) { + ev := reflect.ValueOf(e).Elem() + gv := reflect.ValueOf(g).Elem() + et := ev.Type() + gt := gv.Type() + m := map[string]bool{} + for _, exclude := range excludes { + m[exclude] = true + } + if et.Kind() != gt.Kind() { + t.Fatalf("Expected kind [%s], got [%s]", et.Kind().String(), gt.Kind().String()) + } + if et.Name() != gt.Name() { + t.Fatalf("Expected type [%s], got [%s]", et.Name(), gt.Name()) + } + + for i := 0; i < ev.NumField(); i++ { + sf := ev.Field(i) + if sf.Kind() == reflect.Func || sf.Kind() == reflect.Slice { + continue + } + st := et.Field(i) + if m[st.Name] { + continue + } + if onlyExported && !token.IsExported(st.Name) { + continue + } + var ee, gg interface{} + if !token.IsExported(st.Name) { + ee = reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Interface() + gg = reflect.NewAt(sf.Type(), unsafe.Pointer(gv.Field(i).UnsafeAddr())).Elem().Interface() + } else { + ee = sf.Interface() + gg = gv.Field(i).Interface() + } + if !reflect.DeepEqual(ee, gg) { + t.Errorf("Field %s.%s is not equal, expected [%v], got [%v]", et.Name(), et.Field(i).Name, ee, gg) + } + } + +} + func assertEqual(t *testing.T, e, g interface{}) { if !equal(e, g) { t.Errorf("Expected [%+v], got [%+v]", e, g) @@ -289,3 +334,29 @@ func isNil(v interface{}) bool { } return false } + +// func equalClient(t *testing.T, c1, c2 *Client) bool { +// if !notZero(t, c2) { +// return false +// } +// } + +// func equalTransport(t1, t2 *Transport) bool { +// if !equal(t1.TLSClientConfig, t2.TLSClientConfig) { +// return false +// } +// } + +// func notZero(t *testing.T, v interface{}) bool { +// rv := reflect.ValueOf(v).Elem() +// rt := rv.Type() +// for i := 0; i < rt.NumField(); i++ { +// sf := rt.Field(i) +// if !token.IsExported(sf.Name) { +// continue +// } +// if rv.Field(i).IsZero() { +// t.Errorf("cloned field %s.%s is zero", reflect.TypeOf(v).Name(), sf.Name) +// } +// } +// } From 76250411657490a9f14d7a36012aa9fb439755b2 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 14:39:52 +0800 Subject: [PATCH 350/843] improve assertEqualStruct --- req_test.go | 35 +++++++---------------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/req_test.go b/req_test.go index 04520b9d..65a55bdc 100644 --- a/req_test.go +++ b/req_test.go @@ -263,9 +263,14 @@ func assertEqualStruct(t *testing.T, e, g interface{}, onlyExported bool, exclud for _, exclude := range excludes { m[exclude] = true } - if et.Kind() != gt.Kind() { - t.Fatalf("Expected kind [%s], got [%s]", et.Kind().String(), gt.Kind().String()) + if et.Kind() != reflect.Struct { + t.Fatalf("expect object should be struct instead of %v", et.Kind().String()) } + + if gt.Kind() != reflect.Struct { + t.Fatalf("got object should be struct instead of %v", gt.Kind().String()) + } + if et.Name() != gt.Name() { t.Fatalf("Expected type [%s], got [%s]", et.Name(), gt.Name()) } @@ -334,29 +339,3 @@ func isNil(v interface{}) bool { } return false } - -// func equalClient(t *testing.T, c1, c2 *Client) bool { -// if !notZero(t, c2) { -// return false -// } -// } - -// func equalTransport(t1, t2 *Transport) bool { -// if !equal(t1.TLSClientConfig, t2.TLSClientConfig) { -// return false -// } -// } - -// func notZero(t *testing.T, v interface{}) bool { -// rv := reflect.ValueOf(v).Elem() -// rt := rv.Type() -// for i := 0; i < rt.NumField(); i++ { -// sf := rt.Field(i) -// if !token.IsExported(sf.Name) { -// continue -// } -// if rv.Field(i).IsZero() { -// t.Errorf("cloned field %s.%s is zero", reflect.TypeOf(v).Name(), sf.Name) -// } -// } -// } From addfc5e4066371fec476901de23e98b87544c7b7 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 17:21:56 +0800 Subject: [PATCH 351/843] add many client tests and fix EnableAutoDecode --- client.go | 2 +- client_test.go | 242 ++++++++++++++++++++++++++++++++++++++++++++++++ redirect.go | 12 ++- req_test.go | 25 +++++ request_test.go | 15 --- 5 files changed, 275 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index e6298f13..6d84193d 100644 --- a/client.go +++ b/client.go @@ -857,7 +857,7 @@ func EnableAutoDecode() *Client { // EnableAutoDecode enable auto-detect charset and decode to utf-8 // (enabled by default). func (c *Client) EnableAutoDecode() *Client { - c.getResponseOptions().DisableAutoDecode = true + c.getResponseOptions().DisableAutoDecode = false return c } diff --git a/client_test.go b/client_test.go index 468379bf..9efb7877 100644 --- a/client_test.go +++ b/client_test.go @@ -2,11 +2,232 @@ package req import ( "bytes" + "crypto/tls" "io/ioutil" "net/http" + "net/url" + "os" "testing" + "time" ) +func TestAutoDecode(t *testing.T) { + c := tc().DisableAutoDecode() + resp, err := c.R().Get("/gbk") + assertSuccess(t, resp, err) + assertEqual(t, toGbk("我是roc"), resp.Bytes()) + + resp, err = c.EnableAutoDecode().R().Get("/gbk") + assertSuccess(t, resp, err) + assertEqual(t, "我是roc", resp.String()) +} + +func TestSetTimeout(t *testing.T) { + timeout := 100 * time.Second + c := tc().SetTimeout(timeout) + assertEqual(t, timeout, c.httpClient.Timeout) +} + +func TestSetLogger(t *testing.T) { + l := createDefaultLogger() + c := tc().SetLogger(l) + assertEqual(t, l, c.log) + + c.SetLogger(nil) + assertEqual(t, &disableLogger{}, c.log) +} + +func TestSetScheme(t *testing.T) { + c := tc().SetScheme("https") + assertEqual(t, "https", c.scheme) +} + +func TestDebugLog(t *testing.T) { + c := tc().EnableDebugLog() + assertEqual(t, true, c.DebugLog) + + c.DisableDebugLog() + assertEqual(t, false, c.DebugLog) +} + +func TestSetCommonCookies(t *testing.T) { + headers := make(http.Header) + resp, err := tc().SetCommonCookies(&http.Cookie{ + Name: "test", + Value: "test", + }).R().SetResult(&headers).Get("/header") + assertSuccess(t, resp, err) + assertEqual(t, "test=test", headers.Get("Cookie")) +} + +func TestSetCommonQueryString(t *testing.T) { + resp, err := tc().SetCommonQueryString("test=test").R().Get("/query-parameter") + assertSuccess(t, resp, err) + assertEqual(t, "test=test", resp.String()) +} + +func TestSetCommonPathParams(t *testing.T) { + c := tc().SetCommonPathParams(map[string]string{"test": "test"}) + assertNotNil(t, c.PathParams) + assertEqual(t, "test", c.PathParams["test"]) +} + +func TestSetCommonPathParam(t *testing.T) { + c := tc().SetCommonPathParam("test", "test") + assertNotNil(t, c.PathParams) + assertEqual(t, "test", c.PathParams["test"]) +} + +func TestAddCommonQueryParam(t *testing.T) { + resp, err := tc(). + AddCommonQueryParam("test", "1"). + AddCommonQueryParam("test", "2"). + R().Get("/query-parameter") + assertSuccess(t, resp, err) + assertEqual(t, "test=1&test=2", resp.String()) +} + +func TestSetCommonQueryParam(t *testing.T) { + resp, err := tc().SetCommonQueryParam("test", "test").R().Get("/query-parameter") + assertSuccess(t, resp, err) + assertEqual(t, "test=test", resp.String()) +} + +func TestSetCommonQueryParams(t *testing.T) { + resp, err := tc().SetCommonQueryParams(map[string]string{"test": "test"}).R().Get("/query-parameter") + assertSuccess(t, resp, err) + assertEqual(t, "test=test", resp.String()) +} + +func TestInsecureSkipVerify(t *testing.T) { + c := tc().EnableInsecureSkipVerify() + assertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) + + c.DisableInsecureSkipVerify() + assertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) +} + +func TestSetTLSClientConfig(t *testing.T) { + config := &tls.Config{InsecureSkipVerify: true} + c := tc().SetTLSClientConfig(config) + assertEqual(t, config, c.t.TLSClientConfig) +} + +func TestCompression(t *testing.T) { + c := tc().DisableCompression() + assertEqual(t, true, c.t.DisableCompression) + + c.EnableCompression() + assertEqual(t, false, c.t.DisableCompression) +} + +func TestKeepAlives(t *testing.T) { + c := tc().DisableKeepAlives() + assertEqual(t, true, c.t.DisableKeepAlives) + + c.EnableKeepAlives() + assertEqual(t, false, c.t.DisableKeepAlives) +} + +func TestRedirect(t *testing.T) { + _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect is disabled", true) + + _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") + assertNotNil(t, err) + assertContains(t, err.Error(), "stopped after 3 redirects", true) + + _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") + assertNotNil(t, err) + assertContains(t, err.Error(), "different domain name is not allowed", true) + + _, err = tc().SetRedirectPolicy(SameHostRedirectPolicy()).R().Get("/redirect-to-other") + assertNotNil(t, err) + assertContains(t, err.Error(), "different host name is not allowed", true) + + _, err = tc().SetRedirectPolicy(AllowedHostRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) + + _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) +} + +func TestGetTLSClientConfig(t *testing.T) { + c := tc() + config := c.GetTLSClientConfig() + assertEqual(t, true, c.t.TLSClientConfig != nil) + assertEqual(t, config, c.t.TLSClientConfig) +} + +func TestSetRootCertFromFile(t *testing.T) { + c := tc().SetRootCertsFromFile(getTestFilePath("sample-root.pem")) + assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) +} + +func TestSetRootCertFromString(t *testing.T) { + c := tc().SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) + assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) +} + +func TestSetCerts(t *testing.T) { + c := tc().SetCerts(tls.Certificate{}, tls.Certificate{}) + assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) +} + +func TestSetCertFromFile(t *testing.T) { + c := tc().SetCertFromFile( + getTestFilePath("sample-client.pem"), + getTestFilePath("sample-client-key.pem"), + ) + assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) +} + +func TestSetOutputDirectory(t *testing.T) { + outFile := "test_output_dir" + resp, err := tc(). + SetOutputDirectory(testDataPath). + R().SetOutputFile(outFile). + Get("/") + assertSuccess(t, resp, err) + content := string(getTestFileContent(t, outFile)) + os.Remove(getTestFilePath(outFile)) + assertEqual(t, "TestGet: text response", content) +} + +func TestSetBaseURL(t *testing.T) { + baseURL := "http://dummy-req.local/test" + resp, _ := tc().SetTimeout(time.Nanosecond).SetBaseURL(baseURL).R().Get("/req") + assertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) +} + +func TestSetCommonFormDataFromValues(t *testing.T) { + expectedForm := make(url.Values) + gotForm := make(url.Values) + expectedForm.Set("test", "test") + resp, err := tc(). + SetCommonFormDataFromValues(expectedForm). + R().SetResult(&gotForm). + Post("/form") + assertSuccess(t, resp, err) + assertEqual(t, "test", gotForm.Get("test")) +} + +func TestSetCommonFormData(t *testing.T) { + form := make(url.Values) + resp, err := tc(). + SetCommonFormData( + map[string]string{ + "test": "test", + }).R(). + SetResult(&form). + Post("/form") + assertSuccess(t, resp, err) + assertEqual(t, "test", form.Get("test")) +} + func TestClientClone(t *testing.T) { c1 := tc().DevMode(). SetCommonHeader("test", "test"). @@ -133,4 +354,25 @@ func testClientDump(t *testing.T, newClient func() *Client) { assertContains(t, buf.String(), "test body", false) assertContains(t, buf.String(), "date", false) assertContains(t, buf.String(), "testpost: text response", true) + + dumpFile := "tmp_test_dump_file" + c.EnableDumpAllToFile(getTestFilePath(dumpFile)) + resp, err = c.R().SetBody("test body").Post("/") + assertSuccess(t, resp, err) + dump := string(getTestFileContent(t, dumpFile)) + os.Remove(getTestFilePath(dumpFile)) + assertContains(t, dump, "user-agent", true) + assertContains(t, dump, "test body", false) + assertContains(t, dump, "date", false) + assertContains(t, dump, "testpost: text response", true) + + buf = new(bytes.Buffer) + c.EnableDumpAllTo(buf).EnableDumpAllAsync() + resp, err = c.R().SetBody("test body").Post("/") + assertSuccess(t, resp, err) + time.Sleep(10 * time.Millisecond) + assertContains(t, buf.String(), "user-agent", true) + assertContains(t, buf.String(), "test body", false) + assertContains(t, buf.String(), "date", false) + assertContains(t, buf.String(), "testpost: text response", true) } diff --git a/redirect.go b/redirect.go index ac1c2535..9daa3d56 100644 --- a/redirect.go +++ b/redirect.go @@ -23,7 +23,7 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { // NoRedirectPolicy disable redirect behaviour func NoRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error { - return errors.New("auto redirect is disabled") + return errors.New("redirect is disabled") } } @@ -57,8 +57,9 @@ func AllowedHostRedirectPolicy(hosts ...string) RedirectPolicy { } return func(req *http.Request, via []*http.Request) error { - if _, ok := m[getHostname(req.URL.Host)]; !ok { - return errors.New("redirect host is not allowed") + h := getHostname(req.URL.Host) + if _, ok := m[h]; !ok { + return fmt.Errorf("redirect host [%s] is not allowed", h) } return nil } @@ -73,8 +74,9 @@ func AllowedDomainRedirectPolicy(hosts ...string) RedirectPolicy { } return func(req *http.Request, via []*http.Request) error { - if _, ok := domains[getDomain(req.URL.Host)]; !ok { - return errors.New("redirect domain is not allowed") + domain := getDomain(req.URL.Host) + if _, ok := domains[domain]; !ok { + return fmt.Errorf("redirect domain [%s] is not allowed", domain) } return nil } diff --git a/req_test.go b/req_test.go index 65a55bdc..f94ae63d 100644 --- a/req_test.go +++ b/req_test.go @@ -5,6 +5,8 @@ import ( "encoding/xml" "fmt" "go/token" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" "io" "io/ioutil" "net/http" @@ -81,6 +83,11 @@ func handlePost(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestPost: text response")) case "/raw-upload": io.Copy(ioutil.Discard, r.Body) + case "/form": + r.ParseForm() + ret, _ := json.Marshal(&r.Form) + w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Write(ret) case "/multipart": r.ParseMultipartForm(10e6) m := make(map[string]interface{}) @@ -159,6 +166,15 @@ func handleSearch(w http.ResponseWriter, r *http.Request) { w.Write(data) } +func toGbk(s string) []byte { + reader := transform.NewReader(strings.NewReader(s), simplifiedchinese.GBK.NewEncoder()) + d, e := ioutil.ReadAll(reader) + if e != nil { + panic(e) + } + return d +} + func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -167,8 +183,17 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/host-header": w.Write([]byte(r.Host)) + case "/unlimited-redirect": + w.Header().Set("Location", "/unlimited-redirect") + w.WriteHeader(http.StatusMovedPermanently) + case "/redirect-to-other": + w.Header().Set("Location", "http://dummy.local/test") + w.WriteHeader(http.StatusMovedPermanently) case "/pragma": w.Header().Add("Pragma", "no-cache") + case "/gbk": + w.Header().Set(hdrContentTypeKey, "text/plain; charset=gbk") + w.Write(toGbk("我是roc")) case "/header": b, _ := json.Marshal(r.Header) w.Header().Set(hdrContentTypeKey, jsonContentType) diff --git a/request_test.go b/request_test.go index 32fce396..6b4d005a 100644 --- a/request_test.go +++ b/request_test.go @@ -577,21 +577,6 @@ func testTraceOnTimeout(t *testing.T, c *Client) { assertEqual(t, true, tr.TotalTime == resp.TotalTime()) } -func TestRedirect(t *testing.T) { - testRedirect(t, tc()) - testRedirect(t, tc().EnableForceHTTP1()) -} - -func testRedirect(t *testing.T, c *Client) { - resp, err := c.R().SetBody("test").Post("/redirect") - assertSuccess(t, resp, err) - - c.SetRedirectPolicy(NoRedirectPolicy()) - resp, err = c.R().SetBody("test").Post("/redirect") - assertNotNil(t, err) - assertContains(t, err.Error(), "redirect is disabled", true) -} - func TestAutoDetectRequestContentType(t *testing.T) { resp, err := tc().R().SetBody(getTestFileContent(t, "sample-image.png")).Post("/raw-upload") assertSuccess(t, resp, err) From e51d674b6597730ccb431675f2585419bb270307 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 19:18:16 +0800 Subject: [PATCH 352/843] complete TestAutoDecode --- .testdata/sample-gbk.html | 10 ++++++++++ client_test.go | 20 ++++++++++++++++++++ req_test.go | 7 +++++++ 3 files changed, 37 insertions(+) create mode 100644 .testdata/sample-gbk.html diff --git a/.testdata/sample-gbk.html b/.testdata/sample-gbk.html new file mode 100644 index 00000000..356ff919 --- /dev/null +++ b/.testdata/sample-gbk.html @@ -0,0 +1,10 @@ + + + + + roc + + + Һãroc + + diff --git a/client_test.go b/client_test.go index 9efb7877..3c3b9c08 100644 --- a/client_test.go +++ b/client_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "os" + "strings" "testing" "time" ) @@ -20,6 +21,25 @@ func TestAutoDecode(t *testing.T) { resp, err = c.EnableAutoDecode().R().Get("/gbk") assertSuccess(t, resp, err) assertEqual(t, "我是roc", resp.String()) + + resp, err = c.SetAutoDecodeContentType("html").R().Get("/gbk") + assertSuccess(t, resp, err) + assertEqual(t, toGbk("我是roc"), resp.Bytes()) + resp, err = c.SetAutoDecodeContentType("text").R().Get("/gbk") + assertSuccess(t, resp, err) + assertEqual(t, "我是roc", resp.String()) + resp, err = c.SetAutoDecodeContentTypeFunc(func(contentType string) bool { + if strings.Contains(contentType, "text") { + return true + } + return false + }).R().Get("/gbk") + assertSuccess(t, resp, err) + assertEqual(t, "我是roc", resp.String()) + + resp, err = c.SetAutoDecodeAllContentType().R().Get("/gbk-no-charset") + assertSuccess(t, resp, err) + assertContains(t, resp.String(), "我是roc", true) } func TestSetTimeout(t *testing.T) { diff --git a/req_test.go b/req_test.go index f94ae63d..04d79609 100644 --- a/req_test.go +++ b/req_test.go @@ -194,6 +194,13 @@ func handleGet(w http.ResponseWriter, r *http.Request) { case "/gbk": w.Header().Set(hdrContentTypeKey, "text/plain; charset=gbk") w.Write(toGbk("我是roc")) + case "/gbk-no-charset": + b, err := ioutil.ReadFile(getTestFilePath("sample-gbk.html")) + if err != nil { + panic(err) + } + w.Header().Set(hdrContentTypeKey, "text/html") + w.Write(b) case "/header": b, _ := json.Marshal(r.Header) w.Header().Set(hdrContentTypeKey, jsonContentType) From c1a5cd9bd5ef9d3532afea44863d64f5fe01e081 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 20:15:32 +0800 Subject: [PATCH 353/843] add more tests for client --- client_test.go | 155 +++++++++++++++++++++++++++++++++++++++++++++++++ req_test.go | 3 + 2 files changed, 158 insertions(+) diff --git a/client_test.go b/client_test.go index 3c3b9c08..e8e439f2 100644 --- a/client_test.go +++ b/client_test.go @@ -2,8 +2,11 @@ package req import ( "bytes" + "context" "crypto/tls" + "errors" "io/ioutil" + "net" "net/http" "net/url" "os" @@ -12,6 +15,158 @@ import ( "time" ) +func TestAllowGetMethodPayload(t *testing.T) { + c := tc() + resp, err := c.R().SetBody("test").Get("/payload") + assertSuccess(t, resp, err) + assertEqual(t, "", resp.String()) + + c.EnableAllowGetMethodPayload() + resp, err = c.R().SetBody("test").Get("/payload") + assertSuccess(t, resp, err) + assertEqual(t, "test", resp.String()) + + c.DisableAllowGetMethodPayload() + resp, err = c.R().SetBody("test").Get("/payload") + assertSuccess(t, resp, err) + assertEqual(t, "", resp.String()) +} + +func TestSetTLSHandshakeTimeout(t *testing.T) { + timeout := 2 * time.Second + c := tc().SetTLSHandshakeTimeout(timeout) + assertEqual(t, timeout, c.t.TLSHandshakeTimeout) +} + +func TestSetDial(t *testing.T) { + testErr := errors.New("test") + testDial := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, testErr + } + c := tc().SetDial(testDial) + _, err := c.t.DialContext(nil, "", "") + assertEqual(t, testErr, err) +} + +func TestSetDialTLS(t *testing.T) { + testErr := errors.New("test") + testDialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, testErr + } + c := tc().SetDialTLS(testDialTLS) + _, err := c.t.DialTLSContext(nil, "", "") + assertEqual(t, testErr, err) +} + +func TestSetFuncs(t *testing.T) { + testErr := errors.New("test") + marshalFunc := func(v interface{}) ([]byte, error) { + return nil, testErr + } + unmarshalFunc := func(data []byte, v interface{}) error { + return testErr + } + c := tc(). + SetJsonMarshal(marshalFunc). + SetJsonUnmarshal(unmarshalFunc). + SetXmlMarshal(marshalFunc). + SetXmlUnmarshal(unmarshalFunc) + + _, err := c.jsonMarshal(nil) + assertEqual(t, testErr, err) + err = c.jsonUnmarshal(nil, nil) + assertEqual(t, testErr, err) + + _, err = c.xmlMarshal(nil) + assertEqual(t, testErr, err) + err = c.xmlUnmarshal(nil, nil) + assertEqual(t, testErr, err) +} + +func TestSetCookieJar(t *testing.T) { + c := tc().SetCookieJar(nil) + assertEqual(t, nil, c.httpClient.Jar) +} + +func TestTraceAll(t *testing.T) { + c := tc().EnableTraceAll() + resp, err := c.R().Get("/") + assertSuccess(t, resp, err) + assertEqual(t, true, resp.TraceInfo().TotalTime > 0) + + c.DisableTraceAll() + resp, err = c.R().Get("/") + assertSuccess(t, resp, err) + assertEqual(t, true, resp.TraceInfo().TotalTime == 0) +} + +func TestOnAfterResponse(t *testing.T) { + c := tc() + len1 := len(c.afterResponse) + c.OnAfterResponse(func(client *Client, response *Response) error { + return nil + }) + len2 := len(c.afterResponse) + assertEqual(t, true, len1+1 == len2) +} + +func TestOnBeforeRequest(t *testing.T) { + c := tc().OnBeforeRequest(func(client *Client, request *Request) error { + return nil + }) + assertEqual(t, true, len(c.udBeforeRequest) == 1) +} + +func TestSetProxyURL(t *testing.T) { + c := tc().SetProxyURL("http://dummy.proxy.local") + u, err := c.t.Proxy(nil) + assertError(t, err) + assertEqual(t, "http://dummy.proxy.local", u.String()) +} + +func TestSetProxy(t *testing.T) { + u, _ := url.Parse("http://dummy.proxy.local") + proxy := http.ProxyURL(u) + c := tc().SetProxy(proxy) + uu, err := c.t.Proxy(nil) + assertError(t, err) + assertEqual(t, u.String(), uu.String()) +} + +func TestSetCommonContentType(t *testing.T) { + c := tc().SetCommonContentType(jsonContentType) + assertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) +} + +func TestSetCommonHeader(t *testing.T) { + c := tc().SetCommonHeader("my-header", "my-value") + assertEqual(t, "my-value", c.Headers.Get("my-header")) +} + +func TestSetCommonHeaders(t *testing.T) { + c := tc().SetCommonHeaders(map[string]string{ + "header1": "value1", + "header2": "value2", + }) + assertEqual(t, "value1", c.Headers.Get("header1")) + assertEqual(t, "value2", c.Headers.Get("header2")) +} + +func TestSetCommonBasicAuth(t *testing.T) { + c := tc().SetCommonBasicAuth("imroc", "123456") + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) +} + +func TestSetCommonBearerAuthToken(t *testing.T) { + c := tc().SetCommonBearerAuthToken("123456") + assertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) +} + +func TestSetUserAgent(t *testing.T) { + c := tc().SetUserAgent("test") + assertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) +} + func TestAutoDecode(t *testing.T) { c := tc().DisableAutoDecode() resp, err := c.R().Get("/gbk") diff --git a/req_test.go b/req_test.go index 04d79609..f6fee941 100644 --- a/req_test.go +++ b/req_test.go @@ -191,6 +191,9 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMovedPermanently) case "/pragma": w.Header().Add("Pragma", "no-cache") + case "/payload": + b, _ := ioutil.ReadAll(r.Body) + w.Write(b) case "/gbk": w.Header().Set(hdrContentTypeKey, "text/plain; charset=gbk") w.Write(toGbk("我是roc")) From e3c80c165db65c17398c216279262c16b918504b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 21:15:15 +0800 Subject: [PATCH 354/843] add TestGlobalWrapper --- req_test.go | 221 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) diff --git a/req_test.go b/req_test.go index f6fee941..c8e4d42c 100644 --- a/req_test.go +++ b/req_test.go @@ -1,22 +1,28 @@ package req import ( + "context" + "crypto/tls" "encoding/json" "encoding/xml" + "errors" "fmt" "go/token" "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/transform" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" "strings" "sync" "testing" + "time" "unsafe" ) @@ -374,3 +380,218 @@ func isNil(v interface{}) bool { } return false } + +func TestGlobalWrapper(t *testing.T) { + + SetCookieJar(nil) + assertEqual(t, nil, DefaultClient().httpClient.Jar) + + testErr := errors.New("test") + testDial := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, testErr + } + testDialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, testErr + } + SetDialTLS(testDialTLS) + SetDial(testDial) + _, err := DefaultClient().t.DialTLSContext(nil, "", "") + assertEqual(t, testErr, err) + _, err = DefaultClient().t.DialContext(nil, "", "") + assertEqual(t, testErr, err) + + timeout := 2 * time.Second + SetTLSHandshakeTimeout(timeout) + assertEqual(t, timeout, DefaultClient().t.TLSHandshakeTimeout) + + EnableAllowGetMethodPayload() + assertEqual(t, true, DefaultClient().AllowGetMethodPayload) + + marshalFunc := func(v interface{}) ([]byte, error) { + return nil, testErr + } + unmarshalFunc := func(data []byte, v interface{}) error { + return testErr + } + SetJsonMarshal(marshalFunc) + SetJsonUnmarshal(unmarshalFunc) + SetXmlMarshal(marshalFunc) + SetXmlUnmarshal(unmarshalFunc) + _, err = DefaultClient().jsonMarshal(nil) + assertEqual(t, testErr, err) + err = DefaultClient().jsonUnmarshal(nil, nil) + assertEqual(t, testErr, err) + _, err = DefaultClient().xmlMarshal(nil) + assertEqual(t, testErr, err) + err = DefaultClient().xmlUnmarshal(nil, nil) + assertEqual(t, testErr, err) + + EnableTraceAll() + assertEqual(t, true, DefaultClient().trace) + DisableTraceAll() + assertEqual(t, false, DefaultClient().trace) + + len1 := len(DefaultClient().afterResponse) + OnAfterResponse(func(client *Client, response *Response) error { + return nil + }) + len2 := len(DefaultClient().afterResponse) + assertEqual(t, true, len1+1 == len2) + + OnBeforeRequest(func(client *Client, request *Request) error { + return nil + }) + assertEqual(t, true, len(DefaultClient().udBeforeRequest) == 1) + + SetProxyURL("http://dummy.proxy.local") + u, err := DefaultClient().t.Proxy(nil) + assertError(t, err) + assertEqual(t, "http://dummy.proxy.local", u.String()) + + u, _ = url.Parse("http://dummy.proxy.local") + proxy := http.ProxyURL(u) + SetProxy(proxy) + uu, err := DefaultClient().t.Proxy(nil) + assertError(t, err) + assertEqual(t, u.String(), uu.String()) + + SetCommonContentType(jsonContentType) + assertEqual(t, jsonContentType, DefaultClient().Headers.Get(hdrContentTypeKey)) + + SetCommonHeader("my-header", "my-value") + assertEqual(t, "my-value", DefaultClient().Headers.Get("my-header")) + + SetCommonHeaders(map[string]string{ + "header1": "value1", + "header2": "value2", + }) + assertEqual(t, "value1", DefaultClient().Headers.Get("header1")) + assertEqual(t, "value2", DefaultClient().Headers.Get("header2")) + + SetCommonBasicAuth("imroc", "123456") + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", DefaultClient().Headers.Get("Authorization")) + + SetCommonBearerAuthToken("123456") + assertEqual(t, "Bearer 123456", DefaultClient().Headers.Get("Authorization")) + + SetUserAgent("test") + assertEqual(t, "test", DefaultClient().Headers.Get(hdrUserAgentKey)) + + SetTimeout(timeout) + assertEqual(t, timeout, DefaultClient().httpClient.Timeout) + + l := createDefaultLogger() + SetLogger(l) + assertEqual(t, l, DefaultClient().log) + + SetScheme("https") + assertEqual(t, "https", DefaultClient().scheme) + + EnableDebugLog() + assertEqual(t, true, DefaultClient().DebugLog) + + DisableDebugLog() + assertEqual(t, false, DefaultClient().DebugLog) + + SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}) + assertEqual(t, "test", DefaultClient().Cookies[0].Name) + + SetCommonQueryString("test1=test1") + assertEqual(t, "test1", DefaultClient().QueryParams.Get("test1")) + + SetCommonPathParams(map[string]string{"test1": "test1"}) + assertEqual(t, "test1", DefaultClient().PathParams["test1"]) + + SetCommonPathParam("test2", "test2") + assertEqual(t, "test2", DefaultClient().PathParams["test2"]) + + AddCommonQueryParam("test1", "test11") + assertEqual(t, []string{"test1", "test11"}, DefaultClient().QueryParams["test1"]) + + SetCommonQueryParam("test1", "test111") + assertEqual(t, "test111", DefaultClient().QueryParams.Get("test1")) + + SetCommonQueryParams(map[string]string{"test1": "test1"}) + assertEqual(t, "test1", DefaultClient().QueryParams.Get("test1")) + + EnableInsecureSkipVerify() + assertEqual(t, true, DefaultClient().t.TLSClientConfig.InsecureSkipVerify) + + DisableInsecureSkipVerify() + assertEqual(t, false, DefaultClient().t.TLSClientConfig.InsecureSkipVerify) + + DisableCompression() + assertEqual(t, true, DefaultClient().t.DisableCompression) + + EnableCompression() + assertEqual(t, false, DefaultClient().t.DisableCompression) + + DisableKeepAlives() + assertEqual(t, true, DefaultClient().t.DisableKeepAlives) + + EnableKeepAlives() + assertEqual(t, false, DefaultClient().t.DisableKeepAlives) + + config := GetTLSClientConfig() + assertEqual(t, config, DefaultClient().t.TLSClientConfig) + + SetRootCertsFromFile(getTestFilePath("sample-root.pem")) + assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) + + SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) + assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) + + SetCerts(tls.Certificate{}, tls.Certificate{}) + assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 2) + + SetCertFromFile( + getTestFilePath("sample-client.pem"), + getTestFilePath("sample-client-key.pem"), + ) + assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 3) + + SetOutputDirectory(testDataPath) + assertEqual(t, testDataPath, DefaultClient().outputDirectory) + + baseURL := "http://dummy-req.local/test" + SetBaseURL(baseURL) + assertEqual(t, baseURL, DefaultClient().BaseURL) + + form := make(url.Values) + form.Add("test", "test") + SetCommonFormDataFromValues(form) + assertEqual(t, form, DefaultClient().FormData) + + SetCommonFormData(map[string]string{"test2": "test2"}) + assertEqual(t, "test2", DefaultClient().FormData.Get("test2")) + + DisableAutoReadResponse() + assertEqual(t, true, DefaultClient().disableAutoReadResponse) + EnableAutoReadResponse() + assertEqual(t, false, DefaultClient().disableAutoReadResponse) + + EnableDumpAll() + opt := DefaultClient().getDumpOptions() + assertEqual(t, true, opt.RequestHeader == true && opt.RequestBody == true && opt.ResponseHeader == true && opt.ResponseBody == true) + EnableDumpAllAsync() + assertEqual(t, true, opt.Async) + EnableDumpAllWithoutBody() + assertEqual(t, true, opt.ResponseBody == false && opt.RequestBody == false) + opt.ResponseBody = true + opt.RequestBody = true + EnableDumpAllWithoutResponse() + assertEqual(t, true, opt.ResponseBody == false && opt.ResponseHeader == false) + opt.ResponseBody = true + opt.ResponseHeader = true + EnableDumpAllWithoutRequest() + assertEqual(t, true, opt.RequestHeader == false && opt.RequestBody == false) + opt.RequestHeader = true + opt.RequestBody = true + EnableDumpAllWithoutHeader() + assertEqual(t, true, opt.RequestHeader == false && opt.ResponseHeader == false) + SetCommonDumpOptions(&DumpOptions{ + RequestHeader: true, + }) + opt = DefaultClient().getDumpOptions() + assertEqual(t, true, opt.RequestHeader == true && opt.ResponseHeader == false) +} From 3c61417b60f7dcf0f63c2f4aee86a908cd840443 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 21:35:09 +0800 Subject: [PATCH 355/843] optimize tests --- .testdata/sample-gbk.html | 343 ++++++++++++++++++++++++++++++++++++++ http_test.go | 34 +++- transfer.go | 14 -- 3 files changed, 376 insertions(+), 15 deletions(-) diff --git a/.testdata/sample-gbk.html b/.testdata/sample-gbk.html index 356ff919..cc7b22fa 100644 --- a/.testdata/sample-gbk.html +++ b/.testdata/sample-gbk.html @@ -6,5 +6,348 @@ Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc + Һãroc diff --git a/http_test.go b/http_test.go index f432e859..7dca3a45 100644 --- a/http_test.go +++ b/http_test.go @@ -1,6 +1,38 @@ package req -import "testing" +import ( + "reflect" + "testing" +) + +func TestForeachHeaderElement(t *testing.T) { + tests := []struct { + in string + want []string + }{ + {"Foo", []string{"Foo"}}, + {" Foo", []string{"Foo"}}, + {"Foo ", []string{"Foo"}}, + {" Foo ", []string{"Foo"}}, + + {"foo", []string{"foo"}}, + {"anY-cAsE", []string{"anY-cAsE"}}, + + {"", nil}, + {",,,, , ,, ,,, ,", nil}, + + {" Foo,Bar, Baz,lower,,Quux ", []string{"Foo", "Bar", "Baz", "lower", "Quux"}}, + } + for _, tt := range tests { + var got []string + foreachHeaderElement(tt.in, func(v string) { + got = append(got, v) + }) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("foreachHeaderElement(%q) = %q; want %q", tt.in, got, tt.want) + } + } +} func TestCleanHost(t *testing.T) { tests := []struct { diff --git a/transfer.go b/transfer.go index bdee2245..3e7f9f94 100644 --- a/transfer.go +++ b/transfer.go @@ -104,20 +104,6 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { } atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0 - case *http.Response: - t.IsResponse = true - if rr.Request != nil { - t.Method = rr.Request.Method - } - t.Body = rr.Body - t.BodyCloser = rr.Body - t.ContentLength = rr.ContentLength - t.Close = rr.Close - t.TransferEncoding = rr.TransferEncoding - t.Header = rr.Header - t.Trailer = rr.Trailer - atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - t.ResponseToHEAD = noResponseBodyExpected(t.Method) } // Sanitize Body,ContentLength,TransferEncoding From e1966cc0257ff692d008a0b397802055789e6a4a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 15 Feb 2022 21:44:03 +0800 Subject: [PATCH 356/843] add TestSetFileBytes --- req_test.go | 6 ++++++ request_test.go | 8 +++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index c8e4d42c..a352968e 100644 --- a/req_test.go +++ b/req_test.go @@ -89,6 +89,12 @@ func handlePost(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestPost: text response")) case "/raw-upload": io.Copy(ioutil.Discard, r.Body) + case "/file-text": + r.ParseMultipartForm(10e6) + files := r.MultipartForm.File["file"] + file, _ := files[0].Open() + b, _ := ioutil.ReadAll(file) + w.Write(b) case "/form": r.ParseForm() ret, _ := json.Marshal(&r.Form) diff --git a/request_test.go b/request_test.go index 6b4d005a..3a620672 100644 --- a/request_test.go +++ b/request_test.go @@ -587,7 +587,7 @@ func TestUploadMultipart(t *testing.T) { m := make(map[string]interface{}) resp, err := tc().R(). SetFile("file", getTestFilePath("sample-image.png")). - SetFile("file", getTestFilePath("sample-file.txt")). + SetFiles(map[string]string{"file": getTestFilePath("sample-file.txt")}). SetFormData(map[string]string{ "param1": "value1", "param2": "value2", @@ -606,3 +606,9 @@ func TestFixPragmaCache(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) } + +func TestSetFileBytes(t *testing.T) { + resp, err := tc().R().SetFileBytes("file", "file.txt", []byte("test")).Post("/file-text") + assertSuccess(t, resp, err) + assertEqual(t, "test", resp.String()) +} From fd76286b9bc60a2bdb6924a7e9c6c57af41afe8c Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 09:48:17 +0800 Subject: [PATCH 357/843] split enable dump tests into seperate test funcs --- request_test.go | 140 ++++++++++++++++++++++++++++-------------------- 1 file changed, 81 insertions(+), 59 deletions(-) diff --git a/request_test.go b/request_test.go index 3a620672..36d2e315 100644 --- a/request_test.go +++ b/request_test.go @@ -11,65 +11,78 @@ import ( "time" ) -func TestRequestDump(t *testing.T) { - testRequestDump(t, tc()) - testRequestDump(t, tc().EnableForceHTTP1()) -} - -func testRequestDump(t *testing.T, c *Client) { - testCases := []func(r *Request, reqHeader, reqBody, respHeader, respBody *bool){ - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDump() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = true - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutRequest() - *reqHeader = false - *reqBody = false - *respHeader = true - *respBody = true - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutRequestBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = true - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutResponse() - *reqHeader = true - *reqBody = true - *respHeader = false - *respBody = false - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutResponseBody() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = false - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutHeader() - *reqHeader = false - *reqBody = true - *respHeader = false - *respBody = true - }, - func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = false - }, - } - - for _, fn := range testCases { +func TestEnableDump(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDump() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = true + }) +} + +func TestEnableDumpWithoutRequest(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutRequest() + *reqHeader = false + *reqBody = false + *respHeader = true + *respBody = true + }) +} + +func TestEnableDumpWithoutRequestBody(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutRequestBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = true + }) +} + +func TestEnableDumpWithoutResponse(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutResponse() + *reqHeader = true + *reqBody = true + *respHeader = false + *respBody = false + }) +} + +func TestEnableDumpWithoutResponseBody(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutResponseBody() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = false + }) +} + +func TestEnableDumpWithoutHeader(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutHeader() + *reqHeader = false + *reqBody = true + *respHeader = false + *respBody = true + }) +} + +func TestEnableDumpWithoutBody(t *testing.T) { + testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { + r.EnableDumpWithoutBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = false + }) +} + +func testEnableDump(t *testing.T, fn func(r *Request, reqHeader, reqBody, respHeader, respBody *bool)) { + testDump := func(c *Client) { r := c.R() var reqHeader, reqBody, respHeader, respBody bool fn(r, &reqHeader, &reqBody, &respHeader, &respBody) @@ -81,7 +94,16 @@ func testRequestDump(t *testing.T, c *Client) { assertContains(t, dump, "date", respHeader) assertContains(t, dump, "testpost: text response", respBody) } + testDump(tc()) + testDump(tc().EnableForceHTTP1()) +} + +func TestSetDumpOptions(t *testing.T) { + testSetDumpOptions(t, tc()) + testSetDumpOptions(t, tc().EnableForceHTTP1()) +} +func testSetDumpOptions(t *testing.T, c *Client) { opt := &DumpOptions{ RequestHeader: true, RequestBody: false, From d6a14e68437f5dd080b53711099e9f010f8f8184 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 09:59:32 +0800 Subject: [PATCH 358/843] add TestMethods --- req_test.go | 1 + request_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/req_test.go b/req_test.go index a352968e..0cffcade 100644 --- a/req_test.go +++ b/req_test.go @@ -47,6 +47,7 @@ func createTestServer() *httptest.Server { } func handleHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Method", r.Method) switch r.Method { case http.MethodGet: handleGet(w, r) diff --git a/request_test.go b/request_test.go index 36d2e315..63128f57 100644 --- a/request_test.go +++ b/request_test.go @@ -11,6 +11,55 @@ import ( "time" ) +func TestMethods(t *testing.T) { + testMethods(t, tc()) + testMethods(t, tc().EnableForceHTTP1()) +} + +func testMethods(t *testing.T, c *Client) { + resp, err := c.R().Put("/") + assertSuccess(t, resp, err) + assertEqual(t, "PUT", resp.Header.Get("Method")) + resp = c.R().MustPut("/") + assertEqual(t, "PUT", resp.Header.Get("Method")) + + resp, err = c.R().Patch("/") + assertSuccess(t, resp, err) + assertEqual(t, "PATCH", resp.Header.Get("Method")) + resp = c.R().MustPatch("/") + assertEqual(t, "PATCH", resp.Header.Get("Method")) + + resp, err = c.R().Delete("/") + assertSuccess(t, resp, err) + assertEqual(t, "DELETE", resp.Header.Get("Method")) + resp = c.R().MustDelete("/") + assertEqual(t, "DELETE", resp.Header.Get("Method")) + + resp, err = c.R().Options("/") + assertSuccess(t, resp, err) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + resp = c.R().MustOptions("/") + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + + resp, err = c.R().Head("/") + assertSuccess(t, resp, err) + assertEqual(t, "HEAD", resp.Header.Get("Method")) + resp = c.R().MustHead("/") + assertEqual(t, "HEAD", resp.Header.Get("Method")) + + resp, err = c.R().Get("/") + assertSuccess(t, resp, err) + assertEqual(t, "GET", resp.Header.Get("Method")) + resp = c.R().MustGet("/") + assertEqual(t, "GET", resp.Header.Get("Method")) + + resp, err = c.R().Post("/") + assertSuccess(t, resp, err) + assertEqual(t, "POST", resp.Header.Get("Method")) + resp = c.R().MustPost("/") + assertEqual(t, "POST", resp.Header.Get("Method")) +} + func TestEnableDump(t *testing.T) { testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { r.EnableDump() From edb94300f4f1c5c13dc92215fee7328f84a3645a Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 10:05:38 +0800 Subject: [PATCH 359/843] add tests for global wrapper --- req_test.go | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 0cffcade..77374007 100644 --- a/req_test.go +++ b/req_test.go @@ -389,6 +389,52 @@ func isNil(v interface{}) bool { } func TestGlobalWrapper(t *testing.T) { + EnableInsecureSkipVerify() + testURL := getTestServerURL() + "/" + + resp, err := Put(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "PUT", resp.Header.Get("Method")) + resp = MustPut(testURL) + assertEqual(t, "PUT", resp.Header.Get("Method")) + + resp, err = Patch(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "PATCH", resp.Header.Get("Method")) + resp = MustPatch(testURL) + assertEqual(t, "PATCH", resp.Header.Get("Method")) + + resp, err = Delete(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "DELETE", resp.Header.Get("Method")) + resp = MustDelete(testURL) + assertEqual(t, "DELETE", resp.Header.Get("Method")) + + resp, err = Options(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + resp = MustOptions(testURL) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + + resp, err = Head(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "HEAD", resp.Header.Get("Method")) + resp = MustHead(testURL) + assertEqual(t, "HEAD", resp.Header.Get("Method")) + + resp, err = Get(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "GET", resp.Header.Get("Method")) + resp = MustGet(testURL) + assertEqual(t, "GET", resp.Header.Get("Method")) + + resp, err = Post(testURL) + assertSuccess(t, resp, err) + assertEqual(t, "POST", resp.Header.Get("Method")) + resp = MustPost(testURL) + assertEqual(t, "POST", resp.Header.Get("Method")) + + DisableInsecureSkipVerify() SetCookieJar(nil) assertEqual(t, nil, DefaultClient().httpClient.Jar) @@ -402,7 +448,7 @@ func TestGlobalWrapper(t *testing.T) { } SetDialTLS(testDialTLS) SetDial(testDial) - _, err := DefaultClient().t.DialTLSContext(nil, "", "") + _, err = DefaultClient().t.DialTLSContext(nil, "", "") assertEqual(t, testErr, err) _, err = DefaultClient().t.DialContext(nil, "", "") assertEqual(t, testErr, err) From ff9267b122fc9d75d12f2d49c873e24d2861ef21 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 10:27:04 +0800 Subject: [PATCH 360/843] add testGlobalWrapperEnableDumps --- req_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/req_test.go b/req_test.go index 77374007..0320ba21 100644 --- a/req_test.go +++ b/req_test.go @@ -388,8 +388,77 @@ func isNil(v interface{}) bool { return false } -func TestGlobalWrapper(t *testing.T) { - EnableInsecureSkipVerify() +func testGlobalWrapperEnableDumps(t *testing.T) { + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = true + return EnableDump() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = false + *reqBody = false + *respHeader = true + *respBody = true + return EnableDumpWithoutRequest() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = true + return EnableDumpWithoutRequestBody() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = true + *reqBody = true + *respHeader = false + *respBody = false + return EnableDumpWithoutResponse() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = false + return EnableDumpWithoutResponseBody() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = false + *reqBody = true + *respHeader = false + *respBody = true + return EnableDumpWithoutHeader() + }) + + testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = false + return EnableDumpWithoutBody() + }) +} + +func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { + var reqHeader, reqBody, respHeader, respBody bool + r := fn(&reqHeader, &reqBody, &respHeader, &respBody) + resp, err := r.SetBody(`test body`).Post(getTestServerURL() + "/") + assertSuccess(t, resp, err) + dump := resp.Dump() + assertContains(t, dump, "user-agent", reqHeader) + assertContains(t, dump, "test body", reqBody) + assertContains(t, dump, "date", respHeader) + assertContains(t, dump, "testpost: text response", respBody) +} + +func testGlobalWrapperSendRequest(t *testing.T) { testURL := getTestServerURL() + "/" resp, err := Put(testURL) @@ -433,7 +502,12 @@ func TestGlobalWrapper(t *testing.T) { assertEqual(t, "POST", resp.Header.Get("Method")) resp = MustPost(testURL) assertEqual(t, "POST", resp.Header.Get("Method")) +} +func TestGlobalWrapper(t *testing.T) { + EnableInsecureSkipVerify() + testGlobalWrapperSendRequest(t) + testGlobalWrapperEnableDumps(t) DisableInsecureSkipVerify() SetCookieJar(nil) @@ -448,7 +522,7 @@ func TestGlobalWrapper(t *testing.T) { } SetDialTLS(testDialTLS) SetDial(testDial) - _, err = DefaultClient().t.DialTLSContext(nil, "", "") + _, err := DefaultClient().t.DialTLSContext(nil, "", "") assertEqual(t, testErr, err) _, err = DefaultClient().t.DialContext(nil, "", "") assertEqual(t, testErr, err) From a54dbcdcc1b155126802239a01ba21a093405448 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 13:07:04 +0800 Subject: [PATCH 361/843] add more global wrapper tests --- req_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++++++ request.go | 6 ++- 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 0320ba21..595fa3ad 100644 --- a/req_test.go +++ b/req_test.go @@ -1,6 +1,7 @@ package req import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -504,6 +505,98 @@ func testGlobalWrapperSendRequest(t *testing.T) { assertEqual(t, "POST", resp.Header.Get("Method")) } +func TestGlobalWrapperSetRequest(t *testing.T) { + testFilePath := getTestFilePath("sample-file.txt") + r := SetFiles(map[string]string{"test": testFilePath}) + assertEqual(t, 1, len(r.uploadFiles)) + assertEqual(t, true, r.isMultiPart) + + r = SetFile("test", getTestFilePath("sample-file.txt")) + assertEqual(t, 1, len(r.uploadFiles)) + assertEqual(t, true, r.isMultiPart) + + SetLogger(nil) + r = SetFile("test", getTestFilePath("file-not-exists.txt")) + assertEqual(t, 0, len(r.uploadFiles)) + assertEqual(t, false, r.isMultiPart) + assertNotNil(t, r.error) + + r = SetFileReader("test", "test.txt", bytes.NewBufferString("test")) + assertEqual(t, 1, len(r.uploadFiles)) + assertEqual(t, true, r.isMultiPart) + + r = SetFileBytes("test", "test.txt", []byte("test")) + assertEqual(t, 1, len(r.uploadFiles)) + assertEqual(t, true, r.isMultiPart) + + r = SetFileUpload(FileUpload{}) + assertEqual(t, 1, len(r.uploadFiles)) + assertEqual(t, true, r.isMultiPart) + + var result string + r = SetError(&result) + assertEqual(t, true, r.Error != nil) + + r = SetResult(&result) + assertEqual(t, true, r.Result != nil) + + r = SetOutput(nil) + assertEqual(t, false, r.isSaveResponse) + + r = SetOutput(bytes.NewBufferString("test")) + assertEqual(t, true, r.isSaveResponse) + + r = SetHeader("test", "test") + assertEqual(t, "test", r.Headers.Get("test")) + + r = SetHeaders(map[string]string{"test": "test"}) + assertEqual(t, "test", r.Headers.Get("test")) + + r = SetCookies(&http.Cookie{ + Name: "test", + Value: "test", + }) + assertEqual(t, 1, len(r.Cookies)) + + r = SetBasicAuth("imroc", "123456") + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", r.Headers.Get("Authorization")) + + r = SetBearerAuthToken("123456") + assertEqual(t, "Bearer 123456", r.Headers.Get("Authorization")) + + r = SetQueryString("test=test") + assertEqual(t, "test", r.QueryParams.Get("test")) + + r = SetQueryString("ksjlfjk?") + assertEqual(t, "", r.QueryParams.Get("test")) + + r = SetQueryParam("test", "test") + assertEqual(t, "test", r.QueryParams.Get("test")) + + r = AddQueryParam("test", "test") + assertEqual(t, "test", r.QueryParams.Get("test")) + + r = SetQueryParams(map[string]string{"test": "test"}) + assertEqual(t, "test", r.QueryParams.Get("test")) + + r = SetPathParam("test", "test") + assertEqual(t, "test", r.PathParams["test"]) + + r = SetPathParams(map[string]string{"test": "test"}) + assertEqual(t, "test", r.PathParams["test"]) + + r = SetFormData(map[string]string{"test": "test"}) + assertEqual(t, "test", r.FormData.Get("test")) + + values := make(url.Values) + values.Add("test", "test") + r = SetFormDataFromValues(values) + assertEqual(t, "test", r.FormData.Get("test")) + + r = SetContentType(jsonContentType) + assertEqual(t, jsonContentType, r.Headers.Get(hdrContentTypeKey)) +} + func TestGlobalWrapper(t *testing.T) { EnableInsecureSkipVerify() testGlobalWrapperSendRequest(t) @@ -575,6 +668,11 @@ func TestGlobalWrapper(t *testing.T) { assertError(t, err) assertEqual(t, "http://dummy.proxy.local", u.String()) + SetProxyURL("bad url") + u, err = DefaultClient().t.Proxy(nil) + assertError(t, err) + assertNotEqual(t, "bad url", u.String()) + u, _ = url.Parse("http://dummy.proxy.local") proxy := http.ProxyURL(u) SetProxy(proxy) @@ -716,9 +814,22 @@ func TestGlobalWrapper(t *testing.T) { opt.RequestBody = true EnableDumpAllWithoutHeader() assertEqual(t, true, opt.RequestHeader == false && opt.ResponseHeader == false) + + DefaultClient().getDumpOptions().Output = nil + SetLogger(nil) + EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")) + assertEqual(t, true, DefaultClient().getDumpOptions().Output == nil) + + dumpFile := getTestFilePath("tmpdump.out") + EnableDumpAllToFile(dumpFile) + assertEqual(t, true, DefaultClient().getDumpOptions().Output != nil) + os.Remove(dumpFile) + SetCommonDumpOptions(&DumpOptions{ RequestHeader: true, }) opt = DefaultClient().getDumpOptions() assertEqual(t, true, opt.RequestHeader == true && opt.ResponseHeader == false) + DisableDumpAll() + assertEqual(t, true, DefaultClient().t.dump == nil) } diff --git a/request.go b/request.go index 3123b723..a1991bfe 100644 --- a/request.go +++ b/request.go @@ -251,13 +251,13 @@ func SetFile(paramName, filePath string) *Request { // SetFile set up a multipart form from file path to upload, // which read file from filePath automatically to upload. func (r *Request) SetFile(paramName, filePath string) *Request { - r.isMultiPart = true file, err := os.Open(filePath) if err != nil { r.client.log.Errorf("failed to open %s: %v", filePath, err) r.appendError(err) return r } + r.isMultiPart = true return r.SetFileReader(paramName, filepath.Base(filePath), file) } @@ -372,6 +372,10 @@ func SetOutput(output io.Writer) *Request { // SetOutput set the io.Writer that response body will be downloaded to. func (r *Request) SetOutput(output io.Writer) *Request { + if output == nil { + r.client.log.Warnf("nil io.Writer is not allowed in SetOutput") + return r + } r.output = output r.isSaveResponse = true return r From 579434ca002dce5a5e1fb082ebed246c9fab3244 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 13:35:18 +0800 Subject: [PATCH 362/843] update README: slogan --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index beee6696..ddaab9b4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

Req

-

Simplified Golang HTTP client library with Black Magic, Less code and More efficiency.

+

Simple Go HTTP client with Black Magic (Less code and More efficiency).

From d02c879a788f0cb650ca6a4456079102189ade64 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 14:34:23 +0800 Subject: [PATCH 363/843] setup github actions --- .github/workflows/ci.yml | 46 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..9256b8d3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI + +on: + push: + branches: + - master + paths-ignore: + - '**.md' + pull_request: + branches: + - master + paths-ignore: + - '**.md' + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + build: + name: Build + strategy: + matrix: + go: [ '1.17.x', '1.16.x' ] + os: [ ubuntu-latest ] + + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Test + run: go test ./... -coverprofile=coverage.txt -covermode=atomic + + - name: Coverage + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + run: | + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + ./codecov -t ${CODECOV_TOKEN} From 7349811a61a9c3429f9459dc7b68c8fa29eae192 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 14:47:42 +0800 Subject: [PATCH 364/843] add badages --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ddaab9b4..f86986d6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,13 @@

Req

Simple Go HTTP client with Black Magic (Less code and More efficiency).

-

+

+ Build Status + Code Coverage + Go Report Card + + License +

## News From 2b632d00dbb1280bc5c555e5cb6e756a036863e0 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 15:07:44 +0800 Subject: [PATCH 365/843] split client dump tests --- client_test.go | 159 ++++++++++++++++++++++++++----------------------- 1 file changed, 86 insertions(+), 73 deletions(-) diff --git a/client_test.go b/client_test.go index e8e439f2..32099342 100644 --- a/client_test.go +++ b/client_test.go @@ -436,72 +436,80 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { assertError(t, err) } -func TestClientDump(t *testing.T) { - testClientDump(t, func() *Client { - return tc() +func TestEnableDumpAll(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAll() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = true }) - testClientDump(t, func() *Client { - return tc().EnableForceHTTP1() +} + +func TestEnableDumpAllWithoutRequest(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutRequest() + *reqHeader = false + *reqBody = false + *respHeader = true + *respBody = true }) } -func testClientDump(t *testing.T, newClient func() *Client) { - testCases := []func(r *Client, reqHeader, reqBody, respHeader, respBody *bool){ - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAll() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = true - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutRequest() - *reqHeader = false - *reqBody = false - *respHeader = true - *respBody = true - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutRequestBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = true - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutResponse() - *reqHeader = true - *reqBody = true - *respHeader = false - *respBody = false - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutResponseBody() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = false - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutHeader() - *reqHeader = false - *reqBody = true - *respHeader = false - *respBody = true - }, - func(r *Client, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpAllWithoutBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = false - }, - } +func TestEnableDumpAllWithoutRequestBody(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutRequestBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = true + }) +} + +func TestEnableDumpAllWithoutResponse(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutResponse() + *reqHeader = true + *reqBody = true + *respHeader = false + *respBody = false + }) +} - for _, fn := range testCases { - c := newClient() - buf := new(bytes.Buffer) - c.EnableDumpAllTo(buf) +func TestEnableDumpAllWithoutResponseBody(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutResponseBody() + *reqHeader = true + *reqBody = true + *respHeader = true + *respBody = false + }) +} + +func TestEnableDumpAllWithoutHeader(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutHeader() + *reqHeader = false + *reqBody = true + *respHeader = false + *respBody = true + }) +} + +func TestEnableDumpAllWithoutBody(t *testing.T) { + testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { + c.EnableDumpAllWithoutBody() + *reqHeader = true + *reqBody = false + *respHeader = true + *respBody = false + }) +} + +func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, respHeader, respBody *bool)) { + buf := new(bytes.Buffer) + c := tc().EnableDumpAllTo(buf) + testDump := func(c *Client) { var reqHeader, reqBody, respHeader, respBody bool fn(c, &reqHeader, &reqBody, &respHeader, &respBody) resp, err := c.R().SetBody(`test body`).Post("/") @@ -512,8 +520,13 @@ func testClientDump(t *testing.T, newClient func() *Client) { assertContains(t, dump, "date", respHeader) assertContains(t, dump, "testpost: text response", respBody) } + testDump(c) + c.EnableForceHTTP1() + testDump(c) +} - c := newClient() +func TestSetCommonDumpOptions(t *testing.T) { + c := tc() buf := new(bytes.Buffer) opt := &DumpOptions{ RequestHeader: true, @@ -529,25 +542,25 @@ func testClientDump(t *testing.T, newClient func() *Client) { assertContains(t, buf.String(), "test body", false) assertContains(t, buf.String(), "date", false) assertContains(t, buf.String(), "testpost: text response", true) +} +func TestEnableDumpAllToFile(t *testing.T) { + c := tc() dumpFile := "tmp_test_dump_file" c.EnableDumpAllToFile(getTestFilePath(dumpFile)) - resp, err = c.R().SetBody("test body").Post("/") + resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) dump := string(getTestFileContent(t, dumpFile)) os.Remove(getTestFilePath(dumpFile)) assertContains(t, dump, "user-agent", true) - assertContains(t, dump, "test body", false) - assertContains(t, dump, "date", false) + assertContains(t, dump, "test body", true) + assertContains(t, dump, "date", true) assertContains(t, dump, "testpost: text response", true) +} - buf = new(bytes.Buffer) +func TestEnableDumpAllAsync(t *testing.T) { + c := tc() + buf := new(bytes.Buffer) c.EnableDumpAllTo(buf).EnableDumpAllAsync() - resp, err = c.R().SetBody("test body").Post("/") - assertSuccess(t, resp, err) - time.Sleep(10 * time.Millisecond) - assertContains(t, buf.String(), "user-agent", true) - assertContains(t, buf.String(), "test body", false) - assertContains(t, buf.String(), "date", false) - assertContains(t, buf.String(), "testpost: text response", true) + assertEqual(t, true, c.getDumpOptions().Async) } From ca23a732e6c63fcae7d8c6784f986ce04afac8c4 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 15:11:05 +0800 Subject: [PATCH 366/843] fix testEnableDumpAll --- client_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index 32099342..83d7c299 100644 --- a/client_test.go +++ b/client_test.go @@ -521,7 +521,8 @@ func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, resp assertContains(t, dump, "testpost: text response", respBody) } testDump(c) - c.EnableForceHTTP1() + buf = new(bytes.Buffer) + c = tc().EnableDumpAllTo(buf).EnableDumpAllTo(buf).EnableForceHTTP1() testDump(c) } From fec2b9aff73c3f6253443e64a0de57c19bad9db0 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 15:29:58 +0800 Subject: [PATCH 367/843] remove covermode in ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9256b8d3..e76f3f84 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: go-version: ${{ matrix.go }} - name: Test - run: go test ./... -coverprofile=coverage.txt -covermode=atomic + run: go test ./... -coverprofile=coverage.txt - name: Coverage env: From cced048b8c1211283d41741506b4d6f704c2d868 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 16:32:59 +0800 Subject: [PATCH 368/843] add TestAuthenticate in socks --- internal/socks/socks_test.go | 48 ++++++++++++++++++++++++++++++++++++ internal/tests/tests.go | 22 +++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 internal/socks/socks_test.go create mode 100644 internal/tests/tests.go diff --git a/internal/socks/socks_test.go b/internal/socks/socks_test.go new file mode 100644 index 00000000..d96c44b6 --- /dev/null +++ b/internal/socks/socks_test.go @@ -0,0 +1,48 @@ +package socks + +import ( + "bytes" + "context" + "github.com/imroc/req/v3/internal/tests" + "strings" + "testing" +) + +func TestReply(t *testing.T) { + for i := 0; i < 9; i++ { + s := Reply(i).String() + if strings.Contains(s, "unknown") { + t.Errorf("resply code [%d] should not unkown", i) + } + } + s := Reply(9).String() + if !strings.Contains(s, "unknown") { + t.Errorf("resply code [%d] should unkown", 9) + } +} + +func TestAuthenticate(t *testing.T) { + auth := &UsernamePassword{ + Username: "imroc", + Password: "123456", + } + buf := bytes.NewBuffer([]byte{byte(0x01), byte(0x00)}) + err := auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) + tests.AssertNoError(t, err) + auth.Username = "this is a very long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long name" + err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) + tests.AssertErrorContains(t, err, "invalid") + + auth.Username = "imroc" + buf = bytes.NewBuffer([]byte{byte(0x03), byte(0x00)}) + err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) + tests.AssertErrorContains(t, err, "invalid username/password version") + + buf = bytes.NewBuffer([]byte{byte(0x01), byte(0x02)}) + err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) + tests.AssertErrorContains(t, err, "authentication failed") + + err = auth.Authenticate(context.Background(), buf, AuthMethodNoAcceptableMethods) + tests.AssertErrorContains(t, err, "unsupported authentication method") + +} diff --git a/internal/tests/tests.go b/internal/tests/tests.go new file mode 100644 index 00000000..d7d59dac --- /dev/null +++ b/internal/tests/tests.go @@ -0,0 +1,22 @@ +package tests + +import ( + "strings" + "testing" +) + +func AssertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Error occurred [%v]", err) + } +} + +func AssertErrorContains(t *testing.T, err error, s string) { + if err == nil { + t.Error("err is nil") + return + } + if !strings.Contains(err.Error(), s) { + t.Errorf("%q is not included in error %q", s, err.Error()) + } +} From 844b30aa167fc7d8bff5d6e49efe137a3e761d78 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 17:25:56 +0800 Subject: [PATCH 369/843] improve dump tests --- h2_transport.go | 33 ++++++++++++++------------------- req_test.go | 23 ++++++++++++++++------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/h2_transport.go b/h2_transport.go index 9aa936a2..833bfc5e 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -1270,6 +1270,13 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { return err } + bodyDumps := []*dumper{} + for _, dump := range dumps { + if dump.RequestBody { + bodyDumps = append(bodyDumps, dump) + } + } + hasBody := cs.reqBodyContentLength != 0 if !hasBody { cs.sentEndStream = true @@ -1295,25 +1302,14 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { return err } } - - if len(dumps) > 0 { - dd := []*dumper{} - for _, dump := range dumps { - if dump.RequestBody { - dd = append(dd, dump) - } - } - dumps = dd - } - - if err = cs.writeRequestBody(req, dumps); err != nil { + if err = cs.writeRequestBody(req, bodyDumps); err != nil { if err != http2errStopReqBodyWrite { http2traceWroteRequest(cs.trace, err) return err } } else { cs.sentEndStream = true - for _, dump := range dumps { + for _, dump := range bodyDumps { dump.dump([]byte("\r\n\r\n")) } } @@ -1883,17 +1879,16 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, traceHeaders := http2traceHasWroteHeaderField(trace) writeHeader := cc.writeHeader + headerDumps := []*dumper{} if len(dumps) > 0 { - dd := []*dumper{} for _, dump := range dumps { if dump.RequestHeader { - dd = append(dd, dump) + headerDumps = append(headerDumps, dump) } } - dumps = dd - if len(dumps) > 0 { + if len(headerDumps) > 0 { writeHeader = func(name, value string) { - for _, dump := range dumps { + for _, dump := range headerDumps { dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } cc.writeHeader(name, value) @@ -1915,7 +1910,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } }) - for _, dump := range dumps { + for _, dump := range headerDumps { dump.dump([]byte("\r\n")) } diff --git a/req_test.go b/req_test.go index 595fa3ad..8684e996 100644 --- a/req_test.go +++ b/req_test.go @@ -450,13 +450,22 @@ func testGlobalWrapperEnableDumps(t *testing.T) { func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { var reqHeader, reqBody, respHeader, respBody bool r := fn(&reqHeader, &reqBody, &respHeader, &respBody) - resp, err := r.SetBody(`test body`).Post(getTestServerURL() + "/") - assertSuccess(t, resp, err) - dump := resp.Dump() - assertContains(t, dump, "user-agent", reqHeader) - assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "date", respHeader) - assertContains(t, dump, "testpost: text response", respBody) + dump, ok := r.Context().Value("_dumper").(*dumper) + if !ok { + t.Fatal("no dumper found in request context") + } + if reqHeader != dump.DumpOptions.RequestHeader { + t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", reqHeader, dump.DumpOptions.RequestHeader) + } + if reqBody != dump.DumpOptions.RequestBody { + t.Errorf("Unexpected RequestBody dump option, expected [%v], got [%v]", reqBody, dump.DumpOptions.RequestBody) + } + if respHeader != dump.DumpOptions.ResponseHeader { + t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", respHeader, dump.DumpOptions.ResponseHeader) + } + if respBody != dump.DumpOptions.ResponseBody { + t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", respBody, dump.DumpOptions.ResponseBody) + } } func testGlobalWrapperSendRequest(t *testing.T) { From cbd4f70f3cf8e5e3f84d5313a9e3af4f1232e9de Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 21:01:54 +0800 Subject: [PATCH 370/843] disable concurrency for tests avoid returning RST frame from http2 server when test in high concurrent (default GOMAXPROCS) --- .github/workflows/ci.yml | 2 +- client_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e76f3f84..833e746a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: go-version: ${{ matrix.go }} - name: Test - run: go test ./... -coverprofile=coverage.txt + run: go test -p=1 ./... -coverprofile=coverage.txt - name: Coverage env: diff --git a/client_test.go b/client_test.go index 83d7c299..41bee366 100644 --- a/client_test.go +++ b/client_test.go @@ -522,7 +522,7 @@ func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, resp } testDump(c) buf = new(bytes.Buffer) - c = tc().EnableDumpAllTo(buf).EnableDumpAllTo(buf).EnableForceHTTP1() + c = tc().EnableDumpAllTo(buf).EnableForceHTTP1() testDump(c) } From 3dc4687ef3a2fc89bf9c995733c24957b98569d8 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 16 Feb 2022 21:37:18 +0800 Subject: [PATCH 371/843] add tests for response --- req_test.go | 12 ++++++++++++ response_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 response_test.go diff --git a/req_test.go b/req_test.go index 8684e996..2009355c 100644 --- a/req_test.go +++ b/req_test.go @@ -197,6 +197,18 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/host-header": w.Write([]byte(r.Host)) + case "/json": + r.ParseForm() + if r.FormValue("type") != "no" { + w.Header().Set(hdrContentTypeKey, jsonContentType) + } + w.Write([]byte(`{"name": "roc"}`)) + case "/xml": + r.ParseForm() + if r.FormValue("type") != "no" { + w.Header().Set(hdrContentTypeKey, xmlContentType) + } + w.Write([]byte(`roc`)) case "/unlimited-redirect": w.Header().Set("Location", "/unlimited-redirect") w.WriteHeader(http.StatusMovedPermanently) diff --git a/response_test.go b/response_test.go new file mode 100644 index 00000000..08e700bf --- /dev/null +++ b/response_test.go @@ -0,0 +1,34 @@ +package req + +import "testing" + +type User struct { + Name string `json:"name" xml:"name"` +} + +func TestUnmarshalJson(t *testing.T) { + var user User + resp, err := tc().R().Get("/json") + assertSuccess(t, resp, err) + err = resp.UnmarshalJson(&user) + assertError(t, err) + assertEqual(t, "roc", user.Name) +} + +func TestUnmarshalXml(t *testing.T) { + var user User + resp, err := tc().R().Get("/xml") + assertSuccess(t, resp, err) + err = resp.UnmarshalXml(&user) + assertError(t, err) + assertEqual(t, "roc", user.Name) +} + +func TestUnmarshal(t *testing.T) { + var user User + resp, err := tc().R().Get("/xml") + assertSuccess(t, resp, err) + err = resp.Unmarshal(&user) + assertError(t, err) + assertEqual(t, "roc", user.Name) +} From 915cbda6a1681e763efa81ce666a3b6c5bc4a8e4 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 13:14:13 +0800 Subject: [PATCH 372/843] remove unused code in transfer.go --- transfer.go | 62 +---------------------------------------------------- 1 file changed, 1 insertion(+), 61 deletions(-) diff --git a/transfer.go b/transfer.go index 3e7f9f94..d78bdace 100644 --- a/transfer.go +++ b/transfer.go @@ -465,22 +465,6 @@ func bodyAllowedForStatus(status int) bool { return true } -var ( - suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} - suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} -) - -func suppressedHeaders(status int) []string { - switch { - case status == 304: - // RFC 7232 section 4.1 - return suppressedHeaders304 - case !bodyAllowedForStatus(status): - return suppressedHeadersNoBody - } - return nil -} - // msg is *http.Request or *http.Response. func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} @@ -498,15 +482,6 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { if rr.Request != nil { t.RequestMethod = rr.Request.Method } - case *http.Request: - t.Header = rr.Header - t.RequestMethod = rr.Method - t.ProtoMajor = rr.ProtoMajor - t.ProtoMinor = rr.ProtoMinor - // Transfer semantics for Requests are exactly like those for - // Responses with status code 200, responding to a GET method - t.StatusCode = 200 - t.Close = rr.Close default: panic("unexpected type") } @@ -545,7 +520,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // and the status is not 1xx, 204 or 304, then the body is unbounded. // See RFC 7230, section 3.3. switch msg.(type) { - case *Response: + case *http.Response: if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { // Unbounded body. t.Close = true @@ -578,14 +553,6 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // Unify output switch rr := msg.(type) { - case *http.Request: - rr.Body = t.Body - rr.ContentLength = t.ContentLength - if t.Chunked { - rr.TransferEncoding = []string{"chunked"} - } - rr.Close = t.Close - rr.Trailer = t.Trailer case *http.Response: rr.Body = t.Body rr.ContentLength = t.ContentLength @@ -614,13 +581,6 @@ func (uste *unsupportedTEError) Error() string { return uste.err } -// isUnsupportedTEError checks if the error is of type -// unsupportedTEError. It is usually invoked with a non-nil err. -func isUnsupportedTEError(err error) bool { - _, ok := err.(*unsupportedTEError) - return ok -} - // parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. func (t *transferReader) parseTransferEncoding() error { raw, present := t.Header["Transfer-Encoding"] @@ -666,7 +626,6 @@ func (t *transferReader) parseTransferEncoding() error { // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (int64, error) { - isRequest := !isResponse contentLens := header["Content-Length"] // Hardening against HTTP request smuggling @@ -691,13 +650,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He // Logic based on response type or status if noResponseBodyExpected(requestMethod) { - // For HTTP requests, as part of hardening against request - // smuggling (RFC 7230), don't allow a Content-Length header for - // methods which don't permit bodies. As an exception, allow - // exactly one Content-Length header if its value is "0". - if isRequest && len(contentLens) > 0 && !(len(contentLens) == 1 && contentLens[0] == "0") { - return 0, fmt.Errorf("http: method cannot contain a Content-Length; got %q", contentLens) - } return 0, nil } if status/100 == 1 { @@ -727,18 +679,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He } header.Del("Content-Length") - if isRequest { - // RFC 7230 neither explicitly permits nor forbids an - // entity-body on a GET request so we permit one if - // declared, but we default to 0 here (not -1 below) - // if there's no mention of a body. - // Likewise, all other request methods are assumed to have - // no body if neither Transfer-Encoding chunked nor a - // Content-Length are set. - return 0, nil - } - - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } From 95a30d79d76eb47c7268e0ff308135b0cf62577a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 16:17:51 +0800 Subject: [PATCH 373/843] optimize tests --- req_test.go | 32 +++++++++++++++++--------------- transport.go | 23 ----------------------- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/req_test.go b/req_test.go index 2009355c..1af687e6 100644 --- a/req_test.go +++ b/req_test.go @@ -88,6 +88,7 @@ type echo struct { func handlePost(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": + io.Copy(ioutil.Discard, r.Body) w.Write([]byte("TestPost: text response")) case "/raw-upload": io.Copy(ioutil.Discard, r.Body) @@ -113,6 +114,7 @@ func handlePost(w http.ResponseWriter, r *http.Request) { case "/search": handleSearch(w, r) case "/redirect": + io.Copy(ioutil.Discard, r.Body) w.Header().Set(hdrLocationKey, "/") w.WriteHeader(http.StatusMovedPermanently) case "/echo": @@ -195,6 +197,9 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestGet: text response")) case "/bad-request": w.WriteHeader(http.StatusBadRequest) + case "/chunked": + w.Header().Add("Trailer", "Expires") + w.Write([]byte(`This is a chunked body`)) case "/host-header": w.Write([]byte(r.Host)) case "/json": @@ -259,8 +264,8 @@ func assertStatus(t *testing.T, resp *Response, err error, statusCode int, statu func assertSuccess(t *testing.T, resp *Response, err error) { assertError(t, err) - assertNotNil(t, resp) - assertNotNil(t, resp.Body) + assertNotNil(t, resp.Response) + assertNotNil(t, resp.Response.Body) assertEqual(t, http.StatusOK, resp.StatusCode) assertEqual(t, "200 OK", resp.Status) if !resp.IsSuccess() { @@ -285,13 +290,7 @@ func assertNil(t *testing.T, v interface{}) { func assertNotNil(t *testing.T, v interface{}) { if isNil(v) { - t.Errorf("[%v] was expected to be non-nil", v) - } -} - -func assertType(t *testing.T, typ, v interface{}) { - if reflect.DeepEqual(reflect.TypeOf(typ), reflect.TypeOf(v)) { - t.Errorf("Expected type %t, got %t", typ, v) + t.Fatalf("[%v] was expected to be non-nil", v) } } @@ -370,12 +369,6 @@ func assertEqual(t *testing.T, e, g interface{}) { return } -func removeEmptyString(s string) string { - s = strings.ReplaceAll(s, "\r", "") - s = strings.ReplaceAll(s, "\n", "") - return s -} - func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { if equal(e, g) { t.Errorf("Expected [%v], got [%v]", e, g) @@ -854,3 +847,12 @@ func TestGlobalWrapper(t *testing.T) { DisableDumpAll() assertEqual(t, true, DefaultClient().t.dump == nil) } + +func TestTrailer(t *testing.T) { + resp, err := tc().EnableForceHTTP1().R().Get("/chunked") + assertSuccess(t, resp, err) + _, ok := resp.Trailer["Expires"] + if !ok { + t.Error("trailer not exists") + } +} diff --git a/transport.go b/transport.go index f761db61..e706bc97 100644 --- a/transport.go +++ b/transport.go @@ -37,7 +37,6 @@ import ( "time" "golang.org/x/net/http/httpguts" - "golang.org/x/net/http/httpproxy" ) type HttpVersion string @@ -763,30 +762,8 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { return cancel != nil } -// -// Private implementation past this point. -// - -var ( - // proxyConfigOnce guards proxyConfig - envProxyOnce sync.Once - envProxyFuncValue func(*url.URL) (*url.URL, error) -) - -// defaultProxyConfig returns a ProxyConfig value looked up -// from the environment. This mitigates expensive lookups -// on some platforms (e.g. Windows). -func envProxyFunc() func(*url.URL) (*url.URL, error) { - envProxyOnce.Do(func() { - envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc() - }) - return envProxyFuncValue -} - // resetProxyConfig is used by tests. func resetProxyConfig() { - envProxyOnce = sync.Once{} - envProxyFuncValue = nil } func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { From 11a54b4e9206b98140aff8da3b38d4a800dafc75 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 16:23:37 +0800 Subject: [PATCH 374/843] remove unused dump reader --- dump.go | 18 ------------------ transport_test.go | 7 ------- 2 files changed, 25 deletions(-) diff --git a/dump.go b/dump.go index 816c7b4e..4c2ba354 100644 --- a/dump.go +++ b/dump.go @@ -57,24 +57,6 @@ func (w *dumpWriteCloser) Write(p []byte) (n int, err error) { return } -func (d *dumper) WrapReader(r io.Reader) io.Reader { - return &dumpReader{ - r: r, - dump: d, - } -} - -type dumpReader struct { - r io.Reader - dump *dumper -} - -func (r *dumpReader) Read(p []byte) (n int, err error) { - n, err = r.r.Read(p) - r.dump.dump(p[:n]) - return -} - type dumpWriter struct { w io.Writer dump *dumper diff --git a/transport_test.go b/transport_test.go index 551b134d..38b84a46 100644 --- a/transport_test.go +++ b/transport_test.go @@ -3327,15 +3327,10 @@ func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func } } -func ResetCachedEnvironment() { - resetProxyConfig() -} - func ResetProxyEnv() { for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} { os.Unsetenv(v) } - ResetCachedEnvironment() } func TestProxyFromEnvironment(t *testing.T) { @@ -3347,7 +3342,6 @@ func TestProxyFromEnvironment(t *testing.T) { os.Setenv("HTTPS_PROXY", tt.httpsenv) os.Setenv("NO_PROXY", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) - ResetCachedEnvironment() return httpproxy.FromEnvironment().ProxyFunc()(req.URL) }) } @@ -3362,7 +3356,6 @@ func TestProxyFromEnvironmentLowerCase(t *testing.T) { os.Setenv("https_proxy", tt.httpsenv) os.Setenv("no_proxy", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) - ResetCachedEnvironment() return httpproxy.FromEnvironment().ProxyFunc()(req.URL) }) } From 09c8eceabefd6c8717c055efb524e82d8f7f8009 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 16:45:04 +0800 Subject: [PATCH 375/843] remove unused State() of http2Transport --- h2_transport.go | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/h2_transport.go b/h2_transport.go index 833bfc5e..fbba49ad 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -820,28 +820,6 @@ type http2ClientConnState struct { LastIdle time.Time } -// State returns a snapshot of cc's state. -func (cc *http2ClientConn) State() http2ClientConnState { - cc.wmu.Lock() - maxConcurrent := cc.maxConcurrentStreams - if !cc.seenSettings { - maxConcurrent = 0 - } - cc.wmu.Unlock() - - cc.mu.Lock() - defer cc.mu.Unlock() - return http2ClientConnState{ - Closed: cc.closed, - Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, - StreamsActive: len(cc.streams), - StreamsReserved: cc.streamsReserved, - StreamsPending: cc.pendingRequests, - LastIdle: cc.lastIdle, - MaxConcurrentStreams: maxConcurrent, - } -} - // clientConnIdleState describes the suitability of a client // connection to initiate a new RoundTrip request. type http2clientConnIdleState struct { From e804632384aed5baac33c5bf1aafec933f0f7a7a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 16:51:58 +0800 Subject: [PATCH 376/843] add String() and Blame() tests for TestTraceInfo --- request_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/request_test.go b/request_test.go index 63128f57..c84de65c 100644 --- a/request_test.go +++ b/request_test.go @@ -593,6 +593,17 @@ func testHostHeaderOverride(t *testing.T, c *Client) { func TestTraceInfo(t *testing.T) { testTraceInfo(t, tc()) testTraceInfo(t, tc().EnableForceHTTP1()) + resp, err := tc().R().Get("/") + assertSuccess(t, resp, err) + ti := resp.TraceInfo() + assertContains(t, ti.String(), "not enabled", true) + assertContains(t, ti.Blame(), "not enabled", true) + + resp, err = tc().EnableTraceAll().R().Get("/") + assertSuccess(t, resp, err) + ti = resp.TraceInfo() + assertContains(t, ti.String(), "not enabled", false) + assertContains(t, ti.Blame(), "not enabled", false) } func testTraceInfo(t *testing.T, c *Client) { From c4df28357f242ffb1c7d61be44518e19eed072c8 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 16:58:44 +0800 Subject: [PATCH 377/843] remove IsSpace in ascii package --- internal/ascii/print.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/ascii/print.go b/internal/ascii/print.go index 69f32262..585e5bab 100644 --- a/internal/ascii/print.go +++ b/internal/ascii/print.go @@ -59,7 +59,3 @@ func ToLower(s string) (lower string, ok bool) { } return strings.ToLower(s), true } - -func IsSpace(b byte) bool { - return b == ' ' || b == '\t' || b == '\n' || b == '\r' -} From 5674d5fca22c8c6d37a189bc00eb3c63a50feb52 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 17:26:48 +0800 Subject: [PATCH 378/843] add tests for charsets detect --- client_test.go | 19 ++--- .../charsets/.testdata/HTTP-vs-UTF-8-BOM.html | 48 +++++++++++++ internal/charsets/.testdata/UTF-16BE-BOM.html | Bin 0 -> 2670 bytes internal/charsets/.testdata/UTF-16LE-BOM.html | Bin 0 -> 2682 bytes .../.testdata/UTF-8-BOM-vs-meta-charset.html | 49 +++++++++++++ .../.testdata/UTF-8-BOM-vs-meta-content.html | 48 +++++++++++++ .../.testdata/meta-charset-attribute.html | 48 +++++++++++++ .../.testdata/meta-content-attribute.html | 48 +++++++++++++ internal/charsets/charsets_test.go | 66 ++++++++++++++++++ internal/tests/file.go | 25 +++++++ req_test.go | 29 +++----- request_test.go | 7 +- 12 files changed, 356 insertions(+), 31 deletions(-) create mode 100644 internal/charsets/.testdata/HTTP-vs-UTF-8-BOM.html create mode 100644 internal/charsets/.testdata/UTF-16BE-BOM.html create mode 100644 internal/charsets/.testdata/UTF-16LE-BOM.html create mode 100644 internal/charsets/.testdata/UTF-8-BOM-vs-meta-charset.html create mode 100644 internal/charsets/.testdata/UTF-8-BOM-vs-meta-content.html create mode 100644 internal/charsets/.testdata/meta-charset-attribute.html create mode 100644 internal/charsets/.testdata/meta-content-attribute.html create mode 100644 internal/charsets/charsets_test.go create mode 100644 internal/tests/file.go diff --git a/client_test.go b/client_test.go index 41bee366..1a813651 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/imroc/req/v3/internal/tests" "io/ioutil" "net" "net/http" @@ -338,12 +339,12 @@ func TestGetTLSClientConfig(t *testing.T) { } func TestSetRootCertFromFile(t *testing.T) { - c := tc().SetRootCertsFromFile(getTestFilePath("sample-root.pem")) + c := tc().SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetRootCertFromString(t *testing.T) { - c := tc().SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) + c := tc().SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))) assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } @@ -354,8 +355,8 @@ func TestSetCerts(t *testing.T) { func TestSetCertFromFile(t *testing.T) { c := tc().SetCertFromFile( - getTestFilePath("sample-client.pem"), - getTestFilePath("sample-client-key.pem"), + tests.GetTestFilePath("sample-client.pem"), + tests.GetTestFilePath("sample-client-key.pem"), ) assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) } @@ -367,8 +368,8 @@ func TestSetOutputDirectory(t *testing.T) { R().SetOutputFile(outFile). Get("/") assertSuccess(t, resp, err) - content := string(getTestFileContent(t, outFile)) - os.Remove(getTestFilePath(outFile)) + content := string(tests.GetTestFileContent(t, outFile)) + os.Remove(tests.GetTestFilePath(outFile)) assertEqual(t, "TestGet: text response", content) } @@ -548,11 +549,11 @@ func TestSetCommonDumpOptions(t *testing.T) { func TestEnableDumpAllToFile(t *testing.T) { c := tc() dumpFile := "tmp_test_dump_file" - c.EnableDumpAllToFile(getTestFilePath(dumpFile)) + c.EnableDumpAllToFile(tests.GetTestFilePath(dumpFile)) resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) - dump := string(getTestFileContent(t, dumpFile)) - os.Remove(getTestFilePath(dumpFile)) + dump := string(tests.GetTestFileContent(t, dumpFile)) + os.Remove(tests.GetTestFilePath(dumpFile)) assertContains(t, dump, "user-agent", true) assertContains(t, dump, "test body", true) assertContains(t, dump, "date", true) diff --git a/internal/charsets/.testdata/HTTP-vs-UTF-8-BOM.html b/internal/charsets/.testdata/HTTP-vs-UTF-8-BOM.html new file mode 100644 index 00000000..26e5d8b4 --- /dev/null +++ b/internal/charsets/.testdata/HTTP-vs-UTF-8-BOM.html @@ -0,0 +1,48 @@ + + + + HTTP vs UTF-8 BOM + + + + + + + + + + + +

HTTP vs UTF-8 BOM

+ + +
+ + +
 
+ + + + + +
+

A character encoding set in the HTTP header has lower precedence than the UTF-8 signature.

+

The HTTP header attempts to set the character encoding to ISO 8859-15. The page starts with a UTF-8 signature.

The test contains a div with a class name that contains the following sequence of bytes: 0xC3 0xBD 0xC3 0xA4 0xC3 0xA8. These represent different sequences of characters in ISO 8859-15, ISO 8859-1 and UTF-8. The external, UTF-8-encoded stylesheet contains a selector .test div.ýäè. This matches the sequence of bytes above when they are interpreted as UTF-8. If the class name matches the selector then the test will pass.

If the test is unsuccessful, the characters  should appear at the top of the page. These represent the bytes that make up the UTF-8 signature when encountered in the ISO 8859-15 encoding.

+
+
+
HTML5
+

the-input-byte-stream-034
Result summary & related tests
Detailed results for this test
Link to spec

+
Assumptions:
  • The default encoding for the browser you are testing is not set to ISO 8859-15.
  • +
  • The test is read from a server that supports HTTP.
+
+ + + + + + diff --git a/internal/charsets/.testdata/UTF-16BE-BOM.html b/internal/charsets/.testdata/UTF-16BE-BOM.html new file mode 100644 index 0000000000000000000000000000000000000000..3abf7a9343c20518e57dfea58b374fb0f4fb58a1 GIT binary patch literal 2670 zcmcJR?QRoS5Qc}JAoU&=BQ-(7b^;2j8i*i3RV1JlO@;VXIsPurV!WHiDdLW}i`*CO z^UnC>tih=KsVr;H&Y7?C&O3AV(?534uG?e##U9y_y|!QNi4``n+D>d{2lky^LnFNx z?9HrarH$>rwQR_$g)Hk0*&STI*EYq|47~&U9sfUB+ji})9eR{QqCUra7oDsZ5obtB zdxP%<)-$4Q;rSHJiM>U(#ZI=;?n^BC?Dp6lu=~_1-lnX3u03&2BlmQIY>L+!Uq7XoytKw^Q#oZSM?3*J?)&ojG&yzQRkC!Ml5JE?ax;lp_NYEcdUht`ZswOviB~L5hmJ|pXI71nn20w;>vG! zQGB$EE9&wC``&J#_Ym~PgRu-Bd>1!pOp0||k`kr=VJ zfH6I6rmRaeHA7U-A^OTsT+|d2a^i(>DePzZ{)ibXoCBvJnuYrd-3kkN$uy{qQK;=*Y;S87ro12aTgu^i*%f8zC3>a}9DIe4cfxOzsCw&(cqvP9{ud{N6f` z#TNDY(B6@Gpr|uN+%&x^XZjBHdc@2vsM(Tyc2=vshHQ5w+obmp>tuWT(t4BTUGAQw zxeI$UGSLUBg=WFbF;4f@4=^P2AgY@CFn8A`bcC=_&~)fiDe)#cUARRBzJ^k|%X)69 z+{Cb`wq}Rsg%B62CC_tK!AV(W{(MV?#mndR46CU#BUN<{8e?*oT+!pE5wF#O#TR#a z$9qRT)tpbw8zAI~QQJg2C3|6$I%(T(;`zOMy6SO+&;pG=c#2P|P-WZn$$DpWJlC3U z3*nvmz zwP{u~r$L?-m3uqp9I1+#3yE|3M$(s-BEtih=LQ>`qYoiktOop(wi%!;yh%+Rm z{e|xntY<{q!1F1Z6MKtngPm-p-4|H&+3m4AVE3_AyiHm6Tzlf4M(*ht*%YrezJ6kr zHGj45pc?64*$Cm%-zseWMA`x;)v*~jA=i}szqts9xmQkS`M11|(H7bTXAycsXU53+ zJ?120SRZeyiFjW7enPN`bxk$IaWV3o48oJF7D&2ysoY;6(s6%6vVfaYd&mC=erK!) zNGI^7upQgN)53OHe_VE<@J+G8*Y|p*)zB2Thdi}+YR<5QWHm!|a_*AoZXuv7)$xe| zm3Q$D7{|#}{m4X&UY!6(ZhyYi2(5JLzGE$H)W6BQklnjPMwn<Yvv7Z*TVWwD*=E3QpH37* z#lqXJA0A~J9T_<^W5smspmDg2p6ac5Bjn+~LAoow%1TCdZ*$K8`O zw_$HaCi+0N&@7la#_7KL5r$+QL{)Pi=I&aDjt~|Knht#`CEi4*3%97i_fSfASlwUz0=3V0GCxY}z81UC-nP=CGt2OqYV$ zoRCo+qM9YX*3FFORLC=E3B~S@+KROyk4r5 yX7?DaslDfIebqXgC!KKp4IYy+W~X?ddE6o=`A+x#x0AK&6MF#W&AXxbRrv+SX}PNa literal 0 HcmV?d00001 diff --git a/internal/charsets/.testdata/UTF-8-BOM-vs-meta-charset.html b/internal/charsets/.testdata/UTF-8-BOM-vs-meta-charset.html new file mode 100644 index 00000000..83de4333 --- /dev/null +++ b/internal/charsets/.testdata/UTF-8-BOM-vs-meta-charset.html @@ -0,0 +1,49 @@ + + + + UTF-8 BOM vs meta charset + + + + + + + + + + + +

UTF-8 BOM vs meta charset

+ + +
+ + +
 
+ + + + + +
+

A page with a UTF-8 BOM will be recognized as UTF-8 even if the meta charset attribute declares a different encoding.

+

The page contains an encoding declaration in a meta charset attribute that attempts to set the character encoding to ISO 8859-15, but the file starts with a UTF-8 signature.

The test contains a div with a class name that contains the following sequence of bytes: 0xC3 0xBD 0xC3 0xA4 0xC3 0xA8. These represent different sequences of characters in ISO 8859-15, ISO 8859-1 and UTF-8. The external, UTF-8-encoded stylesheet contains a selector .test div.ýäè. This matches the sequence of bytes above when they are interpreted as UTF-8. If the class name matches the selector then the test will pass.

+
+
+
HTML5
+

the-input-byte-stream-038
Result summary & related tests
Detailed results for this test
Link to spec

+
Assumptions:
  • The default encoding for the browser you are testing is not set to ISO 8859-15.
  • +
  • The test is read from a server that supports HTTP.
+
+ + + + + + diff --git a/internal/charsets/.testdata/UTF-8-BOM-vs-meta-content.html b/internal/charsets/.testdata/UTF-8-BOM-vs-meta-content.html new file mode 100644 index 00000000..501aac2d --- /dev/null +++ b/internal/charsets/.testdata/UTF-8-BOM-vs-meta-content.html @@ -0,0 +1,48 @@ + + + + UTF-8 BOM vs meta content + + + + + + + + + + + +

UTF-8 BOM vs meta content

+ + +
+ + +
 
+ + + + + +
+

A page with a UTF-8 BOM will be recognized as UTF-8 even if the meta content attribute declares a different encoding.

+

The page contains an encoding declaration in a meta content attribute that attempts to set the character encoding to ISO 8859-15, but the file starts with a UTF-8 signature.

The test contains a div with a class name that contains the following sequence of bytes: 0xC3 0xBD 0xC3 0xA4 0xC3 0xA8. These represent different sequences of characters in ISO 8859-15, ISO 8859-1 and UTF-8. The external, UTF-8-encoded stylesheet contains a selector .test div.ýäè. This matches the sequence of bytes above when they are interpreted as UTF-8. If the class name matches the selector then the test will pass.

+
+
+
HTML5
+

the-input-byte-stream-037
Result summary & related tests
Detailed results for this test
Link to spec

+
Assumptions:
  • The default encoding for the browser you are testing is not set to ISO 8859-15.
  • +
  • The test is read from a server that supports HTTP.
+
+ + + + + + diff --git a/internal/charsets/.testdata/meta-charset-attribute.html b/internal/charsets/.testdata/meta-charset-attribute.html new file mode 100644 index 00000000..2d7d25ab --- /dev/null +++ b/internal/charsets/.testdata/meta-charset-attribute.html @@ -0,0 +1,48 @@ + + + + meta charset attribute + + + + + + + + + + + +

meta charset attribute

+ + +
+ + +
 
+ + + + + +
+

The character encoding of the page can be set by a meta element with charset attribute.

+

The only character encoding declaration for this HTML file is in the charset attribute of the meta element, which declares the encoding to be ISO 8859-15.

The test contains a div with a class name that contains the following sequence of bytes: 0xC3 0xBD 0xC3 0xA4 0xC3 0xA8. These represent different sequences of characters in ISO 8859-15, ISO 8859-1 and UTF-8. The external, UTF-8-encoded stylesheet contains a selector .test div.ÜÀÚ. This matches the sequence of bytes above when they are interpreted as ISO 8859-15. If the class name matches the selector then the test will pass.

+
+
+
HTML5
+

the-input-byte-stream-009
Result summary & related tests
Detailed results for this test
Link to spec

+
Assumptions:
  • The default encoding for the browser you are testing is not set to ISO 8859-15.
  • +
  • The test is read from a server that supports HTTP.
+
+ + + + + + diff --git a/internal/charsets/.testdata/meta-content-attribute.html b/internal/charsets/.testdata/meta-content-attribute.html new file mode 100644 index 00000000..1c3f228e --- /dev/null +++ b/internal/charsets/.testdata/meta-content-attribute.html @@ -0,0 +1,48 @@ + + + + meta content attribute + + + + + + + + + + + +

meta content attribute

+ + +
+ + +
 
+ + + + + +
+

The character encoding of the page can be set by a meta element with http-equiv and content attributes.

+

The only character encoding declaration for this HTML file is in the content attribute of the meta element, which declares the encoding to be ISO 8859-15.

The test contains a div with a class name that contains the following sequence of bytes: 0xC3 0xBD 0xC3 0xA4 0xC3 0xA8. These represent different sequences of characters in ISO 8859-15, ISO 8859-1 and UTF-8. The external, UTF-8-encoded stylesheet contains a selector .test div.ÜÀÚ. This matches the sequence of bytes above when they are interpreted as ISO 8859-15. If the class name matches the selector then the test will pass.

+
+
+
HTML5
+

the-input-byte-stream-007
Result summary & related tests
Detailed results for this test
Link to spec

+
Assumptions:
  • The default encoding for the browser you are testing is not set to ISO 8859-15.
  • +
  • The test is read from a server that supports HTTP.
+
+ + + + + + diff --git a/internal/charsets/charsets_test.go b/internal/charsets/charsets_test.go new file mode 100644 index 00000000..239e371f --- /dev/null +++ b/internal/charsets/charsets_test.go @@ -0,0 +1,66 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package charsets + +import ( + "github.com/imroc/req/v3/internal/tests" + "io/ioutil" + "runtime" + "testing" +) + +var sniffTestCases = []struct { + filename, want string +}{ + {"UTF-16LE-BOM.html", "utf-16le"}, + {"UTF-16BE-BOM.html", "utf-16be"}, + {"meta-content-attribute.html", "iso-8859-15"}, + {"meta-charset-attribute.html", "iso-8859-15"}, + {"HTTP-vs-UTF-8-BOM.html", "utf-8"}, + {"UTF-8-BOM-vs-meta-content.html", "utf-8"}, + {"UTF-8-BOM-vs-meta-charset.html", "utf-8"}, +} + +func TestSniff(t *testing.T) { + switch runtime.GOOS { + case "nacl": // platforms that don't permit direct file system access + t.Skipf("not supported on %q", runtime.GOOS) + } + + for _, tc := range sniffTestCases { + content, err := ioutil.ReadFile(tests.GetTestFilePath(tc.filename)) + if err != nil { + t.Errorf("%s: error reading file: %v", tc.filename, err) + continue + } + + _, name := FindEncoding(content) + if name != tc.want { + t.Errorf("%s: got %q, want %q", tc.filename, name, tc.want) + continue + } + } +} + +var metaTestCases = []struct { + meta, want string +}{ + {"", ""}, + {"text/html", ""}, + {"text/html; charset utf-8", ""}, + {"text/html; charset=latin-2", "latin-2"}, + {"text/html; charset; charset = utf-8", "utf-8"}, + {`charset="big5"`, "big5"}, + {"charset='shift_jis'", "shift_jis"}, +} + +func TestFromMeta(t *testing.T) { + for _, tc := range metaTestCases { + got := fromMetaElement(tc.meta) + if got != tc.want { + t.Errorf("%q: got %q, want %q", tc.meta, got, tc.want) + } + } +} diff --git a/internal/tests/file.go b/internal/tests/file.go new file mode 100644 index 00000000..9dd0a391 --- /dev/null +++ b/internal/tests/file.go @@ -0,0 +1,25 @@ +package tests + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +var testDataPath string + +func init() { + pwd, _ := os.Getwd() + testDataPath = filepath.Join(pwd, ".testdata") +} + +func GetTestFileContent(t *testing.T, filename string) []byte { + b, err := ioutil.ReadFile(GetTestFilePath(filename)) + AssertNoError(t, err) + return b +} + +func GetTestFilePath(filename string) string { + return filepath.Join(testDataPath, filename) +} diff --git a/req_test.go b/req_test.go index 1af687e6..2bce5fcc 100644 --- a/req_test.go +++ b/req_test.go @@ -8,6 +8,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/imroc/req/v3/internal/tests" "go/token" "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/transform" @@ -70,16 +71,6 @@ func getTestServerURL() string { return testServer.URL } -func getTestFileContent(t *testing.T, filename string) []byte { - b, err := ioutil.ReadFile(getTestFilePath(filename)) - assertError(t, err) - return b -} - -func getTestFilePath(filename string) string { - return filepath.Join(testDataPath, filename) -} - type echo struct { Header http.Header `json:"header" xml:"header"` Body string `json:"body" xml:"body"` @@ -229,7 +220,7 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Header().Set(hdrContentTypeKey, "text/plain; charset=gbk") w.Write(toGbk("我是roc")) case "/gbk-no-charset": - b, err := ioutil.ReadFile(getTestFilePath("sample-gbk.html")) + b, err := ioutil.ReadFile(tests.GetTestFilePath("sample-gbk.html")) if err != nil { panic(err) } @@ -520,17 +511,17 @@ func testGlobalWrapperSendRequest(t *testing.T) { } func TestGlobalWrapperSetRequest(t *testing.T) { - testFilePath := getTestFilePath("sample-file.txt") + testFilePath := tests.GetTestFilePath("sample-file.txt") r := SetFiles(map[string]string{"test": testFilePath}) assertEqual(t, 1, len(r.uploadFiles)) assertEqual(t, true, r.isMultiPart) - r = SetFile("test", getTestFilePath("sample-file.txt")) + r = SetFile("test", tests.GetTestFilePath("sample-file.txt")) assertEqual(t, 1, len(r.uploadFiles)) assertEqual(t, true, r.isMultiPart) SetLogger(nil) - r = SetFile("test", getTestFilePath("file-not-exists.txt")) + r = SetFile("test", tests.GetTestFilePath("file-not-exists.txt")) assertEqual(t, 0, len(r.uploadFiles)) assertEqual(t, false, r.isMultiPart) assertNotNil(t, r.error) @@ -774,18 +765,18 @@ func TestGlobalWrapper(t *testing.T) { config := GetTLSClientConfig() assertEqual(t, config, DefaultClient().t.TLSClientConfig) - SetRootCertsFromFile(getTestFilePath("sample-root.pem")) + SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) - SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) + SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))) assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) SetCerts(tls.Certificate{}, tls.Certificate{}) assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 2) SetCertFromFile( - getTestFilePath("sample-client.pem"), - getTestFilePath("sample-client-key.pem"), + tests.GetTestFilePath("sample-client.pem"), + tests.GetTestFilePath("sample-client-key.pem"), ) assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 3) @@ -834,7 +825,7 @@ func TestGlobalWrapper(t *testing.T) { EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")) assertEqual(t, true, DefaultClient().getDumpOptions().Output == nil) - dumpFile := getTestFilePath("tmpdump.out") + dumpFile := tests.GetTestFilePath("tmpdump.out") EnableDumpAllToFile(dumpFile) assertEqual(t, true, DefaultClient().getDumpOptions().Output != nil) os.Remove(dumpFile) diff --git a/request_test.go b/request_test.go index c84de65c..ff6466e7 100644 --- a/request_test.go +++ b/request_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "github.com/imroc/req/v3/internal/tests" "net/http" "net/url" "strings" @@ -660,7 +661,7 @@ func testTraceOnTimeout(t *testing.T, c *Client) { } func TestAutoDetectRequestContentType(t *testing.T) { - resp, err := tc().R().SetBody(getTestFileContent(t, "sample-image.png")).Post("/raw-upload") + resp, err := tc().R().SetBody(tests.GetTestFileContent(t, "sample-image.png")).Post("/raw-upload") assertSuccess(t, resp, err) assertEqual(t, "image/png", resp.Request.Headers.Get(hdrContentTypeKey)) } @@ -668,8 +669,8 @@ func TestAutoDetectRequestContentType(t *testing.T) { func TestUploadMultipart(t *testing.T) { m := make(map[string]interface{}) resp, err := tc().R(). - SetFile("file", getTestFilePath("sample-image.png")). - SetFiles(map[string]string{"file": getTestFilePath("sample-file.txt")}). + SetFile("file", tests.GetTestFilePath("sample-image.png")). + SetFiles(map[string]string{"file": tests.GetTestFilePath("sample-file.txt")}). SetFormData(map[string]string{ "param1": "value1", "param2": "value2", From eeaf0c618b261ddbaa4854558686facb628610f5 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 18:52:18 +0800 Subject: [PATCH 379/843] add TestLogger --- client_test.go | 38 +++++++++++++++++++------------------- internal/tests/tests.go | 14 ++++++++++++++ logger_test.go | 19 +++++++++++++++++++ req_test.go | 14 -------------- request_test.go | 32 ++++++++++++++++---------------- 5 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 logger_test.go diff --git a/client_test.go b/client_test.go index 1a813651..3e869a96 100644 --- a/client_test.go +++ b/client_test.go @@ -195,7 +195,7 @@ func TestAutoDecode(t *testing.T) { resp, err = c.SetAutoDecodeAllContentType().R().Get("/gbk-no-charset") assertSuccess(t, resp, err) - assertContains(t, resp.String(), "我是roc", true) + tests.AssertContains(t, resp.String(), "我是roc", true) } func TestSetTimeout(t *testing.T) { @@ -308,27 +308,27 @@ func TestKeepAlives(t *testing.T) { func TestRedirect(t *testing.T) { _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") assertNotNil(t, err) - assertContains(t, err.Error(), "redirect is disabled", true) + tests.AssertContains(t, err.Error(), "redirect is disabled", true) _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") assertNotNil(t, err) - assertContains(t, err.Error(), "stopped after 3 redirects", true) + tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") assertNotNil(t, err) - assertContains(t, err.Error(), "different domain name is not allowed", true) + tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) _, err = tc().SetRedirectPolicy(SameHostRedirectPolicy()).R().Get("/redirect-to-other") assertNotNil(t, err) - assertContains(t, err.Error(), "different host name is not allowed", true) + tests.AssertContains(t, err.Error(), "different host name is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedHostRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") assertNotNil(t, err) - assertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) + tests.AssertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") assertNotNil(t, err) - assertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) + tests.AssertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) } func TestGetTLSClientConfig(t *testing.T) { @@ -516,10 +516,10 @@ func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, resp resp, err := c.R().SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := buf.String() - assertContains(t, dump, "user-agent", reqHeader) - assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "date", respHeader) - assertContains(t, dump, "testpost: text response", respBody) + tests.AssertContains(t, dump, "user-agent", reqHeader) + tests.AssertContains(t, dump, "test body", reqBody) + tests.AssertContains(t, dump, "date", respHeader) + tests.AssertContains(t, dump, "testpost: text response", respBody) } testDump(c) buf = new(bytes.Buffer) @@ -540,10 +540,10 @@ func TestSetCommonDumpOptions(t *testing.T) { c.SetCommonDumpOptions(opt).EnableDumpAll() resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) - assertContains(t, buf.String(), "user-agent", true) - assertContains(t, buf.String(), "test body", false) - assertContains(t, buf.String(), "date", false) - assertContains(t, buf.String(), "testpost: text response", true) + tests.AssertContains(t, buf.String(), "user-agent", true) + tests.AssertContains(t, buf.String(), "test body", false) + tests.AssertContains(t, buf.String(), "date", false) + tests.AssertContains(t, buf.String(), "testpost: text response", true) } func TestEnableDumpAllToFile(t *testing.T) { @@ -554,10 +554,10 @@ func TestEnableDumpAllToFile(t *testing.T) { assertSuccess(t, resp, err) dump := string(tests.GetTestFileContent(t, dumpFile)) os.Remove(tests.GetTestFilePath(dumpFile)) - assertContains(t, dump, "user-agent", true) - assertContains(t, dump, "test body", true) - assertContains(t, dump, "date", true) - assertContains(t, dump, "testpost: text response", true) + tests.AssertContains(t, dump, "user-agent", true) + tests.AssertContains(t, dump, "test body", true) + tests.AssertContains(t, dump, "date", true) + tests.AssertContains(t, dump, "testpost: text response", true) } func TestEnableDumpAllAsync(t *testing.T) { diff --git a/internal/tests/tests.go b/internal/tests/tests.go index d7d59dac..6ff4d3b4 100644 --- a/internal/tests/tests.go +++ b/internal/tests/tests.go @@ -20,3 +20,17 @@ func AssertErrorContains(t *testing.T, err error, s string) { t.Errorf("%q is not included in error %q", s, err.Error()) } } + +func AssertContains(t *testing.T, s, substr string, shouldContain bool) { + s = strings.ToLower(s) + isContain := strings.Contains(s, substr) + if shouldContain { + if !isContain { + t.Errorf("%q is not included in %s", substr, s) + } + } else { + if isContain { + t.Errorf("%q is included in %s", substr, s) + } + } +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 00000000..c15dd6f4 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,19 @@ +package req + +import ( + "bytes" + "github.com/imroc/req/v3/internal/tests" + "log" + "testing" +) + +func TestLogger(t *testing.T) { + buf := new(bytes.Buffer) + l := NewLogger(buf, "", log.Ldate|log.Lmicroseconds) + c := tc().SetLogger(l) + c.SetProxyURL(":=\\<>ksfj&*&sf") + tests.AssertContains(t, buf.String(), "error", true) + buf.Reset() + c.R().SetOutput(nil) + tests.AssertContains(t, buf.String(), "warn", true) +} diff --git a/req_test.go b/req_test.go index 2bce5fcc..1ffddea4 100644 --- a/req_test.go +++ b/req_test.go @@ -285,20 +285,6 @@ func assertNotNil(t *testing.T, v interface{}) { } } -func assertContains(t *testing.T, s, substr string, shouldContain bool) { - s = strings.ToLower(s) - isContain := strings.Contains(s, substr) - if shouldContain { - if !isContain { - t.Errorf("%q is not included in %s", substr, s) - } - } else { - if isContain { - t.Errorf("%q is included in %s", substr, s) - } - } -} - func assertError(t *testing.T, err error) { if err != nil { t.Errorf("Error occurred [%v]", err) diff --git a/request_test.go b/request_test.go index ff6466e7..6395aeda 100644 --- a/request_test.go +++ b/request_test.go @@ -139,10 +139,10 @@ func testEnableDump(t *testing.T, fn func(r *Request, reqHeader, reqBody, respHe resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "user-agent", reqHeader) - assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "date", respHeader) - assertContains(t, dump, "testpost: text response", respBody) + tests.AssertContains(t, dump, "user-agent", reqHeader) + tests.AssertContains(t, dump, "test body", reqBody) + tests.AssertContains(t, dump, "date", respHeader) + tests.AssertContains(t, dump, "testpost: text response", respBody) } testDump(tc()) testDump(tc().EnableForceHTTP1()) @@ -163,10 +163,10 @@ func testSetDumpOptions(t *testing.T, c *Client) { resp, err := c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "user-agent", true) - assertContains(t, dump, "test body", false) - assertContains(t, dump, "date", false) - assertContains(t, dump, "testpost: text response", true) + tests.AssertContains(t, dump, "user-agent", true) + tests.AssertContains(t, dump, "test body", false) + tests.AssertContains(t, dump, "date", false) + tests.AssertContains(t, dump, "testpost: text response", true) } func TestGet(t *testing.T) { @@ -597,14 +597,14 @@ func TestTraceInfo(t *testing.T) { resp, err := tc().R().Get("/") assertSuccess(t, resp, err) ti := resp.TraceInfo() - assertContains(t, ti.String(), "not enabled", true) - assertContains(t, ti.Blame(), "not enabled", true) + tests.AssertContains(t, ti.String(), "not enabled", true) + tests.AssertContains(t, ti.Blame(), "not enabled", true) resp, err = tc().EnableTraceAll().R().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - assertContains(t, ti.String(), "not enabled", false) - assertContains(t, ti.Blame(), "not enabled", false) + tests.AssertContains(t, ti.String(), "not enabled", false) + tests.AssertContains(t, ti.Blame(), "not enabled", false) } func testTraceInfo(t *testing.T, c *Client) { @@ -678,10 +678,10 @@ func TestUploadMultipart(t *testing.T) { SetResult(&m). Post("/multipart") assertSuccess(t, resp, err) - assertContains(t, resp.String(), "sample-image.png", true) - assertContains(t, resp.String(), "sample-file.txt", true) - assertContains(t, resp.String(), "value1", true) - assertContains(t, resp.String(), "value2", true) + tests.AssertContains(t, resp.String(), "sample-image.png", true) + tests.AssertContains(t, resp.String(), "sample-file.txt", true) + tests.AssertContains(t, resp.String(), "value1", true) + tests.AssertContains(t, resp.String(), "value2", true) } func TestFixPragmaCache(t *testing.T) { From 84dd25b992cc0ec48d4403ab0a97c24e81066e95 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 19:06:20 +0800 Subject: [PATCH 380/843] remove unused code in textproto_reader.go --- textproto_reader.go | 367 -------------------------------------------- 1 file changed, 367 deletions(-) diff --git a/textproto_reader.go b/textproto_reader.go index 99e40019..0254ee86 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -9,24 +9,11 @@ import ( "bytes" "fmt" "github.com/imroc/req/v3/internal/util" - "io" - "io/ioutil" "net/textproto" - "strconv" "strings" "sync" ) -// An Error represents a numeric error response from a server. -type codeError struct { - Code int - Msg string -} - -func (e *codeError) Error() string { - return fmt.Sprintf("%03d %s", e.Code, e.Msg) -} - func isASCIILetter(b byte) bool { b |= 0x20 // make lower case return 'a' <= b && b <= 'z' @@ -36,7 +23,6 @@ func isASCIILetter(b byte) bool { // or responses from a text protocol network connection. type textprotoReader struct { R *bufio.Reader - dot *dotReader buf []byte // a re-usable buffer for readContinuedLineSlice readLine func() (line []byte, isPrefix bool, err error) } @@ -96,19 +82,7 @@ func (r *textprotoReader) ReadLine() (string, error) { return string(line), err } -// ReadLineBytes is like ReadLine but returns a []byte instead of a string. -func (r *textprotoReader) ReadLineBytes() ([]byte, error) { - line, err := r.readLineSlice() - if line != nil { - buf := make([]byte, len(line)) - copy(buf, line) - line = buf - } - return line, err -} - func (r *textprotoReader) readLineSlice() ([]byte, error) { - r.closeDot() var line []byte for { @@ -128,30 +102,6 @@ func (r *textprotoReader) readLineSlice() ([]byte, error) { return line, nil } -// ReadContinuedLine reads a possibly continued line from r, -// eliding the final trailing ASCII white space. -// Lines after the first are considered continuations if they -// begin with a space or tab character. In the returned data, -// continuation lines are separated from the previous line -// only by a single space: the newline and leading white space -// are removed. -// -// For example, consider this input: -// -// Line 1 -// continued... -// Line 2 -// -// The first call to ReadContinuedLine will return "Line 1 continued..." -// and the second will return "Line 2". -// -// Empty lines are never continued. -// -func (r *textprotoReader) ReadContinuedLine() (string, error) { - line, err := r.readContinuedLineSlice(noValidation) - return string(line), err -} - // trim returns s with leading and trailing spaces and tabs removed. // It does not assume Unicode or UTF-8. func trim(s []byte) []byte { @@ -166,18 +116,6 @@ func trim(s []byte) []byte { return s[i:n] } -// ReadContinuedLineBytes is like ReadContinuedLine but -// returns a []byte instead of a string. -func (r *textprotoReader) ReadContinuedLineBytes() ([]byte, error) { - line, err := r.readContinuedLineSlice(noValidation) - if line != nil { - buf := make([]byte, len(line)) - copy(buf, line) - line = buf - } - return line, err -} - // readContinuedLineSlice reads continued lines from the reader buffer, // returning a byte slice with all lines. The validateFirstLine function // is run on the first read line, and if it returns an error then this @@ -246,14 +184,6 @@ func (r *textprotoReader) skipSpace() int { return n } -func (r *textprotoReader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { - line, err := r.ReadLine() - if err != nil { - return - } - return parseCodeLine(line, expectCode) -} - // A protocolError describes a protocol violation such // as an invalid response or a hung-up connection. type protocolError string @@ -262,270 +192,6 @@ func (p protocolError) Error() string { return string(p) } -func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { - if len(line) < 4 || line[3] != ' ' && line[3] != '-' { - err = protocolError("short response: " + line) - return - } - continued = line[3] == '-' - code, err = strconv.Atoi(line[0:3]) - if err != nil || code < 100 { - err = protocolError("invalid response code: " + line) - return - } - message = line[4:] - if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || - 10 <= expectCode && expectCode < 100 && code/10 != expectCode || - 100 <= expectCode && expectCode < 1000 && code != expectCode { - err = &codeError{code, message} - } - return -} - -// ReadCodeLine reads a response code line of the form -// code message -// where code is a three-digit status code and the message -// extends to the rest of the line. An example of such a line is: -// 220 plan9.bell-labs.com ESMTP -// -// If the prefix of the status does not match the digits in expectCode, -// ReadCodeLine returns with err set to &codeError{code, message}. -// For example, if expectCode is 31, an error will be returned if -// the status is not in the range [310,319]. -// -// If the response is multi-line, ReadCodeLine returns an error. -// -// An expectCode <= 0 disables the check of the status code. -// -func (r *textprotoReader) ReadCodeLine(expectCode int) (code int, message string, err error) { - code, continued, message, err := r.readCodeLine(expectCode) - if err == nil && continued { - err = protocolError("unexpected multi-line response: " + message) - } - return -} - -// ReadResponse reads a multi-line response of the form: -// -// code-message line 1 -// code-message line 2 -// ... -// code message line n -// -// where code is a three-digit status code. The first line starts with the -// code and a hyphen. The response is terminated by a line that starts -// with the same code followed by a space. Each line in message is -// separated by a newline (\n). -// -// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for -// details of another form of response accepted: -// -// code-message line 1 -// message line 2 -// ... -// code message line n -// -// If the prefix of the status does not match the digits in expectCode, -// ReadResponse returns with err set to &codeError{code, message}. -// For example, if expectCode is 31, an error will be returned if -// the status is not in the range [310,319]. -// -// An expectCode <= 0 disables the check of the status code. -// -func (r *textprotoReader) ReadResponse(expectCode int) (code int, message string, err error) { - code, continued, message, err := r.readCodeLine(expectCode) - multi := continued - for continued { - line, err := r.ReadLine() - if err != nil { - return 0, "", err - } - - var code2 int - var moreMessage string - code2, continued, moreMessage, err = parseCodeLine(line, 0) - if err != nil || code2 != code { - message += "\n" + strings.TrimRight(line, "\r\n") - continued = true - continue - } - message += "\n" + moreMessage - } - if err != nil && multi && message != "" { - // replace one line error message with all lines (full message) - err = &codeError{code, message} - } - return -} - -// DotReader returns a new textprotoReader that satisfies Reads using the -// decoded text of a dot-encoded block read from r. -// The returned textprotoReader is only valid until the next call -// to a method on r. -// -// Dot encoding is a common framing used for data blocks -// in text protocols such as SMTP. The data consists of a sequence -// of lines, each of which ends in "\r\n". The sequence itself -// ends at a line containing just a dot: ".\r\n". Lines beginning -// with a dot are escaped with an additional dot to avoid -// looking like the end of the sequence. -// -// The decoded form returned by the textprotoReader's Read method -// rewrites the "\r\n" line endings into the simpler "\n", -// removes leading dot escapes if present, and stops with error io.EOF -// after consuming (and discarding) the end-of-sequence line. -func (r *textprotoReader) DotReader() io.Reader { - r.closeDot() - r.dot = &dotReader{r: r} - return r.dot -} - -type dotReader struct { - r *textprotoReader - state int -} - -// Read satisfies reads by decoding dot-encoded data read from d.r. -func (d *dotReader) Read(b []byte) (n int, err error) { - // Run data through a simple state machine to - // elide leading dots, rewrite trailing \r\n into \n, - // and detect ending .\r\n line. - const ( - stateBeginLine = iota // beginning of line; initial state; must be zero - stateDot // read . at beginning of line - stateDotCR // read .\r at beginning of line - stateCR // read \r (possibly at end of line) - stateData // reading data in middle of line - stateEOF // reached .\r\n end marker line - ) - br := d.r.R - for n < len(b) && d.state != stateEOF { - var c byte - c, err = br.ReadByte() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - break - } - switch d.state { - case stateBeginLine: - if c == '.' { - d.state = stateDot - continue - } - if c == '\r' { - d.state = stateCR - continue - } - d.state = stateData - - case stateDot: - if c == '\r' { - d.state = stateDotCR - continue - } - if c == '\n' { - d.state = stateEOF - continue - } - d.state = stateData - - case stateDotCR: - if c == '\n' { - d.state = stateEOF - continue - } - // Not part of .\r\n. - // Consume leading dot and emit saved \r. - br.UnreadByte() - c = '\r' - d.state = stateData - - case stateCR: - if c == '\n' { - d.state = stateBeginLine - break - } - // Not part of \r\n. Emit saved \r - br.UnreadByte() - c = '\r' - d.state = stateData - - case stateData: - if c == '\r' { - d.state = stateCR - continue - } - if c == '\n' { - d.state = stateBeginLine - } - } - b[n] = c - n++ - } - if err == nil && d.state == stateEOF { - err = io.EOF - } - if err != nil && d.r.dot == d { - d.r.dot = nil - } - return -} - -// closeDot drains the current DotReader if any, -// making sure that it reads until the ending dot line. -func (r *textprotoReader) closeDot() { - if r.dot == nil { - return - } - buf := make([]byte, 128) - for r.dot != nil { - // When Read reaches EOF or an error, - // it will set r.dot == nil. - r.dot.Read(buf) - } -} - -// ReadDotBytes reads a dot-encoding and returns the decoded data. -// -// See the documentation for the DotReader method for details about dot-encoding. -func (r *textprotoReader) ReadDotBytes() ([]byte, error) { - return ioutil.ReadAll(r.DotReader()) -} - -// ReadDotLines reads a dot-encoding and returns a slice -// containing the decoded lines, with the final \r\n or \n elided from each. -// -// See the documentation for the DotReader method for details about dot-encoding. -func (r *textprotoReader) ReadDotLines() ([]string, error) { - // We could use ReadDotBytes and then Split it, - // but reading a line at a time avoids needing a - // large contiguous block of memory and is simpler. - var v []string - var err error - for { - var line string - line, err = r.ReadLine() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - break - } - - // Dot by itself marks end; otherwise cut one dot. - if len(line) > 0 && line[0] == '.' { - if len(line) == 1 { - break - } - line = line[1:] - } - v = append(v, line) - } - return v, err -} - var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. @@ -611,10 +277,6 @@ func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { } } -// noValidation is a no-op validation func for readContinuedLineSlice -// that permits any lines. -func noValidation(_ []byte) error { return nil } - // mustHaveFieldNameColon ensures that, per RFC 7230, the // field-name is on a single line, so the first line must // contain a colon. @@ -640,35 +302,6 @@ func (r *textprotoReader) upcomingHeaderNewlines() (n int) { return bytes.Count(peek, nl) } -// CanonicalMIMEHeaderKey returns the canonical format of the -// MIME header key s. The canonicalization converts the first -// letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the -// canonical key for "accept-encoding" is "Accept-Encoding". -// MIME header keys are assumed to be ASCII only. -// If s contains a space or invalid header field bytes, it is -// returned without modifications. -func CanonicalMIMEHeaderKey(s string) string { - commonHeaderOnce.Do(initCommonHeader) - - // Quick check for canonical encoding. - upper := true - for i := 0; i < len(s); i++ { - c := s[i] - if !validHeaderFieldByte(c) { - return s - } - if upper && 'a' <= c && c <= 'z' { - return canonicalMIMEHeaderKey([]byte(s)) - } - if !upper && 'A' <= c && c <= 'Z' { - return canonicalMIMEHeaderKey([]byte(s)) - } - upper = c == '-' - } - return s -} - const toLower = 'a' - 'A' // validHeaderFieldByte reports whether b is a valid byte in a header From e7b0cf92c4d054477409f514005f6df7fcc4625c Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 20:58:01 +0800 Subject: [PATCH 381/843] add trailer test for request --- req_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/req_test.go b/req_test.go index 1ffddea4..20f4b7b6 100644 --- a/req_test.go +++ b/req_test.go @@ -832,4 +832,9 @@ func TestTrailer(t *testing.T) { if !ok { t.Error("trailer not exists") } + r := tc().EnableForceHTTP1().R() + r.RawRequest.Trailer = make(http.Header) + r.RawRequest.Trailer.Add("test", "") + resp, err = r.SetBody("test").Post("/") + assertSuccess(t, resp, err) } From 60fe9b84f79c01f1f7474432af4eaa301ddec392 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 17 Feb 2022 21:25:12 +0800 Subject: [PATCH 382/843] add TestParseSettingsFrame --- h2_frame_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index 8a8d2a98..2f75abd8 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -7,6 +7,7 @@ package req import ( "bytes" "fmt" + "github.com/imroc/req/v3/internal/tests" "io" "reflect" "strings" @@ -1282,3 +1283,24 @@ func TestSettingsDuplicates(t *testing.T) { } } + +func TestParseSettingsFrame(t *testing.T) { + fh := http2FrameHeader{} + fh.Flags = http2FlagSettingsAck + fh.Length = 1 + countErr := func(s string) {} + _, err := http2parseSettingsFrame(nil, fh, countErr, nil) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + + fh = http2FrameHeader{StreamID: 1} + _, err = http2parseSettingsFrame(nil, fh, countErr, nil) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh = http2FrameHeader{} + _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("roc")) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + + fh = http2FrameHeader{valid: true} + _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) + tests.AssertNoError(t, err) +} From ba1cb690b94d5cf8099c33a56d13b70a5d66958a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 10:18:48 +0800 Subject: [PATCH 383/843] add more response tests --- req_test.go | 8 +++++++- request_test.go | 1 + response_test.go | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 20f4b7b6..6980e8d4 100644 --- a/req_test.go +++ b/req_test.go @@ -198,7 +198,13 @@ func handleGet(w http.ResponseWriter, r *http.Request) { if r.FormValue("type") != "no" { w.Header().Set(hdrContentTypeKey, jsonContentType) } - w.Write([]byte(`{"name": "roc"}`)) + w.Header().Set(hdrContentTypeKey, jsonContentType) + if r.FormValue("error") == "yes" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message": "not allowed"}`)) + } else { + w.Write([]byte(`{"name": "roc"}`)) + } case "/xml": r.ParseForm() if r.FormValue("type") != "no" { diff --git a/request_test.go b/request_test.go index 6395aeda..8f10146d 100644 --- a/request_test.go +++ b/request_test.go @@ -605,6 +605,7 @@ func TestTraceInfo(t *testing.T) { ti = resp.TraceInfo() tests.AssertContains(t, ti.String(), "not enabled", false) tests.AssertContains(t, ti.Blame(), "not enabled", false) + assertEqual(t, true, resp.TotalTime() > 0) } func testTraceInfo(t *testing.T, c *Client) { diff --git a/response_test.go b/response_test.go index 08e700bf..a80a6d85 100644 --- a/response_test.go +++ b/response_test.go @@ -6,6 +6,10 @@ type User struct { Name string `json:"name" xml:"name"` } +type Message struct { + Message string `json:"message"` +} + func TestUnmarshalJson(t *testing.T) { var user User resp, err := tc().R().Get("/json") @@ -32,3 +36,24 @@ func TestUnmarshal(t *testing.T) { assertError(t, err) assertEqual(t, "roc", user.Name) } + +func TestResponseResult(t *testing.T) { + resp, _ := tc().R().SetResult(&User{}).Get("/json") + user, ok := resp.Result().(*User) + if !ok { + t.Fatal("Response.Result() should return *User") + } + assertEqual(t, "roc", user.Name) + + assertEqual(t, true, resp.TotalTime() > 0) + assertEqual(t, false, resp.ReceivedAt().IsZero()) +} + +func TestResponseError(t *testing.T) { + resp, _ := tc().R().SetError(&Message{}).Get("/json?error=yes") + msg, ok := resp.Error().(*Message) + if !ok { + t.Fatal("Response.Error() should return *Message") + } + assertEqual(t, "not allowed", msg.Message) +} From 2a04cdb4436dae9f7243dfbea5c5dc6e3f7d5454 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 10:32:30 +0800 Subject: [PATCH 384/843] add TestSummarizeFrame --- h2_frame_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index 2f75abd8..c7aae139 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1304,3 +1304,31 @@ func TestParseSettingsFrame(t *testing.T) { _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) tests.AssertNoError(t, err) } + +func TestSummarizeFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + var f http2Frame + f = &http2SettingsFrame{http2FrameHeader: fh} + s := http2summarizeFrame(f) + tests.AssertContains(t, s, "len=0", true) + + f = &http2DataFrame{http2FrameHeader: fh} + s = http2summarizeFrame(f) + tests.AssertContains(t, s, `data=""`, true) + + f = &http2WindowUpdateFrame{http2FrameHeader: fh} + s = http2summarizeFrame(f) + tests.AssertContains(t, s, "onn", true) + + f = &http2PingFrame{http2FrameHeader: fh} + s = http2summarizeFrame(f) + tests.AssertContains(t, s, "ping", true) + + f = &http2GoAwayFrame{http2FrameHeader: fh} + s = http2summarizeFrame(f) + tests.AssertContains(t, s, "laststreamid", true) + + f = &http2RSTStreamFrame{http2FrameHeader: fh} + s = http2summarizeFrame(f) + tests.AssertContains(t, s, "no_error", true) +} From cbe63fee358c7083107d09f4446d982c01bc4322 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 11:31:31 +0800 Subject: [PATCH 385/843] add TestParsePushPromise --- h2_frame_test.go | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/h2_frame_test.go b/h2_frame_test.go index c7aae139..13623ced 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1305,10 +1305,29 @@ func TestParseSettingsFrame(t *testing.T) { tests.AssertNoError(t, err) } +func TestParsePushPromise(t *testing.T) { + fh := http2FrameHeader{} + countError := func(string) {} + _, err := http2parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh.StreamID = 1 + fh.Flags = http2FlagPushPromisePadded + _, err = http2parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "EOF") + + fh.Flags = 0 + _, err = http2parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "EOF") + + _, err = http2parsePushPromise(nil, fh, countError, []byte("ksjfksjksjflskk")) + tests.AssertNoError(t, err) +} + func TestSummarizeFrame(t *testing.T) { fh := http2FrameHeader{valid: true} var f http2Frame - f = &http2SettingsFrame{http2FrameHeader: fh} + f = &http2SettingsFrame{http2FrameHeader: fh, p: []byte{0x09, 0x01, 0x80, 0x20, 0x00, 0x11}} s := http2summarizeFrame(f) tests.AssertContains(t, s, "len=0", true) @@ -1318,7 +1337,7 @@ func TestSummarizeFrame(t *testing.T) { f = &http2WindowUpdateFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, "onn", true) + tests.AssertContains(t, s, "conn", true) f = &http2PingFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) From da57192112ff1bcd19976f793387ea94d6ef2cc3 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 11:56:35 +0800 Subject: [PATCH 386/843] add TestH2Framer --- h2_frame_test.go | 7 +++++++ middleware.go | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/h2_frame_test.go b/h2_frame_test.go index 13623ced..f94012ce 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1351,3 +1351,10 @@ func TestSummarizeFrame(t *testing.T) { s = http2summarizeFrame(f) tests.AssertContains(t, s, "no_error", true) } + +func TestH2Framer(t *testing.T) { + f := &http2Framer{} + f.debugWriteLoggerf = func(s string, i ...interface{}) {} + f.logWrite() + assertNotNil(t, f.debugFramer) +} diff --git a/middleware.go b/middleware.go index e09adb5c..d092108a 100644 --- a/middleware.go +++ b/middleware.go @@ -22,12 +22,6 @@ type ( ResponseMiddleware func(*Client, *Response) error ) -var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") - -func escapeQuotes(s string) string { - return quoteEscaper.Replace(s) -} - func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEHeader { hdr := make(textproto.MIMEHeader) From e630abf72bbf19d32b187cef17b3bbf5a487ec96 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:06:49 +0800 Subject: [PATCH 387/843] add TestParseDataFrame --- h2_frame_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index f94012ce..d1d5ac55 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1352,9 +1352,28 @@ func TestSummarizeFrame(t *testing.T) { tests.AssertContains(t, s, "no_error", true) } +func TestParseDataFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + _, err := http2parseDataFrame(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "DATA frame with stream ID 0") + + fh.StreamID = 1 + fh.Flags = http2FlagDataPadded + fc := &http2frameCache{} + payload := []byte{0x09, 0x00, 0x00, 0x98, 0x11, 0x12} + _, err = http2parseDataFrame(fc, fh, countError, payload) + tests.AssertErrorContains(t, err, "pad size larger than data payload") + + payload = []byte{0x02, 0x00, 0x00, 0x98, 0x11, 0x12} + _, err = http2parseDataFrame(fc, fh, countError, payload) + tests.AssertNoError(t, err) +} + func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} f.logWrite() assertNotNil(t, f.debugFramer) + assertNil(t, f.ErrorDetail()) } From 8ffbc936dd10b27bb7906d41b476a48ded66b2d9 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:13:22 +0800 Subject: [PATCH 388/843] add TestParseWindowUpdateFrame --- h2_frame_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index d1d5ac55..ce596396 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1370,6 +1370,23 @@ func TestParseDataFrame(t *testing.T) { tests.AssertNoError(t, err) } +func TestParseWindowUpdateFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + _, err := http2parseWindowUpdateFrame(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + + p := []byte{0x00, 0x00, 0x00, 0x00} + _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh.StreamID = 255 + p[0] = 0x01 + p[3] = 0x01 + _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) +} + func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} From cb5c31fb4208b514987c8db58b31a6fbb826632b Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:21:16 +0800 Subject: [PATCH 389/843] add TestParseUnknownFrame --- h2_frame_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index ce596396..c92e44cd 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1387,6 +1387,19 @@ func TestParseWindowUpdateFrame(t *testing.T) { tests.AssertNoError(t, err) } +func TestParseUnknownFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + p := []byte("test") + f, err := http2parseUnknownFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) + uf, ok := f.(*http2UnknownFrame) + if !ok { + t.Fatalf("not http2UnknownFrame type: %#+v", f) + } + assertEqual(t, p, uf.Payload()) +} + func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} From 23404045772ecd3f36ca4eefb3eb7f83c0b166a0 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:24:53 +0800 Subject: [PATCH 390/843] add TestPushPromiseFrame --- h2_frame_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index c92e44cd..cc3c52a5 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1400,6 +1400,14 @@ func TestParseUnknownFrame(t *testing.T) { assertEqual(t, p, uf.Payload()) } +func TestPushPromiseFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + buf := []byte("test") + f := &http2PushPromiseFrame{http2FrameHeader: fh, headerFragBuf: buf} + assertEqual(t, buf, f.HeaderBlockFragment()) + assertEqual(t, false, f.HeadersEnded()) +} + func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} From 5c46b21b815a7d91d8be7d314f6165d150f89135 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:27:52 +0800 Subject: [PATCH 391/843] add TestParseRSTStreamFrame --- h2_frame_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index cc3c52a5..6ccca666 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1400,6 +1400,22 @@ func TestParseUnknownFrame(t *testing.T) { assertEqual(t, p, uf.Payload()) } +func TestParseRSTStreamFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + p := []byte("test.") + _, err := http2parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + + p = []byte("test") + _, err = http2parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh.StreamID = 1 + _, err = http2parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) +} + func TestPushPromiseFrame(t *testing.T) { fh := http2FrameHeader{valid: true} buf := []byte("test") From f1765abfb6e2c076ebd382ec605e884d1736fe87 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:36:08 +0800 Subject: [PATCH 392/843] add TestH2Framer --- h2_frame_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index 6ccca666..27fa3a9d 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1430,4 +1430,19 @@ func TestH2Framer(t *testing.T) { f.logWrite() assertNotNil(t, f.debugFramer) assertNil(t, f.ErrorDetail()) + + f.w = new(bytes.Buffer) + err := f.WriteRawFrame(http2FrameData, http2FlagDataEndStream, 1, nil) + tests.AssertNoError(t, err) + + param := http2PushPromiseParam{} + err = f.WritePushPromise(param) + tests.AssertErrorContains(t, err, "invalid stream ID") + + param.StreamID = 1 + param.EndHeaders = true + param.PadLength = 2 + f.AllowIllegalWrites = true + err = f.WritePushPromise(param) + tests.AssertNoError(t, err) } From 9f8b4f155fa3da458eb08026e66a4a6b2cb9a213 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:41:12 +0800 Subject: [PATCH 393/843] add TestParsePingFrame --- h2_frame_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index 27fa3a9d..498f62bb 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1416,6 +1416,23 @@ func TestParseRSTStreamFrame(t *testing.T) { tests.AssertNoError(t, err) } +func TestParsePingFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + payload := []byte("") + _, err := http2parsePingFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + + payload = []byte("testtest") + fh.StreamID = 1 + _, err = http2parsePingFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh.StreamID = 0 + _, err = http2parsePingFrame(nil, fh, countError, payload) + tests.AssertNoError(t, err) +} + func TestPushPromiseFrame(t *testing.T) { fh := http2FrameHeader{valid: true} buf := []byte("test") From b6821fb7105080a02bf5cf5577c8f63d02bd9731 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 13:44:26 +0800 Subject: [PATCH 394/843] add TestParseGoAwayFrame --- h2_frame_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/h2_frame_test.go b/h2_frame_test.go index 498f62bb..c5f2ed6c 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1433,6 +1433,20 @@ func TestParsePingFrame(t *testing.T) { tests.AssertNoError(t, err) } +func TestParseGoAwayFrame(t *testing.T) { + fh := http2FrameHeader{valid: true} + countError := func(string) {} + payload := []byte("") + + fh.StreamID = 1 + _, err := http2parseGoAwayFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + + fh.StreamID = 0 + _, err = http2parseGoAwayFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") +} + func TestPushPromiseFrame(t *testing.T) { fh := http2FrameHeader{valid: true} buf := []byte("test") From 8b6ba92fe02be8326581a8647e37fd350a968133 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 14:01:37 +0800 Subject: [PATCH 395/843] add TestCountReadFrameError --- h2_transport_test.go | 34 ++++++++++++++++++++++++++++++++++ req_test.go | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/h2_transport_test.go b/h2_transport_test.go index 78bca258..37003e4e 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -13,6 +13,7 @@ import ( "errors" "flag" "fmt" + "github.com/imroc/req/v3/internal/tests" "io" "io/ioutil" "log" @@ -5853,3 +5854,36 @@ func TestTransportSlowWrites(t *testing.T) { } resp.Body.Close() } + +func TestCountReadFrameError(t *testing.T) { + cc := &http2ClientConn{} + errMsg := "" + countError := func(errType string) { + errMsg = errType + } + cc.t = &http2Transport{CountError: countError} + + var err error + cc.countReadFrameError(err) + assertEqual(t, "", errMsg) + + err = http2ConnectionError(http2ErrCodeInternal) + cc.countReadFrameError(err) + tests.AssertContains(t, errMsg, "read_frame_conn_error", true) + + err = io.EOF + cc.countReadFrameError(err) + tests.AssertContains(t, errMsg, "read_frame_eof", true) + + err = io.ErrUnexpectedEOF + cc.countReadFrameError(err) + tests.AssertContains(t, errMsg, "read_frame_unexpected_eof", true) + + err = http2ErrFrameTooLarge + cc.countReadFrameError(err) + tests.AssertContains(t, errMsg, "read_frame_too_large", true) + + err = errors.New("other") + cc.countReadFrameError(err) + tests.AssertContains(t, errMsg, "read_frame_other", true) +} diff --git a/req_test.go b/req_test.go index 6980e8d4..540b20ac 100644 --- a/req_test.go +++ b/req_test.go @@ -832,7 +832,7 @@ func TestGlobalWrapper(t *testing.T) { } func TestTrailer(t *testing.T) { - resp, err := tc().EnableForceHTTP1().R().Get("/chunked") + resp, err := tc().EnableForceHTTP1().R().EnableDump().Get("/chunked") assertSuccess(t, resp, err) _, ok := resp.Trailer["Expires"] if !ok { From d9d595edd7c5493e7ba7c78c318421f2a3ff8b9b Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 14:16:13 +0800 Subject: [PATCH 396/843] add TestProcessHeaders --- h2_transport_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/h2_transport_test.go b/h2_transport_test.go index 37003e4e..2a62d601 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -5887,3 +5887,19 @@ func TestCountReadFrameError(t *testing.T) { cc.countReadFrameError(err) tests.AssertContains(t, errMsg, "read_frame_other", true) } + +func TestProcessHeaders(t *testing.T) { + rl := &http2clientConnReadLoop{} + cc := &http2ClientConn{streams: map[uint32]*http2clientStream{}} + cc.streams[1] = &http2clientStream{cc: cc, abort: make(chan struct{})} + rl.cc = cc + f := &http2MetaHeadersFrame{http2HeadersFrame: &http2HeadersFrame{ + http2FrameHeader: http2FrameHeader{StreamID: 1}, + }} + err := rl.processHeaders(f) + tests.AssertNoError(t, err) + + f.StreamID = 0 + err = rl.processHeaders(f) + tests.AssertNoError(t, err) +} From 4b755bd02c09116623ae1c698ff7b50979928d18 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 14:52:22 +0800 Subject: [PATCH 397/843] add TestSettingValid --- h2_test.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/h2_test.go b/h2_test.go index 0d2ca500..526d06cc 100644 --- a/h2_test.go +++ b/h2_test.go @@ -78,15 +78,28 @@ func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { return false } -// waitErrCondition is like waitCondition but with errors instead of bools. -func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { - deadline := time.Now().Add(waitFor) - var err error - for time.Now().Before(deadline) { - if err = fn(); err == nil { - return nil - } - time.Sleep(checkEvery) +func TestSettingValid(t *testing.T) { + cases := []struct { + id http2SettingID + val uint32 + }{ + { + id: http2SettingEnablePush, + val: 2, + }, + { + id: http2SettingInitialWindowSize, + val: 1 << 31, + }, + { + id: http2SettingMaxFrameSize, + val: 0, + }, + } + for _, c := range cases { + s := &http2Setting{ID: c.id, Val: c.val} + assertEqual(t, true, s.Valid() != nil) } - return err + s := &http2Setting{ID: http2SettingMaxHeaderListSize} + assertEqual(t, true, s.Valid() == nil) } From c775e9cc8d26ffe3c4102dfb84ec345e70e26e30 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 14:59:19 +0800 Subject: [PATCH 398/843] add TestBodyAllowedForStatus and TestHttpError --- h2_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/h2_test.go b/h2_test.go index 526d06cc..f96928ac 100644 --- a/h2_test.go +++ b/h2_test.go @@ -103,3 +103,16 @@ func TestSettingValid(t *testing.T) { s := &http2Setting{ID: http2SettingMaxHeaderListSize} assertEqual(t, true, s.Valid() == nil) } + +func TestBodyAllowedForStatus(t *testing.T) { + assertEqual(t, false, http2bodyAllowedForStatus(101)) + assertEqual(t, false, http2bodyAllowedForStatus(204)) + assertEqual(t, false, http2bodyAllowedForStatus(304)) + assertEqual(t, true, http2bodyAllowedForStatus(900)) +} + +func TestHttpError(t *testing.T) { + e := &http2httpError{msg: "test"} + assertEqual(t, "test", e.Error()) + assertEqual(t, true, e.Temporary()) +} From e840660cbcc196b3de06ab3f84bd61b574b744fb Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 15:12:10 +0800 Subject: [PATCH 399/843] add TestSetBodyWrapper --- request_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/request_test.go b/request_test.go index 8f10146d..33616595 100644 --- a/request_test.go +++ b/request_test.go @@ -696,3 +696,40 @@ func TestSetFileBytes(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, "test", resp.String()) } + +func TestSetBodyWrapper(t *testing.T) { + b := []byte("test") + s := string(b) + c := tc() + r := c.R().SetBodyXmlString(s) + assertEqual(t, true, len(r.body) > 0) + r = SetBodyXmlString(s) + assertEqual(t, true, len(r.body) > 0) + + r = c.R().SetBodyXmlBytes(b) + assertEqual(t, true, len(r.body) > 0) + r = SetBodyXmlBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = c.R().SetBodyJsonString(s) + assertEqual(t, true, len(r.body) > 0) + r = SetBodyJsonString(s) + assertEqual(t, true, len(r.body) > 0) + + r = c.R().SetBodyJsonBytes(b) + assertEqual(t, true, len(r.body) > 0) + r = SetBodyJsonBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyXmlMarshal(0) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyString(s) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = SetBody(nil) + assertEqual(t, true, r.RawRequest.Body == nil) +} From 3293ef11d843be97e0e59c3d7af64af721240bbc Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 15:23:53 +0800 Subject: [PATCH 400/843] move global test into TestGlobalWrapper --- req_test.go | 26 ++++++++++++++++++++++++++ request_test.go | 21 +-------------------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/req_test.go b/req_test.go index 540b20ac..6d93e5c9 100644 --- a/req_test.go +++ b/req_test.go @@ -624,6 +624,32 @@ func TestGlobalWrapper(t *testing.T) { EnableAllowGetMethodPayload() assertEqual(t, true, DefaultClient().AllowGetMethodPayload) + b := []byte("test") + s := string(b) + r := SetBodyXmlMarshal(0) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyString(s) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyJsonBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyJsonString(s) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyXmlBytes(b) + assertEqual(t, true, len(r.body) > 0) + + r = SetBodyXmlString(s) + assertEqual(t, true, len(r.body) > 0) + + r = SetBody(nil) + assertEqual(t, true, r.RawRequest.Body == nil) + marshalFunc := func(v interface{}) ([]byte, error) { return nil, testErr } diff --git a/request_test.go b/request_test.go index 33616595..c0385b3c 100644 --- a/request_test.go +++ b/request_test.go @@ -701,35 +701,16 @@ func TestSetBodyWrapper(t *testing.T) { b := []byte("test") s := string(b) c := tc() + r := c.R().SetBodyXmlString(s) assertEqual(t, true, len(r.body) > 0) - r = SetBodyXmlString(s) - assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyXmlBytes(b) assertEqual(t, true, len(r.body) > 0) - r = SetBodyXmlBytes(b) - assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonString(s) assertEqual(t, true, len(r.body) > 0) - r = SetBodyJsonString(s) - assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonBytes(b) assertEqual(t, true, len(r.body) > 0) - r = SetBodyJsonBytes(b) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyXmlMarshal(0) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyString(s) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyBytes(b) - assertEqual(t, true, len(r.body) > 0) - - r = SetBody(nil) - assertEqual(t, true, r.RawRequest.Body == nil) } From 2f22a161f01a0df696aa35bae9d0807178ac39e3 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 16:02:11 +0800 Subject: [PATCH 401/843] add more global wrapper tests --- req_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/req_test.go b/req_test.go index 6d93e5c9..908a5ab2 100644 --- a/req_test.go +++ b/req_test.go @@ -433,6 +433,21 @@ func testGlobalWrapperEnableDumps(t *testing.T) { *respBody = false return EnableDumpWithoutBody() }) + + buf := new(bytes.Buffer) + r := EnableDumpTo(buf) + assertEqual(t, true, r.getDumpOptions().Output != nil) + + dumpFile := tests.GetTestFilePath("req_tmp_dump.out") + r = EnableDumpToFile(tests.GetTestFilePath(dumpFile)) + assertEqual(t, true, r.getDumpOptions().Output != nil) + os.Remove(dumpFile) + + r = SetDumpOptions(&DumpOptions{ + RequestHeader: true, + }) + assertEqual(t, true, r.getDumpOptions().RequestHeader) + } func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { @@ -650,6 +665,21 @@ func TestGlobalWrapper(t *testing.T) { r = SetBody(nil) assertEqual(t, true, r.RawRequest.Body == nil) + r = SetBodyJsonMarshal(User{ + Name: "roc", + }) + assertEqual(t, true, r.RawRequest.Body != nil) + + r = EnableTrace() + assertEqual(t, true, r.trace != nil) + r = DisableTrace() + assertEqual(t, true, r.trace == nil) + + ctx := context.Background() + ctx, _ = context.WithTimeout(ctx, time.Second) + r = SetContext(ctx) + assertEqual(t, ctx, r.Context()) + marshalFunc := func(v interface{}) ([]byte, error) { return nil, testErr } From ac6cf57a148024d85d0389cd6de580e77151d544 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 16:22:48 +0800 Subject: [PATCH 402/843] add TestPeekDrain --- decode_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 decode_test.go diff --git a/decode_test.go b/decode_test.go new file mode 100644 index 00000000..209958ce --- /dev/null +++ b/decode_test.go @@ -0,0 +1,14 @@ +package req + +import "testing" + +func TestPeekDrain(t *testing.T) { + a := autoDecodeReadCloser{peek: []byte("test")} + p := make([]byte, 2) + n, _ := a.peekDrain(p) + assertEqual(t, 2, n) + assertEqual(t, true, a.peek != nil) + n, _ = a.peekDrain(p) + assertEqual(t, 2, n) + assertEqual(t, true, a.peek == nil) +} From a8d182f8b3cbaa63dc2930590346df53b1989880 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 16:38:00 +0800 Subject: [PATCH 403/843] add more global wrapper tests --- req_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/req_test.go b/req_test.go index 908a5ab2..078bd72e 100644 --- a/req_test.go +++ b/req_test.go @@ -638,6 +638,8 @@ func TestGlobalWrapper(t *testing.T) { EnableAllowGetMethodPayload() assertEqual(t, true, DefaultClient().AllowGetMethodPayload) + DisableAllowGetMethodPayload() + assertEqual(t, false, DefaultClient().AllowGetMethodPayload) b := []byte("test") s := string(b) @@ -885,6 +887,48 @@ func TestGlobalWrapper(t *testing.T) { assertEqual(t, true, opt.RequestHeader == true && opt.ResponseHeader == false) DisableDumpAll() assertEqual(t, true, DefaultClient().t.dump == nil) + + r = R() + assertEqual(t, true, r != nil) + c := C() + c.SetTimeout(10 * time.Second) + SetDefaultClient(c) + assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) + + SetRedirectPolicy(NoRedirectPolicy()) + assertEqual(t, true, DefaultClient().httpClient.CheckRedirect != nil) + + EnableForceHTTP1() + assertEqual(t, HTTP1, DefaultClient().t.ForceHttpVersion) + + EnableForceHTTP2() + assertEqual(t, HTTP2, DefaultClient().t.ForceHttpVersion) + + DisableForceHttpVersion() + assertEqual(t, true, DefaultClient().t.ForceHttpVersion == "") + + r = NewRequest() + assertEqual(t, true, r != nil) + c = NewClient() + assertEqual(t, true, c != nil) + + DefaultClient().getResponseOptions().AutoDecodeContentType = nil + SetAutoDecodeContentType("json") + assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) + + DefaultClient().getResponseOptions().AutoDecodeContentType = nil + SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }) + assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) + + DefaultClient().getResponseOptions().AutoDecodeContentType = nil + SetAutoDecodeAllContentType() + assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) + + DisableAutoDecode() + assertEqual(t, true, DefaultClient().getResponseOptions().DisableAutoDecode) + + EnableAutoDecode() + assertEqual(t, false, DefaultClient().getResponseOptions().DisableAutoDecode) } func TestTrailer(t *testing.T) { From 9b9c07c6416d81e842518403e608dc40fa857007 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 17:39:16 +0800 Subject: [PATCH 404/843] add TestParseUintBytes --- h2_gotrack_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/h2_gotrack_test.go b/h2_gotrack_test.go index ae6eb485..fc47c139 100644 --- a/h2_gotrack_test.go +++ b/h2_gotrack_test.go @@ -6,6 +6,7 @@ package req import ( "fmt" + "github.com/imroc/req/v3/internal/tests" "strings" "testing" ) @@ -31,3 +32,25 @@ func TestGoroutineLock(t *testing.T) { t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e) } } + +func TestParseUintBytes(t *testing.T) { + s := []byte{} + _, err := http2parseUintBytes(s, 0, 0) + tests.AssertErrorContains(t, err, "invalid syntax") + + s = []byte("0x") + _, err = http2parseUintBytes(s, 0, 0) + tests.AssertErrorContains(t, err, "invalid syntax") + + s = []byte("0x01") + _, err = http2parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) + + s = []byte("0xa1") + _, err = http2parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) + + s = []byte("0xA1") + _, err = http2parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) +} From 90bb7bd56da891db951738987093e4f445aa625d Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 19:35:11 +0800 Subject: [PATCH 405/843] add comments for RedirectPolicy --- redirect.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/redirect.go b/redirect.go index 9daa3d56..06ce5a6e 100644 --- a/redirect.go +++ b/redirect.go @@ -8,6 +8,7 @@ import ( "strings" ) +// RedirectPolicy represents the redirect policy for Client. type RedirectPolicy func(req *http.Request, via []*http.Request) error // MaxRedirectPolicy specifies the max number of redirect @@ -27,6 +28,9 @@ func NoRedirectPolicy() RedirectPolicy { } } +// SameDomainRedirectPolicy allows redirect only if the redirected domain +// is the same as original domain, e.g. redirect to "www.imroc.cc" from +// "imroc.cc" is allowed, but redirect to "google.com" is not allowed. func SameDomainRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error { if getDomain(req.URL.Host) != getDomain(via[0].URL.Host) { From c6f88a64ef322ba857c4ee6d4ef3186d8f247939 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 19:36:38 +0800 Subject: [PATCH 406/843] cancel test context func --- req_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/req_test.go b/req_test.go index 078bd72e..caf896ad 100644 --- a/req_test.go +++ b/req_test.go @@ -678,9 +678,10 @@ func TestGlobalWrapper(t *testing.T) { assertEqual(t, true, r.trace == nil) ctx := context.Background() - ctx, _ = context.WithTimeout(ctx, time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) r = SetContext(ctx) assertEqual(t, ctx, r.Context()) + cancel() marshalFunc := func(v interface{}) ([]byte, error) { return nil, testErr From 6207ae5a5040416c9b0d0855e5bb1b9479fa773e Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 20:19:05 +0800 Subject: [PATCH 407/843] fix lint --- examples/find-popular-repo/main.go | 2 +- h2_transport.go | 2 +- internal/charsets/charsets.go | 1 + internal/chunked.go | 1 + internal/socks/socks.go | 1 + internal/socks/socks_test.go | 4 ++-- internal/tests/file.go | 2 ++ internal/tests/tests.go | 3 +++ internal/util/util.go | 12 +++--------- middleware.go | 2 +- req.go | 3 ++- roundtrip.go | 3 ++- trace.go | 6 ++++-- transport_default_js.go | 1 + transport_default_other.go | 1 + 15 files changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index 6c17e86f..c9e43097 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -84,7 +84,7 @@ func findTheMostPopularRepo(username string) (repo string, star int, err error) // Unkown http status code, record and return error, here we can use // String() to get response body, cuz response body have already been read // and no error returned, do not need to use ToString(). - err = fmt.Errorf("unkown error. status code %d; body: %s", resp.StatusCode, resp.String()) + err = fmt.Errorf("unknown error. status code %d; body: %s", resp.StatusCode, resp.String()) return } } diff --git a/h2_transport.go b/h2_transport.go index fbba49ad..57d18d5b 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -1556,7 +1556,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper var sawEOF bool for !sawEOF { - n, err := body.Read(buf[:len(buf)]) + n, err := body.Read(buf[:]) if hasContentLen { remainLen -= int64(n) if remainLen == 0 && err == nil { diff --git a/internal/charsets/charsets.go b/internal/charsets/charsets.go index e612fb1e..cb910660 100644 --- a/internal/charsets/charsets.go +++ b/internal/charsets/charsets.go @@ -17,6 +17,7 @@ var boms = []struct { {[]byte{0xef, 0xbb, 0xbf}, "utf-8"}, } +// FindEncoding sniff and find the encoding of the content. func FindEncoding(content []byte) (enc encoding.Encoding, name string) { if len(content) == 0 { return diff --git a/internal/chunked.go b/internal/chunked.go index bd414ebb..dec4ddca 100644 --- a/internal/chunked.go +++ b/internal/chunked.go @@ -19,6 +19,7 @@ import ( const maxLineLength = 4096 // assumed <= bufio.defaultBufSize +// ErrLineTooLong is the error that header line too long. var ErrLineTooLong = errors.New("header line too long") // NewChunkedReader returns a new chunkedReader that translates the data read from r diff --git a/internal/socks/socks.go b/internal/socks/socks.go index 55afd6be..cddef90f 100644 --- a/internal/socks/socks.go +++ b/internal/socks/socks.go @@ -88,6 +88,7 @@ type Addr struct { Port int } +// Network return "socks" func (a *Addr) Network() string { return "socks" } func (a *Addr) String() string { diff --git a/internal/socks/socks_test.go b/internal/socks/socks_test.go index d96c44b6..824a09d7 100644 --- a/internal/socks/socks_test.go +++ b/internal/socks/socks_test.go @@ -12,12 +12,12 @@ func TestReply(t *testing.T) { for i := 0; i < 9; i++ { s := Reply(i).String() if strings.Contains(s, "unknown") { - t.Errorf("resply code [%d] should not unkown", i) + t.Errorf("resply code [%d] should not unknown", i) } } s := Reply(9).String() if !strings.Contains(s, "unknown") { - t.Errorf("resply code [%d] should unkown", 9) + t.Errorf("resply code [%d] should unknown", 9) } } diff --git a/internal/tests/file.go b/internal/tests/file.go index 9dd0a391..f4fc2341 100644 --- a/internal/tests/file.go +++ b/internal/tests/file.go @@ -14,12 +14,14 @@ func init() { testDataPath = filepath.Join(pwd, ".testdata") } +// GetTestFileContent return test file content. func GetTestFileContent(t *testing.T, filename string) []byte { b, err := ioutil.ReadFile(GetTestFilePath(filename)) AssertNoError(t, err) return b } +// GetTestFilePath return test file absolute path. func GetTestFilePath(filename string) string { return filepath.Join(testDataPath, filename) } diff --git a/internal/tests/tests.go b/internal/tests/tests.go index 6ff4d3b4..0a7d8062 100644 --- a/internal/tests/tests.go +++ b/internal/tests/tests.go @@ -5,12 +5,14 @@ import ( "testing" ) +// AssertNoError asserts no error. func AssertNoError(t *testing.T, err error) { if err != nil { t.Errorf("Error occurred [%v]", err) } } +// AssertErrorContains asserts error is not nil and contains specified error string. func AssertErrorContains(t *testing.T, err error, s string) { if err == nil { t.Error("err is nil") @@ -21,6 +23,7 @@ func AssertErrorContains(t *testing.T, err error, s string) { } } +// AssertContains asserts substring is contained in the given string. func AssertContains(t *testing.T, s, substr string, shouldContain bool) { s = strings.ToLower(s) isContain := strings.Contains(s, substr) diff --git a/internal/util/util.go b/internal/util/util.go index 82d488f9..cd816d13 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -18,6 +18,7 @@ func IsXMLType(ct string) bool { return strings.Contains(ct, "xml") } +// GetPointer return the pointer of the interface. func GetPointer(v interface{}) interface{} { vv := reflect.ValueOf(v) if vv.Kind() == reflect.Ptr { @@ -55,15 +56,6 @@ func IsStringEmpty(str string) bool { return len(strings.TrimSpace(str)) == 0 } -func FirstNonEmpty(v ...string) string { - for _, s := range v { - if !IsStringEmpty(s) { - return s - } - } - return "" -} - // See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt // "To receive authorization, the client sends the userid and password, // separated by a single colon (":") character, within a base64 @@ -74,10 +66,12 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } +// BasicAuthHeaderValue return the header of basic auth. func BasicAuthHeaderValue(username, password string) string { return "Basic " + basicAuth(username, password) } +// CreateDirectory create the directory. func CreateDirectory(dir string) (err error) { if _, err = os.Stat(dir); err != nil { if os.IsNotExist(err) { diff --git a/middleware.go b/middleware.go index d092108a..054ebcfd 100644 --- a/middleware.go +++ b/middleware.go @@ -38,7 +38,7 @@ func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEH cd.Add(kv.Key, kv.Value) } } - if c := cd.String(); c != "" { + if c := cd.string(); c != "" { contentDispositionValue += c } hdr.Set("Content-Disposition", contentDispositionValue) diff --git a/req.go b/req.go index ef6f7a70..225d254b 100644 --- a/req.go +++ b/req.go @@ -29,12 +29,13 @@ type ContentDisposition struct { kv []kv } +// Add adds a new key-value pair of Content-Disposition func (c *ContentDisposition) Add(key, value string) *ContentDisposition { c.kv = append(c.kv, kv{Key: key, Value: value}) return c } -func (c *ContentDisposition) String() string { +func (c *ContentDisposition) string() string { if c == nil { return "" } diff --git a/roundtrip.go b/roundtrip.go index c96b3390..f73d3f36 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//+build !js !wasm +//go:build !js || !wasm +// +build !js !wasm package req diff --git a/trace.go b/trace.go index c5849294..c2324b42 100644 --- a/trace.go +++ b/trace.go @@ -25,6 +25,7 @@ IsConnReused: : true RemoteAddr : %v` ) +// Blame return the human-readable reason of why request is slowing. func (t TraceInfo) Blame() string { if t.RemoteAddr == nil { return "trace is not enabled" @@ -50,17 +51,18 @@ func (t TraceInfo) Blame() string { return fmt.Sprintf("the request total time is %v, and costs %v %s", t.TotalTime, mv, mk) } +// String return the details of trace information. func (t TraceInfo) String() string { if t.RemoteAddr == nil { return "trace is not enabled" } if t.IsConnReused { return fmt.Sprintf(traceReusedFmt, t.TotalTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) - } else { - return fmt.Sprintf(traceFmt, t.TotalTime, t.DNSLookupTime, t.TCPConnectTime, t.TLSHandshakeTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) } + return fmt.Sprintf(traceFmt, t.TotalTime, t.DNSLookupTime, t.TCPConnectTime, t.TLSHandshakeTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) } +// TraceInfo represents the trace information. type TraceInfo struct { // DNSLookupTime is a duration that transport took to perform // DNS lookup. diff --git a/transport_default_js.go b/transport_default_js.go index af5df819..7cd8e335 100644 --- a/transport_default_js.go +++ b/transport_default_js.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build js && wasm // +build js,wasm package req diff --git a/transport_default_other.go b/transport_default_other.go index 7f6b27f6..a18e66b5 100644 --- a/transport_default_other.go +++ b/transport_default_other.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !js || !wasm // +build !js !wasm package req From dae2cd0b79b7fbb878ebb99f5f553ce9aca1955a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 20:25:09 +0800 Subject: [PATCH 408/843] fix report url --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f86986d6..bf209deb 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@

Build Status Code Coverage - Go Report Card - + Go Report Card + License

From d395e7e7e340efaf8a773363f8035c43537d49ff Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 20:33:03 +0800 Subject: [PATCH 409/843] rename receiver name of http2Framer --- h2_frame.go | 300 ++++++++++++++++++++++++++-------------------------- 1 file changed, 150 insertions(+), 150 deletions(-) diff --git a/h2_frame.go b/h2_frame.go index 2c153cb0..feb0decc 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -332,16 +332,16 @@ type http2Framer struct { frameCache *http2frameCache // nil if frames aren't reused (default) } -func (fr *http2Framer) maxHeaderListSize() uint32 { - if fr.MaxHeaderListSize == 0 { +func (h2f *http2Framer) maxHeaderListSize() uint32 { + if h2f.MaxHeaderListSize == 0 { return 16 << 20 // sane default, per docs } - return fr.MaxHeaderListSize + return h2f.MaxHeaderListSize } -func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { +func (h2f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { // Write the FrameHeader. - f.wbuf = append(f.wbuf[:0], + h2f.wbuf = append(h2f.wbuf[:0], 0, // 3 bytes of length, filled in in endWrite 0, 0, @@ -353,54 +353,54 @@ func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamI byte(streamID)) } -func (f *http2Framer) endWrite() error { +func (h2f *http2Framer) endWrite() error { // Now that we know the final size, fill in the FrameHeader in // the space previously reserved for it. Abuse append. - length := len(f.wbuf) - http2frameHeaderLen + length := len(h2f.wbuf) - http2frameHeaderLen if length >= (1 << 24) { return http2ErrFrameTooLarge } - _ = append(f.wbuf[:0], + _ = append(h2f.wbuf[:0], byte(length>>16), byte(length>>8), byte(length)) - if f.logWrites { - f.logWrite() + if h2f.logWrites { + h2f.logWrite() } - n, err := f.w.Write(f.wbuf) - if err == nil && n != len(f.wbuf) { + n, err := h2f.w.Write(h2f.wbuf) + if err == nil && n != len(h2f.wbuf) { err = io.ErrShortWrite } return err } -func (f *http2Framer) logWrite() { - if f.debugFramer == nil { - f.debugFramerBuf = new(bytes.Buffer) - f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) - f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below +func (h2f *http2Framer) logWrite() { + if h2f.debugFramer == nil { + h2f.debugFramerBuf = new(bytes.Buffer) + h2f.debugFramer = http2NewFramer(nil, h2f.debugFramerBuf) + h2f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below // Let us read anything, even if we accidentally wrote it // in the wrong order: - f.debugFramer.AllowIllegalReads = true + h2f.debugFramer.AllowIllegalReads = true } - f.debugFramerBuf.Write(f.wbuf) - fr, err := f.debugFramer.ReadFrame() + h2f.debugFramerBuf.Write(h2f.wbuf) + fr, err := h2f.debugFramer.ReadFrame() if err != nil { - f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) + h2f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", h2f) return } - f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) + h2f.debugWriteLoggerf("http2: Framer %p: wrote %v", h2f, http2summarizeFrame(fr)) } -func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } +func (h2f *http2Framer) writeByte(v byte) { h2f.wbuf = append(h2f.wbuf, v) } -func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } +func (h2f *http2Framer) writeBytes(v []byte) { h2f.wbuf = append(h2f.wbuf, v...) } -func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } +func (h2f *http2Framer) writeUint16(v uint16) { h2f.wbuf = append(h2f.wbuf, byte(v>>8), byte(v)) } -func (f *http2Framer) writeUint32(v uint32) { - f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +func (h2f *http2Framer) writeUint32(v uint32) { + h2f.wbuf = append(h2f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } const ( @@ -411,11 +411,11 @@ const ( // SetReuseFrames allows the Framer to reuse Frames. // If called on a Framer, Frames returned by calls to ReadFrame are only // valid until the next call to ReadFrame. -func (fr *http2Framer) SetReuseFrames() { - if fr.frameCache != nil { +func (h2f *http2Framer) SetReuseFrames() { + if h2f.frameCache != nil { return } - fr.frameCache = &http2frameCache{} + h2f.frameCache = &http2frameCache{} } type http2frameCache struct { @@ -455,11 +455,11 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { // that will be read by a subsequent call to ReadFrame. // It is the caller's responsibility to advertise this // limit with a SETTINGS frame. -func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { +func (h2f *http2Framer) SetMaxReadFrameSize(v uint32) { if v > http2maxFrameSize { v = http2maxFrameSize } - fr.maxReadSize = v + h2f.maxReadSize = v } // ErrorDetail returns a more detailed error of the last error @@ -469,8 +469,8 @@ func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { // to return a non-nil value and like the rest of the http2 package, // its return value is not protected by an API compatibility promise. // ErrorDetail is reset after the next call to ReadFrame. -func (fr *http2Framer) ErrorDetail() error { - return fr.errDetail +func (h2f *http2Framer) ErrorDetail() error { + return h2f.errDetail } // ErrFrameTooLarge is returned from Framer.ReadFrame when the peer @@ -493,39 +493,39 @@ func http2terminalReadFrameError(err error) bool { // returned error is ErrFrameTooLarge. Other errors may be of type // ConnectionError, StreamError, or anything else from the underlying // reader. -func (fr *http2Framer) ReadFrame() (http2Frame, error) { - fr.errDetail = nil - if fr.lastFrame != nil { - fr.lastFrame.invalidate() +func (h2f *http2Framer) ReadFrame() (http2Frame, error) { + h2f.errDetail = nil + if h2f.lastFrame != nil { + h2f.lastFrame.invalidate() } - fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) + fh, err := http2readFrameHeader(h2f.headerBuf[:], h2f.r) if err != nil { return nil, err } - if fh.Length > fr.maxReadSize { + if fh.Length > h2f.maxReadSize { return nil, http2ErrFrameTooLarge } - payload := fr.getReadBuf(fh.Length) - if _, err := io.ReadFull(fr.r, payload); err != nil { + payload := h2f.getReadBuf(fh.Length) + if _, err := io.ReadFull(h2f.r, payload); err != nil { return nil, err } - f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) + f, err := http2typeFrameParser(fh.Type)(h2f.frameCache, fh, h2f.countError, payload) if err != nil { if ce, ok := err.(http2connError); ok { - return nil, fr.connError(ce.Code, ce.Reason) + return nil, h2f.connError(ce.Code, ce.Reason) } return nil, err } - if err := fr.checkFrameOrder(f); err != nil { + if err := h2f.checkFrameOrder(f); err != nil { return nil, err } - if fr.logReads { - fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + if h2f.logReads { + h2f.debugReadLoggerf("http2: Framer %p: read %v", h2f, http2summarizeFrame(f)) } - if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + if fh.Type == http2FrameHeaders && h2f.ReadMetaHeaders != nil { var dumps []*dumper - if fr.cc != nil && fr.cc.t.t1 != nil { - dumps = getDumpers(fr.cc.t.t1.dump, fr.cc.currentRequest.Context()) + if h2f.cc != nil && h2f.cc.t.t1 != nil { + dumps = getDumpers(h2f.cc.t.t1.dump, h2f.cc.currentRequest.Context()) } if len(dumps) > 0 { dd := []*dumper{} @@ -536,7 +536,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { } dumps = dd } - hr, err := fr.readMetaFrame(f.(*http2HeadersFrame), dumps) + hr, err := h2f.readMetaFrame(f.(*http2HeadersFrame), dumps) if err == nil && len(dumps) > 0 { for _, dump := range dumps { dump.dump([]byte("\r\n")) @@ -551,44 +551,44 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { // stashes away a public reason to the caller can optionally relay it // to the peer before hanging up on them. This might help others debug // their implementations. -func (fr *http2Framer) connError(code http2ErrCode, reason string) error { - fr.errDetail = errors.New(reason) +func (h2f *http2Framer) connError(code http2ErrCode, reason string) error { + h2f.errDetail = errors.New(reason) return http2ConnectionError(code) } // checkFrameOrder reports an error if f is an invalid frame to return // next from ReadFrame. Mostly it checks whether HEADERS and // CONTINUATION frames are contiguous. -func (fr *http2Framer) checkFrameOrder(f http2Frame) error { - last := fr.lastFrame - fr.lastFrame = f - if fr.AllowIllegalReads { +func (h2f *http2Framer) checkFrameOrder(f http2Frame) error { + last := h2f.lastFrame + h2f.lastFrame = f + if h2f.AllowIllegalReads { return nil } fh := f.Header() - if fr.lastHeaderStream != 0 { + if h2f.lastHeaderStream != 0 { if fh.Type != http2FrameContinuation { - return fr.connError(http2ErrCodeProtocol, + return h2f.connError(http2ErrCodeProtocol, fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", fh.Type, fh.StreamID, - last.Header().Type, fr.lastHeaderStream)) + last.Header().Type, h2f.lastHeaderStream)) } - if fh.StreamID != fr.lastHeaderStream { - return fr.connError(http2ErrCodeProtocol, + if fh.StreamID != h2f.lastHeaderStream { + return h2f.connError(http2ErrCodeProtocol, fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", - fh.StreamID, fr.lastHeaderStream)) + fh.StreamID, h2f.lastHeaderStream)) } } else if fh.Type == http2FrameContinuation { - return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + return h2f.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) } switch fh.Type { case http2FrameHeaders, http2FrameContinuation: if fh.Flags.Has(http2FlagHeadersEndHeaders) { - fr.lastHeaderStream = 0 + h2f.lastHeaderStream = 0 } else { - fr.lastHeaderStream = fh.StreamID + h2f.lastHeaderStream = fh.StreamID } } @@ -670,8 +670,8 @@ func http2validStreamID(streamID uint32) bool { // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. -func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { - return f.WriteDataPadded(streamID, endStream, data, nil) +func (h2f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + return h2f.WriteDataPadded(streamID, endStream, data, nil) } // WriteDataPadded writes a DATA frame with optional padding. @@ -683,15 +683,15 @@ func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) er // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. -func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { return http2errStreamID } if len(pad) > 0 { if len(pad) > 255 { return http2errPadLength } - if !f.AllowIllegalWrites { + if !h2f.AllowIllegalWrites { for _, b := range pad { if b != 0 { // "Padding octets MUST be set to zero when sending." @@ -707,13 +707,13 @@ func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad if pad != nil { flags |= http2FlagDataPadded } - f.startWrite(http2FrameData, flags, streamID) + h2f.startWrite(http2FrameData, flags, streamID) if pad != nil { - f.wbuf = append(f.wbuf, byte(len(pad))) + h2f.wbuf = append(h2f.wbuf, byte(len(pad))) } - f.wbuf = append(f.wbuf, data...) - f.wbuf = append(f.wbuf, pad...) - return f.endWrite() + h2f.wbuf = append(h2f.wbuf, data...) + h2f.wbuf = append(h2f.wbuf, pad...) + return h2f.endWrite() } // A SettingsFrame conveys configuration parameters that affect how @@ -838,22 +838,22 @@ func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteSettings(settings ...http2Setting) error { - f.startWrite(http2FrameSettings, 0, 0) +func (h2f *http2Framer) WriteSettings(settings ...http2Setting) error { + h2f.startWrite(http2FrameSettings, 0, 0) for _, s := range settings { - f.writeUint16(uint16(s.ID)) - f.writeUint32(s.Val) + h2f.writeUint16(uint16(s.ID)) + h2f.writeUint32(s.Val) } - return f.endWrite() + return h2f.endWrite() } // WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteSettingsAck() error { - f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) - return f.endWrite() +func (h2f *http2Framer) WriteSettingsAck() error { + h2f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) + return h2f.endWrite() } // A PingFrame is a mechanism for measuring a minimal round trip time @@ -881,14 +881,14 @@ func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError fun return f, nil } -func (f *http2Framer) WritePing(ack bool, data [8]byte) error { +func (h2f *http2Framer) WritePing(ack bool, data [8]byte) error { var flags http2Flags if ack { flags = http2FlagPingAck } - f.startWrite(http2FramePing, flags, 0) - f.writeBytes(data[:]) - return f.endWrite() + h2f.startWrite(http2FramePing, flags, 0) + h2f.writeBytes(data[:]) + return h2f.endWrite() } // A GoAwayFrame informs the remote peer to stop creating streams on this connection. @@ -926,12 +926,12 @@ func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError f }, nil } -func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { - f.startWrite(http2FrameGoAway, 0, 0) - f.writeUint32(maxStreamID & (1<<31 - 1)) - f.writeUint32(uint32(code)) - f.writeBytes(debugData) - return f.endWrite() +func (h2f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { + h2f.startWrite(http2FrameGoAway, 0, 0) + h2f.writeUint32(maxStreamID & (1<<31 - 1)) + h2f.writeUint32(uint32(code)) + h2f.writeBytes(debugData) + return h2f.endWrite() } // An UnknownFrame is the frame type returned when the frame type is unknown @@ -992,14 +992,14 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countE // The increment value must be between 1 and 2,147,483,647, inclusive. // If the Stream ID is zero, the window update applies to the // connection as a whole. -func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { +func (h2f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." - if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + if (incr < 1 || incr > 2147483647) && !h2f.AllowIllegalWrites { return errors.New("illegal window increment value") } - f.startWrite(http2FrameWindowUpdate, 0, streamID) - f.writeUint32(incr) - return f.endWrite() + h2f.startWrite(http2FrameWindowUpdate, 0, streamID) + h2f.writeUint32(incr) + return h2f.endWrite() } // A HeadersFrame is used to open a stream and additionally carries a @@ -1107,8 +1107,8 @@ type http2HeadersFrameParam struct { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { - if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { return http2errStreamID } var flags http2Flags @@ -1124,24 +1124,24 @@ func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { if !p.Priority.IsZero() { flags |= http2FlagHeadersPriority } - f.startWrite(http2FrameHeaders, flags, p.StreamID) + h2f.startWrite(http2FrameHeaders, flags, p.StreamID) if p.PadLength != 0 { - f.writeByte(p.PadLength) + h2f.writeByte(p.PadLength) } if !p.Priority.IsZero() { v := p.Priority.StreamDep - if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + if !http2validStreamIDOrZero(v) && !h2f.AllowIllegalWrites { return http2errDepStreamID } if p.Priority.Exclusive { v |= 1 << 31 } - f.writeUint32(v) - f.writeByte(p.Priority.Weight) + h2f.writeUint32(v) + h2f.writeByte(p.Priority.Weight) } - f.wbuf = append(f.wbuf, p.BlockFragment...) - f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) - return f.endWrite() + h2f.wbuf = append(h2f.wbuf, p.BlockFragment...) + h2f.wbuf = append(h2f.wbuf, http2padZeros[:p.PadLength]...) + return h2f.endWrite() } // A PriorityFrame specifies the sender-advised priority of a stream. @@ -1197,21 +1197,21 @@ func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { + if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { return http2errStreamID } if !http2validStreamIDOrZero(p.StreamDep) { return http2errDepStreamID } - f.startWrite(http2FramePriority, 0, streamID) + h2f.startWrite(http2FramePriority, 0, streamID) v := p.StreamDep if p.Exclusive { v |= 1 << 31 } - f.writeUint32(v) - f.writeByte(p.Weight) - return f.endWrite() + h2f.writeUint32(v) + h2f.writeByte(p.Weight) + return h2f.endWrite() } // A RSTStreamFrame allows for abnormal termination of a stream. @@ -1237,13 +1237,13 @@ func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countErro // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { + if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { return http2errStreamID } - f.startWrite(http2FrameRSTStream, 0, streamID) - f.writeUint32(uint32(code)) - return f.endWrite() + h2f.startWrite(http2FrameRSTStream, 0, streamID) + h2f.writeUint32(uint32(code)) + return h2f.endWrite() } // A ContinuationFrame is used to continue a sequence of header block fragments. @@ -1274,17 +1274,17 @@ func (f *http2ContinuationFrame) HeadersEnded() bool { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { - if !http2validStreamID(streamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { return http2errStreamID } var flags http2Flags if endHeaders { flags |= http2FlagContinuationEndHeaders } - f.startWrite(http2FrameContinuation, flags, streamID) - f.wbuf = append(f.wbuf, headerBlockFragment...) - return f.endWrite() + h2f.startWrite(http2FrameContinuation, flags, streamID) + h2f.wbuf = append(h2f.wbuf, headerBlockFragment...) + return h2f.endWrite() } // A PushPromiseFrame is used to initiate a server stream. @@ -1373,8 +1373,8 @@ type http2PushPromiseParam struct { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { - if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { +func (h2f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { + if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { return http2errStreamID } var flags http2Flags @@ -1384,25 +1384,25 @@ func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { if p.EndHeaders { flags |= http2FlagPushPromiseEndHeaders } - f.startWrite(http2FramePushPromise, flags, p.StreamID) + h2f.startWrite(http2FramePushPromise, flags, p.StreamID) if p.PadLength != 0 { - f.writeByte(p.PadLength) + h2f.writeByte(p.PadLength) } - if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + if !http2validStreamID(p.PromiseID) && !h2f.AllowIllegalWrites { return http2errStreamID } - f.writeUint32(p.PromiseID) - f.wbuf = append(f.wbuf, p.BlockFragment...) - f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) - return f.endWrite() + h2f.writeUint32(p.PromiseID) + h2f.wbuf = append(h2f.wbuf, p.BlockFragment...) + h2f.wbuf = append(h2f.wbuf, http2padZeros[:p.PadLength]...) + return h2f.endWrite() } // WriteRawFrame writes a raw frame. This can be used to write // extension frames unknown to this package. -func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { - f.startWrite(t, flags, streamID) - f.writeBytes(payload) - return f.endWrite() +func (h2f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { + h2f.startWrite(t, flags, streamID) + h2f.writeBytes(payload) + return h2f.endWrite() } func http2readByte(p []byte) (remain []byte, b byte, err error) { @@ -1522,8 +1522,8 @@ func (mh *http2MetaHeadersFrame) checkPseudos() error { return nil } -func (fr *http2Framer) maxHeaderStringLen() int { - v := fr.maxHeaderListSize() +func (h2f *http2Framer) maxHeaderStringLen() int { + v := h2f.maxHeaderListSize() if uint32(int(v)) == v { return int(v) } @@ -1535,23 +1535,23 @@ func (fr *http2Framer) maxHeaderStringLen() int { // readMetaFrame returns 0 or more CONTINUATION frames from fr and // merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. -func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { - if fr.AllowIllegalReads { +func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { + if h2f.AllowIllegalReads { return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") } mh := &http2MetaHeadersFrame{ http2HeadersFrame: hf, } - var remainSize = fr.maxHeaderListSize() + var remainSize = h2f.maxHeaderListSize() var sawRegular bool var invalid error // pseudo header field errors - hdec := fr.ReadMetaHeaders + hdec := h2f.ReadMetaHeaders hdec.SetEmitEnabled(true) - hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + hdec.SetMaxStringLength(h2f.maxHeaderStringLen()) rawEmitFunc := func(hf hpack.HeaderField) { - if http2VerboseLogs && fr.logReads { - fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) + if http2VerboseLogs && h2f.logReads { + h2f.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httpguts.ValidHeaderFieldValue(hf.Value) { invalid = http2headerFieldValueError(hf.Value) @@ -1608,7 +1608,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*h if hc.HeadersEnded() { break } - if f, err := fr.ReadFrame(); err != nil { + if f, err := h2f.ReadFrame(); err != nil { return nil, err } else { hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder @@ -1622,14 +1622,14 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*h return nil, http2ConnectionError(http2ErrCodeCompression) } if invalid != nil { - fr.errDetail = invalid + h2f.errDetail = invalid if http2VerboseLogs { log.Printf("http2: invalid header: %v", invalid) } return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} } if err := mh.checkPseudos(); err != nil { - fr.errDetail = err + h2f.errDetail = err if http2VerboseLogs { log.Printf("http2: invalid pseudo headers: %v", err) } From b920a4bfb65e72be4bd88a1a20695c43469c2900 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 20:33:53 +0800 Subject: [PATCH 410/843] rename errClosedPipeWrite --- h2_pipe.go | 4 ++-- h2_pipe_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/h2_pipe.go b/h2_pipe.go index 0777847b..56bd8e9c 100644 --- a/h2_pipe.go +++ b/h2_pipe.go @@ -77,7 +77,7 @@ func (p *http2pipe) Read(d []byte) (n int, err error) { } } -var http2errClosedPipeWrite = errors.New("write on closed buffer") +var errClosedPipeWrite = errors.New("write on closed buffer") // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -89,7 +89,7 @@ func (p *http2pipe) Write(d []byte) (n int, err error) { } defer p.c.Signal() if p.err != nil { - return 0, http2errClosedPipeWrite + return 0, errClosedPipeWrite } if p.breakErr != nil { p.unread += len(d) diff --git a/h2_pipe_test.go b/h2_pipe_test.go index 7e9b4657..434c7c15 100644 --- a/h2_pipe_test.go +++ b/h2_pipe_test.go @@ -96,11 +96,11 @@ func TestPipeCloseWithError(t *testing.T) { t.Errorf("pipe should have 0 unread bytes") } // Read and Write should fail. - if n, err := p.Write([]byte("abc")); err != http2errClosedPipeWrite || n != 0 { - t.Errorf("Write(abc) after close\ngot %v, %v\nwant 0, %v", n, err, http2errClosedPipeWrite) + if n, err := p.Write([]byte("abc")); err != errClosedPipeWrite || n != 0 { + t.Errorf("Write(abc) after close\ngot %v, %v\nwant 0, %v", n, err, errClosedPipeWrite) } if n, err := p.Read(make([]byte, 1)); err == nil || n != 0 { - t.Errorf("Read() after close\ngot %v, nil\nwant 0, %v", n, http2errClosedPipeWrite) + t.Errorf("Read() after close\ngot %v, nil\nwant 0, %v", n, errClosedPipeWrite) } if p.Len() != 0 { t.Errorf("pipe should have 0 unread bytes") From 17f175a9c4a7fd27010030a265eee7022f3dc526 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 20:55:14 +0800 Subject: [PATCH 411/843] fix golint --- h2.go | 2 +- h2_errors.go | 6 +-- h2_frame.go | 4 +- h2_server_test.go | 13 +++---- h2_transport.go | 87 +++++++++++++++++++++----------------------- h2_transport_test.go | 48 ++++++++++++------------ request.go | 2 + transport.go | 22 +++++------ transport_test.go | 20 +++++----- 9 files changed, 100 insertions(+), 104 deletions(-) diff --git a/h2.go b/h2.go index de9d83d5..2665d7f8 100644 --- a/h2.go +++ b/h2.go @@ -198,7 +198,7 @@ func (e *http2httpError) Timeout() bool { return e.timeout } func (e *http2httpError) Temporary() bool { return true } -var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} +var errH2Timeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} type http2connectionStater interface { ConnectionState() tls.ConnectionState diff --git a/h2_errors.go b/h2_errors.go index e81daf2e..da23201b 100644 --- a/h2_errors.go +++ b/h2_errors.go @@ -79,7 +79,7 @@ type http2StreamError struct { // errFromPeer is a sentinel error value for StreamError.Cause to // indicate that the StreamError was sent from the peer over the wire // and wasn't locally generated in the Transport. -var http2errFromPeer = errors.New("received from peer") +var errFromPeer = errors.New("received from peer") func http2streamError(id uint32, code http2ErrCode) http2StreamError { return http2StreamError{StreamID: id, Code: code} @@ -133,6 +133,6 @@ func (e http2headerFieldValueError) Error() string { } var ( - http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") - http2errPseudoAfterRegular = errors.New("pseudo header field after regular") + errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + errPseudoAfterRegular = errors.New("pseudo header field after regular") ) diff --git a/h2_frame.go b/h2_frame.go index feb0decc..2d60ea9b 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -1517,7 +1517,7 @@ func (mh *http2MetaHeadersFrame) checkPseudos() error { } } if isRequest && isResponse { - return http2errMixPseudoHeaderTypes + return errMixPseudoHeaderTypes } return nil } @@ -1559,7 +1559,7 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* isPseudo := strings.HasPrefix(hf.Name, ":") if isPseudo { if sawRegular { - invalid = http2errPseudoAfterRegular + invalid = errPseudoAfterRegular } } else { sawRegular = true diff --git a/h2_server_test.go b/h2_server_test.go index 1f9f8465..071cc5e4 100644 --- a/h2_server_test.go +++ b/h2_server_test.go @@ -3403,8 +3403,8 @@ func (w *http2responseWriter) handlerDone() { // Push errors. var ( - http2ErrRecursivePush = errors.New("http2: recursive push not allowed") - http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") + errRecursivePush = errors.New("http2: recursive push not allowed") + errPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") ) var _ http.Pusher = (*http2responseWriter)(nil) @@ -3417,7 +3417,7 @@ func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." // http://tools.ietf.org/html/rfc7540#section-6.6 if st.isPushed() { - return http2ErrRecursivePush + return errRecursivePush } if opts == nil { @@ -3549,7 +3549,7 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { } // http://tools.ietf.org/html/rfc7540#section-6.5.2. if sc.curPushedStreams+1 > sc.clientMaxStreams { - return 0, http2ErrPushLimitReached + return 0, errPushLimitReached } // http://tools.ietf.org/html/rfc7540#section-5.1.1. @@ -3558,7 +3558,7 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { // frame so that the client is forced to open a new connection for new streams. if sc.maxPushPromiseID+2 >= 1<<31 { sc.startGracefulShutdownInternal() - return 0, http2ErrPushLimitReached + return 0, errPushLimitReached } sc.maxPushPromiseID += 2 promisedID := sc.maxPushPromiseID @@ -3900,9 +3900,8 @@ func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []by EndStream: w.endStream, EndHeaders: lastFrag, }) - } else { - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } // writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. diff --git a/h2_transport.go b/h2_transport.go index 57d18d5b..06d8052d 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -524,9 +524,9 @@ func (t *http2Transport) CloseIdleConnections() { } var ( - http2errClientConnClosed = errors.New("http2: client conn is closed") - http2errClientConnUnusable = errors.New("http2: client conn not usable") - http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + errClientConnClosed = errors.New("http2: client conn is closed") + errClientConnUnusable = errors.New("http2: client conn not usable") + errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") ) // shouldRetryRequest is called by RoundTrip when a request fails to get @@ -558,7 +558,7 @@ func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error // The Request.Body can't reset back to the beginning, but we // don't seem to have started to read from it yet, so reuse // the request directly. - if err == http2errClientConnUnusable { + if err == errClientConnUnusable { return req, nil } @@ -566,11 +566,11 @@ func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error } func http2canRetryError(err error) bool { - if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { + if err == errClientConnUnusable || err == errClientConnGotGoAway { return true } if se, ok := err.(http2StreamError); ok { - if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { + if se.Code == http2ErrCodeProtocol && se.Cause == errFromPeer { // See golang/go#47635, golang/go#42777 return true } @@ -758,7 +758,7 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - cs.abortStreamLocked(http2errClientConnGotGoAway) + cs.abortStreamLocked(errClientConnGotGoAway) } } } @@ -997,7 +997,7 @@ func (cc *http2ClientConn) closeForLostPing() error { // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. -var http2errRequestCanceled = errors.New("net/http: request canceled") +var errRequestCanceled = errors.New("net/http: request canceled") func http2commaSeparatedTrailers(req *http.Request) (string, error) { keys := make([]string, 0, len(req.Trailer)) @@ -1093,7 +1093,7 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return http2errRequestCanceled + return errRequestCanceled } } @@ -1145,8 +1145,8 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) cs.abortStream(err) return nil, err case <-cs.reqCancel: - cs.abortStream(http2errRequestCanceled) - return nil, http2errRequestCanceled + cs.abortStream(errRequestCanceled) + return nil, errRequestCanceled } } } @@ -1183,7 +1183,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: - return http2errRequestCanceled + return errRequestCanceled case <-ctx.Done(): return ctx.Err() } @@ -1272,7 +1272,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { case <-ctx.Done(): err = ctx.Err() case <-cs.reqCancel: - err = http2errRequestCanceled + err = errRequestCanceled } timer.Stop() if err != nil { @@ -1281,7 +1281,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } if err = cs.writeRequestBody(req, bodyDumps); err != nil { - if err != http2errStopReqBodyWrite { + if err != errStopReqBodyWrite { http2traceWroteRequest(cs.trace, err) return err } @@ -1311,7 +1311,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { case <-cs.peerClosed: return nil case <-respHeaderTimer: - return http2errTimeout + return errH2Timeout case <-respHeaderRecv: respHeaderRecv = nil respHeaderTimer = nil // keep waiting for END_STREAM @@ -1320,7 +1320,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return http2errRequestCanceled + return errRequestCanceled } } } @@ -1339,7 +1339,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*d case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return http2errRequestCanceled + return errRequestCanceled default: } @@ -1406,7 +1406,7 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { cs.abortStream(err) // possibly redundant, but harmless if cs.sentHeaders { if se, ok := err.(http2StreamError); ok { - if se.Cause != http2errFromPeer { + if se.Cause != errFromPeer { cc.writeStreamReset(cs.ID, se.Code, err) } } else { @@ -1418,7 +1418,7 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { if cs.sentHeaders && !cs.sentEndStream { cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) } - cs.bufPipe.CloseWithError(http2errRequestCanceled) + cs.bufPipe.CloseWithError(errRequestCanceled) } if cs.ID != 0 { cc.forgetStreamID(cs.ID) @@ -1440,7 +1440,7 @@ func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) e for { cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { - return http2errClientConnUnusable + return errClientConnUnusable } cc.lastIdle = time.Time{} if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { @@ -1486,12 +1486,12 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFram // internal error values; they don't escape to callers var ( // abort request body write; don't send cancel - http2errStopReqBodyWrite = errors.New("http2: aborting request body write") + errStopReqBodyWrite = errors.New("http2: aborting request body write") // abort request body write, but send stream reset of cancel. - http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") + errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") - http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") + errReqBodyTooLong = errors.New("http2: request body larger than specified content length") ) // frameScratchBufferLen returns the length of a buffer to use for @@ -1573,7 +1573,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper remainLen -= int64(n1) } if remainLen < 0 { - err = http2errReqBodyTooLong + err = errReqBodyTooLong return err } } @@ -1583,7 +1583,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper cc.mu.Unlock() switch { case bodyClosed: - return http2errStopReqBodyWrite + return errStopReqBodyWrite case err == io.EOF: sawEOF = true err = nil @@ -1672,10 +1672,10 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er defer cc.mu.Unlock() for { if cc.closed { - return 0, http2errClientConnClosed + return 0, errClientConnClosed } if cs.reqBodyClosed { - return 0, http2errStopReqBodyWrite + return 0, errStopReqBodyWrite } select { case <-cs.abort: @@ -1683,7 +1683,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er case <-ctx.Done(): return 0, ctx.Err() case <-cs.reqCancel: - return 0, http2errRequestCanceled + return 0, errRequestCanceled default: } if a := cs.flow.available(); a > 0 { @@ -1702,13 +1702,13 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er } } -var http2errNilRequestURL = errors.New("http2: Request.URI is nil") +var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dumps []*dumper) ([]byte, error) { cc.hbuf.Reset() if req.URL == nil { - return nil, http2errNilRequestURL + return nil, errNilRequestURL } host := req.Host @@ -1729,9 +1729,8 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, if !http2validPseudoPath(path) { if req.URL.Opaque != "" { return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) - } else { - return nil, fmt.Errorf("invalid request :path %q", orig) } + return nil, fmt.Errorf("invalid request :path %q", orig) } } } @@ -1850,7 +1849,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, }) if hlSize > cc.peerMaxHeaderListSize { - return nil, http2errRequestHeaderListSize + return nil, errRequestHeaderListSize } trace := httptrace.ContextClientTrace(req.Context()) @@ -1929,7 +1928,7 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) } } if hlSize > cc.peerMaxHeaderListSize { - return nil, http2errRequestHeaderListSize + return nil, errRequestHeaderListSize } writeHeader := cc.writeHeader @@ -2270,7 +2269,7 @@ func http2foreachHeaderElement(v string, fn func(string)) { // frame (currently only used for 1xx responses). func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*http.Response, error) { if f.Truncated { - return nil, http2errResponseHeaderListSize + return nil, errResponseHeaderListSize } status := f.PseudoValue("status") @@ -2485,7 +2484,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { return } -var http2errClosedResponseBody = errors.New("http2: response body closed") +var errClosedResponseBody = errors.New("http2: response body closed") func (b http2transportResponseBody) Close() error { cs := b.cs @@ -2511,8 +2510,8 @@ func (b http2transportResponseBody) Close() error { cc.wmu.Unlock() } - cs.bufPipe.BreakWithError(http2errClosedResponseBody) - cs.abortStream(http2errClosedResponseBody) + cs.bufPipe.BreakWithError(errClosedResponseBody) + cs.abortStream(errClosedResponseBody) select { case <-cs.donec: @@ -2522,7 +2521,7 @@ func (b http2transportResponseBody) Close() error { // Don't treat this as an error. return nil case <-cs.reqCancel: - return http2errRequestCanceled + return errRequestCanceled } return nil } @@ -2806,7 +2805,7 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er return nil } serr := http2streamError(cs.ID, f.ErrCode) - serr.Cause = http2errFromPeer + serr.Cause = errFromPeer if f.ErrCode == http2ErrCodeProtocol { rl.cc.SetDoNotReuse() } @@ -2907,8 +2906,8 @@ func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, } var ( - http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") - http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") + errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") ) func (cc *http2ClientConn) logf(format string, args ...interface{}) { @@ -2981,10 +2980,6 @@ func (gz *http2gzipReader) Close() error { return gz.body.Close() } -type http2errorReader struct{ err error } - -func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func http2isConnectionCloseRequest(req *http.Request) bool { diff --git a/h2_transport_test.go b/h2_transport_test.go index 2a62d601..252d355f 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -1442,37 +1442,37 @@ func TestTransportReceiveUndeclaredTrailer(t *testing.T) { ct.run() } -func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { - testTransportInvalidTrailer_Pseudo(t, oneHeader) +func TestTransportInvalidTrailerPseudo1(t *testing.T) { + testTransportInvalidTrailerPseudo(t, oneHeader) } -func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { - testTransportInvalidTrailer_Pseudo(t, splitHeader) +func TestTransportInvalidTrailerPseudo2(t *testing.T) { + testTransportInvalidTrailerPseudo(t, splitHeader) } -func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { +func testTransportInvalidTrailerPseudo(t *testing.T, trailers headerType) { testInvalidTrailer(t, trailers, http2pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) }) } -func TestTransportInvalidTrailer_Capital1(t *testing.T) { - testTransportInvalidTrailer_Capital(t, oneHeader) +func TestTransportInvalidTrailerCapital1(t *testing.T) { + testTransportInvalidTrailerCapital(t, oneHeader) } -func TestTransportInvalidTrailer_Capital2(t *testing.T) { - testTransportInvalidTrailer_Capital(t, splitHeader) +func TestTransportInvalidTrailerCapital2(t *testing.T) { + testTransportInvalidTrailerCapital(t, splitHeader) } -func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { +func testTransportInvalidTrailerCapital(t *testing.T, trailers headerType) { testInvalidTrailer(t, trailers, http2headerFieldNameError("Capital"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) }) } -func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { +func TestTransportInvalidTrailerEmptyFieldName(t *testing.T) { testInvalidTrailer(t, oneHeader, http2headerFieldNameError(""), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) }) } -func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { +func TestTransportInvalidTrailerBinaryFieldValue(t *testing.T) { testInvalidTrailer(t, oneHeader, http2headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) }) @@ -1801,26 +1801,26 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { req = newRequest() req.Header = make(http.Header) padHeaders(t, req.Header, peerSize, filler) - checkRoundTrip(req, http2errRequestHeaderListSize, "Headers over limit") + checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit") // Push trailers over the limit. req = newRequest() req.Trailer = make(http.Header) padHeaders(t, req.Trailer, peerSize+1, filler) - checkRoundTrip(req, http2errRequestHeaderListSize, "Trailers over limit") + checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit") // Send headers with a single large value. req = newRequest() filler = strings.Repeat("*", int(peerSize)) req.Header = make(http.Header) req.Header.Set("Big", filler) - checkRoundTrip(req, http2errRequestHeaderListSize, "Single large header") + checkRoundTrip(req, errRequestHeaderListSize, "Single large header") // Send trailers with a single large value. req = newRequest() req.Trailer = make(http.Header) req.Trailer.Set("Big", filler) - checkRoundTrip(req, http2errRequestHeaderListSize, "Single large trailer") + checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer") } func TestTransportChecksResponseHeaderListSize(t *testing.T) { @@ -1831,7 +1831,7 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { if e, ok := err.(http2StreamError); ok { err = e.Cause } - if err != http2errResponseHeaderListSize { + if err != errResponseHeaderListSize { size := int64(0) if res != nil { res.Body.Close() @@ -1841,7 +1841,7 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { } } } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want http2errResponseHeaderListSize", err, size) + return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } return nil } @@ -3368,7 +3368,7 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { reqHeaderMu: make(chan struct{}, 1), } _, err := cc.RoundTrip(req) - if err != http2errClientConnUnusable { + if err != errClientConnUnusable { t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) } slurp, err := ioutil.ReadAll(req.Body) @@ -4598,7 +4598,7 @@ func TestTransportUsesGetBodyWhenPresent(t *testing.T) { }, } - req2, err := http2shouldRetryRequest(req, http2errClientConnUnusable) + req2, err := http2shouldRetryRequest(req, errClientConnUnusable) if err != nil { t.Fatal(err) } @@ -4841,8 +4841,8 @@ func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunk req, _ := http.NewRequest("POST", st.ts.URL, body) req.ContentLength = contentLen _, err := tr.RoundTrip(req) - if err != http2errReqBodyTooLong { - t.Fatalf("expected %v, got %v", http2errReqBodyTooLong, err) + if err != errReqBodyTooLong { + t.Fatalf("expected %v, got %v", errReqBodyTooLong, err) } } @@ -5607,8 +5607,8 @@ func TestClientConnReservations(t *testing.T) { if n != http2initialMaxConcurrentStreams { t.Errorf("did %v reservations; want %v", n, http2initialMaxConcurrentStreams) } - if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, http2errNilRequestURL) { - t.Fatalf("RoundTrip error = %v; want http2errNilRequestURL", err) + if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) { + t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err) } n2 := 0 for n2 <= 5 && cc.ReserveNewRequest() { diff --git a/request.go b/request.go index a1991bfe..9ee90d76 100644 --- a/request.go +++ b/request.go @@ -910,6 +910,8 @@ func (r *Request) EnableDumpToFile(filename string) *Request { return r.EnableDump() } +// SetDumpOptions is a global wrapper methods which delegated +// to the default client, create a request and SetDumpOptions for request. func SetDumpOptions(opt *DumpOptions) *Request { return defaultClient.R().SetDumpOptions(opt) } diff --git a/transport.go b/transport.go index e706bc97..902523e7 100644 --- a/transport.go +++ b/transport.go @@ -39,10 +39,13 @@ import ( "golang.org/x/net/http/httpguts" ) +// HttpVersion represents http version. type HttpVersion string const ( + // HTTP1 represents "HTTP/1.1" HTTP1 HttpVersion = "1.1" + // HTTP2 represents "HTTP/2.0" HTTP2 HttpVersion = "2" ) @@ -1438,20 +1441,20 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. -func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { +func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { // Initiate TLS and check remote host name against certificate. - cfg := cloneTLSConfig(pconn.t.TLSClientConfig) + cfg := cloneTLSConfig(pc.t.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = name } - if pconn.cacheKey.onlyH1 { + if pc.cacheKey.onlyH1 { cfg.NextProtos = nil } - plainConn := pconn.conn + plainConn := pc.conn tlsConn := tls.Client(plainConn, cfg) errc := make(chan error, 2) var timer *time.Timer // for canceling TLS handshake - if d := pconn.t.TLSHandshakeTimeout; d != 0 { + if d := pc.t.TLSHandshakeTimeout; d != 0 { timer = time.AfterFunc(d, func() { errc <- tlsHandshakeTimeoutError{} }) @@ -1477,9 +1480,9 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(cs, nil) } - pconn.tlsState = &cs - pconn.conn = tlsConn - if !forProxy && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + pc.tlsState = &cs + pc.conn = tlsConn + if !forProxy && pc.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil @@ -2694,9 +2697,6 @@ func (e *httpError) Temporary() bool { return true } var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -// errRequestCanceled is set to be identical to the one from h2 to facilitate -// testing. -var errRequestCanceled = http2errRequestCanceled var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? func nop() {} diff --git a/transport_test.go b/transport_test.go index 38b84a46..47d921b6 100644 --- a/transport_test.go +++ b/transport_test.go @@ -84,7 +84,7 @@ func (t *Transport) IdleConnStrsForTesting() []string { return ret } -func (t *Transport) IdleConnStrsForTesting_h2() []string { +func (t *Transport) IdleConnStrsForTestingH2() []string { var ret []string noDialPool := t.t2.ConnPool.(http2noDialClientConnPool) pool := noDialPool.http2clientConnPool @@ -2613,7 +2613,7 @@ func TestCancelRequestWithChannel(t *testing.T) { body, err := io.ReadAll(res.Body) d := time.Since(t0) - if err != http2errRequestCanceled { + if err != errRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } if string(body) != "Hello" { @@ -2716,8 +2716,8 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { if err == nil { t.Fatalf("unexpected success from RoundTrip") } - if err != http2errRequestCanceled { - t.Errorf("RoundTrip error = %v; want http2errRequestCanceled", err) + if err != errRequestCanceled { + t.Errorf("RoundTrip error = %v; want errRequestCanceled", err) } } @@ -4037,7 +4037,7 @@ func TestTransportDialCancelRace(t *testing.T) { }) defer SetEnterRoundTripHook(nil) res, err := tr.RoundTrip(req) - if err != http2errRequestCanceled { + if err != errRequestCanceled { t.Errorf("expected canceled request error; got %v", err) if err == nil { res.Body.Close() @@ -4357,16 +4357,16 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { wg.Wait() } -func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { - testTransportReuseConnection_Gzip(t, true) +func TestTransportReuseConnectionGzipChunked(t *testing.T) { + testTransportReuseConnectionGzip(t, true) } -func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { - testTransportReuseConnection_Gzip(t, false) +func TestTransportReuseConnectionGzipContentLength(t *testing.T) { + testTransportReuseConnectionGzip(t, false) } // Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { +func testTransportReuseConnectionGzip(t *testing.T, chunked bool) { setParallel(t) defer afterTest(t) addr := make(chan string, 2) From fbb227269b5745bbc4cd050662828cbbf0f406ae Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Feb 2022 21:39:52 +0800 Subject: [PATCH 412/843] fix more golint --- dump.go | 9 +- h2_databuffer.go | 4 +- h2_databuffer_test.go | 2 +- h2_frame.go | 42 ++--- h2_frame_test.go | 10 +- h2_server_test.go | 392 ++++++++++++++++++------------------------ h2_transport.go | 4 +- h2_transport_test.go | 4 +- req_test.go | 2 +- request.go | 2 +- transport.go | 9 +- 11 files changed, 210 insertions(+), 270 deletions(-) diff --git a/dump.go b/dump.go index 4c2ba354..250b3f89 100644 --- a/dump.go +++ b/dump.go @@ -16,6 +16,7 @@ type DumpOptions struct { Async bool } +// Clone return a copy of DumpOptions func (do *DumpOptions) Clone() *DumpOptions { if do == nil { return nil @@ -140,12 +141,16 @@ func (d *dumper) Start() { } } -func getDumpers(dump *dumper, ctx context.Context) []*dumper { +type dumperKeyType int + +const dumperKey dumperKeyType = iota + +func getDumpers(ctx context.Context, dump *dumper) []*dumper { dumps := []*dumper{} if dump != nil { dumps = append(dumps, dump) } - if d, ok := ctx.Value("_dumper").(*dumper); ok { + if d, ok := ctx.Value(dumperKey).(*dumper); ok { dumps = append(dumps, d) } return dumps diff --git a/h2_databuffer.go b/h2_databuffer.go index 26f2f877..be110504 100644 --- a/h2_databuffer.go +++ b/h2_databuffer.go @@ -70,13 +70,13 @@ type http2dataBuffer struct { expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) } -var http2errReadEmpty = errors.New("read from empty dataBuffer") +var errReadEmpty = errors.New("read from empty dataBuffer") // Read copies bytes from the buffer into p. // It is an error to read when no data is available. func (b *http2dataBuffer) Read(p []byte) (int, error) { if b.size == 0 { - return 0, http2errReadEmpty + return 0, errReadEmpty } var ntotal int for len(p) > 0 && b.size > 0 { diff --git a/h2_databuffer_test.go b/h2_databuffer_test.go index a9d4f09b..b2da7459 100644 --- a/h2_databuffer_test.go +++ b/h2_databuffer_test.go @@ -52,7 +52,7 @@ func testDataBuffer(t *testing.T, wantBytes []byte, setup func(t *testing.T) *ht for { n, err := b.Read(buf) gotRead.Write(buf[:n]) - if err == http2errReadEmpty { + if err == errReadEmpty { break } if err != nil { diff --git a/h2_frame.go b/h2_frame.go index 2d60ea9b..364a770c 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -358,7 +358,7 @@ func (h2f *http2Framer) endWrite() error { // the space previously reserved for it. Abuse append. length := len(h2f.wbuf) - http2frameHeaderLen if length >= (1 << 24) { - return http2ErrFrameTooLarge + return errFrameTooLarge } _ = append(h2f.wbuf[:0], byte(length>>16), @@ -473,9 +473,9 @@ func (h2f *http2Framer) ErrorDetail() error { return h2f.errDetail } -// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// errFrameTooLarge is returned from Framer.ReadFrame when the peer // sends a frame that is larger than declared with SetMaxReadFrameSize. -var http2ErrFrameTooLarge = errors.New("http2: frame too large") +var errFrameTooLarge = errors.New("http2: frame too large") // terminalReadFrameError reports whether err is an unrecoverable // error from ReadFrame and no other frames should be read. @@ -490,7 +490,7 @@ func http2terminalReadFrameError(err error) bool { // until the next call to ReadFrame. // // If the frame is larger than previously set with SetMaxReadFrameSize, the -// returned error is ErrFrameTooLarge. Other errors may be of type +// returned error is errFrameTooLarge. Other errors may be of type // ConnectionError, StreamError, or anything else from the underlying // reader. func (h2f *http2Framer) ReadFrame() (http2Frame, error) { @@ -503,7 +503,7 @@ func (h2f *http2Framer) ReadFrame() (http2Frame, error) { return nil, err } if fh.Length > h2f.maxReadSize { - return nil, http2ErrFrameTooLarge + return nil, errFrameTooLarge } payload := h2f.getReadBuf(fh.Length) if _, err := io.ReadFull(h2f.r, payload); err != nil { @@ -525,7 +525,7 @@ func (h2f *http2Framer) ReadFrame() (http2Frame, error) { if fh.Type == http2FrameHeaders && h2f.ReadMetaHeaders != nil { var dumps []*dumper if h2f.cc != nil && h2f.cc.t.t1 != nil { - dumps = getDumpers(h2f.cc.t.t1.dump, h2f.cc.currentRequest.Context()) + dumps = getDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.t1.dump) } if len(dumps) > 0 { dd := []*dumper{} @@ -651,10 +651,10 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError fu } var ( - http2errStreamID = errors.New("invalid stream ID") - http2errDepStreamID = errors.New("invalid dependent stream ID") - http2errPadLength = errors.New("pad length too large") - http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") + errStreamID = errors.New("invalid stream ID") + errDepStreamID = errors.New("invalid dependent stream ID") + errPadLength = errors.New("pad length too large") + errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") ) func http2validStreamIDOrZero(streamID uint32) bool { @@ -685,17 +685,17 @@ func (h2f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) // and to not call other Write methods concurrently. func (h2f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } if len(pad) > 0 { if len(pad) > 255 { - return http2errPadLength + return errPadLength } if !h2f.AllowIllegalWrites { for _, b := range pad { if b != 0 { // "Padding octets MUST be set to zero when sending." - return http2errPadBytes + return errPadBytes } } } @@ -1109,7 +1109,7 @@ type http2HeadersFrameParam struct { // It is the caller's responsibility to not call other Write methods concurrently. func (h2f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } var flags http2Flags if p.PadLength != 0 { @@ -1131,7 +1131,7 @@ func (h2f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { if !p.Priority.IsZero() { v := p.Priority.StreamDep if !http2validStreamIDOrZero(v) && !h2f.AllowIllegalWrites { - return http2errDepStreamID + return errDepStreamID } if p.Priority.Exclusive { v |= 1 << 31 @@ -1199,10 +1199,10 @@ func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError // It is the caller's responsibility to not call other Write methods concurrently. func (h2f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } if !http2validStreamIDOrZero(p.StreamDep) { - return http2errDepStreamID + return errDepStreamID } h2f.startWrite(http2FramePriority, 0, streamID) v := p.StreamDep @@ -1239,7 +1239,7 @@ func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countErro // It is the caller's responsibility to not call other Write methods concurrently. func (h2f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } h2f.startWrite(http2FrameRSTStream, 0, streamID) h2f.writeUint32(uint32(code)) @@ -1276,7 +1276,7 @@ func (f *http2ContinuationFrame) HeadersEnded() bool { // It is the caller's responsibility to not call other Write methods concurrently. func (h2f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } var flags http2Flags if endHeaders { @@ -1375,7 +1375,7 @@ type http2PushPromiseParam struct { // It is the caller's responsibility to not call other Write methods concurrently. func (h2f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } var flags http2Flags if p.PadLength != 0 { @@ -1389,7 +1389,7 @@ func (h2f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { h2f.writeByte(p.PadLength) } if !http2validStreamID(p.PromiseID) && !h2f.AllowIllegalWrites { - return http2errStreamID + return errStreamID } h2f.writeUint32(p.PromiseID) h2f.wbuf = append(h2f.wbuf, p.BlockFragment...) diff --git a/h2_frame_test.go b/h2_frame_test.go index c5f2ed6c..36874ff8 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -359,13 +359,13 @@ func TestWriteInvalidStreamDep(t *testing.T) { StreamDep: 1 << 31, }, }) - if err != http2errDepStreamID { - t.Errorf("header error = %v; want %q", err, http2errDepStreamID) + if err != errDepStreamID { + t.Errorf("header error = %v; want %q", err, errDepStreamID) } err = fr.WritePriority(2, http2PriorityParam{StreamDep: 1 << 31}) - if err != http2errDepStreamID { - t.Errorf("priority error = %v; want %q", err, http2errDepStreamID) + if err != errDepStreamID { + t.Errorf("priority error = %v; want %q", err, errDepStreamID) } } @@ -670,7 +670,7 @@ func TestWriteTooLargeFrame(t *testing.T) { fr.startWrite(0, 1, 1) fr.writeBytes(make([]byte, 1<<24)) err := fr.endWrite() - if err != http2ErrFrameTooLarge { + if err != errFrameTooLarge { t.Errorf("endWrite = %v; want errFrameTooLarge", err) } } diff --git a/h2_server_test.go b/h2_server_test.go index 071cc5e4..45a27f30 100644 --- a/h2_server_test.go +++ b/h2_server_test.go @@ -38,160 +38,143 @@ import ( // https://www.iana.org/assignments/tls-parameters/tls-parameters.txt const ( - http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 - http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 - http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 - http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 - http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 - http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 - http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 - http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 - http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 - http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A - http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B - http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C - http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D - http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E - http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F - http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 - http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 - http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 - http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 - http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 - http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 - http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 - http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 - http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 - http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 - http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A - http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B - // Reserved uint16 = 0x001C-1D - http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F - http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 - http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 - http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B - http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A - http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 - // Reserved uint16 = 0x0047-4F - // Reserved uint16 = 0x0050-58 - // Reserved uint16 = 0x0059-5C - // Unassigned uint16 = 0x005D-5F - // Reserved uint16 = 0x0060-66 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D - // Unassigned uint16 = 0x006E-83 - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 - http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A - http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D - http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E - http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 - http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 - http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 - http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 - http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 - http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 - http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 - http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A - http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B - http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C - http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D - http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E - http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F - http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 - http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 - http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 - http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 - http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 - http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 - http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 - http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 - http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 - http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 - http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA - http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB - http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC - http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF - http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 - http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 - // Unassigned uint16 = 0x00C6-FE - http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF - // Unassigned uint16 = 0x01-55,* - http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 - // Unassigned uint16 = 0x5601 - 0xC000 + http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 @@ -234,12 +217,9 @@ const ( http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 @@ -273,22 +253,14 @@ const ( http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 @@ -299,8 +271,6 @@ const ( http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 @@ -315,28 +285,18 @@ const ( http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 @@ -349,34 +309,12 @@ const ( http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D - http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E - http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 - http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 - http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 - http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA - http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF - // Unassigned uint16 = 0xC0B0-FF - // Unassigned uint16 = 0xC1-CB,* - // Unassigned uint16 = 0xCC00-A7 - http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 - http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 - http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA - http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB - http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC - http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD - http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE ) // isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. @@ -679,10 +617,10 @@ const ( ) var ( - http2errClientDisconnected = errors.New("client disconnected") - http2errClosedBody = errors.New("body closed by handler") - http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - http2errStreamClosed = errors.New("http2: stream closed") + errClientDisconnected = errors.New("client disconnected") + errClosedBody = errors.New("body closed by handler") + errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + errStreamClosed = errors.New("http2: stream closed") ) var http2responseWriterStatePool = sync.Pool{ @@ -1331,7 +1269,7 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ if err == nil { return } - if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == errPrefaceTimeout { // Boring, expected errors. sc.vlogf(format, args...) } else { @@ -1427,7 +1365,7 @@ func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { func (sc *http2serverConn) closeAllStreamsOnConnClose() { sc.serveG.check() for _, st := range sc.streams { - sc.closeStream(st, http2errClientDisconnected) + sc.closeStream(st, errClientDisconnected) } } @@ -1612,7 +1550,7 @@ func (sc *http2serverConn) sendServeMsg(msg interface{}) { } } -var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") +var errPrefaceTimeout = errors.New("timeout waiting for client preface") // readPreface reads the ClientPreface greeting from the peer or // returns errPrefaceTimeout on timeout, or an error if the greeting @@ -1634,7 +1572,7 @@ func (sc *http2serverConn) readPreface() error { defer timer.Stop() select { case <-timer.C: - return http2errPrefaceTimeout + return errPrefaceTimeout case err := <-errc: if err == nil { if http2VerboseLogs { @@ -1672,7 +1610,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte case err = <-ch: frameWriteDone = true case <-sc.doneServing: - return http2errClientDisconnected + return errClientDisconnected case <-stream.cw: // If both ch and stream.cw were ready (as might // happen on the final Write after an http.Handler @@ -1685,7 +1623,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte case err = <-ch: frameWriteDone = true default: - return http2errStreamClosed + return errStreamClosed } } http2errChanPool.Put(ch) @@ -1710,7 +1648,7 @@ func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) erro case <-sc.doneServing: // Serve loop is gone. // Client has closed their connection to the server. - return http2errClientDisconnected + return errClientDisconnected } } @@ -1870,7 +1808,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // a complete response. sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) case http2stateHalfClosedRemote: - sc.closeStream(st, http2errHandlerComplete) + sc.closeStream(st, errHandlerComplete) } } else { switch v := wr.write.(type) { @@ -2015,7 +1953,7 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool sc.serveG.check() err := res.err if err != nil { - if err == http2ErrFrameTooLarge { + if err == errFrameTooLarge { sc.goAway(http2ErrCodeFrameSize) return true // goAway will close the loop } @@ -2763,14 +2701,14 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re } delete(rp.header, "Trailer") - var url_ *url.URL + var u *url.URL var requestURI string if rp.method == "CONNECT" { - url_ = &url.URL{Host: rp.authority} + u = &url.URL{Host: rp.authority} requestURI = rp.authority // mimic HTTP/1 server behavior } else { var err error - url_, err = url.ParseRequestURI(rp.path) + u, err = url.ParseRequestURI(rp.path) if err != nil { return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) } @@ -2784,7 +2722,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re } req := &http.Request{ Method: rp.method, - URL: url_, + URL: u, RemoteAddr: sc.remoteAddrStr, Header: rp.header, RequestURI: requestURI, @@ -2873,9 +2811,9 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR http2errChanPool.Put(errc) return err case <-sc.doneServing: - return http2errClientDisconnected + return errClientDisconnected case <-st.cw: - return http2errStreamClosed + return errStreamClosed } } return nil @@ -2977,7 +2915,7 @@ type http2requestBody struct { func (b *http2requestBody) Close() error { if b.pipe != nil && !b.closed { - b.pipe.BreakWithError(http2errClosedBody) + b.pipe.BreakWithError(errClosedBody) } b.closed = true return nil @@ -3379,9 +3317,8 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n if dataB != nil { return rws.bw.Write(dataB) - } else { - return rws.bw.WriteString(dataS) } + return rws.bw.WriteString(dataS) } func (w *http2responseWriter) handlerDone() { @@ -3493,17 +3430,17 @@ func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error select { case <-sc.doneServing: - return http2errClientDisconnected + return errClientDisconnected case <-st.cw: - return http2errStreamClosed + return errStreamClosed case sc.serveMsgCh <- msg: } select { case <-sc.doneServing: - return http2errClientDisconnected + return errClientDisconnected case <-st.cw: - return http2errStreamClosed + return errStreamClosed case err := <-msg.done: http2errChanPool.Put(msg.done) return err @@ -3526,7 +3463,7 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { // is in either the "open" or "half-closed (remote)" state. if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { // responseWriter.Push checks that the stream is peer-initiated. - msg.done <- http2errStreamClosed + msg.done <- errStreamClosed return } @@ -3948,9 +3885,8 @@ func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []b BlockFragment: frag, EndHeaders: lastFrag, }) - } else { - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } type http2write100ContinueHeadersFrame struct { diff --git a/h2_transport.go b/h2_transport.go index 06d8052d..0a1ef6a8 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -1235,7 +1235,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { var dumps []*dumper if t1 := cs.cc.t.t1; t1 != nil { - dumps = getDumpers(t1.dump, req.Context()) + dumps = getDumpers(req.Context(), t1.dump) } // Past this point (where we send request headers), it is possible for @@ -2109,7 +2109,7 @@ func (cc *http2ClientConn) countReadFrameError(err error) { f("read_frame_unexpected_eof") return } - if errors.Is(err, http2ErrFrameTooLarge) { + if errors.Is(err, errFrameTooLarge) { f("read_frame_too_large") return } diff --git a/h2_transport_test.go b/h2_transport_test.go index 252d355f..7e2db8a7 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -4206,7 +4206,7 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { gotErr := <-writeErr if gotErr == nil { t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") - } else if gotErr != http2errStreamClosed { + } else if gotErr != errStreamClosed { t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) } } @@ -5879,7 +5879,7 @@ func TestCountReadFrameError(t *testing.T) { cc.countReadFrameError(err) tests.AssertContains(t, errMsg, "read_frame_unexpected_eof", true) - err = http2ErrFrameTooLarge + err = errFrameTooLarge cc.countReadFrameError(err) tests.AssertContains(t, errMsg, "read_frame_too_large", true) diff --git a/req_test.go b/req_test.go index caf896ad..6b6d3114 100644 --- a/req_test.go +++ b/req_test.go @@ -453,7 +453,7 @@ func testGlobalWrapperEnableDumps(t *testing.T) { func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { var reqHeader, reqBody, respHeader, respBody bool r := fn(&reqHeader, &reqBody, &respHeader, &respBody) - dump, ok := r.Context().Value("_dumper").(*dumper) + dump, ok := r.Context().Value(dumperKey).(*dumper) if !ok { t.Fatal("no dumper found in request context") } diff --git a/request.go b/request.go index 9ee90d76..d832b5f5 100644 --- a/request.go +++ b/request.go @@ -936,7 +936,7 @@ func EnableDump() *Request { // EnableDump enables dump, including all content for the request and response by default. func (r *Request) EnableDump() *Request { - return r.SetContext(context.WithValue(r.Context(), "_dumper", newDumper(r.getDumpOptions()))) + return r.SetContext(context.WithValue(r.Context(), dumperKey, newDumper(r.getDumpOptions()))) } // EnableDumpWithoutBody is a global wrapper methods which delegated diff --git a/transport.go b/transport.go index 902523e7..585b5461 100644 --- a/transport.go +++ b/transport.go @@ -279,7 +279,7 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { } func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { - dumps := getDumpers(t.dump, req.Context()) + dumps := getDumpers(req.Context(), t.dump) for _, dump := range dumps { if dump.ResponseBody { res.Body = dump.WrapReadCloser(res.Body) @@ -1871,7 +1871,7 @@ func fixPragmaCacheControl(header http.Header) { // 100-continue") from the server. It returns the final non-100 one. // trace is optional. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - dumps := getDumpers(pc.t.dump, req.Context()) + dumps := getDumpers(req.Context(), pc.t.dump) tp := newTextprotoReader(pc.br, dumps) resp := &http.Response{ Request: req, @@ -2253,9 +2253,8 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { if is408Message(buf) { pc.closeLocked(errServerClosedIdle) return - } else { - log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) } + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) } if peekErr == io.EOF { // common case. @@ -2501,7 +2500,7 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo } rw := w // raw writer - dumps := getDumpers(pc.t.dump, r.Context()) + dumps := getDumpers(r.Context(), pc.t.dump) for _, dump := range dumps { if dump.RequestHeader { w = dump.WrapWriter(w) From 9a15ac151ecf5ad600e5ab176f0c087f446ab3fd Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 21 Feb 2022 08:56:17 +0800 Subject: [PATCH 413/843] add awesome-go badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bf209deb..a73a251c 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Go Report Card License + Mentioned in Awesome Go

From 83355911e34337014e8585b2ac910cbb658f1324 Mon Sep 17 00:00:00 2001 From: rockerchen Date: Sat, 26 Feb 2022 22:42:56 +0800 Subject: [PATCH 414/843] fix #94 --- transport.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transport.go b/transport.go index 585b5461..5f991050 100644 --- a/transport.go +++ b/transport.go @@ -306,7 +306,10 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { } _, params, err := mime.ParseMediaType(contentType) if err != nil { - panic(err) + if t.Debugf != nil { + t.Debugf("failed to parse content type %q: %v", contentType, err) + } + return } if charset, ok := params["charset"]; ok { charset = strings.ToLower(charset) From d345ea9c445aa8d3d6123a5482f2c691424788cf Mon Sep 17 00:00:00 2001 From: rockerchen Date: Sat, 26 Feb 2022 22:59:14 +0800 Subject: [PATCH 415/843] expose http.Client --- client.go | 11 +++++++++++ req_test.go | 2 ++ 2 files changed, 13 insertions(+) diff --git a/client.go b/client.go index 6d84193d..87e0808a 100644 --- a/client.go +++ b/client.go @@ -1211,6 +1211,17 @@ func (c *Client) isPayloadForbid(m string) bool { return (m == http.MethodGet && !c.AllowGetMethodPayload) || m == http.MethodHead || m == http.MethodOptions } +// GetClient is a global wrapper methods which delegated +// to the default client's GetClient. +func GetClient() *http.Client { + return defaultClient.GetClient() +} + +// GetClient returns the underlying `http.Client`. +func (c *Client) GetClient() *http.Client { + return c.httpClient +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/req_test.go b/req_test.go index 6b6d3114..324cecfd 100644 --- a/req_test.go +++ b/req_test.go @@ -908,6 +908,8 @@ func TestGlobalWrapper(t *testing.T) { DisableForceHttpVersion() assertEqual(t, true, DefaultClient().t.ForceHttpVersion == "") + assertEqual(t, GetClient(), DefaultClient().httpClient) + r = NewRequest() assertEqual(t, true, r != nil) c = NewClient() From 740fb2ef5d895fa8384bc992fc4b1861da85726b Mon Sep 17 00:00:00 2001 From: rockerchen Date: Sat, 26 Feb 2022 23:23:21 +0800 Subject: [PATCH 416/843] update README: add release badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a73a251c..c121bb81 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Go Report Card License + GitHub Releases Mentioned in Awesome Go

From 51e7e9fa241e6c7d0034159747bfd412e47e1cc9 Mon Sep 17 00:00:00 2001 From: rockerchen Date: Mon, 28 Feb 2022 10:46:12 +0800 Subject: [PATCH 417/843] try sniff and auto-decode when Content-Type is malformed --- transport.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/transport.go b/transport.go index 5f991050..302398c5 100644 --- a/transport.go +++ b/transport.go @@ -309,9 +309,7 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { if t.Debugf != nil { t.Debugf("failed to parse content type %q: %v", contentType, err) } - return - } - if charset, ok := params["charset"]; ok { + } else if charset, ok := params["charset"]; ok { charset = strings.ToLower(charset) if strings.Contains(charset, "utf-8") || strings.Contains(charset, "utf8") { // do not decode utf-8 return @@ -319,14 +317,13 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { enc, _ := htmlcharset.Lookup(charset) if enc == nil { enc, err = ianaindex.MIME.Encoding(charset) - if err != nil { - // TODO: log charset not supported + if err != nil || enc == nil { + if t.Debugf != nil { + t.Debugf("ignore charset %s which is detected in Content-Type but not supported", charset) + } return } } - if enc == nil { - return - } if t.Debugf != nil { t.Debugf("charset %s detected in Content-Type, auto-decode to utf-8", charset) } From 163d61c689a2c3c8f20f6e3882cecb362da3e626 Mon Sep 17 00:00:00 2001 From: Lu Chang Date: Mon, 28 Feb 2022 16:33:56 +0800 Subject: [PATCH 418/843] fix typo - Add a missing *C* as client. - It seems that there are no *EnableDumpNoRequestBody* and *EnableDumpNoResponseBody* in code. I change it to *EnableDumpAllWithoutRequestBody* and *EnableDumpAllWithoutResponseBody*. Hope it's correct. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c121bb81..a377027e 100644 --- a/README.md +++ b/README.md @@ -736,7 +736,7 @@ client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.e ```go // Create a client with default download direcotry -client := req.C().SetOutputDirectory("/path/to/download").EnableDumpNoResponseBody() +client := req.C().SetOutputDirectory("/path/to/download").EnableDumpAllWithoutResponseBody() // Download to relative file path, this will be downloaded // to /path/to/download/test.jpg @@ -758,7 +758,7 @@ client.R().SetOutput(file).Get(url) **Multipart Upload** ```go -client := req.().EnableDumpNoRequestBody() // Request body contains unreadable binary, do not dump +client := req.C().EnableDumpAllWithoutRequestBody() // Request body contains unreadable binary, do not dump client.R().SetFile("pic", "test.jpg"). // Set form param name and filename SetFile("pic", "/path/to/roc.png"). // Multiple files using the same form param name From c2cb3214280a8d24fc19ba1905e319297f202aeb Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Mar 2022 16:01:58 +0800 Subject: [PATCH 419/843] Set ContentLength and guess ContentType if body is in-memory []byte. --- middleware.go | 13 ++++++++----- req_test.go | 3 +++ request.go | 23 +++++++++-------------- request_test.go | 21 +++++++++++++++++++-- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/middleware.go b/middleware.go index 054ebcfd..fef73e1b 100644 --- a/middleware.go +++ b/middleware.go @@ -160,11 +160,14 @@ func parseRequestBody(c *Client, r *Request) (err error) { handleMarshalBody(c, r) } - if r.getHeader(hdrContentTypeKey) == "" && r.body != nil { - ct := http.DetectContentType(r.body) - if ct != "application/octet-stream" { - r.SetContentType(ct) - } + if r.body == nil { + return + } + // body is in-memory []byte, so we can set content length + // and guess content type + r.RawRequest.ContentLength = int64(len(r.body)) + if r.getHeader(hdrContentTypeKey) == "" { + r.SetContentType(http.DetectContentType(r.body)) } return } diff --git a/req_test.go b/req_test.go index 324cecfd..5e1f0da5 100644 --- a/req_test.go +++ b/req_test.go @@ -108,6 +108,9 @@ func handlePost(w http.ResponseWriter, r *http.Request) { io.Copy(ioutil.Discard, r.Body) w.Header().Set(hdrLocationKey, "/") w.WriteHeader(http.StatusMovedPermanently) + case "/content-type": + io.Copy(ioutil.Discard, r.Body) + w.Write([]byte(r.Header.Get(hdrContentTypeKey))) case "/echo": b, _ := ioutil.ReadAll(r.Body) e := echo{ diff --git a/request.go b/request.go index d832b5f5..edcea35f 100644 --- a/request.go +++ b/request.go @@ -708,8 +708,7 @@ func SetBodyString(body string) *Request { // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { - r.SetBodyBytes([]byte(body)) - return r.SetContentType(plainTextContentType) + return r.SetBodyBytes([]byte(body)) } // SetBodyJsonString is a global wrapper methods which delegated @@ -721,8 +720,7 @@ func SetBodyJsonString(body string) *Request { // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { - r.SetBodyBytes([]byte(body)) - return r.SetContentType(jsonContentType) + return r.SetBodyJsonBytes([]byte(body)) } // SetBodyJsonBytes is a global wrapper methods which delegated @@ -734,8 +732,8 @@ func SetBodyJsonBytes(body []byte) *Request { // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { - r.SetBodyBytes(body) - return r.SetContentType(jsonContentType) + r.SetContentType(jsonContentType) + return r.SetBodyBytes(body) } // SetBodyJsonMarshal is a global wrapper methods which delegated @@ -752,8 +750,7 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { r.appendError(err) return r } - r.SetBodyBytes(b) - return r.SetContentType(jsonContentType) + return r.SetBodyJsonBytes(b) } // SetBodyXmlString is a global wrapper methods which delegated @@ -765,8 +762,7 @@ func SetBodyXmlString(body string) *Request { // SetBodyXmlString set the request body as string and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { - r.SetBodyBytes([]byte(body)) - return r.SetContentType(xmlContentType) + return r.SetBodyXmlBytes([]byte(body)) } // SetBodyXmlBytes is a global wrapper methods which delegated @@ -778,8 +774,8 @@ func SetBodyXmlBytes(body []byte) *Request { // SetBodyXmlBytes set the request body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { - r.SetBodyBytes(body) - return r.SetContentType(xmlContentType) + r.SetContentType(xmlContentType) + return r.SetBodyBytes(body) } // SetBodyXmlMarshal is a global wrapper methods which delegated @@ -796,8 +792,7 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { r.appendError(err) return r } - r.SetBodyBytes(b) - return r.SetContentType(xmlContentType) + return r.SetBodyXmlBytes(b) } // SetContentType is a global wrapper methods which delegated diff --git a/request_test.go b/request_test.go index c0385b3c..17757b14 100644 --- a/request_test.go +++ b/request_test.go @@ -662,9 +662,26 @@ func testTraceOnTimeout(t *testing.T, c *Client) { } func TestAutoDetectRequestContentType(t *testing.T) { - resp, err := tc().R().SetBody(tests.GetTestFileContent(t, "sample-image.png")).Post("/raw-upload") + c := tc() + resp, err := c.R().SetBody(tests.GetTestFileContent(t, "sample-image.png")).Post("/content-type") + assertSuccess(t, resp, err) + assertEqual(t, "image/png", resp.String()) + + resp, err = c.R().SetBodyJsonString(`{"msg": "test"}`).Post("/content-type") + assertSuccess(t, resp, err) + assertEqual(t, jsonContentType, resp.String()) + + resp, err = c.R().SetContentType(xmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") + assertSuccess(t, resp, err) + assertEqual(t, xmlContentType, resp.String()) + + resp, err = c.R().SetBody(`

hello

`).Post("/content-type") + assertSuccess(t, resp, err) + assertEqual(t, "text/html; charset=utf-8", resp.String()) + + resp, err = c.R().SetBody(`hello world`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, "image/png", resp.Request.Headers.Get(hdrContentTypeKey)) + assertEqual(t, plainTextContentType, resp.String()) } func TestUploadMultipart(t *testing.T) { From 4c70ddb623f741fee3cdfc0803cd9d684aa8a70e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Mar 2022 17:40:41 +0800 Subject: [PATCH 420/843] use variadic parameter in SetFileUpload --- request.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/request.go b/request.go index edcea35f..526de6dd 100644 --- a/request.go +++ b/request.go @@ -263,14 +263,16 @@ func (r *Request) SetFile(paramName, filePath string) *Request { // SetFileUpload is a global wrapper methods which delegated // to the default client, create a request and SetFileUpload for request. -func SetFileUpload(f FileUpload) *Request { - return defaultClient.R().SetFileUpload(f) +func SetFileUpload(f ...FileUpload) *Request { + return defaultClient.R().SetFileUpload(f...) } // SetFileUpload set the fully custimized multipart file upload options. -func (r *Request) SetFileUpload(f FileUpload) *Request { +func (r *Request) SetFileUpload(uploads ...FileUpload) *Request { r.isMultiPart = true - r.uploadFiles = append(r.uploadFiles, &f) + for _, upload := range uploads { + r.uploadFiles = append(r.uploadFiles, &upload) + } return r } From cc2142d164dcb190382f5dc3aee6189b7847167f Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Mar 2022 19:06:23 +0800 Subject: [PATCH 421/843] improve quick api reference --- README.md | 2 +- docs/api.md | 90 ++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index a377027e..7d87d223 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ resp, err := client.R(). // Use R() to create a request **API Reference** -Checkout [Req API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). +Checkout [Quick API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). **Examples** diff --git a/docs/api.md b/docs/api.md index fa668865..8a0529d6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,5 +1,5 @@

-

Req API Reference

+

Quick API Reference

Here is a brief and categorized list of the core APIs, for a more detailed and complete list, please refer to the [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). @@ -8,10 +8,11 @@ Here is a brief and categorized list of the core APIs, for a more detailed and c * [Client Settings](#Client) * [Debug Features](#Debug) - * [Common Settings for HTTP Requests](#Common) + * [Common Settings for constructing HTTP Requests](#Common) * [Auto-Decode](#Decode) - * [Certificates](#Certs) + * [TLS and Certificates](#Certs) * [Marshal&Unmarshal](#Marshal) + * [HTTP Version](#Version) * [Other Settings](#Other) * [Request Settings](#Request) * [URL Query and Path Parameter](#Query) @@ -21,6 +22,7 @@ Here is a brief and categorized list of the core APIs, for a more detailed and c * [Multipart & Form & Upload](#Multipart) * [Download](#Download) * [Other Settings](#Other-Request) +* [Sending Request](#Send-Request) ## Client Settings @@ -38,34 +40,36 @@ Basically, you can know the meaning of most settings directly from the method na * [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests. * [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) -* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) -* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) -* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) -* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) -* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) -* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) -* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) -* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) -* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) * [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) +* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) +* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) +* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) +* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) +* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) +* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) +* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) +* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) +* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) * [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). * [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) -### Common Settings for HTTP Requests +### Common Settings for constructing HTTP Requests -* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) -* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) -* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) -* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) -* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) * [SetCommonBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBasicAuth) +* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) +* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) * [SetCommonCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonCookies) -* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) +* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) +* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) +* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) +* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) +* [SetCommonPathParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonPathParam) +* [SetCommonPathParams(pathParams map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonPathParams) * [SetCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParam) * [SetCommonQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParams) -* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) -* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) +* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) +* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) * [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) ### Auto-Decode @@ -76,12 +80,16 @@ Basically, you can know the meaning of most settings directly from the method na * [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) * [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) -### Certificates +### TLS and Certificates * [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) * [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) * [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) * [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) +* [EnableInsecureSkipVerify()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableInsecureSkipVerify) - Disabled by default. +* [DisableInsecureSkipVerify](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableInsecureSkipVerify) +* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) +* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) ### Marshal&Unmarshal @@ -95,16 +103,17 @@ Basically, you can know the meaning of most settings directly from the method na * [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) * [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) +### HTTP Version + +* [DisableForceHttpVersion()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHttpVersion) +* [EnableForceHTTP2()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP2) +* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) + ### Other Settings * [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) -* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) -* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) -* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) - Disabled by default. -* [DisableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHTTP1) - -* [EnableKeepAlives()](EnableKeepAlives()) +* [EnableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableKeepAlives) * [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Enabled by default. * [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) @@ -131,6 +140,7 @@ Basically, you can know the meaning of most settings directly from the method na * [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. * [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) + ## Request Settings The following are the chainable settings of Request, all of which have corresponding global wrappers. @@ -189,8 +199,10 @@ Basically, you can know the meaning of most settings directly from the method na * [SetFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormData) * [SetFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormDataFromValues) * [SetFile(paramName, filePath string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFile) -* [SetFileReader(paramName, filePath string, reader io.Reader)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileReader) * [SetFiles(files map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFiles) +* [SetFileBytes(paramName, filename string, content []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileBytes) +* [SetFileReader(paramName, filePath string, reader io.Reader)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileReader) +* [SetFileUpload(uploads ...FileUpload)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileUpload) - Set the fully custimized multipart file upload options. ### Download @@ -200,3 +212,23 @@ Basically, you can know the meaning of most settings directly from the method na ### Other Settings * [SetContext(ctx context.Context)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContext) + +## Sending Request + +These methods will fire the http request and get response, `MustXXX` will not return any error, panic if error happens. + +* [Get(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Get) +* [Head(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Head) +* [Post(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Post) +* [Delete(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Delete) +* [Patch(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Patch) +* [Options(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Options) +* [Put(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Put) +* [Send(method, url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Put) - Send request with given method name and url. +* [MustGet(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustGet) +* [MustHead(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustHead) +* [MustPost(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPost) +* [MustDelete(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustDelete) +* [MustPatch(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPatch) +* [MustOptions(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustOptions) +* [MustPut(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPut) From 05ba24da837e562664407af2a56ca8f9fe6909f3 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Mar 2022 19:54:35 +0800 Subject: [PATCH 422/843] upload README --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 7d87d223..ee079ec3 100644 --- a/README.md +++ b/README.md @@ -766,7 +766,7 @@ client.R().SetFile("pic", "test.jpg"). // Set form param name and filename "exe": "test.exe", "src": "main.go", }). - SetFormData(map[string]string{ // Set from param using map + SetFormData(map[string]string{ // Set form data while uploading "name": "imroc", "email": "roc@imroc.cc", }). @@ -881,12 +881,10 @@ client.SetProxy(nil) ## TODO List -* [ ] Add more tests. * [ ] Wrap more transport settings into client. * [ ] Support retry. * [ ] Support unix socket. * [ ] Support h2c. -* [ ] Make videos. * [ ] Design a logo. * [ ] Support HTTP3. From dddb43c875f6754428e76bc89eee8d94e6d2df95 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 17:44:34 +0800 Subject: [PATCH 423/843] Support retry. 1. SetRetryCount 2. SetRetryInterval 3. SetRetryFixedInterval 4. SetRetryBackoffInterval 5. SetRetryHook 6. AddRetryHook 7. SetRetryCondition 8. AddRetryCondition --- client.go | 285 ++++++++++++++++++------ middleware.go | 46 ++-- req.go | 3 +- req_test.go | 559 ++++++++++++++++-------------------------------- request.go | 211 ++++++++++++++++-- request_test.go | 48 +++++ retry.go | 63 ++++++ retry_test.go | 129 +++++++++++ 8 files changed, 855 insertions(+), 489 deletions(-) create mode 100644 retry.go create mode 100644 retry_test.go diff --git a/client.go b/client.go index 87e0808a..758bbc3b 100644 --- a/client.go +++ b/client.go @@ -45,6 +45,7 @@ type Client struct { DebugLog bool AllowGetMethodPayload bool + retryOption *retryOption jsonMarshal func(v interface{}) ([]byte, error) jsonUnmarshal func(data []byte, v interface{}) error xmlMarshal func(v interface{}) ([]byte, error) @@ -71,15 +72,9 @@ func R() *Request { // R create a new request. func (c *Client) R() *Request { - req := &http.Request{ - Header: make(http.Header), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - } return &Request{ - client: c, - RawRequest: req, + client: c, + retryOption: c.retryOption.Clone(), } } @@ -1222,6 +1217,126 @@ func (c *Client) GetClient() *http.Client { return c.httpClient } +func (c *Client) getRetryOption() *retryOption { + if c.retryOption == nil { + c.retryOption = newDefaultRetryOption() + } + return c.retryOption +} + +// SetCommonRetryCount is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryCount for request. +func SetCommonRetryCount(count int) *Client { + return defaultClient.SetCommonRetryCount(count) +} + +// SetCommonRetryCount enables retry and set the maximum retry count for all requests. +func (c *Client) SetCommonRetryCount(count int) *Client { + c.getRetryOption().MaxRetries = count + return c +} + +// SetCommonRetryInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryInterval for request. +func SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { + return defaultClient.SetCommonRetryInterval(getRetryIntervalFunc) +} + +// SetCommonRetryInterval sets the custom GetRetryIntervalFunc for all requests, +// you can use this to implement your own backoff retry algorithm. +// For example: +// req.SetCommonRetryInterval(func(attempt int) time.Duration { +// sleep := 0.01 * math.Exp2(float64(attempt)) +// return time.Duration(math.Min(2, sleep)) * time.Second +// }) +func (c *Client) SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { + c.getRetryOption().GetRetryInterval = getRetryIntervalFunc + return c +} + +// SetCommonRetryFixedInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryFixedInterval for request. +func SetCommonRetryFixedInterval(interval time.Duration) *Client { + return defaultClient.SetCommonRetryFixedInterval(interval) +} + +// SetCommonRetryFixedInterval set retry to use a fixed interval for all requests. +func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { + c.getRetryOption().GetRetryInterval = func(attempt int) time.Duration { + return interval + } + return c +} + +// SetCommonRetryBackoffInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryBackoffInterval for request. +func SetCommonRetryBackoffInterval(min, max time.Duration) *Client { + return defaultClient.SetCommonRetryBackoffInterval(min, max) +} + +// SetCommonRetryBackoffInterval set retry to use a capped exponential backoff with jitter +// for all requests. +// https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ +func (c *Client) SetCommonRetryBackoffInterval(min, max time.Duration) *Client { + c.getRetryOption().GetRetryInterval = backoffInterval(min, max) + return c +} + +// SetCommonRetryHook is a global wrapper methods which delegated +// to the default client, create a request and SetRetryHook for request. +func SetCommonRetryHook(hook RetryHookFunc) *Client { + return defaultClient.SetCommonRetryHook(hook) +} + +// SetCommonRetryHook set the retry hook which will be executed before a retry. +// It will override other retry hooks if any been added before. +func (c *Client) SetCommonRetryHook(hook RetryHookFunc) *Client { + c.getRetryOption().RetryHooks = []RetryHookFunc{hook} + return c +} + +// AddCommonRetryHook is a global wrapper methods which delegated +// to the default client, create a request and AddCommonRetryHook for request. +func AddCommonRetryHook(hook RetryHookFunc) *Client { + return defaultClient.AddCommonRetryHook(hook) +} + +// AddCommonRetryHook adds a retry hook for all requests, which will be +// executed before a retry. +func (c *Client) AddCommonRetryHook(hook RetryHookFunc) *Client { + ro := c.getRetryOption() + ro.RetryHooks = append(ro.RetryHooks, hook) + return c +} + +// SetCommonRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryCondition for request. +func SetCommonRetryCondition(condition RetryConditionFunc) *Client { + return defaultClient.SetCommonRetryCondition(condition) +} + +// SetCommonRetryCondition sets the retry condition, which determines whether the +// request should retry. +// It will override other retry conditions if any been added before. +func (c *Client) SetCommonRetryCondition(condition RetryConditionFunc) *Client { + c.getRetryOption().RetryConditions = []RetryConditionFunc{condition} + return c +} + +// AddCommonRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and AddCommonRetryCondition for request. +func AddCommonRetryCondition(condition RetryConditionFunc) *Client { + return defaultClient.AddCommonRetryCondition(condition) +} + +// AddCommonRetryCondition adds a retry condition, which determines whether the +// request should retry. +func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { + ro := c.getRetryOption() + ro.RetryConditions = append(ro.RetryConditions, condition) + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() @@ -1308,13 +1423,6 @@ func C() *Client { return c } -func setupRequest(r *Request) { - setRequestURL(r, r.URL) - setRequestHeaderAndCookie(r) - setTrace(r) - setContext(r) -} - func (c *Client) do(r *Request) (resp *Response, err error) { resp = &Response{ Request: r, @@ -1331,74 +1439,113 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } } - setupRequest(r) + // setup trace + if r.trace == nil && r.client.trace { + r.trace = &clientTrace{} + } + if r.trace != nil { + r.ctx = r.trace.createContext(r.Context()) + } - if c.DebugLog { - c.log.Debugf("%s %s", r.RawRequest.Method, r.RawRequest.URL.String()) + // setup url and host + var host string + if h := r.getHeader("Host"); h != "" { + host = h // Host header override + } else { + host = r.URL.Host } - r.StartTime = time.Now() - httpResponse, err := c.httpClient.Do(r.RawRequest) - if err != nil { - return + // setup header + var header http.Header + if r.Headers == nil { + header = make(http.Header) + } else { + header = r.Headers.Clone() } + contentLength := int64(len(r.body)) - resp.Request = r - resp.Response = httpResponse + for { + var reqBody io.ReadCloser + if r.getBody != nil { + reqBody, err = r.getBody() + if err != nil { + return + } + } + req := &http.Request{ + Method: r.method, + Header: header, + URL: r.URL, + Host: host, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: contentLength, + Body: reqBody, + GetBody: r.getBody, + } + for _, cookie := range r.Cookies { + req.AddCookie(cookie) + } + if r.ctx != nil { + req = req.WithContext(r.ctx) + } + r.RawRequest = req - if !c.disableAutoReadResponse && !r.isSaveResponse { // auto read response body - _, err = resp.ToBytes() - if err != nil { - return + if c.DebugLog { + c.log.Debugf("%s %s", req.Method, req.URL.String()) } - } - for _, f := range r.client.afterResponse { - if err = f(r.client, resp); err != nil { - return + r.StartTime = time.Now() + var httpResponse *http.Response + httpResponse, err = c.httpClient.Do(req) + resp.Response = httpResponse + + // auto-read response body if possible + if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse { + _, err = resp.ToBytes() + if err != nil { + return + } } - } - return -} -func setContext(r *Request) { - if r.ctx != nil { - r.RawRequest = r.RawRequest.WithContext(r.ctx) - } -} + if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. + if err != nil { // return immediately if error occurs. + return + } + break // jump out to execute the ResponseMiddlewares if possible. + } -func setTrace(r *Request) { - if r.trace == nil && r.client.trace { - r.trace = &clientTrace{} - } - if r.trace != nil { - r.ctx = r.trace.createContext(r.Context()) - } -} + // check retry whether is needed. + needRetry := err != nil // default behaviour: retry if error occurs + for _, condition := range r.retryOption.RetryConditions { // override default behaviour if custom RetryConditions has been set. + needRetry = condition(resp, err) + if needRetry { + break + } + } + if !needRetry { // no retry is needed. + return + } -func setRequestHeaderAndCookie(r *Request) { - for k, vs := range r.Headers { - for _, v := range vs { - r.RawRequest.Header.Add(k, v) + // need retry, attempt to retry + r.RetryAttempt++ + for _, hook := range r.retryOption.RetryHooks { // run retry hooks + hook(resp, err) } - } - for _, cookie := range r.Cookies { - r.RawRequest.AddCookie(cookie) - } -} + time.Sleep(r.retryOption.GetRetryInterval(r.RetryAttempt)) -func setRequestURL(r *Request, url string) error { - // The host's colon:port should be normalized. See Issue 14836. - u, err := urlpkg.Parse(url) - if err != nil { - return err + // clean buffers + if r.dumpBuffer != nil { + r.dumpBuffer.Reset() + } + resp.body = nil } - u.Host = removeEmptyPort(u.Host) - if host := r.getHeader("Host"); host != "" { - r.RawRequest.Host = host // Host header override - } else { - r.RawRequest.Host = u.Host + + for _, f := range r.client.afterResponse { + if err = f(r.client, resp); err != nil { + return + } } - r.RawRequest.URL = u - return nil + return } diff --git a/middleware.go b/middleware.go index fef73e1b..b542313f 100644 --- a/middleware.go +++ b/middleware.go @@ -56,10 +56,14 @@ func closeq(v interface{}) { } func writeMultipartFormFile(w *multipart.Writer, file *FileUpload) error { - defer closeq(file.File) + content, err := file.GetFileContent() + if err != nil { + return err + } + defer content.Close() // Auto detect actual multipart content type cbuf := make([]byte, 512) - size, err := file.File.Read(cbuf) + size, err := content.Read(cbuf) if err != nil && err != io.EOF { return err } @@ -73,7 +77,7 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload) error { return err } - _, err = io.Copy(pw, file.File) + _, err = io.Copy(pw, content) return err } @@ -92,7 +96,9 @@ func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { func handleMultiPart(c *Client, r *Request) (err error) { pr, pw := io.Pipe() - r.RawRequest.Body = pr + r.getBody = func() (io.ReadCloser, error) { + return pr, nil + } w := multipart.NewWriter(pw) r.SetContentType(w.FormDataContentType()) go writeMultiPart(r, w, pw) @@ -137,12 +143,12 @@ func handleMarshalBody(c *Client, r *Request) error { } func parseRequestBody(c *Client, r *Request) (err error) { - if c.isPayloadForbid(r.RawRequest.Method) { - r.RawRequest.Body = nil + if c.isPayloadForbid(r.method) { + r.getBody = nil return } // handle multipart - if r.isMultiPart && (r.RawRequest.Method != http.MethodPatch) { + if r.isMultiPart && (r.method != http.MethodPatch) { return handleMultiPart(c, r) } @@ -163,9 +169,7 @@ func parseRequestBody(c *Client, r *Request) (err error) { if r.body == nil { return } - // body is in-memory []byte, so we can set content length - // and guess content type - r.RawRequest.ContentLength = int64(len(r.body)) + // body is in-memory []byte, so we can guess content type if r.getHeader(hdrContentTypeKey) == "" { r.SetContentType(http.DetectContentType(r.body)) } @@ -266,33 +270,35 @@ func parseRequestCookie(c *Client, r *Request) error { return nil } +// generate URL func parseRequestURL(c *Client, r *Request) error { + tempURL := r.RawURL if len(r.PathParams) > 0 { for p, v := range r.PathParams { - r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1) + tempURL = strings.Replace(tempURL, "{"+p+"}", url.PathEscape(v), -1) } } if len(c.PathParams) > 0 { for p, v := range c.PathParams { - r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1) + tempURL = strings.Replace(tempURL, "{"+p+"}", url.PathEscape(v), -1) } } // Parsing request URL - reqURL, err := url.Parse(r.URL) + reqURL, err := url.Parse(tempURL) if err != nil { return err } - // If Request.URL is relative path then added c.BaseURL into + // If RawURL is relative path then added c.BaseURL into // the request URL otherwise Request.URL will be used as-is if !reqURL.IsAbs() { - r.URL = reqURL.String() - if len(r.URL) > 0 && r.URL[0] != '/' { - r.URL = "/" + r.URL + tempURL = reqURL.String() + if len(tempURL) > 0 && tempURL[0] != '/' { + tempURL = "/" + tempURL } - reqURL, err = url.Parse(c.BaseURL + r.URL) + reqURL, err = url.Parse(c.BaseURL + tempURL) if err != nil { return err } @@ -332,7 +338,7 @@ func parseRequestURL(c *Client, r *Request) error { } } - r.URL = reqURL.String() - + reqURL.Host = removeEmptyPort(reqURL.Host) + r.URL = reqURL return nil } diff --git a/req.go b/req.go index 225d254b..8a7970a9 100644 --- a/req.go +++ b/req.go @@ -2,7 +2,6 @@ package req import ( "fmt" - "io" "net/http" "net/url" ) @@ -53,7 +52,7 @@ type FileUpload struct { // "filename" parameter in `Content-Disposition` FileName string // The file to be uploaded. - File io.Reader + GetFileContent GetContentFunc // According to the HTTP specification, this should be nil, // but some servers may not follow the specification and diff --git a/req_test.go b/req_test.go index 5e1f0da5..57e46bf6 100644 --- a/req_test.go +++ b/req_test.go @@ -191,6 +191,13 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte("TestGet: text response")) case "/bad-request": w.WriteHeader(http.StatusBadRequest) + case "/too-many": + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Write([]byte(`{"errMsg":"too many requests"}`)) + case "/retry": + r.ParseForm() + r.Form.Get("attempt") case "/chunked": w.Header().Add("Trailer", "Expires") w.Write([]byte(`This is a chunked body`)) @@ -450,7 +457,6 @@ func testGlobalWrapperEnableDumps(t *testing.T) { RequestHeader: true, }) assertEqual(t, true, r.getDumpOptions().RequestHeader) - } func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { @@ -520,96 +526,83 @@ func testGlobalWrapperSendRequest(t *testing.T) { assertEqual(t, "POST", resp.Header.Get("Method")) } -func TestGlobalWrapperSetRequest(t *testing.T) { - testFilePath := tests.GetTestFilePath("sample-file.txt") - r := SetFiles(map[string]string{"test": testFilePath}) - assertEqual(t, 1, len(r.uploadFiles)) - assertEqual(t, true, r.isMultiPart) - - r = SetFile("test", tests.GetTestFilePath("sample-file.txt")) - assertEqual(t, 1, len(r.uploadFiles)) - assertEqual(t, true, r.isMultiPart) +func testGlobalWrapperSetRequest(t *testing.T, rs ...*Request) { + for _, r := range rs { + assertNotNil(t, r) + } +} +func TestGlobalWrapperSetRequest(t *testing.T) { SetLogger(nil) - r = SetFile("test", tests.GetTestFilePath("file-not-exists.txt")) - assertEqual(t, 0, len(r.uploadFiles)) - assertEqual(t, false, r.isMultiPart) - assertNotNil(t, r.error) - - r = SetFileReader("test", "test.txt", bytes.NewBufferString("test")) - assertEqual(t, 1, len(r.uploadFiles)) - assertEqual(t, true, r.isMultiPart) - - r = SetFileBytes("test", "test.txt", []byte("test")) - assertEqual(t, 1, len(r.uploadFiles)) - assertEqual(t, true, r.isMultiPart) - - r = SetFileUpload(FileUpload{}) - assertEqual(t, 1, len(r.uploadFiles)) - assertEqual(t, true, r.isMultiPart) - - var result string - r = SetError(&result) - assertEqual(t, true, r.Error != nil) - - r = SetResult(&result) - assertEqual(t, true, r.Result != nil) - - r = SetOutput(nil) - assertEqual(t, false, r.isSaveResponse) - - r = SetOutput(bytes.NewBufferString("test")) - assertEqual(t, true, r.isSaveResponse) - - r = SetHeader("test", "test") - assertEqual(t, "test", r.Headers.Get("test")) - - r = SetHeaders(map[string]string{"test": "test"}) - assertEqual(t, "test", r.Headers.Get("test")) - - r = SetCookies(&http.Cookie{ - Name: "test", - Value: "test", - }) - assertEqual(t, 1, len(r.Cookies)) - - r = SetBasicAuth("imroc", "123456") - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", r.Headers.Get("Authorization")) - - r = SetBearerAuthToken("123456") - assertEqual(t, "Bearer 123456", r.Headers.Get("Authorization")) - - r = SetQueryString("test=test") - assertEqual(t, "test", r.QueryParams.Get("test")) - - r = SetQueryString("ksjlfjk?") - assertEqual(t, "", r.QueryParams.Get("test")) - - r = SetQueryParam("test", "test") - assertEqual(t, "test", r.QueryParams.Get("test")) - - r = AddQueryParam("test", "test") - assertEqual(t, "test", r.QueryParams.Get("test")) - - r = SetQueryParams(map[string]string{"test": "test"}) - assertEqual(t, "test", r.QueryParams.Get("test")) - - r = SetPathParam("test", "test") - assertEqual(t, "test", r.PathParams["test"]) - - r = SetPathParams(map[string]string{"test": "test"}) - assertEqual(t, "test", r.PathParams["test"]) - - r = SetFormData(map[string]string{"test": "test"}) - assertEqual(t, "test", r.FormData.Get("test")) - + testFilePath := tests.GetTestFilePath("sample-file.txt") values := make(url.Values) values.Add("test", "test") - r = SetFormDataFromValues(values) - assertEqual(t, "test", r.FormData.Get("test")) + testGlobalWrapperSetRequest(t, + SetFiles(map[string]string{"test": testFilePath}), + SetFile("test", tests.GetTestFilePath("sample-file.txt")), + SetFile("test", tests.GetTestFilePath("file-not-exists.txt")), + SetFileReader("test", "test.txt", bytes.NewBufferString("test")), + SetFileBytes("test", "test.txt", []byte("test")), + SetFileUpload(FileUpload{ParamName: "test", FileName: "test.txt", GetFileContent: func() (io.ReadCloser, error) { + return nil, nil + }}), + SetError(&ErrorMessage{}), + SetResult(&UserInfo{}), + SetOutput(nil), + SetOutput(bytes.NewBufferString("test")), + SetHeader("test", "test"), + SetHeaders(map[string]string{"test": "test"}), + SetCookies(&http.Cookie{ + Name: "test", + Value: "test", + }), + SetBasicAuth("imroc", "123456"), + SetBearerAuthToken("123456"), + SetQueryString("test=test"), + SetQueryString("ksjlfjk?"), + SetQueryParam("test", "test"), + AddQueryParam("test", "test"), + SetQueryParams(map[string]string{"test": "test"}), + SetPathParam("test", "test"), + SetPathParams(map[string]string{"test": "test"}), + SetFormData(map[string]string{"test": "test"}), + SetFormDataFromValues(values), + SetContentType(jsonContentType), + AddRetryCondition(func(rep *Response, err error) bool { + return err != nil + }), + SetRetryCondition(func(rep *Response, err error) bool { + return err != nil + }), + AddRetryHook(func(resp *Response, err error) {}), + SetRetryHook(func(resp *Response, err error) {}), + SetRetryBackoffInterval(1*time.Millisecond, 500*time.Millisecond), + SetRetryFixedInterval(1*time.Millisecond), + SetRetryInterval(func(attempt int) time.Duration { + return 1 * time.Millisecond + }), + SetRetryCount(3), + SetBodyXmlMarshal(0), + SetBodyString("test"), + SetBodyBytes([]byte("test")), + SetBodyJsonBytes([]byte(`{"user":"roc"}`)), + SetBodyJsonString(`{"user":"roc"}`), + SetBodyXmlBytes([]byte("test")), + SetBodyXmlString("test"), + SetBody("test"), + SetBodyJsonMarshal(User{ + Name: "roc", + }), + EnableTrace(), + DisableTrace(), + SetContext(context.Background()), + ) +} - r = SetContentType(jsonContentType) - assertEqual(t, jsonContentType, r.Headers.Get(hdrContentTypeKey)) +func testGlobalClientSettingWrapper(t *testing.T, cs ...*Client) { + for _, c := range cs { + assertNotNil(t, c) + } } func TestGlobalWrapper(t *testing.T) { @@ -618,9 +611,6 @@ func TestGlobalWrapper(t *testing.T) { testGlobalWrapperEnableDumps(t) DisableInsecureSkipVerify() - SetCookieJar(nil) - assertEqual(t, nil, DefaultClient().httpClient.Jar) - testErr := errors.New("test") testDial := func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, testErr @@ -628,63 +618,6 @@ func TestGlobalWrapper(t *testing.T) { testDialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, testErr } - SetDialTLS(testDialTLS) - SetDial(testDial) - _, err := DefaultClient().t.DialTLSContext(nil, "", "") - assertEqual(t, testErr, err) - _, err = DefaultClient().t.DialContext(nil, "", "") - assertEqual(t, testErr, err) - - timeout := 2 * time.Second - SetTLSHandshakeTimeout(timeout) - assertEqual(t, timeout, DefaultClient().t.TLSHandshakeTimeout) - - EnableAllowGetMethodPayload() - assertEqual(t, true, DefaultClient().AllowGetMethodPayload) - DisableAllowGetMethodPayload() - assertEqual(t, false, DefaultClient().AllowGetMethodPayload) - - b := []byte("test") - s := string(b) - r := SetBodyXmlMarshal(0) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyString(s) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyBytes(b) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyJsonBytes(b) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyJsonString(s) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyXmlBytes(b) - assertEqual(t, true, len(r.body) > 0) - - r = SetBodyXmlString(s) - assertEqual(t, true, len(r.body) > 0) - - r = SetBody(nil) - assertEqual(t, true, r.RawRequest.Body == nil) - - r = SetBodyJsonMarshal(User{ - Name: "roc", - }) - assertEqual(t, true, r.RawRequest.Body != nil) - - r = EnableTrace() - assertEqual(t, true, r.trace != nil) - r = DisableTrace() - assertEqual(t, true, r.trace == nil) - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Second) - r = SetContext(ctx) - assertEqual(t, ctx, r.Context()) - cancel() marshalFunc := func(v interface{}) ([]byte, error) { return nil, testErr @@ -692,261 +625,131 @@ func TestGlobalWrapper(t *testing.T) { unmarshalFunc := func(data []byte, v interface{}) error { return testErr } - SetJsonMarshal(marshalFunc) - SetJsonUnmarshal(unmarshalFunc) - SetXmlMarshal(marshalFunc) - SetXmlUnmarshal(unmarshalFunc) - _, err = DefaultClient().jsonMarshal(nil) - assertEqual(t, testErr, err) - err = DefaultClient().jsonUnmarshal(nil, nil) - assertEqual(t, testErr, err) - _, err = DefaultClient().xmlMarshal(nil) - assertEqual(t, testErr, err) - err = DefaultClient().xmlUnmarshal(nil, nil) - assertEqual(t, testErr, err) - - EnableTraceAll() - assertEqual(t, true, DefaultClient().trace) - DisableTraceAll() - assertEqual(t, false, DefaultClient().trace) - - len1 := len(DefaultClient().afterResponse) - OnAfterResponse(func(client *Client, response *Response) error { - return nil - }) - len2 := len(DefaultClient().afterResponse) - assertEqual(t, true, len1+1 == len2) - - OnBeforeRequest(func(client *Client, request *Request) error { - return nil - }) - assertEqual(t, true, len(DefaultClient().udBeforeRequest) == 1) - - SetProxyURL("http://dummy.proxy.local") - u, err := DefaultClient().t.Proxy(nil) - assertError(t, err) - assertEqual(t, "http://dummy.proxy.local", u.String()) - - SetProxyURL("bad url") - u, err = DefaultClient().t.Proxy(nil) - assertError(t, err) - assertNotEqual(t, "bad url", u.String()) - - u, _ = url.Parse("http://dummy.proxy.local") + u, _ := url.Parse("http://dummy.proxy.local") proxy := http.ProxyURL(u) - SetProxy(proxy) - uu, err := DefaultClient().t.Proxy(nil) - assertError(t, err) - assertEqual(t, u.String(), uu.String()) - - SetCommonContentType(jsonContentType) - assertEqual(t, jsonContentType, DefaultClient().Headers.Get(hdrContentTypeKey)) - - SetCommonHeader("my-header", "my-value") - assertEqual(t, "my-value", DefaultClient().Headers.Get("my-header")) - - SetCommonHeaders(map[string]string{ - "header1": "value1", - "header2": "value2", - }) - assertEqual(t, "value1", DefaultClient().Headers.Get("header1")) - assertEqual(t, "value2", DefaultClient().Headers.Get("header2")) - - SetCommonBasicAuth("imroc", "123456") - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", DefaultClient().Headers.Get("Authorization")) - - SetCommonBearerAuthToken("123456") - assertEqual(t, "Bearer 123456", DefaultClient().Headers.Get("Authorization")) - - SetUserAgent("test") - assertEqual(t, "test", DefaultClient().Headers.Get(hdrUserAgentKey)) - - SetTimeout(timeout) - assertEqual(t, timeout, DefaultClient().httpClient.Timeout) - - l := createDefaultLogger() - SetLogger(l) - assertEqual(t, l, DefaultClient().log) - - SetScheme("https") - assertEqual(t, "https", DefaultClient().scheme) - - EnableDebugLog() - assertEqual(t, true, DefaultClient().DebugLog) - - DisableDebugLog() - assertEqual(t, false, DefaultClient().DebugLog) - - SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}) - assertEqual(t, "test", DefaultClient().Cookies[0].Name) - - SetCommonQueryString("test1=test1") - assertEqual(t, "test1", DefaultClient().QueryParams.Get("test1")) - - SetCommonPathParams(map[string]string{"test1": "test1"}) - assertEqual(t, "test1", DefaultClient().PathParams["test1"]) - - SetCommonPathParam("test2", "test2") - assertEqual(t, "test2", DefaultClient().PathParams["test2"]) - - AddCommonQueryParam("test1", "test11") - assertEqual(t, []string{"test1", "test11"}, DefaultClient().QueryParams["test1"]) - - SetCommonQueryParam("test1", "test111") - assertEqual(t, "test111", DefaultClient().QueryParams.Get("test1")) - - SetCommonQueryParams(map[string]string{"test1": "test1"}) - assertEqual(t, "test1", DefaultClient().QueryParams.Get("test1")) - - EnableInsecureSkipVerify() - assertEqual(t, true, DefaultClient().t.TLSClientConfig.InsecureSkipVerify) - - DisableInsecureSkipVerify() - assertEqual(t, false, DefaultClient().t.TLSClientConfig.InsecureSkipVerify) - - DisableCompression() - assertEqual(t, true, DefaultClient().t.DisableCompression) - - EnableCompression() - assertEqual(t, false, DefaultClient().t.DisableCompression) - - DisableKeepAlives() - assertEqual(t, true, DefaultClient().t.DisableKeepAlives) - - EnableKeepAlives() - assertEqual(t, false, DefaultClient().t.DisableKeepAlives) - - config := GetTLSClientConfig() - assertEqual(t, config, DefaultClient().t.TLSClientConfig) - - SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) - assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) - - SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))) - assertEqual(t, true, DefaultClient().t.TLSClientConfig.RootCAs != nil) - - SetCerts(tls.Certificate{}, tls.Certificate{}) - assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 2) - - SetCertFromFile( - tests.GetTestFilePath("sample-client.pem"), - tests.GetTestFilePath("sample-client-key.pem"), - ) - assertEqual(t, true, len(DefaultClient().t.TLSClientConfig.Certificates) == 3) - - SetOutputDirectory(testDataPath) - assertEqual(t, testDataPath, DefaultClient().outputDirectory) - - baseURL := "http://dummy-req.local/test" - SetBaseURL(baseURL) - assertEqual(t, baseURL, DefaultClient().BaseURL) - form := make(url.Values) form.Add("test", "test") - SetCommonFormDataFromValues(form) - assertEqual(t, form, DefaultClient().FormData) - - SetCommonFormData(map[string]string{"test2": "test2"}) - assertEqual(t, "test2", DefaultClient().FormData.Get("test2")) - - DisableAutoReadResponse() - assertEqual(t, true, DefaultClient().disableAutoReadResponse) - EnableAutoReadResponse() - assertEqual(t, false, DefaultClient().disableAutoReadResponse) - - EnableDumpAll() - opt := DefaultClient().getDumpOptions() - assertEqual(t, true, opt.RequestHeader == true && opt.RequestBody == true && opt.ResponseHeader == true && opt.ResponseBody == true) - EnableDumpAllAsync() - assertEqual(t, true, opt.Async) - EnableDumpAllWithoutBody() - assertEqual(t, true, opt.ResponseBody == false && opt.RequestBody == false) - opt.ResponseBody = true - opt.RequestBody = true - EnableDumpAllWithoutResponse() - assertEqual(t, true, opt.ResponseBody == false && opt.ResponseHeader == false) - opt.ResponseBody = true - opt.ResponseHeader = true - EnableDumpAllWithoutRequest() - assertEqual(t, true, opt.RequestHeader == false && opt.RequestBody == false) - opt.RequestHeader = true - opt.RequestBody = true - EnableDumpAllWithoutHeader() - assertEqual(t, true, opt.RequestHeader == false && opt.ResponseHeader == false) - - DefaultClient().getDumpOptions().Output = nil - SetLogger(nil) - EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")) - assertEqual(t, true, DefaultClient().getDumpOptions().Output == nil) - dumpFile := tests.GetTestFilePath("tmpdump.out") - EnableDumpAllToFile(dumpFile) - assertEqual(t, true, DefaultClient().getDumpOptions().Output != nil) - os.Remove(dumpFile) + testGlobalClientSettingWrapper(t, + SetCookieJar(nil), + SetDialTLS(testDialTLS), + SetDial(testDial), + SetTLSHandshakeTimeout(time.Second), + EnableAllowGetMethodPayload(), + DisableAllowGetMethodPayload(), + SetJsonMarshal(marshalFunc), + SetJsonUnmarshal(unmarshalFunc), + SetXmlMarshal(marshalFunc), + SetXmlUnmarshal(unmarshalFunc), + EnableTraceAll(), + DisableTraceAll(), + OnAfterResponse(func(client *Client, response *Response) error { + return nil + }), + OnBeforeRequest(func(client *Client, request *Request) error { + return nil + }), + SetProxyURL("http://dummy.proxy.local"), + SetProxyURL("bad url"), + SetProxy(proxy), + SetCommonContentType(jsonContentType), + SetCommonHeader("my-header", "my-value"), + SetCommonHeaders(map[string]string{ + "header1": "value1", + "header2": "value2", + }), + SetCommonBasicAuth("imroc", "123456"), + SetCommonBearerAuthToken("123456"), + SetUserAgent("test"), + SetTimeout(1*time.Second), + SetLogger(createDefaultLogger()), + SetScheme("https"), + EnableDebugLog(), + DisableDebugLog(), + SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}), + SetCommonQueryString("test1=test1"), + SetCommonPathParams(map[string]string{"test1": "test1"}), + SetCommonPathParam("test2", "test2"), + AddCommonQueryParam("test1", "test11"), + SetCommonQueryParam("test1", "test111"), + SetCommonQueryParams(map[string]string{"test1": "test1"}), + EnableInsecureSkipVerify(), + DisableInsecureSkipVerify(), + DisableCompression(), + EnableCompression(), + DisableKeepAlives(), + EnableKeepAlives(), + SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")), + SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))), + SetCerts(tls.Certificate{}, tls.Certificate{}), + SetCertFromFile( + tests.GetTestFilePath("sample-client.pem"), + tests.GetTestFilePath("sample-client-key.pem"), + ), + SetOutputDirectory(testDataPath), + SetBaseURL("http://dummy-req.local/test"), + SetCommonFormDataFromValues(form), + SetCommonFormData(map[string]string{"test2": "test2"}), + DisableAutoReadResponse(), + EnableAutoReadResponse(), + EnableDumpAll(), + EnableDumpAllAsync(), + EnableDumpAllWithoutBody(), + EnableDumpAllWithoutResponse(), + EnableDumpAllWithoutRequest(), + EnableDumpAllWithoutHeader(), + SetLogger(nil), + EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")), + EnableDumpAllToFile(tests.GetTestFilePath("tmpdump.out")), + SetCommonDumpOptions(&DumpOptions{ + RequestHeader: true, + }), + DisableDumpAll(), + SetRedirectPolicy(NoRedirectPolicy()), + EnableForceHTTP1(), + EnableForceHTTP2(), + DisableForceHttpVersion(), + SetAutoDecodeContentType("json"), + SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }), + SetAutoDecodeAllContentType(), + DisableAutoDecode(), + EnableAutoDecode(), + AddCommonRetryCondition(func(resp *Response, err error) bool { return true }), + SetCommonRetryCondition(func(resp *Response, err error) bool { return true }), + AddCommonRetryHook(func(resp *Response, err error) {}), + SetCommonRetryHook(func(resp *Response, err error) {}), + SetCommonRetryCount(2), + SetCommonRetryInterval(func(attempt int) time.Duration { + return 1 * time.Second + }), + SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), + SetCommonRetryFixedInterval(1*time.Second), + ) + os.Remove(tests.GetTestFilePath("tmpdump.out")) - SetCommonDumpOptions(&DumpOptions{ - RequestHeader: true, - }) - opt = DefaultClient().getDumpOptions() - assertEqual(t, true, opt.RequestHeader == true && opt.ResponseHeader == false) - DisableDumpAll() - assertEqual(t, true, DefaultClient().t.dump == nil) + config := GetTLSClientConfig() + assertEqual(t, config, DefaultClient().t.TLSClientConfig) - r = R() + r := R() assertEqual(t, true, r != nil) c := C() + c.SetTimeout(10 * time.Second) SetDefaultClient(c) assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - - SetRedirectPolicy(NoRedirectPolicy()) - assertEqual(t, true, DefaultClient().httpClient.CheckRedirect != nil) - - EnableForceHTTP1() - assertEqual(t, HTTP1, DefaultClient().t.ForceHttpVersion) - - EnableForceHTTP2() - assertEqual(t, HTTP2, DefaultClient().t.ForceHttpVersion) - - DisableForceHttpVersion() - assertEqual(t, true, DefaultClient().t.ForceHttpVersion == "") - assertEqual(t, GetClient(), DefaultClient().httpClient) r = NewRequest() assertEqual(t, true, r != nil) c = NewClient() assertEqual(t, true, c != nil) - - DefaultClient().getResponseOptions().AutoDecodeContentType = nil - SetAutoDecodeContentType("json") - assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) - - DefaultClient().getResponseOptions().AutoDecodeContentType = nil - SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }) - assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) - - DefaultClient().getResponseOptions().AutoDecodeContentType = nil - SetAutoDecodeAllContentType() - assertEqual(t, true, DefaultClient().getResponseOptions().AutoDecodeContentType != nil) - - DisableAutoDecode() - assertEqual(t, true, DefaultClient().getResponseOptions().DisableAutoDecode) - - EnableAutoDecode() - assertEqual(t, false, DefaultClient().getResponseOptions().DisableAutoDecode) } func TestTrailer(t *testing.T) { - resp, err := tc().EnableForceHTTP1().R().EnableDump().Get("/chunked") + resp, err := tc().EnableForceHTTP1().R().Get("/chunked") assertSuccess(t, resp, err) _, ok := resp.Trailer["Expires"] if !ok { t.Error("trailer not exists") } - r := tc().EnableForceHTTP1().R() - r.RawRequest.Trailer = make(http.Header) - r.RawRequest.Trailer.Add("test", "") - resp, err = r.SetBody("test").Post("/") - assertSuccess(t, resp, err) } diff --git a/request.go b/request.go index 526de6dd..5ba4fda0 100644 --- a/request.go +++ b/request.go @@ -3,6 +3,7 @@ package req import ( "bytes" "context" + "errors" "github.com/hashicorp/go-multierror" "github.com/imroc/req/v3/internal/util" "io" @@ -19,19 +20,26 @@ import ( // req client. Request provides lots of chainable settings which can // override client level settings. type Request struct { - URL string - PathParams map[string]string - QueryParams urlpkg.Values - FormData urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - Result interface{} - Error interface{} - error error - client *Client - RawRequest *http.Request - StartTime time.Time - + PathParams map[string]string + QueryParams urlpkg.Values + FormData urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + Result interface{} + Error interface{} + error error + client *Client + RawRequest *http.Request + StartTime time.Time + RetryAttempt int + + RawURL string // read only + method string + URL *urlpkg.URL + getBody GetContentFunc + unReplayableBody io.ReadCloser + retryOption *retryOption + bodyReadCloser io.ReadCloser body []byte dumpOptions *DumpOptions marshalBody interface{} @@ -47,6 +55,8 @@ type Request struct { responseReturnTime time.Time } +type GetContentFunc func() (io.ReadCloser, error) + func (r *Request) getHeader(key string) string { if r.Headers == nil { return "" @@ -211,7 +221,12 @@ func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *R r.SetFileUpload(FileUpload{ ParamName: paramName, FileName: filename, - File: reader, + GetFileContent: func() (io.ReadCloser, error) { + if rc, ok := reader.(io.ReadCloser); ok { + return rc, nil + } + return ioutil.NopCloser(reader), nil + }, }) return r } @@ -267,11 +282,30 @@ func SetFileUpload(f ...FileUpload) *Request { return defaultClient.R().SetFileUpload(f...) } +var errMissingParamName = errors.New("missing param name in multipart file upload") +var errMissingFileName = errors.New("missing filename in multipart file upload") +var errMissingFileContent = errors.New("missing file content in multipart file upload") + // SetFileUpload set the fully custimized multipart file upload options. func (r *Request) SetFileUpload(uploads ...FileUpload) *Request { r.isMultiPart = true for _, upload := range uploads { - r.uploadFiles = append(r.uploadFiles, &upload) + shouldAppend := true + if upload.ParamName == "" { + r.appendError(errMissingParamName) + shouldAppend = false + } + if upload.FileName == "" { + r.appendError(errMissingFileName) + shouldAppend = false + } + if upload.GetFileContent == nil { + r.appendError(errMissingFileContent) + shouldAppend = false + } + if shouldAppend { + r.uploadFiles = append(r.uploadFiles, &upload) + } } return r } @@ -460,6 +494,8 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } +var errRetryableWithUnReplayableBody = errors.New("retryable request should not have unreplayable body (io.Reader)") + // Send fires http request and return the *Response which is always // not nil, and the error is not nil if some error happens. func (r *Request) Send(method, url string) (*Response, error) { @@ -469,8 +505,11 @@ func (r *Request) Send(method, url string) (*Response, error) { if r.error != nil { return &Response{Request: r}, r.error } - r.RawRequest.Method = method - r.URL = url + if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable body + return &Response{Request: r}, errRetryableWithUnReplayableBody + } + r.method = method + r.RawURL = url return r.client.do(r) } @@ -676,13 +715,23 @@ func (r *Request) SetBody(body interface{}) *Request { } switch b := body.(type) { case io.ReadCloser: - r.RawRequest.Body = b + r.unReplayableBody = b + r.getBody = func() (io.ReadCloser, error) { + return r.unReplayableBody, nil + } case io.Reader: - r.RawRequest.Body = ioutil.NopCloser(b) + r.unReplayableBody = ioutil.NopCloser(b) + r.getBody = func() (io.ReadCloser, error) { + return r.unReplayableBody, nil + } case []byte: r.SetBodyBytes(b) case string: r.SetBodyString(b) + case func() (io.ReadCloser, error): + r.getBody = b + case GetContentFunc: + r.getBody = b default: r.marshalBody = body } @@ -697,8 +746,10 @@ func SetBodyBytes(body []byte) *Request { // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { - r.RawRequest.Body = ioutil.NopCloser(bytes.NewReader(body)) r.body = body + r.getBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(body)), nil + } return r } @@ -1019,3 +1070,123 @@ func (r *Request) EnableDumpWithoutResponseBody() *Request { o.ResponseBody = false return r.EnableDump() } + +func (r *Request) getRetryOption() *retryOption { + if r.retryOption == nil { + r.retryOption = newDefaultRetryOption() + } + return r.retryOption +} + +// SetRetryCount is a global wrapper methods which delegated +// to the default client, create a request and SetRetryCount for request. +func SetRetryCount(count int) *Request { + return defaultClient.R().SetRetryCount(count) +} + +// SetRetryCount enables retry and set the maximum retry count. +func (r *Request) SetRetryCount(count int) *Request { + r.getRetryOption().MaxRetries = count + return r +} + +// SetRetryInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryInterval for request. +func SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { + return defaultClient.R().SetRetryInterval(getRetryIntervalFunc) +} + +// SetRetryInterval sets the custom GetRetryIntervalFunc, you can use this to +// implement your own backoff retry algorithm. +// For example: +// req.SetRetryInterval(func(attempt int) time.Duration { +// sleep := 0.01 * math.Exp2(float64(attempt)) +// return time.Duration(math.Min(2, sleep)) * time.Second +// }) +func (r *Request) SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { + r.getRetryOption().GetRetryInterval = getRetryIntervalFunc + return r +} + +// SetRetryFixedInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryFixedInterval for request. +func SetRetryFixedInterval(interval time.Duration) *Request { + return defaultClient.R().SetRetryFixedInterval(interval) +} + +// SetRetryFixedInterval set retry to use a fixed interval. +func (r *Request) SetRetryFixedInterval(interval time.Duration) *Request { + r.getRetryOption().GetRetryInterval = func(attempt int) time.Duration { + return interval + } + return r +} + +// SetRetryBackoffInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryBackoffInterval for request. +func SetRetryBackoffInterval(min, max time.Duration) *Request { + return defaultClient.R().SetRetryBackoffInterval(min, max) +} + +// SetRetryBackoffInterval set retry to use a capped exponential backoff with jitter. +// https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ +func (r *Request) SetRetryBackoffInterval(min, max time.Duration) *Request { + r.getRetryOption().GetRetryInterval = backoffInterval(min, max) + return r +} + +// SetRetryHook is a global wrapper methods which delegated +// to the default client, create a request and SetRetryHook for request. +func SetRetryHook(hook RetryHookFunc) *Request { + return defaultClient.R().SetRetryHook(hook) +} + +// SetRetryHook set the retry hook which will be executed before a retry. +// It will override other retry hooks if any been added before (including +// client-level retry hooks). +func (r *Request) SetRetryHook(hook RetryHookFunc) *Request { + r.getRetryOption().RetryHooks = []RetryHookFunc{hook} + return r +} + +// AddRetryHook is a global wrapper methods which delegated +// to the default client, create a request and AddRetryHook for request. +func AddRetryHook(hook RetryHookFunc) *Request { + return defaultClient.R().AddRetryHook(hook) +} + +// AddRetryHook adds a retry hook which will be executed before a retry. +func (r *Request) AddRetryHook(hook RetryHookFunc) *Request { + ro := r.getRetryOption() + ro.RetryHooks = append(ro.RetryHooks, hook) + return r +} + +// SetRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and SetRetryCondition for request. +func SetRetryCondition(condition RetryConditionFunc) *Request { + return defaultClient.R().SetRetryCondition(condition) +} + +// SetRetryCondition sets the retry condition, which determines whether the +// request should retry. +// It will override other retry conditions if any been added before (including +// client-level retry conditions). +func (r *Request) SetRetryCondition(condition RetryConditionFunc) *Request { + r.getRetryOption().RetryConditions = []RetryConditionFunc{condition} + return r +} + +// AddRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and AddRetryCondition for request. +func AddRetryCondition(condition RetryConditionFunc) *Request { + return defaultClient.R().AddRetryCondition(condition) +} + +// AddRetryCondition adds a retry condition, which determines whether the +// request should retry. +func (r *Request) AddRetryCondition(condition RetryConditionFunc) *Request { + ro := r.getRetryOption() + ro.RetryConditions = append(ro.RetryConditions, condition) + return r +} diff --git a/request_test.go b/request_test.go index 17757b14..4b16661d 100644 --- a/request_test.go +++ b/request_test.go @@ -1,12 +1,16 @@ package req import ( + "bytes" "encoding/json" "encoding/xml" "fmt" "github.com/imroc/req/v3/internal/tests" + "io" + "io/ioutil" "net/http" "net/url" + "os" "strings" "testing" "time" @@ -71,6 +75,14 @@ func TestEnableDump(t *testing.T) { }) } +func TestEnableDumpToFIle(t *testing.T) { + tmpFile := "tmp_dumpfile_req" + resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") + assertSuccess(t, resp, err) + assertEqual(t, true, len(tests.GetTestFileContent(t, tmpFile)) > 0) + os.Remove(tests.GetTestFilePath(tmpFile)) +} + func TestEnableDumpWithoutRequest(t *testing.T) { testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { r.EnableDumpWithoutRequest() @@ -282,6 +294,33 @@ func testSetBodyMarshal(t *testing.T, c *Client) { } } +func TestSetBodyReader(t *testing.T) { + var e echo + resp, err := tc().R().SetBody(ioutil.NopCloser(bytes.NewBufferString("hello"))).SetResult(&e).Post("/echo") + assertSuccess(t, resp, err) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) +} + +func TestSetBodyGetContentFunc(t *testing.T) { + var e echo + resp, err := tc().R().SetBody(func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBufferString("hello")), nil + }).SetResult(&e).Post("/echo") + assertSuccess(t, resp, err) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) + + e = echo{} + var fn GetContentFunc = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBufferString("hello")), nil + } + resp, err = tc().R().SetBody(fn).SetResult(&e).Post("/echo") + assertSuccess(t, resp, err) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) +} + func TestSetBodyContent(t *testing.T) { testSetBodyContent(t, tc()) testSetBodyContent(t, tc().EnableForceHTTP1()) @@ -684,6 +723,15 @@ func TestAutoDetectRequestContentType(t *testing.T) { assertEqual(t, plainTextContentType, resp.String()) } +func TestSetFileUploadCheck(t *testing.T) { + c := tc() + resp, err := c.R().SetFileUpload(FileUpload{}).Post("/multipart") + tests.AssertErrorContains(t, err, "missing param name") + tests.AssertErrorContains(t, err, "missing filename") + tests.AssertErrorContains(t, err, "missing file content") + assertEqual(t, 0, len(resp.Request.uploadFiles)) +} + func TestUploadMultipart(t *testing.T) { m := make(map[string]interface{}) resp, err := tc().R(). diff --git a/retry.go b/retry.go new file mode 100644 index 00000000..92f2ee8f --- /dev/null +++ b/retry.go @@ -0,0 +1,63 @@ +package req + +import ( + "math" + "math/rand" + "time" +) + +func defaultGetRetryInterval(attempt int) time.Duration { + return 100 * time.Millisecond +} + +// RetryConditionFunc is a retry condition, which determines +// whether the request should retry. +type RetryConditionFunc func(*Response, error) bool + +// RetryHookFunc is a retry hook which will be executed before a retry. +type RetryHookFunc func(*Response, error) + +// GetRetryIntervalFunc is a function that determines how long should +// sleep between retry attempts. +type GetRetryIntervalFunc func(attempt int) time.Duration + +func backoffInterval(min, max time.Duration) GetRetryIntervalFunc { + base := float64(min) + capLevel := float64(max) + return func(attempt int) time.Duration { + temp := math.Min(capLevel, base*math.Exp2(float64(attempt))) + halfTemp := int64(temp / 2) + sleep := halfTemp + rand.Int63n(halfTemp) + return time.Duration(sleep) + } +} + +func newDefaultRetryOption() *retryOption { + return &retryOption{ + GetRetryInterval: defaultGetRetryInterval, + } +} + +type retryOption struct { + MaxRetries int + GetRetryInterval GetRetryIntervalFunc + RetryConditions []RetryConditionFunc + RetryHooks []RetryHookFunc +} + +func (ro *retryOption) Clone() *retryOption { + if ro == nil { + return nil + } + o := &retryOption{ + MaxRetries: ro.MaxRetries, + GetRetryInterval: ro.GetRetryInterval, + } + for _, c := range ro.RetryConditions { + o.RetryConditions = append(o.RetryConditions, c) + } + for _, h := range ro.RetryHooks { + o.RetryHooks = append(o.RetryHooks, h) + } + return o +} diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 00000000..d97ffc37 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,129 @@ +package req + +import ( + "bytes" + "github.com/imroc/req/v3/internal/tests" + "io/ioutil" + "math" + "net/http" + "testing" + "time" +) + +func TestRetryBackOff(t *testing.T) { + testRetry(t, func(r *Request) { + r.SetRetryBackoffInterval(10*time.Millisecond, 1*time.Second) + }) +} + +func testRetry(t *testing.T, setFunc func(r *Request)) { + attempt := 0 + r := tc().R(). + SetRetryCount(3). + SetRetryCondition(func(resp *Response, err error) bool { + return (err != nil) || (resp.StatusCode == http.StatusTooManyRequests) + }). + SetRetryHook(func(resp *Response, err error) { + attempt++ + }) + setFunc(r) + resp, err := r.Get("/too-many") + tests.AssertNoError(t, err) + assertEqual(t, 3, resp.Request.RetryAttempt) + assertEqual(t, 3, attempt) +} + +func TestRetryInterval(t *testing.T) { + testRetry(t, func(r *Request) { + r.SetRetryInterval(func(attempt int) time.Duration { + sleep := 0.01 * math.Exp2(float64(attempt)) + return time.Duration(math.Min(2, sleep)) * time.Second + }) + }) +} + +func TestRetryFixedInterval(t *testing.T) { + testRetry(t, func(r *Request) { + r.SetRetryFixedInterval(1 * time.Millisecond) + }) +} + +func TestAddRetryHook(t *testing.T) { + test := "test1" + testRetry(t, func(r *Request) { + r.AddRetryHook(func(resp *Response, err error) { + test = "test2" + }) + }) + assertEqual(t, "test2", test) +} + +func TestRetryOverride(t *testing.T) { + c := tc(). + SetCommonRetryCount(3). + SetCommonRetryHook(func(resp *Response, err error) {}). + AddCommonRetryHook(func(resp *Response, err error) {}). + SetCommonRetryCondition(func(resp *Response, err error) bool { + return false + }).SetCommonRetryBackoffInterval(1*time.Millisecond, 10*time.Millisecond) + test := "test" + resp, err := c.R().SetRetryFixedInterval(2 * time.Millisecond). + SetRetryCount(2). + SetRetryHook(func(resp *Response, err error) { + test = "test1" + }).SetRetryCondition(func(resp *Response, err error) bool { + return err != nil || resp.StatusCode == http.StatusTooManyRequests + }).Get("/too-many") + tests.AssertNoError(t, err) + assertEqual(t, "test1", test) + assertEqual(t, 2, resp.Request.RetryAttempt) +} + +func TestAddRetryCondition(t *testing.T) { + attempt := 0 + resp, err := tc().R(). + SetRetryCount(3). + AddRetryCondition(func(resp *Response, err error) bool { + return err != nil + }). + AddRetryCondition(func(resp *Response, err error) bool { + return resp.StatusCode == http.StatusServiceUnavailable + }). + SetRetryHook(func(resp *Response, err error) { + attempt++ + }).Get("/too-many") + tests.AssertNoError(t, err) + assertEqual(t, 0, attempt) + assertEqual(t, 0, resp.Request.RetryAttempt) + + attempt = 0 + resp, err = tc(). + SetCommonRetryCount(3). + AddCommonRetryCondition(func(resp *Response, err error) bool { + return err != nil + }). + AddCommonRetryCondition(func(resp *Response, err error) bool { + return resp.StatusCode == http.StatusServiceUnavailable + }). + SetCommonRetryHook(func(resp *Response, err error) { + attempt++ + }).R().Get("/too-many") + tests.AssertNoError(t, err) + assertEqual(t, 0, attempt) + assertEqual(t, 0, resp.Request.RetryAttempt) + +} + +func TestRetryWithUnreplayableBody(t *testing.T) { + _, err := tc().R(). + SetRetryCount(1). + SetBody(bytes.NewBufferString("test")). + Post("/") + assertEqual(t, errRetryableWithUnReplayableBody, err) + + _, err = tc().R(). + SetRetryCount(1). + SetBody(ioutil.NopCloser(bytes.NewBufferString("test"))). + Post("/") + assertEqual(t, errRetryableWithUnReplayableBody, err) +} From db6a663f23e19e189027c4b904aef603d5d1168a Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 17:59:36 +0800 Subject: [PATCH 424/843] api doc add retry --- docs/api.md | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 8a0529d6..6e8e5d06 100644 --- a/docs/api.md +++ b/docs/api.md @@ -13,6 +13,7 @@ Here is a brief and categorized list of the core APIs, for a more detailed and c * [TLS and Certificates](#Certs) * [Marshal&Unmarshal](#Marshal) * [HTTP Version](#Version) + * [Retry](#Retry-Client) * [Other Settings](#Other) * [Request Settings](#Request) * [URL Query and Path Parameter](#Query) @@ -21,6 +22,7 @@ Here is a brief and categorized list of the core APIs, for a more detailed and c * [Request Level Debug](#Debug-Request) * [Multipart & Form & Upload](#Multipart) * [Download](#Download) + * [Retry](#Retry) * [Other Settings](#Other-Request) * [Sending Request](#Send-Request) @@ -109,6 +111,17 @@ Basically, you can know the meaning of most settings directly from the method na * [EnableForceHTTP2()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP2) * [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) +### Retry + +* [SetCommonRetryCount(count int)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryCount) +* [SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryInterval) +* [SetCommonRetryFixedInterval(interval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryFixedInterval) +* [SetCommonRetryBackoffInterval(min, max time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryBackoffInterval) +* [SetCommonRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryHook) +* [AddCommonRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonRetryHook) +* [SetCommonRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryCondition) +* [AddCommonRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonRetryCondition) + ### Other Settings * [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) @@ -140,7 +153,6 @@ Basically, you can know the meaning of most settings directly from the method na * [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. * [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) - ## Request Settings The following are the chainable settings of Request, all of which have corresponding global wrappers. @@ -209,6 +221,17 @@ Basically, you can know the meaning of most settings directly from the method na * [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) * [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) +### Retry + +* [SetRetryCount(count int)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryCount) +* [SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryInterval) +* [SetRetryFixedInterval(interval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryFixedInterval) +* [SetRetryBackoffInterval(min, max time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryBackoffInterval) +* [SetRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryHook) +* [AddRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddRetryHook) +* [SetRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryCondition) +* [AddRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddRetryCondition) + ### Other Settings * [SetContext(ctx context.Context)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContext) From 6b5e595f6cf0284fb21356cf357a3a72d347d6c5 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 18:01:27 +0800 Subject: [PATCH 425/843] remove blank line in api doc --- docs/api.md | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/docs/api.md b/docs/api.md index 6e8e5d06..abc8ac2d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -35,11 +35,9 @@ Basically, you can know the meaning of most settings directly from the method na ### Debug Features * [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). - * [EnableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDebugLog) - Enable debug level log (disabled by default). * [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) * [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger. - * [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests. * [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) * [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) @@ -52,7 +50,6 @@ Basically, you can know the meaning of most settings directly from the method na * [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) * [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) * [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) - * [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). * [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) @@ -125,31 +122,21 @@ Basically, you can know the meaning of most settings directly from the method na ### Other Settings * [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) - * [EnableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableKeepAlives) * [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Enabled by default. - * [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) * [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) - * [SetProxyURL(proxyUrl string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxyURL) * [SetProxy(proxy func(*http.Request) (*urlpkg.URL, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxy) - * [SetOutputDirectory(dir string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetOutputDirectory) - * [SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDialTLS) * [SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDial) - * [SetCookieJar(jar http.CookieJar)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCookieJar) - * [SetRedirectPolicy(policies ...RedirectPolicy)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRedirectPolicy) - * [EnableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableCompression) * [DisableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableCompression) - Enabled by default - * [EnableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoReadResponse) * [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Enabled by default - * [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. * [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) From 85c453455ba0ce028b657fed622abe1343410899 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 19:53:01 +0800 Subject: [PATCH 426/843] optimize retry --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ client.go | 6 +++--- docs/api.md | 4 ++-- request.go | 4 ++-- retry.go | 6 +++--- 5 files changed, 50 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ee079ec3..bae0b850 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Request and Response Middleware](#Middleware) * [Redirect Policy](#Redirect) * [Proxy](#Proxy) +* [Retry](#Retry) * [TODO List](#TODO) * [License](#License) @@ -879,6 +880,45 @@ client.SetProxy(func(request *http.Request) (*url.URL, error) { client.SetProxy(nil) ``` +## Retry + +You can enable retry for all requests at client-level (check the full list of client-level retry settings around [here](./docs/api.md#Retry-Client)): + +```bash +client := req.C() +client.SetCommonRetryCount(3). // enable retry and set the maximum retry count. + SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second). // set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). + AddCommonRetryHook(func(resp *req.Response, err error){ // add a retry hook which will be executed before a retry. + req := resp.Request.RawRequest + fmt.Println("Retry request:", req.Method, req.URL) + }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add a retry condition which determines whether the request should retry. + return err != nil + }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add another retry condition + return resp.StatusCode == http.StatusTooManyRequests + }) +``` + +You can also override retry settings at request-level (check the full list of request-level retry settings around [here](./docs/api.md#Retry-Request)): + +```bash +client.R(). + SetRetryCount(3). + SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { // use a custom retry interval algorithm. + // sleep seconds from "Retry-After" response header if given and correct, otherwise sleep 2 seconds (https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html). + if resp.Response != nil { + ra := resp.Header.Get("Retry-After") + if ra != "" { + after, err := strconv.Atoi(ra) + if err == nil { + return time.Duration(after) * time.Second + } + } + } + return 2 * time.Second + }).SetRetryHook(hook). // unlike add, set will remove all other retry hooks which is added before at both request and client level. + SetRetryCondition(condition) // similarly, this will remove all other retry conditions which is added before at both request and client level. +``` + ## TODO List * [ ] Wrap more transport settings into client. diff --git a/client.go b/client.go index 758bbc3b..0ef306d3 100644 --- a/client.go +++ b/client.go @@ -1245,7 +1245,7 @@ func SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { // SetCommonRetryInterval sets the custom GetRetryIntervalFunc for all requests, // you can use this to implement your own backoff retry algorithm. // For example: -// req.SetCommonRetryInterval(func(attempt int) time.Duration { +// req.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { // sleep := 0.01 * math.Exp2(float64(attempt)) // return time.Duration(math.Min(2, sleep)) * time.Second // }) @@ -1262,7 +1262,7 @@ func SetCommonRetryFixedInterval(interval time.Duration) *Client { // SetCommonRetryFixedInterval set retry to use a fixed interval for all requests. func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { - c.getRetryOption().GetRetryInterval = func(attempt int) time.Duration { + c.getRetryOption().GetRetryInterval = func(resp *Response, attempt int) time.Duration { return interval } return c @@ -1533,7 +1533,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { for _, hook := range r.retryOption.RetryHooks { // run retry hooks hook(resp, err) } - time.Sleep(r.retryOption.GetRetryInterval(r.RetryAttempt)) + time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) // clean buffers if r.dumpBuffer != nil { diff --git a/docs/api.md b/docs/api.md index abc8ac2d..e424f298 100644 --- a/docs/api.md +++ b/docs/api.md @@ -22,7 +22,7 @@ Here is a brief and categorized list of the core APIs, for a more detailed and c * [Request Level Debug](#Debug-Request) * [Multipart & Form & Upload](#Multipart) * [Download](#Download) - * [Retry](#Retry) + * [Retry](#Retry-Request) * [Other Settings](#Other-Request) * [Sending Request](#Send-Request) @@ -208,7 +208,7 @@ Basically, you can know the meaning of most settings directly from the method na * [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) * [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) -### Retry +### Retry * [SetRetryCount(count int)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryCount) * [SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryInterval) diff --git a/request.go b/request.go index 5ba4fda0..7cf45c16 100644 --- a/request.go +++ b/request.go @@ -1099,7 +1099,7 @@ func SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { // SetRetryInterval sets the custom GetRetryIntervalFunc, you can use this to // implement your own backoff retry algorithm. // For example: -// req.SetRetryInterval(func(attempt int) time.Duration { +// req.SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { // sleep := 0.01 * math.Exp2(float64(attempt)) // return time.Duration(math.Min(2, sleep)) * time.Second // }) @@ -1116,7 +1116,7 @@ func SetRetryFixedInterval(interval time.Duration) *Request { // SetRetryFixedInterval set retry to use a fixed interval. func (r *Request) SetRetryFixedInterval(interval time.Duration) *Request { - r.getRetryOption().GetRetryInterval = func(attempt int) time.Duration { + r.getRetryOption().GetRetryInterval = func(resp *Response, attempt int) time.Duration { return interval } return r diff --git a/retry.go b/retry.go index 92f2ee8f..055d0df6 100644 --- a/retry.go +++ b/retry.go @@ -6,7 +6,7 @@ import ( "time" ) -func defaultGetRetryInterval(attempt int) time.Duration { +func defaultGetRetryInterval(resp *Response, attempt int) time.Duration { return 100 * time.Millisecond } @@ -19,12 +19,12 @@ type RetryHookFunc func(*Response, error) // GetRetryIntervalFunc is a function that determines how long should // sleep between retry attempts. -type GetRetryIntervalFunc func(attempt int) time.Duration +type GetRetryIntervalFunc func(resp *Response, attempt int) time.Duration func backoffInterval(min, max time.Duration) GetRetryIntervalFunc { base := float64(min) capLevel := float64(max) - return func(attempt int) time.Duration { + return func(resp *Response, attempt int) time.Duration { temp := math.Min(capLevel, base*math.Exp2(float64(attempt))) halfTemp := int64(temp / 2) sleep := halfTemp + rand.Int63n(halfTemp) From 2c106f42a7b1f2d9ea053642271cc86ec69594c0 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 19:58:42 +0800 Subject: [PATCH 427/843] fix retry tests --- README.md | 25 +++++++++++++++---------- req_test.go | 4 ++-- retry_test.go | 2 +- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index bae0b850..9fb4b036 100644 --- a/README.md +++ b/README.md @@ -886,16 +886,21 @@ You can enable retry for all requests at client-level (check the full list of cl ```bash client := req.C() -client.SetCommonRetryCount(3). // enable retry and set the maximum retry count. - SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second). // set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). - AddCommonRetryHook(func(resp *req.Response, err error){ // add a retry hook which will be executed before a retry. - req := resp.Request.RawRequest - fmt.Println("Retry request:", req.Method, req.URL) - }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add a retry condition which determines whether the request should retry. - return err != nil - }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add another retry condition - return resp.StatusCode == http.StatusTooManyRequests - }) + // enable retry and set the maximum retry count. + client.SetCommonRetryCount(3). + // set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). + SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second). + // add a retry hook which will be executed before a retry. + AddCommonRetryHook(func(resp *req.Response, err error){ + req := resp.Request.RawRequest + fmt.Println("Retry request:", req.Method, req.URL) + }). + // add a retry condition which determines whether the request should retry. + AddCommonRetryCondition(func(resp *req.Response, err error) bool { + return err != nil + }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add another retry condition + return resp.StatusCode == http.StatusTooManyRequests + }) ``` You can also override retry settings at request-level (check the full list of request-level retry settings around [here](./docs/api.md#Retry-Request)): diff --git a/req_test.go b/req_test.go index 57e46bf6..4790144f 100644 --- a/req_test.go +++ b/req_test.go @@ -578,7 +578,7 @@ func TestGlobalWrapperSetRequest(t *testing.T) { SetRetryHook(func(resp *Response, err error) {}), SetRetryBackoffInterval(1*time.Millisecond, 500*time.Millisecond), SetRetryFixedInterval(1*time.Millisecond), - SetRetryInterval(func(attempt int) time.Duration { + SetRetryInterval(func(resp *Response, attempt int) time.Duration { return 1 * time.Millisecond }), SetRetryCount(3), @@ -719,7 +719,7 @@ func TestGlobalWrapper(t *testing.T) { AddCommonRetryHook(func(resp *Response, err error) {}), SetCommonRetryHook(func(resp *Response, err error) {}), SetCommonRetryCount(2), - SetCommonRetryInterval(func(attempt int) time.Duration { + SetCommonRetryInterval(func(resp *Response, attempt int) time.Duration { return 1 * time.Second }), SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), diff --git a/retry_test.go b/retry_test.go index d97ffc37..7181d125 100644 --- a/retry_test.go +++ b/retry_test.go @@ -35,7 +35,7 @@ func testRetry(t *testing.T, setFunc func(r *Request)) { func TestRetryInterval(t *testing.T) { testRetry(t, func(r *Request) { - r.SetRetryInterval(func(attempt int) time.Duration { + r.SetRetryInterval(func(resp *Response, attempt int) time.Duration { sleep := 0.01 * math.Exp2(float64(attempt)) return time.Duration(math.Min(2, sleep)) * time.Second }) From fc792d29ca107de9f7e59a345eb43c439490677f Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 20:11:03 +0800 Subject: [PATCH 428/843] update README: improve retry doc --- README.md | 77 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 9fb4b036..b557acc4 100644 --- a/README.md +++ b/README.md @@ -886,48 +886,63 @@ You can enable retry for all requests at client-level (check the full list of cl ```bash client := req.C() - // enable retry and set the maximum retry count. - client.SetCommonRetryCount(3). - // set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). - SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second). - // add a retry hook which will be executed before a retry. - AddCommonRetryHook(func(resp *req.Response, err error){ - req := resp.Request.RawRequest - fmt.Println("Retry request:", req.Method, req.URL) - }). - // add a retry condition which determines whether the request should retry. - AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return err != nil - }).AddCommonRetryCondition(func(resp *req.Response, err error) bool { // add another retry condition - return resp.StatusCode == http.StatusTooManyRequests - }) + +// Enable retry and set the maximum retry count. +client.SetCommonRetryCount(3). + +// Set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). +client.SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second) + +// Set the retry to sleep fixed interval of 2 seconds. +client.SetCommonRetryFixedInterval(2 * time.Seconds) + +// Set the retry to use a custom retry interval algorithm. +client.SetCommonRetryFixedInterval(func(resp *req.Response, attempt int) time.Duration { + // Sleep seconds from "Retry-After" response header if it is present and correct (https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html). + if resp.Response != nil { + ra := resp.Header.Get("Retry-After") + if ra != "" { + after, err := strconv.Atoi(ra) + if err == nil { + return time.Duration(after) * time.Second + } + } + } + return 2 * time.Second // Otherwise sleep 2 seconds +}) + +// Add a retry hook which will be executed before a retry. +client.AddCommonRetryHook(func(resp *req.Response, err error){ + req := resp.Request.RawRequest + fmt.Println("Retry request:", req.Method, req.URL) +}) + +// Add a retry condition which determines whether the request should retry. +client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { + return err != nil +}) + +// Add another retry condition +client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { + return resp.StatusCode == http.StatusTooManyRequests +}) ``` You can also override retry settings at request-level (check the full list of request-level retry settings around [here](./docs/api.md#Retry-Request)): ```bash client.R(). - SetRetryCount(3). - SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { // use a custom retry interval algorithm. - // sleep seconds from "Retry-After" response header if given and correct, otherwise sleep 2 seconds (https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html). - if resp.Response != nil { - ra := resp.Header.Get("Retry-After") - if ra != "" { - after, err := strconv.Atoi(ra) - if err == nil { - return time.Duration(after) * time.Second - } - } - } - return 2 * time.Second - }).SetRetryHook(hook). // unlike add, set will remove all other retry hooks which is added before at both request and client level. - SetRetryCondition(condition) // similarly, this will remove all other retry conditions which is added before at both request and client level. + SetRetryCount(2). + SetRetryInterval(intervalFunc). + SetRetryHook(hookFunc1). // Unlike add, set will remove all other retry hooks which is added before at both request and client level. + AddRetryHook(hookFunc2). + SetRetryCondition(conditionFunc1). // Similarly, this will remove all other retry conditions which is added before at both request and client level. + AddRetryCondition(conditionFunc2) ``` ## TODO List * [ ] Wrap more transport settings into client. -* [ ] Support retry. * [ ] Support unix socket. * [ ] Support h2c. * [ ] Design a logo. From dfee29ece77a7d3740f22c4ac744b80e03933a44 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 20:12:48 +0800 Subject: [PATCH 429/843] update README: fix typo --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b557acc4..eaff2e7e 100644 --- a/README.md +++ b/README.md @@ -884,7 +884,7 @@ client.SetProxy(nil) You can enable retry for all requests at client-level (check the full list of client-level retry settings around [here](./docs/api.md#Retry-Client)): -```bash +```go client := req.C() // Enable retry and set the maximum retry count. @@ -897,7 +897,7 @@ client.SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second) client.SetCommonRetryFixedInterval(2 * time.Seconds) // Set the retry to use a custom retry interval algorithm. -client.SetCommonRetryFixedInterval(func(resp *req.Response, attempt int) time.Duration { +client.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { // Sleep seconds from "Retry-After" response header if it is present and correct (https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html). if resp.Response != nil { ra := resp.Header.Get("Retry-After") @@ -930,7 +930,7 @@ client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { You can also override retry settings at request-level (check the full list of request-level retry settings around [here](./docs/api.md#Retry-Request)): -```bash +```go client.R(). SetRetryCount(2). SetRetryInterval(intervalFunc). From 4600224cd63d7cf5eca338ff605374f82beac96d Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 20:14:00 +0800 Subject: [PATCH 430/843] update README: fix blank --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eaff2e7e..b4b24d1a 100644 --- a/README.md +++ b/README.md @@ -913,18 +913,18 @@ client.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duratio // Add a retry hook which will be executed before a retry. client.AddCommonRetryHook(func(resp *req.Response, err error){ - req := resp.Request.RawRequest - fmt.Println("Retry request:", req.Method, req.URL) + req := resp.Request.RawRequest + fmt.Println("Retry request:", req.Method, req.URL) }) // Add a retry condition which determines whether the request should retry. client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return err != nil + return err != nil }) // Add another retry condition client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return resp.StatusCode == http.StatusTooManyRequests + return resp.StatusCode == http.StatusTooManyRequests }) ``` From 6938dfd0072bcbf710783165a1d050dee410a4c1 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 21:21:33 +0800 Subject: [PATCH 431/843] support unix socket --- README.md | 11 +++++++++++ client.go | 16 ++++++++++++++++ docs/api.md | 1 + middleware.go | 4 ++++ req_test.go | 1 + 5 files changed, 33 insertions(+) diff --git a/README.md b/README.md index b4b24d1a..31ef62f2 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ If you want to use the older version, check it out on [v1 branch](https://github * [Request and Response Middleware](#Middleware) * [Redirect Policy](#Redirect) * [Proxy](#Proxy) +* [Unix Socket](#Unix) * [Retry](#Retry) * [TODO List](#TODO) * [License](#License) @@ -880,6 +881,16 @@ client.SetProxy(func(request *http.Request) (*url.URL, error) { client.SetProxy(nil) ``` +## Unix Socket + +```go +client := req.C() +client.SetUnixSocket("/var/run/custom.sock") +client.SetBaseURL("http://example.local") + +resp, err := client.R().Get("/index.html") +``` + ## Retry You can enable retry for all requests at client-level (check the full list of client-level retry settings around [here](./docs/api.md#Retry-Client)): diff --git a/client.go b/client.go index 0ef306d3..c1f94946 100644 --- a/client.go +++ b/client.go @@ -1337,6 +1337,22 @@ func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { return c } +// SetUnixSocket is a global wrapper methods which delegated +// to the default client, create a request and SetUnixSocket for request. +func SetUnixSocket(file string) *Client { + return defaultClient.SetUnixSocket(file) +} + +// SetUnixSocket set client to dial connection use unix socket. +// For example: +// client.SetUnixSocket("/var/run/custom.sock") +func (c *Client) SetUnixSocket(file string) *Client { + return c.SetDial(func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", file) + }) +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/docs/api.md b/docs/api.md index e424f298..65861129 100644 --- a/docs/api.md +++ b/docs/api.md @@ -139,6 +139,7 @@ Basically, you can know the meaning of most settings directly from the method na * [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Enabled by default * [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. * [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) +* [SetUnixSocket(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUnixSocket) ## Request Settings diff --git a/middleware.go b/middleware.go index b542313f..7976de88 100644 --- a/middleware.go +++ b/middleware.go @@ -306,6 +306,10 @@ func parseRequestURL(c *Client, r *Request) error { if reqURL.Scheme == "" && len(c.scheme) > 0 { reqURL.Scheme = c.scheme + reqURL, err = url.Parse(reqURL.String()) // prevent empty URL.Host + if err != nil { + return err + } } // Adding Query Param diff --git a/req_test.go b/req_test.go index 4790144f..01c938d6 100644 --- a/req_test.go +++ b/req_test.go @@ -724,6 +724,7 @@ func TestGlobalWrapper(t *testing.T) { }), SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), SetCommonRetryFixedInterval(1*time.Second), + SetUnixSocket("/var/run/custom.sock"), ) os.Remove(tests.GetTestFilePath("tmpdump.out")) From e2800d6d57aa6577479031a32b34c5481b9041a9 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 7 Mar 2022 21:23:45 +0800 Subject: [PATCH 432/843] update README: update TODO --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 31ef62f2..c68beeee 100644 --- a/README.md +++ b/README.md @@ -954,7 +954,6 @@ client.R(). ## TODO List * [ ] Wrap more transport settings into client. -* [ ] Support unix socket. * [ ] Support h2c. * [ ] Design a logo. * [ ] Support HTTP3. From 214b4955a5d77c4e1849b317b12f369b3ebe3b61 Mon Sep 17 00:00:00 2001 From: Fufu Date: Tue, 8 Mar 2022 16:09:09 +0800 Subject: [PATCH 433/843] fix miss executing ResponseMiddlewares --- client.go | 2 +- retry_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index c1f94946..18894321 100644 --- a/client.go +++ b/client.go @@ -1541,7 +1541,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } } if !needRetry { // no retry is needed. - return + break // jump out to execute the ResponseMiddlewares. } // need retry, attempt to retry diff --git a/retry_test.go b/retry_test.go index 7181d125..9714529b 100644 --- a/retry_test.go +++ b/retry_test.go @@ -127,3 +127,16 @@ func TestRetryWithUnreplayableBody(t *testing.T) { Post("/") assertEqual(t, errRetryableWithUnReplayableBody, err) } + +func TestRetryWithSetResult(t *testing.T) { + headers := make(http.Header) + resp, err := tc().SetCommonCookies(&http.Cookie{ + Name: "test", + Value: "test", + }).R(). + SetRetryCount(1). + SetResult(&headers). + Get("/header") + assertSuccess(t, resp, err) + assertEqual(t, "test=test", headers.Get("Cookie")) +} From 71ce788cb7cc87b113f59efec9b2b1b7f1c1ba49 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Mar 2022 19:49:30 +0800 Subject: [PATCH 434/843] extract assertXXX into tests package --- client_test.go | 148 ++++++++++----------- decode_test.go | 13 +- h2_frame_test.go | 10 +- h2_test.go | 17 +-- h2_transport_test.go | 2 +- internal/tests/{tests.go => assert.go} | 39 ++++++ req_test.go | 101 ++++++--------- request_test.go | 172 ++++++++++++------------- response_test.go | 25 ++-- retry_test.go | 24 ++-- 10 files changed, 286 insertions(+), 265 deletions(-) rename internal/tests/{tests.go => assert.go} (52%) diff --git a/client_test.go b/client_test.go index 3e869a96..36f47366 100644 --- a/client_test.go +++ b/client_test.go @@ -20,23 +20,23 @@ func TestAllowGetMethodPayload(t *testing.T) { c := tc() resp, err := c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) c.EnableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) c.DisableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) } func TestSetTLSHandshakeTimeout(t *testing.T) { timeout := 2 * time.Second c := tc().SetTLSHandshakeTimeout(timeout) - assertEqual(t, timeout, c.t.TLSHandshakeTimeout) + tests.AssertEqual(t, timeout, c.t.TLSHandshakeTimeout) } func TestSetDial(t *testing.T) { @@ -46,7 +46,7 @@ func TestSetDial(t *testing.T) { } c := tc().SetDial(testDial) _, err := c.t.DialContext(nil, "", "") - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetDialTLS(t *testing.T) { @@ -56,7 +56,7 @@ func TestSetDialTLS(t *testing.T) { } c := tc().SetDialTLS(testDialTLS) _, err := c.t.DialTLSContext(nil, "", "") - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetFuncs(t *testing.T) { @@ -74,31 +74,31 @@ func TestSetFuncs(t *testing.T) { SetXmlUnmarshal(unmarshalFunc) _, err := c.jsonMarshal(nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) err = c.jsonUnmarshal(nil, nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) _, err = c.xmlMarshal(nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) err = c.xmlUnmarshal(nil, nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetCookieJar(t *testing.T) { c := tc().SetCookieJar(nil) - assertEqual(t, nil, c.httpClient.Jar) + tests.AssertEqual(t, nil, c.httpClient.Jar) } func TestTraceAll(t *testing.T) { c := tc().EnableTraceAll() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, resp.TraceInfo().TotalTime > 0) + tests.AssertEqual(t, true, resp.TraceInfo().TotalTime > 0) c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, resp.TraceInfo().TotalTime == 0) + tests.AssertEqual(t, true, resp.TraceInfo().TotalTime == 0) } func TestOnAfterResponse(t *testing.T) { @@ -108,21 +108,21 @@ func TestOnAfterResponse(t *testing.T) { return nil }) len2 := len(c.afterResponse) - assertEqual(t, true, len1+1 == len2) + tests.AssertEqual(t, true, len1+1 == len2) } func TestOnBeforeRequest(t *testing.T) { c := tc().OnBeforeRequest(func(client *Client, request *Request) error { return nil }) - assertEqual(t, true, len(c.udBeforeRequest) == 1) + tests.AssertEqual(t, true, len(c.udBeforeRequest) == 1) } func TestSetProxyURL(t *testing.T) { c := tc().SetProxyURL("http://dummy.proxy.local") u, err := c.t.Proxy(nil) - assertError(t, err) - assertEqual(t, "http://dummy.proxy.local", u.String()) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "http://dummy.proxy.local", u.String()) } func TestSetProxy(t *testing.T) { @@ -130,18 +130,18 @@ func TestSetProxy(t *testing.T) { proxy := http.ProxyURL(u) c := tc().SetProxy(proxy) uu, err := c.t.Proxy(nil) - assertError(t, err) - assertEqual(t, u.String(), uu.String()) + tests.AssertNoError(t, err) + tests.AssertEqual(t, u.String(), uu.String()) } func TestSetCommonContentType(t *testing.T) { c := tc().SetCommonContentType(jsonContentType) - assertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) + tests.AssertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) } func TestSetCommonHeader(t *testing.T) { c := tc().SetCommonHeader("my-header", "my-value") - assertEqual(t, "my-value", c.Headers.Get("my-header")) + tests.AssertEqual(t, "my-value", c.Headers.Get("my-header")) } func TestSetCommonHeaders(t *testing.T) { @@ -149,41 +149,41 @@ func TestSetCommonHeaders(t *testing.T) { "header1": "value1", "header2": "value2", }) - assertEqual(t, "value1", c.Headers.Get("header1")) - assertEqual(t, "value2", c.Headers.Get("header2")) + tests.AssertEqual(t, "value1", c.Headers.Get("header1")) + tests.AssertEqual(t, "value2", c.Headers.Get("header2")) } func TestSetCommonBasicAuth(t *testing.T) { c := tc().SetCommonBasicAuth("imroc", "123456") - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) + tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) } func TestSetCommonBearerAuthToken(t *testing.T) { c := tc().SetCommonBearerAuthToken("123456") - assertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) + tests.AssertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) } func TestSetUserAgent(t *testing.T) { c := tc().SetUserAgent("test") - assertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) + tests.AssertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) } func TestAutoDecode(t *testing.T) { c := tc().DisableAutoDecode() resp, err := c.R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, toGbk("我是roc"), resp.Bytes()) + tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.EnableAutoDecode().R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentType("html").R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, toGbk("我是roc"), resp.Bytes()) + tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.SetAutoDecodeContentType("text").R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentTypeFunc(func(contentType string) bool { if strings.Contains(contentType, "text") { return true @@ -191,7 +191,7 @@ func TestAutoDecode(t *testing.T) { return false }).R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeAllContentType().R().Get("/gbk-no-charset") assertSuccess(t, resp, err) @@ -201,29 +201,29 @@ func TestAutoDecode(t *testing.T) { func TestSetTimeout(t *testing.T) { timeout := 100 * time.Second c := tc().SetTimeout(timeout) - assertEqual(t, timeout, c.httpClient.Timeout) + tests.AssertEqual(t, timeout, c.httpClient.Timeout) } func TestSetLogger(t *testing.T) { l := createDefaultLogger() c := tc().SetLogger(l) - assertEqual(t, l, c.log) + tests.AssertEqual(t, l, c.log) c.SetLogger(nil) - assertEqual(t, &disableLogger{}, c.log) + tests.AssertEqual(t, &disableLogger{}, c.log) } func TestSetScheme(t *testing.T) { c := tc().SetScheme("https") - assertEqual(t, "https", c.scheme) + tests.AssertEqual(t, "https", c.scheme) } func TestDebugLog(t *testing.T) { c := tc().EnableDebugLog() - assertEqual(t, true, c.DebugLog) + tests.AssertEqual(t, true, c.DebugLog) c.DisableDebugLog() - assertEqual(t, false, c.DebugLog) + tests.AssertEqual(t, false, c.DebugLog) } func TestSetCommonCookies(t *testing.T) { @@ -233,25 +233,25 @@ func TestSetCommonCookies(t *testing.T) { Value: "test", }).R().SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "test=test", headers.Get("Cookie")) + tests.AssertEqual(t, "test=test", headers.Get("Cookie")) } func TestSetCommonQueryString(t *testing.T) { resp, err := tc().SetCommonQueryString("test=test").R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestSetCommonPathParams(t *testing.T) { c := tc().SetCommonPathParams(map[string]string{"test": "test"}) - assertNotNil(t, c.PathParams) - assertEqual(t, "test", c.PathParams["test"]) + tests.AssertNotNil(t, c.PathParams) + tests.AssertEqual(t, "test", c.PathParams["test"]) } func TestSetCommonPathParam(t *testing.T) { c := tc().SetCommonPathParam("test", "test") - assertNotNil(t, c.PathParams) - assertEqual(t, "test", c.PathParams["test"]) + tests.AssertNotNil(t, c.PathParams) + tests.AssertEqual(t, "test", c.PathParams["test"]) } func TestAddCommonQueryParam(t *testing.T) { @@ -260,97 +260,97 @@ func TestAddCommonQueryParam(t *testing.T) { AddCommonQueryParam("test", "2"). R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=1&test=2", resp.String()) + tests.AssertEqual(t, "test=1&test=2", resp.String()) } func TestSetCommonQueryParam(t *testing.T) { resp, err := tc().SetCommonQueryParam("test", "test").R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestSetCommonQueryParams(t *testing.T) { resp, err := tc().SetCommonQueryParams(map[string]string{"test": "test"}).R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestInsecureSkipVerify(t *testing.T) { c := tc().EnableInsecureSkipVerify() - assertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) c.DisableInsecureSkipVerify() - assertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) } func TestSetTLSClientConfig(t *testing.T) { config := &tls.Config{InsecureSkipVerify: true} c := tc().SetTLSClientConfig(config) - assertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, config, c.t.TLSClientConfig) } func TestCompression(t *testing.T) { c := tc().DisableCompression() - assertEqual(t, true, c.t.DisableCompression) + tests.AssertEqual(t, true, c.t.DisableCompression) c.EnableCompression() - assertEqual(t, false, c.t.DisableCompression) + tests.AssertEqual(t, false, c.t.DisableCompression) } func TestKeepAlives(t *testing.T) { c := tc().DisableKeepAlives() - assertEqual(t, true, c.t.DisableKeepAlives) + tests.AssertEqual(t, true, c.t.DisableKeepAlives) c.EnableKeepAlives() - assertEqual(t, false, c.t.DisableKeepAlives) + tests.AssertEqual(t, false, c.t.DisableKeepAlives) } func TestRedirect(t *testing.T) { _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "redirect is disabled", true) _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) _, err = tc().SetRedirectPolicy(SameHostRedirectPolicy()).R().Get("/redirect-to-other") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "different host name is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedHostRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - assertNotNil(t, err) + tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) } func TestGetTLSClientConfig(t *testing.T) { c := tc() config := c.GetTLSClientConfig() - assertEqual(t, true, c.t.TLSClientConfig != nil) - assertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, true, c.t.TLSClientConfig != nil) + tests.AssertEqual(t, config, c.t.TLSClientConfig) } func TestSetRootCertFromFile(t *testing.T) { c := tc().SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) - assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetRootCertFromString(t *testing.T) { c := tc().SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))) - assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetCerts(t *testing.T) { c := tc().SetCerts(tls.Certificate{}, tls.Certificate{}) - assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) + tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) } func TestSetCertFromFile(t *testing.T) { @@ -358,7 +358,7 @@ func TestSetCertFromFile(t *testing.T) { tests.GetTestFilePath("sample-client.pem"), tests.GetTestFilePath("sample-client-key.pem"), ) - assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) + tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) } func TestSetOutputDirectory(t *testing.T) { @@ -370,13 +370,13 @@ func TestSetOutputDirectory(t *testing.T) { assertSuccess(t, resp, err) content := string(tests.GetTestFileContent(t, outFile)) os.Remove(tests.GetTestFilePath(outFile)) - assertEqual(t, "TestGet: text response", content) + tests.AssertEqual(t, "TestGet: text response", content) } func TestSetBaseURL(t *testing.T) { baseURL := "http://dummy-req.local/test" resp, _ := tc().SetTimeout(time.Nanosecond).SetBaseURL(baseURL).R().Get("/req") - assertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) + tests.AssertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) } func TestSetCommonFormDataFromValues(t *testing.T) { @@ -388,7 +388,7 @@ func TestSetCommonFormDataFromValues(t *testing.T) { R().SetResult(&gotForm). Post("/form") assertSuccess(t, resp, err) - assertEqual(t, "test", gotForm.Get("test")) + tests.AssertEqual(t, "test", gotForm.Get("test")) } func TestSetCommonFormData(t *testing.T) { @@ -401,7 +401,7 @@ func TestSetCommonFormData(t *testing.T) { SetResult(&form). Post("/form") assertSuccess(t, resp, err) - assertEqual(t, "test", form.Get("test")) + tests.AssertEqual(t, "test", form.Get("test")) } func TestClientClone(t *testing.T) { @@ -426,15 +426,15 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { c.DisableAutoReadResponse() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) result, err := resp.ToString() - assertError(t, err) - assertEqual(t, "TestGet: text response", result) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "TestGet: text response", result) resp, err = c.R().Get("/") assertSuccess(t, resp, err) _, err = ioutil.ReadAll(resp.Body) - assertError(t, err) + tests.AssertNoError(t, err) } func TestEnableDumpAll(t *testing.T) { @@ -564,5 +564,5 @@ func TestEnableDumpAllAsync(t *testing.T) { c := tc() buf := new(bytes.Buffer) c.EnableDumpAllTo(buf).EnableDumpAllAsync() - assertEqual(t, true, c.getDumpOptions().Async) + tests.AssertEqual(t, true, c.getDumpOptions().Async) } diff --git a/decode_test.go b/decode_test.go index 209958ce..e65a8ea6 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1,14 +1,17 @@ package req -import "testing" +import ( + "github.com/imroc/req/v3/internal/tests" + "testing" +) func TestPeekDrain(t *testing.T) { a := autoDecodeReadCloser{peek: []byte("test")} p := make([]byte, 2) n, _ := a.peekDrain(p) - assertEqual(t, 2, n) - assertEqual(t, true, a.peek != nil) + tests.AssertEqual(t, 2, n) + tests.AssertEqual(t, true, a.peek != nil) n, _ = a.peekDrain(p) - assertEqual(t, 2, n) - assertEqual(t, true, a.peek == nil) + tests.AssertEqual(t, 2, n) + tests.AssertEqual(t, true, a.peek == nil) } diff --git a/h2_frame_test.go b/h2_frame_test.go index 36874ff8..136a9bd7 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1397,7 +1397,7 @@ func TestParseUnknownFrame(t *testing.T) { if !ok { t.Fatalf("not http2UnknownFrame type: %#+v", f) } - assertEqual(t, p, uf.Payload()) + tests.AssertEqual(t, p, uf.Payload()) } func TestParseRSTStreamFrame(t *testing.T) { @@ -1451,16 +1451,16 @@ func TestPushPromiseFrame(t *testing.T) { fh := http2FrameHeader{valid: true} buf := []byte("test") f := &http2PushPromiseFrame{http2FrameHeader: fh, headerFragBuf: buf} - assertEqual(t, buf, f.HeaderBlockFragment()) - assertEqual(t, false, f.HeadersEnded()) + tests.AssertEqual(t, buf, f.HeaderBlockFragment()) + tests.AssertEqual(t, false, f.HeadersEnded()) } func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} f.logWrite() - assertNotNil(t, f.debugFramer) - assertNil(t, f.ErrorDetail()) + tests.AssertNotNil(t, f.debugFramer) + tests.AssertIsNil(t, f.ErrorDetail()) f.w = new(bytes.Buffer) err := f.WriteRawFrame(http2FrameData, http2FlagDataEndStream, 1, nil) diff --git a/h2_test.go b/h2_test.go index f96928ac..d1249f1f 100644 --- a/h2_test.go +++ b/h2_test.go @@ -7,6 +7,7 @@ package req import ( "flag" "fmt" + "github.com/imroc/req/v3/internal/tests" "net/http" "testing" "time" @@ -98,21 +99,21 @@ func TestSettingValid(t *testing.T) { } for _, c := range cases { s := &http2Setting{ID: c.id, Val: c.val} - assertEqual(t, true, s.Valid() != nil) + tests.AssertEqual(t, true, s.Valid() != nil) } s := &http2Setting{ID: http2SettingMaxHeaderListSize} - assertEqual(t, true, s.Valid() == nil) + tests.AssertEqual(t, true, s.Valid() == nil) } func TestBodyAllowedForStatus(t *testing.T) { - assertEqual(t, false, http2bodyAllowedForStatus(101)) - assertEqual(t, false, http2bodyAllowedForStatus(204)) - assertEqual(t, false, http2bodyAllowedForStatus(304)) - assertEqual(t, true, http2bodyAllowedForStatus(900)) + tests.AssertEqual(t, false, http2bodyAllowedForStatus(101)) + tests.AssertEqual(t, false, http2bodyAllowedForStatus(204)) + tests.AssertEqual(t, false, http2bodyAllowedForStatus(304)) + tests.AssertEqual(t, true, http2bodyAllowedForStatus(900)) } func TestHttpError(t *testing.T) { e := &http2httpError{msg: "test"} - assertEqual(t, "test", e.Error()) - assertEqual(t, true, e.Temporary()) + tests.AssertEqual(t, "test", e.Error()) + tests.AssertEqual(t, true, e.Temporary()) } diff --git a/h2_transport_test.go b/h2_transport_test.go index 7e2db8a7..7eaabe25 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -5865,7 +5865,7 @@ func TestCountReadFrameError(t *testing.T) { var err error cc.countReadFrameError(err) - assertEqual(t, "", errMsg) + tests.AssertEqual(t, "", errMsg) err = http2ConnectionError(http2ErrCodeInternal) cc.countReadFrameError(err) diff --git a/internal/tests/tests.go b/internal/tests/assert.go similarity index 52% rename from internal/tests/tests.go rename to internal/tests/assert.go index 0a7d8062..9d41e2b6 100644 --- a/internal/tests/tests.go +++ b/internal/tests/assert.go @@ -1,10 +1,33 @@ package tests import ( + "reflect" "strings" "testing" ) +// AssertIsNil asserts is nil. +func AssertIsNil(t *testing.T, v interface{}) { + if !isNil(v) { + t.Errorf("[%v] was expected to be nil", v) + } +} + +// AssertNotNil asserts is not nil. +func AssertNotNil(t *testing.T, v interface{}) { + if isNil(v) { + t.Fatalf("[%v] was expected to be non-nil", v) + } +} + +// AssertEqual asserts e (expected) is equal with g (got). +func AssertEqual(t *testing.T, e, g interface{}) { + if !equal(e, g) { + t.Errorf("Expected [%+v], got [%+v]", e, g) + } + return +} + // AssertNoError asserts no error. func AssertNoError(t *testing.T, err error) { if err != nil { @@ -37,3 +60,19 @@ func AssertContains(t *testing.T, s, substr string, shouldContain bool) { } } } + +func equal(expected, got interface{}) bool { + return reflect.DeepEqual(expected, got) +} + +func isNil(v interface{}) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + kind := rv.Kind() + if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { + return true + } + return false +} diff --git a/req_test.go b/req_test.go index 01c938d6..91a37992 100644 --- a/req_test.go +++ b/req_test.go @@ -262,51 +262,33 @@ func handleGet(w http.ResponseWriter, r *http.Request) { } func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { - assertError(t, err) - assertNotNil(t, resp) - assertNotNil(t, resp.Body) - assertEqual(t, statusCode, resp.StatusCode) - assertEqual(t, status, resp.Status) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp) + tests.AssertNotNil(t, resp.Body) + tests.AssertEqual(t, statusCode, resp.StatusCode) + tests.AssertEqual(t, status, resp.Status) } func assertSuccess(t *testing.T, resp *Response, err error) { - assertError(t, err) - assertNotNil(t, resp.Response) - assertNotNil(t, resp.Response.Body) - assertEqual(t, http.StatusOK, resp.StatusCode) - assertEqual(t, "200 OK", resp.Status) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp.Response) + tests.AssertNotNil(t, resp.Response.Body) + tests.AssertEqual(t, http.StatusOK, resp.StatusCode) + tests.AssertEqual(t, "200 OK", resp.Status) if !resp.IsSuccess() { t.Error("Response.IsSuccess should return true") } } func assertIsError(t *testing.T, resp *Response, err error) { - assertError(t, err) - assertNotNil(t, resp) - assertNotNil(t, resp.Body) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp) + tests.AssertNotNil(t, resp.Body) if !resp.IsError() { t.Error("Response.IsError should return true") } } -func assertNil(t *testing.T, v interface{}) { - if !isNil(v) { - t.Errorf("[%v] was expected to be nil", v) - } -} - -func assertNotNil(t *testing.T, v interface{}) { - if isNil(v) { - t.Fatalf("[%v] was expected to be non-nil", v) - } -} - -func assertError(t *testing.T, err error) { - if err != nil { - t.Errorf("Error occurred [%v]", err) - } -} - func assertEqualStruct(t *testing.T, e, g interface{}, onlyExported bool, excludes ...string) { ev := reflect.ValueOf(e).Elem() gv := reflect.ValueOf(g).Elem() @@ -355,13 +337,6 @@ func assertEqualStruct(t *testing.T, e, g interface{}, onlyExported bool, exclud } -func assertEqual(t *testing.T, e, g interface{}) { - if !equal(e, g) { - t.Errorf("Expected [%+v], got [%+v]", e, g) - } - return -} - func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { if equal(e, g) { t.Errorf("Expected [%v], got [%v]", e, g) @@ -446,17 +421,17 @@ func testGlobalWrapperEnableDumps(t *testing.T) { buf := new(bytes.Buffer) r := EnableDumpTo(buf) - assertEqual(t, true, r.getDumpOptions().Output != nil) + tests.AssertEqual(t, true, r.getDumpOptions().Output != nil) dumpFile := tests.GetTestFilePath("req_tmp_dump.out") r = EnableDumpToFile(tests.GetTestFilePath(dumpFile)) - assertEqual(t, true, r.getDumpOptions().Output != nil) + tests.AssertEqual(t, true, r.getDumpOptions().Output != nil) os.Remove(dumpFile) r = SetDumpOptions(&DumpOptions{ RequestHeader: true, }) - assertEqual(t, true, r.getDumpOptions().RequestHeader) + tests.AssertEqual(t, true, r.getDumpOptions().RequestHeader) } func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { @@ -485,50 +460,50 @@ func testGlobalWrapperSendRequest(t *testing.T) { resp, err := Put(testURL) assertSuccess(t, resp, err) - assertEqual(t, "PUT", resp.Header.Get("Method")) + tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) resp = MustPut(testURL) - assertEqual(t, "PUT", resp.Header.Get("Method")) + tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) resp, err = Patch(testURL) assertSuccess(t, resp, err) - assertEqual(t, "PATCH", resp.Header.Get("Method")) + tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) resp = MustPatch(testURL) - assertEqual(t, "PATCH", resp.Header.Get("Method")) + tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) resp, err = Delete(testURL) assertSuccess(t, resp, err) - assertEqual(t, "DELETE", resp.Header.Get("Method")) + tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) resp = MustDelete(testURL) - assertEqual(t, "DELETE", resp.Header.Get("Method")) + tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) resp, err = Options(testURL) assertSuccess(t, resp, err) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp = MustOptions(testURL) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp, err = Head(testURL) assertSuccess(t, resp, err) - assertEqual(t, "HEAD", resp.Header.Get("Method")) + tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) resp = MustHead(testURL) - assertEqual(t, "HEAD", resp.Header.Get("Method")) + tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) resp, err = Get(testURL) assertSuccess(t, resp, err) - assertEqual(t, "GET", resp.Header.Get("Method")) + tests.AssertEqual(t, "GET", resp.Header.Get("Method")) resp = MustGet(testURL) - assertEqual(t, "GET", resp.Header.Get("Method")) + tests.AssertEqual(t, "GET", resp.Header.Get("Method")) resp, err = Post(testURL) assertSuccess(t, resp, err) - assertEqual(t, "POST", resp.Header.Get("Method")) + tests.AssertEqual(t, "POST", resp.Header.Get("Method")) resp = MustPost(testURL) - assertEqual(t, "POST", resp.Header.Get("Method")) + tests.AssertEqual(t, "POST", resp.Header.Get("Method")) } func testGlobalWrapperSetRequest(t *testing.T, rs ...*Request) { for _, r := range rs { - assertNotNil(t, r) + tests.AssertNotNil(t, r) } } @@ -601,7 +576,7 @@ func TestGlobalWrapperSetRequest(t *testing.T) { func testGlobalClientSettingWrapper(t *testing.T, cs ...*Client) { for _, c := range cs { - assertNotNil(t, c) + tests.AssertNotNil(t, c) } } @@ -729,21 +704,21 @@ func TestGlobalWrapper(t *testing.T) { os.Remove(tests.GetTestFilePath("tmpdump.out")) config := GetTLSClientConfig() - assertEqual(t, config, DefaultClient().t.TLSClientConfig) + tests.AssertEqual(t, config, DefaultClient().t.TLSClientConfig) r := R() - assertEqual(t, true, r != nil) + tests.AssertEqual(t, true, r != nil) c := C() c.SetTimeout(10 * time.Second) SetDefaultClient(c) - assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - assertEqual(t, GetClient(), DefaultClient().httpClient) + tests.AssertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) + tests.AssertEqual(t, GetClient(), DefaultClient().httpClient) r = NewRequest() - assertEqual(t, true, r != nil) + tests.AssertEqual(t, true, r != nil) c = NewClient() - assertEqual(t, true, c != nil) + tests.AssertEqual(t, true, c != nil) } func TestTrailer(t *testing.T) { diff --git a/request_test.go b/request_test.go index 4b16661d..e82fc96a 100644 --- a/request_test.go +++ b/request_test.go @@ -24,45 +24,45 @@ func TestMethods(t *testing.T) { func testMethods(t *testing.T, c *Client) { resp, err := c.R().Put("/") assertSuccess(t, resp, err) - assertEqual(t, "PUT", resp.Header.Get("Method")) + tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) resp = c.R().MustPut("/") - assertEqual(t, "PUT", resp.Header.Get("Method")) + tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) resp, err = c.R().Patch("/") assertSuccess(t, resp, err) - assertEqual(t, "PATCH", resp.Header.Get("Method")) + tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) resp = c.R().MustPatch("/") - assertEqual(t, "PATCH", resp.Header.Get("Method")) + tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) resp, err = c.R().Delete("/") assertSuccess(t, resp, err) - assertEqual(t, "DELETE", resp.Header.Get("Method")) + tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) resp = c.R().MustDelete("/") - assertEqual(t, "DELETE", resp.Header.Get("Method")) + tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) resp, err = c.R().Options("/") assertSuccess(t, resp, err) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp = c.R().MustOptions("/") - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp, err = c.R().Head("/") assertSuccess(t, resp, err) - assertEqual(t, "HEAD", resp.Header.Get("Method")) + tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) resp = c.R().MustHead("/") - assertEqual(t, "HEAD", resp.Header.Get("Method")) + tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) resp, err = c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, "GET", resp.Header.Get("Method")) + tests.AssertEqual(t, "GET", resp.Header.Get("Method")) resp = c.R().MustGet("/") - assertEqual(t, "GET", resp.Header.Get("Method")) + tests.AssertEqual(t, "GET", resp.Header.Get("Method")) resp, err = c.R().Post("/") assertSuccess(t, resp, err) - assertEqual(t, "POST", resp.Header.Get("Method")) + tests.AssertEqual(t, "POST", resp.Header.Get("Method")) resp = c.R().MustPost("/") - assertEqual(t, "POST", resp.Header.Get("Method")) + tests.AssertEqual(t, "POST", resp.Header.Get("Method")) } func TestEnableDump(t *testing.T) { @@ -79,7 +79,7 @@ func TestEnableDumpToFIle(t *testing.T) { tmpFile := "tmp_dumpfile_req" resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, len(tests.GetTestFileContent(t, tmpFile)) > 0) + tests.AssertEqual(t, true, len(tests.GetTestFileContent(t, tmpFile)) > 0) os.Remove(tests.GetTestFilePath(tmpFile)) } @@ -189,7 +189,7 @@ func TestGet(t *testing.T) { func testGet(t *testing.T, c *Client) { resp, err := c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, "TestGet: text response", resp.String()) + tests.AssertEqual(t, "TestGet: text response", resp.String()) } func TestBadRequest(t *testing.T) { @@ -216,16 +216,16 @@ func testSetBodyMarshal(t *testing.T, c *Client) { return func(e *echo) { var user User err := json.Unmarshal([]byte(e.Body), &user) - assertError(t, err) - assertEqual(t, username, user.Username) + tests.AssertNoError(t, err) + tests.AssertEqual(t, username, user.Username) } } assertUsernameXml := func(username string) func(e *echo) { return func(e *echo) { var user User err := xml.Unmarshal([]byte(e.Body), &user) - assertError(t, err) - assertEqual(t, username, user.Username) + tests.AssertNoError(t, err) + tests.AssertEqual(t, username, user.Username) } } testCases := []struct { @@ -298,8 +298,8 @@ func TestSetBodyReader(t *testing.T) { var e echo resp, err := tc().R().SetBody(ioutil.NopCloser(bytes.NewBufferString("hello"))).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) + tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, "hello", e.Body) } func TestSetBodyGetContentFunc(t *testing.T) { @@ -308,8 +308,8 @@ func TestSetBodyGetContentFunc(t *testing.T) { return ioutil.NopCloser(bytes.NewBufferString("hello")), nil }).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) + tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, "hello", e.Body) e = echo{} var fn GetContentFunc = func() (io.ReadCloser, error) { @@ -317,8 +317,8 @@ func TestSetBodyGetContentFunc(t *testing.T) { } resp, err = tc().R().SetBody(fn).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) + tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, "hello", e.Body) } func TestSetBodyContent(t *testing.T) { @@ -351,8 +351,8 @@ func testSetBodyContent(t *testing.T, c *Client) { var e echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) - assertEqual(t, testBody, e.Body) + tests.AssertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, testBody, e.Body) } // Set Reader @@ -360,8 +360,8 @@ func testSetBodyContent(t *testing.T, c *Client) { e = echo{} resp, err := c.R().SetBody(testBodyReader).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, testBody, e.Body) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, testBody, e.Body) + tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) } func TestCookie(t *testing.T) { @@ -382,7 +382,7 @@ func testCookie(t *testing.T, c *Client) { }, ).SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) + tests.AssertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } func TestAuth(t *testing.T) { @@ -397,7 +397,7 @@ func testAuth(t *testing.T, c *Client) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) + tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" headers = make(http.Header) @@ -406,7 +406,7 @@ func testAuth(t *testing.T, c *Client) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "Bearer "+token, headers.Get("Authorization")) + tests.AssertEqual(t, "Bearer "+token, headers.Get("Authorization")) } func TestHeader(t *testing.T) { @@ -419,7 +419,7 @@ func testHeader(t *testing.T, c *Client) { customUserAgent := "My Custom User Agent" resp, err := c.R().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") assertSuccess(t, resp, err) - assertEqual(t, customUserAgent, resp.String()) + tests.AssertEqual(t, customUserAgent, resp.String()) // Set custom header headers := make(http.Header) @@ -431,9 +431,9 @@ func testHeader(t *testing.T, c *Client) { }).SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "value1", headers.Get("header1")) - assertEqual(t, "value2", headers.Get("header2")) - assertEqual(t, "value3", headers.Get("header3")) + tests.AssertEqual(t, "value1", headers.Get("header1")) + tests.AssertEqual(t, "value2", headers.Get("header2")) + tests.AssertEqual(t, "value3", headers.Get("header3")) } func TestQueryParam(t *testing.T) { @@ -458,14 +458,14 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key3", "value3"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryString resp, err = c.R(). SetQueryString("key1=value1&key2=value2&key3=value3"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParams resp, err = c.R(). @@ -476,7 +476,7 @@ func testQueryParam(t *testing.T, c *Client) { }). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParam & SetQueryParams & SetQueryString resp, err = c.R(). @@ -488,7 +488,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryString("key4=value4&key5=value5"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) // Set same param to override resp, err = c.R(). @@ -503,7 +503,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) // Add same param without override resp, err = c.R(). @@ -518,7 +518,7 @@ func testQueryParam(t *testing.T, c *Client) { AddQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) } func TestPathParam(t *testing.T) { @@ -532,7 +532,7 @@ func testPathParam(t *testing.T, c *Client) { SetPathParam("username", username). Get("/user/{username}/profile") assertSuccess(t, resp, err) - assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) + tests.AssertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) } func TestSuccess(t *testing.T) { @@ -547,7 +547,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo). Get("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) userInfo = UserInfo{} resp, err = c.R(). @@ -556,7 +556,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo).EnableDump(). Get("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestError(t *testing.T) { @@ -571,7 +571,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10000, errMsg.ErrorCode) + tests.AssertEqual(t, 10000, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -579,7 +579,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10001, errMsg.ErrorCode) + tests.AssertEqual(t, 10001, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -588,7 +588,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10001, errMsg.ErrorCode) + tests.AssertEqual(t, 10001, errMsg.ErrorCode) } func TestForm(t *testing.T) { @@ -606,7 +606,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) v := make(url.Values) v.Add("username", "imroc") @@ -616,7 +616,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestHostHeaderOverride(t *testing.T) { @@ -627,7 +627,7 @@ func TestHostHeaderOverride(t *testing.T) { func testHostHeaderOverride(t *testing.T, c *Client) { resp, err := c.R().SetHeader("Host", "testhostname").Get("/host-header") assertSuccess(t, resp, err) - assertEqual(t, "testhostname", resp.String()) + tests.AssertEqual(t, "testhostname", resp.String()) } func TestTraceInfo(t *testing.T) { @@ -644,7 +644,7 @@ func TestTraceInfo(t *testing.T) { ti = resp.TraceInfo() tests.AssertContains(t, ti.String(), "not enabled", false) tests.AssertContains(t, ti.Blame(), "not enabled", false) - assertEqual(t, true, resp.TotalTime() > 0) + tests.AssertEqual(t, true, resp.TotalTime() > 0) } func testTraceInfo(t *testing.T, c *Client) { @@ -653,28 +653,28 @@ func testTraceInfo(t *testing.T, c *Client) { resp, err := c.R().Get("/") assertSuccess(t, resp, err) ti := resp.TraceInfo() - assertEqual(t, true, ti.TotalTime > 0) - assertEqual(t, true, ti.TCPConnectTime > 0) - assertEqual(t, true, ti.TLSHandshakeTime > 0) - assertEqual(t, true, ti.ConnectTime > 0) - assertEqual(t, true, ti.FirstResponseTime > 0) - assertEqual(t, true, ti.ResponseTime > 0) - assertNotNil(t, ti.RemoteAddr) + tests.AssertEqual(t, true, ti.TotalTime > 0) + tests.AssertEqual(t, true, ti.TCPConnectTime > 0) + tests.AssertEqual(t, true, ti.TLSHandshakeTime > 0) + tests.AssertEqual(t, true, ti.ConnectTime > 0) + tests.AssertEqual(t, true, ti.FirstResponseTime > 0) + tests.AssertEqual(t, true, ti.ResponseTime > 0) + tests.AssertNotNil(t, ti.RemoteAddr) // disable trace at client level c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - assertEqual(t, false, ti.TotalTime > 0) - assertNil(t, ti.RemoteAddr) + tests.AssertEqual(t, false, ti.TotalTime > 0) + tests.AssertIsNil(t, ti.RemoteAddr) // enable trace at request level resp, err = c.R().EnableTrace().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - assertEqual(t, true, ti.TotalTime > 0) - assertNotNil(t, ti.RemoteAddr) + tests.AssertEqual(t, true, ti.TotalTime > 0) + tests.AssertNotNil(t, ti.RemoteAddr) } func TestTraceOnTimeout(t *testing.T) { @@ -686,41 +686,41 @@ func testTraceOnTimeout(t *testing.T, c *Client) { c.EnableTraceAll().SetTimeout(100 * time.Millisecond) resp, err := c.R().Get("http://req-nowhere.local") - assertNotNil(t, err) - assertNotNil(t, resp) + tests.AssertNotNil(t, err) + tests.AssertNotNil(t, resp) tr := resp.TraceInfo() - assertEqual(t, true, tr.DNSLookupTime >= 0) - assertEqual(t, true, tr.ConnectTime == 0) - assertEqual(t, true, tr.TLSHandshakeTime == 0) - assertEqual(t, true, tr.TCPConnectTime == 0) - assertEqual(t, true, tr.FirstResponseTime == 0) - assertEqual(t, true, tr.ResponseTime == 0) - assertEqual(t, true, tr.TotalTime > 0) - assertEqual(t, true, tr.TotalTime == resp.TotalTime()) + tests.AssertEqual(t, true, tr.DNSLookupTime >= 0) + tests.AssertEqual(t, true, tr.ConnectTime == 0) + tests.AssertEqual(t, true, tr.TLSHandshakeTime == 0) + tests.AssertEqual(t, true, tr.TCPConnectTime == 0) + tests.AssertEqual(t, true, tr.FirstResponseTime == 0) + tests.AssertEqual(t, true, tr.ResponseTime == 0) + tests.AssertEqual(t, true, tr.TotalTime > 0) + tests.AssertEqual(t, true, tr.TotalTime == resp.TotalTime()) } func TestAutoDetectRequestContentType(t *testing.T) { c := tc() resp, err := c.R().SetBody(tests.GetTestFileContent(t, "sample-image.png")).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, "image/png", resp.String()) + tests.AssertEqual(t, "image/png", resp.String()) resp, err = c.R().SetBodyJsonString(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, jsonContentType, resp.String()) + tests.AssertEqual(t, jsonContentType, resp.String()) resp, err = c.R().SetContentType(xmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, xmlContentType, resp.String()) + tests.AssertEqual(t, xmlContentType, resp.String()) resp, err = c.R().SetBody(`

hello

`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, "text/html; charset=utf-8", resp.String()) + tests.AssertEqual(t, "text/html; charset=utf-8", resp.String()) resp, err = c.R().SetBody(`hello world`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, plainTextContentType, resp.String()) + tests.AssertEqual(t, plainTextContentType, resp.String()) } func TestSetFileUploadCheck(t *testing.T) { @@ -729,7 +729,7 @@ func TestSetFileUploadCheck(t *testing.T) { tests.AssertErrorContains(t, err, "missing param name") tests.AssertErrorContains(t, err, "missing filename") tests.AssertErrorContains(t, err, "missing file content") - assertEqual(t, 0, len(resp.Request.uploadFiles)) + tests.AssertEqual(t, 0, len(resp.Request.uploadFiles)) } func TestUploadMultipart(t *testing.T) { @@ -753,13 +753,13 @@ func TestUploadMultipart(t *testing.T) { func TestFixPragmaCache(t *testing.T) { resp, err := tc().EnableForceHTTP1().R().Get("/pragma") assertSuccess(t, resp, err) - assertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) + tests.AssertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) } func TestSetFileBytes(t *testing.T) { resp, err := tc().R().SetFileBytes("file", "file.txt", []byte("test")).Post("/file-text") assertSuccess(t, resp, err) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) } func TestSetBodyWrapper(t *testing.T) { @@ -768,14 +768,14 @@ func TestSetBodyWrapper(t *testing.T) { c := tc() r := c.R().SetBodyXmlString(s) - assertEqual(t, true, len(r.body) > 0) + tests.AssertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyXmlBytes(b) - assertEqual(t, true, len(r.body) > 0) + tests.AssertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonString(s) - assertEqual(t, true, len(r.body) > 0) + tests.AssertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonBytes(b) - assertEqual(t, true, len(r.body) > 0) + tests.AssertEqual(t, true, len(r.body) > 0) } diff --git a/response_test.go b/response_test.go index a80a6d85..141d509d 100644 --- a/response_test.go +++ b/response_test.go @@ -1,6 +1,9 @@ package req -import "testing" +import ( + "github.com/imroc/req/v3/internal/tests" + "testing" +) type User struct { Name string `json:"name" xml:"name"` @@ -15,8 +18,8 @@ func TestUnmarshalJson(t *testing.T) { resp, err := tc().R().Get("/json") assertSuccess(t, resp, err) err = resp.UnmarshalJson(&user) - assertError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestUnmarshalXml(t *testing.T) { @@ -24,8 +27,8 @@ func TestUnmarshalXml(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.UnmarshalXml(&user) - assertError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestUnmarshal(t *testing.T) { @@ -33,8 +36,8 @@ func TestUnmarshal(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.Unmarshal(&user) - assertError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestResponseResult(t *testing.T) { @@ -43,10 +46,10 @@ func TestResponseResult(t *testing.T) { if !ok { t.Fatal("Response.Result() should return *User") } - assertEqual(t, "roc", user.Name) + tests.AssertEqual(t, "roc", user.Name) - assertEqual(t, true, resp.TotalTime() > 0) - assertEqual(t, false, resp.ReceivedAt().IsZero()) + tests.AssertEqual(t, true, resp.TotalTime() > 0) + tests.AssertEqual(t, false, resp.ReceivedAt().IsZero()) } func TestResponseError(t *testing.T) { @@ -55,5 +58,5 @@ func TestResponseError(t *testing.T) { if !ok { t.Fatal("Response.Error() should return *Message") } - assertEqual(t, "not allowed", msg.Message) + tests.AssertEqual(t, "not allowed", msg.Message) } diff --git a/retry_test.go b/retry_test.go index 9714529b..74aa2665 100644 --- a/retry_test.go +++ b/retry_test.go @@ -29,8 +29,8 @@ func testRetry(t *testing.T, setFunc func(r *Request)) { setFunc(r) resp, err := r.Get("/too-many") tests.AssertNoError(t, err) - assertEqual(t, 3, resp.Request.RetryAttempt) - assertEqual(t, 3, attempt) + tests.AssertEqual(t, 3, resp.Request.RetryAttempt) + tests.AssertEqual(t, 3, attempt) } func TestRetryInterval(t *testing.T) { @@ -55,7 +55,7 @@ func TestAddRetryHook(t *testing.T) { test = "test2" }) }) - assertEqual(t, "test2", test) + tests.AssertEqual(t, "test2", test) } func TestRetryOverride(t *testing.T) { @@ -75,8 +75,8 @@ func TestRetryOverride(t *testing.T) { return err != nil || resp.StatusCode == http.StatusTooManyRequests }).Get("/too-many") tests.AssertNoError(t, err) - assertEqual(t, "test1", test) - assertEqual(t, 2, resp.Request.RetryAttempt) + tests.AssertEqual(t, "test1", test) + tests.AssertEqual(t, 2, resp.Request.RetryAttempt) } func TestAddRetryCondition(t *testing.T) { @@ -93,8 +93,8 @@ func TestAddRetryCondition(t *testing.T) { attempt++ }).Get("/too-many") tests.AssertNoError(t, err) - assertEqual(t, 0, attempt) - assertEqual(t, 0, resp.Request.RetryAttempt) + tests.AssertEqual(t, 0, attempt) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) attempt = 0 resp, err = tc(). @@ -109,8 +109,8 @@ func TestAddRetryCondition(t *testing.T) { attempt++ }).R().Get("/too-many") tests.AssertNoError(t, err) - assertEqual(t, 0, attempt) - assertEqual(t, 0, resp.Request.RetryAttempt) + tests.AssertEqual(t, 0, attempt) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) } @@ -119,13 +119,13 @@ func TestRetryWithUnreplayableBody(t *testing.T) { SetRetryCount(1). SetBody(bytes.NewBufferString("test")). Post("/") - assertEqual(t, errRetryableWithUnReplayableBody, err) + tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) _, err = tc().R(). SetRetryCount(1). SetBody(ioutil.NopCloser(bytes.NewBufferString("test"))). Post("/") - assertEqual(t, errRetryableWithUnReplayableBody, err) + tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) } func TestRetryWithSetResult(t *testing.T) { @@ -138,5 +138,5 @@ func TestRetryWithSetResult(t *testing.T) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "test=test", headers.Get("Cookie")) + tests.AssertEqual(t, "test=test", headers.Get("Cookie")) } From ca13a83b5d361d054dc5f1be392c6388d7945112 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Mar 2022 19:59:06 +0800 Subject: [PATCH 435/843] clone retryOption in Client.Clone() --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 18894321..07f378f2 100644 --- a/client.go +++ b/client.go @@ -1379,6 +1379,7 @@ func (c *Client) Clone() *Client { cc.udBeforeRequest = cloneRequestMiddleware(c.udBeforeRequest) cc.afterResponse = cloneResponseMiddleware(c.afterResponse) cc.dumpOptions = c.dumpOptions.Clone() + cc.retryOption = c.retryOption.Clone() cc.log = c.log cc.jsonUnmarshal = c.jsonUnmarshal From 7ec4061c1cee5de2deb0b0eb2abcb627de2aeb0a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Mar 2022 20:55:03 +0800 Subject: [PATCH 436/843] refactor tests --- client_test.go | 199 +++++++++++++-------------- decode_test.go | 9 +- h2_frame_test.go | 75 +++++------ h2_gotrack_test.go | 11 +- h2_test.go | 17 ++- h2_transport_test.go | 17 ++- internal/socks/socks_test.go | 27 +++- internal/tests/assert.go | 78 ----------- internal/tests/file.go | 9 -- logger_test.go | 5 +- req_test.go | 253 +++++++++++++++++++---------------- request_test.go | 212 ++++++++++++++--------------- response_test.go | 21 ++- retry_test.go | 33 +++-- 14 files changed, 459 insertions(+), 507 deletions(-) delete mode 100644 internal/tests/assert.go diff --git a/client_test.go b/client_test.go index 36f47366..b77ae782 100644 --- a/client_test.go +++ b/client_test.go @@ -20,23 +20,23 @@ func TestAllowGetMethodPayload(t *testing.T) { c := tc() resp, err := c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", resp.String()) + assertEqual(t, "", resp.String()) c.EnableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test", resp.String()) + assertEqual(t, "test", resp.String()) c.DisableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", resp.String()) + assertEqual(t, "", resp.String()) } func TestSetTLSHandshakeTimeout(t *testing.T) { timeout := 2 * time.Second c := tc().SetTLSHandshakeTimeout(timeout) - tests.AssertEqual(t, timeout, c.t.TLSHandshakeTimeout) + assertEqual(t, timeout, c.t.TLSHandshakeTimeout) } func TestSetDial(t *testing.T) { @@ -46,7 +46,7 @@ func TestSetDial(t *testing.T) { } c := tc().SetDial(testDial) _, err := c.t.DialContext(nil, "", "") - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) } func TestSetDialTLS(t *testing.T) { @@ -56,7 +56,7 @@ func TestSetDialTLS(t *testing.T) { } c := tc().SetDialTLS(testDialTLS) _, err := c.t.DialTLSContext(nil, "", "") - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) } func TestSetFuncs(t *testing.T) { @@ -74,31 +74,31 @@ func TestSetFuncs(t *testing.T) { SetXmlUnmarshal(unmarshalFunc) _, err := c.jsonMarshal(nil) - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) err = c.jsonUnmarshal(nil, nil) - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) _, err = c.xmlMarshal(nil) - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) err = c.xmlUnmarshal(nil, nil) - tests.AssertEqual(t, testErr, err) + assertEqual(t, testErr, err) } func TestSetCookieJar(t *testing.T) { c := tc().SetCookieJar(nil) - tests.AssertEqual(t, nil, c.httpClient.Jar) + assertEqual(t, nil, c.httpClient.Jar) } func TestTraceAll(t *testing.T) { c := tc().EnableTraceAll() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, true, resp.TraceInfo().TotalTime > 0) + assertEqual(t, true, resp.TraceInfo().TotalTime > 0) c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, true, resp.TraceInfo().TotalTime == 0) + assertEqual(t, true, resp.TraceInfo().TotalTime == 0) } func TestOnAfterResponse(t *testing.T) { @@ -108,21 +108,21 @@ func TestOnAfterResponse(t *testing.T) { return nil }) len2 := len(c.afterResponse) - tests.AssertEqual(t, true, len1+1 == len2) + assertEqual(t, true, len1+1 == len2) } func TestOnBeforeRequest(t *testing.T) { c := tc().OnBeforeRequest(func(client *Client, request *Request) error { return nil }) - tests.AssertEqual(t, true, len(c.udBeforeRequest) == 1) + assertEqual(t, true, len(c.udBeforeRequest) == 1) } func TestSetProxyURL(t *testing.T) { c := tc().SetProxyURL("http://dummy.proxy.local") u, err := c.t.Proxy(nil) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "http://dummy.proxy.local", u.String()) + assertNoError(t, err) + assertEqual(t, "http://dummy.proxy.local", u.String()) } func TestSetProxy(t *testing.T) { @@ -130,18 +130,18 @@ func TestSetProxy(t *testing.T) { proxy := http.ProxyURL(u) c := tc().SetProxy(proxy) uu, err := c.t.Proxy(nil) - tests.AssertNoError(t, err) - tests.AssertEqual(t, u.String(), uu.String()) + assertNoError(t, err) + assertEqual(t, u.String(), uu.String()) } func TestSetCommonContentType(t *testing.T) { c := tc().SetCommonContentType(jsonContentType) - tests.AssertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) + assertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) } func TestSetCommonHeader(t *testing.T) { c := tc().SetCommonHeader("my-header", "my-value") - tests.AssertEqual(t, "my-value", c.Headers.Get("my-header")) + assertEqual(t, "my-value", c.Headers.Get("my-header")) } func TestSetCommonHeaders(t *testing.T) { @@ -149,41 +149,41 @@ func TestSetCommonHeaders(t *testing.T) { "header1": "value1", "header2": "value2", }) - tests.AssertEqual(t, "value1", c.Headers.Get("header1")) - tests.AssertEqual(t, "value2", c.Headers.Get("header2")) + assertEqual(t, "value1", c.Headers.Get("header1")) + assertEqual(t, "value2", c.Headers.Get("header2")) } func TestSetCommonBasicAuth(t *testing.T) { c := tc().SetCommonBasicAuth("imroc", "123456") - tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) } func TestSetCommonBearerAuthToken(t *testing.T) { c := tc().SetCommonBearerAuthToken("123456") - tests.AssertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) + assertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) } func TestSetUserAgent(t *testing.T) { c := tc().SetUserAgent("test") - tests.AssertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) + assertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) } func TestAutoDecode(t *testing.T) { c := tc().DisableAutoDecode() resp, err := c.R().Get("/gbk") assertSuccess(t, resp, err) - tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) + assertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.EnableAutoDecode().R().Get("/gbk") assertSuccess(t, resp, err) - tests.AssertEqual(t, "我是roc", resp.String()) + assertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentType("html").R().Get("/gbk") assertSuccess(t, resp, err) - tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) + assertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.SetAutoDecodeContentType("text").R().Get("/gbk") assertSuccess(t, resp, err) - tests.AssertEqual(t, "我是roc", resp.String()) + assertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentTypeFunc(func(contentType string) bool { if strings.Contains(contentType, "text") { return true @@ -191,39 +191,39 @@ func TestAutoDecode(t *testing.T) { return false }).R().Get("/gbk") assertSuccess(t, resp, err) - tests.AssertEqual(t, "我是roc", resp.String()) + assertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeAllContentType().R().Get("/gbk-no-charset") assertSuccess(t, resp, err) - tests.AssertContains(t, resp.String(), "我是roc", true) + assertContains(t, resp.String(), "我是roc", true) } func TestSetTimeout(t *testing.T) { timeout := 100 * time.Second c := tc().SetTimeout(timeout) - tests.AssertEqual(t, timeout, c.httpClient.Timeout) + assertEqual(t, timeout, c.httpClient.Timeout) } func TestSetLogger(t *testing.T) { l := createDefaultLogger() c := tc().SetLogger(l) - tests.AssertEqual(t, l, c.log) + assertEqual(t, l, c.log) c.SetLogger(nil) - tests.AssertEqual(t, &disableLogger{}, c.log) + assertEqual(t, &disableLogger{}, c.log) } func TestSetScheme(t *testing.T) { c := tc().SetScheme("https") - tests.AssertEqual(t, "https", c.scheme) + assertEqual(t, "https", c.scheme) } func TestDebugLog(t *testing.T) { c := tc().EnableDebugLog() - tests.AssertEqual(t, true, c.DebugLog) + assertEqual(t, true, c.DebugLog) c.DisableDebugLog() - tests.AssertEqual(t, false, c.DebugLog) + assertEqual(t, false, c.DebugLog) } func TestSetCommonCookies(t *testing.T) { @@ -233,25 +233,25 @@ func TestSetCommonCookies(t *testing.T) { Value: "test", }).R().SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=test", headers.Get("Cookie")) + assertEqual(t, "test=test", headers.Get("Cookie")) } func TestSetCommonQueryString(t *testing.T) { resp, err := tc().SetCommonQueryString("test=test").R().Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=test", resp.String()) + assertEqual(t, "test=test", resp.String()) } func TestSetCommonPathParams(t *testing.T) { c := tc().SetCommonPathParams(map[string]string{"test": "test"}) - tests.AssertNotNil(t, c.PathParams) - tests.AssertEqual(t, "test", c.PathParams["test"]) + assertNotNil(t, c.PathParams) + assertEqual(t, "test", c.PathParams["test"]) } func TestSetCommonPathParam(t *testing.T) { c := tc().SetCommonPathParam("test", "test") - tests.AssertNotNil(t, c.PathParams) - tests.AssertEqual(t, "test", c.PathParams["test"]) + assertNotNil(t, c.PathParams) + assertEqual(t, "test", c.PathParams["test"]) } func TestAddCommonQueryParam(t *testing.T) { @@ -260,97 +260,97 @@ func TestAddCommonQueryParam(t *testing.T) { AddCommonQueryParam("test", "2"). R().Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=1&test=2", resp.String()) + assertEqual(t, "test=1&test=2", resp.String()) } func TestSetCommonQueryParam(t *testing.T) { resp, err := tc().SetCommonQueryParam("test", "test").R().Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=test", resp.String()) + assertEqual(t, "test=test", resp.String()) } func TestSetCommonQueryParams(t *testing.T) { resp, err := tc().SetCommonQueryParams(map[string]string{"test": "test"}).R().Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=test", resp.String()) + assertEqual(t, "test=test", resp.String()) } func TestInsecureSkipVerify(t *testing.T) { c := tc().EnableInsecureSkipVerify() - tests.AssertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) + assertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) c.DisableInsecureSkipVerify() - tests.AssertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) + assertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) } func TestSetTLSClientConfig(t *testing.T) { config := &tls.Config{InsecureSkipVerify: true} c := tc().SetTLSClientConfig(config) - tests.AssertEqual(t, config, c.t.TLSClientConfig) + assertEqual(t, config, c.t.TLSClientConfig) } func TestCompression(t *testing.T) { c := tc().DisableCompression() - tests.AssertEqual(t, true, c.t.DisableCompression) + assertEqual(t, true, c.t.DisableCompression) c.EnableCompression() - tests.AssertEqual(t, false, c.t.DisableCompression) + assertEqual(t, false, c.t.DisableCompression) } func TestKeepAlives(t *testing.T) { c := tc().DisableKeepAlives() - tests.AssertEqual(t, true, c.t.DisableKeepAlives) + assertEqual(t, true, c.t.DisableKeepAlives) c.EnableKeepAlives() - tests.AssertEqual(t, false, c.t.DisableKeepAlives) + assertEqual(t, false, c.t.DisableKeepAlives) } func TestRedirect(t *testing.T) { _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "redirect is disabled", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect is disabled", true) _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "stopped after 3 redirects", true) _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "different domain name is not allowed", true) _, err = tc().SetRedirectPolicy(SameHostRedirectPolicy()).R().Get("/redirect-to-other") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "different host name is not allowed", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "different host name is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedHostRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) + assertNotNil(t, err) + assertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) } func TestGetTLSClientConfig(t *testing.T) { c := tc() config := c.GetTLSClientConfig() - tests.AssertEqual(t, true, c.t.TLSClientConfig != nil) - tests.AssertEqual(t, config, c.t.TLSClientConfig) + assertEqual(t, true, c.t.TLSClientConfig != nil) + assertEqual(t, config, c.t.TLSClientConfig) } func TestSetRootCertFromFile(t *testing.T) { c := tc().SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) - tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetRootCertFromString(t *testing.T) { - c := tc().SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))) - tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + c := tc().SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) + assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetCerts(t *testing.T) { c := tc().SetCerts(tls.Certificate{}, tls.Certificate{}) - tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) + assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) } func TestSetCertFromFile(t *testing.T) { @@ -358,7 +358,7 @@ func TestSetCertFromFile(t *testing.T) { tests.GetTestFilePath("sample-client.pem"), tests.GetTestFilePath("sample-client-key.pem"), ) - tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) + assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) } func TestSetOutputDirectory(t *testing.T) { @@ -368,15 +368,15 @@ func TestSetOutputDirectory(t *testing.T) { R().SetOutputFile(outFile). Get("/") assertSuccess(t, resp, err) - content := string(tests.GetTestFileContent(t, outFile)) + content := string(getTestFileContent(t, outFile)) os.Remove(tests.GetTestFilePath(outFile)) - tests.AssertEqual(t, "TestGet: text response", content) + assertEqual(t, "TestGet: text response", content) } func TestSetBaseURL(t *testing.T) { baseURL := "http://dummy-req.local/test" resp, _ := tc().SetTimeout(time.Nanosecond).SetBaseURL(baseURL).R().Get("/req") - tests.AssertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) + assertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) } func TestSetCommonFormDataFromValues(t *testing.T) { @@ -388,7 +388,7 @@ func TestSetCommonFormDataFromValues(t *testing.T) { R().SetResult(&gotForm). Post("/form") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test", gotForm.Get("test")) + assertEqual(t, "test", gotForm.Get("test")) } func TestSetCommonFormData(t *testing.T) { @@ -401,7 +401,7 @@ func TestSetCommonFormData(t *testing.T) { SetResult(&form). Post("/form") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test", form.Get("test")) + assertEqual(t, "test", form.Get("test")) } func TestClientClone(t *testing.T) { @@ -411,10 +411,13 @@ func TestClientClone(t *testing.T) { Name: "test", Value: "test", }).SetCommonQueryParam("test", "test"). - SetCommonPathParam("test", "test") + SetCommonPathParam("test", "test"). + SetCommonRetryCount(2). + SetCommonFormData(map[string]string{"test": "test"}). + OnBeforeRequest(func(c *Client, r *Request) error { return nil }) c2 := c1.Clone() - assertEqualStruct(t, c1, c2, false, "t", "t2", "httpClient") + assertClone(t, c1, c2) } func TestDisableAutoReadResponse(t *testing.T) { @@ -426,15 +429,15 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { c.DisableAutoReadResponse() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", resp.String()) + assertEqual(t, "", resp.String()) result, err := resp.ToString() - tests.AssertNoError(t, err) - tests.AssertEqual(t, "TestGet: text response", result) + assertNoError(t, err) + assertEqual(t, "TestGet: text response", result) resp, err = c.R().Get("/") assertSuccess(t, resp, err) _, err = ioutil.ReadAll(resp.Body) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestEnableDumpAll(t *testing.T) { @@ -516,10 +519,10 @@ func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, resp resp, err := c.R().SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := buf.String() - tests.AssertContains(t, dump, "user-agent", reqHeader) - tests.AssertContains(t, dump, "test body", reqBody) - tests.AssertContains(t, dump, "date", respHeader) - tests.AssertContains(t, dump, "testpost: text response", respBody) + assertContains(t, dump, "user-agent", reqHeader) + assertContains(t, dump, "test body", reqBody) + assertContains(t, dump, "date", respHeader) + assertContains(t, dump, "testpost: text response", respBody) } testDump(c) buf = new(bytes.Buffer) @@ -540,10 +543,10 @@ func TestSetCommonDumpOptions(t *testing.T) { c.SetCommonDumpOptions(opt).EnableDumpAll() resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) - tests.AssertContains(t, buf.String(), "user-agent", true) - tests.AssertContains(t, buf.String(), "test body", false) - tests.AssertContains(t, buf.String(), "date", false) - tests.AssertContains(t, buf.String(), "testpost: text response", true) + assertContains(t, buf.String(), "user-agent", true) + assertContains(t, buf.String(), "test body", false) + assertContains(t, buf.String(), "date", false) + assertContains(t, buf.String(), "testpost: text response", true) } func TestEnableDumpAllToFile(t *testing.T) { @@ -552,17 +555,17 @@ func TestEnableDumpAllToFile(t *testing.T) { c.EnableDumpAllToFile(tests.GetTestFilePath(dumpFile)) resp, err := c.R().SetBody("test body").Post("/") assertSuccess(t, resp, err) - dump := string(tests.GetTestFileContent(t, dumpFile)) + dump := string(getTestFileContent(t, dumpFile)) os.Remove(tests.GetTestFilePath(dumpFile)) - tests.AssertContains(t, dump, "user-agent", true) - tests.AssertContains(t, dump, "test body", true) - tests.AssertContains(t, dump, "date", true) - tests.AssertContains(t, dump, "testpost: text response", true) + assertContains(t, dump, "user-agent", true) + assertContains(t, dump, "test body", true) + assertContains(t, dump, "date", true) + assertContains(t, dump, "testpost: text response", true) } func TestEnableDumpAllAsync(t *testing.T) { c := tc() buf := new(bytes.Buffer) c.EnableDumpAllTo(buf).EnableDumpAllAsync() - tests.AssertEqual(t, true, c.getDumpOptions().Async) + assertEqual(t, true, c.getDumpOptions().Async) } diff --git a/decode_test.go b/decode_test.go index e65a8ea6..fc810e83 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1,7 +1,6 @@ package req import ( - "github.com/imroc/req/v3/internal/tests" "testing" ) @@ -9,9 +8,9 @@ func TestPeekDrain(t *testing.T) { a := autoDecodeReadCloser{peek: []byte("test")} p := make([]byte, 2) n, _ := a.peekDrain(p) - tests.AssertEqual(t, 2, n) - tests.AssertEqual(t, true, a.peek != nil) + assertEqual(t, 2, n) + assertEqual(t, true, a.peek != nil) n, _ = a.peekDrain(p) - tests.AssertEqual(t, 2, n) - tests.AssertEqual(t, true, a.peek == nil) + assertEqual(t, 2, n) + assertEqual(t, true, a.peek == nil) } diff --git a/h2_frame_test.go b/h2_frame_test.go index 136a9bd7..573dd99a 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -7,7 +7,6 @@ package req import ( "bytes" "fmt" - "github.com/imroc/req/v3/internal/tests" "io" "reflect" "strings" @@ -1290,38 +1289,38 @@ func TestParseSettingsFrame(t *testing.T) { fh.Length = 1 countErr := func(s string) {} _, err := http2parseSettingsFrame(nil, fh, countErr, nil) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") fh = http2FrameHeader{StreamID: 1} _, err = http2parseSettingsFrame(nil, fh, countErr, nil) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh = http2FrameHeader{} _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("roc")) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") fh = http2FrameHeader{valid: true} _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestParsePushPromise(t *testing.T) { fh := http2FrameHeader{} countError := func(string) {} _, err := http2parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 1 fh.Flags = http2FlagPushPromisePadded _, err = http2parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "EOF") + assertErrorContains(t, err, "EOF") fh.Flags = 0 _, err = http2parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "EOF") + assertErrorContains(t, err, "EOF") _, err = http2parsePushPromise(nil, fh, countError, []byte("ksjfksjksjflskk")) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestSummarizeFrame(t *testing.T) { @@ -1329,62 +1328,62 @@ func TestSummarizeFrame(t *testing.T) { var f http2Frame f = &http2SettingsFrame{http2FrameHeader: fh, p: []byte{0x09, 0x01, 0x80, 0x20, 0x00, 0x11}} s := http2summarizeFrame(f) - tests.AssertContains(t, s, "len=0", true) + assertContains(t, s, "len=0", true) f = &http2DataFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, `data=""`, true) + assertContains(t, s, `data=""`, true) f = &http2WindowUpdateFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, "conn", true) + assertContains(t, s, "conn", true) f = &http2PingFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, "ping", true) + assertContains(t, s, "ping", true) f = &http2GoAwayFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, "laststreamid", true) + assertContains(t, s, "laststreamid", true) f = &http2RSTStreamFrame{http2FrameHeader: fh} s = http2summarizeFrame(f) - tests.AssertContains(t, s, "no_error", true) + assertContains(t, s, "no_error", true) } func TestParseDataFrame(t *testing.T) { fh := http2FrameHeader{valid: true} countError := func(string) {} _, err := http2parseDataFrame(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "DATA frame with stream ID 0") + assertErrorContains(t, err, "DATA frame with stream ID 0") fh.StreamID = 1 fh.Flags = http2FlagDataPadded fc := &http2frameCache{} payload := []byte{0x09, 0x00, 0x00, 0x98, 0x11, 0x12} _, err = http2parseDataFrame(fc, fh, countError, payload) - tests.AssertErrorContains(t, err, "pad size larger than data payload") + assertErrorContains(t, err, "pad size larger than data payload") payload = []byte{0x02, 0x00, 0x00, 0x98, 0x11, 0x12} _, err = http2parseDataFrame(fc, fh, countError, payload) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestParseWindowUpdateFrame(t *testing.T) { fh := http2FrameHeader{valid: true} countError := func(string) {} _, err := http2parseWindowUpdateFrame(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") p := []byte{0x00, 0x00, 0x00, 0x00} _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 255 p[0] = 0x01 p[3] = 0x01 _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestParseUnknownFrame(t *testing.T) { @@ -1392,12 +1391,12 @@ func TestParseUnknownFrame(t *testing.T) { countError := func(string) {} p := []byte("test") f, err := http2parseUnknownFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) + assertNoError(t, err) uf, ok := f.(*http2UnknownFrame) if !ok { t.Fatalf("not http2UnknownFrame type: %#+v", f) } - tests.AssertEqual(t, p, uf.Payload()) + assertEqual(t, p, uf.Payload()) } func TestParseRSTStreamFrame(t *testing.T) { @@ -1405,15 +1404,15 @@ func TestParseRSTStreamFrame(t *testing.T) { countError := func(string) {} p := []byte("test.") _, err := http2parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") p = []byte("test") _, err = http2parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 1 _, err = http2parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestParsePingFrame(t *testing.T) { @@ -1421,16 +1420,16 @@ func TestParsePingFrame(t *testing.T) { countError := func(string) {} payload := []byte("") _, err := http2parsePingFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") payload = []byte("testtest") fh.StreamID = 1 _, err = http2parsePingFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 0 _, err = http2parsePingFrame(nil, fh, countError, payload) - tests.AssertNoError(t, err) + assertNoError(t, err) } func TestParseGoAwayFrame(t *testing.T) { @@ -1440,40 +1439,40 @@ func TestParseGoAwayFrame(t *testing.T) { fh.StreamID = 1 _, err := http2parseGoAwayFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") + assertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 0 _, err = http2parseGoAwayFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") + assertErrorContains(t, err, "FRAME_SIZE_ERROR") } func TestPushPromiseFrame(t *testing.T) { fh := http2FrameHeader{valid: true} buf := []byte("test") f := &http2PushPromiseFrame{http2FrameHeader: fh, headerFragBuf: buf} - tests.AssertEqual(t, buf, f.HeaderBlockFragment()) - tests.AssertEqual(t, false, f.HeadersEnded()) + assertEqual(t, buf, f.HeaderBlockFragment()) + assertEqual(t, false, f.HeadersEnded()) } func TestH2Framer(t *testing.T) { f := &http2Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} f.logWrite() - tests.AssertNotNil(t, f.debugFramer) - tests.AssertIsNil(t, f.ErrorDetail()) + assertNotNil(t, f.debugFramer) + assertIsNil(t, f.ErrorDetail()) f.w = new(bytes.Buffer) err := f.WriteRawFrame(http2FrameData, http2FlagDataEndStream, 1, nil) - tests.AssertNoError(t, err) + assertNoError(t, err) param := http2PushPromiseParam{} err = f.WritePushPromise(param) - tests.AssertErrorContains(t, err, "invalid stream ID") + assertErrorContains(t, err, "invalid stream ID") param.StreamID = 1 param.EndHeaders = true param.PadLength = 2 f.AllowIllegalWrites = true err = f.WritePushPromise(param) - tests.AssertNoError(t, err) + assertNoError(t, err) } diff --git a/h2_gotrack_test.go b/h2_gotrack_test.go index fc47c139..3d0ccdc1 100644 --- a/h2_gotrack_test.go +++ b/h2_gotrack_test.go @@ -6,7 +6,6 @@ package req import ( "fmt" - "github.com/imroc/req/v3/internal/tests" "strings" "testing" ) @@ -36,21 +35,21 @@ func TestGoroutineLock(t *testing.T) { func TestParseUintBytes(t *testing.T) { s := []byte{} _, err := http2parseUintBytes(s, 0, 0) - tests.AssertErrorContains(t, err, "invalid syntax") + assertErrorContains(t, err, "invalid syntax") s = []byte("0x") _, err = http2parseUintBytes(s, 0, 0) - tests.AssertErrorContains(t, err, "invalid syntax") + assertErrorContains(t, err, "invalid syntax") s = []byte("0x01") _, err = http2parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) + assertNoError(t, err) s = []byte("0xa1") _, err = http2parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) + assertNoError(t, err) s = []byte("0xA1") _, err = http2parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) + assertNoError(t, err) } diff --git a/h2_test.go b/h2_test.go index d1249f1f..f96928ac 100644 --- a/h2_test.go +++ b/h2_test.go @@ -7,7 +7,6 @@ package req import ( "flag" "fmt" - "github.com/imroc/req/v3/internal/tests" "net/http" "testing" "time" @@ -99,21 +98,21 @@ func TestSettingValid(t *testing.T) { } for _, c := range cases { s := &http2Setting{ID: c.id, Val: c.val} - tests.AssertEqual(t, true, s.Valid() != nil) + assertEqual(t, true, s.Valid() != nil) } s := &http2Setting{ID: http2SettingMaxHeaderListSize} - tests.AssertEqual(t, true, s.Valid() == nil) + assertEqual(t, true, s.Valid() == nil) } func TestBodyAllowedForStatus(t *testing.T) { - tests.AssertEqual(t, false, http2bodyAllowedForStatus(101)) - tests.AssertEqual(t, false, http2bodyAllowedForStatus(204)) - tests.AssertEqual(t, false, http2bodyAllowedForStatus(304)) - tests.AssertEqual(t, true, http2bodyAllowedForStatus(900)) + assertEqual(t, false, http2bodyAllowedForStatus(101)) + assertEqual(t, false, http2bodyAllowedForStatus(204)) + assertEqual(t, false, http2bodyAllowedForStatus(304)) + assertEqual(t, true, http2bodyAllowedForStatus(900)) } func TestHttpError(t *testing.T) { e := &http2httpError{msg: "test"} - tests.AssertEqual(t, "test", e.Error()) - tests.AssertEqual(t, true, e.Temporary()) + assertEqual(t, "test", e.Error()) + assertEqual(t, true, e.Temporary()) } diff --git a/h2_transport_test.go b/h2_transport_test.go index 7eaabe25..360f2273 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -13,7 +13,6 @@ import ( "errors" "flag" "fmt" - "github.com/imroc/req/v3/internal/tests" "io" "io/ioutil" "log" @@ -5865,27 +5864,27 @@ func TestCountReadFrameError(t *testing.T) { var err error cc.countReadFrameError(err) - tests.AssertEqual(t, "", errMsg) + assertEqual(t, "", errMsg) err = http2ConnectionError(http2ErrCodeInternal) cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_conn_error", true) + assertContains(t, errMsg, "read_frame_conn_error", true) err = io.EOF cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_eof", true) + assertContains(t, errMsg, "read_frame_eof", true) err = io.ErrUnexpectedEOF cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_unexpected_eof", true) + assertContains(t, errMsg, "read_frame_unexpected_eof", true) err = errFrameTooLarge cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_too_large", true) + assertContains(t, errMsg, "read_frame_too_large", true) err = errors.New("other") cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_other", true) + assertContains(t, errMsg, "read_frame_other", true) } func TestProcessHeaders(t *testing.T) { @@ -5897,9 +5896,9 @@ func TestProcessHeaders(t *testing.T) { http2FrameHeader: http2FrameHeader{StreamID: 1}, }} err := rl.processHeaders(f) - tests.AssertNoError(t, err) + assertNoError(t, err) f.StreamID = 0 err = rl.processHeaders(f) - tests.AssertNoError(t, err) + assertNoError(t, err) } diff --git a/internal/socks/socks_test.go b/internal/socks/socks_test.go index 824a09d7..cc3af621 100644 --- a/internal/socks/socks_test.go +++ b/internal/socks/socks_test.go @@ -3,7 +3,6 @@ package socks import ( "bytes" "context" - "github.com/imroc/req/v3/internal/tests" "strings" "testing" ) @@ -21,6 +20,22 @@ func TestReply(t *testing.T) { } } +func assertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Error occurred [%v]", err) + } +} + +func assertErrorContains(t *testing.T, err error, s string) { + if err == nil { + t.Error("err is nil") + return + } + if !strings.Contains(err.Error(), s) { + t.Errorf("%q is not included in error %q", s, err.Error()) + } +} + func TestAuthenticate(t *testing.T) { auth := &UsernamePassword{ Username: "imroc", @@ -28,21 +43,21 @@ func TestAuthenticate(t *testing.T) { } buf := bytes.NewBuffer([]byte{byte(0x01), byte(0x00)}) err := auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - tests.AssertNoError(t, err) + assertNoError(t, err) auth.Username = "this is a very long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long name" err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - tests.AssertErrorContains(t, err, "invalid") + assertErrorContains(t, err, "invalid") auth.Username = "imroc" buf = bytes.NewBuffer([]byte{byte(0x03), byte(0x00)}) err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - tests.AssertErrorContains(t, err, "invalid username/password version") + assertErrorContains(t, err, "invalid username/password version") buf = bytes.NewBuffer([]byte{byte(0x01), byte(0x02)}) err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - tests.AssertErrorContains(t, err, "authentication failed") + assertErrorContains(t, err, "authentication failed") err = auth.Authenticate(context.Background(), buf, AuthMethodNoAcceptableMethods) - tests.AssertErrorContains(t, err, "unsupported authentication method") + assertErrorContains(t, err, "unsupported authentication method") } diff --git a/internal/tests/assert.go b/internal/tests/assert.go deleted file mode 100644 index 9d41e2b6..00000000 --- a/internal/tests/assert.go +++ /dev/null @@ -1,78 +0,0 @@ -package tests - -import ( - "reflect" - "strings" - "testing" -) - -// AssertIsNil asserts is nil. -func AssertIsNil(t *testing.T, v interface{}) { - if !isNil(v) { - t.Errorf("[%v] was expected to be nil", v) - } -} - -// AssertNotNil asserts is not nil. -func AssertNotNil(t *testing.T, v interface{}) { - if isNil(v) { - t.Fatalf("[%v] was expected to be non-nil", v) - } -} - -// AssertEqual asserts e (expected) is equal with g (got). -func AssertEqual(t *testing.T, e, g interface{}) { - if !equal(e, g) { - t.Errorf("Expected [%+v], got [%+v]", e, g) - } - return -} - -// AssertNoError asserts no error. -func AssertNoError(t *testing.T, err error) { - if err != nil { - t.Errorf("Error occurred [%v]", err) - } -} - -// AssertErrorContains asserts error is not nil and contains specified error string. -func AssertErrorContains(t *testing.T, err error, s string) { - if err == nil { - t.Error("err is nil") - return - } - if !strings.Contains(err.Error(), s) { - t.Errorf("%q is not included in error %q", s, err.Error()) - } -} - -// AssertContains asserts substring is contained in the given string. -func AssertContains(t *testing.T, s, substr string, shouldContain bool) { - s = strings.ToLower(s) - isContain := strings.Contains(s, substr) - if shouldContain { - if !isContain { - t.Errorf("%q is not included in %s", substr, s) - } - } else { - if isContain { - t.Errorf("%q is included in %s", substr, s) - } - } -} - -func equal(expected, got interface{}) bool { - return reflect.DeepEqual(expected, got) -} - -func isNil(v interface{}) bool { - if v == nil { - return true - } - rv := reflect.ValueOf(v) - kind := rv.Kind() - if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { - return true - } - return false -} diff --git a/internal/tests/file.go b/internal/tests/file.go index f4fc2341..fc7753a5 100644 --- a/internal/tests/file.go +++ b/internal/tests/file.go @@ -1,10 +1,8 @@ package tests import ( - "io/ioutil" "os" "path/filepath" - "testing" ) var testDataPath string @@ -14,13 +12,6 @@ func init() { testDataPath = filepath.Join(pwd, ".testdata") } -// GetTestFileContent return test file content. -func GetTestFileContent(t *testing.T, filename string) []byte { - b, err := ioutil.ReadFile(GetTestFilePath(filename)) - AssertNoError(t, err) - return b -} - // GetTestFilePath return test file absolute path. func GetTestFilePath(filename string) string { return filepath.Join(testDataPath, filename) diff --git a/logger_test.go b/logger_test.go index c15dd6f4..2718392d 100644 --- a/logger_test.go +++ b/logger_test.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/tests" "log" "testing" ) @@ -12,8 +11,8 @@ func TestLogger(t *testing.T) { l := NewLogger(buf, "", log.Ldate|log.Lmicroseconds) c := tc().SetLogger(l) c.SetProxyURL(":=\\<>ksfj&*&sf") - tests.AssertContains(t, buf.String(), "error", true) + assertContains(t, buf.String(), "error", true) buf.Reset() c.R().SetOutput(nil) - tests.AssertContains(t, buf.String(), "warn", true) + assertContains(t, buf.String(), "warn", true) } diff --git a/req_test.go b/req_test.go index 91a37992..078c325a 100644 --- a/req_test.go +++ b/req_test.go @@ -71,6 +71,108 @@ func getTestServerURL() string { return testServer.URL } +func getTestFileContent(t *testing.T, filename string) []byte { + b, err := ioutil.ReadFile(tests.GetTestFilePath(filename)) + assertNoError(t, err) + return b +} + +func assertIsNil(t *testing.T, v interface{}) { + if !isNil(v) { + t.Errorf("[%v] was expected to be nil", v) + } +} + +func assertNotNil(t *testing.T, v interface{}) { + if isNil(v) { + t.Fatalf("[%v] was expected to be non-nil", v) + } +} + +func assertEqual(t *testing.T, e, g interface{}) { + if !equal(e, g) { + t.Errorf("Expected [%+v], got [%+v]", e, g) + } + return +} + +func assertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Error occurred [%v]", err) + } +} + +func assertErrorContains(t *testing.T, err error, s string) { + if err == nil { + t.Error("err is nil") + return + } + if !strings.Contains(err.Error(), s) { + t.Errorf("%q is not included in error %q", s, err.Error()) + } +} + +func assertContains(t *testing.T, s, substr string, shouldContain bool) { + s = strings.ToLower(s) + isContain := strings.Contains(s, substr) + if shouldContain { + if !isContain { + t.Errorf("%q is not included in %s", substr, s) + } + } else { + if isContain { + t.Errorf("%q is included in %s", substr, s) + } + } +} + +func assertClone(t *testing.T, e, g interface{}) { + ev := reflect.ValueOf(e).Elem() + gv := reflect.ValueOf(g).Elem() + et := ev.Type() + + for i := 0; i < ev.NumField(); i++ { + sf := ev.Field(i) + st := et.Field(i) + + var ee, gg interface{} + if !token.IsExported(st.Name) { + ee = reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Interface() + gg = reflect.NewAt(sf.Type(), unsafe.Pointer(gv.Field(i).UnsafeAddr())).Elem().Interface() + } else { + ee = sf.Interface() + gg = gv.Field(i).Interface() + } + if sf.Kind() == reflect.Func || sf.Kind() == reflect.Slice || sf.Kind() == reflect.Ptr { + if ee != nil { + if gg == nil { + t.Errorf("Field %s.%s is nil", et.Name(), et.Field(i).Name) + } + } + continue + } + if !reflect.DeepEqual(ee, gg) { + t.Errorf("Field %s.%s is not equal, expected [%v], got [%v]", et.Name(), et.Field(i).Name, ee, gg) + } + } +} + +func equal(expected, got interface{}) bool { + return reflect.DeepEqual(expected, got) +} + +func isNil(v interface{}) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + kind := rv.Kind() + if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { + return true + } + return false +} + type echo struct { Header http.Header `json:"header" xml:"header"` Body string `json:"body" xml:"body"` @@ -262,106 +364,33 @@ func handleGet(w http.ResponseWriter, r *http.Request) { } func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { - tests.AssertNoError(t, err) - tests.AssertNotNil(t, resp) - tests.AssertNotNil(t, resp.Body) - tests.AssertEqual(t, statusCode, resp.StatusCode) - tests.AssertEqual(t, status, resp.Status) + assertNoError(t, err) + assertNotNil(t, resp) + assertNotNil(t, resp.Body) + assertEqual(t, statusCode, resp.StatusCode) + assertEqual(t, status, resp.Status) } func assertSuccess(t *testing.T, resp *Response, err error) { - tests.AssertNoError(t, err) - tests.AssertNotNil(t, resp.Response) - tests.AssertNotNil(t, resp.Response.Body) - tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - tests.AssertEqual(t, "200 OK", resp.Status) + assertNoError(t, err) + assertNotNil(t, resp.Response) + assertNotNil(t, resp.Response.Body) + assertEqual(t, http.StatusOK, resp.StatusCode) + assertEqual(t, "200 OK", resp.Status) if !resp.IsSuccess() { t.Error("Response.IsSuccess should return true") } } func assertIsError(t *testing.T, resp *Response, err error) { - tests.AssertNoError(t, err) - tests.AssertNotNil(t, resp) - tests.AssertNotNil(t, resp.Body) + assertNoError(t, err) + assertNotNil(t, resp) + assertNotNil(t, resp.Body) if !resp.IsError() { t.Error("Response.IsError should return true") } } -func assertEqualStruct(t *testing.T, e, g interface{}, onlyExported bool, excludes ...string) { - ev := reflect.ValueOf(e).Elem() - gv := reflect.ValueOf(g).Elem() - et := ev.Type() - gt := gv.Type() - m := map[string]bool{} - for _, exclude := range excludes { - m[exclude] = true - } - if et.Kind() != reflect.Struct { - t.Fatalf("expect object should be struct instead of %v", et.Kind().String()) - } - - if gt.Kind() != reflect.Struct { - t.Fatalf("got object should be struct instead of %v", gt.Kind().String()) - } - - if et.Name() != gt.Name() { - t.Fatalf("Expected type [%s], got [%s]", et.Name(), gt.Name()) - } - - for i := 0; i < ev.NumField(); i++ { - sf := ev.Field(i) - if sf.Kind() == reflect.Func || sf.Kind() == reflect.Slice { - continue - } - st := et.Field(i) - if m[st.Name] { - continue - } - if onlyExported && !token.IsExported(st.Name) { - continue - } - var ee, gg interface{} - if !token.IsExported(st.Name) { - ee = reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Interface() - gg = reflect.NewAt(sf.Type(), unsafe.Pointer(gv.Field(i).UnsafeAddr())).Elem().Interface() - } else { - ee = sf.Interface() - gg = gv.Field(i).Interface() - } - if !reflect.DeepEqual(ee, gg) { - t.Errorf("Field %s.%s is not equal, expected [%v], got [%v]", et.Name(), et.Field(i).Name, ee, gg) - } - } - -} - -func assertNotEqual(t *testing.T, e, g interface{}) (r bool) { - if equal(e, g) { - t.Errorf("Expected [%v], got [%v]", e, g) - } else { - r = true - } - return -} - -func equal(expected, got interface{}) bool { - return reflect.DeepEqual(expected, got) -} - -func isNil(v interface{}) bool { - if v == nil { - return true - } - rv := reflect.ValueOf(v) - kind := rv.Kind() - if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { - return true - } - return false -} - func testGlobalWrapperEnableDumps(t *testing.T) { testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { *reqHeader = true @@ -421,17 +450,17 @@ func testGlobalWrapperEnableDumps(t *testing.T) { buf := new(bytes.Buffer) r := EnableDumpTo(buf) - tests.AssertEqual(t, true, r.getDumpOptions().Output != nil) + assertEqual(t, true, r.getDumpOptions().Output != nil) dumpFile := tests.GetTestFilePath("req_tmp_dump.out") r = EnableDumpToFile(tests.GetTestFilePath(dumpFile)) - tests.AssertEqual(t, true, r.getDumpOptions().Output != nil) + assertEqual(t, true, r.getDumpOptions().Output != nil) os.Remove(dumpFile) r = SetDumpOptions(&DumpOptions{ RequestHeader: true, }) - tests.AssertEqual(t, true, r.getDumpOptions().RequestHeader) + assertEqual(t, true, r.getDumpOptions().RequestHeader) } func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { @@ -460,50 +489,50 @@ func testGlobalWrapperSendRequest(t *testing.T) { resp, err := Put(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) + assertEqual(t, "PUT", resp.Header.Get("Method")) resp = MustPut(testURL) - tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) + assertEqual(t, "PUT", resp.Header.Get("Method")) resp, err = Patch(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) + assertEqual(t, "PATCH", resp.Header.Get("Method")) resp = MustPatch(testURL) - tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) + assertEqual(t, "PATCH", resp.Header.Get("Method")) resp, err = Delete(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) + assertEqual(t, "DELETE", resp.Header.Get("Method")) resp = MustDelete(testURL) - tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) + assertEqual(t, "DELETE", resp.Header.Get("Method")) resp, err = Options(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp = MustOptions(testURL) - tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp, err = Head(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) + assertEqual(t, "HEAD", resp.Header.Get("Method")) resp = MustHead(testURL) - tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) + assertEqual(t, "HEAD", resp.Header.Get("Method")) resp, err = Get(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "GET", resp.Header.Get("Method")) + assertEqual(t, "GET", resp.Header.Get("Method")) resp = MustGet(testURL) - tests.AssertEqual(t, "GET", resp.Header.Get("Method")) + assertEqual(t, "GET", resp.Header.Get("Method")) resp, err = Post(testURL) assertSuccess(t, resp, err) - tests.AssertEqual(t, "POST", resp.Header.Get("Method")) + assertEqual(t, "POST", resp.Header.Get("Method")) resp = MustPost(testURL) - tests.AssertEqual(t, "POST", resp.Header.Get("Method")) + assertEqual(t, "POST", resp.Header.Get("Method")) } func testGlobalWrapperSetRequest(t *testing.T, rs ...*Request) { for _, r := range rs { - tests.AssertNotNil(t, r) + assertNotNil(t, r) } } @@ -576,7 +605,7 @@ func TestGlobalWrapperSetRequest(t *testing.T) { func testGlobalClientSettingWrapper(t *testing.T, cs ...*Client) { for _, c := range cs { - tests.AssertNotNil(t, c) + assertNotNil(t, c) } } @@ -655,7 +684,7 @@ func TestGlobalWrapper(t *testing.T) { DisableKeepAlives(), EnableKeepAlives(), SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")), - SetRootCertFromString(string(tests.GetTestFileContent(t, "sample-root.pem"))), + SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))), SetCerts(tls.Certificate{}, tls.Certificate{}), SetCertFromFile( tests.GetTestFilePath("sample-client.pem"), @@ -704,21 +733,21 @@ func TestGlobalWrapper(t *testing.T) { os.Remove(tests.GetTestFilePath("tmpdump.out")) config := GetTLSClientConfig() - tests.AssertEqual(t, config, DefaultClient().t.TLSClientConfig) + assertEqual(t, config, DefaultClient().t.TLSClientConfig) r := R() - tests.AssertEqual(t, true, r != nil) + assertEqual(t, true, r != nil) c := C() c.SetTimeout(10 * time.Second) SetDefaultClient(c) - tests.AssertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - tests.AssertEqual(t, GetClient(), DefaultClient().httpClient) + assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) + assertEqual(t, GetClient(), DefaultClient().httpClient) r = NewRequest() - tests.AssertEqual(t, true, r != nil) + assertEqual(t, true, r != nil) c = NewClient() - tests.AssertEqual(t, true, c != nil) + assertEqual(t, true, c != nil) } func TestTrailer(t *testing.T) { diff --git a/request_test.go b/request_test.go index e82fc96a..f0bcf397 100644 --- a/request_test.go +++ b/request_test.go @@ -24,45 +24,45 @@ func TestMethods(t *testing.T) { func testMethods(t *testing.T, c *Client) { resp, err := c.R().Put("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) + assertEqual(t, "PUT", resp.Header.Get("Method")) resp = c.R().MustPut("/") - tests.AssertEqual(t, "PUT", resp.Header.Get("Method")) + assertEqual(t, "PUT", resp.Header.Get("Method")) resp, err = c.R().Patch("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) + assertEqual(t, "PATCH", resp.Header.Get("Method")) resp = c.R().MustPatch("/") - tests.AssertEqual(t, "PATCH", resp.Header.Get("Method")) + assertEqual(t, "PATCH", resp.Header.Get("Method")) resp, err = c.R().Delete("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) + assertEqual(t, "DELETE", resp.Header.Get("Method")) resp = c.R().MustDelete("/") - tests.AssertEqual(t, "DELETE", resp.Header.Get("Method")) + assertEqual(t, "DELETE", resp.Header.Get("Method")) resp, err = c.R().Options("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp = c.R().MustOptions("/") - tests.AssertEqual(t, "OPTIONS", resp.Header.Get("Method")) + assertEqual(t, "OPTIONS", resp.Header.Get("Method")) resp, err = c.R().Head("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) + assertEqual(t, "HEAD", resp.Header.Get("Method")) resp = c.R().MustHead("/") - tests.AssertEqual(t, "HEAD", resp.Header.Get("Method")) + assertEqual(t, "HEAD", resp.Header.Get("Method")) resp, err = c.R().Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "GET", resp.Header.Get("Method")) + assertEqual(t, "GET", resp.Header.Get("Method")) resp = c.R().MustGet("/") - tests.AssertEqual(t, "GET", resp.Header.Get("Method")) + assertEqual(t, "GET", resp.Header.Get("Method")) resp, err = c.R().Post("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "POST", resp.Header.Get("Method")) + assertEqual(t, "POST", resp.Header.Get("Method")) resp = c.R().MustPost("/") - tests.AssertEqual(t, "POST", resp.Header.Get("Method")) + assertEqual(t, "POST", resp.Header.Get("Method")) } func TestEnableDump(t *testing.T) { @@ -79,7 +79,7 @@ func TestEnableDumpToFIle(t *testing.T) { tmpFile := "tmp_dumpfile_req" resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, true, len(tests.GetTestFileContent(t, tmpFile)) > 0) + assertEqual(t, true, len(getTestFileContent(t, tmpFile)) > 0) os.Remove(tests.GetTestFilePath(tmpFile)) } @@ -151,10 +151,10 @@ func testEnableDump(t *testing.T, fn func(r *Request, reqHeader, reqBody, respHe resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := resp.Dump() - tests.AssertContains(t, dump, "user-agent", reqHeader) - tests.AssertContains(t, dump, "test body", reqBody) - tests.AssertContains(t, dump, "date", respHeader) - tests.AssertContains(t, dump, "testpost: text response", respBody) + assertContains(t, dump, "user-agent", reqHeader) + assertContains(t, dump, "test body", reqBody) + assertContains(t, dump, "date", respHeader) + assertContains(t, dump, "testpost: text response", respBody) } testDump(tc()) testDump(tc().EnableForceHTTP1()) @@ -175,10 +175,10 @@ func testSetDumpOptions(t *testing.T, c *Client) { resp, err := c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) assertSuccess(t, resp, err) dump := resp.Dump() - tests.AssertContains(t, dump, "user-agent", true) - tests.AssertContains(t, dump, "test body", false) - tests.AssertContains(t, dump, "date", false) - tests.AssertContains(t, dump, "testpost: text response", true) + assertContains(t, dump, "user-agent", true) + assertContains(t, dump, "test body", false) + assertContains(t, dump, "date", false) + assertContains(t, dump, "testpost: text response", true) } func TestGet(t *testing.T) { @@ -189,7 +189,7 @@ func TestGet(t *testing.T) { func testGet(t *testing.T, c *Client) { resp, err := c.R().Get("/") assertSuccess(t, resp, err) - tests.AssertEqual(t, "TestGet: text response", resp.String()) + assertEqual(t, "TestGet: text response", resp.String()) } func TestBadRequest(t *testing.T) { @@ -216,16 +216,16 @@ func testSetBodyMarshal(t *testing.T, c *Client) { return func(e *echo) { var user User err := json.Unmarshal([]byte(e.Body), &user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, username, user.Username) + assertNoError(t, err) + assertEqual(t, username, user.Username) } } assertUsernameXml := func(username string) func(e *echo) { return func(e *echo) { var user User err := xml.Unmarshal([]byte(e.Body), &user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, username, user.Username) + assertNoError(t, err) + assertEqual(t, username, user.Username) } } testCases := []struct { @@ -298,8 +298,8 @@ func TestSetBodyReader(t *testing.T) { var e echo resp, err := tc().R().SetBody(ioutil.NopCloser(bytes.NewBufferString("hello"))).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - tests.AssertEqual(t, "hello", e.Body) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) } func TestSetBodyGetContentFunc(t *testing.T) { @@ -308,8 +308,8 @@ func TestSetBodyGetContentFunc(t *testing.T) { return ioutil.NopCloser(bytes.NewBufferString("hello")), nil }).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - tests.AssertEqual(t, "hello", e.Body) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) e = echo{} var fn GetContentFunc = func() (io.ReadCloser, error) { @@ -317,8 +317,8 @@ func TestSetBodyGetContentFunc(t *testing.T) { } resp, err = tc().R().SetBody(fn).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - tests.AssertEqual(t, "hello", e.Body) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, "hello", e.Body) } func TestSetBodyContent(t *testing.T) { @@ -351,8 +351,8 @@ func testSetBodyContent(t *testing.T, c *Client) { var e echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) - tests.AssertEqual(t, testBody, e.Body) + assertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) + assertEqual(t, testBody, e.Body) } // Set Reader @@ -360,8 +360,8 @@ func testSetBodyContent(t *testing.T, c *Client) { e = echo{} resp, err := c.R().SetBody(testBodyReader).SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, testBody, e.Body) - tests.AssertEqual(t, "", e.Header.Get(hdrContentTypeKey)) + assertEqual(t, testBody, e.Body) + assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) } func TestCookie(t *testing.T) { @@ -382,7 +382,7 @@ func testCookie(t *testing.T, c *Client) { }, ).SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) + assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } func TestAuth(t *testing.T) { @@ -397,7 +397,7 @@ func testAuth(t *testing.T, c *Client) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) + assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" headers = make(http.Header) @@ -406,7 +406,7 @@ func testAuth(t *testing.T, c *Client) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "Bearer "+token, headers.Get("Authorization")) + assertEqual(t, "Bearer "+token, headers.Get("Authorization")) } func TestHeader(t *testing.T) { @@ -419,7 +419,7 @@ func testHeader(t *testing.T, c *Client) { customUserAgent := "My Custom User Agent" resp, err := c.R().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") assertSuccess(t, resp, err) - tests.AssertEqual(t, customUserAgent, resp.String()) + assertEqual(t, customUserAgent, resp.String()) // Set custom header headers := make(http.Header) @@ -431,9 +431,9 @@ func testHeader(t *testing.T, c *Client) { }).SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "value1", headers.Get("header1")) - tests.AssertEqual(t, "value2", headers.Get("header2")) - tests.AssertEqual(t, "value3", headers.Get("header3")) + assertEqual(t, "value1", headers.Get("header1")) + assertEqual(t, "value2", headers.Get("header2")) + assertEqual(t, "value3", headers.Get("header3")) } func TestQueryParam(t *testing.T) { @@ -458,14 +458,14 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key3", "value3"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryString resp, err = c.R(). SetQueryString("key1=value1&key2=value2&key3=value3"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParams resp, err = c.R(). @@ -476,7 +476,7 @@ func testQueryParam(t *testing.T, c *Client) { }). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParam & SetQueryParams & SetQueryString resp, err = c.R(). @@ -488,7 +488,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryString("key4=value4&key5=value5"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) + assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) // Set same param to override resp, err = c.R(). @@ -503,7 +503,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) + assertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) // Add same param without override resp, err = c.R(). @@ -518,7 +518,7 @@ func testQueryParam(t *testing.T, c *Client) { AddQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) + assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) } func TestPathParam(t *testing.T) { @@ -532,7 +532,7 @@ func testPathParam(t *testing.T, c *Client) { SetPathParam("username", username). Get("/user/{username}/profile") assertSuccess(t, resp, err) - tests.AssertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) + assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) } func TestSuccess(t *testing.T) { @@ -547,7 +547,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo). Get("/search") assertSuccess(t, resp, err) - tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) + assertEqual(t, "roc@imroc.cc", userInfo.Email) userInfo = UserInfo{} resp, err = c.R(). @@ -556,7 +556,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo).EnableDump(). Get("/search") assertSuccess(t, resp, err) - tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) + assertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestError(t *testing.T) { @@ -571,7 +571,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - tests.AssertEqual(t, 10000, errMsg.ErrorCode) + assertEqual(t, 10000, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -579,7 +579,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - tests.AssertEqual(t, 10001, errMsg.ErrorCode) + assertEqual(t, 10001, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -588,7 +588,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - tests.AssertEqual(t, 10001, errMsg.ErrorCode) + assertEqual(t, 10001, errMsg.ErrorCode) } func TestForm(t *testing.T) { @@ -606,7 +606,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) + assertEqual(t, "roc@imroc.cc", userInfo.Email) v := make(url.Values) v.Add("username", "imroc") @@ -616,7 +616,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) + assertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestHostHeaderOverride(t *testing.T) { @@ -627,7 +627,7 @@ func TestHostHeaderOverride(t *testing.T) { func testHostHeaderOverride(t *testing.T, c *Client) { resp, err := c.R().SetHeader("Host", "testhostname").Get("/host-header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "testhostname", resp.String()) + assertEqual(t, "testhostname", resp.String()) } func TestTraceInfo(t *testing.T) { @@ -636,15 +636,15 @@ func TestTraceInfo(t *testing.T) { resp, err := tc().R().Get("/") assertSuccess(t, resp, err) ti := resp.TraceInfo() - tests.AssertContains(t, ti.String(), "not enabled", true) - tests.AssertContains(t, ti.Blame(), "not enabled", true) + assertContains(t, ti.String(), "not enabled", true) + assertContains(t, ti.Blame(), "not enabled", true) resp, err = tc().EnableTraceAll().R().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - tests.AssertContains(t, ti.String(), "not enabled", false) - tests.AssertContains(t, ti.Blame(), "not enabled", false) - tests.AssertEqual(t, true, resp.TotalTime() > 0) + assertContains(t, ti.String(), "not enabled", false) + assertContains(t, ti.Blame(), "not enabled", false) + assertEqual(t, true, resp.TotalTime() > 0) } func testTraceInfo(t *testing.T, c *Client) { @@ -653,28 +653,28 @@ func testTraceInfo(t *testing.T, c *Client) { resp, err := c.R().Get("/") assertSuccess(t, resp, err) ti := resp.TraceInfo() - tests.AssertEqual(t, true, ti.TotalTime > 0) - tests.AssertEqual(t, true, ti.TCPConnectTime > 0) - tests.AssertEqual(t, true, ti.TLSHandshakeTime > 0) - tests.AssertEqual(t, true, ti.ConnectTime > 0) - tests.AssertEqual(t, true, ti.FirstResponseTime > 0) - tests.AssertEqual(t, true, ti.ResponseTime > 0) - tests.AssertNotNil(t, ti.RemoteAddr) + assertEqual(t, true, ti.TotalTime > 0) + assertEqual(t, true, ti.TCPConnectTime > 0) + assertEqual(t, true, ti.TLSHandshakeTime > 0) + assertEqual(t, true, ti.ConnectTime > 0) + assertEqual(t, true, ti.FirstResponseTime > 0) + assertEqual(t, true, ti.ResponseTime > 0) + assertNotNil(t, ti.RemoteAddr) // disable trace at client level c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - tests.AssertEqual(t, false, ti.TotalTime > 0) - tests.AssertIsNil(t, ti.RemoteAddr) + assertEqual(t, false, ti.TotalTime > 0) + assertIsNil(t, ti.RemoteAddr) // enable trace at request level resp, err = c.R().EnableTrace().Get("/") assertSuccess(t, resp, err) ti = resp.TraceInfo() - tests.AssertEqual(t, true, ti.TotalTime > 0) - tests.AssertNotNil(t, ti.RemoteAddr) + assertEqual(t, true, ti.TotalTime > 0) + assertNotNil(t, ti.RemoteAddr) } func TestTraceOnTimeout(t *testing.T) { @@ -686,50 +686,50 @@ func testTraceOnTimeout(t *testing.T, c *Client) { c.EnableTraceAll().SetTimeout(100 * time.Millisecond) resp, err := c.R().Get("http://req-nowhere.local") - tests.AssertNotNil(t, err) - tests.AssertNotNil(t, resp) + assertNotNil(t, err) + assertNotNil(t, resp) tr := resp.TraceInfo() - tests.AssertEqual(t, true, tr.DNSLookupTime >= 0) - tests.AssertEqual(t, true, tr.ConnectTime == 0) - tests.AssertEqual(t, true, tr.TLSHandshakeTime == 0) - tests.AssertEqual(t, true, tr.TCPConnectTime == 0) - tests.AssertEqual(t, true, tr.FirstResponseTime == 0) - tests.AssertEqual(t, true, tr.ResponseTime == 0) - tests.AssertEqual(t, true, tr.TotalTime > 0) - tests.AssertEqual(t, true, tr.TotalTime == resp.TotalTime()) + assertEqual(t, true, tr.DNSLookupTime >= 0) + assertEqual(t, true, tr.ConnectTime == 0) + assertEqual(t, true, tr.TLSHandshakeTime == 0) + assertEqual(t, true, tr.TCPConnectTime == 0) + assertEqual(t, true, tr.FirstResponseTime == 0) + assertEqual(t, true, tr.ResponseTime == 0) + assertEqual(t, true, tr.TotalTime > 0) + assertEqual(t, true, tr.TotalTime == resp.TotalTime()) } func TestAutoDetectRequestContentType(t *testing.T) { c := tc() - resp, err := c.R().SetBody(tests.GetTestFileContent(t, "sample-image.png")).Post("/content-type") + resp, err := c.R().SetBody(getTestFileContent(t, "sample-image.png")).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, "image/png", resp.String()) + assertEqual(t, "image/png", resp.String()) resp, err = c.R().SetBodyJsonString(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, jsonContentType, resp.String()) + assertEqual(t, jsonContentType, resp.String()) resp, err = c.R().SetContentType(xmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, xmlContentType, resp.String()) + assertEqual(t, xmlContentType, resp.String()) resp, err = c.R().SetBody(`

hello

`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, "text/html; charset=utf-8", resp.String()) + assertEqual(t, "text/html; charset=utf-8", resp.String()) resp, err = c.R().SetBody(`hello world`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, plainTextContentType, resp.String()) + assertEqual(t, plainTextContentType, resp.String()) } func TestSetFileUploadCheck(t *testing.T) { c := tc() resp, err := c.R().SetFileUpload(FileUpload{}).Post("/multipart") - tests.AssertErrorContains(t, err, "missing param name") - tests.AssertErrorContains(t, err, "missing filename") - tests.AssertErrorContains(t, err, "missing file content") - tests.AssertEqual(t, 0, len(resp.Request.uploadFiles)) + assertErrorContains(t, err, "missing param name") + assertErrorContains(t, err, "missing filename") + assertErrorContains(t, err, "missing file content") + assertEqual(t, 0, len(resp.Request.uploadFiles)) } func TestUploadMultipart(t *testing.T) { @@ -744,22 +744,22 @@ func TestUploadMultipart(t *testing.T) { SetResult(&m). Post("/multipart") assertSuccess(t, resp, err) - tests.AssertContains(t, resp.String(), "sample-image.png", true) - tests.AssertContains(t, resp.String(), "sample-file.txt", true) - tests.AssertContains(t, resp.String(), "value1", true) - tests.AssertContains(t, resp.String(), "value2", true) + assertContains(t, resp.String(), "sample-image.png", true) + assertContains(t, resp.String(), "sample-file.txt", true) + assertContains(t, resp.String(), "value1", true) + assertContains(t, resp.String(), "value2", true) } func TestFixPragmaCache(t *testing.T) { resp, err := tc().EnableForceHTTP1().R().Get("/pragma") assertSuccess(t, resp, err) - tests.AssertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) + assertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) } func TestSetFileBytes(t *testing.T) { resp, err := tc().R().SetFileBytes("file", "file.txt", []byte("test")).Post("/file-text") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test", resp.String()) + assertEqual(t, "test", resp.String()) } func TestSetBodyWrapper(t *testing.T) { @@ -768,14 +768,14 @@ func TestSetBodyWrapper(t *testing.T) { c := tc() r := c.R().SetBodyXmlString(s) - tests.AssertEqual(t, true, len(r.body) > 0) + assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyXmlBytes(b) - tests.AssertEqual(t, true, len(r.body) > 0) + assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonString(s) - tests.AssertEqual(t, true, len(r.body) > 0) + assertEqual(t, true, len(r.body) > 0) r = c.R().SetBodyJsonBytes(b) - tests.AssertEqual(t, true, len(r.body) > 0) + assertEqual(t, true, len(r.body) > 0) } diff --git a/response_test.go b/response_test.go index 141d509d..af7aa979 100644 --- a/response_test.go +++ b/response_test.go @@ -1,7 +1,6 @@ package req import ( - "github.com/imroc/req/v3/internal/tests" "testing" ) @@ -18,8 +17,8 @@ func TestUnmarshalJson(t *testing.T) { resp, err := tc().R().Get("/json") assertSuccess(t, resp, err) err = resp.UnmarshalJson(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) + assertNoError(t, err) + assertEqual(t, "roc", user.Name) } func TestUnmarshalXml(t *testing.T) { @@ -27,8 +26,8 @@ func TestUnmarshalXml(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.UnmarshalXml(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) + assertNoError(t, err) + assertEqual(t, "roc", user.Name) } func TestUnmarshal(t *testing.T) { @@ -36,8 +35,8 @@ func TestUnmarshal(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.Unmarshal(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) + assertNoError(t, err) + assertEqual(t, "roc", user.Name) } func TestResponseResult(t *testing.T) { @@ -46,10 +45,10 @@ func TestResponseResult(t *testing.T) { if !ok { t.Fatal("Response.Result() should return *User") } - tests.AssertEqual(t, "roc", user.Name) + assertEqual(t, "roc", user.Name) - tests.AssertEqual(t, true, resp.TotalTime() > 0) - tests.AssertEqual(t, false, resp.ReceivedAt().IsZero()) + assertEqual(t, true, resp.TotalTime() > 0) + assertEqual(t, false, resp.ReceivedAt().IsZero()) } func TestResponseError(t *testing.T) { @@ -58,5 +57,5 @@ func TestResponseError(t *testing.T) { if !ok { t.Fatal("Response.Error() should return *Message") } - tests.AssertEqual(t, "not allowed", msg.Message) + assertEqual(t, "not allowed", msg.Message) } diff --git a/retry_test.go b/retry_test.go index 74aa2665..afd69895 100644 --- a/retry_test.go +++ b/retry_test.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/tests" "io/ioutil" "math" "net/http" @@ -28,9 +27,9 @@ func testRetry(t *testing.T, setFunc func(r *Request)) { }) setFunc(r) resp, err := r.Get("/too-many") - tests.AssertNoError(t, err) - tests.AssertEqual(t, 3, resp.Request.RetryAttempt) - tests.AssertEqual(t, 3, attempt) + assertNoError(t, err) + assertEqual(t, 3, resp.Request.RetryAttempt) + assertEqual(t, 3, attempt) } func TestRetryInterval(t *testing.T) { @@ -55,7 +54,7 @@ func TestAddRetryHook(t *testing.T) { test = "test2" }) }) - tests.AssertEqual(t, "test2", test) + assertEqual(t, "test2", test) } func TestRetryOverride(t *testing.T) { @@ -74,9 +73,9 @@ func TestRetryOverride(t *testing.T) { }).SetRetryCondition(func(resp *Response, err error) bool { return err != nil || resp.StatusCode == http.StatusTooManyRequests }).Get("/too-many") - tests.AssertNoError(t, err) - tests.AssertEqual(t, "test1", test) - tests.AssertEqual(t, 2, resp.Request.RetryAttempt) + assertNoError(t, err) + assertEqual(t, "test1", test) + assertEqual(t, 2, resp.Request.RetryAttempt) } func TestAddRetryCondition(t *testing.T) { @@ -92,9 +91,9 @@ func TestAddRetryCondition(t *testing.T) { SetRetryHook(func(resp *Response, err error) { attempt++ }).Get("/too-many") - tests.AssertNoError(t, err) - tests.AssertEqual(t, 0, attempt) - tests.AssertEqual(t, 0, resp.Request.RetryAttempt) + assertNoError(t, err) + assertEqual(t, 0, attempt) + assertEqual(t, 0, resp.Request.RetryAttempt) attempt = 0 resp, err = tc(). @@ -108,9 +107,9 @@ func TestAddRetryCondition(t *testing.T) { SetCommonRetryHook(func(resp *Response, err error) { attempt++ }).R().Get("/too-many") - tests.AssertNoError(t, err) - tests.AssertEqual(t, 0, attempt) - tests.AssertEqual(t, 0, resp.Request.RetryAttempt) + assertNoError(t, err) + assertEqual(t, 0, attempt) + assertEqual(t, 0, resp.Request.RetryAttempt) } @@ -119,13 +118,13 @@ func TestRetryWithUnreplayableBody(t *testing.T) { SetRetryCount(1). SetBody(bytes.NewBufferString("test")). Post("/") - tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) + assertEqual(t, errRetryableWithUnReplayableBody, err) _, err = tc().R(). SetRetryCount(1). SetBody(ioutil.NopCloser(bytes.NewBufferString("test"))). Post("/") - tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) + assertEqual(t, errRetryableWithUnReplayableBody, err) } func TestRetryWithSetResult(t *testing.T) { @@ -138,5 +137,5 @@ func TestRetryWithSetResult(t *testing.T) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - tests.AssertEqual(t, "test=test", headers.Get("Cookie")) + assertEqual(t, "test=test", headers.Get("Cookie")) } From eb84de7329ca828b21de01abbf93907ce432e463 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Mar 2022 21:19:34 +0800 Subject: [PATCH 437/843] extract request_wrapper.go --- request.go | 402 -------------------------------------------- request_wrapper.go | 411 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 411 insertions(+), 402 deletions(-) create mode 100644 request_wrapper.go diff --git a/request.go b/request.go index 7cf45c16..cb38674b 100644 --- a/request.go +++ b/request.go @@ -135,12 +135,6 @@ func (r *Request) TraceInfo() TraceInfo { return ti } -// SetFormDataFromValues is a global wrapper methods which delegated -// to the default client, create a request and SetFormDataFromValues for request. -func SetFormDataFromValues(data urlpkg.Values) *Request { - return defaultClient.R().SetFormDataFromValues(data) -} - // SetFormDataFromValues set the form data from url.Values, will not // been used if request method does not allow payload. func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { @@ -155,12 +149,6 @@ func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { return r } -// SetFormData is a global wrapper methods which delegated -// to the default client, create a request and SetFormData for request. -func SetFormData(data map[string]string) *Request { - return defaultClient.R().SetFormData(data) -} - // SetFormData set the form data from a map, will not been used // if request method does not allow payload. func (r *Request) SetFormData(data map[string]string) *Request { @@ -173,24 +161,12 @@ func (r *Request) SetFormData(data map[string]string) *Request { return r } -// SetCookies is a global wrapper methods which delegated -// to the default client, create a request and SetCookies for request. -func SetCookies(cookies ...*http.Cookie) *Request { - return defaultClient.R().SetCookies(cookies...) -} - // SetCookies set http cookies for the request. func (r *Request) SetCookies(cookies ...*http.Cookie) *Request { r.Cookies = append(r.Cookies, cookies...) return r } -// SetQueryString is a global wrapper methods which delegated -// to the default client, create a request and SetQueryString for request. -func SetQueryString(query string) *Request { - return defaultClient.R().SetQueryString(query) -} - // SetQueryString set URL query parameters for the request using // raw query string. func (r *Request) SetQueryString(query string) *Request { @@ -210,12 +186,6 @@ func (r *Request) SetQueryString(query string) *Request { return r } -// SetFileReader is a global wrapper methods which delegated -// to the default client, create a request and SetFileReader for request. -func SetFileReader(paramName, filePath string, reader io.Reader) *Request { - return defaultClient.R().SetFileReader(paramName, filePath, reader) -} - // SetFileReader set up a multipart form with a reader to upload file. func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *Request { r.SetFileUpload(FileUpload{ @@ -231,23 +201,11 @@ func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *R return r } -// SetFileBytes is a global wrapper methods which delegated -// to the default client, create a request and SetFileBytes for request. -func SetFileBytes(paramName, filename string, content []byte) *Request { - return defaultClient.R().SetFileBytes(paramName, filename, content) -} - // SetFileBytes set up a multipart form with given []byte to upload. func (r *Request) SetFileBytes(paramName, filename string, content []byte) *Request { return r.SetFileReader(paramName, filename, bytes.NewReader(content)) } -// SetFiles is a global wrapper methods which delegated -// to the default client, create a request and SetFiles for request. -func SetFiles(files map[string]string) *Request { - return defaultClient.R().SetFiles(files) -} - // SetFiles set up a multipart form from a map to upload, which // key is the parameter name, and value is the file path. func (r *Request) SetFiles(files map[string]string) *Request { @@ -257,12 +215,6 @@ func (r *Request) SetFiles(files map[string]string) *Request { return r } -// SetFile is a global wrapper methods which delegated -// to the default client, create a request and SetFile for request. -func SetFile(paramName, filePath string) *Request { - return defaultClient.R().SetFile(paramName, filePath) -} - // SetFile set up a multipart form from file path to upload, // which read file from filePath automatically to upload. func (r *Request) SetFile(paramName, filePath string) *Request { @@ -276,12 +228,6 @@ func (r *Request) SetFile(paramName, filePath string) *Request { return r.SetFileReader(paramName, filepath.Base(filePath), file) } -// SetFileUpload is a global wrapper methods which delegated -// to the default client, create a request and SetFileUpload for request. -func SetFileUpload(f ...FileUpload) *Request { - return defaultClient.R().SetFileUpload(f...) -} - var errMissingParamName = errors.New("missing param name in multipart file upload") var errMissingFileName = errors.New("missing filename in multipart file upload") var errMissingFileContent = errors.New("missing file content in multipart file upload") @@ -310,12 +256,6 @@ func (r *Request) SetFileUpload(uploads ...FileUpload) *Request { return r } -// SetResult is a global wrapper methods which delegated -// to the default client, create a request and SetResult for request. -func SetResult(result interface{}) *Request { - return defaultClient.R().SetResult(result) -} - // SetResult set the result that response body will be unmarshaled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { @@ -323,12 +263,6 @@ func (r *Request) SetResult(result interface{}) *Request { return r } -// SetError is a global wrapper methods which delegated -// to the default client, create a request and SetError for request. -func SetError(error interface{}) *Request { - return defaultClient.R().SetError(error) -} - // SetError set the result that response body will be unmarshaled to if // request is error ( status `code >= 400`). func (r *Request) SetError(error interface{}) *Request { @@ -336,34 +270,16 @@ func (r *Request) SetError(error interface{}) *Request { return r } -// SetBearerAuthToken is a global wrapper methods which delegated -// to the default client, create a request and SetBearerAuthToken for request. -func SetBearerAuthToken(token string) *Request { - return defaultClient.R().SetBearerAuthToken(token) -} - // SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) } -// SetBasicAuth is a global wrapper methods which delegated -// to the default client, create a request and SetBasicAuth for request. -func SetBasicAuth(username, password string) *Request { - return defaultClient.R().SetBasicAuth(username, password) -} - // SetBasicAuth set basic auth for the request. func (r *Request) SetBasicAuth(username, password string) *Request { return r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) } -// SetHeaders is a global wrapper methods which delegated -// to the default client, create a request and SetHeaders for request. -func SetHeaders(hdrs map[string]string) *Request { - return defaultClient.R().SetHeaders(hdrs) -} - // SetHeaders set headers from a map for the request. func (r *Request) SetHeaders(hdrs map[string]string) *Request { for k, v := range hdrs { @@ -372,12 +288,6 @@ func (r *Request) SetHeaders(hdrs map[string]string) *Request { return r } -// SetHeader is a global wrapper methods which delegated -// to the default client, create a request and SetHeader for request. -func SetHeader(key, value string) *Request { - return defaultClient.R().SetHeader(key, value) -} - // SetHeader set a header for the request. func (r *Request) SetHeader(key, value string) *Request { if r.Headers == nil { @@ -387,12 +297,6 @@ func (r *Request) SetHeader(key, value string) *Request { return r } -// SetOutputFile is a global wrapper methods which delegated -// to the default client, create a request and SetOutputFile for request. -func SetOutputFile(file string) *Request { - return defaultClient.R().SetOutputFile(file) -} - // SetOutputFile set the file that response body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true @@ -400,12 +304,6 @@ func (r *Request) SetOutputFile(file string) *Request { return r } -// SetOutput is a global wrapper methods which delegated -// to the default client, create a request and SetOutput for request. -func SetOutput(output io.Writer) *Request { - return defaultClient.R().SetOutput(output) -} - // SetOutput set the io.Writer that response body will be downloaded to. func (r *Request) SetOutput(output io.Writer) *Request { if output == nil { @@ -417,12 +315,6 @@ func (r *Request) SetOutput(output io.Writer) *Request { return r } -// SetQueryParams is a global wrapper methods which delegated -// to the default client, create a request and SetQueryParams for request. -func SetQueryParams(params map[string]string) *Request { - return defaultClient.R().SetQueryParams(params) -} - // SetQueryParams set URL query parameters from a map for the request. func (r *Request) SetQueryParams(params map[string]string) *Request { for k, v := range params { @@ -431,12 +323,6 @@ func (r *Request) SetQueryParams(params map[string]string) *Request { return r } -// SetQueryParam is a global wrapper methods which delegated -// to the default client, create a request and SetQueryParam for request. -func SetQueryParam(key, value string) *Request { - return defaultClient.R().SetQueryParam(key, value) -} - // SetQueryParam set an URL query parameter for the request. func (r *Request) SetQueryParam(key, value string) *Request { if r.QueryParams == nil { @@ -446,12 +332,6 @@ func (r *Request) SetQueryParam(key, value string) *Request { return r } -// AddQueryParam is a global wrapper methods which delegated -// to the default client, create a request and AddQueryParam for request. -func AddQueryParam(key, value string) *Request { - return defaultClient.R().AddQueryParam(key, value) -} - // AddQueryParam add a URL query parameter for the request. func (r *Request) AddQueryParam(key, value string) *Request { if r.QueryParams == nil { @@ -461,12 +341,6 @@ func (r *Request) AddQueryParam(key, value string) *Request { return r } -// SetPathParams is a global wrapper methods which delegated -// to the default client, create a request and SetPathParams for request. -func SetPathParams(params map[string]string) *Request { - return defaultClient.R().SetPathParams(params) -} - // SetPathParams set URL path parameters from a map for the request. func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { @@ -475,12 +349,6 @@ func (r *Request) SetPathParams(params map[string]string) *Request { return r } -// SetPathParam is a global wrapper methods which delegated -// to the default client, create a request and SetPathParam for request. -func SetPathParam(key, value string) *Request { - return defaultClient.R().SetPathParam(key, value) -} - // SetPathParam set a URL path parameter for the request. func (r *Request) SetPathParam(key, value string) *Request { if r.PathParams == nil { @@ -513,12 +381,6 @@ func (r *Request) Send(method, url string) (*Response, error) { return r.client.do(r) } -// MustGet is a global wrapper methods which delegated -// to the default client, create a request and MustGet for request. -func MustGet(url string) *Response { - return defaultClient.R().MustGet(url) -} - // MustGet like Get, panic if error happens, should only be used to // test without error handling. func (r *Request) MustGet(url string) *Response { @@ -529,23 +391,11 @@ func (r *Request) MustGet(url string) *Response { return resp } -// Get is a global wrapper methods which delegated -// to the default client, create a request and Get for request. -func Get(url string) (*Response, error) { - return defaultClient.R().Get(url) -} - // Get fires http request with GET method and the specified URL. func (r *Request) Get(url string) (*Response, error) { return r.Send(http.MethodGet, url) } -// MustPost is a global wrapper methods which delegated -// to the default client, create a request and Get for request. -func MustPost(url string) *Response { - return defaultClient.R().MustPost(url) -} - // MustPost like Post, panic if error happens. should only be used to // test without error handling. func (r *Request) MustPost(url string) *Response { @@ -556,23 +406,11 @@ func (r *Request) MustPost(url string) *Response { return resp } -// Post is a global wrapper methods which delegated -// to the default client, create a request and Post for request. -func Post(url string) (*Response, error) { - return defaultClient.R().Post(url) -} - // Post fires http request with POST method and the specified URL. func (r *Request) Post(url string) (*Response, error) { return r.Send(http.MethodPost, url) } -// MustPut is a global wrapper methods which delegated -// to the default client, create a request and MustPut for request. -func MustPut(url string) *Response { - return defaultClient.R().MustPut(url) -} - // MustPut like Put, panic if error happens, should only be used to // test without error handling. func (r *Request) MustPut(url string) *Response { @@ -583,23 +421,11 @@ func (r *Request) MustPut(url string) *Response { return resp } -// Put is a global wrapper methods which delegated -// to the default client, create a request and Put for request. -func Put(url string) (*Response, error) { - return defaultClient.R().Put(url) -} - // Put fires http request with PUT method and the specified URL. func (r *Request) Put(url string) (*Response, error) { return r.Send(http.MethodPut, url) } -// MustPatch is a global wrapper methods which delegated -// to the default client, create a request and MustPatch for request. -func MustPatch(url string) *Response { - return defaultClient.R().MustPatch(url) -} - // MustPatch like Patch, panic if error happens, should only be used // to test without error handling. func (r *Request) MustPatch(url string) *Response { @@ -610,23 +436,11 @@ func (r *Request) MustPatch(url string) *Response { return resp } -// Patch is a global wrapper methods which delegated -// to the default client, create a request and Patch for request. -func Patch(url string) (*Response, error) { - return defaultClient.R().Patch(url) -} - // Patch fires http request with PATCH method and the specified URL. func (r *Request) Patch(url string) (*Response, error) { return r.Send(http.MethodPatch, url) } -// MustDelete is a global wrapper methods which delegated -// to the default client, create a request and MustDelete for request. -func MustDelete(url string) *Response { - return defaultClient.R().MustDelete(url) -} - // MustDelete like Delete, panic if error happens, should only be used // to test without error handling. func (r *Request) MustDelete(url string) *Response { @@ -637,23 +451,11 @@ func (r *Request) MustDelete(url string) *Response { return resp } -// Delete is a global wrapper methods which delegated -// to the default client, create a request and Delete for request. -func Delete(url string) (*Response, error) { - return defaultClient.R().Delete(url) -} - // Delete fires http request with DELETE method and the specified URL. func (r *Request) Delete(url string) (*Response, error) { return r.Send(http.MethodDelete, url) } -// MustOptions is a global wrapper methods which delegated -// to the default client, create a request and MustOptions for request. -func MustOptions(url string) *Response { - return defaultClient.R().MustOptions(url) -} - // MustOptions like Options, panic if error happens, should only be // used to test without error handling. func (r *Request) MustOptions(url string) *Response { @@ -664,23 +466,11 @@ func (r *Request) MustOptions(url string) *Response { return resp } -// Options is a global wrapper methods which delegated -// to the default client, create a request and Options for request. -func Options(url string) (*Response, error) { - return defaultClient.R().Options(url) -} - // Options fires http request with OPTIONS method and the specified URL. func (r *Request) Options(url string) (*Response, error) { return r.Send(http.MethodOptions, url) } -// MustHead is a global wrapper methods which delegated -// to the default client, create a request and MustHead for request. -func MustHead(url string) *Response { - return defaultClient.R().MustHead(url) -} - // MustHead like Head, panic if error happens, should only be used // to test without error handling. func (r *Request) MustHead(url string) *Response { @@ -691,23 +481,11 @@ func (r *Request) MustHead(url string) *Response { return resp } -// Head is a global wrapper methods which delegated -// to the default client, create a request and Head for request. -func Head(url string) (*Response, error) { - return defaultClient.R().Head(url) -} - // Head fires http request with HEAD method and the specified URL. func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } -// SetBody is a global wrapper methods which delegated -// to the default client, create a request and SetBody for request. -func SetBody(body interface{}) *Request { - return defaultClient.R().SetBody(body) -} - // SetBody set the request body, accepts string, []byte, io.Reader, map and struct. func (r *Request) SetBody(body interface{}) *Request { if body == nil { @@ -738,12 +516,6 @@ func (r *Request) SetBody(body interface{}) *Request { return r } -// SetBodyBytes is a global wrapper methods which delegated -// to the default client, create a request and SetBodyBytes for request. -func SetBodyBytes(body []byte) *Request { - return defaultClient.R().SetBodyBytes(body) -} - // SetBodyBytes set the request body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { r.body = body @@ -753,35 +525,17 @@ func (r *Request) SetBodyBytes(body []byte) *Request { return r } -// SetBodyString is a global wrapper methods which delegated -// to the default client, create a request and SetBodyString for request. -func SetBodyString(body string) *Request { - return defaultClient.R().SetBodyString(body) -} - // SetBodyString set the request body as string. func (r *Request) SetBodyString(body string) *Request { return r.SetBodyBytes([]byte(body)) } -// SetBodyJsonString is a global wrapper methods which delegated -// to the default client, create a request and SetBodyJsonString for request. -func SetBodyJsonString(body string) *Request { - return defaultClient.R().SetBodyJsonString(body) -} - // SetBodyJsonString set the request body as string and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { return r.SetBodyJsonBytes([]byte(body)) } -// SetBodyJsonBytes is a global wrapper methods which delegated -// to the default client, create a request and SetBodyJsonBytes for request. -func SetBodyJsonBytes(body []byte) *Request { - return defaultClient.R().SetBodyJsonBytes(body) -} - // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { @@ -789,12 +543,6 @@ func (r *Request) SetBodyJsonBytes(body []byte) *Request { return r.SetBodyBytes(body) } -// SetBodyJsonMarshal is a global wrapper methods which delegated -// to the default client, create a request and SetBodyJsonMarshal for request. -func SetBodyJsonMarshal(v interface{}) *Request { - return defaultClient.R().SetBodyJsonMarshal(v) -} - // SetBodyJsonMarshal set the request body that marshaled from object, and // set Content-Type header as "application/json; charset=utf-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { @@ -806,24 +554,12 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { return r.SetBodyJsonBytes(b) } -// SetBodyXmlString is a global wrapper methods which delegated -// to the default client, create a request and SetBodyXmlString for request. -func SetBodyXmlString(body string) *Request { - return defaultClient.R().SetBodyXmlString(body) -} - // SetBodyXmlString set the request body as string and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { return r.SetBodyXmlBytes([]byte(body)) } -// SetBodyXmlBytes is a global wrapper methods which delegated -// to the default client, create a request and SetBodyXmlBytes for request. -func SetBodyXmlBytes(body []byte) *Request { - return defaultClient.R().SetBodyXmlBytes(body) -} - // SetBodyXmlBytes set the request body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { @@ -831,12 +567,6 @@ func (r *Request) SetBodyXmlBytes(body []byte) *Request { return r.SetBodyBytes(body) } -// SetBodyXmlMarshal is a global wrapper methods which delegated -// to the default client, create a request and SetBodyXmlMarshal for request. -func SetBodyXmlMarshal(v interface{}) *Request { - return defaultClient.R().SetBodyXmlMarshal(v) -} - // SetBodyXmlMarshal set the request body that marshaled from object, and // set Content-Type header as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { @@ -848,12 +578,6 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { return r.SetBodyXmlBytes(b) } -// SetContentType is a global wrapper methods which delegated -// to the default client, create a request and SetContentType for request. -func SetContentType(contentType string) *Request { - return defaultClient.R().SetContentType(contentType) -} - // SetContentType set the `Content-Type` for the request. func (r *Request) SetContentType(contentType string) *Request { return r.SetHeader(hdrContentTypeKey, contentType) @@ -868,12 +592,6 @@ func (r *Request) Context() context.Context { return r.ctx } -// SetContext is a global wrapper methods which delegated -// to the default client, create a request and SetContext for request. -func SetContext(ctx context.Context) *Request { - return defaultClient.R().SetContext(ctx) -} - // SetContext method sets the context.Context for current Request. It allows // to interrupt the request execution if ctx.Done() channel is closed. // See https://blog.golang.org/context article and the "context" package @@ -883,24 +601,12 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } -// DisableTrace is a global wrapper methods which delegated -// to the default client, create a request and DisableTrace for request. -func DisableTrace() *Request { - return defaultClient.R().DisableTrace() -} - // DisableTrace disables trace. func (r *Request) DisableTrace() *Request { r.trace = nil return r } -// EnableTrace is a global wrapper methods which delegated -// to the default client, create a request and EnableTrace for request. -func EnableTrace() *Request { - return defaultClient.R().EnableTrace() -} - // EnableTrace enables trace. func (r *Request) EnableTrace() *Request { if r.trace == nil { @@ -929,24 +635,12 @@ func (r *Request) getDumpOptions() *DumpOptions { return r.dumpOptions } -// EnableDumpTo is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpTo for request. -func EnableDumpTo(output io.Writer) *Request { - return defaultClient.R().EnableDumpTo(output) -} - // EnableDumpTo enables dump and save to the specified io.Writer. func (r *Request) EnableDumpTo(output io.Writer) *Request { r.getDumpOptions().Output = output return r.EnableDump() } -// EnableDumpToFile is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpToFile for request. -func EnableDumpToFile(filename string) *Request { - return defaultClient.R().EnableDumpToFile(filename) -} - // EnableDumpToFile enables dump and save to the specified filename. func (r *Request) EnableDumpToFile(filename string) *Request { file, err := os.Create(filename) @@ -958,12 +652,6 @@ func (r *Request) EnableDumpToFile(filename string) *Request { return r.EnableDump() } -// SetDumpOptions is a global wrapper methods which delegated -// to the default client, create a request and SetDumpOptions for request. -func SetDumpOptions(opt *DumpOptions) *Request { - return defaultClient.R().SetDumpOptions(opt) -} - // SetDumpOptions sets DumpOptions at request level. func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { if opt == nil { @@ -976,23 +664,11 @@ func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { return r } -// EnableDump is a global wrapper methods which delegated -// to the default client, create a request and EnableDump for request. -func EnableDump() *Request { - return defaultClient.R().EnableDump() -} - // EnableDump enables dump, including all content for the request and response by default. func (r *Request) EnableDump() *Request { return r.SetContext(context.WithValue(r.Context(), dumperKey, newDumper(r.getDumpOptions()))) } -// EnableDumpWithoutBody is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutBody for request. -func EnableDumpWithoutBody() *Request { - return defaultClient.R().EnableDumpWithoutBody() -} - // EnableDumpWithoutBody enables dump only header for the request and response. func (r *Request) EnableDumpWithoutBody() *Request { o := r.getDumpOptions() @@ -1001,12 +677,6 @@ func (r *Request) EnableDumpWithoutBody() *Request { return r.EnableDump() } -// EnableDumpWithoutHeader is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutHeader for request. -func EnableDumpWithoutHeader() *Request { - return defaultClient.R().EnableDumpWithoutHeader() -} - // EnableDumpWithoutHeader enables dump only body for the request and response. func (r *Request) EnableDumpWithoutHeader() *Request { o := r.getDumpOptions() @@ -1015,12 +685,6 @@ func (r *Request) EnableDumpWithoutHeader() *Request { return r.EnableDump() } -// EnableDumpWithoutResponse is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutResponse for request. -func EnableDumpWithoutResponse() *Request { - return defaultClient.R().EnableDumpWithoutResponse() -} - // EnableDumpWithoutResponse enables dump only request. func (r *Request) EnableDumpWithoutResponse() *Request { o := r.getDumpOptions() @@ -1029,12 +693,6 @@ func (r *Request) EnableDumpWithoutResponse() *Request { return r.EnableDump() } -// EnableDumpWithoutRequest is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutRequest for request. -func EnableDumpWithoutRequest() *Request { - return defaultClient.R().EnableDumpWithoutRequest() -} - // EnableDumpWithoutRequest enables dump only response. func (r *Request) EnableDumpWithoutRequest() *Request { o := r.getDumpOptions() @@ -1043,12 +701,6 @@ func (r *Request) EnableDumpWithoutRequest() *Request { return r.EnableDump() } -// EnableDumpWithoutRequestBody is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutRequestBody for request. -func EnableDumpWithoutRequestBody() *Request { - return defaultClient.R().EnableDumpWithoutRequestBody() -} - // EnableDumpWithoutRequestBody enables dump with request body excluded, // can be used in upload request to avoid dump the unreadable binary content. func (r *Request) EnableDumpWithoutRequestBody() *Request { @@ -1057,12 +709,6 @@ func (r *Request) EnableDumpWithoutRequestBody() *Request { return r.EnableDump() } -// EnableDumpWithoutResponseBody is a global wrapper methods which delegated -// to the default client, create a request and EnableDumpWithoutResponseBody for request. -func EnableDumpWithoutResponseBody() *Request { - return defaultClient.R().EnableDumpWithoutResponseBody() -} - // EnableDumpWithoutResponseBody enables dump with response body excluded, // can be used in download request to avoid dump the unreadable binary content. func (r *Request) EnableDumpWithoutResponseBody() *Request { @@ -1078,24 +724,12 @@ func (r *Request) getRetryOption() *retryOption { return r.retryOption } -// SetRetryCount is a global wrapper methods which delegated -// to the default client, create a request and SetRetryCount for request. -func SetRetryCount(count int) *Request { - return defaultClient.R().SetRetryCount(count) -} - // SetRetryCount enables retry and set the maximum retry count. func (r *Request) SetRetryCount(count int) *Request { r.getRetryOption().MaxRetries = count return r } -// SetRetryInterval is a global wrapper methods which delegated -// to the default client, create a request and SetRetryInterval for request. -func SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { - return defaultClient.R().SetRetryInterval(getRetryIntervalFunc) -} - // SetRetryInterval sets the custom GetRetryIntervalFunc, you can use this to // implement your own backoff retry algorithm. // For example: @@ -1108,12 +742,6 @@ func (r *Request) SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *R return r } -// SetRetryFixedInterval is a global wrapper methods which delegated -// to the default client, create a request and SetRetryFixedInterval for request. -func SetRetryFixedInterval(interval time.Duration) *Request { - return defaultClient.R().SetRetryFixedInterval(interval) -} - // SetRetryFixedInterval set retry to use a fixed interval. func (r *Request) SetRetryFixedInterval(interval time.Duration) *Request { r.getRetryOption().GetRetryInterval = func(resp *Response, attempt int) time.Duration { @@ -1122,12 +750,6 @@ func (r *Request) SetRetryFixedInterval(interval time.Duration) *Request { return r } -// SetRetryBackoffInterval is a global wrapper methods which delegated -// to the default client, create a request and SetRetryBackoffInterval for request. -func SetRetryBackoffInterval(min, max time.Duration) *Request { - return defaultClient.R().SetRetryBackoffInterval(min, max) -} - // SetRetryBackoffInterval set retry to use a capped exponential backoff with jitter. // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ func (r *Request) SetRetryBackoffInterval(min, max time.Duration) *Request { @@ -1135,12 +757,6 @@ func (r *Request) SetRetryBackoffInterval(min, max time.Duration) *Request { return r } -// SetRetryHook is a global wrapper methods which delegated -// to the default client, create a request and SetRetryHook for request. -func SetRetryHook(hook RetryHookFunc) *Request { - return defaultClient.R().SetRetryHook(hook) -} - // SetRetryHook set the retry hook which will be executed before a retry. // It will override other retry hooks if any been added before (including // client-level retry hooks). @@ -1149,12 +765,6 @@ func (r *Request) SetRetryHook(hook RetryHookFunc) *Request { return r } -// AddRetryHook is a global wrapper methods which delegated -// to the default client, create a request and AddRetryHook for request. -func AddRetryHook(hook RetryHookFunc) *Request { - return defaultClient.R().AddRetryHook(hook) -} - // AddRetryHook adds a retry hook which will be executed before a retry. func (r *Request) AddRetryHook(hook RetryHookFunc) *Request { ro := r.getRetryOption() @@ -1162,12 +772,6 @@ func (r *Request) AddRetryHook(hook RetryHookFunc) *Request { return r } -// SetRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and SetRetryCondition for request. -func SetRetryCondition(condition RetryConditionFunc) *Request { - return defaultClient.R().SetRetryCondition(condition) -} - // SetRetryCondition sets the retry condition, which determines whether the // request should retry. // It will override other retry conditions if any been added before (including @@ -1177,12 +781,6 @@ func (r *Request) SetRetryCondition(condition RetryConditionFunc) *Request { return r } -// AddRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and AddRetryCondition for request. -func AddRetryCondition(condition RetryConditionFunc) *Request { - return defaultClient.R().AddRetryCondition(condition) -} - // AddRetryCondition adds a retry condition, which determines whether the // request should retry. func (r *Request) AddRetryCondition(condition RetryConditionFunc) *Request { diff --git a/request_wrapper.go b/request_wrapper.go new file mode 100644 index 00000000..b6fef7bc --- /dev/null +++ b/request_wrapper.go @@ -0,0 +1,411 @@ +package req + +import ( + "context" + "io" + "net/http" + "net/url" + "time" +) + +// SetFormDataFromValues is a global wrapper methods which delegated +// to the default client, create a request and SetFormDataFromValues for request. +func SetFormDataFromValues(data url.Values) *Request { + return defaultClient.R().SetFormDataFromValues(data) +} + +// SetFormData is a global wrapper methods which delegated +// to the default client, create a request and SetFormData for request. +func SetFormData(data map[string]string) *Request { + return defaultClient.R().SetFormData(data) +} + +// SetCookies is a global wrapper methods which delegated +// to the default client, create a request and SetCookies for request. +func SetCookies(cookies ...*http.Cookie) *Request { + return defaultClient.R().SetCookies(cookies...) +} + +// SetQueryString is a global wrapper methods which delegated +// to the default client, create a request and SetQueryString for request. +func SetQueryString(query string) *Request { + return defaultClient.R().SetQueryString(query) +} + +// SetFileReader is a global wrapper methods which delegated +// to the default client, create a request and SetFileReader for request. +func SetFileReader(paramName, filePath string, reader io.Reader) *Request { + return defaultClient.R().SetFileReader(paramName, filePath, reader) +} + +// SetFileBytes is a global wrapper methods which delegated +// to the default client, create a request and SetFileBytes for request. +func SetFileBytes(paramName, filename string, content []byte) *Request { + return defaultClient.R().SetFileBytes(paramName, filename, content) +} + +// SetFiles is a global wrapper methods which delegated +// to the default client, create a request and SetFiles for request. +func SetFiles(files map[string]string) *Request { + return defaultClient.R().SetFiles(files) +} + +// SetFile is a global wrapper methods which delegated +// to the default client, create a request and SetFile for request. +func SetFile(paramName, filePath string) *Request { + return defaultClient.R().SetFile(paramName, filePath) +} + +// SetFileUpload is a global wrapper methods which delegated +// to the default client, create a request and SetFileUpload for request. +func SetFileUpload(f ...FileUpload) *Request { + return defaultClient.R().SetFileUpload(f...) +} + +// SetResult is a global wrapper methods which delegated +// to the default client, create a request and SetResult for request. +func SetResult(result interface{}) *Request { + return defaultClient.R().SetResult(result) +} + +// SetError is a global wrapper methods which delegated +// to the default client, create a request and SetError for request. +func SetError(error interface{}) *Request { + return defaultClient.R().SetError(error) +} + +// SetBearerAuthToken is a global wrapper methods which delegated +// to the default client, create a request and SetBearerAuthToken for request. +func SetBearerAuthToken(token string) *Request { + return defaultClient.R().SetBearerAuthToken(token) +} + +// SetBasicAuth is a global wrapper methods which delegated +// to the default client, create a request and SetBasicAuth for request. +func SetBasicAuth(username, password string) *Request { + return defaultClient.R().SetBasicAuth(username, password) +} + +// SetHeaders is a global wrapper methods which delegated +// to the default client, create a request and SetHeaders for request. +func SetHeaders(hdrs map[string]string) *Request { + return defaultClient.R().SetHeaders(hdrs) +} + +// SetHeader is a global wrapper methods which delegated +// to the default client, create a request and SetHeader for request. +func SetHeader(key, value string) *Request { + return defaultClient.R().SetHeader(key, value) +} + +// SetOutputFile is a global wrapper methods which delegated +// to the default client, create a request and SetOutputFile for request. +func SetOutputFile(file string) *Request { + return defaultClient.R().SetOutputFile(file) +} + +// SetOutput is a global wrapper methods which delegated +// to the default client, create a request and SetOutput for request. +func SetOutput(output io.Writer) *Request { + return defaultClient.R().SetOutput(output) +} + +// SetQueryParams is a global wrapper methods which delegated +// to the default client, create a request and SetQueryParams for request. +func SetQueryParams(params map[string]string) *Request { + return defaultClient.R().SetQueryParams(params) +} + +// SetQueryParam is a global wrapper methods which delegated +// to the default client, create a request and SetQueryParam for request. +func SetQueryParam(key, value string) *Request { + return defaultClient.R().SetQueryParam(key, value) +} + +// AddQueryParam is a global wrapper methods which delegated +// to the default client, create a request and AddQueryParam for request. +func AddQueryParam(key, value string) *Request { + return defaultClient.R().AddQueryParam(key, value) +} + +// SetPathParams is a global wrapper methods which delegated +// to the default client, create a request and SetPathParams for request. +func SetPathParams(params map[string]string) *Request { + return defaultClient.R().SetPathParams(params) +} + +// SetPathParam is a global wrapper methods which delegated +// to the default client, create a request and SetPathParam for request. +func SetPathParam(key, value string) *Request { + return defaultClient.R().SetPathParam(key, value) +} + +// MustGet is a global wrapper methods which delegated +// to the default client, create a request and MustGet for request. +func MustGet(url string) *Response { + return defaultClient.R().MustGet(url) +} + +// Get is a global wrapper methods which delegated +// to the default client, create a request and Get for request. +func Get(url string) (*Response, error) { + return defaultClient.R().Get(url) +} + +// MustPost is a global wrapper methods which delegated +// to the default client, create a request and Get for request. +func MustPost(url string) *Response { + return defaultClient.R().MustPost(url) +} + +// Post is a global wrapper methods which delegated +// to the default client, create a request and Post for request. +func Post(url string) (*Response, error) { + return defaultClient.R().Post(url) +} + +// MustPut is a global wrapper methods which delegated +// to the default client, create a request and MustPut for request. +func MustPut(url string) *Response { + return defaultClient.R().MustPut(url) +} + +// Put is a global wrapper methods which delegated +// to the default client, create a request and Put for request. +func Put(url string) (*Response, error) { + return defaultClient.R().Put(url) +} + +// MustPatch is a global wrapper methods which delegated +// to the default client, create a request and MustPatch for request. +func MustPatch(url string) *Response { + return defaultClient.R().MustPatch(url) +} + +// Patch is a global wrapper methods which delegated +// to the default client, create a request and Patch for request. +func Patch(url string) (*Response, error) { + return defaultClient.R().Patch(url) +} + +// MustDelete is a global wrapper methods which delegated +// to the default client, create a request and MustDelete for request. +func MustDelete(url string) *Response { + return defaultClient.R().MustDelete(url) +} + +// Delete is a global wrapper methods which delegated +// to the default client, create a request and Delete for request. +func Delete(url string) (*Response, error) { + return defaultClient.R().Delete(url) +} + +// MustOptions is a global wrapper methods which delegated +// to the default client, create a request and MustOptions for request. +func MustOptions(url string) *Response { + return defaultClient.R().MustOptions(url) +} + +// Options is a global wrapper methods which delegated +// to the default client, create a request and Options for request. +func Options(url string) (*Response, error) { + return defaultClient.R().Options(url) +} + +// MustHead is a global wrapper methods which delegated +// to the default client, create a request and MustHead for request. +func MustHead(url string) *Response { + return defaultClient.R().MustHead(url) +} + +// Head is a global wrapper methods which delegated +// to the default client, create a request and Head for request. +func Head(url string) (*Response, error) { + return defaultClient.R().Head(url) +} + +// SetBody is a global wrapper methods which delegated +// to the default client, create a request and SetBody for request. +func SetBody(body interface{}) *Request { + return defaultClient.R().SetBody(body) +} + +// SetBodyBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyBytes for request. +func SetBodyBytes(body []byte) *Request { + return defaultClient.R().SetBodyBytes(body) +} + +// SetBodyString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyString for request. +func SetBodyString(body string) *Request { + return defaultClient.R().SetBodyString(body) +} + +// SetBodyJsonString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonString for request. +func SetBodyJsonString(body string) *Request { + return defaultClient.R().SetBodyJsonString(body) +} + +// SetBodyJsonBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonBytes for request. +func SetBodyJsonBytes(body []byte) *Request { + return defaultClient.R().SetBodyJsonBytes(body) +} + +// SetBodyJsonMarshal is a global wrapper methods which delegated +// to the default client, create a request and SetBodyJsonMarshal for request. +func SetBodyJsonMarshal(v interface{}) *Request { + return defaultClient.R().SetBodyJsonMarshal(v) +} + +// SetBodyXmlString is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlString for request. +func SetBodyXmlString(body string) *Request { + return defaultClient.R().SetBodyXmlString(body) +} + +// SetBodyXmlBytes is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlBytes for request. +func SetBodyXmlBytes(body []byte) *Request { + return defaultClient.R().SetBodyXmlBytes(body) +} + +// SetBodyXmlMarshal is a global wrapper methods which delegated +// to the default client, create a request and SetBodyXmlMarshal for request. +func SetBodyXmlMarshal(v interface{}) *Request { + return defaultClient.R().SetBodyXmlMarshal(v) +} + +// SetContentType is a global wrapper methods which delegated +// to the default client, create a request and SetContentType for request. +func SetContentType(contentType string) *Request { + return defaultClient.R().SetContentType(contentType) +} + +// SetContext is a global wrapper methods which delegated +// to the default client, create a request and SetContext for request. +func SetContext(ctx context.Context) *Request { + return defaultClient.R().SetContext(ctx) +} + +// DisableTrace is a global wrapper methods which delegated +// to the default client, create a request and DisableTrace for request. +func DisableTrace() *Request { + return defaultClient.R().DisableTrace() +} + +// EnableTrace is a global wrapper methods which delegated +// to the default client, create a request and EnableTrace for request. +func EnableTrace() *Request { + return defaultClient.R().EnableTrace() +} + +// EnableDumpTo is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpTo for request. +func EnableDumpTo(output io.Writer) *Request { + return defaultClient.R().EnableDumpTo(output) +} + +// EnableDumpToFile is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpToFile for request. +func EnableDumpToFile(filename string) *Request { + return defaultClient.R().EnableDumpToFile(filename) +} + +// SetDumpOptions is a global wrapper methods which delegated +// to the default client, create a request and SetDumpOptions for request. +func SetDumpOptions(opt *DumpOptions) *Request { + return defaultClient.R().SetDumpOptions(opt) +} + +// EnableDump is a global wrapper methods which delegated +// to the default client, create a request and EnableDump for request. +func EnableDump() *Request { + return defaultClient.R().EnableDump() +} + +// EnableDumpWithoutBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutBody for request. +func EnableDumpWithoutBody() *Request { + return defaultClient.R().EnableDumpWithoutBody() +} + +// EnableDumpWithoutHeader is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutHeader for request. +func EnableDumpWithoutHeader() *Request { + return defaultClient.R().EnableDumpWithoutHeader() +} + +// EnableDumpWithoutResponse is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutResponse for request. +func EnableDumpWithoutResponse() *Request { + return defaultClient.R().EnableDumpWithoutResponse() +} + +// EnableDumpWithoutRequest is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutRequest for request. +func EnableDumpWithoutRequest() *Request { + return defaultClient.R().EnableDumpWithoutRequest() +} + +// EnableDumpWithoutRequestBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutRequestBody for request. +func EnableDumpWithoutRequestBody() *Request { + return defaultClient.R().EnableDumpWithoutRequestBody() +} + +// EnableDumpWithoutResponseBody is a global wrapper methods which delegated +// to the default client, create a request and EnableDumpWithoutResponseBody for request. +func EnableDumpWithoutResponseBody() *Request { + return defaultClient.R().EnableDumpWithoutResponseBody() +} + +// SetRetryCount is a global wrapper methods which delegated +// to the default client, create a request and SetRetryCount for request. +func SetRetryCount(count int) *Request { + return defaultClient.R().SetRetryCount(count) +} + +// SetRetryInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryInterval for request. +func SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { + return defaultClient.R().SetRetryInterval(getRetryIntervalFunc) +} + +// SetRetryFixedInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryFixedInterval for request. +func SetRetryFixedInterval(interval time.Duration) *Request { + return defaultClient.R().SetRetryFixedInterval(interval) +} + +// SetRetryBackoffInterval is a global wrapper methods which delegated +// to the default client, create a request and SetRetryBackoffInterval for request. +func SetRetryBackoffInterval(min, max time.Duration) *Request { + return defaultClient.R().SetRetryBackoffInterval(min, max) +} + +// SetRetryHook is a global wrapper methods which delegated +// to the default client, create a request and SetRetryHook for request. +func SetRetryHook(hook RetryHookFunc) *Request { + return defaultClient.R().SetRetryHook(hook) +} + +// AddRetryHook is a global wrapper methods which delegated +// to the default client, create a request and AddRetryHook for request. +func AddRetryHook(hook RetryHookFunc) *Request { + return defaultClient.R().AddRetryHook(hook) +} + +// SetRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and SetRetryCondition for request. +func SetRetryCondition(condition RetryConditionFunc) *Request { + return defaultClient.R().SetRetryCondition(condition) +} + +// AddRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and AddRetryCondition for request. +func AddRetryCondition(condition RetryConditionFunc) *Request { + return defaultClient.R().AddRetryCondition(condition) +} From aea2dab2a983c3c0ef175a9ba24f1e78cc6d4709 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Mar 2022 21:30:49 +0800 Subject: [PATCH 438/843] extract client_wrapper.go --- client.go | 516 --------------------------------------------- client_wrapper.go | 527 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 527 insertions(+), 516 deletions(-) create mode 100644 client_wrapper.go diff --git a/client.go b/client.go index 07f378f2..faa09c68 100644 --- a/client.go +++ b/client.go @@ -64,12 +64,6 @@ type Client struct { afterResponse []ResponseMiddleware } -// R is a global wrapper methods which delegated -// to the default client's R(). -func R() *Request { - return defaultClient.R() -} - // R create a new request. func (c *Client) R() *Request { return &Request{ @@ -78,12 +72,6 @@ func (c *Client) R() *Request { } } -// SetCommonFormDataFromValues is a global wrapper methods which delegated -// to the default client's SetCommonFormDataFromValues. -func SetCommonFormDataFromValues(data urlpkg.Values) *Client { - return defaultClient.SetCommonFormDataFromValues(data) -} - // SetCommonFormDataFromValues set the form data from url.Values for all requests // which request method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { @@ -98,12 +86,6 @@ func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { return c } -// SetCommonFormData is a global wrapper methods which delegated -// to the default client's SetCommonFormData. -func SetCommonFormData(data map[string]string) *Client { - return defaultClient.SetCommonFormData(data) -} - // SetCommonFormData set the form data from map for all requests // which request method allows payload. func (c *Client) SetCommonFormData(data map[string]string) *Client { @@ -116,12 +98,6 @@ func (c *Client) SetCommonFormData(data map[string]string) *Client { return c } -// SetBaseURL is a global wrapper methods which delegated -// to the default client's SetBaseURL. -func SetBaseURL(u string) *Client { - return defaultClient.SetBaseURL(u) -} - // SetBaseURL set the default base URL, will be used if request URL is // a relative URL. func (c *Client) SetBaseURL(u string) *Client { @@ -129,12 +105,6 @@ func (c *Client) SetBaseURL(u string) *Client { return c } -// SetOutputDirectory is a global wrapper methods which delegated -// to the default client's SetOutputDirectory. -func SetOutputDirectory(dir string) *Client { - return defaultClient.SetOutputDirectory(dir) -} - // SetOutputDirectory set output directory that response will // be downloaded to. func (c *Client) SetOutputDirectory(dir string) *Client { @@ -142,12 +112,6 @@ func (c *Client) SetOutputDirectory(dir string) *Client { return c } -// SetCertFromFile is a global wrapper methods which delegated -// to the default client's SetCertFromFile. -func SetCertFromFile(certFile, keyFile string) *Client { - return defaultClient.SetCertFromFile(certFile, keyFile) -} - // SetCertFromFile helps to set client certificates from cert and key file. func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { cert, err := tls.LoadX509KeyPair(certFile, keyFile) @@ -160,12 +124,6 @@ func (c *Client) SetCertFromFile(certFile, keyFile string) *Client { return c } -// SetCerts is a global wrapper methods which delegated -// to the default client's SetCerts. -func SetCerts(certs ...tls.Certificate) *Client { - return defaultClient.SetCerts(certs...) -} - // SetCerts set client certificates. func (c *Client) SetCerts(certs ...tls.Certificate) *Client { config := c.GetTLSClientConfig() @@ -182,24 +140,12 @@ func (c *Client) appendRootCertData(data []byte) { return } -// SetRootCertFromString is a global wrapper methods which delegated -// to the default client's SetRootCertFromString. -func SetRootCertFromString(pemContent string) *Client { - return defaultClient.SetRootCertFromString(pemContent) -} - // SetRootCertFromString set root certificates from string. func (c *Client) SetRootCertFromString(pemContent string) *Client { c.appendRootCertData([]byte(pemContent)) return c } -// SetRootCertsFromFile is a global wrapper methods which delegated -// to the default client's SetRootCertsFromFile. -func SetRootCertsFromFile(pemFiles ...string) *Client { - return defaultClient.SetRootCertsFromFile(pemFiles...) -} - // SetRootCertsFromFile set root certificates from files. func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { for _, pemFile := range pemFiles { @@ -213,12 +159,6 @@ func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { return c } -// GetTLSClientConfig is a global wrapper methods which delegated -// to the default client's GetTLSClientConfig. -func GetTLSClientConfig() *tls.Config { - return defaultClient.GetTLSClientConfig() -} - // GetTLSClientConfig return the underlying tls.Config. func (c *Client) GetTLSClientConfig() *tls.Config { if c.t.TLSClientConfig == nil { @@ -239,12 +179,6 @@ func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) er return nil } -// SetRedirectPolicy is a global wrapper methods which delegated -// to the default client's SetRedirectPolicy. -func SetRedirectPolicy(policies ...RedirectPolicy) *Client { - return defaultClient.SetRedirectPolicy(policies...) -} - // SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect // responses (usually responses with 301 and 302 status code), see the predefined // AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, @@ -271,12 +205,6 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { return c } -// DisableKeepAlives is a global wrapper methods which delegated -// to the default client's DisableKeepAlives. -func DisableKeepAlives() *Client { - return defaultClient.DisableKeepAlives() -} - // DisableKeepAlives disable the HTTP keep-alives (enabled by default) // and will only use the connection to the server for a single // HTTP request. @@ -287,24 +215,12 @@ func (c *Client) DisableKeepAlives() *Client { return c } -// EnableKeepAlives is a global wrapper methods which delegated -// to the default client's EnableKeepAlives. -func EnableKeepAlives() *Client { - return defaultClient.EnableKeepAlives() -} - // EnableKeepAlives enables HTTP keep-alives (enabled by default). func (c *Client) EnableKeepAlives() *Client { c.t.DisableKeepAlives = false return c } -// DisableCompression is a global wrapper methods which delegated -// to the default client's DisableCompression. -func DisableCompression() *Client { - return defaultClient.DisableCompression() -} - // DisableCompression disables the compression (enabled by default), // which prevents the Transport from requesting compression // with an "Accept-Encoding: gzip" request header when the @@ -318,24 +234,12 @@ func (c *Client) DisableCompression() *Client { return c } -// EnableCompression is a global wrapper methods which delegated -// to the default client's EnableCompression. -func EnableCompression() *Client { - return defaultClient.EnableCompression() -} - // EnableCompression enables the compression (enabled by default). func (c *Client) EnableCompression() *Client { c.t.DisableCompression = false return c } -// SetTLSClientConfig is a global wrapper methods which delegated -// to the default client's SetTLSClientConfig. -func SetTLSClientConfig(conf *tls.Config) *Client { - return defaultClient.SetTLSClientConfig(conf) -} - // SetTLSClientConfig set the TLS client config. Be careful! Usually // you don't need this, you can directly set the tls configuration with // methods like EnableInsecureSkipVerify, SetCerts etc. Or you can call @@ -347,12 +251,6 @@ func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { return c } -// EnableInsecureSkipVerify is a global wrapper methods which delegated -// to the default client's EnableInsecureSkipVerify. -func EnableInsecureSkipVerify() *Client { - return defaultClient.EnableInsecureSkipVerify() -} - // EnableInsecureSkipVerify enable send https without verifing // the server's certificates (disabled by default). func (c *Client) EnableInsecureSkipVerify() *Client { @@ -360,12 +258,6 @@ func (c *Client) EnableInsecureSkipVerify() *Client { return c } -// DisableInsecureSkipVerify is a global wrapper methods which delegated -// to the default client's DisableInsecureSkipVerify. -func DisableInsecureSkipVerify() *Client { - return defaultClient.DisableInsecureSkipVerify() -} - // DisableInsecureSkipVerify disable send https without verifing // the server's certificates (disabled by default). func (c *Client) DisableInsecureSkipVerify() *Client { @@ -373,12 +265,6 @@ func (c *Client) DisableInsecureSkipVerify() *Client { return c } -// SetCommonQueryParams is a global wrapper methods which delegated -// to the default client's SetCommonQueryParams. -func SetCommonQueryParams(params map[string]string) *Client { - return defaultClient.SetCommonQueryParams(params) -} - // SetCommonQueryParams set URL query parameters with a map // for all requests. func (c *Client) SetCommonQueryParams(params map[string]string) *Client { @@ -388,12 +274,6 @@ func (c *Client) SetCommonQueryParams(params map[string]string) *Client { return c } -// AddCommonQueryParam is a global wrapper methods which delegated -// to the default client's AddCommonQueryParam. -func AddCommonQueryParam(key, value string) *Client { - return defaultClient.AddCommonQueryParam(key, value) -} - // AddCommonQueryParam add a URL query parameter with a key-value // pair for all requests. func (c *Client) AddCommonQueryParam(key, value string) *Client { @@ -411,24 +291,12 @@ func (c *Client) pathParams() map[string]string { return c.PathParams } -// SetCommonPathParam is a global wrapper methods which delegated -// to the default client's SetCommonPathParam. -func SetCommonPathParam(key, value string) *Client { - return defaultClient.SetCommonPathParam(key, value) -} - // SetCommonPathParam set a path parameter for all requests. func (c *Client) SetCommonPathParam(key, value string) *Client { c.pathParams()[key] = value return c } -// SetCommonPathParams is a global wrapper methods which delegated -// to the default client's SetCommonPathParams. -func SetCommonPathParams(pathParams map[string]string) *Client { - return defaultClient.SetCommonPathParams(pathParams) -} - // SetCommonPathParams set path parameters for all requests. func (c *Client) SetCommonPathParams(pathParams map[string]string) *Client { m := c.pathParams() @@ -438,12 +306,6 @@ func (c *Client) SetCommonPathParams(pathParams map[string]string) *Client { return c } -// SetCommonQueryParam is a global wrapper methods which delegated -// to the default client's SetCommonQueryParam. -func SetCommonQueryParam(key, value string) *Client { - return defaultClient.SetCommonQueryParam(key, value) -} - // SetCommonQueryParam set a URL query parameter with a key-value // pair for all requests. func (c *Client) SetCommonQueryParam(key, value string) *Client { @@ -454,12 +316,6 @@ func (c *Client) SetCommonQueryParam(key, value string) *Client { return c } -// SetCommonQueryString is a global wrapper methods which delegated -// to the default client's SetCommonQueryString. -func SetCommonQueryString(query string) *Client { - return defaultClient.SetCommonQueryString(query) -} - // SetCommonQueryString set URL query parameters with a raw query string // for all requests. func (c *Client) SetCommonQueryString(query string) *Client { @@ -479,48 +335,24 @@ func (c *Client) SetCommonQueryString(query string) *Client { return c } -// SetCommonCookies is a global wrapper methods which delegated -// to the default client's SetCommonCookies. -func SetCommonCookies(cookies ...*http.Cookie) *Client { - return defaultClient.SetCommonCookies(cookies...) -} - // SetCommonCookies set HTTP cookies for all requests. func (c *Client) SetCommonCookies(cookies ...*http.Cookie) *Client { c.Cookies = append(c.Cookies, cookies...) return c } -// DisableDebugLog is a global wrapper methods which delegated -// to the default client's DisableDebugLog. -func DisableDebugLog() *Client { - return defaultClient.DisableDebugLog() -} - // DisableDebugLog disable debug level log (disabled by default). func (c *Client) DisableDebugLog() *Client { c.DebugLog = false return c } -// EnableDebugLog is a global wrapper methods which delegated -// to the default client's EnableDebugLog. -func EnableDebugLog() *Client { - return defaultClient.EnableDebugLog() -} - // EnableDebugLog enable debug level log (disabled by default). func (c *Client) EnableDebugLog() *Client { c.DebugLog = true return c } -// DevMode is a global wrapper methods which delegated -// to the default client's DevMode. -func DevMode() *Client { - return defaultClient.DevMode() -} - // DevMode enables: // 1. Dump content of all requests and responses to see details. // 2. Output debug level log for deeper insights. @@ -533,12 +365,6 @@ func (c *Client) DevMode() *Client { SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") } -// SetScheme is a global wrapper methods which delegated -// to the default client's SetScheme. -func SetScheme(scheme string) *Client { - return defaultClient.SetScheme(scheme) -} - // SetScheme set the default scheme for client, will be used when // there is no scheme in the request URL (e.g. "github.com/imroc/req"). func (c *Client) SetScheme(scheme string) *Client { @@ -548,12 +374,6 @@ func (c *Client) SetScheme(scheme string) *Client { return c } -// SetLogger is a global wrapper methods which delegated -// to the default client's SetLogger. -func SetLogger(log Logger) *Client { - return defaultClient.SetLogger(log) -} - // SetLogger set the customized logger for client, will disable log if set to nil. func (c *Client) SetLogger(log Logger) *Client { if log == nil { @@ -571,12 +391,6 @@ func (c *Client) getResponseOptions() *ResponseOptions { return c.t.ResponseOptions } -// SetTimeout is a global wrapper methods which delegated -// to the default client's SetTimeout. -func SetTimeout(d time.Duration) *Client { - return defaultClient.SetTimeout(d) -} - // SetTimeout set timeout for all requests. func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d @@ -590,12 +404,6 @@ func (c *Client) getDumpOptions() *DumpOptions { return c.dumpOptions } -// EnableDumpAll is a global wrapper methods which delegated -// to the default client's EnableDumpAll. -func EnableDumpAll() *Client { - return defaultClient.EnableDumpAll() -} - // EnableDumpAll enable dump for all requests, including // all content for the request and response by default. func (c *Client) EnableDumpAll() *Client { @@ -606,12 +414,6 @@ func (c *Client) EnableDumpAll() *Client { return c } -// EnableDumpAllToFile is a global wrapper methods which delegated -// to the default client's EnableDumpAllToFile. -func EnableDumpAllToFile(filename string) *Client { - return defaultClient.EnableDumpAllToFile(filename) -} - // EnableDumpAllToFile enable dump for all requests and output // to the specified file. func (c *Client) EnableDumpAllToFile(filename string) *Client { @@ -625,12 +427,6 @@ func (c *Client) EnableDumpAllToFile(filename string) *Client { return c } -// EnableDumpAllTo is a global wrapper methods which delegated -// to the default client's EnableDumpAllTo. -func EnableDumpAllTo(output io.Writer) *Client { - return defaultClient.EnableDumpAllTo(output) -} - // EnableDumpAllTo enable dump for all requests and output to // the specified io.Writer. func (c *Client) EnableDumpAllTo(output io.Writer) *Client { @@ -639,12 +435,6 @@ func (c *Client) EnableDumpAllTo(output io.Writer) *Client { return c } -// EnableDumpAllAsync is a global wrapper methods which delegated -// to the default client's EnableDumpAllAsync. -func EnableDumpAllAsync() *Client { - return defaultClient.EnableDumpAllAsync() -} - // EnableDumpAllAsync enable dump for all requests and output // asynchronously, can be used for debugging in production // environment without affecting performance. @@ -655,12 +445,6 @@ func (c *Client) EnableDumpAllAsync() *Client { return c } -// EnableDumpAllWithoutRequestBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutRequestBody. -func EnableDumpAllWithoutRequestBody() *Client { - return defaultClient.EnableDumpAllWithoutRequestBody() -} - // EnableDumpAllWithoutRequestBody enable dump for all requests without // request body, can be used in the upload request to avoid dumping the // unreadable binary content. @@ -671,12 +455,6 @@ func (c *Client) EnableDumpAllWithoutRequestBody() *Client { return c } -// EnableDumpAllWithoutResponseBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutResponseBody. -func EnableDumpAllWithoutResponseBody() *Client { - return defaultClient.EnableDumpAllWithoutResponseBody() -} - // EnableDumpAllWithoutResponseBody enable dump for all requests without // response body, can be used in the download request to avoid dumping the // unreadable binary content. @@ -687,12 +465,6 @@ func (c *Client) EnableDumpAllWithoutResponseBody() *Client { return c } -// EnableDumpAllWithoutResponse is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutResponse. -func EnableDumpAllWithoutResponse() *Client { - return defaultClient.EnableDumpAllWithoutResponse() -} - // EnableDumpAllWithoutResponse enable dump for all requests without response, // can be used if you only care about the request. func (c *Client) EnableDumpAllWithoutResponse() *Client { @@ -703,12 +475,6 @@ func (c *Client) EnableDumpAllWithoutResponse() *Client { return c } -// EnableDumpAllWithoutRequest is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutRequest. -func EnableDumpAllWithoutRequest() *Client { - return defaultClient.EnableDumpAllWithoutRequest() -} - // EnableDumpAllWithoutRequest enables dump for all requests without request, // can be used if you only care about the response. func (c *Client) EnableDumpAllWithoutRequest() *Client { @@ -719,12 +485,6 @@ func (c *Client) EnableDumpAllWithoutRequest() *Client { return c } -// EnableDumpAllWithoutHeader is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutHeader. -func EnableDumpAllWithoutHeader() *Client { - return defaultClient.EnableDumpAllWithoutHeader() -} - // EnableDumpAllWithoutHeader enable dump for all requests without header, // can be used if you only care about the body. func (c *Client) EnableDumpAllWithoutHeader() *Client { @@ -735,12 +495,6 @@ func (c *Client) EnableDumpAllWithoutHeader() *Client { return c } -// EnableDumpAllWithoutBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutBody. -func EnableDumpAllWithoutBody() *Client { - return defaultClient.EnableDumpAllWithoutBody() -} - // EnableDumpAllWithoutBody enable dump for all requests without body, // can be used if you only care about the header. func (c *Client) EnableDumpAllWithoutBody() *Client { @@ -751,47 +505,23 @@ func (c *Client) EnableDumpAllWithoutBody() *Client { return c } -// NewRequest is a global wrapper methods which delegated -// to the default client's NewRequest. -func NewRequest() *Request { - return defaultClient.R() -} - // NewRequest is the alias of R() func (c *Client) NewRequest() *Request { return c.R() } -// DisableAutoReadResponse is a global wrapper methods which delegated -// to the default client's DisableAutoReadResponse. -func DisableAutoReadResponse() *Client { - return defaultClient.DisableAutoReadResponse() -} - // DisableAutoReadResponse disable read response body automatically (enabled by default). func (c *Client) DisableAutoReadResponse() *Client { c.disableAutoReadResponse = true return c } -// EnableAutoReadResponse is a global wrapper methods which delegated -// to the default client's EnableAutoReadResponse. -func EnableAutoReadResponse() *Client { - return defaultClient.EnableAutoReadResponse() -} - // EnableAutoReadResponse enable read response body automatically (enabled by default). func (c *Client) EnableAutoReadResponse() *Client { c.disableAutoReadResponse = false return c } -// SetAutoDecodeContentType is a global wrapper methods which delegated -// to the default client's SetAutoDecodeContentType. -func SetAutoDecodeContentType(contentTypes ...string) *Client { - return defaultClient.SetAutoDecodeContentType(contentTypes...) -} - // SetAutoDecodeContentType set the content types that will be auto-detected and decode // to utf-8 (e.g. "json", "xml", "html", "text"). func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { @@ -800,12 +530,6 @@ func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { return c } -// SetAutoDecodeContentTypeFunc is a global wrapper methods which delegated -// to the default client's SetAutoDecodeAllTypeFunc. -func SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { - return defaultClient.SetAutoDecodeContentTypeFunc(fn) -} - // SetAutoDecodeContentTypeFunc set the function that determines whether the // specified `Content-Type` should be auto-detected and decode to utf-8. func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { @@ -814,12 +538,6 @@ func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) return c } -// SetAutoDecodeAllContentType is a global wrapper methods which delegated -// to the default client's SetAutoDecodeAllContentType. -func SetAutoDecodeAllContentType() *Client { - return defaultClient.SetAutoDecodeAllContentType() -} - // SetAutoDecodeAllContentType enable try auto-detect charset and decode all // content type to utf-8. func (c *Client) SetAutoDecodeAllContentType() *Client { @@ -830,12 +548,6 @@ func (c *Client) SetAutoDecodeAllContentType() *Client { return c } -// DisableAutoDecode is a global wrapper methods which delegated -// to the default client's DisableAutoDecode. -func DisableAutoDecode() *Client { - return defaultClient.DisableAutoDecode() -} - // DisableAutoDecode disable auto-detect charset and decode to utf-8 // (enabled by default). func (c *Client) DisableAutoDecode() *Client { @@ -843,12 +555,6 @@ func (c *Client) DisableAutoDecode() *Client { return c } -// EnableAutoDecode is a global wrapper methods which delegated -// to the default client's EnableAutoDecode. -func EnableAutoDecode() *Client { - return defaultClient.EnableAutoDecode() -} - // EnableAutoDecode enable auto-detect charset and decode to utf-8 // (enabled by default). func (c *Client) EnableAutoDecode() *Client { @@ -856,46 +562,22 @@ func (c *Client) EnableAutoDecode() *Client { return c } -// SetUserAgent is a global wrapper methods which delegated -// to the default client's SetUserAgent. -func SetUserAgent(userAgent string) *Client { - return defaultClient.SetUserAgent(userAgent) -} - // SetUserAgent set the "User-Agent" header for all requests. func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetCommonHeader(hdrUserAgentKey, userAgent) } -// SetCommonBearerAuthToken is a global wrapper methods which delegated -// to the default client's SetCommonBearerAuthToken. -func SetCommonBearerAuthToken(token string) *Client { - return defaultClient.SetCommonBearerAuthToken(token) -} - // SetCommonBearerAuthToken set the bearer auth token for all requests. func (c *Client) SetCommonBearerAuthToken(token string) *Client { return c.SetCommonHeader("Authorization", "Bearer "+token) } -// SetCommonBasicAuth is a global wrapper methods which delegated -// to the default client's SetCommonBasicAuth. -func SetCommonBasicAuth(username, password string) *Client { - return defaultClient.SetCommonBasicAuth(username, password) -} - // SetCommonBasicAuth set the basic auth for all requests. func (c *Client) SetCommonBasicAuth(username, password string) *Client { c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password)) return c } -// SetCommonHeaders is a global wrapper methods which delegated -// to the default client's SetCommonHeaders. -func SetCommonHeaders(hdrs map[string]string) *Client { - return defaultClient.SetCommonHeaders(hdrs) -} - // SetCommonHeaders set headers for all requests. func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { for k, v := range hdrs { @@ -904,12 +586,6 @@ func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { return c } -// SetCommonHeader is a global wrapper methods which delegated -// to the default client's SetCommonHeader. -func SetCommonHeader(key, value string) *Client { - return defaultClient.SetCommonHeader(key, value) -} - // SetCommonHeader set a header for all requests. func (c *Client) SetCommonHeader(key, value string) *Client { if c.Headers == nil { @@ -919,36 +595,18 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } -// SetCommonContentType is a global wrapper methods which delegated -// to the default client's SetCommonContentType. -func SetCommonContentType(ct string) *Client { - return defaultClient.SetCommonContentType(ct) -} - // SetCommonContentType set the `Content-Type` header for all requests. func (c *Client) SetCommonContentType(ct string) *Client { c.SetCommonHeader(hdrContentTypeKey, ct) return c } -// DisableDumpAll is a global wrapper methods which delegated -// to the default client's DisableDumpAll. -func DisableDumpAll() *Client { - return defaultClient.DisableDumpAll() -} - // DisableDumpAll disable dump for all requests. func (c *Client) DisableDumpAll() *Client { c.t.DisableDump() return c } -// SetCommonDumpOptions is a global wrapper methods which delegated -// to the default client's SetCommonDumpOptions. -func SetCommonDumpOptions(opt *DumpOptions) *Client { - return defaultClient.SetCommonDumpOptions(opt) -} - // SetCommonDumpOptions configures the underlying Transport's DumpOptions // for all requests. func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { @@ -962,48 +620,24 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { return c } -// SetProxy is a global wrapper methods which delegated -// to the default client's SetProxy. -func SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { - return defaultClient.SetProxy(proxy) -} - // SetProxy set the proxy function. func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { c.t.Proxy = proxy return c } -// OnBeforeRequest is a global wrapper methods which delegated -// to the default client's OnBeforeRequest. -func OnBeforeRequest(m RequestMiddleware) *Client { - return defaultClient.OnBeforeRequest(m) -} - // OnBeforeRequest add a request middleware which hooks before request sent. func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { c.udBeforeRequest = append(c.udBeforeRequest, m) return c } -// OnAfterResponse is a global wrapper methods which delegated -// to the default client's OnAfterResponse. -func OnAfterResponse(m ResponseMiddleware) *Client { - return defaultClient.OnAfterResponse(m) -} - // OnAfterResponse add a response middleware which hooks after response received. func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { c.afterResponse = append(c.afterResponse, m) return c } -// SetProxyURL is a global wrapper methods which delegated -// to the default client's SetProxyURL. -func SetProxyURL(proxyUrl string) *Client { - return defaultClient.SetProxyURL(proxyUrl) -} - // SetProxyURL set proxy from the proxy URL. func (c *Client) SetProxyURL(proxyUrl string) *Client { u, err := urlpkg.Parse(proxyUrl) @@ -1015,48 +649,24 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } -// DisableTraceAll is a global wrapper methods which delegated -// to the default client's DisableTraceAll. -func DisableTraceAll() *Client { - return defaultClient.DisableTraceAll() -} - // DisableTraceAll disable trace for all requests. func (c *Client) DisableTraceAll() *Client { c.trace = false return c } -// EnableTraceAll is a global wrapper methods which delegated -// to the default client's EnableTraceAll. -func EnableTraceAll() *Client { - return defaultClient.EnableTraceAll() -} - // EnableTraceAll enable trace for all requests. func (c *Client) EnableTraceAll() *Client { c.trace = true return c } -// SetCookieJar is a global wrapper methods which delegated -// to the default client's SetCookieJar. -func SetCookieJar(jar http.CookieJar) *Client { - return defaultClient.SetCookieJar(jar) -} - // SetCookieJar set the `CookeJar` to the underlying `http.Client`. func (c *Client) SetCookieJar(jar http.CookieJar) *Client { c.httpClient.Jar = jar return c } -// SetJsonMarshal is a global wrapper methods which delegated -// to the default client's SetJsonMarshal. -func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { - return defaultClient.SetJsonMarshal(fn) -} - // SetJsonMarshal set the JSON marshal function which will be used // to marshal request body. func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { @@ -1064,12 +674,6 @@ func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client return c } -// SetJsonUnmarshal is a global wrapper methods which delegated -// to the default client's SetJsonUnmarshal. -func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { - return defaultClient.SetJsonUnmarshal(fn) -} - // SetJsonUnmarshal set the JSON unmarshal function which will be used // to unmarshal response body. func (c *Client) SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { @@ -1077,12 +681,6 @@ func (c *Client) SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Cl return c } -// SetXmlMarshal is a global wrapper methods which delegated -// to the default client's SetXmlMarshal. -func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { - return defaultClient.SetXmlMarshal(fn) -} - // SetXmlMarshal set the XML marshal function which will be used // to marshal request body. func (c *Client) SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { @@ -1090,12 +688,6 @@ func (c *Client) SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { return c } -// SetXmlUnmarshal is a global wrapper methods which delegated -// to the default client's SetXmlUnmarshal. -func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { - return defaultClient.SetXmlUnmarshal(fn) -} - // SetXmlUnmarshal set the XML unmarshal function which will be used // to unmarshal response body. func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { @@ -1103,12 +695,6 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli return c } -// SetDialTLS is a global wrapper methods which delegated -// to the default client's SetDialTLS. -func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - return defaultClient.SetDialTLS(fn) -} - // SetDialTLS set the customized `DialTLSContext` function to Transport. // Make sure the returned `conn` implements TLSConn if you want your // customized `conn` supports HTTP2. @@ -1117,48 +703,24 @@ func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) ( return c } -// SetDial is a global wrapper methods which delegated -// to the default client's SetDial. -func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - return defaultClient.SetDial(fn) -} - // SetDial set the customized `DialContext` function to Transport. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { c.t.DialContext = fn return c } -// SetTLSHandshakeTimeout is a global wrapper methods which delegated -// to the default client's SetTLSHandshakeTimeout. -func SetTLSHandshakeTimeout(timeout time.Duration) *Client { - return defaultClient.SetTLSHandshakeTimeout(timeout) -} - // SetTLSHandshakeTimeout set the TLS handshake timeout. func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { c.t.TLSHandshakeTimeout = timeout return c } -// EnableForceHTTP1 is a global wrapper methods which delegated -// to the default client's EnableForceHTTP1. -func EnableForceHTTP1() *Client { - return defaultClient.EnableForceHTTP1() -} - // EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (c *Client) EnableForceHTTP1() *Client { c.t.ForceHttpVersion = HTTP1 return c } -// EnableForceHTTP2 is a global wrapper methods which delegated -// to the default client's EnableForceHTTP2. -func EnableForceHTTP2() *Client { - return defaultClient.EnableForceHTTP2() -} - // EnableForceHTTP2 enable force using HTTP2 for https requests // (disabled by default). func (c *Client) EnableForceHTTP2() *Client { @@ -1166,36 +728,18 @@ func (c *Client) EnableForceHTTP2() *Client { return c } -// DisableForceHttpVersion is a global wrapper methods which delegated -// to the default client's DisableForceHttpVersion. -func DisableForceHttpVersion() *Client { - return defaultClient.DisableForceHttpVersion() -} - // DisableForceHttpVersion disable force using HTTP1 (disabled by default). func (c *Client) DisableForceHttpVersion() *Client { c.t.ForceHttpVersion = "" return c } -// DisableAllowGetMethodPayload is a global wrapper methods which delegated -// to the default client's DisableAllowGetMethodPayload. -func DisableAllowGetMethodPayload() *Client { - return defaultClient.DisableAllowGetMethodPayload() -} - // DisableAllowGetMethodPayload disable sending GET method requests with body. func (c *Client) DisableAllowGetMethodPayload() *Client { c.AllowGetMethodPayload = false return c } -// EnableAllowGetMethodPayload is a global wrapper methods which delegated -// to the default client's EnableAllowGetMethodPayload. -func EnableAllowGetMethodPayload() *Client { - return defaultClient.EnableAllowGetMethodPayload() -} - // EnableAllowGetMethodPayload allows sending GET method requests with body. func (c *Client) EnableAllowGetMethodPayload() *Client { c.AllowGetMethodPayload = true @@ -1206,12 +750,6 @@ func (c *Client) isPayloadForbid(m string) bool { return (m == http.MethodGet && !c.AllowGetMethodPayload) || m == http.MethodHead || m == http.MethodOptions } -// GetClient is a global wrapper methods which delegated -// to the default client's GetClient. -func GetClient() *http.Client { - return defaultClient.GetClient() -} - // GetClient returns the underlying `http.Client`. func (c *Client) GetClient() *http.Client { return c.httpClient @@ -1224,24 +762,12 @@ func (c *Client) getRetryOption() *retryOption { return c.retryOption } -// SetCommonRetryCount is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryCount for request. -func SetCommonRetryCount(count int) *Client { - return defaultClient.SetCommonRetryCount(count) -} - // SetCommonRetryCount enables retry and set the maximum retry count for all requests. func (c *Client) SetCommonRetryCount(count int) *Client { c.getRetryOption().MaxRetries = count return c } -// SetCommonRetryInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryInterval for request. -func SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { - return defaultClient.SetCommonRetryInterval(getRetryIntervalFunc) -} - // SetCommonRetryInterval sets the custom GetRetryIntervalFunc for all requests, // you can use this to implement your own backoff retry algorithm. // For example: @@ -1254,12 +780,6 @@ func (c *Client) SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFun return c } -// SetCommonRetryFixedInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryFixedInterval for request. -func SetCommonRetryFixedInterval(interval time.Duration) *Client { - return defaultClient.SetCommonRetryFixedInterval(interval) -} - // SetCommonRetryFixedInterval set retry to use a fixed interval for all requests. func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { c.getRetryOption().GetRetryInterval = func(resp *Response, attempt int) time.Duration { @@ -1268,12 +788,6 @@ func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { return c } -// SetCommonRetryBackoffInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryBackoffInterval for request. -func SetCommonRetryBackoffInterval(min, max time.Duration) *Client { - return defaultClient.SetCommonRetryBackoffInterval(min, max) -} - // SetCommonRetryBackoffInterval set retry to use a capped exponential backoff with jitter // for all requests. // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ @@ -1282,12 +796,6 @@ func (c *Client) SetCommonRetryBackoffInterval(min, max time.Duration) *Client { return c } -// SetCommonRetryHook is a global wrapper methods which delegated -// to the default client, create a request and SetRetryHook for request. -func SetCommonRetryHook(hook RetryHookFunc) *Client { - return defaultClient.SetCommonRetryHook(hook) -} - // SetCommonRetryHook set the retry hook which will be executed before a retry. // It will override other retry hooks if any been added before. func (c *Client) SetCommonRetryHook(hook RetryHookFunc) *Client { @@ -1295,12 +803,6 @@ func (c *Client) SetCommonRetryHook(hook RetryHookFunc) *Client { return c } -// AddCommonRetryHook is a global wrapper methods which delegated -// to the default client, create a request and AddCommonRetryHook for request. -func AddCommonRetryHook(hook RetryHookFunc) *Client { - return defaultClient.AddCommonRetryHook(hook) -} - // AddCommonRetryHook adds a retry hook for all requests, which will be // executed before a retry. func (c *Client) AddCommonRetryHook(hook RetryHookFunc) *Client { @@ -1309,12 +811,6 @@ func (c *Client) AddCommonRetryHook(hook RetryHookFunc) *Client { return c } -// SetCommonRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryCondition for request. -func SetCommonRetryCondition(condition RetryConditionFunc) *Client { - return defaultClient.SetCommonRetryCondition(condition) -} - // SetCommonRetryCondition sets the retry condition, which determines whether the // request should retry. // It will override other retry conditions if any been added before. @@ -1323,12 +819,6 @@ func (c *Client) SetCommonRetryCondition(condition RetryConditionFunc) *Client { return c } -// AddCommonRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and AddCommonRetryCondition for request. -func AddCommonRetryCondition(condition RetryConditionFunc) *Client { - return defaultClient.AddCommonRetryCondition(condition) -} - // AddCommonRetryCondition adds a retry condition, which determines whether the // request should retry. func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { @@ -1337,12 +827,6 @@ func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { return c } -// SetUnixSocket is a global wrapper methods which delegated -// to the default client, create a request and SetUnixSocket for request. -func SetUnixSocket(file string) *Client { - return defaultClient.SetUnixSocket(file) -} - // SetUnixSocket set client to dial connection use unix socket. // For example: // client.SetUnixSocket("/var/run/custom.sock") diff --git a/client_wrapper.go b/client_wrapper.go new file mode 100644 index 00000000..404abf02 --- /dev/null +++ b/client_wrapper.go @@ -0,0 +1,527 @@ +package req + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/url" + "time" +) + +// SetCommonFormDataFromValues is a global wrapper methods which delegated +// to the default client's SetCommonFormDataFromValues. +func SetCommonFormDataFromValues(data url.Values) *Client { + return defaultClient.SetCommonFormDataFromValues(data) +} + +// SetCommonFormData is a global wrapper methods which delegated +// to the default client's SetCommonFormData. +func SetCommonFormData(data map[string]string) *Client { + return defaultClient.SetCommonFormData(data) +} + +// SetBaseURL is a global wrapper methods which delegated +// to the default client's SetBaseURL. +func SetBaseURL(u string) *Client { + return defaultClient.SetBaseURL(u) +} + +// SetOutputDirectory is a global wrapper methods which delegated +// to the default client's SetOutputDirectory. +func SetOutputDirectory(dir string) *Client { + return defaultClient.SetOutputDirectory(dir) +} + +// SetCertFromFile is a global wrapper methods which delegated +// to the default client's SetCertFromFile. +func SetCertFromFile(certFile, keyFile string) *Client { + return defaultClient.SetCertFromFile(certFile, keyFile) +} + +// SetCerts is a global wrapper methods which delegated +// to the default client's SetCerts. +func SetCerts(certs ...tls.Certificate) *Client { + return defaultClient.SetCerts(certs...) +} + +// SetRootCertFromString is a global wrapper methods which delegated +// to the default client's SetRootCertFromString. +func SetRootCertFromString(pemContent string) *Client { + return defaultClient.SetRootCertFromString(pemContent) +} + +// SetRootCertsFromFile is a global wrapper methods which delegated +// to the default client's SetRootCertsFromFile. +func SetRootCertsFromFile(pemFiles ...string) *Client { + return defaultClient.SetRootCertsFromFile(pemFiles...) +} + +// GetTLSClientConfig is a global wrapper methods which delegated +// to the default client's GetTLSClientConfig. +func GetTLSClientConfig() *tls.Config { + return defaultClient.GetTLSClientConfig() +} + +// SetRedirectPolicy is a global wrapper methods which delegated +// to the default client's SetRedirectPolicy. +func SetRedirectPolicy(policies ...RedirectPolicy) *Client { + return defaultClient.SetRedirectPolicy(policies...) +} + +// DisableKeepAlives is a global wrapper methods which delegated +// to the default client's DisableKeepAlives. +func DisableKeepAlives() *Client { + return defaultClient.DisableKeepAlives() +} + +// EnableKeepAlives is a global wrapper methods which delegated +// to the default client's EnableKeepAlives. +func EnableKeepAlives() *Client { + return defaultClient.EnableKeepAlives() +} + +// DisableCompression is a global wrapper methods which delegated +// to the default client's DisableCompression. +func DisableCompression() *Client { + return defaultClient.DisableCompression() +} + +// EnableCompression is a global wrapper methods which delegated +// to the default client's EnableCompression. +func EnableCompression() *Client { + return defaultClient.EnableCompression() +} + +// SetTLSClientConfig is a global wrapper methods which delegated +// to the default client's SetTLSClientConfig. +func SetTLSClientConfig(conf *tls.Config) *Client { + return defaultClient.SetTLSClientConfig(conf) +} + +// EnableInsecureSkipVerify is a global wrapper methods which delegated +// to the default client's EnableInsecureSkipVerify. +func EnableInsecureSkipVerify() *Client { + return defaultClient.EnableInsecureSkipVerify() +} + +// DisableInsecureSkipVerify is a global wrapper methods which delegated +// to the default client's DisableInsecureSkipVerify. +func DisableInsecureSkipVerify() *Client { + return defaultClient.DisableInsecureSkipVerify() +} + +// SetCommonQueryParams is a global wrapper methods which delegated +// to the default client's SetCommonQueryParams. +func SetCommonQueryParams(params map[string]string) *Client { + return defaultClient.SetCommonQueryParams(params) +} + +// AddCommonQueryParam is a global wrapper methods which delegated +// to the default client's AddCommonQueryParam. +func AddCommonQueryParam(key, value string) *Client { + return defaultClient.AddCommonQueryParam(key, value) +} + +// SetCommonPathParam is a global wrapper methods which delegated +// to the default client's SetCommonPathParam. +func SetCommonPathParam(key, value string) *Client { + return defaultClient.SetCommonPathParam(key, value) +} + +// SetCommonPathParams is a global wrapper methods which delegated +// to the default client's SetCommonPathParams. +func SetCommonPathParams(pathParams map[string]string) *Client { + return defaultClient.SetCommonPathParams(pathParams) +} + +// SetCommonQueryParam is a global wrapper methods which delegated +// to the default client's SetCommonQueryParam. +func SetCommonQueryParam(key, value string) *Client { + return defaultClient.SetCommonQueryParam(key, value) +} + +// SetCommonQueryString is a global wrapper methods which delegated +// to the default client's SetCommonQueryString. +func SetCommonQueryString(query string) *Client { + return defaultClient.SetCommonQueryString(query) +} + +// SetCommonCookies is a global wrapper methods which delegated +// to the default client's SetCommonCookies. +func SetCommonCookies(cookies ...*http.Cookie) *Client { + return defaultClient.SetCommonCookies(cookies...) +} + +// DisableDebugLog is a global wrapper methods which delegated +// to the default client's DisableDebugLog. +func DisableDebugLog() *Client { + return defaultClient.DisableDebugLog() +} + +// EnableDebugLog is a global wrapper methods which delegated +// to the default client's EnableDebugLog. +func EnableDebugLog() *Client { + return defaultClient.EnableDebugLog() +} + +// DevMode is a global wrapper methods which delegated +// to the default client's DevMode. +func DevMode() *Client { + return defaultClient.DevMode() +} + +// SetScheme is a global wrapper methods which delegated +// to the default client's SetScheme. +func SetScheme(scheme string) *Client { + return defaultClient.SetScheme(scheme) +} + +// SetLogger is a global wrapper methods which delegated +// to the default client's SetLogger. +func SetLogger(log Logger) *Client { + return defaultClient.SetLogger(log) +} + +// SetTimeout is a global wrapper methods which delegated +// to the default client's SetTimeout. +func SetTimeout(d time.Duration) *Client { + return defaultClient.SetTimeout(d) +} + +// EnableDumpAll is a global wrapper methods which delegated +// to the default client's EnableDumpAll. +func EnableDumpAll() *Client { + return defaultClient.EnableDumpAll() +} + +// EnableDumpAllToFile is a global wrapper methods which delegated +// to the default client's EnableDumpAllToFile. +func EnableDumpAllToFile(filename string) *Client { + return defaultClient.EnableDumpAllToFile(filename) +} + +// EnableDumpAllTo is a global wrapper methods which delegated +// to the default client's EnableDumpAllTo. +func EnableDumpAllTo(output io.Writer) *Client { + return defaultClient.EnableDumpAllTo(output) +} + +// EnableDumpAllAsync is a global wrapper methods which delegated +// to the default client's EnableDumpAllAsync. +func EnableDumpAllAsync() *Client { + return defaultClient.EnableDumpAllAsync() +} + +// EnableDumpAllWithoutRequestBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutRequestBody. +func EnableDumpAllWithoutRequestBody() *Client { + return defaultClient.EnableDumpAllWithoutRequestBody() +} + +// EnableDumpAllWithoutResponseBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutResponseBody. +func EnableDumpAllWithoutResponseBody() *Client { + return defaultClient.EnableDumpAllWithoutResponseBody() +} + +// EnableDumpAllWithoutResponse is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutResponse. +func EnableDumpAllWithoutResponse() *Client { + return defaultClient.EnableDumpAllWithoutResponse() +} + +// EnableDumpAllWithoutRequest is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutRequest. +func EnableDumpAllWithoutRequest() *Client { + return defaultClient.EnableDumpAllWithoutRequest() +} + +// EnableDumpAllWithoutHeader is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutHeader. +func EnableDumpAllWithoutHeader() *Client { + return defaultClient.EnableDumpAllWithoutHeader() +} + +// EnableDumpAllWithoutBody is a global wrapper methods which delegated +// to the default client's EnableDumpAllWithoutBody. +func EnableDumpAllWithoutBody() *Client { + return defaultClient.EnableDumpAllWithoutBody() +} + +// DisableAutoReadResponse is a global wrapper methods which delegated +// to the default client's DisableAutoReadResponse. +func DisableAutoReadResponse() *Client { + return defaultClient.DisableAutoReadResponse() +} + +// EnableAutoReadResponse is a global wrapper methods which delegated +// to the default client's EnableAutoReadResponse. +func EnableAutoReadResponse() *Client { + return defaultClient.EnableAutoReadResponse() +} + +// SetAutoDecodeContentType is a global wrapper methods which delegated +// to the default client's SetAutoDecodeContentType. +func SetAutoDecodeContentType(contentTypes ...string) *Client { + return defaultClient.SetAutoDecodeContentType(contentTypes...) +} + +// SetAutoDecodeContentTypeFunc is a global wrapper methods which delegated +// to the default client's SetAutoDecodeAllTypeFunc. +func SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { + return defaultClient.SetAutoDecodeContentTypeFunc(fn) +} + +// SetAutoDecodeAllContentType is a global wrapper methods which delegated +// to the default client's SetAutoDecodeAllContentType. +func SetAutoDecodeAllContentType() *Client { + return defaultClient.SetAutoDecodeAllContentType() +} + +// DisableAutoDecode is a global wrapper methods which delegated +// to the default client's DisableAutoDecode. +func DisableAutoDecode() *Client { + return defaultClient.DisableAutoDecode() +} + +// EnableAutoDecode is a global wrapper methods which delegated +// to the default client's EnableAutoDecode. +func EnableAutoDecode() *Client { + return defaultClient.EnableAutoDecode() +} + +// SetUserAgent is a global wrapper methods which delegated +// to the default client's SetUserAgent. +func SetUserAgent(userAgent string) *Client { + return defaultClient.SetUserAgent(userAgent) +} + +// SetCommonBearerAuthToken is a global wrapper methods which delegated +// to the default client's SetCommonBearerAuthToken. +func SetCommonBearerAuthToken(token string) *Client { + return defaultClient.SetCommonBearerAuthToken(token) +} + +// SetCommonBasicAuth is a global wrapper methods which delegated +// to the default client's SetCommonBasicAuth. +func SetCommonBasicAuth(username, password string) *Client { + return defaultClient.SetCommonBasicAuth(username, password) +} + +// SetCommonHeaders is a global wrapper methods which delegated +// to the default client's SetCommonHeaders. +func SetCommonHeaders(hdrs map[string]string) *Client { + return defaultClient.SetCommonHeaders(hdrs) +} + +// SetCommonHeader is a global wrapper methods which delegated +// to the default client's SetCommonHeader. +func SetCommonHeader(key, value string) *Client { + return defaultClient.SetCommonHeader(key, value) +} + +// SetCommonContentType is a global wrapper methods which delegated +// to the default client's SetCommonContentType. +func SetCommonContentType(ct string) *Client { + return defaultClient.SetCommonContentType(ct) +} + +// DisableDumpAll is a global wrapper methods which delegated +// to the default client's DisableDumpAll. +func DisableDumpAll() *Client { + return defaultClient.DisableDumpAll() +} + +// SetCommonDumpOptions is a global wrapper methods which delegated +// to the default client's SetCommonDumpOptions. +func SetCommonDumpOptions(opt *DumpOptions) *Client { + return defaultClient.SetCommonDumpOptions(opt) +} + +// SetProxy is a global wrapper methods which delegated +// to the default client's SetProxy. +func SetProxy(proxy func(*http.Request) (*url.URL, error)) *Client { + return defaultClient.SetProxy(proxy) +} + +// OnBeforeRequest is a global wrapper methods which delegated +// to the default client's OnBeforeRequest. +func OnBeforeRequest(m RequestMiddleware) *Client { + return defaultClient.OnBeforeRequest(m) +} + +// OnAfterResponse is a global wrapper methods which delegated +// to the default client's OnAfterResponse. +func OnAfterResponse(m ResponseMiddleware) *Client { + return defaultClient.OnAfterResponse(m) +} + +// SetProxyURL is a global wrapper methods which delegated +// to the default client's SetProxyURL. +func SetProxyURL(proxyUrl string) *Client { + return defaultClient.SetProxyURL(proxyUrl) +} + +// DisableTraceAll is a global wrapper methods which delegated +// to the default client's DisableTraceAll. +func DisableTraceAll() *Client { + return defaultClient.DisableTraceAll() +} + +// EnableTraceAll is a global wrapper methods which delegated +// to the default client's EnableTraceAll. +func EnableTraceAll() *Client { + return defaultClient.EnableTraceAll() +} + +// SetCookieJar is a global wrapper methods which delegated +// to the default client's SetCookieJar. +func SetCookieJar(jar http.CookieJar) *Client { + return defaultClient.SetCookieJar(jar) +} + +// SetJsonMarshal is a global wrapper methods which delegated +// to the default client's SetJsonMarshal. +func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { + return defaultClient.SetJsonMarshal(fn) +} + +// SetJsonUnmarshal is a global wrapper methods which delegated +// to the default client's SetJsonUnmarshal. +func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { + return defaultClient.SetJsonUnmarshal(fn) +} + +// SetXmlMarshal is a global wrapper methods which delegated +// to the default client's SetXmlMarshal. +func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { + return defaultClient.SetXmlMarshal(fn) +} + +// SetXmlUnmarshal is a global wrapper methods which delegated +// to the default client's SetXmlUnmarshal. +func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { + return defaultClient.SetXmlUnmarshal(fn) +} + +// SetDialTLS is a global wrapper methods which delegated +// to the default client's SetDialTLS. +func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + return defaultClient.SetDialTLS(fn) +} + +// SetDial is a global wrapper methods which delegated +// to the default client's SetDial. +func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { + return defaultClient.SetDial(fn) +} + +// SetTLSHandshakeTimeout is a global wrapper methods which delegated +// to the default client's SetTLSHandshakeTimeout. +func SetTLSHandshakeTimeout(timeout time.Duration) *Client { + return defaultClient.SetTLSHandshakeTimeout(timeout) +} + +// EnableForceHTTP1 is a global wrapper methods which delegated +// to the default client's EnableForceHTTP1. +func EnableForceHTTP1() *Client { + return defaultClient.EnableForceHTTP1() +} + +// EnableForceHTTP2 is a global wrapper methods which delegated +// to the default client's EnableForceHTTP2. +func EnableForceHTTP2() *Client { + return defaultClient.EnableForceHTTP2() +} + +// DisableForceHttpVersion is a global wrapper methods which delegated +// to the default client's DisableForceHttpVersion. +func DisableForceHttpVersion() *Client { + return defaultClient.DisableForceHttpVersion() +} + +// DisableAllowGetMethodPayload is a global wrapper methods which delegated +// to the default client's DisableAllowGetMethodPayload. +func DisableAllowGetMethodPayload() *Client { + return defaultClient.DisableAllowGetMethodPayload() +} + +// EnableAllowGetMethodPayload is a global wrapper methods which delegated +// to the default client's EnableAllowGetMethodPayload. +func EnableAllowGetMethodPayload() *Client { + return defaultClient.EnableAllowGetMethodPayload() +} + +// SetCommonRetryCount is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryCount for request. +func SetCommonRetryCount(count int) *Client { + return defaultClient.SetCommonRetryCount(count) +} + +// SetCommonRetryInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryInterval for request. +func SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { + return defaultClient.SetCommonRetryInterval(getRetryIntervalFunc) +} + +// SetCommonRetryFixedInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryFixedInterval for request. +func SetCommonRetryFixedInterval(interval time.Duration) *Client { + return defaultClient.SetCommonRetryFixedInterval(interval) +} + +// SetCommonRetryBackoffInterval is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryBackoffInterval for request. +func SetCommonRetryBackoffInterval(min, max time.Duration) *Client { + return defaultClient.SetCommonRetryBackoffInterval(min, max) +} + +// SetCommonRetryHook is a global wrapper methods which delegated +// to the default client, create a request and SetRetryHook for request. +func SetCommonRetryHook(hook RetryHookFunc) *Client { + return defaultClient.SetCommonRetryHook(hook) +} + +// AddCommonRetryHook is a global wrapper methods which delegated +// to the default client, create a request and AddCommonRetryHook for request. +func AddCommonRetryHook(hook RetryHookFunc) *Client { + return defaultClient.AddCommonRetryHook(hook) +} + +// SetCommonRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and SetCommonRetryCondition for request. +func SetCommonRetryCondition(condition RetryConditionFunc) *Client { + return defaultClient.SetCommonRetryCondition(condition) +} + +// AddCommonRetryCondition is a global wrapper methods which delegated +// to the default client, create a request and AddCommonRetryCondition for request. +func AddCommonRetryCondition(condition RetryConditionFunc) *Client { + return defaultClient.AddCommonRetryCondition(condition) +} + +// SetUnixSocket is a global wrapper methods which delegated +// to the default client, create a request and SetUnixSocket for request. +func SetUnixSocket(file string) *Client { + return defaultClient.SetUnixSocket(file) +} + +// GetClient is a global wrapper methods which delegated +// to the default client's GetClient. +func GetClient() *http.Client { + return defaultClient.GetClient() +} + +// NewRequest is a global wrapper methods which delegated +// to the default client's NewRequest. +func NewRequest() *Request { + return defaultClient.R() +} + +// R is a global wrapper methods which delegated +// to the default client's R(). +func R() *Request { + return defaultClient.R() +} From 0e11a197680fc27ac9aabf11e189340aaa958d43 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 10:22:51 +0800 Subject: [PATCH 439/843] support update request in retry hook (#98) --- client.go | 66 ++++++++++++++++++++++++++------------------------- req_test.go | 11 ++++++--- retry_test.go | 19 +++++++++++++++ 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/client.go b/client.go index faa09c68..a2194abc 100644 --- a/client.go +++ b/client.go @@ -929,43 +929,43 @@ func (c *Client) do(r *Request) (resp *Response, err error) { Request: r, } - for _, f := range r.client.udBeforeRequest { - if err = f(r.client, r); err != nil { - return + for { + for _, f := range r.client.udBeforeRequest { + if err = f(r.client, r); err != nil { + return + } } - } - for _, f := range r.client.beforeRequest { - if err = f(r.client, r); err != nil { - return + for _, f := range r.client.beforeRequest { + if err = f(r.client, r); err != nil { + return + } } - } - // setup trace - if r.trace == nil && r.client.trace { - r.trace = &clientTrace{} - } - if r.trace != nil { - r.ctx = r.trace.createContext(r.Context()) - } + // setup trace + if r.trace == nil && r.client.trace { + r.trace = &clientTrace{} + } + if r.trace != nil { + r.ctx = r.trace.createContext(r.Context()) + } - // setup url and host - var host string - if h := r.getHeader("Host"); h != "" { - host = h // Host header override - } else { - host = r.URL.Host - } + // setup url and host + var host string + if h := r.getHeader("Host"); h != "" { + host = h // Host header override + } else { + host = r.URL.Host + } - // setup header - var header http.Header - if r.Headers == nil { - header = make(http.Header) - } else { - header = r.Headers.Clone() - } - contentLength := int64(len(r.body)) + // setup header + var header http.Header + if r.Headers == nil { + header = make(http.Header) + } else { + header = r.Headers.Clone() + } + contentLength := int64(len(r.body)) - for { var reqBody io.ReadCloser if r.getBody != nil { reqBody, err = r.getBody() @@ -1036,10 +1036,12 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) - // clean buffers + // clean up before retry if r.dumpBuffer != nil { r.dumpBuffer.Reset() } + r.trace = nil + r.ctx = nil resp.body = nil } diff --git a/req_test.go b/req_test.go index 078c325a..06db0da9 100644 --- a/req_test.go +++ b/req_test.go @@ -297,9 +297,6 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTooManyRequests) w.Header().Set(hdrContentTypeKey, jsonContentType) w.Write([]byte(`{"errMsg":"too many requests"}`)) - case "/retry": - r.ParseForm() - r.Form.Get("attempt") case "/chunked": w.Header().Add("Trailer", "Expires") w.Write([]byte(`This is a chunked body`)) @@ -356,6 +353,14 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.RawQuery)) case "/search": handleSearch(w, r) + case "/protected": + auth := r.Header.Get("Authorization") + if auth == "Bearer goodtoken" { + w.Write([]byte("good")) + } else { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`bad`)) + } default: if strings.HasPrefix(r.URL.Path, "/user") { handleGetUserProfile(w, r) diff --git a/retry_test.go b/retry_test.go index afd69895..4111ba11 100644 --- a/retry_test.go +++ b/retry_test.go @@ -139,3 +139,22 @@ func TestRetryWithSetResult(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, "test=test", headers.Get("Cookie")) } + +func TestRetryWithModify(t *testing.T) { + tokens := []string{"badtoken1", "badtoken2", "goodtoken"} + tokenIndex := 0 + c := tc().EnableDumpAll(). + SetCommonRetryCount(2). + SetCommonRetryHook(func(resp *Response, err error) { + tokenIndex++ + resp.Request.SetBearerAuthToken(tokens[tokenIndex]) + }).SetCommonRetryCondition(func(resp *Response, err error) bool { + return err != nil || resp.StatusCode == http.StatusUnauthorized + }) + + resp, err := c.R(). + SetBearerAuthToken(tokens[tokenIndex]). + Get("/protected") + assertSuccess(t, resp, err) + assertEqual(t, 2, resp.Request.RetryAttempt) +} From 52a12a6f3584b62c55c935313bb4a9986bfa0873 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 12:03:37 +0800 Subject: [PATCH 440/843] update README: improve retry describe --- README.md | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index c68beeee..9771515d 100644 --- a/README.md +++ b/README.md @@ -909,23 +909,17 @@ client.SetCommonRetryFixedInterval(2 * time.Seconds) // Set the retry to use a custom retry interval algorithm. client.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { - // Sleep seconds from "Retry-After" response header if it is present and correct (https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html). + // Sleep seconds from "Retry-After" response header if it is present and correct. + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html if resp.Response != nil { - ra := resp.Header.Get("Retry-After") - if ra != "" { + if ra := resp.Header.Get("Retry-After"); ra != "" { after, err := strconv.Atoi(ra) if err == nil { return time.Duration(after) * time.Second } } } - return 2 * time.Second // Otherwise sleep 2 seconds -}) - -// Add a retry hook which will be executed before a retry. -client.AddCommonRetryHook(func(resp *req.Response, err error){ - req := resp.Request.RawRequest - fmt.Println("Retry request:", req.Method, req.URL) + return 2 * time.Second // Otherwise, sleep 2 seconds }) // Add a retry condition which determines whether the request should retry. @@ -935,7 +929,15 @@ client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { // Add another retry condition client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return resp.StatusCode == http.StatusTooManyRequests + return resp.StatusCode == http.StatusUnauthorized +}) + +// Add a retry hook which will be executed before a retry. +client.AddCommonRetryHook(func(resp *req.Response, err error){ + req := resp.Request.RawRequest + fmt.Println("Retry request:", req.Method, req.URL) + // Modify request settings in the retry hook. + resp.Request.SetBearerAuthToken(token) }) ``` @@ -945,10 +947,10 @@ You can also override retry settings at request-level (check the full list of re client.R(). SetRetryCount(2). SetRetryInterval(intervalFunc). - SetRetryHook(hookFunc1). // Unlike add, set will remove all other retry hooks which is added before at both request and client level. AddRetryHook(hookFunc2). - SetRetryCondition(conditionFunc1). // Similarly, this will remove all other retry conditions which is added before at both request and client level. - AddRetryCondition(conditionFunc2) + SetRetryHook(hookFunc1). // Unlike add, set will remove all other retry hooks which is added before at both request and client level. + AddRetryCondition(conditionFunc2). + SetRetryCondition(conditionFunc1) // Similarly, this will remove all other retry conditions which is added before at both request and client level. ``` ## TODO List From 2d930bebf9550aa916c4228bb468db4bc78c58d7 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 16:03:20 +0800 Subject: [PATCH 441/843] TestMustSendMethods and TestSendMethods --- request_test.go | 183 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 140 insertions(+), 43 deletions(-) diff --git a/request_test.go b/request_test.go index f0bcf397..6d2b5aa3 100644 --- a/request_test.go +++ b/request_test.go @@ -16,53 +16,150 @@ import ( "time" ) -func TestMethods(t *testing.T) { - testMethods(t, tc()) - testMethods(t, tc().EnableForceHTTP1()) -} - -func testMethods(t *testing.T, c *Client) { - resp, err := c.R().Put("/") - assertSuccess(t, resp, err) - assertEqual(t, "PUT", resp.Header.Get("Method")) - resp = c.R().MustPut("/") - assertEqual(t, "PUT", resp.Header.Get("Method")) - - resp, err = c.R().Patch("/") - assertSuccess(t, resp, err) - assertEqual(t, "PATCH", resp.Header.Get("Method")) - resp = c.R().MustPatch("/") - assertEqual(t, "PATCH", resp.Header.Get("Method")) - - resp, err = c.R().Delete("/") - assertSuccess(t, resp, err) - assertEqual(t, "DELETE", resp.Header.Get("Method")) - resp = c.R().MustDelete("/") - assertEqual(t, "DELETE", resp.Header.Get("Method")) +func TestMustSendMethods(t *testing.T) { + c := tc() + testCases := []struct { + SendReq func(req *Request, url string) *Response + ExpectMethod string + }{ + { + SendReq: func(req *Request, url string) *Response { + return req.MustGet(url) + }, + ExpectMethod: "GET", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustPost(url) + }, + ExpectMethod: "POST", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustPatch(url) + }, + ExpectMethod: "PATCH", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustDelete(url) + }, + ExpectMethod: "DELETE", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustOptions(url) + }, + ExpectMethod: "OPTIONS", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustPut(url) + }, + ExpectMethod: "PUT", + }, + { + SendReq: func(req *Request, url string) *Response { + return req.MustHead(url) + }, + ExpectMethod: "HEAD", + }, + } - resp, err = c.R().Options("/") - assertSuccess(t, resp, err) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) - resp = c.R().MustOptions("/") - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) + for _, tc := range testCases { + testMethod(t, c, func(req *Request) *Response { + return tc.SendReq(req, "/") + }, tc.ExpectMethod, false) + } - resp, err = c.R().Head("/") - assertSuccess(t, resp, err) - assertEqual(t, "HEAD", resp.Header.Get("Method")) - resp = c.R().MustHead("/") - assertEqual(t, "HEAD", resp.Header.Get("Method")) + // test panic + for _, tc := range testCases { + testMethod(t, c, func(req *Request) *Response { + return tc.SendReq(req, "/\r\n") + }, tc.ExpectMethod, true) + } +} - resp, err = c.R().Get("/") - assertSuccess(t, resp, err) - assertEqual(t, "GET", resp.Header.Get("Method")) - resp = c.R().MustGet("/") - assertEqual(t, "GET", resp.Header.Get("Method")) +func TestSendMethods(t *testing.T) { + c := tc() + testCases := []struct { + SendReq func(req *Request) (resp *Response, err error) + ExpectMethod string + }{ + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Get("/") + }, + ExpectMethod: "GET", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Post("/") + }, + ExpectMethod: "POST", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Put("/") + }, + ExpectMethod: "PUT", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Patch("/") + }, + ExpectMethod: "PATCH", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Delete("/") + }, + ExpectMethod: "DELETE", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Options("/") + }, + ExpectMethod: "OPTIONS", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Head("/") + }, + ExpectMethod: "HEAD", + }, + { + SendReq: func(req *Request) (resp *Response, err error) { + return req.Send("GET", "/") + }, + ExpectMethod: "GET", + }, + } + for _, tc := range testCases { + testMethod(t, c, func(req *Request) *Response { + resp, err := tc.SendReq(req) + if err != nil { + t.Errorf("%s %s: %s", req.method, req.RawURL, err.Error()) + } + return resp + }, tc.ExpectMethod, false) + } +} - resp, err = c.R().Post("/") - assertSuccess(t, resp, err) - assertEqual(t, "POST", resp.Header.Get("Method")) - resp = c.R().MustPost("/") - assertEqual(t, "POST", resp.Header.Get("Method")) +func testMethod(t *testing.T, c *Client, sendReq func(*Request) *Response, expectMethod string, expectPanic bool) { + r := c.R() + if expectPanic { + defer func() { + if err := recover(); err == nil { + t.Errorf("Must mehod %s should panic", expectMethod) + } + }() + } + resp := sendReq(r) + method := resp.Header.Get("Method") + if expectMethod != method { + t.Errorf("Expect method %s, got method %s", expectMethod, method) + } } func TestEnableDump(t *testing.T) { From 8d6a53bf5c6d359cc7d5259e1677e4a2559dae33 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 17:04:00 +0800 Subject: [PATCH 442/843] refactor TestEnableDump --- request_test.go | 195 +++++++++++++++++++++--------------------------- retry_test.go | 2 +- 2 files changed, 87 insertions(+), 110 deletions(-) diff --git a/request_test.go b/request_test.go index 6d2b5aa3..c10f6c19 100644 --- a/request_test.go +++ b/request_test.go @@ -162,131 +162,108 @@ func testMethod(t *testing.T, c *Client, sendReq func(*Request) *Response, expec } } -func TestEnableDump(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDump() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpToFIle(t *testing.T) { - tmpFile := "tmp_dumpfile_req" - resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") - assertSuccess(t, resp, err) - assertEqual(t, true, len(getTestFileContent(t, tmpFile)) > 0) - os.Remove(tests.GetTestFilePath(tmpFile)) +type dumpExpected struct { + ReqHeader bool + ReqBody bool + RespHeader bool + RespBody bool } -func TestEnableDumpWithoutRequest(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutRequest() - *reqHeader = false - *reqBody = false - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpWithoutRequestBody(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutRequestBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpWithoutResponse(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutResponse() - *reqHeader = true - *reqBody = true - *respHeader = false - *respBody = false - }) -} - -func TestEnableDumpWithoutResponseBody(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutResponseBody() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = false - }) -} - -func TestEnableDumpWithoutHeader(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutHeader() - *reqHeader = false - *reqBody = true - *respHeader = false - *respBody = true - }) -} - -func TestEnableDumpWithoutBody(t *testing.T) { - testEnableDump(t, func(r *Request, reqHeader, reqBody, respHeader, respBody *bool) { - r.EnableDumpWithoutBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = false - }) -} - -func testEnableDump(t *testing.T, fn func(r *Request, reqHeader, reqBody, respHeader, respBody *bool)) { +func testEnableDump(t *testing.T, fn func(r *Request) (de dumpExpected)) { testDump := func(c *Client) { r := c.R() - var reqHeader, reqBody, respHeader, respBody bool - fn(r, &reqHeader, &reqBody, &respHeader, &respBody) + de := fn(r) resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "user-agent", reqHeader) - assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "date", respHeader) - assertContains(t, dump, "testpost: text response", respBody) + assertContains(t, dump, "user-agent", de.ReqHeader) + assertContains(t, dump, "test body", de.ReqBody) + assertContains(t, dump, "date", de.RespHeader) + assertContains(t, dump, "testpost: text response", de.RespBody) } - testDump(tc()) - testDump(tc().EnableForceHTTP1()) -} - -func TestSetDumpOptions(t *testing.T) { - testSetDumpOptions(t, tc()) - testSetDumpOptions(t, tc().EnableForceHTTP1()) + c := tc() + testDump(c) + testDump(c.EnableForceHTTP1()) } -func testSetDumpOptions(t *testing.T, c *Client) { - opt := &DumpOptions{ - RequestHeader: true, - RequestBody: false, - ResponseHeader: false, - ResponseBody: true, +func TestEnableDump(t *testing.T) { + testCases := []func(r *Request) (d dumpExpected){ + func(r *Request) (de dumpExpected) { + r.EnableDump() + de.ReqHeader = true + de.ReqBody = true + de.RespHeader = true + de.RespBody = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutHeader() + de.ReqBody = true + de.RespBody = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutBody() + de.ReqHeader = true + de.RespHeader = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutRequest() + de.RespHeader = true + de.RespBody = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutRequestBody() + de.ReqHeader = true + de.RespHeader = true + de.RespBody = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutResponse() + de.ReqHeader = true + de.ReqBody = true + return + }, + func(r *Request) (de dumpExpected) { + r.EnableDumpWithoutResponseBody() + de.ReqHeader = true + de.ReqBody = true + de.RespHeader = true + return + }, + func(r *Request) (de dumpExpected) { + r.SetDumpOptions(&DumpOptions{ + RequestHeader: true, + RequestBody: true, + ResponseBody: true, + }).EnableDump() + de.ReqHeader = true + de.ReqBody = true + de.RespBody = true + return + }, + } + for _, fn := range testCases { + testEnableDump(t, fn) } - resp, err := c.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post(getTestServerURL()) - assertSuccess(t, resp, err) - dump := resp.Dump() - assertContains(t, dump, "user-agent", true) - assertContains(t, dump, "test body", false) - assertContains(t, dump, "date", false) - assertContains(t, dump, "testpost: text response", true) } -func TestGet(t *testing.T) { - testGet(t, tc()) - testGet(t, tc().EnableForceHTTP1()) +func TestEnableDumpTo(t *testing.T) { + buff := new(bytes.Buffer) + resp, err := tc().R().EnableDumpTo(buff).Get("/") + assertSuccess(t, resp, err) + assertEqual(t, true, buff.Len() > 0) } -func testGet(t *testing.T, c *Client) { - resp, err := c.R().Get("/") +func TestEnableDumpToFIle(t *testing.T) { + tmpFile := "tmp_dumpfile_req" + resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") assertSuccess(t, resp, err) - assertEqual(t, "TestGet: text response", resp.String()) + assertEqual(t, true, len(getTestFileContent(t, tmpFile)) > 0) + os.Remove(tests.GetTestFilePath(tmpFile)) } func TestBadRequest(t *testing.T) { diff --git a/retry_test.go b/retry_test.go index 4111ba11..cfd86933 100644 --- a/retry_test.go +++ b/retry_test.go @@ -143,7 +143,7 @@ func TestRetryWithSetResult(t *testing.T) { func TestRetryWithModify(t *testing.T) { tokens := []string{"badtoken1", "badtoken2", "goodtoken"} tokenIndex := 0 - c := tc().EnableDumpAll(). + c := tc(). SetCommonRetryCount(2). SetCommonRetryHook(func(resp *Response, err error) { tokenIndex++ From 142e37edf72b07dfda203c8f6d7cc27126b7d27b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 18:04:11 +0800 Subject: [PATCH 443/843] refactor TestSetBody --- req_test.go | 5 +- request_test.go | 224 +++++++++++++++++++++++++----------------------- 2 files changed, 119 insertions(+), 110 deletions(-) diff --git a/req_test.go b/req_test.go index 06db0da9..5e5d080f 100644 --- a/req_test.go +++ b/req_test.go @@ -173,7 +173,8 @@ func isNil(v interface{}) bool { return false } -type echo struct { +// Echo is used in "/echo" API. +type Echo struct { Header http.Header `json:"header" xml:"header"` Body string `json:"body" xml:"body"` } @@ -215,7 +216,7 @@ func handlePost(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.Header.Get(hdrContentTypeKey))) case "/echo": b, _ := ioutil.ReadAll(r.Body) - e := echo{ + e := Echo{ Header: r.Header, Body: string(b), } diff --git a/request_test.go b/request_test.go index c10f6c19..be5209b1 100644 --- a/request_test.go +++ b/request_test.go @@ -11,7 +11,6 @@ import ( "net/http" "net/url" "os" - "strings" "testing" "time" ) @@ -267,175 +266,184 @@ func TestEnableDumpToFIle(t *testing.T) { } func TestBadRequest(t *testing.T) { - testBadRequest(t, tc()) - testBadRequest(t, tc().EnableForceHTTP1()) -} - -func testBadRequest(t *testing.T, c *Client) { - resp, err := c.R().Get("/bad-request") + resp, err := tc().R().Get("/bad-request") assertStatus(t, resp, err, http.StatusBadRequest, "400 Bad Request") } func TestSetBodyMarshal(t *testing.T) { - testSetBodyMarshal(t, tc()) - testSetBodyMarshal(t, tc().EnableForceHTTP1()) -} - -func testSetBodyMarshal(t *testing.T, c *Client) { + username := "imroc" type User struct { Username string `json:"username" xml:"username"` } - assertUsername := func(username string) func(e *echo) { - return func(e *echo) { - var user User - err := json.Unmarshal([]byte(e.Body), &user) - assertNoError(t, err) - assertEqual(t, username, user.Username) - } + assertUsernameJson := func(body []byte) { + var user User + err := json.Unmarshal(body, &user) + assertNoError(t, err) + assertEqual(t, username, user.Username) } - assertUsernameXml := func(username string) func(e *echo) { - return func(e *echo) { - var user User - err := xml.Unmarshal([]byte(e.Body), &user) - assertNoError(t, err) - assertEqual(t, username, user.Username) - } + assertUsernameXml := func(body []byte) { + var user User + err := xml.Unmarshal(body, &user) + assertNoError(t, err) + assertEqual(t, username, user.Username) } + testCases := []struct { Set func(r *Request) - Assert func(e *echo) + Assert func(body []byte) }{ { // SetBody with map Set: func(r *Request) { m := map[string]interface{}{ - "username": "imroc", + "username": username, } r.SetBody(&m) }, - Assert: assertUsername("imroc"), - }, - { // SetBodyJsonMarshal with map - Set: func(r *Request) { - m := map[string]interface{}{ - "username": "imroc", - } - r.SetBodyJsonMarshal(&m) - }, - Assert: assertUsername("imroc"), + Assert: assertUsernameJson, }, { // SetBody with struct Set: func(r *Request) { var user User - user.Username = "imroc" + user.Username = username r.SetBody(&user) }, - Assert: assertUsername("imroc"), + Assert: assertUsernameJson, }, { // SetBody with struct use xml Set: func(r *Request) { var user User - user.Username = "imroc" + user.Username = username r.SetBody(&user).SetContentType(xmlContentType) }, - Assert: assertUsernameXml("imroc"), + Assert: assertUsernameXml, + }, + { // SetBodyJsonMarshal with map + Set: func(r *Request) { + m := map[string]interface{}{ + "username": username, + } + r.SetBodyJsonMarshal(&m) + }, + Assert: assertUsernameJson, }, { // SetBodyJsonMarshal with struct Set: func(r *Request) { var user User - user.Username = "imroc" + user.Username = username r.SetBodyJsonMarshal(&user) }, - Assert: assertUsername("imroc"), + Assert: assertUsernameJson, }, { // SetBodyXmlMarshal with struct Set: func(r *Request) { var user User - user.Username = "imroc" + user.Username = username r.SetBodyXmlMarshal(&user) }, - Assert: assertUsernameXml("imroc"), + Assert: assertUsernameXml, }, } - for _, cs := range testCases { + c := tc() + for _, tc := range testCases { r := c.R() - cs.Set(r) - var e echo + tc.Set(r) + var e Echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - cs.Assert(&e) + tc.Assert([]byte(e.Body)) } } -func TestSetBodyReader(t *testing.T) { - var e echo - resp, err := tc().R().SetBody(ioutil.NopCloser(bytes.NewBufferString("hello"))).SetResult(&e).Post("/echo") - assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) -} - -func TestSetBodyGetContentFunc(t *testing.T) { - var e echo - resp, err := tc().R().SetBody(func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewBufferString("hello")), nil - }).SetResult(&e).Post("/echo") - assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) - - e = echo{} - var fn GetContentFunc = func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewBufferString("hello")), nil +func TestSetBody(t *testing.T) { + body := "hello" + fn := func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBufferString(body)), nil } - resp, err = tc().R().SetBody(fn).SetResult(&e).Post("/echo") - assertSuccess(t, resp, err) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) - assertEqual(t, "hello", e.Body) -} - -func TestSetBodyContent(t *testing.T) { - testSetBodyContent(t, tc()) - testSetBodyContent(t, tc().EnableForceHTTP1()) -} - -func testSetBodyContent(t *testing.T, c *Client) { - var e echo - testBody := "test body" - - testCases := []func(r *Request){ - func(r *Request) { // SetBody with string - r.SetBody(testBody) + c := tc() + testCases := []struct { + SetBody func(r *Request) + ContentType string + }{ + { + SetBody: func(r *Request) { // SetBody with `func() (io.ReadCloser, error)` + r.SetBody(fn) + }, + }, + { + SetBody: func(r *Request) { // SetBody with GetContentFunc + r.SetBody(GetContentFunc(fn)) + }, + }, + { + SetBody: func(r *Request) { // SetBody with io.ReadCloser + r.SetBody(ioutil.NopCloser(bytes.NewBufferString(body))) + }, + }, + { + SetBody: func(r *Request) { // SetBody with io.Reader + r.SetBody(bytes.NewBufferString(body)) + }, + }, + { + SetBody: func(r *Request) { // SetBody with string + r.SetBody(body) + }, + ContentType: plainTextContentType, + }, + { + SetBody: func(r *Request) { // SetBody with []byte + r.SetBody([]byte(body)) + }, + ContentType: plainTextContentType, + }, + { + SetBody: func(r *Request) { // SetBodyString + r.SetBodyString(body) + }, + ContentType: plainTextContentType, }, - func(r *Request) { // SetBody with []byte - r.SetBody([]byte(testBody)) + { + SetBody: func(r *Request) { // SetBodyBytes + r.SetBodyBytes([]byte(body)) + }, + ContentType: plainTextContentType, + }, + { + SetBody: func(r *Request) { // SetBodyJsonString + r.SetBodyJsonString(body) + }, + ContentType: jsonContentType, + }, + { + SetBody: func(r *Request) { // SetBodyJsonBytes + r.SetBodyJsonBytes([]byte(body)) + }, + ContentType: jsonContentType, }, - func(r *Request) { // SetBodyString - r.SetBodyString(testBody) + { + SetBody: func(r *Request) { // SetBodyXmlString + r.SetBodyXmlString(body) + }, + ContentType: xmlContentType, }, - func(r *Request) { // SetBodyBytes - r.SetBodyBytes([]byte(testBody)) + { + SetBody: func(r *Request) { // SetBodyXmlBytes + r.SetBodyXmlBytes([]byte(body)) + }, + ContentType: xmlContentType, }, } - - for _, fn := range testCases { + for _, tc := range testCases { r := c.R() - fn(r) - var e echo + tc.SetBody(r) + var e Echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, plainTextContentType, e.Header.Get(hdrContentTypeKey)) - assertEqual(t, testBody, e.Body) + assertEqual(t, tc.ContentType, e.Header.Get(hdrContentTypeKey)) + assertEqual(t, body, e.Body) } - - // Set Reader - testBodyReader := strings.NewReader(testBody) - e = echo{} - resp, err := c.R().SetBody(testBodyReader).SetResult(&e).Post("/echo") - assertSuccess(t, resp, err) - assertEqual(t, testBody, e.Body) - assertEqual(t, "", e.Header.Get(hdrContentTypeKey)) } func TestCookie(t *testing.T) { From 4610a4557de6c1ca1a5f21bb981b8dea99783178 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 9 Mar 2022 19:47:50 +0800 Subject: [PATCH 444/843] refactor TestCookie, TestSetBasicAuth and TestSetBearerAuthToken --- request_test.go | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/request_test.go b/request_test.go index be5209b1..99172dbf 100644 --- a/request_test.go +++ b/request_test.go @@ -447,13 +447,8 @@ func TestSetBody(t *testing.T) { } func TestCookie(t *testing.T) { - testCookie(t, tc()) - testCookie(t, tc().EnableForceHTTP1()) -} - -func testCookie(t *testing.T, c *Client) { headers := make(http.Header) - resp, err := c.R().SetCookies( + resp, err := tc().R().SetCookies( &http.Cookie{ Name: "cookie1", Value: "value1", @@ -467,23 +462,20 @@ func testCookie(t *testing.T, c *Client) { assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } -func TestAuth(t *testing.T) { - testAuth(t, tc()) - testAuth(t, tc().EnableForceHTTP1()) -} - -func testAuth(t *testing.T, c *Client) { +func TestSetBasicAuth(t *testing.T) { headers := make(http.Header) - resp, err := c.R(). + resp, err := tc().R(). SetBasicAuth("imroc", "123456"). SetResult(&headers). Get("/header") assertSuccess(t, resp, err) assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) +} +func TestSetBearerAuthToken(t *testing.T) { token := "NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4" - headers = make(http.Header) - resp, err = c.R(). + headers := make(http.Header) + resp, err := tc().R(). SetBearerAuthToken(token). SetResult(&headers). Get("/header") From 5f4f3455af122a7c2088570f8fbf112eeb747d92 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 10:23:30 +0800 Subject: [PATCH 445/843] refactor TestTraceInfo --- request_test.go | 62 +++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/request_test.go b/request_test.go index 99172dbf..af2a91a3 100644 --- a/request_test.go +++ b/request_test.go @@ -704,21 +704,44 @@ func testHostHeaderOverride(t *testing.T, c *Client) { assertEqual(t, "testhostname", resp.String()) } -func TestTraceInfo(t *testing.T) { - testTraceInfo(t, tc()) - testTraceInfo(t, tc().EnableForceHTTP1()) - resp, err := tc().R().Get("/") - assertSuccess(t, resp, err) +func assertTraceInfo(t *testing.T, resp *Response, enable bool) { ti := resp.TraceInfo() - assertContains(t, ti.String(), "not enabled", true) - assertContains(t, ti.Blame(), "not enabled", true) + assertEqual(t, true, resp.TotalTime() > 0) + if !enable { + assertEqual(t, false, ti.TotalTime > 0) + assertIsNil(t, ti.RemoteAddr) + assertContains(t, ti.String(), "not enabled", true) + assertContains(t, ti.Blame(), "not enabled", true) + return + } - resp, err = tc().EnableTraceAll().R().Get("/") - assertSuccess(t, resp, err) - ti = resp.TraceInfo() assertContains(t, ti.String(), "not enabled", false) assertContains(t, ti.Blame(), "not enabled", false) - assertEqual(t, true, resp.TotalTime() > 0) + assertEqual(t, true, ti.TotalTime > 0) + assertEqual(t, true, ti.ConnectTime > 0) + assertEqual(t, true, ti.FirstResponseTime > 0) + assertEqual(t, true, ti.ResponseTime > 0) + assertNotNil(t, ti.RemoteAddr) + if ti.IsConnReused { + assertEqual(t, true, ti.TCPConnectTime == 0) + assertEqual(t, true, ti.TLSHandshakeTime == 0) + } else { + assertEqual(t, true, ti.TCPConnectTime > 0) + assertEqual(t, true, ti.TLSHandshakeTime > 0) + } +} + +func assertEnableTraceInfo(t *testing.T, resp *Response) { + assertTraceInfo(t, resp, true) +} + +func assertDisableTraceInfo(t *testing.T, resp *Response) { + assertTraceInfo(t, resp, false) +} + +func TestTraceInfo(t *testing.T) { + testTraceInfo(t, tc()) + testTraceInfo(t, tc().EnableForceHTTP1()) } func testTraceInfo(t *testing.T, c *Client) { @@ -726,29 +749,18 @@ func testTraceInfo(t *testing.T, c *Client) { c.EnableTraceAll() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - ti := resp.TraceInfo() - assertEqual(t, true, ti.TotalTime > 0) - assertEqual(t, true, ti.TCPConnectTime > 0) - assertEqual(t, true, ti.TLSHandshakeTime > 0) - assertEqual(t, true, ti.ConnectTime > 0) - assertEqual(t, true, ti.FirstResponseTime > 0) - assertEqual(t, true, ti.ResponseTime > 0) - assertNotNil(t, ti.RemoteAddr) + assertEnableTraceInfo(t, resp) // disable trace at client level c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) - ti = resp.TraceInfo() - assertEqual(t, false, ti.TotalTime > 0) - assertIsNil(t, ti.RemoteAddr) + assertDisableTraceInfo(t, resp) // enable trace at request level resp, err = c.R().EnableTrace().Get("/") assertSuccess(t, resp, err) - ti = resp.TraceInfo() - assertEqual(t, true, ti.TotalTime > 0) - assertNotNil(t, ti.RemoteAddr) + assertEnableTraceInfo(t, resp) } func TestTraceOnTimeout(t *testing.T) { From d5e5b3737627f261f20836161a73cc76d95d98b3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 10:43:16 +0800 Subject: [PATCH 446/843] refactor TestTraceInfo --- req_test.go | 5 +++++ request_test.go | 37 +++++++++++++++++-------------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/req_test.go b/req_test.go index 5e5d080f..f25d79b4 100644 --- a/req_test.go +++ b/req_test.go @@ -764,3 +764,8 @@ func TestTrailer(t *testing.T) { t.Error("trailer not exists") } } + +func testWithAllTransport(t *testing.T, testFunc func(t *testing.T, c *Client)) { + testFunc(t, tc()) + testFunc(t, tc().EnableForceHTTP1()) +} diff --git a/request_test.go b/request_test.go index af2a91a3..a85af8c4 100644 --- a/request_test.go +++ b/request_test.go @@ -764,26 +764,23 @@ func testTraceInfo(t *testing.T, c *Client) { } func TestTraceOnTimeout(t *testing.T) { - testTraceOnTimeout(t, C()) - testTraceOnTimeout(t, C().EnableForceHTTP1()) -} - -func testTraceOnTimeout(t *testing.T, c *Client) { - c.EnableTraceAll().SetTimeout(100 * time.Millisecond) - - resp, err := c.R().Get("http://req-nowhere.local") - assertNotNil(t, err) - assertNotNil(t, resp) - - tr := resp.TraceInfo() - assertEqual(t, true, tr.DNSLookupTime >= 0) - assertEqual(t, true, tr.ConnectTime == 0) - assertEqual(t, true, tr.TLSHandshakeTime == 0) - assertEqual(t, true, tr.TCPConnectTime == 0) - assertEqual(t, true, tr.FirstResponseTime == 0) - assertEqual(t, true, tr.ResponseTime == 0) - assertEqual(t, true, tr.TotalTime > 0) - assertEqual(t, true, tr.TotalTime == resp.TotalTime()) + testWithAllTransport(t, func(t *testing.T, c *Client) { + c.EnableTraceAll().SetTimeout(100 * time.Millisecond) + + resp, err := c.R().Get("http://req-nowhere.local") + assertNotNil(t, err) + assertNotNil(t, resp) + + ti := resp.TraceInfo() + assertEqual(t, true, ti.DNSLookupTime >= 0) + assertEqual(t, true, ti.ConnectTime == 0) + assertEqual(t, true, ti.TLSHandshakeTime == 0) + assertEqual(t, true, ti.TCPConnectTime == 0) + assertEqual(t, true, ti.FirstResponseTime == 0) + assertEqual(t, true, ti.ResponseTime == 0) + assertEqual(t, true, ti.TotalTime > 0) + assertEqual(t, true, ti.TotalTime == resp.TotalTime()) + }) } func TestAutoDetectRequestContentType(t *testing.T) { From 4f9645d118e912c3fc07a237c195d7d9e8b8b0b2 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 14:29:05 +0800 Subject: [PATCH 447/843] refactor all request tests --- request_test.go | 72 ++++++++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/request_test.go b/request_test.go index a85af8c4..a1cd5f6b 100644 --- a/request_test.go +++ b/request_test.go @@ -484,8 +484,7 @@ func TestSetBearerAuthToken(t *testing.T) { } func TestHeader(t *testing.T) { - testHeader(t, tc()) - testHeader(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testHeader) } func testHeader(t *testing.T, c *Client) { @@ -511,8 +510,7 @@ func testHeader(t *testing.T, c *Client) { } func TestQueryParam(t *testing.T) { - testQueryParam(t, tc()) - testQueryParam(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testQueryParam) } func testQueryParam(t *testing.T, c *Client) { @@ -610,8 +608,7 @@ func testPathParam(t *testing.T, c *Client) { } func TestSuccess(t *testing.T) { - testSuccess(t, tc()) - testSuccess(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testSuccess) } func testSuccess(t *testing.T, c *Client) { @@ -634,8 +631,7 @@ func testSuccess(t *testing.T, c *Client) { } func TestError(t *testing.T) { - testError(t, tc()) - testError(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testError) } func testError(t *testing.T, c *Client) { @@ -666,8 +662,7 @@ func testError(t *testing.T, c *Client) { } func TestForm(t *testing.T) { - testForm(t, tc()) - testForm(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testForm) } func testForm(t *testing.T, c *Client) { @@ -694,8 +689,7 @@ func testForm(t *testing.T, c *Client) { } func TestHostHeaderOverride(t *testing.T) { - testHostHeaderOverride(t, tc()) - testHostHeaderOverride(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testHostHeaderOverride) } func testHostHeaderOverride(t *testing.T, c *Client) { @@ -740,8 +734,7 @@ func assertDisableTraceInfo(t *testing.T, resp *Response) { } func TestTraceInfo(t *testing.T) { - testTraceInfo(t, tc()) - testTraceInfo(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testTraceInfo) } func testTraceInfo(t *testing.T, c *Client) { @@ -840,25 +833,48 @@ func TestFixPragmaCache(t *testing.T) { } func TestSetFileBytes(t *testing.T) { - resp, err := tc().R().SetFileBytes("file", "file.txt", []byte("test")).Post("/file-text") - assertSuccess(t, resp, err) + resp := uploadTextFile(t, func(r *Request) { + r.SetFileBytes("file", "file.txt", []byte("test")) + }) assertEqual(t, "test", resp.String()) } -func TestSetBodyWrapper(t *testing.T) { - b := []byte("test") - s := string(b) - c := tc() +func TestSetFileReader(t *testing.T) { + buff := bytes.NewBufferString("test") + resp := uploadTextFile(t, func(r *Request) { + r.SetFileReader("file", "file.txt", buff) + }) + assertEqual(t, "test", resp.String()) - r := c.R().SetBodyXmlString(s) - assertEqual(t, true, len(r.body) > 0) + buff = bytes.NewBufferString("test") + resp = uploadTextFile(t, func(r *Request) { + r.SetFileReader("file", "file.txt", ioutil.NopCloser(buff)) + }) + assertEqual(t, "test", resp.String()) +} - r = c.R().SetBodyXmlBytes(b) - assertEqual(t, true, len(r.body) > 0) +func TestSetFile(t *testing.T) { + filename := "sample-file.txt" + resp := uploadTextFile(t, func(r *Request) { + r.SetFile("file", tests.GetTestFilePath(filename)) + }) + assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) +} - r = c.R().SetBodyJsonString(s) - assertEqual(t, true, len(r.body) > 0) +func TestSetFiles(t *testing.T) { + filename := "sample-file.txt" + resp := uploadTextFile(t, func(r *Request) { + r.SetFiles(map[string]string{ + "file": tests.GetTestFilePath(filename), + }) + }) + assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) +} - r = c.R().SetBodyJsonBytes(b) - assertEqual(t, true, len(r.body) > 0) +func uploadTextFile(t *testing.T, setReq func(r *Request)) *Response { + r := tc().R() + setReq(r) + resp, err := r.Post("/file-text") + assertSuccess(t, resp, err) + return resp } From e8e7f095516867e854c59f07b4a44cc182762e7b Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 15:04:30 +0800 Subject: [PATCH 448/843] SetFile support retry and refactor global wrapper tests for request --- req_test.go | 73 -------------------------------------- request.go | 14 +++++++- request_test.go | 3 ++ request_wrapper_test.go | 78 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 74 deletions(-) create mode 100644 request_wrapper_test.go diff --git a/req_test.go b/req_test.go index f25d79b4..8127e756 100644 --- a/req_test.go +++ b/req_test.go @@ -536,79 +536,6 @@ func testGlobalWrapperSendRequest(t *testing.T) { assertEqual(t, "POST", resp.Header.Get("Method")) } -func testGlobalWrapperSetRequest(t *testing.T, rs ...*Request) { - for _, r := range rs { - assertNotNil(t, r) - } -} - -func TestGlobalWrapperSetRequest(t *testing.T) { - SetLogger(nil) - testFilePath := tests.GetTestFilePath("sample-file.txt") - values := make(url.Values) - values.Add("test", "test") - testGlobalWrapperSetRequest(t, - SetFiles(map[string]string{"test": testFilePath}), - SetFile("test", tests.GetTestFilePath("sample-file.txt")), - SetFile("test", tests.GetTestFilePath("file-not-exists.txt")), - SetFileReader("test", "test.txt", bytes.NewBufferString("test")), - SetFileBytes("test", "test.txt", []byte("test")), - SetFileUpload(FileUpload{ParamName: "test", FileName: "test.txt", GetFileContent: func() (io.ReadCloser, error) { - return nil, nil - }}), - SetError(&ErrorMessage{}), - SetResult(&UserInfo{}), - SetOutput(nil), - SetOutput(bytes.NewBufferString("test")), - SetHeader("test", "test"), - SetHeaders(map[string]string{"test": "test"}), - SetCookies(&http.Cookie{ - Name: "test", - Value: "test", - }), - SetBasicAuth("imroc", "123456"), - SetBearerAuthToken("123456"), - SetQueryString("test=test"), - SetQueryString("ksjlfjk?"), - SetQueryParam("test", "test"), - AddQueryParam("test", "test"), - SetQueryParams(map[string]string{"test": "test"}), - SetPathParam("test", "test"), - SetPathParams(map[string]string{"test": "test"}), - SetFormData(map[string]string{"test": "test"}), - SetFormDataFromValues(values), - SetContentType(jsonContentType), - AddRetryCondition(func(rep *Response, err error) bool { - return err != nil - }), - SetRetryCondition(func(rep *Response, err error) bool { - return err != nil - }), - AddRetryHook(func(resp *Response, err error) {}), - SetRetryHook(func(resp *Response, err error) {}), - SetRetryBackoffInterval(1*time.Millisecond, 500*time.Millisecond), - SetRetryFixedInterval(1*time.Millisecond), - SetRetryInterval(func(resp *Response, attempt int) time.Duration { - return 1 * time.Millisecond - }), - SetRetryCount(3), - SetBodyXmlMarshal(0), - SetBodyString("test"), - SetBodyBytes([]byte("test")), - SetBodyJsonBytes([]byte(`{"user":"roc"}`)), - SetBodyJsonString(`{"user":"roc"}`), - SetBodyXmlBytes([]byte("test")), - SetBodyXmlString("test"), - SetBody("test"), - SetBodyJsonMarshal(User{ - Name: "roc", - }), - EnableTrace(), - DisableTrace(), - SetContext(context.Background()), - ) -} - func testGlobalClientSettingWrapper(t *testing.T, cs ...*Client) { for _, c := range cs { assertNotNil(t, c) diff --git a/request.go b/request.go index cb38674b..36a1ff9c 100644 --- a/request.go +++ b/request.go @@ -225,7 +225,19 @@ func (r *Request) SetFile(paramName, filePath string) *Request { return r } r.isMultiPart = true - return r.SetFileReader(paramName, filepath.Base(filePath), file) + return r.SetFileUpload(FileUpload{ + ParamName: paramName, + FileName: filepath.Base(filePath), + GetFileContent: func() (io.ReadCloser, error) { + if r.RetryAttempt > 0 { + file, err = os.Open(filePath) + if err != nil { + return nil, err + } + } + return file, nil + }, + }) } var errMissingParamName = errors.New("missing param name in multipart file upload") diff --git a/request_test.go b/request_test.go index a1cd5f6b..c388a7fa 100644 --- a/request_test.go +++ b/request_test.go @@ -859,6 +859,9 @@ func TestSetFile(t *testing.T) { r.SetFile("file", tests.GetTestFilePath(filename)) }) assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) + + resp, err := tc().R().SetFile("file", "file-not-exists.txt").Post("/file-text") + assertErrorContains(t, err, "no such file") } func TestSetFiles(t *testing.T) { diff --git a/request_wrapper_test.go b/request_wrapper_test.go new file mode 100644 index 00000000..cbe4e7d5 --- /dev/null +++ b/request_wrapper_test.go @@ -0,0 +1,78 @@ +package req + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" +) + +func init() { + SetLogger(nil) // disable log +} + +func assertRequestNotNil(t *testing.T, rs ...*Request) { + for _, r := range rs { + assertNotNil(t, r) + } +} + +func TestGlobalWrapperForRequestSettings(t *testing.T) { + assertRequestNotNil(t, + SetFiles(map[string]string{"test": "test"}), + SetFile("test", "test"), + SetFileReader("test", "test.txt", bytes.NewBufferString("test")), + SetFileBytes("test", "test.txt", []byte("test")), + SetFileUpload(FileUpload{}), + SetError(&ErrorMessage{}), + SetResult(&UserInfo{}), + SetOutput(nil), + SetHeader("test", "test"), + SetHeaders(map[string]string{"test": "test"}), + SetCookies(&http.Cookie{ + Name: "test", + Value: "test", + }), + SetBasicAuth("imroc", "123456"), + SetBearerAuthToken("123456"), + SetQueryString("test=test"), + SetQueryString("ksjlfjk?"), + SetQueryParam("test", "test"), + AddQueryParam("test", "test"), + SetQueryParams(map[string]string{"test": "test"}), + SetPathParam("test", "test"), + SetPathParams(map[string]string{"test": "test"}), + SetFormData(map[string]string{"test": "test"}), + SetFormDataFromValues(nil), + SetContentType(jsonContentType), + AddRetryCondition(func(rep *Response, err error) bool { + return err != nil + }), + SetRetryCondition(func(rep *Response, err error) bool { + return err != nil + }), + AddRetryHook(func(resp *Response, err error) {}), + SetRetryHook(func(resp *Response, err error) {}), + SetRetryBackoffInterval(0, 0), + SetRetryFixedInterval(0), + SetRetryInterval(func(resp *Response, attempt int) time.Duration { + return 1 * time.Millisecond + }), + SetRetryCount(3), + SetBodyXmlMarshal(0), + SetBodyString("test"), + SetBodyBytes([]byte("test")), + SetBodyJsonBytes([]byte(`{"user":"roc"}`)), + SetBodyJsonString(`{"user":"roc"}`), + SetBodyXmlBytes([]byte("test")), + SetBodyXmlString("test"), + SetBody("test"), + SetBodyJsonMarshal(User{ + Name: "roc", + }), + EnableTrace(), + DisableTrace(), + SetContext(context.Background()), + ) +} From 910cc4ec46f0ed15ddc5705428694b2f16480dc0 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 15:21:15 +0800 Subject: [PATCH 449/843] add TestSetFileWithRetry --- req_test.go | 4 ++++ request_test.go | 17 +++++++++++++++++ retry.go | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/req_test.go b/req_test.go index 8127e756..ed9f8e3a 100644 --- a/req_test.go +++ b/req_test.go @@ -191,6 +191,10 @@ func handlePost(w http.ResponseWriter, r *http.Request) { files := r.MultipartForm.File["file"] file, _ := files[0].Open() b, _ := ioutil.ReadAll(file) + r.ParseForm() + if a := r.FormValue("attempt"); a != "" && a != "2" { + w.WriteHeader(http.StatusInternalServerError) + } w.Write(b) case "/form": r.ParseForm() diff --git a/request_test.go b/request_test.go index c388a7fa..64f42948 100644 --- a/request_test.go +++ b/request_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "os" + "strconv" "testing" "time" ) @@ -853,6 +854,22 @@ func TestSetFileReader(t *testing.T) { assertEqual(t, "test", resp.String()) } +func TestSetFileWithRetry(t *testing.T) { + resp, err := tc().R(). + SetRetryCount(3). + SetRetryCondition(func(resp *Response, err error) bool { + return err != nil || resp.StatusCode > 499 + }). + SetRetryHook(func(resp *Response, err error) { + resp.Request.SetQueryParam("attempt", strconv.Itoa(resp.Request.RetryAttempt)) + }). + SetFile("file", tests.GetTestFilePath("sample-file.txt")). + SetQueryParam("attempt", "0"). + Post("/file-text") + assertSuccess(t, resp, err) + assertEqual(t, 2, resp.Request.RetryAttempt) +} + func TestSetFile(t *testing.T) { filename := "sample-file.txt" resp := uploadTextFile(t, func(r *Request) { diff --git a/retry.go b/retry.go index 055d0df6..b79e0323 100644 --- a/retry.go +++ b/retry.go @@ -12,10 +12,10 @@ func defaultGetRetryInterval(resp *Response, attempt int) time.Duration { // RetryConditionFunc is a retry condition, which determines // whether the request should retry. -type RetryConditionFunc func(*Response, error) bool +type RetryConditionFunc func(resp *Response, err error) bool // RetryHookFunc is a retry hook which will be executed before a retry. -type RetryHookFunc func(*Response, error) +type RetryHookFunc func(resp *Response, err error) // GetRetryIntervalFunc is a function that determines how long should // sleep between retry attempts. From 06394d63902c00332d1e8d21f26c4773c1aee4f3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 17:17:42 +0800 Subject: [PATCH 450/843] refactor TestGlobalWrapper --- client_wrapper_test.go | 139 +++++++++++++++++++ req_test.go | 299 +--------------------------------------- request_wrapper_test.go | 92 ++++++++++++- 3 files changed, 230 insertions(+), 300 deletions(-) create mode 100644 client_wrapper_test.go diff --git a/client_wrapper_test.go b/client_wrapper_test.go new file mode 100644 index 00000000..e8bf2edb --- /dev/null +++ b/client_wrapper_test.go @@ -0,0 +1,139 @@ +package req + +import ( + "crypto/tls" + "github.com/imroc/req/v3/internal/tests" + "net/http" + "net/url" + "os" + "path/filepath" + "testing" + "time" +) + +func TestGlobalWrapper(t *testing.T) { + EnableInsecureSkipVerify() + testGlobalWrapperSendMethods(t) + testGlobalWrapperMustSendMethods(t) + DisableInsecureSkipVerify() + + u, _ := url.Parse("http://dummy.proxy.local") + proxy := http.ProxyURL(u) + form := make(url.Values) + form.Add("test", "test") + + assertAllNotNil(t, + SetCookieJar(nil), + SetDialTLS(nil), + SetDial(nil), + SetTLSHandshakeTimeout(time.Second), + EnableAllowGetMethodPayload(), + DisableAllowGetMethodPayload(), + SetJsonMarshal(nil), + SetJsonUnmarshal(nil), + SetXmlMarshal(nil), + SetXmlUnmarshal(nil), + EnableTraceAll(), + DisableTraceAll(), + OnAfterResponse(func(client *Client, response *Response) error { + return nil + }), + OnBeforeRequest(func(client *Client, request *Request) error { + return nil + }), + SetProxyURL("http://dummy.proxy.local"), + SetProxyURL("bad url"), + SetProxy(proxy), + SetCommonContentType(jsonContentType), + SetCommonHeader("my-header", "my-value"), + SetCommonHeaders(map[string]string{ + "header1": "value1", + "header2": "value2", + }), + SetCommonBasicAuth("imroc", "123456"), + SetCommonBearerAuthToken("123456"), + SetUserAgent("test"), + SetTimeout(1*time.Second), + SetLogger(createDefaultLogger()), + SetScheme("https"), + EnableDebugLog(), + DisableDebugLog(), + SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}), + SetCommonQueryString("test1=test1"), + SetCommonPathParams(map[string]string{"test1": "test1"}), + SetCommonPathParam("test2", "test2"), + AddCommonQueryParam("test1", "test11"), + SetCommonQueryParam("test1", "test111"), + SetCommonQueryParams(map[string]string{"test1": "test1"}), + EnableInsecureSkipVerify(), + DisableInsecureSkipVerify(), + DisableCompression(), + EnableCompression(), + DisableKeepAlives(), + EnableKeepAlives(), + SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")), + SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))), + SetCerts(tls.Certificate{}, tls.Certificate{}), + SetCertFromFile( + tests.GetTestFilePath("sample-client.pem"), + tests.GetTestFilePath("sample-client-key.pem"), + ), + SetOutputDirectory(testDataPath), + SetBaseURL("http://dummy-req.local/test"), + SetCommonFormDataFromValues(form), + SetCommonFormData(map[string]string{"test2": "test2"}), + DisableAutoReadResponse(), + EnableAutoReadResponse(), + EnableDumpAll(), + EnableDumpAllAsync(), + EnableDumpAllWithoutBody(), + EnableDumpAllWithoutResponse(), + EnableDumpAllWithoutRequest(), + EnableDumpAllWithoutHeader(), + SetLogger(nil), + EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")), + EnableDumpAllToFile(tests.GetTestFilePath("tmpdump.out")), + SetCommonDumpOptions(&DumpOptions{ + RequestHeader: true, + }), + DisableDumpAll(), + SetRedirectPolicy(NoRedirectPolicy()), + EnableForceHTTP1(), + EnableForceHTTP2(), + DisableForceHttpVersion(), + SetAutoDecodeContentType("json"), + SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }), + SetAutoDecodeAllContentType(), + DisableAutoDecode(), + EnableAutoDecode(), + AddCommonRetryCondition(func(resp *Response, err error) bool { return true }), + SetCommonRetryCondition(func(resp *Response, err error) bool { return true }), + AddCommonRetryHook(func(resp *Response, err error) {}), + SetCommonRetryHook(func(resp *Response, err error) {}), + SetCommonRetryCount(2), + SetCommonRetryInterval(func(resp *Response, attempt int) time.Duration { + return 1 * time.Second + }), + SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), + SetCommonRetryFixedInterval(1*time.Second), + SetUnixSocket("/var/run/custom.sock"), + ) + os.Remove(tests.GetTestFilePath("tmpdump.out")) + + config := GetTLSClientConfig() + assertEqual(t, config, DefaultClient().t.TLSClientConfig) + + r := R() + assertEqual(t, true, r != nil) + c := C() + + c.SetTimeout(10 * time.Second) + SetDefaultClient(c) + assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) + assertEqual(t, GetClient(), DefaultClient().httpClient) + + r = NewRequest() + assertEqual(t, true, r != nil) + c = NewClient() + assertEqual(t, true, c != nil) +} diff --git a/req_test.go b/req_test.go index ed9f8e3a..50bddd1c 100644 --- a/req_test.go +++ b/req_test.go @@ -1,12 +1,8 @@ package req import ( - "bytes" - "context" - "crypto/tls" "encoding/json" "encoding/xml" - "errors" "fmt" "github.com/imroc/req/v3/internal/tests" "go/token" @@ -14,17 +10,14 @@ import ( "golang.org/x/text/transform" "io" "io/ioutil" - "net" "net/http" "net/http/httptest" - "net/url" "os" "path/filepath" "reflect" "strings" "sync" "testing" - "time" "unsafe" ) @@ -83,6 +76,12 @@ func assertIsNil(t *testing.T, v interface{}) { } } +func assertAllNotNil(t *testing.T, vv ...interface{}) { + for _, v := range vv { + assertNotNil(t, v) + } +} + func assertNotNil(t *testing.T, v interface{}) { if isNil(v) { t.Fatalf("[%v] was expected to be non-nil", v) @@ -401,292 +400,6 @@ func assertIsError(t *testing.T, resp *Response, err error) { } } -func testGlobalWrapperEnableDumps(t *testing.T) { - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = true - return EnableDump() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = false - *reqBody = false - *respHeader = true - *respBody = true - return EnableDumpWithoutRequest() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = true - return EnableDumpWithoutRequestBody() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = true - *reqBody = true - *respHeader = false - *respBody = false - return EnableDumpWithoutResponse() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = false - return EnableDumpWithoutResponseBody() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = false - *reqBody = true - *respHeader = false - *respBody = true - return EnableDumpWithoutHeader() - }) - - testGlobalWrapperEnableDump(t, func(reqHeader, reqBody, respHeader, respBody *bool) *Request { - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = false - return EnableDumpWithoutBody() - }) - - buf := new(bytes.Buffer) - r := EnableDumpTo(buf) - assertEqual(t, true, r.getDumpOptions().Output != nil) - - dumpFile := tests.GetTestFilePath("req_tmp_dump.out") - r = EnableDumpToFile(tests.GetTestFilePath(dumpFile)) - assertEqual(t, true, r.getDumpOptions().Output != nil) - os.Remove(dumpFile) - - r = SetDumpOptions(&DumpOptions{ - RequestHeader: true, - }) - assertEqual(t, true, r.getDumpOptions().RequestHeader) -} - -func testGlobalWrapperEnableDump(t *testing.T, fn func(reqHeader, reqBody, respHeader, respBody *bool) *Request) { - var reqHeader, reqBody, respHeader, respBody bool - r := fn(&reqHeader, &reqBody, &respHeader, &respBody) - dump, ok := r.Context().Value(dumperKey).(*dumper) - if !ok { - t.Fatal("no dumper found in request context") - } - if reqHeader != dump.DumpOptions.RequestHeader { - t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", reqHeader, dump.DumpOptions.RequestHeader) - } - if reqBody != dump.DumpOptions.RequestBody { - t.Errorf("Unexpected RequestBody dump option, expected [%v], got [%v]", reqBody, dump.DumpOptions.RequestBody) - } - if respHeader != dump.DumpOptions.ResponseHeader { - t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", respHeader, dump.DumpOptions.ResponseHeader) - } - if respBody != dump.DumpOptions.ResponseBody { - t.Errorf("Unexpected RequestHeader dump option, expected [%v], got [%v]", respBody, dump.DumpOptions.ResponseBody) - } -} - -func testGlobalWrapperSendRequest(t *testing.T) { - testURL := getTestServerURL() + "/" - - resp, err := Put(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "PUT", resp.Header.Get("Method")) - resp = MustPut(testURL) - assertEqual(t, "PUT", resp.Header.Get("Method")) - - resp, err = Patch(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "PATCH", resp.Header.Get("Method")) - resp = MustPatch(testURL) - assertEqual(t, "PATCH", resp.Header.Get("Method")) - - resp, err = Delete(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "DELETE", resp.Header.Get("Method")) - resp = MustDelete(testURL) - assertEqual(t, "DELETE", resp.Header.Get("Method")) - - resp, err = Options(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) - resp = MustOptions(testURL) - assertEqual(t, "OPTIONS", resp.Header.Get("Method")) - - resp, err = Head(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "HEAD", resp.Header.Get("Method")) - resp = MustHead(testURL) - assertEqual(t, "HEAD", resp.Header.Get("Method")) - - resp, err = Get(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "GET", resp.Header.Get("Method")) - resp = MustGet(testURL) - assertEqual(t, "GET", resp.Header.Get("Method")) - - resp, err = Post(testURL) - assertSuccess(t, resp, err) - assertEqual(t, "POST", resp.Header.Get("Method")) - resp = MustPost(testURL) - assertEqual(t, "POST", resp.Header.Get("Method")) -} - -func testGlobalClientSettingWrapper(t *testing.T, cs ...*Client) { - for _, c := range cs { - assertNotNil(t, c) - } -} - -func TestGlobalWrapper(t *testing.T) { - EnableInsecureSkipVerify() - testGlobalWrapperSendRequest(t) - testGlobalWrapperEnableDumps(t) - DisableInsecureSkipVerify() - - testErr := errors.New("test") - testDial := func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, testErr - } - testDialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, testErr - } - - marshalFunc := func(v interface{}) ([]byte, error) { - return nil, testErr - } - unmarshalFunc := func(data []byte, v interface{}) error { - return testErr - } - u, _ := url.Parse("http://dummy.proxy.local") - proxy := http.ProxyURL(u) - form := make(url.Values) - form.Add("test", "test") - - testGlobalClientSettingWrapper(t, - SetCookieJar(nil), - SetDialTLS(testDialTLS), - SetDial(testDial), - SetTLSHandshakeTimeout(time.Second), - EnableAllowGetMethodPayload(), - DisableAllowGetMethodPayload(), - SetJsonMarshal(marshalFunc), - SetJsonUnmarshal(unmarshalFunc), - SetXmlMarshal(marshalFunc), - SetXmlUnmarshal(unmarshalFunc), - EnableTraceAll(), - DisableTraceAll(), - OnAfterResponse(func(client *Client, response *Response) error { - return nil - }), - OnBeforeRequest(func(client *Client, request *Request) error { - return nil - }), - SetProxyURL("http://dummy.proxy.local"), - SetProxyURL("bad url"), - SetProxy(proxy), - SetCommonContentType(jsonContentType), - SetCommonHeader("my-header", "my-value"), - SetCommonHeaders(map[string]string{ - "header1": "value1", - "header2": "value2", - }), - SetCommonBasicAuth("imroc", "123456"), - SetCommonBearerAuthToken("123456"), - SetUserAgent("test"), - SetTimeout(1*time.Second), - SetLogger(createDefaultLogger()), - SetScheme("https"), - EnableDebugLog(), - DisableDebugLog(), - SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}), - SetCommonQueryString("test1=test1"), - SetCommonPathParams(map[string]string{"test1": "test1"}), - SetCommonPathParam("test2", "test2"), - AddCommonQueryParam("test1", "test11"), - SetCommonQueryParam("test1", "test111"), - SetCommonQueryParams(map[string]string{"test1": "test1"}), - EnableInsecureSkipVerify(), - DisableInsecureSkipVerify(), - DisableCompression(), - EnableCompression(), - DisableKeepAlives(), - EnableKeepAlives(), - SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")), - SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))), - SetCerts(tls.Certificate{}, tls.Certificate{}), - SetCertFromFile( - tests.GetTestFilePath("sample-client.pem"), - tests.GetTestFilePath("sample-client-key.pem"), - ), - SetOutputDirectory(testDataPath), - SetBaseURL("http://dummy-req.local/test"), - SetCommonFormDataFromValues(form), - SetCommonFormData(map[string]string{"test2": "test2"}), - DisableAutoReadResponse(), - EnableAutoReadResponse(), - EnableDumpAll(), - EnableDumpAllAsync(), - EnableDumpAllWithoutBody(), - EnableDumpAllWithoutResponse(), - EnableDumpAllWithoutRequest(), - EnableDumpAllWithoutHeader(), - SetLogger(nil), - EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")), - EnableDumpAllToFile(tests.GetTestFilePath("tmpdump.out")), - SetCommonDumpOptions(&DumpOptions{ - RequestHeader: true, - }), - DisableDumpAll(), - SetRedirectPolicy(NoRedirectPolicy()), - EnableForceHTTP1(), - EnableForceHTTP2(), - DisableForceHttpVersion(), - SetAutoDecodeContentType("json"), - SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }), - SetAutoDecodeAllContentType(), - DisableAutoDecode(), - EnableAutoDecode(), - AddCommonRetryCondition(func(resp *Response, err error) bool { return true }), - SetCommonRetryCondition(func(resp *Response, err error) bool { return true }), - AddCommonRetryHook(func(resp *Response, err error) {}), - SetCommonRetryHook(func(resp *Response, err error) {}), - SetCommonRetryCount(2), - SetCommonRetryInterval(func(resp *Response, attempt int) time.Duration { - return 1 * time.Second - }), - SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), - SetCommonRetryFixedInterval(1*time.Second), - SetUnixSocket("/var/run/custom.sock"), - ) - os.Remove(tests.GetTestFilePath("tmpdump.out")) - - config := GetTLSClientConfig() - assertEqual(t, config, DefaultClient().t.TLSClientConfig) - - r := R() - assertEqual(t, true, r != nil) - c := C() - - c.SetTimeout(10 * time.Second) - SetDefaultClient(c) - assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - assertEqual(t, GetClient(), DefaultClient().httpClient) - - r = NewRequest() - assertEqual(t, true, r != nil) - c = NewClient() - assertEqual(t, true, c != nil) -} - func TestTrailer(t *testing.T) { resp, err := tc().EnableForceHTTP1().R().Get("/chunked") assertSuccess(t, resp, err) diff --git a/request_wrapper_test.go b/request_wrapper_test.go index cbe4e7d5..8f0e6a5c 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -12,14 +12,8 @@ func init() { SetLogger(nil) // disable log } -func assertRequestNotNil(t *testing.T, rs ...*Request) { - for _, r := range rs { - assertNotNil(t, r) - } -} - func TestGlobalWrapperForRequestSettings(t *testing.T) { - assertRequestNotNil(t, + assertAllNotNil(t, SetFiles(map[string]string{"test": "test"}), SetFile("test", "test"), SetFileReader("test", "test.txt", bytes.NewBufferString("test")), @@ -76,3 +70,87 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { SetContext(context.Background()), ) } + +func testGlobalWrapperMustSendMethods(t *testing.T) { + testCases := []struct { + SendReq func(string) *Response + ExpectMethod string + }{ + { + SendReq: MustGet, + ExpectMethod: "GET", + }, + { + SendReq: MustPost, + ExpectMethod: "POST", + }, + { + SendReq: MustPatch, + ExpectMethod: "PATCH", + }, + { + SendReq: MustPut, + ExpectMethod: "PUT", + }, + { + SendReq: MustDelete, + ExpectMethod: "DELETE", + }, + { + SendReq: MustOptions, + ExpectMethod: "OPTIONS", + }, + { + SendReq: MustHead, + ExpectMethod: "HEAD", + }, + } + url := getTestServerURL() + "/" + for _, tc := range testCases { + resp := tc.SendReq(url) + assertNotNil(t, resp.Response) + assertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) + } +} + +func testGlobalWrapperSendMethods(t *testing.T) { + testCases := []struct { + SendReq func(string) (*Response, error) + ExpectMethod string + }{ + { + SendReq: Get, + ExpectMethod: "GET", + }, + { + SendReq: Post, + ExpectMethod: "POST", + }, + { + SendReq: Patch, + ExpectMethod: "PATCH", + }, + { + SendReq: Put, + ExpectMethod: "PUT", + }, + { + SendReq: Delete, + ExpectMethod: "DELETE", + }, + { + SendReq: Options, + ExpectMethod: "OPTIONS", + }, + { + SendReq: Head, + ExpectMethod: "HEAD", + }, + } + url := getTestServerURL() + "/" + for _, tc := range testCases { + resp, err := tc.SendReq(url) + assertSuccess(t, resp, err) + assertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) + } +} From ca6f36d6cd8b7bfaafd2ddc8b18c449ef12a6457 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 17:53:21 +0800 Subject: [PATCH 451/843] reuse dump.Output if missing in SetCommonDumpOptions, and refactor TestEnableDumpAll --- client.go | 7 ++ client_test.go | 175 +++++++++++++++++++++---------------------------- 2 files changed, 82 insertions(+), 100 deletions(-) diff --git a/client.go b/client.go index a2194abc..875d39e6 100644 --- a/client.go +++ b/client.go @@ -613,6 +613,13 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c } + if opt.Output == nil { + if c.dumpOptions != nil { + opt.Output = c.dumpOptions.Output + } else { + opt.Output = os.Stdout + } + } c.dumpOptions = opt if c.t.dump != nil { c.t.dump.DumpOptions = opt diff --git a/client_test.go b/client_test.go index b77ae782..61642186 100644 --- a/client_test.go +++ b/client_test.go @@ -440,113 +440,88 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { assertNoError(t, err) } -func TestEnableDumpAll(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAll() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpAllWithoutRequest(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutRequest() - *reqHeader = false - *reqBody = false - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpAllWithoutRequestBody(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutRequestBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = true - }) -} - -func TestEnableDumpAllWithoutResponse(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutResponse() - *reqHeader = true - *reqBody = true - *respHeader = false - *respBody = false - }) -} - -func TestEnableDumpAllWithoutResponseBody(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutResponseBody() - *reqHeader = true - *reqBody = true - *respHeader = true - *respBody = false - }) -} - -func TestEnableDumpAllWithoutHeader(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutHeader() - *reqHeader = false - *reqBody = true - *respHeader = false - *respBody = true - }) -} - -func TestEnableDumpAllWithoutBody(t *testing.T) { - testEnableDumpAll(t, func(c *Client, reqHeader, reqBody, respHeader, respBody *bool) { - c.EnableDumpAllWithoutBody() - *reqHeader = true - *reqBody = false - *respHeader = true - *respBody = false - }) -} - -func testEnableDumpAll(t *testing.T, fn func(c *Client, reqHeader, reqBody, respHeader, respBody *bool)) { - buf := new(bytes.Buffer) - c := tc().EnableDumpAllTo(buf) +func testEnableDumpAll(t *testing.T, fn func(c *Client) (de dumpExpected)) { testDump := func(c *Client) { - var reqHeader, reqBody, respHeader, respBody bool - fn(c, &reqHeader, &reqBody, &respHeader, &respBody) - resp, err := c.R().SetBody(`test body`).Post("/") + buff := new(bytes.Buffer) + c.EnableDumpAllTo(buff) + r := c.R() + de := fn(c) + resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) - dump := buf.String() - assertContains(t, dump, "user-agent", reqHeader) - assertContains(t, dump, "test body", reqBody) - assertContains(t, dump, "date", respHeader) - assertContains(t, dump, "testpost: text response", respBody) + dump := buff.String() + assertContains(t, dump, "user-agent", de.ReqHeader) + assertContains(t, dump, "test body", de.ReqBody) + assertContains(t, dump, "date", de.RespHeader) + assertContains(t, dump, "testpost: text response", de.RespBody) } + c := tc() testDump(c) - buf = new(bytes.Buffer) - c = tc().EnableDumpAllTo(buf).EnableForceHTTP1() - testDump(c) + testDump(c.EnableForceHTTP1()) } -func TestSetCommonDumpOptions(t *testing.T) { - c := tc() - buf := new(bytes.Buffer) - opt := &DumpOptions{ - RequestHeader: true, - RequestBody: false, - ResponseHeader: false, - ResponseBody: true, - Output: buf, +func TestEnableDumpAll(t *testing.T) { + testCases := []func(c *Client) (d dumpExpected){ + func(c *Client) (de dumpExpected) { + c.EnableDumpAll() + de.ReqHeader = true + de.ReqBody = true + de.RespHeader = true + de.RespBody = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutHeader() + de.ReqBody = true + de.RespBody = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutBody() + de.ReqHeader = true + de.RespHeader = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutRequest() + de.RespHeader = true + de.RespBody = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutRequestBody() + de.ReqHeader = true + de.RespHeader = true + de.RespBody = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutResponse() + de.ReqHeader = true + de.ReqBody = true + return + }, + func(c *Client) (de dumpExpected) { + c.EnableDumpAllWithoutResponseBody() + de.ReqHeader = true + de.ReqBody = true + de.RespHeader = true + return + }, + func(c *Client) (de dumpExpected) { + c.SetCommonDumpOptions(&DumpOptions{ + RequestHeader: true, + RequestBody: true, + ResponseBody: true, + }).EnableDumpAll() + de.ReqHeader = true + de.ReqBody = true + de.RespBody = true + return + }, + } + for _, fn := range testCases { + testEnableDumpAll(t, fn) } - c.SetCommonDumpOptions(opt).EnableDumpAll() - resp, err := c.R().SetBody("test body").Post("/") - assertSuccess(t, resp, err) - assertContains(t, buf.String(), "user-agent", true) - assertContains(t, buf.String(), "test body", false) - assertContains(t, buf.String(), "date", false) - assertContains(t, buf.String(), "testpost: text response", true) } func TestEnableDumpAllToFile(t *testing.T) { From e9153b0a97b03d7645104a48e1469678e46de203 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 18:02:16 +0800 Subject: [PATCH 452/843] improve TestDisableAutoReadResponse --- client_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client_test.go b/client_test.go index 61642186..df00b2b2 100644 --- a/client_test.go +++ b/client_test.go @@ -421,8 +421,7 @@ func TestClientClone(t *testing.T) { } func TestDisableAutoReadResponse(t *testing.T) { - testDisableAutoReadResponse(t, tc()) - testDisableAutoReadResponse(t, tc().EnableForceHTTP1()) + testWithAllTransport(t, testDisableAutoReadResponse) } func testDisableAutoReadResponse(t *testing.T, c *Client) { From 2279bb9c15f5d695fc5f7d2661b722004ab2a498 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Mar 2022 21:06:28 +0800 Subject: [PATCH 453/843] remove unused code in transferWriter --- transfer.go | 82 +++++++++++++++++++++++------------------------------ 1 file changed, 35 insertions(+), 47 deletions(-) diff --git a/transfer.go b/transfer.go index d78bdace..05b8b3f7 100644 --- a/transfer.go +++ b/transfer.go @@ -58,69 +58,57 @@ type transferWriter struct { Method string Body io.Reader BodyCloser io.Closer - ResponseToHEAD bool ContentLength int64 // -1 means unknown, 0 means exactly none Close bool TransferEncoding []string Header http.Header Trailer http.Header - IsResponse bool bodyReadError error // any non-EOF error from reading Body FlushHeaders bool // flush headers to network before body ByteReadCh chan readResult // non-nil if probeRequestBody called } -func newTransferWriter(r interface{}) (t *transferWriter, err error) { +func newTransferWriter(r *http.Request) (t *transferWriter, err error) { t = &transferWriter{} // Extract relevant fields atLeastHTTP11 := false - switch rr := r.(type) { - case *http.Request: - if rr.ContentLength != 0 && rr.Body == nil { - return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) - } - t.Method = valueOrDefault(rr.Method, "GET") - t.Close = rr.Close - t.TransferEncoding = rr.TransferEncoding - t.Header = rr.Header - t.Trailer = rr.Trailer - t.Body = rr.Body - t.BodyCloser = rr.Body - t.ContentLength = outgoingLength(rr) - if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() { - t.TransferEncoding = []string{"chunked"} - } - // If there's a body, conservatively flush the headers - // to any bufio.Writer we're writing to, just in case - // the server needs the headers early, before we copy - // the body and possibly block. We make an exception - // for the common standard library in-memory types, - // though, to avoid unnecessary TCP packets on the - // wire. (Issue 22088.) - if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) { - t.FlushHeaders = true - } - - atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0 + if r.ContentLength != 0 && r.Body == nil { + return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", r.ContentLength) + } + t.Method = valueOrDefault(r.Method, "GET") + t.Close = r.Close + t.TransferEncoding = r.TransferEncoding + t.Header = r.Header + t.Trailer = r.Trailer + t.Body = r.Body + t.BodyCloser = r.Body + t.ContentLength = outgoingLength(r) + if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() { + t.TransferEncoding = []string{"chunked"} + } + // If there's a body, conservatively flush the headers + // to any bufio.Writer we're writing to, just in case + // the server needs the headers early, before we copy + // the body and possibly block. We make an exception + // for the common standard library in-memory types, + // though, to avoid unnecessary TCP packets on the + // wire. (Issue 22088.) + if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) { + t.FlushHeaders = true } + atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0 + // Sanitize Body,ContentLength,TransferEncoding - if t.ResponseToHEAD { - t.Body = nil - if chunked(t.TransferEncoding) { - t.ContentLength = -1 - } - } else { - if !atLeastHTTP11 || t.Body == nil { - t.TransferEncoding = nil - } - if chunked(t.TransferEncoding) { - t.ContentLength = -1 - } else if t.Body == nil { // no chunking, no body - t.ContentLength = 0 - } + if !atLeastHTTP11 || t.Body == nil { + t.TransferEncoding = nil + } + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } else if t.Body == nil { // no chunking, no body + t.ContentLength = 0 } // Sanitize Trailer @@ -342,7 +330,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { if t.Body != nil { var body = t.unwrapBody() if chunked(t.TransferEncoding) { - if bw, ok := rw.(*bufio.Writer); ok && !t.IsResponse { + if bw, ok := rw.(*bufio.Writer); ok { rw = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(rw) @@ -386,7 +374,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { } } - if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { + if t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: ContentLength=%d with Body length %d", t.ContentLength, ncopy) } From 69fc4bbb230a6295e6c7e7a4f0d4812a5c781c80 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 21 Mar 2022 14:24:36 +0800 Subject: [PATCH 454/843] update go.mod: require go1.15 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 354cd3f8..45abb226 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/imroc/req/v3 -go 1.13 +go 1.15 require ( github.com/hashicorp/go-multierror v1.1.1 From f172c24f15bfe3fe356528e9350c98ea9d519030 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 22 Mar 2022 19:58:55 +0800 Subject: [PATCH 455/843] support js && wasm --- roundtrip_js.go | 334 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 roundtrip_js.go diff --git a/roundtrip_js.go b/roundtrip_js.go new file mode 100644 index 00000000..9d8f3e4a --- /dev/null +++ b/roundtrip_js.go @@ -0,0 +1,334 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js && wasm +// +build js,wasm + +package req + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "strconv" + "syscall/js" +) + +var uint8Array = js.Global().Get("Uint8Array") + +// jsFetchMode is a Request.Header map key that, if present, +// signals that the map entry is actually an option to the Fetch API mode setting. +// Valid values are: "cors", "no-cors", "same-origin", "navigate" +// The default is "same-origin". +// +// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters +const jsFetchMode = "js.fetch:mode" + +// jsFetchCreds is a Request.Header map key that, if present, +// signals that the map entry is actually an option to the Fetch API credentials setting. +// Valid values are: "omit", "same-origin", "include" +// The default is "same-origin". +// +// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters +const jsFetchCreds = "js.fetch:credentials" + +// jsFetchRedirect is a Request.Header map key that, if present, +// signals that the map entry is actually an option to the Fetch API redirect setting. +// Valid values are: "follow", "error", "manual" +// The default is "follow". +// +// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters +const jsFetchRedirect = "js.fetch:redirect" + +// jsFetchMissing will be true if the Fetch API is not present in +// the browser globals. +var jsFetchMissing = js.Global().Get("fetch").IsUndefined() + +// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + // The Transport has a documented contract that states that if the DialContext or + // DialTLSContext functions are set, they will be used to set up the connections. + // If they aren't set then the documented contract is to use Dial or DialTLS, even + // though they are deprecated. Therefore, if any of these are set, we should obey + // the contract and dial using the regular round-trip instead. Otherwise, we'll try + // to fall back on the Fetch API, unless it's not available. + if t.DialContext != nil || t.DialTLSContext != nil || jsFetchMissing { + return t.roundTrip(req) + } + + ac := js.Global().Get("AbortController") + if !ac.IsUndefined() { + // Some browsers that support WASM don't necessarily support + // the AbortController. See + // https://developer.mozilla.org/en-US/docs/Web/API/AbortController#Browser_compatibility. + ac = ac.New() + } + + opt := js.Global().Get("Object").New() + // See https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch + // for options available. + opt.Set("method", req.Method) + opt.Set("credentials", "same-origin") + if h := req.Header.Get(jsFetchCreds); h != "" { + opt.Set("credentials", h) + req.Header.Del(jsFetchCreds) + } + if h := req.Header.Get(jsFetchMode); h != "" { + opt.Set("mode", h) + req.Header.Del(jsFetchMode) + } + if h := req.Header.Get(jsFetchRedirect); h != "" { + opt.Set("redirect", h) + req.Header.Del(jsFetchRedirect) + } + if !ac.IsUndefined() { + opt.Set("signal", ac.Get("signal")) + } + headers := js.Global().Get("Headers").New() + for key, values := range req.Header { + for _, value := range values { + headers.Call("append", key, value) + } + } + opt.Set("headers", headers) + + if req.Body != nil { + // TODO(johanbrandhorst): Stream request body when possible. + // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue. + // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. + // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API + // and browser support. + body, err := ioutil.ReadAll(req.Body) + if err != nil { + req.Body.Close() // RoundTrip must always close the body, including on errors. + return nil, err + } + req.Body.Close() + if len(body) != 0 { + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) + } + } + + fetchPromise := js.Global().Call("fetch", req.URL.String(), opt) + var ( + respCh = make(chan *http.Response, 1) + errCh = make(chan error, 1) + success, failure js.Func + ) + success = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success.Release() + failure.Release() + + result := args[0] + header := http.Header{} + // https://developer.mozilla.org/en-US/docs/Web/API/Headers/entries + headersIt := result.Get("headers").Call("entries") + for { + n := headersIt.Call("next") + if n.Get("done").Bool() { + break + } + pair := n.Get("value") + key, value := pair.Index(0).String(), pair.Index(1).String() + ck := http.CanonicalHeaderKey(key) + header[ck] = append(header[ck], value) + } + + contentLength := int64(0) + clHeader := header.Get("Content-Length") + switch { + case clHeader != "": + cl, err := strconv.ParseInt(clHeader, 10, 64) + if err != nil { + errCh <- fmt.Errorf("net/http: ill-formed Content-Length header: %v", err) + return nil + } + if cl < 0 { + // Content-Length values less than 0 are invalid. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + errCh <- fmt.Errorf("net/http: invalid Content-Length header: %q", clHeader) + return nil + } + contentLength = cl + default: + // If the response length is not declared, set it to -1. + contentLength = -1 + } + + b := result.Get("body") + var body io.ReadCloser + // The body is undefined when the browser does not support streaming response bodies (Firefox), + // and null in certain error cases, i.e. when the request is blocked because of CORS settings. + if !b.IsUndefined() && !b.IsNull() { + body = &streamReader{stream: b.Call("getReader")} + } else { + // Fall back to using ArrayBuffer + // https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer + body = &arrayReader{arrayPromise: result.Call("arrayBuffer")} + } + + code := result.Get("status").Int() + respCh <- &http.Response{ + Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), + StatusCode: code, + Header: header, + ContentLength: contentLength, + Body: body, + Request: req, + } + + return nil + }) + failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success.Release() + failure.Release() + errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) + return nil + }) + + fetchPromise.Call("then", success, failure) + select { + case <-req.Context().Done(): + if !ac.IsUndefined() { + // Abort the Fetch request. + ac.Call("abort") + } + return nil, req.Context().Err() + case resp := <-respCh: + return resp, nil + case err := <-errCh: + return nil, err + } +} + +var errClosed = errors.New("net/http: reader is closed") + +// streamReader implements an io.ReadCloser wrapper for ReadableStream. +// See https://fetch.spec.whatwg.org/#readablestream for more information. +type streamReader struct { + pending []byte + stream js.Value + err error // sticky read error +} + +func (r *streamReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err + } + if len(r.pending) == 0 { + var ( + bCh = make(chan []byte, 1) + errCh = make(chan error, 1) + ) + success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result := args[0] + if result.Get("done").Bool() { + errCh <- io.EOF + return nil + } + value := make([]byte, result.Get("value").Get("byteLength").Int()) + js.CopyBytesToGo(value, result.Get("value")) + bCh <- value + return nil + }) + defer success.Release() + failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + // Assumes it's a TypeError. See + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError + // for more information on this type. See + // https://streams.spec.whatwg.org/#byob-reader-read for the spec on + // the read method. + errCh <- errors.New(args[0].Get("message").String()) + return nil + }) + defer failure.Release() + r.stream.Call("read").Call("then", success, failure) + select { + case b := <-bCh: + r.pending = b + case err := <-errCh: + r.err = err + return 0, err + } + } + n = copy(p, r.pending) + r.pending = r.pending[n:] + return n, nil +} + +func (r *streamReader) Close() error { + // This ignores any error returned from cancel method. So far, I did not encounter any concrete + // situation where reporting the error is meaningful. Most users ignore error from resp.Body.Close(). + // If there's a need to report error here, it can be implemented and tested when that need comes up. + r.stream.Call("cancel") + if r.err == nil { + r.err = errClosed + } + return nil +} + +// arrayReader implements an io.ReadCloser wrapper for ArrayBuffer. +// https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer. +type arrayReader struct { + arrayPromise js.Value + pending []byte + read bool + err error // sticky read error +} + +func (r *arrayReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err + } + if !r.read { + r.read = true + var ( + bCh = make(chan []byte, 1) + errCh = make(chan error, 1) + ) + success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + // Wrap the input ArrayBuffer with a Uint8Array + uint8arrayWrapper := uint8Array.New(args[0]) + value := make([]byte, uint8arrayWrapper.Get("byteLength").Int()) + js.CopyBytesToGo(value, uint8arrayWrapper) + bCh <- value + return nil + }) + defer success.Release() + failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + // Assumes it's a TypeError. See + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError + // for more information on this type. + // See https://fetch.spec.whatwg.org/#concept-body-consume-body for reasons this might error. + errCh <- errors.New(args[0].Get("message").String()) + return nil + }) + defer failure.Release() + r.arrayPromise.Call("then", success, failure) + select { + case b := <-bCh: + r.pending = b + case err := <-errCh: + return 0, err + } + } + if len(r.pending) == 0 { + return 0, io.EOF + } + n = copy(p, r.pending) + r.pending = r.pending[n:] + return n, nil +} + +func (r *arrayReader) Close() error { + if r.err == nil { + r.err = errClosed + } + return nil +} From e967d3f29f33442f907478d07f9895ae0a2d8e21 Mon Sep 17 00:00:00 2001 From: Fuad Olatunji <65264054+fuadop@users.noreply.github.com> Date: Wed, 23 Mar 2022 10:59:33 +0100 Subject: [PATCH 456/843] fix typo on readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9771515d..7d2aef28 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ Checkout [Quick API Reference](docs/api.md) for a brief and categorized list of **Examples** -Checkout more examples below or runnable examples in the [examples](examples) direcotry. +Checkout more examples below or runnable examples in the [examples](examples) directory. ## Debugging - Dump/Log/Trace From 780943c24a57291847c677c5aaaceb20fd582c5d Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 25 Mar 2022 21:50:53 +0800 Subject: [PATCH 457/843] support AlwaysCopyHeaderRedirectPolicy --- client_test.go | 11 +++++++++++ redirect.go | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/client_test.go b/client_test.go index df00b2b2..83123093 100644 --- a/client_test.go +++ b/client_test.go @@ -329,6 +329,17 @@ func TestRedirect(t *testing.T) { _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") assertNotNil(t, err) assertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) + + c := tc().SetRedirectPolicy(AlwaysCopyHeaderRedirectPolicy("Authorization")) + newHeader := make(http.Header) + oldHeader := make(http.Header) + oldHeader.Set("Authorization", "test") + c.GetClient().CheckRedirect(&http.Request{ + Header: newHeader, + }, []*http.Request{&http.Request{ + Header: oldHeader, + }}) + assertEqual(t, "test", newHeader.Get("Authorization")) } func TestGetTLSClientConfig(t *testing.T) { diff --git a/redirect.go b/redirect.go index 06ce5a6e..4f3b71e9 100644 --- a/redirect.go +++ b/redirect.go @@ -103,3 +103,26 @@ func getDomain(host string) string { ss = ss[1:] return strings.Join(ss, ".") } + +// AlwaysCopyHeaderRedirectPolicy ensures that the given sensitive headers will +// always be copied on redirect. +// By default, golang will copy all of the original request's headers on redirect, +// unless they're sensitive, like "Authorization" or "Www-Authenticate". Only send +// sensitive ones to the same origin, or subdomains thereof (https://go-review.googlesource.com/c/go/+/28930/) +// Check discussion: https://github.com/golang/go/issues/4800 +// For example: +// client.SetRedirectPolicy(req.AlwaysCopyHeaderRedirectPolicy("Authorization")) +func AlwaysCopyHeaderRedirectPolicy(headers ...string) RedirectPolicy { + return func(req *http.Request, via []*http.Request) error { + for _, header := range headers { + if len(req.Header.Values(header)) > 0 { + continue + } + vals := via[0].Header.Values(header) + for _, val := range vals { + req.Header.Add(header, val) + } + } + return nil + } +} From 2d22ff02c08b9be1f7bb82d07fb440b3caad2bbd Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 27 Mar 2022 15:21:55 +0800 Subject: [PATCH 458/843] support upload callback(#104) --- README.md | 23 ++++++ docs/api.md | 2 + examples/uploadcallback/README.md | 19 +++++ examples/uploadcallback/uploadclient/go.mod | 7 ++ examples/uploadcallback/uploadclient/go.sum | 13 ++++ examples/uploadcallback/uploadclient/main.go | 47 ++++++++++++ examples/uploadcallback/uploadserver/go.mod | 5 ++ examples/uploadcallback/uploadserver/go.sum | 54 ++++++++++++++ examples/uploadcallback/uploadserver/main.go | 18 +++++ middleware.go | 76 ++++++++++++++++++-- req.go | 19 +++++ request.go | 66 +++++++++++------ request_test.go | 36 +++++++++- request_wrapper.go | 12 ++++ request_wrapper_test.go | 8 ++- 15 files changed, 374 insertions(+), 31 deletions(-) create mode 100644 examples/uploadcallback/README.md create mode 100644 examples/uploadcallback/uploadclient/go.mod create mode 100644 examples/uploadcallback/uploadclient/go.sum create mode 100644 examples/uploadcallback/uploadclient/main.go create mode 100644 examples/uploadcallback/uploadserver/go.mod create mode 100644 examples/uploadcallback/uploadserver/go.sum create mode 100644 examples/uploadcallback/uploadserver/main.go diff --git a/README.md b/README.md index 7d2aef28..76758473 100644 --- a/README.md +++ b/README.md @@ -781,6 +781,29 @@ client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) */ ``` +**Upload Callback** + +You can set `UploadCallback` if you want to show upload progress: + +```go +client := req.C() +client.R(). + SetFile("excel", "test.xlsx"). + SetUploadCallback(func(info req.UploadInfo) { + fmt.Printf("%q uploaded %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0) + }).Post("https://exmaple.com/upload") +/* Output +"test.xlsx" uploaded 7.44% +"test.xlsx" uploaded 29.78% +"test.xlsx" uploaded 52.08% +"test.xlsx" uploaded 74.47% +"test.xlsx" uploaded 96.87% +"test.xlsx" uploaded 100.00% +*/ +``` + +> `UploadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetUploadCallbackWithInterval`. + ## Auto-Decode `Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. diff --git a/docs/api.md b/docs/api.md index 65861129..491329b7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -203,6 +203,8 @@ Basically, you can know the meaning of most settings directly from the method na * [SetFileBytes(paramName, filename string, content []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileBytes) * [SetFileReader(paramName, filePath string, reader io.Reader)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileReader) * [SetFileUpload(uploads ...FileUpload)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileUpload) - Set the fully custimized multipart file upload options. +* [SetUploadCallback(callback UploadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetUploadCallback) +* [SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetUploadCallbackWithInterval) ### Download diff --git a/examples/uploadcallback/README.md b/examples/uploadcallback/README.md new file mode 100644 index 00000000..1d5574b1 --- /dev/null +++ b/examples/uploadcallback/README.md @@ -0,0 +1,19 @@ +# uploadcallback + +This is a upload callback exmaple for `req` + +## How to Run + +Run `uploadserver`: + +```go +cd uploadserver +go run . +``` + +Run `uploadclient`: + +```go +cd uploadclient +go run . +``` \ No newline at end of file diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod new file mode 100644 index 00000000..c95a1ec4 --- /dev/null +++ b/examples/uploadcallback/uploadclient/go.mod @@ -0,0 +1,7 @@ +module uploadclient + +go 1.13 + +replace github.com/imroc/req/v3 => ../../../ + +require github.com/imroc/req/v3 v3.0.0 diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum new file mode 100644 index 00000000..59f8d4e3 --- /dev/null +++ b/examples/uploadcallback/uploadclient/go.sum @@ -0,0 +1,13 @@ +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= +golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/examples/uploadcallback/uploadclient/main.go b/examples/uploadcallback/uploadclient/main.go new file mode 100644 index 00000000..b617e8f7 --- /dev/null +++ b/examples/uploadcallback/uploadclient/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "fmt" + "io" + "time" + "github.com/imroc/req/v3" +) + +type SlowReader struct { + Size int + n int +} + +func (r *SlowReader) Close() error { + return nil +} + +func (r *SlowReader) Read(p []byte) (int, error) { + if r.n >= r.Size { + return 0, io.EOF + } + time.Sleep(1 * time.Millisecond) + n := len(p) + if r.n+n >= r.Size { + n = r.Size - r.n + } + for i := 0; i < n; i++ { + p[i] = 'h' + } + r.n += n + return n, nil +} + +func main() { + size := 10 * 1024 * 1024 + req.SetFileUpload(req.FileUpload{ + ParamName: "file", + FileName: "test.txt", + GetFileContent: func() (io.ReadCloser, error) { + return &SlowReader{Size: size}, nil + }, + FileSize: int64(size), + }).SetUploadCallbackWithInterval(func(info req.UploadInfo) { + fmt.Printf("%s: %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0) + }, 1*time.Second).Post("http://127.0.0.1:8888/upload") +} diff --git a/examples/uploadcallback/uploadserver/go.mod b/examples/uploadcallback/uploadserver/go.mod new file mode 100644 index 00000000..94acf8b9 --- /dev/null +++ b/examples/uploadcallback/uploadserver/go.mod @@ -0,0 +1,5 @@ +module uploadserver + +go 1.13 + +require github.com/gin-gonic/gin v1.7.7 diff --git a/examples/uploadcallback/uploadserver/go.sum b/examples/uploadcallback/uploadserver/go.sum new file mode 100644 index 00000000..5ee9be12 --- /dev/null +++ b/examples/uploadcallback/uploadserver/go.sum @@ -0,0 +1,54 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/uploadcallback/uploadserver/main.go b/examples/uploadcallback/uploadserver/main.go new file mode 100644 index 00000000..f1149207 --- /dev/null +++ b/examples/uploadcallback/uploadserver/main.go @@ -0,0 +1,18 @@ +package main + +import ( + "github.com/gin-gonic/gin" + "io" + "io/ioutil" + "net/http" +) + +func main() { + router := gin.Default() + router.POST("/upload", func(c *gin.Context) { + body := c.Request.Body + io.Copy(ioutil.Discard, body) + c.String(http.StatusOK, "ok") + }) + router.Run(":8888") +} diff --git a/middleware.go b/middleware.go index 7976de88..21315f24 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "errors" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -12,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "time" ) type ( @@ -55,7 +57,7 @@ func closeq(v interface{}) { } } -func writeMultipartFormFile(w *multipart.Writer, file *FileUpload) error { +func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) error { content, err := file.GetFileContent() if err != nil { return err @@ -63,22 +65,82 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload) error { defer content.Close() // Auto detect actual multipart content type cbuf := make([]byte, 512) + seeEOF := false + lastTime := time.Now() size, err := content.Read(cbuf) - if err != nil && err != io.EOF { - return err + if err != nil { + if err == io.EOF { + seeEOF = true + } else { + return err + } } pw, err := w.CreatePart(createMultipartHeader(file, http.DetectContentType(cbuf))) if err != nil { return err } - if _, err = pw.Write(cbuf[:size]); err != nil { return err } + if seeEOF { + return nil + } + if r.uploadCallback == nil { + _, err = io.Copy(pw, content) + return err + } - _, err = io.Copy(pw, content) - return err + uploadedBytes := int64(size) + progressCallback := func() { + r.uploadCallback(UploadInfo{ + ParamName: file.ParamName, + FileName: file.FileName, + FileSize: file.FileSize, + UploadedSize: uploadedBytes, + }) + } + if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + lastTime = now + progressCallback() + } + buf := make([]byte, 1024) + for { + callback := false + nr, er := content.Read(buf) + if nr > 0 { + nw, ew := pw.Write(buf[:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write result") + } + } + uploadedBytes += int64(nw) + if ew != nil { + return ew + } + if nr != nw { + return io.ErrShortWrite + } + if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + lastTime = now + progressCallback() + callback = true + } + } + if er != nil { + if er == io.EOF { + if !callback { + progressCallback() + } + break + } else { + return er + } + } + } + return nil } func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { @@ -88,7 +150,7 @@ func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { } } for _, file := range r.uploadFiles { - writeMultipartFormFile(w, file) + writeMultipartFormFile(w, file, r) } w.Close() // close multipart to write tailer boundary pw.Close() // close pipe writer so that pipe reader could get EOF, and stop upload diff --git a/req.go b/req.go index 8a7970a9..249f4bea 100644 --- a/req.go +++ b/req.go @@ -53,7 +53,10 @@ type FileUpload struct { FileName string // The file to be uploaded. GetFileContent GetContentFunc + // Optional file length in bytes. + FileSize int64 + // Optional extra ContentDisposition parameters. // According to the HTTP specification, this should be nil, // but some servers may not follow the specification and // requires `Content-Disposition` parameters more than just @@ -61,6 +64,22 @@ type FileUpload struct { ExtraContentDisposition *ContentDisposition } +// UploadInfo is the information for each UploadCallback call. +type UploadInfo struct { + // parameter name in multipart upload + ParamName string + // filename in multipart upload + FileName string + // total file length in bytes. + FileSize int64 + // uploaded file length in bytes. + UploadedSize int64 +} + +// UploadCallback is the callback which will be invoked during +// multipart upload. +type UploadCallback func(info UploadInfo) + func cloneCookies(cookies []*http.Cookie) []*http.Cookie { if len(cookies) == 0 { return nil diff --git a/request.go b/request.go index 36a1ff9c..a14af1ce 100644 --- a/request.go +++ b/request.go @@ -33,26 +33,28 @@ type Request struct { StartTime time.Time RetryAttempt int - RawURL string // read only - method string - URL *urlpkg.URL - getBody GetContentFunc - unReplayableBody io.ReadCloser - retryOption *retryOption - bodyReadCloser io.ReadCloser - body []byte - dumpOptions *DumpOptions - marshalBody interface{} - ctx context.Context - isMultiPart bool - uploadFiles []*FileUpload - uploadReader []io.ReadCloser - outputFile string - isSaveResponse bool - output io.Writer - trace *clientTrace - dumpBuffer *bytes.Buffer - responseReturnTime time.Time + RawURL string // read only + method string + URL *urlpkg.URL + getBody GetContentFunc + uploadCallback UploadCallback + uploadCallbackInterval time.Duration + unReplayableBody io.ReadCloser + retryOption *retryOption + bodyReadCloser io.ReadCloser + body []byte + dumpOptions *DumpOptions + marshalBody interface{} + ctx context.Context + isMultiPart bool + uploadFiles []*FileUpload + uploadReader []io.ReadCloser + outputFile string + isSaveResponse bool + output io.Writer + trace *clientTrace + dumpBuffer *bytes.Buffer + responseReturnTime time.Time } type GetContentFunc func() (io.ReadCloser, error) @@ -224,6 +226,12 @@ func (r *Request) SetFile(paramName, filePath string) *Request { r.appendError(err) return r } + fileInfo, err := os.Stat(filePath) + if err != nil { + r.client.log.Errorf("failed to stat file %s: %v", filePath, err) + r.appendError(err) + return r + } r.isMultiPart = true return r.SetFileUpload(FileUpload{ ParamName: paramName, @@ -237,6 +245,7 @@ func (r *Request) SetFile(paramName, filePath string) *Request { } return file, nil }, + FileSize: fileInfo.Size(), }) } @@ -268,6 +277,23 @@ func (r *Request) SetFileUpload(uploads ...FileUpload) *Request { return r } +// SetUploadCallback set the UploadCallback which will be invoked at least +// every 200ms during file upload, usually used to show upload progress. +func (r *Request) SetUploadCallback(callback UploadCallback) *Request { + return r.SetUploadCallbackWithInterval(callback, 200*time.Millisecond) +} + +// SetUploadCallbackWithInterval set the UploadCallback which will be invoked at least +// every `minInterval` during file upload, usually used to show upload progress. +func (r *Request) SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration) *Request { + if callback == nil { + return r + } + r.uploadCallback = callback + r.uploadCallbackInterval = minInterval + return r +} + // SetResult set the result that response body will be unmarshaled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { diff --git a/request_test.go b/request_test.go index 64f42948..50a75044 100644 --- a/request_test.go +++ b/request_test.go @@ -877,7 +877,7 @@ func TestSetFile(t *testing.T) { }) assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) - resp, err := tc().R().SetFile("file", "file-not-exists.txt").Post("/file-text") + resp, err := tc().SetLogger(nil).R().SetFile("file", "file-not-exists.txt").Post("/file-text") assertErrorContains(t, err, "no such file") } @@ -898,3 +898,37 @@ func uploadTextFile(t *testing.T, setReq func(r *Request)) *Response { assertSuccess(t, resp, err) return resp } + +type SlowReader struct { + io.ReadCloser +} + +func (r *SlowReader) Read(p []byte) (int, error) { + time.Sleep(10 * time.Millisecond) + return r.ReadCloser.Read(p) +} + +func TestUploadCallback(t *testing.T) { + r := tc().R() + file := "transport_test.go" + fileInfo, err := os.Stat(file) + if err != nil { + t.Fatal(err) + } + r.SetFile("file", file) + r.uploadFiles[0].FileSize = fileInfo.Size() + content, err := r.uploadFiles[0].GetFileContent() + if err != nil { + t.Fatal(err) + } + r.uploadFiles[0].GetFileContent = func() (io.ReadCloser, error) { + return &SlowReader{content}, nil + } + n := 0 + r.SetUploadCallback(func(info UploadInfo) { + n++ + }) + resp, err := r.Post("/raw-upload") + assertSuccess(t, resp, err) + assertEqual(t, true, n > 1) +} diff --git a/request_wrapper.go b/request_wrapper.go index b6fef7bc..0006aecc 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -409,3 +409,15 @@ func SetRetryCondition(condition RetryConditionFunc) *Request { func AddRetryCondition(condition RetryConditionFunc) *Request { return defaultClient.R().AddRetryCondition(condition) } + +// SetUploadCallback is a global wrapper methods which delegated +// to the default client, create a request and SetUploadCallback for request. +func SetUploadCallback(callback UploadCallback) *Request { + return defaultClient.R().SetUploadCallback(callback) +} + +// SetUploadCallbackWithInterval is a global wrapper methods which delegated +// to the default client, create a request and SetUploadCallbackWithInterval for request. +func SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration) *Request { + return defaultClient.R().SetUploadCallbackWithInterval(callback, minInterval) +} diff --git a/request_wrapper_test.go b/request_wrapper_test.go index 8f0e6a5c..14d43b45 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -14,14 +14,14 @@ func init() { func TestGlobalWrapperForRequestSettings(t *testing.T) { assertAllNotNil(t, - SetFiles(map[string]string{"test": "test"}), - SetFile("test", "test"), + SetFiles(map[string]string{"test": "req.go"}), + SetFile("test", "req.go"), SetFileReader("test", "test.txt", bytes.NewBufferString("test")), SetFileBytes("test", "test.txt", []byte("test")), SetFileUpload(FileUpload{}), SetError(&ErrorMessage{}), SetResult(&UserInfo{}), - SetOutput(nil), + SetOutput(new(bytes.Buffer)), SetHeader("test", "test"), SetHeaders(map[string]string{"test": "test"}), SetCookies(&http.Cookie{ @@ -68,6 +68,8 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { EnableTrace(), DisableTrace(), SetContext(context.Background()), + SetUploadCallback(nil), + SetUploadCallbackWithInterval(nil, 0), ) } From e78b0a63ca86748f5380e075a9f7b84282179dca Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 27 Mar 2022 20:35:21 +0800 Subject: [PATCH 459/843] support download callback --- README.md | 23 +++ client.go | 24 ++- docs/api.md | 2 + examples/uploadcallback/uploadclient/main.go | 2 +- middleware.go | 192 +++++++++++++------ req.go | 12 ++ req_test.go | 19 ++ request.go | 63 +++--- request_test.go | 13 +- request_wrapper.go | 12 ++ request_wrapper_test.go | 2 + transport.go | 24 ++- 12 files changed, 307 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 76758473..126e0038 100644 --- a/README.md +++ b/README.md @@ -757,6 +757,29 @@ if err != nil { client.R().SetOutput(file).Get(url) ``` +**Download Callback** + +You can set `DownloadCallback` if you want to show download progress: + +```go +client := req.C() +client.R(). + SetOutputFile("test.gz"). + SetUploadCallback(func(info req.UploadInfo) { + fmt.Printf("downloaded %.2f%%\n", float64(info.DownloadedSize)/float64(info.Response.ContentLength)*100.0) + }).Post("https://exmaple.com/upload") +/* Output +downloaded 17.92% +downloaded 41.77% +downloaded 67.71% +downloaded 98.89% +downloaded 100.00% +*/ +``` + +> `info.Response.ContentLength` could be 0 or -1 when the total size is unknown. +> `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`. + **Multipart Upload** ```go diff --git a/client.go b/client.go index 875d39e6..85f74235 100644 --- a/client.go +++ b/client.go @@ -995,8 +995,28 @@ func (c *Client) do(r *Request) (resp *Response, err error) { for _, cookie := range r.Cookies { req.AddCookie(cookie) } - if r.ctx != nil { - req = req.WithContext(r.ctx) + ctx := r.ctx + if r.isSaveResponse && r.downloadCallback != nil { + var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { + return &callbackReader{ + ReadCloser: rc, + callback: func(read int64) { + r.downloadCallback(DownloadInfo{ + Response: resp, + DownloadedSize: read, + }) + }, + lastTime: time.Now(), + interval: r.downloadCallbackInterval, + } + } + if ctx == nil { + ctx = context.Background() + } + ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) + } + if ctx != nil { + req = req.WithContext(ctx) } r.RawRequest = req diff --git a/docs/api.md b/docs/api.md index 491329b7..821ecf6a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -210,6 +210,8 @@ Basically, you can know the meaning of most settings directly from the method na * [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) * [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) +* [SetDownloadCallback(callback DownloadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallback) +* [SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallbackWithInterval) ### Retry diff --git a/examples/uploadcallback/uploadclient/main.go b/examples/uploadcallback/uploadclient/main.go index b617e8f7..19dd217e 100644 --- a/examples/uploadcallback/uploadclient/main.go +++ b/examples/uploadcallback/uploadclient/main.go @@ -43,5 +43,5 @@ func main() { FileSize: int64(size), }).SetUploadCallbackWithInterval(func(info req.UploadInfo) { fmt.Printf("%s: %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0) - }, 1*time.Second).Post("http://127.0.0.1:8888/upload") + }, 30*time.Millisecond).Post("http://127.0.0.1:8888/upload") } diff --git a/middleware.go b/middleware.go index 21315f24..e6d2f86f 100644 --- a/middleware.go +++ b/middleware.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "errors" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -80,66 +79,82 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e if err != nil { return err } + + if r.uploadCallback != nil { + pw = &callbackWriter{ + Writer: pw, + lastTime: lastTime, + interval: r.uploadCallbackInterval, + totalSize: file.FileSize, + callback: func(written int64) { + r.uploadCallback(UploadInfo{ + ParamName: file.ParamName, + FileName: file.FileName, + FileSize: file.FileSize, + UploadedSize: written, + }) + }, + } + } + if _, err = pw.Write(cbuf[:size]); err != nil { return err } if seeEOF { return nil } - if r.uploadCallback == nil { - _, err = io.Copy(pw, content) - return err - } - uploadedBytes := int64(size) - progressCallback := func() { - r.uploadCallback(UploadInfo{ - ParamName: file.ParamName, - FileName: file.FileName, - FileSize: file.FileSize, - UploadedSize: uploadedBytes, - }) - } - if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - lastTime = now - progressCallback() - } - buf := make([]byte, 1024) - for { - callback := false - nr, er := content.Read(buf) - if nr > 0 { - nw, ew := pw.Write(buf[:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errors.New("invalid write result") - } - } - uploadedBytes += int64(nw) - if ew != nil { - return ew - } - if nr != nw { - return io.ErrShortWrite - } - if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - lastTime = now - progressCallback() - callback = true - } - } - if er != nil { - if er == io.EOF { - if !callback { - progressCallback() - } - break - } else { - return er - } - } - } + _, err = io.Copy(pw, content) + return err + // uploadedBytes := int64(size) + // progressCallback := func() { + // r.uploadCallback(UploadInfo{ + // ParamName: file.ParamName, + // FileName: file.FileName, + // FileSize: file.FileSize, + // UploadedSize: uploadedBytes, + // }) + // } + // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + // lastTime = now + // progressCallback() + // } + // buf := make([]byte, 1024) + // for { + // callback := false + // nr, er := content.Read(buf) + // if nr > 0 { + // nw, ew := pw.Write(buf[:nr]) + // if nw < 0 || nr < nw { + // nw = 0 + // if ew == nil { + // ew = errors.New("invalid write result") + // } + // } + // uploadedBytes += int64(nw) + // if ew != nil { + // return ew + // } + // if nr != nw { + // return io.ErrShortWrite + // } + // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + // lastTime = now + // progressCallback() + // callback = true + // } + // } + // if er != nil { + // if er == io.EOF { + // if !callback { + // progressCallback() + // } + // break + // } else { + // return er + // } + // } + // } return nil } @@ -266,6 +281,60 @@ func parseResponseBody(c *Client, r *Response) (err error) { return } +type callbackWriter struct { + io.Writer + written int64 + totalSize int64 + lastTime time.Time + interval time.Duration + callback func(written int64) +} + +func (w *callbackWriter) Write(p []byte) (n int, err error) { + n, err = w.Writer.Write(p) + if n <= 0 { + return + } + w.written += int64(n) + if w.written == w.totalSize { + w.callback(w.written) + } else if now := time.Now(); now.Sub(w.lastTime) >= w.interval { + w.lastTime = now + w.callback(w.written) + } + return +} + +type callbackReader struct { + io.ReadCloser + read int64 + lastRead int64 + callback func(read int64) + lastTime time.Time + interval time.Duration +} + +func (r *callbackReader) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + if n <= 0 { + if err == io.EOF && r.read > r.lastRead { + r.callback(r.read) + r.lastRead = r.read + } + return + } + r.read += int64(n) + if err == io.EOF { + r.callback(r.read) + r.lastRead = r.read + } else if now := time.Now(); now.Sub(r.lastTime) >= r.interval { + r.lastTime = now + r.callback(r.read) + r.lastRead = r.read + } + return +} + func handleDownload(c *Client, r *Response) (err error) { if !r.Request.isSaveResponse { return nil @@ -302,6 +371,21 @@ func handleDownload(c *Client, r *Response) (err error) { body.Close() closeq(output) }() + + // if r.Request.downloadCallback != nil { + // output = &callbackWriter{ + // Writer: output, + // lastTime: time.Now(), + // interval: r.Request.downloadCallbackInterval, + // callback: func(written int64) { + // r.Request.downloadCallback(DownloadInfo{ + // Response: r, + // DownloadedSize: written, + // }) + // }, + // } + // } + _, err = io.Copy(output, body) r.setReceivedAt() return diff --git a/req.go b/req.go index 249f4bea..e3e4caee 100644 --- a/req.go +++ b/req.go @@ -80,6 +80,18 @@ type UploadInfo struct { // multipart upload. type UploadCallback func(info UploadInfo) +// DownloadInfo is the information for each DownloadCallback call. +type DownloadInfo struct { + // Response is the corresponding Response during download. + Response *Response + // downloaded body length in bytes. + DownloadedSize int64 +} + +// DownloadCallback is the callback which will be invoked during +// response body download. +type DownloadCallback func(info DownloadInfo) + func cloneCookies(cookies []*http.Cookie) []*http.Cookie { if len(cookies) == 0 { return nil diff --git a/req_test.go b/req_test.go index 50bddd1c..9687f311 100644 --- a/req_test.go +++ b/req_test.go @@ -15,6 +15,7 @@ import ( "os" "path/filepath" "reflect" + "strconv" "strings" "sync" "testing" @@ -357,6 +358,24 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.RawQuery)) case "/search": handleSearch(w, r) + case "/download": + size := 100 * 1024 * 1024 + w.Header().Set("Content-Length", strconv.Itoa(size)) + buf := make([]byte, 1024) + for i := 0; i < 1024; i++ { + buf[i] = 'h' + } + for i := 0; i < size; { + wbuf := buf + if size-i < 1024 { + wbuf = buf[:size-i] + } + n, err := w.Write(wbuf) + if err != nil { + break + } + i += n + } case "/protected": auth := r.Header.Get("Authorization") if auth == "Bearer goodtoken" { diff --git a/request.go b/request.go index a14af1ce..b5b873de 100644 --- a/request.go +++ b/request.go @@ -33,28 +33,30 @@ type Request struct { StartTime time.Time RetryAttempt int - RawURL string // read only - method string - URL *urlpkg.URL - getBody GetContentFunc - uploadCallback UploadCallback - uploadCallbackInterval time.Duration - unReplayableBody io.ReadCloser - retryOption *retryOption - bodyReadCloser io.ReadCloser - body []byte - dumpOptions *DumpOptions - marshalBody interface{} - ctx context.Context - isMultiPart bool - uploadFiles []*FileUpload - uploadReader []io.ReadCloser - outputFile string - isSaveResponse bool - output io.Writer - trace *clientTrace - dumpBuffer *bytes.Buffer - responseReturnTime time.Time + RawURL string // read only + method string + URL *urlpkg.URL + getBody GetContentFunc + uploadCallback UploadCallback + uploadCallbackInterval time.Duration + downloadCallback DownloadCallback + downloadCallbackInterval time.Duration + unReplayableBody io.ReadCloser + retryOption *retryOption + bodyReadCloser io.ReadCloser + body []byte + dumpOptions *DumpOptions + marshalBody interface{} + ctx context.Context + isMultiPart bool + uploadFiles []*FileUpload + uploadReader []io.ReadCloser + outputFile string + isSaveResponse bool + output io.Writer + trace *clientTrace + dumpBuffer *bytes.Buffer + responseReturnTime time.Time } type GetContentFunc func() (io.ReadCloser, error) @@ -294,6 +296,23 @@ func (r *Request) SetUploadCallbackWithInterval(callback UploadCallback, minInte return r } +// SetDownloadCallback set the DownloadCallback which will be invoked at least +// every 200ms during file upload, usually used to show download progress. +func (r *Request) SetDownloadCallback(callback DownloadCallback) *Request { + return r.SetDownloadCallbackWithInterval(callback, 200*time.Millisecond) +} + +// SetDownloadCallbackWithInterval set the DownloadCallback which will be invoked at least +// every `minInterval` during file upload, usually used to show download progress. +func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration) *Request { + if callback == nil { + return r + } + r.downloadCallback = callback + r.downloadCallbackInterval = minInterval + return r +} + // SetResult set the result that response body will be unmarshaled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { diff --git a/request_test.go b/request_test.go index 50a75044..3d30054e 100644 --- a/request_test.go +++ b/request_test.go @@ -904,7 +904,7 @@ type SlowReader struct { } func (r *SlowReader) Read(p []byte) (int, error) { - time.Sleep(10 * time.Millisecond) + time.Sleep(100 * time.Millisecond) return r.ReadCloser.Read(p) } @@ -932,3 +932,14 @@ func TestUploadCallback(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, true, n > 1) } + +func TestDownloadCallback(t *testing.T) { + n := 0 + resp, err := tc().R(). + SetOutput(ioutil.Discard). + SetDownloadCallback(func(info DownloadInfo) { + n++ + }).Get("/download") + assertSuccess(t, resp, err) + assertEqual(t, true, n > 0) +} diff --git a/request_wrapper.go b/request_wrapper.go index 0006aecc..e496cdc4 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -421,3 +421,15 @@ func SetUploadCallback(callback UploadCallback) *Request { func SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration) *Request { return defaultClient.R().SetUploadCallbackWithInterval(callback, minInterval) } + +// SetDownloadCallback is a global wrapper methods which delegated +// to the default client, create a request and SetDownloadCallback for request. +func SetDownloadCallback(callback DownloadCallback) *Request { + return defaultClient.R().SetDownloadCallback(callback) +} + +// SetDownloadCallbackWithInterval is a global wrapper methods which delegated +// to the default client, create a request and SetDownloadCallbackWithInterval for request. +func SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration) *Request { + return defaultClient.R().SetDownloadCallbackWithInterval(callback, minInterval) +} diff --git a/request_wrapper_test.go b/request_wrapper_test.go index 14d43b45..fe080646 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -70,6 +70,8 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { SetContext(context.Background()), SetUploadCallback(nil), SetUploadCallbackWithInterval(nil, 0), + SetDownloadCallback(nil), + SetDownloadCallbackWithInterval(nil, 0), ) } diff --git a/transport.go b/transport.go index 302398c5..6d63e2bd 100644 --- a/transport.go +++ b/transport.go @@ -269,15 +269,37 @@ type Transport struct { *ResponseOptions - dump *dumper + dump *dumper + + // Debugf is the optional debug function. Debugf func(format string, v ...interface{}) } +type wrapResponseBodyKeyType int + +const wrapResponseBodyKey wrapResponseBodyKeyType = iota + +type wrapResponseBodyFunc func(rc io.ReadCloser) io.ReadCloser + func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { + if wrap, ok := req.Context().Value(wrapResponseBodyKey).(wrapResponseBodyFunc); ok { + t.wrapResponseBody(res, wrap) + } t.autoDecodeResponseBody(res) t.dumpResponseBody(res, req) } +func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFunc) { + switch b := res.Body.(type) { + case *gzipReader: + b.body.body = wrap(b.body.body) + case *http2gzipReader: + b.body = wrap(b.body) + default: + res.Body = wrap(res.Body) + } +} + func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { dumps := getDumpers(req.Context(), t.dump) for _, dump := range dumps { From 2c745f6afed0983bbd0940247cb807174a89003e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 27 Mar 2022 20:47:09 +0800 Subject: [PATCH 460/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 126e0038..5bf9fa4f 100644 --- a/README.md +++ b/README.md @@ -777,8 +777,8 @@ downloaded 100.00% */ ``` -> `info.Response.ContentLength` could be 0 or -1 when the total size is unknown. -> `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`. +> 1. `info.Response.ContentLength` could be 0 or -1 when the total size is unknown. +> 2. `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`. **Multipart Upload** From 37c4693df9a992730c00c38434e3c9938873f7a3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 14 Apr 2022 12:09:26 +0800 Subject: [PATCH 461/843] remove unused code in middleware --- middleware.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/middleware.go b/middleware.go index e6d2f86f..6eb99879 100644 --- a/middleware.go +++ b/middleware.go @@ -372,20 +372,6 @@ func handleDownload(c *Client, r *Response) (err error) { closeq(output) }() - // if r.Request.downloadCallback != nil { - // output = &callbackWriter{ - // Writer: output, - // lastTime: time.Now(), - // interval: r.Request.downloadCallbackInterval, - // callback: func(written int64) { - // r.Request.downloadCallback(DownloadInfo{ - // Response: r, - // DownloadedSize: written, - // }) - // }, - // } - // } - _, err = io.Copy(output, body) r.setReceivedAt() return From f51cc1315db2057bab111d22913e5f22a1f1d431 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 14 Apr 2022 12:16:17 +0800 Subject: [PATCH 462/843] default unmarshal to json if Content-Type is not sure(#107) --- middleware.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/middleware.go b/middleware.go index 6eb99879..4434e5da 100644 --- a/middleware.go +++ b/middleware.go @@ -263,6 +263,9 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { return c.jsonUnmarshal(body, v) } else if util.IsXMLType(ct) { return c.xmlUnmarshal(body, v) + } else { + c.log.Warnf("cannot determine the unmarshal function with %q Content-Type, default to json", ct) + return c.jsonUnmarshal(body, v) } return } From 787c86b4b358f2cbc17c937b3fa1a2e157119484 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 24 Apr 2022 14:18:48 +0800 Subject: [PATCH 463/843] avoid concurrent map iteration and map write(fix #111) --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 85f74235..998c3a87 100644 --- a/client.go +++ b/client.go @@ -969,7 +969,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { if r.Headers == nil { header = make(http.Header) } else { - header = r.Headers.Clone() + header = r.Headers } contentLength := int64(len(r.body)) From df618e96918cf9d6259e0f59b9efbac1605d8b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=8F=E5=A4=A9?= Date: Mon, 25 Apr 2022 16:32:50 +0800 Subject: [PATCH 464/843] To use non-canonical keys, assign to the map directly. --- request.go | 29 +++++++++++++++++++++++++++++ request_test.go | 13 +++++++++++++ 2 files changed, 42 insertions(+) diff --git a/request.go b/request.go index b5b873de..fa305db9 100644 --- a/request.go +++ b/request.go @@ -351,6 +351,35 @@ func (r *Request) SetHeader(key, value string) *Request { r.Headers = make(http.Header) } r.Headers.Set(key, value) + + return r +} + +// SetHeadersMap set headers from a map for the request. +// To use non-canonical keys, assign to the map directly. +func (r *Request) SetHeadersMap(hdrs map[string]string) *Request { + for k, v := range hdrs { + r.SetHeaderMap(k, v) + } + return r +} + +// SetHeaderMap set a header for the request. +// To use non-canonical keys, assign to the map directly. +func (r *Request) SetHeaderMap(key, value string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers[key] = []string{value} + return r +} + +// SetHeaderMaps set headers from a map for the request. +// To use non-canonical keys, assign to the map directly. +func (r *Request) SetHeaderMaps(hdrs map[string][]string) *Request { + if r.Headers == nil { + r.Headers = hdrs + } return r } diff --git a/request_test.go b/request_test.go index 3d30054e..41c98224 100644 --- a/request_test.go +++ b/request_test.go @@ -510,6 +510,19 @@ func testHeader(t *testing.T, c *Client) { assertEqual(t, "value3", headers.Get("header3")) } +func TestSetHeaderMaps(t *testing.T) { + // set headers + headers := map[string][]string{ + "header1": {"value1"}, + "header2": {"value2", "value3"}, + } + resp, err := tc().R(). + SetHeaderMaps(headers). + Get("/headers") + assertSuccess(t, resp, err) + +} + func TestQueryParam(t *testing.T) { testWithAllTransport(t, testQueryParam) } From 8ea12465efd3b30fad1e6cf971b37a00122baa36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=8F=E5=A4=A9?= Date: Mon, 25 Apr 2022 16:39:21 +0800 Subject: [PATCH 465/843] testHeadersMaps and testHeadersMap --- request_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/request_test.go b/request_test.go index 41c98224..baea3b3a 100644 --- a/request_test.go +++ b/request_test.go @@ -518,7 +518,21 @@ func TestSetHeaderMaps(t *testing.T) { } resp, err := tc().R(). SetHeaderMaps(headers). - Get("/headers") + Get("/headersMaps") + assertSuccess(t, resp, err) + +} + +func TestSetHeadersMap(t *testing.T) { + // set headers + headers := map[string]string{ + "header1": "value1", + "header2": "value2", + } + + resp, err := tc().R(). + SetHeadersMap(headers). + Get("/headersMap") assertSuccess(t, resp, err) } From 9789245fffb2cf3ff74f25ab5961d72132e15b1e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 26 Apr 2022 15:08:12 +0800 Subject: [PATCH 466/843] support SetHeaderNonCanonical and SetHeadersNonCanonical --- request.go | 25 ++++++++----------------- request_test.go | 35 ++++++++++++++--------------------- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/request.go b/request.go index fa305db9..0de45ac5 100644 --- a/request.go +++ b/request.go @@ -355,31 +355,22 @@ func (r *Request) SetHeader(key, value string) *Request { return r } -// SetHeadersMap set headers from a map for the request. -// To use non-canonical keys, assign to the map directly. -func (r *Request) SetHeadersMap(hdrs map[string]string) *Request { +// SetHeadersNonCanonical set headers from a map for the request which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (r *Request) SetHeadersNonCanonical(hdrs map[string]string) *Request { for k, v := range hdrs { - r.SetHeaderMap(k, v) + r.SetHeaderNonCanonical(k, v) } return r } -// SetHeaderMap set a header for the request. -// To use non-canonical keys, assign to the map directly. -func (r *Request) SetHeaderMap(key, value string) *Request { +// SetHeaderNonCanonical set a header for the request which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (r *Request) SetHeaderNonCanonical(key, value string) *Request { if r.Headers == nil { r.Headers = make(http.Header) } - r.Headers[key] = []string{value} - return r -} - -// SetHeaderMaps set headers from a map for the request. -// To use non-canonical keys, assign to the map directly. -func (r *Request) SetHeaderMaps(hdrs map[string][]string) *Request { - if r.Headers == nil { - r.Headers = hdrs - } + r.Headers[key] = append(r.Headers[key], value) return r } diff --git a/request_test.go b/request_test.go index baea3b3a..b0ca95aa 100644 --- a/request_test.go +++ b/request_test.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "strconv" + "strings" "testing" "time" ) @@ -510,31 +511,23 @@ func testHeader(t *testing.T, c *Client) { assertEqual(t, "value3", headers.Get("header3")) } -func TestSetHeaderMaps(t *testing.T) { +func TestSetHeaderNonCanonical(t *testing.T) { // set headers - headers := map[string][]string{ - "header1": {"value1"}, - "header2": {"value2", "value3"}, - } - resp, err := tc().R(). - SetHeaderMaps(headers). - Get("/headersMaps") + key := "spring.cloud.function.Routing-expression" + c := tc().EnableForceHTTP1() + resp, err := c.R().EnableDumpWithoutResponse(). + SetHeadersNonCanonical(map[string]string{ + key: "test", + }).Get("/header") assertSuccess(t, resp, err) + assertEqual(t, true, strings.Contains(resp.Dump(), key)) -} - -func TestSetHeadersMap(t *testing.T) { - // set headers - headers := map[string]string{ - "header1": "value1", - "header2": "value2", - } - - resp, err := tc().R(). - SetHeadersMap(headers). - Get("/headersMap") + resp, err = c.R(). + EnableDumpWithoutResponse(). + SetHeaderNonCanonical(key, "test"). + Get("/header") assertSuccess(t, resp, err) - + assertEqual(t, true, strings.Contains(resp.Dump(), key)) } func TestQueryParam(t *testing.T) { From deaf86c52b9d6936a7b2edd1877b70a888abdf65 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 26 Apr 2022 15:30:20 +0800 Subject: [PATCH 467/843] support SetCommonHeaderNonCanonical and SetCommonHeadersNonCanonical --- client.go | 19 +++++++++++++++++++ client_test.go | 12 ++++++++++++ middleware.go | 8 +++++--- request_test.go | 7 +++++++ 4 files changed, 43 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 998c3a87..744b9682 100644 --- a/client.go +++ b/client.go @@ -595,6 +595,25 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } +// SetCommonHeaderNonCanonical set a header for all requests which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (c *Client) SetCommonHeaderNonCanonical(key, value string) *Client { + if c.Headers == nil { + c.Headers = make(http.Header) + } + c.Headers[key] = append(c.Headers[key], value) + return c +} + +// SetCommonHeadersNonCanonical set headers for all requests which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { + for k, v := range hdrs { + c.SetCommonHeaderNonCanonical(k, v) + } + return c +} + // SetCommonContentType set the `Content-Type` header for all requests. func (c *Client) SetCommonContentType(ct string) *Client { c.SetCommonHeader(hdrContentTypeKey, ct) diff --git a/client_test.go b/client_test.go index 83123093..e4bce5b8 100644 --- a/client_test.go +++ b/client_test.go @@ -144,6 +144,11 @@ func TestSetCommonHeader(t *testing.T) { assertEqual(t, "my-value", c.Headers.Get("my-header")) } +func TestSetCommonHeaderNonCanonical(t *testing.T) { + c := tc().SetCommonHeaderNonCanonical("my-Header", "my-value") + assertEqual(t, "my-value", c.Headers["my-Header"][0]) +} + func TestSetCommonHeaders(t *testing.T) { c := tc().SetCommonHeaders(map[string]string{ "header1": "value1", @@ -153,6 +158,13 @@ func TestSetCommonHeaders(t *testing.T) { assertEqual(t, "value2", c.Headers.Get("header2")) } +func TestSetCommonHeadersNonCanonical(t *testing.T) { + c := tc().SetCommonHeadersNonCanonical(map[string]string{ + "my-Header": "my-value", + }) + assertEqual(t, "my-value", c.Headers["my-Header"][0]) +} + func TestSetCommonBasicAuth(t *testing.T) { c := tc().SetCommonBasicAuth("imroc", "123456") assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) diff --git a/middleware.go b/middleware.go index 4434e5da..1bce9f6a 100644 --- a/middleware.go +++ b/middleware.go @@ -387,9 +387,11 @@ func parseRequestHeader(c *Client, r *Request) error { if r.Headers == nil { r.Headers = make(http.Header) } - for k := range c.Headers { - if r.Headers.Get(k) == "" { - r.Headers.Add(k, c.Headers.Get(k)) + for k, vs := range c.Headers { + for _, v := range vs { + if len(r.Headers[k]) == 0 { + r.Headers[k] = append(r.Headers[k], v) + } } } return nil diff --git a/request_test.go b/request_test.go index b0ca95aa..f88532d1 100644 --- a/request_test.go +++ b/request_test.go @@ -528,6 +528,13 @@ func TestSetHeaderNonCanonical(t *testing.T) { Get("/header") assertSuccess(t, resp, err) assertEqual(t, true, strings.Contains(resp.Dump(), key)) + + c.SetCommonHeaderNonCanonical(key, "test") + resp, err = c.R(). + EnableDumpWithoutResponse(). + Get("/header") + assertSuccess(t, resp, err) + assertEqual(t, true, strings.Contains(resp.Dump(), key)) } func TestQueryParam(t *testing.T) { From ea655ca65561b2dcade96657679338452a440961 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 09:51:26 +0800 Subject: [PATCH 468/843] http2: omit invalid header value from error message https://github.com/golang/net/commit/749bd193bc2bcebc5f1a048da8af0392cfb2fa5d --- h2_errors.go | 2 +- h2_frame.go | 3 ++- h2_frame_test.go | 2 +- h2_transport.go | 3 ++- h2_transport_test.go | 4 ++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/h2_errors.go b/h2_errors.go index da23201b..24cc07d1 100644 --- a/h2_errors.go +++ b/h2_errors.go @@ -129,7 +129,7 @@ func (e http2headerFieldNameError) Error() string { type http2headerFieldValueError string func (e http2headerFieldValueError) Error() string { - return fmt.Sprintf("invalid header field value %q", string(e)) + return fmt.Sprintf("invalid header field value for %q", string(e)) } var ( diff --git a/h2_frame.go b/h2_frame.go index 364a770c..f910294f 100644 --- a/h2_frame.go +++ b/h2_frame.go @@ -1554,7 +1554,8 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* h2f.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httpguts.ValidHeaderFieldValue(hf.Value) { - invalid = http2headerFieldValueError(hf.Value) + // Don't include the value in the error, because it may be sensitive. + invalid = http2headerFieldValueError(hf.Name) } isPseudo := strings.HasPrefix(hf.Name, ":") if isPseudo { diff --git a/h2_frame_test.go b/h2_frame_test.go index 573dd99a..b2f75aa4 100644 --- a/h2_frame_test.go +++ b/h2_frame_test.go @@ -1092,7 +1092,7 @@ func TestMetaFrameHeader(t *testing.T) { name: "invalid_field_value", w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, want: http2streamError(1, http2ErrCodeProtocol), - wantErrReason: "invalid header field value \"bad_null\\x00\"", + wantErrReason: `invalid header field value for "key"`, }, } for i, tt := range tests { diff --git a/h2_transport.go b/h2_transport.go index 0a1ef6a8..fece52c2 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -1744,7 +1744,8 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } for _, v := range vv { if !httpguts.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + // Don't include the value in the error, because it may be sensitive. + return nil, fmt.Errorf("invalid HTTP header value for header %q", k) } } } diff --git a/h2_transport_test.go b/h2_transport_test.go index 360f2273..3addbe14 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -1472,7 +1472,7 @@ func TestTransportInvalidTrailerEmptyFieldName(t *testing.T) { }) } func TestTransportInvalidTrailerBinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, http2headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) { + testInvalidTrailer(t, oneHeader, http2headerFieldValueError("x"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) }) } @@ -2438,7 +2438,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { }, 3: { h: http.Header{"foo": {"foo\x01bar"}}, - wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`, + wantErr: `invalid HTTP header value for header "foo"`, }, } From 52b1c8a88d0af5c09c461991cb7036bb9942074f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 09:59:55 +0800 Subject: [PATCH 469/843] http2: log pings and RoundTrip retries when http2debug=1 https://github.com/golang/net/commit/543a649e0bddcda61dd54a96c5e90417ba219c63 --- h2_transport.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/h2_transport.go b/h2_transport.go index fece52c2..40e29c9f 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -494,12 +494,14 @@ func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) if req, err = http2shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { + t.vlogf("RoundTrip retrying after failure: %v", err) continue } backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) select { case <-time.After(time.Second * time.Duration(backoff)): + t.vlogf("RoundTrip retrying after failure: %v", err) continue case <-req.Context().Done(): err = req.Context().Err() @@ -726,11 +728,15 @@ func (cc *http2ClientConn) healthCheck() { // trigger the healthCheck again if there is no frame received. ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) if err != nil { + cc.vlogf("http2: Transport health check failure: %v", err) cc.closeForLostPing() cc.t.connPool().MarkDead(cc) return + } else { + cc.vlogf("http2: Transport health check success") } } From e75d6fdb288680715c7b0714c13085e77aa710ff Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 13:54:40 +0800 Subject: [PATCH 470/843] http2: improve handling of slow-closing net.Conns https://github.com/golang/net/commit/27dd8689420fcde088514397d015e4fea5174e0e --- h2_transport.go | 41 +++++++++++++++++++++++++++++------------ tls.go | 18 ++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/h2_transport.go b/h2_transport.go index 40e29c9f..525ea0a3 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -733,7 +733,6 @@ func (cc *http2ClientConn) healthCheck() { if err != nil { cc.vlogf("http2: Transport health check failure: %v", err) cc.closeForLostPing() - cc.t.connPool().MarkDead(cc) return } else { cc.vlogf("http2: Transport health check success") @@ -885,6 +884,24 @@ func (cc *http2ClientConn) onIdleTimeout() { cc.closeIfIdle() } +func (cc *http2ClientConn) closeConn() error { + t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) + defer t.Stop() + return cc.tconn.Close() +} + +// A tls.Conn.Close can hang for a long time if the peer is unresponsive. +// Try to shut it down more aggressively. +func (cc *http2ClientConn) forceCloseConn() { + tc, ok := cc.tconn.(NetConnWrapper) + if !ok { + return + } + if nc := tc.NetConn(); nc != nil { + nc.Close() + } +} + func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 || cc.streamsReserved > 0 { @@ -899,7 +916,7 @@ func (cc *http2ClientConn) closeIfIdle() { if http2VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) } - cc.tconn.Close() + cc.closeConn() } func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { @@ -916,7 +933,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { return err } // Wait for all in-flight streams to complete or connection to close - done := make(chan error, 1) + done := make(chan struct{}) cancelled := false // guarded by cc.mu go func() { cc.mu.Lock() @@ -924,7 +941,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { for { if len(cc.streams) == 0 || cc.closed { cc.closed = true - done <- cc.tconn.Close() + close(done) break } if cancelled { @@ -935,8 +952,8 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { }() http2shutdownEnterWaitStateHook() select { - case err := <-done: - return err + case <-done: + return cc.closeConn() case <-ctx.Done(): cc.mu.Lock() // Free the goroutine above @@ -979,9 +996,9 @@ func (cc *http2ClientConn) closeForError(err error) error { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - defer cc.cond.Broadcast() - defer cc.mu.Unlock() - return cc.tconn.Close() + cc.cond.Broadcast() + cc.mu.Unlock() + return cc.closeConn() } // Close closes the client connection immediately. @@ -2013,7 +2030,7 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) } cc.closed = true - defer cc.tconn.Close() + defer cc.closeConn() } cc.mu.Unlock() @@ -2060,8 +2077,8 @@ func http2isEOFOrNetReadError(err error) bool { func (rl *http2clientConnReadLoop) cleanup() { cc := rl.cc - defer cc.tconn.Close() - defer cc.t.connPool().MarkDead(cc) + cc.t.connPool().MarkDead(cc) + defer cc.closeConn() defer close(cc.readerDone) if cc.idleTimer != nil { diff --git a/tls.go b/tls.go index 121425d5..66f8dbb8 100644 --- a/tls.go +++ b/tls.go @@ -12,6 +12,24 @@ import ( // If this interface is not implemented, HTTP1 will be used by default. type TLSConn interface { net.Conn + // ConnectionState returns basic TLS details about the connection. ConnectionState() tls.ConnectionState + // Handshake runs the client or server handshake + // protocol if it has not yet been run. + // + // Most uses of this package need not call Handshake explicitly: the + // first Read or Write will call it automatically. + // + // For control over canceling or setting a timeout on a handshake, use + // HandshakeContext or the Dialer's DialContext method instead. Handshake() error } + +// NetConnWrapper is the interface to get underlying connection, which is +// introduced in go1.18 for *tls.Conn. +type NetConnWrapper interface { + // NetConn returns the underlying connection that is wrapped by c. + // Note that writing to or reading from this connection directly will corrupt the + // TLS session. + NetConn() net.Conn +} From 88db662f02463c1f05e25cd4c8af778de51ce6a3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 14:00:39 +0800 Subject: [PATCH 471/843] http2: use custom concurrent safe noBodyReader type when no body is present https://github.com/golang/net/commit/1d1ef9303861d099ec7e69ccb17377e0c443542d --- h2_transport.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/h2_transport.go b/h2_transport.go index 525ea0a3..04113c31 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -17,7 +17,6 @@ import ( "fmt" "github.com/imroc/req/v3/internal/ascii" "io" - "io/ioutil" "log" "math" mathrand "math/rand" @@ -2952,7 +2951,12 @@ func (t *http2Transport) logf(format string, args ...interface{}) { log.Printf(format, args...) } -var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) +var http2noBody io.ReadCloser = noBodyReader{} + +type noBodyReader struct{} + +func (noBodyReader) Close() error { return nil } +func (noBodyReader) Read([]byte) (int, error) { return 0, io.EOF } type http2missingBody struct{} From e1c29b79d61a77a8ba765da38bd9ff52008772cd Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 14:03:53 +0800 Subject: [PATCH 472/843] test with go1.18 in github action --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 833e746a..2e17af74 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: name: Build strategy: matrix: - go: [ '1.17.x', '1.16.x' ] + go: [ '1.18.x', '1.17.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} From 451dc9beae29295ea7ef00b0031d0d7e96aa2391 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 14:31:56 +0800 Subject: [PATCH 473/843] net/http: deflake request-not-written path https://github.com/golang/go/commit/3d7f83612390d913e7e8bb4ffa3dc69c41b3078d --- transport.go | 6 ++++++ transport_internal_test.go | 9 +++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/transport.go b/transport.go index 6d63e2bd..abe85fde 100644 --- a/transport.go +++ b/transport.go @@ -596,6 +596,9 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { } else if !pconn.shouldRetryRequest(req, err) { // Issue 16465: return underlying net.Conn.Read error from peek, // as we've historically done. + if e, ok := err.(nothingWrittenError); ok { + err = e.error + } if e, ok := err.(transportReadFromServerError); ok { err = e.err } @@ -2069,6 +2072,9 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte } if _, ok := err.(transportReadFromServerError); ok { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } // Don't decorate return err } diff --git a/transport_internal_test.go b/transport_internal_test.go index f0b9edd6..47c25a06 100644 --- a/transport_internal_test.go +++ b/transport_internal_test.go @@ -58,8 +58,8 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { conn.Close() // simulate the server hanging up on the client _, err = pc.roundTrip(treq) - if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) + if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle { + t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err) } <-pc.closech @@ -69,6 +69,11 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { } } +func isNothingWrittenError(err error) bool { + _, ok := err.(nothingWrittenError) + return ok +} + func isTransportReadFromServerError(err error) bool { _, ok := err.(transportReadFromServerError) return ok From c307b32a6848048233e0bed6b54d1aa83d98154f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 14:39:08 +0800 Subject: [PATCH 474/843] net/http: ignore ECONNRESET errors in TestTransportConcurrency on netbsd https://github.com/golang/go/commit/81ae993e54547415ba674082801b05961e3f2aa3 --- transport_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/transport_test.go b/transport_test.go index 47d921b6..3452f08d 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2279,17 +2279,21 @@ func TestTransportConcurrency(t *testing.T) { for req := range reqs { res, err := c.Get(ts.URL + "/?echo=" + req) if err != nil { - t.Errorf("error on req %s: %v", req, err) + if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { + // https://go.dev/issue/52168: this test was observed to fail with + // ECONNRESET errors in Dial on various netbsd builders. + t.Logf("error on req %s: %v", req, err) + t.Logf("(see https://go.dev/issue/52168)") + } else { + t.Errorf("error on req %s: %v", req, err) + } wg.Done() continue } all, err := io.ReadAll(res.Body) if err != nil { t.Errorf("read error on req %s: %v", req, err) - wg.Done() - continue - } - if string(all) != req { + } else if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } res.Body.Close() From fd7448230bac11d29b008f4f84b58197dec4fd5e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 14:46:56 +0800 Subject: [PATCH 475/843] net/http: correctly show error types in transfer test https://github.com/golang/go/commit/740a490f71d026bb7d2d13cb8fa2d6d6e0572b70 --- transfer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transfer_test.go b/transfer_test.go index 056d920d..0721aeed 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -268,7 +268,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { } if tc.expectedReader != actualReader { - t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader) + t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader) } } From f73e2f3dd76a304137b2e81fa362206b8fbf56e5 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 15:05:28 +0800 Subject: [PATCH 476/843] net/http: deflake TestTransportConnectionCloseOnRequest https://github.com/golang/go/commit/a2f7d9d95a84dedb6909bf1907d6857c2c4a2ef5 --- transport_test.go | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/transport_test.go b/transport_test.go index 3452f08d..59c3f684 100644 --- a/transport_test.go +++ b/transport_test.go @@ -436,6 +436,12 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { connSet.check(t) } +// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse +// an underlying TCP connection after making an http.Request with Request.Close set. +// +// It tests the behavior by making an HTTP request to a server which +// describes the source source connection it got (remote port number + +// address of its net.Conn) func TestTransportConnectionCloseOnRequest(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) @@ -446,7 +452,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { c := tc().httpClient tr := c.Transport.(*Transport) tr.DialContext = testDial - for _, connectionClose := range []bool{false, true} { + for _, reqClose := range []bool{false, true} { fetch := func(n int) string { req := new(http.Request) var err error @@ -458,29 +464,37 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 - req.Close = connectionClose + req.Close = reqClose res, err := c.Do(req) if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) } - if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want { - t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", - connectionClose, got, !connectionClose) + if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { + t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", + reqClose, got, !reqClose) } body, err := io.ReadAll(res.Body) if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) - bodiesDiffer := body1 != body2 - if bodiesDiffer != connectionClose { - t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", - connectionClose, bodiesDiffer, body1, body2) + + got := 1 + if body1 != body2 { + got++ + } + want := 1 + if reqClose { + want = 2 + } + if got != want { + t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", + reqClose, got, want, body1, body2) } tr.CloseIdleConnections() From 64b3c41fee958bfbdfe01e6f374ed5169b995742 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 12 May 2022 20:52:58 +0800 Subject: [PATCH 477/843] avoid panic when invoke Response if error happened --- response.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/response.go b/response.go index e0d27d44..d57d8307 100644 --- a/response.go +++ b/response.go @@ -17,16 +17,25 @@ type Response struct { // IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. func (r *Response) IsSuccess() bool { + if r.Response == nil { + return false + } return r.StatusCode > 199 && r.StatusCode < 300 } // IsError method returns true if HTTP status `code >= 400` otherwise false. func (r *Response) IsError() bool { + if r.Response == nil { + return false + } return r.StatusCode > 399 } // GetContentType return the `Content-Type` header value. func (r *Response) GetContentType() string { + if r.Response == nil { + return "" + } return r.Header.Get(hdrContentTypeKey) } From 0a9ad4b22d4cdd94b04bdfab3b334ee9eb884d62 Mon Sep 17 00:00:00 2001 From: puz_zle Date: Sun, 15 May 2022 22:46:55 +0800 Subject: [PATCH 478/843] Update api.md https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlUnmarshal --- docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 821ecf6a..fe0ba0d8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -94,7 +94,7 @@ Basically, you can know the meaning of most settings directly from the method na * [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) * [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) -* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#SetXmlUnmarshal) +* [SetXmlUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlUnmarshal) * [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) ### Middleware From 6dd03866a0b09c2e2a5d4edbe4f2569d1eab235b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 May 2022 14:34:58 +0800 Subject: [PATCH 479/843] improve debug log --- client.go | 5 ----- h2_transport.go | 3 +++ h2_transport_test.go | 2 ++ transport.go | 3 +++ 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 744b9682..1a9deed1 100644 --- a/client.go +++ b/client.go @@ -1038,11 +1038,6 @@ func (c *Client) do(r *Request) (resp *Response, err error) { req = req.WithContext(ctx) } r.RawRequest = req - - if c.DebugLog { - c.log.Debugf("%s %s", req.Method, req.URL.String()) - } - r.StartTime = time.Now() var httpResponse *http.Response httpResponse, err = c.httpClient.Do(req) diff --git a/h2_transport.go b/h2_transport.go index 04113c31..8c55563a 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -1091,6 +1091,9 @@ func (cc *http2ClientConn) decrStreamReservationsLocked() { } func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + if cc.t != nil && cc.t.t1 != nil && cc.t.t1.Debugf != nil { + cc.t.t1.Debugf("HTTP/2 %s %s", req.Method, req.URL.String()) + } cc.currentRequest = req ctx := req.Context() cs := &http2clientStream{ diff --git a/h2_transport_test.go b/h2_transport_test.go index 3addbe14..1b143d8a 100644 --- a/h2_transport_test.go +++ b/h2_transport_test.go @@ -110,6 +110,7 @@ func TestTransportH2c(t *testing.T) { } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) tr := &http2Transport{ + t1: &Transport{}, AllowHTTP: true, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) @@ -638,6 +639,7 @@ func TestTransportDialTLSh2(t *testing.T) { ) defer ts.Close() tr := &http2Transport{ + t1: &Transport{}, DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() didDial = true diff --git a/transport.go b/transport.go index abe85fde..2dd2006c 100644 --- a/transport.go +++ b/transport.go @@ -2741,6 +2741,9 @@ var ( ) func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, err error) { + if pc.t.Debugf != nil { + pc.t.Debugf("HTTP/1.1 %s %s", req.Method, req.URL.String()) + } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) From 62efe1e6dccf2d6d19b837085a5b853519afba84 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 May 2022 15:17:26 +0800 Subject: [PATCH 480/843] Let EnableForceHTTP1 also take effect when called when there is already an http2 connection --- transport.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/transport.go b/transport.go index 2dd2006c..24d60420 100644 --- a/transport.go +++ b/transport.go @@ -525,14 +525,16 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { cancelKey := cancelKey{origReq} req = setupRewindBody(req) - if altRT := t.alternateRoundTripper(req); altRT != nil { - if resp, err := altRT.RoundTrip(req); err != http.ErrSkipAltProtocol { - return resp, err - } - var err error - req, err = rewindBody(req) - if err != nil { - return nil, err + if t.ForceHttpVersion != HTTP1 { + if altRT := t.alternateRoundTripper(req); altRT != nil { + if resp, err := altRT.RoundTrip(req); err != http.ErrSkipAltProtocol { + return resp, err + } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } } } if !isHTTP { @@ -576,7 +578,7 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { } var resp *http.Response - if pconn.alt != nil { + if t.ForceHttpVersion != HTTP1 && pconn.alt != nil { // HTTP/2 path. t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) From 6f0074c7741a0b37470de381d7a76032ca23460d Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 19 May 2022 21:35:35 +0800 Subject: [PATCH 481/843] update README --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 5bf9fa4f..0c2b3d95 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ If you want to use the older version, check it out on [v1 branch](https://github > v2 is a transitional version, due to some breaking changes were introduced during optmize user experience +## Documentation + +Full documentation is available on the [Req Official Website](https://req.cool/). + ## Table of Contents * [Features](#Features) From 59e60f07e81a18d054dcf59bc8480c73197701ae Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 27 May 2022 18:07:53 +0800 Subject: [PATCH 482/843] Expose method of Request --- client.go | 2 +- middleware.go | 4 ++-- request.go | 4 ++-- request_test.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 1a9deed1..1d61dbfd 100644 --- a/client.go +++ b/client.go @@ -1000,7 +1000,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } } req := &http.Request{ - Method: r.method, + Method: r.Method, Header: header, URL: r.URL, Host: host, diff --git a/middleware.go b/middleware.go index 1bce9f6a..758579b6 100644 --- a/middleware.go +++ b/middleware.go @@ -220,12 +220,12 @@ func handleMarshalBody(c *Client, r *Request) error { } func parseRequestBody(c *Client, r *Request) (err error) { - if c.isPayloadForbid(r.method) { + if c.isPayloadForbid(r.Method) { r.getBody = nil return } // handle multipart - if r.isMultiPart && (r.method != http.MethodPatch) { + if r.isMultiPart && (r.Method != http.MethodPatch) { return handleMultiPart(c, r) } diff --git a/request.go b/request.go index 0de45ac5..7b886f75 100644 --- a/request.go +++ b/request.go @@ -34,7 +34,7 @@ type Request struct { RetryAttempt int RawURL string // read only - method string + Method string URL *urlpkg.URL getBody GetContentFunc uploadCallback UploadCallback @@ -453,7 +453,7 @@ func (r *Request) Send(method, url string) (*Response, error) { if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable body return &Response{Request: r}, errRetryableWithUnReplayableBody } - r.method = method + r.Method = method r.RawURL = url return r.client.do(r) } diff --git a/request_test.go b/request_test.go index f88532d1..5dd76597 100644 --- a/request_test.go +++ b/request_test.go @@ -140,7 +140,7 @@ func TestSendMethods(t *testing.T) { testMethod(t, c, func(req *Request) *Response { resp, err := tc.SendReq(req) if err != nil { - t.Errorf("%s %s: %s", req.method, req.RawURL, err.Error()) + t.Errorf("%s %s: %s", req.Method, req.RawURL, err.Error()) } return resp }, tc.ExpectMethod, false) From b478931a6bfce09d5ad39c3741aaeb3795a840f6 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 27 May 2022 20:12:38 +0800 Subject: [PATCH 483/843] update README --- README.md | 947 +----------------------------------------------------- 1 file changed, 10 insertions(+), 937 deletions(-) diff --git a/README.md b/README.md index 0c2b3d95..00be1e3b 100644 --- a/README.md +++ b/README.md @@ -24,29 +24,6 @@ If you want to use the older version, check it out on [v1 branch](https://github Full documentation is available on the [Req Official Website](https://req.cool/). -## Table of Contents - -* [Features](#Features) -* [Get Started](#Get-Started) -* [Debugging - Dump/Log/Trace](#Debugging) -* [Quick HTTP Test](#Test) -* [HTTP2 and HTTP1](#HTTP2-HTTP1) -* [URL Path and Query Parameter](#Param) -* [Form Data](#Form) -* [Header and Cookie](#Header-Cookie) -* [Body and Marshal/Unmarshal](#Body) -* [Custom Certificates](#Cert) -* [Basic Auth and Bearer Token](#Auth) -* [Download and Upload](#Download-Upload) -* [Auto-Decode](#AutoDecode) -* [Request and Response Middleware](#Middleware) -* [Redirect Policy](#Redirect) -* [Proxy](#Proxy) -* [Unix Socket](#Unix) -* [Retry](#Retry) -* [TODO List](#TODO) -* [License](#License) - ## Features * Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. @@ -81,18 +58,19 @@ import "github.com/imroc/req/v3" // For test, you can create and send a request with the global default // client, use DevMode to see all details, try and suprise :) req.DevMode() -req.Get("https://api.github.com/users/imroc") +req.Get("https://httpbin.org/get") -// Create and send a request with the custom client and settings -client := req.C(). // Use C() to create a client - SetUserAgent("my-custom-client"). // Chainable client settings +// In production, create a client explicitly and reuse it to send all requests +// Create and send a request with the custom client and settings. +client := req.C(). // Use C() to create a client and set with chainable client settings. + SetUserAgent("my-custom-client"). SetTimeout(5 * time.Second). DevMode() -resp, err := client.R(). // Use R() to create a request - SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings +resp, err := client.R(). // Use R() to create a request and set with chainable request settings. + SetHeader("Accept", "application/vnd.github.v3+json"). SetPathParam("username", "imroc"). SetQueryParam("page", "1"). - SetResult(&result). + SetResult(&result). // Unmarshal response into struct automatically. Get("https://api.github.com/users/{username}/repos") ``` @@ -101,914 +79,9 @@ resp, err := client.R(). // Use R() to create a request * [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) * [快速上手 req](https://www.bilibili.com/video/BV1Xq4y1b7UR) (Chinese, BiliBili) -**API Reference** - -Checkout [Quick API Reference](docs/api.md) for a brief and categorized list of the core APIs, for a more detailed and complete list of APIs, please refer to [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). - -**Examples** - -Checkout more examples below or runnable examples in the [examples](examples) directory. - -## Debugging - Dump/Log/Trace - -**Dump the Content** - -```go -// Enable dump at client level, which will dump for all requests, -// including all content of request and response and output -// to stdout by default. -client := req.C().EnableDumpAll() -client.R().Get("https://httpbin.org/get") - -/* Output -:authority: httpbin.org -:method: GET -:path: /get -:scheme: https -user-agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36 -accept-encoding: gzip - -:status: 200 -date: Wed, 26 Jan 2022 06:39:20 GMT -content-type: application/json -content-length: 372 -server: gunicorn/19.9.0 -access-control-allow-origin: * -access-control-allow-credentials: true - -{ - "args": {}, - "headers": { - "Accept-Encoding": "gzip", - "Host": "httpbin.org", - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36", - "X-Amzn-Trace-Id": "Root=1-61f0ec98-5958c02662de26e458b7672b" - }, - "origin": "103.7.29.30", - "url": "https://httpbin.org/get" -} -*/ - -// Customize dump settings with predefined and convenient settings at client level. -client.EnableDumpAllWithoutBody(). // Only dump the header of request and response - EnableDumpAllAsync(). // Dump asynchronously to improve performance - EnableDumpAllToFile("reqdump.log") // Dump to file without printing it out -// Send request to see the content that have been dumped -client.R().Get(url) - -// Enable dump with fully customized settings at client level. -opt := &req.DumpOptions{ - Output: os.Stdout, - RequestHeader: true, - ResponseBody: true, - RequestBody: false, - ResponseHeader: false, - Async: false, - } -client.SetCommonDumpOptions(opt).EnableDumpAll() -client.R().Get("https://httpbin.org/get") - -// Change settings dynamiclly -opt.ResponseBody = false -client.R().Get("https://httpbin.org/get") - -// You can also enable dump at request level, which will not override client-level dumping, -// dump to the internal buffer and will not print it out by default, you can call `Response.Dump()` -// to get the dump result and print only if you want to, typically used in production, only record -// the content of the request when the request is abnormal to help us troubleshoot problems. -resp, err := client.R().EnableDump().SetBody("test body").Post("https://httpbin.org/post") -if err != nil { - fmt.Println("err:", err) - if resp.Dump() != "" { - fmt.Println("raw content:") - fmt.Println(resp.Dump()) - } - return -} -if !resp.IsSuccess() { // Status code not beetween 200 and 299 - fmt.Println("bad status:", resp.Status) - fmt.Println("raw content:") - fmt.Println(resp.Dump()) - return -} - -// Similarly, also support to customize dump settings with the predefined and convenient settings at request level. -resp, err = client.R().EnableDumpWithoutRequest().SetBody("test body").Post("https://httpbin.org/post") -// ... -resp, err = client.R().SetDumpOptions(opt).EnableDump().SetBody("test body").Post("https://httpbin.org/post") -``` - -**Enable DebugLog for Deeper Insights** - -```go -// Logging is enabled by default, but only output the warning and error message. -// Use `EnableDebugLog` to enable debug level logging. -client := req.C().EnableDebugLog() -client.R().Get("http://baidu.com/s?wd=req") -/* Output -2022/01/26 15:46:29.279368 DEBUG [req] GET http://baidu.com/s?wd=req -2022/01/26 15:46:29.469653 DEBUG [req] charset iso-8859-1 detected in Content-Type, auto-decode to utf-8 -2022/01/26 15:46:29.469713 DEBUG [req] GET http://www.baidu.com/s?wd=req -... -*/ - -// SetLogger with nil to disable all log -client.SetLogger(nil) - -// Or customize the logger with your own implementation. -client.SetLogger(logger) -``` - -**Enable Trace to Analyze Performance** - -```go -// Enable trace at request level -client := req.C() -resp, err := client.R().EnableTrace().Get("https://api.github.com/users/imroc") -if err != nil { - log.Fatal(err) -} -trace := resp.TraceInfo() // Use `resp.Request.TraceInfo()` to avoid unnecessary struct copy in production. -fmt.Println(trace.Blame()) // Print out exactly where the http request is slowing down. -fmt.Println("----------") -fmt.Println(trace) // Print details - -/* Output -the request total time is 2.562416041s, and costs 1.289082208s from connection ready to server respond frist byte --------- -TotalTime : 2.562416041s -DNSLookupTime : 445.246375ms -TCPConnectTime : 428.458µs -TLSHandshakeTime : 825.888208ms -FirstResponseTime : 1.289082208s -ResponseTime : 1.712375ms -IsConnReused: : false -RemoteAddr : 98.126.155.187:443 -*/ - -// Enable trace at client level -client.EnableTraceAll() -resp, err = client.R().Get(url) -// ... -``` - -**DevMode** - -If you want to enable all debug features (dump, debug log and tracing), just call `DevMode()`: - -```go -client := req.C().DevMode() -client.R().Get("https://imroc.cc") -``` - -## Quick HTTP Test - -**Test with Global Wrapper Methods** - -`req` wrap methods of both `Client` and `Request` with global methods, which is delegated to the default client behind the scenes, so you can just treat the package name `req` as a Client or Request to test quickly without create one explicitly. - -```go -// Call the global methods just like the Client's method, -// so you can treat package name `req` as a Client, and -// you don't need to create any client explicitly. -req.SetTimeout(5 * time.Second). - SetCommonBasicAuth("imroc", "123456"). - SetCommonHeader("Accept", "text/xml"). - SetUserAgent("my api client"). - DevMode() - -// Call the global method just like the Request's method, -// which will create request automatically using the default -// client, so you can treat package name `req` as a Request, -// and you don't need to create any request and client explicitly. -req.SetQueryParam("page", "2"). - SetHeader("Accept", "application/json"). // Override client level settings at request level. - Get("https://httpbin.org/get") -``` - -**Test with MustXXX** - -Use `MustXXX` to ignore error handling during test, make it possible to complete a complex test with just one line of code: - -```go -fmt.Println(req.DevMode().R().MustGet("https://imroc.cc").TraceInfo()) -``` - -## HTTP2 and HTTP1 - -Req works fine both with `HTTP/2` and `HTTP/1.1`, `HTTP/2` is preferred by default if server support, which is negotiated by TLS handshake. - -You can force using `HTTP/1.1` if you want. - -```go -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutBody() -client.R().MustGet("https://httpbin.org/get") -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Accept-Encoding: gzip - -HTTP/1.1 200 OK -Date: Tue, 08 Feb 2022 02:30:18 GMT -Content-Type: application/json -Content-Length: 289 -Connection: keep-alive -Server: gunicorn/19.9.0 -Access-Control-Allow-Origin: * -Access-Control-Allow-Credentials: true -*/ -``` - -And also you can force using `HTTP/2` if you want, will return error if server does not support: - -```go -client := req.C().EnableForceHTTP2() -client.R().MustGet("https://baidu.com") -/* Output -panic: Get "https://baidu.com": server does not support http2, you can use http/1.1 which is supported -*/ -``` - -## URL Path and Query Parameter - -**Path Parameter** - -Use `SetPathParam` or `SetPathParams` to replace variable in the url path: - -```go -client := req.C().DevMode() - -client.R(). - SetPathParam("owner", "imroc"). // Set a path param, which will replace the vairable in url path - SetPathParams(map[string]string{ // Set multiple path params at once - "repo": "req", - "path": "README.md", - }).Get("https://api.github.com/repos/{owner}/{repo}/contents/{path}") // path parameter will replace path variable in the url -/* Output -2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md -... -*/ - -// You can also set the common PathParam for every request on client -client.SetCommonPathParam(k1, v1).SetCommonPathParams(pathParams) - -resp1, err := client.Get(url1) -... - -resp2, err := client.Get(url2) -... -``` - -**Query Parameter** - -Use `SetQueryParam`, `SetQueryParams` or `SetQueryString` to append url query parameter: - -```go -client := req.C().DevMode() - -// Set query parameter at request level. -client.R(). - SetQueryParam("a", "a"). // Set a query param, which will be encoded as query parameter in url - SetQueryParams(map[string]string{ // Set multiple query params at once - "b": "b", - "c": "c", - }).SetQueryString("d=d&e=e"). // Set query params as a raw query string - Get("https://api.github.com/repos/imroc/req/contents/README.md?x=x") -/* Output -2022/01/23 14:43:59.114592 DEBUG [req] GET https://api.github.com/repos/imroc/req/contents/README.md?x=x&a=a&b=b&c=c&d=d&e=e -... -*/ - -// You can also set the query parameter at client level. -client.SetCommonQueryParam(k, v). - SetCommonQueryParams(queryParams). - SetCommonQueryString(queryString). - -resp1, err := client.Get(url1) -... -resp2, err := client.Get(url2) -... - -// Add query parameter with multiple values at request level. -client.R().AddQueryParam("key", "value1").AddQueryParam("key", "value2").Get("https://httpbin.org/get") -/* Output -2022/02/05 08:49:26.260780 DEBUG [req] GET https://httpbin.org/get?key=value1&key=value2 -... - */ - - -// Multiple values also supported at client level. -client.AddCommonQueryParam("key", "value1").AddCommonQueryParam("key", "value2") -``` - -## Form Data - -```go -client := req.C().EnableDumpAllWithoutResponse() -client.R().SetFormData(map[string]string{ - "username": "imroc", - "blog": "https://imroc.cc", -}).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/x-www-form-urlencoded -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -blog=https%3A%2F%2Fimroc.cc&username=imroc -*/ - -// Multi value form data -v := url.Values{ - "multi": []string{"a", "b", "c"}, -} -client.R().SetFormDataFromValues(v).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/x-www-form-urlencoded -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -multi=a&multi=b&multi=c -*/ - -// You can also set form data in client level -client.SetCommonFormData(m) -client.SetCommonFormDataFromValues(v) -``` - -> `GET`, `HEAD`, and `OPTIONS` requests ignores form data by default - -## Header and Cookie - -**Set Header** - -```go -// Let's dump the header to see what's going on -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() - -// Send a request with multiple headers and cookies -client.R(). - SetHeader("Accept", "application/json"). // Set one header - SetHeaders(map[string]string{ // Set multiple headers at once - "My-Custom-Header": "My Custom Value", - "User": "imroc", - }).Get("https://httpbin.org/get") - -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Accept: application/json -My-Custom-Header: My Custom Value -User: imroc -Accept-Encoding: gzip -*/ - -// You can also set the common header and cookie for every request on client. -client.SetCommonHeader(header).SetCommonHeaders(headers) - -resp1, err := client.R().Get(url1) -... -resp2, err := client.R().Get(url2) -... -``` - -**Set Cookie** - -```go -// Let's dump the header to see what's going on -client := req.C().EnableForceHTTP1().EnableDumpAllWithoutResponse() - -// Send a request with multiple headers and cookies -client.R(). - SetCookies( - &http.Cookie{ - Name: "testcookie1", - Value: "testcookie1 value", - Path: "/", - Domain: "baidu.com", - MaxAge: 36000, - HttpOnly: false, - Secure: true, - }, - &http.Cookie{ - Name: "testcookie2", - Value: "testcookie2 value", - Path: "/", - Domain: "baidu.com", - MaxAge: 36000, - HttpOnly: false, - Secure: true, - }, - ).Get("https://httpbin.org/get") - -/* Output -GET /get HTTP/1.1 -Host: httpbin.org -User-Agent: req/v3 (https://github.com/imroc/req) -Cookie: testcookie1="testcookie1 value"; testcookie2="testcookie2 value" -Accept-Encoding: gzip -*/ - -// You can also set the common cookie for every request on client. -client.SetCommonCookies(cookie1, cookie2, cookie3) - -resp1, err := client.R().Get(url1) -... -resp2, err := client.R().Get(url2) -``` - -You can also customize the CookieJar: -```go -// Set your own http.CookieJar implementation -client.SetCookieJar(jar) - -// Set to nil to disable CookieJar -client.SetCookieJar(nil) -``` - -## Body and Marshal/Unmarshal - -**Request Body** - -```go -// Create a client that dump request -client := req.C().EnableDumpAllWithoutResponse() -// SetBody accepts string, []byte, io.Reader, use type assertion to -// determine the data type of body automatically. -client.R().SetBody("test").Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -test -*/ - -// If it cannot determine, like map and struct, then it will wait -// and marshal to JSON or XML automatically according to the `Content-Type` -// header that have been set before or after, default to json if not set. -type User struct { - Name string `json:"name"` - Email string `json:"email"` -} -user := &User{Name: "imroc", Email: "roc@imroc.cc"} -client.R().SetBody(user).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/json; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -{"name":"imroc","email":"roc@imroc.cc"} -*/ - - -// You can use more specific methods to avoid type assertions and improves performance, -client.R().SetBodyJsonString(`{"username": "imroc"}`).Post("https://httpbin.org/post") -/* -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: application/json; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -{"username": "imroc"} -*/ - -// Marshal body and set `Content-Type` automatically without any guess -cient.R().SetBodyXmlMarshal(user).Post("https://httpbin.org/post") -/* Output -:authority: httpbin.org -:method: POST -:path: /post -:scheme: https -content-type: text/xml; charset=utf-8 -accept-encoding: gzip -user-agent: req/v2 (https://github.com/imroc/req) - -imrocroc@imroc.cc -*/ -``` - -**Response Body** - -```go -// Define success body struct -type User struct { - Name string `json:"name"` - Blog string `json:"blog"` -} -// Define error body struct -type ErrorMessage struct { - Message string `json:"message"` -} -// Create a client and dump body to see details -client := req.C().EnableDumpAllWithoutHeader() - -// Send a request and unmarshal result automatically according to -// response `Content-Type` -user := &User{} -errMsg := &ErrorMessage{} -resp, err := client.R(). - SetResult(user). // Set success result - SetError(errMsg). // Set error result - Get("https://api.github.com/users/imroc") -if err != nil { - log.Fatal(err) -} -fmt.Println("----------") - -if resp.IsSuccess() { // status `code >= 200 and <= 299` is considered as success - // Must have been marshaled to user if no error returned before - fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) -} else if resp.IsError() { // status `code >= 400` is considered as error - // Must have been marshaled to errMsg if no error returned before - fmt.Println("got error:", errMsg.Message) -} else { - log.Fatal("unknown http status:", resp.Status) -} -/* Output -{"login":"imroc","id":7448852,"node_id":"MDQ6VXNlcjc0NDg4NTI=","avatar_url":"https://avatars.githubusercontent.com/u/7448852?v=4","gravatar_id":"","url":"https://api.github.com/users/imroc","html_url":"https://github.com/imroc","followers_url":"https://api.github.com/users/imroc/followers","following_url":"https://api.github.com/users/imroc/following{/other_user}","gists_url":"https://api.github.com/users/imroc/gists{/gist_id}","starred_url":"https://api.github.com/users/imroc/starred{/owner}{/repo}","subscriptions_url":"https://api.github.com/users/imroc/subscriptions","organizations_url":"https://api.github.com/users/imroc/orgs","repos_url":"https://api.github.com/users/imroc/repos","events_url":"https://api.github.com/users/imroc/events{/privacy}","received_events_url":"https://api.github.com/users/imroc/received_events","type":"User","site_admin":false,"name":"roc","company":"Tencent","blog":"https://imroc.cc","location":"China","email":null,"hireable":true,"bio":"I'm roc","twitter_username":"imrocchan","public_repos":129,"public_gists":0,"followers":362,"following":151,"created_at":"2014-04-30T10:50:46Z","updated_at":"2022-01-24T23:32:53Z"} ----------- -roc's blog is https://imroc.cc -*/ - -// Or you can also unmarshal response later -if resp.IsSuccess() { - err = resp.Unmarshal(user) - if err != nil { - log.Fatal(err) - } - fmt.Printf("%s's blog is %s\n", user.Name, user.Blog) -} else { - fmt.Println("bad response:", resp) -} - -// Also, you can get the raw response and Unmarshal by yourself -yaml.Unmarshal(resp.Bytes()) -``` - -**Customize JSON&XML Marshal/Unmarshal** - -```go -// Example of registering json-iterator -import jsoniter "github.com/json-iterator/go" - -json := jsoniter.ConfigCompatibleWithStandardLibrary - -client := req.C(). - SetJsonMarshal(json.Marshal). - SetJsonUnmarshal(json.Unmarshal) - -// Similarly, XML functions can also be customized -client.SetXmlMarshal(xmlMarshalFunc).SetXmlUnmarshal(xmlUnmarshalFunc) -``` - -**Disable Auto-Read Response Body** - -Response body will be read into memory if it's not a download request by default, you can disable it if you want (normally you don't need to do this). - -```go -client.DisableAutoReadResponse() - -resp, err := client.R().Get(url) -if err != nil { - log.Fatal(err) -} -io.Copy(dst, resp.Body) -``` - -## Custom Certificates - -```go -client := req.R() - -// Set root cert and client cert from file path -client.SetRootCertsFromFile("/path/to/root/certs/pemFile1.pem", "/path/to/root/certs/pemFile2.pem", "/path/to/root/certs/pemFile3.pem"). // Set root cert from one or more pem files - SetCertFromFile("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") // Set client cert and key cert file - -// You can also set root cert from string -client.SetRootCertFromString("-----BEGIN CERTIFICATE-----XXXXXX-----END CERTIFICATE-----") - -// And set client cert with -cert1, err := tls.LoadX509KeyPair("/path/to/client/certs/client.pem", "/path/to/client/certs/client.key") -if err != nil { - log.Fatalf("ERROR client certificate: %s", err) -} -// ... - -// you can add more certs if you want -client.SetCerts(cert1, cert2, cert3) -``` - -## Basic Auth and Bearer Token - -```go -client := req.C() - -// Set basic auth for all request -client.SetCommonBasicAuth("imroc", "123456") - -// Set bearer token for all request -client.SetCommonBearerAuthToken("MDc0ZTg5YmU4Yzc5MjAzZGJjM2ZiMzkz") - -// Set basic auth for a request, will override client's basic auth setting. -client.R().SetBasicAuth("myusername", "mypassword").Get("https://api.example.com/profile") - -// Set bearer token for a request, will override client's bearer token setting. -client.R().SetBearerToken("NGU1ZWYwZDJhNmZhZmJhODhmMjQ3ZDc4").Get("https://api.example.com/profile") -``` - -## Download and Upload - -**Download** - -```go -// Create a client with default download direcotry -client := req.C().SetOutputDirectory("/path/to/download").EnableDumpAllWithoutResponseBody() - -// Download to relative file path, this will be downloaded -// to /path/to/download/test.jpg -client.R().SetOutputFile("test.jpg").Get(url) - -// Download to absolute file path, ignore the output directory -// setting from Client -client.R().SetOutputFile("/tmp/test.jpg").Get(url) - -// You can also save file to any `io.WriteCloser` -file, err := os.Create("/tmp/test.jpg") -if err != nil { - fmt.Println(err) - return -} -client.R().SetOutput(file).Get(url) -``` - -**Download Callback** - -You can set `DownloadCallback` if you want to show download progress: - -```go -client := req.C() -client.R(). - SetOutputFile("test.gz"). - SetUploadCallback(func(info req.UploadInfo) { - fmt.Printf("downloaded %.2f%%\n", float64(info.DownloadedSize)/float64(info.Response.ContentLength)*100.0) - }).Post("https://exmaple.com/upload") -/* Output -downloaded 17.92% -downloaded 41.77% -downloaded 67.71% -downloaded 98.89% -downloaded 100.00% -*/ -``` - -> 1. `info.Response.ContentLength` could be 0 or -1 when the total size is unknown. -> 2. `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`. - -**Multipart Upload** - -```go -client := req.C().EnableDumpAllWithoutRequestBody() // Request body contains unreadable binary, do not dump - -client.R().SetFile("pic", "test.jpg"). // Set form param name and filename - SetFile("pic", "/path/to/roc.png"). // Multiple files using the same form param name - SetFiles(map[string]string{ // Set multiple files using map - "exe": "test.exe", - "src": "main.go", - }). - SetFormData(map[string]string{ // Set form data while uploading - "name": "imroc", - "email": "roc@imroc.cc", - }). - SetFromDataFromValues(values). // You can also set form data using `url.Values` - Post("http://127.0.0.1:8888/upload") - -// You can also use io.Reader to upload -avatarImgFile, _ := os.Open("avatar.png") -client.R().SetFileReader("avatar", "avatar.png", avatarImgFile).Post(url) -*/ -``` - -**Upload Callback** - -You can set `UploadCallback` if you want to show upload progress: - -```go -client := req.C() -client.R(). - SetFile("excel", "test.xlsx"). - SetUploadCallback(func(info req.UploadInfo) { - fmt.Printf("%q uploaded %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0) - }).Post("https://exmaple.com/upload") -/* Output -"test.xlsx" uploaded 7.44% -"test.xlsx" uploaded 29.78% -"test.xlsx" uploaded 52.08% -"test.xlsx" uploaded 74.47% -"test.xlsx" uploaded 96.87% -"test.xlsx" uploaded 100.00% -*/ -``` - -> `UploadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetUploadCallbackWithInterval`. - -## Auto-Decode - -`Req` detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default. - -Its principle is to detect `Content-Type` header at first, if it's not the text content type (json, xml, html and so on), `req` will not try to decode. If it is, then `req` will try to find the charset information. And `req` also will try to sniff the body's content to determine the charset if the charset information is not included in the header, if sniffed out and not utf-8, then decode it to utf-8 automatically, and `req` will not try to decode if the charset is not sure, just leave the body untouched. - -You can also disable it if you don't need or care a lot about performance: - -```go -client.DisableAutoDecode() -``` - -And also you can make some customization: - -```go -// Try to auto-detect and decode all content types (some server may return incorrect Content-Type header) -client.SetAutoDecodeAllContentType() - -// Only auto-detect and decode content which `Content-Type` header contains "html" or "json" -client.SetAutoDecodeContentType("html", "json") - -// Or you can customize the function to determine whether to decode -fn := func(contentType string) bool { - if regexContentType.MatchString(contentType) { - return true - } - return false -} -client.SetAutoDecodeContentTypeFunc(fn) -``` - -## Request and Response Middleware - -```go -client := req.C() - -// Registering Request Middleware -client.OnBeforeRequest(func(c *req.Client, r *req.Request) error { - // You can access Client and current Request object to do something - // as you need - - return nil // return nil if it is success - }) - -// Registering Response Middleware -client.OnAfterResponse(func(c *req.Client, r *req.Response) error { - // You can access Client and current Response object to do something - // as you need - - return nil // return nil if it is success - }) -``` - -## Redirect Policy - -```go -client := req.C().EnableDumpAllWithoutResponse() - -client.SetRedirectPolicy( - // Only allow up to 5 redirects - req.MaxRedirectPolicy(5), - // Only allow redirect to same domain. - // e.g. redirect "www.imroc.cc" to "imroc.cc" is allowed, but "google.com" is not - req.SameDomainRedirectPolicy(), -) - -client.SetRedirectPolicy( - // Only *.google.com/google.com and *.imroc.cc/imroc.cc is allowed to redirect - req.AllowedDomainRedirectPolicy("google.com", "imroc.cc"), - // Only allow redirect to same host. - // e.g. redirect "www.imroc.cc" to "imroc.cc" is not allowed, only "www.imroc.cc" is allowed - req.SameHostRedirectPolicy(), -) - -// All redirect is not allowd -client.SetRedirectPolicy(req.NoRedirectPolicy()) - -// Or customize the redirect with your own implementation -client.SetRedirectPolicy(func(req *http.Request, via []*http.Request) error { - // ... -}) -``` - -## Proxy - -`Req` use proxy `http.ProxyFromEnvironment` by default, which will read the `HTTP_PROXY/HTTPS_PROXY/http_proxy/https_proxy` environment variable, and setup proxy if environment variable is been set. You can customize it if you need: - -```go -// Set proxy from proxy url -client.SetProxyURL("http://myproxy:8080") - -// Custmize the proxy function with your own implementation -client.SetProxy(func(request *http.Request) (*url.URL, error) { - // ... -}) - -// Disable proxy -client.SetProxy(nil) -``` - -## Unix Socket - -```go -client := req.C() -client.SetUnixSocket("/var/run/custom.sock") -client.SetBaseURL("http://example.local") - -resp, err := client.R().Get("/index.html") -``` - -## Retry - -You can enable retry for all requests at client-level (check the full list of client-level retry settings around [here](./docs/api.md#Retry-Client)): - -```go -client := req.C() - -// Enable retry and set the maximum retry count. -client.SetCommonRetryCount(3). - -// Set the retry sleep interval with a commonly used algorithm: capped exponential backoff with jitter (https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). -client.SetCommonRetryBackoffInterval(1 * time.Second, 5 * time.Second) - -// Set the retry to sleep fixed interval of 2 seconds. -client.SetCommonRetryFixedInterval(2 * time.Seconds) - -// Set the retry to use a custom retry interval algorithm. -client.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { - // Sleep seconds from "Retry-After" response header if it is present and correct. - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html - if resp.Response != nil { - if ra := resp.Header.Get("Retry-After"); ra != "" { - after, err := strconv.Atoi(ra) - if err == nil { - return time.Duration(after) * time.Second - } - } - } - return 2 * time.Second // Otherwise, sleep 2 seconds -}) - -// Add a retry condition which determines whether the request should retry. -client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return err != nil -}) - -// Add another retry condition -client.AddCommonRetryCondition(func(resp *req.Response, err error) bool { - return resp.StatusCode == http.StatusUnauthorized -}) - -// Add a retry hook which will be executed before a retry. -client.AddCommonRetryHook(func(resp *req.Response, err error){ - req := resp.Request.RawRequest - fmt.Println("Retry request:", req.Method, req.URL) - // Modify request settings in the retry hook. - resp.Request.SetBearerAuthToken(token) -}) -``` - -You can also override retry settings at request-level (check the full list of request-level retry settings around [here](./docs/api.md#Retry-Request)): - -```go -client.R(). - SetRetryCount(2). - SetRetryInterval(intervalFunc). - AddRetryHook(hookFunc2). - SetRetryHook(hookFunc1). // Unlike add, set will remove all other retry hooks which is added before at both request and client level. - AddRetryCondition(conditionFunc2). - SetRetryCondition(conditionFunc1) // Similarly, this will remove all other retry conditions which is added before at both request and client level. -``` - -## TODO List +**More** -* [ ] Wrap more transport settings into client. -* [ ] Support h2c. -* [ ] Design a logo. -* [ ] Support HTTP3. +Check more introduction, tutorials, examples and API references on the [official website](https://req.cool/). ## License From 463be0555d2cdfe1ba8eb1018c0ccc94d74be91a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 27 May 2022 20:31:40 +0800 Subject: [PATCH 484/843] update README --- README.md | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 00be1e3b..d45f8a48 100644 --- a/README.md +++ b/README.md @@ -26,17 +26,15 @@ Full documentation is available on the [Req Official Website](https://req.cool/) ## Features -* Simple and chainable methods for both client-level and request-level settings, and the request-level setting takes precedence if both are set. -* Powerful and convenient debug utilites, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging - Dump/Log/Trace](#Debugging)). -* Easy making HTTP test with code instead of tools like curl or postman, `req` provide global wrapper methods and `MustXXX` to test API with minimal code (see [Quick HTTP Test](#Test)). -* Works fine with both `HTTP/2` and `HTTP/1.1`, which `HTTP/2` is preferred by default if server support, and you can also force `HTTP/1.1` if you want (see [HTTP2 and HTTP1](#HTTP2-HTTP1)). -* Detect the charset of response body and decode it to utf-8 automatically to avoid garbled characters by default (see [Auto-Decode](#AutoDecode)). -* Automatic marshal and unmarshal for JSON and XML content type and fully customizable (see [Body and Marshal/Unmarshal](#Body)). -* Exportable `Transport`, easy to integrate with existing `http.Client`, debug APIs with minimal code change. -* Easy [Download and Upload](#Download-Upload). -* Easy set header, cookie, path parameter, query parameter, form data, basic auth, bearer token for both client and request level. -* Easy set timeout, proxy, certs, redirect policy, cookie jar, compression, keepalives etc for client. -* Support middleware before request sent and after got response (see [Request and Response Middleware](#Middleware)). +* Simple and Powerful: Providing rich client-level and request-level settings, all of which are intuitive and chainable methods, and the request-level setting takes precedence if both are set. +* Easy Debugging: Powerful and convenient debug utilities, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging](https://req.cool/docs/tutorial/debugging/)). +* Easy API Testing: API testing can be done with minimal code, no need to explicitly create any Requests and Clients, or even to handle errors (See [Quick HTTP Test](https://req.cool/docs/tutorial/quick-test/)) +* Smart by Default: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. +* Works fine with HTTP2: Support both with HTTP/2 and HTTP/1.1, and HTTP/2 is preferred by default if server support, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). +* Support Retry: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). +* Easy Download and Upload: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). +* Exportable: `Transport` is exportable, which support dump requests, it's easy to integrate with existing http.Client, so you can debug APIs with minimal code change. +* Extensible: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware/)). ## Get Started From a2dc4f91e58a07353244c5b534f04d2dc740e65f Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 27 May 2022 20:32:56 +0800 Subject: [PATCH 485/843] update README --- README.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/README.md b/README.md index d45f8a48..2421ffbb 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,6 @@

-## News - -Brand-New version v3 is released, which is completely rewritten, bringing revolutionary innovations and many superpowers, try and enjoy :) - -If you want to use the older version, check it out on [v1 branch](https://github.com/imroc/req/tree/v1). - -> v2 is a transitional version, due to some breaking changes were introduced during optmize user experience - ## Documentation Full documentation is available on the [Req Official Website](https://req.cool/). From 4da8ed5fe0ff84e7685b239ca06e17a1ce0aad8e Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 28 May 2022 08:24:44 +0800 Subject: [PATCH 486/843] update README --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2421ffbb..b806148d 100644 --- a/README.md +++ b/README.md @@ -45,10 +45,12 @@ import "github.com/imroc/req/v3" **Basic Usage** ```go -// For test, you can create and send a request with the global default -// client, use DevMode to see all details, try and suprise :) -req.DevMode() -req.Get("https://httpbin.org/get") +// For testing, you can create and send a request with the global wrapper methods +// that use the default client behind the scenes to initiate the request (you can +// just treat package name `req` as a Client or Request, no need to create any client +// or Request explicitly). +req.DevMode() // Use Client.DevMode to see all details, try and surprise :) +req.Get("https://httpbin.org/get") // Use Request.Get to send a GET request. // In production, create a client explicitly and reuse it to send all requests // Create and send a request with the custom client and settings. From 728eef189d2e1ca46a778666eb07139618bc82f7 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 28 May 2022 08:25:51 +0800 Subject: [PATCH 487/843] update README --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index b806148d..6cc23701 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ req.DevMode() // Use Client.DevMode to see all details, try and surprise :) req.Get("https://httpbin.org/get") // Use Request.Get to send a GET request. // In production, create a client explicitly and reuse it to send all requests -// Create and send a request with the custom client and settings. client := req.C(). // Use C() to create a client and set with chainable client settings. SetUserAgent("my-custom-client"). SetTimeout(5 * time.Second). From 2e4b6f66c0bbcc57e7d027462313974df92b179c Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 28 May 2022 09:00:59 +0800 Subject: [PATCH 488/843] update README --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6cc23701..856ea8cd 100644 --- a/README.md +++ b/README.md @@ -18,15 +18,15 @@ Full documentation is available on the [Req Official Website](https://req.cool/) ## Features -* Simple and Powerful: Providing rich client-level and request-level settings, all of which are intuitive and chainable methods, and the request-level setting takes precedence if both are set. -* Easy Debugging: Powerful and convenient debug utilities, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging](https://req.cool/docs/tutorial/debugging/)). -* Easy API Testing: API testing can be done with minimal code, no need to explicitly create any Requests and Clients, or even to handle errors (See [Quick HTTP Test](https://req.cool/docs/tutorial/quick-test/)) -* Smart by Default: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. -* Works fine with HTTP2: Support both with HTTP/2 and HTTP/1.1, and HTTP/2 is preferred by default if server support, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). -* Support Retry: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). -* Easy Download and Upload: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). -* Exportable: `Transport` is exportable, which support dump requests, it's easy to integrate with existing http.Client, so you can debug APIs with minimal code change. -* Extensible: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware/)). +* **Simple and Powerful**: Providing rich client-level and request-level settings, all of which are intuitive and chainable methods, and the request-level setting takes precedence if both are set. +* **Easy Debugging**: Powerful and convenient debug utilities, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging](https://req.cool/docs/tutorial/debugging/)). +* **Easy API Testing**: API testing can be done with minimal code, no need to explicitly create any Requests and Clients, or even to handle errors (See [Quick HTTP Test](https://req.cool/docs/tutorial/quick-test/)) +* **Smart by Default**: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. +* **Works fine with HTTP2**: Support both with HTTP/2 and HTTP/1.1, and HTTP/2 is preferred by default if server support, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). +* **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). +* **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). +* **Exportable**: `Transport` is exportable, which support dump requests, it's easy to integrate with existing http.Client, so you can debug APIs with minimal code change. +* **Extensible**: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware/)). ## Get Started From 029862c6ce49146e77f205d291a91c1c41015dc5 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 1 Jun 2022 21:25:45 +0800 Subject: [PATCH 489/843] improve response unmarshal --- middleware.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/middleware.go b/middleware.go index 758579b6..551e2397 100644 --- a/middleware.go +++ b/middleware.go @@ -275,11 +275,25 @@ func parseResponseBody(c *Client, r *Response) (err error) { return } // Handles only JSON or XML content type - if r.Request.Result != nil && r.IsSuccess() { - unmarshalBody(c, r, r.Request.Result) + if r.Request.Result != nil { + if r.IsSuccess() { + err = unmarshalBody(c, r, r.Request.Result) + if err != nil { + r.Request.Result = nil + } + } else { + r.Request.Result = nil + } } - if r.Request.Error != nil && r.IsError() { - unmarshalBody(c, r, r.Request.Error) + if r.Request.Error != nil { + if r.IsError() { + err = unmarshalBody(c, r, r.Request.Error) + if err != nil { + r.Request.Error = nil + } + } else { + r.Request.Error = nil + } } return } From accc5da695ea449e5612aa78ec0c01d5039ad444 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 2 Jun 2022 15:11:58 +0800 Subject: [PATCH 490/843] optimize SetError and SetResult --- client.go | 2 ++ middleware.go | 25 ++++++++----------------- request.go | 12 ++++++------ response.go | 6 ++++-- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index 1d61dbfd..14a94d06 100644 --- a/client.go +++ b/client.go @@ -1084,6 +1084,8 @@ func (c *Client) do(r *Request) (resp *Response, err error) { r.trace = nil r.ctx = nil resp.body = nil + resp.result = nil + resp.error = nil } for _, f := range r.client.afterResponse { diff --git a/middleware.go b/middleware.go index 551e2397..964e0bc8 100644 --- a/middleware.go +++ b/middleware.go @@ -274,25 +274,16 @@ func parseResponseBody(c *Client, r *Response) (err error) { if r.StatusCode == http.StatusNoContent { return } - // Handles only JSON or XML content type - if r.Request.Result != nil { - if r.IsSuccess() { - err = unmarshalBody(c, r, r.Request.Result) - if err != nil { - r.Request.Result = nil - } - } else { - r.Request.Result = nil + if r.Request.Result != nil && r.IsSuccess() { + err = unmarshalBody(c, r, r.Request.Result) + if err == nil { + r.result = r.Request.Result } } - if r.Request.Error != nil { - if r.IsError() { - err = unmarshalBody(c, r, r.Request.Error) - if err != nil { - r.Request.Error = nil - } - } else { - r.Request.Error = nil + if r.Request.Error != nil && r.IsError() { + err = unmarshalBody(c, r, r.Request.Error) + if err == nil { + r.error = r.Request.Error } } return diff --git a/request.go b/request.go index 7b886f75..04c8e079 100644 --- a/request.go +++ b/request.go @@ -32,9 +32,11 @@ type Request struct { RawRequest *http.Request StartTime time.Time RetryAttempt int + RawURL string // read only + Method string - RawURL string // read only - Method string + isMultiPart bool + isSaveResponse bool URL *urlpkg.URL getBody GetContentFunc uploadCallback UploadCallback @@ -48,11 +50,9 @@ type Request struct { dumpOptions *DumpOptions marshalBody interface{} ctx context.Context - isMultiPart bool uploadFiles []*FileUpload uploadReader []io.ReadCloser outputFile string - isSaveResponse bool output io.Writer trace *clientTrace dumpBuffer *bytes.Buffer @@ -313,14 +313,14 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min return r } -// SetResult set the result that response body will be unmarshaled to if +// SetResult set the result that response body will be unmarshalled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r } -// SetError set the result that response body will be unmarshaled to if +// SetError set the result that response body will be unmarshalled to if // request is error ( status `code >= 400`). func (r *Request) SetError(error interface{}) *Request { r.Error = util.GetPointer(error) diff --git a/response.go b/response.go index d57d8307..1646bd58 100644 --- a/response.go +++ b/response.go @@ -13,6 +13,8 @@ type Response struct { Request *Request body []byte receivedAt time.Time + error interface{} + result interface{} } // IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. @@ -41,12 +43,12 @@ func (r *Response) GetContentType() string { // Result returns the response value as an object if it has one func (r *Response) Result() interface{} { - return r.Request.Result + return r.result } // Error returns the error object if it has one. func (r *Response) Error() interface{} { - return r.Request.Error + return r.error } // TraceInfo returns the TraceInfo from Request. From f6a7b9152706edd47d547c04ef7f9d90f1eaa06d Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 2 Jun 2022 17:09:11 +0800 Subject: [PATCH 491/843] client support SetCommonError --- client.go | 32 +++++++++++++++++++++----------- client_wrapper.go | 6 ++++++ client_wrapper_test.go | 1 + internal/util/util.go | 5 +++++ middleware.go | 17 +++++++++++++---- request_test.go | 9 +++++++++ 6 files changed, 55 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 14a94d06..3022ce93 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,7 @@ import ( "net/http/cookiejar" urlpkg "net/url" "os" + "reflect" "strings" "time" ) @@ -36,23 +37,23 @@ var defaultClient *Client = C() // Client is the req's http client. type Client struct { - BaseURL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - FormData urlpkg.Values - DebugLog bool - AllowGetMethodPayload bool - + BaseURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + FormData urlpkg.Values + DebugLog bool + AllowGetMethodPayload bool + trace bool + disableAutoReadResponse bool + commonErrorType reflect.Type retryOption *retryOption jsonMarshal func(v interface{}) ([]byte, error) jsonUnmarshal func(data []byte, v interface{}) error xmlMarshal func(v interface{}) ([]byte, error) xmlUnmarshal func(data []byte, v interface{}) error - trace bool outputDirectory string - disableAutoReadResponse bool scheme string log Logger t *Transport @@ -72,6 +73,15 @@ func (c *Client) R() *Request { } } +// SetCommonError set the common result that response body will be unmarshalled to +// if it is an error response ( status `code >= 400`). +func (c *Client) SetCommonError(err interface{}) *Client { + if err != nil { + c.commonErrorType = util.GetType(err) + } + return c +} + // SetCommonFormDataFromValues set the form data from url.Values for all requests // which request method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { diff --git a/client_wrapper.go b/client_wrapper.go index 404abf02..56c5e0bf 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -10,6 +10,12 @@ import ( "time" ) +// SetCommonError is a global wrapper methods which delegated +// to the default client's SetCommonError. +func SetCommonError(err interface{}) *Client { + return defaultClient.SetCommonError(err) +} + // SetCommonFormDataFromValues is a global wrapper methods which delegated // to the default client's SetCommonFormDataFromValues. func SetCommonFormDataFromValues(data url.Values) *Client { diff --git a/client_wrapper_test.go b/client_wrapper_test.go index e8bf2edb..8b830327 100644 --- a/client_wrapper_test.go +++ b/client_wrapper_test.go @@ -23,6 +23,7 @@ func TestGlobalWrapper(t *testing.T) { form.Add("test", "test") assertAllNotNil(t, + SetCommonError(nil), SetCookieJar(nil), SetDialTLS(nil), SetDial(nil), diff --git a/internal/util/util.go b/internal/util/util.go index cd816d13..a110a600 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -27,6 +27,11 @@ func GetPointer(v interface{}) interface{} { return reflect.New(vv.Type()).Interface() } +// GetType return the underlying type. +func GetType(v interface{}) reflect.Type { + return reflect.Indirect(reflect.ValueOf(v)).Type() +} + // CutString slices s around the first instance of sep, // returning the text before and after sep. // The found result reports whether sep appears in s. diff --git a/middleware.go b/middleware.go index 964e0bc8..e4a7b939 100644 --- a/middleware.go +++ b/middleware.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "time" ) @@ -280,10 +281,18 @@ func parseResponseBody(c *Client, r *Response) (err error) { r.result = r.Request.Result } } - if r.Request.Error != nil && r.IsError() { - err = unmarshalBody(c, r, r.Request.Error) - if err == nil { - r.error = r.Request.Error + if r.IsError() { + if r.Request.Error != nil { + err = unmarshalBody(c, r, r.Request.Error) + if err == nil { + r.error = r.Request.Error + } + } else if c.commonErrorType != nil { + e := reflect.New(c.commonErrorType).Interface() + err = unmarshalBody(c, r, e) + if err == nil { + r.error = e + } } } return diff --git a/request_test.go b/request_test.go index 5dd76597..beac1381 100644 --- a/request_test.go +++ b/request_test.go @@ -687,6 +687,15 @@ func testError(t *testing.T, c *Client) { Get("/search") assertIsError(t, resp, err) assertEqual(t, 10001, errMsg.ErrorCode) + + c.SetCommonError(&errMsg) + resp, err = c.R(). + SetQueryParam("username", ""). + Get("/search") + assertIsError(t, resp, err) + em, ok := resp.Error().(*ErrorMessage) + assertEqual(t, true, ok) + assertEqual(t, 10000, em.ErrorCode) } func TestForm(t *testing.T) { From 6780a3ce202515dce14f011433fd50fb681091e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=8F=E5=A4=A9?= Date: Tue, 7 Jun 2022 11:51:10 +0800 Subject: [PATCH 492/843] =?UTF-8?q?=E5=9C=A8=E4=B8=8D=E4=BD=BF=E7=94=A8dum?= =?UTF-8?q?p=E7=9A=84=E6=83=85=E5=86=B5=E4=B8=8B=EF=BC=8C=E9=80=9A?= =?UTF-8?q?=E8=BF=87header=E7=9A=84=E6=96=B9=E6=B3=95=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E8=BF=94=E5=9B=9Eheader=E7=9A=84=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- response.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/response.go b/response.go index e0d27d44..b68f4ba9 100644 --- a/response.go +++ b/response.go @@ -142,3 +142,18 @@ func (r *Response) ToBytes() ([]byte, error) { func (r *Response) Dump() string { return r.Request.getDumpBuffer().String() } + +// GetStatusCode return the response status code. +func (r *Response) GetStatusCode() int { + return r.StatusCode +} + +// GetHeaderValue returns the response header value by key. +func (r *Response) GetHeaderValue(key string) string { + return r.Header.Get(key) +} + +// GetHeaderValues returns the response header values by key. +func (r *Response) GetHeaderValues(key string) []string { + return r.Header.Values(key) +} From 26944a376de572bd6ab3236c46c7db4433b43961 Mon Sep 17 00:00:00 2001 From: MJrocker <1725014728@qq.com> Date: Tue, 7 Jun 2022 18:00:54 +0800 Subject: [PATCH 493/843] Fix a syntax error in a map key --- trace.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trace.go b/trace.go index c2324b42..a182a4a7 100644 --- a/trace.go +++ b/trace.go @@ -36,8 +36,8 @@ func (t TraceInfo) Blame() string { "on dns lookup": t.DNSLookupTime, "on tcp connect": t.TCPConnectTime, "on tls handshake": t.TLSHandshakeTime, - "from connection ready to server respond frist byte": t.FirstResponseTime, - "from server respond frist byte to request completion": t.ResponseTime, + "from connection ready to server respond first byte": t.FirstResponseTime, + "from server respond first byte to request completion": t.ResponseTime, } for k, v := range m { if v > mv { From c65ababf705bec814f72674fd777c2cc679255d0 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 7 Jun 2022 20:26:28 +0800 Subject: [PATCH 494/843] fmt --- trace.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trace.go b/trace.go index a182a4a7..1d9926ca 100644 --- a/trace.go +++ b/trace.go @@ -36,8 +36,8 @@ func (t TraceInfo) Blame() string { "on dns lookup": t.DNSLookupTime, "on tcp connect": t.TCPConnectTime, "on tls handshake": t.TLSHandshakeTime, - "from connection ready to server respond first byte": t.FirstResponseTime, - "from server respond first byte to request completion": t.ResponseTime, + "from connection ready to server respond first byte": t.FirstResponseTime, + "from server respond first byte to request completion": t.ResponseTime, } for k, v := range m { if v > mv { From 4dde48e5499c5853f7683477fbe3b2eda7904c76 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 7 Jun 2022 20:49:28 +0800 Subject: [PATCH 495/843] wrap response --- response.go | 23 ++++++++++++++++++++--- response_test.go | 10 ++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/response.go b/response.go index 155efede..da9654ff 100644 --- a/response.go +++ b/response.go @@ -154,17 +154,34 @@ func (r *Response) Dump() string { return r.Request.getDumpBuffer().String() } -// GetStatusCode return the response status code. +// GetStatus returns the response status. +func (r *Response) GetStatus() string { + if r.Response == nil { + return "" + } + return r.Status +} + +// GetStatusCode returns the response status code. func (r *Response) GetStatusCode() int { + if r.Response == nil { + return 0 + } return r.StatusCode } -// GetHeaderValue returns the response header value by key. -func (r *Response) GetHeaderValue(key string) string { +// GetHeader returns the response header value by key. +func (r *Response) GetHeader(key string) string { + if r.Response == nil { + return "" + } return r.Header.Get(key) } // GetHeaderValues returns the response header values by key. func (r *Response) GetHeaderValues(key string) []string { + if r.Response == nil { + return nil + } return r.Header.Values(key) } diff --git a/response_test.go b/response_test.go index af7aa979..1baaa93e 100644 --- a/response_test.go +++ b/response_test.go @@ -1,6 +1,7 @@ package req import ( + "net/http" "testing" ) @@ -59,3 +60,12 @@ func TestResponseError(t *testing.T) { } assertEqual(t, "not allowed", msg.Message) } + +func TestResponseWrap(t *testing.T) { + resp, err := tc().R().Get("/json") + assertSuccess(t, resp, err) + assertEqual(t, true, resp.GetStatusCode() == http.StatusOK) + assertEqual(t, true, resp.GetStatus() == "200 OK") + assertEqual(t, true, resp.GetHeader(hdrContentTypeKey) == jsonContentType) + assertEqual(t, true, len(resp.GetHeaderValues(hdrContentTypeKey)) == 1) +} From 97a292b68b3988dd89efd24c6bdb41a8d92b00d1 Mon Sep 17 00:00:00 2001 From: MJrocker <1725014728@qq.com> Date: Tue, 7 Jun 2022 22:20:58 +0800 Subject: [PATCH 496/843] Fixed some typos in the code comment --- h2_transport.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/h2_transport.go b/h2_transport.go index 8c55563a..9b09ce4b 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -863,10 +863,10 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { return st.canTakeNewRequest } -// tooIdleLocked reports whether this connection has been been sitting idle +// tooIdleLocked reports whether this connection has been sitting idle // for too much wall time. func (cc *http2ClientConn) tooIdleLocked() bool { - // The Round(0) strips the monontonic clock reading so the + // The Round(0) strips the monotonic clock reading so the // times are compared based on their wall time. We don't want // to reuse a connection that's been sitting idle during // VM/laptop suspend if monotonic time was also frozen. From 3912e2a80b1ce3e1176bc8b7907d917ea0a0bc50 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 8 Jun 2022 10:10:34 +0800 Subject: [PATCH 497/843] remove extra blank space --- h2_transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/h2_transport.go b/h2_transport.go index 9b09ce4b..9fd48fb5 100644 --- a/h2_transport.go +++ b/h2_transport.go @@ -863,7 +863,7 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { return st.canTakeNewRequest } -// tooIdleLocked reports whether this connection has been sitting idle +// tooIdleLocked reports whether this connection has been sitting idle // for too much wall time. func (cc *http2ClientConn) tooIdleLocked() bool { // The Round(0) strips the monotonic clock reading so the From 42182b74746005cf8692f93f8da6bfdcbeb38e31 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 8 Jun 2022 11:08:37 +0800 Subject: [PATCH 498/843] Update README.md --- README.md | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 856ea8cd..09a502a8 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,28 @@ resp, err := client.R(). // Use R() to create a request and set with chainable r SetHeader("Accept", "application/vnd.github.v3+json"). SetPathParam("username", "imroc"). SetQueryParam("page", "1"). - SetResult(&result). // Unmarshal response into struct automatically. + SetResult(&result). // Unmarshal response into struct automatically if status code >= 200 and <= 299. + SetError(&errMsg). // Unmarshal response into struct automatically if status code >= 400. + EnableDump(). // Enable dump at request level to help troubleshoot, log content only when an unexpected exception occurs. Get("https://api.github.com/users/{username}/repos") +if err != nil { + // Handle error. + // ... + return +} +if resp.IsSuccess() { + // Handle result. + // ... + return +} +if resp.IsError() { + // Handle errMsg. + // ... + return +} +// Handle unexpected response (corner case). +err = fmt.Errorf("got unexpected response, raw dump:\n%s", resp.Dump()) +return ``` **Videos** @@ -72,7 +92,7 @@ resp, err := client.R(). // Use R() to create a request and set with chainable r **More** -Check more introduction, tutorials, examples and API references on the [official website](https://req.cool/). +Check more introduction, tutorials, examples, best practices and API references on the [official website](https://req.cool/). ## License From b81cf85e6edc8dcf8503bd7f3591ffa539017b37 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 8 Jun 2022 11:20:40 +0800 Subject: [PATCH 499/843] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 09a502a8..31c17e4b 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ if resp.IsError() { } // Handle unexpected response (corner case). err = fmt.Errorf("got unexpected response, raw dump:\n%s", resp.Dump()) -return +// ... ``` **Videos** From b19f12d07a1a19c5ca3ebd7afc2050f1b561a66b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 8 Jun 2022 20:06:38 +0800 Subject: [PATCH 500/843] Run user-defined request middleware after internal middleware Make it possible to read generated info in request middleware, e.g. record req.URL.Path in request middleware. --- client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 3022ce93..492772f4 100644 --- a/client.go +++ b/client.go @@ -966,12 +966,12 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } for { - for _, f := range r.client.udBeforeRequest { + for _, f := range r.client.beforeRequest { if err = f(r.client, r); err != nil { return } } - for _, f := range r.client.beforeRequest { + for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { return } From 895cbd0b4cb27e60aac630fe2e52c742038a00bf Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 29 Jun 2022 00:08:44 +0800 Subject: [PATCH 501/843] :bug: Fix non-pointer panic --- middleware.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/middleware.go b/middleware.go index e4a7b939..c68aa155 100644 --- a/middleware.go +++ b/middleware.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" "mime/multipart" @@ -14,6 +13,8 @@ import ( "reflect" "strings" "time" + + "github.com/imroc/req/v3/internal/util" ) type ( @@ -272,7 +273,7 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } func parseResponseBody(c *Client, r *Response) (err error) { - if r.StatusCode == http.StatusNoContent { + if nil == r.Response || r.StatusCode == http.StatusNoContent { return } if r.Request.Result != nil && r.IsSuccess() { From 4994a5989d1567cbb8c13fbd59168029081a0781 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 29 Jun 2022 00:30:24 +0800 Subject: [PATCH 502/843] :bug: Fix non-pointer panic --- client.go | 9 +++++++-- retry_test.go | 12 ++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 492772f4..920df112 100644 --- a/client.go +++ b/client.go @@ -7,8 +7,6 @@ import ( "encoding/json" "encoding/xml" "errors" - "github.com/imroc/req/v3/internal/util" - "golang.org/x/net/publicsuffix" "io" "io/ioutil" "net" @@ -19,6 +17,9 @@ import ( "reflect" "strings" "time" + + "github.com/imroc/req/v3/internal/util" + "golang.org/x/net/publicsuffix" ) // DefaultClient returns the global default Client. @@ -1098,6 +1099,10 @@ func (c *Client) do(r *Request) (resp *Response, err error) { resp.error = nil } + if nil != err { + return + } + for _, f := range r.client.afterResponse { if err = f(r.client, resp); err != nil { return diff --git a/retry_test.go b/retry_test.go index cfd86933..7513ee1a 100644 --- a/retry_test.go +++ b/retry_test.go @@ -158,3 +158,15 @@ func TestRetryWithModify(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, 2, resp.Request.RetryAttempt) } + +func TestRetryFalse(t *testing.T) { + + resp, err := tc().R(). + SetRetryCount(1). + SetRetryCondition(func(resp *Response, err error) bool { + return false + }).Get("https://non-exists-host.com.cn") + assertNotNil(t, err) + assertIsNil(t, resp.Response) + assertEqual(t, 0, resp.Request.RetryAttempt) +} From 501e793bfbc95789fbb1381a295daffb84a9fb18 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 29 Jun 2022 09:56:22 +0800 Subject: [PATCH 503/843] ajust code style --- client.go | 2 +- middleware.go | 2 +- retry_test.go | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 920df112..c05d4ab8 100644 --- a/client.go +++ b/client.go @@ -1099,7 +1099,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { resp.error = nil } - if nil != err { + if err != nil { return } diff --git a/middleware.go b/middleware.go index c68aa155..d83dcec2 100644 --- a/middleware.go +++ b/middleware.go @@ -273,7 +273,7 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } func parseResponseBody(c *Client, r *Response) (err error) { - if nil == r.Response || r.StatusCode == http.StatusNoContent { + if r.Response == nil || r.StatusCode == http.StatusNoContent { return } if r.Request.Result != nil && r.IsSuccess() { diff --git a/retry_test.go b/retry_test.go index 7513ee1a..137423ec 100644 --- a/retry_test.go +++ b/retry_test.go @@ -160,7 +160,6 @@ func TestRetryWithModify(t *testing.T) { } func TestRetryFalse(t *testing.T) { - resp, err := tc().R(). SetRetryCount(1). SetRetryCondition(func(resp *Response, err error) bool { From e1a787272471136c9b8ca5cc42a1caedae72f75a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 17 Jun 2022 17:17:24 +0800 Subject: [PATCH 504/843] big internal refactor 1. Move http2 logic into internal http2 package. 2. Extract test functions into internal tests package. 3. Extract dump logic into internal dump package. --- client.go | 16 +- client_test.go | 184 +- client_wrapper_test.go | 14 +- decode_test.go | 9 +- dump.go | 117 +- h2_errors.go | 138 - http_request.go | 7 - internal/common/error.go | 7 + internal/dump/dump.go | 137 + internal/header/header.go | 3 + .../http2/client_conn_pool.go | 106 +- .../http2/databuffer.go | 36 +- .../http2/databuffer_test.go | 16 +- internal/http2/errors.go | 138 + .../http2/errors_test.go | 6 +- h2_flow.go => internal/http2/flow.go | 16 +- .../http2/flow_test.go | 10 +- h2_frame.go => internal/http2/frame.go | 761 ++--- .../http2/frame_test.go | 579 ++-- h2_go115.go => internal/http2/go115.go | 7 +- h2_gotrack.go => internal/http2/gotrack.go | 44 +- .../http2/gotrack_test.go | 31 +- .../http2/headermap.go | 28 +- h2.go => internal/http2/http2.go | 126 +- h2_test.go => internal/http2/http2_test.go | 47 +- .../http2/not_go115.go | 4 +- h2_pipe.go => internal/http2/pipe.go | 42 +- .../http2/pipe_test.go | 16 +- .../http2/server_test.go | 2566 ++++++++--------- h2_trace.go => internal/http2/trace.go | 22 +- .../http2/transport.go | 755 +++-- .../http2/transport_go117_test.go | 47 +- .../http2/transport_test.go | 830 ++---- internal/socks/socks_test.go | 27 +- internal/tests/assert.go | 111 + internal/tests/condition.go | 17 + internal/tests/net.go | 17 + internal/tests/reader.go | 10 + internal/tests/transport.go | 139 + internal/tls/conn.go | 35 + internal/transport/transport.go | 100 + logger_test.go | 5 +- req_test.go | 99 +- request.go | 3 +- request_test.go | 176 +- request_wrapper_test.go | 9 +- response_test.go | 29 +- retry_test.go | 41 +- textproto_reader.go | 9 +- transfer.go | 11 +- transport.go | 46 +- transport_internal_test.go | 26 +- transport_test.go | 75 +- transport_wrapper.go | 121 + 54 files changed, 4115 insertions(+), 3856 deletions(-) delete mode 100644 h2_errors.go create mode 100644 internal/common/error.go create mode 100644 internal/dump/dump.go create mode 100644 internal/header/header.go rename h2_client_conn_pool.go => internal/http2/client_conn_pool.go (68%) rename h2_databuffer.go => internal/http2/databuffer.go (80%) rename h2_databuffer_test.go => internal/http2/databuffer_test.go (91%) create mode 100644 internal/http2/errors.go rename h2_errors_test.go => internal/http2/errors_test.go (85%) rename h2_flow.go => internal/http2/flow.go (77%) rename h2_flow_test.go => internal/http2/flow_test.go (95%) rename h2_frame.go => internal/http2/frame.go (63%) rename h2_frame_test.go => internal/http2/frame_test.go (69%) rename h2_go115.go => internal/http2/go115.go (63%) rename h2_gotrack.go => internal/http2/gotrack.go (73%) rename h2_gotrack_test.go => internal/http2/gotrack_test.go (58%) rename h2_headermap.go => internal/http2/headermap.go (64%) rename h2.go => internal/http2/http2.go (58%) rename h2_test.go => internal/http2/http2_test.go (63%) rename h2_not_go115.go => internal/http2/not_go115.go (79%) rename h2_pipe.go => internal/http2/pipe.go (73%) rename h2_pipe_test.go => internal/http2/pipe_test.go (94%) rename h2_server_test.go => internal/http2/server_test.go (61%) rename h2_trace.go => internal/http2/trace.go (65%) rename h2_transport.go => internal/http2/transport.go (76%) rename h2_transport_go117_test.go => internal/http2/transport_go117_test.go (81%) rename h2_transport_test.go => internal/http2/transport_test.go (87%) create mode 100644 internal/tests/assert.go create mode 100644 internal/tests/condition.go create mode 100644 internal/tests/net.go create mode 100644 internal/tests/reader.go create mode 100644 internal/tests/transport.go create mode 100644 internal/tls/conn.go create mode 100644 internal/transport/transport.go create mode 100644 transport_wrapper.go diff --git a/client.go b/client.go index c05d4ab8..f791f538 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/util" "golang.org/x/net/publicsuffix" ) @@ -34,7 +35,7 @@ func SetDefaultClient(c *Client) { } } -var defaultClient *Client = C() +var defaultClient = C() // Client is the req's http client. type Client struct { @@ -58,7 +59,6 @@ type Client struct { scheme string log Logger t *Transport - t2 *http2Transport dumpOptions *DumpOptions httpClient *http.Client beforeRequest []RequestMiddleware @@ -652,7 +652,7 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { } c.dumpOptions = opt if c.t.dump != nil { - c.t.dump.DumpOptions = opt + c.t.dump.SetOptions(dumpOptions{opt}) } return c } @@ -882,14 +882,15 @@ func NewClient() *Client { // Clone copy and returns the Client func (c *Client) Clone() *Client { t := c.t.Clone() - t2, _ := http2ConfigureTransports(t) + t2, _ := http2.ConfigureTransports(transportImpl{t}) + t.t2 = t2 + client := *c.httpClient client.Transport = t cc := *c cc.httpClient = &client cc.t = t - cc.t2 = t2 cc.Headers = cloneHeaders(c.Headers) cc.Cookies = cloneCookies(c.Cookies) @@ -922,7 +923,9 @@ func C() *Client { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } - t2, _ := http2ConfigureTransports(t) + t2, _ := http2.ConfigureTransports(transportImpl{t}) + t.t2 = t2 + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) httpClient := &http.Client{ Transport: t, @@ -945,7 +948,6 @@ func C() *Client { log: createDefaultLogger(), httpClient: httpClient, t: t, - t2: t2, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, diff --git a/client_test.go b/client_test.go index e4bce5b8..f1ab64de 100644 --- a/client_test.go +++ b/client_test.go @@ -20,23 +20,23 @@ func TestAllowGetMethodPayload(t *testing.T) { c := tc() resp, err := c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) c.EnableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) c.DisableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) } func TestSetTLSHandshakeTimeout(t *testing.T) { timeout := 2 * time.Second c := tc().SetTLSHandshakeTimeout(timeout) - assertEqual(t, timeout, c.t.TLSHandshakeTimeout) + tests.AssertEqual(t, timeout, c.t.TLSHandshakeTimeout) } func TestSetDial(t *testing.T) { @@ -46,7 +46,7 @@ func TestSetDial(t *testing.T) { } c := tc().SetDial(testDial) _, err := c.t.DialContext(nil, "", "") - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetDialTLS(t *testing.T) { @@ -56,7 +56,7 @@ func TestSetDialTLS(t *testing.T) { } c := tc().SetDialTLS(testDialTLS) _, err := c.t.DialTLSContext(nil, "", "") - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetFuncs(t *testing.T) { @@ -74,31 +74,31 @@ func TestSetFuncs(t *testing.T) { SetXmlUnmarshal(unmarshalFunc) _, err := c.jsonMarshal(nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) err = c.jsonUnmarshal(nil, nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) _, err = c.xmlMarshal(nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) err = c.xmlUnmarshal(nil, nil) - assertEqual(t, testErr, err) + tests.AssertEqual(t, testErr, err) } func TestSetCookieJar(t *testing.T) { c := tc().SetCookieJar(nil) - assertEqual(t, nil, c.httpClient.Jar) + tests.AssertEqual(t, nil, c.httpClient.Jar) } func TestTraceAll(t *testing.T) { c := tc().EnableTraceAll() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, resp.TraceInfo().TotalTime > 0) + tests.AssertEqual(t, true, resp.TraceInfo().TotalTime > 0) c.DisableTraceAll() resp, err = c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, resp.TraceInfo().TotalTime == 0) + tests.AssertEqual(t, true, resp.TraceInfo().TotalTime == 0) } func TestOnAfterResponse(t *testing.T) { @@ -108,21 +108,21 @@ func TestOnAfterResponse(t *testing.T) { return nil }) len2 := len(c.afterResponse) - assertEqual(t, true, len1+1 == len2) + tests.AssertEqual(t, true, len1+1 == len2) } func TestOnBeforeRequest(t *testing.T) { c := tc().OnBeforeRequest(func(client *Client, request *Request) error { return nil }) - assertEqual(t, true, len(c.udBeforeRequest) == 1) + tests.AssertEqual(t, true, len(c.udBeforeRequest) == 1) } func TestSetProxyURL(t *testing.T) { c := tc().SetProxyURL("http://dummy.proxy.local") u, err := c.t.Proxy(nil) - assertNoError(t, err) - assertEqual(t, "http://dummy.proxy.local", u.String()) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "http://dummy.proxy.local", u.String()) } func TestSetProxy(t *testing.T) { @@ -130,23 +130,23 @@ func TestSetProxy(t *testing.T) { proxy := http.ProxyURL(u) c := tc().SetProxy(proxy) uu, err := c.t.Proxy(nil) - assertNoError(t, err) - assertEqual(t, u.String(), uu.String()) + tests.AssertNoError(t, err) + tests.AssertEqual(t, u.String(), uu.String()) } func TestSetCommonContentType(t *testing.T) { c := tc().SetCommonContentType(jsonContentType) - assertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) + tests.AssertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) } func TestSetCommonHeader(t *testing.T) { c := tc().SetCommonHeader("my-header", "my-value") - assertEqual(t, "my-value", c.Headers.Get("my-header")) + tests.AssertEqual(t, "my-value", c.Headers.Get("my-header")) } func TestSetCommonHeaderNonCanonical(t *testing.T) { c := tc().SetCommonHeaderNonCanonical("my-Header", "my-value") - assertEqual(t, "my-value", c.Headers["my-Header"][0]) + tests.AssertEqual(t, "my-value", c.Headers["my-Header"][0]) } func TestSetCommonHeaders(t *testing.T) { @@ -154,48 +154,48 @@ func TestSetCommonHeaders(t *testing.T) { "header1": "value1", "header2": "value2", }) - assertEqual(t, "value1", c.Headers.Get("header1")) - assertEqual(t, "value2", c.Headers.Get("header2")) + tests.AssertEqual(t, "value1", c.Headers.Get("header1")) + tests.AssertEqual(t, "value2", c.Headers.Get("header2")) } func TestSetCommonHeadersNonCanonical(t *testing.T) { c := tc().SetCommonHeadersNonCanonical(map[string]string{ "my-Header": "my-value", }) - assertEqual(t, "my-value", c.Headers["my-Header"][0]) + tests.AssertEqual(t, "my-value", c.Headers["my-Header"][0]) } func TestSetCommonBasicAuth(t *testing.T) { c := tc().SetCommonBasicAuth("imroc", "123456") - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) + tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) } func TestSetCommonBearerAuthToken(t *testing.T) { c := tc().SetCommonBearerAuthToken("123456") - assertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) + tests.AssertEqual(t, "Bearer 123456", c.Headers.Get("Authorization")) } func TestSetUserAgent(t *testing.T) { c := tc().SetUserAgent("test") - assertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) + tests.AssertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) } func TestAutoDecode(t *testing.T) { c := tc().DisableAutoDecode() resp, err := c.R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, toGbk("我是roc"), resp.Bytes()) + tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.EnableAutoDecode().R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentType("html").R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, toGbk("我是roc"), resp.Bytes()) + tests.AssertEqual(t, toGbk("我是roc"), resp.Bytes()) resp, err = c.SetAutoDecodeContentType("text").R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentTypeFunc(func(contentType string) bool { if strings.Contains(contentType, "text") { return true @@ -203,39 +203,39 @@ func TestAutoDecode(t *testing.T) { return false }).R().Get("/gbk") assertSuccess(t, resp, err) - assertEqual(t, "我是roc", resp.String()) + tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeAllContentType().R().Get("/gbk-no-charset") assertSuccess(t, resp, err) - assertContains(t, resp.String(), "我是roc", true) + tests.AssertContains(t, resp.String(), "我是roc", true) } func TestSetTimeout(t *testing.T) { timeout := 100 * time.Second c := tc().SetTimeout(timeout) - assertEqual(t, timeout, c.httpClient.Timeout) + tests.AssertEqual(t, timeout, c.httpClient.Timeout) } func TestSetLogger(t *testing.T) { l := createDefaultLogger() c := tc().SetLogger(l) - assertEqual(t, l, c.log) + tests.AssertEqual(t, l, c.log) c.SetLogger(nil) - assertEqual(t, &disableLogger{}, c.log) + tests.AssertEqual(t, &disableLogger{}, c.log) } func TestSetScheme(t *testing.T) { c := tc().SetScheme("https") - assertEqual(t, "https", c.scheme) + tests.AssertEqual(t, "https", c.scheme) } func TestDebugLog(t *testing.T) { c := tc().EnableDebugLog() - assertEqual(t, true, c.DebugLog) + tests.AssertEqual(t, true, c.DebugLog) c.DisableDebugLog() - assertEqual(t, false, c.DebugLog) + tests.AssertEqual(t, false, c.DebugLog) } func TestSetCommonCookies(t *testing.T) { @@ -245,25 +245,25 @@ func TestSetCommonCookies(t *testing.T) { Value: "test", }).R().SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "test=test", headers.Get("Cookie")) + tests.AssertEqual(t, "test=test", headers.Get("Cookie")) } func TestSetCommonQueryString(t *testing.T) { resp, err := tc().SetCommonQueryString("test=test").R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestSetCommonPathParams(t *testing.T) { c := tc().SetCommonPathParams(map[string]string{"test": "test"}) - assertNotNil(t, c.PathParams) - assertEqual(t, "test", c.PathParams["test"]) + tests.AssertNotNil(t, c.PathParams) + tests.AssertEqual(t, "test", c.PathParams["test"]) } func TestSetCommonPathParam(t *testing.T) { c := tc().SetCommonPathParam("test", "test") - assertNotNil(t, c.PathParams) - assertEqual(t, "test", c.PathParams["test"]) + tests.AssertNotNil(t, c.PathParams) + tests.AssertEqual(t, "test", c.PathParams["test"]) } func TestAddCommonQueryParam(t *testing.T) { @@ -272,75 +272,75 @@ func TestAddCommonQueryParam(t *testing.T) { AddCommonQueryParam("test", "2"). R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=1&test=2", resp.String()) + tests.AssertEqual(t, "test=1&test=2", resp.String()) } func TestSetCommonQueryParam(t *testing.T) { resp, err := tc().SetCommonQueryParam("test", "test").R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestSetCommonQueryParams(t *testing.T) { resp, err := tc().SetCommonQueryParams(map[string]string{"test": "test"}).R().Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "test=test", resp.String()) + tests.AssertEqual(t, "test=test", resp.String()) } func TestInsecureSkipVerify(t *testing.T) { c := tc().EnableInsecureSkipVerify() - assertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) c.DisableInsecureSkipVerify() - assertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) } func TestSetTLSClientConfig(t *testing.T) { config := &tls.Config{InsecureSkipVerify: true} c := tc().SetTLSClientConfig(config) - assertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, config, c.t.TLSClientConfig) } func TestCompression(t *testing.T) { c := tc().DisableCompression() - assertEqual(t, true, c.t.DisableCompression) + tests.AssertEqual(t, true, c.t.DisableCompression) c.EnableCompression() - assertEqual(t, false, c.t.DisableCompression) + tests.AssertEqual(t, false, c.t.DisableCompression) } func TestKeepAlives(t *testing.T) { c := tc().DisableKeepAlives() - assertEqual(t, true, c.t.DisableKeepAlives) + tests.AssertEqual(t, true, c.t.DisableKeepAlives) c.EnableKeepAlives() - assertEqual(t, false, c.t.DisableKeepAlives) + tests.AssertEqual(t, false, c.t.DisableKeepAlives) } func TestRedirect(t *testing.T) { _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") - assertNotNil(t, err) - assertContains(t, err.Error(), "redirect is disabled", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "redirect is disabled", true) _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") - assertNotNil(t, err) - assertContains(t, err.Error(), "stopped after 3 redirects", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") - assertNotNil(t, err) - assertContains(t, err.Error(), "different domain name is not allowed", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) _, err = tc().SetRedirectPolicy(SameHostRedirectPolicy()).R().Get("/redirect-to-other") - assertNotNil(t, err) - assertContains(t, err.Error(), "different host name is not allowed", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "different host name is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedHostRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - assertNotNil(t, err) - assertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "redirect host [dummy.local] is not allowed", true) _, err = tc().SetRedirectPolicy(AllowedDomainRedirectPolicy("localhost", "127.0.0.1")).R().Get("/redirect-to-other") - assertNotNil(t, err) - assertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "redirect domain [dummy.local] is not allowed", true) c := tc().SetRedirectPolicy(AlwaysCopyHeaderRedirectPolicy("Authorization")) newHeader := make(http.Header) @@ -351,29 +351,29 @@ func TestRedirect(t *testing.T) { }, []*http.Request{&http.Request{ Header: oldHeader, }}) - assertEqual(t, "test", newHeader.Get("Authorization")) + tests.AssertEqual(t, "test", newHeader.Get("Authorization")) } func TestGetTLSClientConfig(t *testing.T) { c := tc() config := c.GetTLSClientConfig() - assertEqual(t, true, c.t.TLSClientConfig != nil) - assertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, true, c.t.TLSClientConfig != nil) + tests.AssertEqual(t, config, c.t.TLSClientConfig) } func TestSetRootCertFromFile(t *testing.T) { c := tc().SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) - assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetRootCertFromString(t *testing.T) { c := tc().SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) - assertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) } func TestSetCerts(t *testing.T) { c := tc().SetCerts(tls.Certificate{}, tls.Certificate{}) - assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) + tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) } func TestSetCertFromFile(t *testing.T) { @@ -381,7 +381,7 @@ func TestSetCertFromFile(t *testing.T) { tests.GetTestFilePath("sample-client.pem"), tests.GetTestFilePath("sample-client-key.pem"), ) - assertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) + tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) } func TestSetOutputDirectory(t *testing.T) { @@ -393,13 +393,13 @@ func TestSetOutputDirectory(t *testing.T) { assertSuccess(t, resp, err) content := string(getTestFileContent(t, outFile)) os.Remove(tests.GetTestFilePath(outFile)) - assertEqual(t, "TestGet: text response", content) + tests.AssertEqual(t, "TestGet: text response", content) } func TestSetBaseURL(t *testing.T) { baseURL := "http://dummy-req.local/test" resp, _ := tc().SetTimeout(time.Nanosecond).SetBaseURL(baseURL).R().Get("/req") - assertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) + tests.AssertEqual(t, baseURL+"/req", resp.Request.RawRequest.URL.String()) } func TestSetCommonFormDataFromValues(t *testing.T) { @@ -411,7 +411,7 @@ func TestSetCommonFormDataFromValues(t *testing.T) { R().SetResult(&gotForm). Post("/form") assertSuccess(t, resp, err) - assertEqual(t, "test", gotForm.Get("test")) + tests.AssertEqual(t, "test", gotForm.Get("test")) } func TestSetCommonFormData(t *testing.T) { @@ -424,7 +424,7 @@ func TestSetCommonFormData(t *testing.T) { SetResult(&form). Post("/form") assertSuccess(t, resp, err) - assertEqual(t, "test", form.Get("test")) + tests.AssertEqual(t, "test", form.Get("test")) } func TestClientClone(t *testing.T) { @@ -451,15 +451,15 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { c.DisableAutoReadResponse() resp, err := c.R().Get("/") assertSuccess(t, resp, err) - assertEqual(t, "", resp.String()) + tests.AssertEqual(t, "", resp.String()) result, err := resp.ToString() - assertNoError(t, err) - assertEqual(t, "TestGet: text response", result) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "TestGet: text response", result) resp, err = c.R().Get("/") assertSuccess(t, resp, err) _, err = ioutil.ReadAll(resp.Body) - assertNoError(t, err) + tests.AssertNoError(t, err) } func testEnableDumpAll(t *testing.T, fn func(c *Client) (de dumpExpected)) { @@ -471,10 +471,10 @@ func testEnableDumpAll(t *testing.T, fn func(c *Client) (de dumpExpected)) { resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := buff.String() - assertContains(t, dump, "user-agent", de.ReqHeader) - assertContains(t, dump, "test body", de.ReqBody) - assertContains(t, dump, "date", de.RespHeader) - assertContains(t, dump, "testpost: text response", de.RespBody) + tests.AssertContains(t, dump, "user-agent", de.ReqHeader) + tests.AssertContains(t, dump, "test body", de.ReqBody) + tests.AssertContains(t, dump, "date", de.RespHeader) + tests.AssertContains(t, dump, "testpost: text response", de.RespBody) } c := tc() testDump(c) @@ -554,15 +554,15 @@ func TestEnableDumpAllToFile(t *testing.T) { assertSuccess(t, resp, err) dump := string(getTestFileContent(t, dumpFile)) os.Remove(tests.GetTestFilePath(dumpFile)) - assertContains(t, dump, "user-agent", true) - assertContains(t, dump, "test body", true) - assertContains(t, dump, "date", true) - assertContains(t, dump, "testpost: text response", true) + tests.AssertContains(t, dump, "user-agent", true) + tests.AssertContains(t, dump, "test body", true) + tests.AssertContains(t, dump, "date", true) + tests.AssertContains(t, dump, "testpost: text response", true) } func TestEnableDumpAllAsync(t *testing.T) { c := tc() buf := new(bytes.Buffer) c.EnableDumpAllTo(buf).EnableDumpAllAsync() - assertEqual(t, true, c.getDumpOptions().Async) + tests.AssertEqual(t, true, c.getDumpOptions().Async) } diff --git a/client_wrapper_test.go b/client_wrapper_test.go index 8b830327..e7cf8623 100644 --- a/client_wrapper_test.go +++ b/client_wrapper_test.go @@ -22,7 +22,7 @@ func TestGlobalWrapper(t *testing.T) { form := make(url.Values) form.Add("test", "test") - assertAllNotNil(t, + tests.AssertAllNotNil(t, SetCommonError(nil), SetCookieJar(nil), SetDialTLS(nil), @@ -122,19 +122,19 @@ func TestGlobalWrapper(t *testing.T) { os.Remove(tests.GetTestFilePath("tmpdump.out")) config := GetTLSClientConfig() - assertEqual(t, config, DefaultClient().t.TLSClientConfig) + tests.AssertEqual(t, config, DefaultClient().t.TLSClientConfig) r := R() - assertEqual(t, true, r != nil) + tests.AssertEqual(t, true, r != nil) c := C() c.SetTimeout(10 * time.Second) SetDefaultClient(c) - assertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - assertEqual(t, GetClient(), DefaultClient().httpClient) + tests.AssertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) + tests.AssertEqual(t, GetClient(), DefaultClient().httpClient) r = NewRequest() - assertEqual(t, true, r != nil) + tests.AssertEqual(t, true, r != nil) c = NewClient() - assertEqual(t, true, c != nil) + tests.AssertEqual(t, true, c != nil) } diff --git a/decode_test.go b/decode_test.go index fc810e83..e65a8ea6 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1,6 +1,7 @@ package req import ( + "github.com/imroc/req/v3/internal/tests" "testing" ) @@ -8,9 +9,9 @@ func TestPeekDrain(t *testing.T) { a := autoDecodeReadCloser{peek: []byte("test")} p := make([]byte, 2) n, _ := a.peekDrain(p) - assertEqual(t, 2, n) - assertEqual(t, true, a.peek != nil) + tests.AssertEqual(t, 2, n) + tests.AssertEqual(t, true, a.peek != nil) n, _ = a.peekDrain(p) - assertEqual(t, 2, n) - assertEqual(t, true, a.peek == nil) + tests.AssertEqual(t, 2, n) + tests.AssertEqual(t, true, a.peek == nil) } diff --git a/dump.go b/dump.go index 250b3f89..627f0b8b 100644 --- a/dump.go +++ b/dump.go @@ -1,7 +1,7 @@ package req import ( - "context" + "github.com/imroc/req/v3/internal/dump" "io" "os" ) @@ -25,60 +25,36 @@ func (do *DumpOptions) Clone() *DumpOptions { return &d } -func (d *dumper) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { - return &dumpReadCloser{rc, d} -} - -type dumpReadCloser struct { - io.ReadCloser - dump *dumper -} - -func (r *dumpReadCloser) Read(p []byte) (n int, err error) { - n, err = r.ReadCloser.Read(p) - r.dump.dump(p[:n]) - if err == io.EOF { - r.dump.dump([]byte("\r\n")) - } - return +type dumpOptions struct { + *DumpOptions } -func (d *dumper) WrapWriteCloser(rc io.WriteCloser) io.WriteCloser { - return &dumpWriteCloser{rc, d} +func (o dumpOptions) Output() io.Writer { + return o.DumpOptions.Output } -type dumpWriteCloser struct { - io.WriteCloser - dump *dumper +func (o dumpOptions) RequestHeader() bool { + return o.DumpOptions.RequestHeader } -func (w *dumpWriteCloser) Write(p []byte) (n int, err error) { - n, err = w.WriteCloser.Write(p) - w.dump.dump(p[:n]) - return +func (o dumpOptions) RequestBody() bool { + return o.DumpOptions.RequestBody } -type dumpWriter struct { - w io.Writer - dump *dumper +func (o dumpOptions) ResponseHeader() bool { + return o.DumpOptions.ResponseHeader } -func (w *dumpWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - w.dump.dump(p[:n]) - return +func (o dumpOptions) ResponseBody() bool { + return o.DumpOptions.ResponseBody } -func (d *dumper) WrapWriter(w io.Writer) io.Writer { - return &dumpWriter{ - w: w, - dump: d, - } +func (o dumpOptions) Async() bool { + return o.DumpOptions.Async } -type dumper struct { - *DumpOptions - ch chan []byte +func (o dumpOptions) Clone() dump.Options { + return dumpOptions{o.DumpOptions.Clone()} } func newDefaultDumpOptions() *DumpOptions { @@ -91,67 +67,12 @@ func newDefaultDumpOptions() *DumpOptions { } } -func newDumper(opt *DumpOptions) *dumper { +func newDumper(opt *DumpOptions) *dump.Dumper { if opt == nil { opt = newDefaultDumpOptions() } if opt.Output == nil { opt.Output = os.Stderr } - d := &dumper{ - DumpOptions: opt, - ch: make(chan []byte, 20), - } - return d -} - -func (d *dumper) Clone() *dumper { - if d == nil { - return nil - } - return &dumper{ - DumpOptions: d.DumpOptions.Clone(), - ch: make(chan []byte, 20), - } -} - -func (d *dumper) dump(p []byte) { - if len(p) == 0 { - return - } - if d.Async { - b := make([]byte, len(p)) - copy(b, p) - d.ch <- b - return - } - d.Output.Write(p) -} - -func (d *dumper) Stop() { - d.ch <- nil -} - -func (d *dumper) Start() { - for b := range d.ch { - if b == nil { - return - } - d.Output.Write(b) - } -} - -type dumperKeyType int - -const dumperKey dumperKeyType = iota - -func getDumpers(ctx context.Context, dump *dumper) []*dumper { - dumps := []*dumper{} - if dump != nil { - dumps = append(dumps, dump) - } - if d, ok := ctx.Value(dumperKey).(*dumper); ok { - dumps = append(dumps, d) - } - return dumps + return dump.NewDumper(dumpOptions{opt}) } diff --git a/h2_errors.go b/h2_errors.go deleted file mode 100644 index 24cc07d1..00000000 --- a/h2_errors.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package req - -import ( - "errors" - "fmt" -) - -// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. -type http2ErrCode uint32 - -const ( - http2ErrCodeNo http2ErrCode = 0x0 - http2ErrCodeProtocol http2ErrCode = 0x1 - http2ErrCodeInternal http2ErrCode = 0x2 - http2ErrCodeFlowControl http2ErrCode = 0x3 - http2ErrCodeSettingsTimeout http2ErrCode = 0x4 - http2ErrCodeStreamClosed http2ErrCode = 0x5 - http2ErrCodeFrameSize http2ErrCode = 0x6 - http2ErrCodeRefusedStream http2ErrCode = 0x7 - http2ErrCodeCancel http2ErrCode = 0x8 - http2ErrCodeCompression http2ErrCode = 0x9 - http2ErrCodeConnect http2ErrCode = 0xa - http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb - http2ErrCodeInadequateSecurity http2ErrCode = 0xc - http2ErrCodeHTTP11Required http2ErrCode = 0xd -) - -var http2errCodeName = map[http2ErrCode]string{ - http2ErrCodeNo: "NO_ERROR", - http2ErrCodeProtocol: "PROTOCOL_ERROR", - http2ErrCodeInternal: "INTERNAL_ERROR", - http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", - http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", - http2ErrCodeStreamClosed: "STREAM_CLOSED", - http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", - http2ErrCodeRefusedStream: "REFUSED_STREAM", - http2ErrCodeCancel: "CANCEL", - http2ErrCodeCompression: "COMPRESSION_ERROR", - http2ErrCodeConnect: "CONNECT_ERROR", - http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", - http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", - http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", -} - -func (e http2ErrCode) String() string { - if s, ok := http2errCodeName[e]; ok { - return s - } - return fmt.Sprintf("unknown error code 0x%x", uint32(e)) -} - -func (e http2ErrCode) stringToken() string { - if s, ok := http2errCodeName[e]; ok { - return s - } - return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) -} - -// ConnectionError is an error that results in the termination of the -// entire connection. -type http2ConnectionError http2ErrCode - -func (e http2ConnectionError) Error() string { - return fmt.Sprintf("connection error: %s", http2ErrCode(e)) -} - -// StreamError is an error that only affects one stream within an -// HTTP/2 connection. -type http2StreamError struct { - StreamID uint32 - Code http2ErrCode - Cause error // optional additional detail -} - -// errFromPeer is a sentinel error value for StreamError.Cause to -// indicate that the StreamError was sent from the peer over the wire -// and wasn't locally generated in the Transport. -var errFromPeer = errors.New("received from peer") - -func http2streamError(id uint32, code http2ErrCode) http2StreamError { - return http2StreamError{StreamID: id, Code: code} -} - -func (e http2StreamError) Error() string { - if e.Cause != nil { - return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) - } - return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) -} - -// connError represents an HTTP/2 ConnectionError error code, along -// with a string (for debugging) explaining why. -// -// Errors of this type are only returned by the frame parser functions -// and converted into ConnectionError(Code), after stashing away -// the Reason into the Framer's errDetail field, accessible via -// the (*Framer).ErrorDetail method. -type http2connError struct { - Code http2ErrCode // the ConnectionError error code - Reason string // additional reason -} - -func (e http2connError) Error() string { - return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) -} - -type http2pseudoHeaderError string - -func (e http2pseudoHeaderError) Error() string { - return fmt.Sprintf("invalid pseudo-header %q", string(e)) -} - -type http2duplicatePseudoHeaderError string - -func (e http2duplicatePseudoHeaderError) Error() string { - return fmt.Sprintf("duplicate pseudo-header %q", string(e)) -} - -type http2headerFieldNameError string - -func (e http2headerFieldNameError) Error() string { - return fmt.Sprintf("invalid header field name %q", string(e)) -} - -type http2headerFieldValueError string - -func (e http2headerFieldValueError) Error() string { - return fmt.Sprintf("invalid header field value for %q", string(e)) -} - -var ( - errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") - errPseudoAfterRegular = errors.New("pseudo header field after regular") -) diff --git a/http_request.go b/http_request.go index c94add39..2d471b4d 100644 --- a/http_request.go +++ b/http_request.go @@ -136,13 +136,6 @@ func reqExpectsContinue(r *http.Request) bool { return hasToken(headerGet(r.Header, "Expect"), "100-continue") } -func reqWantsHttp10KeepAlive(r *http.Request) bool { - if r.ProtoMajor != 1 || r.ProtoMinor != 0 { - return false - } - return hasToken(headerGet(r.Header, "Connection"), "keep-alive") -} - func reqWantsClose(r *http.Request) bool { if r.Close { return true diff --git a/internal/common/error.go b/internal/common/error.go new file mode 100644 index 00000000..c6bf3edd --- /dev/null +++ b/internal/common/error.go @@ -0,0 +1,7 @@ +package common + +import "errors" + +// ErrRequestCanceled is a copy of net/http's common.ErrRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var ErrRequestCanceled = errors.New("net/http: request canceled") diff --git a/internal/dump/dump.go b/internal/dump/dump.go new file mode 100644 index 00000000..f36630bf --- /dev/null +++ b/internal/dump/dump.go @@ -0,0 +1,137 @@ +package dump + +import ( + "context" + "io" +) + +// Options controls the dump behavior. +type Options interface { + Output() io.Writer + RequestHeader() bool + RequestBody() bool + ResponseHeader() bool + ResponseBody() bool + Async() bool + Clone() Options +} + +func (d *Dumper) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { + return &dumpReadCloser{rc, d} +} + +type dumpReadCloser struct { + io.ReadCloser + dump *Dumper +} + +func (r *dumpReadCloser) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + r.dump.Dump(p[:n]) + if err == io.EOF { + r.dump.Dump([]byte("\r\n")) + } + return +} + +func (d *Dumper) WrapWriteCloser(rc io.WriteCloser) io.WriteCloser { + return &dumpWriteCloser{rc, d} +} + +type dumpWriteCloser struct { + io.WriteCloser + dump *Dumper +} + +func (w *dumpWriteCloser) Write(p []byte) (n int, err error) { + n, err = w.WriteCloser.Write(p) + w.dump.Dump(p[:n]) + return +} + +type dumpWriter struct { + w io.Writer + dump *Dumper +} + +func (w *dumpWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.dump.Dump(p[:n]) + return +} + +func (d *Dumper) WrapWriter(w io.Writer) io.Writer { + return &dumpWriter{ + w: w, + dump: d, + } +} + +type Dumper struct { + Options + ch chan []byte +} + +func NewDumper(opt Options) *Dumper { + d := &Dumper{ + Options: opt, + ch: make(chan []byte, 20), + } + return d +} + +func (d *Dumper) SetOptions(opt Options) { + d.Options = opt + return +} + +func (d *Dumper) Clone() *Dumper { + if d == nil { + return nil + } + return &Dumper{ + Options: d.Options.Clone(), + ch: make(chan []byte, 20), + } +} + +func (d *Dumper) Dump(p []byte) { + if len(p) == 0 { + return + } + if d.Async() { + b := make([]byte, len(p)) + copy(b, p) + d.ch <- b + return + } + d.Output().Write(p) +} + +func (d *Dumper) Stop() { + d.ch <- nil +} + +func (d *Dumper) Start() { + for b := range d.ch { + if b == nil { + return + } + d.Output().Write(b) + } +} + +type dumperKeyType int + +const DumperKey dumperKeyType = iota + +func GetDumpers(ctx context.Context, dump *Dumper) []*Dumper { + dumps := []*Dumper{} + if dump != nil { + dumps = append(dumps, dump) + } + if d, ok := ctx.Value(DumperKey).(*Dumper); ok { + dumps = append(dumps, d) + } + return dumps +} diff --git a/internal/header/header.go b/internal/header/header.go new file mode 100644 index 00000000..52135a51 --- /dev/null +++ b/internal/header/header.go @@ -0,0 +1,3 @@ +package header + +const DefaultUserAgent = "req/v3 (https://github.com/imroc/req)" diff --git a/h2_client_conn_pool.go b/internal/http2/client_conn_pool.go similarity index 68% rename from h2_client_conn_pool.go rename to internal/http2/client_conn_pool.go index 8c2bcaa9..5c8c9958 100644 --- a/h2_client_conn_pool.go +++ b/internal/http2/client_conn_pool.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "context" @@ -13,56 +13,56 @@ import ( ) // ClientConnPool manages a pool of HTTP/2 client connections. -type http2ClientConnPool interface { +type ClientConnPool interface { // GetClientConn returns a specific HTTP/2 connection (usually // a TLS-TCP connection) to an HTTP/2 server. On success, the // returned ClientConn accounts for the upcoming RoundTrip // call, so the caller should not omit it. If the caller needs // to, ClientConn.RoundTrip can be called with a bogus // new(http.Request) to release the stream reservation. - GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) - MarkDead(*http2ClientConn) + GetClientConn(req *http.Request, addr string) (*ClientConn, error) + MarkDead(*ClientConn) } // clientConnPoolIdleCloser is the interface implemented by ClientConnPool // implementations which can close their idle connections. -type http2clientConnPoolIdleCloser interface { - http2ClientConnPool +type clientConnPoolIdleCloser interface { + ClientConnPool closeIdleConnections() } var ( - _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) - _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} + _ clientConnPoolIdleCloser = (*clientConnPool)(nil) + _ clientConnPoolIdleCloser = noDialClientConnPool{} ) // TODO: use singleflight for dialing and addConnCalls? -type http2clientConnPool struct { - t *http2Transport +type clientConnPool struct { + t *Transport mu sync.Mutex // TODO: maybe switch to RWMutex // TODO: add support for sharing conns based on cert names // (e.g. share conn for googleapis.com and appspot.com) - conns map[string][]*http2ClientConn // key is host:port - dialing map[string]*http2dialCall // currently in-flight dials - keys map[*http2ClientConn][]string - addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls + conns map[string][]*ClientConn // key is host:port + dialing map[string]*dialCall // currently in-flight dials + keys map[*ClientConn][]string + addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls } -func (p *http2clientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2dialOnMiss) +func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { + return p.getClientConn(req, addr, dialOnMiss) } const ( - http2dialOnMiss = true - http2noDialOnMiss = false + dialOnMiss = true + noDialOnMiss = false ) -func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { +func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? - if http2isConnectionCloseRequest(req) && dialOnMiss { + if isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. - http2traceGetConn(req, addr) + traceGetConn(req, addr) const singleUse = true cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) if err != nil { @@ -78,7 +78,7 @@ func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dial // the GetConn hook has already been called. // Don't call it a second time here. if !cc.getConnCalled { - http2traceGetConn(req, addr) + traceGetConn(req, addr) } cc.getConnCalled = false p.mu.Unlock() @@ -87,13 +87,13 @@ func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dial } if !dialOnMiss { p.mu.Unlock() - return nil, http2ErrNoCachedConn + return nil, ErrNoCachedConn } - http2traceGetConn(req, addr) + traceGetConn(req, addr) call := p.getStartDialLocked(req.Context(), addr) p.mu.Unlock() <-call.done - if http2shouldRetryDial(call, req) { + if shouldRetryDial(call, req) { continue } cc, err := call.res, call.err @@ -107,26 +107,26 @@ func (p *http2clientConnPool) getClientConn(req *http.Request, addr string, dial } // dialCall is an in-flight Transport dial call to a host. -type http2dialCall struct { - _ http2incomparable - p *http2clientConnPool +type dialCall struct { + _ incomparable + p *clientConnPool // the context associated with the request // that created this dialCall ctx context.Context - done chan struct{} // closed when done - res *http2ClientConn // valid after done is closed - err error // valid after done is closed + done chan struct{} // closed when done + res *ClientConn // valid after done is closed + err error // valid after done is closed } // requires p.mu is held. -func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall { +func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall { if call, ok := p.dialing[addr]; ok { // A dial is already in-flight. Don't start another. return call } - call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx} + call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx} if p.dialing == nil { - p.dialing = make(map[string]*http2dialCall) + p.dialing = make(map[string]*dialCall) } p.dialing[addr] = call go call.dial(call.ctx, addr) @@ -134,7 +134,7 @@ func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr strin } // run in its own goroutine. -func (c *http2dialCall) dial(ctx context.Context, addr string) { +func (c *dialCall) dial(ctx context.Context, addr string) { const singleUse = false // shared conn c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) close(c.done) @@ -155,7 +155,7 @@ func (c *http2dialCall) dial(ctx context.Context, addr string) { // This code decides which ones live or die. // The return value used is whether c was used. // c is never closed. -func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c net.Conn) (used bool, err error) { +func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) { p.mu.Lock() for _, cc := range p.conns[key] { if cc.CanTakeNewRequest() { @@ -166,9 +166,9 @@ func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c n call, dup := p.addConnCalls[key] if !dup { if p.addConnCalls == nil { - p.addConnCalls = make(map[string]*http2addConnCall) + p.addConnCalls = make(map[string]*addConnCall) } - call = &http2addConnCall{ + call = &addConnCall{ p: p, done: make(chan struct{}), } @@ -184,14 +184,14 @@ func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c n return !dup, nil } -type http2addConnCall struct { - _ http2incomparable - p *http2clientConnPool +type addConnCall struct { + _ incomparable + p *clientConnPool done chan struct{} // closed when done err error } -func (c *http2addConnCall) run(t *http2Transport, key string, tc net.Conn) { +func (c *addConnCall) run(t *Transport, key string, tc net.Conn) { cc, err := t.NewClientConn(tc) p := c.p @@ -208,23 +208,23 @@ func (c *http2addConnCall) run(t *http2Transport, key string, tc net.Conn) { } // p.mu must be held -func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { +func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) { for _, v := range p.conns[key] { if v == cc { return } } if p.conns == nil { - p.conns = make(map[string][]*http2ClientConn) + p.conns = make(map[string][]*ClientConn) } if p.keys == nil { - p.keys = make(map[*http2ClientConn][]string) + p.keys = make(map[*ClientConn][]string) } p.conns[key] = append(p.conns[key], cc) p.keys[cc] = append(p.keys[cc], key) } -func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { +func (p *clientConnPool) MarkDead(cc *ClientConn) { p.mu.Lock() defer p.mu.Unlock() for _, key := range p.keys[cc] { @@ -232,7 +232,7 @@ func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { if !ok { continue } - newList := http2filterOutClientConn(vv, cc) + newList := filterOutClientConn(vv, cc) if len(newList) > 0 { p.conns[key] = newList } else { @@ -242,7 +242,7 @@ func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { delete(p.keys, cc) } -func (p *http2clientConnPool) closeIdleConnections() { +func (p *clientConnPool) closeIdleConnections() { p.mu.Lock() defer p.mu.Unlock() // TODO: don't close a cc if it was just added to the pool @@ -258,7 +258,7 @@ func (p *http2clientConnPool) closeIdleConnections() { } } -func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { +func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn { out := in[:0] for _, v := range in { if v != exclude { @@ -276,17 +276,17 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ // noDialClientConnPool is an implementation of http2.ClientConnPool // which never dials. We let the HTTP/1.1 client dial and use its TLS // connection instead. -type http2noDialClientConnPool struct{ *http2clientConnPool } +type noDialClientConnPool struct{ *clientConnPool } -func (p http2noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2noDialOnMiss) +func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { + return p.getClientConn(req, addr, noDialOnMiss) } // shouldRetryDial reports whether the current request should // retry dialing after the call finished unsuccessfully, for example // if the dial was canceled because of a context cancellation or // deadline expiry. -func http2shouldRetryDial(call *http2dialCall, req *http.Request) bool { +func shouldRetryDial(call *dialCall, req *http.Request) bool { if call.err == nil { // No error, no need to retry return false diff --git a/h2_databuffer.go b/internal/http2/databuffer.go similarity index 80% rename from h2_databuffer.go rename to internal/http2/databuffer.go index be110504..a3067f8d 100644 --- a/h2_databuffer.go +++ b/internal/http2/databuffer.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "errors" @@ -21,14 +21,14 @@ import ( // improved enough that we can instead allocate chunks like this: // make([]byte, max(16<<10, expectedBytesRemaining)) var ( - http2dataChunkSizeClasses = []int{ + dataChunkSizeClasses = []int{ 1 << 10, 2 << 10, 4 << 10, 8 << 10, 16 << 10, } - http2dataChunkPools = [...]sync.Pool{ + dataChunkPools = [...]sync.Pool{ {New: func() interface{} { return make([]byte, 1<<10) }}, {New: func() interface{} { return make([]byte, 2<<10) }}, {New: func() interface{} { return make([]byte, 4<<10) }}, @@ -37,20 +37,20 @@ var ( } ) -func http2getDataBufferChunk(size int64) []byte { +func getDataBufferChunk(size int64) []byte { i := 0 - for ; i < len(http2dataChunkSizeClasses)-1; i++ { - if size <= int64(http2dataChunkSizeClasses[i]) { + for ; i < len(dataChunkSizeClasses)-1; i++ { + if size <= int64(dataChunkSizeClasses[i]) { break } } - return http2dataChunkPools[i].Get().([]byte) + return dataChunkPools[i].Get().([]byte) } -func http2putDataBufferChunk(p []byte) { - for i, n := range http2dataChunkSizeClasses { +func putDataBufferChunk(p []byte) { + for i, n := range dataChunkSizeClasses { if len(p) == n { - http2dataChunkPools[i].Put(p) + dataChunkPools[i].Put(p) return } } @@ -62,7 +62,7 @@ func http2putDataBufferChunk(p []byte) { // The buffer is divided into chunks so the server can limit the // total memory used by a single connection without limiting the // request body size on any single stream. -type http2dataBuffer struct { +type dataBuffer struct { chunks [][]byte r int // next byte to read is chunks[0][r] w int // next byte to write is chunks[len(chunks)-1][w] @@ -74,7 +74,7 @@ var errReadEmpty = errors.New("read from empty dataBuffer") // Read copies bytes from the buffer into p. // It is an error to read when no data is available. -func (b *http2dataBuffer) Read(p []byte) (int, error) { +func (b *dataBuffer) Read(p []byte) (int, error) { if b.size == 0 { return 0, errReadEmpty } @@ -88,7 +88,7 @@ func (b *http2dataBuffer) Read(p []byte) (int, error) { b.size -= n // If the first chunk has been consumed, advance to the next chunk. if b.r == len(b.chunks[0]) { - http2putDataBufferChunk(b.chunks[0]) + putDataBufferChunk(b.chunks[0]) end := len(b.chunks) - 1 copy(b.chunks[:end], b.chunks[1:]) b.chunks[end] = nil @@ -99,7 +99,7 @@ func (b *http2dataBuffer) Read(p []byte) (int, error) { return ntotal, nil } -func (b *http2dataBuffer) bytesFromFirstChunk() []byte { +func (b *dataBuffer) bytesFromFirstChunk() []byte { if len(b.chunks) == 1 { return b.chunks[0][b.r:b.w] } @@ -107,12 +107,12 @@ func (b *http2dataBuffer) bytesFromFirstChunk() []byte { } // Len returns the number of bytes of the unread portion of the buffer. -func (b *http2dataBuffer) Len() int { +func (b *dataBuffer) Len() int { return b.size } // Write appends p to the buffer. -func (b *http2dataBuffer) Write(p []byte) (int, error) { +func (b *dataBuffer) Write(p []byte) (int, error) { ntotal := len(p) for len(p) > 0 { // If the last chunk is empty, allocate a new chunk. Try to allocate @@ -132,14 +132,14 @@ func (b *http2dataBuffer) Write(p []byte) (int, error) { return ntotal, nil } -func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { +func (b *dataBuffer) lastChunkOrAlloc(want int64) []byte { if len(b.chunks) != 0 { last := b.chunks[len(b.chunks)-1] if b.w < len(last) { return last } } - chunk := http2getDataBufferChunk(want) + chunk := getDataBufferChunk(want) b.chunks = append(b.chunks, chunk) b.w = 0 return chunk diff --git a/h2_databuffer_test.go b/internal/http2/databuffer_test.go similarity index 91% rename from h2_databuffer_test.go rename to internal/http2/databuffer_test.go index b2da7459..32cd5f38 100644 --- a/h2_databuffer_test.go +++ b/internal/http2/databuffer_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bytes" @@ -40,7 +40,7 @@ func fmtDataChunks(chunks [][]byte) string { return out } -func testDataBuffer(t *testing.T, wantBytes []byte, setup func(t *testing.T) *http2dataBuffer) { +func testDataBuffer(t *testing.T, wantBytes []byte, setup func(t *testing.T) *dataBuffer) { // Run setup, then read the remaining bytes from the dataBuffer and check // that they match wantBytes. We use different read sizes to check corner // cases in Read. @@ -83,8 +83,8 @@ func TestDataBufferAllocation(t *testing.T) { wantRead.Write(p) } - testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *http2dataBuffer { - b := &http2dataBuffer{} + testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *dataBuffer { + b := &dataBuffer{} for _, p := range writes { if n, err := b.Write(p); n != len(p) || err != nil { t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) @@ -118,8 +118,8 @@ func TestDataBufferAllocationWithExpected(t *testing.T) { wantRead.Write(p) } - testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *http2dataBuffer { - b := &http2dataBuffer{expected: 32 * 1024} + testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *dataBuffer { + b := &dataBuffer{expected: 32 * 1024} for _, p := range writes { if n, err := b.Write(p); n != len(p) || err != nil { t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) @@ -138,8 +138,8 @@ func TestDataBufferAllocationWithExpected(t *testing.T) { } func TestDataBufferWriteAfterPartialRead(t *testing.T) { - testDataBuffer(t, []byte("cdxyz"), func(t *testing.T) *http2dataBuffer { - b := &http2dataBuffer{} + testDataBuffer(t, []byte("cdxyz"), func(t *testing.T) *dataBuffer { + b := &dataBuffer{} if n, err := b.Write([]byte("abcd")); n != 4 || err != nil { t.Fatalf("Write(\"abcd\")=%v,%v want 4,nil", n, err) } diff --git a/internal/http2/errors.go b/internal/http2/errors.go new file mode 100644 index 00000000..07bc7d6b --- /dev/null +++ b/internal/http2/errors.go @@ -0,0 +1,138 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "errors" + "fmt" +) + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type ErrCode uint32 + +const ( + ErrCodeNo ErrCode = 0x0 + ErrCodeProtocol ErrCode = 0x1 + ErrCodeInternal ErrCode = 0x2 + ErrCodeFlowControl ErrCode = 0x3 + ErrCodeSettingsTimeout ErrCode = 0x4 + ErrCodeStreamClosed ErrCode = 0x5 + ErrCodeFrameSize ErrCode = 0x6 + ErrCodeRefusedStream ErrCode = 0x7 + ErrCodeCancel ErrCode = 0x8 + ErrCodeCompression ErrCode = 0x9 + ErrCodeConnect ErrCode = 0xa + ErrCodeEnhanceYourCalm ErrCode = 0xb + ErrCodeInadequateSecurity ErrCode = 0xc + ErrCodeHTTP11Required ErrCode = 0xd +) + +var errCodeName = map[ErrCode]string{ + ErrCodeNo: "NO_ERROR", + ErrCodeProtocol: "PROTOCOL_ERROR", + ErrCodeInternal: "INTERNAL_ERROR", + ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + ErrCodeStreamClosed: "STREAM_CLOSED", + ErrCodeFrameSize: "FRAME_SIZE_ERROR", + ErrCodeRefusedStream: "REFUSED_STREAM", + ErrCodeCancel: "CANCEL", + ErrCodeCompression: "COMPRESSION_ERROR", + ErrCodeConnect: "CONNECT_ERROR", + ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e ErrCode) String() string { + if s, ok := errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +func (e ErrCode) stringToken() string { + if s, ok := errCodeName[e]; ok { + return s + } + return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type ConnectionError ErrCode + +func (e ConnectionError) Error() string { + return fmt.Sprintf("connection error: %s", ErrCode(e)) +} + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type StreamError struct { + StreamID uint32 + Code ErrCode + Cause error // optional additional detail +} + +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var errFromPeer = errors.New("received from peer") + +func streamError(id uint32, code ErrCode) StreamError { + return StreamError{StreamID: id, Code: code} +} + +func (e StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// connError represents an HTTP/2 ConnectionError error code, along +// with a string (for debugging) explaining why. +// +// Errors of this type are only returned by the frame parser functions +// and converted into ConnectionError(Code), after stashing away +// the Reason into the Framer's errDetail field, accessible via +// the (*Framer).ErrorDetail method. +type connError struct { + Code ErrCode // the ConnectionError error code + Reason string // additional reason +} + +func (e connError) Error() string { + return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) +} + +type pseudoHeaderError string + +func (e pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type duplicatePseudoHeaderError string + +func (e duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type headerFieldNameError string + +func (e headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type headerFieldValueError string + +func (e headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value for %q", string(e)) +} + +var ( + errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + errPseudoAfterRegular = errors.New("pseudo header field after regular") +) diff --git a/h2_errors_test.go b/internal/http2/errors_test.go similarity index 85% rename from h2_errors_test.go rename to internal/http2/errors_test.go index cc3970b2..da5c58c3 100644 --- a/h2_errors_test.go +++ b/internal/http2/errors_test.go @@ -2,16 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import "testing" func TestErrCodeString(t *testing.T) { tests := []struct { - err http2ErrCode + err ErrCode want string }{ - {http2ErrCodeProtocol, "PROTOCOL_ERROR"}, + {ErrCodeProtocol, "PROTOCOL_ERROR"}, {0xd, "HTTP_1_1_REQUIRED"}, {0xf, "unknown error code 0xf"}, } diff --git a/h2_flow.go b/internal/http2/flow.go similarity index 77% rename from h2_flow.go rename to internal/http2/flow.go index 3259a197..2689f132 100644 --- a/h2_flow.go +++ b/internal/http2/flow.go @@ -2,11 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 // flow is the flow control window's size. -type http2flow struct { - _ http2incomparable +type flow struct { + _ incomparable // n is the number of DATA bytes we're allowed to send. // A flow is kept both on a conn and a per-stream. @@ -15,12 +15,12 @@ type http2flow struct { // conn points to the shared connection-level flow that is // shared by all streams on that conn. It is nil for the flow // that's on the conn directly. - conn *http2flow + conn *flow } -func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } +func (f *flow) setConnFlow(cf *flow) { f.conn = cf } -func (f *http2flow) available() int32 { +func (f *flow) available() int32 { n := f.n if f.conn != nil && f.conn.n < n { n = f.conn.n @@ -28,7 +28,7 @@ func (f *http2flow) available() int32 { return n } -func (f *http2flow) take(n int32) { +func (f *flow) take(n int32) { if n > f.available() { panic("internal error: took too much") } @@ -40,7 +40,7 @@ func (f *http2flow) take(n int32) { // add adds n bytes (positive or negative) to the flow control window. // It returns false if the sum would exceed 2^31-1. -func (f *http2flow) add(n int32) bool { +func (f *flow) add(n int32) bool { sum := f.n + n if (sum > n) == (f.n > 0) { f.n = sum diff --git a/h2_flow_test.go b/internal/http2/flow_test.go similarity index 95% rename from h2_flow_test.go rename to internal/http2/flow_test.go index 2a229be6..7ae82c78 100644 --- a/h2_flow_test.go +++ b/internal/http2/flow_test.go @@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import "testing" func TestFlow(t *testing.T) { - var st http2flow - var conn http2flow + var st flow + var conn flow st.add(3) conn.add(2) @@ -30,7 +30,7 @@ func TestFlow(t *testing.T) { } func TestFlowAdd(t *testing.T) { - var f http2flow + var f flow if !f.add(1) { t.Fatal("failed to add 1") } @@ -52,7 +52,7 @@ func TestFlowAdd(t *testing.T) { } func TestFlowAddOverflow(t *testing.T) { - var f http2flow + var f flow if !f.add(0) { t.Fatal("failed to add 0") } diff --git a/h2_frame.go b/internal/http2/frame.go similarity index 63% rename from h2_frame.go rename to internal/http2/frame.go index f910294f..e91df29c 100644 --- a/h2_frame.go +++ b/internal/http2/frame.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bytes" "encoding/binary" "errors" "fmt" + "github.com/imroc/req/v3/internal/dump" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "io" @@ -17,42 +18,42 @@ import ( "sync" ) -const http2frameHeaderLen = 9 +const frameHeaderLen = 9 -var http2padZeros = make([]byte, 255) // zeros for padding +var padZeros = make([]byte, 255) // zeros for padding // A FrameType is a registered frame type as defined in // http://http2.github.io/http2-spec/#rfc.section.11.2 -type http2FrameType uint8 +type FrameType uint8 const ( - http2FrameData http2FrameType = 0x0 - http2FrameHeaders http2FrameType = 0x1 - http2FramePriority http2FrameType = 0x2 - http2FrameRSTStream http2FrameType = 0x3 - http2FrameSettings http2FrameType = 0x4 - http2FramePushPromise http2FrameType = 0x5 - http2FramePing http2FrameType = 0x6 - http2FrameGoAway http2FrameType = 0x7 - http2FrameWindowUpdate http2FrameType = 0x8 - http2FrameContinuation http2FrameType = 0x9 + FrameData FrameType = 0x0 + FrameHeaders FrameType = 0x1 + FramePriority FrameType = 0x2 + FrameRSTStream FrameType = 0x3 + FrameSettings FrameType = 0x4 + FramePushPromise FrameType = 0x5 + FramePing FrameType = 0x6 + FrameGoAway FrameType = 0x7 + FrameWindowUpdate FrameType = 0x8 + FrameContinuation FrameType = 0x9 ) -var http2frameName = map[http2FrameType]string{ - http2FrameData: "DATA", - http2FrameHeaders: "HEADERS", - http2FramePriority: "PRIORITY", - http2FrameRSTStream: "RST_STREAM", - http2FrameSettings: "SETTINGS", - http2FramePushPromise: "PUSH_PROMISE", - http2FramePing: "PING", - http2FrameGoAway: "GOAWAY", - http2FrameWindowUpdate: "WINDOW_UPDATE", - http2FrameContinuation: "CONTINUATION", -} - -func (t http2FrameType) String() string { - if s, ok := http2frameName[t]; ok { +var frameName = map[FrameType]string{ + FrameData: "DATA", + FrameHeaders: "HEADERS", + FramePriority: "PRIORITY", + FrameRSTStream: "RST_STREAM", + FrameSettings: "SETTINGS", + FramePushPromise: "PUSH_PROMISE", + FramePing: "PING", + FrameGoAway: "GOAWAY", + FrameWindowUpdate: "WINDOW_UPDATE", + FrameContinuation: "CONTINUATION", +} + +func (t FrameType) String() string { + if s, ok := frameName[t]; ok { return s } return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) @@ -60,103 +61,103 @@ func (t http2FrameType) String() string { // Flags is a bitmask of HTTP/2 flags. // The meaning of flags varies depending on the frame type. -type http2Flags uint8 +type Flags uint8 // Has reports whether f contains all (0 or more) flags in v. -func (f http2Flags) Has(v http2Flags) bool { +func (f Flags) Has(v Flags) bool { return (f & v) == v } // Frame-specific FrameHeader flag bits. const ( // Data Frame - http2FlagDataEndStream http2Flags = 0x1 - http2FlagDataPadded http2Flags = 0x8 + FlagDataEndStream Flags = 0x1 + FlagDataPadded Flags = 0x8 // Headers Frame - http2FlagHeadersEndStream http2Flags = 0x1 - http2FlagHeadersEndHeaders http2Flags = 0x4 - http2FlagHeadersPadded http2Flags = 0x8 - http2FlagHeadersPriority http2Flags = 0x20 + FlagHeadersEndStream Flags = 0x1 + FlagHeadersEndHeaders Flags = 0x4 + FlagHeadersPadded Flags = 0x8 + FlagHeadersPriority Flags = 0x20 // Settings Frame - http2FlagSettingsAck http2Flags = 0x1 + FlagSettingsAck Flags = 0x1 // Ping Frame - http2FlagPingAck http2Flags = 0x1 + FlagPingAck Flags = 0x1 // Continuation Frame - http2FlagContinuationEndHeaders http2Flags = 0x4 + FlagContinuationEndHeaders Flags = 0x4 - http2FlagPushPromiseEndHeaders http2Flags = 0x4 - http2FlagPushPromisePadded http2Flags = 0x8 + FlagPushPromiseEndHeaders Flags = 0x4 + FlagPushPromisePadded Flags = 0x8 ) -var http2flagName = map[http2FrameType]map[http2Flags]string{ - http2FrameData: { - http2FlagDataEndStream: "END_STREAM", - http2FlagDataPadded: "PADDED", +var flagName = map[FrameType]map[Flags]string{ + FrameData: { + FlagDataEndStream: "END_STREAM", + FlagDataPadded: "PADDED", }, - http2FrameHeaders: { - http2FlagHeadersEndStream: "END_STREAM", - http2FlagHeadersEndHeaders: "END_HEADERS", - http2FlagHeadersPadded: "PADDED", - http2FlagHeadersPriority: "PRIORITY", + FrameHeaders: { + FlagHeadersEndStream: "END_STREAM", + FlagHeadersEndHeaders: "END_HEADERS", + FlagHeadersPadded: "PADDED", + FlagHeadersPriority: "PRIORITY", }, - http2FrameSettings: { - http2FlagSettingsAck: "ACK", + FrameSettings: { + FlagSettingsAck: "ACK", }, - http2FramePing: { - http2FlagPingAck: "ACK", + FramePing: { + FlagPingAck: "ACK", }, - http2FrameContinuation: { - http2FlagContinuationEndHeaders: "END_HEADERS", + FrameContinuation: { + FlagContinuationEndHeaders: "END_HEADERS", }, - http2FramePushPromise: { - http2FlagPushPromiseEndHeaders: "END_HEADERS", - http2FlagPushPromisePadded: "PADDED", + FramePushPromise: { + FlagPushPromiseEndHeaders: "END_HEADERS", + FlagPushPromisePadded: "PADDED", }, } // a frameParser parses a frame given its FrameHeader and payload // bytes. The length of payload will always equal fh.Length (which // might be 0). -type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) - -var http2frameParsers = map[http2FrameType]http2frameParser{ - http2FrameData: http2parseDataFrame, - http2FrameHeaders: http2parseHeadersFrame, - http2FramePriority: http2parsePriorityFrame, - http2FrameRSTStream: http2parseRSTStreamFrame, - http2FrameSettings: http2parseSettingsFrame, - http2FramePushPromise: http2parsePushPromise, - http2FramePing: http2parsePingFrame, - http2FrameGoAway: http2parseGoAwayFrame, - http2FrameWindowUpdate: http2parseWindowUpdateFrame, - http2FrameContinuation: http2parseContinuationFrame, -} - -func http2typeFrameParser(t http2FrameType) http2frameParser { - if f := http2frameParsers[t]; f != nil { +type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) + +var frameParsers = map[FrameType]frameParser{ + FrameData: parseDataFrame, + FrameHeaders: parseHeadersFrame, + FramePriority: parsePriorityFrame, + FrameRSTStream: parseRSTStreamFrame, + FrameSettings: parseSettingsFrame, + FramePushPromise: parsePushPromise, + FramePing: parsePingFrame, + FrameGoAway: parseGoAwayFrame, + FrameWindowUpdate: parseWindowUpdateFrame, + FrameContinuation: parseContinuationFrame, +} + +func typeFrameParser(t FrameType) frameParser { + if f := frameParsers[t]; f != nil { return f } - return http2parseUnknownFrame + return parseUnknownFrame } // A FrameHeader is the 9 byte header of all HTTP/2 frames. // // See http://http2.github.io/http2-spec/#FrameHeader -type http2FrameHeader struct { +type FrameHeader struct { valid bool // caller can access []byte fields in the Frame // Type is the 1 byte frame type. There are ten standard frame // types, but extension frame types may be written by WriteRawFrame // and will be returned by ReadFrame (as UnknownFrame). - Type http2FrameType + Type FrameType // Flags are the 1 byte of 8 potential bit flags per frame. // They are specific to the frame type. - Flags http2Flags + Flags Flags // Length is the length of the frame, not including the 9 byte header. // The maximum size is one byte less than 16MB (uint24), but only @@ -170,9 +171,9 @@ type http2FrameHeader struct { // Header returns h. It exists so FrameHeaders can be embedded in other // specific frame types and implement the Frame interface. -func (h http2FrameHeader) Header() http2FrameHeader { return h } +func (h FrameHeader) Header() FrameHeader { return h } -func (h http2FrameHeader) String() string { +func (h FrameHeader) String() string { var buf bytes.Buffer buf.WriteString("[FrameHeader ") h.writeDebug(&buf) @@ -180,7 +181,7 @@ func (h http2FrameHeader) String() string { return buf.String() } -func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { +func (h FrameHeader) writeDebug(buf *bytes.Buffer) { buf.WriteString(h.Type.String()) if h.Flags != 0 { buf.WriteString(" flags=") @@ -193,7 +194,7 @@ func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { if set > 1 { buf.WriteByte('|') } - name := http2flagName[h.Type][http2Flags(1<= (1 << 24) { return errFrameTooLarge } @@ -375,10 +376,10 @@ func (h2f *http2Framer) endWrite() error { return err } -func (h2f *http2Framer) logWrite() { +func (h2f *Framer) logWrite() { if h2f.debugFramer == nil { h2f.debugFramerBuf = new(bytes.Buffer) - h2f.debugFramer = http2NewFramer(nil, h2f.debugFramerBuf) + h2f.debugFramer = NewFramer(nil, h2f.debugFramerBuf) h2f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below // Let us read anything, even if we accidentally wrote it // in the wrong order: @@ -390,53 +391,53 @@ func (h2f *http2Framer) logWrite() { h2f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", h2f) return } - h2f.debugWriteLoggerf("http2: Framer %p: wrote %v", h2f, http2summarizeFrame(fr)) + h2f.debugWriteLoggerf("http2: Framer %p: wrote %v", h2f, summarizeFrame(fr)) } -func (h2f *http2Framer) writeByte(v byte) { h2f.wbuf = append(h2f.wbuf, v) } +func (h2f *Framer) writeByte(v byte) { h2f.wbuf = append(h2f.wbuf, v) } -func (h2f *http2Framer) writeBytes(v []byte) { h2f.wbuf = append(h2f.wbuf, v...) } +func (h2f *Framer) writeBytes(v []byte) { h2f.wbuf = append(h2f.wbuf, v...) } -func (h2f *http2Framer) writeUint16(v uint16) { h2f.wbuf = append(h2f.wbuf, byte(v>>8), byte(v)) } +func (h2f *Framer) writeUint16(v uint16) { h2f.wbuf = append(h2f.wbuf, byte(v>>8), byte(v)) } -func (h2f *http2Framer) writeUint32(v uint32) { +func (h2f *Framer) writeUint32(v uint32) { h2f.wbuf = append(h2f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } const ( - http2minMaxFrameSize = 1 << 14 - http2maxFrameSize = 1<<24 - 1 + minMaxFrameSize = 1 << 14 + maxFrameSize = 1<<24 - 1 ) // SetReuseFrames allows the Framer to reuse Frames. // If called on a Framer, Frames returned by calls to ReadFrame are only // valid until the next call to ReadFrame. -func (h2f *http2Framer) SetReuseFrames() { +func (h2f *Framer) SetReuseFrames() { if h2f.frameCache != nil { return } - h2f.frameCache = &http2frameCache{} + h2f.frameCache = &frameCache{} } -type http2frameCache struct { - dataFrame http2DataFrame +type frameCache struct { + dataFrame DataFrame } -func (fc *http2frameCache) getDataFrame() *http2DataFrame { +func (fc *frameCache) getDataFrame() *DataFrame { if fc == nil { - return &http2DataFrame{} + return &DataFrame{} } return &fc.dataFrame } // NewFramer returns a Framer that writes frames to w and reads them from r. -func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { - fr := &http2Framer{ +func NewFramer(w io.Writer, r io.Reader) *Framer { + fr := &Framer{ w: w, r: r, countError: func(string) {}, - logReads: http2logFrameReads, - logWrites: http2logFrameWrites, + logReads: logFrameReads, + logWrites: logFrameWrites, debugReadLoggerf: log.Printf, debugWriteLoggerf: log.Printf, } @@ -447,7 +448,7 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr.readBuf = make([]byte, size) return fr.readBuf } - fr.SetMaxReadFrameSize(http2maxFrameSize) + fr.SetMaxReadFrameSize(maxFrameSize) return fr } @@ -455,9 +456,9 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { // that will be read by a subsequent call to ReadFrame. // It is the caller's responsibility to advertise this // limit with a SETTINGS frame. -func (h2f *http2Framer) SetMaxReadFrameSize(v uint32) { - if v > http2maxFrameSize { - v = http2maxFrameSize +func (h2f *Framer) SetMaxReadFrameSize(v uint32) { + if v > maxFrameSize { + v = maxFrameSize } h2f.maxReadSize = v } @@ -469,7 +470,7 @@ func (h2f *http2Framer) SetMaxReadFrameSize(v uint32) { // to return a non-nil value and like the rest of the http2 package, // its return value is not protected by an API compatibility promise. // ErrorDetail is reset after the next call to ReadFrame. -func (h2f *http2Framer) ErrorDetail() error { +func (h2f *Framer) ErrorDetail() error { return h2f.errDetail } @@ -479,8 +480,8 @@ var errFrameTooLarge = errors.New("http2: frame too large") // terminalReadFrameError reports whether err is an unrecoverable // error from ReadFrame and no other frames should be read. -func http2terminalReadFrameError(err error) bool { - if _, ok := err.(http2StreamError); ok { +func terminalReadFrameError(err error) bool { + if _, ok := err.(StreamError); ok { return false } return err != nil @@ -493,12 +494,12 @@ func http2terminalReadFrameError(err error) bool { // returned error is errFrameTooLarge. Other errors may be of type // ConnectionError, StreamError, or anything else from the underlying // reader. -func (h2f *http2Framer) ReadFrame() (http2Frame, error) { +func (h2f *Framer) ReadFrame() (Frame, error) { h2f.errDetail = nil if h2f.lastFrame != nil { h2f.lastFrame.invalidate() } - fh, err := http2readFrameHeader(h2f.headerBuf[:], h2f.r) + fh, err := readFrameHeader(h2f.headerBuf[:], h2f.r) if err != nil { return nil, err } @@ -509,9 +510,9 @@ func (h2f *http2Framer) ReadFrame() (http2Frame, error) { if _, err := io.ReadFull(h2f.r, payload); err != nil { return nil, err } - f, err := http2typeFrameParser(fh.Type)(h2f.frameCache, fh, h2f.countError, payload) + f, err := typeFrameParser(fh.Type)(h2f.frameCache, fh, h2f.countError, payload) if err != nil { - if ce, ok := err.(http2connError); ok { + if ce, ok := err.(connError); ok { return nil, h2f.connError(ce.Code, ce.Reason) } return nil, err @@ -520,26 +521,26 @@ func (h2f *http2Framer) ReadFrame() (http2Frame, error) { return nil, err } if h2f.logReads { - h2f.debugReadLoggerf("http2: Framer %p: read %v", h2f, http2summarizeFrame(f)) + h2f.debugReadLoggerf("http2: Framer %p: read %v", h2f, summarizeFrame(f)) } - if fh.Type == http2FrameHeaders && h2f.ReadMetaHeaders != nil { - var dumps []*dumper - if h2f.cc != nil && h2f.cc.t.t1 != nil { - dumps = getDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.t1.dump) + if fh.Type == FrameHeaders && h2f.ReadMetaHeaders != nil { + var dumps []*dump.Dumper + if h2f.cc != nil { + dumps = dump.GetDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.Dump()) } if len(dumps) > 0 { - dd := []*dumper{} + dd := []*dump.Dumper{} for _, dump := range dumps { - if dump.ResponseHeader { + if dump.ResponseHeader() { dd = append(dd, dump) } } dumps = dd } - hr, err := h2f.readMetaFrame(f.(*http2HeadersFrame), dumps) + hr, err := h2f.readMetaFrame(f.(*HeadersFrame), dumps) if err == nil && len(dumps) > 0 { for _, dump := range dumps { - dump.dump([]byte("\r\n")) + dump.Dump([]byte("\r\n")) } } return hr, err @@ -551,15 +552,15 @@ func (h2f *http2Framer) ReadFrame() (http2Frame, error) { // stashes away a public reason to the caller can optionally relay it // to the peer before hanging up on them. This might help others debug // their implementations. -func (h2f *http2Framer) connError(code http2ErrCode, reason string) error { +func (h2f *Framer) connError(code ErrCode, reason string) error { h2f.errDetail = errors.New(reason) - return http2ConnectionError(code) + return ConnectionError(code) } // checkFrameOrder reports an error if f is an invalid frame to return // next from ReadFrame. Mostly it checks whether HEADERS and // CONTINUATION frames are contiguous. -func (h2f *http2Framer) checkFrameOrder(f http2Frame) error { +func (h2f *Framer) checkFrameOrder(f Frame) error { last := h2f.lastFrame h2f.lastFrame = f if h2f.AllowIllegalReads { @@ -568,24 +569,24 @@ func (h2f *http2Framer) checkFrameOrder(f http2Frame) error { fh := f.Header() if h2f.lastHeaderStream != 0 { - if fh.Type != http2FrameContinuation { - return h2f.connError(http2ErrCodeProtocol, + if fh.Type != FrameContinuation { + return h2f.connError(ErrCodeProtocol, fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", fh.Type, fh.StreamID, last.Header().Type, h2f.lastHeaderStream)) } if fh.StreamID != h2f.lastHeaderStream { - return h2f.connError(http2ErrCodeProtocol, + return h2f.connError(ErrCodeProtocol, fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", fh.StreamID, h2f.lastHeaderStream)) } - } else if fh.Type == http2FrameContinuation { - return h2f.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + } else if fh.Type == FrameContinuation { + return h2f.connError(ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) } switch fh.Type { - case http2FrameHeaders, http2FrameContinuation: - if fh.Flags.Has(http2FlagHeadersEndHeaders) { + case FrameHeaders, FrameContinuation: + if fh.Flags.Has(FlagHeadersEndHeaders) { h2f.lastHeaderStream = 0 } else { h2f.lastHeaderStream = fh.StreamID @@ -598,25 +599,25 @@ func (h2f *http2Framer) checkFrameOrder(f http2Frame) error { // A DataFrame conveys arbitrary, variable-length sequences of octets // associated with a stream. // See http://http2.github.io/http2-spec/#rfc.section.6.1 -type http2DataFrame struct { - http2FrameHeader +type DataFrame struct { + FrameHeader data []byte } -func (f *http2DataFrame) StreamEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) +func (f *DataFrame) StreamEnded() bool { + return f.FrameHeader.Flags.Has(FlagDataEndStream) } // Data returns the frame's data octets, not including any padding // size byte or padding suffix bytes. // The caller must not retain the returned memory past the next // call to ReadFrame. -func (f *http2DataFrame) Data() []byte { +func (f *DataFrame) Data() []byte { f.checkValid() return f.data } -func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { +func parseDataFrame(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) { if fh.StreamID == 0 { // DATA frames MUST be associated with a stream. If a // DATA frame is received whose stream identifier @@ -624,15 +625,15 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError fu // connection error (Section 5.4.1) of type // PROTOCOL_ERROR. countError("frame_data_stream_0") - return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} + return nil, connError{ErrCodeProtocol, "DATA frame with stream ID 0"} } f := fc.getDataFrame() - f.http2FrameHeader = fh + f.FrameHeader = fh var padSize byte - if fh.Flags.Has(http2FlagDataPadded) { + if fh.Flags.Has(FlagDataPadded) { var err error - payload, padSize, err = http2readByte(payload) + payload, padSize, err = readByte(payload) if err != nil { countError("frame_data_pad_byte_short") return nil, err @@ -644,7 +645,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError fu // treat this as a connection error. // Filed: https://github.com/http2/http2-spec/issues/610 countError("frame_data_pad_too_big") - return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} + return nil, connError{ErrCodeProtocol, "pad size larger than data payload"} } f.data = payload[:len(payload)-int(padSize)] return f, nil @@ -657,11 +658,11 @@ var ( errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") ) -func http2validStreamIDOrZero(streamID uint32) bool { +func validStreamIDOrZero(streamID uint32) bool { return streamID&(1<<31) == 0 } -func http2validStreamID(streamID uint32) bool { +func validStreamID(streamID uint32) bool { return streamID != 0 && streamID&(1<<31) == 0 } @@ -670,7 +671,7 @@ func http2validStreamID(streamID uint32) bool { // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. -func (h2f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { +func (h2f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error { return h2f.WriteDataPadded(streamID, endStream, data, nil) } @@ -683,8 +684,8 @@ func (h2f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. -func (h2f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { - if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } if len(pad) > 0 { @@ -700,14 +701,14 @@ func (h2f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, p } } } - var flags http2Flags + var flags Flags if endStream { - flags |= http2FlagDataEndStream + flags |= FlagDataEndStream } if pad != nil { - flags |= http2FlagDataPadded + flags |= FlagDataPadded } - h2f.startWrite(http2FrameData, flags, streamID) + h2f.startWrite(FrameData, flags, streamID) if pad != nil { h2f.wbuf = append(h2f.wbuf, byte(len(pad))) } @@ -721,13 +722,13 @@ func (h2f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, p // behavior. // // See http://http2.github.io/http2-spec/#SETTINGS -type http2SettingsFrame struct { - http2FrameHeader +type SettingsFrame struct { + FrameHeader p []byte } -func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { +func parseSettingsFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { + if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 { // When this (ACK 0x1) bit is set, the payload of the // SETTINGS frame MUST be empty. Receipt of a // SETTINGS frame with the ACK flag set and a length @@ -735,7 +736,7 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError // connection error (Section 5.4.1) of type // FRAME_SIZE_ERROR. countError("frame_settings_ack_with_length") - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } if fh.StreamID != 0 { // SETTINGS frames always apply to a connection, @@ -746,29 +747,29 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError // respond with a connection error (Section 5.4.1) of // type PROTOCOL_ERROR. countError("frame_settings_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } if len(p)%6 != 0 { countError("frame_settings_mod_6") // Expecting even number of 6 byte settings. - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } - f := &http2SettingsFrame{http2FrameHeader: fh, p: p} - if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + f := &SettingsFrame{FrameHeader: fh, p: p} + if v, ok := f.Value(SettingInitialWindowSize); ok && v > (1<<31)-1 { countError("frame_settings_window_size_too_big") // Values above the maximum flow control window size of 2^31 - 1 MUST // be treated as a connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR. - return nil, http2ConnectionError(http2ErrCodeFlowControl) + return nil, ConnectionError(ErrCodeFlowControl) } return f, nil } -func (f *http2SettingsFrame) IsAck() bool { - return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) +func (f *SettingsFrame) IsAck() bool { + return f.FrameHeader.Flags.Has(FlagSettingsAck) } -func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { +func (f *SettingsFrame) Value(id SettingID) (v uint32, ok bool) { f.checkValid() for i := 0; i < f.NumSettings(); i++ { if s := f.Setting(i); s.ID == id { @@ -780,18 +781,18 @@ func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { // Setting returns the setting from the frame at the given 0-based index. // The index must be >= 0 and less than f.NumSettings(). -func (f *http2SettingsFrame) Setting(i int) http2Setting { +func (f *SettingsFrame) Setting(i int) Setting { buf := f.p - return http2Setting{ - ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), + return Setting{ + ID: SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), } } -func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 } +func (f *SettingsFrame) NumSettings() int { return len(f.p) / 6 } // HasDuplicates reports whether f contains any duplicate setting IDs. -func (f *http2SettingsFrame) HasDuplicates() bool { +func (f *SettingsFrame) HasDuplicates() bool { num := f.NumSettings() if num == 0 { return false @@ -810,7 +811,7 @@ func (f *http2SettingsFrame) HasDuplicates() bool { } return false } - seen := map[http2SettingID]bool{} + seen := map[SettingID]bool{} for i := 0; i < num; i++ { id := f.Setting(i).ID if seen[id] { @@ -823,7 +824,7 @@ func (f *http2SettingsFrame) HasDuplicates() bool { // ForeachSetting runs fn for each setting. // It stops and returns the first error. -func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { +func (f *SettingsFrame) ForeachSetting(fn func(Setting) error) error { f.checkValid() for i := 0; i < f.NumSettings(); i++ { if err := fn(f.Setting(i)); err != nil { @@ -838,8 +839,8 @@ func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WriteSettings(settings ...http2Setting) error { - h2f.startWrite(http2FrameSettings, 0, 0) +func (h2f *Framer) WriteSettings(settings ...Setting) error { + h2f.startWrite(FrameSettings, 0, 0) for _, s := range settings { h2f.writeUint16(uint16(s.ID)) h2f.writeUint32(s.Val) @@ -851,8 +852,8 @@ func (h2f *http2Framer) WriteSettings(settings ...http2Setting) error { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WriteSettingsAck() error { - h2f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) +func (h2f *Framer) WriteSettingsAck() error { + h2f.startWrite(FrameSettings, FlagSettingsAck, 0) return h2f.endWrite() } @@ -860,43 +861,43 @@ func (h2f *http2Framer) WriteSettingsAck() error { // from the sender, as well as determining whether an idle connection // is still functional. // See http://http2.github.io/http2-spec/#rfc.section.6.7 -type http2PingFrame struct { - http2FrameHeader +type PingFrame struct { + FrameHeader Data [8]byte } -func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } +func (f *PingFrame) IsAck() bool { return f.Flags.Has(FlagPingAck) } -func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { +func parsePingFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) { if len(payload) != 8 { countError("frame_ping_length") - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } if fh.StreamID != 0 { countError("frame_ping_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } - f := &http2PingFrame{http2FrameHeader: fh} + f := &PingFrame{FrameHeader: fh} copy(f.Data[:], payload) return f, nil } -func (h2f *http2Framer) WritePing(ack bool, data [8]byte) error { - var flags http2Flags +func (h2f *Framer) WritePing(ack bool, data [8]byte) error { + var flags Flags if ack { - flags = http2FlagPingAck + flags = FlagPingAck } - h2f.startWrite(http2FramePing, flags, 0) + h2f.startWrite(FramePing, flags, 0) h2f.writeBytes(data[:]) return h2f.endWrite() } // A GoAwayFrame informs the remote peer to stop creating streams on this connection. // See http://http2.github.io/http2-spec/#rfc.section.6.8 -type http2GoAwayFrame struct { - http2FrameHeader +type GoAwayFrame struct { + FrameHeader LastStreamID uint32 - ErrCode http2ErrCode + ErrCode ErrCode debugData []byte } @@ -904,30 +905,30 @@ type http2GoAwayFrame struct { // are not defined. // The caller must not retain the returned memory past the next // call to ReadFrame. -func (f *http2GoAwayFrame) DebugData() []byte { +func (f *GoAwayFrame) DebugData() []byte { f.checkValid() return f.debugData } -func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { +func parseGoAwayFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { if fh.StreamID != 0 { countError("frame_goaway_has_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } if len(p) < 8 { countError("frame_goaway_short") - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } - return &http2GoAwayFrame{ - http2FrameHeader: fh, - LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), - ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), - debugData: p[8:], + return &GoAwayFrame{ + FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], }, nil } -func (h2f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { - h2f.startWrite(http2FrameGoAway, 0, 0) +func (h2f *Framer) WriteGoAway(maxStreamID uint32, code ErrCode, debugData []byte) error { + h2f.startWrite(FrameGoAway, 0, 0) h2f.writeUint32(maxStreamID & (1<<31 - 1)) h2f.writeUint32(uint32(code)) h2f.writeBytes(debugData) @@ -936,8 +937,8 @@ func (h2f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debug // An UnknownFrame is the frame type returned when the frame type is unknown // or no specific frame type parser exists. -type http2UnknownFrame struct { - http2FrameHeader +type UnknownFrame struct { + FrameHeader p []byte } @@ -946,26 +947,26 @@ type http2UnknownFrame struct { // Framer.ReadFrame, nor is it valid to retain the returned slice. // The memory is owned by the Framer and is invalidated when the next // frame is read. -func (f *http2UnknownFrame) Payload() []byte { +func (f *UnknownFrame) Payload() []byte { f.checkValid() return f.p } -func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { - return &http2UnknownFrame{fh, p}, nil +func parseUnknownFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { + return &UnknownFrame{fh, p}, nil } // A WindowUpdateFrame is used to implement flow control. // See http://http2.github.io/http2-spec/#rfc.section.6.9 -type http2WindowUpdateFrame struct { - http2FrameHeader +type WindowUpdateFrame struct { + FrameHeader Increment uint32 // never read with high bit set } -func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { +func parseWindowUpdateFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { if len(p) != 4 { countError("frame_windowupdate_bad_len") - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit if inc == 0 { @@ -977,14 +978,14 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countE // error (Section 5.4.1). if fh.StreamID == 0 { countError("frame_windowupdate_zero_inc_conn") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } countError("frame_windowupdate_zero_inc_stream") - return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + return nil, streamError(fh.StreamID, ErrCodeProtocol) } - return &http2WindowUpdateFrame{ - http2FrameHeader: fh, - Increment: inc, + return &WindowUpdateFrame{ + FrameHeader: fh, + Increment: inc, }, nil } @@ -992,47 +993,47 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countE // The increment value must be between 1 and 2,147,483,647, inclusive. // If the Stream ID is zero, the window update applies to the // connection as a whole. -func (h2f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { +func (h2f *Framer) WriteWindowUpdate(streamID, incr uint32) error { // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." if (incr < 1 || incr > 2147483647) && !h2f.AllowIllegalWrites { return errors.New("illegal window increment value") } - h2f.startWrite(http2FrameWindowUpdate, 0, streamID) + h2f.startWrite(FrameWindowUpdate, 0, streamID) h2f.writeUint32(incr) return h2f.endWrite() } // A HeadersFrame is used to open a stream and additionally carries a // header block fragment. -type http2HeadersFrame struct { - http2FrameHeader +type HeadersFrame struct { + FrameHeader // Priority is set if FlagHeadersPriority is set in the FrameHeader. - Priority http2PriorityParam + Priority PriorityParam headerFragBuf []byte // not owned } -func (f *http2HeadersFrame) HeaderBlockFragment() []byte { +func (f *HeadersFrame) HeaderBlockFragment() []byte { f.checkValid() return f.headerFragBuf } -func (f *http2HeadersFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) +func (f *HeadersFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagHeadersEndHeaders) } -func (f *http2HeadersFrame) StreamEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) +func (f *HeadersFrame) StreamEnded() bool { + return f.FrameHeader.Flags.Has(FlagHeadersEndStream) } -func (f *http2HeadersFrame) HasPriority() bool { - return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) +func (f *HeadersFrame) HasPriority() bool { + return f.FrameHeader.Flags.Has(FlagHeadersPriority) } -func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { - hf := &http2HeadersFrame{ - http2FrameHeader: fh, +func parseHeadersFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (_ Frame, err error) { + hf := &HeadersFrame{ + FrameHeader: fh, } if fh.StreamID == 0 { // HEADERS frames MUST be associated with a stream. If a HEADERS frame @@ -1040,25 +1041,25 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR. countError("frame_headers_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} + return nil, connError{ErrCodeProtocol, "HEADERS frame with stream ID 0"} } var padLength uint8 - if fh.Flags.Has(http2FlagHeadersPadded) { - if p, padLength, err = http2readByte(p); err != nil { + if fh.Flags.Has(FlagHeadersPadded) { + if p, padLength, err = readByte(p); err != nil { countError("frame_headers_pad_short") return } } - if fh.Flags.Has(http2FlagHeadersPriority) { + if fh.Flags.Has(FlagHeadersPriority) { var v uint32 - p, v, err = http2readUint32(p) + p, v, err = readUint32(p) if err != nil { countError("frame_headers_prio_short") return nil, err } hf.Priority.StreamDep = v & 0x7fffffff hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set - p, hf.Priority.Weight, err = http2readByte(p) + p, hf.Priority.Weight, err = readByte(p) if err != nil { countError("frame_headers_prio_weight_short") return nil, err @@ -1066,14 +1067,14 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError } if len(p)-int(padLength) < 0 { countError("frame_headers_pad_too_big") - return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + return nil, streamError(fh.StreamID, ErrCodeProtocol) } hf.headerFragBuf = p[:len(p)-int(padLength)] return hf, nil } // HeadersFrameParam are the parameters for writing a HEADERS frame. -type http2HeadersFrameParam struct { +type HeadersFrameParam struct { // StreamID is the required Stream ID to initiate. StreamID uint32 // BlockFragment is part (or all) of a Header Block. @@ -1096,7 +1097,7 @@ type http2HeadersFrameParam struct { // Priority, if non-zero, includes stream priority information // in the HEADER frame. - Priority http2PriorityParam + Priority PriorityParam } // WriteHeaders writes a single HEADERS frame. @@ -1107,30 +1108,30 @@ type http2HeadersFrameParam struct { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { - if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WriteHeaders(p HeadersFrameParam) error { + if !validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { return errStreamID } - var flags http2Flags + var flags Flags if p.PadLength != 0 { - flags |= http2FlagHeadersPadded + flags |= FlagHeadersPadded } if p.EndStream { - flags |= http2FlagHeadersEndStream + flags |= FlagHeadersEndStream } if p.EndHeaders { - flags |= http2FlagHeadersEndHeaders + flags |= FlagHeadersEndHeaders } if !p.Priority.IsZero() { - flags |= http2FlagHeadersPriority + flags |= FlagHeadersPriority } - h2f.startWrite(http2FrameHeaders, flags, p.StreamID) + h2f.startWrite(FrameHeaders, flags, p.StreamID) if p.PadLength != 0 { h2f.writeByte(p.PadLength) } if !p.Priority.IsZero() { v := p.Priority.StreamDep - if !http2validStreamIDOrZero(v) && !h2f.AllowIllegalWrites { + if !validStreamIDOrZero(v) && !h2f.AllowIllegalWrites { return errDepStreamID } if p.Priority.Exclusive { @@ -1140,19 +1141,19 @@ func (h2f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { h2f.writeByte(p.Priority.Weight) } h2f.wbuf = append(h2f.wbuf, p.BlockFragment...) - h2f.wbuf = append(h2f.wbuf, http2padZeros[:p.PadLength]...) + h2f.wbuf = append(h2f.wbuf, padZeros[:p.PadLength]...) return h2f.endWrite() } // A PriorityFrame specifies the sender-advised priority of a stream. // See http://http2.github.io/http2-spec/#rfc.section.6.3 -type http2PriorityFrame struct { - http2FrameHeader - http2PriorityParam +type PriorityFrame struct { + FrameHeader + PriorityParam } // PriorityParam are the stream prioritzation parameters. -type http2PriorityParam struct { +type PriorityParam struct { // StreamDep is a 31-bit stream identifier for the // stream that this stream depends on. Zero means no // dependency. @@ -1168,24 +1169,24 @@ type http2PriorityParam struct { Weight uint8 } -func (p http2PriorityParam) IsZero() bool { - return p == http2PriorityParam{} +func (p PriorityParam) IsZero() bool { + return p == PriorityParam{} } -func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { +func parsePriorityFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) { if fh.StreamID == 0 { countError("frame_priority_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} + return nil, connError{ErrCodeProtocol, "PRIORITY frame with stream ID 0"} } if len(payload) != 5 { countError("frame_priority_bad_length") - return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} + return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} } v := binary.BigEndian.Uint32(payload[:4]) streamID := v & 0x7fffffff // mask off high bit - return &http2PriorityFrame{ - http2FrameHeader: fh, - http2PriorityParam: http2PriorityParam{ + return &PriorityFrame{ + FrameHeader: fh, + PriorityParam: PriorityParam{ Weight: payload[4], StreamDep: streamID, Exclusive: streamID != v, // was high bit set? @@ -1197,14 +1198,14 @@ func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { - if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WritePriority(streamID uint32, p PriorityParam) error { + if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } - if !http2validStreamIDOrZero(p.StreamDep) { + if !validStreamIDOrZero(p.StreamDep) { return errDepStreamID } - h2f.startWrite(http2FramePriority, 0, streamID) + h2f.startWrite(FramePriority, 0, streamID) v := p.StreamDep if p.Exclusive { v |= 1 << 31 @@ -1216,97 +1217,97 @@ func (h2f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) err // A RSTStreamFrame allows for abnormal termination of a stream. // See http://http2.github.io/http2-spec/#rfc.section.6.4 -type http2RSTStreamFrame struct { - http2FrameHeader - ErrCode http2ErrCode +type RSTStreamFrame struct { + FrameHeader + ErrCode ErrCode } -func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { +func parseRSTStreamFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { if len(p) != 4 { countError("frame_rststream_bad_len") - return nil, http2ConnectionError(http2ErrCodeFrameSize) + return nil, ConnectionError(ErrCodeFrameSize) } if fh.StreamID == 0 { countError("frame_rststream_zero_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } - return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil + return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil } // WriteRSTStream writes a RST_STREAM frame. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { - if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WriteRSTStream(streamID uint32, code ErrCode) error { + if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } - h2f.startWrite(http2FrameRSTStream, 0, streamID) + h2f.startWrite(FrameRSTStream, 0, streamID) h2f.writeUint32(uint32(code)) return h2f.endWrite() } // A ContinuationFrame is used to continue a sequence of header block fragments. // See http://http2.github.io/http2-spec/#rfc.section.6.10 -type http2ContinuationFrame struct { - http2FrameHeader +type ContinuationFrame struct { + FrameHeader headerFragBuf []byte } -func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { +func parseContinuationFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) { if fh.StreamID == 0 { countError("frame_continuation_zero_stream") - return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} + return nil, connError{ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} } - return &http2ContinuationFrame{fh, p}, nil + return &ContinuationFrame{fh, p}, nil } -func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { +func (f *ContinuationFrame) HeaderBlockFragment() []byte { f.checkValid() return f.headerFragBuf } -func (f *http2ContinuationFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) +func (f *ContinuationFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagContinuationEndHeaders) } // WriteContinuation writes a CONTINUATION frame. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { - if !http2validStreamID(streamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } - var flags http2Flags + var flags Flags if endHeaders { - flags |= http2FlagContinuationEndHeaders + flags |= FlagContinuationEndHeaders } - h2f.startWrite(http2FrameContinuation, flags, streamID) + h2f.startWrite(FrameContinuation, flags, streamID) h2f.wbuf = append(h2f.wbuf, headerBlockFragment...) return h2f.endWrite() } // A PushPromiseFrame is used to initiate a server stream. // See http://http2.github.io/http2-spec/#rfc.section.6.6 -type http2PushPromiseFrame struct { - http2FrameHeader +type PushPromiseFrame struct { + FrameHeader PromiseID uint32 headerFragBuf []byte // not owned } -func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { +func (f *PushPromiseFrame) HeaderBlockFragment() []byte { f.checkValid() return f.headerFragBuf } -func (f *http2PushPromiseFrame) HeadersEnded() bool { - return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) +func (f *PushPromiseFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagPushPromiseEndHeaders) } -func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { - pp := &http2PushPromiseFrame{ - http2FrameHeader: fh, +func parsePushPromise(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (_ Frame, err error) { + pp := &PushPromiseFrame{ + FrameHeader: fh, } if pp.StreamID == 0 { // PUSH_PROMISE frames MUST be associated with an existing, @@ -1316,19 +1317,19 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError f // 0x0, a recipient MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. countError("frame_pushpromise_zero_stream") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } // The PUSH_PROMISE frame includes optional padding. // Padding fields and flags are identical to those defined for DATA frames var padLength uint8 - if fh.Flags.Has(http2FlagPushPromisePadded) { - if p, padLength, err = http2readByte(p); err != nil { + if fh.Flags.Has(FlagPushPromisePadded) { + if p, padLength, err = readByte(p); err != nil { countError("frame_pushpromise_pad_short") return } } - p, pp.PromiseID, err = http2readUint32(p) + p, pp.PromiseID, err = readUint32(p) if err != nil { countError("frame_pushpromise_promiseid_short") return @@ -1338,14 +1339,14 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError f if int(padLength) > len(p) { // like the DATA frame, error out if padding is longer than the body. countError("frame_pushpromise_pad_too_big") - return nil, http2ConnectionError(http2ErrCodeProtocol) + return nil, ConnectionError(ErrCodeProtocol) } pp.headerFragBuf = p[:len(p)-int(padLength)] return pp, nil } // PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. -type http2PushPromiseParam struct { +type PushPromiseParam struct { // StreamID is the required Stream ID to initiate. StreamID uint32 @@ -1373,62 +1374,62 @@ type http2PushPromiseParam struct { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { - if !http2validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { +func (h2f *Framer) WritePushPromise(p PushPromiseParam) error { + if !validStreamID(p.StreamID) && !h2f.AllowIllegalWrites { return errStreamID } - var flags http2Flags + var flags Flags if p.PadLength != 0 { - flags |= http2FlagPushPromisePadded + flags |= FlagPushPromisePadded } if p.EndHeaders { - flags |= http2FlagPushPromiseEndHeaders + flags |= FlagPushPromiseEndHeaders } - h2f.startWrite(http2FramePushPromise, flags, p.StreamID) + h2f.startWrite(FramePushPromise, flags, p.StreamID) if p.PadLength != 0 { h2f.writeByte(p.PadLength) } - if !http2validStreamID(p.PromiseID) && !h2f.AllowIllegalWrites { + if !validStreamID(p.PromiseID) && !h2f.AllowIllegalWrites { return errStreamID } h2f.writeUint32(p.PromiseID) h2f.wbuf = append(h2f.wbuf, p.BlockFragment...) - h2f.wbuf = append(h2f.wbuf, http2padZeros[:p.PadLength]...) + h2f.wbuf = append(h2f.wbuf, padZeros[:p.PadLength]...) return h2f.endWrite() } // WriteRawFrame writes a raw frame. This can be used to write // extension frames unknown to this package. -func (h2f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { +func (h2f *Framer) WriteRawFrame(t FrameType, flags Flags, streamID uint32, payload []byte) error { h2f.startWrite(t, flags, streamID) h2f.writeBytes(payload) return h2f.endWrite() } -func http2readByte(p []byte) (remain []byte, b byte, err error) { +func readByte(p []byte) (remain []byte, b byte, err error) { if len(p) == 0 { return nil, 0, io.ErrUnexpectedEOF } return p[1:], p[0], nil } -func http2readUint32(p []byte) (remain []byte, v uint32, err error) { +func readUint32(p []byte) (remain []byte, v uint32, err error) { if len(p) < 4 { return nil, 0, io.ErrUnexpectedEOF } return p[4:], binary.BigEndian.Uint32(p[:4]), nil } -type http2streamEnder interface { +type streamEnder interface { StreamEnded() bool } -type http2headersEnder interface { +type headersEnder interface { HeadersEnded() bool } -type http2headersOrContinuation interface { - http2headersEnder +type headersOrContinuation interface { + headersEnder HeaderBlockFragment() []byte } @@ -1438,8 +1439,8 @@ type http2headersOrContinuation interface { // // This type of frame does not appear on the wire and is only returned // by the Framer when Framer.ReadMetaHeaders is set. -type http2MetaHeadersFrame struct { - *http2HeadersFrame +type MetaHeadersFrame struct { + *HeadersFrame // Fields are the fields contained in the HEADERS and // CONTINUATION frames. The underlying slice is owned by the @@ -1461,7 +1462,7 @@ type http2MetaHeadersFrame struct { // PseudoValue returns the given pseudo header field's value. // The provided pseudo field should not contain the leading colon. -func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { +func (mh *MetaHeadersFrame) PseudoValue(pseudo string) string { for _, hf := range mh.Fields { if !hf.IsPseudo() { return "" @@ -1475,7 +1476,7 @@ func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { // RegularFields returns the regular (non-pseudo) header fields of mh. // The caller does not own the returned slice. -func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { +func (mh *MetaHeadersFrame) RegularFields() []hpack.HeaderField { for i, hf := range mh.Fields { if !hf.IsPseudo() { return mh.Fields[i:] @@ -1486,7 +1487,7 @@ func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { // PseudoFields returns the pseudo header fields of mh. // The caller does not own the returned slice. -func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { +func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField { for i, hf := range mh.Fields { if !hf.IsPseudo() { return mh.Fields[:i] @@ -1495,7 +1496,7 @@ func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { return mh.Fields } -func (mh *http2MetaHeadersFrame) checkPseudos() error { +func (mh *MetaHeadersFrame) checkPseudos() error { var isRequest, isResponse bool pf := mh.PseudoFields() for i, hf := range pf { @@ -1505,14 +1506,14 @@ func (mh *http2MetaHeadersFrame) checkPseudos() error { case ":status": isResponse = true default: - return http2pseudoHeaderError(hf.Name) + return pseudoHeaderError(hf.Name) } // Check for duplicates. // This would be a bad algorithm, but N is 4. // And this doesn't allocate. for _, hf2 := range pf[:i] { if hf.Name == hf2.Name { - return http2duplicatePseudoHeaderError(hf.Name) + return duplicatePseudoHeaderError(hf.Name) } } } @@ -1522,7 +1523,7 @@ func (mh *http2MetaHeadersFrame) checkPseudos() error { return nil } -func (h2f *http2Framer) maxHeaderStringLen() int { +func (h2f *Framer) maxHeaderStringLen() int { v := h2f.maxHeaderListSize() if uint32(int(v)) == v { return int(v) @@ -1535,12 +1536,12 @@ func (h2f *http2Framer) maxHeaderStringLen() int { // readMetaFrame returns 0 or more CONTINUATION frames from fr and // merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. -func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (*http2MetaHeadersFrame, error) { +func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaHeadersFrame, error) { if h2f.AllowIllegalReads { return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") } - mh := &http2MetaHeadersFrame{ - http2HeadersFrame: hf, + mh := &MetaHeadersFrame{ + HeadersFrame: hf, } var remainSize = h2f.maxHeaderListSize() var sawRegular bool @@ -1550,12 +1551,12 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* hdec.SetEmitEnabled(true) hdec.SetMaxStringLength(h2f.maxHeaderStringLen()) rawEmitFunc := func(hf hpack.HeaderField) { - if http2VerboseLogs && h2f.logReads { + if VerboseLogs && h2f.logReads { h2f.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httpguts.ValidHeaderFieldValue(hf.Value) { // Don't include the value in the error, because it may be sensitive. - invalid = http2headerFieldValueError(hf.Name) + invalid = headerFieldValueError(hf.Name) } isPseudo := strings.HasPrefix(hf.Name, ":") if isPseudo { @@ -1564,8 +1565,8 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* } } else { sawRegular = true - if !http2validWireHeaderFieldName(hf.Name) { - invalid = http2headerFieldNameError(hf.Name) + if !validWireHeaderFieldName(hf.Name) { + invalid = headerFieldNameError(hf.Name) } } @@ -1589,7 +1590,7 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* if len(dumps) > 0 { emitFunc = func(hf hpack.HeaderField) { for _, dump := range dumps { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) } rawEmitFunc(hf) } @@ -1599,11 +1600,11 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* // Lose reference to MetaHeadersFrame: defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) - var hc http2headersOrContinuation = hf + var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() if _, err := hdec.Write(frag); err != nil { - return nil, http2ConnectionError(http2ErrCodeCompression) + return nil, ConnectionError(ErrCodeCompression) } if hc.HeadersEnded() { @@ -1612,40 +1613,40 @@ func (h2f *http2Framer) readMetaFrame(hf *http2HeadersFrame, dumps []*dumper) (* if f, err := h2f.ReadFrame(); err != nil { return nil, err } else { - hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder + hc = f.(*ContinuationFrame) // guaranteed by checkFrameOrder } } - mh.http2HeadersFrame.headerFragBuf = nil - mh.http2HeadersFrame.invalidate() + mh.HeadersFrame.headerFragBuf = nil + mh.HeadersFrame.invalidate() if err := hdec.Close(); err != nil { - return nil, http2ConnectionError(http2ErrCodeCompression) + return nil, ConnectionError(ErrCodeCompression) } if invalid != nil { h2f.errDetail = invalid - if http2VerboseLogs { + if VerboseLogs { log.Printf("http2: invalid header: %v", invalid) } - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} + return nil, StreamError{mh.StreamID, ErrCodeProtocol, invalid} } if err := mh.checkPseudos(); err != nil { h2f.errDetail = err - if http2VerboseLogs { + if VerboseLogs { log.Printf("http2: invalid pseudo headers: %v", err) } - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} + return nil, StreamError{mh.StreamID, ErrCodeProtocol, err} } return mh, nil } -func http2summarizeFrame(f http2Frame) string { +func summarizeFrame(f Frame) string { var buf bytes.Buffer f.Header().writeDebug(&buf) switch f := f.(type) { - case *http2SettingsFrame: + case *SettingsFrame: n := 0 - f.ForeachSetting(func(s http2Setting) error { + f.ForeachSetting(func(s Setting) error { n++ if n == 1 { buf.WriteString(", settings:") @@ -1656,7 +1657,7 @@ func http2summarizeFrame(f http2Frame) string { if n > 0 { buf.Truncate(buf.Len() - 1) // remove trailing comma } - case *http2DataFrame: + case *DataFrame: data := f.Data() const max = 256 if len(data) > max { @@ -1666,17 +1667,17 @@ func http2summarizeFrame(f http2Frame) string { if len(f.Data()) > max { fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) } - case *http2WindowUpdateFrame: + case *WindowUpdateFrame: if f.StreamID == 0 { buf.WriteString(" (conn)") } fmt.Fprintf(&buf, " incr=%v", f.Increment) - case *http2PingFrame: + case *PingFrame: fmt.Fprintf(&buf, " ping=%q", f.Data[:]) - case *http2GoAwayFrame: + case *GoAwayFrame: fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", f.LastStreamID, f.ErrCode, f.debugData) - case *http2RSTStreamFrame: + case *RSTStreamFrame: fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) } return buf.String() diff --git a/h2_frame_test.go b/internal/http2/frame_test.go similarity index 69% rename from h2_frame_test.go rename to internal/http2/frame_test.go index b2f75aa4..ed5ec9c7 100644 --- a/h2_frame_test.go +++ b/internal/http2/frame_test.go @@ -2,11 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bytes" "fmt" + "github.com/imroc/req/v3/internal/tests" "io" "reflect" "strings" @@ -16,26 +17,26 @@ import ( "golang.org/x/net/http2/hpack" ) -func testFramer() (*http2Framer, *bytes.Buffer) { +func testFramer() (*Framer, *bytes.Buffer) { buf := new(bytes.Buffer) - return http2NewFramer(buf, buf), buf + return NewFramer(buf, buf), buf } func TestFrameSizes(t *testing.T) { // Catch people rearranging the FrameHeader fields. - if got, want := int(unsafe.Sizeof(http2FrameHeader{})), 12; got != want { + if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want { t.Errorf("FrameHeader size = %d; want %d", got, want) } } func TestFrameTypeString(t *testing.T) { tests := []struct { - ft http2FrameType + ft FrameType want string }{ - {http2FrameData, "DATA"}, - {http2FramePing, "PING"}, - {http2FrameGoAway, "GOAWAY"}, + {FrameData, "DATA"}, + {FramePing, "PING"}, + {FrameGoAway, "GOAWAY"}, {0xf, "UNKNOWN_FRAME_TYPE_15"}, } @@ -51,7 +52,7 @@ func TestWriteRST(t *testing.T) { fr, buf := testFramer() var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4 - fr.WriteRSTStream(streamID, http2ErrCode(errCode)) + fr.WriteRSTStream(streamID, ErrCode(errCode)) const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04" if buf.String() != wantEnc { t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) @@ -60,8 +61,8 @@ func TestWriteRST(t *testing.T) { if err != nil { t.Fatal(err) } - want := &http2RSTStreamFrame{ - http2FrameHeader: http2FrameHeader{ + want := &RSTStreamFrame{ + FrameHeader: FrameHeader{ valid: true, Type: 0x3, Flags: 0x0, @@ -88,9 +89,9 @@ func TestWriteData(t *testing.T) { if err != nil { t.Fatal(err) } - df, ok := f.(*http2DataFrame) + df, ok := f.(*DataFrame) if !ok { - t.Fatalf("got %T; want *http2DataFrame", f) + t.Fatalf("got %T; want *DataFrame", f) } if !bytes.Equal(df.Data(), data) { t.Errorf("got %q; want %q", df.Data(), data) @@ -106,7 +107,7 @@ func TestWriteDataPadded(t *testing.T) { endStream bool data []byte pad []byte - wantHeader http2FrameHeader + wantHeader FrameHeader }{ // Unpadded: 0: { @@ -114,9 +115,9 @@ func TestWriteDataPadded(t *testing.T) { endStream: true, data: []byte("foo"), pad: nil, - wantHeader: http2FrameHeader{ - Type: http2FrameData, - Flags: http2FlagDataEndStream, + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream, Length: 3, StreamID: 1, }, @@ -128,9 +129,9 @@ func TestWriteDataPadded(t *testing.T) { endStream: true, data: []byte("foo"), pad: []byte{}, - wantHeader: http2FrameHeader{ - Type: http2FrameData, - Flags: http2FlagDataEndStream | http2FlagDataPadded, + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream | FlagDataPadded, Length: 4, StreamID: 1, }, @@ -142,9 +143,9 @@ func TestWriteDataPadded(t *testing.T) { endStream: false, data: []byte("foo"), pad: []byte{0, 0, 0}, - wantHeader: http2FrameHeader{ - Type: http2FrameData, - Flags: http2FlagDataPadded, + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataPadded, Length: 7, StreamID: 1, }, @@ -164,14 +165,14 @@ func TestWriteDataPadded(t *testing.T) { t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader) continue } - df := f.(*http2DataFrame) + df := f.(*DataFrame) if !bytes.Equal(df.Data(), tt.data) { t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data) } } } -func (fh http2FrameHeader) Equal(b http2FrameHeader) bool { +func (fh FrameHeader) Equal(b FrameHeader) bool { return fh.valid == b.valid && fh.Type == b.Type && fh.Flags == b.Flags && @@ -182,98 +183,98 @@ func (fh http2FrameHeader) Equal(b http2FrameHeader) bool { func TestWriteHeaders(t *testing.T) { tests := []struct { name string - p http2HeadersFrameParam + p HeadersFrameParam wantEnc string - wantFrame *http2HeadersFrame + wantFrame *HeadersFrame }{ { "basic", - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, BlockFragment: []byte("abc"), - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, }, "\x00\x00\x03\x01\x00\x00\x00\x00*abc", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, + Type: FrameHeaders, Length: uint32(len("abc")), }, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, headerFragBuf: []byte("abc"), }, }, { "basic + end flags", - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, BlockFragment: []byte("abc"), EndStream: true, EndHeaders: true, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, }, "\x00\x00\x03\x01\x05\x00\x00\x00*abc", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, - Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders, Length: uint32(len("abc")), }, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, headerFragBuf: []byte("abc"), }, }, { "with padding", - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, BlockFragment: []byte("abc"), EndStream: true, EndHeaders: true, PadLength: 5, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, }, "\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, - Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded, Length: uint32(1 + len("abc") + 5), // pad length + contents + padding }, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, headerFragBuf: []byte("abc"), }, }, { "with priority", - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, BlockFragment: []byte("abc"), EndStream: true, EndHeaders: true, PadLength: 2, - Priority: http2PriorityParam{ + Priority: PriorityParam{ StreamDep: 15, Exclusive: true, Weight: 127, }, }, "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, - Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded | http2FlagHeadersPriority, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding }, - Priority: http2PriorityParam{ + Priority: PriorityParam{ StreamDep: 15, Exclusive: true, Weight: 127, @@ -283,28 +284,28 @@ func TestWriteHeaders(t *testing.T) { }, { "with priority stream dep zero", // golang.org/issue/15444 - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, BlockFragment: []byte("abc"), EndStream: true, EndHeaders: true, PadLength: 2, - Priority: http2PriorityParam{ + Priority: PriorityParam{ StreamDep: 0, Exclusive: true, Weight: 127, }, }, "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, - Flags: http2FlagHeadersEndStream | http2FlagHeadersEndHeaders | http2FlagHeadersPadded | http2FlagHeadersPriority, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding }, - Priority: http2PriorityParam{ + Priority: PriorityParam{ StreamDep: 0, Exclusive: true, Weight: 127, @@ -314,19 +315,19 @@ func TestWriteHeaders(t *testing.T) { }, { "zero length", - http2HeadersFrameParam{ + HeadersFrameParam{ StreamID: 42, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, }, "\x00\x00\x00\x01\x00\x00\x00\x00*", - &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ + &HeadersFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: 42, - Type: http2FrameHeaders, + Type: FrameHeaders, Length: 0, }, - Priority: http2PriorityParam{}, + Priority: PriorityParam{}, }, }, } @@ -352,9 +353,9 @@ func TestWriteHeaders(t *testing.T) { func TestWriteInvalidStreamDep(t *testing.T) { fr, _ := testFramer() - err := fr.WriteHeaders(http2HeadersFrameParam{ + err := fr.WriteHeaders(HeadersFrameParam{ StreamID: 42, - Priority: http2PriorityParam{ + Priority: PriorityParam{ StreamDep: 1 << 31, }, }) @@ -362,7 +363,7 @@ func TestWriteInvalidStreamDep(t *testing.T) { t.Errorf("header error = %v; want %q", err, errDepStreamID) } - err = fr.WritePriority(2, http2PriorityParam{StreamDep: 1 << 31}) + err = fr.WritePriority(2, PriorityParam{StreamDep: 1 << 31}) if err != errDepStreamID { t.Errorf("priority error = %v; want %q", err, errDepStreamID) } @@ -375,17 +376,17 @@ func TestWriteContinuation(t *testing.T) { end bool frag []byte - wantFrame *http2ContinuationFrame + wantFrame *ContinuationFrame }{ { "not end", false, []byte("abc"), - &http2ContinuationFrame{ - http2FrameHeader: http2FrameHeader{ + &ContinuationFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: streamID, - Type: http2FrameContinuation, + Type: FrameContinuation, Length: uint32(len("abc")), }, headerFragBuf: []byte("abc"), @@ -395,12 +396,12 @@ func TestWriteContinuation(t *testing.T) { "end", true, []byte("def"), - &http2ContinuationFrame{ - http2FrameHeader: http2FrameHeader{ + &ContinuationFrame{ + FrameHeader: FrameHeader{ valid: true, StreamID: streamID, - Type: http2FrameContinuation, - Flags: http2FlagContinuationEndHeaders, + Type: FrameContinuation, + Flags: FlagContinuationEndHeaders, Length: uint32(len("def")), }, headerFragBuf: []byte("def"), @@ -429,24 +430,24 @@ func TestWritePriority(t *testing.T) { const streamID = 42 tests := []struct { name string - priority http2PriorityParam - wantFrame *http2PriorityFrame + priority PriorityParam + wantFrame *PriorityFrame }{ { "not exclusive", - http2PriorityParam{ + PriorityParam{ StreamDep: 2, Exclusive: false, Weight: 127, }, - &http2PriorityFrame{ - http2FrameHeader{ + &PriorityFrame{ + FrameHeader{ valid: true, StreamID: streamID, - Type: http2FramePriority, + Type: FramePriority, Length: 5, }, - http2PriorityParam{ + PriorityParam{ StreamDep: 2, Exclusive: false, Weight: 127, @@ -456,19 +457,19 @@ func TestWritePriority(t *testing.T) { { "exclusive", - http2PriorityParam{ + PriorityParam{ StreamDep: 3, Exclusive: true, Weight: 77, }, - &http2PriorityFrame{ - http2FrameHeader{ + &PriorityFrame{ + FrameHeader{ valid: true, StreamID: streamID, - Type: http2FramePriority, + Type: FramePriority, Length: 5, }, - http2PriorityParam{ + PriorityParam{ StreamDep: 3, Exclusive: true, Weight: 77, @@ -495,7 +496,7 @@ func TestWritePriority(t *testing.T) { func TestWriteSettings(t *testing.T) { fr, buf := testFramer() - settings := []http2Setting{{1, 2}, {3, 4}} + settings := []Setting{{1, 2}, {3, 4}} fr.WriteSettings(settings...) const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04" if buf.String() != wantEnc { @@ -505,12 +506,12 @@ func TestWriteSettings(t *testing.T) { if err != nil { t.Fatal(err) } - sf, ok := f.(*http2SettingsFrame) + sf, ok := f.(*SettingsFrame) if !ok { t.Fatalf("Got a %T; want a SettingsFrame", f) } - var got []http2Setting - sf.ForeachSetting(func(s http2Setting) error { + var got []Setting + sf.ForeachSetting(func(s Setting) error { got = append(got, s) valBack, ok := sf.Value(s.ID) if !ok || valBack != s.Val { @@ -547,8 +548,8 @@ func TestWriteWindowUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - want := &http2WindowUpdateFrame{ - http2FrameHeader: http2FrameHeader{ + want := &WindowUpdateFrame{ + FrameHeader: FrameHeader{ valid: true, Type: 0x8, Flags: 0x0, @@ -570,9 +571,9 @@ func testWritePing(t *testing.T, ack bool) { if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { t.Fatal(err) } - var wantFlags http2Flags + var wantFlags Flags if ack { - wantFlags = http2FlagPingAck + wantFlags = FlagPingAck } var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08" if buf.String() != wantEnc { @@ -583,8 +584,8 @@ func testWritePing(t *testing.T, ack bool) { if err != nil { t.Fatal(err) } - want := &http2PingFrame{ - http2FrameHeader: http2FrameHeader{ + want := &PingFrame{ + FrameHeader: FrameHeader{ valid: true, Type: 0x6, Flags: wantFlags, @@ -601,20 +602,20 @@ func testWritePing(t *testing.T, ack bool) { func TestReadFrameHeader(t *testing.T) { tests := []struct { in string - want http2FrameHeader + want FrameHeader }{ - {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: http2FrameHeader{}}, - {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: http2FrameHeader{ + {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}}, + {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{ Length: 66051, Type: 4, Flags: 5, StreamID: 101124105, }}, // Ignore high bit: - {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: http2FrameHeader{ + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{ Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, - {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: http2FrameHeader{ + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{ Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, } for i, tt := range tests { - got, err := http2readFrameHeader(make([]byte, 9), strings.NewReader(tt.in)) + got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in)) if err != nil { t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err) continue @@ -629,8 +630,8 @@ func TestReadFrameHeader(t *testing.T) { func TestReadWriteFrameHeader(t *testing.T) { tests := []struct { len uint32 - typ http2FrameType - flags http2Flags + typ FrameType + flags Flags streamID uint32 }{ {len: 0, typ: 255, flags: 1, streamID: 0}, @@ -652,7 +653,7 @@ func TestReadWriteFrameHeader(t *testing.T) { fr.startWrite(tt.typ, tt.flags, tt.streamID) fr.writeBytes(make([]byte, tt.len)) fr.endWrite() - fh, err := http2ReadFrameHeader(buf) + fh, err := ReadFrameHeader(buf) if err != nil { t.Errorf("ReadFrameHeader(%+v) = %v", tt, err) continue @@ -688,8 +689,8 @@ func TestWriteGoAway(t *testing.T) { if err != nil { t.Fatal(err) } - want := &http2GoAwayFrame{ - http2FrameHeader: http2FrameHeader{ + want := &GoAwayFrame{ + FrameHeader: FrameHeader{ valid: true, Type: 0x7, Flags: 0, @@ -703,13 +704,13 @@ func TestWriteGoAway(t *testing.T) { if !reflect.DeepEqual(f, want) { t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) } - if got := string(f.(*http2GoAwayFrame).DebugData()); got != debug { + if got := string(f.(*GoAwayFrame).DebugData()); got != debug { t.Errorf("debug data = %q; want %q", got, debug) } } func TestWritePushPromise(t *testing.T) { - pp := http2PushPromiseParam{ + pp := PushPromiseParam{ StreamID: 42, PromiseID: 42, BlockFragment: []byte("abc"), @@ -726,12 +727,12 @@ func TestWritePushPromise(t *testing.T) { if err != nil { t.Fatal(err) } - _, ok := f.(*http2PushPromiseFrame) + _, ok := f.(*PushPromiseFrame) if !ok { t.Fatalf("got %T; want *PushPromiseFrame", f) } - want := &http2PushPromiseFrame{ - http2FrameHeader: http2FrameHeader{ + want := &PushPromiseFrame{ + FrameHeader: FrameHeader{ valid: true, Type: 0x5, Flags: 0x0, @@ -748,49 +749,49 @@ func TestWritePushPromise(t *testing.T) { // test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled. func TestReadFrameOrder(t *testing.T) { - head := func(f *http2Framer, id uint32, end bool) { - f.WriteHeaders(http2HeadersFrameParam{ + head := func(f *Framer, id uint32, end bool) { + f.WriteHeaders(HeadersFrameParam{ StreamID: id, BlockFragment: []byte("foo"), // unused, but non-empty EndHeaders: end, }) } - cont := func(f *http2Framer, id uint32, end bool) { + cont := func(f *Framer, id uint32, end bool) { f.WriteContinuation(id, end, []byte("foo")) } tests := [...]struct { name string - w func(*http2Framer) + w func(*Framer) atLeast int wantErr string }{ 0: { - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, true) }, }, 1: { - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, true) head(f, 2, true) }, }, 2: { wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1", - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, false) head(f, 2, true) }, }, 3: { wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1", - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, false) }, }, 4: { - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, false) cont(f, 1, true) head(f, 2, true) @@ -798,7 +799,7 @@ func TestReadFrameOrder(t *testing.T) { }, 5: { wantErr: "got CONTINUATION for stream 2; expected stream 1", - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, false) cont(f, 2, true) head(f, 2, true) @@ -806,32 +807,32 @@ func TestReadFrameOrder(t *testing.T) { }, 6: { wantErr: "unexpected CONTINUATION for stream 1", - w: func(f *http2Framer) { + w: func(f *Framer) { cont(f, 1, true) }, }, 7: { wantErr: "unexpected CONTINUATION for stream 1", - w: func(f *http2Framer) { + w: func(f *Framer) { cont(f, 1, false) }, }, 8: { wantErr: "HEADERS frame with stream ID 0", - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 0, true) }, }, 9: { wantErr: "CONTINUATION frame with stream ID 0", - w: func(f *http2Framer) { + w: func(f *Framer) { cont(f, 0, true) }, }, 10: { wantErr: "unexpected CONTINUATION for stream 1", atLeast: 5, - w: func(f *http2Framer) { + w: func(f *Framer) { head(f, 1, false) cont(f, 1, false) cont(f, 1, false) @@ -843,7 +844,7 @@ func TestReadFrameOrder(t *testing.T) { } for i, tt := range tests { buf := new(bytes.Buffer) - f := http2NewFramer(buf, buf) + f := NewFramer(buf, buf) f.AllowIllegalWrites = true tt.w(f) f.WriteData(1, true, nil) // to test transition away from last step @@ -852,7 +853,7 @@ func TestReadFrameOrder(t *testing.T) { n := 0 var log bytes.Buffer for { - var got http2Frame + var got Frame got, err = f.ReadFrame() fmt.Fprintf(&log, " read %v, %v\n", got, err) if err != nil { @@ -868,7 +869,7 @@ func TestReadFrameOrder(t *testing.T) { t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes()) continue } - if !ok && err != http2ConnectionError(http2ErrCodeProtocol) { + if !ok && err != ConnectionError(ErrCodeProtocol) { t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes()) continue } @@ -906,11 +907,11 @@ func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte } func TestMetaFrameHeader(t *testing.T) { - write := func(f *http2Framer, frags ...[]byte) { + write := func(f *Framer, frags ...[]byte) { for i, frag := range frags { end := (i == len(frags)-1) if i == 0 { - f.WriteHeaders(http2HeadersFrameParam{ + f.WriteHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: frag, EndHeaders: end, @@ -921,11 +922,11 @@ func TestMetaFrameHeader(t *testing.T) { } } - want := func(flags http2Flags, length uint32, pairs ...string) *http2MetaHeadersFrame { - mh := &http2MetaHeadersFrame{ - http2HeadersFrame: &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{ - Type: http2FrameHeaders, + want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame { + mh := &MetaHeadersFrame{ + HeadersFrame: &HeadersFrame{ + FrameHeader: FrameHeader{ + Type: FrameHeaders, Flags: flags, Length: length, StreamID: 1, @@ -942,34 +943,34 @@ func TestMetaFrameHeader(t *testing.T) { } return mh } - truncated := func(mh *http2MetaHeadersFrame) *http2MetaHeadersFrame { + truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame { mh.Truncated = true return mh } - const noFlags http2Flags = 0 + const noFlags Flags = 0 oneKBString := strings.Repeat("a", 1<<10) tests := [...]struct { name string - w func(*http2Framer) + w func(*Framer) want interface{} // *MetaHeaderFrame or error wantErrReason string maxHeaderListSize uint32 }{ 0: { name: "single_headers", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/") write(f, all) }, - want: want(http2FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"), + want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"), }, 1: { name: "with_continuation", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") write(f, all[:1], all[1:]) @@ -978,7 +979,7 @@ func TestMetaFrameHeader(t *testing.T) { }, 2: { name: "with_two_continuation", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") write(f, all[:2], all[2:4], all[4:]) @@ -987,7 +988,7 @@ func TestMetaFrameHeader(t *testing.T) { }, 3: { name: "big_string_okay", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) write(f, all[:2], all[2:]) @@ -996,17 +997,17 @@ func TestMetaFrameHeader(t *testing.T) { }, 4: { name: "big_string_error", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) write(f, all[:2], all[2:]) }, maxHeaderListSize: (1 << 10) / 2, - want: http2ConnectionError(http2ErrCodeCompression), + want: ConnectionError(ErrCodeCompression), }, 5: { name: "max_header_list_truncated", - w: func(f *http2Framer) { + w: func(f *Framer) { var he hpackEncoder var pairs = []string{":method", "GET", ":path", "/"} for i := 0; i < 100; i++ { @@ -1034,71 +1035,71 @@ func TestMetaFrameHeader(t *testing.T) { }, 6: { name: "pseudo_order", - w: func(f *http2Framer) { + w: func(f *Framer) { write(f, encodeHeaderRaw(t, ":method", "GET", "foo", "bar", ":path", "/", // bogus )) }, - want: http2streamError(1, http2ErrCodeProtocol), + want: streamError(1, ErrCodeProtocol), wantErrReason: "pseudo header field after regular", }, 7: { name: "pseudo_unknown", - w: func(f *http2Framer) { + w: func(f *Framer) { write(f, encodeHeaderRaw(t, ":unknown", "foo", // bogus "foo", "bar", )) }, - want: http2streamError(1, http2ErrCodeProtocol), + want: streamError(1, ErrCodeProtocol), wantErrReason: "invalid pseudo-header \":unknown\"", }, 8: { name: "pseudo_mix_request_response", - w: func(f *http2Framer) { + w: func(f *Framer) { write(f, encodeHeaderRaw(t, ":method", "GET", ":status", "100", )) }, - want: http2streamError(1, http2ErrCodeProtocol), + want: streamError(1, ErrCodeProtocol), wantErrReason: "mix of request and response pseudo headers", }, 9: { name: "pseudo_dup", - w: func(f *http2Framer) { + w: func(f *Framer) { write(f, encodeHeaderRaw(t, ":method", "GET", ":method", "POST", )) }, - want: http2streamError(1, http2ErrCodeProtocol), + want: streamError(1, ErrCodeProtocol), wantErrReason: "duplicate pseudo-header \":method\"", }, 10: { name: "trailer_okay_no_pseudo", - w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) }, - want: want(http2FlagHeadersEndHeaders, 8, "foo", "bar"), + w: func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) }, + want: want(FlagHeadersEndHeaders, 8, "foo", "bar"), }, 11: { name: "invalid_field_name", - w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) }, - want: http2streamError(1, http2ErrCodeProtocol), + w: func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) }, + want: streamError(1, ErrCodeProtocol), wantErrReason: "invalid header field name \"CapitalBad\"", }, 12: { name: "invalid_field_value", - w: func(f *http2Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, - want: http2streamError(1, http2ErrCodeProtocol), + w: func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, + want: streamError(1, ErrCodeProtocol), wantErrReason: `invalid header field value for "key"`, }, } for i, tt := range tests { buf := new(bytes.Buffer) - f := http2NewFramer(buf, buf) - f.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + f := NewFramer(buf, buf) + f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) f.MaxHeaderListSize = tt.maxHeaderListSize tt.w(f) @@ -1115,16 +1116,16 @@ func TestMetaFrameHeader(t *testing.T) { // Ignore the StreamError.Cause field, if it matches the wantErrReason. // The test table above predates the Cause field. - if se, ok := err.(http2StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason { + if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason { se.Cause = nil got = se } } if !reflect.DeepEqual(got, tt.want) { - if mhg, ok := got.(*http2MetaHeadersFrame); ok { - if mhw, ok := tt.want.(*http2MetaHeadersFrame); ok { - hg := mhg.http2HeadersFrame - hw := mhw.http2HeadersFrame + if mhg, ok := got.(*MetaHeadersFrame); ok { + if mhw, ok := tt.want.(*MetaHeadersFrame); ok { + hg := mhg.HeadersFrame + hw := mhw.HeadersFrame if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) { t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw) } @@ -1210,7 +1211,7 @@ func TestNoSetReuseFrames(t *testing.T) { } } -func readAndVerifyDataFrame(data string, length byte, fr *http2Framer, buf *bytes.Buffer, t *testing.T) *http2DataFrame { +func readAndVerifyDataFrame(data string, length byte, fr *Framer, buf *bytes.Buffer, t *testing.T) *DataFrame { var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 fr.WriteData(streamID, true, []byte(data)) wantEnc := "\x00\x00" + string(length) + "\x00\x01\x01\x02\x03\x04" + data @@ -1221,9 +1222,9 @@ func readAndVerifyDataFrame(data string, length byte, fr *http2Framer, buf *byte if err != nil { t.Fatal(err) } - df, ok := f.(*http2DataFrame) + df, ok := f.(*DataFrame) if !ok { - t.Fatalf("got %T; want *http2DataFrame", f) + t.Fatalf("got %T; want *DataFrame", f) } if !bytes.Equal(df.Data(), []byte(data)) { t.Errorf("got %q; want %q", df.Data(), []byte(data)) @@ -1241,27 +1242,27 @@ func encodeHeaderRaw(t *testing.T, pairs ...string) []byte { func TestSettingsDuplicates(t *testing.T) { tests := []struct { - settings []http2Setting + settings []Setting want bool }{ {nil, false}, - {[]http2Setting{{ID: 1}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}, false}, + {[]Setting{{ID: 1}}, false}, + {[]Setting{{ID: 1}, {ID: 2}}, false}, + {[]Setting{{ID: 1}, {ID: 2}}, false}, + {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, + {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, + {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}, false}, - {[]http2Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 2}}, true}, - {[]http2Setting{{ID: 4}, {ID: 2}, {ID: 3}, {ID: 4}}, true}, + {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 2}}, true}, + {[]Setting{{ID: 4}, {ID: 2}, {ID: 3}, {ID: 4}}, true}, - {[]http2Setting{ + {[]Setting{ {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, {ID: 9}, {ID: 10}, {ID: 11}, {ID: 12}, }, false}, - {[]http2Setting{ + {[]Setting{ {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, {ID: 9}, {ID: 10}, {ID: 11}, {ID: 11}, @@ -1274,7 +1275,7 @@ func TestSettingsDuplicates(t *testing.T) { if err != nil { t.Fatalf("%d. ReadFrame: %v", i, err) } - sf := f.(*http2SettingsFrame) + sf := f.(*SettingsFrame) got := sf.HasDuplicates() if got != tt.want { t.Errorf("%d. HasDuplicates = %v; want %v", i, got, tt.want) @@ -1284,195 +1285,195 @@ func TestSettingsDuplicates(t *testing.T) { } func TestParseSettingsFrame(t *testing.T) { - fh := http2FrameHeader{} - fh.Flags = http2FlagSettingsAck + fh := FrameHeader{} + fh.Flags = FlagSettingsAck fh.Length = 1 countErr := func(s string) {} - _, err := http2parseSettingsFrame(nil, fh, countErr, nil) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + _, err := parseSettingsFrame(nil, fh, countErr, nil) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - fh = http2FrameHeader{StreamID: 1} - _, err = http2parseSettingsFrame(nil, fh, countErr, nil) - assertErrorContains(t, err, "PROTOCOL_ERROR") + fh = FrameHeader{StreamID: 1} + _, err = parseSettingsFrame(nil, fh, countErr, nil) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - fh = http2FrameHeader{} - _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("roc")) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + fh = FrameHeader{} + _, err = parseSettingsFrame(nil, fh, countErr, []byte("roc")) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - fh = http2FrameHeader{valid: true} - _, err = http2parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) - assertNoError(t, err) + fh = FrameHeader{valid: true} + _, err = parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) + tests.AssertNoError(t, err) } func TestParsePushPromise(t *testing.T) { - fh := http2FrameHeader{} + fh := FrameHeader{} countError := func(string) {} - _, err := http2parsePushPromise(nil, fh, countError, nil) - assertErrorContains(t, err, "PROTOCOL_ERROR") + _, err := parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 1 - fh.Flags = http2FlagPushPromisePadded - _, err = http2parsePushPromise(nil, fh, countError, nil) - assertErrorContains(t, err, "EOF") + fh.Flags = FlagPushPromisePadded + _, err = parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "EOF") fh.Flags = 0 - _, err = http2parsePushPromise(nil, fh, countError, nil) - assertErrorContains(t, err, "EOF") + _, err = parsePushPromise(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "EOF") - _, err = http2parsePushPromise(nil, fh, countError, []byte("ksjfksjksjflskk")) - assertNoError(t, err) + _, err = parsePushPromise(nil, fh, countError, []byte("ksjfksjksjflskk")) + tests.AssertNoError(t, err) } func TestSummarizeFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} - var f http2Frame - f = &http2SettingsFrame{http2FrameHeader: fh, p: []byte{0x09, 0x01, 0x80, 0x20, 0x00, 0x11}} - s := http2summarizeFrame(f) - assertContains(t, s, "len=0", true) - - f = &http2DataFrame{http2FrameHeader: fh} - s = http2summarizeFrame(f) - assertContains(t, s, `data=""`, true) - - f = &http2WindowUpdateFrame{http2FrameHeader: fh} - s = http2summarizeFrame(f) - assertContains(t, s, "conn", true) - - f = &http2PingFrame{http2FrameHeader: fh} - s = http2summarizeFrame(f) - assertContains(t, s, "ping", true) - - f = &http2GoAwayFrame{http2FrameHeader: fh} - s = http2summarizeFrame(f) - assertContains(t, s, "laststreamid", true) - - f = &http2RSTStreamFrame{http2FrameHeader: fh} - s = http2summarizeFrame(f) - assertContains(t, s, "no_error", true) + fh := FrameHeader{valid: true} + var f Frame + f = &SettingsFrame{FrameHeader: fh, p: []byte{0x09, 0x01, 0x80, 0x20, 0x00, 0x11}} + s := summarizeFrame(f) + tests.AssertContains(t, s, "len=0", true) + + f = &DataFrame{FrameHeader: fh} + s = summarizeFrame(f) + tests.AssertContains(t, s, `data=""`, true) + + f = &WindowUpdateFrame{FrameHeader: fh} + s = summarizeFrame(f) + tests.AssertContains(t, s, "conn", true) + + f = &PingFrame{FrameHeader: fh} + s = summarizeFrame(f) + tests.AssertContains(t, s, "ping", true) + + f = &GoAwayFrame{FrameHeader: fh} + s = summarizeFrame(f) + tests.AssertContains(t, s, "laststreamid", true) + + f = &RSTStreamFrame{FrameHeader: fh} + s = summarizeFrame(f) + tests.AssertContains(t, s, "no_error", true) } func TestParseDataFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} - _, err := http2parseDataFrame(nil, fh, countError, nil) - assertErrorContains(t, err, "DATA frame with stream ID 0") + _, err := parseDataFrame(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "DATA frame with stream ID 0") fh.StreamID = 1 - fh.Flags = http2FlagDataPadded - fc := &http2frameCache{} + fh.Flags = FlagDataPadded + fc := &frameCache{} payload := []byte{0x09, 0x00, 0x00, 0x98, 0x11, 0x12} - _, err = http2parseDataFrame(fc, fh, countError, payload) - assertErrorContains(t, err, "pad size larger than data payload") + _, err = parseDataFrame(fc, fh, countError, payload) + tests.AssertErrorContains(t, err, "pad size larger than data payload") payload = []byte{0x02, 0x00, 0x00, 0x98, 0x11, 0x12} - _, err = http2parseDataFrame(fc, fh, countError, payload) - assertNoError(t, err) + _, err = parseDataFrame(fc, fh, countError, payload) + tests.AssertNoError(t, err) } func TestParseWindowUpdateFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} - _, err := http2parseWindowUpdateFrame(nil, fh, countError, nil) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + _, err := parseWindowUpdateFrame(nil, fh, countError, nil) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") p := []byte{0x00, 0x00, 0x00, 0x00} - _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) - assertErrorContains(t, err, "PROTOCOL_ERROR") + _, err = parseWindowUpdateFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 255 p[0] = 0x01 p[3] = 0x01 - _, err = http2parseWindowUpdateFrame(nil, fh, countError, p) - assertNoError(t, err) + _, err = parseWindowUpdateFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) } func TestParseUnknownFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} p := []byte("test") - f, err := http2parseUnknownFrame(nil, fh, countError, p) - assertNoError(t, err) - uf, ok := f.(*http2UnknownFrame) + f, err := parseUnknownFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) + uf, ok := f.(*UnknownFrame) if !ok { - t.Fatalf("not http2UnknownFrame type: %#+v", f) + t.Fatalf("not UnknownFrame type: %#+v", f) } - assertEqual(t, p, uf.Payload()) + tests.AssertEqual(t, p, uf.Payload()) } func TestParseRSTStreamFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} p := []byte("test.") - _, err := http2parseRSTStreamFrame(nil, fh, countError, p) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + _, err := parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") p = []byte("test") - _, err = http2parseRSTStreamFrame(nil, fh, countError, p) - assertErrorContains(t, err, "PROTOCOL_ERROR") + _, err = parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 1 - _, err = http2parseRSTStreamFrame(nil, fh, countError, p) - assertNoError(t, err) + _, err = parseRSTStreamFrame(nil, fh, countError, p) + tests.AssertNoError(t, err) } func TestParsePingFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} payload := []byte("") - _, err := http2parsePingFrame(nil, fh, countError, payload) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + _, err := parsePingFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") payload = []byte("testtest") fh.StreamID = 1 - _, err = http2parsePingFrame(nil, fh, countError, payload) - assertErrorContains(t, err, "PROTOCOL_ERROR") + _, err = parsePingFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 0 - _, err = http2parsePingFrame(nil, fh, countError, payload) - assertNoError(t, err) + _, err = parsePingFrame(nil, fh, countError, payload) + tests.AssertNoError(t, err) } func TestParseGoAwayFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} countError := func(string) {} payload := []byte("") fh.StreamID = 1 - _, err := http2parseGoAwayFrame(nil, fh, countError, payload) - assertErrorContains(t, err, "PROTOCOL_ERROR") + _, err := parseGoAwayFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") fh.StreamID = 0 - _, err = http2parseGoAwayFrame(nil, fh, countError, payload) - assertErrorContains(t, err, "FRAME_SIZE_ERROR") + _, err = parseGoAwayFrame(nil, fh, countError, payload) + tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") } func TestPushPromiseFrame(t *testing.T) { - fh := http2FrameHeader{valid: true} + fh := FrameHeader{valid: true} buf := []byte("test") - f := &http2PushPromiseFrame{http2FrameHeader: fh, headerFragBuf: buf} - assertEqual(t, buf, f.HeaderBlockFragment()) - assertEqual(t, false, f.HeadersEnded()) + f := &PushPromiseFrame{FrameHeader: fh, headerFragBuf: buf} + tests.AssertEqual(t, buf, f.HeaderBlockFragment()) + tests.AssertEqual(t, false, f.HeadersEnded()) } func TestH2Framer(t *testing.T) { - f := &http2Framer{} + f := &Framer{} f.debugWriteLoggerf = func(s string, i ...interface{}) {} f.logWrite() - assertNotNil(t, f.debugFramer) - assertIsNil(t, f.ErrorDetail()) + tests.AssertNotNil(t, f.debugFramer) + tests.AssertIsNil(t, f.ErrorDetail()) f.w = new(bytes.Buffer) - err := f.WriteRawFrame(http2FrameData, http2FlagDataEndStream, 1, nil) - assertNoError(t, err) + err := f.WriteRawFrame(FrameData, FlagDataEndStream, 1, nil) + tests.AssertNoError(t, err) - param := http2PushPromiseParam{} + param := PushPromiseParam{} err = f.WritePushPromise(param) - assertErrorContains(t, err, "invalid stream ID") + tests.AssertErrorContains(t, err, "invalid stream ID") param.StreamID = 1 param.EndHeaders = true param.PadLength = 2 f.AllowIllegalWrites = true err = f.WritePushPromise(param) - assertNoError(t, err) + tests.AssertNoError(t, err) } diff --git a/h2_go115.go b/internal/http2/go115.go similarity index 63% rename from h2_go115.go rename to internal/http2/go115.go index 66d1cfbb..69769559 100644 --- a/h2_go115.go +++ b/internal/http2/go115.go @@ -5,16 +5,17 @@ //go:build go1.15 // +build go1.15 -package req +package http2 import ( "context" "crypto/tls" + reqtls "github.com/imroc/req/v3/internal/tls" ) // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { dialer := &tls.Dialer{ Config: cfg, } @@ -22,6 +23,6 @@ func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr s if err != nil { return nil, err } - tlsCn := cn.(TLSConn) // DialContext comment promises this will always succeed + tlsCn := cn.(reqtls.Conn) // DialContext comment promises this will always succeed return tlsCn, nil } diff --git a/h2_gotrack.go b/internal/http2/gotrack.go similarity index 73% rename from h2_gotrack.go rename to internal/http2/gotrack.go index 1a4656b7..9933c9f8 100644 --- a/h2_gotrack.go +++ b/internal/http2/gotrack.go @@ -5,7 +5,7 @@ // Defensive debug-only utility to track that functions run on the // goroutine that they're supposed to. -package req +package http2 import ( "bytes" @@ -17,57 +17,57 @@ import ( "sync" ) -var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" +var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" -type http2goroutineLock uint64 +type goroutineLock uint64 -func http2newGoroutineLock() http2goroutineLock { - if !http2DebugGoroutines { +func newGoroutineLock() goroutineLock { + if !DebugGoroutines { return 0 } - return http2goroutineLock(http2curGoroutineID()) + return goroutineLock(curGoroutineID()) } -func (g http2goroutineLock) check() { - if !http2DebugGoroutines { +func (g goroutineLock) check() { + if !DebugGoroutines { return } - if http2curGoroutineID() != uint64(g) { + if curGoroutineID() != uint64(g) { panic("running on the wrong goroutine") } } -func (g http2goroutineLock) checkNotOn() { - if !http2DebugGoroutines { +func (g goroutineLock) checkNotOn() { + if !DebugGoroutines { return } - if http2curGoroutineID() == uint64(g) { + if curGoroutineID() == uint64(g) { panic("running on the wrong goroutine") } } -var http2goroutineSpace = []byte("goroutine ") +var goroutineSpace = []byte("goroutine ") -func http2curGoroutineID() uint64 { - bp := http2littleBuf.Get().(*[]byte) - defer http2littleBuf.Put(bp) +func curGoroutineID() uint64 { + bp := littleBuf.Get().(*[]byte) + defer littleBuf.Put(bp) b := *bp b = b[:runtime.Stack(b, false)] // Parse the 4707 out of "goroutine 4707 [" - b = bytes.TrimPrefix(b, http2goroutineSpace) + b = bytes.TrimPrefix(b, goroutineSpace) i := bytes.IndexByte(b, ' ') if i < 0 { panic(fmt.Sprintf("No space found in %q", b)) } b = b[:i] - n, err := http2parseUintBytes(b, 10, 64) + n, err := parseUintBytes(b, 10, 64) if err != nil { panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) } return n } -var http2littleBuf = sync.Pool{ +var littleBuf = sync.Pool{ New: func() interface{} { buf := make([]byte, 64) return &buf @@ -75,7 +75,7 @@ var http2littleBuf = sync.Pool{ } // parseUintBytes is like strconv.ParseUint, but using a []byte. -func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { +func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { var cutoff, maxVal uint64 if bitSize == 0 { @@ -113,7 +113,7 @@ func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) } n = 0 - cutoff = http2cutoff64(base) + cutoff = cutoff64(base) maxVal = 1<= 1<<64. -func http2cutoff64(base int) uint64 { +func cutoff64(base int) uint64 { if base < 2 { return 0 } diff --git a/h2_gotrack_test.go b/internal/http2/gotrack_test.go similarity index 58% rename from h2_gotrack_test.go rename to internal/http2/gotrack_test.go index 3d0ccdc1..55d2d3a1 100644 --- a/h2_gotrack_test.go +++ b/internal/http2/gotrack_test.go @@ -2,20 +2,21 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "fmt" + "github.com/imroc/req/v3/internal/tests" "strings" "testing" ) func TestGoroutineLock(t *testing.T) { - oldDebug := http2DebugGoroutines - http2DebugGoroutines = true - defer func() { http2DebugGoroutines = oldDebug }() + oldDebug := DebugGoroutines + DebugGoroutines = true + defer func() { DebugGoroutines = oldDebug }() - g := http2newGoroutineLock() + g := newGoroutineLock() g.check() sawPanic := make(chan interface{}) @@ -34,22 +35,22 @@ func TestGoroutineLock(t *testing.T) { func TestParseUintBytes(t *testing.T) { s := []byte{} - _, err := http2parseUintBytes(s, 0, 0) - assertErrorContains(t, err, "invalid syntax") + _, err := parseUintBytes(s, 0, 0) + tests.AssertErrorContains(t, err, "invalid syntax") s = []byte("0x") - _, err = http2parseUintBytes(s, 0, 0) - assertErrorContains(t, err, "invalid syntax") + _, err = parseUintBytes(s, 0, 0) + tests.AssertErrorContains(t, err, "invalid syntax") s = []byte("0x01") - _, err = http2parseUintBytes(s, 0, 0) - assertNoError(t, err) + _, err = parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) s = []byte("0xa1") - _, err = http2parseUintBytes(s, 0, 0) - assertNoError(t, err) + _, err = parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) s = []byte("0xA1") - _, err = http2parseUintBytes(s, 0, 0) - assertNoError(t, err) + _, err = parseUintBytes(s, 0, 0) + tests.AssertNoError(t, err) } diff --git a/h2_headermap.go b/internal/http2/headermap.go similarity index 64% rename from h2_headermap.go rename to internal/http2/headermap.go index a78e405b..50431a6d 100644 --- a/h2_headermap.go +++ b/internal/http2/headermap.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "github.com/imroc/req/v3/internal/ascii" @@ -11,16 +11,16 @@ import ( ) var ( - http2commonBuildOnce sync.Once - http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case - http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case + commonBuildOnce sync.Once + commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case ) -func http2buildCommonHeaderMapsOnce() { - http2commonBuildOnce.Do(http2buildCommonHeaderMaps) +func buildCommonHeaderMapsOnce() { + commonBuildOnce.Do(buildCommonHeaderMaps) } -func http2buildCommonHeaderMaps() { +func buildCommonHeaderMaps() { common := []string{ "accept", "accept-charset", @@ -70,18 +70,18 @@ func http2buildCommonHeaderMaps() { "via", "www-authenticate", } - http2commonLowerHeader = make(map[string]string, len(common)) - http2commonCanonHeader = make(map[string]string, len(common)) + commonLowerHeader = make(map[string]string, len(common)) + commonCanonHeader = make(map[string]string, len(common)) for _, v := range common { chk := http.CanonicalHeaderKey(v) - http2commonLowerHeader[chk] = v - http2commonCanonHeader[v] = chk + commonLowerHeader[chk] = v + commonCanonHeader[v] = chk } } -func http2lowerHeader(v string) (lower string, isAscii bool) { - http2buildCommonHeaderMapsOnce() - if s, ok := http2commonLowerHeader[v]; ok { +func lowerHeader(v string) (lower string, isAscii bool) { + buildCommonHeaderMapsOnce() + if s, ok := commonLowerHeader[v]; ok { return s, true } return ascii.ToLower(v) diff --git a/h2.go b/internal/http2/http2.go similarity index 58% rename from h2.go rename to internal/http2/http2.go index 2665d7f8..253cc5a4 100644 --- a/h2.go +++ b/internal/http2/http2.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bufio" @@ -18,72 +18,72 @@ import ( ) var ( - http2VerboseLogs bool - http2logFrameWrites bool - http2logFrameReads bool - http2inTests bool + VerboseLogs bool + logFrameWrites bool + logFrameReads bool + inTests bool ) func init() { e := os.Getenv("GODEBUG") if strings.Contains(e, "http2debug=1") { - http2VerboseLogs = true + VerboseLogs = true } if strings.Contains(e, "http2debug=2") { - http2VerboseLogs = true - http2logFrameWrites = true - http2logFrameReads = true + VerboseLogs = true + logFrameWrites = true + logFrameReads = true } } const ( // ClientPreface is the string that must be sent by new // connections from clients. - http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/2's TLS setup. - http2NextProtoTLS = "h2" + NextProtoTLS = "h2" // http://http2.github.io/http2-spec/#SettingValues - http2initialHeaderTableSize = 4096 + initialHeaderTableSize = 4096 - http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size + initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size ) var ( - http2clientPreface = []byte(http2ClientPreface) + clientPreface = []byte(ClientPreface) ) // Setting is a setting parameter: which setting it is, and its value. -type http2Setting struct { +type Setting struct { // ID is which setting is being set. // See http://http2.github.io/http2-spec/#SettingValues - ID http2SettingID + ID SettingID // Val is the value. Val uint32 } -func (s http2Setting) String() string { +func (s Setting) String() string { return fmt.Sprintf("[%v = %d]", s.ID, s.Val) } // Valid reports whether the setting is valid. -func (s http2Setting) Valid() error { +func (s Setting) Valid() error { // Limits and error codes from 6.5.2 Defined SETTINGS Parameters switch s.ID { - case http2SettingEnablePush: + case SettingEnablePush: if s.Val != 1 && s.Val != 0 { - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } - case http2SettingInitialWindowSize: + case SettingInitialWindowSize: if s.Val > 1<<31-1 { - return http2ConnectionError(http2ErrCodeFlowControl) + return ConnectionError(ErrCodeFlowControl) } - case http2SettingMaxFrameSize: + case SettingMaxFrameSize: if s.Val < 16384 || s.Val > 1<<24-1 { - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } } return nil @@ -91,28 +91,28 @@ func (s http2Setting) Valid() error { // A SettingID is an HTTP/2 setting as defined in // http://http2.github.io/http2-spec/#iana-settings -type http2SettingID uint16 +type SettingID uint16 const ( - http2SettingHeaderTableSize http2SettingID = 0x1 - http2SettingEnablePush http2SettingID = 0x2 - http2SettingMaxConcurrentStreams http2SettingID = 0x3 - http2SettingInitialWindowSize http2SettingID = 0x4 - http2SettingMaxFrameSize http2SettingID = 0x5 - http2SettingMaxHeaderListSize http2SettingID = 0x6 + SettingHeaderTableSize SettingID = 0x1 + SettingEnablePush SettingID = 0x2 + SettingMaxConcurrentStreams SettingID = 0x3 + SettingInitialWindowSize SettingID = 0x4 + SettingMaxFrameSize SettingID = 0x5 + SettingMaxHeaderListSize SettingID = 0x6 ) -var http2settingName = map[http2SettingID]string{ - http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", - http2SettingEnablePush: "ENABLE_PUSH", - http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", - http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", - http2SettingMaxFrameSize: "MAX_FRAME_SIZE", - http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +var settingName = map[SettingID]string{ + SettingHeaderTableSize: "HEADER_TABLE_SIZE", + SettingEnablePush: "ENABLE_PUSH", + SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + SettingMaxFrameSize: "MAX_FRAME_SIZE", + SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", } -func (s http2SettingID) String() string { - if v, ok := http2settingName[s]; ok { +func (s SettingID) String() string { + if v, ok := settingName[s]; ok { return v } return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) @@ -126,7 +126,7 @@ func (s http2SettingID) String() string { // characters that are compared in a case-insensitive // fashion. However, header field names MUST be converted to // lowercase prior to their encoding in HTTP/2. " -func http2validWireHeaderFieldName(v string) bool { +func validWireHeaderFieldName(v string) bool { if len(v) == 0 { return false } @@ -141,7 +141,7 @@ func http2validWireHeaderFieldName(v string) bool { return true } -func http2httpCodeString(code int) string { +func httpCodeString(code int) string { switch code { case 200: return "200" @@ -157,15 +157,15 @@ func http2httpCodeString(code int) string { // TODO: pick a less arbitrary value? this is a bit under // (3 x typical 1500 byte MTU) at least. Other than that, // not much thought went into it. -const http2bufWriterPoolBufferSize = 4 << 10 +const bufWriterPoolBufferSize = 4 << 10 -var http2bufWriterPool = sync.Pool{ +var bufWriterPool = sync.Pool{ New: func() interface{} { - return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) + return bufio.NewWriterSize(nil, bufWriterPoolBufferSize) }, } -func http2mustUint31(v int32) uint32 { +func mustUint31(v int32) uint32 { if v < 0 || v > 2147483647 { panic("out of range") } @@ -174,7 +174,7 @@ func http2mustUint31(v int32) uint32 { // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 7230, section 3.3. -func http2bodyAllowedForStatus(status int) bool { +func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false @@ -186,41 +186,41 @@ func http2bodyAllowedForStatus(status int) bool { return true } -type http2httpError struct { - _ http2incomparable +type httpError struct { + _ incomparable msg string timeout bool } -func (e *http2httpError) Error() string { return e.msg } +func (e *httpError) Error() string { return e.msg } -func (e *http2httpError) Timeout() bool { return e.timeout } +func (e *httpError) Timeout() bool { return e.timeout } -func (e *http2httpError) Temporary() bool { return true } +func (e *httpError) Temporary() bool { return true } -var errH2Timeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} +var errH2Timeout error = &httpError{msg: "http2: timeout awaiting response headers", timeout: true} -type http2connectionStater interface { +type connectionStater interface { ConnectionState() tls.ConnectionState } -var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} +var sorterPool = sync.Pool{New: func() interface{} { return new(sorter) }} -type http2sorter struct { +type sorter struct { v []string // owned by sorter } -func (s *http2sorter) Len() int { return len(s.v) } +func (s *sorter) Len() int { return len(s.v) } -func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } +func (s *sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } -func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } +func (s *sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } // Keys returns the sorted keys of h. // // The returned slice is only valid until s used again or returned to // its pool. -func (s *http2sorter) Keys(h http.Header) []string { +func (s *sorter) Keys(h http.Header) []string { keys := s.v[:0] for k := range h { keys = append(keys, k) @@ -230,7 +230,7 @@ func (s *http2sorter) Keys(h http.Header) []string { return keys } -func (s *http2sorter) SortStrings(ss []string) { +func (s *sorter) SortStrings(ss []string) { // Our sorter works on s.v, which sorter owns, so // stash it away while we sort the user's buffer. save := s.v @@ -252,11 +252,11 @@ func (s *http2sorter) SortStrings(ss []string) { // We used to enforce that the path also didn't start with "//", but // Google's GFE accepts such paths and Chrome sends them, so ignore // that part of the spec. See golang.org/issue/19103. -func http2validPseudoPath(v string) bool { +func validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } // incomparable is a zero-width, non-comparable type. Adding it to a struct // makes that struct also non-comparable, and generally doesn't add // any size (as long as it's first). -type http2incomparable [0]func() +type incomparable [0]func() diff --git a/h2_test.go b/internal/http2/http2_test.go similarity index 63% rename from h2_test.go rename to internal/http2/http2_test.go index f96928ac..c905758d 100644 --- a/h2_test.go +++ b/internal/http2/http2_test.go @@ -2,29 +2,30 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "flag" "fmt" + "github.com/imroc/req/v3/internal/tests" "net/http" "testing" "time" ) func init() { - http2inTests = true - http2DebugGoroutines = true - flag.BoolVar(&http2VerboseLogs, "verboseh2", http2VerboseLogs, "Verbose HTTP/2 debug logging") + inTests = true + DebugGoroutines = true + flag.BoolVar(&VerboseLogs, "verboseh2", VerboseLogs, "Verbose HTTP/2 debug logging") } func TestSettingString(t *testing.T) { tests := []struct { - s http2Setting + s Setting want string }{ - {http2Setting{http2SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"}, - {http2Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"}, + {Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"}, + {Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"}, } for i, tt := range tests { got := fmt.Sprint(tt.s) @@ -47,7 +48,7 @@ func TestSorterPoolAllocs(t *testing.T) { "b": nil, "c": nil, } - sorter := new(http2sorter) + sorter := new(sorter) if allocs := testing.AllocsPerRun(100, func() { sorter.SortStrings(ss) @@ -80,39 +81,39 @@ func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { func TestSettingValid(t *testing.T) { cases := []struct { - id http2SettingID + id SettingID val uint32 }{ { - id: http2SettingEnablePush, + id: SettingEnablePush, val: 2, }, { - id: http2SettingInitialWindowSize, + id: SettingInitialWindowSize, val: 1 << 31, }, { - id: http2SettingMaxFrameSize, + id: SettingMaxFrameSize, val: 0, }, } for _, c := range cases { - s := &http2Setting{ID: c.id, Val: c.val} - assertEqual(t, true, s.Valid() != nil) + s := &Setting{ID: c.id, Val: c.val} + tests.AssertEqual(t, true, s.Valid() != nil) } - s := &http2Setting{ID: http2SettingMaxHeaderListSize} - assertEqual(t, true, s.Valid() == nil) + s := &Setting{ID: SettingMaxHeaderListSize} + tests.AssertEqual(t, true, s.Valid() == nil) } func TestBodyAllowedForStatus(t *testing.T) { - assertEqual(t, false, http2bodyAllowedForStatus(101)) - assertEqual(t, false, http2bodyAllowedForStatus(204)) - assertEqual(t, false, http2bodyAllowedForStatus(304)) - assertEqual(t, true, http2bodyAllowedForStatus(900)) + tests.AssertEqual(t, false, bodyAllowedForStatus(101)) + tests.AssertEqual(t, false, bodyAllowedForStatus(204)) + tests.AssertEqual(t, false, bodyAllowedForStatus(304)) + tests.AssertEqual(t, true, bodyAllowedForStatus(900)) } func TestHttpError(t *testing.T) { - e := &http2httpError{msg: "test"} - assertEqual(t, "test", e.Error()) - assertEqual(t, true, e.Temporary()) + e := &httpError{msg: "test"} + tests.AssertEqual(t, "test", e.Error()) + tests.AssertEqual(t, true, e.Temporary()) } diff --git a/h2_not_go115.go b/internal/http2/not_go115.go similarity index 79% rename from h2_not_go115.go rename to internal/http2/not_go115.go index cebaf62b..47349d62 100644 --- a/h2_not_go115.go +++ b/internal/http2/not_go115.go @@ -5,10 +5,10 @@ //go:build !go1.15 // +build !go1.15 -package req +package http2 // dialTLSWithContext opens a TLS connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { cn, err := tls.Dial(network, addr, cfg) if err != nil { return nil, err diff --git a/h2_pipe.go b/internal/http2/pipe.go similarity index 73% rename from h2_pipe.go rename to internal/http2/pipe.go index 56bd8e9c..c15b8a77 100644 --- a/h2_pipe.go +++ b/internal/http2/pipe.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "errors" @@ -13,18 +13,18 @@ import ( // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) -type http2pipe struct { +type pipe struct { mu sync.Mutex - c sync.Cond // c.L lazily initialized to &p.mu - b http2pipeBuffer // nil when done reading - unread int // bytes unread when done - err error // read error once empty. non-nil means closed. - breakErr error // immediate read error (caller doesn't see rest of b) - donec chan struct{} // closed on error - readFn func() // optional code to run in Read before error + c sync.Cond // c.L lazily initialized to &p.mu + b pipeBuffer // nil when done reading + unread int // bytes unread when done + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error } -type http2pipeBuffer interface { +type pipeBuffer interface { Len() int io.Writer io.Reader @@ -32,7 +32,7 @@ type http2pipeBuffer interface { // setBuffer initializes the pipe buffer. // It has no effect if the pipe is already closed. -func (p *http2pipe) setBuffer(b http2pipeBuffer) { +func (p *pipe) setBuffer(b pipeBuffer) { p.mu.Lock() defer p.mu.Unlock() if p.err != nil || p.breakErr != nil { @@ -41,7 +41,7 @@ func (p *http2pipe) setBuffer(b http2pipeBuffer) { p.b = b } -func (p *http2pipe) Len() int { +func (p *pipe) Len() int { p.mu.Lock() defer p.mu.Unlock() if p.b == nil { @@ -52,7 +52,7 @@ func (p *http2pipe) Len() int { // Read waits until data is available and copies bytes // from the buffer into p. -func (p *http2pipe) Read(d []byte) (n int, err error) { +func (p *pipe) Read(d []byte) (n int, err error) { p.mu.Lock() defer p.mu.Unlock() if p.c.L == nil { @@ -81,7 +81,7 @@ var errClosedPipeWrite = errors.New("write on closed buffer") // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. -func (p *http2pipe) Write(d []byte) (n int, err error) { +func (p *pipe) Write(d []byte) (n int, err error) { p.mu.Lock() defer p.mu.Unlock() if p.c.L == nil { @@ -103,18 +103,18 @@ func (p *http2pipe) Write(d []byte) (n int, err error) { // read. // // The error must be non-nil. -func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } +func (p *pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } // BreakWithError causes the next Read (waking up a current blocked // Read if needed) to return the provided err immediately, without // waiting for unread data. -func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } +func (p *pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } // closeWithErrorAndCode is like CloseWithError but also sets some code to run // in the caller's goroutine before returning the error. -func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } +func (p *pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } -func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { +func (p *pipe) closeWithError(dst *error, err error, fn func()) { if err == nil { panic("err must be non-nil") } @@ -140,7 +140,7 @@ func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { } // requires p.mu be held. -func (p *http2pipe) closeDoneLocked() { +func (p *pipe) closeDoneLocked() { if p.donec == nil { return } @@ -154,7 +154,7 @@ func (p *http2pipe) closeDoneLocked() { } // Err returns the error (if any) first set by BreakWithError or CloseWithError. -func (p *http2pipe) Err() error { +func (p *pipe) Err() error { p.mu.Lock() defer p.mu.Unlock() if p.breakErr != nil { @@ -165,7 +165,7 @@ func (p *http2pipe) Err() error { // Done returns a channel which is closed if and when this pipe is closed // with CloseWithError. -func (p *http2pipe) Done() <-chan struct{} { +func (p *pipe) Done() <-chan struct{} { p.mu.Lock() defer p.mu.Unlock() if p.donec == nil { diff --git a/h2_pipe_test.go b/internal/http2/pipe_test.go similarity index 94% rename from h2_pipe_test.go rename to internal/http2/pipe_test.go index 434c7c15..83d2dfd2 100644 --- a/h2_pipe_test.go +++ b/internal/http2/pipe_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bytes" @@ -13,7 +13,7 @@ import ( ) func TestPipeClose(t *testing.T) { - var p http2pipe + var p pipe p.b = new(bytes.Buffer) a := errors.New("a") b := errors.New("b") @@ -26,7 +26,7 @@ func TestPipeClose(t *testing.T) { } func TestPipeDoneChan(t *testing.T) { - var p http2pipe + var p pipe done := p.Done() select { case <-done: @@ -42,7 +42,7 @@ func TestPipeDoneChan(t *testing.T) { } func TestPipeDoneChan_ErrFirst(t *testing.T) { - var p http2pipe + var p pipe p.CloseWithError(io.EOF) done := p.Done() select { @@ -53,7 +53,7 @@ func TestPipeDoneChan_ErrFirst(t *testing.T) { } func TestPipeDoneChan_Break(t *testing.T) { - var p http2pipe + var p pipe done := p.Done() select { case <-done: @@ -69,7 +69,7 @@ func TestPipeDoneChan_Break(t *testing.T) { } func TestPipeDoneChan_Break_ErrFirst(t *testing.T) { - var p http2pipe + var p pipe p.BreakWithError(io.EOF) done := p.Done() select { @@ -80,7 +80,7 @@ func TestPipeDoneChan_Break_ErrFirst(t *testing.T) { } func TestPipeCloseWithError(t *testing.T) { - p := &http2pipe{b: new(bytes.Buffer)} + p := &pipe{b: new(bytes.Buffer)} const body = "foo" io.WriteString(p, body) a := errors.New("test error") @@ -108,7 +108,7 @@ func TestPipeCloseWithError(t *testing.T) { } func TestPipeBreakWithError(t *testing.T) { - p := &http2pipe{b: new(bytes.Buffer)} + p := &pipe{b: new(bytes.Buffer)} io.WriteString(p, "foo") a := errors.New("test err") p.BreakWithError(a) diff --git a/h2_server_test.go b/internal/http2/server_test.go similarity index 61% rename from h2_server_test.go rename to internal/http2/server_test.go index 45a27f30..c0c3074b 100644 --- a/h2_server_test.go +++ b/internal/http2/server_test.go @@ -1,4 +1,4 @@ -package req +package http2 import ( "bufio" @@ -38,283 +38,283 @@ import ( // https://www.iana.org/assignments/tls-parameters/tls-parameters.txt const ( - http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 - http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 - http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 - http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 - http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 - http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 - http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 - http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 - http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 - http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A - http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B - http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C - http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D - http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E - http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F - http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 - http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 - http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 - http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 - http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 - http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 - http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 - http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 - http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 - http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 - http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A - http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B - http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F - http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 - http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 - http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 - http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B - http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A - http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 - http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A - http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D - http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E - http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 - http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 - http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 - http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 - http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 - http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 - http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 - http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A - http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B - http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C - http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D - http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 - http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 - http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 - http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 - http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 - http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 - http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 - http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 - http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC - http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF - http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 - http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 - http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF - http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 - http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 - http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 - http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 - http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 - http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A - http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B - http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C - http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F - http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 - http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 - http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 - http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 - http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 - http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 - http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 - http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 - http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A - http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B - http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C - http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F - http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F - http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 - http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 - http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 - http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B - http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C - http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D - http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E - http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F - http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 - http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 - http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 - http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F - http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 - http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 - http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 - http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 - http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 - http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 - http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A - http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 - http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 - http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 - http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A - http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 - http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A - http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D - http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E - http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 - http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 - http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B - http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C - http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D - http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 - http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 - http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 - http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 - http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 - http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 + cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF + cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 + cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 + cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 + cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 + cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 + cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 + cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 + cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 + cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 + cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A + cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B + cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C + cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D + cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E + cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F + cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 + cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 + cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 + cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 + cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 + cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 + cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 + cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 + cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 + cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 + cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A + cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B + cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C + cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D + cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E + cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F + cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 + cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 + cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 + cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 + cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 + cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 + cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 + cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 + cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 + cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 + cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A + cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D + cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E + cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F + cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 + cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 + cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 + cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 + cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 + cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 + cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 + cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B + cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C + cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D + cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E + cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F + cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 + cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 + cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 + cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 + cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 + cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 + cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 + cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 + cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 + cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 + cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A + cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B + cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C + cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D + cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E + cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F + cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 + cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 + cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 + cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 + cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 + cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 + cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A + cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B + cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E + cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F + cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 + cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 + cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 + cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 + cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 + cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 + cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 + cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 + cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A + cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B + cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E + cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F + cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 + cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 + cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 + cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 + cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 + cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 + cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A + cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 + cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 + cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D + cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E + cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F + cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 + cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 + cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 + cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 + cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 + cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 + cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 + cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 + cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A + cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B + cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C + cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D + cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 + cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 + cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 + cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 + cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 + cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 ) // isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. @@ -324,284 +324,284 @@ const ( // "This list includes those cipher suites that do not // offer an ephemeral key exchange and those that are // based on the TLS null, stream or block cipher type" -func http2isBadCipher(cipher uint16) bool { +func isBadCipher(cipher uint16) bool { switch cipher { - case http2cipher_TLS_NULL_WITH_NULL_NULL, - http2cipher_TLS_RSA_WITH_NULL_MD5, - http2cipher_TLS_RSA_WITH_NULL_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_RSA_WITH_RC4_128_MD5, - http2cipher_TLS_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, - http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, - http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, - http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_KRB5_WITH_RC4_128_SHA, - http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, - http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, - http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, - http2cipher_TLS_KRB5_WITH_RC4_128_MD5, - http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, - http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, - http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, - http2cipher_TLS_PSK_WITH_NULL_SHA, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_WITH_NULL_SHA256, - http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, - http2cipher_TLS_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, - http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, - http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, - http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_PSK_WITH_NULL_SHA256, - http2cipher_TLS_PSK_WITH_NULL_SHA384, - http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, - http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, - http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, - http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, - http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, - http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, - http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, - http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, - http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, - http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, - http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, - http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - http2cipher_TLS_RSA_WITH_AES_128_CCM, - http2cipher_TLS_RSA_WITH_AES_256_CCM, - http2cipher_TLS_RSA_WITH_AES_128_CCM_8, - http2cipher_TLS_RSA_WITH_AES_256_CCM_8, - http2cipher_TLS_PSK_WITH_AES_128_CCM, - http2cipher_TLS_PSK_WITH_AES_256_CCM, - http2cipher_TLS_PSK_WITH_AES_128_CCM_8, - http2cipher_TLS_PSK_WITH_AES_256_CCM_8: + case cipher_TLS_NULL_WITH_NULL_NULL, + cipher_TLS_RSA_WITH_NULL_MD5, + cipher_TLS_RSA_WITH_NULL_SHA, + cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, + cipher_TLS_RSA_WITH_RC4_128_MD5, + cipher_TLS_RSA_WITH_RC4_128_SHA, + cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, + cipher_TLS_RSA_WITH_IDEA_CBC_SHA, + cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_RSA_WITH_DES_CBC_SHA, + cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, + cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, + cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, + cipher_TLS_DH_anon_WITH_RC4_128_MD5, + cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, + cipher_TLS_DH_anon_WITH_DES_CBC_SHA, + cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_KRB5_WITH_DES_CBC_SHA, + cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_KRB5_WITH_RC4_128_SHA, + cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, + cipher_TLS_KRB5_WITH_DES_CBC_MD5, + cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, + cipher_TLS_KRB5_WITH_RC4_128_MD5, + cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, + cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, + cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, + cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, + cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, + cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, + cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, + cipher_TLS_PSK_WITH_NULL_SHA, + cipher_TLS_DHE_PSK_WITH_NULL_SHA, + cipher_TLS_RSA_PSK_WITH_NULL_SHA, + cipher_TLS_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, + cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, + cipher_TLS_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, + cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, + cipher_TLS_RSA_WITH_NULL_SHA256, + cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, + cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, + cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, + cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, + cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, + cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, + cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, + cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, + cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, + cipher_TLS_PSK_WITH_RC4_128_SHA, + cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_PSK_WITH_AES_128_CBC_SHA, + cipher_TLS_PSK_WITH_AES_256_CBC_SHA, + cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, + cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, + cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, + cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, + cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, + cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, + cipher_TLS_RSA_WITH_SEED_CBC_SHA, + cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, + cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, + cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, + cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, + cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, + cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, + cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, + cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, + cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, + cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, + cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, + cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, + cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, + cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, + cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, + cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, + cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, + cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, + cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, + cipher_TLS_PSK_WITH_NULL_SHA256, + cipher_TLS_PSK_WITH_NULL_SHA384, + cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, + cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, + cipher_TLS_DHE_PSK_WITH_NULL_SHA256, + cipher_TLS_DHE_PSK_WITH_NULL_SHA384, + cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, + cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, + cipher_TLS_RSA_PSK_WITH_NULL_SHA256, + cipher_TLS_RSA_PSK_WITH_NULL_SHA384, + cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, + cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, + cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, + cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDH_RSA_WITH_NULL_SHA, + cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, + cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, + cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, + cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDH_anon_WITH_NULL_SHA, + cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, + cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, + cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, + cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, + cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, + cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, + cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, + cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, + cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, + cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, + cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, + cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, + cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, + cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, + cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, + cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, + cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, + cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, + cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, + cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, + cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, + cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, + cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, + cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, + cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, + cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + cipher_TLS_RSA_WITH_AES_128_CCM, + cipher_TLS_RSA_WITH_AES_256_CCM, + cipher_TLS_RSA_WITH_AES_128_CCM_8, + cipher_TLS_RSA_WITH_AES_256_CCM_8, + cipher_TLS_PSK_WITH_AES_128_CCM, + cipher_TLS_PSK_WITH_AES_256_CCM, + cipher_TLS_PSK_WITH_AES_128_CCM_8, + cipher_TLS_PSK_WITH_AES_256_CCM_8: return true default: return false @@ -609,11 +609,11 @@ func http2isBadCipher(cipher uint16) bool { } const ( - http2prefaceTimeout = 10 * time.Second - http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway - http2handlerChunkWriteSize = 4 << 10 - http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? - http2maxQueuedControlFrames = 10000 + prefaceTimeout = 10 * time.Second + firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + handlerChunkWriteSize = 4 << 10 + defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + maxQueuedControlFrames = 10000 ) var ( @@ -623,24 +623,24 @@ var ( errStreamClosed = errors.New("http2: stream closed") ) -var http2responseWriterStatePool = sync.Pool{ +var responseWriterStatePool = sync.Pool{ New: func() interface{} { - rws := &http2responseWriterState{} - rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) + rws := &responseWriterState{} + rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize) return rws }, } // Test hooks. var ( - http2testHookOnConn func() - http2testHookGetServerConn func(*http2serverConn) - http2testHookOnPanicMu *sync.Mutex // nil except in tests - http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) + testHookOnConn func() + testHookGetServerConn func(*serverConn) + testHookOnPanicMu *sync.Mutex // nil except in tests + testHookOnPanic func(sc *serverConn, panicVal interface{}) (rePanic bool) ) // Server is an HTTP/2 server. -type http2Server struct { +type Server struct { // MaxHandlers limits the number of http.Handler ServeHTTP goroutines // which may run at a time over all connections. // Negative or zero no limit. @@ -685,7 +685,7 @@ type http2Server struct { // NewWriteScheduler constructs a write scheduler for a connection. // If nil, a default scheduler is chosen. - NewWriteScheduler func() http2WriteScheduler + NewWriteScheduler func() WriteScheduler // CountError, if non-nil, is called on HTTP/2 server errors. // It's intended to increment a metric for monitoring, such @@ -696,44 +696,44 @@ type http2Server struct { // Internal state. This is a pointer (rather than embedded directly) // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. - state *http2serverInternalState + state *serverInternalState } -func (s *http2Server) initialConnRecvWindowSize() int32 { - if s.MaxUploadBufferPerConnection > http2initialWindowSize { +func (s *Server) initialConnRecvWindowSize() int32 { + if s.MaxUploadBufferPerConnection > initialWindowSize { return s.MaxUploadBufferPerConnection } return 1 << 20 } -func (s *http2Server) initialStreamRecvWindowSize() int32 { +func (s *Server) initialStreamRecvWindowSize() int32 { if s.MaxUploadBufferPerStream > 0 { return s.MaxUploadBufferPerStream } return 1 << 20 } -func (s *http2Server) maxReadFrameSize() uint32 { - if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { +func (s *Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize { return v } - return http2defaultMaxReadFrameSize + return defaultMaxReadFrameSize } -func (s *http2Server) maxConcurrentStreams() uint32 { +func (s *Server) maxConcurrentStreams() uint32 { if v := s.MaxConcurrentStreams; v > 0 { return v } - return http2defaultMaxStreams + return defaultMaxStreams } // maxQueuedControlFrames is the maximum number of control frames like // SETTINGS, PING and RST_STREAM that will be queued for writing before // the connection is closed to prevent memory exhaustion attacks. -func (s *http2Server) maxQueuedControlFrames() int { +func (s *Server) maxQueuedControlFrames() int { // TODO: if anybody asks, add a Server field, and remember to define the // behavior of negative values. - return http2maxQueuedControlFrames + return maxQueuedControlFrames } // ServeConn serves HTTP/2 requests on the provided connection and @@ -750,31 +750,31 @@ func (s *http2Server) maxQueuedControlFrames() int { // implemented in terms of providing a suitably-behaving net.Conn. // // The opts parameter is optional. If nil, default values are used. -func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { - baseCtx, cancel := http2serverConnBaseContext(c, opts) +func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { + baseCtx, cancel := serverConnBaseContext(c, opts) defer cancel() - sc := &http2serverConn{ + sc := &serverConn{ srv: s, hs: opts.baseConfig(), conn: c, baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), + bw: newBufferedWriter(c), handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + streams: make(map[uint32]*stream), + readFrameCh: make(chan readFrameResult), + wantWriteFrameCh: make(chan FrameWriteRequest, 8), serveMsgCh: make(chan interface{}, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync - bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way + wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync + bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way doneServing: make(chan struct{}), clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" advMaxStreams: s.maxConcurrentStreams(), - initialStreamSendWindowSize: http2initialWindowSize, - maxFrameSize: http2initialMaxFrameSize, - headerTableSize: http2initialHeaderTableSize, - serveG: http2newGoroutineLock(), + initialStreamSendWindowSize: initialWindowSize, + maxFrameSize: initialMaxFrameSize, + headerTableSize: initialHeaderTableSize, + serveG: newGoroutineLock(), pushEnabled: true, } @@ -793,23 +793,23 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { if s.NewWriteScheduler != nil { sc.writeSched = s.NewWriteScheduler() } else { - sc.writeSched = http2NewRandomWriteScheduler() + sc.writeSched = NewRandomWriteScheduler() } // These start at the RFC-specified defaults. If there is a higher // configured value for inflow, that will be updated when we send a // WINDOW_UPDATE shortly after sending SETTINGS. - sc.flow.add(http2initialWindowSize) - sc.inflow.add(http2initialWindowSize) + sc.flow.add(initialWindowSize) + sc.inflow.add(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - fr := http2NewFramer(sc.bw, c) - fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr := NewFramer(sc.bw, c) + fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) sc.framer = fr - if tc, ok := c.(http2connectionStater); ok { + if tc, ok := c.(connectionStater); ok { sc.tlsState = new(tls.ConnectionState) *sc.tlsState = tc.ConnectionState() // 9.2 Use of TLS Features @@ -823,7 +823,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { // this section with a connection error (Section // 5.4.1) of type INADEQUATE_SECURITY. if sc.tlsState.Version < tls.VersionTLS12 { - sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low") + sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low") return } @@ -839,7 +839,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { // So for now, do nothing here again. } - if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) { + if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { // "Endpoints MAY choose to generate a connection error // (Section 5.4.1) of type INADEQUATE_SECURITY if one of // the prohibited cipher suites are negotiated." @@ -850,23 +850,23 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { // excuses here. If we really must, we could allow an // "AllowInsecureWeakCiphers" option on the server later. // Let's see how it plays out first. - sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) + sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) return } } - if hook := http2testHookGetServerConn; hook != nil { + if hook := testHookGetServerConn; hook != nil { hook(sc) } sc.serve() } -type http2serverInternalState struct { +type serverInternalState struct { mu sync.Mutex - activeConns map[*http2serverConn]struct{} + activeConns map[*serverConn]struct{} } -func (s *http2serverInternalState) registerConn(sc *http2serverConn) { +func (s *serverInternalState) registerConn(sc *serverConn) { if s == nil { return // if the Server was used without calling ConfigureServer } @@ -875,7 +875,7 @@ func (s *http2serverInternalState) registerConn(sc *http2serverConn) { s.mu.Unlock() } -func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { +func (s *serverInternalState) unregisterConn(sc *serverConn) { if s == nil { return // if the Server was used without calling ConfigureServer } @@ -884,7 +884,7 @@ func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { s.mu.Unlock() } -func (s *http2serverInternalState) startGracefulShutdown() { +func (s *serverInternalState) startGracefulShutdown() { if s == nil { return // if the Server was used without calling ConfigureServer } @@ -896,7 +896,7 @@ func (s *http2serverInternalState) startGracefulShutdown() { } // ServeConnOpts are options for the Server.ServeConn method. -type http2ServeConnOpts struct { +type ServeConnOpts struct { // Context is the base context to use. // If nil, context.Background is used. Context context.Context @@ -911,21 +911,21 @@ type http2ServeConnOpts struct { Handler http.Handler } -func (o *http2ServeConnOpts) context() context.Context { +func (o *ServeConnOpts) context() context.Context { if o != nil && o.Context != nil { return o.Context } return context.Background() } -func (o *http2ServeConnOpts) baseConfig() *http.Server { +func (o *ServeConnOpts) baseConfig() *http.Server { if o != nil && o.BaseConfig != nil { return o.BaseConfig } return new(http.Server) } -func (o *http2ServeConnOpts) handler() http.Handler { +func (o *ServeConnOpts) handler() http.Handler { if o != nil { if o.Handler != nil { return o.Handler @@ -937,7 +937,7 @@ func (o *http2ServeConnOpts) handler() http.Handler { return http.DefaultServeMux } -func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { +func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { ctx, cancel = context.WithCancel(opts.context()) ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) if hs := opts.baseConfig(); hs != nil { @@ -949,45 +949,45 @@ func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx conte // bufferedWriter is a buffered writer that writes to w. // Its buffered writer is lazily allocated as needed, to minimize // idle memory usage with many connections. -type http2bufferedWriter struct { - _ http2incomparable +type bufferedWriter struct { + _ incomparable w io.Writer // immutable bw *bufio.Writer // non-nil when data is buffered } -func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { - return &http2bufferedWriter{w: w} +func newBufferedWriter(w io.Writer) *bufferedWriter { + return &bufferedWriter{w: w} } -func (w *http2bufferedWriter) Available() int { +func (w *bufferedWriter) Available() int { if w.bw == nil { - return http2bufWriterPoolBufferSize + return bufWriterPoolBufferSize } return w.bw.Available() } -func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { +func (w *bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { - bw := http2bufWriterPool.Get().(*bufio.Writer) + bw := bufWriterPool.Get().(*bufio.Writer) bw.Reset(w.w) w.bw = bw } return w.bw.Write(p) } -func (w *http2bufferedWriter) Flush() error { +func (w *bufferedWriter) Flush() error { bw := w.bw if bw == nil { return nil } err := bw.Flush() bw.Reset(nil) - http2bufWriterPool.Put(bw) + bufWriterPool.Put(bw) w.bw = nil return err } -func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { +func (sc *serverConn) rejectConn(err ErrCode, debug string) { sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) // ignoring errors. hanging up anyway. sc.framer.WriteGoAway(0, err, []byte(debug)) @@ -995,29 +995,29 @@ func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { sc.conn.Close() } -type http2serverConn struct { +type serverConn struct { // Immutable: - srv *http2Server + srv *Server hs *http.Server conn net.Conn - bw *http2bufferedWriter // writing to conn + bw *bufferedWriter // writing to conn handler http.Handler baseCtx context.Context - framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + framer *Framer + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan FrameWriteRequest // from handlers -> serve + wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan bodyReadMsg // from handlers -> serve + serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop + flow flow // conn-wide (not stream-specific) outbound flow control + inflow flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string - writeSched http2WriteScheduler + writeSched WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): - serveG http2goroutineLock // used to verify funcs are on serve() + serveG goroutineLock // used to verify funcs are on serve() pushEnabled bool sawFirstSettings bool // got the initial SETTINGS frame after the preface needToSendSettingsAck bool @@ -1029,7 +1029,7 @@ type http2serverConn struct { curPushedStreams uint32 // number of open streams initiated by server push maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes - streams map[uint32]*http2stream + streams map[uint32]*stream initialStreamSendWindowSize int32 maxFrameSize int32 headerTableSize uint32 @@ -1041,7 +1041,7 @@ type http2serverConn struct { inGoAway bool // we've started to or sent GOAWAY inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop needToSendGoAway bool // we need to schedule a GOAWAY frame write - goAwayCode http2ErrCode + goAwayCode ErrCode shutdownTimer *time.Timer // nil until used idleTimer *time.Timer // nil if unused @@ -1053,7 +1053,7 @@ type http2serverConn struct { shutdownOnce sync.Once } -func (sc *http2serverConn) maxHeaderListSize() uint32 { +func (sc *serverConn) maxHeaderListSize() uint32 { n := sc.hs.MaxHeaderBytes if n <= 0 { n = http.DefaultMaxHeaderBytes @@ -1065,29 +1065,29 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } -func (sc *http2serverConn) curOpenStreams() uint32 { +func (sc *serverConn) curOpenStreams() uint32 { sc.serveG.check() return sc.curClientStreams + sc.curPushedStreams } // A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). -type http2closeWaiter chan struct{} +type closeWaiter chan struct{} // Init makes a closeWaiter usable. // It exists because so a closeWaiter value can be placed inside a // larger struct and have the Mutex and Cond's memory in the same // allocation. -func (cw *http2closeWaiter) Init() { +func (cw *closeWaiter) Init() { *cw = make(chan struct{}) } // Close marks the closeWaiter as closed and unblocks any waiters. -func (cw http2closeWaiter) Close() { +func (cw closeWaiter) Close() { close(cw) } // Wait waits for the closeWaiter to become closed. -func (cw http2closeWaiter) Wait() { +func (cw closeWaiter) Wait() { <-cw } @@ -1098,21 +1098,21 @@ func (cw http2closeWaiter) Wait() { // handler, this struct intentionally has no pointer to the // *responseWriter{,State} itself, as the Handler ending nils out the // responseWriter's state field. -type http2stream struct { +type stream struct { // immutable: - sc *http2serverConn + sc *serverConn id uint32 - body *http2pipe // non-nil if expecting DATA frames - cw http2closeWaiter // closed wait stream transitions to closed state + body *pipe // non-nil if expecting DATA frames + cw closeWaiter // closed wait stream transitions to closed state ctx context.Context cancelCtx func() // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow http2flow // limits writing from Handler to client - inflow http2flow // what the client is allowed to POST/etc to us - state http2streamState + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow flow // limits writing from Handler to client + inflow flow // what the client is allowed to POST/etc to us + state streamState resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen wroteHeaders bool // whether we wrote headers (not status 100) @@ -1122,25 +1122,25 @@ type http2stream struct { reqTrailer http.Header // handler's Request.Trailer } -func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } +func (sc *serverConn) Framer() *Framer { return sc.framer } -func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } +func (sc *serverConn) CloseConn() error { return sc.conn.Close() } -func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } +func (sc *serverConn) Flush() error { return sc.bw.Flush() } -func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { +func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { return sc.hpackEncoder, &sc.headerWriteBuf } const ( // SETTINGS_MAX_FRAME_SIZE default // http://http2.github.io/http2-spec/#rfc.section.6.5.2 - http2initialMaxFrameSize = 16384 + initialMaxFrameSize = 16384 - http2defaultMaxReadFrameSize = 1 << 20 + defaultMaxReadFrameSize = 1 << 20 ) -type http2streamState int +type streamState int // HTTP/2 stream states. // @@ -1155,26 +1155,26 @@ type http2streamState int // "reserved (remote)" is omitted since the client code does not // support server push. const ( - http2stateIdle http2streamState = iota - http2stateOpen - http2stateHalfClosedLocal - http2stateHalfClosedRemote - http2stateClosed + stateIdle streamState = iota + stateOpen + stateHalfClosedLocal + stateHalfClosedRemote + stateClosed ) -var http2stateName = [...]string{ - http2stateIdle: "Idle", - http2stateOpen: "Open", - http2stateHalfClosedLocal: "HalfClosedLocal", - http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateClosed: "Closed", +var stateName = [...]string{ + stateIdle: "Idle", + stateOpen: "Open", + stateHalfClosedLocal: "HalfClosedLocal", + stateHalfClosedRemote: "HalfClosedRemote", + stateClosed: "Closed", } -func (st http2streamState) String() string { - return http2stateName[st] +func (st streamState) String() string { + return stateName[st] } -func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { +func (sc *serverConn) state(streamID uint32) (streamState, *stream) { sc.serveG.check() // http://tools.ietf.org/html/rfc7540#section-5.1 if st, ok := sc.streams[streamID]; ok { @@ -1188,32 +1188,32 @@ func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2strea // state when the first frame for stream 7 is sent or received." if streamID%2 == 1 { if streamID <= sc.maxClientStreamID { - return http2stateClosed, nil + return stateClosed, nil } } else { if streamID <= sc.maxPushPromiseID { - return http2stateClosed, nil + return stateClosed, nil } } - return http2stateIdle, nil + return stateIdle, nil } // setConnState calls the net/http ConnState hook for this connection, if configured. // Note that the net/http package does StateNew and StateClosed for us. // There is currently no plan for StateHijacked or hijacking HTTP/2 connections. -func (sc *http2serverConn) setConnState(state http.ConnState) { +func (sc *serverConn) setConnState(state http.ConnState) { if sc.hs.ConnState != nil { sc.hs.ConnState(sc.conn, state) } } -func (sc *http2serverConn) vlogf(format string, args ...interface{}) { - if http2VerboseLogs { +func (sc *serverConn) vlogf(format string, args ...interface{}) { + if VerboseLogs { sc.logf(format, args...) } } -func (sc *http2serverConn) logf(format string, args ...interface{}) { +func (sc *serverConn) logf(format string, args ...interface{}) { if lg := sc.hs.ErrorLog; lg != nil { lg.Printf(format, args...) } else { @@ -1225,7 +1225,7 @@ func (sc *http2serverConn) logf(format string, args ...interface{}) { // // TODO: remove this helper function once http2 can use build // tags. See comment in isClosedConnError. -func http2errno(v error) uintptr { +func errno(v error) uintptr { if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { return uintptr(rv.Uint()) } @@ -1234,7 +1234,7 @@ func http2errno(v error) uintptr { // isClosedConnError reports whether err is an error from use of a closed // network connection. -func http2isClosedConnError(err error) bool { +func isClosedConnError(err error) bool { if err == nil { return false } @@ -1248,7 +1248,7 @@ func http2isClosedConnError(err error) bool { } // TODO(bradfitz): x/tools/cmd/bundle doesn't really support - // build tags, so I can't make an http2_windows.go file with + // build tags, so I can't make an _windows.go file with // Windows-specific stuff. Fix that and move this, once we // have a way to bundle this into std's net/http somehow. if runtime.GOOS == "windows" { @@ -1256,7 +1256,7 @@ func http2isClosedConnError(err error) bool { if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { const WSAECONNABORTED = 10053 const WSAECONNRESET = 10054 - if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { return true } } @@ -1265,11 +1265,11 @@ func http2isClosedConnError(err error) bool { return false } -func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { +func (sc *serverConn) condlogf(err error, format string, args ...interface{}) { if err == nil { return } - if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == errPrefaceTimeout { + if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout { // Boring, expected errors. sc.vlogf(format, args...) } else { @@ -1277,10 +1277,10 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ } } -func (sc *http2serverConn) canonicalHeader(v string) string { +func (sc *serverConn) canonicalHeader(v string) string { sc.serveG.check() - http2buildCommonHeaderMapsOnce() - cv, ok := http2commonCanonHeader[v] + buildCommonHeaderMapsOnce() + cv, ok := commonCanonHeader[v] if ok { return cv } @@ -1304,8 +1304,8 @@ func (sc *http2serverConn) canonicalHeader(v string) string { return cv } -type http2readFrameResult struct { - f http2Frame // valid until readMore is called +type readFrameResult struct { + f Frame // valid until readMore is called err error // readMore should be called once the consumer no longer needs or @@ -1315,23 +1315,23 @@ type http2readFrameResult struct { } // A gate lets two goroutines coordinate their activities. -type http2gate chan struct{} +type gate chan struct{} -func (g http2gate) Done() { g <- struct{}{} } +func (g gate) Done() { g <- struct{}{} } -func (g http2gate) Wait() { <-g } +func (g gate) Wait() { <-g } // readFrames is the loop that reads incoming frames. // It takes care to only read one frame at a time, blocking until the // consumer is done with the frame. // It's run on its own goroutine. -func (sc *http2serverConn) readFrames() { - gate := make(http2gate) +func (sc *serverConn) readFrames() { + gate := make(gate) gateDone := gate.Done for { f, err := sc.framer.ReadFrame() select { - case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: + case sc.readFrameCh <- readFrameResult{f, err, gateDone}: case <-sc.doneServing: return } @@ -1340,58 +1340,58 @@ func (sc *http2serverConn) readFrames() { case <-sc.doneServing: return } - if http2terminalReadFrameError(err) { + if terminalReadFrameError(err) { return } } } // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. -type http2frameWriteResult struct { - _ http2incomparable - wr http2FrameWriteRequest // what was written (or attempted) - err error // result of the writeFrame call +type frameWriteResult struct { + _ incomparable + wr FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { +func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest) { err := wr.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} + sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err} } -func (sc *http2serverConn) closeAllStreamsOnConnClose() { +func (sc *serverConn) closeAllStreamsOnConnClose() { sc.serveG.check() for _, st := range sc.streams { sc.closeStream(st, errClientDisconnected) } } -func (sc *http2serverConn) stopShutdownTimer() { +func (sc *serverConn) stopShutdownTimer() { sc.serveG.check() if t := sc.shutdownTimer; t != nil { t.Stop() } } -func (sc *http2serverConn) notePanic() { +func (sc *serverConn) notePanic() { // Note: this is for serverConn.serve panicking, not http.Handler code. - if http2testHookOnPanicMu != nil { - http2testHookOnPanicMu.Lock() - defer http2testHookOnPanicMu.Unlock() + if testHookOnPanicMu != nil { + testHookOnPanicMu.Lock() + defer testHookOnPanicMu.Unlock() } - if http2testHookOnPanic != nil { + if testHookOnPanic != nil { if e := recover(); e != nil { - if http2testHookOnPanic(sc, e) { + if testHookOnPanic(sc, e) { panic(e) } } } } -func (sc *http2serverConn) serve() { +func (sc *serverConn) serve() { sc.serveG.check() defer sc.notePanic() defer sc.conn.Close() @@ -1399,23 +1399,23 @@ func (sc *http2serverConn) serve() { defer sc.stopShutdownTimer() defer close(sc.doneServing) // unblocks handlers trying to send - if http2VerboseLogs { + if VerboseLogs { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(http2FrameWriteRequest{ - write: http2writeSettings{ - {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, - {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, - {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, - {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, + sc.writeFrame(FrameWriteRequest{ + write: writeSettings{ + {SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {SettingMaxConcurrentStreams, sc.advMaxStreams}, + {SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, }, }) sc.unackedSettings++ // Each connection starts with initialWindowSize inflow tokens. // If a higher value is configured, we add more tokens. - if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { + if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { sc.sendWindowUpdate(nil, int(diff)) } @@ -1437,7 +1437,7 @@ func (sc *http2serverConn) serve() { go sc.readFrames() // closed by defer sc.conn.Close above - settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) defer settingsTimer.Stop() loopNum := 0 @@ -1445,7 +1445,7 @@ func (sc *http2serverConn) serve() { loopNum++ select { case wr := <-sc.wantWriteFrameCh: - if se, ok := wr.write.(http2StreamError); ok { + if se, ok := wr.write.(StreamError); ok { sc.resetStream(se) break } @@ -1476,23 +1476,23 @@ func (sc *http2serverConn) serve() { switch v := msg.(type) { case func(int): v(loopNum) // for testing - case *http2serverMessage: + case *serverMessage: switch v { - case http2settingsTimerMsg: + case settingsTimerMsg: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return - case http2idleTimerMsg: + case idleTimerMsg: sc.vlogf("connection is idle") - sc.goAway(http2ErrCodeNo) - case http2shutdownTimerMsg: + sc.goAway(ErrCodeNo) + case shutdownTimerMsg: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return - case http2gracefulShutdownMsg: + case gracefulShutdownMsg: sc.startGracefulShutdownInternal() default: panic("unknown timer") } - case *http2startPushRequest: + case *startPushRequest: sc.startPush(v) default: panic(fmt.Sprintf("unexpected type %T", v)) @@ -1511,14 +1511,14 @@ func (sc *http2serverConn) serve() { // with no error code (graceful shutdown), don't start the timer until // all open streams have been completed. sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame - gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 - if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { - sc.shutDownIn(http2goAwayTimeout) + gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0 + if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) { + sc.shutDownIn(goAwayTimeout) } } } -func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { +func (sc *serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { select { case <-sc.doneServing: case <-sharedCh: @@ -1526,23 +1526,23 @@ func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, priva } } -type http2serverMessage int +type serverMessage int // Message values sent to serveMsgCh. var ( - http2settingsTimerMsg = new(http2serverMessage) - http2idleTimerMsg = new(http2serverMessage) - http2shutdownTimerMsg = new(http2serverMessage) - http2gracefulShutdownMsg = new(http2serverMessage) + settingsTimerMsg = new(serverMessage) + idleTimerMsg = new(serverMessage) + shutdownTimerMsg = new(serverMessage) + gracefulShutdownMsg = new(serverMessage) ) -func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } +func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } -func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } +func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) } -func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } +func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) } -func (sc *http2serverConn) sendServeMsg(msg interface{}) { +func (sc *serverConn) sendServeMsg(msg interface{}) { sc.serveG.checkNotOn() // NOT select { case sc.serveMsgCh <- msg: @@ -1555,27 +1555,27 @@ var errPrefaceTimeout = errors.New("timeout waiting for client preface") // readPreface reads the ClientPreface greeting from the peer or // returns errPrefaceTimeout on timeout, or an error if the greeting // is invalid. -func (sc *http2serverConn) readPreface() error { +func (sc *serverConn) readPreface() error { errc := make(chan error, 1) go func() { // Read the client preface - buf := make([]byte, len(http2ClientPreface)) + buf := make([]byte, len(ClientPreface)) if _, err := io.ReadFull(sc.conn, buf); err != nil { errc <- err - } else if !bytes.Equal(buf, http2clientPreface) { + } else if !bytes.Equal(buf, clientPreface) { errc <- fmt.Errorf("bogus greeting %q", buf) } else { errc <- nil } }() - timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *http2Server? + timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { case <-timer.C: return errPrefaceTimeout case err := <-errc: if err == nil { - if http2VerboseLogs { + if VerboseLogs { sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) } } @@ -1583,21 +1583,21 @@ func (sc *http2serverConn) readPreface() error { } } -var http2errChanPool = sync.Pool{ +var errChanPool = sync.Pool{ New: func() interface{} { return make(chan error, 1) }, } -var http2writeDataPool = sync.Pool{ - New: func() interface{} { return new(http2writeData) }, +var writeDataPool = sync.Pool{ + New: func() interface{} { return new(writeData) }, } // writeDataFromHandler writes DATA response frames from a handler on // the given stream. -func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { - ch := http2errChanPool.Get().(chan error) - writeArg := http2writeDataPool.Get().(*http2writeData) - *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2FrameWriteRequest{ +func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error { + ch := errChanPool.Get().(chan error) + writeArg := writeDataPool.Get().(*writeData) + *writeArg = writeData{stream.id, data, endStream} + err := sc.writeFrameFromHandler(FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -1626,9 +1626,9 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte return errStreamClosed } } - http2errChanPool.Put(ch) + errChanPool.Put(ch) if frameWriteDone { - http2writeDataPool.Put(writeArg) + writeDataPool.Put(writeArg) } return err } @@ -1640,7 +1640,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { +func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error { sc.serveG.checkNotOn() // NOT select { case sc.wantWriteFrameCh <- wr: @@ -1660,7 +1660,7 @@ func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) erro // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { +func (sc *serverConn) writeFrame(wr FrameWriteRequest) { sc.serveG.check() // If true, wr will not be written and wr.done will not be signaled. @@ -1685,8 +1685,8 @@ func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { // may result in duplicate RST_STREAMs in some cases, but the client should // ignore those. if wr.StreamID() != 0 { - _, isReset := wr.write.(http2StreamError) - if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + _, isReset := wr.write.(StreamError) + if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset { ignoreWrite = true } } @@ -1694,9 +1694,9 @@ func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { // Don't send a 100-continue response if we've already sent headers. // See golang.org/issue/14030. switch wr.write.(type) { - case *http2writeResHeaders: + case *writeResHeaders: wr.stream.wroteHeaders = true - case http2write100ContinueHeadersFrame: + case write100ContinueHeadersFrame: if wr.stream.wroteHeaders { // We do not need to notify wr.done because this frame is // never written with wr.done != nil. @@ -1724,7 +1724,7 @@ func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { // startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the // serve goroutine's state about the world, updated from info in wr. -func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { +func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") @@ -1733,19 +1733,19 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { st := wr.stream if st != nil { switch st.state { - case http2stateHalfClosedLocal: + case stateHalfClosedLocal: switch wr.write.(type) { - case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: + case StreamError, handlerPanicRST, writeWindowUpdate: // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE // in this state. (We never send PRIORITY from the server, so that is not checked.) default: panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) } - case http2stateClosed: + case stateClosed: panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) } } - if wpp, ok := wr.write.(*http2writePushPromise); ok { + if wpp, ok := wr.write.(*writePushPromise); ok { var err error wpp.promisedID, err = wpp.allocatePromisedID() if err != nil { @@ -1760,7 +1760,7 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { if wr.write.staysWithinBuffer(sc.bw.Available()) { sc.writingFrameAsync = false err := wr.write.writeFrame(sc) - sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) + sc.wroteFrame(frameWriteResult{wr: wr, err: err}) } else { sc.writingFrameAsync = true go sc.writeFrameAsync(wr) @@ -1770,11 +1770,11 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { // errHandlerPanicked is the error given to any callers blocked in a read from // Request.Body when the main goroutine panics. Since most handlers read in the // main ServeHTTP goroutine, this will show up rarely. -var http2errHandlerPanicked = errors.New("http2: handler panicked") +var errHandlerPanicked = errors.New("http2: handler panicked") // wroteFrame is called on the serve goroutine with the result of // whatever happened on writeFrameAsync. -func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { +func (sc *serverConn) wroteFrame(res frameWriteResult) { sc.serveG.check() if !sc.writingFrame { panic("internal error: expected to be already writing a frame") @@ -1784,13 +1784,13 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { wr := res.wr - if http2writeEndsStream(wr.write) { + if writeEndsStream(wr.write) { st := wr.stream if st == nil { panic("internal error: expecting non-nil stream") } switch st.state { - case http2stateOpen: + case stateOpen: // Here we would go to stateHalfClosedLocal in // theory, but since our handler is done and // the net/http package provides no mechanism @@ -1801,24 +1801,24 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // hanging up on them. We'll transition to // stateClosed after the RST_STREAM frame is // written. - st.state = http2stateHalfClosedLocal + st.state = stateHalfClosedLocal // Section 8.1: a server MAY request that the client abort // transmission of a request without error by sending a // RST_STREAM with an error code of NO_ERROR after sending // a complete response. - sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) - case http2stateHalfClosedRemote: + sc.resetStream(streamError(st.id, ErrCodeNo)) + case stateHalfClosedRemote: sc.closeStream(st, errHandlerComplete) } } else { switch v := wr.write.(type) { - case http2StreamError: + case StreamError: // st may be unknown if the RST_STREAM was generated to reject bad input. if st, ok := sc.streams[v.StreamID]; ok { sc.closeStream(st, v) } - case http2handlerPanicRST: - sc.closeStream(wr.stream, http2errHandlerPanicked) + case handlerPanicRST: + sc.closeStream(wr.stream, errHandlerPanicked) } } @@ -1838,7 +1838,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // // If a frame isn't being written and there's nothing else to send, we // flush the write buffer. -func (sc *http2serverConn) scheduleFrameWrite() { +func (sc *serverConn) scheduleFrameWrite() { sc.serveG.check() if sc.writingFrame || sc.inFrameScheduleLoop { return @@ -1847,8 +1847,8 @@ func (sc *http2serverConn) scheduleFrameWrite() { for !sc.writingFrameAsync { if sc.needToSendGoAway { sc.needToSendGoAway = false - sc.startFrameWrite(http2FrameWriteRequest{ - write: &http2writeGoAway{ + sc.startFrameWrite(FrameWriteRequest{ + write: &writeGoAway{ maxStreamID: sc.maxClientStreamID, code: sc.goAwayCode, }, @@ -1857,10 +1857,10 @@ func (sc *http2serverConn) scheduleFrameWrite() { } if sc.needToSendSettingsAck { sc.needToSendSettingsAck = false - sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}}) continue } - if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if !sc.inGoAway || sc.goAwayCode == ErrCodeNo { if wr, ok := sc.writeSched.Pop(); ok { if wr.isControl() { sc.queuedControlFrames-- @@ -1870,7 +1870,7 @@ func (sc *http2serverConn) scheduleFrameWrite() { } } if sc.needsFrameFlush { - sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}}) sc.needsFrameFlush = false // after startFrameWrite, since it sets this true continue } @@ -1880,15 +1880,15 @@ func (sc *http2serverConn) scheduleFrameWrite() { } // startGracefulShutdown gracefully shuts down a connection. This -// sends GOAWAY with http2ErrCodeNo to tell the client we're gracefully +// sends GOAWAY with ErrCodeNo to tell the client we're gracefully // shutting down. The connection isn't closed until all current // streams are done. // // startGracefulShutdown returns immediately; it does not wait until // the connection has shut down. -func (sc *http2serverConn) startGracefulShutdown() { +func (sc *serverConn) startGracefulShutdown() { sc.serveG.checkNotOn() // NOT - sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) + sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) }) } // After sending GOAWAY with an error code (non-graceful shutdown), the @@ -1907,13 +1907,13 @@ func (sc *http2serverConn) startGracefulShutdown() { // loopback interface making the expected RTT very small. // // TODO: configurable? -var http2goAwayTimeout = 1 * time.Second +var goAwayTimeout = 1 * time.Second -func (sc *http2serverConn) startGracefulShutdownInternal() { - sc.goAway(http2ErrCodeNo) +func (sc *serverConn) startGracefulShutdownInternal() { + sc.goAway(ErrCodeNo) } -func (sc *http2serverConn) goAway(code http2ErrCode) { +func (sc *serverConn) goAway(code ErrCode) { sc.serveG.check() if sc.inGoAway { return @@ -1924,14 +1924,14 @@ func (sc *http2serverConn) goAway(code http2ErrCode) { sc.scheduleFrameWrite() } -func (sc *http2serverConn) shutDownIn(d time.Duration) { +func (sc *serverConn) shutDownIn(d time.Duration) { sc.serveG.check() sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) } -func (sc *http2serverConn) resetStream(se http2StreamError) { +func (sc *serverConn) resetStream(se StreamError) { sc.serveG.check() - sc.writeFrame(http2FrameWriteRequest{write: se}) + sc.writeFrame(FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { st.resetQueued = true } @@ -1942,22 +1942,22 @@ func (sc *http2serverConn) resetStream(se http2StreamError) { // window to exceed this maximum it MUST terminate either the stream // or the connection, as appropriate. For streams, [...]; for the // connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." -type http2goAwayFlowError struct{} +type goAwayFlowError struct{} -func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } +func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" } // processFrameFromReader processes the serve loop's read from readFrameCh from the // frame-reading goroutine. // processFrameFromReader returns whether the connection should be kept open. -func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { +func (sc *serverConn) processFrameFromReader(res readFrameResult) bool { sc.serveG.check() err := res.err if err != nil { if err == errFrameTooLarge { - sc.goAway(http2ErrCodeFrameSize) + sc.goAway(ErrCodeFrameSize) return true // goAway will close the loop } - clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) + clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) if clientGone { // TODO: could we also get into this state if // the peer does a half close @@ -1971,8 +1971,8 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool } } else { f := res.f - if http2VerboseLogs { - sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) + if VerboseLogs { + sc.vlogf("http2: server read frame %v", summarizeFrame(f)) } err = sc.processFrame(f) if err == nil { @@ -1981,15 +1981,15 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool } switch ev := err.(type) { - case http2StreamError: + case StreamError: sc.resetStream(ev) return true - case http2goAwayFlowError: - sc.goAway(http2ErrCodeFlowControl) + case goAwayFlowError: + sc.goAway(ErrCodeFlowControl) return true - case http2ConnectionError: + case ConnectionError: sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) - sc.goAway(http2ErrCode(ev)) + sc.goAway(ErrCode(ev)) return true // goAway will handle shutdown default: if res.err != nil { @@ -2001,45 +2001,45 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool } } -func (sc *http2serverConn) processFrame(f http2Frame) error { +func (sc *serverConn) processFrame(f Frame) error { sc.serveG.check() // First frame received must be SETTINGS. if !sc.sawFirstSettings { - if _, ok := f.(*http2SettingsFrame); !ok { - return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) + if _, ok := f.(*SettingsFrame); !ok { + return sc.countError("first_settings", ConnectionError(ErrCodeProtocol)) } sc.sawFirstSettings = true } switch f := f.(type) { - case *http2SettingsFrame: + case *SettingsFrame: return sc.processSettings(f) - case *http2MetaHeadersFrame: + case *MetaHeadersFrame: return sc.processHeaders(f) - case *http2WindowUpdateFrame: + case *WindowUpdateFrame: return sc.processWindowUpdate(f) - case *http2PingFrame: + case *PingFrame: return sc.processPing(f) - case *http2DataFrame: + case *DataFrame: return sc.processData(f) - case *http2RSTStreamFrame: + case *RSTStreamFrame: return sc.processResetStream(f) - case *http2PriorityFrame: + case *PriorityFrame: return sc.processPriority(f) - case *http2GoAwayFrame: + case *GoAwayFrame: return sc.processGoAway(f) - case *http2PushPromiseFrame: + case *PushPromiseFrame: // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("push_promise", ConnectionError(ErrCodeProtocol)) default: sc.vlogf("http2: server ignoring frame: %v", f.Header()) return nil } } -func (sc *http2serverConn) processPing(f *http2PingFrame) error { +func (sc *serverConn) processPing(f *PingFrame) error { sc.serveG.check() if f.IsAck() { // 6.7 PING: " An endpoint MUST NOT respond to PING frames @@ -2052,26 +2052,26 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { // identifier field value other than 0x0, the recipient MUST // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR." - return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol)) } - if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + if sc.inGoAway && sc.goAwayCode != ErrCodeNo { return nil } - sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) + sc.writeFrame(FrameWriteRequest{write: writePingAck{f}}) return nil } -func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { +func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { sc.serveG.check() switch { case f.StreamID != 0: // stream-level flow control state, st := sc.state(f.StreamID) - if state == http2stateIdle { + if state == stateIdle { // Section 5.1: "Receiving any frame other than HEADERS // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol)) } if st == nil { // "WINDOW_UPDATE can be sent by a peer that has sent a @@ -2082,42 +2082,42 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error return nil } if !st.flow.add(int32(f.Increment)) { - return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) + return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl)) } default: // connection-level flow control if !sc.flow.add(int32(f.Increment)) { - return http2goAwayFlowError{} + return goAwayFlowError{} } } sc.scheduleFrameWrite() return nil } -func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { +func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { sc.serveG.check() state, st := sc.state(f.StreamID) - if state == http2stateIdle { + if state == stateIdle { // 6.4 "RST_STREAM frames MUST NOT be sent for a // stream in the "idle" state. If a RST_STREAM frame // identifying an idle stream is received, the // recipient MUST treat this as a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol)) } if st != nil { st.cancelCtx() - sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) + sc.closeStream(st, streamError(f.StreamID, f.ErrCode)) } return nil } -func (sc *http2serverConn) closeStream(st *http2stream, err error) { +func (sc *serverConn) closeStream(st *stream, err error) { sc.serveG.check() - if st.state == http2stateIdle || st.state == http2stateClosed { + if st.state == stateIdle || st.state == stateClosed { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } - st.state = http2stateClosed + st.state = stateClosed if st.writeDeadline != nil { st.writeDeadline.Stop() } @@ -2132,7 +2132,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { if sc.srv.IdleTimeout != 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } - if http2h1ServerKeepAlivesDisabled(sc.hs) { + if h1ServerKeepAlivesDisabled(sc.hs) { sc.startGracefulShutdownInternal() } } @@ -2147,7 +2147,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { sc.writeSched.CloseStream(st.id) } -func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { +func (sc *serverConn) processSettings(f *SettingsFrame) error { sc.serveG.check() if f.IsAck() { sc.unackedSettings-- @@ -2155,7 +2155,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // Why is the peer ACKing settings we never sent? // The spec doesn't mention this case, but // hang up on them anyway. - return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol)) } return nil } @@ -2163,7 +2163,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // This isn't actually in the spec, but hang up on // suspiciously large settings frames or those with // duplicate entries. - return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol)) } if err := f.ForeachSetting(sc.processSetting); err != nil { return err @@ -2175,40 +2175,40 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { return nil } -func (sc *http2serverConn) processSetting(s http2Setting) error { +func (sc *serverConn) processSetting(s Setting) error { sc.serveG.check() if err := s.Valid(); err != nil { return err } - if http2VerboseLogs { + if VerboseLogs { sc.vlogf("http2: server processing setting %v", s) } switch s.ID { - case http2SettingHeaderTableSize: + case SettingHeaderTableSize: sc.headerTableSize = s.Val sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) - case http2SettingEnablePush: + case SettingEnablePush: sc.pushEnabled = s.Val != 0 - case http2SettingMaxConcurrentStreams: + case SettingMaxConcurrentStreams: sc.clientMaxStreams = s.Val - case http2SettingInitialWindowSize: + case SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) - case http2SettingMaxFrameSize: + case SettingMaxFrameSize: sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 - case http2SettingMaxHeaderListSize: + case SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: // Unknown setting: "An endpoint that receives a SETTINGS // frame with any unknown or unsupported identifier MUST // ignore that setting." - if http2VerboseLogs { + if VerboseLogs { sc.vlogf("http2: server ignoring unknown setting %v", s) } } return nil } -func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { +func (sc *serverConn) processSettingInitialWindowSize(val uint32) error { sc.serveG.check() // Note: val already validated to be within range by // processSetting's Valid call. @@ -2230,16 +2230,16 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { // control window to exceed the maximum size as a // connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR." - return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) + return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl)) } } return nil } -func (sc *http2serverConn) processData(f *http2DataFrame) error { +func (sc *serverConn) processData(f *DataFrame) error { sc.serveG.check() id := f.Header().StreamID - if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || id > sc.maxClientStreamID) { + if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || id > sc.maxClientStreamID) { // Discard all DATA frames if the GOAWAY is due to an // error, or: // @@ -2252,7 +2252,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { data := f.Data() state, st := sc.state(id) - if id == 0 || state == http2stateIdle { + if id == 0 || state == stateIdle { // Section 6.1: "DATA frames MUST be associated with a // stream. If a DATA frame is received whose stream // identifier field is 0x0, the recipient MUST respond @@ -2263,13 +2263,13 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol)) } // "If a DATA frame is received whose stream is not in "open" // or "half closed (local)" state, the recipient MUST respond // with a stream error (Section 5.4.2) of type STREAM_CLOSED." - if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { + if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued { // This includes sending a RST_STREAM if the stream is // in stateHalfClosedLocal (which currently means that // the http.Handler returned, so it's done reading & @@ -2280,7 +2280,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // and return any flow control bytes since we're not going // to consume them. if sc.inflow.available() < int32(f.Length) { - return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) + return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) } // Deduct the flow control from inflow, since we're // going to immediately add it back in @@ -2293,7 +2293,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // Already have a stream error in flight. Don't send another. return nil } - return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) + return sc.countError("closed", streamError(id, ErrCodeStreamClosed)) } if st.body == nil { panic("internal error: should have a body in this state") @@ -2305,12 +2305,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // value of a content-length header field does not equal the sum of the // DATA frame payload lengths that form the body. - return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) + return sc.countError("send_too_much", streamError(id, ErrCodeProtocol)) } if f.Length > 0 { // Check whether the client has flow control quota. if st.inflow.available() < int32(f.Length) { - return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) + return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl)) } st.inflow.take(int32(f.Length)) @@ -2318,7 +2318,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { wrote, err := st.body.Write(data) if err != nil { sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) + return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed)) } if wrote != len(data) { panic("internal error: bad Writer") @@ -2339,9 +2339,9 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { return nil } -func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { +func (sc *serverConn) processGoAway(f *GoAwayFrame) error { sc.serveG.check() - if f.ErrCode != http2ErrCodeNo { + if f.ErrCode != ErrCodeNo { sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) } else { sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) @@ -2354,13 +2354,13 @@ func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { } // isPushed reports whether the stream is server-initiated. -func (st *http2stream) isPushed() bool { +func (st *stream) isPushed() bool { return st.id%2 == 0 } // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). -func (st *http2stream) endStream() { +func (st *stream) endStream() { sc := st.sc sc.serveG.check() @@ -2371,12 +2371,12 @@ func (st *http2stream) endStream() { st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) st.body.CloseWithError(io.EOF) } - st.state = http2stateHalfClosedRemote + st.state = stateHalfClosedRemote } // copyTrailersToHandlerRequest is run in the Handler's goroutine in // its Request.Body.Read just before it gets io.EOF. -func (st *http2stream) copyTrailersToHandlerRequest() { +func (st *stream) copyTrailersToHandlerRequest() { for k, vv := range st.trailer { if _, ok := st.reqTrailer[k]; ok { // Only copy it over it was pre-declared. @@ -2387,11 +2387,11 @@ func (st *http2stream) copyTrailersToHandlerRequest() { // onWriteTimeout is run on its own goroutine (from time.AfterFunc) // when the stream's WriteTimeout has fired. -func (st *http2stream) onWriteTimeout() { - st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +func (st *stream) onWriteTimeout() { + st.sc.writeFrameFromHandler(FrameWriteRequest{write: streamError(st.id, ErrCodeInternal)}) } -func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { +func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { sc.serveG.check() id := f.StreamID if sc.inGoAway { @@ -2404,7 +2404,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // stream identifier MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. if id%2 != 1 { - return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("headers_even", ConnectionError(ErrCodeProtocol)) } // A HEADERS frame can be used to create a new stream or // send a trailer for an open one. If we already have a stream @@ -2420,8 +2420,8 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in // this state, it MUST respond with a stream error (Section 5.4.2) of // type STREAM_CLOSED. - if st.state == http2stateHalfClosedRemote { - return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) + if st.state == stateHalfClosedRemote { + return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed)) } return st.processTrailerHeaders(f) } @@ -2432,7 +2432,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // receives an unexpected stream identifier MUST respond with // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. if id <= sc.maxClientStreamID { - return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol)) } sc.maxClientStreamID = id @@ -2449,19 +2449,19 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if sc.curClientStreams+1 > sc.advMaxStreams { if sc.unackedSettings == 0 { // They should know better. - return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) + return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol)) } // Assume it's a network race, where they just haven't // received our last SETTINGS update. But actually // this can't happen yet, because we don't yet provide // a way for users to adjust server parameters at // runtime. - return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) + return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream)) } - initialState := http2stateOpen + initialState := stateOpen if f.StreamEnded() { - initialState = http2stateHalfClosedRemote + initialState = stateHalfClosedRemote } st := sc.newStream(id, 0, initialState) @@ -2480,15 +2480,15 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if st.reqTrailer != nil { st.trailer = make(http.Header) } - st.body = req.Body.(*http2requestBody).pipe // may be nil + st.body = req.Body.(*requestBody).pipe // may be nil st.declBodyBytes = req.ContentLength handler := sc.handler.ServeHTTP if f.Truncated { // Their header list was too long. Send a 431 error. - handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { - handler = http2new400Handler(err) + handler = handleHeaderListTooLong + } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil { + handler = new400Handler(err) } // The net/http package sets the read deadline from the @@ -2506,19 +2506,19 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { return nil } -func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { +func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { sc := st.sc sc.serveG.check() if st.gotTrailerHeader { - return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) + return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol)) } st.gotTrailerHeader = true if !f.StreamEnded() { - return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) + return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol)) } if len(f.PseudoFields()) > 0 { - return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) + return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol)) } if st.trailer != nil { for _, hf := range f.RegularFields() { @@ -2527,7 +2527,7 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { // TODO: send more details to the peer somehow. But http2 has // no way to send debug data at a stream level. Discuss with // HTTP folk. - return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) + return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol)) } st.trailer[key] = append(st.trailer[key], hf.Value) } @@ -2536,36 +2536,36 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { +func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error { if streamID == p.StreamDep { // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." // Section 5.3.3 says that a stream can depend on one of its dependencies, // so it's only self-dependencies that are forbidden. - return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) + return sc.countError("priority", streamError(streamID, ErrCodeProtocol)) } return nil } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { +func (sc *serverConn) processPriority(f *PriorityFrame) error { if sc.inGoAway { return nil } - if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil { return err } - sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam) return nil } -func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { +func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream { sc.serveG.check() if id == 0 { panic("internal error: cannot create stream with id 0") } ctx, cancelCtx := context.WithCancel(sc.baseCtx) - st := &http2stream{ + st := &stream{ sc: sc, id: id, state: state, @@ -2582,7 +2582,7 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState } sc.streams[id] = st - sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID}) if st.isPushed() { sc.curPushedStreams++ } else { @@ -2595,10 +2595,10 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState return st } -func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { sc.serveG.check() - rp := http2requestParam{ + rp := requestParam{ method: f.PseudoValue("method"), scheme: f.PseudoValue("scheme"), authority: f.PseudoValue("authority"), @@ -2608,7 +2608,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead isConnect := rp.method == "CONNECT" if isConnect { if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) + return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) } } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: @@ -2621,13 +2621,13 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead // "All HTTP/2 requests MUST include exactly one valid // value for the :method, :scheme, and :path // pseudo-header fields" - return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) + return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) } bodyOpen := !f.StreamEnded() if rp.method == "HEAD" && bodyOpen { // HEAD requests can't have bodies - return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) + return nil, nil, sc.countError("head_body", streamError(f.StreamID, ErrCodeProtocol)) } rp.header = make(http.Header) @@ -2652,20 +2652,20 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } else { req.ContentLength = -1 } - req.Body.(*http2requestBody).pipe = &http2pipe{ - b: &http2dataBuffer{expected: req.ContentLength}, + req.Body.(*requestBody).pipe = &pipe{ + b: &dataBuffer{expected: req.ContentLength}, } } return rw, req, nil } -type http2requestParam struct { +type requestParam struct { method string scheme, authority, path string header http.Header } -func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) { sc.serveG.check() var tlsState *tls.ConnectionState // nil if not scheme https @@ -2710,12 +2710,12 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var err error u, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) + return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol)) } requestURI = rp.path } - body := &http2requestBody{ + body := &requestBody{ conn: sc, stream: st, needsContinue: needsContinue, @@ -2736,29 +2736,29 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re } req = req.WithContext(st.ctx) - rws := http2responseWriterStatePool.Get().(*http2responseWriterState) + rws := responseWriterStatePool.Get().(*responseWriterState) bwSave := rws.bw - *rws = http2responseWriterState{} // zero all the fields + *rws = responseWriterState{} // zero all the fields rws.conn = sc rws.bw = bwSave - rws.bw.Reset(http2chunkWriter{rws}) + rws.bw.Reset(chunkWriter{rws}) rws.stream = st rws.req = req rws.body = body - rw := &http2responseWriter{rws: rws} + rw := &responseWriter{rws: rws} return rw, req, nil } // Run on its own goroutine. -func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { +func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { didPanic := true defer func() { rw.rws.stream.cancelCtx() if didPanic { e := recover() - sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: http2handlerPanicRST{rw.rws.stream.id}, + sc.writeFrameFromHandler(FrameWriteRequest{ + write: handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) // Same as net/http: @@ -2776,7 +2776,7 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *http.Request didPanic = false } -func http2handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { +func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { // 10.5.1 Limits on Header Block Size: // .. "A server that receives a larger header block than it is // willing to handle can send an HTTP 431 (Request Header Fields Too @@ -2788,7 +2788,7 @@ func http2handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { // called from handler goroutines. // h may be nil. -func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { +func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error { sc.serveG.checkNotOn() // NOT on var errc chan error if headerData.h != nil { @@ -2796,9 +2796,9 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR // waiting for this frame to be written, so an http.Flush mid-handler // writes out the correct value of keys, before a handler later potentially // mutates it. - errc = http2errChanPool.Get().(chan error) + errc = errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + if err := sc.writeFrameFromHandler(FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -2808,7 +2808,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR if errc != nil { select { case err := <-errc: - http2errChanPool.Put(errc) + errChanPool.Put(errc) return err case <-sc.doneServing: return errClientDisconnected @@ -2820,37 +2820,37 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR } // called from handler goroutines. -func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2FrameWriteRequest{ - write: http2write100ContinueHeadersFrame{st.id}, +func (sc *serverConn) write100ContinueHeaders(st *stream) { + sc.writeFrameFromHandler(FrameWriteRequest{ + write: write100ContinueHeadersFrame{st.id}, stream: st, }) } // A bodyReadMsg tells the server loop that the http.Handler read n // bytes of the DATA from the client on the given stream. -type http2bodyReadMsg struct { - st *http2stream +type bodyReadMsg struct { + st *stream n int } // called from handler goroutines. // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. -func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { +func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) { sc.serveG.checkNotOn() // NOT on if n > 0 { select { - case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case sc.bodyReadCh <- bodyReadMsg{st, n}: case <-sc.doneServing: } } } -func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { +func (sc *serverConn) noteBodyRead(st *stream, n int) { sc.serveG.check() sc.sendWindowUpdate(nil, n) // conn-level - if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { + if st.state != stateHalfClosedRemote && st.state != stateClosed { // Don't send this WINDOW_UPDATE if the stream is closed // remotely. sc.sendWindowUpdate(st, n) @@ -2858,7 +2858,7 @@ func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { } // st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { +func (sc *serverConn) sendWindowUpdate(st *stream, n int) { sc.serveG.check() // "The legal range for the increment to the flow control // window is 1 to 2^31-1 (2,147,483,647) octets." @@ -2874,7 +2874,7 @@ func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { } // st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { +func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { sc.serveG.check() if n == 0 { return @@ -2886,8 +2886,8 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(http2FrameWriteRequest{ - write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + sc.writeFrame(FrameWriteRequest{ + write: writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) var ok bool @@ -2903,17 +2903,17 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { // requestBody is the Handler's Request.Body type. // Read and Close may be called concurrently. -type http2requestBody struct { - _ http2incomparable - stream *http2stream - conn *http2serverConn - closed bool // for use by Close only - sawEOF bool // for use by Read only - pipe *http2pipe // non-nil if we have a HTTP entity message body - needsContinue bool // need to send a 100-continue +type requestBody struct { + _ incomparable + stream *stream + conn *serverConn + closed bool // for use by Close only + sawEOF bool // for use by Read only + pipe *pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue } -func (b *http2requestBody) Close() error { +func (b *requestBody) Close() error { if b.pipe != nil && !b.closed { b.pipe.BreakWithError(errClosedBody) } @@ -2921,7 +2921,7 @@ func (b *http2requestBody) Close() error { return nil } -func (b *http2requestBody) Read(p []byte) (n int, err error) { +func (b *requestBody) Read(p []byte) (n int, err error) { if b.needsContinue { b.needsContinue = false b.conn.write100ContinueHeaders(b.stream) @@ -2933,7 +2933,7 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { if err == io.EOF { b.sawEOF = true } - if b.conn == nil && http2inTests { + if b.conn == nil && inTests { return } b.conn.noteBodyReadFromHandler(b.stream, n, err) @@ -2946,28 +2946,28 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { // request (in handlerDone) and calls on the responseWriter thereafter // simply crash (caller's mistake), but the much larger responseWriterState // and buffers are reused between multiple requests. -type http2responseWriter struct { - rws *http2responseWriterState +type responseWriter struct { + rws *responseWriterState } // from pkg io -type http2stringWriter interface { +type stringWriter interface { WriteString(s string) (n int, err error) } // Optional http.ResponseWriter interfaces implemented. var ( - _ http.CloseNotifier = (*http2responseWriter)(nil) - _ http.Flusher = (*http2responseWriter)(nil) - _ http2stringWriter = (*http2responseWriter)(nil) + _ http.CloseNotifier = (*responseWriter)(nil) + _ http.Flusher = (*responseWriter)(nil) + _ stringWriter = (*responseWriter)(nil) ) -type http2responseWriterState struct { +type responseWriterState struct { // immutable within a request: - stream *http2stream + stream *stream req *http.Request - body *http2requestBody // to close at end of request, if DATA frames didn't - conn *http2serverConn + body *requestBody // to close at end of request, if DATA frames didn't + conn *serverConn // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} @@ -2989,13 +2989,13 @@ type http2responseWriterState struct { closeNotifierCh chan bool // nil until first used } -type http2chunkWriter struct{ rws *http2responseWriterState } +type chunkWriter struct{ rws *responseWriterState } -func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } +func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } -func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } +func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } -func (rws *http2responseWriterState) hasNonemptyTrailers() bool { +func (rws *responseWriterState) hasNonemptyTrailers() bool { for _, trailer := range rws.trailers { if _, ok := rws.handlerHeader[trailer]; ok { return true @@ -3007,14 +3007,14 @@ func (rws *http2responseWriterState) hasNonemptyTrailers() bool { // declareTrailer is called for each Trailer header when the // response header is written. It notes that a header will need to be // written in the trailers at the end of the response. -func (rws *http2responseWriterState) declareTrailer(k string) { +func (rws *responseWriterState) declareTrailer(k string) { k = http.CanonicalHeaderKey(k) if !httpguts.ValidTrailerHeader(k) { // Forbidden by RFC 7230, section 4.1.2. rws.conn.logf("ignoring invalid trailer %q", k) return } - if !http2strSliceContains(rws.trailers, k) { + if !strSliceContains(rws.trailers, k) { rws.trailers = append(rws.trailers, k) } } @@ -3025,7 +3025,7 @@ func (rws *http2responseWriterState) declareTrailer(k string) { // // writeChunk is also responsible (on the first chunk) for sending the // HEADER response. -func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { +func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { if !rws.wroteHeader { rws.writeHeader(200) } @@ -3042,7 +3042,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { clen = "" } } - if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + if clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { clen = strconv.Itoa(len(p)) } _, hasContentType := rws.snapHeader["Content-Type"] @@ -3050,7 +3050,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { // sniff the body. See Issue golang.org/issue/31753. ce := rws.snapHeader.Get("Content-Encoding") hasCE := len(ce) > 0 - if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { + if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 { ctype = http.DetectContentType(p) } var date string @@ -3060,7 +3060,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { } for _, v := range rws.snapHeader["Trailer"] { - http2foreachHeaderElement(v, rws.declareTrailer) + foreachHeaderElement(v, rws.declareTrailer) } // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), @@ -3077,7 +3077,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { } endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp - err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{ streamID: rws.stream.id, httpResCode: rws.status, h: rws.snapHeader, @@ -3118,7 +3118,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { } if rws.handlerDone && hasNonemptyTrailers { - err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{ streamID: rws.stream.id, h: rws.handlerHeader, trailers: rws.trailers, @@ -3144,7 +3144,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { // is preferred: // https://golang.org/pkg/net/http/#ResponseWriter // https://golang.org/pkg/net/http/#example_ResponseWriter_trailers -const http2TrailerPrefix = "Trailer:" +const TrailerPrefix = "Trailer:" // promoteUndeclaredTrailers permits http.Handlers to set trailers // after the header has already been flushed. Because the Go @@ -3167,24 +3167,24 @@ const http2TrailerPrefix = "Trailer:" // // This method runs after the Handler is done and promotes any Header // fields to be trailers. -func (rws *http2responseWriterState) promoteUndeclaredTrailers() { +func (rws *responseWriterState) promoteUndeclaredTrailers() { for k, vv := range rws.handlerHeader { - if !strings.HasPrefix(k, http2TrailerPrefix) { + if !strings.HasPrefix(k, TrailerPrefix) { continue } - trailerKey := strings.TrimPrefix(k, http2TrailerPrefix) + trailerKey := strings.TrimPrefix(k, TrailerPrefix) rws.declareTrailer(trailerKey) rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv } if len(rws.trailers) > 1 { - sorter := http2sorterPool.Get().(*http2sorter) + sorter := sorterPool.Get().(*sorter) sorter.SortStrings(rws.trailers) - http2sorterPool.Put(sorter) + sorterPool.Put(sorter) } } -func (w *http2responseWriter) Flush() { +func (w *responseWriter) Flush() { rws := w.rws if rws == nil { panic("Header called after Handler finished") @@ -3203,7 +3203,7 @@ func (w *http2responseWriter) Flush() { } } -func (w *http2responseWriter) CloseNotify() <-chan bool { +func (w *responseWriter) CloseNotify() <-chan bool { rws := w.rws if rws == nil { panic("CloseNotify called after Handler finished") @@ -3223,7 +3223,7 @@ func (w *http2responseWriter) CloseNotify() <-chan bool { return ch } -func (w *http2responseWriter) Header() http.Header { +func (w *responseWriter) Header() http.Header { rws := w.rws if rws == nil { panic("Header called after Handler finished") @@ -3235,7 +3235,7 @@ func (w *http2responseWriter) Header() http.Header { } // checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. -func http2checkWriteHeaderCode(code int) { +func checkWriteHeaderCode(code int) { // Issue 22880: require valid WriteHeader status codes. // For now we only enforce that it's three digits. // In the future we might block things over 599 (600 and above aren't defined @@ -3252,7 +3252,7 @@ func http2checkWriteHeaderCode(code int) { } } -func (w *http2responseWriter) WriteHeader(code int) { +func (w *responseWriter) WriteHeader(code int) { rws := w.rws if rws == nil { panic("WriteHeader called after Handler finished") @@ -3260,18 +3260,18 @@ func (w *http2responseWriter) WriteHeader(code int) { rws.writeHeader(code) } -func (rws *http2responseWriterState) writeHeader(code int) { +func (rws *responseWriterState) writeHeader(code int) { if !rws.wroteHeader { - http2checkWriteHeaderCode(code) + checkWriteHeaderCode(code) rws.wroteHeader = true rws.status = code if len(rws.handlerHeader) > 0 { - rws.snapHeader = http2cloneHeader(rws.handlerHeader) + rws.snapHeader = cloneHeader(rws.handlerHeader) } } } -func http2cloneHeader(h http.Header) http.Header { +func cloneHeader(h http.Header) http.Header { h2 := make(http.Header, len(h)) for k, vv := range h { vv2 := make([]string, len(vv)) @@ -3289,16 +3289,16 @@ func http2cloneHeader(h http.Header) http.Header { // * -> chunkWriter{rws} // * -> responseWriterState.writeChunk(p []byte) // * -> responseWriterState.writeChunk (most of the magic; see comment there) -func (w *http2responseWriter) Write(p []byte) (n int, err error) { +func (w *responseWriter) Write(p []byte) (n int, err error) { return w.write(len(p), p, "") } -func (w *http2responseWriter) WriteString(s string) (n int, err error) { +func (w *responseWriter) WriteString(s string) (n int, err error) { return w.write(len(s), nil, s) } // either dataB or dataS is non-zero. -func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { +func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { rws := w.rws if rws == nil { panic("Write called after Handler finished") @@ -3306,7 +3306,7 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n if !rws.wroteHeader { w.WriteHeader(200) } - if !http2bodyAllowedForStatus(rws.status) { + if !bodyAllowedForStatus(rws.status) { return 0, http.ErrBodyNotAllowed } rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set @@ -3321,7 +3321,7 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n return rws.bw.WriteString(dataS) } -func (w *http2responseWriter) handlerDone() { +func (w *responseWriter) handlerDone() { rws := w.rws dirty := rws.dirty rws.handlerDone = true @@ -3334,7 +3334,7 @@ func (w *http2responseWriter) handlerDone() { // there might still be write goroutines outstanding // from the serverConn referencing the rws memory. See // issue 20704. - http2responseWriterStatePool.Put(rws) + responseWriterStatePool.Put(rws) } } @@ -3344,9 +3344,9 @@ var ( errPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") ) -var _ http.Pusher = (*http2responseWriter)(nil) +var _ http.Pusher = (*responseWriter)(nil) -func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error { +func (w *responseWriter) Push(target string, opts *http.PushOptions) error { st := w.rws.stream sc := st.sc sc.serveG.checkNotOn() @@ -3409,7 +3409,7 @@ func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error return fmt.Errorf("promised request headers cannot include %q", k) } } - if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil { return err } @@ -3420,12 +3420,12 @@ func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error return fmt.Errorf("method %q must be GET or HEAD", opts.Method) } - msg := &http2startPushRequest{ + msg := &startPushRequest{ parent: st, method: opts.Method, url: u, - header: http2cloneHeader(opts.Header), - done: http2errChanPool.Get().(chan error), + header: cloneHeader(opts.Header), + done: errChanPool.Get().(chan error), } select { @@ -3442,26 +3442,26 @@ func (w *http2responseWriter) Push(target string, opts *http.PushOptions) error case <-st.cw: return errStreamClosed case err := <-msg.done: - http2errChanPool.Put(msg.done) + errChanPool.Put(msg.done) return err } } -type http2startPushRequest struct { - parent *http2stream +type startPushRequest struct { + parent *stream method string url *url.URL header http.Header done chan error } -func (sc *http2serverConn) startPush(msg *http2startPushRequest) { +func (sc *serverConn) startPush(msg *startPushRequest) { sc.serveG.check() // http://tools.ietf.org/html/rfc7540#section-6.6. // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that // is in either the "open" or "half-closed (remote)" state. - if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote { // responseWriter.Push checks that the stream is peer-initiated. msg.done <- errStreamClosed return @@ -3505,13 +3505,13 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { // transition to "half closed (remote)" after sending the initial HEADERS, but // we start in "half closed (remote)" for simplicity. // See further comments at the definition of stateHalfClosedRemote. - promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) - rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ method: msg.method, scheme: msg.url.Scheme, authority: msg.url.Host, path: msg.url.RequestURI(), - header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE }) if err != nil { // Should not happen, since we've already validated msg.url. @@ -3522,8 +3522,8 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { return promisedID, nil } - sc.writeFrame(http2FrameWriteRequest{ - write: &http2writePushPromise{ + sc.writeFrame(FrameWriteRequest{ + write: &writePushPromise{ streamID: msg.parent.id, method: msg.method, url: msg.url, @@ -3536,7 +3536,7 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) { } // From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 -var http2connHeaders = []string{ +var connHeaders = []string{ "Connection", "Keep-Alive", "Proxy-Connection", @@ -3547,8 +3547,8 @@ var http2connHeaders = []string{ // checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func http2checkValidHTTP2RequestHeaders(h http.Header) error { - for _, k := range http2connHeaders { +func checkValidHTTP2RequestHeaders(h http.Header) error { + for _, k := range connHeaders { if _, ok := h[k]; ok { return fmt.Errorf("request header %q is not valid in HTTP/2", k) } @@ -3560,7 +3560,7 @@ func http2checkValidHTTP2RequestHeaders(h http.Header) error { return nil } -func http2new400Handler(err error) http.HandlerFunc { +func new400Handler(err error) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) } @@ -3569,7 +3569,7 @@ func http2new400Handler(err error) http.HandlerFunc { // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives // disabled. See comments on h1ServerShutdownChan above for why // the code is written this way. -func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { +func h1ServerKeepAlivesDisabled(hs *http.Server) bool { var x interface{} = hs type I interface { doKeepAlives() bool @@ -3580,7 +3580,7 @@ func http2h1ServerKeepAlivesDisabled(hs *http.Server) bool { return false } -func (sc *http2serverConn) countError(name string, err error) error { +func (sc *serverConn) countError(name string, err error) error { if sc == nil || sc.srv == nil { return err } @@ -3589,18 +3589,18 @@ func (sc *http2serverConn) countError(name string, err error) error { return err } var typ string - var code http2ErrCode + var code ErrCode switch e := err.(type) { - case http2ConnectionError: + case ConnectionError: typ = "conn" - code = http2ErrCode(e) - case http2StreamError: + code = ErrCode(e) + case StreamError: typ = "stream" - code = http2ErrCode(e.Code) + code = ErrCode(e.Code) default: return err } - codeStr := http2errCodeName[code] + codeStr := errCodeName[code] if codeStr == "" { codeStr = strconv.Itoa(int(code)) } @@ -3609,8 +3609,8 @@ func (sc *http2serverConn) countError(name string, err error) error { } // writeFramer is implemented by any type that is used to write frames. -type http2writeFramer interface { - writeFrame(http2writeContext) error +type writeFramer interface { + writeFrame(writeContext) error // staysWithinBuffer reports whether this writer promises that // it will only write less than or equal to size bytes, and it @@ -3622,14 +3622,14 @@ type http2writeFramer interface { // types below. All the writeFrame methods below are scheduled via the // frame writing scheduler (see writeScheduler in writesched.go). // -// This interface is implemented by *http2serverConn. +// This interface is implemented by *serverConn. // // TODO: decide whether to a) use this in the client code (which didn't // end up using this yet, because it has a simpler design, not // currently implementing priorities), or b) delete this and // make the server code a bit more concrete. -type http2writeContext interface { - Framer() *http2Framer +type writeContext interface { + Framer() *Framer Flush() error CloseConn() error // HeaderEncoder returns an HPACK encoder that writes to the @@ -3640,11 +3640,11 @@ type http2writeContext interface { // writeEndsStream reports whether w writes a frame that will transition // the stream to a half-closed local state. This returns false for RST_STREAM, // which closes the entire stream (not just the local half). -func http2writeEndsStream(w http2writeFramer) bool { +func writeEndsStream(w writeFramer) bool { switch v := w.(type) { - case *http2writeData: + case *writeData: return v.endStream - case *http2writeResHeaders: + case *writeResHeaders: return v.endStream case nil: // This can only happen if the caller reuses w after it's @@ -3655,110 +3655,110 @@ func http2writeEndsStream(w http2writeFramer) bool { return false } -type http2flushFrameWriter struct{} +type flushFrameWriter struct{} -func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { +func (flushFrameWriter) writeFrame(ctx writeContext) error { return ctx.Flush() } -func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } +func (flushFrameWriter) staysWithinBuffer(max int) bool { return false } -type http2writeSettings []http2Setting +type writeSettings []Setting -func (s http2writeSettings) staysWithinBuffer(max int) bool { +func (s writeSettings) staysWithinBuffer(max int) bool { const settingSize = 6 // uint16 + uint32 - return http2frameHeaderLen+settingSize*len(s) <= max + return frameHeaderLen+settingSize*len(s) <= max } -func (s http2writeSettings) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteSettings([]http2Setting(s)...) +func (s writeSettings) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteSettings([]Setting(s)...) } -type http2writeGoAway struct { +type writeGoAway struct { maxStreamID uint32 - code http2ErrCode + code ErrCode } -func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { +func (p *writeGoAway) writeFrame(ctx writeContext) error { err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) ctx.Flush() // ignore error: we're hanging up on them anyway return err } -func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes +func (*writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes -type http2writeData struct { +type writeData struct { streamID uint32 p []byte endStream bool } -func (w *http2writeData) String() string { +func (w *writeData) String() string { return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) } -func (w *http2writeData) writeFrame(ctx http2writeContext) error { +func (w *writeData) writeFrame(ctx writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } -func (w *http2writeData) staysWithinBuffer(max int) bool { - return http2frameHeaderLen+len(w.p) <= max +func (w *writeData) staysWithinBuffer(max int) bool { + return frameHeaderLen+len(w.p) <= max } // handlerPanicRST is the message sent from handler goroutines when // the handler panics. -type http2handlerPanicRST struct { +type handlerPanicRST struct { StreamID uint32 } -func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { - return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) +func (hp handlerPanicRST) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteRSTStream(hp.StreamID, ErrCodeInternal) } -func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } +func (hp handlerPanicRST) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } -func (se http2StreamError) writeFrame(ctx http2writeContext) error { +func (se StreamError) writeFrame(ctx writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } -func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } +func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } -type http2writePingAck struct{ pf *http2PingFrame } +type writePingAck struct{ pf *PingFrame } -func (w http2writePingAck) writeFrame(ctx http2writeContext) error { +func (w writePingAck) writeFrame(ctx writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } -func (w http2writePingAck) staysWithinBuffer(max int) bool { - return http2frameHeaderLen+len(w.pf.Data) <= max +func (w writePingAck) staysWithinBuffer(max int) bool { + return frameHeaderLen+len(w.pf.Data) <= max } -type http2writeSettingsAck struct{} +type writeSettingsAck struct{} -func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { +func (writeSettingsAck) writeFrame(ctx writeContext) error { return ctx.Framer().WriteSettingsAck() } -func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } +func (writeSettingsAck) staysWithinBuffer(max int) bool { return frameHeaderLen <= max } // splitHeaderBlock splits headerBlock into fragments so that each fragment fits // in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true // for the first/last fragment, respectively. -func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { +func splitHeaderBlock(ctx writeContext, headerBlock []byte, fn func(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error) error { // For now we're lazy and just pick the minimum MAX_FRAME_SIZE // that all peers must support (16KB). Later we could care // more and send larger frames if the peer advertised it, but // there's little point. Most headers are small anyway (so we // generally won't have CONTINUATION frames), and extra frames // only waste 9 bytes anyway. - const http2maxFrameSize = 16384 + const maxFrameSize = 16384 first := true for len(headerBlock) > 0 { frag := headerBlock - if len(frag) > http2maxFrameSize { - frag = frag[:http2maxFrameSize] + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] } headerBlock = headerBlock[len(frag):] if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { @@ -3771,7 +3771,7 @@ func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ct // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. -type http2writeResHeaders struct { +type writeResHeaders struct { streamID uint32 httpResCode int // 0 means no ":status" line h http.Header // may be nil @@ -3783,14 +3783,14 @@ type http2writeResHeaders struct { contentLength string } -func http2encKV(enc *hpack.Encoder, k, v string) { - if http2VerboseLogs { +func encKV(enc *hpack.Encoder, k, v string) { + if VerboseLogs { log.Printf("http2: server encoding header %q = %q", k, v) } enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } -func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { +func (w *writeResHeaders) staysWithinBuffer(max int) bool { // TODO: this is a common one. It'd be nice to return true // here and get into the fast path if we could be clever and // calculate the size fast enough, or at least a conservative @@ -3801,24 +3801,24 @@ func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { return false } -func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { +func (w *writeResHeaders) writeFrame(ctx writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() if w.httpResCode != 0 { - http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) + encKV(enc, ":status", httpCodeString(w.httpResCode)) } - http2encodeHeaders(enc, w.h, w.trailers) + encodeHeaders(enc, w.h, w.trailers) if w.contentType != "" { - http2encKV(enc, "content-type", w.contentType) + encKV(enc, "content-type", w.contentType) } if w.contentLength != "" { - http2encKV(enc, "content-length", w.contentLength) + encKV(enc, "content-length", w.contentLength) } if w.date != "" { - http2encKV(enc, "date", w.date) + encKV(enc, "date", w.date) } headerBlock := buf.Bytes() @@ -3826,12 +3826,12 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { panic("unexpected empty hpack") } - return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) + return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) } -func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { +func (w *writeResHeaders) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { if firstFrag { - return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + return ctx.Framer().WriteHeaders(HeadersFrameParam{ StreamID: w.streamID, BlockFragment: frag, EndStream: w.endStream, @@ -3842,7 +3842,7 @@ func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []by } // writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. -type http2writePushPromise struct { +type writePushPromise struct { streamID uint32 // pusher stream method string // for :method url *url.URL // for :scheme, :authority, :path @@ -3854,32 +3854,32 @@ type http2writePushPromise struct { promisedID uint32 } -func (w *http2writePushPromise) staysWithinBuffer(max int) bool { +func (w *writePushPromise) staysWithinBuffer(max int) bool { // TODO: see writeResHeaders.staysWithinBuffer return false } -func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { +func (w *writePushPromise) writeFrame(ctx writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() - http2encKV(enc, ":method", w.method) - http2encKV(enc, ":scheme", w.url.Scheme) - http2encKV(enc, ":authority", w.url.Host) - http2encKV(enc, ":path", w.url.RequestURI()) - http2encodeHeaders(enc, w.h, nil) + encKV(enc, ":method", w.method) + encKV(enc, ":scheme", w.url.Scheme) + encKV(enc, ":authority", w.url.Host) + encKV(enc, ":path", w.url.RequestURI()) + encodeHeaders(enc, w.h, nil) headerBlock := buf.Bytes() if len(headerBlock) == 0 { panic("unexpected empty hpack") } - return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) + return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) } -func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { +func (w *writePushPromise) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { if firstFrag { - return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + return ctx.Framer().WritePushPromise(PushPromiseParam{ StreamID: w.streamID, PromiseID: w.promisedID, BlockFragment: frag, @@ -3889,15 +3889,15 @@ func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []b return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } -type http2write100ContinueHeadersFrame struct { +type write100ContinueHeadersFrame struct { streamID uint32 } -func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { +func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() - http2encKV(enc, ":status", "100") - return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + encKV(enc, ":status", "100") + return ctx.Framer().WriteHeaders(HeadersFrameParam{ StreamID: w.streamID, BlockFragment: buf.Bytes(), EndStream: false, @@ -3905,42 +3905,42 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err }) } -func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { +func (w write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { // Sloppy but conservative: return 9+2*(len(":status")+len("100")) <= max } -type http2writeWindowUpdate struct { +type writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } -func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } +func (wu writeWindowUpdate) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } -func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { +func (wu writeWindowUpdate) writeFrame(ctx writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } // encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) // is encoded only if k is in keys. -func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { +func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { if keys == nil { - sorter := http2sorterPool.Get().(*http2sorter) + sorter := sorterPool.Get().(*sorter) // Using defer here, since the returned keys from the // sorter.Keys method is only valid until the sorter // is returned: - defer http2sorterPool.Put(sorter) + defer sorterPool.Put(sorter) keys = sorter.Keys(h) } for _, k := range keys { vv := h[k] - k, ascii := http2lowerHeader(k) + k, ascii := lowerHeader(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). continue } - if !http2validWireHeaderFieldName(k) { + if !validWireHeaderFieldName(k) { // Skip it as backup paranoia. Per // golang.org/issue/14048, these should // already be rejected at a higher level. @@ -3957,18 +3957,18 @@ func http2encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { if isTE && v != "trailers" { continue } - http2encKV(enc, k, v) + encKV(enc, k, v) } } } // WriteScheduler is the interface implemented by HTTP/2 write schedulers. // Methods are never called concurrently. -type http2WriteScheduler interface { +type WriteScheduler interface { // OpenStream opens a new stream in the write scheduler. // It is illegal to call this with streamID=0 or with a streamID that is // already open -- the call may panic. - OpenStream(streamID uint32, options http2OpenStreamOptions) + OpenStream(streamID uint32, options OpenStreamOptions) // CloseStream closes a stream in the write scheduler. Any frames queued on // this stream should be discarded. It is illegal to call this on a stream @@ -3979,38 +3979,38 @@ type http2WriteScheduler interface { // on a stream that has not yet been opened or has been closed. Note that // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: // https://tools.ietf.org/html/rfc7540#section-5.1 - AdjustStream(streamID uint32, priority http2PriorityParam) + AdjustStream(streamID uint32, priority PriorityParam) // Push queues a frame in the scheduler. In most cases, this will not be // called with wr.StreamID()!=0 unless that stream is currently open. The one // exception is RST_STREAM frames, which may be sent on idle or closed streams. - Push(wr http2FrameWriteRequest) + Push(wr FrameWriteRequest) // Pop dequeues the next frame to write. Returns false if no frames can // be written. Frames with a given wr.StreamID() are Pop'd in the same // order they are Push'd, except RST_STREAM frames. No frames should be // discarded except by CloseStream. - Pop() (wr http2FrameWriteRequest, ok bool) + Pop() (wr FrameWriteRequest, ok bool) } // OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. -type http2OpenStreamOptions struct { +type OpenStreamOptions struct { // PusherID is zero if the stream was initiated by the client. Otherwise, // PusherID names the stream that pushed the newly opened stream. PusherID uint32 } // FrameWriteRequest is a request to write a frame. -type http2FrameWriteRequest struct { +type FrameWriteRequest struct { // write is the interface value that does the writing, once the // WriteScheduler has selected this frame to write. The write // functions are all defined in write.go. - write http2writeFramer + write writeFramer // stream is the stream on which this frame will be written. // nil for non-stream frames like PING and SETTINGS. // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. - stream *http2stream + stream *stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -4020,10 +4020,10 @@ type http2FrameWriteRequest struct { // StreamID returns the id of the stream this frame will be written to. // 0 is used for non-stream frames such as PING and SETTINGS. -func (wr http2FrameWriteRequest) StreamID() uint32 { +func (wr FrameWriteRequest) StreamID() uint32 { if wr.stream == nil { - if se, ok := wr.write.(http2StreamError); ok { - // (*http2serverConn).resetStream doesn't set + if se, ok := wr.write.(StreamError); ok { + // (*serverConn).resetStream doesn't set // stream because it doesn't necessarily have // one. So special case this type of write // message. @@ -4036,14 +4036,14 @@ func (wr http2FrameWriteRequest) StreamID() uint32 { // isControl reports whether wr is a control frame for MaxQueuedControlFrames // purposes. That includes non-stream frames and RST_STREAM frames. -func (wr http2FrameWriteRequest) isControl() bool { +func (wr FrameWriteRequest) isControl() bool { return wr.stream == nil } // DataSize returns the number of flow control bytes that must be consumed // to write this entire frame. This is 0 for non-DATA frames. -func (wr http2FrameWriteRequest) DataSize() int { - if wd, ok := wr.write.(*http2writeData); ok { +func (wr FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*writeData); ok { return len(wd.p) } return 0 @@ -4059,11 +4059,11 @@ func (wr http2FrameWriteRequest) DataSize() int { // returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and // 'rest' contains the remaining bytes. The consumed bytes are deducted from the // underlying stream's flow control budget. -func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { - var empty http2FrameWriteRequest +func (wr FrameWriteRequest) Consume(n int32) (FrameWriteRequest, FrameWriteRequest, int) { + var empty FrameWriteRequest // Non-DATA frames are always consumed whole. - wd, ok := wr.write.(*http2writeData) + wd, ok := wr.write.(*writeData) if !ok || len(wd.p) == 0 { return wr, empty, 1 } @@ -4081,9 +4081,9 @@ func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2 } if len(wd.p) > int(allowed) { wr.stream.flow.take(allowed) - consumed := http2FrameWriteRequest{ + consumed := FrameWriteRequest{ stream: wr.stream, - write: &http2writeData{ + write: &writeData{ streamID: wd.streamID, p: wd.p[:allowed], // Even if the original had endStream set, there @@ -4095,9 +4095,9 @@ func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2 // this intermediate frame, so no need to wait. done: nil, } - rest := http2FrameWriteRequest{ + rest := FrameWriteRequest{ stream: wr.stream, - write: &http2writeData{ + write: &writeData{ streamID: wd.streamID, p: wd.p[allowed:], endStream: wd.endStream, @@ -4114,7 +4114,7 @@ func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2 } // String is for debugging only. -func (wr http2FrameWriteRequest) String() string { +func (wr FrameWriteRequest) String() string { var des string if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() @@ -4126,7 +4126,7 @@ func (wr http2FrameWriteRequest) String() string { // replyToWriter sends err to wr.done and panics if the send must block // This does nothing if wr.done is nil. -func (wr *http2FrameWriteRequest) replyToWriter(err error) { +func (wr *FrameWriteRequest) replyToWriter(err error) { if wr.done == nil { return } @@ -4139,24 +4139,24 @@ func (wr *http2FrameWriteRequest) replyToWriter(err error) { } // writeQueue is used by implementations of WriteScheduler. -type http2writeQueue struct { - s []http2FrameWriteRequest +type writeQueue struct { + s []FrameWriteRequest } -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } +func (q *writeQueue) empty() bool { return len(q.s) == 0 } -func (q *http2writeQueue) push(wr http2FrameWriteRequest) { +func (q *writeQueue) push(wr FrameWriteRequest) { q.s = append(q.s, wr) } -func (q *http2writeQueue) shift() http2FrameWriteRequest { +func (q *writeQueue) shift() FrameWriteRequest { if len(q.s) == 0 { panic("invalid use of queue") } wr := q.s[0] // TODO: less copy-happy queue. copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s[len(q.s)-1] = FrameWriteRequest{} q.s = q.s[:len(q.s)-1] return wr } @@ -4165,14 +4165,14 @@ func (q *http2writeQueue) shift() http2FrameWriteRequest { // entirely consumed, it is removed from the queue. If the frame // is partially consumed, the frame is kept with the consumed // bytes removed. Returns true iff any bytes were consumed. -func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { +func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) { if len(q.s) == 0 { - return http2FrameWriteRequest{}, false + return FrameWriteRequest{}, false } consumed, rest, numresult := q.s[0].Consume(n) switch numresult { case 0: - return http2FrameWriteRequest{}, false + return FrameWriteRequest{}, false case 1: q.shift() case 2: @@ -4181,24 +4181,24 @@ func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { return consumed, true } -type http2writeQueuePool []*http2writeQueue +type writeQueuePool []*writeQueue // put inserts an unused writeQueue into the pool. // put inserts an unused writeQueue into the pool. -func (p *http2writeQueuePool) put(q *http2writeQueue) { +func (p *writeQueuePool) put(q *writeQueue) { for i := range q.s { - q.s[i] = http2FrameWriteRequest{} + q.s[i] = FrameWriteRequest{} } q.s = q.s[:0] *p = append(*p, q) } // get returns an empty writeQueue. -func (p *http2writeQueuePool) get() *http2writeQueue { +func (p *writeQueuePool) get() *writeQueue { ln := len(*p) if ln == 0 { - return new(http2writeQueue) + return new(writeQueue) } x := ln - 1 q := (*p)[x] @@ -4208,12 +4208,12 @@ func (p *http2writeQueuePool) get() *http2writeQueue { } // RFC 7540, Section 5.3.5: the default weight is 16. -const http2priorityDefaultWeight = 15 // 16 = 15 + 1 +const priorityDefaultWeight = 15 // 16 = 15 + 1 // PriorityWriteSchedulerConfig configures a priorityWriteScheduler. -type http2PriorityWriteSchedulerConfig struct { +type PriorityWriteSchedulerConfig struct { // MaxClosedNodesInTree controls the maximum number of closed streams to - // retain in the priority tree. http2Setting this to zero saves a small amount + // retain in the priority tree. Setting this to zero saves a small amount // of memory at the cost of performance. // // See RFC 7540, Section 5.3.4: @@ -4227,7 +4227,7 @@ type http2PriorityWriteSchedulerConfig struct { MaxClosedNodesInTree int // MaxIdleNodesInTree controls the maximum number of idle streams to - // retain in the priority tree. http2Setting this to zero saves a small amount + // retain in the priority tree. Setting this to zero saves a small amount // of memory at the cost of performance. // // See RFC 7540, Section 5.3.4: @@ -4252,19 +4252,19 @@ type http2PriorityWriteSchedulerConfig struct { // NewPriorityWriteScheduler constructs a WriteScheduler that schedules // frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. // If cfg is nil, default options are used. -func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { +func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler { if cfg == nil { // For justification of these defaults, see: // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY - cfg = &http2PriorityWriteSchedulerConfig{ + cfg = &PriorityWriteSchedulerConfig{ MaxClosedNodesInTree: 10, MaxIdleNodesInTree: 10, ThrottleOutOfOrderWrites: false, } } - ws := &http2priorityWriteScheduler{ - nodes: make(map[uint32]*http2priorityNode), + ws := &priorityWriteScheduler{ + nodes: make(map[uint32]*priorityNode), maxClosedNodesInTree: cfg.MaxClosedNodesInTree, maxIdleNodesInTree: cfg.MaxIdleNodesInTree, enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, @@ -4278,32 +4278,32 @@ func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http return ws } -type http2priorityNodeState int +type priorityNodeState int const ( - http2priorityNodeOpen http2priorityNodeState = iota - http2priorityNodeClosed - http2priorityNodeIdle + priorityNodeOpen priorityNodeState = iota + priorityNodeClosed + priorityNodeIdle ) // priorityNode is a node in an HTTP/2 priority tree. // Each node is associated with a single stream ID. // See RFC 7540, Section 5.3. -type http2priorityNode struct { - q http2writeQueue // queue of pending frames to write - id uint32 // id of the stream, or 0 for the root of the tree - weight uint8 // the actual weight is weight+1, so the value is in [1,256] - state http2priorityNodeState // open | closed | idle - bytes int64 // number of bytes written by this node, or 0 if closed - subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree +type priorityNode struct { + q writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree // These links form the priority tree. - parent *http2priorityNode - kids *http2priorityNode // start of the kids list - prev, next *http2priorityNode // doubly-linked list of siblings + parent *priorityNode + kids *priorityNode // start of the kids list + prev, next *priorityNode // doubly-linked list of siblings } -func (n *http2priorityNode) setParent(parent *http2priorityNode) { +func (n *priorityNode) setParent(parent *priorityNode) { if n == parent { panic("setParent to self") } @@ -4338,7 +4338,7 @@ func (n *http2priorityNode) setParent(parent *http2priorityNode) { } } -func (n *http2priorityNode) addBytes(b int64) { +func (n *priorityNode) addBytes(b int64) { n.bytes += b for ; n != nil; n = n.parent { n.subtreeBytes += b @@ -4351,7 +4351,7 @@ func (n *http2priorityNode) addBytes(b int64) { // // f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true // if any ancestor p of n is still open (ignoring the root node). -func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { +func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool { if !n.q.empty() && f(n, openParent) { return true } @@ -4362,7 +4362,7 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior // Don't consider the root "open" when updating openParent since // we can't send data frames on the root stream (only control frames). if n.id != 0 { - openParent = openParent || (n.state == http2priorityNodeOpen) + openParent = openParent || (n.state == priorityNodeOpen) } // Common case: only one kid or all kids have the same weight. @@ -4392,7 +4392,7 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior *tmp = append(*tmp, n.kids) n.kids.setParent(nil) } - sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + sort.Sort(sortPriorityNodeSiblings(*tmp)) for i := len(*tmp) - 1; i >= 0; i-- { (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids } @@ -4404,13 +4404,13 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior return false } -type http2sortPriorityNodeSiblings []*http2priorityNode +type sortPriorityNodeSiblings []*priorityNode -func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } +func (z sortPriorityNodeSiblings) Len() int { return len(z) } -func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } +func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } -func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { +func (z sortPriorityNodeSiblings) Less(i, k int) bool { // Prefer the subtree that has sent fewer bytes relative to its weight. // See sections 5.3.2 and 5.3.4. wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) @@ -4424,13 +4424,13 @@ func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { return bi/bk <= wi/wk } -type http2priorityWriteScheduler struct { +type priorityWriteScheduler struct { // root is the root of the priority tree, where root.id = 0. // The root queues control frames that are not associated with any stream. - root http2priorityNode + root priorityNode // nodes maps stream ids to priority tree nodes. - nodes map[uint32]*http2priorityNode + nodes map[uint32]*priorityNode // maxID is the maximum stream id in nodes. maxID uint32 @@ -4438,7 +4438,7 @@ type http2priorityWriteScheduler struct { // lists of nodes that have been closed or are idle, but are kept in // the tree for improved prioritization. When the lengths exceed either // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. - closedNodes, idleNodes []*http2priorityNode + closedNodes, idleNodes []*priorityNode // From the config. maxClosedNodesInTree int @@ -4447,19 +4447,19 @@ type http2priorityWriteScheduler struct { enableWriteThrottle bool // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. - tmp []*http2priorityNode + tmp []*priorityNode // pool of empty queues for reuse. - queuePool http2writeQueuePool + queuePool writeQueuePool } -func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { +func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { // The stream may be currently idle but cannot be opened or closed. if curr := ws.nodes[streamID]; curr != nil { - if curr.state != http2priorityNodeIdle { + if curr.state != priorityNodeIdle { panic(fmt.Sprintf("stream %d already opened", streamID)) } - curr.state = http2priorityNodeOpen + curr.state = priorityNodeOpen return } @@ -4471,11 +4471,11 @@ func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2 if parent == nil { parent = &ws.root } - n := &http2priorityNode{ + n := &priorityNode{ q: *ws.queuePool.get(), id: streamID, - weight: http2priorityDefaultWeight, - state: http2priorityNodeOpen, + weight: priorityDefaultWeight, + state: priorityNodeOpen, } n.setParent(parent) ws.nodes[streamID] = n @@ -4484,19 +4484,19 @@ func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2 } } -func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { +func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { if streamID == 0 { panic("violation of WriteScheduler interface: cannot close stream 0") } if ws.nodes[streamID] == nil { panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ws.nodes[streamID].state != http2priorityNodeOpen { + if ws.nodes[streamID].state != priorityNodeOpen { panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } n := ws.nodes[streamID] - n.state = http2priorityNodeClosed + n.state = priorityNodeClosed n.addBytes(-n.bytes) q := n.q @@ -4509,7 +4509,7 @@ func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { } } -func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { +func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { if streamID == 0 { panic("adjustPriority on root") } @@ -4523,11 +4523,11 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht return } ws.maxID = streamID - n = &http2priorityNode{ + n = &priorityNode{ q: *ws.queuePool.get(), id: streamID, - weight: http2priorityDefaultWeight, - state: http2priorityNodeIdle, + weight: priorityDefaultWeight, + state: priorityNodeIdle, } n.setParent(&ws.root) ws.nodes[streamID] = n @@ -4539,7 +4539,7 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht parent := ws.nodes[priority.StreamDep] if parent == nil { n.setParent(&ws.root) - n.weight = http2priorityDefaultWeight + n.weight = priorityDefaultWeight return } @@ -4580,8 +4580,8 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht n.weight = priority.Weight } -func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { - var n *http2priorityNode +func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { + var n *priorityNode if id := wr.StreamID(); id == 0 { n = &ws.root } else { @@ -4601,8 +4601,8 @@ func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { n.q.push(wr) } -func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { - ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { +func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool { limit := int32(math.MaxInt32) if openParent { limit = ws.writeThrottleLimit @@ -4628,7 +4628,7 @@ func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool return wr, ok } -func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { +func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) { if maxSize == 0 { return } @@ -4642,7 +4642,7 @@ func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorit *list = append(*list, n) } -func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { +func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { for k := n.kids; k != nil; k = k.next { k.setParent(n.parent) } @@ -4654,28 +4654,28 @@ func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { // priorities. Control frames like SETTINGS and PING are written before DATA // frames, but if no control frames are queued and multiple streams have queued // HEADERS or DATA frames, Pop selects a ready stream arbitrarily. -func http2NewRandomWriteScheduler() http2WriteScheduler { - return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +func NewRandomWriteScheduler() WriteScheduler { + return &randomWriteScheduler{sq: make(map[uint32]*writeQueue)} } -type http2randomWriteScheduler struct { +type randomWriteScheduler struct { // zero are frames not associated with a specific stream. - zero http2writeQueue + zero writeQueue // sq contains the stream-specific queues, keyed by stream ID. // When a stream is idle, closed, or emptied, it's deleted // from the map. - sq map[uint32]*http2writeQueue + sq map[uint32]*writeQueue // pool of empty queues for reuse. - queuePool http2writeQueuePool + queuePool writeQueuePool } -func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { +func (ws *randomWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { // no-op: idle streams are not tracked } -func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { +func (ws *randomWriteScheduler) CloseStream(streamID uint32) { q, ok := ws.sq[streamID] if !ok { return @@ -4684,11 +4684,11 @@ func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { ws.queuePool.put(q) } -func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { +func (ws *randomWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { // no-op: priorities are ignored } -func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { +func (ws *randomWriteScheduler) Push(wr FrameWriteRequest) { if wr.isControl() { ws.zero.push(wr) return @@ -4702,7 +4702,7 @@ func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { q.push(wr) } -func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { +func (ws *randomWriteScheduler) Pop() (FrameWriteRequest, bool) { // Control and RST_STREAM frames first. if !ws.zero.empty() { return ws.zero.shift(), true @@ -4717,7 +4717,7 @@ func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { return wr, true } } - return http2FrameWriteRequest{}, false + return FrameWriteRequest{}, false } var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") @@ -4757,15 +4757,15 @@ type serverTester struct { cc net.Conn // client conn t testing.TB ts *httptest.Server - fr *http2Framer + fr *Framer serverLogBuf safeBuffer // logger for httptest.Server logFilter []string // substrings to filter out scMu sync.Mutex // guards sc - sc *http2serverConn + sc *serverConn hpackDec *hpack.Decoder decodedHeaders [][2]string - // If http2debug!=2, then we capture Frame debug logs that will be written + // If debug!=2, then we capture Frame debug logs that will be written // to t.Log after a test fails. The read and write logs use separate locks // and buffers so we don't accidentally introduce synchronization between // the read and write goroutines, which may hide data races. @@ -4798,14 +4798,14 @@ func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) { } func init() { - http2testHookOnPanicMu = new(sync.Mutex) - http2goAwayTimeout = 25 * time.Millisecond + testHookOnPanicMu = new(sync.Mutex) + goAwayTimeout = 25 * time.Millisecond } func resetHooks() { - http2testHookOnPanicMu.Lock() - http2testHookOnPanic = nil - http2testHookOnPanicMu.Unlock() + testHookOnPanicMu.Lock() + testHookOnPanic = nil + testHookOnPanicMu.Unlock() } // ConfigureServer adds HTTP/2 support to a net/http Server. @@ -4813,14 +4813,14 @@ func resetHooks() { // The configuration conf may be nil. // // ConfigureServer must be called before s begins serving. -func http2ConfigureServer(s *http.Server, conf *http2Server) error { +func ConfigureServer(s *http.Server, conf *Server) error { if s == nil { panic("nil *http.Server") } if conf == nil { - conf = new(http2Server) + conf = new(Server) } - conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})} + conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})} if h1, h2 := s, conf; h2.IdleTimeout == 0 { if h1.IdleTimeout != 0 { h2.IdleTimeout = h1.IdleTimeout @@ -4860,10 +4860,10 @@ func http2ConfigureServer(s *http.Server, conf *http2Server) error { s.TLSConfig.PreferServerCipherSuites = true - if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) + if !strSliceContains(s.TLSConfig.NextProtos, NextProtoTLS) { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) } - if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { + if !strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") } @@ -4871,8 +4871,8 @@ func http2ConfigureServer(s *http.Server, conf *http2Server) error { s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} } protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { - if http2testHookOnConn != nil { - http2testHookOnConn() + if testHookOnConn != nil { + testHookOnConn() } // The TLSNextProto interface predates contexts, so // the net/http package passes down its per-connection @@ -4886,13 +4886,13 @@ func http2ConfigureServer(s *http.Server, conf *http2Server) error { if bc, ok := h.(baseContexter); ok { ctx = bc.BaseContext() } - conf.ServeConn(c, &http2ServeConnOpts{ + conf.ServeConn(c, &ServeConnOpts{ Context: ctx, Handler: h, BaseConfig: hs, }) } - s.TLSNextProto[http2NextProtoTLS] = protoHandler + s.TLSNextProto[NextProtoTLS] = protoHandler return nil } @@ -4927,18 +4927,18 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} tlsConfig := &tls.Config{ InsecureSkipVerify: true, - NextProtos: []string{http2NextProtoTLS}, + NextProtos: []string{NextProtoTLS}, } var onlyServer, quiet, framerReuseFrames bool - h2server := new(http2Server) + h2server := new(Server) for _, opt := range opts { switch v := opt.(type) { case func(*tls.Config): v(tlsConfig) case func(*httptest.Server): v(ts) - case func(*http2Server): + case func(*Server): v(h2server) case serverTesterOpt: switch v { @@ -4956,14 +4956,14 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } } - http2ConfigureServer(ts.Config, h2server) + ConfigureServer(ts.Config, h2server) st := &serverTester{ t: t, ts: ts, } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) - st.hpackDec = hpack.NewDecoder(http2initialHeaderTableSize, st.onHeaderField) + st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config if quiet { @@ -4973,10 +4973,10 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } ts.StartTLS() - if http2VerboseLogs { + if VerboseLogs { t.Logf("Running test server at: %s", ts.URL) } - http2testHookGetServerConn = func(v *http2serverConn) { + testHookGetServerConn = func(v *serverConn) { st.scMu.Lock() defer st.scMu.Unlock() st.sc = v @@ -4988,11 +4988,11 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} t.Fatal(err) } st.cc = cc - st.fr = http2NewFramer(cc, cc) + st.fr = NewFramer(cc, cc) if framerReuseFrames { st.fr.SetReuseFrames() } - if !http2logFrameReads && !http2logFrameWrites { + if !logFrameReads && !logFrameWrites { st.fr.debugReadLoggerf = func(m string, v ...interface{}) { m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" st.frameReadLogMu.Lock() @@ -5022,16 +5022,16 @@ func (st *serverTester) addLogFilter(phrase string) { st.logFilter = append(st.logFilter, phrase) } -func (st *serverTester) stream(id uint32) *http2stream { - ch := make(chan *http2stream, 1) +func (st *serverTester) stream(id uint32) *stream { + ch := make(chan *stream, 1) st.sc.serveMsgCh <- func(int) { ch <- st.sc.streams[id] } return <-ch } -func (st *serverTester) http2streamState(id uint32) http2streamState { - ch := make(chan http2streamState, 1) +func (st *serverTester) streamState(id uint32) streamState { + ch := make(chan streamState, 1) st.sc.serveMsgCh <- func(int) { state, _ := st.sc.state(id) ch <- state @@ -5097,10 +5097,10 @@ func (st *serverTester) Close() { // greet initiates the client's HTTP/2 connection into a state where // frames may be sent. func (st *serverTester) greet() { - st.greetAndCheckSettings(func(http2Setting) error { return nil }) + st.greetAndCheckSettings(func(Setting) error { return nil }) } -func (st *serverTester) greetAndCheckSettings(checkSetting func(s http2Setting) error) { +func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) { st.writePreface() st.writeInitialSettings() st.wantSettings().ForeachSetting(checkSetting) @@ -5116,17 +5116,17 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s http2Setting) st.t.Fatal(err) } switch f := f.(type) { - case *http2SettingsFrame: - if !f.Header().Flags.Has(http2FlagSettingsAck) { + case *SettingsFrame: + if !f.Header().Flags.Has(FlagSettingsAck) { st.t.Fatal("Settings Frame didn't have ACK set") } gotSettingsAck = true - case *http2WindowUpdateFrame: - if f.http2FrameHeader.StreamID != 0 { - st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.http2FrameHeader.StreamID) + case *WindowUpdateFrame: + if f.FrameHeader.StreamID != 0 { + st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID) } - incr := uint32((&http2Server{}).initialConnRecvWindowSize() - http2initialWindowSize) + incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize) if f.Increment != incr { st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr) } @@ -5146,12 +5146,12 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s http2Setting) } func (st *serverTester) writePreface() { - n, err := st.cc.Write(http2clientPreface) + n, err := st.cc.Write(clientPreface) if err != nil { st.t.Fatalf("Error writing client preface: %v", err) } - if n != len(http2clientPreface) { - st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2clientPreface)) + if n != len(clientPreface) { + st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface)) } } @@ -5167,13 +5167,13 @@ func (st *serverTester) writeSettingsAck() { } } -func (st *serverTester) writeHeaders(p http2HeadersFrameParam) { +func (st *serverTester) writeHeaders(p HeadersFrameParam) { if err := st.fr.WriteHeaders(p); err != nil { st.t.Fatalf("Error writing HEADERS: %v", err) } } -func (st *serverTester) writePriority(id uint32, p http2PriorityParam) { +func (st *serverTester) writePriority(id uint32, p PriorityParam) { if err := st.fr.WritePriority(id, p); err != nil { st.t.Fatalf("Error writing PRIORITY: %v", err) } @@ -5269,7 +5269,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte { // bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set. func (st *serverTester) bodylessReq1(headers ...string) { - st.writeHeaders(http2HeadersFrameParam{ + st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(headers...), EndStream: true, @@ -5289,96 +5289,96 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p } } -func (st *serverTester) readFrame() (http2Frame, error) { +func (st *serverTester) readFrame() (Frame, error) { return st.fr.ReadFrame() } -func (st *serverTester) wantHeaders() *http2HeadersFrame { +func (st *serverTester) wantHeaders() *HeadersFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a HEADERS frame: %v", err) } - hf, ok := f.(*http2HeadersFrame) + hf, ok := f.(*HeadersFrame) if !ok { - st.t.Fatalf("got a %T; want *http2HeadersFrame", f) + st.t.Fatalf("got a %T; want *HeadersFrame", f) } return hf } -func (st *serverTester) wantContinuation() *http2ContinuationFrame { +func (st *serverTester) wantContinuation() *ContinuationFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err) } - cf, ok := f.(*http2ContinuationFrame) + cf, ok := f.(*ContinuationFrame) if !ok { - st.t.Fatalf("got a %T; want *http2ContinuationFrame", f) + st.t.Fatalf("got a %T; want *ContinuationFrame", f) } return cf } -func (st *serverTester) wantData() *http2DataFrame { +func (st *serverTester) wantData() *DataFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a DATA frame: %v", err) } - df, ok := f.(*http2DataFrame) + df, ok := f.(*DataFrame) if !ok { - st.t.Fatalf("got a %T; want *http2DataFrame", f) + st.t.Fatalf("got a %T; want *DataFrame", f) } return df } -func (st *serverTester) wantSettings() *http2SettingsFrame { +func (st *serverTester) wantSettings() *SettingsFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) } - sf, ok := f.(*http2SettingsFrame) + sf, ok := f.(*SettingsFrame) if !ok { - st.t.Fatalf("got a %T; want *http2SettingsFrame", f) + st.t.Fatalf("got a %T; want *SettingsFrame", f) } return sf } -func (st *serverTester) wantPing() *http2PingFrame { +func (st *serverTester) wantPing() *PingFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a PING frame: %v", err) } - pf, ok := f.(*http2PingFrame) + pf, ok := f.(*PingFrame) if !ok { - st.t.Fatalf("got a %T; want *http2PingFrame", f) + st.t.Fatalf("got a %T; want *PingFrame", f) } return pf } -func (st *serverTester) wantGoAway() *http2GoAwayFrame { +func (st *serverTester) wantGoAway() *GoAwayFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err) } - gf, ok := f.(*http2GoAwayFrame) + gf, ok := f.(*GoAwayFrame) if !ok { - st.t.Fatalf("got a %T; want *http2GoAwayFrame", f) + st.t.Fatalf("got a %T; want *GoAwayFrame", f) } return gf } -func (st *serverTester) wantRSTStream(streamID uint32, errCode http2ErrCode) { +func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting an RSTStream frame: %v", err) } - rs, ok := f.(*http2RSTStreamFrame) + rs, ok := f.(*RSTStreamFrame) if !ok { - st.t.Fatalf("got a %T; want *http2RSTStreamFrame", f) + st.t.Fatalf("got a %T; want *RSTStreamFrame", f) } - if rs.http2FrameHeader.StreamID != streamID { - st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.http2FrameHeader.StreamID, streamID) + if rs.FrameHeader.StreamID != streamID { + st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID) } if rs.ErrCode != errCode { - st.t.Fatalf("RSTStream http2ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) + st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) } } @@ -5387,12 +5387,12 @@ func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { if err != nil { st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err) } - wu, ok := f.(*http2WindowUpdateFrame) + wu, ok := f.(*WindowUpdateFrame) if !ok { - st.t.Fatalf("got a %T; want *http2WindowUpdateFrame", f) + st.t.Fatalf("got a %T; want *WindowUpdateFrame", f) } - if wu.http2FrameHeader.StreamID != streamID { - st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.http2FrameHeader.StreamID, streamID) + if wu.FrameHeader.StreamID != streamID { + st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID) } if wu.Increment != incr { st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) @@ -5404,21 +5404,21 @@ func (st *serverTester) wantSettingsAck() { if err != nil { st.t.Fatal(err) } - sf, ok := f.(*http2SettingsFrame) + sf, ok := f.(*SettingsFrame) if !ok { st.t.Fatalf("Wanting a settings ACK, received a %T", f) } - if !sf.Header().Flags.Has(http2FlagSettingsAck) { + if !sf.Header().Flags.Has(FlagSettingsAck) { st.t.Fatal("Settings Frame didn't have ACK set") } } -func (st *serverTester) wantPushPromise() *http2PushPromiseFrame { +func (st *serverTester) wantPushPromise() *PushPromiseFrame { f, err := st.readFrame() if err != nil { st.t.Fatal(err) } - ppf, ok := f.(*http2PushPromiseFrame) + ppf, ok := f.(*PushPromiseFrame) if !ok { st.t.Fatalf("Wanted PushPromise, received %T", ppf) } diff --git a/h2_trace.go b/internal/http2/trace.go similarity index 65% rename from h2_trace.go rename to internal/http2/trace.go index cf40a785..0be4bc2a 100644 --- a/h2_trace.go +++ b/internal/http2/trace.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "net/http" @@ -11,24 +11,24 @@ import ( "time" ) -func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { +func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return trace != nil && trace.WroteHeaderField != nil } -func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { +func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { if trace != nil && trace.WroteHeaderField != nil { trace.WroteHeaderField(k, []string{v}) } } -func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { +func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { if trace != nil { return trace.Got1xxResponse } return nil } -func http2traceGetConn(req *http.Request, hostPort string) { +func traceGetConn(req *http.Request, hostPort string) { trace := httptrace.ContextClientTrace(req.Context()) if trace == nil || trace.GetConn == nil { return @@ -36,7 +36,7 @@ func http2traceGetConn(req *http.Request, hostPort string) { trace.GetConn(hostPort) } -func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { +func traceGotConn(req *http.Request, cc *ClientConn, reused bool) { trace := httptrace.ContextClientTrace(req.Context()) if trace == nil || trace.GotConn == nil { return @@ -53,31 +53,31 @@ func http2traceGotConn(req *http.Request, cc *http2ClientConn, reused bool) { trace.GotConn(ci) } -func http2traceWroteHeaders(trace *httptrace.ClientTrace) { +func traceWroteHeaders(trace *httptrace.ClientTrace) { if trace != nil && trace.WroteHeaders != nil { trace.WroteHeaders() } } -func http2traceGot100Continue(trace *httptrace.ClientTrace) { +func traceGot100Continue(trace *httptrace.ClientTrace) { if trace != nil && trace.Got100Continue != nil { trace.Got100Continue() } } -func http2traceWait100Continue(trace *httptrace.ClientTrace) { +func traceWait100Continue(trace *httptrace.ClientTrace) { if trace != nil && trace.Wait100Continue != nil { trace.Wait100Continue() } } -func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { +func traceWroteRequest(trace *httptrace.ClientTrace, err error) { if trace != nil && trace.WroteRequest != nil { trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) } } -func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { +func traceFirstResponseByte(trace *httptrace.ClientTrace) { if trace != nil && trace.GotFirstResponseByte != nil { trace.GotFirstResponseByte() } diff --git a/h2_transport.go b/internal/http2/transport.go similarity index 76% rename from h2_transport.go rename to internal/http2/transport.go index 9fd48fb5..6133594d 100644 --- a/h2_transport.go +++ b/internal/http2/transport.go @@ -4,7 +4,7 @@ // Transport code. -package req +package http2 import ( "bufio" @@ -16,6 +16,11 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + reqtls "github.com/imroc/req/v3/internal/tls" + "github.com/imroc/req/v3/internal/transport" "io" "log" "math" @@ -40,32 +45,32 @@ import ( const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. - http2transportDefaultConnFlow = 1 << 30 + transportDefaultConnFlow = 1 << 30 // transportDefaultStreamFlow is how many stream-level flow // control tokens we announce to the peer, and how many bytes // we buffer per stream. - http2transportDefaultStreamFlow = 4 << 20 + transportDefaultStreamFlow = 4 << 20 // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send // a stream-level WINDOW_UPDATE for at a time. - http2transportDefaultStreamMinRefresh = 4 << 10 + transportDefaultStreamMinRefresh = 4 << 10 // initialMaxConcurrentStreams is a connections maxConcurrentStreams until // it's received servers initial SETTINGS frame, which corresponds with the // spec's minimum recommended value. - http2initialMaxConcurrentStreams = 100 + initialMaxConcurrentStreams = 100 // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams // if the server doesn't include one in its initial SETTINGS frame. - http2defaultMaxConcurrentStreams = 1000 + defaultMaxConcurrentStreams = 1000 ) // Transport is an HTTP/2 Transport. // // A Transport internally caches connections to servers. It is safe // for concurrent use by multiple goroutines. -type http2Transport struct { +type Transport struct { // DialTLS specifies an optional dial function for creating // TLS connections for requests. // @@ -77,17 +82,7 @@ type http2Transport struct { // ConnPool optionally specifies an alternate connection pool to use. // If nil, the default is used. - ConnPool http2ClientConnPool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool + ConnPool ClientConnPool // AllowHTTP, if true, permits HTTP/2 requests using the insecure, // plain-text "http" scheme. Note that this does not enable h2c support. @@ -139,13 +134,13 @@ type http2Transport struct { // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). - t1 *Transport + transport.Interface connPoolOnce sync.Once - connPoolOrDef http2ClientConnPool // non-nil version of ConnPool + connPoolOrDef ClientConnPool // non-nil version of ConnPool } -func (t *http2Transport) maxHeaderListSize() uint32 { +func (t *Transport) maxHeaderListSize() uint32 { if t.MaxHeaderListSize == 0 { return 10 << 20 } @@ -155,11 +150,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 { return t.MaxHeaderListSize } -func (t *http2Transport) disableCompression() bool { - return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) -} - -func (t *http2Transport) pingTimeout() time.Duration { +func (t *Transport) pingTimeout() time.Duration { if t.PingTimeout == 0 { return 15 * time.Second } @@ -167,42 +158,33 @@ func (t *http2Transport) pingTimeout() time.Duration { } -// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns an error if t1 has already been HTTP/2-enabled. -// -// Use ConfigureTransports instead to configure the HTTP/2 Transport. -func http2ConfigureTransport(t1 *Transport) error { - _, err := http2ConfigureTransports(t1) - return err -} - // ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. // It returns a new HTTP/2 Transport for further configuration. // It returns an error if t1 has already been HTTP/2-enabled. -func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { - connPool := new(http2clientConnPool) - t2 := &http2Transport{ - ConnPool: http2noDialClientConnPool{connPool}, - t1: t1, +func ConfigureTransports(t1 transport.Interface) (*Transport, error) { + connPool := new(clientConnPool) + t2 := &Transport{ + ConnPool: noDialClientConnPool{connPool}, + Interface: t1, } connPool.t = t2 - if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { return nil, err } - if t1.TLSClientConfig == nil { - t1.TLSClientConfig = new(tls.Config) + if t1.TLSClientConfig() == nil { + t1.SetTLSClientConfig(new(tls.Config)) } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { - t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) + if !strSliceContains(t1.TLSClientConfig().NextProtos, "h2") { + t1.TLSClientConfig().NextProtos = append([]string{"h2"}, t1.TLSClientConfig().NextProtos...) } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { - t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + if !strSliceContains(t1.TLSClientConfig().NextProtos, "http/1.1") { + t1.TLSClientConfig().NextProtos = append(t1.TLSClientConfig().NextProtos, "http/1.1") } - upgradeFn := func(authority string, c TLSConn) http.RoundTripper { - addr := http2authorityAddr("https", authority) + upgradeFn := func(authority string, c reqtls.Conn) http.RoundTripper { + addr := authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { go c.Close() - return http2erringRoundTripper{err} + return erringRoundTripper{err} } else if !used { // Turns out we don't need this c. // For example, two goroutines made requests to the same host @@ -212,34 +194,34 @@ func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { } return t2 } - if m := t1.TLSNextProto; len(m) == 0 { - t1.TLSNextProto = map[string]func(string, TLSConn) http.RoundTripper{ + if m := t1.TLSNextProto(); len(m) == 0 { + t1.SetTLSNextProto(map[string]func(string, reqtls.Conn) http.RoundTripper{ "h2": upgradeFn, - } + }) } else { m["h2"] = upgradeFn } return t2, nil } -func (t *http2Transport) connPool() http2ClientConnPool { +func (t *Transport) connPool() ClientConnPool { t.connPoolOnce.Do(t.initConnPool) return t.connPoolOrDef } -func (t *http2Transport) initConnPool() { +func (t *Transport) initConnPool() { if t.ConnPool != nil { t.connPoolOrDef = t.ConnPool } else { - t.connPoolOrDef = &http2clientConnPool{t: t} + t.connPoolOrDef = &clientConnPool{t: t} } } // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. -type http2ClientConn struct { +type ClientConn struct { currentRequest *http.Request - t *http2Transport + t *Transport tconn net.Conn // usually TLSConn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic @@ -255,17 +237,17 @@ type http2ClientConn struct { mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow http2flow // our conn-level flow control quota (cs.flow is per stream) - inflow http2flow // peer's conn-level flow control + flow flow // our conn-level flow control quota (cs.flow is per stream) + inflow flow // peer's conn-level flow control doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool - seenSettings bool // true if we've seen a settings frame, false otherwise - wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back - goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received - goAwayDebug string // goAway frame's debug data, retained as a string - streams map[uint32]*http2clientStream // client-initiated - streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip + seenSettings bool // true if we've seen a settings frame, false otherwise + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*clientStream // client-initiated + streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip nextStreamID uint32 pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel @@ -288,7 +270,7 @@ type http2ClientConn struct { // Only acquire both at the same time when changing peer settings. wmu sync.Mutex bw *bufio.Writer - fr *http2Framer + fr *Framer werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder @@ -296,8 +278,8 @@ type http2ClientConn struct { // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. -type http2clientStream struct { - cc *http2ClientConn +type clientStream struct { + cc *ClientConn // Fields of Request that we may access even after the response body is closed. ctx context.Context @@ -305,7 +287,7 @@ type http2clientStream struct { trace *httptrace.ClientTrace // or nil ID uint32 - bufPipe http2pipe // buffered pipe with the flow-controlled response payload + bufPipe pipe // buffered pipe with the flow-controlled response payload requestedGzip bool isHead bool @@ -320,10 +302,10 @@ type http2clientStream struct { respHeaderRecv chan struct{} // closed when headers are received res *http.Response // set if respHeaderRecv is closed - flow http2flow // guarded by cc.mu - inflow http2flow // guarded by cc.mu - bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read - readErr error // sticky read error; owned by transportResponseBody.Read + flow flow // guarded by cc.mu + inflow flow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser reqBodyContentLength int64 // -1 means unknown @@ -345,24 +327,24 @@ type http2clientStream struct { resTrailer *http.Header // client's Response.Trailer } -var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error +var got1xxFuncForTests func(int, textproto.MIMEHeader) error // get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, // if any. It returns nil if not set or if the Go version is too old. -func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { - if fn := http2got1xxFuncForTests; fn != nil { +func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { + if fn := got1xxFuncForTests; fn != nil { return fn } - return http2traceGot1xxResponseFunc(cs.trace) + return traceGot1xxResponseFunc(cs.trace) } -func (cs *http2clientStream) abortStream(err error) { +func (cs *clientStream) abortStream(err error) { cs.cc.mu.Lock() defer cs.cc.mu.Unlock() cs.abortStreamLocked(err) } -func (cs *http2clientStream) abortStreamLocked(err error) { +func (cs *clientStream) abortStreamLocked(err error) { cs.abortOnce.Do(func() { cs.abortErr = err close(cs.abort) @@ -378,7 +360,7 @@ func (cs *http2clientStream) abortStreamLocked(err error) { } } -func (cs *http2clientStream) abortRequestBodyWrite() { +func (cs *clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -389,13 +371,13 @@ func (cs *http2clientStream) abortRequestBodyWrite() { } } -type http2stickyErrWriter struct { +type stickyErrWriter struct { conn net.Conn timeout time.Duration err *error } -func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { +func (sew stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } @@ -422,25 +404,25 @@ func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { // bundled version (in h2_bundle.go with a rewritten type name) or // from a user's x/net/http2. As such, as it has a unique method name // (IsHTTP2NoCachedConnError) that net/http sniffs for via func -// isNoCachedConnError. -type http2noCachedConnError struct{} +// IsNoCachedConnError. +type noCachedConnError struct{} -func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} +func (noCachedConnError) IsHTTP2NoCachedConnError() {} -func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } +func (noCachedConnError) Error() string { return "http2: no cached connection was available" } -// isNoCachedConnError reports whether err is of type noCachedConnError +// IsNoCachedConnError reports whether err is of type noCachedConnError // or its equivalent renamed type in net/http2's h2_bundle.go. Both types // may coexist in the same running program. -func http2isNoCachedConnError(err error) bool { +func IsNoCachedConnError(err error) bool { _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) return ok } -var http2ErrNoCachedConn error = http2noCachedConnError{} +var ErrNoCachedConn error = noCachedConnError{} // RoundTripOpt are options for the Transport.RoundTripOpt method. -type http2RoundTripOpt struct { +type RoundTripOpt struct { // OnlyCachedConn controls whether RoundTripOpt may // create a new TCP connection. If set true and // no cached connection is available, RoundTripOpt @@ -448,13 +430,13 @@ type http2RoundTripOpt struct { OnlyCachedConn bool } -func (t *http2Transport) RoundTrip(req *http.Request) (*http.Response, error) { - return t.RoundTripOpt(req, http2RoundTripOpt{}) +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripOpt(req, RoundTripOpt{}) } // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. -func http2authorityAddr(scheme string, authority string) (addr string) { +func authorityAddr(scheme string, authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port port = "443" @@ -474,12 +456,12 @@ func http2authorityAddr(scheme string, authority string) (addr string) { } // RoundTripOpt is like RoundTrip, but takes options. -func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) (*http.Response, error) { +func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { return nil, errors.New("http2: unsupported scheme") } - addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + addr := authorityAddr(req.URL.Scheme, req.URL.Host) for retry := 0; ; retry++ { cc, err := t.connPool().GetClientConn(req, addr) if err != nil { @@ -487,10 +469,10 @@ func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) return nil, err } reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) - http2traceGotConn(req, cc, reused) + traceGotConn(req, cc, reused) res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { - if req, err = http2shouldRetryRequest(req, err); err == nil { + if req, err = shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { t.vlogf("RoundTrip retrying after failure: %v", err) @@ -518,8 +500,8 @@ func (t *http2Transport) RoundTripOpt(req *http.Request, opt http2RoundTripOpt) // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. -func (t *http2Transport) CloseIdleConnections() { - if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { +func (t *Transport) CloseIdleConnections() { + if cp, ok := t.connPool().(clientConnPoolIdleCloser); ok { cp.closeIdleConnections() } } @@ -534,8 +516,8 @@ var ( // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { - if !http2canRetryError(err) { +func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { + if !canRetryError(err) { return nil, err } // If the Body is nil (or http.NoBody), it's safe to reuse @@ -566,21 +548,21 @@ func http2shouldRetryRequest(req *http.Request, err error) (*http.Request, error return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) } -func http2canRetryError(err error) bool { +func canRetryError(err error) bool { if err == errClientConnUnusable || err == errClientConnGotGoAway { return true } - if se, ok := err.(http2StreamError); ok { - if se.Code == http2ErrCodeProtocol && se.Cause == errFromPeer { + if se, ok := err.(StreamError); ok { + if se.Code == ErrCodeProtocol && se.Cause == errFromPeer { // See golang/go#47635, golang/go#42777 return true } - return se.Code == http2ErrCodeRefusedStream + return se.Code == ErrCodeRefusedStream } return false } -func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) { +func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -592,13 +574,13 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single return t.newClientConn(tconn, singleUse) } -func (t *http2Transport) newTLSConfig(host string) *tls.Config { +func (t *Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) - if t.t1 != nil && t.t1.TLSClientConfig != nil { - *cfg = *t.t1.TLSClientConfig.Clone() + if c := t.TLSClientConfig(); c != nil { + *cfg = *c.Clone() } - if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { - cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) + if !strSliceContains(cfg.NextProtos, NextProtoTLS) { + cfg.NextProtos = append([]string{NextProtoTLS}, cfg.NextProtos...) } if cfg.ServerName == "" { cfg.ServerName = host @@ -606,7 +588,7 @@ func (t *http2Transport) newTLSConfig(host string) *tls.Config { return cfg } -func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { +func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS } @@ -616,8 +598,8 @@ func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls. return nil, err } state := tlsCn.ConnectionState() - if p := state.NegotiatedProtocol; p != http2NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) + if p := state.NegotiatedProtocol; p != NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS) } if !state.NegotiatedProtocolIsMutual { return nil, errors.New("http2: could not negotiate protocol mutually") @@ -626,64 +608,51 @@ func (t *http2Transport) dialTLS(ctx context.Context) func(string, string, *tls. } } -// disableKeepAlives reports whether connections should be closed as -// soon as possible after handling the first request. -func (t *http2Transport) disableKeepAlives() bool { - return t.t1 != nil && t.t1.DisableKeepAlives -} - -func (t *http2Transport) expectContinueTimeout() time.Duration { - if t.t1 == nil { - return 0 - } - return t.t1.ExpectContinueTimeout -} - -func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) +func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { + return t.newClientConn(c, t.DisableKeepAlives()) } -func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { - cc := &http2ClientConn{ +func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { + cc := &ClientConn{ t: t, tconn: c, readerDone: make(chan struct{}), nextStreamID: 1, - maxFrameSize: 16 << 10, // spec default - initialWindowSize: 65535, // spec default - maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. - peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. - streams: make(map[uint32]*http2clientStream), + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + streams: make(map[uint32]*clientStream), singleUse: singleUse, wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } - if d := t.idleConnTimeout(); d != 0 { + if d := t.IdleConnTimeout(); d != 0 { cc.idleTimeout = d cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } - if http2VerboseLogs { + if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) } cc.cond = sync.NewCond(&cc.mu) - cc.flow.add(int32(http2initialWindowSize)) + cc.flow.add(int32(initialWindowSize)) // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. - cc.bw = bufio.NewWriter(http2stickyErrWriter{ + cc.bw = bufio.NewWriter(stickyErrWriter{ conn: c, timeout: t.WriteByteTimeout, err: &cc.werr, }) cc.br = bufio.NewReader(c) - cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr = NewFramer(cc.bw, cc.br) cc.fr.cc = cc // for dump single request if t.CountError != nil { cc.fr.countError = t.CountError } - cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on @@ -694,23 +663,23 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client cc.nextStreamID = 3 } - if cs, ok := c.(http2connectionStater); ok { + if cs, ok := c.(connectionStater); ok { state := cs.ConnectionState() cc.tlsState = &state } - initialSettings := []http2Setting{ - {ID: http2SettingEnablePush, Val: 0}, - {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + initialSettings := []Setting{ + {ID: SettingEnablePush, Val: 0}, + {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow}, } if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) + initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) } - cc.bw.Write(http2clientPreface) + cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) - cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) - cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) + cc.inflow.add(transportDefaultConnFlow + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() @@ -721,7 +690,7 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client return cc, nil } -func (cc *http2ClientConn) healthCheck() { +func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. @@ -739,13 +708,13 @@ func (cc *http2ClientConn) healthCheck() { } // SetDoNotReuse marks cc as not reusable for future HTTP requests. -func (cc *http2ClientConn) SetDoNotReuse() { +func (cc *ClientConn) SetDoNotReuse() { cc.mu.Lock() defer cc.mu.Unlock() cc.doNotReuse = true } -func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { +func (cc *ClientConn) setGoAway(f *GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -756,7 +725,7 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { if cc.goAwayDebug == "" { cc.goAwayDebug = string(f.DebugData()) } - if old != nil && old.ErrCode != http2ErrCodeNo { + if old != nil && old.ErrCode != ErrCodeNo { cc.goAway.ErrCode = old.ErrCode } last := f.LastStreamID @@ -772,7 +741,7 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { // // If the caller is going to immediately make a new request on this // connection, use ReserveNewRequest instead. -func (cc *http2ClientConn) CanTakeNewRequest() bool { +func (cc *ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() return cc.canTakeNewRequestLocked() @@ -781,7 +750,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool { // ReserveNewRequest is like CanTakeNewRequest but also reserves a // concurrent stream in cc. The reservation is decremented on the // next call to RoundTrip. -func (cc *http2ClientConn) ReserveNewRequest() bool { +func (cc *ClientConn) ReserveNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() if st := cc.idleStateLocked(); !st.canTakeNewRequest { @@ -792,7 +761,7 @@ func (cc *http2ClientConn) ReserveNewRequest() bool { } // ClientConnState describes the state of a ClientConn. -type http2ClientConnState struct { +type ClientConnState struct { // Closed is whether the connection is closed. Closed bool @@ -826,17 +795,17 @@ type http2ClientConnState struct { // clientConnIdleState describes the suitability of a client // connection to initiate a new RoundTrip request. -type http2clientConnIdleState struct { +type clientConnIdleState struct { canTakeNewRequest bool } -func (cc *http2ClientConn) idleState() http2clientConnIdleState { +func (cc *ClientConn) idleState() clientConnIdleState { cc.mu.Lock() defer cc.mu.Unlock() return cc.idleStateLocked() } -func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { +func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) { if cc.singleUse && cc.nextStreamID > 1 { return } @@ -858,14 +827,14 @@ func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { return } -func (cc *http2ClientConn) canTakeNewRequestLocked() bool { +func (cc *ClientConn) canTakeNewRequestLocked() bool { st := cc.idleStateLocked() return st.canTakeNewRequest } // tooIdleLocked reports whether this connection has been sitting idle // for too much wall time. -func (cc *http2ClientConn) tooIdleLocked() bool { +func (cc *ClientConn) tooIdleLocked() bool { // The Round(0) strips the monotonic clock reading so the // times are compared based on their wall time. We don't want // to reuse a connection that's been sitting idle during @@ -879,11 +848,11 @@ func (cc *http2ClientConn) tooIdleLocked() bool { // so this simply calls the synchronized closeIfIdle to shut down this // connection. The timer could just call closeIfIdle, but this is more // clear. -func (cc *http2ClientConn) onIdleTimeout() { +func (cc *ClientConn) onIdleTimeout() { cc.closeIfIdle() } -func (cc *http2ClientConn) closeConn() error { +func (cc *ClientConn) closeConn() error { t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) defer t.Stop() return cc.tconn.Close() @@ -891,8 +860,8 @@ func (cc *http2ClientConn) closeConn() error { // A tls.Conn.Close can hang for a long time if the peer is unresponsive. // Try to shut it down more aggressively. -func (cc *http2ClientConn) forceCloseConn() { - tc, ok := cc.tconn.(NetConnWrapper) +func (cc *ClientConn) forceCloseConn() { + tc, ok := cc.tconn.(reqtls.NetConnWrapper) if !ok { return } @@ -901,7 +870,7 @@ func (cc *http2ClientConn) forceCloseConn() { } } -func (cc *http2ClientConn) closeIfIdle() { +func (cc *ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 || cc.streamsReserved > 0 { cc.mu.Unlock() @@ -912,22 +881,22 @@ func (cc *http2ClientConn) closeIfIdle() { // TODO: do clients send GOAWAY too? maybe? Just Close: cc.mu.Unlock() - if http2VerboseLogs { + if VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) } cc.closeConn() } -func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { +func (cc *ClientConn) isDoNotReuseAndIdle() bool { cc.mu.Lock() defer cc.mu.Unlock() return cc.doNotReuse && len(cc.streams) == 0 } -var http2shutdownEnterWaitStateHook = func() {} +var shutdownEnterWaitStateHook = func() {} // Shutdown gracefully closes the client connection, waiting for running streams to complete. -func (cc *http2ClientConn) Shutdown(ctx context.Context) error { +func (cc *ClientConn) Shutdown(ctx context.Context) error { if err := cc.sendGoAway(); err != nil { return err } @@ -949,7 +918,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { cc.cond.Wait() } }() - http2shutdownEnterWaitStateHook() + shutdownEnterWaitStateHook() select { case <-done: return cc.closeConn() @@ -963,7 +932,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { } } -func (cc *http2ClientConn) sendGoAway() error { +func (cc *ClientConn) sendGoAway() error { cc.mu.Lock() closing := cc.closing cc.closing = true @@ -977,7 +946,7 @@ func (cc *http2ClientConn) sendGoAway() error { cc.wmu.Lock() defer cc.wmu.Unlock() // Send a graceful shutdown frame to server - if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { + if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil { return err } if err := cc.bw.Flush(); err != nil { @@ -989,7 +958,7 @@ func (cc *http2ClientConn) sendGoAway() error { // closes the client connection immediately. In-flight requests are interrupted. // err is sent to streams. -func (cc *http2ClientConn) closeForError(err error) error { +func (cc *ClientConn) closeForError(err error) error { cc.mu.Lock() cc.closed = true for _, cs := range cc.streams { @@ -1003,13 +972,13 @@ func (cc *http2ClientConn) closeForError(err error) error { // Close closes the client connection immediately. // // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. -func (cc *http2ClientConn) Close() error { +func (cc *ClientConn) Close() error { err := errors.New("http2: client connection force closed via ClientConn.Close") return cc.closeForError(err) } // closes the client connection immediately. In-flight requests are interrupted. -func (cc *http2ClientConn) closeForLostPing() error { +func (cc *ClientConn) closeForLostPing() error { err := errors.New("http2: client connection lost") if f := cc.t.CountError; f != nil { f("conn_close_lost_ping") @@ -1017,11 +986,7 @@ func (cc *http2ClientConn) closeForLostPing() error { return cc.closeForError(err) } -// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not -// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. -var errRequestCanceled = errors.New("net/http: request canceled") - -func http2commaSeparatedTrailers(req *http.Request) (string, error) { +func commaSeparatedTrailers(req *http.Request) (string, error) { keys := make([]string, 0, len(req.Trailer)) for k := range req.Trailer { k = http.CanonicalHeaderKey(k) @@ -1038,21 +1003,14 @@ func http2commaSeparatedTrailers(req *http.Request) (string, error) { return "", nil } -func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { - if cc.t.t1 != nil { - return cc.t.t1.ResponseHeaderTimeout - } - // No way to do this (yet?) with just an http2.Transport. Probably - // no need. Request.Cancel this is the new way. We only need to support - // this for compatibility with the old http.Transport fields when - // we're doing transparent http2. - return 0 +func (cc *ClientConn) responseHeaderTimeout() time.Duration { + return cc.t.ResponseHeaderTimeout() } // checkConnHeaders checks whether req has any invalid connection-level headers. // per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields. // Certain headers are special-cased as okay but not transmitted later. -func http2checkConnHeaders(req *http.Request) error { +func checkConnHeaders(req *http.Request) error { if v := req.Header.Get("Upgrade"); v != "" { return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } @@ -1068,7 +1026,7 @@ func http2checkConnHeaders(req *http.Request) error { // actualContentLength returns a sanitized version of // req.ContentLength, where 0 actually means zero (not unknown) and -1 // means unknown. -func http2actualContentLength(req *http.Request) int64 { +func actualContentLength(req *http.Request) int64 { if req.Body == nil || req.Body == http.NoBody { return 0 } @@ -1078,31 +1036,31 @@ func http2actualContentLength(req *http.Request) int64 { return -1 } -func (cc *http2ClientConn) decrStreamReservations() { +func (cc *ClientConn) decrStreamReservations() { cc.mu.Lock() defer cc.mu.Unlock() cc.decrStreamReservationsLocked() } -func (cc *http2ClientConn) decrStreamReservationsLocked() { +func (cc *ClientConn) decrStreamReservationsLocked() { if cc.streamsReserved > 0 { cc.streamsReserved-- } } -func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { - if cc.t != nil && cc.t.t1 != nil && cc.t.t1.Debugf != nil { - cc.t.t1.Debugf("HTTP/2 %s %s", req.Method, req.URL.String()) +func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + if cc.t != nil && cc.t.Debugf() != nil { + cc.t.Debugf()("HTTP/2 %s %s", req.Method, req.URL.String()) } cc.currentRequest = req ctx := req.Context() - cs := &http2clientStream{ + cs := &clientStream{ cc: cc, ctx: ctx, reqCancel: req.Cancel, isHead: req.Method == "HEAD", reqBody: req.Body, - reqBodyContentLength: http2actualContentLength(req), + reqBodyContentLength: actualContentLength(req), trace: httptrace.ContextClientTrace(ctx), peerClosed: make(chan struct{}), abort: make(chan struct{}), @@ -1118,7 +1076,7 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return errRequestCanceled + return common.ErrRequestCanceled } } @@ -1138,7 +1096,7 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) } res.Request = req res.TLS = cc.tlsState - if res.Body == http2noBody && http2actualContentLength(req) == 0 { + if res.Body == noBody && actualContentLength(req) == 0 { // If there isn't a request or response body still being // written, then wait for the stream to be closed before // RoundTrip returns. @@ -1170,8 +1128,8 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) cs.abortStream(err) return nil, err case <-cs.reqCancel: - cs.abortStream(errRequestCanceled) - return nil, errRequestCanceled + cs.abortStream(common.ErrRequestCanceled) + return nil, common.ErrRequestCanceled } } } @@ -1179,7 +1137,7 @@ func (cc *http2ClientConn) RoundTrip(req *http.Request) (*http.Response, error) // doRequest runs for the duration of the request lifetime. // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). -func (cs *http2clientStream) doRequest(req *http.Request) { +func (cs *clientStream) doRequest(req *http.Request) { err := cs.writeRequest(req) cs.cleanupWriteRequest(err) } @@ -1191,11 +1149,11 @@ func (cs *http2clientStream) doRequest(req *http.Request) { // // It returns non-nil if the request ends otherwise. // If the returned error is StreamError, the error Code may be used in resetting the stream. -func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { +func (cs *clientStream) writeRequest(req *http.Request) (err error) { cc := cs.cc ctx := cs.ctx - if err := http2checkConnHeaders(req); err != nil { + if err := checkConnHeaders(req); err != nil { return err } @@ -1208,7 +1166,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: - return errRequestCanceled + return common.ErrRequestCanceled case <-ctx.Done(): return ctx.Err() } @@ -1224,13 +1182,13 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { return err } cc.addStreamLocked(cs) // assigns stream ID - if http2isConnectionCloseRequest(req) { + if isConnectionCloseRequest(req) { cc.doNotReuse = true } cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - if !cc.t.disableCompression() && + if !cc.t.DisableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && !cs.isHead { @@ -1249,7 +1207,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { cs.requestedGzip = true } - continueTimeout := cc.t.expectContinueTimeout() + continueTimeout := cc.t.ExpectContinueTimeout() if continueTimeout != 0 { if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { continueTimeout = 0 @@ -1258,9 +1216,9 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { } } - var dumps []*dumper - if t1 := cs.cc.t.t1; t1 != nil { - dumps = getDumpers(req.Context(), t1.dump) + var dumps []*dump.Dumper + if t := cs.cc.t; t != nil { + dumps = dump.GetDumpers(req.Context(), t.Dump()) } // Past this point (where we send request headers), it is possible for @@ -1273,9 +1231,9 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { return err } - bodyDumps := []*dumper{} + bodyDumps := []*dump.Dumper{} for _, dump := range dumps { - if dump.RequestBody { + if dump.RequestBody() { bodyDumps = append(bodyDumps, dump) } } @@ -1285,7 +1243,7 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { cs.sentEndStream = true } else { if continueTimeout != 0 { - http2traceWait100Continue(cs.trace) + traceWait100Continue(cs.trace) timer := time.NewTimer(continueTimeout) select { case <-timer.C: @@ -1297,28 +1255,28 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { case <-ctx.Done(): err = ctx.Err() case <-cs.reqCancel: - err = errRequestCanceled + err = common.ErrRequestCanceled } timer.Stop() if err != nil { - http2traceWroteRequest(cs.trace, err) + traceWroteRequest(cs.trace, err) return err } } if err = cs.writeRequestBody(req, bodyDumps); err != nil { if err != errStopReqBodyWrite { - http2traceWroteRequest(cs.trace, err) + traceWroteRequest(cs.trace, err) return err } } else { cs.sentEndStream = true for _, dump := range bodyDumps { - dump.dump([]byte("\r\n\r\n")) + dump.Dump([]byte("\r\n\r\n")) } } } - http2traceWroteRequest(cs.trace, err) + traceWroteRequest(cs.trace, err) var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} @@ -1345,12 +1303,12 @@ func (cs *http2clientStream) writeRequest(req *http.Request) (err error) { case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return errRequestCanceled + return common.ErrRequestCanceled } } } -func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*dumper) error { +func (cs *clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*dump.Dumper) error { cc := cs.cc ctx := cs.ctx @@ -1364,7 +1322,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*d case <-ctx.Done(): return ctx.Err() case <-cs.reqCancel: - return errRequestCanceled + return common.ErrRequestCanceled default: } @@ -1373,12 +1331,12 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*d // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is // sent by writeRequestBody below, along with any Trailers, // again in form HEADERS{1}, CONTINUATION{0,}) - trailers, err := http2commaSeparatedTrailers(req) + trailers, err := commaSeparatedTrailers(req) if err != nil { return err } hasTrailers := trailers != "" - contentLen := http2actualContentLength(req) + contentLen := actualContentLength(req) hasBody := contentLen != 0 hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen, dumps) if err != nil { @@ -1389,7 +1347,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*d endStream := !hasBody && !hasTrailers cs.sentHeaders = true err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) - http2traceWroteHeaders(cs.trace) + traceWroteHeaders(cs.trace) return err } @@ -1397,7 +1355,7 @@ func (cs *http2clientStream) encodeAndWriteHeaders(req *http.Request, dumps []*d // // If err (the result of writeRequest) is non-nil and the stream is not closed, // cleanupWriteRequest will send a reset to the peer. -func (cs *http2clientStream) cleanupWriteRequest(err error) { +func (cs *clientStream) cleanupWriteRequest(err error) { cc := cs.cc if cs.ID == 0 { @@ -1430,20 +1388,20 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { if err != nil { cs.abortStream(err) // possibly redundant, but harmless if cs.sentHeaders { - if se, ok := err.(http2StreamError); ok { + if se, ok := err.(StreamError); ok { if se.Cause != errFromPeer { cc.writeStreamReset(cs.ID, se.Code, err) } } else { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) + cc.writeStreamReset(cs.ID, ErrCodeCancel, err) } } cs.bufPipe.CloseWithError(err) // no-op if already closed } else { if cs.sentHeaders && !cs.sentEndStream { - cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) + cc.writeStreamReset(cs.ID, ErrCodeNo, nil) } - cs.bufPipe.CloseWithError(errRequestCanceled) + cs.bufPipe.CloseWithError(common.ErrRequestCanceled) } if cs.ID != 0 { cc.forgetStreamID(cs.ID) @@ -1461,7 +1419,7 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { // awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. -func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { +func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { for { cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { @@ -1483,7 +1441,7 @@ func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) e } // requires cc.wmu be held -func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { +func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { first := true // first frame written (HEADERS is first, then CONTINUATION) for len(hdrs) > 0 && cc.werr == nil { chunk := hdrs @@ -1493,7 +1451,7 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFram hdrs = hdrs[len(chunk):] endHeaders := len(hdrs) == 0 if first { - cc.fr.WriteHeaders(http2HeadersFrameParam{ + cc.fr.WriteHeaders(HeadersFrameParam{ StreamID: streamID, BlockFragment: chunk, EndStream: endStream, @@ -1524,7 +1482,7 @@ var ( // // It returns max(1, min(peer's advertised max frame size, // Request.ContentLength+1, 512KB)). -func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { +func (cs *clientStream) frameScratchBufferLen(maxFrameSize int) int { const max = 512 << 10 n := int64(maxFrameSize) if n > max { @@ -1543,9 +1501,9 @@ func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { return int(n) // doesn't truncate; max is 512K } -var http2bufPool sync.Pool // of *[]byte +var bufPool sync.Pool // of *[]byte -func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper) (err error) { +func (cs *clientStream) writeRequestBody(req *http.Request, dumps []*dump.Dumper) (err error) { cc := cs.cc body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM @@ -1561,19 +1519,19 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper // Scratch buffer for reading into & writing from. scratchLen := cs.frameScratchBufferLen(maxFrameSize) var buf []byte - if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { - defer http2bufPool.Put(bp) + if bp, ok := bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer bufPool.Put(bp) buf = *bp } else { buf = make([]byte, scratchLen) - defer http2bufPool.Put(&buf) + defer bufPool.Put(&buf) } writeData := cc.fr.WriteData if len(dumps) > 0 { writeData = func(streamID uint32, endStream bool, data []byte) error { for _, dump := range dumps { - dump.dump(data) + dump.Dump(data) } return cc.fr.WriteData(streamID, endStream, data) } @@ -1690,7 +1648,7 @@ func (cs *http2clientStream) writeRequestBody(req *http.Request, dumps []*dumper // control tokens from the server. // It returns either the non-zero number of tokens taken or an error // if the stream is dead. -func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { +func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { cc := cs.cc ctx := cs.ctx cc.mu.Lock() @@ -1708,7 +1666,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er case <-ctx.Done(): return 0, ctx.Err() case <-cs.reqCancel: - return 0, errRequestCanceled + return 0, common.ErrRequestCanceled default: } if a := cs.flow.available(); a > 0 { @@ -1730,7 +1688,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. -func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dumps []*dumper) ([]byte, error) { +func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dumps []*dump.Dumper) ([]byte, error) { cc.hbuf.Reset() if req.URL == nil { return nil, errNilRequestURL @@ -1748,10 +1706,10 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, var path string if req.Method != "CONNECT" { path = req.URL.RequestURI() - if !http2validPseudoPath(path) { + if !validPseudoPath(path) { orig := path path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) - if !http2validPseudoPath(path) { + if !validPseudoPath(path) { if req.URL.Opaque != "" { return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) } @@ -1853,14 +1811,14 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, f(k, v) } } - if http2shouldSendReqContentLength(req.Method, contentLength) { + if shouldSendReqContentLength(req.Method, contentLength) { f("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { f("accept-encoding", "gzip") } if !didUA { - f("user-agent", hdrUserAgentValue) + f("user-agent", header.DefaultUserAgent) } } @@ -1879,20 +1837,20 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } trace := httptrace.ContextClientTrace(req.Context()) - traceHeaders := http2traceHasWroteHeaderField(trace) + traceHeaders := traceHasWroteHeaderField(trace) writeHeader := cc.writeHeader - headerDumps := []*dumper{} + headerDumps := []*dump.Dumper{} if len(dumps) > 0 { for _, dump := range dumps { - if dump.RequestHeader { + if dump.RequestHeader() { headerDumps = append(headerDumps, dump) } } if len(headerDumps) > 0 { writeHeader = func(name, value string) { for _, dump := range headerDumps { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } cc.writeHeader(name, value) } @@ -1909,12 +1867,12 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, } writeHeader(name, value) if traceHeaders { - http2traceWroteHeaderField(trace, name, value) + traceWroteHeaderField(trace, name, value) } }) for _, dump := range headerDumps { - dump.dump([]byte("\r\n")) + dump.Dump([]byte("\r\n")) } return cc.hbuf.Bytes(), nil @@ -1925,7 +1883,7 @@ func (cc *http2ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, // transferWriter.shouldSendContentLength. // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). // -1 means unknown. -func http2shouldSendReqContentLength(method string, contentLength int64) bool { +func shouldSendReqContentLength(method string, contentLength int64) bool { if contentLength > 0 { return true } @@ -1943,7 +1901,7 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } // requires cc.wmu be held. -func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) ([]byte, error) { +func (cc *ClientConn) encodeTrailers(trailer http.Header, dumps []*dump.Dumper) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) @@ -1961,7 +1919,7 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) if len(dumps) > 0 { writeHeader = func(name, value string) { for _, dump := range dumps { - dump.dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } cc.writeHeader(name, value) } @@ -1983,24 +1941,24 @@ func (cc *http2ClientConn) encodeTrailers(trailer http.Header, dumps []*dumper) return cc.hbuf.Bytes(), nil } -func (cc *http2ClientConn) writeHeader(name, value string) { - if http2VerboseLogs { +func (cc *ClientConn) writeHeader(name, value string) { + if VerboseLogs { log.Printf("http2: Transport encoding header %q = %q", name, value) } cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) } -type http2resAndError struct { - _ http2incomparable +type resAndError struct { + _ incomparable res *http.Response err error } // requires cc.mu be held. -func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { +func (cc *ClientConn) addStreamLocked(cs *clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) - cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.add(transportDefaultStreamFlow) cs.inflow.setConnFlow(&cc.inflow) cs.ID = cc.nextStreamID cc.nextStreamID += 2 @@ -2010,7 +1968,7 @@ func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { } } -func (cc *http2ClientConn) forgetStreamID(id uint32) { +func (cc *ClientConn) forgetStreamID(id uint32) { cc.mu.Lock() slen := len(cc.streams) delete(cc.streams, id) @@ -2026,9 +1984,9 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) { // wake up RoundTrip if there is a pending request. cc.cond.Broadcast() - closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.DisableKeepAlives() if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { - if http2VerboseLogs { + if VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) } cc.closed = true @@ -2039,37 +1997,37 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) { } // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. -type http2clientConnReadLoop struct { - _ http2incomparable - cc *http2ClientConn +type clientConnReadLoop struct { + _ incomparable + cc *ClientConn } // readLoop runs in its own goroutine and reads and dispatches frames. -func (cc *http2ClientConn) readLoop() { - rl := &http2clientConnReadLoop{cc: cc} +func (cc *ClientConn) readLoop() { + rl := &clientConnReadLoop{cc: cc} defer rl.cleanup() cc.readerErr = rl.run() - if ce, ok := cc.readerErr.(http2ConnectionError); ok { + if ce, ok := cc.readerErr.(ConnectionError); ok { cc.wmu.Lock() - cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.fr.WriteGoAway(0, ErrCode(ce), nil) cc.wmu.Unlock() } } // GoAwayError is returned by the Transport when the server closes the // TCP connection after sending a GOAWAY frame. -type http2GoAwayError struct { +type GoAwayError struct { LastStreamID uint32 - ErrCode http2ErrCode + ErrCode ErrCode DebugData string } -func (e http2GoAwayError) Error() string { +func (e GoAwayError) Error() string { return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", e.LastStreamID, e.ErrCode, e.DebugData) } -func http2isEOFOrNetReadError(err error) bool { +func isEOFOrNetReadError(err error) bool { if err == io.EOF { return true } @@ -2077,7 +2035,7 @@ func http2isEOFOrNetReadError(err error) bool { return ok && ne.Op == "read" } -func (rl *http2clientConnReadLoop) cleanup() { +func (rl *clientConnReadLoop) cleanup() { cc := rl.cc cc.t.connPool().MarkDead(cc) defer cc.closeConn() @@ -2092,8 +2050,8 @@ func (rl *http2clientConnReadLoop) cleanup() { // gotten a response yet. err := cc.readerErr cc.mu.Lock() - if cc.goAway != nil && http2isEOFOrNetReadError(err) { - err = http2GoAwayError{ + if cc.goAway != nil && isEOFOrNetReadError(err) { + err = GoAwayError{ LastStreamID: cc.goAway.LastStreamID, ErrCode: cc.goAway.ErrCode, DebugData: cc.goAwayDebug, @@ -2117,13 +2075,13 @@ func (rl *http2clientConnReadLoop) cleanup() { // countReadFrameError calls Transport.CountError with a string // representing err. -func (cc *http2ClientConn) countReadFrameError(err error) { +func (cc *ClientConn) countReadFrameError(err error) { f := cc.t.CountError if f == nil || err == nil { return } - if ce, ok := err.(http2ConnectionError); ok { - errCode := http2ErrCode(ce) + if ce, ok := err.(ConnectionError); ok { + errCode := ErrCode(ce) f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) return } @@ -2142,7 +2100,7 @@ func (cc *http2ClientConn) countReadFrameError(err error) { f("read_frame_other") } -func (rl *http2clientConnReadLoop) run() error { +func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout @@ -2159,7 +2117,7 @@ func (rl *http2clientConnReadLoop) run() error { if err != nil { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } - if se, ok := err.(http2StreamError); ok { + if se, ok := err.(StreamError); ok { if cs := rl.streamByID(se.StreamID); cs != nil { if se.Cause == nil { se.Cause = cc.fr.errDetail @@ -2171,47 +2129,47 @@ func (rl *http2clientConnReadLoop) run() error { cc.countReadFrameError(err) return err } - if http2VerboseLogs { - cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) + if VerboseLogs { + cc.vlogf("http2: Transport received %s", summarizeFrame(f)) } if !gotSettings { - if _, ok := f.(*http2SettingsFrame); !ok { + if _, ok := f.(*SettingsFrame); !ok { cc.logf("protocol error: received %T before a SETTINGS frame", f) - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } gotSettings = true } switch f := f.(type) { - case *http2MetaHeadersFrame: + case *MetaHeadersFrame: err = rl.processHeaders(f) - case *http2DataFrame: + case *DataFrame: err = rl.processData(f) - case *http2GoAwayFrame: + case *GoAwayFrame: err = rl.processGoAway(f) - case *http2RSTStreamFrame: + case *RSTStreamFrame: err = rl.processResetStream(f) - case *http2SettingsFrame: + case *SettingsFrame: err = rl.processSettings(f) - case *http2PushPromiseFrame: + case *PushPromiseFrame: err = rl.processPushPromise(f) - case *http2WindowUpdateFrame: + case *WindowUpdateFrame: err = rl.processWindowUpdate(f) - case *http2PingFrame: + case *PingFrame: err = rl.processPing(f) default: cc.logf("Transport: unhandled response frame type %T", f) } if err != nil { - if http2VerboseLogs { - cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) + if VerboseLogs { + cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, summarizeFrame(f), err) } return err } } } -func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { +func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { cs := rl.streamByID(f.StreamID) if cs == nil { // We'd get here if we canceled a request while the @@ -2220,9 +2178,9 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return nil } if cs.readClosed { - rl.endStreamError(cs, http2StreamError{ + rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, - Code: http2ErrCodeProtocol, + Code: ErrCodeProtocol, Cause: errors.New("protocol error: headers after END_STREAM"), }) return nil @@ -2233,7 +2191,7 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro // when we first read the 9 byte header, not waiting // until all the HEADERS+CONTINUATION frames have been // merged. This works for now. - http2traceFirstResponseByte(cs.trace) + traceFirstResponseByte(cs.trace) } cs.firstByte = true } @@ -2245,13 +2203,13 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro res, err := rl.handleResponse(cs, f) if err != nil { - if _, ok := err.(http2ConnectionError); ok { + if _, ok := err.(ConnectionError); ok { return err } // Any other error type is a stream error. - rl.endStreamError(cs, http2StreamError{ + rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, - Code: http2ErrCodeProtocol, + Code: ErrCodeProtocol, Cause: err, }) return nil // return nil from process* funcs to keep conn alive @@ -2271,7 +2229,7 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. -func http2foreachHeaderElement(v string, fn func(string)) { +func foreachHeaderElement(v string, fn func(string)) { v = textproto.TrimString(v) if v == "" { return @@ -2293,7 +2251,7 @@ func http2foreachHeaderElement(v string, fn func(string)) { // // As a special case, handleResponse may return (nil, nil) to skip the // frame (currently only used for 1xx responses). -func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*http.Response, error) { +func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFrame) (*http.Response, error) { if f.Truncated { return nil, errResponseHeaderListSize } @@ -2325,7 +2283,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http t = make(http.Header) res.Trailer = t } - http2foreachHeaderElement(hf.Value, func(v string) { + foreachHeaderElement(hf.Value, func(v string) { t[http.CanonicalHeaderKey(v)] = nil }) } else { @@ -2359,7 +2317,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } } if statusCode == 100 { - http2traceGot100Continue(cs.trace) + traceGot100Continue(cs.trace) select { case cs.on100 <- struct{}{}: default: @@ -2385,49 +2343,49 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } if cs.isHead { - res.Body = http2noBody + res.Body = noBody return res, nil } if f.StreamEnded() { if res.ContentLength > 0 { - res.Body = http2missingBody{} + res.Body = missingBody{} } else { - res.Body = http2noBody + res.Body = noBody } return res, nil } - cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) + cs.bufPipe.setBuffer(&dataBuffer{expected: res.ContentLength}) cs.bytesRemain = res.ContentLength - res.Body = http2transportResponseBody{cs} + res.Body = transportResponseBody{cs} if cs.requestedGzip && ascii.EqualFold(res.Header.Get("Content-Encoding"), "gzip") { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 - res.Body = &http2gzipReader{body: res.Body} + res.Body = &GzipReader{Body: res.Body} res.Uncompressed = true } return res, nil } -func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { +func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFrame) error { if cs.pastTrailers { // Too many HEADERS frames for this stream. - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } cs.pastTrailers = true if !f.StreamEnded() { // We expect that any headers for trailers also // has END_STREAM. - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } if len(f.PseudoFields()) > 0 { // No pseudo header fields are defined for trailers. // TODO: ConnectionError might be overly harsh? Check. - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } trailer := make(http.Header) @@ -2443,11 +2401,11 @@ func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *htt // transportResponseBody is the concrete type of Transport.RoundTrip's // Response.Body. It is an io.ReadCloser. -type http2transportResponseBody struct { - cs *http2clientStream +type transportResponseBody struct { + cs *clientStream } -func (b http2transportResponseBody) Read(p []byte) (n int, err error) { +func (b transportResponseBody) Read(p []byte) (n int, err error) { cs := b.cs cc := cs.cc @@ -2480,8 +2438,8 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cc.mu.Lock() var connAdd, streamAdd int32 // Check the conn-level first, before the stream-level. - if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { - connAdd = http2transportDefaultConnFlow - v + if v := cc.inflow.available(); v < transportDefaultConnFlow/2 { + connAdd = transportDefaultConnFlow - v cc.inflow.add(connAdd) } if err == nil { // No need to refresh if the stream is over or failed. @@ -2489,8 +2447,8 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { // consumed by the client) when computing flow control for this // stream. v := int(cs.inflow.available()) + cs.bufPipe.Len() - if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { - streamAdd = int32(http2transportDefaultStreamFlow - v) + if v < transportDefaultStreamFlow-transportDefaultStreamMinRefresh { + streamAdd = int32(transportDefaultStreamFlow - v) cs.inflow.add(streamAdd) } } @@ -2500,10 +2458,10 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cc.wmu.Lock() defer cc.wmu.Unlock() if connAdd != 0 { - cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + cc.fr.WriteWindowUpdate(0, mustUint31(connAdd)) } if streamAdd != 0 { - cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + cc.fr.WriteWindowUpdate(cs.ID, mustUint31(streamAdd)) } cc.bw.Flush() } @@ -2512,7 +2470,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { var errClosedResponseBody = errors.New("http2: response body closed") -func (b http2transportResponseBody) Close() error { +func (b transportResponseBody) Close() error { cs := b.cs cc := cs.cc @@ -2547,12 +2505,12 @@ func (b http2transportResponseBody) Close() error { // Don't treat this as an error. return nil case <-cs.reqCancel: - return errRequestCanceled + return common.ErrRequestCanceled } return nil } -func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { +func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc := rl.cc cs := rl.streamByID(f.StreamID) data := f.Data() @@ -2563,7 +2521,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { if f.StreamID >= neverSent { // We never asked for this. cc.logf("http2: Transport received unsolicited DATA frame; closing connection") - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } // We probably did ask for this, but canceled. Just ignore it. // TODO: be stricter here? only silently ignore things which @@ -2585,26 +2543,26 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } if cs.readClosed { cc.logf("protocol error: received DATA after END_STREAM") - rl.endStreamError(cs, http2StreamError{ + rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, - Code: http2ErrCodeProtocol, + Code: ErrCodeProtocol, }) return nil } if !cs.firstByte { cc.logf("protocol error: received DATA before a HEADERS frame") - rl.endStreamError(cs, http2StreamError{ + rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, - Code: http2ErrCodeProtocol, + Code: ErrCodeProtocol, }) return nil } if f.Length > 0 { if cs.isHead && len(data) > 0 { cc.logf("protocol error: received DATA on a HEAD request") - rl.endStreamError(cs, http2StreamError{ + rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, - Code: http2ErrCodeProtocol, + Code: ErrCodeProtocol, }) return nil } @@ -2614,7 +2572,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cs.inflow.take(int32(f.Length)) } else { cc.mu.Unlock() - return http2ConnectionError(http2ErrCodeFlowControl) + return ConnectionError(ErrCodeFlowControl) } // Return any padded flow control now, since we won't // refund it later on body reads. @@ -2664,7 +2622,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return nil } -func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { +func (rl *clientConnReadLoop) endStream(cs *clientStream) { // TODO: check that any declared content-length matches, like // server.go's (*stream).endStream method. if !cs.readClosed { @@ -2680,12 +2638,12 @@ func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { } } -func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { +func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) { cs.readAborted = true cs.abortStream(err) } -func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { +func (rl *clientConnReadLoop) streamByID(id uint32) *clientStream { rl.cc.mu.Lock() defer rl.cc.mu.Unlock() cs := rl.cc.streams[id] @@ -2695,7 +2653,7 @@ func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { return nil } -func (cs *http2clientStream) copyTrailers() { +func (cs *clientStream) copyTrailers() { for k, vv := range cs.trailer { t := cs.resTrailer if *t == nil { @@ -2705,7 +2663,7 @@ func (cs *http2clientStream) copyTrailers() { } } -func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { +func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error { cc := rl.cc cc.t.connPool().MarkDead(cc) if f.ErrCode != 0 { @@ -2720,7 +2678,7 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { return nil } -func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { +func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { cc := rl.cc // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. @@ -2737,7 +2695,7 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error return nil } -func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { +func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -2747,26 +2705,26 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) cc.wantSettingsAck = false return nil } - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } var seenMaxConcurrentStreams bool - err := f.ForeachSetting(func(s http2Setting) error { + err := f.ForeachSetting(func(s Setting) error { switch s.ID { - case http2SettingMaxFrameSize: + case SettingMaxFrameSize: cc.maxFrameSize = s.Val - case http2SettingMaxConcurrentStreams: + case SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val seenMaxConcurrentStreams = true - case http2SettingMaxHeaderListSize: + case SettingMaxHeaderListSize: cc.peerMaxHeaderListSize = uint64(s.Val) - case http2SettingInitialWindowSize: + case SettingInitialWindowSize: // Values above the maximum flow-control // window size of 2^31-1 MUST be treated as a // connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR. if s.Val > math.MaxInt32 { - return http2ConnectionError(http2ErrCodeFlowControl) + return ConnectionError(ErrCodeFlowControl) } // Adjust flow control of currently-open @@ -2795,7 +2753,7 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) // didn't contain a MAX_CONCURRENT_STREAMS field so // increase the number of concurrent streams this // connection can establish to our default. - cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + cc.maxConcurrentStreams = defaultMaxConcurrentStreams } cc.seenSettings = true } @@ -2803,7 +2761,7 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) return nil } -func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { +func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { cc := rl.cc cs := rl.streamByID(f.StreamID) if f.StreamID != 0 && cs == nil { @@ -2818,21 +2776,21 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame fl = &cs.flow } if !fl.add(int32(f.Increment)) { - return http2ConnectionError(http2ErrCodeFlowControl) + return ConnectionError(ErrCodeFlowControl) } cc.cond.Broadcast() return nil } -func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { +func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { cs := rl.streamByID(f.StreamID) if cs == nil { // TODO: return error if server tries to RST_STREAM an idle stream return nil } - serr := http2streamError(cs.ID, f.ErrCode) + serr := streamError(cs.ID, f.ErrCode) serr.Cause = errFromPeer - if f.ErrCode == http2ErrCodeProtocol { + if f.ErrCode == ErrCodeProtocol { rl.cc.SetDoNotReuse() } if fn := cs.cc.t.CountError; fn != nil { @@ -2845,7 +2803,7 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er } // Ping sends a PING frame to the server and waits for the ack. -func (cc *http2ClientConn) Ping(ctx context.Context) error { +func (cc *ClientConn) Ping(ctx context.Context) error { c := make(chan struct{}) // Generate a random payload var p [8]byte @@ -2888,7 +2846,7 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { } } -func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { +func (rl *clientConnReadLoop) processPing(f *PingFrame) error { if f.IsAck() { cc := rl.cc cc.mu.Lock() @@ -2909,7 +2867,7 @@ func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { return cc.bw.Flush() } -func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { +func (rl *clientConnReadLoop) processPushPromise(f *PushPromiseFrame) error { // We told the peer we don't want them. // Spec says: // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH @@ -2917,10 +2875,10 @@ func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) // has set this setting and has received acknowledgement MUST // treat the receipt of a PUSH_PROMISE frame as a connection // error (Section 5.4.1) of type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return ConnectionError(ErrCodeProtocol) } -func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { +func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error) { // TODO: map err to more interesting error codes, once the // HTTP community comes up with some. But currently for // RST_STREAM there's no equivalent to GOAWAY frame's debug @@ -2936,38 +2894,38 @@ var ( errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") ) -func (cc *http2ClientConn) logf(format string, args ...interface{}) { +func (cc *ClientConn) logf(format string, args ...interface{}) { cc.t.logf(format, args...) } -func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { +func (cc *ClientConn) vlogf(format string, args ...interface{}) { cc.t.vlogf(format, args...) } -func (t *http2Transport) vlogf(format string, args ...interface{}) { - if http2VerboseLogs { +func (t *Transport) vlogf(format string, args ...interface{}) { + if VerboseLogs { t.logf(format, args...) } } -func (t *http2Transport) logf(format string, args ...interface{}) { +func (t *Transport) logf(format string, args ...interface{}) { log.Printf(format, args...) } -var http2noBody io.ReadCloser = noBodyReader{} +var noBody io.ReadCloser = noBodyReader{} type noBodyReader struct{} func (noBodyReader) Close() error { return nil } func (noBodyReader) Read([]byte) (int, error) { return 0, io.EOF } -type http2missingBody struct{} +type missingBody struct{} -func (http2missingBody) Close() error { return nil } +func (missingBody) Close() error { return nil } -func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } +func (missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } -func http2strSliceContains(ss []string, s string) bool { +func strSliceContains(ss []string, s string) bool { for _, v := range ss { if v == s { return true @@ -2976,29 +2934,29 @@ func http2strSliceContains(ss []string, s string) bool { return false } -type http2erringRoundTripper struct{ err error } +type erringRoundTripper struct{ err error } -func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } +func (rt erringRoundTripper) RoundTripErr() error { return rt.err } -func (rt http2erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { +func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err } -// gzipReader wraps a response body so it can lazily +// GzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read -type http2gzipReader struct { - _ http2incomparable - body io.ReadCloser // underlying Response.Body +type GzipReader struct { + _ incomparable + Body io.ReadCloser // underlying Response.Body zr *gzip.Reader // lazily-initialized gzip reader zerr error // sticky error } -func (gz *http2gzipReader) Read(p []byte) (n int, err error) { +func (gz *GzipReader) Read(p []byte) (n int, err error) { if gz.zerr != nil { return 0, gz.zerr } if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) + gz.zr, err = gzip.NewReader(gz.Body) if err != nil { gz.zerr = err return 0, err @@ -3007,19 +2965,19 @@ func (gz *http2gzipReader) Read(p []byte) (n int, err error) { return gz.zr.Read(p) } -func (gz *http2gzipReader) Close() error { - return gz.body.Close() +func (gz *GzipReader) Close() error { + return gz.Body.Close() } // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. -func http2isConnectionCloseRequest(req *http.Request) bool { +func isConnectionCloseRequest(req *http.Request) bool { return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") } // registerHTTPSProtocol calls Transport.RegisterProtocol but // converting panics into errors. -func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { +func registerHTTPSProtocol(t transport.Interface, rt noDialH2RoundTripper) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("%v", e) @@ -3033,19 +2991,12 @@ func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err // if there's already has a cached connection to the host. // (The field is exported so it can be accessed via reflect from net/http; tested // by TestNoDialH2RoundTripperType) -type http2noDialH2RoundTripper struct{ *http2Transport } +type noDialH2RoundTripper struct{ *Transport } -func (rt http2noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - res, err := rt.http2Transport.RoundTrip(req) - if http2isNoCachedConnError(err) { +func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + res, err := rt.Transport.RoundTrip(req) + if IsNoCachedConnError(err) { return nil, http.ErrSkipAltProtocol } return res, err } - -func (t *http2Transport) idleConnTimeout() time.Duration { - if t.t1 != nil { - return t.t1.IdleConnTimeout - } - return 0 -} diff --git a/h2_transport_go117_test.go b/internal/http2/transport_go117_test.go similarity index 81% rename from h2_transport_go117_test.go rename to internal/http2/transport_go117_test.go index b27237b8..c39dc7a1 100644 --- a/h2_transport_go117_test.go +++ b/internal/http2/transport_go117_test.go @@ -5,12 +5,13 @@ //go:build go1.17 // +build go1.17 -package req +package http2 import ( "context" "crypto/tls" "errors" + "github.com/imroc/req/v3/internal/tests" "net/http" "net/http/httptest" @@ -33,15 +34,17 @@ func TestTransportDialTLSContexth2(t *testing.T) { ) defer ts.Close() tr := &Transport{ - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Tests that the context provided to `req` is - // passed into this function. - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() + Interface: tests.Transport{ + TLSClientConfigValue: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Tests that the context provided to `req` is + // passed into this function. + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, }, - InsecureSkipVerify: true, }, } defer tr.CloseIdleConnections() @@ -97,19 +100,21 @@ func TestDialRaceResumesDial(t *testing.T) { ) defer ts.Close() tr := &Transport{ - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - select { - case <-blockCh: - // If we already errored, return without error. - return &tls.Certificate{}, nil - default: - } - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() + Interface: tests.Transport{ + TLSClientConfigValue: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + select { + case <-blockCh: + // If we already errored, return without error. + return &tls.Certificate{}, nil + default: + } + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, }, - InsecureSkipVerify: true, }, } defer tr.CloseIdleConnections() diff --git a/h2_transport_test.go b/internal/http2/transport_test.go similarity index 87% rename from h2_transport_test.go rename to internal/http2/transport_test.go index 1b143d8a..5fce877d 100644 --- a/h2_transport_test.go +++ b/internal/http2/transport_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package req +package http2 import ( "bufio" @@ -13,6 +13,8 @@ import ( "errors" "flag" "fmt" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/tests" "io" "io/ioutil" "log" @@ -58,7 +60,7 @@ func TestTransportExternal(t *testing.T) { t.Skip("skipping external network test") } req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) - rt := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + rt := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} res, err := rt.RoundTrip(req) if err != nil { t.Fatalf("%v", err) @@ -73,20 +75,20 @@ type fakeTLSConn struct { func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { return tls.ConnectionState{ Version: tls.VersionTLS12, - CipherSuite: http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } } func startH2cServer(t *testing.T) net.Listener { - h2Server := &http2Server{} - l := newLocalListener(t) + h2Server := &Server{} + l := tests.NewLocalListener(t) go func() { conn, err := l.Accept() if err != nil { t.Error(err) return } - h2Server.ServeConn(&fakeTLSConn{conn}, &http2ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) })}) }() @@ -109,8 +111,8 @@ func TestTransportH2c(t *testing.T) { }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - tr := &http2Transport{ - t1: &Transport{}, + tr := &Transport{ + Interface: tests.Transport{}, AllowHTTP: true, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) @@ -142,7 +144,7 @@ func TestTransport(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() u, err := url.Parse(st.ts.URL) @@ -198,9 +200,9 @@ func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Req t.Logf("conn %v is now state %v", c.RemoteAddr(), st) }) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, } defer tr.CloseIdleConnections() @@ -277,9 +279,9 @@ func testTransportGetGotConnHooks(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, } @@ -350,9 +352,9 @@ func TestTransportGroupsPendingDials(t *testing.T) { dialCount int closeCount int ) - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: &tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() @@ -425,7 +427,7 @@ func TestTransportAbortClosesPipes(t *testing.T) { errCh := make(chan error) go func() { defer close(errCh) - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { errCh <- err @@ -468,7 +470,7 @@ func TestTransportPath(t *testing.T) { ) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() const ( path = "/testpath" @@ -535,7 +537,7 @@ func TestActualContentLength(t *testing.T) { }, } for i, tt := range tests { - got := http2actualContentLength(tt.req) + got := actualContentLength(tt.req) if got != tt.want { t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) } @@ -582,7 +584,7 @@ func TestTransportBody(t *testing.T) { defer st.Close() for i, tt := range bodyTests { - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() var body io.Reader = strings.NewReader(tt.body) @@ -638,8 +640,8 @@ func TestTransportDialTLSh2(t *testing.T) { optOnlyServer, ) defer ts.Close() - tr := &http2Transport{ - t1: &Transport{}, + tr := &Transport{ + Interface: tests.Transport{}, DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() didDial = true @@ -668,47 +670,6 @@ func TestTransportDialTLSh2(t *testing.T) { } } -func TestConfigureTransport(t *testing.T) { - t1 := &Transport{} - err := http2ConfigureTransport(t1) - if err != nil { - t.Fatal(err) - } - if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) { - // Laziness, to avoid buildtags. - t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got) - } - wantNextProtos := []string{"h2", "http/1.1"} - if t1.TLSClientConfig == nil { - t.Errorf("nil t1.TLSClientConfig") - } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) { - t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos) - } - if err := http2ConfigureTransport(t1); err == nil { - t.Error("unexpected success on second call to http2ConfigureTransport") - } - - // And does it work? - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.Proto) - }, optOnlyServer) - defer st.Close() - - t1.TLSClientConfig.InsecureSkipVerify = true - c := &http.Client{Transport: t1} - res, err := c.Get(st.ts.URL) - if err != nil { - t.Fatal(err) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if got, want := string(slurp), "HTTP/2.0"; got != want { - t.Errorf("body = %q; want %q", got, want) - } -} - type capitalizeReader struct { r io.Reader } @@ -737,9 +698,9 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { type clientTester struct { t *testing.T - tr *http2Transport - sc, cc net.Conn // server and client conn - fr *http2Framer // server's framer + tr *Transport + sc, cc net.Conn // server and client conn + fr *Framer // server's framer client func() error server func() error } @@ -752,9 +713,9 @@ func newClientTester(t *testing.T) *clientTester { ct := &clientTester{ t: t, } - ct.tr = &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + ct.tr = &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { dialOnce.Lock() @@ -767,7 +728,7 @@ func newClientTester(t *testing.T) *clientTester { }, } - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) @@ -780,12 +741,12 @@ func newClientTester(t *testing.T) *clientTester { ln.Close() ct.cc = cc ct.sc = sc - ct.fr = http2NewFramer(sc, sc) + ct.fr = NewFramer(sc, sc) return ct } -func (ct *clientTester) greet(settings ...http2Setting) { - buf := make([]byte, len(http2ClientPreface)) +func (ct *clientTester) greet(settings ...Setting) { + buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { ct.t.Fatalf("reading client preface: %v", err) @@ -794,7 +755,7 @@ func (ct *clientTester) greet(settings ...http2Setting) { if err != nil { ct.t.Fatalf("Reading client settings frame: %v", err) } - if sf, ok := f.(*http2SettingsFrame); !ok { + if sf, ok := f.(*SettingsFrame); !ok { ct.t.Fatalf("Wanted client settings frame; got %v", f) _ = sf // stash it away? } @@ -806,13 +767,13 @@ func (ct *clientTester) greet(settings ...http2Setting) { } } -func (ct *clientTester) readNonSettingsFrame() (http2Frame, error) { +func (ct *clientTester) readNonSettingsFrame() (Frame, error) { for { f, err := ct.fr.ReadFrame() if err != nil { return nil, err } - if _, ok := f.(*http2SettingsFrame); ok { + if _, ok := f.(*SettingsFrame); ok { continue } return f, nil @@ -849,21 +810,21 @@ func (ct *clientTester) run() { errOnce.Do(ct.cleanup) // clean up if no error } -func (ct *clientTester) readFrame() (http2Frame, error) { +func (ct *clientTester) readFrame() (Frame, error) { return ct.fr.ReadFrame() } -func (ct *clientTester) firstHeaders() (*http2HeadersFrame, error) { +func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { for { f, err := ct.readFrame() if err != nil { return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: + case *WindowUpdateFrame, *SettingsFrame: continue } - hf, ok := f.(*http2HeadersFrame) + hf, ok := f.(*HeadersFrame) if !ok { return nil, fmt.Errorf("Got %T; want HeadersFrame", f) } @@ -899,8 +860,8 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { } defer close(clientDone) - body := &http2pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) + body := &pipe{b: new(bytes.Buffer)} + io.Copy(body, io.LimitReader(tests.NeverEnding('A'), bodySize/2)) req, err := http.NewRequest("PUT", "https://dummy.tld/", body) if err != nil { return err @@ -912,7 +873,7 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { if res.StatusCode != status { return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) } - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) + io.Copy(body, io.LimitReader(tests.NeverEnding('A'), bodySize/2)) body.CloseWithError(io.EOF) slurp, err := ioutil.ReadAll(res.Body) if err != nil { @@ -956,20 +917,20 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { // println(fmt.Sprintf("server got frame: %v", f)) ended := false switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } if f.StreamEnded() { return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) } - case *http2DataFrame: + case *DataFrame: dataLen := len(f.Data()) if dataLen > 0 { if dataRecv == 0 { enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, @@ -996,7 +957,7 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { if f.StreamEnded() { ended = true } - case *http2RSTStreamFrame: + case *RSTStreamFrame: if status == 200 { return fmt.Errorf("Unexpected client frame %v", f) } @@ -1025,7 +986,7 @@ func TestTransportFullDuplex(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1078,7 +1039,7 @@ func TestTransportConnectRequest(t *testing.T) { t.Fatal(err) } - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1243,7 +1204,7 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy hbf := buf.Bytes() switch mode { case oneHeader: - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: true, EndStream: endStream, @@ -1253,7 +1214,7 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy if len(hbf) < 2 { panic("too small") } - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: false, EndStream: endStream, @@ -1265,8 +1226,8 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy } } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2DataFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *DataFrame: if !f.StreamEnded() { // No need to send flow control tokens. The test request body is tiny. continue @@ -1296,7 +1257,7 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy if endStream { return nil } - case *http2HeadersFrame: + case *HeadersFrame: if expect100Continue != noHeader { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) @@ -1311,8 +1272,8 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy // Issue 26189, Issue 17739: ignore unknown 1xx responses func TestTransportUnknown1xx(t *testing.T) { var buf bytes.Buffer - defer func() { http2got1xxFuncForTests = nil }() - http2got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { + defer func() { got1xxFuncForTests = nil }() + got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) return nil } @@ -1350,13 +1311,13 @@ code=114 header=map[Foo-Bar:[114]] return err } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: for i := 110; i <= 114; i++ { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, @@ -1365,7 +1326,7 @@ code=114 header=map[Foo-Bar:[114]] } buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, @@ -1407,13 +1368,13 @@ func TestTransportReceiveUndeclaredTrailer(t *testing.T) { ct.greet() var n int - var hf *http2HeadersFrame + var hf *HeadersFrame for hf == nil && n < 10 { f, err := ct.fr.ReadFrame() if err != nil { return err } - hf, _ = f.(*http2HeadersFrame) + hf, _ = f.(*HeadersFrame) n++ } @@ -1422,7 +1383,7 @@ func TestTransportReceiveUndeclaredTrailer(t *testing.T) { // send headers without Trailer header enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -1432,7 +1393,7 @@ func TestTransportReceiveUndeclaredTrailer(t *testing.T) { // send trailers buf.Reset() enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: true, @@ -1450,7 +1411,7 @@ func TestTransportInvalidTrailerPseudo2(t *testing.T) { testTransportInvalidTrailerPseudo(t, splitHeader) } func testTransportInvalidTrailerPseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, http2pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) }) @@ -1463,18 +1424,18 @@ func TestTransportInvalidTrailerCapital2(t *testing.T) { testTransportInvalidTrailerCapital(t, splitHeader) } func testTransportInvalidTrailerCapital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, http2headerFieldNameError("Capital"), func(enc *hpack.Encoder) { + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) }) } func TestTransportInvalidTrailerEmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, http2headerFieldNameError(""), func(enc *hpack.Encoder) { + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) }) } func TestTransportInvalidTrailerBinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, http2headerFieldValueError("x"), func(enc *hpack.Encoder) { + testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) }) } @@ -1492,7 +1453,7 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT return fmt.Errorf("status code = %v; want 200", res.StatusCode) } slurp, err := ioutil.ReadAll(res.Body) - se, ok := err.(http2StreamError) + se, ok := err.(StreamError) if !ok || se.Cause != wantErr { return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) } @@ -1512,13 +1473,13 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT return err } switch f := f.(type) { - case *http2HeadersFrame: + case *HeadersFrame: var endStream bool send := func(mode headerType) { hbf := buf.Bytes() switch mode { case oneHeader: - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: endStream, @@ -1528,7 +1489,7 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT if len(hbf) < 2 { panic("too small") } - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: false, EndStream: endStream, @@ -1693,7 +1654,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { ) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() checkRoundTrip := func(req *http.Request, wantErr error, desc string) { @@ -1721,12 +1682,12 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } } headerListSizeForRequest := func(req *http.Request) (size uint64) { - contentLen := http2actualContentLength(req) - trailers, err := http2commaSeparatedTrailers(req) + contentLen := actualContentLength(req) + trailers, err := commaSeparatedTrailers(req) if err != nil { t.Fatalf("headerListSizeForRequest: %v", err) } - cc := &http2ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} cc.henc = hpack.NewEncoder(&cc.hbuf) cc.mu.Lock() hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen, nil) @@ -1734,7 +1695,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { if err != nil { t.Fatalf("headerListSizeForRequest: %v", err) } - hpackDec := hpack.NewDecoder(http2initialHeaderTableSize, func(hf hpack.HeaderField) { + hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) { size += uint64(hf.Size()) }) if len(hdrs) > 0 { @@ -1764,7 +1725,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { // Get the ClientConn associated with the request and validate // peerMaxHeaderListSize. - addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + addr := authorityAddr(req.URL.Scheme, req.URL.Host) cc, err := tr.connPool().GetClientConn(req, addr) if err != nil { t.Fatalf("GetClientConn: %v", err) @@ -1829,7 +1790,7 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) - if e, ok := err.(http2StreamError); ok { + if e, ok := err.(StreamError); ok { err = e.Cause } if err != errResponseHeaderListSize { @@ -1857,7 +1818,7 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { return err } switch f := f.(type) { - case *http2HeadersFrame: + case *HeadersFrame: enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) large := strings.Repeat("a", 1<<10) for i := 0; i < 5042; i++ { @@ -1872,7 +1833,7 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { // header block fragment frame. return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) } - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, @@ -1903,8 +1864,8 @@ func TestTransportCookieHeaderSplit(t *testing.T) { return err } switch f := f.(type) { - case *http2HeadersFrame: - dec := hpack.NewDecoder(http2initialHeaderTableSize, nil) + case *HeadersFrame: + dec := hpack.NewDecoder(initialHeaderTableSize, nil) hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) if err != nil { return err @@ -1923,7 +1884,7 @@ func TestTransportCookieHeaderSplit(t *testing.T) { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, @@ -1952,7 +1913,7 @@ func TestTransportBodyReadErrorType(t *testing.T) { ) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1964,8 +1925,8 @@ func TestTransportBodyReadErrorType(t *testing.T) { doPanic <- true buf := make([]byte, 100) n, err := res.Body.Read(buf) - got, ok := err.(http2StreamError) - want := http2StreamError{StreamID: 0x1, Code: 0x2} + got, ok := err.(StreamError) + want := StreamError{StreamID: 0x1, Code: 0x2} if !ok || got.StreamID != want.StreamID || got.Code != want.Code { t.Errorf("Read = %v, %#v; want error %#v", n, err, want) } @@ -1992,9 +1953,9 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) { ) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -2025,10 +1986,10 @@ func TestTransportDisableKeepAlives(t *testing.T) { defer st.Close() connClosed := make(chan struct{}) // closed on tls.Conn.Close - tr := &http2Transport{ - t1: &Transport{ - DisableKeepAlives: true, - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + DisableKeepAlivesValue: true, + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -2071,10 +2032,10 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { var dials int32 var conns sync.WaitGroup - tr := &http2Transport{ - t1: &Transport{ - DisableKeepAlives: true, - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + DisableKeepAlivesValue: true, + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -2152,8 +2113,8 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { func testTransportResponseHeaderTimeout(t *testing.T, body bool) { ct := newClientTester(t) - ct.tr.t1 = &Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, + ct.tr.Interface = &tests.Transport{ + ResponseHeaderTimeoutValue: 5 * time.Millisecond, } ct.client = func() error { c := &http.Client{Transport: ct.tr} @@ -2182,7 +2143,7 @@ func testTransportResponseHeaderTimeout(t *testing.T, body bool) { return nil } switch f := f.(type) { - case *http2DataFrame: + case *DataFrame: dataLen := len(f.Data()) if dataLen > 0 { if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { @@ -2192,8 +2153,8 @@ func testTransportResponseHeaderTimeout(t *testing.T, body bool) { return err } } - case *http2RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == http2ErrCodeCancel { + case *RSTStreamFrame: + if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { return nil } } @@ -2206,7 +2167,7 @@ func TestTransportDisableCompression(t *testing.T) { const body = "sup" st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { want := http.Header{ - "User-Agent": []string{hdrUserAgentValue}, + "User-Agent": []string{header.DefaultUserAgent}, } if !reflect.DeepEqual(r.Header, want) { t.Errorf("request headers = %v; want %v", r.Header, want) @@ -2214,10 +2175,10 @@ func TestTransportDisableCompression(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - DisableCompression: true, - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + DisableCompressionValue: true, + TLSClientConfigValue: tlsConfigInsecure, }, } defer tr.CloseIdleConnections() @@ -2245,7 +2206,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() tests := []struct { @@ -2350,7 +2311,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { // Reject content-length headers containing a sign. // See https://golang.org/issue/39017 func TestTransportRejectsContentLengthWithSign(t *testing.T) { - tests := []struct { + testCases := []struct { name string cl []string wantCL string @@ -2382,14 +2343,14 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { }, } - for _, tt := range tests { + for _, tt := range testCases { tt := tt t.Run(tt.name, func(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", tt.cl[0]) }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("HEAD", st.ts.URL, nil) @@ -2422,7 +2383,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { }, optOnlyServer) defer st.Close() - tests := [...]struct { + testCases := [...]struct { h http.Header wantErr string }{ @@ -2444,10 +2405,10 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { }, } - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() - for i, tt := range tests { + for i, tt := range testCases { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h res, err := tr.RoundTrip(req) @@ -2472,11 +2433,11 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { } } -// Tests that gzipReader doesn't crash on a second Read call following +// Tests that GzipReader doesn't crash on a second Read call following // the first Read call's gzip.NewReader returning an error. func TestGzipReader_DoubleReadCrash(t *testing.T) { - gz := &http2gzipReader{ - body: ioutil.NopCloser(strings.NewReader("0123456789")), + gz := &GzipReader{ + Body: ioutil.NopCloser(strings.NewReader("0123456789")), } var buf [1]byte n, err1 := gz.Read(buf[:]) @@ -2490,7 +2451,7 @@ func TestGzipReader_DoubleReadCrash(t *testing.T) { } func TestTransportNewTLSConfig(t *testing.T) { - tests := [...]struct { + testCases := [...]struct { conf *tls.Config host string want *tls.Config @@ -2501,7 +2462,7 @@ func TestTransportNewTLSConfig(t *testing.T) { host: "foo.com", want: &tls.Config{ ServerName: "foo.com", - NextProtos: []string{http2NextProtoTLS}, + NextProtos: []string{NextProtoTLS}, }, }, @@ -2513,7 +2474,7 @@ func TestTransportNewTLSConfig(t *testing.T) { host: "foo.com", want: &tls.Config{ ServerName: "bar.com", - NextProtos: []string{http2NextProtoTLS}, + NextProtos: []string{NextProtoTLS}, }, }, @@ -2525,32 +2486,32 @@ func TestTransportNewTLSConfig(t *testing.T) { host: "example.com", want: &tls.Config{ ServerName: "example.com", - NextProtos: []string{http2NextProtoTLS, "foo", "bar"}, + NextProtos: []string{NextProtoTLS, "foo", "bar"}, }, }, // NextProto is not duplicated: 3: { conf: &tls.Config{ - NextProtos: []string{"foo", "bar", http2NextProtoTLS}, + NextProtos: []string{"foo", "bar", NextProtoTLS}, }, host: "example.com", want: &tls.Config{ ServerName: "example.com", - NextProtos: []string{"foo", "bar", http2NextProtoTLS}, + NextProtos: []string{"foo", "bar", NextProtoTLS}, }, }, } - for i, tt := range tests { + for i, tt := range testCases { // Ignore the session ticket keys part, which ends up populating // unexported fields in the Config: if tt.conf != nil { tt.conf.SessionTicketsDisabled = true } - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tt.conf, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tt.conf, }, } got := tr.newTLSConfig(tt.host) @@ -2596,7 +2557,7 @@ func TestTransportReadHeadResponse(t *testing.T) { t.Logf("ReadFrame: %v", err) return nil } - hf, ok := f.(*http2HeadersFrame) + hf, ok := f.(*HeadersFrame) if !ok { continue } @@ -2604,7 +2565,7 @@ func TestTransportReadHeadResponse(t *testing.T) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, // as the GFE does @@ -2655,7 +2616,7 @@ func TestTransportReadHeadResponseWithBody(t *testing.T) { t.Logf("ReadFrame: %v", err) return nil } - hf, ok := f.(*http2HeadersFrame) + hf, ok := f.(*HeadersFrame) if !ok { continue } @@ -2663,7 +2624,7 @@ func TestTransportReadHeadResponseWithBody(t *testing.T) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -2695,18 +2656,18 @@ func TestTransportHandlerBodyClose(t *testing.T) { const bodySize = 10 << 20 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { r.Body.Close() - io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) + io.Copy(w, io.LimitReader(tests.NeverEnding('A'), bodySize)) }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() g0 := runtime.NumGoroutine() const numReq = 10 for i := 0; i < numReq; i++ { - req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(tests.NeverEnding('A'), bodySize)}) if err != nil { t.Fatal(err) } @@ -2752,7 +2713,7 @@ func TestTransportFlowControl(t *testing.T) { } }, optOnlyServer) - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { @@ -2776,7 +2737,7 @@ func TestTransportFlowControl(t *testing.T) { } read += int64(n) - const max = http2transportDefaultStreamFlow + const max = transportDefaultStreamFlow if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) } @@ -2803,7 +2764,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { ct := newClientTester(t) clientDone := make(chan struct{}) - const goAwayErrCode = http2ErrCodeHTTP11Required // arbitrary + const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" ct.client = func() error { @@ -2817,7 +2778,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { _, err = io.Copy(ioutil.Discard, res.Body) res.Body.Close() } - want := http2GoAwayError{ + want := GoAwayError{ LastStreamID: 5, ErrCode: goAwayErrCode, DebugData: goAwayDebugData, @@ -2835,7 +2796,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { t.Logf("ReadFrame: %v", err) return nil } - hf, ok := f.(*http2HeadersFrame) + hf, ok := f.(*HeadersFrame) if !ok { continue } @@ -2844,7 +2805,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -2853,7 +2814,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { } // Write two GOAWAY frames, to test that the Transport takes // the interesting parts of both. - ct.fr.WriteGoAway(5, http2ErrCodeNo, []byte(goAwayDebugData)) + ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) ct.fr.WriteGoAway(5, goAwayErrCode, nil) ct.sc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { @@ -2892,18 +2853,18 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { ct.server = func() error { ct.greet() - var hf *http2HeadersFrame + var hf *HeadersFrame for { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: + case *WindowUpdateFrame, *SettingsFrame: continue } var ok bool - hf, ok = f.(*http2HeadersFrame) + hf, ok = f.(*HeadersFrame) if !ok { return fmt.Errorf("Got %T; want HeadersFrame", f) } @@ -2914,7 +2875,7 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -2951,25 +2912,25 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) } switch f := f.(type) { - case *http2SettingsFrame: - case *http2RSTStreamFrame: + case *SettingsFrame: + case *RSTStreamFrame: if sawRST { - return fmt.Errorf("saw second RSTStreamFrame: %v", http2summarizeFrame(f)) + return fmt.Errorf("saw second RSTStreamFrame: %v", summarizeFrame(f)) } - if f.ErrCode != http2ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", http2summarizeFrame(f)) + if f.ErrCode != ErrCodeCancel { + return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } sawRST = true - case *http2WindowUpdateFrame: + case *WindowUpdateFrame: if sawWUF { - return fmt.Errorf("saw second WindowUpdateFrame: %v", http2summarizeFrame(f)) + return fmt.Errorf("saw second WindowUpdateFrame: %v", summarizeFrame(f)) } if f.Increment != 4999 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", http2summarizeFrame(f)) + return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } sawWUF = true default: - return fmt.Errorf("Unexpected frame: %v", http2summarizeFrame(f)) + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } } return nil @@ -3003,7 +2964,7 @@ func TestTransportAdjustsFlowControl(t *testing.T) { } defer close(clientDone) - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(tests.NeverEnding('A'), bodySize)}) res, err := ct.tr.RoundTrip(req) if err != nil { return err @@ -3012,7 +2973,7 @@ func TestTransportAdjustsFlowControl(t *testing.T) { return nil } ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(http2ClientPreface))) + _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) if err != nil { return fmt.Errorf("reading client preface: %v", err) } @@ -3030,16 +2991,16 @@ func TestTransportAdjustsFlowControl(t *testing.T) { } } switch f := f.(type) { - case *http2DataFrame: + case *DataFrame: gotBytes += int64(len(f.Data())) // After we've got half the client's // initial flow control window's worth // of request body data, give it just // enough flow control to finish. - if gotBytes >= http2initialWindowSize/2 && !sentSettings { + if gotBytes >= initialWindowSize/2 && !sentSettings { sentSettings = true - ct.fr.WriteSettings(http2Setting{ID: http2SettingInitialWindowSize, Val: bodySize}) + ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) ct.fr.WriteWindowUpdate(0, bodySize) ct.fr.WriteSettingsAck() } @@ -3048,7 +3009,7 @@ func TestTransportAdjustsFlowControl(t *testing.T) { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, @@ -3080,18 +3041,18 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { ct.server = func() error { ct.greet() - var hf *http2HeadersFrame + var hf *HeadersFrame for { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: + case *WindowUpdateFrame, *SettingsFrame: continue } var ok bool - hf, ok = f.(*http2HeadersFrame) + hf, ok = f.(*HeadersFrame) if !ok { return fmt.Errorf("Got %T; want HeadersFrame", f) } @@ -3102,7 +3063,7 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -3116,16 +3077,16 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) } wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding - if wuf, ok := f.(*http2WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { - return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, http2summarizeFrame(f)) + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { + return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) } f, err = ct.readNonSettingsFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) } - if wuf, ok := f.(*http2WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { - return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, http2summarizeFrame(f)) + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { + return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) } unblockClient <- true return nil @@ -3145,7 +3106,7 @@ func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { res.Body.Close() return errors.New("unexpected successful GET") } - want := http2StreamError{1, http2ErrCodeProtocol, http2headerFieldNameError(" content-type")} + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} if !reflect.DeepEqual(want, err) { t.Errorf("RoundTrip error = %#v; want %#v", err, want) } @@ -3163,7 +3124,7 @@ func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -3175,11 +3136,11 @@ func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { if err != nil { return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) } - if _, ok := fr.(*http2SettingsFrame); ok { + if _, ok := fr.(*SettingsFrame); ok { continue } - if rst, ok := fr.(*http2RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != http2ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with http2ErrCodeProtocol", http2summarizeFrame(fr)) + if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } break } @@ -3216,7 +3177,7 @@ func TestTransportBodyDoubleEndStream(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() for i := 0; i < 2; i++ { @@ -3332,13 +3293,13 @@ func TestTransportRequestPathPseudo(t *testing.T) { }, } for i, tt := range tests { - cc := &http2ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} cc.henc = hpack.NewEncoder(&cc.hbuf) cc.mu.Lock() hdrs, err := cc.encodeHeaders(tt.req, false, "", -1, nil) cc.mu.Unlock() var got result - hpackDec := hpack.NewDecoder(http2initialHeaderTableSize, func(f hpack.HeaderField) { + hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { if f.Name == ":path" { got.path = f.Value } @@ -3364,7 +3325,7 @@ func TestTransportRequestPathPseudo(t *testing.T) { func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) - cc := &http2ClientConn{ + cc := &ClientConn{ closed: true, reqHeaderMu: make(chan struct{}, 1), } @@ -3384,7 +3345,7 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { func TestClientConnPing(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -3423,7 +3384,7 @@ func TestTransportCancelDataResponseRace(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -3460,7 +3421,7 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("GET", st.ts.URL, nil) @@ -3508,9 +3469,9 @@ func TestTransportPingWriteBlocks(t *testing.T) { optOnlyServer, ) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { s, c := net.Pipe() // unbuffered, unlike a TCP conn @@ -3619,20 +3580,20 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D } } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) streamID = f.StreamID - case *http2PingFrame: + case *PingFrame: pingCount++ if pingCount == expectedPingCount { if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { @@ -3642,7 +3603,7 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D if err := ct.fr.WritePing(true, f.Data); err != nil { return err } - case *http2RSTStreamFrame: + case *RSTStreamFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } @@ -3659,12 +3620,12 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { ct1 := make(chan *clientTester) ct2 := make(chan *clientTester) - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, } tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { @@ -3687,7 +3648,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { tr: tr, cc: cc, sc: sc, - fr: http2NewFramer(sc, sc), + fr: NewFramer(sc, sc), } switch dialer.count { case 1: @@ -3730,7 +3691,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { return } t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, http2ErrCodeNo, nil); err != nil { + if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) return } @@ -3754,7 +3715,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(http2HeadersFrameParam{ + err = ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, @@ -3821,17 +3782,17 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { } } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } nreq++ if nreq == 1 { - ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeRefusedStream) + ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) } else { enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, @@ -3870,7 +3831,7 @@ func TestTransportResponseDataBeforeHeaders(t *testing.T) { if err == nil { return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) } - if err, ok := err.(http2StreamError); !ok || err.Code != http2ErrCodeProtocol { + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) } return nil @@ -3885,15 +3846,15 @@ func TestTransportResponseDataBeforeHeaders(t *testing.T) { return err } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame, *http2RSTStreamFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: + case *HeadersFrame: switch f.StreamID { case 1: // Send a valid response to first request. var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, @@ -3912,7 +3873,7 @@ func TestTransportResponseDataBeforeHeaders(t *testing.T) { func TestTransportRequestsLowServerLimit(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, optOnlyServer, func(s *http2Server) { + }, optOnlyServer, func(s *Server) { s.MaxConcurrentStreams = 1 }) defer st.Close() @@ -3921,9 +3882,9 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { connCountMu sync.Mutex connCount int ) - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCountMu.Lock() @@ -4076,7 +4037,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { var wg sync.WaitGroup defer wg.Wait() - ct.greet(http2Setting{http2SettingMaxConcurrentStreams, maxConcurrent}) + ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) // Server write loop. var buf bytes.Buffer @@ -4090,7 +4051,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { for id := range writeResp { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: id, EndHeaders: true, EndStream: true, @@ -4113,11 +4074,11 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { } } switch f := f.(type) { - case *http2WindowUpdateFrame: - case *http2SettingsFrame: + case *WindowUpdateFrame: + case *SettingsFrame: // Wait for the client SETTINGS ack until ending the greet. close(greet) - case *http2HeadersFrame: + case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } @@ -4127,7 +4088,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { if nreq == maxConcurrent+1 { close(writeResp) } - case *http2DataFrame: + case *DataFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } @@ -4151,7 +4112,7 @@ func TestAuthorityAddr(t *testing.T) { {"https", "[::1]", "[::1]:443"}, } for _, tt := range tests { - got := http2authorityAddr(tt.scheme, tt.authority) + got := authorityAddr(tt.scheme, tt.authority) if got != tt.want { t.Errorf("http2authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) } @@ -4181,7 +4142,7 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} res, err := c.Get(st.ts.URL) @@ -4196,7 +4157,7 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { t.Error(err) } - trb, ok := res.Body.(http2transportResponseBody) + trb, ok := res.Body.(transportResponseBody) if !ok { t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) } @@ -4238,9 +4199,9 @@ func TestTransportNoBodyMeansNoDATA(t *testing.T) { switch f := f.(type) { default: return fmt.Errorf("Got %T; want HeadersFrame", f) - case *http2WindowUpdateFrame, *http2SettingsFrame: + case *WindowUpdateFrame, *SettingsFrame: continue - case *http2HeadersFrame: + case *HeadersFrame: if !f.StreamEnded() { return fmt.Errorf("got headers frame without END_STREAM") } @@ -4252,9 +4213,9 @@ func TestTransportNoBodyMeansNoDATA(t *testing.T) { } func disableGoroutineTracking() (restore func()) { - old := http2DebugGoroutines - http2DebugGoroutines = false - return func() { http2DebugGoroutines = old } + old := DebugGoroutines + DebugGoroutines = false + return func() { DebugGoroutines = old } } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { @@ -4272,7 +4233,7 @@ func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { ) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) @@ -4316,7 +4277,7 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() // The request body needs to be big enough to trigger flow control. @@ -4354,9 +4315,9 @@ func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { return err } switch f := f.(type) { - case *http2HeadersFrame: + case *HeadersFrame: enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, // we'll send some DATA to try to crash the transport @@ -4384,7 +4345,7 @@ func BenchmarkClientResponseHeaders(b *testing.B) { b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) }) } -func activeStreams(cc *http2ClientConn) int { +func activeStreams(cc *ClientConn) int { count := 0 cc.mu.Lock() defer cc.mu.Unlock() @@ -4434,7 +4395,7 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { } }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -4502,11 +4463,11 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { <-handlerDone case shutdown: wait := make(chan struct{}) - http2shutdownEnterWaitStateHook = func() { + shutdownEnterWaitStateHook = func() { close(wait) - http2shutdownEnterWaitStateHook = func() {} + shutdownEnterWaitStateHook = func() {} } - defer func() { http2shutdownEnterWaitStateHook = func() {} }() + defer func() { shutdownEnterWaitStateHook = func() {} }() shutdown := make(chan struct{}, 1) go func() { if err = cc.Shutdown(context.Background()); err != nil { @@ -4599,7 +4560,7 @@ func TestTransportUsesGetBodyWhenPresent(t *testing.T) { }, } - req2, err := http2shouldRetryRequest(req, errClientConnUnusable) + req2, err := shouldRetryRequest(req, errClientConnUnusable) if err != nil { t.Fatal(err) } @@ -4657,9 +4618,9 @@ func testTransportBodyReadError(t *testing.T, body []byte) { defer close(clientDone) checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*http2clientConnPool) + cp, ok := ct.tr.connPool().(*clientConnPool) if !ok { - return fmt.Errorf("conn pool is %T; want *http2clientConnPool", ct.tr.connPool()) + return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) } cp.mu.Lock() defer cp.mu.Unlock() @@ -4715,11 +4676,11 @@ func testTransportBodyReadError(t *testing.T, body []byte) { } } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: - case *http2DataFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + case *DataFrame: receivedBody = append(receivedBody, f.Data()...) - case *http2RSTStreamFrame: + case *RSTStreamFrame: resetCount++ default: return fmt.Errorf("Unexpected client frame %v", f) @@ -4767,17 +4728,17 @@ func TestTransportBodyEagerEndStream(t *testing.T) { } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: - case *http2DataFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + case *DataFrame: if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeRefusedStream) + ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) return fmt.Errorf("data frame without END_STREAM %v", f) } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: true, EndStream: false, @@ -4785,7 +4746,7 @@ func TestTransportBodyEagerEndStream(t *testing.T) { }) ct.fr.WriteData(f.StreamID, true, []byte(resBody)) return nil - case *http2RSTStreamFrame: + case *RSTStreamFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } @@ -4836,7 +4797,7 @@ func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunk }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("POST", st.ts.URL, body) @@ -4849,30 +4810,30 @@ func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunk func TestClientConnTooIdle(t *testing.T) { tests := []struct { - cc func() *http2ClientConn + cc func() *ClientConn want bool }{ { - func() *http2ClientConn { - return &http2ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + func() *ClientConn { + return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} }, true, }, { - func() *http2ClientConn { - return &http2ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} + func() *ClientConn { + return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} }, false, }, { - func() *http2ClientConn { - return &http2ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + func() *ClientConn { + return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} }, false, }, { - func() *http2ClientConn { - return &http2ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} + func() *ClientConn { + return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} }, false, }, @@ -4902,7 +4863,7 @@ func (fce *fakeConnErr) Close() error { // issue 39337: close the connection on a failed write func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { - tr := &http2Transport{} + tr := &Transport{Interface: tests.Transport{}} writeErr := errors.New("write error") fakeConn := &fakeConnErr{writeErr: writeErr} _, err := tr.NewClientConn(fakeConn) @@ -4922,7 +4883,7 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -4948,177 +4909,6 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) { } } -// Issue 31192: A failed request may be retried if the body has not been read -// already. If the request body has started to be sent, one must wait until it -// is completed. -func TestTransportBodyRewindRace(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Connection", "close") - w.WriteHeader(http.StatusOK) - return - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - MaxConnsPerHost: 1, - } - err := http2ConfigureTransport(tr) - if err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: tr, - } - - const clients = 50 - - var wg sync.WaitGroup - wg.Add(clients) - for i := 0; i < clients; i++ { - req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef")) - if err != nil { - t.Fatalf("unexpect new request error: %v", err) - } - - go func() { - defer wg.Done() - res, err := client.Do(req) - if err == nil { - res.Body.Close() - } - }() - } - - wg.Wait() -} - -// Issue 42498: A request with a body will never be sent if the stream is -// reset prior to sending any data. -func TestTransportServerResetStreamAtHeaders(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - return - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - MaxConnsPerHost: 1, - ExpectContinueTimeout: 10 * time.Second, - } - - err := http2ConfigureTransport(tr) - if err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: tr, - } - - req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF}) - if err != nil { - t.Fatalf("unexpect new request error: %v", err) - } - req.ContentLength = 0 // so transport is tempted to sniff it - req.Header.Set("Expect", "100-continue") - res, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() -} - -type trackingReader struct { - rdr io.Reader - wasRead uint32 -} - -func (tr *trackingReader) Read(p []byte) (int, error) { - atomic.StoreUint32(&tr.wasRead, 1) - return tr.rdr.Read(p) -} - -func (tr *trackingReader) WasRead() bool { - return atomic.LoadUint32(&tr.wasRead) != 0 -} - -func TestTransportExpectContinue(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/reject": - w.WriteHeader(403) - default: - io.Copy(io.Discard, r.Body) - } - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - MaxConnsPerHost: 1, - ExpectContinueTimeout: 10 * time.Second, - } - - err := http2ConfigureTransport(tr) - if err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: tr, - } - - testCases := []struct { - Name string - Path string - Body *trackingReader - ExpectedCode int - ShouldRead bool - }{ - { - Name: "read-all", - Path: "/", - Body: &trackingReader{rdr: strings.NewReader("hello")}, - ExpectedCode: 200, - ShouldRead: true, - }, - { - Name: "reject", - Path: "/reject", - Body: &trackingReader{rdr: strings.NewReader("hello")}, - ExpectedCode: 403, - ShouldRead: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - startTime := time.Now() - - req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Expect", "100-continue") - res, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout { - t.Error("Request didn't finish before expect continue timeout") - } - if res.StatusCode != tc.ExpectedCode { - t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode) - } - if tc.Body.WasRead() != tc.ShouldRead { - t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead) - } - }) - } -} - type closeChecker struct { io.ReadCloser closed chan struct{} @@ -5224,7 +5014,7 @@ func TestTransportFrameBufferReuse(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() var wg sync.WaitGroup @@ -5307,7 +5097,7 @@ func TestTransportBlockingRequestWrite(t *testing.T) { if v := r.Trailer.Get("Big"); v != "" && v != filler { t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler) } - }, optOnlyServer, func(s *http2Server) { + }, optOnlyServer, func(s *Server) { s.MaxConcurrentStreams = 1 }) defer st.Close() @@ -5315,9 +5105,9 @@ func TestTransportBlockingRequestWrite(t *testing.T) { // This Transport creates connections that block on writes after 1024 bytes. connc := make(chan *blockingWriteConn, 1) connCount := 0 - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCount++ @@ -5410,7 +5200,7 @@ func TestTransportCloseRequestBody(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -5440,17 +5230,17 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a http2ClientConnPool that wraps lower and +// collectClientsConnPool is a ClientConnPool that wraps lower and // collects what calls were made on it. type collectClientsConnPool struct { - lower http2ClientConnPool + lower ClientConnPool mu sync.Mutex getErrs int - got []*http2ClientConn + got []*ClientConn } -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*http2ClientConn, error) { +func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { cc, err := p.lower.GetClientConn(req, addr) p.mu.Lock() defer p.mu.Unlock() @@ -5462,14 +5252,14 @@ func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) ( return cc, nil } -func (p *collectClientsConnPool) MarkDead(cc *http2ClientConn) { +func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { p.lower.MarkDead(cc) } func TestTransportRetriesOnStreamProtocolError(t *testing.T) { ct := newClientTester(t) pool := &collectClientsConnPool{ - lower: &http2clientConnPool{t: ct.tr}, + lower: &clientConnPool{t: ct.tr}, } ct.tr.ConnPool = pool @@ -5563,15 +5353,15 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { return nil } switch f := f.(type) { - case *http2WindowUpdateFrame, *http2SettingsFrame: - case *http2HeadersFrame: + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: numHeaders++ if numHeaders == 1 { firstStreamID = f.StreamID hbuf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(http2HeadersFrameParam{ + ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, @@ -5581,7 +5371,7 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { } if !sentErr { sentErr = true - ct.fr.WriteRSTStream(f.StreamID, http2ErrCodeProtocol) + ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) ct.fr.WriteData(firstStreamID, true, nil) continue } @@ -5593,20 +5383,20 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { } func TestClientConnReservations(t *testing.T) { - cc := &http2ClientConn{ + cc := &ClientConn{ reqHeaderMu: make(chan struct{}, 1), - streams: make(map[uint32]*http2clientStream), - maxConcurrentStreams: http2initialMaxConcurrentStreams, + streams: make(map[uint32]*clientStream), + maxConcurrentStreams: initialMaxConcurrentStreams, nextStreamID: 1, - t: &http2Transport{}, + t: &Transport{Interface: tests.Transport{}}, } cc.cond = sync.NewCond(&cc.mu) n := 0 - for n <= http2initialMaxConcurrentStreams && cc.ReserveNewRequest() { + for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n++ } - if n != http2initialMaxConcurrentStreams { - t.Errorf("did %v reservations; want %v", n, http2initialMaxConcurrentStreams) + if n != initialMaxConcurrentStreams { + t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) } if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) { t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err) @@ -5625,7 +5415,7 @@ func TestClientConnReservations(t *testing.T) { } n2 = 0 - for n2 <= http2initialMaxConcurrentStreams && cc.ReserveNewRequest() { + for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n2++ } if n2 != n { @@ -5675,7 +5465,7 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { w.Header().Set("Content-Length", contentLength) }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() for _, test := range []struct { @@ -5730,7 +5520,7 @@ func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() pr, pw := net.Pipe() @@ -5758,7 +5548,7 @@ func TestTransport300ResponseBody(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &http2Transport{t1: &Transport{TLSClientConfig: tlsConfigInsecure}} + tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} defer tr.CloseIdleConnections() pr, pw := net.Pipe() @@ -5788,9 +5578,9 @@ func TestTransportWriteByteTimeout(t *testing.T) { optOnlyServer, ) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { _, c := net.Pipe() @@ -5834,9 +5624,9 @@ func TestTransportSlowWrites(t *testing.T) { optOnlyServer, ) defer st.Close() - tr := &http2Transport{ - t1: &Transport{ - TLSClientConfig: tlsConfigInsecure, + tr := &Transport{ + Interface: tests.Transport{ + TLSClientConfigValue: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { cfg.InsecureSkipVerify = true @@ -5849,7 +5639,7 @@ func TestTransportSlowWrites(t *testing.T) { c := &http.Client{Transport: tr} const bodySize = 1 << 20 - resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize)) + resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(tests.NeverEnding('A'), bodySize)) if err != nil { t.Fatal(err) } @@ -5857,50 +5647,50 @@ func TestTransportSlowWrites(t *testing.T) { } func TestCountReadFrameError(t *testing.T) { - cc := &http2ClientConn{} + cc := &ClientConn{} errMsg := "" countError := func(errType string) { errMsg = errType } - cc.t = &http2Transport{CountError: countError} + cc.t = &Transport{CountError: countError} var err error cc.countReadFrameError(err) - assertEqual(t, "", errMsg) + tests.AssertEqual(t, "", errMsg) - err = http2ConnectionError(http2ErrCodeInternal) + err = ConnectionError(ErrCodeInternal) cc.countReadFrameError(err) - assertContains(t, errMsg, "read_frame_conn_error", true) + tests.AssertContains(t, errMsg, "read_frame_conn_error", true) err = io.EOF cc.countReadFrameError(err) - assertContains(t, errMsg, "read_frame_eof", true) + tests.AssertContains(t, errMsg, "read_frame_eof", true) err = io.ErrUnexpectedEOF cc.countReadFrameError(err) - assertContains(t, errMsg, "read_frame_unexpected_eof", true) + tests.AssertContains(t, errMsg, "read_frame_unexpected_eof", true) err = errFrameTooLarge cc.countReadFrameError(err) - assertContains(t, errMsg, "read_frame_too_large", true) + tests.AssertContains(t, errMsg, "read_frame_too_large", true) err = errors.New("other") cc.countReadFrameError(err) - assertContains(t, errMsg, "read_frame_other", true) + tests.AssertContains(t, errMsg, "read_frame_other", true) } func TestProcessHeaders(t *testing.T) { - rl := &http2clientConnReadLoop{} - cc := &http2ClientConn{streams: map[uint32]*http2clientStream{}} - cc.streams[1] = &http2clientStream{cc: cc, abort: make(chan struct{})} + rl := &clientConnReadLoop{} + cc := &ClientConn{streams: map[uint32]*clientStream{}} + cc.streams[1] = &clientStream{cc: cc, abort: make(chan struct{})} rl.cc = cc - f := &http2MetaHeadersFrame{http2HeadersFrame: &http2HeadersFrame{ - http2FrameHeader: http2FrameHeader{StreamID: 1}, + f := &MetaHeadersFrame{HeadersFrame: &HeadersFrame{ + FrameHeader: FrameHeader{StreamID: 1}, }} err := rl.processHeaders(f) - assertNoError(t, err) + tests.AssertNoError(t, err) f.StreamID = 0 err = rl.processHeaders(f) - assertNoError(t, err) + tests.AssertNoError(t, err) } diff --git a/internal/socks/socks_test.go b/internal/socks/socks_test.go index cc3af621..824a09d7 100644 --- a/internal/socks/socks_test.go +++ b/internal/socks/socks_test.go @@ -3,6 +3,7 @@ package socks import ( "bytes" "context" + "github.com/imroc/req/v3/internal/tests" "strings" "testing" ) @@ -20,22 +21,6 @@ func TestReply(t *testing.T) { } } -func assertNoError(t *testing.T, err error) { - if err != nil { - t.Errorf("Error occurred [%v]", err) - } -} - -func assertErrorContains(t *testing.T, err error, s string) { - if err == nil { - t.Error("err is nil") - return - } - if !strings.Contains(err.Error(), s) { - t.Errorf("%q is not included in error %q", s, err.Error()) - } -} - func TestAuthenticate(t *testing.T) { auth := &UsernamePassword{ Username: "imroc", @@ -43,21 +28,21 @@ func TestAuthenticate(t *testing.T) { } buf := bytes.NewBuffer([]byte{byte(0x01), byte(0x00)}) err := auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - assertNoError(t, err) + tests.AssertNoError(t, err) auth.Username = "this is a very long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long name" err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - assertErrorContains(t, err, "invalid") + tests.AssertErrorContains(t, err, "invalid") auth.Username = "imroc" buf = bytes.NewBuffer([]byte{byte(0x03), byte(0x00)}) err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - assertErrorContains(t, err, "invalid username/password version") + tests.AssertErrorContains(t, err, "invalid username/password version") buf = bytes.NewBuffer([]byte{byte(0x01), byte(0x02)}) err = auth.Authenticate(context.Background(), buf, AuthMethodUsernamePassword) - assertErrorContains(t, err, "authentication failed") + tests.AssertErrorContains(t, err, "authentication failed") err = auth.Authenticate(context.Background(), buf, AuthMethodNoAcceptableMethods) - assertErrorContains(t, err, "unsupported authentication method") + tests.AssertErrorContains(t, err, "unsupported authentication method") } diff --git a/internal/tests/assert.go b/internal/tests/assert.go new file mode 100644 index 00000000..c4bdbbe9 --- /dev/null +++ b/internal/tests/assert.go @@ -0,0 +1,111 @@ +package tests + +import ( + "go/token" + "reflect" + "strings" + "testing" + "unsafe" +) + +func AssertIsNil(t *testing.T, v interface{}) { + if !isNil(v) { + t.Errorf("[%v] was expected to be nil", v) + } +} + +func AssertAllNotNil(t *testing.T, vv ...interface{}) { + for _, v := range vv { + AssertNotNil(t, v) + } +} + +func AssertNotNil(t *testing.T, v interface{}) { + if isNil(v) { + t.Fatalf("[%v] was expected to be non-nil", v) + } +} + +func AssertEqual(t *testing.T, e, g interface{}) { + if !equal(e, g) { + t.Errorf("Expected [%+v], got [%+v]", e, g) + } + return +} + +func AssertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Error occurred [%v]", err) + } +} + +func AssertErrorContains(t *testing.T, err error, s string) { + if err == nil { + t.Error("err is nil") + return + } + if !strings.Contains(err.Error(), s) { + t.Errorf("%q is not included in error %q", s, err.Error()) + } +} + +func AssertContains(t *testing.T, s, substr string, shouldContain bool) { + s = strings.ToLower(s) + isContain := strings.Contains(s, substr) + if shouldContain { + if !isContain { + t.Errorf("%q is not included in %s", substr, s) + } + } else { + if isContain { + t.Errorf("%q is included in %s", substr, s) + } + } +} + +func AssertClone(t *testing.T, e, g interface{}) { + ev := reflect.ValueOf(e).Elem() + gv := reflect.ValueOf(g).Elem() + et := ev.Type() + + for i := 0; i < ev.NumField(); i++ { + sf := ev.Field(i) + st := et.Field(i) + + var ee, gg interface{} + if !token.IsExported(st.Name) { + ee = reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Interface() + gg = reflect.NewAt(sf.Type(), unsafe.Pointer(gv.Field(i).UnsafeAddr())).Elem().Interface() + } else { + ee = sf.Interface() + gg = gv.Field(i).Interface() + } + if sf.Kind() == reflect.Func || sf.Kind() == reflect.Slice || sf.Kind() == reflect.Ptr { + if ee != nil { + if gg == nil { + t.Errorf("Field %s.%s is nil", et.Name(), et.Field(i).Name) + } + } + continue + } + if !reflect.DeepEqual(ee, gg) { + t.Errorf("Field %s.%s is not equal, expected [%v], got [%v]", et.Name(), et.Field(i).Name, ee, gg) + } + } +} + +func equal(expected, got interface{}) bool { + return reflect.DeepEqual(expected, got) +} + +func isNil(v interface{}) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + kind := rv.Kind() + if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { + return true + } + return false +} diff --git a/internal/tests/condition.go b/internal/tests/condition.go new file mode 100644 index 00000000..46816d79 --- /dev/null +++ b/internal/tests/condition.go @@ -0,0 +1,17 @@ +package tests + +import "time" + +// WaitCondition reports whether fn eventually returned true, +// checking immediately and then every checkEvery amount, +// until waitFor has elapsed, at which point it returns false. +func WaitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} diff --git a/internal/tests/net.go b/internal/tests/net.go new file mode 100644 index 00000000..da3f7e05 --- /dev/null +++ b/internal/tests/net.go @@ -0,0 +1,17 @@ +package tests + +import ( + "net" + "testing" +) + +func NewLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} diff --git a/internal/tests/reader.go b/internal/tests/reader.go new file mode 100644 index 00000000..251dd4b5 --- /dev/null +++ b/internal/tests/reader.go @@ -0,0 +1,10 @@ +package tests + +type NeverEnding byte + +func (b NeverEnding) Read(p []byte) (int, error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} diff --git a/internal/tests/transport.go b/internal/tests/transport.go new file mode 100644 index 00000000..9378c28b --- /dev/null +++ b/internal/tests/transport.go @@ -0,0 +1,139 @@ +package tests + +import ( + "context" + "crypto/tls" + "github.com/imroc/req/v3/internal/dump" + reqtls "github.com/imroc/req/v3/internal/tls" + "github.com/imroc/req/v3/internal/transport" + "net" + "net/http" + "net/url" + "time" +) + +type Transport struct { + ProxyValue func(*http.Request) (*url.URL, error) + TLSClientConfigValue *tls.Config + DisableCompressionValue bool + DisableKeepAlivesValue bool + TLSHandshakeTimeoutValue time.Duration + ResponseHeaderTimeoutValue time.Duration + ExpectContinueTimeoutValue time.Duration + IdleConnTimeoutValue time.Duration + DumpValue *dump.Dumper + ReadBufferSizeValue int + WriteBufferSizeValue int + MaxIdleConnsValue int + MaxIdleConnsPerHostValue int + MaxConnsPerHostValue int + MaxResponseHeaderBytesValue int64 + TLSNextProtoValue map[string]func(authority string, c reqtls.Conn) http.RoundTripper + GetProxyConnectHeaderValue func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) + ProxyConnectHeaderValue http.Header + DebugfValue func(format string, v ...interface{}) +} + +func (t Transport) Proxy() func(*http.Request) (*url.URL, error) { + return t.ProxyValue +} + +func (t Transport) Clone() transport.Interface { + return nil +} + +func (t Transport) Debugf() func(format string, v ...interface{}) { + return t.DebugfValue +} + +func (t Transport) SetDebugf(f func(format string, v ...interface{})) { + t.DebugfValue = f +} + +func (t Transport) DisableCompression() bool { + return t.DisableCompressionValue +} + +func (t Transport) TLSClientConfig() *tls.Config { + return t.TLSClientConfigValue +} + +func (t Transport) SetTLSClientConfig(c *tls.Config) { + t.TLSClientConfigValue = c +} + +func (t Transport) TLSHandshakeTimeout() time.Duration { + return t.TLSHandshakeTimeoutValue +} + +func (t Transport) DialContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { + return nil +} + +func (t Transport) DialTLSContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { + return nil +} + +func (t Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { +} + +func (t Transport) DisableKeepAlives() bool { + return t.DisableKeepAlivesValue +} + +func (t Transport) Dump() *dump.Dumper { + return t.DumpValue + +} + +func (t Transport) MaxIdleConns() int { + return t.MaxIdleConnsValue +} + +func (t Transport) MaxIdleConnsPerHost() int { + return t.MaxIdleConnsPerHostValue +} + +func (t Transport) MaxConnsPerHost() int { + return t.MaxConnsPerHostValue +} + +func (t Transport) IdleConnTimeout() time.Duration { + return t.IdleConnTimeoutValue +} + +func (t Transport) ResponseHeaderTimeout() time.Duration { + return t.ResponseHeaderTimeoutValue +} + +func (t Transport) ExpectContinueTimeout() time.Duration { + return t.ExpectContinueTimeoutValue +} + +func (t Transport) ProxyConnectHeader() http.Header { + return t.ProxyConnectHeaderValue +} + +func (t Transport) GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { + return t.GetProxyConnectHeaderValue +} + +func (t Transport) MaxResponseHeaderBytes() int64 { + return t.MaxResponseHeaderBytesValue +} + +func (t Transport) WriteBufferSize() int { + return t.WriteBufferSizeValue +} + +func (t Transport) ReadBufferSize() int { + return t.ReadBufferSizeValue +} + +func (t Transport) TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper { + return t.TLSNextProtoValue +} + +func (t Transport) SetTLSNextProto(m map[string]func(authority string, c reqtls.Conn) http.RoundTripper) { + t.TLSNextProtoValue = m +} diff --git a/internal/tls/conn.go b/internal/tls/conn.go new file mode 100644 index 00000000..2eaae7c1 --- /dev/null +++ b/internal/tls/conn.go @@ -0,0 +1,35 @@ +package tls + +import ( + "crypto/tls" + "net" +) + +// Conn is the recommended interface for the connection +// returned by the DailTLS function (Client.SetDialTLS, +// Transport.DialTLSContext), so that the TLS handshake negotiation +// can automatically decide whether to use HTTP2 or HTTP1 (ALPN). +// If this interface is not implemented, HTTP1 will be used by default. +type Conn interface { + net.Conn + // ConnectionState returns basic TLS details about the connection. + ConnectionState() tls.ConnectionState + // Handshake runs the client or server handshake + // protocol if it has not yet been run. + // + // Most uses of this package need not call Handshake explicitly: the + // first Read or Write will call it automatically. + // + // For control over canceling or setting a timeout on a handshake, use + // HandshakeContext or the Dialer's DialContext method instead. + Handshake() error +} + +// NetConnWrapper is the interface to get underlying connection, which is +// introduced in go1.18 for *tls.Conn. +type NetConnWrapper interface { + // NetConn returns the underlying connection that is wrapped by c. + // Note that writing to or reading from this connection directly will corrupt the + // TLS session. + NetConn() net.Conn +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 00000000..f15195d2 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,100 @@ +package transport + +import ( + "context" + "crypto/tls" + "github.com/imroc/req/v3/internal/dump" + reqtls "github.com/imroc/req/v3/internal/tls" + "net" + "net/http" + "net/url" + "time" +) + +type Interface interface { + Proxy() func(*http.Request) (*url.URL, error) + Clone() Interface + Debugf() func(format string, v ...interface{}) + SetDebugf(func(format string, v ...interface{})) + DisableCompression() bool + TLSClientConfig() *tls.Config + SetTLSClientConfig(c *tls.Config) + TLSHandshakeTimeout() time.Duration + DialContext() func(ctx context.Context, network, addr string) (net.Conn, error) + DialTLSContext() func(ctx context.Context, network, addr string) (net.Conn, error) + RegisterProtocol(scheme string, rt http.RoundTripper) + DisableKeepAlives() bool + Dump() *dump.Dumper + + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns() int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) connections to keep per-host. If zero, + // defaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost() int + + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + MaxConnsPerHost() int + + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout() time.Duration + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout() time.Duration + + // ExpectContinueTimeout, if non-zero, specifies the amount of + // time to wait for a server's first response headers after fully + // writing the request headers if the request has an + // "Expect: 100-continue" header. Zero means no timeout and + // causes the body to be sent immediately, without + // waiting for the server to approve. + // This time does not include the time to send the request header. + ExpectContinueTimeout() time.Duration + + // ProxyConnectHeader optionally specifies headers to send to + // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. + ProxyConnectHeader() http.Header + + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) + + // MaxResponseHeaderBytes specifies a limit on how many + // response bytes are allowed in the server's response + // header. + // + // Zero means to use a default limit. + MaxResponseHeaderBytes() int64 + + // WriteBufferSize specifies the size of the write buffer used + // when writing to the transport. + // If zero, a default (currently 4KB) is used. + WriteBufferSize() int + + // ReadBufferSize specifies the size of the read buffer used + // when reading from the transport. + // If zero, a default (currently 4KB) is used. + ReadBufferSize() int + + TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper + + SetTLSNextProto(map[string]func(authority string, c reqtls.Conn) http.RoundTripper) +} diff --git a/logger_test.go b/logger_test.go index 2718392d..c15dd6f4 100644 --- a/logger_test.go +++ b/logger_test.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "github.com/imroc/req/v3/internal/tests" "log" "testing" ) @@ -11,8 +12,8 @@ func TestLogger(t *testing.T) { l := NewLogger(buf, "", log.Ldate|log.Lmicroseconds) c := tc().SetLogger(l) c.SetProxyURL(":=\\<>ksfj&*&sf") - assertContains(t, buf.String(), "error", true) + tests.AssertContains(t, buf.String(), "error", true) buf.Reset() c.R().SetOutput(nil) - assertContains(t, buf.String(), "warn", true) + tests.AssertContains(t, buf.String(), "warn", true) } diff --git a/req_test.go b/req_test.go index 9687f311..a1cdccb7 100644 --- a/req_test.go +++ b/req_test.go @@ -67,65 +67,10 @@ func getTestServerURL() string { func getTestFileContent(t *testing.T, filename string) []byte { b, err := ioutil.ReadFile(tests.GetTestFilePath(filename)) - assertNoError(t, err) + tests.AssertNoError(t, err) return b } -func assertIsNil(t *testing.T, v interface{}) { - if !isNil(v) { - t.Errorf("[%v] was expected to be nil", v) - } -} - -func assertAllNotNil(t *testing.T, vv ...interface{}) { - for _, v := range vv { - assertNotNil(t, v) - } -} - -func assertNotNil(t *testing.T, v interface{}) { - if isNil(v) { - t.Fatalf("[%v] was expected to be non-nil", v) - } -} - -func assertEqual(t *testing.T, e, g interface{}) { - if !equal(e, g) { - t.Errorf("Expected [%+v], got [%+v]", e, g) - } - return -} - -func assertNoError(t *testing.T, err error) { - if err != nil { - t.Errorf("Error occurred [%v]", err) - } -} - -func assertErrorContains(t *testing.T, err error, s string) { - if err == nil { - t.Error("err is nil") - return - } - if !strings.Contains(err.Error(), s) { - t.Errorf("%q is not included in error %q", s, err.Error()) - } -} - -func assertContains(t *testing.T, s, substr string, shouldContain bool) { - s = strings.ToLower(s) - isContain := strings.Contains(s, substr) - if shouldContain { - if !isContain { - t.Errorf("%q is not included in %s", substr, s) - } - } else { - if isContain { - t.Errorf("%q is included in %s", substr, s) - } - } -} - func assertClone(t *testing.T, e, g interface{}) { ev := reflect.ValueOf(e).Elem() gv := reflect.ValueOf(g).Elem() @@ -157,22 +102,6 @@ func assertClone(t *testing.T, e, g interface{}) { } } -func equal(expected, got interface{}) bool { - return reflect.DeepEqual(expected, got) -} - -func isNil(v interface{}) bool { - if v == nil { - return true - } - rv := reflect.ValueOf(v) - kind := rv.Kind() - if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { - return true - } - return false -} - // Echo is used in "/echo" API. type Echo struct { Header http.Header `json:"header" xml:"header"` @@ -392,28 +321,28 @@ func handleGet(w http.ResponseWriter, r *http.Request) { } func assertStatus(t *testing.T, resp *Response, err error, statusCode int, status string) { - assertNoError(t, err) - assertNotNil(t, resp) - assertNotNil(t, resp.Body) - assertEqual(t, statusCode, resp.StatusCode) - assertEqual(t, status, resp.Status) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp) + tests.AssertNotNil(t, resp.Body) + tests.AssertEqual(t, statusCode, resp.StatusCode) + tests.AssertEqual(t, status, resp.Status) } func assertSuccess(t *testing.T, resp *Response, err error) { - assertNoError(t, err) - assertNotNil(t, resp.Response) - assertNotNil(t, resp.Response.Body) - assertEqual(t, http.StatusOK, resp.StatusCode) - assertEqual(t, "200 OK", resp.Status) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp.Response) + tests.AssertNotNil(t, resp.Response.Body) + tests.AssertEqual(t, http.StatusOK, resp.StatusCode) + tests.AssertEqual(t, "200 OK", resp.Status) if !resp.IsSuccess() { t.Error("Response.IsSuccess should return true") } } func assertIsError(t *testing.T, resp *Response, err error) { - assertNoError(t, err) - assertNotNil(t, resp) - assertNotNil(t, resp.Body) + tests.AssertNoError(t, err) + tests.AssertNotNil(t, resp) + tests.AssertNotNil(t, resp.Body) if !resp.IsError() { t.Error("Response.IsError should return true") } diff --git a/request.go b/request.go index 04c8e079..b199543e 100644 --- a/request.go +++ b/request.go @@ -5,6 +5,7 @@ import ( "context" "errors" "github.com/hashicorp/go-multierror" + "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -743,7 +744,7 @@ func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { // EnableDump enables dump, including all content for the request and response by default. func (r *Request) EnableDump() *Request { - return r.SetContext(context.WithValue(r.Context(), dumperKey, newDumper(r.getDumpOptions()))) + return r.SetContext(context.WithValue(r.Context(), dump.DumperKey, newDumper(r.getDumpOptions()))) } // EnableDumpWithoutBody enables dump only header for the request and response. diff --git a/request_test.go b/request_test.go index beac1381..a126a54c 100644 --- a/request_test.go +++ b/request_test.go @@ -177,10 +177,10 @@ func testEnableDump(t *testing.T, fn func(r *Request) (de dumpExpected)) { resp, err := r.SetBody(`test body`).Post("/") assertSuccess(t, resp, err) dump := resp.Dump() - assertContains(t, dump, "user-agent", de.ReqHeader) - assertContains(t, dump, "test body", de.ReqBody) - assertContains(t, dump, "date", de.RespHeader) - assertContains(t, dump, "testpost: text response", de.RespBody) + tests.AssertContains(t, dump, "user-agent", de.ReqHeader) + tests.AssertContains(t, dump, "test body", de.ReqBody) + tests.AssertContains(t, dump, "date", de.RespHeader) + tests.AssertContains(t, dump, "testpost: text response", de.RespBody) } c := tc() testDump(c) @@ -256,14 +256,14 @@ func TestEnableDumpTo(t *testing.T) { buff := new(bytes.Buffer) resp, err := tc().R().EnableDumpTo(buff).Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, buff.Len() > 0) + tests.AssertEqual(t, true, buff.Len() > 0) } func TestEnableDumpToFIle(t *testing.T) { tmpFile := "tmp_dumpfile_req" resp, err := tc().R().EnableDumpToFile(tests.GetTestFilePath(tmpFile)).Get("/") assertSuccess(t, resp, err) - assertEqual(t, true, len(getTestFileContent(t, tmpFile)) > 0) + tests.AssertEqual(t, true, len(getTestFileContent(t, tmpFile)) > 0) os.Remove(tests.GetTestFilePath(tmpFile)) } @@ -281,14 +281,14 @@ func TestSetBodyMarshal(t *testing.T) { assertUsernameJson := func(body []byte) { var user User err := json.Unmarshal(body, &user) - assertNoError(t, err) - assertEqual(t, username, user.Username) + tests.AssertNoError(t, err) + tests.AssertEqual(t, username, user.Username) } assertUsernameXml := func(body []byte) { var user User err := xml.Unmarshal(body, &user) - assertNoError(t, err) - assertEqual(t, username, user.Username) + tests.AssertNoError(t, err) + tests.AssertEqual(t, username, user.Username) } testCases := []struct { @@ -443,8 +443,8 @@ func TestSetBody(t *testing.T) { var e Echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - assertEqual(t, tc.ContentType, e.Header.Get(hdrContentTypeKey)) - assertEqual(t, body, e.Body) + tests.AssertEqual(t, tc.ContentType, e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, body, e.Body) } } @@ -461,7 +461,7 @@ func TestCookie(t *testing.T) { }, ).SetResult(&headers).Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) + tests.AssertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } func TestSetBasicAuth(t *testing.T) { @@ -471,7 +471,7 @@ func TestSetBasicAuth(t *testing.T) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) + tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) } func TestSetBearerAuthToken(t *testing.T) { @@ -482,7 +482,7 @@ func TestSetBearerAuthToken(t *testing.T) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "Bearer "+token, headers.Get("Authorization")) + tests.AssertEqual(t, "Bearer "+token, headers.Get("Authorization")) } func TestHeader(t *testing.T) { @@ -494,7 +494,7 @@ func testHeader(t *testing.T, c *Client) { customUserAgent := "My Custom User Agent" resp, err := c.R().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") assertSuccess(t, resp, err) - assertEqual(t, customUserAgent, resp.String()) + tests.AssertEqual(t, customUserAgent, resp.String()) // Set custom header headers := make(http.Header) @@ -506,9 +506,9 @@ func testHeader(t *testing.T, c *Client) { }).SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "value1", headers.Get("header1")) - assertEqual(t, "value2", headers.Get("header2")) - assertEqual(t, "value3", headers.Get("header3")) + tests.AssertEqual(t, "value1", headers.Get("header1")) + tests.AssertEqual(t, "value2", headers.Get("header2")) + tests.AssertEqual(t, "value3", headers.Get("header3")) } func TestSetHeaderNonCanonical(t *testing.T) { @@ -520,21 +520,21 @@ func TestSetHeaderNonCanonical(t *testing.T) { key: "test", }).Get("/header") assertSuccess(t, resp, err) - assertEqual(t, true, strings.Contains(resp.Dump(), key)) + tests.AssertEqual(t, true, strings.Contains(resp.Dump(), key)) resp, err = c.R(). EnableDumpWithoutResponse(). SetHeaderNonCanonical(key, "test"). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, true, strings.Contains(resp.Dump(), key)) + tests.AssertEqual(t, true, strings.Contains(resp.Dump(), key)) c.SetCommonHeaderNonCanonical(key, "test") resp, err = c.R(). EnableDumpWithoutResponse(). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, true, strings.Contains(resp.Dump(), key)) + tests.AssertEqual(t, true, strings.Contains(resp.Dump(), key)) } func TestQueryParam(t *testing.T) { @@ -558,14 +558,14 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key3", "value3"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryString resp, err = c.R(). SetQueryString("key1=value1&key2=value2&key3=value3"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParams resp, err = c.R(). @@ -576,7 +576,7 @@ func testQueryParam(t *testing.T, c *Client) { }). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=client&key5=client&key5=extra", resp.String()) // SetQueryParam & SetQueryParams & SetQueryString resp, err = c.R(). @@ -588,7 +588,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryString("key4=value4&key5=value5"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) // Set same param to override resp, err = c.R(). @@ -603,7 +603,7 @@ func testQueryParam(t *testing.T, c *Client) { SetQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) // Add same param without override resp, err = c.R(). @@ -618,7 +618,7 @@ func testQueryParam(t *testing.T, c *Client) { AddQueryParam("key4", "value44"). Get("/query-parameter") assertSuccess(t, resp, err) - assertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) } func TestPathParam(t *testing.T) { @@ -632,7 +632,7 @@ func testPathParam(t *testing.T, c *Client) { SetPathParam("username", username). Get("/user/{username}/profile") assertSuccess(t, resp, err) - assertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) + tests.AssertEqual(t, fmt.Sprintf("%s's profile", username), resp.String()) } func TestSuccess(t *testing.T) { @@ -646,7 +646,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo). Get("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) userInfo = UserInfo{} resp, err = c.R(). @@ -655,7 +655,7 @@ func testSuccess(t *testing.T, c *Client) { SetResult(&userInfo).EnableDump(). Get("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestError(t *testing.T) { @@ -669,7 +669,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10000, errMsg.ErrorCode) + tests.AssertEqual(t, 10000, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -677,7 +677,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10001, errMsg.ErrorCode) + tests.AssertEqual(t, 10001, errMsg.ErrorCode) errMsg = ErrorMessage{} resp, err = c.R(). @@ -686,7 +686,7 @@ func testError(t *testing.T, c *Client) { SetError(&errMsg). Get("/search") assertIsError(t, resp, err) - assertEqual(t, 10001, errMsg.ErrorCode) + tests.AssertEqual(t, 10001, errMsg.ErrorCode) c.SetCommonError(&errMsg) resp, err = c.R(). @@ -694,8 +694,8 @@ func testError(t *testing.T, c *Client) { Get("/search") assertIsError(t, resp, err) em, ok := resp.Error().(*ErrorMessage) - assertEqual(t, true, ok) - assertEqual(t, 10000, em.ErrorCode) + tests.AssertEqual(t, true, ok) + tests.AssertEqual(t, 10000, em.ErrorCode) } func TestForm(t *testing.T) { @@ -712,7 +712,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) v := make(url.Values) v.Add("username", "imroc") @@ -722,7 +722,7 @@ func testForm(t *testing.T, c *Client) { SetResult(&userInfo). Post("/search") assertSuccess(t, resp, err) - assertEqual(t, "roc@imroc.cc", userInfo.Email) + tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) } func TestHostHeaderOverride(t *testing.T) { @@ -732,33 +732,33 @@ func TestHostHeaderOverride(t *testing.T) { func testHostHeaderOverride(t *testing.T, c *Client) { resp, err := c.R().SetHeader("Host", "testhostname").Get("/host-header") assertSuccess(t, resp, err) - assertEqual(t, "testhostname", resp.String()) + tests.AssertEqual(t, "testhostname", resp.String()) } func assertTraceInfo(t *testing.T, resp *Response, enable bool) { ti := resp.TraceInfo() - assertEqual(t, true, resp.TotalTime() > 0) + tests.AssertEqual(t, true, resp.TotalTime() > 0) if !enable { - assertEqual(t, false, ti.TotalTime > 0) - assertIsNil(t, ti.RemoteAddr) - assertContains(t, ti.String(), "not enabled", true) - assertContains(t, ti.Blame(), "not enabled", true) + tests.AssertEqual(t, false, ti.TotalTime > 0) + tests.AssertIsNil(t, ti.RemoteAddr) + tests.AssertContains(t, ti.String(), "not enabled", true) + tests.AssertContains(t, ti.Blame(), "not enabled", true) return } - assertContains(t, ti.String(), "not enabled", false) - assertContains(t, ti.Blame(), "not enabled", false) - assertEqual(t, true, ti.TotalTime > 0) - assertEqual(t, true, ti.ConnectTime > 0) - assertEqual(t, true, ti.FirstResponseTime > 0) - assertEqual(t, true, ti.ResponseTime > 0) - assertNotNil(t, ti.RemoteAddr) + tests.AssertContains(t, ti.String(), "not enabled", false) + tests.AssertContains(t, ti.Blame(), "not enabled", false) + tests.AssertEqual(t, true, ti.TotalTime > 0) + tests.AssertEqual(t, true, ti.ConnectTime > 0) + tests.AssertEqual(t, true, ti.FirstResponseTime > 0) + tests.AssertEqual(t, true, ti.ResponseTime > 0) + tests.AssertNotNil(t, ti.RemoteAddr) if ti.IsConnReused { - assertEqual(t, true, ti.TCPConnectTime == 0) - assertEqual(t, true, ti.TLSHandshakeTime == 0) + tests.AssertEqual(t, true, ti.TCPConnectTime == 0) + tests.AssertEqual(t, true, ti.TLSHandshakeTime == 0) } else { - assertEqual(t, true, ti.TCPConnectTime > 0) - assertEqual(t, true, ti.TLSHandshakeTime > 0) + tests.AssertEqual(t, true, ti.TCPConnectTime > 0) + tests.AssertEqual(t, true, ti.TLSHandshakeTime > 0) } } @@ -798,18 +798,18 @@ func TestTraceOnTimeout(t *testing.T) { c.EnableTraceAll().SetTimeout(100 * time.Millisecond) resp, err := c.R().Get("http://req-nowhere.local") - assertNotNil(t, err) - assertNotNil(t, resp) + tests.AssertNotNil(t, err) + tests.AssertNotNil(t, resp) ti := resp.TraceInfo() - assertEqual(t, true, ti.DNSLookupTime >= 0) - assertEqual(t, true, ti.ConnectTime == 0) - assertEqual(t, true, ti.TLSHandshakeTime == 0) - assertEqual(t, true, ti.TCPConnectTime == 0) - assertEqual(t, true, ti.FirstResponseTime == 0) - assertEqual(t, true, ti.ResponseTime == 0) - assertEqual(t, true, ti.TotalTime > 0) - assertEqual(t, true, ti.TotalTime == resp.TotalTime()) + tests.AssertEqual(t, true, ti.DNSLookupTime >= 0) + tests.AssertEqual(t, true, ti.ConnectTime == 0) + tests.AssertEqual(t, true, ti.TLSHandshakeTime == 0) + tests.AssertEqual(t, true, ti.TCPConnectTime == 0) + tests.AssertEqual(t, true, ti.FirstResponseTime == 0) + tests.AssertEqual(t, true, ti.ResponseTime == 0) + tests.AssertEqual(t, true, ti.TotalTime > 0) + tests.AssertEqual(t, true, ti.TotalTime == resp.TotalTime()) }) } @@ -817,32 +817,32 @@ func TestAutoDetectRequestContentType(t *testing.T) { c := tc() resp, err := c.R().SetBody(getTestFileContent(t, "sample-image.png")).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, "image/png", resp.String()) + tests.AssertEqual(t, "image/png", resp.String()) resp, err = c.R().SetBodyJsonString(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, jsonContentType, resp.String()) + tests.AssertEqual(t, jsonContentType, resp.String()) resp, err = c.R().SetContentType(xmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, xmlContentType, resp.String()) + tests.AssertEqual(t, xmlContentType, resp.String()) resp, err = c.R().SetBody(`

hello

`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, "text/html; charset=utf-8", resp.String()) + tests.AssertEqual(t, "text/html; charset=utf-8", resp.String()) resp, err = c.R().SetBody(`hello world`).Post("/content-type") assertSuccess(t, resp, err) - assertEqual(t, plainTextContentType, resp.String()) + tests.AssertEqual(t, plainTextContentType, resp.String()) } func TestSetFileUploadCheck(t *testing.T) { c := tc() resp, err := c.R().SetFileUpload(FileUpload{}).Post("/multipart") - assertErrorContains(t, err, "missing param name") - assertErrorContains(t, err, "missing filename") - assertErrorContains(t, err, "missing file content") - assertEqual(t, 0, len(resp.Request.uploadFiles)) + tests.AssertErrorContains(t, err, "missing param name") + tests.AssertErrorContains(t, err, "missing filename") + tests.AssertErrorContains(t, err, "missing file content") + tests.AssertEqual(t, 0, len(resp.Request.uploadFiles)) } func TestUploadMultipart(t *testing.T) { @@ -857,23 +857,23 @@ func TestUploadMultipart(t *testing.T) { SetResult(&m). Post("/multipart") assertSuccess(t, resp, err) - assertContains(t, resp.String(), "sample-image.png", true) - assertContains(t, resp.String(), "sample-file.txt", true) - assertContains(t, resp.String(), "value1", true) - assertContains(t, resp.String(), "value2", true) + tests.AssertContains(t, resp.String(), "sample-image.png", true) + tests.AssertContains(t, resp.String(), "sample-file.txt", true) + tests.AssertContains(t, resp.String(), "value1", true) + tests.AssertContains(t, resp.String(), "value2", true) } func TestFixPragmaCache(t *testing.T) { resp, err := tc().EnableForceHTTP1().R().Get("/pragma") assertSuccess(t, resp, err) - assertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) + tests.AssertEqual(t, "no-cache", resp.Header.Get("Cache-Control")) } func TestSetFileBytes(t *testing.T) { resp := uploadTextFile(t, func(r *Request) { r.SetFileBytes("file", "file.txt", []byte("test")) }) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) } func TestSetFileReader(t *testing.T) { @@ -881,13 +881,13 @@ func TestSetFileReader(t *testing.T) { resp := uploadTextFile(t, func(r *Request) { r.SetFileReader("file", "file.txt", buff) }) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) buff = bytes.NewBufferString("test") resp = uploadTextFile(t, func(r *Request) { r.SetFileReader("file", "file.txt", ioutil.NopCloser(buff)) }) - assertEqual(t, "test", resp.String()) + tests.AssertEqual(t, "test", resp.String()) } func TestSetFileWithRetry(t *testing.T) { @@ -903,7 +903,7 @@ func TestSetFileWithRetry(t *testing.T) { SetQueryParam("attempt", "0"). Post("/file-text") assertSuccess(t, resp, err) - assertEqual(t, 2, resp.Request.RetryAttempt) + tests.AssertEqual(t, 2, resp.Request.RetryAttempt) } func TestSetFile(t *testing.T) { @@ -911,10 +911,10 @@ func TestSetFile(t *testing.T) { resp := uploadTextFile(t, func(r *Request) { r.SetFile("file", tests.GetTestFilePath(filename)) }) - assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) + tests.AssertEqual(t, getTestFileContent(t, filename), resp.Bytes()) resp, err := tc().SetLogger(nil).R().SetFile("file", "file-not-exists.txt").Post("/file-text") - assertErrorContains(t, err, "no such file") + tests.AssertErrorContains(t, err, "no such file") } func TestSetFiles(t *testing.T) { @@ -924,7 +924,7 @@ func TestSetFiles(t *testing.T) { "file": tests.GetTestFilePath(filename), }) }) - assertEqual(t, getTestFileContent(t, filename), resp.Bytes()) + tests.AssertEqual(t, getTestFileContent(t, filename), resp.Bytes()) } func uploadTextFile(t *testing.T, setReq func(r *Request)) *Response { @@ -966,7 +966,7 @@ func TestUploadCallback(t *testing.T) { }) resp, err := r.Post("/raw-upload") assertSuccess(t, resp, err) - assertEqual(t, true, n > 1) + tests.AssertEqual(t, true, n > 1) } func TestDownloadCallback(t *testing.T) { @@ -977,5 +977,5 @@ func TestDownloadCallback(t *testing.T) { n++ }).Get("/download") assertSuccess(t, resp, err) - assertEqual(t, true, n > 0) + tests.AssertEqual(t, true, n > 0) } diff --git a/request_wrapper_test.go b/request_wrapper_test.go index fe080646..d0204ded 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -3,6 +3,7 @@ package req import ( "bytes" "context" + "github.com/imroc/req/v3/internal/tests" "net/http" "testing" "time" @@ -13,7 +14,7 @@ func init() { } func TestGlobalWrapperForRequestSettings(t *testing.T) { - assertAllNotNil(t, + tests.AssertAllNotNil(t, SetFiles(map[string]string{"test": "req.go"}), SetFile("test", "req.go"), SetFileReader("test", "test.txt", bytes.NewBufferString("test")), @@ -112,8 +113,8 @@ func testGlobalWrapperMustSendMethods(t *testing.T) { url := getTestServerURL() + "/" for _, tc := range testCases { resp := tc.SendReq(url) - assertNotNil(t, resp.Response) - assertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) + tests.AssertNotNil(t, resp.Response) + tests.AssertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) } } @@ -155,6 +156,6 @@ func testGlobalWrapperSendMethods(t *testing.T) { for _, tc := range testCases { resp, err := tc.SendReq(url) assertSuccess(t, resp, err) - assertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) + tests.AssertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) } } diff --git a/response_test.go b/response_test.go index 1baaa93e..d87cd211 100644 --- a/response_test.go +++ b/response_test.go @@ -1,6 +1,7 @@ package req import ( + "github.com/imroc/req/v3/internal/tests" "net/http" "testing" ) @@ -18,8 +19,8 @@ func TestUnmarshalJson(t *testing.T) { resp, err := tc().R().Get("/json") assertSuccess(t, resp, err) err = resp.UnmarshalJson(&user) - assertNoError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestUnmarshalXml(t *testing.T) { @@ -27,8 +28,8 @@ func TestUnmarshalXml(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.UnmarshalXml(&user) - assertNoError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestUnmarshal(t *testing.T) { @@ -36,8 +37,8 @@ func TestUnmarshal(t *testing.T) { resp, err := tc().R().Get("/xml") assertSuccess(t, resp, err) err = resp.Unmarshal(&user) - assertNoError(t, err) - assertEqual(t, "roc", user.Name) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "roc", user.Name) } func TestResponseResult(t *testing.T) { @@ -46,10 +47,10 @@ func TestResponseResult(t *testing.T) { if !ok { t.Fatal("Response.Result() should return *User") } - assertEqual(t, "roc", user.Name) + tests.AssertEqual(t, "roc", user.Name) - assertEqual(t, true, resp.TotalTime() > 0) - assertEqual(t, false, resp.ReceivedAt().IsZero()) + tests.AssertEqual(t, true, resp.TotalTime() > 0) + tests.AssertEqual(t, false, resp.ReceivedAt().IsZero()) } func TestResponseError(t *testing.T) { @@ -58,14 +59,14 @@ func TestResponseError(t *testing.T) { if !ok { t.Fatal("Response.Error() should return *Message") } - assertEqual(t, "not allowed", msg.Message) + tests.AssertEqual(t, "not allowed", msg.Message) } func TestResponseWrap(t *testing.T) { resp, err := tc().R().Get("/json") assertSuccess(t, resp, err) - assertEqual(t, true, resp.GetStatusCode() == http.StatusOK) - assertEqual(t, true, resp.GetStatus() == "200 OK") - assertEqual(t, true, resp.GetHeader(hdrContentTypeKey) == jsonContentType) - assertEqual(t, true, len(resp.GetHeaderValues(hdrContentTypeKey)) == 1) + tests.AssertEqual(t, true, resp.GetStatusCode() == http.StatusOK) + tests.AssertEqual(t, true, resp.GetStatus() == "200 OK") + tests.AssertEqual(t, true, resp.GetHeader(hdrContentTypeKey) == jsonContentType) + tests.AssertEqual(t, true, len(resp.GetHeaderValues(hdrContentTypeKey)) == 1) } diff --git a/retry_test.go b/retry_test.go index 137423ec..2bffbf13 100644 --- a/retry_test.go +++ b/retry_test.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "github.com/imroc/req/v3/internal/tests" "io/ioutil" "math" "net/http" @@ -27,9 +28,9 @@ func testRetry(t *testing.T, setFunc func(r *Request)) { }) setFunc(r) resp, err := r.Get("/too-many") - assertNoError(t, err) - assertEqual(t, 3, resp.Request.RetryAttempt) - assertEqual(t, 3, attempt) + tests.AssertNoError(t, err) + tests.AssertEqual(t, 3, resp.Request.RetryAttempt) + tests.AssertEqual(t, 3, attempt) } func TestRetryInterval(t *testing.T) { @@ -54,7 +55,7 @@ func TestAddRetryHook(t *testing.T) { test = "test2" }) }) - assertEqual(t, "test2", test) + tests.AssertEqual(t, "test2", test) } func TestRetryOverride(t *testing.T) { @@ -73,9 +74,9 @@ func TestRetryOverride(t *testing.T) { }).SetRetryCondition(func(resp *Response, err error) bool { return err != nil || resp.StatusCode == http.StatusTooManyRequests }).Get("/too-many") - assertNoError(t, err) - assertEqual(t, "test1", test) - assertEqual(t, 2, resp.Request.RetryAttempt) + tests.AssertNoError(t, err) + tests.AssertEqual(t, "test1", test) + tests.AssertEqual(t, 2, resp.Request.RetryAttempt) } func TestAddRetryCondition(t *testing.T) { @@ -91,9 +92,9 @@ func TestAddRetryCondition(t *testing.T) { SetRetryHook(func(resp *Response, err error) { attempt++ }).Get("/too-many") - assertNoError(t, err) - assertEqual(t, 0, attempt) - assertEqual(t, 0, resp.Request.RetryAttempt) + tests.AssertNoError(t, err) + tests.AssertEqual(t, 0, attempt) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) attempt = 0 resp, err = tc(). @@ -107,9 +108,9 @@ func TestAddRetryCondition(t *testing.T) { SetCommonRetryHook(func(resp *Response, err error) { attempt++ }).R().Get("/too-many") - assertNoError(t, err) - assertEqual(t, 0, attempt) - assertEqual(t, 0, resp.Request.RetryAttempt) + tests.AssertNoError(t, err) + tests.AssertEqual(t, 0, attempt) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) } @@ -118,13 +119,13 @@ func TestRetryWithUnreplayableBody(t *testing.T) { SetRetryCount(1). SetBody(bytes.NewBufferString("test")). Post("/") - assertEqual(t, errRetryableWithUnReplayableBody, err) + tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) _, err = tc().R(). SetRetryCount(1). SetBody(ioutil.NopCloser(bytes.NewBufferString("test"))). Post("/") - assertEqual(t, errRetryableWithUnReplayableBody, err) + tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) } func TestRetryWithSetResult(t *testing.T) { @@ -137,7 +138,7 @@ func TestRetryWithSetResult(t *testing.T) { SetResult(&headers). Get("/header") assertSuccess(t, resp, err) - assertEqual(t, "test=test", headers.Get("Cookie")) + tests.AssertEqual(t, "test=test", headers.Get("Cookie")) } func TestRetryWithModify(t *testing.T) { @@ -156,7 +157,7 @@ func TestRetryWithModify(t *testing.T) { SetBearerAuthToken(tokens[tokenIndex]). Get("/protected") assertSuccess(t, resp, err) - assertEqual(t, 2, resp.Request.RetryAttempt) + tests.AssertEqual(t, 2, resp.Request.RetryAttempt) } func TestRetryFalse(t *testing.T) { @@ -165,7 +166,7 @@ func TestRetryFalse(t *testing.T) { SetRetryCondition(func(resp *Response, err error) bool { return false }).Get("https://non-exists-host.com.cn") - assertNotNil(t, err) - assertIsNil(t, resp.Response) - assertEqual(t, 0, resp.Request.RetryAttempt) + tests.AssertNotNil(t, err) + tests.AssertIsNil(t, resp.Response) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) } diff --git a/textproto_reader.go b/textproto_reader.go index 0254ee86..15581e40 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "fmt" + "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/util" "net/textproto" "strings" @@ -32,13 +33,13 @@ type textprotoReader struct { // To avoid denial of service attacks, the provided bufio.Reader // should be reading from an io.LimitReader or similar textprotoReader to bound // the size of responses. -func newTextprotoReader(r *bufio.Reader, dumps []*dumper) *textprotoReader { +func newTextprotoReader(r *bufio.Reader, dumps []*dump.Dumper) *textprotoReader { commonHeaderOnce.Do(initCommonHeader) t := &textprotoReader{R: r} if len(dumps) > 0 { - dd := []*dumper{} + dd := []*dump.Dumper{} for _, dump := range dumps { - if dump.ResponseHeader { + if dump.ResponseHeader() { dd = append(dd, dump) } } @@ -57,7 +58,7 @@ func newTextprotoReader(r *bufio.Reader, dumps []*dumper) *textprotoReader { } err = nil for _, dump := range dumps { - dump.dump(line) + dump.Dump(line) } if line[len(line)-1] == '\n' { drop := 1 diff --git a/transfer.go b/transfer.go index 05b8b3f7..8de2b9f1 100644 --- a/transfer.go +++ b/transfer.go @@ -11,6 +11,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/dump" "io" "io/ioutil" "net/http" @@ -304,7 +305,7 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) } // always closes t.BodyCloser -func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { +func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error) { var ncopy int64 closed := false defer func() { @@ -318,7 +319,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { rw := w // raw writer for _, dump := range dumps { - if dump.RequestBody { + if dump.RequestBody() { w = dump.WrapWriter(w) } } @@ -335,7 +336,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { } cw := internal.NewChunkedWriter(rw) for _, dump := range dumps { - if dump.RequestBody { + if dump.RequestBody() { cw = dump.WrapWriteCloser(cw) } } @@ -362,8 +363,8 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dumper) (err error) { return err } for _, dump := range dumps { - if dump.RequestBody { - dump.dump([]byte("\r\n")) + if dump.RequestBody() { + dump.Dump([]byte("\r\n")) } } } diff --git a/transport.go b/transport.go index 24d60420..78665ccc 100644 --- a/transport.go +++ b/transport.go @@ -18,7 +18,11 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/socks" + reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/util" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" @@ -225,7 +229,7 @@ type Transport struct { // must return a http.RoundTripper that then handles the request. // If TLSNextProto is not nil, HTTP/2 support is not enabled // automatically. - TLSNextProto map[string]func(authority string, c TLSConn) http.RoundTripper + TLSNextProto map[string]func(authority string, c reqtls.Conn) http.RoundTripper // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. @@ -258,7 +262,7 @@ type Transport struct { // If zero, a default (currently 4KB) is used. ReadBufferSize int - t2 *http2Transport // non-nil if http2 wired up + t2 *http2.Transport // non-nil if http2 wired up // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. @@ -269,7 +273,7 @@ type Transport struct { *ResponseOptions - dump *dumper + dump *dump.Dumper // Debugf is the optional debug function. Debugf func(format string, v ...interface{}) @@ -293,17 +297,17 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu switch b := res.Body.(type) { case *gzipReader: b.body.body = wrap(b.body.body) - case *http2gzipReader: - b.body = wrap(b.body) + case *http2.GzipReader: + b.Body = wrap(b.Body) default: res.Body = wrap(res.Body) } } func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { - dumps := getDumpers(req.Context(), t.dump) + dumps := dump.GetDumpers(req.Context(), t.dump) for _, dump := range dumps { - if dump.ResponseBody { + if dump.ResponseBody() { res.Body = dump.WrapReadCloser(res.Body) } } @@ -411,7 +415,7 @@ func (t *Transport) Clone() *Transport { t2.TLSClientConfig = t.TLSClientConfig.Clone() } if t.TLSNextProto != nil { - npm := map[string]func(authority string, c TLSConn) http.RoundTripper{} + npm := map[string]func(authority string, c reqtls.Conn) http.RoundTripper{} for k, v := range t.TLSNextProto { npm[k] = v } @@ -591,7 +595,7 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { } // Failed. Clean up and determine whether to retry. - if http2isNoCachedConnError(err) { + if http2.IsNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) } @@ -674,7 +678,7 @@ func rewindBody(req *http.Request) (rewound *http.Request, err error) { // HTTP request on a new connection. The non-nil input error is the // error from roundTrip. func (pc *persistConn) shouldRetryRequest(req *http.Request, err error) bool { - if http2isNoCachedConnError(err) { + if http2.IsNoCachedConnError(err) { // Issue 16582: if the user started a bunch of // requests at once, they can all pick the same conn // and violate the server's max concurrent streams. @@ -773,7 +777,7 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *http.Request) { - t.cancelRequest(cancelKey{req}, errRequestCanceled) + t.cancelRequest(cancelKey{req}, common.ErrRequestCanceled) } // Cancel an in-flight request, recording the error value. @@ -1345,7 +1349,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi case <-req.Context().Done(): return nil, req.Context().Err() case err := <-cancelc: - if err == errRequestCanceled { + if err == common.ErrRequestCanceled { err = errRequestCanceledConn } return nil, err @@ -1359,7 +1363,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi case <-req.Context().Done(): return nil, req.Context().Err() case err := <-cancelc: - if err == errRequestCanceled { + if err == common.ErrRequestCanceled { err = errRequestCanceledConn } return nil, err @@ -1509,7 +1513,7 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace } pc.tlsState = &cs pc.conn = tlsConn - if !forProxy && pc.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + if !forProxy && pc.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2.NextProtoTLS { return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil @@ -1569,7 +1573,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers trace.TLSHandshakeDone(cs, nil) } pconn.tlsState = &cs - if cm.proxyURL == nil && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2NextProtoTLS { + if cm.proxyURL == nil && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2.NextProtoTLS { return nil, newHttp2NotSupportedError(cs.NegotiatedProtocol) } } @@ -1898,7 +1902,7 @@ func fixPragmaCacheControl(header http.Header) { // 100-continue") from the server. It returns the final non-100 one. // trace is optional. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - dumps := getDumpers(req.Context(), pc.t.dump) + dumps := dump.GetDumpers(req.Context(), pc.t.dump) tp := newTextprotoReader(pc.br, dumps) resp := &http.Response{ Request: req, @@ -2013,7 +2017,7 @@ func (pc *persistConn) cancelRequest(err error) { pc.mu.Lock() defer pc.mu.Unlock() pc.canceledErr = err - pc.closeLocked(errRequestCanceled) + pc.closeLocked(common.ErrRequestCanceled) } // closeConnIfStillIdle closes the connection if it's still sitting idle. @@ -2530,9 +2534,9 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo } rw := w // raw writer - dumps := getDumpers(r.Context(), pc.t.dump) + dumps := dump.GetDumpers(r.Context(), pc.t.dump) for _, dump := range dumps { - if dump.RequestHeader { + if dump.RequestHeader() { w = dump.WrapWriter(w) } } @@ -2749,7 +2753,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) - return nil, errRequestCanceled + return nil, common.ErrRequestCanceled } pc.mu.Lock() pc.numExpectedResponses++ @@ -2874,7 +2878,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er } return re.res, nil case <-cancelChan: - canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) + canceled = pc.t.cancelRequest(req.cancelKey, common.ErrRequestCanceled) cancelChan = nil case <-ctxDoneChan: canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) diff --git a/transport_internal_test.go b/transport_internal_test.go index 47c25a06..4ff9da08 100644 --- a/transport_internal_test.go +++ b/transport_internal_test.go @@ -11,7 +11,10 @@ import ( "context" "crypto/tls" "errors" + "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/testcert" + "github.com/imroc/req/v3/internal/tests" + reqtls "github.com/imroc/req/v3/internal/tls" "io" "net" "net/http" @@ -25,7 +28,7 @@ func withT(r *http.Request, t *testing.T) *http.Request { // Issue 15446: incorrect wrapping of errors when server closes an idle connection. func TestTransportPersistConnReadLoopEOF(t *testing.T) { - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() connc := make(chan net.Conn, 1) @@ -79,17 +82,6 @@ func isTransportReadFromServerError(err error) bool { return ok } -func newLocalListener(t *testing.T) net.Listener { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - ln, err = net.Listen("tcp6", "[::1]:0") - } - if err != nil { - t.Fatal(err) - } - return ln -} - func dummyRequest(method string) *http.Request { req, err := http.NewRequest(method, "http://fake.tld/", nil) if err != nil { @@ -140,7 +132,7 @@ func TestTransportShouldRetryRequest(t *testing.T) { 2: { pc: &persistConn{reused: true}, req: dummyRequest("POST"), - err: http2ErrNoCachedConn, + err: http2.ErrNoCachedConn, want: true, }, 3: { @@ -206,7 +198,7 @@ func TestTransportBodyAltRewind(t *testing.T) { if err != nil { t.Fatal(err) } - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() go func() { @@ -233,8 +225,8 @@ func TestTransportBodyAltRewind(t *testing.T) { roundTripped := false tr := &Transport{ DisableKeepAlives: true, - TLSNextProto: map[string]func(string, TLSConn) http.RoundTripper{ - "foo": func(authority string, c TLSConn) http.RoundTripper { + TLSNextProto: map[string]func(string, reqtls.Conn) http.RoundTripper{ + "foo": func(authority string, c reqtls.Conn) http.RoundTripper { return roundTripFunc(func(r *http.Request) (*http.Response, error) { n, _ := io.Copy(io.Discard, r.Body) if n == 0 { @@ -247,7 +239,7 @@ func TestTransportBodyAltRewind(t *testing.T) { }, nil } roundTripped = true - return nil, http2noCachedConnError{} + return nil, http2.ErrNoCachedConn }) }, }, diff --git a/transport_test.go b/transport_test.go index 59c3f684..b3a05532 100644 --- a/transport_test.go +++ b/transport_test.go @@ -20,9 +20,14 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/testcert" + "github.com/imroc/req/v3/internal/tests" + reqtls "github.com/imroc/req/v3/internal/tls" "go/token" "golang.org/x/net/http/httpproxy" + nethttp2 "golang.org/x/net/http2" "io" "log" mrand "math/rand" @@ -84,24 +89,6 @@ func (t *Transport) IdleConnStrsForTesting() []string { return ret } -func (t *Transport) IdleConnStrsForTestingH2() []string { - var ret []string - noDialPool := t.t2.ConnPool.(http2noDialClientConnPool) - pool := noDialPool.http2clientConnPool - - pool.mu.Lock() - defer pool.mu.Unlock() - - for k, cc := range pool.conns { - for range cc { - ret = append(ret, k) - } - } - - sort.Strings(ret) - return ret -} - func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { t.idleMu.Lock() defer t.idleMu.Unlock() @@ -959,7 +946,7 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { ts.CloseClientConnections() var keys2 []string - if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { + if !tests.WaitCondition(3*time.Second, 50*time.Millisecond, func() bool { keys2 = tr.IdleConnKeysForTesting() return len(keys2) == 0 }) { @@ -1469,7 +1456,7 @@ func TestTransportExpect100Continue(t *testing.T) { func TestSOCKS5Proxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) - l := newLocalListener(t) + l := tests.NewLocalListener(t) defer l.Close() defer close(ch) proxy := func(t *testing.T) { @@ -1720,7 +1707,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() listenerDone := make(chan struct{}) go func() { @@ -2212,7 +2199,7 @@ func TestIssue3595(t *testing.T) { })) defer ts.Close() c := tc().httpClient - res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + res, err := c.Post(ts.URL, "application/octet-stream", tests.NeverEnding('a')) if err != nil { t.Errorf("Post: %v", err) return @@ -2369,7 +2356,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { const debug = false mux := http.NewServeMux() mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, neverEnding('a')) + io.Copy(w, tests.NeverEnding('a')) }) ts := httptest.NewServer(mux) defer ts.Close() @@ -2427,7 +2414,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { const debug = false mux := http.NewServeMux() mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, neverEnding('a')) + io.Copy(w, tests.NeverEnding('a')) }) mux.HandleFunc("/put", func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() @@ -2631,7 +2618,7 @@ func TestCancelRequestWithChannel(t *testing.T) { body, err := io.ReadAll(res.Body) d := time.Since(t0) - if err != errRequestCanceled { + if err != common.ErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } if string(body) != "Hello" { @@ -2734,7 +2721,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { if err == nil { t.Fatalf("unexpected success from RoundTrip") } - if err != errRequestCanceled { + if err != common.ErrRequestCanceled { t.Errorf("RoundTrip error = %v; want errRequestCanceled", err) } } @@ -3077,10 +3064,6 @@ Content-Length: %d } } -var ( - ExportHttp2ConfigureServer = http2ConfigureServer -) - type clientServerTest struct { t *testing.T h2 bool @@ -3126,12 +3109,6 @@ var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } -func optWithServerLog(lg *log.Logger) func(*httptest.Server) { - return func(ts *httptest.Server) { - ts.Config.ErrorLog = lg - } -} - func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interface{}) *clientServerTest { cst := &clientServerTest{ t: t, @@ -3157,14 +3134,14 @@ func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interfac cst.ts.Start() return cst } - http2ConfigureServer(cst.ts.Config, nil) + nethttp2.ConfigureServer(cst.ts.Config, nil) cst.ts.TLS = cst.ts.Config.TLSConfig cst.ts.StartTLS() cst.tr.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } - if _, err := http2ConfigureTransports(cst.tr); err != nil { + if _, err := http2.ConfigureTransports(transportImpl{cst.tr}); err != nil { t.Fatal(err) } return cst @@ -3462,7 +3439,7 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() testdonec := make(chan struct{}) defer close(testdonec) @@ -3639,7 +3616,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { const bodySize = 256 << 10 finalBit := make(byteFromChanReader, 1) - req, _ := http.NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req, _ := http.NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(tests.NeverEnding('x'), bodySize-1), finalBit)) req.ContentLength = bodySize res, err := c.Do(req) if err := wantBody(res, err, "foo"); err != nil { @@ -3736,7 +3713,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { io.Reader io.Closer }{ - io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), + io.MultiReader(io.LimitReader(tests.NeverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), closerFunc(func() error { select { case didClose <- true: @@ -4055,7 +4032,7 @@ func TestTransportDialCancelRace(t *testing.T) { }) defer SetEnterRoundTripHook(nil) res, err := tr.RoundTrip(req) - if err != errRequestCanceled { + if err != common.ErrRequestCanceled { t.Errorf("expected canceled request error; got %v", err) if err == nil { res.Body.Close() @@ -4244,7 +4221,7 @@ func TestTransportAutomaticHTTP2(t *testing.T) { func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { tr := tc().t - tr.TLSNextProto = make(map[string]func(string, TLSConn) http.RoundTripper) + tr.TLSNextProto = make(map[string]func(string, reqtls.Conn) http.RoundTripper) testTransportAutoHTTP(t, tr, false) } @@ -4298,7 +4275,7 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { if err != nil { t.Fatal(err) } - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) defer ln.Close() var wg sync.WaitGroup @@ -4336,8 +4313,8 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { tr := &Transport{ DisableKeepAlives: true, - TLSNextProto: map[string]func(string, TLSConn) http.RoundTripper{ - "foo": func(authority string, c TLSConn) http.RoundTripper { + TLSNextProto: map[string]func(string, reqtls.Conn) http.RoundTripper{ + "foo": func(authority string, c reqtls.Conn) http.RoundTripper { madeRoundTripper <- true return funcRoundTripper(func() { t.Error("foo http.RoundTripper should not be called") @@ -5113,7 +5090,7 @@ func TestMissingStatusNoPanic(t *testing.T) { const want = "unknown status code" - ln := newLocalListener(t) + ln := tests.NewLocalListener(t) addr := ln.Addr().String() done := make(chan bool) fullAddrURL := fmt.Sprintf("http://%s", addr) @@ -5623,8 +5600,8 @@ func TestTransportClone(t *testing.T) { GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, - TLSNextProto: map[string]func(authority string, c TLSConn) http.RoundTripper{ - "foo": func(authority string, c TLSConn) http.RoundTripper { panic("") }, + TLSNextProto: map[string]func(authority string, c reqtls.Conn) http.RoundTripper{ + "foo": func(authority string, c reqtls.Conn) http.RoundTripper { panic("") }, }, ReadBufferSize: 1, WriteBufferSize: 1, diff --git a/transport_wrapper.go b/transport_wrapper.go new file mode 100644 index 00000000..e17d9ee7 --- /dev/null +++ b/transport_wrapper.go @@ -0,0 +1,121 @@ +package req + +import ( + "context" + "crypto/tls" + "github.com/imroc/req/v3/internal/dump" + reqtls "github.com/imroc/req/v3/internal/tls" + "github.com/imroc/req/v3/internal/transport" + "net" + "net/http" + "net/url" + "time" +) + +type transportImpl struct { + t *Transport +} + +func (t transportImpl) Proxy() func(*http.Request) (*url.URL, error) { + return t.t.Proxy +} + +func (t transportImpl) Clone() transport.Interface { + return transportImpl{t.t.Clone()} +} + +func (t transportImpl) Debugf() func(format string, v ...interface{}) { + return t.t.Debugf +} + +func (t transportImpl) SetDebugf(f func(format string, v ...interface{})) { + t.t.Debugf = f +} + +func (t transportImpl) DisableCompression() bool { + return t.t.DisableCompression +} + +func (t transportImpl) TLSClientConfig() *tls.Config { + return t.t.TLSClientConfig +} + +func (t transportImpl) SetTLSClientConfig(c *tls.Config) { + t.t.TLSClientConfig = c +} + +func (t transportImpl) TLSHandshakeTimeout() time.Duration { + return t.t.TLSHandshakeTimeout +} + +func (t transportImpl) DialContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { + return t.t.DialContext +} + +func (t transportImpl) DialTLSContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { + return t.t.DialTLSContext +} + +func (t transportImpl) RegisterProtocol(scheme string, rt http.RoundTripper) { + t.t.RegisterProtocol(scheme, rt) +} + +func (t transportImpl) DisableKeepAlives() bool { + return t.t.DisableKeepAlives +} + +func (t transportImpl) Dump() *dump.Dumper { + return t.t.dump +} + +func (t transportImpl) MaxIdleConns() int { + return t.t.MaxIdleConns +} + +func (t transportImpl) MaxIdleConnsPerHost() int { + return t.t.MaxIdleConnsPerHost +} + +func (t transportImpl) MaxConnsPerHost() int { + return t.t.MaxConnsPerHost +} + +func (t transportImpl) IdleConnTimeout() time.Duration { + return t.t.IdleConnTimeout +} + +func (t transportImpl) ResponseHeaderTimeout() time.Duration { + return t.t.ResponseHeaderTimeout +} + +func (t transportImpl) ExpectContinueTimeout() time.Duration { + return t.t.ExpectContinueTimeout +} + +func (t transportImpl) ProxyConnectHeader() http.Header { + return t.t.ProxyConnectHeader +} + +func (t transportImpl) GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { + return t.t.GetProxyConnectHeader +} + +func (t transportImpl) MaxResponseHeaderBytes() int64 { + return t.t.MaxResponseHeaderBytes +} + +func (t transportImpl) WriteBufferSize() int { + return t.t.WriteBufferSize +} + +func (t transportImpl) ReadBufferSize() int { + return t.t.ReadBufferSize +} + +func (t transportImpl) TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper { + return t.t.TLSNextProto +} + +func (t transportImpl) SetTLSNextProto(m map[string]func(authority string, c reqtls.Conn) http.RoundTripper) { + t.t.TLSNextProto = m +} From 13313e934eb4dff92e4f09969b8123ce4a76602d Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 17 Jun 2022 17:34:37 +0800 Subject: [PATCH 505/843] extract header const into internal header package --- client.go | 12 ++++++------ client_test.go | 7 ++++--- client_wrapper_test.go | 3 ++- internal/header/header.go | 11 ++++++++++- middleware.go | 14 +++++++------- req.go | 11 ----------- req_test.go | 33 +++++++++++++++++---------------- request.go | 7 ++++--- request_test.go | 31 ++++++++++++++++--------------- request_wrapper_test.go | 3 ++- response.go | 3 ++- response_test.go | 5 +++-- transport.go | 3 ++- 13 files changed, 75 insertions(+), 68 deletions(-) diff --git a/client.go b/client.go index f791f538..81856e88 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,10 @@ import ( "encoding/json" "encoding/xml" "errors" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/http2" + "github.com/imroc/req/v3/internal/util" + "golang.org/x/net/publicsuffix" "io" "io/ioutil" "net" @@ -17,10 +21,6 @@ import ( "reflect" "strings" "time" - - "github.com/imroc/req/v3/internal/http2" - "github.com/imroc/req/v3/internal/util" - "golang.org/x/net/publicsuffix" ) // DefaultClient returns the global default Client. @@ -575,7 +575,7 @@ func (c *Client) EnableAutoDecode() *Client { // SetUserAgent set the "User-Agent" header for all requests. func (c *Client) SetUserAgent(userAgent string) *Client { - return c.SetCommonHeader(hdrUserAgentKey, userAgent) + return c.SetCommonHeader(header.UserAgent, userAgent) } // SetCommonBearerAuthToken set the bearer auth token for all requests. @@ -627,7 +627,7 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { // SetCommonContentType set the `Content-Type` header for all requests. func (c *Client) SetCommonContentType(ct string) *Client { - c.SetCommonHeader(hdrContentTypeKey, ct) + c.SetCommonHeader(header.ContentType, ct) return c } diff --git a/client_test.go b/client_test.go index f1ab64de..fca50bdf 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "io/ioutil" "net" @@ -135,8 +136,8 @@ func TestSetProxy(t *testing.T) { } func TestSetCommonContentType(t *testing.T) { - c := tc().SetCommonContentType(jsonContentType) - tests.AssertEqual(t, jsonContentType, c.Headers.Get(hdrContentTypeKey)) + c := tc().SetCommonContentType(header.JsonContentType) + tests.AssertEqual(t, header.JsonContentType, c.Headers.Get(header.ContentType)) } func TestSetCommonHeader(t *testing.T) { @@ -177,7 +178,7 @@ func TestSetCommonBearerAuthToken(t *testing.T) { func TestSetUserAgent(t *testing.T) { c := tc().SetUserAgent("test") - tests.AssertEqual(t, "test", c.Headers.Get(hdrUserAgentKey)) + tests.AssertEqual(t, "test", c.Headers.Get(header.UserAgent)) } func TestAutoDecode(t *testing.T) { diff --git a/client_wrapper_test.go b/client_wrapper_test.go index e7cf8623..bf9496f5 100644 --- a/client_wrapper_test.go +++ b/client_wrapper_test.go @@ -2,6 +2,7 @@ package req import ( "crypto/tls" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "net/http" "net/url" @@ -45,7 +46,7 @@ func TestGlobalWrapper(t *testing.T) { SetProxyURL("http://dummy.proxy.local"), SetProxyURL("bad url"), SetProxy(proxy), - SetCommonContentType(jsonContentType), + SetCommonContentType(header.JsonContentType), SetCommonHeader("my-header", "my-value"), SetCommonHeaders(map[string]string{ "header1": "value1", diff --git a/internal/header/header.go b/internal/header/header.go index 52135a51..dcf4c56d 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -1,3 +1,12 @@ package header -const DefaultUserAgent = "req/v3 (https://github.com/imroc/req)" +const ( + DefaultUserAgent = "req/v3 (https://github.com/imroc/req)" + UserAgent = "User-Agent" + Location = "Location" + ContentType = "Content-Type" + PlainTextContentType = "text/plain; charset=utf-8" + JsonContentType = "application/json; charset=utf-8" + XmlContentType = "text/xml; charset=utf-8" + FormContentType = "application/x-www-form-urlencoded" +) diff --git a/middleware.go b/middleware.go index d83dcec2..f6d595f1 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,8 @@ package req import ( "bytes" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" "mime/multipart" @@ -13,8 +15,6 @@ import ( "reflect" "strings" "time" - - "github.com/imroc/req/v3/internal/util" ) type ( @@ -47,7 +47,7 @@ func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEH hdr.Set("Content-Disposition", contentDispositionValue) if !util.IsStringEmpty(contentType) { - hdr.Set(hdrContentTypeKey, contentType) + hdr.Set(header.ContentType, contentType) } return hdr } @@ -185,17 +185,17 @@ func handleMultiPart(c *Client, r *Request) (err error) { } func handleFormData(r *Request) { - r.SetContentType(formContentType) + r.SetContentType(header.FormContentType) r.SetBodyBytes([]byte(r.FormData.Encode())) } func handleMarshalBody(c *Client, r *Request) error { ct := "" if r.Headers != nil { - ct = r.Headers.Get(hdrContentTypeKey) + ct = r.Headers.Get(header.ContentType) } if ct == "" { - ct = c.Headers.Get(hdrContentTypeKey) + ct = c.Headers.Get(header.ContentType) } if ct != "" { if util.IsXMLType(ct) { @@ -249,7 +249,7 @@ func parseRequestBody(c *Client, r *Request) (err error) { return } // body is in-memory []byte, so we can guess content type - if r.getHeader(hdrContentTypeKey) == "" { + if r.getHeader(header.ContentType) == "" { r.SetContentType(http.DetectContentType(r.body)) } return diff --git a/req.go b/req.go index e3e4caee..507a1ce9 100644 --- a/req.go +++ b/req.go @@ -6,17 +6,6 @@ import ( "net/url" ) -const ( - hdrUserAgentKey = "User-Agent" - hdrUserAgentValue = "req/v3 (https://github.com/imroc/req)" - hdrLocationKey = "Location" - hdrContentTypeKey = "Content-Type" - plainTextContentType = "text/plain; charset=utf-8" - jsonContentType = "application/json; charset=utf-8" - xmlContentType = "text/xml; charset=utf-8" - formContentType = "application/x-www-form-urlencoded" -) - type kv struct { Key string Value string diff --git a/req_test.go b/req_test.go index a1cdccb7..27d0b6f5 100644 --- a/req_test.go +++ b/req_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "go/token" "golang.org/x/text/encoding/simplifiedchinese" @@ -128,7 +129,7 @@ func handlePost(w http.ResponseWriter, r *http.Request) { case "/form": r.ParseForm() ret, _ := json.Marshal(&r.Form) - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) w.Write(ret) case "/multipart": r.ParseMultipartForm(10e6) @@ -136,24 +137,24 @@ func handlePost(w http.ResponseWriter, r *http.Request) { m["values"] = r.MultipartForm.Value m["files"] = r.MultipartForm.File ret, _ := json.Marshal(&m) - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) w.Write(ret) case "/search": handleSearch(w, r) case "/redirect": io.Copy(ioutil.Discard, r.Body) - w.Header().Set(hdrLocationKey, "/") + w.Header().Set(header.Location, "/") w.WriteHeader(http.StatusMovedPermanently) case "/content-type": io.Copy(ioutil.Discard, r.Body) - w.Write([]byte(r.Header.Get(hdrContentTypeKey))) + w.Write([]byte(r.Header.Get(header.ContentType))) case "/echo": b, _ := ioutil.ReadAll(r.Body) e := Echo{ Header: r.Header, Body: string(b), } - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) result, _ := json.Marshal(&e) w.Write(result) } @@ -181,10 +182,10 @@ func handleSearch(w http.ResponseWriter, r *http.Request) { tp := r.FormValue("type") var marshalFunc func(v interface{}) ([]byte, error) if tp == "xml" { - w.Header().Set(hdrContentTypeKey, xmlContentType) + w.Header().Set(header.ContentType, header.XmlContentType) marshalFunc = xml.Marshal } else { - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) marshalFunc = json.Marshal } var result interface{} @@ -229,7 +230,7 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) case "/too-many": w.WriteHeader(http.StatusTooManyRequests) - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) w.Write([]byte(`{"errMsg":"too many requests"}`)) case "/chunked": w.Header().Add("Trailer", "Expires") @@ -239,9 +240,9 @@ func handleGet(w http.ResponseWriter, r *http.Request) { case "/json": r.ParseForm() if r.FormValue("type") != "no" { - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) } - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) if r.FormValue("error") == "yes" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"message": "not allowed"}`)) @@ -251,7 +252,7 @@ func handleGet(w http.ResponseWriter, r *http.Request) { case "/xml": r.ParseForm() if r.FormValue("type") != "no" { - w.Header().Set(hdrContentTypeKey, xmlContentType) + w.Header().Set(header.ContentType, header.XmlContentType) } w.Write([]byte(`roc`)) case "/unlimited-redirect": @@ -266,23 +267,23 @@ func handleGet(w http.ResponseWriter, r *http.Request) { b, _ := ioutil.ReadAll(r.Body) w.Write(b) case "/gbk": - w.Header().Set(hdrContentTypeKey, "text/plain; charset=gbk") + w.Header().Set(header.ContentType, "text/plain; charset=gbk") w.Write(toGbk("我是roc")) case "/gbk-no-charset": b, err := ioutil.ReadFile(tests.GetTestFilePath("sample-gbk.html")) if err != nil { panic(err) } - w.Header().Set(hdrContentTypeKey, "text/html") + w.Header().Set(header.ContentType, "text/html") w.Write(b) case "/header": b, _ := json.Marshal(r.Header) - w.Header().Set(hdrContentTypeKey, jsonContentType) + w.Header().Set(header.ContentType, header.JsonContentType) w.Write(b) case "/user-agent": - w.Write([]byte(r.Header.Get(hdrUserAgentKey))) + w.Write([]byte(r.Header.Get(header.UserAgent))) case "/content-type": - w.Write([]byte(r.Header.Get(hdrContentTypeKey))) + w.Write([]byte(r.Header.Get(header.ContentType))) case "/query-parameter": w.Write([]byte(r.URL.RawQuery)) case "/search": diff --git a/request.go b/request.go index b199543e..e56a8ea4 100644 --- a/request.go +++ b/request.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/hashicorp/go-multierror" "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -617,7 +618,7 @@ func (r *Request) SetBodyJsonString(body string) *Request { // SetBodyJsonBytes set the request body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { - r.SetContentType(jsonContentType) + r.SetContentType(header.JsonContentType) return r.SetBodyBytes(body) } @@ -641,7 +642,7 @@ func (r *Request) SetBodyXmlString(body string) *Request { // SetBodyXmlBytes set the request body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { - r.SetContentType(xmlContentType) + r.SetContentType(header.XmlContentType) return r.SetBodyBytes(body) } @@ -658,7 +659,7 @@ func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { // SetContentType set the `Content-Type` for the request. func (r *Request) SetContentType(contentType string) *Request { - return r.SetHeader(hdrContentTypeKey, contentType) + return r.SetHeader(header.ContentType, contentType) } // Context method returns the Context if its already set in request diff --git a/request_test.go b/request_test.go index a126a54c..7a5597bf 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "io" "io/ioutil" @@ -316,7 +317,7 @@ func TestSetBodyMarshal(t *testing.T) { Set: func(r *Request) { var user User user.Username = username - r.SetBody(&user).SetContentType(xmlContentType) + r.SetBody(&user).SetContentType(header.XmlContentType) }, Assert: assertUsernameXml, }, @@ -392,49 +393,49 @@ func TestSetBody(t *testing.T) { SetBody: func(r *Request) { // SetBody with string r.SetBody(body) }, - ContentType: plainTextContentType, + ContentType: header.PlainTextContentType, }, { SetBody: func(r *Request) { // SetBody with []byte r.SetBody([]byte(body)) }, - ContentType: plainTextContentType, + ContentType: header.PlainTextContentType, }, { SetBody: func(r *Request) { // SetBodyString r.SetBodyString(body) }, - ContentType: plainTextContentType, + ContentType: header.PlainTextContentType, }, { SetBody: func(r *Request) { // SetBodyBytes r.SetBodyBytes([]byte(body)) }, - ContentType: plainTextContentType, + ContentType: header.PlainTextContentType, }, { SetBody: func(r *Request) { // SetBodyJsonString r.SetBodyJsonString(body) }, - ContentType: jsonContentType, + ContentType: header.JsonContentType, }, { SetBody: func(r *Request) { // SetBodyJsonBytes r.SetBodyJsonBytes([]byte(body)) }, - ContentType: jsonContentType, + ContentType: header.JsonContentType, }, { SetBody: func(r *Request) { // SetBodyXmlString r.SetBodyXmlString(body) }, - ContentType: xmlContentType, + ContentType: header.XmlContentType, }, { SetBody: func(r *Request) { // SetBodyXmlBytes r.SetBodyXmlBytes([]byte(body)) }, - ContentType: xmlContentType, + ContentType: header.XmlContentType, }, } for _, tc := range testCases { @@ -443,7 +444,7 @@ func TestSetBody(t *testing.T) { var e Echo resp, err := r.SetResult(&e).Post("/echo") assertSuccess(t, resp, err) - tests.AssertEqual(t, tc.ContentType, e.Header.Get(hdrContentTypeKey)) + tests.AssertEqual(t, tc.ContentType, e.Header.Get(header.ContentType)) tests.AssertEqual(t, body, e.Body) } } @@ -492,7 +493,7 @@ func TestHeader(t *testing.T) { func testHeader(t *testing.T, c *Client) { // Set User-Agent customUserAgent := "My Custom User Agent" - resp, err := c.R().SetHeader(hdrUserAgentKey, customUserAgent).Get("/user-agent") + resp, err := c.R().SetHeader(header.UserAgent, customUserAgent).Get("/user-agent") assertSuccess(t, resp, err) tests.AssertEqual(t, customUserAgent, resp.String()) @@ -821,11 +822,11 @@ func TestAutoDetectRequestContentType(t *testing.T) { resp, err = c.R().SetBodyJsonString(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, jsonContentType, resp.String()) + tests.AssertEqual(t, header.JsonContentType, resp.String()) - resp, err = c.R().SetContentType(xmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") + resp, err = c.R().SetContentType(header.XmlContentType).SetBody(`{"msg": "test"}`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, xmlContentType, resp.String()) + tests.AssertEqual(t, header.XmlContentType, resp.String()) resp, err = c.R().SetBody(`

hello

`).Post("/content-type") assertSuccess(t, resp, err) @@ -833,7 +834,7 @@ func TestAutoDetectRequestContentType(t *testing.T) { resp, err = c.R().SetBody(`hello world`).Post("/content-type") assertSuccess(t, resp, err) - tests.AssertEqual(t, plainTextContentType, resp.String()) + tests.AssertEqual(t, header.PlainTextContentType, resp.String()) } func TestSetFileUploadCheck(t *testing.T) { diff --git a/request_wrapper_test.go b/request_wrapper_test.go index d0204ded..4c74cbe9 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -3,6 +3,7 @@ package req import ( "bytes" "context" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "net/http" "testing" @@ -40,7 +41,7 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { SetPathParams(map[string]string{"test": "test"}), SetFormData(map[string]string{"test": "test"}), SetFormDataFromValues(nil), - SetContentType(jsonContentType), + SetContentType(header.JsonContentType), AddRetryCondition(func(rep *Response, err error) bool { return err != nil }), diff --git a/response.go b/response.go index da9654ff..60593bb8 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,7 @@ package req import ( + "github.com/imroc/req/v3/internal/header" "io/ioutil" "net/http" "strings" @@ -38,7 +39,7 @@ func (r *Response) GetContentType() string { if r.Response == nil { return "" } - return r.Header.Get(hdrContentTypeKey) + return r.Header.Get(header.ContentType) } // Result returns the response value as an object if it has one diff --git a/response_test.go b/response_test.go index d87cd211..7ec82376 100644 --- a/response_test.go +++ b/response_test.go @@ -1,6 +1,7 @@ package req import ( + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "net/http" "testing" @@ -67,6 +68,6 @@ func TestResponseWrap(t *testing.T) { assertSuccess(t, resp, err) tests.AssertEqual(t, true, resp.GetStatusCode() == http.StatusOK) tests.AssertEqual(t, true, resp.GetStatus() == "200 OK") - tests.AssertEqual(t, true, resp.GetHeader(hdrContentTypeKey) == jsonContentType) - tests.AssertEqual(t, true, len(resp.GetHeaderValues(hdrContentTypeKey)) == 1) + tests.AssertEqual(t, true, resp.GetHeader(header.ContentType) == header.JsonContentType) + tests.AssertEqual(t, true, len(resp.GetHeaderValues(header.ContentType)) == 1) } diff --git a/transport.go b/transport.go index 78665ccc..017fc2cd 100644 --- a/transport.go +++ b/transport.go @@ -20,6 +20,7 @@ import ( "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/socks" reqtls "github.com/imroc/req/v3/internal/tls" @@ -2557,7 +2558,7 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. - userAgent := hdrUserAgentValue + userAgent := header.DefaultUserAgent if headerHas(r.Header, "User-Agent") { userAgent = r.Header.Get("User-Agent") } From fd78f1b249cbe9f37c19c941c3fda4ba37e697a1 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 17 Jun 2022 19:47:53 +0800 Subject: [PATCH 506/843] optimize dump --- internal/dump/dump.go | 39 +++++++++++++++++++++++++++++++++++++++ textproto_reader.go | 18 +++--------------- transport.go | 15 +++------------ 3 files changed, 45 insertions(+), 27 deletions(-) diff --git a/internal/dump/dump.go b/internal/dump/dump.go index f36630bf..5ffffaca 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -3,6 +3,7 @@ package dump import ( "context" "io" + "net/http" ) // Options controls the dump behavior. @@ -67,11 +68,40 @@ func (d *Dumper) WrapWriter(w io.Writer) io.Writer { } } +// GetResponseHeaderDumpers return Dumpers which need dump response header. +func GetResponseHeaderDumpers(ctx context.Context, dump *Dumper) Dumpers { + dumpers := GetDumpers(ctx, dump) + var ds []*Dumper + for _, d := range dumpers { + if d.ResponseHeader() { + ds = append(ds, d) + } + } + return Dumpers(ds) +} + +// Dumpers is an array of Dumpper +type Dumpers []*Dumper + +// ShouldDump is true if Dumper is not empty. +func (ds Dumpers) ShouldDump() bool { + return len(ds) > 0 +} + +// Dump with all dumpers. +func (ds Dumpers) Dump(p []byte) { + for _, d := range ds { + d.Dump(p) + } +} + +// Dumper is the dump tool. type Dumper struct { Options ch chan []byte } +// NewDumper create a new Dumper. func NewDumper(opt Options) *Dumper { d := &Dumper{ Options: opt, @@ -135,3 +165,12 @@ func GetDumpers(ctx context.Context, dump *Dumper) []*Dumper { } return dumps } + +func WrapResponseBodyIfNeeded(res *http.Response, req *http.Request, dump *Dumper) { + dumps := GetDumpers(req.Context(), dump) + for _, d := range dumps { + if d.ResponseBody() { + res.Body = d.WrapReadCloser(res.Body) + } + } +} diff --git a/textproto_reader.go b/textproto_reader.go index 15581e40..b3a9096f 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -33,21 +33,11 @@ type textprotoReader struct { // To avoid denial of service attacks, the provided bufio.Reader // should be reading from an io.LimitReader or similar textprotoReader to bound // the size of responses. -func newTextprotoReader(r *bufio.Reader, dumps []*dump.Dumper) *textprotoReader { +func newTextprotoReader(r *bufio.Reader, ds dump.Dumpers) *textprotoReader { commonHeaderOnce.Do(initCommonHeader) t := &textprotoReader{R: r} - if len(dumps) > 0 { - dd := []*dump.Dumper{} - for _, dump := range dumps { - if dump.ResponseHeader() { - dd = append(dd, dump) - } - } - dumps = dd - } - - if len(dumps) > 0 { + if ds.ShouldDump() { t.readLine = func() (line []byte, isPrefix bool, err error) { line, err = t.R.ReadSlice('\n') if len(line) == 0 { @@ -57,9 +47,7 @@ func newTextprotoReader(r *bufio.Reader, dumps []*dump.Dumper) *textprotoReader return } err = nil - for _, dump := range dumps { - dump.Dump(line) - } + ds.Dump(line) if line[len(line)-1] == '\n' { drop := 1 if len(line) > 1 && line[len(line)-2] == '\r' { diff --git a/transport.go b/transport.go index 017fc2cd..f01b077b 100644 --- a/transport.go +++ b/transport.go @@ -291,7 +291,7 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { t.wrapResponseBody(res, wrap) } t.autoDecodeResponseBody(res) - t.dumpResponseBody(res, req) + dump.WrapResponseBodyIfNeeded(res, req, t.dump) } func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFunc) { @@ -305,15 +305,6 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu } } -func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { - dumps := dump.GetDumpers(req.Context(), t.dump) - for _, dump := range dumps { - if dump.ResponseBody() { - res.Body = dump.WrapReadCloser(res.Body) - } - } -} - func (t *Transport) autoDecodeResponseBody(res *http.Response) { if t.ResponseOptions == nil { return @@ -1903,8 +1894,8 @@ func fixPragmaCacheControl(header http.Header) { // 100-continue") from the server. It returns the final non-100 one. // trace is optional. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - dumps := dump.GetDumpers(req.Context(), pc.t.dump) - tp := newTextprotoReader(pc.br, dumps) + ds := dump.GetResponseHeaderDumpers(req.Context(), pc.t.dump) + tp := newTextprotoReader(pc.br, ds) resp := &http.Response{ Request: req, } From 6cec0a5c65cc26007f44321e05257820ccca1875 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 17:04:14 +0800 Subject: [PATCH 507/843] Support HTTP3 --- altsvc.go | 16 + client.go | 17 +- go.mod | 16 +- go.sum | 301 +++- internal/altsvcutil/altsvcutil.go | 212 +++ internal/handshake/aead.go | 162 +++ internal/handshake/aead_test.go | 204 +++ internal/handshake/crypto_setup.go | 820 +++++++++++ internal/handshake/crypto_setup_test.go | 864 +++++++++++ internal/handshake/handshake_suite_test.go | 48 + internal/handshake/header_protector.go | 137 ++ internal/handshake/hkdf.go | 29 + internal/handshake/hkdf_test.go | 17 + internal/handshake/initial_aead.go | 82 ++ internal/handshake/initial_aead_test.go | 219 +++ internal/handshake/interface.go | 103 ++ .../handshake/mock_handshake_runner_test.go | 84 ++ internal/handshake/mockgen.go | 3 + internal/handshake/retry.go | 63 + internal/handshake/retry_test.go | 36 + internal/handshake/session_ticket.go | 48 + internal/handshake/session_ticket_test.go | 54 + internal/handshake/tls_extension_handler.go | 69 + .../handshake/tls_extension_handler_test.go | 210 +++ internal/handshake/token_generator.go | 134 ++ internal/handshake/token_generator_test.go | 127 ++ internal/handshake/token_protector.go | 89 ++ internal/handshake/token_protector_test.go | 67 + internal/handshake/updatable_aead.go | 324 +++++ internal/handshake/updatable_aead_test.go | 528 +++++++ internal/http2/transport.go | 135 +- internal/http3/body.go | 130 ++ internal/http3/body_test.go | 54 + internal/http3/client.go | 428 ++++++ internal/http3/client_test.go | 1003 +++++++++++++ internal/http3/error_codes.go | 73 + internal/http3/error_codes_test.go | 39 + internal/http3/frames.go | 164 +++ internal/http3/frames_test.go | 245 ++++ internal/http3/gzip_reader.go | 39 + internal/http3/http3_suite_test.go | 38 + internal/http3/http_stream.go | 71 + internal/http3/http_stream_test.go | 150 ++ internal/http3/request.go | 113 ++ internal/http3/request_test.go | 197 +++ internal/http3/request_writer.go | 291 ++++ internal/http3/request_writer_test.go | 112 ++ internal/http3/response_writer.go | 118 ++ internal/http3/response_writer_test.go | 150 ++ internal/http3/roundtrip.go | 251 ++++ internal/http3/roundtrip_test.go | 252 ++++ internal/http3/server.go | 736 ++++++++++ internal/http3/server_test.go | 1289 +++++++++++++++++ internal/logging/frame.go | 66 + internal/logging/interface.go | 135 ++ internal/logging/logging_suite_test.go | 25 + .../logging/mock_connection_tracer_test.go | 351 +++++ internal/logging/mock_tracer_test.go | 76 + internal/logging/mockgen.go | 4 + internal/logging/multiplex.go | 219 +++ internal/logging/multiplex_test.go | 266 ++++ internal/logging/packet_header.go | 27 + internal/logging/packet_header_test.go | 60 + internal/logging/types.go | 94 ++ .../ackhandler/received_packet_handler.go | 105 ++ .../mocks/ackhandler/sent_packet_handler.go | 240 +++ internal/mocks/congestion.go | 192 +++ internal/mocks/connection_flow_controller.go | 128 ++ internal/mocks/crypto_setup.go | 264 ++++ internal/mocks/logging/connection_tracer.go | 352 +++++ internal/mocks/logging/tracer.go | 77 + internal/mocks/long_header_opener.go | 76 + internal/mocks/mockgen.go | 20 + internal/mocks/quic/early_conn.go | 255 ++++ internal/mocks/quic/early_listener.go | 80 + internal/mocks/quic/stream.go | 176 +++ internal/mocks/short_header_opener.go | 77 + internal/mocks/short_header_sealer.go | 89 ++ internal/mocks/stream_flow_controller.go | 140 ++ internal/mocks/tls/client_session_cache.go | 62 + internal/netutil/addr.go | 40 + internal/protocol/connection_id.go | 69 + internal/protocol/connection_id_test.go | 108 ++ internal/protocol/encryption_level.go | 30 + internal/protocol/encryption_level_test.go | 20 + internal/protocol/key_phase.go | 36 + internal/protocol/key_phase_test.go | 27 + internal/protocol/packet_number.go | 79 + internal/protocol/packet_number_test.go | 204 +++ internal/protocol/params.go | 193 +++ internal/protocol/params_test.go | 13 + internal/protocol/perspective.go | 26 + internal/protocol/perspective_test.go | 19 + internal/protocol/protocol.go | 97 ++ internal/protocol/protocol_suite_test.go | 13 + internal/protocol/protocol_test.go | 25 + internal/protocol/stream.go | 76 + internal/protocol/stream_test.go | 70 + internal/protocol/version.go | 77 + internal/protocol/version_test.go | 121 ++ internal/qerr/error_codes.go | 88 ++ internal/qerr/errorcodes_test.go | 52 + internal/qerr/errors.go | 125 ++ internal/qerr/errors_suite_test.go | 13 + internal/qerr/errors_test.go | 124 ++ internal/qtls/go116.go | 100 ++ internal/qtls/go117.go | 100 ++ internal/qtls/go118.go | 100 ++ internal/qtls/go119.go | 6 + internal/qtls/go_oldversion.go | 7 + internal/qtls/qtls_suite_test.go | 25 + internal/qtls/qtls_test.go | 17 + internal/quicvarint/io.go | 68 + internal/quicvarint/io_test.go | 115 ++ internal/quicvarint/quicvarint_suite_test.go | 13 + internal/quicvarint/varint.go | 139 ++ internal/quicvarint/varint_test.go | 221 +++ internal/testdata/ca.pem | 17 + internal/testdata/cert.go | 55 + internal/testdata/cert.pem | 18 + internal/testdata/cert_test.go | 31 + internal/testdata/generate_key.sh | 24 + internal/testdata/priv.key | 28 + internal/testdata/testdata_suite_test.go | 13 + internal/transport/transport.go | 2 - internal/utils/atomic_bool.go | 22 + internal/utils/atomic_bool_test.go | 29 + internal/utils/buffered_write_closer.go | 26 + internal/utils/buffered_write_closer_test.go | 26 + internal/utils/byteinterval_linkedlist.go | 217 +++ internal/utils/byteoder_big_endian_test.go | 107 ++ internal/utils/byteorder.go | 17 + internal/utils/byteorder_big_endian.go | 89 ++ internal/utils/gen.go | 5 + internal/utils/ip.go | 10 + internal/utils/ip_test.go | 17 + internal/utils/linkedlist/README.md | 11 + internal/utils/linkedlist/linkedlist.go | 218 +++ internal/utils/log.go | 131 ++ internal/utils/log_test.go | 144 ++ internal/utils/minmax.go | 170 +++ internal/utils/minmax_test.go | 123 ++ internal/utils/new_connection_id.go | 12 + internal/utils/newconnectionid_linkedlist.go | 217 +++ internal/utils/packet_interval.go | 9 + internal/utils/packetinterval_linkedlist.go | 217 +++ internal/utils/rand.go | 29 + internal/utils/rand_test.go | 32 + internal/utils/rtt_stats.go | 127 ++ internal/utils/rtt_stats_test.go | 157 ++ internal/utils/streamframe_interval.go | 9 + internal/utils/timer.go | 53 + internal/utils/timer_test.go | 87 ++ internal/utils/utils_suite_test.go | 13 + internal/wire/ack_frame.go | 252 ++++ internal/wire/ack_frame_test.go | 454 ++++++ internal/wire/ack_range.go | 14 + internal/wire/ack_range_test.go | 13 + internal/wire/connection_close_frame.go | 84 ++ internal/wire/connection_close_frame_test.go | 153 ++ internal/wire/crypto_frame.go | 103 ++ internal/wire/crypto_frame_test.go | 148 ++ internal/wire/data_blocked_frame.go | 39 + internal/wire/data_blocked_frame_test.go | 54 + internal/wire/datagram_frame.go | 86 ++ internal/wire/datagram_frame_test.go | 154 ++ internal/wire/extended_header.go | 250 ++++ internal/wire/extended_header_test.go | 481 ++++++ internal/wire/frame_parser.go | 144 ++ internal/wire/frame_parser_test.go | 410 ++++++ internal/wire/handshake_done_frame.go | 29 + internal/wire/header.go | 275 ++++ internal/wire/header_test.go | 583 ++++++++ internal/wire/interface.go | 20 + internal/wire/log.go | 72 + internal/wire/log_test.go | 168 +++ internal/wire/max_data_frame.go | 41 + internal/wire/max_data_frame_test.go | 57 + internal/wire/max_stream_data_frame.go | 47 + internal/wire/max_stream_data_frame_test.go | 63 + internal/wire/max_streams_frame.go | 56 + internal/wire/max_streams_frame_test.go | 107 ++ internal/wire/new_connection_id_frame.go | 81 ++ internal/wire/new_connection_id_frame_test.go | 104 ++ internal/wire/new_token_frame.go | 49 + internal/wire/new_token_frame_test.go | 66 + internal/wire/path_challenge_frame.go | 39 + internal/wire/path_challenge_frame_test.go | 48 + internal/wire/path_response_frame.go | 39 + internal/wire/path_response_frame_test.go | 47 + internal/wire/ping_frame.go | 28 + internal/wire/ping_frame_test.go | 39 + internal/wire/pool.go | 33 + internal/wire/pool_test.go | 24 + internal/wire/reset_stream_frame.go | 59 + internal/wire/reset_stream_frame_test.go | 70 + internal/wire/retire_connection_id_frame.go | 37 + .../wire/retire_connection_id_frame_test.go | 53 + internal/wire/stop_sending_frame.go | 49 + internal/wire/stop_sending_frame_test.go | 63 + internal/wire/stream_data_blocked_frame.go | 47 + .../wire/stream_data_blocked_frame_test.go | 63 + internal/wire/stream_frame.go | 190 +++ internal/wire/stream_frame_test.go | 443 ++++++ internal/wire/streams_blocked_frame.go | 56 + internal/wire/streams_blocked_frame_test.go | 108 ++ internal/wire/transport_parameter_test.go | 612 ++++++++ internal/wire/transport_parameters.go | 476 ++++++ internal/wire/version_negotiation.go | 55 + internal/wire/version_negotiation_test.go | 83 ++ internal/wire/wire_suite_test.go | 31 + pkg/altsvc/altsvc.go | 51 + pkg/altsvc/jar.go | 6 + roundtrip.go | 9 +- tls.go | 35 - transport.go | 219 ++- transport_wrapper.go | 9 - 217 files changed, 27576 insertions(+), 187 deletions(-) create mode 100644 altsvc.go create mode 100644 internal/altsvcutil/altsvcutil.go create mode 100644 internal/handshake/aead.go create mode 100644 internal/handshake/aead_test.go create mode 100644 internal/handshake/crypto_setup.go create mode 100644 internal/handshake/crypto_setup_test.go create mode 100644 internal/handshake/handshake_suite_test.go create mode 100644 internal/handshake/header_protector.go create mode 100644 internal/handshake/hkdf.go create mode 100644 internal/handshake/hkdf_test.go create mode 100644 internal/handshake/initial_aead.go create mode 100644 internal/handshake/initial_aead_test.go create mode 100644 internal/handshake/interface.go create mode 100644 internal/handshake/mock_handshake_runner_test.go create mode 100644 internal/handshake/mockgen.go create mode 100644 internal/handshake/retry.go create mode 100644 internal/handshake/retry_test.go create mode 100644 internal/handshake/session_ticket.go create mode 100644 internal/handshake/session_ticket_test.go create mode 100644 internal/handshake/tls_extension_handler.go create mode 100644 internal/handshake/tls_extension_handler_test.go create mode 100644 internal/handshake/token_generator.go create mode 100644 internal/handshake/token_generator_test.go create mode 100644 internal/handshake/token_protector.go create mode 100644 internal/handshake/token_protector_test.go create mode 100644 internal/handshake/updatable_aead.go create mode 100644 internal/handshake/updatable_aead_test.go create mode 100644 internal/http3/body.go create mode 100644 internal/http3/body_test.go create mode 100644 internal/http3/client.go create mode 100644 internal/http3/client_test.go create mode 100644 internal/http3/error_codes.go create mode 100644 internal/http3/error_codes_test.go create mode 100644 internal/http3/frames.go create mode 100644 internal/http3/frames_test.go create mode 100644 internal/http3/gzip_reader.go create mode 100644 internal/http3/http3_suite_test.go create mode 100644 internal/http3/http_stream.go create mode 100644 internal/http3/http_stream_test.go create mode 100644 internal/http3/request.go create mode 100644 internal/http3/request_test.go create mode 100644 internal/http3/request_writer.go create mode 100644 internal/http3/request_writer_test.go create mode 100644 internal/http3/response_writer.go create mode 100644 internal/http3/response_writer_test.go create mode 100644 internal/http3/roundtrip.go create mode 100644 internal/http3/roundtrip_test.go create mode 100644 internal/http3/server.go create mode 100644 internal/http3/server_test.go create mode 100644 internal/logging/frame.go create mode 100644 internal/logging/interface.go create mode 100644 internal/logging/logging_suite_test.go create mode 100644 internal/logging/mock_connection_tracer_test.go create mode 100644 internal/logging/mock_tracer_test.go create mode 100644 internal/logging/mockgen.go create mode 100644 internal/logging/multiplex.go create mode 100644 internal/logging/multiplex_test.go create mode 100644 internal/logging/packet_header.go create mode 100644 internal/logging/packet_header_test.go create mode 100644 internal/logging/types.go create mode 100644 internal/mocks/ackhandler/received_packet_handler.go create mode 100644 internal/mocks/ackhandler/sent_packet_handler.go create mode 100644 internal/mocks/congestion.go create mode 100644 internal/mocks/connection_flow_controller.go create mode 100644 internal/mocks/crypto_setup.go create mode 100644 internal/mocks/logging/connection_tracer.go create mode 100644 internal/mocks/logging/tracer.go create mode 100644 internal/mocks/long_header_opener.go create mode 100644 internal/mocks/mockgen.go create mode 100644 internal/mocks/quic/early_conn.go create mode 100644 internal/mocks/quic/early_listener.go create mode 100644 internal/mocks/quic/stream.go create mode 100644 internal/mocks/short_header_opener.go create mode 100644 internal/mocks/short_header_sealer.go create mode 100644 internal/mocks/stream_flow_controller.go create mode 100644 internal/mocks/tls/client_session_cache.go create mode 100644 internal/netutil/addr.go create mode 100644 internal/protocol/connection_id.go create mode 100644 internal/protocol/connection_id_test.go create mode 100644 internal/protocol/encryption_level.go create mode 100644 internal/protocol/encryption_level_test.go create mode 100644 internal/protocol/key_phase.go create mode 100644 internal/protocol/key_phase_test.go create mode 100644 internal/protocol/packet_number.go create mode 100644 internal/protocol/packet_number_test.go create mode 100644 internal/protocol/params.go create mode 100644 internal/protocol/params_test.go create mode 100644 internal/protocol/perspective.go create mode 100644 internal/protocol/perspective_test.go create mode 100644 internal/protocol/protocol.go create mode 100644 internal/protocol/protocol_suite_test.go create mode 100644 internal/protocol/protocol_test.go create mode 100644 internal/protocol/stream.go create mode 100644 internal/protocol/stream_test.go create mode 100644 internal/protocol/version.go create mode 100644 internal/protocol/version_test.go create mode 100644 internal/qerr/error_codes.go create mode 100644 internal/qerr/errorcodes_test.go create mode 100644 internal/qerr/errors.go create mode 100644 internal/qerr/errors_suite_test.go create mode 100644 internal/qerr/errors_test.go create mode 100644 internal/qtls/go116.go create mode 100644 internal/qtls/go117.go create mode 100644 internal/qtls/go118.go create mode 100644 internal/qtls/go119.go create mode 100644 internal/qtls/go_oldversion.go create mode 100644 internal/qtls/qtls_suite_test.go create mode 100644 internal/qtls/qtls_test.go create mode 100644 internal/quicvarint/io.go create mode 100644 internal/quicvarint/io_test.go create mode 100644 internal/quicvarint/quicvarint_suite_test.go create mode 100644 internal/quicvarint/varint.go create mode 100644 internal/quicvarint/varint_test.go create mode 100644 internal/testdata/ca.pem create mode 100644 internal/testdata/cert.go create mode 100644 internal/testdata/cert.pem create mode 100644 internal/testdata/cert_test.go create mode 100755 internal/testdata/generate_key.sh create mode 100644 internal/testdata/priv.key create mode 100644 internal/testdata/testdata_suite_test.go create mode 100644 internal/utils/atomic_bool.go create mode 100644 internal/utils/atomic_bool_test.go create mode 100644 internal/utils/buffered_write_closer.go create mode 100644 internal/utils/buffered_write_closer_test.go create mode 100644 internal/utils/byteinterval_linkedlist.go create mode 100644 internal/utils/byteoder_big_endian_test.go create mode 100644 internal/utils/byteorder.go create mode 100644 internal/utils/byteorder_big_endian.go create mode 100644 internal/utils/gen.go create mode 100644 internal/utils/ip.go create mode 100644 internal/utils/ip_test.go create mode 100644 internal/utils/linkedlist/README.md create mode 100644 internal/utils/linkedlist/linkedlist.go create mode 100644 internal/utils/log.go create mode 100644 internal/utils/log_test.go create mode 100644 internal/utils/minmax.go create mode 100644 internal/utils/minmax_test.go create mode 100644 internal/utils/new_connection_id.go create mode 100644 internal/utils/newconnectionid_linkedlist.go create mode 100644 internal/utils/packet_interval.go create mode 100644 internal/utils/packetinterval_linkedlist.go create mode 100644 internal/utils/rand.go create mode 100644 internal/utils/rand_test.go create mode 100644 internal/utils/rtt_stats.go create mode 100644 internal/utils/rtt_stats_test.go create mode 100644 internal/utils/streamframe_interval.go create mode 100644 internal/utils/timer.go create mode 100644 internal/utils/timer_test.go create mode 100644 internal/utils/utils_suite_test.go create mode 100644 internal/wire/ack_frame.go create mode 100644 internal/wire/ack_frame_test.go create mode 100644 internal/wire/ack_range.go create mode 100644 internal/wire/ack_range_test.go create mode 100644 internal/wire/connection_close_frame.go create mode 100644 internal/wire/connection_close_frame_test.go create mode 100644 internal/wire/crypto_frame.go create mode 100644 internal/wire/crypto_frame_test.go create mode 100644 internal/wire/data_blocked_frame.go create mode 100644 internal/wire/data_blocked_frame_test.go create mode 100644 internal/wire/datagram_frame.go create mode 100644 internal/wire/datagram_frame_test.go create mode 100644 internal/wire/extended_header.go create mode 100644 internal/wire/extended_header_test.go create mode 100644 internal/wire/frame_parser.go create mode 100644 internal/wire/frame_parser_test.go create mode 100644 internal/wire/handshake_done_frame.go create mode 100644 internal/wire/header.go create mode 100644 internal/wire/header_test.go create mode 100644 internal/wire/interface.go create mode 100644 internal/wire/log.go create mode 100644 internal/wire/log_test.go create mode 100644 internal/wire/max_data_frame.go create mode 100644 internal/wire/max_data_frame_test.go create mode 100644 internal/wire/max_stream_data_frame.go create mode 100644 internal/wire/max_stream_data_frame_test.go create mode 100644 internal/wire/max_streams_frame.go create mode 100644 internal/wire/max_streams_frame_test.go create mode 100644 internal/wire/new_connection_id_frame.go create mode 100644 internal/wire/new_connection_id_frame_test.go create mode 100644 internal/wire/new_token_frame.go create mode 100644 internal/wire/new_token_frame_test.go create mode 100644 internal/wire/path_challenge_frame.go create mode 100644 internal/wire/path_challenge_frame_test.go create mode 100644 internal/wire/path_response_frame.go create mode 100644 internal/wire/path_response_frame_test.go create mode 100644 internal/wire/ping_frame.go create mode 100644 internal/wire/ping_frame_test.go create mode 100644 internal/wire/pool.go create mode 100644 internal/wire/pool_test.go create mode 100644 internal/wire/reset_stream_frame.go create mode 100644 internal/wire/reset_stream_frame_test.go create mode 100644 internal/wire/retire_connection_id_frame.go create mode 100644 internal/wire/retire_connection_id_frame_test.go create mode 100644 internal/wire/stop_sending_frame.go create mode 100644 internal/wire/stop_sending_frame_test.go create mode 100644 internal/wire/stream_data_blocked_frame.go create mode 100644 internal/wire/stream_data_blocked_frame_test.go create mode 100644 internal/wire/stream_frame.go create mode 100644 internal/wire/stream_frame_test.go create mode 100644 internal/wire/streams_blocked_frame.go create mode 100644 internal/wire/streams_blocked_frame_test.go create mode 100644 internal/wire/transport_parameter_test.go create mode 100644 internal/wire/transport_parameters.go create mode 100644 internal/wire/version_negotiation.go create mode 100644 internal/wire/version_negotiation_test.go create mode 100644 internal/wire/wire_suite_test.go create mode 100644 pkg/altsvc/altsvc.go create mode 100644 pkg/altsvc/jar.go delete mode 100644 tls.go diff --git a/altsvc.go b/altsvc.go new file mode 100644 index 00000000..c03dd56f --- /dev/null +++ b/altsvc.go @@ -0,0 +1,16 @@ +package req + +import ( + "github.com/imroc/req/v3/pkg/altsvc" + "net/http" + "sync" + "time" +) + +type pendingAltSvc struct { + CurrentIndex int + Entries []*altsvc.AltSvc + Mu sync.Mutex + LastTime time.Time + Transport http.RoundTripper +} diff --git a/client.go b/client.go index 81856e88..b11d8b9a 100644 --- a/client.go +++ b/client.go @@ -874,6 +874,14 @@ func (c *Client) SetUnixSocket(file string) *Client { }) } +func (c *Client) EnableHttp3() *Client { + err := c.t.enableH3() + if err != nil { + c.log.Errorf("failed to enabled http3: %s", err.Error()) + } + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() @@ -882,7 +890,9 @@ func NewClient() *Client { // Clone copy and returns the Client func (c *Client) Clone() *Client { t := c.t.Clone() - t2, _ := http2.ConfigureTransports(transportImpl{t}) + t2 := &http2.Transport{ + Interface: transportImpl{t}, + } t.t2 = t2 client := *c.httpClient @@ -922,8 +932,11 @@ func C() *Client { IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}, + } + t2 := &http2.Transport{ + Interface: transportImpl{t}, } - t2, _ := http2.ConfigureTransports(transportImpl{t}) t.t2 = t2 jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) diff --git a/go.mod b/go.mod index 45abb226..05f85d36 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,21 @@ module github.com/imroc/req/v3 go 1.15 require ( + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/golang/mock v1.6.0 + github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 - golang.org/x/net v0.0.0-20220111093109-d55c255bac03 + github.com/lucas-clemente/quic-go v0.27.2 + github.com/marten-seemann/qpack v0.2.1 + github.com/marten-seemann/qtls-go1-16 v0.1.5 + github.com/marten-seemann/qtls-go1-17 v0.1.2 + github.com/marten-seemann/qtls-go1-18 v0.1.2 + github.com/onsi/ginkgo v1.16.5 + github.com/onsi/gomega v1.13.0 + golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect + golang.org/x/net v0.0.0-20220615171555-694bf12d69de + golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c // indirect golang.org/x/text v0.3.7 + golang.org/x/tools v0.1.11 // indirect + golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect ) diff --git a/go.sum b/go.sum index 59f8d4e3..40a0d90f 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,310 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= -golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.27.2 h1:zsMwwniyybb8B/UDNXRSYee7WpQJVOcjQEGgpw2ikXs= +github.com/lucas-clemente/quic-go v0.27.2/go.mod h1:vXgO/11FBSKM+js1NxoaQ/bPtVFYfB7uxhfHXyMhl1A= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220615171555-694bf12d69de h1:ogOG2+P6LjO2j55AkRScrkB2BFpd+Z8TY2wcM0Z3MGo= +golang.org/x/net v0.0.0-20220615171555-694bf12d69de/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI4N6/rdy1iusx77G3oU= +golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f h1:uF6paiQQebLeSXkrTqHqz0MXhXXS1KgF41eUdBNvxK0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/internal/altsvcutil/altsvcutil.go b/internal/altsvcutil/altsvcutil.go new file mode 100644 index 00000000..8556c7cb --- /dev/null +++ b/internal/altsvcutil/altsvcutil.go @@ -0,0 +1,212 @@ +package altsvcutil + +import ( + "bytes" + "fmt" + "github.com/imroc/req/v3/internal/netutil" + "github.com/imroc/req/v3/pkg/altsvc" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" +) + +type altAvcParser struct { + *bytes.Buffer +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +func validOptionalPort(port string) bool { + if port == "" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true +} + +// splitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +func splitHostPort(hostPort string) (host, port string) { + host = hostPort + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +func ParseHeader(value string) ([]*altsvc.AltSvc, error) { + p := newAltSvcParser(value) + return p.Parse() +} + +func newAltSvcParser(value string) *altAvcParser { + buf := bytes.NewBufferString(value) + return &altAvcParser{buf} +} + +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +func (p *altAvcParser) Parse() (as []*altsvc.AltSvc, err error) { + for { + a, e := p.parseOne() + if a != nil { + as = append(as, a) + } + if e != nil { + if e == io.EOF { + return + } else { + err = e + return + } + } + } + return +} + +func (p *altAvcParser) parseKv() (key, value string, haveNextField bool, err error) { + line, err := p.ReadBytes('=') + if len(line) == 0 { + return + } + key = strings.TrimSpace(string(line[:len(line)-1])) + bs := p.Bytes() + if len(bs) == 0 { + err = io.EOF + return + } + if bs[0] == '"' { + quoteIndex := 0 + for i := 1; i < len(bs); i++ { + if bs[i] == '"' { + quoteIndex = i + break + } + } + if quoteIndex == 0 { + err = fmt.Errorf("quote in alt-svc is not complete: %s", bs) + return + } + value = string(bs[1:quoteIndex]) + p.Next(quoteIndex + 1) + if len(bs) == quoteIndex+1 { + err = io.EOF + return + } + var b byte + b, err = p.ReadByte() + if err != nil { + return + } + if b == ';' { + haveNextField = true + } + } else { + delimIndex := 0 + LOOP: + for i, v := range bs { + switch v { + case ',': + delimIndex = i + break LOOP + case ';': + delimIndex = i + haveNextField = true + break LOOP + } + } + if delimIndex == 0 { + err = io.EOF + value = strings.TrimSpace(string(bs)) + return + } + p.Next(delimIndex + 1) + value = string(bs[:delimIndex]) + } + return +} + +func (p *altAvcParser) parseOne() (as *altsvc.AltSvc, err error) { + proto, addr, haveNextField, err := p.parseKv() + if proto == "" || addr == "" { + return + } + host, port := splitHostPort(addr) + + as = &altsvc.AltSvc{ + Protocol: proto, + Host: host, + Port: port, + Expire: endOfTime, + } + + if !haveNextField { + return + } + + key, ma, haveNextField, err := p.parseKv() + if key == "" || ma == "" { + return + } + if key != "ma" { + err = fmt.Errorf("expect ma field, got %s", key) + return + } + + maInt, err := strconv.ParseInt(ma, 10, 64) + if err != nil { + return + } + as.Expire = time.Now().Add(time.Duration(maInt) * time.Second) + + if !haveNextField { + return + } + + // drain useless fields + for { + _, _, haveNextField, err = p.parseKv() + if haveNextField { + continue + } else { + break + } + } + return +} + +func ConvertURL(a *altsvc.AltSvc, u *url.URL) *url.URL { + host, port := netutil.AuthorityHostPort(u.Scheme, u.Host) + uu := *u + modify := false + if a.Host != "" && a.Host != host { + host = a.Host + modify = true + } + if a.Port != "" && a.Port != port { + port = a.Port + modify = true + } + if modify { + uu.Host = net.JoinHostPort(host, port) + } + return &uu +} diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go new file mode 100644 index 00000000..1b8c28f8 --- /dev/null +++ b/internal/handshake/aead.go @@ -0,0 +1,162 @@ +package handshake + +import ( + "crypto/cipher" + "encoding/binary" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" + "github.com/imroc/req/v3/internal/utils" +) + +func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v quic.VersionNumber) cipher.AEAD { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) + iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) + return suite.AEAD(key, iv) +} + +type longHeaderSealer struct { + aead cipher.AEAD + headerProtector headerProtector + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderSealer = &longHeaderSealer{} + +func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { + return &longHeaderSealer{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return s.aead.Seal(dst, s.nonceBuf, src, ad) +} + +func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) +} + +func (s *longHeaderSealer) Overhead() int { + return s.aead.Overhead() +} + +type longHeaderOpener struct { + aead cipher.AEAD + headerProtector headerProtector + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderOpener = &longHeaderOpener{} + +func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { + return &longHeaderOpener{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) +} + +func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) + if err == nil { + o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) + } else { + err = ErrDecryptionFailed + } + return dec, err +} + +func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) +} + +type handshakeSealer struct { + LongHeaderSealer + + dropInitialKeys func() + dropped bool +} + +func newHandshakeSealer( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderSealer { + sealer := newLongHeaderSealer(aead, headerProtector) + // The client drops Initial keys when sending the first Handshake packet. + if perspective == protocol.PerspectiveServer { + return sealer + } + return &handshakeSealer{ + LongHeaderSealer: sealer, + dropInitialKeys: dropInitialKeys, + } +} + +func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + data := s.LongHeaderSealer.Seal(dst, src, pn, ad) + if !s.dropped { + s.dropInitialKeys() + s.dropped = true + } + return data +} + +type handshakeOpener struct { + LongHeaderOpener + + dropInitialKeys func() + dropped bool +} + +func newHandshakeOpener( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderOpener { + opener := newLongHeaderOpener(aead, headerProtector) + // The server drops Initial keys when first successfully processing a Handshake packet. + if perspective == protocol.PerspectiveClient { + return opener + } + return &handshakeOpener{ + LongHeaderOpener: opener, + dropInitialKeys: dropInitialKeys, + } +} + +func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) + if err == nil && !o.dropped { + o.dropInitialKeys() + o.dropped = true + } + return dec, err +} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go new file mode 100644 index 00000000..ae406f10 --- /dev/null +++ b/internal/handshake/aead_test.go @@ -0,0 +1,204 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/tls" + "fmt" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Long Header AEAD", func() { + for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + + return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), + newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) + } + + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("decodes the packet number", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) + Expect(err).To(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + }) + + Context("header encryption", func() { + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) + + It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := bytes.Repeat([]byte{0xff}, 16) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + }) + + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener() + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) + }) + }) + } + }) + + Describe("Long Header AEAD", func() { + var ( + dropped chan struct{} // use a chan because closing it twice will panic + aead cipher.AEAD + hp headerProtector + ) + dropCb := func() { close(dropped) } + msg := []byte("Lorem ipsum dolor sit amet.") + ad := []byte("Donec in velit neque.") + + BeforeEach(func() { + dropped = make(chan struct{}) + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err = cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) + }) + + Context("for the server", func() { + It("drops keys when first successfully processing a Handshake packet", func() { + serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) + // first try to open an invalid message + _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) + Expect(err).To(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) + // then open a valid message + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) + _, err = serverOpener.Open(nil, enc, 10, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(dropped).To(BeClosed()) + // now open the same message again to make sure the callback is only called once + _, err = serverOpener.Open(nil, enc, 10, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't drop keys when sealing a Handshake packet", func() { + serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) + serverSealer.Seal(nil, msg, 1, ad) + Expect(dropped).ToNot(BeClosed()) + }) + }) + + Context("for the client", func() { + It("drops keys when first sealing a Handshake packet", func() { + clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) + // seal the first message + clientSealer.Seal(nil, msg, 1, ad) + Expect(dropped).To(BeClosed()) + // seal another message to make sure the callback is only called once + clientSealer.Seal(nil, msg, 2, ad) + }) + + It("doesn't drop keys when processing a Handshake packet", func() { + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) + clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) + _, err := clientOpener.Open(nil, enc, 42, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) + }) + }) + }) + } +}) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go new file mode 100644 index 00000000..ed70ec57 --- /dev/null +++ b/internal/handshake/crypto_setup.go @@ -0,0 +1,820 @@ +package handshake + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "io" + "net" + "sync" + "time" + + "github.com/imroc/req/v3/internal/logging" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/qtls" + "github.com/imroc/req/v3/internal/utils" + "github.com/imroc/req/v3/internal/wire" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// TLS unexpected_message alert +const alertUnexpectedMessage uint8 = 10 + +type messageType uint8 + +// TLS handshake message types. +const ( + typeClientHello messageType = 1 + typeServerHello messageType = 2 + typeNewSessionTicket messageType = 4 + typeEncryptedExtensions messageType = 8 + typeCertificate messageType = 11 + typeCertificateRequest messageType = 13 + typeCertificateVerify messageType = 15 + typeFinished messageType = 20 +) + +func (m messageType) String() string { + switch m { + case typeClientHello: + return "ClientHello" + case typeServerHello: + return "ServerHello" + case typeNewSessionTicket: + return "NewSessionTicket" + case typeEncryptedExtensions: + return "EncryptedExtensions" + case typeCertificate: + return "Certificate" + case typeCertificateRequest: + return "CertificateRequest" + case typeCertificateVerify: + return "CertificateVerify" + case typeFinished: + return "Finished" + default: + return fmt.Sprintf("unknown message type: %d", m) + } +} + +const clientSessionStateRevision = 3 + +type conn struct { + localAddr, remoteAddr net.Addr + version quic.VersionNumber +} + +var _ ConnWithVersion = &conn{} + +func newConn(local, remote net.Addr, version quic.VersionNumber) ConnWithVersion { + return &conn{ + localAddr: local, + remoteAddr: remote, + version: version, + } +} + +var _ net.Conn = &conn{} + +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } +func (c *conn) GetQUICVersion() quic.VersionNumber { return c.version } + +type cryptoSetup struct { + tlsConf *tls.Config + extraConf *qtls.ExtraConfig + conn *qtls.Conn + + version quic.VersionNumber + + messageChan chan []byte + isReadingHandshakeMessage chan struct{} + readFirstHandshakeMessage bool + + ourParams *wire.TransportParameters + peerParams *wire.TransportParameters + paramsChan <-chan []byte + + runner handshakeRunner + + alertChan chan uint8 + // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns + handshakeDone chan struct{} + // is closed when Close() is called + closeChan chan struct{} + + zeroRTTParameters *wire.TransportParameters + clientHelloWritten bool + clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written + zeroRTTParametersChan chan<- *wire.TransportParameters + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + + perspective protocol.Perspective + + mutex sync.Mutex // protects all members below + + handshakeCompleteTime time.Time + + readEncLevel protocol.EncryptionLevel + writeEncLevel protocol.EncryptionLevel + + zeroRTTOpener LongHeaderOpener // only set for the server + zeroRTTSealer LongHeaderSealer // only set for the client + + initialStream io.Writer + initialOpener LongHeaderOpener + initialSealer LongHeaderSealer + + handshakeStream io.Writer + handshakeOpener LongHeaderOpener + handshakeSealer LongHeaderSealer + + aead *updatableAEAD + has1RTTSealer bool + has1RTTOpener bool +} + +var ( + _ qtls.RecordLayer = &cryptoSetup{} + _ CryptoSetup = &cryptoSetup{} +) + +// NewCryptoSetupClient creates a new crypto setup for the client +func NewCryptoSetupClient( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version quic.VersionNumber, +) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + cs, clientHelloWritten := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveClient, + version, + ) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs, clientHelloWritten +} + +// NewCryptoSetupServer creates a new crypto setup for the server +func NewCryptoSetupServer( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version quic.VersionNumber, +) CryptoSetup { + cs, _ := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveServer, + version, + ) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs +} + +func newCryptoSetup( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + perspective protocol.Perspective, + version quic.VersionNumber, +) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) + if tracer != nil { + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } + extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) + zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) + cs := &cryptoSetup{ + tlsConf: tlsConf, + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + aead: newUpdatableAEAD(rttStats, tracer, logger, version), + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + runner: runner, + ourParams: tp, + paramsChan: extHandler.TransportParameters(), + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + alertChan: make(chan uint8), + clientHelloWrittenChan: make(chan struct{}), + zeroRTTParametersChan: zeroRTTParametersChan, + messageChan: make(chan []byte, 100), + isReadingHandshakeMessage: make(chan struct{}), + closeChan: make(chan struct{}), + version: version, + } + var maxEarlyData uint32 + if enable0RTT { + maxEarlyData = 0xffffffff + } + cs.extraConf = &qtls.ExtraConfig{ + GetExtensions: extHandler.GetExtensions, + ReceivedExtensions: extHandler.ReceivedExtensions, + AlternativeRecordLayer: cs, + EnforceNextProtoSelection: true, + MaxEarlyData: maxEarlyData, + Accept0RTT: cs.accept0RTT, + Rejected0RTT: cs.rejected0RTT, + Enable0RTT: enable0RTT, + GetAppDataForSessionState: cs.marshalDataForSessionState, + SetAppDataFromSessionState: cs.handleDataFromSessionState, + } + return cs, zeroRTTParametersChan +} + +func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { + initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) + h.initialSealer = initialSealer + h.initialOpener = initialOpener + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } +} + +func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { + return h.aead.SetLargestAcked(pn) +} + +func (h *cryptoSetup) RunHandshake() { + // Handle errors that might occur when HandleData() is called. + handshakeComplete := make(chan struct{}) + handshakeErrChan := make(chan error, 1) + go func() { + defer close(h.handshakeDone) + if err := h.conn.Handshake(); err != nil { + handshakeErrChan <- err + return + } + close(handshakeComplete) + }() + + if h.perspective == protocol.PerspectiveClient { + select { + case err := <-handshakeErrChan: + h.onError(0, err.Error()) + return + case <-h.clientHelloWrittenChan: + } + } + + select { + case <-handshakeComplete: // return when the handshake is done + h.mutex.Lock() + h.handshakeCompleteTime = time.Now() + h.mutex.Unlock() + h.runner.OnHandshakeComplete() + case <-h.closeChan: + // wait until the Handshake() go routine has returned + <-h.handshakeDone + case alert := <-h.alertChan: + handshakeErr := <-handshakeErrChan + h.onError(alert, handshakeErr.Error()) + } +} + +func (h *cryptoSetup) onError(alert uint8, message string) { + var err error + if alert == 0 { + err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} + } else { + err = qerr.NewCryptoError(alert, message) + } + h.runner.OnError(err) +} + +// Close closes the crypto setup. +// It aborts the handshake, if it is still running. +// It must only be called once. +func (h *cryptoSetup) Close() error { + close(h.closeChan) + // wait until qtls.Handshake() actually returned + <-h.handshakeDone + return nil +} + +// handleMessage handles a TLS handshake message. +// It is called by the crypto streams when a new message is available. +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { + msgType := messageType(data[0]) + h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) + if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { + h.onError(alertUnexpectedMessage, err.Error()) + return false + } + h.messageChan <- data + if encLevel == protocol.Encryption1RTT { + h.handlePostHandshakeMessage() + return false + } +readLoop: + for { + select { + case data := <-h.paramsChan: + if data == nil { + h.onError(0x6d, "missing quic_transport_parameters extension") + } else { + h.handleTransportParameters(data) + } + case <-h.isReadingHandshakeMessage: + break readLoop + case <-h.handshakeDone: + break readLoop + case <-h.closeChan: + break readLoop + } + } + // We're done with the Initial encryption level after processing a ClientHello / ServerHello, + // but only if a handshake opener and sealer was created. + // Otherwise, a HelloRetryRequest was performed. + // We're done with the Handshake encryption level after processing the Finished message. + return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || + msgType == typeFinished +} + +func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { + var expected protocol.EncryptionLevel + switch msgType { + case typeClientHello, + typeServerHello: + expected = protocol.EncryptionInitial + case typeEncryptedExtensions, + typeCertificate, + typeCertificateRequest, + typeCertificateVerify, + typeFinished: + expected = protocol.EncryptionHandshake + case typeNewSessionTicket: + expected = protocol.Encryption1RTT + default: + return fmt.Errorf("unexpected handshake message: %d", msgType) + } + if encLevel != expected { + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) + } + return nil +} + +func (h *cryptoSetup) handleTransportParameters(data []byte) { + var tp wire.TransportParameters + if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { + h.runner.OnError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) + } + h.peerParams = &tp + h.runner.OnReceivedParams(h.peerParams) +} + +// must be called after receiving the transport parameters +func (h *cryptoSetup) marshalDataForSessionState() []byte { + buf := &bytes.Buffer{} + quicvarint.Write(buf, clientSessionStateRevision) + quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) + h.peerParams.MarshalForSessionTicket(buf) + return buf.Bytes() +} + +func (h *cryptoSetup) handleDataFromSessionState(data []byte) { + tp, err := h.handleDataFromSessionStateImpl(data) + if err != nil { + h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) + return + } + h.zeroRTTParameters = tp +} + +func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { + r := bytes.NewReader(data) + ver, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ver != clientSessionStateRevision { + return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return nil, err + } + return &tp, nil +} + +// only valid for the server +func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { + var appData []byte + // Save transport parameters to the session ticket if we're allowing 0-RTT. + if h.extraConf.MaxEarlyData > 0 { + appData = (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() + } + return h.conn.GetSessionTicket(appData) +} + +// accept0RTT is called for the server when receiving the client's session ticket. +// It decides whether to accept 0-RTT. +func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { + var t sessionTicket + if err := t.Unmarshal(sessionTicketData); err != nil { + h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + return false + } + valid := h.ourParams.ValidFor0RTT(t.Parameters) + if valid { + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) + } else { + h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") + } + return valid +} + +// rejected0RTT is called for the client when the server rejects 0-RTT. +func (h *cryptoSetup) rejected0RTT() { + h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") + + h.mutex.Lock() + had0RTTKeys := h.zeroRTTSealer != nil + h.zeroRTTSealer = nil + h.mutex.Unlock() + + if had0RTTKeys { + h.runner.DropKeys(protocol.Encryption0RTT) + } +} + +func (h *cryptoSetup) handlePostHandshakeMessage() { + // make sure the handshake has already completed + <-h.handshakeDone + + done := make(chan struct{}) + defer close(done) + + // h.alertChan is an unbuffered channel. + // If an error occurs during conn.HandlePostHandshakeMessage, + // it will be sent on this channel. + // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. + alertChan := make(chan uint8, 1) + go func() { + <-h.isReadingHandshakeMessage + select { + case alert := <-h.alertChan: + alertChan <- alert + case <-done: + } + }() + + if err := h.conn.HandlePostHandshakeMessage(); err != nil { + select { + case <-h.closeChan: + case alert := <-alertChan: + h.onError(alert, err.Error()) + } + } +} + +// ReadHandshakeMessage is called by TLS. +// It blocks until a new handshake message is available. +func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { + if !h.readFirstHandshakeMessage { + h.readFirstHandshakeMessage = true + } else { + select { + case h.isReadingHandshakeMessage <- struct{}{}: + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } + } + select { + case msg := <-h.messageChan: + return msg, nil + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } +} + +func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveClient { + panic("Received 0-RTT read key for the client") + } + h.zeroRTTOpener = newLongHeaderOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) + } + return + case qtls.EncryptionHandshake: + h.readEncLevel = protocol.EncryptionHandshake + h.handshakeOpener = newHandshakeOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.readEncLevel = protocol.Encryption1RTT + h.aead.SetReadKey(suite, trafficSecret) + h.has1RTTOpener = true + h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + default: + panic("unexpected read encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) + } +} + +func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveServer { + panic("Received 0-RTT write key for the server") + } + h.zeroRTTSealer = newLongHeaderSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) + } + return + case qtls.EncryptionHandshake: + h.writeEncLevel = protocol.EncryptionHandshake + h.handshakeSealer = newHandshakeSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.writeEncLevel = protocol.Encryption1RTT + h.aead.SetWriteKey(suite, trafficSecret) + h.has1RTTSealer = true + h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.zeroRTTSealer != nil { + h.zeroRTTSealer = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + default: + panic("unexpected write encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) + } +} + +// WriteRecord is called when TLS writes data +func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + //nolint:exhaustive // LS records can only be written for Initial and Handshake. + switch h.writeEncLevel { + case protocol.EncryptionInitial: + // assume that the first WriteRecord call contains the ClientHello + n, err := h.initialStream.Write(p) + if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { + h.clientHelloWritten = true + close(h.clientHelloWrittenChan) + if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { + h.logger.Debugf("Doing 0-RTT.") + h.zeroRTTParametersChan <- h.zeroRTTParameters + } else { + h.logger.Debugf("Not doing 0-RTT.") + h.zeroRTTParametersChan <- nil + } + } + return n, err + case protocol.EncryptionHandshake: + return h.handshakeStream.Write(p) + default: + panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) + } +} + +func (h *cryptoSetup) SendAlert(alert uint8) { + select { + case h.alertChan <- alert: + case <-h.closeChan: + // no need to send an alert when we've already closed + } +} + +// used a callback in the handshakeSealer and handshakeOpener +func (h *cryptoSetup) dropInitialKeys() { + h.mutex.Lock() + h.initialOpener = nil + h.initialSealer = nil + h.mutex.Unlock() + h.runner.DropKeys(protocol.EncryptionInitial) + h.logger.Debugf("Dropping Initial keys.") +} + +func (h *cryptoSetup) SetHandshakeConfirmed() { + h.aead.SetHandshakeConfirmed() + // drop Handshake keys + var dropped bool + h.mutex.Lock() + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } + h.mutex.Unlock() + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } +} + +func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return h.initialSealer, nil +} + +func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTSealer == nil { + return nil, ErrKeysDropped + } + return h.zeroRTTSealer, nil +} + +func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeSealer == nil { + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return nil, ErrKeysNotYetAvailable + } + return h.handshakeSealer, nil +} + +func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if !h.has1RTTSealer { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialOpener == nil { + return nil, ErrKeysDropped + } + return h.initialOpener, nil +} + +func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.zeroRTTOpener, nil +} + +func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.handshakeOpener, nil +} + +func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { + h.zeroRTTOpener = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + + if !h.has1RTTOpener { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) ConnectionState() ConnectionState { + return qtls.GetConnectionState(h.conn) +} diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go new file mode 100644 index 00000000..7d592428 --- /dev/null +++ b/internal/handshake/crypto_setup_test.go @@ -0,0 +1,864 @@ +package handshake + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" + + mocktls "github.com/imroc/req/v3/internal/mocks/tls" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/testdata" + "github.com/imroc/req/v3/internal/utils" + "github.com/imroc/req/v3/internal/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +type chunk struct { + data []byte + encLevel protocol.EncryptionLevel +} + +type stream struct { + encLevel protocol.EncryptionLevel + chunkChan chan<- chunk +} + +func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { + return &stream{ + chunkChan: chunkChan, + encLevel: encLevel, + } +} + +func (s *stream) Write(b []byte) (int, error) { + data := make([]byte, len(b)) + copy(data, b) + select { + case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: + default: + panic("chunkChan too small") + } + return len(b), nil +} + +var _ = Describe("Crypto Setup TLS", func() { + var clientConf, serverConf *tls.Config + + // unparam incorrectly complains that the first argument is never used. + //nolint:unparam + initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { + chunkChan := make(chan chunk, 100) + initialStream := newStream(chunkChan, protocol.EncryptionInitial) + handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) + return chunkChan, initialStream, handshakeStream + } + + BeforeEach(func() { + serverConf = testdata.GetTLSConfig() + serverConf.NextProtos = []string{"crypto-setup"} + clientConf = &tls.Config{ + ServerName: "localhost", + RootCAs: testdata.GetRootCA(), + NextProtos: []string{"crypto-setup"}, + } + }) + + It("returns Handshake() when an error occurs in qtls", func() { + sErrChan := make(chan error, 1) + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + _, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + testdata.GetTLSConfig(), + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "local error: tls: unexpected message", + }))) + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + handledMessage := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.HandleMessage(fakeCH, protocol.EncryptionInitial) + close(handledMessage) + }() + Eventually(handledMessage).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + + It("handles qtls errors occurring before during ClientHello generation", func() { + sErrChan := make(chan error, 1) + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + _, sInitialStream, sHandshakeStream := initStreams() + tlsConf := testdata.GetTLSConfig() + tlsConf.InsecureSkipVerify = true + tlsConf.NextProtos = []string{""} + cl, _ := NewCryptoSetupClient( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + runner, + tlsConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cl.RunHandshake() + close(done) + }() + + Eventually(done).Should(BeClosed()) + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: "tls: invalid NextProtos value", + }))) + }) + + It("errors when a message is received at the wrong encryption level", func() { + sErrChan := make(chan error, 1) + _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + testdata.GetTLSConfig(), + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", + }))) + + // make the go routine return + Expect(server.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("returns Handshake() when handling a message fails", func() { + sErrChan := make(chan error, 1) + _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + var err error + Expect(sErrChan).To(Receive(&err)) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + close(done) + }() + + fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) + server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level + Eventually(done).Should(BeClosed()) + }) + + It("returns Handshake() when it is closed", func() { + _, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + NewMockHandshakeRunner(mockCtrl), + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + close(done) + }() + Expect(server.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + Context("doing the handshake", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + + newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { + rttStats := &utils.RTTStats{} + rttStats.UpdateRTT(rtt, 0, time.Now()) + ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) + return rttStats + } + + handshake := func(client CryptoSetup, cChunkChan <-chan chunk, + server CryptoSetup, sChunkChan <-chan chunk, + ) { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + for { + select { + case c := <-cChunkChan: + msgType := messageType(c.data[0]) + finished := server.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeClientHello { + // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. + _, err := server.GetHandshakeOpener() + Expect(finished).To(Equal(err == nil)) + } else { + Expect(finished).To(BeFalse()) + } + case c := <-sChunkChan: + msgType := messageType(c.data[0]) + finished := client.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeServerHello { + Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) + } else { + Expect(finished).To(BeFalse()) + } + case <-done: // handshake complete + return + } + } + }() + + go func() { + defer GinkgoRecover() + defer close(done) + server.RunHandshake() + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + client.HandleMessage(ticket, protocol.Encryption1RTT) + } + }() + + client.RunHandshake() + Eventually(done).Should(BeClosed()) + } + + handshakeWithTLSConf := func( + clientConf, serverConf *tls.Config, + clientRTTStats, serverRTTStats *utils.RTTStats, + clientTransportParameters, serverTransportParameters *wire.TransportParameters, + enable0RTT bool, + ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { + var cHandshakeComplete bool + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cErrChan := make(chan error, 1) + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) + cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) + cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) + client, clientHelloWrittenChan := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + clientTransportParameters, + cRunner, + clientConf, + enable0RTT, + clientRTTStats, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + var sHandshakeComplete bool + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sErrChan := make(chan error, 1) + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) + sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) + if serverTransportParameters.StatelessResetToken == nil { + var token protocol.StatelessResetToken + serverTransportParameters.StatelessResetToken = &token + } + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + serverTransportParameters, + sRunner, + serverConf, + enable0RTT, + serverRTTStats, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + handshake(client, cChunkChan, server, sChunkChan) + var cErr, sErr error + select { + case sErr = <-sErrChan: + default: + Expect(sHandshakeComplete).To(BeTrue()) + } + select { + case cErr = <-cErrChan: + default: + Expect(cHandshakeComplete).To(BeTrue()) + } + return clientHelloWrittenChan, client, cErr, server, sErr + } + + It("handshakes", func() { + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("performs a HelloRetryRequst", func() { + serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("handshakes with client auth", func() { + clientConf.Certificates = []tls.Certificate{generateCert()} + serverConf.ClientAuth = tls.RequireAnyClientCert + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("signals when it has written the ClientHello", func() { + runner := NewMockHandshakeRunner(mockCtrl) + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, chChan := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + runner, + &tls.Config{InsecureSkipVerify: true}, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RunHandshake() + close(done) + }() + var ch chunk + Eventually(cChunkChan).Should(Receive(&ch)) + Eventually(chChan).Should(Receive(BeNil())) + // make sure the whole ClientHello was written + Expect(len(ch.data)).To(BeNumerically(">=", 4)) + Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) + Expect(len(ch.data) - 4).To(Equal(length)) + + // make the go routine return + Expect(client.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second} + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + cTransportParameters, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) + sRunner.EXPECT().OnHandshakeComplete() + sTransportParameters := &wire.TransportParameters{ + MaxIdleTimeout: 0x1337 * time.Second, + StatelessResetToken: &token, + } + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + sTransportParameters, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) + }) + + Context("with session tickets", func() { + It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + // inject an invalid session ticket + cRunner.EXPECT().OnError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.EncryptionHandshake) + }) + + It("errors when handling the NewSessionTicket fails", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + // inject an invalid session ticket + cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.Encryption1RTT) + }) + + It("uses session resumption", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + clientRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + }) + + It("doesn't use session resumption if the server disabled it", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + _, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + + serverConf.SessionTicketsDisabled = true + csc.EXPECT().Get(gomock.Any()).Return(state, true) + _, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + }) + + It("uses 0-RTT", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + const initialMaxData protocol.ByteCount = 1337 + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, serverOrigRTTStats, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), nil) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + + clientRTTStats := &utils.RTTStats{} + serverRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, serverRTTStats, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) + + var tp *wire.TransportParameters + Expect(clientHelloWrittenChan).To(Receive(&tp)) + Expect(tp.InitialMaxData).To(Equal(initialMaxData)) + + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(server.ConnectionState().Used0RTT).To(BeTrue()) + Expect(client.ConnectionState().Used0RTT).To(BeTrue()) + }) + + It("rejects 0-RTT, when the transport parameters changed", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + const initialMaxData protocol.ByteCount = 1337 + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), nil) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + + clientRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + + var tp *wire.TransportParameters + Expect(clientHelloWrittenChan).To(Receive(&tp)) + Expect(tp.InitialMaxData).To(Equal(initialMaxData)) + + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(server.ConnectionState().Used0RTT).To(BeFalse()) + Expect(client.ConnectionState().Used0RTT).To(BeFalse()) + }) + }) + }) +}) diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go new file mode 100644 index 00000000..065877e8 --- /dev/null +++ b/internal/handshake/handshake_suite_test.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "crypto/tls" + "encoding/hex" + "strings" + "testing" + + "github.com/imroc/req/v3/internal/qtls" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestHandshake(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Handshake Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) + +func splitHexString(s string) (slice []byte) { + for _, ss := range strings.Split(s, " ") { + if ss[0:2] == "0x" { + ss = ss[2:] + } + d, err := hex.DecodeString(ss) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + slice = append(slice, d...) + } + return +} + +var cipherSuites = []*qtls.CipherSuiteTLS13{ + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), + qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), +} diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go new file mode 100644 index 00000000..8921fb05 --- /dev/null +++ b/internal/handshake/header_protector.go @@ -0,0 +1,137 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + "github.com/lucas-clemente/quic-go" + + "golang.org/x/crypto/chacha20" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" +) + +type headerProtector interface { + EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) + DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) +} + +func hkdfHeaderProtectionLabel(v quic.VersionNumber) string { + if v == protocol.Version2 { + return "quicv2 hp" + } + return "quic hp" +} + +func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v quic.VersionNumber) headerProtector { + hkdfLabel := hkdfHeaderProtectionLabel(v) + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + case tls.TLS_CHACHA20_POLY1305_SHA256: + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + default: + panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) + } +} + +type aesHeaderProtector struct { + mask []byte + block cipher.Block + isLongHeader bool +} + +var _ headerProtector = &aesHeaderProtector{} + +func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + block, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return &aesHeaderProtector{ + block: block, + mask: make([]byte, block.BlockSize()), + isLongHeader: isLongHeader, + } +} + +func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != len(p.mask) { + panic("invalid sample size") + } + p.block.Encrypt(p.mask, sample) + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} + +type chachaHeaderProtector struct { + mask [5]byte + + key [32]byte + isLongHeader bool +} + +var _ headerProtector = &chachaHeaderProtector{} + +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + + p := &chachaHeaderProtector{ + isLongHeader: isLongHeader, + } + copy(p.key[:], hpKey) + return p +} + +func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != 16 { + panic("invalid sample size") + } + for i := 0; i < 5; i++ { + p.mask[i] = 0 + } + cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) + if err != nil { + panic(err) + } + cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) + cipher.XORKeyStream(p.mask[:], p.mask[:]) + p.applyMask(firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/handshake/hkdf.go b/internal/handshake/hkdf.go new file mode 100644 index 00000000..c4fd86c5 --- /dev/null +++ b/internal/handshake/hkdf.go @@ -0,0 +1,29 @@ +package handshake + +import ( + "crypto" + "encoding/binary" + + "golang.org/x/crypto/hkdf" +) + +// hkdfExpandLabel HKDF expands a label. +// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the +// hkdfExpandLabel in the standard library. +func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { + b := make([]byte, 3, 3+6+len(label)+1+len(context)) + binary.BigEndian.PutUint16(b, uint16(length)) + b[2] = uint8(6 + len(label)) + b = append(b, []byte("tls13 ")...) + b = append(b, []byte(label)...) + b = b[:3+6+len(label)+1] + b[3+6+len(label)] = uint8(len(context)) + b = append(b, context...) + + out := make([]byte, length) + n, err := hkdf.Expand(hash.New, secret, b).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/internal/handshake/hkdf_test.go b/internal/handshake/hkdf_test.go new file mode 100644 index 00000000..16154199 --- /dev/null +++ b/internal/handshake/hkdf_test.go @@ -0,0 +1,17 @@ +package handshake + +import ( + "crypto" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Initial AEAD using AES-GCM", func() { + // Result generated by running in qtls: + // cipherSuiteTLS13ByID(TLS_AES_128_GCM_SHA256).expandLabel([]byte("secret"), []byte("context"), "label", 42) + It("gets the same results as qtls", func() { + expanded := hkdfExpandLabel(crypto.SHA256, []byte("secret"), []byte("context"), "label", 42) + Expect(expanded).To(Equal([]byte{0x78, 0x87, 0x6a, 0xb5, 0x84, 0xa2, 0x26, 0xb7, 0x8, 0x5a, 0x7b, 0x3a, 0x4c, 0xbb, 0x1e, 0xbc, 0x2f, 0x9b, 0x67, 0xd0, 0x6a, 0xa2, 0x24, 0xb4, 0x7d, 0x29, 0x3c, 0x7a, 0xce, 0xc7, 0xc3, 0x74, 0xcd, 0x59, 0x7a, 0xa8, 0x21, 0x5e, 0xe7, 0xca, 0x1, 0xda})) + }) +}) diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go new file mode 100644 index 00000000..97a40f74 --- /dev/null +++ b/internal/handshake/initial_aead.go @@ -0,0 +1,82 @@ +package handshake + +import ( + "crypto" + "crypto/tls" + "github.com/lucas-clemente/quic-go" + + "golang.org/x/crypto/hkdf" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" +) + +var ( + quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} + quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} +) + +const ( + hkdfLabelKeyV1 = "quic key" + hkdfLabelKeyV2 = "quicv2 key" + hkdfLabelIVV1 = "quic iv" + hkdfLabelIVV2 = "quicv2 iv" +) + +func getSalt(v quic.VersionNumber) []byte { + if v == protocol.Version2 { + return quicSaltV2 + } + if v == protocol.Version1 { + return quicSaltV1 + } + return quicSaltOld +} + +var initialSuite = &qtls.CipherSuiteTLS13{ + ID: tls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, +} + +// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v quic.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { + clientSecret, serverSecret := computeSecrets(connID, v) + var mySecret, otherSecret []byte + if pers == protocol.PerspectiveClient { + mySecret = clientSecret + otherSecret = serverSecret + } else { + mySecret = serverSecret + otherSecret = clientSecret + } + myKey, myIV := computeInitialKeyAndIV(mySecret, v) + otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) + + encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) + decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) + + return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) +} + +func computeSecrets(connID protocol.ConnectionID, v quic.VersionNumber) (clientSecret, serverSecret []byte) { + initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) + serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) + return +} + +func computeInitialKeyAndIV(secret []byte, v quic.VersionNumber) (key, iv []byte) { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) + iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) + return +} diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go new file mode 100644 index 00000000..f7a02f63 --- /dev/null +++ b/internal/handshake/initial_aead_test.go @@ -0,0 +1,219 @@ +package handshake + +import ( + "fmt" + "math/rand" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Initial AEAD using AES-GCM", func() { + It("converts the string representation used in the draft into byte slices", func() { + Expect(splitHexString("0xdeadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + Expect(splitHexString("deadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) + + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + + DescribeTable("computes the client key and IV", + func(v quic.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { + clientSecret, _ := computeSecrets(connID, v) + Expect(clientSecret).To(Equal(expectedClientSecret)) + key, iv := computeInitialKeyAndIV(clientSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft-29", + protocol.VersionDraft29, + splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"), + splitHexString("175257a31eb09dea9366d8bb79ad80ba"), + splitHexString("6b26114b9cba2b63a9e8dd4f"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), + splitHexString("1f369613dd76d5467730efcbe3b1a22d"), + splitHexString("fa044b2f42a3fd3b46fb255c"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("9fe72e1452e91f551b770005054034e4 7575d4a0fb4c27b7c6cb303a338423ae"), + splitHexString("95df2be2e8d549c82e996fc9339f4563"), + splitHexString("ea5e3c95f933db14b7020ad8"), + ), + ) + + DescribeTable("computes the server key and IV", + func(v quic.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { + _, serverSecret := computeSecrets(connID, v) + Expect(serverSecret).To(Equal(expectedServerSecret)) + key, iv := computeInitialKeyAndIV(serverSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"), + splitHexString("149d0b1662ab871fbe63c49b5e655a5d"), + splitHexString("bab2b12a4c76016ace47856d"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), + splitHexString("cf3a5331653c364c88f0f379b6067e37"), + splitHexString("0ac1493ca1905853b0bba03e"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("3c9bf6a9c1c8c71819876967bd8b979e fd98ec665edf27f22c06e9845ba0ae2f"), + splitHexString("15d5b4d9a2b8916aa39b1bfe574d2aad"), + splitHexString("a85e7ac31cd275cbb095c626"), + ), + ) + + DescribeTable("encrypts the client's Initial", + func(v quic.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) + data = append(data, make([]byte, 1162-len(data))...) // add PADDING + sealed := sealer.Seal(nil, data, 2, header) + sample := sealed[0:16] + Expect(sample).To(Equal(expectedSample)) + sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) + Expect(header[0]).To(Equal(expectedHdrFirstByte)) + Expect(header[len(header)-4:]).To(Equal(expectedHdr)) + packet := append(header, sealed...) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c3ff00001d088394c8f03e5157080000449e00000002"), + splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001"), + splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"), + byte(0xc5), + splitHexString("4a95245b"), + splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c300000001088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"), + byte(0xc0), + splitHexString("7b9aec34"), + splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d3709a50c4088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("23b8e610589c83c92d0e97eb7a6e5003"), + byte(0xdd), + splitHexString("4391d848"), + splitHexString("dd709a50c4088394c8f03e5157080000 449e4391d84823b8e610589c83c92d0e 97eb7a6e5003f57764c5c7f0095ba54b 90818f1bfeecc1c97c54fc731edbd2a2 44e3b1e639a9bc75ed545b98649343b2 53615ec6b3e4df0fd2e7fe9d691a09e6 a144b436d8a2c088a404262340dfd995 ec3865694e3026ecd8c6d2561a5a3667 2a1005018168c0f081c10e2bf14d550c 977e28bb9a759c57d0f7ffb1cdfb40bd 774dec589657542047dffefa56fc8089 a4d1ef379c81ba3df71a05ddc7928340 775910feb3ce4cbcfd8d253edd05f161 458f9dc44bea017c3117cca7065a315d eda9464e672ec80c3f79ac993437b441 ef74227ecc4dc9d597f66ab0ab8d214b 55840c70349d7616cbe38e5e1d052d07 f1fedb3dd3c4d8ce295724945e67ed2e efcd9fb52472387f318e3d9d233be7df c79d6bf6080dcbbb41feb180d7858849 7c3e439d38c334748d2b56fd19ab364d 057a9bd5a699ae145d7fdbc8f5777518 1b0a97c3bdedc91a555d6c9b8634e106 d8c9ca45a9d5450a7679edc545da9102 5bc93a7cf9a023a066ffadb9717ffaf3 414c3b646b5738b3cc4116502d18d79d 8227436306d9b2b3afc6c785ce3c817f eb703a42b9c83b59f0dcef1245d0b3e4 0299821ec19549ce489714fe2611e72c d882f4f70dce7d3671296fc045af5c9f 630d7b49a3eb821bbca60f1984dce664 91713bfe06001a56f51bb3abe92f7960 547c4d0a70f4a962b3f05dc25a34bbe8 30a7ea4736d3b0161723500d82beda9b e3327af2aa413821ff678b2a876ec4b0 0bb605ffcc3917ffdc279f187daa2fce 8cde121980bba8ec8f44ca562b0f1319 14c901cfbd847408b778e6738c7bb5b1 b3f97d01b0a24dcca40e3bed29411b1b a8f60843c4a241021b23132b9500509b 9a3516d4a9dd41d3bacbcd426b451393 521828afedcf20fa46ac24f44a8e2973 30b16705d5d5f798eff9e9134a065979 87a1db4617caa2d93837730829d4d89e 16413be4d8a8a38a7e6226623b64a820 178ec3a66954e10710e043ae73dd3fb2 715a0525a46343fb7590e5eac7ee55fc 810e0d8b4b8f7be82cd5a214575a1b99 629d47a9b281b61348c8627cab38e2a6 4db6626e97bb8f77bdcb0fee476aedd7 ba8f5441acaab00f4432edab3791047d 9091b2a753f035648431f6d12f7d6a68 1e64c861f4ac911a0f7d6ec0491a78c9 f192f96b3a5e7560a3f056bc1ca85983 67ad6acb6f2e034c7f37beeb9ed470c4 304af0107f0eb919be36a86f68f37fa6 1dae7aff14decd67ec3157a11488a14f ed0142828348f5f608b0fe03e1f3c0af 3acca0ce36852ed42e220ae9abf8f890 6f00f1b86bff8504c8f16c784fd52d25 e013ff4fda903e9e1eb453c1464b1196 6db9b28e8f26a3fc419e6a60a48d4c72 14ee9c6c6a12b68a32cac8f61580c64f 29cb6922408783c6d12e725b014fe485 cd17e484c5952bf99bc94941d4b1919d 04317b8aa1bd3754ecbaa10ec227de85 40695bf2fb8ee56f6dc526ef366625b9 1aa4970b6ffa5c8284b9b5ab852b905f 9d83f5669c0535bc377bcc05ad5e48e2 81ec0e1917ca3c6a471f8da0894bc82a c2a8965405d6eef3b5e293a88fda203f 09bdc72757b107ab14880eaa3ef7045b 580f4821ce6dd325b5a90655d8c5b55f 76fb846279a9b518c5e9b9a21165c509 3ed49baaacadf1f21873266c767f6769"), + ), + ) + + DescribeTable("encrypts the server's Initial", + func(v quic.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) + sealed := sealer.Seal(nil, data, 1, header) + sample := sealed[2 : 2+16] + Expect(sample).To(Equal(expectedSample)) + sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) + Expect(header).To(Equal(expectedHdr)) + packet := append(header, sealed...) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c1ff00001d0008f067a5502a4262b50040740001"), + splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304"), + splitHexString("823a5d3a1207c86ee49132824f046524"), + splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"), + splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c1000000010008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("2cd0991cd25b0aac406a5816b6394100"), + splitHexString("cf000000010008f067a5502a4262b5004075c0d9"), + splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d1709a50c40008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("ebb7972fdce59d50e7e49ff2a7e8de76"), + splitHexString("d0709a50c40008f067a5502a4262b5004075103e"), + splitHexString("d0709a50c40008f067a5502a4262b500 4075103e63b4ebb7972fdce59d50e7e4 9ff2a7e8de76b0cd8c10100a1f13d549 dd6fe801588fb14d279bef8d7c53ef62 66a9a7a1a5f2fa026c236a5bf8df5aa0 f9d74773aeccfffe910b0f76814b5e33 f7b7f8ec278d23fd8c7a9e66856b8bbe 72558135bca27c54d63fcc902253461c fc089d4e6b9b19"), + ), + ) + + for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + It("seals and opens", func() { + connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} + clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) + serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) + + clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) + m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) + Expect(err).ToNot(HaveOccurred()) + Expect(m).To(Equal([]byte("foobar"))) + serverMessage := serverSealer.Seal(nil, []byte("raboof"), 99, []byte("daa")) + m, err = clientOpener.Open(nil, serverMessage, 99, []byte("daa")) + Expect(err).ToNot(HaveOccurred()) + Expect(m).To(Equal([]byte("raboof"))) + }) + + It("doesn't work if initialized with different connection IDs", func() { + c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} + c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} + clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) + _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) + + clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) + _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("encrypts und decrypts the header", func() { + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) + serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) + + // the first byte and the last 4 bytes should be encrypted + header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + clientSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + serverOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + + serverSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + clientOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) + }) + } +}) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go new file mode 100644 index 00000000..e17202c3 --- /dev/null +++ b/internal/handshake/interface.go @@ -0,0 +1,103 @@ +package handshake + +import ( + "errors" + "github.com/lucas-clemente/quic-go" + "io" + "net" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" + "github.com/imroc/req/v3/internal/wire" +) + +var ( + // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding opener has not yet been initialized + // This can happen when packets arrive out of order. + ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") + // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding keys have already been dropped. + ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") + // ErrDecryptionFailed is returned when the AEAD fails to open the packet. + ErrDecryptionFailed = errors.New("decryption failed") +) + +// ConnectionState contains information about the state of the connection. +type ConnectionState = qtls.ConnectionState + +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + +// LongHeaderOpener opens a long header packet +type LongHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + +// ShortHeaderOpener opens a short header packet +type ShortHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) +} + +// LongHeaderSealer seals a long header packet +type LongHeaderSealer interface { + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) + Overhead() int +} + +// ShortHeaderSealer seals a short header packet +type ShortHeaderSealer interface { + LongHeaderSealer + KeyPhase() protocol.KeyPhaseBit +} + +// A tlsExtensionHandler sends and received the QUIC TLS extension. +type tlsExtensionHandler interface { + GetExtensions(msgType uint8) []qtls.Extension + ReceivedExtensions(msgType uint8, exts []qtls.Extension) + TransportParameters() <-chan []byte +} + +type handshakeRunner interface { + OnReceivedParams(*wire.TransportParameters) + OnHandshakeComplete() + OnError(error) + DropKeys(protocol.EncryptionLevel) +} + +// CryptoSetup handles the handshake and protecting / unprotecting packets +type CryptoSetup interface { + RunHandshake() + io.Closer + ChangeConnectionID(protocol.ConnectionID) + GetSessionTicket() ([]byte, error) + + HandleMessage([]byte, protocol.EncryptionLevel) bool + SetLargest1RTTAcked(protocol.PacketNumber) error + SetHandshakeConfirmed() + ConnectionState() ConnectionState + + GetInitialOpener() (LongHeaderOpener, error) + GetHandshakeOpener() (LongHeaderOpener, error) + Get0RTTOpener() (LongHeaderOpener, error) + Get1RTTOpener() (ShortHeaderOpener, error) + + GetInitialSealer() (LongHeaderSealer, error) + GetHandshakeSealer() (LongHeaderSealer, error) + Get0RTTSealer() (LongHeaderSealer, error) + Get1RTTSealer() (ShortHeaderSealer, error) +} + +// ConnWithVersion is the connection used in the ClientHelloInfo. +// It can be used to determine the QUIC version in use. +type ConnWithVersion interface { + net.Conn + GetQUICVersion() quic.VersionNumber +} diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go new file mode 100644 index 00000000..4f25e6a2 --- /dev/null +++ b/internal/handshake/mock_handshake_runner_test.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go + +// Package handshake is a generated GoMock package. +package handshake + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + wire "github.com/imroc/req/v3/internal/wire" +) + +// MockHandshakeRunner is a mock of HandshakeRunner interface. +type MockHandshakeRunner struct { + ctrl *gomock.Controller + recorder *MockHandshakeRunnerMockRecorder +} + +// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner. +type MockHandshakeRunnerMockRecorder struct { + mock *MockHandshakeRunner +} + +// NewMockHandshakeRunner creates a new mock instance. +func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { + mock := &MockHandshakeRunner{ctrl: ctrl} + mock.recorder = &MockHandshakeRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { + return m.recorder +} + +// DropKeys mocks base method. +func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropKeys", arg0) +} + +// DropKeys indicates an expected call of DropKeys. +func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) +} + +// OnError mocks base method. +func (m *MockHandshakeRunner) OnError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnError", arg0) +} + +// OnError indicates an expected call of OnError. +func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) +} + +// OnHandshakeComplete mocks base method. +func (m *MockHandshakeRunner) OnHandshakeComplete() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnHandshakeComplete") +} + +// OnHandshakeComplete indicates an expected call of OnHandshakeComplete. +func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) +} + +// OnReceivedParams mocks base method. +func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnReceivedParams", arg0) +} + +// OnReceivedParams indicates an expected call of OnReceivedParams. +func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) +} diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go new file mode 100644 index 00000000..0fc6389b --- /dev/null +++ b/internal/handshake/mockgen.go @@ -0,0 +1,3 @@ +package handshake + +//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/imroc/req/v3/internal/handshake handshakeRunner" diff --git a/internal/handshake/retry.go b/internal/handshake/retry.go new file mode 100644 index 00000000..c942cd47 --- /dev/null +++ b/internal/handshake/retry.go @@ -0,0 +1,63 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" + "github.com/lucas-clemente/quic-go" + "sync" + + "github.com/imroc/req/v3/internal/protocol" +) + +var ( + oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 + retryAEAD cipher.AEAD // used for QUIC draft-34 +) + +func init() { + oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) + retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) +} + +func initAEAD(key [16]byte) cipher.AEAD { + aes, err := aes.NewCipher(key[:]) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + return aead +} + +var ( + retryBuf bytes.Buffer + retryMutex sync.Mutex + oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} + retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} +) + +// GetRetryIntegrityTag calculates the integrity tag on a Retry packet +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version quic.VersionNumber) *[16]byte { + retryMutex.Lock() + retryBuf.WriteByte(uint8(origDestConnID.Len())) + retryBuf.Write(origDestConnID.Bytes()) + retryBuf.Write(retry) + + var tag [16]byte + var sealed []byte + if version != protocol.Version1 { + sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) + } else { + sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) + } + if len(sealed) != 16 { + panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) + } + retryBuf.Reset() + retryMutex.Unlock() + return &tag +} diff --git a/internal/handshake/retry_test.go b/internal/handshake/retry_test.go new file mode 100644 index 00000000..4b74fc41 --- /dev/null +++ b/internal/handshake/retry_test.go @@ -0,0 +1,36 @@ +package handshake + +import ( + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Retry Integrity Check", func() { + It("calculates retry integrity tags", func() { + fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + Expect(fooTag).ToNot(BeNil()) + Expect(barTag).ToNot(BeNil()) + Expect(*fooTag).ToNot(Equal(*barTag)) + }) + + It("includes the original connection ID in the tag calculation", func() { + t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) + t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) + Expect(*t1).ToNot(Equal(*t2)) + }) + + It("uses the test vector from the draft, for old draft versions", func() { + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") + Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) + }) + + It("uses the test vector from the draft, for version 1", func() { + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") + Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) + }) +}) diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go new file mode 100644 index 00000000..1ac9e049 --- /dev/null +++ b/internal/handshake/session_ticket.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "bytes" + "errors" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/wire" +) + +const sessionTicketRevision = 2 + +type sessionTicket struct { + Parameters *wire.TransportParameters + RTT time.Duration // to be encoded in mus +} + +func (t *sessionTicket) Marshal() []byte { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + quicvarint.Write(b, uint64(t.RTT.Microseconds())) + t.Parameters.MarshalForSessionTicket(b) + return b.Bytes() +} + +func (t *sessionTicket) Unmarshal(b []byte) error { + r := bytes.NewReader(b) + rev, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read session ticket revision") + } + if rev != sessionTicketRevision { + return fmt.Errorf("unknown session ticket revision: %d", rev) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read RTT") + } + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + t.RTT = time.Duration(rtt) * time.Microsecond + return nil +} diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go new file mode 100644 index 00000000..bf5b2407 --- /dev/null +++ b/internal/handshake/session_ticket_test.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "bytes" + "time" + + "github.com/imroc/req/v3/internal/wire" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Session Ticket", func() { + It("marshals and unmarshals a session ticket", func() { + ticket := &sessionTicket{ + Parameters: &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + }, + RTT: 1337 * time.Microsecond, + } + var t sessionTicket + Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) + Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) + Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) + Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + }) + + It("refuses to unmarshal if the ticket is too short for the revision", func() { + Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) + }) + + It("refuses to unmarshal if the revision doesn't match", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, 1337) + Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("unknown session ticket revision: 1337")) + }) + + It("refuses to unmarshal if the RTT cannot be read", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("failed to read RTT")) + }) + + It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + b.Write([]byte("foobar")) + err := (&sessionTicket{}).Unmarshal(b.Bytes()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) + }) +}) diff --git a/internal/handshake/tls_extension_handler.go b/internal/handshake/tls_extension_handler.go new file mode 100644 index 00000000..967f4d1f --- /dev/null +++ b/internal/handshake/tls_extension_handler.go @@ -0,0 +1,69 @@ +package handshake + +import ( + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" + "github.com/lucas-clemente/quic-go" +) + +const ( + quicTLSExtensionTypeOldDrafts = 0xffa5 + quicTLSExtensionType = 0x39 +) + +type extensionHandler struct { + ourParams []byte + paramsChan chan []byte + + extensionType uint16 + + perspective protocol.Perspective +} + +var _ tlsExtensionHandler = &extensionHandler{} + +// newExtensionHandler creates a new extension handler +func newExtensionHandler(params []byte, pers protocol.Perspective, v quic.VersionNumber) tlsExtensionHandler { + et := uint16(quicTLSExtensionType) + if v != protocol.Version1 { + et = quicTLSExtensionTypeOldDrafts + } + return &extensionHandler{ + ourParams: params, + paramsChan: make(chan []byte), + perspective: pers, + extensionType: et, + } +} + +func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { + return nil + } + return []qtls.Extension{{ + Type: h.extensionType, + Data: h.ourParams, + }} +} + +func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { + return + } + + var data []byte + for _, ext := range exts { + if ext.Type == h.extensionType { + data = ext.Data + break + } + } + + h.paramsChan <- data +} + +func (h *extensionHandler) TransportParameters() <-chan []byte { + return h.paramsChan +} diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go new file mode 100644 index 00000000..fcd2223b --- /dev/null +++ b/internal/handshake/tls_extension_handler_test.go @@ -0,0 +1,210 @@ +package handshake + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("TLS Extension Handler, for the server", func() { + var ( + handlerServer tlsExtensionHandler + handlerClient tlsExtensionHandler + version quic.VersionNumber + ) + + BeforeEach(func() { + version = protocol.VersionDraft29 + }) + + JustBeforeEach(func() { + handlerServer = newExtensionHandler( + []byte("foobar"), + protocol.PerspectiveServer, + version, + ) + handlerClient = newExtensionHandler( + []byte("raboof"), + protocol.PerspectiveClient, + version, + ) + }) + + Context("for the server", func() { + for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1} { + v := ver + + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the EncryptedExtensions message", func() { + exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("foobar"))) + }) + }) + } + + Context("receiving", func() { + var chExts []qtls.Extension + + JustBeforeEach(func() { + chExts = handlerClient.GetExtensions(uint8(typeClientHello)) + Expect(chExts).To(HaveLen(1)) + }) + + It("sends the extension on the channel", func() { + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) + Expect(data).To(Equal([]byte("raboof"))) + }) + + It("sends nil on the channel if the extension is missing", func() { + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions with different code points", func() { + go func() { + defer GinkgoRecover() + exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} + handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive()) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions that are not sent with the ClientHello", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) + close(done) + }() + + Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) + Eventually(done).Should(BeClosed()) + }) + }) + }) + + Context("for the client", func() { + for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1} { + v := ver + + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the ClientHello message", func() { + exts := handlerClient.GetExtensions(uint8(typeClientHello)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("raboof"))) + }) + }) + } + + Context("receiving", func() { + var chExts []qtls.Extension + + JustBeforeEach(func() { + chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(chExts).To(HaveLen(1)) + }) + + It("sends the extension on the channel", func() { + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("sends nil on the channel if the extension is missing", func() { + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions with different code points", func() { + go func() { + defer GinkgoRecover() + exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive()) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions that are not sent with the EncryptedExtensions", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) + close(done) + }() + + Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) + Eventually(done).Should(BeClosed()) + }) + }) + }) +}) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go new file mode 100644 index 00000000..349a1ee5 --- /dev/null +++ b/internal/handshake/token_generator.go @@ -0,0 +1,134 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "io" + "net" + "time" + + "github.com/imroc/req/v3/internal/protocol" +) + +const ( + tokenPrefixIP byte = iota + tokenPrefixString +) + +// A Token is derived from the client address and can be used to verify the ownership of this address. +type Token struct { + IsRetryToken bool + RemoteAddr string + SentTime time.Time + // only set for retry tokens + OriginalDestConnectionID protocol.ConnectionID + RetrySrcConnectionID protocol.ConnectionID +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + IsRetryToken bool + RemoteAddr []byte + Timestamp int64 + OriginalDestConnectionID []byte + RetrySrcConnectionID []byte +} + +// A TokenGenerator generates tokens +type TokenGenerator struct { + tokenProtector tokenProtector +} + +// NewTokenGenerator initializes a new TookenGenerator +func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { + tokenProtector, err := newTokenProtector(rand) + if err != nil { + return nil, err + } + return &TokenGenerator{ + tokenProtector: tokenProtector, + }, nil +} + +// NewRetryToken generates a new token for a Retry for a given source address +func (g *TokenGenerator) NewRetryToken( + raddr net.Addr, + origDestConnID protocol.ConnectionID, + retrySrcConnID protocol.ConnectionID, +) ([]byte, error) { + data, err := asn1.Marshal(token{ + IsRetryToken: true, + RemoteAddr: encodeRemoteAddr(raddr), + OriginalDestConnectionID: origDestConnID, + RetrySrcConnectionID: retrySrcConnID, + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// NewToken generates a new token to be sent in a NEW_TOKEN frame +func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + RemoteAddr: encodeRemoteAddr(raddr), + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// DecodeToken decodes a token +func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { + // if the client didn't send any token, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.tokenProtector.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + token := &Token{ + IsRetryToken: t.IsRetryToken, + RemoteAddr: decodeRemoteAddr(t.RemoteAddr), + SentTime: time.Unix(0, t.Timestamp), + } + if t.IsRetryToken { + token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + } + return token, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the token +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{tokenPrefixIP}, udpAddr.IP...) + } + return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the token +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a token that we generated. + // Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == tokenPrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go new file mode 100644 index 00000000..f2a2c0b3 --- /dev/null +++ b/internal/handshake/token_generator_test.go @@ -0,0 +1,127 @@ +package handshake + +import ( + "crypto/rand" + "encoding/asn1" + "net" + "time" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Token Generator", func() { + var tokenGen *TokenGenerator + + BeforeEach(func() { + var err error + tokenGen, err = NewTokenGenerator(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + }) + + It("generates a token", func() { + ip := net.IPv4(127, 0, 0, 1) + token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(token).ToNot(BeEmpty()) + }) + + It("works with nil tokens", func() { + token, err := tokenGen.DecodeToken(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(token).To(BeNil()) + }) + + It("accepts a valid token", func() { + ip := net.IPv4(192, 168, 0, 1) + tokenEnc, err := tokenGen.NewRetryToken( + &net.UDPAddr{IP: ip, Port: 1337}, + nil, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal("192.168.0.1")) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) + Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) + }) + + It("saves the connection ID", func() { + tokenEnc, err := tokenGen.NewRetryToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + ) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + }) + + It("rejects invalid tokens", func() { + _, err := tokenGen.DecodeToken([]byte("invalid token")) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that cannot be decoded", func() { + token, err := tokenGen.tokenProtector.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(token) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that can be decoded, but have additional payload", func() { + t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) + Expect(err).ToNot(HaveOccurred()) + t = append(t, []byte("rest")...) + enc, err := tokenGen.tokenProtector.NewToken(t) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(enc) + Expect(err).To(MatchError("rest when unpacking token: 4")) + }) + + // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason + It("doesn't panic if a tokens has no data", func() { + t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) + Expect(err).ToNot(HaveOccurred()) + enc, err := tokenGen.tokenProtector.NewToken(t) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(enc) + Expect(err).ToNot(HaveOccurred()) + }) + + It("works with an IPv6 addresses ", func() { + addresses := []string{ + "2001:db8::68", + "2001:0000:4136:e378:8000:63bf:3fff:fdd2", + "2001::1", + "ff01:0:0:0:0:0:0:2", + } + for _, addr := range addresses { + ip := net.ParseIP(addr) + Expect(ip).ToNot(BeNil()) + raddr := &net.UDPAddr{IP: ip, Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal(ip.String())) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + } + }) + + It("uses the string representation an address that is not a UDP address", func() { + raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + }) +}) diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go new file mode 100644 index 00000000..650f230b --- /dev/null +++ b/internal/handshake/token_protector.go @@ -0,0 +1,89 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// TokenProtector is used to create and verify a token +type tokenProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const ( + tokenSecretSize = 32 + tokenNonceSize = 32 +) + +// tokenProtector is used to create and verify a token +type tokenProtectorImpl struct { + rand io.Reader + secret []byte +} + +// newTokenProtector creates a source for source address tokens +func newTokenProtector(rand io.Reader) (tokenProtector, error) { + secret := make([]byte, tokenSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &tokenProtectorImpl{ + rand: rand, + secret: secret, + }, nil +} + +// NewToken encodes data into a new token. +func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, tokenNonceSize) + if _, err := s.rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { + if len(p) < tokenNonceSize { + return nil, fmt.Errorf("token too short: %d", len(p)) + } + nonce := p[:tokenNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) +} + +func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go new file mode 100644 index 00000000..7171e865 --- /dev/null +++ b/internal/handshake/token_protector_test.go @@ -0,0 +1,67 @@ +package handshake + +import ( + "crypto/rand" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type zeroReader struct{} + +func (r *zeroReader) Read(b []byte) (int, error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} + +var _ = Describe("Token Protector", func() { + var tp tokenProtector + + BeforeEach(func() { + var err error + tp, err = newTokenProtector(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + }) + + It("uses the random source", func() { + tp1, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + tp2, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + t1, err := tp1.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + t2, err := tp2.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t1).To(Equal(t2)) + tp3, err := newTokenProtector(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + t3, err := tp3.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t3).ToNot(Equal(t1)) + }) + + It("encodes and decodes tokens", func() { + token, err := tp.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(token).ToNot(ContainSubstring("foobar")) + decoded, err := tp.DecodeToken(token) + Expect(err).ToNot(HaveOccurred()) + Expect(decoded).To(Equal([]byte("foobar"))) + }) + + It("fails deconding invalid tokens", func() { + token, err := tp.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + token = token[1:] // remove the first byte + _, err = tp.DecodeToken(token) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("message authentication failed")) + }) + + It("errors when decoding too short tokens", func() { + _, err := tp.DecodeToken([]byte("foobar")) + Expect(err).To(MatchError("token too short: 6")) + }) +}) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go new file mode 100644 index 00000000..13ddbc22 --- /dev/null +++ b/internal/handshake/updatable_aead.go @@ -0,0 +1,324 @@ +package handshake + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + "github.com/lucas-clemente/quic-go" + "time" + + "github.com/imroc/req/v3/internal/logging" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/qtls" + "github.com/imroc/req/v3/internal/utils" +) + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +// It's a package-level variable to allow modifying it for testing purposes. +var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval + +type updatableAEAD struct { + suite *qtls.CipherSuiteTLS13 + + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber + handshakeConfirmed bool + + keyUpdateInterval uint64 + invalidPacketLimit uint64 + invalidPacketCount uint64 + + // Time when the keys should be dropped. Keys are dropped on the next call to Open(). + prevRcvAEADExpiry time.Time + prevRcvAEAD cipher.AEAD + + firstRcvdWithCurrentKey protocol.PacketNumber + firstSentWithCurrentKey protocol.PacketNumber + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + numRcvdWithCurrentKey uint64 + numSentWithCurrentKey uint64 + rcvAEAD cipher.AEAD + sendAEAD cipher.AEAD + // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). + aeadOverhead int + + nextRcvAEAD cipher.AEAD + nextSendAEAD cipher.AEAD + nextRcvTrafficSecret []byte + nextSendTrafficSecret []byte + + headerDecrypter headerProtector + headerEncrypter headerProtector + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + version quic.VersionNumber + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var ( + _ ShortHeaderOpener = &updatableAEAD{} + _ ShortHeaderSealer = &updatableAEAD{} +) + +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version quic.VersionNumber) *updatableAEAD { + return &updatableAEAD{ + firstPacketNumber: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, + firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, + firstSentWithCurrentKey: protocol.InvalidPacketNumber, + keyUpdateInterval: KeyUpdateInterval, + rttStats: rttStats, + tracer: tracer, + logger: logger, + version: version, + } +} + +func (a *updatableAEAD) rollKeys() { + if a.prevRcvAEAD != nil { + a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + a.prevRcvAEADExpiry = time.Time{} + } + + a.keyPhase++ + a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber + a.firstSentWithCurrentKey = protocol.InvalidPacketNumber + a.numRcvdWithCurrentKey = 0 + a.numSentWithCurrentKey = 0 + a.prevRcvAEAD = a.rcvAEAD + a.rcvAEAD = a.nextRcvAEAD + a.sendAEAD = a.nextSendAEAD + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) + a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) + a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) startKeyDropTimer(now time.Time) { + d := 3 * a.rttStats.PTO(true) + a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) + a.prevRcvAEADExpiry = now.Add(d) +} + +func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { + return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) +} + +// For the client, this function is called before SetWriteKey. +// For the server, this function is called after SetWriteKey. +func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.rcvAEAD, suite) + } + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) +} + +// For the client, this function is called after SetReadKey. +// For the server, this function is called before SetWriteKey. +func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.sendAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.sendAEAD, suite) + } + + a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { + a.nonceBuf = make([]byte, aead.NonceSize()) + a.aeadOverhead = aead.Overhead() + a.suite = suite + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + a.invalidPacketLimit = protocol.InvalidPacketLimitAES + case tls.TLS_CHACHA20_POLY1305_SHA256: + a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha + default: + panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) + } +} + +func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) +} + +func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + dec, err := a.open(dst, src, rcvTime, pn, kp, ad) + if err == ErrDecryptionFailed { + a.invalidPacketCount++ + if a.invalidPacketCount >= a.invalidPacketLimit { + return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} + } + } + if err == nil { + a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) + } + return dec, err +} + +func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { + a.prevRcvAEAD = nil + a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) + a.prevRcvAEADExpiry = time.Time{} + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + } + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + if kp != a.keyPhase.Bit() { + if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped + } + // we updated the key, but the peer hasn't updated yet + dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + err = ErrDecryptionFailed + } + return dec, err + } + // try opening the packet with the next key phase + dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return nil, ErrDecryptionFailed + } + // Opening succeeded. Check if the peer was allowed to update. + if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + return nil, &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + } + } + a.rollKeys() + a.logger.Debugf("Peer updated keys to %d", a.keyPhase) + // The peer initiated this key update. It's safe to drop the keys for the previous generation now. + // Start a timer to drop the previous key generation. + a.startKeyDropTimer(rcvTime) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, true) + } + a.firstRcvdWithCurrentKey = pn + return dec, err + } + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return dec, ErrDecryptionFailed + } + a.numRcvdWithCurrentKey++ + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + // We initiated the key updated, and now we received the first packet protected with the new key phase. + // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. + if a.keyPhase > 0 { + a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) + a.startKeyDropTimer(rcvTime) + } + a.firstRcvdWithCurrentKey = pn + } + return dec, err +} + +func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + a.firstSentWithCurrentKey = pn + } + if a.firstPacketNumber == protocol.InvalidPacketNumber { + a.firstPacketNumber = pn + } + a.numSentWithCurrentKey++ + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) +} + +func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { + if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), + } + } + a.largestAcked = pn + return nil +} + +func (a *updatableAEAD) SetHandshakeConfirmed() { + a.handshakeConfirmed = true +} + +func (a *updatableAEAD) updateAllowed() bool { + if !a.handshakeConfirmed { + return false + } + // the first key update is allowed as soon as the handshake is confirmed + return a.keyPhase == 0 || + // subsequent key updates as soon as a packet sent with that key phase has been acknowledged + (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey) +} + +func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { + if !a.updateAllowed() { + return false + } + if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) + return true + } + if a.numSentWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) + return true + } + return false +} + +func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { + if a.shouldInitiateKeyUpdate() { + a.rollKeys() + a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, false) + } + } + return a.keyPhase.Bit() +} + +func (a *updatableAEAD) Overhead() int { + return a.aeadOverhead +} + +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { + return a.firstPacketNumber +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go new file mode 100644 index 00000000..0246cc23 --- /dev/null +++ b/internal/handshake/updatable_aead_test.go @@ -0,0 +1,528 @@ +package handshake + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "time" + + "github.com/golang/mock/gomock" + + mocklogging "github.com/imroc/req/v3/internal/mocks/logging" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Updatable AEAD", func() { + DescribeTable("ChaCha test vector", + func(v quic.VersionNumber, expectedPayload, expectedPacket []byte) { + secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") + aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) + chacha := cipherSuites[2] + Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) + aead.SetWriteKey(chacha, secret) + const pnOffset = 1 + header := splitHexString("4200bff4") + payloadOffset := len(header) + plaintext := splitHexString("01") + payload := aead.Seal(nil, plaintext, 654360564, header) + Expect(payload).To(Equal(expectedPayload)) + packet := append(header, payload...) + aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("QUIC v1", + protocol.Version1, + splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), + splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), + splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), + ), + ) + + for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + var ( + client, server *updatableAEAD + serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + rttStats = utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) + }) + + Context("header protection", func() { + It("encrypts and decrypts the header", func() { + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + client.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + server.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) + }) + }) + + Context("message encryption", func() { + var msg, ad []byte + + BeforeEach(func() { + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") + }) + + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("saves the first packet number", func() { + client.Seal(nil, msg, 0x1337, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + client.Seal(nil, msg, 0x1338, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + }) + + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("decodes the packet number", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + + It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { + client.invalidPacketLimit = 10 + for i := 0; i < 9; i++ { + _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + } + _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) + }) + + Context("key updates", func() { + Context("receiving key updates", func() { + It("updates keys", func() { + now := time.Now() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys() + decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) + + It("updates the keys when receiving a packet with the next key phase", func() { + now := time.Now() + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x43, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + now := time.Now() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + rttStats.UpdateRTT(10*time.Millisecond, 0, now) + pto := rttStats.PTO(true) + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("allows the first key update immediately", func() { + // receive a packet at key phase one, before having sent or received any packets at key phase 0 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x1337, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { + client.rollKeys() + encrypted := client.Seal(nil, msg, 0x1337, ad) + encrypted = encrypted[:len(encrypted)-1] + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("errors when the peer updates keys too frequently", func() { + server.rollKeys() + client.rollKeys() + // receive the first packet at key phase one + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase two, before having sent any packets + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + })) + }) + }) + + Context("initiating key updates", func() { + const keyUpdateInterval = 20 + + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + server.SetHandshakeConfirmed() + }) + + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // Now that our keys are updated, send a packet using the new keys. + const nextPN = keyUpdateInterval + 1 + server.Seal(nil, msg, nextPN, ad) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", + })) + }) + + It("doesn't error before actually sending a packet in the new key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + // Now that our keys are updated, send a packet using the new keys. + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) + }) + + It("initiates a key update after opening the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, 1, ad) + Expect(server.SetLargestAcked(1)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(0)).To(Succeed()) + // Now we've initiated the first key update. + // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there + threePTO := 3 * rttStats.PTO(false) + dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // Now receive a packet with key phase 1. + // This should start the timer to drop the keys after 3 PTOs. + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) + t := now.Add(threePTO).Add(time.Second) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // Make sure the keys are still here. + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("doesn't drop the first key generation too early", func() { + now := time.Now() + data1 := client.Seal(nil, msg, 1, ad) + _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + Expect(server.SetLargestAcked(pn)).To(Succeed()) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // The server never received a packet at key phase 1. + // Make sure the key phase 0 is still there at a much later point. + data2 := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + const nextPN = keyUpdateInterval + 1 + // Send and receive an acknowledgement for a packet in key phase 1. + // We are now running a timer to drop the keys with 3 PTO. + server.Seal(nil, msg, nextPN, ad) + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) + now := time.Now() + _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(nextPN)) + // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. + // This mean that we need to drop the keys for key phase 0 immediately. + client.rollKeys() + dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), + ) + _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys early when we initiate another key update within the 3 PTO period", func() { + server.SetHandshakeConfirmed() + // send so many packets that we initiate the first key update + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // send so many packets that we initiate the next key update + for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + client.rollKeys() + b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) + now := time.Now() + _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), + ) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + // We haven't received an ACK for a packet sent in key phase 2 yet. + // Make sure we canceled the timer to drop the previous key phase. + b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) + _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + }) + }) + } + }) + } +}) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 6133594d..90f76c4b 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -19,6 +19,7 @@ import ( "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/netutil" reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/transport" "io" @@ -137,7 +138,7 @@ type Transport struct { transport.Interface connPoolOnce sync.Once - connPoolOrDef ClientConnPool // non-nil version of ConnPool + connPoolOrDef *clientConnPool // non-nil version of ConnPool } func (t *Transport) maxHeaderListSize() uint32 { @@ -161,60 +162,56 @@ func (t *Transport) pingTimeout() time.Duration { // ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. // It returns a new HTTP/2 Transport for further configuration. // It returns an error if t1 has already been HTTP/2-enabled. -func ConfigureTransports(t1 transport.Interface) (*Transport, error) { - connPool := new(clientConnPool) - t2 := &Transport{ - ConnPool: noDialClientConnPool{connPool}, - Interface: t1, - } - connPool.t = t2 - if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { - return nil, err - } - if t1.TLSClientConfig() == nil { - t1.SetTLSClientConfig(new(tls.Config)) - } - if !strSliceContains(t1.TLSClientConfig().NextProtos, "h2") { - t1.TLSClientConfig().NextProtos = append([]string{"h2"}, t1.TLSClientConfig().NextProtos...) - } - if !strSliceContains(t1.TLSClientConfig().NextProtos, "http/1.1") { - t1.TLSClientConfig().NextProtos = append(t1.TLSClientConfig().NextProtos, "http/1.1") - } - upgradeFn := func(authority string, c reqtls.Conn) http.RoundTripper { - addr := authorityAddr("https", authority) - if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { - go c.Close() - return erringRoundTripper{err} - } else if !used { - // Turns out we don't need this c. - // For example, two goroutines made requests to the same host - // at the same time, both kicking off TCP dials. (since protocol - // was unknown) - go c.Close() - } - return t2 - } - if m := t1.TLSNextProto(); len(m) == 0 { - t1.SetTLSNextProto(map[string]func(string, reqtls.Conn) http.RoundTripper{ - "h2": upgradeFn, - }) - } else { - m["h2"] = upgradeFn - } - return t2, nil -} - -func (t *Transport) connPool() ClientConnPool { +// func ConfigureTransports(t1 transport.Interface) (*Transport, error) { +// connPool := new(clientConnPool) +// t2 := &Transport{ +// ConnPool: noDialClientConnPool{connPool}, +// Interface: t1, +// } +// connPool.t = t2 +// if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { +// return nil, err +// } +// if t1.TLSClientConfig() == nil { +// t1.SetTLSClientConfig(new(tls.Config)) +// } +// if !strSliceContains(t1.TLSClientConfig().NextProtos, "h2") { +// t1.TLSClientConfig().NextProtos = append([]string{"h2"}, t1.TLSClientConfig().NextProtos...) +// } +// if !strSliceContains(t1.TLSClientConfig().NextProtos, "http/1.1") { +// t1.TLSClientConfig().NextProtos = append(t1.TLSClientConfig().NextProtos, "http/1.1") +// } +// upgradeFn := func(authority string, c reqtls.Conn) http.RoundTripper { +// addr := authorityAddr("https", authority) +// if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { +// go c.Close() +// return erringRoundTripper{err} +// } else if !used { +// // Turns out we don't need this c. +// // For example, two goroutines made requests to the same host +// // at the same time, both kicking off TCP dials. (since protocol +// // was unknown) +// go c.Close() +// } +// return t2 +// } +// if m := t1.TLSNextProto(); len(m) == 0 { +// t1.SetTLSNextProto(map[string]func(string, reqtls.Conn) http.RoundTripper{ +// "h2": upgradeFn, +// }) +// } else { +// m["h2"] = upgradeFn +// } +// return t2, nil +// } + +func (t *Transport) connPool() *clientConnPool { t.connPoolOnce.Do(t.initConnPool) return t.connPoolOrDef } func (t *Transport) initConnPool() { - if t.ConnPool != nil { - t.connPoolOrDef = t.ConnPool - } else { - t.connPoolOrDef = &clientConnPool{t: t} - } + t.connPoolOrDef = &clientConnPool{t: t} } // ClientConn is the state of a single HTTP/2 client connection to an @@ -434,6 +431,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripOpt(req, RoundTripOpt{}) } +func (t *Transport) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, error) { + return t.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) +} + // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func authorityAddr(scheme string, authority string) (addr string) { @@ -455,15 +456,29 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } +func (t *Transport) AddConn(conn net.Conn, addr string) error { + _, err := t.connPool().addConnIfNeeded(addr, t, conn) + return err +} + // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { return nil, errors.New("http2: unsupported scheme") } - addr := authorityAddr(req.URL.Scheme, req.URL.Host) + addr := netutil.AuthorityAddr(req.URL.Scheme, req.URL.Host) + var cc *ClientConn + var err error + if opt.OnlyCachedConn { + cc, err = t.connPool().getClientConn(req, addr, false) + if err != nil { + return nil, err + } + return cc.RoundTrip(req) + } for retry := 0; ; retry++ { - cc, err := t.connPool().GetClientConn(req, addr) + cc, err = t.connPool().getClientConn(req, addr, true) if err != nil { t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err @@ -501,9 +516,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. func (t *Transport) CloseIdleConnections() { - if cp, ok := t.connPool().(clientConnPoolIdleCloser); ok { - cp.closeIdleConnections() - } + t.connPool().closeIdleConnections() } var ( @@ -2975,18 +2988,6 @@ func isConnectionCloseRequest(req *http.Request) bool { return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") } -// registerHTTPSProtocol calls Transport.RegisterProtocol but -// converting panics into errors. -func registerHTTPSProtocol(t transport.Interface, rt noDialH2RoundTripper) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - }() - t.RegisterProtocol("https", rt) - return nil -} - // noDialH2RoundTripper is a RoundTripper which only tries to complete the request // if there's already has a cached connection to the host. // (The field is exported so it can be accessed via reflect from net/http; tested diff --git a/internal/http3/body.go b/internal/http3/body.go new file mode 100644 index 00000000..b3d1afd7 --- /dev/null +++ b/internal/http3/body.go @@ -0,0 +1,130 @@ +package http3 + +import ( + "context" + "io" + "net" + + "github.com/lucas-clemente/quic-go" +) + +// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: +// * for the server: the http.Request.Body +// * for the client: the http.Response.Body +// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. +// When a stream is taken over, it's the caller's responsibility to close the stream. +type HTTPStreamer interface { + HTTPStream() Stream +} + +type StreamCreator interface { + OpenStream() (quic.Stream, error) + OpenStreamSync(context.Context) (quic.Stream, error) + OpenUniStream() (quic.SendStream, error) + OpenUniStreamSync(context.Context) (quic.SendStream, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +var _ StreamCreator = quic.Connection(nil) + +// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body. +// It is used by WebTransport to create WebTransport streams after a session has been established. +type Hijacker interface { + StreamCreator() StreamCreator +} + +// The body of a http.Request or http.Response. +type body struct { + str quic.Stream + + wasHijacked bool // set when HTTPStream is called +} + +var ( + _ io.ReadCloser = &body{} + _ HTTPStreamer = &body{} +) + +func newRequestBody(str Stream) *body { + return &body{str: str} +} + +func (r *body) HTTPStream() Stream { + r.wasHijacked = true + return r.str +} + +func (r *body) wasStreamHijacked() bool { + return r.wasHijacked +} + +func (r *body) Read(b []byte) (int, error) { + return r.str.Read(b) +} + +func (r *body) Close() error { + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} + +type hijackableBody struct { + body + conn quic.Connection // only needed to implement Hijacker + + // only set for the http.Response + // The channel is closed when the user is done with this response: + // either when Read() errors, or when Close() is called. + reqDone chan<- struct{} + reqDoneClosed bool +} + +var ( + _ Hijacker = &hijackableBody{} + _ HTTPStreamer = &hijackableBody{} +) + +func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { + return &hijackableBody{ + body: body{ + str: str, + }, + reqDone: done, + conn: conn, + } +} + +func (r *hijackableBody) StreamCreator() StreamCreator { + return r.conn +} + +func (r *hijackableBody) Read(b []byte) (int, error) { + n, err := r.str.Read(b) + if err != nil { + r.requestDone() + } + return n, err +} + +func (r *hijackableBody) requestDone() { + if r.reqDoneClosed || r.reqDone == nil { + return + } + close(r.reqDone) + r.reqDoneClosed = true +} + +func (r *body) StreamID() quic.StreamID { + return r.str.StreamID() +} + +func (r *hijackableBody) Close() error { + r.requestDone() + // If the EOF was read, CancelRead() is a no-op. + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} + +func (r *hijackableBody) HTTPStream() Stream { + return r.str +} diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go new file mode 100644 index 00000000..ee289d1d --- /dev/null +++ b/internal/http3/body_test.go @@ -0,0 +1,54 @@ +package http3 + +import ( + "errors" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/lucas-clemente/quic-go" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Response Body", func() { + var reqDone chan struct{} + + BeforeEach(func() { reqDone = make(chan struct{}) }) + + It("closes the reqDone channel when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(MatchError("test error")) + Expect(reqDone).To(BeClosed()) + }) + + It("allows multiple calls to Read, when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + _, err = rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + + It("closes responses", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + Expect(rb.Close()).To(Succeed()) + }) + + It("allows multiple calls to Close", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) + Expect(rb.Close()).To(Succeed()) + Expect(reqDone).To(BeClosed()) + Expect(rb.Close()).To(Succeed()) + }) +}) diff --git a/internal/http3/client.go b/internal/http3/client.go new file mode 100644 index 00000000..200b742f --- /dev/null +++ b/internal/http3/client.go @@ -0,0 +1,428 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qtls" + "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/marten-seemann/qpack" + "io" + "net/http" + "strconv" + "sync" +) + +// MethodGet0RTT allows a GET request to be sent using 0-RTT. +// Note that 0-RTT data doesn't provide replay protection. +const MethodGet0RTT = "GET_0RTT" + +const ( + defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB +) + +var defaultQuicConfig = &quic.Config{ + MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams + // KeepAlivePeriod: 10 * time.Second, + Versions: []quic.VersionNumber{protocol.VersionTLS}, +} + +type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + +var dialAddr = quic.DialAddrEarlyContext + +type roundTripperOpts struct { + DisableCompression bool + EnableDatagram bool + MaxHeaderBytes int64 + AdditionalSettings map[uint64]uint64 + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + dump *dump.Dumper +} + +// client is a HTTP3 client doing requests +type client struct { + tlsConf *tls.Config + config *quic.Config + opts *roundTripperOpts + + dialOnce sync.Once + dialer dialFunc + handshakeErr error + + requestWriter *requestWriter + + decoder *qpack.Decoder + + hostname string + conn quic.EarlyConnection + + logger utils.Logger +} + +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { + if conf == nil { + conf = defaultQuicConfig.Clone() + } else if len(conf.Versions) == 0 { + conf = conf.Clone() + conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + } + if len(conf.Versions) != 1 { + return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if conf.MaxIncomingStreams == 0 { + conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + conf.EnableDatagrams = opts.EnableDatagram + logger := utils.DefaultLogger.WithPrefix("h3 client") + + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} + + return &client{ + hostname: authorityAddr("https", hostname), + tlsConf: tlsConf, + requestWriter: newRequestWriter(logger), + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + config: conf, + opts: opts, + dialer: dialer, + logger: logger, + }, nil +} + +func (c *client) dial(ctx context.Context) error { + var err error + if c.dialer != nil { + c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) + } else { + c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) + } + if err != nil { + return err + } + + // send the SETTINGs frame, using 0-RTT data, if possible + go func() { + if err := c.setupConn(); err != nil { + c.logger.Debugf("Setting up connection failed: %s", err) + c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + } + }() + + if c.opts.StreamHijacker != nil { + go c.handleBidirectionalStreams() + } + go c.handleUnidirectionalStreams() + return nil +} + +func (c *client) setupConn() error { + // open the control stream + str, err := c.conn.OpenUniStream() + if err != nil { + return err + } + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + // send the SETTINGS frame + (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf) + _, err = str.Write(buf.Bytes()) + return err +} + +func (c *client) handleBidirectionalStreams() { + for { + str, err := c.conn.AcceptStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting bidirectional stream failed: %s", err) + return + } + go func(str quic.Stream) { + _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { + return c.opts.StreamHijacker(ft, c.conn, str, e) + }) + if err == errHijacked { + return + } + if err != nil { + c.logger.Debugf("error handling stream: %s", err) + } + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + }(str) + } +} + +func (c *client) handleUnidirectionalStreams() { + for { + str, err := c.conn.AcceptUniStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting unidirectional stream failed: %s", err) + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) { + return + } + c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: + // Our QPACK implementation doesn't use the dynamic table yet. + // TODO: check that only one stream of each type is opened. + return + case streamTypePushStream: + // We never increased the Push ID, so we don't expect any push streams. + c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + return + default: + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) { + return + } + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + return + } + f, err := parseNextFrame(str, nil) + if err != nil { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + return + } + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the server side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + } + }(str) + } +} + +func (c *client) Close() error { + if c.conn == nil { + return nil + } + return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") +} + +func (c *client) maxHeaderBytes() uint64 { + if c.opts.MaxHeaderBytes <= 0 { + return defaultMaxResponseHeaderBytes + } + return uint64(c.opts.MaxHeaderBytes) +} + +// RoundTripOpt executes a request and returns a response +func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { + return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) + } + + c.dialOnce.Do(func() { + c.handshakeErr = c.dial(req.Context()) + }) + + if c.handshakeErr != nil { + return nil, c.handshakeErr + } + + // Immediately send out this request, if this is a 0-RTT request. + if req.Method == MethodGet0RTT { + req.Method = http.MethodGet + } else { + // wait for the handshake to complete + select { + case <-c.conn.HandshakeComplete().Done(): + case <-req.Context().Done(): + return nil, req.Context().Err() + } + } + + str, err := c.conn.OpenStreamSync(req.Context()) + if err != nil { + return nil, err + } + + // Request Cancellation: + // This go routine keeps running even after RoundTripOpt() returns. + // It is shut down when the application is done processing the body. + reqDone := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + case <-reqDone: + } + }() + + rsp, rerr := c.doRequest(req, str, opt, reqDone) + if rerr.err != nil { // if any error occurred + close(reqDone) + if rerr.streamErr != 0 { // if it was a stream error + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) + } + if rerr.connErr != 0 { // if it was a connection error + var reason string + if rerr.err != nil { + reason = rerr.err.Error() + } + c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + } + } + return rsp, rerr.err +} + +func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { + defer body.Close() + b := make([]byte, bodyCopyBufferSize) + for { + n, rerr := body.Read(b) + if n == 0 { + if rerr == nil { + continue + } + if rerr == io.EOF { + break + } + } + if _, err := str.Write(b[:n]); err != nil { + return err + } + if rerr != nil { + if rerr == io.EOF { + break + } + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + return rerr + } + } + return nil +} + +func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { + var requestGzip bool + if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { + requestGzip = true + } + var dumps []*dump.Dumper + for _, dump := range dump.GetDumpers(req.Context(), c.opts.dump) { + if dump.RequestHeader() { + dumps = append(dumps, dump) + } + } + if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, dumps); err != nil { + return nil, newStreamError(errorInternalError, err) + } + + if req.Body == nil && !opt.DontCloseRequestStream { + str.Close() + } + + hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if req.Body != nil { + // send the request body asynchronously + go func() { + if err := c.sendRequestBody(hstr, req.Body); err != nil { + c.logger.Errorf("Error writing request: %s", err) + } + if !opt.DontCloseRequestStream { + hstr.Close() + } + }() + } + + frame, err := parseNextFrame(str, nil) + if err != nil { + return nil, newStreamError(errorFrameError, err) + } + hf, ok := frame.(*headersFrame) + if !ok { + return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) + } + if hf.Length > c.maxHeaderBytes() { + return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(str, headerBlock); err != nil { + return nil, newStreamError(errorRequestIncomplete, err) + } + hfs, err := c.decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + return nil, newConnError(errorGeneralProtocolError, err) + } + + connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS) + res := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: http.Header{}, + TLS: &connState, + } + for _, hf := range hfs { + switch hf.Name { + case ":status": + status, err := strconv.Atoi(hf.Value) + if err != nil { + return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header")) + } + res.StatusCode = status + res.Status = hf.Value + " " + http.StatusText(status) + default: + res.Header.Add(hf.Name, hf.Value) + } + } + respBody := newResponseBody(hstr, c.conn, reqDone) + + // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. + _, hasTransferEncoding := res.Header["Transfer-Encoding"] + isInformational := res.StatusCode >= 100 && res.StatusCode < 200 + isNoContent := res.StatusCode == 204 + isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 + if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { + res.ContentLength = -1 + if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { + if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { + res.ContentLength = clen64 + } + } + } + + if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newGzipReader(respBody) + res.Uncompressed = true + } else { + res.Body = respBody + } + + return res, requestError{} +} diff --git a/internal/http3/client_test.go b/internal/http3/client_test.go new file mode 100644 index 00000000..9952572b --- /dev/null +++ b/internal/http3/client_test.go @@ -0,0 +1,1003 @@ +package http3 + +import ( + "bytes" + "compress/gzip" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "time" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quicvarint" + + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qpack" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Client", func() { + var ( + client *client + req *http.Request + origDialAddr = dialAddr + handshakeCtx context.Context // an already canceled context + ) + + BeforeEach(func() { + origDialAddr = dialAddr + hostname := "quic.clemente.io:1337" + var err error + client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(client.hostname).To(Equal(hostname)) + + req, err = http.NewRequest("GET", "https://localhost:1337", nil) + Expect(err).ToNot(HaveOccurred()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + handshakeCtx = ctx + }) + + AfterEach(func() { + dialAddr = origDialAddr + }) + + It("rejects quic.Configs that allow multiple QUIC versions", func() { + qconf := &quic.Config{ + Versions: []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1}, + } + _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) + Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) + }) + + It("uses the default QUIC and TLS config if none is give", func() { + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + var dialAddrCalled bool + dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + Expect(quicConf).To(Equal(defaultQuicConfig)) + Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(quicConf.Versions).To(Equal([]quic.VersionNumber{protocol.Version1})) + dialAddrCalled = true + return nil, errors.New("test done") + } + client.RoundTripOpt(req, RoundTripOpt{}) + Expect(dialAddrCalled).To(BeTrue()) + }) + + It("adds the port to the hostname, if none is given", func() { + client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + var dialAddrCalled bool + dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + Expect(hostname).To(Equal("quic.clemente.io:443")) + dialAddrCalled = true + return nil, errors.New("test done") + } + req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) + Expect(err).ToNot(HaveOccurred()) + client.RoundTripOpt(req, RoundTripOpt{}) + Expect(dialAddrCalled).To(BeTrue()) + }) + + It("uses the TLS config and QUIC config", func() { + tlsConf := &tls.Config{ + ServerName: "foo.bar", + NextProtos: []string{"proto foo", "proto bar"}, + } + quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} + client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) + Expect(err).ToNot(HaveOccurred()) + var dialAddrCalled bool + dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + Expect(host).To(Equal("localhost:1337")) + Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) + Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) + dialAddrCalled = true + return nil, errors.New("test done") + } + client.RoundTripOpt(req, RoundTripOpt{}) + Expect(dialAddrCalled).To(BeTrue()) + // make sure the original tls.Config was not modified + Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) + }) + + It("uses the custom dialer, if provided", func() { + testErr := errors.New("test done") + tlsConf := &tls.Config{ServerName: "foo.bar"} + quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + var dialerCalled bool + dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + Expect(ctxP).To(Equal(ctx)) + Expect(address).To(Equal("localhost:1337")) + Expect(tlsConfP.ServerName).To(Equal("foo.bar")) + Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) + dialerCalled = true + return nil, testErr + } + client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) + Expect(err).ToNot(HaveOccurred()) + _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + Expect(dialerCalled).To(BeTrue()) + }) + + It("enables HTTP/3 Datagrams", func() { + testErr := errors.New("handshake error") + client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + Expect(quicConf.EnableDatagrams).To(BeTrue()) + return nil, testErr + } + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + }) + + It("errors when dialing fails", func() { + testErr := errors.New("handshake error") + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return nil, testErr + } + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + }) + + It("closes correctly if connection was not created", func() { + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(client.Close()).To(Succeed()) + }) + + Context("validating the address", func() { + It("refuses to do requests for the wrong host", func() { + req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) + }) + + It("allows requests using a different scheme", func() { + testErr := errors.New("handshake error") + req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return nil, testErr + } + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + }) + }) + + Context("hijacking bidirectional streams", func() { + var ( + request *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("hijacks a bidirectional stream of unknown frame type", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return true, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + }) + + It("closes the connection when hijacker returned error", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, errors.New("error in hijacker") + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + }) + + It("handles errors that occur when reading the frame type", func() { + testErr := errors.New("test error") + unknownStr := mockquic.NewMockStream(mockCtrl) + done := make(chan struct{}) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { + defer close(done) + Expect(e).To(MatchError(testErr)) + Expect(ft).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + return false, nil + } + + unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + + Context("hijacking unidirectional streams", func() { + var ( + req *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("handles errors that occur when reading the stream type", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + unknownStr := mockquic.NewMockStream(mockCtrl) + client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { + defer close(done) + Expect(st).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + Expect(err).To(MatchError(testErr)) + return true + } + + unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + + Context("control stream handling", func() { + var ( + req *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("parses the SETTINGS frame", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&settingsFrame{}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { + streamType := t + name := "encoder" + if streamType == streamTypeQPACKDecoderStream { + name = "decoder" + } + + It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamType) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return str, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead + }) + } + + It("resets streams Other than the control stream and the QPACK streams", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + done := make(chan struct{}) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { + close(done) + }) + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return str, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + }) + + It("errors when the first frame on the control stream is not a SETTINGS frame", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&dataFrame{}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorMissingSettings)) + close(done) + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + }) + + It("errors when parsing the frame on the control stream fails", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + b := &bytes.Buffer{} + (&settingsFrame{}).Write(b) + buf.Write(b.Bytes()[:b.Len()-1]) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorFrameError)) + close(done) + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + }) + + It("errors when parsing the server opens a push stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypePushStream) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorIDError)) + close(done) + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + }) + + It("errors when the server advertises datagram support (and we enabled support for it)", func() { + client.opts.EnableDatagram = true + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&settingsFrame{Datagram: true}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorSettingsError)) + Expect(reason).To(Equal("missing QUIC Datagram support")) + close(done) + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("Doing requests", func() { + var ( + req *http.Request + str *mockquic.MockStream + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + getHeadersFrame := func(headers map[string]string) []byte { + buf := &bytes.Buffer{} + headerBuf := &bytes.Buffer{} + enc := qpack.NewEncoder(headerBuf) + for name, value := range headers { + Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed()) + } + Expect(enc.Close()).To(Succeed()) + (&headersFrame{Length: uint64(headerBuf.Len())}).Write(buf) + buf.Write(headerBuf.Bytes()) + return buf.Bytes() + } + + decodeHeader := func(str io.Reader) map[string]string { + fields := make(map[string]string) + decoder := qpack.NewDecoder(nil) + + frame, err := parseNextFrame(str, nil) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) + headersFrame := frame.(*headersFrame) + data := make([]byte, headersFrame.Length) + _, err = io.ReadFull(str, data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + hfs, err := decoder.DecodeFull(data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + for _, p := range hfs { + fields[p.Name] = p.Value + } + return fields + } + + getResponse := func(status int) []byte { + buf := &bytes.Buffer{} + rstr := mockquic.NewMockStream(mockCtrl) + rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw.WriteHeader(status) + rw.Flush() + return buf.Bytes() + } + + BeforeEach(func() { + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + r := bytes.NewReader(b) + streamType, err := quicvarint.Read(r) + Expect(err).ToNot(HaveOccurred()) + Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) + close(settingsFrameWritten) + }) // SETTINGS frame + str = mockquic.NewMockStream(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("errors if it can't open a stream", func() { + testErr := errors.New("stream open error") + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + }) + + It("performs a 0-RTT request", func() { + testErr := errors.New("stream open error") + req.Method = MethodGet0RTT + // don't EXPECT any calls to HandshakeComplete() + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + str.EXPECT().Close() + str.EXPECT().CancelWrite(gomock.Any()) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + return 0, testErr + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) + }) + + It("returns a response", func() { + rspBuf := bytes.NewBuffer(getResponse(418)) + gomock.InOrder( + conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), + ) + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Close() + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Proto).To(Equal("HTTP/3.0")) + Expect(rsp.ProtoMajor).To(Equal(3)) + Expect(rsp.StatusCode).To(Equal(418)) + }) + + Context("requests containing a Body", func() { + var strBuf *bytes.Buffer + + BeforeEach(func() { + strBuf = &bytes.Buffer{} + gomock.InOrder( + conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + ) + body := &mockBody{} + body.SetData([]byte("request body")) + var err error + req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) + Expect(err).ToNot(HaveOccurred()) + str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() + }) + + It("sends a request", func() { + done := make(chan struct{}) + gomock.InOrder( + str.EXPECT().Close().Do(func() { close(done) }), + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors + ) + // the response body is sent asynchronously, while already reading the response + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + <-done + return 0, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + hfs := decodeHeader(strBuf) + Expect(hfs).To(HaveKeyWithValue(":method", "POST")) + Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) + }) + + It("returns the error that occurred when reading the body", func() { + req.Body.(*mockBody).readErr = errors.New("testErr") + done := make(chan struct{}) + gomock.InOrder( + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { + close(done) + }), + str.EXPECT().CancelWrite(gomock.Any()), + ) + + // the response body is sent asynchronously, while already reading the response + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + <-done + return 0, errors.New("test done") + }) + closed := make(chan struct{}) + str.EXPECT().Close().Do(func() { close(closed) }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Eventually(closed).Should(BeClosed()) + }) + + It("sets the Content-Length", func() { + done := make(chan struct{}) + buf := &bytes.Buffer{} + buf.Write(getHeadersFrame(map[string]string{ + ":status": "200", + "Content-Length": "1337", + })) + (&dataFrame{Length: 0x6}).Write(buf) + buf.Write([]byte("foobar")) + str.EXPECT().Close().Do(func() { close(done) }) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors + // the response body is sent asynchronously, while already reading the response + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + req, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).ToNot(HaveOccurred()) + Expect(req.ContentLength).To(BeEquivalentTo(1337)) + Eventually(done).Should(BeClosed()) + }) + + It("closes the connection when the first frame is not a HEADERS frame", func() { + buf := &bytes.Buffer{} + (&dataFrame{Length: 0x42}).Write(buf) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()) + closed := make(chan struct{}) + str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) + Eventually(closed).Should(BeClosed()) + }) + + It("cancels the stream when the HEADERS frame is too large", func() { + buf := &bytes.Buffer{} + (&headersFrame{Length: 1338}).Write(buf) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)) + closed := make(chan struct{}) + str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) + Eventually(closed).Should(BeClosed()) + }) + }) + + Context("request cancellations", func() { + It("cancels a request while waiting for the handshake to complete", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + + errChan := make(chan error) + go func() { + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + errChan <- err + }() + Consistently(errChan).ShouldNot(Receive()) + cancel() + Eventually(errChan).Should(Receive(MatchError("context canceled"))) + }) + + It("cancels a request while the request is still in flight", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + + done := make(chan struct{}) + canceled := make(chan struct{}) + gomock.InOrder( + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), + ) + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + cancel() + <-canceled + return 0, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Eventually(done).Should(BeClosed()) + }) + + It("cancels a request after the response arrived", func() { + rspBuf := bytes.NewBuffer(getResponse(404)) + + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).ToNot(HaveOccurred()) + cancel() + Eventually(done).Should(BeClosed()) + }) + }) + + Context("gzip compression", func() { + BeforeEach(func() { + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + }) + + It("adds the gzip header to requests", func() { + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + gomock.InOrder( + str.EXPECT().Close(), + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors + ) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + hfs := decodeHeader(buf) + Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) + }) + + It("doesn't add gzip if the header disable it", func() { + client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + gomock.InOrder( + str.EXPECT().Close(), + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors + ) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + hfs := decodeHeader(buf) + Expect(hfs).ToNot(HaveKey("accept-encoding")) + }) + + It("decompresses the response", func() { + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + buf := &bytes.Buffer{} + rstr := mockquic.NewMockStream(mockCtrl) + rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(rw) + gz.Write([]byte("gzipped response")) + gz.Close() + rw.Flush() + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str.EXPECT().Close() + + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) + Expect(string(data)).To(Equal("gzipped response")) + Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) + Expect(rsp.Uncompressed).To(BeTrue()) + }) + + It("only decompresses the response if the response contains the right content-encoding header", func() { + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + buf := &bytes.Buffer{} + rstr := mockquic.NewMockStream(mockCtrl) + rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw.Write([]byte("not gzipped")) + rw.Flush() + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str.EXPECT().Close() + + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("not gzipped")) + Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) + }) + }) + }) +}) diff --git a/internal/http3/error_codes.go b/internal/http3/error_codes.go new file mode 100644 index 00000000..d87eef4a --- /dev/null +++ b/internal/http3/error_codes.go @@ -0,0 +1,73 @@ +package http3 + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +type errorCode quic.ApplicationErrorCode + +const ( + errorNoError errorCode = 0x100 + errorGeneralProtocolError errorCode = 0x101 + errorInternalError errorCode = 0x102 + errorStreamCreationError errorCode = 0x103 + errorClosedCriticalStream errorCode = 0x104 + errorFrameUnexpected errorCode = 0x105 + errorFrameError errorCode = 0x106 + errorExcessiveLoad errorCode = 0x107 + errorIDError errorCode = 0x108 + errorSettingsError errorCode = 0x109 + errorMissingSettings errorCode = 0x10a + errorRequestRejected errorCode = 0x10b + errorRequestCanceled errorCode = 0x10c + errorRequestIncomplete errorCode = 0x10d + errorMessageError errorCode = 0x10e + errorConnectError errorCode = 0x10f + errorVersionFallback errorCode = 0x110 + errorDatagramError errorCode = 0x4a1268 +) + +func (e errorCode) String() string { + switch e { + case errorNoError: + return "H3_NO_ERROR" + case errorGeneralProtocolError: + return "H3_GENERAL_PROTOCOL_ERROR" + case errorInternalError: + return "H3_INTERNAL_ERROR" + case errorStreamCreationError: + return "H3_STREAM_CREATION_ERROR" + case errorClosedCriticalStream: + return "H3_CLOSED_CRITICAL_STREAM" + case errorFrameUnexpected: + return "H3_FRAME_UNEXPECTED" + case errorFrameError: + return "H3_FRAME_ERROR" + case errorExcessiveLoad: + return "H3_EXCESSIVE_LOAD" + case errorIDError: + return "H3_ID_ERROR" + case errorSettingsError: + return "H3_SETTINGS_ERROR" + case errorMissingSettings: + return "H3_MISSING_SETTINGS" + case errorRequestRejected: + return "H3_REQUEST_REJECTED" + case errorRequestCanceled: + return "H3_REQUEST_CANCELLED" + case errorRequestIncomplete: + return "H3_INCOMPLETE_REQUEST" + case errorMessageError: + return "H3_MESSAGE_ERROR" + case errorConnectError: + return "H3_CONNECT_ERROR" + case errorVersionFallback: + return "H3_VERSION_FALLBACK" + case errorDatagramError: + return "H3_DATAGRAM_ERROR" + default: + return fmt.Sprintf("unknown error code: %#x", uint16(e)) + } +} diff --git a/internal/http3/error_codes_test.go b/internal/http3/error_codes_test.go new file mode 100644 index 00000000..e4aae37e --- /dev/null +++ b/internal/http3/error_codes_test.go @@ -0,0 +1,39 @@ +package http3 + +import ( + "go/ast" + "go/parser" + "go/token" + "path" + "runtime" + "strconv" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("error codes", func() { + It("has a string representation for every error code", func() { + // We parse the error code file, extract all constants, and verify that + // each of them has a string version. Go FTW! + _, thisfile, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + filename := path.Join(path.Dir(thisfile), "error_codes.go") + fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) + Expect(err).NotTo(HaveOccurred()) + constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs + Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing + for _, c := range constSpecs { + valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + val, err := strconv.ParseInt(valString, 0, 64) + Expect(err).NotTo(HaveOccurred()) + Expect(errorCode(val).String()).ToNot(Equal("unknown error code")) + } + }) + + It("has a string representation for unknown error codes", func() { + Expect(errorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) + }) +}) diff --git a/internal/http3/frames.go b/internal/http3/frames.go new file mode 100644 index 00000000..f7f28913 --- /dev/null +++ b/internal/http3/frames.go @@ -0,0 +1,164 @@ +package http3 + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// FrameType is the frame type of a HTTP/3 frame +type FrameType uint64 + +type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error) + +type frame interface{} + +var errHijacked = errors.New("hijacked") + +func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { + qr := quicvarint.NewReader(r) + for { + t, err := quicvarint.Read(qr) + if err != nil { + if unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(0, err) + if err != nil { + return nil, err + } + if hijacked { + return nil, errHijacked + } + } + return nil, err + } + // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec + if t > 0xd && unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(FrameType(t), nil) + if err != nil { + return nil, err + } + if hijacked { + return nil, errHijacked + } + // If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it. + } + l, err := quicvarint.Read(qr) + if err != nil { + return nil, err + } + + switch t { + case 0x0: + return &dataFrame{Length: l}, nil + case 0x1: + return &headersFrame{Length: l}, nil + case 0x4: + return parseSettingsFrame(r, l) + case 0x3: // CANCEL_PUSH + case 0x5: // PUSH_PROMISE + case 0x7: // GOAWAY + case 0xd: // MAX_PUSH_ID + } + // skip over unknown frames + if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { + return nil, err + } + } +} + +type dataFrame struct { + Length uint64 +} + +func (f *dataFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x0) + quicvarint.Write(b, f.Length) +} + +type headersFrame struct { + Length uint64 +} + +func (f *headersFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x1) + quicvarint.Write(b, f.Length) +} + +const settingDatagram = 0xffd277 + +type settingsFrame struct { + Datagram bool + Other map[uint64]uint64 // all settings that we don't explicitly recognize +} + +func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { + if l > 8*(1<<10) { + return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l) + } + buf := make([]byte, l) + if _, err := io.ReadFull(r, buf); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + frame := &settingsFrame{} + b := bytes.NewReader(buf) + var readDatagram bool + for b.Len() > 0 { + id, err := quicvarint.Read(b) + if err != nil { // should not happen. We allocated the whole frame already. + return nil, err + } + val, err := quicvarint.Read(b) + if err != nil { // should not happen. We allocated the whole frame already. + return nil, err + } + + switch id { + case settingDatagram: + if readDatagram { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + readDatagram = true + if val != 0 && val != 1 { + return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val) + } + frame.Datagram = val == 1 + default: + if _, ok := frame.Other[id]; ok { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + if frame.Other == nil { + frame.Other = make(map[uint64]uint64) + } + frame.Other[id] = val + } + } + return frame, nil +} + +func (f *settingsFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x4) + var l protocol.ByteCount + for id, val := range f.Other { + l += quicvarint.Len(id) + quicvarint.Len(val) + } + if f.Datagram { + l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) + } + quicvarint.Write(b, uint64(l)) + if f.Datagram { + quicvarint.Write(b, settingDatagram) + quicvarint.Write(b, 1) + } + for id, val := range f.Other { + quicvarint.Write(b, id) + quicvarint.Write(b, val) + } +} diff --git a/internal/http3/frames_test.go b/internal/http3/frames_test.go new file mode 100644 index 00000000..cf0efe9d --- /dev/null +++ b/internal/http3/frames_test.go @@ -0,0 +1,245 @@ +package http3 + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type errReader struct{ err error } + +func (e errReader) Read([]byte) (int, error) { return 0, e.err } + +var _ = Describe("Frames", func() { + appendVarInt := func(b []byte, val uint64) []byte { + buf := &bytes.Buffer{} + quicvarint.Write(buf, val) + return append(b, buf.Bytes()...) + } + + It("skips unknown frames", func() { + data := appendVarInt(nil, 0xdeadbeef) // type byte + data = appendVarInt(data, 0x42) + data = append(data, make([]byte, 0x42)...) + buf := bytes.NewBuffer(data) + (&dataFrame{Length: 0x1234}).Write(buf) + frame, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) + Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) + }) + + Context("DATA frames", func() { + It("parses", func() { + data := appendVarInt(nil, 0) // type byte + data = appendVarInt(data, 0x1337) + frame, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) + Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) + }) + + It("writes", func() { + buf := &bytes.Buffer{} + (&dataFrame{Length: 0xdeadbeef}).Write(buf) + frame, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) + Expect(frame.(*dataFrame).Length).To(Equal(uint64(0xdeadbeef))) + }) + }) + + Context("HEADERS frames", func() { + It("parses", func() { + data := appendVarInt(nil, 1) // type byte + data = appendVarInt(data, 0x1337) + frame, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) + Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) + }) + + It("writes", func() { + buf := &bytes.Buffer{} + (&headersFrame{Length: 0xdeadbeef}).Write(buf) + frame, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) + Expect(frame.(*headersFrame).Length).To(Equal(uint64(0xdeadbeef))) + }) + }) + + Context("SETTINGS frames", func() { + It("parses", func() { + settings := appendVarInt(nil, 13) + settings = appendVarInt(settings, 37) + settings = appendVarInt(settings, 0xdead) + settings = appendVarInt(settings, 0xbeef) + data := appendVarInt(nil, 4) // type byte + data = appendVarInt(data, uint64(len(settings))) + data = append(data, settings...) + frame, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) + sf := frame.(*settingsFrame) + Expect(sf.Other).To(HaveKeyWithValue(uint64(13), uint64(37))) + Expect(sf.Other).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef))) + }) + + It("rejects duplicate settings", func() { + settings := appendVarInt(nil, 13) + settings = appendVarInt(settings, 37) + settings = appendVarInt(settings, 13) + settings = appendVarInt(settings, 38) + data := appendVarInt(nil, 4) // type byte + data = appendVarInt(data, uint64(len(settings))) + data = append(data, settings...) + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).To(MatchError("duplicate setting: 13")) + }) + + It("writes", func() { + sf := &settingsFrame{Other: map[uint64]uint64{ + 1: 2, + 99: 999, + 13: 37, + }} + buf := &bytes.Buffer{} + sf.Write(buf) + frame, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(sf)) + }) + + It("errors on EOF", func() { + sf := &settingsFrame{Other: map[uint64]uint64{ + 13: 37, + 0xdeadbeef: 0xdecafbad, + }} + buf := &bytes.Buffer{} + sf.Write(buf) + + data := buf.Bytes() + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + + for i := range data { + b := make([]byte, i) + copy(b, data[:i]) + _, err := parseNextFrame(bytes.NewReader(b), nil) + Expect(err).To(MatchError(io.EOF)) + } + }) + + Context("H3_DATAGRAM", func() { + It("reads the H3_DATAGRAM value", func() { + settings := appendVarInt(nil, settingDatagram) + settings = appendVarInt(settings, 1) + data := appendVarInt(nil, 4) // type byte + data = appendVarInt(data, uint64(len(settings))) + data = append(data, settings...) + f, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) + sf := f.(*settingsFrame) + Expect(sf.Datagram).To(BeTrue()) + }) + + It("rejects duplicate H3_DATAGRAM entries", func() { + settings := appendVarInt(nil, settingDatagram) + settings = appendVarInt(settings, 1) + settings = appendVarInt(settings, settingDatagram) + settings = appendVarInt(settings, 1) + data := appendVarInt(nil, 4) // type byte + data = appendVarInt(data, uint64(len(settings))) + data = append(data, settings...) + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) + }) + + It("rejects invalid values for the H3_DATAGRAM entry", func() { + settings := appendVarInt(nil, settingDatagram) + settings = appendVarInt(settings, 1337) + data := appendVarInt(nil, 4) // type byte + data = appendVarInt(data, uint64(len(settings))) + data = append(data, settings...) + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) + }) + + It("writes the H3_DATAGRAM setting", func() { + sf := &settingsFrame{Datagram: true} + buf := &bytes.Buffer{} + sf.Write(buf) + frame, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(sf)) + }) + }) + }) + + Context("hijacking", func() { + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("foobar") + buf.Write(customFrameContents) + + var called bool + _, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foo")) + return true, nil + }) + Expect(err).To(MatchError(errHijacked)) + Expect(called).To(BeTrue()) + }) + + It("passes on errors that occur when reading the frame type", func() { + testErr := errors.New("test error") + var called bool + _, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).To(MatchError(testErr)) + Expect(ft).To(BeZero()) + called = true + return true, nil + }) + Expect(err).To(MatchError(errHijacked)) + Expect(called).To(BeTrue()) + }) + + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("custom frame") + quicvarint.Write(buf, uint64(len(customFrameContents))) + buf.Write(customFrameContents) + (&dataFrame{Length: 6}).Write(buf) + buf.WriteString("foobar") + + var called bool + frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + Expect(ft).To(BeEquivalentTo(1337)) + called = true + return false, nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(&dataFrame{Length: 6})) + Expect(called).To(BeTrue()) + }) + }) +}) diff --git a/internal/http3/gzip_reader.go b/internal/http3/gzip_reader.go new file mode 100644 index 00000000..01983ac7 --- /dev/null +++ b/internal/http3/gzip_reader.go @@ -0,0 +1,39 @@ +package http3 + +// copied from net/transport.go + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +import ( + "compress/gzip" + "io" +) + +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func newGzipReader(body io.ReadCloser) io.ReadCloser { + return &gzipReader{body: body} +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} diff --git a/internal/http3/http3_suite_test.go b/internal/http3/http3_suite_test.go new file mode 100644 index 00000000..c94d932a --- /dev/null +++ b/internal/http3/http3_suite_test.go @@ -0,0 +1,38 @@ +package http3 + +import ( + "os" + "strconv" + "testing" + "time" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestHttp3(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "HTTP/3 Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) + +//nolint:unparam +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go new file mode 100644 index 00000000..4c69068c --- /dev/null +++ b/internal/http3/http_stream.go @@ -0,0 +1,71 @@ +package http3 + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +// A Stream is a HTTP/3 stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type Stream quic.Stream + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type stream struct { + quic.Stream + + onFrameError func() + bytesRemainingInFrame uint64 +} + +var _ Stream = &stream{} + +func newStream(str quic.Stream, onFrameError func()) *stream { + return &stream{Stream: str, onFrameError: onFrameError} +} + +func (s *stream) Read(b []byte) (int, error) { + if s.bytesRemainingInFrame == 0 { + parseLoop: + for { + frame, err := parseNextFrame(s.Stream, nil) + if err != nil { + return 0, err + } + switch f := frame.(type) { + case *headersFrame: + // skip HEADERS frames + continue + case *dataFrame: + s.bytesRemainingInFrame = f.Length + break parseLoop + default: + s.onFrameError() + // parseNextFrame skips over unknown frame types + // Therefore, this condition is only entered when we parsed another known frame type. + return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) + } + } + } + + var n int + var err error + if s.bytesRemainingInFrame < uint64(len(b)) { + n, err = s.Stream.Read(b[:s.bytesRemainingInFrame]) + } else { + n, err = s.Stream.Read(b) + } + s.bytesRemainingInFrame -= uint64(n) + return n, err +} + +func (s *stream) Write(b []byte) (int, error) { + buf := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(b))}).Write(buf) + if _, err := s.Stream.Write(buf.Bytes()); err != nil { + return 0, err + } + return s.Stream.Write(b) +} diff --git a/internal/http3/http_stream_test.go b/internal/http3/http_stream_test.go new file mode 100644 index 00000000..6d3fef02 --- /dev/null +++ b/internal/http3/http_stream_test.go @@ -0,0 +1,150 @@ +package http3 + +import ( + "bytes" + "io" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream", func() { + Context("reading", func() { + var ( + str Stream + qstr *mockquic.MockStream + buf *bytes.Buffer + errorCbCalled bool + ) + + errorCb := func() { errorCbCalled = true } + getDataFrame := func(data []byte) []byte { + b := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(data))}).Write(b) + b.Write(data) + return b.Bytes() + } + + BeforeEach(func() { + buf = &bytes.Buffer{} + errorCbCalled = false + qstr = mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str = newStream(qstr, errorCb) + }) + + It("reads DATA frames in a single run", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 3) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("bar"))) + }) + + It("reads DATA frames into too large buffers", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 10) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b[:n]).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames into too large buffers, in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte("foob"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte("ar"))) + }) + + It("reads multiple DATA frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("bar"))) + }) + + It("skips HEADERS frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + (&headersFrame{Length: 10}).Write(buf) + buf.Write(make([]byte, 10)) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := io.ReadFull(str, b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("errors when it can't parse the frame", func() { + buf.Write([]byte("invalid")) + _, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + + It("errors on unexpected frames, and calls the error callback", func() { + (&settingsFrame{}).Write(buf) + _, err := str.Read([]byte{0}) + Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) + Expect(errorCbCalled).To(BeTrue()) + }) + }) + + Context("writing", func() { + It("writes data frames", func() { + buf := &bytes.Buffer{} + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + str := newStream(qstr, nil) + str.Write([]byte("foo")) + str.Write([]byte("foobar")) + + f, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 3})) + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foo"))) + + f, err = parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 6})) + b = make([]byte, 6) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foobar"))) + }) + }) +}) diff --git a/internal/http3/request.go b/internal/http3/request.go new file mode 100644 index 00000000..0b9a7278 --- /dev/null +++ b/internal/http3/request.go @@ -0,0 +1,113 @@ +package http3 + +import ( + "crypto/tls" + "errors" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/marten-seemann/qpack" +) + +func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { + var path, authority, method, protocol, scheme, contentLengthStr string + + httpHeaders := http.Header{} + for _, h := range headers { + switch h.Name { + case ":path": + path = h.Value + case ":method": + method = h.Value + case ":authority": + authority = h.Value + case ":protocol": + protocol = h.Value + case ":scheme": + scheme = h.Value + case "content-length": + contentLengthStr = h.Value + default: + if !h.IsPseudo() { + httpHeaders.Add(h.Name, h.Value) + } + } + } + + // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 + if len(httpHeaders["Cookie"]) > 0 { + httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) + } + + isConnect := method == http.MethodConnect + // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 + isExtendedConnected := isConnect && protocol != "" + if isExtendedConnected { + if scheme == "" || path == "" || authority == "" { + return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") + } + } else if isConnect { + if path != "" || authority == "" { // normal CONNECT + return nil, errors.New(":path must be empty and :authority must not be empty") + } + } else if len(path) == 0 || len(authority) == 0 || len(method) == 0 { + return nil, errors.New(":path, :authority and :method must not be empty") + } + + var u *url.URL + var requestURI string + var err error + + if isConnect { + u = &url.URL{} + if isExtendedConnected { + u, err = url.ParseRequestURI(path) + if err != nil { + return nil, err + } + } else { + u.Path = path + } + u.Scheme = scheme + u.Host = authority + requestURI = authority + } else { + protocol = "HTTP/3.0" + u, err = url.ParseRequestURI(path) + if err != nil { + return nil, err + } + requestURI = path + } + + var contentLength int64 + if len(contentLengthStr) > 0 { + contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, err + } + } + + return &http.Request{ + Method: method, + URL: u, + Proto: protocol, + ProtoMajor: 3, + ProtoMinor: 0, + Header: httpHeaders, + Body: nil, + ContentLength: contentLength, + Host: authority, + RequestURI: requestURI, + TLS: &tls.ConnectionState{}, + }, nil +} + +func hostnameFromRequest(req *http.Request) string { + if req.URL != nil { + return req.URL.Host + } + return "" +} diff --git a/internal/http3/request_test.go b/internal/http3/request_test.go new file mode 100644 index 00000000..f2b84eca --- /dev/null +++ b/internal/http3/request_test.go @@ -0,0 +1,197 @@ +package http3 + +import ( + "net/http" + "net/url" + + "github.com/marten-seemann/qpack" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Request", func() { + It("populates request", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + {Name: "content-length", Value: "42"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Method).To(Equal("GET")) + Expect(req.URL.Path).To(Equal("/foo")) + Expect(req.URL.Host).To(BeEmpty()) + Expect(req.Proto).To(Equal("HTTP/3.0")) + Expect(req.ProtoMajor).To(Equal(3)) + Expect(req.ProtoMinor).To(BeZero()) + Expect(req.ContentLength).To(Equal(int64(42))) + Expect(req.Header).To(BeEmpty()) + Expect(req.Body).To(BeNil()) + Expect(req.Host).To(Equal("quic.clemente.io")) + Expect(req.RequestURI).To(Equal("/foo")) + Expect(req.TLS).ToNot(BeNil()) + }) + + It("parses path with leading double slashes", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "//foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Header).To(BeEmpty()) + Expect(req.Body).To(BeNil()) + Expect(req.URL.Path).To(Equal("//foo")) + Expect(req.URL.Host).To(BeEmpty()) + Expect(req.Host).To(Equal("quic.clemente.io")) + Expect(req.RequestURI).To(Equal("//foo")) + }) + + It("concatenates the cookie headers", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + {Name: "cookie", Value: "cookie1=foobar1"}, + {Name: "cookie", Value: "cookie2=foobar2"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Header).To(Equal(http.Header{ + "Cookie": []string{"cookie1=foobar1; cookie2=foobar2"}, + })) + }) + + It("handles Other headers", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + {Name: "cache-control", Value: "max-age=0"}, + {Name: "duplicate-header", Value: "1"}, + {Name: "duplicate-header", Value: "2"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Header).To(Equal(http.Header{ + "Cache-Control": []string{"max-age=0"}, + "Duplicate-Header": []string{"1", "2"}, + })) + }) + + It("errors with missing path", func() { + headers := []qpack.HeaderField{ + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path, :authority and :method must not be empty")) + }) + + It("errors with missing method", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path, :authority and :method must not be empty")) + }) + + It("errors with missing authority", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":method", Value: "GET"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path, :authority and :method must not be empty")) + }) + + Context("regular HTTP CONNECT", func() { + It("handles CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: http.MethodConnect}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Method).To(Equal(http.MethodConnect)) + Expect(req.RequestURI).To(Equal("quic.clemente.io")) + }) + + It("errors with missing authority in CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":method", Value: http.MethodConnect}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + }) + + It("errors with extra path in CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: http.MethodConnect}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + }) + }) + + Context("Extended CONNECT", func() { + It("handles Extended CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":protocol", Value: "webtransport"}, + {Name: ":scheme", Value: "ftp"}, + {Name: ":method", Value: http.MethodConnect}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":path", Value: "/foo?val=1337"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Method).To(Equal(http.MethodConnect)) + Expect(req.Proto).To(Equal("webtransport")) + Expect(req.URL.String()).To(Equal("ftp://quic.clemente.io/foo?val=1337")) + Expect(req.URL.Query().Get("val")).To(Equal("1337")) + }) + + It("errors with missing scheme", func() { + headers := []qpack.HeaderField{ + {Name: ":protocol", Value: "webtransport"}, + {Name: ":method", Value: http.MethodConnect}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":path", Value: "/foo"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError("extended CONNECT: :scheme, :path and :authority must not be empty")) + }) + }) + + Context("extracting the hostname from a request", func() { + var url *url.URL + + BeforeEach(func() { + var err error + url, err = url.Parse("https://quic.clemente.io:1337") + Expect(err).ToNot(HaveOccurred()) + }) + + It("uses req.URL.Host", func() { + req := &http.Request{URL: url} + Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337")) + }) + + It("uses req.URL.Host even if req.Host is available", func() { + req := &http.Request{ + Host: "www.example.org", + URL: url, + } + Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337")) + }) + + It("returns an empty hostname if nothing is set", func() { + Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty()) + }) + }) +}) diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go new file mode 100644 index 00000000..7cfdf2dd --- /dev/null +++ b/internal/http3/request_writer.go @@ -0,0 +1,291 @@ +package http3 + +import ( + "bytes" + "fmt" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/marten-seemann/qpack" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" +) + +const bodyCopyBufferSize = 8 * 1024 + +type requestWriter struct { + mutex sync.Mutex + encoder *qpack.Encoder + headerBuf *bytes.Buffer + + logger utils.Logger +} + +func newRequestWriter(logger utils.Logger) *requestWriter { + headerBuf := &bytes.Buffer{} + encoder := qpack.NewEncoder(headerBuf) + return &requestWriter{ + encoder: encoder, + headerBuf: headerBuf, + logger: logger, + } +} + +func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool, dumps []*dump.Dumper) error { + // TODO: figure out how to add support for trailers + buf := &bytes.Buffer{} + if err := w.writeHeaders(buf, req, gzip, dumps); err != nil { + return err + } + _, err := str.Write(buf.Bytes()) + return err +} + +func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, dumps []*dump.Dumper) error { + w.mutex.Lock() + defer w.mutex.Unlock() + defer w.encoder.Close() + defer w.headerBuf.Reset() + + if err := w.encodeHeaders(req, gzip, "", actualContentLength(req), dumps); err != nil { + return err + } + + buf := &bytes.Buffer{} + hf := headersFrame{Length: uint64(w.headerBuf.Len())} + hf.Write(buf) + if _, err := wr.Write(buf.Bytes()); err != nil { + return err + } + _, err := wr.Write(w.headerBuf.Bytes()) + return err +} + +// copied from net/transport.go +// Modified to support Extended CONNECT: +// Contrary to what the godoc for the http.Request says, +// we do respect the Proto field if the method is CONNECT. +func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, dumps []*dump.Dumper) error { + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return err + } + + // http.NewRequest sets this field to HTTP/1.1 + isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + + var path string + if req.Method != http.MethodConnect || isExtendedConnect { + path = req.URL.RequestURI() + if !validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !validPseudoPath(path) { + if req.URL.Opaque != "" { + return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + f(":method", req.Method) + if req.Method != http.MethodConnect || isExtendedConnect { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if isExtendedConnect { + f(":protocol", req.Proto) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || + strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || + strings.EqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if strings.EqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + + } + + for _, v := range vv { + f(k, v) + } + } + if shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", header.DefaultUserAgent) + } + } + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + for _, dump := range dumps { + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + } + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + // TODO: check maximum header list size + // if hlSize > cc.peerMaxHeaderListSize { + // return errRequestHeaderListSize + // } + + // trace := httptrace.ContextClientTrace(req.Context()) + // traceHeaders := traceHasWroteHeaderField(trace) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name = strings.ToLower(name) + for _, dump := range dumps { + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + } + w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) + // if traceHeaders { + // traceWroteHeaderField(trace, name, value) + // } + }) + + return nil +} + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} diff --git a/internal/http3/request_writer_test.go b/internal/http3/request_writer_test.go new file mode 100644 index 00000000..a9cbb69f --- /dev/null +++ b/internal/http3/request_writer_test.go @@ -0,0 +1,112 @@ +package http3 + +import ( + "bytes" + "io" + "net/http" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/imroc/req/v3/internal/utils" + + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qpack" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Request Writer", func() { + var ( + rw *requestWriter + str *mockquic.MockStream + strBuf *bytes.Buffer + ) + + decode := func(str io.Reader) map[string]string { + frame, err := parseNextFrame(str, nil) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) + headersFrame := frame.(*headersFrame) + data := make([]byte, headersFrame.Length) + _, err = io.ReadFull(str, data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + decoder := qpack.NewDecoder(nil) + hfs, err := decoder.DecodeFull(data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + values := make(map[string]string) + for _, hf := range hfs { + values[hf.Name] = hf.Value + } + return values + } + + BeforeEach(func() { + rw = newRequestWriter(utils.DefaultLogger) + strBuf = &bytes.Buffer{} + str = mockquic.NewMockStream(mockCtrl) + str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() + }) + + It("writes a GET request", func() { + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) + Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar")) + Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) + Expect(headerFields).ToNot(HaveKey("accept-encoding")) + }) + + It("sends cookies", func() { + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) + Expect(err).ToNot(HaveOccurred()) + cookie1 := &http.Cookie{ + Name: "Cookie #1", + Value: "Value #1", + } + cookie2 := &http.Cookie{ + Name: "Cookie #2", + Value: "Value #2", + } + req.AddCookie(cookie1) + req.AddCookie(cookie2) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) + }) + + It("adds the header for gzip support", func() { + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) + }) + + It("writes a CONNECT request", func() { + req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + Expect(headerFields).ToNot(HaveKey(":path")) + Expect(headerFields).ToNot(HaveKey(":scheme")) + Expect(headerFields).ToNot(HaveKey(":protocol")) + }) + + It("writes an Extended CONNECT request", func() { + req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) + Expect(err).ToNot(HaveOccurred()) + req.Proto = "webtransport" + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) + Expect(headerFields).To(HaveKeyWithValue(":path", "/foobar")) + Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) + Expect(headerFields).To(HaveKeyWithValue(":protocol", "webtransport")) + }) +}) diff --git a/internal/http3/response_writer.go b/internal/http3/response_writer.go new file mode 100644 index 00000000..d1c16dc3 --- /dev/null +++ b/internal/http3/response_writer.go @@ -0,0 +1,118 @@ +package http3 + +import ( + "bufio" + "bytes" + "net/http" + "strconv" + "strings" + + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/marten-seemann/qpack" +) + +type responseWriter struct { + conn quic.Connection + bufferedStr *bufio.Writer + + header http.Header + status int // status code passed to WriteHeader + headerWritten bool + + logger utils.Logger +} + +var ( + _ http.ResponseWriter = &responseWriter{} + _ http.Flusher = &responseWriter{} + _ Hijacker = &responseWriter{} +) + +func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { + return &responseWriter{ + header: http.Header{}, + conn: conn, + bufferedStr: bufio.NewWriter(str), + logger: logger, + } +} + +func (w *responseWriter) Header() http.Header { + return w.header +} + +func (w *responseWriter) WriteHeader(status int) { + if w.headerWritten { + return + } + + if status < 100 || status >= 200 { + w.headerWritten = true + } + w.status = status + + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) + + for k, v := range w.header { + for index := range v { + enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + } + } + + buf := &bytes.Buffer{} + (&headersFrame{Length: uint64(headers.Len())}).Write(buf) + w.logger.Infof("Responding with %d", status) + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { + w.logger.Errorf("could not write headers frame: %s", err.Error()) + } + if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { + w.logger.Errorf("could not write header frame payload: %s", err.Error()) + } + if !w.headerWritten { + w.Flush() + } +} + +func (w *responseWriter) Write(p []byte) (int, error) { + if !w.headerWritten { + w.WriteHeader(200) + } + if !bodyAllowedForStatus(w.status) { + return 0, http.ErrBodyNotAllowed + } + df := &dataFrame{Length: uint64(len(p))} + buf := &bytes.Buffer{} + df.Write(buf) + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { + return 0, err + } + return w.bufferedStr.Write(p) +} + +func (w *responseWriter) Flush() { + if err := w.bufferedStr.Flush(); err != nil { + w.logger.Errorf("could not flush to stream: %s", err.Error()) + } +} + +func (w *responseWriter) StreamCreator() StreamCreator { + return w.conn +} + +// copied from http2/http2.go +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 2616, section 4.4. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} diff --git a/internal/http3/response_writer_test.go b/internal/http3/response_writer_test.go new file mode 100644 index 00000000..abada013 --- /dev/null +++ b/internal/http3/response_writer_test.go @@ -0,0 +1,150 @@ +package http3 + +import ( + "bytes" + "io" + "net/http" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/imroc/req/v3/internal/utils" + + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qpack" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Response Writer", func() { + var ( + rw *responseWriter + strBuf *bytes.Buffer + ) + + BeforeEach(func() { + strBuf = &bytes.Buffer{} + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() + rw = newResponseWriter(str, nil, utils.DefaultLogger) + }) + + decodeHeader := func(str io.Reader) map[string][]string { + rw.Flush() + fields := make(map[string][]string) + decoder := qpack.NewDecoder(nil) + + frame, err := parseNextFrame(str, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) + headersFrame := frame.(*headersFrame) + data := make([]byte, headersFrame.Length) + _, err = io.ReadFull(str, data) + Expect(err).ToNot(HaveOccurred()) + hfs, err := decoder.DecodeFull(data) + Expect(err).ToNot(HaveOccurred()) + for _, p := range hfs { + fields[p.Name] = append(fields[p.Name], p.Value) + } + return fields + } + + getData := func(str io.Reader) []byte { + frame, err := parseNextFrame(str, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) + df := frame.(*dataFrame) + data := make([]byte, df.Length) + _, err = io.ReadFull(str, data) + Expect(err).ToNot(HaveOccurred()) + return data + } + + It("writes status", func() { + rw.WriteHeader(http.StatusTeapot) + fields := decodeHeader(strBuf) + Expect(fields).To(HaveLen(1)) + Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) + }) + + It("writes headers", func() { + rw.Header().Add("content-length", "42") + rw.WriteHeader(http.StatusTeapot) + fields := decodeHeader(strBuf) + Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"})) + }) + + It("writes multiple headers with the same name", func() { + const cookie1 = "test1=1; Max-Age=7200; path=/" + const cookie2 = "test2=2; Max-Age=7200; path=/" + rw.Header().Add("set-cookie", cookie1) + rw.Header().Add("set-cookie", cookie2) + rw.WriteHeader(http.StatusTeapot) + fields := decodeHeader(strBuf) + Expect(fields).To(HaveKey("set-cookie")) + cookies := fields["set-cookie"] + Expect(cookies).To(ContainElement(cookie1)) + Expect(cookies).To(ContainElement(cookie2)) + }) + + It("writes data", func() { + n, err := rw.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + // Should have written 200 on the header stream + fields := decodeHeader(strBuf) + Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) + // And foobar on the data stream + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + }) + + It("writes data after WriteHeader is called", func() { + rw.WriteHeader(http.StatusTeapot) + n, err := rw.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + // Should have written 418 on the header stream + fields := decodeHeader(strBuf) + Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) + // And foobar on the data stream + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + }) + + It("does not WriteHeader() twice", func() { + rw.WriteHeader(200) + rw.WriteHeader(500) + fields := decodeHeader(strBuf) + Expect(fields).To(HaveLen(1)) + Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) + }) + + It("allows calling WriteHeader() several times when using the 103 status code", func() { + rw.Header().Add("Link", "; rel=preload; as=style") + rw.Header().Add("Link", "; rel=preload; as=script") + rw.WriteHeader(http.StatusEarlyHints) + + n, err := rw.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + + // Early Hints must have been received + fields := decodeHeader(strBuf) + Expect(fields).To(HaveLen(2)) + Expect(fields).To(HaveKeyWithValue(":status", []string{"103"})) + Expect(fields).To(HaveKeyWithValue("link", []string{"; rel=preload; as=style", "; rel=preload; as=script"})) + + // According to the spec, headers sent in the informational response must also be included in the final response + fields = decodeHeader(strBuf) + Expect(fields).To(HaveLen(2)) + Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(fields).To(HaveKeyWithValue("link", []string{"; rel=preload; as=style", "; rel=preload; as=script"})) + + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + }) + + It("doesn't allow writes if the status code doesn't allow a body", func() { + rw.WriteHeader(304) + n, err := rw.Write([]byte("foobar")) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(http.ErrBodyNotAllowed)) + }) +}) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go new file mode 100644 index 00000000..d668708a --- /dev/null +++ b/internal/http3/roundtrip.go @@ -0,0 +1,251 @@ +package http3 + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "github.com/imroc/req/v3/internal/transport" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + + "golang.org/x/net/http/httpguts" +) + +type roundTripCloser interface { + RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) + io.Closer +} + +// RoundTripper implements the http.RoundTripper interface +type RoundTripper struct { + transport.Interface + mutex sync.Mutex + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // QuicConfig is the quic.Config used for dialing new connections. + // If nil, reasonable default values will be used. + QuicConfig *quic.Config + + // Enable support for HTTP/3 datagrams. + // If set to true, QuicConfig.EnableDatagram will be set. + // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + EnableDatagrams bool + + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // If parsing the frame type fails, the error is passed to the callback. + // In that case, the frame type will not be set. + // Callers can either ignore the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + // If parsing the stream type fails, the error is passed to the callback. + // In that case, the stream type will not be set. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, quic.DialAddrEarlyContext will be used. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + + clients map[string]roundTripCloser +} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. + DontCloseRequestStream bool +} + +var ( + _ http.RoundTripper = &RoundTripper{} + _ io.Closer = &RoundTripper{} +) + +// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set +var ErrNoCachedConn = errors.New("http3: no cached connection was available") + +// RoundTripOpt is like RoundTrip, but takes options. +func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.URL") + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("http3: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.Header") + } + + if req.URL.Scheme == "https" { + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) + } + } + } + } else { + closeRequestBody(req) + return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) + } + + if req.Method != "" && !validMethod(req.Method) { + closeRequestBody(req) + return nil, fmt.Errorf("http3: invalid method %q", req.Method) + } + + hostname := authorityAddr("https", hostnameFromRequest(req)) + cl, err := r.getClient(hostname, opt.OnlyCachedConn) + if err != ErrNoCachedConn { + if debugf := r.Debugf(); debugf != nil { + debugf("HTTP/3 %s %s", req.Method, req.URL.String()) + } + } + if err != nil { + return nil, err + } + return cl.RoundTripOpt(req, opt) +} + +// RoundTrip does a round trip. +func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{}) +} + +func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) +} + +func (r *RoundTripper) HaveCachedConn(addr string) bool { + _, ok := r.clients[addr] + return ok +} + +func (r *RoundTripper) AddConn(addr string) error { + c, err := r.getClient(addr, false) + if err != nil { + return err + } + client, ok := c.(*client) + if !ok { + return errors.New("bad client type") + } + client.dialOnce.Do(func() { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + client.handshakeErr = client.dial(ctx) + }) + return client.handshakeErr +} + +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.clients == nil { + r.clients = make(map[string]roundTripCloser) + } + + client, ok := r.clients[hostname] + if !ok { + if onlyCached { + return nil, ErrNoCachedConn + } + var err error + client, err = newClient( + hostname, + r.TLSClientConfig(), + &roundTripperOpts{ + EnableDatagram: r.EnableDatagrams, + DisableCompression: r.DisableCompression, + MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, + dump: r.Interface.Dump(), + }, + r.QuicConfig, + r.Dial, + ) + if err != nil { + return nil, err + } + r.clients[hostname] = client + } + return client, nil +} + +// Close closes the QUIC connections that this RoundTripper has used +func (r *RoundTripper) Close() error { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, client := range r.clients { + if err := client.Close(); err != nil { + return err + } + } + r.clients = nil + return nil +} + +func closeRequestBody(req *http.Request) { + if req.Body != nil { + req.Body.Close() + } +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// copied from net/http/http.go +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go new file mode 100644 index 00000000..e70c45a2 --- /dev/null +++ b/internal/http3/roundtrip_test.go @@ -0,0 +1,252 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "io" + "net/http" + "time" + + "github.com/golang/mock/gomock" + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/lucas-clemente/quic-go" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockClient struct { + closed bool +} + +func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + return &http.Response{Request: req}, nil +} + +func (m *mockClient) Close() error { + m.closed = true + return nil +} + +var _ roundTripCloser = &mockClient{} + +type mockBody struct { + reader bytes.Reader + readErr error + closeErr error + closed bool +} + +// make sure the mockBody can be used as a http.Request.Body +var _ io.ReadCloser = &mockBody{} + +func (m *mockBody) Read(p []byte) (int, error) { + if m.readErr != nil { + return 0, m.readErr + } + return m.reader.Read(p) +} + +func (m *mockBody) SetData(data []byte) { + m.reader = *bytes.NewReader(data) +} + +func (m *mockBody) Close() error { + m.closed = true + return m.closeErr +} + +var _ = Describe("RoundTripper", func() { + var ( + rt *RoundTripper + req1 *http.Request + conn *mockquic.MockEarlyConnection + handshakeCtx context.Context // an already canceled context + ) + + BeforeEach(func() { + rt = &RoundTripper{} + var err error + req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) + Expect(err).ToNot(HaveOccurred()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + handshakeCtx = ctx + }) + + Context("dialing hosts", func() { + origDialAddr := dialAddr + + BeforeEach(func() { + conn = mockquic.NewMockEarlyConnection(mockCtrl) + origDialAddr = dialAddr + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + // return an error when trying to open a stream + // we don't want to test all the dial logic here, just that dialing happens at all + return conn, nil + } + }) + + AfterEach(func() { + dialAddr = origDialAddr + }) + + It("creates new clients", func() { + closed := make(chan struct{}) + testErr := errors.New("test err") + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-closed + return nil, errors.New("test done") + }).MaxTimes(1) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(testErr)) + Expect(rt.clients).To(HaveLen(1)) + Eventually(closed).Should(BeClosed()) + }) + + It("uses the quic.Config, if provided", func() { + config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} + var receivedConfig *quic.Config + dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { + receivedConfig = config + return nil, errors.New("handshake error") + } + rt.QuicConfig = config + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("handshake error")) + Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) + }) + + It("uses the custom dialer, if provided", func() { + var dialed bool + dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + dialed = true + return nil, errors.New("handshake error") + } + rt.Dial = dialer + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("handshake error")) + Expect(dialed).To(BeTrue()) + }) + + It("reuses existing clients", func() { + closed := make(chan struct{}) + testErr := errors.New("test err") + conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-closed + return nil, errors.New("test done") + }).MaxTimes(1) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) + req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(testErr)) + Expect(rt.clients).To(HaveLen(1)) + req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req2) + Expect(err).To(MatchError(testErr)) + Expect(rt.clients).To(HaveLen(1)) + Eventually(closed).Should(BeClosed()) + }) + + It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) + Expect(err).To(MatchError(ErrNoCachedConn)) + }) + }) + + Context("validating request", func() { + It("rejects plain HTTP requests", func() { + req, err := http.NewRequest("GET", "http://www.example.org/", nil) + req.Body = &mockBody{} + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError("http3: unsupported protocol scheme: http")) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects requests without a URL", func() { + req1.URL = nil + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: nil Request.URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects request without a URL Host", func() { + req1.URL.Host = "" + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: no Host in request URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("doesn't try to close the body if the request doesn't have one", func() { + req1.URL = nil + Expect(req1.Body).To(BeNil()) + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: nil Request.URL")) + }) + + It("rejects requests without a header", func() { + req1.Header = nil + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: nil Request.Header")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects requests with invalid header name fields", func() { + req1.Header.Add("foobär", "value") + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) + }) + + It("rejects requests with invalid header name values", func() { + req1.Header.Add("foo", string([]byte{0x7})) + _, err := rt.RoundTrip(req1) + Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) + }) + + It("rejects requests with an invalid request method", func() { + req1.Method = "foobär" + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("http3: invalid method \"foobär\"")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + }) + + Context("closing", func() { + It("closes", func() { + rt.clients = make(map[string]roundTripCloser) + cl := &mockClient{} + rt.clients["foo.bar"] = cl + err := rt.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(len(rt.clients)).To(BeZero()) + Expect(cl.closed).To(BeTrue()) + }) + + It("closes a RoundTripper that has never been used", func() { + Expect(len(rt.clients)).To(BeZero()) + err := rt.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(len(rt.clients)).To(BeZero()) + }) + }) +}) diff --git a/internal/http3/server.go b/internal/http3/server.go new file mode 100644 index 00000000..4a338f00 --- /dev/null +++ b/internal/http3/server.go @@ -0,0 +1,736 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "runtime" + "strings" + "sync" + "time" + + "github.com/imroc/req/v3/internal/handshake" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/marten-seemann/qpack" +) + +// allows mocking of quic.Listen and quic.ListenAddr +var ( + quicListen = quic.ListenEarly + quicListenAddr = quic.ListenAddrEarly +) + +const ( + nextProtoH3Draft29 = "h3-29" + nextProtoH3 = "h3" +) + +// StreamType is the stream type of a unidirectional stream. +type StreamType uint64 + +const ( + streamTypeControlStream = 0 + streamTypePushStream = 1 + streamTypeQPACKEncoderStream = 2 + streamTypeQPACKDecoderStream = 3 +) + +func versionToALPN(v quic.VersionNumber) string { + if v == protocol.Version1 || v == protocol.Version2 { + return nextProtoH3 + } + if v == protocol.VersionTLS || v == protocol.VersionDraft29 { + return nextProtoH3Draft29 + } + return "" +} + +// ConfigureTLSConfig creates a new tls.Config which can be used +// to create a quic.Listener meant for serving http3. The created +// tls.Config adds the functionality of detecting the used QUIC version +// in order to set the correct ALPN value for the http3 connection. +func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { + // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + return &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { + proto = versionToALPN(qconn.GetQUICVersion()) + } + config := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err := getConfigForClient(ch) + if err != nil { + return nil, err + } + if conf != nil { + config = conf + } + } + if config == nil { + return nil, nil + } + config = config.Clone() + config.NextProtos = []string{proto} + return config, nil + }, + } +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name } + +// ServerContextKey is a context key. It can be used in HTTP +// handlers with Context.Value to access the server that +// started the handler. The associated value will be of +// type *http3.Server. +var ServerContextKey = &contextKey{"http3-server"} + +type requestError struct { + err error + streamErr errorCode + connErr errorCode +} + +func newStreamError(code errorCode, err error) requestError { + return requestError{err: err, streamErr: code} +} + +func newConnError(code errorCode, err error) requestError { + return requestError{err: err, connErr: code} +} + +// listenerInfo contains info about specific listener added with addListener +type listenerInfo struct { + port int // 0 means that no info about port is available +} + +// Server is a HTTP/3 server. +type Server struct { + // Addr optionally specifies the UDP address for the server to listen on, + // in the form "host:port". + // + // When used by ListenAndServe and ListenAndServeTLS methods, if empty, + // ":https" (port 443) is used. See net.Dial for details of the address + // format. + // + // Otherwise, if Port is not set and underlying QUIC listeners do not + // have valid port numbers, the port part is used in Alt-Svc headers set + // with SetQuicHeaders. + Addr string + + // Port is used in Alt-Svc response headers set with SetQuicHeaders. If + // needed Port can be manually set when the Server is created. + // + // This is useful when a Layer 4 firewall is redirecting UDP traffic and + // clients must use a port different from the port the Server is + // listening on. + Port int + + // TLSConfig provides a TLS configuration for use by server. It must be + // set for ListenAndServe and Serve methods. + TLSConfig *tls.Config + + // QuicConfig provides the parameters for QUIC connection created with + // Serve. If nil, it uses reasonable default values. + // + // Configured versions are also used in Alt-Svc response header set with + // SetQuicHeaders. + QuicConfig *quic.Config + + // Handler is the HTTP request handler to use. If not set, defaults to + // http.NotFound. + Handler http.Handler + + // EnableDatagrams enables support for HTTP/3 datagrams. + // If set to true, QuicConfig.EnableDatagram will be set. + // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. + EnableDatagrams bool + + // MaxHeaderBytes controls the maximum number of bytes the server will + // read parsing the request HEADERS frame. It does not limit the size of + // the request body. If zero or negative, http.DefaultMaxHeaderBytes is + // used. + MaxHeaderBytes int + + // AdditionalSettings specifies additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // If parsing the frame type fails, the error is passed to the callback. + // In that case, the frame type will not be set. + // Callers can either ignore the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + + // UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type. + // If parsing the stream type fails, the error is passed to the callback. + // In that case, the stream type will not be set. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + + mutex sync.RWMutex + listeners map[*quic.EarlyListener]listenerInfo + + closed bool + + altSvcHeader string + + logger utils.Logger +} + +// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. +// +// If s.Addr is blank, ":https" is used. +func (s *Server) ListenAndServe() error { + return s.serveConn(s.TLSConfig, nil) +} + +// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. +// +// If s.Addr is blank, ":https" is used. +func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + var err error + certs := make([]tls.Certificate, 1) + certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + // We currently only use the cert-related stuff from tls.Config, + // so we don't need to make a full copy. + config := &tls.Config{ + Certificates: certs, + } + return s.serveConn(config, nil) +} + +// Serve an existing UDP connection. +// It is possible to reuse the same connection for outgoing connections. +// Closing the server does not close the packet conn. +func (s *Server) Serve(conn net.PacketConn) error { + return s.serveConn(s.TLSConfig, conn) +} + +// ServeListener serves an existing QUIC listener. +// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config +// and use it to construct a http3-friendly QUIC listener. +// Closing the server does close the listener. +func (s *Server) ServeListener(ln quic.EarlyListener) error { + if err := s.addListener(&ln); err != nil { + return err + } + err := s.serveListener(ln) + s.removeListener(&ln) + return err +} + +var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") + +func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { + if tlsConf == nil { + return errServerWithoutTLSConfig + } + + s.mutex.Lock() + closed := s.closed + s.mutex.Unlock() + if closed { + return http.ErrServerClosed + } + + baseConf := ConfigureTLSConfig(tlsConf) + quicConf := s.QuicConfig + if quicConf == nil { + quicConf = &quic.Config{} + } else { + quicConf = s.QuicConfig.Clone() + } + if s.EnableDatagrams { + quicConf.EnableDatagrams = true + } + + var ln quic.EarlyListener + var err error + if conn == nil { + addr := s.Addr + if addr == "" { + addr = ":https" + } + ln, err = quicListenAddr(addr, baseConf, quicConf) + } else { + ln, err = quicListen(conn, baseConf, quicConf) + } + if err != nil { + return err + } + if err := s.addListener(&ln); err != nil { + return err + } + err = s.serveListener(ln) + s.removeListener(&ln) + return err +} + +func (s *Server) serveListener(ln quic.EarlyListener) error { + for { + conn, err := ln.Accept(context.Background()) + if err != nil { + return err + } + go s.handleConn(conn) + } +} + +func extractPort(addr string) (int, error) { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0, err + } + + portInt, err := net.LookupPort("tcp", portStr) + if err != nil { + return 0, err + } + return portInt, nil +} + +func (s *Server) generateAltSvcHeader() { + if len(s.listeners) == 0 { + // Don't announce any ports since no one is listening for connections + s.altSvcHeader = "" + return + } + + // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. + supportedVersions := protocol.SupportedVersions + if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 { + supportedVersions = s.QuicConfig.Versions + } + var versionStrings []string + for _, version := range supportedVersions { + if v := versionToALPN(version); len(v) > 0 { + versionStrings = append(versionStrings, v) + } + } + + var altSvc []string + addPort := func(port int) { + for _, v := range versionStrings { + altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)) + } + } + + if s.Port != 0 { + // if Port is specified, we must use it instead of the + // listener addresses since there's a reason it's specified. + addPort(s.Port) + } else { + // if we have some listeners assigned, try to find ports + // which we can announce, otherwise nothing should be announced + validPortsFound := false + for _, info := range s.listeners { + if info.port != 0 { + addPort(info.port) + validPortsFound = true + } + } + if !validPortsFound { + if port, err := extractPort(s.Addr); err == nil { + addPort(port) + } + } + } + + s.altSvcHeader = strings.Join(altSvc, ",") +} + +// We store a pointer to interface in the map set. This is safe because we only +// call trackListener via Serve and can track+defer untrack the same pointer to +// local variable there. We never need to compare a Listener from another caller. +func (s *Server) addListener(l *quic.EarlyListener) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closed { + return http.ErrServerClosed + } + if s.logger == nil { + s.logger = utils.DefaultLogger.WithPrefix("server") + } + if s.listeners == nil { + s.listeners = make(map[*quic.EarlyListener]listenerInfo) + } + + if port, err := extractPort((*l).Addr().String()); err == nil { + s.listeners[l] = listenerInfo{port} + } else { + s.logger.Errorf( + "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.listeners[l] = listenerInfo{} + } + s.generateAltSvcHeader() + return nil +} + +func (s *Server) removeListener(l *quic.EarlyListener) { + s.mutex.Lock() + delete(s.listeners, l) + s.generateAltSvcHeader() + s.mutex.Unlock() +} + +func (s *Server) handleConn(conn quic.EarlyConnection) { + decoder := qpack.NewDecoder(nil) + + // send a SETTINGS frame + str, err := conn.OpenUniStream() + if err != nil { + s.logger.Debugf("Opening the control stream failed.") + return + } + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) // stream type + (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Write(buf) + str.Write(buf.Bytes()) + + go s.handleUnidirectionalStreams(conn) + + // Process all requests immediately. + // It's the client's responsibility to decide which requests are eligible for 0-RTT. + for { + str, err := conn.AcceptStream(context.Background()) + if err != nil { + s.logger.Debugf("Accepting stream failed: %s", err) + return + } + go func() { + rerr := s.handleRequest(conn, str, decoder, func() { + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") + }) + if rerr.err == errHijacked { + return + } + if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { + s.logger.Debugf("Handling request failed: %s", err) + if rerr.streamErr != 0 { + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) + } + if rerr.connErr != 0 { + var reason string + if rerr.err != nil { + reason = rerr.err.Error() + } + conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + } + return + } + str.Close() + }() + } +} + +func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { + for { + str, err := conn.AcceptUniStream(context.Background()) + if err != nil { + s.logger.Debugf("accepting unidirectional stream failed: %s", err) + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) { + return + } + s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: + // Our QPACK implementation doesn't use the dynamic table yet. + // TODO: check that only one stream of each type is opened. + return + case streamTypePushStream: // only the server can push + conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") + return + default: + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) { + return + } + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + return + } + f, err := parseNextFrame(str, nil) + if err != nil { + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + return + } + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the client side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams { + conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + } + }(str) + } +} + +func (s *Server) maxHeaderBytes() uint64 { + if s.MaxHeaderBytes <= 0 { + return http.DefaultMaxHeaderBytes + } + return uint64(s.MaxHeaderBytes) +} + +func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { + var ufh unknownFrameHandlerFunc + if s.StreamHijacker != nil { + ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) } + } + frame, err := parseNextFrame(str, ufh) + if err != nil { + if err == errHijacked { + return requestError{err: errHijacked} + } + return newStreamError(errorRequestIncomplete, err) + } + hf, ok := frame.(*headersFrame) + if !ok { + return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) + } + if hf.Length > s.maxHeaderBytes() { + return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes())) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(str, headerBlock); err != nil { + return newStreamError(errorRequestIncomplete, err) + } + hfs, err := decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + return newConnError(errorGeneralProtocolError, err) + } + req, err := requestFromHeaders(hfs) + if err != nil { + // TODO: use the right error code + return newStreamError(errorGeneralProtocolError, err) + } + + req.RemoteAddr = conn.RemoteAddr().String() + body := newRequestBody(newStream(str, onFrameError)) + req.Body = body + + if s.logger.Debug() { + s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) + } else { + s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) + } + + ctx := str.Context() + ctx = context.WithValue(ctx, ServerContextKey, s) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) + req = req.WithContext(ctx) + r := newResponseWriter(str, conn, s.logger) + defer r.Flush() + handler := s.Handler + if handler == nil { + handler = http.DefaultServeMux + } + + var panicked bool + func() { + defer func() { + if p := recover(); p != nil { + // Copied from net/http/server.go + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + s.logger.Errorf("http: panic serving: %v\n%s", p, buf) + panicked = true + } + }() + handler.ServeHTTP(r, req) + }() + + if body.wasStreamHijacked() { + return requestError{err: errHijacked} + } + + if panicked { + r.WriteHeader(500) + } else { + r.WriteHeader(200) + } + // If the EOF was read by the handler, CancelRead() is a no-op. + str.CancelRead(quic.StreamErrorCode(errorNoError)) + return requestError{} +} + +// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. +// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +func (s *Server) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.closed = true + + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. +// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +func (s *Server) CloseGracefully(timeout time.Duration) error { + // TODO: implement + return nil +} + +// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found +// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port +// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr. +var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr") + +// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3. +// The values set by default advertise all of the ports the server is listening on, but can be +// changed to a specific port by setting Server.Port before launching the serverr. +// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used +// to extract the port, if specified. +// For example, a server launched using ListenAndServe on an address with port 443 would set: +// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +func (s *Server) SetQuicHeaders(hdr http.Header) error { + s.mutex.RLock() + defer s.mutex.RUnlock() + + if s.altSvcHeader == "" { + return ErrNoAltSvcPort + } + // use the map directly to avoid constant canonicalization + // since the key is already canonicalized + hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader) + return nil +} + +// ListenAndServeQUIC listens on the UDP network address addr and calls the +// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is +// used when handler is nil. +func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { + server := &Server{ + Addr: addr, + Handler: handler, + } + return server.ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServe listens on the given network address for both, TLS and QUIC +// connections in parallel. It returns if one of the two returns an error. +// http.DefaultServeMux is used when handler is nil. +// The correct Alt-Svc headers for QUIC are set. +func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { + // Load certs + var err error + certs := make([]tls.Certificate, 1) + certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + // We currently only use the cert-related stuff from tls.Config, + // so we don't need to make a full copy. + config := &tls.Config{ + Certificates: certs, + } + + if addr == "" { + addr = ":https" + } + + // Open the listeners + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + defer udpConn.Close() + + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return err + } + tcpConn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return err + } + defer tcpConn.Close() + + tlsConn := tls.NewListener(tcpConn, config) + defer tlsConn.Close() + + // Start the servers + httpServer := &http.Server{} + quicServer := &Server{ + TLSConfig: config, + } + + if handler == nil { + handler = http.DefaultServeMux + } + httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + quicServer.SetQuicHeaders(w.Header()) + handler.ServeHTTP(w, r) + }) + + hErr := make(chan error) + qErr := make(chan error) + go func() { + hErr <- httpServer.Serve(tlsConn) + }() + go func() { + qErr <- quicServer.Serve(udpConn) + }() + + select { + case err := <-hErr: + quicServer.Close() + return err + case err := <-qErr: + // Cannot close the HTTP server or wait for requests to complete properly :/ + return err + } +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go new file mode 100644 index 00000000..3b4e68f8 --- /dev/null +++ b/internal/http3/server_test.go @@ -0,0 +1,1289 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "runtime" + "sync/atomic" + "time" + + mockquic "github.com/imroc/req/v3/internal/mocks/quic" + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/testdata" + "github.com/imroc/req/v3/internal/utils" + "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quicvarint" + + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qpack" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + gmtypes "github.com/onsi/gomega/types" +) + +type mockConn struct { + net.Conn + version quic.VersionNumber +} + +func newMockConn(version quic.VersionNumber) net.Conn { + return &mockConn{version: version} +} + +func (c *mockConn) GetQUICVersion() quic.VersionNumber { + return c.version +} + +type mockAddr struct { + addr string +} + +func (ma *mockAddr) Network() string { + return "udp" +} + +func (ma *mockAddr) String() string { + return ma.addr +} + +type mockAddrListener struct { + *mockquic.MockEarlyListener + addr *mockAddr +} + +func (m *mockAddrListener) Addr() net.Addr { + _ = m.MockEarlyListener.Addr() + return m.addr +} + +func newMockAddrListener(addr string) *mockAddrListener { + return &mockAddrListener{ + MockEarlyListener: mockquic.NewMockEarlyListener(mockCtrl), + addr: &mockAddr{ + addr: addr, + }, + } +} + +type noPortListener struct { + *mockAddrListener +} + +func (m *noPortListener) Addr() net.Addr { + _ = m.mockAddrListener.Addr() + return &net.UnixAddr{ + Net: "unix", + Name: "/tmp/quic.sock", + } +} + +var _ = Describe("Server", func() { + var ( + s *Server + origQuicListenAddr = quicListenAddr + ) + + BeforeEach(func() { + s = &Server{ + TLSConfig: testdata.GetTLSConfig(), + logger: utils.DefaultLogger, + } + origQuicListenAddr = quicListenAddr + }) + + AfterEach(func() { + quicListenAddr = origQuicListenAddr + }) + + Context("handling requests", func() { + var ( + qpackDecoder *qpack.Decoder + str *mockquic.MockStream + conn *mockquic.MockEarlyConnection + exampleGetRequest *http.Request + examplePostRequest *http.Request + ) + reqContext := context.Background() + + decodeHeader := func(str io.Reader) map[string][]string { + fields := make(map[string][]string) + decoder := qpack.NewDecoder(nil) + + frame, err := parseNextFrame(str, nil) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) + headersFrame := frame.(*headersFrame) + data := make([]byte, headersFrame.Length) + _, err = io.ReadFull(str, data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + hfs, err := decoder.DecodeFull(data) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + for _, p := range hfs { + fields[p.Name] = append(fields[p.Name], p.Value) + } + return fields + } + + encodeRequest := func(req *http.Request) []byte { + buf := &bytes.Buffer{} + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + rw := newRequestWriter(utils.DefaultLogger) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + return buf.Bytes() + } + + setRequest := func(data []byte) { + buf := bytes.NewBuffer(data) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + if buf.Len() == 0 { + return 0, io.EOF + } + return buf.Read(p) + }).AnyTimes() + } + + BeforeEach(func() { + var err error + exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) + Expect(err).ToNot(HaveOccurred()) + examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) + Expect(err).ToNot(HaveOccurred()) + + qpackDecoder = qpack.NewDecoder(nil) + str = mockquic.NewMockStream(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + It("calls the HTTP handler function", func() { + requestChan := make(chan *http.Request, 1) + s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + requestChan <- r + }) + + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return len(p), nil + }).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) + var req *http.Request + Eventually(requestChan).Should(Receive(&req)) + Expect(req.Host).To(Equal("www.example.com")) + Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) + Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) + }) + + It("returns 200 with an empty handler", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + }) + + It("handles a panicking handler", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("foobar") + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) + }) + + Context("hijacking bidirectional streams", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("hijacks a bidirectional stream of unknown frame type", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return true, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels writing when hijacker returned error", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, errors.New("error in hijacker") + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("handles errors that occur when reading the stream type", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + unknownStr := mockquic.NewMockStream(mockCtrl) + s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { + defer close(done) + Expect(ft).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + Expect(err).To(MatchError(testErr)) + return true, nil + } + + unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + + Context("hijacking unidirectional streams", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("handles errors that occur when reading the stream type", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + unknownStr := mockquic.NewMockStream(mockCtrl) + s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { + defer close(done) + Expect(st).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + Expect(err).To(MatchError(testErr)) + return true + } + + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + + Context("control stream handling", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("parses the SETTINGS frame", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&settingsFrame{}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { + streamType := t + name := "encoder" + if streamType == streamTypeQPACKDecoderStream { + name = "decoder" + } + + It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamType) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return str, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead + }) + } + + It("reset streams other than the control stream and the QPACK streams", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + done := make(chan struct{}) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { + close(done) + }) + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return str, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("errors when the first frame on the control stream is not a SETTINGS frame", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&dataFrame{}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorMissingSettings)) + close(done) + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("errors when parsing the frame on the control stream fails", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + b := &bytes.Buffer{} + (&settingsFrame{}).Write(b) + buf.Write(b.Bytes()[:b.Len()-1]) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorFrameError)) + close(done) + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("errors when the client opens a push stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypePushStream) + (&dataFrame{}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorStreamCreationError)) + close(done) + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("errors when the client advertises datagram support (and we enabled support for it)", func() { + s.EnableDatagrams = true + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + (&settingsFrame{Datagram: true}).Write(buf) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { + defer GinkgoRecover() + Expect(code).To(BeEquivalentTo(errorSettingsError)) + Expect(reason).To(Equal("missing QUIC Datagram support")) + close(done) + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("stream- and connection-level errors", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("cancels reading when client sends a body in GET request", func() { + var handlerCalled bool + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + }) + + requestData := encodeRequest(exampleGetRequest) + buf := &bytes.Buffer{} + (&dataFrame{Length: 6}).Write(buf) // add a body + buf.Write([]byte("foobar")) + responseBuf := &bytes.Buffer{} + setRequest(append(requestData, buf.Bytes()...)) + done := make(chan struct{}) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) + str.EXPECT().Close().Do(func() { close(done) }) + + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(handlerCalled).To(BeTrue()) + }) + + It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerCalled) + r.Body.(HTTPStreamer).HTTPStream() + str.Write([]byte("foobar")) + }) + + requestData := encodeRequest(exampleGetRequest) + buf := &bytes.Buffer{} + (&dataFrame{Length: 6}).Write(buf) // add a body + buf.Write([]byte("foobar")) + setRequest(append(requestData, buf.Bytes()...)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write([]byte("foobar")).Return(6, nil) + + s.handleConn(conn) + Eventually(handlerCalled).Should(BeClosed()) + }) + + It("errors when the client sends a too large header frame", func() { + s.MaxHeaderBytes = 20 + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Fail("Handler should not be called.") + }) + + requestData := encodeRequest(exampleGetRequest) + buf := &bytes.Buffer{} + (&dataFrame{Length: 6}).Write(buf) // add a body + buf.Write([]byte("foobar")) + responseBuf := &bytes.Buffer{} + setRequest(append(requestData, buf.Bytes()...)) + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) + + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("handles a request for which the client immediately resets the stream", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(handlerCalled) + }) + + testErr := errors.New("stream reset") + done := make(chan struct{}) + str.EXPECT().Read(gomock.Any()).Return(0, testErr) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) + + s.handleConn(conn) + Consistently(handlerCalled).ShouldNot(BeClosed()) + }) + + It("closes the connection when the first frame is not a HEADERS frame", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(handlerCalled) + }) + + buf := &bytes.Buffer{} + (&dataFrame{}).Write(buf) + setRequest(buf.Bytes()) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return len(p), nil + }).AnyTimes() + + done := make(chan struct{}) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + Expect(code).To(Equal(quic.ApplicationErrorCode(errorFrameUnexpected))) + close(done) + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + + It("closes the connection when the first frame is not a HEADERS frame", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(handlerCalled) + }) + + // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, + // but the request will still end up larger than DefaultMaxHeaderBytes. + url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) + req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) + Expect(err).ToNot(HaveOccurred()) + setRequest(encodeRequest(req)) + // str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return len(p), nil + }).AnyTimes() + done := make(chan struct{}) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) + + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + }) + + It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = struct { + io.Reader + io.Closer + }{} + close(handlerCalled) + }) + + setRequest(encodeRequest(examplePostRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return len(p), nil + }).AnyTimes() + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + Eventually(handlerCalled).Should(BeClosed()) + }) + + It("cancels the request context when the stream is closed", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(r.Context().Done()).To(BeClosed()) + Expect(r.Context().Err()).To(MatchError(context.Canceled)) + close(handlerCalled) + }) + setRequest(encodeRequest(examplePostRequest)) + + reqContext, cancel := context.WithCancel(context.Background()) + cancel() + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return len(p), nil + }).AnyTimes() + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + Eventually(handlerCalled).Should(BeClosed()) + }) + }) + + Context("setting http headers", func() { + BeforeEach(func() { + s.QuicConfig = &quic.Config{Versions: []quic.VersionNumber{protocol.VersionDraft29}} + }) + + var ln1 quic.EarlyListener + var ln2 quic.EarlyListener + expected := http.Header{ + "Alt-Svc": {`h3-29=":443"; ma=2592000`}, + } + + addListener := func(addr string, ln *quic.EarlyListener) { + mln := newMockAddrListener(addr) + mln.EXPECT().Addr() + *ln = mln + s.addListener(ln) + } + + removeListener := func(ln *quic.EarlyListener) { + s.removeListener(ln) + } + + checkSetHeaders := func(expected gmtypes.GomegaMatcher) { + hdr := http.Header{} + Expect(s.SetQuicHeaders(hdr)).To(Succeed()) + Expect(hdr).To(expected) + } + + checkSetHeaderError := func() { + hdr := http.Header{} + Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) + } + + It("sets proper headers with numeric port", func() { + addListener(":443", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("sets proper headers with full addr", func() { + addListener("127.0.0.1:443", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("sets proper headers with string port", func() { + addListener(":https", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("works multiple times", func() { + addListener(":https", &ln1) + checkSetHeaders(Equal(expected)) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("works if the quic.Config sets QUIC versions", func() { + s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.VersionDraft29} + addListener(":443", &ln1) + checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3-29=":443"; ma=2592000`}})) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("uses s.Port if set to a non-zero value", func() { + s.Port = 8443 + addListener(":443", &ln1) + checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000`}})) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("uses s.Addr if listeners don't have ports available", func() { + s.Addr = ":443" + mln := &noPortListener{newMockAddrListener("")} + mln.EXPECT().Addr() + ln1 = mln + s.addListener(&ln1) + checkSetHeaders(Equal(expected)) + s.removeListener(&ln1) + checkSetHeaderError() + }) + + It("properly announces multiple listeners", func() { + addListener(":443", &ln1) + addListener(":8443", &ln2) + checkSetHeaders(Or( + Equal(http.Header{"Alt-Svc": {`h3-29=":443"; ma=2592000,h3-29=":8443"; ma=2592000`}}), + Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000,h3-29=":443"; ma=2592000`}}), + )) + removeListener(&ln1) + removeListener(&ln2) + checkSetHeaderError() + }) + }) + + It("errors when ListenAndServe is called with s.TLSConfig nil", func() { + Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) + }) + + It("should nop-Close() when s.server is nil", func() { + Expect((&Server{}).Close()).To(Succeed()) + }) + + It("errors when ListenAndServeTLS is called after Close", func() { + serv := &Server{} + Expect(serv.Close()).To(Succeed()) + Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) + }) + + It("handles concurrent Serve and Close", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + c, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.Serve(c) + }() + runtime.Gosched() + s.Close() + Eventually(done).Should(BeClosed()) + }) + + Context("ConfigureTLSConfig", func() { + var tlsConf *tls.Config + var ch *tls.ClientHelloInfo + + BeforeEach(func() { + tlsConf = &tls.Config{} + ch = &tls.ClientHelloInfo{} + }) + + It("advertises v1 by default", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) + }) + + It("advertises h3-29 for draft-29", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + ch.Conn = newMockConn(protocol.VersionDraft29) + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + }) + }) + + Context("Serve", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a packet conn", func() { + ln := newMockAddrListener(":443") + conn := &net.UDPConn{} + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + Expect(c).To(Equal(conn)) + return ln, nil + } + + s := &Server{ + TLSConfig: &tls.Config{}, + } + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept + return nil, errors.New("closed") + }) + ln.EXPECT().Addr() // generate alt-svc headers + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.Serve(conn) + }() + + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two packet conns", func() { + ln1 := newMockAddrListener(":443") + ln2 := newMockAddrListener(":8443") + lns := make(chan quic.EarlyListener, 2) + lns <- ln1 + lns <- ln2 + conn1 := &net.UDPConn{} + conn2 := &net.UDPConn{} + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + return <-lns, nil + } + + s := &Server{ + TLSConfig: &tls.Config{}, + } + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + ln1.EXPECT().Addr() // generate alt-svc headers + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + ln2.EXPECT().Addr() + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.Serve(conn1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.Serve(conn2) + }() + + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) + }) + + Context("ServeListener", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a listener", func() { + var called int32 + ln := newMockAddrListener(":443") + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return ln, nil + } + + s := &Server{} + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept + return nil, errors.New("closed") + }) + ln.EXPECT().Addr() // generate alt-svc headers + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.ServeListener(ln) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two listeners", func() { + var called int32 + ln1 := newMockAddrListener(":443") + ln2 := newMockAddrListener(":8443") + lns := make(chan quic.EarlyListener, 2) + lns <- ln1 + lns <- ln2 + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return <-lns, nil + } + + s := &Server{} + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + ln1.EXPECT().Addr() // generate alt-svc headers + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + ln2.EXPECT().Addr() + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.ServeListener(ln1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.ServeListener(ln2) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) + }) + + Context("ListenAndServe", func() { + BeforeEach(func() { + s.Addr = "localhost:0" + }) + + AfterEach(func() { + Expect(s.Close()).To(Succeed()) + }) + + checkGetConfigForClientVersions := func(conf *tls.Config) { + c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)}) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.Version1)}) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3})) + } + + It("uses the quic.Config to start the QUIC server", func() { + conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} + var receivedConf *quic.Config + quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + receivedConf = config + return nil, errors.New("listen err") + } + s.QuicConfig = conf + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf).To(Equal(conf)) + }) + + It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() { + tlsConf := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + NextProtos: []string{"foo", "bar"}, + } + var receivedConf *tls.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = tlsConf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf.NextProtos).To(BeEmpty()) + Expect(receivedConf.ClientAuth).To(BeZero()) + // make sure the original tls.Config was not modified + Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) + // make sure that the config returned from the GetConfigForClient callback sets the fields of the original config + conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) + checkGetConfigForClientVersions(receivedConf) + }) + + It("sets the GetConfigForClient callback if no tls.Config is given", func() { + var receivedConf *tls.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = tlsConf + return nil, errors.New("listen err") + } + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf).ToNot(BeNil()) + checkGetConfigForClientVersions(receivedConf) + }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + NextProtos: []string{"foo", "bar"}, + }, nil + }, + } + + var receivedConf *tls.Config + quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = conf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + // check that the original config was not modified + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + // check that the config returned by the GetConfigForClient callback uses the returned config + conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) + checkGetConfigForClientVersions(receivedConf) + }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { + tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}} + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return tlsClientConf, nil + }, + } + + var receivedConf *tls.Config + quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = conf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + // check that the original config was not modified + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + checkGetConfigForClientVersions(receivedConf) + }) + + It("works if GetConfigForClient returns a nil tls.Config", func() { + tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }} + + var receivedConf *tls.Config + quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = conf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf).ToNot(BeNil()) + checkGetConfigForClientVersions(receivedConf) + }) + }) + + It("closes gracefully", func() { + Expect(s.CloseGracefully(0)).To(Succeed()) + }) + + It("errors when listening fails", func() { + testErr := errors.New("listen error") + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + return nil, testErr + } + fullpem, privkey := testdata.GetCertificatePaths() + Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) + }) + + It("supports H3_DATAGRAM", func() { + s.EnableDatagrams = true + var receivedConf *quic.Config + quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + receivedConf = config + return nil, errors.New("listen err") + } + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf.EnableDatagrams).To(BeTrue()) + }) +}) diff --git a/internal/logging/frame.go b/internal/logging/frame.go new file mode 100644 index 00000000..c2897747 --- /dev/null +++ b/internal/logging/frame.go @@ -0,0 +1,66 @@ +package logging + +import "github.com/imroc/req/v3/internal/wire" + +// A Frame is a QUIC frame +type Frame interface{} + +// The AckRange is used within the AckFrame. +// It is a range of packet numbers that is being acknowledged. +type AckRange = wire.AckRange + +type ( + // An AckFrame is an ACK frame. + AckFrame = wire.AckFrame + // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. + ConnectionCloseFrame = wire.ConnectionCloseFrame + // A DataBlockedFrame is a DATA_BLOCKED frame. + DataBlockedFrame = wire.DataBlockedFrame + // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. + HandshakeDoneFrame = wire.HandshakeDoneFrame + // A MaxDataFrame is a MAX_DATA frame. + MaxDataFrame = wire.MaxDataFrame + // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. + MaxStreamDataFrame = wire.MaxStreamDataFrame + // A MaxStreamsFrame is a MAX_STREAMS_FRAME. + MaxStreamsFrame = wire.MaxStreamsFrame + // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. + NewConnectionIDFrame = wire.NewConnectionIDFrame + // A NewTokenFrame is a NEW_TOKEN frame. + NewTokenFrame = wire.NewTokenFrame + // A PathChallengeFrame is a PATH_CHALLENGE frame. + PathChallengeFrame = wire.PathChallengeFrame + // A PathResponseFrame is a PATH_RESPONSE frame. + PathResponseFrame = wire.PathResponseFrame + // A PingFrame is a PING frame. + PingFrame = wire.PingFrame + // A ResetStreamFrame is a RESET_STREAM frame. + ResetStreamFrame = wire.ResetStreamFrame + // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. + RetireConnectionIDFrame = wire.RetireConnectionIDFrame + // A StopSendingFrame is a STOP_SENDING frame. + StopSendingFrame = wire.StopSendingFrame + // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. + StreamsBlockedFrame = wire.StreamsBlockedFrame + // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. + StreamDataBlockedFrame = wire.StreamDataBlockedFrame +) + +// A CryptoFrame is a CRYPTO frame. +type CryptoFrame struct { + Offset ByteCount + Length ByteCount +} + +// A StreamFrame is a STREAM frame. +type StreamFrame struct { + StreamID StreamID + Offset ByteCount + Length ByteCount + Fin bool +} + +// A DatagramFrame is a DATAGRAM frame. +type DatagramFrame struct { + Length ByteCount +} diff --git a/internal/logging/interface.go b/internal/logging/interface.go new file mode 100644 index 00000000..8e6eeba6 --- /dev/null +++ b/internal/logging/interface.go @@ -0,0 +1,135 @@ +// Package logging defines a logging interface for quic-go. +// This package should not be considered stable +package logging + +import ( + "context" + "github.com/lucas-clemente/quic-go" + "net" + "time" + + "github.com/imroc/req/v3/internal/utils" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/wire" +) + +type ( + // A ByteCount is used to count bytes. + ByteCount = protocol.ByteCount + // A ConnectionID is a QUIC Connection ID. + ConnectionID = protocol.ConnectionID + // The EncryptionLevel is the encryption level of a packet. + EncryptionLevel = protocol.EncryptionLevel + // The KeyPhase is the key phase of the 1-RTT keys. + KeyPhase = protocol.KeyPhase + // The KeyPhaseBit is the value of the key phase bit of the 1-RTT packets. + KeyPhaseBit = protocol.KeyPhaseBit + // The PacketNumber is the packet number of a packet. + PacketNumber = protocol.PacketNumber + // The Perspective is the role of a QUIC endpoint (client or server). + Perspective = protocol.Perspective + // A StatelessResetToken is a stateless reset token. + StatelessResetToken = protocol.StatelessResetToken + // The StreamID is the stream ID. + StreamID = protocol.StreamID + // The StreamNum is the number of the stream. + StreamNum = protocol.StreamNum + // The StreamType is the type of the stream (unidirectional or bidirectional). + StreamType = protocol.StreamType + // The VersionNumber is the QUIC version. + VersionNumber = quic.VersionNumber + + // The Header is the QUIC packet header, before removing header protection. + Header = wire.Header + // The ExtendedHeader is the QUIC packet header, after removing header protection. + ExtendedHeader = wire.ExtendedHeader + // The TransportParameters are QUIC transport parameters. + TransportParameters = wire.TransportParameters + // The PreferredAddress is the preferred address sent in the transport parameters. + PreferredAddress = wire.PreferredAddress + + // A TransportError is a transport-level error code. + TransportError = qerr.TransportErrorCode + // An ApplicationError is an application-defined error code. + ApplicationError = qerr.TransportErrorCode + + // The RTTStats contain statistics used by the congestion controller. + RTTStats = utils.RTTStats +) + +const ( + // KeyPhaseZero is key phase bit 0 + KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero + // KeyPhaseOne is key phase bit 1 + KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne +) + +const ( + // PerspectiveServer is used for a QUIC server + PerspectiveServer Perspective = protocol.PerspectiveServer + // PerspectiveClient is used for a QUIC client + PerspectiveClient Perspective = protocol.PerspectiveClient +) + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = protocol.EncryptionInitial + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT EncryptionLevel = protocol.Encryption1RTT + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT EncryptionLevel = protocol.Encryption0RTT +) + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni = protocol.StreamTypeUni + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi = protocol.StreamTypeBidi +) + +// A Tracer traces events. +type Tracer interface { + // TracerForConnection requests a new tracer for a connection. + // The ODCID is the original destination connection ID: + // The destination connection ID that the client used on the first Initial packet it sent on this connection. + // If nil is returned, tracing will be disabled for this connection. + TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer + + SentPacket(net.Addr, *Header, ByteCount, []Frame) + DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// A ConnectionTracer records events. +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection(error) + SentTransportParameters(*TransportParameters) + ReceivedTransportParameters(*TransportParameters) + RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT + SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket(*Header, []VersionNumber) + ReceivedRetry(*Header) + ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) + BufferedPacket(PacketType) + DroppedPacket(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket(EncryptionLevel, PacketNumber) + LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState(CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(EncryptionLevel, Perspective) + UpdatedKey(generation KeyPhase, remote bool) + DroppedEncryptionLevel(EncryptionLevel) + DroppedKey(generation KeyPhase) + SetLossTimer(TimerType, EncryptionLevel, time.Time) + LossTimerExpired(TimerType, EncryptionLevel) + LossTimerCanceled() + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/internal/logging/logging_suite_test.go b/internal/logging/logging_suite_test.go new file mode 100644 index 00000000..0a81943d --- /dev/null +++ b/internal/logging/logging_suite_test.go @@ -0,0 +1,25 @@ +package logging + +import ( + "testing" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestLogging(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Logging Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/logging/mock_connection_tracer_test.go b/internal/logging/mock_connection_tracer_test.go new file mode 100644 index 00000000..a0a398aa --- /dev/null +++ b/internal/logging/mock_connection_tracer_test.go @@ -0,0 +1,351 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/logging (interfaces: ConnectionTracer) + +// Package logging is a generated GoMock package. +package logging + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + utils "github.com/imroc/req/v3/internal/utils" + wire "github.com/imroc/req/v3/internal/wire" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 quic.VersionNumber, arg1, arg2 []quic.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedPacket mocks base method. +func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []quic.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentPacket mocks base method. +func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/logging/mock_tracer_test.go b/internal/logging/mock_tracer_test.go new file mode 100644 index 00000000..98e245d6 --- /dev/null +++ b/internal/logging/mock_tracer_test.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/logging (interfaces: Tracer) + +// Package logging is a generated GoMock package. +package logging + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + wire "github.com/imroc/req/v3/internal/wire" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// TracerForConnection mocks base method. +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) + ret0, _ := ret[0].(ConnectionTracer) + return ret0 +} + +// TracerForConnection indicates an expected call of TracerForConnection. +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) +} diff --git a/internal/logging/mockgen.go b/internal/logging/mockgen.go new file mode 100644 index 00000000..48750afb --- /dev/null +++ b/internal/logging/mockgen.go @@ -0,0 +1,4 @@ +package logging + +//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/logging -destination mock_connection_tracer_test.go github.com/imroc/req/v3/internal/logging ConnectionTracer" +//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/logging -destination mock_tracer_test.go github.com/imroc/req/v3/internal/logging Tracer" diff --git a/internal/logging/multiplex.go b/internal/logging/multiplex.go new file mode 100644 index 00000000..8280e8cd --- /dev/null +++ b/internal/logging/multiplex.go @@ -0,0 +1,219 @@ +package logging + +import ( + "context" + "net" + "time" +) + +type tracerMultiplexer struct { + tracers []Tracer +} + +var _ Tracer = &tracerMultiplexer{} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...Tracer) Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &tracerMultiplexer{tracers} +} + +func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { + var connTracers []ConnectionTracer + for _, t := range m.tracers { + if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { + connTracers = append(connTracers, ct) + } + } + return NewMultiplexedConnectionTracer(connTracers...) +} + +func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(remote, hdr, size, frames) + } +} + +func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(remote, typ, size, reason) + } +} + +type connTracerMultiplexer struct { + tracers []ConnectionTracer +} + +var _ ConnectionTracer = &connTracerMultiplexer{} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &connTracerMultiplexer{tracers: tracers} +} + +func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range m.tracers { + t.StartedConnection(local, remote, srcConnID, destConnID) + } +} + +func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range m.tracers { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } +} + +func (m *connTracerMultiplexer) ClosedConnection(e error) { + for _, t := range m.tracers { + t.ClosedConnection(e) + } +} + +func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.SentTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.ReceivedTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.RestoredTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(hdr, size, ack, frames) + } +} + +func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { + for _, t := range m.tracers { + t.ReceivedVersionNegotiationPacket(hdr, versions) + } +} + +func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { + for _, t := range m.tracers { + t.ReceivedRetry(hdr) + } +} + +func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.ReceivedPacket(hdr, size, frames) + } +} + +func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { + for _, t := range m.tracers { + t.BufferedPacket(typ) + } +} + +func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(typ, size, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { + for _, t := range m.tracers { + t.UpdatedCongestionState(state) + } +} + +func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { + for _, t := range m.tracers { + t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) + } +} + +func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range m.tracers { + t.AcknowledgedPacket(encLevel, pn) + } +} + +func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range m.tracers { + t.LostPacket(encLevel, pn, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { + for _, t := range m.tracers { + t.UpdatedPTOCount(value) + } +} + +func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range m.tracers { + t.UpdatedKeyFromTLS(encLevel, perspective) + } +} + +func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { + for _, t := range m.tracers { + t.UpdatedKey(generation, remote) + } +} + +func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.DroppedEncryptionLevel(encLevel) + } +} + +func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { + for _, t := range m.tracers { + t.DroppedKey(generation) + } +} + +func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range m.tracers { + t.SetLossTimer(typ, encLevel, exp) + } +} + +func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.LossTimerExpired(typ, encLevel) + } +} + +func (m *connTracerMultiplexer) LossTimerCanceled() { + for _, t := range m.tracers { + t.LossTimerCanceled() + } +} + +func (m *connTracerMultiplexer) Debug(name, msg string) { + for _, t := range m.tracers { + t.Debug(name, msg) + } +} + +func (m *connTracerMultiplexer) Close() { + for _, t := range m.tracers { + t.Close() + } +} diff --git a/internal/logging/multiplex_test.go b/internal/logging/multiplex_test.go new file mode 100644 index 00000000..acc0e9d4 --- /dev/null +++ b/internal/logging/multiplex_test.go @@ -0,0 +1,266 @@ +package logging + +import ( + "context" + "errors" + "net" + "time" + + "github.com/imroc/req/v3/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Tracing", func() { + Context("Tracer", func() { + It("returns a nil tracer if no tracers are passed in", func() { + Expect(NewMultiplexedTracer()).To(BeNil()) + }) + + It("returns the raw tracer if only one tracer is passed in", func() { + tr := NewMockTracer(mockCtrl) + tracer := NewMultiplexedTracer(tr) + Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) + }) + + Context("tracing events", func() { + var ( + tracer Tracer + tr1, tr2 *MockTracer + ) + + BeforeEach(func() { + tr1 = NewMockTracer(mockCtrl) + tr2 = NewMockTracer(mockCtrl) + tracer = NewMultiplexedTracer(tr1, tr2) + }) + + It("multiplexes the TracerForConnection call", func() { + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + }) + + It("uses multiple connection tracers", func() { + ctx := context.Background() + ctr1 := NewMockConnectionTracer(mockCtrl) + ctr2 := NewMockConnectionTracer(mockCtrl) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + ctr1.EXPECT().LossTimerCanceled() + ctr2.EXPECT().LossTimerCanceled() + tr.LossTimerCanceled() + }) + + It("handles tracers that return a nil ConnectionTracer", func() { + ctx := context.Background() + ctr1 := NewMockConnectionTracer(mockCtrl) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + ctr1.EXPECT().LossTimerCanceled() + tr.LossTimerCanceled() + }) + + It("returns nil when all tracers return a nil ConnectionTracer", func() { + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) + }) + + It("traces the PacketSent event", func() { + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + f := &MaxDataFrame{MaximumData: 1337} + tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) + tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) + tracer.SentPacket(remote, hdr, 1024, []Frame{f}) + }) + + It("traces the PacketDropped event", func() { + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) + tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) + tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) + }) + }) + }) + + Context("Connection Tracer", func() { + var ( + tracer ConnectionTracer + tr1 *MockConnectionTracer + tr2 *MockConnectionTracer + ) + + BeforeEach(func() { + tr1 = NewMockConnectionTracer(mockCtrl) + tr2 = NewMockConnectionTracer(mockCtrl) + tracer = NewMultiplexedConnectionTracer(tr1, tr2) + }) + + It("trace the ConnectionStarted event", func() { + local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + }) + + It("traces the ClosedConnection event", func() { + e := errors.New("test err") + tr1.EXPECT().ClosedConnection(e) + tr2.EXPECT().ClosedConnection(e) + tracer.ClosedConnection(e) + }) + + It("traces the SentTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().SentTransportParameters(tp) + tr2.EXPECT().SentTransportParameters(tp) + tracer.SentTransportParameters(tp) + }) + + It("traces the ReceivedTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().ReceivedTransportParameters(tp) + tr2.EXPECT().ReceivedTransportParameters(tp) + tracer.ReceivedTransportParameters(tp) + }) + + It("traces the RestoredTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().RestoredTransportParameters(tp) + tr2.EXPECT().RestoredTransportParameters(tp) + tracer.RestoredTransportParameters(tp) + }) + + It("traces the SentPacket event", func() { + hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} + ping := &PingFrame{} + tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) + tr2.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) + tracer.SentPacket(hdr, 1337, ack, []Frame{ping}) + }) + + It("traces the ReceivedVersionNegotiationPacket event", func() { + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + tr1.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + tr2.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + tracer.ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + }) + + It("traces the ReceivedRetry event", func() { + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + tr1.EXPECT().ReceivedRetry(hdr) + tr2.EXPECT().ReceivedRetry(hdr) + tracer.ReceivedRetry(hdr) + }) + + It("traces the ReceivedPacket event", func() { + hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + ping := &PingFrame{} + tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) + tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) + tracer.ReceivedPacket(hdr, 1337, []Frame{ping}) + }) + + It("traces the BufferedPacket event", func() { + tr1.EXPECT().BufferedPacket(PacketTypeHandshake) + tr2.EXPECT().BufferedPacket(PacketTypeHandshake) + tracer.BufferedPacket(PacketTypeHandshake) + }) + + It("traces the DroppedPacket event", func() { + tr1.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) + tr2.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) + tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) + }) + + It("traces the UpdatedCongestionState event", func() { + tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tracer.UpdatedCongestionState(CongestionStateRecovery) + }) + + It("traces the UpdatedMetrics event", func() { + rttStats := &RTTStats{} + rttStats.UpdateRTT(time.Second, 0, time.Now()) + tr1.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) + tr2.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) + tracer.UpdatedMetrics(rttStats, 1337, 42, 13) + }) + + It("traces the AcknowledgedPacket event", func() { + tr1.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) + tr2.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) + tracer.AcknowledgedPacket(EncryptionHandshake, 42) + }) + + It("traces the LostPacket event", func() { + tr1.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) + tr2.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) + tracer.LostPacket(EncryptionHandshake, 42, PacketLossReorderingThreshold) + }) + + It("traces the UpdatedPTOCount event", func() { + tr1.EXPECT().UpdatedPTOCount(uint32(88)) + tr2.EXPECT().UpdatedPTOCount(uint32(88)) + tracer.UpdatedPTOCount(88) + }) + + It("traces the UpdatedKeyFromTLS event", func() { + tr1.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + tr2.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + tracer.UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + }) + + It("traces the UpdatedKey event", func() { + tr1.EXPECT().UpdatedKey(KeyPhase(42), true) + tr2.EXPECT().UpdatedKey(KeyPhase(42), true) + tracer.UpdatedKey(KeyPhase(42), true) + }) + + It("traces the DroppedEncryptionLevel event", func() { + tr1.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) + tr2.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) + tracer.DroppedEncryptionLevel(EncryptionHandshake) + }) + + It("traces the DroppedKey event", func() { + tr1.EXPECT().DroppedKey(KeyPhase(123)) + tr2.EXPECT().DroppedKey(KeyPhase(123)) + tracer.DroppedKey(123) + }) + + It("traces the SetLossTimer event", func() { + now := time.Now() + tr1.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + tr2.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + tracer.SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + }) + + It("traces the LossTimerExpired event", func() { + tr1.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) + tr2.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) + tracer.LossTimerExpired(TimerTypePTO, EncryptionHandshake) + }) + + It("traces the LossTimerCanceled event", func() { + tr1.EXPECT().LossTimerCanceled() + tr2.EXPECT().LossTimerCanceled() + tracer.LossTimerCanceled() + }) + + It("traces the Close event", func() { + tr1.EXPECT().Close() + tr2.EXPECT().Close() + tracer.Close() + }) + }) +}) diff --git a/internal/logging/packet_header.go b/internal/logging/packet_header.go new file mode 100644 index 00000000..dd7682ec --- /dev/null +++ b/internal/logging/packet_header.go @@ -0,0 +1,27 @@ +package logging + +import ( + "github.com/imroc/req/v3/internal/protocol" +) + +// PacketTypeFromHeader determines the packet type from a *wire.Header. +func PacketTypeFromHeader(hdr *Header) PacketType { + if !hdr.IsLongHeader { + return PacketType1RTT + } + if hdr.Version == 0 { + return PacketTypeVersionNegotiation + } + switch hdr.Type { + case protocol.PacketTypeInitial: + return PacketTypeInitial + case protocol.PacketTypeHandshake: + return PacketTypeHandshake + case protocol.PacketType0RTT: + return PacketType0RTT + case protocol.PacketTypeRetry: + return PacketTypeRetry + default: + return PacketTypeNotDetermined + } +} diff --git a/internal/logging/packet_header_test.go b/internal/logging/packet_header_test.go new file mode 100644 index 00000000..d0e7f08f --- /dev/null +++ b/internal/logging/packet_header_test.go @@ -0,0 +1,60 @@ +package logging + +import ( + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet Header", func() { + Context("determining the packet type from the header", func() { + It("recognizes Initial packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeInitial)) + }) + + It("recognizes Handshake packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeHandshake)) + }) + + It("recognizes Retry packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeRetry)) + }) + + It("recognizes 0-RTT packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + Version: protocol.VersionTLS, + })).To(Equal(PacketType0RTT)) + }) + + It("recognizes Version Negotiation packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{IsLongHeader: true})).To(Equal(PacketTypeVersionNegotiation)) + }) + + It("recognizes 1-RTT packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketType1RTT)) + }) + + It("handles unrecognized packet types", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeNotDetermined)) + }) + }) +}) diff --git a/internal/logging/types.go b/internal/logging/types.go new file mode 100644 index 00000000..ad800692 --- /dev/null +++ b/internal/logging/types.go @@ -0,0 +1,94 @@ +package logging + +// PacketType is the packet type of a QUIC packet +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = iota + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT + // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet + PacketTypeVersionNegotiation + // PacketType1RTT is a 1-RTT packet + PacketType1RTT + // PacketTypeStatelessReset is a stateless reset + PacketTypeStatelessReset + // PacketTypeNotDetermined is the packet type when it could not be determined + PacketTypeNotDetermined +) + +type PacketLossReason uint8 + +const ( + // PacketLossReorderingThreshold: when a packet is deemed lost due to reordering threshold + PacketLossReorderingThreshold PacketLossReason = iota + // PacketLossTimeThreshold: when a packet is deemed lost due to time threshold + PacketLossTimeThreshold +) + +type PacketDropReason uint8 + +const ( + // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable + PacketDropKeyUnavailable PacketDropReason = iota + // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown + PacketDropUnknownConnectionID + // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed + PacketDropHeaderParseError + // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed + PacketDropPayloadDecryptError + // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation + PacketDropProtocolViolation + // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack + PacketDropDOSPrevention + // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported + PacketDropUnsupportedVersion + // PacketDropUnexpectedPacket is used when an unexpected packet is received + PacketDropUnexpectedPacket + // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received + PacketDropUnexpectedSourceConnectionID + // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received + PacketDropUnexpectedVersion + // PacketDropDuplicate is used when a duplicate packet is received + PacketDropDuplicate +) + +// TimerType is the type of the loss detection timer +type TimerType uint8 + +const ( + // TimerTypeACK is the timer type for the early retransmit timer + TimerTypeACK TimerType = iota + // TimerTypePTO is the timer type for the PTO retransmit timer + TimerTypePTO +) + +// TimeoutReason is the reason why a connection is closed +type TimeoutReason uint8 + +const ( + // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonHandshake TimeoutReason = iota + // TimeoutReasonIdle is used when the connection is closed due to an idle timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonIdle +) + +type CongestionState uint8 + +const ( + // CongestionStateSlowStart is the slow start phase of Reno / Cubic + CongestionStateSlowStart CongestionState = iota + // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic + CongestionStateCongestionAvoidance + // CongestionStateRecovery is the recovery phase of Reno / Cubic + CongestionStateRecovery + // CongestionStateApplicationLimited means that the congestion controller is application limited + CongestionStateApplicationLimited +) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go new file mode 100644 index 00000000..a4134c67 --- /dev/null +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -0,0 +1,105 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/ackhandler (interfaces: ReceivedPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + wire "github.com/imroc/req/v3/internal/wire" +) + +// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. +type MockReceivedPacketHandler struct { + ctrl *gomock.Controller + recorder *MockReceivedPacketHandlerMockRecorder +} + +// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler. +type MockReceivedPacketHandlerMockRecorder struct { + mock *MockReceivedPacketHandler +} + +// NewMockReceivedPacketHandler creates a new mock instance. +func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { + mock := &MockReceivedPacketHandler{ctrl: ctrl} + mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { + return m.recorder +} + +// DropPackets mocks base method. +func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets. +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) +} + +// GetAckFrame mocks base method. +func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame. +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) +} + +// GetAlarmTimeout mocks base method. +func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout. +func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) +} + +// IsPotentiallyDuplicate mocks base method. +func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPotentiallyDuplicate", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. +func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) +} + +// ReceivedPacket mocks base method. +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN, arg2 protocol.EncryptionLevel, arg3 time.Time, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) +} diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go new file mode 100644 index 00000000..9c41986a --- /dev/null +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -0,0 +1,240 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/ackhandler (interfaces: SentPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/ackhandler" + protocol "github.com/imroc/req/v3/internal/protocol" + wire "github.com/imroc/req/v3/internal/wire" +) + +// MockSentPacketHandler is a mock of SentPacketHandler interface. +type MockSentPacketHandler struct { + ctrl *gomock.Controller + recorder *MockSentPacketHandlerMockRecorder +} + +// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler. +type MockSentPacketHandlerMockRecorder struct { + mock *MockSentPacketHandler +} + +// NewMockSentPacketHandler creates a new mock instance. +func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { + mock := &MockSentPacketHandler{ctrl: ctrl} + mock.recorder = &MockSentPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { + return m.recorder +} + +// DropPackets mocks base method. +func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets. +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) +} + +// GetLossDetectionTimeout mocks base method. +func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLossDetectionTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. +func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) +} + +// HasPacingBudget mocks base method. +func (m *MockSentPacketHandler) HasPacingBudget() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasPacingBudget") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasPacingBudget indicates an expected call of HasPacingBudget. +func (mr *MockSentPacketHandlerMockRecorder) HasPacingBudget() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSentPacketHandler)(nil).HasPacingBudget)) +} + +// OnLossDetectionTimeout mocks base method. +func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnLossDetectionTimeout") + ret0, _ := ret[0].(error) + return ret0 +} + +// OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. +func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) +} + +// PeekPacketNumber mocks base method. +func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PeekPacketNumber", arg0) + ret0, _ := ret[0].(protocol.PacketNumber) + ret1, _ := ret[1].(protocol.PacketNumberLen) + return ret0, ret1 +} + +// PeekPacketNumber indicates an expected call of PeekPacketNumber. +func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) +} + +// PopPacketNumber mocks base method. +func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PopPacketNumber", arg0) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// PopPacketNumber indicates an expected call of PopPacketNumber. +func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) +} + +// QueueProbePacket mocks base method. +func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueueProbePacket", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// QueueProbePacket indicates an expected call of QueueProbePacket. +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) +} + +// ReceivedAck mocks base method. +func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceivedAck indicates an expected call of ReceivedAck. +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) +} + +// ReceivedBytes mocks base method. +func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedBytes", arg0) +} + +// ReceivedBytes indicates an expected call of ReceivedBytes. +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) +} + +// ResetForRetry mocks base method. +func (m *MockSentPacketHandler) ResetForRetry() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResetForRetry") + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetForRetry indicates an expected call of ResetForRetry. +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) +} + +// SendMode mocks base method. +func (m *MockSentPacketHandler) SendMode() ackhandler.SendMode { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMode") + ret0, _ := ret[0].(ackhandler.SendMode) + return ret0 +} + +// SendMode indicates an expected call of SendMode. +func (mr *MockSentPacketHandlerMockRecorder) SendMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode)) +} + +// SentPacket mocks base method. +func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) +} + +// SetHandshakeConfirmed mocks base method. +func (m *MockSentPacketHandler) SetHandshakeConfirmed() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeConfirmed") +} + +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) +} + +// SetMaxDatagramSize mocks base method. +func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) +} + +// TimeUntilSend mocks base method. +func (m *MockSentPacketHandler) TimeUntilSend() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TimeUntilSend") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend. +func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) +} diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go new file mode 100644 index 00000000..23114372 --- /dev/null +++ b/internal/mocks/congestion.go @@ -0,0 +1,192 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/congestion (interfaces: SendAlgorithmWithDebugInfos) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockSendAlgorithmWithDebugInfos is a mock of SendAlgorithmWithDebugInfos interface. +type MockSendAlgorithmWithDebugInfos struct { + ctrl *gomock.Controller + recorder *MockSendAlgorithmWithDebugInfosMockRecorder +} + +// MockSendAlgorithmWithDebugInfosMockRecorder is the mock recorder for MockSendAlgorithmWithDebugInfos. +type MockSendAlgorithmWithDebugInfosMockRecorder struct { + mock *MockSendAlgorithmWithDebugInfos +} + +// NewMockSendAlgorithmWithDebugInfos creates a new mock instance. +func NewMockSendAlgorithmWithDebugInfos(ctrl *gomock.Controller) *MockSendAlgorithmWithDebugInfos { + mock := &MockSendAlgorithmWithDebugInfos{ctrl: ctrl} + mock.recorder = &MockSendAlgorithmWithDebugInfosMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSendAlgorithmWithDebugInfos) EXPECT() *MockSendAlgorithmWithDebugInfosMockRecorder { + return m.recorder +} + +// CanSend mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanSend", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanSend indicates an expected call of CanSend. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) +} + +// GetCongestionWindow mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCongestionWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetCongestionWindow indicates an expected call of GetCongestionWindow. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) +} + +// HasPacingBudget mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasPacingBudget") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasPacingBudget indicates an expected call of HasPacingBudget. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget)) +} + +// InRecovery mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InRecovery") + ret0, _ := ret[0].(bool) + return ret0 +} + +// InRecovery indicates an expected call of InRecovery. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) +} + +// InSlowStart mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InSlowStart") + ret0, _ := ret[0].(bool) + return ret0 +} + +// InSlowStart indicates an expected call of InSlowStart. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) +} + +// MaybeExitSlowStart mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "MaybeExitSlowStart") +} + +// MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) +} + +// OnPacketAcked mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) +} + +// OnPacketAcked indicates an expected call of OnPacketAcked. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) +} + +// OnPacketLost mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) +} + +// OnPacketLost indicates an expected call of OnPacketLost. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) +} + +// OnPacketSent mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) +} + +// OnPacketSent indicates an expected call of OnPacketSent. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) +} + +// OnRetransmissionTimeout mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnRetransmissionTimeout", arg0) +} + +// OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) +} + +// SetMaxDatagramSize mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) +} + +// TimeUntilSend mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TimeUntilSend", arg0) + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) +} diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go new file mode 100644 index 00000000..d16b3aae --- /dev/null +++ b/internal/mocks/connection_flow_controller.go @@ -0,0 +1,128 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/flowcontrol (interfaces: ConnectionFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockConnectionFlowController is a mock of ConnectionFlowController interface. +type MockConnectionFlowController struct { + ctrl *gomock.Controller + recorder *MockConnectionFlowControllerMockRecorder +} + +// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController. +type MockConnectionFlowControllerMockRecorder struct { + mock *MockConnectionFlowController +} + +// NewMockConnectionFlowController creates a new mock instance. +func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { + mock := &MockConnectionFlowController{ctrl: ctrl} + mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { + return m.recorder +} + +// AddBytesRead mocks base method. +func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead. +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method. +func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent. +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method. +func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate. +func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) +} + +// IsNewlyBlocked mocks base method. +func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNewlyBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. +func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) +} + +// Reset mocks base method. +func (m *MockConnectionFlowController) Reset() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reset") + ret0, _ := ret[0].(error) + return ret0 +} + +// Reset indicates an expected call of Reset. +func (mr *MockConnectionFlowControllerMockRecorder) Reset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) +} + +// SendWindowSize mocks base method. +func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize. +func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) +} + +// UpdateSendWindow mocks base method. +func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow. +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go new file mode 100644 index 00000000..b28499e2 --- /dev/null +++ b/internal/mocks/crypto_setup.go @@ -0,0 +1,264 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/handshake (interfaces: CryptoSetup) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + handshake "github.com/imroc/req/v3/internal/handshake" + protocol "github.com/imroc/req/v3/internal/protocol" + qtls "github.com/imroc/req/v3/internal/qtls" +) + +// MockCryptoSetup is a mock of CryptoSetup interface. +type MockCryptoSetup struct { + ctrl *gomock.Controller + recorder *MockCryptoSetupMockRecorder +} + +// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup. +type MockCryptoSetupMockRecorder struct { + mock *MockCryptoSetup +} + +// NewMockCryptoSetup creates a new mock instance. +func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup { + mock := &MockCryptoSetup{ctrl: ctrl} + mock.recorder = &MockCryptoSetupMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder { + return m.recorder +} + +// ChangeConnectionID mocks base method. +func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ChangeConnectionID", arg0) +} + +// ChangeConnectionID indicates an expected call of ChangeConnectionID. +func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) +} + +// Close mocks base method. +func (m *MockCryptoSetup) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) +} + +// ConnectionState mocks base method. +func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(qtls.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) +} + +// Get0RTTOpener mocks base method. +func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTOpener indicates an expected call of Get0RTTOpener. +func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) +} + +// Get0RTTSealer mocks base method. +func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTSealer indicates an expected call of Get0RTTSealer. +func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) +} + +// Get1RTTOpener mocks base method. +func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get1RTTOpener") + ret0, _ := ret[0].(handshake.ShortHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get1RTTOpener indicates an expected call of Get1RTTOpener. +func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) +} + +// Get1RTTSealer mocks base method. +func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get1RTTSealer") + ret0, _ := ret[0].(handshake.ShortHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get1RTTSealer indicates an expected call of Get1RTTSealer. +func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) +} + +// GetHandshakeOpener mocks base method. +func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeOpener indicates an expected call of GetHandshakeOpener. +func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) +} + +// GetHandshakeSealer mocks base method. +func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. +func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) +} + +// GetInitialOpener mocks base method. +func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialOpener indicates an expected call of GetInitialOpener. +func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) +} + +// GetInitialSealer mocks base method. +func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialSealer indicates an expected call of GetInitialSealer. +func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) +} + +// GetSessionTicket mocks base method. +func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionTicket") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSessionTicket indicates an expected call of GetSessionTicket. +func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) +} + +// HandleMessage mocks base method. +func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleMessage indicates an expected call of HandleMessage. +func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) +} + +// RunHandshake mocks base method. +func (m *MockCryptoSetup) RunHandshake() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RunHandshake") +} + +// RunHandshake indicates an expected call of RunHandshake. +func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) +} + +// SetHandshakeConfirmed mocks base method. +func (m *MockCryptoSetup) SetHandshakeConfirmed() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeConfirmed") +} + +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. +func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) +} + +// SetLargest1RTTAcked mocks base method. +func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. +func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) +} diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go new file mode 100644 index 00000000..fc1ae8a1 --- /dev/null +++ b/internal/mocks/logging/connection_tracer.go @@ -0,0 +1,352 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/logging (interfaces: ConnectionTracer) + +// Package mocklogging is a generated GoMock package. +package mocklogging + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + utils "github.com/imroc/req/v3/internal/utils" + wire "github.com/imroc/req/v3/internal/wire" + logging "github.com/imroc/req/v3/internal/logging" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 quic.VersionNumber, arg1, arg2 []quic.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedPacket mocks base method. +func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []quic.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentPacket mocks base method. +func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go new file mode 100644 index 00000000..bf86f4be --- /dev/null +++ b/internal/mocks/logging/tracer.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/logging (interfaces: Tracer) + +// Package mocklogging is a generated GoMock package. +package mocklogging + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" + wire "github.com/imroc/req/v3/internal/wire" + logging "github.com/imroc/req/v3/internal/logging" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// TracerForConnection mocks base method. +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) + ret0, _ := ret[0].(logging.ConnectionTracer) + return ret0 +} + +// TracerForConnection indicates an expected call of TracerForConnection. +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) +} diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go new file mode 100644 index 00000000..022bb8b3 --- /dev/null +++ b/internal/mocks/long_header_opener.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/handshake (interfaces: LongHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockLongHeaderOpener is a mock of LongHeaderOpener interface. +type MockLongHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockLongHeaderOpenerMockRecorder +} + +// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener. +type MockLongHeaderOpenerMockRecorder struct { + mock *MockLongHeaderOpener +} + +// NewMockLongHeaderOpener creates a new mock instance. +func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { + mock := &MockLongHeaderOpener{ctrl: ctrl} + mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { + return m.recorder +} + +// DecodePacketNumber mocks base method. +func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber. +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + +// DecryptHeader mocks base method. +func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader. +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method. +func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) +} diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go new file mode 100644 index 00000000..a96d7109 --- /dev/null +++ b/internal/mocks/mockgen.go @@ -0,0 +1,20 @@ +package mocks + +//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/lucas-clemente/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && goimports -w quic/early_conn.go" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener" +//go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/imroc/req/v3/internal/logging Tracer" +//go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/imroc/req/v3/internal/logging ConnectionTracer" +//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/imroc/req/v3/internal/handshake ShortHeaderSealer" +//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/imroc/req/v3/internal/handshake ShortHeaderOpener" +//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/imroc/req/v3/internal/handshake LongHeaderOpener" +//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/imroc/req/v3/internal/handshake CryptoSetup && sed -E 's~github.com/marten-seemann/qtls[[:alnum:]_-]*~github.com/imroc/req/v3/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go" +//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/imroc/req/v3/internal/flowcontrol StreamFlowController" +//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/imroc/req/v3/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/imroc/req/v3/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/imroc/req/v3/internal/ackhandler SentPacketHandler" +//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/imroc/req/v3/internal/ackhandler ReceivedPacketHandler" + +// The following command produces a warning message on OSX, however, it still generates the correct mock file. +// See https://github.com/golang/mock/issues/339 for details. +//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go new file mode 100644 index 00000000..4a9ee5ef --- /dev/null +++ b/internal/mocks/quic/early_conn.go @@ -0,0 +1,255 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: EarlyConnection) + +// Package mockquic is a generated GoMock package. +package mockquic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + quic "github.com/lucas-clemente/quic-go" + qerr "github.com/imroc/req/v3/internal/qerr" +) + +// MockEarlyConnection is a mock of EarlyConnection interface. +type MockEarlyConnection struct { + ctrl *gomock.Controller + recorder *MockEarlyConnectionMockRecorder +} + +// MockEarlyConnectionMockRecorder is the mock recorder for MockEarlyConnection. +type MockEarlyConnectionMockRecorder struct { + mock *MockEarlyConnection +} + +// NewMockEarlyConnection creates a new mock instance. +func NewMockEarlyConnection(ctrl *gomock.Controller) *MockEarlyConnection { + mock := &MockEarlyConnection{ctrl: ctrl} + mock.recorder = &MockEarlyConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEarlyConnection) EXPECT() *MockEarlyConnectionMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method. +func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptStream", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream. +func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) +} + +// AcceptUniStream mocks base method. +func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) + ret0, _ := ret[0].(quic.ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream. +func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) +} + +// CloseWithError mocks base method. +func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWithError indicates an expected call of CloseWithError. +func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) +} + +// ConnectionState mocks base method. +func (m *MockEarlyConnection) ConnectionState() quic.ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(quic.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockEarlyConnectionMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlyConnection)(nil).ConnectionState)) +} + +// Context mocks base method. +func (m *MockEarlyConnection) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlyConnection)(nil).Context)) +} + +// HandshakeComplete mocks base method. +func (m *MockEarlyConnection) HandshakeComplete() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandshakeComplete") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// HandshakeComplete indicates an expected call of HandshakeComplete. +func (mr *MockEarlyConnectionMockRecorder) HandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlyConnection)(nil).HandshakeComplete)) +} + +// LocalAddr mocks base method. +func (m *MockEarlyConnection) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockEarlyConnectionMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlyConnection)(nil).LocalAddr)) +} + +// NextConnection mocks base method. +func (m *MockEarlyConnection) NextConnection() quic.Connection { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextConnection") + ret0, _ := ret[0].(quic.Connection) + return ret0 +} + +// NextConnection indicates an expected call of NextConnection. +func (mr *MockEarlyConnectionMockRecorder) NextConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) +} + +// OpenStream mocks base method. +func (m *MockEarlyConnection) OpenStream() (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream. +func (mr *MockEarlyConnectionMockRecorder) OpenStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method. +func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync. +func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) +} + +// OpenUniStream mocks base method. +func (m *MockEarlyConnection) OpenUniStream() (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream. +func (mr *MockEarlyConnectionMockRecorder) OpenUniStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method. +func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. +func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) +} + +// ReceiveMessage mocks base method. +func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage. +func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage)) +} + +// RemoteAddr mocks base method. +func (m *MockEarlyConnection) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockEarlyConnectionMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlyConnection)(nil).RemoteAddr)) +} + +// SendMessage mocks base method. +func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage. +func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) +} diff --git a/internal/mocks/quic/early_listener.go b/internal/mocks/quic/early_listener.go new file mode 100644 index 00000000..279096b8 --- /dev/null +++ b/internal/mocks/quic/early_listener.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: EarlyListener) + +// Package mockquic is a generated GoMock package. +package mockquic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + quic "github.com/lucas-clemente/quic-go" +) + +// MockEarlyListener is a mock of EarlyListener interface. +type MockEarlyListener struct { + ctrl *gomock.Controller + recorder *MockEarlyListenerMockRecorder +} + +// MockEarlyListenerMockRecorder is the mock recorder for MockEarlyListener. +type MockEarlyListenerMockRecorder struct { + mock *MockEarlyListener +} + +// NewMockEarlyListener creates a new mock instance. +func NewMockEarlyListener(ctrl *gomock.Controller) *MockEarlyListener { + mock := &MockEarlyListener{ctrl: ctrl} + mock.recorder = &MockEarlyListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept", arg0) + ret0, _ := ret[0].(quic.EarlyConnection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockEarlyListener)(nil).Accept), arg0) +} + +// Addr mocks base method. +func (m *MockEarlyListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockEarlyListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockEarlyListener)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockEarlyListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockEarlyListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEarlyListener)(nil).Close)) +} diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go new file mode 100644 index 00000000..a298ba5e --- /dev/null +++ b/internal/mocks/quic/stream.go @@ -0,0 +1,176 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Stream) + +// Package mockquic is a generated GoMock package. +package mockquic + +import ( + context "context" + "github.com/lucas-clemente/quic-go" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + qerr "github.com/imroc/req/v3/internal/qerr" +) + +// MockStream is a mock of Stream interface. +type MockStream struct { + ctrl *gomock.Controller + recorder *MockStreamMockRecorder +} + +// MockStreamMockRecorder is the mock recorder for MockStream. +type MockStreamMockRecorder struct { + mock *MockStream +} + +// NewMockStream creates a new mock instance. +func NewMockStream(ctrl *gomock.Controller) *MockStream { + mock := &MockStream{ctrl: ctrl} + mock.recorder = &MockStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStream) EXPECT() *MockStreamMockRecorder { + return m.recorder +} + +// CancelRead mocks base method. +func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelRead", arg0) +} + +// CancelRead indicates an expected call of CancelRead. +func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) +} + +// CancelWrite mocks base method. +func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelWrite", arg0) +} + +// CancelWrite indicates an expected call of CancelWrite. +func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) +} + +// Close mocks base method. +func (m *MockStream) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockStreamMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close)) +} + +// Context mocks base method. +func (m *MockStream) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockStreamMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context)) +} + +// Read mocks base method. +func (m *MockStream) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) +} + +// SetDeadline mocks base method. +func (m *MockStream) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method. +func (m *MockStream) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method. +func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) +} + +// StreamID mocks base method. +func (m *MockStream) StreamID() quic.StreamID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(quic.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID. +func (mr *MockStreamMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID)) +} + +// Write mocks base method. +func (m *MockStream) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) +} diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go new file mode 100644 index 00000000..146579c1 --- /dev/null +++ b/internal/mocks/short_header_opener.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/handshake (interfaces: ShortHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockShortHeaderOpener is a mock of ShortHeaderOpener interface. +type MockShortHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockShortHeaderOpenerMockRecorder +} + +// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener. +type MockShortHeaderOpenerMockRecorder struct { + mock *MockShortHeaderOpener +} + +// NewMockShortHeaderOpener creates a new mock instance. +func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { + mock := &MockShortHeaderOpener{ctrl: ctrl} + mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { + return m.recorder +} + +// DecodePacketNumber mocks base method. +func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber. +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + +// DecryptHeader mocks base method. +func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader. +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method. +func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) +} diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go new file mode 100644 index 00000000..2ea53fc4 --- /dev/null +++ b/internal/mocks/short_header_sealer.go @@ -0,0 +1,89 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/handshake (interfaces: ShortHeaderSealer) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockShortHeaderSealer is a mock of ShortHeaderSealer interface. +type MockShortHeaderSealer struct { + ctrl *gomock.Controller + recorder *MockShortHeaderSealerMockRecorder +} + +// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer. +type MockShortHeaderSealerMockRecorder struct { + mock *MockShortHeaderSealer +} + +// NewMockShortHeaderSealer creates a new mock instance. +func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { + mock := &MockShortHeaderSealer{ctrl: ctrl} + mock.recorder = &MockShortHeaderSealerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { + return m.recorder +} + +// EncryptHeader mocks base method. +func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) +} + +// EncryptHeader indicates an expected call of EncryptHeader. +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) +} + +// KeyPhase mocks base method. +func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyPhase") + ret0, _ := ret[0].(protocol.KeyPhaseBit) + return ret0 +} + +// KeyPhase indicates an expected call of KeyPhase. +func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) +} + +// Overhead mocks base method. +func (m *MockShortHeaderSealer) Overhead() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Overhead") + ret0, _ := ret[0].(int) + return ret0 +} + +// Overhead indicates an expected call of Overhead. +func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) +} + +// Seal mocks base method. +func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Seal indicates an expected call of Seal. +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) +} diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go new file mode 100644 index 00000000..ba779a6a --- /dev/null +++ b/internal/mocks/stream_flow_controller.go @@ -0,0 +1,140 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/flowcontrol (interfaces: StreamFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/protocol" +) + +// MockStreamFlowController is a mock of StreamFlowController interface. +type MockStreamFlowController struct { + ctrl *gomock.Controller + recorder *MockStreamFlowControllerMockRecorder +} + +// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController. +type MockStreamFlowControllerMockRecorder struct { + mock *MockStreamFlowController +} + +// NewMockStreamFlowController creates a new mock instance. +func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { + mock := &MockStreamFlowController{ctrl: ctrl} + mock.recorder = &MockStreamFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { + return m.recorder +} + +// Abandon mocks base method. +func (m *MockStreamFlowController) Abandon() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Abandon") +} + +// Abandon indicates an expected call of Abandon. +func (mr *MockStreamFlowControllerMockRecorder) Abandon() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) +} + +// AddBytesRead mocks base method. +func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead. +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method. +func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent. +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method. +func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate. +func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) +} + +// IsNewlyBlocked mocks base method. +func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNewlyBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. +func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) +} + +// SendWindowSize mocks base method. +func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize. +func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) +} + +// UpdateHighestReceived mocks base method. +func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount, arg1 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateHighestReceived", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHighestReceived indicates an expected call of UpdateHighestReceived. +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) +} + +// UpdateSendWindow mocks base method. +func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow. +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go new file mode 100644 index 00000000..e3ae2c8e --- /dev/null +++ b/internal/mocks/tls/client_session_cache.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto/tls (interfaces: ClientSessionCache) + +// Package mocktls is a generated GoMock package. +package mocktls + +import ( + tls "crypto/tls" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockClientSessionCache is a mock of ClientSessionCache interface. +type MockClientSessionCache struct { + ctrl *gomock.Controller + recorder *MockClientSessionCacheMockRecorder +} + +// MockClientSessionCacheMockRecorder is the mock recorder for MockClientSessionCache. +type MockClientSessionCacheMockRecorder struct { + mock *MockClientSessionCache +} + +// NewMockClientSessionCache creates a new mock instance. +func NewMockClientSessionCache(ctrl *gomock.Controller) *MockClientSessionCache { + mock := &MockClientSessionCache{ctrl: ctrl} + mock.recorder = &MockClientSessionCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientSessionCache) EXPECT() *MockClientSessionCacheMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(*tls.ClientSessionState) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) +} + +// Put mocks base method. +func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Put", arg0, arg1) +} + +// Put indicates an expected call of Put. +func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) +} diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go new file mode 100644 index 00000000..28a80fcb --- /dev/null +++ b/internal/netutil/addr.go @@ -0,0 +1,40 @@ +package netutil + +import ( + "golang.org/x/net/idna" + "net" + "net/url" + "strings" +) + +func AuthorityKey(u *url.URL) string { + return u.Scheme + "://" + AuthorityAddr(u.Scheme, u.Host) +} + +// AuthorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func AuthorityAddr(scheme, authority string) (addr string) { + host, port := AuthorityHostPort(scheme, authority) + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + addr = net.JoinHostPort(host, port) + return +} + +func AuthorityHostPort(scheme, authority string) (host, port string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + + return +} diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go new file mode 100644 index 00000000..3aec2cd3 --- /dev/null +++ b/internal/protocol/connection_id.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +const maxConnectionIDLen = 20 + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 20 bytes. +func GenerateConnectionIDForInitial() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go new file mode 100644 index 00000000..345e656c --- /dev/null +++ b/internal/protocol/connection_id_test.go @@ -0,0 +1,108 @@ +package protocol + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection ID generation", func() { + It("generates random connection IDs", func() { + c1, err := GenerateConnectionID(8) + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(BeZero()) + c2, err := GenerateConnectionID(8) + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(Equal(c2)) + }) + + It("generates connection IDs with the requested length", func() { + c, err := GenerateConnectionID(5) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(Equal(5)) + }) + + It("generates random length destination connection IDs", func() { + var has8ByteConnID, has20ByteConnID bool + for i := 0; i < 1000; i++ { + c, err := GenerateConnectionIDForInitial() + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(BeNumerically(">=", 8)) + Expect(c.Len()).To(BeNumerically("<=", 20)) + if c.Len() == 8 { + has8ByteConnID = true + } + if c.Len() == 20 { + has20ByteConnID = true + } + } + Expect(has8ByteConnID).To(BeTrue()) + Expect(has20ByteConnID).To(BeTrue()) + }) + + It("says if connection IDs are equal", func() { + c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + Expect(c1.Equal(c1)).To(BeTrue()) + Expect(c2.Equal(c2)).To(BeTrue()) + Expect(c1.Equal(c2)).To(BeFalse()) + Expect(c2.Equal(c1)).To(BeFalse()) + }) + + It("reads the connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + c, err := ReadConnectionID(buf, 9) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) + }) + + It("returns io.EOF if there's not enough data to read", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + _, err := ReadConnectionID(buf, 5) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns nil for a 0 length connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + c, err := ReadConnectionID(buf, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(c).To(BeNil()) + }) + + It("returns the length", func() { + c := ConnectionID{1, 2, 3, 4, 5, 6, 7} + Expect(c.Len()).To(Equal(7)) + }) + + It("has 0 length for the default value", func() { + var c ConnectionID + Expect(c.Len()).To(BeZero()) + }) + + It("returns the bytes", func() { + c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) + }) + + It("returns a nil byte slice for the default value", func() { + var c ConnectionID + Expect(c.Bytes()).To(BeNil()) + }) + + It("has a string representation", func() { + c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + Expect(c.String()).To(Equal("deadbeef42")) + }) + + It("has a long string representation", func() { + c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} + Expect(c.String()).To(Equal("13370000decafbad")) + }) + + It("has a string representation for the default value", func() { + var c ConnectionID + Expect(c.String()).To(Equal("(empty)")) + }) +}) diff --git a/internal/protocol/encryption_level.go b/internal/protocol/encryption_level.go new file mode 100644 index 00000000..32d38ab1 --- /dev/null +++ b/internal/protocol/encryption_level.go @@ -0,0 +1,30 @@ +package protocol + +// EncryptionLevel is the encryption level +// Default value is Unencrypted +type EncryptionLevel uint8 + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = 1 + iota + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT +) + +func (e EncryptionLevel) String() string { + switch e { + case EncryptionInitial: + return "Initial" + case EncryptionHandshake: + return "Handshake" + case Encryption0RTT: + return "0-RTT" + case Encryption1RTT: + return "1-RTT" + } + return "unknown" +} diff --git a/internal/protocol/encryption_level_test.go b/internal/protocol/encryption_level_test.go new file mode 100644 index 00000000..9b07b08b --- /dev/null +++ b/internal/protocol/encryption_level_test.go @@ -0,0 +1,20 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Encryption Level", func() { + It("doesn't use 0 as a value", func() { + // 0 is used in some tests + Expect(EncryptionInitial * EncryptionHandshake * Encryption0RTT * Encryption1RTT).ToNot(BeZero()) + }) + + It("has the correct string representation", func() { + Expect(EncryptionInitial.String()).To(Equal("Initial")) + Expect(EncryptionHandshake.String()).To(Equal("Handshake")) + Expect(Encryption0RTT.String()).To(Equal("0-RTT")) + Expect(Encryption1RTT.String()).To(Equal("1-RTT")) + }) +}) diff --git a/internal/protocol/key_phase.go b/internal/protocol/key_phase.go new file mode 100644 index 00000000..edd740cf --- /dev/null +++ b/internal/protocol/key_phase.go @@ -0,0 +1,36 @@ +package protocol + +// KeyPhase is the key phase +type KeyPhase uint64 + +// Bit determines the key phase bit +func (p KeyPhase) Bit() KeyPhaseBit { + if p%2 == 0 { + return KeyPhaseZero + } + return KeyPhaseOne +} + +// KeyPhaseBit is the key phase bit +type KeyPhaseBit uint8 + +const ( + // KeyPhaseUndefined is an undefined key phase + KeyPhaseUndefined KeyPhaseBit = iota + // KeyPhaseZero is key phase 0 + KeyPhaseZero + // KeyPhaseOne is key phase 1 + KeyPhaseOne +) + +func (p KeyPhaseBit) String() string { + //nolint:exhaustive + switch p { + case KeyPhaseZero: + return "0" + case KeyPhaseOne: + return "1" + default: + return "undefined" + } +} diff --git a/internal/protocol/key_phase_test.go b/internal/protocol/key_phase_test.go new file mode 100644 index 00000000..92f404a5 --- /dev/null +++ b/internal/protocol/key_phase_test.go @@ -0,0 +1,27 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Key Phases", func() { + It("has undefined as its default value", func() { + var k KeyPhaseBit + Expect(k).To(Equal(KeyPhaseUndefined)) + }) + + It("has the correct string representation", func() { + Expect(KeyPhaseZero.String()).To(Equal("0")) + Expect(KeyPhaseOne.String()).To(Equal("1")) + }) + + It("converts the key phase to the key phase bit", func() { + Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne)) + }) +}) diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go new file mode 100644 index 00000000..bd340161 --- /dev/null +++ b/internal/protocol/packet_number.go @@ -0,0 +1,79 @@ +package protocol + +// A PacketNumber in QUIC +type PacketNumber int64 + +// InvalidPacketNumber is a packet number that is never sent. +// In QUIC, 0 is a valid packet number. +const InvalidPacketNumber PacketNumber = -1 + +// PacketNumberLen is the length of the packet number in bytes +type PacketNumberLen uint8 + +const ( + // PacketNumberLen1 is a packet number length of 1 byte + PacketNumberLen1 PacketNumberLen = 1 + // PacketNumberLen2 is a packet number length of 2 bytes + PacketNumberLen2 PacketNumberLen = 2 + // PacketNumberLen3 is a packet number length of 3 bytes + PacketNumberLen3 PacketNumberLen = 3 + // PacketNumberLen4 is a packet number length of 4 bytes + PacketNumberLen4 PacketNumberLen = 4 +) + +// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number +func DecodePacketNumber( + packetNumberLength PacketNumberLen, + lastPacketNumber PacketNumber, + wirePacketNumber PacketNumber, +) PacketNumber { + var epochDelta PacketNumber + switch packetNumberLength { + case PacketNumberLen1: + epochDelta = PacketNumber(1) << 8 + case PacketNumberLen2: + epochDelta = PacketNumber(1) << 16 + case PacketNumberLen3: + epochDelta = PacketNumber(1) << 24 + case PacketNumberLen4: + epochDelta = PacketNumber(1) << 32 + } + epoch := lastPacketNumber & ^(epochDelta - 1) + var prevEpochBegin PacketNumber + if epoch > epochDelta { + prevEpochBegin = epoch - epochDelta + } + nextEpochBegin := epoch + epochDelta + return closestTo( + lastPacketNumber+1, + epoch+wirePacketNumber, + closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), + ) +} + +func closestTo(target, a, b PacketNumber) PacketNumber { + if delta(target, a) < delta(target, b) { + return a + } + return b +} + +func delta(a, b PacketNumber) PacketNumber { + if a < b { + return b - a + } + return a - b +} + +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header +// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { + diff := uint64(packetNumber - leastUnacked) + if diff < (1 << (16 - 1)) { + return PacketNumberLen2 + } + if diff < (1 << (24 - 1)) { + return PacketNumberLen3 + } + return PacketNumberLen4 +} diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go new file mode 100644 index 00000000..d3bfe1d5 --- /dev/null +++ b/internal/protocol/packet_number_test.go @@ -0,0 +1,204 @@ +package protocol + +import ( + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// Tests taken and extended from chrome +var _ = Describe("packet number calculation", func() { + It("InvalidPacketNumber is smaller than all valid packet numbers", func() { + Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) + }) + + It("works with the example from the draft", func() { + Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) + }) + + It("works with the examples from the draft", func() { + Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) + Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) + }) + + getEpoch := func(len PacketNumberLen) uint64 { + if len > 4 { + Fail("invalid packet number len") + } + return uint64(1) << (len * 8) + } + + check := func(length PacketNumberLen, expected, last uint64) { + epoch := getEpoch(length) + epochMask := epoch - 1 + wirePacketNumber := expected & epochMask + ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) + } + + for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { + length := l + + Context(fmt.Sprintf("with %d bytes", length), func() { + epoch := getEpoch(length) + epochMask := epoch - 1 + + It("works near epoch start", func() { + // A few quick manual sanity check + check(length, 1, 0) + check(length, epoch+1, epochMask) + check(length, epoch, epochMask) + + // Cases where the last number was close to the start of the range. + for last := uint64(0); last < 10; last++ { + // Small numbers should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, j, last) + } + + // Large numbers should not wrap either (because we're near 0 already). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + It("works near epoch end", func() { + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := epoch - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, epoch+j, last) + } + + // Large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + // Next check where we're in a non-zero epoch to verify we handle + // reverse wrapping, too. + It("works near previous epoch", func() { + prevEpoch := 1 * epoch + curEpoch := 2 * epoch + // Cases where the last number was close to the start of the range + for i := uint64(0); i < 10; i++ { + last := curEpoch + i + // Small number should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, curEpoch+j, last) + } + + // But large numbers should reverse wrap. + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, prevEpoch+num, last) + } + } + }) + + It("works near next epoch", func() { + curEpoch := 2 * epoch + nextEpoch := 3 * epoch + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := nextEpoch - 1 - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, nextEpoch+j, last) + } + + // but large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, curEpoch+num, last) + } + } + }) + + Context("shortening a packet number for the header", func() { + Context("shortening", func() { + It("sends out low packet numbers as 2 byte", func() { + length := GetPacketNumberLengthForHeader(4, 2) + Expect(length).To(Equal(PacketNumberLen2)) + }) + + It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { + length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) + Expect(length).To(Equal(PacketNumberLen2)) + }) + + It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { + length := GetPacketNumberLengthForHeader(40000, 2) + Expect(length).To(Equal(PacketNumberLen3)) + }) + + It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { + length := GetPacketNumberLengthForHeader(40000000, 2) + Expect(length).To(Equal(PacketNumberLen4)) + }) + }) + + Context("self-consistency", func() { + It("works for small packet numbers", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(1) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("works for small packet numbers and increasing ACKed packets", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(i / 2) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 + wirePacketNumber := uint64(packetNumber) & epochMask + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("also works for larger packet numbers", func() { + var increment uint64 + for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(1) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 + wirePacketNumber := uint64(packetNumber) & epochMask + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + + increment = getEpoch(length) / 8 + } + }) + + It("works for packet numbers larger than 2^48", func() { + for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(i - 1000) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + }) + }) + }) + } +}) diff --git a/internal/protocol/params.go b/internal/protocol/params.go new file mode 100644 index 00000000..83137113 --- /dev/null +++ b/internal/protocol/params.go @@ -0,0 +1,193 @@ +package protocol + +import "time" + +// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. +const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB + +// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. +const InitialPacketSizeIPv4 = 1252 + +// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. +const InitialPacketSizeIPv6 = 1232 + +// MaxCongestionWindowPackets is the maximum congestion window in packet. +const MaxCongestionWindowPackets = 10000 + +// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. +const MaxUndecryptablePackets = 32 + +// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window +// This is the value that Chromium is using +const ConnectionFlowControlMultiplier = 1.5 + +// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data +const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb + +// DefaultInitialMaxData is the connection-level flow control window for receiving data +const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData + +// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data +const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB + +// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data +const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB + +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 + +// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open +const DefaultMaxIncomingStreams = 100 + +// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open +const DefaultMaxIncomingUniStreams = 100 + +// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. +const MaxServerUnprocessedPackets = 1024 + +// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. +const MaxConnUnprocessedPackets = 256 + +// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. +// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. +const SkipPacketInitialPeriod PacketNumber = 256 + +// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. +const SkipPacketMaxPeriod PacketNumber = 128 * 1024 + +// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. +// If the queue is full, new connection attempts will be rejected. +const MaxAcceptQueueSize = 32 + +// TokenValidity is the duration that a (non-retry) token is considered valid +const TokenValidity = 24 * time.Hour + +// RetryTokenValidity is the duration that a retry token is considered valid +const RetryTokenValidity = 10 * time.Second + +// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. +// When reached, it imposes a soft limit on sending new packets: +// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. +const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets + +// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. +// When reached, no more packets will be sent. +// This value *must* be larger than MaxOutstandingSentPackets. +const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 + +// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, +// but no ack-eliciting frames, that we send in a row +const MaxNonAckElicitingAcks = 19 + +// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames +// prevents DoS attacks against the streamFrameSorter +const MaxStreamFrameSorterGaps = 1000 + +// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame +// that we use the buffer for. This protects against a DoS where an attacker would send us +// very small STREAM frames to consume a lot of memory. +const MinStreamFrameBufferSize = 128 + +// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. +// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. +const MinCoalescedPacketSize = 128 + +// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. +// This limits the size of the ClientHello and Certificates that can be received. +const MaxCryptoStreamOffset = 16 * (1 << 10) + +// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout +const MinRemoteIdleTimeout = 5 * time.Second + +// DefaultIdleTimeout is the default idle timeout +const DefaultIdleTimeout = 30 * time.Second + +// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. +const DefaultHandshakeIdleTimeout = 5 * time.Second + +// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. +const DefaultHandshakeTimeout = 10 * time.Second + +// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. +// It should be shorter than the time that NATs clear their mapping. +const MaxKeepAliveInterval = 20 * time.Second + +// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. +// after this time all information about the old connection will be deleted +const RetiredConnectionIDDeleteTimeout = 5 * time.Second + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 + +// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames +// we send after the handshake completes. +const MaxPostHandshakeCryptoFrameSize = 1000 + +// MaxAckFrameSize is the maximum size for an ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + +// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). +// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. +const MaxDatagramFrameSize ByteCount = 1220 + +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) +const DatagramRcvQueueLen = 128 + +// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. +// It also serves as a limit for the packet history. +// If at any point we keep track of more ranges, old ranges are discarded. +const MaxNumAckRanges = 32 + +// MinPacingDelay is the minimum duration that is used for packet pacing +// If the packet packing frequency is higher, multiple packets might be sent at once. +// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. +const MinPacingDelay = time.Millisecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 + +// MaxActiveConnectionIDs is the number of connection IDs that we're storing. +const MaxActiveConnectionIDs = 4 + +// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. +const MaxIssuedConnectionIDs = 6 + +// PacketsPerConnectionID is the number of packets we send using one connection ID. +// If the peer provices us with enough new connection IDs, we switch to a new connection ID. +const PacketsPerConnectionID = 10000 + +// AckDelayExponent is the ack delay exponent used when sending ACKs. +const AckDelayExponent = 3 + +// Estimated timer granularity. +// The loss detection timer will not be set to a value smaller than granularity. +const TimerGranularity = time.Millisecond + +// MaxAckDelay is the maximum time by which we delay sending ACKs. +const MaxAckDelay = 25 * time.Millisecond + +// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. +// This is the value that should be advertised to the peer. +const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +const KeyUpdateInterval = 100 * 1000 + +// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. +const Max0RTTQueueingDuration = 100 * time.Millisecond + +// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. +const Max0RTTQueues = 32 + +// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. +// When a new connection is created, all buffered packets are passed to the connection immediately. +// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. +// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. +const Max0RTTQueueLen = 31 diff --git a/internal/protocol/params_test.go b/internal/protocol/params_test.go new file mode 100644 index 00000000..50a260d2 --- /dev/null +++ b/internal/protocol/params_test.go @@ -0,0 +1,13 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Parameters", func() { + It("can queue more packets in the session than in the 0-RTT queue", func() { + Expect(MaxConnUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) + Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) + }) +}) diff --git a/internal/protocol/perspective.go b/internal/protocol/perspective.go new file mode 100644 index 00000000..43358fec --- /dev/null +++ b/internal/protocol/perspective.go @@ -0,0 +1,26 @@ +package protocol + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) + +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "Server" + case PerspectiveClient: + return "Client" + default: + return "invalid perspective" + } +} diff --git a/internal/protocol/perspective_test.go b/internal/protocol/perspective_test.go new file mode 100644 index 00000000..0ae23d7c --- /dev/null +++ b/internal/protocol/perspective_test.go @@ -0,0 +1,19 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Perspective", func() { + It("has a string representation", func() { + Expect(PerspectiveClient.String()).To(Equal("Client")) + Expect(PerspectiveServer.String()).To(Equal("Server")) + Expect(Perspective(0).String()).To(Equal("invalid perspective")) + }) + + It("returns the opposite", func() { + Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer)) + Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient)) + }) +}) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go new file mode 100644 index 00000000..8241e274 --- /dev/null +++ b/internal/protocol/protocol.go @@ -0,0 +1,97 @@ +package protocol + +import ( + "fmt" + "time" +) + +// The PacketType is the Long Header Type +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = 1 + iota + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT +) + +func (t PacketType) String() string { + switch t { + case PacketTypeInitial: + return "Initial" + case PacketTypeRetry: + return "Retry" + case PacketTypeHandshake: + return "Handshake" + case PacketType0RTT: + return "0-RTT Protected" + default: + return fmt.Sprintf("unknown packet type: %d", t) + } +} + +type ECN uint8 + +const ( + ECNNon ECN = iota // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 +) + +// A ByteCount in QUIC +type ByteCount int64 + +// MaxByteCount is the maximum value of a ByteCount +const MaxByteCount = ByteCount(1<<62 - 1) + +// InvalidByteCount is an invalid byte count +const InvalidByteCount ByteCount = -1 + +// A StatelessResetToken is a stateless reset token. +type StatelessResetToken [16]byte + +// MaxPacketBufferSize maximum packet size of any QUIC packet, based on +// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, +// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. +// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. +const MaxPacketBufferSize ByteCount = 1452 + +// MinInitialPacketSize is the minimum size an Initial packet is required to have. +const MinInitialPacketSize = 1200 + +// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version +// needs to have in order to trigger a Version Negotiation packet. +const MinUnknownVersionPacketSize = MinInitialPacketSize + +// MinStatelessResetSize is the minimum size of a stateless reset packet that we send +const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ + +// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. +const MinConnectionIDLenInitial = 8 + +// DefaultAckDelayExponent is the default ack delay exponent +const DefaultAckDelayExponent = 3 + +// MaxAckDelayExponent is the maximum ack delay exponent +const MaxAckDelayExponent = 20 + +// DefaultMaxAckDelay is the default max_ack_delay +const DefaultMaxAckDelay = 25 * time.Millisecond + +// MaxMaxAckDelay is the maximum max_ack_delay +const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond + +// MaxConnIDLen is the maximum length of the connection ID +const MaxConnIDLen = 20 + +// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using +// AEAD_AES_128_GCM or AEAD_AES_265_GCM. +const InvalidPacketLimitAES = 1 << 52 + +// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. +const InvalidPacketLimitChaCha = 1 << 36 diff --git a/internal/protocol/protocol_suite_test.go b/internal/protocol/protocol_suite_test.go new file mode 100644 index 00000000..60da0157 --- /dev/null +++ b/internal/protocol/protocol_suite_test.go @@ -0,0 +1,13 @@ +package protocol + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestProtocol(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Protocol Suite") +} diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go new file mode 100644 index 00000000..117405e4 --- /dev/null +++ b/internal/protocol/protocol_test.go @@ -0,0 +1,25 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Protocol", func() { + Context("Long Header Packet Types", func() { + It("has the correct string representation", func() { + Expect(PacketTypeInitial.String()).To(Equal("Initial")) + Expect(PacketTypeRetry.String()).To(Equal("Retry")) + Expect(PacketTypeHandshake.String()).To(Equal("Handshake")) + Expect(PacketType0RTT.String()).To(Equal("0-RTT Protected")) + Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) + }) + }) + + It("converts ECN bits from the IP header wire to the correct types", func() { + Expect(ECN(0)).To(Equal(ECNNon)) + Expect(ECN(0b00000010)).To(Equal(ECT0)) + Expect(ECN(0b00000001)).To(Equal(ECT1)) + Expect(ECN(0b00000011)).To(Equal(ECNCE)) + }) +}) diff --git a/internal/protocol/stream.go b/internal/protocol/stream.go new file mode 100644 index 00000000..ad7de864 --- /dev/null +++ b/internal/protocol/stream.go @@ -0,0 +1,76 @@ +package protocol + +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} + +// A StreamID in QUIC +type StreamID int64 + +// InitiatedBy says if the stream was initiated by the client or by the server +func (s StreamID) InitiatedBy() Perspective { + if s%2 == 0 { + return PerspectiveClient + } + return PerspectiveServer +} + +// Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi +} + +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() StreamNum { + return StreamNum(s/4) + 1 +} diff --git a/internal/protocol/stream_test.go b/internal/protocol/stream_test.go new file mode 100644 index 00000000..4209f8a0 --- /dev/null +++ b/internal/protocol/stream_test.go @@ -0,0 +1,70 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream ID", func() { + It("InvalidStreamID is smaller than all valid stream IDs", func() { + Expect(InvalidStreamID).To(BeNumerically("<", 0)) + }) + + It("says who initiated a stream", func() { + Expect(StreamID(4).InitiatedBy()).To(Equal(PerspectiveClient)) + Expect(StreamID(5).InitiatedBy()).To(Equal(PerspectiveServer)) + Expect(StreamID(6).InitiatedBy()).To(Equal(PerspectiveClient)) + Expect(StreamID(7).InitiatedBy()).To(Equal(PerspectiveServer)) + }) + + It("tells the directionality", func() { + Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(6).Type()).To(Equal(StreamTypeUni)) + Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) + }) + + It("tells the stream number", func() { + Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) + }) + + Context("converting stream nums to stream IDs", func() { + It("handles 0", func() { + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID)) + }) + + It("handles the first", func() { + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) + }) + + It("handles others", func() { + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396))) + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) + }) + + It("has the right value for MaxStreamCount", func() { + const maxStreamID = StreamID(1<<62 - 1) + for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { + for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { + Expect(MaxStreamCount.StreamID(dir, pers)).To(BeNumerically("<=", maxStreamID)) + Expect((MaxStreamCount + 1).StreamID(dir, pers)).To(BeNumerically(">", maxStreamID)) + } + } + }) + }) +}) diff --git a/internal/protocol/version.go b/internal/protocol/version.go new file mode 100644 index 00000000..5f0d93c8 --- /dev/null +++ b/internal/protocol/version.go @@ -0,0 +1,77 @@ +package protocol + +import ( + "crypto/rand" + "encoding/binary" + "github.com/lucas-clemente/quic-go" + "math" +) + +// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions +const ( + gquicVersion0 = 0x51303030 + maxGquicVersion = 0x51303439 +) + +// The version numbers, making grepping easier +const ( + VersionTLS quic.VersionNumber = 0x1 + VersionWhatever quic.VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter + VersionUnknown quic.VersionNumber = math.MaxUint32 + VersionDraft29 quic.VersionNumber = 0xff00001d + Version1 quic.VersionNumber = 0x1 + Version2 quic.VersionNumber = 0x709a50c4 +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []quic.VersionNumber{Version1, Version2, VersionDraft29} + +// IsValidVersion says if the version is known to quic-go +func IsValidVersion(v quic.VersionNumber) bool { + return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) +} + +// IsSupportedVersion returns true if the server supports this version +func IsSupportedVersion(supported []quic.VersionNumber, v quic.VersionNumber) bool { + for _, t := range supported { + if t == v { + return true + } + } + return false +} + +// ChooseSupportedVersion finds the best version in the overlap of ours and theirs +// ours is a slice of versions that we support, sorted by our preference (descending) +// theirs is a slice of versions offered by the peer. The order does not matter. +// The bool returned indicates if a matching version was found. +func ChooseSupportedVersion(ours, theirs []quic.VersionNumber) (quic.VersionNumber, bool) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return ourVer, true + } + } + } + return 0, false +} + +// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() quic.VersionNumber { + b := make([]byte, 4) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + return quic.VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +} + +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position +func GetGreasedVersions(supported []quic.VersionNumber) []quic.VersionNumber { + b := make([]byte, 1) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + randPos := int(b[0]) % (len(supported) + 1) + greased := make([]quic.VersionNumber, len(supported)+1) + copy(greased, supported[:randPos]) + greased[randPos] = generateReservedVersion() + copy(greased[randPos+1:], supported[randPos:]) + return greased +} diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go new file mode 100644 index 00000000..33c6598b --- /dev/null +++ b/internal/protocol/version_test.go @@ -0,0 +1,121 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Version", func() { + isReservedVersion := func(v VersionNumber) bool { + return v&0x0f0f0f0f == 0x0a0a0a0a + } + + It("says if a version is valid", func() { + Expect(IsValidVersion(VersionTLS)).To(BeTrue()) + Expect(IsValidVersion(VersionWhatever)).To(BeFalse()) + Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) + Expect(IsValidVersion(VersionDraft29)).To(BeTrue()) + Expect(IsValidVersion(Version1)).To(BeTrue()) + Expect(IsValidVersion(Version2)).To(BeTrue()) + Expect(IsValidVersion(1234)).To(BeFalse()) + }) + + It("versions don't have reserved version numbers", func() { + Expect(isReservedVersion(VersionTLS)).To(BeFalse()) + }) + + It("has the right string representation", func() { + Expect(VersionWhatever.String()).To(Equal("whatever")) + Expect(VersionUnknown.String()).To(Equal("unknown")) + Expect(VersionDraft29.String()).To(Equal("draft-29")) + Expect(Version1.String()).To(Equal("v1")) + Expect(Version2.String()).To(Equal("v2")) + // check with unsupported version numbers from the wiki + Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) + Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) + Expect(VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) + Expect(VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) + Expect(VersionNumber(0x01234567).String()).To(Equal("0x1234567")) + }) + + It("recognizes supported versions", func() { + Expect(IsSupportedVersion(SupportedVersions, 0)).To(BeFalse()) + Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[0])).To(BeTrue()) + Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue()) + }) + + Context("highest supported version", func() { + It("finds the supported version", func() { + supportedVersions := []VersionNumber{1, 2, 3} + other := []VersionNumber{6, 5, 4, 3} + ver, ok := ChooseSupportedVersion(supportedVersions, other) + Expect(ok).To(BeTrue()) + Expect(ver).To(Equal(VersionNumber(3))) + }) + + It("picks the preferred version", func() { + supportedVersions := []VersionNumber{2, 1, 3} + other := []VersionNumber{3, 6, 1, 8, 2, 10} + ver, ok := ChooseSupportedVersion(supportedVersions, other) + Expect(ok).To(BeTrue()) + Expect(ver).To(Equal(VersionNumber(2))) + }) + + It("says when no matching version was found", func() { + _, ok := ChooseSupportedVersion([]VersionNumber{1}, []VersionNumber{2}) + Expect(ok).To(BeFalse()) + }) + + It("handles empty inputs", func() { + _, ok := ChooseSupportedVersion([]VersionNumber{102, 101}, []VersionNumber{}) + Expect(ok).To(BeFalse()) + _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{1, 2}) + Expect(ok).To(BeFalse()) + _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{}) + Expect(ok).To(BeFalse()) + }) + }) + + Context("reserved versions", func() { + It("adds a greased version if passed an empty slice", func() { + greased := GetGreasedVersions([]VersionNumber{}) + Expect(greased).To(HaveLen(1)) + Expect(isReservedVersion(greased[0])).To(BeTrue()) + }) + + It("creates greased lists of version numbers", func() { + supported := []VersionNumber{10, 18, 29} + for _, v := range supported { + Expect(isReservedVersion(v)).To(BeFalse()) + } + var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int + // check that + // 1. the greased version sometimes appears first + // 2. the greased version sometimes appears in the middle + // 3. the greased version sometimes appears last + // 4. the supported versions are kept in order + for i := 0; i < 100; i++ { + greased := GetGreasedVersions(supported) + Expect(greased).To(HaveLen(4)) + var j int + for i, v := range greased { + if isReservedVersion(v) { + if i == 0 { + greasedVersionFirst++ + } + if i == len(greased)-1 { + greasedVersionLast++ + } + greasedVersionMiddle++ + continue + } + Expect(supported[j]).To(Equal(v)) + j++ + } + } + Expect(greasedVersionFirst).ToNot(BeZero()) + Expect(greasedVersionLast).ToNot(BeZero()) + Expect(greasedVersionMiddle).ToNot(BeZero()) + }) + }) +}) diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go new file mode 100644 index 00000000..58ea6b43 --- /dev/null +++ b/internal/qerr/error_codes.go @@ -0,0 +1,88 @@ +package qerr + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/qtls" +) + +// TransportErrorCode is a QUIC transport error. +type TransportErrorCode uint64 + +// The error codes defined by QUIC +const ( + NoError TransportErrorCode = 0x0 + InternalError TransportErrorCode = 0x1 + ConnectionRefused TransportErrorCode = 0x2 + FlowControlError TransportErrorCode = 0x3 + StreamLimitError TransportErrorCode = 0x4 + StreamStateError TransportErrorCode = 0x5 + FinalSizeError TransportErrorCode = 0x6 + FrameEncodingError TransportErrorCode = 0x7 + TransportParameterError TransportErrorCode = 0x8 + ConnectionIDLimitError TransportErrorCode = 0x9 + ProtocolViolation TransportErrorCode = 0xa + InvalidToken TransportErrorCode = 0xb + ApplicationErrorErrorCode TransportErrorCode = 0xc + CryptoBufferExceeded TransportErrorCode = 0xd + KeyUpdateError TransportErrorCode = 0xe + AEADLimitReached TransportErrorCode = 0xf + NoViablePathError TransportErrorCode = 0x10 +) + +func (e TransportErrorCode) IsCryptoError() bool { + return e >= 0x100 && e < 0x200 +} + +// Message is a description of the error. +// It only returns a non-empty string for crypto errors. +func (e TransportErrorCode) Message() string { + if !e.IsCryptoError() { + return "" + } + return qtls.Alert(e - 0x100).Error() +} + +func (e TransportErrorCode) String() string { + switch e { + case NoError: + return "NO_ERROR" + case InternalError: + return "INTERNAL_ERROR" + case ConnectionRefused: + return "CONNECTION_REFUSED" + case FlowControlError: + return "FLOW_CONTROL_ERROR" + case StreamLimitError: + return "STREAM_LIMIT_ERROR" + case StreamStateError: + return "STREAM_STATE_ERROR" + case FinalSizeError: + return "FINAL_SIZE_ERROR" + case FrameEncodingError: + return "FRAME_ENCODING_ERROR" + case TransportParameterError: + return "TRANSPORT_PARAMETER_ERROR" + case ConnectionIDLimitError: + return "CONNECTION_ID_LIMIT_ERROR" + case ProtocolViolation: + return "PROTOCOL_VIOLATION" + case InvalidToken: + return "INVALID_TOKEN" + case ApplicationErrorErrorCode: + return "APPLICATION_ERROR" + case CryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case KeyUpdateError: + return "KEY_UPDATE_ERROR" + case AEADLimitReached: + return "AEAD_LIMIT_REACHED" + case NoViablePathError: + return "NO_VIABLE_PATH" + default: + if e.IsCryptoError() { + return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) + } +} diff --git a/internal/qerr/errorcodes_test.go b/internal/qerr/errorcodes_test.go new file mode 100644 index 00000000..cfc6cd85 --- /dev/null +++ b/internal/qerr/errorcodes_test.go @@ -0,0 +1,52 @@ +package qerr + +import ( + "go/ast" + "go/parser" + "go/token" + "path" + "runtime" + "strconv" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("error codes", func() { + // If this test breaks, you should run `go generate ./...` + It("has a string representation for every error code", func() { + // We parse the error code file, extract all constants, and verify that + // each of them has a string version. Go FTW! + _, thisfile, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + filename := path.Join(path.Dir(thisfile), "error_codes.go") + fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) + Expect(err).NotTo(HaveOccurred()) + constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs + Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing + for _, c := range constSpecs { + valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + val, err := strconv.ParseInt(valString, 0, 64) + Expect(err).NotTo(HaveOccurred()) + Expect(TransportErrorCode(val).String()).ToNot(Equal("unknown error code")) + } + }) + + It("has a string representation for unknown error codes", func() { + Expect(TransportErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) + }) + + It("says if an error is a crypto error", func() { + for i := 0; i < 0x100; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } + for i := 0x100; i < 0x200; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeTrue()) + } + for i := 0x200; i < 0x300; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } + }) +}) diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go new file mode 100644 index 00000000..327491b8 --- /dev/null +++ b/internal/qerr/errors.go @@ -0,0 +1,125 @@ +package qerr + +import ( + "fmt" + "github.com/lucas-clemente/quic-go" + "net" + + "github.com/imroc/req/v3/internal/protocol" +) + +var ( + ErrHandshakeTimeout = &HandshakeTimeoutError{} + ErrIdleTimeout = &IdleTimeoutError{} +) + +type TransportError struct { + Remote bool + FrameType uint64 + ErrorCode TransportErrorCode + ErrorMessage string +} + +var _ error = &TransportError{} + +// NewCryptoError create a new TransportError instance for a crypto error +func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { + return &TransportError{ + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), + ErrorMessage: errorMessage, + } +} + +func (e *TransportError) Error() string { + str := e.ErrorCode.String() + if e.FrameType != 0 { + str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) + } + msg := e.ErrorMessage + if len(msg) == 0 { + msg = e.ErrorCode.Message() + } + if len(msg) == 0 { + return str + } + return str + ": " + msg +} + +func (e *TransportError) Is(target error) bool { + return target == net.ErrClosed +} + +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint64 + +func (e *ApplicationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StreamErrorCode is an error code used to cancel streams. +type StreamErrorCode uint64 + +type ApplicationError struct { + Remote bool + ErrorCode ApplicationErrorCode + ErrorMessage string +} + +var _ error = &ApplicationError{} + +func (e *ApplicationError) Error() string { + if len(e.ErrorMessage) == 0 { + return fmt.Sprintf("Application error %#x", e.ErrorCode) + } + return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) +} + +type IdleTimeoutError struct{} + +var _ error = &IdleTimeoutError{} + +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +type HandshakeTimeoutError struct{} + +var _ error = &HandshakeTimeoutError{} + +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. +type VersionNegotiationError struct { + Ours []quic.VersionNumber + Theirs []quic.VersionNumber +} + +func (e *VersionNegotiationError) Error() string { + return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) +} + +func (e *VersionNegotiationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StatelessResetError occurs when we receive a stateless reset. +type StatelessResetError struct { + Token protocol.StatelessResetToken +} + +var _ net.Error = &StatelessResetError{} + +func (e *StatelessResetError) Error() string { + return fmt.Sprintf("received a stateless reset with token %x", e.Token) +} + +func (e *StatelessResetError) Is(target error) bool { + return target == net.ErrClosed +} + +func (e *StatelessResetError) Timeout() bool { return false } +func (e *StatelessResetError) Temporary() bool { return true } diff --git a/internal/qerr/errors_suite_test.go b/internal/qerr/errors_suite_test.go new file mode 100644 index 00000000..749cdedc --- /dev/null +++ b/internal/qerr/errors_suite_test.go @@ -0,0 +1,13 @@ +package qerr + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestErrorcodes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Errors Suite") +} diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go new file mode 100644 index 00000000..4376fcc6 --- /dev/null +++ b/internal/qerr/errors_test.go @@ -0,0 +1,124 @@ +package qerr + +import ( + "errors" + "net" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("QUIC Errors", func() { + Context("Transport Errors", func() { + It("has a string representation", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) + }) + + It("has a string representation for empty error phrases", func() { + Expect((&TransportError{ErrorCode: FlowControlError}).Error()).To(Equal("FLOW_CONTROL_ERROR")) + }) + + It("includes the frame type, for errors without a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) + }) + + It("includes the frame type, for errors with a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) + }) + + Context("crypto errors", func() { + It("has a string representation for errors with a message", func() { + err := NewCryptoError(0x42, "foobar") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) + }) + + It("has a string representation for errors without a message", func() { + err := NewCryptoError(0x2a, "") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) + }) + }) + }) + + Context("Application Errors", func() { + It("has a string representation for errors with a message", func() { + Expect((&ApplicationError{ + ErrorCode: 0x42, + ErrorMessage: "foobar", + }).Error()).To(Equal("Application error 0x42: foobar")) + }) + + It("has a string representation for errors without a message", func() { + Expect((&ApplicationError{ + ErrorCode: 0x42, + }).Error()).To(Equal("Application error 0x42")) + }) + }) + + Context("timeout errors", func() { + It("handshake timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &HandshakeTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) + }) + + It("idle timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &IdleTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err.Error()).To(Equal("timeout: no recent network activity")) + }) + }) + + Context("Version Negotiation errors", func() { + It("has a string representation", func() { + Expect((&VersionNegotiationError{ + Ours: []quic.VersionNumber{2, 3}, + Theirs: []quic.VersionNumber{4, 5, 6}, + }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) + }) + }) + + Context("Stateless Reset errors", func() { + token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + + It("has a string representation", func() { + Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f")) + }) + + It("is a net.Error", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &StatelessResetError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeFalse()) + }) + }) + + It("says that errors are net.ErrClosed errors", func() { + Expect(errors.Is(&TransportError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&ApplicationError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&IdleTimeoutError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&StatelessResetError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&VersionNegotiationError{}, net.ErrClosed)).To(BeTrue()) + }) +}) diff --git a/internal/qtls/go116.go b/internal/qtls/go116.go new file mode 100644 index 00000000..e3024624 --- /dev/null +++ b/internal/qtls/go116.go @@ -0,0 +1,100 @@ +//go:build go1.16 && !go1.17 +// +build go1.16,!go1.17 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-16" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/qtls/go117.go b/internal/qtls/go117.go new file mode 100644 index 00000000..bc385f19 --- /dev/null +++ b/internal/qtls/go117.go @@ -0,0 +1,100 @@ +//go:build go1.17 && !go1.18 +// +build go1.17,!go1.18 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-17" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/qtls/go118.go b/internal/qtls/go118.go new file mode 100644 index 00000000..0e0e7966 --- /dev/null +++ b/internal/qtls/go118.go @@ -0,0 +1,100 @@ +//go:build go1.18 +// +build go1.18 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-18" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go new file mode 100644 index 00000000..87e7132e --- /dev/null +++ b/internal/qtls/go119.go @@ -0,0 +1,6 @@ +//go:build go1.19 +// +build go1.19 + +package qtls + +var _ int = "The version of quic-go you're using can't be built on Go 1.19 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/qtls/go_oldversion.go b/internal/qtls/go_oldversion.go new file mode 100644 index 00000000..384d719c --- /dev/null +++ b/internal/qtls/go_oldversion.go @@ -0,0 +1,7 @@ +//go:build (go1.9 || go1.10 || go1.11 || go1.12 || go1.13 || go1.14 || go1.15) && !go1.16 +// +build go1.9 go1.10 go1.11 go1.12 go1.13 go1.14 go1.15 +// +build !go1.16 + +package qtls + +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go new file mode 100644 index 00000000..24b143b2 --- /dev/null +++ b/internal/qtls/qtls_suite_test.go @@ -0,0 +1,25 @@ +package qtls + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQTLS(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "qtls Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/qtls/qtls_test.go b/internal/qtls/qtls_test.go new file mode 100644 index 00000000..c64c5e9e --- /dev/null +++ b/internal/qtls/qtls_test.go @@ -0,0 +1,17 @@ +package qtls + +import ( + "crypto/tls" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("qtls wrapper", func() { + It("gets cipher suites", func() { + for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} { + cs := CipherSuiteTLS13ByID(id) + Expect(cs.ID).To(Equal(id)) + } + }) +}) diff --git a/internal/quicvarint/io.go b/internal/quicvarint/io.go new file mode 100644 index 00000000..9368d1c5 --- /dev/null +++ b/internal/quicvarint/io.go @@ -0,0 +1,68 @@ +package quicvarint + +import ( + "bytes" + "io" +) + +// Reader implements both the io.ByteReader and io.Reader interfaces. +type Reader interface { + io.ByteReader + io.Reader +} + +var _ Reader = &bytes.Reader{} + +type byteReader struct { + io.Reader +} + +var _ Reader = &byteReader{} + +// NewReader returns a Reader for r. +// If r already implements both io.ByteReader and io.Reader, NewReader returns r. +// Otherwise, r is wrapped to add the missing interfaces. +func NewReader(r io.Reader) Reader { + if r, ok := r.(Reader); ok { + return r + } + return &byteReader{r} +} + +func (r *byteReader) ReadByte() (byte, error) { + var b [1]byte + n, err := r.Reader.Read(b[:]) + if n == 1 && err == io.EOF { + err = nil + } + return b[0], err +} + +// Writer implements both the io.ByteWriter and io.Writer interfaces. +type Writer interface { + io.ByteWriter + io.Writer +} + +var _ Writer = &bytes.Buffer{} + +type byteWriter struct { + io.Writer +} + +var _ Writer = &byteWriter{} + +// NewWriter returns a Writer for w. +// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w. +// Otherwise, w is wrapped to add the missing interfaces. +func NewWriter(w io.Writer) Writer { + if w, ok := w.(Writer); ok { + return w + } + return &byteWriter{w} +} + +func (w *byteWriter) WriteByte(c byte) error { + _, err := w.Writer.Write([]byte{c}) + return err +} diff --git a/internal/quicvarint/io_test.go b/internal/quicvarint/io_test.go new file mode 100644 index 00000000..054ab864 --- /dev/null +++ b/internal/quicvarint/io_test.go @@ -0,0 +1,115 @@ +package quicvarint + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type nopReader struct{} + +func (r *nopReader) Read(_ []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +var _ io.Reader = &nopReader{} + +type nopWriter struct{} + +func (r *nopWriter) Write(_ []byte) (int, error) { + return 0, io.ErrShortBuffer +} + +// eofReader is a reader that returns data and the io.EOF at the same time in the last Read call +type eofReader struct { + Data []byte + pos int +} + +func (r *eofReader) Read(b []byte) (int, error) { + n := copy(b, r.Data[r.pos:]) + r.pos += n + if r.pos >= len(r.Data) { + return n, io.EOF + } + return n, nil +} + +var _ io.Writer = &nopWriter{} + +var _ = Describe("Varint I/O", func() { + Context("Reader", func() { + Context("NewReader", func() { + It("passes through a Reader unchanged", func() { + b := bytes.NewReader([]byte{0}) + r := NewReader(b) + Expect(r).To(Equal(b)) + }) + + It("wraps an io.Reader", func() { + n := &nopReader{} + r := NewReader(n) + Expect(r).ToNot(Equal(n)) + }) + }) + + It("returns an error when reading from an underlying io.Reader fails", func() { + r := NewReader(&nopReader{}) + val, err := r.ReadByte() + Expect(err).To(Equal(io.ErrUnexpectedEOF)) + Expect(val).To(Equal(byte(0))) + }) + + Context("EOF handling", func() { + It("eofReader works correctly", func() { + r := &eofReader{Data: []byte("foobar")} + b := make([]byte, 3) + n, err := r.Read(b) + Expect(n).To(Equal(3)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foo")) + n, err = r.Read(b) + Expect(n).To(Equal(3)) + Expect(err).To(MatchError(io.EOF)) + Expect(string(b)).To(Equal("bar")) + n, err = r.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(BeZero()) + }) + + It("correctly handles io.EOF", func() { + buf := &bytes.Buffer{} + Write(buf, 1337) + + r := NewReader(&eofReader{Data: buf.Bytes()}) + n, err := Read(r) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(BeEquivalentTo(1337)) + }) + }) + }) + + Context("Writer", func() { + Context("NewWriter", func() { + It("passes through a Writer unchanged", func() { + b := &bytes.Buffer{} + w := NewWriter(b) + Expect(w).To(Equal(b)) + }) + + It("wraps an io.Writer", func() { + n := &nopWriter{} + w := NewWriter(n) + Expect(w).ToNot(Equal(n)) + }) + }) + + It("returns an error when writing to an underlying io.Writer fails", func() { + w := NewWriter(&nopWriter{}) + err := w.WriteByte(0) + Expect(err).To(Equal(io.ErrShortBuffer)) + }) + }) +}) diff --git a/internal/quicvarint/quicvarint_suite_test.go b/internal/quicvarint/quicvarint_suite_test.go new file mode 100644 index 00000000..b7b17de7 --- /dev/null +++ b/internal/quicvarint/quicvarint_suite_test.go @@ -0,0 +1,13 @@ +package quicvarint_test + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQuicVarint(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "QUIC Varint Suite") +} diff --git a/internal/quicvarint/varint.go b/internal/quicvarint/varint.go new file mode 100644 index 00000000..3bb242fd --- /dev/null +++ b/internal/quicvarint/varint.go @@ -0,0 +1,139 @@ +package quicvarint + +import ( + "fmt" + "io" + + "github.com/imroc/req/v3/internal/protocol" +) + +// taken from the QUIC draft +const ( + // Min is the minimum value allowed for a QUIC varint. + Min = 0 + + // Max is the maximum allowed value for a QUIC varint (2^62-1). + Max = maxVarInt8 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// Read reads a number in the QUIC varint format from r. +func Read(r io.ByteReader) (uint64, error) { + firstByte, err := r.ReadByte() + if err != nil { + return 0, err + } + // the first two bits of the first byte encode the length + len := 1 << ((firstByte & 0xc0) >> 6) + b1 := firstByte & (0xff - 0xc0) + if len == 1 { + return uint64(b1), nil + } + b2, err := r.ReadByte() + if err != nil { + return 0, err + } + if len == 2 { + return uint64(b2) + uint64(b1)<<8, nil + } + b3, err := r.ReadByte() + if err != nil { + return 0, err + } + b4, err := r.ReadByte() + if err != nil { + return 0, err + } + if len == 4 { + return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil + } + b5, err := r.ReadByte() + if err != nil { + return 0, err + } + b6, err := r.ReadByte() + if err != nil { + return 0, err + } + b7, err := r.ReadByte() + if err != nil { + return 0, err + } + b8, err := r.ReadByte() + if err != nil { + return 0, err + } + return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil +} + +// Write writes i in the QUIC varint format to w. +func Write(w Writer, i uint64) { + if i <= maxVarInt1 { + w.WriteByte(uint8(i)) + } else if i <= maxVarInt2 { + w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + } else if i <= maxVarInt4 { + w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + } else if i <= maxVarInt8 { + w.Write([]byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) + } else { + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) + } +} + +// WriteWithLen writes i in the QUIC varint format with the desired length to w. +func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { + if length != 1 && length != 2 && length != 4 && length != 8 { + panic("invalid varint length") + } + l := Len(i) + if l == length { + Write(w, i) + return + } + if l > length { + panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) + } + if length == 2 { + w.WriteByte(0b01000000) + } else if length == 4 { + w.WriteByte(0b10000000) + } else if length == 8 { + w.WriteByte(0b11000000) + } + for j := protocol.ByteCount(1); j < length-l; j++ { + w.WriteByte(0) + } + for j := protocol.ByteCount(0); j < l; j++ { + w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) + } +} + +// Len determines the number of bytes that will be needed to write the number i. +func Len(i uint64) protocol.ByteCount { + if i <= maxVarInt1 { + return 1 + } + if i <= maxVarInt2 { + return 2 + } + if i <= maxVarInt4 { + return 4 + } + if i <= maxVarInt8 { + return 8 + } + // Don't use a fmt.Sprintf here to format the error message. + // The function would then exceed the inlining budget. + panic(struct { + message string + num uint64 + }{"value doesn't fit into 62 bits: ", i}) +} diff --git a/internal/quicvarint/varint_test.go b/internal/quicvarint/varint_test.go new file mode 100644 index 00000000..acf4a31c --- /dev/null +++ b/internal/quicvarint/varint_test.go @@ -0,0 +1,221 @@ +package quicvarint + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Varint encoding / decoding", func() { + Context("limits", func() { + Specify("Min == 0", func() { + Expect(Min).To(Equal(0)) + }) + + Specify("Max == 2^62-1", func() { + Expect(uint64(Max)).To(Equal(uint64(1<<62 - 1))) + }) + }) + + Context("decoding", func() { + It("reads a 1 byte number", func() { + b := bytes.NewReader([]byte{0b00011001}) + val, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(25))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a number that is encoded too long", func() { + b := bytes.NewReader([]byte{0b01000000, 0x25}) + val, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(37))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a 2 byte number", func() { + b := bytes.NewReader([]byte{0b01111011, 0xbd}) + val, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(15293))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a 4 byte number", func() { + b := bytes.NewReader([]byte{0b10011101, 0x7f, 0x3e, 0x7d}) + val, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(494878333))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads an 8 byte number", func() { + b := bytes.NewReader([]byte{0b11000010, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}) + val, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(151288809941952652))) + Expect(b.Len()).To(BeZero()) + }) + }) + + Context("encoding", func() { + Context("with minimal length", func() { + It("writes a 1 byte number", func() { + b := &bytes.Buffer{} + Write(b, 37) + Expect(b.Bytes()).To(Equal([]byte{0x25})) + }) + + It("writes the maximum 1 byte number in 1 byte", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt1) + Expect(b.Bytes()).To(Equal([]byte{0b00111111})) + }) + + It("writes the minimum 2 byte number in 2 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt1+1) + Expect(b.Bytes()).To(Equal([]byte{0x40, maxVarInt1 + 1})) + }) + + It("writes a 2 byte number", func() { + b := &bytes.Buffer{} + Write(b, 15293) + Expect(b.Bytes()).To(Equal([]byte{0b01000000 ^ 0x3b, 0xbd})) + }) + + It("writes the maximum 2 byte number in 2 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt2) + Expect(b.Bytes()).To(Equal([]byte{0b01111111, 0xff})) + }) + + It("writes the minimum 4 byte number in 4 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt2+1) + Expect(b.Len()).To(Equal(4)) + num, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(num).To(Equal(uint64(maxVarInt2 + 1))) + }) + + It("writes a 4 byte number", func() { + b := &bytes.Buffer{} + Write(b, 494878333) + Expect(b.Bytes()).To(Equal([]byte{0b10000000 ^ 0x1d, 0x7f, 0x3e, 0x7d})) + }) + + It("writes the maximum 4 byte number in 4 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt4) + Expect(b.Bytes()).To(Equal([]byte{0b10111111, 0xff, 0xff, 0xff})) + }) + + It("writes the minimum 8 byte number in 8 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt4+1) + Expect(b.Len()).To(Equal(8)) + num, err := Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(num).To(Equal(uint64(maxVarInt4 + 1))) + }) + + It("writes an 8 byte number", func() { + b := &bytes.Buffer{} + Write(b, 151288809941952652) + Expect(b.Bytes()).To(Equal([]byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c})) + }) + + It("writes the maximum 8 byte number in 8 bytes", func() { + b := &bytes.Buffer{} + Write(b, maxVarInt8) + Expect(b.Bytes()).To(Equal([]byte{0xff /* 11111111 */, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) + }) + + It("panics when given a too large number (> 62 bit)", func() { + Expect(func() { Write(&bytes.Buffer{}, maxVarInt8+1) }).Should(Panic()) + }) + }) + + Context("with fixed length", func() { + It("panics when given an invalid length", func() { + Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic()) + }) + + It("panics when given a too short length", func() { + Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic()) + Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic()) + Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic()) + }) + + It("writes a 1-byte number in minimal encoding", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 37, 1) + Expect(b.Bytes()).To(Equal([]byte{0x25})) + }) + + It("writes a 1-byte number in 2 bytes", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 37, 2) + Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25})) + Expect(Read(b)).To(BeEquivalentTo(37)) + }) + + It("writes a 1-byte number in 4 bytes", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 37, 4) + Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25})) + Expect(Read(b)).To(BeEquivalentTo(37)) + }) + + It("writes a 1-byte number in 8 bytes", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 37, 8) + Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25})) + Expect(Read(b)).To(BeEquivalentTo(37)) + }) + + It("writes a 2-byte number in 4 bytes", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 15293, 4) + Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd})) + Expect(Read(b)).To(BeEquivalentTo(15293)) + }) + + It("write a 4-byte number in 8 bytes", func() { + b := &bytes.Buffer{} + WriteWithLen(b, 494878333, 8) + Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d})) + Expect(Read(b)).To(BeEquivalentTo(494878333)) + }) + }) + }) + + Context("determining the length needed for encoding", func() { + It("for numbers that need 1 byte", func() { + Expect(Len(0)).To(BeEquivalentTo(1)) + Expect(Len(maxVarInt1)).To(BeEquivalentTo(1)) + }) + + It("for numbers that need 2 bytes", func() { + Expect(Len(maxVarInt1 + 1)).To(BeEquivalentTo(2)) + Expect(Len(maxVarInt2)).To(BeEquivalentTo(2)) + }) + + It("for numbers that need 4 bytes", func() { + Expect(Len(maxVarInt2 + 1)).To(BeEquivalentTo(4)) + Expect(Len(maxVarInt4)).To(BeEquivalentTo(4)) + }) + + It("for numbers that need 8 bytes", func() { + Expect(Len(maxVarInt4 + 1)).To(BeEquivalentTo(8)) + Expect(Len(maxVarInt8)).To(BeEquivalentTo(8)) + }) + + It("panics when given a too large number (> 62 bit)", func() { + Expect(func() { Len(maxVarInt8 + 1) }).Should(Panic()) + }) + }) +}) diff --git a/internal/testdata/ca.pem b/internal/testdata/ca.pem new file mode 100644 index 00000000..67a5545e --- /dev/null +++ b/internal/testdata/ca.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x +dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z +MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 +aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh +iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 +qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp +3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD +WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS +Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU +1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j +wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK +uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ +nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX +CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp +qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 +-----END CERTIFICATE----- diff --git a/internal/testdata/cert.go b/internal/testdata/cert.go new file mode 100644 index 00000000..f862b0cb --- /dev/null +++ b/internal/testdata/cert.go @@ -0,0 +1,55 @@ +package testdata + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "path" + "runtime" +) + +var certPath string + +func init() { + _, filename, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + + certPath = path.Dir(filename) +} + +// GetCertificatePaths returns the paths to certificate and key +func GetCertificatePaths() (string, string) { + return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") +} + +// GetTLSConfig returns a tls config for quic.clemente.io +func GetTLSConfig() *tls.Config { + cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +// AddRootCA adds the root CA certificate to a cert pool +func AddRootCA(certPool *x509.CertPool) { + caCertPath := path.Join(certPath, "ca.pem") + caCertRaw, err := ioutil.ReadFile(caCertPath) + if err != nil { + panic(err) + } + if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { + panic("Could not add root ceritificate to pool.") + } +} + +// GetRootCA returns an x509.CertPool containing (only) the CA certificate +func GetRootCA() *x509.CertPool { + pool := x509.NewCertPool() + AddRootCA(pool) + return pool +} diff --git a/internal/testdata/cert.pem b/internal/testdata/cert.pem new file mode 100644 index 00000000..91d1aa9e --- /dev/null +++ b/internal/testdata/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV +BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz +NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 ++w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu +WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ +meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw +Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM +i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz +shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf +636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj +fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh +hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R +vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U +A/TbaR0ih/qG +-----END CERTIFICATE----- diff --git a/internal/testdata/cert_test.go b/internal/testdata/cert_test.go new file mode 100644 index 00000000..0de1bd7b --- /dev/null +++ b/internal/testdata/cert_test.go @@ -0,0 +1,31 @@ +package testdata + +import ( + "crypto/tls" + "io/ioutil" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("certificates", func() { + It("returns certificates", func() { + ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + conn, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + _, err = conn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + }() + + conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(conn) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("foobar")) + }) +}) diff --git a/internal/testdata/generate_key.sh b/internal/testdata/generate_key.sh new file mode 100755 index 00000000..7ecaa966 --- /dev/null +++ b/internal/testdata/generate_key.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e + +echo "Generating CA key and certificate:" +openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ + -keyout ca.key -out ca.pem \ + -subj "/O=quic-go Certificate Authority/" + +echo "Generating CSR" +openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ + -subj "/O=quic-go/" + +echo "Sign certificate:" +openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ + -CA ca.pem -CAkey ca.key -CAcreateserial \ + -extfile <(printf "subjectAltName=DNS:localhost") + +# debug output the certificate +openssl x509 -noout -text -in cert.pem + +# we don't need the CA key, the serial number and the CSR any more +rm ca.key cert.csr ca.srl + diff --git a/internal/testdata/priv.key b/internal/testdata/priv.key new file mode 100644 index 00000000..56b8d894 --- /dev/null +++ b/internal/testdata/priv.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ +23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK +pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq +t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa +bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd +Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls +dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw +dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd +IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 +hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 +AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 +tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ +Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r +9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u +O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 +tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 +Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H +9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec +7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW +uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 +vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq +unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 +lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T +d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ +yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 +jT0GzDymgLMGp8RPdBkpk+w= +-----END PRIVATE KEY----- diff --git a/internal/testdata/testdata_suite_test.go b/internal/testdata/testdata_suite_test.go new file mode 100644 index 00000000..4e9011cf --- /dev/null +++ b/internal/testdata/testdata_suite_test.go @@ -0,0 +1,13 @@ +package testdata + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestTestdata(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Testdata Suite") +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index f15195d2..b9e864ce 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -13,7 +13,6 @@ import ( type Interface interface { Proxy() func(*http.Request) (*url.URL, error) - Clone() Interface Debugf() func(format string, v ...interface{}) SetDebugf(func(format string, v ...interface{})) DisableCompression() bool @@ -22,7 +21,6 @@ type Interface interface { TLSHandshakeTimeout() time.Duration DialContext() func(ctx context.Context, network, addr string) (net.Conn, error) DialTLSContext() func(ctx context.Context, network, addr string) (net.Conn, error) - RegisterProtocol(scheme string, rt http.RoundTripper) DisableKeepAlives() bool Dump() *dump.Dumper diff --git a/internal/utils/atomic_bool.go b/internal/utils/atomic_bool.go new file mode 100644 index 00000000..cf464250 --- /dev/null +++ b/internal/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/internal/utils/atomic_bool_test.go b/internal/utils/atomic_bool_test.go new file mode 100644 index 00000000..83a200c2 --- /dev/null +++ b/internal/utils/atomic_bool_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Atomic Bool", func() { + var a *AtomicBool + + BeforeEach(func() { + a = &AtomicBool{} + }) + + It("has the right default value", func() { + Expect(a.Get()).To(BeFalse()) + }) + + It("sets the value to true", func() { + a.Set(true) + Expect(a.Get()).To(BeTrue()) + }) + + It("sets the value to false", func() { + a.Set(true) + a.Set(false) + Expect(a.Get()).To(BeFalse()) + }) +}) diff --git a/internal/utils/buffered_write_closer.go b/internal/utils/buffered_write_closer.go new file mode 100644 index 00000000..b5b9d6fc --- /dev/null +++ b/internal/utils/buffered_write_closer.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bufio" + "io" +) + +type bufferedWriteCloser struct { + *bufio.Writer + io.Closer +} + +// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer +func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { + return &bufferedWriteCloser{ + Writer: writer, + Closer: closer, + } +} + +func (h bufferedWriteCloser) Close() error { + if err := h.Writer.Flush(); err != nil { + return err + } + return h.Closer.Close() +} diff --git a/internal/utils/buffered_write_closer_test.go b/internal/utils/buffered_write_closer_test.go new file mode 100644 index 00000000..9c93d615 --- /dev/null +++ b/internal/utils/buffered_write_closer_test.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bufio" + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type nopCloser struct{} + +func (nopCloser) Close() error { return nil } + +var _ = Describe("buffered io.WriteCloser", func() { + It("flushes before closing", func() { + buf := &bytes.Buffer{} + + w := bufio.NewWriter(buf) + wc := NewBufferedWriteCloser(w, &nopCloser{}) + wc.Write([]byte("foobar")) + Expect(buf.Len()).To(BeZero()) + Expect(wc.Close()).To(Succeed()) + Expect(buf.String()).To(Equal("foobar")) + }) +}) diff --git a/internal/utils/byteinterval_linkedlist.go b/internal/utils/byteinterval_linkedlist.go new file mode 100644 index 00000000..096023ef --- /dev/null +++ b/internal/utils/byteinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// ByteIntervalElement is an element of a linked list. +type ByteIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *ByteIntervalElement + + // The list to which this element belongs. + list *ByteIntervalList + + // The value stored with this element. + Value ByteInterval +} + +// Next returns the next list element or nil. +func (e *ByteIntervalElement) Next() *ByteIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *ByteIntervalElement) Prev() *ByteIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// ByteIntervalList is a linked list of ByteIntervals. +type ByteIntervalList struct { + root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *ByteIntervalList) Init() *ByteIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewByteIntervalList returns an initialized list. +func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *ByteIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *ByteIntervalList) Front() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *ByteIntervalList) Back() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *ByteIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { + return l.insert(&ByteIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/utils/byteoder_big_endian_test.go b/internal/utils/byteoder_big_endian_test.go new file mode 100644 index 00000000..5d0873a9 --- /dev/null +++ b/internal/utils/byteoder_big_endian_test.go @@ -0,0 +1,107 @@ +package utils + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Big Endian encoding / decoding", func() { + Context("ReadUint16", func() { + It("reads a big endian", func() { + b := []byte{0x13, 0xEF} + val, err := BigEndian.ReadUint16(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint16(0x13EF))) + }) + + It("throws an error if less than 2 bytes are passed", func() { + b := []byte{0x13, 0xEF} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint16(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("ReadUint24", func() { + It("reads a big endian", func() { + b := []byte{0x13, 0xbe, 0xef} + val, err := BigEndian.ReadUint24(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x13beef))) + }) + + It("throws an error if less than 3 bytes are passed", func() { + b := []byte{0x13, 0xbe, 0xef} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint24(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("ReadUint32", func() { + It("reads a big endian", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF} + val, err := BigEndian.ReadUint32(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x1235ABFF))) + }) + + It("throws an error if less than 4 bytes are passed", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint32(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("WriteUint16", func() { + It("outputs 2 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint16(b, uint16(1)) + Expect(b.Len()).To(Equal(2)) + }) + + It("outputs a big endian", func() { + num := uint16(0xFF11) + b := &bytes.Buffer{} + BigEndian.WriteUint16(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xFF, 0x11})) + }) + }) + + Context("WriteUint24", func() { + It("outputs 3 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, uint32(1)) + Expect(b.Len()).To(Equal(3)) + }) + + It("outputs a big endian", func() { + num := uint32(0xff11aa) + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa})) + }) + }) + + Context("WriteUint32", func() { + It("outputs 4 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint32(b, uint32(1)) + Expect(b.Len()).To(Equal(4)) + }) + + It("outputs a big endian", func() { + num := uint32(0xEFAC3512) + b := &bytes.Buffer{} + BigEndian.WriteUint32(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) + }) + }) +}) diff --git a/internal/utils/byteorder.go b/internal/utils/byteorder.go new file mode 100644 index 00000000..d1f52842 --- /dev/null +++ b/internal/utils/byteorder.go @@ -0,0 +1,17 @@ +package utils + +import ( + "bytes" + "io" +) + +// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + ReadUint32(io.ByteReader) (uint32, error) + ReadUint24(io.ByteReader) (uint32, error) + ReadUint16(io.ByteReader) (uint16, error) + + WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) + WriteUint16(*bytes.Buffer, uint16) +} diff --git a/internal/utils/byteorder_big_endian.go b/internal/utils/byteorder_big_endian.go new file mode 100644 index 00000000..d05542e1 --- /dev/null +++ b/internal/utils/byteorder_big_endian.go @@ -0,0 +1,89 @@ +package utils + +import ( + "bytes" + "io" +) + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian ByteOrder = bigEndian{} + +type bigEndian struct{} + +var _ ByteOrder = &bigEndian{} + +// ReadUintN reads N bytes +func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { + var res uint64 + for i := uint8(0); i < length; i++ { + bt, err := b.ReadByte() + if err != nil { + return 0, err + } + res ^= uint64(bt) << ((length - 1 - i) * 8) + } + return res, nil +} + +// ReadUint32 reads a uint32 +func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +// ReadUint24 reads a uint24 +func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { + var b1, b2, b3 uint8 + var err error + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil +} + +// ReadUint16 reads a uint16 +func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +} + +// WriteUint32 writes a uint32 +func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint24 writes a uint24 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint16 writes a uint16 +func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { + b.Write([]byte{uint8(i >> 8), uint8(i)}) +} diff --git a/internal/utils/gen.go b/internal/utils/gen.go new file mode 100644 index 00000000..8a63e958 --- /dev/null +++ b/internal/utils/gen.go @@ -0,0 +1,5 @@ +package utils + +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out newconnectionid_linkedlist.go gen Item=NewConnectionID diff --git a/internal/utils/ip.go b/internal/utils/ip.go new file mode 100644 index 00000000..7ac7ffec --- /dev/null +++ b/internal/utils/ip.go @@ -0,0 +1,10 @@ +package utils + +import "net" + +func IsIPv4(ip net.IP) bool { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + return ip.To4() != nil +} diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go new file mode 100644 index 00000000..b61cf529 --- /dev/null +++ b/internal/utils/ip_test.go @@ -0,0 +1,17 @@ +package utils + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("IP", func() { + It("tells IPv4 and IPv6 addresses apart", func() { + Expect(IsIPv4(net.IPv4(127, 0, 0, 1))).To(BeTrue()) + Expect(IsIPv4(net.IPv4zero)).To(BeTrue()) + Expect(IsIPv4(net.IPv6zero)).To(BeFalse()) + Expect(IsIPv4(net.IPv6loopback)).To(BeFalse()) + }) +}) diff --git a/internal/utils/linkedlist/README.md b/internal/utils/linkedlist/README.md new file mode 100644 index 00000000..15b46dce --- /dev/null +++ b/internal/utils/linkedlist/README.md @@ -0,0 +1,11 @@ +# Usage + +This is the Go standard library implementation of a linked list +(https://golang.org/src/container/list/list.go), modified such that genny +(https://github.com/cheekybits/genny) can be used to generate a typed linked +list. + +To generate, run +``` +genny -pkg $PACKAGE -in linkedlist.go -out $OUTFILE gen Item=$TYPE +``` diff --git a/internal/utils/linkedlist/linkedlist.go b/internal/utils/linkedlist/linkedlist.go new file mode 100644 index 00000000..74b815a8 --- /dev/null +++ b/internal/utils/linkedlist/linkedlist.go @@ -0,0 +1,218 @@ +package linkedlist + +import "github.com/cheekybits/genny/generic" + +// Linked list implementation from the Go standard library. + +// Item is a generic type. +type Item generic.Type + +// ItemElement is an element of a linked list. +type ItemElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *ItemElement + + // The list to which this element belongs. + list *ItemList + + // The value stored with this element. + Value Item +} + +// Next returns the next list element or nil. +func (e *ItemElement) Next() *ItemElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *ItemElement) Prev() *ItemElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// ItemList is a linked list of Items. +type ItemList struct { + root ItemElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *ItemList) Init() *ItemList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewItemList returns an initialized list. +func NewItemList() *ItemList { return new(ItemList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *ItemList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *ItemList) Front() *ItemElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *ItemList) Back() *ItemElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *ItemList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *ItemList) insert(e, at *ItemElement) *ItemElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *ItemList) insertValue(v Item, at *ItemElement) *ItemElement { + return l.insert(&ItemElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *ItemList) remove(e *ItemElement) *ItemElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *ItemList) Remove(e *ItemElement) Item { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *ItemList) PushFront(v Item) *ItemElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *ItemList) PushBack(v Item) *ItemElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ItemList) InsertBefore(v Item, mark *ItemElement) *ItemElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ItemList) InsertAfter(v Item, mark *ItemElement) *ItemElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ItemList) MoveToFront(e *ItemElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ItemList) MoveToBack(e *ItemElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ItemList) MoveBefore(e, mark *ItemElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ItemList) MoveAfter(e, mark *ItemElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ItemList) PushBackList(other *ItemList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ItemList) PushFrontList(other *ItemList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/utils/log.go b/internal/utils/log.go new file mode 100644 index 00000000..e27f01b4 --- /dev/null +++ b/internal/utils/log.go @@ -0,0 +1,131 @@ +package utils + +import ( + "fmt" + "log" + "os" + "strings" + "time" +) + +// LogLevel of quic-go +type LogLevel uint8 + +const ( + // LogLevelNothing disables + LogLevelNothing LogLevel = iota + // LogLevelError enables err logs + LogLevelError + // LogLevelInfo enables info logs (e.g. packets) + LogLevelInfo + // LogLevelDebug enables debug logs (e.g. packet contents) + LogLevelDebug +) + +const logEnv = "QUIC_GO_LOG_LEVEL" + +// A Logger logs. +type Logger interface { + SetLogLevel(LogLevel) + SetLogTimeFormat(format string) + WithPrefix(prefix string) Logger + Debug() bool + + Errorf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Debugf(format string, args ...interface{}) +} + +// DefaultLogger is used by quic-go for logging. +var DefaultLogger Logger + +type defaultLogger struct { + prefix string + + logLevel LogLevel + timeFormat string +} + +var _ Logger = &defaultLogger{} + +// SetLogLevel sets the log level +func (l *defaultLogger) SetLogLevel(level LogLevel) { + l.logLevel = level +} + +// SetLogTimeFormat sets the format of the timestamp +// an empty string disables the logging of timestamps +func (l *defaultLogger) SetLogTimeFormat(format string) { + log.SetFlags(0) // disable timestamp logging done by the log package + l.timeFormat = format +} + +// Debugf logs something +func (l *defaultLogger) Debugf(format string, args ...interface{}) { + if l.logLevel == LogLevelDebug { + l.logMessage(format, args...) + } +} + +// Infof logs something +func (l *defaultLogger) Infof(format string, args ...interface{}) { + if l.logLevel >= LogLevelInfo { + l.logMessage(format, args...) + } +} + +// Errorf logs something +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + if l.logLevel >= LogLevelError { + l.logMessage(format, args...) + } +} + +func (l *defaultLogger) logMessage(format string, args ...interface{}) { + var pre string + + if len(l.timeFormat) > 0 { + pre = time.Now().Format(l.timeFormat) + " " + } + if len(l.prefix) > 0 { + pre += l.prefix + " " + } + log.Printf(pre+format, args...) +} + +func (l *defaultLogger) WithPrefix(prefix string) Logger { + if len(l.prefix) > 0 { + prefix = l.prefix + " " + prefix + } + return &defaultLogger{ + logLevel: l.logLevel, + timeFormat: l.timeFormat, + prefix: prefix, + } +} + +// Debug returns true if the log level is LogLevelDebug +func (l *defaultLogger) Debug() bool { + return l.logLevel == LogLevelDebug +} + +func init() { + DefaultLogger = &defaultLogger{} + DefaultLogger.SetLogLevel(readLoggingEnv()) +} + +func readLoggingEnv() LogLevel { + switch strings.ToLower(os.Getenv(logEnv)) { + case "": + return LogLevelNothing + case "debug": + return LogLevelDebug + case "info": + return LogLevelInfo + case "error": + return LogLevelError + default: + fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + return LogLevelNothing + } +} diff --git a/internal/utils/log_test.go b/internal/utils/log_test.go new file mode 100644 index 00000000..36edc1cc --- /dev/null +++ b/internal/utils/log_test.go @@ -0,0 +1,144 @@ +package utils + +import ( + "bytes" + "log" + "os" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Log", func() { + var b *bytes.Buffer + + BeforeEach(func() { + b = &bytes.Buffer{} + log.SetOutput(b) + }) + + AfterEach(func() { + log.SetOutput(os.Stdout) + DefaultLogger.SetLogLevel(LogLevelNothing) + }) + + It("the log level has the correct numeric value", func() { + Expect(LogLevelNothing).To(BeEquivalentTo(0)) + Expect(LogLevelError).To(BeEquivalentTo(1)) + Expect(LogLevelInfo).To(BeEquivalentTo(2)) + Expect(LogLevelDebug).To(BeEquivalentTo(3)) + }) + + It("log level nothing", func() { + DefaultLogger.SetLogLevel(LogLevelNothing) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") + Expect(b.String()).To(BeEmpty()) + }) + + It("log level err", func() { + DefaultLogger.SetLogLevel(LogLevelError) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") + Expect(b.String()).To(ContainSubstring("err\n")) + Expect(b.String()).ToNot(ContainSubstring("info")) + Expect(b.String()).ToNot(ContainSubstring("debug")) + }) + + It("log level info", func() { + DefaultLogger.SetLogLevel(LogLevelInfo) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") + Expect(b.String()).To(ContainSubstring("err\n")) + Expect(b.String()).To(ContainSubstring("info\n")) + Expect(b.String()).ToNot(ContainSubstring("debug")) + }) + + It("log level debug", func() { + DefaultLogger.SetLogLevel(LogLevelDebug) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") + Expect(b.String()).To(ContainSubstring("err\n")) + Expect(b.String()).To(ContainSubstring("info\n")) + Expect(b.String()).To(ContainSubstring("debug\n")) + }) + + It("doesn't add a timestamp if the time format is empty", func() { + DefaultLogger.SetLogLevel(LogLevelDebug) + DefaultLogger.SetLogTimeFormat("") + DefaultLogger.Debugf("debug") + Expect(b.String()).To(Equal("debug\n")) + }) + + It("adds a timestamp", func() { + format := "Jan 2, 2006" + DefaultLogger.SetLogTimeFormat(format) + DefaultLogger.SetLogLevel(LogLevelInfo) + DefaultLogger.Infof("info") + t, err := time.Parse(format, b.String()[:b.Len()-6]) + Expect(err).ToNot(HaveOccurred()) + Expect(t).To(BeTemporally("~", time.Now(), 25*time.Hour)) + }) + + It("says whether debug is enabled", func() { + Expect(DefaultLogger.Debug()).To(BeFalse()) + DefaultLogger.SetLogLevel(LogLevelDebug) + Expect(DefaultLogger.Debug()).To(BeTrue()) + }) + + It("adds a prefix", func() { + DefaultLogger.SetLogLevel(LogLevelDebug) + prefixLogger := DefaultLogger.WithPrefix("prefix") + prefixLogger.Debugf("debug") + Expect(b.String()).To(ContainSubstring("prefix")) + Expect(b.String()).To(ContainSubstring("debug")) + }) + + It("adds multiple prefixes", func() { + DefaultLogger.SetLogLevel(LogLevelDebug) + prefixLogger := DefaultLogger.WithPrefix("prefix1") + prefixPrefixLogger := prefixLogger.WithPrefix("prefix2") + prefixPrefixLogger.Debugf("debug") + Expect(b.String()).To(ContainSubstring("prefix")) + Expect(b.String()).To(ContainSubstring("debug")) + }) + + Context("reading from env", func() { + BeforeEach(func() { + Expect(DefaultLogger.(*defaultLogger).logLevel).To(Equal(LogLevelNothing)) + }) + + It("reads DEBUG", func() { + os.Setenv(logEnv, "DEBUG") + Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) + }) + + It("reads debug", func() { + os.Setenv(logEnv, "debug") + Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) + }) + + It("reads INFO", func() { + os.Setenv(logEnv, "INFO") + readLoggingEnv() + Expect(readLoggingEnv()).To(Equal(LogLevelInfo)) + }) + + It("reads ERROR", func() { + os.Setenv(logEnv, "ERROR") + Expect(readLoggingEnv()).To(Equal(LogLevelError)) + }) + + It("does not error reading invalid log levels from env", func() { + os.Setenv(logEnv, "") + Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) + os.Setenv(logEnv, "asdf") + Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) + }) + }) +}) diff --git a/internal/utils/minmax.go b/internal/utils/minmax.go new file mode 100644 index 00000000..c634aa22 --- /dev/null +++ b/internal/utils/minmax.go @@ -0,0 +1,170 @@ +package utils + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/protocol" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +// Max returns the maximum of two Ints +func Max(a, b int) int { + if a < b { + return b + } + return a +} + +// MaxUint32 returns the maximum of two uint32 +func MaxUint32(a, b uint32) uint32 { + if a < b { + return b + } + return a +} + +// MaxUint64 returns the maximum of two uint64 +func MaxUint64(a, b uint64) uint64 { + if a < b { + return b + } + return a +} + +// MinUint64 returns the maximum of two uint64 +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +// Min returns the minimum of two Ints +func Min(a, b int) int { + if a < b { + return a + } + return b +} + +// MinUint32 returns the maximum of two uint32 +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +// MinInt64 returns the minimum of two int64 +func MinInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +// MaxInt64 returns the minimum of two int64 +func MaxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// MinByteCount returns the minimum of two ByteCounts +func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return a + } + return b +} + +// MaxByteCount returns the maximum of two ByteCounts +func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return b + } + return a +} + +// MaxDuration returns the max duration +func MaxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +// MinDuration returns the minimum duration +func MinDuration(a, b time.Duration) time.Duration { + if a > b { + return b + } + return a +} + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return MinDuration(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + +// MaxPacketNumber returns the max packet number +func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a > b { + return a + } + return b +} + +// MinPacketNumber returns the min packet number +func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a < b { + return a + } + return b +} diff --git a/internal/utils/minmax_test.go b/internal/utils/minmax_test.go new file mode 100644 index 00000000..021212b6 --- /dev/null +++ b/internal/utils/minmax_test.go @@ -0,0 +1,123 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Min / Max", func() { + Context("Max", func() { + It("returns the maximum", func() { + Expect(Max(5, 7)).To(Equal(7)) + Expect(Max(7, 5)).To(Equal(7)) + }) + + It("returns the maximum uint32", func() { + Expect(MaxUint32(5, 7)).To(Equal(uint32(7))) + Expect(MaxUint32(7, 5)).To(Equal(uint32(7))) + }) + + It("returns the maximum uint64", func() { + Expect(MaxUint64(5, 7)).To(Equal(uint64(7))) + Expect(MaxUint64(7, 5)).To(Equal(uint64(7))) + }) + + It("returns the minimum uint64", func() { + Expect(MinUint64(5, 7)).To(Equal(uint64(5))) + Expect(MinUint64(7, 5)).To(Equal(uint64(5))) + }) + + It("returns the maximum int64", func() { + Expect(MaxInt64(5, 7)).To(Equal(int64(7))) + Expect(MaxInt64(7, 5)).To(Equal(int64(7))) + }) + + It("returns the maximum ByteCount", func() { + Expect(MaxByteCount(7, 5)).To(Equal(protocol.ByteCount(7))) + Expect(MaxByteCount(5, 7)).To(Equal(protocol.ByteCount(7))) + }) + + It("returns the maximum duration", func() { + Expect(MaxDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Microsecond)) + Expect(MaxDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Microsecond)) + }) + + It("returns the minimum duration", func() { + Expect(MinDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Nanosecond)) + Expect(MinDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Nanosecond)) + }) + + It("returns packet number max", func() { + Expect(MaxPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(2))) + Expect(MaxPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(2))) + }) + + It("returns the maximum time", func() { + a := time.Now() + b := a.Add(time.Second) + Expect(MaxTime(a, b)).To(Equal(b)) + Expect(MaxTime(b, a)).To(Equal(b)) + }) + }) + + Context("Min", func() { + It("returns the minimum", func() { + Expect(Min(5, 7)).To(Equal(5)) + Expect(Min(7, 5)).To(Equal(5)) + }) + + It("returns the minimum uint32", func() { + Expect(MinUint32(7, 5)).To(Equal(uint32(5))) + Expect(MinUint32(5, 7)).To(Equal(uint32(5))) + }) + + It("returns the minimum int64", func() { + Expect(MinInt64(7, 5)).To(Equal(int64(5))) + Expect(MinInt64(5, 7)).To(Equal(int64(5))) + }) + + It("returns the minimum ByteCount", func() { + Expect(MinByteCount(7, 5)).To(Equal(protocol.ByteCount(5))) + Expect(MinByteCount(5, 7)).To(Equal(protocol.ByteCount(5))) + }) + + It("returns packet number min", func() { + Expect(MinPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(1))) + Expect(MinPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(1))) + }) + + It("returns the minimum duration", func() { + a := time.Now() + b := a.Add(time.Second) + Expect(MinTime(a, b)).To(Equal(a)) + Expect(MinTime(b, a)).To(Equal(a)) + }) + + It("returns the minium non-zero duration", func() { + var a time.Duration + b := time.Second + Expect(MinNonZeroDuration(0, 0)).To(BeZero()) + Expect(MinNonZeroDuration(a, b)).To(Equal(b)) + Expect(MinNonZeroDuration(b, a)).To(Equal(b)) + Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) + }) + + It("returns the minium non-zero time", func() { + a := time.Time{} + b := time.Now() + Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) + Expect(MinNonZeroTime(a, b)).To(Equal(b)) + Expect(MinNonZeroTime(b, a)).To(Equal(b)) + Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) + Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) + }) + }) + + It("returns the abs time", func() { + Expect(AbsDuration(time.Microsecond)).To(Equal(time.Microsecond)) + Expect(AbsDuration(-time.Microsecond)).To(Equal(time.Microsecond)) + }) +}) diff --git a/internal/utils/new_connection_id.go b/internal/utils/new_connection_id.go new file mode 100644 index 00000000..f758d63d --- /dev/null +++ b/internal/utils/new_connection_id.go @@ -0,0 +1,12 @@ +package utils + +import ( + "github.com/imroc/req/v3/internal/protocol" +) + +// NewConnectionID is a new connection ID +type NewConnectionID struct { + SequenceNumber uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} diff --git a/internal/utils/newconnectionid_linkedlist.go b/internal/utils/newconnectionid_linkedlist.go new file mode 100644 index 00000000..d59562e5 --- /dev/null +++ b/internal/utils/newconnectionid_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// NewConnectionIDElement is an element of a linked list. +type NewConnectionIDElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *NewConnectionIDElement + + // The list to which this element belongs. + list *NewConnectionIDList + + // The value stored with this element. + Value NewConnectionID +} + +// Next returns the next list element or nil. +func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// NewConnectionIDList is a linked list of NewConnectionIDs. +type NewConnectionIDList struct { + root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *NewConnectionIDList) Init() *NewConnectionIDList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewNewConnectionIDList returns an initialized list. +func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *NewConnectionIDList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Front() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Back() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *NewConnectionIDList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { + return l.insert(&NewConnectionIDElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/utils/packet_interval.go b/internal/utils/packet_interval.go new file mode 100644 index 00000000..ac2ca048 --- /dev/null +++ b/internal/utils/packet_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/imroc/req/v3/internal/protocol" + +// PacketInterval is an interval from one PacketNumber to the other +type PacketInterval struct { + Start protocol.PacketNumber + End protocol.PacketNumber +} diff --git a/internal/utils/packetinterval_linkedlist.go b/internal/utils/packetinterval_linkedlist.go new file mode 100644 index 00000000..b461e85a --- /dev/null +++ b/internal/utils/packetinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// PacketIntervalElement is an element of a linked list. +type PacketIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *PacketIntervalElement + + // The list to which this element belongs. + list *PacketIntervalList + + // The value stored with this element. + Value PacketInterval +} + +// Next returns the next list element or nil. +func (e *PacketIntervalElement) Next() *PacketIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *PacketIntervalElement) Prev() *PacketIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// PacketIntervalList is a linked list of PacketIntervals. +type PacketIntervalList struct { + root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *PacketIntervalList) Init() *PacketIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewPacketIntervalList returns an initialized list. +func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *PacketIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *PacketIntervalList) Front() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *PacketIntervalList) Back() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *PacketIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { + return l.insert(&PacketIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/utils/rand.go b/internal/utils/rand.go new file mode 100644 index 00000000..30069144 --- /dev/null +++ b/internal/utils/rand.go @@ -0,0 +1,29 @@ +package utils + +import ( + "crypto/rand" + "encoding/binary" +) + +// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. +type Rand struct { + buf [4]byte +} + +func (r *Rand) Int31() int32 { + rand.Read(r.buf[:]) + return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) +} + +// copied from the standard library math/rand implementation of Int63n +func (r *Rand) Int31n(n int32) int32 { + if n&(n-1) == 0 { // n is power of two, can mask + return r.Int31() & (n - 1) + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := r.Int31() + for v > max { + v = r.Int31() + } + return v % n +} diff --git a/internal/utils/rand_test.go b/internal/utils/rand_test.go new file mode 100644 index 00000000..f15a644e --- /dev/null +++ b/internal/utils/rand_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Rand", func() { + It("generates random numbers", func() { + const ( + num = 1000 + max = 12345678 + ) + + var values [num]int32 + var r Rand + for i := 0; i < num; i++ { + v := r.Int31n(max) + Expect(v).To(And( + BeNumerically(">=", 0), + BeNumerically("<", max), + )) + values[i] = v + } + + var sum uint64 + for _, n := range values { + sum += uint64(n) + } + Expect(float64(sum) / num).To(BeNumerically("~", max/2, max/25)) + }) +}) diff --git a/internal/utils/rtt_stats.go b/internal/utils/rtt_stats.go new file mode 100644 index 00000000..478d2596 --- /dev/null +++ b/internal/utils/rtt_stats.go @@ -0,0 +1,127 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/protocol" +) + +const ( + rttAlpha = 0.125 + oneMinusAlpha = 1 - rttAlpha + rttBeta = 0.25 + oneMinusBeta = 1 - rttBeta + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond +) + +// RTTStats provides round-trip statistics +type RTTStats struct { + hasMeasurement bool + + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + meanDeviation time.Duration + + maxAckDelay time.Duration +} + +// NewRTTStats makes a properly initialized RTTStats object +func NewRTTStats() *RTTStats { + return &RTTStats{} +} + +// MinRTT Returns the minRTT for the entire connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } + +// LatestRTT returns the most recent rtt measurement. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } + +// SmoothedRTT returns the smoothed RTT for the connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } + +// MeanDeviation gets the mean deviation +func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } + +// MaxAckDelay gets the max_ack_delay advertised by the peer +func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } + +// PTO gets the probe timeout duration. +func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { + if r.SmoothedRTT() == 0 { + return 2 * defaultInitialRTT + } + pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + if includeMaxAckDelay { + pto += r.MaxAckDelay() + } + return pto +} + +// UpdateRTT updates the RTT based on a new sample. +func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { + if sendDelta == InfDuration || sendDelta <= 0 { + return + } + + // Update r.minRTT first. r.minRTT does not use an rttSample corrected for + // ackDelay but the raw observed sendDelta, since poor clock granularity at + // the client may cause a high ackDelay to result in underestimation of the + // r.minRTT. + if r.minRTT == 0 || r.minRTT > sendDelta { + r.minRTT = sendDelta + } + + // Correct for ackDelay if information received from the peer results in a + // an RTT sample at least as large as minRTT. Otherwise, only use the + // sendDelta. + sample := sendDelta + if sample-r.minRTT >= ackDelay { + sample -= ackDelay + } + r.latestRTT = sample + // First time call. + if !r.hasMeasurement { + r.hasMeasurement = true + r.smoothedRTT = sample + r.meanDeviation = sample / 2 + } else { + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond + } +} + +// SetMaxAckDelay sets the max_ack_delay +func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { + r.maxAckDelay = mad +} + +// SetInitialRTT sets the initial RTT. +// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. +func (r *RTTStats) SetInitialRTT(t time.Duration) { + if r.hasMeasurement { + panic("initial RTT set after first measurement") + } + r.smoothedRTT = t + r.latestRTT = t +} + +// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. +func (r *RTTStats) OnConnectionMigration() { + r.latestRTT = 0 + r.minRTT = 0 + r.smoothedRTT = 0 + r.meanDeviation = 0 +} + +// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt +// is larger. The mean deviation is increased to the most recent deviation if +// it's larger. +func (r *RTTStats) ExpireSmoothedMetrics() { + r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) + r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) +} diff --git a/internal/utils/rtt_stats_test.go b/internal/utils/rtt_stats_test.go new file mode 100644 index 00000000..555a8b8d --- /dev/null +++ b/internal/utils/rtt_stats_test.go @@ -0,0 +1,157 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RTT stats", func() { + var rttStats *RTTStats + + BeforeEach(func() { + rttStats = NewRTTStats() + }) + + It("DefaultsBeforeUpdate", func() { + Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) + }) + + It("SmoothedRTT", func() { + // Verify that ack_delay is ignored in the first measurement. + rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) + // Verify that Smoothed RTT includes max ack delay if it's reasonable. + rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) + // Verify that large erroneous ack_delay does not change Smoothed RTT. + rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond))) + }) + + It("MinRTT", func() { + rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + // Verify that ack_delay does not go into recording of MinRTT_. + rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond))) + }) + + It("MaxAckDelay", func() { + rttStats.SetMaxAckDelay(42 * time.Minute) + Expect(rttStats.MaxAckDelay()).To(Equal(42 * time.Minute)) + }) + + It("computes the PTO", func() { + maxAckDelay := 42 * time.Minute + rttStats.SetMaxAckDelay(maxAckDelay) + rtt := time.Second + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) + Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) + Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) + Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) + }) + + It("uses the granularity for computing the PTO for short RTTs", func() { + rtt := time.Microsecond + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) + }) + + It("ExpireSmoothedMetrics", func() { + initialRtt := (10 * time.Millisecond) + rttStats.UpdateRTT(initialRtt, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + + Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2)) + + // Update once with a 20ms RTT. + doubledRtt := initialRtt * (2) + rttStats.UpdateRTT(doubledRtt, 0, time.Time{}) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125))) + + // Expire the smoothed metrics, increasing smoothed rtt and mean deviation. + rttStats.ExpireSmoothedMetrics() + Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt)) + Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875))) + + // Now go back down to 5ms and expire the smoothed metrics, and ensure the + // mean deviation increases to 15ms. + halfRtt := initialRtt / 2 + rttStats.UpdateRTT(halfRtt, 0, time.Time{}) + Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT())) + Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation())) + }) + + It("UpdateRTTWithBadSendDeltas", func() { + // Make sure we ignore bad RTTs. + // base::test::MockLog log; + + initialRtt := (10 * time.Millisecond) + rttStats.UpdateRTT(initialRtt, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + + badSendDeltas := []time.Duration{ + 0, + InfDuration, + -1000 * time.Microsecond, + } + // log.StartCapturingLogs(); + + for _, badSendDelta := range badSendDeltas { + // SCOPED_TRACE(Message() << "bad_send_delta = " + // << bad_send_delta.ToMicroseconds()); + // EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring"))); + rttStats.UpdateRTT(badSendDelta, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + } + }) + + It("ResetAfterConnectionMigrations", func() { + rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + + // Reset rtt stats on connection migrations. + rttStats.OnConnectionMigration() + Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) + }) + + It("restores the RTT", func() { + rttStats.SetInitialRTT(10 * time.Second) + Expect(rttStats.LatestRTT()).To(Equal(10 * time.Second)) + Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second)) + Expect(rttStats.MeanDeviation()).To(BeZero()) + // update the RTT and make sure that the initial value is immediately forgotten + rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal(200 * time.Millisecond)) + Expect(rttStats.SmoothedRTT()).To(Equal(200 * time.Millisecond)) + Expect(rttStats.MeanDeviation()).To(Equal(100 * time.Millisecond)) + }) +}) diff --git a/internal/utils/streamframe_interval.go b/internal/utils/streamframe_interval.go new file mode 100644 index 00000000..71f3c6e5 --- /dev/null +++ b/internal/utils/streamframe_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/imroc/req/v3/internal/protocol" + +// ByteInterval is an interval from one ByteCount to the other +type ByteInterval struct { + Start protocol.ByteCount + End protocol.ByteCount +} diff --git a/internal/utils/timer.go b/internal/utils/timer.go new file mode 100644 index 00000000..a4f5e67a --- /dev/null +++ b/internal/utils/timer.go @@ -0,0 +1,53 @@ +package utils + +import ( + "math" + "time" +) + +// A Timer wrapper that behaves correctly when resetting +type Timer struct { + t *time.Timer + read bool + deadline time.Time +} + +// NewTimer creates a new timer that is not set +func NewTimer() *Timer { + return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))} +} + +// Chan returns the channel of the wrapped timer +func (t *Timer) Chan() <-chan time.Time { + return t.t.C +} + +// Reset the timer, no matter whether the value was read or not +func (t *Timer) Reset(deadline time.Time) { + if deadline.Equal(t.deadline) && !t.read { + // No need to reset the timer + return + } + + // We need to drain the timer if the value from its channel was not read yet. + // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU + if !t.t.Stop() && !t.read { + <-t.t.C + } + if !deadline.IsZero() { + t.t.Reset(time.Until(deadline)) + } + + t.read = false + t.deadline = deadline +} + +// SetRead should be called after the value from the chan was read +func (t *Timer) SetRead() { + t.read = true +} + +// Stop stops the timer +func (t *Timer) Stop() { + t.t.Stop() +} diff --git a/internal/utils/timer_test.go b/internal/utils/timer_test.go new file mode 100644 index 00000000..0cbb4a01 --- /dev/null +++ b/internal/utils/timer_test.go @@ -0,0 +1,87 @@ +package utils + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Timer", func() { + const d = 10 * time.Millisecond + + It("doesn't fire a newly created timer", func() { + t := NewTimer() + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("works", func() { + t := NewTimer() + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("works multiple times with reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + t.SetRead() + } + }) + + It("works multiple times without reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + time.Sleep(d * 2) + } + Eventually(t.Chan()).Should(Receive()) + }) + + It("works when resetting without expiration", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(time.Hour)) + } + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("immediately fires the timer, if the deadlines has already passed", func() { + t := NewTimer() + t.Reset(time.Now().Add(-time.Second)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("doesn't set a timer if the deadline is the zero value", func() { + t := NewTimer() + t.Reset(time.Time{}) + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("fires the timer twice, if reset to the same deadline", func() { + deadline := time.Now().Add(-time.Millisecond) + t := NewTimer() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + t.SetRead() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + }) + + It("only fires the timer once, if it is reset to the same deadline, but not read in between", func() { + deadline := time.Now().Add(-time.Millisecond) + t := NewTimer() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("stops", func() { + t := NewTimer() + t.Reset(time.Now().Add(50 * time.Millisecond)) + t.Stop() + Consistently(t.Chan()).ShouldNot(Receive()) + }) +}) diff --git a/internal/utils/utils_suite_test.go b/internal/utils/utils_suite_test.go new file mode 100644 index 00000000..9ecb8c05 --- /dev/null +++ b/internal/utils/utils_suite_test.go @@ -0,0 +1,13 @@ +package utils + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestCrypto(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go new file mode 100644 index 00000000..a71607c9 --- /dev/null +++ b/internal/wire/ack_frame.go @@ -0,0 +1,252 @@ +package wire + +import ( + "bytes" + "errors" + "github.com/lucas-clemente/quic-go" + "sort" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/utils" +) + +var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") + +// An AckFrame is an ACK frame +type AckFrame struct { + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + DelayTime time.Duration + + ECT0, ECT1, ECNCE uint64 +} + +// parseAckFrame reads an ACK frame +func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ quic.VersionNumber) (*AckFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + ecn := typeByte&0x1 > 0 + + frame := &AckFrame{} + + la, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + largestAcked := protocol.PacketNumber(la) + delay, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + delayTime := time.Duration(delay*1< largestAcked { + return nil, errors.New("invalid first ACK range") + } + smallest := largestAcked - ackBlock + + // read all the other ACK ranges + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) + for i := uint64(0); i < numBlocks; i++ { + g, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + gap := protocol.PacketNumber(g) + if smallest < gap+2 { + return nil, errInvalidAckRanges + } + largest := smallest - gap - 2 + + ab, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ackBlock := protocol.PacketNumber(ab) + + if ackBlock > largest { + return nil, errInvalidAckRanges + } + smallest = largest - ackBlock + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) + } + + if !frame.validateAckRanges() { + return nil, errInvalidAckRanges + } + + // parse (and skip) the ECN section + if ecn { + for i := 0; i < 3; i++ { + if _, err := quicvarint.Read(r); err != nil { + return nil, err + } + } + } + + return frame, nil +} + +// Write writes an ACK frame. +func (f *AckFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + if hasECN { + b.WriteByte(0x3) + } else { + b.WriteByte(0x2) + } + quicvarint.Write(b, uint64(f.LargestAcked())) + quicvarint.Write(b, encodeAckDelay(f.DelayTime)) + + numRanges := f.numEncodableAckRanges() + quicvarint.Write(b, uint64(numRanges-1)) + + // write the first range + _, firstRange := f.encodeAckRange(0) + quicvarint.Write(b, firstRange) + + // write all the other range + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + quicvarint.Write(b, gap) + quicvarint.Write(b, len) + } + + if hasECN { + quicvarint.Write(b, f.ECT0) + quicvarint.Write(b, f.ECT1) + quicvarint.Write(b, f.ECNCE) + } + return nil +} + +// Length of a written frame +func (f *AckFrame) Length(version quic.VersionNumber) protocol.ByteCount { + largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() + + length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + + length += quicvarint.Len(uint64(numRanges - 1)) + lowestInFirstRange := f.AckRanges[0].Smallest + length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) + + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += quicvarint.Len(gap) + length += quicvarint.Len(len) + } + if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { + length += quicvarint.Len(f.ECT0) + length += quicvarint.Len(f.ECT1) + length += quicvarint.Len(f.ECNCE) + } + return length +} + +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := quicvarint.Len(gap) + quicvarint.Len(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) +} + +// HasMissingRanges returns if this frame reports any missing packets +func (f *AckFrame) HasMissingRanges() bool { + return len(f.AckRanges) > 1 +} + +func (f *AckFrame) validateAckRanges() bool { + if len(f.AckRanges) == 0 { + return false + } + + // check the validity of every single ACK range + for _, ackRange := range f.AckRanges { + if ackRange.Smallest > ackRange.Largest { + return false + } + } + + // check the consistency for ACK with multiple NACK ranges + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + lastAckRange := f.AckRanges[i-1] + if lastAckRange.Smallest <= ackRange.Smallest { + return false + } + if lastAckRange.Smallest <= ackRange.Largest+1 { + return false + } + } + + return true +} + +// LargestAcked is the largest acked packet number +func (f *AckFrame) LargestAcked() protocol.PacketNumber { + return f.AckRanges[0].Largest +} + +// LowestAcked is the lowest acked packet number +func (f *AckFrame) LowestAcked() protocol.PacketNumber { + return f.AckRanges[len(f.AckRanges)-1].Smallest +} + +// AcksPacket determines if this ACK frame acks a certain packet number +func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { + if p < f.LowestAcked() || p > f.LargestAcked() { + return false + } + + i := sort.Search(len(f.AckRanges), func(i int) bool { + return p >= f.AckRanges[i].Smallest + }) + // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked + return p <= f.AckRanges[i].Largest +} + +func encodeAckDelay(delay time.Duration) uint64 { + return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) +} diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go new file mode 100644 index 00000000..de1671b0 --- /dev/null +++ b/internal/wire/ack_frame_test.go @@ -0,0 +1,454 @@ +package wire + +import ( + "bytes" + "io" + "math" + "time" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ACK Frame (for IETF QUIC)", func() { + Context("parsing", func() { + It("parses an ACK frame without any ranges", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(10)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("parses an ACK frame that only acks a single packet", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(55)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts an ACK frame that acks all packets from 0 to largest", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(20)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(20)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects an ACK frame that has a first ACK block which is larger than LargestAcked", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(20)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(21)...) // first ack block + b := bytes.NewReader(data) + _, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError("invalid first ACK range")) + }) + + It("parses an ACK frame that has a single block", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(frame.AckRanges).To(Equal([]AckRange{ + {Largest: 1000, Smallest: 900}, + {Largest: 800, Smallest: 750}, + })) + Expect(b.Len()).To(BeZero()) + }) + + It("parses an ACK frame that has a multiple blocks", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(2)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + data = append(data, encodeVarInt(0)...) // gap + data = append(data, encodeVarInt(0)...) // ack block + data = append(data, encodeVarInt(1)...) // gap + data = append(data, encodeVarInt(1)...) // ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(frame.AckRanges).To(Equal([]AckRange{ + {Largest: 100, Smallest: 100}, + {Largest: 98, Smallest: 98}, + {Largest: 95, Smallest: 94}, + })) + Expect(b.Len()).To(BeZero()) + }) + + It("uses the ack delay exponent", func() { + const delayTime = 1 << 10 * time.Millisecond + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: delayTime, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + for i := uint8(0); i < 8; i++ { + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) + } + }) + + It("gracefully handles overflows of the delay time", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(math.MaxUint64/5)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.DelayTime).To(BeNumerically(">", 0)) + // The maximum encodable duration is ~292 years. + Expect(frame.DelayTime.Hours()).To(BeNumerically("~", 292*365*24, 365*24)) + }) + + It("errors on EOF", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + + Context("ACK_ECN", func() { + It("parses", func() { + data := []byte{0x3} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(10)...) // first ack block + data = append(data, encodeVarInt(0x42)...) // ECT(0) + data = append(data, encodeVarInt(0x12345)...) // ECT(1) + data = append(data, encodeVarInt(0x12345678)...) // ECN-CE + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOF", func() { + data := []byte{0x3} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + data = append(data, encodeVarInt(0x42)...) // ECT(0) + data = append(data, encodeVarInt(0x12345)...) // ECT(1) + data = append(data, encodeVarInt(0x12345678)...) // ECN-CE + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + }) + + Context("when writing", func() { + It("writes a simple frame", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x2} + expected = append(expected, encodeVarInt(1337)...) // largest acked + expected = append(expected, 0) // delay + expected = append(expected, encodeVarInt(0)...) // num ranges + expected = append(expected, encodeVarInt(1337-100)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes an ACK-ECN frame", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 10, Largest: 2000}}, + ECT0: 13, + ECT1: 37, + ECNCE: 12345, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + expected := []byte{0x3} + expected = append(expected, encodeVarInt(2000)...) // largest acked + expected = append(expected, 0) // delay + expected = append(expected, encodeVarInt(0)...) // num ranges + expected = append(expected, encodeVarInt(2000-10)...) + expected = append(expected, encodeVarInt(13)...) + expected = append(expected, encodeVarInt(37)...) + expected = append(expected, encodeVarInt(12345)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a frame that acks a single packet", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, + DelayTime: 18 * time.Millisecond, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(frame.DelayTime).To(Equal(f.DelayTime)) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame that acks many packets", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame with a a single gap", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 400, Largest: 1000}, + {Smallest: 100, Largest: 200}, + }, + } + Expect(f.validateAckRanges()).To(BeTrue()) + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame with multiple ranges", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 10, Largest: 10}, + {Smallest: 8, Largest: 8}, + {Smallest: 5, Largest: 6}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(f.validateAckRanges()).To(BeTrue()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + }) + + It("limits the maximum size of the ACK frame", func() { + buf := &bytes.Buffer{} + const numRanges = 1000 + ackRanges := make([]AckRange, numRanges) + for i := protocol.PacketNumber(1); i <= numRanges; i++ { + ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i} + } + f := &AckFrame{AckRanges: ackRanges} + Expect(f.validateAckRanges()).To(BeTrue()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize + Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) + Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges + }) + }) + + Context("ACK range validator", func() { + It("rejects ACKs without ranges", func() { + Expect((&AckFrame{}).validateAckRanges()).To(BeFalse()) + }) + + It("accepts an ACK without NACK Ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 7}}, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + + It("rejects ACK ranges with Smallest greater than Largest", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 8, Largest: 10}, + {Smallest: 4, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects ACK ranges in the wrong order", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 2, Largest: 2}, + {Smallest: 6, Largest: 7}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects with overlapping ACK ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 7}, + {Smallest: 2, Largest: 5}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects ACK ranges that are part of a larger ACK range", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 4, Largest: 7}, + {Smallest: 5, Largest: 6}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects with directly adjacent ACK ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 7}, + {Smallest: 2, Largest: 4}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("accepts an ACK with one lost packet", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 10}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + + It("accepts an ACK with multiple lost packets", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 15, Largest: 20}, + {Smallest: 10, Largest: 12}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + }) + + Context("check if ACK frame acks a certain packet", func() { + It("works with an ACK without any ranges", func() { + f := AckFrame{ + AckRanges: []AckRange{{Smallest: 5, Largest: 10}}, + } + Expect(f.AcksPacket(1)).To(BeFalse()) + Expect(f.AcksPacket(4)).To(BeFalse()) + Expect(f.AcksPacket(5)).To(BeTrue()) + Expect(f.AcksPacket(8)).To(BeTrue()) + Expect(f.AcksPacket(10)).To(BeTrue()) + Expect(f.AcksPacket(11)).To(BeFalse()) + Expect(f.AcksPacket(20)).To(BeFalse()) + }) + + It("works with an ACK with multiple ACK ranges", func() { + f := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 15, Largest: 20}, + {Smallest: 5, Largest: 8}, + }, + } + Expect(f.AcksPacket(4)).To(BeFalse()) + Expect(f.AcksPacket(5)).To(BeTrue()) + Expect(f.AcksPacket(6)).To(BeTrue()) + Expect(f.AcksPacket(7)).To(BeTrue()) + Expect(f.AcksPacket(8)).To(BeTrue()) + Expect(f.AcksPacket(9)).To(BeFalse()) + Expect(f.AcksPacket(14)).To(BeFalse()) + Expect(f.AcksPacket(15)).To(BeTrue()) + Expect(f.AcksPacket(18)).To(BeTrue()) + Expect(f.AcksPacket(19)).To(BeTrue()) + Expect(f.AcksPacket(20)).To(BeTrue()) + Expect(f.AcksPacket(21)).To(BeFalse()) + }) + }) +}) diff --git a/internal/wire/ack_range.go b/internal/wire/ack_range.go new file mode 100644 index 00000000..68032205 --- /dev/null +++ b/internal/wire/ack_range.go @@ -0,0 +1,14 @@ +package wire + +import "github.com/imroc/req/v3/internal/protocol" + +// AckRange is an ACK range +type AckRange struct { + Smallest protocol.PacketNumber + Largest protocol.PacketNumber +} + +// Len returns the number of packets contained in this ACK range +func (r AckRange) Len() protocol.PacketNumber { + return r.Largest - r.Smallest + 1 +} diff --git a/internal/wire/ack_range_test.go b/internal/wire/ack_range_test.go new file mode 100644 index 00000000..84ef71b5 --- /dev/null +++ b/internal/wire/ack_range_test.go @@ -0,0 +1,13 @@ +package wire + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ACK range", func() { + It("returns the length", func() { + Expect(AckRange{Smallest: 10, Largest: 10}.Len()).To(BeEquivalentTo(1)) + Expect(AckRange{Smallest: 10, Largest: 13}.Len()).To(BeEquivalentTo(4)) + }) +}) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go new file mode 100644 index 00000000..627007dd --- /dev/null +++ b/internal/wire/connection_close_frame.go @@ -0,0 +1,84 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A ConnectionCloseFrame is a CONNECTION_CLOSE frame +type ConnectionCloseFrame struct { + IsApplicationError bool + ErrorCode uint64 + FrameType uint64 + ReasonPhrase string +} + +func parseConnectionCloseFrame(r *bytes.Reader, _ quic.VersionNumber) (*ConnectionCloseFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} + ec, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.ErrorCode = ec + // read the Frame Type, if this is not an application error + if !f.IsApplicationError { + ft, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.FrameType = ft + } + var reasonPhraseLen uint64 + reasonPhraseLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unnecessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the whole reason phrase would result in EOF when attempting to READ + if int(reasonPhraseLen) > r.Len() { + return nil, io.EOF + } + + reasonPhrase := make([]byte, reasonPhraseLen) + if _, err := io.ReadFull(r, reasonPhrase); err != nil { + // this should never happen, since we already checked the reasonPhraseLen earlier + return nil, err + } + f.ReasonPhrase = string(reasonPhrase) + return f, nil +} + +// Length of a written frame +func (f *ConnectionCloseFrame) Length(quic.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) + if !f.IsApplicationError { + length += quicvarint.Len(f.FrameType) // for the frame type + } + return length +} + +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + if f.IsApplicationError { + b.WriteByte(0x1d) + } else { + b.WriteByte(0x1c) + } + + quicvarint.Write(b, f.ErrorCode) + if !f.IsApplicationError { + quicvarint.Write(b, f.FrameType) + } + quicvarint.Write(b, uint64(len(f.ReasonPhrase))) + b.WriteString(f.ReasonPhrase) + return nil +} diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go new file mode 100644 index 00000000..7947c681 --- /dev/null +++ b/internal/wire/connection_close_frame_test.go @@ -0,0 +1,153 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CONNECTION_CLOSE Frame", func() { + Context("when parsing", func() { + It("accepts sample frame containing a QUIC error code", func() { + reason := "No recent network activity." + data := []byte{0x1c} + data = append(data, encodeVarInt(0x19)...) + data = append(data, encodeVarInt(0x1337)...) // frame type + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, []byte(reason)...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.IsApplicationError).To(BeFalse()) + Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) + Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) + Expect(frame.ReasonPhrase).To(Equal(reason)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts sample frame containing an application error code", func() { + reason := "The application messed things up." + data := []byte{0x1d} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, reason...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.IsApplicationError).To(BeTrue()) + Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) + Expect(frame.ReasonPhrase).To(Equal(reason)) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects long reason phrases", func() { + data := []byte{0x1c} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(0x42)...) // frame type + data = append(data, encodeVarInt(0xffff)...) // reason phrase length + b := bytes.NewReader(data) + _, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + reason := "No recent network activity." + data := []byte{0x1c} + data = append(data, encodeVarInt(0x19)...) + data = append(data, encodeVarInt(0x1337)...) // frame type + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, []byte(reason)...) + _, err := parseConnectionCloseFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + + It("parses a frame without a reason phrase", func() { + data := []byte{0x1c} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(0x42)...) // frame type + data = append(data, encodeVarInt(0)...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.ReasonPhrase).To(BeEmpty()) + Expect(b.Len()).To(BeZero()) + }) + }) + + Context("when writing", func() { + It("writes a frame without a reason phrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xbeef, + FrameType: 0x12345, + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1c} + expected = append(expected, encodeVarInt(0xbeef)...) + expected = append(expected, encodeVarInt(0x12345)...) // frame type + expected = append(expected, encodeVarInt(0)...) // reason phrase length + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with a reason phrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xdead, + ReasonPhrase: "foobar", + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1c} + expected = append(expected, encodeVarInt(0xdead)...) + expected = append(expected, encodeVarInt(0)...) // frame type + expected = append(expected, encodeVarInt(6)...) // reason phrase length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with an application error code", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 0xdead, + ReasonPhrase: "foobar", + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1d} + expected = append(expected, encodeVarInt(0xdead)...) + expected = append(expected, encodeVarInt(6)...) // reason phrase length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has proper min length, for a frame containing a QUIC error code", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + ErrorCode: 0xcafe, + FrameType: 0xdeadbeef, + ReasonPhrase: "foobar", + } + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + }) + + It("has proper min length, for a frame containing an application error code", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 0xcafe, + ReasonPhrase: "foobar", + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + }) + }) +}) diff --git a/internal/wire/crypto_frame.go b/internal/wire/crypto_frame.go new file mode 100644 index 00000000..56f176a7 --- /dev/null +++ b/internal/wire/crypto_frame.go @@ -0,0 +1,103 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A CryptoFrame is a CRYPTO frame +type CryptoFrame struct { + Offset protocol.ByteCount + Data []byte +} + +func parseCryptoFrame(r *bytes.Reader, _ quic.VersionNumber) (*CryptoFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &CryptoFrame{} + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + dataLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + return frame, nil +} + +func (f *CryptoFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x6) + quicvarint.Write(b, uint64(f.Offset)) + quicvarint.Write(b, uint64(len(f.Data))) + b.Write(f.Data) + return nil +} + +// Length of a written frame +func (f *CryptoFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1 + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version quic.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { + if f.Length(version) <= maxSize { + return nil, false + } + + n := f.MaxDataLen(maxSize) + if n == 0 { + return nil, true + } + + newLen := protocol.ByteCount(len(f.Data)) - n + + new := &CryptoFrame{} + new.Offset = f.Offset + new.Data = make([]byte, newLen) + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} diff --git a/internal/wire/crypto_frame_test.go b/internal/wire/crypto_frame_test.go new file mode 100644 index 00000000..08ede5d0 --- /dev/null +++ b/internal/wire/crypto_frame_test.go @@ -0,0 +1,148 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CRYPTO frame", func() { + Context("when parsing", func() { + It("parses", func() { + data := []byte{0x6} + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseCryptoFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(r.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x6} + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // data length + data = append(data, []byte("foobar")...) + _, err := parseCryptoFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes a frame", func() { + f := &CryptoFrame{ + Offset: 0x123456, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x6} + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, encodeVarInt(6)...) // length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("always returns a data length such that the resulting frame has the right size", func() { + data := make([]byte, maxSize) + f := &CryptoFrame{ + Offset: 0xdeadbeef, + } + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < maxSize; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) + if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written + // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) + + Context("length", func() { + It("has the right length for a frame without offset and data length", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) + }) + }) + + Context("splitting", func() { + It("splits a frame", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + hdrLen := f.Length(protocol.Version1) - 6 + new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(new.Data).To(Equal([]byte("foo"))) + Expect(new.Offset).To(Equal(protocol.ByteCount(0x1337))) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x1337 + 3))) + }) + + It("doesn't split if there's enough space in the frame", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + f, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) + Expect(needsSplit).To(BeFalse()) + Expect(f).To(BeNil()) + }) + + It("doesn't split if the size is too small", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + length := f.Length(protocol.Version1) - 6 + for i := protocol.ByteCount(0); i <= length; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + f, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).ToNot(BeNil()) + }) + }) +}) diff --git a/internal/wire/data_blocked_frame.go b/internal/wire/data_blocked_frame.go new file mode 100644 index 00000000..a6ab54fc --- /dev/null +++ b/internal/wire/data_blocked_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A DataBlockedFrame is a DATA_BLOCKED frame +type DataBlockedFrame struct { + MaximumData protocol.ByteCount +} + +func parseDataBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*DataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &DataBlockedFrame{ + MaximumData: protocol.ByteCount(offset), + }, nil +} + +func (f *DataBlockedFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + typeByte := uint8(0x14) + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *DataBlockedFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) +} diff --git a/internal/wire/data_blocked_frame_test.go b/internal/wire/data_blocked_frame_test.go new file mode 100644 index 00000000..2aac0525 --- /dev/null +++ b/internal/wire/data_blocked_frame_test.go @@ -0,0 +1,54 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("DATA_BLOCKED frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x14} + data = append(data, encodeVarInt(0x12345678)...) + b := bytes.NewReader(data) + frame, err := parseDataBlockedFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x14} + data = append(data, encodeVarInt(0x12345678)...) + _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := DataBlockedFrame{MaximumData: 0xdeadbeef} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x14} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := DataBlockedFrame{MaximumData: 0x12345} + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x12345))) + }) + }) +}) diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go new file mode 100644 index 00000000..a330185e --- /dev/null +++ b/internal/wire/datagram_frame.go @@ -0,0 +1,86 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A DatagramFrame is a DATAGRAM frame +type DatagramFrame struct { + DataLenPresent bool + Data []byte +} + +func parseDatagramFrame(r *bytes.Reader, _ quic.VersionNumber) (*DatagramFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &DatagramFrame{} + f.DataLenPresent = typeByte&0x1 > 0 + + var length uint64 + if f.DataLenPresent { + var err error + len, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if len > uint64(r.Len()) { + return nil, io.EOF + } + length = len + } else { + length = uint64(r.Len()) + } + f.Data = make([]byte, length) + if _, err := io.ReadFull(r, f.Data); err != nil { + return nil, err + } + return f, nil +} + +func (f *DatagramFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + typeByte := uint8(0x30) + if f.DataLenPresent { + typeByte ^= 0x1 + } + b.WriteByte(typeByte) + if f.DataLenPresent { + quicvarint.Write(b, uint64(len(f.Data))) + } + b.Write(f.Data) + return nil +} + +// MaxDataLen returns the maximum data length +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version quic.VersionNumber) protocol.ByteCount { + headerLen := protocol.ByteCount(1) + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// Length of a written frame +func (f *DatagramFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + length := 1 + protocol.ByteCount(len(f.Data)) + if f.DataLenPresent { + length += protocol.ByteCount(quicvarint.Len(uint64(len(f.Data)))) + } + return length +} diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go new file mode 100644 index 00000000..363d6c34 --- /dev/null +++ b/internal/wire/datagram_frame_test.go @@ -0,0 +1,154 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame", func() { + Context("when parsing", func() { + It("parses a frame containing a length", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + }) + + It("parses a frame without length", func() { + data := []byte{0x30} + data = append(data, []byte("Lorem ipsum dolor sit amet")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) + Expect(frame.DataLenPresent).To(BeFalse()) + Expect(r.Len()).To(BeZero()) + }) + + It("errors when the length is longer than the rest of the frame", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("fooba")...) + r := bytes.NewReader(data) + _, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(6)...) // length + data = append(data, []byte("foobar")...) + _, err := parseDatagramFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x30 ^ 0x1} + expected = append(expected, encodeVarInt(0x6)...) + expected = append(expected, []byte("foobar")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a frame without length", func() { + f := &DatagramFrame{Data: []byte("Lorem ipsum")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x30} + expected = append(expected, []byte("Lorem ipsum")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("length", func() { + It("has the right length for a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(6) + 6)) + }) + + It("has the right length for a frame without length", func() { + f := &DatagramFrame{Data: []byte("foobar")} + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(1 + 6))) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("returns a data length such that the resulting frame has the right size, if data length is not present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{} + b := &bytes.Buffer{} + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(Equal(i)) + } + }) + + It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{DataLenPresent: true} + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) +}) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go new file mode 100644 index 00000000..5434399b --- /dev/null +++ b/internal/wire/extended_header.go @@ -0,0 +1,250 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/utils" +) + +// ErrInvalidReservedBits is returned when the reserved bits are incorrect. +// When this error is returned, parsing continues, and an ExtendedHeader is returned. +// This is necessary because we need to decrypt the packet in that case, +// in order to avoid a timing side-channel. +var ErrInvalidReservedBits = errors.New("invalid reserved bits") + +// ExtendedHeader is the header of a QUIC packet. +type ExtendedHeader struct { + Header + + typeByte byte + + KeyPhase protocol.KeyPhaseBit + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + parsedLen protocol.ByteCount +} + +func (h *ExtendedHeader) parse(b *bytes.Reader, v quic.VersionNumber) (bool /* reserved bits valid */, error) { + startLen := b.Len() + // read the (now unencrypted) first byte + var err error + h.typeByte, err = b.ReadByte() + if err != nil { + return false, err + } + if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { + return false, err + } + var reservedBitsValid bool + if h.IsLongHeader { + reservedBitsValid, err = h.parseLongHeader(b, v) + } else { + reservedBitsValid, err = h.parseShortHeader(b, v) + } + if err != nil { + return false, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return reservedBitsValid, err +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ quic.VersionNumber) (bool /* reserved bits valid */, error) { + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0xc != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ quic.VersionNumber) (bool /* reserved bits valid */, error) { + h.KeyPhase = protocol.KeyPhaseZero + if h.typeByte&0x4 > 0 { + h.KeyPhase = protocol.KeyPhaseOne + } + + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0x18 != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { + h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + n, err := b.ReadByte() + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen2: + n, err := utils.BigEndian.ReadUint16(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen3: + n, err := utils.BigEndian.ReadUint24(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen4: + n, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// Write writes the Header. +func (h *ExtendedHeader) Write(b *bytes.Buffer, ver quic.VersionNumber) error { + if h.DestConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + } + if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + } + if h.IsLongHeader { + return h.writeLongHeader(b, ver) + } + return h.writeShortHeader(b, ver) +} + +func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version quic.VersionNumber) error { + var packetType uint8 + if version == protocol.Version2 { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b01 + case protocol.PacketType0RTT: + packetType = 0b10 + case protocol.PacketTypeHandshake: + packetType = 0b11 + case protocol.PacketTypeRetry: + packetType = 0b00 + } + } else { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b00 + case protocol.PacketType0RTT: + packetType = 0b01 + case protocol.PacketTypeHandshake: + packetType = 0b10 + case protocol.PacketTypeRetry: + packetType = 0b11 + } + } + firstByte := 0xc0 | packetType<<4 + if h.Type != protocol.PacketTypeRetry { + // Retry packets don't have a packet number + firstByte |= uint8(h.PacketNumberLen - 1) + } + + b.WriteByte(firstByte) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + b.WriteByte(uint8(h.DestConnectionID.Len())) + b.Write(h.DestConnectionID.Bytes()) + b.WriteByte(uint8(h.SrcConnectionID.Len())) + b.Write(h.SrcConnectionID.Bytes()) + + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeRetry: + b.Write(h.Token) + return nil + case protocol.PacketTypeInitial: + quicvarint.Write(b, uint64(len(h.Token))) + b.Write(h.Token) + } + quicvarint.WriteWithLen(b, uint64(h.Length), 2) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ quic.VersionNumber) error { + typeByte := 0x40 | uint8(h.PacketNumberLen-1) + if h.KeyPhase == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } + + b.WriteByte(typeByte) + b.Write(h.DestConnectionID.Bytes()) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen3: + utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// GetLength determines the length of the Header. +func (h *ExtendedHeader) GetLength(v quic.VersionNumber) protocol.ByteCount { + if h.IsLongHeader { + length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ + if h.Type == protocol.PacketTypeInitial { + length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length + } + + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) + length += protocol.ByteCount(h.PacketNumberLen) + return length +} + +// Log logs the Header +func (h *ExtendedHeader) Log(logger utils.Logger) { + if h.IsLongHeader { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) + return + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go new file mode 100644 index 00000000..f9ab54ba --- /dev/null +++ b/internal/wire/extended_header_test.go @@ -0,0 +1,481 @@ +package wire + +import ( + "bytes" + "log" + "os" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Header", func() { + const versionIETFHeader = protocol.VersionTLS // a QUIC version that uses the IETF Header format + + Context("Writing", func() { + var buf *bytes.Buffer + + BeforeEach(func() { + buf = &bytes.Buffer{} + }) + + Context("Long Header", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + + It("writes", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, + Version: 0x1020304, + Length: protocol.InitialPacketSizeIPv4, + }, + PacketNumber: 0xdecaf, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{ + 0xc0 | 0x2<<4 | 0x2, + 0x1, 0x2, 0x3, 0x4, // version number + 0x6, // dest connection ID length + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID + 0x8, // src connection ID length + 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID + } + expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length + expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("refuses to write a header with a too long connection ID", func() { + err := (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + SrcConnectionID: srcConnID, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long + Version: 0x1020304, + Type: 0x5, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader) + Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) + }) + + It("writes a header with a 20 byte connection ID", func() { + err := (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + SrcConnectionID: srcConnID, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long + Version: 0x1020304, + Type: 0x5, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) + }) + + It("writes an Initial containing a token", func() { + token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: 0x1020304, + Type: protocol.PacketTypeInitial, + Token: token, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0) + expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) + Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) + }) + + It("uses a 2-byte encoding for the length on Initial packets", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: 0x1020304, + Type: protocol.PacketTypeInitial, + Length: 37, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader)).To(Succeed()) + b := &bytes.Buffer{} + quicvarint.WriteWithLen(b, 37, 2) + Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) + }) + + It("writes a Retry packet", func() { + token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") + Expect((&ExtendedHeader{Header: Header{ + IsLongHeader: true, + Version: protocol.Version1, + Type: protocol.PacketTypeRetry, + Token: token, + }}).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version1) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length + expected = append(expected, token...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("long header, version 2", func() { + It("writes an Initial", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeInitial, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b01) + }) + + It("writes a Retry packet", func() { + token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") + Expect((&ExtendedHeader{Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeRetry, + Token: token, + }}).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version2) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length + expected = append(expected, token...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a Handshake Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeHandshake, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b11) + }) + + It("writes a 0-RTT Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketType0RTT, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b10) + }) + }) + + Context("short header", func() { + It("writes a header with connection ID", func() { + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0x42, // packet number + })) + }) + + It("writes a header without connection ID", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40, + 0x42, // packet number + })) + }) + + It("writes a header with a 2 byte packet number", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen2, + PacketNumber: 0x765, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0x40 | 0x1} + expected = append(expected, []byte{0x7, 0x65}...) // packet number + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a header with a 4 byte packet number", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen4, + PacketNumber: 0x12345678, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0x40 | 0x3} + expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("errors when given an invalid packet number length", func() { + err := (&ExtendedHeader{ + PacketNumberLen: 5, + PacketNumber: 0xdecafbad, + }).Write(buf, versionIETFHeader) + Expect(err).To(MatchError("invalid packet number length: 5")) + }) + + It("writes the Key Phase Bit", func() { + Expect((&ExtendedHeader{ + KeyPhase: protocol.KeyPhaseOne, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40 | 0x4, + 0x42, // packet number + })) + }) + }) + }) + + Context("getting the length", func() { + var buf *bytes.Buffer + + BeforeEach(func() { + buf = &bytes.Buffer{} + }) + + It("has the right length for the Long Header, for a short length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Length: 1, + }, + PacketNumberLen: protocol.PacketNumberLen1, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for the Long Header, for a long length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Length: 1500, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial that has a short length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 15, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial not containing a Token", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 1500, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial containing a Token", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Type: protocol.PacketTypeInitial, + Length: 1500, + Token: []byte("foo"), + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for a Short Header containing a connection ID", func() { + h := &ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: protocol.PacketNumberLen1, + } + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(10)) + }) + + It("has the right length for a short header without a connection ID", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(2)) + }) + + It("has the right length for a short header with a 2 byte packet number", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(3)) + }) + + It("has the right length for a short header with a 5 byte packet number", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(5)) + }) + }) + + Context("Logging", func() { + var ( + buf *bytes.Buffer + logger utils.Logger + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) + }) + + AfterEach(func() { + log.SetOutput(os.Stdout) + }) + + It("logs Long Headers", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, + Type: protocol.PacketTypeHandshake, + Length: 54321, + Version: 0xfeed, + }, + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: deadbeefcafe1337, SrcConnectionID: decafbad13371337, PacketNumber: 1337, PacketNumberLen: 2, Length: 54321, Version: 0xfeed}")) + }) + + It("logs Initial Packets with a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeInitial, + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Length: 100, + Version: 0xfeed, + }, + PacketNumber: 42, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0xdeadbeef, PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) + }) + + It("logs Initial packets without a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeInitial, + Length: 100, + Version: 0xfeed, + }, + PacketNumber: 42, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: (empty), PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) + }) + + It("logs Retry packets with a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeRetry, + Token: []byte{0x12, 0x34, 0x56}, + Version: 0xfeed, + }, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}")) + }) + + It("logs Short Headers containing a connection ID", func() { + (&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + }, + KeyPhase: protocol.KeyPhaseOne, + PacketNumber: 1337, + PacketNumberLen: 4, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) + }) + }) +}) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go new file mode 100644 index 00000000..3d06b30f --- /dev/null +++ b/internal/wire/frame_parser.go @@ -0,0 +1,144 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "reflect" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" +) + +type frameParser struct { + ackDelayExponent uint8 + + supportsDatagrams bool + + version quic.VersionNumber +} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(supportsDatagrams bool, v quic.VersionNumber) FrameParser { + return &frameParser{ + supportsDatagrams: supportsDatagrams, + version: v, + } +} + +// ParseNext parses the next frame. +// It skips PADDING frames. +func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { + for r.Len() != 0 { + typeByte, _ := r.ReadByte() + if typeByte == 0x0 { // PADDING frame + continue + } + r.UnreadByte() + + f, err := p.parseFrame(r, typeByte, encLevel) + if err != nil { + return nil, &qerr.TransportError{ + FrameType: uint64(typeByte), + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } + } + return f, nil + } + return nil, nil +} + +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { + var frame Frame + var err error + if typeByte&0xf8 == 0x8 { + frame, err = parseStreamFrame(r, p.version) + } else { + switch typeByte { + case 0x1: + frame, err = parsePingFrame(r, p.version) + case 0x2, 0x3: + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) + case 0x4: + frame, err = parseResetStreamFrame(r, p.version) + case 0x5: + frame, err = parseStopSendingFrame(r, p.version) + case 0x6: + frame, err = parseCryptoFrame(r, p.version) + case 0x7: + frame, err = parseNewTokenFrame(r, p.version) + case 0x10: + frame, err = parseMaxDataFrame(r, p.version) + case 0x11: + frame, err = parseMaxStreamDataFrame(r, p.version) + case 0x12, 0x13: + frame, err = parseMaxStreamsFrame(r, p.version) + case 0x14: + frame, err = parseDataBlockedFrame(r, p.version) + case 0x15: + frame, err = parseStreamDataBlockedFrame(r, p.version) + case 0x16, 0x17: + frame, err = parseStreamsBlockedFrame(r, p.version) + case 0x18: + frame, err = parseNewConnectionIDFrame(r, p.version) + case 0x19: + frame, err = parseRetireConnectionIDFrame(r, p.version) + case 0x1a: + frame, err = parsePathChallengeFrame(r, p.version) + case 0x1b: + frame, err = parsePathResponseFrame(r, p.version) + case 0x1c, 0x1d: + frame, err = parseConnectionCloseFrame(r, p.version) + case 0x1e: + frame, err = parseHandshakeDoneFrame(r, p.version) + case 0x30, 0x31: + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, p.version) + break + } + fallthrough + default: + err = errors.New("unknown frame type") + } + } + if err != nil { + return nil, err + } + if !p.isAllowedAtEncLevel(frame, encLevel) { + return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + } + return frame, nil +} + +func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: + return true + default: + return false + } + case protocol.Encryption0RTT: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + return false + default: + return true + } + case protocol.Encryption1RTT: + return true + default: + panic("unknown encryption level") + } +} + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go new file mode 100644 index 00000000..ff5c83b8 --- /dev/null +++ b/internal/wire/frame_parser_test.go @@ -0,0 +1,410 @@ +package wire + +import ( + "bytes" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Frame parsing", func() { + var ( + buf *bytes.Buffer + parser FrameParser + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + parser = NewFrameParser(true, protocol.Version1) + }) + + It("returns nil if there's nothing more to read", func() { + f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeNil()) + }) + + It("skips PADDING frames", func() { + buf.Write([]byte{0}) // PADDING frame + (&PingFrame{}).Write(buf, protocol.Version1) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&PingFrame{})) + }) + + It("handles PADDING at the end", func() { + r := bytes.NewReader([]byte{0, 0, 0}) + f, err := parser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeNil()) + Expect(r.Len()).To(BeZero()) + }) + + It("unpacks ACK frames", func() { + f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) + }) + + It("uses the custom ack delay exponent for 1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + // The ACK frame is always written using the protocol.AckDelayExponent. + // That's why we expect a different value when parsing. + Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) + }) + + It("uses the default ack delay exponent for non-1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) + }) + + It("unpacks RESET_STREAM frames", func() { + f := &ResetStreamFrame{ + StreamID: 0xdeadbeef, + FinalSize: 0xdecafbad1234, + ErrorCode: 0x1337, + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STOP_SENDING frames", func() { + f := &StopSendingFrame{StreamID: 0x42} + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks CRYPTO frames", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("lorem ipsum"), + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks NEW_TOKEN frames", func() { + f := &NewTokenFrame{Token: []byte("foobar")} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAM frames", func() { + f := &StreamFrame{ + StreamID: 0x42, + Offset: 0x1337, + Fin: true, + Data: []byte("foobar"), + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_DATA frames", func() { + f := &MaxDataFrame{ + MaximumData: 0xcafe, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_STREAM_DATA frames", func() { + f := &MaxStreamDataFrame{ + StreamID: 0xdeadbeef, + MaximumStreamData: 0xdecafbad, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_STREAMS frames", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 0x1337, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks DATA_BLOCKED frames", func() { + f := &DataBlockedFrame{MaximumData: 0x1234} + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAM_DATA_BLOCKED frames", func() { + f := &StreamDataBlockedFrame{ + StreamID: 0xdeadbeef, + MaximumStreamData: 0xdead, + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAMS_BLOCKED frames", func() { + f := &StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 0x1234567, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks NEW_CONNECTION_ID frames", func() { + f := &NewConnectionIDFrame{ + SequenceNumber: 0x1337, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + } + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks RETIRE_CONNECTION_ID frames", func() { + f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks PATH_CHALLENGE frames", func() { + f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*PathChallengeFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("unpacks PATH_RESPONSE frames", func() { + f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("unpacks CONNECTION_CLOSE frames", func() { + f := &ConnectionCloseFrame{ + IsApplicationError: true, + ReasonPhrase: "foobar", + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks HANDSHAKE_DONE frames", func() { + f := &HandshakeDoneFrame{} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks DATAGRAM frames", func() { + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when DATAGRAM frames are not supported", func() { + parser = NewFrameParser(false, protocol.Version1) + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x30, + ErrorMessage: "unknown frame type", + })) + }) + + It("errors on invalid type", func() { + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x42, + ErrorMessage: "unknown frame type", + })) + }) + + It("errors on invalid frames", func() { + f := &MaxStreamDataFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + b := &bytes.Buffer{} + f.Write(b, protocol.Version1) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) + Expect(err).To(HaveOccurred()) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + }) + + Context("encryption level check", func() { + frames := []Frame{ + &PingFrame{}, + &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 42}}}, + &ResetStreamFrame{}, + &StopSendingFrame{}, + &CryptoFrame{}, + &NewTokenFrame{Token: []byte("lorem ipsum")}, + &StreamFrame{Data: []byte("foobar")}, + &MaxDataFrame{}, + &MaxStreamDataFrame{}, + &MaxStreamsFrame{}, + &DataBlockedFrame{}, + &StreamDataBlockedFrame{}, + &StreamsBlockedFrame{}, + &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + &RetireConnectionIDFrame{}, + &PathChallengeFrame{}, + &PathResponseFrame{}, + &ConnectionCloseFrame{}, + &HandshakeDoneFrame{}, + &DatagramFrame{}, + } + + var framesSerialized [][]byte + + BeforeEach(func() { + framesSerialized = nil + for _, frame := range frames { + buf := &bytes.Buffer{} + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) + framesSerialized = append(framesSerialized, buf.Bytes()) + } + }) + + It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Initial packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionInitial) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Initial")) + } + } + }) + + It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Handshake packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Handshake")) + } + } + }) + + It("rejects all frames but ACK, CRYPTO, CONNECTION_CLOSE, NEW_TOKEN, PATH_RESPONSE and RETIRE_CONNECTION_ID in 0-RTT packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption0RTT) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level 0-RTT")) + default: + Expect(err).ToNot(HaveOccurred()) + } + } + }) + + It("accepts all frame types in 1-RTT packets", func() { + for _, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) +}) diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go new file mode 100644 index 00000000..fad59371 --- /dev/null +++ b/internal/wire/handshake_done_frame.go @@ -0,0 +1,29 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" +) + +// A HandshakeDoneFrame is a HANDSHAKE_DONE frame +type HandshakeDoneFrame struct{} + +// ParseHandshakeDoneFrame parses a HandshakeDone frame +func parseHandshakeDoneFrame(r *bytes.Reader, _ quic.VersionNumber) (*HandshakeDoneFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &HandshakeDoneFrame{}, nil +} + +func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x1e) + return nil +} + +// Length of a written frame +func (f *HandshakeDoneFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/internal/wire/header.go b/internal/wire/header.go new file mode 100644 index 00000000..2cce3fac --- /dev/null +++ b/internal/wire/header.go @@ -0,0 +1,275 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// ParseConnectionID parses the destination connection ID of a packet. +// It uses the data slice for the connection ID. +// That means that the connection ID must not be used after the packet buffer is released. +func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { + if len(data) == 0 { + return nil, io.EOF + } + isLongHeader := data[0]&0x80 > 0 + if !isLongHeader { + if len(data) < shortHeaderConnIDLen+1 { + return nil, io.EOF + } + return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + } + if len(data) < 6 { + return nil, io.EOF + } + destConnIDLen := int(data[5]) + if len(data) < 6+destConnIDLen { + return nil, io.EOF + } + return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil +} + +// IsVersionNegotiationPacket says if this is a version negotiation packet +func IsVersionNegotiationPacket(b []byte) bool { + if len(b) < 5 { + return false + } + return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 +} + +// Is0RTTPacket says if this is a 0-RTT packet. +// A packet sent with a version we don't understand can never be a 0-RTT packet. +func Is0RTTPacket(b []byte) bool { + if len(b) < 5 { + return false + } + if b[0]&0x80 == 0 { + return false + } + version := quic.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { + return false + } + if version == protocol.Version2 { + return b[0]>>4&0b11 == 0b10 + } + return b[0]>>4&0b11 == 0b01 +} + +var ErrUnsupportedVersion = errors.New("unsupported version") + +// The Header is the version independent part of the header +type Header struct { + IsLongHeader bool + typeByte byte + Type protocol.PacketType + + Version quic.VersionNumber + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID + + Length protocol.ByteCount + + Token []byte + + parsedLen protocol.ByteCount // how many bytes were read while parsing this header +} + +// ParsePacket parses a packet. +// If the packet has a long header, the packet is cut according to the length field. +// If we understand the version, the packet is header up unto the packet number. +// Otherwise, only the invariant part of the header is parsed. +func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { + hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) + if err != nil { + if err == ErrUnsupportedVersion { + return hdr, nil, nil, ErrUnsupportedVersion + } + return nil, nil, nil, err + } + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { + return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + return hdr, data, rest, nil +} + +// ParseHeader parses the header. +// For short header packets: up to the packet number. +// For long header packets: +// * if we understand the version: up to the packet number +// * if not, only the invariant part of the header +func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + startLen := b.Len() + h, err := parseHeaderImpl(b, shortHeaderConnIDLen) + if err != nil { + return h, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return h, err +} + +func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + + h := &Header{ + typeByte: typeByte, + IsLongHeader: typeByte&0x80 > 0, + } + + if !h.IsLongHeader { + if h.typeByte&0x40 == 0 { + return nil, errors.New("not a QUIC packet") + } + if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { + return nil, err + } + return h, nil + } + return h, h.parseLongHeader(b) +} + +func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { + var err error + h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) + return err +} + +func (h *Header) parseLongHeader(b *bytes.Reader) error { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.Version = quic.VersionNumber(v) + if h.Version != 0 && h.typeByte&0x40 == 0 { + return errors.New("not a QUIC packet") + } + destConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) + if err != nil { + return err + } + srcConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) + if err != nil { + return err + } + if h.Version == 0 { // version negotiation packet + return nil + } + // If we don't understand the version, we have no idea how to interpret the rest of the bytes + if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { + return ErrUnsupportedVersion + } + + if h.Version == protocol.Version2 { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeRetry + case 0b01: + h.Type = protocol.PacketTypeInitial + case 0b10: + h.Type = protocol.PacketType0RTT + case 0b11: + h.Type = protocol.PacketTypeHandshake + } + } else { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeInitial + case 0b01: + h.Type = protocol.PacketType0RTT + case 0b10: + h.Type = protocol.PacketTypeHandshake + case 0b11: + h.Type = protocol.PacketTypeRetry + } + } + + if h.Type == protocol.PacketTypeRetry { + tokenLen := b.Len() - 16 + if tokenLen <= 0 { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + _, err := b.Seek(16, io.SeekCurrent) + return err + } + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := quicvarint.Read(b) + if err != nil { + return err + } + if tokenLen > uint64(b.Len()) { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + } + + pl, err := quicvarint.Read(b) + if err != nil { + return err + } + h.Length = protocol.ByteCount(pl) + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *Header) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// ParseExtended parses the version dependent part of the header. +// The Reader has to be set such that it points to the first byte of the header. +func (h *Header) ParseExtended(b *bytes.Reader, ver quic.VersionNumber) (*ExtendedHeader, error) { + extHdr := h.toExtendedHeader() + reservedBitsValid, err := extHdr.parse(b, ver) + if err != nil { + return nil, err + } + if !reservedBitsValid { + return extHdr, ErrInvalidReservedBits + } + return extHdr, nil +} + +func (h *Header) toExtendedHeader() *ExtendedHeader { + return &ExtendedHeader{Header: *h} +} + +// PacketType is the type of the packet, for logging purposes +func (h *Header) PacketType() string { + if h.IsLongHeader { + return h.Type.String() + } + return "1-RTT" +} diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go new file mode 100644 index 00000000..f7ecab39 --- /dev/null +++ b/internal/wire/header_test.go @@ -0,0 +1,583 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Header Parsing", func() { + Context("Parsing the Connection ID", func() { + It("parses the connection ID of a long header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + connID, err := ParseConnectionID(buf.Bytes(), 8) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("parses the connection ID of a short header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + buf.Write([]byte("foobar")) + connID, err := ParseConnectionID(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("errors on EOF, for short header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < len(data); i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + + It("errors on EOF, for long header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("identifying 0-RTT packets", func() { + It("recognizes 0-RTT packets, for QUIC v1", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b01<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) + + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) + Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header + Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) + }) + + It("recognizes 0-RTT packets, for QUIC v2", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b10<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) + + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) + Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header + Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) + }) + }) + + Context("Identifying Version Negotiation Packets", func() { + It("identifies version negotiation packets", func() { + Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) + Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) + }) + + It("returns false on EOF", func() { + vnp := []byte{0x80, 0, 0, 0, 0} + for i := range vnp { + Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) + } + }) + }) + + Context("Long Headers", func() { + It("parses a Long Header", func() { + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + data := []byte{0xc0 ^ 0x3} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x9) // dest conn id length + data = append(data, destConnID...) + data = append(data, 0x4) // src conn id length + data = append(data, srcConnID...) + data = append(data, encodeVarInt(6)...) // token length + data = append(data, []byte("foobar")...) // token + data = append(data, encodeVarInt(10)...) // length + hdrLen := len(data) + data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number + data = append(data, []byte("foobar")...) + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(Equal(data)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) + Expect(hdr.Version).To(Equal(protocol.Version1)) + Expect(rest).To(BeEmpty()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(b.Len()).To(Equal(6)) // foobar + Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4)) + }) + + It("errors if 0x40 is not set", func() { + data := []byte{ + 0x80 | 0x2<<4, + 0x11, // connection ID lengths + 0xde, 0xca, 0xfb, 0xad, // dest conn ID + 0xde, 0xad, 0xbe, 0xef, // src conn ID + } + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError("not a QUIC packet")) + }) + + It("stops parsing when encountering an unsupported version", func() { + data := []byte{ + 0xc0, + 0xde, 0xad, 0xbe, 0xef, + 0x8, // dest conn ID len + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID + 0x8, // src conn ID len + 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID + 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes + } + hdr, _, rest, err := ParsePacket(data, 0) + Expect(err).To(MatchError(ErrUnsupportedVersion)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.Version).To(Equal(quic.VersionNumber(0xdeadbeef))) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) + Expect(rest).To(BeEmpty()) + }) + + It("parses a Long Header without a destination connection ID", func() { + data := []byte{0xc0 ^ 0x1<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x0) // dest conn ID len + data = append(data, 0x4) // src conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(hdr.DestConnectionID).To(BeEmpty()) + }) + + It("parses a Long Header without a source connection ID", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0xa) // dest conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID + data = append(data, 0x0) // src conn ID len + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.SrcConnectionID).To(BeEmpty()) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + }) + + It("parses a Long Header with a 2 byte packet number", func() { + data := []byte{0xc0 ^ 0x1} + data = appendVersion(data, protocol.Version1) // version number + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(0)...) // token length + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0x1, 0x23}...) + + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("parses a Retry packet, for QUIC v1", func() { + data := []byte{0xc0 | 0b11<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{6}...) // dest conn ID len + data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID + data = append(data, []byte{10}...) // src conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version1)) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("parses a Retry packet, for QUIC v2", func() { + data := []byte{0xc0 | 0b00<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version2) + data = append(data, []byte{6}...) // dest conn ID len + data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID + data = append(data, []byte{10}...) // src conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version2)) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("errors if the Retry packet is too short for the integrity tag", func() { + data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0, 0}...) // conn ID lens + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) + // this results in a token length of 0 + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors if the token length is too large", func() { + data := []byte{0xc0 ^ 0x1} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x0) // connection ID lengths + data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) + data = append(data, encodeVarInt(0x42)...) // length, 1 byte + data = append(data, []byte{0x12, 0x34}...) // packet number + + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors if the 5th or 6th bit are set", func() { + data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(2)...) // length + data = append(data, []byte{0x12, 0x34}...) // packet number + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(MatchError(ErrInvalidReservedBits)) + Expect(extHdr).ToNot(BeNil()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234))) + }) + + It("errors on EOF, when parsing the header", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x8) // dest conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID + data = append(data, 0x8) // src conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID + for i := 0; i < len(data); i++ { + _, _, _, err := ParsePacket(data[:i], 0) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, when parsing the extended header", func() { + data := []byte{0xc0 | 0x2<<4 | 0x3} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(0)...) // length + hdrLen := len(data) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + _, err = hdr.ParseExtended(b, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, for a Retry packet", func() { + data := []byte{0xc0 ^ 0x3<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, 0xa) // Orig Destination Connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + hdrLen := len(data) + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + _, err = hdr.ParseExtended(b, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + + Context("coalesced packets", func() { + It("cuts packets", func() { + buf := &bytes.Buffer{} + hdr := Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 2 + 6, + Version: protocol.Version1, + } + Expect((&ExtendedHeader{ + Header: hdr, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + hdrRaw := append([]byte{}, buf.Bytes()...) + buf.Write([]byte("foobar")) // payload of the first packet + buf.Write([]byte("raboof")) // second packet + parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(parsedHdr.Type).To(Equal(hdr.Type)) + Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) + Expect(rest).To(Equal([]byte("raboof"))) + }) + + It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 3, + Version: protocol.Version1, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) + }) + + It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 1000, + Version: protocol.Version1, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + buf.Write(make([]byte, 500-2 /* for packet number length */)) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + }) + }) + + Context("Short Headers", func() { + It("reads a Short Header with a 8 byte connection ID", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + data := append([]byte{0x40}, connID...) + data = append(data, 0x42) // packet number + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + + hdr, pdata, rest, err := ParsePacket(data, 8) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) + Expect(extHdr.DestConnectionID).To(Equal(connID)) + Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("errors if 0x40 is not set", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + data := append([]byte{0x0}, connID...) + _, _, _, err := ParsePacket(data, 8) + Expect(err).To(MatchError("not a QUIC packet")) + }) + + It("errors if the 4th or 5th bit are set", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) + data = append(data, 0x42) // packet number + hdr, _, _, err := ParsePacket(data, 5) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(MatchError(ErrInvalidReservedBits)) + Expect(extHdr).ToNot(BeNil()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + }) + + It("reads a Short Header with a 5 byte connection ID", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + data := append([]byte{0x40}, connID...) + data = append(data, 0x42) // packet number + hdr, pdata, rest, err := ParsePacket(data, 5) + Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(HaveLen(len(data))) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) + Expect(extHdr.DestConnectionID).To(Equal(connID)) + Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(rest).To(BeEmpty()) + }) + + It("reads the Key Phase Bit", func() { + data := []byte{ + 0x40 ^ 0x4, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID + } + data = append(data, 11) // packet number + hdr, _, _, err := ParsePacket(data, 6) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a header with a 2 byte packet number", func() { + data := []byte{ + 0x40 | 0x1, + 0xde, 0xad, 0xbe, 0xef, // connection ID + } + data = append(data, []byte{0x13, 0x37}...) // packet number + hdr, _, _, err := ParsePacket(data, 4) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.IsLongHeader).To(BeFalse()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a header with a 3 byte packet number", func() { + data := []byte{ + 0x40 | 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID + } + data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number + hdr, _, _, err := ParsePacket(data, 10) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.IsLongHeader).To(BeFalse()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOF, when parsing the header", func() { + data := []byte{ + 0x40 ^ 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + } + for i := 0; i < len(data); i++ { + data = data[:i] + _, _, _, err := ParsePacket(data, 8) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, when parsing the extended header", func() { + data := []byte{ + 0x40 ^ 0x3, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID + } + hdrLen := len(data) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 6) + Expect(err).ToNot(HaveOccurred()) + _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + }) + + It("tells its packet type for logging", func() { + Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake")) + Expect((&Header{}).PacketType()).To(Equal("1-RTT")) + }) +}) diff --git a/internal/wire/interface.go b/internal/wire/interface.go new file mode 100644 index 00000000..b096a6e1 --- /dev/null +++ b/internal/wire/interface.go @@ -0,0 +1,20 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" +) + +// A Frame in QUIC +type Frame interface { + Write(b *bytes.Buffer, version quic.VersionNumber) error + Length(version quic.VersionNumber) protocol.ByteCount +} + +// A FrameParser parses QUIC frames, one by one. +type FrameParser interface { + ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + SetAckDelayExponent(uint8) +} diff --git a/internal/wire/log.go b/internal/wire/log.go new file mode 100644 index 00000000..276465ee --- /dev/null +++ b/internal/wire/log.go @@ -0,0 +1,72 @@ +package wire + +import ( + "fmt" + "strings" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" +) + +// LogFrame logs a frame, either sent or received +func LogFrame(logger utils.Logger, frame Frame, sent bool) { + if !logger.Debug() { + return + } + dir := "<-" + if sent { + dir = "->" + } + switch f := frame.(type) { + case *CryptoFrame: + dataLen := protocol.ByteCount(len(f.Data)) + logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) + case *StreamFrame: + logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *ResetStreamFrame: + logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) + case *AckFrame: + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + var ecn string + if hasECN { + ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) + } + if len(f.AckRanges) > 1 { + ackRanges := make([]string, len(f.AckRanges)) + for i, r := range f.AckRanges { + ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) + } + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) + } else { + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) + } + case *MaxDataFrame: + logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) + case *MaxStreamDataFrame: + logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *DataBlockedFrame: + logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) + case *StreamDataBlockedFrame: + logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *MaxStreamsFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) + } + case *StreamsBlockedFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) + } + case *NewConnectionIDFrame: + logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) + case *NewTokenFrame: + logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) + default: + logger.Debugf("\t%s %#v", dir, frame) + } +} diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go new file mode 100644 index 00000000..7094fcc5 --- /dev/null +++ b/internal/wire/log_test.go @@ -0,0 +1,168 @@ +package wire + +import ( + "bytes" + "log" + "os" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Frame logging", func() { + var ( + buf *bytes.Buffer + logger utils.Logger + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) + }) + + AfterEach(func() { + log.SetOutput(os.Stdout) + }) + + It("doesn't log when debug is disabled", func() { + logger.SetLogLevel(utils.LogLevelInfo) + LogFrame(logger, &ResetStreamFrame{}, true) + Expect(buf.Len()).To(BeZero()) + }) + + It("logs sent frames", func() { + LogFrame(logger, &ResetStreamFrame{}, true) + Expect(buf.String()).To(ContainSubstring("\t-> &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) + }) + + It("logs received frames", func() { + LogFrame(logger, &ResetStreamFrame{}, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) + }) + + It("logs CRYPTO frames", func() { + frame := &CryptoFrame{ + Offset: 42, + Data: make([]byte, 123), + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.CryptoFrame{Offset: 42, Data length: 123, Offset + Data length: 165}\n")) + }) + + It("logs STREAM frames", func() { + frame := &StreamFrame{ + StreamID: 42, + Offset: 1337, + Data: bytes.Repeat([]byte{'f'}, 100), + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamFrame{StreamID: 42, Fin: false, Offset: 1337, Data length: 100, Offset + Data length: 1437}\n")) + }) + + It("logs ACK frames without missing packets", func() { + frame := &AckFrame{ + AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, + DelayTime: 1 * time.Millisecond, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms}\n")) + }) + + It("logs ACK frames with ECN", func() { + frame := &AckFrame{ + AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, + DelayTime: 1 * time.Millisecond, + ECT0: 5, + ECT1: 66, + ECNCE: 777, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms, ECT0: 5, ECT1: 66, CE: 777}\n")) + }) + + It("logs ACK frames with missing packets", func() { + frame := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 8}, + {Smallest: 2, Largest: 3}, + }, + DelayTime: 12 * time.Millisecond, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 8, LowestAcked: 2, AckRanges: {{Largest: 8, Smallest: 5}, {Largest: 3, Smallest: 2}}, DelayTime: 12ms}\n")) + }) + + It("logs MAX_STREAMS frames", func() { + frame := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: 42}\n")) + }) + + It("logs MAX_DATA frames", func() { + frame := &MaxDataFrame{ + MaximumData: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxDataFrame{MaximumData: 42}\n")) + }) + + It("logs MAX_STREAM_DATA frames", func() { + frame := &MaxStreamDataFrame{ + StreamID: 10, + MaximumStreamData: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 42}\n")) + }) + + It("logs DATA_BLOCKED frames", func() { + frame := &DataBlockedFrame{ + MaximumData: 1000, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.DataBlockedFrame{MaximumData: 1000}\n")) + }) + + It("logs STREAM_DATA_BLOCKED frames", func() { + frame := &StreamDataBlockedFrame{ + StreamID: 42, + MaximumStreamData: 1000, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1000}\n")) + }) + + It("logs STREAMS_BLOCKED frames", func() { + frame := &StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: 42}\n")) + }) + + It("logs NEW_CONNECTION_ID frames", func() { + LogFrame(logger, &NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, + }, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) + }) + + It("logs NEW_TOKEN frames", func() { + LogFrame(logger, &NewTokenFrame{ + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + }, true) + Expect(buf.String()).To(ContainSubstring("\t-> &wire.NewTokenFrame{Token: 0xdeadbeef")) + }) +}) diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go new file mode 100644 index 00000000..31ec503d --- /dev/null +++ b/internal/wire/max_data_frame.go @@ -0,0 +1,41 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A MaxDataFrame carries flow control information for the connection +type MaxDataFrame struct { + MaximumData protocol.ByteCount +} + +// parseMaxDataFrame parses a MAX_DATA frame +func parseMaxDataFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &MaxDataFrame{} + byteOffset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.MaximumData = protocol.ByteCount(byteOffset) + return frame, nil +} + +// Write writes a MAX_STREAM_DATA frame +func (f *MaxDataFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + b.WriteByte(0x10) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *MaxDataFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) +} diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go new file mode 100644 index 00000000..c363ecd8 --- /dev/null +++ b/internal/wire/max_data_frame_test.go @@ -0,0 +1,57 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_DATA frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x10} + data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset + b := bytes.NewReader(data) + frame, err := parseMaxDataFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x10} + data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset + _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &MaxDataFrame{ + MaximumData: 0xdeadbeef, + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) + }) + + It("writes a MAX_DATA frame", func() { + b := &bytes.Buffer{} + f := &MaxDataFrame{ + MaximumData: 0xdeadbeefcafe, + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go new file mode 100644 index 00000000..bc300279 --- /dev/null +++ b/internal/wire/max_stream_data_frame.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A MaxStreamDataFrame is a MAX_STREAM_DATA frame +type MaxStreamDataFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseMaxStreamDataFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxStreamDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + b.WriteByte(0x11) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *MaxStreamDataFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) +} diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go new file mode 100644 index 00000000..4e205ad9 --- /dev/null +++ b/internal/wire/max_stream_data_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAM_DATA frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x11} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset + b := bytes.NewReader(data) + frame, err := parseMaxStreamDataFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x11} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset + _, err := parseMaxStreamDataFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &MaxStreamDataFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + Expect(f.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)))) + }) + + It("writes a sample frame", func() { + b := &bytes.Buffer{} + f := &MaxStreamDataFrame{ + StreamID: 0xdecafbad, + MaximumStreamData: 0xdeadbeefcafe42, + } + expected := []byte{0x11} + expected = append(expected, encodeVarInt(0xdecafbad)...) + expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go new file mode 100644 index 00000000..d25c2cad --- /dev/null +++ b/internal/wire/max_streams_frame.go @@ -0,0 +1,56 @@ +package wire + +import ( + "bytes" + "fmt" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A MaxStreamsFrame is a MAX_STREAMS frame +type MaxStreamsFrame struct { + Type protocol.StreamType + MaxStreamNum protocol.StreamNum +} + +func parseMaxStreamsFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxStreamsFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &MaxStreamsFrame{} + switch typeByte { + case 0x12: + f.Type = protocol.StreamTypeBidi + case 0x13: + f.Type = protocol.StreamTypeUni + } + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.MaxStreamNum = protocol.StreamNum(streamID) + if f.MaxStreamNum > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + } + return f, nil +} + +func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x12) + case protocol.StreamTypeUni: + b.WriteByte(0x13) + } + quicvarint.Write(b, uint64(f.MaxStreamNum)) + return nil +} + +// Length of a written frame +func (f *MaxStreamsFrame) Length(quic.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) +} diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go new file mode 100644 index 00000000..bc7ec913 --- /dev/null +++ b/internal/wire/max_streams_frame_test.go @@ -0,0 +1,107 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAMS frame", func() { + Context("parsing", func() { + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x12} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) + Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x13} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeUni)) + Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x1d} + data = append(data, encodeVarInt(0xdeadbeefcafe13)...) + _, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } + }) + + Context("writing", func() { + It("for a bidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 0xdeadbeef, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x12} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("for a unidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: 0xdecafbad, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x13} + expected = append(expected, encodeVarInt(0xdecafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := MaxStreamsFrame{MaxStreamNum: 0x1337} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(0x1337))) + }) + }) +}) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go new file mode 100644 index 00000000..a79603b9 --- /dev/null +++ b/internal/wire/new_connection_id_frame.go @@ -0,0 +1,81 @@ +package wire + +import ( + "bytes" + "fmt" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame +type NewConnectionIDFrame struct { + SequenceNumber uint64 + RetirePriorTo uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +func parseNewConnectionIDFrame(r *bytes.Reader, _ quic.VersionNumber) (*NewConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ret, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ret > seq { + //nolint:stylecheck + return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) + } + connIDLen, err := r.ReadByte() + if err != nil { + return nil, err + } + if connIDLen > protocol.MaxConnIDLen { + return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return nil, err + } + frame := &NewConnectionIDFrame{ + SequenceNumber: seq, + RetirePriorTo: ret, + ConnectionID: connID, + } + if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + + return frame, nil +} + +func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x18) + quicvarint.Write(b, f.SequenceNumber) + quicvarint.Write(b, f.RetirePriorTo) + connIDLen := f.ConnectionID.Len() + if connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + b.WriteByte(uint8(connIDLen)) + b.Write(f.ConnectionID.Bytes()) + b.Write(f.StatelessResetToken[:]) + return nil +} + +// Length of a written frame +func (f *NewConnectionIDFrame) Length(quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)+quicvarint.Len(f.RetirePriorTo)) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 +} diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go new file mode 100644 index 00000000..91bc2e20 --- /dev/null +++ b/internal/wire/new_connection_id_frame_test.go @@ -0,0 +1,104 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_CONNECTION_ID frame", func() { + Context("when parsing", func() { + It("accepts a sample frame", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe)...) // retire prior to + data = append(data, 10) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + frame, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) + Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) + Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) + }) + + It("errors when the Retire Prior To value is larger than the Sequence Number", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(1000)...) // sequence number + data = append(data, encodeVarInt(1001)...) // retire prior to + data = append(data, 3) + data = append(data, []byte{1, 2, 3}...) + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) + }) + + It("errors when the connection ID has an invalid length", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe)...) // retire prior to + data = append(data, 21) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).To(MatchError("invalid connection ID length: 21")) + }) + + It("errors on EOFs", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe1234)...) // retire prior to + data = append(data, 10) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + frame := &NewConnectionIDFrame{ + SequenceNumber: 0x1337, + RetirePriorTo: 0x42, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + StatelessResetToken: token, + } + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x18} + expected = append(expected, encodeVarInt(0x1337)...) + expected = append(expected, encodeVarInt(0x42)...) + expected = append(expected, 6) + expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) + expected = append(expected, token[:]...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct length", func() { + token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + frame := &NewConnectionIDFrame{ + SequenceNumber: 0xdecafbad, + RetirePriorTo: 0xdeadbeefcafe, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + StatelessResetToken: token, + } + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + }) + }) +}) diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go new file mode 100644 index 00000000..a7ff519f --- /dev/null +++ b/internal/wire/new_token_frame.go @@ -0,0 +1,49 @@ +package wire + +import ( + "bytes" + "errors" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A NewTokenFrame is a NEW_TOKEN frame +type NewTokenFrame struct { + Token []byte +} + +func parseNewTokenFrame(r *bytes.Reader, _ quic.VersionNumber) (*NewTokenFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + tokenLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if uint64(r.Len()) < tokenLen { + return nil, io.EOF + } + if tokenLen == 0 { + return nil, errors.New("token must not be empty") + } + token := make([]byte, int(tokenLen)) + if _, err := io.ReadFull(r, token); err != nil { + return nil, err + } + return &NewTokenFrame{Token: token}, nil +} + +func (f *NewTokenFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x7) + quicvarint.Write(b, uint64(len(f.Token))) + b.Write(f.Token) + return nil +} + +// Length of a written frame +func (f *NewTokenFrame) Length(quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(len(f.Token)))) + protocol.ByteCount(len(f.Token)) +} diff --git a/internal/wire/new_token_frame_test.go b/internal/wire/new_token_frame_test.go new file mode 100644 index 00000000..c4a6685c --- /dev/null +++ b/internal/wire/new_token_frame_test.go @@ -0,0 +1,66 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_TOKEN frame", func() { + Context("parsing", func() { + It("accepts a sample frame", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(len(token)))...) + data = append(data, token...) + b := bytes.NewReader(data) + f, err := parseNewTokenFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(string(f.Token)).To(Equal(token)) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects empty tokens", func() { + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(0))...) + b := bytes.NewReader(data) + _, err := parseNewTokenFrame(b, protocol.VersionWhatever) + Expect(err).To(MatchError("token must not be empty")) + }) + + It("errors on EOFs", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(len(token)))...) + data = append(data, token...) + _, err := parseNewTokenFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseNewTokenFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("writes a sample frame", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." + f := &NewTokenFrame{Token: []byte(token)} + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x7} + expected = append(expected, encodeVarInt(uint64(len(token)))...) + expected = append(expected, token...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := &NewTokenFrame{Token: []byte("foobar")} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(6) + 6)) + }) + }) +}) diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go new file mode 100644 index 00000000..ae519a7f --- /dev/null +++ b/internal/wire/path_challenge_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" +) + +// A PathChallengeFrame is a PATH_CHALLENGE frame +type PathChallengeFrame struct { + Data [8]byte +} + +func parsePathChallengeFrame(r *bytes.Reader, _ quic.VersionNumber) (*PathChallengeFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathChallengeFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x1a) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathChallengeFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go new file mode 100644 index 00000000..52d08d90 --- /dev/null +++ b/internal/wire/path_challenge_frame_test.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PATH_CHALLENGE frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8}) + f, err := parsePathChallengeFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("errors on EOFs", func() { + data := []byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8} + _, err := parsePathChallengeFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + }) + + It("has the correct min length", func() { + frame := PathChallengeFrame{} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) + }) + }) +}) diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go new file mode 100644 index 00000000..d8dbebdc --- /dev/null +++ b/internal/wire/path_response_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" +) + +// A PathResponseFrame is a PATH_RESPONSE frame +type PathResponseFrame struct { + Data [8]byte +} + +func parsePathResponseFrame(r *bytes.Reader, _ quic.VersionNumber) (*PathResponseFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathResponseFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathResponseFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x1b) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathResponseFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go new file mode 100644 index 00000000..872d1c59 --- /dev/null +++ b/internal/wire/path_response_frame_test.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PATH_RESPONSE frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8}) + f, err := parsePathResponseFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("errors on EOFs", func() { + data := []byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8} + _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + }) + + It("has the correct min length", func() { + frame := PathResponseFrame{} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) + }) + }) +}) diff --git a/internal/wire/ping_frame.go b/internal/wire/ping_frame.go new file mode 100644 index 00000000..38b47c2f --- /dev/null +++ b/internal/wire/ping_frame.go @@ -0,0 +1,28 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" +) + +// A PingFrame is a PING frame +type PingFrame struct{} + +func parsePingFrame(r *bytes.Reader, _ quic.VersionNumber) (*PingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &PingFrame{}, nil +} + +func (f *PingFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + b.WriteByte(0x1) + return nil +} + +// Length of a written frame +func (f *PingFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/internal/wire/ping_frame_test.go b/internal/wire/ping_frame_test.go new file mode 100644 index 00000000..3664731a --- /dev/null +++ b/internal/wire/ping_frame_test.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PingFrame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1}) + _, err := parsePingFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + _, err := parsePingFrame(bytes.NewReader(nil), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PingFrame{} + frame.Write(b, protocol.VersionWhatever) + Expect(b.Bytes()).To(Equal([]byte{0x1})) + }) + + It("has the correct min length", func() { + frame := PingFrame{} + Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1))) + }) + }) +}) diff --git a/internal/wire/pool.go b/internal/wire/pool.go new file mode 100644 index 00000000..aba32137 --- /dev/null +++ b/internal/wire/pool.go @@ -0,0 +1,33 @@ +package wire + +import ( + "sync" + + "github.com/imroc/req/v3/internal/protocol" +) + +var pool sync.Pool + +func init() { + pool.New = func() interface{} { + return &StreamFrame{ + Data: make([]byte, 0, protocol.MaxPacketBufferSize), + fromPool: true, + } + } +} + +func GetStreamFrame() *StreamFrame { + f := pool.Get().(*StreamFrame) + return f +} + +func putStreamFrame(f *StreamFrame) { + if !f.fromPool { + return + } + if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { + panic("wire.PutStreamFrame called with packet of wrong size!") + } + pool.Put(f) +} diff --git a/internal/wire/pool_test.go b/internal/wire/pool_test.go new file mode 100644 index 00000000..b55e493b --- /dev/null +++ b/internal/wire/pool_test.go @@ -0,0 +1,24 @@ +package wire + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Pool", func() { + It("gets and puts STREAM frames", func() { + f := GetStreamFrame() + putStreamFrame(f) + }) + + It("panics when putting a STREAM frame with a wrong capacity", func() { + f := GetStreamFrame() + f.Data = []byte("foobar") + Expect(func() { putStreamFrame(f) }).To(Panic()) + }) + + It("accepts STREAM frames not from the buffer, but ignores them", func() { + f := &StreamFrame{Data: []byte("foobar")} + putStreamFrame(f) + }) +}) diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go new file mode 100644 index 00000000..190ddda5 --- /dev/null +++ b/internal/wire/reset_stream_frame.go @@ -0,0 +1,59 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A ResetStreamFrame is a RESET_STREAM frame in QUIC +type ResetStreamFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode + FinalSize protocol.ByteCount +} + +func parseResetStreamFrame(r *bytes.Reader, _ quic.VersionNumber) (*ResetStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var byteOffset protocol.ByteCount + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + bo, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + + return &ResetStreamFrame{ + StreamID: streamID, + ErrorCode: qerr.StreamErrorCode(errorCode), + FinalSize: byteOffset, + }, nil +} + +func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x4) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + quicvarint.Write(b, uint64(f.FinalSize)) + return nil +} + +// Length of a written frame +func (f *ResetStreamFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) +} diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go new file mode 100644 index 00000000..e4b008d5 --- /dev/null +++ b/internal/wire/reset_stream_frame_test.go @@ -0,0 +1,70 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RESET_STREAM frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x4} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + b := bytes.NewReader(data) + frame, err := parseResetStreamFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) + }) + + It("errors on EOFs", func() { + data := []byte{0x4} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + frame := ResetStreamFrame{ + StreamID: 0x1337, + FinalSize: 0x11223344decafbad, + ErrorCode: 0xcafe, + } + b := &bytes.Buffer{} + err := frame.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x4} + expected = append(expected, encodeVarInt(0x1337)...) + expected = append(expected, encodeVarInt(0xcafe)...) + expected = append(expected, encodeVarInt(0x11223344decafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + rst := ResetStreamFrame{ + StreamID: 0x1337, + FinalSize: 0x1234567, + ErrorCode: 0xde, + } + expectedLen := 1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + 2 + Expect(rst.Length(protocol.Version1)).To(Equal(expectedLen)) + }) + }) +}) diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go new file mode 100644 index 00000000..1d8d1dbd --- /dev/null +++ b/internal/wire/retire_connection_id_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame +type RetireConnectionIDFrame struct { + SequenceNumber uint64 +} + +func parseRetireConnectionIDFrame(r *bytes.Reader, _ quic.VersionNumber) (*RetireConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &RetireConnectionIDFrame{SequenceNumber: seq}, nil +} + +func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x19) + quicvarint.Write(b, f.SequenceNumber) + return nil +} + +// Length of a written frame +func (f *RetireConnectionIDFrame) Length(quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)) +} diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go new file mode 100644 index 00000000..b67f733d --- /dev/null +++ b/internal/wire/retire_connection_id_frame_test.go @@ -0,0 +1,53 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_CONNECTION_ID frame", func() { + Context("when parsing", func() { + It("accepts a sample frame", func() { + data := []byte{0x19} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + b := bytes.NewReader(data) + frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) + }) + + It("errors on EOFs", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x19} + expected = append(expected, encodeVarInt(0x1337)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct length", func() { + frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + }) + }) +}) diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go new file mode 100644 index 00000000..f9d44db9 --- /dev/null +++ b/internal/wire/stop_sending_frame.go @@ -0,0 +1,49 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode +} + +// parseStopSendingFrame parses a STOP_SENDING frame +func parseStopSendingFrame(r *bytes.Reader, _ quic.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: qerr.StreamErrorCode(errorCode), + }, nil +} + +// Length of a written frame +func (f *StopSendingFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + b.WriteByte(0x5) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + return nil +} diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go new file mode 100644 index 00000000..9e709b32 --- /dev/null +++ b/internal/wire/stop_sending_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STOP_SENDING frame", func() { + Context("when parsing", func() { + It("parses a sample frame", func() { + data := []byte{0x5} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + b := bytes.NewReader(data) + frame, err := parseStopSendingFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x5} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, encodeVarInt(0x123456)...) // error code + _, err := parseStopSendingFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeefcafe, + ErrorCode: 0xdecafbad, + } + buf := &bytes.Buffer{} + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x5} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + expected = append(expected, encodeVarInt(0xdecafbad)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeef, + ErrorCode: 0x1234567, + } + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) + }) + }) +}) diff --git a/internal/wire/stream_data_blocked_frame.go b/internal/wire/stream_data_blocked_frame.go new file mode 100644 index 00000000..011d14c7 --- /dev/null +++ b/internal/wire/stream_data_blocked_frame.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame +type StreamDataBlockedFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseStreamDataBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamDataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StreamDataBlockedFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + b.WriteByte(0x15) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *StreamDataBlockedFrame) Length(version quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) +} diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go new file mode 100644 index 00000000..5306edcb --- /dev/null +++ b/internal/wire/stream_data_blocked_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM_DATA_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x15} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + b := bytes.NewReader(data) + frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x15} + data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xc0010ff)...) + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &StreamDataBlockedFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + Expect(f.Length(0)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0xdeadbeef))) + }) + + It("writes a sample frame", func() { + b := &bytes.Buffer{} + f := &StreamDataBlockedFrame{ + StreamID: 0xdecafbad, + MaximumStreamData: 0x1337, + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x15} + expected = append(expected, encodeVarInt(uint64(f.StreamID))...) + expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go new file mode 100644 index 00000000..ed21d388 --- /dev/null +++ b/internal/wire/stream_frame.go @@ -0,0 +1,190 @@ +package wire + +import ( + "bytes" + "errors" + "github.com/lucas-clemente/quic-go" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A StreamFrame of QUIC +type StreamFrame struct { + StreamID protocol.StreamID + Offset protocol.ByteCount + Data []byte + Fin bool + DataLenPresent bool + + fromPool bool +} + +func parseStreamFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + hasOffset := typeByte&0x4 > 0 + fin := typeByte&0x1 > 0 + hasDataLen := typeByte&0x2 > 0 + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + var offset uint64 + if hasOffset { + offset, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } + + var dataLen uint64 + if hasDataLen { + var err error + dataLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } else { + // The rest of the packet is data + dataLen = uint64(r.Len()) + } + + var frame *StreamFrame + if dataLen < protocol.MinStreamFrameBufferSize { + frame = &StreamFrame{Data: make([]byte, dataLen)} + } else { + frame = GetStreamFrame() + // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, + // since those StreamFrames have a buffer length of the maximum packet size. + if dataLen > uint64(cap(frame.Data)) { + return nil, io.EOF + } + frame.Data = frame.Data[:dataLen] + } + + frame.StreamID = protocol.StreamID(streamID) + frame.Offset = protocol.ByteCount(offset) + frame.Fin = fin + frame.DataLenPresent = hasDataLen + + if dataLen != 0 { + if _, err := io.ReadFull(r, frame.Data); err != nil { + return nil, err + } + } + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, errors.New("stream data overflows maximum offset") + } + return frame, nil +} + +// Write writes a STREAM frame +func (f *StreamFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { + if len(f.Data) == 0 && !f.Fin { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := byte(0x8) + if f.Fin { + typeByte ^= 0x1 + } + hasOffset := f.Offset != 0 + if f.DataLenPresent { + typeByte ^= 0x2 + } + if hasOffset { + typeByte ^= 0x4 + } + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.StreamID)) + if hasOffset { + quicvarint.Write(b, uint64(f.Offset)) + } + if f.DataLenPresent { + quicvarint.Write(b, uint64(f.DataLen())) + } + b.Write(f.Data) + return nil +} + +// Length returns the total length of the STREAM frame +func (f *StreamFrame) Length(version quic.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + length += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + length += quicvarint.Len(uint64(f.DataLen())) + } + return protocol.ByteCount(length) + f.DataLen() +} + +// DataLen gives the length of data in bytes +func (f *StreamFrame) DataLen() protocol.ByteCount { + return protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version quic.VersionNumber) protocol.ByteCount { + headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + headerLen += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if protocol.ByteCount(headerLen) > maxSize { + return 0 + } + maxDataLen := maxSize - protocol.ByteCount(headerLen) + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version quic.VersionNumber) (*StreamFrame, bool /* was splitting required */) { + if maxSize >= f.Length(version) { + return nil, false + } + + n := f.MaxDataLen(maxSize, version) + if n == 0 { + return nil, true + } + + new := GetStreamFrame() + new.StreamID = f.StreamID + new.Offset = f.Offset + new.Fin = false + new.DataLenPresent = f.DataLenPresent + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + new.fromPool, f.fromPool = f.fromPool, new.fromPool + + f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} + +func (f *StreamFrame) PutBack() { + putStreamFrame(f) +} diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go new file mode 100644 index 00000000..9e49de8a --- /dev/null +++ b/internal/wire/stream_frame_test.go @@ -0,0 +1,443 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame", func() { + Context("when parsing", func() { + It("parses a frame with OFF bit", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(r.Len()).To(BeZero()) + }) + + It("respects the LEN when parsing the frame", func() { + data := []byte{0x8 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(4)...) // data length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foob"))) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.Offset).To(BeZero()) + Expect(r.Len()).To(Equal(2)) + }) + + It("parses a frame with FIN bit", func() { + data := []byte{0x8 ^ 0x1} + data = append(data, encodeVarInt(9)...) // stream ID + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.Fin).To(BeTrue()) + Expect(frame.Offset).To(BeZero()) + Expect(r.Len()).To(BeZero()) + }) + + It("allows empty frames", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x1337)...) // stream ID + data = append(data, encodeVarInt(0x12345)...) // offset + r := bytes.NewReader(data) + f, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) + Expect(f.Data).To(BeEmpty()) + Expect(f.Fin).To(BeFalse()) + }) + + It("rejects frames that overflow the maximum offset", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + _, err := parseStreamFrame(r, protocol.Version1) + Expect(err).To(MatchError("stream data overflows maximum offset")) + }) + + It("rejects frames that claim to be longer than the packet size", func() { + data := []byte{0x8 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length + data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) + r := bytes.NewReader(data) + _, err := parseStreamFrame(r, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x8 ^ 0x4 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // data length + data = append(data, []byte("foobar")...) + _, err := parseStreamFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("using the buffer", func() { + It("uses the buffer for long STREAM frames", func() { + data := []byte{0x8} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) + Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize)) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.fromPool).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + Expect(frame.PutBack).ToNot(Panic()) + }) + + It("doesn't use the buffer for short STREAM frames", func() { + data := []byte{0x8} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) + Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1)) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.fromPool).To(BeFalse()) + Expect(r.Len()).To(BeZero()) + Expect(frame.PutBack).ToNot(Panic()) + }) + }) + + Context("when writing", func() { + It("writes a frame without offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with FIN bit", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + Fin: true, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4 ^ 0x1} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length and offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + Offset: 0x123456, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("refuses to write an empty frame without FIN", func() { + f := &StreamFrame{ + StreamID: 0x42, + Offset: 0x1337, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) + }) + }) + + Context("length", func() { + It("has the right length for a frame without offset and data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) + }) + + It("has the right length for a frame with offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x42, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) + }) + + It("has the right length for a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x1234567, + DataLenPresent: true, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("always returns a data length such that the resulting frame has the right size, if data length is not present", func() { + data := make([]byte, maxSize) + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0xdeadbeef, + } + b := &bytes.Buffer{} + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(Equal(i)) + } + }) + + It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { + data := make([]byte, maxSize) + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0xdeadbeef, + DataLenPresent: true, + } + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) + + Context("splitting", func() { + It("doesn't split if the frame is short enough", func() { + f := &StreamFrame{ + StreamID: 0x1337, + DataLenPresent: true, + Offset: 0xdeadbeef, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) + Expect(needsSplit).To(BeFalse()) + Expect(frame).To(BeNil()) + Expect(f.DataLen()).To(BeEquivalentTo(100)) + frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame.DataLen()).To(BeEquivalentTo(99)) + f.PutBack() + }) + + It("keeps the data len", func() { + f := &StreamFrame{ + StreamID: 0x1337, + DataLenPresent: true, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(66, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(frame.DataLenPresent).To(BeTrue()) + }) + + It("adjusts the offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x100, + Data: []byte("foobar"), + } + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0x100))) + Expect(frame.Data).To(Equal([]byte("foo"))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x100 + 3))) + Expect(f.Data).To(Equal([]byte("bar"))) + }) + + It("preserves the FIN bit", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Fin: true, + Offset: 0xdeadbeef, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Offset).To(BeNumerically("<", f.Offset)) + Expect(f.Fin).To(BeTrue()) + Expect(frame.Fin).To(BeFalse()) + }) + + It("produces frames of the correct length, without data len", func() { + const size = 1000 + f := &StreamFrame{ + StreamID: 0xdecafbad, + Offset: 0x1234, + Data: []byte{0}, + } + minFrameSize := f.Length(protocol.Version1) + for i := protocol.ByteCount(0); i < minFrameSize; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + for i := minFrameSize; i < size; i++ { + f.fromPool = false + f.Data = make([]byte, size) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f.Length(protocol.Version1)).To(Equal(i)) + } + }) + + It("produces frames of the correct length, with data len", func() { + const size = 1000 + f := &StreamFrame{ + StreamID: 0xdecafbad, + Offset: 0x1234, + DataLenPresent: true, + Data: []byte{0}, + } + minFrameSize := f.Length(protocol.Version1) + for i := protocol.ByteCount(0); i < minFrameSize; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + var frameOneByteTooSmallCounter int + for i := minFrameSize; i < size; i++ { + f.fromPool = false + f.Data = make([]byte, size) + newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if newFrame.Length(protocol.Version1) == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(newFrame.Length(protocol.Version1)).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) +}) diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go new file mode 100644 index 00000000..5e18bb34 --- /dev/null +++ b/internal/wire/streams_blocked_frame.go @@ -0,0 +1,56 @@ +package wire + +import ( + "bytes" + "fmt" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" +) + +// A StreamsBlockedFrame is a STREAMS_BLOCKED frame +type StreamsBlockedFrame struct { + Type protocol.StreamType + StreamLimit protocol.StreamNum +} + +func parseStreamsBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamsBlockedFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &StreamsBlockedFrame{} + switch typeByte { + case 0x16: + f.Type = protocol.StreamTypeBidi + case 0x17: + f.Type = protocol.StreamTypeUni + } + streamLimit, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.StreamLimit = protocol.StreamNum(streamLimit) + if f.StreamLimit > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + } + return f, nil +} + +func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x16) + case protocol.StreamTypeUni: + b.WriteByte(0x17) + } + quicvarint.Write(b, uint64(f.StreamLimit)) + return nil +} + +// Length of a written frame +func (f *StreamsBlockedFrame) Length(_ quic.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamLimit))) +} diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go new file mode 100644 index 00000000..eb5f94bc --- /dev/null +++ b/internal/wire/streams_blocked_frame_test.go @@ -0,0 +1,108 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAMS_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts a frame for bidirectional streams", func() { + expected := []byte{0x16} + expected = append(expected, encodeVarInt(0x1337)...) + b := bytes.NewReader(expected) + f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) + Expect(f.StreamLimit).To(BeEquivalentTo(0x1337)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts a frame for unidirectional streams", func() { + expected := []byte{0x17} + expected = append(expected, encodeVarInt(0x7331)...) + b := bytes.NewReader(expected) + f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeUni)) + Expect(f.StreamLimit).To(BeEquivalentTo(0x7331)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x16} + data = append(data, encodeVarInt(0x12345678)...) + _, err := parseStreamsBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } + }) + + Context("writing", func() { + It("writes a frame for bidirectional streams", func() { + b := &bytes.Buffer{} + f := StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 0xdeadbeefcafe, + } + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x16} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame for unidirectional streams", func() { + b := &bytes.Buffer{} + f := StreamsBlockedFrame{ + Type: protocol.StreamTypeUni, + StreamLimit: 0xdeadbeefcafe, + } + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x17} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := StreamsBlockedFrame{StreamLimit: 0x123456} + Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + quicvarint.Len(0x123456))) + }) + }) +}) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go new file mode 100644 index 00000000..396ed50f --- /dev/null +++ b/internal/wire/transport_parameter_test.go @@ -0,0 +1,612 @@ +package wire + +import ( + "bytes" + "fmt" + "math" + "math/rand" + "net" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Transport Parameters", func() { + getRandomValueUpTo := func(max int64) uint64 { + maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} + m := maxVals[int(rand.Int31n(4))] + if m > max { + m = max + } + return uint64(rand.Int63n(m)) + } + + getRandomValue := func() uint64 { + return getRandomValueUpTo(math.MaxInt64) + } + + BeforeEach(func() { + rand.Seed(GinkgoRandomSeed()) + }) + + addInitialSourceConnectionID := func(b *bytes.Buffer) { + quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + } + + It("has a string representation", func() { + p := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1234, + InitialMaxStreamDataBidiRemote: 2345, + InitialMaxStreamDataUni: 3456, + InitialMaxData: 4567, + MaxBidiStreamNum: 1337, + MaxUniStreamNum: 7331, + MaxIdleTimeout: 42 * time.Second, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + AckDelayExponent: 14, + MaxAckDelay: 37 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + ActiveConnectionIDLimit: 123, + MaxDatagramFrameSize: 876, + } + Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: decafbad, RetrySourceConnectionID: deadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}")) + }) + + It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() { + p := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1234, + InitialMaxStreamDataBidiRemote: 2345, + InitialMaxStreamDataUni: 3456, + InitialMaxData: 4567, + MaxBidiStreamNum: 1337, + MaxUniStreamNum: 7331, + MaxIdleTimeout: 42 * time.Second, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{}, + AckDelayExponent: 14, + MaxAckDelay: 37 * time.Second, + ActiveConnectionIDLimit: 89, + MaxDatagramFrameSize: protocol.InvalidByteCount, + } + Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}")) + }) + + It("marshals and unmarshals", func() { + var token protocol.StatelessResetToken + rand.Read(token[:]) + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxIdleTimeout: 0xcafe * time.Second, + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + DisableActiveMigration: true, + StatelessResetToken: &token, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + AckDelayExponent: 13, + MaxAckDelay: 42 * time.Millisecond, + ActiveConnectionIDLimit: getRandomValue(), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), + } + data := params.Marshal(protocol.PerspectiveServer) + + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) + Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) + Expect(p.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) + Expect(p.InitialMaxData).To(Equal(params.InitialMaxData)) + Expect(p.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) + Expect(p.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) + Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) + Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) + Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) + Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(p.AckDelayExponent).To(Equal(uint8(13))) + Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) + Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) + }) + + It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() { + data := (&TransportParameters{ + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.RetrySourceConnectionID).To(BeNil()) + }) + + It("marshals a zero-length retry_source_connection_id", func() { + data := (&TransportParameters{ + RetrySourceConnectionID: &protocol.ConnectionID{}, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.RetrySourceConnectionID).ToNot(BeNil()) + Expect(p.RetrySourceConnectionID.Len()).To(BeZero()) + }) + + It("errors when the stateless_reset_token has the wrong length", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 15) + b.Write(make([]byte, 15)) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for stateless_reset_token: 15 (expected 16)", + })) + }) + + It("errors when the max_packet_size is too small", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(maxUDPPayloadSizeParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(1199))) + quicvarint.Write(b, 1199) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_packet_size: 1199 (minimum 1200)", + })) + }) + + It("errors when disable_active_migration has content", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for disable_active_migration: 6 (expected empty)", + })) + }) + + It("errors when the server doesn't set the original_destination_connection_id", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 16) + b.Write(make([]byte, 16)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing original_destination_connection_id", + })) + }) + + It("errors when the initial_source_connection_id is missing", func() { + Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing initial_source_connection_id", + })) + }) + + It("errors when the max_ack_delay is too large", func() { + data := (&TransportParameters{ + MaxAckDelay: 1 << 14 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", + })) + }) + + It("doesn't send the max_ack_delay, if it has the default value", func() { + const num = 1000 + var defaultLen, dataLen int + // marshal 1000 times to average out the greasing transport parameter + maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond + for i := 0; i < num; i++ { + dataDefault := (&TransportParameters{ + MaxAckDelay: protocol.DefaultMaxAckDelay, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + defaultLen += len(dataDefault) + data := (&TransportParameters{ + MaxAckDelay: maxAckDelay, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + dataLen += len(data) + } + entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(uint64(maxAckDelay.Milliseconds())))) /*length */ + quicvarint.Len(uint64(maxAckDelay.Milliseconds())) /* value */ + Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) + }) + + It("errors when the ack_delay_exponenent is too large", func() { + data := (&TransportParameters{ + AckDelayExponent: 21, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for ack_delay_exponent: 21 (maximum 20)", + })) + }) + + It("doesn't send the ack_delay_exponent, if it has the default value", func() { + const num = 1000 + var defaultLen, dataLen int + // marshal 1000 times to average out the greasing transport parameter + for i := 0; i < num; i++ { + dataDefault := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + defaultLen += len(dataDefault) + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent + 1, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + dataLen += len(data) + } + entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(protocol.DefaultAckDelayExponent+1))) /* length */ + quicvarint.Len(protocol.DefaultAckDelayExponent+1) /* value */ + Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) + }) + + It("sets the default value for the ack_delay_exponent, when no value was sent", func() { + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent)) + }) + + It("errors when the varint value has the wrong length", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, 2) + val := uint64(0xdeadbeef) + Expect(quicvarint.Len(val)).ToNot(BeEquivalentTo(2)) + quicvarint.Write(b, val) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) + }) + + It("errors if initial_max_streams_bidi is too large", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamsBidiParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) + quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", + })) + }) + + It("errors if initial_max_streams_uni is too large", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamsUniParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) + quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", + })) + }) + + It("handles huge max_ack_delay values", func() { + b := &bytes.Buffer{} + val := uint64(math.MaxUint64) / 5 + quicvarint.Write(b, uint64(maxAckDelayParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(val))) + quicvarint.Write(b, val) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", + })) + }) + + It("skips unknown parameters", func() { + b := &bytes.Buffer{} + // write a known parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + // write an unknown parameter + quicvarint.Write(b, 0x42) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + // write a known parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x42))) + quicvarint.Write(b, 0x42) + addInitialSourceConnectionID(b) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) + Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) + Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) + }) + + It("rejects duplicate parameters", func() { + b := &bytes.Buffer{} + // write first parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + // write a second parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x42))) + quicvarint.Write(b, 0x42) + // write first parameter again + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) + }) + + It("errors if there's not enough data to read", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, 0x42) + quicvarint.Write(b, 7) + b.Write([]byte("foobar")) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "remaining length (6) smaller than parameter length (7)", + })) + }) + + It("errors if the client sent a stateless_reset_token", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(16))) + b.Write(make([]byte, 16)) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a stateless_reset_token", + })) + }) + + It("errors if the client sent the original_destination_connection_id", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent an original_destination_connection_id", + })) + }) + + Context("preferred address", func() { + var pa *PreferredAddress + + BeforeEach(func() { + pa = &PreferredAddress{ + IPv4: net.IPv4(127, 0, 0, 1), + IPv4Port: 42, + IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + IPv6Port: 13, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + } + }) + + It("marshals and unmarshals", func() { + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) + Expect(p.PreferredAddress.IPv4Port).To(Equal(pa.IPv4Port)) + Expect(p.PreferredAddress.IPv6.String()).To(Equal(pa.IPv6.String())) + Expect(p.PreferredAddress.IPv6Port).To(Equal(pa.IPv6Port)) + Expect(p.PreferredAddress.ConnectionID).To(Equal(pa.ConnectionID)) + Expect(p.PreferredAddress.StatelessResetToken).To(Equal(pa.StatelessResetToken)) + }) + + It("errors if the client sent a preferred_address", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(preferredAddressParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a preferred_address", + })) + }) + + It("errors on zero-length connection IDs", func() { + pa.ConnectionID = protocol.ConnectionID{} + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 0", + })) + }) + + It("errors on too long connection IDs", func() { + pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} + Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 21", + })) + }) + + It("errors on EOF", func() { + raw := []byte{ + 127, 0, 0, 1, // IPv4 + 0, 42, // IPv4 Port + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // IPv6 + 13, 37, // IPv6 Port, + 4, // conn ID len + 0xde, 0xad, 0xbe, 0xef, + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // stateless reset token + } + for i := 1; i < len(raw); i++ { + buf := &bytes.Buffer{} + quicvarint.Write(buf, uint64(preferredAddressParameterID)) + buf.Write(raw[:i]) + p := &TransportParameters{} + Expect(p.Unmarshal(buf.Bytes(), protocol.PerspectiveServer)).ToNot(Succeed()) + } + }) + }) + + Context("saving and retrieving from a session ticket", func() { + It("saves and retrieves the parameters", func() { + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + ActiveConnectionIDLimit: getRandomValue(), + } + Expect(params.ValidFor0RTT(params)).To(BeTrue()) + b := &bytes.Buffer{} + params.MarshalForSessionTicket(b) + var tp TransportParameters + Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) + Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) + Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) + Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) + Expect(tp.InitialMaxData).To(Equal(params.InitialMaxData)) + Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) + Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) + Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + }) + + It("rejects the parameters if it can't parse them", func() { + var p TransportParameters + Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed()) + }) + + It("rejects the parameters if the version changed", func() { + var p TransportParameters + buf := &bytes.Buffer{} + p.MarshalForSessionTicket(buf) + data := buf.Bytes() + b := &bytes.Buffer{} + quicvarint.Write(b, transportParameterMarshalingVersion+1) + b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) + Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) + }) + + Context("rejects the parameters if they changed", func() { + var p TransportParameters + saved := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + InitialMaxStreamDataUni: 3, + InitialMaxData: 4, + MaxBidiStreamNum: 5, + MaxUniStreamNum: 6, + ActiveConnectionIDLimit: 7, + } + + BeforeEach(func() { + p = *saved + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxData was reduced", func() { + p.InitialMaxData = saved.InitialMaxData - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxData was increased", func() { + p.InitialMaxData = saved.InitialMaxData + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("accepts the parameters if the MaxBidiStreamNum was increased", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxUniStreamNum changed", func() { + p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("accepts the parameters if the MaxUniStreamNum was increased", func() { + p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the ActiveConnectionIDLimit changed", func() { + p.ActiveConnectionIDLimit = 0 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + }) + }) +}) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go new file mode 100644 index 00000000..0d1a3401 --- /dev/null +++ b/internal/wire/transport_parameters.go @@ -0,0 +1,476 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + "sort" + "time" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/qerr" + "github.com/imroc/req/v3/internal/utils" + "github.com/imroc/req/v3/internal/quicvarint" +) + +const transportParameterMarshalingVersion = 1 + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} + +type transportParameterID uint64 + +const ( + originalDestinationConnectionIDParameterID transportParameterID = 0x0 + maxIdleTimeoutParameterID transportParameterID = 0x1 + statelessResetTokenParameterID transportParameterID = 0x2 + maxUDPPayloadSizeParameterID transportParameterID = 0x3 + initialMaxDataParameterID transportParameterID = 0x4 + initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 + initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 + initialMaxStreamDataUniParameterID transportParameterID = 0x7 + initialMaxStreamsBidiParameterID transportParameterID = 0x8 + initialMaxStreamsUniParameterID transportParameterID = 0x9 + ackDelayExponentParameterID transportParameterID = 0xa + maxAckDelayParameterID transportParameterID = 0xb + disableActiveMigrationParameterID transportParameterID = 0xc + preferredAddressParameterID transportParameterID = 0xd + activeConnectionIDLimitParameterID transportParameterID = 0xe + initialSourceConnectionIDParameterID transportParameterID = 0xf + retrySourceConnectionIDParameterID transportParameterID = 0x10 + // RFC 9221 + maxDatagramFrameSizeParameterID transportParameterID = 0x20 +) + +// PreferredAddress is the value encoding in the preferred_address transport parameter +type PreferredAddress struct { + IPv4 net.IP + IPv4Port uint16 + IPv6 net.IP + IPv6Port uint16 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + InitialMaxStreamDataBidiLocal protocol.ByteCount + InitialMaxStreamDataBidiRemote protocol.ByteCount + InitialMaxStreamDataUni protocol.ByteCount + InitialMaxData protocol.ByteCount + + MaxAckDelay time.Duration + AckDelayExponent uint8 + + DisableActiveMigration bool + + MaxUDPPayloadSize protocol.ByteCount + + MaxUniStreamNum protocol.StreamNum + MaxBidiStreamNum protocol.StreamNum + + MaxIdleTimeout time.Duration + + PreferredAddress *PreferredAddress + + OriginalDestinationConnectionID protocol.ConnectionID + InitialSourceConnectionID protocol.ConnectionID + RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters + + StatelessResetToken *protocol.StatelessResetToken + ActiveConnectionIDLimit uint64 + + MaxDatagramFrameSize protocol.ByteCount +} + +// Unmarshal the transport parameters +func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { + if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { + return &qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + } + } + return nil +} + +func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { + // needed to check that every parameter is only sent at most once + var parameterIDs []transportParameterID + + var ( + readOriginalDestinationConnectionID bool + readInitialSourceConnectionID bool + ) + + p.AckDelayExponent = protocol.DefaultAckDelayExponent + p.MaxAckDelay = protocol.DefaultMaxAckDelay + p.MaxDatagramFrameSize = protocol.InvalidByteCount + + for r.Len() > 0 { + paramIDInt, err := quicvarint.Read(r) + if err != nil { + return err + } + paramID := transportParameterID(paramIDInt) + paramLen, err := quicvarint.Read(r) + if err != nil { + return err + } + if uint64(r.Len()) < paramLen { + return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) + } + parameterIDs = append(parameterIDs, paramID) + switch paramID { + case maxIdleTimeoutParameterID, + maxUDPPayloadSizeParameterID, + initialMaxDataParameterID, + initialMaxStreamDataBidiLocalParameterID, + initialMaxStreamDataBidiRemoteParameterID, + initialMaxStreamDataUniParameterID, + initialMaxStreamsBidiParameterID, + initialMaxStreamsUniParameterID, + maxAckDelayParameterID, + activeConnectionIDLimitParameterID, + maxDatagramFrameSizeParameterID, + ackDelayExponentParameterID: + if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { + return err + } + case preferredAddressParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a preferred_address") + } + if err := p.readPreferredAddress(r, int(paramLen)); err != nil { + return err + } + case disableActiveMigrationParameterID: + if paramLen != 0 { + return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) + } + p.DisableActiveMigration = true + case statelessResetTokenParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a stateless_reset_token") + } + if paramLen != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) + } + var token protocol.StatelessResetToken + r.Read(token[:]) + p.StatelessResetToken = &token + case originalDestinationConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent an original_destination_connection_id") + } + p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readOriginalDestinationConnectionID = true + case initialSourceConnectionIDParameterID: + p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readInitialSourceConnectionID = true + case retrySourceConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a retry_source_connection_id") + } + connID, _ := protocol.ReadConnectionID(r, int(paramLen)) + p.RetrySourceConnectionID = &connID + default: + r.Seek(int64(paramLen), io.SeekCurrent) + } + } + + if !fromSessionTicket { + if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { + return errors.New("missing original_destination_connection_id") + } + if p.MaxUDPPayloadSize == 0 { + p.MaxUDPPayloadSize = protocol.MaxByteCount + } + if !readInitialSourceConnectionID { + return errors.New("missing initial_source_connection_id") + } + } + + // check that every transport parameter was sent at most once + sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) + for i := 0; i < len(parameterIDs)-1; i++ { + if parameterIDs[i] == parameterIDs[i+1] { + return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) + } + } + + return nil +} + +func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { + remainingLen := r.Len() + pa := &PreferredAddress{} + ipv4 := make([]byte, 4) + if _, err := io.ReadFull(r, ipv4); err != nil { + return err + } + pa.IPv4 = net.IP(ipv4) + port, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv4Port = port + ipv6 := make([]byte, 16) + if _, err := io.ReadFull(r, ipv6); err != nil { + return err + } + pa.IPv6 = net.IP(ipv6) + port, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv6Port = port + connIDLen, err := r.ReadByte() + if err != nil { + return err + } + if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return err + } + pa.ConnectionID = connID + if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { + return err + } + if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { + return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) + } + p.PreferredAddress = pa + return nil +} + +func (p *TransportParameters) readNumericTransportParameter( + r *bytes.Reader, + paramID transportParameterID, + expectedLen int, +) error { + remainingLen := r.Len() + val, err := quicvarint.Read(r) + if err != nil { + return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) + } + if remainingLen-r.Len() != expectedLen { + return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) + } + //nolint:exhaustive // This only covers the numeric transport parameters. + switch paramID { + case initialMaxStreamDataBidiLocalParameterID: + p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) + case initialMaxStreamDataBidiRemoteParameterID: + p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) + case initialMaxStreamDataUniParameterID: + p.InitialMaxStreamDataUni = protocol.ByteCount(val) + case initialMaxDataParameterID: + p.InitialMaxData = protocol.ByteCount(val) + case initialMaxStreamsBidiParameterID: + p.MaxBidiStreamNum = protocol.StreamNum(val) + if p.MaxBidiStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) + } + case initialMaxStreamsUniParameterID: + p.MaxUniStreamNum = protocol.StreamNum(val) + if p.MaxUniStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) + } + case maxIdleTimeoutParameterID: + p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + case maxUDPPayloadSizeParameterID: + if val < 1200 { + return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) + } + p.MaxUDPPayloadSize = protocol.ByteCount(val) + case ackDelayExponentParameterID: + if val > protocol.MaxAckDelayExponent { + return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) + } + p.AckDelayExponent = uint8(val) + case maxAckDelayParameterID: + if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { + return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) + } + p.MaxAckDelay = time.Duration(val) * time.Millisecond + case activeConnectionIDLimitParameterID: + p.ActiveConnectionIDLimit = val + case maxDatagramFrameSizeParameterID: + p.MaxDatagramFrameSize = protocol.ByteCount(val) + default: + return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) + } + return nil +} + +// Marshal the transport parameters +func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { + b := &bytes.Buffer{} + + // add a greased value + quicvarint.Write(b, uint64(27+31*rand.Intn(100))) + length := rand.Intn(16) + randomData := make([]byte, length) + rand.Read(randomData) + quicvarint.Write(b, uint64(length)) + b.Write(randomData) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // idle_timeout + p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) + // max_packet_size + p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) + // max_ack_delay + // Only send it if is different from the default value. + if p.MaxAckDelay != protocol.DefaultMaxAckDelay { + p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) + } + // ack_delay_exponent + // Only send it if is different from the default value. + if p.AckDelayExponent != protocol.DefaultAckDelayExponent { + p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) + } + // disable_active_migration + if p.DisableActiveMigration { + quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) + quicvarint.Write(b, 0) + } + if pers == protocol.PerspectiveServer { + // stateless_reset_token + if p.StatelessResetToken != nil { + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 16) + b.Write(p.StatelessResetToken[:]) + } + // original_destination_connection_id + quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) + b.Write(p.OriginalDestinationConnectionID.Bytes()) + // preferred_address + if p.PreferredAddress != nil { + quicvarint.Write(b, uint64(preferredAddressParameterID)) + quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + ipv4 := p.PreferredAddress.IPv4 + b.Write(ipv4[len(ipv4)-4:]) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) + b.Write(p.PreferredAddress.IPv6) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) + b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) + b.Write(p.PreferredAddress.ConnectionID.Bytes()) + b.Write(p.PreferredAddress.StatelessResetToken[:]) + } + } + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + // initial_source_connection_id + quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) + b.Write(p.InitialSourceConnectionID.Bytes()) + // retry_source_connection_id + if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { + quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) + b.Write(p.RetrySourceConnectionID.Bytes()) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } + return b.Bytes() +} + +func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { + quicvarint.Write(b, uint64(id)) + quicvarint.Write(b, uint64(quicvarint.Len(val))) + quicvarint.Write(b, val) +} + +// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. +// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. +// The client will remember the transport parameters used in the last session, +// and apply those to the 0-RTT data it sends. +// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT +// if the transport parameters changed. +// Since the session ticket is encrypted, the serialization format is defined by the server. +// For convenience, we use the same format that we also use for sending the transport parameters. +func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { + quicvarint.Write(b, transportParameterMarshalingVersion) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) +} + +// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. +func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { + version, err := quicvarint.Read(r) + if err != nil { + return err + } + if version != transportParameterMarshalingVersion { + return fmt.Errorf("unknown transport parameter marshaling version: %d", version) + } + return p.unmarshal(r, protocol.PerspectiveServer, true) +} + +// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. +func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && + p.InitialMaxData >= saved.InitialMaxData && + p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && + p.MaxUniStreamNum >= saved.MaxUniStreamNum && + p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit +} + +// String returns a string representation, intended for logging. +func (p *TransportParameters) String() string { + logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " + logParams := []interface{}{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} + if p.RetrySourceConnectionID != nil { + logString += "RetrySourceConnectionID: %s, " + logParams = append(logParams, p.RetrySourceConnectionID) + } + logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" + logParams = append(logParams, []interface{}{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) + if p.StatelessResetToken != nil { // the client never sends a stateless reset token + logString += ", StatelessResetToken: %#x" + logParams = append(logParams, *p.StatelessResetToken) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + logString += ", MaxDatagramFrameSize: %d" + logParams = append(logParams, p.MaxDatagramFrameSize) + } + logString += "}" + return fmt.Sprintf(logString, logParams...) +} diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go new file mode 100644 index 00000000..971571dd --- /dev/null +++ b/internal/wire/version_negotiation.go @@ -0,0 +1,55 @@ +package wire + +import ( + "bytes" + "crypto/rand" + "errors" + "github.com/lucas-clemente/quic-go" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/utils" +) + +// ParseVersionNegotiationPacket parses a Version Negotiation packet. +func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []quic.VersionNumber, error) { + hdr, err := parseHeader(b, 0) + if err != nil { + return nil, nil, err + } + if b.Len() == 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has empty version list") + } + if b.Len()%4 != 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + } + versions := make([]quic.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, nil, err + } + versions[i] = quic.VersionNumber(v) + } + return hdr, versions, nil +} + +// ComposeVersionNegotiation composes a Version Negotiation +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []quic.VersionNumber) []byte { + greasedVersions := protocol.GetGreasedVersions(versions) + expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 + buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) + r := make([]byte, 1) + _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf.WriteByte(r[0] | 0x80) + utils.BigEndian.WriteUint32(buf, 0) // version 0 + buf.WriteByte(uint8(destConnID.Len())) + buf.Write(destConnID) + buf.WriteByte(uint8(srcConnID.Len())) + buf.Write(srcConnID) + for _, v := range greasedVersions { + utils.BigEndian.WriteUint32(buf, uint32(v)) + } + return buf.Bytes() +} diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go new file mode 100644 index 00000000..4cc1f68e --- /dev/null +++ b/internal/wire/version_negotiation_test.go @@ -0,0 +1,83 @@ +package wire + +import ( + "bytes" + "encoding/binary" + + "github.com/imroc/req/v3/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Version Negotiation Packets", func() { + It("parses a Version Negotiation packet", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + versions := []quic.VersionNumber{0x22334455, 0x33445566} + data := []byte{0x80, 0, 0, 0, 0} + data = append(data, uint8(len(destConnID))) + data = append(data, destConnID...) + data = append(data, uint8(len(srcConnID))) + data = append(data, srcConnID...) + for _, v := range versions { + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) + } + Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.Version).To(BeZero()) + Expect(supportedVersions).To(Equal(versions)) + }) + + It("errors if it contains versions of the wrong length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []quic.VersionNumber{0x22334455, 0x33445566} + data := ComposeVersionNegotiation(connID, connID, versions) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) + }) + + It("errors if the version list is empty", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []quic.VersionNumber{0x22334455} + data := ComposeVersionNegotiation(connID, connID, versions) + // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number + data = data[:len(data)-8] + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).To(MatchError("Version Negotiation packet has empty version list")) + }) + + It("adds a reserved version", func() { + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []quic.VersionNumber{1001, 1003} + data := ComposeVersionNegotiation(destConnID, srcConnID, versions) + Expect(data[0] & 0x80).ToNot(BeZero()) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.Version).To(BeZero()) + // the supported versions should include one reserved version number + Expect(supportedVersions).To(HaveLen(len(versions) + 1)) + for _, v := range versions { + Expect(supportedVersions).To(ContainElement(v)) + } + var reservedVersion quic.VersionNumber + versionLoop: + for _, ver := range supportedVersions { + for _, v := range versions { + if v == ver { + continue versionLoop + } + } + reservedVersion = ver + } + Expect(reservedVersion).ToNot(BeZero()) + Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number + }) +}) diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go new file mode 100644 index 00000000..528f728f --- /dev/null +++ b/internal/wire/wire_suite_test.go @@ -0,0 +1,31 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestWire(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Wire Suite") +} + +func encodeVarInt(i uint64) []byte { + b := &bytes.Buffer{} + quicvarint.Write(b, i) + return b.Bytes() +} + +func appendVersion(data []byte, v quic.VersionNumber) []byte { + offset := len(data) + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[offset:], uint32(v)) + return data +} diff --git a/pkg/altsvc/altsvc.go b/pkg/altsvc/altsvc.go new file mode 100644 index 00000000..8b095424 --- /dev/null +++ b/pkg/altsvc/altsvc.go @@ -0,0 +1,51 @@ +package altsvc + +import ( + "sync" + "time" +) + +type AltSvcJar struct { + entries map[string]*AltSvc + mu sync.Mutex +} + +func NewAltSvcJar() *AltSvcJar { + return &AltSvcJar{ + entries: make(map[string]*AltSvc), + } +} + +func (j *AltSvcJar) GetAltSvc(addr string) *AltSvc { + if addr == "" { + return nil + } + as, ok := j.entries[addr] + if !ok { + return nil + } + now := time.Now() + j.mu.Lock() + defer j.mu.Unlock() + if as.Expire.Before(now) { // expired + delete(j.entries, addr) + return nil + } + return as +} + +func (j *AltSvcJar) SetAltSvc(addr string, as *AltSvc) { + if addr == "" { + return + } + j.mu.Lock() + defer j.mu.Unlock() + j.entries[addr] = as +} + +type AltSvc struct { + Protocol string + Host string + Port string + Expire time.Time +} diff --git a/pkg/altsvc/jar.go b/pkg/altsvc/jar.go new file mode 100644 index 00000000..60ca03ca --- /dev/null +++ b/pkg/altsvc/jar.go @@ -0,0 +1,6 @@ +package altsvc + +type Jar interface { + SetAltSvc(addr string, as *AltSvc) + GetAltSvc(addr string) *AltSvc +} diff --git a/roundtrip.go b/roundtrip.go index f73d3f36..d230481b 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -7,7 +7,9 @@ package req -import "net/http" +import ( + "net/http" +) // RoundTrip implements the RoundTripper interface. // @@ -21,6 +23,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if err != nil { return } + if t.altSvcJar != nil { + if v := resp.Header.Get("alt-svc"); v != "" { + t.handleAltSvc(req, v) + } + } t.handleResponseBody(resp, req) return } diff --git a/tls.go b/tls.go deleted file mode 100644 index 66f8dbb8..00000000 --- a/tls.go +++ /dev/null @@ -1,35 +0,0 @@ -package req - -import ( - "crypto/tls" - "net" -) - -// TLSConn is the recommended interface for the connection -// returned by the DailTLS function (Client.SetDialTLS, -// Transport.DialTLSContext), so that the TLS handshake negotiation -// can automatically decide whether to use HTTP2 or HTTP1 (ALPN). -// If this interface is not implemented, HTTP1 will be used by default. -type TLSConn interface { - net.Conn - // ConnectionState returns basic TLS details about the connection. - ConnectionState() tls.ConnectionState - // Handshake runs the client or server handshake - // protocol if it has not yet been run. - // - // Most uses of this package need not call Handshake explicitly: the - // first Read or Write will call it automatically. - // - // For control over canceling or setting a timeout on a handshake, use - // HandshakeContext or the Dialer's DialContext method instead. - Handshake() error -} - -// NetConnWrapper is the interface to get underlying connection, which is -// introduced in go1.18 for *tls.Conn. -type NetConnWrapper interface { - // NetConn returns the underlying connection that is wrapped by c. - // Note that writing to or reading from this connection directly will corrupt the - // TLS session. - NetConn() net.Conn -} diff --git a/transport.go b/transport.go index f01b077b..6677c852 100644 --- a/transport.go +++ b/transport.go @@ -17,14 +17,18 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/imroc/req/v3/internal/altsvcutil" "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/http2" + "github.com/imroc/req/v3/internal/http3" + "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/socks" reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/util" + "github.com/imroc/req/v3/pkg/altsvc" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" @@ -38,7 +42,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "golang.org/x/net/http/httpguts" @@ -52,6 +55,8 @@ const ( HTTP1 HttpVersion = "1.1" // HTTP2 represents "HTTP/2.0" HTTP2 HttpVersion = "2" + // HTTP3 represents "HTTP/3.0" + HTTP3 HttpVersion = "3" ) // defaultMaxIdleConnsPerHost is the default value of Transport's @@ -115,12 +120,12 @@ type Transport struct { reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - altMu sync.Mutex // guards changing altProto only - altProto atomic.Value // of nil or map[string]http.RoundTripper, key is URI scheme - connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + altSvcJar altsvc.Jar + pendingAltSvcs map[string]*pendingAltSvc + pendingAltSvcsMu sync.Mutex // Force using specific http version ForceHttpVersion HttpVersion @@ -264,6 +269,7 @@ type Transport struct { ReadBufferSize int t2 *http2.Transport // non-nil if http2 wired up + t3 *http3.RoundTripper // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. @@ -280,6 +286,20 @@ type Transport struct { Debugf func(format string, v ...interface{}) } +func (t *Transport) enableH3() error { + if t.altSvcJar == nil { + t.altSvcJar = altsvc.NewAltSvcJar() + } + if t.pendingAltSvcs == nil { + t.pendingAltSvcs = make(map[string]*pendingAltSvc) + } + t3 := &http3.RoundTripper{ + Interface: transportImpl{t}, + } + t.t3 = t3 + return nil +} + type wrapResponseBodyKeyType int const wrapResponseBodyKey wrapResponseBodyKeyType = iota @@ -294,6 +314,65 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { dump.WrapResponseBodyIfNeeded(res, req, t.dump) } +var allowedProtocols = map[string]bool{ + "h3": true, + "h2": true, +} + +func (t *Transport) handleAltSvc(req *http.Request, value string) { + addr := netutil.AuthorityKey(req.URL) + as := t.altSvcJar.GetAltSvc(addr) + if as != nil { + return + } + + t.pendingAltSvcsMu.Lock() + defer t.pendingAltSvcsMu.Unlock() + _, ok := t.pendingAltSvcs[addr] + if ok { + return + } + ass, err := altsvcutil.ParseHeader(value) + if err != nil { + if t.Debugf != nil { + t.Debugf("failed to parse alt-svc header: %s", err.Error()) + } + return + } + var entries []*altsvc.AltSvc + for _, a := range ass { + if allowedProtocols[a.Protocol] { + entries = append(entries, a) + } + } + if len(entries) > 0 { + pas := &pendingAltSvc{ + Entries: entries, + } + t.pendingAltSvcs[addr] = pas + go t.handlePendingAltSvc(netutil.AuthorityAddr(req.URL.Scheme, req.URL.Host), pas) + } +} + +func (t *Transport) handlePendingAltSvc(hostname string, pas *pendingAltSvc) { + for i := pas.CurrentIndex; i < len(pas.Entries); i++ { + switch pas.Entries[i].Protocol { + case "h3": + err := t.t3.AddConn(hostname) + if err != nil { + if t.Debugf != nil { + t.Debugf("failed to get http3 connection: %s", err.Error()) + } + } else { + pas.CurrentIndex = i + pas.Transport = t.t3 + return + } + case "h2": // TODO + } + } +} + func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFunc) { switch b := res.Body.(type) { case *gzipReader: @@ -463,32 +542,23 @@ func (tr *transportRequest) setError(err error) { tr.mu.Unlock() } -// useRegisteredProtocol reports whether an alternate protocol (as registered -// with Transport.RegisterProtocol) should be respected for this request. -func (t *Transport) useRegisteredProtocol(req *http.Request) bool { - if req.URL.Scheme == "https" && requestRequiresHTTP1(req) { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false - } - return true -} - -// alternatehttp.RoundTripper returns the alternate http.RoundTripper to use -// for this request if the Request's URL scheme requires one, -// or nil for the normal case of using the Transport. -func (t *Transport) alternateRoundTripper(req *http.Request) http.RoundTripper { - if !t.useRegisteredProtocol(req) { - return nil +func (t *Transport) roundTripAltSvc(req *http.Request, as *altsvc.AltSvc) (resp *http.Response, err error) { + r := req.Clone(req.Context()) + r.URL = altsvcutil.ConvertURL(as, req.URL) + switch as.Protocol { + case "h3": + resp, err = t.t3.RoundTrip(r) + case "h2": + resp, err = t.t2.RoundTrip(r) + default: + // impossible! + panic(fmt.Sprintf("unknown protocol %q", as.Protocol)) } - altProto, _ := t.altProto.Load().(map[string]http.RoundTripper) - return altProto[req.URL.Scheme] + return } // roundTrip implements a http.RoundTripper over HTTP. -func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { +func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error) { ctx := req.Context() trace := httptrace.ContextClientTrace(ctx) @@ -496,6 +566,37 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { closeBody(req) return nil, errors.New("http: nil Request.URL") } + + if t.altSvcJar != nil { + addr := netutil.AuthorityKey(req.URL) + pas, ok := t.pendingAltSvcs[addr] + if ok { + if pas.Transport != nil { + pas.Mu.Lock() + if pas.Transport != nil { + pas.LastTime = time.Now() + resp, err = pas.Transport.RoundTrip(req) + if err != nil { + pas.Transport = nil + if pas.CurrentIndex+1 < len(pas.Entries) { + pas.CurrentIndex++ + go t.handlePendingAltSvc(addr, pas) + } + } else { + t.altSvcJar.SetAltSvc(addr, pas.Entries[pas.CurrentIndex]) + delete(t.pendingAltSvcs, addr) + } + } + pas.Mu.Unlock() + return + } + } + as := t.altSvcJar.GetAltSvc(addr) + if as != nil { + return t.roundTripAltSvc(req, as) + } + } + if req.Header == nil { closeBody(req) return nil, errors.New("http: nil Request.Header") @@ -517,22 +618,33 @@ func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { } } + if t.ForceHttpVersion == HTTP3 { + return t.t3.RoundTrip(req) + } + origReq := req cancelKey := cancelKey{origReq} req = setupRewindBody(req) if t.ForceHttpVersion != HTTP1 { - if altRT := t.alternateRoundTripper(req); altRT != nil { - if resp, err := altRT.RoundTrip(req); err != http.ErrSkipAltProtocol { - return resp, err - } - var err error - req, err = rewindBody(req) - if err != nil { - return nil, err - } + resp, err := t.t2.RoundTripOnlyCachedConn(req) + if err != http2.ErrNoCachedConn { + return resp, err + } + req, err = rewindBody(req) + if err != nil { + return nil, err + } + resp, err = t.t3.RoundTripOnlyCachedConn(req) + if err != http3.ErrNoCachedConn { + return resp, err + } + req, err = rewindBody(req) + if err != nil { + return nil, err } } + if !isHTTP { closeBody(req) return nil, badStringError("unsupported protocol scheme", scheme) @@ -716,31 +828,6 @@ func (pc *persistConn) shouldRetryRequest(req *http.Request, err error) bool { return false // conservatively } -// RegisterProtocol registers a new protocol with scheme. -// The Transport will pass requests using the given scheme to rt. -// It is rt's responsibility to simulate HTTP request semantics. -// -// RegisterProtocol can be used by other packages to provide -// implementations of protocol schemes like "ftp" or "file". -// -// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will -// handle the RoundTrip itself for that one request, as if the -// protocol were not registered. -func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { - t.altMu.Lock() - defer t.altMu.Unlock() - oldMap, _ := t.altProto.Load().(map[string]http.RoundTripper) - if _, exists := oldMap[scheme]; exists { - panic("protocol " + scheme + " already registered") - } - newMap := make(map[string]http.RoundTripper) - for k, v := range oldMap { - newMap[k] = v - } - newMap[scheme] = rt - t.altProto.Store(newMap) -} - // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle in // a "keep-alive" state. It does not interrupt any connections currently @@ -1547,7 +1634,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if err != nil { return nil, wrapErr(err) } - if tc, ok := pconn.conn.(TLSConn); ok { + if tc, ok := pconn.conn.(reqtls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. if trace != nil && trace.TLSHandshakeStart != nil { @@ -1700,13 +1787,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if s := pconn.tlsState; t.ForceHttpVersion != HTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { - if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - alt := next(cm.targetAddr, pconn.conn.(TLSConn)) - if e, ok := alt.(erringRoundTripper); ok { - // pconn.conn was closed by next (http2configureTransports.upgradeFn). - return nil, e.RoundTripErr() - } - return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil + if s.NegotiatedProtocol == http2.NextProtoTLS { + t.t2.AddConn(pconn.conn, cm.targetAddr) + return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: t.t2}, nil } } diff --git a/transport_wrapper.go b/transport_wrapper.go index e17d9ee7..c67c7e38 100644 --- a/transport_wrapper.go +++ b/transport_wrapper.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "github.com/imroc/req/v3/internal/dump" reqtls "github.com/imroc/req/v3/internal/tls" - "github.com/imroc/req/v3/internal/transport" "net" "net/http" "net/url" @@ -20,10 +19,6 @@ func (t transportImpl) Proxy() func(*http.Request) (*url.URL, error) { return t.t.Proxy } -func (t transportImpl) Clone() transport.Interface { - return transportImpl{t.t.Clone()} -} - func (t transportImpl) Debugf() func(format string, v ...interface{}) { return t.t.Debugf } @@ -56,10 +51,6 @@ func (t transportImpl) DialTLSContext() func(ctx context.Context, network string return t.t.DialTLSContext } -func (t transportImpl) RegisterProtocol(scheme string, rt http.RoundTripper) { - t.t.RegisterProtocol(scheme, rt) -} - func (t transportImpl) DisableKeepAlives() bool { return t.t.DisableKeepAlives } From 31bd9426408e21e70c196440a62feb427e3c146c Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 18:02:46 +0800 Subject: [PATCH 508/843] Allow enable http3 only when go version is between 1.16 and 1.18 --- altsvc.go | 16 ---------------- client.go | 18 ++++++++++++++++-- pkg/tls/conn.go | 4 ++++ transport.go | 14 ++++++++++---- 4 files changed, 30 insertions(+), 22 deletions(-) delete mode 100644 altsvc.go create mode 100644 pkg/tls/conn.go diff --git a/altsvc.go b/altsvc.go deleted file mode 100644 index c03dd56f..00000000 --- a/altsvc.go +++ /dev/null @@ -1,16 +0,0 @@ -package req - -import ( - "github.com/imroc/req/v3/pkg/altsvc" - "net/http" - "sync" - "time" -) - -type pendingAltSvc struct { - CurrentIndex int - Entries []*altsvc.AltSvc - Mu sync.Mutex - LastTime time.Time - Transport http.RoundTripper -} diff --git a/client.go b/client.go index b11d8b9a..0f1c955e 100644 --- a/client.go +++ b/client.go @@ -19,6 +19,8 @@ import ( urlpkg "net/url" "os" "reflect" + "runtime" + "strconv" "strings" "time" ) @@ -875,9 +877,21 @@ func (c *Client) SetUnixSocket(file string) *Client { } func (c *Client) EnableHttp3() *Client { - err := c.t.enableH3() + v := runtime.Version() + ss := strings.Split(v, ".") + if len(ss) < 2 || ss[0] != "go1" { + c.log.Warnf("bad go version format: %s", v) + return c + } + minorVersion, err := strconv.Atoi(ss[1]) if err != nil { - c.log.Errorf("failed to enabled http3: %s", err.Error()) + c.log.Warnf("bad go minor version: %s", v) + return c + } + if minorVersion >= 16 && minorVersion <= 18 { + c.t.enableH3() + } else { + c.log.Warnf("%s is not support http3", v) } return c } diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go new file mode 100644 index 00000000..d9a40fb5 --- /dev/null +++ b/pkg/tls/conn.go @@ -0,0 +1,4 @@ +package tls + +type Conn interface { +} diff --git a/transport.go b/transport.go index 6677c852..4cb0fd6e 100644 --- a/transport.go +++ b/transport.go @@ -286,7 +286,15 @@ type Transport struct { Debugf func(format string, v ...interface{}) } -func (t *Transport) enableH3() error { +type pendingAltSvc struct { + CurrentIndex int + Entries []*altsvc.AltSvc + Mu sync.Mutex + LastTime time.Time + Transport http.RoundTripper +} + +func (t *Transport) enableH3() { if t.altSvcJar == nil { t.altSvcJar = altsvc.NewAltSvcJar() } @@ -297,7 +305,6 @@ func (t *Transport) enableH3() error { Interface: transportImpl{t}, } t.t3 = t3 - return nil } type wrapResponseBodyKeyType int @@ -357,7 +364,7 @@ func (t *Transport) handleAltSvc(req *http.Request, value string) { func (t *Transport) handlePendingAltSvc(hostname string, pas *pendingAltSvc) { for i := pas.CurrentIndex; i < len(pas.Entries); i++ { switch pas.Entries[i].Protocol { - case "h3": + case "h3": // only support h3 in alt-svc for now err := t.t3.AddConn(hostname) if err != nil { if t.Debugf != nil { @@ -368,7 +375,6 @@ func (t *Transport) handlePendingAltSvc(hostname string, pas *pendingAltSvc) { pas.Transport = t.t3 return } - case "h2": // TODO } } } From bf92ee058ff6d941808e48ec0ef3fa272e04b351 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 19:18:51 +0800 Subject: [PATCH 509/843] rename package internal/tls --> pkg/tls --- internal/http2/go115.go | 2 +- internal/http2/transport.go | 2 +- internal/tests/transport.go | 2 +- internal/tls/conn.go | 35 --------------------------------- internal/transport/transport.go | 2 +- pkg/tls/conn.go | 31 +++++++++++++++++++++++++++++ transport.go | 2 +- transport_internal_test.go | 4 ++-- transport_test.go | 2 +- transport_wrapper.go | 2 +- 10 files changed, 40 insertions(+), 44 deletions(-) delete mode 100644 internal/tls/conn.go diff --git a/internal/http2/go115.go b/internal/http2/go115.go index 69769559..c9d2183f 100644 --- a/internal/http2/go115.go +++ b/internal/http2/go115.go @@ -10,7 +10,7 @@ package http2 import ( "context" "crypto/tls" - reqtls "github.com/imroc/req/v3/internal/tls" + reqtls "github.com/imroc/req/v3/pkg/tls" ) // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 90f76c4b..0444ab40 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -20,8 +20,8 @@ import ( "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" - reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/transport" + reqtls "github.com/imroc/req/v3/pkg/tls" "io" "log" "math" diff --git a/internal/tests/transport.go b/internal/tests/transport.go index 9378c28b..9122a524 100644 --- a/internal/tests/transport.go +++ b/internal/tests/transport.go @@ -4,8 +4,8 @@ import ( "context" "crypto/tls" "github.com/imroc/req/v3/internal/dump" - reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/transport" + reqtls "github.com/imroc/req/v3/pkg/tls" "net" "net/http" "net/url" diff --git a/internal/tls/conn.go b/internal/tls/conn.go deleted file mode 100644 index 2eaae7c1..00000000 --- a/internal/tls/conn.go +++ /dev/null @@ -1,35 +0,0 @@ -package tls - -import ( - "crypto/tls" - "net" -) - -// Conn is the recommended interface for the connection -// returned by the DailTLS function (Client.SetDialTLS, -// Transport.DialTLSContext), so that the TLS handshake negotiation -// can automatically decide whether to use HTTP2 or HTTP1 (ALPN). -// If this interface is not implemented, HTTP1 will be used by default. -type Conn interface { - net.Conn - // ConnectionState returns basic TLS details about the connection. - ConnectionState() tls.ConnectionState - // Handshake runs the client or server handshake - // protocol if it has not yet been run. - // - // Most uses of this package need not call Handshake explicitly: the - // first Read or Write will call it automatically. - // - // For control over canceling or setting a timeout on a handshake, use - // HandshakeContext or the Dialer's DialContext method instead. - Handshake() error -} - -// NetConnWrapper is the interface to get underlying connection, which is -// introduced in go1.18 for *tls.Conn. -type NetConnWrapper interface { - // NetConn returns the underlying connection that is wrapped by c. - // Note that writing to or reading from this connection directly will corrupt the - // TLS session. - NetConn() net.Conn -} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index b9e864ce..91ca9555 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -4,7 +4,7 @@ import ( "context" "crypto/tls" "github.com/imroc/req/v3/internal/dump" - reqtls "github.com/imroc/req/v3/internal/tls" + reqtls "github.com/imroc/req/v3/pkg/tls" "net" "net/http" "net/url" diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index d9a40fb5..2eaae7c1 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -1,4 +1,35 @@ package tls +import ( + "crypto/tls" + "net" +) + +// Conn is the recommended interface for the connection +// returned by the DailTLS function (Client.SetDialTLS, +// Transport.DialTLSContext), so that the TLS handshake negotiation +// can automatically decide whether to use HTTP2 or HTTP1 (ALPN). +// If this interface is not implemented, HTTP1 will be used by default. type Conn interface { + net.Conn + // ConnectionState returns basic TLS details about the connection. + ConnectionState() tls.ConnectionState + // Handshake runs the client or server handshake + // protocol if it has not yet been run. + // + // Most uses of this package need not call Handshake explicitly: the + // first Read or Write will call it automatically. + // + // For control over canceling or setting a timeout on a handshake, use + // HandshakeContext or the Dialer's DialContext method instead. + Handshake() error +} + +// NetConnWrapper is the interface to get underlying connection, which is +// introduced in go1.18 for *tls.Conn. +type NetConnWrapper interface { + // NetConn returns the underlying connection that is wrapped by c. + // Note that writing to or reading from this connection directly will corrupt the + // TLS session. + NetConn() net.Conn } diff --git a/transport.go b/transport.go index 4cb0fd6e..a99bbe3b 100644 --- a/transport.go +++ b/transport.go @@ -26,9 +26,9 @@ import ( "github.com/imroc/req/v3/internal/http3" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/socks" - reqtls "github.com/imroc/req/v3/internal/tls" "github.com/imroc/req/v3/internal/util" "github.com/imroc/req/v3/pkg/altsvc" + reqtls "github.com/imroc/req/v3/pkg/tls" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" "io" diff --git a/transport_internal_test.go b/transport_internal_test.go index 4ff9da08..688f9eb3 100644 --- a/transport_internal_test.go +++ b/transport_internal_test.go @@ -14,7 +14,7 @@ import ( "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/testcert" "github.com/imroc/req/v3/internal/tests" - reqtls "github.com/imroc/req/v3/internal/tls" + reqtls "github.com/imroc/req/v3/pkg/tls" "io" "net" "net/http" @@ -212,7 +212,7 @@ func TestTransportBodyAltRewind(t *testing.T) { t.Error(err) return } - if err := sc.(TLSConn).Handshake(); err != nil { + if err := sc.(reqtls.Conn).Handshake(); err != nil { t.Error(err) return } diff --git a/transport_test.go b/transport_test.go index b3a05532..578e244c 100644 --- a/transport_test.go +++ b/transport_test.go @@ -24,7 +24,7 @@ import ( "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/testcert" "github.com/imroc/req/v3/internal/tests" - reqtls "github.com/imroc/req/v3/internal/tls" + reqtls "github.com/imroc/req/v3/pkg/tls" "go/token" "golang.org/x/net/http/httpproxy" nethttp2 "golang.org/x/net/http2" diff --git a/transport_wrapper.go b/transport_wrapper.go index c67c7e38..285a3a36 100644 --- a/transport_wrapper.go +++ b/transport_wrapper.go @@ -4,7 +4,7 @@ import ( "context" "crypto/tls" "github.com/imroc/req/v3/internal/dump" - reqtls "github.com/imroc/req/v3/internal/tls" + reqtls "github.com/imroc/req/v3/pkg/tls" "net" "net/http" "net/url" From 5dcc4145f86d52f62881479a4caebf5917f53539 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 19:33:13 +0800 Subject: [PATCH 510/843] Remove TLSNextProto in Transport --- internal/transport/transport.go | 5 -- transport.go | 19 ----- transport_internal_test.go | 77 ------------------- transport_test.go | 126 ++------------------------------ transport_wrapper.go | 9 --- 5 files changed, 5 insertions(+), 231 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 91ca9555..b6257490 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "github.com/imroc/req/v3/internal/dump" - reqtls "github.com/imroc/req/v3/pkg/tls" "net" "net/http" "net/url" @@ -91,8 +90,4 @@ type Interface interface { // when reading from the transport. // If zero, a default (currently 4KB) is used. ReadBufferSize() int - - TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper - - SetTLSNextProto(map[string]func(authority string, c reqtls.Conn) http.RoundTripper) } diff --git a/transport.go b/transport.go index a99bbe3b..7d25aab9 100644 --- a/transport.go +++ b/transport.go @@ -225,18 +225,6 @@ type Transport struct { // This time does not include the time to send the request header. ExpectContinueTimeout time.Duration - // TLSNextProto specifies how the Transport switches to an - // alternate protocol (such as HTTP/2) after a TLS ALPN - // protocol negotiation. If Transport dials an TLS connection - // with a non-empty protocol name and TLSNextProto contains a - // map entry for that key (such as "h2"), then the func is - // called with the request's authority (such as "example.com" - // or "example.com:1234") and the TLS connection. The function - // must return a http.RoundTripper that then handles the request. - // If TLSNextProto is not nil, HTTP/2 support is not enabled - // automatically. - TLSNextProto map[string]func(authority string, c reqtls.Conn) http.RoundTripper - // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. // To set the header dynamically, see GetProxyConnectHeader. @@ -491,13 +479,6 @@ func (t *Transport) Clone() *Transport { if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() } - if t.TLSNextProto != nil { - npm := map[string]func(authority string, c reqtls.Conn) http.RoundTripper{} - for k, v := range t.TLSNextProto { - npm[k] = v - } - t2.TLSNextProto = npm - } return t2 } diff --git a/transport_internal_test.go b/transport_internal_test.go index 688f9eb3..91bea4cd 100644 --- a/transport_internal_test.go +++ b/transport_internal_test.go @@ -7,15 +7,10 @@ package req import ( - "bytes" "context" - "crypto/tls" "errors" "github.com/imroc/req/v3/internal/http2" - "github.com/imroc/req/v3/internal/testcert" "github.com/imroc/req/v3/internal/tests" - reqtls "github.com/imroc/req/v3/pkg/tls" - "io" "net" "net/http" "strings" @@ -191,75 +186,3 @@ type roundTripFunc func(r *http.Request) (*http.Response, error) func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } - -// Issue 25009 -func TestTransportBodyAltRewind(t *testing.T) { - cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) - if err != nil { - t.Fatal(err) - } - ln := tests.NewLocalListener(t) - defer ln.Close() - - go func() { - tln := tls.NewListener(ln, &tls.Config{ - NextProtos: []string{"foo"}, - Certificates: []tls.Certificate{cert}, - }) - for i := 0; i < 2; i++ { - sc, err := tln.Accept() - if err != nil { - t.Error(err) - return - } - if err := sc.(reqtls.Conn).Handshake(); err != nil { - t.Error(err) - return - } - sc.Close() - } - }() - - addr := ln.Addr().String() - req, _ := http.NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) - roundTripped := false - tr := &Transport{ - DisableKeepAlives: true, - TLSNextProto: map[string]func(string, reqtls.Conn) http.RoundTripper{ - "foo": func(authority string, c reqtls.Conn) http.RoundTripper { - return roundTripFunc(func(r *http.Request) (*http.Response, error) { - n, _ := io.Copy(io.Discard, r.Body) - if n == 0 { - t.Error("body length is zero") - } - if roundTripped { - return &http.Response{ - Body: NoBody, - StatusCode: 200, - }, nil - } - roundTripped = true - return nil, http2.ErrNoCachedConn - }) - }, - }, - DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { - tc, err := tls.Dial("tcp", addr, &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{"foo"}, - }) - if err != nil { - return nil, err - } - if err := tc.Handshake(); err != nil { - return nil, err - } - return tc, nil - }, - } - c := &http.Client{Transport: tr} - _, err = c.Do(req) - if err != nil { - t.Error(err) - } -} diff --git a/transport_test.go b/transport_test.go index 578e244c..3c9ec311 100644 --- a/transport_test.go +++ b/transport_test.go @@ -22,9 +22,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/http2" - "github.com/imroc/req/v3/internal/testcert" "github.com/imroc/req/v3/internal/tests" - reqtls "github.com/imroc/req/v3/pkg/tls" "go/token" "golang.org/x/net/http/httpproxy" nethttp2 "golang.org/x/net/http2" @@ -4214,27 +4212,6 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { } } -func TestTransportAutomaticHTTP2(t *testing.T) { - tr := tc().t - testTransportAutoHTTP(t, tr, true) -} - -func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { - tr := tc().t - tr.TLSNextProto = make(map[string]func(string, reqtls.Conn) http.RoundTripper) - testTransportAutoHTTP(t, tr, false) -} - -func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { - _, err := tr.RoundTrip(new(http.Request)) - if err == nil { - t.Error("expected error from RoundTrip") - } - if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { - t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) - } -} - // Issue 13633: there was a race where we returned bodyless responses // to callers before recycling the persistent connection, which meant // a client doing two subsequent requests could end up on different @@ -4269,89 +4246,6 @@ func TestTransportReuseConnEmptyResponseBody(t *testing.T) { } } -// Issue 13839 -func TestNoCrashReturningTransportAltConn(t *testing.T) { - cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) - if err != nil { - t.Fatal(err) - } - ln := tests.NewLocalListener(t) - defer ln.Close() - - var wg sync.WaitGroup - SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) - defer SetPendingDialHooks(nil, nil) - - testDone := make(chan struct{}) - defer close(testDone) - go func() { - tln := tls.NewListener(ln, &tls.Config{ - NextProtos: []string{"foo"}, - Certificates: []tls.Certificate{cert}, - }) - sc, err := tln.Accept() - if err != nil { - t.Error(err) - return - } - if err := sc.(*tls.Conn).Handshake(); err != nil { - t.Error(err) - return - } - <-testDone - sc.Close() - }() - - addr := ln.Addr().String() - - req, _ := http.NewRequest("GET", "https://fake.tld/", nil) - cancel := make(chan struct{}) - req.Cancel = cancel - - doReturned := make(chan bool, 1) - madeRoundTripper := make(chan bool, 1) - - tr := &Transport{ - DisableKeepAlives: true, - TLSNextProto: map[string]func(string, reqtls.Conn) http.RoundTripper{ - "foo": func(authority string, c reqtls.Conn) http.RoundTripper { - madeRoundTripper <- true - return funcRoundTripper(func() { - t.Error("foo http.RoundTripper should not be called") - }) - }, - }, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - panic("shouldn't be called") - }, - DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { - tc, err := tls.Dial("tcp", addr, &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{"foo"}, - }) - if err != nil { - return nil, err - } - if err := tc.Handshake(); err != nil { - return nil, err - } - close(cancel) - <-doReturned - return tc, nil - }, - } - c := &http.Client{Transport: tr} - - _, err = c.Do(req) - if ue, ok := err.(*url.Error); !ok || ue.Err != errRequestCanceledConn { - t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) - } - - doReturned <- true - <-madeRoundTripper - wg.Wait() -} - func TestTransportReuseConnectionGzipChunked(t *testing.T) { testTransportReuseConnectionGzip(t, true) } @@ -5600,14 +5494,11 @@ func TestTransportClone(t *testing.T) { GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, - TLSNextProto: map[string]func(authority string, c reqtls.Conn) http.RoundTripper{ - "foo": func(authority string, c reqtls.Conn) http.RoundTripper { panic("") }, - }, - ReadBufferSize: 1, - WriteBufferSize: 1, - ForceHttpVersion: HTTP1, - ResponseOptions: &ResponseOptions{}, - Debugf: func(format string, v ...interface{}) {}, + ReadBufferSize: 1, + WriteBufferSize: 1, + ForceHttpVersion: HTTP1, + ResponseOptions: &ResponseOptions{}, + Debugf: func(format string, v ...interface{}) {}, } tr2 := tr.Clone() rv := reflect.ValueOf(tr2).Elem() @@ -5622,16 +5513,9 @@ func TestTransportClone(t *testing.T) { } } - if _, ok := tr2.TLSNextProto["foo"]; !ok { - t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") - } - // But test that a nil TLSNextProto is kept nil: tr = new(Transport) tr2 = tr.Clone() - if tr2.TLSNextProto != nil { - t.Errorf("Transport.TLSNextProto unexpected non-nil") - } } func TestIs408(t *testing.T) { diff --git a/transport_wrapper.go b/transport_wrapper.go index 285a3a36..561d3aef 100644 --- a/transport_wrapper.go +++ b/transport_wrapper.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "github.com/imroc/req/v3/internal/dump" - reqtls "github.com/imroc/req/v3/pkg/tls" "net" "net/http" "net/url" @@ -102,11 +101,3 @@ func (t transportImpl) WriteBufferSize() int { func (t transportImpl) ReadBufferSize() int { return t.t.ReadBufferSize } - -func (t transportImpl) TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper { - return t.t.TLSNextProto -} - -func (t transportImpl) SetTLSNextProto(m map[string]func(authority string, c reqtls.Conn) http.RoundTripper) { - t.t.TLSNextProto = m -} From c1755f3662ec1e78ec860a9af59056dbc1e402c5 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 20:11:14 +0800 Subject: [PATCH 511/843] Only allow h3 in alt-svc --- transport.go | 1 - 1 file changed, 1 deletion(-) diff --git a/transport.go b/transport.go index 7d25aab9..6d781563 100644 --- a/transport.go +++ b/transport.go @@ -311,7 +311,6 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { var allowedProtocols = map[string]bool{ "h3": true, - "h2": true, } func (t *Transport) handleAltSvc(req *http.Request, value string) { From 422b2808aa1c0a8aa11c5b1f0ba7c42d037fb16b Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 30 Jun 2022 20:36:34 +0800 Subject: [PATCH 512/843] Some fixes 1. Fix tests relies on create transport. 2. Only try http2 or http3 when sheme is https. 3. Only try http3 if t3 is not nil. 4. Expose Client.GetTransport. 5. Remove some useless tests. --- client.go | 5 +++ internal/tests/assert.go | 2 +- transport.go | 18 +++++---- transport_test.go | 79 +--------------------------------------- 4 files changed, 17 insertions(+), 87 deletions(-) diff --git a/client.go b/client.go index 0f1c955e..76344afa 100644 --- a/client.go +++ b/client.go @@ -76,6 +76,11 @@ func (c *Client) R() *Request { } } +// GetTransport return the underlying transport. +func (c *Client) GetTransport() *Transport { + return c.t +} + // SetCommonError set the common result that response body will be unmarshalled to // if it is an error response ( status `code >= 400`). func (c *Client) SetCommonError(err interface{}) *Client { diff --git a/internal/tests/assert.go b/internal/tests/assert.go index c4bdbbe9..29adad29 100644 --- a/internal/tests/assert.go +++ b/internal/tests/assert.go @@ -58,7 +58,7 @@ func AssertContains(t *testing.T, s, substr string, shouldContain bool) { } } else { if isContain { - t.Errorf("%q is included in %s", substr, s) + t.Errorf("%q is included in %q", substr, s) } } } diff --git a/transport.go b/transport.go index 6d781563..29c2e9b8 100644 --- a/transport.go +++ b/transport.go @@ -612,7 +612,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error cancelKey := cancelKey{origReq} req = setupRewindBody(req) - if t.ForceHttpVersion != HTTP1 { + if scheme == "https" && t.ForceHttpVersion != HTTP1 { resp, err := t.t2.RoundTripOnlyCachedConn(req) if err != http2.ErrNoCachedConn { return resp, err @@ -621,13 +621,15 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error if err != nil { return nil, err } - resp, err = t.t3.RoundTripOnlyCachedConn(req) - if err != http3.ErrNoCachedConn { - return resp, err - } - req, err = rewindBody(req) - if err != nil { - return nil, err + if t.t3 != nil { + resp, err = t.t3.RoundTripOnlyCachedConn(req) + if err != http3.ErrNoCachedConn { + return resp, err + } + req, err = rewindBody(req) + if err != nil { + return nil, err + } } } diff --git a/transport_test.go b/transport_test.go index 3c9ec311..cf8c88dc 100644 --- a/transport_test.go +++ b/transport_test.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/common" - "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/tests" "go/token" "golang.org/x/net/http/httpproxy" @@ -2799,25 +2798,6 @@ func (fooProto) RoundTrip(req *http.Request) (*http.Response, error) { return res, nil } -func TestTransportAltProto(t *testing.T) { - defer afterTest(t) - tr := &Transport{} - c := &http.Client{Transport: tr} - tr.RegisterProtocol("foo", fooProto{}) - res, err := c.Get("foo://bar.com/path") - if err != nil { - t.Fatal(err) - } - bodyb, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - body := string(bodyb) - if e := "You wanted foo://bar.com/path"; body != e { - t.Errorf("got response %q, want %q", body, e) - } -} - func TestTransportNoHost(t *testing.T) { defer afterTest(t) tr := &Transport{} @@ -3112,7 +3092,7 @@ func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interfac t: t, h2: h2, h: h, - tr: &Transport{}, + tr: C().GetTransport(), } cst.c = &http.Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) @@ -3139,9 +3119,6 @@ func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interfac cst.tr.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } - if _, err := http2.ConfigureTransports(transportImpl{cst.tr}); err != nil { - t.Fatal(err) - } return cst } @@ -5848,60 +5825,6 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { } } -// Issue 36820 -// Test that we use the older backward compatible cancellation protocol -// when a http.RoundTripper is registered via RegisterProtocol. -func TestAltProtoCancellation(t *testing.T) { - defer afterTest(t) - tr := &Transport{} - c := &http.Client{ - Transport: tr, - Timeout: time.Millisecond, - } - tr.RegisterProtocol("timeout", timeoutProto{}) - _, err := c.Get("timeout://bar.com/path") - if err == nil { - t.Error("request unexpectedly succeeded") - } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { - t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) - } -} - -var timeoutProtoErr = errors.New("canceled as expected") - -type timeoutProto struct{} - -func (timeoutProto) RoundTrip(req *http.Request) (*http.Response, error) { - select { - case <-req.Cancel: - return nil, timeoutProtoErr - case <-time.After(5 * time.Second): - return nil, errors.New("request was not canceled") - } -} - -// Issue 32441: body is not reset after ErrSkipAltProtocol -func TestIssue32441(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if n, _ := io.Copy(io.Discard, r.Body); n == 0 { - t.Error("body length is zero") - } - })) - defer ts.Close() - c := tc().httpClient - c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *http.Request) (*http.Response, error) { - // Draining body to trigger failure condition on actual request to server. - if n, _ := io.Copy(io.Discard, r.Body); n == 0 { - t.Error("body length is zero during round trip") - } - return nil, http.ErrSkipAltProtocol - })) - if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { - t.Error(err) - } -} - // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { From f9bc9b2a35f691420fcb2a689c11dc5474d9b8d9 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 08:40:30 +0800 Subject: [PATCH 513/843] some refactor --- client.go | 44 +---- internal/http2/frame.go | 2 +- internal/http2/transport.go | 26 ++- internal/http3/roundtrip.go | 8 +- internal/transport/option.go | 144 ++++++++++++++ internal/transport/transport.go | 93 ---------- transport.go | 320 +++++++++++++++----------------- transport_test.go | 1 - transport_wrapper.go | 103 ---------- 9 files changed, 318 insertions(+), 423 deletions(-) create mode 100644 internal/transport/option.go delete mode 100644 internal/transport/transport.go delete mode 100644 transport_wrapper.go diff --git a/client.go b/client.go index 76344afa..50e98b5e 100644 --- a/client.go +++ b/client.go @@ -19,8 +19,6 @@ import ( urlpkg "net/url" "os" "reflect" - "runtime" - "strconv" "strings" "time" ) @@ -425,7 +423,7 @@ func (c *Client) getDumpOptions() *DumpOptions { // EnableDumpAll enable dump for all requests, including // all content for the request and response by default. func (c *Client) EnableDumpAll() *Client { - if c.t.dump != nil { // dump already started + if c.t.Dump != nil { // dump already started return c } c.t.EnableDump(c.getDumpOptions()) @@ -658,8 +656,8 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { } } c.dumpOptions = opt - if c.t.dump != nil { - c.t.dump.SetOptions(dumpOptions{opt}) + if c.t.Dump != nil { + c.t.Dump.SetOptions(dumpOptions{opt}) } return c } @@ -881,23 +879,8 @@ func (c *Client) SetUnixSocket(file string) *Client { }) } -func (c *Client) EnableHttp3() *Client { - v := runtime.Version() - ss := strings.Split(v, ".") - if len(ss) < 2 || ss[0] != "go1" { - c.log.Warnf("bad go version format: %s", v) - return c - } - minorVersion, err := strconv.Atoi(ss[1]) - if err != nil { - c.log.Warnf("bad go minor version: %s", v) - return c - } - if minorVersion >= 16 && minorVersion <= 18 { - c.t.enableH3() - } else { - c.log.Warnf("%s is not support http3", v) - } +func (c *Client) EnableHTTP3() *Client { + c.t.EnableHTTP3() return c } @@ -910,7 +893,7 @@ func NewClient() *Client { func (c *Client) Clone() *Client { t := c.t.Clone() t2 := &http2.Transport{ - Interface: transportImpl{t}, + Options: &t.Options, } t.t2 = t2 @@ -943,20 +926,7 @@ func (c *Client) Clone() *Client { // C create a new client. func C() *Client { - t := &Transport{ - ResponseOptions: &ResponseOptions{}, - ForceAttemptHTTP2: true, - Proxy: http.ProxyFromEnvironment, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}, - } - t2 := &http2.Transport{ - Interface: transportImpl{t}, - } - t.t2 = t2 + t := T() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) httpClient := &http.Client{ diff --git a/internal/http2/frame.go b/internal/http2/frame.go index e91df29c..3ed6a0ac 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -526,7 +526,7 @@ func (h2f *Framer) ReadFrame() (Frame, error) { if fh.Type == FrameHeaders && h2f.ReadMetaHeaders != nil { var dumps []*dump.Dumper if h2f.cc != nil { - dumps = dump.GetDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.Dump()) + dumps = dump.GetDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.Dump) } if len(dumps) > 0 { dd := []*dump.Dumper{} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 0444ab40..25a8895d 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -72,6 +72,7 @@ const ( // A Transport internally caches connections to servers. It is safe // for concurrent use by multiple goroutines. type Transport struct { + *transport.Options // DialTLS specifies an optional dial function for creating // TLS connections for requests. // @@ -132,11 +133,6 @@ type Transport struct { // The errType consists of only ASCII word characters. CountError func(errType string) - // t1, if non-nil, is the standard library Transport using - // this transport. Its settings are used (but not its - // RoundTrip method, etc). - transport.Interface - connPoolOnce sync.Once connPoolOrDef *clientConnPool // non-nil version of ConnPool } @@ -589,7 +585,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b func (t *Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) - if c := t.TLSClientConfig(); c != nil { + if c := t.TLSClientConfig; c != nil { *cfg = *c.Clone() } if !strSliceContains(cfg.NextProtos, NextProtoTLS) { @@ -622,7 +618,7 @@ func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Confi } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.DisableKeepAlives()) + return t.newClientConn(c, t.DisableKeepAlives) } func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { @@ -641,7 +637,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } - if d := t.IdleConnTimeout(); d != 0 { + if d := t.IdleConnTimeout; d != 0 { cc.idleTimeout = d cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } @@ -1017,7 +1013,7 @@ func commaSeparatedTrailers(req *http.Request) (string, error) { } func (cc *ClientConn) responseHeaderTimeout() time.Duration { - return cc.t.ResponseHeaderTimeout() + return cc.t.ResponseHeaderTimeout } // checkConnHeaders checks whether req has any invalid connection-level headers. @@ -1062,8 +1058,8 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { - if cc.t != nil && cc.t.Debugf() != nil { - cc.t.Debugf()("HTTP/2 %s %s", req.Method, req.URL.String()) + if cc.t != nil && cc.t.Debugf != nil { + cc.t.Debugf("HTTP/2 %s %s", req.Method, req.URL.String()) } cc.currentRequest = req ctx := req.Context() @@ -1201,7 +1197,7 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - if !cc.t.DisableCompression() && + if !cc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && !cs.isHead { @@ -1220,7 +1216,7 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { cs.requestedGzip = true } - continueTimeout := cc.t.ExpectContinueTimeout() + continueTimeout := cc.t.ExpectContinueTimeout if continueTimeout != 0 { if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { continueTimeout = 0 @@ -1231,7 +1227,7 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var dumps []*dump.Dumper if t := cs.cc.t; t != nil { - dumps = dump.GetDumpers(req.Context(), t.Dump()) + dumps = dump.GetDumpers(req.Context(), t.Dump) } // Past this point (where we send request headers), it is possible for @@ -1997,7 +1993,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { // wake up RoundTrip if there is a pending request. cc.cond.Broadcast() - closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.DisableKeepAlives() + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.DisableKeepAlives if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index d668708a..ba510c6e 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -24,7 +24,7 @@ type roundTripCloser interface { // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { - transport.Interface + *transport.Options mutex sync.Mutex // DisableCompression, if true, prevents the Transport from @@ -133,7 +133,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. hostname := authorityAddr("https", hostnameFromRequest(req)) cl, err := r.getClient(hostname, opt.OnlyCachedConn) if err != ErrNoCachedConn { - if debugf := r.Debugf(); debugf != nil { + if debugf := r.Debugf; debugf != nil { debugf("HTTP/3 %s %s", req.Method, req.URL.String()) } } @@ -189,14 +189,14 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo var err error client, err = newClient( hostname, - r.TLSClientConfig(), + r.TLSClientConfig, &roundTripperOpts{ EnableDatagram: r.EnableDatagrams, DisableCompression: r.DisableCompression, MaxHeaderBytes: r.MaxResponseHeaderBytes, StreamHijacker: r.StreamHijacker, UniStreamHijacker: r.UniStreamHijacker, - dump: r.Interface.Dump(), + dump: r.Dump, }, r.QuicConfig, r.Dial, diff --git a/internal/transport/option.go b/internal/transport/option.go new file mode 100644 index 00000000..9faaeccb --- /dev/null +++ b/internal/transport/option.go @@ -0,0 +1,144 @@ +package transport + +import ( + "context" + "crypto/tls" + "github.com/imroc/req/v3/internal/dump" + "net" + "net/http" + "net/url" + "time" +) + +type Options struct { + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // + // The proxy type is determined by the URL scheme. "http", + // "https", and "socks5" are supported. If the scheme is empty, + // "http" is assumed. + // + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // DialContext specifies the dial function for creating unencrypted TCP connections. + // If DialContext is nil, then the transport dials using package net. + // + // DialContext runs concurrently with calls to RoundTrip. + // A RoundTrip call that initiates a dial may end up using + // a connection dialed previously when the earlier connection + // becomes idle before the later DialContext completes. + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // DialTLSContext specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // If DialTLSContext is nil, DialContext and TLSClientConfig are used. + // + // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS + // requests and the TLSClientConfig and TLSHandshakeTimeout + // are ignored. The returned net.Conn is assumed to already be + // past the TLS handshake. + DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. + // If nil, the default configuration is used. + // If non-nil, HTTP/2 support may not be enabled by default. + TLSClientConfig *tls.Config + + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + + // DisableKeepAlives, if true, disables HTTP keep-alives and + // will only use the connection to the server for a single + // HTTP request. + // + // This is unrelated to the similarly named TCP keep-alives. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) connections to keep per-host. If zero, + // defaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + MaxConnsPerHost int + + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // ExpectContinueTimeout, if non-zero, specifies the amount of + // time to wait for a server's first response headers after fully + // writing the request headers if the request has an + // "Expect: 100-continue" header. Zero means no timeout and + // causes the body to be sent immediately, without + // waiting for the server to approve. + // This time does not include the time to send the request header. + ExpectContinueTimeout time.Duration + + // ProxyConnectHeader optionally specifies headers to send to + // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. + ProxyConnectHeader http.Header + + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) + + // MaxResponseHeaderBytes specifies a limit on how many + // response bytes are allowed in the server's response + // header. + // + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + + // WriteBufferSize specifies the size of the write buffer used + // when writing to the transport. + // If zero, a default (currently 4KB) is used. + WriteBufferSize int + + // ReadBufferSize specifies the size of the read buffer used + // when reading from the transport. + // If zero, a default (currently 4KB) is used. + ReadBufferSize int + + // Debugf is the optional debug function. + Debugf func(format string, v ...interface{}) + + Dump *dump.Dumper +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go deleted file mode 100644 index b6257490..00000000 --- a/internal/transport/transport.go +++ /dev/null @@ -1,93 +0,0 @@ -package transport - -import ( - "context" - "crypto/tls" - "github.com/imroc/req/v3/internal/dump" - "net" - "net/http" - "net/url" - "time" -) - -type Interface interface { - Proxy() func(*http.Request) (*url.URL, error) - Debugf() func(format string, v ...interface{}) - SetDebugf(func(format string, v ...interface{})) - DisableCompression() bool - TLSClientConfig() *tls.Config - SetTLSClientConfig(c *tls.Config) - TLSHandshakeTimeout() time.Duration - DialContext() func(ctx context.Context, network, addr string) (net.Conn, error) - DialTLSContext() func(ctx context.Context, network, addr string) (net.Conn, error) - DisableKeepAlives() bool - Dump() *dump.Dumper - - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns() int - - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // defaultMaxIdleConnsPerHost is used. - MaxIdleConnsPerHost() int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost() int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout() time.Duration - - // ResponseHeaderTimeout, if non-zero, specifies the amount of - // time to wait for a server's response headers after fully - // writing the request (including its body, if any). This - // time does not include the time to read the response body. - ResponseHeaderTimeout() time.Duration - - // ExpectContinueTimeout, if non-zero, specifies the amount of - // time to wait for a server's first response headers after fully - // writing the request headers if the request has an - // "Expect: 100-continue" header. Zero means no timeout and - // causes the body to be sent immediately, without - // waiting for the server to approve. - // This time does not include the time to send the request header. - ExpectContinueTimeout() time.Duration - - // ProxyConnectHeader optionally specifies headers to send to - // proxies during CONNECT requests. - // To set the header dynamically, see GetProxyConnectHeader. - ProxyConnectHeader() http.Header - - // GetProxyConnectHeader optionally specifies a func to return - // headers to send to proxyURL during a CONNECT request to the - // ip:port target. - // If it returns an error, the Transport's RoundTrip fails with - // that error. It can return (nil, nil) to not add headers. - // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is - // ignored. - GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) - - // MaxResponseHeaderBytes specifies a limit on how many - // response bytes are allowed in the server's response - // header. - // - // Zero means to use a default limit. - MaxResponseHeaderBytes() int64 - - // WriteBufferSize specifies the size of the write buffer used - // when writing to the transport. - // If zero, a default (currently 4KB) is used. - WriteBufferSize() int - - // ReadBufferSize specifies the size of the read buffer used - // when reading from the transport. - // If zero, a default (currently 4KB) is used. - ReadBufferSize() int -} diff --git a/transport.go b/transport.go index 29c2e9b8..51d7b84c 100644 --- a/transport.go +++ b/transport.go @@ -26,6 +26,7 @@ import ( "github.com/imroc/req/v3/internal/http3" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/socks" + "github.com/imroc/req/v3/internal/transport" "github.com/imroc/req/v3/internal/util" "github.com/imroc/req/v3/pkg/altsvc" reqtls "github.com/imroc/req/v3/pkg/tls" @@ -39,6 +40,7 @@ import ( "net/http/httptrace" "net/textproto" "net/url" + "runtime" "strconv" "strings" "sync" @@ -130,148 +132,124 @@ type Transport struct { // Force using specific http version ForceHttpVersion HttpVersion - // Proxy specifies a function to return a proxy for a given - // Request. If the function returns a non-nil error, the - // request is aborted with the provided error. - // - // The proxy type is determined by the URL scheme. "http", - // "https", and "socks5" are supported. If the scheme is empty, - // "http" is assumed. - // - // If Proxy is nil or returns a nil *URL, no proxy is used. - Proxy func(*http.Request) (*url.URL, error) + transport.Options - // DialContext specifies the dial function for creating unencrypted TCP connections. - // If DialContext is nil, then the transport dials using package net. - // - // DialContext runs concurrently with calls to RoundTrip. - // A RoundTrip call that initiates a dial may end up using - // a connection dialed previously when the earlier connection - // becomes idle before the later DialContext completes. - DialContext func(ctx context.Context, network, addr string) (net.Conn, error) - - // DialTLSContext specifies an optional dial function for creating - // TLS connections for non-proxied HTTPS requests. - // - // If DialTLSContext is nil, DialContext and TLSClientConfig are used. - // - // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS - // requests and the TLSClientConfig and TLSHandshakeTimeout - // are ignored. The returned net.Conn is assumed to already be - // past the TLS handshake. - DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) - - // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. - // If nil, the default configuration is used. - // If non-nil, HTTP/2 support may not be enabled by default. - TLSClientConfig *tls.Config - - // TLSHandshakeTimeout specifies the maximum amount of time waiting to - // wait for a TLS handshake. Zero means no timeout. - TLSHandshakeTimeout time.Duration - - // DisableKeepAlives, if true, disables HTTP keep-alives and - // will only use the connection to the server for a single - // HTTP request. - // - // This is unrelated to the similarly named TCP keep-alives. - DisableKeepAlives bool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool - - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns int - - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // defaultMaxIdleConnsPerHost is used. - MaxIdleConnsPerHost int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout time.Duration - - // ResponseHeaderTimeout, if non-zero, specifies the amount of - // time to wait for a server's response headers after fully - // writing the request (including its body, if any). This - // time does not include the time to read the response body. - ResponseHeaderTimeout time.Duration - - // ExpectContinueTimeout, if non-zero, specifies the amount of - // time to wait for a server's first response headers after fully - // writing the request headers if the request has an - // "Expect: 100-continue" header. Zero means no timeout and - // causes the body to be sent immediately, without - // waiting for the server to approve. - // This time does not include the time to send the request header. - ExpectContinueTimeout time.Duration - - // ProxyConnectHeader optionally specifies headers to send to - // proxies during CONNECT requests. - // To set the header dynamically, see GetProxyConnectHeader. - ProxyConnectHeader http.Header - - // GetProxyConnectHeader optionally specifies a func to return - // headers to send to proxyURL during a CONNECT request to the - // ip:port target. - // If it returns an error, the Transport's RoundTrip fails with - // that error. It can return (nil, nil) to not add headers. - // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is - // ignored. - GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) - - // MaxResponseHeaderBytes specifies a limit on how many - // response bytes are allowed in the server's response - // header. - // - // Zero means to use a default limit. - MaxResponseHeaderBytes int64 + t2 *http2.Transport // non-nil if http2 wired up + t3 *http3.RoundTripper - // WriteBufferSize specifies the size of the write buffer used - // when writing to the transport. - // If zero, a default (currently 4KB) is used. - WriteBufferSize int + *ResponseOptions - // ReadBufferSize specifies the size of the read buffer used - // when reading from the transport. - // If zero, a default (currently 4KB) is used. - ReadBufferSize int + // DisableAutoDecode, if true, prevents auto detect response + // body's charset and decode it to utf-8 + DisableAutoDecode bool - t2 *http2.Transport // non-nil if http2 wired up - t3 *http3.RoundTripper + // AutoDecodeContentType specifies an optional function for determine + // whether the response body should been auto decode to utf-8. + // Only valid when DisableAutoDecode is true. + AutoDecodeContentType func(contentType string) bool +} - // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero - // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. - // By default, use of any those fields conservatively disables HTTP/2. - // To use a custom dialer or TLS config and still attempt HTTP/2 - // upgrades, set this to true. - ForceAttemptHTTP2 bool +func NewTransport() *Transport { + return T() +} - *ResponseOptions +func T() *Transport { + t := &Transport{ + Options: transport.Options{ + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}, + }, + } + t.t2 = &http2.Transport{Options: &t.Options} + return t +} + +func (t *Transport) GetMaxIdleConns() int { + return t.MaxIdleConns +} + +func (t *Transport) SetMaxIdleConns(max int) *Transport { + t.MaxIdleConns = max + return t +} - dump *dump.Dumper +func (t *Transport) SetMaxConnsPerHost(max int) *Transport { + t.MaxConnsPerHost = max + return t +} - // Debugf is the optional debug function. - Debugf func(format string, v ...interface{}) +func (t *Transport) SetIdleConnTimeout(timeout time.Duration) *Transport { + t.IdleConnTimeout = timeout + return t +} + +func (t *Transport) SetResponseHeaderTimeout(timeout time.Duration) *Transport { + t.ResponseHeaderTimeout = timeout + return t +} + +func (t *Transport) SetExpectContinueTimeout(timeout time.Duration) *Transport { + t.ExpectContinueTimeout = timeout + return t +} + +func (t *Transport) SetGetProxyConnectHeader(fn func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error)) *Transport { + t.GetProxyConnectHeader = fn + return t +} + +func (t *Transport) SetProxyConnectHeader(header http.Header) *Transport { + t.ProxyConnectHeader = header + return t +} + +func (t *Transport) SetReadBufferSize(size int) *Transport { + t.ReadBufferSize = size + return t +} + +func (t *Transport) SetWriteBufferSize(size int) *Transport { + t.WriteBufferSize = size + return t +} + +func (t *Transport) SetMaxResponseHeaderBytes(max int64) *Transport { + t.MaxResponseHeaderBytes = max + return t +} + +func (t *Transport) SetResponseOptions(opt *ResponseOptions) *Transport { + t.ResponseOptions = opt + return t +} + +func (t *Transport) SetTLSClientConfig(cfg *tls.Config) *Transport { + t.TLSClientConfig = cfg + return t +} + +func (t *Transport) SetDebug(debugf func(format string, v ...interface{})) *Transport { + t.Debugf = debugf + return t +} + +func (t *Transport) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Transport { + t.Proxy = proxy + return t +} + +func (t *Transport) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { + t.DialContext = fn + return t +} + +func (t *Transport) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { + t.DialTLSContext = fn + return t } type pendingAltSvc struct { @@ -282,7 +260,30 @@ type pendingAltSvc struct { Transport http.RoundTripper } -func (t *Transport) enableH3() { +func (t *Transport) EnableHTTP3() { + v := runtime.Version() + ss := strings.Split(v, ".") + + if len(ss) < 2 || ss[0] != "go1" { + if t.Debugf != nil { + t.Debugf("bad go version format: %s", v) + } + return + } + minorVersion, err := strconv.Atoi(ss[1]) + if err != nil { + if t.Debugf != nil { + t.Debugf("bad go minor version: %s", v) + } + return + } + if !(minorVersion >= 16 && minorVersion <= 18) { + if t.Debugf != nil { + t.Debugf("%s is not support http3", v) + } + return + } + if t.altSvcJar == nil { t.altSvcJar = altsvc.NewAltSvcJar() } @@ -290,7 +291,7 @@ func (t *Transport) enableH3() { t.pendingAltSvcs = make(map[string]*pendingAltSvc) } t3 := &http3.RoundTripper{ - Interface: transportImpl{t}, + Options: &t.Options, } t.t3 = t3 } @@ -306,7 +307,7 @@ func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { t.wrapResponseBody(res, wrap) } t.autoDecodeResponseBody(res) - dump.WrapResponseBodyIfNeeded(res, req, t.dump) + dump.WrapResponseBodyIfNeeded(res, req, t.Dump) } var allowedProtocols = map[string]bool{ @@ -447,33 +448,14 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { - t2 := &Transport{ - Proxy: t.Proxy, - DialContext: t.DialContext, - DialTLSContext: t.DialTLSContext, - TLSHandshakeTimeout: t.TLSHandshakeTimeout, - DisableKeepAlives: t.DisableKeepAlives, - DisableCompression: t.DisableCompression, - MaxIdleConns: t.MaxIdleConns, - MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, - MaxConnsPerHost: t.MaxConnsPerHost, - IdleConnTimeout: t.IdleConnTimeout, - ResponseHeaderTimeout: t.ResponseHeaderTimeout, - ExpectContinueTimeout: t.ExpectContinueTimeout, - ProxyConnectHeader: t.ProxyConnectHeader.Clone(), - GetProxyConnectHeader: t.GetProxyConnectHeader, - MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, - ForceAttemptHTTP2: t.ForceAttemptHTTP2, - WriteBufferSize: t.WriteBufferSize, - ReadBufferSize: t.ReadBufferSize, - ResponseOptions: t.ResponseOptions, - ForceHttpVersion: t.ForceHttpVersion, - Debugf: t.Debugf, - dump: t.dump.Clone(), - } - if t.dump != nil { - go t.dump.Start() + Options: t.Options, + ResponseOptions: t.ResponseOptions, + ForceHttpVersion: t.ForceHttpVersion, + } + t2.Options.Dump = t.Options.Dump.Clone() + if t.Dump != nil { + go t.Dump.Start() } if t.TLSClientConfig != nil { t2.TLSClientConfig = t.TLSClientConfig.Clone() @@ -484,15 +466,15 @@ func (t *Transport) Clone() *Transport { // EnableDump enables the dump for all requests with specified dump options. func (t *Transport) EnableDump(opt *DumpOptions) { dump := newDumper(opt) - t.dump = dump + t.Dump = dump go dump.Start() } // DisableDump disables the dump. func (t *Transport) DisableDump() { - if t.dump != nil { - t.dump.Stop() - t.dump = nil + if t.Dump != nil { + t.Dump.Stop() + t.Dump = nil } } @@ -1965,7 +1947,7 @@ func fixPragmaCacheControl(header http.Header) { // 100-continue") from the server. It returns the final non-100 one. // trace is optional. func (pc *persistConn) _readResponse(req *http.Request) (*http.Response, error) { - ds := dump.GetResponseHeaderDumpers(req.Context(), pc.t.dump) + ds := dump.GetResponseHeaderDumpers(req.Context(), pc.t.Dump) tp := newTextprotoReader(pc.br, ds) resp := &http.Response{ Request: req, @@ -2597,7 +2579,7 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo } rw := w // raw writer - dumps := dump.GetDumpers(r.Context(), pc.t.dump) + dumps := dump.GetDumpers(r.Context(), pc.t.Dump) for _, dump := range dumps { if dump.RequestHeader() { w = dump.WrapWriter(w) diff --git a/transport_test.go b/transport_test.go index cf8c88dc..46a3f428 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5470,7 +5470,6 @@ func TestTransportClone(t *testing.T) { ProxyConnectHeader: http.Header{}, GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, - ForceAttemptHTTP2: true, ReadBufferSize: 1, WriteBufferSize: 1, ForceHttpVersion: HTTP1, diff --git a/transport_wrapper.go b/transport_wrapper.go deleted file mode 100644 index 561d3aef..00000000 --- a/transport_wrapper.go +++ /dev/null @@ -1,103 +0,0 @@ -package req - -import ( - "context" - "crypto/tls" - "github.com/imroc/req/v3/internal/dump" - "net" - "net/http" - "net/url" - "time" -) - -type transportImpl struct { - t *Transport -} - -func (t transportImpl) Proxy() func(*http.Request) (*url.URL, error) { - return t.t.Proxy -} - -func (t transportImpl) Debugf() func(format string, v ...interface{}) { - return t.t.Debugf -} - -func (t transportImpl) SetDebugf(f func(format string, v ...interface{})) { - t.t.Debugf = f -} - -func (t transportImpl) DisableCompression() bool { - return t.t.DisableCompression -} - -func (t transportImpl) TLSClientConfig() *tls.Config { - return t.t.TLSClientConfig -} - -func (t transportImpl) SetTLSClientConfig(c *tls.Config) { - t.t.TLSClientConfig = c -} - -func (t transportImpl) TLSHandshakeTimeout() time.Duration { - return t.t.TLSHandshakeTimeout -} - -func (t transportImpl) DialContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { - return t.t.DialContext -} - -func (t transportImpl) DialTLSContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { - return t.t.DialTLSContext -} - -func (t transportImpl) DisableKeepAlives() bool { - return t.t.DisableKeepAlives -} - -func (t transportImpl) Dump() *dump.Dumper { - return t.t.dump -} - -func (t transportImpl) MaxIdleConns() int { - return t.t.MaxIdleConns -} - -func (t transportImpl) MaxIdleConnsPerHost() int { - return t.t.MaxIdleConnsPerHost -} - -func (t transportImpl) MaxConnsPerHost() int { - return t.t.MaxConnsPerHost -} - -func (t transportImpl) IdleConnTimeout() time.Duration { - return t.t.IdleConnTimeout -} - -func (t transportImpl) ResponseHeaderTimeout() time.Duration { - return t.t.ResponseHeaderTimeout -} - -func (t transportImpl) ExpectContinueTimeout() time.Duration { - return t.t.ExpectContinueTimeout -} - -func (t transportImpl) ProxyConnectHeader() http.Header { - return t.t.ProxyConnectHeader -} - -func (t transportImpl) GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { - return t.t.GetProxyConnectHeader -} - -func (t transportImpl) MaxResponseHeaderBytes() int64 { - return t.t.MaxResponseHeaderBytes -} - -func (t transportImpl) WriteBufferSize() int { - return t.t.WriteBufferSize -} - -func (t transportImpl) ReadBufferSize() int { - return t.t.ReadBufferSize -} From 7cde489ff96416e39a3b91bd162a328e2a28822a Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 08:49:47 +0800 Subject: [PATCH 514/843] remove duplicate header dump --- internal/http3/request_writer.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 7cfdf2dd..d7a2c6ac 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -187,9 +187,6 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // modifying the hpack state. hlSize := uint64(0) enumerateHeaders(func(name, value string) { - for _, dump := range dumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) - } hf := hpack.HeaderField{Name: name, Value: value} hlSize += uint64(hf.Size()) }) From 8bd495e17683643b1d487e1ba7421aa7e0d2918d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 09:08:18 +0800 Subject: [PATCH 515/843] support dump request body --- internal/http3/client.go | 38 ++++++++++++++++++++++++++------ internal/http3/request_writer.go | 4 ++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index 200b742f..b2f68aa9 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -298,9 +298,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, rerr.err } -func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { +func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { defer body.Close() b := make([]byte, bodyCopyBufferSize) + writeData := func(data []byte) error { + if _, err := str.Write(data); err != nil { + return err + } + return nil + } + if len(dumps) > 0 { + writeData = func(data []byte) error { + for _, dump := range dumps { + dump.Dump(data) + } + if _, err := str.Write(data); err != nil { + return err + } + return nil + } + } for { n, rerr := body.Read(b) if n == 0 { @@ -311,7 +328,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { break } } - if _, err := str.Write(b[:n]); err != nil { + if err := writeData(b[:n]); err != nil { return err } if rerr != nil { @@ -330,13 +347,14 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true } - var dumps []*dump.Dumper - for _, dump := range dump.GetDumpers(req.Context(), c.opts.dump) { + dumps := dump.GetDumpers(req.Context(), c.opts.dump) + var headerDumps []*dump.Dumper + for _, dump := range dumps { if dump.RequestHeader() { - dumps = append(dumps, dump) + headerDumps = append(headerDumps, dump) } } - if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, dumps); err != nil { + if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, headerDumps); err != nil { return nil, newStreamError(errorInternalError, err) } @@ -348,7 +366,13 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, if req.Body != nil { // send the request body asynchronously go func() { - if err := c.sendRequestBody(hstr, req.Body); err != nil { + var bodyDumps []*dump.Dumper + for _, dump := range dumps { + if dump.RequestBody() { + bodyDumps = append(bodyDumps, dump) + } + } + if err := c.sendRequestBody(hstr, req.Body, bodyDumps); err != nil { c.logger.Errorf("Error writing request: %s", err) } if !opt.DontCloseRequestStream { diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index d7a2c6ac..dbd04ad0 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -211,6 +211,10 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // } }) + for _, dump := range dumps { + dump.Dump([]byte("\r\n")) + } + return nil } From 66f1243ca526cae5617772689a04d768dc5e7920 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 09:31:44 +0800 Subject: [PATCH 516/843] support dump response header --- internal/http3/client.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/internal/http3/client.go b/internal/http3/client.go index b2f68aa9..a93dc107 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -325,6 +325,9 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D continue } if rerr == io.EOF { + for _, dump := range dumps { + dump.Dump([]byte("\r\n\r\n")) + } break } } @@ -333,6 +336,9 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D } if rerr != nil { if rerr == io.EOF { + for _, dump := range dumps { + dump.Dump([]byte("\r\n\r\n")) + } break } str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) @@ -396,7 +402,23 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, if _, err := io.ReadFull(str, headerBlock); err != nil { return nil, newStreamError(errorRequestIncomplete, err) } + var respHeaderDumps []*dump.Dumper + for _, dump := range dumps { + if dump.ResponseHeader() { + respHeaderDumps = append(respHeaderDumps, dump) + } + } hfs, err := c.decoder.DecodeFull(headerBlock) + if len(respHeaderDumps) > 0 { + for _, hf := range hfs { + for _, dump := range respHeaderDumps { + dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + } + } + for _, dump := range respHeaderDumps { + dump.Dump([]byte("\r\n")) + } + } if err != nil { // TODO: use the right error code return nil, newConnError(errorGeneralProtocolError, err) From 5d3fe92df5c887a2a8fe10c7ff5f4813d17892c5 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 10:56:39 +0800 Subject: [PATCH 517/843] add http3 detect debug log --- transport.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transport.go b/transport.go index 51d7b84c..dd78ca65 100644 --- a/transport.go +++ b/transport.go @@ -361,6 +361,9 @@ func (t *Transport) handlePendingAltSvc(hostname string, pas *pendingAltSvc) { } else { pas.CurrentIndex = i pas.Transport = t.t3 + if t.Debugf != nil { + t.Debugf("detected that the server %s supports http3, will try to use http3 protocol in subsequent requests", hostname) + } return } } From 94770395b6475b0407136cc38a1b27ef3c3f29ef Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 17:41:58 +0800 Subject: [PATCH 518/843] some refactor --- client.go | 22 +--- internal/http2/transport.go | 47 +-------- internal/tests/transport.go | 138 ------------------------- transport.go | 86 +++++++++------ transport_test.go | 201 ++++++++++++++++-------------------- 5 files changed, 151 insertions(+), 343 deletions(-) diff --git a/client.go b/client.go index 50e98b5e..a6b00754 100644 --- a/client.go +++ b/client.go @@ -400,13 +400,6 @@ func (c *Client) SetLogger(log Logger) *Client { return c } -func (c *Client) getResponseOptions() *ResponseOptions { - if c.t.ResponseOptions == nil { - c.t.ResponseOptions = &ResponseOptions{} - } - return c.t.ResponseOptions -} - // SetTimeout set timeout for all requests. func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d @@ -541,40 +534,35 @@ func (c *Client) EnableAutoReadResponse() *Client { // SetAutoDecodeContentType set the content types that will be auto-detected and decode // to utf-8 (e.g. "json", "xml", "html", "text"). func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { - opt := c.getResponseOptions() - opt.AutoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) + c.t.SetAutoDecodeContentType(contentTypes...) return c } // SetAutoDecodeContentTypeFunc set the function that determines whether the // specified `Content-Type` should be auto-detected and decode to utf-8. func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { - opt := c.getResponseOptions() - opt.AutoDecodeContentType = fn + c.t.SetAutoDecodeContentTypeFunc(fn) return c } // SetAutoDecodeAllContentType enable try auto-detect charset and decode all // content type to utf-8. func (c *Client) SetAutoDecodeAllContentType() *Client { - opt := c.getResponseOptions() - opt.AutoDecodeContentType = func(contentType string) bool { - return true - } + c.t.SetAutoDecodeAllContentType() return c } // DisableAutoDecode disable auto-detect charset and decode to utf-8 // (enabled by default). func (c *Client) DisableAutoDecode() *Client { - c.getResponseOptions().DisableAutoDecode = true + c.t.DisableAutoDecode() return c } // EnableAutoDecode enable auto-detect charset and decode to utf-8 // (enabled by default). func (c *Client) EnableAutoDecode() *Client { - c.getResponseOptions().DisableAutoDecode = false + c.t.EnableAutoDecode() return c } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 25a8895d..ee8320c7 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -155,52 +155,6 @@ func (t *Transport) pingTimeout() time.Duration { } -// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns a new HTTP/2 Transport for further configuration. -// It returns an error if t1 has already been HTTP/2-enabled. -// func ConfigureTransports(t1 transport.Interface) (*Transport, error) { -// connPool := new(clientConnPool) -// t2 := &Transport{ -// ConnPool: noDialClientConnPool{connPool}, -// Interface: t1, -// } -// connPool.t = t2 -// if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { -// return nil, err -// } -// if t1.TLSClientConfig() == nil { -// t1.SetTLSClientConfig(new(tls.Config)) -// } -// if !strSliceContains(t1.TLSClientConfig().NextProtos, "h2") { -// t1.TLSClientConfig().NextProtos = append([]string{"h2"}, t1.TLSClientConfig().NextProtos...) -// } -// if !strSliceContains(t1.TLSClientConfig().NextProtos, "http/1.1") { -// t1.TLSClientConfig().NextProtos = append(t1.TLSClientConfig().NextProtos, "http/1.1") -// } -// upgradeFn := func(authority string, c reqtls.Conn) http.RoundTripper { -// addr := authorityAddr("https", authority) -// if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { -// go c.Close() -// return erringRoundTripper{err} -// } else if !used { -// // Turns out we don't need this c. -// // For example, two goroutines made requests to the same host -// // at the same time, both kicking off TCP dials. (since protocol -// // was unknown) -// go c.Close() -// } -// return t2 -// } -// if m := t1.TLSNextProto(); len(m) == 0 { -// t1.SetTLSNextProto(map[string]func(string, reqtls.Conn) http.RoundTripper{ -// "h2": upgradeFn, -// }) -// } else { -// m["h2"] = upgradeFn -// } -// return t2, nil -// } - func (t *Transport) connPool() *clientConnPool { t.connPoolOnce.Do(t.initConnPool) return t.connPoolOrDef @@ -471,6 +425,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res if err != nil { return nil, err } + traceGotConn(req, cc, true) return cc.RoundTrip(req) } for retry := 0; ; retry++ { diff --git a/internal/tests/transport.go b/internal/tests/transport.go index 9122a524..ca8701d2 100644 --- a/internal/tests/transport.go +++ b/internal/tests/transport.go @@ -1,139 +1 @@ package tests - -import ( - "context" - "crypto/tls" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/transport" - reqtls "github.com/imroc/req/v3/pkg/tls" - "net" - "net/http" - "net/url" - "time" -) - -type Transport struct { - ProxyValue func(*http.Request) (*url.URL, error) - TLSClientConfigValue *tls.Config - DisableCompressionValue bool - DisableKeepAlivesValue bool - TLSHandshakeTimeoutValue time.Duration - ResponseHeaderTimeoutValue time.Duration - ExpectContinueTimeoutValue time.Duration - IdleConnTimeoutValue time.Duration - DumpValue *dump.Dumper - ReadBufferSizeValue int - WriteBufferSizeValue int - MaxIdleConnsValue int - MaxIdleConnsPerHostValue int - MaxConnsPerHostValue int - MaxResponseHeaderBytesValue int64 - TLSNextProtoValue map[string]func(authority string, c reqtls.Conn) http.RoundTripper - GetProxyConnectHeaderValue func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) - ProxyConnectHeaderValue http.Header - DebugfValue func(format string, v ...interface{}) -} - -func (t Transport) Proxy() func(*http.Request) (*url.URL, error) { - return t.ProxyValue -} - -func (t Transport) Clone() transport.Interface { - return nil -} - -func (t Transport) Debugf() func(format string, v ...interface{}) { - return t.DebugfValue -} - -func (t Transport) SetDebugf(f func(format string, v ...interface{})) { - t.DebugfValue = f -} - -func (t Transport) DisableCompression() bool { - return t.DisableCompressionValue -} - -func (t Transport) TLSClientConfig() *tls.Config { - return t.TLSClientConfigValue -} - -func (t Transport) SetTLSClientConfig(c *tls.Config) { - t.TLSClientConfigValue = c -} - -func (t Transport) TLSHandshakeTimeout() time.Duration { - return t.TLSHandshakeTimeoutValue -} - -func (t Transport) DialContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { - return nil -} - -func (t Transport) DialTLSContext() func(ctx context.Context, network string, addr string) (net.Conn, error) { - return nil -} - -func (t Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { -} - -func (t Transport) DisableKeepAlives() bool { - return t.DisableKeepAlivesValue -} - -func (t Transport) Dump() *dump.Dumper { - return t.DumpValue - -} - -func (t Transport) MaxIdleConns() int { - return t.MaxIdleConnsValue -} - -func (t Transport) MaxIdleConnsPerHost() int { - return t.MaxIdleConnsPerHostValue -} - -func (t Transport) MaxConnsPerHost() int { - return t.MaxConnsPerHostValue -} - -func (t Transport) IdleConnTimeout() time.Duration { - return t.IdleConnTimeoutValue -} - -func (t Transport) ResponseHeaderTimeout() time.Duration { - return t.ResponseHeaderTimeoutValue -} - -func (t Transport) ExpectContinueTimeout() time.Duration { - return t.ExpectContinueTimeoutValue -} - -func (t Transport) ProxyConnectHeader() http.Header { - return t.ProxyConnectHeaderValue -} - -func (t Transport) GetProxyConnectHeader() func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { - return t.GetProxyConnectHeaderValue -} - -func (t Transport) MaxResponseHeaderBytes() int64 { - return t.MaxResponseHeaderBytesValue -} - -func (t Transport) WriteBufferSize() int { - return t.WriteBufferSizeValue -} - -func (t Transport) ReadBufferSize() int { - return t.ReadBufferSizeValue -} - -func (t Transport) TLSNextProto() map[string]func(authority string, c reqtls.Conn) http.RoundTripper { - return t.TLSNextProtoValue -} - -func (t Transport) SetTLSNextProto(m map[string]func(authority string, c reqtls.Conn) http.RoundTripper) { - t.TLSNextProtoValue = m -} diff --git a/transport.go b/transport.go index dd78ca65..b00677a2 100644 --- a/transport.go +++ b/transport.go @@ -65,18 +65,6 @@ const ( // MaxIdleConnsPerHost. const defaultMaxIdleConnsPerHost = 2 -// ResponseOptions determines that how should the response been processed. -type ResponseOptions struct { - // DisableAutoDecode, if true, prevents auto detect response - // body's charset and decode it to utf-8 - DisableAutoDecode bool - - // AutoDecodeContentType specifies an optional function for determine - // whether the response body should been auto decode to utf-8. - // Only valid when DisableAutoDecode is true. - AutoDecodeContentType func(contentType string) bool -} - // Transport is an implementation of http.RoundTripper that supports HTTP, // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). // @@ -137,22 +125,22 @@ type Transport struct { t2 *http2.Transport // non-nil if http2 wired up t3 *http3.RoundTripper - *ResponseOptions - - // DisableAutoDecode, if true, prevents auto detect response + // disableAutoDecode, if true, prevents auto detect response // body's charset and decode it to utf-8 - DisableAutoDecode bool + disableAutoDecode bool - // AutoDecodeContentType specifies an optional function for determine + // autoDecodeContentType specifies an optional function for determine // whether the response body should been auto decode to utf-8. // Only valid when DisableAutoDecode is true. - AutoDecodeContentType func(contentType string) bool + autoDecodeContentType func(contentType string) bool } +// NewTransport is an alias of T func NewTransport() *Transport { return T() } +// T create a new Transport. func T() *Transport { t := &Transport{ Options: transport.Options{ @@ -168,6 +156,42 @@ func T() *Transport { return t } +// DisableAutoDecode disable auto-detect charset and decode to utf-8 +// (enabled by default). +func (t *Transport) DisableAutoDecode() *Transport { + t.disableAutoDecode = true + return t +} + +// EnableAutoDecode enable auto-detect charset and decode to utf-8 +// (enabled by default). +func (t *Transport) EnableAutoDecode() *Transport { + t.disableAutoDecode = false + return t +} + +// SetAutoDecodeContentTypeFunc set the function that determines whether the +// specified `Content-Type` should be auto-detected and decode to utf-8. +func (t *Transport) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Transport { + t.autoDecodeContentType = fn + return t +} + +// SetAutoDecodeAllContentType enable try auto-detect charset and decode all +// content type to utf-8. +func (t *Transport) SetAutoDecodeAllContentType() *Transport { + t.autoDecodeContentType = func(contentType string) bool { + return true + } + return t +} + +// SetAutoDecodeContentType set the content types that will be auto-detected and decode +// to utf-8 (e.g. "json", "xml", "html", "text"). +func (t *Transport) SetAutoDecodeContentType(contentTypes ...string) { + t.autoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) +} + func (t *Transport) GetMaxIdleConns() int { return t.MaxIdleConns } @@ -187,6 +211,11 @@ func (t *Transport) SetIdleConnTimeout(timeout time.Duration) *Transport { return t } +func (t *Transport) SetTLSHandshakeTimeout(timeout time.Duration) *Transport { + t.TLSHandshakeTimeout = timeout + return t +} + func (t *Transport) SetResponseHeaderTimeout(timeout time.Duration) *Transport { t.ResponseHeaderTimeout = timeout return t @@ -222,11 +251,6 @@ func (t *Transport) SetMaxResponseHeaderBytes(max int64) *Transport { return t } -func (t *Transport) SetResponseOptions(opt *ResponseOptions) *Transport { - t.ResponseOptions = opt - return t -} - func (t *Transport) SetTLSClientConfig(cfg *tls.Config) *Transport { t.TLSClientConfig = cfg return t @@ -382,16 +406,13 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu } func (t *Transport) autoDecodeResponseBody(res *http.Response) { - if t.ResponseOptions == nil { - return - } - if t.ResponseOptions.DisableAutoDecode { + if t.disableAutoDecode { return } contentType := res.Header.Get("Content-Type") var shouldDecode func(contentType string) bool - if t.ResponseOptions.AutoDecodeContentType != nil { - shouldDecode = t.ResponseOptions.AutoDecodeContentType + if t.autoDecodeContentType != nil { + shouldDecode = t.autoDecodeContentType } else { shouldDecode = autoDecodeText } @@ -452,9 +473,10 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { t2 := &Transport{ - Options: t.Options, - ResponseOptions: t.ResponseOptions, - ForceHttpVersion: t.ForceHttpVersion, + Options: t.Options, + ForceHttpVersion: t.ForceHttpVersion, + disableAutoDecode: t.disableAutoDecode, + autoDecodeContentType: t.autoDecodeContentType, } t2.Options.Dump = t.Options.Dump.Clone() if t.Dump != nil { diff --git a/transport_test.go b/transport_test.go index 46a3f428..de69826e 100644 --- a/transport_test.go +++ b/transport_test.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/tests" + "github.com/imroc/req/v3/internal/transport" "go/token" "golang.org/x/net/http/httpproxy" nethttp2 "golang.org/x/net/http2" @@ -1421,9 +1422,8 @@ func TestTransportExpect100Continue(t *testing.T) { c := tc().httpClient for i, v := range tests { - tr := &Transport{ - ExpectContinueTimeout: 2 * time.Second, - } + tr := T() + tr.ExpectContinueTimeout = 2 * time.Second defer tr.CloseIdleConnections() c.Transport = tr body := bytes.NewReader(v.body) @@ -1739,12 +1739,11 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) { return }() + tr := T().SetProxy(func(*http.Request) (*url.URL, error) { + return url.Parse("http://" + ln.Addr().String()) + }) c := &http.Client{ - Transport: &Transport{ - Proxy: func(*http.Request) (*url.URL, error) { - return url.Parse("http://" + ln.Addr().String()) - }, - }, + Transport: tr, } req, err := http.NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) if err != nil { @@ -1767,14 +1766,11 @@ func TestTransportDialPreservesNetOpProxyError(t *testing.T) { var errDial = errors.New("some dial error") - tr := &Transport{ - Proxy: func(*http.Request) (*url.URL, error) { - return url.Parse("http://proxy.fake.tld/") - }, - DialContext: func(context.Context, string, string) (net.Conn, error) { - return nil, errDial - }, - } + tr := T().SetProxy(func(*http.Request) (*url.URL, error) { + return url.Parse("http://proxy.fake.tld/") + }).SetDial(func(context.Context, string, string) (net.Conn, error) { + return nil, errDial + }) defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -2687,13 +2683,11 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { defer afterTest(t) serverConnCh := make(chan net.Conn, 1) - tr := &Transport{ - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - cc, sc := net.Pipe() - serverConnCh <- sc - return cc, nil - }, - } + tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { + cc, sc := net.Pipe() + serverConnCh <- sc + return cc, nil + }) defer tr.CloseIdleConnections() errc := make(chan error, 1) req, _ := http.NewRequest("GET", "http://example.com/", nil) @@ -2800,7 +2794,7 @@ func (fooProto) RoundTrip(req *http.Request) (*http.Response, error) { func TestTransportNoHost(t *testing.T) { defer afterTest(t) - tr := &Transport{} + tr := T() _, err := tr.RoundTrip(&http.Request{ Header: make(http.Header), URL: &url.URL{ @@ -2995,24 +2989,21 @@ Content-Length: %d } - tr := &Transport{ - DialContext: func(_ context.Context, n, addr string) (net.Conn, error) { - sr, sw := io.Pipe() // server read/write - cr, cw := io.Pipe() // client read/write - conn := &rwTestConn{ - Reader: cr, - Writer: sw, - closeFunc: func() error { - sw.Close() - cw.Close() - return nil - }, - } - go send100Response(cw, sr) - return conn, nil - }, - DisableKeepAlives: false, - } + tr := T().SetDial(func(_ context.Context, n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }) defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -3092,7 +3083,7 @@ func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interfac t: t, h2: h2, h: h, - tr: C().GetTransport(), + tr: T(), } cst.c = &http.Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) @@ -3116,9 +3107,7 @@ func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interfac cst.ts.TLS = cst.ts.Config.TLSConfig cst.ts.StartTLS() - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } + cst.tr.TLSClientConfig.InsecureSkipVerify = true return cst } @@ -3432,12 +3421,9 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { getdonec := make(chan struct{}) go func() { defer close(getdonec) - tr := &Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("tcp", ln.Addr().String()) - }, - TLSHandshakeTimeout: 250 * time.Millisecond, - } + tr := T().SetDial(func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }).SetTLSHandshakeTimeout(250 * time.Millisecond) cl := &http.Client{Transport: tr} _, err := cl.Get("https://dummy.tld/") if err == nil { @@ -3764,7 +3750,7 @@ func TestRoundTripReturnsProxyError(t *testing.T) { return nil, errors.New("errorMessage") } - tr := &Transport{Proxy: badProxy} + tr := T().SetProxy(badProxy) req, _ := http.NewRequest("GET", "http://example.com", nil) @@ -3777,7 +3763,7 @@ func TestRoundTripReturnsProxyError(t *testing.T) { // tests that putting an idle conn after a call to CloseIdleConns does return it func TestTransportCloseIdleConnsThenReturn(t *testing.T) { - tr := &Transport{} + tr := T() wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn if got == n { @@ -3817,7 +3803,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) { // Test for issue 34282 // Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn func TestTransportTraceGotConnH2IdleConns(t *testing.T) { - tr := &Transport{} + tr := T() wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 if got == n { @@ -4055,11 +4041,9 @@ func TestTransportFlushesBodyChunks(t *testing.T) { rch: resBody, w: connw, } - tr := &Transport{ - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - return lw, nil - }, - } + tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { + return lw, nil + }) bodyr, bodyw := io.Pipe() // body pipe pair go func() { defer bodyw.Close() @@ -4507,12 +4491,11 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) - c := &http.Client{Transport: &Transport{ - TLSClientConfig: &tls.Config{ - ServerName: "dns-is-faked.golang", - RootCAs: certpool, - }, - }} + tr := T().SetTLSClientConfig(&tls.Config{ + ServerName: "dns-is-faked.golang", + RootCAs: certpool, + }) + c := &http.Client{Transport: tr} trace := &httptrace.ClientTrace{ TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, @@ -4568,7 +4551,7 @@ func skipIfDNSHijacked(t *testing.T) { func TestTransportEventTraceRealDNS(t *testing.T) { skipIfDNSHijacked(t) defer afterTest(t) - tr := &Transport{} + tr := T() defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -4765,21 +4748,20 @@ func TestTransportReturnsPeekError(t *testing.T) { wrote := make(chan struct{}) var wroteOnce sync.Once - tr := &Transport{ - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - c := funcConn{ - read: func([]byte) (int, error) { - <-wrote - return 0, errValue - }, - write: func(p []byte) (int, error) { - wroteOnce.Do(func() { close(wrote) }) - return len(p), nil - }, - } - return c, nil - }, - } + tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { + c := funcConn{ + read: func([]byte) (int, error) { + <-wrote + return 0, errValue + }, + write: func(p []byte) (int, error) { + wroteOnce.Do(func() { close(wrote) }) + return len(p), nil + }, + } + return c, nil + }) + _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) if err != errValue { t.Errorf("error = %#v; want %v", err, errValue) @@ -4989,7 +4971,7 @@ func TestMissingStatusNoPanic(t *testing.T) { t.Fatalf("proxyURL: %v", err) } - tr := &Transport{Proxy: http.ProxyURL(proxyURL)} + tr := T().SetProxy(http.ProxyURL(proxyURL)) req, _ := http.NewRequest("GET", "https://golang.org/", nil) res, err, panicked := doFetchCheckPanic(tr, req) @@ -5066,7 +5048,7 @@ func (d doneContext) Err() error { return d.err } // Issue 25852: Transport should check whether Context is done early. func TestTransportCheckContextDoneEarly(t *testing.T) { - tr := &Transport{} + tr := T() req, _ := http.NewRequest("GET", "http://fake.example/", nil) wantErr := errors.New("some error") req = req.WithContext(doneContext{context.Background(), wantErr}) @@ -5454,27 +5436,28 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { func TestTransportClone(t *testing.T) { tr := &Transport{ - Proxy: func(*http.Request) (*url.URL, error) { panic("") }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - TLSClientConfig: new(tls.Config), - TLSHandshakeTimeout: time.Second, - DisableKeepAlives: true, - DisableCompression: true, - MaxIdleConns: 1, - MaxIdleConnsPerHost: 1, - MaxConnsPerHost: 1, - IdleConnTimeout: time.Second, - ResponseHeaderTimeout: time.Second, - ExpectContinueTimeout: time.Second, - ProxyConnectHeader: http.Header{}, - GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, - MaxResponseHeaderBytes: 1, - ReadBufferSize: 1, - WriteBufferSize: 1, - ForceHttpVersion: HTTP1, - ResponseOptions: &ResponseOptions{}, - Debugf: func(format string, v ...interface{}) {}, + ForceHttpVersion: HTTP1, + Options: transport.Options{ + Proxy: func(*http.Request) (*url.URL, error) { panic("") }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: http.Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, + MaxResponseHeaderBytes: 1, + ReadBufferSize: 1, + WriteBufferSize: 1, + Debugf: func(format string, v ...interface{}) {}, + }, } tr2 := tr.Clone() rv := reflect.ValueOf(tr2).Elem() @@ -5880,11 +5863,9 @@ func testTransportRace(req *http.Request) { defer pw.Close() dr := &delegateReader{c: make(chan io.Reader)} - t := &Transport{ - DialContext: func(_ context.Context, net, addr string) (net.Conn, error) { - return &dumpConn{pw, dr}, nil - }, - } + t := T().SetDial(func(_ context.Context, net, addr string) (net.Conn, error) { + return &dumpConn{pw, dr}, nil + }) defer t.CloseIdleConnections() quitReadCh := make(chan struct{}) From 30245bfb7d277fb9c342e82b69d3cca9edcbfc32 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 2 Jul 2022 18:04:36 +0800 Subject: [PATCH 519/843] fix test TestDontCacheBrokenHTTP2Conn --- internal/http2/transport.go | 6 +++--- transport.go | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index ee8320c7..a93548f8 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -406,9 +406,9 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -func (t *Transport) AddConn(conn net.Conn, addr string) error { - _, err := t.connPool().addConnIfNeeded(addr, t, conn) - return err +func (t *Transport) AddConn(conn net.Conn, addr string) (used bool, err error) { + used, err = t.connPool().addConnIfNeeded(addr, t, conn) + return } // RoundTripOpt is like RoundTrip, but takes options. diff --git a/transport.go b/transport.go index b00677a2..7b13bfc7 100644 --- a/transport.go +++ b/transport.go @@ -1783,7 +1783,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; t.ForceHttpVersion != HTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if s.NegotiatedProtocol == http2.NextProtoTLS { - t.t2.AddConn(pconn.conn, cm.targetAddr) + if used, err := t.t2.AddConn(pconn.conn, cm.targetAddr); err != nil { + go pconn.conn.Close() + return nil, err + } else if !used { + go pconn.conn.Close() + } return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: t.t2}, nil } } From 8734f12a63dea420da030f82a157cca7d9adafa6 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 4 Jul 2022 17:16:17 +0800 Subject: [PATCH 520/843] fix tests --- go.mod | 8 +- go.sum | 11 + internal/handshake/aead_test.go | 1 + internal/handshake/initial_aead_test.go | 1 + .../handshake/tls_extension_handler_test.go | 1 + internal/handshake/updatable_aead_test.go | 1 + internal/http2/client_conn_pool.go | 40 +-- internal/http2/transport.go | 18 +- internal/http2/transport_go117_test.go | 56 ++-- internal/http2/transport_test.go | 171 ++++++------- internal/http3/request_writer_test.go | 10 +- internal/http3/roundtrip_test.go | 3 +- internal/http3/server_test.go | 4 +- .../logging/mock_connection_tracer_test.go | 1 + .../ackhandler/received_packet_handler.go | 105 -------- .../mocks/ackhandler/sent_packet_handler.go | 240 ------------------ internal/mocks/logging/connection_tracer.go | 1 + internal/mocks/quic/early_conn.go | 3 +- internal/mocks/quic/stream.go | 5 +- internal/protocol/version_test.go | 37 +-- internal/qerr/errors_test.go | 1 + internal/utils/linkedlist/README.md | 11 - internal/utils/linkedlist/linkedlist.go | 218 ---------------- internal/wire/header_test.go | 1 + internal/wire/version_negotiation_test.go | 1 + internal/wire/wire_suite_test.go | 2 +- 26 files changed, 188 insertions(+), 763 deletions(-) delete mode 100644 internal/mocks/ackhandler/received_packet_handler.go delete mode 100644 internal/mocks/ackhandler/sent_packet_handler.go delete mode 100644 internal/utils/linkedlist/README.md delete mode 100644 internal/utils/linkedlist/linkedlist.go diff --git a/go.mod b/go.mod index 05f85d36..8440cca7 100644 --- a/go.mod +++ b/go.mod @@ -7,16 +7,16 @@ require ( github.com/golang/mock v1.6.0 github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 - github.com/lucas-clemente/quic-go v0.27.2 + github.com/lucas-clemente/quic-go v0.28.0 github.com/marten-seemann/qpack v0.2.1 github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 github.com/marten-seemann/qtls-go1-18 v0.1.2 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.13.0 - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect - golang.org/x/net v0.0.0-20220615171555-694bf12d69de - golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c // indirect + golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect + golang.org/x/net v0.0.0-20220630215102-69896b714898 + golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect golang.org/x/text v0.3.7 golang.org/x/tools v0.1.11 // indirect golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect diff --git a/go.sum b/go.sum index 40a0d90f..26a2e59a 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lucas-clemente/quic-go v0.27.2 h1:zsMwwniyybb8B/UDNXRSYee7WpQJVOcjQEGgpw2ikXs= github.com/lucas-clemente/quic-go v0.27.2/go.mod h1:vXgO/11FBSKM+js1NxoaQ/bPtVFYfB7uxhfHXyMhl1A= +github.com/lucas-clemente/quic-go v0.28.0 h1:9eXVRgIkMQQyiyorz/dAaOYIx3TFzXsIFkNFz4cxuJM= +github.com/lucas-clemente/quic-go v0.28.0/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= @@ -95,6 +97,8 @@ github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 h1:7m/WlWcSROrcK5NxuXaxYD32BZqe/LEEnBrWcH/cOqQ= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -171,6 +175,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -200,6 +206,9 @@ golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220615171555-694bf12d69de h1:ogOG2+P6LjO2j55AkRScrkB2BFpd+Z8TY2wcM0Z3MGo= golang.org/x/net v0.0.0-20220615171555-694bf12d69de/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= +golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -236,6 +245,8 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI4N6/rdy1iusx77G3oU= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index ae406f10..76557ddc 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "github.com/lucas-clemente/quic-go" "github.com/imroc/req/v3/internal/protocol" diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index f7a02f63..7fd496e6 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -2,6 +2,7 @@ package handshake import ( "fmt" + "github.com/lucas-clemente/quic-go" "math/rand" "github.com/imroc/req/v3/internal/protocol" diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go index fcd2223b..212ab32f 100644 --- a/internal/handshake/tls_extension_handler_test.go +++ b/internal/handshake/tls_extension_handler_test.go @@ -2,6 +2,7 @@ package handshake import ( "fmt" + "github.com/lucas-clemente/quic-go" "github.com/imroc/req/v3/internal/protocol" "github.com/imroc/req/v3/internal/qtls" diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 0246cc23..83cdeebc 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "github.com/lucas-clemente/quic-go" "time" "github.com/golang/mock/gomock" diff --git a/internal/http2/client_conn_pool.go b/internal/http2/client_conn_pool.go index 5c8c9958..fcecc6e0 100644 --- a/internal/http2/client_conn_pool.go +++ b/internal/http2/client_conn_pool.go @@ -20,22 +20,12 @@ type ClientConnPool interface { // call, so the caller should not omit it. If the caller needs // to, ClientConn.RoundTrip can be called with a bogus // new(http.Request) to release the stream reservation. - GetClientConn(req *http.Request, addr string) (*ClientConn, error) + GetClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) MarkDead(*ClientConn) + CloseIdleConnections() + AddConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) } -// clientConnPoolIdleCloser is the interface implemented by ClientConnPool -// implementations which can close their idle connections. -type clientConnPoolIdleCloser interface { - ClientConnPool - closeIdleConnections() -} - -var ( - _ clientConnPoolIdleCloser = (*clientConnPool)(nil) - _ clientConnPoolIdleCloser = noDialClientConnPool{} -) - // TODO: use singleflight for dialing and addConnCalls? type clientConnPool struct { t *Transport @@ -49,16 +39,7 @@ type clientConnPool struct { addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls } -func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - return p.getClientConn(req, addr, dialOnMiss) -} - -const ( - dialOnMiss = true - noDialOnMiss = false -) - -func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { +func (p *clientConnPool) GetClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? if isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. @@ -155,7 +136,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) { // This code decides which ones live or die. // The return value used is whether c was used. // c is never closed. -func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) { +func (p *clientConnPool) AddConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) { p.mu.Lock() for _, cc := range p.conns[key] { if cc.CanTakeNewRequest() { @@ -242,7 +223,7 @@ func (p *clientConnPool) MarkDead(cc *ClientConn) { delete(p.keys, cc) } -func (p *clientConnPool) closeIdleConnections() { +func (p *clientConnPool) CloseIdleConnections() { p.mu.Lock() defer p.mu.Unlock() // TODO: don't close a cc if it was just added to the pool @@ -273,15 +254,6 @@ func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn { return out } -// noDialClientConnPool is an implementation of http2.ClientConnPool -// which never dials. We let the HTTP/1.1 client dial and use its TLS -// connection instead. -type noDialClientConnPool struct{ *clientConnPool } - -func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - return p.getClientConn(req, addr, noDialOnMiss) -} - // shouldRetryDial reports whether the current request should // retry dialing after the call finished unsuccessfully, for example // if the dial was canceled because of a context cancellation or diff --git a/internal/http2/transport.go b/internal/http2/transport.go index a93548f8..01600061 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -134,7 +134,7 @@ type Transport struct { CountError func(errType string) connPoolOnce sync.Once - connPoolOrDef *clientConnPool // non-nil version of ConnPool + connPoolOrDef ClientConnPool // non-nil version of ConnPool } func (t *Transport) maxHeaderListSize() uint32 { @@ -155,13 +155,17 @@ func (t *Transport) pingTimeout() time.Duration { } -func (t *Transport) connPool() *clientConnPool { +func (t *Transport) connPool() ClientConnPool { t.connPoolOnce.Do(t.initConnPool) return t.connPoolOrDef } func (t *Transport) initConnPool() { - t.connPoolOrDef = &clientConnPool{t: t} + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &clientConnPool{t: t} + } } // ClientConn is the state of a single HTTP/2 client connection to an @@ -407,7 +411,7 @@ func authorityAddr(scheme string, authority string) (addr string) { } func (t *Transport) AddConn(conn net.Conn, addr string) (used bool, err error) { - used, err = t.connPool().addConnIfNeeded(addr, t, conn) + used, err = t.connPool().AddConnIfNeeded(addr, t, conn) return } @@ -421,7 +425,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res var cc *ClientConn var err error if opt.OnlyCachedConn { - cc, err = t.connPool().getClientConn(req, addr, false) + cc, err = t.connPool().GetClientConn(req, addr, false) if err != nil { return nil, err } @@ -429,7 +433,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res return cc.RoundTrip(req) } for retry := 0; ; retry++ { - cc, err = t.connPool().getClientConn(req, addr, true) + cc, err = t.connPool().GetClientConn(req, addr, true) if err != nil { t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err @@ -467,7 +471,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. func (t *Transport) CloseIdleConnections() { - t.connPool().closeIdleConnections() + t.connPool().CloseIdleConnections() } var ( diff --git a/internal/http2/transport_go117_test.go b/internal/http2/transport_go117_test.go index c39dc7a1..e46e5fcf 100644 --- a/internal/http2/transport_go117_test.go +++ b/internal/http2/transport_go117_test.go @@ -11,7 +11,7 @@ import ( "context" "crypto/tls" "errors" - "github.com/imroc/req/v3/internal/tests" + "github.com/imroc/req/v3/internal/transport" "net/http" "net/http/httptest" @@ -33,20 +33,21 @@ func TestTransportDialTLSContexth2(t *testing.T) { serverTLSConfigFunc, ) defer ts.Close() - tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Tests that the context provided to `req` is - // passed into this function. - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() - }, - InsecureSkipVerify: true, + opt := &transport.Options{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Tests that the context provided to `req` is + // passed into this function. + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() }, + InsecureSkipVerify: true, }, } + tr := &Transport{ + Options: opt, + } defer tr.CloseIdleConnections() req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) if err != nil { @@ -99,24 +100,25 @@ func TestDialRaceResumesDial(t *testing.T) { serverTLSConfigFunc, ) defer ts.Close() - tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - select { - case <-blockCh: - // If we already errored, return without error. - return &tls.Certificate{}, nil - default: - } - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() - }, - InsecureSkipVerify: true, + opt := &transport.Options{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + select { + case <-blockCh: + // If we already errored, return without error. + return &tls.Certificate{}, nil + default: + } + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() }, + InsecureSkipVerify: true, }, } + tr := &Transport{ + Options: opt, + } defer tr.CloseIdleConnections() req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) if err != nil { diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 5fce877d..a4739f10 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -15,6 +15,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" + "github.com/imroc/req/v3/internal/transport" "io" "io/ioutil" "log" @@ -60,7 +61,10 @@ func TestTransportExternal(t *testing.T) { t.Skip("skipping external network test") } req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) - rt := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + opt := &transport.Options{ + TLSClientConfig: tlsConfigInsecure, + } + rt := &Transport{Options: opt} res, err := rt.RoundTrip(req) if err != nil { t.Fatalf("%v", err) @@ -112,7 +116,7 @@ func TestTransportH2c(t *testing.T) { } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) tr := &Transport{ - Interface: tests.Transport{}, + Options: &transport.Options{}, AllowHTTP: true, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) @@ -144,7 +148,8 @@ func TestTransport(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + opt := &transport.Options{TLSClientConfig: tlsConfigInsecure} + tr := &Transport{Options: opt} defer tr.CloseIdleConnections() u, err := url.Parse(st.ts.URL) @@ -201,9 +206,7 @@ func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Req }) defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, } defer tr.CloseIdleConnections() get := func() string { @@ -280,8 +283,8 @@ func testTransportGetGotConnHooks(t *testing.T) { defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, }, } @@ -353,8 +356,8 @@ func TestTransportGroupsPendingDials(t *testing.T) { closeCount int ) tr := &Transport{ - Interface: &tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() @@ -427,7 +430,11 @@ func TestTransportAbortClosesPipes(t *testing.T) { errCh := make(chan error) go func() { defer close(errCh) - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{ + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, + }, + } req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { errCh <- err @@ -470,7 +477,11 @@ func TestTransportPath(t *testing.T) { ) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{ + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, + }, + } defer tr.CloseIdleConnections() const ( path = "/testpath" @@ -584,7 +595,11 @@ func TestTransportBody(t *testing.T) { defer st.Close() for i, tt := range bodyTests { - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{ + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, + }, + } defer tr.CloseIdleConnections() var body io.Reader = strings.NewReader(tt.body) @@ -641,7 +656,7 @@ func TestTransportDialTLSh2(t *testing.T) { ) defer ts.Close() tr := &Transport{ - Interface: tests.Transport{}, + Options: &transport.Options{}, DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() didDial = true @@ -714,8 +729,8 @@ func newClientTester(t *testing.T) *clientTester { t: t, } ct.tr = &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { dialOnce.Lock() @@ -986,7 +1001,11 @@ func TestTransportFullDuplex(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{ + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, + }, + } defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1039,7 +1058,7 @@ func TestTransportConnectRequest(t *testing.T) { t.Fatal(err) } - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1654,7 +1673,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { ) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() checkRoundTrip := func(req *http.Request, wantErr error, desc string) { @@ -1726,7 +1745,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { // Get the ClientConn associated with the request and validate // peerMaxHeaderListSize. addr := authorityAddr(req.URL.Scheme, req.URL.Host) - cc, err := tr.connPool().GetClientConn(req, addr) + cc, err := tr.connPool().GetClientConn(req, addr, true) if err != nil { t.Fatalf("GetClientConn: %v", err) } @@ -1913,7 +1932,7 @@ func TestTransportBodyReadErrorType(t *testing.T) { ) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -1954,9 +1973,7 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) { defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) if err != nil { @@ -1987,9 +2004,9 @@ func TestTransportDisableKeepAlives(t *testing.T) { connClosed := make(chan struct{}) // closed on tls.Conn.Close tr := &Transport{ - Interface: tests.Transport{ - DisableKeepAlivesValue: true, - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + DisableKeepAlives: true, + TLSClientConfig: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -2033,9 +2050,9 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { var dials int32 var conns sync.WaitGroup tr := &Transport{ - Interface: tests.Transport{ - DisableKeepAlivesValue: true, - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + DisableKeepAlives: true, + TLSClientConfig: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -2113,9 +2130,7 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { func testTransportResponseHeaderTimeout(t *testing.T, body bool) { ct := newClientTester(t) - ct.tr.Interface = &tests.Transport{ - ResponseHeaderTimeoutValue: 5 * time.Millisecond, - } + ct.tr.Options.ResponseHeaderTimeout = 5 * time.Millisecond ct.client = func() error { c := &http.Client{Transport: ct.tr} var err error @@ -2176,9 +2191,9 @@ func TestTransportDisableCompression(t *testing.T) { defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - DisableCompressionValue: true, - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + DisableCompression: true, + TLSClientConfig: tlsConfigInsecure, }, } defer tr.CloseIdleConnections() @@ -2206,7 +2221,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() tests := []struct { @@ -2350,7 +2365,7 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { w.Header().Set("Content-Length", tt.cl[0]) }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("HEAD", st.ts.URL, nil) @@ -2405,7 +2420,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { }, } - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() for i, tt := range testCases { @@ -2510,8 +2525,8 @@ func TestTransportNewTLSConfig(t *testing.T) { } tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tt.conf, + Options: &transport.Options{ + TLSClientConfig: tt.conf, }, } got := tr.newTLSConfig(tt.host) @@ -2660,7 +2675,7 @@ func TestTransportHandlerBodyClose(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() g0 := runtime.NumGoroutine() @@ -2713,7 +2728,7 @@ func TestTransportFlowControl(t *testing.T) { } }, optOnlyServer) - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { @@ -3177,7 +3192,7 @@ func TestTransportBodyDoubleEndStream(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() for i := 0; i < 2; i++ { @@ -3345,7 +3360,7 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { func TestClientConnPing(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -3384,7 +3399,7 @@ func TestTransportCancelDataResponseRace(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} @@ -3421,7 +3436,7 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("GET", st.ts.URL, nil) @@ -3470,9 +3485,7 @@ func TestTransportPingWriteBlocks(t *testing.T) { ) defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { s, c := net.Pipe() // unbuffered, unlike a TCP conn go func() { @@ -3624,9 +3637,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { defer ln.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, } tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { dialer.Lock() @@ -3883,9 +3894,7 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { connCount int ) tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCountMu.Lock() defer connCountMu.Unlock() @@ -4142,7 +4151,7 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} res, err := c.Get(st.ts.URL) @@ -4233,7 +4242,7 @@ func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { ) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) @@ -4277,7 +4286,7 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() // The request body needs to be big enough to trigger flow control. @@ -4395,7 +4404,7 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { } }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -4797,7 +4806,7 @@ func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunk }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() req, _ := http.NewRequest("POST", st.ts.URL, body) @@ -4863,7 +4872,7 @@ func (fce *fakeConnErr) Close() error { // issue 39337: close the connection on a failed write func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { - tr := &Transport{Interface: tests.Transport{}} + tr := &Transport{Options: &transport.Options{}} writeErr := errors.New("write error") fakeConn := &fakeConnErr{writeErr: writeErr} _, err := tr.NewClientConn(fakeConn) @@ -4883,7 +4892,7 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -5014,7 +5023,7 @@ func TestTransportFrameBufferReuse(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() var wg sync.WaitGroup @@ -5106,8 +5115,8 @@ func TestTransportBlockingRequestWrite(t *testing.T) { connc := make(chan *blockingWriteConn, 1) connCount := 0 tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, + Options: &transport.Options{ + TLSClientConfig: tlsConfigInsecure, }, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCount++ @@ -5200,7 +5209,7 @@ func TestTransportCloseRequestBody(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) @@ -5233,15 +5242,15 @@ func TestTransportCloseRequestBody(t *testing.T) { // collectClientsConnPool is a ClientConnPool that wraps lower and // collects what calls were made on it. type collectClientsConnPool struct { - lower ClientConnPool + ClientConnPool mu sync.Mutex getErrs int got []*ClientConn } -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) +func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { + cc, err := p.ClientConnPool.GetClientConn(req, addr, dialOnMiss) p.mu.Lock() defer p.mu.Unlock() if err != nil { @@ -5252,14 +5261,10 @@ func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) ( return cc, nil } -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { ct := newClientTester(t) pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, + ClientConnPool: &clientConnPool{t: ct.tr}, } ct.tr.ConnPool = pool @@ -5388,7 +5393,7 @@ func TestClientConnReservations(t *testing.T) { streams: make(map[uint32]*clientStream), maxConcurrentStreams: initialMaxConcurrentStreams, nextStreamID: 1, - t: &Transport{Interface: tests.Transport{}}, + t: &Transport{Options: &transport.Options{}}, } cc.cond = sync.NewCond(&cc.mu) n := 0 @@ -5465,7 +5470,7 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { w.Header().Set("Content-Length", contentLength) }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() for _, test := range []struct { @@ -5520,7 +5525,7 @@ func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() pr, pw := net.Pipe() @@ -5548,7 +5553,7 @@ func TestTransport300ResponseBody(t *testing.T) { }, optOnlyServer) defer st.Close() - tr := &Transport{Interface: tests.Transport{TLSClientConfigValue: tlsConfigInsecure}} + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} defer tr.CloseIdleConnections() pr, pw := net.Pipe() @@ -5579,9 +5584,7 @@ func TestTransportWriteByteTimeout(t *testing.T) { ) defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { _, c := net.Pipe() return c, nil @@ -5625,9 +5628,7 @@ func TestTransportSlowWrites(t *testing.T) { ) defer st.Close() tr := &Transport{ - Interface: tests.Transport{ - TLSClientConfigValue: tlsConfigInsecure, - }, + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { cfg.InsecureSkipVerify = true c, err := tls.Dial(network, addr, cfg) diff --git a/internal/http3/request_writer_test.go b/internal/http3/request_writer_test.go index a9cbb69f..345c74cb 100644 --- a/internal/http3/request_writer_test.go +++ b/internal/http3/request_writer_test.go @@ -50,7 +50,7 @@ var _ = Describe("Request Writer", func() { It("writes a GET request", func() { req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) @@ -72,7 +72,7 @@ var _ = Describe("Request Writer", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) }) @@ -80,7 +80,7 @@ var _ = Describe("Request Writer", func() { It("adds the header for gzip support", func() { req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, true, nil)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) @@ -88,7 +88,7 @@ var _ = Describe("Request Writer", func() { It("writes a CONNECT request", func() { req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) @@ -101,7 +101,7 @@ var _ = Describe("Request Writer", func() { req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) Expect(err).ToNot(HaveOccurred()) req.Proto = "webtransport" - Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index e70c45a2..82b7e707 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/imroc/req/v3/internal/transport" "io" "net/http" "time" @@ -66,7 +67,7 @@ var _ = Describe("RoundTripper", func() { ) BeforeEach(func() { - rt = &RoundTripper{} + rt = &RoundTripper{Options: &transport.Options{}} var err error req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index 3b4e68f8..a051b847 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -15,10 +15,10 @@ import ( mockquic "github.com/imroc/req/v3/internal/mocks/quic" "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" "github.com/imroc/req/v3/internal/testdata" "github.com/imroc/req/v3/internal/utils" "github.com/lucas-clemente/quic-go" - "github.com/imroc/req/v3/internal/quicvarint" "github.com/golang/mock/gomock" "github.com/marten-seemann/qpack" @@ -136,7 +136,7 @@ var _ = Describe("Server", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) return buf.Bytes() } diff --git a/internal/logging/mock_connection_tracer_test.go b/internal/logging/mock_connection_tracer_test.go index a0a398aa..4d628fee 100644 --- a/internal/logging/mock_connection_tracer_test.go +++ b/internal/logging/mock_connection_tracer_test.go @@ -5,6 +5,7 @@ package logging import ( + "github.com/lucas-clemente/quic-go" net "net" reflect "reflect" time "time" diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go deleted file mode 100644 index a4134c67..00000000 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ /dev/null @@ -1,105 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/ackhandler (interfaces: ReceivedPacketHandler) - -// Package mockackhandler is a generated GoMock package. -package mockackhandler - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - wire "github.com/imroc/req/v3/internal/wire" -) - -// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. -type MockReceivedPacketHandler struct { - ctrl *gomock.Controller - recorder *MockReceivedPacketHandlerMockRecorder -} - -// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler. -type MockReceivedPacketHandlerMockRecorder struct { - mock *MockReceivedPacketHandler -} - -// NewMockReceivedPacketHandler creates a new mock instance. -func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { - mock := &MockReceivedPacketHandler{ctrl: ctrl} - mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { - return m.recorder -} - -// DropPackets mocks base method. -func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropPackets", arg0) -} - -// DropPackets indicates an expected call of DropPackets. -func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) -} - -// GetAckFrame mocks base method. -func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) - ret0, _ := ret[0].(*wire.AckFrame) - return ret0 -} - -// GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) -} - -// GetAlarmTimeout mocks base method. -func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAlarmTimeout") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// GetAlarmTimeout indicates an expected call of GetAlarmTimeout. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) -} - -// IsPotentiallyDuplicate mocks base method. -func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsPotentiallyDuplicate", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. -func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) -} - -// ReceivedPacket mocks base method. -func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN, arg2 protocol.EncryptionLevel, arg3 time.Time, arg4 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) -} diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go deleted file mode 100644 index 9c41986a..00000000 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ /dev/null @@ -1,240 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/ackhandler (interfaces: SentPacketHandler) - -// Package mockackhandler is a generated GoMock package. -package mockackhandler - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/ackhandler" - protocol "github.com/imroc/req/v3/internal/protocol" - wire "github.com/imroc/req/v3/internal/wire" -) - -// MockSentPacketHandler is a mock of SentPacketHandler interface. -type MockSentPacketHandler struct { - ctrl *gomock.Controller - recorder *MockSentPacketHandlerMockRecorder -} - -// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler. -type MockSentPacketHandlerMockRecorder struct { - mock *MockSentPacketHandler -} - -// NewMockSentPacketHandler creates a new mock instance. -func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { - mock := &MockSentPacketHandler{ctrl: ctrl} - mock.recorder = &MockSentPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { - return m.recorder -} - -// DropPackets mocks base method. -func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropPackets", arg0) -} - -// DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) -} - -// GetLossDetectionTimeout mocks base method. -func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLossDetectionTimeout") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) -} - -// HasPacingBudget mocks base method. -func (m *MockSentPacketHandler) HasPacingBudget() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasPacingBudget") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSentPacketHandlerMockRecorder) HasPacingBudget() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSentPacketHandler)(nil).HasPacingBudget)) -} - -// OnLossDetectionTimeout mocks base method. -func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnLossDetectionTimeout") - ret0, _ := ret[0].(error) - return ret0 -} - -// OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) -} - -// PeekPacketNumber mocks base method. -func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PeekPacketNumber", arg0) - ret0, _ := ret[0].(protocol.PacketNumber) - ret1, _ := ret[1].(protocol.PacketNumberLen) - return ret0, ret1 -} - -// PeekPacketNumber indicates an expected call of PeekPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) -} - -// PopPacketNumber mocks base method. -func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PopPacketNumber", arg0) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// PopPacketNumber indicates an expected call of PopPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) -} - -// QueueProbePacket mocks base method. -func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueueProbePacket", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// QueueProbePacket indicates an expected call of QueueProbePacket. -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) -} - -// ReceivedAck mocks base method. -func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReceivedAck indicates an expected call of ReceivedAck. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) -} - -// ReceivedBytes mocks base method. -func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedBytes", arg0) -} - -// ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) -} - -// ResetForRetry mocks base method. -func (m *MockSentPacketHandler) ResetForRetry() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetForRetry") - ret0, _ := ret[0].(error) - return ret0 -} - -// ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) -} - -// SendMode mocks base method. -func (m *MockSentPacketHandler) SendMode() ackhandler.SendMode { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMode") - ret0, _ := ret[0].(ackhandler.SendMode) - return ret0 -} - -// SendMode indicates an expected call of SendMode. -func (mr *MockSentPacketHandlerMockRecorder) SendMode() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode)) -} - -// SentPacket mocks base method. -func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) -} - -// SetHandshakeConfirmed mocks base method. -func (m *MockSentPacketHandler) SetHandshakeConfirmed() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeConfirmed") -} - -// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) -} - -// SetMaxDatagramSize mocks base method. -func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxDatagramSize", arg0) -} - -// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) -} - -// TimeUntilSend mocks base method. -func (m *MockSentPacketHandler) TimeUntilSend() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TimeUntilSend") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) -} diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index fc1ae8a1..a305fe0a 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -5,6 +5,7 @@ package mocklogging import ( + "github.com/lucas-clemente/quic-go" net "net" reflect "reflect" time "time" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 4a9ee5ef..9eecbc65 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -11,7 +11,6 @@ import ( gomock "github.com/golang/mock/gomock" quic "github.com/lucas-clemente/quic-go" - qerr "github.com/imroc/req/v3/internal/qerr" ) // MockEarlyConnection is a mock of EarlyConnection interface. @@ -68,7 +67,7 @@ func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *go } // CloseWithError mocks base method. -func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { +func (m *MockEarlyConnection) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index a298ba5e..dbdb3429 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -11,7 +11,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - qerr "github.com/imroc/req/v3/internal/qerr" ) // MockStream is a mock of Stream interface. @@ -38,7 +37,7 @@ func (m *MockStream) EXPECT() *MockStreamMockRecorder { } // CancelRead mocks base method. -func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { +func (m *MockStream) CancelRead(arg0 quic.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } @@ -50,7 +49,7 @@ func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { } // CancelWrite mocks base method. -func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { +func (m *MockStream) CancelWrite(arg0 quic.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 33c6598b..33e59f71 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -1,12 +1,13 @@ package protocol import ( + "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Version", func() { - isReservedVersion := func(v VersionNumber) bool { + isReservedVersion := func(v quic.VersionNumber) bool { return v&0x0f0f0f0f == 0x0a0a0a0a } @@ -31,11 +32,11 @@ var _ = Describe("Version", func() { Expect(Version1.String()).To(Equal("v1")) Expect(Version2.String()).To(Equal("v2")) // check with unsupported version numbers from the wiki - Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) - Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) - Expect(VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) - Expect(VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) - Expect(VersionNumber(0x01234567).String()).To(Equal("0x1234567")) + Expect(quic.VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) + Expect(quic.VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) + Expect(quic.VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) + Expect(quic.VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) + Expect(quic.VersionNumber(0x01234567).String()).To(Equal("0x1234567")) }) It("recognizes supported versions", func() { @@ -46,45 +47,45 @@ var _ = Describe("Version", func() { Context("highest supported version", func() { It("finds the supported version", func() { - supportedVersions := []VersionNumber{1, 2, 3} - other := []VersionNumber{6, 5, 4, 3} + supportedVersions := []quic.VersionNumber{1, 2, 3} + other := []quic.VersionNumber{6, 5, 4, 3} ver, ok := ChooseSupportedVersion(supportedVersions, other) Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(3))) + Expect(ver).To(Equal(quic.VersionNumber(3))) }) It("picks the preferred version", func() { - supportedVersions := []VersionNumber{2, 1, 3} - other := []VersionNumber{3, 6, 1, 8, 2, 10} + supportedVersions := []quic.VersionNumber{2, 1, 3} + other := []quic.VersionNumber{3, 6, 1, 8, 2, 10} ver, ok := ChooseSupportedVersion(supportedVersions, other) Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(2))) + Expect(ver).To(Equal(quic.VersionNumber(2))) }) It("says when no matching version was found", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{1}, []VersionNumber{2}) + _, ok := ChooseSupportedVersion([]quic.VersionNumber{1}, []quic.VersionNumber{2}) Expect(ok).To(BeFalse()) }) It("handles empty inputs", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{102, 101}, []VersionNumber{}) + _, ok := ChooseSupportedVersion([]quic.VersionNumber{102, 101}, []quic.VersionNumber{}) Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{1, 2}) + _, ok = ChooseSupportedVersion([]quic.VersionNumber{}, []quic.VersionNumber{1, 2}) Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{}) + _, ok = ChooseSupportedVersion([]quic.VersionNumber{}, []quic.VersionNumber{}) Expect(ok).To(BeFalse()) }) }) Context("reserved versions", func() { It("adds a greased version if passed an empty slice", func() { - greased := GetGreasedVersions([]VersionNumber{}) + greased := GetGreasedVersions([]quic.VersionNumber{}) Expect(greased).To(HaveLen(1)) Expect(isReservedVersion(greased[0])).To(BeTrue()) }) It("creates greased lists of version numbers", func() { - supported := []VersionNumber{10, 18, 29} + supported := []quic.VersionNumber{10, 18, 29} for _, v := range supported { Expect(isReservedVersion(v)).To(BeFalse()) } diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index 4376fcc6..81aa2df1 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -2,6 +2,7 @@ package qerr import ( "errors" + "github.com/lucas-clemente/quic-go" "net" "github.com/imroc/req/v3/internal/protocol" diff --git a/internal/utils/linkedlist/README.md b/internal/utils/linkedlist/README.md deleted file mode 100644 index 15b46dce..00000000 --- a/internal/utils/linkedlist/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Usage - -This is the Go standard library implementation of a linked list -(https://golang.org/src/container/list/list.go), modified such that genny -(https://github.com/cheekybits/genny) can be used to generate a typed linked -list. - -To generate, run -``` -genny -pkg $PACKAGE -in linkedlist.go -out $OUTFILE gen Item=$TYPE -``` diff --git a/internal/utils/linkedlist/linkedlist.go b/internal/utils/linkedlist/linkedlist.go deleted file mode 100644 index 74b815a8..00000000 --- a/internal/utils/linkedlist/linkedlist.go +++ /dev/null @@ -1,218 +0,0 @@ -package linkedlist - -import "github.com/cheekybits/genny/generic" - -// Linked list implementation from the Go standard library. - -// Item is a generic type. -type Item generic.Type - -// ItemElement is an element of a linked list. -type ItemElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *ItemElement - - // The list to which this element belongs. - list *ItemList - - // The value stored with this element. - Value Item -} - -// Next returns the next list element or nil. -func (e *ItemElement) Next() *ItemElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *ItemElement) Prev() *ItemElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// ItemList is a linked list of Items. -type ItemList struct { - root ItemElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *ItemList) Init() *ItemList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewItemList returns an initialized list. -func NewItemList() *ItemList { return new(ItemList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *ItemList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *ItemList) Front() *ItemElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *ItemList) Back() *ItemElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *ItemList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *ItemList) insert(e, at *ItemElement) *ItemElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *ItemList) insertValue(v Item, at *ItemElement) *ItemElement { - return l.insert(&ItemElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *ItemList) remove(e *ItemElement) *ItemElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *ItemList) Remove(e *ItemElement) Item { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *ItemList) PushFront(v Item) *ItemElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *ItemList) PushBack(v Item) *ItemElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ItemList) InsertBefore(v Item, mark *ItemElement) *ItemElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ItemList) InsertAfter(v Item, mark *ItemElement) *ItemElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ItemList) MoveToFront(e *ItemElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ItemList) MoveToBack(e *ItemElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ItemList) MoveBefore(e, mark *ItemElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ItemList) MoveAfter(e, mark *ItemElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ItemList) PushBackList(other *ItemList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ItemList) PushFrontList(other *ItemList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index f7ecab39..8025571f 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "encoding/binary" + "github.com/lucas-clemente/quic-go" "io" "github.com/imroc/req/v3/internal/protocol" diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 4cc1f68e..6a2a62b9 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "encoding/binary" + "github.com/lucas-clemente/quic-go" "github.com/imroc/req/v3/internal/protocol" . "github.com/onsi/ginkgo" diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go index 528f728f..9042747a 100644 --- a/internal/wire/wire_suite_test.go +++ b/internal/wire/wire_suite_test.go @@ -3,9 +3,9 @@ package wire import ( "bytes" "encoding/binary" + "github.com/lucas-clemente/quic-go" "testing" - "github.com/imroc/req/v3/internal/protocol" "github.com/imroc/req/v3/internal/quicvarint" . "github.com/onsi/ginkgo" From 238f9c95bb1cd12266264591eddd970a7011d7a7 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 4 Jul 2022 20:18:32 +0800 Subject: [PATCH 521/843] http3: set default keepalive interval --- internal/http3/client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index a93dc107..032d0dce 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -17,6 +17,7 @@ import ( "net/http" "strconv" "sync" + "time" ) // MethodGet0RTT allows a GET request to be sent using 0-RTT. @@ -29,8 +30,8 @@ const ( var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams - // KeepAlivePeriod: 10 * time.Second, - Versions: []quic.VersionNumber{protocol.VersionTLS}, + KeepAlivePeriod: 10 * time.Second, + Versions: []quic.VersionNumber{protocol.VersionTLS}, } type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) From e4bcf0c4833220b5f5bda7157986f535c70d5f4c Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 4 Jul 2022 21:02:02 +0800 Subject: [PATCH 522/843] remove unused code --- internal/handshake/aead.go | 162 --- internal/handshake/aead_test.go | 205 --- internal/handshake/crypto_setup.go | 820 ----------- internal/handshake/crypto_setup_test.go | 864 ----------- internal/handshake/handshake_suite_test.go | 48 - internal/handshake/header_protector.go | 137 -- internal/handshake/hkdf.go | 29 - internal/handshake/hkdf_test.go | 17 - internal/handshake/initial_aead.go | 82 -- internal/handshake/initial_aead_test.go | 220 --- internal/handshake/interface.go | 103 -- .../handshake/mock_handshake_runner_test.go | 84 -- internal/handshake/mockgen.go | 3 - internal/handshake/retry.go | 63 - internal/handshake/retry_test.go | 36 - internal/handshake/session_ticket.go | 48 - internal/handshake/session_ticket_test.go | 54 - internal/handshake/tls_extension_handler.go | 69 - .../handshake/tls_extension_handler_test.go | 211 --- internal/handshake/token_generator.go | 134 -- internal/handshake/token_generator_test.go | 127 -- internal/handshake/token_protector.go | 89 -- internal/handshake/token_protector_test.go | 67 - internal/handshake/updatable_aead.go | 324 ----- internal/handshake/updatable_aead_test.go | 529 ------- internal/http3/server.go | 685 --------- internal/http3/server_test.go | 1289 ----------------- internal/logging/frame.go | 66 - internal/logging/interface.go | 135 -- internal/logging/logging_suite_test.go | 25 - .../logging/mock_connection_tracer_test.go | 352 ----- internal/logging/mock_tracer_test.go | 76 - internal/logging/mockgen.go | 4 - internal/logging/multiplex.go | 219 --- internal/logging/multiplex_test.go | 266 ---- internal/logging/packet_header.go | 27 - internal/logging/packet_header_test.go | 60 - internal/logging/types.go | 94 -- internal/mocks/congestion.go | 192 --- internal/mocks/connection_flow_controller.go | 128 -- internal/mocks/crypto_setup.go | 264 ---- internal/mocks/logging/connection_tracer.go | 353 ----- internal/mocks/logging/tracer.go | 77 - internal/mocks/long_header_opener.go | 76 - internal/mocks/quic/early_listener.go | 80 - internal/mocks/short_header_opener.go | 77 - internal/mocks/short_header_sealer.go | 89 -- internal/mocks/stream_flow_controller.go | 140 -- internal/mocks/tls/client_session_cache.go | 62 - internal/qerr/error_codes.go | 88 -- internal/qerr/errorcodes_test.go | 52 - internal/qerr/errors.go | 125 -- internal/qerr/errors_suite_test.go | 13 - internal/qerr/errors_test.go | 125 -- internal/qtls/go116.go | 81 -- internal/qtls/go117.go | 81 -- internal/qtls/go118.go | 81 -- internal/qtls/go119.go | 15 +- internal/qtls/qtls_suite_test.go | 25 - internal/qtls/qtls_test.go | 17 - internal/wire/ack_frame.go | 252 ---- internal/wire/ack_frame_test.go | 454 ------ internal/wire/ack_range.go | 14 - internal/wire/ack_range_test.go | 13 - internal/wire/connection_close_frame.go | 84 -- internal/wire/connection_close_frame_test.go | 153 -- internal/wire/crypto_frame.go | 103 -- internal/wire/crypto_frame_test.go | 148 -- internal/wire/data_blocked_frame.go | 39 - internal/wire/data_blocked_frame_test.go | 54 - internal/wire/datagram_frame.go | 86 -- internal/wire/datagram_frame_test.go | 154 -- internal/wire/extended_header.go | 250 ---- internal/wire/extended_header_test.go | 481 ------ internal/wire/frame_parser.go | 144 -- internal/wire/frame_parser_test.go | 410 ------ internal/wire/handshake_done_frame.go | 29 - internal/wire/header.go | 275 ---- internal/wire/header_test.go | 584 -------- internal/wire/interface.go | 20 - internal/wire/log.go | 72 - internal/wire/log_test.go | 168 --- internal/wire/max_data_frame.go | 41 - internal/wire/max_data_frame_test.go | 57 - internal/wire/max_stream_data_frame.go | 47 - internal/wire/max_stream_data_frame_test.go | 63 - internal/wire/max_streams_frame.go | 56 - internal/wire/max_streams_frame_test.go | 107 -- internal/wire/new_connection_id_frame.go | 81 -- internal/wire/new_connection_id_frame_test.go | 104 -- internal/wire/new_token_frame.go | 49 - internal/wire/new_token_frame_test.go | 66 - internal/wire/path_challenge_frame.go | 39 - internal/wire/path_challenge_frame_test.go | 48 - internal/wire/path_response_frame.go | 39 - internal/wire/path_response_frame_test.go | 47 - internal/wire/ping_frame.go | 28 - internal/wire/ping_frame_test.go | 39 - internal/wire/pool.go | 33 - internal/wire/pool_test.go | 24 - internal/wire/reset_stream_frame.go | 59 - internal/wire/reset_stream_frame_test.go | 70 - internal/wire/retire_connection_id_frame.go | 37 - .../wire/retire_connection_id_frame_test.go | 53 - internal/wire/stop_sending_frame.go | 49 - internal/wire/stop_sending_frame_test.go | 63 - internal/wire/stream_data_blocked_frame.go | 47 - .../wire/stream_data_blocked_frame_test.go | 63 - internal/wire/stream_frame.go | 190 --- internal/wire/stream_frame_test.go | 443 ------ internal/wire/streams_blocked_frame.go | 56 - internal/wire/streams_blocked_frame_test.go | 108 -- internal/wire/transport_parameter_test.go | 612 -------- internal/wire/transport_parameters.go | 476 ------ internal/wire/version_negotiation.go | 55 - internal/wire/version_negotiation_test.go | 84 -- internal/wire/wire_suite_test.go | 31 - 117 files changed, 14 insertions(+), 17501 deletions(-) delete mode 100644 internal/handshake/aead.go delete mode 100644 internal/handshake/aead_test.go delete mode 100644 internal/handshake/crypto_setup.go delete mode 100644 internal/handshake/crypto_setup_test.go delete mode 100644 internal/handshake/handshake_suite_test.go delete mode 100644 internal/handshake/header_protector.go delete mode 100644 internal/handshake/hkdf.go delete mode 100644 internal/handshake/hkdf_test.go delete mode 100644 internal/handshake/initial_aead.go delete mode 100644 internal/handshake/initial_aead_test.go delete mode 100644 internal/handshake/interface.go delete mode 100644 internal/handshake/mock_handshake_runner_test.go delete mode 100644 internal/handshake/mockgen.go delete mode 100644 internal/handshake/retry.go delete mode 100644 internal/handshake/retry_test.go delete mode 100644 internal/handshake/session_ticket.go delete mode 100644 internal/handshake/session_ticket_test.go delete mode 100644 internal/handshake/tls_extension_handler.go delete mode 100644 internal/handshake/tls_extension_handler_test.go delete mode 100644 internal/handshake/token_generator.go delete mode 100644 internal/handshake/token_generator_test.go delete mode 100644 internal/handshake/token_protector.go delete mode 100644 internal/handshake/token_protector_test.go delete mode 100644 internal/handshake/updatable_aead.go delete mode 100644 internal/handshake/updatable_aead_test.go delete mode 100644 internal/http3/server_test.go delete mode 100644 internal/logging/frame.go delete mode 100644 internal/logging/interface.go delete mode 100644 internal/logging/logging_suite_test.go delete mode 100644 internal/logging/mock_connection_tracer_test.go delete mode 100644 internal/logging/mock_tracer_test.go delete mode 100644 internal/logging/mockgen.go delete mode 100644 internal/logging/multiplex.go delete mode 100644 internal/logging/multiplex_test.go delete mode 100644 internal/logging/packet_header.go delete mode 100644 internal/logging/packet_header_test.go delete mode 100644 internal/logging/types.go delete mode 100644 internal/mocks/congestion.go delete mode 100644 internal/mocks/connection_flow_controller.go delete mode 100644 internal/mocks/crypto_setup.go delete mode 100644 internal/mocks/logging/connection_tracer.go delete mode 100644 internal/mocks/logging/tracer.go delete mode 100644 internal/mocks/long_header_opener.go delete mode 100644 internal/mocks/quic/early_listener.go delete mode 100644 internal/mocks/short_header_opener.go delete mode 100644 internal/mocks/short_header_sealer.go delete mode 100644 internal/mocks/stream_flow_controller.go delete mode 100644 internal/mocks/tls/client_session_cache.go delete mode 100644 internal/qerr/error_codes.go delete mode 100644 internal/qerr/errorcodes_test.go delete mode 100644 internal/qerr/errors.go delete mode 100644 internal/qerr/errors_suite_test.go delete mode 100644 internal/qerr/errors_test.go delete mode 100644 internal/qtls/qtls_suite_test.go delete mode 100644 internal/qtls/qtls_test.go delete mode 100644 internal/wire/ack_frame.go delete mode 100644 internal/wire/ack_frame_test.go delete mode 100644 internal/wire/ack_range.go delete mode 100644 internal/wire/ack_range_test.go delete mode 100644 internal/wire/connection_close_frame.go delete mode 100644 internal/wire/connection_close_frame_test.go delete mode 100644 internal/wire/crypto_frame.go delete mode 100644 internal/wire/crypto_frame_test.go delete mode 100644 internal/wire/data_blocked_frame.go delete mode 100644 internal/wire/data_blocked_frame_test.go delete mode 100644 internal/wire/datagram_frame.go delete mode 100644 internal/wire/datagram_frame_test.go delete mode 100644 internal/wire/extended_header.go delete mode 100644 internal/wire/extended_header_test.go delete mode 100644 internal/wire/frame_parser.go delete mode 100644 internal/wire/frame_parser_test.go delete mode 100644 internal/wire/handshake_done_frame.go delete mode 100644 internal/wire/header.go delete mode 100644 internal/wire/header_test.go delete mode 100644 internal/wire/interface.go delete mode 100644 internal/wire/log.go delete mode 100644 internal/wire/log_test.go delete mode 100644 internal/wire/max_data_frame.go delete mode 100644 internal/wire/max_data_frame_test.go delete mode 100644 internal/wire/max_stream_data_frame.go delete mode 100644 internal/wire/max_stream_data_frame_test.go delete mode 100644 internal/wire/max_streams_frame.go delete mode 100644 internal/wire/max_streams_frame_test.go delete mode 100644 internal/wire/new_connection_id_frame.go delete mode 100644 internal/wire/new_connection_id_frame_test.go delete mode 100644 internal/wire/new_token_frame.go delete mode 100644 internal/wire/new_token_frame_test.go delete mode 100644 internal/wire/path_challenge_frame.go delete mode 100644 internal/wire/path_challenge_frame_test.go delete mode 100644 internal/wire/path_response_frame.go delete mode 100644 internal/wire/path_response_frame_test.go delete mode 100644 internal/wire/ping_frame.go delete mode 100644 internal/wire/ping_frame_test.go delete mode 100644 internal/wire/pool.go delete mode 100644 internal/wire/pool_test.go delete mode 100644 internal/wire/reset_stream_frame.go delete mode 100644 internal/wire/reset_stream_frame_test.go delete mode 100644 internal/wire/retire_connection_id_frame.go delete mode 100644 internal/wire/retire_connection_id_frame_test.go delete mode 100644 internal/wire/stop_sending_frame.go delete mode 100644 internal/wire/stop_sending_frame_test.go delete mode 100644 internal/wire/stream_data_blocked_frame.go delete mode 100644 internal/wire/stream_data_blocked_frame_test.go delete mode 100644 internal/wire/stream_frame.go delete mode 100644 internal/wire/stream_frame_test.go delete mode 100644 internal/wire/streams_blocked_frame.go delete mode 100644 internal/wire/streams_blocked_frame_test.go delete mode 100644 internal/wire/transport_parameter_test.go delete mode 100644 internal/wire/transport_parameters.go delete mode 100644 internal/wire/version_negotiation.go delete mode 100644 internal/wire/version_negotiation_test.go delete mode 100644 internal/wire/wire_suite_test.go diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go deleted file mode 100644 index 1b8c28f8..00000000 --- a/internal/handshake/aead.go +++ /dev/null @@ -1,162 +0,0 @@ -package handshake - -import ( - "crypto/cipher" - "encoding/binary" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" - "github.com/imroc/req/v3/internal/utils" -) - -func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v quic.VersionNumber) cipher.AEAD { - keyLabel := hkdfLabelKeyV1 - ivLabel := hkdfLabelIVV1 - if v == protocol.Version2 { - keyLabel = hkdfLabelKeyV2 - ivLabel = hkdfLabelIVV2 - } - key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) - iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) - return suite.AEAD(key, iv) -} - -type longHeaderSealer struct { - aead cipher.AEAD - headerProtector headerProtector - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var _ LongHeaderSealer = &longHeaderSealer{} - -func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { - return &longHeaderSealer{ - aead: aead, - headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), - } -} - -func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return s.aead.Seal(dst, s.nonceBuf, src, ad) -} - -func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) -} - -func (s *longHeaderSealer) Overhead() int { - return s.aead.Overhead() -} - -type longHeaderOpener struct { - aead cipher.AEAD - headerProtector headerProtector - highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var _ LongHeaderOpener = &longHeaderOpener{} - -func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { - return &longHeaderOpener{ - aead: aead, - headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), - } -} - -func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { - return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) -} - -func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) - if err == nil { - o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) - } else { - err = ErrDecryptionFailed - } - return dec, err -} - -func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) -} - -type handshakeSealer struct { - LongHeaderSealer - - dropInitialKeys func() - dropped bool -} - -func newHandshakeSealer( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderSealer { - sealer := newLongHeaderSealer(aead, headerProtector) - // The client drops Initial keys when sending the first Handshake packet. - if perspective == protocol.PerspectiveServer { - return sealer - } - return &handshakeSealer{ - LongHeaderSealer: sealer, - dropInitialKeys: dropInitialKeys, - } -} - -func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - data := s.LongHeaderSealer.Seal(dst, src, pn, ad) - if !s.dropped { - s.dropInitialKeys() - s.dropped = true - } - return data -} - -type handshakeOpener struct { - LongHeaderOpener - - dropInitialKeys func() - dropped bool -} - -func newHandshakeOpener( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderOpener { - opener := newLongHeaderOpener(aead, headerProtector) - // The server drops Initial keys when first successfully processing a Handshake packet. - if perspective == protocol.PerspectiveClient { - return opener - } - return &handshakeOpener{ - LongHeaderOpener: opener, - dropInitialKeys: dropInitialKeys, - } -} - -func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) - if err == nil && !o.dropped { - o.dropInitialKeys() - o.dropped = true - } - return dec, err -} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go deleted file mode 100644 index 76557ddc..00000000 --- a/internal/handshake/aead_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/tls" - "fmt" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Long Header AEAD", func() { - for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - for i := range cipherSuites { - cs := cipherSuites[i] - - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - - return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), - newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) - } - - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") - - It("encrypts and decrypts a message", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) - }) - - It("ignores packets it can't decrypt for packet number derivation", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) - Expect(err).To(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) - }) - }) - - Context("header encryption", func() { - It("encrypts and encrypts the header", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) - }) - - It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := bytes.Repeat([]byte{0xff}, 16) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - }) - - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener() - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - sealer.EncryptHeader(sample, &header[0], header[9:13]) - rand.Read(sample) // use a different sample - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - }) - }) - }) - } - }) - - Describe("Long Header AEAD", func() { - var ( - dropped chan struct{} // use a chan because closing it twice will panic - aead cipher.AEAD - hp headerProtector - ) - dropCb := func() { close(dropped) } - msg := []byte("Lorem ipsum dolor sit amet.") - ad := []byte("Donec in velit neque.") - - BeforeEach(func() { - dropped = make(chan struct{}) - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err = cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) - }) - - Context("for the server", func() { - It("drops keys when first successfully processing a Handshake packet", func() { - serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) - // first try to open an invalid message - _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) - Expect(err).To(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - // then open a valid message - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).To(BeClosed()) - // now open the same message again to make sure the callback is only called once - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't drop keys when sealing a Handshake packet", func() { - serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) - serverSealer.Seal(nil, msg, 1, ad) - Expect(dropped).ToNot(BeClosed()) - }) - }) - - Context("for the client", func() { - It("drops keys when first sealing a Handshake packet", func() { - clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) - // seal the first message - clientSealer.Seal(nil, msg, 1, ad) - Expect(dropped).To(BeClosed()) - // seal another message to make sure the callback is only called once - clientSealer.Seal(nil, msg, 2, ad) - }) - - It("doesn't drop keys when processing a Handshake packet", func() { - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) - clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) - _, err := clientOpener.Open(nil, enc, 42, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - }) - }) - }) - } -}) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go deleted file mode 100644 index ed70ec57..00000000 --- a/internal/handshake/crypto_setup.go +++ /dev/null @@ -1,820 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/tls" - "errors" - "fmt" - "github.com/lucas-clemente/quic-go" - "io" - "net" - "sync" - "time" - - "github.com/imroc/req/v3/internal/logging" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/qtls" - "github.com/imroc/req/v3/internal/utils" - "github.com/imroc/req/v3/internal/wire" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// TLS unexpected_message alert -const alertUnexpectedMessage uint8 = 10 - -type messageType uint8 - -// TLS handshake message types. -const ( - typeClientHello messageType = 1 - typeServerHello messageType = 2 - typeNewSessionTicket messageType = 4 - typeEncryptedExtensions messageType = 8 - typeCertificate messageType = 11 - typeCertificateRequest messageType = 13 - typeCertificateVerify messageType = 15 - typeFinished messageType = 20 -) - -func (m messageType) String() string { - switch m { - case typeClientHello: - return "ClientHello" - case typeServerHello: - return "ServerHello" - case typeNewSessionTicket: - return "NewSessionTicket" - case typeEncryptedExtensions: - return "EncryptedExtensions" - case typeCertificate: - return "Certificate" - case typeCertificateRequest: - return "CertificateRequest" - case typeCertificateVerify: - return "CertificateVerify" - case typeFinished: - return "Finished" - default: - return fmt.Sprintf("unknown message type: %d", m) - } -} - -const clientSessionStateRevision = 3 - -type conn struct { - localAddr, remoteAddr net.Addr - version quic.VersionNumber -} - -var _ ConnWithVersion = &conn{} - -func newConn(local, remote net.Addr, version quic.VersionNumber) ConnWithVersion { - return &conn{ - localAddr: local, - remoteAddr: remote, - version: version, - } -} - -var _ net.Conn = &conn{} - -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } -func (c *conn) GetQUICVersion() quic.VersionNumber { return c.version } - -type cryptoSetup struct { - tlsConf *tls.Config - extraConf *qtls.ExtraConfig - conn *qtls.Conn - - version quic.VersionNumber - - messageChan chan []byte - isReadingHandshakeMessage chan struct{} - readFirstHandshakeMessage bool - - ourParams *wire.TransportParameters - peerParams *wire.TransportParameters - paramsChan <-chan []byte - - runner handshakeRunner - - alertChan chan uint8 - // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns - handshakeDone chan struct{} - // is closed when Close() is called - closeChan chan struct{} - - zeroRTTParameters *wire.TransportParameters - clientHelloWritten bool - clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written - zeroRTTParametersChan chan<- *wire.TransportParameters - - rttStats *utils.RTTStats - - tracer logging.ConnectionTracer - logger utils.Logger - - perspective protocol.Perspective - - mutex sync.Mutex // protects all members below - - handshakeCompleteTime time.Time - - readEncLevel protocol.EncryptionLevel - writeEncLevel protocol.EncryptionLevel - - zeroRTTOpener LongHeaderOpener // only set for the server - zeroRTTSealer LongHeaderSealer // only set for the client - - initialStream io.Writer - initialOpener LongHeaderOpener - initialSealer LongHeaderSealer - - handshakeStream io.Writer - handshakeOpener LongHeaderOpener - handshakeSealer LongHeaderSealer - - aead *updatableAEAD - has1RTTSealer bool - has1RTTOpener bool -} - -var ( - _ qtls.RecordLayer = &cryptoSetup{} - _ CryptoSetup = &cryptoSetup{} -) - -// NewCryptoSetupClient creates a new crypto setup for the client -func NewCryptoSetupClient( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version quic.VersionNumber, -) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { - cs, clientHelloWritten := newCryptoSetup( - initialStream, - handshakeStream, - connID, - tp, - runner, - tlsConf, - enable0RTT, - rttStats, - tracer, - logger, - protocol.PerspectiveClient, - version, - ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) - return cs, clientHelloWritten -} - -// NewCryptoSetupServer creates a new crypto setup for the server -func NewCryptoSetupServer( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version quic.VersionNumber, -) CryptoSetup { - cs, _ := newCryptoSetup( - initialStream, - handshakeStream, - connID, - tp, - runner, - tlsConf, - enable0RTT, - rttStats, - tracer, - logger, - protocol.PerspectiveServer, - version, - ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) - return cs -} - -func newCryptoSetup( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - perspective protocol.Perspective, - version quic.VersionNumber, -) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { - initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { - tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) - tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) - } - extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) - zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) - cs := &cryptoSetup{ - tlsConf: tlsConf, - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - aead: newUpdatableAEAD(rttStats, tracer, logger, version), - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - runner: runner, - ourParams: tp, - paramsChan: extHandler.TransportParameters(), - rttStats: rttStats, - tracer: tracer, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - alertChan: make(chan uint8), - clientHelloWrittenChan: make(chan struct{}), - zeroRTTParametersChan: zeroRTTParametersChan, - messageChan: make(chan []byte, 100), - isReadingHandshakeMessage: make(chan struct{}), - closeChan: make(chan struct{}), - version: version, - } - var maxEarlyData uint32 - if enable0RTT { - maxEarlyData = 0xffffffff - } - cs.extraConf = &qtls.ExtraConfig{ - GetExtensions: extHandler.GetExtensions, - ReceivedExtensions: extHandler.ReceivedExtensions, - AlternativeRecordLayer: cs, - EnforceNextProtoSelection: true, - MaxEarlyData: maxEarlyData, - Accept0RTT: cs.accept0RTT, - Rejected0RTT: cs.rejected0RTT, - Enable0RTT: enable0RTT, - GetAppDataForSessionState: cs.marshalDataForSessionState, - SetAppDataFromSessionState: cs.handleDataFromSessionState, - } - return cs, zeroRTTParametersChan -} - -func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { - initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) - h.initialSealer = initialSealer - h.initialOpener = initialOpener - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) - h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) - } -} - -func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { - return h.aead.SetLargestAcked(pn) -} - -func (h *cryptoSetup) RunHandshake() { - // Handle errors that might occur when HandleData() is called. - handshakeComplete := make(chan struct{}) - handshakeErrChan := make(chan error, 1) - go func() { - defer close(h.handshakeDone) - if err := h.conn.Handshake(); err != nil { - handshakeErrChan <- err - return - } - close(handshakeComplete) - }() - - if h.perspective == protocol.PerspectiveClient { - select { - case err := <-handshakeErrChan: - h.onError(0, err.Error()) - return - case <-h.clientHelloWrittenChan: - } - } - - select { - case <-handshakeComplete: // return when the handshake is done - h.mutex.Lock() - h.handshakeCompleteTime = time.Now() - h.mutex.Unlock() - h.runner.OnHandshakeComplete() - case <-h.closeChan: - // wait until the Handshake() go routine has returned - <-h.handshakeDone - case alert := <-h.alertChan: - handshakeErr := <-handshakeErrChan - h.onError(alert, handshakeErr.Error()) - } -} - -func (h *cryptoSetup) onError(alert uint8, message string) { - var err error - if alert == 0 { - err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} - } else { - err = qerr.NewCryptoError(alert, message) - } - h.runner.OnError(err) -} - -// Close closes the crypto setup. -// It aborts the handshake, if it is still running. -// It must only be called once. -func (h *cryptoSetup) Close() error { - close(h.closeChan) - // wait until qtls.Handshake() actually returned - <-h.handshakeDone - return nil -} - -// handleMessage handles a TLS handshake message. -// It is called by the crypto streams when a new message is available. -// It returns if it is done with messages on the same encryption level. -func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { - msgType := messageType(data[0]) - h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) - if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - h.onError(alertUnexpectedMessage, err.Error()) - return false - } - h.messageChan <- data - if encLevel == protocol.Encryption1RTT { - h.handlePostHandshakeMessage() - return false - } -readLoop: - for { - select { - case data := <-h.paramsChan: - if data == nil { - h.onError(0x6d, "missing quic_transport_parameters extension") - } else { - h.handleTransportParameters(data) - } - case <-h.isReadingHandshakeMessage: - break readLoop - case <-h.handshakeDone: - break readLoop - case <-h.closeChan: - break readLoop - } - } - // We're done with the Initial encryption level after processing a ClientHello / ServerHello, - // but only if a handshake opener and sealer was created. - // Otherwise, a HelloRetryRequest was performed. - // We're done with the Handshake encryption level after processing the Finished message. - return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || - msgType == typeFinished -} - -func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { - var expected protocol.EncryptionLevel - switch msgType { - case typeClientHello, - typeServerHello: - expected = protocol.EncryptionInitial - case typeEncryptedExtensions, - typeCertificate, - typeCertificateRequest, - typeCertificateVerify, - typeFinished: - expected = protocol.EncryptionHandshake - case typeNewSessionTicket: - expected = protocol.Encryption1RTT - default: - return fmt.Errorf("unexpected handshake message: %d", msgType) - } - if encLevel != expected { - return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) - } - return nil -} - -func (h *cryptoSetup) handleTransportParameters(data []byte) { - var tp wire.TransportParameters - if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { - h.runner.OnError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - }) - } - h.peerParams = &tp - h.runner.OnReceivedParams(h.peerParams) -} - -// must be called after receiving the transport parameters -func (h *cryptoSetup) marshalDataForSessionState() []byte { - buf := &bytes.Buffer{} - quicvarint.Write(buf, clientSessionStateRevision) - quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) - h.peerParams.MarshalForSessionTicket(buf) - return buf.Bytes() -} - -func (h *cryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) - if err != nil { - h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) - return - } - h.zeroRTTParameters = tp -} - -func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { - r := bytes.NewReader(data) - ver, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err - } - return &tp, nil -} - -// only valid for the server -func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - var appData []byte - // Save transport parameters to the session ticket if we're allowing 0-RTT. - if h.extraConf.MaxEarlyData > 0 { - appData = (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() - } - return h.conn.GetSessionTicket(appData) -} - -// accept0RTT is called for the server when receiving the client's session ticket. -// It decides whether to accept 0-RTT. -func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { - var t sessionTicket - if err := t.Unmarshal(sessionTicketData); err != nil { - h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) - return false - } - valid := h.ourParams.ValidFor0RTT(t.Parameters) - if valid { - h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) - } else { - h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") - } - return valid -} - -// rejected0RTT is called for the client when the server rejects 0-RTT. -func (h *cryptoSetup) rejected0RTT() { - h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - - h.mutex.Lock() - had0RTTKeys := h.zeroRTTSealer != nil - h.zeroRTTSealer = nil - h.mutex.Unlock() - - if had0RTTKeys { - h.runner.DropKeys(protocol.Encryption0RTT) - } -} - -func (h *cryptoSetup) handlePostHandshakeMessage() { - // make sure the handshake has already completed - <-h.handshakeDone - - done := make(chan struct{}) - defer close(done) - - // h.alertChan is an unbuffered channel. - // If an error occurs during conn.HandlePostHandshakeMessage, - // it will be sent on this channel. - // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. - alertChan := make(chan uint8, 1) - go func() { - <-h.isReadingHandshakeMessage - select { - case alert := <-h.alertChan: - alertChan <- alert - case <-done: - } - }() - - if err := h.conn.HandlePostHandshakeMessage(); err != nil { - select { - case <-h.closeChan: - case alert := <-alertChan: - h.onError(alert, err.Error()) - } - } -} - -// ReadHandshakeMessage is called by TLS. -// It blocks until a new handshake message is available. -func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { - if !h.readFirstHandshakeMessage { - h.readFirstHandshakeMessage = true - } else { - select { - case h.isReadingHandshakeMessage <- struct{}{}: - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } - } - select { - case msg := <-h.messageChan: - return msg, nil - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } -} - -func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: - if h.perspective == protocol.PerspectiveClient { - panic("Received 0-RTT read key for the client") - } - h.zeroRTTOpener = newLongHeaderOpener( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - ) - h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) - } - return - case qtls.EncryptionHandshake: - h.readEncLevel = protocol.EncryptionHandshake - h.handshakeOpener = newHandshakeOpener( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, - ) - h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - case qtls.EncryptionApplication: - h.readEncLevel = protocol.Encryption1RTT - h.aead.SetReadKey(suite, trafficSecret) - h.has1RTTOpener = true - h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - default: - panic("unexpected read encryption level") - } - h.mutex.Unlock() - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) - } -} - -func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: - if h.perspective == protocol.PerspectiveServer { - panic("Received 0-RTT write key for the server") - } - h.zeroRTTSealer = newLongHeaderSealer( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - ) - h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) - } - return - case qtls.EncryptionHandshake: - h.writeEncLevel = protocol.EncryptionHandshake - h.handshakeSealer = newHandshakeSealer( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, - ) - h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - case qtls.EncryptionApplication: - h.writeEncLevel = protocol.Encryption1RTT - h.aead.SetWriteKey(suite, trafficSecret) - h.has1RTTSealer = true - h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.zeroRTTSealer != nil { - h.zeroRTTSealer = nil - h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { - h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) - } - } - default: - panic("unexpected write encryption level") - } - h.mutex.Unlock() - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) - } -} - -// WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - //nolint:exhaustive // LS records can only be written for Initial and Handshake. - switch h.writeEncLevel { - case protocol.EncryptionInitial: - // assume that the first WriteRecord call contains the ClientHello - n, err := h.initialStream.Write(p) - if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { - h.clientHelloWritten = true - close(h.clientHelloWrittenChan) - if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { - h.logger.Debugf("Doing 0-RTT.") - h.zeroRTTParametersChan <- h.zeroRTTParameters - } else { - h.logger.Debugf("Not doing 0-RTT.") - h.zeroRTTParametersChan <- nil - } - } - return n, err - case protocol.EncryptionHandshake: - return h.handshakeStream.Write(p) - default: - panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) - } -} - -func (h *cryptoSetup) SendAlert(alert uint8) { - select { - case h.alertChan <- alert: - case <-h.closeChan: - // no need to send an alert when we've already closed - } -} - -// used a callback in the handshakeSealer and handshakeOpener -func (h *cryptoSetup) dropInitialKeys() { - h.mutex.Lock() - h.initialOpener = nil - h.initialSealer = nil - h.mutex.Unlock() - h.runner.DropKeys(protocol.EncryptionInitial) - h.logger.Debugf("Dropping Initial keys.") -} - -func (h *cryptoSetup) SetHandshakeConfirmed() { - h.aead.SetHandshakeConfirmed() - // drop Handshake keys - var dropped bool - h.mutex.Lock() - if h.handshakeOpener != nil { - h.handshakeOpener = nil - h.handshakeSealer = nil - dropped = true - } - h.mutex.Unlock() - if dropped { - h.runner.DropKeys(protocol.EncryptionHandshake) - h.logger.Debugf("Dropping Handshake keys.") - } -} - -func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.initialSealer == nil { - return nil, ErrKeysDropped - } - return h.initialSealer, nil -} - -func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTSealer == nil { - return nil, ErrKeysDropped - } - return h.zeroRTTSealer, nil -} - -func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.handshakeSealer == nil { - if h.initialSealer == nil { - return nil, ErrKeysDropped - } - return nil, ErrKeysNotYetAvailable - } - return h.handshakeSealer, nil -} - -func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if !h.has1RTTSealer { - return nil, ErrKeysNotYetAvailable - } - return h.aead, nil -} - -func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.initialOpener == nil { - return nil, ErrKeysDropped - } - return h.initialOpener, nil -} - -func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTOpener == nil { - if h.initialOpener != nil { - return nil, ErrKeysNotYetAvailable - } - // if the initial opener is also not available, the keys were already dropped - return nil, ErrKeysDropped - } - return h.zeroRTTOpener, nil -} - -func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.handshakeOpener == nil { - if h.initialOpener != nil { - return nil, ErrKeysNotYetAvailable - } - // if the initial opener is also not available, the keys were already dropped - return nil, ErrKeysDropped - } - return h.handshakeOpener, nil -} - -func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { - h.zeroRTTOpener = nil - h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { - h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) - } - } - - if !h.has1RTTOpener { - return nil, ErrKeysNotYetAvailable - } - return h.aead, nil -} - -func (h *cryptoSetup) ConnectionState() ConnectionState { - return qtls.GetConnectionState(h.conn) -} diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go deleted file mode 100644 index 7d592428..00000000 --- a/internal/handshake/crypto_setup_test.go +++ /dev/null @@ -1,864 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "math/big" - "time" - - mocktls "github.com/imroc/req/v3/internal/mocks/tls" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/testdata" - "github.com/imroc/req/v3/internal/utils" - "github.com/imroc/req/v3/internal/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -type chunk struct { - data []byte - encLevel protocol.EncryptionLevel -} - -type stream struct { - encLevel protocol.EncryptionLevel - chunkChan chan<- chunk -} - -func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { - return &stream{ - chunkChan: chunkChan, - encLevel: encLevel, - } -} - -func (s *stream) Write(b []byte) (int, error) { - data := make([]byte, len(b)) - copy(data, b) - select { - case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: - default: - panic("chunkChan too small") - } - return len(b), nil -} - -var _ = Describe("Crypto Setup TLS", func() { - var clientConf, serverConf *tls.Config - - // unparam incorrectly complains that the first argument is never used. - //nolint:unparam - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { - chunkChan := make(chan chunk, 100) - initialStream := newStream(chunkChan, protocol.EncryptionInitial) - handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) - return chunkChan, initialStream, handshakeStream - } - - BeforeEach(func() { - serverConf = testdata.GetTLSConfig() - serverConf.NextProtos = []string{"crypto-setup"} - clientConf = &tls.Config{ - ServerName: "localhost", - RootCAs: testdata.GetRootCA(), - NextProtos: []string{"crypto-setup"}, - } - }) - - It("returns Handshake() when an error occurs in qtls", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "local error: tls: unexpected message", - }))) - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - handledMessage := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.HandleMessage(fakeCH, protocol.EncryptionInitial) - close(handledMessage) - }() - Eventually(handledMessage).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - - It("handles qtls errors occurring before during ClientHello generation", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - tlsConf := testdata.GetTLSConfig() - tlsConf.InsecureSkipVerify = true - tlsConf.NextProtos = []string{""} - cl, _ := NewCryptoSetupClient( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - runner, - tlsConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cl.RunHandshake() - close(done) - }() - - Eventually(done).Should(BeClosed()) - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: "tls: invalid NextProtos value", - }))) - }) - - It("errors when a message is received at the wrong encryption level", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", - }))) - - // make the go routine return - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when handling a message fails", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - close(done) - }() - - fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - NewMockHandshakeRunner(mockCtrl), - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - Context("doing the handshake", func() { - generateCert := func() tls.Certificate { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{}, - SignatureAlgorithm: x509.SHA256WithRSA, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), // valid for an hour - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - PrivateKey: priv, - Certificate: [][]byte{certDER}, - } - } - - newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { - rttStats := &utils.RTTStats{} - rttStats.UpdateRTT(rtt, 0, time.Now()) - ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) - return rttStats - } - - handshake := func(client CryptoSetup, cChunkChan <-chan chunk, - server CryptoSetup, sChunkChan <-chan chunk, - ) { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - msgType := messageType(c.data[0]) - finished := server.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeClientHello { - // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. - _, err := server.GetHandshakeOpener() - Expect(finished).To(Equal(err == nil)) - } else { - Expect(finished).To(BeFalse()) - } - case c := <-sChunkChan: - msgType := messageType(c.data[0]) - finished := client.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeServerHello { - Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) - } else { - Expect(finished).To(BeFalse()) - } - case <-done: // handshake complete - return - } - } - }() - - go func() { - defer GinkgoRecover() - defer close(done) - server.RunHandshake() - ticket, err := server.GetSessionTicket() - Expect(err).ToNot(HaveOccurred()) - if ticket != nil { - client.HandleMessage(ticket, protocol.Encryption1RTT) - } - }() - - client.RunHandshake() - Eventually(done).Should(BeClosed()) - } - - handshakeWithTLSConf := func( - clientConf, serverConf *tls.Config, - clientRTTStats, serverRTTStats *utils.RTTStats, - clientTransportParameters, serverTransportParameters *wire.TransportParameters, - enable0RTT bool, - ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { - var cHandshakeComplete bool - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cErrChan := make(chan error, 1) - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) - cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) - cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) - client, clientHelloWrittenChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - clientTransportParameters, - cRunner, - clientConf, - enable0RTT, - clientRTTStats, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - var sHandshakeComplete bool - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sErrChan := make(chan error, 1) - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) - sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) - if serverTransportParameters.StatelessResetToken == nil { - var token protocol.StatelessResetToken - serverTransportParameters.StatelessResetToken = &token - } - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - serverTransportParameters, - sRunner, - serverConf, - enable0RTT, - serverRTTStats, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - handshake(client, cChunkChan, server, sChunkChan) - var cErr, sErr error - select { - case sErr = <-sErrChan: - default: - Expect(sHandshakeComplete).To(BeTrue()) - } - select { - case cErr = <-cErrChan: - default: - Expect(cHandshakeComplete).To(BeTrue()) - } - return clientHelloWrittenChan, client, cErr, server, sErr - } - - It("handshakes", func() { - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("performs a HelloRetryRequst", func() { - serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("handshakes with client auth", func() { - clientConf.Certificates = []tls.Certificate{generateCert()} - serverConf.ClientAuth = tls.RequireAnyClientCert - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("signals when it has written the ClientHello", func() { - runner := NewMockHandshakeRunner(mockCtrl) - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, chChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - runner, - &tls.Config{InsecureSkipVerify: true}, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() - var ch chunk - Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(Receive(BeNil())) - // make sure the whole ClientHello was written - Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) - length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) - Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - Expect(client.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second} - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - cTransportParameters, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) - sRunner.EXPECT().OnHandshakeComplete() - sTransportParameters := &wire.TransportParameters{ - MaxIdleTimeout: 0x1337 * time.Second, - StatelessResetToken: &token, - } - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - sTransportParameters, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) - Expect(sTransportParametersRcvd).ToNot(BeNil()) - Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) - }) - - Context("with session tickets", func() { - It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - - // inject an invalid session ticket - cRunner.EXPECT().OnError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", - }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.EncryptionHandshake) - }) - - It("errors when handling the NewSessionTicket fails", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - - // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) - }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.Encryption1RTT) - }) - - It("uses session resumption", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - }) - - It("doesn't use session resumption if the server disabled it", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - _, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - - serverConf.SessionTicketsDisabled = true - csc.EXPECT().Get(gomock.Any()).Return(state, true) - _, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - }) - - It("uses 0-RTT", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, serverOrigRTTStats, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - - clientRTTStats := &utils.RTTStats{} - serverRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, serverRTTStats, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) - - var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) - Expect(tp.InitialMaxData).To(Equal(initialMaxData)) - - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(server.ConnectionState().Used0RTT).To(BeTrue()) - Expect(client.ConnectionState().Used0RTT).To(BeTrue()) - }) - - It("rejects 0-RTT, when the transport parameters changed", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - - clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - - var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) - Expect(tp.InitialMaxData).To(Equal(initialMaxData)) - - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(server.ConnectionState().Used0RTT).To(BeFalse()) - Expect(client.ConnectionState().Used0RTT).To(BeFalse()) - }) - }) - }) -}) diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go deleted file mode 100644 index 065877e8..00000000 --- a/internal/handshake/handshake_suite_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package handshake - -import ( - "crypto/tls" - "encoding/hex" - "strings" - "testing" - - "github.com/imroc/req/v3/internal/qtls" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestHandshake(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Handshake Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) - -func splitHexString(s string) (slice []byte) { - for _, ss := range strings.Split(s, " ") { - if ss[0:2] == "0x" { - ss = ss[2:] - } - d, err := hex.DecodeString(ss) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - slice = append(slice, d...) - } - return -} - -var cipherSuites = []*qtls.CipherSuiteTLS13{ - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), - qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), -} diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go deleted file mode 100644 index 8921fb05..00000000 --- a/internal/handshake/header_protector.go +++ /dev/null @@ -1,137 +0,0 @@ -package handshake - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/tls" - "encoding/binary" - "fmt" - "github.com/lucas-clemente/quic-go" - - "golang.org/x/crypto/chacha20" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" -) - -type headerProtector interface { - EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) - DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) -} - -func hkdfHeaderProtectionLabel(v quic.VersionNumber) string { - if v == protocol.Version2 { - return "quicv2 hp" - } - return "quic hp" -} - -func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v quic.VersionNumber) headerProtector { - hkdfLabel := hkdfHeaderProtectionLabel(v) - switch suite.ID { - case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: - return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) - case tls.TLS_CHACHA20_POLY1305_SHA256: - return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) - default: - panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) - } -} - -type aesHeaderProtector struct { - mask []byte - block cipher.Block - isLongHeader bool -} - -var _ headerProtector = &aesHeaderProtector{} - -func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) - block, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - return &aesHeaderProtector{ - block: block, - mask: make([]byte, block.BlockSize()), - isLongHeader: isLongHeader, - } -} - -func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { - if len(sample) != len(p.mask) { - panic("invalid sample size") - } - p.block.Encrypt(p.mask, sample) - if p.isLongHeader { - *firstByte ^= p.mask[0] & 0xf - } else { - *firstByte ^= p.mask[0] & 0x1f - } - for i := range hdrBytes { - hdrBytes[i] ^= p.mask[i+1] - } -} - -type chachaHeaderProtector struct { - mask [5]byte - - key [32]byte - isLongHeader bool -} - -var _ headerProtector = &chachaHeaderProtector{} - -func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) - - p := &chachaHeaderProtector{ - isLongHeader: isLongHeader, - } - copy(p.key[:], hpKey) - return p -} - -func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { - if len(sample) != 16 { - panic("invalid sample size") - } - for i := 0; i < 5; i++ { - p.mask[i] = 0 - } - cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) - if err != nil { - panic(err) - } - cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) - cipher.XORKeyStream(p.mask[:], p.mask[:]) - p.applyMask(firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { - if p.isLongHeader { - *firstByte ^= p.mask[0] & 0xf - } else { - *firstByte ^= p.mask[0] & 0x1f - } - for i := range hdrBytes { - hdrBytes[i] ^= p.mask[i+1] - } -} diff --git a/internal/handshake/hkdf.go b/internal/handshake/hkdf.go deleted file mode 100644 index c4fd86c5..00000000 --- a/internal/handshake/hkdf.go +++ /dev/null @@ -1,29 +0,0 @@ -package handshake - -import ( - "crypto" - "encoding/binary" - - "golang.org/x/crypto/hkdf" -) - -// hkdfExpandLabel HKDF expands a label. -// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the -// hkdfExpandLabel in the standard library. -func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { - b := make([]byte, 3, 3+6+len(label)+1+len(context)) - binary.BigEndian.PutUint16(b, uint16(length)) - b[2] = uint8(6 + len(label)) - b = append(b, []byte("tls13 ")...) - b = append(b, []byte(label)...) - b = b[:3+6+len(label)+1] - b[3+6+len(label)] = uint8(len(context)) - b = append(b, context...) - - out := make([]byte, length) - n, err := hkdf.Expand(hash.New, secret, b).Read(out) - if err != nil || n != length { - panic("quic: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} diff --git a/internal/handshake/hkdf_test.go b/internal/handshake/hkdf_test.go deleted file mode 100644 index 16154199..00000000 --- a/internal/handshake/hkdf_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package handshake - -import ( - "crypto" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Initial AEAD using AES-GCM", func() { - // Result generated by running in qtls: - // cipherSuiteTLS13ByID(TLS_AES_128_GCM_SHA256).expandLabel([]byte("secret"), []byte("context"), "label", 42) - It("gets the same results as qtls", func() { - expanded := hkdfExpandLabel(crypto.SHA256, []byte("secret"), []byte("context"), "label", 42) - Expect(expanded).To(Equal([]byte{0x78, 0x87, 0x6a, 0xb5, 0x84, 0xa2, 0x26, 0xb7, 0x8, 0x5a, 0x7b, 0x3a, 0x4c, 0xbb, 0x1e, 0xbc, 0x2f, 0x9b, 0x67, 0xd0, 0x6a, 0xa2, 0x24, 0xb4, 0x7d, 0x29, 0x3c, 0x7a, 0xce, 0xc7, 0xc3, 0x74, 0xcd, 0x59, 0x7a, 0xa8, 0x21, 0x5e, 0xe7, 0xca, 0x1, 0xda})) - }) -}) diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go deleted file mode 100644 index 97a40f74..00000000 --- a/internal/handshake/initial_aead.go +++ /dev/null @@ -1,82 +0,0 @@ -package handshake - -import ( - "crypto" - "crypto/tls" - "github.com/lucas-clemente/quic-go" - - "golang.org/x/crypto/hkdf" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" -) - -var ( - quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} - quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} - quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} -) - -const ( - hkdfLabelKeyV1 = "quic key" - hkdfLabelKeyV2 = "quicv2 key" - hkdfLabelIVV1 = "quic iv" - hkdfLabelIVV2 = "quicv2 iv" -) - -func getSalt(v quic.VersionNumber) []byte { - if v == protocol.Version2 { - return quicSaltV2 - } - if v == protocol.Version1 { - return quicSaltV1 - } - return quicSaltOld -} - -var initialSuite = &qtls.CipherSuiteTLS13{ - ID: tls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, -} - -// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v quic.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { - clientSecret, serverSecret := computeSecrets(connID, v) - var mySecret, otherSecret []byte - if pers == protocol.PerspectiveClient { - mySecret = clientSecret - otherSecret = serverSecret - } else { - mySecret = serverSecret - otherSecret = clientSecret - } - myKey, myIV := computeInitialKeyAndIV(mySecret, v) - otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) - - encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - - return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), - newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) -} - -func computeSecrets(connID protocol.ConnectionID, v quic.VersionNumber) (clientSecret, serverSecret []byte) { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) - clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) - serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) - return -} - -func computeInitialKeyAndIV(secret []byte, v quic.VersionNumber) (key, iv []byte) { - keyLabel := hkdfLabelKeyV1 - ivLabel := hkdfLabelIVV1 - if v == protocol.Version2 { - keyLabel = hkdfLabelKeyV2 - ivLabel = hkdfLabelIVV2 - } - key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) - iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) - return -} diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go deleted file mode 100644 index 7fd496e6..00000000 --- a/internal/handshake/initial_aead_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package handshake - -import ( - "fmt" - "github.com/lucas-clemente/quic-go" - "math/rand" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" -) - -var _ = Describe("Initial AEAD using AES-GCM", func() { - It("converts the string representation used in the draft into byte slices", func() { - Expect(splitHexString("0xdeadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - Expect(splitHexString("deadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - }) - - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - - DescribeTable("computes the client key and IV", - func(v quic.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { - clientSecret, _ := computeSecrets(connID, v) - Expect(clientSecret).To(Equal(expectedClientSecret)) - key, iv := computeInitialKeyAndIV(clientSecret, v) - Expect(key).To(Equal(expectedKey)) - Expect(iv).To(Equal(expectedIV)) - }, - Entry("draft-29", - protocol.VersionDraft29, - splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"), - splitHexString("175257a31eb09dea9366d8bb79ad80ba"), - splitHexString("6b26114b9cba2b63a9e8dd4f"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), - splitHexString("1f369613dd76d5467730efcbe3b1a22d"), - splitHexString("fa044b2f42a3fd3b46fb255c"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("9fe72e1452e91f551b770005054034e4 7575d4a0fb4c27b7c6cb303a338423ae"), - splitHexString("95df2be2e8d549c82e996fc9339f4563"), - splitHexString("ea5e3c95f933db14b7020ad8"), - ), - ) - - DescribeTable("computes the server key and IV", - func(v quic.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { - _, serverSecret := computeSecrets(connID, v) - Expect(serverSecret).To(Equal(expectedServerSecret)) - key, iv := computeInitialKeyAndIV(serverSecret, v) - Expect(key).To(Equal(expectedKey)) - Expect(iv).To(Equal(expectedIV)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"), - splitHexString("149d0b1662ab871fbe63c49b5e655a5d"), - splitHexString("bab2b12a4c76016ace47856d"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), - splitHexString("cf3a5331653c364c88f0f379b6067e37"), - splitHexString("0ac1493ca1905853b0bba03e"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("3c9bf6a9c1c8c71819876967bd8b979e fd98ec665edf27f22c06e9845ba0ae2f"), - splitHexString("15d5b4d9a2b8916aa39b1bfe574d2aad"), - splitHexString("a85e7ac31cd275cbb095c626"), - ), - ) - - DescribeTable("encrypts the client's Initial", - func(v quic.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) - data = append(data, make([]byte, 1162-len(data))...) // add PADDING - sealed := sealer.Seal(nil, data, 2, header) - sample := sealed[0:16] - Expect(sample).To(Equal(expectedSample)) - sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) - Expect(header[0]).To(Equal(expectedHdrFirstByte)) - Expect(header[len(header)-4:]).To(Equal(expectedHdr)) - packet := append(header, sealed...) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("c3ff00001d088394c8f03e5157080000449e00000002"), - splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001"), - splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"), - byte(0xc5), - splitHexString("4a95245b"), - splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c300000001088394c8f03e5157080000449e00000002"), - splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), - splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"), - byte(0xc0), - splitHexString("7b9aec34"), - splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("d3709a50c4088394c8f03e5157080000449e00000002"), - splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), - splitHexString("23b8e610589c83c92d0e97eb7a6e5003"), - byte(0xdd), - splitHexString("4391d848"), - splitHexString("dd709a50c4088394c8f03e5157080000 449e4391d84823b8e610589c83c92d0e 97eb7a6e5003f57764c5c7f0095ba54b 90818f1bfeecc1c97c54fc731edbd2a2 44e3b1e639a9bc75ed545b98649343b2 53615ec6b3e4df0fd2e7fe9d691a09e6 a144b436d8a2c088a404262340dfd995 ec3865694e3026ecd8c6d2561a5a3667 2a1005018168c0f081c10e2bf14d550c 977e28bb9a759c57d0f7ffb1cdfb40bd 774dec589657542047dffefa56fc8089 a4d1ef379c81ba3df71a05ddc7928340 775910feb3ce4cbcfd8d253edd05f161 458f9dc44bea017c3117cca7065a315d eda9464e672ec80c3f79ac993437b441 ef74227ecc4dc9d597f66ab0ab8d214b 55840c70349d7616cbe38e5e1d052d07 f1fedb3dd3c4d8ce295724945e67ed2e efcd9fb52472387f318e3d9d233be7df c79d6bf6080dcbbb41feb180d7858849 7c3e439d38c334748d2b56fd19ab364d 057a9bd5a699ae145d7fdbc8f5777518 1b0a97c3bdedc91a555d6c9b8634e106 d8c9ca45a9d5450a7679edc545da9102 5bc93a7cf9a023a066ffadb9717ffaf3 414c3b646b5738b3cc4116502d18d79d 8227436306d9b2b3afc6c785ce3c817f eb703a42b9c83b59f0dcef1245d0b3e4 0299821ec19549ce489714fe2611e72c d882f4f70dce7d3671296fc045af5c9f 630d7b49a3eb821bbca60f1984dce664 91713bfe06001a56f51bb3abe92f7960 547c4d0a70f4a962b3f05dc25a34bbe8 30a7ea4736d3b0161723500d82beda9b e3327af2aa413821ff678b2a876ec4b0 0bb605ffcc3917ffdc279f187daa2fce 8cde121980bba8ec8f44ca562b0f1319 14c901cfbd847408b778e6738c7bb5b1 b3f97d01b0a24dcca40e3bed29411b1b a8f60843c4a241021b23132b9500509b 9a3516d4a9dd41d3bacbcd426b451393 521828afedcf20fa46ac24f44a8e2973 30b16705d5d5f798eff9e9134a065979 87a1db4617caa2d93837730829d4d89e 16413be4d8a8a38a7e6226623b64a820 178ec3a66954e10710e043ae73dd3fb2 715a0525a46343fb7590e5eac7ee55fc 810e0d8b4b8f7be82cd5a214575a1b99 629d47a9b281b61348c8627cab38e2a6 4db6626e97bb8f77bdcb0fee476aedd7 ba8f5441acaab00f4432edab3791047d 9091b2a753f035648431f6d12f7d6a68 1e64c861f4ac911a0f7d6ec0491a78c9 f192f96b3a5e7560a3f056bc1ca85983 67ad6acb6f2e034c7f37beeb9ed470c4 304af0107f0eb919be36a86f68f37fa6 1dae7aff14decd67ec3157a11488a14f ed0142828348f5f608b0fe03e1f3c0af 3acca0ce36852ed42e220ae9abf8f890 6f00f1b86bff8504c8f16c784fd52d25 e013ff4fda903e9e1eb453c1464b1196 6db9b28e8f26a3fc419e6a60a48d4c72 14ee9c6c6a12b68a32cac8f61580c64f 29cb6922408783c6d12e725b014fe485 cd17e484c5952bf99bc94941d4b1919d 04317b8aa1bd3754ecbaa10ec227de85 40695bf2fb8ee56f6dc526ef366625b9 1aa4970b6ffa5c8284b9b5ab852b905f 9d83f5669c0535bc377bcc05ad5e48e2 81ec0e1917ca3c6a471f8da0894bc82a c2a8965405d6eef3b5e293a88fda203f 09bdc72757b107ab14880eaa3ef7045b 580f4821ce6dd325b5a90655d8c5b55f 76fb846279a9b518c5e9b9a21165c509 3ed49baaacadf1f21873266c767f6769"), - ), - ) - - DescribeTable("encrypts the server's Initial", - func(v quic.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) - sealed := sealer.Seal(nil, data, 1, header) - sample := sealed[2 : 2+16] - Expect(sample).To(Equal(expectedSample)) - sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) - Expect(header).To(Equal(expectedHdr)) - packet := append(header, sealed...) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("c1ff00001d0008f067a5502a4262b50040740001"), - splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304"), - splitHexString("823a5d3a1207c86ee49132824f046524"), - splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"), - splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c1000000010008f067a5502a4262b50040750001"), - splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), - splitHexString("2cd0991cd25b0aac406a5816b6394100"), - splitHexString("cf000000010008f067a5502a4262b5004075c0d9"), - splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("d1709a50c40008f067a5502a4262b50040750001"), - splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), - splitHexString("ebb7972fdce59d50e7e49ff2a7e8de76"), - splitHexString("d0709a50c40008f067a5502a4262b5004075103e"), - splitHexString("d0709a50c40008f067a5502a4262b500 4075103e63b4ebb7972fdce59d50e7e4 9ff2a7e8de76b0cd8c10100a1f13d549 dd6fe801588fb14d279bef8d7c53ef62 66a9a7a1a5f2fa026c236a5bf8df5aa0 f9d74773aeccfffe910b0f76814b5e33 f7b7f8ec278d23fd8c7a9e66856b8bbe 72558135bca27c54d63fcc902253461c fc089d4e6b9b19"), - ), - ) - - for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - It("seals and opens", func() { - connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} - clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) - serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) - - clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) - m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(m).To(Equal([]byte("foobar"))) - serverMessage := serverSealer.Seal(nil, []byte("raboof"), 99, []byte("daa")) - m, err = clientOpener.Open(nil, serverMessage, 99, []byte("daa")) - Expect(err).ToNot(HaveOccurred()) - Expect(m).To(Equal([]byte("raboof"))) - }) - - It("doesn't work if initialized with different connection IDs", func() { - c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} - c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} - clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) - _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) - - clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) - _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("encrypts und decrypts the header", func() { - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) - serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) - - // the first byte and the last 4 bytes should be encrypted - header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - clientSealer.EncryptHeader(sample, &header[0], header[6:10]) - // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified - Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - serverOpener.DecryptHeader(sample, &header[0], header[6:10]) - Expect(header[0]).To(Equal(byte(0x5e))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - - serverSealer.EncryptHeader(sample, &header[0], header[6:10]) - // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified - Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - clientOpener.DecryptHeader(sample, &header[0], header[6:10]) - Expect(header[0]).To(Equal(byte(0x5e))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - }) - }) - } -}) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go deleted file mode 100644 index e17202c3..00000000 --- a/internal/handshake/interface.go +++ /dev/null @@ -1,103 +0,0 @@ -package handshake - -import ( - "errors" - "github.com/lucas-clemente/quic-go" - "io" - "net" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" - "github.com/imroc/req/v3/internal/wire" -) - -var ( - // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, - // but the corresponding opener has not yet been initialized - // This can happen when packets arrive out of order. - ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") - // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, - // but the corresponding keys have already been dropped. - ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") - // ErrDecryptionFailed is returned when the AEAD fails to open the packet. - ErrDecryptionFailed = errors.New("decryption failed") -) - -// ConnectionState contains information about the state of the connection. -type ConnectionState = qtls.ConnectionState - -type headerDecryptor interface { - DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) -} - -// LongHeaderOpener opens a long header packet -type LongHeaderOpener interface { - headerDecryptor - DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber - Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) -} - -// ShortHeaderOpener opens a short header packet -type ShortHeaderOpener interface { - headerDecryptor - DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber - Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) -} - -// LongHeaderSealer seals a long header packet -type LongHeaderSealer interface { - Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte - EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) - Overhead() int -} - -// ShortHeaderSealer seals a short header packet -type ShortHeaderSealer interface { - LongHeaderSealer - KeyPhase() protocol.KeyPhaseBit -} - -// A tlsExtensionHandler sends and received the QUIC TLS extension. -type tlsExtensionHandler interface { - GetExtensions(msgType uint8) []qtls.Extension - ReceivedExtensions(msgType uint8, exts []qtls.Extension) - TransportParameters() <-chan []byte -} - -type handshakeRunner interface { - OnReceivedParams(*wire.TransportParameters) - OnHandshakeComplete() - OnError(error) - DropKeys(protocol.EncryptionLevel) -} - -// CryptoSetup handles the handshake and protecting / unprotecting packets -type CryptoSetup interface { - RunHandshake() - io.Closer - ChangeConnectionID(protocol.ConnectionID) - GetSessionTicket() ([]byte, error) - - HandleMessage([]byte, protocol.EncryptionLevel) bool - SetLargest1RTTAcked(protocol.PacketNumber) error - SetHandshakeConfirmed() - ConnectionState() ConnectionState - - GetInitialOpener() (LongHeaderOpener, error) - GetHandshakeOpener() (LongHeaderOpener, error) - Get0RTTOpener() (LongHeaderOpener, error) - Get1RTTOpener() (ShortHeaderOpener, error) - - GetInitialSealer() (LongHeaderSealer, error) - GetHandshakeSealer() (LongHeaderSealer, error) - Get0RTTSealer() (LongHeaderSealer, error) - Get1RTTSealer() (ShortHeaderSealer, error) -} - -// ConnWithVersion is the connection used in the ClientHelloInfo. -// It can be used to determine the QUIC version in use. -type ConnWithVersion interface { - net.Conn - GetQUICVersion() quic.VersionNumber -} diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go deleted file mode 100644 index 4f25e6a2..00000000 --- a/internal/handshake/mock_handshake_runner_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: interface.go - -// Package handshake is a generated GoMock package. -package handshake - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - wire "github.com/imroc/req/v3/internal/wire" -) - -// MockHandshakeRunner is a mock of HandshakeRunner interface. -type MockHandshakeRunner struct { - ctrl *gomock.Controller - recorder *MockHandshakeRunnerMockRecorder -} - -// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner. -type MockHandshakeRunnerMockRecorder struct { - mock *MockHandshakeRunner -} - -// NewMockHandshakeRunner creates a new mock instance. -func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { - mock := &MockHandshakeRunner{ctrl: ctrl} - mock.recorder = &MockHandshakeRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { - return m.recorder -} - -// DropKeys mocks base method. -func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropKeys", arg0) -} - -// DropKeys indicates an expected call of DropKeys. -func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) -} - -// OnError mocks base method. -func (m *MockHandshakeRunner) OnError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnError", arg0) -} - -// OnError indicates an expected call of OnError. -func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) -} - -// OnHandshakeComplete mocks base method. -func (m *MockHandshakeRunner) OnHandshakeComplete() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnHandshakeComplete") -} - -// OnHandshakeComplete indicates an expected call of OnHandshakeComplete. -func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) -} - -// OnReceivedParams mocks base method. -func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnReceivedParams", arg0) -} - -// OnReceivedParams indicates an expected call of OnReceivedParams. -func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) -} diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go deleted file mode 100644 index 0fc6389b..00000000 --- a/internal/handshake/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package handshake - -//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/imroc/req/v3/internal/handshake handshakeRunner" diff --git a/internal/handshake/retry.go b/internal/handshake/retry.go deleted file mode 100644 index c942cd47..00000000 --- a/internal/handshake/retry.go +++ /dev/null @@ -1,63 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "fmt" - "github.com/lucas-clemente/quic-go" - "sync" - - "github.com/imroc/req/v3/internal/protocol" -) - -var ( - oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 - retryAEAD cipher.AEAD // used for QUIC draft-34 -) - -func init() { - oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) - retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) -} - -func initAEAD(key [16]byte) cipher.AEAD { - aes, err := aes.NewCipher(key[:]) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - return aead -} - -var ( - retryBuf bytes.Buffer - retryMutex sync.Mutex - oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} - retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} -) - -// GetRetryIntegrityTag calculates the integrity tag on a Retry packet -func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version quic.VersionNumber) *[16]byte { - retryMutex.Lock() - retryBuf.WriteByte(uint8(origDestConnID.Len())) - retryBuf.Write(origDestConnID.Bytes()) - retryBuf.Write(retry) - - var tag [16]byte - var sealed []byte - if version != protocol.Version1 { - sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) - } else { - sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) - } - if len(sealed) != 16 { - panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) - } - retryBuf.Reset() - retryMutex.Unlock() - return &tag -} diff --git a/internal/handshake/retry_test.go b/internal/handshake/retry_test.go deleted file mode 100644 index 4b74fc41..00000000 --- a/internal/handshake/retry_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package handshake - -import ( - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Retry Integrity Check", func() { - It("calculates retry integrity tags", func() { - fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - Expect(fooTag).ToNot(BeNil()) - Expect(barTag).ToNot(BeNil()) - Expect(*fooTag).ToNot(Equal(*barTag)) - }) - - It("includes the original connection ID in the tag calculation", func() { - t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) - t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) - Expect(*t1).ToNot(Equal(*t2)) - }) - - It("uses the test vector from the draft, for old draft versions", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") - Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) - }) - - It("uses the test vector from the draft, for version 1", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") - Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) - }) -}) diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go deleted file mode 100644 index 1ac9e049..00000000 --- a/internal/handshake/session_ticket.go +++ /dev/null @@ -1,48 +0,0 @@ -package handshake - -import ( - "bytes" - "errors" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/wire" -) - -const sessionTicketRevision = 2 - -type sessionTicket struct { - Parameters *wire.TransportParameters - RTT time.Duration // to be encoded in mus -} - -func (t *sessionTicket) Marshal() []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - quicvarint.Write(b, uint64(t.RTT.Microseconds())) - t.Parameters.MarshalForSessionTicket(b) - return b.Bytes() -} - -func (t *sessionTicket) Unmarshal(b []byte) error { - r := bytes.NewReader(b) - rev, err := quicvarint.Read(r) - if err != nil { - return errors.New("failed to read session ticket revision") - } - if rev != sessionTicketRevision { - return fmt.Errorf("unknown session ticket revision: %d", rev) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return errors.New("failed to read RTT") - } - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) - } - t.Parameters = &tp - t.RTT = time.Duration(rtt) * time.Microsecond - return nil -} diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go deleted file mode 100644 index bf5b2407..00000000 --- a/internal/handshake/session_ticket_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package handshake - -import ( - "bytes" - "time" - - "github.com/imroc/req/v3/internal/wire" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Session Ticket", func() { - It("marshals and unmarshals a session ticket", func() { - ticket := &sessionTicket{ - Parameters: &wire.TransportParameters{ - InitialMaxStreamDataBidiLocal: 1, - InitialMaxStreamDataBidiRemote: 2, - }, - RTT: 1337 * time.Microsecond, - } - var t sessionTicket - Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) - Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) - Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) - Expect(t.RTT).To(Equal(1337 * time.Microsecond)) - }) - - It("refuses to unmarshal if the ticket is too short for the revision", func() { - Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) - }) - - It("refuses to unmarshal if the revision doesn't match", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, 1337) - Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("unknown session ticket revision: 1337")) - }) - - It("refuses to unmarshal if the RTT cannot be read", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("failed to read RTT")) - }) - - It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - b.Write([]byte("foobar")) - err := (&sessionTicket{}).Unmarshal(b.Bytes()) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) - }) -}) diff --git a/internal/handshake/tls_extension_handler.go b/internal/handshake/tls_extension_handler.go deleted file mode 100644 index 967f4d1f..00000000 --- a/internal/handshake/tls_extension_handler.go +++ /dev/null @@ -1,69 +0,0 @@ -package handshake - -import ( - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" - "github.com/lucas-clemente/quic-go" -) - -const ( - quicTLSExtensionTypeOldDrafts = 0xffa5 - quicTLSExtensionType = 0x39 -) - -type extensionHandler struct { - ourParams []byte - paramsChan chan []byte - - extensionType uint16 - - perspective protocol.Perspective -} - -var _ tlsExtensionHandler = &extensionHandler{} - -// newExtensionHandler creates a new extension handler -func newExtensionHandler(params []byte, pers protocol.Perspective, v quic.VersionNumber) tlsExtensionHandler { - et := uint16(quicTLSExtensionType) - if v != protocol.Version1 { - et = quicTLSExtensionTypeOldDrafts - } - return &extensionHandler{ - ourParams: params, - paramsChan: make(chan []byte), - perspective: pers, - extensionType: et, - } -} - -func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { - return nil - } - return []qtls.Extension{{ - Type: h.extensionType, - Data: h.ourParams, - }} -} - -func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { - return - } - - var data []byte - for _, ext := range exts { - if ext.Type == h.extensionType { - data = ext.Data - break - } - } - - h.paramsChan <- data -} - -func (h *extensionHandler) TransportParameters() <-chan []byte { - return h.paramsChan -} diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go deleted file mode 100644 index 212ab32f..00000000 --- a/internal/handshake/tls_extension_handler_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package handshake - -import ( - "fmt" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("TLS Extension Handler, for the server", func() { - var ( - handlerServer tlsExtensionHandler - handlerClient tlsExtensionHandler - version quic.VersionNumber - ) - - BeforeEach(func() { - version = protocol.VersionDraft29 - }) - - JustBeforeEach(func() { - handlerServer = newExtensionHandler( - []byte("foobar"), - protocol.PerspectiveServer, - version, - ) - handlerClient = newExtensionHandler( - []byte("raboof"), - protocol.PerspectiveClient, - version, - ) - }) - - Context("for the server", func() { - for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1} { - v := ver - - Context(fmt.Sprintf("sending, for version %s", v), func() { - var extensionType uint16 - - BeforeEach(func() { - version = v - if v == protocol.VersionDraft29 { - extensionType = quicTLSExtensionTypeOldDrafts - } else { - extensionType = quicTLSExtensionType - } - }) - - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the EncryptedExtensions message", func() { - exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) - Expect(exts[0].Data).To(Equal([]byte("foobar"))) - }) - }) - } - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("raboof"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the ClientHello", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("for the client", func() { - for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1} { - v := ver - - Context(fmt.Sprintf("sending, for version %s", v), func() { - var extensionType uint16 - - BeforeEach(func() { - version = v - if v == protocol.VersionDraft29 { - extensionType = quicTLSExtensionTypeOldDrafts - } else { - extensionType = quicTLSExtensionType - } - }) - - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the ClientHello message", func() { - exts := handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) - Expect(exts[0].Data).To(Equal([]byte("raboof"))) - }) - }) - } - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("foobar"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the EncryptedExtensions", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) -}) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go deleted file mode 100644 index 349a1ee5..00000000 --- a/internal/handshake/token_generator.go +++ /dev/null @@ -1,134 +0,0 @@ -package handshake - -import ( - "encoding/asn1" - "fmt" - "io" - "net" - "time" - - "github.com/imroc/req/v3/internal/protocol" -) - -const ( - tokenPrefixIP byte = iota - tokenPrefixString -) - -// A Token is derived from the client address and can be used to verify the ownership of this address. -type Token struct { - IsRetryToken bool - RemoteAddr string - SentTime time.Time - // only set for retry tokens - OriginalDestConnectionID protocol.ConnectionID - RetrySrcConnectionID protocol.ConnectionID -} - -// token is the struct that is used for ASN1 serialization and deserialization -type token struct { - IsRetryToken bool - RemoteAddr []byte - Timestamp int64 - OriginalDestConnectionID []byte - RetrySrcConnectionID []byte -} - -// A TokenGenerator generates tokens -type TokenGenerator struct { - tokenProtector tokenProtector -} - -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil -} - -// NewRetryToken generates a new token for a Retry for a given source address -func (g *TokenGenerator) NewRetryToken( - raddr net.Addr, - origDestConnID protocol.ConnectionID, - retrySrcConnID protocol.ConnectionID, -) ([]byte, error) { - data, err := asn1.Marshal(token{ - IsRetryToken: true, - RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origDestConnID, - RetrySrcConnectionID: retrySrcConnID, - Timestamp: time.Now().UnixNano(), - }) - if err != nil { - return nil, err - } - return g.tokenProtector.NewToken(data) -} - -// NewToken generates a new token to be sent in a NEW_TOKEN frame -func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) { - data, err := asn1.Marshal(token{ - RemoteAddr: encodeRemoteAddr(raddr), - Timestamp: time.Now().UnixNano(), - }) - if err != nil { - return nil, err - } - return g.tokenProtector.NewToken(data) -} - -// DecodeToken decodes a token -func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { - // if the client didn't send any token, DecodeToken will be called with a nil-slice - if len(encrypted) == 0 { - return nil, nil - } - - data, err := g.tokenProtector.DecodeToken(encrypted) - if err != nil { - return nil, err - } - t := &token{} - rest, err := asn1.Unmarshal(data, t) - if err != nil { - return nil, err - } - if len(rest) != 0 { - return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) - } - token := &Token{ - IsRetryToken: t.IsRetryToken, - RemoteAddr: decodeRemoteAddr(t.RemoteAddr), - SentTime: time.Unix(0, t.Timestamp), - } - if t.IsRetryToken { - token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) - token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) - } - return token, nil -} - -// encodeRemoteAddr encodes a remote address such that it can be saved in the token -func encodeRemoteAddr(remoteAddr net.Addr) []byte { - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return append([]byte{tokenPrefixIP}, udpAddr.IP...) - } - return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) -} - -// decodeRemoteAddr decodes the remote address saved in the token -func decodeRemoteAddr(data []byte) string { - // data will never be empty for a token that we generated. - // Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == tokenPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go deleted file mode 100644 index f2a2c0b3..00000000 --- a/internal/handshake/token_generator_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package handshake - -import ( - "crypto/rand" - "encoding/asn1" - "net" - "time" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Token Generator", func() { - var tokenGen *TokenGenerator - - BeforeEach(func() { - var err error - tokenGen, err = NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - }) - - It("generates a token", func() { - ip := net.IPv4(127, 0, 0, 1) - token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(token).ToNot(BeEmpty()) - }) - - It("works with nil tokens", func() { - token, err := tokenGen.DecodeToken(nil) - Expect(err).ToNot(HaveOccurred()) - Expect(token).To(BeNil()) - }) - - It("accepts a valid token", func() { - ip := net.IPv4(192, 168, 0, 1) - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{IP: ip, Port: 1337}, - nil, - nil, - ) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.0.1")) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) - Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) - }) - - It("saves the connection ID", func() { - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{}, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - ) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - }) - - It("rejects invalid tokens", func() { - _, err := tokenGen.DecodeToken([]byte("invalid token")) - Expect(err).To(HaveOccurred()) - }) - - It("rejects tokens that cannot be decoded", func() { - token, err := tokenGen.tokenProtector.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(token) - Expect(err).To(HaveOccurred()) - }) - - It("rejects tokens that can be decoded, but have additional payload", func() { - t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) - Expect(err).ToNot(HaveOccurred()) - t = append(t, []byte("rest")...) - enc, err := tokenGen.tokenProtector.NewToken(t) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(enc) - Expect(err).To(MatchError("rest when unpacking token: 4")) - }) - - // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason - It("doesn't panic if a tokens has no data", func() { - t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) - Expect(err).ToNot(HaveOccurred()) - enc, err := tokenGen.tokenProtector.NewToken(t) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(enc) - Expect(err).ToNot(HaveOccurred()) - }) - - It("works with an IPv6 addresses ", func() { - addresses := []string{ - "2001:db8::68", - "2001:0000:4136:e378:8000:63bf:3fff:fdd2", - "2001::1", - "ff01:0:0:0:0:0:0:2", - } - for _, addr := range addresses { - ip := net.ParseIP(addr) - Expect(ip).ToNot(BeNil()) - raddr := &net.UDPAddr{IP: ip, Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal(ip.String())) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - } - }) - - It("uses the string representation an address that is not a UDP address", func() { - raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) -}) diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go deleted file mode 100644 index 650f230b..00000000 --- a/internal/handshake/token_protector.go +++ /dev/null @@ -1,89 +0,0 @@ -package handshake - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/sha256" - "fmt" - "io" - - "golang.org/x/crypto/hkdf" -) - -// TokenProtector is used to create and verify a token -type tokenProtector interface { - // NewToken creates a new token - NewToken([]byte) ([]byte, error) - // DecodeToken decodes a token - DecodeToken([]byte) ([]byte, error) -} - -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) - -// tokenProtector is used to create and verify a token -type tokenProtectorImpl struct { - rand io.Reader - secret []byte -} - -// newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil -} - -// NewToken encodes data into a new token. -func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { - return nil, err - } - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil -} - -// DecodeToken decodes a token. -func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { - if len(p) < tokenNonceSize { - return nil, fmt.Errorf("token too short: %d", len(p)) - } - nonce := p[:tokenNonceSize] - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) -} - -func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) - key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 - if _, err := io.ReadFull(h, key); err != nil { - return nil, nil, err - } - aeadNonce := make([]byte, 12) - if _, err := io.ReadFull(h, aeadNonce); err != nil { - return nil, nil, err - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - aead, err := cipher.NewGCM(c) - if err != nil { - return nil, nil, err - } - return aead, aeadNonce, nil -} diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go deleted file mode 100644 index 7171e865..00000000 --- a/internal/handshake/token_protector_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package handshake - -import ( - "crypto/rand" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type zeroReader struct{} - -func (r *zeroReader) Read(b []byte) (int, error) { - for i := range b { - b[i] = 0 - } - return len(b), nil -} - -var _ = Describe("Token Protector", func() { - var tp tokenProtector - - BeforeEach(func() { - var err error - tp, err = newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - }) - - It("uses the random source", func() { - tp1, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - tp2, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - t1, err := tp1.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - t2, err := tp2.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t1).To(Equal(t2)) - tp3, err := newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - t3, err := tp3.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t3).ToNot(Equal(t1)) - }) - - It("encodes and decodes tokens", func() { - token, err := tp.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(token).ToNot(ContainSubstring("foobar")) - decoded, err := tp.DecodeToken(token) - Expect(err).ToNot(HaveOccurred()) - Expect(decoded).To(Equal([]byte("foobar"))) - }) - - It("fails deconding invalid tokens", func() { - token, err := tp.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - token = token[1:] // remove the first byte - _, err = tp.DecodeToken(token) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("message authentication failed")) - }) - - It("errors when decoding too short tokens", func() { - _, err := tp.DecodeToken([]byte("foobar")) - Expect(err).To(MatchError("token too short: 6")) - }) -}) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go deleted file mode 100644 index 13ddbc22..00000000 --- a/internal/handshake/updatable_aead.go +++ /dev/null @@ -1,324 +0,0 @@ -package handshake - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "encoding/binary" - "fmt" - "github.com/lucas-clemente/quic-go" - "time" - - "github.com/imroc/req/v3/internal/logging" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/qtls" - "github.com/imroc/req/v3/internal/utils" -) - -// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. -// It's a package-level variable to allow modifying it for testing purposes. -var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval - -type updatableAEAD struct { - suite *qtls.CipherSuiteTLS13 - - keyPhase protocol.KeyPhase - largestAcked protocol.PacketNumber - firstPacketNumber protocol.PacketNumber - handshakeConfirmed bool - - keyUpdateInterval uint64 - invalidPacketLimit uint64 - invalidPacketCount uint64 - - // Time when the keys should be dropped. Keys are dropped on the next call to Open(). - prevRcvAEADExpiry time.Time - prevRcvAEAD cipher.AEAD - - firstRcvdWithCurrentKey protocol.PacketNumber - firstSentWithCurrentKey protocol.PacketNumber - highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - numRcvdWithCurrentKey uint64 - numSentWithCurrentKey uint64 - rcvAEAD cipher.AEAD - sendAEAD cipher.AEAD - // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). - aeadOverhead int - - nextRcvAEAD cipher.AEAD - nextSendAEAD cipher.AEAD - nextRcvTrafficSecret []byte - nextSendTrafficSecret []byte - - headerDecrypter headerProtector - headerEncrypter headerProtector - - rttStats *utils.RTTStats - - tracer logging.ConnectionTracer - logger utils.Logger - version quic.VersionNumber - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var ( - _ ShortHeaderOpener = &updatableAEAD{} - _ ShortHeaderSealer = &updatableAEAD{} -) - -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version quic.VersionNumber) *updatableAEAD { - return &updatableAEAD{ - firstPacketNumber: protocol.InvalidPacketNumber, - largestAcked: protocol.InvalidPacketNumber, - firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, - firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, - rttStats: rttStats, - tracer: tracer, - logger: logger, - version: version, - } -} - -func (a *updatableAEAD) rollKeys() { - if a.prevRcvAEAD != nil { - a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { - a.tracer.DroppedKey(a.keyPhase - 1) - } - a.prevRcvAEADExpiry = time.Time{} - } - - a.keyPhase++ - a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber - a.firstSentWithCurrentKey = protocol.InvalidPacketNumber - a.numRcvdWithCurrentKey = 0 - a.numSentWithCurrentKey = 0 - a.prevRcvAEAD = a.rcvAEAD - a.rcvAEAD = a.nextRcvAEAD - a.sendAEAD = a.nextSendAEAD - - a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) - a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) - a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) - a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) -} - -func (a *updatableAEAD) startKeyDropTimer(now time.Time) { - d := 3 * a.rttStats.PTO(true) - a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) - a.prevRcvAEADExpiry = now.Add(d) -} - -func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { - return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) -} - -// For the client, this function is called before SetWriteKey. -// For the server, this function is called after SetWriteKey. -func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) - a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) - if a.suite == nil { - a.setAEADParameters(a.rcvAEAD, suite) - } - - a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) -} - -// For the client, this function is called after SetReadKey. -// For the server, this function is called before SetWriteKey. -func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.sendAEAD = createAEAD(suite, trafficSecret, a.version) - a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) - if a.suite == nil { - a.setAEADParameters(a.sendAEAD, suite) - } - - a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) -} - -func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { - a.nonceBuf = make([]byte, aead.NonceSize()) - a.aeadOverhead = aead.Overhead() - a.suite = suite - switch suite.ID { - case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: - a.invalidPacketLimit = protocol.InvalidPacketLimitAES - case tls.TLS_CHACHA20_POLY1305_SHA256: - a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha - default: - panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) - } -} - -func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { - return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) -} - -func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { - dec, err := a.open(dst, src, rcvTime, pn, kp, ad) - if err == ErrDecryptionFailed { - a.invalidPacketCount++ - if a.invalidPacketCount >= a.invalidPacketLimit { - return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} - } - } - if err == nil { - a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) - } - return dec, err -} - -func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { - if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { - a.prevRcvAEAD = nil - a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) - a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { - a.tracer.DroppedKey(a.keyPhase - 1) - } - } - binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) - if kp != a.keyPhase.Bit() { - if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { - if a.prevRcvAEAD == nil { - return nil, ErrKeysDropped - } - // we updated the key, but the peer hasn't updated yet - dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - err = ErrDecryptionFailed - } - return dec, err - } - // try opening the packet with the next key phase - dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - return nil, ErrDecryptionFailed - } - // Opening succeeded. Check if the peer was allowed to update. - if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { - return nil, &qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "keys updated too quickly", - } - } - a.rollKeys() - a.logger.Debugf("Peer updated keys to %d", a.keyPhase) - // The peer initiated this key update. It's safe to drop the keys for the previous generation now. - // Start a timer to drop the previous key generation. - a.startKeyDropTimer(rcvTime) - if a.tracer != nil { - a.tracer.UpdatedKey(a.keyPhase, true) - } - a.firstRcvdWithCurrentKey = pn - return dec, err - } - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - return dec, ErrDecryptionFailed - } - a.numRcvdWithCurrentKey++ - if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { - // We initiated the key updated, and now we received the first packet protected with the new key phase. - // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. - if a.keyPhase > 0 { - a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) - a.startKeyDropTimer(rcvTime) - } - a.firstRcvdWithCurrentKey = pn - } - return dec, err -} - -func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { - a.firstSentWithCurrentKey = pn - } - if a.firstPacketNumber == protocol.InvalidPacketNumber { - a.firstPacketNumber = pn - } - a.numSentWithCurrentKey++ - binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) -} - -func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { - if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { - return &qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), - } - } - a.largestAcked = pn - return nil -} - -func (a *updatableAEAD) SetHandshakeConfirmed() { - a.handshakeConfirmed = true -} - -func (a *updatableAEAD) updateAllowed() bool { - if !a.handshakeConfirmed { - return false - } - // the first key update is allowed as soon as the handshake is confirmed - return a.keyPhase == 0 || - // subsequent key updates as soon as a packet sent with that key phase has been acknowledged - (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - a.largestAcked != protocol.InvalidPacketNumber && - a.largestAcked >= a.firstSentWithCurrentKey) -} - -func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { - if !a.updateAllowed() { - return false - } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) - return true - } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) - return true - } - return false -} - -func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { - if a.shouldInitiateKeyUpdate() { - a.rollKeys() - a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { - a.tracer.UpdatedKey(a.keyPhase, false) - } - } - return a.keyPhase.Bit() -} - -func (a *updatableAEAD) Overhead() int { - return a.aeadOverhead -} - -func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) -} - -func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) -} - -func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { - return a.firstPacketNumber -} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go deleted file mode 100644 index 83cdeebc..00000000 --- a/internal/handshake/updatable_aead_test.go +++ /dev/null @@ -1,529 +0,0 @@ -package handshake - -import ( - "crypto/rand" - "crypto/tls" - "fmt" - "github.com/lucas-clemente/quic-go" - "time" - - "github.com/golang/mock/gomock" - - mocklogging "github.com/imroc/req/v3/internal/mocks/logging" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" -) - -var _ = Describe("Updatable AEAD", func() { - DescribeTable("ChaCha test vector", - func(v quic.VersionNumber, expectedPayload, expectedPacket []byte) { - secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") - aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) - chacha := cipherSuites[2] - Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) - aead.SetWriteKey(chacha, secret) - const pnOffset = 1 - header := splitHexString("4200bff4") - payloadOffset := len(header) - plaintext := splitHexString("01") - payload := aead.Seal(nil, plaintext, 654360564, header) - Expect(payload).To(Equal(expectedPayload)) - packet := append(header, payload...) - aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("QUIC v1", - protocol.Version1, - splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), - splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), - splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), - ), - ) - - for _, ver := range []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - for i := range cipherSuites { - cs := cipherSuites[i] - - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - var ( - client, server *updatableAEAD - serverTracer *mocklogging.MockConnectionTracer - rttStats *utils.RTTStats - ) - - BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) - trafficSecret1 := make([]byte, 16) - trafficSecret2 := make([]byte, 16) - rand.Read(trafficSecret1) - rand.Read(trafficSecret2) - - rttStats = utils.NewRTTStats() - client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) - client.SetReadKey(cs, trafficSecret2) - client.SetWriteKey(cs, trafficSecret1) - server.SetReadKey(cs, trafficSecret1) - server.SetWriteKey(cs, trafficSecret2) - }) - - Context("header protection", func() { - It("encrypts and decrypts the header", func() { - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - client.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - server.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - }) - - Context("message encryption", func() { - var msg, ad []byte - - BeforeEach(func() { - msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad = []byte("Donec in velit neque.") - }) - - It("encrypts and decrypts a message", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("saves the first packet number", func() { - client.Seal(nil, msg, 0x1337, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - client.Seal(nil, msg, 0x1338, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - }) - - It("fails to open a message if the associated data is not the same", func() { - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) - }) - - It("ignores packets it can't decrypt for packet number derivation", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) - }) - - It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { - client.invalidPacketLimit = 10 - for i := 0; i < 9; i++ { - _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - } - _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) - }) - - Context("key updates", func() { - Context("receiving key updates", func() { - It("updates keys", func() { - now := time.Now() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys() - decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - }) - - It("updates the keys when receiving a packet with the next key phase", func() { - now := time.Now() - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x43, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("opens a reordered packet with the old keys after an update", func() { - now := time.Now() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO(true) - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("allows the first key update immediately", func() { - // receive a packet at key phase one, before having sent or received any packets at key phase 0 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x1337, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { - client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - encrypted = encrypted[:len(encrypted)-1] - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("errors when the peer updates keys too frequently", func() { - server.rollKeys() - client.rollKeys() - // receive the first packet at key phase one - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase two, before having sent any packets - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "keys updated too quickly", - })) - }) - }) - - Context("initiating key updates", func() { - const keyUpdateInterval = 20 - - BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval - server.SetHandshakeConfirmed() - }) - - It("initiates a key update after sealing the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // receive an ACK for a packet sent in key phase 0 - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // Now that our keys are updated, send a packet using the new keys. - const nextPN = keyUpdateInterval + 1 - server.Seal(nil, msg, nextPN, ad) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", - })) - }) - - It("doesn't error before actually sending a packet in the new key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - // Now that our keys are updated, send a packet using the new keys. - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) - }) - - It("initiates a key update after opening the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, 1, ad) - Expect(server.SetLargestAcked(1)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(0)).To(Succeed()) - // Now we've initiated the first key update. - // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there - threePTO := 3 * rttStats.PTO(false) - dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // Now receive a packet with key phase 1. - // This should start the timer to drop the keys after 3 PTOs. - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) - t := now.Add(threePTO).Add(time.Second) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // Make sure the keys are still here. - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("doesn't drop the first key generation too early", func() { - now := time.Now() - data1 := client.Seal(nil, msg, 1, ad) - _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - Expect(server.SetLargestAcked(pn)).To(Succeed()) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // The server never received a packet at key phase 1. - // Make sure the key phase 0 is still there at a much later point. - data2 := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - const nextPN = keyUpdateInterval + 1 - // Send and receive an acknowledgement for a packet in key phase 1. - // We are now running a timer to drop the keys with 3 PTO. - server.Seal(nil, msg, nextPN, ad) - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) - now := time.Now() - _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(nextPN)) - // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. - // This mean that we need to drop the keys for key phase 0 immediately. - client.rollKeys() - dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), - ) - _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("drops keys early when we initiate another key update within the 3 PTO period", func() { - server.SetHandshakeConfirmed() - // send so many packets that we initiate the first key update - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // send so many packets that we initiate the next key update - for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - client.rollKeys() - b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) - now := time.Now() - _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), - ) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - // We haven't received an ACK for a packet sent in key phase 2 yet. - // Make sure we canceled the timer to drop the previous key phase. - b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) - _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - }) - }) - }) - }) - }) - } - }) - } -}) diff --git a/internal/http3/server.go b/internal/http3/server.go index 4a338f00..07d15ed2 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -1,25 +1,8 @@ package http3 import ( - "bytes" - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "runtime" - "strings" - "sync" - "time" - - "github.com/imroc/req/v3/internal/handshake" "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/utils" "github.com/lucas-clemente/quic-go" - "github.com/marten-seemann/qpack" ) // allows mocking of quic.Listen and quic.ListenAddr @@ -53,56 +36,6 @@ func versionToALPN(v quic.VersionNumber) string { return "" } -// ConfigureTLSConfig creates a new tls.Config which can be used -// to create a quic.Listener meant for serving http3. The created -// tls.Config adds the functionality of detecting the used QUIC version -// in order to set the correct ALPN value for the http3 connection. -func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { - // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. - // That way, we can get the QUIC version and set the correct ALPN value. - return &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - // determine the ALPN from the QUIC version used - proto := nextProtoH3 - if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - proto = versionToALPN(qconn.GetQUICVersion()) - } - config := tlsConf - if tlsConf.GetConfigForClient != nil { - getConfigForClient := tlsConf.GetConfigForClient - var err error - conf, err := getConfigForClient(ch) - if err != nil { - return nil, err - } - if conf != nil { - config = conf - } - } - if config == nil { - return nil, nil - } - config = config.Clone() - config.NextProtos = []string{proto} - return config, nil - }, - } -} - -// contextKey is a value for use with context.WithValue. It's used as -// a pointer so it fits in an interface{} without allocation. -type contextKey struct { - name string -} - -func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name } - -// ServerContextKey is a context key. It can be used in HTTP -// handlers with Context.Value to access the server that -// started the handler. The associated value will be of -// type *http3.Server. -var ServerContextKey = &contextKey{"http3-server"} - type requestError struct { err error streamErr errorCode @@ -116,621 +49,3 @@ func newStreamError(code errorCode, err error) requestError { func newConnError(code errorCode, err error) requestError { return requestError{err: err, connErr: code} } - -// listenerInfo contains info about specific listener added with addListener -type listenerInfo struct { - port int // 0 means that no info about port is available -} - -// Server is a HTTP/3 server. -type Server struct { - // Addr optionally specifies the UDP address for the server to listen on, - // in the form "host:port". - // - // When used by ListenAndServe and ListenAndServeTLS methods, if empty, - // ":https" (port 443) is used. See net.Dial for details of the address - // format. - // - // Otherwise, if Port is not set and underlying QUIC listeners do not - // have valid port numbers, the port part is used in Alt-Svc headers set - // with SetQuicHeaders. - Addr string - - // Port is used in Alt-Svc response headers set with SetQuicHeaders. If - // needed Port can be manually set when the Server is created. - // - // This is useful when a Layer 4 firewall is redirecting UDP traffic and - // clients must use a port different from the port the Server is - // listening on. - Port int - - // TLSConfig provides a TLS configuration for use by server. It must be - // set for ListenAndServe and Serve methods. - TLSConfig *tls.Config - - // QuicConfig provides the parameters for QUIC connection created with - // Serve. If nil, it uses reasonable default values. - // - // Configured versions are also used in Alt-Svc response header set with - // SetQuicHeaders. - QuicConfig *quic.Config - - // Handler is the HTTP request handler to use. If not set, defaults to - // http.NotFound. - Handler http.Handler - - // EnableDatagrams enables support for HTTP/3 datagrams. - // If set to true, QuicConfig.EnableDatagram will be set. - // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. - EnableDatagrams bool - - // MaxHeaderBytes controls the maximum number of bytes the server will - // read parsing the request HEADERS frame. It does not limit the size of - // the request body. If zero or negative, http.DefaultMaxHeaderBytes is - // used. - MaxHeaderBytes int - - // AdditionalSettings specifies additional HTTP/3 settings. - // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. - AdditionalSettings map[uint64]uint64 - - // StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream. - // It is called right after parsing the frame type. - // If parsing the frame type fails, the error is passed to the callback. - // In that case, the frame type will not be set. - // Callers can either ignore the frame and return control of the stream back to HTTP/3 - // (by returning hijacked false). - // Alternatively, callers can take over the QUIC stream (by returning hijacked true). - StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) - - // UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type. - // If parsing the stream type fails, the error is passed to the callback. - // In that case, the stream type will not be set. - UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) - - mutex sync.RWMutex - listeners map[*quic.EarlyListener]listenerInfo - - closed bool - - altSvcHeader string - - logger utils.Logger -} - -// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. -// -// If s.Addr is blank, ":https" is used. -func (s *Server) ListenAndServe() error { - return s.serveConn(s.TLSConfig, nil) -} - -// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. -// -// If s.Addr is blank, ":https" is used. -func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { - var err error - certs := make([]tls.Certificate, 1) - certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - // We currently only use the cert-related stuff from tls.Config, - // so we don't need to make a full copy. - config := &tls.Config{ - Certificates: certs, - } - return s.serveConn(config, nil) -} - -// Serve an existing UDP connection. -// It is possible to reuse the same connection for outgoing connections. -// Closing the server does not close the packet conn. -func (s *Server) Serve(conn net.PacketConn) error { - return s.serveConn(s.TLSConfig, conn) -} - -// ServeListener serves an existing QUIC listener. -// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config -// and use it to construct a http3-friendly QUIC listener. -// Closing the server does close the listener. -func (s *Server) ServeListener(ln quic.EarlyListener) error { - if err := s.addListener(&ln); err != nil { - return err - } - err := s.serveListener(ln) - s.removeListener(&ln) - return err -} - -var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") - -func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { - if tlsConf == nil { - return errServerWithoutTLSConfig - } - - s.mutex.Lock() - closed := s.closed - s.mutex.Unlock() - if closed { - return http.ErrServerClosed - } - - baseConf := ConfigureTLSConfig(tlsConf) - quicConf := s.QuicConfig - if quicConf == nil { - quicConf = &quic.Config{} - } else { - quicConf = s.QuicConfig.Clone() - } - if s.EnableDatagrams { - quicConf.EnableDatagrams = true - } - - var ln quic.EarlyListener - var err error - if conn == nil { - addr := s.Addr - if addr == "" { - addr = ":https" - } - ln, err = quicListenAddr(addr, baseConf, quicConf) - } else { - ln, err = quicListen(conn, baseConf, quicConf) - } - if err != nil { - return err - } - if err := s.addListener(&ln); err != nil { - return err - } - err = s.serveListener(ln) - s.removeListener(&ln) - return err -} - -func (s *Server) serveListener(ln quic.EarlyListener) error { - for { - conn, err := ln.Accept(context.Background()) - if err != nil { - return err - } - go s.handleConn(conn) - } -} - -func extractPort(addr string) (int, error) { - _, portStr, err := net.SplitHostPort(addr) - if err != nil { - return 0, err - } - - portInt, err := net.LookupPort("tcp", portStr) - if err != nil { - return 0, err - } - return portInt, nil -} - -func (s *Server) generateAltSvcHeader() { - if len(s.listeners) == 0 { - // Don't announce any ports since no one is listening for connections - s.altSvcHeader = "" - return - } - - // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. - supportedVersions := protocol.SupportedVersions - if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 { - supportedVersions = s.QuicConfig.Versions - } - var versionStrings []string - for _, version := range supportedVersions { - if v := versionToALPN(version); len(v) > 0 { - versionStrings = append(versionStrings, v) - } - } - - var altSvc []string - addPort := func(port int) { - for _, v := range versionStrings { - altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)) - } - } - - if s.Port != 0 { - // if Port is specified, we must use it instead of the - // listener addresses since there's a reason it's specified. - addPort(s.Port) - } else { - // if we have some listeners assigned, try to find ports - // which we can announce, otherwise nothing should be announced - validPortsFound := false - for _, info := range s.listeners { - if info.port != 0 { - addPort(info.port) - validPortsFound = true - } - } - if !validPortsFound { - if port, err := extractPort(s.Addr); err == nil { - addPort(port) - } - } - } - - s.altSvcHeader = strings.Join(altSvc, ",") -} - -// We store a pointer to interface in the map set. This is safe because we only -// call trackListener via Serve and can track+defer untrack the same pointer to -// local variable there. We never need to compare a Listener from another caller. -func (s *Server) addListener(l *quic.EarlyListener) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.closed { - return http.ErrServerClosed - } - if s.logger == nil { - s.logger = utils.DefaultLogger.WithPrefix("server") - } - if s.listeners == nil { - s.listeners = make(map[*quic.EarlyListener]listenerInfo) - } - - if port, err := extractPort((*l).Addr().String()); err == nil { - s.listeners[l] = listenerInfo{port} - } else { - s.logger.Errorf( - "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) - s.listeners[l] = listenerInfo{} - } - s.generateAltSvcHeader() - return nil -} - -func (s *Server) removeListener(l *quic.EarlyListener) { - s.mutex.Lock() - delete(s.listeners, l) - s.generateAltSvcHeader() - s.mutex.Unlock() -} - -func (s *Server) handleConn(conn quic.EarlyConnection) { - decoder := qpack.NewDecoder(nil) - - // send a SETTINGS frame - str, err := conn.OpenUniStream() - if err != nil { - s.logger.Debugf("Opening the control stream failed.") - return - } - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) // stream type - (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Write(buf) - str.Write(buf.Bytes()) - - go s.handleUnidirectionalStreams(conn) - - // Process all requests immediately. - // It's the client's responsibility to decide which requests are eligible for 0-RTT. - for { - str, err := conn.AcceptStream(context.Background()) - if err != nil { - s.logger.Debugf("Accepting stream failed: %s", err) - return - } - go func() { - rerr := s.handleRequest(conn, str, decoder, func() { - conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") - }) - if rerr.err == errHijacked { - return - } - if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { - s.logger.Debugf("Handling request failed: %s", err) - if rerr.streamErr != 0 { - str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) - } - if rerr.connErr != 0 { - var reason string - if rerr.err != nil { - reason = rerr.err.Error() - } - conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) - } - return - } - str.Close() - }() - } -} - -func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { - for { - str, err := conn.AcceptUniStream(context.Background()) - if err != nil { - s.logger.Debugf("accepting unidirectional stream failed: %s", err) - return - } - - go func(str quic.ReceiveStream) { - streamType, err := quicvarint.Read(quicvarint.NewReader(str)) - if err != nil { - if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) { - return - } - s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) - return - } - // We're only interested in the control stream here. - switch streamType { - case streamTypeControlStream: - case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: - // Our QPACK implementation doesn't use the dynamic table yet. - // TODO: check that only one stream of each type is opened. - return - case streamTypePushStream: // only the server can push - conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") - return - default: - if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) { - return - } - str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) - return - } - f, err := parseNextFrame(str, nil) - if err != nil { - conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") - return - } - sf, ok := f.(*settingsFrame) - if !ok { - conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") - return - } - if !sf.Datagram { - return - } - // If datagram support was enabled on our side as well as on the client side, - // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. - // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams { - conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") - } - }(str) - } -} - -func (s *Server) maxHeaderBytes() uint64 { - if s.MaxHeaderBytes <= 0 { - return http.DefaultMaxHeaderBytes - } - return uint64(s.MaxHeaderBytes) -} - -func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { - var ufh unknownFrameHandlerFunc - if s.StreamHijacker != nil { - ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) } - } - frame, err := parseNextFrame(str, ufh) - if err != nil { - if err == errHijacked { - return requestError{err: errHijacked} - } - return newStreamError(errorRequestIncomplete, err) - } - hf, ok := frame.(*headersFrame) - if !ok { - return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) - } - if hf.Length > s.maxHeaderBytes() { - return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes())) - } - headerBlock := make([]byte, hf.Length) - if _, err := io.ReadFull(str, headerBlock); err != nil { - return newStreamError(errorRequestIncomplete, err) - } - hfs, err := decoder.DecodeFull(headerBlock) - if err != nil { - // TODO: use the right error code - return newConnError(errorGeneralProtocolError, err) - } - req, err := requestFromHeaders(hfs) - if err != nil { - // TODO: use the right error code - return newStreamError(errorGeneralProtocolError, err) - } - - req.RemoteAddr = conn.RemoteAddr().String() - body := newRequestBody(newStream(str, onFrameError)) - req.Body = body - - if s.logger.Debug() { - s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) - } else { - s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) - } - - ctx := str.Context() - ctx = context.WithValue(ctx, ServerContextKey, s) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) - req = req.WithContext(ctx) - r := newResponseWriter(str, conn, s.logger) - defer r.Flush() - handler := s.Handler - if handler == nil { - handler = http.DefaultServeMux - } - - var panicked bool - func() { - defer func() { - if p := recover(); p != nil { - // Copied from net/http/server.go - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - s.logger.Errorf("http: panic serving: %v\n%s", p, buf) - panicked = true - } - }() - handler.ServeHTTP(r, req) - }() - - if body.wasStreamHijacked() { - return requestError{err: errHijacked} - } - - if panicked { - r.WriteHeader(500) - } else { - r.WriteHeader(200) - } - // If the EOF was read by the handler, CancelRead() is a no-op. - str.CancelRead(quic.StreamErrorCode(errorNoError)) - return requestError{} -} - -// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. -// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) Close() error { - s.mutex.Lock() - defer s.mutex.Unlock() - - s.closed = true - - var err error - for ln := range s.listeners { - if cerr := (*ln).Close(); cerr != nil && err == nil { - err = cerr - } - } - return err -} - -// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. -// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) CloseGracefully(timeout time.Duration) error { - // TODO: implement - return nil -} - -// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found -// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port -// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr. -var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr") - -// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3. -// The values set by default advertise all of the ports the server is listening on, but can be -// changed to a specific port by setting Server.Port before launching the serverr. -// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used -// to extract the port, if specified. -// For example, a server launched using ListenAndServe on an address with port 443 would set: -// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 -func (s *Server) SetQuicHeaders(hdr http.Header) error { - s.mutex.RLock() - defer s.mutex.RUnlock() - - if s.altSvcHeader == "" { - return ErrNoAltSvcPort - } - // use the map directly to avoid constant canonicalization - // since the key is already canonicalized - hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader) - return nil -} - -// ListenAndServeQUIC listens on the UDP network address addr and calls the -// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is -// used when handler is nil. -func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { - server := &Server{ - Addr: addr, - Handler: handler, - } - return server.ListenAndServeTLS(certFile, keyFile) -} - -// ListenAndServe listens on the given network address for both, TLS and QUIC -// connections in parallel. It returns if one of the two returns an error. -// http.DefaultServeMux is used when handler is nil. -// The correct Alt-Svc headers for QUIC are set. -func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { - // Load certs - var err error - certs := make([]tls.Certificate, 1) - certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - // We currently only use the cert-related stuff from tls.Config, - // so we don't need to make a full copy. - config := &tls.Config{ - Certificates: certs, - } - - if addr == "" { - addr = ":https" - } - - // Open the listeners - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return err - } - udpConn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return err - } - defer udpConn.Close() - - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return err - } - tcpConn, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return err - } - defer tcpConn.Close() - - tlsConn := tls.NewListener(tcpConn, config) - defer tlsConn.Close() - - // Start the servers - httpServer := &http.Server{} - quicServer := &Server{ - TLSConfig: config, - } - - if handler == nil { - handler = http.DefaultServeMux - } - httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - quicServer.SetQuicHeaders(w.Header()) - handler.ServeHTTP(w, r) - }) - - hErr := make(chan error) - qErr := make(chan error) - go func() { - hErr <- httpServer.Serve(tlsConn) - }() - go func() { - qErr <- quicServer.Serve(udpConn) - }() - - select { - case err := <-hErr: - quicServer.Close() - return err - case err := <-qErr: - // Cannot close the HTTP server or wait for requests to complete properly :/ - return err - } -} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go deleted file mode 100644 index a051b847..00000000 --- a/internal/http3/server_test.go +++ /dev/null @@ -1,1289 +0,0 @@ -package http3 - -import ( - "bytes" - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "runtime" - "sync/atomic" - "time" - - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/testdata" - "github.com/imroc/req/v3/internal/utils" - "github.com/lucas-clemente/quic-go" - - "github.com/golang/mock/gomock" - "github.com/marten-seemann/qpack" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - gmtypes "github.com/onsi/gomega/types" -) - -type mockConn struct { - net.Conn - version quic.VersionNumber -} - -func newMockConn(version quic.VersionNumber) net.Conn { - return &mockConn{version: version} -} - -func (c *mockConn) GetQUICVersion() quic.VersionNumber { - return c.version -} - -type mockAddr struct { - addr string -} - -func (ma *mockAddr) Network() string { - return "udp" -} - -func (ma *mockAddr) String() string { - return ma.addr -} - -type mockAddrListener struct { - *mockquic.MockEarlyListener - addr *mockAddr -} - -func (m *mockAddrListener) Addr() net.Addr { - _ = m.MockEarlyListener.Addr() - return m.addr -} - -func newMockAddrListener(addr string) *mockAddrListener { - return &mockAddrListener{ - MockEarlyListener: mockquic.NewMockEarlyListener(mockCtrl), - addr: &mockAddr{ - addr: addr, - }, - } -} - -type noPortListener struct { - *mockAddrListener -} - -func (m *noPortListener) Addr() net.Addr { - _ = m.mockAddrListener.Addr() - return &net.UnixAddr{ - Net: "unix", - Name: "/tmp/quic.sock", - } -} - -var _ = Describe("Server", func() { - var ( - s *Server - origQuicListenAddr = quicListenAddr - ) - - BeforeEach(func() { - s = &Server{ - TLSConfig: testdata.GetTLSConfig(), - logger: utils.DefaultLogger, - } - origQuicListenAddr = quicListenAddr - }) - - AfterEach(func() { - quicListenAddr = origQuicListenAddr - }) - - Context("handling requests", func() { - var ( - qpackDecoder *qpack.Decoder - str *mockquic.MockStream - conn *mockquic.MockEarlyConnection - exampleGetRequest *http.Request - examplePostRequest *http.Request - ) - reqContext := context.Background() - - decodeHeader := func(str io.Reader) map[string][]string { - fields := make(map[string][]string) - decoder := qpack.NewDecoder(nil) - - frame, err := parseNextFrame(str, nil) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) - headersFrame := frame.(*headersFrame) - data := make([]byte, headersFrame.Length) - _, err = io.ReadFull(str, data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - hfs, err := decoder.DecodeFull(data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - for _, p := range hfs { - fields[p.Name] = append(fields[p.Name], p.Value) - } - return fields - } - - encodeRequest := func(req *http.Request) []byte { - buf := &bytes.Buffer{} - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) - return buf.Bytes() - } - - setRequest := func(data []byte) { - buf := bytes.NewBuffer(data) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - if buf.Len() == 0 { - return 0, io.EOF - } - return buf.Read(p) - }).AnyTimes() - } - - BeforeEach(func() { - var err error - exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) - Expect(err).ToNot(HaveOccurred()) - examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) - Expect(err).ToNot(HaveOccurred()) - - qpackDecoder = qpack.NewDecoder(nil) - str = mockquic.NewMockStream(mockCtrl) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - }) - - It("calls the HTTP handler function", func() { - requestChan := make(chan *http.Request, 1) - s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - requestChan <- r - }) - - setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return len(p), nil - }).AnyTimes() - str.EXPECT().CancelRead(gomock.Any()) - - Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) - var req *http.Request - Eventually(requestChan).Should(Receive(&req)) - Expect(req.Host).To(Equal("www.example.com")) - Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) - Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) - }) - - It("returns 200 with an empty handler", func() { - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - responseBuf := &bytes.Buffer{} - setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelRead(gomock.Any()) - - serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) - hfs := decodeHeader(responseBuf) - Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) - }) - - It("handles a panicking handler", func() { - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic("foobar") - }) - - responseBuf := &bytes.Buffer{} - setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelRead(gomock.Any()) - - serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) - hfs := decodeHeader(responseBuf) - Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) - }) - - Context("hijacking bidirectional streams", func() { - var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) - - BeforeEach(func() { - testDone = make(chan struct{}) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - }) - - AfterEach(func() { testDone <- struct{}{} }) - - It("hijacks a bidirectional stream of unknown frame type", func() { - frameTypeChan := make(chan FrameType, 1) - s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return true, nil - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { - frameTypeChan := make(chan FrameType, 1) - s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, nil - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("cancels writing when hijacker returned error", func() { - frameTypeChan := make(chan FrameType, 1) - s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, errors.New("error in hijacker") - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("handles errors that occur when reading the stream type", func() { - testErr := errors.New("test error") - done := make(chan struct{}) - unknownStr := mockquic.NewMockStream(mockCtrl) - s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { - defer close(done) - Expect(ft).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - Expect(err).To(MatchError(testErr)) - return true, nil - } - - unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - }) - - Context("hijacking unidirectional streams", func() { - var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) - - BeforeEach(func() { - testDone = make(chan struct{}) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - }) - - AfterEach(func() { testDone <- struct{}{} }) - - It("hijacks an unidirectional stream of unknown stream type", func() { - streamTypeChan := make(chan StreamType, 1) - s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return true - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x54) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return unknownStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("handles errors that occur when reading the stream type", func() { - testErr := errors.New("test error") - done := make(chan struct{}) - unknownStr := mockquic.NewMockStream(mockCtrl) - s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { - defer close(done) - Expect(st).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - Expect(err).To(MatchError(testErr)) - return true - } - - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) - conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { - streamTypeChan := make(chan StreamType, 1) - s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return false - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x54) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) - - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return unknownStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - }) - - Context("control stream handling", func() { - var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) - - BeforeEach(func() { - conn = mockquic.NewMockEarlyConnection(mockCtrl) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - }) - - AfterEach(func() { testDone <- struct{}{} }) - - It("parses the SETTINGS frame", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&settingsFrame{}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { - streamType := t - name := "encoder" - if streamType == streamTypeQPACKDecoderStream { - name = "decoder" - } - - It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamType) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return str, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead - }) - } - - It("reset streams other than the control stream and the QPACK streams", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, 1337) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - done := make(chan struct{}) - str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { - close(done) - }) - - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return str, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("errors when the first frame on the control stream is not a SETTINGS frame", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&dataFrame{}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorMissingSettings)) - close(done) - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("errors when parsing the frame on the control stream fails", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - b := &bytes.Buffer{} - (&settingsFrame{}).Write(b) - buf.Write(b.Bytes()[:b.Len()-1]) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorFrameError)) - close(done) - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("errors when the client opens a push stream", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypePushStream) - (&dataFrame{}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorStreamCreationError)) - close(done) - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("errors when the client advertises datagram support (and we enabled support for it)", func() { - s.EnableDatagrams = true - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&settingsFrame{Datagram: true}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorSettingsError)) - Expect(reason).To(Equal("missing QUIC Datagram support")) - close(done) - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("stream- and connection-level errors", func() { - var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) - - BeforeEach(func() { - testDone = make(chan struct{}) - addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - conn = mockquic.NewMockEarlyConnection(mockCtrl) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) - conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - }) - - AfterEach(func() { testDone <- struct{}{} }) - - It("cancels reading when client sends a body in GET request", func() { - var handlerCalled bool - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlerCalled = true - }) - - requestData := encodeRequest(exampleGetRequest) - buf := &bytes.Buffer{} - (&dataFrame{Length: 6}).Write(buf) // add a body - buf.Write([]byte("foobar")) - responseBuf := &bytes.Buffer{} - setRequest(append(requestData, buf.Bytes()...)) - done := make(chan struct{}) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) - str.EXPECT().Close().Do(func() { close(done) }) - - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - hfs := decodeHeader(responseBuf) - Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) - Expect(handlerCalled).To(BeTrue()) - }) - - It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer close(handlerCalled) - r.Body.(HTTPStreamer).HTTPStream() - str.Write([]byte("foobar")) - }) - - requestData := encodeRequest(exampleGetRequest) - buf := &bytes.Buffer{} - (&dataFrame{Length: 6}).Write(buf) // add a body - buf.Write([]byte("foobar")) - setRequest(append(requestData, buf.Bytes()...)) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write([]byte("foobar")).Return(6, nil) - - s.handleConn(conn) - Eventually(handlerCalled).Should(BeClosed()) - }) - - It("errors when the client sends a too large header frame", func() { - s.MaxHeaderBytes = 20 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Fail("Handler should not be called.") - }) - - requestData := encodeRequest(exampleGetRequest) - buf := &bytes.Buffer{} - (&dataFrame{Length: 6}).Write(buf) // add a body - buf.Write([]byte("foobar")) - responseBuf := &bytes.Buffer{} - setRequest(append(requestData, buf.Bytes()...)) - done := make(chan struct{}) - str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("handles a request for which the client immediately resets the stream", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(handlerCalled) - }) - - testErr := errors.New("stream reset") - done := make(chan struct{}) - str.EXPECT().Read(gomock.Any()).Return(0, testErr) - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) - - s.handleConn(conn) - Consistently(handlerCalled).ShouldNot(BeClosed()) - }) - - It("closes the connection when the first frame is not a HEADERS frame", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(handlerCalled) - }) - - buf := &bytes.Buffer{} - (&dataFrame{}).Write(buf) - setRequest(buf.Bytes()) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return len(p), nil - }).AnyTimes() - - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - Expect(code).To(Equal(quic.ApplicationErrorCode(errorFrameUnexpected))) - close(done) - }) - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - - It("closes the connection when the first frame is not a HEADERS frame", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(handlerCalled) - }) - - // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, - // but the request will still end up larger than DefaultMaxHeaderBytes. - url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) - req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) - Expect(err).ToNot(HaveOccurred()) - setRequest(encodeRequest(req)) - // str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return len(p), nil - }).AnyTimes() - done := make(chan struct{}) - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - - s.handleConn(conn) - Eventually(done).Should(BeClosed()) - }) - }) - - It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Body = struct { - io.Reader - io.Closer - }{} - close(handlerCalled) - }) - - setRequest(encodeRequest(examplePostRequest)) - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return len(p), nil - }).AnyTimes() - str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) - - serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) - Eventually(handlerCalled).Should(BeClosed()) - }) - - It("cancels the request context when the stream is closed", func() { - handlerCalled := make(chan struct{}) - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - Expect(r.Context().Done()).To(BeClosed()) - Expect(r.Context().Err()).To(MatchError(context.Canceled)) - close(handlerCalled) - }) - setRequest(encodeRequest(examplePostRequest)) - - reqContext, cancel := context.WithCancel(context.Background()) - cancel() - str.EXPECT().Context().Return(reqContext) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return len(p), nil - }).AnyTimes() - str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) - - serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) - Eventually(handlerCalled).Should(BeClosed()) - }) - }) - - Context("setting http headers", func() { - BeforeEach(func() { - s.QuicConfig = &quic.Config{Versions: []quic.VersionNumber{protocol.VersionDraft29}} - }) - - var ln1 quic.EarlyListener - var ln2 quic.EarlyListener - expected := http.Header{ - "Alt-Svc": {`h3-29=":443"; ma=2592000`}, - } - - addListener := func(addr string, ln *quic.EarlyListener) { - mln := newMockAddrListener(addr) - mln.EXPECT().Addr() - *ln = mln - s.addListener(ln) - } - - removeListener := func(ln *quic.EarlyListener) { - s.removeListener(ln) - } - - checkSetHeaders := func(expected gmtypes.GomegaMatcher) { - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(expected) - } - - checkSetHeaderError := func() { - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) - } - - It("sets proper headers with numeric port", func() { - addListener(":443", &ln1) - checkSetHeaders(Equal(expected)) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("sets proper headers with full addr", func() { - addListener("127.0.0.1:443", &ln1) - checkSetHeaders(Equal(expected)) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("sets proper headers with string port", func() { - addListener(":https", &ln1) - checkSetHeaders(Equal(expected)) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("works multiple times", func() { - addListener(":https", &ln1) - checkSetHeaders(Equal(expected)) - checkSetHeaders(Equal(expected)) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("works if the quic.Config sets QUIC versions", func() { - s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.VersionDraft29} - addListener(":443", &ln1) - checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3-29=":443"; ma=2592000`}})) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("uses s.Port if set to a non-zero value", func() { - s.Port = 8443 - addListener(":443", &ln1) - checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000`}})) - removeListener(&ln1) - checkSetHeaderError() - }) - - It("uses s.Addr if listeners don't have ports available", func() { - s.Addr = ":443" - mln := &noPortListener{newMockAddrListener("")} - mln.EXPECT().Addr() - ln1 = mln - s.addListener(&ln1) - checkSetHeaders(Equal(expected)) - s.removeListener(&ln1) - checkSetHeaderError() - }) - - It("properly announces multiple listeners", func() { - addListener(":443", &ln1) - addListener(":8443", &ln2) - checkSetHeaders(Or( - Equal(http.Header{"Alt-Svc": {`h3-29=":443"; ma=2592000,h3-29=":8443"; ma=2592000`}}), - Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000,h3-29=":443"; ma=2592000`}}), - )) - removeListener(&ln1) - removeListener(&ln2) - checkSetHeaderError() - }) - }) - - It("errors when ListenAndServe is called with s.TLSConfig nil", func() { - Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) - }) - - It("should nop-Close() when s.server is nil", func() { - Expect((&Server{}).Close()).To(Succeed()) - }) - - It("errors when ListenAndServeTLS is called after Close", func() { - serv := &Server{} - Expect(serv.Close()).To(Succeed()) - Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) - }) - - It("handles concurrent Serve and Close", func() { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - c, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - s.Serve(c) - }() - runtime.Gosched() - s.Close() - Eventually(done).Should(BeClosed()) - }) - - Context("ConfigureTLSConfig", func() { - var tlsConf *tls.Config - var ch *tls.ClientHelloInfo - - BeforeEach(func() { - tlsConf = &tls.Config{} - ch = &tls.ClientHelloInfo{} - }) - - It("advertises v1 by default", func() { - tlsConf = ConfigureTLSConfig(tlsConf) - Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) - - config, err := tlsConf.GetConfigForClient(ch) - Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) - }) - - It("advertises h3-29 for draft-29", func() { - tlsConf = ConfigureTLSConfig(tlsConf) - Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) - - ch.Conn = newMockConn(protocol.VersionDraft29) - config, err := tlsConf.GetConfigForClient(ch) - Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) - }) - }) - - Context("Serve", func() { - origQuicListen := quicListen - - AfterEach(func() { - quicListen = origQuicListen - }) - - It("serves a packet conn", func() { - ln := newMockAddrListener(":443") - conn := &net.UDPConn{} - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - Expect(c).To(Equal(conn)) - return ln, nil - } - - s := &Server{ - TLSConfig: &tls.Config{}, - } - - stopAccept := make(chan struct{}) - ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept - return nil, errors.New("closed") - }) - ln.EXPECT().Addr() // generate alt-svc headers - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - s.Serve(conn) - }() - - Consistently(done).ShouldNot(BeClosed()) - ln.EXPECT().Close().Do(func() { close(stopAccept) }) - Expect(s.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("serves two packet conns", func() { - ln1 := newMockAddrListener(":443") - ln2 := newMockAddrListener(":8443") - lns := make(chan quic.EarlyListener, 2) - lns <- ln1 - lns <- ln2 - conn1 := &net.UDPConn{} - conn2 := &net.UDPConn{} - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - return <-lns, nil - } - - s := &Server{ - TLSConfig: &tls.Config{}, - } - - stopAccept1 := make(chan struct{}) - ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept1 - return nil, errors.New("closed") - }) - ln1.EXPECT().Addr() // generate alt-svc headers - stopAccept2 := make(chan struct{}) - ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept2 - return nil, errors.New("closed") - }) - ln2.EXPECT().Addr() - - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done1) - s.Serve(conn1) - }() - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done2) - s.Serve(conn2) - }() - - Consistently(done1).ShouldNot(BeClosed()) - Expect(done2).ToNot(BeClosed()) - ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) - ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) - Expect(s.Close()).To(Succeed()) - Eventually(done1).Should(BeClosed()) - Eventually(done2).Should(BeClosed()) - }) - }) - - Context("ServeListener", func() { - origQuicListen := quicListen - - AfterEach(func() { - quicListen = origQuicListen - }) - - It("serves a listener", func() { - var called int32 - ln := newMockAddrListener(":443") - quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - atomic.StoreInt32(&called, 1) - return ln, nil - } - - s := &Server{} - - stopAccept := make(chan struct{}) - ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept - return nil, errors.New("closed") - }) - ln.EXPECT().Addr() // generate alt-svc headers - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - s.ServeListener(ln) - }() - - Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) - Consistently(done).ShouldNot(BeClosed()) - ln.EXPECT().Close().Do(func() { close(stopAccept) }) - Expect(s.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("serves two listeners", func() { - var called int32 - ln1 := newMockAddrListener(":443") - ln2 := newMockAddrListener(":8443") - lns := make(chan quic.EarlyListener, 2) - lns <- ln1 - lns <- ln2 - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - atomic.StoreInt32(&called, 1) - return <-lns, nil - } - - s := &Server{} - - stopAccept1 := make(chan struct{}) - ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept1 - return nil, errors.New("closed") - }) - ln1.EXPECT().Addr() // generate alt-svc headers - stopAccept2 := make(chan struct{}) - ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { - <-stopAccept2 - return nil, errors.New("closed") - }) - ln2.EXPECT().Addr() - - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done1) - s.ServeListener(ln1) - }() - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done2) - s.ServeListener(ln2) - }() - - Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) - Consistently(done1).ShouldNot(BeClosed()) - Expect(done2).ToNot(BeClosed()) - ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) - ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) - Expect(s.Close()).To(Succeed()) - Eventually(done1).Should(BeClosed()) - Eventually(done2).Should(BeClosed()) - }) - }) - - Context("ListenAndServe", func() { - BeforeEach(func() { - s.Addr = "localhost:0" - }) - - AfterEach(func() { - Expect(s.Close()).To(Succeed()) - }) - - checkGetConfigForClientVersions := func(conf *tls.Config) { - c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)}) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29})) - c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.Version1)}) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3})) - } - - It("uses the quic.Config to start the QUIC server", func() { - conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} - var receivedConf *quic.Config - quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - receivedConf = config - return nil, errors.New("listen err") - } - s.QuicConfig = conf - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf).To(Equal(conf)) - }) - - It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() { - tlsConf := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - NextProtos: []string{"foo", "bar"}, - } - var receivedConf *tls.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = tlsConf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.NextProtos).To(BeEmpty()) - Expect(receivedConf.ClientAuth).To(BeZero()) - // make sure the original tls.Config was not modified - Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) - // make sure that the config returned from the GetConfigForClient callback sets the fields of the original config - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the GetConfigForClient callback if no tls.Config is given", func() { - var receivedConf *tls.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = tlsConf - return nil, errors.New("listen err") - } - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf).ToNot(BeNil()) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { - tlsConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - NextProtos: []string{"foo", "bar"}, - }, nil - }, - } - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the original config was not modified - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) - // check that the config returned by the GetConfigForClient callback uses the returned config - conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { - tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}} - tlsConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return tlsClientConf, nil - }, - } - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the original config was not modified - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) - checkGetConfigForClientVersions(receivedConf) - }) - - It("works if GetConfigForClient returns a nil tls.Config", func() { - tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }} - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf).ToNot(BeNil()) - checkGetConfigForClientVersions(receivedConf) - }) - }) - - It("closes gracefully", func() { - Expect(s.CloseGracefully(0)).To(Succeed()) - }) - - It("errors when listening fails", func() { - testErr := errors.New("listen error") - quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - return nil, testErr - } - fullpem, privkey := testdata.GetCertificatePaths() - Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) - }) - - It("supports H3_DATAGRAM", func() { - s.EnableDatagrams = true - var receivedConf *quic.Config - quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { - receivedConf = config - return nil, errors.New("listen err") - } - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.EnableDatagrams).To(BeTrue()) - }) -}) diff --git a/internal/logging/frame.go b/internal/logging/frame.go deleted file mode 100644 index c2897747..00000000 --- a/internal/logging/frame.go +++ /dev/null @@ -1,66 +0,0 @@ -package logging - -import "github.com/imroc/req/v3/internal/wire" - -// A Frame is a QUIC frame -type Frame interface{} - -// The AckRange is used within the AckFrame. -// It is a range of packet numbers that is being acknowledged. -type AckRange = wire.AckRange - -type ( - // An AckFrame is an ACK frame. - AckFrame = wire.AckFrame - // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. - ConnectionCloseFrame = wire.ConnectionCloseFrame - // A DataBlockedFrame is a DATA_BLOCKED frame. - DataBlockedFrame = wire.DataBlockedFrame - // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. - HandshakeDoneFrame = wire.HandshakeDoneFrame - // A MaxDataFrame is a MAX_DATA frame. - MaxDataFrame = wire.MaxDataFrame - // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. - MaxStreamDataFrame = wire.MaxStreamDataFrame - // A MaxStreamsFrame is a MAX_STREAMS_FRAME. - MaxStreamsFrame = wire.MaxStreamsFrame - // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. - NewConnectionIDFrame = wire.NewConnectionIDFrame - // A NewTokenFrame is a NEW_TOKEN frame. - NewTokenFrame = wire.NewTokenFrame - // A PathChallengeFrame is a PATH_CHALLENGE frame. - PathChallengeFrame = wire.PathChallengeFrame - // A PathResponseFrame is a PATH_RESPONSE frame. - PathResponseFrame = wire.PathResponseFrame - // A PingFrame is a PING frame. - PingFrame = wire.PingFrame - // A ResetStreamFrame is a RESET_STREAM frame. - ResetStreamFrame = wire.ResetStreamFrame - // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. - RetireConnectionIDFrame = wire.RetireConnectionIDFrame - // A StopSendingFrame is a STOP_SENDING frame. - StopSendingFrame = wire.StopSendingFrame - // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. - StreamsBlockedFrame = wire.StreamsBlockedFrame - // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. - StreamDataBlockedFrame = wire.StreamDataBlockedFrame -) - -// A CryptoFrame is a CRYPTO frame. -type CryptoFrame struct { - Offset ByteCount - Length ByteCount -} - -// A StreamFrame is a STREAM frame. -type StreamFrame struct { - StreamID StreamID - Offset ByteCount - Length ByteCount - Fin bool -} - -// A DatagramFrame is a DATAGRAM frame. -type DatagramFrame struct { - Length ByteCount -} diff --git a/internal/logging/interface.go b/internal/logging/interface.go deleted file mode 100644 index 8e6eeba6..00000000 --- a/internal/logging/interface.go +++ /dev/null @@ -1,135 +0,0 @@ -// Package logging defines a logging interface for quic-go. -// This package should not be considered stable -package logging - -import ( - "context" - "github.com/lucas-clemente/quic-go" - "net" - "time" - - "github.com/imroc/req/v3/internal/utils" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/wire" -) - -type ( - // A ByteCount is used to count bytes. - ByteCount = protocol.ByteCount - // A ConnectionID is a QUIC Connection ID. - ConnectionID = protocol.ConnectionID - // The EncryptionLevel is the encryption level of a packet. - EncryptionLevel = protocol.EncryptionLevel - // The KeyPhase is the key phase of the 1-RTT keys. - KeyPhase = protocol.KeyPhase - // The KeyPhaseBit is the value of the key phase bit of the 1-RTT packets. - KeyPhaseBit = protocol.KeyPhaseBit - // The PacketNumber is the packet number of a packet. - PacketNumber = protocol.PacketNumber - // The Perspective is the role of a QUIC endpoint (client or server). - Perspective = protocol.Perspective - // A StatelessResetToken is a stateless reset token. - StatelessResetToken = protocol.StatelessResetToken - // The StreamID is the stream ID. - StreamID = protocol.StreamID - // The StreamNum is the number of the stream. - StreamNum = protocol.StreamNum - // The StreamType is the type of the stream (unidirectional or bidirectional). - StreamType = protocol.StreamType - // The VersionNumber is the QUIC version. - VersionNumber = quic.VersionNumber - - // The Header is the QUIC packet header, before removing header protection. - Header = wire.Header - // The ExtendedHeader is the QUIC packet header, after removing header protection. - ExtendedHeader = wire.ExtendedHeader - // The TransportParameters are QUIC transport parameters. - TransportParameters = wire.TransportParameters - // The PreferredAddress is the preferred address sent in the transport parameters. - PreferredAddress = wire.PreferredAddress - - // A TransportError is a transport-level error code. - TransportError = qerr.TransportErrorCode - // An ApplicationError is an application-defined error code. - ApplicationError = qerr.TransportErrorCode - - // The RTTStats contain statistics used by the congestion controller. - RTTStats = utils.RTTStats -) - -const ( - // KeyPhaseZero is key phase bit 0 - KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero - // KeyPhaseOne is key phase bit 1 - KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne -) - -const ( - // PerspectiveServer is used for a QUIC server - PerspectiveServer Perspective = protocol.PerspectiveServer - // PerspectiveClient is used for a QUIC client - PerspectiveClient Perspective = protocol.PerspectiveClient -) - -const ( - // EncryptionInitial is the Initial encryption level - EncryptionInitial EncryptionLevel = protocol.EncryptionInitial - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake - // Encryption1RTT is the 1-RTT encryption level - Encryption1RTT EncryptionLevel = protocol.Encryption1RTT - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT EncryptionLevel = protocol.Encryption0RTT -) - -const ( - // StreamTypeUni is a unidirectional stream - StreamTypeUni = protocol.StreamTypeUni - // StreamTypeBidi is a bidirectional stream - StreamTypeBidi = protocol.StreamTypeBidi -) - -// A Tracer traces events. -type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - - SentPacket(net.Addr, *Header, ByteCount, []Frame) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(*Header, []VersionNumber) - ReceivedRetry(*Header) - ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - BufferedPacket(PacketType) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/internal/logging/logging_suite_test.go b/internal/logging/logging_suite_test.go deleted file mode 100644 index 0a81943d..00000000 --- a/internal/logging/logging_suite_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package logging - -import ( - "testing" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestLogging(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Logging Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/logging/mock_connection_tracer_test.go b/internal/logging/mock_connection_tracer_test.go deleted file mode 100644 index 4d628fee..00000000 --- a/internal/logging/mock_connection_tracer_test.go +++ /dev/null @@ -1,352 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/logging (interfaces: ConnectionTracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - "github.com/lucas-clemente/quic-go" - net "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - utils "github.com/imroc/req/v3/internal/utils" - wire "github.com/imroc/req/v3/internal/wire" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 quic.VersionNumber, arg1, arg2 []quic.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedPacket mocks base method. -func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []quic.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentPacket mocks base method. -func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) -} diff --git a/internal/logging/mock_tracer_test.go b/internal/logging/mock_tracer_test.go deleted file mode 100644 index 98e245d6..00000000 --- a/internal/logging/mock_tracer_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/logging (interfaces: Tracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - wire "github.com/imroc/req/v3/internal/wire" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/internal/logging/mockgen.go b/internal/logging/mockgen.go deleted file mode 100644 index 48750afb..00000000 --- a/internal/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/logging -destination mock_connection_tracer_test.go github.com/imroc/req/v3/internal/logging ConnectionTracer" -//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/logging -destination mock_tracer_test.go github.com/imroc/req/v3/internal/logging Tracer" diff --git a/internal/logging/multiplex.go b/internal/logging/multiplex.go deleted file mode 100644 index 8280e8cd..00000000 --- a/internal/logging/multiplex.go +++ /dev/null @@ -1,219 +0,0 @@ -package logging - -import ( - "context" - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer - for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } - } - return NewMultiplexedConnectionTracer(connTracers...) -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(hdr, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { - for _, t := range m.tracers { - t.BufferedPacket(typ) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/internal/logging/multiplex_test.go b/internal/logging/multiplex_test.go deleted file mode 100644 index acc0e9d4..00000000 --- a/internal/logging/multiplex_test.go +++ /dev/null @@ -1,266 +0,0 @@ -package logging - -import ( - "context" - "errors" - "net" - "time" - - "github.com/imroc/req/v3/internal/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Tracing", func() { - Context("Tracer", func() { - It("returns a nil tracer if no tracers are passed in", func() { - Expect(NewMultiplexedTracer()).To(BeNil()) - }) - - It("returns the raw tracer if only one tracer is passed in", func() { - tr := NewMockTracer(mockCtrl) - tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) - }) - - Context("tracing events", func() { - var ( - tracer Tracer - tr1, tr2 *MockTracer - ) - - BeforeEach(func() { - tr1 = NewMockTracer(mockCtrl) - tr2 = NewMockTracer(mockCtrl) - tracer = NewMultiplexedTracer(tr1, tr2) - }) - - It("multiplexes the TracerForConnection call", func() { - ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - }) - - It("uses multiple connection tracers", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - ctr2 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - ctr1.EXPECT().LossTimerCanceled() - ctr2.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("handles tracers that return a nil ConnectionTracer", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - ctr1.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("returns nil when all tracers return a nil ConnectionTracer", func() { - ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) - }) - - It("traces the PacketSent event", func() { - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - f := &MaxDataFrame{MaximumData: 1337} - tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tracer.SentPacket(remote, hdr, 1024, []Frame{f}) - }) - - It("traces the PacketDropped event", func() { - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) - }) - }) - }) - - Context("Connection Tracer", func() { - var ( - tracer ConnectionTracer - tr1 *MockConnectionTracer - tr2 *MockConnectionTracer - ) - - BeforeEach(func() { - tr1 = NewMockConnectionTracer(mockCtrl) - tr2 = NewMockConnectionTracer(mockCtrl) - tracer = NewMultiplexedConnectionTracer(tr1, tr2) - }) - - It("trace the ConnectionStarted event", func() { - local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - }) - - It("traces the ClosedConnection event", func() { - e := errors.New("test err") - tr1.EXPECT().ClosedConnection(e) - tr2.EXPECT().ClosedConnection(e) - tracer.ClosedConnection(e) - }) - - It("traces the SentTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().SentTransportParameters(tp) - tr2.EXPECT().SentTransportParameters(tp) - tracer.SentTransportParameters(tp) - }) - - It("traces the ReceivedTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().ReceivedTransportParameters(tp) - tr2.EXPECT().ReceivedTransportParameters(tp) - tracer.ReceivedTransportParameters(tp) - }) - - It("traces the RestoredTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().RestoredTransportParameters(tp) - tr2.EXPECT().RestoredTransportParameters(tp) - tracer.RestoredTransportParameters(tp) - }) - - It("traces the SentPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} - ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} - ping := &PingFrame{} - tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentPacket(hdr, 1337, ack, []Frame{ping}) - }) - - It("traces the ReceivedVersionNegotiationPacket event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - tr1.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tr2.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tracer.ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - }) - - It("traces the ReceivedRetry event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - tr1.EXPECT().ReceivedRetry(hdr) - tr2.EXPECT().ReceivedRetry(hdr) - tracer.ReceivedRetry(hdr) - }) - - It("traces the ReceivedPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} - ping := &PingFrame{} - tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedPacket(hdr, 1337, []Frame{ping}) - }) - - It("traces the BufferedPacket event", func() { - tr1.EXPECT().BufferedPacket(PacketTypeHandshake) - tr2.EXPECT().BufferedPacket(PacketTypeHandshake) - tracer.BufferedPacket(PacketTypeHandshake) - }) - - It("traces the DroppedPacket event", func() { - tr1.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tr2.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) - }) - - It("traces the UpdatedCongestionState event", func() { - tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) - tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) - tracer.UpdatedCongestionState(CongestionStateRecovery) - }) - - It("traces the UpdatedMetrics event", func() { - rttStats := &RTTStats{} - rttStats.UpdateRTT(time.Second, 0, time.Now()) - tr1.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) - tr2.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) - tracer.UpdatedMetrics(rttStats, 1337, 42, 13) - }) - - It("traces the AcknowledgedPacket event", func() { - tr1.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) - tr2.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) - tracer.AcknowledgedPacket(EncryptionHandshake, 42) - }) - - It("traces the LostPacket event", func() { - tr1.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) - tr2.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) - tracer.LostPacket(EncryptionHandshake, 42, PacketLossReorderingThreshold) - }) - - It("traces the UpdatedPTOCount event", func() { - tr1.EXPECT().UpdatedPTOCount(uint32(88)) - tr2.EXPECT().UpdatedPTOCount(uint32(88)) - tracer.UpdatedPTOCount(88) - }) - - It("traces the UpdatedKeyFromTLS event", func() { - tr1.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - tr2.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - tracer.UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - }) - - It("traces the UpdatedKey event", func() { - tr1.EXPECT().UpdatedKey(KeyPhase(42), true) - tr2.EXPECT().UpdatedKey(KeyPhase(42), true) - tracer.UpdatedKey(KeyPhase(42), true) - }) - - It("traces the DroppedEncryptionLevel event", func() { - tr1.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) - tr2.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) - tracer.DroppedEncryptionLevel(EncryptionHandshake) - }) - - It("traces the DroppedKey event", func() { - tr1.EXPECT().DroppedKey(KeyPhase(123)) - tr2.EXPECT().DroppedKey(KeyPhase(123)) - tracer.DroppedKey(123) - }) - - It("traces the SetLossTimer event", func() { - now := time.Now() - tr1.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - tr2.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - tracer.SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - }) - - It("traces the LossTimerExpired event", func() { - tr1.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) - tr2.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) - tracer.LossTimerExpired(TimerTypePTO, EncryptionHandshake) - }) - - It("traces the LossTimerCanceled event", func() { - tr1.EXPECT().LossTimerCanceled() - tr2.EXPECT().LossTimerCanceled() - tracer.LossTimerCanceled() - }) - - It("traces the Close event", func() { - tr1.EXPECT().Close() - tr2.EXPECT().Close() - tracer.Close() - }) - }) -}) diff --git a/internal/logging/packet_header.go b/internal/logging/packet_header.go deleted file mode 100644 index dd7682ec..00000000 --- a/internal/logging/packet_header.go +++ /dev/null @@ -1,27 +0,0 @@ -package logging - -import ( - "github.com/imroc/req/v3/internal/protocol" -) - -// PacketTypeFromHeader determines the packet type from a *wire.Header. -func PacketTypeFromHeader(hdr *Header) PacketType { - if !hdr.IsLongHeader { - return PacketType1RTT - } - if hdr.Version == 0 { - return PacketTypeVersionNegotiation - } - switch hdr.Type { - case protocol.PacketTypeInitial: - return PacketTypeInitial - case protocol.PacketTypeHandshake: - return PacketTypeHandshake - case protocol.PacketType0RTT: - return PacketType0RTT - case protocol.PacketTypeRetry: - return PacketTypeRetry - default: - return PacketTypeNotDetermined - } -} diff --git a/internal/logging/packet_header_test.go b/internal/logging/packet_header_test.go deleted file mode 100644 index d0e7f08f..00000000 --- a/internal/logging/packet_header_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package logging - -import ( - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet Header", func() { - Context("determining the packet type from the header", func() { - It("recognizes Initial packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeInitial)) - }) - - It("recognizes Handshake packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeHandshake)) - }) - - It("recognizes Retry packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeRetry)) - }) - - It("recognizes 0-RTT packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - Version: protocol.VersionTLS, - })).To(Equal(PacketType0RTT)) - }) - - It("recognizes Version Negotiation packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{IsLongHeader: true})).To(Equal(PacketTypeVersionNegotiation)) - }) - - It("recognizes 1-RTT packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketType1RTT)) - }) - - It("handles unrecognized packet types", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeNotDetermined)) - }) - }) -}) diff --git a/internal/logging/types.go b/internal/logging/types.go deleted file mode 100644 index ad800692..00000000 --- a/internal/logging/types.go +++ /dev/null @@ -1,94 +0,0 @@ -package logging - -// PacketType is the packet type of a QUIC packet -type PacketType uint8 - -const ( - // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = iota - // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake - // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry - // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT - // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet - PacketTypeVersionNegotiation - // PacketType1RTT is a 1-RTT packet - PacketType1RTT - // PacketTypeStatelessReset is a stateless reset - PacketTypeStatelessReset - // PacketTypeNotDetermined is the packet type when it could not be determined - PacketTypeNotDetermined -) - -type PacketLossReason uint8 - -const ( - // PacketLossReorderingThreshold: when a packet is deemed lost due to reordering threshold - PacketLossReorderingThreshold PacketLossReason = iota - // PacketLossTimeThreshold: when a packet is deemed lost due to time threshold - PacketLossTimeThreshold -) - -type PacketDropReason uint8 - -const ( - // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable - PacketDropKeyUnavailable PacketDropReason = iota - // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown - PacketDropUnknownConnectionID - // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed - PacketDropHeaderParseError - // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed - PacketDropPayloadDecryptError - // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation - PacketDropProtocolViolation - // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack - PacketDropDOSPrevention - // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported - PacketDropUnsupportedVersion - // PacketDropUnexpectedPacket is used when an unexpected packet is received - PacketDropUnexpectedPacket - // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received - PacketDropUnexpectedSourceConnectionID - // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received - PacketDropUnexpectedVersion - // PacketDropDuplicate is used when a duplicate packet is received - PacketDropDuplicate -) - -// TimerType is the type of the loss detection timer -type TimerType uint8 - -const ( - // TimerTypeACK is the timer type for the early retransmit timer - TimerTypeACK TimerType = iota - // TimerTypePTO is the timer type for the PTO retransmit timer - TimerTypePTO -) - -// TimeoutReason is the reason why a connection is closed -type TimeoutReason uint8 - -const ( - // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout - // This reason is not defined in the qlog draft, but very useful for debugging. - TimeoutReasonHandshake TimeoutReason = iota - // TimeoutReasonIdle is used when the connection is closed due to an idle timeout - // This reason is not defined in the qlog draft, but very useful for debugging. - TimeoutReasonIdle -) - -type CongestionState uint8 - -const ( - // CongestionStateSlowStart is the slow start phase of Reno / Cubic - CongestionStateSlowStart CongestionState = iota - // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic - CongestionStateCongestionAvoidance - // CongestionStateRecovery is the recovery phase of Reno / Cubic - CongestionStateRecovery - // CongestionStateApplicationLimited means that the congestion controller is application limited - CongestionStateApplicationLimited -) diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go deleted file mode 100644 index 23114372..00000000 --- a/internal/mocks/congestion.go +++ /dev/null @@ -1,192 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/congestion (interfaces: SendAlgorithmWithDebugInfos) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockSendAlgorithmWithDebugInfos is a mock of SendAlgorithmWithDebugInfos interface. -type MockSendAlgorithmWithDebugInfos struct { - ctrl *gomock.Controller - recorder *MockSendAlgorithmWithDebugInfosMockRecorder -} - -// MockSendAlgorithmWithDebugInfosMockRecorder is the mock recorder for MockSendAlgorithmWithDebugInfos. -type MockSendAlgorithmWithDebugInfosMockRecorder struct { - mock *MockSendAlgorithmWithDebugInfos -} - -// NewMockSendAlgorithmWithDebugInfos creates a new mock instance. -func NewMockSendAlgorithmWithDebugInfos(ctrl *gomock.Controller) *MockSendAlgorithmWithDebugInfos { - mock := &MockSendAlgorithmWithDebugInfos{ctrl: ctrl} - mock.recorder = &MockSendAlgorithmWithDebugInfosMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSendAlgorithmWithDebugInfos) EXPECT() *MockSendAlgorithmWithDebugInfosMockRecorder { - return m.recorder -} - -// CanSend mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSend", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// CanSend indicates an expected call of CanSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) -} - -// GetCongestionWindow mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCongestionWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetCongestionWindow indicates an expected call of GetCongestionWindow. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) -} - -// HasPacingBudget mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasPacingBudget") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget)) -} - -// InRecovery mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InRecovery") - ret0, _ := ret[0].(bool) - return ret0 -} - -// InRecovery indicates an expected call of InRecovery. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) -} - -// InSlowStart mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InSlowStart") - ret0, _ := ret[0].(bool) - return ret0 -} - -// InSlowStart indicates an expected call of InSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) -} - -// MaybeExitSlowStart mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "MaybeExitSlowStart") -} - -// MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) -} - -// OnPacketAcked mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) -} - -// OnPacketAcked indicates an expected call of OnPacketAcked. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) -} - -// OnPacketLost mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) -} - -// OnPacketLost indicates an expected call of OnPacketLost. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) -} - -// OnPacketSent mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) -} - -// OnPacketSent indicates an expected call of OnPacketSent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) -} - -// OnRetransmissionTimeout mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnRetransmissionTimeout", arg0) -} - -// OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) -} - -// SetMaxDatagramSize mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxDatagramSize", arg0) -} - -// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) -} - -// TimeUntilSend mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TimeUntilSend", arg0) - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) -} diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go deleted file mode 100644 index d16b3aae..00000000 --- a/internal/mocks/connection_flow_controller.go +++ /dev/null @@ -1,128 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/flowcontrol (interfaces: ConnectionFlowController) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockConnectionFlowController is a mock of ConnectionFlowController interface. -type MockConnectionFlowController struct { - ctrl *gomock.Controller - recorder *MockConnectionFlowControllerMockRecorder -} - -// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController. -type MockConnectionFlowControllerMockRecorder struct { - mock *MockConnectionFlowController -} - -// NewMockConnectionFlowController creates a new mock instance. -func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { - mock := &MockConnectionFlowController{ctrl: ctrl} - mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { - return m.recorder -} - -// AddBytesRead mocks base method. -func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesRead", arg0) -} - -// AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) -} - -// AddBytesSent mocks base method. -func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesSent", arg0) -} - -// AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) -} - -// GetWindowUpdate mocks base method. -func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) -} - -// IsNewlyBlocked mocks base method. -func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsNewlyBlocked") - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) -} - -// Reset mocks base method. -func (m *MockConnectionFlowController) Reset() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Reset") - ret0, _ := ret[0].(error) - return ret0 -} - -// Reset indicates an expected call of Reset. -func (mr *MockConnectionFlowControllerMockRecorder) Reset() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) -} - -// SendWindowSize mocks base method. -func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) -} - -// UpdateSendWindow mocks base method. -func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) -} - -// UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) -} diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go deleted file mode 100644 index b28499e2..00000000 --- a/internal/mocks/crypto_setup.go +++ /dev/null @@ -1,264 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/handshake (interfaces: CryptoSetup) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - handshake "github.com/imroc/req/v3/internal/handshake" - protocol "github.com/imroc/req/v3/internal/protocol" - qtls "github.com/imroc/req/v3/internal/qtls" -) - -// MockCryptoSetup is a mock of CryptoSetup interface. -type MockCryptoSetup struct { - ctrl *gomock.Controller - recorder *MockCryptoSetupMockRecorder -} - -// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup. -type MockCryptoSetupMockRecorder struct { - mock *MockCryptoSetup -} - -// NewMockCryptoSetup creates a new mock instance. -func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup { - mock := &MockCryptoSetup{ctrl: ctrl} - mock.recorder = &MockCryptoSetupMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder { - return m.recorder -} - -// ChangeConnectionID mocks base method. -func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ChangeConnectionID", arg0) -} - -// ChangeConnectionID indicates an expected call of ChangeConnectionID. -func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) -} - -// Close mocks base method. -func (m *MockCryptoSetup) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) -} - -// ConnectionState mocks base method. -func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(qtls.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState. -func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) -} - -// Get0RTTOpener mocks base method. -func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get0RTTOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get0RTTOpener indicates an expected call of Get0RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) -} - -// Get0RTTSealer mocks base method. -func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get0RTTSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get0RTTSealer indicates an expected call of Get0RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) -} - -// Get1RTTOpener mocks base method. -func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get1RTTOpener") - ret0, _ := ret[0].(handshake.ShortHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get1RTTOpener indicates an expected call of Get1RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) -} - -// Get1RTTSealer mocks base method. -func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get1RTTSealer") - ret0, _ := ret[0].(handshake.ShortHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get1RTTSealer indicates an expected call of Get1RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) -} - -// GetHandshakeOpener mocks base method. -func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHandshakeOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHandshakeOpener indicates an expected call of GetHandshakeOpener. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) -} - -// GetHandshakeSealer mocks base method. -func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHandshakeSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) -} - -// GetInitialOpener mocks base method. -func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInitialOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInitialOpener indicates an expected call of GetInitialOpener. -func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) -} - -// GetInitialSealer mocks base method. -func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInitialSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInitialSealer indicates an expected call of GetInitialSealer. -func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) -} - -// GetSessionTicket mocks base method. -func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionTicket") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionTicket indicates an expected call of GetSessionTicket. -func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) -} - -// HandleMessage mocks base method. -func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) -} - -// RunHandshake mocks base method. -func (m *MockCryptoSetup) RunHandshake() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RunHandshake") -} - -// RunHandshake indicates an expected call of RunHandshake. -func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) -} - -// SetHandshakeConfirmed mocks base method. -func (m *MockCryptoSetup) SetHandshakeConfirmed() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeConfirmed") -} - -// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) -} - -// SetLargest1RTTAcked mocks base method. -func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. -func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) -} diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go deleted file mode 100644 index a305fe0a..00000000 --- a/internal/mocks/logging/connection_tracer.go +++ /dev/null @@ -1,353 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/logging (interfaces: ConnectionTracer) - -// Package mocklogging is a generated GoMock package. -package mocklogging - -import ( - "github.com/lucas-clemente/quic-go" - net "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - utils "github.com/imroc/req/v3/internal/utils" - wire "github.com/imroc/req/v3/internal/wire" - logging "github.com/imroc/req/v3/internal/logging" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 quic.VersionNumber, arg1, arg2 []quic.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedPacket mocks base method. -func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []quic.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentPacket mocks base method. -func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) -} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go deleted file mode 100644 index bf86f4be..00000000 --- a/internal/mocks/logging/tracer.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/logging (interfaces: Tracer) - -// Package mocklogging is a generated GoMock package. -package mocklogging - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" - wire "github.com/imroc/req/v3/internal/wire" - logging "github.com/imroc/req/v3/internal/logging" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(logging.ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go deleted file mode 100644 index 022bb8b3..00000000 --- a/internal/mocks/long_header_opener.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/handshake (interfaces: LongHeaderOpener) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockLongHeaderOpener is a mock of LongHeaderOpener interface. -type MockLongHeaderOpener struct { - ctrl *gomock.Controller - recorder *MockLongHeaderOpenerMockRecorder -} - -// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener. -type MockLongHeaderOpenerMockRecorder struct { - mock *MockLongHeaderOpener -} - -// NewMockLongHeaderOpener creates a new mock instance. -func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { - mock := &MockLongHeaderOpener{ctrl: ctrl} - mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { - return m.recorder -} - -// DecodePacketNumber mocks base method. -func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) -} - -// DecryptHeader mocks base method. -func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) -} - -// DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) -} - -// Open mocks base method. -func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open indicates an expected call of Open. -func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) -} diff --git a/internal/mocks/quic/early_listener.go b/internal/mocks/quic/early_listener.go deleted file mode 100644 index 279096b8..00000000 --- a/internal/mocks/quic/early_listener.go +++ /dev/null @@ -1,80 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: EarlyListener) - -// Package mockquic is a generated GoMock package. -package mockquic - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - quic "github.com/lucas-clemente/quic-go" -) - -// MockEarlyListener is a mock of EarlyListener interface. -type MockEarlyListener struct { - ctrl *gomock.Controller - recorder *MockEarlyListenerMockRecorder -} - -// MockEarlyListenerMockRecorder is the mock recorder for MockEarlyListener. -type MockEarlyListenerMockRecorder struct { - mock *MockEarlyListener -} - -// NewMockEarlyListener creates a new mock instance. -func NewMockEarlyListener(ctrl *gomock.Controller) *MockEarlyListener { - mock := &MockEarlyListener{ctrl: ctrl} - mock.recorder = &MockEarlyListenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { - return m.recorder -} - -// Accept mocks base method. -func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Accept", arg0) - ret0, _ := ret[0].(quic.EarlyConnection) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Accept indicates an expected call of Accept. -func (mr *MockEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockEarlyListener)(nil).Accept), arg0) -} - -// Addr mocks base method. -func (m *MockEarlyListener) Addr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Addr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// Addr indicates an expected call of Addr. -func (mr *MockEarlyListenerMockRecorder) Addr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockEarlyListener)(nil).Addr)) -} - -// Close mocks base method. -func (m *MockEarlyListener) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockEarlyListenerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEarlyListener)(nil).Close)) -} diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go deleted file mode 100644 index 146579c1..00000000 --- a/internal/mocks/short_header_opener.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/handshake (interfaces: ShortHeaderOpener) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockShortHeaderOpener is a mock of ShortHeaderOpener interface. -type MockShortHeaderOpener struct { - ctrl *gomock.Controller - recorder *MockShortHeaderOpenerMockRecorder -} - -// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener. -type MockShortHeaderOpenerMockRecorder struct { - mock *MockShortHeaderOpener -} - -// NewMockShortHeaderOpener creates a new mock instance. -func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { - mock := &MockShortHeaderOpener{ctrl: ctrl} - mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { - return m.recorder -} - -// DecodePacketNumber mocks base method. -func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) -} - -// DecryptHeader mocks base method. -func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) -} - -// DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) -} - -// Open mocks base method. -func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open indicates an expected call of Open. -func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) -} diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go deleted file mode 100644 index 2ea53fc4..00000000 --- a/internal/mocks/short_header_sealer.go +++ /dev/null @@ -1,89 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/handshake (interfaces: ShortHeaderSealer) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockShortHeaderSealer is a mock of ShortHeaderSealer interface. -type MockShortHeaderSealer struct { - ctrl *gomock.Controller - recorder *MockShortHeaderSealerMockRecorder -} - -// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer. -type MockShortHeaderSealerMockRecorder struct { - mock *MockShortHeaderSealer -} - -// NewMockShortHeaderSealer creates a new mock instance. -func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { - mock := &MockShortHeaderSealer{ctrl: ctrl} - mock.recorder = &MockShortHeaderSealerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { - return m.recorder -} - -// EncryptHeader mocks base method. -func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) -} - -// EncryptHeader indicates an expected call of EncryptHeader. -func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) -} - -// KeyPhase mocks base method. -func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "KeyPhase") - ret0, _ := ret[0].(protocol.KeyPhaseBit) - return ret0 -} - -// KeyPhase indicates an expected call of KeyPhase. -func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) -} - -// Overhead mocks base method. -func (m *MockShortHeaderSealer) Overhead() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Overhead") - ret0, _ := ret[0].(int) - return ret0 -} - -// Overhead indicates an expected call of Overhead. -func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) -} - -// Seal mocks base method. -func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Seal indicates an expected call of Seal. -func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) -} diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go deleted file mode 100644 index ba779a6a..00000000 --- a/internal/mocks/stream_flow_controller.go +++ /dev/null @@ -1,140 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/flowcontrol (interfaces: StreamFlowController) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/protocol" -) - -// MockStreamFlowController is a mock of StreamFlowController interface. -type MockStreamFlowController struct { - ctrl *gomock.Controller - recorder *MockStreamFlowControllerMockRecorder -} - -// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController. -type MockStreamFlowControllerMockRecorder struct { - mock *MockStreamFlowController -} - -// NewMockStreamFlowController creates a new mock instance. -func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { - mock := &MockStreamFlowController{ctrl: ctrl} - mock.recorder = &MockStreamFlowControllerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { - return m.recorder -} - -// Abandon mocks base method. -func (m *MockStreamFlowController) Abandon() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Abandon") -} - -// Abandon indicates an expected call of Abandon. -func (mr *MockStreamFlowControllerMockRecorder) Abandon() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) -} - -// AddBytesRead mocks base method. -func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesRead", arg0) -} - -// AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) -} - -// AddBytesSent mocks base method. -func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesSent", arg0) -} - -// AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) -} - -// GetWindowUpdate mocks base method. -func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) -} - -// IsNewlyBlocked mocks base method. -func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsNewlyBlocked") - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) -} - -// SendWindowSize mocks base method. -func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) -} - -// UpdateHighestReceived mocks base method. -func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount, arg1 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateHighestReceived", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateHighestReceived indicates an expected call of UpdateHighestReceived. -func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) -} - -// UpdateSendWindow mocks base method. -func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) -} - -// UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) -} diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go deleted file mode 100644 index e3ae2c8e..00000000 --- a/internal/mocks/tls/client_session_cache.go +++ /dev/null @@ -1,62 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: crypto/tls (interfaces: ClientSessionCache) - -// Package mocktls is a generated GoMock package. -package mocktls - -import ( - tls "crypto/tls" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockClientSessionCache is a mock of ClientSessionCache interface. -type MockClientSessionCache struct { - ctrl *gomock.Controller - recorder *MockClientSessionCacheMockRecorder -} - -// MockClientSessionCacheMockRecorder is the mock recorder for MockClientSessionCache. -type MockClientSessionCacheMockRecorder struct { - mock *MockClientSessionCache -} - -// NewMockClientSessionCache creates a new mock instance. -func NewMockClientSessionCache(ctrl *gomock.Controller) *MockClientSessionCache { - mock := &MockClientSessionCache{ctrl: ctrl} - mock.recorder = &MockClientSessionCacheMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClientSessionCache) EXPECT() *MockClientSessionCacheMockRecorder { - return m.recorder -} - -// Get mocks base method. -func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(*tls.ClientSessionState) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) -} - -// Put mocks base method. -func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Put", arg0, arg1) -} - -// Put indicates an expected call of Put. -func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) -} diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go deleted file mode 100644 index 58ea6b43..00000000 --- a/internal/qerr/error_codes.go +++ /dev/null @@ -1,88 +0,0 @@ -package qerr - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/qtls" -) - -// TransportErrorCode is a QUIC transport error. -type TransportErrorCode uint64 - -// The error codes defined by QUIC -const ( - NoError TransportErrorCode = 0x0 - InternalError TransportErrorCode = 0x1 - ConnectionRefused TransportErrorCode = 0x2 - FlowControlError TransportErrorCode = 0x3 - StreamLimitError TransportErrorCode = 0x4 - StreamStateError TransportErrorCode = 0x5 - FinalSizeError TransportErrorCode = 0x6 - FrameEncodingError TransportErrorCode = 0x7 - TransportParameterError TransportErrorCode = 0x8 - ConnectionIDLimitError TransportErrorCode = 0x9 - ProtocolViolation TransportErrorCode = 0xa - InvalidToken TransportErrorCode = 0xb - ApplicationErrorErrorCode TransportErrorCode = 0xc - CryptoBufferExceeded TransportErrorCode = 0xd - KeyUpdateError TransportErrorCode = 0xe - AEADLimitReached TransportErrorCode = 0xf - NoViablePathError TransportErrorCode = 0x10 -) - -func (e TransportErrorCode) IsCryptoError() bool { - return e >= 0x100 && e < 0x200 -} - -// Message is a description of the error. -// It only returns a non-empty string for crypto errors. -func (e TransportErrorCode) Message() string { - if !e.IsCryptoError() { - return "" - } - return qtls.Alert(e - 0x100).Error() -} - -func (e TransportErrorCode) String() string { - switch e { - case NoError: - return "NO_ERROR" - case InternalError: - return "INTERNAL_ERROR" - case ConnectionRefused: - return "CONNECTION_REFUSED" - case FlowControlError: - return "FLOW_CONTROL_ERROR" - case StreamLimitError: - return "STREAM_LIMIT_ERROR" - case StreamStateError: - return "STREAM_STATE_ERROR" - case FinalSizeError: - return "FINAL_SIZE_ERROR" - case FrameEncodingError: - return "FRAME_ENCODING_ERROR" - case TransportParameterError: - return "TRANSPORT_PARAMETER_ERROR" - case ConnectionIDLimitError: - return "CONNECTION_ID_LIMIT_ERROR" - case ProtocolViolation: - return "PROTOCOL_VIOLATION" - case InvalidToken: - return "INVALID_TOKEN" - case ApplicationErrorErrorCode: - return "APPLICATION_ERROR" - case CryptoBufferExceeded: - return "CRYPTO_BUFFER_EXCEEDED" - case KeyUpdateError: - return "KEY_UPDATE_ERROR" - case AEADLimitReached: - return "AEAD_LIMIT_REACHED" - case NoViablePathError: - return "NO_VIABLE_PATH" - default: - if e.IsCryptoError() { - return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) - } - return fmt.Sprintf("unknown error code: %#x", uint16(e)) - } -} diff --git a/internal/qerr/errorcodes_test.go b/internal/qerr/errorcodes_test.go deleted file mode 100644 index cfc6cd85..00000000 --- a/internal/qerr/errorcodes_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package qerr - -import ( - "go/ast" - "go/parser" - "go/token" - "path" - "runtime" - "strconv" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("error codes", func() { - // If this test breaks, you should run `go generate ./...` - It("has a string representation for every error code", func() { - // We parse the error code file, extract all constants, and verify that - // each of them has a string version. Go FTW! - _, thisfile, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - filename := path.Join(path.Dir(thisfile), "error_codes.go") - fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) - Expect(err).NotTo(HaveOccurred()) - constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs - Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing - for _, c := range constSpecs { - valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value - val, err := strconv.ParseInt(valString, 0, 64) - Expect(err).NotTo(HaveOccurred()) - Expect(TransportErrorCode(val).String()).ToNot(Equal("unknown error code")) - } - }) - - It("has a string representation for unknown error codes", func() { - Expect(TransportErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) - }) - - It("says if an error is a crypto error", func() { - for i := 0; i < 0x100; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) - } - for i := 0x100; i < 0x200; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeTrue()) - } - for i := 0x200; i < 0x300; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) - } - }) -}) diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go deleted file mode 100644 index 327491b8..00000000 --- a/internal/qerr/errors.go +++ /dev/null @@ -1,125 +0,0 @@ -package qerr - -import ( - "fmt" - "github.com/lucas-clemente/quic-go" - "net" - - "github.com/imroc/req/v3/internal/protocol" -) - -var ( - ErrHandshakeTimeout = &HandshakeTimeoutError{} - ErrIdleTimeout = &IdleTimeoutError{} -) - -type TransportError struct { - Remote bool - FrameType uint64 - ErrorCode TransportErrorCode - ErrorMessage string -} - -var _ error = &TransportError{} - -// NewCryptoError create a new TransportError instance for a crypto error -func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { - return &TransportError{ - ErrorCode: 0x100 + TransportErrorCode(tlsAlert), - ErrorMessage: errorMessage, - } -} - -func (e *TransportError) Error() string { - str := e.ErrorCode.String() - if e.FrameType != 0 { - str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) - } - msg := e.ErrorMessage - if len(msg) == 0 { - msg = e.ErrorCode.Message() - } - if len(msg) == 0 { - return str - } - return str + ": " + msg -} - -func (e *TransportError) Is(target error) bool { - return target == net.ErrClosed -} - -// An ApplicationErrorCode is an application-defined error code. -type ApplicationErrorCode uint64 - -func (e *ApplicationError) Is(target error) bool { - return target == net.ErrClosed -} - -// A StreamErrorCode is an error code used to cancel streams. -type StreamErrorCode uint64 - -type ApplicationError struct { - Remote bool - ErrorCode ApplicationErrorCode - ErrorMessage string -} - -var _ error = &ApplicationError{} - -func (e *ApplicationError) Error() string { - if len(e.ErrorMessage) == 0 { - return fmt.Sprintf("Application error %#x", e.ErrorCode) - } - return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) -} - -type IdleTimeoutError struct{} - -var _ error = &IdleTimeoutError{} - -func (e *IdleTimeoutError) Timeout() bool { return true } -func (e *IdleTimeoutError) Temporary() bool { return false } -func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } -func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } - -type HandshakeTimeoutError struct{} - -var _ error = &HandshakeTimeoutError{} - -func (e *HandshakeTimeoutError) Timeout() bool { return true } -func (e *HandshakeTimeoutError) Temporary() bool { return false } -func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } -func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } - -// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. -type VersionNegotiationError struct { - Ours []quic.VersionNumber - Theirs []quic.VersionNumber -} - -func (e *VersionNegotiationError) Error() string { - return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) -} - -func (e *VersionNegotiationError) Is(target error) bool { - return target == net.ErrClosed -} - -// A StatelessResetError occurs when we receive a stateless reset. -type StatelessResetError struct { - Token protocol.StatelessResetToken -} - -var _ net.Error = &StatelessResetError{} - -func (e *StatelessResetError) Error() string { - return fmt.Sprintf("received a stateless reset with token %x", e.Token) -} - -func (e *StatelessResetError) Is(target error) bool { - return target == net.ErrClosed -} - -func (e *StatelessResetError) Timeout() bool { return false } -func (e *StatelessResetError) Temporary() bool { return true } diff --git a/internal/qerr/errors_suite_test.go b/internal/qerr/errors_suite_test.go deleted file mode 100644 index 749cdedc..00000000 --- a/internal/qerr/errors_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package qerr - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestErrorcodes(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Errors Suite") -} diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go deleted file mode 100644 index 81aa2df1..00000000 --- a/internal/qerr/errors_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package qerr - -import ( - "errors" - "github.com/lucas-clemente/quic-go" - "net" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("QUIC Errors", func() { - Context("Transport Errors", func() { - It("has a string representation", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - ErrorMessage: "foobar", - }).Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) - }) - - It("has a string representation for empty error phrases", func() { - Expect((&TransportError{ErrorCode: FlowControlError}).Error()).To(Equal("FLOW_CONTROL_ERROR")) - }) - - It("includes the frame type, for errors without a message", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - FrameType: 0x1337, - }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) - }) - - It("includes the frame type, for errors with a message", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - FrameType: 0x1337, - ErrorMessage: "foobar", - }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) - }) - - Context("crypto errors", func() { - It("has a string representation for errors with a message", func() { - err := NewCryptoError(0x42, "foobar") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) - }) - - It("has a string representation for errors without a message", func() { - err := NewCryptoError(0x2a, "") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) - }) - }) - }) - - Context("Application Errors", func() { - It("has a string representation for errors with a message", func() { - Expect((&ApplicationError{ - ErrorCode: 0x42, - ErrorMessage: "foobar", - }).Error()).To(Equal("Application error 0x42: foobar")) - }) - - It("has a string representation for errors without a message", func() { - Expect((&ApplicationError{ - ErrorCode: 0x42, - }).Error()).To(Equal("Application error 0x42")) - }) - }) - - Context("timeout errors", func() { - It("handshake timeouts", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &HandshakeTimeoutError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) - }) - - It("idle timeouts", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &IdleTimeoutError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(Equal("timeout: no recent network activity")) - }) - }) - - Context("Version Negotiation errors", func() { - It("has a string representation", func() { - Expect((&VersionNegotiationError{ - Ours: []quic.VersionNumber{2, 3}, - Theirs: []quic.VersionNumber{4, 5, 6}, - }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) - }) - }) - - Context("Stateless Reset errors", func() { - token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - - It("has a string representation", func() { - Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f")) - }) - - It("is a net.Error", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &StatelessResetError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeFalse()) - }) - }) - - It("says that errors are net.ErrClosed errors", func() { - Expect(errors.Is(&TransportError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&ApplicationError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&IdleTimeoutError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&StatelessResetError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&VersionNegotiationError{}, net.ErrClosed)).To(BeTrue()) - }) -}) diff --git a/internal/qtls/go116.go b/internal/qtls/go116.go index e3024624..73966c3e 100644 --- a/internal/qtls/go116.go +++ b/internal/qtls/go116.go @@ -4,97 +4,16 @@ package qtls import ( - "crypto" - "crypto/cipher" "crypto/tls" - "net" - "unsafe" - "github.com/marten-seemann/qtls-go1-16" ) type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn // ConnectionState contains information about the state of the connection. ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication ) -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - // ToTLSConnectionState extracts the tls.ConnectionState func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { return cs.ConnectionState } - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/qtls/go117.go b/internal/qtls/go117.go index bc385f19..8bbb04df 100644 --- a/internal/qtls/go117.go +++ b/internal/qtls/go117.go @@ -4,97 +4,16 @@ package qtls import ( - "crypto" - "crypto/cipher" "crypto/tls" - "net" - "unsafe" - "github.com/marten-seemann/qtls-go1-17" ) type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn // ConnectionState contains information about the state of the connection. ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication ) -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - // ToTLSConnectionState extracts the tls.ConnectionState func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { return cs.ConnectionState } - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/qtls/go118.go b/internal/qtls/go118.go index 0e0e7966..20f6a63f 100644 --- a/internal/qtls/go118.go +++ b/internal/qtls/go118.go @@ -4,97 +4,16 @@ package qtls import ( - "crypto" - "crypto/cipher" "crypto/tls" - "net" - "unsafe" - "github.com/marten-seemann/qtls-go1-18" ) type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn // ConnectionState contains information about the state of the connection. ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication ) -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - // ToTLSConnectionState extracts the tls.ConnectionState func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { return cs.ConnectionState } - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go index 87e7132e..56e65c7d 100644 --- a/internal/qtls/go119.go +++ b/internal/qtls/go119.go @@ -3,4 +3,17 @@ package qtls -var _ int = "The version of quic-go you're using can't be built on Go 1.19 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." +import ( + "crypto/tls" + "github.com/marten-seemann/qtls-go1-19" +) + +type ( + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT +) + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go deleted file mode 100644 index 24b143b2..00000000 --- a/internal/qtls/qtls_suite_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package qtls - -import ( - "testing" - - gomock "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestQTLS(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "qtls Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/qtls/qtls_test.go b/internal/qtls/qtls_test.go deleted file mode 100644 index c64c5e9e..00000000 --- a/internal/qtls/qtls_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package qtls - -import ( - "crypto/tls" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("qtls wrapper", func() { - It("gets cipher suites", func() { - for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} { - cs := CipherSuiteTLS13ByID(id) - Expect(cs.ID).To(Equal(id)) - } - }) -}) diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go deleted file mode 100644 index a71607c9..00000000 --- a/internal/wire/ack_frame.go +++ /dev/null @@ -1,252 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "github.com/lucas-clemente/quic-go" - "sort" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/utils" -) - -var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") - -// An AckFrame is an ACK frame -type AckFrame struct { - AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last - DelayTime time.Duration - - ECT0, ECT1, ECNCE uint64 -} - -// parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ quic.VersionNumber) (*AckFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - ecn := typeByte&0x1 > 0 - - frame := &AckFrame{} - - la, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - largestAcked := protocol.PacketNumber(la) - delay, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - delayTime := time.Duration(delay*1< largestAcked { - return nil, errors.New("invalid first ACK range") - } - smallest := largestAcked - ackBlock - - // read all the other ACK ranges - frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) - for i := uint64(0); i < numBlocks; i++ { - g, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - gap := protocol.PacketNumber(g) - if smallest < gap+2 { - return nil, errInvalidAckRanges - } - largest := smallest - gap - 2 - - ab, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - ackBlock := protocol.PacketNumber(ab) - - if ackBlock > largest { - return nil, errInvalidAckRanges - } - smallest = largest - ackBlock - frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) - } - - if !frame.validateAckRanges() { - return nil, errInvalidAckRanges - } - - // parse (and skip) the ECN section - if ecn { - for i := 0; i < 3; i++ { - if _, err := quicvarint.Read(r); err != nil { - return nil, err - } - } - } - - return frame, nil -} - -// Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 - if hasECN { - b.WriteByte(0x3) - } else { - b.WriteByte(0x2) - } - quicvarint.Write(b, uint64(f.LargestAcked())) - quicvarint.Write(b, encodeAckDelay(f.DelayTime)) - - numRanges := f.numEncodableAckRanges() - quicvarint.Write(b, uint64(numRanges-1)) - - // write the first range - _, firstRange := f.encodeAckRange(0) - quicvarint.Write(b, firstRange) - - // write all the other range - for i := 1; i < numRanges; i++ { - gap, len := f.encodeAckRange(i) - quicvarint.Write(b, gap) - quicvarint.Write(b, len) - } - - if hasECN { - quicvarint.Write(b, f.ECT0) - quicvarint.Write(b, f.ECT1) - quicvarint.Write(b, f.ECNCE) - } - return nil -} - -// Length of a written frame -func (f *AckFrame) Length(version quic.VersionNumber) protocol.ByteCount { - largestAcked := f.AckRanges[0].Largest - numRanges := f.numEncodableAckRanges() - - length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) - - length += quicvarint.Len(uint64(numRanges - 1)) - lowestInFirstRange := f.AckRanges[0].Smallest - length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) - - for i := 1; i < numRanges; i++ { - gap, len := f.encodeAckRange(i) - length += quicvarint.Len(gap) - length += quicvarint.Len(len) - } - if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { - length += quicvarint.Len(f.ECT0) - length += quicvarint.Len(f.ECT1) - length += quicvarint.Len(f.ECNCE) - } - return length -} - -// gets the number of ACK ranges that can be encoded -// such that the resulting frame is smaller than the maximum ACK frame size -func (f *AckFrame) numEncodableAckRanges() int { - length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) - length += 2 // assume that the number of ranges will consume 2 bytes - for i := 1; i < len(f.AckRanges); i++ { - gap, len := f.encodeAckRange(i) - rangeLen := quicvarint.Len(gap) + quicvarint.Len(len) - if length+rangeLen > protocol.MaxAckFrameSize { - // Writing range i would exceed the MaxAckFrameSize. - // So encode one range less than that. - return i - 1 - } - length += rangeLen - } - return len(f.AckRanges) -} - -func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { - if i == 0 { - return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) - } - return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), - uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) -} - -// HasMissingRanges returns if this frame reports any missing packets -func (f *AckFrame) HasMissingRanges() bool { - return len(f.AckRanges) > 1 -} - -func (f *AckFrame) validateAckRanges() bool { - if len(f.AckRanges) == 0 { - return false - } - - // check the validity of every single ACK range - for _, ackRange := range f.AckRanges { - if ackRange.Smallest > ackRange.Largest { - return false - } - } - - // check the consistency for ACK with multiple NACK ranges - for i, ackRange := range f.AckRanges { - if i == 0 { - continue - } - lastAckRange := f.AckRanges[i-1] - if lastAckRange.Smallest <= ackRange.Smallest { - return false - } - if lastAckRange.Smallest <= ackRange.Largest+1 { - return false - } - } - - return true -} - -// LargestAcked is the largest acked packet number -func (f *AckFrame) LargestAcked() protocol.PacketNumber { - return f.AckRanges[0].Largest -} - -// LowestAcked is the lowest acked packet number -func (f *AckFrame) LowestAcked() protocol.PacketNumber { - return f.AckRanges[len(f.AckRanges)-1].Smallest -} - -// AcksPacket determines if this ACK frame acks a certain packet number -func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { - if p < f.LowestAcked() || p > f.LargestAcked() { - return false - } - - i := sort.Search(len(f.AckRanges), func(i int) bool { - return p >= f.AckRanges[i].Smallest - }) - // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked - return p <= f.AckRanges[i].Largest -} - -func encodeAckDelay(delay time.Duration) uint64 { - return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) -} diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go deleted file mode 100644 index de1671b0..00000000 --- a/internal/wire/ack_frame_test.go +++ /dev/null @@ -1,454 +0,0 @@ -package wire - -import ( - "bytes" - "io" - "math" - "time" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("ACK Frame (for IETF QUIC)", func() { - Context("parsing", func() { - It("parses an ACK frame without any ranges", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(10)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("parses an ACK frame that only acks a single packet", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(55)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts an ACK frame that acks all packets from 0 to largest", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(20)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(20)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects an ACK frame that has a first ACK block which is larger than LargestAcked", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(20)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(21)...) // first ack block - b := bytes.NewReader(data) - _, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError("invalid first ACK range")) - }) - - It("parses an ACK frame that has a single block", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(frame.AckRanges).To(Equal([]AckRange{ - {Largest: 1000, Smallest: 900}, - {Largest: 800, Smallest: 750}, - })) - Expect(b.Len()).To(BeZero()) - }) - - It("parses an ACK frame that has a multiple blocks", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(2)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - data = append(data, encodeVarInt(0)...) // gap - data = append(data, encodeVarInt(0)...) // ack block - data = append(data, encodeVarInt(1)...) // gap - data = append(data, encodeVarInt(1)...) // ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(frame.AckRanges).To(Equal([]AckRange{ - {Largest: 100, Smallest: 100}, - {Largest: 98, Smallest: 98}, - {Largest: 95, Smallest: 94}, - })) - Expect(b.Len()).To(BeZero()) - }) - - It("uses the ack delay exponent", func() { - const delayTime = 1 << 10 * time.Millisecond - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: delayTime, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - for i := uint8(0); i < 8; i++ { - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) - } - }) - - It("gracefully handles overflows of the delay time", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(math.MaxUint64/5)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.DelayTime).To(BeNumerically(">", 0)) - // The maximum encodable duration is ~292 years. - Expect(frame.DelayTime.Hours()).To(BeNumerically("~", 292*365*24, 365*24)) - }) - - It("errors on EOF", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - - Context("ACK_ECN", func() { - It("parses", func() { - data := []byte{0x3} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(10)...) // first ack block - data = append(data, encodeVarInt(0x42)...) // ECT(0) - data = append(data, encodeVarInt(0x12345)...) // ECT(1) - data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOF", func() { - data := []byte{0x3} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - data = append(data, encodeVarInt(0x42)...) // ECT(0) - data = append(data, encodeVarInt(0x12345)...) // ECT(1) - data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - }) - - Context("when writing", func() { - It("writes a simple frame", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x2} - expected = append(expected, encodeVarInt(1337)...) // largest acked - expected = append(expected, 0) // delay - expected = append(expected, encodeVarInt(0)...) // num ranges - expected = append(expected, encodeVarInt(1337-100)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes an ACK-ECN frame", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 10, Largest: 2000}}, - ECT0: 13, - ECT1: 37, - ECNCE: 12345, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - expected := []byte{0x3} - expected = append(expected, encodeVarInt(2000)...) // largest acked - expected = append(expected, 0) // delay - expected = append(expected, encodeVarInt(0)...) // num ranges - expected = append(expected, encodeVarInt(2000-10)...) - expected = append(expected, encodeVarInt(13)...) - expected = append(expected, encodeVarInt(37)...) - expected = append(expected, encodeVarInt(12345)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a frame that acks a single packet", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, - DelayTime: 18 * time.Millisecond, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(frame.DelayTime).To(Equal(f.DelayTime)) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame that acks many packets", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame with a a single gap", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 400, Largest: 1000}, - {Smallest: 100, Largest: 200}, - }, - } - Expect(f.validateAckRanges()).To(BeTrue()) - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame with multiple ranges", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 10, Largest: 10}, - {Smallest: 8, Largest: 8}, - {Smallest: 5, Largest: 6}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - }) - - It("limits the maximum size of the ACK frame", func() { - buf := &bytes.Buffer{} - const numRanges = 1000 - ackRanges := make([]AckRange, numRanges) - for i := protocol.PacketNumber(1); i <= numRanges; i++ { - ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i} - } - f := &AckFrame{AckRanges: ackRanges} - Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize - Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) - Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges - }) - }) - - Context("ACK range validator", func() { - It("rejects ACKs without ranges", func() { - Expect((&AckFrame{}).validateAckRanges()).To(BeFalse()) - }) - - It("accepts an ACK without NACK Ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 7}}, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - - It("rejects ACK ranges with Smallest greater than Largest", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 8, Largest: 10}, - {Smallest: 4, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects ACK ranges in the wrong order", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 2, Largest: 2}, - {Smallest: 6, Largest: 7}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects with overlapping ACK ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 7}, - {Smallest: 2, Largest: 5}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects ACK ranges that are part of a larger ACK range", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 4, Largest: 7}, - {Smallest: 5, Largest: 6}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects with directly adjacent ACK ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 7}, - {Smallest: 2, Largest: 4}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("accepts an ACK with one lost packet", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 10}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - - It("accepts an ACK with multiple lost packets", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 15, Largest: 20}, - {Smallest: 10, Largest: 12}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - }) - - Context("check if ACK frame acks a certain packet", func() { - It("works with an ACK without any ranges", func() { - f := AckFrame{ - AckRanges: []AckRange{{Smallest: 5, Largest: 10}}, - } - Expect(f.AcksPacket(1)).To(BeFalse()) - Expect(f.AcksPacket(4)).To(BeFalse()) - Expect(f.AcksPacket(5)).To(BeTrue()) - Expect(f.AcksPacket(8)).To(BeTrue()) - Expect(f.AcksPacket(10)).To(BeTrue()) - Expect(f.AcksPacket(11)).To(BeFalse()) - Expect(f.AcksPacket(20)).To(BeFalse()) - }) - - It("works with an ACK with multiple ACK ranges", func() { - f := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 15, Largest: 20}, - {Smallest: 5, Largest: 8}, - }, - } - Expect(f.AcksPacket(4)).To(BeFalse()) - Expect(f.AcksPacket(5)).To(BeTrue()) - Expect(f.AcksPacket(6)).To(BeTrue()) - Expect(f.AcksPacket(7)).To(BeTrue()) - Expect(f.AcksPacket(8)).To(BeTrue()) - Expect(f.AcksPacket(9)).To(BeFalse()) - Expect(f.AcksPacket(14)).To(BeFalse()) - Expect(f.AcksPacket(15)).To(BeTrue()) - Expect(f.AcksPacket(18)).To(BeTrue()) - Expect(f.AcksPacket(19)).To(BeTrue()) - Expect(f.AcksPacket(20)).To(BeTrue()) - Expect(f.AcksPacket(21)).To(BeFalse()) - }) - }) -}) diff --git a/internal/wire/ack_range.go b/internal/wire/ack_range.go deleted file mode 100644 index 68032205..00000000 --- a/internal/wire/ack_range.go +++ /dev/null @@ -1,14 +0,0 @@ -package wire - -import "github.com/imroc/req/v3/internal/protocol" - -// AckRange is an ACK range -type AckRange struct { - Smallest protocol.PacketNumber - Largest protocol.PacketNumber -} - -// Len returns the number of packets contained in this ACK range -func (r AckRange) Len() protocol.PacketNumber { - return r.Largest - r.Smallest + 1 -} diff --git a/internal/wire/ack_range_test.go b/internal/wire/ack_range_test.go deleted file mode 100644 index 84ef71b5..00000000 --- a/internal/wire/ack_range_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package wire - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("ACK range", func() { - It("returns the length", func() { - Expect(AckRange{Smallest: 10, Largest: 10}.Len()).To(BeEquivalentTo(1)) - Expect(AckRange{Smallest: 10, Largest: 13}.Len()).To(BeEquivalentTo(4)) - }) -}) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go deleted file mode 100644 index 627007dd..00000000 --- a/internal/wire/connection_close_frame.go +++ /dev/null @@ -1,84 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A ConnectionCloseFrame is a CONNECTION_CLOSE frame -type ConnectionCloseFrame struct { - IsApplicationError bool - ErrorCode uint64 - FrameType uint64 - ReasonPhrase string -} - -func parseConnectionCloseFrame(r *bytes.Reader, _ quic.VersionNumber) (*ConnectionCloseFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} - ec, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.ErrorCode = ec - // read the Frame Type, if this is not an application error - if !f.IsApplicationError { - ft, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.FrameType = ft - } - var reasonPhraseLen uint64 - reasonPhraseLen, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - // shortcut to prevent the unnecessary allocation of dataLen bytes - // if the dataLen is larger than the remaining length of the packet - // reading the whole reason phrase would result in EOF when attempting to READ - if int(reasonPhraseLen) > r.Len() { - return nil, io.EOF - } - - reasonPhrase := make([]byte, reasonPhraseLen) - if _, err := io.ReadFull(r, reasonPhrase); err != nil { - // this should never happen, since we already checked the reasonPhraseLen earlier - return nil, err - } - f.ReasonPhrase = string(reasonPhrase) - return f, nil -} - -// Length of a written frame -func (f *ConnectionCloseFrame) Length(quic.VersionNumber) protocol.ByteCount { - length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) - if !f.IsApplicationError { - length += quicvarint.Len(f.FrameType) // for the frame type - } - return length -} - -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - if f.IsApplicationError { - b.WriteByte(0x1d) - } else { - b.WriteByte(0x1c) - } - - quicvarint.Write(b, f.ErrorCode) - if !f.IsApplicationError { - quicvarint.Write(b, f.FrameType) - } - quicvarint.Write(b, uint64(len(f.ReasonPhrase))) - b.WriteString(f.ReasonPhrase) - return nil -} diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go deleted file mode 100644 index 7947c681..00000000 --- a/internal/wire/connection_close_frame_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("CONNECTION_CLOSE Frame", func() { - Context("when parsing", func() { - It("accepts sample frame containing a QUIC error code", func() { - reason := "No recent network activity." - data := []byte{0x1c} - data = append(data, encodeVarInt(0x19)...) - data = append(data, encodeVarInt(0x1337)...) // frame type - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, []byte(reason)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.IsApplicationError).To(BeFalse()) - Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) - Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) - Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts sample frame containing an application error code", func() { - reason := "The application messed things up." - data := []byte{0x1d} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, reason...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.IsApplicationError).To(BeTrue()) - Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) - Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects long reason phrases", func() { - data := []byte{0x1c} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(0x42)...) // frame type - data = append(data, encodeVarInt(0xffff)...) // reason phrase length - b := bytes.NewReader(data) - _, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors on EOFs", func() { - reason := "No recent network activity." - data := []byte{0x1c} - data = append(data, encodeVarInt(0x19)...) - data = append(data, encodeVarInt(0x1337)...) // frame type - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, []byte(reason)...) - _, err := parseConnectionCloseFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - - It("parses a frame without a reason phrase", func() { - data := []byte{0x1c} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(0x42)...) // frame type - data = append(data, encodeVarInt(0)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.ReasonPhrase).To(BeEmpty()) - Expect(b.Len()).To(BeZero()) - }) - }) - - Context("when writing", func() { - It("writes a frame without a reason phrase", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - ErrorCode: 0xbeef, - FrameType: 0x12345, - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1c} - expected = append(expected, encodeVarInt(0xbeef)...) - expected = append(expected, encodeVarInt(0x12345)...) // frame type - expected = append(expected, encodeVarInt(0)...) // reason phrase length - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with a reason phrase", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - ErrorCode: 0xdead, - ReasonPhrase: "foobar", - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1c} - expected = append(expected, encodeVarInt(0xdead)...) - expected = append(expected, encodeVarInt(0)...) // frame type - expected = append(expected, encodeVarInt(6)...) // reason phrase length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with an application error code", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - IsApplicationError: true, - ErrorCode: 0xdead, - ReasonPhrase: "foobar", - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1d} - expected = append(expected, encodeVarInt(0xdead)...) - expected = append(expected, encodeVarInt(6)...) // reason phrase length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has proper min length, for a frame containing a QUIC error code", func() { - b := &bytes.Buffer{} - f := &ConnectionCloseFrame{ - ErrorCode: 0xcafe, - FrameType: 0xdeadbeef, - ReasonPhrase: "foobar", - } - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) - }) - - It("has proper min length, for a frame containing an application error code", func() { - b := &bytes.Buffer{} - f := &ConnectionCloseFrame{ - IsApplicationError: true, - ErrorCode: 0xcafe, - ReasonPhrase: "foobar", - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) - }) - }) -}) diff --git a/internal/wire/crypto_frame.go b/internal/wire/crypto_frame.go deleted file mode 100644 index 56f176a7..00000000 --- a/internal/wire/crypto_frame.go +++ /dev/null @@ -1,103 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A CryptoFrame is a CRYPTO frame -type CryptoFrame struct { - Offset protocol.ByteCount - Data []byte -} - -func parseCryptoFrame(r *bytes.Reader, _ quic.VersionNumber) (*CryptoFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - frame := &CryptoFrame{} - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - frame.Offset = protocol.ByteCount(offset) - dataLen, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if dataLen > uint64(r.Len()) { - return nil, io.EOF - } - if dataLen != 0 { - frame.Data = make([]byte, dataLen) - if _, err := io.ReadFull(r, frame.Data); err != nil { - // this should never happen, since we already checked the dataLen earlier - return nil, err - } - } - return frame, nil -} - -func (f *CryptoFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x6) - quicvarint.Write(b, uint64(f.Offset)) - quicvarint.Write(b, uint64(len(f.Data))) - b.Write(f.Data) - return nil -} - -// Length of a written frame -func (f *CryptoFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) -} - -// MaxDataLen returns the maximum data length -func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1 - if headerLen > maxSize { - return 0 - } - maxDataLen := maxSize - headerLen - if quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. -// It returns if the frame was actually split. -// The frame might not be split if: -// * the size is large enough to fit the whole frame -// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version quic.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { - if f.Length(version) <= maxSize { - return nil, false - } - - n := f.MaxDataLen(maxSize) - if n == 0 { - return nil, true - } - - newLen := protocol.ByteCount(len(f.Data)) - n - - new := &CryptoFrame{} - new.Offset = f.Offset - new.Data = make([]byte, newLen) - - // swap the data slices - new.Data, f.Data = f.Data, new.Data - - copy(f.Data, new.Data[n:]) - new.Data = new.Data[:n] - f.Offset += n - - return new, true -} diff --git a/internal/wire/crypto_frame_test.go b/internal/wire/crypto_frame_test.go deleted file mode 100644 index 08ede5d0..00000000 --- a/internal/wire/crypto_frame_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("CRYPTO frame", func() { - Context("when parsing", func() { - It("parses", func() { - data := []byte{0x6} - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseCryptoFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(r.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x6} - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // data length - data = append(data, []byte("foobar")...) - _, err := parseCryptoFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes a frame", func() { - f := &CryptoFrame{ - Offset: 0x123456, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x6} - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, encodeVarInt(6)...) // length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("always returns a data length such that the resulting frame has the right size", func() { - data := make([]byte, maxSize) - f := &CryptoFrame{ - Offset: 0xdeadbeef, - } - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < maxSize; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) - if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written - // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) - - Context("length", func() { - It("has the right length for a frame without offset and data length", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) - }) - }) - - Context("splitting", func() { - It("splits a frame", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - hdrLen := f.Length(protocol.Version1) - 6 - new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(new.Data).To(Equal([]byte("foo"))) - Expect(new.Offset).To(Equal(protocol.ByteCount(0x1337))) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x1337 + 3))) - }) - - It("doesn't split if there's enough space in the frame", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - f, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) - Expect(needsSplit).To(BeFalse()) - Expect(f).To(BeNil()) - }) - - It("doesn't split if the size is too small", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - length := f.Length(protocol.Version1) - 6 - for i := protocol.ByteCount(0); i <= length; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - f, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).ToNot(BeNil()) - }) - }) -}) diff --git a/internal/wire/data_blocked_frame.go b/internal/wire/data_blocked_frame.go deleted file mode 100644 index a6ab54fc..00000000 --- a/internal/wire/data_blocked_frame.go +++ /dev/null @@ -1,39 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A DataBlockedFrame is a DATA_BLOCKED frame -type DataBlockedFrame struct { - MaximumData protocol.ByteCount -} - -func parseDataBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*DataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - return &DataBlockedFrame{ - MaximumData: protocol.ByteCount(offset), - }, nil -} - -func (f *DataBlockedFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - typeByte := uint8(0x14) - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil -} - -// Length of a written frame -func (f *DataBlockedFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) -} diff --git a/internal/wire/data_blocked_frame_test.go b/internal/wire/data_blocked_frame_test.go deleted file mode 100644 index 2aac0525..00000000 --- a/internal/wire/data_blocked_frame_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("DATA_BLOCKED frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x14} - data = append(data, encodeVarInt(0x12345678)...) - b := bytes.NewReader(data) - frame, err := parseDataBlockedFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x14} - data = append(data, encodeVarInt(0x12345678)...) - _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - for i := range data { - _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := DataBlockedFrame{MaximumData: 0xdeadbeef} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x14} - expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := DataBlockedFrame{MaximumData: 0x12345} - Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x12345))) - }) - }) -}) diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go deleted file mode 100644 index a330185e..00000000 --- a/internal/wire/datagram_frame.go +++ /dev/null @@ -1,86 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A DatagramFrame is a DATAGRAM frame -type DatagramFrame struct { - DataLenPresent bool - Data []byte -} - -func parseDatagramFrame(r *bytes.Reader, _ quic.VersionNumber) (*DatagramFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &DatagramFrame{} - f.DataLenPresent = typeByte&0x1 > 0 - - var length uint64 - if f.DataLenPresent { - var err error - len, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if len > uint64(r.Len()) { - return nil, io.EOF - } - length = len - } else { - length = uint64(r.Len()) - } - f.Data = make([]byte, length) - if _, err := io.ReadFull(r, f.Data); err != nil { - return nil, err - } - return f, nil -} - -func (f *DatagramFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - typeByte := uint8(0x30) - if f.DataLenPresent { - typeByte ^= 0x1 - } - b.WriteByte(typeByte) - if f.DataLenPresent { - quicvarint.Write(b, uint64(len(f.Data))) - } - b.Write(f.Data) - return nil -} - -// MaxDataLen returns the maximum data length -func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version quic.VersionNumber) protocol.ByteCount { - headerLen := protocol.ByteCount(1) - if f.DataLenPresent { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen++ - } - if headerLen > maxSize { - return 0 - } - maxDataLen := maxSize - headerLen - if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// Length of a written frame -func (f *DatagramFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - length := 1 + protocol.ByteCount(len(f.Data)) - if f.DataLenPresent { - length += protocol.ByteCount(quicvarint.Len(uint64(len(f.Data)))) - } - return length -} diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go deleted file mode 100644 index 363d6c34..00000000 --- a/internal/wire/datagram_frame_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM frame", func() { - Context("when parsing", func() { - It("parses a frame containing a length", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(0x6)...) // length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.DataLenPresent).To(BeTrue()) - Expect(r.Len()).To(BeZero()) - }) - - It("parses a frame without length", func() { - data := []byte{0x30} - data = append(data, []byte("Lorem ipsum dolor sit amet")...) - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) - Expect(frame.DataLenPresent).To(BeFalse()) - Expect(r.Len()).To(BeZero()) - }) - - It("errors when the length is longer than the rest of the frame", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(0x6)...) // length - data = append(data, []byte("fooba")...) - r := bytes.NewReader(data) - _, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors on EOFs", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(6)...) // length - data = append(data, []byte("foobar")...) - _, err := parseDatagramFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a frame with length", func() { - f := &DatagramFrame{ - DataLenPresent: true, - Data: []byte("foobar"), - } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x30 ^ 0x1} - expected = append(expected, encodeVarInt(0x6)...) - expected = append(expected, []byte("foobar")...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a frame without length", func() { - f := &DatagramFrame{Data: []byte("Lorem ipsum")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x30} - expected = append(expected, []byte("Lorem ipsum")...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - }) - - Context("length", func() { - It("has the right length for a frame with length", func() { - f := &DatagramFrame{ - DataLenPresent: true, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(6) + 6)) - }) - - It("has the right length for a frame without length", func() { - f := &DatagramFrame{Data: []byte("foobar")} - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(1 + 6))) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("returns a data length such that the resulting frame has the right size, if data length is not present", func() { - data := make([]byte, maxSize) - f := &DatagramFrame{} - b := &bytes.Buffer{} - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(Equal(i)) - } - }) - - It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { - data := make([]byte, maxSize) - f := &DatagramFrame{DataLenPresent: true} - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) -}) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go deleted file mode 100644 index 5434399b..00000000 --- a/internal/wire/extended_header.go +++ /dev/null @@ -1,250 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/utils" -) - -// ErrInvalidReservedBits is returned when the reserved bits are incorrect. -// When this error is returned, parsing continues, and an ExtendedHeader is returned. -// This is necessary because we need to decrypt the packet in that case, -// in order to avoid a timing side-channel. -var ErrInvalidReservedBits = errors.New("invalid reserved bits") - -// ExtendedHeader is the header of a QUIC packet. -type ExtendedHeader struct { - Header - - typeByte byte - - KeyPhase protocol.KeyPhaseBit - - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - - parsedLen protocol.ByteCount -} - -func (h *ExtendedHeader) parse(b *bytes.Reader, v quic.VersionNumber) (bool /* reserved bits valid */, error) { - startLen := b.Len() - // read the (now unencrypted) first byte - var err error - h.typeByte, err = b.ReadByte() - if err != nil { - return false, err - } - if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { - return false, err - } - var reservedBitsValid bool - if h.IsLongHeader { - reservedBitsValid, err = h.parseLongHeader(b, v) - } else { - reservedBitsValid, err = h.parseShortHeader(b, v) - } - if err != nil { - return false, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return reservedBitsValid, err -} - -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ quic.VersionNumber) (bool /* reserved bits valid */, error) { - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0xc != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ quic.VersionNumber) (bool /* reserved bits valid */, error) { - h.KeyPhase = protocol.KeyPhaseZero - if h.typeByte&0x4 > 0 { - h.KeyPhase = protocol.KeyPhaseOne - } - - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0x18 != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { - h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - n, err := b.ReadByte() - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen2: - n, err := utils.BigEndian.ReadUint16(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen3: - n, err := utils.BigEndian.ReadUint24(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen4: - n, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// Write writes the Header. -func (h *ExtendedHeader) Write(b *bytes.Buffer, ver quic.VersionNumber) error { - if h.DestConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) - } - if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) - } - if h.IsLongHeader { - return h.writeLongHeader(b, ver) - } - return h.writeShortHeader(b, ver) -} - -func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version quic.VersionNumber) error { - var packetType uint8 - if version == protocol.Version2 { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b01 - case protocol.PacketType0RTT: - packetType = 0b10 - case protocol.PacketTypeHandshake: - packetType = 0b11 - case protocol.PacketTypeRetry: - packetType = 0b00 - } - } else { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b00 - case protocol.PacketType0RTT: - packetType = 0b01 - case protocol.PacketTypeHandshake: - packetType = 0b10 - case protocol.PacketTypeRetry: - packetType = 0b11 - } - } - firstByte := 0xc0 | packetType<<4 - if h.Type != protocol.PacketTypeRetry { - // Retry packets don't have a packet number - firstByte |= uint8(h.PacketNumberLen - 1) - } - - b.WriteByte(firstByte) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - b.WriteByte(uint8(h.DestConnectionID.Len())) - b.Write(h.DestConnectionID.Bytes()) - b.WriteByte(uint8(h.SrcConnectionID.Len())) - b.Write(h.SrcConnectionID.Bytes()) - - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeRetry: - b.Write(h.Token) - return nil - case protocol.PacketTypeInitial: - quicvarint.Write(b, uint64(len(h.Token))) - b.Write(h.Token) - } - quicvarint.WriteWithLen(b, uint64(h.Length), 2) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ quic.VersionNumber) error { - typeByte := 0x40 | uint8(h.PacketNumberLen-1) - if h.KeyPhase == protocol.KeyPhaseOne { - typeByte |= byte(1 << 2) - } - - b.WriteByte(typeByte) - b.Write(h.DestConnectionID.Bytes()) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen3: - utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// ParsedLen returns the number of bytes that were consumed when parsing the header -func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { - return h.parsedLen -} - -// GetLength determines the length of the Header. -func (h *ExtendedHeader) GetLength(v quic.VersionNumber) protocol.ByteCount { - if h.IsLongHeader { - length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ - if h.Type == protocol.PacketTypeInitial { - length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) - } - return length - } - - length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) - length += protocol.ByteCount(h.PacketNumberLen) - return length -} - -// Log logs the Header -func (h *ExtendedHeader) Log(logger utils.Logger) { - if h.IsLongHeader { - var token string - if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { - if len(h.Token) == 0 { - token = "Token: (empty), " - } else { - token = fmt.Sprintf("Token: %#x, ", h.Token) - } - if h.Type == protocol.PacketTypeRetry { - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) - return - } - } - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) - } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) - } -} diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go deleted file mode 100644 index f9ab54ba..00000000 --- a/internal/wire/extended_header_test.go +++ /dev/null @@ -1,481 +0,0 @@ -package wire - -import ( - "bytes" - "log" - "os" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/utils" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Header", func() { - const versionIETFHeader = protocol.VersionTLS // a QUIC version that uses the IETF Header format - - Context("Writing", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - - Context("Long Header", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - - It("writes", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, - Version: 0x1020304, - Length: protocol.InitialPacketSizeIPv4, - }, - PacketNumber: 0xdecaf, - PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{ - 0xc0 | 0x2<<4 | 0x2, - 0x1, 0x2, 0x3, 0x4, // version number - 0x6, // dest connection ID length - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID - 0x8, // src connection ID length - 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID - } - expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length - expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("refuses to write a header with a too long connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) - }) - - It("writes a header with a 20 byte connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).ToNot(HaveOccurred()) - Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) - }) - - It("writes an Initial containing a token", func() { - token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: 0x1020304, - Type: protocol.PacketTypeInitial, - Token: token, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0) - expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) - Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) - }) - - It("uses a 2-byte encoding for the length on Initial packets", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: 0x1020304, - Type: protocol.PacketTypeInitial, - Length: 37, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - b := &bytes.Buffer{} - quicvarint.WriteWithLen(b, 37, 2) - Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) - }) - - It("writes a Retry packet", func() { - token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ - IsLongHeader: true, - Version: protocol.Version1, - Type: protocol.PacketTypeRetry, - Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0xc0 | 0b11<<4} - expected = appendVersion(expected, protocol.Version1) - expected = append(expected, 0x0) // dest connection ID length - expected = append(expected, 0x0) // src connection ID length - expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - }) - - Context("long header, version 2", func() { - It("writes an Initial", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeInitial, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b01) - }) - - It("writes a Retry packet", func() { - token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeRetry, - Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0xc0 | 0b11<<4} - expected = appendVersion(expected, protocol.Version2) - expected = append(expected, 0x0) // dest connection ID length - expected = append(expected, 0x0) // src connection ID length - expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a Handshake Packet", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeHandshake, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b11) - }) - - It("writes a 0-RTT Packet", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketType0RTT, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b10) - }) - }) - - Context("short header", func() { - It("writes a header with connection ID", func() { - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - 0x42, // packet number - })) - }) - - It("writes a header without connection ID", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0x42, // packet number - })) - }) - - It("writes a header with a 2 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen2, - PacketNumber: 0x765, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x1} - expected = append(expected, []byte{0x7, 0x65}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a header with a 4 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen4, - PacketNumber: 0x12345678, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x3} - expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("errors when given an invalid packet number length", func() { - err := (&ExtendedHeader{ - PacketNumberLen: 5, - PacketNumber: 0xdecafbad, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid packet number length: 5")) - }) - - It("writes the Key Phase Bit", func() { - Expect((&ExtendedHeader{ - KeyPhase: protocol.KeyPhaseOne, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40 | 0x4, - 0x42, // packet number - })) - }) - }) - }) - - Context("getting the length", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - - It("has the right length for the Long Header, for a short length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Length: 1, - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for the Long Header, for a long length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Length: 1500, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial that has a short length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 15, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial not containing a Token", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 1500, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial containing a Token", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Type: protocol.PacketTypeInitial, - Length: 1500, - Token: []byte("foo"), - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for a Short Header containing a connection ID", func() { - h := &ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(10)) - }) - - It("has the right length for a short header without a connection ID", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(2)) - }) - - It("has the right length for a short header with a 2 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(3)) - }) - - It("has the right length for a short header with a 5 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(5)) - }) - }) - - Context("Logging", func() { - var ( - buf *bytes.Buffer - logger utils.Logger - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - logger = utils.DefaultLogger - logger.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(buf) - }) - - AfterEach(func() { - log.SetOutput(os.Stdout) - }) - - It("logs Long Headers", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, - Type: protocol.PacketTypeHandshake, - Length: 54321, - Version: 0xfeed, - }, - PacketNumber: 1337, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: deadbeefcafe1337, SrcConnectionID: decafbad13371337, PacketNumber: 1337, PacketNumberLen: 2, Length: 54321, Version: 0xfeed}")) - }) - - It("logs Initial Packets with a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeInitial, - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - Length: 100, - Version: 0xfeed, - }, - PacketNumber: 42, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0xdeadbeef, PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) - }) - - It("logs Initial packets without a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeInitial, - Length: 100, - Version: 0xfeed, - }, - PacketNumber: 42, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: (empty), PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) - }) - - It("logs Retry packets with a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeRetry, - Token: []byte{0x12, 0x34, 0x56}, - Version: 0xfeed, - }, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}")) - }) - - It("logs Short Headers containing a connection ID", func() { - (&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - }, - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 1337, - PacketNumberLen: 4, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) - }) - }) -}) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go deleted file mode 100644 index 3d06b30f..00000000 --- a/internal/wire/frame_parser.go +++ /dev/null @@ -1,144 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "github.com/lucas-clemente/quic-go" - "reflect" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" -) - -type frameParser struct { - ackDelayExponent uint8 - - supportsDatagrams bool - - version quic.VersionNumber -} - -// NewFrameParser creates a new frame parser. -func NewFrameParser(supportsDatagrams bool, v quic.VersionNumber) FrameParser { - return &frameParser{ - supportsDatagrams: supportsDatagrams, - version: v, - } -} - -// ParseNext parses the next frame. -// It skips PADDING frames. -func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { - for r.Len() != 0 { - typeByte, _ := r.ReadByte() - if typeByte == 0x0 { // PADDING frame - continue - } - r.UnreadByte() - - f, err := p.parseFrame(r, typeByte, encLevel) - if err != nil { - return nil, &qerr.TransportError{ - FrameType: uint64(typeByte), - ErrorCode: qerr.FrameEncodingError, - ErrorMessage: err.Error(), - } - } - return f, nil - } - return nil, nil -} - -func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { - var frame Frame - var err error - if typeByte&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, p.version) - } else { - switch typeByte { - case 0x1: - frame, err = parsePingFrame(r, p.version) - case 0x2, 0x3: - ackDelayExponent := p.ackDelayExponent - if encLevel != protocol.Encryption1RTT { - ackDelayExponent = protocol.DefaultAckDelayExponent - } - frame, err = parseAckFrame(r, ackDelayExponent, p.version) - case 0x4: - frame, err = parseResetStreamFrame(r, p.version) - case 0x5: - frame, err = parseStopSendingFrame(r, p.version) - case 0x6: - frame, err = parseCryptoFrame(r, p.version) - case 0x7: - frame, err = parseNewTokenFrame(r, p.version) - case 0x10: - frame, err = parseMaxDataFrame(r, p.version) - case 0x11: - frame, err = parseMaxStreamDataFrame(r, p.version) - case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, p.version) - case 0x14: - frame, err = parseDataBlockedFrame(r, p.version) - case 0x15: - frame, err = parseStreamDataBlockedFrame(r, p.version) - case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, p.version) - case 0x18: - frame, err = parseNewConnectionIDFrame(r, p.version) - case 0x19: - frame, err = parseRetireConnectionIDFrame(r, p.version) - case 0x1a: - frame, err = parsePathChallengeFrame(r, p.version) - case 0x1b: - frame, err = parsePathResponseFrame(r, p.version) - case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, p.version) - case 0x1e: - frame, err = parseHandshakeDoneFrame(r, p.version) - case 0x30, 0x31: - if p.supportsDatagrams { - frame, err = parseDatagramFrame(r, p.version) - break - } - fallthrough - default: - err = errors.New("unknown frame type") - } - } - if err != nil { - return nil, err - } - if !p.isAllowedAtEncLevel(frame, encLevel) { - return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) - } - return frame, nil -} - -func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { - switch encLevel { - case protocol.EncryptionInitial, protocol.EncryptionHandshake: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: - return true - default: - return false - } - case protocol.Encryption0RTT: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - return false - default: - return true - } - case protocol.Encryption1RTT: - return true - default: - panic("unknown encryption level") - } -} - -func (p *frameParser) SetAckDelayExponent(exp uint8) { - p.ackDelayExponent = exp -} diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go deleted file mode 100644 index ff5c83b8..00000000 --- a/internal/wire/frame_parser_test.go +++ /dev/null @@ -1,410 +0,0 @@ -package wire - -import ( - "bytes" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Frame parsing", func() { - var ( - buf *bytes.Buffer - parser FrameParser - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - parser = NewFrameParser(true, protocol.Version1) - }) - - It("returns nil if there's nothing more to read", func() { - f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeNil()) - }) - - It("skips PADDING frames", func() { - buf.Write([]byte{0}) // PADDING frame - (&PingFrame{}).Write(buf, protocol.Version1) - f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(Equal(&PingFrame{})) - }) - - It("handles PADDING at the end", func() { - r := bytes.NewReader([]byte{0, 0, 0}) - f, err := parser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeNil()) - Expect(r.Len()).To(BeZero()) - }) - - It("unpacks ACK frames", func() { - f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) - }) - - It("uses the custom ack delay exponent for 1RTT packets", func() { - parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: time.Second, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - // The ACK frame is always written using the protocol.AckDelayExponent. - // That's why we expect a different value when parsing. - Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) - }) - - It("uses the default ack delay exponent for non-1RTT packets", func() { - parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: time.Second, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) - }) - - It("unpacks RESET_STREAM frames", func() { - f := &ResetStreamFrame{ - StreamID: 0xdeadbeef, - FinalSize: 0xdecafbad1234, - ErrorCode: 0x1337, - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STOP_SENDING frames", func() { - f := &StopSendingFrame{StreamID: 0x42} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks CRYPTO frames", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("lorem ipsum"), - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks NEW_TOKEN frames", func() { - f := &NewTokenFrame{Token: []byte("foobar")} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAM frames", func() { - f := &StreamFrame{ - StreamID: 0x42, - Offset: 0x1337, - Fin: true, - Data: []byte("foobar"), - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_DATA frames", func() { - f := &MaxDataFrame{ - MaximumData: 0xcafe, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_STREAM_DATA frames", func() { - f := &MaxStreamDataFrame{ - StreamID: 0xdeadbeef, - MaximumStreamData: 0xdecafbad, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_STREAMS frames", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 0x1337, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks DATA_BLOCKED frames", func() { - f := &DataBlockedFrame{MaximumData: 0x1234} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAM_DATA_BLOCKED frames", func() { - f := &StreamDataBlockedFrame{ - StreamID: 0xdeadbeef, - MaximumStreamData: 0xdead, - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAMS_BLOCKED frames", func() { - f := &StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 0x1234567, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks NEW_CONNECTION_ID frames", func() { - f := &NewConnectionIDFrame{ - SequenceNumber: 0x1337, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks RETIRE_CONNECTION_ID frames", func() { - f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks PATH_CHALLENGE frames", func() { - f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*PathChallengeFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("unpacks PATH_RESPONSE frames", func() { - f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("unpacks CONNECTION_CLOSE frames", func() { - f := &ConnectionCloseFrame{ - IsApplicationError: true, - ReasonPhrase: "foobar", - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks HANDSHAKE_DONE frames", func() { - f := &HandshakeDoneFrame{} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks DATAGRAM frames", func() { - f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when DATAGRAM frames are not supported", func() { - parser = NewFrameParser(false, protocol.Version1) - f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FrameEncodingError, - FrameType: 0x30, - ErrorMessage: "unknown frame type", - })) - }) - - It("errors on invalid type", func() { - _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FrameEncodingError, - FrameType: 0x42, - ErrorMessage: "unknown frame type", - })) - }) - - It("errors on invalid frames", func() { - f := &MaxStreamDataFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - b := &bytes.Buffer{} - f.Write(b, protocol.Version1) - _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) - Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - }) - - Context("encryption level check", func() { - frames := []Frame{ - &PingFrame{}, - &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 42}}}, - &ResetStreamFrame{}, - &StopSendingFrame{}, - &CryptoFrame{}, - &NewTokenFrame{Token: []byte("lorem ipsum")}, - &StreamFrame{Data: []byte("foobar")}, - &MaxDataFrame{}, - &MaxStreamDataFrame{}, - &MaxStreamsFrame{}, - &DataBlockedFrame{}, - &StreamDataBlockedFrame{}, - &StreamsBlockedFrame{}, - &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, - &RetireConnectionIDFrame{}, - &PathChallengeFrame{}, - &PathResponseFrame{}, - &ConnectionCloseFrame{}, - &HandshakeDoneFrame{}, - &DatagramFrame{}, - } - - var framesSerialized [][]byte - - BeforeEach(func() { - framesSerialized = nil - for _, frame := range frames { - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) - framesSerialized = append(framesSerialized, buf.Bytes()) - } - }) - - It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Initial packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionInitial) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: - Expect(err).ToNot(HaveOccurred()) - default: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Initial")) - } - } - }) - - It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Handshake packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: - Expect(err).ToNot(HaveOccurred()) - default: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Handshake")) - } - } - }) - - It("rejects all frames but ACK, CRYPTO, CONNECTION_CLOSE, NEW_TOKEN, PATH_RESPONSE and RETIRE_CONNECTION_ID in 0-RTT packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption0RTT) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level 0-RTT")) - default: - Expect(err).ToNot(HaveOccurred()) - } - } - }) - - It("accepts all frame types in 1-RTT packets", func() { - for _, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - } - }) - }) -}) diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go deleted file mode 100644 index fad59371..00000000 --- a/internal/wire/handshake_done_frame.go +++ /dev/null @@ -1,29 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" -) - -// A HandshakeDoneFrame is a HANDSHAKE_DONE frame -type HandshakeDoneFrame struct{} - -// ParseHandshakeDoneFrame parses a HandshakeDone frame -func parseHandshakeDoneFrame(r *bytes.Reader, _ quic.VersionNumber) (*HandshakeDoneFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &HandshakeDoneFrame{}, nil -} - -func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x1e) - return nil -} - -// Length of a written frame -func (f *HandshakeDoneFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/internal/wire/header.go b/internal/wire/header.go deleted file mode 100644 index 2cce3fac..00000000 --- a/internal/wire/header.go +++ /dev/null @@ -1,275 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/utils" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// ParseConnectionID parses the destination connection ID of a packet. -// It uses the data slice for the connection ID. -// That means that the connection ID must not be used after the packet buffer is released. -func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { - if len(data) == 0 { - return nil, io.EOF - } - isLongHeader := data[0]&0x80 > 0 - if !isLongHeader { - if len(data) < shortHeaderConnIDLen+1 { - return nil, io.EOF - } - return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil - } - if len(data) < 6 { - return nil, io.EOF - } - destConnIDLen := int(data[5]) - if len(data) < 6+destConnIDLen { - return nil, io.EOF - } - return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil -} - -// IsVersionNegotiationPacket says if this is a version negotiation packet -func IsVersionNegotiationPacket(b []byte) bool { - if len(b) < 5 { - return false - } - return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 -} - -// Is0RTTPacket says if this is a 0-RTT packet. -// A packet sent with a version we don't understand can never be a 0-RTT packet. -func Is0RTTPacket(b []byte) bool { - if len(b) < 5 { - return false - } - if b[0]&0x80 == 0 { - return false - } - version := quic.VersionNumber(binary.BigEndian.Uint32(b[1:5])) - if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { - return false - } - if version == protocol.Version2 { - return b[0]>>4&0b11 == 0b10 - } - return b[0]>>4&0b11 == 0b01 -} - -var ErrUnsupportedVersion = errors.New("unsupported version") - -// The Header is the version independent part of the header -type Header struct { - IsLongHeader bool - typeByte byte - Type protocol.PacketType - - Version quic.VersionNumber - SrcConnectionID protocol.ConnectionID - DestConnectionID protocol.ConnectionID - - Length protocol.ByteCount - - Token []byte - - parsedLen protocol.ByteCount // how many bytes were read while parsing this header -} - -// ParsePacket parses a packet. -// If the packet has a long header, the packet is cut according to the length field. -// If we understand the version, the packet is header up unto the packet number. -// Otherwise, only the invariant part of the header is parsed. -func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { - hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) - if err != nil { - if err == ErrUnsupportedVersion { - return hdr, nil, nil, ErrUnsupportedVersion - } - return nil, nil, nil, err - } - var rest []byte - if hdr.IsLongHeader { - if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { - return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - packetLen := int(hdr.ParsedLen() + hdr.Length) - rest = data[packetLen:] - data = data[:packetLen] - } - return hdr, data, rest, nil -} - -// ParseHeader parses the header. -// For short header packets: up to the packet number. -// For long header packets: -// * if we understand the version: up to the packet number -// * if not, only the invariant part of the header -func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { - startLen := b.Len() - h, err := parseHeaderImpl(b, shortHeaderConnIDLen) - if err != nil { - return h, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return h, err -} - -func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { - typeByte, err := b.ReadByte() - if err != nil { - return nil, err - } - - h := &Header{ - typeByte: typeByte, - IsLongHeader: typeByte&0x80 > 0, - } - - if !h.IsLongHeader { - if h.typeByte&0x40 == 0 { - return nil, errors.New("not a QUIC packet") - } - if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { - return nil, err - } - return h, nil - } - return h, h.parseLongHeader(b) -} - -func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { - var err error - h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) - return err -} - -func (h *Header) parseLongHeader(b *bytes.Reader) error { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.Version = quic.VersionNumber(v) - if h.Version != 0 && h.typeByte&0x40 == 0 { - return errors.New("not a QUIC packet") - } - destConnIDLen, err := b.ReadByte() - if err != nil { - return err - } - h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) - if err != nil { - return err - } - srcConnIDLen, err := b.ReadByte() - if err != nil { - return err - } - h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) - if err != nil { - return err - } - if h.Version == 0 { // version negotiation packet - return nil - } - // If we don't understand the version, we have no idea how to interpret the rest of the bytes - if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { - return ErrUnsupportedVersion - } - - if h.Version == protocol.Version2 { - switch h.typeByte >> 4 & 0b11 { - case 0b00: - h.Type = protocol.PacketTypeRetry - case 0b01: - h.Type = protocol.PacketTypeInitial - case 0b10: - h.Type = protocol.PacketType0RTT - case 0b11: - h.Type = protocol.PacketTypeHandshake - } - } else { - switch h.typeByte >> 4 & 0b11 { - case 0b00: - h.Type = protocol.PacketTypeInitial - case 0b01: - h.Type = protocol.PacketType0RTT - case 0b10: - h.Type = protocol.PacketTypeHandshake - case 0b11: - h.Type = protocol.PacketTypeRetry - } - } - - if h.Type == protocol.PacketTypeRetry { - tokenLen := b.Len() - 16 - if tokenLen <= 0 { - return io.EOF - } - h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } - _, err := b.Seek(16, io.SeekCurrent) - return err - } - - if h.Type == protocol.PacketTypeInitial { - tokenLen, err := quicvarint.Read(b) - if err != nil { - return err - } - if tokenLen > uint64(b.Len()) { - return io.EOF - } - h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } - } - - pl, err := quicvarint.Read(b) - if err != nil { - return err - } - h.Length = protocol.ByteCount(pl) - return nil -} - -// ParsedLen returns the number of bytes that were consumed when parsing the header -func (h *Header) ParsedLen() protocol.ByteCount { - return h.parsedLen -} - -// ParseExtended parses the version dependent part of the header. -// The Reader has to be set such that it points to the first byte of the header. -func (h *Header) ParseExtended(b *bytes.Reader, ver quic.VersionNumber) (*ExtendedHeader, error) { - extHdr := h.toExtendedHeader() - reservedBitsValid, err := extHdr.parse(b, ver) - if err != nil { - return nil, err - } - if !reservedBitsValid { - return extHdr, ErrInvalidReservedBits - } - return extHdr, nil -} - -func (h *Header) toExtendedHeader() *ExtendedHeader { - return &ExtendedHeader{Header: *h} -} - -// PacketType is the type of the packet, for logging purposes -func (h *Header) PacketType() string { - if h.IsLongHeader { - return h.Type.String() - } - return "1-RTT" -} diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go deleted file mode 100644 index 8025571f..00000000 --- a/internal/wire/header_test.go +++ /dev/null @@ -1,584 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Header Parsing", func() { - Context("Parsing the Connection ID", func() { - It("parses the connection ID of a long header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, - Version: protocol.Version1, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - connID, err := ParseConnectionID(buf.Bytes(), 8) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - }) - - It("parses the connection ID of a short header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write([]byte("foobar")) - connID, err := ParseConnectionID(buf.Bytes(), 4) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - }) - - It("errors on EOF, for short header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < len(data); i++ { - b := make([]byte, i) - copy(b, data[:i]) - _, err := ParseConnectionID(b, 8) - Expect(err).To(MatchError(io.EOF)) - } - }) - - It("errors on EOF, for long header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, - Version: protocol.Version1, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { - b := make([]byte, i) - copy(b, data[:i]) - _, err := ParseConnectionID(b, 8) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("identifying 0-RTT packets", func() { - It("recognizes 0-RTT packets, for QUIC v1", func() { - zeroRTTHeader := make([]byte, 5) - zeroRTTHeader[0] = 0x80 | 0b01<<4 - binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) - - Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) - Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header - Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) - }) - - It("recognizes 0-RTT packets, for QUIC v2", func() { - zeroRTTHeader := make([]byte, 5) - zeroRTTHeader[0] = 0x80 | 0b10<<4 - binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) - - Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) - Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header - Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) - }) - }) - - Context("Identifying Version Negotiation Packets", func() { - It("identifies version negotiation packets", func() { - Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) - Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) - }) - - It("returns false on EOF", func() { - vnp := []byte{0x80, 0, 0, 0, 0} - for i := range vnp { - Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) - } - }) - }) - - Context("Long Headers", func() { - It("parses a Long Header", func() { - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - data := []byte{0xc0 ^ 0x3} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x9) // dest conn id length - data = append(data, destConnID...) - data = append(data, 0x4) // src conn id length - data = append(data, srcConnID...) - data = append(data, encodeVarInt(6)...) // token length - data = append(data, []byte("foobar")...) // token - data = append(data, encodeVarInt(10)...) // length - hdrLen := len(data) - data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number - data = append(data, []byte("foobar")...) - Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) - - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(pdata).To(Equal(data)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) - Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(rest).To(BeEmpty()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(b.Len()).To(Equal(6)) // foobar - Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) - Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4)) - }) - - It("errors if 0x40 is not set", func() { - data := []byte{ - 0x80 | 0x2<<4, - 0x11, // connection ID lengths - 0xde, 0xca, 0xfb, 0xad, // dest conn ID - 0xde, 0xad, 0xbe, 0xef, // src conn ID - } - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError("not a QUIC packet")) - }) - - It("stops parsing when encountering an unsupported version", func() { - data := []byte{ - 0xc0, - 0xde, 0xad, 0xbe, 0xef, - 0x8, // dest conn ID len - 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID - 0x8, // src conn ID len - 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID - 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes - } - hdr, _, rest, err := ParsePacket(data, 0) - Expect(err).To(MatchError(ErrUnsupportedVersion)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(Equal(quic.VersionNumber(0xdeadbeef))) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) - Expect(rest).To(BeEmpty()) - }) - - It("parses a Long Header without a destination connection ID", func() { - data := []byte{0xc0 ^ 0x1<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // dest conn ID len - data = append(data, 0x4) // src conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(hdr.DestConnectionID).To(BeEmpty()) - }) - - It("parses a Long Header without a source connection ID", func() { - data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0xa) // dest conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID - data = append(data, 0x0) // src conn ID len - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.SrcConnectionID).To(BeEmpty()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - }) - - It("parses a Long Header with a 2 byte packet number", func() { - data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, protocol.Version1) // version number - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(0)...) // token length - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0x1, 0x23}...) - - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("parses a Retry packet, for QUIC v1", func() { - data := []byte{0xc0 | 0b11<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{6}...) // dest conn ID len - data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID - data = append(data, []byte{10}...) // src conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("parses a Retry packet, for QUIC v2", func() { - data := []byte{0xc0 | 0b00<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version2) - data = append(data, []byte{6}...) // dest conn ID len - data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID - data = append(data, []byte{10}...) // src conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(hdr.Version).To(Equal(protocol.Version2)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("errors if the Retry packet is too short for the integrity tag", func() { - data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0, 0}...) // conn ID lens - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) - // this results in a token length of 0 - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors if the token length is too large", func() { - data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // connection ID lengths - data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) - data = append(data, encodeVarInt(0x42)...) // length, 1 byte - data = append(data, []byte{0x12, 0x34}...) // packet number - - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors if the 5th or 6th bit are set", func() { - data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(2)...) // length - data = append(data, []byte{0x12, 0x34}...) // packet number - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(extHdr).ToNot(BeNil()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234))) - }) - - It("errors on EOF, when parsing the header", func() { - data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x8) // dest conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID - data = append(data, 0x8) // src conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID - for i := 0; i < len(data); i++ { - _, _, _, err := ParsePacket(data[:i], 0) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, when parsing the extended header", func() { - data := []byte{0xc0 | 0x2<<4 | 0x3} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(0)...) // length - hdrLen := len(data) - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, for a Retry packet", func() { - data := []byte{0xc0 ^ 0x3<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, 0xa) // Orig Destination Connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - hdrLen := len(data) - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - - Context("coalesced packets", func() { - It("cuts packets", func() { - buf := &bytes.Buffer{} - hdr := Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 2 + 6, - Version: protocol.Version1, - } - Expect((&ExtendedHeader{ - Header: hdr, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - hdrRaw := append([]byte{}, buf.Bytes()...) - buf.Write([]byte("foobar")) // payload of the first packet - buf.Write([]byte("raboof")) // second packet - parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4) - Expect(err).ToNot(HaveOccurred()) - Expect(parsedHdr.Type).To(Equal(hdr.Type)) - Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) - Expect(rest).To(Equal([]byte("raboof"))) - }) - - It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 3, - Version: protocol.Version1, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - _, _, _, err := ParsePacket(buf.Bytes(), 4) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) - }) - - It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 1000, - Version: protocol.Version1, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write(make([]byte, 500-2 /* for packet number length */)) - _, _, _, err := ParsePacket(buf.Bytes(), 4) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - }) - }) - - Context("Short Headers", func() { - It("reads a Short Header with a 8 byte connection ID", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x40}, connID...) - data = append(data, 0x42) // packet number - Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) - - hdr, pdata, rest, err := ParsePacket(data, 8) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) - Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("errors if 0x40 is not set", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x0}, connID...) - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(MatchError("not a QUIC packet")) - }) - - It("errors if the 4th or 5th bit are set", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) - data = append(data, 0x42) // packet number - hdr, _, _, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(extHdr).ToNot(BeNil()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - }) - - It("reads a Short Header with a 5 byte connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40}, connID...) - data = append(data, 0x42) // packet number - hdr, pdata, rest, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(pdata).To(HaveLen(len(data))) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) - Expect(rest).To(BeEmpty()) - }) - - It("reads the Key Phase Bit", func() { - data := []byte{ - 0x40 ^ 0x4, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - data = append(data, 11) // packet number - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 2 byte packet number", func() { - data := []byte{ - 0x40 | 0x1, - 0xde, 0xad, 0xbe, 0xef, // connection ID - } - data = append(data, []byte{0x13, 0x37}...) // packet number - hdr, _, _, err := ParsePacket(data, 4) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 3 byte packet number", func() { - data := []byte{ - 0x40 | 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID - } - data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number - hdr, _, _, err := ParsePacket(data, 10) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOF, when parsing the header", func() { - data := []byte{ - 0x40 ^ 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - } - for i := 0; i < len(data); i++ { - data = data[:i] - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, when parsing the extended header", func() { - data := []byte{ - 0x40 ^ 0x3, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - hdrLen := len(data) - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - }) - - It("tells its packet type for logging", func() { - Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake")) - Expect((&Header{}).PacketType()).To(Equal("1-RTT")) - }) -}) diff --git a/internal/wire/interface.go b/internal/wire/interface.go deleted file mode 100644 index b096a6e1..00000000 --- a/internal/wire/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" -) - -// A Frame in QUIC -type Frame interface { - Write(b *bytes.Buffer, version quic.VersionNumber) error - Length(version quic.VersionNumber) protocol.ByteCount -} - -// A FrameParser parses QUIC frames, one by one. -type FrameParser interface { - ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) - SetAckDelayExponent(uint8) -} diff --git a/internal/wire/log.go b/internal/wire/log.go deleted file mode 100644 index 276465ee..00000000 --- a/internal/wire/log.go +++ /dev/null @@ -1,72 +0,0 @@ -package wire - -import ( - "fmt" - "strings" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/utils" -) - -// LogFrame logs a frame, either sent or received -func LogFrame(logger utils.Logger, frame Frame, sent bool) { - if !logger.Debug() { - return - } - dir := "<-" - if sent { - dir = "->" - } - switch f := frame.(type) { - case *CryptoFrame: - dataLen := protocol.ByteCount(len(f.Data)) - logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) - case *StreamFrame: - logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) - case *ResetStreamFrame: - logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) - case *AckFrame: - hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 - var ecn string - if hasECN { - ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) - } - if len(f.AckRanges) > 1 { - ackRanges := make([]string, len(f.AckRanges)) - for i, r := range f.AckRanges { - ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) - } - logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) - } else { - logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) - } - case *MaxDataFrame: - logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) - case *MaxStreamDataFrame: - logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) - case *DataBlockedFrame: - logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) - case *StreamDataBlockedFrame: - logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) - case *MaxStreamsFrame: - switch f.Type { - case protocol.StreamTypeUni: - logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) - case protocol.StreamTypeBidi: - logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) - } - case *StreamsBlockedFrame: - switch f.Type { - case protocol.StreamTypeUni: - logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) - case protocol.StreamTypeBidi: - logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) - } - case *NewConnectionIDFrame: - logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) - case *NewTokenFrame: - logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) - default: - logger.Debugf("\t%s %#v", dir, frame) - } -} diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go deleted file mode 100644 index 7094fcc5..00000000 --- a/internal/wire/log_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package wire - -import ( - "bytes" - "log" - "os" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Frame logging", func() { - var ( - buf *bytes.Buffer - logger utils.Logger - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - logger = utils.DefaultLogger - logger.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(buf) - }) - - AfterEach(func() { - log.SetOutput(os.Stdout) - }) - - It("doesn't log when debug is disabled", func() { - logger.SetLogLevel(utils.LogLevelInfo) - LogFrame(logger, &ResetStreamFrame{}, true) - Expect(buf.Len()).To(BeZero()) - }) - - It("logs sent frames", func() { - LogFrame(logger, &ResetStreamFrame{}, true) - Expect(buf.String()).To(ContainSubstring("\t-> &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) - }) - - It("logs received frames", func() { - LogFrame(logger, &ResetStreamFrame{}, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) - }) - - It("logs CRYPTO frames", func() { - frame := &CryptoFrame{ - Offset: 42, - Data: make([]byte, 123), - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.CryptoFrame{Offset: 42, Data length: 123, Offset + Data length: 165}\n")) - }) - - It("logs STREAM frames", func() { - frame := &StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'f'}, 100), - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamFrame{StreamID: 42, Fin: false, Offset: 1337, Data length: 100, Offset + Data length: 1437}\n")) - }) - - It("logs ACK frames without missing packets", func() { - frame := &AckFrame{ - AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, - DelayTime: 1 * time.Millisecond, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms}\n")) - }) - - It("logs ACK frames with ECN", func() { - frame := &AckFrame{ - AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, - DelayTime: 1 * time.Millisecond, - ECT0: 5, - ECT1: 66, - ECNCE: 777, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms, ECT0: 5, ECT1: 66, CE: 777}\n")) - }) - - It("logs ACK frames with missing packets", func() { - frame := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 8}, - {Smallest: 2, Largest: 3}, - }, - DelayTime: 12 * time.Millisecond, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 8, LowestAcked: 2, AckRanges: {{Largest: 8, Smallest: 5}, {Largest: 3, Smallest: 2}}, DelayTime: 12ms}\n")) - }) - - It("logs MAX_STREAMS frames", func() { - frame := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: 42}\n")) - }) - - It("logs MAX_DATA frames", func() { - frame := &MaxDataFrame{ - MaximumData: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxDataFrame{MaximumData: 42}\n")) - }) - - It("logs MAX_STREAM_DATA frames", func() { - frame := &MaxStreamDataFrame{ - StreamID: 10, - MaximumStreamData: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 42}\n")) - }) - - It("logs DATA_BLOCKED frames", func() { - frame := &DataBlockedFrame{ - MaximumData: 1000, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.DataBlockedFrame{MaximumData: 1000}\n")) - }) - - It("logs STREAM_DATA_BLOCKED frames", func() { - frame := &StreamDataBlockedFrame{ - StreamID: 42, - MaximumStreamData: 1000, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1000}\n")) - }) - - It("logs STREAMS_BLOCKED frames", func() { - frame := &StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: 42}\n")) - }) - - It("logs NEW_CONNECTION_ID frames", func() { - LogFrame(logger, &NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, - }, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) - }) - - It("logs NEW_TOKEN frames", func() { - LogFrame(logger, &NewTokenFrame{ - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - }, true) - Expect(buf.String()).To(ContainSubstring("\t-> &wire.NewTokenFrame{Token: 0xdeadbeef")) - }) -}) diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go deleted file mode 100644 index 31ec503d..00000000 --- a/internal/wire/max_data_frame.go +++ /dev/null @@ -1,41 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A MaxDataFrame carries flow control information for the connection -type MaxDataFrame struct { - MaximumData protocol.ByteCount -} - -// parseMaxDataFrame parses a MAX_DATA frame -func parseMaxDataFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - frame := &MaxDataFrame{} - byteOffset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - frame.MaximumData = protocol.ByteCount(byteOffset) - return frame, nil -} - -// Write writes a MAX_STREAM_DATA frame -func (f *MaxDataFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - b.WriteByte(0x10) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil -} - -// Length of a written frame -func (f *MaxDataFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) -} diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go deleted file mode 100644 index c363ecd8..00000000 --- a/internal/wire/max_data_frame_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_DATA frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x10} - data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset - b := bytes.NewReader(data) - frame, err := parseMaxDataFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x10} - data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset - _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &MaxDataFrame{ - MaximumData: 0xdeadbeef, - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) - }) - - It("writes a MAX_DATA frame", func() { - b := &bytes.Buffer{} - f := &MaxDataFrame{ - MaximumData: 0xdeadbeefcafe, - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x10} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go deleted file mode 100644 index bc300279..00000000 --- a/internal/wire/max_stream_data_frame.go +++ /dev/null @@ -1,47 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A MaxStreamDataFrame is a MAX_STREAM_DATA frame -type MaxStreamDataFrame struct { - StreamID protocol.StreamID - MaximumStreamData protocol.ByteCount -} - -func parseMaxStreamDataFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxStreamDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &MaxStreamDataFrame{ - StreamID: protocol.StreamID(sid), - MaximumStreamData: protocol.ByteCount(offset), - }, nil -} - -func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - b.WriteByte(0x11) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil -} - -// Length of a written frame -func (f *MaxStreamDataFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) -} diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go deleted file mode 100644 index 4e205ad9..00000000 --- a/internal/wire/max_stream_data_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_STREAM_DATA frame", func() { - Context("parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x11} - data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID - data = append(data, encodeVarInt(0x12345678)...) // Offset - b := bytes.NewReader(data) - frame, err := parseMaxStreamDataFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x11} - data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID - data = append(data, encodeVarInt(0x12345678)...) // Offset - _, err := parseMaxStreamDataFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &MaxStreamDataFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - Expect(f.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)))) - }) - - It("writes a sample frame", func() { - b := &bytes.Buffer{} - f := &MaxStreamDataFrame{ - StreamID: 0xdecafbad, - MaximumStreamData: 0xdeadbeefcafe42, - } - expected := []byte{0x11} - expected = append(expected, encodeVarInt(0xdecafbad)...) - expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go deleted file mode 100644 index d25c2cad..00000000 --- a/internal/wire/max_streams_frame.go +++ /dev/null @@ -1,56 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A MaxStreamsFrame is a MAX_STREAMS frame -type MaxStreamsFrame struct { - Type protocol.StreamType - MaxStreamNum protocol.StreamNum -} - -func parseMaxStreamsFrame(r *bytes.Reader, _ quic.VersionNumber) (*MaxStreamsFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &MaxStreamsFrame{} - switch typeByte { - case 0x12: - f.Type = protocol.StreamTypeBidi - case 0x13: - f.Type = protocol.StreamTypeUni - } - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.MaxStreamNum = protocol.StreamNum(streamID) - if f.MaxStreamNum > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) - } - return f, nil -} - -func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - switch f.Type { - case protocol.StreamTypeBidi: - b.WriteByte(0x12) - case protocol.StreamTypeUni: - b.WriteByte(0x13) - } - quicvarint.Write(b, uint64(f.MaxStreamNum)) - return nil -} - -// Length of a written frame -func (f *MaxStreamsFrame) Length(quic.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) -} diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go deleted file mode 100644 index bc7ec913..00000000 --- a/internal/wire/max_streams_frame_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_STREAMS frame", func() { - Context("parsing", func() { - It("accepts a frame for a bidirectional stream", func() { - data := []byte{0x12} - data = append(data, encodeVarInt(0xdecaf)...) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) - Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts a frame for a bidirectional stream", func() { - data := []byte{0x13} - data = append(data, encodeVarInt(0xdecaf)...) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeUni)) - Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x1d} - data = append(data, encodeVarInt(0xdeadbeefcafe13)...) - _, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } - }) - - for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { - streamType := t - - It("accepts a frame containing the maximum stream count", func() { - f := &MaxStreamsFrame{ - Type: streamType, - MaxStreamNum: protocol.MaxStreamCount, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when receiving a too large stream count", func() { - f := &MaxStreamsFrame{ - Type: streamType, - MaxStreamNum: protocol.MaxStreamCount + 1, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) - }) - } - }) - - Context("writing", func() { - It("for a bidirectional stream", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 0xdeadbeef, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x12} - expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("for a unidirectional stream", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: 0xdecafbad, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x13} - expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := MaxStreamsFrame{MaxStreamNum: 0x1337} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(0x1337))) - }) - }) -}) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go deleted file mode 100644 index a79603b9..00000000 --- a/internal/wire/new_connection_id_frame.go +++ /dev/null @@ -1,81 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame -type NewConnectionIDFrame struct { - SequenceNumber uint64 - RetirePriorTo uint64 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} - -func parseNewConnectionIDFrame(r *bytes.Reader, _ quic.VersionNumber) (*NewConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - seq, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - ret, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if ret > seq { - //nolint:stylecheck - return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) - } - connIDLen, err := r.ReadByte() - if err != nil { - return nil, err - } - if connIDLen > protocol.MaxConnIDLen { - return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return nil, err - } - frame := &NewConnectionIDFrame{ - SequenceNumber: seq, - RetirePriorTo: ret, - ConnectionID: connID, - } - if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - - return frame, nil -} - -func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x18) - quicvarint.Write(b, f.SequenceNumber) - quicvarint.Write(b, f.RetirePriorTo) - connIDLen := f.ConnectionID.Len() - if connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - b.WriteByte(uint8(connIDLen)) - b.Write(f.ConnectionID.Bytes()) - b.Write(f.StatelessResetToken[:]) - return nil -} - -// Length of a written frame -func (f *NewConnectionIDFrame) Length(quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)+quicvarint.Len(f.RetirePriorTo)) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 -} diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go deleted file mode 100644 index 91bc2e20..00000000 --- a/internal/wire/new_connection_id_frame_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_CONNECTION_ID frame", func() { - Context("when parsing", func() { - It("accepts a sample frame", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe)...) // retire prior to - data = append(data, 10) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - frame, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) - Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) - Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) - }) - - It("errors when the Retire Prior To value is larger than the Sequence Number", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(1000)...) // sequence number - data = append(data, encodeVarInt(1001)...) // retire prior to - data = append(data, 3) - data = append(data, []byte{1, 2, 3}...) - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) - }) - - It("errors when the connection ID has an invalid length", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe)...) // retire prior to - data = append(data, 21) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("invalid connection ID length: 21")) - }) - - It("errors on EOFs", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe1234)...) // retire prior to - data = append(data, 10) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - frame := &NewConnectionIDFrame{ - SequenceNumber: 0x1337, - RetirePriorTo: 0x42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, - StatelessResetToken: token, - } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x18} - expected = append(expected, encodeVarInt(0x1337)...) - expected = append(expected, encodeVarInt(0x42)...) - expected = append(expected, 6) - expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) - expected = append(expected, token[:]...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct length", func() { - token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - frame := &NewConnectionIDFrame{ - SequenceNumber: 0xdecafbad, - RetirePriorTo: 0xdeadbeefcafe, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - StatelessResetToken: token, - } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) - }) - }) -}) diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go deleted file mode 100644 index a7ff519f..00000000 --- a/internal/wire/new_token_frame.go +++ /dev/null @@ -1,49 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A NewTokenFrame is a NEW_TOKEN frame -type NewTokenFrame struct { - Token []byte -} - -func parseNewTokenFrame(r *bytes.Reader, _ quic.VersionNumber) (*NewTokenFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - tokenLen, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if uint64(r.Len()) < tokenLen { - return nil, io.EOF - } - if tokenLen == 0 { - return nil, errors.New("token must not be empty") - } - token := make([]byte, int(tokenLen)) - if _, err := io.ReadFull(r, token); err != nil { - return nil, err - } - return &NewTokenFrame{Token: token}, nil -} - -func (f *NewTokenFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x7) - quicvarint.Write(b, uint64(len(f.Token))) - b.Write(f.Token) - return nil -} - -// Length of a written frame -func (f *NewTokenFrame) Length(quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(len(f.Token)))) + protocol.ByteCount(len(f.Token)) -} diff --git a/internal/wire/new_token_frame_test.go b/internal/wire/new_token_frame_test.go deleted file mode 100644 index c4a6685c..00000000 --- a/internal/wire/new_token_frame_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_TOKEN frame", func() { - Context("parsing", func() { - It("accepts a sample frame", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(len(token)))...) - data = append(data, token...) - b := bytes.NewReader(data) - f, err := parseNewTokenFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(string(f.Token)).To(Equal(token)) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects empty tokens", func() { - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(0))...) - b := bytes.NewReader(data) - _, err := parseNewTokenFrame(b, protocol.VersionWhatever) - Expect(err).To(MatchError("token must not be empty")) - }) - - It("errors on EOFs", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(len(token)))...) - data = append(data, token...) - _, err := parseNewTokenFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseNewTokenFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("writes a sample frame", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." - f := &NewTokenFrame{Token: []byte(token)} - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x7} - expected = append(expected, encodeVarInt(uint64(len(token)))...) - expected = append(expected, token...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := &NewTokenFrame{Token: []byte("foobar")} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(6) + 6)) - }) - }) -}) diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go deleted file mode 100644 index ae519a7f..00000000 --- a/internal/wire/path_challenge_frame.go +++ /dev/null @@ -1,39 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" -) - -// A PathChallengeFrame is a PATH_CHALLENGE frame -type PathChallengeFrame struct { - Data [8]byte -} - -func parsePathChallengeFrame(r *bytes.Reader, _ quic.VersionNumber) (*PathChallengeFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &PathChallengeFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - return frame, nil -} - -func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x1a) - b.Write(f.Data[:]) - return nil -} - -// Length of a written frame -func (f *PathChallengeFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 + 8 -} diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go deleted file mode 100644 index 52d08d90..00000000 --- a/internal/wire/path_challenge_frame_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PATH_CHALLENGE frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathChallengeFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("errors on EOFs", func() { - data := []byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathChallengeFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) - }) - - It("has the correct min length", func() { - frame := PathChallengeFrame{} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) - }) - }) -}) diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go deleted file mode 100644 index d8dbebdc..00000000 --- a/internal/wire/path_response_frame.go +++ /dev/null @@ -1,39 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" -) - -// A PathResponseFrame is a PATH_RESPONSE frame -type PathResponseFrame struct { - Data [8]byte -} - -func parsePathResponseFrame(r *bytes.Reader, _ quic.VersionNumber) (*PathResponseFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &PathResponseFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - return frame, nil -} - -func (f *PathResponseFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x1b) - b.Write(f.Data[:]) - return nil -} - -// Length of a written frame -func (f *PathResponseFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 + 8 -} diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go deleted file mode 100644 index 872d1c59..00000000 --- a/internal/wire/path_response_frame_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PATH_RESPONSE frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathResponseFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("errors on EOFs", func() { - data := []byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) - }) - - It("has the correct min length", func() { - frame := PathResponseFrame{} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) - }) - }) -}) diff --git a/internal/wire/ping_frame.go b/internal/wire/ping_frame.go deleted file mode 100644 index 38b47c2f..00000000 --- a/internal/wire/ping_frame.go +++ /dev/null @@ -1,28 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" -) - -// A PingFrame is a PING frame -type PingFrame struct{} - -func parsePingFrame(r *bytes.Reader, _ quic.VersionNumber) (*PingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &PingFrame{}, nil -} - -func (f *PingFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - b.WriteByte(0x1) - return nil -} - -// Length of a written frame -func (f *PingFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/internal/wire/ping_frame_test.go b/internal/wire/ping_frame_test.go deleted file mode 100644 index 3664731a..00000000 --- a/internal/wire/ping_frame_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PingFrame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1}) - _, err := parsePingFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - _, err := parsePingFrame(bytes.NewReader(nil), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PingFrame{} - frame.Write(b, protocol.VersionWhatever) - Expect(b.Bytes()).To(Equal([]byte{0x1})) - }) - - It("has the correct min length", func() { - frame := PingFrame{} - Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1))) - }) - }) -}) diff --git a/internal/wire/pool.go b/internal/wire/pool.go deleted file mode 100644 index aba32137..00000000 --- a/internal/wire/pool.go +++ /dev/null @@ -1,33 +0,0 @@ -package wire - -import ( - "sync" - - "github.com/imroc/req/v3/internal/protocol" -) - -var pool sync.Pool - -func init() { - pool.New = func() interface{} { - return &StreamFrame{ - Data: make([]byte, 0, protocol.MaxPacketBufferSize), - fromPool: true, - } - } -} - -func GetStreamFrame() *StreamFrame { - f := pool.Get().(*StreamFrame) - return f -} - -func putStreamFrame(f *StreamFrame) { - if !f.fromPool { - return - } - if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { - panic("wire.PutStreamFrame called with packet of wrong size!") - } - pool.Put(f) -} diff --git a/internal/wire/pool_test.go b/internal/wire/pool_test.go deleted file mode 100644 index b55e493b..00000000 --- a/internal/wire/pool_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package wire - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Pool", func() { - It("gets and puts STREAM frames", func() { - f := GetStreamFrame() - putStreamFrame(f) - }) - - It("panics when putting a STREAM frame with a wrong capacity", func() { - f := GetStreamFrame() - f.Data = []byte("foobar") - Expect(func() { putStreamFrame(f) }).To(Panic()) - }) - - It("accepts STREAM frames not from the buffer, but ignores them", func() { - f := &StreamFrame{Data: []byte("foobar")} - putStreamFrame(f) - }) -}) diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go deleted file mode 100644 index 190ddda5..00000000 --- a/internal/wire/reset_stream_frame.go +++ /dev/null @@ -1,59 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A ResetStreamFrame is a RESET_STREAM frame in QUIC -type ResetStreamFrame struct { - StreamID protocol.StreamID - ErrorCode qerr.StreamErrorCode - FinalSize protocol.ByteCount -} - -func parseResetStreamFrame(r *bytes.Reader, _ quic.VersionNumber) (*ResetStreamFrame, error) { - if _, err := r.ReadByte(); err != nil { // read the TypeByte - return nil, err - } - - var streamID protocol.StreamID - var byteOffset protocol.ByteCount - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - streamID = protocol.StreamID(sid) - errorCode, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - bo, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - byteOffset = protocol.ByteCount(bo) - - return &ResetStreamFrame{ - StreamID: streamID, - ErrorCode: qerr.StreamErrorCode(errorCode), - FinalSize: byteOffset, - }, nil -} - -func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x4) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - quicvarint.Write(b, uint64(f.FinalSize)) - return nil -} - -// Length of a written frame -func (f *ResetStreamFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) -} diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go deleted file mode 100644 index e4b008d5..00000000 --- a/internal/wire/reset_stream_frame_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("RESET_STREAM frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x4} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - data = append(data, encodeVarInt(0x987654321)...) // byte offset - b := bytes.NewReader(data) - frame, err := parseResetStreamFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) - Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) - }) - - It("errors on EOFs", func() { - data := []byte{0x4} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - data = append(data, encodeVarInt(0x987654321)...) // byte offset - _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - frame := ResetStreamFrame{ - StreamID: 0x1337, - FinalSize: 0x11223344decafbad, - ErrorCode: 0xcafe, - } - b := &bytes.Buffer{} - err := frame.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x4} - expected = append(expected, encodeVarInt(0x1337)...) - expected = append(expected, encodeVarInt(0xcafe)...) - expected = append(expected, encodeVarInt(0x11223344decafbad)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - rst := ResetStreamFrame{ - StreamID: 0x1337, - FinalSize: 0x1234567, - ErrorCode: 0xde, - } - expectedLen := 1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + 2 - Expect(rst.Length(protocol.Version1)).To(Equal(expectedLen)) - }) - }) -}) diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go deleted file mode 100644 index 1d8d1dbd..00000000 --- a/internal/wire/retire_connection_id_frame.go +++ /dev/null @@ -1,37 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame -type RetireConnectionIDFrame struct { - SequenceNumber uint64 -} - -func parseRetireConnectionIDFrame(r *bytes.Reader, _ quic.VersionNumber) (*RetireConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - seq, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - return &RetireConnectionIDFrame{SequenceNumber: seq}, nil -} - -func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x19) - quicvarint.Write(b, f.SequenceNumber) - return nil -} - -// Length of a written frame -func (f *RetireConnectionIDFrame) Length(quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)) -} diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go deleted file mode 100644 index b67f733d..00000000 --- a/internal/wire/retire_connection_id_frame_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_CONNECTION_ID frame", func() { - Context("when parsing", func() { - It("accepts a sample frame", func() { - data := []byte{0x19} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - b := bytes.NewReader(data) - frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) - }) - - It("errors on EOFs", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x19} - expected = append(expected, encodeVarInt(0x1337)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct length", func() { - frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) - }) - }) -}) diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go deleted file mode 100644 index f9d44db9..00000000 --- a/internal/wire/stop_sending_frame.go +++ /dev/null @@ -1,49 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A StopSendingFrame is a STOP_SENDING frame -type StopSendingFrame struct { - StreamID protocol.StreamID - ErrorCode qerr.StreamErrorCode -} - -// parseStopSendingFrame parses a STOP_SENDING frame -func parseStopSendingFrame(r *bytes.Reader, _ quic.VersionNumber) (*StopSendingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - errorCode, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &StopSendingFrame{ - StreamID: protocol.StreamID(streamID), - ErrorCode: qerr.StreamErrorCode(errorCode), - }, nil -} - -// Length of a written frame -func (f *StopSendingFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) -} - -func (f *StopSendingFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - b.WriteByte(0x5) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - return nil -} diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go deleted file mode 100644 index 9e709b32..00000000 --- a/internal/wire/stop_sending_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STOP_SENDING frame", func() { - Context("when parsing", func() { - It("parses a sample frame", func() { - data := []byte{0x5} - data = append(data, encodeVarInt(0xdecafbad)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - b := bytes.NewReader(data) - frame, err := parseStopSendingFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) - Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x5} - data = append(data, encodeVarInt(0xdecafbad)...) // stream ID - data = append(data, encodeVarInt(0x123456)...) // error code - _, err := parseStopSendingFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes", func() { - frame := &StopSendingFrame{ - StreamID: 0xdeadbeefcafe, - ErrorCode: 0xdecafbad, - } - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x5} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := &StopSendingFrame{ - StreamID: 0xdeadbeef, - ErrorCode: 0x1234567, - } - Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) - }) - }) -}) diff --git a/internal/wire/stream_data_blocked_frame.go b/internal/wire/stream_data_blocked_frame.go deleted file mode 100644 index 011d14c7..00000000 --- a/internal/wire/stream_data_blocked_frame.go +++ /dev/null @@ -1,47 +0,0 @@ -package wire - -import ( - "bytes" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame -type StreamDataBlockedFrame struct { - StreamID protocol.StreamID - MaximumStreamData protocol.ByteCount -} - -func parseStreamDataBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamDataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &StreamDataBlockedFrame{ - StreamID: protocol.StreamID(sid), - MaximumStreamData: protocol.ByteCount(offset), - }, nil -} - -func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - b.WriteByte(0x15) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil -} - -// Length of a written frame -func (f *StreamDataBlockedFrame) Length(version quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) -} diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go deleted file mode 100644 index 5306edcb..00000000 --- a/internal/wire/stream_data_blocked_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM_DATA_BLOCKED frame", func() { - Context("parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x15} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - b := bytes.NewReader(data) - frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x15} - data = append(data, encodeVarInt(0xdeadbeef)...) - data = append(data, encodeVarInt(0xc0010ff)...) - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &StreamDataBlockedFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - Expect(f.Length(0)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0xdeadbeef))) - }) - - It("writes a sample frame", func() { - b := &bytes.Buffer{} - f := &StreamDataBlockedFrame{ - StreamID: 0xdecafbad, - MaximumStreamData: 0x1337, - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x15} - expected = append(expected, encodeVarInt(uint64(f.StreamID))...) - expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go deleted file mode 100644 index ed21d388..00000000 --- a/internal/wire/stream_frame.go +++ /dev/null @@ -1,190 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "github.com/lucas-clemente/quic-go" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A StreamFrame of QUIC -type StreamFrame struct { - StreamID protocol.StreamID - Offset protocol.ByteCount - Data []byte - Fin bool - DataLenPresent bool - - fromPool bool -} - -func parseStreamFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - hasOffset := typeByte&0x4 > 0 - fin := typeByte&0x1 > 0 - hasDataLen := typeByte&0x2 > 0 - - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - var offset uint64 - if hasOffset { - offset, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - } - - var dataLen uint64 - if hasDataLen { - var err error - dataLen, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - } else { - // The rest of the packet is data - dataLen = uint64(r.Len()) - } - - var frame *StreamFrame - if dataLen < protocol.MinStreamFrameBufferSize { - frame = &StreamFrame{Data: make([]byte, dataLen)} - } else { - frame = GetStreamFrame() - // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, - // since those StreamFrames have a buffer length of the maximum packet size. - if dataLen > uint64(cap(frame.Data)) { - return nil, io.EOF - } - frame.Data = frame.Data[:dataLen] - } - - frame.StreamID = protocol.StreamID(streamID) - frame.Offset = protocol.ByteCount(offset) - frame.Fin = fin - frame.DataLenPresent = hasDataLen - - if dataLen != 0 { - if _, err := io.ReadFull(r, frame.Data); err != nil { - return nil, err - } - } - if frame.Offset+frame.DataLen() > protocol.MaxByteCount { - return nil, errors.New("stream data overflows maximum offset") - } - return frame, nil -} - -// Write writes a STREAM frame -func (f *StreamFrame) Write(b *bytes.Buffer, version quic.VersionNumber) error { - if len(f.Data) == 0 && !f.Fin { - return errors.New("StreamFrame: attempting to write empty frame without FIN") - } - - typeByte := byte(0x8) - if f.Fin { - typeByte ^= 0x1 - } - hasOffset := f.Offset != 0 - if f.DataLenPresent { - typeByte ^= 0x2 - } - if hasOffset { - typeByte ^= 0x4 - } - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.StreamID)) - if hasOffset { - quicvarint.Write(b, uint64(f.Offset)) - } - if f.DataLenPresent { - quicvarint.Write(b, uint64(f.DataLen())) - } - b.Write(f.Data) - return nil -} - -// Length returns the total length of the STREAM frame -func (f *StreamFrame) Length(version quic.VersionNumber) protocol.ByteCount { - length := 1 + quicvarint.Len(uint64(f.StreamID)) - if f.Offset != 0 { - length += quicvarint.Len(uint64(f.Offset)) - } - if f.DataLenPresent { - length += quicvarint.Len(uint64(f.DataLen())) - } - return protocol.ByteCount(length) + f.DataLen() -} - -// DataLen gives the length of data in bytes -func (f *StreamFrame) DataLen() protocol.ByteCount { - return protocol.ByteCount(len(f.Data)) -} - -// MaxDataLen returns the maximum data length -// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). -func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version quic.VersionNumber) protocol.ByteCount { - headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) - if f.Offset != 0 { - headerLen += quicvarint.Len(uint64(f.Offset)) - } - if f.DataLenPresent { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen++ - } - if protocol.ByteCount(headerLen) > maxSize { - return 0 - } - maxDataLen := maxSize - protocol.ByteCount(headerLen) - if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. -// It returns if the frame was actually split. -// The frame might not be split if: -// * the size is large enough to fit the whole frame -// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version quic.VersionNumber) (*StreamFrame, bool /* was splitting required */) { - if maxSize >= f.Length(version) { - return nil, false - } - - n := f.MaxDataLen(maxSize, version) - if n == 0 { - return nil, true - } - - new := GetStreamFrame() - new.StreamID = f.StreamID - new.Offset = f.Offset - new.Fin = false - new.DataLenPresent = f.DataLenPresent - - // swap the data slices - new.Data, f.Data = f.Data, new.Data - new.fromPool, f.fromPool = f.fromPool, new.fromPool - - f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] - copy(f.Data, new.Data[n:]) - new.Data = new.Data[:n] - f.Offset += n - - return new, true -} - -func (f *StreamFrame) PutBack() { - putStreamFrame(f) -} diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go deleted file mode 100644 index 9e49de8a..00000000 --- a/internal/wire/stream_frame_test.go +++ /dev/null @@ -1,443 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM frame", func() { - Context("when parsing", func() { - It("parses a frame with OFF bit", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(r.Len()).To(BeZero()) - }) - - It("respects the LEN when parsing the frame", func() { - data := []byte{0x8 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(4)...) // data length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal([]byte("foob"))) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(Equal(2)) - }) - - It("parses a frame with FIN bit", func() { - data := []byte{0x8 ^ 0x1} - data = append(data, encodeVarInt(9)...) // stream ID - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.Fin).To(BeTrue()) - Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(BeZero()) - }) - - It("allows empty frames", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x1337)...) // stream ID - data = append(data, encodeVarInt(0x12345)...) // offset - r := bytes.NewReader(data) - f, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeFalse()) - }) - - It("rejects frames that overflow the maximum offset", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, protocol.Version1) - Expect(err).To(MatchError("stream data overflows maximum offset")) - }) - - It("rejects frames that claim to be longer than the packet size", func() { - data := []byte{0x8 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length - data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - }) - - It("errors on EOFs", func() { - data := []byte{0x8 ^ 0x4 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // data length - data = append(data, []byte("foobar")...) - _, err := parseStreamFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("using the buffer", func() { - It("uses the buffer for long STREAM frames", func() { - data := []byte{0x8} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) - Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize)) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.fromPool).To(BeTrue()) - Expect(r.Len()).To(BeZero()) - Expect(frame.PutBack).ToNot(Panic()) - }) - - It("doesn't use the buffer for short STREAM frames", func() { - data := []byte{0x8} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) - Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1)) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.fromPool).To(BeFalse()) - Expect(r.Len()).To(BeZero()) - Expect(frame.PutBack).ToNot(Panic()) - }) - }) - - Context("when writing", func() { - It("writes a frame without offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x123456, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with FIN bit", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x123456, - Fin: true, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4 ^ 0x1} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - DataLenPresent: true, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x2} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(6)...) // data length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with data length and offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - DataLenPresent: true, - Offset: 0x123456, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4 ^ 0x2} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, encodeVarInt(6)...) // data length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("refuses to write an empty frame without FIN", func() { - f := &StreamFrame{ - StreamID: 0x42, - Offset: 0x1337, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) - }) - }) - - Context("length", func() { - It("has the right length for a frame without offset and data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) - }) - - It("has the right length for a frame with offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x42, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) - }) - - It("has the right length for a frame with data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x1234567, - DataLenPresent: true, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("always returns a data length such that the resulting frame has the right size, if data length is not present", func() { - data := make([]byte, maxSize) - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0xdeadbeef, - } - b := &bytes.Buffer{} - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(Equal(i)) - } - }) - - It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { - data := make([]byte, maxSize) - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0xdeadbeef, - DataLenPresent: true, - } - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) - - Context("splitting", func() { - It("doesn't split if the frame is short enough", func() { - f := &StreamFrame{ - StreamID: 0x1337, - DataLenPresent: true, - Offset: 0xdeadbeef, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) - Expect(needsSplit).To(BeFalse()) - Expect(frame).To(BeNil()) - Expect(f.DataLen()).To(BeEquivalentTo(100)) - frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame.DataLen()).To(BeEquivalentTo(99)) - f.PutBack() - }) - - It("keeps the data len", func() { - f := &StreamFrame{ - StreamID: 0x1337, - DataLenPresent: true, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(66, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(frame.DataLenPresent).To(BeTrue()) - }) - - It("adjusts the offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x100, - Data: []byte("foobar"), - } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0x100))) - Expect(frame.Data).To(Equal([]byte("foo"))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x100 + 3))) - Expect(f.Data).To(Equal([]byte("bar"))) - }) - - It("preserves the FIN bit", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Fin: true, - Offset: 0xdeadbeef, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Offset).To(BeNumerically("<", f.Offset)) - Expect(f.Fin).To(BeTrue()) - Expect(frame.Fin).To(BeFalse()) - }) - - It("produces frames of the correct length, without data len", func() { - const size = 1000 - f := &StreamFrame{ - StreamID: 0xdecafbad, - Offset: 0x1234, - Data: []byte{0}, - } - minFrameSize := f.Length(protocol.Version1) - for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - for i := minFrameSize; i < size; i++ { - f.fromPool = false - f.Data = make([]byte, size) - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f.Length(protocol.Version1)).To(Equal(i)) - } - }) - - It("produces frames of the correct length, with data len", func() { - const size = 1000 - f := &StreamFrame{ - StreamID: 0xdecafbad, - Offset: 0x1234, - DataLenPresent: true, - Data: []byte{0}, - } - minFrameSize := f.Length(protocol.Version1) - for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - var frameOneByteTooSmallCounter int - for i := minFrameSize; i < size; i++ { - f.fromPool = false - f.Data = make([]byte, size) - newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if newFrame.Length(protocol.Version1) == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(newFrame.Length(protocol.Version1)).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) -}) diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go deleted file mode 100644 index 5e18bb34..00000000 --- a/internal/wire/streams_blocked_frame.go +++ /dev/null @@ -1,56 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" -) - -// A StreamsBlockedFrame is a STREAMS_BLOCKED frame -type StreamsBlockedFrame struct { - Type protocol.StreamType - StreamLimit protocol.StreamNum -} - -func parseStreamsBlockedFrame(r *bytes.Reader, _ quic.VersionNumber) (*StreamsBlockedFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &StreamsBlockedFrame{} - switch typeByte { - case 0x16: - f.Type = protocol.StreamTypeBidi - case 0x17: - f.Type = protocol.StreamTypeUni - } - streamLimit, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.StreamLimit = protocol.StreamNum(streamLimit) - if f.StreamLimit > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) - } - return f, nil -} - -func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ quic.VersionNumber) error { - switch f.Type { - case protocol.StreamTypeBidi: - b.WriteByte(0x16) - case protocol.StreamTypeUni: - b.WriteByte(0x17) - } - quicvarint.Write(b, uint64(f.StreamLimit)) - return nil -} - -// Length of a written frame -func (f *StreamsBlockedFrame) Length(_ quic.VersionNumber) protocol.ByteCount { - return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamLimit))) -} diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go deleted file mode 100644 index eb5f94bc..00000000 --- a/internal/wire/streams_blocked_frame_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAMS_BLOCKED frame", func() { - Context("parsing", func() { - It("accepts a frame for bidirectional streams", func() { - expected := []byte{0x16} - expected = append(expected, encodeVarInt(0x1337)...) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) - Expect(f.StreamLimit).To(BeEquivalentTo(0x1337)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts a frame for unidirectional streams", func() { - expected := []byte{0x17} - expected = append(expected, encodeVarInt(0x7331)...) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeUni)) - Expect(f.StreamLimit).To(BeEquivalentTo(0x7331)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x16} - data = append(data, encodeVarInt(0x12345678)...) - _, err := parseStreamsBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - for i := range data { - _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - - for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { - streamType := t - - It("accepts a frame containing the maximum stream count", func() { - f := &StreamsBlockedFrame{ - Type: streamType, - StreamLimit: protocol.MaxStreamCount, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when receiving a too large stream count", func() { - f := &StreamsBlockedFrame{ - Type: streamType, - StreamLimit: protocol.MaxStreamCount + 1, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) - }) - } - }) - - Context("writing", func() { - It("writes a frame for bidirectional streams", func() { - b := &bytes.Buffer{} - f := StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 0xdeadbeefcafe, - } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x16} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame for unidirectional streams", func() { - b := &bytes.Buffer{} - f := StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, - StreamLimit: 0xdeadbeefcafe, - } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x17} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := StreamsBlockedFrame{StreamLimit: 0x123456} - Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + quicvarint.Len(0x123456))) - }) - }) -}) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go deleted file mode 100644 index 396ed50f..00000000 --- a/internal/wire/transport_parameter_test.go +++ /dev/null @@ -1,612 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "math" - "math/rand" - "net" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Transport Parameters", func() { - getRandomValueUpTo := func(max int64) uint64 { - maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - m := maxVals[int(rand.Int31n(4))] - if m > max { - m = max - } - return uint64(rand.Int63n(m)) - } - - getRandomValue := func() uint64 { - return getRandomValueUpTo(math.MaxInt64) - } - - BeforeEach(func() { - rand.Seed(GinkgoRandomSeed()) - }) - - addInitialSourceConnectionID := func(b *bytes.Buffer) { - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - } - - It("has a string representation", func() { - p := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1234, - InitialMaxStreamDataBidiRemote: 2345, - InitialMaxStreamDataUni: 3456, - InitialMaxData: 4567, - MaxBidiStreamNum: 1337, - MaxUniStreamNum: 7331, - MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - AckDelayExponent: 14, - MaxAckDelay: 37 * time.Millisecond, - StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - ActiveConnectionIDLimit: 123, - MaxDatagramFrameSize: 876, - } - Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: decafbad, RetrySourceConnectionID: deadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}")) - }) - - It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() { - p := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1234, - InitialMaxStreamDataBidiRemote: 2345, - InitialMaxStreamDataUni: 3456, - InitialMaxData: 4567, - MaxBidiStreamNum: 1337, - MaxUniStreamNum: 7331, - MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{}, - AckDelayExponent: 14, - MaxAckDelay: 37 * time.Second, - ActiveConnectionIDLimit: 89, - MaxDatagramFrameSize: protocol.InvalidByteCount, - } - Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}")) - }) - - It("marshals and unmarshals", func() { - var token protocol.StatelessResetToken - rand.Read(token[:]) - params := &TransportParameters{ - InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), - InitialMaxData: protocol.ByteCount(getRandomValue()), - MaxIdleTimeout: 0xcafe * time.Second, - MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - DisableActiveMigration: true, - StatelessResetToken: &token, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - AckDelayExponent: 13, - MaxAckDelay: 42 * time.Millisecond, - ActiveConnectionIDLimit: getRandomValue(), - MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), - } - data := params.Marshal(protocol.PerspectiveServer) - - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) - Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) - Expect(p.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) - Expect(p.InitialMaxData).To(Equal(params.InitialMaxData)) - Expect(p.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) - Expect(p.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) - Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) - Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) - Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) - Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - Expect(p.AckDelayExponent).To(Equal(uint8(13))) - Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) - Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) - Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) - }) - - It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() { - data := (&TransportParameters{ - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.RetrySourceConnectionID).To(BeNil()) - }) - - It("marshals a zero-length retry_source_connection_id", func() { - data := (&TransportParameters{ - RetrySourceConnectionID: &protocol.ConnectionID{}, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.RetrySourceConnectionID).ToNot(BeNil()) - Expect(p.RetrySourceConnectionID.Len()).To(BeZero()) - }) - - It("errors when the stateless_reset_token has the wrong length", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 15) - b.Write(make([]byte, 15)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "wrong length for stateless_reset_token: 15 (expected 16)", - })) - }) - - It("errors when the max_packet_size is too small", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(maxUDPPayloadSizeParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(1199))) - quicvarint.Write(b, 1199) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_packet_size: 1199 (minimum 1200)", - })) - }) - - It("errors when disable_active_migration has content", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "wrong length for disable_active_migration: 6 (expected empty)", - })) - }) - - It("errors when the server doesn't set the original_destination_connection_id", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(make([]byte, 16)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "missing original_destination_connection_id", - })) - }) - - It("errors when the initial_source_connection_id is missing", func() { - Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "missing initial_source_connection_id", - })) - }) - - It("errors when the max_ack_delay is too large", func() { - data := (&TransportParameters{ - MaxAckDelay: 1 << 14 * time.Millisecond, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", - })) - }) - - It("doesn't send the max_ack_delay, if it has the default value", func() { - const num = 1000 - var defaultLen, dataLen int - // marshal 1000 times to average out the greasing transport parameter - maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond - for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{ - MaxAckDelay: protocol.DefaultMaxAckDelay, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - defaultLen += len(dataDefault) - data := (&TransportParameters{ - MaxAckDelay: maxAckDelay, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - dataLen += len(data) - } - entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(uint64(maxAckDelay.Milliseconds())))) /*length */ + quicvarint.Len(uint64(maxAckDelay.Milliseconds())) /* value */ - Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) - }) - - It("errors when the ack_delay_exponenent is too large", func() { - data := (&TransportParameters{ - AckDelayExponent: 21, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for ack_delay_exponent: 21 (maximum 20)", - })) - }) - - It("doesn't send the ack_delay_exponent, if it has the default value", func() { - const num = 1000 - var defaultLen, dataLen int - // marshal 1000 times to average out the greasing transport parameter - for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - defaultLen += len(dataDefault) - data := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent + 1, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - dataLen += len(data) - } - entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(protocol.DefaultAckDelayExponent+1))) /* length */ + quicvarint.Len(protocol.DefaultAckDelayExponent+1) /* value */ - Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) - }) - - It("sets the default value for the ack_delay_exponent, when no value was sent", func() { - data := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent)) - }) - - It("errors when the varint value has the wrong length", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, 2) - val := uint64(0xdeadbeef) - Expect(quicvarint.Len(val)).ToNot(BeEquivalentTo(2)) - quicvarint.Write(b, val) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), - })) - }) - - It("errors if initial_max_streams_bidi is too large", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamsBidiParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) - quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", - })) - }) - - It("errors if initial_max_streams_uni is too large", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamsUniParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) - quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", - })) - }) - - It("handles huge max_ack_delay values", func() { - b := &bytes.Buffer{} - val := uint64(math.MaxUint64) / 5 - quicvarint.Write(b, uint64(maxAckDelayParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", - })) - }) - - It("skips unknown parameters", func() { - b := &bytes.Buffer{} - // write a known parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - // write an unknown parameter - quicvarint.Write(b, 0x42) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - // write a known parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x42))) - quicvarint.Write(b, 0x42) - addInitialSourceConnectionID(b) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) - Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) - Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) - }) - - It("rejects duplicate parameters", func() { - b := &bytes.Buffer{} - // write first parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - // write a second parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x42))) - quicvarint.Write(b, 0x42) - // write first parameter again - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), - })) - }) - - It("errors if there's not enough data to read", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, 0x42) - quicvarint.Write(b, 7) - b.Write([]byte("foobar")) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "remaining length (6) smaller than parameter length (7)", - })) - }) - - It("errors if the client sent a stateless_reset_token", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(16))) - b.Write(make([]byte, 16)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent a stateless_reset_token", - })) - }) - - It("errors if the client sent the original_destination_connection_id", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent an original_destination_connection_id", - })) - }) - - Context("preferred address", func() { - var pa *PreferredAddress - - BeforeEach(func() { - pa = &PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv4Port: 42, - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - IPv6Port: 13, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - } - }) - - It("marshals and unmarshals", func() { - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) - Expect(p.PreferredAddress.IPv4Port).To(Equal(pa.IPv4Port)) - Expect(p.PreferredAddress.IPv6.String()).To(Equal(pa.IPv6.String())) - Expect(p.PreferredAddress.IPv6Port).To(Equal(pa.IPv6Port)) - Expect(p.PreferredAddress.ConnectionID).To(Equal(pa.ConnectionID)) - Expect(p.PreferredAddress.StatelessResetToken).To(Equal(pa.StatelessResetToken)) - }) - - It("errors if the client sent a preferred_address", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent a preferred_address", - })) - }) - - It("errors on zero-length connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{} - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 0", - })) - }) - - It("errors on too long connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} - Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 21", - })) - }) - - It("errors on EOF", func() { - raw := []byte{ - 127, 0, 0, 1, // IPv4 - 0, 42, // IPv4 Port - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // IPv6 - 13, 37, // IPv6 Port, - 4, // conn ID len - 0xde, 0xad, 0xbe, 0xef, - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // stateless reset token - } - for i := 1; i < len(raw); i++ { - buf := &bytes.Buffer{} - quicvarint.Write(buf, uint64(preferredAddressParameterID)) - buf.Write(raw[:i]) - p := &TransportParameters{} - Expect(p.Unmarshal(buf.Bytes(), protocol.PerspectiveServer)).ToNot(Succeed()) - } - }) - }) - - Context("saving and retrieving from a session ticket", func() { - It("saves and retrieves the parameters", func() { - params := &TransportParameters{ - InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), - InitialMaxData: protocol.ByteCount(getRandomValue()), - MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - ActiveConnectionIDLimit: getRandomValue(), - } - Expect(params.ValidFor0RTT(params)).To(BeTrue()) - b := &bytes.Buffer{} - params.MarshalForSessionTicket(b) - var tp TransportParameters - Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) - Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) - Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) - Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) - Expect(tp.InitialMaxData).To(Equal(params.InitialMaxData)) - Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) - Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) - Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) - }) - - It("rejects the parameters if it can't parse them", func() { - var p TransportParameters - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed()) - }) - - It("rejects the parameters if the version changed", func() { - var p TransportParameters - buf := &bytes.Buffer{} - p.MarshalForSessionTicket(buf) - data := buf.Bytes() - b := &bytes.Buffer{} - quicvarint.Write(b, transportParameterMarshalingVersion+1) - b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) - }) - - Context("rejects the parameters if they changed", func() { - var p TransportParameters - saved := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1, - InitialMaxStreamDataBidiRemote: 2, - InitialMaxStreamDataUni: 3, - InitialMaxData: 4, - MaxBidiStreamNum: 5, - MaxUniStreamNum: 6, - ActiveConnectionIDLimit: 7, - } - - BeforeEach(func() { - p = *saved - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { - p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { - p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { - p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { - p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { - p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { - p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxData was reduced", func() { - p.InitialMaxData = saved.InitialMaxData - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxData was increased", func() { - p.InitialMaxData = saved.InitialMaxData + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { - p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("accepts the parameters if the MaxBidiStreamNum was increased", func() { - p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the MaxUniStreamNum changed", func() { - p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("accepts the parameters if the MaxUniStreamNum was increased", func() { - p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the ActiveConnectionIDLimit changed", func() { - p.ActiveConnectionIDLimit = 0 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - }) - }) -}) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go deleted file mode 100644 index 0d1a3401..00000000 --- a/internal/wire/transport_parameters.go +++ /dev/null @@ -1,476 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/rand" - "net" - "sort" - "time" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qerr" - "github.com/imroc/req/v3/internal/utils" - "github.com/imroc/req/v3/internal/quicvarint" -) - -const transportParameterMarshalingVersion = 1 - -func init() { - rand.Seed(time.Now().UTC().UnixNano()) -} - -type transportParameterID uint64 - -const ( - originalDestinationConnectionIDParameterID transportParameterID = 0x0 - maxIdleTimeoutParameterID transportParameterID = 0x1 - statelessResetTokenParameterID transportParameterID = 0x2 - maxUDPPayloadSizeParameterID transportParameterID = 0x3 - initialMaxDataParameterID transportParameterID = 0x4 - initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 - initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 - initialMaxStreamDataUniParameterID transportParameterID = 0x7 - initialMaxStreamsBidiParameterID transportParameterID = 0x8 - initialMaxStreamsUniParameterID transportParameterID = 0x9 - ackDelayExponentParameterID transportParameterID = 0xa - maxAckDelayParameterID transportParameterID = 0xb - disableActiveMigrationParameterID transportParameterID = 0xc - preferredAddressParameterID transportParameterID = 0xd - activeConnectionIDLimitParameterID transportParameterID = 0xe - initialSourceConnectionIDParameterID transportParameterID = 0xf - retrySourceConnectionIDParameterID transportParameterID = 0x10 - // RFC 9221 - maxDatagramFrameSizeParameterID transportParameterID = 0x20 -) - -// PreferredAddress is the value encoding in the preferred_address transport parameter -type PreferredAddress struct { - IPv4 net.IP - IPv4Port uint16 - IPv6 net.IP - IPv6Port uint16 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} - -// TransportParameters are parameters sent to the peer during the handshake -type TransportParameters struct { - InitialMaxStreamDataBidiLocal protocol.ByteCount - InitialMaxStreamDataBidiRemote protocol.ByteCount - InitialMaxStreamDataUni protocol.ByteCount - InitialMaxData protocol.ByteCount - - MaxAckDelay time.Duration - AckDelayExponent uint8 - - DisableActiveMigration bool - - MaxUDPPayloadSize protocol.ByteCount - - MaxUniStreamNum protocol.StreamNum - MaxBidiStreamNum protocol.StreamNum - - MaxIdleTimeout time.Duration - - PreferredAddress *PreferredAddress - - OriginalDestinationConnectionID protocol.ConnectionID - InitialSourceConnectionID protocol.ConnectionID - RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters - - StatelessResetToken *protocol.StatelessResetToken - ActiveConnectionIDLimit uint64 - - MaxDatagramFrameSize protocol.ByteCount -} - -// Unmarshal the transport parameters -func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { - if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { - return &qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - } - } - return nil -} - -func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { - // needed to check that every parameter is only sent at most once - var parameterIDs []transportParameterID - - var ( - readOriginalDestinationConnectionID bool - readInitialSourceConnectionID bool - ) - - p.AckDelayExponent = protocol.DefaultAckDelayExponent - p.MaxAckDelay = protocol.DefaultMaxAckDelay - p.MaxDatagramFrameSize = protocol.InvalidByteCount - - for r.Len() > 0 { - paramIDInt, err := quicvarint.Read(r) - if err != nil { - return err - } - paramID := transportParameterID(paramIDInt) - paramLen, err := quicvarint.Read(r) - if err != nil { - return err - } - if uint64(r.Len()) < paramLen { - return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) - } - parameterIDs = append(parameterIDs, paramID) - switch paramID { - case maxIdleTimeoutParameterID, - maxUDPPayloadSizeParameterID, - initialMaxDataParameterID, - initialMaxStreamDataBidiLocalParameterID, - initialMaxStreamDataBidiRemoteParameterID, - initialMaxStreamDataUniParameterID, - initialMaxStreamsBidiParameterID, - initialMaxStreamsUniParameterID, - maxAckDelayParameterID, - activeConnectionIDLimitParameterID, - maxDatagramFrameSizeParameterID, - ackDelayExponentParameterID: - if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { - return err - } - case preferredAddressParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a preferred_address") - } - if err := p.readPreferredAddress(r, int(paramLen)); err != nil { - return err - } - case disableActiveMigrationParameterID: - if paramLen != 0 { - return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) - } - p.DisableActiveMigration = true - case statelessResetTokenParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a stateless_reset_token") - } - if paramLen != 16 { - return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) - } - var token protocol.StatelessResetToken - r.Read(token[:]) - p.StatelessResetToken = &token - case originalDestinationConnectionIDParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent an original_destination_connection_id") - } - p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) - readOriginalDestinationConnectionID = true - case initialSourceConnectionIDParameterID: - p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) - readInitialSourceConnectionID = true - case retrySourceConnectionIDParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a retry_source_connection_id") - } - connID, _ := protocol.ReadConnectionID(r, int(paramLen)) - p.RetrySourceConnectionID = &connID - default: - r.Seek(int64(paramLen), io.SeekCurrent) - } - } - - if !fromSessionTicket { - if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { - return errors.New("missing original_destination_connection_id") - } - if p.MaxUDPPayloadSize == 0 { - p.MaxUDPPayloadSize = protocol.MaxByteCount - } - if !readInitialSourceConnectionID { - return errors.New("missing initial_source_connection_id") - } - } - - // check that every transport parameter was sent at most once - sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) - for i := 0; i < len(parameterIDs)-1; i++ { - if parameterIDs[i] == parameterIDs[i+1] { - return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) - } - } - - return nil -} - -func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { - remainingLen := r.Len() - pa := &PreferredAddress{} - ipv4 := make([]byte, 4) - if _, err := io.ReadFull(r, ipv4); err != nil { - return err - } - pa.IPv4 = net.IP(ipv4) - port, err := utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv4Port = port - ipv6 := make([]byte, 16) - if _, err := io.ReadFull(r, ipv6); err != nil { - return err - } - pa.IPv6 = net.IP(ipv6) - port, err = utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv6Port = port - connIDLen, err := r.ReadByte() - if err != nil { - return err - } - if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return err - } - pa.ConnectionID = connID - if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { - return err - } - if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { - return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) - } - p.PreferredAddress = pa - return nil -} - -func (p *TransportParameters) readNumericTransportParameter( - r *bytes.Reader, - paramID transportParameterID, - expectedLen int, -) error { - remainingLen := r.Len() - val, err := quicvarint.Read(r) - if err != nil { - return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) - } - if remainingLen-r.Len() != expectedLen { - return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) - } - //nolint:exhaustive // This only covers the numeric transport parameters. - switch paramID { - case initialMaxStreamDataBidiLocalParameterID: - p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) - case initialMaxStreamDataBidiRemoteParameterID: - p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) - case initialMaxStreamDataUniParameterID: - p.InitialMaxStreamDataUni = protocol.ByteCount(val) - case initialMaxDataParameterID: - p.InitialMaxData = protocol.ByteCount(val) - case initialMaxStreamsBidiParameterID: - p.MaxBidiStreamNum = protocol.StreamNum(val) - if p.MaxBidiStreamNum > protocol.MaxStreamCount { - return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) - } - case initialMaxStreamsUniParameterID: - p.MaxUniStreamNum = protocol.StreamNum(val) - if p.MaxUniStreamNum > protocol.MaxStreamCount { - return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) - } - case maxIdleTimeoutParameterID: - p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) - case maxUDPPayloadSizeParameterID: - if val < 1200 { - return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) - } - p.MaxUDPPayloadSize = protocol.ByteCount(val) - case ackDelayExponentParameterID: - if val > protocol.MaxAckDelayExponent { - return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) - } - p.AckDelayExponent = uint8(val) - case maxAckDelayParameterID: - if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { - return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) - } - p.MaxAckDelay = time.Duration(val) * time.Millisecond - case activeConnectionIDLimitParameterID: - p.ActiveConnectionIDLimit = val - case maxDatagramFrameSizeParameterID: - p.MaxDatagramFrameSize = protocol.ByteCount(val) - default: - return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) - } - return nil -} - -// Marshal the transport parameters -func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { - b := &bytes.Buffer{} - - // add a greased value - quicvarint.Write(b, uint64(27+31*rand.Intn(100))) - length := rand.Intn(16) - randomData := make([]byte, length) - rand.Read(randomData) - quicvarint.Write(b, uint64(length)) - b.Write(randomData) - - // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) - // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) - // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) - // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) - // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) - // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) - // idle_timeout - p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) - // max_packet_size - p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) - // max_ack_delay - // Only send it if is different from the default value. - if p.MaxAckDelay != protocol.DefaultMaxAckDelay { - p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) - } - // ack_delay_exponent - // Only send it if is different from the default value. - if p.AckDelayExponent != protocol.DefaultAckDelayExponent { - p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) - } - // disable_active_migration - if p.DisableActiveMigration { - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 0) - } - if pers == protocol.PerspectiveServer { - // stateless_reset_token - if p.StatelessResetToken != nil { - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(p.StatelessResetToken[:]) - } - // original_destination_connection_id - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) - b.Write(p.OriginalDestinationConnectionID.Bytes()) - // preferred_address - if p.PreferredAddress != nil { - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) - ipv4 := p.PreferredAddress.IPv4 - b.Write(ipv4[len(ipv4)-4:]) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) - b.Write(p.PreferredAddress.IPv6) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) - b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) - b.Write(p.PreferredAddress.ConnectionID.Bytes()) - b.Write(p.PreferredAddress.StatelessResetToken[:]) - } - } - // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) - // initial_source_connection_id - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) - b.Write(p.InitialSourceConnectionID.Bytes()) - // retry_source_connection_id - if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { - quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) - b.Write(p.RetrySourceConnectionID.Bytes()) - } - if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) - } - return b.Bytes() -} - -func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { - quicvarint.Write(b, uint64(id)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) -} - -// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. -// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. -// The client will remember the transport parameters used in the last session, -// and apply those to the 0-RTT data it sends. -// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT -// if the transport parameters changed. -// Since the session ticket is encrypted, the serialization format is defined by the server. -// For convenience, we use the same format that we also use for sending the transport parameters. -func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { - quicvarint.Write(b, transportParameterMarshalingVersion) - - // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) - // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) - // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) - // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) - // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) - // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) - // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) -} - -// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. -func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { - version, err := quicvarint.Read(r) - if err != nil { - return err - } - if version != transportParameterMarshalingVersion { - return fmt.Errorf("unknown transport parameter marshaling version: %d", version) - } - return p.unmarshal(r, protocol.PerspectiveServer, true) -} - -// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. -func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { - return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && - p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && - p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && - p.InitialMaxData >= saved.InitialMaxData && - p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && - p.MaxUniStreamNum >= saved.MaxUniStreamNum && - p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit -} - -// String returns a string representation, intended for logging. -func (p *TransportParameters) String() string { - logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " - logParams := []interface{}{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} - if p.RetrySourceConnectionID != nil { - logString += "RetrySourceConnectionID: %s, " - logParams = append(logParams, p.RetrySourceConnectionID) - } - logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" - logParams = append(logParams, []interface{}{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) - if p.StatelessResetToken != nil { // the client never sends a stateless reset token - logString += ", StatelessResetToken: %#x" - logParams = append(logParams, *p.StatelessResetToken) - } - if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - logString += ", MaxDatagramFrameSize: %d" - logParams = append(logParams, p.MaxDatagramFrameSize) - } - logString += "}" - return fmt.Sprintf(logString, logParams...) -} diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go deleted file mode 100644 index 971571dd..00000000 --- a/internal/wire/version_negotiation.go +++ /dev/null @@ -1,55 +0,0 @@ -package wire - -import ( - "bytes" - "crypto/rand" - "errors" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/utils" -) - -// ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []quic.VersionNumber, error) { - hdr, err := parseHeader(b, 0) - if err != nil { - return nil, nil, err - } - if b.Len() == 0 { - //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has empty version list") - } - if b.Len()%4 != 0 { - //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") - } - versions := make([]quic.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, nil, err - } - versions[i] = quic.VersionNumber(v) - } - return hdr, versions, nil -} - -// ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []quic.VersionNumber) []byte { - greasedVersions := protocol.GetGreasedVersions(versions) - expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 - buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) - r := make([]byte, 1) - _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. - buf.WriteByte(r[0] | 0x80) - utils.BigEndian.WriteUint32(buf, 0) // version 0 - buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID) - buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID) - for _, v := range greasedVersions { - utils.BigEndian.WriteUint32(buf, uint32(v)) - } - return buf.Bytes() -} diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go deleted file mode 100644 index 6a2a62b9..00000000 --- a/internal/wire/version_negotiation_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "github.com/lucas-clemente/quic-go" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Version Negotiation Packets", func() { - It("parses a Version Negotiation packet", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - versions := []quic.VersionNumber{0x22334455, 0x33445566} - data := []byte{0x80, 0, 0, 0, 0} - data = append(data, uint8(len(destConnID))) - data = append(data, destConnID...) - data = append(data, uint8(len(srcConnID))) - data = append(data, srcConnID...) - for _, v := range versions { - data = append(data, []byte{0, 0, 0, 0}...) - binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) - } - Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(BeZero()) - Expect(supportedVersions).To(Equal(versions)) - }) - - It("errors if it contains versions of the wrong length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []quic.VersionNumber{0x22334455, 0x33445566} - data := ComposeVersionNegotiation(connID, connID, versions) - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) - Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) - }) - - It("errors if the version list is empty", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []quic.VersionNumber{0x22334455} - data := ComposeVersionNegotiation(connID, connID, versions) - // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number - data = data[:len(data)-8] - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).To(MatchError("Version Negotiation packet has empty version list")) - }) - - It("adds a reserved version", func() { - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []quic.VersionNumber{1001, 1003} - data := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(data[0] & 0x80).ToNot(BeZero()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Version).To(BeZero()) - // the supported versions should include one reserved version number - Expect(supportedVersions).To(HaveLen(len(versions) + 1)) - for _, v := range versions { - Expect(supportedVersions).To(ContainElement(v)) - } - var reservedVersion quic.VersionNumber - versionLoop: - for _, ver := range supportedVersions { - for _, v := range versions { - if v == ver { - continue versionLoop - } - } - reservedVersion = ver - } - Expect(reservedVersion).ToNot(BeZero()) - Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number - }) -}) diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go deleted file mode 100644 index 9042747a..00000000 --- a/internal/wire/wire_suite_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "github.com/lucas-clemente/quic-go" - "testing" - - "github.com/imroc/req/v3/internal/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestWire(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Wire Suite") -} - -func encodeVarInt(i uint64) []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, i) - return b.Bytes() -} - -func appendVersion(data []byte, v quic.VersionNumber) []byte { - offset := len(data) - data = append(data, []byte{0, 0, 0, 0}...) - binary.BigEndian.PutUint32(data[offset:], uint32(v)) - return data -} From 35d4c97dc9f56ce1c45b8b7552fa8a47a7440f3f Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 11:54:31 +0800 Subject: [PATCH 523/843] updat go mod --- go.mod | 1 + 1 file changed, 1 insertion(+) diff --git a/go.mod b/go.mod index 8440cca7..a1b1edd0 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 github.com/marten-seemann/qtls-go1-18 v0.1.2 + github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.13.0 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect From ceb0fa2433837638be3ac5364d5d73478365ad9d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 12:07:29 +0800 Subject: [PATCH 524/843] optimize Client.Clone() --- client.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index a6b00754..f36f6f9f 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,6 @@ import ( "encoding/xml" "errors" "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/util" "golang.org/x/net/publicsuffix" "io" @@ -880,10 +879,17 @@ func NewClient() *Client { // Clone copy and returns the Client func (c *Client) Clone() *Client { t := c.t.Clone() - t2 := &http2.Transport{ - Options: &t.Options, + opt := t.Options + if t.t2 != nil { + t2 := *t.t2 + t2.Options = &opt + t.t2 = &t2 + } + if t.t3 != nil { + t3 := *t.t3 + t3.Options = &opt + t.t3 = &t3 } - t.t2 = t2 client := *c.httpClient client.Transport = t From 5e6314a94b7c5b45f2369672f6aa4b7b38df4ba0 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 12:14:47 +0800 Subject: [PATCH 525/843] optimize Client.Clone() --- client.go | 11 ----------- transport.go | 27 ++++++++++++++++----------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index f36f6f9f..333df2b2 100644 --- a/client.go +++ b/client.go @@ -879,17 +879,6 @@ func NewClient() *Client { // Clone copy and returns the Client func (c *Client) Clone() *Client { t := c.t.Clone() - opt := t.Options - if t.t2 != nil { - t2 := *t.t2 - t2.Options = &opt - t.t2 = &t2 - } - if t.t3 != nil { - t3 := *t.t3 - t3.Options = &opt - t.t3 = &t3 - } client := *c.httpClient client.Transport = t diff --git a/transport.go b/transport.go index 7b13bfc7..e1e7b6d4 100644 --- a/transport.go +++ b/transport.go @@ -472,20 +472,25 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { - t2 := &Transport{ - Options: t.Options, - ForceHttpVersion: t.ForceHttpVersion, - disableAutoDecode: t.disableAutoDecode, - autoDecodeContentType: t.autoDecodeContentType, + tt := *t + tt.Dump = t.Dump.Clone() + if tt.Dump != nil { + go tt.Dump.Start() } - t2.Options.Dump = t.Options.Dump.Clone() - if t.Dump != nil { - go t.Dump.Start() + if tt.TLSClientConfig != nil { + tt.TLSClientConfig = tt.TLSClientConfig.Clone() + } + if t.t2 != nil { + t2 := *t.t2 + t2.Options = &tt.Options + tt.t2 = &t2 } - if t.TLSClientConfig != nil { - t2.TLSClientConfig = t.TLSClientConfig.Clone() + if t.t3 != nil { + t3 := *t.t3 + t3.Options = &tt.Options + tt.t3 = &t3 } - return t2 + return &tt } // EnableDump enables the dump for all requests with specified dump options. From 2c44a59874e593e6db3faee03e78e5c03f48dc57 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 16:41:44 +0800 Subject: [PATCH 526/843] remove unused code --- internal/protocol/connection_id.go | 69 ------ internal/protocol/connection_id_test.go | 108 --------- internal/protocol/encryption_level.go | 30 --- internal/protocol/encryption_level_test.go | 20 -- internal/protocol/key_phase.go | 36 --- internal/protocol/key_phase_test.go | 27 --- internal/protocol/packet_number.go | 79 ------- internal/protocol/packet_number_test.go | 204 ----------------- internal/protocol/params.go | 193 ----------------- internal/protocol/params_test.go | 13 -- internal/protocol/perspective.go | 26 --- internal/protocol/perspective_test.go | 19 -- internal/protocol/protocol.go | 93 -------- internal/protocol/protocol_suite_test.go | 13 -- internal/protocol/protocol_test.go | 25 --- internal/protocol/stream.go | 76 ------- internal/protocol/stream_test.go | 70 ------ internal/protocol/version.go | 72 +----- internal/protocol/version_test.go | 122 ----------- internal/utils/atomic_bool.go | 22 -- internal/utils/atomic_bool_test.go | 29 --- internal/utils/buffered_write_closer.go | 26 --- internal/utils/buffered_write_closer_test.go | 26 --- internal/utils/byteinterval_linkedlist.go | 217 ------------------- internal/utils/byteoder_big_endian_test.go | 107 --------- internal/utils/byteorder.go | 17 -- internal/utils/byteorder_big_endian.go | 89 -------- internal/utils/ip.go | 10 - internal/utils/ip_test.go | 17 -- internal/utils/minmax.go | 170 --------------- internal/utils/minmax_test.go | 123 ----------- internal/utils/new_connection_id.go | 12 - internal/utils/newconnectionid_linkedlist.go | 217 ------------------- internal/utils/packet_interval.go | 9 - internal/utils/packetinterval_linkedlist.go | 217 ------------------- internal/utils/rand.go | 29 --- internal/utils/rand_test.go | 32 --- internal/utils/rtt_stats.go | 127 ----------- internal/utils/rtt_stats_test.go | 157 -------------- internal/utils/streamframe_interval.go | 9 - internal/utils/timer.go | 53 ----- internal/utils/timer_test.go | 87 -------- 42 files changed, 4 insertions(+), 3093 deletions(-) delete mode 100644 internal/protocol/connection_id.go delete mode 100644 internal/protocol/connection_id_test.go delete mode 100644 internal/protocol/encryption_level.go delete mode 100644 internal/protocol/encryption_level_test.go delete mode 100644 internal/protocol/key_phase.go delete mode 100644 internal/protocol/key_phase_test.go delete mode 100644 internal/protocol/packet_number.go delete mode 100644 internal/protocol/packet_number_test.go delete mode 100644 internal/protocol/params.go delete mode 100644 internal/protocol/params_test.go delete mode 100644 internal/protocol/perspective.go delete mode 100644 internal/protocol/perspective_test.go delete mode 100644 internal/protocol/protocol_suite_test.go delete mode 100644 internal/protocol/protocol_test.go delete mode 100644 internal/protocol/stream.go delete mode 100644 internal/protocol/stream_test.go delete mode 100644 internal/protocol/version_test.go delete mode 100644 internal/utils/atomic_bool.go delete mode 100644 internal/utils/atomic_bool_test.go delete mode 100644 internal/utils/buffered_write_closer.go delete mode 100644 internal/utils/buffered_write_closer_test.go delete mode 100644 internal/utils/byteinterval_linkedlist.go delete mode 100644 internal/utils/byteoder_big_endian_test.go delete mode 100644 internal/utils/byteorder.go delete mode 100644 internal/utils/byteorder_big_endian.go delete mode 100644 internal/utils/ip.go delete mode 100644 internal/utils/ip_test.go delete mode 100644 internal/utils/minmax.go delete mode 100644 internal/utils/minmax_test.go delete mode 100644 internal/utils/new_connection_id.go delete mode 100644 internal/utils/newconnectionid_linkedlist.go delete mode 100644 internal/utils/packet_interval.go delete mode 100644 internal/utils/packetinterval_linkedlist.go delete mode 100644 internal/utils/rand.go delete mode 100644 internal/utils/rand_test.go delete mode 100644 internal/utils/rtt_stats.go delete mode 100644 internal/utils/rtt_stats_test.go delete mode 100644 internal/utils/streamframe_interval.go delete mode 100644 internal/utils/timer.go delete mode 100644 internal/utils/timer_test.go diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go deleted file mode 100644 index 3aec2cd3..00000000 --- a/internal/protocol/connection_id.go +++ /dev/null @@ -1,69 +0,0 @@ -package protocol - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" -) - -// A ConnectionID in QUIC -type ConnectionID []byte - -const maxConnectionIDLen = 20 - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID(len int) (ConnectionID, error) { - b := make([]byte, len) - if _, err := rand.Read(b); err != nil { - return nil, err - } - return ConnectionID(b), nil -} - -// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. -// It uses a length randomly chosen between 8 and 20 bytes. -func GenerateConnectionIDForInitial() (ConnectionID, error) { - r := make([]byte, 1) - if _, err := rand.Read(r); err != nil { - return nil, err - } - len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) - return GenerateConnectionID(len) -} - -// ReadConnectionID reads a connection ID of length len from the given io.Reader. -// It returns io.EOF if there are not enough bytes to read. -func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { - if len == 0 { - return nil, nil - } - c := make(ConnectionID, len) - _, err := io.ReadFull(r, c) - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return c, err -} - -// Equal says if two connection IDs are equal -func (c ConnectionID) Equal(other ConnectionID) bool { - return bytes.Equal(c, other) -} - -// Len returns the length of the connection ID in bytes -func (c ConnectionID) Len() int { - return len(c) -} - -// Bytes returns the byte representation -func (c ConnectionID) Bytes() []byte { - return []byte(c) -} - -func (c ConnectionID) String() string { - if c.Len() == 0 { - return "(empty)" - } - return fmt.Sprintf("%x", c.Bytes()) -} diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go deleted file mode 100644 index 345e656c..00000000 --- a/internal/protocol/connection_id_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package protocol - -import ( - "bytes" - "io" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection ID generation", func() { - It("generates random connection IDs", func() { - c1, err := GenerateConnectionID(8) - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(BeZero()) - c2, err := GenerateConnectionID(8) - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(Equal(c2)) - }) - - It("generates connection IDs with the requested length", func() { - c, err := GenerateConnectionID(5) - Expect(err).ToNot(HaveOccurred()) - Expect(c.Len()).To(Equal(5)) - }) - - It("generates random length destination connection IDs", func() { - var has8ByteConnID, has20ByteConnID bool - for i := 0; i < 1000; i++ { - c, err := GenerateConnectionIDForInitial() - Expect(err).ToNot(HaveOccurred()) - Expect(c.Len()).To(BeNumerically(">=", 8)) - Expect(c.Len()).To(BeNumerically("<=", 20)) - if c.Len() == 8 { - has8ByteConnID = true - } - if c.Len() == 20 { - has20ByteConnID = true - } - } - Expect(has8ByteConnID).To(BeTrue()) - Expect(has20ByteConnID).To(BeTrue()) - }) - - It("says if connection IDs are equal", func() { - c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - Expect(c1.Equal(c1)).To(BeTrue()) - Expect(c2.Equal(c2)).To(BeTrue()) - Expect(c1.Equal(c2)).To(BeFalse()) - Expect(c2.Equal(c1)).To(BeFalse()) - }) - - It("reads the connection ID", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) - c, err := ReadConnectionID(buf, 9) - Expect(err).ToNot(HaveOccurred()) - Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) - }) - - It("returns io.EOF if there's not enough data to read", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) - _, err := ReadConnectionID(buf, 5) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns nil for a 0 length connection ID", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) - c, err := ReadConnectionID(buf, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(c).To(BeNil()) - }) - - It("returns the length", func() { - c := ConnectionID{1, 2, 3, 4, 5, 6, 7} - Expect(c.Len()).To(Equal(7)) - }) - - It("has 0 length for the default value", func() { - var c ConnectionID - Expect(c.Len()).To(BeZero()) - }) - - It("returns the bytes", func() { - c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) - Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) - }) - - It("returns a nil byte slice for the default value", func() { - var c ConnectionID - Expect(c.Bytes()).To(BeNil()) - }) - - It("has a string representation", func() { - c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) - Expect(c.String()).To(Equal("deadbeef42")) - }) - - It("has a long string representation", func() { - c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} - Expect(c.String()).To(Equal("13370000decafbad")) - }) - - It("has a string representation for the default value", func() { - var c ConnectionID - Expect(c.String()).To(Equal("(empty)")) - }) -}) diff --git a/internal/protocol/encryption_level.go b/internal/protocol/encryption_level.go deleted file mode 100644 index 32d38ab1..00000000 --- a/internal/protocol/encryption_level.go +++ /dev/null @@ -1,30 +0,0 @@ -package protocol - -// EncryptionLevel is the encryption level -// Default value is Unencrypted -type EncryptionLevel uint8 - -const ( - // EncryptionInitial is the Initial encryption level - EncryptionInitial EncryptionLevel = 1 + iota - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT - // Encryption1RTT is the 1-RTT encryption level - Encryption1RTT -) - -func (e EncryptionLevel) String() string { - switch e { - case EncryptionInitial: - return "Initial" - case EncryptionHandshake: - return "Handshake" - case Encryption0RTT: - return "0-RTT" - case Encryption1RTT: - return "1-RTT" - } - return "unknown" -} diff --git a/internal/protocol/encryption_level_test.go b/internal/protocol/encryption_level_test.go deleted file mode 100644 index 9b07b08b..00000000 --- a/internal/protocol/encryption_level_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Encryption Level", func() { - It("doesn't use 0 as a value", func() { - // 0 is used in some tests - Expect(EncryptionInitial * EncryptionHandshake * Encryption0RTT * Encryption1RTT).ToNot(BeZero()) - }) - - It("has the correct string representation", func() { - Expect(EncryptionInitial.String()).To(Equal("Initial")) - Expect(EncryptionHandshake.String()).To(Equal("Handshake")) - Expect(Encryption0RTT.String()).To(Equal("0-RTT")) - Expect(Encryption1RTT.String()).To(Equal("1-RTT")) - }) -}) diff --git a/internal/protocol/key_phase.go b/internal/protocol/key_phase.go deleted file mode 100644 index edd740cf..00000000 --- a/internal/protocol/key_phase.go +++ /dev/null @@ -1,36 +0,0 @@ -package protocol - -// KeyPhase is the key phase -type KeyPhase uint64 - -// Bit determines the key phase bit -func (p KeyPhase) Bit() KeyPhaseBit { - if p%2 == 0 { - return KeyPhaseZero - } - return KeyPhaseOne -} - -// KeyPhaseBit is the key phase bit -type KeyPhaseBit uint8 - -const ( - // KeyPhaseUndefined is an undefined key phase - KeyPhaseUndefined KeyPhaseBit = iota - // KeyPhaseZero is key phase 0 - KeyPhaseZero - // KeyPhaseOne is key phase 1 - KeyPhaseOne -) - -func (p KeyPhaseBit) String() string { - //nolint:exhaustive - switch p { - case KeyPhaseZero: - return "0" - case KeyPhaseOne: - return "1" - default: - return "undefined" - } -} diff --git a/internal/protocol/key_phase_test.go b/internal/protocol/key_phase_test.go deleted file mode 100644 index 92f404a5..00000000 --- a/internal/protocol/key_phase_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Key Phases", func() { - It("has undefined as its default value", func() { - var k KeyPhaseBit - Expect(k).To(Equal(KeyPhaseUndefined)) - }) - - It("has the correct string representation", func() { - Expect(KeyPhaseZero.String()).To(Equal("0")) - Expect(KeyPhaseOne.String()).To(Equal("1")) - }) - - It("converts the key phase to the key phase bit", func() { - Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne)) - Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne)) - Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne)) - }) -}) diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go deleted file mode 100644 index bd340161..00000000 --- a/internal/protocol/packet_number.go +++ /dev/null @@ -1,79 +0,0 @@ -package protocol - -// A PacketNumber in QUIC -type PacketNumber int64 - -// InvalidPacketNumber is a packet number that is never sent. -// In QUIC, 0 is a valid packet number. -const InvalidPacketNumber PacketNumber = -1 - -// PacketNumberLen is the length of the packet number in bytes -type PacketNumberLen uint8 - -const ( - // PacketNumberLen1 is a packet number length of 1 byte - PacketNumberLen1 PacketNumberLen = 1 - // PacketNumberLen2 is a packet number length of 2 bytes - PacketNumberLen2 PacketNumberLen = 2 - // PacketNumberLen3 is a packet number length of 3 bytes - PacketNumberLen3 PacketNumberLen = 3 - // PacketNumberLen4 is a packet number length of 4 bytes - PacketNumberLen4 PacketNumberLen = 4 -) - -// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number -func DecodePacketNumber( - packetNumberLength PacketNumberLen, - lastPacketNumber PacketNumber, - wirePacketNumber PacketNumber, -) PacketNumber { - var epochDelta PacketNumber - switch packetNumberLength { - case PacketNumberLen1: - epochDelta = PacketNumber(1) << 8 - case PacketNumberLen2: - epochDelta = PacketNumber(1) << 16 - case PacketNumberLen3: - epochDelta = PacketNumber(1) << 24 - case PacketNumberLen4: - epochDelta = PacketNumber(1) << 32 - } - epoch := lastPacketNumber & ^(epochDelta - 1) - var prevEpochBegin PacketNumber - if epoch > epochDelta { - prevEpochBegin = epoch - epochDelta - } - nextEpochBegin := epoch + epochDelta - return closestTo( - lastPacketNumber+1, - epoch+wirePacketNumber, - closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), - ) -} - -func closestTo(target, a, b PacketNumber) PacketNumber { - if delta(target, a) < delta(target, b) { - return a - } - return b -} - -func delta(a, b PacketNumber) PacketNumber { - if a < b { - return b - a - } - return a - b -} - -// GetPacketNumberLengthForHeader gets the length of the packet number for the public header -// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { - diff := uint64(packetNumber - leastUnacked) - if diff < (1 << (16 - 1)) { - return PacketNumberLen2 - } - if diff < (1 << (24 - 1)) { - return PacketNumberLen3 - } - return PacketNumberLen4 -} diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go deleted file mode 100644 index d3bfe1d5..00000000 --- a/internal/protocol/packet_number_test.go +++ /dev/null @@ -1,204 +0,0 @@ -package protocol - -import ( - "fmt" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -// Tests taken and extended from chrome -var _ = Describe("packet number calculation", func() { - It("InvalidPacketNumber is smaller than all valid packet numbers", func() { - Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) - }) - - It("works with the example from the draft", func() { - Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) - }) - - It("works with the examples from the draft", func() { - Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) - Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) - }) - - getEpoch := func(len PacketNumberLen) uint64 { - if len > 4 { - Fail("invalid packet number len") - } - return uint64(1) << (len * 8) - } - - check := func(length PacketNumberLen, expected, last uint64) { - epoch := getEpoch(length) - epochMask := epoch - 1 - wirePacketNumber := expected & epochMask - ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) - } - - for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { - length := l - - Context(fmt.Sprintf("with %d bytes", length), func() { - epoch := getEpoch(length) - epochMask := epoch - 1 - - It("works near epoch start", func() { - // A few quick manual sanity check - check(length, 1, 0) - check(length, epoch+1, epochMask) - check(length, epoch, epochMask) - - // Cases where the last number was close to the start of the range. - for last := uint64(0); last < 10; last++ { - // Small numbers should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, j, last) - } - - // Large numbers should not wrap either (because we're near 0 already). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - It("works near epoch end", func() { - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := epoch - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, epoch+j, last) - } - - // Large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - // Next check where we're in a non-zero epoch to verify we handle - // reverse wrapping, too. - It("works near previous epoch", func() { - prevEpoch := 1 * epoch - curEpoch := 2 * epoch - // Cases where the last number was close to the start of the range - for i := uint64(0); i < 10; i++ { - last := curEpoch + i - // Small number should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, curEpoch+j, last) - } - - // But large numbers should reverse wrap. - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, prevEpoch+num, last) - } - } - }) - - It("works near next epoch", func() { - curEpoch := 2 * epoch - nextEpoch := 3 * epoch - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := nextEpoch - 1 - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, nextEpoch+j, last) - } - - // but large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, curEpoch+num, last) - } - } - }) - - Context("shortening a packet number for the header", func() { - Context("shortening", func() { - It("sends out low packet numbers as 2 byte", func() { - length := GetPacketNumberLengthForHeader(4, 2) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000, 2) - Expect(length).To(Equal(PacketNumberLen3)) - }) - - It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000000, 2) - Expect(length).To(Equal(PacketNumberLen4)) - }) - }) - - Context("self-consistency", func() { - It("works for small packet numbers", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("works for small packet numbers and increasing ACKed packets", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i / 2) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("also works for larger packet numbers", func() { - var increment uint64 - for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - - increment = getEpoch(length) / 8 - } - }) - - It("works for packet numbers larger than 2^48", func() { - for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - }) - }) - }) - } -}) diff --git a/internal/protocol/params.go b/internal/protocol/params.go deleted file mode 100644 index 83137113..00000000 --- a/internal/protocol/params.go +++ /dev/null @@ -1,193 +0,0 @@ -package protocol - -import "time" - -// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. -const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB - -// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. -const InitialPacketSizeIPv4 = 1252 - -// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. -const InitialPacketSizeIPv6 = 1232 - -// MaxCongestionWindowPackets is the maximum congestion window in packet. -const MaxCongestionWindowPackets = 10000 - -// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. -const MaxUndecryptablePackets = 32 - -// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window -// This is the value that Chromium is using -const ConnectionFlowControlMultiplier = 1.5 - -// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data -const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb - -// DefaultInitialMaxData is the connection-level flow control window for receiving data -const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData - -// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data -const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB - -// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data -const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB - -// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client -const WindowUpdateThreshold = 0.25 - -// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open -const DefaultMaxIncomingStreams = 100 - -// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open -const DefaultMaxIncomingUniStreams = 100 - -// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. -const MaxServerUnprocessedPackets = 1024 - -// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. -const MaxConnUnprocessedPackets = 256 - -// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. -// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. -const SkipPacketInitialPeriod PacketNumber = 256 - -// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. -const SkipPacketMaxPeriod PacketNumber = 128 * 1024 - -// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. -// If the queue is full, new connection attempts will be rejected. -const MaxAcceptQueueSize = 32 - -// TokenValidity is the duration that a (non-retry) token is considered valid -const TokenValidity = 24 * time.Hour - -// RetryTokenValidity is the duration that a retry token is considered valid -const RetryTokenValidity = 10 * time.Second - -// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. -// When reached, it imposes a soft limit on sending new packets: -// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. -const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets - -// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. -// When reached, no more packets will be sent. -// This value *must* be larger than MaxOutstandingSentPackets. -const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 - -// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, -// but no ack-eliciting frames, that we send in a row -const MaxNonAckElicitingAcks = 19 - -// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames -// prevents DoS attacks against the streamFrameSorter -const MaxStreamFrameSorterGaps = 1000 - -// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame -// that we use the buffer for. This protects against a DoS where an attacker would send us -// very small STREAM frames to consume a lot of memory. -const MinStreamFrameBufferSize = 128 - -// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. -// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. -const MinCoalescedPacketSize = 128 - -// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. -// This limits the size of the ClientHello and Certificates that can be received. -const MaxCryptoStreamOffset = 16 * (1 << 10) - -// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout -const MinRemoteIdleTimeout = 5 * time.Second - -// DefaultIdleTimeout is the default idle timeout -const DefaultIdleTimeout = 30 * time.Second - -// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. -const DefaultHandshakeIdleTimeout = 5 * time.Second - -// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. -const DefaultHandshakeTimeout = 10 * time.Second - -// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. -// It should be shorter than the time that NATs clear their mapping. -const MaxKeepAliveInterval = 20 * time.Second - -// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. -// after this time all information about the old connection will be deleted -const RetiredConnectionIDDeleteTimeout = 5 * time.Second - -// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. -// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: -// 1. it reduces the framing overhead -// 2. it reduces the head-of-line blocking, when a packet is lost -const MinStreamFrameSize ByteCount = 128 - -// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames -// we send after the handshake completes. -const MaxPostHandshakeCryptoFrameSize = 1000 - -// MaxAckFrameSize is the maximum size for an ACK frame that we write -// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. -// The MaxAckFrameSize should be large enough to encode many ACK range, -// but must ensure that a maximum size ACK frame fits into one packet. -const MaxAckFrameSize ByteCount = 1000 - -// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). -// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. -const MaxDatagramFrameSize ByteCount = 1220 - -// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) -const DatagramRcvQueueLen = 128 - -// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. -// It also serves as a limit for the packet history. -// If at any point we keep track of more ranges, old ranges are discarded. -const MaxNumAckRanges = 32 - -// MinPacingDelay is the minimum duration that is used for packet pacing -// If the packet packing frequency is higher, multiple packets might be sent at once. -// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. -const MinPacingDelay = time.Millisecond - -// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections -// if no other value is configured. -const DefaultConnectionIDLength = 4 - -// MaxActiveConnectionIDs is the number of connection IDs that we're storing. -const MaxActiveConnectionIDs = 4 - -// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. -const MaxIssuedConnectionIDs = 6 - -// PacketsPerConnectionID is the number of packets we send using one connection ID. -// If the peer provices us with enough new connection IDs, we switch to a new connection ID. -const PacketsPerConnectionID = 10000 - -// AckDelayExponent is the ack delay exponent used when sending ACKs. -const AckDelayExponent = 3 - -// Estimated timer granularity. -// The loss detection timer will not be set to a value smaller than granularity. -const TimerGranularity = time.Millisecond - -// MaxAckDelay is the maximum time by which we delay sending ACKs. -const MaxAckDelay = 25 * time.Millisecond - -// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. -// This is the value that should be advertised to the peer. -const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity - -// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. -const KeyUpdateInterval = 100 * 1000 - -// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. -const Max0RTTQueueingDuration = 100 * time.Millisecond - -// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. -const Max0RTTQueues = 32 - -// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. -// When a new connection is created, all buffered packets are passed to the connection immediately. -// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. -// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. -const Max0RTTQueueLen = 31 diff --git a/internal/protocol/params_test.go b/internal/protocol/params_test.go deleted file mode 100644 index 50a260d2..00000000 --- a/internal/protocol/params_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Parameters", func() { - It("can queue more packets in the session than in the 0-RTT queue", func() { - Expect(MaxConnUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) - Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) - }) -}) diff --git a/internal/protocol/perspective.go b/internal/protocol/perspective.go deleted file mode 100644 index 43358fec..00000000 --- a/internal/protocol/perspective.go +++ /dev/null @@ -1,26 +0,0 @@ -package protocol - -// Perspective determines if we're acting as a server or a client -type Perspective int - -// the perspectives -const ( - PerspectiveServer Perspective = 1 - PerspectiveClient Perspective = 2 -) - -// Opposite returns the perspective of the peer -func (p Perspective) Opposite() Perspective { - return 3 - p -} - -func (p Perspective) String() string { - switch p { - case PerspectiveServer: - return "Server" - case PerspectiveClient: - return "Client" - default: - return "invalid perspective" - } -} diff --git a/internal/protocol/perspective_test.go b/internal/protocol/perspective_test.go deleted file mode 100644 index 0ae23d7c..00000000 --- a/internal/protocol/perspective_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Perspective", func() { - It("has a string representation", func() { - Expect(PerspectiveClient.String()).To(Equal("Client")) - Expect(PerspectiveServer.String()).To(Equal("Server")) - Expect(Perspective(0).String()).To(Equal("invalid perspective")) - }) - - It("returns the opposite", func() { - Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer)) - Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient)) - }) -}) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 8241e274..366d83ff 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -1,97 +1,4 @@ package protocol -import ( - "fmt" - "time" -) - -// The PacketType is the Long Header Type -type PacketType uint8 - -const ( - // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = 1 + iota - // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry - // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake - // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT -) - -func (t PacketType) String() string { - switch t { - case PacketTypeInitial: - return "Initial" - case PacketTypeRetry: - return "Retry" - case PacketTypeHandshake: - return "Handshake" - case PacketType0RTT: - return "0-RTT Protected" - default: - return fmt.Sprintf("unknown packet type: %d", t) - } -} - -type ECN uint8 - -const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 -) - // A ByteCount in QUIC type ByteCount int64 - -// MaxByteCount is the maximum value of a ByteCount -const MaxByteCount = ByteCount(1<<62 - 1) - -// InvalidByteCount is an invalid byte count -const InvalidByteCount ByteCount = -1 - -// A StatelessResetToken is a stateless reset token. -type StatelessResetToken [16]byte - -// MaxPacketBufferSize maximum packet size of any QUIC packet, based on -// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, -// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. -// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. -const MaxPacketBufferSize ByteCount = 1452 - -// MinInitialPacketSize is the minimum size an Initial packet is required to have. -const MinInitialPacketSize = 1200 - -// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version -// needs to have in order to trigger a Version Negotiation packet. -const MinUnknownVersionPacketSize = MinInitialPacketSize - -// MinStatelessResetSize is the minimum size of a stateless reset packet that we send -const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ - -// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. -const MinConnectionIDLenInitial = 8 - -// DefaultAckDelayExponent is the default ack delay exponent -const DefaultAckDelayExponent = 3 - -// MaxAckDelayExponent is the maximum ack delay exponent -const MaxAckDelayExponent = 20 - -// DefaultMaxAckDelay is the default max_ack_delay -const DefaultMaxAckDelay = 25 * time.Millisecond - -// MaxMaxAckDelay is the maximum max_ack_delay -const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond - -// MaxConnIDLen is the maximum length of the connection ID -const MaxConnIDLen = 20 - -// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using -// AEAD_AES_128_GCM or AEAD_AES_265_GCM. -const InvalidPacketLimitAES = 1 << 52 - -// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. -const InvalidPacketLimitChaCha = 1 << 36 diff --git a/internal/protocol/protocol_suite_test.go b/internal/protocol/protocol_suite_test.go deleted file mode 100644 index 60da0157..00000000 --- a/internal/protocol/protocol_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package protocol - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestProtocol(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Protocol Suite") -} diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go deleted file mode 100644 index 117405e4..00000000 --- a/internal/protocol/protocol_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Protocol", func() { - Context("Long Header Packet Types", func() { - It("has the correct string representation", func() { - Expect(PacketTypeInitial.String()).To(Equal("Initial")) - Expect(PacketTypeRetry.String()).To(Equal("Retry")) - Expect(PacketTypeHandshake.String()).To(Equal("Handshake")) - Expect(PacketType0RTT.String()).To(Equal("0-RTT Protected")) - Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) - }) - }) - - It("converts ECN bits from the IP header wire to the correct types", func() { - Expect(ECN(0)).To(Equal(ECNNon)) - Expect(ECN(0b00000010)).To(Equal(ECT0)) - Expect(ECN(0b00000001)).To(Equal(ECT1)) - Expect(ECN(0b00000011)).To(Equal(ECNCE)) - }) -}) diff --git a/internal/protocol/stream.go b/internal/protocol/stream.go deleted file mode 100644 index ad7de864..00000000 --- a/internal/protocol/stream.go +++ /dev/null @@ -1,76 +0,0 @@ -package protocol - -// StreamType encodes if this is a unidirectional or bidirectional stream -type StreamType uint8 - -const ( - // StreamTypeUni is a unidirectional stream - StreamTypeUni StreamType = iota - // StreamTypeBidi is a bidirectional stream - StreamTypeBidi -) - -// InvalidPacketNumber is a stream ID that is invalid. -// The first valid stream ID in QUIC is 0. -const InvalidStreamID StreamID = -1 - -// StreamNum is the stream number -type StreamNum int64 - -const ( - // InvalidStreamNum is an invalid stream number. - InvalidStreamNum = -1 - // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames - // and as the stream count in the transport parameters - MaxStreamCount StreamNum = 1 << 60 -) - -// StreamID calculates the stream ID. -func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { - if s == 0 { - return InvalidStreamID - } - var first StreamID - switch stype { - case StreamTypeBidi: - switch pers { - case PerspectiveClient: - first = 0 - case PerspectiveServer: - first = 1 - } - case StreamTypeUni: - switch pers { - case PerspectiveClient: - first = 2 - case PerspectiveServer: - first = 3 - } - } - return first + 4*StreamID(s-1) -} - -// A StreamID in QUIC -type StreamID int64 - -// InitiatedBy says if the stream was initiated by the client or by the server -func (s StreamID) InitiatedBy() Perspective { - if s%2 == 0 { - return PerspectiveClient - } - return PerspectiveServer -} - -// Type says if this is a unidirectional or bidirectional stream -func (s StreamID) Type() StreamType { - if s%4 >= 2 { - return StreamTypeUni - } - return StreamTypeBidi -} - -// StreamNum returns how many streams in total are below this -// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) -func (s StreamID) StreamNum() StreamNum { - return StreamNum(s/4) + 1 -} diff --git a/internal/protocol/stream_test.go b/internal/protocol/stream_test.go deleted file mode 100644 index 4209f8a0..00000000 --- a/internal/protocol/stream_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Stream ID", func() { - It("InvalidStreamID is smaller than all valid stream IDs", func() { - Expect(InvalidStreamID).To(BeNumerically("<", 0)) - }) - - It("says who initiated a stream", func() { - Expect(StreamID(4).InitiatedBy()).To(Equal(PerspectiveClient)) - Expect(StreamID(5).InitiatedBy()).To(Equal(PerspectiveServer)) - Expect(StreamID(6).InitiatedBy()).To(Equal(PerspectiveClient)) - Expect(StreamID(7).InitiatedBy()).To(Equal(PerspectiveServer)) - }) - - It("tells the directionality", func() { - Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi)) - Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi)) - Expect(StreamID(6).Type()).To(Equal(StreamTypeUni)) - Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) - }) - - It("tells the stream number", func() { - Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) - }) - - Context("converting stream nums to stream IDs", func() { - It("handles 0", func() { - Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID)) - }) - - It("handles the first", func() { - Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) - Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) - Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) - Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) - }) - - It("handles others", func() { - Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396))) - Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397))) - Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) - Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) - }) - - It("has the right value for MaxStreamCount", func() { - const maxStreamID = StreamID(1<<62 - 1) - for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { - for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { - Expect(MaxStreamCount.StreamID(dir, pers)).To(BeNumerically("<=", maxStreamID)) - Expect((MaxStreamCount + 1).StreamID(dir, pers)).To(BeNumerically(">", maxStreamID)) - } - } - }) - }) -}) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index 5f0d93c8..a96656a2 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -1,77 +1,13 @@ package protocol import ( - "crypto/rand" - "encoding/binary" "github.com/lucas-clemente/quic-go" - "math" -) - -// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions -const ( - gquicVersion0 = 0x51303030 - maxGquicVersion = 0x51303439 ) // The version numbers, making grepping easier const ( - VersionTLS quic.VersionNumber = 0x1 - VersionWhatever quic.VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter - VersionUnknown quic.VersionNumber = math.MaxUint32 - VersionDraft29 quic.VersionNumber = 0xff00001d - Version1 quic.VersionNumber = 0x1 - Version2 quic.VersionNumber = 0x709a50c4 + VersionTLS quic.VersionNumber = 0x1 + VersionDraft29 quic.VersionNumber = 0xff00001d + Version1 quic.VersionNumber = 0x1 + Version2 quic.VersionNumber = 0x709a50c4 ) - -// SupportedVersions lists the versions that the server supports -// must be in sorted descending order -var SupportedVersions = []quic.VersionNumber{Version1, Version2, VersionDraft29} - -// IsValidVersion says if the version is known to quic-go -func IsValidVersion(v quic.VersionNumber) bool { - return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) -} - -// IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []quic.VersionNumber, v quic.VersionNumber) bool { - for _, t := range supported { - if t == v { - return true - } - } - return false -} - -// ChooseSupportedVersion finds the best version in the overlap of ours and theirs -// ours is a slice of versions that we support, sorted by our preference (descending) -// theirs is a slice of versions offered by the peer. The order does not matter. -// The bool returned indicates if a matching version was found. -func ChooseSupportedVersion(ours, theirs []quic.VersionNumber) (quic.VersionNumber, bool) { - for _, ourVer := range ours { - for _, theirVer := range theirs { - if ourVer == theirVer { - return ourVer, true - } - } - } - return 0, false -} - -// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) -func generateReservedVersion() quic.VersionNumber { - b := make([]byte, 4) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - return quic.VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) -} - -// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position -func GetGreasedVersions(supported []quic.VersionNumber) []quic.VersionNumber { - b := make([]byte, 1) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - randPos := int(b[0]) % (len(supported) + 1) - greased := make([]quic.VersionNumber, len(supported)+1) - copy(greased, supported[:randPos]) - greased[randPos] = generateReservedVersion() - copy(greased[randPos+1:], supported[randPos:]) - return greased -} diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go deleted file mode 100644 index 33e59f71..00000000 --- a/internal/protocol/version_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package protocol - -import ( - "github.com/lucas-clemente/quic-go" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Version", func() { - isReservedVersion := func(v quic.VersionNumber) bool { - return v&0x0f0f0f0f == 0x0a0a0a0a - } - - It("says if a version is valid", func() { - Expect(IsValidVersion(VersionTLS)).To(BeTrue()) - Expect(IsValidVersion(VersionWhatever)).To(BeFalse()) - Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) - Expect(IsValidVersion(VersionDraft29)).To(BeTrue()) - Expect(IsValidVersion(Version1)).To(BeTrue()) - Expect(IsValidVersion(Version2)).To(BeTrue()) - Expect(IsValidVersion(1234)).To(BeFalse()) - }) - - It("versions don't have reserved version numbers", func() { - Expect(isReservedVersion(VersionTLS)).To(BeFalse()) - }) - - It("has the right string representation", func() { - Expect(VersionWhatever.String()).To(Equal("whatever")) - Expect(VersionUnknown.String()).To(Equal("unknown")) - Expect(VersionDraft29.String()).To(Equal("draft-29")) - Expect(Version1.String()).To(Equal("v1")) - Expect(Version2.String()).To(Equal("v2")) - // check with unsupported version numbers from the wiki - Expect(quic.VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) - Expect(quic.VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) - Expect(quic.VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) - Expect(quic.VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) - Expect(quic.VersionNumber(0x01234567).String()).To(Equal("0x1234567")) - }) - - It("recognizes supported versions", func() { - Expect(IsSupportedVersion(SupportedVersions, 0)).To(BeFalse()) - Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[0])).To(BeTrue()) - Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue()) - }) - - Context("highest supported version", func() { - It("finds the supported version", func() { - supportedVersions := []quic.VersionNumber{1, 2, 3} - other := []quic.VersionNumber{6, 5, 4, 3} - ver, ok := ChooseSupportedVersion(supportedVersions, other) - Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(quic.VersionNumber(3))) - }) - - It("picks the preferred version", func() { - supportedVersions := []quic.VersionNumber{2, 1, 3} - other := []quic.VersionNumber{3, 6, 1, 8, 2, 10} - ver, ok := ChooseSupportedVersion(supportedVersions, other) - Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(quic.VersionNumber(2))) - }) - - It("says when no matching version was found", func() { - _, ok := ChooseSupportedVersion([]quic.VersionNumber{1}, []quic.VersionNumber{2}) - Expect(ok).To(BeFalse()) - }) - - It("handles empty inputs", func() { - _, ok := ChooseSupportedVersion([]quic.VersionNumber{102, 101}, []quic.VersionNumber{}) - Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]quic.VersionNumber{}, []quic.VersionNumber{1, 2}) - Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]quic.VersionNumber{}, []quic.VersionNumber{}) - Expect(ok).To(BeFalse()) - }) - }) - - Context("reserved versions", func() { - It("adds a greased version if passed an empty slice", func() { - greased := GetGreasedVersions([]quic.VersionNumber{}) - Expect(greased).To(HaveLen(1)) - Expect(isReservedVersion(greased[0])).To(BeTrue()) - }) - - It("creates greased lists of version numbers", func() { - supported := []quic.VersionNumber{10, 18, 29} - for _, v := range supported { - Expect(isReservedVersion(v)).To(BeFalse()) - } - var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int - // check that - // 1. the greased version sometimes appears first - // 2. the greased version sometimes appears in the middle - // 3. the greased version sometimes appears last - // 4. the supported versions are kept in order - for i := 0; i < 100; i++ { - greased := GetGreasedVersions(supported) - Expect(greased).To(HaveLen(4)) - var j int - for i, v := range greased { - if isReservedVersion(v) { - if i == 0 { - greasedVersionFirst++ - } - if i == len(greased)-1 { - greasedVersionLast++ - } - greasedVersionMiddle++ - continue - } - Expect(supported[j]).To(Equal(v)) - j++ - } - } - Expect(greasedVersionFirst).ToNot(BeZero()) - Expect(greasedVersionLast).ToNot(BeZero()) - Expect(greasedVersionMiddle).ToNot(BeZero()) - }) - }) -}) diff --git a/internal/utils/atomic_bool.go b/internal/utils/atomic_bool.go deleted file mode 100644 index cf464250..00000000 --- a/internal/utils/atomic_bool.go +++ /dev/null @@ -1,22 +0,0 @@ -package utils - -import "sync/atomic" - -// An AtomicBool is an atomic bool -type AtomicBool struct { - v int32 -} - -// Set sets the value -func (a *AtomicBool) Set(value bool) { - var n int32 - if value { - n = 1 - } - atomic.StoreInt32(&a.v, n) -} - -// Get gets the value -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.v) != 0 -} diff --git a/internal/utils/atomic_bool_test.go b/internal/utils/atomic_bool_test.go deleted file mode 100644 index 83a200c2..00000000 --- a/internal/utils/atomic_bool_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package utils - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Atomic Bool", func() { - var a *AtomicBool - - BeforeEach(func() { - a = &AtomicBool{} - }) - - It("has the right default value", func() { - Expect(a.Get()).To(BeFalse()) - }) - - It("sets the value to true", func() { - a.Set(true) - Expect(a.Get()).To(BeTrue()) - }) - - It("sets the value to false", func() { - a.Set(true) - a.Set(false) - Expect(a.Get()).To(BeFalse()) - }) -}) diff --git a/internal/utils/buffered_write_closer.go b/internal/utils/buffered_write_closer.go deleted file mode 100644 index b5b9d6fc..00000000 --- a/internal/utils/buffered_write_closer.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "bufio" - "io" -) - -type bufferedWriteCloser struct { - *bufio.Writer - io.Closer -} - -// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer -func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { - return &bufferedWriteCloser{ - Writer: writer, - Closer: closer, - } -} - -func (h bufferedWriteCloser) Close() error { - if err := h.Writer.Flush(); err != nil { - return err - } - return h.Closer.Close() -} diff --git a/internal/utils/buffered_write_closer_test.go b/internal/utils/buffered_write_closer_test.go deleted file mode 100644 index 9c93d615..00000000 --- a/internal/utils/buffered_write_closer_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "bufio" - "bytes" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type nopCloser struct{} - -func (nopCloser) Close() error { return nil } - -var _ = Describe("buffered io.WriteCloser", func() { - It("flushes before closing", func() { - buf := &bytes.Buffer{} - - w := bufio.NewWriter(buf) - wc := NewBufferedWriteCloser(w, &nopCloser{}) - wc.Write([]byte("foobar")) - Expect(buf.Len()).To(BeZero()) - Expect(wc.Close()).To(Succeed()) - Expect(buf.String()).To(Equal("foobar")) - }) -}) diff --git a/internal/utils/byteinterval_linkedlist.go b/internal/utils/byteinterval_linkedlist.go deleted file mode 100644 index 096023ef..00000000 --- a/internal/utils/byteinterval_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// ByteIntervalElement is an element of a linked list. -type ByteIntervalElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *ByteIntervalElement - - // The list to which this element belongs. - list *ByteIntervalList - - // The value stored with this element. - Value ByteInterval -} - -// Next returns the next list element or nil. -func (e *ByteIntervalElement) Next() *ByteIntervalElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *ByteIntervalElement) Prev() *ByteIntervalElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// ByteIntervalList is a linked list of ByteIntervals. -type ByteIntervalList struct { - root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *ByteIntervalList) Init() *ByteIntervalList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewByteIntervalList returns an initialized list. -func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *ByteIntervalList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *ByteIntervalList) Front() *ByteIntervalElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *ByteIntervalList) Back() *ByteIntervalElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *ByteIntervalList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { - return l.insert(&ByteIntervalElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/utils/byteoder_big_endian_test.go b/internal/utils/byteoder_big_endian_test.go deleted file mode 100644 index 5d0873a9..00000000 --- a/internal/utils/byteoder_big_endian_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package utils - -import ( - "bytes" - "io" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Big Endian encoding / decoding", func() { - Context("ReadUint16", func() { - It("reads a big endian", func() { - b := []byte{0x13, 0xEF} - val, err := BigEndian.ReadUint16(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint16(0x13EF))) - }) - - It("throws an error if less than 2 bytes are passed", func() { - b := []byte{0x13, 0xEF} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint16(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("ReadUint24", func() { - It("reads a big endian", func() { - b := []byte{0x13, 0xbe, 0xef} - val, err := BigEndian.ReadUint24(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint32(0x13beef))) - }) - - It("throws an error if less than 3 bytes are passed", func() { - b := []byte{0x13, 0xbe, 0xef} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint24(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("ReadUint32", func() { - It("reads a big endian", func() { - b := []byte{0x12, 0x35, 0xAB, 0xFF} - val, err := BigEndian.ReadUint32(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint32(0x1235ABFF))) - }) - - It("throws an error if less than 4 bytes are passed", func() { - b := []byte{0x12, 0x35, 0xAB, 0xFF} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint32(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("WriteUint16", func() { - It("outputs 2 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint16(b, uint16(1)) - Expect(b.Len()).To(Equal(2)) - }) - - It("outputs a big endian", func() { - num := uint16(0xFF11) - b := &bytes.Buffer{} - BigEndian.WriteUint16(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xFF, 0x11})) - }) - }) - - Context("WriteUint24", func() { - It("outputs 3 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint24(b, uint32(1)) - Expect(b.Len()).To(Equal(3)) - }) - - It("outputs a big endian", func() { - num := uint32(0xff11aa) - b := &bytes.Buffer{} - BigEndian.WriteUint24(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa})) - }) - }) - - Context("WriteUint32", func() { - It("outputs 4 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint32(b, uint32(1)) - Expect(b.Len()).To(Equal(4)) - }) - - It("outputs a big endian", func() { - num := uint32(0xEFAC3512) - b := &bytes.Buffer{} - BigEndian.WriteUint32(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) - }) - }) -}) diff --git a/internal/utils/byteorder.go b/internal/utils/byteorder.go deleted file mode 100644 index d1f52842..00000000 --- a/internal/utils/byteorder.go +++ /dev/null @@ -1,17 +0,0 @@ -package utils - -import ( - "bytes" - "io" -) - -// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. -type ByteOrder interface { - ReadUint32(io.ByteReader) (uint32, error) - ReadUint24(io.ByteReader) (uint32, error) - ReadUint16(io.ByteReader) (uint16, error) - - WriteUint32(*bytes.Buffer, uint32) - WriteUint24(*bytes.Buffer, uint32) - WriteUint16(*bytes.Buffer, uint16) -} diff --git a/internal/utils/byteorder_big_endian.go b/internal/utils/byteorder_big_endian.go deleted file mode 100644 index d05542e1..00000000 --- a/internal/utils/byteorder_big_endian.go +++ /dev/null @@ -1,89 +0,0 @@ -package utils - -import ( - "bytes" - "io" -) - -// BigEndian is the big-endian implementation of ByteOrder. -var BigEndian ByteOrder = bigEndian{} - -type bigEndian struct{} - -var _ ByteOrder = &bigEndian{} - -// ReadUintN reads N bytes -func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { - var res uint64 - for i := uint8(0); i < length; i++ { - bt, err := b.ReadByte() - if err != nil { - return 0, err - } - res ^= uint64(bt) << ((length - 1 - i) * 8) - } - return res, nil -} - -// ReadUint32 reads a uint32 -func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { - var b1, b2, b3, b4 uint8 - var err error - if b4, err = b.ReadByte(); err != nil { - return 0, err - } - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil -} - -// ReadUint24 reads a uint24 -func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { - var b1, b2, b3 uint8 - var err error - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil -} - -// ReadUint16 reads a uint16 -func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { - var b1, b2 uint8 - var err error - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint16(b1) + uint16(b2)<<8, nil -} - -// WriteUint32 writes a uint32 -func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { - b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) -} - -// WriteUint24 writes a uint24 -func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { - b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) -} - -// WriteUint16 writes a uint16 -func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { - b.Write([]byte{uint8(i >> 8), uint8(i)}) -} diff --git a/internal/utils/ip.go b/internal/utils/ip.go deleted file mode 100644 index 7ac7ffec..00000000 --- a/internal/utils/ip.go +++ /dev/null @@ -1,10 +0,0 @@ -package utils - -import "net" - -func IsIPv4(ip net.IP) bool { - // If ip is not an IPv4 address, To4 returns nil. - // Note that there might be some corner cases, where this is not correct. - // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. - return ip.To4() != nil -} diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go deleted file mode 100644 index b61cf529..00000000 --- a/internal/utils/ip_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package utils - -import ( - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("IP", func() { - It("tells IPv4 and IPv6 addresses apart", func() { - Expect(IsIPv4(net.IPv4(127, 0, 0, 1))).To(BeTrue()) - Expect(IsIPv4(net.IPv4zero)).To(BeTrue()) - Expect(IsIPv4(net.IPv6zero)).To(BeFalse()) - Expect(IsIPv4(net.IPv6loopback)).To(BeFalse()) - }) -}) diff --git a/internal/utils/minmax.go b/internal/utils/minmax.go deleted file mode 100644 index c634aa22..00000000 --- a/internal/utils/minmax.go +++ /dev/null @@ -1,170 +0,0 @@ -package utils - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/protocol" -) - -// InfDuration is a duration of infinite length -const InfDuration = time.Duration(math.MaxInt64) - -// Max returns the maximum of two Ints -func Max(a, b int) int { - if a < b { - return b - } - return a -} - -// MaxUint32 returns the maximum of two uint32 -func MaxUint32(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - -// MaxUint64 returns the maximum of two uint64 -func MaxUint64(a, b uint64) uint64 { - if a < b { - return b - } - return a -} - -// MinUint64 returns the maximum of two uint64 -func MinUint64(a, b uint64) uint64 { - if a < b { - return a - } - return b -} - -// Min returns the minimum of two Ints -func Min(a, b int) int { - if a < b { - return a - } - return b -} - -// MinUint32 returns the maximum of two uint32 -func MinUint32(a, b uint32) uint32 { - if a < b { - return a - } - return b -} - -// MinInt64 returns the minimum of two int64 -func MinInt64(a, b int64) int64 { - if a < b { - return a - } - return b -} - -// MaxInt64 returns the minimum of two int64 -func MaxInt64(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// MinByteCount returns the minimum of two ByteCounts -func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return a - } - return b -} - -// MaxByteCount returns the maximum of two ByteCounts -func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return b - } - return a -} - -// MaxDuration returns the max duration -func MaxDuration(a, b time.Duration) time.Duration { - if a > b { - return a - } - return b -} - -// MinDuration returns the minimum duration -func MinDuration(a, b time.Duration) time.Duration { - if a > b { - return b - } - return a -} - -// MinNonZeroDuration return the minimum duration that's not zero. -func MinNonZeroDuration(a, b time.Duration) time.Duration { - if a == 0 { - return b - } - if b == 0 { - return a - } - return MinDuration(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d -} - -// MinTime returns the earlier time -func MinTime(a, b time.Time) time.Time { - if a.After(b) { - return b - } - return a -} - -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - -// MaxTime returns the later time -func MaxTime(a, b time.Time) time.Time { - if a.After(b) { - return a - } - return b -} - -// MaxPacketNumber returns the max packet number -func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a > b { - return a - } - return b -} - -// MinPacketNumber returns the min packet number -func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a < b { - return a - } - return b -} diff --git a/internal/utils/minmax_test.go b/internal/utils/minmax_test.go deleted file mode 100644 index 021212b6..00000000 --- a/internal/utils/minmax_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Min / Max", func() { - Context("Max", func() { - It("returns the maximum", func() { - Expect(Max(5, 7)).To(Equal(7)) - Expect(Max(7, 5)).To(Equal(7)) - }) - - It("returns the maximum uint32", func() { - Expect(MaxUint32(5, 7)).To(Equal(uint32(7))) - Expect(MaxUint32(7, 5)).To(Equal(uint32(7))) - }) - - It("returns the maximum uint64", func() { - Expect(MaxUint64(5, 7)).To(Equal(uint64(7))) - Expect(MaxUint64(7, 5)).To(Equal(uint64(7))) - }) - - It("returns the minimum uint64", func() { - Expect(MinUint64(5, 7)).To(Equal(uint64(5))) - Expect(MinUint64(7, 5)).To(Equal(uint64(5))) - }) - - It("returns the maximum int64", func() { - Expect(MaxInt64(5, 7)).To(Equal(int64(7))) - Expect(MaxInt64(7, 5)).To(Equal(int64(7))) - }) - - It("returns the maximum ByteCount", func() { - Expect(MaxByteCount(7, 5)).To(Equal(protocol.ByteCount(7))) - Expect(MaxByteCount(5, 7)).To(Equal(protocol.ByteCount(7))) - }) - - It("returns the maximum duration", func() { - Expect(MaxDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Microsecond)) - Expect(MaxDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Microsecond)) - }) - - It("returns the minimum duration", func() { - Expect(MinDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Nanosecond)) - Expect(MinDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Nanosecond)) - }) - - It("returns packet number max", func() { - Expect(MaxPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(2))) - Expect(MaxPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(2))) - }) - - It("returns the maximum time", func() { - a := time.Now() - b := a.Add(time.Second) - Expect(MaxTime(a, b)).To(Equal(b)) - Expect(MaxTime(b, a)).To(Equal(b)) - }) - }) - - Context("Min", func() { - It("returns the minimum", func() { - Expect(Min(5, 7)).To(Equal(5)) - Expect(Min(7, 5)).To(Equal(5)) - }) - - It("returns the minimum uint32", func() { - Expect(MinUint32(7, 5)).To(Equal(uint32(5))) - Expect(MinUint32(5, 7)).To(Equal(uint32(5))) - }) - - It("returns the minimum int64", func() { - Expect(MinInt64(7, 5)).To(Equal(int64(5))) - Expect(MinInt64(5, 7)).To(Equal(int64(5))) - }) - - It("returns the minimum ByteCount", func() { - Expect(MinByteCount(7, 5)).To(Equal(protocol.ByteCount(5))) - Expect(MinByteCount(5, 7)).To(Equal(protocol.ByteCount(5))) - }) - - It("returns packet number min", func() { - Expect(MinPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(1))) - Expect(MinPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(1))) - }) - - It("returns the minimum duration", func() { - a := time.Now() - b := a.Add(time.Second) - Expect(MinTime(a, b)).To(Equal(a)) - Expect(MinTime(b, a)).To(Equal(a)) - }) - - It("returns the minium non-zero duration", func() { - var a time.Duration - b := time.Second - Expect(MinNonZeroDuration(0, 0)).To(BeZero()) - Expect(MinNonZeroDuration(a, b)).To(Equal(b)) - Expect(MinNonZeroDuration(b, a)).To(Equal(b)) - Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) - }) - - It("returns the minium non-zero time", func() { - a := time.Time{} - b := time.Now() - Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) - Expect(MinNonZeroTime(a, b)).To(Equal(b)) - Expect(MinNonZeroTime(b, a)).To(Equal(b)) - Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) - Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) - }) - }) - - It("returns the abs time", func() { - Expect(AbsDuration(time.Microsecond)).To(Equal(time.Microsecond)) - Expect(AbsDuration(-time.Microsecond)).To(Equal(time.Microsecond)) - }) -}) diff --git a/internal/utils/new_connection_id.go b/internal/utils/new_connection_id.go deleted file mode 100644 index f758d63d..00000000 --- a/internal/utils/new_connection_id.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import ( - "github.com/imroc/req/v3/internal/protocol" -) - -// NewConnectionID is a new connection ID -type NewConnectionID struct { - SequenceNumber uint64 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} diff --git a/internal/utils/newconnectionid_linkedlist.go b/internal/utils/newconnectionid_linkedlist.go deleted file mode 100644 index d59562e5..00000000 --- a/internal/utils/newconnectionid_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// NewConnectionIDElement is an element of a linked list. -type NewConnectionIDElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *NewConnectionIDElement - - // The list to which this element belongs. - list *NewConnectionIDList - - // The value stored with this element. - Value NewConnectionID -} - -// Next returns the next list element or nil. -func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// NewConnectionIDList is a linked list of NewConnectionIDs. -type NewConnectionIDList struct { - root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *NewConnectionIDList) Init() *NewConnectionIDList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewNewConnectionIDList returns an initialized list. -func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *NewConnectionIDList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Front() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Back() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *NewConnectionIDList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { - return l.insert(&NewConnectionIDElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/utils/packet_interval.go b/internal/utils/packet_interval.go deleted file mode 100644 index ac2ca048..00000000 --- a/internal/utils/packet_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/imroc/req/v3/internal/protocol" - -// PacketInterval is an interval from one PacketNumber to the other -type PacketInterval struct { - Start protocol.PacketNumber - End protocol.PacketNumber -} diff --git a/internal/utils/packetinterval_linkedlist.go b/internal/utils/packetinterval_linkedlist.go deleted file mode 100644 index b461e85a..00000000 --- a/internal/utils/packetinterval_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// PacketIntervalElement is an element of a linked list. -type PacketIntervalElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *PacketIntervalElement - - // The list to which this element belongs. - list *PacketIntervalList - - // The value stored with this element. - Value PacketInterval -} - -// Next returns the next list element or nil. -func (e *PacketIntervalElement) Next() *PacketIntervalElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *PacketIntervalElement) Prev() *PacketIntervalElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// PacketIntervalList is a linked list of PacketIntervals. -type PacketIntervalList struct { - root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *PacketIntervalList) Init() *PacketIntervalList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewPacketIntervalList returns an initialized list. -func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *PacketIntervalList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *PacketIntervalList) Front() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *PacketIntervalList) Back() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *PacketIntervalList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { - return l.insert(&PacketIntervalElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/utils/rand.go b/internal/utils/rand.go deleted file mode 100644 index 30069144..00000000 --- a/internal/utils/rand.go +++ /dev/null @@ -1,29 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/binary" -) - -// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. -type Rand struct { - buf [4]byte -} - -func (r *Rand) Int31() int32 { - rand.Read(r.buf[:]) - return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) -} - -// copied from the standard library math/rand implementation of Int63n -func (r *Rand) Int31n(n int32) int32 { - if n&(n-1) == 0 { // n is power of two, can mask - return r.Int31() & (n - 1) - } - max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) - v := r.Int31() - for v > max { - v = r.Int31() - } - return v % n -} diff --git a/internal/utils/rand_test.go b/internal/utils/rand_test.go deleted file mode 100644 index f15a644e..00000000 --- a/internal/utils/rand_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package utils - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Rand", func() { - It("generates random numbers", func() { - const ( - num = 1000 - max = 12345678 - ) - - var values [num]int32 - var r Rand - for i := 0; i < num; i++ { - v := r.Int31n(max) - Expect(v).To(And( - BeNumerically(">=", 0), - BeNumerically("<", max), - )) - values[i] = v - } - - var sum uint64 - for _, n := range values { - sum += uint64(n) - } - Expect(float64(sum) / num).To(BeNumerically("~", max/2, max/25)) - }) -}) diff --git a/internal/utils/rtt_stats.go b/internal/utils/rtt_stats.go deleted file mode 100644 index 478d2596..00000000 --- a/internal/utils/rtt_stats.go +++ /dev/null @@ -1,127 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/protocol" -) - -const ( - rttAlpha = 0.125 - oneMinusAlpha = 1 - rttAlpha - rttBeta = 0.25 - oneMinusBeta = 1 - rttBeta - // The default RTT used before an RTT sample is taken. - defaultInitialRTT = 100 * time.Millisecond -) - -// RTTStats provides round-trip statistics -type RTTStats struct { - hasMeasurement bool - - minRTT time.Duration - latestRTT time.Duration - smoothedRTT time.Duration - meanDeviation time.Duration - - maxAckDelay time.Duration -} - -// NewRTTStats makes a properly initialized RTTStats object -func NewRTTStats() *RTTStats { - return &RTTStats{} -} - -// MinRTT Returns the minRTT for the entire connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } - -// LatestRTT returns the most recent rtt measurement. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } - -// SmoothedRTT returns the smoothed RTT for the connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } - -// MeanDeviation gets the mean deviation -func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } - -// MaxAckDelay gets the max_ack_delay advertised by the peer -func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } - -// PTO gets the probe timeout duration. -func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { - if r.SmoothedRTT() == 0 { - return 2 * defaultInitialRTT - } - pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) - if includeMaxAckDelay { - pto += r.MaxAckDelay() - } - return pto -} - -// UpdateRTT updates the RTT based on a new sample. -func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { - if sendDelta == InfDuration || sendDelta <= 0 { - return - } - - // Update r.minRTT first. r.minRTT does not use an rttSample corrected for - // ackDelay but the raw observed sendDelta, since poor clock granularity at - // the client may cause a high ackDelay to result in underestimation of the - // r.minRTT. - if r.minRTT == 0 || r.minRTT > sendDelta { - r.minRTT = sendDelta - } - - // Correct for ackDelay if information received from the peer results in a - // an RTT sample at least as large as minRTT. Otherwise, only use the - // sendDelta. - sample := sendDelta - if sample-r.minRTT >= ackDelay { - sample -= ackDelay - } - r.latestRTT = sample - // First time call. - if !r.hasMeasurement { - r.hasMeasurement = true - r.smoothedRTT = sample - r.meanDeviation = sample / 2 - } else { - r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond - r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond - } -} - -// SetMaxAckDelay sets the max_ack_delay -func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { - r.maxAckDelay = mad -} - -// SetInitialRTT sets the initial RTT. -// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. -func (r *RTTStats) SetInitialRTT(t time.Duration) { - if r.hasMeasurement { - panic("initial RTT set after first measurement") - } - r.smoothedRTT = t - r.latestRTT = t -} - -// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. -func (r *RTTStats) OnConnectionMigration() { - r.latestRTT = 0 - r.minRTT = 0 - r.smoothedRTT = 0 - r.meanDeviation = 0 -} - -// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt -// is larger. The mean deviation is increased to the most recent deviation if -// it's larger. -func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) -} diff --git a/internal/utils/rtt_stats_test.go b/internal/utils/rtt_stats_test.go deleted file mode 100644 index 555a8b8d..00000000 --- a/internal/utils/rtt_stats_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("RTT stats", func() { - var rttStats *RTTStats - - BeforeEach(func() { - rttStats = NewRTTStats() - }) - - It("DefaultsBeforeUpdate", func() { - Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) - }) - - It("SmoothedRTT", func() { - // Verify that ack_delay is ignored in the first measurement. - rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) - // Verify that Smoothed RTT includes max ack delay if it's reasonable. - rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) - // Verify that large erroneous ack_delay does not change Smoothed RTT. - rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond))) - }) - - It("MinRTT", func() { - rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - // Verify that ack_delay does not go into recording of MinRTT_. - rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond))) - }) - - It("MaxAckDelay", func() { - rttStats.SetMaxAckDelay(42 * time.Minute) - Expect(rttStats.MaxAckDelay()).To(Equal(42 * time.Minute)) - }) - - It("computes the PTO", func() { - maxAckDelay := 42 * time.Minute - rttStats.SetMaxAckDelay(maxAckDelay) - rtt := time.Second - rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) - Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) - Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) - Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) - }) - - It("uses the granularity for computing the PTO for short RTTs", func() { - rtt := time.Microsecond - rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) - }) - - It("ExpireSmoothedMetrics", func() { - initialRtt := (10 * time.Millisecond) - rttStats.UpdateRTT(initialRtt, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - - Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2)) - - // Update once with a 20ms RTT. - doubledRtt := initialRtt * (2) - rttStats.UpdateRTT(doubledRtt, 0, time.Time{}) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125))) - - // Expire the smoothed metrics, increasing smoothed rtt and mean deviation. - rttStats.ExpireSmoothedMetrics() - Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt)) - Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875))) - - // Now go back down to 5ms and expire the smoothed metrics, and ensure the - // mean deviation increases to 15ms. - halfRtt := initialRtt / 2 - rttStats.UpdateRTT(halfRtt, 0, time.Time{}) - Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT())) - Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation())) - }) - - It("UpdateRTTWithBadSendDeltas", func() { - // Make sure we ignore bad RTTs. - // base::test::MockLog log; - - initialRtt := (10 * time.Millisecond) - rttStats.UpdateRTT(initialRtt, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - - badSendDeltas := []time.Duration{ - 0, - InfDuration, - -1000 * time.Microsecond, - } - // log.StartCapturingLogs(); - - for _, badSendDelta := range badSendDeltas { - // SCOPED_TRACE(Message() << "bad_send_delta = " - // << bad_send_delta.ToMicroseconds()); - // EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring"))); - rttStats.UpdateRTT(badSendDelta, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - } - }) - - It("ResetAfterConnectionMigrations", func() { - rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - - // Reset rtt stats on connection migrations. - rttStats.OnConnectionMigration() - Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) - }) - - It("restores the RTT", func() { - rttStats.SetInitialRTT(10 * time.Second) - Expect(rttStats.LatestRTT()).To(Equal(10 * time.Second)) - Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second)) - Expect(rttStats.MeanDeviation()).To(BeZero()) - // update the RTT and make sure that the initial value is immediately forgotten - rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal(200 * time.Millisecond)) - Expect(rttStats.SmoothedRTT()).To(Equal(200 * time.Millisecond)) - Expect(rttStats.MeanDeviation()).To(Equal(100 * time.Millisecond)) - }) -}) diff --git a/internal/utils/streamframe_interval.go b/internal/utils/streamframe_interval.go deleted file mode 100644 index 71f3c6e5..00000000 --- a/internal/utils/streamframe_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/imroc/req/v3/internal/protocol" - -// ByteInterval is an interval from one ByteCount to the other -type ByteInterval struct { - Start protocol.ByteCount - End protocol.ByteCount -} diff --git a/internal/utils/timer.go b/internal/utils/timer.go deleted file mode 100644 index a4f5e67a..00000000 --- a/internal/utils/timer.go +++ /dev/null @@ -1,53 +0,0 @@ -package utils - -import ( - "math" - "time" -) - -// A Timer wrapper that behaves correctly when resetting -type Timer struct { - t *time.Timer - read bool - deadline time.Time -} - -// NewTimer creates a new timer that is not set -func NewTimer() *Timer { - return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))} -} - -// Chan returns the channel of the wrapped timer -func (t *Timer) Chan() <-chan time.Time { - return t.t.C -} - -// Reset the timer, no matter whether the value was read or not -func (t *Timer) Reset(deadline time.Time) { - if deadline.Equal(t.deadline) && !t.read { - // No need to reset the timer - return - } - - // We need to drain the timer if the value from its channel was not read yet. - // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU - if !t.t.Stop() && !t.read { - <-t.t.C - } - if !deadline.IsZero() { - t.t.Reset(time.Until(deadline)) - } - - t.read = false - t.deadline = deadline -} - -// SetRead should be called after the value from the chan was read -func (t *Timer) SetRead() { - t.read = true -} - -// Stop stops the timer -func (t *Timer) Stop() { - t.t.Stop() -} diff --git a/internal/utils/timer_test.go b/internal/utils/timer_test.go deleted file mode 100644 index 0cbb4a01..00000000 --- a/internal/utils/timer_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package utils - -import ( - "time" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Timer", func() { - const d = 10 * time.Millisecond - - It("doesn't fire a newly created timer", func() { - t := NewTimer() - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("works", func() { - t := NewTimer() - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("works multiple times with reading", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - t.SetRead() - } - }) - - It("works multiple times without reading", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(d)) - time.Sleep(d * 2) - } - Eventually(t.Chan()).Should(Receive()) - }) - - It("works when resetting without expiration", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(time.Hour)) - } - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("immediately fires the timer, if the deadlines has already passed", func() { - t := NewTimer() - t.Reset(time.Now().Add(-time.Second)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("doesn't set a timer if the deadline is the zero value", func() { - t := NewTimer() - t.Reset(time.Time{}) - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("fires the timer twice, if reset to the same deadline", func() { - deadline := time.Now().Add(-time.Millisecond) - t := NewTimer() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - t.SetRead() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - }) - - It("only fires the timer once, if it is reset to the same deadline, but not read in between", func() { - deadline := time.Now().Add(-time.Millisecond) - t := NewTimer() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("stops", func() { - t := NewTimer() - t.Reset(time.Now().Add(50 * time.Millisecond)) - t.Stop() - Consistently(t.Chan()).ShouldNot(Receive()) - }) -}) From 83640d24da385afd9d7e9116fa78946ce611269d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 17:21:20 +0800 Subject: [PATCH 527/843] add altsvcutil test --- client.go | 7 +++++++ internal/altsvcutil/altsvcutil.go | 1 - internal/altsvcutil/altsvcutil_test.go | 14 ++++++++++++++ internal/netutil/addr.go | 1 - transport.go | 5 +++++ 5 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 internal/altsvcutil/altsvcutil_test.go diff --git a/client.go b/client.go index 333df2b2..e9d28e2c 100644 --- a/client.go +++ b/client.go @@ -866,6 +866,13 @@ func (c *Client) SetUnixSocket(file string) *Client { }) } +// DisableHTTP3 disables the http3 protocol. +func (c *Client) DisableHTTP3() *Client { + c.t.DisableHTTP3() + return c +} + +// EnableHTTP3 enables the http3 protocol. func (c *Client) EnableHTTP3() *Client { c.t.EnableHTTP3() return c diff --git a/internal/altsvcutil/altsvcutil.go b/internal/altsvcutil/altsvcutil.go index 8556c7cb..5399d157 100644 --- a/internal/altsvcutil/altsvcutil.go +++ b/internal/altsvcutil/altsvcutil.go @@ -48,7 +48,6 @@ func splitHostPort(hostPort string) (host, port string) { if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { host = host[1 : len(host)-1] } - return } diff --git a/internal/altsvcutil/altsvcutil_test.go b/internal/altsvcutil/altsvcutil_test.go new file mode 100644 index 00000000..3cbc0019 --- /dev/null +++ b/internal/altsvcutil/altsvcutil_test.go @@ -0,0 +1,14 @@ +package altsvcutil + +import ( + "github.com/imroc/req/v3/internal/tests" + "testing" +) + +func TestParseHeader(t *testing.T) { + as, err := ParseHeader(` h3=":443"; ma=86400, h3-29=":443"; ma=86400`) + tests.AssertNoError(t, err) + tests.AssertEqual(t, 2, len(as)) + tests.AssertEqual(t, "h3", as[0].Protocol) + tests.AssertEqual(t, "443", as[0].Port) +} diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go index 28a80fcb..d5dfc430 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/addr.go @@ -35,6 +35,5 @@ func AuthorityHostPort(scheme, authority string) (host, port string) { if a, err := idna.ToASCII(host); err == nil { host = a } - return } diff --git a/transport.go b/transport.go index e1e7b6d4..205a7566 100644 --- a/transport.go +++ b/transport.go @@ -284,6 +284,11 @@ type pendingAltSvc struct { Transport http.RoundTripper } +func (t *Transport) DisableHTTP3() { + t.altSvcJar = nil + t.pendingAltSvcs = nil +} + func (t *Transport) EnableHTTP3() { v := runtime.Version() ss := strings.Split(v, ".") From 696fa9447e465e831021deaca30be23a49fb7432 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Jul 2022 21:28:42 +0800 Subject: [PATCH 528/843] add comments --- internal/altsvcutil/altsvcutil.go | 2 ++ internal/transport/option.go | 1 + pkg/altsvc/altsvc.go | 15 ++++++++++++--- pkg/altsvc/jar.go | 3 +++ 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/internal/altsvcutil/altsvcutil.go b/internal/altsvcutil/altsvcutil.go index 5399d157..8aa160a0 100644 --- a/internal/altsvcutil/altsvcutil.go +++ b/internal/altsvcutil/altsvcutil.go @@ -51,6 +51,7 @@ func splitHostPort(hostPort string) (host, port string) { return } +// ParseHeader parses the AltSvc from header value. func ParseHeader(value string) ([]*altsvc.AltSvc, error) { p := newAltSvcParser(value) return p.Parse() @@ -192,6 +193,7 @@ func (p *altAvcParser) parseOne() (as *altsvc.AltSvc, err error) { return } +// ConvertURL converts the raw request url to expected alt-svc's url. func ConvertURL(a *altsvc.AltSvc, u *url.URL) *url.URL { host, port := netutil.AuthorityHostPort(u.Scheme, u.Host) uu := *u diff --git a/internal/transport/option.go b/internal/transport/option.go index 9faaeccb..54acc9b5 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -10,6 +10,7 @@ import ( "time" ) +// Options is transport's options. type Options struct { // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the diff --git a/pkg/altsvc/altsvc.go b/pkg/altsvc/altsvc.go index 8b095424..b464b05d 100644 --- a/pkg/altsvc/altsvc.go +++ b/pkg/altsvc/altsvc.go @@ -5,11 +5,14 @@ import ( "time" ) +// AltSvcJar is default implementation of Jar, which stores +// AltSvc in memory. type AltSvcJar struct { entries map[string]*AltSvc mu sync.Mutex } +// NewAltSvcJar create a AltSvcJar which implements Jar. func NewAltSvcJar() *AltSvcJar { return &AltSvcJar{ entries: make(map[string]*AltSvc), @@ -43,9 +46,15 @@ func (j *AltSvcJar) SetAltSvc(addr string, as *AltSvc) { j.entries[addr] = as } +// AltSvc is the parsed alt-svc. type AltSvc struct { + // Protocol is the alt-svc proto, e.g. h3. Protocol string - Host string - Port string - Expire time.Time + // Host is the alt-svc's host, could be empty if + // it's the same host as the raw request. + Host string + // Port is the alt-svc's port. + Port string + // Expire is the time that the alt-svc should expire. + Expire time.Time } diff --git a/pkg/altsvc/jar.go b/pkg/altsvc/jar.go index 60ca03ca..6264bc56 100644 --- a/pkg/altsvc/jar.go +++ b/pkg/altsvc/jar.go @@ -1,6 +1,9 @@ package altsvc +// Jar is a container of AltSvc. type Jar interface { + // SetAltSvc store the AltSvc. SetAltSvc(addr string, as *AltSvc) + // GetAltSvc get the AltSvc. GetAltSvc(addr string) *AltSvc } From 56b46b91c9d580aa63689e826ca3e969d8c77289 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 09:48:04 +0800 Subject: [PATCH 529/843] Add Client.Logger() to expose internal logger (#132) --- client.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/client.go b/client.go index e9d28e2c..32e81e03 100644 --- a/client.go +++ b/client.go @@ -878,6 +878,15 @@ func (c *Client) EnableHTTP3() *Client { return c } +// Logger return the internal logger, usually used in middleware. +func (c *Client) Logger() Logger { + if c.log != nil { + return c.log + } + c.log = createDefaultLogger() + return c.log +} + // NewClient is the alias of C func NewClient() *Client { return C() From 69b88f404bf8a8ae36402cc489f39ef59a01300d Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 09:50:04 +0800 Subject: [PATCH 530/843] Cancel set header in DevMode (#134) --- client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client.go b/client.go index 32e81e03..f078cfd8 100644 --- a/client.go +++ b/client.go @@ -376,8 +376,7 @@ func (c *Client) EnableDebugLog() *Client { func (c *Client) DevMode() *Client { return c.EnableDumpAll(). EnableDebugLog(). - EnableTraceAll(). - SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36") + EnableTraceAll() } // SetScheme set the default scheme for client, will be used when From 7f72f368acccf062503c5b11e5793e74a6343c10 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 09:51:04 +0800 Subject: [PATCH 531/843] Update comments --- client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client.go b/client.go index f078cfd8..238cf5d7 100644 --- a/client.go +++ b/client.go @@ -372,7 +372,6 @@ func (c *Client) EnableDebugLog() *Client { // 1. Dump content of all requests and responses to see details. // 2. Output debug level log for deeper insights. // 3. Trace all requests, so you can get trace info to analyze performance. -// 4. Set User-Agent to pretend to be a web browser, avoid returning abnormal data from some sites. func (c *Client) DevMode() *Client { return c.EnableDumpAll(). EnableDebugLog(). From 76558ca0789caf90ddeddce0f20a874f432f3756 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 11:52:34 +0800 Subject: [PATCH 532/843] Optimize client and transport --- client.go | 34 +++++++++++++++++++++------------- transport.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 238cf5d7..55466888 100644 --- a/client.go +++ b/client.go @@ -387,6 +387,15 @@ func (c *Client) SetScheme(scheme string) *Client { return c } +// GetLogger return the internal logger, usually used in middleware. +func (c *Client) GetLogger() Logger { + if c.log != nil { + return c.log + } + c.log = createDefaultLogger() + return c.log +} + // SetLogger set the customized logger for client, will disable log if set to nil. func (c *Client) SetLogger(log Logger) *Client { if log == nil { @@ -744,20 +753,28 @@ func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { // EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (c *Client) EnableForceHTTP1() *Client { - c.t.ForceHttpVersion = HTTP1 + c.t.EnableForceHTTP1() return c } // EnableForceHTTP2 enable force using HTTP2 for https requests // (disabled by default). func (c *Client) EnableForceHTTP2() *Client { - c.t.ForceHttpVersion = HTTP2 + c.t.EnableForceHTTP2() + return c +} + +// EnableForceHTTP3 enable force using HTTP3 for https requests +// (disabled by default). +func (c *Client) EnableForceHTTP3() *Client { + c.t.EnableForceHTTP3() return c } -// DisableForceHttpVersion disable force using HTTP1 (disabled by default). +// DisableForceHttpVersion disable force using specified http +// version (disabled by default). func (c *Client) DisableForceHttpVersion() *Client { - c.t.ForceHttpVersion = "" + c.t.DisableForceHttpVersion() return c } @@ -876,15 +893,6 @@ func (c *Client) EnableHTTP3() *Client { return c } -// Logger return the internal logger, usually used in middleware. -func (c *Client) Logger() Logger { - if c.log != nil { - return c.log - } - c.log = createDefaultLogger() - return c.log -} - // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/transport.go b/transport.go index 205a7566..db144a7d 100644 --- a/transport.go +++ b/transport.go @@ -192,85 +192,102 @@ func (t *Transport) SetAutoDecodeContentType(contentTypes ...string) { t.autoDecodeContentType = autoDecodeContentTypeFunc(contentTypes...) } +// GetMaxIdleConns returns MaxIdleConns. func (t *Transport) GetMaxIdleConns() int { return t.MaxIdleConns } +// SetMaxIdleConns set the MaxIdleConns. func (t *Transport) SetMaxIdleConns(max int) *Transport { t.MaxIdleConns = max return t } +// SetMaxConnsPerHost set the MaxConnsPerHost. func (t *Transport) SetMaxConnsPerHost(max int) *Transport { t.MaxConnsPerHost = max return t } +// SetIdleConnTimeout set the IdleConnTimeout. func (t *Transport) SetIdleConnTimeout(timeout time.Duration) *Transport { t.IdleConnTimeout = timeout return t } +// SetTLSHandshakeTimeout set the TLSHandshakeTimeout. func (t *Transport) SetTLSHandshakeTimeout(timeout time.Duration) *Transport { t.TLSHandshakeTimeout = timeout return t } +// SetResponseHeaderTimeout set the ResponseHeaderTimeout. func (t *Transport) SetResponseHeaderTimeout(timeout time.Duration) *Transport { t.ResponseHeaderTimeout = timeout return t } +// SetExpectContinueTimeout set the ExpectContinueTimeout. func (t *Transport) SetExpectContinueTimeout(timeout time.Duration) *Transport { t.ExpectContinueTimeout = timeout return t } +// SetGetProxyConnectHeader set the GetProxyConnectHeader function. func (t *Transport) SetGetProxyConnectHeader(fn func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error)) *Transport { t.GetProxyConnectHeader = fn return t } +// SetProxyConnectHeader set the ProxyConnectHeader. func (t *Transport) SetProxyConnectHeader(header http.Header) *Transport { t.ProxyConnectHeader = header return t } +// SetReadBufferSize set the ReadBufferSize. func (t *Transport) SetReadBufferSize(size int) *Transport { t.ReadBufferSize = size return t } +// SetWriteBufferSize set the WriteBufferSize. func (t *Transport) SetWriteBufferSize(size int) *Transport { t.WriteBufferSize = size return t } +// SetMaxResponseHeaderBytes set the MaxResponseHeaderBytes. func (t *Transport) SetMaxResponseHeaderBytes(max int64) *Transport { t.MaxResponseHeaderBytes = max return t } +// SetTLSClientConfig set the custom tle client config. func (t *Transport) SetTLSClientConfig(cfg *tls.Config) *Transport { t.TLSClientConfig = cfg return t } +// SetDebug set the debug function. func (t *Transport) SetDebug(debugf func(format string, v ...interface{})) *Transport { t.Debugf = debugf return t } +// SetProxy set the http proxy, only valid for HTTP1 and HTTP2. func (t *Transport) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Transport { t.Proxy = proxy return t } +// SetDial set the custom DialContext function, only valid for HTTP1 and HTTP2. func (t *Transport) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { t.DialContext = fn return t } +// SetDialTLS set the custom DialTLSContext function, only valid for HTTP1 and HTTP2. func (t *Transport) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { t.DialTLSContext = fn return t @@ -284,12 +301,44 @@ type pendingAltSvc struct { Transport http.RoundTripper } +// EnableForceHTTP1 enable force using HTTP1 (disabled by default). +func (t *Transport) EnableForceHTTP1() *Transport { + t.ForceHttpVersion = HTTP1 + return t +} + +// EnableForceHTTP2 enable force using HTTP2 for https requests +// (disabled by default). +func (t *Transport) EnableForceHTTP2() *Transport { + t.ForceHttpVersion = HTTP2 + return t +} + +// EnableForceHTTP3 enable force using HTTP3 for https requests +// (disabled by default). +func (t *Transport) EnableForceHTTP3() *Transport { + t.ForceHttpVersion = HTTP3 + return t +} + +// DisableForceHttpVersion disable force using specified http +// version (disabled by default). +func (t *Transport) DisableForceHttpVersion() *Transport { + t.ForceHttpVersion = "" + return t +} + func (t *Transport) DisableHTTP3() { t.altSvcJar = nil t.pendingAltSvcs = nil + t.t3 = nil } func (t *Transport) EnableHTTP3() { + if t.t3 != nil { + return + } + v := runtime.Version() ss := strings.Split(v, ".") From f1c71411de982bb2f88479763d1342af58a938c7 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 12:03:53 +0800 Subject: [PATCH 533/843] fix http3 option DisableCompression and MaxResponseHeaderBytes --- internal/http3/roundtrip.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index ba510c6e..208d81cd 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -27,16 +27,6 @@ type RoundTripper struct { *transport.Options mutex sync.Mutex - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool - // QuicConfig is the quic.Config used for dialing new connections. // If nil, reasonable default values will be used. QuicConfig *quic.Config @@ -69,11 +59,6 @@ type RoundTripper struct { // If Dial is nil, quic.DialAddrEarlyContext will be used. Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - // MaxResponseHeaderBytes specifies a limit on how many response bytes are - // allowed in the server's response header. - // Zero means to use a default limit. - MaxResponseHeaderBytes int64 - clients map[string]roundTripCloser } From 21baa4c7531237a89355e56a6325391f42ef36f9 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 14:14:29 +0800 Subject: [PATCH 534/843] remove unused HaveCachedConn in http3 --- internal/http3/roundtrip.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 208d81cd..96870559 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -137,11 +137,6 @@ func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Respons return r.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) } -func (r *RoundTripper) HaveCachedConn(addr string) bool { - _, ok := r.clients[addr] - return ok -} - func (r *RoundTripper) AddConn(addr string) error { c, err := r.getClient(addr, false) if err != nil { From 7eb96177e2d24918d25440c18558f06d09e2d0de Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 14:16:13 +0800 Subject: [PATCH 535/843] add comments for http3 roundtriper --- internal/http3/roundtrip.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 96870559..e45ec85b 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -133,10 +133,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } +// RoundTripOnlyCachedConn round trip only cached conn. func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) } +// AddConn add a http3 connection, dial new conn if not exists. func (r *RoundTripper) AddConn(addr string) error { c, err := r.getClient(addr, false) if err != nil { From bf375264feddb3cb41153090dc034f0106ed273a Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Jul 2022 17:13:27 +0800 Subject: [PATCH 536/843] wrap more methods for client --- client.go | 13 +++++++------ client_wrapper.go | 12 ++++++++++++ client_wrapper_test.go | 2 ++ 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 55466888..06df9e2c 100644 --- a/client.go +++ b/client.go @@ -658,7 +658,7 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { // SetProxy set the proxy function. func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { - c.t.Proxy = proxy + c.t.SetProxy(proxy) return c } @@ -681,7 +681,8 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { c.log.Errorf("failed to parse proxy url %s: %v", proxyUrl, err) return c } - c.t.Proxy = http.ProxyURL(u) + proxy := http.ProxyURL(u) + c.t.SetProxy(proxy) return c } @@ -732,22 +733,22 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli } // SetDialTLS set the customized `DialTLSContext` function to Transport. -// Make sure the returned `conn` implements TLSConn if you want your +// Make sure the returned `conn` implements pkg/tls.Conn if you want your // customized `conn` supports HTTP2. func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - c.t.DialTLSContext = fn + c.t.SetDialTLS(fn) return c } // SetDial set the customized `DialContext` function to Transport. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - c.t.DialContext = fn + c.t.SetDial(fn) return c } // SetTLSHandshakeTimeout set the TLS handshake timeout. func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { - c.t.TLSHandshakeTimeout = timeout + c.t.SetTLSHandshakeTimeout(timeout) return c } diff --git a/client_wrapper.go b/client_wrapper.go index 56c5e0bf..30b8430b 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -442,6 +442,18 @@ func EnableForceHTTP2() *Client { return defaultClient.EnableForceHTTP2() } +// EnableForceHTTP3 is a global wrapper methods which delegated +// to the default client's EnableForceHTTP3. +func EnableForceHTTP3() *Client { + return defaultClient.EnableForceHTTP3() +} + +// EnableHTTP3 is a global wrapper methods which delegated +// to the default client's EnableHTTP3. +func EnableHTTP3() *Client { + return defaultClient.EnableHTTP3() +} + // DisableForceHttpVersion is a global wrapper methods which delegated // to the default client's DisableForceHttpVersion. func DisableForceHttpVersion() *Client { diff --git a/client_wrapper_test.go b/client_wrapper_test.go index bf9496f5..40a1ab48 100644 --- a/client_wrapper_test.go +++ b/client_wrapper_test.go @@ -102,6 +102,8 @@ func TestGlobalWrapper(t *testing.T) { SetRedirectPolicy(NoRedirectPolicy()), EnableForceHTTP1(), EnableForceHTTP2(), + EnableForceHTTP3(), + EnableHTTP3(), DisableForceHttpVersion(), SetAutoDecodeContentType("json"), SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }), From e7a6e5c4f7b39d7fc6fd9e7ed7062e282674f37e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:02:01 +0800 Subject: [PATCH 537/843] update README: add logo --- README.md | 6 ++++-- req.png | Bin 0 -> 30990 bytes 2 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 req.png diff --git a/README.md b/README.md index 31c17e4b..b879acf5 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ +# Req +

-

Req

-

Simple Go HTTP client with Black Magic (Less code and More efficiency).

+

+

Simple Go HTTP client with Black Magic

Build Status Code Coverage diff --git a/req.png b/req.png new file mode 100644 index 0000000000000000000000000000000000000000..28ab2c64773bfdf5defcd9cd8563dd9db8db14a6 GIT binary patch literal 30990 zcmWh!1ymGm6dg)hNu^^!Lb_8xK)M78X&mJB!+J07k996IZKb{6gGUZzP z*!biJ*sR3KV?Oph)^2q-ontd42x9gZ8&EG6dNHCG=WpMZb>)Siv(t2`*pN-Pco$nkiKGa&|6J7MtP z1learM1o=oGC4>2)nRTlQ`m7E3r&Zj8lc3;!43Nis}IAL{2nt}WbgI$$vj((>EZ_u zz1xhrp>QXPi0XZ3DE#;e`*TlWVd2Qg$jQc_W{&8)`3tTqTSbCqu0t$19KITKg}}mp zJ~KZLZEW=N@wvLZjObf)73B8BSj>>b5T%5E9;^-?h{+Ukm}}X|lkhv;9CvqE>}`_ksi^6^T?sfQ442*0 zx1@p*sEY_9`UO8ba37u8wnlppQhp}D2edj&*L<(dy|L(5Qu^I zNyrcd0HE=|{3Gg4NQurHffIvfRgD{xS)qMPuC=GnUfjx!PJ~9R%&x4&t}Lp(&|6n2 zKNfRk5LK2NnLsv3<1uFTaGL(($B*&zZ;DiK9ISY;?TiWy#Aml<*D*85i>EsfW|;X zj9z_zIwfPzK+HJYn67*;hV@I(ZTgp95tdr+^C-5L`)pp0x1evnu?)&4csly>LKy%+ zsaVYikn|?9@NZ`I2y)X%Oe2T2w3CZwGBMMLB&6qG=n>mb3?4KfKR2ZW$+5DB!88Si2dv}yVAF8$-ZVJmHceiNK z_vo$;k1Q%LKMm{larLgAX5YfnJRTYZ|oWPGe05 zO)}j@>s}3oL(65p?{T>aaa^QLOI%AkN>le=xH8ko)A*!Q@|A) zcjY?*B}M*9K;u^%ZW4S(t`EEQ8tDbmtS$5^XT@1%<65wUftOHs9dH!b%fKB}A>ijr zSM))8rEB^i1^d2JI?J+ z;kfcKK+8_U?zk%|yD(j+Br3b)&Cgxg41P(&fwr$g3&;$#B&8VJ>>sSgS3E z8`|QSm@$jiWPtaQU8&>i#qm(oSi{1?>M`dff%5-&LQE{ipmS;4sO_#k@A(2$BI2ne z|CLP}r(czwX4fDtH6_l@=s7No4;)8HU0x1?H6N{K^>}`Ecrs65Nj2|T-InspDu>ZM zBIcGNVcMZhTb)_?Q;3YEkus1WXg95-?hQWRl_)6XkG$+?T@XE+)E4&G9Q0TZ6kCeq>$CP6UcOA>pYkTRa`N*K_Q$*4`;3$3%(sCp3u32SpIT=C)E|-^QBDM5@7Ho9a<& znQK{viX--fl+xG$!?$`-&<4%)jz+wGX}Vph+K6y$$cW^3B}F^f9~mNI{# z`d*ZCMwYt;#f5lD%K;>XHDD2Nuv#{~z}e?d532=l2&P5NY)Prooec>t8yfsE8h^zI3XV@Y(I z<-!F8Sv|$s8T61^t}Ta!*8bk#GM3tSa250fFbiPr@GSY0Q$(mP`4;fU0N1J|)%l#x&AO`Lu+LCdjgP)F;%P|(oe zsXDp;{J zga0UlFDW!^fkh;Y8IuUD5IlfpzSLl68XX-)lM#K&uMi}s4kc$m-+98%LNm&@(cj-MaZuZrmN$cA4lnqo zl1J1OK16hrmZ#w@6BdIHC-HXg4Z}69@ln;Z#jHKls%>qh0*aEE*>!fQ;3P!KDueuq ze1jV3-QI8~Xz`99|AIH0!Mk^8KyMvxfo9xt6M8D)H#eF0WHb`q2Wfe*A=JN^D=Xsy zI1UG4m8o?rKP5lKYy-bQ4t8QTJ~9`k5tlLvQ7#kYQnNB z?;`yF@;7hZxVhDl$*S}S1{t0GX8a+@mZXh0${TZccgOX*G_j{i?hPncAu%r_6HBKz0vz z_oltdmd=F|wv6m-d?18DnM8RZf|(I)Y!#uxuB=is??5qtYO^78YTOUHu5E~fYCqoh z2S`0!(73BreRfn)E6fQ4o*&n_V%fJYAX((Er$0|T$dlPD7-|>>+q48VN$k>;U*usYbK%fE7O&UsY3<^zt z>BETLa7;zL)WYm+xba4o0-!DGp1p)0I>P-{-m?As_wQ9zPO6hOmZ*N_rFrV(>B&2g z=!QwuWcTu8Zykf~GYj}15Tm=g4?ZImH8u3yQsD_|WEmM5Nr;P!iV$ToM;T*-;4yGL zV;oi!Y5$6spu!YwQDOd)mUiIZqD{Wju2?KV6D`aW4E|2k``ctlRB4((1YJ~5TQtp+ zWHP0qfpU{4P%_q zy7jE>Ib+l6wlF7^HBmhhfMPzwns@WZ!C6|~#lJV*L7eHA1`WxJ60h`rxo01FHrB2&SUlzSawZLNdjmsB;L&ssQc_kq|2zV z7u6QKhNTh>QNS`v%B4U?{zrY`NACFWITKU)CjW~)NN~{Qq78lEzaEj*z@NuooSd9E z^eTN18|S4St~W9TY@9n=o>EJD;|*+_VAJ{jIfSk{J2_q41U5D{>Xd3AoS=yRS04yu zx>z+GJ;d2pNhKl4j$4K#&JLeCb3ibbubR>n0Fszf>8GHZuj$pA8y7BEVH`4p^L?P36-<$b5u8fygXPmay*sWP(<8Xul0rx z(ghsUTa7+>C@(3oPB-z+6n0wrKMbM}dT(!U6e%mW@p+Ikf^5PI#J?;hLlnnFW29n~ zY9+U6ERs=D*U0XD4u+ry6&Ok>TcFh*MI5iM#q~XvIG9%D3Hn>`B@c4D+sX90<>gK4 zhH8P>g576P^gLK{1;3l!HpC?)&~AlwC@}r}{ARx^ zaD8Yu`1bC-z&mka{_3`lUH(tnNz9B&qI)7JJPJ_SNY29W&F@D}S)_{-6Mkm=`1PvT z=_Td7BOFhJd&y2nUtCyTwiWRpD#rcZ8-S}ilUAIOjKZS?XX;|RiK3$$=3Co5?BnoCqNjiSm*+L}w4esH_e z2r}tL8jg6`7yY234AKw{H5-MLQb{rCKI7HZ*HK}lHCkx0ZYG_|tG=y4`?LJB`V@JV z2YKk6@g2Yj*ZoxZjyry}j--4aq%x>ysLw!j?KW`xsKlP3#`1VG8H^qxp`|76J^y~# z57=Gk6+W~aYx>xfsT3h`Sagcd1gQQ3Ih!Ty(^;@8ZgGjDYyP z@RQ1XZ)y3Ys(__8HT((bKe)*FBY@uLy>($5f4O~Jd_0P*7UMR&O`Fe074w_Hdz+b) z?cIW!-G`I-Wbd1U-mufaW0*kmUL>Jz2o;4{nvv1b;IInbJ6rZ<+o{61xTgRxK->am ziDcfOq1=cV{dWye663-f*4J#yUM{r^Y;A+)mhOUrYCi@9L)gWdYSq^kwDEMO=@QF* zCp)d)&%b?c&q(-oRm%8aQdJ%xid$^5k0RX#`kv{b%P@d!?tXorc#W(al}4_2IJiji z5XbAK%}2CAYXciR)z&h;2JSJd!9H1X&J<3xB9GeQ&pToXAVp-AH8uLu3sI5am%Ip9 za>@ z&zR0_(_)a*@N`ROQqtgqo>ew6vof;|q1|i?PDY8(43C;Z>kO++j@y8Bdq-?{{LGtl zJ62{|RUC=gkI^`2gz4#g8X3-G=0;Jj*nkqF#~Hp&uIrLT>0MC~GZeg)K0f+Fw5hx; zC}*F!_PoPNFOi9u!?>L>-Q8B{xIB(O*ZXjrN9j=_MkEGPd^|uceloRD$ZyCL|sw1v}id z!Z&sIIW1li#+xl!zj7o^bZD*jLgQyYkZ;BZ6pm||08c1XcQ%%TUKe_vUDKH-QX14B zh%ier6Ib!MQNmj@h|822+|K$G!vxOEBg&pP`;sYaA%mK41u?LusM3QXFkW9RpM8&)LI*gAFR`MrikG^T7!&N|Wwd`# zVd(MW0pyKdGTJJXN4R!bO&c?aeeaCs&3&Hd6Z!C_Kq=c!bGJs;gi^cKkzp)glS;=+ zRYxl6Jx>H<7DT4Rsc&FaE5hTOvn(EnJra|M!Cl@u^TgHWC#ACmfzPtle6GzAa=cxW z?qR2DSpM9`#>{uKV30oOdbnVv`e!g?Mt?O>*ee|&S758J&jC=ZbT3(xMjCtg2yyP& zeb^r75c!T^$bo4}2gSd>T(}li%crB3{Kzj|L|j7jZy$45fh7{kaY>#6h2KCbsfb zF{3(>S(72&&ae`SAE|^h{_j_xiPPwI>ihRJf$N%qpK@Hwj7TmTp?m|&&hba?ABN=> zRXy(OkE!)I_mBQ!0_JpxB6ni8E?OWiVvmg z6q&A_%xsJR@uB)J|KS5H2K9M7Rh1u#@fB18)!&L@Jh+JC7fZ@ZvXhgMGfY>T4sU~- zA1t8rC%)|>D}gr$elA#de(>1LfA{M=>t*vt4bPN!MN3Mk)jYf^Z(E^OZ!Hcj2UX4m zxiUrF)>-odPd9j4Yq8iel$eu-i$=GW%oBg_NZtNz?znmaRJXkH&WMUpwO|6P`P_Km zjaiPVg%Qh$%laUDX=8IMUNkQcNGZ~Ycn{{TPZ{vj?K_leDY;GYCPm2yc!2k_bPvbX z3*s(|?BLe?uT{7j#XXP&>0hY1t9P@G<_MW+ot4;E`K@EbY5b z_YMY9fCMn6({zSFER>bs*}Zjpt6Td`3r099IBr>Z94#YXGbxQcoYxCH8*ebVd_1%Y zCgK-6j^^n!6}{SsH;rLY+s4<6&&E^2EJ8&Ac$KQOYJ`O9fQG6D}7-QP=pydQXXVl7zx#bSOpEm&_qEq_~cS7mpE zsr}lJ|LmFd5;V%ss+nE)H+oVfnnSWW5Eer%b)3^ulQCuj9hC}URwfFm4j}EwcQp_y;{Cd<0FY`N4@R;V7c&ZbXHatC=mzr1pO20n;oo# z7Iiv%Fe}%KePMYS)d?cfC;bv_*oEOI@`m>kkj_@R&A-HC0y(@P@C#U^XtA~M>G)qXj3(|nsr zxwvXLB-Ay>fSG`|=sP;G*iMc^Rtcl3kiKFB^DuvMNoIpOj6XTbmic+?LRvBZCM7t8 zKzoF%Opan+%qGSS!D~!cGq~g7*KTSiGw;P-oHE_SIas&lPnv@pDL^;I^LNiSV6U^y zrTOmDyIC+jlGes{$>y(kACNux43Sr;Y15g{CPT2yamNGOYLFg3BzZ zZZB6POF9<=2tcJPu_CMOHcO!lQQUX4HFOW~RY&*}Kt@m0zBH#$#UdLs2_TGL-MY*? zsdah@P$>7b473m*I=L7<2n|^tULTemHM+R@>GIQQGOts@`dwM#ff&Al_bbdBRQiqK z{1`a5)J*~sirV6}L?{}HMy&mdntWC@@(yjV&S+1c>-zlrfNll;*;mqGUryrjnyd_c zQ`T_yC{d9jC$E`c>V4;lnq_XNqvQVcBuR(UTL}%xhpw8G_!{TCl85n8NZsxIyhrrt z!^M!?7kclqR(VCmL9gKJzSZMSY9Mfn+zXpt*nTo7^!wm-w5lMK8pyQ-Ezc~p{uw2& z>C%3j$Oitw@*mDSkw>GZ=fTg}eaD67mwh*4yKiD1+q_Q~ceY|*HEv9QRS362BLaaV zoPC3`z6{FM33*7}o_-3@?|8rij3GB$&=vm*u3swwM>nX>nxk`-A1+e%c=Y@(?xdk? z0_hh;kwR%Qtw&P`Q;Y04ClV#{*zOj_a2Tw$wbTk1I1u$(`h~b;V|%6YYkgRr#ylk0 zkrbVH?r~Z8C@gHq*g1Z5Lqq=ig=*I$q~|V8D7IU&zfmGAK!Uig>|^E8si-$eqT4w6 zv7L+2Nw{a_S6XYTmF0nO`o&?T{v(07puM-65QvMMoQCd8@DlRTXUoQC6&=v*yqRPF z`804VR%Al#Z0z##GUuI;#N&F>VVsn(fy7tLg}P=ii!u&ip5s7Q!||s{z2)>!WTk8S z>(Oy%F2B`C+i_=hdOCDU(~e^3IhIT7ZV5``S<=z=Kkv$;^ynV54Pah`3Cx&}Z(XL6 zC5DAR`mKLoHKQ#z3tF)F;^pNvo-5%qRxn6=H52?A;=8td_CnILi|Y9F3BH2*KTNzZ zFy1cjfKVS(I1_D!rag)Kcx`fB^)~rQ^^@3*UkK}%{03Es8!oOL3u%+pg&C2^R}tz! z=RQ$t&|bvarEhVHiPJd83m)UTfAEz&Q$crpj)ChlRe(4$gb!dW-drY1qqH}rayC0W z-O+WFhSiY}nsq&6P@r}F$EPV-6%d$M@t z$-$7Rm{rqMGkq$D=!~_M9$cAUCe=AVl6>_yXYnUYE)R4_Sv7K`gCW9}7J6*|%R7@< zLZhi>uBR2pdPh>_0Uwx#2;U=_*jp^f=2H7X>%*U~oE|p8K!)cw&(qEwQw~Z}Wf<*P zzuF1Aew|AI&U_*iHCV;POtz%q?oo=ke;- zMtaBf!IUCBLPW&8i1(B%W^hADNQjn}R#EzXJ&u)Tv8Bl3Gv6>EquuNt+PeC9Ax)R% zGg!42aXT#C@wfKNQJhxnUK~YP6)BE=|EnyD`e~b|4I_ zz4}rp$c!12f1<^CfT>_^_T=aZ@SE^8D+T8(4jxV}9!`!)+~>YQ5cw2s603*dX2JJ? z4v)XP2&M72?A}Qzo4`D2(ZXE|6J5)`cX*BexZECyOVb-|kio}PRoo&=u}7g;BXRF- z8e<=UfsX}qT4(&Igye~;>RyShvkUSoR)-S@U_OvKf9Q93q36OpHJiG&x}?m{NnzA@ zSNtqr$~%p0?w7#i;>K48^iKx8eRYW4Dg;boWRK# zMOIq{=BhNk^qU7=<~PMo>ja9y8W4RUYcL#^sV46)TTmURrD&vYq^hau!cT?GAn1;P zWhI7H906Bk4hNT>{r~_i0pmzX;#~qEE_H=ALYG?CnJ}3*y6w;ouf-j2{zByam*ulR zzXUYm1bvLwT!SBqk7|=hC2UkweY?FB;%YnPyg@w&bdYtMmS(FLLACgq{d8ovXbZ9%(?_8 z-Nk+RvCdVKjV_i&Jv$7;{M*{u4!OVk>FZ8oa@|+7M%7kyyaO#qphC+cj+uH8?sAg`=!vMpP4ktQ&?CaZ6`ie*!}=4_gup2@ z-}N>i-#vCvLV_?~{|dXgIyt3R@X%5=!@8^0cgWY->E^67ul0I2`0-*Tfajr~M>vxa zgRhq?5CzdDrH}{8J0`BzLBac`O)O_(%68WnD}{jqH25My3fWW}h}?+Dd=Y;&RDVHu z0v~E@OvUfzfgyT{waUe_Oth)Axi_72@*x>Ix>j$&9>>i$Fu2PTK-0Ng#DMeAvY~BJ zL!o8+6f-)oM>}1ljFKBjV4lNG(JLVxWKr_5<#+R=up zq6T%M=OLi40D!P+UUpW8Q7V?$5hUVZ%19=vDS-ZnvcJIhpzfsUJM=+%EUy3}%f>Ke zm83YbLYFoL`4ZV78Uea}yJ*kqIuZ+K-&2kOEY-b)GuOea>WH`A{OEwd!|@U#I4FGE z`wR6vflWS!eA!9w6#22K+eQw&8ZD%?FHUl6>jkB3S!c%?^6kmQ)A7U0)YKh4J>=Ku zoFd9&+tGvCD7MnVdbcB1O}geC=e9#(pQB~+`1$DN!@>%JW9(vw>NNt;p6ckP*+behNbKTm0?e=e_Du3Jrr39#}kMq{Fy@_GQ z0@oQhIaQ^S-VMX)JxqRXxB(0xO;ZRz31lrM*-&ax=GDa;AE}0Yp2=O~>DnGO0no=7 ztXCXWN}KzNbaw|eKJrF}00Xp{TZhtUW;KXi`shAu?$G6+bb{oigne@>sk>P%5g zHv1g6c!S=K!A>8U0L%hA;X(bw59E=kZW3E@FBMT@5yf>FF(vq`K0`Mg8NhXRq<{#_RLrXUDF+!)ro>nbZVxx*f=1eTerU*Z<5@rcGYUyL8iM3e$@1 zz9;K(adCA42U}d6G0@UYnt8$T4ZrX@~eS;8)pFY>v$O;PY^Q1B)}5R7#BOx`XN9duJq=Nr+G7jz}dwCs$v_1v>9e zN9ozIva_-P2Cdg$sU?}-F88#TOzu?q?zZ&AD|dkr=^owr4<83O69vC?H8X&{1*r__ zF38A^d`6v;CnRkaM|n8?*1EXPq>l}M);f&Ks-Hbk`4Qyz@>Ag}nc9zQ2AU3_W%48Y zJ*OH)%hti2VJ>&%8KRS6-lWXIzAD8$S#2v2U53Tl5bu-P2`+O;y$^_i&A%8Id-q(6 zoa)e^1l&){VZpkg*FG;y+N%NZBMwK|D=)Xw{XFn8trx5%S7228hmj3)q@qBAk%&JE zq8+*RmiJ4n3)6XOc6VU6Dt-0qHi9GwSu^;w{a$le3re!7 zF+EmQreLB^UZjA>q}+3P=?(mQPWCQ;yzA(!A?Tw2_9+v3$Zl6_y|1k9VA5?Yq%(4; zIa!UNZ#s8i_UOFrb@K`w1@iOZvrm>(KDE6ejlE%I0QyRftV=DkfpEz;e{@_{ zO2U_*+>^XZC1lo08hOaFE0RDnZZy3j;C|P$ZEQU|ZUh9B|5d`liBV=xeHmZGL}S!~ zR}5BJ!x{9h#1v~{Mo*CG7EHIMSTRpz)%r8c&Myf8YDuPCZNQ-E)EG}VF8SCyQTtjs!uI5O9BM=>KFHCq8l0r{>X#m0W3e5;+bXk5fs=JVL?Ia2N zgNd$Ywq=r~$S`4$`55_-+tyY{n1NBi&Q-MD-24~n{VJs}qFl@ARI|T46%89GUfLTp zogI1JRJGnZxHam9eY*S*nzNpD_w`U&Tu+kZ*hy2ql_W9F)+c(}k5B#0p2U2oj#4&8 zFti!t%*Rqu zQ$KDNL(E}2u;DuVacPQeL+uKQ5NpW`l?`@NJ}S$S(e=`>N^NT}5KkfSf}dRfdyN<} zsH|%A{W#9ubv?T*ji}8_9Clq$W=Zg+#kPTKVrTZ=T=l-KC z8n1HeRWI@a^d}Yspe3)SX2F-2&cAok#A$e!h!<~`&z?)VyYSEO&G4qLP7=yl;(Un- zOX;Qjpn0OCt(D5De0d2n%C=F6J5PT$Cnj?Lt^I8=oZjbS!^<#H&fvtgXV;^r+0thJ z#=*~muT~#3OmZ{3p4@jHvK$!bnwX%hKJVXep4}L8r#n;F@6bib8yXo>QBmnt8@E^v zujl2N?q&?vW)k>Bz|jDqzCOWpc>4YXJFg@}^IHNncgb|ttvbduJE&;cR~GM$c0xnA z(~`BCcpOm^FKJwAZ8EOFl3cBxkK4}98*QKUeXZL1!jH?>C#bY4H8r*WWxq32F}&eJ znw6oR=rVk1dTDvtxXb_I3cssy<7C?}Uw9IEK8idVrtZ4eD9x>dX-Sm zO6jCn2Q^RY!AgKbpY@j-b?z!FB70FX&UJBBoAxNL;`+3-XKj3%%#>H?UGo zdI9gs`3=g6@~o`vbZ!%Im!svypEg}rmqrCISA%^IAGsa|`+hpNcdUnjL0kwB)gmuG zZjyz^7gEie^-k%lQTd5){@bR;;FKe~A1#Q)ZEFjb=m@rdt#%a3ZRTubKg!Gb6lU{o zZ;>*a%%po~%kRjKO@c1k;Vyd~i;ZNGhp}Sh0UGhbLN&qkO4?6eReY@CaxwK>Ih{8{ z3mHykG4Y$#RiyK?EIIEjS=)%0?mVk*GxdG&1BS}!8vCM3Tkjrpk2`1$TO{XJURCv& zX+FssEcR<}ov&zF&Kj&h-bYA(Lk3pM#6g>9%nYBiah;{6VD1l)#g`kL^ejUm9?I|+ z`54_bkc$d@?q>7zzf;ndr&Uf)@sDTuBCEb&O;G~iCnTWXd46@8I(F%G0bxCn&=Ch) z(0WKmmK4(Wj~T{y`(SY!DWZjw6f%N2y&zARx`%V~ZPb;NsMl&o@KD2i=D?7@qV+`1 z&G53H4dm2nydkQt43J@P$7)Xbcm04}dGp5vDE|Jf9%KF>VCchfhxHP}*oE-EtmjQ_ z(9{o`c!_q8lR|Ti2`=a=cbE#vfyLEqg`PU@2)Ghj~oL$!N zJPFa!e0^~o{1`Ct$7oai*V`9nchG<%w2;@Dubl=u;_GQlehL46jEOu&4?JSh?DCu# zzHgeUHeWfJ>=5=}MZi0ahCOM41dm|5#jaR)#OU`FYv;v5mk6JU95H&+ zfd13+wwD0`GqK9>ljr2vBOJ{ZIHM|JzW^XarakC_@FY$+^-F;xJ%&@e8cYk8c6YU_ zDjo(O0_-fF6)1Oi7Qb*~nWiyHgfWbZ@&e$lM zvyy*j+DkN~$dkrB?ER6Hc8>`PcOpT=XrGQQ&V^@q{EjFAwwkg9R6>#GdpcY)puk5G z)J6#|E-ps-(!MLi{-~*WydGxV)34^v<%2~H3NM6fp84n4hT0jf{^JGyrKVN+ZiV|A zsa{ASKP2KLcuwJ+4h1 z?o`$_@+onc8n;O~u6(U;F@@53IZ#3Vlvq*MImZEl;1Eo-LpCh3d6iEtZ?zsOiFR91+KbG?h^otu-ny!=7y{GckO5 z^19>^8ft4UI>&25cK=OSAt;!(WQ_V;;mOmbf2{O; zd}jtg94MLf`iYupwweeR5VrG>laHyph^7>JBCJid<{zM0 ziH(Z;C=34X!NH6df;R@`zNsiHBs8sFC9p#ns>_YrM{1f^ZH41d8to+a{NZA=1Saz# z9|uEuY*126YPwlxkC+R3z3iKpWiy|(`njh+v0!F+FB79hMRRvvUvI=k%Y!igxV0I< zm$=&w>Z>Ov#y*QZG;cuevsZt83&)X0;(>E|qRd1VS(z&B5 zCkTuF$h0&^Q8a(25DhzP32q&a>MzmEf0#3JiLEqi<`LGhsw&7cWI0vi;yXM%ymx$_ z=xT7!Q2buC)Gn)b%ls`Y0U6?K zyno^dQ5sB)JNub`fAD}jp0sa)j=$+4g!s&ANLD=QN(?(aB6x=l{>^mj=QiV>6P_8% zJ0bmehvEy57dut-=zhsG5n5^Mt>Eao%iL2%Rgw-ml%5IjA3KV!Tl^v$R_``{IsnQ3 zY0g`A#Bt^~z?_%)8NtJ6_4>tf12yG9MlPz4GCZyshw!lE_9K)@6Acx_G|3f5C7I@1 zC1PY4v<&qkBI#=8x7-=myZ==5_R3DQ%fHg~ovrkr zK~tC|Mh$4>`ngob$5CZL3<{eGqf%LG~;DKhWHI{8ShL+XP|aT$e77F z*nso7jzxGTT4s&g<E7>H%iFf{0YVzJkxB+skWfb3Q*q-pk@x zm7m-M_Lytg!DJ)HZYyj{_x<@ZIk~wtSdvu`!UW$-OIb73E@#=;4^Q>8pD7_Dk+;$_ z&-U2VW?N}Ctt6s;`8W9wWzh6yB(>yZ1$TApUi8Dfay^_s=t16e-ZFo_3k^&kwQ44q zXfAPx@-577k?=_E6oFqUxFWFm{}~9-Ts}vlr_K@*5-Jb6*+~2ATX+}`9A8PJ;zDk2 z?lZ}M-Jw|c-W}H>;iOH-y#l1sPVmaHmv=+s|4jR4g+;|?ykl?vK+<3Ijl&(463L`?<(-cs3 z2aX8$1`aF_}SKtI{a)5#{rA{DXOG zS}Z+z`=CNrNljZ;Nm7!DOEX+S-RFryrn4vs-(hKKK#w5JKO4)Dh26Iomlf##IVrV2 z!IS5=bsVkwdi+evA@}h{+~GMm4vca}%m>9g+1mb{pMXsHL_g;;Ou1<+FE2ME`Jb34 zxO{tzR5puRIPU9Ys<>RXnH=Sg{H5c-!EYTNKw3aFFKFG|JgUUHsX8cFysdL-iCw#Bx#GChuwC9_1a^!oH5CVRDC>0Ua4~3C` zJ@3!xo{q@NG0;Za5tY$l0T&heJrhi1KJAQ-Qh355BG9I$wBx$ZYiD?>j$e~4@MUJ* zRCPVh(o@l)+zZ!d19ot4bpDnVJoyBj4`=}4Dpx2t9=@6ELMBX)MNLiZ1fMAXF(*0p z_1RUB<(&)VMWI;|HMr7D;_c)=dlCX{HOeYBU6W>qDIa>%;H^`AGb+(~pNk*2&sMS@ zq0*Wi|E3Ya0jV{Wl5^8BIT#?PvNf!8;!Db)vi#1Z7DXX*OOFHrq9zpUgeG&w3$j~ig7*1s0xMyQUCVPH9_`Z>)YD@6z)7Xk1QeWZ&C`)&p0ViENKI0waaU_*vo(Ua$Wl2WV8Xem*8IJ>{ z)4i#F##$CoRQsWr0k~EfRSM-DuP@h;<4EI2DCx>=NVbwUAw+s$N677|=}&o+ddm@5C~ zXwC<9POT3pBA6r%%ihJp@L}4Wb#WEK%sFag@MoOo{UNc)Ghq=djMUV)x5IhAi}LfJ z8|5~yGQZ>?@uF8+4_>U^Z6rbhez}UaL%I}(WD8hlm298noS(w6 z;F?OFK5hslZyyw&F4b{1_e<+;AA|k94z7U!@_YZ}olWV_30#rtlk`-2Wd05UD}43t zEYSzzS74O|vpJYrT+D3Yk$1HF|MOuo?EgYl+Kx7eC1+@m^zwarSMKRqG$_^q2mzFE zO6*X09x|vh$cK)m39(RAR9tTA$S%%BZsX^m?tmtvY1&!K)!62~wTWu*y{?d+b3*+p zaEuf0f(KilyV<;GQO8)O8EUXyh}{(oDjxo9^K|sl;O3uMrejUBMSgs!brJMJG`V!1 z`hKG3mYI>HBbjhcH?BRqdBK`AY~<47=M@xW{WMbux@y8E;k{baL%2O-b;Rc3v*&T- z-4#M!)78CinEjlK+C)5nefAZ~NVYrnyoZUWvdHKlqp zGwh@dN~D-s{6tGUd>oY46uoIaA|F>tz91tH@Mwr=!3T{~cZk^MJ$~4jS3TPs;Qjnk zymdX7cDnxA9n6drk-I8=f4pE$Jx5=*S&a>{Gjyw+_bMZ>{7EKzM)B|!4<=|tO}&qY z_xDS{FOjfuW~|~A6m0EG`l$Kl6Zk1fV@TvjiYbD=VJ5l>BO1@KOV(We`}r4Fb5PvT z^XKFNsw&x1nfD|8UH*>Sbk;7DPbL?iRhb1o{N6lyl)8*O&c9A~m3MZkB9G2m`dPuQ0pg*~Yno^}(g{NjuMp|qhkO@Iu?K{;Y zAv*79mKRhH1iY5dX*sA?{SJ#gG#+<81yn(H!`oY54f02s$B))n2Q_^V6B83^6uMP% zY$vK~Z*R{}#RC#f*0f*r=x7lDvFhr-OQlRM4Vp8bucIK@|H$qRggC4N-;eV#%h$I# zoQmzt(t4I{A34l2_0cOe33m9ewdhG6o!uPWBUeigTOIncgfivZJ{cWf1Q+i6qw;SZ z9T+MiK&>zAfg$SG*6|!QKG>f-lcsG|18S@7x2TWX&(HE%vZ$dP&rI*(Fb%Pwx%wLY zqrio&lTFv}6W+N7R}ZRCnQr8+#M$F*x9X$c{aLA|xb>||b%q~e+?%u7ucNg60fJ&+ zr8-ZPkxa~Xl9aE=ueq4$6M&QEcknl^jXrMNZLvC?`T8W+C0d@?AV zRpPf_1=V}tojmYB)zl24ay5C-+YPA1QHy?%m(?&GKBYi_tT(5yZ^eWcqa$437!w$G zm}?>$+wJj=*vY4LF2fE{ z5?E94*=(OVE&>!d!HUs;1vYj8+Fm=aLI0E#~tS z`R`e3mm?4+gKLXtRa51|=u2$6;KTZX2@~%2^KV@bEjGFWVyJ1dXX!JUW8|C>JeiLc z!gP1}Sbf)z_2RwDUDbHD$nn*GC0%7yR9)LXLxaSC)JTUS0@6~_2ty;?C@J0DNP|On zDh<*ljUe4EEiE-PNWbU#z8~{r)~q$_tbKOe_f?y3>3Ctk5>;&Ep%%aGh|pfg(?zkn z_tFtWm%~UaJ&YfBy+I?XF>_6MoXrayL-%{Uj*hh()JK=q*JKZ^jeaICv6vsV?G!b} z25$%|nv75UA5j5}^$1Yp4$8h1)895V6q*Y+JtUj4>DK-IPYaK1+F9}3p}JVhcOK1f z8>qx_C#SKr(ehKZ{}@LVCDOc8_c#5ZQ3!|e9!MZ$uRCNCYxuY{9PL)lJfYaWX(X~mM`<4`lh zfaHuwk*BwBTI=)SY!xW9L`3IQ-2en&7-x%?uWDpXvwJkXvSKkBG0 zvgoNXJKGuazbHsaRKW~@Y3Xy{pOz7FE-dj-R?iVMukJ|2D-SP}bsecD%lE__R6B6L z2R;4ltssycx5WA(TY2nUbhQOhU4?wR3yt z*UjN#z^7QtL~LberQ2rwPEDq{9%Dx++p9O9d#q%^nLsE#+Pw}Hp#4_3rI{q`+~qiA zGK!WN*`PbLmp~o%Hy=V_Bo}sQYah75_59g#0D;T z0Y=>xn!79Sv35@TiPJ(A`{BdcvY{sg3)Yh%MPsGeT(fou9Ewp4UcKt0o+le=Opnyd zE6$>~f<;#9b@ebP`5ej?Nta`C+udD|40W;pQPyy|hPqZvML8QY8ygk|YzJY>lrT^M z?vY5Oiq;%x`B?Y;M)M-mt9NwAP=(jI@snO72{u8|@`ac~W0doiKfdU=POY)0q~!Si zvUL6*P$+hP%2JtiGkdT;@bqa`7KO+Mj@ZPdt<#jGg%eY7JKGcW@gG@I)+P0mq^Lngelv-uTdV8@bD+)YEGrl*>DaGvu?}N>ii1l zxgUG0Mox3CSM5A|osvp^C-tA4!Si!Def#8}eB@#NJeNJP)P1!y=lhhJ z>s?)==%rRN9#L);CoR1l`AU!ka&~wHWaxMomdLH`xxKxJ)(nM6lald(r7fo|4B_kR zz|+B{_q-3MILjufYpL^3_vWNEKw*>J+vY3enpxDU=fUZ{@{(wKYaHh`wc5#E$ExsM zQ6re;G5Ns_JB$0vS#Z6kI83{3vtda2e!1RaK{3m`>wE_ta10Y>|L5m=d-CX8;xt=MzuvqyFgDjJ%|#r+!yCGKZb&_=L`RAkB3wB?f(4>3O8U6VjY9X z3=n*2pt7|UmnyUbQ53Ov&-rjoOF7mS;-tigK!e+fM@`w*D(ISdu4CEp$9K4+w%$T+Qm^uc;dBbm2(kknscSkhr85M z;h5i&tk%I5$LCpnbe(&Zq^ZJf3R5Axv|jOW?c}QObz0`j53xFiKNtfmtFVA}oxA(5 zGL+df)<3sT5Ooa=(s}QV-dArs66ij){S{l{$(wshLP%$VZ1w5&T-{`lLJLA_lmWSg zrD#OSjo*RwZ9WZ=*OyUO~se(JzZWN z9>1kV1&|-8`J!D9^CVZa-iYB}N6G`aUq3R~l4S&ggVE6N1~}d%;lIbLv-=i)my@!)Cr1QuX(kWZ?u-0?c9!y7{cN32b z;GCi)72FSPnTB4NtgUOXdDD2BeR6%0OcpGGeZ2G6FlUcNT1sIms!xmE54qcSH7`7| zY>T@r{&Er8ur)B`eC@oTi&>2t&Htv)P{nfx#cYeK42iq|#dB4KrRJ;m;89b?!sLE) zRB@a{$|&d^VhN*ux$M|i2{H89Kf8Un7@evdqu68g?GP1GE*9ST<%`IuFB{tM+Q_S< z@x8fkH1}V0cUVmv$0h5~ndNt$t+fAlA$t;ht_&k9F0M0qzr1s`xwA|C?^1NE!)2Sf zDbK3R9GOKvF}g#6Zm9!K$}FSbv$N)lhzg4o#*EKFAZP`EOz#; zhfo2U|7g(}|MQ=T&Z8`8b+tn5>@Y%HLcq@dpDVA2%38{GQ!~#aK}wESucNoPBely# zP;P?1+rNP#s^eMfvx>*5shro^$*_548BHdgy3&i_sD~h6!|GLt8T|%5X{G;F)={?K z>0%!Zdcf0ce>@;KiBjqFwvE{D1|jlu5T0ed`zE}k$C>+Tq2HKWR(vzGvDg$kV_yrI zI5-?BFBjO$d3>1j^ZxF)Gct433M>AOlc=Rz4L{nO@x`c!@!!dAT|LywqNdct9x)jI z4S}#b1SP+Ycbet&9$m2C(R(;uWwC#B-!2x{!2-~#PDkcI6;)4yQ7mpqb4uZ>3>uFe z;pb`%`?`}=1^+Ms#^CMCe{yo^+dT_HLbZxNM;voTMn+b?(c3$&^=?lhSV1{pcl!I= zVgA3!+ZiTM*E^EqyKucMA@s&$(qPrr2r+hlE}SBKJbn~ZfDVn6kj>r*M?WL=y_(_T zwltO%fd+iNU&XWVzv)t3oJq!ei~4;CE<{ZLQf+UKrf?tK9w*xRezT^C6!*k44z zhV`^jfb#iO0#k$k+~(j)v#;tr#`N=`_(g+7zo#?PJj6blfVjV8XMbj;ux;^$TLupN zw!aAuvls0!??uGc1^!Im`f??Fa=9_)UscG43-pN3wR681TX3G#VLBWNdRUIhJlVR; zd)1*pJ8);9jn21~~v^qZ=1)g{t9WQlP3JyIJ z-q@qQyBc8^b?W)C;6y*pGkSAszl{Gebgbi%%3t7Lb_bKU_wdRJ+mI+=W{B>u!wm4S z>ck(GnHk_TG&C$UfS~4doAledp6NK;Sk8|+!YmoRG|r~Gi3TFAmyepQmrgZ}7O`B{ zQbJ!Uy-q8hxbu5>{o607YLG~WI67A*QxDAP{~Vy5K^th7)WQi zgYNr@jnNKgeS>*r{KflU4Ca)6H-VUkoNG*AbDB1{eZier_GppbukMZa?DRuhH&*#Q9HLd(ubKvHz}P`4O?mC#cHdO-oM?_UFh8 zS2jcS$A=a?XOTU_yS6MiV{^-Q_+XTFB26ks13h(Pa)lK#qA*A^+jBb&Bt!hhYv1tj z=zE+zuWUX4c2oaGtFq#iVSQ-{9Y*@1+FM7}8Ya_FZ`Vpb`-2~Y%Y`uu?wUPzEl(B^ z^$h@S8DaN!U}JHy*$SC5HP=~j9>QxagsjG^{JZ@B>&J+R&4ghNc>&?!14z5tM!q7i z>FFr;G#?~;yf7W$volA=?sv(|Tbd_y`m?p?9#K|Nsw!%M>z7t;2pU8x7SVwOKQ})a z-|o_wo$tE8kyJt-7e0m|L~nj{9lF6UgAY6u&OF-Q^+x%XH+0Cyvv(>re(t07RwW_D)|fE?P?V zVJ$hs$Zzc(ASN|4AffRL5|G@`z?AJLsIBIH{@GudG1q-d?xRk2mxR&#uUQ-;4wcBO zmLTz6EbgU@zsMj+Xa)^-{>EQOhfqguTiY_J#Mt&F?LM*1r%wO?HZygNchU~Ru?Ul% zZ|RGNdmFIDfu!5o$-{=Ltt9f5x63n$(bT7h)|KrKx*$#kqV!r?bLYELWo2dmMOM7a zV|u(`2Y&D7tJz>zH@tE86QO_FFw;`_zS?#uy&#l;2(fnNefp$1`Sn3`&tQ#OzdYUk zpq8E<#YQL}yAgl)NWhXREpwdp&-cu{*U_!bS}&-*{4Nj)PElve)V;s$(35tOZIzSH zIwc7pVsEL98f@kiOFW~=*@r3$uc&mJM_Mg6hy3D-O3H+ zO?+zLjsLDz_4TXx6{fzYXE$?lRfsqvYputQ0q{rV_}*llwesV41*v8*;h>tK3;D*R zF9m@f&_d()f1yBke*S)b?L>iKcU5U`SF+@sGil&PRf937i4q^byIX^EaiPBdaH zw`*x6aUV(kOv`(T_~Xwi!nj|!A5S6TI*HInSF0s zMk*Xt?^m{w0&apAxYGd*b%YOLK>+k^NC+?~(2scF*r$_Bf)j$D&1)TF;O2R8 zG53V&fFD_GVq%*mN_{p+5XjU41$6boG(tZn#2MgJ*lD8z6-&9>mtvQ1+o z-^ligYtfhV%xruei}qj5!^3OpPB&U1-C9_W3bgh-UlNPg>q}nZRAeC)0yQ!5GT*Qg`+`;51N|5U1E4A$%yNRYPPe$21# z6I43eMKu)wD94-sk~R7oT~b&WfEIr@r3|$qJ2OG$N{**l{*m^_PT6m!xaQD(xQv30ifLgQu4;FEoE#Tg z$Lj8v-I!n3(jsKJi2OeD-$*2dO{l5==wU$TgM=#xx_vKC+Lc~Otqytm!kkoxs0-2e z)F7+mcHzbOsvj*KUd}2NuG@FR1rmGx*DQM}foC6DUyT0bdI3NkP-KxXQ@(^hGyafG;tEG1|6nWMM1br z1X;@I)ybZTN6dUnQv*RdH93Kf71=%reW7I+PlB+pDn5fmTmY8(fs4f(#m_qVA6^pS zhe&87JmHDyL&EYy{yoOx(V8qd?26cb?`-U}d_QL?kRqssk-3uKONnxK z#Ti~js@8mnqmZT*NNr50_c%l%^x+cYUEAvdlD5i{7U>edurfIdRs zZP5{$_6%4J(VyByd!fIT?#hDuOP{=11V(RMJ(&l#h62k|f|fy%#}6M!uj69ThmNk$ zT>`=nww4Wb^j|mt>LC^6;>Pn!Vu+)=!R6OXeU=Y3HR`;!cN7ys)t zK^X(tQSn>FW%rnf5YGosBeu3RgvCnhrjNbj?(goHXdd1!FFM;kHEqx{GC%xclpcLC z9vk3r3Q7&}(q#l&QHI#sQPP&mCNu~nkWTF8Nmavk(hWD)D7M&wTj{SjLB@Nxx!H7O zZaQa}RI;&rqzTq;sC_>ariBWZBMRc$fWme<$+2Zs^9v&b`c>-LUZ9KysJ=x7h~nWW zvVc^t$`rYYDwgG9mtz2EtDW*_ZS$QGvS+(H;$|_rsZ2N480U5F{gsEG=J?l72y^fW z1e(8em56$k)4ox+QQ>_*`GmzS!w${qyTg$e+E*7FEjxUZg;}|u*$7SGn7Hg$0pS)c z;hd`-`LaBjQ1SqIifW;tGDD$Zs#l1mdGSpUGim&{=Mkpy;rdPDoBI3+7fR+#^*q;Q zF|2YBI(b^%o0)TMbZXTELAw|t{bdLKmblm;5z3o0!AjG1xWA- z?Hvm%iQ464G0mb!27MrhB6@UNLSoVWAaRboHnWSU8}<__{T@a{X|f(NhRQ#g3a^Xf z$Xwk-i`m(~K`p9CbKSdbM38tZ10^fRJJLwgUrnoRB<{!NfjcdT?0zxv+Wa(*p zp#Uh8c($X@jr-Y=SOUn6UurJg4;`K1V>Ed9Ot+C!@=m-S&qQ4MKcY}zz zdxzKjc}DkvItDHL`P1eKrgcuBK8i^UA5B#b1qG#7KvEizQvEHvSnYzybdG0KU8V`g z;sn}%`|F9kG_9R{UXnlw8*cEt52~~ojR7X)PR(3 zdVvEJh8MS?DB+Z|*gp`#!x|%XXk(q`(g>$a|EDgpb-wF7`Pc%_op0C>i0G$w$ z#-ca35^w+o+msM{R0S>@&j`dynOauB=@@0>iCFt);SBV>DsqfORhn03nR(ElaCHpt z()^=^3aIMtJE>B2|4>r67;YeEAu(C~3QE3I_BzYOOu?)Zk4Ej!owG{q>P?!Mdlu2Y zi{Dga-^QNW+FW6_Huut3Gg3%*;>)0YQ&GkRGNpoowP`-ea(&sSw&zNVwqQdQj3SQi zF~;Q`Im)vk*3#%D7=>a71X$7*F#ZHRYMLkkney0DDDamd5^dgLj1W#Ksr5{{pbadk zUe{4s9NAJ~#Is@Y{HP;Z!uf{0+3~v6Uv*1hWtUDTKO|37k_D5+jt)!&$Ag{_iw+t{ zM=u}8TJMH>LD=|{PGRpu}#n59(dqTPe<+I^P?Uk8}z9#Z1lL zJ}kWhouc29hCU-O#d<;XwBgWnR@J|Tyl{MRaR`rrFZbJ(oQ&R zS5P%BLKq+dj8t?W_cIEOW7vphdEV*BNZ79~Nioeo<@&|Bm)s(n`PIvPgY*PID zC%$C|+`Z(IiS%cfADGn`WFya__1PP%&aOD`&!4>e%>ghNp;r(S}H ztI)euDzRVpyV@JFb>!>KOp2Rui!abPoO`I|q@#l5ju=rIMgSv~m{Nc2^y@DG7KA?t zHGem~^sGqxWdV=k6WV%ylmG`xF|l4T)rD=v{bp{JsU~8sS^)LW!oq#Axz7g3NeHMcmUD zHUt1Do&q>-qs(eMv@w7ps`@_mF#($t%SD+MuaMa2uZxG1sD|UbRRlFUCvE%xJfj~4K_OO}o$p7Kd*U0B&OV9p(V=$@gHhZ^b=(WgR zfA+&T83*PSeXgE4jUu?LsS1Lew&he8U_N%6f_$-U`qxENzTBZd`s$ut!v zK-K4UIZp=4ts5n?4#5lnj32i}qQdvh`<3nMk=8d4yVpBLH%IuDVg9r-zWaR35_?tV(&8P0p|dhC7AFaCSi9S*LKXU(nGc`{lDT=ia+zOURe2)|Qg z9boIMW#eI1(-%j>;_8jb=fCAHa@-j=LSzAnKdOF>zjE@R@tTjFhV}{|6Jrk(czF z`X$v^xrqdD#*L2anWJp~wJp5tJ&^2sJnoLUm9+*yp}9a&w%6cD;}V!Z^&A|VZCFS1 z72?)t!Fj=LLfh$qkELC&y`j=(EEAXT`ju`p;Qq526ZUh294k1J~>ZS z)jTA6#&fC^79P+>?S&^2<#)1%SkgIkhk*VWa+Zv4H2 zGObFH-TW_qa=J1FajemYLqx^(ms(v6A8eEkYuu+wri;pJ)XoF0-9OunC zbgGRV9%g1{khNpkF>g=D#2=fTwn{}GPAb4v*EX34Q(k@;V(&QpoA;T9%jcy*du=TX z`SAVyk-f-u;?otcv*qO3*$;MhBgLhQ6)j(_*Xz2GEkdv)O?&%qWeLIm)>>!_d^sC$ zF|#dR+8~hf^8s8Fj-%=-K{c}_!lE^Phu*h&6%}qErSJ36Akd{7Z*FsEhmg?Z0E{`! z5cYOEG03QRd>`hL%E3d&c5nHs8@1n51)?E2q5O4hi5qIK*?NUjBG{==R%BZreU)#o1nxH5+R& z0)dDtF$o3GUR|(yV{9JCiN_P=4@-Y*zasR5y`t@M(0No~YUgal zb&f{#+SHJh)GYdG1OtLU6N?;`?$-$4>YL-v=H`W>oChsM#V4L5? zfuD1{_R{XJR>Ry4AA{Wye_4T}Rnp7_158M_?{{PncD!@f7&XjY) z%p&yO-5sks%z697BgBL$zA(8@W{Ip|ySGvL2#a2B4bD8b&$lWCh0=ln1(CSVjltL? z!n7I+0IoI*zIhjWrHU06m4<_}GBXD6jZpMXC`A)wFa=(>|Jl;i#?>fzBq-ber z$(P{1mQt0I1*Zdn$iK$G&;M<{AmF}MJe$Yw|3LP6bf?n)UQJOEFD|m_1zmpM#{I7H z*l&}VpA^&=yB{zIM;Quma@`WQ;w_(jIp;IxiV5FJ7pekKfbD)?Qw>hp-3JHrQ zBz|z+ceg+b6QBR~fE&~>>qZVGp8n(q6G>ax+y9bL{J3%{Gu_HC$F*Hldhv9Uhc7j; zu3F`*KCAUR=Nht5QdoM*lbbzltFzlr)^K)!EZioTUNzTnyR#e0{F~b}`5XZ_P1C%I)mREfTZ|@Mq zH~U)o1SCx}tr-gnEhXY3HR4pcLU%iZ9{6wmGWhs)obC}pf+B5Iyr~qfUgnEyEdOU8 zxa89^4>FoAD>`p1Yx0b>8ct27EYmEfK7YNpXO%=8s{W9G)_H#SF~|4v8C?HYG<)3S z07q;)IxUK&9>ZRo484jQmf2u)!w(7CUnDKaIlBuRZBLj0Y* zpP%6iE&|%;aJ6cl-Ts-c?QDbSHSe7j`q1$OavR);^NovOaX8tEY;vlDY1~Dkr~RU# zBDF-&Pvf|>$q_m@Z(S3sfJA7(mkavLm`>UDSoLn?!uO! z#}NcCVbII7)SU-ai~D@U_VmGkm~0~ag+QNzK4*Zw%yfB_ODqk2EThVzsDef1(@vGu z6i}xQ8m3^o7s4b8rNz0%kcePamBB$dYisMMsHn*drDuEcKddzFgRzB*#X8$P8F2%# zNMGOCjSI!b^L!yJaS5sqg2JX;@L-76ha!5j)J&=3uCVgq1XilbCEk-;^i@<|j&L)+?dtBnCBCBtRzzbZZTy zo{vFc5>}!qgUjUWk>Rlqv+oy*!eZsfw9Je@-TiY!CD3njngIhCou=Q1u{g#Ko&f-g zYg5PlVf*Ui!}jn$0?|7U_?oHBFw3kO7|quL89=2$}6v{6P%<_|gkGSu_sPn8kFgAxj8(VPZM zu_c|G>d!D`ne%hSx9q&0DDS#J(K_oaPyqO$4k(R05+CUTBfq_UeI;KBi&;iWpuA@# z?YLUGy*(uWz^e3JqLKnrI+xjG?v>+86Pe%f&oEdK{*KQ`Ixje!_v7NmX=MpDgoS4~ z`A{wmHrc^A^Cu8sSMRI4l2!Eg$$9Pg6{%o zXbwFj9Wak$u&`DFOWh}AuY>AFVmRA92I4^_-EAlZM#L*9c=Zh({pjzC7Nsl^-z)JQ zm(u|Yea{}=qmIY5*0Jr!L;I`4njz{j(T?oYuEaw@6PJATX67Iq8|rJ;taF0-{$J<; z2g$su^&vG~V(aOlp=)nkQilE}j76zVsyyK|i9SSn`W8U!A1ov6mY^EZ5b7(=4D99j6bHA7(bs^ z34hM1h-#`Tr;%b?&W6F6n!@*+vm`Bix)2Kc7FqK#+Hn-Nv1eeFuBoMUc0Tp^OxP<{ zaqY$toVJg*1OBCy#3-sWu;5rRmw^H@N^pd_-;Sw##F3lOhDZ?4!kUF?jl zIsR2Ev&X7B{zeh_FyEgmsXjr_&)Zkc`$baFj4BQ~U$F+^vV=o?`zVMwVc(^x$em0si84Y{lA5PhckfxXZ46RE^+I zpRHSbPlk*PeZZNF@3@~k*{FJA{Aw=@09*_(>GEx399Z6>kPAK>hB0KQa{R#n`Vj`P ziN_oI{AqIJSp5)w0NDW!?~;S<>WMO_BLB8dB!*DfSOR(>u#M=VDBJ={#$8^0n02e+ zQ3zRf$ES%PaEmwuiD7apE*To%Utkq=cF1?8EU+bIz?TnYl$Bf0Ec=8-0>ngLqqc){ zcnh>T{a6bv!%TASxS4qlhK6n@X&E9fyf7bgSw|t()Au zt1cKi$WuJLUBdtF8yJ~?@M8lXyktPh#6*prP;1(~WjRk;LsHPeTIGveVcyWBzaq&5 z8bB8kt4dNsJ>OSek<AKChHJE-1XzoGyQ;8Q;vHukM|o21l(j?#04;hU?5YNWhwn?F=!{ zd33?C+)sWYgVRs$-h|un6L-1VR*uY6M^EKRq8_f=sA7FrX_t6s03icfMRZtj?yDLu zG!%Hj*YhV*a3&P)t$~s7m&W;YM`~D4Q%{uwN?V>M!c^r3JW6@q-@GLv)L4e`DopW+ zK~c!f?!rF%*mKA6rX9L2NDRD8`cY1tOz5+;+t@gxp$3qc+3A@IF|6mRopw_VPraoU zdPawt`Vk5#?P^Br=O~D+QlKSOZH|##NEs%UpK>gVv?cmLC4^6}l4CXU%7CrB@Fl(g z5Y3GbM9UV~VbLW;$f&)Sc;$~!Cl5X?mj%?HuT3J%h*9`0Byi9+qX`2OB`jM|X@bBU zZ02iz4%+a1g^-)*gVHG`4k3zuyDCz-q`fyN7HW6 zk*>>Hu}AxS*#ObN4@+Gu5=|CuY5jtswL%n102+YGA6W(4Gd0z)T2t&Z=PoR`XNlD7 za#Z!I{N`GUpf!^?1R*1N7^dpC=>}UfdASfzh6GjLDRuXF@$3cNqNNZ1sx|It*%tv! zqU$>cQD9?3E@=G0M0tztjZ1DjBDcsqiq~M0*_6#9qW2rv zntZ9C5cZ0mJNzQ$%{C-mVTK|J;mVMx-{k56XussEB$3}d^P?l7$(7QIe9o+tpnfeC zODew`tuVk_XelQwvGv+qU%(9paD*bD%snU?KTwA}{HEko{Y_aXQFCJU2sloqcU>S7 zOV1=VA<{Na`|&=9Z!f#Htgj9m^bTh>+-yUQnHLl!%vikcCTCLf(FlwJa;6ijb z4y;`t@Sux%)Y0sFY5lY=6aRYsJVuj z=u1Z7A=aAH?oUep$vTl+tKBm8j(>YY>+3HT!qFoFANpLuWQ}55}z#8Mk0Zq@+?t4g1Ay_yb~jYkrp+{BQzQdE;txk6g?AOLPY^GbNNeDC01RD1DkNR!;)VfvAgNX za%~^_6OR?FiS{rHzbC;aEs__&x7dx%l$q=a@>tHZp!Zy8T8&<;f`V4vWeLlR5(_?a zyM)6G2Q{SmXQ*@NQWz4&4RN?1!>W^@vf2$MI&b8)qvs zC@VGYB3!;1G?^k`F&GqG67CaRhrlhV@+-)5fZJ|phx+)EjDg3D`#`)sXrh92?wrj` zM?bqEEYW-l08aP5ly8s&lIC=WA2g^jeqyQL!$H&v0Q+l~r?XI%FH-zzEHwr+&OC!V}XG zhVm*7yfeEKryb`nFm8mXQ#KmMaNgELZNG)5&$P4=9 z$nIjv*LL4Ru zo$urAgVC}c@jFK%4^Ml4S(pdiYfu>Ai#HQH$^7Ve!TuPN2_x;v%{mJ7u~lxkS$ z9WU>p5-g&^@*r3emHf)ovKw_bqfqJNHhxRUExMt_vdLm$hM*#UH$Zp7%ym}_sz(6@vZBrloq*FZ* zx+u=9(XAxQNG=!>LzPF;wG0Mz)-qWv!?uDXEqxzk3w;E?#&A9U7r{M< z+fG)?xn~dJrqm5Lirm~Zl2Pxt?N)0)PS#qr?c?=Dd}@@`0?S8@ZA08HI&gnFd#z(!;qc_S8=MTkn@XtNxF}CsT+og!W?X)KQ>sb zg3~}|G}bc!;u!I|mq&5Rwa^lX9ndh(WbwzF1Hbh1N)4aKZ-pYapKX++1eM|y7Zyxb z@b9FD>TAKY!~)WL`FEP}qZwl9%qH9xEnhb#!0%!)+eu3U(+|a@XU0%!+M)8~fi}-r zp0Q=~nBB$o`8=*Zwm$BDYs&F$yoIzFzn!;FWsu5eavB}DrSk8$^K9~44Q_54Yrop% zD|2Q-`fU0V5Xu&)m0b)fIjNi)MVO$!ZNfY!$`mYe9Q)OEzYC()WT@^z~P0@T==l z8!gUdoHR}*Qu)i^=^fGYuNLMTzr(|CG__}6){{|25T$hS7zhYVy}rMQ=a!ttEF!eINx_)ff!I&hiO_*+Yh>FKDq;q9%vzxQrf4x8|8 zio2-O&by}Po^hku^_~s@K+71|hZoMba{jl>nH00{kOO&~M?WI0$#W`M#3}JFP(Gdh zk=6ocIYmVuOWeh!rlDW^nKk8yLRyJ*mm;fF4HM6GW)3Kbu>g+$*3!`U&5le|tcWS< zC{>IdPP!VLdFl)Fcs>k%YY0uw$#x#%_4muZcE-ORS_LIR(_S3<^+l7#Ck3 zo_cV2vPnkiFei-c>Z+vi{dn-HwbA zDSccoyCL8%7_Tb@&w-29GZ`A6A13LbWRkqOYxS`TLMyjOXg+!TOm}uL%c2%SP77d7 zJ^PqJ^bLm3-TiXaS9)t$xX8^tjA9S;a?7M====Fk1V^9?DU7$ZYikwNvG(ViJ`hk4 zNPrQu3k{Z$^F!Q2yY@B%)3+)DwdLY>7jegj=8fcZcx%emlp;$2I@ov=1^rA3r|Xc!Z3H(e*@C T_)CEo90RiNl%y&oJ_P*_08_v< literal 0 HcmV?d00001 From 7d648752f44b729e186c09fd6e27dee2ee7c720e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:07:41 +0800 Subject: [PATCH 538/843] update README: add font-weight for slogan --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b879acf5..106a8429 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-

Simple Go HTTP client with Black Magic

+

Simple Go HTTP client with Black Magic

Build Status Code Coverage From 5de2513ac1c2668568ea1b94cb504c486b0561ca Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:08:31 +0800 Subject: [PATCH 539/843] update README: bold for slogan --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 106a8429..7ce4e363 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-

Simple Go HTTP client with Black Magic

+

Simple Go HTTP client with Black Magic

Build Status Code Coverage From 42bb2b8d329719067c016d0a568f4dd1fb14a94a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:09:20 +0800 Subject: [PATCH 540/843] update README: strong tag for slogan --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ce4e363..f6079d87 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-

Simple Go HTTP client with Black Magic

+

Simple Go HTTP client with Black Magic

Build Status Code Coverage From 6ab6393b61529f6e2b422af3b1a03ee3c81d03e2 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:12:08 +0800 Subject: [PATCH 541/843] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f6079d87..c6845129 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Req +# req

@@ -16,7 +16,7 @@ ## Documentation -Full documentation is available on the [Req Official Website](https://req.cool/). +Full documentation is available on the official website: https://req.cool. ## Features From 8a78cfcae346ee909525a15675e19aa47c9240c8 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 18:17:20 +0800 Subject: [PATCH 542/843] update README: update logo --- README.md | 2 +- req.png | Bin 30990 -> 34587 bytes 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c6845129..0121ca11 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # req

-

+

Simple Go HTTP client with Black Magic

Build Status diff --git a/req.png b/req.png index 28ab2c64773bfdf5defcd9cd8563dd9db8db14a6..62a770a361f0413186553f43f04d5606150c2d32 100644 GIT binary patch literal 34587 zcmXV1WmH>Dx5kRQ7m9lc?(Xhd++B;iYoNG8ai@3#6bn+^-61XR?i9FrzjfEj$@!Bt znKQFzKeG1$s;kOjppu}%z`$TA$V+QN?@Rw($cWJIx}88)=ndtoyuJqv3@gol7p!l& zm^Tbe^rV8cgtqVZ(|n(7GM(2Uv4e?;v-JRfI?#+Eocx7ON_`P!Eb%yv&%8innY=1;WNZpW{Cfp!fV6eC_c2fw+Yn{9ZUr zXpteI2wdvWa16q3#A7G7>rbvim8w;opxi{K#GaEx1}4K%0kGky5Ht-a%bq4VUqacr z;buA>_m>PlfR_8bDe3}2OLlvu>8|%0YyWzdgi2z4!jMA6;2K=Um?0Zf(Nch5^@{LJ z=l79^t{e7LWEd*vGO|8z>J?Rdd%4C_MqTd2GwCJ(DOH37TBs}%ggaEz+BEc!iCy`a z!BV*gviS)1zXpy6HA>2r;Zw&cdt|zlIY~WUNmOxanZ}Tn9J3H-xE<{iJJk6@c z5V#rGzv+_p6y^A959;9f3j#)~`CTq!{N|uvt0tgp*mDFh48X52$-2~1Z|EPP1%VsG z*FbK46_s)P@4j@Z+|kxT$m0?QgmcX}{iukp?^h1s7+eX<$WX~roV9OO35YCl;Zi4G zNhv)k*(WHYFMpy|87!&)lChtU)Lc1_cvAV6-*`` zo&LIHZ3!Z*uL~iBMg6M><6d!wWKjz$c_O0%0u}Zvg94QVGJ=j9L|l=v=N*srTrd3N z233rbHtU>J`Wn7LX*_7Q{3au1%xQ^I{HRf=2!(@K z0;kK{Ea{Ay_vc&O$-#NQ0ekqp{TRN86rY`_GW(Px^&6}@400a+b+63Qj^^6n>R?D) zFrCWdE9u!W&5jXCGS{`_aH3F&JJ!u0Q-yQr^Crsx^{77$C_V&uWGdXfeF7*pgeD47 zrloif1}VK~*to);$fR)DG!o@}wt8$4%P9w?WF}Y(t)viV;_*1PhJwHcsV67&o2!b` z$Ah(v&`af3iu1^}EGlf?hHD;#mt0^Mc8`_B2%O%grI!1jNNj*C1 zWwvtr8Dte8G;-IPc4p7)I8LoGFWFp^;_C(fgQyErRwx@njCBwZ%QC>3ja4L)RHLB< ztZ?}HGidi&aP+QqXOGMG3H|7PE>GV8QC zkqDU`Y*0h!myxuKlj7tHn!@v;$-9)3Gupy)gOg|!dmNQTT%{#!6`dXF?IMCGPrmjo z%uwBLifBi7-)y*ErF! zqn(|!8;LbC@vt|wUP^W&4a%uJg8gcAQKs>yQ|j#f)fC3LIFiB08F{LNXW?{V$=ftd zIIcaw6q=My`^Db?R&W;01;N)F{eOXkka48tjG@Vgzk5G@qfRN?qn6rTTnK7igOOa9?JV{(29gyyGf{7Ya z`n7skhrS;<_@bl>ZtdYRBR+5Z4-uztlvyL@QRNa15D=GZCv^L;%4VtsEaw%XO z<^+Nwasncg+l>s>-SbNzJ%t_g+(-z`ln1s|D1XFJP7_ua4Z9XJ!pe#2#>1#ED`(1C z>frCqxFAxf48+3Cdi zgSl)1Y=Z`KE#&et>|#!-;tbb{a)TX0ck{Fc=j-^jZ%t8pi21z^2q*(J)y{l%LOS^|1LJJ(;cNCjSKMb1j_l-Wjfh}u^pfEv zgB%EvdlJ=@Prry^L&5TO)Do`3LT$w!7l*AI^Jvp7i2r6FZ*DJy%NHT2t&SF+LJ-WF z8|P^^#uT=WrWtvo*>8&9F-V57I{mnyeh&|6 z3zT$C$B1G$U!p)`kJpe`$!SK%$cdYUfeEM57Dc|GNEQ)*0rm3Vaw7QR{p{` zZ_GP>3#qk7r$rCs2Sd0lY8W2Sp*&DysXhjmhu0_7CU%Lr6#tbvZ zUvlh@RCM=y=A82tWi{#hD{t-2r;~qOiNAR{(>XC;qntF{uC6N_&mzKD!V=BO;z5LP zk1oZB-!8*0zkZTHrqUQ2_PBxLHUU@Fhj<}$w&yi059$s zWjcZelzt+2VZ|8ozB?YdciM&SJ8a|u3j$~ftpCjtaf;9c1Dz_8gNhN$T;Dh9lL-~Z zC2ym9WGCPDiexQ}!3mvs-nB^SS@UUd!H(e!1;=byb$YA}OV zd+GkMD{xJE<~exO1oXkNnAn6G!9*0#B{-P&1=J3m+_PoV(ooOT<4)iEuO;8FcIsnX z9Gr@OGrQ#8Dn(c&ov3s}07#c2Y{((Yh69U~>?!YwgqkFU(lkxF=1Y3-H_+L<>A}hF zxuiCXuxvXpA`}pT4Uws1{Ln^fcdFh|{x<-ky2+Gpa*Q8MYmr!_ICInmVPrDyG>hQY z6v?QIv+DoGK>QQHY7h#l?R|1Zd#3uDjNj^}%SyAffpfZO6U8!coTLi_nf#z&$(wc? zmPhuDex{WS0}jG6#5F}RYV!YMz=D_{Ec=s&mGKYnl+x@3x2S%_L6Bcz9|5GtzuxKo zn+pj3N4jfskK#JY9MD-p3|T{2w=_dC=WrKphNGEmJ+N`GGsxMqit_o3L73w_vX9+O zIHRmop131{4fqW8tu>Sd*OAh*Bq*(?-%I0-aMDZ+`EsM{awGie9ZORUUgRH^t+fH3 zVQdR;2P_`Wj|gsSif{Oz1C5Hq9UBe>Sx(s%|qxhCSYMS*QlW70V zOcO>@&#o?|byU66npZCz2C(kis>GzCUe)euIN%@3GGIW;gZwe_fEgfEqJomgGJ#s$ zBCE&{OPMJFq;TM{wrYd^h6LHGzq!eY6k~|E!#aTnhTFWmfl^K$zQW{&urW%{u>{vH zS%Mr1ma&0y!$|hAK*OKFqs0}r2&cc_+rlhe)Kd1BR31hv_7Ab#FvRh=?C*XCQ~)H< z5VKX(flbs!z@de~{<}svN|+`p#6tyR+m8C5byCXd(tach;`SRcydIQi%#)TkeXk-^ z#fMCyL70n#)#UAq0fytv9lE{mi5&Vw<6E#K<8J+?6|TMd)Z(8;uBkv_&d?xGMhs8o zB?Oy70kXfjWG!x>1w;d#XZl4kbSpQzs^u*NBzSN~-=kBdBG6Hp7*$pWY_0#~2(5%^ zg(ktqz~gZY9K+kdiy@J3m2HyC{ufuuu#MDBMMb2@)5AA&k}Eg>p&Dm~`j%4VBQl5# zv#A{)Lwd5jmZ)4FrvSy9eF+)RVb8B0NDl?XmR$CcdIp#Sm>7&pL%jVcZ`ww-%U$}T zQ}&(TinRz|S-&FNY^D#TXZYdxVrpSnQbJnmV7|+55JzQ5u4!PCK*mdhz96(zg3xOy zEEECJ#dzho6%^i)bb)}nd_7J6D^%2+^Q4kHUo(xyWz634Ub%5Bf;oB}J`!_%C>zU= zqclZ1s{b!y+Hep6KI0m6{*1b#CiEy5>$MLUj~IRapp2>OSzJ@ z^rU8WFhja~pGOqiHyr)7#LU*}NWFfN!EMnw>M6_oLYs+ID7_DhCQgK9|2AJjVuq&7qrlPOT6U*Hd;W9P!^ulgAkGpkFHTFvV=F#q)g$8z0mKB$X(A5{pAvNPUm_QaWA4Md^PcUd+$}}RR*HbZc{F=ar-{o+~EhM&Uw>~r}>p* zpB0-i^7z;1zkYkd${Xm&4%PPNR^ZmBD%F7ir9CM3YvNzOC| zES>Go^*7V3HIqdq%TtLP6wTV?lOsOYY4|mjrJLgYzRLVqmGHow8;LxLTESSrbaGrG ztOg~>t6H7$fz^WD^$wL|XP z60P(U4ao#+4~^er3Fk25lA`Z?qukDzPUO{R-Cj{ZO?1z_uDit~A0J_)7;#9(gO($A zUkYOI5@t7>zHbqDtOoxu_Vt(=-*{{*Y+6@+`MPe5D`X-8(2(AsAxlt2yx zIoAbI*b9^-pD()Pf9u_T3Qn%;i8W zq8t9Kk#@ED{eHtpzG}O@{47+rOz2B`+w*Ba_wFPH1-(TAmKao3c1iL!G6RcJy9Sy+ z1vbFcS0)gu#l}5zAluI`w?_JM4*q2*82X_mGE}>4TKdC_17oKGM@5?ictCKuA>98E zX)TwWX}R=iO}-JA_r^Wh5w}5}Tf+#;5#L2r?jW_TT8%cF-77@5t-CV+8Vo!SDSjIK=#F_2= zh%N6lH|8S+eOATxwu0Y>KnM?JNm!pIW$aL@q#_zi_~>^|fvaYbACqq2lA6{6DK+8z zQq@-k6&0P@f_KhwND#sFF`ef)8=;r(|xk?AmzXb7FkBp$I)FiCr$?7z36@g|nhnnfjn?LSGgGwgBY z%!I|=Vfjpv$$F~FAeP{mf~608V%-yj@K-j#M$OKdv0=8|*a0m@K8nZICrX6Ov-DGX zs1(R@O7e45-SQ5A5!7wr^q$lGMj{3jXz1y8oE90Z@nhRABdKyGPr_3n(e(e&6-fBE zUC*1AxTskVTUgIAp|_40l^R=5f}({N6=jbk2G#!W3%Rq>eyWRU?2dH%eaE?(w{@VmJnJCH^8JrC!y#DCFmYH^;AC?||=Che^ zd@m5XTo{mh^6pO8?nF^2aB}l0V0T#H?NICR)a9UdntYhKL2TNk^IXQ(mg4NJLeYU% zL!$kVP|{4uZ5I6K+@U)`r@<8$Xo^H^v56qug2U?jvk&@=Fd;OQy7g`Fo1d(#)D`W* zbc11qNaXYg*KR;d26~B7T_uUi%cz37br#zTbzv@%>5q!DtFx-KxaJ2rFC1RaQxhRC z82jKSROT$O6cHu$qV{4Nc0*NP^I52sEOm5~qNIGI6*~`x)T&ob3ddfa$^jSgmU&we*%z(fK@s0qVZs&1v>)62^Wk?WK>z1DnNv<4Kxz&rv z{qCY|AwUtL%q!)kgR$hmXqZZbSxIEapdD z0sv$F!bN7A*%l+364+7AMcKt9{b8m(i5}2LIFLc!r{ZFL$DKc)6~`jo2PLaMap0qy zkcWZLG*gN*c%x5?Lc`K=BR|M*1V%f#G6a#;6>-XsTipKQP+qeU8d52c)o~ViGRi!k z!o;n{N86vG#`$5@6O>fZ)`m3RjNQM#b`m^Q(@OA-OeEo5G*GOk=PG94^GUURP8=EV z@5=`&7E*@PxgSJe+2!l@g@nRHs1tJH46pTi#ooK=D`j@miv+o$<~Ga31q)+Lz*K!~ zA!xCz&=|R;c2X+Exr!qxb51*V;dA{88BQD+y#>t+J_zK>ChcFQiApq=q!vtJrY||( znJOv(2#0~POuNmsCQaXEJRs&*UG1K=E@<9yI|Ix>dM&7^>;I#$^Ja+@mfk@@<*Dus zGTkqQ4Z@NU7Yr9Sq4^s(mix!2E?pmq@Y?F7W3-}k_Q^%uQi%6XMde#w7lCD^lZ>$z zTBdc{WSQH|)1_o}ZS8g@X?^jM%O9jH6E(nJoFs-{|KuQlMi0n~sC9lXX|?X;>$9cc zVB-0#Kp1_xgj?|>)--h4d=t(IZmJ4@;h9h=EjO>ixWaIu{5()Z8Kw0lxB5wHZib#{ z{_!S{qiLJC)U}|v2I0JgNRLqX4u8X&)W~CU*!Q$)9Tp%ocIf}tcl1&iGtbO~IKIo50fv;2=K11q$0(Dv6iYRR zY}KB8>#0)mrc*CY5HslTGiH6i*Q^;F=jGbjk*n++Hfb*}S_USbK#X94cn!F@i581D~vY@AV(IiHaMiE&*t1U9y(3?GpB&w`HST zgHW1Kc=n+!62$?B4_kx26=DcoTEaDzNSI=u?pNSNpEVlKgrMtJ8I}Xis`Pm3WA|0=9I(k;x5%$2VV7cY-g!A!AEI_oiX$+0jwe=B zf$-S!W5IC17lZ6a#&yC0T({$)rIec5lKQ5cv0Q_iW0I<>8hYG>;`F4v>9k3!89|Z5 zk1b31sD9pSmh`SYx5ASNr9e5y(E~2&RFne$`@&u8Zwj5~SdyinB*ae)60MyH1V0O9 zkcnH)z-2HenJfr4tTqDx<3p`yy$Z`CpQt_<*rPsudJ4qkE38bOPB}ZF z+gy$}WW-O49P4e*5@)IB8T61J?`u}hq-DURB2mLuqX!P*(WVmcCq`(>hW#Z(^(}}u znEnJLNuXn?aXhgz-5g_CHkq`BJi_k`tL#1J1y7W!y`|Y%bCp(}*=*0v$d5}X#+(ye z%(XtMY-eCVgCgUe&J{|>vMchv%`1OSEhZUXi=>18{ zlYz(3bT_POHSsOT>u#56z|jF)@=pMHd%dbZA&bQxGaK*&)@#h(G(OMIB)s^b7}NJ> z;5NL#31IZB?LE4n;3q8JI2G;Pj!#juIXXCxZ=L}m#W&6lpddOQVK@SZTNemj&t~dd zuZEJl&1LbuY<_1)@0Sr+?nDwL)-zG|K6E`#>;Q}#lb>^Co7>TkoG4l>E48>B zaL1-gLC`;BZ0m*c$K-Q8)wArnYKNS${|@q*OS$>W(f7-wkmK3;p{CtC8R5&pX<$^+ zzdTAekLso!y@PCB#2)TOhM7)JhjEuR!ANfid zZ}0vRp=JZbaX-4n(x$Gs#Vgz9(y56|F^4?X>0MrID;*67zP~`M{z}G>y7;1Jb6owg zv|DqND|mg{8$@Kzqt_7fXO6^Xgvb=TedkvNgk9`k!OeZBHkmZL66w%oEpr@nf@y^~ zV^v(V^Lb%lSj~TQUsd)^d2LWVN!T&!1Wz9XwOF5fa1%0+BthfZa`wBA@4+IGev=i` z2!8@zTX|(C1{DHEY2~)`4l#u|&{__P|0P~qN{kJ?&@4}tT4iyRm$b9*|FL^nS=4`g zQ_#dv>0x_NJCQ6Roi{1?WPSVTM*~c6LDTfn78TMKprL)AcyZQ^Y(`HY7_Kw$oVtSS zox@nfO?VB7FTZd604{++e!M@CAy#+b>rLDf)?18R&x^|J)eX!cRGeEwW%Kb2S`+vi;1{@OUJ|x581C?z_8S?GFI{OK97vvX zKe&=<_FTY~)4Ef?z{^`4BqN$N$byep;Lu$t(L*(D3%`yd`A3?1U>j=8g4$SZ-)i}C z9`;E+eY|L>brzBeBx_oM7OZGYu+TzYCP5$D?vasbOo0v>w0J_?Jt^uN1&B8dK)C!E za{2eudz)bkp4wceQ`0DWZD7fwj&|$9Wvx8mLkJ?B~XhR>UT#rauV=&b0uY|id;Bkl?IpD;UTGKTk)9{&jw zp*aqpYK^RhgWefMa~bLdH3zEAlbetj3d&5eR|FWL?`{R6|M+@Jla+Q3t_NxyA0G~; z_9uI`^m5x3@^Ukow0YE|^)gX9m*Kl$BxAH8a;QKM&97DNA6XOB%iVDHeUa7k#6o_| zB!)xk`i3!h9843Nfw0Rg2&h0<*D-(N6WKI91JW$&@gbn>S85!iJqTdD-&A9n7hBy7&w%!IT?en^Q zrD&WdNgJ9;7&5k7#7u{8#5y%^0VV^gh2T({hE9XZb^no8AIeZYXO^a8Me?Vi<$;@? z=Z`RRz~YThv)Vk1lUnji+@xw~Y_2Z;ex_sRN%ujKphiGUU`5tc2HL160I?UyudVxK zeugvdd0%6p?ORE6N1`;Rbwz>1B5t{7a#h6!Gc|hZr`|n#sX{fganxnkR!d7EOt`R% zfZMzKnRb62GhDr~V*+NTT{N>V>YKs^V;wf9y;UyZ*h-Dxi8U4H@|mH*mnGR6Bkjed zlIA^>NT;2K*{pf3@Io%yZa&$t2O(4QL&0e(ZmW5=6gMzc0cu6 zK2G!-pDk!JFPn>#=aeaFN({;OsrS7Z5`$}ewLdttbbW?D(lOo@?2N6u0=;M206l5n#3_Wd|6bR zo2pQiWYU4gUO=9qld-_f&F;T#wS7gK=?tZO)U%8!RZaQ=Q0}-xjA2pqOWmM2P)ZQ; za1}J2+#oDWH)FRl*YEzR%$bx^pAdYZHDWz_(0uFdHMa`Rom#8|wbd8>_>pI1iZ~N1 zPe9=fGS zT>tHimVt;9DP}x}hUaC2Oz19^*wsNR&hzaS6V9_c%OIi#Ow`q=xEL+F0duziLq$k)xs~8rEUngpa)IRLg1t8FpQoO!q3%NA z(T#uA>u&pLYz`m8zy!=lAwq~+U1Lp_OtBek1@oU4r2JZ+@)>`#A`vv|nR zWKP!RH`m}V7k7)U=ED}Z0ORJq`2djY3gKOv+zrIlXlh>UsU9ZVbD`Y))R zuGMFYe!i7Nh>&KLGp>eX2bzTc>|I;ww%r>tJJwZy5{WLK7M*-2Yi{V}CqBJ1r&$u| z<)G*}fEBq4jX2bA1^L3KU!*xW3{Bej?ma_(UWUEL-uu|5IXkly(9CZ1KfYpF^!LRa z4YNIE%Ci~9k*Rzg+F?72&Tp%@+58|gyj)zAL2aO;(p(+)0ZVvnCo9+x5OV$zrbfsy1llDszSh* z0;tgmiHaqVPpzUd-*cF=dGp7t0glq5<|=pC*t47bI>*Tmf}0UfRwT3AEei9(i#KsK z(OsI`x>IV%e0QH?$+hiHS*mto&B2A-Cx(5ouUIev!8e_P1{zdbZ!f0y0XuiwZ@Qcc zlVwbx(^e$tCNPQEH00c}8rxJf_pLNj_ATkPFT!znKqum!4-ZNQ)tHKq23 z*NkeBs)k&mI1BK)46THKmY_0V^zQ!a!$*dA0=0$^>IJEW4%Lgd@}CKCRuN0~BusP! zmP6`?XL~KAZzE_uZ#jpC=GSc6JgOxnqX^voPuxLd<3}fYVR>h;RKNjd^x4lR-?-5v zHIU~^nR?zAq#wR0*WMYym4FFrnne)m<;P*A;|qZ_wPOfR?hsia_b4u5wQngh5W}` z3yonfndo3NA=DMUK$@_JK#{l$u*l5~1_#U_- zC-7wIb68Kutnt(zts(?xS01aao$OFuZaGD5^0ebRcoNc)XcRNrYpw5ABd0>Izy1wR z4`bUiTofY;fMvbDX@J zMg5sLFqv<`V@Nw=+?%WV#1-B4<`)0BoKp$+&%p6DDq(3=g55~d(u@t%X2&t(H=C#5%=qT60~pS}(z-C^1mOl|~R zSWFqeV>6oshC$Ws9C%m7Nk~FwoY5bn=mbeK^GzR!Pg#uR3Y; zKhNxKkDb*76Nb(h(j3+1Ifmt<6BIt*7y$$iSS&lWl5Cd5ROf%5cln7V=;iXcbficH z`xCTJC6+AQyZe2XOPcR@i}cHWXC zHtij6%faK-_^o#-j0yu2*7)S`eX>@IIZRVqNPJIyR`(-48hKZtOp@@(HBlj6{`fb< z3E-FHs#*q(O=+)j{No|#n0pcTpnR)<^%1B_^I630_%Z=+jsxu?`p;q{j^-n`=py$TIEV?MCmHnO}AaLQtzTxypso zL_@BtIZR&ogE-Z?oZ7l(bNAVhUAq-7ZaHNd>(k?8x2o$`wM! zPFj%)SVAmoaMDI2s{f@84@JFRu2pDSj`fA==Ddf9Z@$wcrWJghI^y+UZ@>On-}IYl znNC+n22>JfBxiW`aErm=MPtrnG&WKT$FdGE8yKRpib=C!+jXh-_90$&SO z?ss7c$ahw$nwJCb0fA zX4;eD(1-{4PbyJiXGV|#aT=cedAUz-2o{V&FM6oVZ)BQh{!&jy*@U!l4DdOOPOT^ zip<(BF{{NM>_18H`=42oy58#APZMFCAOYl-&DR@K+SOG2v!83c8`Knlb3BX${!wBw zkFCA@kblO2pkovR(}1XZpDmkE-zqA(s0*7>PwP?79NN!wiMR^d*gzdzC7rY?my%{T zI7w6*mFpafWA27zzGBKTh6xPQEvG64d`a=5?w|Sc>OJEUYboN-t9qGQ%9e!tg4*Et zFC%~V=dA*C)6UK*mzSZVD^ut672!WkVd@G$pfkfI98og11L<%<%8-=(d>`vm@XHYg z*F(8Ozh3j1@or})-v~?fKg`>@bL#_-sl|q^OGBvP)jefdQjM=m8(zqa#!7&D_Jf^f zp%K5Gl-0ruX1AkL9mFvZpNq|Map^XSsL^Yn__4l$1dJ_ZH`ZJI*KtR42Iqu|2&`uG?DvtE}k z$IZdC?ft2~rDG zy3?w$!{W*vrgZieO}G{ggbYsGM)2rl2AUH2CwSsEt&xlf_~x_#d2DX&T61|5Q1Iy4 zZ6=6G@<@P#+k=<+lEva+^`gqkybNaD=D;U+LKdSDAy;r%-e}{$c3Zu-4X&=hGoBKg zHC|=0SM=`R45BiNWmmTV{X`Sf2sdP^5if7ewnhF{p*{51xr<=1s!fG~ES`c+vT7C_ zKg&8jWIkbi9SJ_mb?p!uNi6V(h8+u5ovF>l*_jt-#n#K4v5SjohL*_>51cV)YC`cf zKn&3IrubdQwE}MoW#sPYc*p>k*i-Q(oY#BRz73|pDf|6~Bb)cS*OvPNw$uB6V<`4bA2x@7PDi@>LkW}FdR)O`{%x&^o_G2HkhaNE0WhqLFw5pE1PHmehjCi}n{Ao6-}oETa_HMfuqy|H<_Cx()@5RX` zySx3!-R|oB?!52QLk|hNsky+=`3;WqYBRAeoEencvYJBG0geiNC+{O11U@=m?*c~T z2PC+Zb9oU+mhg8^S-v$-L5t6tTIjo1eeheQ7)dINRw6$tUuDXi!7xG{LLGj_4t=so zd-#i32H}!OpYSMz+HWP2C*AvxFmSi!E#}7pC=C%>7)RfjF z;2))(jZ?%;m@LaRl;>$Xw9>*`OskmTx8}Yf56VR9@lUY}I{E+u{)~)yL1-`MlB_r?cnkd-wX%93Nkd>$UN*vR5i*es|s@Oq;vr#{VLH(um zXWG88xA}q#m)`w9yomEJ9KqD<1-_^`t#2QctfXf+d@^CO1J`n}x?_|!%+Md$2eY{-NT2$0uppH;tdCE36J*o~ zvt?$B>uA4rB1>h~Yul>pYj@)c|3%)2zO1Tvw!B^)|7f^g^_bp!A~GqmuQcX=+YdQ( zTURdoIZV@ClQB`F%X#MP+y3269?@!`FHZU!pNw!Ko*J?c;j$_xbpKFN}!i!0yE`{AZx~$IXjiIcw{^Mx*x+L?$Dy_3enEL25$zAA7nr z#ubsoy5F8mK}!p|dcq#R>ox-JnJhxR*k_MPRDbP=iadpO+XXp_ax(I>0^+qj&_cwn z#hZ4o)WO zE#zSt7f7>q=|DV&=J_e1n|(vq+2`YNqqZAfivtf)Ia(c9Ni%^&!d7G677U{lN!J7m zJD}uC!AwEYoB?rCd^04Sb#2tsZ2l{*TL%%F#XVUYEV~0VO>Ox*zIoqF0t`FwArz8f zSE5wqW$o=bO_sx)!T)>-Squj7_}&v_ZiwRcb6);f&%u*|k;3w0ckXRKptF{KeVEiq zmWj>>Rn9tQ;x2 zfA2h;K__5Y?D`i2NwGHuYH7GVTr*Gvly*Ds$)4}_5^qK38;D^(%P6}isVol3QZ2=N z{9W8gB9J}Nmti^Fe7e`>$2!CPL-H?LkP;!LA)9Q;<;Hi6HACcNcx*a75LPTKqH6Xp z_usi&o~dTBY3Rb6v;H$*lOMJh*|Eo#2lm#t;;l%IP0+7@M>|s>7%-(Wgi@ zUI{0$1^z|rdfLxB*__45_!@E6Lv=8#2(7IAUb(9E@T&h2&-43W9H*|P7)Ccy#H1*%G#y(cC1iP6Y^qbf%WGs7m&+1tc z*>d62uEbdnB>1!q671;Q^SBPF2uj$gxcEgAaMrSH@RCTv8MC~mII3I;# zFL+6yzWaVn0?3n1yKYD9WFWk9o;m38{O96s$Ch|BQ|@9QEAlprH4-6gKGGPo zYio+$q0d^Io125vdsc`@!{~*dYmro275+m#S5FUhjgBD|Zfe?(TLJYV2?s-762o zc0OG9NptMyMU4J>L}DkRWo&HRCAQa-McO$2Vi?TES7k_gV&e4J9T2qH>EsU$BE+~+ z>})Wja%^4n8x|Fbq0r^el9KTKOu|U)0X&qFSw_iNiQl5fCnrsScxh@0>Xf3I8VB|1 z`v+=*7)_-Hom^XE$C!jhCwo^{Q%6Hu+~ST7|GKlk9gZsKbeo8p(=jWI$M<&uh>qSx z7RE^ENU|26eX;(S6K0U$85_Z$##Gzt1SVucnhM`fZG|6u;Rn4UE?yQEmadocT&T&D zh$tGZ^}DXem_G&G&c15zD>6YLZrXiO=L-R3DkCGK>)-WwAT+U4%nBuAan zvc4IM*$JQTSaox8n{%`9B|zk^y?uPa?7cf@q}cfBudVvQYTH-*N1nMI9UZ~*n_2xhk)Vh9Aik^>v=UU(@93C5h(`Yg%!7<${hdWTwyqgk?4jLCMP<@>q>L!jSJ zl#aJ>{23p0VQG=DEvvM(KGTj{k(Q(nHE!LSYZVq_Ztqp^W~b2(c(VeOBVP3a@+BJKB`op z#g%cVPgAG|KRbXQx-T#uzmAXMmwZ>3BGXEQYD@DxJhjEA^LlE9Zsq7qk$Y54S}Yde z^`=OfY?3)dDYz=AEp(5clC~+=VXNzC+F~$%x^9sH@*uGDjmyr?jxueusAcF&+;1$Q zZ_rfT$v+c?Ew6MvZu*uA)f<2KRj>9uuSXfUY|lR-?B{A#<4=pe0OZKL7K(z5e^S7w z+DtkdXo=W!LQT`{gPU|A+hyh4aL*$54d1l`Fi-9<1qBpDTzM@VJ*YC#HV6YnsdU#z zM2X%*EB5)B3jzu<*uahVo^p!9(|aw;CWa?g@kDGDO-*ifXRbYWyska3-aXFanJlF2 zTC*!nmSKI6Tt^pEt#FLn_1xivRa>#w6t)62wF>ThJWb8RBl`7Kp!NEQJU%l?ar!Uq zzk4@~)4!p;EweKt1*WeyJ-jtqxC{jgPY{RLH5yVhB#wyusY zZ8TY)_SyO3`Yp>%zzKL5BY@bky>_0`pO4wuojWLz!9 z3DjpyJ9QIS`#ecfm|6$2X|8GF>b80>$X0-BE;!TNgJy}bNMLnCnKq}F4L_eVL)<*;Y%} zx*Ka{@#Au6RXv1-fuiA^N$rB#If{&k_>?%XtOY>PV;mBMdrxRXZd-ozE$KobFGm>>@&L_x6NO?S-DQhr4|eR?V^r|(`$X=vBzd65PsKq# z+)5G0NqJ#G9?%UWW*;SLffFmN^jqHaJm?aT6im?^B^fWV9kp;LJ43gU?>?%?^Szyw zi_h&>Xa2?Ab!DyBY=zQ!0TxP;M_Beska88Cy51X)Kl#>&IFXB}%~+i=KD)4?Y%7QK z6j&htS!Qy<9x7?7*8NGloDzMF$<8h&w(Gr8#<;Z|#+zaK;)n`%3e*Z{|I>%g!+4KT z!m6AKH>{n1_nJZ7uY@orjltrWXj&{89~$l2llRd1+$pDslF z{=D3?zn|wypEpSzNl6%AH&0snU48E?dRY@%s0ThHrqvS67!+TBfMVLEJA5AgKglrZ zATr>Z|J;u+-z*A){){rT$=AZNBNkaq5O;;sFE5TJ+UAenR~A-FjnV?*t%BgOg0DQb z@5CUz;*O041>{@)Cv|rYGdIFpWCI^;7^lzg6z~XB*HTjhMU%*qJm$Kwjys%NY1U^G z_oPR>Kw2C&+|wprHc|2#Ve%Ed`oA*B6j|;2NilJnUH9p1bOTM7)mub?F zb_wCz@@cJ+gB%VI*R{2^=S>PK1Sk_Qllec6&M~->u8YDGb7I@(#MZ>N?TKyM_QaYP z6Wf{Cwr%_C_xo|HQkAOY_PKq|-Fxk4tyT!~&|O-cn-h`OWIA~*SodwD0Ce-E`^nr+ z|IK1){~qG$r-UZi5|`~{)6qL%@`A)M@N0WJ2-SNRHDkKiQkXINB4`~t`gS5lgPtgS zX@0joWOtQROH)+CTUG-jArg9ag^BS`JlPq~cbJ&8AVVUTH^H~JE`6Cz$9Wfylbg3I z(W!a_ph0zi!2@QcX_K?Bvr%VW)K-s0&t(sdX%Uo@6!;Ol?GWegco$yh4l^%pK8rjZ zw{KPhDeCXxCP*apW`1E(#7ya~$cs~N1szoZkIhER1(HJ5bl#%S0cm49P%IIRlP8`7 z<);$tjnnLJF=DH0S%c(h zwREsI1KfNF?3B>NT-T-D^K03Tp!j#Gn2oMCp<*a0>`}QsDVqmZ0>#gfX?jS~`dNim z6tY5MWGf37CaY?Nunb{3E^MunuA8k7UlNFQzUye2oeu%IW6J;hdHkP-3_hbrpX&CI zl?uhdetPUi9LDv%hY2|cTx!5pq=~F$=6my4pKsR0Ohjm)Sje!mq106Lk0I*8NkJ?A zj;di~pLcK;fm#CkV4qoP&mnb&A}1kQUt;^CsiT^(-gw+Rj&z{Vj2|OAbA@n=Y_XDHQHz+v(9+WBBK6H62`~t9geT zstOr{3>VbgU$r8W5?pzLD%=c0Gmbb)-Im6hL9XRnqo7D-2_#R-&$pl>?jxYOKEY2Eps_&}f6B>-(`bO;kv86I+U6`E!+V?~#H z_YJ^&L=WKN3PX#<;IQRcf{n9y$@8Kk)xhQ0kpWa zG+F>#xLJk#t%1kESDUH(#@`z2>dO13!*^3lT4!t}m&cZ%0?~>-)^c@s6fR65@*`uv zx%&;@>%5C6DvjyiySnH#^p)f()FSjc-=t+!D`=|9Dm=DZ@l$^J>maFA)jxj zdIqWA{Fp~#FHl(OH4gXJJn1)k3@ej>S(%L$-A8wI2iEs^96m@M_J*qdv7@q$UsjoJ zn>O3=FvJonT%3(AuY`nrN_1Kc;a5IS?e736f4NE@?J60~Bb`MsDRq4;UyvLI~6V254yQu}%Nv+(<&Pkp`*7uXHL z#y)?o#=pO*@d=T12tq^SQeyGLTn-y*$ixs&tt zJV?2Je4J&nvEp|f?sj%mf@`e?6_&cy&2BkyrpCSX>~qsij0s3+*^cFYSx`U8V2|+w z^=(wRP^Dn?X<`12jvux;+$2!JceUEt{Q<`Bc;e=K6i9y6DO4pSRZbag*eil~cXc<3 z!WSibdKtyg7GC4`aTdJ^9dtC9rnd`}jMUBRj&R=+PJcaqeWH7oI!*SE9ZEJ=gR?H6 zQr4blm_MI^ut4+>G6-Zjb@}2nxFp%0i51ork?~imL17}aYk-R3n7ixrWYw)9b@$GZ14{0_#CxY+KyNwD?3ICHM>3V-)d{zG&V$%rMW{iRHQ6&JUA1uyzCXu50)4Sex&f zB|u}S-rGpeb8MHqJ5gmfElt@N1Es-^^HqBT6!wsqoCgS#}M zGnIoBs0t#qc?nwnk?~iBDKeC#rA=$z^W>j+Q01g|i62R@JQK52mL_5LV5QysJ@p9X zlMuU!C;t8^?>YQ;vdC%EXgX3u7Qb#47K8s;Kj&>F#b}50Fq6!FUXs#jnR&Q&n82&H zf4=PJl1S_H;=*7@tCPpgAm+;w-xFFfXV18uZ(3*94L*G^Q*1vjRW@W`lGN_QVav;O zT=?bS-R2>p-*T+eBnXl@=$pS@Qk!|_mD*CwM6kxUcSz2!=kATFQUb4gF-3tOm|+0? zF%&f(Z#{F>`|P(HNF06(lWmw^m}ULN3e4UBMI=$CIQ%CGXRHW+ZfV8eq=T{4I{QuD z)G9q)?&g9E_t5sZ@irPw-^WY z+cK@W!)zpt_A>v`=?>c2dHS~#xF@v2sn2FjCWecr^HxU`K5k`9<@hqmuE~63{TxO3 zeD*_hdh^TKyKqaJS;?lraKhp!&*|r{C(ZrZ&yQa7pt+n~*uR6IBMz6x?pJmb=;bf< zCPbFL&N$Ma1B&yxW=AgQ_*B`!KQn;20-zlB0X&8XgKxbX^6}VVFP+zcM0R3gUR9lx z9&XBIA9ronsYvZ^sr@K%? z=qjvj>tR$wboo({liv`o6J*-qLW-oX$4JKSmlfk1b=P>=BYV84W3|2KPtoq}#I4Q# zXOY!3!mrXh zK3-*!49xKBR<(4Qf5EHnQO4u2;ZPC~ZzWM|EohD-sUq-Th;So+Hf-q?m|xNnyv}iO z9v21u>u5)#Ca0B~Ep3_Rq!T(d)!o2AUdx>oWa( z53O$&Cpc=4Bq<5g+YYYc-ra&D8e1(OZ1RY;R7J>CnGIC$V?dDI9q+4j427p`vVi6V zy>Qq-lmwbuFmK8GQ!7X<4;ErwCR9h9Cqqc-M_9P(QtfvB+erpb6qswyKcV5vpk`-u zi*Q+-L21M;4Kj3pGn)mMR^ys8+V%QT(x4Zb&|&YrVGd>OlwB5m9tbu!n3HO~#JM4d z5ls2rWJbOItmdzv5pjaQYPElmKo44r)Efhl0Q2SGG4s4!Mx?WKppu!@24V6|@HHGn zhmLFK<;;-3+9)`0EDo6eZyq-t)$2J5OM5;*W3E89kc%M@X=uw7HtWadfPZ%WcIhp5q}}d$t{0W#AKXSh164^u8;asZA)(h)y)(>qhj?g4AQMYr2m^C zsu2uUY&lEEQeqJ|<&DLzbD^*M3rxIHH_?Lgzc+_1L5_=4x5s?zlSMV|KrQ{r?)Ek?!|$d z+!yMa^{s)H6I4S}s2G=&XxUXdz1nbVaj2Qh)6W6ehO)X-oOKx-#x!ja@@kIuJlhLP zlo(kG{LsZc&Rb~CR}0SP?Epk*_?;{iK?a8-BV{YmwtlQ`gXl}7xXKMQQPvP@fBJf%Ub|u%VQ0q!&!H(_EGDSciDOU#?MK9OFubIXSaeIsxOasHrtK|WCHiwn8sLL zmI*Dc5xv~@WFLeN*fgI5lOLL*zmBT@VDUk6Y=kAr06V=mq950Lar6}x@P z*Mx_i?%9zH*q{Xq3j@E-MxI>QB2s5!W5qWT;AggU!c`ID|5S;1kZ$jF2LMcDNmY`F z-yKvc?*|z$bk;Vy*^8^6TRA!f^eSRCWs&M7aY)2R%B8P%_j3iMa=g7vtCwC5xxHvq ze$W`U>CNw5hsjfT6V#=I<6uvG!^HyKDdu_?xl(pSHiqt(&6t(@18}|}-H^sQLNw>fjW7)izLGA=wyR;C0jf{k+YI6pElF z8kz2{trf|tp_XI-H}gIZr+flNIR5IgWcJxKdC*={d;o)PvlTcrxTbkp=7QucA522}Kh2j*4nmCuip&VY@Rjl{vK3Mz$ z^&__{1*08HP836Hz}crQEKVl8@yvvcjaZ`F8jJMz!r0Kk{5vyo3tZ2_Q6DXfZR}a6pXn-jgn{>V2O@2qKGb zDEetB>mv)X)wNydeUnux(gCTJcD)6qx2?{xIouhSd^@|u(Ug~5P*{Zq&uMD?U7+iE z4$rG+a#>pzlI*MFhN0^=sGG^W*U@@S_ytyV)AyNfv`I9P`myso?A*0ilzVMPw5!6Z z&^-r{vV17!KbV>9^X^4H$i^pk&8{2V%AhqrjB2yCu9i!p(yXTpiv^MrnRu@Cg_d!Y zxM+Y8?KrbL+LDkj6`A8s16rd6 zA;l*ZQPjU?{cw!%K!7EZ6=12Q{Pwl{5m@7}J>`4*xmOD$gW*%_sfQ&=%<`YvVDR1Q zr}93u9H&1Z*s@DE6=njm6+W%Q zZMjI+!2fpO2%@XMf&jJ;Kn9|Gxef4xwq&X6aph#Y;@EF}p;NAl`Mc}(3rgAua6-d( zwZXeZ@Hs3ZjfHTZpZv|mpA%B5jD5_+heD{ken$7I_bj!pX_C~ zE_1+$=uaf}-d!wkZ!Sb({N0Wy)wn*SK~Zitjw_5cso0aBe$cj)uC74kNE(Wz1h5;F zvhclV_Aq-03+Kq)Sq3D~1s3YPfh4SBNKC zT9K^&i)y*X89{c_CiNSnoHUTShj&pjvEKUd!1H7-l2GS{^s@8payuDFPl!UN5pk@t zusN`6-jsR>mc)A+-a4_yXgiu3W<(2LT1vA1Gdyt^QzTU!uBKkdo9ivlu@d{Z+i)Eu zN5iD)kR6Fux619vhaY1!f+C*PW<&B2CUAGwK$s*_MDmQC(jLg z!cUdKTx^pXNB{>x<5g$#ynYr>gA(*uiKlRx<32W9((})|AR#Ub#Zv^qO>6DK2By#(8!I z2SxG`vLs;VivyLIS>h<*4{}xL&l3PkB zMS^1J+mwdDA-?K#@}SFoYstUKi|$`h^F?Yf+Ek)M{maTKYW5W1IZS45(W6(~{4#;q z5f?S;mv|A{5N9b(hrzq4=4cZ?u|-fS(eF$&Gt&~4C9nO`bii8^0m#0Dbw7dBgp$S@ zI3DO>#G_EqKy&0yPx>JmE^jHVz7%cK746{C!M?d zbiS-E>P`_b|xBXmNyn)8eX2^4Ez0M`hjDAU-4B$n`Y|R0jr&S ztJ|VgIIY9$f%q+!cNd!9wGRdj7J{3HhwFV>WW{Sq_hqU3OZLNc?eBsEA?e2IV}NRs zo%Mhfm{F@Oi%1di);t!5k>nb|HKUdy7Yvb>s*RGUoxx5ZB2K+cMY`v+qV?N}(|uHV zTwJFsVVnz{u``{hGFb#EIjDx_KYPIiD%;;pgoSl2cdUu3 z4{U8j3tLWP^~2wTzEy!<{19Yk$uu1kuVm~OwrsDYu!pWj19AWN>-mUIL?=vaAt}8tRkFzcBCI_5c>&(%`5gen~pb8&v)?A!!T%N&S={Kvo z!6kb>j34}-v74Dz~ zU-anHVPj{1954%vBPk`GosubT^~@}76%lst8{Lc&`u1O`PaMWom^tMs_eW@t4c348*peu+Wa`IKak7@7$bK8Q*`dWIu5spP* zM6cKJ$j%ai4S{6!XNjs16?g=&1;%fB&nNPn_e<($_yR5nJ{4*>P>bRyr_Za5 z7(!hRfq2Q6c2NvE)r{}d=WYyvVP;bT%(Ca{%rM}EKjF5Mgg)>bIE&)Cw9@RtD=&Y2 z%SS*N>11Z+9mr%a{CdKV9o%fKEB^MArUJQN$KWTRb44JP{n1%lZY#;uPX84pVb~Si zopND5l!>h|TR)A}-MCA_%PYRJO*_Pe^x3R7L%pYo96N6>2nQmSvYHA?T(#S!Mcr@( zwRuba(Ahv2k8(lVa4Wq;6)h_ZP7N`QLua`PF99|rR@C^zi9T47bD3R9D*38V2XpK0 zzg$34o2c<}RXX=$kj;(@HH>vHuN%-|HSrbmL+4l?BMNbX7C7|6Q(5ukU4Lc7ojmNd zzr!$T*SU<|deNIsp|L6|0Fpa^n6EXml}FnQOZ4}2!o zTJv9Tu59iC0_Cton^+VH*c4`#F$bK$hmG%Kwu62PL_iALs4zW8fl=XUml!3M7doJ zCJuzk4O1-M*WY)HW}sD5QlNq0DA(HK5KISk97|=wBs{nKP4C;_j}_BbG8&w6Vw8X! z){_+c=PqHex<>+60atxFo?mXZpG$BZpJ})mHw$(O$7i}V%{8}tzkeo5h@!V%K9)CH zYxD>25HF@J&X1Z;0kEdG2e=ifQ3~5SEv&wK_|tE(eH>iCtoE* zoqT1eiIK{r2-#$_18${{EJqtM&}3smBoXzI5IAEzY+w~vdgYZqSg{t>ytDjK;eK4X%$Ra&6zU#zeqk*QAc~q> zqbW#`wAgf#()0$AxRyS`bcM4CU4m)M9c>vyNn<(oJBVg*{7(`Dxd9>YAU)a`G1B6M z1|X|x4I|YQHR5M{YR{98AR2jtah8+D&KSx4`?n&d0mv97itky;1@7|T3W^H@r;DOG z_B&}&McfB}I$N%l!&sbh@relnE?G|&bls(D*iPufkU$q!#`xQ1jVopr#x>jG52ggh z0PC5afDtv}ScH+d(KqTIL-2aG82CCF6LSdvBpC!Jywo1X{SgzV^QHwA+xNAt$nNyP zmKi}yGRIC-e2n8}Ghwj&pUM^&K?q4l!D|o{TUnaz$*D?eD;_^j+?^dpl$-8Bz-fq* zh_KS7g4B9nf{E#icb$Es4A2jG;@gE?y^O@Gu=U4QoEUTFLOeJ~zT9 zAOb-c%Ms~Kv`U~7NKxdqm6tnevA_jI+w@9QYT#9ctt{pV4ee18WsmK0BCN<5h|jH# zI|M`E{C$Kyy~K$i4|3`3b8~9Z0~to4*vg%KmVyWxM(V96ld=TDt*Jwd*CeDf?=WCU z7O9L3ngRsRKCRk~Fhew=Obm^t;)#TzMc{Siyllv{yY~!;b}=4sKboi;*p-N zNJ=zvStYo{jt8}4>*a%mYoU=zxam7+1<1LD4y_+27#^E2P9jmy9h$IvO9-Api5_H* zrnHb{!qrQ6$9kBBol;{$!73&zRXvSGRBF zJglWQ=MA`EbY>=a)AuvxnDullY$I%koaDYCIr5m$U!24yOzz~u;L7-xS)v_}b%6al z61VW4Y1M+TK@lVPmsJ&8j#}?G4CwqIr=X*k68ho3Sy+8b6{3Yc;|Uphx<=zqH{BXe zU0O*FMG8qNm&pD`1<6^<_g$?aBrPdpUi4Wi)&>B0Lhkd%tpAz5>x|MFde6!ThH6k& zSBZ6ZNdFCu#Vtb$5Q~#378Lpe-7t)J*c(-;-@jjf%xR&@58kY{dF-z=Pl&1hAgsJ| z7nFco!wSWxkO{vDBJ5$Jq{BxxifHp>RnbgzAObnn<2(0x9$qQ_@OO$ZUfK z=VpSiuz@C>c^(y^McjHyNF7K-+2YXRqApgHwu(8z7PeqT2^7-Q+u(l}uxO3`Cxu<6 z_^KmEhwshK2wCW=9w0l4M;cX{2%WDWWvsbMg*+oh12Y|7P_CVdK;z4htA;_D#QpnH zYdphPs!>f{UC4&~j=(Df*VZS!qPjiK7;!U1`}3Nu-bWKQ#6|uo1WBa?g{dULim^SQ zd5dkt;M|?pUGJ<;P6>y&7@91V2qz{knDp`*=`7P~RI4F^PA`!14{B-%x0(+DyJ(a3WI5(! zRcIJ_GXWqx$&t_^ywoUx{<7{_G0TS6H1eU%-)n7yv77TeD<2;?Oq4t|cOSubkV1hM}B zx`5_QtyjE=Z6f-j?CI2yhI7LuSOB5xiaYWY!)H**TgR`-RB{jSAki3d6201Xog%H` zD9{u#SmFTH{4)=FM(m`pF(EUR*%1z{?}v@$7LuM-K#jNe)Ns9b!3=a4Vue~&K~+_i z1P(Fynl;OJYLo7VVx&OT3bntUF4}eCYc}T94(A%mqBkRHhqwFEVxHN)>zm#u=Y`8I zaW=4pTl|wWoc3nW6fMADoxyn z?vA?IXQo65r8S~F*+_&&MkyUJAfiR}%$MS*%mViYSeJA^0&LqAsPcRm6=dAgEt3=* z!_LBMd>&}m?LWSMAauVR3OHR0LD@fTJUJeu2$+7phgh-Q3^MBda~LKVbe(G!p=IGD zO^uoD!oCkctEJ0&d_=fxzw)*1{*cS@Jk1bl@0{Xi0cI;;TsH%8v(wWWj-sR*$tvfd0j81|bQ%LehJo>M*A?J-tIe6~dzRz8|H^VvGMM?ZK zcDZrB_ZpwKX}HVgxfv2;!Y}+hGPwQ5&3a-Fu1R@+({#KzZZyy;0EwGuX>lS-zO(nP zX8how8H35|(g=f2OVUd|hs{)Rp#?@p;Ev9f>k70AxrAsFa{m>{kWGB*4GFcPqT*gt zkoe52TK(UDMRk+ztu`)(1!*=NTyy+z6Fm}2aZHyyY@sklj~}6@1-%l8_YqH8;z$wK z)A|(R>L{b12V3|1!9r*-6d{sK>H6;KL_JxCb!vs*2V8aA>C)N49c8X}=zlM1#t-5& z-A8=>hcbDTn4ut-v;-rRze6l@3V{vtpi{K=Zx?RBQv2FuU%c;Ps)MQDoyhwoN;7;Q zY`q`4Ev+6+@?7l@ezj!F`>FGu0O=ujPu4=fl1L+!%^K*}NGu?UzH#apyJSS6oS_l1YF-2^8O}wcy-=N|#8*Vq^mA z*GLt0b@#0WiCtWO9nl;&B>(p4=-&+u1+br5TWSY91qzamztRw;^$1hv9_9GB_9&0% zcKc6xAHF~Qx@?%6+~089Dkvx*vDl)L|Ihm*N$F#xP*%sF<2Z7_d)?#ra=_>}8Vs{H zqh+9(Kg2K%|0;%_J7M~Kn$|s;?Y7Y_WV7zRC(HY|j0z@Y;!m059yuq+lR~KLfPFQd zInn*)BVg|R()@n@^i+FApgos1d9!Fx#0Y(X(dOI=SuiOZSD__fTDOsIWVSG zVD^NM2o)>Z0}AS5yO;X!Q&Q-wn^MEk5>gdA#Hq>@F*a}o%-@E8nE5MyEg^C9BQZnBC6{ zJ)bO`6tbEBKHsj!zH_{r{`&y7U5~uqD_k`IYz7DFhZ++0bUuA1;o9N7;$5NV19Vwc z_US22*VhN3!)O`hZO_M->ELLJin=PZ;k$;f_lRkO{8kWYO{dn;D_Y2a%hB@Aj_7H=>?^|z23GJtapz@M;hXp=~-L^jD zeeVl%;c0dNOt{lARvA{%@CJnB~i%kr)t|8?GC_*jyNnELe&8$ zZOiAYfWt|#0N(BDn9p~SBq4EeN?~V1%|v(JfU8y$Cy#R+D-AyD$i&*xVe{q<;LnFv z*#MSza+Kn!GPLwlQmq#~0mPu4Kz1B>Z+&dHECX(h+hP{T4HsGEN$xXvhK^(Jy~cdv z3>n9N#8BYfk%ErTqKu)KdhVA00(^lbHdt?Q2!K;IKwHEqHUxftzlrPK`RxR`_!F21 zb-r&yZ538GO%`zrUhhVUh>w20E(w_T1;eDr)6L~v(eB=W=YHN}^FD9)i%3<^Ucfxa zXWIX&N*;3HH73N1P60 zuMdvDIY+r~I)#QTRS{H@`rm=BV=Em^?p9tA(-r36HA*oWK0cw6es5tcjZf_0_5{aH z3}51dtPbb&hP|NlckLzP~#(V?Gx>hVo97%U!v2H zeM8`W76i6A=dp+Lwy-IGMS%%IJAE;kJ{913l4V~A%^kyjn=R2Zwvrk#MuEwtRwl~Y zXt1dc091wMu7$5RmZ_fa9M7Xb6eVsboM9I%HGNeIN|Oz){I#GGRWT@=iv=PM)-a(T z{D?|9Suto}=%r?Ul8q$FDKVS3)8Y7zmHvmT0OEs!Pr;ZT|5#cL-aQ^@q)-eT1|8ll zzf2}8b3A~A{}uZ&cRTyg|M});qF@B0c-dr%mtv}Elz02ge1H);LOFk5`lmi_nD6w| zARfmi3zig>-H2gU1|o$3IPHe^nH%VMm%jH+sye)4qVzt6rIv`zmA*=++UG%T9-&!Dvk9as&x814?-l#B`GN)w!sdzC9i`aEXz0_UFx-|BbLxG_Gw??MoU)d-SU zUQK2B!YFTDBxynG*JfG-*hU$YHOaS983~~^4$@ zlc|a*IMqPkYIwfvp z^1MIxg?m)0RI-ex!H2|+^qp&!4XCJobx@Qggj7NZv1S(`G3+=P8$ptbU^~R=E7%3g z=e@2c>y$>}B`dyIK0FO`8YM)rnx`6!;XoLRs=^}`p>qmKQKNx&!7{5dh%c6~lx0ur zum9QoPeVbJ^j3jtP%7h_J93rH-<2y*9mGDnTlKq00Ac*u^r=```%RE%|zwK}P8 zbaWCD9pmGI8Un0HXll;0N){;Hn)&&8`*jC#^fd-V3h>B`Bq@_r_pUhT&Id5luq!VK z%aj)iglJ5C?-0PB#hng|8yE~5rLb%IH-i|RA?FBc^UjP$cNRAENJWCUj-IrxK>(gz z+OQW4L@(XF&?yUbw)=qb_H{6dGoYX3VtA-|D75a zT5JD!dV) zPjLJBek7)3)XFWP&-?rIkp%Qd@WPQO6C2%%%OW2-t3cRQiTx^uR3=leX}X^-i_NN9 zP>$8-JH)!rlK|IcOCikGX+D$~QBS1L>x5nP!Q0lCV)q`Q*Evj%&tsh+wc%_nzT-b3 z-k1G3u%@SZ+~=*MFwd5B=L=-+#DxS?^NV`IwNYJ+nW~U&}*)4fCCFT9C3v1(AAN(H=l~ z;6SEn(P$8+mc$i2W1^5))V753Zu@_vvszwuvm8z&Ao$U<2 z>n?jN)S3A<$k{rzp|;56X~PYOXz6r6%m@Hu2a)4CfS=?)lR@A+H%J^oW9+;~Nz-wB ziC5PPgkDr7AN=`;SY?z(ulwq){|%OrPzfL3eo*wf4&>CpPBF5{UajuuT7Qg zl{2EL)7^4LG9l6E(gqZC#sJ>xUt`GXH`h4VY--~Gu{r!vYVV>pe z;v0{{EI0u6&P689-0{}hfM|EVs+-X%kOyTDx14&}{n z?nwBBp`5OKY+MshUv5OE=XO! zgY`?+f=uxbG5rA{tCiYa`yMBdaPF!bpS1}*uR~0jgz;#FMzuzx>_wNiLsXf3E;_(O9t7>-Zy;8i2$xjC{ns58=Gf?@4!DWVj`xi@>&_d>`+jMv%7LoA!~8L z)II1-RZIuBrEEH+u9&(ZZa*CWw5rjtqU#++}1sIG7RgZZT^ zp#`f9Yc(FUSU^}x30;!^Wkh+!vJ~F~unWK3A$dRmeRCA&<8tA$ZTm&&eFv4ir0mM- z1u7UE4Xm@R@4W%Db`3p~bT z*bO72#qsNBLK3+@nQJzFyJ$Z6?KEk2GHYt`H={sWnGVDF{n??nv1aW^AtmHr@*`Ze zKGcU|SHvvBDIwxg5<~UVA{{Q53_u>d(3@Ve2~RI3lz8o&o2rrVlivpsxE?Fjfl$-O zvkRw{cW1314)GZ87D*=0pOydN=7`vsV1uEb382^5g(4=!iP^GaO7U_I?(*s7P|gzS zU#=^jtsim{DQqCow~rb?D@IKZs8Ut9q-34e5-mgjb=dbSdZ>E}f%jh8K_lO-jrk!Y zg<{E+3~}KC4UBAzPh*kaGyM;`4A+A66lTi-1{PcS{aV+X?d@S$h(?8;$FIBzKDl}t zEXw}BJ!rt_w(zGCvI$urrb-hczSei_z7U0D9J0u{$O7dmqJLpjLcfYI3^H(f^Y2No zN1=I9VD5gw@9W^cqDxUn@p~Q23QNb(SLQhyGiiiSZxlBEU~6YBM){MUo1;o&q`=+H zWtV=&$2-(ygESE%59WZSFeHZ(qN)ckSKV1rAB`x2tyVYymHhBQ^Adf zFa$$p*7dB2QsY#{AkfDEZiddttclsWU+CXcQdn8pb~zGVQaR_{$()bH~6melr#-_p5~DqA zv?=9S_nRu}>NRh~;IPQT!~3!f5`jRwoC<{54H5*ck#eAmC2Hh2cFO)@jaE^qy)DMf za0#RK<}f}Q;kCab9}s7&^;{&$@o`~b?D}MII%F(tW)$JS+I``l`6J?D8mZIO11Oq1 z`cIn^-ItwN`no=Bn7Z}>8B*vvk-s=VEzzsX`p&4kibNF6P?my6qwJGIVBy!Am>R2{ z9pW*;9v2n-Fwu#Ha1S2qr?+M**PIn37|x@hjlN!`>lCh&lVqhv1|m??c|B0>4kz@_ zJ}+C3F3;RSair6)k$ExiuPfF@j>H8Icz3xNSw-a8H+lLTGb;4FAOs{9pHEN`iTJBC z&I?lJ7u4*>nZ2BVmq;(z)G#8edMY_`lOoo=^Mm7Bcg`WpR$DBijyv&OsofOF+l74T z8q9NERs9JeIdc(uUSL@sq3dYND^-n!`fM9nXJ2D1#h-U931aLP(2&uHL}6kJjb!h; zXjJ?7G`*Ap_kNAPiGlel!<-GnDX3lEk#2xey*+>`f92)}gik2rIZuO}c_ZK$j;6jE zhELmb3Lh(Il0>rm8xH*oJ=`nwk&o&L{;Md@K=#bvDJ_ekM{TfHW}L?%-TOd=reo6N zpDcz-0V+}M3Icvh)pdiov6G65zH-2_ua3XM?BqA#Bmxv-(i zC%Y=OZHq^Be!U%<#o3duqa_#Mq|$VRRnyh#dMBQH#0rr6Ub3nA;p7rz1xU%qGwD#< zZ^hQt#KL7Mi9MNX4*HOf1YaO_k}e z|Ex%x%0wyU|4}STN0@A)mOC4#o0566roC>WQfp-NP!%*-Mkj}XN)IkVT0ZNkyVxCX zR1q!f#`NU^X$UPJl>%}Qe;PPbQS-gh(+%Q$iAVfIYcj>H%47JlK_$^(1PzCc__U&3 zS1W1ZwwiM@{JEsj`{g*-!h2RAghg~^xytR?qvdANf)@CbPRdIg(e2|x$oDymN3(Vw zt;zKc`#ms2OHEa!T3tZOrD+r@hwwdR0wU<6u?G5sR%Y4&tH91$y9?;@=0XjcF8uw4 zd-o-u__}{G*?9pg(0{D)kd0$CZd73UsuFk_A9w3vRUxuxjY{RdR|`q^4&jj&#K!K zPSynE&E+|&XVG{=j#{_TYmuOi`lx++@Xt14MB#WfhS5i4I$3f-5_BR zSM@w0b;ZS_=wr(o_dY1(a|ykel~dD^9;5(YVlO%fTWoA8^pr+FV5LR$w(h7}7<_w0 z!GQFj5?!}pNwqv)x4gA3WJBw4ZV3ufk`ZEWadvz$wGD(JW+&8+!3HI&8bKh?ftXZL zg?I`qu6t;S-w(9zj|exP{cVo@>-K@dCx{Cb=|9foh&d0Oelj|)R z?ByBROG!C}q;-QVr}M)w$HV7;0WAd5`yo}+cj1_8uhUwqqP9^=GpNmAlEGk>!F0rE z>7*3FJ`rRjYOdV#N`qRMK`kt+SXe=pE3}IKj6jg@UI-Zw%qoVvgh^Yp(nYdJ$1(Zp zwMjvy{m`Zq*ubo{he4?DD4IgFE>2Iwn5%GV8z!(RTRaxq7DlI!mrj%G>MZ`gNuf2` zNJec=kM6z=RgJZDuj&j2 zL2y8z-BUoswiz?J0V?9q*2(3+_&Wf*pLMn#KXYd?ZJE+MK2#LwFa_8E?+IPnFC>nL zLHXjpI5>-LL`7{#(d)eYd|LbFF6y-xsUnbm8Vwm;EIzT;>>xZgRG<_aDeJ>P(7O1m z2=^~uN_fr^(Tb*sO^(nT#msfZ&O3d0!$t;yA(rR*t2;pg0OoF;YDa+EMD zD5bR2q9RMFsD%q7szu9wBny(FMVq1)VFfNk3pYX8Iw)5{3nN5Z6h#^Y5}H4uHR_B% z&# z-##A0bS5DpQG3o{N{}=?R)q}zAfAWv zhNh#-3VgcX35AdnoE*MiixLc)-LGDVmbYD$p1#Gc>N6ZI z*@-tJ4V(_b!5KQ=f2FnK3yn{j`7th&vTZx%bq!{A$suM&N0!**Hp_{q$rCvnj)unl zfq2X#p2CuWNcA3z1E&yz@Z<#Ub?Z4^--z4mrR(`qy1IurTm6X5>4_8<=8~44LT~>s zIy(EKV~eY>V?uk?MwZ7(id%GC;^qGOS#YIyR$T@b5y;WDtt>qPHvP{V5 zgMa8dWfNzFz)&@GISim6Nh`Q1Lhhnw0Q21jdp=5ND!QyJ_6#gV2!S#+g_M{`_JM=s zl$NnC5MXL>fbrfQg2O|&l9TZk>|%X>0nW8aV#Ys)p(qP?LN?1AAz^6he8h6$aI9cq z-M+N&Jf9bMv)6}jB&bDJmYCW`hN?2_AB(n3&Dc>uX3_3wjRswjQA44~d1`Kq&HsTU zMTDiRvH<#UI^!}A5z}IsHa4q-c-rkYHY*leST;5m@EboCL7SPDtj_=d002ovPDHLk FV1i4yLpJ~b literal 30990 zcmWh!1ymGm6dg)hNu^^!Lb_8xK)M78X&mJB!+J07k996IZKb{6gGUZzP z*!biJ*sR3KV?Oph)^2q-ontd42x9gZ8&EG6dNHCG=WpMZb>)Siv(t2`*pN-Pco$nkiKGa&|6J7MtP z1learM1o=oGC4>2)nRTlQ`m7E3r&Zj8lc3;!43Nis}IAL{2nt}WbgI$$vj((>EZ_u zz1xhrp>QXPi0XZ3DE#;e`*TlWVd2Qg$jQc_W{&8)`3tTqTSbCqu0t$19KITKg}}mp zJ~KZLZEW=N@wvLZjObf)73B8BSj>>b5T%5E9;^-?h{+Ukm}}X|lkhv;9CvqE>}`_ksi^6^T?sfQ442*0 zx1@p*sEY_9`UO8ba37u8wnlppQhp}D2edj&*L<(dy|L(5Qu^I zNyrcd0HE=|{3Gg4NQurHffIvfRgD{xS)qMPuC=GnUfjx!PJ~9R%&x4&t}Lp(&|6n2 zKNfRk5LK2NnLsv3<1uFTaGL(($B*&zZ;DiK9ISY;?TiWy#Aml<*D*85i>EsfW|;X zj9z_zIwfPzK+HJYn67*;hV@I(ZTgp95tdr+^C-5L`)pp0x1evnu?)&4csly>LKy%+ zsaVYikn|?9@NZ`I2y)X%Oe2T2w3CZwGBMMLB&6qG=n>mb3?4KfKR2ZW$+5DB!88Si2dv}yVAF8$-ZVJmHceiNK z_vo$;k1Q%LKMm{larLgAX5YfnJRTYZ|oWPGe05 zO)}j@>s}3oL(65p?{T>aaa^QLOI%AkN>le=xH8ko)A*!Q@|A) zcjY?*B}M*9K;u^%ZW4S(t`EEQ8tDbmtS$5^XT@1%<65wUftOHs9dH!b%fKB}A>ijr zSM))8rEB^i1^d2JI?J+ z;kfcKK+8_U?zk%|yD(j+Br3b)&Cgxg41P(&fwr$g3&;$#B&8VJ>>sSgS3E z8`|QSm@$jiWPtaQU8&>i#qm(oSi{1?>M`dff%5-&LQE{ipmS;4sO_#k@A(2$BI2ne z|CLP}r(czwX4fDtH6_l@=s7No4;)8HU0x1?H6N{K^>}`Ecrs65Nj2|T-InspDu>ZM zBIcGNVcMZhTb)_?Q;3YEkus1WXg95-?hQWRl_)6XkG$+?T@XE+)E4&G9Q0TZ6kCeq>$CP6UcOA>pYkTRa`N*K_Q$*4`;3$3%(sCp3u32SpIT=C)E|-^QBDM5@7Ho9a<& znQK{viX--fl+xG$!?$`-&<4%)jz+wGX}Vph+K6y$$cW^3B}F^f9~mNI{# z`d*ZCMwYt;#f5lD%K;>XHDD2Nuv#{~z}e?d532=l2&P5NY)Prooec>t8yfsE8h^zI3XV@Y(I z<-!F8Sv|$s8T61^t}Ta!*8bk#GM3tSa250fFbiPr@GSY0Q$(mP`4;fU0N1J|)%l#x&AO`Lu+LCdjgP)F;%P|(oe zsXDp;{J zga0UlFDW!^fkh;Y8IuUD5IlfpzSLl68XX-)lM#K&uMi}s4kc$m-+98%LNm&@(cj-MaZuZrmN$cA4lnqo zl1J1OK16hrmZ#w@6BdIHC-HXg4Z}69@ln;Z#jHKls%>qh0*aEE*>!fQ;3P!KDueuq ze1jV3-QI8~Xz`99|AIH0!Mk^8KyMvxfo9xt6M8D)H#eF0WHb`q2Wfe*A=JN^D=Xsy zI1UG4m8o?rKP5lKYy-bQ4t8QTJ~9`k5tlLvQ7#kYQnNB z?;`yF@;7hZxVhDl$*S}S1{t0GX8a+@mZXh0${TZccgOX*G_j{i?hPncAu%r_6HBKz0vz z_oltdmd=F|wv6m-d?18DnM8RZf|(I)Y!#uxuB=is??5qtYO^78YTOUHu5E~fYCqoh z2S`0!(73BreRfn)E6fQ4o*&n_V%fJYAX((Er$0|T$dlPD7-|>>+q48VN$k>;U*usYbK%fE7O&UsY3<^zt z>BETLa7;zL)WYm+xba4o0-!DGp1p)0I>P-{-m?As_wQ9zPO6hOmZ*N_rFrV(>B&2g z=!QwuWcTu8Zykf~GYj}15Tm=g4?ZImH8u3yQsD_|WEmM5Nr;P!iV$ToM;T*-;4yGL zV;oi!Y5$6spu!YwQDOd)mUiIZqD{Wju2?KV6D`aW4E|2k``ctlRB4((1YJ~5TQtp+ zWHP0qfpU{4P%_q zy7jE>Ib+l6wlF7^HBmhhfMPzwns@WZ!C6|~#lJV*L7eHA1`WxJ60h`rxo01FHrB2&SUlzSawZLNdjmsB;L&ssQc_kq|2zV z7u6QKhNTh>QNS`v%B4U?{zrY`NACFWITKU)CjW~)NN~{Qq78lEzaEj*z@NuooSd9E z^eTN18|S4St~W9TY@9n=o>EJD;|*+_VAJ{jIfSk{J2_q41U5D{>Xd3AoS=yRS04yu zx>z+GJ;d2pNhKl4j$4K#&JLeCb3ibbubR>n0Fszf>8GHZuj$pA8y7BEVH`4p^L?P36-<$b5u8fygXPmay*sWP(<8Xul0rx z(ghsUTa7+>C@(3oPB-z+6n0wrKMbM}dT(!U6e%mW@p+Ikf^5PI#J?;hLlnnFW29n~ zY9+U6ERs=D*U0XD4u+ry6&Ok>TcFh*MI5iM#q~XvIG9%D3Hn>`B@c4D+sX90<>gK4 zhH8P>g576P^gLK{1;3l!HpC?)&~AlwC@}r}{ARx^ zaD8Yu`1bC-z&mka{_3`lUH(tnNz9B&qI)7JJPJ_SNY29W&F@D}S)_{-6Mkm=`1PvT z=_Td7BOFhJd&y2nUtCyTwiWRpD#rcZ8-S}ilUAIOjKZS?XX;|RiK3$$=3Co5?BnoCqNjiSm*+L}w4esH_e z2r}tL8jg6`7yY234AKw{H5-MLQb{rCKI7HZ*HK}lHCkx0ZYG_|tG=y4`?LJB`V@JV z2YKk6@g2Yj*ZoxZjyry}j--4aq%x>ysLw!j?KW`xsKlP3#`1VG8H^qxp`|76J^y~# z57=Gk6+W~aYx>xfsT3h`Sagcd1gQQ3Ih!Ty(^;@8ZgGjDYyP z@RQ1XZ)y3Ys(__8HT((bKe)*FBY@uLy>($5f4O~Jd_0P*7UMR&O`Fe074w_Hdz+b) z?cIW!-G`I-Wbd1U-mufaW0*kmUL>Jz2o;4{nvv1b;IInbJ6rZ<+o{61xTgRxK->am ziDcfOq1=cV{dWye663-f*4J#yUM{r^Y;A+)mhOUrYCi@9L)gWdYSq^kwDEMO=@QF* zCp)d)&%b?c&q(-oRm%8aQdJ%xid$^5k0RX#`kv{b%P@d!?tXorc#W(al}4_2IJiji z5XbAK%}2CAYXciR)z&h;2JSJd!9H1X&J<3xB9GeQ&pToXAVp-AH8uLu3sI5am%Ip9 za>@ z&zR0_(_)a*@N`ROQqtgqo>ew6vof;|q1|i?PDY8(43C;Z>kO++j@y8Bdq-?{{LGtl zJ62{|RUC=gkI^`2gz4#g8X3-G=0;Jj*nkqF#~Hp&uIrLT>0MC~GZeg)K0f+Fw5hx; zC}*F!_PoPNFOi9u!?>L>-Q8B{xIB(O*ZXjrN9j=_MkEGPd^|uceloRD$ZyCL|sw1v}id z!Z&sIIW1li#+xl!zj7o^bZD*jLgQyYkZ;BZ6pm||08c1XcQ%%TUKe_vUDKH-QX14B zh%ier6Ib!MQNmj@h|822+|K$G!vxOEBg&pP`;sYaA%mK41u?LusM3QXFkW9RpM8&)LI*gAFR`MrikG^T7!&N|Wwd`# zVd(MW0pyKdGTJJXN4R!bO&c?aeeaCs&3&Hd6Z!C_Kq=c!bGJs;gi^cKkzp)glS;=+ zRYxl6Jx>H<7DT4Rsc&FaE5hTOvn(EnJra|M!Cl@u^TgHWC#ACmfzPtle6GzAa=cxW z?qR2DSpM9`#>{uKV30oOdbnVv`e!g?Mt?O>*ee|&S758J&jC=ZbT3(xMjCtg2yyP& zeb^r75c!T^$bo4}2gSd>T(}li%crB3{Kzj|L|j7jZy$45fh7{kaY>#6h2KCbsfb zF{3(>S(72&&ae`SAE|^h{_j_xiPPwI>ihRJf$N%qpK@Hwj7TmTp?m|&&hba?ABN=> zRXy(OkE!)I_mBQ!0_JpxB6ni8E?OWiVvmg z6q&A_%xsJR@uB)J|KS5H2K9M7Rh1u#@fB18)!&L@Jh+JC7fZ@ZvXhgMGfY>T4sU~- zA1t8rC%)|>D}gr$elA#de(>1LfA{M=>t*vt4bPN!MN3Mk)jYf^Z(E^OZ!Hcj2UX4m zxiUrF)>-odPd9j4Yq8iel$eu-i$=GW%oBg_NZtNz?znmaRJXkH&WMUpwO|6P`P_Km zjaiPVg%Qh$%laUDX=8IMUNkQcNGZ~Ycn{{TPZ{vj?K_leDY;GYCPm2yc!2k_bPvbX z3*s(|?BLe?uT{7j#XXP&>0hY1t9P@G<_MW+ot4;E`K@EbY5b z_YMY9fCMn6({zSFER>bs*}Zjpt6Td`3r099IBr>Z94#YXGbxQcoYxCH8*ebVd_1%Y zCgK-6j^^n!6}{SsH;rLY+s4<6&&E^2EJ8&Ac$KQOYJ`O9fQG6D}7-QP=pydQXXVl7zx#bSOpEm&_qEq_~cS7mpE zsr}lJ|LmFd5;V%ss+nE)H+oVfnnSWW5Eer%b)3^ulQCuj9hC}URwfFm4j}EwcQp_y;{Cd<0FY`N4@R;V7c&ZbXHatC=mzr1pO20n;oo# z7Iiv%Fe}%KePMYS)d?cfC;bv_*oEOI@`m>kkj_@R&A-HC0y(@P@C#U^XtA~M>G)qXj3(|nsr zxwvXLB-Ay>fSG`|=sP;G*iMc^Rtcl3kiKFB^DuvMNoIpOj6XTbmic+?LRvBZCM7t8 zKzoF%Opan+%qGSS!D~!cGq~g7*KTSiGw;P-oHE_SIas&lPnv@pDL^;I^LNiSV6U^y zrTOmDyIC+jlGes{$>y(kACNux43Sr;Y15g{CPT2yamNGOYLFg3BzZ zZZB6POF9<=2tcJPu_CMOHcO!lQQUX4HFOW~RY&*}Kt@m0zBH#$#UdLs2_TGL-MY*? zsdah@P$>7b473m*I=L7<2n|^tULTemHM+R@>GIQQGOts@`dwM#ff&Al_bbdBRQiqK z{1`a5)J*~sirV6}L?{}HMy&mdntWC@@(yjV&S+1c>-zlrfNll;*;mqGUryrjnyd_c zQ`T_yC{d9jC$E`c>V4;lnq_XNqvQVcBuR(UTL}%xhpw8G_!{TCl85n8NZsxIyhrrt z!^M!?7kclqR(VCmL9gKJzSZMSY9Mfn+zXpt*nTo7^!wm-w5lMK8pyQ-Ezc~p{uw2& z>C%3j$Oitw@*mDSkw>GZ=fTg}eaD67mwh*4yKiD1+q_Q~ceY|*HEv9QRS362BLaaV zoPC3`z6{FM33*7}o_-3@?|8rij3GB$&=vm*u3swwM>nX>nxk`-A1+e%c=Y@(?xdk? z0_hh;kwR%Qtw&P`Q;Y04ClV#{*zOj_a2Tw$wbTk1I1u$(`h~b;V|%6YYkgRr#ylk0 zkrbVH?r~Z8C@gHq*g1Z5Lqq=ig=*I$q~|V8D7IU&zfmGAK!Uig>|^E8si-$eqT4w6 zv7L+2Nw{a_S6XYTmF0nO`o&?T{v(07puM-65QvMMoQCd8@DlRTXUoQC6&=v*yqRPF z`804VR%Al#Z0z##GUuI;#N&F>VVsn(fy7tLg}P=ii!u&ip5s7Q!||s{z2)>!WTk8S z>(Oy%F2B`C+i_=hdOCDU(~e^3IhIT7ZV5``S<=z=Kkv$;^ynV54Pah`3Cx&}Z(XL6 zC5DAR`mKLoHKQ#z3tF)F;^pNvo-5%qRxn6=H52?A;=8td_CnILi|Y9F3BH2*KTNzZ zFy1cjfKVS(I1_D!rag)Kcx`fB^)~rQ^^@3*UkK}%{03Es8!oOL3u%+pg&C2^R}tz! z=RQ$t&|bvarEhVHiPJd83m)UTfAEz&Q$crpj)ChlRe(4$gb!dW-drY1qqH}rayC0W z-O+WFhSiY}nsq&6P@r}F$EPV-6%d$M@t z$-$7Rm{rqMGkq$D=!~_M9$cAUCe=AVl6>_yXYnUYE)R4_Sv7K`gCW9}7J6*|%R7@< zLZhi>uBR2pdPh>_0Uwx#2;U=_*jp^f=2H7X>%*U~oE|p8K!)cw&(qEwQw~Z}Wf<*P zzuF1Aew|AI&U_*iHCV;POtz%q?oo=ke;- zMtaBf!IUCBLPW&8i1(B%W^hADNQjn}R#EzXJ&u)Tv8Bl3Gv6>EquuNt+PeC9Ax)R% zGg!42aXT#C@wfKNQJhxnUK~YP6)BE=|EnyD`e~b|4I_ zz4}rp$c!12f1<^CfT>_^_T=aZ@SE^8D+T8(4jxV}9!`!)+~>YQ5cw2s603*dX2JJ? z4v)XP2&M72?A}Qzo4`D2(ZXE|6J5)`cX*BexZECyOVb-|kio}PRoo&=u}7g;BXRF- z8e<=UfsX}qT4(&Igye~;>RyShvkUSoR)-S@U_OvKf9Q93q36OpHJiG&x}?m{NnzA@ zSNtqr$~%p0?w7#i;>K48^iKx8eRYW4Dg;boWRK# zMOIq{=BhNk^qU7=<~PMo>ja9y8W4RUYcL#^sV46)TTmURrD&vYq^hau!cT?GAn1;P zWhI7H906Bk4hNT>{r~_i0pmzX;#~qEE_H=ALYG?CnJ}3*y6w;ouf-j2{zByam*ulR zzXUYm1bvLwT!SBqk7|=hC2UkweY?FB;%YnPyg@w&bdYtMmS(FLLACgq{d8ovXbZ9%(?_8 z-Nk+RvCdVKjV_i&Jv$7;{M*{u4!OVk>FZ8oa@|+7M%7kyyaO#qphC+cj+uH8?sAg`=!vMpP4ktQ&?CaZ6`ie*!}=4_gup2@ z-}N>i-#vCvLV_?~{|dXgIyt3R@X%5=!@8^0cgWY->E^67ul0I2`0-*Tfajr~M>vxa zgRhq?5CzdDrH}{8J0`BzLBac`O)O_(%68WnD}{jqH25My3fWW}h}?+Dd=Y;&RDVHu z0v~E@OvUfzfgyT{waUe_Oth)Axi_72@*x>Ix>j$&9>>i$Fu2PTK-0Ng#DMeAvY~BJ zL!o8+6f-)oM>}1ljFKBjV4lNG(JLVxWKr_5<#+R=up zq6T%M=OLi40D!P+UUpW8Q7V?$5hUVZ%19=vDS-ZnvcJIhpzfsUJM=+%EUy3}%f>Ke zm83YbLYFoL`4ZV78Uea}yJ*kqIuZ+K-&2kOEY-b)GuOea>WH`A{OEwd!|@U#I4FGE z`wR6vflWS!eA!9w6#22K+eQw&8ZD%?FHUl6>jkB3S!c%?^6kmQ)A7U0)YKh4J>=Ku zoFd9&+tGvCD7MnVdbcB1O}geC=e9#(pQB~+`1$DN!@>%JW9(vw>NNt;p6ckP*+behNbKTm0?e=e_Du3Jrr39#}kMq{Fy@_GQ z0@oQhIaQ^S-VMX)JxqRXxB(0xO;ZRz31lrM*-&ax=GDa;AE}0Yp2=O~>DnGO0no=7 ztXCXWN}KzNbaw|eKJrF}00Xp{TZhtUW;KXi`shAu?$G6+bb{oigne@>sk>P%5g zHv1g6c!S=K!A>8U0L%hA;X(bw59E=kZW3E@FBMT@5yf>FF(vq`K0`Mg8NhXRq<{#_RLrXUDF+!)ro>nbZVxx*f=1eTerU*Z<5@rcGYUyL8iM3e$@1 zz9;K(adCA42U}d6G0@UYnt8$T4ZrX@~eS;8)pFY>v$O;PY^Q1B)}5R7#BOx`XN9duJq=Nr+G7jz}dwCs$v_1v>9e zN9ozIva_-P2Cdg$sU?}-F88#TOzu?q?zZ&AD|dkr=^owr4<83O69vC?H8X&{1*r__ zF38A^d`6v;CnRkaM|n8?*1EXPq>l}M);f&Ks-Hbk`4Qyz@>Ag}nc9zQ2AU3_W%48Y zJ*OH)%hti2VJ>&%8KRS6-lWXIzAD8$S#2v2U53Tl5bu-P2`+O;y$^_i&A%8Id-q(6 zoa)e^1l&){VZpkg*FG;y+N%NZBMwK|D=)Xw{XFn8trx5%S7228hmj3)q@qBAk%&JE zq8+*RmiJ4n3)6XOc6VU6Dt-0qHi9GwSu^;w{a$le3re!7 zF+EmQreLB^UZjA>q}+3P=?(mQPWCQ;yzA(!A?Tw2_9+v3$Zl6_y|1k9VA5?Yq%(4; zIa!UNZ#s8i_UOFrb@K`w1@iOZvrm>(KDE6ejlE%I0QyRftV=DkfpEz;e{@_{ zO2U_*+>^XZC1lo08hOaFE0RDnZZy3j;C|P$ZEQU|ZUh9B|5d`liBV=xeHmZGL}S!~ zR}5BJ!x{9h#1v~{Mo*CG7EHIMSTRpz)%r8c&Myf8YDuPCZNQ-E)EG}VF8SCyQTtjs!uI5O9BM=>KFHCq8l0r{>X#m0W3e5;+bXk5fs=JVL?Ia2N zgNd$Ywq=r~$S`4$`55_-+tyY{n1NBi&Q-MD-24~n{VJs}qFl@ARI|T46%89GUfLTp zogI1JRJGnZxHam9eY*S*nzNpD_w`U&Tu+kZ*hy2ql_W9F)+c(}k5B#0p2U2oj#4&8 zFti!t%*Rqu zQ$KDNL(E}2u;DuVacPQeL+uKQ5NpW`l?`@NJ}S$S(e=`>N^NT}5KkfSf}dRfdyN<} zsH|%A{W#9ubv?T*ji}8_9Clq$W=Zg+#kPTKVrTZ=T=l-KC z8n1HeRWI@a^d}Yspe3)SX2F-2&cAok#A$e!h!<~`&z?)VyYSEO&G4qLP7=yl;(Un- zOX;Qjpn0OCt(D5De0d2n%C=F6J5PT$Cnj?Lt^I8=oZjbS!^<#H&fvtgXV;^r+0thJ z#=*~muT~#3OmZ{3p4@jHvK$!bnwX%hKJVXep4}L8r#n;F@6bib8yXo>QBmnt8@E^v zujl2N?q&?vW)k>Bz|jDqzCOWpc>4YXJFg@}^IHNncgb|ttvbduJE&;cR~GM$c0xnA z(~`BCcpOm^FKJwAZ8EOFl3cBxkK4}98*QKUeXZL1!jH?>C#bY4H8r*WWxq32F}&eJ znw6oR=rVk1dTDvtxXb_I3cssy<7C?}Uw9IEK8idVrtZ4eD9x>dX-Sm zO6jCn2Q^RY!AgKbpY@j-b?z!FB70FX&UJBBoAxNL;`+3-XKj3%%#>H?UGo zdI9gs`3=g6@~o`vbZ!%Im!svypEg}rmqrCISA%^IAGsa|`+hpNcdUnjL0kwB)gmuG zZjyz^7gEie^-k%lQTd5){@bR;;FKe~A1#Q)ZEFjb=m@rdt#%a3ZRTubKg!Gb6lU{o zZ;>*a%%po~%kRjKO@c1k;Vyd~i;ZNGhp}Sh0UGhbLN&qkO4?6eReY@CaxwK>Ih{8{ z3mHykG4Y$#RiyK?EIIEjS=)%0?mVk*GxdG&1BS}!8vCM3Tkjrpk2`1$TO{XJURCv& zX+FssEcR<}ov&zF&Kj&h-bYA(Lk3pM#6g>9%nYBiah;{6VD1l)#g`kL^ejUm9?I|+ z`54_bkc$d@?q>7zzf;ndr&Uf)@sDTuBCEb&O;G~iCnTWXd46@8I(F%G0bxCn&=Ch) z(0WKmmK4(Wj~T{y`(SY!DWZjw6f%N2y&zARx`%V~ZPb;NsMl&o@KD2i=D?7@qV+`1 z&G53H4dm2nydkQt43J@P$7)Xbcm04}dGp5vDE|Jf9%KF>VCchfhxHP}*oE-EtmjQ_ z(9{o`c!_q8lR|Ti2`=a=cbE#vfyLEqg`PU@2)Ghj~oL$!N zJPFa!e0^~o{1`Ct$7oai*V`9nchG<%w2;@Dubl=u;_GQlehL46jEOu&4?JSh?DCu# zzHgeUHeWfJ>=5=}MZi0ahCOM41dm|5#jaR)#OU`FYv;v5mk6JU95H&+ zfd13+wwD0`GqK9>ljr2vBOJ{ZIHM|JzW^XarakC_@FY$+^-F;xJ%&@e8cYk8c6YU_ zDjo(O0_-fF6)1Oi7Qb*~nWiyHgfWbZ@&e$lM zvyy*j+DkN~$dkrB?ER6Hc8>`PcOpT=XrGQQ&V^@q{EjFAwwkg9R6>#GdpcY)puk5G z)J6#|E-ps-(!MLi{-~*WydGxV)34^v<%2~H3NM6fp84n4hT0jf{^JGyrKVN+ZiV|A zsa{ASKP2KLcuwJ+4h1 z?o`$_@+onc8n;O~u6(U;F@@53IZ#3Vlvq*MImZEl;1Eo-LpCh3d6iEtZ?zsOiFR91+KbG?h^otu-ny!=7y{GckO5 z^19>^8ft4UI>&25cK=OSAt;!(WQ_V;;mOmbf2{O; zd}jtg94MLf`iYupwweeR5VrG>laHyph^7>JBCJid<{zM0 ziH(Z;C=34X!NH6df;R@`zNsiHBs8sFC9p#ns>_YrM{1f^ZH41d8to+a{NZA=1Saz# z9|uEuY*126YPwlxkC+R3z3iKpWiy|(`njh+v0!F+FB79hMRRvvUvI=k%Y!igxV0I< zm$=&w>Z>Ov#y*QZG;cuevsZt83&)X0;(>E|qRd1VS(z&B5 zCkTuF$h0&^Q8a(25DhzP32q&a>MzmEf0#3JiLEqi<`LGhsw&7cWI0vi;yXM%ymx$_ z=xT7!Q2buC)Gn)b%ls`Y0U6?K zyno^dQ5sB)JNub`fAD}jp0sa)j=$+4g!s&ANLD=QN(?(aB6x=l{>^mj=QiV>6P_8% zJ0bmehvEy57dut-=zhsG5n5^Mt>Eao%iL2%Rgw-ml%5IjA3KV!Tl^v$R_``{IsnQ3 zY0g`A#Bt^~z?_%)8NtJ6_4>tf12yG9MlPz4GCZyshw!lE_9K)@6Acx_G|3f5C7I@1 zC1PY4v<&qkBI#=8x7-=myZ==5_R3DQ%fHg~ovrkr zK~tC|Mh$4>`ngob$5CZL3<{eGqf%LG~;DKhWHI{8ShL+XP|aT$e77F z*nso7jzxGTT4s&g<E7>H%iFf{0YVzJkxB+skWfb3Q*q-pk@x zm7m-M_Lytg!DJ)HZYyj{_x<@ZIk~wtSdvu`!UW$-OIb73E@#=;4^Q>8pD7_Dk+;$_ z&-U2VW?N}Ctt6s;`8W9wWzh6yB(>yZ1$TApUi8Dfay^_s=t16e-ZFo_3k^&kwQ44q zXfAPx@-577k?=_E6oFqUxFWFm{}~9-Ts}vlr_K@*5-Jb6*+~2ATX+}`9A8PJ;zDk2 z?lZ}M-Jw|c-W}H>;iOH-y#l1sPVmaHmv=+s|4jR4g+;|?ykl?vK+<3Ijl&(463L`?<(-cs3 z2aX8$1`aF_}SKtI{a)5#{rA{DXOG zS}Z+z`=CNrNljZ;Nm7!DOEX+S-RFryrn4vs-(hKKK#w5JKO4)Dh26Iomlf##IVrV2 z!IS5=bsVkwdi+evA@}h{+~GMm4vca}%m>9g+1mb{pMXsHL_g;;Ou1<+FE2ME`Jb34 zxO{tzR5puRIPU9Ys<>RXnH=Sg{H5c-!EYTNKw3aFFKFG|JgUUHsX8cFysdL-iCw#Bx#GChuwC9_1a^!oH5CVRDC>0Ua4~3C` zJ@3!xo{q@NG0;Za5tY$l0T&heJrhi1KJAQ-Qh355BG9I$wBx$ZYiD?>j$e~4@MUJ* zRCPVh(o@l)+zZ!d19ot4bpDnVJoyBj4`=}4Dpx2t9=@6ELMBX)MNLiZ1fMAXF(*0p z_1RUB<(&)VMWI;|HMr7D;_c)=dlCX{HOeYBU6W>qDIa>%;H^`AGb+(~pNk*2&sMS@ zq0*Wi|E3Ya0jV{Wl5^8BIT#?PvNf!8;!Db)vi#1Z7DXX*OOFHrq9zpUgeG&w3$j~ig7*1s0xMyQUCVPH9_`Z>)YD@6z)7Xk1QeWZ&C`)&p0ViENKI0waaU_*vo(Ua$Wl2WV8Xem*8IJ>{ z)4i#F##$CoRQsWr0k~EfRSM-DuP@h;<4EI2DCx>=NVbwUAw+s$N677|=}&o+ddm@5C~ zXwC<9POT3pBA6r%%ihJp@L}4Wb#WEK%sFag@MoOo{UNc)Ghq=djMUV)x5IhAi}LfJ z8|5~yGQZ>?@uF8+4_>U^Z6rbhez}UaL%I}(WD8hlm298noS(w6 z;F?OFK5hslZyyw&F4b{1_e<+;AA|k94z7U!@_YZ}olWV_30#rtlk`-2Wd05UD}43t zEYSzzS74O|vpJYrT+D3Yk$1HF|MOuo?EgYl+Kx7eC1+@m^zwarSMKRqG$_^q2mzFE zO6*X09x|vh$cK)m39(RAR9tTA$S%%BZsX^m?tmtvY1&!K)!62~wTWu*y{?d+b3*+p zaEuf0f(KilyV<;GQO8)O8EUXyh}{(oDjxo9^K|sl;O3uMrejUBMSgs!brJMJG`V!1 z`hKG3mYI>HBbjhcH?BRqdBK`AY~<47=M@xW{WMbux@y8E;k{baL%2O-b;Rc3v*&T- z-4#M!)78CinEjlK+C)5nefAZ~NVYrnyoZUWvdHKlqp zGwh@dN~D-s{6tGUd>oY46uoIaA|F>tz91tH@Mwr=!3T{~cZk^MJ$~4jS3TPs;Qjnk zymdX7cDnxA9n6drk-I8=f4pE$Jx5=*S&a>{Gjyw+_bMZ>{7EKzM)B|!4<=|tO}&qY z_xDS{FOjfuW~|~A6m0EG`l$Kl6Zk1fV@TvjiYbD=VJ5l>BO1@KOV(We`}r4Fb5PvT z^XKFNsw&x1nfD|8UH*>Sbk;7DPbL?iRhb1o{N6lyl)8*O&c9A~m3MZkB9G2m`dPuQ0pg*~Yno^}(g{NjuMp|qhkO@Iu?K{;Y zAv*79mKRhH1iY5dX*sA?{SJ#gG#+<81yn(H!`oY54f02s$B))n2Q_^V6B83^6uMP% zY$vK~Z*R{}#RC#f*0f*r=x7lDvFhr-OQlRM4Vp8bucIK@|H$qRggC4N-;eV#%h$I# zoQmzt(t4I{A34l2_0cOe33m9ewdhG6o!uPWBUeigTOIncgfivZJ{cWf1Q+i6qw;SZ z9T+MiK&>zAfg$SG*6|!QKG>f-lcsG|18S@7x2TWX&(HE%vZ$dP&rI*(Fb%Pwx%wLY zqrio&lTFv}6W+N7R}ZRCnQr8+#M$F*x9X$c{aLA|xb>||b%q~e+?%u7ucNg60fJ&+ zr8-ZPkxa~Xl9aE=ueq4$6M&QEcknl^jXrMNZLvC?`T8W+C0d@?AV zRpPf_1=V}tojmYB)zl24ay5C-+YPA1QHy?%m(?&GKBYi_tT(5yZ^eWcqa$437!w$G zm}?>$+wJj=*vY4LF2fE{ z5?E94*=(OVE&>!d!HUs;1vYj8+Fm=aLI0E#~tS z`R`e3mm?4+gKLXtRa51|=u2$6;KTZX2@~%2^KV@bEjGFWVyJ1dXX!JUW8|C>JeiLc z!gP1}Sbf)z_2RwDUDbHD$nn*GC0%7yR9)LXLxaSC)JTUS0@6~_2ty;?C@J0DNP|On zDh<*ljUe4EEiE-PNWbU#z8~{r)~q$_tbKOe_f?y3>3Ctk5>;&Ep%%aGh|pfg(?zkn z_tFtWm%~UaJ&YfBy+I?XF>_6MoXrayL-%{Uj*hh()JK=q*JKZ^jeaICv6vsV?G!b} z25$%|nv75UA5j5}^$1Yp4$8h1)895V6q*Y+JtUj4>DK-IPYaK1+F9}3p}JVhcOK1f z8>qx_C#SKr(ehKZ{}@LVCDOc8_c#5ZQ3!|e9!MZ$uRCNCYxuY{9PL)lJfYaWX(X~mM`<4`lh zfaHuwk*BwBTI=)SY!xW9L`3IQ-2en&7-x%?uWDpXvwJkXvSKkBG0 zvgoNXJKGuazbHsaRKW~@Y3Xy{pOz7FE-dj-R?iVMukJ|2D-SP}bsecD%lE__R6B6L z2R;4ltssycx5WA(TY2nUbhQOhU4?wR3yt z*UjN#z^7QtL~LberQ2rwPEDq{9%Dx++p9O9d#q%^nLsE#+Pw}Hp#4_3rI{q`+~qiA zGK!WN*`PbLmp~o%Hy=V_Bo}sQYah75_59g#0D;T z0Y=>xn!79Sv35@TiPJ(A`{BdcvY{sg3)Yh%MPsGeT(fou9Ewp4UcKt0o+le=Opnyd zE6$>~f<;#9b@ebP`5ej?Nta`C+udD|40W;pQPyy|hPqZvML8QY8ygk|YzJY>lrT^M z?vY5Oiq;%x`B?Y;M)M-mt9NwAP=(jI@snO72{u8|@`ac~W0doiKfdU=POY)0q~!Si zvUL6*P$+hP%2JtiGkdT;@bqa`7KO+Mj@ZPdt<#jGg%eY7JKGcW@gG@I)+P0mq^Lngelv-uTdV8@bD+)YEGrl*>DaGvu?}N>ii1l zxgUG0Mox3CSM5A|osvp^C-tA4!Si!Def#8}eB@#NJeNJP)P1!y=lhhJ z>s?)==%rRN9#L);CoR1l`AU!ka&~wHWaxMomdLH`xxKxJ)(nM6lald(r7fo|4B_kR zz|+B{_q-3MILjufYpL^3_vWNEKw*>J+vY3enpxDU=fUZ{@{(wKYaHh`wc5#E$ExsM zQ6re;G5Ns_JB$0vS#Z6kI83{3vtda2e!1RaK{3m`>wE_ta10Y>|L5m=d-CX8;xt=MzuvqyFgDjJ%|#r+!yCGKZb&_=L`RAkB3wB?f(4>3O8U6VjY9X z3=n*2pt7|UmnyUbQ53Ov&-rjoOF7mS;-tigK!e+fM@`w*D(ISdu4CEp$9K4+w%$T+Qm^uc;dBbm2(kknscSkhr85M z;h5i&tk%I5$LCpnbe(&Zq^ZJf3R5Axv|jOW?c}QObz0`j53xFiKNtfmtFVA}oxA(5 zGL+df)<3sT5Ooa=(s}QV-dArs66ij){S{l{$(wshLP%$VZ1w5&T-{`lLJLA_lmWSg zrD#OSjo*RwZ9WZ=*OyUO~se(JzZWN z9>1kV1&|-8`J!D9^CVZa-iYB}N6G`aUq3R~l4S&ggVE6N1~}d%;lIbLv-=i)my@!)Cr1QuX(kWZ?u-0?c9!y7{cN32b z;GCi)72FSPnTB4NtgUOXdDD2BeR6%0OcpGGeZ2G6FlUcNT1sIms!xmE54qcSH7`7| zY>T@r{&Er8ur)B`eC@oTi&>2t&Htv)P{nfx#cYeK42iq|#dB4KrRJ;m;89b?!sLE) zRB@a{$|&d^VhN*ux$M|i2{H89Kf8Un7@evdqu68g?GP1GE*9ST<%`IuFB{tM+Q_S< z@x8fkH1}V0cUVmv$0h5~ndNt$t+fAlA$t;ht_&k9F0M0qzr1s`xwA|C?^1NE!)2Sf zDbK3R9GOKvF}g#6Zm9!K$}FSbv$N)lhzg4o#*EKFAZP`EOz#; zhfo2U|7g(}|MQ=T&Z8`8b+tn5>@Y%HLcq@dpDVA2%38{GQ!~#aK}wESucNoPBely# zP;P?1+rNP#s^eMfvx>*5shro^$*_548BHdgy3&i_sD~h6!|GLt8T|%5X{G;F)={?K z>0%!Zdcf0ce>@;KiBjqFwvE{D1|jlu5T0ed`zE}k$C>+Tq2HKWR(vzGvDg$kV_yrI zI5-?BFBjO$d3>1j^ZxF)Gct433M>AOlc=Rz4L{nO@x`c!@!!dAT|LywqNdct9x)jI z4S}#b1SP+Ycbet&9$m2C(R(;uWwC#B-!2x{!2-~#PDkcI6;)4yQ7mpqb4uZ>3>uFe z;pb`%`?`}=1^+Ms#^CMCe{yo^+dT_HLbZxNM;voTMn+b?(c3$&^=?lhSV1{pcl!I= zVgA3!+ZiTM*E^EqyKucMA@s&$(qPrr2r+hlE}SBKJbn~ZfDVn6kj>r*M?WL=y_(_T zwltO%fd+iNU&XWVzv)t3oJq!ei~4;CE<{ZLQf+UKrf?tK9w*xRezT^C6!*k44z zhV`^jfb#iO0#k$k+~(j)v#;tr#`N=`_(g+7zo#?PJj6blfVjV8XMbj;ux;^$TLupN zw!aAuvls0!??uGc1^!Im`f??Fa=9_)UscG43-pN3wR681TX3G#VLBWNdRUIhJlVR; zd)1*pJ8);9jn21~~v^qZ=1)g{t9WQlP3JyIJ z-q@qQyBc8^b?W)C;6y*pGkSAszl{Gebgbi%%3t7Lb_bKU_wdRJ+mI+=W{B>u!wm4S z>ck(GnHk_TG&C$UfS~4doAledp6NK;Sk8|+!YmoRG|r~Gi3TFAmyepQmrgZ}7O`B{ zQbJ!Uy-q8hxbu5>{o607YLG~WI67A*QxDAP{~Vy5K^th7)WQi zgYNr@jnNKgeS>*r{KflU4Ca)6H-VUkoNG*AbDB1{eZier_GppbukMZa?DRuhH&*#Q9HLd(ubKvHz}P`4O?mC#cHdO-oM?_UFh8 zS2jcS$A=a?XOTU_yS6MiV{^-Q_+XTFB26ks13h(Pa)lK#qA*A^+jBb&Bt!hhYv1tj z=zE+zuWUX4c2oaGtFq#iVSQ-{9Y*@1+FM7}8Ya_FZ`Vpb`-2~Y%Y`uu?wUPzEl(B^ z^$h@S8DaN!U}JHy*$SC5HP=~j9>QxagsjG^{JZ@B>&J+R&4ghNc>&?!14z5tM!q7i z>FFr;G#?~;yf7W$volA=?sv(|Tbd_y`m?p?9#K|Nsw!%M>z7t;2pU8x7SVwOKQ})a z-|o_wo$tE8kyJt-7e0m|L~nj{9lF6UgAY6u&OF-Q^+x%XH+0Cyvv(>re(t07RwW_D)|fE?P?V zVJ$hs$Zzc(ASN|4AffRL5|G@`z?AJLsIBIH{@GudG1q-d?xRk2mxR&#uUQ-;4wcBO zmLTz6EbgU@zsMj+Xa)^-{>EQOhfqguTiY_J#Mt&F?LM*1r%wO?HZygNchU~Ru?Ul% zZ|RGNdmFIDfu!5o$-{=Ltt9f5x63n$(bT7h)|KrKx*$#kqV!r?bLYELWo2dmMOM7a zV|u(`2Y&D7tJz>zH@tE86QO_FFw;`_zS?#uy&#l;2(fnNefp$1`Sn3`&tQ#OzdYUk zpq8E<#YQL}yAgl)NWhXREpwdp&-cu{*U_!bS}&-*{4Nj)PElve)V;s$(35tOZIzSH zIwc7pVsEL98f@kiOFW~=*@r3$uc&mJM_Mg6hy3D-O3H+ zO?+zLjsLDz_4TXx6{fzYXE$?lRfsqvYputQ0q{rV_}*llwesV41*v8*;h>tK3;D*R zF9m@f&_d()f1yBke*S)b?L>iKcU5U`SF+@sGil&PRf937i4q^byIX^EaiPBdaH zw`*x6aUV(kOv`(T_~Xwi!nj|!A5S6TI*HInSF0s zMk*Xt?^m{w0&apAxYGd*b%YOLK>+k^NC+?~(2scF*r$_Bf)j$D&1)TF;O2R8 zG53V&fFD_GVq%*mN_{p+5XjU41$6boG(tZn#2MgJ*lD8z6-&9>mtvQ1+o z-^ligYtfhV%xruei}qj5!^3OpPB&U1-C9_W3bgh-UlNPg>q}nZRAeC)0yQ!5GT*Qg`+`;51N|5U1E4A$%yNRYPPe$21# z6I43eMKu)wD94-sk~R7oT~b&WfEIr@r3|$qJ2OG$N{**l{*m^_PT6m!xaQD(xQv30ifLgQu4;FEoE#Tg z$Lj8v-I!n3(jsKJi2OeD-$*2dO{l5==wU$TgM=#xx_vKC+Lc~Otqytm!kkoxs0-2e z)F7+mcHzbOsvj*KUd}2NuG@FR1rmGx*DQM}foC6DUyT0bdI3NkP-KxXQ@(^hGyafG;tEG1|6nWMM1br z1X;@I)ybZTN6dUnQv*RdH93Kf71=%reW7I+PlB+pDn5fmTmY8(fs4f(#m_qVA6^pS zhe&87JmHDyL&EYy{yoOx(V8qd?26cb?`-U}d_QL?kRqssk-3uKONnxK z#Ti~js@8mnqmZT*NNr50_c%l%^x+cYUEAvdlD5i{7U>edurfIdRs zZP5{$_6%4J(VyByd!fIT?#hDuOP{=11V(RMJ(&l#h62k|f|fy%#}6M!uj69ThmNk$ zT>`=nww4Wb^j|mt>LC^6;>Pn!Vu+)=!R6OXeU=Y3HR`;!cN7ys)t zK^X(tQSn>FW%rnf5YGosBeu3RgvCnhrjNbj?(goHXdd1!FFM;kHEqx{GC%xclpcLC z9vk3r3Q7&}(q#l&QHI#sQPP&mCNu~nkWTF8Nmavk(hWD)D7M&wTj{SjLB@Nxx!H7O zZaQa}RI;&rqzTq;sC_>ariBWZBMRc$fWme<$+2Zs^9v&b`c>-LUZ9KysJ=x7h~nWW zvVc^t$`rYYDwgG9mtz2EtDW*_ZS$QGvS+(H;$|_rsZ2N480U5F{gsEG=J?l72y^fW z1e(8em56$k)4ox+QQ>_*`GmzS!w${qyTg$e+E*7FEjxUZg;}|u*$7SGn7Hg$0pS)c z;hd`-`LaBjQ1SqIifW;tGDD$Zs#l1mdGSpUGim&{=Mkpy;rdPDoBI3+7fR+#^*q;Q zF|2YBI(b^%o0)TMbZXTELAw|t{bdLKmblm;5z3o0!AjG1xWA- z?Hvm%iQ464G0mb!27MrhB6@UNLSoVWAaRboHnWSU8}<__{T@a{X|f(NhRQ#g3a^Xf z$Xwk-i`m(~K`p9CbKSdbM38tZ10^fRJJLwgUrnoRB<{!NfjcdT?0zxv+Wa(*p zp#Uh8c($X@jr-Y=SOUn6UurJg4;`K1V>Ed9Ot+C!@=m-S&qQ4MKcY}zz zdxzKjc}DkvItDHL`P1eKrgcuBK8i^UA5B#b1qG#7KvEizQvEHvSnYzybdG0KU8V`g z;sn}%`|F9kG_9R{UXnlw8*cEt52~~ojR7X)PR(3 zdVvEJh8MS?DB+Z|*gp`#!x|%XXk(q`(g>$a|EDgpb-wF7`Pc%_op0C>i0G$w$ z#-ca35^w+o+msM{R0S>@&j`dynOauB=@@0>iCFt);SBV>DsqfORhn03nR(ElaCHpt z()^=^3aIMtJE>B2|4>r67;YeEAu(C~3QE3I_BzYOOu?)Zk4Ej!owG{q>P?!Mdlu2Y zi{Dga-^QNW+FW6_Huut3Gg3%*;>)0YQ&GkRGNpoowP`-ea(&sSw&zNVwqQdQj3SQi zF~;Q`Im)vk*3#%D7=>a71X$7*F#ZHRYMLkkney0DDDamd5^dgLj1W#Ksr5{{pbadk zUe{4s9NAJ~#Is@Y{HP;Z!uf{0+3~v6Uv*1hWtUDTKO|37k_D5+jt)!&$Ag{_iw+t{ zM=u}8TJMH>LD=|{PGRpu}#n59(dqTPe<+I^P?Uk8}z9#Z1lL zJ}kWhouc29hCU-O#d<;XwBgWnR@J|Tyl{MRaR`rrFZbJ(oQ&R zS5P%BLKq+dj8t?W_cIEOW7vphdEV*BNZ79~Nioeo<@&|Bm)s(n`PIvPgY*PID zC%$C|+`Z(IiS%cfADGn`WFya__1PP%&aOD`&!4>e%>ghNp;r(S}H ztI)euDzRVpyV@JFb>!>KOp2Rui!abPoO`I|q@#l5ju=rIMgSv~m{Nc2^y@DG7KA?t zHGem~^sGqxWdV=k6WV%ylmG`xF|l4T)rD=v{bp{JsU~8sS^)LW!oq#Axz7g3NeHMcmUD zHUt1Do&q>-qs(eMv@w7ps`@_mF#($t%SD+MuaMa2uZxG1sD|UbRRlFUCvE%xJfj~4K_OO}o$p7Kd*U0B&OV9p(V=$@gHhZ^b=(WgR zfA+&T83*PSeXgE4jUu?LsS1Lew&he8U_N%6f_$-U`qxENzTBZd`s$ut!v zK-K4UIZp=4ts5n?4#5lnj32i}qQdvh`<3nMk=8d4yVpBLH%IuDVg9r-zWaR35_?tV(&8P0p|dhC7AFaCSi9S*LKXU(nGc`{lDT=ia+zOURe2)|Qg z9boIMW#eI1(-%j>;_8jb=fCAHa@-j=LSzAnKdOF>zjE@R@tTjFhV}{|6Jrk(czF z`X$v^xrqdD#*L2anWJp~wJp5tJ&^2sJnoLUm9+*yp}9a&w%6cD;}V!Z^&A|VZCFS1 z72?)t!Fj=LLfh$qkELC&y`j=(EEAXT`ju`p;Qq526ZUh294k1J~>ZS z)jTA6#&fC^79P+>?S&^2<#)1%SkgIkhk*VWa+Zv4H2 zGObFH-TW_qa=J1FajemYLqx^(ms(v6A8eEkYuu+wri;pJ)XoF0-9OunC zbgGRV9%g1{khNpkF>g=D#2=fTwn{}GPAb4v*EX34Q(k@;V(&QpoA;T9%jcy*du=TX z`SAVyk-f-u;?otcv*qO3*$;MhBgLhQ6)j(_*Xz2GEkdv)O?&%qWeLIm)>>!_d^sC$ zF|#dR+8~hf^8s8Fj-%=-K{c}_!lE^Phu*h&6%}qErSJ36Akd{7Z*FsEhmg?Z0E{`! z5cYOEG03QRd>`hL%E3d&c5nHs8@1n51)?E2q5O4hi5qIK*?NUjBG{==R%BZreU)#o1nxH5+R& z0)dDtF$o3GUR|(yV{9JCiN_P=4@-Y*zasR5y`t@M(0No~YUgal zb&f{#+SHJh)GYdG1OtLU6N?;`?$-$4>YL-v=H`W>oChsM#V4L5? zfuD1{_R{XJR>Ry4AA{Wye_4T}Rnp7_158M_?{{PncD!@f7&XjY) z%p&yO-5sks%z697BgBL$zA(8@W{Ip|ySGvL2#a2B4bD8b&$lWCh0=ln1(CSVjltL? z!n7I+0IoI*zIhjWrHU06m4<_}GBXD6jZpMXC`A)wFa=(>|Jl;i#?>fzBq-ber z$(P{1mQt0I1*Zdn$iK$G&;M<{AmF}MJe$Yw|3LP6bf?n)UQJOEFD|m_1zmpM#{I7H z*l&}VpA^&=yB{zIM;Quma@`WQ;w_(jIp;IxiV5FJ7pekKfbD)?Qw>hp-3JHrQ zBz|z+ceg+b6QBR~fE&~>>qZVGp8n(q6G>ax+y9bL{J3%{Gu_HC$F*Hldhv9Uhc7j; zu3F`*KCAUR=Nht5QdoM*lbbzltFzlr)^K)!EZioTUNzTnyR#e0{F~b}`5XZ_P1C%I)mREfTZ|@Mq zH~U)o1SCx}tr-gnEhXY3HR4pcLU%iZ9{6wmGWhs)obC}pf+B5Iyr~qfUgnEyEdOU8 zxa89^4>FoAD>`p1Yx0b>8ct27EYmEfK7YNpXO%=8s{W9G)_H#SF~|4v8C?HYG<)3S z07q;)IxUK&9>ZRo484jQmf2u)!w(7CUnDKaIlBuRZBLj0Y* zpP%6iE&|%;aJ6cl-Ts-c?QDbSHSe7j`q1$OavR);^NovOaX8tEY;vlDY1~Dkr~RU# zBDF-&Pvf|>$q_m@Z(S3sfJA7(mkavLm`>UDSoLn?!uO! z#}NcCVbII7)SU-ai~D@U_VmGkm~0~ag+QNzK4*Zw%yfB_ODqk2EThVzsDef1(@vGu z6i}xQ8m3^o7s4b8rNz0%kcePamBB$dYisMMsHn*drDuEcKddzFgRzB*#X8$P8F2%# zNMGOCjSI!b^L!yJaS5sqg2JX;@L-76ha!5j)J&=3uCVgq1XilbCEk-;^i@<|j&L)+?dtBnCBCBtRzzbZZTy zo{vFc5>}!qgUjUWk>Rlqv+oy*!eZsfw9Je@-TiY!CD3njngIhCou=Q1u{g#Ko&f-g zYg5PlVf*Ui!}jn$0?|7U_?oHBFw3kO7|quL89=2$}6v{6P%<_|gkGSu_sPn8kFgAxj8(VPZM zu_c|G>d!D`ne%hSx9q&0DDS#J(K_oaPyqO$4k(R05+CUTBfq_UeI;KBi&;iWpuA@# z?YLUGy*(uWz^e3JqLKnrI+xjG?v>+86Pe%f&oEdK{*KQ`Ixje!_v7NmX=MpDgoS4~ z`A{wmHrc^A^Cu8sSMRI4l2!Eg$$9Pg6{%o zXbwFj9Wak$u&`DFOWh}AuY>AFVmRA92I4^_-EAlZM#L*9c=Zh({pjzC7Nsl^-z)JQ zm(u|Yea{}=qmIY5*0Jr!L;I`4njz{j(T?oYuEaw@6PJATX67Iq8|rJ;taF0-{$J<; z2g$su^&vG~V(aOlp=)nkQilE}j76zVsyyK|i9SSn`W8U!A1ov6mY^EZ5b7(=4D99j6bHA7(bs^ z34hM1h-#`Tr;%b?&W6F6n!@*+vm`Bix)2Kc7FqK#+Hn-Nv1eeFuBoMUc0Tp^OxP<{ zaqY$toVJg*1OBCy#3-sWu;5rRmw^H@N^pd_-;Sw##F3lOhDZ?4!kUF?jl zIsR2Ev&X7B{zeh_FyEgmsXjr_&)Zkc`$baFj4BQ~U$F+^vV=o?`zVMwVc(^x$em0si84Y{lA5PhckfxXZ46RE^+I zpRHSbPlk*PeZZNF@3@~k*{FJA{Aw=@09*_(>GEx399Z6>kPAK>hB0KQa{R#n`Vj`P ziN_oI{AqIJSp5)w0NDW!?~;S<>WMO_BLB8dB!*DfSOR(>u#M=VDBJ={#$8^0n02e+ zQ3zRf$ES%PaEmwuiD7apE*To%Utkq=cF1?8EU+bIz?TnYl$Bf0Ec=8-0>ngLqqc){ zcnh>T{a6bv!%TASxS4qlhK6n@X&E9fyf7bgSw|t()Au zt1cKi$WuJLUBdtF8yJ~?@M8lXyktPh#6*prP;1(~WjRk;LsHPeTIGveVcyWBzaq&5 z8bB8kt4dNsJ>OSek<AKChHJE-1XzoGyQ;8Q;vHukM|o21l(j?#04;hU?5YNWhwn?F=!{ zd33?C+)sWYgVRs$-h|un6L-1VR*uY6M^EKRq8_f=sA7FrX_t6s03icfMRZtj?yDLu zG!%Hj*YhV*a3&P)t$~s7m&W;YM`~D4Q%{uwN?V>M!c^r3JW6@q-@GLv)L4e`DopW+ zK~c!f?!rF%*mKA6rX9L2NDRD8`cY1tOz5+;+t@gxp$3qc+3A@IF|6mRopw_VPraoU zdPawt`Vk5#?P^Br=O~D+QlKSOZH|##NEs%UpK>gVv?cmLC4^6}l4CXU%7CrB@Fl(g z5Y3GbM9UV~VbLW;$f&)Sc;$~!Cl5X?mj%?HuT3J%h*9`0Byi9+qX`2OB`jM|X@bBU zZ02iz4%+a1g^-)*gVHG`4k3zuyDCz-q`fyN7HW6 zk*>>Hu}AxS*#ObN4@+Gu5=|CuY5jtswL%n102+YGA6W(4Gd0z)T2t&Z=PoR`XNlD7 za#Z!I{N`GUpf!^?1R*1N7^dpC=>}UfdASfzh6GjLDRuXF@$3cNqNNZ1sx|It*%tv! zqU$>cQD9?3E@=G0M0tztjZ1DjBDcsqiq~M0*_6#9qW2rv zntZ9C5cZ0mJNzQ$%{C-mVTK|J;mVMx-{k56XussEB$3}d^P?l7$(7QIe9o+tpnfeC zODew`tuVk_XelQwvGv+qU%(9paD*bD%snU?KTwA}{HEko{Y_aXQFCJU2sloqcU>S7 zOV1=VA<{Na`|&=9Z!f#Htgj9m^bTh>+-yUQnHLl!%vikcCTCLf(FlwJa;6ijb z4y;`t@Sux%)Y0sFY5lY=6aRYsJVuj z=u1Z7A=aAH?oUep$vTl+tKBm8j(>YY>+3HT!qFoFANpLuWQ}55}z#8Mk0Zq@+?t4g1Ay_yb~jYkrp+{BQzQdE;txk6g?AOLPY^GbNNeDC01RD1DkNR!;)VfvAgNX za%~^_6OR?FiS{rHzbC;aEs__&x7dx%l$q=a@>tHZp!Zy8T8&<;f`V4vWeLlR5(_?a zyM)6G2Q{SmXQ*@NQWz4&4RN?1!>W^@vf2$MI&b8)qvs zC@VGYB3!;1G?^k`F&GqG67CaRhrlhV@+-)5fZJ|phx+)EjDg3D`#`)sXrh92?wrj` zM?bqEEYW-l08aP5ly8s&lIC=WA2g^jeqyQL!$H&v0Q+l~r?XI%FH-zzEHwr+&OC!V}XG zhVm*7yfeEKryb`nFm8mXQ#KmMaNgELZNG)5&$P4=9 z$nIjv*LL4Ru zo$urAgVC}c@jFK%4^Ml4S(pdiYfu>Ai#HQH$^7Ve!TuPN2_x;v%{mJ7u~lxkS$ z9WU>p5-g&^@*r3emHf)ovKw_bqfqJNHhxRUExMt_vdLm$hM*#UH$Zp7%ym}_sz(6@vZBrloq*FZ* zx+u=9(XAxQNG=!>LzPF;wG0Mz)-qWv!?uDXEqxzk3w;E?#&A9U7r{M< z+fG)?xn~dJrqm5Lirm~Zl2Pxt?N)0)PS#qr?c?=Dd}@@`0?S8@ZA08HI&gnFd#z(!;qc_S8=MTkn@XtNxF}CsT+og!W?X)KQ>sb zg3~}|G}bc!;u!I|mq&5Rwa^lX9ndh(WbwzF1Hbh1N)4aKZ-pYapKX++1eM|y7Zyxb z@b9FD>TAKY!~)WL`FEP}qZwl9%qH9xEnhb#!0%!)+eu3U(+|a@XU0%!+M)8~fi}-r zp0Q=~nBB$o`8=*Zwm$BDYs&F$yoIzFzn!;FWsu5eavB}DrSk8$^K9~44Q_54Yrop% zD|2Q-`fU0V5Xu&)m0b)fIjNi)MVO$!ZNfY!$`mYe9Q)OEzYC()WT@^z~P0@T==l z8!gUdoHR}*Qu)i^=^fGYuNLMTzr(|CG__}6){{|25T$hS7zhYVy}rMQ=a!ttEF!eINx_)ff!I&hiO_*+Yh>FKDq;q9%vzxQrf4x8|8 zio2-O&by}Po^hku^_~s@K+71|hZoMba{jl>nH00{kOO&~M?WI0$#W`M#3}JFP(Gdh zk=6ocIYmVuOWeh!rlDW^nKk8yLRyJ*mm;fF4HM6GW)3Kbu>g+$*3!`U&5le|tcWS< zC{>IdPP!VLdFl)Fcs>k%YY0uw$#x#%_4muZcE-ORS_LIR(_S3<^+l7#Ck3 zo_cV2vPnkiFei-c>Z+vi{dn-HwbA zDSccoyCL8%7_Tb@&w-29GZ`A6A13LbWRkqOYxS`TLMyjOXg+!TOm}uL%c2%SP77d7 zJ^PqJ^bLm3-TiXaS9)t$xX8^tjA9S;a?7M====Fk1V^9?DU7$ZYikwNvG(ViJ`hk4 zNPrQu3k{Z$^F!Q2yY@B%)3+)DwdLY>7jegj=8fcZcx%emlp;$2I@ov=1^rA3r|Xc!Z3H(e*@C T_)CEo90RiNl%y&oJ_P*_08_v< From ce1deca9cc69058ca1c5e55b8ea9c8453beb336f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 19:37:20 +0800 Subject: [PATCH 543/843] logo url change to official website --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0121ca11..7bd2a6c3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # req

-

+

Simple Go HTTP client with Black Magic

Build Status From 0b7c11ba4254cc61d4d58167745dad1bd83c4ee8 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 19:50:30 +0800 Subject: [PATCH 544/843] remove codecov --- .github/workflows/ci.yml | 17 ++--------------- README.md | 1 - req.png | Bin 34587 -> 0 bytes 3 files changed, 2 insertions(+), 16 deletions(-) delete mode 100644 req.png diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e17af74..9adcefe6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,31 +16,18 @@ on: workflow_dispatch: jobs: - build: - name: Build + test: strategy: matrix: go: [ '1.18.x', '1.17.x' ] os: [ ubuntu-latest ] - runs-on: ${{ matrix.os }} - steps: - name: Checkout uses: actions/checkout@v2 - - name: Setup Go uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - - name: Test - run: go test -p=1 ./... -coverprofile=coverage.txt - - - name: Coverage - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - run: | - curl -Os https://uploader.codecov.io/latest/linux/codecov - chmod +x codecov - ./codecov -t ${CODECOV_TOKEN} + run: go test -p=1 ./... -coverprofile=coverage.txt \ No newline at end of file diff --git a/README.md b/README.md index 7bd2a6c3..4a274c2f 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@

Simple Go HTTP client with Black Magic

Build Status - Code Coverage Go Report Card License diff --git a/req.png b/req.png deleted file mode 100644 index 62a770a361f0413186553f43f04d5606150c2d32..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 34587 zcmXV1WmH>Dx5kRQ7m9lc?(Xhd++B;iYoNG8ai@3#6bn+^-61XR?i9FrzjfEj$@!Bt znKQFzKeG1$s;kOjppu}%z`$TA$V+QN?@Rw($cWJIx}88)=ndtoyuJqv3@gol7p!l& zm^Tbe^rV8cgtqVZ(|n(7GM(2Uv4e?;v-JRfI?#+Eocx7ON_`P!Eb%yv&%8innY=1;WNZpW{Cfp!fV6eC_c2fw+Yn{9ZUr zXpteI2wdvWa16q3#A7G7>rbvim8w;opxi{K#GaEx1}4K%0kGky5Ht-a%bq4VUqacr z;buA>_m>PlfR_8bDe3}2OLlvu>8|%0YyWzdgi2z4!jMA6;2K=Um?0Zf(Nch5^@{LJ z=l79^t{e7LWEd*vGO|8z>J?Rdd%4C_MqTd2GwCJ(DOH37TBs}%ggaEz+BEc!iCy`a z!BV*gviS)1zXpy6HA>2r;Zw&cdt|zlIY~WUNmOxanZ}Tn9J3H-xE<{iJJk6@c z5V#rGzv+_p6y^A959;9f3j#)~`CTq!{N|uvt0tgp*mDFh48X52$-2~1Z|EPP1%VsG z*FbK46_s)P@4j@Z+|kxT$m0?QgmcX}{iukp?^h1s7+eX<$WX~roV9OO35YCl;Zi4G zNhv)k*(WHYFMpy|87!&)lChtU)Lc1_cvAV6-*`` zo&LIHZ3!Z*uL~iBMg6M><6d!wWKjz$c_O0%0u}Zvg94QVGJ=j9L|l=v=N*srTrd3N z233rbHtU>J`Wn7LX*_7Q{3au1%xQ^I{HRf=2!(@K z0;kK{Ea{Ay_vc&O$-#NQ0ekqp{TRN86rY`_GW(Px^&6}@400a+b+63Qj^^6n>R?D) zFrCWdE9u!W&5jXCGS{`_aH3F&JJ!u0Q-yQr^Crsx^{77$C_V&uWGdXfeF7*pgeD47 zrloif1}VK~*to);$fR)DG!o@}wt8$4%P9w?WF}Y(t)viV;_*1PhJwHcsV67&o2!b` z$Ah(v&`af3iu1^}EGlf?hHD;#mt0^Mc8`_B2%O%grI!1jNNj*C1 zWwvtr8Dte8G;-IPc4p7)I8LoGFWFp^;_C(fgQyErRwx@njCBwZ%QC>3ja4L)RHLB< ztZ?}HGidi&aP+QqXOGMG3H|7PE>GV8QC zkqDU`Y*0h!myxuKlj7tHn!@v;$-9)3Gupy)gOg|!dmNQTT%{#!6`dXF?IMCGPrmjo z%uwBLifBi7-)y*ErF! zqn(|!8;LbC@vt|wUP^W&4a%uJg8gcAQKs>yQ|j#f)fC3LIFiB08F{LNXW?{V$=ftd zIIcaw6q=My`^Db?R&W;01;N)F{eOXkka48tjG@Vgzk5G@qfRN?qn6rTTnK7igOOa9?JV{(29gyyGf{7Ya z`n7skhrS;<_@bl>ZtdYRBR+5Z4-uztlvyL@QRNa15D=GZCv^L;%4VtsEaw%XO z<^+Nwasncg+l>s>-SbNzJ%t_g+(-z`ln1s|D1XFJP7_ua4Z9XJ!pe#2#>1#ED`(1C z>frCqxFAxf48+3Cdi zgSl)1Y=Z`KE#&et>|#!-;tbb{a)TX0ck{Fc=j-^jZ%t8pi21z^2q*(J)y{l%LOS^|1LJJ(;cNCjSKMb1j_l-Wjfh}u^pfEv zgB%EvdlJ=@Prry^L&5TO)Do`3LT$w!7l*AI^Jvp7i2r6FZ*DJy%NHT2t&SF+LJ-WF z8|P^^#uT=WrWtvo*>8&9F-V57I{mnyeh&|6 z3zT$C$B1G$U!p)`kJpe`$!SK%$cdYUfeEM57Dc|GNEQ)*0rm3Vaw7QR{p{` zZ_GP>3#qk7r$rCs2Sd0lY8W2Sp*&DysXhjmhu0_7CU%Lr6#tbvZ zUvlh@RCM=y=A82tWi{#hD{t-2r;~qOiNAR{(>XC;qntF{uC6N_&mzKD!V=BO;z5LP zk1oZB-!8*0zkZTHrqUQ2_PBxLHUU@Fhj<}$w&yi059$s zWjcZelzt+2VZ|8ozB?YdciM&SJ8a|u3j$~ftpCjtaf;9c1Dz_8gNhN$T;Dh9lL-~Z zC2ym9WGCPDiexQ}!3mvs-nB^SS@UUd!H(e!1;=byb$YA}OV zd+GkMD{xJE<~exO1oXkNnAn6G!9*0#B{-P&1=J3m+_PoV(ooOT<4)iEuO;8FcIsnX z9Gr@OGrQ#8Dn(c&ov3s}07#c2Y{((Yh69U~>?!YwgqkFU(lkxF=1Y3-H_+L<>A}hF zxuiCXuxvXpA`}pT4Uws1{Ln^fcdFh|{x<-ky2+Gpa*Q8MYmr!_ICInmVPrDyG>hQY z6v?QIv+DoGK>QQHY7h#l?R|1Zd#3uDjNj^}%SyAffpfZO6U8!coTLi_nf#z&$(wc? zmPhuDex{WS0}jG6#5F}RYV!YMz=D_{Ec=s&mGKYnl+x@3x2S%_L6Bcz9|5GtzuxKo zn+pj3N4jfskK#JY9MD-p3|T{2w=_dC=WrKphNGEmJ+N`GGsxMqit_o3L73w_vX9+O zIHRmop131{4fqW8tu>Sd*OAh*Bq*(?-%I0-aMDZ+`EsM{awGie9ZORUUgRH^t+fH3 zVQdR;2P_`Wj|gsSif{Oz1C5Hq9UBe>Sx(s%|qxhCSYMS*QlW70V zOcO>@&#o?|byU66npZCz2C(kis>GzCUe)euIN%@3GGIW;gZwe_fEgfEqJomgGJ#s$ zBCE&{OPMJFq;TM{wrYd^h6LHGzq!eY6k~|E!#aTnhTFWmfl^K$zQW{&urW%{u>{vH zS%Mr1ma&0y!$|hAK*OKFqs0}r2&cc_+rlhe)Kd1BR31hv_7Ab#FvRh=?C*XCQ~)H< z5VKX(flbs!z@de~{<}svN|+`p#6tyR+m8C5byCXd(tach;`SRcydIQi%#)TkeXk-^ z#fMCyL70n#)#UAq0fytv9lE{mi5&Vw<6E#K<8J+?6|TMd)Z(8;uBkv_&d?xGMhs8o zB?Oy70kXfjWG!x>1w;d#XZl4kbSpQzs^u*NBzSN~-=kBdBG6Hp7*$pWY_0#~2(5%^ zg(ktqz~gZY9K+kdiy@J3m2HyC{ufuuu#MDBMMb2@)5AA&k}Eg>p&Dm~`j%4VBQl5# zv#A{)Lwd5jmZ)4FrvSy9eF+)RVb8B0NDl?XmR$CcdIp#Sm>7&pL%jVcZ`ww-%U$}T zQ}&(TinRz|S-&FNY^D#TXZYdxVrpSnQbJnmV7|+55JzQ5u4!PCK*mdhz96(zg3xOy zEEECJ#dzho6%^i)bb)}nd_7J6D^%2+^Q4kHUo(xyWz634Ub%5Bf;oB}J`!_%C>zU= zqclZ1s{b!y+Hep6KI0m6{*1b#CiEy5>$MLUj~IRapp2>OSzJ@ z^rU8WFhja~pGOqiHyr)7#LU*}NWFfN!EMnw>M6_oLYs+ID7_DhCQgK9|2AJjVuq&7qrlPOT6U*Hd;W9P!^ulgAkGpkFHTFvV=F#q)g$8z0mKB$X(A5{pAvNPUm_QaWA4Md^PcUd+$}}RR*HbZc{F=ar-{o+~EhM&Uw>~r}>p* zpB0-i^7z;1zkYkd${Xm&4%PPNR^ZmBD%F7ir9CM3YvNzOC| zES>Go^*7V3HIqdq%TtLP6wTV?lOsOYY4|mjrJLgYzRLVqmGHow8;LxLTESSrbaGrG ztOg~>t6H7$fz^WD^$wL|XP z60P(U4ao#+4~^er3Fk25lA`Z?qukDzPUO{R-Cj{ZO?1z_uDit~A0J_)7;#9(gO($A zUkYOI5@t7>zHbqDtOoxu_Vt(=-*{{*Y+6@+`MPe5D`X-8(2(AsAxlt2yx zIoAbI*b9^-pD()Pf9u_T3Qn%;i8W zq8t9Kk#@ED{eHtpzG}O@{47+rOz2B`+w*Ba_wFPH1-(TAmKao3c1iL!G6RcJy9Sy+ z1vbFcS0)gu#l}5zAluI`w?_JM4*q2*82X_mGE}>4TKdC_17oKGM@5?ictCKuA>98E zX)TwWX}R=iO}-JA_r^Wh5w}5}Tf+#;5#L2r?jW_TT8%cF-77@5t-CV+8Vo!SDSjIK=#F_2= zh%N6lH|8S+eOATxwu0Y>KnM?JNm!pIW$aL@q#_zi_~>^|fvaYbACqq2lA6{6DK+8z zQq@-k6&0P@f_KhwND#sFF`ef)8=;r(|xk?AmzXb7FkBp$I)FiCr$?7z36@g|nhnnfjn?LSGgGwgBY z%!I|=Vfjpv$$F~FAeP{mf~608V%-yj@K-j#M$OKdv0=8|*a0m@K8nZICrX6Ov-DGX zs1(R@O7e45-SQ5A5!7wr^q$lGMj{3jXz1y8oE90Z@nhRABdKyGPr_3n(e(e&6-fBE zUC*1AxTskVTUgIAp|_40l^R=5f}({N6=jbk2G#!W3%Rq>eyWRU?2dH%eaE?(w{@VmJnJCH^8JrC!y#DCFmYH^;AC?||=Che^ zd@m5XTo{mh^6pO8?nF^2aB}l0V0T#H?NICR)a9UdntYhKL2TNk^IXQ(mg4NJLeYU% zL!$kVP|{4uZ5I6K+@U)`r@<8$Xo^H^v56qug2U?jvk&@=Fd;OQy7g`Fo1d(#)D`W* zbc11qNaXYg*KR;d26~B7T_uUi%cz37br#zTbzv@%>5q!DtFx-KxaJ2rFC1RaQxhRC z82jKSROT$O6cHu$qV{4Nc0*NP^I52sEOm5~qNIGI6*~`x)T&ob3ddfa$^jSgmU&we*%z(fK@s0qVZs&1v>)62^Wk?WK>z1DnNv<4Kxz&rv z{qCY|AwUtL%q!)kgR$hmXqZZbSxIEapdD z0sv$F!bN7A*%l+364+7AMcKt9{b8m(i5}2LIFLc!r{ZFL$DKc)6~`jo2PLaMap0qy zkcWZLG*gN*c%x5?Lc`K=BR|M*1V%f#G6a#;6>-XsTipKQP+qeU8d52c)o~ViGRi!k z!o;n{N86vG#`$5@6O>fZ)`m3RjNQM#b`m^Q(@OA-OeEo5G*GOk=PG94^GUURP8=EV z@5=`&7E*@PxgSJe+2!l@g@nRHs1tJH46pTi#ooK=D`j@miv+o$<~Ga31q)+Lz*K!~ zA!xCz&=|R;c2X+Exr!qxb51*V;dA{88BQD+y#>t+J_zK>ChcFQiApq=q!vtJrY||( znJOv(2#0~POuNmsCQaXEJRs&*UG1K=E@<9yI|Ix>dM&7^>;I#$^Ja+@mfk@@<*Dus zGTkqQ4Z@NU7Yr9Sq4^s(mix!2E?pmq@Y?F7W3-}k_Q^%uQi%6XMde#w7lCD^lZ>$z zTBdc{WSQH|)1_o}ZS8g@X?^jM%O9jH6E(nJoFs-{|KuQlMi0n~sC9lXX|?X;>$9cc zVB-0#Kp1_xgj?|>)--h4d=t(IZmJ4@;h9h=EjO>ixWaIu{5()Z8Kw0lxB5wHZib#{ z{_!S{qiLJC)U}|v2I0JgNRLqX4u8X&)W~CU*!Q$)9Tp%ocIf}tcl1&iGtbO~IKIo50fv;2=K11q$0(Dv6iYRR zY}KB8>#0)mrc*CY5HslTGiH6i*Q^;F=jGbjk*n++Hfb*}S_USbK#X94cn!F@i581D~vY@AV(IiHaMiE&*t1U9y(3?GpB&w`HST zgHW1Kc=n+!62$?B4_kx26=DcoTEaDzNSI=u?pNSNpEVlKgrMtJ8I}Xis`Pm3WA|0=9I(k;x5%$2VV7cY-g!A!AEI_oiX$+0jwe=B zf$-S!W5IC17lZ6a#&yC0T({$)rIec5lKQ5cv0Q_iW0I<>8hYG>;`F4v>9k3!89|Z5 zk1b31sD9pSmh`SYx5ASNr9e5y(E~2&RFne$`@&u8Zwj5~SdyinB*ae)60MyH1V0O9 zkcnH)z-2HenJfr4tTqDx<3p`yy$Z`CpQt_<*rPsudJ4qkE38bOPB}ZF z+gy$}WW-O49P4e*5@)IB8T61J?`u}hq-DURB2mLuqX!P*(WVmcCq`(>hW#Z(^(}}u znEnJLNuXn?aXhgz-5g_CHkq`BJi_k`tL#1J1y7W!y`|Y%bCp(}*=*0v$d5}X#+(ye z%(XtMY-eCVgCgUe&J{|>vMchv%`1OSEhZUXi=>18{ zlYz(3bT_POHSsOT>u#56z|jF)@=pMHd%dbZA&bQxGaK*&)@#h(G(OMIB)s^b7}NJ> z;5NL#31IZB?LE4n;3q8JI2G;Pj!#juIXXCxZ=L}m#W&6lpddOQVK@SZTNemj&t~dd zuZEJl&1LbuY<_1)@0Sr+?nDwL)-zG|K6E`#>;Q}#lb>^Co7>TkoG4l>E48>B zaL1-gLC`;BZ0m*c$K-Q8)wArnYKNS${|@q*OS$>W(f7-wkmK3;p{CtC8R5&pX<$^+ zzdTAekLso!y@PCB#2)TOhM7)JhjEuR!ANfid zZ}0vRp=JZbaX-4n(x$Gs#Vgz9(y56|F^4?X>0MrID;*67zP~`M{z}G>y7;1Jb6owg zv|DqND|mg{8$@Kzqt_7fXO6^Xgvb=TedkvNgk9`k!OeZBHkmZL66w%oEpr@nf@y^~ zV^v(V^Lb%lSj~TQUsd)^d2LWVN!T&!1Wz9XwOF5fa1%0+BthfZa`wBA@4+IGev=i` z2!8@zTX|(C1{DHEY2~)`4l#u|&{__P|0P~qN{kJ?&@4}tT4iyRm$b9*|FL^nS=4`g zQ_#dv>0x_NJCQ6Roi{1?WPSVTM*~c6LDTfn78TMKprL)AcyZQ^Y(`HY7_Kw$oVtSS zox@nfO?VB7FTZd604{++e!M@CAy#+b>rLDf)?18R&x^|J)eX!cRGeEwW%Kb2S`+vi;1{@OUJ|x581C?z_8S?GFI{OK97vvX zKe&=<_FTY~)4Ef?z{^`4BqN$N$byep;Lu$t(L*(D3%`yd`A3?1U>j=8g4$SZ-)i}C z9`;E+eY|L>brzBeBx_oM7OZGYu+TzYCP5$D?vasbOo0v>w0J_?Jt^uN1&B8dK)C!E za{2eudz)bkp4wceQ`0DWZD7fwj&|$9Wvx8mLkJ?B~XhR>UT#rauV=&b0uY|id;Bkl?IpD;UTGKTk)9{&jw zp*aqpYK^RhgWefMa~bLdH3zEAlbetj3d&5eR|FWL?`{R6|M+@Jla+Q3t_NxyA0G~; z_9uI`^m5x3@^Ukow0YE|^)gX9m*Kl$BxAH8a;QKM&97DNA6XOB%iVDHeUa7k#6o_| zB!)xk`i3!h9843Nfw0Rg2&h0<*D-(N6WKI91JW$&@gbn>S85!iJqTdD-&A9n7hBy7&w%!IT?en^Q zrD&WdNgJ9;7&5k7#7u{8#5y%^0VV^gh2T({hE9XZb^no8AIeZYXO^a8Me?Vi<$;@? z=Z`RRz~YThv)Vk1lUnji+@xw~Y_2Z;ex_sRN%ujKphiGUU`5tc2HL160I?UyudVxK zeugvdd0%6p?ORE6N1`;Rbwz>1B5t{7a#h6!Gc|hZr`|n#sX{fganxnkR!d7EOt`R% zfZMzKnRb62GhDr~V*+NTT{N>V>YKs^V;wf9y;UyZ*h-Dxi8U4H@|mH*mnGR6Bkjed zlIA^>NT;2K*{pf3@Io%yZa&$t2O(4QL&0e(ZmW5=6gMzc0cu6 zK2G!-pDk!JFPn>#=aeaFN({;OsrS7Z5`$}ewLdttbbW?D(lOo@?2N6u0=;M206l5n#3_Wd|6bR zo2pQiWYU4gUO=9qld-_f&F;T#wS7gK=?tZO)U%8!RZaQ=Q0}-xjA2pqOWmM2P)ZQ; za1}J2+#oDWH)FRl*YEzR%$bx^pAdYZHDWz_(0uFdHMa`Rom#8|wbd8>_>pI1iZ~N1 zPe9=fGS zT>tHimVt;9DP}x}hUaC2Oz19^*wsNR&hzaS6V9_c%OIi#Ow`q=xEL+F0duziLq$k)xs~8rEUngpa)IRLg1t8FpQoO!q3%NA z(T#uA>u&pLYz`m8zy!=lAwq~+U1Lp_OtBek1@oU4r2JZ+@)>`#A`vv|nR zWKP!RH`m}V7k7)U=ED}Z0ORJq`2djY3gKOv+zrIlXlh>UsU9ZVbD`Y))R zuGMFYe!i7Nh>&KLGp>eX2bzTc>|I;ww%r>tJJwZy5{WLK7M*-2Yi{V}CqBJ1r&$u| z<)G*}fEBq4jX2bA1^L3KU!*xW3{Bej?ma_(UWUEL-uu|5IXkly(9CZ1KfYpF^!LRa z4YNIE%Ci~9k*Rzg+F?72&Tp%@+58|gyj)zAL2aO;(p(+)0ZVvnCo9+x5OV$zrbfsy1llDszSh* z0;tgmiHaqVPpzUd-*cF=dGp7t0glq5<|=pC*t47bI>*Tmf}0UfRwT3AEei9(i#KsK z(OsI`x>IV%e0QH?$+hiHS*mto&B2A-Cx(5ouUIev!8e_P1{zdbZ!f0y0XuiwZ@Qcc zlVwbx(^e$tCNPQEH00c}8rxJf_pLNj_ATkPFT!znKqum!4-ZNQ)tHKq23 z*NkeBs)k&mI1BK)46THKmY_0V^zQ!a!$*dA0=0$^>IJEW4%Lgd@}CKCRuN0~BusP! zmP6`?XL~KAZzE_uZ#jpC=GSc6JgOxnqX^voPuxLd<3}fYVR>h;RKNjd^x4lR-?-5v zHIU~^nR?zAq#wR0*WMYym4FFrnne)m<;P*A;|qZ_wPOfR?hsia_b4u5wQngh5W}` z3yonfndo3NA=DMUK$@_JK#{l$u*l5~1_#U_- zC-7wIb68Kutnt(zts(?xS01aao$OFuZaGD5^0ebRcoNc)XcRNrYpw5ABd0>Izy1wR z4`bUiTofY;fMvbDX@J zMg5sLFqv<`V@Nw=+?%WV#1-B4<`)0BoKp$+&%p6DDq(3=g55~d(u@t%X2&t(H=C#5%=qT60~pS}(z-C^1mOl|~R zSWFqeV>6oshC$Ws9C%m7Nk~FwoY5bn=mbeK^GzR!Pg#uR3Y; zKhNxKkDb*76Nb(h(j3+1Ifmt<6BIt*7y$$iSS&lWl5Cd5ROf%5cln7V=;iXcbficH z`xCTJC6+AQyZe2XOPcR@i}cHWXC zHtij6%faK-_^o#-j0yu2*7)S`eX>@IIZRVqNPJIyR`(-48hKZtOp@@(HBlj6{`fb< z3E-FHs#*q(O=+)j{No|#n0pcTpnR)<^%1B_^I630_%Z=+jsxu?`p;q{j^-n`=py$TIEV?MCmHnO}AaLQtzTxypso zL_@BtIZR&ogE-Z?oZ7l(bNAVhUAq-7ZaHNd>(k?8x2o$`wM! zPFj%)SVAmoaMDI2s{f@84@JFRu2pDSj`fA==Ddf9Z@$wcrWJghI^y+UZ@>On-}IYl znNC+n22>JfBxiW`aErm=MPtrnG&WKT$FdGE8yKRpib=C!+jXh-_90$&SO z?ss7c$ahw$nwJCb0fA zX4;eD(1-{4PbyJiXGV|#aT=cedAUz-2o{V&FM6oVZ)BQh{!&jy*@U!l4DdOOPOT^ zip<(BF{{NM>_18H`=42oy58#APZMFCAOYl-&DR@K+SOG2v!83c8`Knlb3BX${!wBw zkFCA@kblO2pkovR(}1XZpDmkE-zqA(s0*7>PwP?79NN!wiMR^d*gzdzC7rY?my%{T zI7w6*mFpafWA27zzGBKTh6xPQEvG64d`a=5?w|Sc>OJEUYboN-t9qGQ%9e!tg4*Et zFC%~V=dA*C)6UK*mzSZVD^ut672!WkVd@G$pfkfI98og11L<%<%8-=(d>`vm@XHYg z*F(8Ozh3j1@or})-v~?fKg`>@bL#_-sl|q^OGBvP)jefdQjM=m8(zqa#!7&D_Jf^f zp%K5Gl-0ruX1AkL9mFvZpNq|Map^XSsL^Yn__4l$1dJ_ZH`ZJI*KtR42Iqu|2&`uG?DvtE}k z$IZdC?ft2~rDG zy3?w$!{W*vrgZieO}G{ggbYsGM)2rl2AUH2CwSsEt&xlf_~x_#d2DX&T61|5Q1Iy4 zZ6=6G@<@P#+k=<+lEva+^`gqkybNaD=D;U+LKdSDAy;r%-e}{$c3Zu-4X&=hGoBKg zHC|=0SM=`R45BiNWmmTV{X`Sf2sdP^5if7ewnhF{p*{51xr<=1s!fG~ES`c+vT7C_ zKg&8jWIkbi9SJ_mb?p!uNi6V(h8+u5ovF>l*_jt-#n#K4v5SjohL*_>51cV)YC`cf zKn&3IrubdQwE}MoW#sPYc*p>k*i-Q(oY#BRz73|pDf|6~Bb)cS*OvPNw$uB6V<`4bA2x@7PDi@>LkW}FdR)O`{%x&^o_G2HkhaNE0WhqLFw5pE1PHmehjCi}n{Ao6-}oETa_HMfuqy|H<_Cx()@5RX` zySx3!-R|oB?!52QLk|hNsky+=`3;WqYBRAeoEencvYJBG0geiNC+{O11U@=m?*c~T z2PC+Zb9oU+mhg8^S-v$-L5t6tTIjo1eeheQ7)dINRw6$tUuDXi!7xG{LLGj_4t=so zd-#i32H}!OpYSMz+HWP2C*AvxFmSi!E#}7pC=C%>7)RfjF z;2))(jZ?%;m@LaRl;>$Xw9>*`OskmTx8}Yf56VR9@lUY}I{E+u{)~)yL1-`MlB_r?cnkd-wX%93Nkd>$UN*vR5i*es|s@Oq;vr#{VLH(um zXWG88xA}q#m)`w9yomEJ9KqD<1-_^`t#2QctfXf+d@^CO1J`n}x?_|!%+Md$2eY{-NT2$0uppH;tdCE36J*o~ zvt?$B>uA4rB1>h~Yul>pYj@)c|3%)2zO1Tvw!B^)|7f^g^_bp!A~GqmuQcX=+YdQ( zTURdoIZV@ClQB`F%X#MP+y3269?@!`FHZU!pNw!Ko*J?c;j$_xbpKFN}!i!0yE`{AZx~$IXjiIcw{^Mx*x+L?$Dy_3enEL25$zAA7nr z#ubsoy5F8mK}!p|dcq#R>ox-JnJhxR*k_MPRDbP=iadpO+XXp_ax(I>0^+qj&_cwn z#hZ4o)WO zE#zSt7f7>q=|DV&=J_e1n|(vq+2`YNqqZAfivtf)Ia(c9Ni%^&!d7G677U{lN!J7m zJD}uC!AwEYoB?rCd^04Sb#2tsZ2l{*TL%%F#XVUYEV~0VO>Ox*zIoqF0t`FwArz8f zSE5wqW$o=bO_sx)!T)>-Squj7_}&v_ZiwRcb6);f&%u*|k;3w0ckXRKptF{KeVEiq zmWj>>Rn9tQ;x2 zfA2h;K__5Y?D`i2NwGHuYH7GVTr*Gvly*Ds$)4}_5^qK38;D^(%P6}isVol3QZ2=N z{9W8gB9J}Nmti^Fe7e`>$2!CPL-H?LkP;!LA)9Q;<;Hi6HACcNcx*a75LPTKqH6Xp z_usi&o~dTBY3Rb6v;H$*lOMJh*|Eo#2lm#t;;l%IP0+7@M>|s>7%-(Wgi@ zUI{0$1^z|rdfLxB*__45_!@E6Lv=8#2(7IAUb(9E@T&h2&-43W9H*|P7)Ccy#H1*%G#y(cC1iP6Y^qbf%WGs7m&+1tc z*>d62uEbdnB>1!q671;Q^SBPF2uj$gxcEgAaMrSH@RCTv8MC~mII3I;# zFL+6yzWaVn0?3n1yKYD9WFWk9o;m38{O96s$Ch|BQ|@9QEAlprH4-6gKGGPo zYio+$q0d^Io125vdsc`@!{~*dYmro275+m#S5FUhjgBD|Zfe?(TLJYV2?s-762o zc0OG9NptMyMU4J>L}DkRWo&HRCAQa-McO$2Vi?TES7k_gV&e4J9T2qH>EsU$BE+~+ z>})Wja%^4n8x|Fbq0r^el9KTKOu|U)0X&qFSw_iNiQl5fCnrsScxh@0>Xf3I8VB|1 z`v+=*7)_-Hom^XE$C!jhCwo^{Q%6Hu+~ST7|GKlk9gZsKbeo8p(=jWI$M<&uh>qSx z7RE^ENU|26eX;(S6K0U$85_Z$##Gzt1SVucnhM`fZG|6u;Rn4UE?yQEmadocT&T&D zh$tGZ^}DXem_G&G&c15zD>6YLZrXiO=L-R3DkCGK>)-WwAT+U4%nBuAan zvc4IM*$JQTSaox8n{%`9B|zk^y?uPa?7cf@q}cfBudVvQYTH-*N1nMI9UZ~*n_2xhk)Vh9Aik^>v=UU(@93C5h(`Yg%!7<${hdWTwyqgk?4jLCMP<@>q>L!jSJ zl#aJ>{23p0VQG=DEvvM(KGTj{k(Q(nHE!LSYZVq_Ztqp^W~b2(c(VeOBVP3a@+BJKB`op z#g%cVPgAG|KRbXQx-T#uzmAXMmwZ>3BGXEQYD@DxJhjEA^LlE9Zsq7qk$Y54S}Yde z^`=OfY?3)dDYz=AEp(5clC~+=VXNzC+F~$%x^9sH@*uGDjmyr?jxueusAcF&+;1$Q zZ_rfT$v+c?Ew6MvZu*uA)f<2KRj>9uuSXfUY|lR-?B{A#<4=pe0OZKL7K(z5e^S7w z+DtkdXo=W!LQT`{gPU|A+hyh4aL*$54d1l`Fi-9<1qBpDTzM@VJ*YC#HV6YnsdU#z zM2X%*EB5)B3jzu<*uahVo^p!9(|aw;CWa?g@kDGDO-*ifXRbYWyska3-aXFanJlF2 zTC*!nmSKI6Tt^pEt#FLn_1xivRa>#w6t)62wF>ThJWb8RBl`7Kp!NEQJU%l?ar!Uq zzk4@~)4!p;EweKt1*WeyJ-jtqxC{jgPY{RLH5yVhB#wyusY zZ8TY)_SyO3`Yp>%zzKL5BY@bky>_0`pO4wuojWLz!9 z3DjpyJ9QIS`#ecfm|6$2X|8GF>b80>$X0-BE;!TNgJy}bNMLnCnKq}F4L_eVL)<*;Y%} zx*Ka{@#Au6RXv1-fuiA^N$rB#If{&k_>?%XtOY>PV;mBMdrxRXZd-ozE$KobFGm>>@&L_x6NO?S-DQhr4|eR?V^r|(`$X=vBzd65PsKq# z+)5G0NqJ#G9?%UWW*;SLffFmN^jqHaJm?aT6im?^B^fWV9kp;LJ43gU?>?%?^Szyw zi_h&>Xa2?Ab!DyBY=zQ!0TxP;M_Beska88Cy51X)Kl#>&IFXB}%~+i=KD)4?Y%7QK z6j&htS!Qy<9x7?7*8NGloDzMF$<8h&w(Gr8#<;Z|#+zaK;)n`%3e*Z{|I>%g!+4KT z!m6AKH>{n1_nJZ7uY@orjltrWXj&{89~$l2llRd1+$pDslF z{=D3?zn|wypEpSzNl6%AH&0snU48E?dRY@%s0ThHrqvS67!+TBfMVLEJA5AgKglrZ zATr>Z|J;u+-z*A){){rT$=AZNBNkaq5O;;sFE5TJ+UAenR~A-FjnV?*t%BgOg0DQb z@5CUz;*O041>{@)Cv|rYGdIFpWCI^;7^lzg6z~XB*HTjhMU%*qJm$Kwjys%NY1U^G z_oPR>Kw2C&+|wprHc|2#Ve%Ed`oA*B6j|;2NilJnUH9p1bOTM7)mub?F zb_wCz@@cJ+gB%VI*R{2^=S>PK1Sk_Qllec6&M~->u8YDGb7I@(#MZ>N?TKyM_QaYP z6Wf{Cwr%_C_xo|HQkAOY_PKq|-Fxk4tyT!~&|O-cn-h`OWIA~*SodwD0Ce-E`^nr+ z|IK1){~qG$r-UZi5|`~{)6qL%@`A)M@N0WJ2-SNRHDkKiQkXINB4`~t`gS5lgPtgS zX@0joWOtQROH)+CTUG-jArg9ag^BS`JlPq~cbJ&8AVVUTH^H~JE`6Cz$9Wfylbg3I z(W!a_ph0zi!2@QcX_K?Bvr%VW)K-s0&t(sdX%Uo@6!;Ol?GWegco$yh4l^%pK8rjZ zw{KPhDeCXxCP*apW`1E(#7ya~$cs~N1szoZkIhER1(HJ5bl#%S0cm49P%IIRlP8`7 z<);$tjnnLJF=DH0S%c(h zwREsI1KfNF?3B>NT-T-D^K03Tp!j#Gn2oMCp<*a0>`}QsDVqmZ0>#gfX?jS~`dNim z6tY5MWGf37CaY?Nunb{3E^MunuA8k7UlNFQzUye2oeu%IW6J;hdHkP-3_hbrpX&CI zl?uhdetPUi9LDv%hY2|cTx!5pq=~F$=6my4pKsR0Ohjm)Sje!mq106Lk0I*8NkJ?A zj;di~pLcK;fm#CkV4qoP&mnb&A}1kQUt;^CsiT^(-gw+Rj&z{Vj2|OAbA@n=Y_XDHQHz+v(9+WBBK6H62`~t9geT zstOr{3>VbgU$r8W5?pzLD%=c0Gmbb)-Im6hL9XRnqo7D-2_#R-&$pl>?jxYOKEY2Eps_&}f6B>-(`bO;kv86I+U6`E!+V?~#H z_YJ^&L=WKN3PX#<;IQRcf{n9y$@8Kk)xhQ0kpWa zG+F>#xLJk#t%1kESDUH(#@`z2>dO13!*^3lT4!t}m&cZ%0?~>-)^c@s6fR65@*`uv zx%&;@>%5C6DvjyiySnH#^p)f()FSjc-=t+!D`=|9Dm=DZ@l$^J>maFA)jxj zdIqWA{Fp~#FHl(OH4gXJJn1)k3@ej>S(%L$-A8wI2iEs^96m@M_J*qdv7@q$UsjoJ zn>O3=FvJonT%3(AuY`nrN_1Kc;a5IS?e736f4NE@?J60~Bb`MsDRq4;UyvLI~6V254yQu}%Nv+(<&Pkp`*7uXHL z#y)?o#=pO*@d=T12tq^SQeyGLTn-y*$ixs&tt zJV?2Je4J&nvEp|f?sj%mf@`e?6_&cy&2BkyrpCSX>~qsij0s3+*^cFYSx`U8V2|+w z^=(wRP^Dn?X<`12jvux;+$2!JceUEt{Q<`Bc;e=K6i9y6DO4pSRZbag*eil~cXc<3 z!WSibdKtyg7GC4`aTdJ^9dtC9rnd`}jMUBRj&R=+PJcaqeWH7oI!*SE9ZEJ=gR?H6 zQr4blm_MI^ut4+>G6-Zjb@}2nxFp%0i51ork?~imL17}aYk-R3n7ixrWYw)9b@$GZ14{0_#CxY+KyNwD?3ICHM>3V-)d{zG&V$%rMW{iRHQ6&JUA1uyzCXu50)4Sex&f zB|u}S-rGpeb8MHqJ5gmfElt@N1Es-^^HqBT6!wsqoCgS#}M zGnIoBs0t#qc?nwnk?~iBDKeC#rA=$z^W>j+Q01g|i62R@JQK52mL_5LV5QysJ@p9X zlMuU!C;t8^?>YQ;vdC%EXgX3u7Qb#47K8s;Kj&>F#b}50Fq6!FUXs#jnR&Q&n82&H zf4=PJl1S_H;=*7@tCPpgAm+;w-xFFfXV18uZ(3*94L*G^Q*1vjRW@W`lGN_QVav;O zT=?bS-R2>p-*T+eBnXl@=$pS@Qk!|_mD*CwM6kxUcSz2!=kATFQUb4gF-3tOm|+0? zF%&f(Z#{F>`|P(HNF06(lWmw^m}ULN3e4UBMI=$CIQ%CGXRHW+ZfV8eq=T{4I{QuD z)G9q)?&g9E_t5sZ@irPw-^WY z+cK@W!)zpt_A>v`=?>c2dHS~#xF@v2sn2FjCWecr^HxU`K5k`9<@hqmuE~63{TxO3 zeD*_hdh^TKyKqaJS;?lraKhp!&*|r{C(ZrZ&yQa7pt+n~*uR6IBMz6x?pJmb=;bf< zCPbFL&N$Ma1B&yxW=AgQ_*B`!KQn;20-zlB0X&8XgKxbX^6}VVFP+zcM0R3gUR9lx z9&XBIA9ronsYvZ^sr@K%? z=qjvj>tR$wboo({liv`o6J*-qLW-oX$4JKSmlfk1b=P>=BYV84W3|2KPtoq}#I4Q# zXOY!3!mrXh zK3-*!49xKBR<(4Qf5EHnQO4u2;ZPC~ZzWM|EohD-sUq-Th;So+Hf-q?m|xNnyv}iO z9v21u>u5)#Ca0B~Ep3_Rq!T(d)!o2AUdx>oWa( z53O$&Cpc=4Bq<5g+YYYc-ra&D8e1(OZ1RY;R7J>CnGIC$V?dDI9q+4j427p`vVi6V zy>Qq-lmwbuFmK8GQ!7X<4;ErwCR9h9Cqqc-M_9P(QtfvB+erpb6qswyKcV5vpk`-u zi*Q+-L21M;4Kj3pGn)mMR^ys8+V%QT(x4Zb&|&YrVGd>OlwB5m9tbu!n3HO~#JM4d z5ls2rWJbOItmdzv5pjaQYPElmKo44r)Efhl0Q2SGG4s4!Mx?WKppu!@24V6|@HHGn zhmLFK<;;-3+9)`0EDo6eZyq-t)$2J5OM5;*W3E89kc%M@X=uw7HtWadfPZ%WcIhp5q}}d$t{0W#AKXSh164^u8;asZA)(h)y)(>qhj?g4AQMYr2m^C zsu2uUY&lEEQeqJ|<&DLzbD^*M3rxIHH_?Lgzc+_1L5_=4x5s?zlSMV|KrQ{r?)Ek?!|$d z+!yMa^{s)H6I4S}s2G=&XxUXdz1nbVaj2Qh)6W6ehO)X-oOKx-#x!ja@@kIuJlhLP zlo(kG{LsZc&Rb~CR}0SP?Epk*_?;{iK?a8-BV{YmwtlQ`gXl}7xXKMQQPvP@fBJf%Ub|u%VQ0q!&!H(_EGDSciDOU#?MK9OFubIXSaeIsxOasHrtK|WCHiwn8sLL zmI*Dc5xv~@WFLeN*fgI5lOLL*zmBT@VDUk6Y=kAr06V=mq950Lar6}x@P z*Mx_i?%9zH*q{Xq3j@E-MxI>QB2s5!W5qWT;AggU!c`ID|5S;1kZ$jF2LMcDNmY`F z-yKvc?*|z$bk;Vy*^8^6TRA!f^eSRCWs&M7aY)2R%B8P%_j3iMa=g7vtCwC5xxHvq ze$W`U>CNw5hsjfT6V#=I<6uvG!^HyKDdu_?xl(pSHiqt(&6t(@18}|}-H^sQLNw>fjW7)izLGA=wyR;C0jf{k+YI6pElF z8kz2{trf|tp_XI-H}gIZr+flNIR5IgWcJxKdC*={d;o)PvlTcrxTbkp=7QucA522}Kh2j*4nmCuip&VY@Rjl{vK3Mz$ z^&__{1*08HP836Hz}crQEKVl8@yvvcjaZ`F8jJMz!r0Kk{5vyo3tZ2_Q6DXfZR}a6pXn-jgn{>V2O@2qKGb zDEetB>mv)X)wNydeUnux(gCTJcD)6qx2?{xIouhSd^@|u(Ug~5P*{Zq&uMD?U7+iE z4$rG+a#>pzlI*MFhN0^=sGG^W*U@@S_ytyV)AyNfv`I9P`myso?A*0ilzVMPw5!6Z z&^-r{vV17!KbV>9^X^4H$i^pk&8{2V%AhqrjB2yCu9i!p(yXTpiv^MrnRu@Cg_d!Y zxM+Y8?KrbL+LDkj6`A8s16rd6 zA;l*ZQPjU?{cw!%K!7EZ6=12Q{Pwl{5m@7}J>`4*xmOD$gW*%_sfQ&=%<`YvVDR1Q zr}93u9H&1Z*s@DE6=njm6+W%Q zZMjI+!2fpO2%@XMf&jJ;Kn9|Gxef4xwq&X6aph#Y;@EF}p;NAl`Mc}(3rgAua6-d( zwZXeZ@Hs3ZjfHTZpZv|mpA%B5jD5_+heD{ken$7I_bj!pX_C~ zE_1+$=uaf}-d!wkZ!Sb({N0Wy)wn*SK~Zitjw_5cso0aBe$cj)uC74kNE(Wz1h5;F zvhclV_Aq-03+Kq)Sq3D~1s3YPfh4SBNKC zT9K^&i)y*X89{c_CiNSnoHUTShj&pjvEKUd!1H7-l2GS{^s@8payuDFPl!UN5pk@t zusN`6-jsR>mc)A+-a4_yXgiu3W<(2LT1vA1Gdyt^QzTU!uBKkdo9ivlu@d{Z+i)Eu zN5iD)kR6Fux619vhaY1!f+C*PW<&B2CUAGwK$s*_MDmQC(jLg z!cUdKTx^pXNB{>x<5g$#ynYr>gA(*uiKlRx<32W9((})|AR#Ub#Zv^qO>6DK2By#(8!I z2SxG`vLs;VivyLIS>h<*4{}xL&l3PkB zMS^1J+mwdDA-?K#@}SFoYstUKi|$`h^F?Yf+Ek)M{maTKYW5W1IZS45(W6(~{4#;q z5f?S;mv|A{5N9b(hrzq4=4cZ?u|-fS(eF$&Gt&~4C9nO`bii8^0m#0Dbw7dBgp$S@ zI3DO>#G_EqKy&0yPx>JmE^jHVz7%cK746{C!M?d zbiS-E>P`_b|xBXmNyn)8eX2^4Ez0M`hjDAU-4B$n`Y|R0jr&S ztJ|VgIIY9$f%q+!cNd!9wGRdj7J{3HhwFV>WW{Sq_hqU3OZLNc?eBsEA?e2IV}NRs zo%Mhfm{F@Oi%1di);t!5k>nb|HKUdy7Yvb>s*RGUoxx5ZB2K+cMY`v+qV?N}(|uHV zTwJFsVVnz{u``{hGFb#EIjDx_KYPIiD%;;pgoSl2cdUu3 z4{U8j3tLWP^~2wTzEy!<{19Yk$uu1kuVm~OwrsDYu!pWj19AWN>-mUIL?=vaAt}8tRkFzcBCI_5c>&(%`5gen~pb8&v)?A!!T%N&S={Kvo z!6kb>j34}-v74Dz~ zU-anHVPj{1954%vBPk`GosubT^~@}76%lst8{Lc&`u1O`PaMWom^tMs_eW@t4c348*peu+Wa`IKak7@7$bK8Q*`dWIu5spP* zM6cKJ$j%ai4S{6!XNjs16?g=&1;%fB&nNPn_e<($_yR5nJ{4*>P>bRyr_Za5 z7(!hRfq2Q6c2NvE)r{}d=WYyvVP;bT%(Ca{%rM}EKjF5Mgg)>bIE&)Cw9@RtD=&Y2 z%SS*N>11Z+9mr%a{CdKV9o%fKEB^MArUJQN$KWTRb44JP{n1%lZY#;uPX84pVb~Si zopND5l!>h|TR)A}-MCA_%PYRJO*_Pe^x3R7L%pYo96N6>2nQmSvYHA?T(#S!Mcr@( zwRuba(Ahv2k8(lVa4Wq;6)h_ZP7N`QLua`PF99|rR@C^zi9T47bD3R9D*38V2XpK0 zzg$34o2c<}RXX=$kj;(@HH>vHuN%-|HSrbmL+4l?BMNbX7C7|6Q(5ukU4Lc7ojmNd zzr!$T*SU<|deNIsp|L6|0Fpa^n6EXml}FnQOZ4}2!o zTJv9Tu59iC0_Cton^+VH*c4`#F$bK$hmG%Kwu62PL_iALs4zW8fl=XUml!3M7doJ zCJuzk4O1-M*WY)HW}sD5QlNq0DA(HK5KISk97|=wBs{nKP4C;_j}_BbG8&w6Vw8X! z){_+c=PqHex<>+60atxFo?mXZpG$BZpJ})mHw$(O$7i}V%{8}tzkeo5h@!V%K9)CH zYxD>25HF@J&X1Z;0kEdG2e=ifQ3~5SEv&wK_|tE(eH>iCtoE* zoqT1eiIK{r2-#$_18${{EJqtM&}3smBoXzI5IAEzY+w~vdgYZqSg{t>ytDjK;eK4X%$Ra&6zU#zeqk*QAc~q> zqbW#`wAgf#()0$AxRyS`bcM4CU4m)M9c>vyNn<(oJBVg*{7(`Dxd9>YAU)a`G1B6M z1|X|x4I|YQHR5M{YR{98AR2jtah8+D&KSx4`?n&d0mv97itky;1@7|T3W^H@r;DOG z_B&}&McfB}I$N%l!&sbh@relnE?G|&bls(D*iPufkU$q!#`xQ1jVopr#x>jG52ggh z0PC5afDtv}ScH+d(KqTIL-2aG82CCF6LSdvBpC!Jywo1X{SgzV^QHwA+xNAt$nNyP zmKi}yGRIC-e2n8}Ghwj&pUM^&K?q4l!D|o{TUnaz$*D?eD;_^j+?^dpl$-8Bz-fq* zh_KS7g4B9nf{E#icb$Es4A2jG;@gE?y^O@Gu=U4QoEUTFLOeJ~zT9 zAOb-c%Ms~Kv`U~7NKxdqm6tnevA_jI+w@9QYT#9ctt{pV4ee18WsmK0BCN<5h|jH# zI|M`E{C$Kyy~K$i4|3`3b8~9Z0~to4*vg%KmVyWxM(V96ld=TDt*Jwd*CeDf?=WCU z7O9L3ngRsRKCRk~Fhew=Obm^t;)#TzMc{Siyllv{yY~!;b}=4sKboi;*p-N zNJ=zvStYo{jt8}4>*a%mYoU=zxam7+1<1LD4y_+27#^E2P9jmy9h$IvO9-Api5_H* zrnHb{!qrQ6$9kBBol;{$!73&zRXvSGRBF zJglWQ=MA`EbY>=a)AuvxnDullY$I%koaDYCIr5m$U!24yOzz~u;L7-xS)v_}b%6al z61VW4Y1M+TK@lVPmsJ&8j#}?G4CwqIr=X*k68ho3Sy+8b6{3Yc;|Uphx<=zqH{BXe zU0O*FMG8qNm&pD`1<6^<_g$?aBrPdpUi4Wi)&>B0Lhkd%tpAz5>x|MFde6!ThH6k& zSBZ6ZNdFCu#Vtb$5Q~#378Lpe-7t)J*c(-;-@jjf%xR&@58kY{dF-z=Pl&1hAgsJ| z7nFco!wSWxkO{vDBJ5$Jq{BxxifHp>RnbgzAObnn<2(0x9$qQ_@OO$ZUfK z=VpSiuz@C>c^(y^McjHyNF7K-+2YXRqApgHwu(8z7PeqT2^7-Q+u(l}uxO3`Cxu<6 z_^KmEhwshK2wCW=9w0l4M;cX{2%WDWWvsbMg*+oh12Y|7P_CVdK;z4htA;_D#QpnH zYdphPs!>f{UC4&~j=(Df*VZS!qPjiK7;!U1`}3Nu-bWKQ#6|uo1WBa?g{dULim^SQ zd5dkt;M|?pUGJ<;P6>y&7@91V2qz{knDp`*=`7P~RI4F^PA`!14{B-%x0(+DyJ(a3WI5(! zRcIJ_GXWqx$&t_^ywoUx{<7{_G0TS6H1eU%-)n7yv77TeD<2;?Oq4t|cOSubkV1hM}B zx`5_QtyjE=Z6f-j?CI2yhI7LuSOB5xiaYWY!)H**TgR`-RB{jSAki3d6201Xog%H` zD9{u#SmFTH{4)=FM(m`pF(EUR*%1z{?}v@$7LuM-K#jNe)Ns9b!3=a4Vue~&K~+_i z1P(Fynl;OJYLo7VVx&OT3bntUF4}eCYc}T94(A%mqBkRHhqwFEVxHN)>zm#u=Y`8I zaW=4pTl|wWoc3nW6fMADoxyn z?vA?IXQo65r8S~F*+_&&MkyUJAfiR}%$MS*%mViYSeJA^0&LqAsPcRm6=dAgEt3=* z!_LBMd>&}m?LWSMAauVR3OHR0LD@fTJUJeu2$+7phgh-Q3^MBda~LKVbe(G!p=IGD zO^uoD!oCkctEJ0&d_=fxzw)*1{*cS@Jk1bl@0{Xi0cI;;TsH%8v(wWWj-sR*$tvfd0j81|bQ%LehJo>M*A?J-tIe6~dzRz8|H^VvGMM?ZK zcDZrB_ZpwKX}HVgxfv2;!Y}+hGPwQ5&3a-Fu1R@+({#KzZZyy;0EwGuX>lS-zO(nP zX8how8H35|(g=f2OVUd|hs{)Rp#?@p;Ev9f>k70AxrAsFa{m>{kWGB*4GFcPqT*gt zkoe52TK(UDMRk+ztu`)(1!*=NTyy+z6Fm}2aZHyyY@sklj~}6@1-%l8_YqH8;z$wK z)A|(R>L{b12V3|1!9r*-6d{sK>H6;KL_JxCb!vs*2V8aA>C)N49c8X}=zlM1#t-5& z-A8=>hcbDTn4ut-v;-rRze6l@3V{vtpi{K=Zx?RBQv2FuU%c;Ps)MQDoyhwoN;7;Q zY`q`4Ev+6+@?7l@ezj!F`>FGu0O=ujPu4=fl1L+!%^K*}NGu?UzH#apyJSS6oS_l1YF-2^8O}wcy-=N|#8*Vq^mA z*GLt0b@#0WiCtWO9nl;&B>(p4=-&+u1+br5TWSY91qzamztRw;^$1hv9_9GB_9&0% zcKc6xAHF~Qx@?%6+~089Dkvx*vDl)L|Ihm*N$F#xP*%sF<2Z7_d)?#ra=_>}8Vs{H zqh+9(Kg2K%|0;%_J7M~Kn$|s;?Y7Y_WV7zRC(HY|j0z@Y;!m059yuq+lR~KLfPFQd zInn*)BVg|R()@n@^i+FApgos1d9!Fx#0Y(X(dOI=SuiOZSD__fTDOsIWVSG zVD^NM2o)>Z0}AS5yO;X!Q&Q-wn^MEk5>gdA#Hq>@F*a}o%-@E8nE5MyEg^C9BQZnBC6{ zJ)bO`6tbEBKHsj!zH_{r{`&y7U5~uqD_k`IYz7DFhZ++0bUuA1;o9N7;$5NV19Vwc z_US22*VhN3!)O`hZO_M->ELLJin=PZ;k$;f_lRkO{8kWYO{dn;D_Y2a%hB@Aj_7H=>?^|z23GJtapz@M;hXp=~-L^jD zeeVl%;c0dNOt{lARvA{%@CJnB~i%kr)t|8?GC_*jyNnELe&8$ zZOiAYfWt|#0N(BDn9p~SBq4EeN?~V1%|v(JfU8y$Cy#R+D-AyD$i&*xVe{q<;LnFv z*#MSza+Kn!GPLwlQmq#~0mPu4Kz1B>Z+&dHECX(h+hP{T4HsGEN$xXvhK^(Jy~cdv z3>n9N#8BYfk%ErTqKu)KdhVA00(^lbHdt?Q2!K;IKwHEqHUxftzlrPK`RxR`_!F21 zb-r&yZ538GO%`zrUhhVUh>w20E(w_T1;eDr)6L~v(eB=W=YHN}^FD9)i%3<^Ucfxa zXWIX&N*;3HH73N1P60 zuMdvDIY+r~I)#QTRS{H@`rm=BV=Em^?p9tA(-r36HA*oWK0cw6es5tcjZf_0_5{aH z3}51dtPbb&hP|NlckLzP~#(V?Gx>hVo97%U!v2H zeM8`W76i6A=dp+Lwy-IGMS%%IJAE;kJ{913l4V~A%^kyjn=R2Zwvrk#MuEwtRwl~Y zXt1dc091wMu7$5RmZ_fa9M7Xb6eVsboM9I%HGNeIN|Oz){I#GGRWT@=iv=PM)-a(T z{D?|9Suto}=%r?Ul8q$FDKVS3)8Y7zmHvmT0OEs!Pr;ZT|5#cL-aQ^@q)-eT1|8ll zzf2}8b3A~A{}uZ&cRTyg|M});qF@B0c-dr%mtv}Elz02ge1H);LOFk5`lmi_nD6w| zARfmi3zig>-H2gU1|o$3IPHe^nH%VMm%jH+sye)4qVzt6rIv`zmA*=++UG%T9-&!Dvk9as&x814?-l#B`GN)w!sdzC9i`aEXz0_UFx-|BbLxG_Gw??MoU)d-SU zUQK2B!YFTDBxynG*JfG-*hU$YHOaS983~~^4$@ zlc|a*IMqPkYIwfvp z^1MIxg?m)0RI-ex!H2|+^qp&!4XCJobx@Qggj7NZv1S(`G3+=P8$ptbU^~R=E7%3g z=e@2c>y$>}B`dyIK0FO`8YM)rnx`6!;XoLRs=^}`p>qmKQKNx&!7{5dh%c6~lx0ur zum9QoPeVbJ^j3jtP%7h_J93rH-<2y*9mGDnTlKq00Ac*u^r=```%RE%|zwK}P8 zbaWCD9pmGI8Un0HXll;0N){;Hn)&&8`*jC#^fd-V3h>B`Bq@_r_pUhT&Id5luq!VK z%aj)iglJ5C?-0PB#hng|8yE~5rLb%IH-i|RA?FBc^UjP$cNRAENJWCUj-IrxK>(gz z+OQW4L@(XF&?yUbw)=qb_H{6dGoYX3VtA-|D75a zT5JD!dV) zPjLJBek7)3)XFWP&-?rIkp%Qd@WPQO6C2%%%OW2-t3cRQiTx^uR3=leX}X^-i_NN9 zP>$8-JH)!rlK|IcOCikGX+D$~QBS1L>x5nP!Q0lCV)q`Q*Evj%&tsh+wc%_nzT-b3 z-k1G3u%@SZ+~=*MFwd5B=L=-+#DxS?^NV`IwNYJ+nW~U&}*)4fCCFT9C3v1(AAN(H=l~ z;6SEn(P$8+mc$i2W1^5))V753Zu@_vvszwuvm8z&Ao$U<2 z>n?jN)S3A<$k{rzp|;56X~PYOXz6r6%m@Hu2a)4CfS=?)lR@A+H%J^oW9+;~Nz-wB ziC5PPgkDr7AN=`;SY?z(ulwq){|%OrPzfL3eo*wf4&>CpPBF5{UajuuT7Qg zl{2EL)7^4LG9l6E(gqZC#sJ>xUt`GXH`h4VY--~Gu{r!vYVV>pe z;v0{{EI0u6&P689-0{}hfM|EVs+-X%kOyTDx14&}{n z?nwBBp`5OKY+MshUv5OE=XO! zgY`?+f=uxbG5rA{tCiYa`yMBdaPF!bpS1}*uR~0jgz;#FMzuzx>_wNiLsXf3E;_(O9t7>-Zy;8i2$xjC{ns58=Gf?@4!DWVj`xi@>&_d>`+jMv%7LoA!~8L z)II1-RZIuBrEEH+u9&(ZZa*CWw5rjtqU#++}1sIG7RgZZT^ zp#`f9Yc(FUSU^}x30;!^Wkh+!vJ~F~unWK3A$dRmeRCA&<8tA$ZTm&&eFv4ir0mM- z1u7UE4Xm@R@4W%Db`3p~bT z*bO72#qsNBLK3+@nQJzFyJ$Z6?KEk2GHYt`H={sWnGVDF{n??nv1aW^AtmHr@*`Ze zKGcU|SHvvBDIwxg5<~UVA{{Q53_u>d(3@Ve2~RI3lz8o&o2rrVlivpsxE?Fjfl$-O zvkRw{cW1314)GZ87D*=0pOydN=7`vsV1uEb382^5g(4=!iP^GaO7U_I?(*s7P|gzS zU#=^jtsim{DQqCow~rb?D@IKZs8Ut9q-34e5-mgjb=dbSdZ>E}f%jh8K_lO-jrk!Y zg<{E+3~}KC4UBAzPh*kaGyM;`4A+A66lTi-1{PcS{aV+X?d@S$h(?8;$FIBzKDl}t zEXw}BJ!rt_w(zGCvI$urrb-hczSei_z7U0D9J0u{$O7dmqJLpjLcfYI3^H(f^Y2No zN1=I9VD5gw@9W^cqDxUn@p~Q23QNb(SLQhyGiiiSZxlBEU~6YBM){MUo1;o&q`=+H zWtV=&$2-(ygESE%59WZSFeHZ(qN)ckSKV1rAB`x2tyVYymHhBQ^Adf zFa$$p*7dB2QsY#{AkfDEZiddttclsWU+CXcQdn8pb~zGVQaR_{$()bH~6melr#-_p5~DqA zv?=9S_nRu}>NRh~;IPQT!~3!f5`jRwoC<{54H5*ck#eAmC2Hh2cFO)@jaE^qy)DMf za0#RK<}f}Q;kCab9}s7&^;{&$@o`~b?D}MII%F(tW)$JS+I``l`6J?D8mZIO11Oq1 z`cIn^-ItwN`no=Bn7Z}>8B*vvk-s=VEzzsX`p&4kibNF6P?my6qwJGIVBy!Am>R2{ z9pW*;9v2n-Fwu#Ha1S2qr?+M**PIn37|x@hjlN!`>lCh&lVqhv1|m??c|B0>4kz@_ zJ}+C3F3;RSair6)k$ExiuPfF@j>H8Icz3xNSw-a8H+lLTGb;4FAOs{9pHEN`iTJBC z&I?lJ7u4*>nZ2BVmq;(z)G#8edMY_`lOoo=^Mm7Bcg`WpR$DBijyv&OsofOF+l74T z8q9NERs9JeIdc(uUSL@sq3dYND^-n!`fM9nXJ2D1#h-U931aLP(2&uHL}6kJjb!h; zXjJ?7G`*Ap_kNAPiGlel!<-GnDX3lEk#2xey*+>`f92)}gik2rIZuO}c_ZK$j;6jE zhELmb3Lh(Il0>rm8xH*oJ=`nwk&o&L{;Md@K=#bvDJ_ekM{TfHW}L?%-TOd=reo6N zpDcz-0V+}M3Icvh)pdiov6G65zH-2_ua3XM?BqA#Bmxv-(i zC%Y=OZHq^Be!U%<#o3duqa_#Mq|$VRRnyh#dMBQH#0rr6Ub3nA;p7rz1xU%qGwD#< zZ^hQt#KL7Mi9MNX4*HOf1YaO_k}e z|Ex%x%0wyU|4}STN0@A)mOC4#o0566roC>WQfp-NP!%*-Mkj}XN)IkVT0ZNkyVxCX zR1q!f#`NU^X$UPJl>%}Qe;PPbQS-gh(+%Q$iAVfIYcj>H%47JlK_$^(1PzCc__U&3 zS1W1ZwwiM@{JEsj`{g*-!h2RAghg~^xytR?qvdANf)@CbPRdIg(e2|x$oDymN3(Vw zt;zKc`#ms2OHEa!T3tZOrD+r@hwwdR0wU<6u?G5sR%Y4&tH91$y9?;@=0XjcF8uw4 zd-o-u__}{G*?9pg(0{D)kd0$CZd73UsuFk_A9w3vRUxuxjY{RdR|`q^4&jj&#K!K zPSynE&E+|&XVG{=j#{_TYmuOi`lx++@Xt14MB#WfhS5i4I$3f-5_BR zSM@w0b;ZS_=wr(o_dY1(a|ykel~dD^9;5(YVlO%fTWoA8^pr+FV5LR$w(h7}7<_w0 z!GQFj5?!}pNwqv)x4gA3WJBw4ZV3ufk`ZEWadvz$wGD(JW+&8+!3HI&8bKh?ftXZL zg?I`qu6t;S-w(9zj|exP{cVo@>-K@dCx{Cb=|9foh&d0Oelj|)R z?ByBROG!C}q;-QVr}M)w$HV7;0WAd5`yo}+cj1_8uhUwqqP9^=GpNmAlEGk>!F0rE z>7*3FJ`rRjYOdV#N`qRMK`kt+SXe=pE3}IKj6jg@UI-Zw%qoVvgh^Yp(nYdJ$1(Zp zwMjvy{m`Zq*ubo{he4?DD4IgFE>2Iwn5%GV8z!(RTRaxq7DlI!mrj%G>MZ`gNuf2` zNJec=kM6z=RgJZDuj&j2 zL2y8z-BUoswiz?J0V?9q*2(3+_&Wf*pLMn#KXYd?ZJE+MK2#LwFa_8E?+IPnFC>nL zLHXjpI5>-LL`7{#(d)eYd|LbFF6y-xsUnbm8Vwm;EIzT;>>xZgRG<_aDeJ>P(7O1m z2=^~uN_fr^(Tb*sO^(nT#msfZ&O3d0!$t;yA(rR*t2;pg0OoF;YDa+EMD zD5bR2q9RMFsD%q7szu9wBny(FMVq1)VFfNk3pYX8Iw)5{3nN5Z6h#^Y5}H4uHR_B% z&# z-##A0bS5DpQG3o{N{}=?R)q}zAfAWv zhNh#-3VgcX35AdnoE*MiixLc)-LGDVmbYD$p1#Gc>N6ZI z*@-tJ4V(_b!5KQ=f2FnK3yn{j`7th&vTZx%bq!{A$suM&N0!**Hp_{q$rCvnj)unl zfq2X#p2CuWNcA3z1E&yz@Z<#Ub?Z4^--z4mrR(`qy1IurTm6X5>4_8<=8~44LT~>s zIy(EKV~eY>V?uk?MwZ7(id%GC;^qGOS#YIyR$T@b5y;WDtt>qPHvP{V5 zgMa8dWfNzFz)&@GISim6Nh`Q1Lhhnw0Q21jdp=5ND!QyJ_6#gV2!S#+g_M{`_JM=s zl$NnC5MXL>fbrfQg2O|&l9TZk>|%X>0nW8aV#Ys)p(qP?LN?1AAz^6he8h6$aI9cq z-M+N&Jf9bMv)6}jB&bDJmYCW`hN?2_AB(n3&Dc>uX3_3wjRswjQA44~d1`Kq&HsTU zMTDiRvH<#UI^!}A5z}IsHa4q-c-rkYHY*leST;5m@EboCL7SPDtj_=d002ovPDHLk FV1i4yLpJ~b From 649861db57cb01aa777011d8a470ffa39dad9422 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 20:18:12 +0800 Subject: [PATCH 545/843] fix download callback in http3 with gzip body --- internal/http3/gzip_reader.go | 16 ++++++++-------- transport.go | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/internal/http3/gzip_reader.go b/internal/http3/gzip_reader.go index 01983ac7..9050623e 100644 --- a/internal/http3/gzip_reader.go +++ b/internal/http3/gzip_reader.go @@ -2,7 +2,7 @@ package http3 // copied from net/transport.go -// gzipReader wraps a response body so it can lazily +// GzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read import ( "compress/gzip" @@ -10,22 +10,22 @@ import ( ) // call gzip.NewReader on the first call to Read -type gzipReader struct { - body io.ReadCloser // underlying Response.Body +type GzipReader struct { + Body io.ReadCloser // underlying Response.Body zr *gzip.Reader // lazily-initialized gzip reader zerr error // sticky error } func newGzipReader(body io.ReadCloser) io.ReadCloser { - return &gzipReader{body: body} + return &GzipReader{Body: body} } -func (gz *gzipReader) Read(p []byte) (n int, err error) { +func (gz *GzipReader) Read(p []byte) (n int, err error) { if gz.zerr != nil { return 0, gz.zerr } if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) + gz.zr, err = gzip.NewReader(gz.Body) if err != nil { gz.zerr = err return 0, err @@ -34,6 +34,6 @@ func (gz *gzipReader) Read(p []byte) (n int, err error) { return gz.zr.Read(p) } -func (gz *gzipReader) Close() error { - return gz.body.Close() +func (gz *GzipReader) Close() error { + return gz.Body.Close() } diff --git a/transport.go b/transport.go index db144a7d..1e071f17 100644 --- a/transport.go +++ b/transport.go @@ -454,6 +454,8 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu b.body.body = wrap(b.body.body) case *http2.GzipReader: b.Body = wrap(b.Body) + case *http3.GzipReader: + b.Body = wrap(b.Body) default: res.Body = wrap(res.Body) } From 027d94aab2203876a86b8aaaa1355a25773235b5 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 7 Jul 2022 20:27:08 +0800 Subject: [PATCH 546/843] unexpose ForceHttpVersion --- transport.go | 40 ++++++++++++++++++++-------------------- transport_test.go | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/transport.go b/transport.go index 1e071f17..3d7692a9 100644 --- a/transport.go +++ b/transport.go @@ -49,16 +49,16 @@ import ( "golang.org/x/net/http/httpguts" ) -// HttpVersion represents http version. -type HttpVersion string +// httpVersion represents http version. +type httpVersion string const ( - // HTTP1 represents "HTTP/1.1" - HTTP1 HttpVersion = "1.1" - // HTTP2 represents "HTTP/2.0" - HTTP2 HttpVersion = "2" - // HTTP3 represents "HTTP/3.0" - HTTP3 HttpVersion = "3" + // h1 represents "HTTP/1.1" + h1 httpVersion = "1.1" + // h2 represents "HTTP/2.0" + h2 httpVersion = "2" + // h3 represents "HTTP/3.0" + h3 httpVersion = "3" ) // defaultMaxIdleConnsPerHost is the default value of Transport's @@ -118,7 +118,7 @@ type Transport struct { pendingAltSvcsMu sync.Mutex // Force using specific http version - ForceHttpVersion HttpVersion + forceHttpVersion httpVersion transport.Options @@ -303,28 +303,28 @@ type pendingAltSvc struct { // EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (t *Transport) EnableForceHTTP1() *Transport { - t.ForceHttpVersion = HTTP1 + t.forceHttpVersion = h1 return t } // EnableForceHTTP2 enable force using HTTP2 for https requests // (disabled by default). func (t *Transport) EnableForceHTTP2() *Transport { - t.ForceHttpVersion = HTTP2 + t.forceHttpVersion = h2 return t } // EnableForceHTTP3 enable force using HTTP3 for https requests // (disabled by default). func (t *Transport) EnableForceHTTP3() *Transport { - t.ForceHttpVersion = HTTP3 + t.forceHttpVersion = h3 return t } // DisableForceHttpVersion disable force using specified http // version (disabled by default). func (t *Transport) DisableForceHttpVersion() *Transport { - t.ForceHttpVersion = "" + t.forceHttpVersion = "" return t } @@ -672,7 +672,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } } - if t.ForceHttpVersion == HTTP3 { + if t.forceHttpVersion == h3 { return t.t3.RoundTrip(req) } @@ -680,7 +680,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error cancelKey := cancelKey{origReq} req = setupRewindBody(req) - if scheme == "https" && t.ForceHttpVersion != HTTP1 { + if scheme == "https" && t.forceHttpVersion != h1 { resp, err := t.t2.RoundTripOnlyCachedConn(req) if err != http2.ErrNoCachedConn { return resp, err @@ -742,7 +742,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } var resp *http.Response - if t.ForceHttpVersion != HTTP1 && pconn.alt != nil { + if t.forceHttpVersion != h1 && pconn.alt != nil { // HTTP/2 path. t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) @@ -941,7 +941,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } - cm.onlyH1 = t.ForceHttpVersion == HTTP1 || requestRequiresHTTP1(treq.Request) + cm.onlyH1 = t.forceHttpVersion == h1 || requestRequiresHTTP1(treq.Request) return cm, err } @@ -1648,7 +1648,7 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace } pc.tlsState = &cs pc.conn = tlsConn - if !forProxy && pc.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2.NextProtoTLS { + if !forProxy && pc.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != http2.NextProtoTLS { return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil @@ -1708,7 +1708,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers trace.TLSHandshakeDone(cs, nil) } pconn.tlsState = &cs - if cm.proxyURL == nil && pconn.t.ForceHttpVersion == HTTP2 && cs.NegotiatedProtocol != http2.NextProtoTLS { + if cm.proxyURL == nil && pconn.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != http2.NextProtoTLS { return nil, newHttp2NotSupportedError(cs.NegotiatedProtocol) } } @@ -1842,7 +1842,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } } - if s := pconn.tlsState; t.ForceHttpVersion != HTTP1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + if s := pconn.tlsState; t.forceHttpVersion != h1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if s.NegotiatedProtocol == http2.NextProtoTLS { if used, err := t.t2.AddConn(pconn.conn, cm.targetAddr); err != nil { go pconn.conn.Close() diff --git a/transport_test.go b/transport_test.go index de69826e..102bbc5a 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5436,7 +5436,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { func TestTransportClone(t *testing.T) { tr := &Transport{ - ForceHttpVersion: HTTP1, + forceHttpVersion: h1, Options: transport.Options{ Proxy: func(*http.Request) (*url.URL, error) { panic("") }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, From 97723949ab4a2a57f85b22cd05efc4097a669835 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 11 Jul 2022 13:57:58 +0800 Subject: [PATCH 547/843] declare in the comment that h3 does not support trace yet --- client.go | 2 +- request.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 06df9e2c..f8e4dcd3 100644 --- a/client.go +++ b/client.go @@ -692,7 +692,7 @@ func (c *Client) DisableTraceAll() *Client { return c } -// EnableTraceAll enable trace for all requests. +// EnableTraceAll enable trace for all requests (http3 currently does not support trace). func (c *Client) EnableTraceAll() *Client { c.trace = true return c diff --git a/request.go b/request.go index e56a8ea4..89019b74 100644 --- a/request.go +++ b/request.go @@ -686,7 +686,7 @@ func (r *Request) DisableTrace() *Request { return r } -// EnableTrace enables trace. +// EnableTrace enables trace (http3 currently does not support trace). func (r *Request) EnableTrace() *Request { if r.trace == nil { r.trace = &clientTrace{} From db241c246357ab6a8e1a5e5324941a779faf6671 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 11 Jul 2022 14:35:26 +0800 Subject: [PATCH 548/843] ensure EnableHTTP3 in EnableForceHTTP3 --- transport.go | 1 + 1 file changed, 1 insertion(+) diff --git a/transport.go b/transport.go index 3d7692a9..f0d954b8 100644 --- a/transport.go +++ b/transport.go @@ -317,6 +317,7 @@ func (t *Transport) EnableForceHTTP2() *Transport { // EnableForceHTTP3 enable force using HTTP3 for https requests // (disabled by default). func (t *Transport) EnableForceHTTP3() *Transport { + t.EnableHTTP3() t.forceHttpVersion = h3 return t } From 637f9c0b21c75e03997b77590d62f597cf3ba1d6 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 11 Jul 2022 19:27:30 +0800 Subject: [PATCH 549/843] ajust log level to debug when cannot determine the unmarshal function(#133) --- middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index f6d595f1..613ae422 100644 --- a/middleware.go +++ b/middleware.go @@ -266,7 +266,7 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } else if util.IsXMLType(ct) { return c.xmlUnmarshal(body, v) } else { - c.log.Warnf("cannot determine the unmarshal function with %q Content-Type, default to json", ct) + c.log.Debugf("cannot determine the unmarshal function with %q Content-Type, default to json", ct) return c.jsonUnmarshal(body, v) } return From da5269b93b58ec42ed03b6372a4a5f7f1aec6f7f Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 2 Aug 2022 17:24:13 +0800 Subject: [PATCH 550/843] ignore empty proxy url in SetProxyUR (#145) --- client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client.go b/client.go index f8e4dcd3..24b65a22 100644 --- a/client.go +++ b/client.go @@ -676,6 +676,10 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { // SetProxyURL set proxy from the proxy URL. func (c *Client) SetProxyURL(proxyUrl string) *Client { + if proxyUrl == "" { + c.log.Warnf("ignore empty proxy url in SetProxyURL") + return c + } u, err := urlpkg.Parse(proxyUrl) if err != nil { c.log.Errorf("failed to parse proxy url %s: %v", proxyUrl, err) From fafb362acbe191c7fe3a87034b0a4b2cda3af4b0 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 3 Aug 2022 14:10:06 +0800 Subject: [PATCH 551/843] http3: ignore context after response when using DontCloseRequestStream Merged from upstream: https://github.com/lucas-clemente/quic-go/pull/3473 --- internal/http3/body.go | 4 +++- internal/http3/client.go | 13 ++++++++++++- internal/http3/client_test.go | 22 +++++++++++++++++++++- internal/http3/roundtrip.go | 2 +- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/internal/http3/body.go b/internal/http3/body.go index b3d1afd7..d6e704eb 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -110,7 +110,9 @@ func (r *hijackableBody) requestDone() { if r.reqDoneClosed || r.reqDone == nil { return } - close(r.reqDone) + if r.reqDone != nil { + close(r.reqDone) + } r.reqDoneClosed = true } diff --git a/internal/http3/client.go b/internal/http3/client.go index 032d0dce..f43c9ba1 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -273,7 +273,9 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. reqDone := make(chan struct{}) + done := make(chan struct{}) go func() { + defer close(done) select { case <-req.Context().Done(): str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) @@ -282,9 +284,14 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } }() - rsp, rerr := c.doRequest(req, str, opt, reqDone) + doneChan := reqDone + if opt.DontCloseRequestStream { + doneChan = nil + } + rsp, rerr := c.doRequest(req, str, opt, doneChan) if rerr.err != nil { // if any error occurred close(reqDone) + <-done if rerr.streamErr != 0 { // if it was a stream error str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) } @@ -296,6 +303,10 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } } + if opt.DontCloseRequestStream { + close(reqDone) + <-done + } return rsp, rerr.err } diff --git a/internal/http3/client_test.go b/internal/http3/client_test.go index 9952572b..5531fe17 100644 --- a/internal/http3/client_test.go +++ b/internal/http3/client_test.go @@ -14,9 +14,9 @@ import ( mockquic "github.com/imroc/req/v3/internal/mocks/quic" "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quicvarint" "github.com/imroc/req/v3/internal/utils" "github.com/lucas-clemente/quic-go" - "github.com/imroc/req/v3/internal/quicvarint" "github.com/golang/mock/gomock" "github.com/marten-seemann/qpack" @@ -913,6 +913,26 @@ var _ = Describe("Client", func() { cancel() Eventually(done).Should(BeClosed()) }) + + It("doesn't cancel a request if DontCloseRequestStream is set", func() { + rspBuf := bytes.NewBuffer(getResponse(404)) + + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) + Expect(err).ToNot(HaveOccurred()) + cancel() + _, err = io.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + }) }) Context("gzip compression", func() { diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index e45ec85b..06cf90db 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -67,7 +67,7 @@ type RoundTripOpt struct { // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool - // DontCloseRequestStream controls whether the request stream is closed after sending the request. + // If set, context cancellations have no effect after the response headers are received. DontCloseRequestStream bool } From e01e83e1d8a5c1181e2368210084a77bbb61e24b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 3 Aug 2022 15:08:07 +0800 Subject: [PATCH 552/843] support go1.19 --- go.mod | 10 +++++----- go.sum | 15 +++++++++++++++ internal/qtls/go118.go | 4 ++-- internal/qtls/go120.go | 6 ++++++ transfer.go | 32 ++++++++++++++++++++++++-------- 5 files changed, 52 insertions(+), 15 deletions(-) create mode 100644 internal/qtls/go120.go diff --git a/go.mod b/go.mod index a1b1edd0..88995321 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/golang/mock v1.6.0 github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 - github.com/lucas-clemente/quic-go v0.28.0 + github.com/lucas-clemente/quic-go v0.28.1 github.com/marten-seemann/qpack v0.2.1 github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 @@ -15,10 +15,10 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.13.0 - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - golang.org/x/net v0.0.0-20220630215102-69896b714898 - golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b + golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect golang.org/x/text v0.3.7 - golang.org/x/tools v0.1.11 // indirect + golang.org/x/tools v0.1.12 // indirect golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect ) diff --git a/go.sum b/go.sum index 26a2e59a..9ba9df93 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ github.com/lucas-clemente/quic-go v0.27.2 h1:zsMwwniyybb8B/UDNXRSYee7WpQJVOcjQEG github.com/lucas-clemente/quic-go v0.27.2/go.mod h1:vXgO/11FBSKM+js1NxoaQ/bPtVFYfB7uxhfHXyMhl1A= github.com/lucas-clemente/quic-go v0.28.0 h1:9eXVRgIkMQQyiyorz/dAaOYIx3TFzXsIFkNFz4cxuJM= github.com/lucas-clemente/quic-go v0.28.0/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= @@ -163,6 +165,7 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= @@ -177,6 +180,8 @@ golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -209,6 +214,9 @@ golang.org/x/net v0.0.0-20220615171555-694bf12d69de/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b h1:3ogNYyK4oIQdIKzTu68hQrr4iuVxF3AxKl9Aj/eDrw0= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -221,6 +229,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -247,6 +256,10 @@ golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -269,6 +282,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/qtls/go118.go b/internal/qtls/go118.go index 20f6a63f..312bec4e 100644 --- a/internal/qtls/go118.go +++ b/internal/qtls/go118.go @@ -1,5 +1,5 @@ -//go:build go1.18 -// +build go1.18 +//go:build go1.18 && !go1.19 +// +build go1.18,!go1.19 package qtls diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go new file mode 100644 index 00000000..f8d59d8b --- /dev/null +++ b/internal/qtls/go120.go @@ -0,0 +1,6 @@ +//go:build go1.20 +// +build go1.20 + +package qtls + +var _ int = "The version of quic-go you're using can't be built on Go 1.20 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/transfer.go b/transfer.go index 8de2b9f1..69a710e5 100644 --- a/transfer.go +++ b/transfer.go @@ -169,10 +169,11 @@ func (t *transferWriter) shouldSendChunkedRequestBody() bool { // headers before the pipe is fed data), we need to be careful and bound how // long we wait for it. This delay will only affect users if all the following // are true: -// * the request body blocks -// * the content length is not set (or set to -1) -// * the method doesn't usually have a body (GET, HEAD, DELETE, ...) -// * there is no transfer-encoding=chunked already set. +// - the request body blocks +// - the content length is not set (or set to -1) +// - the method doesn't usually have a body (GET, HEAD, DELETE, ...) +// - there is no transfer-encoding=chunked already set. +// // In other words, this delay will not normally affect anybody, and there // are workarounds if it does. func (t *transferWriter) probeRequestBody() { @@ -411,8 +412,8 @@ func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err // // This function is only intended for use in writeBody. func (t *transferWriter) unwrapBody() io.Reader { - if reflect.TypeOf(t.Body) == nopCloserType { - return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) + if r, ok := unwrapNopCloser(t.Body); ok { + return r } if r, ok := t.Body.(*readTrackingBody); ok { r.didRead = true @@ -1004,6 +1005,21 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { } var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil)) +var nopCloserWriterToType = reflect.TypeOf(ioutil.NopCloser(struct { + io.Reader + io.WriterTo +}{})) + +// unwrapNopCloser return the underlying reader and true if r is a NopCloser +// else it return false +func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { + switch reflect.TypeOf(r) { + case nopCloserType, nopCloserWriterToType: + return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true + default: + return nil, false + } +} // isKnownInMemoryReader reports whether r is a type known to not // block on Read. Its caller uses this as an optional optimization to @@ -1013,8 +1029,8 @@ func isKnownInMemoryReader(r io.Reader) bool { case *bytes.Reader, *bytes.Buffer, *strings.Reader: return true } - if reflect.TypeOf(r) == nopCloserType { - return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader)) + if r, ok := unwrapNopCloser(r); ok { + return isKnownInMemoryReader(r) } if r, ok := r.(*readTrackingBody); ok { return isKnownInMemoryReader(r.ReadCloser) From 4126a6bbce9e73b8240adb5303054ff651608b86 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 3 Aug 2022 15:12:19 +0800 Subject: [PATCH 553/843] add go1.19 for CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9adcefe6..ada219da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.18.x', '1.17.x' ] + go: [ '1.19.x', '1.18.x', '1.17.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: From 5eb452e7d0355872b54adcaca24af755c6f1a2e2 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 3 Aug 2022 15:13:41 +0800 Subject: [PATCH 554/843] remove -p=1 in go test --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ada219da..a479e10f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,4 +30,4 @@ jobs: with: go-version: ${{ matrix.go }} - name: Test - run: go test -p=1 ./... -coverprofile=coverage.txt \ No newline at end of file + run: go test ./... -coverprofile=coverage.txt \ No newline at end of file From 059a0ee91d566cff0736807e1e0aada077ec0af3 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 4 Aug 2022 19:27:43 +0800 Subject: [PATCH 555/843] Support response middleware executed even error is not nil (#140) --- client.go | 20 ++++++++++---------- response.go | 14 ++++++++++---- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/client.go b/client.go index 24b65a22..e9931aa7 100644 --- a/client.go +++ b/client.go @@ -820,10 +820,11 @@ func (c *Client) SetCommonRetryCount(count int) *Client { // SetCommonRetryInterval sets the custom GetRetryIntervalFunc for all requests, // you can use this to implement your own backoff retry algorithm. // For example: -// req.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { -// sleep := 0.01 * math.Exp2(float64(attempt)) -// return time.Duration(math.Min(2, sleep)) * time.Second -// }) +// +// req.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { +// sleep := 0.01 * math.Exp2(float64(attempt)) +// return time.Duration(math.Min(2, sleep)) * time.Second +// }) func (c *Client) SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { c.getRetryOption().GetRetryInterval = getRetryIntervalFunc return c @@ -878,7 +879,8 @@ func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { // SetUnixSocket set client to dial connection use unix socket. // For example: -// client.SetUnixSocket("/var/run/custom.sock") +// +// client.SetUnixSocket("/var/run/custom.sock") func (c *Client) SetUnixSocket(file string) *Client { return c.SetDial(func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer @@ -1113,13 +1115,11 @@ func (c *Client) do(r *Request) (resp *Response, err error) { resp.error = nil } - if err != nil { - return - } + resp.Err = err for _, f := range r.client.afterResponse { - if err = f(r.client, resp); err != nil { - return + if err := f(r.client, resp); err != nil { + return resp, err } } return diff --git a/response.go b/response.go index 60593bb8..76a2dcc4 100644 --- a/response.go +++ b/response.go @@ -10,7 +10,13 @@ import ( // Response is the http response. type Response struct { + // The underlying http.Response is embed into Response. *http.Response + // Err is the underlying error, not nil if some error occurs. + // Usually used in the ResponseMiddleware, you can skip logic in + // ResponseMiddleware that doesn't need to be executed when err occurs. + Err error + // Request is the Response's related Request. Request *Request body []byte receivedAt time.Time @@ -109,8 +115,8 @@ func (r *Response) Unmarshal(v interface{}) error { // Bytes return the response body as []bytes that hava already been read, could be // nil if not read, the following cases are already read: -// 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, // also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) Bytes() []byte { return r.body @@ -118,8 +124,8 @@ func (r *Response) Bytes() []byte { // String returns the response body as string that hava already been read, could be // nil if not read, the following cases are already read: -// 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, +// 1. `Request.SetResult` or `Request.SetError` is called. +// 2. `Client.DisableAutoReadResponse(false)` is not called, // also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) String() string { return string(r.body) From 36341dac07dfc02386d8ee5c385bce4d40d88741 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Aug 2022 10:12:41 +0800 Subject: [PATCH 556/843] support nil pointer and pointer of pointer in SetResult and SetError(#139) --- internal/util/util.go | 26 ++++++++++++++++++++++---- request_test.go | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/internal/util/util.go b/internal/util/util.go index a110a600..80b56058 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -20,11 +20,29 @@ func IsXMLType(ct string) bool { // GetPointer return the pointer of the interface. func GetPointer(v interface{}) interface{} { - vv := reflect.ValueOf(v) - if vv.Kind() == reflect.Ptr { - return v + t := reflect.TypeOf(v) + if t.Kind() == reflect.Ptr { + if tt := t.Elem(); tt.Kind() == reflect.Ptr { // pointer of pointer + if tt.Elem().Kind() == reflect.Ptr { + panic("pointer of pointer of pointer is not supported") + } + el := reflect.ValueOf(v).Elem() + if el.IsZero() { + vv := reflect.New(tt.Elem()) + el.Set(vv) + return vv.Interface() + } else { + return el.Interface() + } + } else { + if reflect.ValueOf(v).IsZero() { + vv := reflect.New(t.Elem()) + return vv.Interface() + } + return v + } } - return reflect.New(vv.Type()).Interface() + return reflect.New(t).Interface() } // GetType return the underlying type. diff --git a/request_test.go b/request_test.go index 7a5597bf..1bf6946c 100644 --- a/request_test.go +++ b/request_test.go @@ -359,6 +359,26 @@ func TestSetBodyMarshal(t *testing.T) { } } +func TestSetResult(t *testing.T) { + c := tc() + var user *UserInfo + url := "/search?username=imroc&type=json" + + resp, err := c.R().SetResult(&user).Get(url) + assertSuccess(t, resp, err) + tests.AssertEqual(t, "imroc", user.Username) + + user = &UserInfo{} + resp, err = c.R().SetResult(user).Get(url) + assertSuccess(t, resp, err) + tests.AssertEqual(t, "imroc", user.Username) + + user = nil + resp, err = c.R().SetResult(user).Get(url) + assertSuccess(t, resp, err) + tests.AssertEqual(t, "imroc", resp.Result().(*UserInfo).Username) +} + func TestSetBody(t *testing.T) { body := "hello" fn := func() (io.ReadCloser, error) { From ab1e34e1e6e13ba63ce806ea28bd1fceee27ea08 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Aug 2022 13:51:58 +0800 Subject: [PATCH 557/843] Support Do API style(#137) --- client.go | 49 +++++++++++++++++++++++++++++++++++++++++ request.go | 42 ++++++++++++++++++++++++++--------- request_test.go | 10 +++++++++ request_wrapper.go | 6 +++++ request_wrapper_test.go | 1 + response.go | 18 +++++++++++++++ 6 files changed, 116 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index e9931aa7..7fde82a1 100644 --- a/client.go +++ b/client.go @@ -73,6 +73,55 @@ func (c *Client) R() *Request { } } +// Get create a new GET request. +func (c *Client) Get() *Request { + r := c.R() + r.Method = http.MethodGet + return r +} + +// Post create a new POST request. +func (c *Client) Post() *Request { + r := c.R() + r.Method = http.MethodPost + return r +} + +// Patch create a new PATCH request. +func (c *Client) Patch() *Request { + r := c.R() + r.Method = http.MethodPatch + return r +} + +// Delete create a new DELETE request. +func (c *Client) Delete() *Request { + r := c.R() + r.Method = http.MethodDelete + return r +} + +// Put create a new PUT request. +func (c *Client) Put() *Request { + r := c.R() + r.Method = http.MethodPut + return r +} + +// Head create a new HEAD request. +func (c *Client) Head() *Request { + r := c.R() + r.Method = http.MethodHead + return r +} + +// Options create a new OPTIONS request. +func (c *Client) Options() *Request { + r := c.R() + r.Method = http.MethodOptions + return r +} + // GetTransport return the underlying transport. func (c *Client) GetTransport() *Transport { return c.t diff --git a/request.go b/request.go index 89019b74..3adf28bf 100644 --- a/request.go +++ b/request.go @@ -141,6 +141,12 @@ func (r *Request) TraceInfo() TraceInfo { return ti } +// SetURL set the url for request. +func (r *Request) SetURL(url string) *Request { + r.RawURL = url + return r +} + // SetFormDataFromValues set the form data from url.Values, will not // been used if request method does not allow payload. func (r *Request) SetFormDataFromValues(data urlpkg.Values) *Request { @@ -443,21 +449,36 @@ func (r *Request) appendError(err error) { var errRetryableWithUnReplayableBody = errors.New("retryable request should not have unreplayable body (io.Reader)") -// Send fires http request and return the *Response which is always -// not nil, and the error is not nil if some error happens. -func (r *Request) Send(method, url string) (*Response, error) { +// Do fires http request and return the *Response which is always +// not nil, and the error is not nil if error occurs. +func (r *Request) Do() *Response { defer func() { r.responseReturnTime = time.Now() }() if r.error != nil { - return &Response{Request: r}, r.error + resp := &Response{Request: r} + resp.Err = r.error + return resp } if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable body - return &Response{Request: r}, errRetryableWithUnReplayableBody + resp := &Response{Request: r} + resp.Err = errRetryableWithUnReplayableBody + return resp } + resp, err := r.client.do(r) + if err != nil { + resp.Err = err + } + return resp +} + +// Send fires http request with specified method and url, returns the +// *Response which is always not nil, and the error is not nil if error occurs. +func (r *Request) Send(method, url string) (*Response, error) { r.Method = method r.RawURL = url - return r.client.do(r) + resp := r.Do() + return resp, resp.Err } // MustGet like Get, panic if error happens, should only be used to @@ -812,10 +833,11 @@ func (r *Request) SetRetryCount(count int) *Request { // SetRetryInterval sets the custom GetRetryIntervalFunc, you can use this to // implement your own backoff retry algorithm. // For example: -// req.SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { -// sleep := 0.01 * math.Exp2(float64(attempt)) -// return time.Duration(math.Min(2, sleep)) * time.Second -// }) +// +// req.SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { +// sleep := 0.01 * math.Exp2(float64(attempt)) +// return time.Duration(math.Min(2, sleep)) * time.Second +// }) func (r *Request) SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { r.getRetryOption().GetRetryInterval = getRetryIntervalFunc return r diff --git a/request_test.go b/request_test.go index 1bf6946c..624c6f47 100644 --- a/request_test.go +++ b/request_test.go @@ -359,6 +359,16 @@ func TestSetBodyMarshal(t *testing.T) { } } +func TestDoAPIStyle(t *testing.T) { + c := tc() + user := &UserInfo{} + url := "/search?username=imroc&type=json" + + err := c.Get().SetURL(url).Do().Into(user) + tests.AssertEqual(t, true, err == nil) + tests.AssertEqual(t, "imroc", user.Username) +} + func TestSetResult(t *testing.T) { c := tc() var user *UserInfo diff --git a/request_wrapper.go b/request_wrapper.go index e496cdc4..a64c2c25 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -8,6 +8,12 @@ import ( "time" ) +// SetURL is a global wrapper methods which delegated +// to the default client, create a request and SetURL for request. +func SetURL(url string) *Request { + return defaultClient.R().SetURL(url) +} + // SetFormDataFromValues is a global wrapper methods which delegated // to the default client, create a request and SetFormDataFromValues for request. func SetFormDataFromValues(data url.Values) *Request { diff --git a/request_wrapper_test.go b/request_wrapper_test.go index 4c74cbe9..91a244a6 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -40,6 +40,7 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { SetPathParam("test", "test"), SetPathParams(map[string]string{"test": "test"}), SetFormData(map[string]string{"test": "test"}), + SetURL(""), SetFormDataFromValues(nil), SetContentType(header.JsonContentType), AddRetryCondition(func(rep *Response, err error) bool { diff --git a/response.go b/response.go index 76a2dcc4..9695eeed 100644 --- a/response.go +++ b/response.go @@ -85,6 +85,9 @@ func (r *Response) setReceivedAt() { // UnmarshalJson unmarshals JSON response body into the specified object. func (r *Response) UnmarshalJson(v interface{}) error { + if r.Err != nil { + return r.Err + } b, err := r.ToBytes() if err != nil { return err @@ -94,6 +97,9 @@ func (r *Response) UnmarshalJson(v interface{}) error { // UnmarshalXml unmarshals XML response body into the specified object. func (r *Response) UnmarshalXml(v interface{}) error { + if r.Err != nil { + return r.Err + } b, err := r.ToBytes() if err != nil { return err @@ -104,6 +110,9 @@ func (r *Response) UnmarshalXml(v interface{}) error { // Unmarshal unmarshals response body into the specified object according // to response `Content-Type`. func (r *Response) Unmarshal(v interface{}) error { + if r.Err != nil { + return r.Err + } contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "json") { return r.UnmarshalJson(v) @@ -113,6 +122,12 @@ func (r *Response) Unmarshal(v interface{}) error { return r.UnmarshalJson(v) } +// Into unmarshals response body into the specified object according +// to response `Content-Type`. +func (r *Response) Into(v interface{}) error { + return r.Unmarshal(v) +} + // Bytes return the response body as []bytes that hava already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. @@ -139,6 +154,9 @@ func (r *Response) ToString() (string, error) { // ToBytes returns the response body as []byte, read body if not have been read. func (r *Response) ToBytes() ([]byte, error) { + if r.Err != nil { + return nil, r.Err + } if r.body != nil { return r.body, nil } From 5d8f06fbc85e102644230fde6ce10e61a9be325a Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Aug 2022 14:20:36 +0800 Subject: [PATCH 558/843] SetBody support basic types --- request.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/request.go b/request.go index 3adf28bf..1f22d608 100644 --- a/request.go +++ b/request.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/hashicorp/go-multierror" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" @@ -14,6 +15,7 @@ import ( urlpkg "net/url" "os" "path/filepath" + "reflect" "strings" "time" ) @@ -611,7 +613,13 @@ func (r *Request) SetBody(body interface{}) *Request { case GetContentFunc: r.getBody = b default: - r.marshalBody = body + t := reflect.TypeOf(body) + switch t.Kind() { + case reflect.Ptr, reflect.Struct, reflect.Map: + r.marshalBody = body + default: + r.SetBodyString(fmt.Sprint(body)) + } } return r } From 795e668f32c4f242a41f8de5fbafd3ddc0d80381 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Aug 2022 15:07:14 +0800 Subject: [PATCH 559/843] Support pointer of pointer in resp.Unmarshal and resp.Into --- response.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/response.go b/response.go index 9695eeed..5ccc650b 100644 --- a/response.go +++ b/response.go @@ -2,6 +2,7 @@ package req import ( "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" "io/ioutil" "net/http" "strings" @@ -113,6 +114,7 @@ func (r *Response) Unmarshal(v interface{}) error { if r.Err != nil { return r.Err } + v = util.GetPointer(v) contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "json") { return r.UnmarshalJson(v) From 2d45503c332418bf6e52fc406760a870ad1fd6cb Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Aug 2022 18:16:06 +0800 Subject: [PATCH 560/843] Support middleware in Transport(#138) --- client.go | 14 +++++++++++++ client_test.go | 16 +++++++++++++++ roundtrip.go | 6 +++++- transport.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 87 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 7fde82a1..c0b1c41b 100644 --- a/client.go +++ b/client.go @@ -949,6 +949,20 @@ func (c *Client) EnableHTTP3() *Client { return c } +// WrapRoundTrip adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { + c.t.WrapRoundTrip(wrappers...) + return c +} + +// WrapRoundTripFunc adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (c *Client) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { + c.t.WrapRoundTripFunc(funcs...) + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/client_test.go b/client_test.go index fca50bdf..011c87d9 100644 --- a/client_test.go +++ b/client_test.go @@ -17,6 +17,22 @@ import ( "time" ) +func TestWrapRoundTrip(t *testing.T) { + i, j := 0, 0 + c := tc().WrapRoundTripFunc(func(rt http.RoundTripper) RoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + i = 1 + resp, err = rt.RoundTrip(req) + j = 1 + return + } + }) + resp, err := c.R().Get("/") + assertSuccess(t, resp, err) + tests.AssertEqual(t, 1, i) + tests.AssertEqual(t, 1, j) +} + func TestAllowGetMethodPayload(t *testing.T) { c := tc() resp, err := c.R().SetBody("test").Get("/payload") diff --git a/roundtrip.go b/roundtrip.go index d230481b..0a63edbb 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -19,7 +19,11 @@ import ( // Like the RoundTripper interface, the error types returned // by RoundTrip are unspecified. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - resp, err = t.roundTrip(req) + if t.wrappedRoundTrip != nil { + resp, err = t.wrappedRoundTrip.RoundTrip(req) + } else { + resp, err = t.roundTrip(req) + } if err != nil { return } diff --git a/transport.go b/transport.go index f0d954b8..12843ce2 100644 --- a/transport.go +++ b/transport.go @@ -133,6 +133,7 @@ type Transport struct { // whether the response body should been auto decode to utf-8. // Only valid when DisableAutoDecode is true. autoDecodeContentType func(contentType string) bool + wrappedRoundTrip http.RoundTripper } // NewTransport is an alias of T @@ -156,6 +157,54 @@ func T() *Transport { return t } +// RoundTripFunc is a http.RoundTripper implementation, which is a simple function. +type RoundTripFunc func(req *http.Request) (resp *http.Response, err error) + +// RoundTrip implements http.RoundTripper. +func (fn RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// RoundTripWrapper is transport middleware function. +type RoundTripWrapper func(rt http.RoundTripper) http.RoundTripper + +// RoundTripWrapperFunc is transport middleware function, more convenient than RoundTripWrapper. +type RoundTripWrapperFunc func(rt http.RoundTripper) RoundTripFunc + +func (f RoundTripWrapperFunc) wrapper() RoundTripWrapper { + return func(rt http.RoundTripper) http.RoundTripper { + return f(rt) + } +} + +// WrapRoundTripFunc adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (t *Transport) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Transport { + var wrappers []RoundTripWrapper + for _, fn := range funcs { + wrappers = append(wrappers, fn.wrapper()) + } + return t.WrapRoundTrip(wrappers...) +} + +// WrapRoundTrip adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (t *Transport) WrapRoundTrip(wrappers ...RoundTripWrapper) *Transport { + if len(wrappers) == 0 { + return t + } + if t.wrappedRoundTrip == nil { + fn := func(req *http.Request) (*http.Response, error) { + return t.roundTrip(req) + } + t.wrappedRoundTrip = RoundTripFunc(fn) + } + for _, w := range wrappers { + t.wrappedRoundTrip = w(t.wrappedRoundTrip) + } + return t +} + // DisableAutoDecode disable auto-detect charset and decode to utf-8 // (enabled by default). func (t *Transport) DisableAutoDecode() *Transport { @@ -1906,7 +1955,6 @@ var _ io.ReaderFrom = (*persistConnWriter)(nil) // socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com // https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com // https://proxy.com|http https to proxy, http to anywhere after that -// type connectMethod struct { _ incomparable proxyURL *url.URL // nil for no proxy, else full proxy URL @@ -2024,8 +2072,11 @@ type persistConn struct { } // RFC 7234, section 5.4: Should treat +// // Pragma: no-cache +// // like +// // Cache-Control: no-cache func fixPragmaCacheControl(header http.Header) { if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { From f31f0abc74f318d96584cc12e9150d8d0fe13660 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 12:41:58 +0800 Subject: [PATCH 561/843] Support middleware in Client --- client.go | 108 ++++++++++++++++++++++++++++++++++++------------- client_test.go | 14 ++++++- request.go | 5 +-- transport.go | 26 ++++++------ 4 files changed, 106 insertions(+), 47 deletions(-) diff --git a/client.go b/client.go index c0b1c41b..cfd978d2 100644 --- a/client.go +++ b/client.go @@ -63,6 +63,7 @@ type Client struct { beforeRequest []RequestMiddleware udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware + wrappedRoundTrip RoundTripper } // R create a new request. @@ -949,20 +950,6 @@ func (c *Client) EnableHTTP3() *Client { return c } -// WrapRoundTrip adds a transport middleware function that will give the caller -// an opportunity to wrap the underlying http.RoundTripper. -func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { - c.t.WrapRoundTrip(wrappers...) - return c -} - -// WrapRoundTripFunc adds a transport middleware function that will give the caller -// an opportunity to wrap the underlying http.RoundTripper. -func (c *Client) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { - c.t.WrapRoundTripFunc(funcs...) - return c -} - // NewClient is the alias of C func NewClient() *Client { return C() @@ -1040,10 +1027,82 @@ func C() *Client { return c } -func (c *Client) do(r *Request) (resp *Response, err error) { - resp = &Response{ - Request: r, +// RoundTripper is the interface of req's Client. +type RoundTripper interface { + RoundTrip(*Request) (*Response, error) +} + +// RoundTripFunc is a RoundTripper implementation, which is a simple function. +type RoundTripFunc func(req *Request) (resp *Response, err error) + +// RoundTrip implements RoundTripper. +func (fn RoundTripFunc) RoundTrip(req *Request) (*Response, error) { + return fn(req) +} + +// RoundTripWrapper is client middleware function. +type RoundTripWrapper func(rt RoundTripper) RoundTripper + +// RoundTripWrapperFunc is client middleware function, more convenient than RoundTripWrapper. +type RoundTripWrapperFunc func(rt RoundTripper) RoundTripFunc + +func (f RoundTripWrapperFunc) wrapper() RoundTripWrapper { + return func(rt RoundTripper) RoundTripper { + return f(rt) + } +} + +// WrapRoundTripFunc adds a client middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (c *Client) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { + var wrappers []RoundTripWrapper + for _, fn := range funcs { + wrappers = append(wrappers, fn.wrapper()) + } + return c.WrapRoundTrip(wrappers...) +} + +// WrapRoundTrip adds a client middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper. +func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { + if len(wrappers) == 0 { + return c } + if c.wrappedRoundTrip == nil { + c.wrappedRoundTrip = c + } + for _, w := range wrappers { + c.wrappedRoundTrip = w(c.wrappedRoundTrip) + } + return c +} + +// RoundTrip implements RoundTripper +func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { + resp = &Response{Request: r} + var httpResponse *http.Response + httpResponse, err = c.httpClient.Do(r.RawRequest) + resp.Response = httpResponse + + // auto-read response body if possible + if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse { + _, err = resp.ToBytes() + if err != nil { + return + } + } + return +} + +func (c *Client) do(r *Request) (resp *Response, err error) { + defer func() { + if resp == nil { + resp = &Response{Request: r} + } + if err != nil { + resp.Err = err + } + }() for { for _, f := range r.client.beforeRequest { @@ -1129,16 +1188,11 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } r.RawRequest = req r.StartTime = time.Now() - var httpResponse *http.Response - httpResponse, err = c.httpClient.Do(req) - resp.Response = httpResponse - // auto-read response body if possible - if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse { - _, err = resp.ToBytes() - if err != nil { - return - } + if c.wrappedRoundTrip != nil { + resp, err = c.wrappedRoundTrip.RoundTrip(r) + } else { + resp, err = c.RoundTrip(r) } if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. @@ -1178,8 +1232,6 @@ func (c *Client) do(r *Request) (resp *Response, err error) { resp.error = nil } - resp.Err = err - for _, f := range r.client.afterResponse { if err := f(r.client, resp); err != nil { return resp, err diff --git a/client_test.go b/client_test.go index 011c87d9..6f06def2 100644 --- a/client_test.go +++ b/client_test.go @@ -18,8 +18,16 @@ import ( ) func TestWrapRoundTrip(t *testing.T) { - i, j := 0, 0 - c := tc().WrapRoundTripFunc(func(rt http.RoundTripper) RoundTripFunc { + i, j, a, b := 0, 0, 0, 0 + c := tc().WrapRoundTripFunc(func(rt RoundTripper) RoundTripFunc { + return func(req *Request) (resp *Response, err error) { + a = 1 + resp, err = rt.RoundTrip(req) + b = 1 + return + } + }) + c.GetTransport().WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { return func(req *http.Request) (resp *http.Response, err error) { i = 1 resp, err = rt.RoundTrip(req) @@ -31,6 +39,8 @@ func TestWrapRoundTrip(t *testing.T) { assertSuccess(t, resp, err) tests.AssertEqual(t, 1, i) tests.AssertEqual(t, 1, j) + tests.AssertEqual(t, 1, a) + tests.AssertEqual(t, 1, b) } func TestAllowGetMethodPayload(t *testing.T) { diff --git a/request.go b/request.go index 1f22d608..f60bb7e0 100644 --- a/request.go +++ b/request.go @@ -467,10 +467,7 @@ func (r *Request) Do() *Response { resp.Err = errRetryableWithUnReplayableBody return resp } - resp, err := r.client.do(r) - if err != nil { - resp.Err = err - } + resp, _ := r.client.do(r) return resp } diff --git a/transport.go b/transport.go index 12843ce2..8cc25b43 100644 --- a/transport.go +++ b/transport.go @@ -141,7 +141,7 @@ func NewTransport() *Transport { return T() } -// T create a new Transport. +// T create a Transport. func T() *Transport { t := &Transport{ Options: transport.Options{ @@ -157,21 +157,21 @@ func T() *Transport { return t } -// RoundTripFunc is a http.RoundTripper implementation, which is a simple function. -type RoundTripFunc func(req *http.Request) (resp *http.Response, err error) +// HttpRoundTripFunc is a http.RoundTripper implementation, which is a simple function. +type HttpRoundTripFunc func(req *http.Request) (resp *http.Response, err error) // RoundTrip implements http.RoundTripper. -func (fn RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { +func (fn HttpRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return fn(req) } -// RoundTripWrapper is transport middleware function. -type RoundTripWrapper func(rt http.RoundTripper) http.RoundTripper +// HttpRoundTripWrapper is transport middleware function. +type HttpRoundTripWrapper func(rt http.RoundTripper) http.RoundTripper -// RoundTripWrapperFunc is transport middleware function, more convenient than RoundTripWrapper. -type RoundTripWrapperFunc func(rt http.RoundTripper) RoundTripFunc +// HttpRoundTripWrapperFunc is transport middleware function, more convenient than HttpRoundTripWrapper. +type HttpRoundTripWrapperFunc func(rt http.RoundTripper) HttpRoundTripFunc -func (f RoundTripWrapperFunc) wrapper() RoundTripWrapper { +func (f HttpRoundTripWrapperFunc) wrapper() HttpRoundTripWrapper { return func(rt http.RoundTripper) http.RoundTripper { return f(rt) } @@ -179,8 +179,8 @@ func (f RoundTripWrapperFunc) wrapper() RoundTripWrapper { // WrapRoundTripFunc adds a transport middleware function that will give the caller // an opportunity to wrap the underlying http.RoundTripper. -func (t *Transport) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Transport { - var wrappers []RoundTripWrapper +func (t *Transport) WrapRoundTripFunc(funcs ...HttpRoundTripWrapperFunc) *Transport { + var wrappers []HttpRoundTripWrapper for _, fn := range funcs { wrappers = append(wrappers, fn.wrapper()) } @@ -189,7 +189,7 @@ func (t *Transport) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Transport // WrapRoundTrip adds a transport middleware function that will give the caller // an opportunity to wrap the underlying http.RoundTripper. -func (t *Transport) WrapRoundTrip(wrappers ...RoundTripWrapper) *Transport { +func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { if len(wrappers) == 0 { return t } @@ -197,7 +197,7 @@ func (t *Transport) WrapRoundTrip(wrappers ...RoundTripWrapper) *Transport { fn := func(req *http.Request) (*http.Response, error) { return t.roundTrip(req) } - t.wrappedRoundTrip = RoundTripFunc(fn) + t.wrappedRoundTrip = HttpRoundTripFunc(fn) } for _, w := range wrappers { t.wrappedRoundTrip = w(t.wrappedRoundTrip) From dc3de81cf9bbdf2991a95540d0d47bc4096c6ae7 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 12:56:19 +0800 Subject: [PATCH 562/843] add global wrapper for client --- client_wrapper.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/client_wrapper.go b/client_wrapper.go index 30b8430b..e0e535d9 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -10,6 +10,18 @@ import ( "time" ) +// WrapRoundTrip is a global wrapper methods which delegated +// to the default client's WrapRoundTrip. +func WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { + return defaultClient.WrapRoundTrip(wrappers...) +} + +// WrapRoundTripFunc is a global wrapper methods which delegated +// to the default client's WrapRoundTripFunc. +func WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { + return defaultClient.WrapRoundTripFunc(funcs...) +} + // SetCommonError is a global wrapper methods which delegated // to the default client's SetCommonError. func SetCommonError(err interface{}) *Client { From 33c30fd0e87a32a3dc899d4a68415062677873c0 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 13:30:57 +0800 Subject: [PATCH 563/843] optimize client middleware --- client.go | 156 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 80 insertions(+), 76 deletions(-) diff --git a/client.go b/client.go index cfd978d2..fd3906bd 100644 --- a/client.go +++ b/client.go @@ -1080,17 +1080,95 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { // RoundTrip implements RoundTripper func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { resp = &Response{Request: r} + + // setup trace + if r.trace == nil && r.client.trace { + r.trace = &clientTrace{} + } + if r.trace != nil { + r.ctx = r.trace.createContext(r.Context()) + } + + // setup url and host + var host string + if h := r.getHeader("Host"); h != "" { + host = h // Host header override + } else { + host = r.URL.Host + } + + // setup header + contentLength := int64(len(r.body)) + + var reqBody io.ReadCloser + if r.getBody != nil { + reqBody, err = r.getBody() + if err != nil { + return + } + } + req := &http.Request{ + Method: r.Method, + Header: r.Headers, + URL: r.URL, + Host: host, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: contentLength, + Body: reqBody, + GetBody: r.getBody, + } + for _, cookie := range r.Cookies { + req.AddCookie(cookie) + } + ctx := r.ctx + if r.isSaveResponse && r.downloadCallback != nil { + var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { + return &callbackReader{ + ReadCloser: rc, + callback: func(read int64) { + r.downloadCallback(DownloadInfo{ + Response: resp, + DownloadedSize: read, + }) + }, + lastTime: time.Now(), + interval: r.downloadCallbackInterval, + } + } + if ctx == nil { + ctx = context.Background() + } + ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) + } + if ctx != nil { + req = req.WithContext(ctx) + } + r.RawRequest = req + r.StartTime = time.Now() + var httpResponse *http.Response httpResponse, err = c.httpClient.Do(r.RawRequest) resp.Response = httpResponse + if err != nil { + return + } // auto-read response body if possible - if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse { + if !c.disableAutoReadResponse && !r.isSaveResponse { _, err = resp.ToBytes() if err != nil { return } } + + for _, f := range r.client.afterResponse { + if err = f(r.client, resp); err != nil { + resp.Err = err + return + } + } return } @@ -1116,78 +1194,9 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } } - // setup trace - if r.trace == nil && r.client.trace { - r.trace = &clientTrace{} - } - if r.trace != nil { - r.ctx = r.trace.createContext(r.Context()) - } - - // setup url and host - var host string - if h := r.getHeader("Host"); h != "" { - host = h // Host header override - } else { - host = r.URL.Host - } - - // setup header - var header http.Header if r.Headers == nil { - header = make(http.Header) - } else { - header = r.Headers - } - contentLength := int64(len(r.body)) - - var reqBody io.ReadCloser - if r.getBody != nil { - reqBody, err = r.getBody() - if err != nil { - return - } - } - req := &http.Request{ - Method: r.Method, - Header: header, - URL: r.URL, - Host: host, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - ContentLength: contentLength, - Body: reqBody, - GetBody: r.getBody, - } - for _, cookie := range r.Cookies { - req.AddCookie(cookie) - } - ctx := r.ctx - if r.isSaveResponse && r.downloadCallback != nil { - var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { - return &callbackReader{ - ReadCloser: rc, - callback: func(read int64) { - r.downloadCallback(DownloadInfo{ - Response: resp, - DownloadedSize: read, - }) - }, - lastTime: time.Now(), - interval: r.downloadCallbackInterval, - } - } - if ctx == nil { - ctx = context.Background() - } - ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) - } - if ctx != nil { - req = req.WithContext(ctx) + r.Headers = make(http.Header) } - r.RawRequest = req - r.StartTime = time.Now() if c.wrappedRoundTrip != nil { resp, err = c.wrappedRoundTrip.RoundTrip(r) @@ -1232,10 +1241,5 @@ func (c *Client) do(r *Request) (resp *Response, err error) { resp.error = nil } - for _, f := range r.client.afterResponse { - if err := f(r.client, resp); err != nil { - return resp, err - } - } return } From d9c35923a9b422fdf08a519f4f1285bacaa4a920 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 13:34:44 +0800 Subject: [PATCH 564/843] expose Body and GetBody in Client --- client.go | 8 ++++---- middleware.go | 8 ++++---- request.go | 54 +++++++++++++++++++++++++-------------------------- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/client.go b/client.go index fd3906bd..a8cb82db 100644 --- a/client.go +++ b/client.go @@ -1098,11 +1098,11 @@ func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { } // setup header - contentLength := int64(len(r.body)) + contentLength := int64(len(r.Body)) var reqBody io.ReadCloser - if r.getBody != nil { - reqBody, err = r.getBody() + if r.GetBody != nil { + reqBody, err = r.GetBody() if err != nil { return } @@ -1117,7 +1117,7 @@ func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { ProtoMinor: 1, ContentLength: contentLength, Body: reqBody, - GetBody: r.getBody, + GetBody: r.GetBody, } for _, cookie := range r.Cookies { req.AddCookie(cookie) diff --git a/middleware.go b/middleware.go index 613ae422..1278b1b1 100644 --- a/middleware.go +++ b/middleware.go @@ -175,7 +175,7 @@ func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { func handleMultiPart(c *Client, r *Request) (err error) { pr, pw := io.Pipe() - r.getBody = func() (io.ReadCloser, error) { + r.GetBody = func() (io.ReadCloser, error) { return pr, nil } w := multipart.NewWriter(pw) @@ -223,7 +223,7 @@ func handleMarshalBody(c *Client, r *Request) error { func parseRequestBody(c *Client, r *Request) (err error) { if c.isPayloadForbid(r.Method) { - r.getBody = nil + r.GetBody = nil return } // handle multipart @@ -245,12 +245,12 @@ func parseRequestBody(c *Client, r *Request) (err error) { handleMarshalBody(c, r) } - if r.body == nil { + if r.Body == nil { return } // body is in-memory []byte, so we can guess content type if r.getHeader(header.ContentType) == "" { - r.SetContentType(http.DetectContentType(r.body)) + r.SetContentType(http.DetectContentType(r.Body)) } return } diff --git a/request.go b/request.go index f60bb7e0..33a97d97 100644 --- a/request.go +++ b/request.go @@ -38,11 +38,12 @@ type Request struct { RetryAttempt int RawURL string // read only Method string + Body []byte + URL *urlpkg.URL + GetBody GetContentFunc isMultiPart bool isSaveResponse bool - URL *urlpkg.URL - getBody GetContentFunc uploadCallback UploadCallback uploadCallbackInterval time.Duration downloadCallback DownloadCallback @@ -50,7 +51,6 @@ type Request struct { unReplayableBody io.ReadCloser retryOption *retryOption bodyReadCloser io.ReadCloser - body []byte dumpOptions *DumpOptions marshalBody interface{} ctx context.Context @@ -323,14 +323,14 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min return r } -// SetResult set the result that response body will be unmarshalled to if +// SetResult set the result that response Body will be unmarshalled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r } -// SetError set the result that response body will be unmarshalled to if +// SetError set the result that response Body will be unmarshalled to if // request is error ( status `code >= 400`). func (r *Request) SetError(error interface{}) *Request { r.Error = util.GetPointer(error) @@ -384,14 +384,14 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { return r } -// SetOutputFile set the file that response body will be downloaded to. +// SetOutputFile set the file that response Body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true r.outputFile = file return r } -// SetOutput set the io.Writer that response body will be downloaded to. +// SetOutput set the io.Writer that response Body will be downloaded to. func (r *Request) SetOutput(output io.Writer) *Request { if output == nil { r.client.log.Warnf("nil io.Writer is not allowed in SetOutput") @@ -449,7 +449,7 @@ func (r *Request) appendError(err error) { r.error = multierror.Append(r.error, err) } -var errRetryableWithUnReplayableBody = errors.New("retryable request should not have unreplayable body (io.Reader)") +var errRetryableWithUnReplayableBody = errors.New("retryable request should not have unreplayable Body (io.Reader)") // Do fires http request and return the *Response which is always // not nil, and the error is not nil if error occurs. @@ -462,7 +462,7 @@ func (r *Request) Do() *Response { resp.Err = r.error return resp } - if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable body + if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable Body resp := &Response{Request: r} resp.Err = errRetryableWithUnReplayableBody return resp @@ -585,7 +585,7 @@ func (r *Request) Head(url string) (*Response, error) { return r.Send(http.MethodHead, url) } -// SetBody set the request body, accepts string, []byte, io.Reader, map and struct. +// SetBody set the request Body, accepts string, []byte, io.Reader, map and struct. func (r *Request) SetBody(body interface{}) *Request { if body == nil { return r @@ -593,12 +593,12 @@ func (r *Request) SetBody(body interface{}) *Request { switch b := body.(type) { case io.ReadCloser: r.unReplayableBody = b - r.getBody = func() (io.ReadCloser, error) { + r.GetBody = func() (io.ReadCloser, error) { return r.unReplayableBody, nil } case io.Reader: r.unReplayableBody = ioutil.NopCloser(b) - r.getBody = func() (io.ReadCloser, error) { + r.GetBody = func() (io.ReadCloser, error) { return r.unReplayableBody, nil } case []byte: @@ -606,9 +606,9 @@ func (r *Request) SetBody(body interface{}) *Request { case string: r.SetBodyString(b) case func() (io.ReadCloser, error): - r.getBody = b + r.GetBody = b case GetContentFunc: - r.getBody = b + r.GetBody = b default: t := reflect.TypeOf(body) switch t.Kind() { @@ -621,34 +621,34 @@ func (r *Request) SetBody(body interface{}) *Request { return r } -// SetBodyBytes set the request body as []byte. +// SetBodyBytes set the request Body as []byte. func (r *Request) SetBodyBytes(body []byte) *Request { - r.body = body - r.getBody = func() (io.ReadCloser, error) { + r.Body = body + r.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewReader(body)), nil } return r } -// SetBodyString set the request body as string. +// SetBodyString set the request Body as string. func (r *Request) SetBodyString(body string) *Request { return r.SetBodyBytes([]byte(body)) } -// SetBodyJsonString set the request body as string and set Content-Type header +// SetBodyJsonString set the request Body as string and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonString(body string) *Request { return r.SetBodyJsonBytes([]byte(body)) } -// SetBodyJsonBytes set the request body as []byte and set Content-Type header +// SetBodyJsonBytes set the request Body as []byte and set Content-Type header // as "application/json; charset=utf-8" func (r *Request) SetBodyJsonBytes(body []byte) *Request { r.SetContentType(header.JsonContentType) return r.SetBodyBytes(body) } -// SetBodyJsonMarshal set the request body that marshaled from object, and +// SetBodyJsonMarshal set the request Body that marshaled from object, and // set Content-Type header as "application/json; charset=utf-8" func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { b, err := r.client.jsonMarshal(v) @@ -659,20 +659,20 @@ func (r *Request) SetBodyJsonMarshal(v interface{}) *Request { return r.SetBodyJsonBytes(b) } -// SetBodyXmlString set the request body as string and set Content-Type header +// SetBodyXmlString set the request Body as string and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlString(body string) *Request { return r.SetBodyXmlBytes([]byte(body)) } -// SetBodyXmlBytes set the request body as []byte and set Content-Type header +// SetBodyXmlBytes set the request Body as []byte and set Content-Type header // as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlBytes(body []byte) *Request { r.SetContentType(header.XmlContentType) return r.SetBodyBytes(body) } -// SetBodyXmlMarshal set the request body that marshaled from object, and +// SetBodyXmlMarshal set the request Body that marshaled from object, and // set Content-Type header as "text/xml; charset=utf-8" func (r *Request) SetBodyXmlMarshal(v interface{}) *Request { b, err := r.client.xmlMarshal(v) @@ -782,7 +782,7 @@ func (r *Request) EnableDumpWithoutBody() *Request { return r.EnableDump() } -// EnableDumpWithoutHeader enables dump only body for the request and response. +// EnableDumpWithoutHeader enables dump only Body for the request and response. func (r *Request) EnableDumpWithoutHeader() *Request { o := r.getDumpOptions() o.RequestHeader = false @@ -806,7 +806,7 @@ func (r *Request) EnableDumpWithoutRequest() *Request { return r.EnableDump() } -// EnableDumpWithoutRequestBody enables dump with request body excluded, +// EnableDumpWithoutRequestBody enables dump with request Body excluded, // can be used in upload request to avoid dump the unreadable binary content. func (r *Request) EnableDumpWithoutRequestBody() *Request { o := r.getDumpOptions() @@ -814,7 +814,7 @@ func (r *Request) EnableDumpWithoutRequestBody() *Request { return r.EnableDump() } -// EnableDumpWithoutResponseBody enables dump with response body excluded, +// EnableDumpWithoutResponseBody enables dump with response Body excluded, // can be used in download request to avoid dump the unreadable binary content. func (r *Request) EnableDumpWithoutResponseBody() *Request { o := r.getDumpOptions() From 9406b764d1a0dcea0673e170c8f65ee43710842f Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 13:36:41 +0800 Subject: [PATCH 565/843] ajust struct field order --- client.go | 17 +++++++++-------- request.go | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index a8cb82db..6e440d97 100644 --- a/client.go +++ b/client.go @@ -38,14 +38,15 @@ var defaultClient = C() // Client is the req's http client. type Client struct { - BaseURL string - PathParams map[string]string - QueryParams urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - FormData urlpkg.Values - DebugLog bool - AllowGetMethodPayload bool + BaseURL string + PathParams map[string]string + QueryParams urlpkg.Values + Headers http.Header + Cookies []*http.Cookie + FormData urlpkg.Values + DebugLog bool + AllowGetMethodPayload bool + trace bool disableAutoReadResponse bool commonErrorType reflect.Type diff --git a/request.go b/request.go index 33a97d97..0923f4a8 100644 --- a/request.go +++ b/request.go @@ -31,8 +31,6 @@ type Request struct { Cookies []*http.Cookie Result interface{} Error interface{} - error error - client *Client RawRequest *http.Request StartTime time.Time RetryAttempt int @@ -44,6 +42,8 @@ type Request struct { isMultiPart bool isSaveResponse bool + error error + client *Client uploadCallback UploadCallback uploadCallbackInterval time.Duration downloadCallback DownloadCallback From 5ece883532ed02ad8d9e8156b636b6e8f2688ad4 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 16:46:41 +0800 Subject: [PATCH 566/843] Support SetFormDataAnyType(#148) --- middleware.go | 50 ---------------------------------------------- request.go | 13 ++++++++++++ request_wrapper.go | 6 ++++++ 3 files changed, 19 insertions(+), 50 deletions(-) diff --git a/middleware.go b/middleware.go index 1278b1b1..03c0e58e 100644 --- a/middleware.go +++ b/middleware.go @@ -108,56 +108,6 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e _, err = io.Copy(pw, content) return err - // uploadedBytes := int64(size) - // progressCallback := func() { - // r.uploadCallback(UploadInfo{ - // ParamName: file.ParamName, - // FileName: file.FileName, - // FileSize: file.FileSize, - // UploadedSize: uploadedBytes, - // }) - // } - // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - // lastTime = now - // progressCallback() - // } - // buf := make([]byte, 1024) - // for { - // callback := false - // nr, er := content.Read(buf) - // if nr > 0 { - // nw, ew := pw.Write(buf[:nr]) - // if nw < 0 || nr < nw { - // nw = 0 - // if ew == nil { - // ew = errors.New("invalid write result") - // } - // } - // uploadedBytes += int64(nw) - // if ew != nil { - // return ew - // } - // if nr != nw { - // return io.ErrShortWrite - // } - // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - // lastTime = now - // progressCallback() - // callback = true - // } - // } - // if er != nil { - // if er == io.EOF { - // if !callback { - // progressCallback() - // } - // break - // } else { - // return er - // } - // } - // } - return nil } func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { diff --git a/request.go b/request.go index 0923f4a8..74cf8254 100644 --- a/request.go +++ b/request.go @@ -175,6 +175,19 @@ func (r *Request) SetFormData(data map[string]string) *Request { return r } +// SetFormDataAnyType set the form data from a map, which value could be any type, +// will convert to string automatically. +// It will not been used if request method does not allow payload. +func (r *Request) SetFormDataAnyType(data map[string]interface{}) *Request { + if r.FormData == nil { + r.FormData = urlpkg.Values{} + } + for k, v := range data { + r.FormData.Set(k, fmt.Sprint(v)) + } + return r +} + // SetCookies set http cookies for the request. func (r *Request) SetCookies(cookies ...*http.Cookie) *Request { r.Cookies = append(r.Cookies, cookies...) diff --git a/request_wrapper.go b/request_wrapper.go index a64c2c25..797db7c6 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -26,6 +26,12 @@ func SetFormData(data map[string]string) *Request { return defaultClient.R().SetFormData(data) } +// SetFormDataAnyType is a global wrapper methods which delegated +// to the default client, create a request and SetFormDataAnyType for request. +func SetFormDataAnyType(data map[string]interface{}) *Request { + return defaultClient.R().SetFormDataAnyType(data) +} + // SetCookies is a global wrapper methods which delegated // to the default client, create a request and SetCookies for request. func SetCookies(cookies ...*http.Cookie) *Request { From b9be2874168079ef3027e71a405b09f4aa2bbb9a Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 20:12:45 +0800 Subject: [PATCH 567/843] Request.Do() accepts 0 or 1 context --- client.go | 1 - request.go | 27 ++++++++++++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 6e440d97..0ada46e0 100644 --- a/client.go +++ b/client.go @@ -1236,7 +1236,6 @@ func (c *Client) do(r *Request) (resp *Response, err error) { r.dumpBuffer.Reset() } r.trace = nil - r.ctx = nil resp.body = nil resp.result = nil resp.error = nil diff --git a/request.go b/request.go index 74cf8254..953b948a 100644 --- a/request.go +++ b/request.go @@ -464,21 +464,30 @@ func (r *Request) appendError(err error) { var errRetryableWithUnReplayableBody = errors.New("retryable request should not have unreplayable Body (io.Reader)") -// Do fires http request and return the *Response which is always -// not nil, and the error is not nil if error occurs. -func (r *Request) Do() *Response { +func (r *Request) newErrorResponse(err error) *Response { + resp := &Response{Request: r} + resp.Err = err + return resp +} + +// Do fires http request, 0 or 1 context ia allowed, and returns the *Response which +// is always not nil, and Response.Err is not nil if error occurs. +func (r *Request) Do(ctx ...context.Context) *Response { + if len(ctx) > 0 { + if len(ctx) > 1 { + return r.newErrorResponse(fmt.Errorf("only 0 or 1 context is allowed in Do, and %d are received", len(ctx))) + } + r.ctx = ctx[0] + } + defer func() { r.responseReturnTime = time.Now() }() if r.error != nil { - resp := &Response{Request: r} - resp.Err = r.error - return resp + return r.newErrorResponse(r.error) } if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable Body - resp := &Response{Request: r} - resp.Err = errRetryableWithUnReplayableBody - return resp + return r.newErrorResponse(errRetryableWithUnReplayableBody) } resp, _ := r.client.do(r) return resp From 0029353779535c154a398bd8f2722cd2eb945083 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 20:51:35 +0800 Subject: [PATCH 568/843] Support 0 or 1 url in Client's Get, Post and so on --- client.go | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 0ada46e0..49e1bc86 100644 --- a/client.go +++ b/client.go @@ -75,51 +75,72 @@ func (c *Client) R() *Request { } } -// Get create a new GET request. -func (c *Client) Get() *Request { +// Get create a new GET request, accepts 0 or 1 url. +func (c *Client) Get(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodGet return r } // Post create a new POST request. -func (c *Client) Post() *Request { +func (c *Client) Post(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodPost return r } // Patch create a new PATCH request. -func (c *Client) Patch() *Request { +func (c *Client) Patch(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodPatch return r } // Delete create a new DELETE request. -func (c *Client) Delete() *Request { +func (c *Client) Delete(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodDelete return r } // Put create a new PUT request. -func (c *Client) Put() *Request { +func (c *Client) Put(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodPut return r } // Head create a new HEAD request. -func (c *Client) Head() *Request { +func (c *Client) Head(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodHead return r } // Options create a new OPTIONS request. -func (c *Client) Options() *Request { +func (c *Client) Options(url ...string) *Request { r := c.R() + if len(url) > 0 { + r.RawURL = url[0] + } r.Method = http.MethodOptions return r } From d0ecd074720df8295cec71a7746db9a7198a27b0 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 8 Aug 2022 20:52:04 +0800 Subject: [PATCH 569/843] Support Request.SetQueryParamsAnyType --- request.go | 9 +++++++++ request_wrapper.go | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/request.go b/request.go index 953b948a..13c7a549 100644 --- a/request.go +++ b/request.go @@ -423,6 +423,15 @@ func (r *Request) SetQueryParams(params map[string]string) *Request { return r } +// SetQueryParamsAnyType set URL query parameters from a map for the request. +// The value of map is any type, will be convert to string automatically. +func (r *Request) SetQueryParamsAnyType(params map[string]interface{}) *Request { + for k, v := range params { + r.SetQueryParam(k, fmt.Sprint(v)) + } + return r +} + // SetQueryParam set an URL query parameter for the request. func (r *Request) SetQueryParam(key, value string) *Request { if r.QueryParams == nil { diff --git a/request_wrapper.go b/request_wrapper.go index 797db7c6..d4a8d24d 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -128,6 +128,12 @@ func SetQueryParams(params map[string]string) *Request { return defaultClient.R().SetQueryParams(params) } +// SetQueryParamsAnyType is a global wrapper methods which delegated +// to the default client, create a request and SetQueryParamsAnyType for request. +func SetQueryParamsAnyType(params map[string]interface{}) *Request { + return defaultClient.R().SetQueryParamsAnyType(params) +} + // SetQueryParam is a global wrapper methods which delegated // to the default client, create a request and SetQueryParam for request. func SetQueryParam(key, value string) *Request { From bc476a0a7e3afd402f930878745e27be5eb79725 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 9 Aug 2022 14:01:55 +0800 Subject: [PATCH 570/843] add requtil: ConvertHeaderToString --- pkg/util/util.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 pkg/util/util.go diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 00000000..589e8b87 --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,13 @@ +package requtil + +import ( + "bytes" + "net/http" +) + +// ConvertHeaderToString converts http header to a string. +func ConvertHeaderToString(h http.Header) string { + buf := new(bytes.Buffer) + h.Write(buf) + return buf.String() +} From cbf1b3340df821d1ea020434a64c4426c9a0046e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 9 Aug 2022 14:05:09 +0800 Subject: [PATCH 571/843] update go mod for examples --- examples/find-popular-repo/go.mod | 25 +- examples/find-popular-repo/go.sum | 294 ++++++++++++++++++++ examples/upload/uploadclient/go.mod | 25 +- examples/upload/uploadclient/go.sum | 294 ++++++++++++++++++++ examples/upload/uploadserver/go.mod | 26 +- examples/upload/uploadserver/go.sum | 74 +++++ examples/uploadcallback/uploadclient/go.mod | 25 +- examples/uploadcallback/uploadclient/go.sum | 294 ++++++++++++++++++++ examples/uploadcallback/uploadserver/go.mod | 26 +- examples/uploadcallback/uploadserver/go.sum | 73 +++++ 10 files changed, 1149 insertions(+), 7 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 19199729..b042eba8 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -1,7 +1,30 @@ module find-popular-repo -go 1.13 +go 1.18 replace github.com/imroc/req/v3 => ../../ require github.com/imroc/req/v3 v3.0.0 + +require ( + github.com/cheekybits/genny v1.0.0 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/lucas-clemente/quic-go v0.28.1 // indirect + github.com/marten-seemann/qpack v0.2.1 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.12 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect +) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 59f8d4e3..f5d21907 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -1,13 +1,307 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= +github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index c95a1ec4..cfc07949 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -1,7 +1,30 @@ module uploadclient -go 1.13 +go 1.18 replace github.com/imroc/req/v3 => ../../../ require github.com/imroc/req/v3 v3.0.0 + +require ( + github.com/cheekybits/genny v1.0.0 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/lucas-clemente/quic-go v0.28.1 // indirect + github.com/marten-seemann/qpack v0.2.1 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.12 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect +) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 59f8d4e3..f5d21907 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -1,13 +1,307 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= +github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/examples/upload/uploadserver/go.mod b/examples/upload/uploadserver/go.mod index 94acf8b9..9daf3713 100644 --- a/examples/upload/uploadserver/go.mod +++ b/examples/upload/uploadserver/go.mod @@ -1,5 +1,27 @@ module uploadserver -go 1.13 +go 1.18 -require github.com/gin-gonic/gin v1.7.7 +require github.com/gin-gonic/gin v1.8.1 + +require ( + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/go-playground/validator/v10 v10.11.0 // indirect + github.com/goccy/go-json v0.9.10 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.2 // indirect + github.com/ugorji/go/codec v1.2.7 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/examples/upload/uploadserver/go.sum b/examples/upload/uploadserver/go.sum index 5ee9be12..a10a0570 100644 --- a/examples/upload/uploadserver/go.sum +++ b/examples/upload/uploadserver/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,50 +6,123 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= +github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw= +github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/goccy/go-json v0.9.10 h1:hCeNmprSNLB8B8vQKWl6DpuH0t60oEs+TAk9a7CScKc= +github.com/goccy/go-json v0.9.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= +github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod index c95a1ec4..cfc07949 100644 --- a/examples/uploadcallback/uploadclient/go.mod +++ b/examples/uploadcallback/uploadclient/go.mod @@ -1,7 +1,30 @@ module uploadclient -go 1.13 +go 1.18 replace github.com/imroc/req/v3 => ../../../ require github.com/imroc/req/v3 v3.0.0 + +require ( + github.com/cheekybits/genny v1.0.0 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/lucas-clemente/quic-go v0.28.1 // indirect + github.com/marten-seemann/qpack v0.2.1 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.12 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect +) diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum index 59f8d4e3..f5d21907 100644 --- a/examples/uploadcallback/uploadclient/go.sum +++ b/examples/uploadcallback/uploadclient/go.sum @@ -1,13 +1,307 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= +github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/examples/uploadcallback/uploadserver/go.mod b/examples/uploadcallback/uploadserver/go.mod index 94acf8b9..9daf3713 100644 --- a/examples/uploadcallback/uploadserver/go.mod +++ b/examples/uploadcallback/uploadserver/go.mod @@ -1,5 +1,27 @@ module uploadserver -go 1.13 +go 1.18 -require github.com/gin-gonic/gin v1.7.7 +require github.com/gin-gonic/gin v1.8.1 + +require ( + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/go-playground/validator/v10 v10.11.0 // indirect + github.com/goccy/go-json v0.9.10 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.2 // indirect + github.com/ugorji/go/codec v1.2.7 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/examples/uploadcallback/uploadserver/go.sum b/examples/uploadcallback/uploadserver/go.sum index 5ee9be12..ac20466c 100644 --- a/examples/uploadcallback/uploadserver/go.sum +++ b/examples/uploadcallback/uploadserver/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,50 +6,122 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= +github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw= +github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/goccy/go-json v0.9.10 h1:hCeNmprSNLB8B8vQKWl6DpuH0t60oEs+TAk9a7CScKc= +github.com/goccy/go-json v0.9.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= +github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 85a3582acc3d3a6de5d24455359053dac8e34049 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 9 Aug 2022 14:52:05 +0800 Subject: [PATCH 572/843] add HeaderToString for req and resp --- pkg/util/util.go | 13 ------------- req.go | 11 +++++++++++ request.go | 5 +++++ response.go | 8 ++++++++ 4 files changed, 24 insertions(+), 13 deletions(-) delete mode 100644 pkg/util/util.go diff --git a/pkg/util/util.go b/pkg/util/util.go deleted file mode 100644 index 589e8b87..00000000 --- a/pkg/util/util.go +++ /dev/null @@ -1,13 +0,0 @@ -package requtil - -import ( - "bytes" - "net/http" -) - -// ConvertHeaderToString converts http header to a string. -func ConvertHeaderToString(h http.Header) string { - buf := new(bytes.Buffer) - h.Write(buf) - return buf.String() -} diff --git a/req.go b/req.go index 507a1ce9..eef943bb 100644 --- a/req.go +++ b/req.go @@ -1,6 +1,7 @@ package req import ( + "bytes" "fmt" "net/http" "net/url" @@ -145,3 +146,13 @@ func cloneMap(h map[string]string) map[string]string { } return m } + +// convertHeaderToString converts http header to a string. +func convertHeaderToString(h http.Header) string { + if h == nil { + return "" + } + buf := new(bytes.Buffer) + h.Write(buf) + return buf.String() +} diff --git a/request.go b/request.go index 13c7a549..591dd89e 100644 --- a/request.go +++ b/request.go @@ -143,6 +143,11 @@ func (r *Request) TraceInfo() TraceInfo { return ti } +// HeaderToString get all header as string. +func (r *Request) HeaderToString() string { + return convertHeaderToString(r.Headers) +} + // SetURL set the url for request. func (r *Request) SetURL(url string) *Request { r.RawURL = url diff --git a/response.go b/response.go index 5ccc650b..ffbfd30a 100644 --- a/response.go +++ b/response.go @@ -212,3 +212,11 @@ func (r *Response) GetHeaderValues(key string) []string { } return r.Header.Values(key) } + +// HeaderToString get all header as string. +func (r *Response) HeaderToString() string { + if r.Response == nil { + return "" + } + return convertHeaderToString(r.Header) +} From ac9ac4d4d70b0ae2dd4990883fe97558b842a0ea Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 9 Aug 2022 14:58:12 +0800 Subject: [PATCH 573/843] ensure err in client's RoundTrip --- client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client.go b/client.go index 49e1bc86..40257b94 100644 --- a/client.go +++ b/client.go @@ -1191,6 +1191,9 @@ func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { return } } + if resp.Err != nil { // in case that set error in middleware and not return error. + err = resp.Err + } return } From 0ca4e58adab140ee1be0cd36a89dfe7c6bedb904 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 9 Aug 2022 19:36:42 +0800 Subject: [PATCH 574/843] add opentelemetry-jaeger-tracing example --- .../opentelemetry-jaeger-tracing/README.md | 30 ++ .../github/github.go | 188 +++++++++++ examples/opentelemetry-jaeger-tracing/go.mod | 38 +++ examples/opentelemetry-jaeger-tracing/go.sum | 318 ++++++++++++++++++ examples/opentelemetry-jaeger-tracing/main.go | 153 +++++++++ 5 files changed, 727 insertions(+) create mode 100644 examples/opentelemetry-jaeger-tracing/README.md create mode 100644 examples/opentelemetry-jaeger-tracing/github/github.go create mode 100644 examples/opentelemetry-jaeger-tracing/go.mod create mode 100644 examples/opentelemetry-jaeger-tracing/go.sum create mode 100644 examples/opentelemetry-jaeger-tracing/main.go diff --git a/examples/opentelemetry-jaeger-tracing/README.md b/examples/opentelemetry-jaeger-tracing/README.md new file mode 100644 index 00000000..3dfa4a56 --- /dev/null +++ b/examples/opentelemetry-jaeger-tracing/README.md @@ -0,0 +1,30 @@ +# opentelemetry-jaeger-tracing + +This is a runnable example of req, which uses the built-in tiny github sdk built on req to query and display the information of the specified user. + +Best of all, it integrates seamlessly with jaeger tracing and is very easy to extend. + +## How to run + +First, use `docker` or `podman` to start a test jeager container (see jeager official doc: [ Getting Started](https://www.jaegertracing.io/docs/1.37/getting-started/#all-in-one)). + +Then, run example: + +```bash +go run . +``` +```txt +Please give a github username: +``` + +Input a github username, e.g. `imroc`: + +```bash +$ go run . +Please give a github username: imroc +The moust popular repo of roc (https://imroc.cc) is req, which have 2500 stars +``` + +Then enter the Jaeger UI with browser (`http://127.0.0.1:16686/`), checkout the tracing details. + +Run example again, try to input some username that doesn't exist, and check the error log in Jaeger UI. diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go new file mode 100644 index 00000000..7bf77fcb --- /dev/null +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -0,0 +1,188 @@ +package github + +import ( + "context" + "fmt" + "github.com/imroc/req/v3" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "strconv" + "strings" +) + +// Client is the go client for GitHub API. +type Client struct { + *req.Client +} + +// NewClient create a GitHub client. +func NewClient() *Client { + c := req.C(). + // All GitHub API requests need this header. + SetCommonHeader("Accept", "application/vnd.github.v3+json"). + // All GitHub API requests use the same base URL. + SetBaseURL("https://api.github.com"). + // EnableDump at the request level in request middleware which dump content into + // memory (not print to stdout), we can record dump content only when unexpected + // exception occurs, it is helpful to troubleshoot problems in production. + OnBeforeRequest(func(c *req.Client, r *req.Request) error { + if r.RetryAttempt > 0 { // Ignore on retry. + return nil + } + r.EnableDump() + return nil + }). + // Unmarshal response body into an APIError struct when status >= 400. + SetCommonError(&APIError{}). + // Handle common exceptions in response middleware. + OnAfterResponse(func(client *req.Client, resp *req.Response) error { + if resp.Err != nil { // Ignore if there is an underlying error. + return nil + } + if err, ok := resp.Error().(*APIError); ok { // Server returns an error message. + // Convert it to human-readable go error. + resp.Err = err + return nil + } + // Corner case: neither an error response nor a success response, + // dump content to help troubleshoot. + if !resp.IsSuccess() { + return fmt.Errorf("bad response, raw dump:\n%s", resp.Dump()) + } + return nil + }) + + return &Client{ + Client: c, + } +} + +// LoginWithToken login with GitHub personal access token. +// GitHub API doc: https://docs.github.com/en/rest/overview/other-authentication-methods#authenticating-for-saml-sso +func (c *Client) LoginWithToken(token string) *Client { + c.SetCommonHeader("Authorization", "token "+token) + return c +} + +// APIError represents the error message that GitHub API returns. +// GitHub API doc: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors +type APIError struct { + Message string `json:"message"` + DocumentationUrl string `json:"documentation_url,omitempty"` + Errors []struct { + Resource string `json:"resource"` + Field string `json:"field"` + Code string `json:"code"` + } `json:"errors,omitempty"` +} + +// Error convert APIError to a human readable error and return. +func (e *APIError) Error() string { + msg := fmt.Sprintf("API error: %s", e.Message) + if e.DocumentationUrl != "" { + return fmt.Sprintf("%s (see doc %s)", msg, e.DocumentationUrl) + } + if len(e.Errors) == 0 { + return msg + } + errs := []string{} + for _, err := range e.Errors { + errs = append(errs, fmt.Sprintf("resource:%s field:%s code:%s", err.Resource, err.Field, err.Code)) + } + return fmt.Sprintf("%s (%s)", msg, strings.Join(errs, " | ")) +} + +// SetDebug enable debug if set to true, disable debug if set to false. +func (c *Client) SetDebug(enable bool) *Client { + if enable { + c.EnableDebugLog() + c.EnableDumpAll() + } else { + c.DisableDebugLog() + c.DisableDumpAll() + } + return c +} + +type apiNameType int + +const apiNameKey apiNameType = iota + +// SetTracer set the tracer of opentelemetry. +func (c *Client) SetTracer(tracer trace.Tracer) { + c.WrapRoundTripFunc(func(rt req.RoundTripper) req.RoundTripFunc { + return func(r *req.Request) (resp *req.Response, err error) { + ctx := r.Context() + spanName := ctx.Value(apiNameKey).(string) + _, span := tracer.Start(r.Context(), spanName) + defer span.End() + span.SetAttributes( + attribute.String("http.url", r.URL.String()), + attribute.String("http.method", r.Method), + attribute.String("http.req.header", r.HeaderToString()), + ) + if len(r.Body) > 0 { + span.SetAttributes( + attribute.String("http.req.body", string(r.Body)), + ) + } + resp, err = rt.RoundTrip(r) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return + } + span.SetAttributes( + attribute.Int("http.status_code", resp.StatusCode), + attribute.String("http.resp.header", resp.HeaderToString()), + attribute.String("resp.resp.body", resp.String()), + ) + return + } + }) +} + +func withAPIName(ctx context.Context, name string) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, apiNameKey, name) +} + +type UserProfile struct { + Name string `json:"name"` + Blog string `json:"blog"` +} + +// GetUserProfile returns the user profile for the specified user. +// Github API doc: https://docs.github.com/en/rest/users/users#get-a-user +func (c *Client) GetUserProfile(ctx context.Context, username string) (user *UserProfile, err error) { + err = c.Get("/users/{username}"). + SetPathParam("username", username). + Do(withAPIName(ctx, "GetUserProfile")). + Into(&user) + return +} + +type Repo struct { + Name string `json:"name"` + Star int `json:"stargazers_count"` +} + +// ListUserRepo returns a list of public repositories for the specified user +// Github API doc: https://docs.github.com/en/rest/repos/repos#list-repositories-for-a-user +func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (repos []*Repo, err error) { + err = c.Get("/users/{username}/repos"). + SetQueryParamsAnyType(map[string]any{ + "type": "owner", + "page": strconv.Itoa(page), + "per_page": "100", + "sort": "updated", + "direction": "desc", + }). + SetPathParam("username", username). + Do(withAPIName(ctx, "ListUserRepo")). + Into(&repos) + return +} diff --git a/examples/opentelemetry-jaeger-tracing/go.mod b/examples/opentelemetry-jaeger-tracing/go.mod new file mode 100644 index 00000000..888ea567 --- /dev/null +++ b/examples/opentelemetry-jaeger-tracing/go.mod @@ -0,0 +1,38 @@ +module opentelemetry-jaeger-tracing + +go 1.18 + +replace github.com/imroc/req/v3 => ../../ + +require ( + github.com/imroc/req/v3 v3.0.0 + go.opentelemetry.io/otel v1.9.0 + go.opentelemetry.io/otel/exporters/jaeger v1.9.0 + go.opentelemetry.io/otel/sdk v1.9.0 + go.opentelemetry.io/otel/trace v1.9.0 +) + +require ( + github.com/cheekybits/genny v1.0.0 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/lucas-clemente/quic-go v0.28.1 // indirect + github.com/marten-seemann/qpack v0.2.1 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.12 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect +) diff --git a/examples/opentelemetry-jaeger-tracing/go.sum b/examples/opentelemetry-jaeger-tracing/go.sum new file mode 100644 index 00000000..676e2148 --- /dev/null +++ b/examples/opentelemetry-jaeger-tracing/go.sum @@ -0,0 +1,318 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= +github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.opentelemetry.io/otel v1.9.0 h1:8WZNQFIB2a71LnANS9JeyidJKKGOOremcUtb/OtHISw= +go.opentelemetry.io/otel v1.9.0/go.mod h1:np4EoPGzoPs3O67xUVNoPPcmSvsfOxNlNA4F4AC+0Eo= +go.opentelemetry.io/otel/exporters/jaeger v1.9.0 h1:gAEgEVGDWwFjcis9jJTOJqZNxDzoZfR12WNIxr7g9Ww= +go.opentelemetry.io/otel/exporters/jaeger v1.9.0/go.mod h1:hquezOLVAybNW6vanIxkdLXTXvzlj2Vn3wevSP15RYs= +go.opentelemetry.io/otel/sdk v1.9.0 h1:LNXp1vrr83fNXTHgU8eO89mhzxb/bbWAsHG6fNf3qWo= +go.opentelemetry.io/otel/sdk v1.9.0/go.mod h1:AEZc8nt5bd2F7BC24J5R0mrjYnpEgYHyTcM/vrSple4= +go.opentelemetry.io/otel/trace v1.9.0 h1:oZaCNJUjWcg60VXWee8lJKlqhPbXAPB51URuR47pQYc= +go.opentelemetry.io/otel/trace v1.9.0/go.mod h1:2737Q0MuG8q1uILYm2YYVkAyLtOofiTNGg6VODnOiPo= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= +golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/examples/opentelemetry-jaeger-tracing/main.go b/examples/opentelemetry-jaeger-tracing/main.go new file mode 100644 index 00000000..8964ef69 --- /dev/null +++ b/examples/opentelemetry-jaeger-tracing/main.go @@ -0,0 +1,153 @@ +package main + +import ( + "context" + "fmt" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/exporters/jaeger" + "go.opentelemetry.io/otel/sdk/resource" + "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.12.0" + "log" + "opentelemetry-jaeger-tracing/github" + "os" +) + +const serviceName = "github-query" + +var githubClient *github.Client + +func findMostPopularRepo(ctx context.Context, username string) (repo *github.Repo, err error) { + ctx, span := otel.Tracer("query").Start(ctx, "findMostPopularRepo") + defer span.End() + + for page := 1; ; page++ { + repos, e := githubClient.ListUserRepo(ctx, username, page) + if e != nil { + return + } + if len(repos) == 0 { + break + } + if repo == nil { + repo = repos[0] + } + for _, rp := range repos[1:] { + if rp.Star >= repo.Star { + repo = rp + } + } + if len(repos) == 100 { + continue + } + break + } + + if repo == nil { + err = fmt.Errorf("no repo found for %s", username) + } + return +} + +// QueryUser queries information for specified GitHub user, and display a +// brief introduction which includes name, blog, and the most popular repo. +func QueryUser(username string) error { + ctx, span := otel.Tracer("query").Start(context.Background(), "QueryUser") + defer span.End() + + span.SetAttributes( + attribute.String("query.username", username), + ) + profile, err := githubClient.GetUserProfile(ctx, username) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + span.SetAttributes( + attribute.String("query.name", profile.Name), + attribute.String("result.blog", profile.Blog), + ) + repo, err := findMostPopularRepo(ctx, username) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + span.SetAttributes( + attribute.String("popular.repo.name", repo.Name), + attribute.Int("popular.repo.star", repo.Star), + ) + fmt.Printf("the moust popular repo of %s (%s) is %s, which have %d stars\n", profile.Name, profile.Blog, repo.Name, repo.Star) + return nil +} + +func traceProvider() (*trace.TracerProvider, error) { + // Create the Jaeger exporter + ep := os.Getenv("JAEGER_ENDPOINT") + if ep == "" { + ep = "http://localhost:14268/api/traces" + } + exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(ep))) + if err != nil { + return nil, err + } + + // Record information about this application in a Resource. + res, _ := resource.Merge( + resource.Default(), + resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String(serviceName), + semconv.ServiceVersionKey.String("v0.1.0"), + attribute.String("environment", "test"), + ), + ) + + // Create the TraceProvider. + tp := trace.NewTracerProvider( + // Always be sure to batch in production. + trace.WithBatcher(exp), + // Record information about this application in a Resource. + trace.WithResource(res), + trace.WithSampler(trace.AlwaysSample()), + ) + return tp, nil +} + +func main() { + tp, err := traceProvider() + if err != nil { + panic(err) + } + defer func() { + if err := tp.Shutdown(context.Background()); err != nil { + log.Fatal(err) + } + }() + otel.SetTracerProvider(tp) + + githubClient = github.NewClient() + if os.Getenv("DEBUG") == "on" { + githubClient.SetDebug(true) + } + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + githubClient.LoginWithToken(token) + } + githubClient.SetTracer(otel.Tracer("github")) + + for { + var name string + fmt.Printf("Please give a github username: ") + _, err := fmt.Fscanf(os.Stdin, "%s\n", &name) + if err != nil { + panic(err) + } + err = QueryUser(name) + if err != nil { + fmt.Println(err.Error()) + } + } +} From 0cb6666b4e25ae644d84fa840c9ec48ad814a6a8 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 08:50:37 +0800 Subject: [PATCH 575/843] update example: opentelemetry-jaeger-tracing --- .../github/github.go | 103 +++++++++--------- 1 file changed, 53 insertions(+), 50 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index 7bf77fcb..e223296f 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -16,6 +16,34 @@ type Client struct { *req.Client } +// APIError represents the error message that GitHub API returns. +// GitHub API doc: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors +type APIError struct { + Message string `json:"message"` + DocumentationUrl string `json:"documentation_url,omitempty"` + Errors []struct { + Resource string `json:"resource"` + Field string `json:"field"` + Code string `json:"code"` + } `json:"errors,omitempty"` +} + +// Error convert APIError to a human readable error and return. +func (e *APIError) Error() string { + msg := fmt.Sprintf("API error: %s", e.Message) + if e.DocumentationUrl != "" { + return fmt.Sprintf("%s (see doc %s)", msg, e.DocumentationUrl) + } + if len(e.Errors) == 0 { + return msg + } + errs := []string{} + for _, err := range e.Errors { + errs = append(errs, fmt.Sprintf("resource:%s field:%s code:%s", err.Resource, err.Field, err.Code)) + } + return fmt.Sprintf("%s (%s)", msg, strings.Join(errs, " | ")) +} + // NewClient create a GitHub client. func NewClient() *Client { c := req.C(). @@ -37,8 +65,11 @@ func NewClient() *Client { SetCommonError(&APIError{}). // Handle common exceptions in response middleware. OnAfterResponse(func(client *req.Client, resp *req.Response) error { - if resp.Err != nil { // Ignore if there is an underlying error. - return nil + if resp.Err != nil { + if dump := resp.Dump(); dump != "" { // Append dump content to original underlying error to help troubleshoot. + resp.Err = fmt.Errorf("%s\nraw dump:\n%s", resp.Err.Error(), resp.Dump()) + } + return nil // Skip the following logic if there is an underlying error. } if err, ok := resp.Error().(*APIError); ok { // Server returns an error message. // Convert it to human-readable go error. @@ -48,7 +79,7 @@ func NewClient() *Client { // Corner case: neither an error response nor a success response, // dump content to help troubleshoot. if !resp.IsSuccess() { - return fmt.Errorf("bad response, raw dump:\n%s", resp.Dump()) + resp.Err = fmt.Errorf("bad response, raw dump:\n%s", resp.Dump()) } return nil }) @@ -58,53 +89,6 @@ func NewClient() *Client { } } -// LoginWithToken login with GitHub personal access token. -// GitHub API doc: https://docs.github.com/en/rest/overview/other-authentication-methods#authenticating-for-saml-sso -func (c *Client) LoginWithToken(token string) *Client { - c.SetCommonHeader("Authorization", "token "+token) - return c -} - -// APIError represents the error message that GitHub API returns. -// GitHub API doc: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors -type APIError struct { - Message string `json:"message"` - DocumentationUrl string `json:"documentation_url,omitempty"` - Errors []struct { - Resource string `json:"resource"` - Field string `json:"field"` - Code string `json:"code"` - } `json:"errors,omitempty"` -} - -// Error convert APIError to a human readable error and return. -func (e *APIError) Error() string { - msg := fmt.Sprintf("API error: %s", e.Message) - if e.DocumentationUrl != "" { - return fmt.Sprintf("%s (see doc %s)", msg, e.DocumentationUrl) - } - if len(e.Errors) == 0 { - return msg - } - errs := []string{} - for _, err := range e.Errors { - errs = append(errs, fmt.Sprintf("resource:%s field:%s code:%s", err.Resource, err.Field, err.Code)) - } - return fmt.Sprintf("%s (%s)", msg, strings.Join(errs, " | ")) -} - -// SetDebug enable debug if set to true, disable debug if set to false. -func (c *Client) SetDebug(enable bool) *Client { - if enable { - c.EnableDebugLog() - c.EnableDumpAll() - } else { - c.DisableDebugLog() - c.DisableDumpAll() - } - return c -} - type apiNameType int const apiNameKey apiNameType = iota @@ -186,3 +170,22 @@ func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (r Into(&repos) return } + +// LoginWithToken login with GitHub personal access token. +// GitHub API doc: https://docs.github.com/en/rest/overview/other-authentication-methods#authenticating-for-saml-sso +func (c *Client) LoginWithToken(token string) *Client { + c.SetCommonHeader("Authorization", "token "+token) + return c +} + +// SetDebug enable debug if set to true, disable debug if set to false. +func (c *Client) SetDebug(enable bool) *Client { + if enable { + c.EnableDebugLog() + c.EnableDumpAll() + } else { + c.DisableDebugLog() + c.DisableDumpAll() + } + return c +} From 19d3d871b0ba298b6cf6b2e38962fcd19db9a930 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 08:55:51 +0800 Subject: [PATCH 576/843] use http.ErrUseLastResponse to prevent return error in NoRedirectPolicy --- client_test.go | 3 +-- redirect.go | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index 6f06def2..70b176d4 100644 --- a/client_test.go +++ b/client_test.go @@ -346,8 +346,7 @@ func TestKeepAlives(t *testing.T) { func TestRedirect(t *testing.T) { _, err := tc().SetRedirectPolicy(NoRedirectPolicy()).R().Get("/unlimited-redirect") - tests.AssertNotNil(t, err) - tests.AssertContains(t, err.Error(), "redirect is disabled", true) + tests.AssertIsNil(t, err) _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(3)).R().Get("/unlimited-redirect") tests.AssertNotNil(t, err) diff --git a/redirect.go b/redirect.go index 4f3b71e9..364a55eb 100644 --- a/redirect.go +++ b/redirect.go @@ -24,7 +24,7 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { // NoRedirectPolicy disable redirect behaviour func NoRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error { - return errors.New("redirect is disabled") + return http.ErrUseLastResponse } } @@ -111,7 +111,8 @@ func getDomain(host string) string { // sensitive ones to the same origin, or subdomains thereof (https://go-review.googlesource.com/c/go/+/28930/) // Check discussion: https://github.com/golang/go/issues/4800 // For example: -// client.SetRedirectPolicy(req.AlwaysCopyHeaderRedirectPolicy("Authorization")) +// +// client.SetRedirectPolicy(req.AlwaysCopyHeaderRedirectPolicy("Authorization")) func AlwaysCopyHeaderRedirectPolicy(headers ...string) RedirectPolicy { return func(req *http.Request, via []*http.Request) error { for _, header := range headers { From bb788259acc7f41c4e626f990032fd586be8745f Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 14:14:14 +0800 Subject: [PATCH 577/843] update example: opentelemetry-jaeger-tracing --- .../github/github.go | 45 +++--- examples/opentelemetry-jaeger-tracing/main.go | 134 ++++++++++-------- 2 files changed, 95 insertions(+), 84 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index e223296f..25809292 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -55,7 +55,7 @@ func NewClient() *Client { // memory (not print to stdout), we can record dump content only when unexpected // exception occurs, it is helpful to troubleshoot problems in production. OnBeforeRequest(func(c *req.Client, r *req.Request) error { - if r.RetryAttempt > 0 { // Ignore on retry. + if r.RetryAttempt > 0 { // Ignore on retry, no need to repeat EnableDump. return nil } r.EnableDump() @@ -65,9 +65,9 @@ func NewClient() *Client { SetCommonError(&APIError{}). // Handle common exceptions in response middleware. OnAfterResponse(func(client *req.Client, resp *req.Response) error { - if resp.Err != nil { + if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error(SetResult or SetError was invoked before). if dump := resp.Dump(); dump != "" { // Append dump content to original underlying error to help troubleshoot. - resp.Err = fmt.Errorf("%s\nraw dump:\n%s", resp.Err.Error(), resp.Dump()) + resp.Err = fmt.Errorf("%s\nraw content:\n%s", resp.Err.Error(), resp.Dump()) } return nil // Skip the following logic if there is an underlying error. } @@ -76,10 +76,10 @@ func NewClient() *Client { resp.Err = err return nil } - // Corner case: neither an error response nor a success response, - // dump content to help troubleshoot. + // Corner case: neither an error response nor a success response, e.g. status code < 200 + // Just dump the raw content into error to help troubleshoot. if !resp.IsSuccess() { - resp.Err = fmt.Errorf("bad response, raw dump:\n%s", resp.Dump()) + resp.Err = fmt.Errorf("bad response, raw content:\n%s", resp.Dump()) } return nil }) @@ -96,22 +96,25 @@ const apiNameKey apiNameType = iota // SetTracer set the tracer of opentelemetry. func (c *Client) SetTracer(tracer trace.Tracer) { c.WrapRoundTripFunc(func(rt req.RoundTripper) req.RoundTripFunc { - return func(r *req.Request) (resp *req.Response, err error) { - ctx := r.Context() - spanName := ctx.Value(apiNameKey).(string) - _, span := tracer.Start(r.Context(), spanName) + return func(req *req.Request) (resp *req.Response, err error) { + ctx := req.Context() + apiName, ok := ctx.Value(apiNameKey).(string) + if !ok { + apiName = req.URL.Path + } + _, span := tracer.Start(req.Context(), apiName) defer span.End() span.SetAttributes( - attribute.String("http.url", r.URL.String()), - attribute.String("http.method", r.Method), - attribute.String("http.req.header", r.HeaderToString()), + attribute.String("http.url", req.URL.String()), + attribute.String("http.method", req.Method), + attribute.String("http.req.header", req.HeaderToString()), ) - if len(r.Body) > 0 { + if len(req.Body) > 0 { span.SetAttributes( - attribute.String("http.req.body", string(r.Body)), + attribute.String("http.req.body", string(req.Body)), ) } - resp, err = rt.RoundTrip(r) + resp, err = rt.RoundTrip(req) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -144,8 +147,8 @@ type UserProfile struct { func (c *Client) GetUserProfile(ctx context.Context, username string) (user *UserProfile, err error) { err = c.Get("/users/{username}"). SetPathParam("username", username). - Do(withAPIName(ctx, "GetUserProfile")). - Into(&user) + SetResult(&user). + Do(withAPIName(ctx, "GetUserProfile")).Err return } @@ -158,6 +161,7 @@ type Repo struct { // Github API doc: https://docs.github.com/en/rest/repos/repos#list-repositories-for-a-user func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (repos []*Repo, err error) { err = c.Get("/users/{username}/repos"). + SetPathParam("username", username). SetQueryParamsAnyType(map[string]any{ "type": "owner", "page": strconv.Itoa(page), @@ -165,9 +169,8 @@ func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (r "sort": "updated", "direction": "desc", }). - SetPathParam("username", username). - Do(withAPIName(ctx, "ListUserRepo")). - Into(&repos) + SetResult(&repos). + Do(withAPIName(ctx, "ListUserRepo")).Err return } diff --git a/examples/opentelemetry-jaeger-tracing/main.go b/examples/opentelemetry-jaeger-tracing/main.go index 8964ef69..193ad074 100644 --- a/examples/opentelemetry-jaeger-tracing/main.go +++ b/examples/opentelemetry-jaeger-tracing/main.go @@ -13,42 +13,45 @@ import ( "log" "opentelemetry-jaeger-tracing/github" "os" + "os/signal" + "syscall" ) const serviceName = "github-query" var githubClient *github.Client -func findMostPopularRepo(ctx context.Context, username string) (repo *github.Repo, err error) { - ctx, span := otel.Tracer("query").Start(ctx, "findMostPopularRepo") - defer span.End() - - for page := 1; ; page++ { - repos, e := githubClient.ListUserRepo(ctx, username, page) - if e != nil { - return - } - if len(repos) == 0 { - break - } - if repo == nil { - repo = repos[0] - } - for _, rp := range repos[1:] { - if rp.Star >= repo.Star { - repo = rp - } - } - if len(repos) == 100 { - continue - } - break +func traceProvider() (*trace.TracerProvider, error) { + // Create the Jaeger exporter + ep := os.Getenv("JAEGER_ENDPOINT") + if ep == "" { + ep = "http://localhost:14268/api/traces" } - - if repo == nil { - err = fmt.Errorf("no repo found for %s", username) + exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(ep))) + if err != nil { + return nil, err } - return + + // Record information about this application in a Resource. + res, _ := resource.Merge( + resource.Default(), + resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String(serviceName), + semconv.ServiceVersionKey.String("v0.1.0"), + attribute.String("environment", "test"), + ), + ) + + // Create the TraceProvider. + tp := trace.NewTracerProvider( + // Always be sure to batch in production. + trace.WithBatcher(exp), + // Record information about this application in a Resource. + trace.WithResource(res), + trace.WithSampler(trace.AlwaysSample()), + ) + return tp, nil } // QueryUser queries information for specified GitHub user, and display a @@ -80,41 +83,40 @@ func QueryUser(username string) error { attribute.String("popular.repo.name", repo.Name), attribute.Int("popular.repo.star", repo.Star), ) - fmt.Printf("the moust popular repo of %s (%s) is %s, which have %d stars\n", profile.Name, profile.Blog, repo.Name, repo.Star) + fmt.Printf("The most popular repo of %s (%s) is %s, with %d stars\n", profile.Name, profile.Blog, repo.Name, repo.Star) return nil } -func traceProvider() (*trace.TracerProvider, error) { - // Create the Jaeger exporter - ep := os.Getenv("JAEGER_ENDPOINT") - if ep == "" { - ep = "http://localhost:14268/api/traces" - } - exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(ep))) - if err != nil { - return nil, err - } +func findMostPopularRepo(ctx context.Context, username string) (repo *github.Repo, err error) { + ctx, span := otel.Tracer("query").Start(ctx, "findMostPopularRepo") + defer span.End() - // Record information about this application in a Resource. - res, _ := resource.Merge( - resource.Default(), - resource.NewWithAttributes( - semconv.SchemaURL, - semconv.ServiceNameKey.String(serviceName), - semconv.ServiceVersionKey.String("v0.1.0"), - attribute.String("environment", "test"), - ), - ) + for page := 1; ; page++ { + repos, e := githubClient.ListUserRepo(ctx, username, page) + if e != nil { + return + } + if len(repos) == 0 { + break + } + if repo == nil { + repo = repos[0] + } + for _, rp := range repos[1:] { + if rp.Star >= repo.Star { + repo = rp + } + } + if len(repos) == 100 { + continue + } + break + } - // Create the TraceProvider. - tp := trace.NewTracerProvider( - // Always be sure to batch in production. - trace.WithBatcher(exp), - // Record information about this application in a Resource. - trace.WithResource(res), - trace.WithSampler(trace.AlwaysSample()), - ) - return tp, nil + if repo == nil { + err = fmt.Errorf("no repo found for %s", username) + } + return } func main() { @@ -122,11 +124,6 @@ func main() { if err != nil { panic(err) } - defer func() { - if err := tp.Shutdown(context.Background()); err != nil { - log.Fatal(err) - } - }() otel.SetTracerProvider(tp) githubClient = github.NewClient() @@ -138,6 +135,17 @@ func main() { } githubClient.SetTracer(otel.Tracer("github")) + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + go func() { + sig := <-sigs + fmt.Printf("Caught %s, shutting down\n", sig) + if err := tp.Shutdown(context.Background()); err != nil { + log.Fatal(err) + } + os.Exit(0) + }() + for { var name string fmt.Printf("Please give a github username: ") From a7293c9d4c176867989de67bc0d03dd470090dda Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 16:53:26 +0800 Subject: [PATCH 578/843] update example: fix typo --- examples/opentelemetry-jaeger-tracing/github/github.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index 25809292..5b51c5d2 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -123,7 +123,7 @@ func (c *Client) SetTracer(tracer trace.Tracer) { span.SetAttributes( attribute.Int("http.status_code", resp.StatusCode), attribute.String("http.resp.header", resp.HeaderToString()), - attribute.String("resp.resp.body", resp.String()), + attribute.String("http.resp.body", resp.String()), ) return } From 7de58879dc82beb3332b4e6de121fcaf863b8f0b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 19:06:47 +0800 Subject: [PATCH 579/843] update example: opentelemetry-jaeger-tracing --- .../opentelemetry-jaeger-tracing/github/github.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index 5b51c5d2..7f103f80 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -118,13 +118,14 @@ func (c *Client) SetTracer(tracer trace.Tracer) { if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) - return } - span.SetAttributes( - attribute.Int("http.status_code", resp.StatusCode), - attribute.String("http.resp.header", resp.HeaderToString()), - attribute.String("http.resp.body", resp.String()), - ) + if resp.Response != nil { + span.SetAttributes( + attribute.Int("http.status_code", resp.StatusCode), + attribute.String("http.resp.header", resp.HeaderToString()), + attribute.String("http.resp.body", resp.String()), + ) + } return } }) From f83370ee2884c8bc6241ce9352d901626b04335e Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 10 Aug 2022 19:25:51 +0800 Subject: [PATCH 580/843] update example: opentelemetry-jaeger-tracing --- examples/opentelemetry-jaeger-tracing/main.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/main.go b/examples/opentelemetry-jaeger-tracing/main.go index 193ad074..c4620b57 100644 --- a/examples/opentelemetry-jaeger-tracing/main.go +++ b/examples/opentelemetry-jaeger-tracing/main.go @@ -92,8 +92,9 @@ func findMostPopularRepo(ctx context.Context, username string) (repo *github.Rep defer span.End() for page := 1; ; page++ { - repos, e := githubClient.ListUserRepo(ctx, username, page) - if e != nil { + var repos []*github.Repo + repos, err = githubClient.ListUserRepo(ctx, username, page) + if err != nil { return } if len(repos) == 0 { From 6984f39aca47dedeb33f1bdd089876f4d07f08db Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 12 Aug 2022 19:55:14 +0800 Subject: [PATCH 581/843] unexpose Client's RoundTrip --- client.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 40257b94..6de39af9 100644 --- a/client.go +++ b/client.go @@ -1084,6 +1084,14 @@ func (c *Client) WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { return c.WrapRoundTrip(wrappers...) } +type roundTripImpl struct { + *Client +} + +func (r roundTripImpl) RoundTrip(req *Request) (resp *Response, err error) { + return r.roundTrip(req) +} + // WrapRoundTrip adds a client middleware function that will give the caller // an opportunity to wrap the underlying http.RoundTripper. func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { @@ -1091,7 +1099,7 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { return c } if c.wrappedRoundTrip == nil { - c.wrappedRoundTrip = c + c.wrappedRoundTrip = roundTripImpl{c} } for _, w := range wrappers { c.wrappedRoundTrip = w(c.wrappedRoundTrip) @@ -1100,7 +1108,7 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { } // RoundTrip implements RoundTripper -func (c *Client) RoundTrip(r *Request) (resp *Response, err error) { +func (c *Client) roundTrip(r *Request) (resp *Response, err error) { resp = &Response{Request: r} // setup trace @@ -1226,7 +1234,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { if c.wrappedRoundTrip != nil { resp, err = c.wrappedRoundTrip.RoundTrip(r) } else { - resp, err = c.RoundTrip(r) + resp, err = c.roundTrip(r) } if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. From 6e389591f13c321a7f5f3e0c65ec21810cd8a3ea Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 14 Aug 2022 15:48:27 +0800 Subject: [PATCH 582/843] Restore Response.Body when AutoReadResponse is enabled(#152) --- client.go | 3 +++ request_test.go | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/client.go b/client.go index 6de39af9..bd56526c 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package req import ( + "bytes" "context" "crypto/tls" "crypto/x509" @@ -1191,6 +1192,8 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { if err != nil { return } + // restore body for re-reads + resp.Body = ioutil.NopCloser(bytes.NewReader(resp.body)) } for _, f := range r.client.afterResponse { diff --git a/request_test.go b/request_test.go index 624c6f47..76ddbb8c 100644 --- a/request_test.go +++ b/request_test.go @@ -1010,3 +1010,14 @@ func TestDownloadCallback(t *testing.T) { assertSuccess(t, resp, err) tests.AssertEqual(t, true, n > 0) } + +func TestRestoreResponseBody(t *testing.T) { + c := tc() + resp, err := c.R().Get("/") + assertSuccess(t, resp, err) + tests.AssertNoError(t, err) + tests.AssertEqual(t, true, len(resp.Bytes()) > 0) + body, err := ioutil.ReadAll(resp.Body) + tests.AssertNoError(t, err) + tests.AssertEqual(t, true, len(body) > 0) +} From 2861e6adc37731d3073d53b180bbdd4b9f806271 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 15 Aug 2022 12:38:42 +0800 Subject: [PATCH 583/843] update README --- README.md | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a274c2f..cdeb56c6 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Full documentation is available on the official website: https://req.cool. * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). * **Exportable**: `Transport` is exportable, which support dump requests, it's easy to integrate with existing http.Client, so you can debug APIs with minimal code change. -* **Extensible**: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware/)). +* **Extensible**: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware-for-request-and-response/)), and also support Middleware for Client and Transport (See [Client and Transport Middleware](https://req.cool/docs/tutorial/middleware-for-client-and-transport/)). ## Get Started @@ -86,6 +86,24 @@ err = fmt.Errorf("got unexpected response, raw dump:\n%s", resp.Dump()) // ... ``` +You can also use another style if you want: + +```go +resp := client.Get("https://api.github.com/users/{username}/repos"). // Create a GET request with specified URL. + SetHeader("Accept", "application/vnd.github.v3+json"). + SetPathParam("username", "imroc"). + SetQueryParam("page", "1"). + SetResult(&result). + SetError(&errMsg). + EnableDump(). + Do() // Send request with Do. + +if resp.Err != nil { + // ... +} +// ... +``` + **Videos** * [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) From 6a39a631f68c02c15d15c7cbf9a79897520ad26f Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 15 Aug 2022 16:25:08 +0800 Subject: [PATCH 584/843] update comment for SetCookieJar --- client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index bd56526c..2fffcdd9 100644 --- a/client.go +++ b/client.go @@ -775,7 +775,8 @@ func (c *Client) EnableTraceAll() *Client { return c } -// SetCookieJar set the `CookeJar` to the underlying `http.Client`. +// SetCookieJar set the `CookeJar` to the underlying `http.Client`, set to nil if you +// want to disable cookie. func (c *Client) SetCookieJar(jar http.CookieJar) *Client { c.httpClient.Jar = jar return c From dbb8317925f28cc80b802838461cb0f338f68407 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 15 Aug 2022 21:07:26 +0800 Subject: [PATCH 585/843] enable allow GET with body by default, discard body if disabled(#153) --- client.go | 19 ++++++++++--------- client_test.go | 10 +++++----- middleware.go | 2 ++ 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/client.go b/client.go index 2fffcdd9..57cbe487 100644 --- a/client.go +++ b/client.go @@ -1031,15 +1031,16 @@ func C() *Client { handleDownload, } c := &Client{ - beforeRequest: beforeRequest, - afterResponse: afterResponse, - log: createDefaultLogger(), - httpClient: httpClient, - t: t, - jsonMarshal: json.Marshal, - jsonUnmarshal: json.Unmarshal, - xmlMarshal: xml.Marshal, - xmlUnmarshal: xml.Unmarshal, + AllowGetMethodPayload: true, + beforeRequest: beforeRequest, + afterResponse: afterResponse, + log: createDefaultLogger(), + httpClient: httpClient, + t: t, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, } httpClient.CheckRedirect = c.defaultCheckRedirect diff --git a/client_test.go b/client_test.go index 70b176d4..960c028d 100644 --- a/client_test.go +++ b/client_test.go @@ -47,17 +47,17 @@ func TestAllowGetMethodPayload(t *testing.T) { c := tc() resp, err := c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) - tests.AssertEqual(t, "", resp.String()) - - c.EnableAllowGetMethodPayload() - resp, err = c.R().SetBody("test").Get("/payload") - assertSuccess(t, resp, err) tests.AssertEqual(t, "test", resp.String()) c.DisableAllowGetMethodPayload() resp, err = c.R().SetBody("test").Get("/payload") assertSuccess(t, resp, err) tests.AssertEqual(t, "", resp.String()) + + c.EnableAllowGetMethodPayload() + resp, err = c.R().SetBody("test").Get("/payload") + assertSuccess(t, resp, err) + tests.AssertEqual(t, "test", resp.String()) } func TestSetTLSHandshakeTimeout(t *testing.T) { diff --git a/middleware.go b/middleware.go index 03c0e58e..8cac7da5 100644 --- a/middleware.go +++ b/middleware.go @@ -173,6 +173,8 @@ func handleMarshalBody(c *Client, r *Request) error { func parseRequestBody(c *Client, r *Request) (err error) { if c.isPayloadForbid(r.Method) { + r.marshalBody = nil + r.Body = nil r.GetBody = nil return } From bae0ebfc3f2920334179ecae6368a670324d999a Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 15 Aug 2022 21:13:16 +0800 Subject: [PATCH 586/843] remove unused api.md --- docs/api.md | 249 ---------------------------------------------------- 1 file changed, 249 deletions(-) delete mode 100644 docs/api.md diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index fe0ba0d8..00000000 --- a/docs/api.md +++ /dev/null @@ -1,249 +0,0 @@ -

-

Quick API Reference

-

- -Here is a brief and categorized list of the core APIs, for a more detailed and complete list, please refer to the [GoDoc](https://pkg.go.dev/github.com/imroc/req/v3). - -## Table of Contents - -* [Client Settings](#Client) - * [Debug Features](#Debug) - * [Common Settings for constructing HTTP Requests](#Common) - * [Auto-Decode](#Decode) - * [TLS and Certificates](#Certs) - * [Marshal&Unmarshal](#Marshal) - * [HTTP Version](#Version) - * [Retry](#Retry-Client) - * [Other Settings](#Other) -* [Request Settings](#Request) - * [URL Query and Path Parameter](#Query) - * [Header and Cookie](#Header) - * [Body and Marshal&Unmarshal](#Body) - * [Request Level Debug](#Debug-Request) - * [Multipart & Form & Upload](#Multipart) - * [Download](#Download) - * [Retry](#Retry-Request) - * [Other Settings](#Other-Request) -* [Sending Request](#Send-Request) - -## Client Settings - -The following are the chainable settings of Client, all of which have corresponding global wrappers (Just treat the package name `req` as a Client to test, set up the Client without create any Client explicitly). - -Basically, you can know the meaning of most settings directly from the method name. - -### Debug Features - -* [DevMode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DevMode) - Enable all debug features (Dump, DebugLog and Trace). -* [EnableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDebugLog) - Enable debug level log (disabled by default). -* [DisableDebugLog()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDebugLog) -* [SetLogger(log Logger)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetLogger) - Set the customized logger. -* [EnableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAll) - Enable dump for all requests. -* [DisableDumpAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableDumpAll) -* [SetCommonDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonDumpOptions) -* [EnableDumpAllAsync()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllAsync) -* [EnableDumpAllTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllTo) -* [EnableDumpAllToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllToFile) -* [EnableDumpAllWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutBody) -* [EnableDumpAllWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutHeader) -* [EnableDumpAllWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequest) -* [EnableDumpAllWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutRequestBody) -* [EnableDumpAllWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponse) -* [EnableDumpAllWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableDumpAllWithoutResponseBody) -* [EnableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableTraceAll) - Enable trace for all requests (disabled by default). -* [DisableTraceAll()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableTraceAll) - -### Common Settings for constructing HTTP Requests - -* [SetCommonBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBasicAuth) -* [SetCommonBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonBearerAuthToken) -* [SetCommonContentType(ct string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonContentType) -* [SetCommonCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonCookies) -* [SetCommonFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormData) -* [SetCommonFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonFormDataFromValues) -* [SetCommonHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeader) -* [SetCommonHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonHeaders) -* [SetCommonPathParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonPathParam) -* [SetCommonPathParams(pathParams map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonPathParams) -* [SetCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParam) -* [SetCommonQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryParams) -* [SetCommonQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonQueryString) -* [AddCommonQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonQueryParam) -* [SetUserAgent(userAgent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUserAgent) - -### Auto-Decode - -* [EnableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoDecode) -* [DisableAutoDecode()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoDecode) - Disable auto-detect charset and decode to utf-8 (enabled by default). -* [SetAutoDecodeContentType(contentTypes ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentType) -* [SetAutoDecodeAllContentType()](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeAllContentType) -* [SetAutoDecodeContentTypeFunc(fn func(contentType string) bool)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetAutoDecodeContentTypeFunc) - -### TLS and Certificates - -* [SetCerts(certs ...tls.Certificate) ](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCerts) -* [SetCertFromFile(certFile, keyFile string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCertFromFile) -* [SetRootCertsFromFile(pemFiles ...string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertsFromFile) -* [SetRootCertFromString(pemContent string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRootCertFromString) -* [EnableInsecureSkipVerify()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableInsecureSkipVerify) - Disabled by default. -* [DisableInsecureSkipVerify](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableInsecureSkipVerify) -* [SetTLSHandshakeTimeout(timeout time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSHandshakeTimeout) -* [SetTLSClientConfig(conf *tls.Config)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTLSClientConfig) - -### Marshal&Unmarshal - -* [SetJsonUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonUnmarshal) -* [SetJsonMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetJsonMarshal) -* [SetXmlUnmarshal(fn func(data []byte, v interface{}) error)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlUnmarshal) -* [SetXmlMarshal(fn func(v interface{}) ([]byte, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetXmlMarshal) - -### Middleware - -* [OnBeforeRequest(m RequestMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnBeforeRequest) -* [OnAfterResponse(m ResponseMiddleware)](https://pkg.go.dev/github.com/imroc/req/v3#Client.OnAfterResponse) - -### HTTP Version - -* [DisableForceHttpVersion()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableForceHttpVersion) -* [EnableForceHTTP2()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP2) -* [EnableForceHTTP1()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableForceHTTP1) - -### Retry - -* [SetCommonRetryCount(count int)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryCount) -* [SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryInterval) -* [SetCommonRetryFixedInterval(interval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryFixedInterval) -* [SetCommonRetryBackoffInterval(min, max time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryBackoffInterval) -* [SetCommonRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryHook) -* [AddCommonRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonRetryHook) -* [SetCommonRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCommonRetryCondition) -* [AddCommonRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Client.AddCommonRetryCondition) - -### Other Settings - -* [SetTimeout(d time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetTimeout) -* [EnableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableKeepAlives) -* [DisableKeepAlives()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableKeepAlives) - Enabled by default. -* [SetScheme(scheme string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetScheme) -* [SetBaseURL(u string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetBaseURL) -* [SetProxyURL(proxyUrl string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxyURL) -* [SetProxy(proxy func(*http.Request) (*urlpkg.URL, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetProxy) -* [SetOutputDirectory(dir string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetOutputDirectory) -* [SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDialTLS) -* [SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error))](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetDial) -* [SetCookieJar(jar http.CookieJar)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetCookieJar) -* [SetRedirectPolicy(policies ...RedirectPolicy)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetRedirectPolicy) -* [EnableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableCompression) -* [DisableCompression()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableCompression) - Enabled by default -* [EnableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAutoReadResponse) -* [DisableAutoReadResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAutoReadResponse) - Enabled by default -* [EnableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.EnableAllowGetMethodPayload) - Disabled by default. -* [DisableAllowGetMethodPayload()](https://pkg.go.dev/github.com/imroc/req/v3#Client.DisableAllowGetMethodPayload) -* [SetUnixSocket(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Client.SetUnixSocket) - -## Request Settings - -The following are the chainable settings of Request, all of which have corresponding global wrappers. - -Basically, you can know the meaning of most settings directly from the method name. - -### URL Query and Path Parameter - -* [AddQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddQueryParam) -* [SetQueryParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryParam) -* [SetQueryParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryParams) -* [SetQueryString(query string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetQueryString) -* [SetPathParam(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetPathParam) -* [SetPathParams(params map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetPathParams) - -### Header and Cookie - -* [SetHeader(key, value string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetHeader) -* [SetHeaders(hdrs map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetHeaders) -* [SetBasicAuth(username, password string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBasicAuth) -* [SetBearerAuthToken(token string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBearerAuthToken) -* [SetContentType(contentType string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContentType) -* [SetCookies(cookies ...*http.Cookie)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetCookies) - -### Body and Marshal&Unmarshal - -* [SetBody(body interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBody) -* [SetBodyBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyBytes) -* [SetBodyJsonBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonBytes) -* [SetBodyJsonMarshal(v interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonMarshal) -* [SetBodyJsonString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyJsonString) -* [SetBodyString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyString) -* [SetBodyXmlBytes(body []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlBytes) -* [SetBodyXmlMarshal(v interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlMarshal) -* [SetBodyXmlString(body string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetBodyXmlString) -* [SetResult(result interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetResult) -* [SetError(error interface{})](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetError) - -### Request Level Debug - -* [EnableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableTrace) - Disabled by default. -* [DisableTrace()](https://pkg.go.dev/github.com/imroc/req/v3#Request.DisableTrace) -* [EnableDump()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDump) -* [EnableDumpTo(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpTo) -* [EnableDumpToFile(filename string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpToFile) -* [EnableDumpWithoutBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutBody) -* [EnableDumpWithoutHeader()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutHeader) -* [EnableDumpWithoutRequest()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutRequest) -* [EnableDumpWithoutRequestBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutRequestBody) -* [EnableDumpWithoutResponse()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutResponse) -* [EnableDumpWithoutResponseBody()](https://pkg.go.dev/github.com/imroc/req/v3#Request.EnableDumpWithoutResponseBody) -* [SetDumpOptions(opt *DumpOptions)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDumpOptions) - -### Multipart & Form & Upload - -* [SetFormData(data map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormData) -* [SetFormDataFromValues(data url.Values)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFormDataFromValues) -* [SetFile(paramName, filePath string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFile) -* [SetFiles(files map[string]string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFiles) -* [SetFileBytes(paramName, filename string, content []byte)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileBytes) -* [SetFileReader(paramName, filePath string, reader io.Reader)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileReader) -* [SetFileUpload(uploads ...FileUpload)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetFileUpload) - Set the fully custimized multipart file upload options. -* [SetUploadCallback(callback UploadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetUploadCallback) -* [SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetUploadCallbackWithInterval) - -### Download - -* [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) -* [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) -* [SetDownloadCallback(callback DownloadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallback) -* [SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallbackWithInterval) - -### Retry - -* [SetRetryCount(count int)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryCount) -* [SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryInterval) -* [SetRetryFixedInterval(interval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryFixedInterval) -* [SetRetryBackoffInterval(min, max time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryBackoffInterval) -* [SetRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryHook) -* [AddRetryHook(hook RetryHookFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddRetryHook) -* [SetRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetRetryCondition) -* [AddRetryCondition(condition RetryConditionFunc)](https://pkg.go.dev/github.com/imroc/req/v3#Request.AddRetryCondition) - -### Other Settings - -* [SetContext(ctx context.Context)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetContext) - -## Sending Request - -These methods will fire the http request and get response, `MustXXX` will not return any error, panic if error happens. - -* [Get(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Get) -* [Head(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Head) -* [Post(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Post) -* [Delete(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Delete) -* [Patch(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Patch) -* [Options(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Options) -* [Put(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Put) -* [Send(method, url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.Put) - Send request with given method name and url. -* [MustGet(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustGet) -* [MustHead(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustHead) -* [MustPost(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPost) -* [MustDelete(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustDelete) -* [MustPatch(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPatch) -* [MustOptions(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustOptions) -* [MustPut(url string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.MustPut) From dc2c3cca0d4c45f5d1d7b46c469af0e32ec72842 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 16 Aug 2022 13:55:20 +0800 Subject: [PATCH 587/843] update README: add contact --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index cdeb56c6..24ec0e7b 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,18 @@ if resp.Err != nil { Check more introduction, tutorials, examples, best practices and API references on the [official website](https://req.cool/). +## Contributing + +If you have a bug report or feature request, you can [open an issue](https://github.com/imroc/req/issues/new), and [pull requests](https://github.com/imroc/req/pulls) are also welcome. + +## Contact + +If you have questions, feel free to reach out to us in the following ways: + +* [Github Discussion](https://github.com/imroc/req/discussions) +* [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) (International) | [Join](https://slack.req.cool/) +* QQ Group (Chinese): 621411351 - + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. From 0925c0413347e6a2243fd30396173c127b31eb1d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 16 Aug 2022 17:50:24 +0800 Subject: [PATCH 588/843] remove useless code --- client.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client.go b/client.go index 57cbe487..4b3dc632 100644 --- a/client.go +++ b/client.go @@ -1204,9 +1204,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { return } } - if resp.Err != nil { // in case that set error in middleware and not return error. - err = resp.Err - } return } From a049e206ef5b56aa6bae9c22d466c97b535a70a0 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 16 Aug 2022 20:27:24 +0800 Subject: [PATCH 589/843] add EnableDumpEachRequestXXX for Client --- client.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++ client_wrapper.go | 42 ++++++++++++++++++++++++ middleware.go | 4 +-- 3 files changed, 128 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 4b3dc632..2c720a80 100644 --- a/client.go +++ b/client.go @@ -593,6 +593,90 @@ func (c *Client) EnableDumpAllWithoutBody() *Client { return c } +// EnableDumpEachRequest enable dump at the request-level for each request, and only +// temporarily stores the dump content in memory, call Response.Dump() to get the +// dump content when needed. +func (c *Client) EnableDumpEachRequest() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDump() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutBody enable dump without body at the request-level for +// each request, and only temporarily stores the dump content in memory, call +// Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutBody() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutBody() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutHeader enable dump without header at the request-level for +// each request, and only temporarily stores the dump content in memory, call +// Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutHeader() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutHeader() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutRequest enable dump without request at the request-level for +// each request, and only temporarily stores the dump content in memory, call +// Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutRequest() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutRequest() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutResponse enable dump without response at the request-level for +// each request, and only temporarily stores the dump content in memory, call +// Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutResponse() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutResponse() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutResponseBody enable dump without response body at the +// request-level for each request, and only temporarily stores the dump content in memory, +// call Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutResponseBody() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutResponseBody() + } + return nil + }) +} + +// EnableDumpEachRequestWithoutRequestBody enable dump without request body at the +// request-level for each request, and only temporarily stores the dump content in memory, +// call Response.Dump() to get the dump content when needed. +func (c *Client) EnableDumpEachRequestWithoutRequestBody() *Client { + return c.OnBeforeRequest(func(client *Client, req *Request) error { + if req.RetryAttempt == 0 { // Ignore on retry, no need to repeat enable dump. + req.EnableDumpWithoutRequestBody() + } + return nil + }) +} + // NewRequest is the alias of R() func (c *Client) NewRequest() *Request { return c.R() diff --git a/client_wrapper.go b/client_wrapper.go index e0e535d9..84246d00 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -268,6 +268,48 @@ func EnableDumpAllWithoutBody() *Client { return defaultClient.EnableDumpAllWithoutBody() } +// EnableDumpEachRequest is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequest. +func EnableDumpEachRequest() *Client { + return defaultClient.EnableDumpEachRequest() +} + +// EnableDumpEachRequestWithoutBody is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutBody. +func EnableDumpEachRequestWithoutBody() *Client { + return defaultClient.EnableDumpEachRequestWithoutBody() +} + +// EnableDumpEachRequestWithoutHeader is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutHeader. +func EnableDumpEachRequestWithoutHeader() *Client { + return defaultClient.EnableDumpEachRequestWithoutHeader() +} + +// EnableDumpEachRequestWithoutResponse is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutResponse. +func EnableDumpEachRequestWithoutResponse() *Client { + return defaultClient.EnableDumpEachRequestWithoutResponse() +} + +// EnableDumpEachRequestWithoutRequest is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutRequest. +func EnableDumpEachRequestWithoutRequest() *Client { + return defaultClient.EnableDumpEachRequestWithoutRequest() +} + +// EnableDumpEachRequestWithoutResponseBody is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutResponseBody. +func EnableDumpEachRequestWithoutResponseBody() *Client { + return defaultClient.EnableDumpEachRequestWithoutResponseBody() +} + +// EnableDumpEachRequestWithoutRequestBody is a global wrapper methods which delegated +// to the default client's EnableDumpEachRequestWithoutRequestBody. +func EnableDumpEachRequestWithoutRequestBody() *Client { + return defaultClient.EnableDumpEachRequestWithoutRequestBody() +} + // DisableAutoReadResponse is a global wrapper methods which delegated // to the default client's DisableAutoReadResponse. func DisableAutoReadResponse() *Client { diff --git a/middleware.go b/middleware.go index 8cac7da5..392a8e16 100644 --- a/middleware.go +++ b/middleware.go @@ -19,10 +19,10 @@ import ( type ( // RequestMiddleware type is for request middleware, called before a request is sent - RequestMiddleware func(*Client, *Request) error + RequestMiddleware func(client *Client, req *Request) error // ResponseMiddleware type is for response middleware, called after a response has been received - ResponseMiddleware func(*Client, *Response) error + ResponseMiddleware func(client *Client, resp *Response) error ) func createMultipartHeader(file *FileUpload, contentType string) textproto.MIMEHeader { From aa760f1a92008d84f3d8ff9160fb974275c56e5e Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 17 Aug 2022 09:26:30 +0800 Subject: [PATCH 590/843] update exmaple: opentelemetry-jaeger-tracing --- .../github/github.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index 7f103f80..894553f4 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -51,16 +51,13 @@ func NewClient() *Client { SetCommonHeader("Accept", "application/vnd.github.v3+json"). // All GitHub API requests use the same base URL. SetBaseURL("https://api.github.com"). - // EnableDump at the request level in request middleware which dump content into - // memory (not print to stdout), we can record dump content only when unexpected - // exception occurs, it is helpful to troubleshoot problems in production. - OnBeforeRequest(func(c *req.Client, r *req.Request) error { - if r.RetryAttempt > 0 { // Ignore on retry, no need to repeat EnableDump. - return nil - } - r.EnableDump() - return nil - }). + // Enable dump at the request-level for each request, and only + // temporarily stores the dump content in memory, so we can call + // resp.Dump() to get the dump content when needed in response + // middleware. + // This is actually a syntax sugar, implemented internally using + // request middleware + EnableDumpEachRequest(). // Unmarshal response body into an APIError struct when status >= 400. SetCommonError(&APIError{}). // Handle common exceptions in response middleware. From 86a59044c385d34138964af2b82bf3a631a9765d Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 17 Aug 2022 09:42:33 +0800 Subject: [PATCH 591/843] update README --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 24ec0e7b..4baae90b 100644 --- a/README.md +++ b/README.md @@ -19,15 +19,15 @@ Full documentation is available on the official website: https://req.cool. ## Features -* **Simple and Powerful**: Providing rich client-level and request-level settings, all of which are intuitive and chainable methods, and the request-level setting takes precedence if both are set. +* **Simple and Powerful**: Simple and easy to use, providing rich client-level and request-level settings, all of which are intuitive and chainable methods. * **Easy Debugging**: Powerful and convenient debug utilities, including debug logs, performance traces, and even dump the complete request and response content (see [Debugging](https://req.cool/docs/tutorial/debugging/)). -* **Easy API Testing**: API testing can be done with minimal code, no need to explicitly create any Requests and Clients, or even to handle errors (See [Quick HTTP Test](https://req.cool/docs/tutorial/quick-test/)) +* **Easy API Testing**: API testing can be done with minimal code, no need to explicitly create any Request or Client, or even to handle errors (See [Quick HTTP Test](https://req.cool/docs/tutorial/quick-test/)) * **Smart by Default**: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. -* **Works fine with HTTP2**: Support both with HTTP/2 and HTTP/1.1, and HTTP/2 is preferred by default if server support, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). +* **Support Multiple HTTP Versions**: Support `HTTP/1.1`, `HTTP/2`, and `HTTP/3`, and can automatically detect the server side and select the optimal HTTP version for requests, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). -* **Exportable**: `Transport` is exportable, which support dump requests, it's easy to integrate with existing http.Client, so you can debug APIs with minimal code change. -* **Extensible**: Support Middleware for Request and Response (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware-for-request-and-response/)), and also support Middleware for Client and Transport (See [Client and Transport Middleware](https://req.cool/docs/tutorial/middleware-for-client-and-transport/)). +* **Exportable**: `req.Transport` is exportable. Compared with `http.Transport`, it also supports HTTP3, dump content, middleware, etc. It can directly replace the Transport of `http.Client` in existing projects, and obtain more powerful functions with minimal code change. +* **Extensible**: Support Middleware for Request, Response, Client and Transport (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware-for-request-and-response/)) and [Client and Transport Middleware](https://req.cool/docs/tutorial/middleware-for-client-and-transport/)). ## Get Started From fbbc61bbd083b2d38603b718254f366fb40cfbee Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 17 Aug 2022 11:48:49 +0800 Subject: [PATCH 592/843] add ClearCookies for Client --- client.go | 9 +++++++++ client_wrapper.go | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/client.go b/client.go index 2c720a80..62c4bbf7 100644 --- a/client.go +++ b/client.go @@ -866,6 +866,15 @@ func (c *Client) SetCookieJar(jar http.CookieJar) *Client { return c } +// ClearCookies clears all cookies if cookie is enabled. +func (c *Client) ClearCookies() *Client { + if c.httpClient.Jar != nil { + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + c.httpClient.Jar = jar + } + return c +} + // SetJsonMarshal set the JSON marshal function which will be used // to marshal request body. func (c *Client) SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { diff --git a/client_wrapper.go b/client_wrapper.go index 84246d00..d919fd0d 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -442,6 +442,12 @@ func SetCookieJar(jar http.CookieJar) *Client { return defaultClient.SetCookieJar(jar) } +// ClearCookies is a global wrapper methods which delegated +// to the default client's ClearCookies. +func ClearCookies() *Client { + return defaultClient.ClearCookies() +} + // SetJsonMarshal is a global wrapper methods which delegated // to the default client's SetJsonMarshal. func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { From 548f200cbf55639fd574796f68a39c19cb17b6bb Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 21 Aug 2022 09:52:57 +0800 Subject: [PATCH 593/843] fix autodecode when html page is small --- decode.go | 2 +- request.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/decode.go b/decode.go index 9d453c57..825cffc6 100644 --- a/decode.go +++ b/decode.go @@ -44,7 +44,7 @@ type autoDecodeReadCloser struct { func (a *autoDecodeReadCloser) peekRead(p []byte) (n int, err error) { n, err = a.ReadCloser.Read(p) - if n == 0 || err != nil { + if n == 0 || (err != nil && err != io.EOF) { return } a.detected = true diff --git a/request.go b/request.go index 591dd89e..2502e5eb 100644 --- a/request.go +++ b/request.go @@ -379,7 +379,6 @@ func (r *Request) SetHeader(key, value string) *Request { r.Headers = make(http.Header) } r.Headers.Set(key, value) - return r } From 13a8efced7bb280703deac7adcd8931694eb2b9e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 21 Aug 2022 18:51:43 +0800 Subject: [PATCH 594/843] support http3 in go1.19 --- transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport.go b/transport.go index 8cc25b43..323254fa 100644 --- a/transport.go +++ b/transport.go @@ -405,7 +405,7 @@ func (t *Transport) EnableHTTP3() { } return } - if !(minorVersion >= 16 && minorVersion <= 18) { + if !(minorVersion >= 16 && minorVersion <= 19) { if t.Debugf != nil { t.Debugf("%s is not support http3", v) } From f91733eba0ead0f17cc5b1add9028e44226fbb4e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 21 Aug 2022 18:55:28 +0800 Subject: [PATCH 595/843] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4baae90b..06247454 100644 --- a/README.md +++ b/README.md @@ -106,8 +106,10 @@ if resp.Err != nil { **Videos** -* [Get Started With Req](https://www.youtube.com/watch?v=k47i0CKBVrA) (English, Youtube) -* [快速上手 req](https://www.bilibili.com/video/BV1Xq4y1b7UR) (Chinese, BiliBili) +The following is a series of video tutorials for req: + +* [Youtube Play List](https://www.youtube.com/watch?v=Dy8iph8JWw0&list=PLnW6i9cc0XqlhUgOJJp5Yf1FHXlANYMhF&index=2) +* [BiliBili 播放列表](https://www.bilibili.com/video/BV14t4y1J7cm) **More** From 288f4e1c1959101236bd672c29b8df6ce7304952 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 24 Aug 2022 09:24:04 +0800 Subject: [PATCH 596/843] update README --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 06247454..e45030d5 100644 --- a/README.md +++ b/README.md @@ -33,12 +33,16 @@ Full documentation is available on the official website: https://req.cool. **Install** +You first need [Go](https://go.dev/) installed (version 1.16+ is required), then you can use the below Go command to install req. + ``` sh go get github.com/imroc/req/v3 ``` **Import** +Import req to your code: + ```go import "github.com/imroc/req/v3" ``` @@ -109,7 +113,7 @@ if resp.Err != nil { The following is a series of video tutorials for req: * [Youtube Play List](https://www.youtube.com/watch?v=Dy8iph8JWw0&list=PLnW6i9cc0XqlhUgOJJp5Yf1FHXlANYMhF&index=2) -* [BiliBili 播放列表](https://www.bilibili.com/video/BV14t4y1J7cm) +* [BiliBili 播放列表](https://www.bilibili.com/video/BV14t4y1J7cm) (Chinese) **More** @@ -124,7 +128,7 @@ If you have a bug report or feature request, you can [open an issue](https://git If you have questions, feel free to reach out to us in the following ways: * [Github Discussion](https://github.com/imroc/req/discussions) -* [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) (International) | [Join](https://slack.req.cool/) +* [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) | [Join](https://slack.req.cool/) * QQ Group (Chinese): 621411351 - ## License From 1cf597212cc8d7918ef0c2c473f13eec035f8133 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 24 Aug 2022 09:37:22 +0800 Subject: [PATCH 597/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e45030d5..c901753c 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Full documentation is available on the official website: https://req.cool. **Install** -You first need [Go](https://go.dev/) installed (version 1.16+ is required), then you can use the below Go command to install req. +You first need [Go](https://go.dev/) installed (version 1.16+ is required), then you can use the below Go command to install req: ``` sh go get github.com/imroc/req/v3 From 3da9823500484a61c692a51419dd3e01c79b7dd8 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 24 Aug 2022 14:03:39 +0800 Subject: [PATCH 598/843] fix concurrent map in Client.Clone() when high concurrency (#157) --- client.go | 28 ++++++++++++++-------------- internal/transport/option.go | 12 ++++++++++++ transport.go | 23 +++++++++-------------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/client.go b/client.go index 62c4bbf7..66c053cd 100644 --- a/client.go +++ b/client.go @@ -1074,15 +1074,18 @@ func NewClient() *Client { // Clone copy and returns the Client func (c *Client) Clone() *Client { - t := c.t.Clone() + cc := *c - client := *c.httpClient - client.Transport = t + // clone Transport + cc.t = c.t.Clone() + cc.initTransport() - cc := *c + // clone http.Client + client := *c.httpClient + client.Transport = cc.t cc.httpClient = &client - cc.t = t + // clone other fields that may need to be cloned cc.Headers = cloneHeaders(c.Headers) cc.Cookies = cloneCookies(c.Cookies) cc.PathParams = cloneMap(c.PathParams) @@ -1093,13 +1096,6 @@ func (c *Client) Clone() *Client { cc.afterResponse = cloneResponseMiddleware(c.afterResponse) cc.dumpOptions = c.dumpOptions.Clone() cc.retryOption = c.retryOption.Clone() - - cc.log = c.log - cc.jsonUnmarshal = c.jsonUnmarshal - cc.jsonMarshal = c.jsonMarshal - cc.xmlMarshal = c.xmlMarshal - cc.xmlUnmarshal = c.xmlUnmarshal - return &cc } @@ -1137,12 +1133,16 @@ func C() *Client { } httpClient.CheckRedirect = c.defaultCheckRedirect - t.Debugf = func(format string, v ...interface{}) { + c.initTransport() + return c +} + +func (c *Client) initTransport() { + c.t.Debugf = func(format string, v ...interface{}) { if c.DebugLog { c.log.Debugf(format, v...) } } - return c } // RoundTripper is the interface of req's Client. diff --git a/internal/transport/option.go b/internal/transport/option.go index 54acc9b5..ed9538b9 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -143,3 +143,15 @@ type Options struct { Dump *dump.Dumper } + +func (o Options) Clone() Options { + oo := o + if o.TLSClientConfig != nil { + oo.TLSClientConfig = o.TLSClientConfig.Clone() + } + if o.Dump != nil { + oo.Dump = o.Dump.Clone() + go oo.Dump.Start() + } + return oo +} diff --git a/transport.go b/transport.go index 323254fa..07dfd506 100644 --- a/transport.go +++ b/transport.go @@ -578,25 +578,20 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { - tt := *t - tt.Dump = t.Dump.Clone() - if tt.Dump != nil { - go tt.Dump.Start() - } - if tt.TLSClientConfig != nil { - tt.TLSClientConfig = tt.TLSClientConfig.Clone() + tt := &Transport{ + Options: t.Options.Clone(), + disableAutoDecode: t.disableAutoDecode, + autoDecodeContentType: t.autoDecodeContentType, + wrappedRoundTrip: t.wrappedRoundTrip, + forceHttpVersion: t.forceHttpVersion, } if t.t2 != nil { - t2 := *t.t2 - t2.Options = &tt.Options - tt.t2 = &t2 + tt.t2 = &http2.Transport{Options: &tt.Options} } if t.t3 != nil { - t3 := *t.t3 - t3.Options = &tt.Options - tt.t3 = &t3 + tt.EnableHTTP3() } - return &tt + return tt } // EnableDump enables the dump for all requests with specified dump options. From 9a2746374d8394b9bce0449114490d6bbec22a05 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 24 Aug 2022 14:32:22 +0800 Subject: [PATCH 599/843] fix no Host in URL when SetScheme invoked --- middleware.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/middleware.go b/middleware.go index 392a8e16..bf1f583d 100644 --- a/middleware.go +++ b/middleware.go @@ -394,6 +394,13 @@ func parseRequestURL(c *Client, r *Request) error { return err } + if reqURL.Scheme == "" && len(c.scheme) > 0 { // set scheme if missing + reqURL, err = url.Parse(c.scheme + "://" + tempURL) + if err != nil { + return err + } + } + // If RawURL is relative path then added c.BaseURL into // the request URL otherwise Request.URL will be used as-is if !reqURL.IsAbs() { @@ -408,14 +415,6 @@ func parseRequestURL(c *Client, r *Request) error { } } - if reqURL.Scheme == "" && len(c.scheme) > 0 { - reqURL.Scheme = c.scheme - reqURL, err = url.Parse(reqURL.String()) // prevent empty URL.Host - if err != nil { - return err - } - } - // Adding Query Param query := make(url.Values) for k, v := range c.QueryParams { From b5ecb93ada28a067e12a87fa865f1e13367e7bca Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 24 Aug 2022 21:26:09 +0800 Subject: [PATCH 600/843] support h2c --- client.go | 12 ++++++++++++ client_wrapper.go | 12 ++++++++++++ internal/transport/option.go | 3 +++ transport.go | 27 +++++++++++++++++++++++++-- 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 66c053cd..6e545184 100644 --- a/client.go +++ b/client.go @@ -950,6 +950,18 @@ func (c *Client) DisableForceHttpVersion() *Client { return c } +// EnableH2C enables HTTP/2 over TCP without TLS. +func (c *Client) EnableH2C() *Client { + c.t.EnableH2C() + return c +} + +// DisableH2C disables HTTP/2 over TCP without TLS. +func (c *Client) DisableH2C() *Client { + c.t.DisableH2C() + return c +} + // DisableAllowGetMethodPayload disable sending GET method requests with body. func (c *Client) DisableAllowGetMethodPayload() *Client { c.AllowGetMethodPayload = false diff --git a/client_wrapper.go b/client_wrapper.go index d919fd0d..97661081 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -520,6 +520,18 @@ func DisableForceHttpVersion() *Client { return defaultClient.DisableForceHttpVersion() } +// EnableH2C is a global wrapper methods which delegated +// to the default client's EnableH2C. +func EnableH2C() *Client { + return defaultClient.EnableH2C() +} + +// DisableH2C is a global wrapper methods which delegated +// to the default client's DisableH2C. +func DisableH2C() *Client { + return defaultClient.DisableH2C() +} + // DisableAllowGetMethodPayload is a global wrapper methods which delegated // to the default client's DisableAllowGetMethodPayload. func DisableAllowGetMethodPayload() *Client { diff --git a/internal/transport/option.go b/internal/transport/option.go index ed9538b9..714aea34 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -70,6 +70,9 @@ type Options struct { // uncompressed. DisableCompression bool + // EnableH2C, if true, enables http2 over plain http without tls. + EnableH2C bool + // MaxIdleConns controls the maximum number of idle (keep-alive) // connections across all hosts. Zero means no limit. MaxIdleConns int diff --git a/transport.go b/transport.go index 07dfd506..84adbd5f 100644 --- a/transport.go +++ b/transport.go @@ -363,6 +363,24 @@ func (t *Transport) EnableForceHTTP2() *Transport { return t } +// EnableH2C enables HTTP2 over TCP without TLS. +func (t *Transport) EnableH2C() *Transport { + t.Options.EnableH2C = true + t.t2.AllowHTTP = true + t.t2.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + } + return t +} + +// DisableH2C disables HTTP2 over TCP without TLS. +func (t *Transport) DisableH2C() *Transport { + t.Options.EnableH2C = false + t.t2.AllowHTTP = false + t.t2.DialTLS = nil + return t +} + // EnableForceHTTP3 enable force using HTTP3 for https requests // (disabled by default). func (t *Transport) EnableForceHTTP3() *Transport { @@ -717,8 +735,13 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } } - if t.forceHttpVersion == h3 { - return t.t3.RoundTrip(req) + if t.forceHttpVersion != "" { + switch t.forceHttpVersion { + case h3: + return t.t3.RoundTrip(req) + case h2: + return t.t2.RoundTrip(req) + } } origReq := req From fe49f385f328b2a724cad13f0032b07f16f3feec Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 26 Aug 2022 20:36:50 +0800 Subject: [PATCH 601/843] beautify code --- client.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 6e545184..4b671c2e 100644 --- a/client.go +++ b/client.go @@ -1345,10 +1345,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. - if err != nil { // return immediately if error occurs. - return - } - break // jump out to execute the ResponseMiddlewares if possible. + return } // check retry whether is needed. @@ -1360,7 +1357,7 @@ func (c *Client) do(r *Request) (resp *Response, err error) { } } if !needRetry { // no retry is needed. - break // jump out to execute the ResponseMiddlewares. + return } // need retry, attempt to retry From 5a3fb0994710ea09f23b025031b22948a7dfcb90 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 26 Aug 2022 21:09:51 +0800 Subject: [PATCH 602/843] fix data race (#159) --- client.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 4b671c2e..520e0a81 100644 --- a/client.go +++ b/client.go @@ -1223,8 +1223,11 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { if r.trace == nil && r.client.trace { r.trace = &clientTrace{} } + + ctx := r.ctx + if r.trace != nil { - r.ctx = r.trace.createContext(r.Context()) + ctx = r.trace.createContext(r.Context()) } // setup url and host @@ -1260,7 +1263,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { for _, cookie := range r.Cookies { req.AddCookie(cookie) } - ctx := r.ctx if r.isSaveResponse && r.downloadCallback != nil { var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { return &callbackReader{ @@ -1275,10 +1277,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { interval: r.downloadCallbackInterval, } } - if ctx == nil { - ctx = context.Background() - } - ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) + ctx = context.WithValue(r.Context(), wrapResponseBodyKey, wrap) } if ctx != nil { req = req.WithContext(ctx) @@ -1371,7 +1370,9 @@ func (c *Client) do(r *Request) (resp *Response, err error) { if r.dumpBuffer != nil { r.dumpBuffer.Reset() } - r.trace = nil + if r.trace != nil { + r.trace = &clientTrace{} + } resp.body = nil resp.result = nil resp.error = nil From e39eaa0c74755d9e9e8fc7f87fdb46870dc1a9cf Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 27 Aug 2022 15:39:32 +0800 Subject: [PATCH 603/843] not use chunked encoding by default when upload (#160) --- middleware.go | 33 +++++++++++++++++++++++---------- request.go | 12 ++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/middleware.go b/middleware.go index bf1f583d..3e0e1cd0 100644 --- a/middleware.go +++ b/middleware.go @@ -82,7 +82,7 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e return err } - if r.uploadCallback != nil { + if r.forceChunkedEncoding && r.uploadCallback != nil { pw = &callbackWriter{ Writer: pw, lastTime: lastTime, @@ -110,7 +110,8 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e return err } -func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { +func writeMultiPart(r *Request, w *multipart.Writer) { + defer w.Close() // close multipart to write tailer boundary for k, vs := range r.FormData { for _, v := range vs { w.WriteField(k, v) @@ -119,18 +120,30 @@ func writeMultiPart(r *Request, w *multipart.Writer, pw *io.PipeWriter) { for _, file := range r.uploadFiles { writeMultipartFormFile(w, file, r) } - w.Close() // close multipart to write tailer boundary - pw.Close() // close pipe writer so that pipe reader could get EOF, and stop upload } func handleMultiPart(c *Client, r *Request) (err error) { - pr, pw := io.Pipe() - r.GetBody = func() (io.ReadCloser, error) { - return pr, nil + if r.forceChunkedEncoding { + pr, pw := io.Pipe() + r.GetBody = func() (io.ReadCloser, error) { + return pr, nil + } + w := multipart.NewWriter(pw) + r.SetContentType(w.FormDataContentType()) + go func() { + writeMultiPart(r, w) + pw.Close() // close pipe writer so that pipe reader could get EOF, and stop upload + }() + } else { + buf := new(bytes.Buffer) + w := multipart.NewWriter(buf) + writeMultiPart(r, w) + r.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil + } + r.Body = buf.Bytes() + r.SetContentType(w.FormDataContentType()) } - w := multipart.NewWriter(pw) - r.SetContentType(w.FormDataContentType()) - go writeMultiPart(r, w, pw) return } diff --git a/request.go b/request.go index 2502e5eb..e387b923 100644 --- a/request.go +++ b/request.go @@ -41,6 +41,7 @@ type Request struct { GetBody GetContentFunc isMultiPart bool + forceChunkedEncoding bool isSaveResponse bool error error client *Client @@ -319,6 +320,7 @@ func (r *Request) SetUploadCallbackWithInterval(callback UploadCallback, minInte if callback == nil { return r } + r.forceChunkedEncoding = true r.uploadCallback = callback r.uploadCallbackInterval = minInterval return r @@ -857,6 +859,16 @@ func (r *Request) EnableDumpWithoutResponseBody() *Request { return r.EnableDump() } +func (r *Request) EnableForceChunkedEncoding() *Request { + r.forceChunkedEncoding = true + return r +} + +func (r *Request) DisableForceChunkedEncoding() *Request { + r.forceChunkedEncoding = true + return r +} + func (r *Request) getRetryOption() *retryOption { if r.retryOption == nil { r.retryOption = newDefaultRetryOption() From f2fd2d981fd1e02ba9027c92b58668ab7c23f9a2 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 27 Aug 2022 15:44:51 +0800 Subject: [PATCH 604/843] add comments and global wrapper --- request.go | 2 ++ request_wrapper.go | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/request.go b/request.go index e387b923..3366a8d4 100644 --- a/request.go +++ b/request.go @@ -859,11 +859,13 @@ func (r *Request) EnableDumpWithoutResponseBody() *Request { return r.EnableDump() } +// EnableForceChunkedEncoding enables force using chunked encoding when uploading. func (r *Request) EnableForceChunkedEncoding() *Request { r.forceChunkedEncoding = true return r } +// DisableForceChunkedEncoding disables force using chunked encoding when uploading. func (r *Request) DisableForceChunkedEncoding() *Request { r.forceChunkedEncoding = true return r diff --git a/request_wrapper.go b/request_wrapper.go index d4a8d24d..ac775fbc 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -320,6 +320,18 @@ func EnableTrace() *Request { return defaultClient.R().EnableTrace() } +// EnableForceChunkedEncoding is a global wrapper methods which delegated +// to the default client, create a request and EnableForceChunkedEncoding for request. +func EnableForceChunkedEncoding() *Request { + return defaultClient.R().EnableForceChunkedEncoding() +} + +// DisableForceChunkedEncoding is a global wrapper methods which delegated +// to the default client, create a request and DisableForceChunkedEncoding for request. +func DisableForceChunkedEncoding() *Request { + return defaultClient.R().DisableForceChunkedEncoding() +} + // EnableDumpTo is a global wrapper methods which delegated // to the default client, create a request and EnableDumpTo for request. func EnableDumpTo(output io.Writer) *Request { From 592eee8b8ab65539fdad907da0cdbeb298def356 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 29 Aug 2022 14:59:59 +0800 Subject: [PATCH 605/843] Add EnableForceMultipart/DisableForceMultipart for Request --- request.go | 14 +++++++++++++- request_wrapper.go | 12 ++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/request.go b/request.go index 3366a8d4..7ea8c23c 100644 --- a/request.go +++ b/request.go @@ -867,7 +867,19 @@ func (r *Request) EnableForceChunkedEncoding() *Request { // DisableForceChunkedEncoding disables force using chunked encoding when uploading. func (r *Request) DisableForceChunkedEncoding() *Request { - r.forceChunkedEncoding = true + r.forceChunkedEncoding = false + return r +} + +// EnableForceMultipart enables force using multipart to upload form data. +func (r *Request) EnableForceMultipart() *Request { + r.isMultiPart = true + return r +} + +// DisableForceMultipart disables force using multipart to upload form data. +func (r *Request) DisableForceMultipart() *Request { + r.isMultiPart = true return r } diff --git a/request_wrapper.go b/request_wrapper.go index ac775fbc..a78a0684 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -332,6 +332,18 @@ func DisableForceChunkedEncoding() *Request { return defaultClient.R().DisableForceChunkedEncoding() } +// EnableForceMultipart is a global wrapper methods which delegated +// to the default client, create a request and EnableForceMultipart for request. +func EnableForceMultipart() *Request { + return defaultClient.R().EnableForceMultipart() +} + +// DisableForceMultipart is a global wrapper methods which delegated +// to the default client, create a request and DisableForceMultipart for request. +func DisableForceMultipart() *Request { + return defaultClient.R().DisableForceMultipart() +} + // EnableDumpTo is a global wrapper methods which delegated // to the default client, create a request and EnableDumpTo for request. func EnableDumpTo(output io.Writer) *Request { From 1dff30329f6a1a33b29d9cbc5f36c8cb9d671b65 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 29 Aug 2022 20:11:34 +0800 Subject: [PATCH 606/843] Support change the client of request dynamically. Add SetClient and GetClient for Request. --- client.go | 70 -------------------------------------------- request.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 71 deletions(-) diff --git a/client.go b/client.go index 520e0a81..4245a09c 100644 --- a/client.go +++ b/client.go @@ -1310,73 +1310,3 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { } return } - -func (c *Client) do(r *Request) (resp *Response, err error) { - defer func() { - if resp == nil { - resp = &Response{Request: r} - } - if err != nil { - resp.Err = err - } - }() - - for { - for _, f := range r.client.beforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - for _, f := range r.client.udBeforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - - if r.Headers == nil { - r.Headers = make(http.Header) - } - - if c.wrappedRoundTrip != nil { - resp, err = c.wrappedRoundTrip.RoundTrip(r) - } else { - resp, err = c.roundTrip(r) - } - - if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. - return - } - - // check retry whether is needed. - needRetry := err != nil // default behaviour: retry if error occurs - for _, condition := range r.retryOption.RetryConditions { // override default behaviour if custom RetryConditions has been set. - needRetry = condition(resp, err) - if needRetry { - break - } - } - if !needRetry { // no retry is needed. - return - } - - // need retry, attempt to retry - r.RetryAttempt++ - for _, hook := range r.retryOption.RetryHooks { // run retry hooks - hook(resp, err) - } - time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) - - // clean up before retry - if r.dumpBuffer != nil { - r.dumpBuffer.Reset() - } - if r.trace != nil { - r.trace = &clientTrace{} - } - resp.body = nil - resp.result = nil - resp.error = nil - } - - return -} diff --git a/request.go b/request.go index 7ea8c23c..62c2b72f 100644 --- a/request.go +++ b/request.go @@ -504,10 +504,80 @@ func (r *Request) Do(ctx ...context.Context) *Response { if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable Body return r.newErrorResponse(errRetryableWithUnReplayableBody) } - resp, _ := r.client.do(r) + resp, _ := r.do() return resp } +func (r *Request) do() (resp *Response, err error) { + defer func() { + if resp == nil { + resp = &Response{Request: r} + } + if err != nil { + resp.Err = err + } + }() + + for { + for _, f := range r.client.beforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + for _, f := range r.client.udBeforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + + if r.Headers == nil { + r.Headers = make(http.Header) + } + + if r.client.wrappedRoundTrip != nil { + resp, err = r.client.wrappedRoundTrip.RoundTrip(r) + } else { + resp, err = r.client.roundTrip(r) + } + + if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. + return + } + + // check retry whether is needed. + needRetry := err != nil // default behaviour: retry if error occurs + for _, condition := range r.retryOption.RetryConditions { // override default behaviour if custom RetryConditions has been set. + needRetry = condition(resp, err) + if needRetry { + break + } + } + if !needRetry { // no retry is needed. + return + } + + // need retry, attempt to retry + r.RetryAttempt++ + for _, hook := range r.retryOption.RetryHooks { // run retry hooks + hook(resp, err) + } + time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) + + // clean up before retry + if r.dumpBuffer != nil { + r.dumpBuffer.Reset() + } + if r.trace != nil { + r.trace = &clientTrace{} + } + resp.body = nil + resp.result = nil + resp.error = nil + } + + return +} + // Send fires http request with specified method and url, returns the // *Response which is always not nil, and the error is not nil if error occurs. func (r *Request) Send(method, url string) (*Response, error) { @@ -955,3 +1025,16 @@ func (r *Request) AddRetryCondition(condition RetryConditionFunc) *Request { ro.RetryConditions = append(ro.RetryConditions, condition) return r } + +// SetClient change the client of request dynamically. +func (r *Request) SetClient(client *Client) *Request { + if client != nil { + r.client = client + } + return r +} + +// GetClient returns the current client used by request. +func (r *Request) GetClient() *Client { + return r.client +} From 47a82c9bb4de407e8c14ed2e36784a52669af9d2 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 18 Sep 2022 22:17:36 +0800 Subject: [PATCH 607/843] optimize debug log level when cannot determine the unmarshal function --- middleware.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index 3e0e1cd0..ac670fd4 100644 --- a/middleware.go +++ b/middleware.go @@ -231,7 +231,9 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } else if util.IsXMLType(ct) { return c.xmlUnmarshal(body, v) } else { - c.log.Debugf("cannot determine the unmarshal function with %q Content-Type, default to json", ct) + if c.DebugLog { + c.log.Debugf("cannot determine the unmarshal function with %q Content-Type, default to json", ct) + } return c.jsonUnmarshal(body, v) } return From 5b2c1a444016c7aaf9b9a4b36e4beb64d6d88e51 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 19 Sep 2022 20:05:00 +0800 Subject: [PATCH 608/843] Support AddQueryParams (#164) --- client.go | 11 +++++++++++ client_wrapper.go | 6 ++++++ request.go | 11 +++++++++++ request_test.go | 19 +++---------------- request_wrapper.go | 6 ++++++ 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 4245a09c..fda6d6f2 100644 --- a/client.go +++ b/client.go @@ -372,6 +372,17 @@ func (c *Client) AddCommonQueryParam(key, value string) *Client { return c } +// AddCommonQueryParams add one or more values of specified URL query parameter for all requests. +func (c *Client) AddCommonQueryParams(key string, values ...string) *Client { + if c.QueryParams == nil { + c.QueryParams = make(urlpkg.Values) + } + vs := c.QueryParams[key] + vs = append(vs, values...) + c.QueryParams[key] = vs + return c +} + func (c *Client) pathParams() map[string]string { if c.PathParams == nil { c.PathParams = make(map[string]string) diff --git a/client_wrapper.go b/client_wrapper.go index 97661081..7c1016e9 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -142,6 +142,12 @@ func AddCommonQueryParam(key, value string) *Client { return defaultClient.AddCommonQueryParam(key, value) } +// AddCommonQueryParams is a global wrapper methods which delegated +// to the default client's AddCommonQueryParams. +func AddCommonQueryParams(key string, values ...string) *Client { + return defaultClient.AddCommonQueryParams(key, values...) +} + // SetCommonPathParam is a global wrapper methods which delegated // to the default client's SetCommonPathParam. func SetCommonPathParam(key, value string) *Client { diff --git a/request.go b/request.go index 62c2b72f..c9323777 100644 --- a/request.go +++ b/request.go @@ -456,6 +456,17 @@ func (r *Request) AddQueryParam(key, value string) *Request { return r } +// AddQueryParams add one or more values of specified URL query parameter for the request. +func (r *Request) AddQueryParams(key string, values ...string) *Request { + if r.QueryParams == nil { + r.QueryParams = make(urlpkg.Values) + } + vs := r.QueryParams[key] + vs = append(vs, values...) + r.QueryParams[key] = vs + return r +} + // SetPathParams set URL path parameters from a map for the request. func (r *Request) SetPathParams(params map[string]string) *Request { for key, value := range params { diff --git a/request_test.go b/request_test.go index 76ddbb8c..4dad7c4c 100644 --- a/request_test.go +++ b/request_test.go @@ -621,21 +621,6 @@ func testQueryParam(t *testing.T, c *Client) { assertSuccess(t, resp, err) tests.AssertEqual(t, "key1=value1&key2=value2&key3=value3&key4=value4&key5=value5", resp.String()) - // Set same param to override - resp, err = c.R(). - SetQueryParam("key1", "value1"). - SetQueryParams(map[string]string{ - "key2": "value2", - "key3": "value3", - }). - SetQueryString("key4=value4&key5=value5"). - SetQueryParam("key1", "value11"). - SetQueryParam("key2", "value22"). - SetQueryParam("key4", "value44"). - Get("/query-parameter") - assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value11&key2=value22&key3=value3&key4=value44&key5=value5", resp.String()) - // Add same param without override resp, err = c.R(). SetQueryParam("key1", "value1"). @@ -647,9 +632,11 @@ func testQueryParam(t *testing.T, c *Client) { AddQueryParam("key1", "value11"). AddQueryParam("key2", "value22"). AddQueryParam("key4", "value44"). + AddQueryParams("key6", "value6", "value66"). Get("/query-parameter") assertSuccess(t, resp, err) - tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5", resp.String()) + tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5&key6=value6&key6=value66", resp.String()) + } func TestPathParam(t *testing.T) { diff --git a/request_wrapper.go b/request_wrapper.go index a78a0684..d765e404 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -146,6 +146,12 @@ func AddQueryParam(key, value string) *Request { return defaultClient.R().AddQueryParam(key, value) } +// AddQueryParams is a global wrapper methods which delegated +// to the default client, create a request and AddQueryParams for request. +func AddQueryParams(key string, values ...string) *Request { + return defaultClient.R().AddQueryParams(key, values...) +} + // SetPathParams is a global wrapper methods which delegated // to the default client, create a request and SetPathParams for request. func SetPathParams(params map[string]string) *Request { From 2abac7be7b6d38f7d7a70d6b63eebdb18d7a789d Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 19 Sep 2022 20:37:02 +0800 Subject: [PATCH 609/843] Record original request in http3 (#165) --- internal/http3/client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/http3/client.go b/internal/http3/client.go index f43c9ba1..bc22b0e4 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -442,6 +442,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, ProtoMajor: 3, Header: http.Header{}, TLS: &connState, + Request: req, } for _, hf := range hfs { switch hf.Name { From eb32195eba47ebd0112a2616a8d649211172e679 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 24 Sep 2022 18:13:51 +0800 Subject: [PATCH 610/843] support parallel download --- client.go | 7 ++ logger.go | 6 +- parallel_download.go | 294 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 306 insertions(+), 1 deletion(-) create mode 100644 parallel_download.go diff --git a/client.go b/client.go index fda6d6f2..29d523fb 100644 --- a/client.go +++ b/client.go @@ -693,6 +693,13 @@ func (c *Client) NewRequest() *Request { return c.R() } +func (c *Client) NewParallelDownload(url string) *ParallelDownload { + return &ParallelDownload{ + url: url, + client: c, + } +} + // DisableAutoReadResponse disable read response body automatically (enabled by default). func (c *Client) DisableAutoReadResponse() *Client { c.disableAutoReadResponse = true diff --git a/logger.go b/logger.go index af461064..0a6bbd2f 100644 --- a/logger.go +++ b/logger.go @@ -4,6 +4,7 @@ import ( "io" "log" "os" + "sync" ) // Logger is the abstract logging interface, gives control to @@ -32,7 +33,8 @@ func (l *disableLogger) Warnf(format string, v ...interface{}) {} func (l *disableLogger) Debugf(format string, v ...interface{}) {} type logger struct { - l *log.Logger + mu sync.Mutex + l *log.Logger } func (l *logger) Errorf(format string, v ...interface{}) { @@ -48,6 +50,8 @@ func (l *logger) Debugf(format string, v ...interface{}) { } func (l *logger) output(level, format string, v ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() format = level + " [req] " + format if len(v) == 0 { l.l.Print(format) diff --git a/parallel_download.go b/parallel_download.go new file mode 100644 index 00000000..3765c47c --- /dev/null +++ b/parallel_download.go @@ -0,0 +1,294 @@ +package req + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "math" + urlpkg "net/url" + "os" + "path/filepath" + "strings" + "sync" +) + +type ParallelDownload struct { + url string + client *Client + concurrency int + output io.Writer + filename string + segmentSize int64 + perm os.FileMode + cacheRootDir string + cacheDir string + taskCh chan *downloadTask + doneCh chan struct{} + wgDoneCh chan struct{} + errCh chan error + wg sync.WaitGroup + taskMap map[int]*downloadTask + taskNotifyCh chan *downloadTask + mu sync.Mutex + lastIndex int +} + +func (pd *ParallelDownload) completeTask(task *downloadTask) { + pd.mu.Lock() + pd.taskMap[task.index] = task + pd.mu.Unlock() + pd.taskNotifyCh <- task +} + +func (pd *ParallelDownload) popTask(index int) *downloadTask { + pd.mu.Lock() + if task, ok := pd.taskMap[index]; ok { + delete(pd.taskMap, index) + pd.mu.Unlock() + return task + } + pd.mu.Unlock() + for { + task := <-pd.taskNotifyCh + if task.index == index { + pd.mu.Lock() + delete(pd.taskMap, index) + pd.mu.Unlock() + return task + } + } +} + +func md5Sum(s string) string { + sum := md5.Sum([]byte(s)) + return hex.EncodeToString(sum[:]) +} + +func (pd *ParallelDownload) ensure() error { + if pd.concurrency <= 0 { + pd.concurrency = 5 + } + if pd.segmentSize <= 0 { + pd.segmentSize = 1073741824 // 10MB + } + if pd.perm == 0 { + pd.perm = 0777 + } + if pd.cacheRootDir == "" { + pd.cacheRootDir = os.TempDir() + } + if pd.client.DebugLog { + pd.client.log.Debugf("use cache root directory %s", pd.cacheRootDir) + pd.client.log.Debugf("download with %d concurrency and %d bytes segment size", pd.concurrency, pd.segmentSize) + } + pd.cacheDir = filepath.Join(pd.cacheRootDir, md5Sum(pd.url)) + err := os.MkdirAll(pd.cacheDir, os.ModePerm) + if err != nil { + return err + } + + pd.taskCh = make(chan *downloadTask) + pd.doneCh = make(chan struct{}) + pd.wgDoneCh = make(chan struct{}) + pd.errCh = make(chan error) + pd.taskMap = make(map[int]*downloadTask) + pd.taskNotifyCh = make(chan *downloadTask) + return nil +} + +func (pd *ParallelDownload) SetSegmentSize(segmentSize int64) *ParallelDownload { + pd.segmentSize = segmentSize + return pd +} + +func (pd *ParallelDownload) SetCacheRootDir(cacheRootDir string) *ParallelDownload { + pd.cacheRootDir = cacheRootDir + return pd +} + +func (pd *ParallelDownload) SetFileMode(perm os.FileMode) *ParallelDownload { + pd.perm = perm + return pd +} + +func (pd *ParallelDownload) SetConcurrency(concurrency int) *ParallelDownload { + pd.concurrency = concurrency + return pd +} + +func (pd *ParallelDownload) SetOutput(output io.Writer) *ParallelDownload { + if output != nil { + pd.output = output + } + return pd +} + +func (pd *ParallelDownload) SetOutputFile(filename string) *ParallelDownload { + pd.filename = filename + return pd +} + +func getRangeTempFile(rangeStart, rangeEnd int64, workerDir string) string { + return filepath.Join(workerDir, fmt.Sprintf("temp-%d-%d", rangeStart, rangeEnd)) +} + +type downloadTask struct { + index int + rangeStart, rangeEnd int64 + tempFilename string + tempFile *os.File +} + +func (pd *ParallelDownload) handleTask(t *downloadTask) { + pd.wg.Add(1) + defer pd.wg.Done() + t.tempFilename = getRangeTempFile(t.rangeStart, t.rangeEnd, pd.cacheDir) + if pd.client.DebugLog { + pd.client.log.Debugf("downloading segment %d-%d", t.rangeStart, t.rangeEnd) + } + file, err := os.OpenFile(t.tempFilename, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + pd.errCh <- err + return + } + _, err = pd.client.R(). + SetHeader("Range", fmt.Sprintf("bytes=%d-%d", t.rangeStart, t.rangeEnd)). + SetOutput(file). + Get(pd.url) + if err != nil { + pd.errCh <- err + return + } + t.tempFile = file + pd.completeTask(t) +} + +func (pd *ParallelDownload) startWorker() { + for { + select { + case t := <-pd.taskCh: + pd.handleTask(t) + case <-pd.doneCh: + return + } + } +} + +func (pd *ParallelDownload) mergeFile() { + defer pd.wg.Done() + file, err := pd.getOutputFile() + if err != nil { + pd.errCh <- err + return + } + for i := 0; ; i++ { + task := pd.popTask(i) + tempFile, err := os.Open(task.tempFilename) + if err != nil { + pd.errCh <- err + return + } + _, err = io.Copy(file, tempFile) + tempFile.Close() + if err != nil { + pd.errCh <- err + return + } + if i < pd.lastIndex { + continue + } + break + } + err = os.RemoveAll(pd.cacheDir) + if err != nil { + pd.errCh <- err + } + if pd.client.DebugLog { + pd.client.log.Debugf("removed cache directory %s", pd.cacheDir) + } +} + +func (pd *ParallelDownload) Do() error { + err := pd.ensure() + if err != nil { + return err + } + for i := 0; i < pd.concurrency; i++ { + go pd.startWorker() + } + resp, err := pd.client.R().Head(pd.url) + if err != nil { + return err + } + if resp.ContentLength <= 0 { + return fmt.Errorf("bad content length: %d", resp.ContentLength) + } + pd.lastIndex = int(math.Ceil(float64(resp.ContentLength)/float64(pd.segmentSize))) - 1 + pd.wg.Add(1) + go pd.mergeFile() + go func() { + pd.wg.Wait() + close(pd.wgDoneCh) + }() + totalBytes := resp.ContentLength + start := int64(0) + for i := 0; ; i++ { + end := start + (pd.segmentSize - 1) + if end > (totalBytes - 1) { + end = totalBytes - 1 + } + task := &downloadTask{ + index: i, + rangeStart: start, + rangeEnd: end, + } + pd.taskCh <- task + if end < (totalBytes - 1) { + start = end + 1 + continue + } + break + } + select { + case <-pd.wgDoneCh: + if pd.client.DebugLog { + if pd.filename != "" { + pd.client.log.Debugf("download completed from %s to %s", pd.url, pd.filename) + } else { + pd.client.log.Debugf("download completed for %s", pd.url) + } + } + close(pd.doneCh) + case err := <-pd.errCh: + return err + } + return nil +} + +func (pd *ParallelDownload) getOutputFile() (io.Writer, error) { + outputFile := pd.output + if outputFile != nil { + return outputFile, nil + } + if pd.filename == "" { + u, err := urlpkg.Parse(pd.url) + if err != nil { + panic(err) + } + paths := strings.Split(u.Path, "/") + for i := len(paths) - 1; i > 0; i-- { + if paths[i] != "" { + pd.filename = paths[i] + break + } + } + if pd.filename == "" { + pd.filename = "download" + } + } + if pd.client.outputDirectory != "" && !filepath.IsAbs(pd.filename) { + pd.filename = filepath.Join(pd.client.outputDirectory, pd.filename) + } + return os.OpenFile(pd.filename, os.O_RDWR|os.O_CREATE, pd.perm) +} From 766bee48421279ed4c237b52147286c8553e266e Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 25 Sep 2022 11:36:04 +0800 Subject: [PATCH 611/843] optimize parallel download --- parallel_download.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/parallel_download.go b/parallel_download.go index 3765c47c..94ad286a 100644 --- a/parallel_download.go +++ b/parallel_download.go @@ -38,7 +38,12 @@ func (pd *ParallelDownload) completeTask(task *downloadTask) { pd.mu.Lock() pd.taskMap[task.index] = task pd.mu.Unlock() - pd.taskNotifyCh <- task + go func() { + select { + case pd.taskNotifyCh <- task: + case <-pd.doneCh: + } + }() } func (pd *ParallelDownload) popTask(index int) *downloadTask { @@ -78,11 +83,11 @@ func (pd *ParallelDownload) ensure() error { if pd.cacheRootDir == "" { pd.cacheRootDir = os.TempDir() } + pd.cacheDir = filepath.Join(pd.cacheRootDir, md5Sum(pd.url)) if pd.client.DebugLog { - pd.client.log.Debugf("use cache root directory %s", pd.cacheRootDir) + pd.client.log.Debugf("use cache directory %s", pd.cacheDir) pd.client.log.Debugf("download with %d concurrency and %d bytes segment size", pd.concurrency, pd.segmentSize) } - pd.cacheDir = filepath.Join(pd.cacheRootDir, md5Sum(pd.url)) err := os.MkdirAll(pd.cacheDir, os.ModePerm) if err != nil { return err @@ -200,13 +205,13 @@ func (pd *ParallelDownload) mergeFile() { } break } + if pd.client.DebugLog { + pd.client.log.Debugf("removing cache directory %s", pd.cacheDir) + } err = os.RemoveAll(pd.cacheDir) if err != nil { pd.errCh <- err } - if pd.client.DebugLog { - pd.client.log.Debugf("removed cache directory %s", pd.cacheDir) - } } func (pd *ParallelDownload) Do() error { From faa14f5018469e70b69648a41f8b0718e0464693 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 26 Sep 2022 10:17:04 +0800 Subject: [PATCH 612/843] add context to parallel download --- logger.go | 6 +----- parallel_download.go | 26 ++++++++++++++------------ request.go | 11 +++++------ 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/logger.go b/logger.go index 0a6bbd2f..af461064 100644 --- a/logger.go +++ b/logger.go @@ -4,7 +4,6 @@ import ( "io" "log" "os" - "sync" ) // Logger is the abstract logging interface, gives control to @@ -33,8 +32,7 @@ func (l *disableLogger) Warnf(format string, v ...interface{}) {} func (l *disableLogger) Debugf(format string, v ...interface{}) {} type logger struct { - mu sync.Mutex - l *log.Logger + l *log.Logger } func (l *logger) Errorf(format string, v ...interface{}) { @@ -50,8 +48,6 @@ func (l *logger) Debugf(format string, v ...interface{}) { } func (l *logger) output(level, format string, v ...interface{}) { - l.mu.Lock() - defer l.mu.Unlock() format = level + " [req] " + format if len(v) == 0 { l.l.Print(format) diff --git a/parallel_download.go b/parallel_download.go index 94ad286a..c08a8a47 100644 --- a/parallel_download.go +++ b/parallel_download.go @@ -1,6 +1,7 @@ package req import ( + "context" "crypto/md5" "encoding/hex" "fmt" @@ -145,22 +146,23 @@ type downloadTask struct { tempFile *os.File } -func (pd *ParallelDownload) handleTask(t *downloadTask) { +func (pd *ParallelDownload) handleTask(t *downloadTask, ctx ...context.Context) { pd.wg.Add(1) defer pd.wg.Done() t.tempFilename = getRangeTempFile(t.rangeStart, t.rangeEnd, pd.cacheDir) if pd.client.DebugLog { pd.client.log.Debugf("downloading segment %d-%d", t.rangeStart, t.rangeEnd) } - file, err := os.OpenFile(t.tempFilename, os.O_RDWR|os.O_CREATE, 0666) + file, err := os.OpenFile(t.tempFilename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { pd.errCh <- err return } - _, err = pd.client.R(). + err = pd.client.Get(pd.url). SetHeader("Range", fmt.Sprintf("bytes=%d-%d", t.rangeStart, t.rangeEnd)). SetOutput(file). - Get(pd.url) + Do(ctx...).Err + if err != nil { pd.errCh <- err return @@ -169,11 +171,11 @@ func (pd *ParallelDownload) handleTask(t *downloadTask) { pd.completeTask(t) } -func (pd *ParallelDownload) startWorker() { +func (pd *ParallelDownload) startWorker(ctx ...context.Context) { for { select { case t := <-pd.taskCh: - pd.handleTask(t) + pd.handleTask(t, ctx...) case <-pd.doneCh: return } @@ -214,17 +216,17 @@ func (pd *ParallelDownload) mergeFile() { } } -func (pd *ParallelDownload) Do() error { +func (pd *ParallelDownload) Do(ctx ...context.Context) error { err := pd.ensure() if err != nil { return err } for i := 0; i < pd.concurrency; i++ { - go pd.startWorker() + go pd.startWorker(ctx...) } - resp, err := pd.client.R().Head(pd.url) - if err != nil { - return err + resp := pd.client.Head(pd.url).Do(ctx...) + if resp.Err != nil { + return resp.Err } if resp.ContentLength <= 0 { return fmt.Errorf("bad content length: %d", resp.ContentLength) @@ -295,5 +297,5 @@ func (pd *ParallelDownload) getOutputFile() (io.Writer, error) { if pd.client.outputDirectory != "" && !filepath.IsAbs(pd.filename) { pd.filename = filepath.Join(pd.client.outputDirectory, pd.filename) } - return os.OpenFile(pd.filename, os.O_RDWR|os.O_CREATE, pd.perm) + return os.OpenFile(pd.filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, pd.perm) } diff --git a/request.go b/request.go index c9323777..361a30db 100644 --- a/request.go +++ b/request.go @@ -496,13 +496,10 @@ func (r *Request) newErrorResponse(err error) *Response { return resp } -// Do fires http request, 0 or 1 context ia allowed, and returns the *Response which +// Do fires http request, 0 or 1 context is allowed, and returns the *Response which // is always not nil, and Response.Err is not nil if error occurs. func (r *Request) Do(ctx ...context.Context) *Response { - if len(ctx) > 0 { - if len(ctx) > 1 { - return r.newErrorResponse(fmt.Errorf("only 0 or 1 context is allowed in Do, and %d are received", len(ctx))) - } + if len(ctx) > 0 && ctx[0] != nil { r.ctx = ctx[0] } @@ -820,7 +817,9 @@ func (r *Request) Context() context.Context { // See https://blog.golang.org/context article and the "context" package // documentation. func (r *Request) SetContext(ctx context.Context) *Request { - r.ctx = ctx + if ctx != nil { + r.ctx = ctx + } return r } From 5afabccc30e1a8316d5fa30680af4e8aee58a5de Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 26 Sep 2022 10:25:41 +0800 Subject: [PATCH 613/843] rename cache to temp in parallel download --- parallel_download.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/parallel_download.go b/parallel_download.go index c08a8a47..8690037b 100644 --- a/parallel_download.go +++ b/parallel_download.go @@ -22,8 +22,8 @@ type ParallelDownload struct { filename string segmentSize int64 perm os.FileMode - cacheRootDir string - cacheDir string + tempRootDir string + tempDir string taskCh chan *downloadTask doneCh chan struct{} wgDoneCh chan struct{} @@ -81,15 +81,15 @@ func (pd *ParallelDownload) ensure() error { if pd.perm == 0 { pd.perm = 0777 } - if pd.cacheRootDir == "" { - pd.cacheRootDir = os.TempDir() + if pd.tempRootDir == "" { + pd.tempRootDir = os.TempDir() } - pd.cacheDir = filepath.Join(pd.cacheRootDir, md5Sum(pd.url)) + pd.tempDir = filepath.Join(pd.tempRootDir, md5Sum(pd.url)) if pd.client.DebugLog { - pd.client.log.Debugf("use cache directory %s", pd.cacheDir) + pd.client.log.Debugf("use cache directory %s", pd.tempDir) pd.client.log.Debugf("download with %d concurrency and %d bytes segment size", pd.concurrency, pd.segmentSize) } - err := os.MkdirAll(pd.cacheDir, os.ModePerm) + err := os.MkdirAll(pd.tempDir, os.ModePerm) if err != nil { return err } @@ -108,8 +108,8 @@ func (pd *ParallelDownload) SetSegmentSize(segmentSize int64) *ParallelDownload return pd } -func (pd *ParallelDownload) SetCacheRootDir(cacheRootDir string) *ParallelDownload { - pd.cacheRootDir = cacheRootDir +func (pd *ParallelDownload) SetTempRootDir(tempRootDir string) *ParallelDownload { + pd.tempRootDir = tempRootDir return pd } @@ -149,7 +149,7 @@ type downloadTask struct { func (pd *ParallelDownload) handleTask(t *downloadTask, ctx ...context.Context) { pd.wg.Add(1) defer pd.wg.Done() - t.tempFilename = getRangeTempFile(t.rangeStart, t.rangeEnd, pd.cacheDir) + t.tempFilename = getRangeTempFile(t.rangeStart, t.rangeEnd, pd.tempDir) if pd.client.DebugLog { pd.client.log.Debugf("downloading segment %d-%d", t.rangeStart, t.rangeEnd) } @@ -208,9 +208,9 @@ func (pd *ParallelDownload) mergeFile() { break } if pd.client.DebugLog { - pd.client.log.Debugf("removing cache directory %s", pd.cacheDir) + pd.client.log.Debugf("removing cache directory %s", pd.tempDir) } - err = os.RemoveAll(pd.cacheDir) + err = os.RemoveAll(pd.tempDir) if err != nil { pd.errCh <- err } From 6139df8b84276c51ad91e443f3eda65c9c8afa84 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 26 Sep 2022 11:35:28 +0800 Subject: [PATCH 614/843] improve debug log in parallel download --- parallel_download.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parallel_download.go b/parallel_download.go index 8690037b..cdff7288 100644 --- a/parallel_download.go +++ b/parallel_download.go @@ -86,7 +86,7 @@ func (pd *ParallelDownload) ensure() error { } pd.tempDir = filepath.Join(pd.tempRootDir, md5Sum(pd.url)) if pd.client.DebugLog { - pd.client.log.Debugf("use cache directory %s", pd.tempDir) + pd.client.log.Debugf("use temporary directory %s", pd.tempDir) pd.client.log.Debugf("download with %d concurrency and %d bytes segment size", pd.concurrency, pd.segmentSize) } err := os.MkdirAll(pd.tempDir, os.ModePerm) @@ -208,7 +208,7 @@ func (pd *ParallelDownload) mergeFile() { break } if pd.client.DebugLog { - pd.client.log.Debugf("removing cache directory %s", pd.tempDir) + pd.client.log.Debugf("removing temporary directory %s", pd.tempDir) } err = os.RemoveAll(pd.tempDir) if err != nil { From 333d1035e4c1d7d66a5b6c4cdf62cf7cda74031b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 11 Oct 2022 14:41:02 +0800 Subject: [PATCH 615/843] still return body when Response.ToBytes() got an error --- response.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/response.go b/response.go index ffbfd30a..4d5f0677 100644 --- a/response.go +++ b/response.go @@ -167,12 +167,9 @@ func (r *Response) ToBytes() ([]byte, error) { } defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) - r.setReceivedAt() - if err != nil { - return nil, err - } r.body = body - return body, nil + r.setReceivedAt() + return body, err } // Dump return the string content that have been dumped for the request. From b97c57951c870db4abc1d0abab5ae7c25b880374 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 17 Oct 2022 13:45:47 +0800 Subject: [PATCH 616/843] Support flexible dump Add fields for DumpOptions to control dump output more fine-grained: * RequestOutput * ResponseOutput * RequestHeaderOutput * RequestBodyOutput * ResponseHeaderOutput * ResponseBodyOutput --- dump.go | 61 ++++++++++++++++-- internal/dump/dump.go | 106 ++++++++++++++++++++++--------- internal/http2/frame.go | 9 ++- internal/http2/transport.go | 10 +-- internal/http3/client.go | 10 +-- internal/http3/request_writer.go | 8 +-- request.go | 6 +- textproto_reader.go | 14 ++-- transfer.go | 6 +- transport.go | 2 +- 10 files changed, 165 insertions(+), 67 deletions(-) diff --git a/dump.go b/dump.go index 627f0b8b..77f96349 100644 --- a/dump.go +++ b/dump.go @@ -8,12 +8,18 @@ import ( // DumpOptions controls the dump behavior. type DumpOptions struct { - Output io.Writer - RequestHeader bool - RequestBody bool - ResponseHeader bool - ResponseBody bool - Async bool + Output io.Writer + RequestOutput io.Writer + ResponseOutput io.Writer + RequestHeaderOutput io.Writer + RequestBodyOutput io.Writer + ResponseHeaderOutput io.Writer + ResponseBodyOutput io.Writer + RequestHeader bool + RequestBody bool + ResponseHeader bool + ResponseBody bool + Async bool } // Clone return a copy of DumpOptions @@ -30,9 +36,52 @@ type dumpOptions struct { } func (o dumpOptions) Output() io.Writer { + if o.DumpOptions.Output == nil { + return os.Stdout + } return o.DumpOptions.Output } +func (o dumpOptions) RequestHeaderOutput() io.Writer { + if o.DumpOptions.RequestHeaderOutput != nil { + return o.DumpOptions.RequestHeaderOutput + } + if o.DumpOptions.RequestOutput != nil { + return o.DumpOptions.RequestOutput + } + return o.Output() +} + +func (o dumpOptions) RequestBodyOutput() io.Writer { + if o.DumpOptions.RequestBodyOutput != nil { + return o.DumpOptions.RequestBodyOutput + } + if o.DumpOptions.RequestOutput != nil { + return o.DumpOptions.RequestOutput + } + return o.Output() +} + +func (o dumpOptions) ResponseHeaderOutput() io.Writer { + if o.DumpOptions.ResponseHeaderOutput != nil { + return o.DumpOptions.ResponseHeaderOutput + } + if o.DumpOptions.ResponseOutput != nil { + return o.DumpOptions.ResponseOutput + } + return o.Output() +} + +func (o dumpOptions) ResponseBodyOutput() io.Writer { + if o.DumpOptions.ResponseBodyOutput != nil { + return o.DumpOptions.ResponseBodyOutput + } + if o.DumpOptions.ResponseOutput != nil { + return o.DumpOptions.ResponseOutput + } + return o.Output() +} + func (o dumpOptions) RequestHeader() bool { return o.DumpOptions.RequestHeader } diff --git a/internal/dump/dump.go b/internal/dump/dump.go index 5ffffaca..88efac97 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -9,6 +9,10 @@ import ( // Options controls the dump behavior. type Options interface { Output() io.Writer + RequestHeaderOutput() io.Writer + RequestBodyOutput() io.Writer + ResponseHeaderOutput() io.Writer + ResponseBodyOutput() io.Writer RequestHeader() bool RequestBody() bool ResponseHeader() bool @@ -17,52 +21,70 @@ type Options interface { Clone() Options } -func (d *Dumper) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { - return &dumpReadCloser{rc, d} +func (d *Dumper) WrapResponseBodyReadCloser(rc io.ReadCloser) io.ReadCloser { + return &dumpReponseBodyReadCloser{rc, d} } -type dumpReadCloser struct { +type dumpReponseBodyReadCloser struct { io.ReadCloser dump *Dumper } -func (r *dumpReadCloser) Read(p []byte) (n int, err error) { +func (r *dumpReponseBodyReadCloser) Read(p []byte) (n int, err error) { n, err = r.ReadCloser.Read(p) - r.dump.Dump(p[:n]) + r.dump.DumpResponseBody(p[:n]) if err == io.EOF { - r.dump.Dump([]byte("\r\n")) + r.dump.DumpDefault([]byte("\r\n")) } return } -func (d *Dumper) WrapWriteCloser(rc io.WriteCloser) io.WriteCloser { - return &dumpWriteCloser{rc, d} +func (d *Dumper) WrapRequestBodyWriteCloser(rc io.WriteCloser) io.WriteCloser { + return &dumpRequestBodyWriteCloser{rc, d} } -type dumpWriteCloser struct { +type dumpRequestBodyWriteCloser struct { io.WriteCloser dump *Dumper } -func (w *dumpWriteCloser) Write(p []byte) (n int, err error) { +func (w *dumpRequestBodyWriteCloser) Write(p []byte) (n int, err error) { n, err = w.WriteCloser.Write(p) - w.dump.Dump(p[:n]) + w.dump.DumpRequestBody(p[:n]) return } -type dumpWriter struct { +type dumpRequestHeaderWriter struct { w io.Writer dump *Dumper } -func (w *dumpWriter) Write(p []byte) (n int, err error) { +func (w *dumpRequestHeaderWriter) Write(p []byte) (n int, err error) { n, err = w.w.Write(p) - w.dump.Dump(p[:n]) + w.dump.DumpRequestHeader(p[:n]) return } -func (d *Dumper) WrapWriter(w io.Writer) io.Writer { - return &dumpWriter{ +func (d *Dumper) WrapRequestHeaderWriter(w io.Writer) io.Writer { + return &dumpRequestHeaderWriter{ + w: w, + dump: d, + } +} + +type dumpRequestBodyWriter struct { + w io.Writer + dump *Dumper +} + +func (w *dumpRequestBodyWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.dump.DumpRequestBody(p[:n]) + return +} + +func (d *Dumper) WrapRequestBodyWriter(w io.Writer) io.Writer { + return &dumpRequestBodyWriter{ w: w, dump: d, } @@ -88,24 +110,28 @@ func (ds Dumpers) ShouldDump() bool { return len(ds) > 0 } -// Dump with all dumpers. -func (ds Dumpers) Dump(p []byte) { +func (ds Dumpers) DumpResponseHeader(p []byte) { for _, d := range ds { - d.Dump(p) + d.DumpResponseHeader(p) } } // Dumper is the dump tool. type Dumper struct { Options - ch chan []byte + ch chan *dumpTask +} + +type dumpTask struct { + Data []byte + Output io.Writer } // NewDumper create a new Dumper. func NewDumper(opt Options) *Dumper { d := &Dumper{ Options: opt, - ch: make(chan []byte, 20), + ch: make(chan *dumpTask, 20), } return d } @@ -121,21 +147,41 @@ func (d *Dumper) Clone() *Dumper { } return &Dumper{ Options: d.Options.Clone(), - ch: make(chan []byte, 20), + ch: make(chan *dumpTask, 20), } } -func (d *Dumper) Dump(p []byte) { - if len(p) == 0 { +func (d *Dumper) DumpTo(p []byte, output io.Writer) { + if len(p) == 0 || output == nil { return } if d.Async() { b := make([]byte, len(p)) copy(b, p) - d.ch <- b + d.ch <- &dumpTask{Data: b, Output: output} return } - d.Output().Write(p) + output.Write(p) +} + +func (d *Dumper) DumpDefault(p []byte) { + d.DumpTo(p, d.Output()) +} + +func (d *Dumper) DumpRequestHeader(p []byte) { + d.DumpTo(p, d.RequestHeaderOutput()) +} + +func (d *Dumper) DumpRequestBody(p []byte) { + d.DumpTo(p, d.RequestBodyOutput()) +} + +func (d *Dumper) DumpResponseHeader(p []byte) { + d.DumpTo(p, d.ResponseHeaderOutput()) +} + +func (d *Dumper) DumpResponseBody(p []byte) { + d.DumpTo(p, d.ResponseBodyOutput()) } func (d *Dumper) Stop() { @@ -143,11 +189,11 @@ func (d *Dumper) Stop() { } func (d *Dumper) Start() { - for b := range d.ch { - if b == nil { + for t := range d.ch { + if t == nil { return } - d.Output().Write(b) + t.Output.Write(t.Data) } } @@ -170,7 +216,7 @@ func WrapResponseBodyIfNeeded(res *http.Response, req *http.Request, dump *Dumpe dumps := GetDumpers(req.Context(), dump) for _, d := range dumps { if d.ResponseBody() { - res.Body = d.WrapReadCloser(res.Body) + res.Body = d.WrapResponseBodyReadCloser(res.Body) } } } diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 3ed6a0ac..e5323b74 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -540,7 +540,7 @@ func (h2f *Framer) ReadFrame() (Frame, error) { hr, err := h2f.readMetaFrame(f.(*HeadersFrame), dumps) if err == nil && len(dumps) > 0 { for _, dump := range dumps { - dump.Dump([]byte("\r\n")) + dump.DumpResponseHeader([]byte("\r\n")) } } return hr, err @@ -1587,11 +1587,10 @@ func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaH } emitFunc := rawEmitFunc - if len(dumps) > 0 { + ds := dump.Dumpers(dumps) + if ds.ShouldDump() { emitFunc = func(hf hpack.HeaderField) { - for _, dump := range dumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) - } + ds.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) rawEmitFunc(hf) } } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 01600061..8b9d8a84 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -1239,7 +1239,7 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } else { cs.sentEndStream = true for _, dump := range bodyDumps { - dump.Dump([]byte("\r\n\r\n")) + dump.DumpDefault([]byte("\r\n\r\n")) } } } @@ -1499,7 +1499,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request, dumps []*dump.Dumper if len(dumps) > 0 { writeData = func(streamID uint32, endStream bool, data []byte) error { for _, dump := range dumps { - dump.Dump(data) + dump.DumpRequestBody(data) } return cc.fr.WriteData(streamID, endStream, data) } @@ -1818,7 +1818,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail if len(headerDumps) > 0 { writeHeader = func(name, value string) { for _, dump := range headerDumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + dump.DumpRequestHeader([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } cc.writeHeader(name, value) } @@ -1840,7 +1840,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail }) for _, dump := range headerDumps { - dump.Dump([]byte("\r\n")) + dump.DumpRequestHeader([]byte("\r\n")) } return cc.hbuf.Bytes(), nil @@ -1887,7 +1887,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header, dumps []*dump.Dumper) if len(dumps) > 0 { writeHeader = func(name, value string) { for _, dump := range dumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + dump.DumpRequestHeader([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } cc.writeHeader(name, value) } diff --git a/internal/http3/client.go b/internal/http3/client.go index bc22b0e4..39032029 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -322,7 +322,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D if len(dumps) > 0 { writeData = func(data []byte) error { for _, dump := range dumps { - dump.Dump(data) + dump.DumpRequestBody(data) } if _, err := str.Write(data); err != nil { return err @@ -338,7 +338,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D } if rerr == io.EOF { for _, dump := range dumps { - dump.Dump([]byte("\r\n\r\n")) + dump.DumpDefault([]byte("\r\n\r\n")) } break } @@ -349,7 +349,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D if rerr != nil { if rerr == io.EOF { for _, dump := range dumps { - dump.Dump([]byte("\r\n\r\n")) + dump.DumpDefault([]byte("\r\n\r\n")) } break } @@ -424,11 +424,11 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, if len(respHeaderDumps) > 0 { for _, hf := range hfs { for _, dump := range respHeaderDumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + dump.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) } } for _, dump := range respHeaderDumps { - dump.Dump([]byte("\r\n")) + dump.DumpResponseHeader([]byte("\r\n")) } } if err != nil { diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index dbd04ad0..0c7312b2 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -203,7 +203,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra enumerateHeaders(func(name, value string) { name = strings.ToLower(name) for _, dump := range dumps { - dump.Dump([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + dump.DumpRequestHeader([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) // if traceHeaders { @@ -212,7 +212,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra }) for _, dump := range dumps { - dump.Dump([]byte("\r\n")) + dump.DumpRequestHeader([]byte("\r\n")) } return nil @@ -242,8 +242,8 @@ func authorityAddr(scheme string, authority string) (addr string) { // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // -// *) a non-empty string starting with '/' -// *) the string '*', for OPTIONS requests. +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. // // For now this is only used a quick check for deciding when to clean // up Opaque URLs before sending requests from the Transport. diff --git a/request.go b/request.go index 361a30db..f4498fd6 100644 --- a/request.go +++ b/request.go @@ -882,7 +882,11 @@ func (r *Request) SetDumpOptions(opt *DumpOptions) *Request { if opt.Output == nil { opt.Output = r.getDumpBuffer() } - r.dumpOptions = opt + if r.dumpOptions != nil { + *r.dumpOptions = *opt + } else { + r.dumpOptions = opt + } return r } diff --git a/textproto_reader.go b/textproto_reader.go index b3a9096f..46103c07 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -47,7 +47,7 @@ func newTextprotoReader(r *bufio.Reader, ds dump.Dumpers) *textprotoReader { return } err = nil - ds.Dump(line) + ds.DumpResponseHeader(line) if line[len(line)-1] == '\n' { drop := 1 if len(line) > 1 && line[len(line)-2] == '\r' { @@ -202,7 +202,6 @@ var colon = []byte(":") // "My-Key": {"Value 1", "Value 2"}, // "Long-Key": {"Even Longer Value"}, // } -// func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller @@ -295,11 +294,12 @@ const toLower = 'a' - 'A' // validHeaderFieldByte reports whether b is a valid byte in a header // field name. RFC 7230 says: -// header-field = field-name ":" OWS field-value OWS -// field-name = token -// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / -// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA -// token = 1*tchar +// +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +// token = 1*tchar func validHeaderFieldByte(b byte) bool { return int(b) < len(isTokenTable) && isTokenTable[b] } diff --git a/transfer.go b/transfer.go index 69a710e5..95001f66 100644 --- a/transfer.go +++ b/transfer.go @@ -321,7 +321,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error rw := w // raw writer for _, dump := range dumps { if dump.RequestBody() { - w = dump.WrapWriter(w) + w = dump.WrapRequestBodyWriter(w) } } @@ -338,7 +338,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error cw := internal.NewChunkedWriter(rw) for _, dump := range dumps { if dump.RequestBody() { - cw = dump.WrapWriteCloser(cw) + cw = dump.WrapRequestBodyWriteCloser(cw) } } _, err = t.doBodyCopy(cw, body) @@ -365,7 +365,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error } for _, dump := range dumps { if dump.RequestBody() { - dump.Dump([]byte("\r\n")) + dump.DumpDefault([]byte("\r\n")) } } } diff --git a/transport.go b/transport.go index 84adbd5f..7d14f86d 100644 --- a/transport.go +++ b/transport.go @@ -2743,7 +2743,7 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo dumps := dump.GetDumpers(r.Context(), pc.t.Dump) for _, dump := range dumps { if dump.RequestHeader() { - w = dump.WrapWriter(w) + w = dump.WrapRequestHeaderWriter(w) } } From 150457b4d23e289c21d73165ef9635568f66d759 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 20 Oct 2022 17:01:10 +0800 Subject: [PATCH 617/843] update comments for Response.Error() and Response.Result() --- response.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/response.go b/response.go index 4d5f0677..10bd86bc 100644 --- a/response.go +++ b/response.go @@ -49,12 +49,14 @@ func (r *Response) GetContentType() string { return r.Header.Get(header.ContentType) } -// Result returns the response value as an object if it has one +// Result returns the automatically unmarshalled object if Request.SetResult is called, +// and Response.IsSuccess() == true. Otherwise return nil. func (r *Response) Result() interface{} { return r.result } -// Error returns the error object if it has one. +// Error returns the automatically unmarshalled object when Request.SetError is called, +// and Response.IsError() == true. Otherwise return nil. func (r *Response) Error() interface{} { return r.error } From a91b8f9f378c0be272dcfd4b7ddc86c1277743bc Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 28 Nov 2022 12:33:33 +0800 Subject: [PATCH 618/843] Support customize Content-Type when uploading multipart --- middleware.go | 6 +++++- req.go | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index ac670fd4..b2a22dbb 100644 --- a/middleware.go +++ b/middleware.go @@ -77,7 +77,11 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e } } - pw, err := w.CreatePart(createMultipartHeader(file, http.DetectContentType(cbuf))) + ct := file.ContentType + if ct == "" { + ct = http.DetectContentType(cbuf) + } + pw, err := w.CreatePart(createMultipartHeader(file, ct)) if err != nil { return err } diff --git a/req.go b/req.go index eef943bb..81fadac8 100644 --- a/req.go +++ b/req.go @@ -45,6 +45,8 @@ type FileUpload struct { GetFileContent GetContentFunc // Optional file length in bytes. FileSize int64 + // Optional Content-Type + ContentType string // Optional extra ContentDisposition parameters. // According to the HTTP specification, this should be nil, From 51a6c64240869b3de8c453fe3cd416c9cedf636a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 29 Nov 2022 14:32:37 +0800 Subject: [PATCH 619/843] Support EnableCloseConnection (#183) --- client.go | 1 + request.go | 12 ++++++++++++ request_wrapper.go | 6 ++++++ 3 files changed, 19 insertions(+) diff --git a/client.go b/client.go index 29d523fb..d7d298fa 100644 --- a/client.go +++ b/client.go @@ -1277,6 +1277,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { ContentLength: contentLength, Body: reqBody, GetBody: r.GetBody, + Close: r.close, } for _, cookie := range r.Cookies { req.AddCookie(cookie) diff --git a/request.go b/request.go index f4498fd6..6508e9b6 100644 --- a/request.go +++ b/request.go @@ -43,6 +43,7 @@ type Request struct { isMultiPart bool forceChunkedEncoding bool isSaveResponse bool + close bool error error client *Client uploadCallback UploadCallback @@ -1052,3 +1053,14 @@ func (r *Request) SetClient(client *Client) *Request { func (r *Request) GetClient() *Client { return r.client } + +// EnableCloseConnection closes the connection after sending this +// request and reading its response if set to true in HTTP/1.1 and +// HTTP/2. +// +// Setting this field prevents re-use of TCP connections between +// requests to the same hosts event if EnableKeepAlives() were called. +func (r *Request) EnableCloseConnection() *Request { + r.close = true + return r +} diff --git a/request_wrapper.go b/request_wrapper.go index d765e404..9fabc9de 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -481,3 +481,9 @@ func SetDownloadCallback(callback DownloadCallback) *Request { func SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration) *Request { return defaultClient.R().SetDownloadCallbackWithInterval(callback, minInterval) } + +// EnableCloseConnection is a global wrapper methods which delegated +// to the default client, create a request and EnableCloseConnection for request. +func EnableCloseConnection() *Request { + return defaultClient.R().EnableCloseConnection() +} From d673df054bc9f5f5f65ed72bd0cb0593b77fbeb2 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 1 Dec 2022 18:51:42 +0800 Subject: [PATCH 620/843] fix Transport.SetDialTLS not work in http2 --- internal/http2/go115.go | 20 +++++++++++++------- internal/http2/not_go115.go | 26 -------------------------- 2 files changed, 13 insertions(+), 33 deletions(-) delete mode 100644 internal/http2/not_go115.go diff --git a/internal/http2/go115.go b/internal/http2/go115.go index c9d2183f..629d8613 100644 --- a/internal/http2/go115.go +++ b/internal/http2/go115.go @@ -11,18 +11,24 @@ import ( "context" "crypto/tls" reqtls "github.com/imroc/req/v3/pkg/tls" + "net" ) // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { - dialer := &tls.Dialer{ - Config: cfg, +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsCn reqtls.Conn, err error) { + var conn net.Conn + if t.DialTLSContext != nil { + conn, err = t.DialTLSContext(ctx, network, addr) + } else { + dialer := &tls.Dialer{ + Config: cfg, + } + conn, err = dialer.DialContext(ctx, network, addr) } - cn, err := dialer.DialContext(ctx, network, addr) if err != nil { - return nil, err + return } - tlsCn := cn.(reqtls.Conn) // DialContext comment promises this will always succeed - return tlsCn, nil + tlsCn = conn.(reqtls.Conn) + return } diff --git a/internal/http2/not_go115.go b/internal/http2/not_go115.go deleted file mode 100644 index 47349d62..00000000 --- a/internal/http2/not_go115.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.15 -// +build !go1.15 - -package http2 - -// dialTLSWithContext opens a TLS connection. -func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (TLSConn, error) { - cn, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - if err := cn.Handshake(); err != nil { - return nil, err - } - if cfg.InsecureSkipVerify { - return cn, nil - } - if err := cn.VerifyHostname(cfg.ServerName); err != nil { - return nil, err - } - return cn, nil -} From fe9bec84ecf5e0938d537f5a6f1335f82f01c2b4 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 1 Dec 2022 19:45:07 +0800 Subject: [PATCH 621/843] SetDialTLS should override dial func in EnableH2C --- internal/http2/go115.go | 20 +++++++------------- internal/http2/transport.go | 6 ++++++ transport.go | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/http2/go115.go b/internal/http2/go115.go index 629d8613..a3a4dfc5 100644 --- a/internal/http2/go115.go +++ b/internal/http2/go115.go @@ -11,24 +11,18 @@ import ( "context" "crypto/tls" reqtls "github.com/imroc/req/v3/pkg/tls" - "net" ) // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsCn reqtls.Conn, err error) { - var conn net.Conn - if t.DialTLSContext != nil { - conn, err = t.DialTLSContext(ctx, network, addr) - } else { - dialer := &tls.Dialer{ - Config: cfg, - } - conn, err = dialer.DialContext(ctx, network, addr) +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, } + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { - return + return nil, err } - tlsCn = conn.(reqtls.Conn) - return + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 8b9d8a84..f6b3efe5 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -73,6 +73,7 @@ const ( // for concurrent use by multiple goroutines. type Transport struct { *transport.Options + // DialTLS specifies an optional dial function for creating // TLS connections for requests. // @@ -560,6 +561,11 @@ func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Confi if t.DialTLS != nil { return t.DialTLS } + if t.DialTLSContext != nil { + return func(network string, addr string, cfg *tls.Config) (net.Conn, error) { + return t.DialTLSContext(ctx, network, addr) + } + } return func(network, addr string, cfg *tls.Config) (net.Conn, error) { tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) if err != nil { diff --git a/transport.go b/transport.go index 7d14f86d..e6446484 100644 --- a/transport.go +++ b/transport.go @@ -367,7 +367,7 @@ func (t *Transport) EnableForceHTTP2() *Transport { func (t *Transport) EnableH2C() *Transport { t.Options.EnableH2C = true t.t2.AllowHTTP = true - t.t2.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + t.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial(network, addr) } return t @@ -377,7 +377,7 @@ func (t *Transport) EnableH2C() *Transport { func (t *Transport) DisableH2C() *Transport { t.Options.EnableH2C = false t.t2.AllowHTTP = false - t.t2.DialTLS = nil + t.t2.DialTLSContext = nil return t } From c914b21bd647f18315886673f59527e0f20b01dc Mon Sep 17 00:00:00 2001 From: shuai_yang Date: Mon, 5 Dec 2022 14:23:17 +0800 Subject: [PATCH 622/843] feat: add request method DisableAutoReadResponse add request method DisableAutoReadResponse for disable auto read response in request level --- request.go | 6 ++++++ request_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/request.go b/request.go index 6508e9b6..37556c8f 100644 --- a/request.go +++ b/request.go @@ -830,6 +830,12 @@ func (r *Request) DisableTrace() *Request { return r } +// DisableAutoReadResponse disable read response body automatically (enabled by default). +func (r *Request) DisableAutoReadResponse() *Request { + r.isSaveResponse = true + return r +} + // EnableTrace enables trace (http3 currently does not support trace). func (r *Request) EnableTrace() *Request { if r.trace == nil { diff --git a/request_test.go b/request_test.go index 4dad7c4c..70eeb963 100644 --- a/request_test.go +++ b/request_test.go @@ -998,6 +998,22 @@ func TestDownloadCallback(t *testing.T) { tests.AssertEqual(t, true, n > 0) } +func TestRequestDisableAutoReadResponse(t *testing.T) { + testWithAllTransport(t, func(t *testing.T, c *Client) { + resp, err := c.R().DisableAutoReadResponse().Get("/") + assertSuccess(t, resp, err) + tests.AssertEqual(t, "", resp.String()) + result, err := resp.ToString() + tests.AssertNoError(t, err) + tests.AssertEqual(t, "TestGet: text response", result) + + resp, err = c.R().DisableAutoReadResponse().Get("/") + assertSuccess(t, resp, err) + _, err = ioutil.ReadAll(resp.Body) + tests.AssertNoError(t, err) + }) +} + func TestRestoreResponseBody(t *testing.T) { c := tc() resp, err := c.R().Get("/") From 75c4b64f8e3bd7470a01235dafaa11e7ece37378 Mon Sep 17 00:00:00 2001 From: shuai_yang Date: Mon, 5 Dec 2022 14:30:54 +0800 Subject: [PATCH 623/843] feat: add request method DisableAutoReadResponse add request method DisableAutoReadResponse for disable auto read response in request level --- client.go | 2 +- request.go | 19 +++++++++++++------ response.go | 2 ++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index d7d298fa..2b12c823 100644 --- a/client.go +++ b/client.go @@ -1312,7 +1312,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { return } // auto-read response body if possible - if !c.disableAutoReadResponse && !r.isSaveResponse { + if !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { _, err = resp.ToBytes() if err != nil { return diff --git a/request.go b/request.go index 37556c8f..f4470f06 100644 --- a/request.go +++ b/request.go @@ -41,6 +41,7 @@ type Request struct { GetBody GetContentFunc isMultiPart bool + disableAutoReadResponse bool forceChunkedEncoding bool isSaveResponse bool close bool @@ -824,15 +825,21 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } -// DisableTrace disables trace. -func (r *Request) DisableTrace() *Request { - r.trace = nil +// DisableAutoReadResponse disable read response body automatically (enabled by default). +func (r *Request) DisableAutoReadResponse() *Request { + r.disableAutoReadResponse = true return r } -// DisableAutoReadResponse disable read response body automatically (enabled by default). -func (r *Request) DisableAutoReadResponse() *Request { - r.isSaveResponse = true +// EnableAutoReadResponse enable read response body automatically (enabled by default). +func (r *Request) EnableAutoReadResponse() *Request { + r.disableAutoReadResponse = false + return r +} + +// DisableTrace disables trace. +func (r *Request) DisableTrace() *Request { + r.trace = nil return r } diff --git a/response.go b/response.go index 10bd86bc..060217ce 100644 --- a/response.go +++ b/response.go @@ -136,6 +136,7 @@ func (r *Response) Into(v interface{}) error { // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. // 2. `Client.DisableAutoReadResponse(false)` is not called, +// 3. `Request.DisableAutoReadResponse(false)` is not called, // also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) Bytes() []byte { return r.body @@ -145,6 +146,7 @@ func (r *Response) Bytes() []byte { // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. // 2. `Client.DisableAutoReadResponse(false)` is not called, +// 3. `Request.DisableAutoReadResponse(false)` is not called, // also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) String() string { return string(r.body) From 4ded7b91adc3acc7e8ae0d98b58778a1be964b7e Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 5 Dec 2022 16:13:55 +0800 Subject: [PATCH 624/843] improve comments in Response --- response.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/response.go b/response.go index 060217ce..31632b7e 100644 --- a/response.go +++ b/response.go @@ -135,9 +135,8 @@ func (r *Response) Into(v interface{}) error { // Bytes return the response body as []bytes that hava already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, -// 3. `Request.DisableAutoReadResponse(false)` is not called, -// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +// 2. `Client.DisableAutoReadResponse` and `Request.DisableAutoReadResponse` is not +// called, and also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) Bytes() []byte { return r.body } @@ -145,9 +144,8 @@ func (r *Response) Bytes() []byte { // String returns the response body as string that hava already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. -// 2. `Client.DisableAutoReadResponse(false)` is not called, -// 3. `Request.DisableAutoReadResponse(false)` is not called, -// also `Request.SetOutput` and `Request.SetOutputFile` is not called. +// 2. `Client.DisableAutoReadResponse` and `Request.DisableAutoReadResponse` is not +// called, and also `Request.SetOutput` and `Request.SetOutputFile` is not called. func (r *Response) String() string { return string(r.body) } From 22eacee17e317d96bafffa75d7637d6e9c600394 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 5 Dec 2022 16:14:40 +0800 Subject: [PATCH 625/843] update go mod: require go1.16 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 88995321..b50f7b82 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/imroc/req/v3 -go 1.15 +go 1.16 require ( github.com/fsnotify/fsnotify v1.5.4 // indirect From 2522eb249235b8fe492eaeb2c06cd594823d235b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 20 Dec 2022 13:30:29 +0800 Subject: [PATCH 626/843] execute user defined middleware at first(#190) --- request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index f4470f06..57723142 100644 --- a/request.go +++ b/request.go @@ -529,12 +529,12 @@ func (r *Request) do() (resp *Response, err error) { }() for { - for _, f := range r.client.beforeRequest { + for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { return } } - for _, f := range r.client.udBeforeRequest { + for _, f := range r.client.beforeRequest { if err = f(r.client, r); err != nil { return } From e8ed815f63e51742272c12c3d9b46bdb10e8870b Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 3 Jan 2023 10:59:23 +0800 Subject: [PATCH 627/843] fix data race in http2 dump(#181) --- internal/dump/dump.go | 3 +++ internal/http2/frame.go | 34 ++++++++++++++++++++++++++++++++-- internal/http2/transport.go | 20 ++++++++++---------- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/internal/dump/dump.go b/internal/dump/dump.go index 88efac97..231c4beb 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -206,6 +206,9 @@ func GetDumpers(ctx context.Context, dump *Dumper) []*Dumper { if dump != nil { dumps = append(dumps, dump) } + if ctx == nil { + return dumps + } if d, ok := ctx.Value(DumperKey).(*Dumper); ok { dumps = append(dumps, d) } diff --git a/internal/http2/frame.go b/internal/http2/frame.go index e5323b74..4510b917 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -6,6 +6,7 @@ package http2 import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "golang.org/x/net/http2/hpack" "io" "log" + "net/http" "strings" "sync" ) @@ -487,6 +489,28 @@ func terminalReadFrameError(err error) bool { return err != nil } +func (h2f *Framer) streamByID(id uint32) *clientStream { + if h2f.cc == nil { + return nil + } + h2f.cc.mu.Lock() + defer h2f.cc.mu.Unlock() + cs := h2f.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs + } + return nil +} + +func (h2f *Framer) currentRequest(id uint32) *http.Request { + if cs := h2f.streamByID(id); cs != nil { + if req := cs.currentRequest; req != nil { + return req + } + } + return nil +} + // ReadFrame reads a single frame. The returned Frame is only valid // until the next call to ReadFrame. // @@ -524,9 +548,15 @@ func (h2f *Framer) ReadFrame() (Frame, error) { h2f.debugReadLoggerf("http2: Framer %p: read %v", h2f, summarizeFrame(f)) } if fh.Type == FrameHeaders && h2f.ReadMetaHeaders != nil { + hf := f.(*HeadersFrame) + req := h2f.currentRequest(hf.StreamID) + var ctx context.Context + if req != nil { + ctx = req.Context() + } var dumps []*dump.Dumper if h2f.cc != nil { - dumps = dump.GetDumpers(h2f.cc.currentRequest.Context(), h2f.cc.t.Dump) + dumps = dump.GetDumpers(ctx, h2f.cc.t.Dump) } if len(dumps) > 0 { dd := []*dump.Dumper{} @@ -537,7 +567,7 @@ func (h2f *Framer) ReadFrame() (Frame, error) { } dumps = dd } - hr, err := h2f.readMetaFrame(f.(*HeadersFrame), dumps) + hr, err := h2f.readMetaFrame(hf, dumps) if err == nil && len(dumps) > 0 { for _, dump := range dumps { dump.DumpResponseHeader([]byte("\r\n")) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index f6b3efe5..3f8bc930 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -172,13 +172,12 @@ func (t *Transport) initConnPool() { // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type ClientConn struct { - currentRequest *http.Request - t *Transport - tconn net.Conn // usually TLSConn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls - reused uint32 // whether conn is being reused; atomic - singleUse bool // whether being used for a single http.Request - getConnCalled bool // used by clientConnPool + t *Transport + tconn net.Conn // usually TLSConn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool // readLoop goroutine fields: readerDone chan struct{} // closed on error @@ -231,7 +230,8 @@ type ClientConn struct { // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type clientStream struct { - cc *ClientConn + currentRequest *http.Request + cc *ClientConn // Fields of Request that we may access even after the response body is closed. ctx context.Context @@ -622,7 +622,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro }) cc.br = bufio.NewReader(c) cc.fr = NewFramer(cc.bw, cc.br) - cc.fr.cc = cc // for dump single request + cc.fr.cc = cc if t.CountError != nil { cc.fr.countError = t.CountError } @@ -1026,9 +1026,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if cc.t != nil && cc.t.Debugf != nil { cc.t.Debugf("HTTP/2 %s %s", req.Method, req.URL.String()) } - cc.currentRequest = req ctx := req.Context() cs := &clientStream{ + currentRequest: req, cc: cc, ctx: ctx, reqCancel: req.Cancel, From fe5fafc978eb0db30cb14532591fdef698d2a676 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 4 Jan 2023 10:37:43 +0800 Subject: [PATCH 628/843] Add comments to explain the Request.URL field (#197) --- request.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/request.go b/request.go index 57723142..94dec749 100644 --- a/request.go +++ b/request.go @@ -37,8 +37,10 @@ type Request struct { RawURL string // read only Method string Body []byte - URL *urlpkg.URL GetBody GetContentFunc + // URL is an auto-generated field, and is nil in request middleware (OnBeforeRequest), + // consider using RawURL if you want, it's not nil in client middleware (WrapRoundTripFunc) + URL *urlpkg.URL isMultiPart bool disableAutoReadResponse bool From 9d5943a7cbf125e358cf57bc56b47bb4a8596d1e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 11:33:42 +0800 Subject: [PATCH 629/843] Fix missing TraceInfo when download callback is set(#200) --- client.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 2b12c823..466cfa16 100644 --- a/client.go +++ b/client.go @@ -1296,7 +1296,10 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { interval: r.downloadCallbackInterval, } } - ctx = context.WithValue(r.Context(), wrapResponseBodyKey, wrap) + if ctx == nil { + ctx = context.Background() + } + ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) } if ctx != nil { req = req.WithContext(ctx) From 334f0e2350b60fe3e55e473facc300508131ce0f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 12:47:53 +0800 Subject: [PATCH 630/843] Update go modules --- go.mod | 7 ++----- go.sum | 38 +++++++------------------------------- 2 files changed, 9 insertions(+), 36 deletions(-) diff --git a/go.mod b/go.mod index b50f7b82..2ff010ed 100644 --- a/go.mod +++ b/go.mod @@ -16,9 +16,6 @@ require ( github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.13.0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b - golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect - golang.org/x/text v0.3.7 - golang.org/x/tools v0.1.12 // indirect - golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect + golang.org/x/net v0.5.0 + golang.org/x/text v0.6.0 ) diff --git a/go.sum b/go.sum index 9ba9df93..805ec4cc 100644 --- a/go.sum +++ b/go.sum @@ -55,7 +55,6 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= @@ -66,7 +65,6 @@ github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE0 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= -github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -83,10 +81,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.27.2 h1:zsMwwniyybb8B/UDNXRSYee7WpQJVOcjQEGgpw2ikXs= -github.com/lucas-clemente/quic-go v0.27.2/go.mod h1:vXgO/11FBSKM+js1NxoaQ/bPtVFYfB7uxhfHXyMhl1A= -github.com/lucas-clemente/quic-go v0.28.0 h1:9eXVRgIkMQQyiyorz/dAaOYIx3TFzXsIFkNFz4cxuJM= -github.com/lucas-clemente/quic-go v0.28.0/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= @@ -164,7 +158,6 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= @@ -176,10 +169,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -207,16 +196,11 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220615171555-694bf12d69de h1:ogOG2+P6LjO2j55AkRScrkB2BFpd+Z8TY2wcM0Z3MGo= -golang.org/x/net v0.0.0-20220615171555-694bf12d69de/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= -golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b h1:3ogNYyK4oIQdIKzTu68hQrr4iuVxF3AxKl9Aj/eDrw0= -golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -249,26 +233,22 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI4N6/rdy1iusx77G3oU= -golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= -golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -280,16 +260,12 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= -golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f h1:uF6paiQQebLeSXkrTqHqz0MXhXXS1KgF41eUdBNvxK0= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= From c2a3bbf860daae66da048d7b5016a882bd605cdc Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 14:07:44 +0800 Subject: [PATCH 631/843] http2: add newly dialed conns to the pool before signaling completion https://github.com/golang/net/commit/694bf12d69de87b859d3663757617cec92450753 --- internal/http2/client_conn_pool.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/http2/client_conn_pool.go b/internal/http2/client_conn_pool.go index fcecc6e0..6136871a 100644 --- a/internal/http2/client_conn_pool.go +++ b/internal/http2/client_conn_pool.go @@ -118,7 +118,6 @@ func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *d func (c *dialCall) dial(ctx context.Context, addr string) { const singleUse = false // shared conn c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) - close(c.done) c.p.mu.Lock() delete(c.p.dialing, addr) @@ -126,6 +125,8 @@ func (c *dialCall) dial(ctx context.Context, addr string) { c.p.addConnLocked(addr, c.res) } c.p.mu.Unlock() + + close(c.done) } // addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't From d323afd4e5d343b972d2e41c12337a6924f5cb9f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 14:22:23 +0800 Subject: [PATCH 632/843] http2: fix spec document links https://github.com/golang/net/commit/0bcc04d9c69b4e379010f97e15bb7751dc57156b --- internal/http2/frame.go | 22 +++++++++++----------- internal/http2/http2.go | 19 ++++++++++--------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 4510b917..bf829fe9 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -25,7 +25,7 @@ const frameHeaderLen = 9 var padZeros = make([]byte, 255) // zeros for padding // A FrameType is a registered frame type as defined in -// http://http2.github.io/http2-spec/#rfc.section.11.2 +// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 type FrameType uint8 const ( @@ -148,7 +148,7 @@ func typeFrameParser(t FrameType) frameParser { // A FrameHeader is the 9 byte header of all HTTP/2 frames. // -// See http://http2.github.io/http2-spec/#FrameHeader +// See https://httpwg.org/specs/rfc7540.html#FrameHeader type FrameHeader struct { valid bool // caller can access []byte fields in the Frame @@ -628,7 +628,7 @@ func (h2f *Framer) checkFrameOrder(f Frame) error { // A DataFrame conveys arbitrary, variable-length sequences of octets // associated with a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.1 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.1 type DataFrame struct { FrameHeader data []byte @@ -751,7 +751,7 @@ func (h2f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad [] // endpoints communicate, such as preferences and constraints on peer // behavior. // -// See http://http2.github.io/http2-spec/#SETTINGS +// See https://httpwg.org/specs/rfc7540.html#SETTINGS type SettingsFrame struct { FrameHeader p []byte @@ -890,7 +890,7 @@ func (h2f *Framer) WriteSettingsAck() error { // A PingFrame is a mechanism for measuring a minimal round trip time // from the sender, as well as determining whether an idle connection // is still functional. -// See http://http2.github.io/http2-spec/#rfc.section.6.7 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.7 type PingFrame struct { FrameHeader Data [8]byte @@ -923,7 +923,7 @@ func (h2f *Framer) WritePing(ack bool, data [8]byte) error { } // A GoAwayFrame informs the remote peer to stop creating streams on this connection. -// See http://http2.github.io/http2-spec/#rfc.section.6.8 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8 type GoAwayFrame struct { FrameHeader LastStreamID uint32 @@ -987,7 +987,7 @@ func parseUnknownFrame(_ *frameCache, fh FrameHeader, countError func(string), p } // A WindowUpdateFrame is used to implement flow control. -// See http://http2.github.io/http2-spec/#rfc.section.6.9 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.9 type WindowUpdateFrame struct { FrameHeader Increment uint32 // never read with high bit set @@ -1176,7 +1176,7 @@ func (h2f *Framer) WriteHeaders(p HeadersFrameParam) error { } // A PriorityFrame specifies the sender-advised priority of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.3 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3 type PriorityFrame struct { FrameHeader PriorityParam @@ -1246,7 +1246,7 @@ func (h2f *Framer) WritePriority(streamID uint32, p PriorityParam) error { } // A RSTStreamFrame allows for abnormal termination of a stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.4 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4 type RSTStreamFrame struct { FrameHeader ErrCode ErrCode @@ -1278,7 +1278,7 @@ func (h2f *Framer) WriteRSTStream(streamID uint32, code ErrCode) error { } // A ContinuationFrame is used to continue a sequence of header block fragments. -// See http://http2.github.io/http2-spec/#rfc.section.6.10 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.10 type ContinuationFrame struct { FrameHeader headerFragBuf []byte @@ -1319,7 +1319,7 @@ func (h2f *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlo } // A PushPromiseFrame is used to initiate a server stream. -// See http://http2.github.io/http2-spec/#rfc.section.6.6 +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.6 type PushPromiseFrame struct { FrameHeader PromiseID uint32 diff --git a/internal/http2/http2.go b/internal/http2/http2.go index 253cc5a4..0e1fe2b6 100644 --- a/internal/http2/http2.go +++ b/internal/http2/http2.go @@ -45,7 +45,7 @@ const ( // HTTP/2's TLS setup. NextProtoTLS = "h2" - // http://http2.github.io/http2-spec/#SettingValues + // https://httpwg.org/specs/rfc7540.html#SettingValues initialHeaderTableSize = 4096 initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size @@ -58,7 +58,7 @@ var ( // Setting is a setting parameter: which setting it is, and its value. type Setting struct { // ID is which setting is being set. - // See http://http2.github.io/http2-spec/#SettingValues + // See https://httpwg.org/specs/rfc7540.html#SettingValues ID SettingID // Val is the value. @@ -90,7 +90,7 @@ func (s Setting) Valid() error { } // A SettingID is an HTTP/2 setting as defined in -// http://http2.github.io/http2-spec/#iana-settings +// https://httpwg.org/specs/rfc7540.html#iana-settings type SettingID uint16 const ( @@ -122,10 +122,11 @@ func (s SettingID) String() string { // name (key). See httpguts.ValidHeaderName for the base rules. // // Further, http2 says: -// "Just as in HTTP/1.x, header field names are strings of ASCII -// characters that are compared in a case-insensitive -// fashion. However, header field names MUST be converted to -// lowercase prior to their encoding in HTTP/2. " +// +// "Just as in HTTP/1.x, header field names are strings of ASCII +// characters that are compared in a case-insensitive +// fashion. However, header field names MUST be converted to +// lowercase prior to their encoding in HTTP/2. " func validWireHeaderFieldName(v string) bool { if len(v) == 0 { return false @@ -242,8 +243,8 @@ func (s *sorter) SortStrings(ss []string) { // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // -// *) a non-empty string starting with '/' -// *) the string '*', for OPTIONS requests. +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. // // For now this is only used a quick check for deciding when to clean // up Opaque URLs before sending requests from the Transport. From f670f3281f4a2b69ea5fdf6ff606d9166568ef05 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 14:44:09 +0800 Subject: [PATCH 633/843] http2: close client connections after receiving GOAWAY https://github.com/golang/net/commit/d0c6ba3f52d93c7050946b638e6c317c1d7ea069 --- internal/http2/transport.go | 24 ++++++----- internal/http2/transport_test.go | 73 ++++++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 14 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 3f8bc930..731154db 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -173,7 +173,8 @@ func (t *Transport) initConnPool() { // HTTP/2 server. type ClientConn struct { t *Transport - tconn net.Conn // usually TLSConn, except specialized impls + tconn net.Conn // usually TLSConn, except specialized impls + tconnClosed bool tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request @@ -826,10 +827,10 @@ func (cc *ClientConn) onIdleTimeout() { cc.closeIfIdle() } -func (cc *ClientConn) closeConn() error { +func (cc *ClientConn) closeConn() { t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) defer t.Stop() - return cc.tconn.Close() + cc.tconn.Close() } // A tls.Conn.Close can hang for a long time if the peer is unresponsive. @@ -895,7 +896,8 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { shutdownEnterWaitStateHook() select { case <-done: - return cc.closeConn() + cc.closeConn() + return nil case <-ctx.Done(): cc.mu.Lock() // Free the goroutine above @@ -932,7 +934,7 @@ func (cc *ClientConn) sendGoAway() error { // closes the client connection immediately. In-flight requests are interrupted. // err is sent to streams. -func (cc *ClientConn) closeForError(err error) error { +func (cc *ClientConn) closeForError(err error) { cc.mu.Lock() cc.closed = true for _, cs := range cc.streams { @@ -940,7 +942,7 @@ func (cc *ClientConn) closeForError(err error) error { } cc.cond.Broadcast() cc.mu.Unlock() - return cc.closeConn() + cc.closeConn() } // Close closes the client connection immediately. @@ -948,16 +950,17 @@ func (cc *ClientConn) closeForError(err error) error { // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. func (cc *ClientConn) Close() error { err := errors.New("http2: client connection force closed via ClientConn.Close") - return cc.closeForError(err) + cc.closeForError(err) + return nil } // closes the client connection immediately. In-flight requests are interrupted. -func (cc *ClientConn) closeForLostPing() error { +func (cc *ClientConn) closeForLostPing() { err := errors.New("http2: client connection lost") if f := cc.t.CountError; f != nil { f("conn_close_lost_ping") } - return cc.closeForError(err) + cc.closeForError(err) } func commaSeparatedTrailers(req *http.Request) (string, error) { @@ -1958,7 +1961,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { // wake up RoundTrip if there is a pending request. cc.cond.Broadcast() - closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.DisableKeepAlives + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.DisableKeepAlives || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if VerboseLogs { cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) @@ -2646,7 +2649,6 @@ func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error { if fn := cc.t.CountError; fn != nil { fn("recv_goaway_" + f.ErrCode.stringToken()) } - } cc.setGoAway(f) return nil diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index a4739f10..2e116e72 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -1122,7 +1122,9 @@ const ( ) // Test all 36 combinations of response frame orders: -// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } +// +// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } +// // Generated by http://play.golang.org/p/SScqYKJYXd func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) } func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) } @@ -1542,8 +1544,9 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT } // headerListSize returns the HTTP2 header list size of h. -// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE -// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock +// +// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE +// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock func headerListSize(h http.Header) (size uint32) { for k, vv := range h { for _, v := range vv { @@ -5695,3 +5698,67 @@ func TestProcessHeaders(t *testing.T) { err = rl.processHeaders(f) tests.AssertNoError(t, err) } + +func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) { + testTransportClosesConnAfterGoAway(t, 0) +} +func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { + testTransportClosesConnAfterGoAway(t, 1) +} + +// testTransportClosesConnAfterGoAway verifies that the transport +// closes a connection after reading a GOAWAY from it. +// +// lastStream is the last stream ID in the GOAWAY frame. +// When 0, the transport (unsuccessfully) retries the request (stream 1); +// when 1, the transport reads the response after receiving the GOAWAY. +func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { + ct := newClientTester(t) + + var wg sync.WaitGroup + wg.Add(1) + ct.client = func() error { + defer wg.Done() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err == nil { + res.Body.Close() + } + if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { + t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) + } + if err = ct.cc.Close(); err == nil { + err = fmt.Errorf("expected error on Close") + } else if strings.Contains(err.Error(), "use of closed network") { + err = nil + } + return err + } + + ct.server = func() error { + defer wg.Wait() + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + return fmt.Errorf("server failed reading HEADERS: %v", err) + } + if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { + return fmt.Errorf("server failed writing GOAWAY: %v", err) + } + if lastStream > 0 { + // Send a valid response to first request. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + return nil + } + + ct.run() +} From b0f1de0d9ddbc86b015661ce96d47be3c50db47f Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 14:52:59 +0800 Subject: [PATCH 634/843] http2: remove race from TestTransportCancelDataResponseRace https://github.com/golang/net/commit/db77216a4ee971c957784185f031d12f481af9c9 --- internal/http2/transport_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 2e116e72..3efed037 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -3381,7 +3381,7 @@ func TestClientConnPing(t *testing.T) { // connection. func TestTransportCancelDataResponseRace(t *testing.T) { cancel := make(chan struct{}) - clientGotError := make(chan bool, 1) + clientGotResponse := make(chan bool, 1) const msg = "Hello." st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -3394,8 +3394,8 @@ func TestTransportCancelDataResponseRace(t *testing.T) { io.WriteString(w, "Some data.") w.(http.Flusher).Flush() if i == 2 { + <-clientGotResponse close(cancel) - <-clientGotError } time.Sleep(10 * time.Millisecond) } @@ -3409,13 +3409,13 @@ func TestTransportCancelDataResponseRace(t *testing.T) { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Cancel = cancel res, err := c.Do(req) + clientGotResponse <- true if err != nil { t.Fatal(err) } if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { t.Fatal("unexpected success") } - clientGotError <- true res, err = c.Get(st.ts.URL + "/hello") if err != nil { From c6cc9bfe59f108160ec8e334461cf897721229e9 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 14:59:46 +0800 Subject: [PATCH 635/843] http2: don't rely on double-close of a net.Conn failing https://github.com/golang/net/commit/d300de134e69b2e21b6def1df88a04b21275ccd9 --- internal/http2/transport_test.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 3efed037..29bc9934 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -5706,6 +5706,20 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } +type closeOnceConn struct { + net.Conn + closed uint32 +} + +var errClosed = errors.New("Close of closed connection") + +func (c *closeOnceConn) Close() error { + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return c.Conn.Close() + } + return errClosed +} + // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // @@ -5714,6 +5728,7 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { ct := newClientTester(t) + ct.cc = &closeOnceConn{Conn: ct.cc} var wg sync.WaitGroup wg.Add(1) @@ -5727,12 +5742,10 @@ func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) } - if err = ct.cc.Close(); err == nil { - err = fmt.Errorf("expected error on Close") - } else if strings.Contains(err.Error(), "use of closed network") { - err = nil + if err = ct.cc.Close(); err != errClosed { + return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) } - return err + return nil } ct.server = func() error { From a2c0d01e211551ea24d1ffc39cf056ce5ad6683a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:05:42 +0800 Subject: [PATCH 636/843] http2: improved Request.Body.Close not to hold lock on connection https://github.com/golang/net/commit/f486391704dcfa95dab69d779cb1574e3c1f7db1 --- internal/http2/transport.go | 45 +++++++++++++++++++++++++---- internal/http2/transport_test.go | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 731154db..f5ba25ce 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -292,25 +292,33 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error } func (cs *clientStream) abortStream(err error) { + var reqBody io.ReadCloser + defer func() { + if reqBody != nil { + reqBody.Close() + } + }() cs.cc.mu.Lock() defer cs.cc.mu.Unlock() - cs.abortStreamLocked(err) + reqBody = cs.abortStreamLocked(err) } -func (cs *clientStream) abortStreamLocked(err error) { +func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser { cs.abortOnce.Do(func() { cs.abortErr = err close(cs.abort) }) + var reqBody io.ReadCloser if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() cs.reqBodyClosed = true + reqBody = cs.reqBody } // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. cs.cc.cond.Broadcast() } + return reqBody } func (cs *clientStream) abortRequestBodyWrite() { @@ -690,6 +698,12 @@ func (cc *ClientConn) SetDoNotReuse() { } func (cc *ClientConn) setGoAway(f *GoAwayFrame) { + var reqBodiesToClose []io.ReadCloser + defer func() { + for _, reqBody := range reqBodiesToClose { + reqBody.Close() + } + }() cc.mu.Lock() defer cc.mu.Unlock() @@ -706,7 +720,10 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - cs.abortStreamLocked(errClientConnGotGoAway) + reqBody := cs.abortStreamLocked(errClientConnGotGoAway) + if reqBody != nil { + reqBodiesToClose = append(reqBodiesToClose, reqBody) + } } } } @@ -937,11 +954,19 @@ func (cc *ClientConn) sendGoAway() error { func (cc *ClientConn) closeForError(err error) { cc.mu.Lock() cc.closed = true + + var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { - cs.abortStreamLocked(err) + reqBody := cs.abortStreamLocked(err) + if reqBody != nil { + reqBodiesToClose = append(reqBodiesToClose, reqBody) + } } cc.cond.Broadcast() cc.mu.Unlock() + for _, reqBody := range reqBodiesToClose { + reqBody.Close() + } cc.closeConn() } @@ -2037,17 +2062,25 @@ func (rl *clientConnReadLoop) cleanup() { err = io.ErrUnexpectedEOF } cc.closed = true + + var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { select { case <-cs.peerClosed: // The server closed the stream before closing the conn, // so no need to interrupt it. default: - cs.abortStreamLocked(err) + reqBody := cs.abortStreamLocked(err) + if reqBody != nil { + reqBodiesToClose = append(reqBodiesToClose, reqBody) + } } } cc.cond.Broadcast() cc.mu.Unlock() + for _, reqBody := range reqBodiesToClose { + reqBody.Close() + } } // countReadFrameError calls Transport.CountError with a string diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 29bc9934..f2182686 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -5775,3 +5775,52 @@ func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { ct.run() } + +type slowCloser struct { + closing chan struct{} + closed chan struct{} +} + +func (r *slowCloser) Read([]byte) (int, error) { + return 0, io.EOF +} + +func (r *slowCloser) Close() error { + close(r.closing) + <-r.closed + return nil +} + +func TestTransportSlowClose(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, optOnlyServer) + defer st.Close() + + client := st.ts.Client() + body := &slowCloser{ + closing: make(chan struct{}), + closed: make(chan struct{}), + } + + reqc := make(chan struct{}) + go func() { + defer close(reqc) + res, err := client.Post(st.ts.URL, "text/plain", body) + if err != nil { + t.Error(err) + } + res.Body.Close() + }() + defer func() { + close(body.closed) + <-reqc // wait for POST request to finish + }() + + <-body.closing // wait for POST request to call body.Close + // This GET request should not be blocked by the in-progress POST. + res, err := client.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} From 2652b4dae14ea8e015c472472e6e83af50ded799 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:37:26 +0800 Subject: [PATCH 637/843] http2: don't return from RoundTrip until request body is closed https://github.com/golang/net/commit/107f3e3c3b0b37888bfdc868e563e2973f54be61 --- internal/http2/transport.go | 84 ++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 49 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index f5ba25ce..aa4f177e 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -261,8 +261,8 @@ type clientStream struct { readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser - reqBodyContentLength int64 // -1 means unknown - reqBodyClosed bool // body has been closed; guarded by cc.mu + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done // owned by writeRequest: sentEndStream bool // sent an END_STREAM flag to the peer @@ -292,46 +292,48 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error } func (cs *clientStream) abortStream(err error) { - var reqBody io.ReadCloser - defer func() { - if reqBody != nil { - reqBody.Close() - } - }() cs.cc.mu.Lock() defer cs.cc.mu.Unlock() - reqBody = cs.abortStreamLocked(err) + cs.abortStreamLocked(err) } -func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser { +func (cs *clientStream) abortStreamLocked(err error) { cs.abortOnce.Do(func() { cs.abortErr = err close(cs.abort) }) - var reqBody io.ReadCloser - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBodyClosed = true - reqBody = cs.reqBody + if cs.reqBody != nil { + cs.closeReqBodyLocked() } // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. cs.cc.cond.Broadcast() } - return reqBody } func (cs *clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true + if cs.reqBody != nil && cs.reqBodyClosed == nil { + cs.closeReqBodyLocked() cc.cond.Broadcast() } } +func (cs *clientStream) closeReqBodyLocked() { + if cs.reqBodyClosed != nil { + return + } + cs.reqBodyClosed = make(chan struct{}) + reqBodyClosed := cs.reqBodyClosed + go func() { + cs.reqBody.Close() + close(reqBodyClosed) + }() +} + type stickyErrWriter struct { conn net.Conn timeout time.Duration @@ -698,12 +700,6 @@ func (cc *ClientConn) SetDoNotReuse() { } func (cc *ClientConn) setGoAway(f *GoAwayFrame) { - var reqBodiesToClose []io.ReadCloser - defer func() { - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } - }() cc.mu.Lock() defer cc.mu.Unlock() @@ -720,10 +716,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - reqBody := cs.abortStreamLocked(errClientConnGotGoAway) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(errClientConnGotGoAway) } } } @@ -954,19 +947,11 @@ func (cc *ClientConn) sendGoAway() error { func (cc *ClientConn) closeForError(err error) { cc.mu.Lock() cc.closed = true - - var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { - reqBody := cs.abortStreamLocked(err) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(err) } cc.cond.Broadcast() cc.mu.Unlock() - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } cc.closeConn() } @@ -1370,11 +1355,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) { // and in multiple cases: server replies <=299 and >299 // while still writing request body cc.mu.Lock() + mustCloseBody := false + if cs.reqBody != nil && cs.reqBodyClosed == nil { + mustCloseBody = true + cs.reqBodyClosed = make(chan struct{}) + } bodyClosed := cs.reqBodyClosed - cs.reqBodyClosed = true cc.mu.Unlock() - if !bodyClosed && cs.reqBody != nil { + if mustCloseBody { cs.reqBody.Close() + close(bodyClosed) + } + if bodyClosed != nil { + <-bodyClosed } if err != nil && cs.sentEndStream { @@ -1564,7 +1557,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request, dumps []*dump.Dumper } if err != nil { cc.mu.Lock() - bodyClosed := cs.reqBodyClosed + bodyClosed := cs.reqBodyClosed != nil cc.mu.Unlock() switch { case bodyClosed: @@ -1659,7 +1652,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if cc.closed { return 0, errClientConnClosed } - if cs.reqBodyClosed { + if cs.reqBodyClosed != nil { return 0, errStopReqBodyWrite } select { @@ -2063,24 +2056,17 @@ func (rl *clientConnReadLoop) cleanup() { } cc.closed = true - var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { select { case <-cs.peerClosed: // The server closed the stream before closing the conn, // so no need to interrupt it. default: - reqBody := cs.abortStreamLocked(err) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(err) } } cc.cond.Broadcast() cc.mu.Unlock() - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } } // countReadFrameError calls Transport.CountError with a string From 58550db89a8d87837a754c5569b25b2080d6a37c Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:39:35 +0800 Subject: [PATCH 638/843] http2: add a few other common headers to the shared headermap cache https://github.com/golang/net/commit/c877839975014e976a41dde0045be50819199b72 --- internal/http2/headermap.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/internal/http2/headermap.go b/internal/http2/headermap.go index 50431a6d..df757ec2 100644 --- a/internal/http2/headermap.go +++ b/internal/http2/headermap.go @@ -28,7 +28,14 @@ func buildCommonHeaderMaps() { "accept-language", "accept-ranges", "age", + "access-control-allow-credentials", + "access-control-allow-headers", + "access-control-allow-methods", "access-control-allow-origin", + "access-control-expose-headers", + "access-control-max-age", + "access-control-request-headers", + "access-control-request-method", "allow", "authorization", "cache-control", @@ -54,6 +61,7 @@ func buildCommonHeaderMaps() { "link", "location", "max-forwards", + "origin", "proxy-authenticate", "proxy-authorization", "range", @@ -69,6 +77,8 @@ func buildCommonHeaderMaps() { "vary", "via", "www-authenticate", + "x-forwarded-for", + "x-forwarded-proto", } commonLowerHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common)) From ef53894a836b8f7cc12440f7da08c20567b06534 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:47:00 +0800 Subject: [PATCH 639/843] http2: add common header caching to Transport to reduce allocations https://github.com/golang/net/commit/a1278a7f7ee0c218caeda793b867e0568bbe1e77 --- internal/http2/headermap.go | 8 ++++++++ internal/http2/transport.go | 12 ++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/http2/headermap.go b/internal/http2/headermap.go index df757ec2..a8e01cf9 100644 --- a/internal/http2/headermap.go +++ b/internal/http2/headermap.go @@ -96,3 +96,11 @@ func lowerHeader(v string) (lower string, isAscii bool) { } return ascii.ToLower(v) } + +func canonicalHeader(v string) string { + buildCommonHeaderMapsOnce() + if s, ok := commonCanonHeader[v]; ok { + return s + } + return http.CanonicalHeaderKey(v) +} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index aa4f177e..60b0828b 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -976,7 +976,7 @@ func (cc *ClientConn) closeForLostPing() { func commaSeparatedTrailers(req *http.Request) (string, error) { keys := make([]string, 0, len(req.Trailer)) for k := range req.Trailer { - k = http.CanonicalHeaderKey(k) + k = canonicalHeader(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": return "", fmt.Errorf("invalid Trailer key %q", k) @@ -1854,7 +1854,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { - name, ascii := ascii.ToLower(name) + name, ascii := lowerHeader(name) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -1921,7 +1921,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header, dumps []*dump.Dumper) } for k, vv := range trailer { - lowKey, ascii := ascii.ToLower(k) + lowKey, ascii := lowerHeader(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // field names have to be ASCII characters (just as in HTTP/1.x). @@ -2272,7 +2272,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra Status: status + " " + http.StatusText(statusCode), } for _, hf := range regularFields { - key := http.CanonicalHeaderKey(hf.Name) + key := canonicalHeader(hf.Name) if key == "Trailer" { t := res.Trailer if t == nil { @@ -2280,7 +2280,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra res.Trailer = t } foreachHeaderElement(hf.Value, func(v string) { - t[http.CanonicalHeaderKey(v)] = nil + t[canonicalHeader(v)] = nil }) } else { vv := header[key] @@ -2386,7 +2386,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr trailer := make(http.Header) for _, hf := range f.RegularFields() { - key := http.CanonicalHeaderKey(hf.Name) + key := canonicalHeader(hf.Name) trailer[key] = append(trailer[key], hf.Value) } cs.trailer = trailer From 180cfa666e5ebf8274f9859b18d2e9b7e68fc096 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:51:20 +0800 Subject: [PATCH 640/843] http2: GzipReader will reset zr to nil after closing body https://github.com/golang/net/commit/7a676822c292e3f405fc21ac0897393d715a12ea --- internal/http2/transport.go | 7 ++++++- internal/http2/transport_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 60b0828b..2d684010 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -23,6 +23,7 @@ import ( "github.com/imroc/req/v3/internal/transport" reqtls "github.com/imroc/req/v3/pkg/tls" "io" + "io/fs" "log" "math" mathrand "math/rand" @@ -2961,7 +2962,11 @@ func (gz *GzipReader) Read(p []byte) (n int, err error) { } func (gz *GzipReader) Close() error { - return gz.Body.Close() + if err := gz.Body.Close(); err != nil { + return err + } + gz.zerr = fs.ErrClosed + return nil } // isConnectionCloseRequest reports whether req should use its own diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index f2182686..343d68b6 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -7,6 +7,7 @@ package http2 import ( "bufio" "bytes" + "compress/gzip" "context" "crypto/tls" "encoding/hex" @@ -17,6 +18,7 @@ import ( "github.com/imroc/req/v3/internal/tests" "github.com/imroc/req/v3/internal/transport" "io" + "io/fs" "io/ioutil" "log" "math/rand" @@ -2468,6 +2470,28 @@ func TestGzipReader_DoubleReadCrash(t *testing.T) { } } +func TestGzipReader_ReadAfterClose(t *testing.T) { + body := bytes.Buffer{} + w := gzip.NewWriter(&body) + w.Write([]byte("012345679")) + w.Close() + gz := &GzipReader{ + Body: io.NopCloser(&body), + } + var buf [1]byte + n, err := gz.Read(buf[:]) + if n != 1 || err != nil { + t.Fatalf("first Read = %v, %v; want 1, nil", n, err) + } + if err := gz.Close(); err != nil { + t.Fatalf("gz Close error: %v", err) + } + n, err = gz.Read(buf[:]) + if n != 0 || err != fs.ErrClosed { + t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err) + } +} + func TestTransportNewTLSConfig(t *testing.T) { testCases := [...]struct { conf *tls.Config From b59d52ba1f00769f42161faba7d0f25048b26a67 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 15:58:12 +0800 Subject: [PATCH 641/843] http2: speed up TestTransportRetryHasLimit https://github.com/golang/net/commit/15e1b255657fc034cf76cbfc7c0c2a1e801c404d --- internal/http2/transport.go | 14 +++++++- internal/http2/transport_test.go | 58 ++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 2d684010..a5b1ff4f 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -428,6 +428,15 @@ func (t *Transport) AddConn(conn net.Conn, addr string) (used bool, err error) { return } +var retryBackoffHook func(time.Duration) *time.Timer + +func backoffNewTimer(d time.Duration) *time.Timer { + if retryBackoffHook != nil { + return retryBackoffHook(d) + } + return time.NewTimer(d) +} + // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -463,11 +472,14 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) + d := time.Second * time.Duration(backoff) + timer := backoffNewTimer(d) select { - case <-time.After(time.Second * time.Duration(backoff)): + case <-timer.C: t.vlogf("RoundTrip retrying after failure: %v", err) continue case <-req.Context().Done(): + timer.Stop() err = req.Context().Err() } } diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 343d68b6..6c2b2303 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -3845,6 +3845,64 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { ct.run() } +func TestTransportRetryHasLimit(t *testing.T) { + // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. + if testing.Short() { + t.Skip("skipping long test in short mode") + } + retryBackoffHook = func(d time.Duration) *time.Timer { + return time.NewTimer(0) // fires immediately + } + defer func() { + retryBackoffHook = nil + }() + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + if runtime.GOOS == "plan9" { + // CloseWrite not supported on Plan 9; Issue 17906 + defer ct.cc.(*net.TCPConn).Close() + } + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + resp, err := ct.tr.RoundTrip(req) + if err == nil { + return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + } + t.Logf("expected error, got: %v", err) + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + func TestTransportResponseDataBeforeHeaders(t *testing.T) { // This test use not valid response format. // Discarding logger output to not spam tests output. From a89b4557811440103973602aa43f0207d5e1111d Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 5 Jan 2023 16:34:53 +0800 Subject: [PATCH 642/843] http2: rewrite inbound flow control tracking https://github.com/golang/net/commit/7805fdc37dc2b54b28b9d621030e14dcf1dab67c --- internal/http2/flow.go | 90 ++++++- internal/http2/flow_test.go | 66 ++++- internal/http2/server_test.go | 400 ++++++------------------------- internal/http2/transport.go | 87 +++---- internal/http2/transport_test.go | 134 +++++++---- 5 files changed, 337 insertions(+), 440 deletions(-) diff --git a/internal/http2/flow.go b/internal/http2/flow.go index 2689f132..750ac52f 100644 --- a/internal/http2/flow.go +++ b/internal/http2/flow.go @@ -2,25 +2,95 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Flow control + package http2 -// flow is the flow control window's size. -type flow struct { +// inflowMinRefresh is the minimum number of bytes we'll send for a +// flow control window update. +const inflowMinRefresh = 4 << 10 + +// inflow accounts for an inbound flow control window. +// It tracks both the latest window sent to the peer (used for enforcement) +// and the accumulated unsent window. +type inflow struct { + avail int32 + unsent int32 +} + +// set sets the initial window. +func (f *inflow) init(n int32) { + f.avail = n +} + +// add adds n bytes to the window, with a maximum window size of max, +// indicating that the peer can now send us more data. +// For example, the user read from a {Request,Response} body and consumed +// some of the buffered data, so the peer can now send more. +// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer. +// Window updates are accumulated and sent when the unsent capacity +// is at least inflowMinRefresh or will at least double the peer's available window. +func (f *inflow) add(n int) (connAdd int32) { + if n < 0 { + panic("negative update") + } + unsent := int64(f.unsent) + int64(n) + // "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets." + // RFC 7540 Section 6.9.1. + const maxWindow = 1<<31 - 1 + if unsent+int64(f.avail) > maxWindow { + panic("flow control update exceeds maximum window size") + } + f.unsent = int32(unsent) + if f.unsent < inflowMinRefresh && f.unsent < f.avail { + // If there aren't at least inflowMinRefresh bytes of window to send, + // and this update won't at least double the window, buffer the update for later. + return 0 + } + f.avail += f.unsent + f.unsent = 0 + return int32(unsent) +} + +// take attempts to take n bytes from the peer's flow control window. +// It reports whether the window has available capacity. +func (f *inflow) take(n uint32) bool { + if n > uint32(f.avail) { + return false + } + f.avail -= int32(n) + return true +} + +// takeInflows attempts to take n bytes from two inflows, +// typically connection-level and stream-level flows. +// It reports whether both windows have available capacity. +func takeInflows(f1, f2 *inflow, n uint32) bool { + if n > uint32(f1.avail) || n > uint32(f2.avail) { + return false + } + f1.avail -= int32(n) + f2.avail -= int32(n) + return true +} + +// outflow is the outbound flow control window's size. +type outflow struct { _ incomparable // n is the number of DATA bytes we're allowed to send. - // A flow is kept both on a conn and a per-stream. + // An outflow is kept both on a conn and a per-stream. n int32 - // conn points to the shared connection-level flow that is - // shared by all streams on that conn. It is nil for the flow + // conn points to the shared connection-level outflow that is + // shared by all streams on that conn. It is nil for the outflow // that's on the conn directly. - conn *flow + conn *outflow } -func (f *flow) setConnFlow(cf *flow) { f.conn = cf } +func (f *outflow) setConnFlow(cf *outflow) { f.conn = cf } -func (f *flow) available() int32 { +func (f *outflow) available() int32 { n := f.n if f.conn != nil && f.conn.n < n { n = f.conn.n @@ -28,7 +98,7 @@ func (f *flow) available() int32 { return n } -func (f *flow) take(n int32) { +func (f *outflow) take(n int32) { if n > f.available() { panic("internal error: took too much") } @@ -40,7 +110,7 @@ func (f *flow) take(n int32) { // add adds n bytes (positive or negative) to the flow control window. // It returns false if the sum would exceed 2^31-1. -func (f *flow) add(n int32) bool { +func (f *outflow) add(n int32) bool { sum := f.n + n if (sum > n) == (f.n > 0) { f.n = sum diff --git a/internal/http2/flow_test.go b/internal/http2/flow_test.go index 7ae82c78..cae4f38c 100644 --- a/internal/http2/flow_test.go +++ b/internal/http2/flow_test.go @@ -6,9 +6,61 @@ package http2 import "testing" -func TestFlow(t *testing.T) { - var st flow - var conn flow +func TestInFlowTake(t *testing.T) { + var f inflow + f.init(100) + if !f.take(40) { + t.Fatalf("f.take(40) from 100: got false, want true") + } + if !f.take(40) { + t.Fatalf("f.take(40) from 60: got false, want true") + } + if f.take(40) { + t.Fatalf("f.take(40) from 20: got true, want false") + } + if !f.take(20) { + t.Fatalf("f.take(20) from 20: got false, want true") + } +} + +func TestInflowAddSmall(t *testing.T) { + var f inflow + f.init(0) + // Adding even a small amount when there is no flow causes an immediate send. + if got, want := f.add(1), int32(1); got != want { + t.Fatalf("f.add(1) to 1 = %v, want %v", got, want) + } +} + +func TestInflowAdd(t *testing.T) { + var f inflow + f.init(10 * inflowMinRefresh) + if got, want := f.add(inflowMinRefresh-1), int32(0); got != want { + t.Fatalf("f.add(minRefresh - 1) = %v, want %v", got, want) + } + if got, want := f.add(1), int32(inflowMinRefresh); got != want { + t.Fatalf("f.add(minRefresh) = %v, want %v", got, want) + } +} + +func TestTakeInflows(t *testing.T) { + var a, b inflow + a.init(10) + b.init(20) + if !takeInflows(&a, &b, 5) { + t.Fatalf("takeInflows(a, b, 5) from 10, 20: got false, want true") + } + if takeInflows(&a, &b, 6) { + t.Fatalf("takeInflows(a, b, 6) from 5, 15: got true, want false") + } + if !takeInflows(&a, &b, 5) { + t.Fatalf("takeInflows(a, b, 5) from 5, 15: got false, want true") + } +} + +func TestOutFlow(t *testing.T) { + var st outflow + var conn outflow st.add(3) conn.add(2) @@ -29,8 +81,8 @@ func TestFlow(t *testing.T) { } } -func TestFlowAdd(t *testing.T) { - var f flow +func TestOutFlowAdd(t *testing.T) { + var f outflow if !f.add(1) { t.Fatal("failed to add 1") } @@ -51,8 +103,8 @@ func TestFlowAdd(t *testing.T) { } } -func TestFlowAddOverflow(t *testing.T) { - var f flow +func TestOutFlowAddOverflow(t *testing.T) { + var f outflow if !f.add(0) { t.Fatal("failed to add 0") } diff --git a/internal/http2/server_test.go b/internal/http2/server_test.go index c0c3074b..c1ddee69 100644 --- a/internal/http2/server_test.go +++ b/internal/http2/server_test.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/tls" - "encoding/xml" "errors" "flag" "fmt" @@ -21,7 +20,6 @@ import ( "net/url" "os" "reflect" - "regexp" "runtime" "sort" "strconv" @@ -1010,8 +1008,8 @@ type serverConn struct { wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan bodyReadMsg // from handlers -> serve serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow flow // conn-wide (not stream-specific) outbound flow control - inflow flow // conn-wide inbound flow control + flow outflow // conn-wide (not stream-specific) outbound flow control + inflow inflow // conn-wide inbound flow control tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string writeSched WriteScheduler @@ -1108,10 +1106,10 @@ type stream struct { cancelCtx func() // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow flow // limits writing from Handler to client - inflow flow // what the client is allowed to POST/etc to us + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow outflow // limits writing from Handler to client + inflow inflow // what the client is allowed to POST/etc to us state streamState resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen @@ -2279,14 +2277,9 @@ func (sc *serverConn) processData(f *DataFrame) error { // But still enforce their connection-level flow control, // and return any flow control bytes since we're not going // to consume them. - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) } - // Deduct the flow control from inflow, since we're - // going to immediately add it back in - // sendWindowUpdate, which also schedules sending the - // frames. - sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level if st != nil && st.resetQueued { @@ -2301,6 +2294,11 @@ func (sc *serverConn) processData(f *DataFrame) error { // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + if !sc.inflow.take(f.Length) { + return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) + } + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // value of a content-length header field does not equal the sum of the @@ -2309,10 +2307,9 @@ func (sc *serverConn) processData(f *DataFrame) error { } if f.Length > 0 { // Check whether the client has flow control quota. - if st.inflow.available() < int32(f.Length) { + if !takeInflows(&sc.inflow, &st.inflow, f.Length) { return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl)) } - st.inflow.take(int32(f.Length)) if len(data) > 0 { wrote, err := st.body.Write(data) @@ -2328,10 +2325,12 @@ func (sc *serverConn) processData(f *DataFrame) error { // Return any padded flow control now, since we won't // refund it later on body reads. - if pad := int32(f.Length) - int32(len(data)); pad > 0 { - sc.sendWindowUpdate32(nil, pad) - sc.sendWindowUpdate32(st, pad) - } + // Call sendWindowUpdate even if there is no padding, + // to return buffered flow control credit if the sent + // window has shrunk. + pad := int32(f.Length) - int32(len(data)) + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) } if f.StreamEnded() { st.endStream() @@ -2575,8 +2574,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.cw.Init() st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) - st.inflow.conn = &sc.inflow // link to conn-level counter - st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + st.inflow.init(sc.srv.initialStreamRecvWindowSize()) if sc.hs.WriteTimeout != 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } @@ -2858,47 +2856,28 @@ func (sc *serverConn) noteBodyRead(st *stream, n int) { } // st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate(st *stream, n int) { - sc.serveG.check() - // "The legal range for the increment to the flow control - // window is 1 to 2^31-1 (2,147,483,647) octets." - // A Go Read call on 64-bit machines could in theory read - // a larger Read than this. Very unlikely, but we handle it here - // rather than elsewhere for now. - const maxUint31 = 1<<31 - 1 - for n >= maxUint31 { - sc.sendWindowUpdate32(st, maxUint31) - n -= maxUint31 - } - sc.sendWindowUpdate32(st, int32(n)) +func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { + sc.sendWindowUpdate(st, int(n)) } // st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { +func (sc *serverConn) sendWindowUpdate(st *stream, n int) { sc.serveG.check() - if n == 0 { - return - } - if n < 0 { - panic("negative update") - } var streamID uint32 - if st != nil { + var send int32 + if st == nil { + send = sc.inflow.add(n) + } else { streamID = st.id + send = st.inflow.add(n) + } + if send == 0 { + return } sc.writeFrame(FrameWriteRequest{ - write: writeWindowUpdate{streamID: streamID, n: uint32(n)}, + write: writeWindowUpdate{streamID: streamID, n: uint32(send)}, stream: st, }) - var ok bool - if st == nil { - ok = sc.inflow.add(n) - } else { - ok = st.inflow.add(n) - } - if !ok { - panic("internal error; sent too many window updates without decrements?") - } } // requestBody is the Handler's Request.Body type. @@ -3142,8 +3121,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { // prior to the headers being written. If the set of trailers is fixed // or known before the header is written, the normal Go trailers mechanism // is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +// +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers const TrailerPrefix = "Trailer:" // promoteUndeclaredTrailers permits http.Handlers to set trailers @@ -5289,6 +5269,22 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p } } +// writeReadPing sends a PING and immediately reads the PING ACK. +// It will fail if any other unread data was pending on the connection. +func (st *serverTester) writeReadPing() { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := st.fr.WritePing(false, data); err != nil { + st.t.Fatalf("Error writing PING: %v", err) + } + p := st.wantPing() + if p.Flags&FlagPingAck == 0 { + st.t.Fatalf("got a PING, want a PING ACK") + } + if p.Data != data { + st.t.Fatalf("got PING data = %x, want %x", p.Data, data) + } +} + func (st *serverTester) readFrame() (Frame, error) { return st.fr.ReadFrame() } @@ -5399,284 +5395,24 @@ func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { } } -func (st *serverTester) wantSettingsAck() { - f, err := st.readFrame() - if err != nil { - st.t.Fatal(err) - } - sf, ok := f.(*SettingsFrame) - if !ok { - st.t.Fatalf("Wanting a settings ACK, received a %T", f) - } - if !sf.Header().Flags.Has(FlagSettingsAck) { - st.t.Fatal("Settings Frame didn't have ACK set") - } -} - -func (st *serverTester) wantPushPromise() *PushPromiseFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatal(err) - } - ppf, ok := f.(*PushPromiseFrame) - if !ok { - st.t.Fatalf("Wanted PushPromise, received %T", ppf) - } - return ppf -} - -type specCoverage struct { - coverage map[specPart]bool - d *xml.Decoder -} - -func joinSection(sec []int) string { - s := fmt.Sprintf("%d", sec[0]) - for _, n := range sec[1:] { - s = fmt.Sprintf("%s.%d", s, n) - } - return s -} - -func (sc specCoverage) readSection(sec []int) { - var ( - buf = new(bytes.Buffer) - sub = 0 - ) - for { - tk, err := sc.d.Token() - if err != nil { - if err == io.EOF { - return - } - panic(err) - } - switch v := tk.(type) { - case xml.StartElement: - if skipElement(v) { - if err := sc.d.Skip(); err != nil { - panic(err) - } - if v.Name.Local == "section" { - sub++ - } - break - } - switch v.Name.Local { - case "section": - sub++ - sc.readSection(append(sec, sub)) - case "xref": - buf.Write(sc.readXRef(v)) - } - case xml.CharData: - if len(sec) == 0 { - break - } - buf.Write(v) - case xml.EndElement: - if v.Name.Local == "section" { - sc.addSentences(joinSection(sec), buf.String()) - return - } - } - } -} - -func attrSig(se xml.StartElement) string { - var names []string - for _, attr := range se.Attr { - if attr.Name.Local == "fmt" { - names = append(names, "fmt-"+attr.Value) +func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { + var initial int32 + if streamID == 0 { + initial = st.sc.srv.initialConnRecvWindowSize() + } else { + initial = st.sc.srv.initialStreamRecvWindowSize() + } + donec := make(chan struct{}) + st.sc.sendServeMsg(func(sc *serverConn) { + defer close(donec) + var avail int32 + if streamID == 0 { + avail = sc.inflow.avail + sc.inflow.unsent } else { - names = append(names, attr.Name.Local) } - } - sort.Strings(names) - return strings.Join(names, ",") -} - -func attrValue(se xml.StartElement, attr string) string { - for _, a := range se.Attr { - if a.Name.Local == attr { - return a.Value + if got, want := initial-avail, consumed; got != want { + st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want) } - } - panic("unknown attribute " + attr) -} - -func (sc specCoverage) readXRef(se xml.StartElement) []byte { - var b []byte - for { - tk, err := sc.d.Token() - if err != nil { - panic(err) - } - switch v := tk.(type) { - case xml.CharData: - if b != nil { - panic("unexpected CharData") - } - b = []byte(string(v)) - case xml.EndElement: - if v.Name.Local != "xref" { - panic("expected ") - } - if b != nil { - return b - } - sig := attrSig(se) - switch sig { - case "target": - return []byte(fmt.Sprintf("[%s]", attrValue(se, "target"))) - case "fmt-of,rel,target", "fmt-,,rel,target": - return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel"))) - case "fmt-of,sec,target", "fmt-,,sec,target": - return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target"))) - case "fmt-of,rel,sec,target": - return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel"))) - default: - panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se))) - } - default: - panic(fmt.Sprintf("unexpected tag %q", v)) - } - } -} - -var skipAnchor = map[string]bool{ - "intro": true, - "Overview": true, -} - -var skipTitle = map[string]bool{ - "Acknowledgements": true, - "Change Log": true, - "Document Organization": true, - "Conventions and Terminology": true, -} - -func skipElement(s xml.StartElement) bool { - switch s.Name.Local { - case "artwork": - return true - case "section": - for _, attr := range s.Attr { - switch attr.Name.Local { - case "anchor": - if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") { - return true - } - case "title": - if skipTitle[attr.Value] { - return true - } - } - } - } - return false -} - -type specPart struct { - section string - sentence string -} - -func (ss specPart) Less(oo specPart) bool { - atoi := func(s string) int { - n, err := strconv.Atoi(s) - if err != nil { - panic(err) - } - return n - } - a := strings.Split(ss.section, ".") - b := strings.Split(oo.section, ".") - for len(a) > 0 { - if len(b) == 0 { - return false - } - x, y := atoi(a[0]), atoi(b[0]) - if x == y { - a, b = a[1:], b[1:] - continue - } - return x < y - } - if len(b) > 0 { - return true - } - return false -} - -type bySpecSection []specPart - -func (a bySpecSection) Len() int { return len(a) } -func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) } -func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] } - -func readSpecCov(r io.Reader) specCoverage { - sc := specCoverage{ - coverage: map[specPart]bool{}, - d: xml.NewDecoder(r)} - sc.readSection(nil) - return sc -} - -var whitespaceRx = regexp.MustCompile(`\s+`) - -func parseSentences(sens string) []string { - sens = strings.TrimSpace(sens) - if sens == "" { - return nil - } - ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ") - for i, s := range ss { - s = strings.TrimSpace(s) - if !strings.HasSuffix(s, ".") { - s += "." - } - ss[i] = s - } - return ss -} - -func (sc specCoverage) addSentences(sec string, sentence string) { - for _, s := range parseSentences(sentence) { - sc.coverage[specPart{sec, s}] = false - } -} - -func (sc specCoverage) cover(sec string, sentence string) { - for _, s := range parseSentences(sentence) { - p := specPart{sec, s} - if _, ok := sc.coverage[p]; !ok { - panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s)) - } - sc.coverage[specPart{sec, s}] = true - } - -} - -var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests") - -// The global map of sentence coverage for the http2 spec. -var defaultSpecCoverage specCoverage - -var loadSpecOnce sync.Once - -func loadSpec() { - if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil { - panic(err) - } else { - defaultSpecCoverage = readSpecCov(f) - f.Close() - } -} - -// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not -// "covered" will be included in report outputted by TestSpecCoverage. -func covers(sec, sentences string) { - loadSpecOnce.Do(loadSpec) - defaultSpecCoverage.cover(sec, sentences) + }) + <-donec } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index a5b1ff4f..1461837b 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -54,10 +54,6 @@ const ( // we buffer per stream. transportDefaultStreamFlow = 4 << 20 - // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send - // a stream-level WINDOW_UPDATE for at a time. - transportDefaultStreamMinRefresh = 4 << 10 - // initialMaxConcurrentStreams is a connections maxConcurrentStreams until // it's received servers initial SETTINGS frame, which corresponds with the // spec's minimum recommended value. @@ -190,8 +186,8 @@ type ClientConn struct { mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow flow // our conn-level flow control quota (cs.flow is per stream) - inflow flow // peer's conn-level flow control + flow outflow // our conn-level flow control quota (cs.outflow is per stream) + inflow inflow // peer's conn-level flow control doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool @@ -256,10 +252,10 @@ type clientStream struct { respHeaderRecv chan struct{} // closed when headers are received res *http.Response // set if respHeaderRecv is closed - flow flow // guarded by cc.mu - inflow flow // guarded by cc.mu - bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read - readErr error // sticky read error; owned by transportResponseBody.Read + flow outflow // guarded by cc.mu + inflow inflow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser reqBodyContentLength int64 // -1 means unknown @@ -677,7 +673,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) - cc.inflow.add(transportDefaultConnFlow + initialWindowSize) + cc.inflow.init(transportDefaultConnFlow + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() @@ -1966,8 +1962,7 @@ type resAndError struct { func (cc *ClientConn) addStreamLocked(cs *clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) - cs.inflow.add(transportDefaultStreamFlow) - cs.inflow.setConnFlow(&cc.inflow) + cs.inflow.init(transportDefaultStreamFlow) cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs @@ -2445,21 +2440,10 @@ func (b transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - var connAdd, streamAdd int32 - // Check the conn-level first, before the stream-level. - if v := cc.inflow.available(); v < transportDefaultConnFlow/2 { - connAdd = transportDefaultConnFlow - v - cc.inflow.add(connAdd) - } + connAdd := cc.inflow.add(n) + var streamAdd int32 if err == nil { // No need to refresh if the stream is over or failed. - // Consider any buffered body data (read from the conn but not - // consumed by the client) when computing flow control for this - // stream. - v := int(cs.inflow.available()) + cs.bufPipe.Len() - if v < transportDefaultStreamFlow-transportDefaultStreamMinRefresh { - streamAdd = int32(transportDefaultStreamFlow - v) - cs.inflow.add(streamAdd) - } + streamAdd = cs.inflow.add(n) } cc.mu.Unlock() @@ -2487,17 +2471,15 @@ func (b transportResponseBody) Close() error { if unread > 0 { cc.mu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.inflow.add(int32(unread)) - } + connAdd := cc.inflow.add(unread) cc.mu.Unlock() // TODO(dneil): Acquiring this mutex can block indefinitely. // Move flow control return to a goroutine? cc.wmu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.fr.WriteWindowUpdate(0, uint32(unread)) + if connAdd > 0 { + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) } cc.bw.Flush() cc.wmu.Unlock() @@ -2518,7 +2500,6 @@ func (b transportResponseBody) Close() error { } return nil } - func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc := rl.cc cs := rl.streamByID(f.StreamID) @@ -2540,13 +2521,18 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { // But at least return their flow control: if f.Length > 0 { cc.mu.Lock() - cc.inflow.add(int32(f.Length)) + ok := cc.inflow.take(f.Length) + connAdd := cc.inflow.add(int(f.Length)) cc.mu.Unlock() - - cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(f.Length)) - cc.bw.Flush() - cc.wmu.Unlock() + if !ok { + return ConnectionError(ErrCodeFlowControl) + } + if connAdd > 0 { + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) + cc.bw.Flush() + cc.wmu.Unlock() + } } return nil } @@ -2577,9 +2563,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } // Check connection-level flow control. cc.mu.Lock() - if cs.inflow.available() >= int32(f.Length) { - cs.inflow.take(int32(f.Length)) - } else { + if !takeInflows(&cc.inflow, &cs.inflow, f.Length) { cc.mu.Unlock() return ConnectionError(ErrCodeFlowControl) } @@ -2601,19 +2585,20 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } } - if refund > 0 { - cc.inflow.add(int32(refund)) - if !didReset { - cs.inflow.add(int32(refund)) - } + sendConn := cc.inflow.add(refund) + var sendStream int32 + if !didReset { + sendStream = cs.inflow.add(refund) } cc.mu.Unlock() - if refund > 0 { + if sendConn > 0 || sendStream > 0 { cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(refund)) - if !didReset { - cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + if sendConn > 0 { + cc.fr.WriteWindowUpdate(0, uint32(sendConn)) + } + if sendStream > 0 { + cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream)) } cc.bw.Flush() cc.wmu.Unlock() diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 6c2b2303..c75086b2 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -797,6 +797,55 @@ func (ct *clientTester) readNonSettingsFrame() (Frame, error) { } } +// writeReadPing sends a PING and immediately reads the PING ACK. +// It will fail if any other unread data was pending on the connection, +// aside from SETTINGS frames. +func (ct *clientTester) writeReadPing() error { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := ct.fr.WritePing(false, data); err != nil { + return fmt.Errorf("Error writing PING: %v", err) + } + f, err := ct.readNonSettingsFrame() + if err != nil { + return err + } + p, ok := f.(*PingFrame) + if !ok { + return fmt.Errorf("got a %v, want a PING ACK", f) + } + if p.Flags&FlagPingAck == 0 { + return fmt.Errorf("got a PING, want a PING ACK") + } + if p.Data != data { + return fmt.Errorf("got PING data = %x, want %x", p.Data, data) + } + return nil +} + +func (ct *clientTester) inflowWindow(streamID uint32) int32 { + pool := ct.tr.connPoolOrDef.(*clientConnPool) + pool.mu.Lock() + defer pool.mu.Unlock() + if n := len(pool.keys); n != 1 { + ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) + return -1 + } + for cc := range pool.keys { + cc.mu.Lock() + defer cc.mu.Unlock() + if streamID == 0 { + return cc.inflow.avail + cc.inflow.unsent + } + cs := cc.streams[streamID] + if cs == nil { + ct.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent + } + return -1 +} + func (ct *clientTester) cleanup() { ct.tr.CloseIdleConnections() @@ -2873,22 +2922,17 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { ct := newClientTester(t) - clientClosed := make(chan struct{}) - serverWroteFirstByte := make(chan struct{}) - ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } - <-serverWroteFirstByte if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) } res.Body.Close() // leaving 4999 bytes unread - close(clientClosed) return nil } @@ -2923,6 +2967,7 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { EndStream: false, BlockFragment: buf.Bytes(), }) + initialInflow := ct.inflowWindow(0) // Two cases: // - Send one DATA frame with 5000 bytes. @@ -2931,50 +2976,63 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { // In both cases, the client should consume one byte of data, // refund that byte, then refund the following 4999 bytes. // - // In the second case, the server waits for the client connection to - // close before seconding the second DATA frame. This tests the case + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case // where the client receives a DATA frame after it has reset the stream. if oneDataFrame { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - close(serverWroteFirstByte) - <-clientClosed } else { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - close(serverWroteFirstByte) - <-clientClosed - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) } - waitingFor := "RSTStreamFrame" - sawRST := false - sawWUF := false - for !sawRST && !sawWUF { - f, err := ct.fr.ReadFrame() + wantRST := true + wantWUF := true + if !oneDataFrame { + wantWUF = false // flow control update is small, and will not be sent + } + for wantRST || wantWUF { + f, err := ct.readNonSettingsFrame() if err != nil { - return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) + return err } switch f := f.(type) { - case *SettingsFrame: case *RSTStreamFrame: - if sawRST { - return fmt.Errorf("saw second RSTStreamFrame: %v", summarizeFrame(f)) + if !wantRST { + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } if f.ErrCode != ErrCodeCancel { return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - sawRST = true + wantRST = false case *WindowUpdateFrame: - if sawWUF { - return fmt.Errorf("saw second WindowUpdateFrame: %v", summarizeFrame(f)) + if !wantWUF { + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } - if f.Increment != 4999 { + if f.Increment != 5000 { return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - sawWUF = true + wantWUF = false default: return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } } + if !oneDataFrame { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) + f, err := ct.readNonSettingsFrame() + if err != nil { + return err + } + wuf, ok := f.(*WindowUpdateFrame) + if !ok || wuf.Increment != 5000 { + return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + } + } + if err := ct.writeReadPing(); err != nil { + return err + } + if got, want := ct.inflowWindow(0), initialInflow; got != want { + return fmt.Errorf("connection flow tokens = %v, want %v", got, want) + } return nil } ct.run() @@ -3101,6 +3159,8 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { break } + initialConnWindow := ct.inflowWindow(0) + var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) @@ -3111,24 +3171,18 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { EndStream: false, BlockFragment: buf.Bytes(), }) + initialStreamWindow := ct.inflowWindow(hf.StreamID) pad := make([]byte, 5) ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - - f, err := ct.readNonSettingsFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) - } - wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { - return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + if err := ct.writeReadPing(); err != nil { + return err } - - f, err = ct.readNonSettingsFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) + // Padding flow control should have been returned. + if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) } - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { - return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } unblockClient <- true return nil From eac74319fda595b17307660315190c70cc3aa180 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 6 Jan 2023 14:19:44 +0800 Subject: [PATCH 643/843] Unexpose unnecessary NetConnWrapper interface --- internal/http2/transport.go | 12 ++++++++++-- pkg/tls/conn.go | 9 --------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 1461837b..37f251ee 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -21,7 +21,6 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/transport" - reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" "log" @@ -852,10 +851,19 @@ func (cc *ClientConn) closeConn() { cc.tconn.Close() } +// netConnWrapper is the interface to get underlying connection, which is +// introduced in go1.18 for *tls.Conn. +type netConnWrapper interface { + // NetConn returns the underlying connection that is wrapped by c. + // Note that writing to or reading from this connection directly will corrupt the + // TLS session. + NetConn() net.Conn +} + // A tls.Conn.Close can hang for a long time if the peer is unresponsive. // Try to shut it down more aggressively. func (cc *ClientConn) forceCloseConn() { - tc, ok := cc.tconn.(reqtls.NetConnWrapper) + tc, ok := cc.tconn.(netConnWrapper) if !ok { return } diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 2eaae7c1..8c6ef1ad 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -24,12 +24,3 @@ type Conn interface { // HandshakeContext or the Dialer's DialContext method instead. Handshake() error } - -// NetConnWrapper is the interface to get underlying connection, which is -// introduced in go1.18 for *tls.Conn. -type NetConnWrapper interface { - // NetConn returns the underlying connection that is wrapped by c. - // Note that writing to or reading from this connection directly will corrupt the - // TLS session. - NetConn() net.Conn -} From c66273e7aa9251e2eec4abcf9ece12deaadd8370 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 6 Jan 2023 15:01:03 +0800 Subject: [PATCH 644/843] Improve comments for Transport settings --- transport.go | 88 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 16 deletions(-) diff --git a/transport.go b/transport.go index e6446484..2476dbf6 100644 --- a/transport.go +++ b/transport.go @@ -246,97 +246,153 @@ func (t *Transport) GetMaxIdleConns() int { return t.MaxIdleConns } -// SetMaxIdleConns set the MaxIdleConns. +// SetMaxIdleConns set the MaxIdleConns, which controls the maximum number of idle (keep-alive) +// connections across all hosts. Zero means no limit. func (t *Transport) SetMaxIdleConns(max int) *Transport { t.MaxIdleConns = max return t } -// SetMaxConnsPerHost set the MaxConnsPerHost. +// SetMaxConnsPerHost set the MaxConnsPerHost, optionally limits the +// total number of connections per host, including connections in the +// dialing, active, and idle states. On limit violation, dials will block. +// +// Zero means no limit. func (t *Transport) SetMaxConnsPerHost(max int) *Transport { t.MaxConnsPerHost = max return t } -// SetIdleConnTimeout set the IdleConnTimeout. +// SetIdleConnTimeout set the IdleConnTimeout, which is the maximum +// amount of time an idle (keep-alive) connection will remain idle before +// closing itself. +// +// Zero means no limit. func (t *Transport) SetIdleConnTimeout(timeout time.Duration) *Transport { t.IdleConnTimeout = timeout return t } -// SetTLSHandshakeTimeout set the TLSHandshakeTimeout. +// SetTLSHandshakeTimeout set the TLSHandshakeTimeout, which specifies the +// maximum amount of time waiting to wait for a TLS handshake. +// +// Zero means no timeout. func (t *Transport) SetTLSHandshakeTimeout(timeout time.Duration) *Transport { t.TLSHandshakeTimeout = timeout return t } -// SetResponseHeaderTimeout set the ResponseHeaderTimeout. +// SetResponseHeaderTimeout set the ResponseHeaderTimeout, if non-zero, specifies +// the amount of time to wait for a server's response headers after fully writing +// the request (including its body, if any). This time does not include the time +// to read the response body. func (t *Transport) SetResponseHeaderTimeout(timeout time.Duration) *Transport { t.ResponseHeaderTimeout = timeout return t } -// SetExpectContinueTimeout set the ExpectContinueTimeout. +// SetExpectContinueTimeout set the ExpectContinueTimeout, if non-zero, specifies +// the amount of time to wait for a server's first response headers after fully +// writing the request headers if the request has an "Expect: 100-continue" header. +// Zero means no timeout and causes the body to be sent immediately, without waiting +// for the server to approve. +// This time does not include the time to send the request header. func (t *Transport) SetExpectContinueTimeout(timeout time.Duration) *Transport { t.ExpectContinueTimeout = timeout return t } -// SetGetProxyConnectHeader set the GetProxyConnectHeader function. +// SetGetProxyConnectHeader set the GetProxyConnectHeader, which optionally specifies a func +// to return headers to send to proxyURL during a CONNECT request to the ip:port target. +// If it returns an error, the Transport's RoundTrip fails with that error. It can +// return (nil, nil) to not add headers. +// If GetProxyConnectHeader is non-nil, ProxyConnectHeader is ignored. func (t *Transport) SetGetProxyConnectHeader(fn func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error)) *Transport { t.GetProxyConnectHeader = fn return t } -// SetProxyConnectHeader set the ProxyConnectHeader. +// SetProxyConnectHeader set the ProxyConnectHeader, which optionally specifies headers to +// send to proxies during CONNECT requests. +// To set the header dynamically, see SetGetProxyConnectHeader. func (t *Transport) SetProxyConnectHeader(header http.Header) *Transport { t.ProxyConnectHeader = header return t } -// SetReadBufferSize set the ReadBufferSize. +// SetReadBufferSize set the ReadBufferSize, which specifies the size of the read buffer used +// when reading from the transport. +// If zero, a default (currently 4KB) is used. func (t *Transport) SetReadBufferSize(size int) *Transport { t.ReadBufferSize = size return t } -// SetWriteBufferSize set the WriteBufferSize. +// SetWriteBufferSize set the WriteBufferSize, which specifies the size of the write buffer used +// when writing to the transport. +// If zero, a default (currently 4KB) is used. func (t *Transport) SetWriteBufferSize(size int) *Transport { t.WriteBufferSize = size return t } -// SetMaxResponseHeaderBytes set the MaxResponseHeaderBytes. +// SetMaxResponseHeaderBytes set the MaxResponseHeaderBytes, which specifies a limit on how many +// response bytes are allowed in the server's response header. +// +// Zero means to use a default limit. func (t *Transport) SetMaxResponseHeaderBytes(max int64) *Transport { t.MaxResponseHeaderBytes = max return t } -// SetTLSClientConfig set the custom tle client config. +// SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to +// use with tls.Client. +// If nil, the default configuration is used. +// If non-nil, HTTP/2 support may not be enabled by default. func (t *Transport) SetTLSClientConfig(cfg *tls.Config) *Transport { t.TLSClientConfig = cfg return t } -// SetDebug set the debug function. +// SetDebug set the optional debug function. func (t *Transport) SetDebug(debugf func(format string, v ...interface{})) *Transport { t.Debugf = debugf return t } -// SetProxy set the http proxy, only valid for HTTP1 and HTTP2. +// SetProxy set the http proxy, only valid for HTTP1 and HTTP2, which specifies a function +// to return a proxy for a given Request. If the function returns a non-nil error, the request +// is aborted with the provided error. +// +// The proxy type is determined by the URL scheme. "http", +// "https", and "socks5" are supported. If the scheme is empty, +// "http" is assumed. +// +// If Proxy is nil or returns a nil *URL, no proxy is used. func (t *Transport) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Transport { t.Proxy = proxy return t } -// SetDial set the custom DialContext function, only valid for HTTP1 and HTTP2. +// SetDial set the custom DialContext function, only valid for HTTP1 and HTTP2, which specifies the +// dial function for creating unencrypted TCP connections. +// If it is nil, then the transport dials using package net. +// +// The dial function runs concurrently with calls to RoundTrip. +// A RoundTrip call that initiates a dial may end up using a connection dialed previously when the +// earlier connection becomes idle before the later dial function completes. func (t *Transport) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { t.DialContext = fn return t } -// SetDialTLS set the custom DialTLSContext function, only valid for HTTP1 and HTTP2. +// SetDialTLS set the custom DialTLSContext function, only valid for HTTP1 and HTTP2, which specifies +// an optional dial function for creating TLS connections for non-proxied HTTPS requests. +// +// If it is nil, DialContext and TLSClientConfig are used. +// +// If it is set, the function that set in SetDial is not used for HTTPS requests and the TLSClientConfig +// and TLSHandshakeTimeout are ignored. The returned net.Conn is assumed to already be past the TLS handshake. func (t *Transport) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Transport { t.DialTLSContext = fn return t From 97163fccfdb2f903ec1ef9aa7ac6d341a124593c Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Jan 2023 14:26:55 +0800 Subject: [PATCH 645/843] Support SetResponseBodyTransformer --- client.go | 10 +++++++++- client_test.go | 12 ++++++++++++ client_wrapper.go | 6 ++++++ req_test.go | 12 ++++++++++++ response.go | 5 ++++- 5 files changed, 43 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 466cfa16..c37e67a7 100644 --- a/client.go +++ b/client.go @@ -66,6 +66,7 @@ type Client struct { udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware wrappedRoundTrip RoundTripper + responseBodyTransformer func([]byte) ([]byte, error) } // R create a new request. @@ -151,6 +152,13 @@ func (c *Client) GetTransport() *Transport { return c.t } +// SetResponseBodyTransformer set the response body transformer, which can modify the +// response body before unmarshalled if auto-read response body is not disabled. +func (c *Client) SetResponseBodyTransformer(fn func(body []byte) ([]byte, error)) *Client { + c.responseBodyTransformer = fn + return c +} + // SetCommonError set the common result that response body will be unmarshalled to // if it is an error response ( status `code >= 400`). func (c *Client) SetCommonError(err interface{}) *Client { @@ -1077,7 +1085,7 @@ func (c *Client) AddCommonRetryCondition(condition RetryConditionFunc) *Client { // SetUnixSocket set client to dial connection use unix socket. // For example: // -// client.SetUnixSocket("/var/run/custom.sock") +// client.SetUnixSocket("/var/run/custom.sock") func (c *Client) SetUnixSocket(file string) *Client { return c.SetDial(func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer diff --git a/client_test.go b/client_test.go index 960c028d..b03472f6 100644 --- a/client_test.go +++ b/client_test.go @@ -592,3 +592,15 @@ func TestEnableDumpAllAsync(t *testing.T) { c.EnableDumpAllTo(buf).EnableDumpAllAsync() tests.AssertEqual(t, true, c.getDumpOptions().Async) } + +func TestSetResponseBodyTransformer(t *testing.T) { + c := tc().SetResponseBodyTransformer(func(body []byte) ([]byte, error) { + result, err := url.QueryUnescape(string(body)) + return []byte(result), err + }) + user := &UserInfo{} + resp, err := c.R().SetResult(user).Get("/urlencode") + assertSuccess(t, resp, err) + tests.AssertEqual(t, user.Username, "我是roc") + tests.AssertEqual(t, user.Email, "roc@imroc.cc") +} diff --git a/client_wrapper.go b/client_wrapper.go index 7c1016e9..9bb76f90 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -598,6 +598,12 @@ func AddCommonRetryCondition(condition RetryConditionFunc) *Client { return defaultClient.AddCommonRetryCondition(condition) } +// SetResponseBodyTransformer is a global wrapper methods which delegated +// to the default client, create a request and SetResponseBodyTransformer for request. +func SetResponseBodyTransformer(fn func(body []byte) ([]byte, error)) *Client { + return defaultClient.SetResponseBodyTransformer(fn) +} + // SetUnixSocket is a global wrapper methods which delegated // to the default client, create a request and SetUnixSocket for request. func SetUnixSocket(file string) *Client { diff --git a/req_test.go b/req_test.go index 27d0b6f5..b019dc88 100644 --- a/req_test.go +++ b/req_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" @@ -226,6 +227,17 @@ func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestGet: text response")) + case "/urlencode": + info := &UserInfo{ + Username: "我是roc", + Email: "roc@imroc.cc", + } + bs, err := json.Marshal(info) + if err != nil { + panic(err) + } + result := url.QueryEscape(string(bs)) + w.Write([]byte(result)) case "/bad-request": w.WriteHeader(http.StatusBadRequest) case "/too-many": diff --git a/response.go b/response.go index 31632b7e..2afa0797 100644 --- a/response.go +++ b/response.go @@ -169,8 +169,11 @@ func (r *Response) ToBytes() ([]byte, error) { } defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) - r.body = body r.setReceivedAt() + if err == nil && r.Request.client.responseBodyTransformer != nil { + body, err = r.Request.client.responseBodyTransformer(body) + } + r.body = body return body, err } From 3b87ac1fcc67d44d2d1c9a28304f852e8fbc41fd Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Jan 2023 16:10:15 +0800 Subject: [PATCH 646/843] Optimize SetResponseBodyTransformer --- client.go | 4 ++-- client_test.go | 9 ++++++--- client_wrapper.go | 2 +- response.go | 5 +++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index c37e67a7..765f96ae 100644 --- a/client.go +++ b/client.go @@ -66,7 +66,7 @@ type Client struct { udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware wrappedRoundTrip RoundTripper - responseBodyTransformer func([]byte) ([]byte, error) + responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) } // R create a new request. @@ -154,7 +154,7 @@ func (c *Client) GetTransport() *Transport { // SetResponseBodyTransformer set the response body transformer, which can modify the // response body before unmarshalled if auto-read response body is not disabled. -func (c *Client) SetResponseBodyTransformer(fn func(body []byte) ([]byte, error)) *Client { +func (c *Client) SetResponseBodyTransformer(fn func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error)) *Client { c.responseBodyTransformer = fn return c } diff --git a/client_test.go b/client_test.go index b03472f6..7a6c4e1b 100644 --- a/client_test.go +++ b/client_test.go @@ -594,9 +594,12 @@ func TestEnableDumpAllAsync(t *testing.T) { } func TestSetResponseBodyTransformer(t *testing.T) { - c := tc().SetResponseBodyTransformer(func(body []byte) ([]byte, error) { - result, err := url.QueryUnescape(string(body)) - return []byte(result), err + c := tc().SetResponseBodyTransformer(func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) { + if resp.IsSuccess() { + result, err := url.QueryUnescape(string(rawBody)) + return []byte(result), err + } + return rawBody, nil }) user := &UserInfo{} resp, err := c.R().SetResult(user).Get("/urlencode") diff --git a/client_wrapper.go b/client_wrapper.go index 9bb76f90..3031bd5a 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -600,7 +600,7 @@ func AddCommonRetryCondition(condition RetryConditionFunc) *Client { // SetResponseBodyTransformer is a global wrapper methods which delegated // to the default client, create a request and SetResponseBodyTransformer for request. -func SetResponseBodyTransformer(fn func(body []byte) ([]byte, error)) *Client { +func SetResponseBodyTransformer(fn func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error)) *Client { return defaultClient.SetResponseBodyTransformer(fn) } diff --git a/response.go b/response.go index 2afa0797..66af8364 100644 --- a/response.go +++ b/response.go @@ -170,10 +170,11 @@ func (r *Response) ToBytes() ([]byte, error) { defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) r.setReceivedAt() + r.body = body if err == nil && r.Request.client.responseBodyTransformer != nil { - body, err = r.Request.client.responseBodyTransformer(body) + body, err = r.Request.client.responseBodyTransformer(body, r.Request, r) + r.body = body } - r.body = body return body, err } From 28436ebfcaea0dbc34c9b9c9caba66f502bc81e8 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Jan 2023 23:36:06 +0800 Subject: [PATCH 647/843] Ensure response middleware executed when error occurs --- client.go | 10 +++------- response.go | 16 ++++++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 765f96ae..b4b46235 100644 --- a/client.go +++ b/client.go @@ -1319,17 +1319,13 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { httpResponse, err = c.httpClient.Do(r.RawRequest) resp.Response = httpResponse - if err != nil { - return - } // auto-read response body if possible - if !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { + if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { _, err = resp.ToBytes() - if err != nil { - return - } // restore body for re-reads resp.Body = ioutil.NopCloser(bytes.NewReader(resp.body)) + } else if err != nil { + resp.Err = err } for _, f := range r.client.afterResponse { diff --git a/response.go b/response.go index 66af8364..6cbcb657 100644 --- a/response.go +++ b/response.go @@ -157,7 +157,7 @@ func (r *Response) ToString() (string, error) { } // ToBytes returns the response body as []byte, read body if not have been read. -func (r *Response) ToBytes() ([]byte, error) { +func (r *Response) ToBytes() (body []byte, err error) { if r.Err != nil { return nil, r.Err } @@ -167,15 +167,19 @@ func (r *Response) ToBytes() ([]byte, error) { if r.Response == nil || r.Response.Body == nil { return []byte{}, nil } - defer r.Body.Close() - body, err := ioutil.ReadAll(r.Body) + defer func() { + r.Body.Close() + if err != nil { + r.Err = err + } + r.body = body + }() + body, err = ioutil.ReadAll(r.Body) r.setReceivedAt() - r.body = body if err == nil && r.Request.client.responseBodyTransformer != nil { body, err = r.Request.client.responseBodyTransformer(body, r.Request, r) - r.body = body } - return body, err + return } // Dump return the string content that have been dumped for the request. From 1039e89e5b5ca921e1a28fc578968299d6827e5c Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 16 Jan 2023 18:47:32 +0800 Subject: [PATCH 648/843] embed quic-go and support go1.20 --- go.mod | 21 +- go.sum | 88 +- internal/http3/body.go | 2 +- internal/http3/body_test.go | 4 +- internal/http3/client.go | 10 +- internal/http3/client_test.go | 13 +- internal/http3/error_codes.go | 2 +- internal/http3/frames.go | 4 +- internal/http3/frames_test.go | 2 +- internal/http3/http_stream.go | 2 +- internal/http3/http_stream_test.go | 2 +- internal/http3/request.go | 3 +- internal/http3/request_test.go | 2 +- internal/http3/request_writer.go | 6 +- internal/http3/request_writer_test.go | 7 +- internal/http3/response_writer.go | 6 +- internal/http3/response_writer_test.go | 7 +- internal/http3/roundtrip.go | 2 +- internal/http3/roundtrip_test.go | 4 +- internal/http3/server.go | 4 +- internal/mocks/mockgen.go | 20 - internal/protocol/protocol.go | 4 - internal/protocol/version.go | 13 - internal/qtls/go116.go | 19 - internal/qtls/go117.go | 19 - internal/qtls/go118.go | 19 - internal/qtls/go119.go | 19 - internal/qtls/go120.go | 6 - internal/quic-go/ackhandler/ack_eliciting.go | 20 + .../quic-go/ackhandler/ack_eliciting_test.go | 34 + internal/quic-go/ackhandler/ackhandler.go | 21 + .../ackhandler/ackhandler_suite_test.go | 29 + internal/quic-go/ackhandler/frame.go | 9 + internal/quic-go/ackhandler/gen.go | 3 + internal/quic-go/ackhandler/interfaces.go | 68 + .../mock_sent_packet_tracker_test.go | 61 + internal/quic-go/ackhandler/mockgen.go | 3 + .../quic-go/ackhandler/packet_linkedlist.go | 217 ++ .../ackhandler/packet_number_generator.go | 76 + .../packet_number_generator_test.go | 99 + .../ackhandler/received_packet_handler.go | 136 + .../received_packet_handler_test.go | 168 + .../ackhandler/received_packet_history.go | 142 + .../received_packet_history_test.go | 354 ++ .../ackhandler/received_packet_tracker.go | 196 ++ .../received_packet_tracker_test.go | 348 ++ internal/quic-go/ackhandler/send_mode.go | 40 + internal/quic-go/ackhandler/send_mode_test.go | 18 + .../quic-go/ackhandler/sent_packet_handler.go | 838 +++++ .../ackhandler/sent_packet_handler_test.go | 1386 ++++++++ .../quic-go/ackhandler/sent_packet_history.go | 108 + .../ackhandler/sent_packet_history_test.go | 263 ++ internal/quic-go/buffer_pool.go | 80 + internal/quic-go/buffer_pool_test.go | 55 + internal/quic-go/client.go | 339 ++ internal/quic-go/client_test.go | 611 ++++ internal/quic-go/closed_conn.go | 112 + internal/quic-go/closed_conn_test.go | 56 + internal/quic-go/config.go | 124 + internal/quic-go/config_test.go | 180 + internal/quic-go/congestion/bandwidth.go | 25 + internal/quic-go/congestion/bandwidth_test.go | 14 + internal/quic-go/congestion/clock.go | 18 + .../congestion/congestion_suite_test.go | 13 + internal/quic-go/congestion/cubic.go | 214 ++ internal/quic-go/congestion/cubic_sender.go | 316 ++ .../quic-go/congestion/cubic_sender_test.go | 526 +++ internal/quic-go/congestion/cubic_test.go | 239 ++ .../quic-go/congestion/hybrid_slow_start.go | 113 + .../congestion/hybrid_slow_start_test.go | 72 + internal/quic-go/congestion/interface.go | 28 + internal/quic-go/congestion/pacer.go | 77 + internal/quic-go/congestion/pacer_test.go | 131 + internal/quic-go/conn_id_generator.go | 140 + internal/quic-go/conn_id_generator_test.go | 187 + internal/quic-go/conn_id_manager.go | 207 ++ internal/quic-go/conn_id_manager_test.go | 364 ++ internal/quic-go/connection.go | 2006 +++++++++++ internal/quic-go/connection_test.go | 3038 +++++++++++++++++ internal/quic-go/crypto_stream.go | 115 + internal/quic-go/crypto_stream_manager.go | 61 + .../quic-go/crypto_stream_manager_test.go | 119 + internal/quic-go/crypto_stream_test.go | 187 + internal/quic-go/datagram_queue.go | 87 + internal/quic-go/datagram_queue_test.go | 98 + internal/quic-go/errors.go | 58 + .../flowcontrol/base_flow_controller.go | 125 + .../flowcontrol/base_flow_controller_test.go | 236 ++ .../flowcontrol/connection_flow_controller.go | 112 + .../connection_flow_controller_test.go | 185 + .../flowcontrol/flowcontrol_suite_test.go | 24 + internal/quic-go/flowcontrol/interface.go | 42 + .../flowcontrol/stream_flow_controller.go | 149 + .../stream_flow_controller_test.go | 272 ++ internal/quic-go/frame_sorter.go | 224 ++ internal/quic-go/frame_sorter_test.go | 1527 +++++++++ internal/quic-go/framer.go | 171 + internal/quic-go/framer_test.go | 385 +++ internal/quic-go/handshake/aead.go | 161 + internal/quic-go/handshake/aead_test.go | 204 ++ internal/quic-go/handshake/crypto_setup.go | 819 +++++ .../quic-go/handshake/crypto_setup_test.go | 864 +++++ .../quic-go/handshake/handshake_suite_test.go | 48 + .../quic-go/handshake/header_protector.go | 136 + internal/quic-go/handshake/hkdf.go | 29 + internal/quic-go/handshake/hkdf_test.go | 17 + internal/quic-go/handshake/initial_aead.go | 81 + .../quic-go/handshake/initial_aead_test.go | 219 ++ internal/quic-go/handshake/interface.go | 102 + .../handshake/mock_handshake_runner_test.go | 84 + internal/quic-go/handshake/mockgen.go | 3 + internal/quic-go/handshake/retry.go | 62 + internal/quic-go/handshake/retry_test.go | 36 + internal/quic-go/handshake/session_ticket.go | 48 + .../quic-go/handshake/session_ticket_test.go | 54 + .../handshake/tls_extension_handler.go | 68 + .../handshake/tls_extension_handler_test.go | 210 ++ internal/quic-go/handshake/token_generator.go | 134 + .../quic-go/handshake/token_generator_test.go | 127 + internal/quic-go/handshake/token_protector.go | 89 + .../quic-go/handshake/token_protector_test.go | 67 + internal/quic-go/handshake/updatable_aead.go | 323 ++ .../quic-go/handshake/updatable_aead_test.go | 528 +++ internal/quic-go/interface.go | 328 ++ internal/quic-go/logging/frame.go | 66 + internal/quic-go/logging/interface.go | 134 + .../quic-go/logging/logging_suite_test.go | 25 + .../logging/mock_connection_tracer_test.go | 351 ++ internal/quic-go/logging/mock_tracer_test.go | 76 + internal/quic-go/logging/mockgen.go | 4 + internal/quic-go/logging/multiplex.go | 219 ++ internal/quic-go/logging/multiplex_test.go | 266 ++ internal/quic-go/logging/packet_header.go | 27 + .../quic-go/logging/packet_header_test.go | 60 + internal/quic-go/logging/types.go | 94 + internal/quic-go/logutils/frame.go | 33 + internal/quic-go/logutils/frame_test.go | 51 + .../quic-go/logutils/logutils_suite_test.go | 13 + .../quic-go/mock_ack_frame_source_test.go | 50 + internal/quic-go/mock_batch_conn_test.go | 50 + internal/quic-go/mock_conn_runner_test.go | 123 + .../quic-go/mock_crypto_data_handler_test.go | 49 + internal/quic-go/mock_crypto_stream_test.go | 121 + internal/quic-go/mock_frame_source_test.go | 80 + internal/quic-go/mock_mtu_discoverer_test.go | 66 + internal/quic-go/mock_multiplexer_test.go | 65 + internal/quic-go/mock_packer_test.go | 179 + .../mock_packet_handler_manager_test.go | 175 + internal/quic-go/mock_packet_handler_test.go | 85 + internal/quic-go/mock_packetconn_test.go | 137 + internal/quic-go/mock_quic_conn_test.go | 346 ++ .../mock_receive_stream_internal_test.go | 146 + internal/quic-go/mock_sealing_manager_test.go | 95 + internal/quic-go/mock_send_conn_test.go | 91 + .../quic-go/mock_send_stream_internal_test.go | 187 + internal/quic-go/mock_sender_test.go | 100 + internal/quic-go/mock_stream_getter_test.go | 65 + internal/quic-go/mock_stream_internal_test.go | 284 ++ internal/quic-go/mock_stream_manager_test.go | 231 ++ internal/quic-go/mock_stream_sender_test.go | 72 + internal/quic-go/mock_token_store_test.go | 60 + .../mock_unknown_packet_handler_test.go | 58 + internal/quic-go/mock_unpacker_test.go | 51 + internal/quic-go/mockgen.go | 27 + internal/quic-go/mockgen_private.sh | 49 + .../ackhandler/received_packet_handler.go | 105 + .../mocks/ackhandler/sent_packet_handler.go | 240 ++ internal/quic-go/mocks/congestion.go | 192 ++ .../mocks/connection_flow_controller.go | 128 + internal/quic-go/mocks/crypto_setup.go | 264 ++ .../mocks/logging/connection_tracer.go | 352 ++ internal/quic-go/mocks/logging/tracer.go | 77 + internal/quic-go/mocks/long_header_opener.go | 76 + internal/quic-go/mocks/mockgen.go | 20 + .../{ => quic-go}/mocks/quic/early_conn.go | 7 +- internal/quic-go/mocks/quic/early_listener.go | 80 + internal/{ => quic-go}/mocks/quic/stream.go | 13 +- internal/quic-go/mocks/short_header_opener.go | 77 + internal/quic-go/mocks/short_header_sealer.go | 89 + .../quic-go/mocks/stream_flow_controller.go | 140 + .../quic-go/mocks/tls/client_session_cache.go | 62 + internal/quic-go/mtu_discoverer.go | 74 + internal/quic-go/mtu_discoverer_test.go | 112 + internal/quic-go/multiplexer.go | 107 + internal/quic-go/multiplexer_test.go | 70 + internal/quic-go/packet_handler_map.go | 489 +++ internal/quic-go/packet_handler_map_test.go | 495 +++ internal/quic-go/packet_packer.go | 894 +++++ internal/quic-go/packet_packer_test.go | 1556 +++++++++ internal/quic-go/packet_unpacker.go | 196 ++ internal/quic-go/packet_unpacker_test.go | 292 ++ internal/quic-go/protocol/connection_id.go | 69 + .../quic-go/protocol/connection_id_test.go | 108 + internal/quic-go/protocol/encryption_level.go | 30 + .../quic-go/protocol/encryption_level_test.go | 20 + internal/quic-go/protocol/key_phase.go | 36 + internal/quic-go/protocol/key_phase_test.go | 27 + internal/quic-go/protocol/packet_number.go | 79 + .../quic-go/protocol/packet_number_test.go | 204 ++ internal/quic-go/protocol/params.go | 193 ++ internal/quic-go/protocol/params_test.go | 13 + internal/quic-go/protocol/perspective.go | 26 + internal/quic-go/protocol/perspective_test.go | 19 + internal/quic-go/protocol/protocol.go | 97 + .../quic-go/protocol/protocol_suite_test.go | 13 + internal/quic-go/protocol/protocol_test.go | 25 + internal/quic-go/protocol/stream.go | 76 + internal/quic-go/protocol/stream_test.go | 70 + internal/quic-go/protocol/version.go | 114 + internal/quic-go/protocol/version_test.go | 121 + internal/quic-go/qerr/error_codes.go | 88 + internal/quic-go/qerr/errorcodes_test.go | 52 + internal/quic-go/qerr/errors.go | 124 + internal/quic-go/qerr/errors_suite_test.go | 13 + internal/quic-go/qerr/errors_test.go | 124 + internal/quic-go/qlog/event.go | 529 +++ internal/quic-go/qlog/event_test.go | 43 + internal/quic-go/qlog/frame.go | 227 ++ internal/quic-go/qlog/frame_test.go | 377 ++ internal/quic-go/qlog/packet_header.go | 119 + internal/quic-go/qlog/packet_header_test.go | 175 + internal/quic-go/qlog/qlog.go | 486 +++ internal/quic-go/qlog/qlog_suite_test.go | 51 + internal/quic-go/qlog/qlog_test.go | 849 +++++ internal/quic-go/qlog/trace.go | 66 + internal/quic-go/qlog/types.go | 320 ++ internal/quic-go/qlog/types_test.go | 130 + internal/quic-go/qtls/go116.go | 100 + internal/quic-go/qtls/go117.go | 100 + internal/quic-go/qtls/go118.go | 100 + internal/quic-go/qtls/go119.go | 100 + internal/{ => quic-go}/qtls/go_oldversion.go | 2 +- internal/quic-go/qtls/qtls_suite_test.go | 25 + internal/quic-go/qtls/qtls_test.go | 17 + internal/quic-go/quic_suite_test.go | 34 + internal/{ => quic-go}/quicvarint/io.go | 0 internal/{ => quic-go}/quicvarint/io_test.go | 0 .../quicvarint/quicvarint_suite_test.go | 0 internal/{ => quic-go}/quicvarint/varint.go | 2 +- .../{ => quic-go}/quicvarint/varint_test.go | 0 internal/quic-go/receive_stream.go | 331 ++ internal/quic-go/receive_stream_test.go | 696 ++++ internal/quic-go/retransmission_queue.go | 131 + internal/quic-go/retransmission_queue_test.go | 187 + internal/quic-go/send_conn.go | 74 + internal/quic-go/send_conn_test.go | 45 + internal/quic-go/send_queue.go | 88 + internal/quic-go/send_queue_test.go | 126 + internal/quic-go/send_stream.go | 496 +++ internal/quic-go/send_stream_test.go | 1159 +++++++ internal/quic-go/server.go | 670 ++++ internal/quic-go/server_test.go | 1237 +++++++ internal/quic-go/stream.go | 149 + internal/quic-go/stream_test.go | 106 + internal/quic-go/streams_map.go | 317 ++ .../quic-go/streams_map_generic_helper.go | 18 + internal/quic-go/streams_map_incoming_bidi.go | 192 ++ .../quic-go/streams_map_incoming_generic.go | 190 ++ .../streams_map_incoming_generic_test.go | 307 ++ internal/quic-go/streams_map_incoming_uni.go | 192 ++ internal/quic-go/streams_map_outgoing_bidi.go | 226 ++ .../quic-go/streams_map_outgoing_generic.go | 224 ++ .../streams_map_outgoing_generic_test.go | 539 +++ internal/quic-go/streams_map_outgoing_uni.go | 226 ++ internal/quic-go/streams_map_test.go | 499 +++ internal/quic-go/sys_conn.go | 80 + internal/quic-go/sys_conn_df.go | 16 + internal/quic-go/sys_conn_df_linux.go | 40 + internal/quic-go/sys_conn_df_windows.go | 46 + internal/quic-go/sys_conn_helper_darwin.go | 22 + internal/quic-go/sys_conn_helper_freebsd.go | 22 + internal/quic-go/sys_conn_helper_linux.go | 20 + internal/quic-go/sys_conn_no_oob.go | 16 + internal/quic-go/sys_conn_oob.go | 257 ++ internal/quic-go/sys_conn_oob_test.go | 243 ++ internal/quic-go/sys_conn_test.go | 33 + internal/quic-go/sys_conn_windows.go | 40 + internal/quic-go/sys_conn_windows_test.go | 33 + internal/quic-go/testdata/ca.pem | 17 + internal/quic-go/testdata/cert.go | 55 + internal/quic-go/testdata/cert.pem | 18 + internal/quic-go/testdata/cert_test.go | 31 + internal/quic-go/testdata/generate_key.sh | 24 + internal/quic-go/testdata/priv.key | 28 + .../quic-go/testdata/testdata_suite_test.go | 13 + internal/quic-go/testutils/testutils.go | 97 + internal/quic-go/token_store.go | 117 + internal/quic-go/token_store_test.go | 108 + internal/quic-go/tools.go | 9 + internal/quic-go/utils/atomic_bool.go | 22 + internal/quic-go/utils/atomic_bool_test.go | 29 + .../quic-go/utils/buffered_write_closer.go | 26 + .../utils/buffered_write_closer_test.go | 26 + .../quic-go/utils/byteinterval_linkedlist.go | 217 ++ .../quic-go/utils/byteoder_big_endian_test.go | 107 + internal/quic-go/utils/byteorder.go | 17 + .../quic-go/utils/byteorder_big_endian.go | 89 + internal/{ => quic-go}/utils/gen.go | 0 internal/quic-go/utils/ip.go | 10 + internal/quic-go/utils/ip_test.go | 17 + internal/quic-go/utils/linkedlist/README.md | 11 + .../quic-go/utils/linkedlist/linkedlist.go | 218 ++ internal/{ => quic-go}/utils/log.go | 2 +- internal/{ => quic-go}/utils/log_test.go | 0 internal/quic-go/utils/minmax.go | 170 + internal/quic-go/utils/minmax_test.go | 123 + internal/quic-go/utils/new_connection_id.go | 12 + .../utils/newconnectionid_linkedlist.go | 217 ++ internal/quic-go/utils/packet_interval.go | 9 + .../utils/packetinterval_linkedlist.go | 217 ++ internal/quic-go/utils/rand.go | 29 + internal/quic-go/utils/rand_test.go | 32 + internal/quic-go/utils/rtt_stats.go | 127 + internal/quic-go/utils/rtt_stats_test.go | 157 + .../quic-go/utils/streamframe_interval.go | 9 + internal/quic-go/utils/timer.go | 53 + internal/quic-go/utils/timer_test.go | 87 + .../{ => quic-go}/utils/utils_suite_test.go | 0 internal/quic-go/window_update_queue.go | 71 + internal/quic-go/window_update_queue_test.go | 112 + internal/quic-go/wire/ack_frame.go | 251 ++ internal/quic-go/wire/ack_frame_test.go | 454 +++ internal/quic-go/wire/ack_range.go | 14 + internal/quic-go/wire/ack_range_test.go | 13 + .../quic-go/wire/connection_close_frame.go | 83 + .../wire/connection_close_frame_test.go | 153 + internal/quic-go/wire/crypto_frame.go | 102 + internal/quic-go/wire/crypto_frame_test.go | 148 + internal/quic-go/wire/data_blocked_frame.go | 38 + .../quic-go/wire/data_blocked_frame_test.go | 54 + internal/quic-go/wire/datagram_frame.go | 85 + internal/quic-go/wire/datagram_frame_test.go | 154 + internal/quic-go/wire/extended_header.go | 249 ++ internal/quic-go/wire/extended_header_test.go | 481 +++ internal/quic-go/wire/frame_parser.go | 143 + internal/quic-go/wire/frame_parser_test.go | 410 +++ internal/quic-go/wire/handshake_done_frame.go | 28 + internal/quic-go/wire/header.go | 274 ++ internal/quic-go/wire/header_test.go | 583 ++++ internal/quic-go/wire/interface.go | 19 + internal/quic-go/wire/log.go | 72 + internal/quic-go/wire/log_test.go | 168 + internal/quic-go/wire/max_data_frame.go | 40 + internal/quic-go/wire/max_data_frame_test.go | 57 + .../quic-go/wire/max_stream_data_frame.go | 46 + .../wire/max_stream_data_frame_test.go | 63 + internal/quic-go/wire/max_streams_frame.go | 55 + .../quic-go/wire/max_streams_frame_test.go | 107 + .../quic-go/wire/new_connection_id_frame.go | 80 + .../wire/new_connection_id_frame_test.go | 104 + internal/quic-go/wire/new_token_frame.go | 48 + internal/quic-go/wire/new_token_frame_test.go | 66 + internal/quic-go/wire/path_challenge_frame.go | 38 + .../quic-go/wire/path_challenge_frame_test.go | 48 + internal/quic-go/wire/path_response_frame.go | 38 + .../quic-go/wire/path_response_frame_test.go | 47 + internal/quic-go/wire/ping_frame.go | 27 + internal/quic-go/wire/ping_frame_test.go | 39 + internal/quic-go/wire/pool.go | 33 + internal/quic-go/wire/pool_test.go | 24 + internal/quic-go/wire/reset_stream_frame.go | 58 + .../quic-go/wire/reset_stream_frame_test.go | 70 + .../wire/retire_connection_id_frame.go | 36 + .../wire/retire_connection_id_frame_test.go | 53 + internal/quic-go/wire/stop_sending_frame.go | 48 + .../quic-go/wire/stop_sending_frame_test.go | 63 + .../quic-go/wire/stream_data_blocked_frame.go | 46 + .../wire/stream_data_blocked_frame_test.go | 63 + internal/quic-go/wire/stream_frame.go | 189 + internal/quic-go/wire/stream_frame_test.go | 443 +++ .../quic-go/wire/streams_blocked_frame.go | 55 + .../wire/streams_blocked_frame_test.go | 108 + .../quic-go/wire/transport_parameter_test.go | 612 ++++ internal/quic-go/wire/transport_parameters.go | 476 +++ internal/quic-go/wire/version_negotiation.go | 54 + .../quic-go/wire/version_negotiation_test.go | 83 + internal/quic-go/wire/wire_suite_test.go | 31 + 377 files changed, 58753 insertions(+), 208 deletions(-) delete mode 100644 internal/mocks/mockgen.go delete mode 100644 internal/protocol/protocol.go delete mode 100644 internal/protocol/version.go delete mode 100644 internal/qtls/go116.go delete mode 100644 internal/qtls/go117.go delete mode 100644 internal/qtls/go118.go delete mode 100644 internal/qtls/go119.go delete mode 100644 internal/qtls/go120.go create mode 100644 internal/quic-go/ackhandler/ack_eliciting.go create mode 100644 internal/quic-go/ackhandler/ack_eliciting_test.go create mode 100644 internal/quic-go/ackhandler/ackhandler.go create mode 100644 internal/quic-go/ackhandler/ackhandler_suite_test.go create mode 100644 internal/quic-go/ackhandler/frame.go create mode 100644 internal/quic-go/ackhandler/gen.go create mode 100644 internal/quic-go/ackhandler/interfaces.go create mode 100644 internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go create mode 100644 internal/quic-go/ackhandler/mockgen.go create mode 100644 internal/quic-go/ackhandler/packet_linkedlist.go create mode 100644 internal/quic-go/ackhandler/packet_number_generator.go create mode 100644 internal/quic-go/ackhandler/packet_number_generator_test.go create mode 100644 internal/quic-go/ackhandler/received_packet_handler.go create mode 100644 internal/quic-go/ackhandler/received_packet_handler_test.go create mode 100644 internal/quic-go/ackhandler/received_packet_history.go create mode 100644 internal/quic-go/ackhandler/received_packet_history_test.go create mode 100644 internal/quic-go/ackhandler/received_packet_tracker.go create mode 100644 internal/quic-go/ackhandler/received_packet_tracker_test.go create mode 100644 internal/quic-go/ackhandler/send_mode.go create mode 100644 internal/quic-go/ackhandler/send_mode_test.go create mode 100644 internal/quic-go/ackhandler/sent_packet_handler.go create mode 100644 internal/quic-go/ackhandler/sent_packet_handler_test.go create mode 100644 internal/quic-go/ackhandler/sent_packet_history.go create mode 100644 internal/quic-go/ackhandler/sent_packet_history_test.go create mode 100644 internal/quic-go/buffer_pool.go create mode 100644 internal/quic-go/buffer_pool_test.go create mode 100644 internal/quic-go/client.go create mode 100644 internal/quic-go/client_test.go create mode 100644 internal/quic-go/closed_conn.go create mode 100644 internal/quic-go/closed_conn_test.go create mode 100644 internal/quic-go/config.go create mode 100644 internal/quic-go/config_test.go create mode 100644 internal/quic-go/congestion/bandwidth.go create mode 100644 internal/quic-go/congestion/bandwidth_test.go create mode 100644 internal/quic-go/congestion/clock.go create mode 100644 internal/quic-go/congestion/congestion_suite_test.go create mode 100644 internal/quic-go/congestion/cubic.go create mode 100644 internal/quic-go/congestion/cubic_sender.go create mode 100644 internal/quic-go/congestion/cubic_sender_test.go create mode 100644 internal/quic-go/congestion/cubic_test.go create mode 100644 internal/quic-go/congestion/hybrid_slow_start.go create mode 100644 internal/quic-go/congestion/hybrid_slow_start_test.go create mode 100644 internal/quic-go/congestion/interface.go create mode 100644 internal/quic-go/congestion/pacer.go create mode 100644 internal/quic-go/congestion/pacer_test.go create mode 100644 internal/quic-go/conn_id_generator.go create mode 100644 internal/quic-go/conn_id_generator_test.go create mode 100644 internal/quic-go/conn_id_manager.go create mode 100644 internal/quic-go/conn_id_manager_test.go create mode 100644 internal/quic-go/connection.go create mode 100644 internal/quic-go/connection_test.go create mode 100644 internal/quic-go/crypto_stream.go create mode 100644 internal/quic-go/crypto_stream_manager.go create mode 100644 internal/quic-go/crypto_stream_manager_test.go create mode 100644 internal/quic-go/crypto_stream_test.go create mode 100644 internal/quic-go/datagram_queue.go create mode 100644 internal/quic-go/datagram_queue_test.go create mode 100644 internal/quic-go/errors.go create mode 100644 internal/quic-go/flowcontrol/base_flow_controller.go create mode 100644 internal/quic-go/flowcontrol/base_flow_controller_test.go create mode 100644 internal/quic-go/flowcontrol/connection_flow_controller.go create mode 100644 internal/quic-go/flowcontrol/connection_flow_controller_test.go create mode 100644 internal/quic-go/flowcontrol/flowcontrol_suite_test.go create mode 100644 internal/quic-go/flowcontrol/interface.go create mode 100644 internal/quic-go/flowcontrol/stream_flow_controller.go create mode 100644 internal/quic-go/flowcontrol/stream_flow_controller_test.go create mode 100644 internal/quic-go/frame_sorter.go create mode 100644 internal/quic-go/frame_sorter_test.go create mode 100644 internal/quic-go/framer.go create mode 100644 internal/quic-go/framer_test.go create mode 100644 internal/quic-go/handshake/aead.go create mode 100644 internal/quic-go/handshake/aead_test.go create mode 100644 internal/quic-go/handshake/crypto_setup.go create mode 100644 internal/quic-go/handshake/crypto_setup_test.go create mode 100644 internal/quic-go/handshake/handshake_suite_test.go create mode 100644 internal/quic-go/handshake/header_protector.go create mode 100644 internal/quic-go/handshake/hkdf.go create mode 100644 internal/quic-go/handshake/hkdf_test.go create mode 100644 internal/quic-go/handshake/initial_aead.go create mode 100644 internal/quic-go/handshake/initial_aead_test.go create mode 100644 internal/quic-go/handshake/interface.go create mode 100644 internal/quic-go/handshake/mock_handshake_runner_test.go create mode 100644 internal/quic-go/handshake/mockgen.go create mode 100644 internal/quic-go/handshake/retry.go create mode 100644 internal/quic-go/handshake/retry_test.go create mode 100644 internal/quic-go/handshake/session_ticket.go create mode 100644 internal/quic-go/handshake/session_ticket_test.go create mode 100644 internal/quic-go/handshake/tls_extension_handler.go create mode 100644 internal/quic-go/handshake/tls_extension_handler_test.go create mode 100644 internal/quic-go/handshake/token_generator.go create mode 100644 internal/quic-go/handshake/token_generator_test.go create mode 100644 internal/quic-go/handshake/token_protector.go create mode 100644 internal/quic-go/handshake/token_protector_test.go create mode 100644 internal/quic-go/handshake/updatable_aead.go create mode 100644 internal/quic-go/handshake/updatable_aead_test.go create mode 100644 internal/quic-go/interface.go create mode 100644 internal/quic-go/logging/frame.go create mode 100644 internal/quic-go/logging/interface.go create mode 100644 internal/quic-go/logging/logging_suite_test.go create mode 100644 internal/quic-go/logging/mock_connection_tracer_test.go create mode 100644 internal/quic-go/logging/mock_tracer_test.go create mode 100644 internal/quic-go/logging/mockgen.go create mode 100644 internal/quic-go/logging/multiplex.go create mode 100644 internal/quic-go/logging/multiplex_test.go create mode 100644 internal/quic-go/logging/packet_header.go create mode 100644 internal/quic-go/logging/packet_header_test.go create mode 100644 internal/quic-go/logging/types.go create mode 100644 internal/quic-go/logutils/frame.go create mode 100644 internal/quic-go/logutils/frame_test.go create mode 100644 internal/quic-go/logutils/logutils_suite_test.go create mode 100644 internal/quic-go/mock_ack_frame_source_test.go create mode 100644 internal/quic-go/mock_batch_conn_test.go create mode 100644 internal/quic-go/mock_conn_runner_test.go create mode 100644 internal/quic-go/mock_crypto_data_handler_test.go create mode 100644 internal/quic-go/mock_crypto_stream_test.go create mode 100644 internal/quic-go/mock_frame_source_test.go create mode 100644 internal/quic-go/mock_mtu_discoverer_test.go create mode 100644 internal/quic-go/mock_multiplexer_test.go create mode 100644 internal/quic-go/mock_packer_test.go create mode 100644 internal/quic-go/mock_packet_handler_manager_test.go create mode 100644 internal/quic-go/mock_packet_handler_test.go create mode 100644 internal/quic-go/mock_packetconn_test.go create mode 100644 internal/quic-go/mock_quic_conn_test.go create mode 100644 internal/quic-go/mock_receive_stream_internal_test.go create mode 100644 internal/quic-go/mock_sealing_manager_test.go create mode 100644 internal/quic-go/mock_send_conn_test.go create mode 100644 internal/quic-go/mock_send_stream_internal_test.go create mode 100644 internal/quic-go/mock_sender_test.go create mode 100644 internal/quic-go/mock_stream_getter_test.go create mode 100644 internal/quic-go/mock_stream_internal_test.go create mode 100644 internal/quic-go/mock_stream_manager_test.go create mode 100644 internal/quic-go/mock_stream_sender_test.go create mode 100644 internal/quic-go/mock_token_store_test.go create mode 100644 internal/quic-go/mock_unknown_packet_handler_test.go create mode 100644 internal/quic-go/mock_unpacker_test.go create mode 100644 internal/quic-go/mockgen.go create mode 100755 internal/quic-go/mockgen_private.sh create mode 100644 internal/quic-go/mocks/ackhandler/received_packet_handler.go create mode 100644 internal/quic-go/mocks/ackhandler/sent_packet_handler.go create mode 100644 internal/quic-go/mocks/congestion.go create mode 100644 internal/quic-go/mocks/connection_flow_controller.go create mode 100644 internal/quic-go/mocks/crypto_setup.go create mode 100644 internal/quic-go/mocks/logging/connection_tracer.go create mode 100644 internal/quic-go/mocks/logging/tracer.go create mode 100644 internal/quic-go/mocks/long_header_opener.go create mode 100644 internal/quic-go/mocks/mockgen.go rename internal/{ => quic-go}/mocks/quic/early_conn.go (97%) create mode 100644 internal/quic-go/mocks/quic/early_listener.go rename internal/{ => quic-go}/mocks/quic/stream.go (92%) create mode 100644 internal/quic-go/mocks/short_header_opener.go create mode 100644 internal/quic-go/mocks/short_header_sealer.go create mode 100644 internal/quic-go/mocks/stream_flow_controller.go create mode 100644 internal/quic-go/mocks/tls/client_session_cache.go create mode 100644 internal/quic-go/mtu_discoverer.go create mode 100644 internal/quic-go/mtu_discoverer_test.go create mode 100644 internal/quic-go/multiplexer.go create mode 100644 internal/quic-go/multiplexer_test.go create mode 100644 internal/quic-go/packet_handler_map.go create mode 100644 internal/quic-go/packet_handler_map_test.go create mode 100644 internal/quic-go/packet_packer.go create mode 100644 internal/quic-go/packet_packer_test.go create mode 100644 internal/quic-go/packet_unpacker.go create mode 100644 internal/quic-go/packet_unpacker_test.go create mode 100644 internal/quic-go/protocol/connection_id.go create mode 100644 internal/quic-go/protocol/connection_id_test.go create mode 100644 internal/quic-go/protocol/encryption_level.go create mode 100644 internal/quic-go/protocol/encryption_level_test.go create mode 100644 internal/quic-go/protocol/key_phase.go create mode 100644 internal/quic-go/protocol/key_phase_test.go create mode 100644 internal/quic-go/protocol/packet_number.go create mode 100644 internal/quic-go/protocol/packet_number_test.go create mode 100644 internal/quic-go/protocol/params.go create mode 100644 internal/quic-go/protocol/params_test.go create mode 100644 internal/quic-go/protocol/perspective.go create mode 100644 internal/quic-go/protocol/perspective_test.go create mode 100644 internal/quic-go/protocol/protocol.go create mode 100644 internal/quic-go/protocol/protocol_suite_test.go create mode 100644 internal/quic-go/protocol/protocol_test.go create mode 100644 internal/quic-go/protocol/stream.go create mode 100644 internal/quic-go/protocol/stream_test.go create mode 100644 internal/quic-go/protocol/version.go create mode 100644 internal/quic-go/protocol/version_test.go create mode 100644 internal/quic-go/qerr/error_codes.go create mode 100644 internal/quic-go/qerr/errorcodes_test.go create mode 100644 internal/quic-go/qerr/errors.go create mode 100644 internal/quic-go/qerr/errors_suite_test.go create mode 100644 internal/quic-go/qerr/errors_test.go create mode 100644 internal/quic-go/qlog/event.go create mode 100644 internal/quic-go/qlog/event_test.go create mode 100644 internal/quic-go/qlog/frame.go create mode 100644 internal/quic-go/qlog/frame_test.go create mode 100644 internal/quic-go/qlog/packet_header.go create mode 100644 internal/quic-go/qlog/packet_header_test.go create mode 100644 internal/quic-go/qlog/qlog.go create mode 100644 internal/quic-go/qlog/qlog_suite_test.go create mode 100644 internal/quic-go/qlog/qlog_test.go create mode 100644 internal/quic-go/qlog/trace.go create mode 100644 internal/quic-go/qlog/types.go create mode 100644 internal/quic-go/qlog/types_test.go create mode 100644 internal/quic-go/qtls/go116.go create mode 100644 internal/quic-go/qtls/go117.go create mode 100644 internal/quic-go/qtls/go118.go create mode 100644 internal/quic-go/qtls/go119.go rename internal/{ => quic-go}/qtls/go_oldversion.go (80%) create mode 100644 internal/quic-go/qtls/qtls_suite_test.go create mode 100644 internal/quic-go/qtls/qtls_test.go create mode 100644 internal/quic-go/quic_suite_test.go rename internal/{ => quic-go}/quicvarint/io.go (100%) rename internal/{ => quic-go}/quicvarint/io_test.go (100%) rename internal/{ => quic-go}/quicvarint/quicvarint_suite_test.go (100%) rename internal/{ => quic-go}/quicvarint/varint.go (98%) rename internal/{ => quic-go}/quicvarint/varint_test.go (100%) create mode 100644 internal/quic-go/receive_stream.go create mode 100644 internal/quic-go/receive_stream_test.go create mode 100644 internal/quic-go/retransmission_queue.go create mode 100644 internal/quic-go/retransmission_queue_test.go create mode 100644 internal/quic-go/send_conn.go create mode 100644 internal/quic-go/send_conn_test.go create mode 100644 internal/quic-go/send_queue.go create mode 100644 internal/quic-go/send_queue_test.go create mode 100644 internal/quic-go/send_stream.go create mode 100644 internal/quic-go/send_stream_test.go create mode 100644 internal/quic-go/server.go create mode 100644 internal/quic-go/server_test.go create mode 100644 internal/quic-go/stream.go create mode 100644 internal/quic-go/stream_test.go create mode 100644 internal/quic-go/streams_map.go create mode 100644 internal/quic-go/streams_map_generic_helper.go create mode 100644 internal/quic-go/streams_map_incoming_bidi.go create mode 100644 internal/quic-go/streams_map_incoming_generic.go create mode 100644 internal/quic-go/streams_map_incoming_generic_test.go create mode 100644 internal/quic-go/streams_map_incoming_uni.go create mode 100644 internal/quic-go/streams_map_outgoing_bidi.go create mode 100644 internal/quic-go/streams_map_outgoing_generic.go create mode 100644 internal/quic-go/streams_map_outgoing_generic_test.go create mode 100644 internal/quic-go/streams_map_outgoing_uni.go create mode 100644 internal/quic-go/streams_map_test.go create mode 100644 internal/quic-go/sys_conn.go create mode 100644 internal/quic-go/sys_conn_df.go create mode 100644 internal/quic-go/sys_conn_df_linux.go create mode 100644 internal/quic-go/sys_conn_df_windows.go create mode 100644 internal/quic-go/sys_conn_helper_darwin.go create mode 100644 internal/quic-go/sys_conn_helper_freebsd.go create mode 100644 internal/quic-go/sys_conn_helper_linux.go create mode 100644 internal/quic-go/sys_conn_no_oob.go create mode 100644 internal/quic-go/sys_conn_oob.go create mode 100644 internal/quic-go/sys_conn_oob_test.go create mode 100644 internal/quic-go/sys_conn_test.go create mode 100644 internal/quic-go/sys_conn_windows.go create mode 100644 internal/quic-go/sys_conn_windows_test.go create mode 100644 internal/quic-go/testdata/ca.pem create mode 100644 internal/quic-go/testdata/cert.go create mode 100644 internal/quic-go/testdata/cert.pem create mode 100644 internal/quic-go/testdata/cert_test.go create mode 100755 internal/quic-go/testdata/generate_key.sh create mode 100644 internal/quic-go/testdata/priv.key create mode 100644 internal/quic-go/testdata/testdata_suite_test.go create mode 100644 internal/quic-go/testutils/testutils.go create mode 100644 internal/quic-go/token_store.go create mode 100644 internal/quic-go/token_store_test.go create mode 100644 internal/quic-go/tools.go create mode 100644 internal/quic-go/utils/atomic_bool.go create mode 100644 internal/quic-go/utils/atomic_bool_test.go create mode 100644 internal/quic-go/utils/buffered_write_closer.go create mode 100644 internal/quic-go/utils/buffered_write_closer_test.go create mode 100644 internal/quic-go/utils/byteinterval_linkedlist.go create mode 100644 internal/quic-go/utils/byteoder_big_endian_test.go create mode 100644 internal/quic-go/utils/byteorder.go create mode 100644 internal/quic-go/utils/byteorder_big_endian.go rename internal/{ => quic-go}/utils/gen.go (100%) create mode 100644 internal/quic-go/utils/ip.go create mode 100644 internal/quic-go/utils/ip_test.go create mode 100644 internal/quic-go/utils/linkedlist/README.md create mode 100644 internal/quic-go/utils/linkedlist/linkedlist.go rename internal/{ => quic-go}/utils/log.go (97%) rename internal/{ => quic-go}/utils/log_test.go (100%) create mode 100644 internal/quic-go/utils/minmax.go create mode 100644 internal/quic-go/utils/minmax_test.go create mode 100644 internal/quic-go/utils/new_connection_id.go create mode 100644 internal/quic-go/utils/newconnectionid_linkedlist.go create mode 100644 internal/quic-go/utils/packet_interval.go create mode 100644 internal/quic-go/utils/packetinterval_linkedlist.go create mode 100644 internal/quic-go/utils/rand.go create mode 100644 internal/quic-go/utils/rand_test.go create mode 100644 internal/quic-go/utils/rtt_stats.go create mode 100644 internal/quic-go/utils/rtt_stats_test.go create mode 100644 internal/quic-go/utils/streamframe_interval.go create mode 100644 internal/quic-go/utils/timer.go create mode 100644 internal/quic-go/utils/timer_test.go rename internal/{ => quic-go}/utils/utils_suite_test.go (100%) create mode 100644 internal/quic-go/window_update_queue.go create mode 100644 internal/quic-go/window_update_queue_test.go create mode 100644 internal/quic-go/wire/ack_frame.go create mode 100644 internal/quic-go/wire/ack_frame_test.go create mode 100644 internal/quic-go/wire/ack_range.go create mode 100644 internal/quic-go/wire/ack_range_test.go create mode 100644 internal/quic-go/wire/connection_close_frame.go create mode 100644 internal/quic-go/wire/connection_close_frame_test.go create mode 100644 internal/quic-go/wire/crypto_frame.go create mode 100644 internal/quic-go/wire/crypto_frame_test.go create mode 100644 internal/quic-go/wire/data_blocked_frame.go create mode 100644 internal/quic-go/wire/data_blocked_frame_test.go create mode 100644 internal/quic-go/wire/datagram_frame.go create mode 100644 internal/quic-go/wire/datagram_frame_test.go create mode 100644 internal/quic-go/wire/extended_header.go create mode 100644 internal/quic-go/wire/extended_header_test.go create mode 100644 internal/quic-go/wire/frame_parser.go create mode 100644 internal/quic-go/wire/frame_parser_test.go create mode 100644 internal/quic-go/wire/handshake_done_frame.go create mode 100644 internal/quic-go/wire/header.go create mode 100644 internal/quic-go/wire/header_test.go create mode 100644 internal/quic-go/wire/interface.go create mode 100644 internal/quic-go/wire/log.go create mode 100644 internal/quic-go/wire/log_test.go create mode 100644 internal/quic-go/wire/max_data_frame.go create mode 100644 internal/quic-go/wire/max_data_frame_test.go create mode 100644 internal/quic-go/wire/max_stream_data_frame.go create mode 100644 internal/quic-go/wire/max_stream_data_frame_test.go create mode 100644 internal/quic-go/wire/max_streams_frame.go create mode 100644 internal/quic-go/wire/max_streams_frame_test.go create mode 100644 internal/quic-go/wire/new_connection_id_frame.go create mode 100644 internal/quic-go/wire/new_connection_id_frame_test.go create mode 100644 internal/quic-go/wire/new_token_frame.go create mode 100644 internal/quic-go/wire/new_token_frame_test.go create mode 100644 internal/quic-go/wire/path_challenge_frame.go create mode 100644 internal/quic-go/wire/path_challenge_frame_test.go create mode 100644 internal/quic-go/wire/path_response_frame.go create mode 100644 internal/quic-go/wire/path_response_frame_test.go create mode 100644 internal/quic-go/wire/ping_frame.go create mode 100644 internal/quic-go/wire/ping_frame_test.go create mode 100644 internal/quic-go/wire/pool.go create mode 100644 internal/quic-go/wire/pool_test.go create mode 100644 internal/quic-go/wire/reset_stream_frame.go create mode 100644 internal/quic-go/wire/reset_stream_frame_test.go create mode 100644 internal/quic-go/wire/retire_connection_id_frame.go create mode 100644 internal/quic-go/wire/retire_connection_id_frame_test.go create mode 100644 internal/quic-go/wire/stop_sending_frame.go create mode 100644 internal/quic-go/wire/stop_sending_frame_test.go create mode 100644 internal/quic-go/wire/stream_data_blocked_frame.go create mode 100644 internal/quic-go/wire/stream_data_blocked_frame_test.go create mode 100644 internal/quic-go/wire/stream_frame.go create mode 100644 internal/quic-go/wire/stream_frame_test.go create mode 100644 internal/quic-go/wire/streams_blocked_frame.go create mode 100644 internal/quic-go/wire/streams_blocked_frame_test.go create mode 100644 internal/quic-go/wire/transport_parameter_test.go create mode 100644 internal/quic-go/wire/transport_parameters.go create mode 100644 internal/quic-go/wire/version_negotiation.go create mode 100644 internal/quic-go/wire/version_negotiation_test.go create mode 100644 internal/quic-go/wire/wire_suite_test.go diff --git a/go.mod b/go.mod index 2ff010ed..74e6a7da 100644 --- a/go.mod +++ b/go.mod @@ -3,19 +3,24 @@ module github.com/imroc/req/v3 go 1.16 require ( - github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 - github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 - github.com/lucas-clemente/quic-go v0.28.1 - github.com/marten-seemann/qpack v0.2.1 + github.com/marten-seemann/qpack v0.3.0 github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 - github.com/marten-seemann/qtls-go1-18 v0.1.2 - github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 + github.com/marten-seemann/qtls-go1-18 v0.1.3 + github.com/marten-seemann/qtls-go1-19 v0.1.1 github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.13.0 - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect + github.com/onsi/gomega v1.24.1 + golang.org/x/crypto v0.1.0 golang.org/x/net v0.5.0 + golang.org/x/sys v0.4.0 golang.org/x/text v0.6.0 ) + +require ( + github.com/cheekybits/genny v1.0.0 + github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect +) diff --git a/go.sum b/go.sum index 805ec4cc..505b6184 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,9 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -21,6 +24,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -29,6 +33,8 @@ github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmV github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -56,10 +62,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -71,30 +81,27 @@ github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= -github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= -github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= +github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= -github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= -github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 h1:7m/WlWcSROrcK5NxuXaxYD32BZqe/LEEnBrWcH/cOqQ= -github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= +github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= +github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -106,15 +113,27 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= +github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= +github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= +github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= +github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= +github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= +github.com/onsi/ginkgo/v2 v2.5.0 h1:TRtrvv2vdQqzkwrQ1ke6vtXf7IK34RBUJafIy1wMwls= +github.com/onsi/ginkgo/v2 v2.5.0/go.mod h1:Luc4sArBICYCS8THh8v3i3i5CuSZO+RaQRaJoeNwomw= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= -github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= +github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= +github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= +github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= +github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= +github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= +github.com/onsi/gomega v1.24.1 h1:KORJXNNTzJXzu4ScJWssJfJMnJ+2QJqhoQSRwNlze9E= +github.com/onsi/gomega v1.24.1/go.mod h1:3AOiACssS3/MajrniINInwbfOOtfZvplPzuRSmvt1jM= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -158,6 +177,7 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= @@ -169,16 +189,18 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= +golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -191,14 +213,16 @@ golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -223,9 +247,9 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -233,20 +257,28 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -260,8 +292,10 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -289,11 +323,11 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= @@ -304,6 +338,8 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/http3/body.go b/internal/http3/body.go index d6e704eb..07b68ff0 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -5,7 +5,7 @@ import ( "io" "net" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" ) // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go index ee289d1d..886d391b 100644 --- a/internal/http3/body_test.go +++ b/internal/http3/body_test.go @@ -3,8 +3,8 @@ package http3 import ( "errors" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" diff --git a/internal/http3/client.go b/internal/http3/client.go index 39032029..2c4ce4c5 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -7,11 +7,11 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/qtls" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/utils" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" "github.com/marten-seemann/qpack" "io" "net/http" diff --git a/internal/http3/client_test.go b/internal/http3/client_test.go index 5531fe17..213a127f 100644 --- a/internal/http3/client_test.go +++ b/internal/http3/client_test.go @@ -7,20 +7,19 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/marten-seemann/qpack" "io" "io/ioutil" "net/http" "time" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" - "github.com/imroc/req/v3/internal/utils" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" "github.com/golang/mock/gomock" - "github.com/marten-seemann/qpack" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) diff --git a/internal/http3/error_codes.go b/internal/http3/error_codes.go index d87eef4a..353a1a8e 100644 --- a/internal/http3/error_codes.go +++ b/internal/http3/error_codes.go @@ -3,7 +3,7 @@ package http3 import ( "fmt" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" ) type errorCode quic.ApplicationErrorCode diff --git a/internal/http3/frames.go b/internal/http3/frames.go index f7f28913..b0d886d5 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -7,8 +7,8 @@ import ( "io" "io/ioutil" - "github.com/imroc/req/v3/internal/protocol" - "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" ) // FrameType is the frame type of a HTTP/3 frame diff --git a/internal/http3/frames_test.go b/internal/http3/frames_test.go index cf0efe9d..1e0146cb 100644 --- a/internal/http3/frames_test.go +++ b/internal/http3/frames_test.go @@ -6,7 +6,7 @@ import ( "fmt" "io" - "github.com/imroc/req/v3/internal/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index 4c69068c..3ff7a61b 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" ) // A Stream is a HTTP/3 stream. diff --git a/internal/http3/http_stream_test.go b/internal/http3/http_stream_test.go index 6d3fef02..f4ccaa6d 100644 --- a/internal/http3/http_stream_test.go +++ b/internal/http3/http_stream_test.go @@ -4,7 +4,7 @@ import ( "bytes" "io" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" diff --git a/internal/http3/request.go b/internal/http3/request.go index 0b9a7278..dceb96e4 100644 --- a/internal/http3/request.go +++ b/internal/http3/request.go @@ -3,12 +3,11 @@ package http3 import ( "crypto/tls" "errors" + "github.com/marten-seemann/qpack" "net/http" "net/url" "strconv" "strings" - - "github.com/marten-seemann/qpack" ) func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { diff --git a/internal/http3/request_test.go b/internal/http3/request_test.go index f2b84eca..6841abed 100644 --- a/internal/http3/request_test.go +++ b/internal/http3/request_test.go @@ -1,10 +1,10 @@ package http3 import ( + "github.com/marten-seemann/qpack" "net/http" "net/url" - "github.com/marten-seemann/qpack" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 0c7312b2..bf714210 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" + "github.com/marten-seemann/qpack" "io" "net" "net/http" @@ -12,9 +13,8 @@ import ( "strings" "sync" - "github.com/imroc/req/v3/internal/utils" - "github.com/lucas-clemente/quic-go" - "github.com/marten-seemann/qpack" + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/utils" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" diff --git a/internal/http3/request_writer_test.go b/internal/http3/request_writer_test.go index 345c74cb..d8f1eafe 100644 --- a/internal/http3/request_writer_test.go +++ b/internal/http3/request_writer_test.go @@ -2,15 +2,14 @@ package http3 import ( "bytes" + "github.com/marten-seemann/qpack" "io" "net/http" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/imroc/req/v3/internal/utils" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" + "github.com/imroc/req/v3/internal/quic-go/utils" "github.com/golang/mock/gomock" - "github.com/marten-seemann/qpack" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) diff --git a/internal/http3/response_writer.go b/internal/http3/response_writer.go index d1c16dc3..2233c2b0 100644 --- a/internal/http3/response_writer.go +++ b/internal/http3/response_writer.go @@ -3,13 +3,13 @@ package http3 import ( "bufio" "bytes" + "github.com/marten-seemann/qpack" "net/http" "strconv" "strings" - "github.com/imroc/req/v3/internal/utils" - "github.com/lucas-clemente/quic-go" - "github.com/marten-seemann/qpack" + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/utils" ) type responseWriter struct { diff --git a/internal/http3/response_writer_test.go b/internal/http3/response_writer_test.go index abada013..20a01d9f 100644 --- a/internal/http3/response_writer_test.go +++ b/internal/http3/response_writer_test.go @@ -2,15 +2,14 @@ package http3 import ( "bytes" + "github.com/marten-seemann/qpack" "io" "net/http" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/imroc/req/v3/internal/utils" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" + "github.com/imroc/req/v3/internal/quic-go/utils" "github.com/golang/mock/gomock" - "github.com/marten-seemann/qpack" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 06cf90db..09b5ef1c 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -12,7 +12,7 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" "golang.org/x/net/http/httpguts" ) diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index 82b7e707..ed1aed90 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -11,8 +11,8 @@ import ( "time" "github.com/golang/mock/gomock" - mockquic "github.com/imroc/req/v3/internal/mocks/quic" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" + mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) diff --git a/internal/http3/server.go b/internal/http3/server.go index 07d15ed2..0ce67b07 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -1,8 +1,8 @@ package http3 import ( - "github.com/imroc/req/v3/internal/protocol" - "github.com/lucas-clemente/quic-go" + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/protocol" ) // allows mocking of quic.Listen and quic.ListenAddr diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go deleted file mode 100644 index a96d7109..00000000 --- a/internal/mocks/mockgen.go +++ /dev/null @@ -1,20 +0,0 @@ -package mocks - -//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/lucas-clemente/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && goimports -w quic/early_conn.go" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener" -//go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/imroc/req/v3/internal/logging Tracer" -//go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/imroc/req/v3/internal/logging ConnectionTracer" -//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/imroc/req/v3/internal/handshake ShortHeaderSealer" -//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/imroc/req/v3/internal/handshake ShortHeaderOpener" -//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/imroc/req/v3/internal/handshake LongHeaderOpener" -//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/imroc/req/v3/internal/handshake CryptoSetup && sed -E 's~github.com/marten-seemann/qtls[[:alnum:]_-]*~github.com/imroc/req/v3/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go" -//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/imroc/req/v3/internal/flowcontrol StreamFlowController" -//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/imroc/req/v3/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/imroc/req/v3/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/imroc/req/v3/internal/ackhandler SentPacketHandler" -//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/imroc/req/v3/internal/ackhandler ReceivedPacketHandler" - -// The following command produces a warning message on OSX, however, it still generates the correct mock file. -// See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go deleted file mode 100644 index 366d83ff..00000000 --- a/internal/protocol/protocol.go +++ /dev/null @@ -1,4 +0,0 @@ -package protocol - -// A ByteCount in QUIC -type ByteCount int64 diff --git a/internal/protocol/version.go b/internal/protocol/version.go deleted file mode 100644 index a96656a2..00000000 --- a/internal/protocol/version.go +++ /dev/null @@ -1,13 +0,0 @@ -package protocol - -import ( - "github.com/lucas-clemente/quic-go" -) - -// The version numbers, making grepping easier -const ( - VersionTLS quic.VersionNumber = 0x1 - VersionDraft29 quic.VersionNumber = 0xff00001d - Version1 quic.VersionNumber = 0x1 - Version2 quic.VersionNumber = 0x709a50c4 -) diff --git a/internal/qtls/go116.go b/internal/qtls/go116.go deleted file mode 100644 index 73966c3e..00000000 --- a/internal/qtls/go116.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build go1.16 && !go1.17 -// +build go1.16,!go1.17 - -package qtls - -import ( - "crypto/tls" - "github.com/marten-seemann/qtls-go1-16" -) - -type ( - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT -) - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} diff --git a/internal/qtls/go117.go b/internal/qtls/go117.go deleted file mode 100644 index 8bbb04df..00000000 --- a/internal/qtls/go117.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build go1.17 && !go1.18 -// +build go1.17,!go1.18 - -package qtls - -import ( - "crypto/tls" - "github.com/marten-seemann/qtls-go1-17" -) - -type ( - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT -) - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} diff --git a/internal/qtls/go118.go b/internal/qtls/go118.go deleted file mode 100644 index 312bec4e..00000000 --- a/internal/qtls/go118.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build go1.18 && !go1.19 -// +build go1.18,!go1.19 - -package qtls - -import ( - "crypto/tls" - "github.com/marten-seemann/qtls-go1-18" -) - -type ( - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT -) - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go deleted file mode 100644 index 56e65c7d..00000000 --- a/internal/qtls/go119.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build go1.19 -// +build go1.19 - -package qtls - -import ( - "crypto/tls" - "github.com/marten-seemann/qtls-go1-19" -) - -type ( - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT -) - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go deleted file mode 100644 index f8d59d8b..00000000 --- a/internal/qtls/go120.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build go1.20 -// +build go1.20 - -package qtls - -var _ int = "The version of quic-go you're using can't be built on Go 1.20 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/quic-go/ackhandler/ack_eliciting.go b/internal/quic-go/ackhandler/ack_eliciting.go new file mode 100644 index 00000000..76e8cc01 --- /dev/null +++ b/internal/quic-go/ackhandler/ack_eliciting.go @@ -0,0 +1,20 @@ +package ackhandler + +import "github.com/imroc/req/v3/internal/quic-go/wire" + +// IsFrameAckEliciting returns true if the frame is ack-eliciting. +func IsFrameAckEliciting(f wire.Frame) bool { + _, isAck := f.(*wire.AckFrame) + _, isConnectionClose := f.(*wire.ConnectionCloseFrame) + return !isAck && !isConnectionClose +} + +// HasAckElicitingFrames returns true if at least one frame is ack-eliciting. +func HasAckElicitingFrames(fs []Frame) bool { + for _, f := range fs { + if IsFrameAckEliciting(f.Frame) { + return true + } + } + return false +} diff --git a/internal/quic-go/ackhandler/ack_eliciting_test.go b/internal/quic-go/ackhandler/ack_eliciting_test.go new file mode 100644 index 00000000..625899f7 --- /dev/null +++ b/internal/quic-go/ackhandler/ack_eliciting_test.go @@ -0,0 +1,34 @@ +package ackhandler + +import ( + "reflect" + + "github.com/imroc/req/v3/internal/quic-go/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ack-eliciting frames", func() { + for fl, el := range map[wire.Frame]bool{ + &wire.AckFrame{}: false, + &wire.ConnectionCloseFrame{}: false, + &wire.DataBlockedFrame{}: true, + &wire.PingFrame{}: true, + &wire.ResetStreamFrame{}: true, + &wire.StreamFrame{}: true, + &wire.MaxDataFrame{}: true, + &wire.MaxStreamDataFrame{}: true, + } { + f := fl + e := el + fName := reflect.ValueOf(f).Elem().Type().Name() + + It("works for "+fName, func() { + Expect(IsFrameAckEliciting(f)).To(Equal(e)) + }) + + It("HasAckElicitingFrames works for "+fName, func() { + Expect(HasAckElicitingFrames([]Frame{{Frame: f}})).To(Equal(e)) + }) + } +}) diff --git a/internal/quic-go/ackhandler/ackhandler.go b/internal/quic-go/ackhandler/ackhandler.go new file mode 100644 index 00000000..c5ebb712 --- /dev/null +++ b/internal/quic-go/ackhandler/ackhandler.go @@ -0,0 +1,21 @@ +package ackhandler + +import ( + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler +func NewAckHandler( + initialPacketNumber protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) (SentPacketHandler, ReceivedPacketHandler) { + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) + return sph, newReceivedPacketHandler(sph, rttStats, logger, version) +} diff --git a/internal/quic-go/ackhandler/ackhandler_suite_test.go b/internal/quic-go/ackhandler/ackhandler_suite_test.go new file mode 100644 index 00000000..17481188 --- /dev/null +++ b/internal/quic-go/ackhandler/ackhandler_suite_test.go @@ -0,0 +1,29 @@ +package ackhandler + +import ( + "math/rand" + "testing" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestCrypto(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "AckHandler Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeSuite(func() { + rand.Seed(GinkgoRandomSeed()) +}) + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/quic-go/ackhandler/frame.go b/internal/quic-go/ackhandler/frame.go new file mode 100644 index 00000000..98866e91 --- /dev/null +++ b/internal/quic-go/ackhandler/frame.go @@ -0,0 +1,9 @@ +package ackhandler + +import "github.com/imroc/req/v3/internal/quic-go/wire" + +type Frame struct { + wire.Frame // nil if the frame has already been acknowledged in another packet + OnLost func(wire.Frame) + OnAcked func(wire.Frame) +} diff --git a/internal/quic-go/ackhandler/gen.go b/internal/quic-go/ackhandler/gen.go new file mode 100644 index 00000000..32235f81 --- /dev/null +++ b/internal/quic-go/ackhandler/gen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet diff --git a/internal/quic-go/ackhandler/interfaces.go b/internal/quic-go/ackhandler/interfaces.go new file mode 100644 index 00000000..aa54bf53 --- /dev/null +++ b/internal/quic-go/ackhandler/interfaces.go @@ -0,0 +1,68 @@ +package ackhandler + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// A Packet is a packet +type Packet struct { + PacketNumber protocol.PacketNumber + Frames []Frame + LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK + Length protocol.ByteCount + EncryptionLevel protocol.EncryptionLevel + SendTime time.Time + + IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. + + includedInBytesInFlight bool + declaredLost bool + skippedPacket bool +} + +// SentPacketHandler handles ACKs received for outgoing packets +type SentPacketHandler interface { + // SentPacket may modify the packet + SentPacket(packet *Packet) + ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + ReceivedBytes(protocol.ByteCount) + DropPackets(protocol.EncryptionLevel) + ResetForRetry() error + SetHandshakeConfirmed() + + // The SendMode determines if and what kind of packets can be sent. + SendMode() SendMode + // TimeUntilSend is the time when the next packet should be sent. + // It is used for pacing packets. + TimeUntilSend() time.Time + // HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment. + HasPacingBudget() bool + SetMaxDatagramSize(count protocol.ByteCount) + + // only to be called once the handshake is complete + QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) + PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber + + GetLossDetectionTimeout() time.Time + OnLossDetectionTimeout() error +} + +type sentPacketTracker interface { + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + ReceivedPacket(protocol.EncryptionLevel) +} + +// ReceivedPacketHandler handles ACKs needed to send for incoming packets +type ReceivedPacketHandler interface { + IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + DropPackets(protocol.EncryptionLevel) + + GetAlarmTimeout() time.Time + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame +} diff --git a/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go b/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go new file mode 100644 index 00000000..01eb17df --- /dev/null +++ b/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go @@ -0,0 +1,61 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces.go + +// Package ackhandler is a generated GoMock package. +package ackhandler + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockSentPacketTracker is a mock of SentPacketTracker interface. +type MockSentPacketTracker struct { + ctrl *gomock.Controller + recorder *MockSentPacketTrackerMockRecorder +} + +// MockSentPacketTrackerMockRecorder is the mock recorder for MockSentPacketTracker. +type MockSentPacketTrackerMockRecorder struct { + mock *MockSentPacketTracker +} + +// NewMockSentPacketTracker creates a new mock instance. +func NewMockSentPacketTracker(ctrl *gomock.Controller) *MockSentPacketTracker { + mock := &MockSentPacketTracker{ctrl: ctrl} + mock.recorder = &MockSentPacketTrackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSentPacketTracker) EXPECT() *MockSentPacketTrackerMockRecorder { + return m.recorder +} + +// GetLowestPacketNotConfirmedAcked mocks base method. +func (m *MockSentPacketTracker) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked. +func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) +} + +// ReceivedPacket mocks base method. +func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) +} diff --git a/internal/quic-go/ackhandler/mockgen.go b/internal/quic-go/ackhandler/mockgen.go new file mode 100644 index 00000000..8c5e33c0 --- /dev/null +++ b/internal/quic-go/ackhandler/mockgen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/imroc/req/v3/internal/quic-go/ackhandler sentPacketTracker" diff --git a/internal/quic-go/ackhandler/packet_linkedlist.go b/internal/quic-go/ackhandler/packet_linkedlist.go new file mode 100644 index 00000000..bb74f4ef --- /dev/null +++ b/internal/quic-go/ackhandler/packet_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package ackhandler + +// Linked list implementation from the Go standard library. + +// PacketElement is an element of a linked list. +type PacketElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *PacketElement + + // The list to which this element belongs. + list *PacketList + + // The value stored with this element. + Value Packet +} + +// Next returns the next list element or nil. +func (e *PacketElement) Next() *PacketElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *PacketElement) Prev() *PacketElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// PacketList is a linked list of Packets. +type PacketList struct { + root PacketElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *PacketList) Init() *PacketList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewPacketList returns an initialized list. +func NewPacketList() *PacketList { return new(PacketList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *PacketList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *PacketList) Front() *PacketElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *PacketList) Back() *PacketElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *PacketList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *PacketList) insert(e, at *PacketElement) *PacketElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { + return l.insert(&PacketElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *PacketList) remove(e *PacketElement) *PacketElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *PacketList) Remove(e *PacketElement) Packet { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *PacketList) PushFront(v Packet) *PacketElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *PacketList) PushBack(v Packet) *PacketElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketList) MoveToFront(e *PacketElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketList) MoveToBack(e *PacketElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketList) MoveBefore(e, mark *PacketElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketList) MoveAfter(e, mark *PacketElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketList) PushBackList(other *PacketList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketList) PushFrontList(other *PacketList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/quic-go/ackhandler/packet_number_generator.go b/internal/quic-go/ackhandler/packet_number_generator.go new file mode 100644 index 00000000..0d81f6d1 --- /dev/null +++ b/internal/quic-go/ackhandler/packet_number_generator.go @@ -0,0 +1,76 @@ +package ackhandler + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type packetNumberGenerator interface { + Peek() protocol.PacketNumber + Pop() protocol.PacketNumber +} + +type sequentialPacketNumberGenerator struct { + next protocol.PacketNumber +} + +var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} + +func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { + return &sequentialPacketNumberGenerator{next: initial} +} + +func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ + return next +} + +// The skippingPacketNumberGenerator generates the packet number for the next packet +// it randomly skips a packet number every averagePeriod packets (on average). +// It is guaranteed to never skip two consecutive packet numbers. +type skippingPacketNumberGenerator struct { + period protocol.PacketNumber + maxPeriod protocol.PacketNumber + + next protocol.PacketNumber + nextToSkip protocol.PacketNumber + + rng utils.Rand +} + +var _ packetNumberGenerator = &skippingPacketNumberGenerator{} + +func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { + g := &skippingPacketNumberGenerator{ + next: initial, + period: initialPeriod, + maxPeriod: maxPeriod, + } + g.generateNewSkip() + return g +} + +func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ // generate a new packet number for the next packet + if p.next == p.nextToSkip { + p.next++ + p.generateNewSkip() + } + return next +} + +func (p *skippingPacketNumberGenerator) generateNewSkip() { + // make sure that there are never two consecutive packet numbers that are skipped + p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) + p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) +} diff --git a/internal/quic-go/ackhandler/packet_number_generator_test.go b/internal/quic-go/ackhandler/packet_number_generator_test.go new file mode 100644 index 00000000..db4d096d --- /dev/null +++ b/internal/quic-go/ackhandler/packet_number_generator_test.go @@ -0,0 +1,99 @@ +package ackhandler + +import ( + "fmt" + "math" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Sequential Packet Number Generator", func() { + It("generates sequential packet numbers", func() { + const initialPN protocol.PacketNumber = 123 + png := newSequentialPacketNumberGenerator(initialPN) + + for i := initialPN; i < initialPN+1000; i++ { + Expect(png.Peek()).To(Equal(i)) + Expect(png.Peek()).To(Equal(i)) + Expect(png.Pop()).To(Equal(i)) + } + }) +}) + +var _ = Describe("Skipping Packet Number Generator", func() { + const initialPN protocol.PacketNumber = 8 + const initialPeriod protocol.PacketNumber = 25 + const maxPeriod protocol.PacketNumber = 300 + + It("uses a maximum period that is sufficiently small such that using a 32-bit random number is ok", func() { + Expect(2 * protocol.SkipPacketMaxPeriod).To(BeNumerically("<", math.MaxInt32)) + }) + + It("can be initialized to return any first packet number", func() { + png := newSkippingPacketNumberGenerator(12345, initialPeriod, maxPeriod) + Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) + }) + + It("allows peeking", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod).(*skippingPacketNumberGenerator) + png.nextToSkip = 1000 + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Pop()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN + 1)) + Expect(png.Peek()).To(Equal(initialPN + 1)) + }) + + It("skips a packet number", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) + var last protocol.PacketNumber + var skipped bool + for i := 0; i < 1000; i++ { + num := png.Pop() + if num > last+1 { + skipped = true + break + } + last = num + } + Expect(skipped).To(BeTrue()) + }) + + It("generates a new packet number to skip", func() { + const rep = 2500 + periods := make([][]protocol.PacketNumber, rep) + expectedPeriods := []protocol.PacketNumber{25, 50, 100, 200, 300, 300, 300} + + for i := 0; i < rep; i++ { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) + last := initialPN + lastSkip := initialPN + for len(periods[i]) < len(expectedPeriods) { + next := png.Pop() + if next > last+1 { + skipped := next - 1 + Expect(skipped).To(BeNumerically(">", lastSkip+1)) + periods[i] = append(periods[i], skipped-lastSkip-1) + lastSkip = skipped + } + last = next + } + } + + for j := 0; j < len(expectedPeriods); j++ { + var average float64 + for i := 0; i < rep; i++ { + average += float64(periods[i][j]) / float64(len(periods)) + } + fmt.Fprintf(GinkgoWriter, "Period %d: %.2f (expected %d)\n", j, average, expectedPeriods[j]) + tolerance := protocol.PacketNumber(5) + if t := expectedPeriods[j] / 10; t > tolerance { + tolerance = t + } + Expect(average).To(BeNumerically("~", expectedPeriods[j]+1 /* we never skip two packet numbers at the same time */, tolerance)) + } + }) +}) diff --git a/internal/quic-go/ackhandler/received_packet_handler.go b/internal/quic-go/ackhandler/received_packet_handler.go new file mode 100644 index 00000000..39e45da4 --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_handler.go @@ -0,0 +1,136 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type receivedPacketHandler struct { + sentPackets sentPacketTracker + + initialPackets *receivedPacketTracker + handshakePackets *receivedPacketTracker + appDataPackets *receivedPacketTracker + + lowest1RTTPacket protocol.PacketNumber +} + +var _ ReceivedPacketHandler = &receivedPacketHandler{} + +func newReceivedPacketHandler( + sentPackets sentPacketTracker, + rttStats *utils.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) ReceivedPacketHandler { + return &receivedPacketHandler{ + sentPackets: sentPackets, + initialPackets: newReceivedPacketTracker(rttStats, logger, version), + handshakePackets: newReceivedPacketTracker(rttStats, logger, version), + appDataPackets: newReceivedPacketTracker(rttStats, logger, version), + lowest1RTTPacket: protocol.InvalidPacketNumber, + } +} + +func (h *receivedPacketHandler) ReceivedPacket( + pn protocol.PacketNumber, + ecn protocol.ECN, + encLevel protocol.EncryptionLevel, + rcvTime time.Time, + shouldInstigateAck bool, +) error { + h.sentPackets.ReceivedPacket(encLevel) + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.EncryptionHandshake: + h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.Encryption0RTT: + if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { + return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) + } + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.Encryption1RTT: + if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { + h.lowest1RTTPacket = pn + } + h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + default: + panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) + } + return nil +} + +func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + //nolint:exhaustive // 1-RTT packet number space is never dropped. + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + case protocol.Encryption0RTT: + // Nothing to do here. + // If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted. + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } +} + +func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { + var initialAlarm, handshakeAlarm time.Time + if h.initialPackets != nil { + initialAlarm = h.initialPackets.GetAlarmTimeout() + } + if h.handshakePackets != nil { + handshakeAlarm = h.handshakePackets.GetAlarmTimeout() + } + oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() + return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) +} + +func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { + var ack *wire.AckFrame + //nolint:exhaustive // 0-RTT packets can't contain ACK frames. + switch encLevel { + case protocol.EncryptionInitial: + if h.initialPackets != nil { + ack = h.initialPackets.GetAckFrame(onlyIfQueued) + } + case protocol.EncryptionHandshake: + if h.handshakePackets != nil { + ack = h.handshakePackets.GetAckFrame(onlyIfQueued) + } + case protocol.Encryption1RTT: + // 0-RTT packets can't contain ACK frames + return h.appDataPackets.GetAckFrame(onlyIfQueued) + default: + return nil + } + // For Initial and Handshake ACKs, the delay time is ignored by the receiver. + // Set it to 0 in order to save bytes. + if ack != nil { + ack.DelayTime = 0 + } + return ack +} + +func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial: + if h.initialPackets != nil { + return h.initialPackets.IsPotentiallyDuplicate(pn) + } + case protocol.EncryptionHandshake: + if h.handshakePackets != nil { + return h.handshakePackets.IsPotentiallyDuplicate(pn) + } + case protocol.Encryption0RTT, protocol.Encryption1RTT: + return h.appDataPackets.IsPotentiallyDuplicate(pn) + } + panic("unexpected encryption level") +} diff --git a/internal/quic-go/ackhandler/received_packet_handler_test.go b/internal/quic-go/ackhandler/received_packet_handler_test.go new file mode 100644 index 00000000..b852b068 --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_handler_test.go @@ -0,0 +1,168 @@ +package ackhandler + +import ( + "time" + + "github.com/golang/mock/gomock" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Received Packet Handler", func() { + var handler ReceivedPacketHandler + var sentPackets *MockSentPacketTracker + + BeforeEach(func() { + sentPackets = NewMockSentPacketTracker(mockCtrl) + handler = newReceivedPacketHandler( + sentPackets, + &utils.RTTStats{}, + utils.DefaultLogger, + protocol.VersionWhatever, + ) + }) + + It("generates ACKs for different packet number spaces", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sendTime := time.Now().Add(-time.Second) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) + Expect(handler.ReceivedPacket(2, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(5, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + initialAck := handler.GetAckFrame(protocol.EncryptionInitial, true) + Expect(initialAck).ToNot(BeNil()) + Expect(initialAck.AckRanges).To(HaveLen(1)) + Expect(initialAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) + Expect(initialAck.DelayTime).To(BeZero()) + Expect(initialAck.ECT0).To(BeEquivalentTo(2)) + Expect(initialAck.ECT1).To(BeZero()) + Expect(initialAck.ECNCE).To(BeZero()) + handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake, true) + Expect(handshakeAck).ToNot(BeNil()) + Expect(handshakeAck.AckRanges).To(HaveLen(1)) + Expect(handshakeAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) + Expect(handshakeAck.DelayTime).To(BeZero()) + Expect(handshakeAck.ECT0).To(BeZero()) + Expect(handshakeAck.ECT1).To(BeEquivalentTo(2)) + Expect(handshakeAck.ECNCE).To(BeZero()) + oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT, true) + Expect(oneRTTAck).ToNot(BeNil()) + Expect(oneRTTAck.AckRanges).To(HaveLen(1)) + Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) + Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) + Expect(oneRTTAck.ECT0).To(BeZero()) + Expect(oneRTTAck.ECT1).To(BeZero()) + Expect(oneRTTAck.ECNCE).To(BeEquivalentTo(2)) + }) + + It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + ack := handler.GetAckFrame(protocol.Encryption1RTT, true) + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(1)) + Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) + }) + + It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sendTime := time.Now() + Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(MatchError("received packet number 12 on a 0-RTT packet after receiving 11 on a 1-RTT packet")) + }) + + It("allows reordered 0-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sendTime := time.Now() + Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + }) + + It("drops Initial packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) + }) + + It("drops Handshake packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.Encryption1RTT, true)).ToNot(BeNil()) + }) + + It("does nothing when dropping 0-RTT packets", func() { + handler.DropPackets(protocol.Encryption0RTT) + }) + + It("drops old ACK ranges", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() + sendTime := time.Now() + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + ack := handler.GetAckFrame(protocol.Encryption1RTT, true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked() + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(2)) + Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + ack = handler.GetAckFrame(protocol.Encryption1RTT, true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(2))) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) + }) + + It("says if packets are duplicates", func() { + sendTime := time.Now() + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + // Initial + Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeFalse()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeTrue()) + // Handshake + Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeFalse()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeTrue()) + // 0-RTT + Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeFalse()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeTrue()) + // 1-RTT + Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption1RTT)).To(BeTrue()) + Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeFalse()) + Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeTrue()) + }) +}) diff --git a/internal/quic-go/ackhandler/received_packet_history.go b/internal/quic-go/ackhandler/received_packet_history.go new file mode 100644 index 00000000..f3bcc25e --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_history.go @@ -0,0 +1,142 @@ +package ackhandler + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// The receivedPacketHistory stores if a packet number has already been received. +// It generates ACK ranges which can be used to assemble an ACK frame. +// It does not store packet contents. +type receivedPacketHistory struct { + ranges *utils.PacketIntervalList + + deletedBelow protocol.PacketNumber +} + +func newReceivedPacketHistory() *receivedPacketHistory { + return &receivedPacketHistory{ + ranges: utils.NewPacketIntervalList(), + } +} + +// ReceivedPacket registers a packet with PacketNumber p and updates the ranges +func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { + // ignore delayed packets, if we already deleted the range + if p < h.deletedBelow { + return false + } + isNew := h.addToRanges(p) + h.maybeDeleteOldRanges() + return isNew +} + +func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { + if h.ranges.Len() == 0 { + h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) + return true + } + + for el := h.ranges.Back(); el != nil; el = el.Prev() { + // p already included in an existing range. Nothing to do here + if p >= el.Value.Start && p <= el.Value.End { + return false + } + + if el.Value.End == p-1 { // extend a range at the end + el.Value.End = p + return true + } + if el.Value.Start == p+1 { // extend a range at the beginning + el.Value.Start = p + + prev := el.Prev() + if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges + prev.Value.End = el.Value.End + h.ranges.Remove(el) + } + return true + } + + // create a new range at the end + if p > el.Value.End { + h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) + return true + } + } + + // create a new range at the beginning + h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) + return true +} + +// Delete old ranges, if we're tracking more than 500 of them. +// This is a DoS defense against a peer that sends us too many gaps. +func (h *receivedPacketHistory) maybeDeleteOldRanges() { + for h.ranges.Len() > protocol.MaxNumAckRanges { + h.ranges.Remove(h.ranges.Front()) + } +} + +// DeleteBelow deletes all entries below (but not including) p +func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { + if p < h.deletedBelow { + return + } + h.deletedBelow = p + + nextEl := h.ranges.Front() + for el := h.ranges.Front(); nextEl != nil; el = nextEl { + nextEl = el.Next() + + if el.Value.End < p { // delete a whole range + h.ranges.Remove(el) + } else if p > el.Value.Start && p <= el.Value.End { + el.Value.Start = p + return + } else { // no ranges affected. Nothing to do + return + } + } +} + +// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame +func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { + if h.ranges.Len() == 0 { + return nil + } + + ackRanges := make([]wire.AckRange, h.ranges.Len()) + i := 0 + for el := h.ranges.Back(); el != nil; el = el.Prev() { + ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} + i++ + } + return ackRanges +} + +func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { + ackRange := wire.AckRange{} + if h.ranges.Len() > 0 { + r := h.ranges.Back().Value + ackRange.Smallest = r.Start + ackRange.Largest = r.End + } + return ackRange +} + +func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool { + if p < h.deletedBelow { + return true + } + for el := h.ranges.Back(); el != nil; el = el.Prev() { + if p > el.Value.End { + return false + } + if p <= el.Value.End && p >= el.Value.Start { + return true + } + } + return false +} diff --git a/internal/quic-go/ackhandler/received_packet_history_test.go b/internal/quic-go/ackhandler/received_packet_history_test.go new file mode 100644 index 00000000..9994b489 --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_history_test.go @@ -0,0 +1,354 @@ +package ackhandler + +import ( + "fmt" + "math/rand" + "sort" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("receivedPacketHistory", func() { + var hist *receivedPacketHistory + + BeforeEach(func() { + hist = newReceivedPacketHistory() + }) + + Context("ranges", func() { + It("adds the first packet", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + }) + + It("doesn't care about duplicate packets", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeFalse()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + }) + + It("adds a few consecutive packets", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) + }) + + It("doesn't care about a duplicate packet contained in an existing range", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeFalse()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) + }) + + It("extends a range at the front", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(3)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4})) + }) + + It("creates a new range when a packet is lost", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) + }) + + It("creates a new range in between two ranges", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(3)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) + }) + + It("creates a new range before an existing range for a belated packet", func() { + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) + }) + + It("extends a previous range at the end", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) + }) + + It("extends a range at the front", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7})) + }) + + It("closes a range", func() { + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) + }) + + It("closes a range in the middle", func() { + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(4)) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ranges.Len()).To(Equal(3)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1})) + Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) + }) + }) + + Context("deleting", func() { + It("does nothing when the history is empty", func() { + hist.DeleteBelow(5) + Expect(hist.ranges.Len()).To(BeZero()) + }) + + It("deletes a range", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + hist.DeleteBelow(6) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) + }) + + It("deletes multiple ranges", func() { + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + hist.DeleteBelow(8) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) + }) + + It("adjusts a range, if packets are delete from an existing range", func() { + Expect(hist.ReceivedPacket(3)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + hist.DeleteBelow(5) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7})) + }) + + It("adjusts a range, if only one packet remains in the range", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + hist.DeleteBelow(5) + Expect(hist.ranges.Len()).To(Equal(2)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5})) + Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) + }) + + It("keeps a one-packet range, if deleting up to the packet directly below", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + hist.DeleteBelow(4) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) + }) + + It("doesn't add delayed packets below deleted ranges", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + hist.DeleteBelow(5) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) + Expect(hist.ReceivedPacket(2)).To(BeFalse()) + Expect(hist.ranges.Len()).To(Equal(1)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) + }) + + It("doesn't create more than MaxNumAckRanges ranges", func() { + for i := protocol.PacketNumber(0); i < protocol.MaxNumAckRanges; i++ { + Expect(hist.ReceivedPacket(2 * i)).To(BeTrue()) + } + Expect(hist.ranges.Len()).To(Equal(protocol.MaxNumAckRanges)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 0, End: 0})) + hist.ReceivedPacket(2*protocol.MaxNumAckRanges + 1000) + // check that the oldest ACK range was deleted + Expect(hist.ranges.Len()).To(Equal(protocol.MaxNumAckRanges)) + Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 2, End: 2})) + }) + }) + + Context("ACK range export", func() { + It("returns nil if there are no ranges", func() { + Expect(hist.GetAckRanges()).To(BeNil()) + }) + + It("gets a single ACK range", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + ackRanges := hist.GetAckRanges() + Expect(ackRanges).To(HaveLen(1)) + Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) + }) + + It("gets multiple ACK ranges", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(11)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + Expect(hist.ReceivedPacket(2)).To(BeTrue()) + ackRanges := hist.GetAckRanges() + Expect(ackRanges).To(HaveLen(3)) + Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11})) + Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6})) + Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) + }) + }) + + Context("Getting the highest ACK range", func() { + It("returns the zero value if there are no ranges", func() { + Expect(hist.GetHighestAckRange()).To(BeZero()) + }) + + It("gets a single ACK range", func() { + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) + }) + + It("gets the highest of multiple ACK ranges", func() { + Expect(hist.ReceivedPacket(3)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7})) + }) + }) + + Context("duplicate detection", func() { + It("doesn't declare the first packet a duplicate", func() { + Expect(hist.IsPotentiallyDuplicate(5)).To(BeFalse()) + }) + + It("detects a duplicate in a range", func() { + hist.ReceivedPacket(4) + hist.ReceivedPacket(5) + hist.ReceivedPacket(6) + Expect(hist.IsPotentiallyDuplicate(3)).To(BeFalse()) + Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(6)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(7)).To(BeFalse()) + }) + + It("detects a duplicate in multiple ranges", func() { + hist.ReceivedPacket(4) + hist.ReceivedPacket(5) + hist.ReceivedPacket(8) + hist.ReceivedPacket(9) + Expect(hist.IsPotentiallyDuplicate(3)).To(BeFalse()) + Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(6)).To(BeFalse()) + Expect(hist.IsPotentiallyDuplicate(7)).To(BeFalse()) + Expect(hist.IsPotentiallyDuplicate(8)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(9)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(10)).To(BeFalse()) + }) + + It("says a packet is a potentially duplicate if the ranges were already deleted", func() { + hist.ReceivedPacket(4) + hist.ReceivedPacket(5) + hist.ReceivedPacket(8) + hist.ReceivedPacket(9) + hist.ReceivedPacket(11) + hist.DeleteBelow(8) + Expect(hist.IsPotentiallyDuplicate(3)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(6)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(7)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(8)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(9)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(10)).To(BeFalse()) + Expect(hist.IsPotentiallyDuplicate(11)).To(BeTrue()) + Expect(hist.IsPotentiallyDuplicate(12)).To(BeFalse()) + }) + }) + + Context("randomized receiving", func() { + It("receiving packets in a random order, with gaps", func() { + packets := make(map[protocol.PacketNumber]int) + // Make sure we never end up with more than protocol.MaxNumAckRanges ACK ranges, even + // when we're receiving packets in a random order. + const num = 2 * protocol.MaxNumAckRanges + numLostPackets := rand.Intn(protocol.MaxNumAckRanges) + numRcvdPackets := num - numLostPackets + + for i := 0; i < num; i++ { + packets[protocol.PacketNumber(i)] = 0 + } + lostPackets := make([]protocol.PacketNumber, 0, numLostPackets) + for len(lostPackets) < numLostPackets { + p := protocol.PacketNumber(rand.Intn(num)) + if _, ok := packets[p]; ok { + lostPackets = append(lostPackets, p) + delete(packets, p) + } + } + sort.Slice(lostPackets, func(i, j int) bool { return lostPackets[i] < lostPackets[j] }) + fmt.Fprintf(GinkgoWriter, "Losing packets: %v\n", lostPackets) + + ordered := make([]protocol.PacketNumber, 0, numRcvdPackets) + for p := range packets { + ordered = append(ordered, p) + } + rand.Shuffle(len(ordered), func(i, j int) { ordered[i], ordered[j] = ordered[j], ordered[i] }) + + fmt.Fprintf(GinkgoWriter, "Receiving packets: %v\n", ordered) + for i, p := range ordered { + Expect(hist.ReceivedPacket(p)).To(BeTrue()) + // sometimes receive a duplicate + if i > 0 && rand.Int()%5 == 0 { + Expect(hist.ReceivedPacket(ordered[rand.Intn(i)])).To(BeFalse()) + } + } + var counter int + ackRanges := hist.GetAckRanges() + fmt.Fprintf(GinkgoWriter, "ACK ranges: %v\n", ackRanges) + Expect(len(ackRanges)).To(BeNumerically("<=", numLostPackets+1)) + for _, ackRange := range ackRanges { + for p := ackRange.Smallest; p <= ackRange.Largest; p++ { + counter++ + Expect(packets).To(HaveKey(p)) + } + } + Expect(counter).To(Equal(numRcvdPackets)) + }) + }) +}) diff --git a/internal/quic-go/ackhandler/received_packet_tracker.go b/internal/quic-go/ackhandler/received_packet_tracker.go new file mode 100644 index 00000000..31882311 --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_tracker.go @@ -0,0 +1,196 @@ +package ackhandler + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// number of ack-eliciting packets received before sending an ack. +const packetsBeforeAck = 2 + +type receivedPacketTracker struct { + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + largestObservedReceivedTime time.Time + ect0, ect1, ecnce uint64 + + packetHistory *receivedPacketHistory + + maxAckDelay time.Duration + rttStats *utils.RTTStats + + hasNewAck bool // true as soon as we received an ack-eliciting new packet + ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets + + ackElicitingPacketsReceivedSinceLastAck int + ackAlarm time.Time + lastAck *wire.AckFrame + + logger utils.Logger + + version protocol.VersionNumber +} + +func newReceivedPacketTracker( + rttStats *utils.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) *receivedPacketTracker { + return &receivedPacketTracker{ + packetHistory: newReceivedPacketHistory(), + maxAckDelay: protocol.MaxAckDelay, + rttStats: rttStats, + logger: logger, + version: version, + } +} + +func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) { + if packetNumber < h.ignoreBelow { + return + } + + isMissing := h.isMissing(packetNumber) + if packetNumber >= h.largestObserved { + h.largestObserved = packetNumber + h.largestObservedReceivedTime = rcvTime + } + + if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck { + h.hasNewAck = true + } + if shouldInstigateAck { + h.maybeQueueAck(packetNumber, rcvTime, isMissing) + } + switch ecn { + case protocol.ECNNon: + case protocol.ECT0: + h.ect0++ + case protocol.ECT1: + h.ect1++ + case protocol.ECNCE: + h.ecnce++ + } +} + +// IgnoreBelow sets a lower limit for acknowledging packets. +// Packets with packet numbers smaller than p will not be acked. +func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { + if p <= h.ignoreBelow { + return + } + h.ignoreBelow = p + h.packetHistory.DeleteBelow(p) + if h.logger.Debug() { + h.logger.Debugf("\tIgnoring all packets below %d.", p) + } +} + +// isMissing says if a packet was reported missing in the last ACK. +func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { + if h.lastAck == nil || p < h.ignoreBelow { + return false + } + return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) +} + +func (h *receivedPacketTracker) hasNewMissingPackets() bool { + if h.lastAck == nil { + return false + } + highestRange := h.packetHistory.GetHighestAckRange() + return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 +} + +// maybeQueueAck queues an ACK, if necessary. +func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { + // always acknowledge the first packet + if h.lastAck == nil { + if !h.ackQueued { + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + } + h.ackQueued = true + return + } + + if h.ackQueued { + return + } + + h.ackElicitingPacketsReceivedSinceLastAck++ + + // Send an ACK if this packet was reported missing in an ACK sent before. + // Ack decimation with reordering relies on the timer to send an ACK, but if + // missing packets we reported in the previous ack, send an ACK immediately. + if wasMissing { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) + } + h.ackQueued = true + } + + // send an ACK every 2 ack-eliciting packets + if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) + } + h.ackQueued = true + } else if h.ackAlarm.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) + } + h.ackAlarm = rcvTime.Add(h.maxAckDelay) + } + + // Queue an ACK if there are new missing packets to report. + if h.hasNewMissingPackets() { + h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") + h.ackQueued = true + } + + if h.ackQueued { + // cancel the ack alarm + h.ackAlarm = time.Time{} + } +} + +func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { + if !h.hasNewAck { + return nil + } + now := time.Now() + if onlyIfQueued { + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + return nil + } + if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + h.logger.Debugf("Sending ACK because the ACK timer expired.") + } + } + + ack := &wire.AckFrame{ + AckRanges: h.packetHistory.GetAckRanges(), + // Make sure that the DelayTime is always positive. + // This is not guaranteed on systems that don't have a monotonic clock. + DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)), + ECT0: h.ect0, + ECT1: h.ect1, + ECNCE: h.ecnce, + } + + h.lastAck = ack + h.ackAlarm = time.Time{} + h.ackQueued = false + h.hasNewAck = false + h.ackElicitingPacketsReceivedSinceLastAck = 0 + return ack +} + +func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } + +func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { + return h.packetHistory.IsPotentiallyDuplicate(pn) +} diff --git a/internal/quic-go/ackhandler/received_packet_tracker_test.go b/internal/quic-go/ackhandler/received_packet_tracker_test.go new file mode 100644 index 00000000..66b43cde --- /dev/null +++ b/internal/quic-go/ackhandler/received_packet_tracker_test.go @@ -0,0 +1,348 @@ +package ackhandler + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Received Packet Tracker", func() { + var ( + tracker *receivedPacketTracker + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + rttStats = &utils.RTTStats{} + tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger, protocol.VersionWhatever) + }) + + Context("accepting packets", func() { + It("saves the time when each packet arrived", func() { + tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true) + Expect(tracker.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + }) + + It("updates the largestObserved and the largestObservedReceivedTime", func() { + now := time.Now() + tracker.largestObserved = 3 + tracker.largestObservedReceivedTime = now.Add(-1 * time.Second) + tracker.ReceivedPacket(5, protocol.ECNNon, now, true) + Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) + Expect(tracker.largestObservedReceivedTime).To(Equal(now)) + }) + + It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { + now := time.Now() + timestamp := now.Add(-1 * time.Second) + tracker.largestObserved = 5 + tracker.largestObservedReceivedTime = timestamp + tracker.ReceivedPacket(4, protocol.ECNNon, now, true) + Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) + Expect(tracker.largestObservedReceivedTime).To(Equal(timestamp)) + }) + }) + + Context("ACKs", func() { + Context("queueing ACKs", func() { + receiveAndAck10Packets := func() { + for i := 1; i <= 10; i++ { + tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Time{}, true) + } + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + Expect(tracker.ackQueued).To(BeFalse()) + } + + It("always queues an ACK for the first packet", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + Expect(tracker.ackQueued).To(BeTrue()) + Expect(tracker.GetAlarmTimeout()).To(BeZero()) + Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) + }) + + It("works with packet number 0", func() { + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) + Expect(tracker.ackQueued).To(BeTrue()) + Expect(tracker.GetAlarmTimeout()).To(BeZero()) + Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) + }) + + It("sets ECN flags", func() { + tracker.ReceivedPacket(0, protocol.ECT0, time.Now(), true) + pn := protocol.PacketNumber(1) + for i := 0; i < 2; i++ { + tracker.ReceivedPacket(pn, protocol.ECT1, time.Now(), true) + pn++ + } + for i := 0; i < 3; i++ { + tracker.ReceivedPacket(pn, protocol.ECNCE, time.Now(), true) + pn++ + } + ack := tracker.GetAckFrame(false) + Expect(ack.ECT0).To(BeEquivalentTo(1)) + Expect(ack.ECT1).To(BeEquivalentTo(2)) + Expect(ack.ECNCE).To(BeEquivalentTo(3)) + }) + + It("queues an ACK for every second ack-eliciting packet", func() { + receiveAndAck10Packets() + p := protocol.PacketNumber(11) + for i := 0; i <= 20; i++ { + tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) + Expect(tracker.ackQueued).To(BeFalse()) + p++ + tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) + Expect(tracker.ackQueued).To(BeTrue()) + p++ + // dequeue the ACK frame + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + } + }) + + It("resets the counter when a non-queued ACK frame is generated", func() { + receiveAndAck10Packets() + rcvTime := time.Now() + tracker.ReceivedPacket(11, protocol.ECNNon, rcvTime, true) + Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) + tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + tracker.ReceivedPacket(13, protocol.ECNNon, rcvTime, true) + Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) + }) + + It("only sets the timer when receiving a ack-eliciting packets", func() { + receiveAndAck10Packets() + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) + Expect(tracker.ackQueued).To(BeFalse()) + Expect(tracker.GetAlarmTimeout()).To(BeZero()) + rcvTime := time.Now().Add(10 * time.Millisecond) + tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) + Expect(tracker.ackQueued).To(BeFalse()) + Expect(tracker.GetAlarmTimeout()).To(Equal(rcvTime.Add(protocol.MaxAckDelay))) + }) + + It("queues an ACK if it was reported missing before", func() { + receiveAndAck10Packets() + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) // ACK: 1-11 and 13, missing: 12 + Expect(ack).ToNot(BeNil()) + Expect(ack.HasMissingRanges()).To(BeTrue()) + Expect(tracker.ackQueued).To(BeFalse()) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) + Expect(tracker.ackQueued).To(BeTrue()) + }) + + It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() { + receiveAndAck10Packets() + // 11 is missing + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) // ACK: 1-10, 12-13 + Expect(ack).ToNot(BeNil()) + // now receive 11 + tracker.IgnoreBelow(12) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) + ack = tracker.GetAckFrame(true) + Expect(ack).To(BeNil()) + }) + + It("doesn't recognize in-order packets as out-of-order after raising the threshold", func() { + receiveAndAck10Packets() + Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + Expect(tracker.ackQueued).To(BeFalse()) + tracker.IgnoreBelow(11) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("recognizes out-of-order packets after raising the threshold", func() { + receiveAndAck10Packets() + Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + Expect(tracker.ackQueued).To(BeFalse()) + tracker.IgnoreBelow(11) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 12, Largest: 12}})) + }) + + It("doesn't queue an ACK if for non-ack-eliciting packets arriving out-of-order", func() { + receiveAndAck10Packets() + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), false) // receive a non-ack-eliciting packet out-of-order + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("doesn't queue an ACK if packets arrive out-of-order, but haven't been acknowledged yet", func() { + receiveAndAck10Packets() + Expect(tracker.lastAck).ToNot(BeNil()) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), false) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + // 11 is received out-of-order, but this hasn't been reported in an ACK frame yet + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + }) + + Context("ACK generation", func() { + It("generates an ACK for an ack-eliciting packet, if no ACK is queued yet", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + // The first packet is always acknowledged. + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + }) + + It("doesn't generate ACK for a non-ack-eliciting packet, if no ACK is queued yet", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + // The first packet is always acknowledged. + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), false) + Expect(tracker.GetAckFrame(false)).To(BeNil()) + tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(false) + Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) + }) + + Context("for queued ACKs", func() { + BeforeEach(func() { + tracker.ackQueued = true + }) + + It("generates a simple ACK frame", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) + + It("generates an ACK for packet number 0", func() { + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) + + It("sets the delay time", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now().Add(-1337*time.Millisecond), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond)) + }) + + It("uses a 0 delay time if the delay would be negative", func() { + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now().Add(time.Hour), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.DelayTime).To(BeZero()) + }) + + It("saves the last sent ACK", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(tracker.lastAck).To(Equal(ack)) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) + tracker.ackQueued = true + ack = tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(tracker.lastAck).To(Equal(ack)) + }) + + It("generates an ACK frame with missing packets", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + {Smallest: 4, Largest: 4}, + {Smallest: 1, Largest: 1}, + })) + }) + + It("generates an ACK for packet number 0 and other packets", func() { + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + {Smallest: 3, Largest: 3}, + {Smallest: 0, Largest: 1}, + })) + }) + + It("doesn't add delayed packets to the packetHistory", func() { + tracker.IgnoreBelow(7) + tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(10, protocol.ECNNon, time.Now(), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10))) + }) + + It("deletes packets from the packetHistory when a lower limit is set", func() { + for i := 1; i <= 12; i++ { + tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Now(), true) + } + tracker.IgnoreBelow(7) + // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) + + It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ackAlarm = time.Now().Add(-time.Minute) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + Expect(tracker.GetAlarmTimeout()).To(BeZero()) + Expect(tracker.ackElicitingPacketsReceivedSinceLastAck).To(BeZero()) + Expect(tracker.ackQueued).To(BeFalse()) + }) + + It("doesn't generate an ACK when none is queued and the timer is not set", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ackQueued = false + tracker.ackAlarm = time.Time{} + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ackQueued = false + tracker.ackAlarm = time.Now().Add(time.Minute) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("generates an ACK when the timer has expired", func() { + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ackQueued = false + tracker.ackAlarm = time.Now().Add(-time.Minute) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + }) + }) + }) + }) +}) diff --git a/internal/quic-go/ackhandler/send_mode.go b/internal/quic-go/ackhandler/send_mode.go new file mode 100644 index 00000000..3d5fe560 --- /dev/null +++ b/internal/quic-go/ackhandler/send_mode.go @@ -0,0 +1,40 @@ +package ackhandler + +import "fmt" + +// The SendMode says what kind of packets can be sent. +type SendMode uint8 + +const ( + // SendNone means that no packets should be sent + SendNone SendMode = iota + // SendAck means an ACK-only packet should be sent + SendAck + // SendPTOInitial means that an Initial probe packet should be sent + SendPTOInitial + // SendPTOHandshake means that a Handshake probe packet should be sent + SendPTOHandshake + // SendPTOAppData means that an Application data probe packet should be sent + SendPTOAppData + // SendAny means that any packet should be sent + SendAny +) + +func (s SendMode) String() string { + switch s { + case SendNone: + return "none" + case SendAck: + return "ack" + case SendPTOInitial: + return "pto (Initial)" + case SendPTOHandshake: + return "pto (Handshake)" + case SendPTOAppData: + return "pto (Application Data)" + case SendAny: + return "any" + default: + return fmt.Sprintf("invalid send mode: %d", s) + } +} diff --git a/internal/quic-go/ackhandler/send_mode_test.go b/internal/quic-go/ackhandler/send_mode_test.go new file mode 100644 index 00000000..86515d74 --- /dev/null +++ b/internal/quic-go/ackhandler/send_mode_test.go @@ -0,0 +1,18 @@ +package ackhandler + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Send Mode", func() { + It("has a string representation", func() { + Expect(SendNone.String()).To(Equal("none")) + Expect(SendAny.String()).To(Equal("any")) + Expect(SendAck.String()).To(Equal("ack")) + Expect(SendPTOInitial.String()).To(Equal("pto (Initial)")) + Expect(SendPTOHandshake.String()).To(Equal("pto (Handshake)")) + Expect(SendPTOAppData.String()).To(Equal("pto (Application Data)")) + Expect(SendMode(123).String()).To(Equal("invalid send mode: 123")) + }) +}) diff --git a/internal/quic-go/ackhandler/sent_packet_handler.go b/internal/quic-go/ackhandler/sent_packet_handler.go new file mode 100644 index 00000000..2a6b19b8 --- /dev/null +++ b/internal/quic-go/ackhandler/sent_packet_handler.go @@ -0,0 +1,838 @@ +package ackhandler + +import ( + "errors" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/congestion" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +const ( + // Maximum reordering in time space before time based loss detection considers a packet lost. + // Specified as an RTT multiplier. + timeThreshold = 9.0 / 8 + // Maximum reordering in packets before packet threshold loss detection considers a packet lost. + packetThreshold = 3 + // Before validating the client's address, the server won't send more than 3x bytes than it received. + amplificationFactor = 3 + // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. + minRTTAfterRetry = 5 * time.Millisecond +) + +type packetNumberSpace struct { + history *sentPacketHistory + pns packetNumberGenerator + + lossTime time.Time + lastAckElicitingPacketTime time.Time + + largestAcked protocol.PacketNumber + largestSent protocol.PacketNumber +} + +func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { + var pns packetNumberGenerator + if skipPNs { + pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) + } else { + pns = newSequentialPacketNumberGenerator(initialPN) + } + return &packetNumberSpace{ + history: newSentPacketHistory(rttStats), + pns: pns, + largestSent: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, + } +} + +type sentPacketHandler struct { + initialPackets *packetNumberSpace + handshakePackets *packetNumberSpace + appDataPackets *packetNumberSpace + + // Do we know that the peer completed address validation yet? + // Always true for the server. + peerCompletedAddressValidation bool + bytesReceived protocol.ByteCount + bytesSent protocol.ByteCount + // Have we validated the peer's address yet? + // Always true for the client. + peerAddressValidated bool + + handshakeConfirmed bool + + // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101 + // Only applies to the application-data packet number space. + lowestNotConfirmedAcked protocol.PacketNumber + + ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets + + bytesInFlight protocol.ByteCount + + congestion congestion.SendAlgorithmWithDebugInfos + rttStats *utils.RTTStats + + // The number of times a PTO has been sent without receiving an ack. + ptoCount uint32 + ptoMode SendMode + // The number of PTO probe packets that should be sent. + // Only applies to the application-data packet number space. + numProbesToSend int + + // The alarm timeout + alarm time.Time + + perspective protocol.Perspective + + tracer logging.ConnectionTracer + logger utils.Logger +} + +var ( + _ SentPacketHandler = &sentPacketHandler{} + _ sentPacketTracker = &sentPacketHandler{} +) + +func newSentPacketHandler( + initialPN protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, +) *sentPacketHandler { + congestion := congestion.NewCubicSender( + congestion.DefaultClock{}, + rttStats, + initialMaxDatagramSize, + true, // use Reno + tracer, + ) + + return &sentPacketHandler{ + peerCompletedAddressValidation: pers == protocol.PerspectiveServer, + peerAddressValidated: pers == protocol.PerspectiveClient, + initialPackets: newPacketNumberSpace(initialPN, false, rttStats), + handshakePackets: newPacketNumberSpace(0, false, rttStats), + appDataPackets: newPacketNumberSpace(0, true, rttStats), + rttStats: rttStats, + congestion: congestion, + perspective: pers, + tracer: tracer, + logger: logger, + } +} + +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { + // This function is called when the crypto setup seals a Handshake packet. + // If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space + // before SentPacket() was called for that Initial packet. + return + } + h.dropPackets(encLevel) +} + +func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { + if p.includedInBytesInFlight { + if p.Length > h.bytesInFlight { + panic("negative bytes_in_flight") + } + h.bytesInFlight -= p.Length + p.includedInBytesInFlight = false + } +} + +func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { + // The server won't await address validation after the handshake is confirmed. + // This applies even if we didn't receive an ACK for a Handshake packet. + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { + h.peerCompletedAddressValidation = true + } + // remove outstanding packets from bytes_in_flight + if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { + pnSpace := h.getPacketNumberSpace(encLevel) + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + h.removeFromBytesInFlight(p) + return true, nil + }) + } + // drop the packet history + //nolint:exhaustive // Not every packet number space can be dropped. + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + case protocol.Encryption0RTT: + // This function is only called when 0-RTT is rejected, + // and not when the client drops 0-RTT keys when the handshake completes. + // When 0-RTT is rejected, all application data sent so far becomes invalid. + // Delete the packets from the history and remove them from bytes_in_flight. + h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + if p.EncryptionLevel != protocol.Encryption0RTT { + return false, nil + } + h.removeFromBytesInFlight(p) + h.appDataPackets.history.Remove(p.PacketNumber) + return true, nil + }) + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } + if h.tracer != nil && h.ptoCount != 0 { + h.tracer.UpdatedPTOCount(0) + } + h.ptoCount = 0 + h.numProbesToSend = 0 + h.ptoMode = SendNone + h.setLossDetectionTimer() +} + +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { + wasAmplificationLimit := h.isAmplificationLimited() + h.bytesReceived += n + if wasAmplificationLimit && !h.isAmplificationLimited() { + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { + h.peerAddressValidated = true + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) packetsInFlight() int { + packetsInFlight := h.appDataPackets.history.Len() + if h.handshakePackets != nil { + packetsInFlight += h.handshakePackets.history.Len() + } + if h.initialPackets != nil { + packetsInFlight += h.initialPackets.history.Len() + } + return packetsInFlight +} + +func (h *sentPacketHandler) SentPacket(packet *Packet) { + h.bytesSent += packet.Length + // For the client, drop the Initial packet number space when the first Handshake packet is sent. + if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { + h.dropPackets(protocol.EncryptionInitial) + } + isAckEliciting := h.sentPacketImpl(packet) + h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet, isAckEliciting) + if h.tracer != nil && isAckEliciting { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + if isAckEliciting || !h.peerCompletedAddressValidation { + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { + switch encLevel { + case protocol.EncryptionInitial: + return h.initialPackets + case protocol.EncryptionHandshake: + return h.handshakePackets + case protocol.Encryption0RTT, protocol.Encryption1RTT: + return h.appDataPackets + default: + panic("invalid packet number space") + } +} + +func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ { + pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) + + if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { + for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { + h.logger.Debugf("Skipping packet number %d", p) + } + } + + pnSpace.largestSent = packet.PacketNumber + isAckEliciting := len(packet.Frames) > 0 + + if isAckEliciting { + pnSpace.lastAckElicitingPacketTime = packet.SendTime + packet.includedInBytesInFlight = true + h.bytesInFlight += packet.Length + if h.numProbesToSend > 0 { + h.numProbesToSend-- + } + } + h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting) + + return isAckEliciting +} + +func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { + pnSpace := h.getPacketNumberSpace(encLevel) + + largestAcked := ack.LargestAcked() + if largestAcked > pnSpace.largestSent { + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + } + } + + pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) + + // Servers complete address validation when a protected packet is received. + if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && + (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { + h.peerCompletedAddressValidation = true + h.logger.Debugf("Peer doesn't await address validation any longer.") + // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. + h.setLossDetectionTimer() + } + + priorInFlight := h.bytesInFlight + ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel) + if err != nil || len(ackedPackets) == 0 { + return false, err + } + // update the RTT, if the largest acked is newly acknowledged + if len(ackedPackets) > 0 { + if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() { + // don't use the ack delay for Initial and Handshake packets + var ackDelay time.Duration + if encLevel == protocol.Encryption1RTT { + ackDelay = utils.MinDuration(ack.DelayTime, h.rttStats.MaxAckDelay()) + } + h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } + h.congestion.MaybeExitSlowStart() + } + } + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { + return false, err + } + var acked1RTTPacket bool + for _, p := range ackedPackets { + if p.includedInBytesInFlight && !p.declaredLost { + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + } + if p.EncryptionLevel == protocol.Encryption1RTT { + acked1RTTPacket = true + } + h.removeFromBytesInFlight(p) + } + + // Reset the pto_count unless the client is unsure if the server has validated the client's address. + if h.peerCompletedAddressValidation { + if h.tracer != nil && h.ptoCount != 0 { + h.tracer.UpdatedPTOCount(0) + } + h.ptoCount = 0 + } + h.numProbesToSend = 0 + + if h.tracer != nil { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + + pnSpace.history.DeleteOldPackets(rcvTime) + h.setLossDetectionTimer() + return acked1RTTPacket, nil +} + +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestNotConfirmedAcked +} + +// Packets are returned in ascending packet number order. +func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { + pnSpace := h.getPacketNumberSpace(encLevel) + h.ackedPackets = h.ackedPackets[:0] + ackRangeIndex := 0 + lowestAcked := ack.LowestAcked() + largestAcked := ack.LargestAcked() + err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { + // Ignore packets below the lowest acked + if p.PacketNumber < lowestAcked { + return true, nil + } + // Break after largest acked is reached + if p.PacketNumber > largestAcked { + return false, nil + } + + if ack.HasMissingRanges() { + ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] + + for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 { + ackRangeIndex++ + ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] + } + + if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range + return true, nil + } + if p.PacketNumber > ackRange.Largest { + return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) + } + } + if p.skippedPacket { + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), + } + } + h.ackedPackets = append(h.ackedPackets, p) + return true, nil + }) + if h.logger.Debug() && len(h.ackedPackets) > 0 { + pns := make([]protocol.PacketNumber, len(h.ackedPackets)) + for i, p := range h.ackedPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns) + } + + for _, p := range h.ackedPackets { + if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { + h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) + } + + for _, f := range p.Frames { + if f.OnAcked != nil { + f.OnAcked(f.Frame) + } + } + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + return nil, err + } + if h.tracer != nil { + h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) + } + } + + return h.ackedPackets, err +} + +func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { + var encLevel protocol.EncryptionLevel + var lossTime time.Time + + if h.initialPackets != nil { + lossTime = h.initialPackets.lossTime + encLevel = protocol.EncryptionInitial + } + if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) { + lossTime = h.handshakePackets.lossTime + encLevel = protocol.EncryptionHandshake + } + if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) { + lossTime = h.appDataPackets.lossTime + encLevel = protocol.Encryption1RTT + } + return lossTime, encLevel +} + +// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime +func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { + // We only send application data probe packets once the handshake is confirmed, + // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. + if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { + if h.peerCompletedAddressValidation { + return + } + t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) + if h.initialPackets != nil { + return t, protocol.EncryptionInitial, true + } + return t, protocol.EncryptionHandshake, true + } + + if h.initialPackets != nil { + encLevel = protocol.EncryptionInitial + if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { + pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) + } + } + if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { + t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.EncryptionHandshake + } + } + if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { + t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.Encryption1RTT + } + } + return pto, encLevel, true +} + +func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { + if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() { + return true + } + if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() { + return true + } + return false +} + +func (h *sentPacketHandler) hasOutstandingPackets() bool { + return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() +} + +func (h *sentPacketHandler) setLossDetectionTimer() { + oldAlarm := h.alarm // only needed in case tracing is enabled + lossTime, encLevel := h.getLossTimeAndSpace() + if !lossTime.IsZero() { + // Early retransmit timer or time loss detection. + h.alarm = lossTime + if h.tracer != nil && h.alarm != oldAlarm { + h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) + } + return + } + + // Cancel the alarm if amplification limited. + if h.isAmplificationLimited() { + h.alarm = time.Time{} + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. Amplification limited.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + + // Cancel the alarm if no packets are outstanding + if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { + h.alarm = time.Time{} + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. No packets in flight.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + + // PTO alarm + ptoTime, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + if !oldAlarm.IsZero() { + h.alarm = time.Time{} + h.logger.Debugf("Canceling loss detection timer. No PTO needed..") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + h.alarm = ptoTime + if h.tracer != nil && h.alarm != oldAlarm { + h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) + } +} + +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { + pnSpace := h.getPacketNumberSpace(encLevel) + pnSpace.lossTime = time.Time{} + + maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + lossDelay := time.Duration(timeThreshold * maxRTT) + + // Minimum time of granularity before packets are deemed lost. + lossDelay = utils.MaxDuration(lossDelay, protocol.TimerGranularity) + + // Packets sent before this time are deemed lost. + lostSendTime := now.Add(-lossDelay) + + priorInFlight := h.bytesInFlight + return pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if p.PacketNumber > pnSpace.largestAcked { + return false, nil + } + if p.declaredLost || p.skippedPacket { + return true, nil + } + + var packetLost bool + if p.SendTime.Before(lostSendTime) { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) + } + } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) + } + } else if pnSpace.lossTime.IsZero() { + // Note: This conditional is only entered once per call + lossTime := p.SendTime.Add(lossDelay) + if h.logger.Debug() { + h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime) + } + pnSpace.lossTime = lossTime + } + if packetLost { + p.declaredLost = true + // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted + h.removeFromBytesInFlight(p) + h.queueFramesForRetransmission(p) + if !p.IsPathMTUProbePacket { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } + } + return true, nil + }) +} + +func (h *sentPacketHandler) OnLossDetectionTimeout() error { + defer h.setLossDetectionTimer() + earliestLossTime, encLevel := h.getLossTimeAndSpace() + if !earliestLossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) + } + if h.tracer != nil { + h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) + } + // Early retransmit or time loss detection + return h.detectLostPackets(time.Now(), encLevel) + } + + // PTO + // When all outstanding are acknowledged, the alarm is canceled in + // setLossDetectionTimer. This doesn't reset the timer in the session though. + // When OnAlarm is called, we therefore need to make sure that there are + // actually packets outstanding. + if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { + h.ptoCount++ + h.numProbesToSend++ + if h.initialPackets != nil { + h.ptoMode = SendPTOInitial + } else if h.handshakePackets != nil { + h.ptoMode = SendPTOHandshake + } else { + return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped") + } + return nil + } + + _, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + return nil + } + if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation { + return nil + } + h.ptoCount++ + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) + } + if h.tracer != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + h.tracer.UpdatedPTOCount(h.ptoCount) + } + h.numProbesToSend += 2 + //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + h.ptoMode = SendPTOInitial + case protocol.EncryptionHandshake: + h.ptoMode = SendPTOHandshake + case protocol.Encryption1RTT: + // skip a packet number in order to elicit an immediate ACK + _ = h.PopPacketNumber(protocol.Encryption1RTT) + h.ptoMode = SendPTOAppData + default: + return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) + } + return nil +} + +func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { + return h.alarm +} + +func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + pnSpace := h.getPacketNumberSpace(encLevel) + + var lowestUnacked protocol.PacketNumber + if p := pnSpace.history.FirstOutstanding(); p != nil { + lowestUnacked = p.PacketNumber + } else { + lowestUnacked = pnSpace.largestAcked + 1 + } + + pn := pnSpace.pns.Peek() + return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) +} + +func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { + return h.getPacketNumberSpace(encLevel).pns.Pop() +} + +func (h *sentPacketHandler) SendMode() SendMode { + numTrackedPackets := h.appDataPackets.history.Len() + if h.initialPackets != nil { + numTrackedPackets += h.initialPackets.history.Len() + } + if h.handshakePackets != nil { + numTrackedPackets += h.handshakePackets.history.Len() + } + + if h.isAmplificationLimited() { + h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) + return SendNone + } + // Don't send any packets if we're keeping track of the maximum number of packets. + // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, + // we will stop sending out new data when reaching MaxOutstandingSentPackets, + // but still allow sending of retransmissions and ACKs. + if numTrackedPackets >= protocol.MaxTrackedSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + } + return SendNone + } + if h.numProbesToSend > 0 { + return h.ptoMode + } + // Only send ACKs if we're congestion limited. + if !h.congestion.CanSend(h.bytesInFlight) { + if h.logger.Debug() { + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow()) + } + return SendAck + } + if numTrackedPackets >= protocol.MaxOutstandingSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + } + return SendAck + } + return SendAny +} + +func (h *sentPacketHandler) TimeUntilSend() time.Time { + return h.congestion.TimeUntilSend(h.bytesInFlight) +} + +func (h *sentPacketHandler) HasPacingBudget() bool { + return h.congestion.HasPacingBudget() +} + +func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { + h.congestion.SetMaxDatagramSize(s) +} + +func (h *sentPacketHandler) isAmplificationLimited() bool { + if h.peerAddressValidated { + return false + } + return h.bytesSent >= amplificationFactor*h.bytesReceived +} + +func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { + pnSpace := h.getPacketNumberSpace(encLevel) + p := pnSpace.history.FirstOutstanding() + if p == nil { + return false + } + h.queueFramesForRetransmission(p) + // TODO: don't declare the packet lost here. + // Keep track of acknowledged frames instead. + h.removeFromBytesInFlight(p) + p.declaredLost = true + return true +} + +func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { + if len(p.Frames) == 0 { + panic("no frames") + } + for _, f := range p.Frames { + f.OnLost(f.Frame) + } + p.Frames = nil +} + +func (h *sentPacketHandler) ResetForRetry() error { + h.bytesInFlight = 0 + var firstPacketSendTime time.Time + h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { + if firstPacketSendTime.IsZero() { + firstPacketSendTime = p.SendTime + } + if p.declaredLost || p.skippedPacket { + return true, nil + } + h.queueFramesForRetransmission(p) + return true, nil + }) + // All application data packets sent at this point are 0-RTT packets. + // In the case of a Retry, we can assume that the server dropped all of them. + h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost && !p.skippedPacket { + h.queueFramesForRetransmission(p) + } + return true, nil + }) + + // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. + // Otherwise, we don't know which Initial the Retry was sent in response to. + if h.ptoCount == 0 { + // Don't set the RTT to a value lower than 5ms here. + now := time.Now() + h.rttStats.UpdateRTT(utils.MaxDuration(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } + if h.tracer != nil { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + } + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) + oldAlarm := h.alarm + h.alarm = time.Time{} + if h.tracer != nil { + h.tracer.UpdatedPTOCount(0) + if !oldAlarm.IsZero() { + h.tracer.LossTimerCanceled() + } + } + h.ptoCount = 0 + return nil +} + +func (h *sentPacketHandler) SetHandshakeConfirmed() { + h.handshakeConfirmed = true + // We don't send PTOs for application data packets before the handshake completes. + // Make sure the timer is armed now, if necessary. + h.setLossDetectionTimer() +} diff --git a/internal/quic-go/ackhandler/sent_packet_handler_test.go b/internal/quic-go/ackhandler/sent_packet_handler_test.go new file mode 100644 index 00000000..e7a19250 --- /dev/null +++ b/internal/quic-go/ackhandler/sent_packet_handler_test.go @@ -0,0 +1,1386 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/golang/mock/gomock" + + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("SentPacketHandler", func() { + var ( + handler *sentPacketHandler + streamFrame wire.StreamFrame + lostPackets []protocol.PacketNumber + perspective protocol.Perspective + ) + + BeforeEach(func() { perspective = protocol.PerspectiveServer }) + + JustBeforeEach(func() { + lostPackets = nil + rttStats := utils.NewRTTStats() + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, perspective, nil, utils.DefaultLogger) + streamFrame = wire.StreamFrame{ + StreamID: 5, + Data: []byte{0x13, 0x37}, + } + }) + + getPacket := func(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) *Packet { + if el, ok := handler.getPacketNumberSpace(encLevel).history.packetMap[pn]; ok { + return &el.Value + } + return nil + } + + ackElicitingPacket := func(p *Packet) *Packet { + if p.EncryptionLevel == 0 { + p.EncryptionLevel = protocol.Encryption1RTT + } + if p.Length == 0 { + p.Length = 1 + } + if p.SendTime.IsZero() { + p.SendTime = time.Now() + } + if len(p.Frames) == 0 { + p.Frames = []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, + } + } + return p + } + + nonAckElicitingPacket := func(p *Packet) *Packet { + p = ackElicitingPacket(p) + p.Frames = nil + p.LargestAcked = 1 + return p + } + + initialPacket := func(p *Packet) *Packet { + p = ackElicitingPacket(p) + p.EncryptionLevel = protocol.EncryptionInitial + return p + } + + handshakePacket := func(p *Packet) *Packet { + p = ackElicitingPacket(p) + p.EncryptionLevel = protocol.EncryptionHandshake + return p + } + + handshakePacketNonAckEliciting := func(p *Packet) *Packet { + p = nonAckElicitingPacket(p) + p.EncryptionLevel = protocol.EncryptionHandshake + return p + } + + expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { + pnSpace := handler.getPacketNumberSpace(encLevel) + var length int + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost && !p.skippedPacket { + length++ + } + return true, nil + }) + ExpectWithOffset(1, length).To(Equal(len(expected))) + for _, p := range expected { + ExpectWithOffset(2, pnSpace.history.packetMap).To(HaveKey(p)) + } + } + + updateRTT := func(rtt time.Duration) { + handler.rttStats.UpdateRTT(rtt, 0, time.Now()) + ExpectWithOffset(1, handler.rttStats.SmoothedRTT()).To(Equal(rtt)) + } + + Context("registering sent packets", func() { + It("accepts two consecutive packets", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.EncryptionHandshake})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, EncryptionLevel: protocol.EncryptionHandshake})) + Expect(handler.handshakePackets.largestSent).To(Equal(protocol.PacketNumber(2))) + expectInPacketHistory([]protocol.PacketNumber{1, 2}, protocol.EncryptionHandshake) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) + }) + + It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, EncryptionLevel: protocol.Encryption1RTT})) + Expect(handler.appDataPackets.largestSent).To(Equal(protocol.PacketNumber(2))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) + }) + + It("accepts packet number 0", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 0, EncryptionLevel: protocol.Encryption1RTT})) + Expect(handler.appDataPackets.largestSent).To(BeZero()) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT})) + Expect(handler.appDataPackets.largestSent).To(Equal(protocol.PacketNumber(1))) + expectInPacketHistory([]protocol.PacketNumber{0, 1}, protocol.Encryption1RTT) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) + }) + + It("stores the sent time", func() { + sendTime := time.Now().Add(-time.Minute) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) + Expect(handler.appDataPackets.lastAckElicitingPacketTime).To(Equal(sendTime)) + }) + + It("stores the sent time of Initial packets", func() { + sendTime := time.Now().Add(-time.Minute) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT})) + Expect(handler.initialPackets.lastAckElicitingPacketTime).To(Equal(sendTime)) + }) + }) + + Context("ACK processing", func() { + JustBeforeEach(func() { + for i := protocol.PacketNumber(0); i < 10; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) + } + // Increase RTT, because the tests would be flaky otherwise + updateRTT(time.Hour) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + }) + + Context("ACK processing", func() { + It("accepts ACKs sent in packet 0", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) + }) + + It("says if a 1-RTT packet was acknowledged", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 101, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102, EncryptionLevel: protocol.Encryption1RTT})) + acked1RTT, err := handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 101}}}, + protocol.Encryption1RTT, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(acked1RTT).To(BeFalse()) + acked1RTT, err = handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 101, Largest: 102}}}, + protocol.Encryption1RTT, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(acked1RTT).To(BeTrue()) + }) + + It("accepts multiple ACKs sent in the same packet", func() { + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}} + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) + // this wouldn't happen in practice + // for testing purposes, we pretend send a different ACK frame in a duplicated packet, to be able to verify that it actually doesn't get processed + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(4))) + }) + + It("rejects ACKs that acknowledge a skipped packet number", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 102}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received an ACK for skipped packet number: 101 (1-RTT)", + })) + }) + + It("rejects ACKs with a too high LargestAcked packet number", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + })) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + }) + + It("ignores repeated ACKs", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) + }) + }) + + Context("acks the right packets", func() { + expectInPacketHistoryOrLost := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { + pnSpace := handler.getPacketNumberSpace(encLevel) + var length int + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost { + length++ + } + return true, nil + }) + ExpectWithOffset(1, length+len(lostPackets)).To(Equal(len(expected))) + expectedLoop: + for _, p := range expected { + if _, ok := pnSpace.history.packetMap[p]; ok { + continue + } + for _, lostP := range lostPackets { + if lostP == p { + continue expectedLoop + } + } + Fail(fmt.Sprintf("Packet %d not in packet history.", p)) + } + } + + It("adjusts the LargestAcked, and adjusts the bytes in flight", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) + expectInPacketHistoryOrLost([]protocol.PacketNumber{6, 7, 8, 9}, protocol.Encryption1RTT) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4))) + }) + + It("acks packet 0", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 0}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(getPacket(0, protocol.Encryption1RTT)).To(BeNil()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9}, protocol.Encryption1RTT) + }) + + It("calls the OnAcked callback", func() { + var acked bool + ping := &wire.PingFrame{} + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 13, + Frames: []Frame{{ + Frame: ping, OnAcked: func(f wire.Frame) { + Expect(f).To(Equal(ping)) + acked = true + }, + }}, + })) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(acked).To(BeTrue()) + }) + + It("handles an ACK frame with one missing packet range", func() { + ack := &wire.AckFrame{ // lose 4 and 5 + AckRanges: []wire.AckRange{ + {Smallest: 6, Largest: 9}, + {Smallest: 1, Largest: 3}, + }, + } + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 4, 5}, protocol.Encryption1RTT) + }) + + It("does not ack packets below the LowestAcked", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 8}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 1, 2, 9}, protocol.Encryption1RTT) + }) + + It("handles an ACK with multiple missing packet ranges", func() { + ack := &wire.AckFrame{ // packets 2, 4 and 5, and 8 were lost + AckRanges: []wire.AckRange{ + {Smallest: 9, Largest: 9}, + {Smallest: 6, Largest: 7}, + {Smallest: 3, Largest: 3}, + {Smallest: 1, Largest: 1}, + }, + } + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 2, 4, 5, 8}, protocol.Encryption1RTT) + }) + + It("processes an ACK frame that would be sent after a late arrival of a packet", func() { + ack1 := &wire.AckFrame{ // 5 lost + AckRanges: []wire.AckRange{ + {Smallest: 6, Largest: 6}, + {Smallest: 1, Largest: 4}, + }, + } + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 5, 7, 8, 9}, protocol.Encryption1RTT) + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} // now ack 5 + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) + }) + + It("processes an ACK that contains old ACK ranges", func() { + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) + ack2 := &wire.AckFrame{ + AckRanges: []wire.AckRange{ + {Smallest: 8, Largest: 8}, + {Smallest: 3, Largest: 3}, + {Smallest: 1, Largest: 1}, + }, + } + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 9}, protocol.Encryption1RTT) + }) + }) + + Context("calculating RTT", func() { + It("computes the RTT", func() { + now := time.Now() + // First, fake the sent times of the first, second and last packet + getPacket(1, protocol.Encryption1RTT).SendTime = now.Add(-10 * time.Minute) + getPacket(2, protocol.Encryption1RTT).SendTime = now.Add(-5 * time.Minute) + getPacket(6, protocol.Encryption1RTT).SendTime = now.Add(-1 * time.Minute) + // Now, check that the proper times are used when calculating the deltas + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) + ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) + ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second)) + }) + + It("ignores the DelayTime for Initial and Handshake packets", func() { + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) + handler.rttStats.SetMaxAckDelay(time.Hour) + // make sure the rttStats have a min RTT, so that the delay is used + handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) + getPacket(1, protocol.EncryptionInitial).SendTime = time.Now().Add(-10 * time.Minute) + ack := &wire.AckFrame{ + AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: 5 * time.Minute, + } + _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) + }) + + It("uses the DelayTime in the ACK frame", func() { + handler.rttStats.SetMaxAckDelay(time.Hour) + // make sure the rttStats have a min RTT, so that the delay is used + handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) + getPacket(1, protocol.Encryption1RTT).SendTime = time.Now().Add(-10 * time.Minute) + ack := &wire.AckFrame{ + AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: 5 * time.Minute, + } + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) + }) + + It("limits the DelayTime in the ACK frame to max_ack_delay", func() { + handler.rttStats.SetMaxAckDelay(time.Minute) + // make sure the rttStats have a min RTT, so that the delay is used + handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) + getPacket(1, protocol.Encryption1RTT).SendTime = time.Now().Add(-10 * time.Minute) + ack := &wire.AckFrame{ + AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: 5 * time.Minute, + } + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 9*time.Minute, 1*time.Second)) + }) + }) + + Context("determining which ACKs we have received an ACK for", func() { + JustBeforeEach(func() { + morePackets := []*Packet{ + { + PacketNumber: 13, + LargestAcked: 100, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, + { + PacketNumber: 14, + LargestAcked: 200, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, + { + PacketNumber: 15, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, + } + for _, packet := range morePackets { + handler.SentPacket(packet) + } + }) + + It("determines which ACK we have received an ACK for", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + + It("doesn't do anything when the acked packet didn't contain an ACK", func() { + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}} + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + }) + + It("doesn't decrease the value", func() { + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}} + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + }) + }) + + Context("congestion", func() { + var cong *mocks.MockSendAlgorithmWithDebugInfos + + JustBeforeEach(func() { + cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) + handler.congestion = cong + }) + + It("should call OnSent", func() { + cong.EXPECT().OnPacketSent( + gomock.Any(), + protocol.ByteCount(42), + protocol.PacketNumber(1), + protocol.ByteCount(42), + true, + ) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 42, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) {}}}, + EncryptionLevel: protocol.Encryption1RTT, + }) + }) + + It("should call MaybeExitSlowStart and OnPacketAcked", func() { + rcvTime := time.Now().Add(-5 * time.Second) + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), // must be called before packets are acked + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3), rcvTime), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(3), rcvTime), + ) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, rcvTime) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't call OnPacketAcked when a retransmitted packet is acked", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + // lose packet 1 + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), + ) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + // don't EXPECT any further calls to the congestion controller + ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + var mtuPacketDeclaredLost bool + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + SendTime: time.Now().Add(-time.Hour), + IsPathMTUProbePacket: true, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, + })) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + // lose packet 1, but don't EXPECT any calls to OnPacketLost() + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), + ) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(mtuPacketDeclaredLost).To(BeTrue()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + + It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: time.Now().Add(-30 * time.Minute)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4, SendTime: time.Now()})) + // receive the first ACK + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), + ) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute)) + Expect(err).ToNot(HaveOccurred()) + // receive the second ACK + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), + ) + ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("passes the bytes in flight to the congestion controller", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true) + handler.SentPacket(&Packet{ + Length: 42, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true) + handler.SendMode() + }) + + It("allows sending of ACKs when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + cong.EXPECT().CanSend(gomock.Any()).Return(true) + Expect(handler.SendMode()).To(Equal(SendAny)) + cong.EXPECT().CanSend(gomock.Any()).Return(false) + Expect(handler.SendMode()).To(Equal(SendAck)) + }) + + It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + for i := protocol.PacketNumber(0); i < protocol.MaxOutstandingSentPackets; i++ { + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) + } + Expect(handler.SendMode()).To(Equal(SendAck)) + }) + + It("allows PTOs, even when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + // note that we don't EXPECT a call to GetCongestionWindow + // that means retransmissions are sent without considering the congestion window + handler.numProbesToSend = 1 + handler.ptoMode = SendPTOHandshake + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) + }) + + It("says if it has pacing budget", func() { + cong.EXPECT().HasPacingBudget().Return(true) + Expect(handler.HasPacingBudget()).To(BeTrue()) + cong.EXPECT().HasPacingBudget().Return(false) + Expect(handler.HasPacingBudget()).To(BeFalse()) + }) + + It("returns the pacing delay", func() { + t := time.Now() + cong.EXPECT().TimeUntilSend(gomock.Any()).Return(t) + Expect(handler.TimeUntilSend()).To(Equal(t)) + }) + }) + + It("doesn't set an alarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + + It("does nothing on OnAlarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + Context("probe packets", func() { + It("queues a probe packet", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) + queued := handler.QueueProbePacket(protocol.Encryption1RTT) + Expect(queued).To(BeTrue()) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{10})) + }) + + It("says when it can't queue a probe packet", func() { + queued := handler.QueueProbePacket(protocol.Encryption1RTT) + Expect(queued).To(BeFalse()) + }) + + It("implements exponential backoff", func() { + handler.peerAddressValidated = true + handler.SetHandshakeConfirmed() + sendTime := time.Now().Add(-time.Hour) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) + timeout := handler.GetLossDetectionTimeout().Sub(sendTime) + Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(timeout)) + handler.ptoCount = 1 + handler.setLossDetectionTimer() + Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(2 * timeout)) + handler.ptoCount = 2 + handler.setLossDetectionTimer() + Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) + }) + + It("reset the PTO count when receiving an ACK", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + now := time.Now() + handler.SetHandshakeConfirmed() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ptoCount).To(BeZero()) + }) + + It("resets the PTO mode and PTO count when a packet number space is dropped", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + + now := time.Now() + handler.rttStats.UpdateRTT(time.Second/2, 0, now) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second / 2)) + Expect(handler.rttStats.PTO(true)).To(And( + BeNumerically(">", time.Second), + BeNumerically("<", 2*time.Second), + )) + sendTimeHandshake := now.Add(-2 * time.Minute) + sendTimeAppData := now.Add(-time.Minute) + + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + EncryptionLevel: protocol.EncryptionHandshake, + SendTime: sendTimeHandshake, + })) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 2, + SendTime: sendTimeAppData, + })) + + // PTO timer based on the Handshake packet + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) + Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) + handler.SetHandshakeConfirmed() + handler.DropPackets(protocol.EncryptionHandshake) + // PTO timer based on the 1-RTT packet + Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 + Expect(handler.SendMode()).ToNot(Equal(SendPTOHandshake)) + Expect(handler.ptoCount).To(BeZero()) + }) + + It("allows two 1-RTT PTOs", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + var lostPackets []protocol.PacketNumber + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + SendTime: time.Now().Add(-time.Hour), + Frames: []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, + }, + })) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) + }) + + It("skips a packet number for 1-RTT PTOs", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + var lostPackets []protocol.PacketNumber + pn := handler.PopPacketNumber(protocol.Encryption1RTT) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: pn, + SendTime: time.Now().Add(-time.Hour), + Frames: []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, + }, + })) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + // The packet number generator might have introduced another skipped a packet number. + Expect(handler.PopPacketNumber(protocol.Encryption1RTT)).To(BeNumerically(">=", pn+2)) + }) + + It("only counts ack-eliciting packets as probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + for p := protocol.PacketNumber(3); p < 30; p++ { + handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: p})) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + } + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 30})) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) + }) + + It("gets two probe packets if PTO expires", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + + updateRTT(time.Hour) + Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) + + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) + + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO + Expect(handler.ptoCount).To(BeEquivalentTo(2)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) + + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + It("gets two probe packets if PTO expires, for Handshake packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) + + updateRTT(time.Hour) + Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) + + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 3})) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 4})) + + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + It("doesn't send 1-RTT probe packets before the handshake completes", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) + updateRTT(time.Hour) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SetHandshakeConfirmed() + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + }) + + It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + updateRTT(time.Second) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + It("handles ACKs for the original packet", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) + updateRTT(time.Second) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + }) + + It("doesn't set the PTO timer for Path MTU probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + updateRTT(time.Second) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true})) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + }) + + Context("amplification limit, for the server", func() { + It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { + handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address + handler.ReceivedBytes(200) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 599, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SentPacket(&Packet{ + PacketNumber: 2, + Length: 1, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.SendMode()).To(Equal(SendNone)) + }) + + It("cancels the loss detection timer when it is amplification limited, and resets it when becoming unblocked", func() { + handler.ReceivedBytes(300) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 900, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + // Amplification limited. We don't need to set a timer now. + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + // Unblock the server. Now we should fire up the timer. + handler.ReceivedBytes(1) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + }) + + It("resets the loss detection timer when the client's address is validated", func() { + handler.ReceivedBytes(300) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 900, + EncryptionLevel: protocol.EncryptionHandshake, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + // Amplification limited. We don't need to set a timer now. + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + handler.ReceivedPacket(protocol.EncryptionHandshake) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + }) + + It("cancels the loss detection alarm when all Handshake packets are acknowledged", func() { + t := time.Now().Add(-time.Second) + handler.ReceivedBytes(99999) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: t})) + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 3, SendTime: t})) + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 4, SendTime: t})) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 4}}}, protocol.EncryptionHandshake, time.Now()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + }) + + Context("amplification limit, for the client", func() { + BeforeEach(func() { + perspective = protocol.PerspectiveClient + }) + + It("sends an Initial packet to unblock the server", func() { + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) + _, err := handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, + protocol.EncryptionInitial, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + // No packets are outstanding at this point. + // Make sure that a probe packet is sent. + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + + // send a single packet to unblock the server + handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) + Expect(handler.SendMode()).To(Equal(SendAny)) + + // Now receive an ACK for a Handshake packet. + // This tells the client that the server completed address validation. + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1})) + _, err = handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, + protocol.EncryptionHandshake, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + // Make sure that no timer is set at this point. + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + + It("sends a Handshake packet to unblock the server, if Initial keys were already dropped", func() { + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) + _, err := handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, + protocol.EncryptionInitial, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + + handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1})) // also drops Initial packets + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) + + // Now receive an ACK for this packet, and send another one. + _, err = handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, + protocol.EncryptionHandshake, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 2})) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + + It("doesn't send a packet to unblock the server after handshake confirmation, even if no Handshake ACK was received", func() { + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1})) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) + // confirm the handshake + handler.DropPackets(protocol.EncryptionHandshake) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + + It("correctly sets the timer after the Initial packet number space has been dropped", func() { + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-42 * time.Second)})) + _, err := handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, + protocol.EncryptionInitial, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1, SendTime: time.Now()})) + Expect(handler.initialPackets).To(BeNil()) + + pto := handler.rttStats.PTO(false) + Expect(pto).ToNot(BeZero()) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", time.Now().Add(pto), 10*time.Millisecond)) + }) + + It("doesn't reset the PTO count when receiving an ACK", func() { + now := time.Now() + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + }) + }) + + Context("Packet-based loss detection", func() { + It("declares packet below the packet loss threshold as lost", func() { + now := time.Now() + for i := protocol.PacketNumber(1); i <= 6; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) + } + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) + Expect(err).ToNot(HaveOccurred()) + expectInPacketHistory([]protocol.PacketNumber{4, 5}, protocol.Encryption1RTT) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{1, 2, 3})) + }) + }) + + Context("Delay-based loss detection", func() { + It("immediately detects old packets as lost when receiving an ACK", func() { + now := time.Now() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Hour)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Second)})) + Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) + + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) + Expect(err).ToNot(HaveOccurred()) + // no need to set an alarm, since packet 1 was already declared lost + Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + + It("sets the early retransmit alarm", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.handshakeConfirmed = true + now := time.Now() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: now})) + Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) + + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) + + // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. + Expect(handler.GetLossDetectionTimeout().Sub(getPacket(1, protocol.Encryption1RTT).SendTime)).To(Equal(time.Second * 9 / 8)) + Expect(handler.SendMode()).To(Equal(SendAny)) + + expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.Encryption1RTT) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + expectInPacketHistory([]protocol.PacketNumber{3}, protocol.Encryption1RTT) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + It("sets the early retransmit alarm for crypto packets", func() { + handler.ReceivedBytes(1000) + now := time.Now() + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 3, SendTime: now})) + Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) + + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) + + // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. + Expect(handler.GetLossDetectionTimeout().Sub(getPacket(1, protocol.EncryptionInitial).SendTime)).To(Equal(time.Second * 9 / 8)) + Expect(handler.SendMode()).To(Equal(SendAny)) + + expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.EncryptionInitial) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + expectInPacketHistory([]protocol.PacketNumber{3}, protocol.EncryptionInitial) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + + It("sets the early retransmit alarm for Path MTU probe packets", func() { + var mtuPacketDeclaredLost bool + now := time.Now() + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + SendTime: now.Add(-3 * time.Second), + IsPathMTUProbePacket: true, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, + })) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-3 * time.Second)})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(mtuPacketDeclaredLost).To(BeFalse()) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(mtuPacketDeclaredLost).To(BeTrue()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + }) + + Context("crypto packets", func() { + It("rejects an ACK that acks packets with a higher encryption level", func() { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 13, + EncryptionLevel: protocol.Encryption1RTT, + })) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + _, err := handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now()) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + })) + }) + + It("deletes Initial packets, as a server", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionInitial, + })) + } + for i := protocol.PacketNumber(0); i < 10; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionHandshake, + })) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) + handler.DropPackets(protocol.EncryptionInitial) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + Expect(handler.initialPackets).To(BeNil()) + Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) + }) + + Context("deleting Initials", func() { + BeforeEach(func() { perspective = protocol.PerspectiveClient }) + + It("deletes Initials, as a client", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionInitial, + })) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) + handler.DropPackets(protocol.EncryptionInitial) + // DropPackets should be ignored for clients and the Initial packet number space. + // It has to be possible to send another Initial packets after this function was called. + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 10, + EncryptionLevel: protocol.EncryptionInitial, + })) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) + // Sending a Handshake packet triggers dropping of Initials. + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + EncryptionLevel: protocol.EncryptionHandshake, + })) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission + Expect(handler.initialPackets).To(BeNil()) + Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) + }) + }) + + It("deletes Handshake packets", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionHandshake, + })) + } + for i := protocol.PacketNumber(0); i < 10; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.Encryption1RTT, + })) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) + handler.DropPackets(protocol.EncryptionHandshake) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + Expect(handler.handshakePackets).To(BeNil()) + }) + + It("doesn't retransmit 0-RTT packets when 0-RTT keys are dropped", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + if i == 3 { + continue + } + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.Encryption0RTT, + })) + } + for i := protocol.PacketNumber(6); i < 12; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11))) + handler.DropPackets(protocol.Encryption0RTT) + Expect(lostPackets).To(BeEmpty()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) + }) + + It("cancels the PTO when dropping a packet number space", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + now := time.Now() + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) + handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + handler.DropPackets(protocol.EncryptionHandshake) + Expect(handler.ptoCount).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + }) + + Context("peeking and popping packet number", func() { + It("peeks and pops the initial packet number", func() { + pn, _ := handler.PeekPacketNumber(protocol.EncryptionInitial) + Expect(pn).To(Equal(protocol.PacketNumber(42))) + Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(Equal(protocol.PacketNumber(42))) + }) + + It("peeks and pops beyond the initial packet number", func() { + Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(Equal(protocol.PacketNumber(42))) + Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(BeNumerically(">", 42)) + }) + + It("starts at 0 for handshake and application-data packet number space", func() { + pn, _ := handler.PeekPacketNumber(protocol.EncryptionHandshake) + Expect(pn).To(BeZero()) + Expect(handler.PopPacketNumber(protocol.EncryptionHandshake)).To(BeZero()) + pn, _ = handler.PeekPacketNumber(protocol.Encryption1RTT) + Expect(pn).To(BeZero()) + Expect(handler.PopPacketNumber(protocol.Encryption1RTT)).To(BeZero()) + }) + }) + + Context("for the client", func() { + BeforeEach(func() { + perspective = protocol.PerspectiveClient + }) + + It("considers the server's address validated right away", func() { + }) + + It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() { + handler.SentPacket(initialPacket(&Packet{PacketNumber: 42})) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.bytesInFlight).ToNot(BeZero()) + Expect(handler.SendMode()).To(Equal(SendAny)) + // now receive a Retry + Expect(handler.ResetForRetry()).To(Succeed()) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{42})) + Expect(handler.bytesInFlight).To(BeZero()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendAny)) + Expect(handler.ptoCount).To(BeZero()) + }) + + It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() { + var lostInitial, lost0RTT bool + handler.SentPacket(&Packet{ + PacketNumber: 13, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{ + {Frame: &wire.CryptoFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lostInitial = true }}, + }, + Length: 100, + }) + pn := handler.PopPacketNumber(protocol.Encryption0RTT) + handler.SentPacket(&Packet{ + PacketNumber: pn, + EncryptionLevel: protocol.Encryption0RTT, + Frames: []Frame{ + {Frame: &wire.StreamFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lost0RTT = true }}, + }, + Length: 999, + }) + Expect(handler.bytesInFlight).ToNot(BeZero()) + // now receive a Retry + Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.bytesInFlight).To(BeZero()) + Expect(lostInitial).To(BeTrue()) + Expect(lost0RTT).To(BeTrue()) + + // make sure we keep increasing the packet number for 0-RTT packets + Expect(handler.PopPacketNumber(protocol.Encryption0RTT)).To(BeNumerically(">", pn)) + }) + + It("uses a Retry for an RTT estimate, if it was not retransmitted", func() { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 42, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-500 * time.Millisecond), + })) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 43, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-10 * time.Millisecond), + })) + Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.rttStats.SmoothedRTT()).To(BeNumerically("~", 500*time.Millisecond, 100*time.Millisecond)) + }) + + It("uses a Retry for an RTT estimate, but doesn't set the RTT to a value lower than 5ms", func() { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 42, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-500 * time.Microsecond), + })) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 43, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-10 * time.Microsecond), + })) + Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(minRTTAfterRetry)) + }) + + It("doesn't use a Retry for an RTT estimate, if it was not retransmitted", func() { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 42, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-800 * time.Millisecond), + })) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 43, + EncryptionLevel: protocol.EncryptionInitial, + SendTime: time.Now().Add(-100 * time.Millisecond), + })) + Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) + }) + }) +}) diff --git a/internal/quic-go/ackhandler/sent_packet_history.go b/internal/quic-go/ackhandler/sent_packet_history.go new file mode 100644 index 00000000..f6acc2be --- /dev/null +++ b/internal/quic-go/ackhandler/sent_packet_history.go @@ -0,0 +1,108 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type sentPacketHistory struct { + rttStats *utils.RTTStats + packetList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement + highestSent protocol.PacketNumber +} + +func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { + return &sentPacketHistory{ + rttStats: rttStats, + packetList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + highestSent: protocol.InvalidPacketNumber, + } +} + +func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { + if p.PacketNumber <= h.highestSent { + panic("non-sequential packet number use") + } + // Skipped packet numbers. + for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { + el := h.packetList.PushBack(Packet{ + PacketNumber: pn, + EncryptionLevel: p.EncryptionLevel, + SendTime: p.SendTime, + skippedPacket: true, + }) + h.packetMap[pn] = el + } + h.highestSent = p.PacketNumber + + if isAckEliciting { + el := h.packetList.PushBack(*p) + h.packetMap[p.PacketNumber] = el + } +} + +// Iterate iterates through all packets. +func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { + cont := true + var next *PacketElement + for el := h.packetList.Front(); cont && el != nil; el = next { + var err error + next = el.Next() + cont, err = cb(&el.Value) + if err != nil { + return err + } + } + return nil +} + +// FirstOutStanding returns the first outstanding packet. +func (h *sentPacketHistory) FirstOutstanding() *Packet { + for el := h.packetList.Front(); el != nil; el = el.Next() { + p := &el.Value + if !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket { + return p + } + } + return nil +} + +func (h *sentPacketHistory) Len() int { + return len(h.packetMap) +} + +func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { + el, ok := h.packetMap[p] + if !ok { + return fmt.Errorf("packet %d not found in sent packet history", p) + } + h.packetList.Remove(el) + delete(h.packetMap, p) + return nil +} + +func (h *sentPacketHistory) HasOutstandingPackets() bool { + return h.FirstOutstanding() != nil +} + +func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { + maxAge := 3 * h.rttStats.PTO(false) + var nextEl *PacketElement + for el := h.packetList.Front(); el != nil; el = nextEl { + nextEl = el.Next() + p := el.Value + if p.SendTime.After(now.Add(-maxAge)) { + break + } + if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes + continue + } + delete(h.packetMap, p.PacketNumber) + h.packetList.Remove(el) + } +} diff --git a/internal/quic-go/ackhandler/sent_packet_history_test.go b/internal/quic-go/ackhandler/sent_packet_history_test.go new file mode 100644 index 00000000..b3cf2f0e --- /dev/null +++ b/internal/quic-go/ackhandler/sent_packet_history_test.go @@ -0,0 +1,263 @@ +package ackhandler + +import ( + "errors" + "time" + + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("SentPacketHistory", func() { + var ( + hist *sentPacketHistory + rttStats *utils.RTTStats + ) + + expectInHistory := func(packetNumbers []protocol.PacketNumber) { + var mapLen int + for _, el := range hist.packetMap { + if !el.Value.skippedPacket { + mapLen++ + } + } + var listLen int + for el := hist.packetList.Front(); el != nil; el = el.Next() { + if !el.Value.skippedPacket { + listLen++ + } + } + ExpectWithOffset(1, mapLen).To(Equal(len(packetNumbers))) + ExpectWithOffset(1, listLen).To(Equal(len(packetNumbers))) + i := 0 + err := hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } + pn := packetNumbers[i] + ExpectWithOffset(1, p.PacketNumber).To(Equal(pn)) + ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn)) + i++ + return true, nil + }) + Expect(err).ToNot(HaveOccurred()) + } + + BeforeEach(func() { + rttStats = utils.NewRTTStats() + hist = newSentPacketHistory(rttStats) + }) + + It("saves sent packets", func() { + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 3}, true) + hist.SentPacket(&Packet{PacketNumber: 4}, true) + expectInHistory([]protocol.PacketNumber{1, 3, 4}) + }) + + It("doesn't save non-ack-eliciting packets", func() { + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 3}, false) + hist.SentPacket(&Packet{PacketNumber: 4}, true) + expectInHistory([]protocol.PacketNumber{1, 4}) + for el := hist.packetList.Front(); el != nil; el = el.Next() { + Expect(el.Value.PacketNumber).ToNot(Equal(protocol.PacketNumber(3))) + } + }) + + It("gets the length", func() { + hist.SentPacket(&Packet{PacketNumber: 0}, true) + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 2}, true) + Expect(hist.Len()).To(Equal(3)) + }) + + Context("getting the first outstanding packet", func() { + It("gets nil, if there are no packets", func() { + Expect(hist.FirstOutstanding()).To(BeNil()) + }) + + It("gets the first outstanding packet", func() { + hist.SentPacket(&Packet{PacketNumber: 2}, true) + hist.SentPacket(&Packet{PacketNumber: 3}, true) + front := hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) + }) + + It("doesn't regard path MTU packets as outstanding", func() { + hist.SentPacket(&Packet{PacketNumber: 2}, true) + hist.SentPacket(&Packet{PacketNumber: 4, IsPathMTUProbePacket: true}, true) + front := hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) + }) + }) + + It("removes packets", func() { + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 4}, true) + hist.SentPacket(&Packet{PacketNumber: 8}, true) + err := hist.Remove(4) + Expect(err).ToNot(HaveOccurred()) + expectInHistory([]protocol.PacketNumber{1, 8}) + }) + + It("errors when trying to remove a non existing packet", func() { + hist.SentPacket(&Packet{PacketNumber: 1}, true) + err := hist.Remove(2) + Expect(err).To(MatchError("packet 2 not found in sent packet history")) + }) + + Context("iterating", func() { + BeforeEach(func() { + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 4}, true) + hist.SentPacket(&Packet{PacketNumber: 8}, true) + }) + + It("iterates over all packets", func() { + var iterations []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } + iterations = append(iterations, p.PacketNumber) + return true, nil + })).To(Succeed()) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) + }) + + It("also iterates over skipped packets", func() { + var packets, skippedPackets []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + skippedPackets = append(skippedPackets, p.PacketNumber) + } else { + packets = append(packets, p.PacketNumber) + } + return true, nil + })).To(Succeed()) + Expect(packets).To(Equal([]protocol.PacketNumber{1, 4, 8})) + Expect(skippedPackets).To(Equal([]protocol.PacketNumber{0, 2, 3, 5, 6, 7})) + }) + + It("stops iterating", func() { + var iterations []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } + iterations = append(iterations, p.PacketNumber) + return p.PacketNumber != 4, nil + })).To(Succeed()) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) + }) + + It("returns the error", func() { + testErr := errors.New("test error") + var iterations []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } + iterations = append(iterations, p.PacketNumber) + if p.PacketNumber == 4 { + return false, testErr + } + return true, nil + })).To(MatchError(testErr)) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) + }) + + It("allows deletions", func() { + var iterations []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } + iterations = append(iterations, p.PacketNumber) + if p.PacketNumber == 4 { + Expect(hist.Remove(4)).To(Succeed()) + } + return true, nil + })).To(Succeed()) + expectInHistory([]protocol.PacketNumber{1, 8}) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) + }) + }) + + Context("outstanding packets", func() { + It("says if it has outstanding packets", func() { + Expect(hist.HasOutstandingPackets()).To(BeFalse()) + hist.SentPacket(&Packet{EncryptionLevel: protocol.Encryption1RTT}, true) + Expect(hist.HasOutstandingPackets()).To(BeTrue()) + }) + + It("accounts for deleted packets", func() { + hist.SentPacket(&Packet{ + PacketNumber: 10, + EncryptionLevel: protocol.Encryption1RTT, + }, true) + Expect(hist.HasOutstandingPackets()).To(BeTrue()) + Expect(hist.Remove(10)).To(Succeed()) + Expect(hist.HasOutstandingPackets()).To(BeFalse()) + }) + + It("counts the number of packets", func() { + hist.SentPacket(&Packet{ + PacketNumber: 10, + EncryptionLevel: protocol.Encryption1RTT, + }, true) + hist.SentPacket(&Packet{ + PacketNumber: 11, + EncryptionLevel: protocol.Encryption1RTT, + }, true) + Expect(hist.Remove(11)).To(Succeed()) + Expect(hist.HasOutstandingPackets()).To(BeTrue()) + Expect(hist.Remove(10)).To(Succeed()) + Expect(hist.HasOutstandingPackets()).To(BeFalse()) + }) + }) + + Context("deleting old packets", func() { + const pto = 3 * time.Second + + BeforeEach(func() { + rttStats.UpdateRTT(time.Second, 0, time.Time{}) + Expect(rttStats.PTO(false)).To(Equal(pto)) + }) + + It("deletes old packets after 3 PTOs", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) + expectInHistory([]protocol.PacketNumber{10}) + hist.DeleteOldPackets(now.Add(-time.Nanosecond)) + expectInHistory([]protocol.PacketNumber{10}) + hist.DeleteOldPackets(now) + expectInHistory([]protocol.PacketNumber{}) + }) + + It("doesn't delete a packet if it hasn't been declared lost yet", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) + hist.SentPacket(&Packet{PacketNumber: 11, SendTime: now.Add(-3 * pto), declaredLost: false}, true) + expectInHistory([]protocol.PacketNumber{10, 11}) + hist.DeleteOldPackets(now) + expectInHistory([]protocol.PacketNumber{11}) + }) + + It("deletes skipped packets", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto)}, true) + expectInHistory([]protocol.PacketNumber{10}) + Expect(hist.Len()).To(Equal(11)) + hist.DeleteOldPackets(now) + expectInHistory([]protocol.PacketNumber{10}) // the packet was not declared lost + Expect(hist.Len()).To(Equal(1)) + }) + }) +}) diff --git a/internal/quic-go/buffer_pool.go b/internal/quic-go/buffer_pool.go new file mode 100644 index 00000000..c8d50a43 --- /dev/null +++ b/internal/quic-go/buffer_pool.go @@ -0,0 +1,80 @@ +package quic + +import ( + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +type packetBuffer struct { + Data []byte + + // refCount counts how many packets Data is used in. + // It doesn't support concurrent use. + // It is > 1 when used for coalesced packet. + refCount int +} + +// Split increases the refCount. +// It must be called when a packet buffer is used for more than one packet, +// e.g. when splitting coalesced packets. +func (b *packetBuffer) Split() { + b.refCount++ +} + +// Decrement decrements the reference counter. +// It doesn't put the buffer back into the pool. +func (b *packetBuffer) Decrement() { + b.refCount-- + if b.refCount < 0 { + panic("negative packetBuffer refCount") + } +} + +// MaybeRelease puts the packet buffer back into the pool, +// if the reference counter already reached 0. +func (b *packetBuffer) MaybeRelease() { + // only put the packetBuffer back if it's not used any more + if b.refCount == 0 { + b.putBack() + } +} + +// Release puts back the packet buffer into the pool. +// It should be called when processing is definitely finished. +func (b *packetBuffer) Release() { + b.Decrement() + if b.refCount != 0 { + panic("packetBuffer refCount not zero") + } + b.putBack() +} + +// Len returns the length of Data +func (b *packetBuffer) Len() protocol.ByteCount { + return protocol.ByteCount(len(b.Data)) +} + +func (b *packetBuffer) putBack() { + if cap(b.Data) != int(protocol.MaxPacketBufferSize) { + panic("putPacketBuffer called with packet of wrong size!") + } + bufferPool.Put(b) +} + +var bufferPool sync.Pool + +func getPacketBuffer() *packetBuffer { + buf := bufferPool.Get().(*packetBuffer) + buf.refCount = 1 + buf.Data = buf.Data[:0] + return buf +} + +func init() { + bufferPool.New = func() interface{} { + return &packetBuffer{ + Data: make([]byte, 0, protocol.MaxPacketBufferSize), + } + } +} diff --git a/internal/quic-go/buffer_pool_test.go b/internal/quic-go/buffer_pool_test.go new file mode 100644 index 00000000..8e28ad02 --- /dev/null +++ b/internal/quic-go/buffer_pool_test.go @@ -0,0 +1,55 @@ +package quic + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Buffer Pool", func() { + It("returns buffers of cap", func() { + buf := getPacketBuffer() + Expect(buf.Data).To(HaveCap(int(protocol.MaxPacketBufferSize))) + }) + + It("releases buffers", func() { + buf := getPacketBuffer() + buf.Release() + }) + + It("gets the length", func() { + buf := getPacketBuffer() + buf.Data = append(buf.Data, []byte("foobar")...) + Expect(buf.Len()).To(BeEquivalentTo(6)) + }) + + It("panics if wrong-sized buffers are passed", func() { + buf := getPacketBuffer() + buf.Data = make([]byte, 10) + Expect(func() { buf.Release() }).To(Panic()) + }) + + It("panics if it is released twice", func() { + buf := getPacketBuffer() + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) + }) + + It("panics if it is decremented too many times", func() { + buf := getPacketBuffer() + buf.Decrement() + Expect(func() { buf.Decrement() }).To(Panic()) + }) + + It("waits until all parts have been released", func() { + buf := getPacketBuffer() + buf.Split() + buf.Split() + // now we have 3 parts + buf.Decrement() + buf.Decrement() + buf.Decrement() + Expect(func() { buf.Decrement() }).To(Panic()) + }) +}) diff --git a/internal/quic-go/client.go b/internal/quic-go/client.go new file mode 100644 index 00000000..3bcdbece --- /dev/null +++ b/internal/quic-go/client.go @@ -0,0 +1,339 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "strings" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type client struct { + sconn sendConn + // If the client is created with DialAddr, we create a packet conn. + // If it is started with Dial, we take a packet conn as a parameter. + createdPacketConn bool + + use0RTT bool + + packetHandlers packetHandlerManager + + tlsConf *tls.Config + config *Config + + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + + initialPacketNumber protocol.PacketNumber + hasNegotiatedVersion bool + version protocol.VersionNumber + + handshakeChan chan struct{} + + conn quicConn + + tracer logging.ConnectionTracer + tracingID uint64 + logger utils.Logger +} + +var ( + // make it possible to mock connection ID generation in the tests + generateConnectionID = protocol.GenerateConnectionID + generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial +) + +// DialAddr establishes a new QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +// The hostname for SNI is taken from the given address. +// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. +func DialAddr( + addr string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return DialAddrContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +// The hostname for SNI is taken from the given address. +// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. +func DialAddrEarly( + addr string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. +// See DialAddrEarly for details +func DialAddrEarlyContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) + if err != nil { + return nil, err + } + utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") + return conn, nil +} + +// DialAddrContext establishes a new QUIC connection to a server using the provided context. +// See DialAddr for details. +func DialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialAddrContext(ctx, addr, tlsConf, config, false) +} + +func dialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, + use0RTT bool, +) (quicConn, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) +} + +// Dial establishes a new QUIC connection to a server using a net.PacketConn. If +// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. The same PacketConn can be used for multiple calls to Dial and +// Listen, QUIC connection IDs are used for demultiplexing the different +// connections. The host parameter is used for SNI. The tls.Config must define +// an application protocol (using NextProtos). +func Dial( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. +// The same PacketConn can be used for multiple calls to Dial and Listen, +// QUIC connection IDs are used for demultiplexing the different connections. +// The host parameter is used for SNI. +// The tls.Config must define an application protocol (using NextProtos). +func DialEarly( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) +} + +// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. +// See DialEarly for details. +func DialEarlyContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) +} + +// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. +// See Dial for details. +func DialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) +} + +func dialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, + use0RTT bool, + createdPacketConn bool, +) (quicConn, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(config); err != nil { + return nil, err + } + config = populateClientConfig(config, createdPacketConn) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + if err != nil { + return nil, err + } + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + + c.tracingID = nextConnTracingID() + if c.config.Tracer != nil { + c.tracer = c.config.Tracer.TracerForConnection( + context.WithValue(ctx, ConnectionTracingKey, c.tracingID), + protocol.PerspectiveClient, + c.destConnID, + ) + } + if c.tracer != nil { + c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) + } + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.conn, nil +} + +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + config *Config, + tlsConf *tls.Config, + host string, + use0RTT bool, + createdPacketConn bool, +) (*client, error) { + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + if tlsConf.ServerName == "" { + sni := host + if strings.IndexByte(sni, ':') != -1 { + var err error + sni, _, err = net.SplitHostPort(sni) + if err != nil { + return nil, err + } + } + + tlsConf.ServerName = sni + } + + // check that all versions are actually supported + if config != nil { + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + } + + srcConnID, err := generateConnectionID(config.ConnectionIDLength) + if err != nil { + return nil, err + } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } + c := &client{ + srcConnID: srcConnID, + destConnID: destConnID, + sconn: newSendPconn(pconn, remoteAddr), + createdPacketConn: createdPacketConn, + use0RTT: use0RTT, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), + } + return c, nil +} + +func (c *client) dial(ctx context.Context) error { + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + c.conn = newClientConnection( + c.sconn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) + c.packetHandlers.Add(c.srcConnID, c.conn) + + errorChan := make(chan error, 1) + go func() { + err := c.conn.run() // returns as soon as the connection is closed + + if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { + c.packetHandlers.Destroy() + } + errorChan <- err + }() + + // only set when we're using 0-RTT + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if c.use0RTT { + earlyConnChan = c.conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + c.conn.shutdown() + return ctx.Err() + case err := <-errorChan: + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) + } + return err + case <-earlyConnChan: + // ready to send 0-RTT data + return nil + case <-c.conn.HandshakeComplete().Done(): + // handshake successfully completed + return nil + } +} diff --git a/internal/quic-go/client_test.go b/internal/quic-go/client_test.go new file mode 100644 index 00000000..5dd3bf00 --- /dev/null +++ b/internal/quic-go/client_test.go @@ -0,0 +1,611 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "net" + "os" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Client", func() { + var ( + cl *client + packetConn *MockPacketConn + addr net.Addr + connID protocol.ConnectionID + mockMultiplexer *MockMultiplexer + origMultiplexer multiplexer + tlsConf *tls.Config + tracer *mocklogging.MockConnectionTracer + config *Config + + originalClientConnConstructor func( + conn sendConn, + runner connRunner, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + conf *Config, + tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, + enable0RTT bool, + hasNegotiatedVersion bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, + ) quicConn + ) + + BeforeEach(func() { + tlsConf = &tls.Config{NextProtos: []string{"proto1"}} + connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} + originalClientConnConstructor = newClientConnection + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tr := mocklogging.NewMockTracer(mockCtrl) + tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) + config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} + Eventually(areConnsRunning).Should(BeFalse()) + addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + packetConn = NewMockPacketConn(mockCtrl) + packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + cl = &client{ + srcConnID: connID, + destConnID: connID, + version: protocol.VersionTLS, + sconn: newSendPconn(packetConn, addr), + tracer: tracer, + logger: utils.DefaultLogger, + } + getMultiplexer() // make the sync.Once execute + // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer + mockMultiplexer = NewMockMultiplexer(mockCtrl) + origMultiplexer = connMuxer + connMuxer = mockMultiplexer + }) + + AfterEach(func() { + connMuxer = origMultiplexer + newClientConnection = originalClientConnConstructor + }) + + AfterEach(func() { + if s, ok := cl.conn.(*connection); ok { + s.shutdown() + } + Eventually(areConnsRunning).Should(BeFalse()) + }) + + Context("Dialing", func() { + var origGenerateConnectionID func(int) (protocol.ConnectionID, error) + var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) + + BeforeEach(func() { + origGenerateConnectionID = generateConnectionID + origGenerateConnectionIDForInitial = generateConnectionIDForInitial + generateConnectionID = func(int) (protocol.ConnectionID, error) { + return connID, nil + } + generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { + return connID, nil + } + }) + + AfterEach(func() { + generateConnectionID = origGenerateConnectionID + generateConnectionIDForInitial = origGenerateConnectionIDForInitial + }) + + It("resolves the address", func() { + if os.Getenv("APPVEYOR") == "True" { + Skip("This test is flaky on AppVeyor.") + } + + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + manager.EXPECT().Destroy() + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + remoteAddrChan := make(chan string, 1) + newClientConnection = func( + sconn sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + remoteAddrChan <- sconn.RemoteAddr().String() + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) + Expect(err).ToNot(HaveOccurred()) + Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) + }) + + It("uses the tls.Config.ServerName as the hostname, if present", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + manager.EXPECT().Destroy() + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + hostnameChan := make(chan string, 1) + newClientConnection = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + tlsConf *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + hostnameChan <- tlsConf.ServerName + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + tlsConf.ServerName = "foobar" + _, err := DialAddr("localhost:17890", tlsConf, nil) + Expect(err).ToNot(HaveOccurred()) + Eventually(hostnameChan).Should(Receive(Equal("foobar"))) + }) + + It("allows passing host without port as server name", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + hostnameChan := make(chan string, 1) + newClientConnection = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + tlsConf *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + hostnameChan <- tlsConf.ServerName + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().run() + return conn + } + tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, gomock.Any(), gomock.Any()) + _, err := Dial( + packetConn, + addr, + "test.com", + tlsConf, + config, + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(hostnameChan).Should(Receive(Equal("test.com"))) + }) + + It("returns after the handshake is complete", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + run := make(chan struct{}) + newClientConnection = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + enable0RTT bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(enable0RTT).To(BeFalse()) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { close(run) }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn.EXPECT().HandshakeComplete().Return(ctx) + return conn + } + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + s, err := Dial( + packetConn, + addr, + "localhost:1337", + tlsConf, + config, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Eventually(run).Should(BeClosed()) + }) + + It("returns early connections", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + readyChan := make(chan struct{}) + done := make(chan struct{}) + newClientConnection = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + enable0RTT bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(enable0RTT).To(BeTrue()) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { <-done }) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().earlyConnReady().Return(readyChan) + return conn + } + + go func() { + defer GinkgoRecover() + defer close(done) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + s, err := DialEarly( + packetConn, + addr, + "localhost:1337", + tlsConf, + config, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + }() + Consistently(done).ShouldNot(BeClosed()) + close(readyChan) + Eventually(done).Should(BeClosed()) + }) + + It("returns an error that occurs while waiting for the handshake to complete", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + testErr := errors.New("early handshake error") + newClientConnection = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Return(testErr) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := Dial( + packetConn, + addr, + "localhost:1337", + tlsConf, + config, + ) + Expect(err).To(MatchError(testErr)) + }) + + It("closes the connection when the context is canceled", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + connRunning := make(chan struct{}) + defer close(connRunning) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { + <-connRunning + }) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + newClientConnection = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + return conn + } + ctx, cancel := context.WithCancel(context.Background()) + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := DialContext( + ctx, + packetConn, + addr, + "localhost:1337", + tlsConf, + config, + ) + Expect(err).To(MatchError(context.Canceled)) + close(dialed) + }() + Consistently(dialed).ShouldNot(BeClosed()) + conn.EXPECT().shutdown() + cancel() + Eventually(dialed).Should(BeClosed()) + }) + + It("closes the connection when it was created by DialAddr", func() { + if os.Getenv("APPVEYOR") == "True" { + Skip("This test is flaky on AppVeyor.") + } + + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + + var sconn sendConn + run := make(chan struct{}) + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + newClientConnection = func( + connP sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + sconn = connP + close(connCreated) + return conn + } + conn.EXPECT().run().Do(func() { + <-run + }) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := DialAddr("localhost:1337", tlsConf, nil) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + Eventually(connCreated).Should(BeClosed()) + + // check that the connection is not closed + Expect(sconn.Write([]byte("foobar"))).To(Succeed()) + + manager.EXPECT().Destroy() + close(run) + time.Sleep(50 * time.Millisecond) + + Eventually(done).Should(BeClosed()) + }) + + Context("quic.Config", func() { + It("setups with the right values", func() { + tokenStore := NewLRUTokenStore(10, 4) + config := &Config{ + HandshakeIdleTimeout: 1337 * time.Minute, + MaxIdleTimeout: 42 * time.Hour, + MaxIncomingStreams: 1234, + MaxIncomingUniStreams: 4321, + ConnectionIDLength: 13, + StatelessResetKey: []byte("foobar"), + TokenStore: tokenStore, + EnableDatagrams: true, + } + c := populateClientConfig(config, false) + Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) + Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) + Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) + Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) + Expect(c.ConnectionIDLength).To(Equal(13)) + Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) + Expect(c.TokenStore).To(Equal(tokenStore)) + Expect(c.EnableDatagrams).To(BeTrue()) + }) + + It("errors when the Config contains an invalid version", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + version := protocol.VersionNumber(0x1234) + _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) + Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + }) + + It("disables bidirectional streams", func() { + config := &Config{ + MaxIncomingStreams: -1, + MaxIncomingUniStreams: 4321, + } + c := populateClientConfig(config, false) + Expect(c.MaxIncomingStreams).To(BeZero()) + Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) + }) + + It("disables unidirectional streams", func() { + config := &Config{ + MaxIncomingStreams: 1234, + MaxIncomingUniStreams: -1, + } + c := populateClientConfig(config, false) + Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) + Expect(c.MaxIncomingUniStreams).To(BeZero()) + }) + + It("uses 0-byte connection IDs when dialing an address", func() { + c := populateClientConfig(&Config{}, true) + Expect(c.ConnectionIDLength).To(BeZero()) + }) + + It("fills in default values if options are not set in the Config", func() { + c := populateClientConfig(&Config{}, false) + Expect(c.Versions).To(Equal(protocol.SupportedVersions)) + Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) + Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) + }) + }) + + It("creates new connections with the right parameters", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(connID, gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} + c := make(chan struct{}) + var cconn sendConn + var version protocol.VersionNumber + var conf *Config + newClientConnection = func( + connP sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + configP *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + versionP protocol.VersionNumber, + ) quicConn { + cconn = connP + version = versionP + conf = configP + close(c) + // TODO: check connection IDs? + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) + Expect(err).ToNot(HaveOccurred()) + Eventually(c).Should(BeClosed()) + Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn)) + Expect(version).To(Equal(config.Versions[0])) + Expect(conf.Versions).To(Equal(config.Versions)) + }) + + It("creates a new connections after version negotiation", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(connID, gomock.Any()).Times(2) + manager.EXPECT().Destroy() + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + + var counter int + newClientConnection = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + configP *Config, + _ *tls.Config, + pn protocol.PacketNumber, + _ bool, + hasNegotiatedVersion bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + versionP protocol.VersionNumber, + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + if counter == 0 { + Expect(pn).To(BeZero()) + Expect(hasNegotiatedVersion).To(BeFalse()) + conn.EXPECT().run().Return(&errCloseForRecreating{ + nextPacketNumber: 109, + nextVersion: 789, + }) + } else { + Expect(pn).To(Equal(protocol.PacketNumber(109))) + Expect(hasNegotiatedVersion).To(BeTrue()) + conn.EXPECT().run() + } + counter++ + return conn + } + + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := DialAddr("localhost:7890", tlsConf, config) + Expect(err).ToNot(HaveOccurred()) + Expect(counter).To(Equal(2)) + }) + }) +}) diff --git a/internal/quic-go/closed_conn.go b/internal/quic-go/closed_conn.go new file mode 100644 index 00000000..a97861b9 --- /dev/null +++ b/internal/quic-go/closed_conn.go @@ -0,0 +1,112 @@ +package quic + +import ( + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// A closedLocalConn is a connection that we closed locally. +// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, +// with an exponential backoff. +type closedLocalConn struct { + conn sendConn + connClosePacket []byte + + closeOnce sync.Once + closeChan chan struct{} // is closed when the connection is closed or destroyed + + receivedPackets chan *receivedPacket + counter uint64 // number of packets received + + perspective protocol.Perspective + + logger utils.Logger +} + +var _ packetHandler = &closedLocalConn{} + +// newClosedLocalConn creates a new closedLocalConn and runs it. +func newClosedLocalConn( + conn sendConn, + connClosePacket []byte, + perspective protocol.Perspective, + logger utils.Logger, +) packetHandler { + s := &closedLocalConn{ + conn: conn, + connClosePacket: connClosePacket, + perspective: perspective, + logger: logger, + closeChan: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, 64), + } + go s.run() + return s +} + +func (s *closedLocalConn) run() { + for { + select { + case p := <-s.receivedPackets: + s.handlePacketImpl(p) + case <-s.closeChan: + return + } + } +} + +func (s *closedLocalConn) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + } +} + +func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { + s.counter++ + // exponential backoff + // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving + for n := s.counter; n > 1; n = n / 2 { + if n%2 != 0 { + return + } + } + s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter) + if err := s.conn.Write(s.connClosePacket); err != nil { + s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) + } +} + +func (s *closedLocalConn) shutdown() { + s.destroy(nil) +} + +func (s *closedLocalConn) destroy(error) { + s.closeOnce.Do(func() { + close(s.closeChan) + }) +} + +func (s *closedLocalConn) getPerspective() protocol.Perspective { + return s.perspective +} + +// A closedRemoteConn is a connection that was closed remotely. +// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. +// We can just ignore those packets. +type closedRemoteConn struct { + perspective protocol.Perspective +} + +var _ packetHandler = &closedRemoteConn{} + +func newClosedRemoteConn(pers protocol.Perspective) packetHandler { + return &closedRemoteConn{perspective: pers} +} + +func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) shutdown() {} +func (s *closedRemoteConn) destroy(error) {} +func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/internal/quic-go/closed_conn_test.go b/internal/quic-go/closed_conn_test.go new file mode 100644 index 00000000..330a9dad --- /dev/null +++ b/internal/quic-go/closed_conn_test.go @@ -0,0 +1,56 @@ +package quic + +import ( + "errors" + "time" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Closed local connection", func() { + var ( + conn packetHandler + mconn *MockSendConn + ) + + BeforeEach(func() { + mconn = NewMockSendConn(mockCtrl) + conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) + }) + + AfterEach(func() { + Eventually(areClosedConnsRunning).Should(BeFalse()) + }) + + It("tells its perspective", func() { + Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) + // stop the connection + conn.shutdown() + }) + + It("repeats the packet containing the CONNECTION_CLOSE frame", func() { + written := make(chan []byte) + mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes() + for i := 1; i <= 20; i++ { + conn.handlePacket(&receivedPacket{}) + if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { + Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE + } else { + Consistently(written, 10*time.Millisecond).Should(HaveLen(0)) + } + } + // stop the connection + conn.shutdown() + }) + + It("destroys connections", func() { + Eventually(areClosedConnsRunning).Should(BeTrue()) + conn.destroy(errors.New("destroy")) + Eventually(areClosedConnsRunning).Should(BeFalse()) + }) +}) diff --git a/internal/quic-go/config.go b/internal/quic-go/config.go new file mode 100644 index 00000000..8f444f5d --- /dev/null +++ b/internal/quic-go/config.go @@ -0,0 +1,124 @@ +package quic + +import ( + "errors" + "time" + + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// Clone clones a Config +func (c *Config) Clone() *Config { + copy := *c + return © +} + +func (c *Config) handshakeTimeout() time.Duration { + return utils.MaxDuration(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) +} + +func validateConfig(config *Config) error { + if config == nil { + return nil + } + if config.MaxIncomingStreams > 1<<60 { + return errors.New("invalid value for Config.MaxIncomingStreams") + } + if config.MaxIncomingUniStreams > 1<<60 { + return errors.New("invalid value for Config.MaxIncomingUniStreams") + } + return nil +} + +// populateServerConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateServerConfig(config *Config) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + if config.AcceptToken == nil { + config.AcceptToken = defaultAcceptToken + } + return config +} + +// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateClientConfig(config *Config, createdPacketConn bool) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 && !createdPacketConn { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + return config +} + +func populateConfig(config *Config) *Config { + if config == nil { + config = &Config{} + } + versions := config.Versions + if len(versions) == 0 { + versions = protocol.SupportedVersions + } + handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout + if config.HandshakeIdleTimeout != 0 { + handshakeIdleTimeout = config.HandshakeIdleTimeout + } + idleTimeout := protocol.DefaultIdleTimeout + if config.MaxIdleTimeout != 0 { + idleTimeout = config.MaxIdleTimeout + } + initialStreamReceiveWindow := config.InitialStreamReceiveWindow + if initialStreamReceiveWindow == 0 { + initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData + } + maxStreamReceiveWindow := config.MaxStreamReceiveWindow + if maxStreamReceiveWindow == 0 { + maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow + } + initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow + if initialConnectionReceiveWindow == 0 { + initialConnectionReceiveWindow = protocol.DefaultInitialMaxData + } + maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow + if maxConnectionReceiveWindow == 0 { + maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow + } + maxIncomingStreams := config.MaxIncomingStreams + if maxIncomingStreams == 0 { + maxIncomingStreams = protocol.DefaultMaxIncomingStreams + } else if maxIncomingStreams < 0 { + maxIncomingStreams = 0 + } + maxIncomingUniStreams := config.MaxIncomingUniStreams + if maxIncomingUniStreams == 0 { + maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams + } else if maxIncomingUniStreams < 0 { + maxIncomingUniStreams = 0 + } + + return &Config{ + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + AcceptToken: config.AcceptToken, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + ConnectionIDLength: config.ConnectionIDLength, + StatelessResetKey: config.StatelessResetKey, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, + Tracer: config.Tracer, + } +} diff --git a/internal/quic-go/config_test.go b/internal/quic-go/config_test.go new file mode 100644 index 00000000..1710bdab --- /dev/null +++ b/internal/quic-go/config_test.go @@ -0,0 +1,180 @@ +package quic + +import ( + "fmt" + "net" + "reflect" + "time" + + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Config", func() { + Context("validating", func() { + It("validates a nil config", func() { + Expect(validateConfig(nil)).To(Succeed()) + }) + + It("validates a config with normal values", func() { + Expect(validateConfig(populateServerConfig(&Config{}))).To(Succeed()) + }) + + It("errors on too large values for MaxIncomingStreams", func() { + Expect(validateConfig(&Config{MaxIncomingStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingStreams")) + }) + + It("errors on too large values for MaxIncomingUniStreams", func() { + Expect(validateConfig(&Config{MaxIncomingUniStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingUniStreams")) + }) + }) + + configWithNonZeroNonFunctionFields := func() *Config { + c := &Config{} + v := reflect.ValueOf(c).Elem() + + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f := v.Field(i) + if !f.CanSet() { + // unexported field; not cloned. + continue + } + + switch fn := typ.Field(i).Name; fn { + case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease": + // Can't compare functions. + case "Versions": + f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) + case "ConnectionIDLength": + f.Set(reflect.ValueOf(8)) + case "HandshakeIdleTimeout": + f.Set(reflect.ValueOf(time.Second)) + case "MaxIdleTimeout": + f.Set(reflect.ValueOf(time.Hour)) + case "TokenStore": + f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) + case "InitialStreamReceiveWindow": + f.Set(reflect.ValueOf(uint64(1234))) + case "MaxStreamReceiveWindow": + f.Set(reflect.ValueOf(uint64(9))) + case "InitialConnectionReceiveWindow": + f.Set(reflect.ValueOf(uint64(4321))) + case "MaxConnectionReceiveWindow": + f.Set(reflect.ValueOf(uint64(10))) + case "MaxIncomingStreams": + f.Set(reflect.ValueOf(int64(11))) + case "MaxIncomingUniStreams": + f.Set(reflect.ValueOf(int64(12))) + case "StatelessResetKey": + f.Set(reflect.ValueOf([]byte{1, 2, 3, 4})) + case "KeepAlivePeriod": + f.Set(reflect.ValueOf(time.Second)) + case "EnableDatagrams": + f.Set(reflect.ValueOf(true)) + case "DisableVersionNegotiationPackets": + f.Set(reflect.ValueOf(true)) + case "DisablePathMTUDiscovery": + f.Set(reflect.ValueOf(true)) + case "Tracer": + f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) + default: + Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) + } + } + return c + } + + It("uses 10s handshake timeout for short handshake idle timeouts", func() { + c := &Config{HandshakeIdleTimeout: time.Second} + Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) + }) + + It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { + c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} + Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) + }) + + Context("cloning", func() { + It("clones function fields", func() { + var calledAcceptToken, calledAllowConnectionWindowIncrease bool + c1 := &Config{ + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, + AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, + } + c2 := c1.Clone() + c2.AcceptToken(&net.UDPAddr{}, &Token{}) + Expect(calledAcceptToken).To(BeTrue()) + c2.AllowConnectionWindowIncrease(nil, 1234) + Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) + }) + + It("clones non-function fields", func() { + c := configWithNonZeroNonFunctionFields() + Expect(c.Clone()).To(Equal(c)) + }) + + It("returns a copy", func() { + c1 := &Config{ + MaxIncomingStreams: 100, + AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, + } + c2 := c1.Clone() + c2.MaxIncomingStreams = 200 + c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + + Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) + Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + }) + }) + + Context("populating", func() { + It("populates function fields", func() { + var calledAcceptToken bool + c1 := &Config{ + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, + } + c2 := populateConfig(c1) + c2.AcceptToken(&net.UDPAddr{}, &Token{}) + Expect(calledAcceptToken).To(BeTrue()) + }) + + It("copies non-function fields", func() { + c := configWithNonZeroNonFunctionFields() + Expect(populateConfig(c)).To(Equal(c)) + }) + + It("populates empty fields with default values", func() { + c := populateConfig(&Config{}) + Expect(c.Versions).To(Equal(protocol.SupportedVersions)) + Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) + Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) + Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) + Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) + Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) + Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) + Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) + Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) + Expect(c.DisablePathMTUDiscovery).To(BeFalse()) + }) + + It("populates empty fields with default values, for the server", func() { + c := populateServerConfig(&Config{}) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + Expect(c.AcceptToken).ToNot(BeNil()) + }) + + It("sets a default connection ID length if we didn't create the conn, for the client", func() { + c := populateClientConfig(&Config{}, false) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) + + It("doesn't set a default connection ID length if we created the conn, for the client", func() { + c := populateClientConfig(&Config{}, true) + Expect(c.ConnectionIDLength).To(BeZero()) + }) + }) +}) diff --git a/internal/quic-go/congestion/bandwidth.go b/internal/quic-go/congestion/bandwidth.go new file mode 100644 index 00000000..a6560980 --- /dev/null +++ b/internal/quic-go/congestion/bandwidth.go @@ -0,0 +1,25 @@ +package congestion + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const infBandwidth Bandwidth = math.MaxUint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/internal/quic-go/congestion/bandwidth_test.go b/internal/quic-go/congestion/bandwidth_test.go new file mode 100644 index 00000000..03162747 --- /dev/null +++ b/internal/quic-go/congestion/bandwidth_test.go @@ -0,0 +1,14 @@ +package congestion + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Bandwidth", func() { + It("converts from time delta", func() { + Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond)) + }) +}) diff --git a/internal/quic-go/congestion/clock.go b/internal/quic-go/congestion/clock.go new file mode 100644 index 00000000..405fae70 --- /dev/null +++ b/internal/quic-go/congestion/clock.go @@ -0,0 +1,18 @@ +package congestion + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() time.Time { + return time.Now() +} diff --git a/internal/quic-go/congestion/congestion_suite_test.go b/internal/quic-go/congestion/congestion_suite_test.go new file mode 100644 index 00000000..6a0f7ed7 --- /dev/null +++ b/internal/quic-go/congestion/congestion_suite_test.go @@ -0,0 +1,13 @@ +package congestion + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestCongestion(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Congestion Suite") +} diff --git a/internal/quic-go/congestion/cubic.go b/internal/quic-go/congestion/cubic.go new file mode 100644 index 00000000..acbb6bcc --- /dev/null +++ b/internal/quic-go/congestion/cubic.go @@ -0,0 +1,214 @@ +package congestion + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// This cubic implementation is based on the one found in Chromiums's QUIC +// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. + +// Constants based on TCP defaults. +// The following constants are in 2^10 fractions of a second instead of ms to +// allow a 10 shift right to divide. + +// 1024*1024^3 (first 1024 is from 0.100^3) +// where 0.100 is 100 ms which is the scaling round trip time. +const ( + cubeScale = 40 + cubeCongestionWindowScale = 410 + cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize + // TODO: when re-enabling cubic, make sure to use the actual packet size here + maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) +) + +const defaultNumConnections = 1 + +// Default Cubic backoff factor +const beta float32 = 0.7 + +// Additional backoff factor when loss occurs in the concave part of the Cubic +// curve. This additional backoff factor is expected to give up bandwidth to +// new concurrent flows and speed up convergence. +const betaLastMax float32 = 0.85 + +// Cubic implements the cubic algorithm from TCP +type Cubic struct { + clock Clock + + // Number of connections to simulate. + numConnections int + + // Time when this cycle started, after last loss event. + epoch time.Time + + // Max congestion window used just before last loss event. + // Note: to improve fairness to other streams an additional back off is + // applied to this value if the new value is below our latest value. + lastMaxCongestionWindow protocol.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount protocol.ByteCount + + // TCP Reno equivalent congestion window in packets. + estimatedTCPcongestionWindow protocol.ByteCount + + // Origin point of cubic function. + originPointCongestionWindow protocol.ByteCount + + // Time to origin point of cubic function in 2^10 fractions of a second. + timeToOriginPoint uint32 + + // Last congestion window in packets computed by cubic function. + lastTargetCongestionWindow protocol.ByteCount +} + +// NewCubic returns a new Cubic instance +func NewCubic(clock Clock) *Cubic { + c := &Cubic{ + clock: clock, + numConnections: defaultNumConnections, + } + c.Reset() + return c +} + +// Reset is called after a timeout to reset the cubic state +func (c *Cubic) Reset() { + c.epoch = time.Time{} + c.lastMaxCongestionWindow = 0 + c.ackedBytesCount = 0 + c.estimatedTCPcongestionWindow = 0 + c.originPointCongestionWindow = 0 + c.timeToOriginPoint = 0 + c.lastTargetCongestionWindow = 0 +} + +func (c *Cubic) alpha() float32 { + // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that + // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. + // We derive the equivalent alpha for an N-connection emulation as: + b := c.beta() + return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) +} + +func (c *Cubic) beta() float32 { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) +} + +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + +// OnApplicationLimited is called on ack arrival when sender is unable to use +// the available congestion window. Resets Cubic state during quiescence. +func (c *Cubic) OnApplicationLimited() { + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} +} + +// CongestionWindowAfterPacketLoss computes a new congestion window to use after +// a loss event. Returns the new congestion window in packets. The new +// congestion window is a multiplicative decrease of our current window. +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { + if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { + // We never reached the old max, so assume we are competing with another + // flow. Use our extra back off factor to allow the other flow to go up. + c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) + } else { + c.lastMaxCongestionWindow = currentCongestionWindow + } + c.epoch = time.Time{} // Reset time. + return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) +} + +// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. +// Returns the new congestion window in packets. The new congestion window +// follows a cubic function that depends on the time passed since last +// packet loss. +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes protocol.ByteCount, + currentCongestionWindow protocol.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) protocol.ByteCount { + c.ackedBytesCount += ackedBytes + + if c.epoch.IsZero() { + // First ACK after a loss event. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. + // Reset estimated_tcp_congestion_window_ to be in sync with cubic. + c.estimatedTCPcongestionWindow = currentCongestionWindow + if c.lastMaxCongestionWindow <= currentCongestionWindow { + c.timeToOriginPoint = 0 + c.originPointCongestionWindow = currentCongestionWindow + } else { + c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) + c.originPointCongestionWindow = c.lastMaxCongestionWindow + } + } + + // Change the time unit from microseconds to 2^10 fractions per second. Take + // the round trip time in account. This is done to allow us to use shift as a + // divide operator. + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. + offset := int64(c.timeToOriginPoint) - elapsedTime + if offset < 0 { + offset = -offset + } + + deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale + var targetCongestionWindow protocol.ByteCount + if elapsedTime > int64(c.timeToOriginPoint) { + targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow + } else { + targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow + } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 + + // We have a new cubic congestion window. + c.lastTargetCongestionWindow = targetCongestionWindow + + // Compute target congestion_window based on cubic target and estimated TCP + // congestion_window, use highest (fastest). + if targetCongestionWindow < c.estimatedTCPcongestionWindow { + targetCongestionWindow = c.estimatedTCPcongestionWindow + } + return targetCongestionWindow +} + +// SetNumConnections sets the number of emulated connections +func (c *Cubic) SetNumConnections(n int) { + c.numConnections = n +} diff --git a/internal/quic-go/congestion/cubic_sender.go b/internal/quic-go/congestion/cubic_sender.go new file mode 100644 index 00000000..12074d90 --- /dev/null +++ b/internal/quic-go/congestion/cubic_sender.go @@ -0,0 +1,316 @@ +package congestion + +import ( + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +const ( + // maxDatagramSize is the default maximum packet size used in the Linux TCP implementation. + // Used in QUIC for congestion window computations in bytes. + initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) + maxBurstPackets = 3 + renoBeta = 0.7 // Reno backoff factor. + minCongestionWindowPackets = 2 + initialCongestionWindow = 32 +) + +type cubicSender struct { + hybridSlowStart HybridSlowStart + rttStats *utils.RTTStats + cubic *Cubic + pacer *pacer + clock Clock + + reno bool + + // Track the largest packet that has been sent. + largestSentPacketNumber protocol.PacketNumber + + // Track the largest packet that has been acked. + largestAckedPacketNumber protocol.PacketNumber + + // Track the largest packet number outstanding when a CWND cutback occurs. + largestSentAtLastCutback protocol.PacketNumber + + // Whether the last loss event caused us to exit slowstart. + // Used for stats collection of slowstartPacketsLost + lastCutbackExitedSlowstart bool + + // Congestion window in bytes. + congestionWindow protocol.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowStartThreshold protocol.ByteCount + + // ACK counter for the Reno implementation. + numAckedPackets uint64 + + initialCongestionWindow protocol.ByteCount + initialMaxCongestionWindow protocol.ByteCount + + maxDatagramSize protocol.ByteCount + + lastState logging.CongestionState + tracer logging.ConnectionTracer +} + +var ( + _ SendAlgorithm = &cubicSender{} + _ SendAlgorithmWithDebugInfos = &cubicSender{} +) + +// NewCubicSender makes a new cubic sender +func NewCubicSender( + clock Clock, + rttStats *utils.RTTStats, + initialMaxDatagramSize protocol.ByteCount, + reno bool, + tracer logging.ConnectionTracer, +) *cubicSender { + return newCubicSender( + clock, + rttStats, + reno, + initialMaxDatagramSize, + initialCongestionWindow*initialMaxDatagramSize, + protocol.MaxCongestionWindowPackets*initialMaxDatagramSize, + tracer, + ) +} + +func newCubicSender( + clock Clock, + rttStats *utils.RTTStats, + reno bool, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow protocol.ByteCount, + tracer logging.ConnectionTracer, +) *cubicSender { + c := &cubicSender{ + rttStats: rttStats, + largestSentPacketNumber: protocol.InvalidPacketNumber, + largestAckedPacketNumber: protocol.InvalidPacketNumber, + largestSentAtLastCutback: protocol.InvalidPacketNumber, + initialCongestionWindow: initialCongestionWindow, + initialMaxCongestionWindow: initialMaxCongestionWindow, + congestionWindow: initialCongestionWindow, + slowStartThreshold: protocol.MaxByteCount, + cubic: NewCubic(clock), + clock: clock, + reno: reno, + tracer: tracer, + maxDatagramSize: initialMaxDatagramSize, + } + c.pacer = newPacer(c.BandwidthEstimate) + if c.tracer != nil { + c.lastState = logging.CongestionStateSlowStart + c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) + } + return c +} + +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time { + return c.pacer.TimeUntilSend() +} + +func (c *cubicSender) HasPacingBudget() bool { + return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize +} + +func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * protocol.MaxCongestionWindowPackets +} + +func (c *cubicSender) minCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * minCongestionWindowPackets +} + +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + _ protocol.ByteCount, + packetNumber protocol.PacketNumber, + bytes protocol.ByteCount, + isRetransmittable bool, +) { + c.pacer.SentPacket(sentTime, bytes) + if !isRetransmittable { + return + } + c.largestSentPacketNumber = packetNumber + c.hybridSlowStart.OnPacketSent(packetNumber) +} + +func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool { + return bytesInFlight < c.GetCongestionWindow() +} + +func (c *cubicSender) InRecovery() bool { + return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback +} + +func (c *cubicSender) InSlowStart() bool { + return c.GetCongestionWindow() < c.slowStartThreshold +} + +func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { + return c.congestionWindow +} + +func (c *cubicSender) MaybeExitSlowStart() { + if c.InSlowStart() && + c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { + // exit slow start + c.slowStartThreshold = c.congestionWindow + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + } +} + +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { + c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) + if c.InRecovery() { + return + } + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) + if c.InSlowStart() { + c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) + } +} + +func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { + // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets + // already sent should be treated as a single loss event, since it's expected. + if packetNumber <= c.largestSentAtLastCutback { + return + } + c.lastCutbackExitedSlowstart = c.InSlowStart() + c.maybeTraceStateChange(logging.CongestionStateRecovery) + + if c.reno { + c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta) + } else { + c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) + } + if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { + c.congestionWindow = minCwnd + } + c.slowStartThreshold = c.congestionWindow + c.largestSentAtLastCutback = c.largestSentPacketNumber + // reset packet count from congestion avoidance mode. We start + // counting again when we're out of recovery. + c.numAckedPackets = 0 +} + +// Called when we receive an ack. Normal TCP tracks how many packets one ack +// represents, but quic has a separate ack for each packet. +func (c *cubicSender) maybeIncreaseCwnd( + _ protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { + // Do not increase the congestion window unless the sender is close to using + // the current window. + if !c.isCwndLimited(priorInFlight) { + c.cubic.OnApplicationLimited() + c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + return + } + if c.congestionWindow >= c.maxCongestionWindow() { + return + } + if c.InSlowStart() { + // TCP slow start, exponential growth, increase by one for each ACK. + c.congestionWindow += c.maxDatagramSize + c.maybeTraceStateChange(logging.CongestionStateSlowStart) + return + } + // Congestion avoidance + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + if c.reno { + // Classic Reno congestion avoidance. + c.numAckedPackets++ + if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { + c.congestionWindow += c.maxDatagramSize + c.numAckedPackets = 0 + } + } else { + c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + } +} + +func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool { + congestionWindow := c.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return true + } + availableBytes := congestionWindow - bytesInFlight + slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 + return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize +} + +// BandwidthEstimate returns the current bandwidth estimate +func (c *cubicSender) BandwidthEstimate() Bandwidth { + srtt := c.rttStats.SmoothedRTT() + if srtt == 0 { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return infBandwidth + } + return BandwidthFromDelta(c.GetCongestionWindow(), srtt) +} + +// OnRetransmissionTimeout is called on an retransmission timeout +func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + c.largestSentAtLastCutback = protocol.InvalidPacketNumber + if !packetsRetransmitted { + return + } + c.hybridSlowStart.Restart() + c.cubic.Reset() + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = c.minCongestionWindow() +} + +// OnConnectionMigration is called when the connection is migrated (?) +func (c *cubicSender) OnConnectionMigration() { + c.hybridSlowStart.Restart() + c.largestSentPacketNumber = protocol.InvalidPacketNumber + c.largestAckedPacketNumber = protocol.InvalidPacketNumber + c.largestSentAtLastCutback = protocol.InvalidPacketNumber + c.lastCutbackExitedSlowstart = false + c.cubic.Reset() + c.numAckedPackets = 0 + c.congestionWindow = c.initialCongestionWindow + c.slowStartThreshold = c.initialMaxCongestionWindow +} + +func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { + if c.tracer == nil || new == c.lastState { + return + } + c.tracer.UpdatedCongestionState(new) + c.lastState = new +} + +func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) { + if s < c.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) + } + cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() + c.maxDatagramSize = s + if cwndIsMinCwnd { + c.congestionWindow = c.minCongestionWindow() + } + c.pacer.SetMaxDatagramSize(s) +} diff --git a/internal/quic-go/congestion/cubic_sender_test.go b/internal/quic-go/congestion/cubic_sender_test.go new file mode 100644 index 00000000..cddab314 --- /dev/null +++ b/internal/quic-go/congestion/cubic_sender_test.go @@ -0,0 +1,526 @@ +package congestion + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const ( + initialCongestionWindowPackets = 10 + defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * maxDatagramSize +) + +type mockClock time.Time + +func (c *mockClock) Now() time.Time { + return time.Time(*c) +} + +func (c *mockClock) Advance(d time.Duration) { + *c = mockClock(time.Time(*c).Add(d)) +} + +const MaxCongestionWindow protocol.ByteCount = 200 * maxDatagramSize + +var _ = Describe("Cubic Sender", func() { + var ( + sender *cubicSender + clock mockClock + bytesInFlight protocol.ByteCount + packetNumber protocol.PacketNumber + ackedPacketNumber protocol.PacketNumber + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + bytesInFlight = 0 + packetNumber = 1 + ackedPacketNumber = 0 + clock = mockClock{} + rttStats = utils.NewRTTStats() + sender = newCubicSender( + &clock, + rttStats, + true, /*reno*/ + protocol.InitialPacketSizeIPv4, + initialCongestionWindowPackets*maxDatagramSize, + MaxCongestionWindow, + nil, + ) + }) + + SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int { + var packetsSent int + for sender.CanSend(bytesInFlight) { + sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true) + packetNumber++ + packetsSent++ + bytesInFlight += packetLength + } + return packetsSent + } + + // Normal is that TCP acks every other segment. + AckNPackets := func(n int) { + rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now()) + sender.MaybeExitSlowStart() + for i := 0; i < n; i++ { + ackedPacketNumber++ + sender.OnPacketAcked(ackedPacketNumber, maxDatagramSize, bytesInFlight, clock.Now()) + } + bytesInFlight -= protocol.ByteCount(n) * maxDatagramSize + clock.Advance(time.Millisecond) + } + + LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { + for i := 0; i < n; i++ { + ackedPacketNumber++ + sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight) + } + bytesInFlight -= protocol.ByteCount(n) * packetLength + } + + // Does not increment acked_packet_number_. + LosePacket := func(number protocol.PacketNumber) { + sender.OnPacketLost(number, maxDatagramSize, bytesInFlight) + bytesInFlight -= maxDatagramSize + } + + SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(maxDatagramSize) } + LoseNPackets := func(n int) { LoseNPacketsLen(n, maxDatagramSize) } + + It("has the right values at startup", func() { + // At startup make sure we are at the default. + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + // Make sure we can send. + Expect(sender.TimeUntilSend(0)).To(BeZero()) + Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) + // And that window is un-affected. + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + + // Fill the send window with data, then verify that we can't send. + SendAvailableSendWindow() + Expect(sender.CanSend(bytesInFlight)).To(BeFalse()) + }) + + It("paces", func() { + rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now()) + clock.Advance(time.Hour) + // Fill the send window with data, then verify that we can't send. + SendAvailableSendWindow() + AckNPackets(1) + delay := sender.TimeUntilSend(bytesInFlight) + Expect(delay).ToNot(BeZero()) + Expect(delay).ToNot(Equal(utils.InfDuration)) + }) + + It("application limited slow start", func() { + // Send exactly 10 packets and ensure the CWND ends at 14 packets. + const numberOfAcks = 5 + // At startup make sure we can send. + Expect(sender.CanSend(0)).To(BeTrue()) + Expect(sender.TimeUntilSend(0)).To(BeZero()) + + SendAvailableSendWindow() + for i := 0; i < numberOfAcks; i++ { + AckNPackets(2) + } + bytesToSend := sender.GetCongestionWindow() + // It's expected 2 acks will arrive when the bytes_in_flight are greater than + // half the CWND. + Expect(bytesToSend).To(Equal(defaultWindowTCP + maxDatagramSize*2*2)) + }) + + It("exponential slow start", func() { + const numberOfAcks = 20 + // At startup make sure we can send. + Expect(sender.CanSend(0)).To(BeTrue()) + Expect(sender.TimeUntilSend(0)).To(BeZero()) + Expect(sender.BandwidthEstimate()).To(Equal(infBandwidth)) + // Make sure we can send. + Expect(sender.TimeUntilSend(0)).To(BeZero()) + + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + cwnd := sender.GetCongestionWindow() + Expect(cwnd).To(Equal(defaultWindowTCP + maxDatagramSize*2*numberOfAcks)) + Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT()))) + }) + + It("slow start packet loss", func() { + const numberOfAcks = 10 + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Lose a packet to exit slow start. + LoseNPackets(1) + packetsInRecoveryWindow := expectedSendWindow / maxDatagramSize + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Recovery phase. We need to ack every packet in the recovery window before + // we exit recovery. + numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize + AckNPackets(int(packetsInRecoveryWindow)) + SendAvailableSendWindow() + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // We need to ack an entire window before we increase CWND by 1. + AckNPackets(int(numberOfPacketsInWindow) - 2) + SendAvailableSendWindow() + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Next ack should increase cwnd by 1. + AckNPackets(1) + expectedSendWindow += maxDatagramSize + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Now RTO and ensure slow start gets reset. + Expect(sender.hybridSlowStart.Started()).To(BeTrue()) + sender.OnRetransmissionTimeout(true) + Expect(sender.hybridSlowStart.Started()).To(BeFalse()) + }) + + It("slow start packet loss PRR", func() { + // Test based on the first example in RFC6937. + // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. + const numberOfAcks = 5 + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + sendWindowBeforeLoss := expectedSendWindow + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Testing TCP proportional rate reduction. + // We should send packets paced over the received acks for the remaining + // outstanding packets. The number of packets before we exit recovery is the + // original CWND minus the packet that has been lost and the one which + // triggered the loss. + remainingPacketsInRecovery := sendWindowBeforeLoss/maxDatagramSize - 2 + + for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ { + AckNPackets(1) + SendAvailableSendWindow() + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + } + + // We need to ack another window before we increase CWND by 1. + numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize + for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ { + AckNPackets(1) + Expect(SendAvailableSendWindow()).To(Equal(1)) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + } + + AckNPackets(1) + expectedSendWindow += maxDatagramSize + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + }) + + It("slow start burst packet loss PRR", func() { + // Test based on the second example in RFC6937, though we also implement + // forward acknowledgements, so the first two incoming acks will trigger + // PRR immediately. + // Ack 20 packets in 10 acks to raise the CWND to 30. + const numberOfAcks = 10 + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Lose one more than the congestion window reduction, so that after loss, + // bytes_in_flight is lesser than the congestion window. + sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow)) + numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/maxDatagramSize + 1 + LoseNPackets(int(numPacketsToLose)) + // Immediately after the loss, ensure at least one packet can be sent. + // Losses without subsequent acks can occur with timer based loss detection. + Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) + AckNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Only 2 packets should be allowed to be sent, per PRR-SSRB + Expect(SendAvailableSendWindow()).To(Equal(2)) + + // Ack the next packet, which triggers another loss. + LoseNPackets(1) + AckNPackets(1) + + // Send 2 packets to simulate PRR-SSRB. + Expect(SendAvailableSendWindow()).To(Equal(2)) + + // Ack the next packet, which triggers another loss. + LoseNPackets(1) + AckNPackets(1) + + // Send 2 packets to simulate PRR-SSRB. + Expect(SendAvailableSendWindow()).To(Equal(2)) + + // Exit recovery and return to sending at the new rate. + for i := 0; i < numberOfAcks; i++ { + AckNPackets(1) + Expect(SendAvailableSendWindow()).To(Equal(1)) + } + }) + + It("RTO congestion window", func() { + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) + + // Expect the window to decrease to the minimum once the RTO fires + // and slow start threshold to be set to 1/2 of the CWND. + sender.OnRetransmissionTimeout(true) + Expect(sender.GetCongestionWindow()).To(Equal(2 * maxDatagramSize)) + Expect(sender.slowStartThreshold).To(Equal(5 * maxDatagramSize)) + }) + + It("RTO congestion window no retransmission", func() { + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + + // Expect the window to remain unchanged if the RTO fires but no + // packets are retransmitted. + sender.OnRetransmissionTimeout(false) + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + }) + + It("tcp cubic reset epoch on quiescence", func() { + const maxCongestionWindow = 50 + const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize + sender = newCubicSender(&clock, rttStats, false, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindowBytes, nil) + + numSent := SendAvailableSendWindow() + + // Make sure we fall out of slow start. + savedCwnd := sender.GetCongestionWindow() + LoseNPackets(1) + Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) + + // Ack the rest of the outstanding packets to get out of recovery. + for i := 1; i < numSent; i++ { + AckNPackets(1) + } + Expect(bytesInFlight).To(BeZero()) + + // Send a new window of data and ack all; cubic growth should occur. + savedCwnd = sender.GetCongestionWindow() + numSent = SendAvailableSendWindow() + for i := 0; i < numSent; i++ { + AckNPackets(1) + } + Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow())) + Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) + Expect(bytesInFlight).To(BeZero()) + + // Quiescent time of 100 seconds + clock.Advance(100 * time.Second) + + // Send new window of data and ack one packet. Cubic epoch should have + // been reset; ensure cwnd increase is not dramatic. + savedCwnd = sender.GetCongestionWindow() + SendAvailableSendWindow() + AckNPackets(1) + Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), maxDatagramSize)) + Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) + }) + + It("multiple losses in one window", func() { + SendAvailableSendWindow() + initialWindow := sender.GetCongestionWindow() + LosePacket(ackedPacketNumber + 1) + postLossWindow := sender.GetCongestionWindow() + Expect(initialWindow).To(BeNumerically(">", postLossWindow)) + LosePacket(ackedPacketNumber + 3) + Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) + LosePacket(packetNumber - 1) + Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) + + // Lose a later packet and ensure the window decreases. + LosePacket(packetNumber) + Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow())) + }) + + It("1 connection congestion avoidance at end of recovery", func() { + // Ack 10 packets in 5 acks to raise the CWND to 20. + const numberOfAcks = 5 + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // No congestion window growth should occur in recovery phase, i.e., until the + // currently outstanding 20 packets are acked. + for i := 0; i < 10; i++ { + // Send our full send window. + SendAvailableSendWindow() + Expect(sender.InRecovery()).To(BeTrue()) + AckNPackets(2) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + } + Expect(sender.InRecovery()).To(BeFalse()) + + // Out of recovery now. Congestion window should not grow during RTT. + for i := protocol.ByteCount(0); i < expectedSendWindow/maxDatagramSize-2; i += 2 { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + } + + // Next ack should cause congestion window to grow by 1MSS. + SendAvailableSendWindow() + AckNPackets(2) + expectedSendWindow += maxDatagramSize + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + }) + + It("no PRR", func() { + SendAvailableSendWindow() + LoseNPackets(9) + AckNPackets(1) + + Expect(sender.GetCongestionWindow()).To(Equal(protocol.ByteCount(renoBeta * float32(defaultWindowTCP)))) + windowInPackets := renoBeta * float32(defaultWindowTCP) / float32(maxDatagramSize) + numSent := SendAvailableSendWindow() + Expect(numSent).To(BeEquivalentTo(windowInPackets)) + }) + + It("reset after connection migration", func() { + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) + + // Starts with slow start. + const numberOfAcks = 10 + for i := 0; i < numberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + + // Loses a packet to exit slow start. + LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. Slow + // start threshold is also updated. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) + Expect(sender.slowStartThreshold).To(Equal(expectedSendWindow)) + + // Resets cwnd and slow start threshold on connection migrations. + sender.OnConnectionMigration() + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + Expect(sender.slowStartThreshold).To(Equal(MaxCongestionWindow)) + Expect(sender.hybridSlowStart.Started()).To(BeFalse()) + }) + + It("slow starts up to the maximum congestion window", func() { + const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize + sender = newCubicSender(&clock, rttStats, true, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) + + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { + sender.MaybeExitSlowStart() + sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) + } + Expect(sender.GetCongestionWindow()).To(Equal(initialMaxCongestionWindow)) + }) + + It("doesn't allow reductions of the maximum packet size", func() { + Expect(func() { sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }).To(Panic()) + }) + + It("slow starts up to maximum congestion window, if larger packets are sent", func() { + const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize + sender = newCubicSender(&clock, rttStats, true, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) + const packetSize = initialMaxDatagramSize + 100 + sender.SetMaxDatagramSize(packetSize) + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { + sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) + } + const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize + Expect(sender.GetCongestionWindow()).To(And( + BeNumerically(">", maxCwnd), + BeNumerically("<=", maxCwnd+packetSize), + )) + }) + + It("limit cwnd increase in congestion avoidance", func() { + // Enable Cubic. + sender = newCubicSender(&clock, rttStats, false, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil) + numSent := SendAvailableSendWindow() + + // Make sure we fall out of slow start. + savedCwnd := sender.GetCongestionWindow() + LoseNPackets(1) + Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) + + // Ack the rest of the outstanding packets to get out of recovery. + for i := 1; i < numSent; i++ { + AckNPackets(1) + } + Expect(bytesInFlight).To(BeZero()) + + savedCwnd = sender.GetCongestionWindow() + SendAvailableSendWindow() + + // Ack packets until the CWND increases. + for sender.GetCongestionWindow() == savedCwnd { + AckNPackets(1) + SendAvailableSendWindow() + } + // Bytes in flight may be larger than the CWND if the CWND isn't an exact + // multiple of the packet sizes being sent. + Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow())) + savedCwnd = sender.GetCongestionWindow() + + // Advance time 2 seconds waiting for an ack. + clock.Advance(2 * time.Second) + + // Ack two packets. The CWND should increase by only one packet. + AckNPackets(2) + Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + maxDatagramSize)) + }) +}) diff --git a/internal/quic-go/congestion/cubic_test.go b/internal/quic-go/congestion/cubic_test.go new file mode 100644 index 00000000..e2fc5d33 --- /dev/null +++ b/internal/quic-go/congestion/cubic_test.go @@ -0,0 +1,239 @@ +package congestion + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const ( + numConnections uint32 = 2 + nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections) + nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections) + nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta) + maxCubicTimeInterval = 30 * time.Millisecond +) + +var _ = Describe("Cubic", func() { + var ( + clock mockClock + cubic *Cubic + ) + + BeforeEach(func() { + clock = mockClock{} + cubic = NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) + }) + + renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount { + return currentCwnd + protocol.ByteCount(float32(maxDatagramSize)*nConnectionAlpha*float32(maxDatagramSize)/float32(currentCwnd)) + } + + cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount { + offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000 + deltaCongestionWindow := 410 * offset * offset * offset * maxDatagramSize >> 40 + return initialCwnd + deltaCongestionWindow + } + + It("works above origin (with tighter bounds)", func() { + // Convex growth. + const rttMin = 100 * time.Millisecond + const rttMinS = float32(rttMin/time.Millisecond) / 1000.0 + currentCwnd := 10 * maxDatagramSize + initialCwnd := currentCwnd + + clock.Advance(time.Millisecond) + initialTime := clock.Now() + expectedFirstCwnd := renoCwnd(currentCwnd) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, initialTime) + Expect(expectedFirstCwnd).To(Equal(currentCwnd)) + + // Normal TCP phase. + // The maximum number of expected reno RTTs can be calculated by + // finding the point where the cubic curve and the reno curve meet. + maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2) + for i := 0; i < maxRenoRtts; i++ { + // Alternatively, we expect it to increase by one, every time we + // receive current_cwnd/Alpha acks back. (This is another way of + // saying we expect cwnd to increase by approximately Alpha once + // we receive current_cwnd number ofacks back). + numAcksThisEpoch := int(float32(currentCwnd/maxDatagramSize) / nConnectionAlpha) + + initialCwndThisEpoch := currentCwnd + for n := 0; n < numAcksThisEpoch; n++ { + // Call once per ACK. + expectedNextCwnd := renoCwnd(currentCwnd) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + Expect(currentCwnd).To(Equal(expectedNextCwnd)) + } + // Our byte-wise Reno implementation is an estimate. We expect + // the cwnd to increase by approximately one MSS every + // cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as + // half a packet for smaller values of current_cwnd. + cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch + Expect(cwndChangeThisEpoch).To(BeNumerically("~", maxDatagramSize, maxDatagramSize/2)) + clock.Advance(100 * time.Millisecond) + } + + for i := 0; i < 54; i++ { + maxAcksThisEpoch := currentCwnd / maxDatagramSize + interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond + for n := 0; n < int(maxAcksThisEpoch); n++ { + clock.Advance(interval) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + // If we allow per-ack updates, every update is a small cubic update. + Expect(currentCwnd).To(Equal(expectedCwnd)) + } + } + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + Expect(currentCwnd).To(Equal(expectedCwnd)) + }) + + It("works above the origin with fine grained cubing", func() { + // Start the test with an artificially large cwnd to prevent Reno + // from over-taking cubic. + currentCwnd := 1000 * maxDatagramSize + initialCwnd := currentCwnd + rttMin := 100 * time.Millisecond + clock.Advance(time.Millisecond) + initialTime := clock.Now() + + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + clock.Advance(600 * time.Millisecond) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + + // We expect the algorithm to perform only non-zero, fine-grained cubic + // increases on every ack in this case. + for i := 0; i < 100; i++ { + clock.Advance(10 * time.Millisecond) + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + // Make sure we are performing cubic increases. + Expect(nextCwnd).To(Equal(expectedCwnd)) + // Make sure that these are non-zero, less-than-packet sized increases. + Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) + cwndDelta := nextCwnd - currentCwnd + Expect(maxDatagramSize / 10).To(BeNumerically(">", cwndDelta)) + currentCwnd = nextCwnd + } + }) + + It("handles per ack updates", func() { + // Start the test with a large cwnd and RTT, to force the first + // increase to be a cubic increase. + initialCwndPackets := 150 + currentCwnd := protocol.ByteCount(initialCwndPackets) * maxDatagramSize + rttMin := 350 * time.Millisecond + + // Initialize the epoch + clock.Advance(time.Millisecond) + // Keep track of the growth of the reno-equivalent cwnd. + rCwnd := renoCwnd(currentCwnd) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + initialCwnd := currentCwnd + + // Simulate the return of cwnd packets in less than + // MaxCubicInterval() time. + maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha) + interval := maxCubicTimeInterval / time.Duration(maxAcks+1) + + // In this scenario, the first increase is dictated by the cubic + // equation, but it is less than one byte, so the cwnd doesn't + // change. Normally, without per-ack increases, any cwnd plateau + // will cause the cwnd to be pinned for MaxCubicTimeInterval(). If + // we enable per-ack updates, the cwnd will continue to grow, + // regardless of the temporary plateau. + clock.Advance(interval) + rCwnd = renoCwnd(rCwnd) + Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd)) + for i := 1; i < maxAcks; i++ { + clock.Advance(interval) + nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + rCwnd = renoCwnd(rCwnd) + // The window shoud increase on every ack. + Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) + Expect(nextCwnd).To(Equal(rCwnd)) + currentCwnd = nextCwnd + } + + // After all the acks are returned from the epoch, we expect the + // cwnd to have increased by nearly one packet. (Not exactly one + // packet, because our byte-wise Reno algorithm is always a slight + // under-estimation). Without per-ack updates, the current_cwnd + // would otherwise be unchanged. + minimumExpectedIncrease := maxDatagramSize * 9 / 10 + Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease)) + }) + + It("handles loss events", func() { + rttMin := 100 * time.Millisecond + currentCwnd := 422 * maxDatagramSize + expectedCwnd := renoCwnd(currentCwnd) + // Initialize the state. + clock.Advance(time.Millisecond) + Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) + + // On the first loss, the last max congestion window is set to the + // congestion window before the loss. + preLossCwnd := currentCwnd + Expect(cubic.lastMaxCongestionWindow).To(BeZero()) + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) + Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd)) + currentCwnd = expectedCwnd + + // On the second loss, the current congestion window has not yet + // reached the last max congestion window. The last max congestion + // window will be reduced by an additional backoff factor to allow + // for competition. + preLossCwnd = currentCwnd + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) + currentCwnd = expectedCwnd + Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow)) + expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax) + Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) + Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow)) + // Simulate an increase, and check that we are below the origin. + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd)) + + // On the final loss, simulate the condition where the congestion + // window had a chance to grow nearly to the last congestion window. + currentCwnd = cubic.lastMaxCongestionWindow - 1 + preLossCwnd = currentCwnd + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) + expectedLastMax = preLossCwnd + Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) + }) + + It("works below origin", func() { + // Concave growth. + rttMin := 100 * time.Millisecond + currentCwnd := 422 * maxDatagramSize + expectedCwnd := renoCwnd(currentCwnd) + // Initialize the state. + clock.Advance(time.Millisecond) + Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) + + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) + currentCwnd = expectedCwnd + // First update after loss to initialize the epoch. + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + // Cubic phase. + for i := 0; i < 40; i++ { + clock.Advance(100 * time.Millisecond) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + } + expectedCwnd = 553632 * maxDatagramSize / 1460 + Expect(currentCwnd).To(Equal(expectedCwnd)) + }) +}) diff --git a/internal/quic-go/congestion/hybrid_slow_start.go b/internal/quic-go/congestion/hybrid_slow_start.go new file mode 100644 index 00000000..035bc0da --- /dev/null +++ b/internal/quic-go/congestion/hybrid_slow_start.go @@ -0,0 +1,113 @@ +package congestion + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// Note(pwestin): the magic clamping numbers come from the original code in +// tcp_cubic.c. +const hybridStartLowWindow = protocol.ByteCount(16) + +// Number of delay samples for detecting the increase of delay. +const hybridStartMinSamples = uint32(8) + +// Exit slow start if the min rtt has increased by more than 1/8th. +const hybridStartDelayFactorExp = 3 // 2^3 = 8 +// The original paper specifies 2 and 8ms, but those have changed over time. +const ( + hybridStartDelayMinThresholdUs = int64(4000) + hybridStartDelayMaxThresholdUs = int64(16000) +) + +// HybridSlowStart implements the TCP hybrid slow start algorithm +type HybridSlowStart struct { + endPacketNumber protocol.PacketNumber + lastSentPacketNumber protocol.PacketNumber + started bool + currentMinRTT time.Duration + rttSampleCount uint32 + hystartFound bool +} + +// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. +func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) { + s.endPacketNumber = lastSent + s.currentMinRTT = 0 + s.rttSampleCount = 0 + s.started = true +} + +// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. +func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool { + return s.endPacketNumber < ack +} + +// ShouldExitSlowStart should be called on every new ack frame, since a new +// RTT measurement can be made then. +// rtt: the RTT for this ack packet. +// minRTT: is the lowest delay (RTT) we have seen during the session. +// congestionWindow: the congestion window in packets. +func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool { + if !s.started { + // Time to start the hybrid slow start. + s.StartReceiveRound(s.lastSentPacketNumber) + } + if s.hystartFound { + return true + } + // Second detection parameter - delay increase detection. + // Compare the minimum delay (s.currentMinRTT) of the current + // burst of packets relative to the minimum delay during the session. + // Note: we only look at the first few(8) packets in each burst, since we + // only want to compare the lowest RTT of the burst relative to previous + // bursts. + s.rttSampleCount++ + if s.rttSampleCount <= hybridStartMinSamples { + if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { + s.currentMinRTT = latestRTT + } + } + // We only need to check this once per round. + if s.rttSampleCount == hybridStartMinSamples { + // Divide minRTT by 8 to get a rtt increase threshold for exiting. + minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) + // Ensure the rtt threshold is never less than 2ms or more than 16ms. + minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + + if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { + s.hystartFound = true + } + } + // Exit from slow start if the cwnd is greater than 16 and + // increasing delay is found. + return congestionWindow >= hybridStartLowWindow && s.hystartFound +} + +// OnPacketSent is called when a packet was sent +func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) { + s.lastSentPacketNumber = packetNumber +} + +// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end +// the round when the final packet of the burst is received and start it on +// the next incoming ack. +func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) { + if s.IsEndOfRound(ackedPacketNumber) { + s.started = false + } +} + +// Started returns true if started +func (s *HybridSlowStart) Started() bool { + return s.started +} + +// Restart the slow start phase +func (s *HybridSlowStart) Restart() { + s.started = false + s.hystartFound = false +} diff --git a/internal/quic-go/congestion/hybrid_slow_start_test.go b/internal/quic-go/congestion/hybrid_slow_start_test.go new file mode 100644 index 00000000..6de9ca8e --- /dev/null +++ b/internal/quic-go/congestion/hybrid_slow_start_test.go @@ -0,0 +1,72 @@ +package congestion + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Hybrid slow start", func() { + var slowStart HybridSlowStart + + BeforeEach(func() { + slowStart = HybridSlowStart{} + }) + + It("works in a simple case", func() { + packetNumber := protocol.PacketNumber(1) + endPacketNumber := protocol.PacketNumber(3) + slowStart.StartReceiveRound(endPacketNumber) + + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) + + // Test duplicates. + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) + + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) + + // Test without a new registered end_packet_number; + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) + + endPacketNumber = 20 + slowStart.StartReceiveRound(endPacketNumber) + for packetNumber < endPacketNumber { + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) + } + packetNumber++ + Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) + }) + + It("works with delay", func() { + rtt := 60 * time.Millisecond + // We expect to detect the increase at +1/8 of the RTT; hence at a typical + // RTT of 60ms the detection will happen at 67.5 ms. + const hybridStartMinSamples = 8 // Number of acks required to trigger. + + endPacketNumber := protocol.PacketNumber(1) + endPacketNumber++ + slowStart.StartReceiveRound(endPacketNumber) + + // Will not trigger since our lowest RTT in our burst is the same as the long + // term RTT provided. + for n := 0; n < hybridStartMinSamples; n++ { + Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse()) + } + endPacketNumber++ + slowStart.StartReceiveRound(endPacketNumber) + for n := 1; n < hybridStartMinSamples; n++ { + Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse()) + } + // Expect to trigger since all packets in this burst was above the long term + // RTT provided. + Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue()) + }) +}) diff --git a/internal/quic-go/congestion/interface.go b/internal/quic-go/congestion/interface.go new file mode 100644 index 00000000..f56ed395 --- /dev/null +++ b/internal/quic-go/congestion/interface.go @@ -0,0 +1,28 @@ +package congestion + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A SendAlgorithm performs congestion control +type SendAlgorithm interface { + TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time + HasPacingBudget() bool + OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) + CanSend(bytesInFlight protocol.ByteCount) bool + MaybeExitSlowStart() + OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) + OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) + OnRetransmissionTimeout(packetsRetransmitted bool) + SetMaxDatagramSize(protocol.ByteCount) +} + +// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos +type SendAlgorithmWithDebugInfos interface { + SendAlgorithm + InSlowStart() bool + InRecovery() bool + GetCongestionWindow() protocol.ByteCount +} diff --git a/internal/quic-go/congestion/pacer.go b/internal/quic-go/congestion/pacer.go new file mode 100644 index 00000000..0dd26607 --- /dev/null +++ b/internal/quic-go/congestion/pacer.go @@ -0,0 +1,77 @@ +package congestion + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +const maxBurstSizePackets = 10 + +// The pacer implements a token bucket pacing algorithm. +type pacer struct { + budgetAtLastSent protocol.ByteCount + maxDatagramSize protocol.ByteCount + lastSentTime time.Time + getAdjustedBandwidth func() uint64 // in bytes/s +} + +func newPacer(getBandwidth func() Bandwidth) *pacer { + p := &pacer{ + maxDatagramSize: initialMaxDatagramSize, + getAdjustedBandwidth: func() uint64 { + // Bandwidth is in bits/s. We need the value in bytes/s. + bw := uint64(getBandwidth() / BytesPerSecond) + // Use a slightly higher value than the actual measured bandwidth. + // RTT variations then won't result in under-utilization of the congestion window. + // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, + // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. + return bw * 5 / 4 + }, + } + p.budgetAtLastSent = p.maxBurstSize() + return p +} + +func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *pacer) Budget(now time.Time) protocol.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return utils.MinByteCount(p.maxBurstSize(), budget) +} + +func (p *pacer) maxBurstSize() protocol.ByteCount { + return utils.MaxByteCount( + protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, + maxBurstSizePackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(utils.MaxDuration( + protocol.MinPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, + )) +} + +func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { + p.maxDatagramSize = s +} diff --git a/internal/quic-go/congestion/pacer_test.go b/internal/quic-go/congestion/pacer_test.go new file mode 100644 index 00000000..e840ff22 --- /dev/null +++ b/internal/quic-go/congestion/pacer_test.go @@ -0,0 +1,131 @@ +package congestion + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Pacer", func() { + var p *pacer + + const packetsPerSecond = 50 + var bandwidth uint64 // in bytes/s + + BeforeEach(func() { + bandwidth = uint64(packetsPerSecond * initialMaxDatagramSize) // 50 full-size packets per second + // The pacer will multiply the bandwidth with 1.25 to achieve a slightly higher pacing speed. + // For the tests, cancel out this factor, so we can do the math using the exact bandwidth. + p = newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) + }) + + It("allows a burst at the beginning", func() { + t := time.Now() + Expect(p.TimeUntilSend()).To(BeZero()) + Expect(p.Budget(t)).To(BeEquivalentTo(maxBurstSizePackets * initialMaxDatagramSize)) + }) + + It("allows a big burst for high pacing rates", func() { + t := time.Now() + bandwidth = uint64(10000 * packetsPerSecond * initialMaxDatagramSize) + Expect(p.TimeUntilSend()).To(BeZero()) + Expect(p.Budget(t)).To(BeNumerically(">", maxBurstSizePackets*initialMaxDatagramSize)) + }) + + It("reduces the budget when sending packets", func() { + t := time.Now() + budget := p.Budget(t) + for budget > 0 { + Expect(p.TimeUntilSend()).To(BeZero()) + Expect(p.Budget(t)).To(Equal(budget)) + p.SentPacket(t, initialMaxDatagramSize) + budget -= initialMaxDatagramSize + } + Expect(p.Budget(t)).To(BeZero()) + Expect(p.TimeUntilSend()).ToNot(BeZero()) + }) + + sendBurst := func(t time.Time) { + for p.Budget(t) > 0 { + p.SentPacket(t, initialMaxDatagramSize) + } + } + + It("paces packets after a burst", func() { + t := time.Now() + sendBurst(t) + // send 100 exactly paced packets + for i := 0; i < 100; i++ { + t2 := p.TimeUntilSend() + Expect(t2.Sub(t)).To(BeNumerically("~", time.Second/packetsPerSecond, time.Nanosecond)) + Expect(p.Budget(t2)).To(BeEquivalentTo(initialMaxDatagramSize)) + p.SentPacket(t2, initialMaxDatagramSize) + t = t2 + } + }) + + It("accounts for non-full-size packets", func() { + t := time.Now() + sendBurst(t) + t2 := p.TimeUntilSend() + Expect(t2.Sub(t)).To(BeNumerically("~", time.Second/packetsPerSecond, time.Nanosecond)) + // send a half-full packet + Expect(p.Budget(t2)).To(BeEquivalentTo(initialMaxDatagramSize)) + size := initialMaxDatagramSize / 2 + p.SentPacket(t2, size) + Expect(p.Budget(t2)).To(Equal(initialMaxDatagramSize - size)) + Expect(p.TimeUntilSend()).To(BeTemporally("~", t2.Add(time.Second/packetsPerSecond/2), time.Nanosecond)) + }) + + It("accumulates budget, if no packets are sent", func() { + t := time.Now() + sendBurst(t) + t2 := p.TimeUntilSend() + Expect(t2).To(BeTemporally(">", t)) + // wait for 5 times the duration + Expect(p.Budget(t.Add(5 * t2.Sub(t)))).To(BeEquivalentTo(5 * initialMaxDatagramSize)) + }) + + It("accumulates budget, if no packets are sent, for larger packet sizes", func() { + t := time.Now() + sendBurst(t) + const packetSize = initialMaxDatagramSize + 200 + p.SetMaxDatagramSize(packetSize) + t2 := p.TimeUntilSend() + Expect(t2).To(BeTemporally(">", t)) + // wait for 5 times the duration + Expect(p.Budget(t.Add(5 * t2.Sub(t)))).To(BeEquivalentTo(5 * packetSize)) + }) + + It("never allows bursts larger than the maximum burst size", func() { + t := time.Now() + sendBurst(t) + Expect(p.Budget(t.Add(time.Hour))).To(BeEquivalentTo(maxBurstSizePackets * initialMaxDatagramSize)) + }) + + It("never allows bursts larger than the maximum burst size, for larger packets", func() { + t := time.Now() + const packetSize = initialMaxDatagramSize + 200 + p.SetMaxDatagramSize(packetSize) + sendBurst(t) + Expect(p.Budget(t.Add(time.Hour))).To(BeEquivalentTo(maxBurstSizePackets * packetSize)) + }) + + It("changes the bandwidth", func() { + t := time.Now() + sendBurst(t) + bandwidth = uint64(5 * initialMaxDatagramSize) // reduce the bandwidth to 5 packet per second + Expect(p.TimeUntilSend()).To(Equal(t.Add(time.Second / 5))) + }) + + It("doesn't pace faster than the minimum pacing duration", func() { + t := time.Now() + sendBurst(t) + bandwidth = uint64(1e6 * initialMaxDatagramSize) + Expect(p.TimeUntilSend()).To(Equal(t.Add(protocol.MinPacingDelay))) + Expect(p.Budget(t.Add(protocol.MinPacingDelay))).To(Equal(protocol.ByteCount(protocol.MinPacingDelay) * initialMaxDatagramSize * 1e6 / 1e9)) + }) +}) diff --git a/internal/quic-go/conn_id_generator.go b/internal/quic-go/conn_id_generator.go new file mode 100644 index 00000000..10f30ae9 --- /dev/null +++ b/internal/quic-go/conn_id_generator.go @@ -0,0 +1,140 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type connIDGenerator struct { + connIDLen int + highestSeq uint64 + + activeSrcConnIDs map[uint64]protocol.ConnectionID + initialClientDestConnID protocol.ConnectionID + + addConnectionID func(protocol.ConnectionID) + getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken + removeConnectionID func(protocol.ConnectionID) + retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func(protocol.ConnectionID, packetHandler) + queueControlFrame func(wire.Frame) + + version protocol.VersionNumber +} + +func newConnIDGenerator( + initialConnectionID protocol.ConnectionID, + initialClientDestConnID protocol.ConnectionID, // nil for the client + addConnectionID func(protocol.ConnectionID), + getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, + removeConnectionID func(protocol.ConnectionID), + retireConnectionID func(protocol.ConnectionID), + replaceWithClosed func(protocol.ConnectionID, packetHandler), + queueControlFrame func(wire.Frame), + version protocol.VersionNumber, +) *connIDGenerator { + m := &connIDGenerator{ + connIDLen: initialConnectionID.Len(), + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + addConnectionID: addConnectionID, + getStatelessResetToken: getStatelessResetToken, + removeConnectionID: removeConnectionID, + retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, + queueControlFrame: queueControlFrame, + version: version, + } + m.activeSrcConnIDs[0] = initialConnectionID + m.initialClientDestConnID = initialClientDestConnID + return m +} + +func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { + if m.connIDLen == 0 { + return nil + } + // The active_connection_id_limit transport parameter is the number of + // connection IDs the peer will store. This limit includes the connection ID + // used during the handshake, and the one sent in the preferred_address + // transport parameter. + // We currently don't send the preferred_address transport parameter, + // so we can issue (limit - 1) connection IDs. + for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ { + if err := m.issueNewConnID(); err != nil { + return err + } + } + return nil +} + +func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { + if seq > m.highestSeq { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), + } + } + connID, ok := m.activeSrcConnIDs[seq] + // We might already have deleted this connection ID, if this is a duplicate frame. + if !ok { + return nil + } + if connID.Equal(sentWithDestConnID) { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), + } + } + m.retireConnectionID(connID) + delete(m.activeSrcConnIDs, seq) + // Don't issue a replacement for the initial connection ID. + if seq == 0 { + return nil + } + return m.issueNewConnID() +} + +func (m *connIDGenerator) issueNewConnID() error { + connID, err := protocol.GenerateConnectionID(m.connIDLen) + if err != nil { + return err + } + m.activeSrcConnIDs[m.highestSeq+1] = connID + m.addConnectionID(connID) + m.queueControlFrame(&wire.NewConnectionIDFrame{ + SequenceNumber: m.highestSeq + 1, + ConnectionID: connID, + StatelessResetToken: m.getStatelessResetToken(connID), + }) + m.highestSeq++ + return nil +} + +func (m *connIDGenerator) SetHandshakeComplete() { + if m.initialClientDestConnID != nil { + m.retireConnectionID(m.initialClientDestConnID) + m.initialClientDestConnID = nil + } +} + +func (m *connIDGenerator) RemoveAll() { + if m.initialClientDestConnID != nil { + m.removeConnectionID(m.initialClientDestConnID) + } + for _, connID := range m.activeSrcConnIDs { + m.removeConnectionID(connID) + } +} + +func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { + if m.initialClientDestConnID != nil { + m.replaceWithClosed(m.initialClientDestConnID, handler) + } + for _, connID := range m.activeSrcConnIDs { + m.replaceWithClosed(connID, handler) + } +} diff --git a/internal/quic-go/conn_id_generator_test.go b/internal/quic-go/conn_id_generator_test.go new file mode 100644 index 00000000..543fce4b --- /dev/null +++ b/internal/quic-go/conn_id_generator_test.go @@ -0,0 +1,187 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection ID Generator", func() { + var ( + addedConnIDs []protocol.ConnectionID + retiredConnIDs []protocol.ConnectionID + removedConnIDs []protocol.ConnectionID + replacedWithClosed map[string]packetHandler + queuedFrames []wire.Frame + g *connIDGenerator + ) + initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} + + connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken { + return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]} + } + + BeforeEach(func() { + addedConnIDs = nil + retiredConnIDs = nil + removedConnIDs = nil + queuedFrames = nil + replacedWithClosed = make(map[string]packetHandler) + g = newConnIDGenerator( + initialConnID, + initialClientDestConnID, + func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, + connIDToToken, + func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, + func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, + func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, + func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, + protocol.VersionDraft29, + ) + }) + + It("issues new connection IDs", func() { + Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) + Expect(retiredConnIDs).To(BeEmpty()) + Expect(addedConnIDs).To(HaveLen(3)) + for i := 0; i < len(addedConnIDs)-1; i++ { + Expect(addedConnIDs[i]).ToNot(Equal(addedConnIDs[i+1])) + } + Expect(queuedFrames).To(HaveLen(3)) + for i := 0; i < 3; i++ { + f := queuedFrames[i] + Expect(f).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) + nf := f.(*wire.NewConnectionIDFrame) + Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1)) + Expect(nf.ConnectionID.Len()).To(Equal(7)) + Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID))) + } + }) + + It("limits the number of connection IDs that it issues", func() { + Expect(g.SetMaxActiveConnIDs(9999999)).To(Succeed()) + Expect(retiredConnIDs).To(BeEmpty()) + Expect(addedConnIDs).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1)) + Expect(queuedFrames).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1)) + }) + + // SetMaxActiveConnIDs is called twice when we dialing a 0-RTT connection: + // once for the restored from the old connections, once when we receive the transport parameters + Context("dealing with 0-RTT", func() { + It("doesn't issue new connection IDs when SetMaxActiveConnIDs is called with the same value", func() { + Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(3)) + queuedFrames = nil + Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) + Expect(queuedFrames).To(BeEmpty()) + }) + + It("issues more connection IDs if the server allows a higher limit on the resumed connection", func() { + Expect(g.SetMaxActiveConnIDs(3)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(2)) + queuedFrames = nil + Expect(g.SetMaxActiveConnIDs(6)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(3)) + }) + + It("issues more connection IDs if the server allows a higher limit on the resumed connection, when connection IDs were retired in between", func() { + Expect(g.SetMaxActiveConnIDs(3)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(2)) + queuedFrames = nil + g.Retire(1, protocol.ConnectionID{}) + Expect(queuedFrames).To(HaveLen(1)) + queuedFrames = nil + Expect(g.SetMaxActiveConnIDs(6)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(3)) + }) + }) + + It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() { + Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "retired connection ID 1 (highest issued: 0)", + })) + }) + + It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() { + Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) + Expect(queuedFrames).ToNot(BeEmpty()) + Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) + f := queuedFrames[0].(*wire.NewConnectionIDFrame) + Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID), + })) + }) + + It("issues new connection IDs, when old ones are retired", func() { + Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) + queuedFrames = nil + Expect(retiredConnIDs).To(BeEmpty()) + Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed()) + Expect(queuedFrames).To(HaveLen(1)) + Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) + nf := queuedFrames[0].(*wire.NewConnectionIDFrame) + Expect(nf.SequenceNumber).To(BeEquivalentTo(5)) + Expect(nf.ConnectionID.Len()).To(Equal(7)) + }) + + It("retires the initial connection ID", func() { + Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed()) + Expect(removedConnIDs).To(BeEmpty()) + Expect(retiredConnIDs).To(HaveLen(1)) + Expect(retiredConnIDs[0]).To(Equal(initialConnID)) + Expect(addedConnIDs).To(BeEmpty()) + }) + + It("handles duplicate retirements", func() { + Expect(g.SetMaxActiveConnIDs(11)).To(Succeed()) + queuedFrames = nil + Expect(retiredConnIDs).To(BeEmpty()) + Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) + Expect(retiredConnIDs).To(HaveLen(1)) + Expect(queuedFrames).To(HaveLen(1)) + Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) + Expect(retiredConnIDs).To(HaveLen(1)) + Expect(queuedFrames).To(HaveLen(1)) + }) + + It("retires the client's initial destination connection ID when the handshake completes", func() { + g.SetHandshakeComplete() + Expect(retiredConnIDs).To(HaveLen(1)) + Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID)) + }) + + It("removes all connection IDs", func() { + Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(4)) + g.RemoveAll() + Expect(removedConnIDs).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones + Expect(removedConnIDs).To(ContainElement(initialConnID)) + Expect(removedConnIDs).To(ContainElement(initialClientDestConnID)) + for _, f := range queuedFrames { + nf := f.(*wire.NewConnectionIDFrame) + Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) + } + }) + + It("replaces with a closed connection for all connection IDs", func() { + Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(4)) + sess := NewMockPacketHandler(mockCtrl) + g.ReplaceWithClosed(sess) + Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones + Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess)) + Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess)) + for _, f := range queuedFrames { + nf := f.(*wire.NewConnectionIDFrame) + Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess)) + } + }) +}) diff --git a/internal/quic-go/conn_id_manager.go b/internal/quic-go/conn_id_manager.go new file mode 100644 index 00000000..bb12de28 --- /dev/null +++ b/internal/quic-go/conn_id_manager.go @@ -0,0 +1,207 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type connIDManager struct { + queue utils.NewConnectionIDList + + handshakeComplete bool + activeSequenceNumber uint64 + highestRetired uint64 + activeConnectionID protocol.ConnectionID + activeStatelessResetToken *protocol.StatelessResetToken + + // We change the connection ID after sending on average + // protocol.PacketsPerConnectionID packets. The actual value is randomized + // hide the packet loss rate from on-path observers. + rand utils.Rand + packetsSinceLastChange uint32 + packetsPerConnectionID uint32 + + addStatelessResetToken func(protocol.StatelessResetToken) + removeStatelessResetToken func(protocol.StatelessResetToken) + queueControlFrame func(wire.Frame) +} + +func newConnIDManager( + initialDestConnID protocol.ConnectionID, + addStatelessResetToken func(protocol.StatelessResetToken), + removeStatelessResetToken func(protocol.StatelessResetToken), + queueControlFrame func(wire.Frame), +) *connIDManager { + return &connIDManager{ + activeConnectionID: initialDestConnID, + addStatelessResetToken: addStatelessResetToken, + removeStatelessResetToken: removeStatelessResetToken, + queueControlFrame: queueControlFrame, + } +} + +func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { + return h.addConnectionID(1, connID, resetToken) +} + +func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { + if err := h.add(f); err != nil { + return err + } + if h.queue.Len() >= protocol.MaxActiveConnectionIDs { + return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} + } + return nil +} + +func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { + // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active + // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. + if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: f.SequenceNumber, + }) + return nil + } + + // Retire elements in the queue. + // Doesn't retire the active connection ID. + if f.RetirePriorTo > h.highestRetired { + var next *utils.NewConnectionIDElement + for el := h.queue.Front(); el != nil; el = next { + if el.Value.SequenceNumber >= f.RetirePriorTo { + break + } + next = el.Next() + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: el.Value.SequenceNumber, + }) + h.queue.Remove(el) + } + h.highestRetired = f.RetirePriorTo + } + + if f.SequenceNumber == h.activeSequenceNumber { + return nil + } + + if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil { + return err + } + + // Retire the active connection ID, if necessary. + if h.activeSequenceNumber < f.RetirePriorTo { + // The queue is guaranteed to have at least one element at this point. + h.updateConnectionID() + } + return nil +} + +func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { + // insert a new element at the end + if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { + h.queue.PushBack(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }) + return nil + } + // insert a new element somewhere in the middle + for el := h.queue.Front(); el != nil; el = el.Next() { + if el.Value.SequenceNumber == seq { + if !el.Value.ConnectionID.Equal(connID) { + return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) + } + if el.Value.StatelessResetToken != resetToken { + return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) + } + break + } + if el.Value.SequenceNumber > seq { + h.queue.InsertBefore(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }, el) + break + } + } + return nil +} + +func (h *connIDManager) updateConnectionID() { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: h.activeSequenceNumber, + }) + h.highestRetired = utils.MaxUint64(h.highestRetired, h.activeSequenceNumber) + if h.activeStatelessResetToken != nil { + h.removeStatelessResetToken(*h.activeStatelessResetToken) + } + + front := h.queue.Remove(h.queue.Front()) + h.activeSequenceNumber = front.SequenceNumber + h.activeConnectionID = front.ConnectionID + h.activeStatelessResetToken = &front.StatelessResetToken + h.packetsSinceLastChange = 0 + h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID)) + h.addStatelessResetToken(*h.activeStatelessResetToken) +} + +func (h *connIDManager) Close() { + if h.activeStatelessResetToken != nil { + h.removeStatelessResetToken(*h.activeStatelessResetToken) + } +} + +// is called when the server performs a Retry +// and when the server changes the connection ID in the first Initial sent +func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeConnectionID = newConnID +} + +// is called when the server provides a stateless reset token in the transport parameters +func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeStatelessResetToken = &token + h.addStatelessResetToken(token) +} + +func (h *connIDManager) SentPacket() { + h.packetsSinceLastChange++ +} + +func (h *connIDManager) shouldUpdateConnID() bool { + if !h.handshakeComplete { + return false + } + // initiate the first change as early as possible (after handshake completion) + if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { + return true + } + // For later changes, only change if + // 1. The queue of connection IDs is filled more than 50%. + // 2. We sent at least PacketsPerConnectionID packets + return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs && + h.packetsSinceLastChange >= h.packetsPerConnectionID +} + +func (h *connIDManager) Get() protocol.ConnectionID { + if h.shouldUpdateConnID() { + h.updateConnectionID() + } + return h.activeConnectionID +} + +func (h *connIDManager) SetHandshakeComplete() { + h.handshakeComplete = true +} diff --git a/internal/quic-go/conn_id_manager_test.go b/internal/quic-go/conn_id_manager_test.go new file mode 100644 index 00000000..5348a0d7 --- /dev/null +++ b/internal/quic-go/conn_id_manager_test.go @@ -0,0 +1,364 @@ +package quic + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection ID Manager", func() { + var ( + m *connIDManager + frameQueue []wire.Frame + tokenAdded *protocol.StatelessResetToken + removedTokens []protocol.StatelessResetToken + ) + initialConnID := protocol.ConnectionID{0, 0, 0, 0} + + BeforeEach(func() { + frameQueue = nil + tokenAdded = nil + removedTokens = nil + m = newConnIDManager( + initialConnID, + func(token protocol.StatelessResetToken) { tokenAdded = &token }, + func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, + func(f wire.Frame, + ) { + frameQueue = append(frameQueue, f) + }) + }) + + get := func() (protocol.ConnectionID, protocol.StatelessResetToken) { + if m.queue.Len() == 0 { + return nil, protocol.StatelessResetToken{} + } + val := m.queue.Remove(m.queue.Front()) + return val.ConnectionID, val.StatelessResetToken + } + + It("returns the initial connection ID", func() { + Expect(m.Get()).To(Equal(initialConnID)) + }) + + It("changes the initial connection ID", func() { + m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5}) + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + }) + + It("sets the token for the first connection ID", func() { + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + m.SetStatelessResetToken(token) + Expect(*m.activeStatelessResetToken).To(Equal(token)) + Expect(*tokenAdded).To(Equal(token)) + }) + + It("adds and gets connection IDs", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + })).To(Succeed()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 4, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + })).To(Succeed()) + c1, rt1 := get() + Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) + c2, rt2 := get() + Expect(c2).To(Equal(protocol.ConnectionID{2, 3, 4, 5})) + Expect(rt2).To(Equal(protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})) + c3, _ := get() + Expect(c3).To(BeNil()) + }) + + It("accepts duplicates", func() { + f1 := &wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + } + f2 := &wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + } + Expect(m.Add(f1)).To(Succeed()) + Expect(m.Add(f2)).To(Succeed()) + c1, rt1 := get() + Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) + c2, _ := get() + Expect(c2).To(BeNil()) + }) + + It("ignores duplicates for the currently used connection ID", func() { + f := &wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + } + m.SetHandshakeComplete() + Expect(m.Add(f)).To(Succeed()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + c, _ := get() + Expect(c).To(BeNil()) + // Now send the same connection ID again. It should not be queued. + Expect(m.Add(f)).To(Succeed()) + c, _ = get() + Expect(c).To(BeNil()) + }) + + It("rejects duplicates with different connection IDs", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + })).To(Succeed()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + })).To(MatchError("received conflicting connection IDs for sequence number 42")) + }) + + It("rejects duplicates with different connection IDs", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + })).To(Succeed()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + })).To(MatchError("received conflicting stateless reset tokens for sequence number 42")) + }) + + It("retires connection IDs", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + })).To(Succeed()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 13, + ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + })).To(Succeed()) + Expect(frameQueue).To(BeEmpty()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + RetirePriorTo: 14, + SequenceNumber: 17, + ConnectionID: protocol.ConnectionID{3, 4, 5, 6}, + })).To(Succeed()) + Expect(frameQueue).To(HaveLen(3)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10)) + Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13)) + Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) + }) + + It("ignores reordered connection IDs, if their sequence number was already retired", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + RetirePriorTo: 5, + })).To(Succeed()) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + frameQueue = nil + // If this NEW_CONNECTION_ID frame hadn't been reordered, we would have retired it before. + // Make sure it gets retired immediately now. + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 4, + ConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + })).To(Succeed()) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(4)) + }) + + It("ignores reordered connection IDs, if their sequence number was already retired or less than active", func() { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + RetirePriorTo: 5, + })).To(Succeed()) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + frameQueue = nil + Expect(m.Get()).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 9, + ConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + RetirePriorTo: 5, + })).To(Succeed()) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(9)) + }) + + It("accepts retransmissions for the connection ID that is in use", func() { + connID := protocol.ConnectionID{1, 2, 3, 4} + + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: connID, + })).To(Succeed()) + m.SetHandshakeComplete() + Expect(frameQueue).To(BeEmpty()) + Expect(m.Get()).To(Equal(connID)) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0]).To(BeAssignableToTypeOf(&wire.RetireConnectionIDFrame{})) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + frameQueue = nil + + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: connID, + })).To(Succeed()) + Expect(frameQueue).To(BeEmpty()) + }) + + It("errors when the peer sends too connection IDs", func() { + for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(i), + ConnectionID: protocol.ConnectionID{i, i, i, i}, + StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, + })).To(Succeed()) + } + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(9999), + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + })).To(MatchError(&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})) + }) + + It("initiates the first connection ID update as soon as possible", func() { + Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + }) + + It("waits until handshake completion before initiating a connection ID update", func() { + Expect(m.Get()).To(Equal(initialConnID)) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + }) + + It("initiates subsequent updates when enough packets are sent", func() { + var s uint8 + for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(s), + ConnectionID: protocol.ConnectionID{s, s, s, s}, + StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, + })).To(Succeed()) + } + + m.SetHandshakeComplete() + lastConnID := m.Get() + Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + + var counter int + for i := 0; i < 50*protocol.PacketsPerConnectionID; i++ { + m.SentPacket() + + connID := m.Get() + if !connID.Equal(lastConnID) { + counter++ + lastConnID = connID + Expect(removedTokens).To(HaveLen(1)) + removedTokens = nil + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(s), + ConnectionID: protocol.ConnectionID{s, s, s, s}, + StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, + })).To(Succeed()) + s++ + } + } + Expect(counter).To(BeNumerically("~", 50, 10)) + }) + + It("retires delayed connection IDs that arrive after a higher connection ID was already retired", func() { + for s := uint8(10); s <= 10+protocol.MaxActiveConnectionIDs/2; s++ { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(s), + ConnectionID: protocol.ConnectionID{s, s, s, s}, + StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, + })).To(Succeed()) + } + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) + for { + m.SentPacket() + if m.Get().Equal(protocol.ConnectionID{11, 11, 11, 11}) { + break + } + } + // The active conn ID is now {11, 11, 11, 11} + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + // Add a delayed connection ID. It should just be ignored now. + frameQueue = nil + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(5), + ConnectionID: protocol.ConnectionID{5, 5, 5, 5}, + StatelessResetToken: protocol.StatelessResetToken{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, + })).To(Succeed()) + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(5)) + }) + + It("only initiates subsequent updates when enough if enough connection IDs are queued", func() { + for i := uint8(1); i <= protocol.MaxActiveConnectionIDs/2; i++ { + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(i), + ConnectionID: protocol.ConnectionID{i, i, i, i}, + StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, + })).To(Succeed()) + } + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { + m.SentPacket() + } + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1337, + ConnectionID: protocol.ConnectionID{1, 3, 3, 7}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{2, 2, 2, 2})) + Expect(removedTokens).To(HaveLen(1)) + Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) + }) + + It("removes the currently active stateless reset token when it is closed", func() { + m.Close() + Expect(removedTokens).To(BeEmpty()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + m.Close() + Expect(removedTokens).To(HaveLen(1)) + Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})) + }) +}) diff --git a/internal/quic-go/connection.go b/internal/quic-go/connection.go new file mode 100644 index 00000000..fd981709 --- /dev/null +++ b/internal/quic-go/connection.go @@ -0,0 +1,2006 @@ +package quic + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/logutils" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type unpacker interface { + Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) +} + +type streamGetter interface { + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) +} + +type streamManager interface { + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenUniStream() (SendStream, error) + OpenStreamSync(context.Context) (Stream, error) + OpenUniStreamSync(context.Context) (SendStream, error) + AcceptStream(context.Context) (Stream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*wire.TransportParameters) + HandleMaxStreamsFrame(*wire.MaxStreamsFrame) + CloseWithError(error) + ResetFor0RTT() + UseResetMaps() +} + +type cryptoStreamHandler interface { + RunHandshake() + ChangeConnectionID(protocol.ConnectionID) + SetLargest1RTTAcked(protocol.PacketNumber) error + SetHandshakeConfirmed() + GetSessionTicket() ([]byte, error) + io.Closer + ConnectionState() handshake.ConnectionState +} + +type packetInfo struct { + addr net.IP + ifIndex uint32 +} + +type receivedPacket struct { + buffer *packetBuffer + + remoteAddr net.Addr + rcvTime time.Time + data []byte + + ecn protocol.ECN + + info *packetInfo +} + +func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } + +func (p *receivedPacket) Clone() *receivedPacket { + return &receivedPacket{ + remoteAddr: p.remoteAddr, + rcvTime: p.rcvTime, + data: p.data, + buffer: p.buffer, + ecn: p.ecn, + info: p.info, + } +} + +type connRunner interface { + Add(protocol.ConnectionID, packetHandler) bool + GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken + Retire(protocol.ConnectionID) + Remove(protocol.ConnectionID) + ReplaceWithClosed(protocol.ConnectionID, packetHandler) + AddResetToken(protocol.StatelessResetToken, packetHandler) + RemoveResetToken(protocol.StatelessResetToken) +} + +type handshakeRunner struct { + onReceivedParams func(*wire.TransportParameters) + onError func(error) + dropKeys func(protocol.EncryptionLevel) + onHandshakeComplete func() +} + +func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) } +func (r *handshakeRunner) OnError(e error) { r.onError(e) } +func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } +func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } + +type closeError struct { + err error + remote bool + immediate bool +} + +type errCloseForRecreating struct { + nextPacketNumber protocol.PacketNumber + nextVersion protocol.VersionNumber +} + +func (e *errCloseForRecreating) Error() string { + return "closing connection in order to recreate it" +} + +var connTracingID uint64 // to be accessed atomically +func nextConnTracingID() uint64 { return atomic.AddUint64(&connTracingID, 1) } + +// A Connection is a QUIC connection +type connection struct { + // Destination connection ID used during the handshake. + // Used to check source connection ID on incoming packets. + handshakeDestConnID protocol.ConnectionID + // Set for the client. Destination connection ID used on the first Initial sent. + origDestConnID protocol.ConnectionID + retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed) + + srcConnIDLen int + + perspective protocol.Perspective + version protocol.VersionNumber + config *Config + + conn sendConn + sendQueue sender + + streamsMap streamManager + connIDManager *connIDManager + connIDGenerator *connIDGenerator + + rttStats *utils.RTTStats + + cryptoStreamManager *cryptoStreamManager + sentPacketHandler ackhandler.SentPacketHandler + receivedPacketHandler ackhandler.ReceivedPacketHandler + retransmissionQueue *retransmissionQueue + framer framer + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController + tokenStoreKey string // only set for the client + tokenGenerator *handshake.TokenGenerator // only set for the server + + unpacker unpacker + frameParser wire.FrameParser + packer packer + mtuDiscoverer mtuDiscoverer // initialized when the handshake completes + + oneRTTStream cryptoStream // only set for the server + cryptoStreamHandler cryptoStreamHandler + + receivedPackets chan *receivedPacket + sendingScheduled chan struct{} + + closeOnce sync.Once + // closeChan is used to notify the run loop that it should terminate + closeChan chan closeError + + ctx context.Context + ctxCancel context.CancelFunc + handshakeCtx context.Context + handshakeCtxCancel context.CancelFunc + + undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level + undecryptablePacketsToProcess []*receivedPacket + + clientHelloWritten <-chan *wire.TransportParameters + earlyConnReadyChan chan struct{} + handshakeCompleteChan chan struct{} // is closed when the handshake completes + sentFirstPacket bool + handshakeComplete bool + handshakeConfirmed bool + + receivedRetry bool + versionNegotiated bool + receivedFirstPacket bool + + idleTimeout time.Duration + creationTime time.Time + // The idle timeout is set based on the max of the time we received the last packet... + lastPacketReceivedTime time.Time + // ... and the time we sent a new ack-eliciting packet after receiving a packet. + firstAckElicitingPacketAfterIdleSentTime time.Time + // pacingDeadline is the time when the next packet should be sent + pacingDeadline time.Time + + peerParams *wire.TransportParameters + + timer *utils.Timer + // keepAlivePingSent stores whether a keep alive PING is in flight. + // It is reset as soon as we receive a packet from the peer. + keepAlivePingSent bool + keepAliveInterval time.Duration + + datagramQueue *datagramQueue + + logID string + tracer logging.ConnectionTracer + logger utils.Logger +} + +var ( + _ Connection = &connection{} + _ EarlyConnection = &connection{} + _ streamSender = &connection{} + deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine +) + +var newConnection = func( + conn sendConn, + runner connRunner, + origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + statelessResetToken protocol.StatelessResetToken, + conf *Config, + tlsConf *tls.Config, + tokenGenerator *handshake.TokenGenerator, + enable0RTT bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, +) quicConn { + s := &connection{ + conn: conn, + config: conf, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + tokenGenerator: tokenGenerator, + oneRTTStream: newCryptoStream(), + perspective: protocol.PerspectiveServer, + handshakeCompleteChan: make(chan struct{}), + tracer: tracer, + logger: logger, + version: v, + } + if origDestConnID != nil { + s.logID = origDestConnID.String() + } else { + s.logID = destConnID.String() + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + s.connIDGenerator = newConnIDGenerator( + srcConnID, + clientDestConnID, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + s.version, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( + 0, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + s.perspective, + s.tracer, + s.logger, + s.version, + ) + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + params := &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + StatelessResetToken: &statelessResetToken, + OriginalDestinationConnectionID: origDestConnID, + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + RetrySourceConnectionID: retrySrcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs := handshake.NewCryptoSetupServer( + initialStream, + handshakeStream, + clientDestConnID, + conn.LocalAddr(), + conn.RemoteAddr(), + params, + &handshakeRunner{ + onReceivedParams: s.handleTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { + runner.Retire(clientDestConnID) + close(s.handshakeCompleteChan) + }, + }, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + ) + s.cryptoStreamHandler = cs + s.packer = newPacketPacker( + srcConnID, + s.connIDManager.Get, + initialStream, + handshakeStream, + s.sentPacketHandler, + s.retransmissionQueue, + s.RemoteAddr(), + cs, + s.framer, + s.receivedPacketHandler, + s.datagramQueue, + s.perspective, + s.version, + ) + s.unpacker = newPacketUnpacker(cs, s.version) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) + return s +} + +// declare this as a variable, such that we can it mock it in the tests +var newClientConnection = func( + conn sendConn, + runner connRunner, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + conf *Config, + tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, + enable0RTT bool, + hasNegotiatedVersion bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, +) quicConn { + s := &connection{ + conn: conn, + config: conf, + origDestConnID: destConnID, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + perspective: protocol.PerspectiveClient, + handshakeCompleteChan: make(chan struct{}), + logID: destConnID.String(), + logger: logger, + tracer: tracer, + versionNegotiated: hasNegotiatedVersion, + version: v, + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + s.connIDGenerator = newConnIDGenerator( + srcConnID, + nil, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + s.version, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( + initialPacketNumber, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + s.perspective, + s.tracer, + s.logger, + s.version, + ) + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + params := &wire.TransportParameters{ + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs, clientHelloWritten := handshake.NewCryptoSetupClient( + initialStream, + handshakeStream, + destConnID, + conn.LocalAddr(), + conn.RemoteAddr(), + params, + &handshakeRunner{ + onReceivedParams: s.handleTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, + }, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + ) + s.clientHelloWritten = clientHelloWritten + s.cryptoStreamHandler = cs + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) + s.unpacker = newPacketUnpacker(cs, s.version) + s.packer = newPacketPacker( + srcConnID, + s.connIDManager.Get, + initialStream, + handshakeStream, + s.sentPacketHandler, + s.retransmissionQueue, + s.RemoteAddr(), + cs, + s.framer, + s.receivedPacketHandler, + s.datagramQueue, + s.perspective, + s.version, + ) + if len(tlsConf.ServerName) > 0 { + s.tokenStoreKey = tlsConf.ServerName + } else { + s.tokenStoreKey = conn.RemoteAddr().String() + } + if s.config.TokenStore != nil { + if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { + s.packer.SetToken(token.data) + } + } + return s +} + +func (s *connection) preSetup() { + s.sendQueue = newSendQueue(s.conn) + s.retransmissionQueue = newRetransmissionQueue(s.version) + s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) + s.rttStats = &utils.RTTStats{} + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + protocol.ByteCount(s.config.MaxConnectionReceiveWindow), + s.onHasConnectionWindowUpdate, + func(size protocol.ByteCount) bool { + if s.config.AllowConnectionWindowIncrease == nil { + return true + } + return s.config.AllowConnectionWindowIncrease(s, uint64(size)) + }, + s.rttStats, + s.logger, + ) + s.earlyConnReadyChan = make(chan struct{}) + s.streamsMap = newStreamsMap( + s, + s.newFlowController, + uint64(s.config.MaxIncomingStreams), + uint64(s.config.MaxIncomingUniStreams), + s.perspective, + s.version, + ) + s.framer = newFramer(s.streamsMap, s.version) + s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) + s.closeChan = make(chan closeError, 1) + s.sendingScheduled = make(chan struct{}, 1) + s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) + + now := time.Now() + s.lastPacketReceivedTime = now + s.creationTime = now + + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) + if s.config.EnableDatagrams { + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + } +} + +// run the connection main loop +func (s *connection) run() error { + defer s.ctxCancel() + + s.timer = utils.NewTimer() + + go s.cryptoStreamHandler.RunHandshake() + go func() { + if err := s.sendQueue.Run(); err != nil { + s.destroyImpl(err) + } + }() + + if s.perspective == protocol.PerspectiveClient { + select { + case zeroRTTParams := <-s.clientHelloWritten: + s.scheduleSending() + if zeroRTTParams != nil { + s.restoreTransportParameters(zeroRTTParams) + close(s.earlyConnReadyChan) + } + case closeErr := <-s.closeChan: + // put the close error back into the channel, so that the run loop can receive it + s.closeChan <- closeErr + } + } + + var ( + closeErr closeError + sendQueueAvailable <-chan struct{} + ) + +runLoop: + for { + // Close immediately if requested + select { + case closeErr = <-s.closeChan: + break runLoop + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() + default: + } + + s.maybeResetTimer() + + var processedUndecryptablePacket bool + if len(s.undecryptablePacketsToProcess) > 0 { + queue := s.undecryptablePacketsToProcess + s.undecryptablePacketsToProcess = nil + for _, p := range queue { + if processed := s.handlePacketImpl(p); processed { + processedUndecryptablePacket = true + } + // Don't set timers and send packets if the packet made us close the connection. + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + } + } + // If we processed any undecryptable packets, jump to the resetting of the timers directly. + if !processedUndecryptablePacket { + select { + case closeErr = <-s.closeChan: + break runLoop + case <-s.timer.Chan(): + s.timer.SetRead() + // We do all the interesting stuff after the switch statement, so + // nothing to see here. + case <-s.sendingScheduled: + // We do all the interesting stuff after the switch statement, so + // nothing to see here. + case <-sendQueueAvailable: + case firstPacket := <-s.receivedPackets: + wasProcessed := s.handlePacketImpl(firstPacket) + // Don't set timers and send packets if the packet made us close the connection. + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + if s.handshakeComplete { + // Now process all packets in the receivedPackets channel. + // Limit the number of packets to the length of the receivedPackets channel, + // so we eventually get a chance to send out an ACK when receiving a lot of packets. + numPackets := len(s.receivedPackets) + receiveLoop: + for i := 0; i < numPackets; i++ { + select { + case p := <-s.receivedPackets: + if processed := s.handlePacketImpl(p); processed { + wasProcessed = true + } + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + default: + break receiveLoop + } + } + } + // Only reset the timers if this packet was actually processed. + // This avoids modifying any state when handling undecryptable packets, + // which could be injected by an attacker. + if !wasProcessed { + continue + } + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() + } + } + + now := time.Now() + if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { + // This could cause packets to be retransmitted. + // Check it before trying to send packets. + if err := s.sentPacketHandler.OnLossDetectionTimeout(); err != nil { + s.closeLocal(err) + } + } + + if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) { + // send a PING frame since there is no activity in the connection + s.logger.Debugf("Sending a keep-alive PING to keep the connection alive.") + s.framer.QueueControlFrame(&wire.PingFrame{}) + s.keepAlivePingSent = true + } else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() { + s.destroyImpl(qerr.ErrHandshakeTimeout) + continue + } else { + idleTimeoutStartTime := s.idleTimeoutStartTime() + if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || + (s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { + s.destroyImpl(qerr.ErrIdleTimeout) + continue + } + } + + if s.sendQueue.WouldBlock() { + // The send queue is still busy sending out packets. + // Wait until there's space to enqueue new packets. + sendQueueAvailable = s.sendQueue.Available() + continue + } + if err := s.sendPackets(); err != nil { + s.closeLocal(err) + } + if s.sendQueue.WouldBlock() { + sendQueueAvailable = s.sendQueue.Available() + } else { + sendQueueAvailable = nil + } + } + + s.handleCloseError(&closeErr) + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { + s.tracer.Close() + } + s.logger.Infof("Connection %s closed.", s.logID) + s.cryptoStreamHandler.Close() + s.sendQueue.Close() + s.timer.Stop() + return closeErr.err +} + +// blocks until the early connection can be used +func (s *connection) earlyConnReady() <-chan struct{} { + return s.earlyConnReadyChan +} + +func (s *connection) HandshakeComplete() context.Context { + return s.handshakeCtx +} + +func (s *connection) Context() context.Context { + return s.ctx +} + +func (s *connection) supportsDatagrams() bool { + return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount +} + +func (s *connection) ConnectionState() ConnectionState { + return ConnectionState{ + TLS: s.cryptoStreamHandler.ConnectionState(), + SupportsDatagrams: s.supportsDatagrams(), + } +} + +// Time when the next keep-alive packet should be sent. +// It returns a zero time if no keep-alive should be sent. +func (s *connection) nextKeepAliveTime() time.Time { + if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { + return time.Time{} + } + return s.lastPacketReceivedTime.Add(s.keepAliveInterval) +} + +func (s *connection) maybeResetTimer() { + var deadline time.Time + if !s.handshakeComplete { + deadline = utils.MinTime( + s.creationTime.Add(s.config.handshakeTimeout()), + s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout), + ) + } else { + if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { + deadline = keepAliveTime + } else { + deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) + } + } + + if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { + deadline = utils.MinTime(deadline, ackAlarm) + } + if lossTime := s.sentPacketHandler.GetLossDetectionTimeout(); !lossTime.IsZero() { + deadline = utils.MinTime(deadline, lossTime) + } + if !s.pacingDeadline.IsZero() { + deadline = utils.MinTime(deadline, s.pacingDeadline) + } + + s.timer.Reset(deadline) +} + +func (s *connection) idleTimeoutStartTime() time.Time { + return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime) +} + +func (s *connection) handleHandshakeComplete() { + s.handshakeComplete = true + s.handshakeCompleteChan = nil // prevent this case from ever being selected again + defer s.handshakeCtxCancel() + // Once the handshake completes, we have derived 1-RTT keys. + // There's no point in queueing undecryptable packets for later decryption any more. + s.undecryptablePackets = nil + + s.connIDManager.SetHandshakeComplete() + s.connIDGenerator.SetHandshakeComplete() + + if s.perspective == protocol.PerspectiveClient { + s.applyTransportParameters() + return + } + + s.handleHandshakeConfirmed() + + ticket, err := s.cryptoStreamHandler.GetSessionTicket() + if err != nil { + s.closeLocal(err) + } + if ticket != nil { + s.oneRTTStream.Write(ticket) + for s.oneRTTStream.HasData() { + s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) + if err != nil { + s.closeLocal(err) + } + s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.queueControlFrame(&wire.HandshakeDoneFrame{}) +} + +func (s *connection) handleHandshakeConfirmed() { + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() + s.cryptoStreamHandler.SetHandshakeConfirmed() + + if !s.config.DisablePathMTUDiscovery { + maxPacketSize := s.peerParams.MaxUDPPayloadSize + if maxPacketSize == 0 { + maxPacketSize = protocol.MaxByteCount + } + maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize) + s.mtuDiscoverer = newMTUDiscoverer( + s.rttStats, + getMaxPacketSize(s.conn.RemoteAddr()), + maxPacketSize, + func(size protocol.ByteCount) { + s.sentPacketHandler.SetMaxDatagramSize(size) + s.packer.SetMaxPacketSize(size) + }, + ) + } +} + +func (s *connection) handlePacketImpl(rp *receivedPacket) bool { + s.sentPacketHandler.ReceivedBytes(rp.Size()) + + if wire.IsVersionNegotiationPacket(rp.data) { + s.handleVersionNegotiationPacket(rp) + return false + } + + var counter uint8 + var lastConnID protocol.ConnectionID + var processed bool + data := rp.data + p := rp + for len(data) > 0 { + if counter > 0 { + p = p.Clone() + p.data = data + } + + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) + if err != nil { + if s.tracer != nil { + dropReason := logging.PacketDropHeaderParseError + if err == wire.ErrUnsupportedVersion { + dropReason = logging.PacketDropUnsupportedVersion + } + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + } + s.logger.Debugf("error parsing packet: %s", err) + break + } + + if hdr.IsLongHeader && hdr.Version != s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + } + s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) + break + } + + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + break + } + lastConnID = hdr.DestConnectionID + + if counter > 0 { + p.buffer.Split() + } + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + p.data = packetData + if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { + processed = true + } + data = rest + } + p.buffer.MaybeRelease() + return processed +} + +func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { + var wasQueued bool + + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + p.buffer.Decrement() + } + }() + + if hdr.Type == protocol.PacketTypeRetry { + return s.handleRetryPacket(hdr, p.data) + } + + // The server can change the source connection ID with the first Handshake packet. + // After this, all packets with a different source connection have to be ignored. + if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) + return false + } + // drop 0-RTT packets, if we are a client + if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) + } + return false + } + + packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) + if err != nil { + switch err { + case handshake.ErrKeysDropped: + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable) + } + s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size()) + case handshake.ErrKeysNotYetAvailable: + // Sealer for this encryption level not yet available. + // Try again later. + wasQueued = true + s.tryQueueingUndecryptablePacket(p, hdr) + case wire.ErrInvalidReservedBits: + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: err.Error(), + }) + case handshake.ErrDecryptionFailed: + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) + default: + var headerErr *headerParseError + if errors.As(err, &headerErr) { + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) + } else { + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + s.closeLocal(err) + } + } + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.packetNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) + packet.hdr.Log(s.logger) + } + + if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.packetNumber, packet.encryptionLevel) { + s.logger.Debugf("Dropping (potentially) duplicate packet.") + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) + } + return false + } + + if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { + s.closeLocal(err) + return false + } + return true +} + +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { + if s.perspective == protocol.PerspectiveServer { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry.") + return false + } + if s.receivedFirstPacket { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry, since we already received a packet.") + return false + } + destConnID := s.connIDManager.Get() + if hdr.SrcConnectionID.Equal(destConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") + return false + } + // If a token is already set, this means that we already received a Retry from the server. + // Ignore this Retry packet. + if s.receivedRetry { + s.logger.Debugf("Ignoring Retry, since a Retry was already received.") + return false + } + + tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) + if !bytes.Equal(data[len(data)-16:], tag[:]) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Received Retry:") + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) + s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) + } + if s.tracer != nil { + s.tracer.ReceivedRetry(hdr) + } + newDestConnID := hdr.SrcConnectionID + s.receivedRetry = true + if err := s.sentPacketHandler.ResetForRetry(); err != nil { + s.closeLocal(err) + return false + } + s.handshakeDestConnID = newDestConnID + s.retrySrcConnID = &newDestConnID + s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) + s.packer.SetToken(hdr.Token) + s.connIDManager.ChangeInitialConnID(newDestConnID) + s.scheduleSending() + return true +} + +func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { + if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets + s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + } + return + } + + hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) + return + } + + for _, v := range supportedVersions { + if v == s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + } + // The Version Negotiation packet contains the version that we offered. + // This might be a packet sent by an attacker, or it was corrupted. + return + } + } + + s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) + if s.tracer != nil { + s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) + } + newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) + if !ok { + s.destroyImpl(&VersionNegotiationError{ + Ours: s.config.Versions, + Theirs: supportedVersions, + }) + s.logger.Infof("No compatible QUIC version found.") + return + } + if s.tracer != nil { + s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) + } + + s.logger.Infof("Switching to QUIC version %s.", newVersion) + nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) + s.destroyImpl(&errCloseForRecreating{ + nextPacketNumber: nextPN, + nextVersion: newVersion, + }) +} + +func (s *connection) handleUnpackedPacket( + packet *unpackedPacket, + ecn protocol.ECN, + rcvTime time.Time, + packetSize protocol.ByteCount, // only for logging +) error { + if len(packet.data) == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + } + } + + if !s.receivedFirstPacket { + s.receivedFirstPacket = true + if !s.versionNegotiated && s.tracer != nil { + var clientVersions, serverVersions []protocol.VersionNumber + switch s.perspective { + case protocol.PerspectiveClient: + clientVersions = s.config.Versions + case protocol.PerspectiveServer: + serverVersions = s.config.Versions + } + s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) + } + // The server can change the source connection ID with the first Handshake packet. + if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + cid := packet.hdr.SrcConnectionID + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) + s.handshakeDestConnID = cid + s.connIDManager.ChangeInitialConnID(cid) + } + // We create the connection as soon as we receive the first packet from the client. + // We do that before authenticating the packet. + // That means that if the source connection ID was corrupted, + // we might have create a connection with an incorrect source connection ID. + // Once we authenticate the first packet, we need to update it. + if s.perspective == protocol.PerspectiveServer { + if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + s.handshakeDestConnID = packet.hdr.SrcConnectionID + s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) + } + if s.tracer != nil { + s.tracer.StartedConnection( + s.conn.LocalAddr(), + s.conn.RemoteAddr(), + packet.hdr.SrcConnectionID, + packet.hdr.DestConnectionID, + ) + } + } + } + + s.lastPacketReceivedTime = rcvTime + s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} + s.keepAlivePingSent = false + + // Only used for tracing. + // If we're not tracing, this slice will always remain empty. + var frames []wire.Frame + r := bytes.NewReader(packet.data) + var isAckEliciting bool + for { + frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) + if err != nil { + return err + } + if frame == nil { + break + } + if ackhandler.IsFrameAckEliciting(frame) { + isAckEliciting = true + } + // Only process frames now if we're not logging. + // If we're logging, we need to make sure that the packet_received event is logged first. + if s.tracer == nil { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { + return err + } + } else { + frames = append(frames, frame) + } + } + + if s.tracer != nil { + fs := make([]logging.Frame, len(frames)) + for i, frame := range frames { + fs[i] = logutils.ConvertFrame(frame) + } + s.tracer.ReceivedPacket(packet.hdr, packetSize, fs) + for _, frame := range frames { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { + return err + } + } + } + + return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) +} + +func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { + var err error + wire.LogFrame(s.logger, f, false) + switch frame := f.(type) { + case *wire.CryptoFrame: + err = s.handleCryptoFrame(frame, encLevel) + case *wire.StreamFrame: + err = s.handleStreamFrame(frame) + case *wire.AckFrame: + err = s.handleAckFrame(frame, encLevel) + case *wire.ConnectionCloseFrame: + s.handleConnectionCloseFrame(frame) + case *wire.ResetStreamFrame: + err = s.handleResetStreamFrame(frame) + case *wire.MaxDataFrame: + s.handleMaxDataFrame(frame) + case *wire.MaxStreamDataFrame: + err = s.handleMaxStreamDataFrame(frame) + case *wire.MaxStreamsFrame: + s.handleMaxStreamsFrame(frame) + case *wire.DataBlockedFrame: + case *wire.StreamDataBlockedFrame: + case *wire.StreamsBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) + case *wire.PingFrame: + case *wire.PathChallengeFrame: + s.handlePathChallengeFrame(frame) + case *wire.PathResponseFrame: + // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs + err = errors.New("unexpected PATH_RESPONSE frame") + case *wire.NewTokenFrame: + err = s.handleNewTokenFrame(frame) + case *wire.NewConnectionIDFrame: + err = s.handleNewConnectionIDFrame(frame) + case *wire.RetireConnectionIDFrame: + err = s.handleRetireConnectionIDFrame(frame, destConnID) + case *wire.HandshakeDoneFrame: + err = s.handleHandshakeDoneFrame() + case *wire.DatagramFrame: + err = s.handleDatagramFrame(frame) + default: + err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) + } + return err +} + +// handlePacket is called by the server with a new packet +func (s *connection) handlePacket(p *receivedPacket) { + // Discard packets once the amount of queued packets is larger than + // the channel size, protocol.MaxConnUnprocessedPackets + select { + case s.receivedPackets <- p: + default: + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { + if frame.IsApplicationError { + s.closeRemote(&qerr.ApplicationError{ + Remote: true, + ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode), + ErrorMessage: frame.ReasonPhrase, + }) + return + } + s.closeRemote(&qerr.TransportError{ + Remote: true, + ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), + FrameType: frame.FrameType, + ErrorMessage: frame.ReasonPhrase, + }) +} + +func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { + encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err != nil { + return err + } + if encLevelChanged { + // Queue all packets for decryption that have been undecryptable so far. + s.undecryptablePacketsToProcess = s.undecryptablePackets + s.undecryptablePackets = nil + } + return nil +} + +func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // Stream is closed and already garbage collected + // ignore this StreamFrame + return nil + } + return str.handleStreamFrame(frame) +} + +func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { + s.connFlowController.UpdateSendWindow(frame.MaximumData) +} + +func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.updateSendWindow(frame.MaximumStreamData) + return nil +} + +func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { + s.streamsMap.HandleMaxStreamsFrame(frame) +} + +func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + return str.handleResetStreamFrame(frame) +} + +func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.handleStopSendingFrame(frame) + return nil +} + +func (s *connection) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { + s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) +} + +func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { + if s.perspective == protocol.PerspectiveServer { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received NEW_TOKEN frame from the client", + } + } + if s.config.TokenStore != nil { + s.config.TokenStore.Put(s.tokenStoreKey, &ClientToken{data: frame.Token}) + } + return nil +} + +func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { + return s.connIDManager.Add(f) +} + +func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { + return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) +} + +func (s *connection) handleHandshakeDoneFrame() error { + if s.perspective == protocol.PerspectiveServer { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received a HANDSHAKE_DONE frame", + } + } + if !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } + return nil +} + +func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { + acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) + if err != nil { + return err + } + if !acked1RTTPacket { + return nil + } + if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } + return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) +} + +func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { + if f.Length(s.version) > protocol.MaxDatagramFrameSize { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "DATAGRAM frame too large", + } + } + s.datagramQueue.HandleDatagramFrame(f) + return nil +} + +// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error +func (s *connection) closeLocal(e error) { + s.closeOnce.Do(func() { + if e == nil { + s.logger.Infof("Closing connection.") + } else { + s.logger.Errorf("Closing connection with error: %s", e) + } + s.closeChan <- closeError{err: e, immediate: false, remote: false} + }) +} + +// destroy closes the connection without sending the error on the wire +func (s *connection) destroy(e error) { + s.destroyImpl(e) + <-s.ctx.Done() +} + +func (s *connection) destroyImpl(e error) { + s.closeOnce.Do(func() { + if nerr, ok := e.(net.Error); ok && nerr.Timeout() { + s.logger.Errorf("Destroying connection: %s", e) + } else { + s.logger.Errorf("Destroying connection with error: %s", e) + } + s.closeChan <- closeError{err: e, immediate: true, remote: false} + }) +} + +func (s *connection) closeRemote(e error) { + s.closeOnce.Do(func() { + s.logger.Errorf("Peer closed connection with error: %s", e) + s.closeChan <- closeError{err: e, immediate: true, remote: true} + }) +} + +// Close the connection. It sends a NO_ERROR application error. +// It waits until the run loop has stopped before returning +func (s *connection) shutdown() { + s.closeLocal(nil) + <-s.ctx.Done() +} + +func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { + s.closeLocal(&qerr.ApplicationError{ + ErrorCode: code, + ErrorMessage: desc, + }) + <-s.ctx.Done() + return nil +} + +func (s *connection) handleCloseError(closeErr *closeError) { + e := closeErr.err + if e == nil { + e = &qerr.ApplicationError{} + } else { + defer func() { + closeErr.err = e + }() + } + + var ( + statelessResetErr *StatelessResetError + versionNegotiationErr *VersionNegotiationError + recreateErr *errCloseForRecreating + applicationErr *ApplicationError + transportErr *TransportError + ) + switch { + case errors.Is(e, qerr.ErrIdleTimeout), + errors.Is(e, qerr.ErrHandshakeTimeout), + errors.As(e, &statelessResetErr), + errors.As(e, &versionNegotiationErr), + errors.As(e, &recreateErr), + errors.As(e, &applicationErr), + errors.As(e, &transportErr): + default: + e = &qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: e.Error(), + } + } + + s.streamsMap.CloseWithError(e) + s.connIDManager.Close() + if s.datagramQueue != nil { + s.datagramQueue.CloseWithError(e) + } + + if s.tracer != nil && !errors.As(e, &recreateErr) { + s.tracer.ClosedConnection(e) + } + + // If this is a remote close we're done here + if closeErr.remote { + s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) + return + } + if closeErr.immediate { + s.connIDGenerator.RemoveAll() + return + } + // Don't send out any CONNECTION_CLOSE if this is an error that occurred + // before we even sent out the first packet. + if s.perspective == protocol.PerspectiveClient && !s.sentFirstPacket { + s.connIDGenerator.RemoveAll() + return + } + connClosePacket, err := s.sendConnectionClose(e) + if err != nil { + s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) + } + cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) + s.connIDGenerator.ReplaceWithClosed(cs) +} + +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { + s.sentPacketHandler.DropPackets(encLevel) + s.receivedPacketHandler.DropPackets(encLevel) + if s.tracer != nil { + s.tracer.DroppedEncryptionLevel(encLevel) + } + if encLevel == protocol.Encryption0RTT { + s.streamsMap.ResetFor0RTT() + if err := s.connFlowController.Reset(); err != nil { + s.closeLocal(err) + } + if err := s.framer.Handle0RTTRejection(); err != nil { + s.closeLocal(err) + } + } +} + +// is called for the client, when restoring transport parameters saved for 0-RTT +func (s *connection) restoreTransportParameters(params *wire.TransportParameters) { + if s.logger.Debug() { + s.logger.Debugf("Restoring Transport Parameters: %s", params) + } + + s.peerParams = params + s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) + s.connFlowController.UpdateSendWindow(params.InitialMaxData) + s.streamsMap.UpdateLimits(params) +} + +func (s *connection) handleTransportParameters(params *wire.TransportParameters) { + if err := s.checkTransportParameters(params); err != nil { + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) + } + s.peerParams = params + // On the client side we have to wait for handshake completion. + // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. + if s.perspective == protocol.PerspectiveServer { + s.applyTransportParameters() + // On the server side, the early connection is ready as soon as we processed + // the client's transport parameters. + close(s.earlyConnReadyChan) + } +} + +func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { + if s.logger.Debug() { + s.logger.Debugf("Processed Transport Parameters: %s", params) + } + if s.tracer != nil { + s.tracer.ReceivedTransportParameters(params) + } + + // check the initial_source_connection_id + if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) + } + + if s.perspective == protocol.PerspectiveServer { + return nil + } + // check the original_destination_connection_id + if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) + } + if s.retrySrcConnID != nil { // a Retry was performed + if params.RetrySourceConnectionID == nil { + return errors.New("missing retry_source_connection_id") + } + if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) + } + } else if params.RetrySourceConnectionID != nil { + return errors.New("received retry_source_connection_id, although no Retry was performed") + } + return nil +} + +func (s *connection) applyTransportParameters() { + params := s.peerParams + // Our local idle timeout will always be > 0. + s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) + s.keepAliveInterval = utils.MinDuration(s.config.KeepAlivePeriod, utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.streamsMap.UpdateLimits(params) + s.packer.HandleTransportParameters(params) + s.frameParser.SetAckDelayExponent(params.AckDelayExponent) + s.connFlowController.UpdateSendWindow(params.InitialMaxData) + s.rttStats.SetMaxAckDelay(params.MaxAckDelay) + s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) + if params.StatelessResetToken != nil { + s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) + } + // We don't support connection migration yet, so we don't have any use for the preferred_address. + if params.PreferredAddress != nil { + // Retire the connection ID. + s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken) + } +} + +func (s *connection) sendPackets() error { + s.pacingDeadline = time.Time{} + + var sentPacket bool // only used in for packets sent in send mode SendAny + for { + sendMode := s.sentPacketHandler.SendMode() + if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { + deadline := s.sentPacketHandler.TimeUntilSend() + if deadline.IsZero() { + deadline = deadlineSendImmediately + } + s.pacingDeadline = deadline + // Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). + // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) + // sends enough ACKs to allow its peer to utilize the bandwidth. + if sentPacket { + return nil + } + sendMode = ackhandler.SendAck + } + switch sendMode { + case ackhandler.SendNone: + return nil + case ackhandler.SendAck: + // If we already sent packets, and the send mode switches to SendAck, + // as we've just become congestion limited. + // There's no need to try to send an ACK at this moment. + if sentPacket { + return nil + } + // We can at most send a single ACK only packet. + // There will only be a new ACK after receiving new packets. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. + return s.maybeSendAckOnlyPacket() + case ackhandler.SendPTOInitial: + if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { + return err + } + case ackhandler.SendPTOHandshake: + if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { + return err + } + case ackhandler.SendPTOAppData: + if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { + return err + } + case ackhandler.SendAny: + sent, err := s.sendPacket() + if err != nil || !sent { + return err + } + sentPacket = true + default: + return fmt.Errorf("BUG: invalid send mode %d", sendMode) + } + // Prioritize receiving of packets over sending out more packets. + if len(s.receivedPackets) > 0 { + s.pacingDeadline = deadlineSendImmediately + return nil + } + if s.sendQueue.WouldBlock() { + return nil + } + } +} + +func (s *connection) maybeSendAckOnlyPacket() error { + packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) + if err != nil { + return err + } + if packet == nil { + return nil + } + s.sendPackedPacket(packet, time.Now()) + return nil +} + +func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { + // Queue probe packets until we actually send out a packet, + // or until there are no more packets to queue. + var packet *packedPacket + for { + if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { + break + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + if packet != nil { + break + } + } + if packet == nil { + //nolint:exhaustive // Cannot send probe packets for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + s.retransmissionQueue.AddInitial(&wire.PingFrame{}) + case protocol.EncryptionHandshake: + s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + case protocol.Encryption1RTT: + s.retransmissionQueue.AddAppData(&wire.PingFrame{}) + default: + panic("unexpected encryption level") + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + } + if packet == nil || packet.packetContents == nil { + return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) + } + s.sendPackedPacket(packet, time.Now()) + return nil +} + +func (s *connection) sendPacket() (bool, error) { + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) + } + s.windowUpdateQueue.QueueAll() + + now := time.Now() + if !s.handshakeConfirmed { + packet, err := s.packer.PackCoalescedPacket() + if err != nil || packet == nil { + return false, err + } + s.sentFirstPacket = true + s.logCoalescedPacket(packet) + for _, p := range packet.packets { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) + } + s.connIDManager.SentPacket() + s.sendQueue.Send(packet.buffer) + return true, nil + } + if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { + packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) + if err != nil { + return false, err + } + s.sendPackedPacket(packet, now) + return true, nil + } + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + s.sendPackedPacket(packet, now) + return true, nil +} + +func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.logPacket(packet) + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue)) + s.connIDManager.SentPacket() + s.sendQueue.Send(packet.buffer) +} + +func (s *connection) sendConnectionClose(e error) ([]byte, error) { + var packet *coalescedPacket + var err error + var transportErr *qerr.TransportError + var applicationErr *qerr.ApplicationError + if errors.As(e, &transportErr) { + packet, err = s.packer.PackConnectionClose(transportErr) + } else if errors.As(e, &applicationErr) { + packet, err = s.packer.PackApplicationClose(applicationErr) + } else { + packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), + }) + } + if err != nil { + return nil, err + } + s.logCoalescedPacket(packet) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data) +} + +func (s *connection) logPacketContents(p *packetContents) { + // tracing + if s.tracer != nil { + frames := make([]logging.Frame, 0, len(p.frames)) + for _, f := range p.frames { + frames = append(frames, logutils.ConvertFrame(f.Frame)) + } + s.tracer.SentPacket(p.header, p.length, p.ack, frames) + } + + // quic-go logging + if !s.logger.Debug() { + return + } + p.header.Log(s.logger) + if p.ack != nil { + wire.LogFrame(s.logger, p.ack, true) + } + for _, frame := range p.frames { + wire.LogFrame(s.logger, frame.Frame, true) + } +} + +func (s *connection) logCoalescedPacket(packet *coalescedPacket) { + if s.logger.Debug() { + if len(packet.packets) > 1 { + s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) + } else { + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel()) + } + } + for _, p := range packet.packets { + s.logPacketContents(p) + } +} + +func (s *connection) logPacket(packet *packedPacket) { + if s.logger.Debug() { + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) + } + s.logPacketContents(packet.packetContents) +} + +// AcceptStream returns the next stream openend by the peer +func (s *connection) AcceptStream(ctx context.Context) (Stream, error) { + return s.streamsMap.AcceptStream(ctx) +} + +func (s *connection) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + return s.streamsMap.AcceptUniStream(ctx) +} + +// OpenStream opens a stream +func (s *connection) OpenStream() (Stream, error) { + return s.streamsMap.OpenStream() +} + +func (s *connection) OpenStreamSync(ctx context.Context) (Stream, error) { + return s.streamsMap.OpenStreamSync(ctx) +} + +func (s *connection) OpenUniStream() (SendStream, error) { + return s.streamsMap.OpenUniStream() +} + +func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + return s.streamsMap.OpenUniStreamSync(ctx) +} + +func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { + initialSendWindow := s.peerParams.InitialMaxStreamDataUni + if id.Type() == protocol.StreamTypeBidi { + if id.InitiatedBy() == s.perspective { + initialSendWindow = s.peerParams.InitialMaxStreamDataBidiRemote + } else { + initialSendWindow = s.peerParams.InitialMaxStreamDataBidiLocal + } + } + return flowcontrol.NewStreamFlowController( + id, + s.connFlowController, + protocol.ByteCount(s.config.InitialStreamReceiveWindow), + protocol.ByteCount(s.config.MaxStreamReceiveWindow), + initialSendWindow, + s.onHasStreamWindowUpdate, + s.rttStats, + s.logger, + ) +} + +// scheduleSending signals that we have data for sending +func (s *connection) scheduleSending() { + select { + case s.sendingScheduled <- struct{}{}: + default: + } +} + +func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { + if s.handshakeComplete { + panic("shouldn't queue undecryptable packets after handshake completion") + } + if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDOSPrevention) + } + s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) + return + } + s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) + if s.tracer != nil { + s.tracer.BufferedPacket(logging.PacketTypeFromHeader(hdr)) + } + s.undecryptablePackets = append(s.undecryptablePackets, p) +} + +func (s *connection) queueControlFrame(f wire.Frame) { + s.framer.QueueControlFrame(f) + s.scheduleSending() +} + +func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.AddStream(id) + s.scheduleSending() +} + +func (s *connection) onHasConnectionWindowUpdate() { + s.windowUpdateQueue.AddConnection() + s.scheduleSending() +} + +func (s *connection) onHasStreamData(id protocol.StreamID) { + s.framer.AddActiveStream(id) + s.scheduleSending() +} + +func (s *connection) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.closeLocal(err) + } +} + +func (s *connection) SendMessage(p []byte) error { + f := &wire.DatagramFrame{DataLenPresent: true} + if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { + return errors.New("message too large") + } + f.Data = make([]byte, len(p)) + copy(f.Data, p) + return s.datagramQueue.AddAndWait(f) +} + +func (s *connection) ReceiveMessage() ([]byte, error) { + return s.datagramQueue.Receive() +} + +func (s *connection) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *connection) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *connection) getPerspective() protocol.Perspective { + return s.perspective +} + +func (s *connection) GetVersion() protocol.VersionNumber { + return s.version +} + +func (s *connection) NextConnection() Connection { + <-s.HandshakeComplete().Done() + s.streamsMap.UseResetMaps() + return s +} diff --git a/internal/quic-go/connection_test.go b/internal/quic-go/connection_test.go new file mode 100644 index 00000000..66c88156 --- /dev/null +++ b/internal/quic-go/connection_test.go @@ -0,0 +1,3038 @@ +package quic + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "runtime/pprof" + "strings" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/mocks" + mockackhandler "github.com/imroc/req/v3/internal/quic-go/mocks/ackhandler" + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/testutils" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func areConnsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*connection).run") +} + +func areClosedConnsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run") +} + +var _ = Describe("Connection", func() { + var ( + conn *connection + connRunner *MockConnRunner + mconn *MockSendConn + streamManager *MockStreamManager + packer *MockPacker + cryptoSetup *mocks.MockCryptoSetup + tracer *mocklogging.MockConnectionTracer + ) + remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + + getPacket := func(pn protocol.PacketNumber) *packedPacket { + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("foobar")...) + return &packedPacket{ + buffer: buffer, + packetContents: &packetContents{ + header: &wire.ExtendedHeader{PacketNumber: pn}, + length: 6, // foobar + }, + } + } + + expectReplaceWithClosed := func() { + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) + s.shutdown() + Eventually(areClosedConnsRunning).Should(BeFalse()) + }) + } + + BeforeEach(func() { + Eventually(areConnsRunning).Should(BeFalse()) + + connRunner = NewMockConnRunner(mockCtrl) + mconn = NewMockSendConn(mockCtrl) + mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() + mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentTransportParameters(gomock.Any()) + tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().UpdatedCongestionState(gomock.Any()) + conn = newConnection( + mconn, + connRunner, + nil, + nil, + clientDestConnID, + destConnID, + srcConnID, + protocol.StatelessResetToken{}, + populateServerConfig(&Config{DisablePathMTUDiscovery: true}), + nil, // tls.Config + tokenGenerator, + false, + tracer, + 1234, + utils.DefaultLogger, + protocol.VersionTLS, + ).(*connection) + streamManager = NewMockStreamManager(mockCtrl) + conn.streamsMap = streamManager + packer = NewMockPacker(mockCtrl) + conn.packer = packer + cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) + conn.cryptoStreamHandler = cryptoSetup + conn.handshakeComplete = true + conn.idleTimeout = time.Hour + }) + + AfterEach(func() { + Eventually(areConnsRunning).Should(BeFalse()) + }) + + Context("frame handling", func() { + Context("handling STREAM frames", func() { + It("passes STREAM frames to the stream", func() { + f := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + } + str := NewMockReceiveStreamI(mockCtrl) + str.EXPECT().handleStreamFrame(f) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) + Expect(conn.handleStreamFrame(f)).To(Succeed()) + }) + + It("returns errors", func() { + testErr := errors.New("test err") + f := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + } + str := NewMockReceiveStreamI(mockCtrl) + str.EXPECT().handleStreamFrame(f).Return(testErr) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) + Expect(conn.handleStreamFrame(f)).To(MatchError(testErr)) + }) + + It("ignores STREAM frames for closed streams", func() { + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(nil, nil) // for closed streams, the streamManager returns nil + Expect(conn.handleStreamFrame(&wire.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + })).To(Succeed()) + }) + }) + + Context("handling ACK frames", func() { + It("informs the SentPacketHandler about ACKs", func() { + f := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().ReceivedAck(f, protocol.EncryptionHandshake, gomock.Any()) + conn.sentPacketHandler = sph + err := conn.handleAckFrame(f, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("handling RESET_STREAM frames", func() { + It("closes the streams for writing", func() { + f := &wire.ResetStreamFrame{ + StreamID: 555, + ErrorCode: 42, + FinalSize: 0x1337, + } + str := NewMockReceiveStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(555)).Return(str, nil) + str.EXPECT().handleResetStreamFrame(f) + err := conn.handleResetStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns errors", func() { + f := &wire.ResetStreamFrame{ + StreamID: 7, + FinalSize: 0x1337, + } + testErr := errors.New("flow control violation") + str := NewMockReceiveStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(7)).Return(str, nil) + str.EXPECT().handleResetStreamFrame(f).Return(testErr) + err := conn.handleResetStreamFrame(f) + Expect(err).To(MatchError(testErr)) + }) + + It("ignores RESET_STREAM frames for closed streams", func() { + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(nil, nil) + Expect(conn.handleFrame(&wire.ResetStreamFrame{ + StreamID: 3, + ErrorCode: 42, + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + }) + }) + + Context("handling MAX_DATA and MAX_STREAM_DATA frames", func() { + var connFC *mocks.MockConnectionFlowController + + BeforeEach(func() { + connFC = mocks.NewMockConnectionFlowController(mockCtrl) + conn.connFlowController = connFC + }) + + It("updates the flow control window of a stream", func() { + f := &wire.MaxStreamDataFrame{ + StreamID: 12345, + MaximumStreamData: 0x1337, + } + str := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(12345)).Return(str, nil) + str.EXPECT().updateSendWindow(protocol.ByteCount(0x1337)) + Expect(conn.handleMaxStreamDataFrame(f)).To(Succeed()) + }) + + It("updates the flow control window of the connection", func() { + offset := protocol.ByteCount(0x800000) + connFC.EXPECT().UpdateSendWindow(offset) + conn.handleMaxDataFrame(&wire.MaxDataFrame{MaximumData: offset}) + }) + + It("ignores MAX_STREAM_DATA frames for a closed stream", func() { + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(10)).Return(nil, nil) + Expect(conn.handleFrame(&wire.MaxStreamDataFrame{ + StreamID: 10, + MaximumStreamData: 1337, + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + }) + }) + + Context("handling MAX_STREAM_ID frames", func() { + It("passes the frame to the streamsMap", func() { + f := &wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: 10, + } + streamManager.EXPECT().HandleMaxStreamsFrame(f) + conn.handleMaxStreamsFrame(f) + }) + }) + + Context("handling STOP_SENDING frames", func() { + It("passes the frame to the stream", func() { + f := &wire.StopSendingFrame{ + StreamID: 5, + ErrorCode: 10, + } + str := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(5)).Return(str, nil) + str.EXPECT().handleStopSendingFrame(f) + err := conn.handleStopSendingFrame(f) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores STOP_SENDING frames for a closed stream", func() { + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(3)).Return(nil, nil) + Expect(conn.handleFrame(&wire.StopSendingFrame{ + StreamID: 3, + ErrorCode: 1337, + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + }) + }) + + It("handles NEW_CONNECTION_ID frames", func() { + Expect(conn.handleFrame(&wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + }) + + It("handles PING frames", func() { + err := conn.handleFrame(&wire.PingFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects PATH_RESPONSE frames", func() { + err := conn.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).To(MatchError("unexpected PATH_RESPONSE frame")) + }) + + It("handles PATH_CHALLENGE frames", func() { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + err := conn.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).ToNot(HaveOccurred()) + frames, _ := conn.framer.AppendControlFrames(nil, 1000) + Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PathResponseFrame{Data: data}}})) + }) + + It("rejects NEW_TOKEN frames", func() { + err := conn.handleNewTokenFrame(&wire.NewTokenFrame{}) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) + }) + + It("handles BLOCKED frames", func() { + err := conn.handleFrame(&wire.DataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles STREAM_BLOCKED frames", func() { + err := conn.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles STREAMS_BLOCKED frames", func() { + err := conn.handleFrame(&wire.StreamsBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles CONNECTION_CLOSE frames, with a transport error code", func() { + expectedErr := &qerr.TransportError{ + Remote: true, + ErrorCode: qerr.StreamLimitError, + ErrorMessage: "foobar", + } + streamManager.EXPECT().CloseWithError(expectedErr) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) + }) + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) + }) + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(expectedErr), + tracer.EXPECT().Close(), + ) + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + Expect(conn.run()).To(MatchError(expectedErr)) + }() + Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ + ErrorCode: uint64(qerr.StreamLimitError), + ReasonPhrase: "foobar", + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("handles CONNECTION_CLOSE frames, with an application error code", func() { + testErr := &qerr.ApplicationError{ + Remote: true, + ErrorCode: 0x1337, + ErrorMessage: "foobar", + } + streamManager.EXPECT().CloseWithError(testErr) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) + }) + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) + }) + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(testErr), + tracer.EXPECT().Close(), + ) + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + Expect(conn.run()).To(MatchError(testErr)) + }() + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: 0x1337, + ReasonPhrase: "foobar", + IsApplicationError: true, + } + Expect(conn.handleFrame(ccf, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("errors on HANDSHAKE_DONE frames", func() { + Expect(conn.handleHandshakeDoneFrame()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received a HANDSHAKE_DONE frame", + })) + }) + }) + + It("tells its versions", func() { + conn.version = 4242 + Expect(conn.GetVersion()).To(Equal(protocol.VersionNumber(4242))) + }) + + Context("closing", func() { + var ( + runErr chan error + expectedRunErr error + ) + + BeforeEach(func() { + runErr = make(chan error, 1) + expectedRunErr = nil + }) + + AfterEach(func() { + if expectedRunErr != nil { + Eventually(runErr).Should(Receive(MatchError(expectedRunErr))) + } else { + Eventually(runErr).Should(Receive()) + } + }) + + runConn := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + runErr <- conn.run() + }() + Eventually(areConnsRunning).Should(BeTrue()) + } + + It("shuts down without error", func() { + conn.handshakeComplete = true + runConn() + streamManager.EXPECT().CloseWithError(&qerr.ApplicationError{}) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("connection close")...) + packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { + Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) + Expect(e.ErrorMessage).To(BeEmpty()) + return &coalescedPacket{buffer: buffer}, nil + }) + mconn.EXPECT().Write([]byte("connection close")) + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + var appErr *ApplicationError + Expect(errors.As(e, &appErr)).To(BeTrue()) + Expect(appErr.Remote).To(BeFalse()) + Expect(appErr.ErrorCode).To(BeZero()) + }), + tracer.EXPECT().Close(), + ) + conn.shutdown() + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) + }) + + It("only closes once", func() { + runConn() + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + conn.shutdown() + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) + }) + + It("closes with an error", func() { + runConn() + expectedErr := &qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + } + streamManager.EXPECT().CloseWithError(expectedErr) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackApplicationClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + mconn.EXPECT().Write(gomock.Any()) + gomock.InOrder( + tracer.EXPECT().ClosedConnection(expectedErr), + tracer.EXPECT().Close(), + ) + conn.CloseWithError(0x1337, "test error") + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) + }) + + It("includes the frame type in transport-level close frames", func() { + runConn() + expectedErr := &qerr.TransportError{ + ErrorCode: 0x1337, + FrameType: 0x42, + ErrorMessage: "test error", + } + streamManager.EXPECT().CloseWithError(expectedErr) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + mconn.EXPECT().Write(gomock.Any()) + gomock.InOrder( + tracer.EXPECT().ClosedConnection(expectedErr), + tracer.EXPECT().Close(), + ) + conn.closeLocal(expectedErr) + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) + }) + + It("destroys the connection", func() { + runConn() + testErr := errors.New("close") + streamManager.EXPECT().CloseWithError(gomock.Any()) + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cryptoSetup.EXPECT().Close() + // don't EXPECT any calls to mconn.Write() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + var transportErr *TransportError + Expect(errors.As(e, &transportErr)).To(BeTrue()) + Expect(transportErr.Remote).To(BeFalse()) + Expect(transportErr.ErrorCode).To(Equal(qerr.InternalError)) + }), + tracer.EXPECT().Close(), + ) + conn.destroy(testErr) + Eventually(areConnsRunning).Should(BeFalse()) + expectedRunErr = &qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: testErr.Error(), + } + }) + + It("cancels the context when the run loop exists", func() { + runConn() + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + returned := make(chan struct{}) + go func() { + defer GinkgoRecover() + ctx := conn.Context() + <-ctx.Done() + Expect(ctx.Err()).To(MatchError(context.Canceled)) + close(returned) + }() + Consistently(returned).ShouldNot(BeClosed()) + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(returned).Should(BeClosed()) + }) + + It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { + unpacker := NewMockUnpacker(mockCtrl) + conn.handshakeConfirmed = true + conn.unpacker = unpacker + runConn() + cryptoSetup.EXPECT().Close() + streamManager.EXPECT().CloseWithError(gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() + buf := &bytes.Buffer{} + hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen2, + } + Expect(hdr.Write(buf, conn.version)).To(Succeed()) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { + buf := &bytes.Buffer{} + Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, conn.version)).To(Succeed()) + return &unpackedPacket{ + hdr: hdr, + data: buf.Bytes(), + encryptionLevel: protocol.Encryption1RTT, + }, nil + }) + gomock.InOrder( + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()), + tracer.EXPECT().ClosedConnection(gomock.Any()), + tracer.EXPECT().Close(), + ) + // don't EXPECT any calls to packer.PackPacket() + conn.handlePacket(&receivedPacket{ + rcvTime: time.Now(), + remoteAddr: &net.UDPAddr{}, + buffer: getPacketBuffer(), + data: buf.Bytes(), + }) + // Consistently(pack).ShouldNot(Receive()) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("closes when the sendQueue encounters an error", func() { + conn.handshakeConfirmed = true + sconn := NewMockSendConn(mockCtrl) + sconn.EXPECT().Write(gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + conn.sendQueue = newSendQueue(sconn) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + // only expect a single SentPacket() call + sph.EXPECT().SentPacket(gomock.Any()) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + streamManager.EXPECT().CloseWithError(gomock.Any()) + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cryptoSetup.EXPECT().Close() + conn.sentPacketHandler = sph + p := getPacket(1) + packer.EXPECT().PackPacket().Return(p, nil) + packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + runConn() + conn.queueControlFrame(&wire.PingFrame{}) + conn.scheduleSending() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("closes due to a stateless reset", func() { + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + runConn() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + var srErr *StatelessResetError + Expect(errors.As(e, &srErr)).To(BeTrue()) + Expect(srErr.Token).To(Equal(token)) + }), + tracer.EXPECT().Close(), + ) + streamManager.EXPECT().CloseWithError(gomock.Any()) + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cryptoSetup.EXPECT().Close() + conn.destroy(&StatelessResetError{Token: token}) + }) + }) + + Context("receiving packets", func() { + var unpacker *MockUnpacker + + BeforeEach(func() { + unpacker = NewMockUnpacker(mockCtrl) + conn.unpacker = unpacker + }) + + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(extHdr.Write(buf, conn.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + rcvTime: time.Now(), + } + } + + It("drops Retry packets", func() { + p := getPacket(&wire.ExtendedHeader{Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + Version: conn.version, + Token: []byte("foobar"), + }}, make([]byte, 16) /* Retry integrity tag */) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("drops Version Negotiation packets", func() { + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) + Expect(conn.handlePacketImpl(&receivedPacket{ + data: b, + buffer: getPacketBuffer(), + })).To(BeFalse()) + }) + + It("drops packets for which header decryption fails", func() { + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }, nil) + p.data[0] ^= 0x40 // unset the QUIC bit + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("drops packets for which the version is unsupported", func() { + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: conn.version + 1, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }, nil) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnsupportedVersion) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("drops packets with an unsupported version", func() { + origSupportedVersions := make([]protocol.VersionNumber, len(protocol.SupportedVersions)) + copy(origSupportedVersions, protocol.SupportedVersions) + defer func() { + protocol.SupportedVersions = origSupportedVersions + }() + + protocol.SupportedVersions = append(protocol.SupportedVersions, conn.version+1) + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + Version: conn.version + 1, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }, nil) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedVersion) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x37, + PacketNumberLen: protocol.PacketNumberLen1, + } + packet := getPacket(hdr, nil) + packet.ecn = protocol.ECNCE + rcvTime := time.Now().Add(-10 * time.Second) + unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ + packetNumber: 0x1337, + encryptionLevel: protocol.EncryptionInitial, + hdr: hdr, + data: []byte{0}, // one PADDING frame + }, nil) + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + gomock.InOrder( + rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false), + ) + conn.receivedPacketHandler = rph + packet.rcvTime = rcvTime + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{}) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) + }) + + It("informs the ReceivedPacketHandler about ack-eliciting packets", func() { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x37, + PacketNumberLen: protocol.PacketNumberLen1, + } + rcvTime := time.Now().Add(-10 * time.Second) + buf := &bytes.Buffer{} + Expect((&wire.PingFrame{}).Write(buf, conn.version)).To(Succeed()) + packet := getPacket(hdr, nil) + packet.ecn = protocol.ECT1 + unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ + packetNumber: 0x1337, + encryptionLevel: protocol.Encryption1RTT, + hdr: hdr, + data: buf.Bytes(), + }, nil) + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + gomock.InOrder( + rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true), + ) + conn.receivedPacketHandler = rph + packet.rcvTime = rcvTime + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) + }) + + It("drops duplicate packets", func() { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x37, + PacketNumberLen: protocol.PacketNumberLen1, + } + packet := getPacket(hdr, nil) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + packetNumber: 0x1337, + encryptionLevel: protocol.Encryption1RTT, + hdr: hdr, + data: []byte("foobar"), + }, nil) + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) + conn.receivedPacketHandler = rph + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) + Expect(conn.handlePacketImpl(packet)).To(BeFalse()) + }) + + It("drops a packet when unpacking fails", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + expectReplaceWithClosed() + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: srcConnID, + Version: conn.version, + Length: 2 + 6, + }, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar")) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropPayloadDecryptError) + conn.handlePacket(p) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) + // make the go routine return + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("processes multiple received packets before sending one", func() { + conn.creationTime = time.Now() + var pn protocol.PacketNumber + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + pn++ + return &unpackedPacket{ + data: []byte{0}, // PADDING frame + encryptionLevel: protocol.Encryption1RTT, + packetNumber: pn, + hdr: &wire.ExtendedHeader{Header: *hdr}, + }, nil + }).Times(3) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { + }).Times(3) + packer.EXPECT().PackCoalescedPacket() // only expect a single call + + for i := 0; i < 3; i++ { + conn.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar"))) + } + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) + + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("doesn't processes multiple received packets before sending one before handshake completion", func() { + conn.handshakeComplete = false + conn.creationTime = time.Now() + var pn protocol.PacketNumber + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + pn++ + return &unpackedPacket{ + data: []byte{0}, // PADDING frame + encryptionLevel: protocol.Encryption1RTT, + packetNumber: pn, + hdr: &wire.ExtendedHeader{Header: *hdr}, + }, nil + }).Times(3) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { + }).Times(3) + packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call + + for i := 0; i < 3; i++ { + conn.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar"))) + } + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) + + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := conn.run() + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) + close(done) + }() + expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.handlePacket(packet) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("ignores packets when unpacking the header fails", func() { + testErr := &headerParseError{errors.New("test error")} + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + runErr := make(chan error) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + runErr <- conn.run() + }() + expectReplaceWithClosed() + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) + conn.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) + Consistently(runErr).ShouldNot(Receive()) + // make the go routine return + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("closes the connection when unpacking fails because of an error other than a decryption error", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := conn.run() + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ConnectionIDLimitError)) + close(done) + }() + expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.handlePacket(packet) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("rejects packets with empty payload", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + hdr: &wire.ExtendedHeader{}, + data: []byte{}, // no payload + encryptionLevel: protocol.Encryption1RTT, + }, nil) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + Expect(conn.run()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + })) + close(done) + }() + expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) + Eventually(done).Should(BeClosed()) + }) + + It("ignores packets with a different source connection ID", func() { + hdr1 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + Length: 1, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + hdr2 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 2, + } + Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) + // Send one packet, which might change the connection ID. + // only EXPECT one call to the unpacker + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + encryptionLevel: protocol.Encryption1RTT, + hdr: hdr1, + data: []byte{0}, // one PADDING frame + }, nil) + p1 := getPacket(hdr1, nil) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) + Expect(conn.handlePacketImpl(p1)).To(BeTrue()) + // The next packet has to be ignored, since the source connection ID doesn't match. + p2 := getPacket(hdr2, nil) + tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) + Expect(conn.handlePacketImpl(p2)).To(BeFalse()) + }) + + It("queues undecryptable packets", func() { + conn.handshakeComplete = false + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + Length: 1, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) + packet := getPacket(hdr, nil) + tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake) + Expect(conn.handlePacketImpl(packet)).To(BeFalse()) + Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) + }) + + Context("updating the remote address", func() { + It("doesn't support connection migration", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + encryptionLevel: protocol.Encryption1RTT, + hdr: &wire.ExtendedHeader{}, + data: []byte{0}, // one PADDING frame + }, nil) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) + }) + }) + + Context("coalesced packets", func() { + BeforeEach(func() { + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + }) + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + SrcConnectionID: destConnID, + Version: protocol.VersionTLS, + Length: length, + }, + PacketNumberLen: protocol.PacketNumberLen3, + } + hdrLen := hdr.GetLength(conn.version) + b := make([]byte, 1) + rand.Read(b) + packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) + return int(hdrLen), packet + } + + It("cuts packets to the right length", func() { + hdrLen, packet := getPacketWithLength(srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen + 456 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + hdr: &wire.ExtendedHeader{}, + }, nil + }) + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) + }) + + It("handles coalesced packets", func() { + hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + packetNumber: 1, + hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, + }, nil + }) + hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + packetNumber: 2, + hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, + }, nil + }) + gomock.InOrder( + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + ) + packet1.data = append(packet1.data, packet2.data...) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) + }) + + It("works with undecryptable packets", func() { + conn.handshakeComplete = false + hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) + hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) + gomock.InOrder( + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + hdr: &wire.ExtendedHeader{}, + }, nil + }), + ) + gomock.InOrder( + tracer.EXPECT().BufferedPacket(gomock.Any()), + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + ) + packet1.data = append(packet1.data, packet2.data...) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) + + Expect(conn.undecryptablePackets).To(HaveLen(1)) + Expect(conn.undecryptablePackets[0].data).To(HaveLen(hdrLen1 + 456 - 3)) + }) + + It("ignores coalesced packet parts if the destination connection IDs don't match", func() { + wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + Expect(srcConnID).ToNot(Equal(wrongConnID)) + hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + hdr: &wire.ExtendedHeader{}, + }, nil + }) + _, packet2 := getPacketWithLength(wrongConnID, 123) + // don't EXPECT any more calls to unpacker.Unpack() + gomock.InOrder( + tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), + ) + packet1.data = append(packet1.data, packet2.data...) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) + }) + }) + }) + + Context("sending packets", func() { + var ( + connDone chan struct{} + sender *MockSender + ) + + BeforeEach(func() { + sender = NewMockSender(mockCtrl) + sender.EXPECT().Run() + sender.EXPECT().WouldBlock().AnyTimes() + conn.sendQueue = sender + connDone = make(chan struct{}) + }) + + AfterEach(func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + sender.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + Eventually(connDone).Should(BeClosed()) + }) + + runConn := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + close(connDone) + }() + } + + It("sends packets", func() { + conn.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + conn.sentPacketHandler = sph + runConn() + p := getPacket(1) + packer.EXPECT().PackPacket().Return(p, nil) + packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + sent := make(chan struct{}) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + tracer.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []logging.Frame{}) + conn.scheduleSending() + Eventually(sent).Should(BeClosed()) + }) + + It("doesn't send packets if there's nothing to send", func() { + conn.handshakeConfirmed = true + runConn() + packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() + }) + + It("sends ACK only packets", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + done := make(chan struct{}) + packer.EXPECT().MaybePackAckPacket(false).Do(func(bool) { close(done) }) + conn.sentPacketHandler = sph + runConn() + conn.scheduleSending() + Eventually(done).Should(BeClosed()) + }) + + It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { + conn.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + conn.sentPacketHandler = sph + fc := mocks.NewMockConnectionFlowController(mockCtrl) + fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + fc.EXPECT().IsNewlyBlocked() + p := getPacket(1) + packer.EXPECT().PackPacket().Return(p, nil) + packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + conn.connFlowController = fc + runConn() + sent := make(chan struct{}) + sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + tracer.EXPECT().SentPacket(p.header, p.length, nil, []logging.Frame{}) + conn.scheduleSending() + Eventually(sent).Should(BeClosed()) + frames, _ := conn.framer.AppendControlFrames(nil, 1000) + Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &logging.DataBlockedFrame{MaximumData: 1337}}})) + }) + + It("doesn't send when the SentPacketHandler doesn't allow it", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendNone).AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + conn.sentPacketHandler = sph + runConn() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) + }) + + for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { + encLevel := enc + + Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { + var sendMode ackhandler.SendMode + var getFrame func(protocol.ByteCount) wire.Frame + + BeforeEach(func() { + //nolint:exhaustive + switch encLevel { + case protocol.EncryptionInitial: + sendMode = ackhandler.SendPTOInitial + getFrame = conn.retransmissionQueue.GetInitialFrame + case protocol.EncryptionHandshake: + sendMode = ackhandler.SendPTOHandshake + getFrame = conn.retransmissionQueue.GetHandshakeFrame + case protocol.Encryption1RTT: + sendMode = ackhandler.SendPTOAppData + getFrame = conn.retransmissionQueue.GetAppDataFrame + } + }) + + It("sends a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + sph.EXPECT().QueueProbePacket(encLevel) + p := getPacket(123) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + conn.sentPacketHandler = sph + runConn() + sent := make(chan struct{}) + sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) + conn.scheduleSending() + Eventually(sent).Should(BeClosed()) + }) + + It("sends a PING as a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + sph.EXPECT().QueueProbePacket(encLevel).Return(false) + p := getPacket(123) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + conn.sentPacketHandler = sph + runConn() + sent := make(chan struct{}) + sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) + conn.scheduleSending() + Eventually(sent).Should(BeClosed()) + // We're using a mock packet packer in this test. + // We therefore need to test separately that the PING was actually queued. + Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) + }) + }) + } + }) + + Context("packet pacing", func() { + var ( + sph *mockackhandler.MockSentPacketHandler + sender *MockSender + ) + + BeforeEach(func() { + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + conn.handshakeConfirmed = true + conn.handshakeComplete = true + conn.sentPacketHandler = sph + sender = NewMockSender(mockCtrl) + sender.EXPECT().Run() + conn.sendQueue = sender + streamManager.EXPECT().CloseWithError(gomock.Any()) + }) + + AfterEach(func() { + // make the go routine return + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + sender.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("sends multiple packets one by one immediately", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().HasPacingBudget().Return(true).Times(2) + sph.EXPECT().HasPacingBudget() + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) + packer.EXPECT().PackPacket().Return(getPacket(10), nil) + packer.EXPECT().PackPacket().Return(getPacket(11), nil) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).Times(2) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent + }) + + It("sends multiple packets, when the pacer allows immediate sending", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) + packer.EXPECT().PackPacket().Return(getPacket(10), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent + }) + + It("allows an ACK to be sent when pacing limited", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget() + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + packer.EXPECT().MaybePackAckPacket(gomock.Any()).Return(getPacket(10), nil) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent + }) + + // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck + // we shouldn't send the ACK in the same run + It("doesn't send an ACK right after becoming congestion limited", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true) + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent + }) + + It("paces packets", func() { + pacingDelay := scaleDuration(100 * time.Millisecond) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + gomock.InOrder( + sph.EXPECT().HasPacingBudget().Return(true), + packer.EXPECT().PackPacket().Return(getPacket(100), nil), + sph.EXPECT().SentPacket(gomock.Any()), + sph.EXPECT().HasPacingBudget(), + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), + sph.EXPECT().HasPacingBudget().Return(true), + packer.EXPECT().PackPacket().Return(getPacket(101), nil), + sph.EXPECT().SentPacket(gomock.Any()), + sph.EXPECT().HasPacingBudget(), + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), + ) + written := make(chan struct{}, 2) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(2) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + Eventually(written).Should(HaveLen(1)) + Consistently(written, pacingDelay/2).Should(HaveLen(1)) + Eventually(written, 2*pacingDelay).Should(HaveLen(2)) + }) + + It("sends multiple packets at once", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(3) + sph.EXPECT().HasPacingBudget().Return(true).Times(3) + sph.EXPECT().HasPacingBudget() + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4) + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + written := make(chan struct{}, 3) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + Eventually(written).Should(HaveLen(3)) + }) + + It("doesn't try to send if the send queue is full", func() { + available := make(chan struct{}, 1) + sender.EXPECT().WouldBlock().Return(true) + sender.EXPECT().Available().Return(available) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(scaleDuration(50 * time.Millisecond)) + + written := make(chan struct{}) + sender.EXPECT().WouldBlock().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) + available <- struct{}{} + Eventually(written).Should(BeClosed()) + }) + + It("stops sending when there are new packets to receive", func() { + sender.EXPECT().WouldBlock().AnyTimes() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + + written := make(chan struct{}) + sender.EXPECT().WouldBlock().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) { + sph.EXPECT().ReceivedBytes(gomock.Any()) + conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) + }) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) + + conn.scheduleSending() + time.Sleep(scaleDuration(50 * time.Millisecond)) + + Eventually(written).Should(BeClosed()) + }) + + It("stops sending when the send queue is full", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + written := make(chan struct{}, 1) + sender.EXPECT().WouldBlock() + sender.EXPECT().WouldBlock().Return(true).Times(2) + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + available := make(chan struct{}, 1) + sender.EXPECT().Available().Return(available) + conn.scheduleSending() + Eventually(written).Should(Receive()) + time.Sleep(scaleDuration(50 * time.Millisecond)) + + // now make room in the send queue + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sender.EXPECT().WouldBlock().AnyTimes() + packer.EXPECT().PackPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + available <- struct{}{} + Eventually(written).Should(Receive()) + + // The send queue is not full any more. Sending on the available channel should have no effect. + available <- struct{}{} + time.Sleep(scaleDuration(50 * time.Millisecond)) + }) + + It("doesn't set a pacing timer when there is no data to send", func() { + sph.EXPECT().HasPacingBudget().Return(true) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sender.EXPECT().WouldBlock().AnyTimes() + packer.EXPECT().PackPacket() + // don't EXPECT any calls to mconn.Write() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() // no packet will get sent + time.Sleep(50 * time.Millisecond) + }) + + It("sends a Path MTU probe packet", func() { + mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl) + conn.mtuDiscoverer = mtuDiscoverer + conn.config.DisablePathMTUDiscovery = false + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + written := make(chan struct{}, 1) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) + ping := ackhandler.Frame{Frame: &wire.PingFrame{}} + mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + Eventually(written).Should(Receive()) + }) + }) + + Context("scheduling sending", func() { + var sender *MockSender + + BeforeEach(func() { + sender = NewMockSender(mockCtrl) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Run() + conn.sendQueue = sender + conn.handshakeConfirmed = true + }) + + AfterEach(func() { + // make the go routine return + expectReplaceWithClosed() + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + sender.EXPECT().Close() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("sends when scheduleSending is called", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + conn.sentPacketHandler = sph + packer.EXPECT().PackPacket().Return(getPacket(1), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + // don't EXPECT any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) + // only EXPECT calls after scheduleSending is called + written := make(chan struct{}) + sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + conn.scheduleSending() + Eventually(written).Should(BeClosed()) + }) + + It("sets the timer to the ack timer", func() { + packer.EXPECT().PackPacket().Return(getPacket(1234), nil) + packer.EXPECT().PackPacket().Return(nil, nil) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1234))) + }) + conn.sentPacketHandler = sph + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) + // make the run loop wait + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) + conn.receivedPacketHandler = rph + + written := make(chan struct{}) + sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + Eventually(written).Should(BeClosed()) + }) + }) + + It("sends coalesced packets before the handshake is confirmed", func() { + conn.handshakeComplete = false + conn.handshakeConfirmed = false + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("foobar")...) + packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ + buffer: buffer, + packets: []*packetContents{ + { + header: &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + }, + PacketNumber: 13, + }, + length: 123, + }, + { + header: &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + }, + PacketNumber: 37, + }, + length: 1234, + }, + }, + }, nil) + packer.EXPECT().PackCoalescedPacket().AnyTimes() + + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() + gomock.InOrder( + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(13))) + Expect(p.Length).To(BeEquivalentTo(123)) + }), + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionHandshake)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(37))) + Expect(p.Length).To(BeEquivalentTo(1234)) + }), + ) + gomock.InOrder( + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + }), + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + }), + ) + + sent := make(chan struct{}) + mconn.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(sent) }) + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + + conn.scheduleSending() + Eventually(sent).Should(BeClosed()) + + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("cancels the HandshakeComplete context when the handshake completes", func() { + packer.EXPECT().PackCoalescedPacket().AnyTimes() + finishHandshake := make(chan struct{}) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() + connRunner.EXPECT().Retire(clientDestConnID) + go func() { + defer GinkgoRecover() + <-finishHandshake + cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().GetSessionTicket() + close(conn.handshakeCompleteChan) + conn.run() + }() + handshakeCtx := conn.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + close(finishHandshake) + Eventually(handshakeCtx.Done()).Should(BeClosed()) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("sends a connection ticket when the handshake completes", func() { + const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 + packer.EXPECT().PackCoalescedPacket().AnyTimes() + finishHandshake := make(chan struct{}) + connRunner.EXPECT().Retire(clientDestConnID) + go func() { + defer GinkgoRecover() + <-finishHandshake + cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) + close(conn.handshakeCompleteChan) + conn.run() + }() + + handshakeCtx := conn.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + close(finishHandshake) + var frames []ackhandler.Frame + Eventually(func() []ackhandler.Frame { + frames, _ = conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) + return frames + }).ShouldNot(BeEmpty()) + var count int + var s int + for _, f := range frames { + if cf, ok := f.Frame.(*wire.CryptoFrame); ok { + count++ + s += len(cf.Data) + Expect(f.Length(conn.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + Expect(size).To(BeEquivalentTo(s)) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { + packer.EXPECT().PackCoalescedPacket().AnyTimes() + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake() + conn.run() + }() + handshakeCtx := conn.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + mconn.EXPECT().Write(gomock.Any()) + conn.closeLocal(errors.New("handshake error")) + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() + sph.EXPECT().SentPacket(gomock.Any()) + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + conn.sentPacketHandler = sph + done := make(chan struct{}) + connRunner.EXPECT().Retire(clientDestConnID) + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) + Expect(frames).ToNot(BeEmpty()) + Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) + defer close(done) + return &packedPacket{ + packetContents: &packetContents{ + header: &wire.ExtendedHeader{}, + }, + buffer: getPacketBuffer(), + }, nil + }) + packer.EXPECT().PackPacket().AnyTimes() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().GetSessionTicket() + mconn.EXPECT().Write(gomock.Any()) + close(conn.handshakeCompleteChan) + conn.run() + }() + Eventually(done).Should(BeClosed()) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("doesn't return a run error when closing", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + Expect(conn.run()).To(Succeed()) + close(done) + }() + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(done).Should(BeClosed()) + }) + + It("passes errors to the connection runner", func() { + testErr := errors.New("handshake error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := conn.run() + Expect(err).To(MatchError(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: testErr.Error(), + })) + close(done) + }() + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + Context("transport parameters", func() { + It("processes transport parameters received from the client", func() { + params := &wire.TransportParameters{ + MaxIdleTimeout: 90 * time.Second, + InitialMaxStreamDataBidiLocal: 0x5000, + InitialMaxData: 0x5000, + ActiveConnectionIDLimit: 3, + // marshaling always sets it to this value + MaxUDPPayloadSize: protocol.MaxPacketBufferSize, + InitialSourceConnectionID: destConnID, + } + streamManager.EXPECT().UpdateLimits(params) + packer.EXPECT().HandleTransportParameters(params) + packer.EXPECT().PackCoalescedPacket().MaxTimes(3) + Expect(conn.earlyConnReady()).ToNot(BeClosed()) + connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) + connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Expect(conn.earlyConnReady()).To(BeClosed()) + }) + }) + + Context("keep-alives", func() { + setRemoteIdleTimeout := func(t time.Duration) { + streamManager.EXPECT().UpdateLimits(gomock.Any()) + packer.EXPECT().HandleTransportParameters(gomock.Any()) + tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) + conn.handleTransportParameters(&wire.TransportParameters{ + MaxIdleTimeout: t, + InitialSourceConnectionID: destConnID, + }) + } + + runConn := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + } + + BeforeEach(func() { + conn.config.MaxIdleTimeout = 30 * time.Second + conn.config.KeepAlivePeriod = 15 * time.Second + conn.receivedPacketHandler.ReceivedPacket(0, protocol.ECNNon, protocol.EncryptionHandshake, time.Now(), true) + }) + + AfterEach(func() { + // make the go routine return + expectReplaceWithClosed() + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("sends a PING as a keep-alive after half the idle timeout", func() { + setRemoteIdleTimeout(5 * time.Second) + conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) + sent := make(chan struct{}) + packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + close(sent) + return nil, nil + }) + runConn() + Eventually(sent).Should(BeClosed()) + }) + + It("sends a PING after a maximum of protocol.MaxKeepAliveInterval", func() { + conn.config.MaxIdleTimeout = time.Hour + setRemoteIdleTimeout(time.Hour) + conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) + sent := make(chan struct{}) + packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + close(sent) + return nil, nil + }) + runConn() + Eventually(sent).Should(BeClosed()) + }) + + It("doesn't send a PING packet if keep-alive is disabled", func() { + setRemoteIdleTimeout(5 * time.Second) + conn.config.KeepAlivePeriod = 0 + conn.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) + runConn() + // don't EXPECT() any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) + }) + + It("doesn't send a PING if the handshake isn't completed yet", func() { + conn.config.HandshakeIdleTimeout = time.Hour + conn.handshakeComplete = false + // Needs to be shorter than our idle timeout. + // Otherwise we'll try to send a CONNECTION_CLOSE. + conn.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) + runConn() + // don't EXPECT() any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) + }) + }) + + Context("timeouts", func() { + BeforeEach(func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) + }) + + It("times out due to no network activity", func() { + connRunner.EXPECT().Remove(gomock.Any()).Times(2) + conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) + done := make(chan struct{}) + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + Expect(e).To(MatchError(&qerr.IdleTimeoutError{})) + }), + tracer.EXPECT().Close(), + ) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := conn.run() + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) + + It("times out due to non-completed handshake", func() { + conn.handshakeComplete = false + conn.creationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + connRunner.EXPECT().Remove(gomock.Any()).Times(2) + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + Expect(e).To(MatchError(&HandshakeTimeoutError{})) + }), + tracer.EXPECT().Close(), + ) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := conn.run() + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err).To(MatchError(qerr.ErrHandshakeTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) + + It("does not use the idle timeout before the handshake complete", func() { + conn.handshakeComplete = false + conn.config.HandshakeIdleTimeout = 9999 * time.Second + conn.config.MaxIdleTimeout = 9999 * time.Second + conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) + packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { + Expect(e.ErrorCode).To(BeZero()) + return &coalescedPacket{buffer: getPacketBuffer()}, nil + }) + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + idleTimeout := &IdleTimeoutError{} + handshakeTimeout := &HandshakeTimeoutError{} + Expect(errors.As(e, &idleTimeout)).To(BeFalse()) + Expect(errors.As(e, &handshakeTimeout)).To(BeFalse()) + }), + tracer.EXPECT().Close(), + ) + // the handshake timeout is irrelevant here, since it depends on the time the connection was created, + // and not on the last network activity + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) + // make the go routine return + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("closes the connection due to the idle timeout before handshake", func() { + conn.config.HandshakeIdleTimeout = 0 + packer.EXPECT().PackCoalescedPacket().AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + Expect(e).To(MatchError(&IdleTimeoutError{})) + }), + tracer.EXPECT().Close(), + ) + done := make(chan struct{}) + conn.handshakeComplete = false + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) + err := conn.run() + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) + + It("closes the connection due to the idle timeout after handshake", func() { + packer.EXPECT().PackCoalescedPacket().AnyTimes() + gomock.InOrder( + connRunner.EXPECT().Retire(clientDestConnID), + connRunner.EXPECT().Remove(gomock.Any()), + ) + cryptoSetup.EXPECT().Close() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + Expect(e).To(MatchError(&IdleTimeoutError{})) + }), + tracer.EXPECT().Close(), + ) + conn.idleTimeout = 0 + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) + cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) + close(conn.handshakeCompleteChan) + err := conn.run() + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) + + It("doesn't time out when it just sent a packet", func() { + conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) + conn.firstAckElicitingPacketAfterIdleSentTime = time.Now().Add(-time.Second) + conn.idleTimeout = 30 * time.Second + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) + // make the go routine return + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + }) + + It("stores up to MaxConnUnprocessedPackets packets", func() { + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.ByteCount, logging.PacketDropReason) { + close(done) + }) + // Nothing here should block + for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ { + conn.handlePacket(&receivedPacket{data: []byte("foobar")}) + } + Eventually(done).Should(BeClosed()) + }) + + Context("getting streams", func() { + It("opens streams", func() { + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().OpenStream().Return(mstr, nil) + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + + It("opens streams synchronously", func() { + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil) + str, err := conn.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + + It("opens unidirectional streams", func() { + mstr := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().OpenUniStream().Return(mstr, nil) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + + It("opens unidirectional streams synchronously", func() { + mstr := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil) + str, err := conn.OpenUniStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + + It("accepts streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil) + str, err := conn.AcceptStream(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + + It("accepts unidirectional streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + mstr := NewMockReceiveStreamI(mockCtrl) + streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil) + str, err := conn.AcceptUniStream(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) + }) + }) + + It("returns the local address", func() { + Expect(conn.LocalAddr()).To(Equal(localAddr)) + }) + + It("returns the remote address", func() { + Expect(conn.RemoteAddr()).To(Equal(remoteAddr)) + }) +}) + +var _ = Describe("Client Connection", func() { + var ( + conn *connection + connRunner *MockConnRunner + packer *MockPacker + mconn *MockSendConn + cryptoSetup *mocks.MockCryptoSetup + tracer *mocklogging.MockConnectionTracer + tlsConf *tls.Config + quicConf *Config + ) + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + + getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(hdr.Write(buf, conn.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + + expectReplaceWithClosed := func() { + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + s.shutdown() + Eventually(areClosedConnsRunning).Should(BeFalse()) + }) + } + + BeforeEach(func() { + quicConf = populateClientConfig(&Config{}, true) + tlsConf = nil + }) + + JustBeforeEach(func() { + Eventually(areConnsRunning).Should(BeFalse()) + + mconn = NewMockSendConn(mockCtrl) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() + mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + if tlsConf == nil { + tlsConf = &tls.Config{} + } + connRunner = NewMockConnRunner(mockCtrl) + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentTransportParameters(gomock.Any()) + tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().UpdatedCongestionState(gomock.Any()) + conn = newClientConnection( + mconn, + connRunner, + destConnID, + protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + quicConf, + tlsConf, + 42, // initial packet number + false, + false, + tracer, + 1234, + utils.DefaultLogger, + protocol.VersionTLS, + ).(*connection) + packer = NewMockPacker(mockCtrl) + conn.packer = packer + cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) + conn.cryptoStreamHandler = cryptoSetup + conn.sentFirstPacket = true + }) + + It("changes the connection ID when receiving the first packet from the server", func() { + unpacker := NewMockUnpacker(mockCtrl) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { + return &unpackedPacket{ + encryptionLevel: protocol.Encryption1RTT, + hdr: &wire.ExtendedHeader{Header: *hdr}, + data: []byte{0}, // one PADDING frame + }, nil + }) + conn.unpacker = unpacker + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: newConnID, + DestConnectionID: srcConnID, + Length: 2 + 6, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar")) + tracer.EXPECT().ReceivedPacket(gomock.Any(), p.Size(), []logging.Frame{}) + Expect(conn.handlePacketImpl(p)).To(BeTrue()) + // make sure the go routine returns + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + }) + + It("continues accepting Long Header packets after using a new connection ID", func() { + unpacker := NewMockUnpacker(mockCtrl) + conn.unpacker = unpacker + connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) + conn.connIDManager.SetHandshakeComplete() + conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, + }) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + // now receive a packet with the original source connection ID + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { + return &unpackedPacket{ + hdr: &wire.ExtendedHeader{Header: *hdr}, + data: []byte{0}, + encryptionLevel: protocol.EncryptionHandshake, + }, nil + }) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: srcConnID, + SrcConnectionID: destConnID, + } + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) + }) + + It("handles HANDSHAKE_DONE frames", func() { + conn.peerParams = &wire.TransportParameters{} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + sph.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + Expect(conn.handleHandshakeDoneFrame()).To(Succeed()) + }) + + It("interprets an ACK for 1-RTT packets as confirmation of the handshake", func() { + conn.peerParams = &wire.TransportParameters{} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}} + sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil) + sph.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) + cryptoSetup.EXPECT().SetHandshakeConfirmed() + Expect(conn.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) + }) + + It("doesn't send a CONNECTION_CLOSE when no packet was sent", func() { + conn.sentFirstPacket = false + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + running := make(chan struct{}) + cryptoSetup.EXPECT().RunHandshake().Do(func() { + close(running) + conn.closeLocal(errors.New("early error")) + }) + cryptoSetup.EXPECT().Close() + connRunner.EXPECT().Remove(gomock.Any()) + go func() { + defer GinkgoRecover() + conn.run() + }() + Eventually(running).Should(BeClosed()) + Eventually(areConnsRunning).Should(BeFalse()) + }) + + Context("handling tokens", func() { + var mockTokenStore *MockTokenStore + + BeforeEach(func() { + mockTokenStore = NewMockTokenStore(mockCtrl) + tlsConf = &tls.Config{ServerName: "server"} + quicConf.TokenStore = mockTokenStore + mockTokenStore.EXPECT().Pop(gomock.Any()) + quicConf.TokenStore = mockTokenStore + }) + + It("handles NEW_TOKEN frames", func() { + mockTokenStore.EXPECT().Put("server", &ClientToken{data: []byte("foobar")}) + Expect(conn.handleNewTokenFrame(&wire.NewTokenFrame{Token: []byte("foobar")})).To(Succeed()) + }) + }) + + Context("handling Version Negotiation", func() { + getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) + return &receivedPacket{ + data: b, + buffer: getPacketBuffer(), + } + } + + It("closes and returns the right error", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()) + sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) + conn.config.Versions = []protocol.VersionNumber{1234, 4321} + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + errChan <- conn.run() + }() + connRunner.EXPECT().Remove(srcConnID) + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()).Do(func(hdr *wire.Header, versions []logging.VersionNumber) { + Expect(hdr.Version).To(BeZero()) + Expect(versions).To(And( + ContainElement(protocol.VersionNumber(4321)), + ContainElement(protocol.VersionNumber(1337)), + )) + }) + cryptoSetup.EXPECT().Close() + Expect(conn.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{})) + recreateErr := err.(*errCloseForRecreating) + Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321))) + Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) + }) + + It("it closes when no matching version is found", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + errChan <- conn.run() + }() + connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) + gomock.InOrder( + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()), + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { + var vnErr *VersionNegotiationError + Expect(errors.As(e, &vnErr)).To(BeTrue()) + Expect(vnErr.Theirs).To(ContainElement(logging.VersionNumber(12345678))) + }), + tracer.EXPECT().Close(), + ) + cryptoSetup.EXPECT().Close() + Expect(conn.handlePacketImpl(getVNP(12345678))).To(BeFalse()) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).ToNot(BeAssignableToTypeOf(errCloseForRecreating{})) + Expect(err.Error()).To(ContainSubstring("no compatible QUIC version found")) + }) + + It("ignores Version Negotiation packets that offer the current version", func() { + p := getVNP(conn.version) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("ignores unparseable Version Negotiation packets", func() { + p := getVNP(conn.version) + p.data = p.data[:len(p.data)-2] + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + }) + + Context("handling Retry", func() { + origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + + var retryHdr *wire.ExtendedHeader + + JustBeforeEach(func() { + retryHdr = &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Token: []byte("foobar"), + Version: conn.version, + }, + } + }) + + getRetryTag := func(hdr *wire.ExtendedHeader) []byte { + buf := &bytes.Buffer{} + hdr.Write(buf, conn.version) + return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:] + } + + It("handles Retry packets", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + sph.EXPECT().ResetForRetry() + sph.EXPECT().ReceivedBytes(gomock.Any()) + cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + packer.EXPECT().SetToken([]byte("foobar")) + tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { + Expect(hdr.DestConnectionID).To(Equal(retryHdr.DestConnectionID)) + Expect(hdr.SrcConnectionID).To(Equal(retryHdr.SrcConnectionID)) + Expect(hdr.Token).To(Equal(retryHdr.Token)) + }) + Expect(conn.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) + }) + + It("ignores Retry packets after receiving a regular packet", func() { + conn.receivedFirstPacket = true + p := getPacket(retryHdr, getRetryTag(retryHdr)) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("ignores Retry packets if the server didn't change the connection ID", func() { + retryHdr.SrcConnectionID = destConnID + p := getPacket(retryHdr, getRetryTag(retryHdr)) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + It("ignores Retry packets with the a wrong Integrity tag", func() { + tag := getRetryTag(retryHdr) + tag[0]++ + p := getPacket(retryHdr, tag) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropPayloadDecryptError) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + }) + + Context("transport parameters", func() { + var ( + closed bool + errChan chan error + ) + + JustBeforeEach(func() { + errChan = make(chan error, 1) + closed = false + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + errChan <- conn.run() + close(errChan) + }() + }) + + expectClose := func(applicationClose bool) { + if !closed { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) + s.shutdown() + }) + if applicationClose { + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + } else { + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + } + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()), + tracer.EXPECT().Close(), + ) + } + closed = true + } + + AfterEach(func() { + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + Eventually(errChan).Should(BeClosed()) + }) + + It("uses the preferred_address connection ID", func() { + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + PreferredAddress: &wire.PreferredAddress{ + IPv4: net.IPv4(127, 0, 0, 1), + IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + }, + } + packer.EXPECT().HandleTransportParameters(gomock.Any()) + packer.EXPECT().PackCoalescedPacket().MaxTimes(1) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + conn.handleHandshakeComplete() + // make sure the connection ID is not retired + cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) + Expect(cf).To(BeEmpty()) + connRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, conn) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + // shut down + connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) + expectClose(true) + }) + + It("uses the minimum of the peers' idle timeouts", func() { + conn.config.MaxIdleTimeout = 19 * time.Second + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + MaxIdleTimeout: 18 * time.Second, + } + packer.EXPECT().HandleTransportParameters(gomock.Any()) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + conn.handleHandshakeComplete() + Expect(conn.idleTimeout).To(Equal(18 * time.Second)) + expectClose(true) + }) + + It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { + conn.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose(false) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad", + }))) + }) + + It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { + conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose(false) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing retry_source_connection_id", + }))) + }) + + It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { + conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose(false) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de", + }))) + }) + + It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose(false) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "received retry_source_connection_id, although no Retry was performed", + }))) + }) + + It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { + conn.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + InitialSourceConnectionID: conn.handshakeDestConnID, + StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose(false) + tracer.EXPECT().ReceivedTransportParameters(params) + conn.handleTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad", + }))) + }) + }) + + Context("handling potentially injected packets", func() { + var unpacker *MockUnpacker + + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(extHdr.Write(buf, conn.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + + // Convert an already packed raw packet into a receivedPacket + wrapPacket := func(packet []byte) *receivedPacket { + return &receivedPacket{ + data: packet, + buffer: getPacketBuffer(), + } + } + + // Illustrates that attacker may inject an Initial packet with a different + // source connection ID, causing endpoint to ignore a subsequent real Initial packets. + It("ignores Initial packets with a different source connection ID", func() { + // Modified from test "ignores packets with a different source connection ID" + unpacker = NewMockUnpacker(mockCtrl) + conn.unpacker = unpacker + + hdr1 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, + Length: 1, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + hdr2 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, + Version: conn.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 2, + } + Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID)) + // Send one packet, which might change the connection ID. + // only EXPECT one call to the unpacker + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + encryptionLevel: protocol.EncryptionInitial, + hdr: hdr1, + data: []byte{0}, // one PADDING frame + }, nil) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) + // The next packet has to be ignored, since the source connection ID doesn't match. + tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) + }) + + It("ignores 0-RTT packets", func() { + p := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + DestConnectionID: srcConnID, + Length: 2 + 6, + Version: conn.version, + }, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar")) + tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, p.Size(), gomock.Any()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) + }) + + // Illustrates that an injected Initial with an ACK frame for an unsent packet causes + // the connection to immediately break down + It("fails on Initial-level ACK for unsent packet", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + }) + + // Illustrates that an injected Initial with a CONNECTION_CLOSE frame causes + // the connection to immediately break down + It("fails on Initial-level CONNECTION_CLOSE frame", func() { + connCloseFrame := &wire.ConnectionCloseFrame{ + IsApplicationError: true, + ReasonPhrase: "mitm attacker", + } + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) + }) + + // Illustrates that attacker who injects a Retry packet and changes the connection ID + // can cause subsequent real Initial packets to be ignored + It("ignores Initial packets which use original source id, after accepting a Retry", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) + sph.EXPECT().ResetForRetry() + newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) + packer.EXPECT().SetToken([]byte("foobar")) + + tracer.EXPECT().ReceivedRetry(gomock.Any()) + conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) + initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) + tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + }) + }) +}) diff --git a/internal/quic-go/crypto_stream.go b/internal/quic-go/crypto_stream.go new file mode 100644 index 00000000..0763b165 --- /dev/null +++ b/internal/quic-go/crypto_stream.go @@ -0,0 +1,115 @@ +package quic + +import ( + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type cryptoStream interface { + // for receiving data + HandleCryptoFrame(*wire.CryptoFrame) error + GetCryptoData() []byte + Finish() error + // for sending data + io.Writer + HasData() bool + PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame +} + +type cryptoStreamImpl struct { + queue *frameSorter + msgBuf []byte + + highestOffset protocol.ByteCount + finished bool + + writeOffset protocol.ByteCount + writeBuf []byte +} + +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{queue: newFrameSorter()} +} + +func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { + highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) + if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { + return &qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset), + } + } + if s.finished { + if highestOffset > s.highestOffset { + // reject crypto data received after this stream was already finished + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received crypto data after change of encryption level", + } + } + // ignore data with a smaller offset than the highest received + // could e.g. be a retransmission + return nil + } + s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) + if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { + return err + } + for { + _, data, _ := s.queue.Pop() + if data == nil { + return nil + } + s.msgBuf = append(s.msgBuf, data...) + } +} + +// GetCryptoData retrieves data that was received in CRYPTO frames +func (s *cryptoStreamImpl) GetCryptoData() []byte { + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg +} + +func (s *cryptoStreamImpl) Finish() error { + if s.queue.HasMoreData() { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "encryption level changed, but crypto stream has more data to read", + } + } + s.finished = true + return nil +} + +// Writes writes data that should be sent out in CRYPTO frames +func (s *cryptoStreamImpl) Write(p []byte) (int, error) { + s.writeBuf = append(s.writeBuf, p...) + return len(p), nil +} + +func (s *cryptoStreamImpl) HasData() bool { + return len(s.writeBuf) > 0 +} + +func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { + f := &wire.CryptoFrame{Offset: s.writeOffset} + n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + f.Data = s.writeBuf[:n] + s.writeBuf = s.writeBuf[n:] + s.writeOffset += n + return f +} diff --git a/internal/quic-go/crypto_stream_manager.go b/internal/quic-go/crypto_stream_manager.go new file mode 100644 index 00000000..83a70ae5 --- /dev/null +++ b/internal/quic-go/crypto_stream_manager.go @@ -0,0 +1,61 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type cryptoDataHandler interface { + HandleMessage([]byte, protocol.EncryptionLevel) bool +} + +type cryptoStreamManager struct { + cryptoHandler cryptoDataHandler + + initialStream cryptoStream + handshakeStream cryptoStream + oneRTTStream cryptoStream +} + +func newCryptoStreamManager( + cryptoHandler cryptoDataHandler, + initialStream cryptoStream, + handshakeStream cryptoStream, + oneRTTStream cryptoStream, +) *cryptoStreamManager { + return &cryptoStreamManager{ + cryptoHandler: cryptoHandler, + initialStream: initialStream, + handshakeStream: handshakeStream, + oneRTTStream: oneRTTStream, + } +} + +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { + var str cryptoStream + //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + str = m.initialStream + case protocol.EncryptionHandshake: + str = m.handshakeStream + case protocol.Encryption1RTT: + str = m.oneRTTStream + default: + return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + } + if err := str.HandleCryptoFrame(frame); err != nil { + return false, err + } + for { + data := str.GetCryptoData() + if data == nil { + return false, nil + } + if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { + return true, str.Finish() + } + } +} diff --git a/internal/quic-go/crypto_stream_manager_test.go b/internal/quic-go/crypto_stream_manager_test.go new file mode 100644 index 00000000..d5d7ed85 --- /dev/null +++ b/internal/quic-go/crypto_stream_manager_test.go @@ -0,0 +1,119 @@ +package quic + +import ( + "errors" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Crypto Stream Manager", func() { + var ( + csm *cryptoStreamManager + cs *MockCryptoDataHandler + + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream + oneRTTStream *MockCryptoStream + ) + + BeforeEach(func() { + initialStream = NewMockCryptoStream(mockCtrl) + handshakeStream = NewMockCryptoStream(mockCtrl) + oneRTTStream = NewMockCryptoStream(mockCtrl) + cs = NewMockCryptoDataHandler(mockCtrl) + csm = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) + }) + + It("passes messages to the initial stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + initialStream.EXPECT().HandleCryptoFrame(cf) + initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + initialStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("passes messages to the handshake stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("passes messages to the 1-RTT stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + oneRTTStream.EXPECT().HandleCryptoFrame(cf) + oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + oneRTTStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("doesn't call the message handler, if there's no message", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle + // don't EXPECT any calls to HandleMessage() + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("processes all messages", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo")) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) + cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish(), + ) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeTrue()) + }) + + It("returns errors that occur when finishing a stream", func() { + testErr := errors.New("test error") + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish().Return(testErr), + ) + _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).To(MatchError(err)) + }) + + It("errors for unknown encryption levels", func() { + _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level")) + }) +}) diff --git a/internal/quic-go/crypto_stream_test.go b/internal/quic-go/crypto_stream_test.go new file mode 100644 index 00000000..7c8301b7 --- /dev/null +++ b/internal/quic-go/crypto_stream_test.go @@ -0,0 +1,187 @@ +package quic + +import ( + "crypto/rand" + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func createHandshakeMessage(len int) []byte { + msg := make([]byte, 4+len) + rand.Read(msg[:1]) // random message type + msg[1] = uint8(len >> 16) + msg[2] = uint8(len >> 8) + msg[3] = uint8(len) + rand.Read(msg[4:]) + return msg +} + +var _ = Describe("Crypto Stream", func() { + var str cryptoStream + + BeforeEach(func() { + str = newCryptoStream() + }) + + Context("handling incoming data", func() { + It("handles in-order CRYPTO frames", func() { + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("handles multiple messages in one CRYPTO frame", func() { + msg1 := createHandshakeMessage(6) + msg2 := createHandshakeMessage(10) + msg := append(append([]byte{}, msg1...), msg2...) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg1)) + Expect(str.GetCryptoData()).To(Equal(msg2)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("errors if the frame exceeds the maximum offset", func() { + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.MaxCryptoStreamOffset - 5, + Data: []byte("foobar"), + })).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset), + })) + }) + + It("handles messages split over multiple CRYPTO frames", func() { + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg[:4], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(BeNil()) + err = str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: 4, + Data: msg[4:], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("handles out-of-order CRYPTO frames", func() { + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: 4, + Data: msg[4:], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(BeNil()) + err = str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg[:4], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + Context("finishing", func() { + It("errors if there's still data to read after finishing", func() { + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + Offset: 10, + })).To(Succeed()) + Expect(str.Finish()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "encryption level changed, but crypto stream has more data to read", + })) + }) + + It("works with reordered data", func() { + f1 := &wire.CryptoFrame{ + Data: []byte("foo"), + } + f2 := &wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + } + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + Expect(str.HandleCryptoFrame(f1)).To(Succeed()) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + }) + + It("rejects new crypto data after finishing", func() { + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + })).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received crypto data after change of encryption level", + })) + }) + + It("ignores crypto data below the maximum offset received before finishing", func() { + msg := createHandshakeMessage(15) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg, + })).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(len(msg) - 6), + Data: []byte("foobar"), + })).To(Succeed()) + }) + }) + }) + + Context("writing data", func() { + It("says if it has data", func() { + Expect(str.HasData()).To(BeFalse()) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.HasData()).To(BeTrue()) + }) + + It("pops crypto frames", func() { + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("coalesces multiple writes", func() { + _, err := str.Write([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("bar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("respects the maximum size", func() { + frameHeaderLen := (&wire.CryptoFrame{}).Length(protocol.VersionWhatever) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(frameHeaderLen + 3) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foo"))) + f = str.PopCryptoFrame(frameHeaderLen + 3) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(Equal(protocol.ByteCount(3))) + Expect(f.Data).To(Equal([]byte("bar"))) + }) + }) +}) diff --git a/internal/quic-go/datagram_queue.go b/internal/quic-go/datagram_queue.go new file mode 100644 index 00000000..561b7a8e --- /dev/null +++ b/internal/quic-go/datagram_queue.go @@ -0,0 +1,87 @@ +package quic + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type datagramQueue struct { + sendQueue chan *wire.DatagramFrame + rcvQueue chan []byte + + closeErr error + closed chan struct{} + + hasData func() + + dequeued chan struct{} + + logger utils.Logger +} + +func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { + return &datagramQueue{ + hasData: hasData, + sendQueue: make(chan *wire.DatagramFrame, 1), + rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), + dequeued: make(chan struct{}), + closed: make(chan struct{}), + logger: logger, + } +} + +// AddAndWait queues a new DATAGRAM frame for sending. +// It blocks until the frame has been dequeued. +func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { + select { + case h.sendQueue <- f: + h.hasData() + case <-h.closed: + return h.closeErr + } + + select { + case <-h.dequeued: + return nil + case <-h.closed: + return h.closeErr + } +} + +// Get dequeues a DATAGRAM frame for sending. +func (h *datagramQueue) Get() *wire.DatagramFrame { + select { + case f := <-h.sendQueue: + h.dequeued <- struct{}{} + return f + default: + return nil + } +} + +// HandleDatagramFrame handles a received DATAGRAM frame. +func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { + data := make([]byte, len(f.Data)) + copy(data, f.Data) + select { + case h.rcvQueue <- data: + default: + h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + } +} + +// Receive gets a received DATAGRAM frame. +func (h *datagramQueue) Receive() ([]byte, error) { + select { + case data := <-h.rcvQueue: + return data, nil + case <-h.closed: + return nil, h.closeErr + } +} + +func (h *datagramQueue) CloseWithError(e error) { + h.closeErr = e + close(h.closed) +} diff --git a/internal/quic-go/datagram_queue_test.go b/internal/quic-go/datagram_queue_test.go new file mode 100644 index 00000000..29351ce6 --- /dev/null +++ b/internal/quic-go/datagram_queue_test.go @@ -0,0 +1,98 @@ +package quic + +import ( + "errors" + + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Datagram Queue", func() { + var queue *datagramQueue + var queued chan struct{} + + BeforeEach(func() { + queued = make(chan struct{}, 100) + queue = newDatagramQueue(func() { + queued <- struct{}{} + }, utils.DefaultLogger) + }) + + Context("sending", func() { + It("returns nil when there's no datagram to send", func() { + Expect(queue.Get()).To(BeNil()) + }) + + It("queues a datagram", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed()) + }() + + Eventually(queued).Should(HaveLen(1)) + Consistently(done).ShouldNot(BeClosed()) + f := queue.Get() + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foobar"))) + Eventually(done).Should(BeClosed()) + Expect(queue.Get()).To(BeNil()) + }) + + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) + }() + + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) + }) + }) + + Context("receiving", func() { + It("receives DATAGRAM frames", func() { + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foo"))) + data, err = queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("bar"))) + }) + + It("blocks until a frame is received", func() { + c := make(chan []byte, 1) + go func() { + defer GinkgoRecover() + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + c <- data + }() + + Consistently(c).ShouldNot(Receive()) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")}) + Eventually(c).Should(Receive(Equal([]byte("foobar")))) + }) + + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := queue.Receive() + errChan <- err + }() + + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) + }) + }) +}) diff --git a/internal/quic-go/errors.go b/internal/quic-go/errors.go new file mode 100644 index 00000000..5f9050ac --- /dev/null +++ b/internal/quic-go/errors.go @@ -0,0 +1,58 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/qerr" +) + +type ( + TransportError = qerr.TransportError + ApplicationError = qerr.ApplicationError + VersionNegotiationError = qerr.VersionNegotiationError + StatelessResetError = qerr.StatelessResetError + IdleTimeoutError = qerr.IdleTimeoutError + HandshakeTimeoutError = qerr.HandshakeTimeoutError +) + +type ( + TransportErrorCode = qerr.TransportErrorCode + ApplicationErrorCode = qerr.ApplicationErrorCode + StreamErrorCode = qerr.StreamErrorCode +) + +const ( + NoError = qerr.NoError + InternalError = qerr.InternalError + ConnectionRefused = qerr.ConnectionRefused + FlowControlError = qerr.FlowControlError + StreamLimitError = qerr.StreamLimitError + StreamStateError = qerr.StreamStateError + FinalSizeError = qerr.FinalSizeError + FrameEncodingError = qerr.FrameEncodingError + TransportParameterError = qerr.TransportParameterError + ConnectionIDLimitError = qerr.ConnectionIDLimitError + ProtocolViolation = qerr.ProtocolViolation + InvalidToken = qerr.InvalidToken + ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode + CryptoBufferExceeded = qerr.CryptoBufferExceeded + KeyUpdateError = qerr.KeyUpdateError + AEADLimitReached = qerr.AEADLimitReached + NoViablePathError = qerr.NoViablePathError +) + +// A StreamError is used for Stream.CancelRead and Stream.CancelWrite. +// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing. +type StreamError struct { + StreamID StreamID + ErrorCode StreamErrorCode +} + +func (e *StreamError) Is(target error) bool { + _, ok := target.(*StreamError) + return ok +} + +func (e *StreamError) Error() string { + return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) +} diff --git a/internal/quic-go/flowcontrol/base_flow_controller.go b/internal/quic-go/flowcontrol/base_flow_controller.go new file mode 100644 index 00000000..4c7bcb70 --- /dev/null +++ b/internal/quic-go/flowcontrol/base_flow_controller.go @@ -0,0 +1,125 @@ +package flowcontrol + +import ( + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type baseFlowController struct { + // for sending data + bytesSent protocol.ByteCount + sendWindow protocol.ByteCount + lastBlockedAt protocol.ByteCount + + // for receiving data + //nolint:structcheck // The mutex is used both by the stream and the connection flow controller + mutex sync.Mutex + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowSize protocol.ByteCount + maxReceiveWindowSize protocol.ByteCount + + allowWindowIncrease func(size protocol.ByteCount) bool + + epochStartTime time.Time + epochStartOffset protocol.ByteCount + rttStats *utils.RTTStats + + logger utils.Logger +} + +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + +func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { + c.bytesSent += n +} + +// UpdateSendWindow is be called after receiving a MAX_{STREAM_}DATA frame. +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { + if offset > c.sendWindow { + c.sendWindow = offset + } +} + +func (c *baseFlowController) sendWindowSize() protocol.ByteCount { + // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters + if c.bytesSent > c.sendWindow { + return 0 + } + return c.sendWindow - c.bytesSent +} + +// needs to be called with locked mutex +func (c *baseFlowController) addBytesRead(n protocol.ByteCount) { + // pretend we sent a WindowUpdate when reading the first byte + // this way auto-tuning of the window size already works for the first WindowUpdate + if c.bytesRead == 0 { + c.startNewAutoTuningEpoch(time.Now()) + } + c.bytesRead += n +} + +func (c *baseFlowController) hasWindowUpdate() bool { + bytesRemaining := c.receiveWindow - c.bytesRead + // update the window when more than the threshold was consumed + return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold)) +} + +// getWindowUpdate updates the receive window, if necessary +// it returns the new offset +func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { + if !c.hasWindowUpdate() { + return 0 + } + + c.maybeAdjustWindowSize() + c.receiveWindow = c.bytesRead + c.receiveWindowSize + return c.receiveWindow +} + +// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. +// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. +func (c *baseFlowController) maybeAdjustWindowSize() { + bytesReadInEpoch := c.bytesRead - c.epochStartOffset + // don't do anything if less than half the window has been consumed + if bytesReadInEpoch <= c.receiveWindowSize/2 { + return + } + rtt := c.rttStats.SmoothedRTT() + if rtt == 0 { + return + } + + fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) + now := time.Now() + if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { + // window is consumed too fast, try to increase the window size + newSize := utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) + if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { + c.receiveWindowSize = newSize + } + } + c.startNewAutoTuningEpoch(now) +} + +func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) { + c.epochStartTime = now + c.epochStartOffset = c.bytesRead +} + +func (c *baseFlowController) checkFlowControlViolation() bool { + return c.highestReceived > c.receiveWindow +} diff --git a/internal/quic-go/flowcontrol/base_flow_controller_test.go b/internal/quic-go/flowcontrol/base_flow_controller_test.go new file mode 100644 index 00000000..e5a9f578 --- /dev/null +++ b/internal/quic-go/flowcontrol/base_flow_controller_test.go @@ -0,0 +1,236 @@ +package flowcontrol + +import ( + "os" + "strconv" + "time" + + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// on the CIs, the timing is a lot less precise, so scale every duration by this factor +// +//nolint:unparam +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} + +var _ = Describe("Base Flow controller", func() { + var controller *baseFlowController + + BeforeEach(func() { + controller = &baseFlowController{} + controller.rttStats = &utils.RTTStats{} + }) + + Context("send flow control", func() { + It("adds bytes sent", func() { + controller.bytesSent = 5 + controller.AddBytesSent(6) + Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6))) + }) + + It("gets the size of the remaining flow control window", func() { + controller.bytesSent = 5 + controller.sendWindow = 12 + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(12 - 5))) + }) + + It("updates the size of the flow control window", func() { + controller.AddBytesSent(5) + controller.UpdateSendWindow(15) + Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15))) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) + }) + + It("says that the window size is 0 if we sent more than we were allowed to", func() { + controller.AddBytesSent(15) + controller.UpdateSendWindow(10) + Expect(controller.sendWindowSize()).To(BeZero()) + }) + + It("does not decrease the flow control window", func() { + controller.UpdateSendWindow(20) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) + controller.UpdateSendWindow(10) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) + }) + + It("says when it's blocked", func() { + controller.UpdateSendWindow(100) + Expect(controller.IsNewlyBlocked()).To(BeFalse()) + controller.AddBytesSent(100) + blocked, offset := controller.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) + Expect(offset).To(Equal(protocol.ByteCount(100))) + }) + + It("doesn't say that it's newly blocked multiple times for the same offset", func() { + controller.UpdateSendWindow(100) + controller.AddBytesSent(100) + newlyBlocked, offset := controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeTrue()) + Expect(offset).To(Equal(protocol.ByteCount(100))) + newlyBlocked, _ = controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeFalse()) + controller.UpdateSendWindow(150) + controller.AddBytesSent(150) + newlyBlocked, _ = controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeTrue()) + }) + }) + + Context("receive flow control", func() { + var ( + receiveWindow protocol.ByteCount = 10000 + receiveWindowSize protocol.ByteCount = 1000 + ) + + BeforeEach(func() { + controller.bytesRead = receiveWindow - receiveWindowSize + controller.receiveWindow = receiveWindow + controller.receiveWindowSize = receiveWindowSize + }) + + It("adds bytes read", func() { + controller.bytesRead = 5 + controller.addBytesRead(6) + Expect(controller.bytesRead).To(Equal(protocol.ByteCount(5 + 6))) + }) + + It("triggers a window update when necessary", func() { + bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold + bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) + readPosition := receiveWindow - bytesRemaining + controller.bytesRead = readPosition + offset := controller.getWindowUpdate() + Expect(offset).To(Equal(readPosition + receiveWindowSize)) + Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize)) + }) + + It("doesn't trigger a window update when not necessary", func() { + bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold + bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) + readPosition := receiveWindow - bytesRemaining + controller.bytesRead = readPosition + offset := controller.getWindowUpdate() + Expect(offset).To(BeZero()) + }) + + Context("receive window size auto-tuning", func() { + var oldWindowSize protocol.ByteCount + + BeforeEach(func() { + oldWindowSize = controller.receiveWindowSize + controller.maxReceiveWindowSize = 5000 + }) + + // update the congestion such that it returns a given value for the smoothed RTT + setRtt := func(t time.Duration) { + controller.rttStats.UpdateRTT(t, 0, time.Now()) + Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked + } + + It("doesn't increase the window size for a new stream", func() { + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) + }) + + It("doesn't increase the window size when no RTT estimate is available", func() { + setRtt(0) + controller.startNewAutoTuningEpoch(time.Now()) + controller.addBytesRead(400) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) // make sure a window update is sent + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) + }) + + It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() { + bytesRead := controller.bytesRead + rtt := scaleDuration(50 * time.Millisecond) + setRtt(rtt) + // consume more than 2/3 of the window... + dataRead := receiveWindowSize*2/3 + 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) + controller.addBytesRead(dataRead) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) + // check that the window size was increased + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(2 * oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(bytesRead + dataRead + newWindowSize)) + }) + + It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() { + bytesRead := controller.bytesRead + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + // consume more than 2/3 of the window... + dataRead := receiveWindowSize*1/3 + 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3) + controller.addBytesRead(dataRead) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) + // check that the window size was not increased + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(bytesRead + dataRead + newWindowSize)) + }) + + It("doesn't increase the window size if read too slowly", func() { + bytesRead := controller.bytesRead + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + // consume less than 2/3 of the window... + dataRead := receiveWindowSize*2/3 - 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) + controller.addBytesRead(dataRead) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) + // check that the window size was not increased + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(bytesRead + dataRead + oldWindowSize)) + }) + + It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() { + resetEpoch := func() { + // make sure the next call to maybeAdjustWindowSize will increase the window + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.epochStartOffset = controller.bytesRead + controller.addBytesRead(controller.receiveWindowSize/2 + 1) + } + setRtt(scaleDuration(20 * time.Millisecond)) + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000 + // because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000 + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 + }) + }) + }) +}) diff --git a/internal/quic-go/flowcontrol/connection_flow_controller.go b/internal/quic-go/flowcontrol/connection_flow_controller.go new file mode 100644 index 00000000..3e40e0d5 --- /dev/null +++ b/internal/quic-go/flowcontrol/connection_flow_controller.go @@ -0,0 +1,112 @@ +package flowcontrol + +import ( + "errors" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type connectionFlowController struct { + baseFlowController + + queueWindowUpdate func() +} + +var _ ConnectionFlowController = &connectionFlowController{} + +// NewConnectionFlowController gets a new flow controller for the connection +// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0. +func NewConnectionFlowController( + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + queueWindowUpdate func(), + allowWindowIncrease func(size protocol.ByteCount) bool, + rttStats *utils.RTTStats, + logger utils.Logger, +) ConnectionFlowController { + return &connectionFlowController{ + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + allowWindowIncrease: allowWindowIncrease, + logger: logger, + }, + queueWindowUpdate: queueWindowUpdate, + } +} + +func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { + return c.baseFlowController.sendWindowSize() +} + +// IncrementHighestReceived adds an increment to the highestReceived value +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.highestReceived += increment + if c.checkFlowControlViolation() { + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow), + } + } + return nil +} + +func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + c.baseFlowController.addBytesRead(n) + shouldQueueWindowUpdate := c.hasWindowUpdate() + c.mutex.Unlock() + if shouldQueueWindowUpdate { + c.queueWindowUpdate() + } +} + +func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if oldWindowSize < c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + } + c.mutex.Unlock() + return offset +} + +// EnsureMinimumWindowSize sets a minimum window size +// it should make sure that the connection-level window is increased when a stream-level window grows +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { + c.mutex.Lock() + if inc > c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) + newSize := utils.MinByteCount(inc, c.maxReceiveWindowSize) + if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { + c.receiveWindowSize = newSize + } + c.startNewAutoTuningEpoch(time.Now()) + } + c.mutex.Unlock() +} + +// Reset rests the flow controller. This happens when 0-RTT is rejected. +// All stream data is invalidated, it's if we had never opened a stream and never sent any data. +// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet. +func (c *connectionFlowController) Reset() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() { + return errors.New("flow controller reset after reading data") + } + c.bytesSent = 0 + c.lastBlockedAt = 0 + return nil +} diff --git a/internal/quic-go/flowcontrol/connection_flow_controller_test.go b/internal/quic-go/flowcontrol/connection_flow_controller_test.go new file mode 100644 index 00000000..32ad79c5 --- /dev/null +++ b/internal/quic-go/flowcontrol/connection_flow_controller_test.go @@ -0,0 +1,185 @@ +package flowcontrol + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection Flow controller", func() { + var ( + controller *connectionFlowController + queuedWindowUpdate bool + ) + + // update the congestion such that it returns a given value for the smoothed RTT + setRtt := func(t time.Duration) { + controller.rttStats.UpdateRTT(t, 0, time.Now()) + Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked + } + + BeforeEach(func() { + queuedWindowUpdate = false + controller = &connectionFlowController{} + controller.rttStats = &utils.RTTStats{} + controller.logger = utils.DefaultLogger + controller.queueWindowUpdate = func() { queuedWindowUpdate = true } + controller.allowWindowIncrease = func(protocol.ByteCount) bool { return true } + }) + + Context("Constructor", func() { + rttStats := &utils.RTTStats{} + + It("sets the send and receive windows", func() { + receiveWindow := protocol.ByteCount(2000) + maxReceiveWindow := protocol.ByteCount(3000) + + fc := NewConnectionFlowController( + receiveWindow, + maxReceiveWindow, + nil, + func(protocol.ByteCount) bool { return true }, + rttStats, + utils.DefaultLogger).(*connectionFlowController) + Expect(fc.receiveWindow).To(Equal(receiveWindow)) + Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) + }) + }) + + Context("receive flow control", func() { + It("increases the highestReceived by a given window size", func() { + controller.highestReceived = 1337 + controller.IncrementHighestReceived(123) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123))) + }) + + Context("getting window updates", func() { + BeforeEach(func() { + controller.receiveWindow = 100 + controller.receiveWindowSize = 60 + controller.maxReceiveWindowSize = 1000 + controller.bytesRead = 100 - 60 + }) + + It("queues window updates", func() { + controller.AddBytesRead(1) + Expect(queuedWindowUpdate).To(BeFalse()) + controller.AddBytesRead(29) + Expect(queuedWindowUpdate).To(BeTrue()) + Expect(controller.GetWindowUpdate()).ToNot(BeZero()) + queuedWindowUpdate = false + controller.AddBytesRead(1) + Expect(queuedWindowUpdate).To(BeFalse()) + }) + + It("gets a window update", func() { + windowSize := controller.receiveWindowSize + oldOffset := controller.bytesRead + dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning + controller.AddBytesRead(dataRead) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(oldOffset + dataRead + 60)) + }) + + It("auto-tunes the window", func() { + var allowed protocol.ByteCount + controller.allowWindowIncrease = func(size protocol.ByteCount) bool { + allowed = size + return true + } + oldOffset := controller.bytesRead + oldWindowSize := controller.receiveWindowSize + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.epochStartOffset = oldOffset + dataRead := oldWindowSize/2 + 1 + controller.AddBytesRead(dataRead) + offset := controller.GetWindowUpdate() + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(2 * oldWindowSize)) + Expect(offset).To(Equal(oldOffset + dataRead + newWindowSize)) + Expect(allowed).To(Equal(oldWindowSize)) + }) + + It("doesn't auto-tune the window if it's not allowed", func() { + controller.allowWindowIncrease = func(protocol.ByteCount) bool { return false } + oldOffset := controller.bytesRead + oldWindowSize := controller.receiveWindowSize + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.epochStartOffset = oldOffset + dataRead := oldWindowSize/2 + 1 + controller.AddBytesRead(dataRead) + offset := controller.GetWindowUpdate() + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(oldWindowSize)) + Expect(offset).To(Equal(oldOffset + dataRead + newWindowSize)) + }) + }) + }) + + Context("setting the minimum window size", func() { + var ( + oldWindowSize protocol.ByteCount + receiveWindow protocol.ByteCount = 10000 + receiveWindowSize protocol.ByteCount = 1000 + ) + + BeforeEach(func() { + controller.receiveWindow = receiveWindow + controller.receiveWindowSize = receiveWindowSize + oldWindowSize = controller.receiveWindowSize + controller.maxReceiveWindowSize = 3000 + }) + + It("sets the minimum window window size", func() { + controller.EnsureMinimumWindowSize(1800) + Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800))) + }) + + It("doesn't reduce the window window size", func() { + controller.EnsureMinimumWindowSize(1) + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) + }) + + It("doesn't increase the window size beyond the maxReceiveWindowSize", func() { + max := controller.maxReceiveWindowSize + controller.EnsureMinimumWindowSize(2 * max) + Expect(controller.receiveWindowSize).To(Equal(max)) + }) + + It("starts a new epoch after the window size was increased", func() { + controller.EnsureMinimumWindowSize(1912) + Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + }) + }) + + Context("resetting", func() { + It("resets", func() { + const initialWindow protocol.ByteCount = 1337 + controller.UpdateSendWindow(initialWindow) + controller.AddBytesSent(1000) + Expect(controller.SendWindowSize()).To(Equal(initialWindow - 1000)) + Expect(controller.Reset()).To(Succeed()) + Expect(controller.SendWindowSize()).To(Equal(initialWindow)) + }) + + It("says if is blocked after resetting", func() { + const initialWindow protocol.ByteCount = 1337 + controller.UpdateSendWindow(initialWindow) + controller.AddBytesSent(initialWindow) + blocked, _ := controller.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) + Expect(controller.Reset()).To(Succeed()) + controller.AddBytesSent(initialWindow) + blocked, blockedAt := controller.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) + Expect(blockedAt).To(Equal(initialWindow)) + }) + }) +}) diff --git a/internal/quic-go/flowcontrol/flowcontrol_suite_test.go b/internal/quic-go/flowcontrol/flowcontrol_suite_test.go new file mode 100644 index 00000000..91102815 --- /dev/null +++ b/internal/quic-go/flowcontrol/flowcontrol_suite_test.go @@ -0,0 +1,24 @@ +package flowcontrol + +import ( + "testing" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestFlowControl(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "FlowControl Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/quic-go/flowcontrol/interface.go b/internal/quic-go/flowcontrol/interface.go new file mode 100644 index 00000000..58a0e88f --- /dev/null +++ b/internal/quic-go/flowcontrol/interface.go @@ -0,0 +1,42 @@ +package flowcontrol + +import "github.com/imroc/req/v3/internal/quic-go/protocol" + +type flowController interface { + // for sending + SendWindowSize() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + AddBytesSent(protocol.ByteCount) + // for receiving + AddBytesRead(protocol.ByteCount) + GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + IsNewlyBlocked() (bool, protocol.ByteCount) +} + +// A StreamFlowController is a flow controller for a QUIC stream. +type StreamFlowController interface { + flowController + // for receiving + // UpdateHighestReceived should be called when a new highest offset is received + // final has to be to true if this is the final offset of the stream, + // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame + UpdateHighestReceived(offset protocol.ByteCount, final bool) error + // Abandon should be called when reading from the stream is aborted early, + // and there won't be any further calls to AddBytesRead. + Abandon() +} + +// The ConnectionFlowController is the flow controller for the connection. +type ConnectionFlowController interface { + flowController + Reset() error +} + +type connectionFlowControllerI interface { + ConnectionFlowController + // The following two methods are not supposed to be called from outside this packet, but are needed internally + // for sending + EnsureMinimumWindowSize(protocol.ByteCount) + // for receiving + IncrementHighestReceived(protocol.ByteCount) error +} diff --git a/internal/quic-go/flowcontrol/stream_flow_controller.go b/internal/quic-go/flowcontrol/stream_flow_controller.go new file mode 100644 index 00000000..7ee95862 --- /dev/null +++ b/internal/quic-go/flowcontrol/stream_flow_controller.go @@ -0,0 +1,149 @@ +package flowcontrol + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type streamFlowController struct { + baseFlowController + + streamID protocol.StreamID + + queueWindowUpdate func() + + connection connectionFlowControllerI + + receivedFinalOffset bool +} + +var _ StreamFlowController = &streamFlowController{} + +// NewStreamFlowController gets a new flow controller for a stream +func NewStreamFlowController( + streamID protocol.StreamID, + cfc ConnectionFlowController, + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), + rttStats *utils.RTTStats, + logger utils.Logger, +) StreamFlowController { + return &streamFlowController{ + streamID: streamID, + connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + sendWindow: initialSendWindow, + logger: logger, + }, + } +} + +// UpdateHighestReceived updates the highestReceived value, if the offset is higher. +func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error { + // If the final offset for this stream is already known, check for consistency. + if c.receivedFinalOffset { + // If we receive another final offset, check that it's the same. + if final && offset != c.highestReceived { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset), + } + } + // Check that the offset is below the final offset. + if offset > c.highestReceived { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived), + } + } + } + + if final { + c.receivedFinalOffset = true + } + if offset == c.highestReceived { + return nil + } + // A higher offset was received before. + // This can happen due to reordering. + if offset <= c.highestReceived { + if final { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived), + } + } + return nil + } + + increment := offset - c.highestReceived + c.highestReceived = offset + if c.checkFlowControlViolation() { + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), + } + } + return c.connection.IncrementHighestReceived(increment) +} + +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + c.baseFlowController.addBytesRead(n) + shouldQueueWindowUpdate := c.shouldQueueWindowUpdate() + c.mutex.Unlock() + if shouldQueueWindowUpdate { + c.queueWindowUpdate() + } + c.connection.AddBytesRead(n) +} + +func (c *streamFlowController) Abandon() { + c.mutex.Lock() + unread := c.highestReceived - c.bytesRead + c.mutex.Unlock() + if unread > 0 { + c.connection.AddBytesRead(unread) + } +} + +func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { + c.baseFlowController.AddBytesSent(n) + c.connection.AddBytesSent(n) +} + +func (c *streamFlowController) SendWindowSize() protocol.ByteCount { + return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) +} + +func (c *streamFlowController) shouldQueueWindowUpdate() bool { + return !c.receivedFinalOffset && c.hasWindowUpdate() +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + // If we already received the final offset for this stream, the peer won't need any additional flow control credit. + if c.receivedFinalOffset { + return 0 + } + + // Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler + c.mutex.Lock() + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size + c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) + } + c.mutex.Unlock() + return offset +} diff --git a/internal/quic-go/flowcontrol/stream_flow_controller_test.go b/internal/quic-go/flowcontrol/stream_flow_controller_test.go new file mode 100644 index 00000000..61084795 --- /dev/null +++ b/internal/quic-go/flowcontrol/stream_flow_controller_test.go @@ -0,0 +1,272 @@ +package flowcontrol + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream Flow controller", func() { + var ( + controller *streamFlowController + queuedWindowUpdate bool + ) + + BeforeEach(func() { + queuedWindowUpdate = false + rttStats := &utils.RTTStats{} + controller = &streamFlowController{ + streamID: 10, + connection: NewConnectionFlowController( + 1000, + 1000, + func() {}, + func(protocol.ByteCount) bool { return true }, + rttStats, + utils.DefaultLogger, + ).(*connectionFlowController), + } + controller.maxReceiveWindowSize = 10000 + controller.rttStats = rttStats + controller.logger = utils.DefaultLogger + controller.queueWindowUpdate = func() { queuedWindowUpdate = true } + }) + + Context("Constructor", func() { + rttStats := &utils.RTTStats{} + const receiveWindow protocol.ByteCount = 2000 + const maxReceiveWindow protocol.ByteCount = 3000 + const sendWindow protocol.ByteCount = 4000 + + It("sets the send and receive windows", func() { + cc := NewConnectionFlowController(0, 0, nil, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger) + fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController) + Expect(fc.streamID).To(Equal(protocol.StreamID(5))) + Expect(fc.receiveWindow).To(Equal(receiveWindow)) + Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) + Expect(fc.sendWindow).To(Equal(sendWindow)) + }) + + It("queues window updates with the correct stream ID", func() { + var queued bool + queueWindowUpdate := func(id protocol.StreamID) { + Expect(id).To(Equal(protocol.StreamID(5))) + queued = true + } + + cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger) + fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) + fc.AddBytesRead(receiveWindow) + Expect(queued).To(BeTrue()) + }) + }) + + Context("receiving data", func() { + Context("registering received offsets", func() { + var receiveWindow protocol.ByteCount = 10000 + var receiveWindowSize protocol.ByteCount = 600 + + BeforeEach(func() { + controller.receiveWindow = receiveWindow + controller.receiveWindowSize = receiveWindowSize + }) + + It("updates the highestReceived", func() { + controller.highestReceived = 1337 + Expect(controller.UpdateHighestReceived(1338, false)).To(Succeed()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338))) + }) + + It("informs the connection flow controller about received data", func() { + controller.highestReceived = 10 + controller.connection.(*connectionFlowController).highestReceived = 100 + Expect(controller.UpdateHighestReceived(20, false)).To(Succeed()) + Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10))) + }) + + It("does not decrease the highestReceived", func() { + controller.highestReceived = 1337 + Expect(controller.UpdateHighestReceived(1000, false)).To(Succeed()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337))) + }) + + It("does nothing when setting the same byte offset", func() { + controller.highestReceived = 1337 + Expect(controller.UpdateHighestReceived(1337, false)).To(Succeed()) + }) + + It("does not give a flow control violation when using the window completely", func() { + controller.connection.(*connectionFlowController).receiveWindow = receiveWindow + Expect(controller.UpdateHighestReceived(receiveWindow, false)).To(Succeed()) + }) + + It("detects a flow control violation", func() { + Expect(controller.UpdateHighestReceived(receiveWindow+1, false)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: "received 10001 bytes on stream 10, allowed 10000 bytes", + })) + }) + + It("accepts a final offset higher than the highest received", func() { + Expect(controller.UpdateHighestReceived(100, false)).To(Succeed()) + Expect(controller.UpdateHighestReceived(101, true)).To(Succeed()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(101))) + }) + + It("errors when receiving a final offset smaller than the highest offset received so far", func() { + controller.UpdateHighestReceived(100, false) + Expect(controller.UpdateHighestReceived(50, true)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received final offset 50 for stream 10, but already received offset 100 before", + })) + }) + + It("accepts delayed data after receiving a final offset", func() { + Expect(controller.UpdateHighestReceived(300, true)).To(Succeed()) + Expect(controller.UpdateHighestReceived(250, false)).To(Succeed()) + }) + + It("errors when receiving a higher offset after receiving a final offset", func() { + Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) + Expect(controller.UpdateHighestReceived(250, false)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received offset 250 for stream 10, but final offset was already received at 200", + })) + }) + + It("accepts duplicate final offsets", func() { + Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) + Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(200))) + }) + + It("errors when receiving inconsistent final offsets", func() { + Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) + Expect(controller.UpdateHighestReceived(201, true)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)", + })) + }) + + It("tells the connection flow controller when a stream is abandoned", func() { + controller.AddBytesRead(5) + Expect(controller.UpdateHighestReceived(100, true)).To(Succeed()) + controller.Abandon() + Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(100))) + }) + }) + + It("saves when data is read", func() { + controller.AddBytesRead(200) + Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200))) + Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200))) + }) + + Context("generating window updates", func() { + var oldWindowSize protocol.ByteCount + + // update the congestion such that it returns a given value for the smoothed RTT + setRtt := func(t time.Duration) { + controller.rttStats.UpdateRTT(t, 0, time.Now()) + Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked + } + + BeforeEach(func() { + controller.receiveWindow = 100 + controller.receiveWindowSize = 60 + controller.bytesRead = 100 - 60 + controller.connection.(*connectionFlowController).receiveWindow = 100 + controller.connection.(*connectionFlowController).receiveWindowSize = 120 + oldWindowSize = controller.receiveWindowSize + }) + + It("queues window updates", func() { + controller.AddBytesRead(1) + Expect(queuedWindowUpdate).To(BeFalse()) + controller.AddBytesRead(29) + Expect(queuedWindowUpdate).To(BeTrue()) + Expect(controller.GetWindowUpdate()).ToNot(BeZero()) + queuedWindowUpdate = false + controller.AddBytesRead(1) + Expect(queuedWindowUpdate).To(BeFalse()) + }) + + It("tells the connection flow controller when the window was auto-tuned", func() { + var allowed protocol.ByteCount + controller.connection.(*connectionFlowController).allowWindowIncrease = func(size protocol.ByteCount) bool { + allowed = size + return true + } + oldOffset := controller.bytesRead + setRtt(scaleDuration(20 * time.Millisecond)) + controller.epochStartOffset = oldOffset + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.AddBytesRead(55) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(oldOffset + 55 + 2*oldWindowSize)) + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) + Expect(allowed).To(Equal(oldWindowSize)) + Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))) + }) + + It("doesn't increase the connection flow control window if it's not allowed", func() { + oldOffset := controller.bytesRead + oldConnectionSize := controller.connection.(*connectionFlowController).receiveWindowSize + controller.connection.(*connectionFlowController).allowWindowIncrease = func(protocol.ByteCount) bool { return false } + setRtt(scaleDuration(20 * time.Millisecond)) + controller.epochStartOffset = oldOffset + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.AddBytesRead(55) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(oldOffset + 55 + 2*oldWindowSize)) + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) + Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(oldConnectionSize)) + }) + + It("sends a connection-level window update when a large stream is abandoned", func() { + Expect(controller.UpdateHighestReceived(90, true)).To(Succeed()) + Expect(controller.connection.GetWindowUpdate()).To(BeZero()) + controller.Abandon() + Expect(controller.connection.GetWindowUpdate()).ToNot(BeZero()) + }) + + It("doesn't increase the window after a final offset was already received", func() { + Expect(controller.UpdateHighestReceived(90, true)).To(Succeed()) + controller.AddBytesRead(30) + Expect(queuedWindowUpdate).To(BeFalse()) + offset := controller.GetWindowUpdate() + Expect(offset).To(BeZero()) + }) + }) + }) + + Context("sending data", func() { + It("gets the size of the send window", func() { + controller.connection.UpdateSendWindow(1000) + controller.UpdateSendWindow(15) + controller.AddBytesSent(5) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10))) + }) + + It("makes sure that it doesn't overflow the connection-level window", func() { + controller.connection.UpdateSendWindow(12) + controller.UpdateSendWindow(20) + controller.AddBytesSent(10) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(2))) + }) + + It("doesn't say that it's blocked, if only the connection is blocked", func() { + controller.connection.UpdateSendWindow(50) + controller.UpdateSendWindow(100) + controller.AddBytesSent(50) + blocked, _ := controller.connection.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) + Expect(controller.IsNewlyBlocked()).To(BeFalse()) + }) + }) +}) diff --git a/internal/quic-go/frame_sorter.go b/internal/quic-go/frame_sorter.go new file mode 100644 index 00000000..aa16e38c --- /dev/null +++ b/internal/quic-go/frame_sorter.go @@ -0,0 +1,224 @@ +package quic + +import ( + "errors" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type frameSorterEntry struct { + Data []byte + DoneCb func() +} + +type frameSorter struct { + queue map[protocol.ByteCount]frameSorterEntry + readPos protocol.ByteCount + gaps *utils.ByteIntervalList +} + +var errDuplicateStreamData = errors.New("duplicate stream data") + +func newFrameSorter() *frameSorter { + s := frameSorter{ + gaps: utils.NewByteIntervalList(), + queue: make(map[protocol.ByteCount]frameSorterEntry), + } + s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) + return &s +} + +func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error { + err := s.push(data, offset, doneCb) + if err == errDuplicateStreamData { + if doneCb != nil { + doneCb() + } + return nil + } + return err +} + +func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error { + if len(data) == 0 { + return errDuplicateStreamData + } + + start := offset + end := offset + protocol.ByteCount(len(data)) + + if end <= s.gaps.Front().Value.Start { + return errDuplicateStreamData + } + + startGap, startsInGap := s.findStartGap(start) + endGap, endsInGap := s.findEndGap(startGap, end) + + startGapEqualsEndGap := startGap == endGap + + if (startGapEqualsEndGap && end <= startGap.Value.Start) || + (!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) { + return errDuplicateStreamData + } + + startGapNext := startGap.Next() + startGapEnd := startGap.Value.End // save it, in case startGap is modified + endGapStart := endGap.Value.Start // save it, in case endGap is modified + endGapEnd := endGap.Value.End // save it, in case endGap is modified + var adjustedStartGapEnd bool + var wasCut bool + + pos := start + var hasReplacedAtLeastOne bool + for { + oldEntry, ok := s.queue[pos] + if !ok { + break + } + oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) + if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) { + // The existing frame is shorter than the new frame. Replace it. + delete(s.queue, pos) + pos += oldEntryLen + hasReplacedAtLeastOne = true + if oldEntry.DoneCb != nil { + oldEntry.DoneCb() + } + } else { + if !hasReplacedAtLeastOne { + return errDuplicateStreamData + } + // The existing frame is longer than the new frame. + // Cut the new frame such that the end aligns with the start of the existing frame. + data = data[:pos-start] + end = pos + wasCut = true + break + } + } + + if !startsInGap && !hasReplacedAtLeastOne { + // cut the frame, such that it starts at the start of the gap + data = data[startGap.Value.Start-start:] + start = startGap.Value.Start + wasCut = true + } + if start <= startGap.Value.Start { + if end >= startGap.Value.End { + // The frame covers the whole startGap. Delete the gap. + s.gaps.Remove(startGap) + } else { + startGap.Value.Start = end + } + } else if !hasReplacedAtLeastOne { + startGap.Value.End = start + adjustedStartGapEnd = true + } + + if !startGapEqualsEndGap { + s.deleteConsecutive(startGapEnd) + var nextGap *utils.ByteIntervalElement + for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { + nextGap = gap.Next() + s.deleteConsecutive(gap.Value.End) + s.gaps.Remove(gap) + } + } + + if !endsInGap && start != endGapEnd && end > endGapEnd { + // cut the frame, such that it ends at the end of the gap + data = data[:endGapEnd-start] + end = endGapEnd + wasCut = true + } + if end == endGapEnd { + if !startGapEqualsEndGap { + // The frame covers the whole endGap. Delete the gap. + s.gaps.Remove(endGap) + } + } else { + if startGapEqualsEndGap && adjustedStartGapEnd { + // The frame split the existing gap into two. + s.gaps.InsertAfter(utils.ByteInterval{Start: end, End: startGapEnd}, startGap) + } else if !startGapEqualsEndGap { + endGap.Value.Start = end + } + } + + if wasCut && len(data) < protocol.MinStreamFrameBufferSize { + newData := make([]byte, len(data)) + copy(newData, data) + data = newData + if doneCb != nil { + doneCb() + doneCb = nil + } + } + + if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { + return errors.New("too many gaps in received data") + } + + s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb} + return nil +} + +func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { + for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { + if offset >= gap.Value.Start && offset <= gap.Value.End { + return gap, true + } + if offset < gap.Value.Start { + return gap, false + } + } + panic("no gap found") +} + +func (s *frameSorter) findEndGap(startGap *utils.ByteIntervalElement, offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { + for gap := startGap; gap != nil; gap = gap.Next() { + if offset >= gap.Value.Start && offset < gap.Value.End { + return gap, true + } + if offset < gap.Value.Start { + return gap.Prev(), false + } + } + panic("no gap found") +} + +// deleteConsecutive deletes consecutive frames from the queue, starting at pos +func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) { + for { + oldEntry, ok := s.queue[pos] + if !ok { + break + } + oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) + delete(s.queue, pos) + if oldEntry.DoneCb != nil { + oldEntry.DoneCb() + } + pos += oldEntryLen + } +} + +func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { + entry, ok := s.queue[s.readPos] + if !ok { + return s.readPos, nil, nil + } + delete(s.queue, s.readPos) + offset := s.readPos + s.readPos += protocol.ByteCount(len(entry.Data)) + if s.gaps.Front().Value.End <= s.readPos { + panic("frame sorter BUG: read position higher than a gap") + } + return offset, entry.Data, entry.DoneCb +} + +// HasMoreData says if there is any more data queued at *any* offset. +func (s *frameSorter) HasMoreData() bool { + return len(s.queue) > 0 +} diff --git a/internal/quic-go/frame_sorter_test.go b/internal/quic-go/frame_sorter_test.go new file mode 100644 index 00000000..52614111 --- /dev/null +++ b/internal/quic-go/frame_sorter_test.go @@ -0,0 +1,1527 @@ +package quic + +import ( + "bytes" + "fmt" + "math" + "math/rand" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("frame sorter", func() { + var s *frameSorter + + checkGaps := func(expectedGaps []utils.ByteInterval) { + if s.gaps.Len() != len(expectedGaps) { + fmt.Println("Gaps:") + for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { + fmt.Printf("\t%d - %d\n", gap.Value.Start, gap.Value.End) + } + ExpectWithOffset(1, s.gaps.Len()).To(Equal(len(expectedGaps))) + } + var i int + for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { + ExpectWithOffset(1, gap.Value).To(Equal(expectedGaps[i])) + i++ + } + } + + type callbackTracker struct { + called *bool + cb func() + } + + getCallback := func() (func(), callbackTracker) { + var called bool + cb := func() { + if called { + panic("double free") + } + called = true + } + return cb, callbackTracker{ + cb: cb, + called: &called, + } + } + + checkCallbackCalled := func(t callbackTracker) { + ExpectWithOffset(1, *t.called).To(BeTrue()) + } + + checkCallbackNotCalled := func(t callbackTracker) { + ExpectWithOffset(1, *t.called).To(BeFalse()) + t.cb() + ExpectWithOffset(1, *t.called).To(BeTrue()) + } + + BeforeEach(func() { + s = newFrameSorter() + }) + + It("returns nil when empty", func() { + _, data, doneCb := s.Pop() + Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) + }) + + It("inserts and pops a single frame", func() { + cb, t := getCallback() + Expect(s.Push([]byte("foobar"), 0, cb)).To(Succeed()) + offset, data, doneCb := s.Pop() + Expect(offset).To(BeZero()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(doneCb).ToNot(BeNil()) + checkCallbackNotCalled(t) + offset, data, doneCb = s.Pop() + Expect(offset).To(Equal(protocol.ByteCount(6))) + Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) + }) + + It("inserts and pops two consecutive frame", func() { + cb1, t1 := getCallback() + cb2, t2 := getCallback() + Expect(s.Push([]byte("bar"), 3, cb2)).To(Succeed()) + Expect(s.Push([]byte("foo"), 0, cb1)).To(Succeed()) + offset, data, doneCb := s.Pop() + Expect(offset).To(BeZero()) + Expect(data).To(Equal([]byte("foo"))) + Expect(doneCb).ToNot(BeNil()) + doneCb() + checkCallbackCalled(t1) + offset, data, doneCb = s.Pop() + Expect(offset).To(Equal(protocol.ByteCount(3))) + Expect(data).To(Equal([]byte("bar"))) + Expect(doneCb).ToNot(BeNil()) + doneCb() + checkCallbackCalled(t2) + offset, data, doneCb = s.Pop() + Expect(offset).To(Equal(protocol.ByteCount(6))) + Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) + }) + + It("ignores empty frames", func() { + Expect(s.Push(nil, 0, nil)).To(Succeed()) + _, data, doneCb := s.Pop() + Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) + }) + + It("says if has more data", func() { + Expect(s.HasMoreData()).To(BeFalse()) + Expect(s.Push([]byte("foo"), 0, nil)).To(Succeed()) + Expect(s.HasMoreData()).To(BeTrue()) + _, data, _ := s.Pop() + Expect(data).To(Equal([]byte("foo"))) + Expect(s.HasMoreData()).To(BeFalse()) + }) + + Context("Gap handling", func() { + var dataCounter uint8 + + BeforeEach(func() { + dataCounter = 0 + }) + + checkQueue := func(m map[protocol.ByteCount][]byte) { + ExpectWithOffset(1, s.queue).To(HaveLen(len(m))) + for offset, data := range m { + ExpectWithOffset(1, s.queue).To(HaveKey(offset)) + ExpectWithOffset(1, s.queue[offset].Data).To(Equal(data)) + } + } + + getData := func(l protocol.ByteCount) []byte { + dataCounter++ + return bytes.Repeat([]byte{dataCounter}, int(l)) + } + + // ---xxx-------------- + // ++++++ + // => + // ---xxx++++++-------- + It("case 1", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 11, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ---xxx----------------- + // +++++++ + // => + // ---xxx---+++++++-------- + It("case 2", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 -15 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 10: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 6, End: 10}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ---xxx----xxxxxx------- + // ++++ + // => + // ---xxx++++xxxxx-------- + It("case 3", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f2, 6, cb3)).To(Succeed()) // 6 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f2, + 10: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ----xxxx------- + // ++++ + // => + // ----xxxx++----- + It("case 4", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 + Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 7: f2[2:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + It("case 4, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) + f1 := getData(4 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 7 + Expect(s.Push(f2, 5*mult, cb2)).To(Succeed()) // 5 - 9 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 7 * mult: f2[2*mult:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 9 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + }) + + // xxxx------- + // ++++ + // => + // xxxx+++----- + It("case 5", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 4 + Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 7 + checkQueue(map[protocol.ByteCount][]byte{ + 0: f1, + 4: f2[1:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 7, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + It("case 5, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) + f1 := getData(4 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 4 + Expect(s.Push(f2, 3*mult, cb2)).To(Succeed()) // 3 - 7 + checkQueue(map[protocol.ByteCount][]byte{ + 0: f1, + 4 * mult: f2[mult:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 7 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ----xxxx------- + // ++++ + // => + // --++xxxx------- + It("case 6", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 + Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 7 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f2[:2], + 5: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + It("case 6, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) + f1 := getData(4 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5*mult, cb1)).To(Succeed()) // 5 - 9 + Expect(s.Push(f2, 3*mult, cb2)).To(Succeed()) // 3 - 7 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f2[:2*mult], + 5 * mult: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 9 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ---xxx----xxxxxx------- + // ++ + // => + // ---xxx++--xxxxx-------- + It("case 7", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(2) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f2, 6, cb3)).To(Succeed()) // 6 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f2, + 10: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 8, End: 10}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx---------xxxxxx-- + // ++ + // => + // ---xxx---++----xxxxx-- + It("case 8", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(2) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f3, 15, cb2)).To(Succeed()) // 15 - 20 + Expect(s.Push(f2, 10, cb3)).To(Succeed()) // 10 - 12 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 10: f2, + 15: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 6, End: 10}, + {Start: 12, End: 15}, + {Start: 20, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx----xxxxxx------- + // ++ + // => + // ---xxx--++xxxxx-------- + It("case 9", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(2) + cb2, t2 := getCallback() + cb3, t3 := getCallback() + f3 := getData(5) + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f2, 8, cb3)).To(Succeed()) // 8 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 8: f2, + 10: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 6, End: 8}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx----=====------- + // +++++++ + // => + // ---xxx++++=====-------- + It("case 10", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + f3 := getData(6) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f3[1:5], + 10: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 10, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 4)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(5 * mult) + cb2, t2 := getCallback() + f3 := getData(6 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 10*mult, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 6 * mult: f3[mult : 5*mult], + 10 * mult: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 15 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxxx----=====------- + // ++++++ + // => + // ---xxx++++=====-------- + It("case 11", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 + Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 7: f3[2:], + 10: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + // ---xxxx----=====------- + // ++++++ + // => + // ---xxx++++=====-------- + It("case 11, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) + f1 := getData(4 * mult) + cb1, t1 := getCallback() + f2 := getData(5 * mult) + cb2, t2 := getCallback() + f3 := getData(5 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 7 + Expect(s.Push(f2, 10*mult, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 7 * mult: f3[2*mult:], + 10 * mult: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 15 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ----xxxx------- + // +++++++ + // => + // ----+++++++----- + It("case 12", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(7) + cb2, t2 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 + Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ----xxx===------- + // +++++++ + // => + // ----+++++++----- + It("case 13", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(7) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ----xxx====------- + // +++++ + // => + // ----+++====----- + It("case 14", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3[:3], + 6: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 14, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + f3 := getData(5 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 3*mult, cb3)).To(Succeed()) // 3 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f3[:3*mult], + 6 * mult: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 10 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ----xxx===------- + // ++++++ + // => + // ----++++++----- + It("case 15", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(6) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 9 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxxx------- + // ++++ + // => + // ---xxxx----- + It("case 16", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 + Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + // ----xxx===------- + // +++ + // => + // ----xxx===----- + It("case 17", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(3) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 6 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + // ---xxxx------- + // ++ + // => + // ---xxxx----- + It("case 18", func() { + f1 := getData(4) + cb1, t1 := getCallback() + f2 := getData(2) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 + Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 7 + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 9, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + // ---xxxxx------ + // ++ + // => + // ---xxxxx---- + It("case 19", func() { + f1 := getData(5) + cb1, t1 := getCallback() + f2 := getData(2) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + }) + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + // xxxxx------ + // ++ + // => + // xxxxx------ + It("case 20", func() { + f1 := getData(10) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 10 + Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 + checkQueue(map[protocol.ByteCount][]byte{ + 0: f1, + }) + checkGaps([]utils.ByteInterval{ + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + // ---xxxxx--- + // +++ + // => + // ---xxxxx--- + It("case 21", func() { + f1 := getData(5) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 + Expect(s.Push(f2, 7, cb2)).To(Succeed()) // 7 - 10 + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + }) + + // ----xxx------ + // +++++ + // => + // --+++++---- + It("case 22", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 + Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 8, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ----xxx===------ + // ++++++++ + // => + // --++++++++---- + It("case 23", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(8) + cb3, t3 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 + Expect(s.Push(f2, 8, cb2)).To(Succeed()) // 8 - 11 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 11, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // --xxx---===--- + // ++++++ + // => + // --xxx++++++---- + It("case 24", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(6) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 6, cb3)).To(Succeed()) // 6 - 12 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 12, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // --xxx---===---### + // +++++++++ + // => + // --xxx+++++++++### + It("case 25", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(3) + cb3, t3 := getCallback() + f4 := getData(9) + cb4, t4 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 15, cb3)).To(Succeed()) // 15 - 18 + Expect(s.Push(f4, 6, cb4)).To(Succeed()) // 6 - 15 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f4, + 15: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 18, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + checkCallbackNotCalled(t4) + }) + + // ----xxx------ + // +++++++ + // => + // --+++++++--- + It("case 26", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(10) + cb2, t2 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 + Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 13 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 13, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + }) + + // ---xxx====--- + // ++++ + // => + // --+xxx====--- + It("case 27", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + f3 := getData(4) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 2, cb3)).To(Succeed()) // 2 - 6 + checkQueue(map[protocol.ByteCount][]byte{ + 2: f3[:1], + 3: f1, + 6: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 2}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 27, for long frames", func() { + const mult = protocol.MinStreamFrameSize + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + f3 := getData(4 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 2*mult, cb3)).To(Succeed()) // 2 - 6 + checkQueue(map[protocol.ByteCount][]byte{ + 2 * mult: f3[:mult], + 3 * mult: f1, + 6 * mult: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 2 * mult}, + {Start: 10 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx====--- + // ++++++ + // => + // --+xxx====--- + It("case 28", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + f3 := getData(6) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 2, cb3)).To(Succeed()) // 2 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 2: f3[:1], + 3: f1, + 6: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 2}, + {Start: 10, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 28, for long frames", func() { + const mult = protocol.MinStreamFrameSize + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + f3 := getData(6 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 + Expect(s.Push(f3, 2*mult, cb3)).To(Succeed()) // 2 - 8 + checkQueue(map[protocol.ByteCount][]byte{ + 2 * mult: f3[:mult], + 3 * mult: f1, + 6 * mult: f2, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 2 * mult}, + {Start: 10 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx===----- + // +++++ + // => + // ---xxx+++++--- + It("case 29", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(5) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 6, cb3)).To(Succeed()) // 6 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 11, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx===---- + // ++++++ + // => + // ---xxx===++-- + It("case 30", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(6) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f2, + 9: f3[4:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 11, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 30, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(3 * mult) + cb2, t2 := getCallback() + f3 := getData(6 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 9 + Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 11 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 6 * mult: f2, + 9 * mult: f3[4*mult:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 11 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx---===----- + // ++++++++++ + // => + // ---xxx++++++++--- + It("case 31", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(10) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 15 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f3[1:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 15, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 31, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 9)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(3 * mult) + cb2, t2 := getCallback() + f3 := getData(10 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 15 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 6 * mult: f3[mult:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 15 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx---===----- + // +++++++++ + // => + // ---+++++++++--- + It("case 32", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(9) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 12 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 12, End: protocol.MaxByteCount}, + }) + checkCallbackCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + }) + + // ---xxx---===###----- + // ++++++++++++ + // => + // ---xxx++++++++++--- + It("case 33", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(3) + cb2, t2 := getCallback() + f3 := getData(3) + cb3, t3 := getCallback() + f4 := getData(12) + cb4, t4 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 9, cb3)).To(Succeed()) // 12 - 15 + Expect(s.Push(f4, 5, cb4)).To(Succeed()) // 5 - 17 + checkQueue(map[protocol.ByteCount][]byte{ + 3: f1, + 6: f4[1:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 17, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackCalled(t3) + checkCallbackCalled(t4) + }) + + It("case 33, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 11)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(3 * mult) + cb2, t2 := getCallback() + f3 := getData(3 * mult) + cb3, t3 := getCallback() + f4 := getData(12 * mult) + cb4, t4 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 12 + Expect(s.Push(f3, 9*mult, cb3)).To(Succeed()) // 12 - 15 + Expect(s.Push(f4, 5*mult, cb4)).To(Succeed()) // 5 - 17 + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f1, + 6 * mult: f4[mult:], + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 17 * mult, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackCalled(t3) + checkCallbackNotCalled(t4) + }) + + // ---xxx===---### + // ++++++ + // => + // ---xxx++++++### + It("case 34", func() { + f1 := getData(5) + cb1, t1 := getCallback() + f2 := getData(5) + cb2, t2 := getCallback() + f3 := getData(10) + cb3, t3 := getCallback() + f4 := getData(5) + cb4, t4 := getCallback() + Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 + Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 + Expect(s.Push(f4, 20, cb3)).To(Succeed()) // 20 - 25 + Expect(s.Push(f3, 10, cb4)).To(Succeed()) // 10 - 20 + checkQueue(map[protocol.ByteCount][]byte{ + 5: f1, + 10: f3, + 20: f4, + }) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 25, End: protocol.MaxByteCount}, + }) + checkCallbackNotCalled(t1) + checkCallbackCalled(t2) + checkCallbackNotCalled(t3) + checkCallbackNotCalled(t4) + }) + + // ---xxx---####--- + // ++++++++ + // => + // ---++++++####--- + It("case 35", func() { + f1 := getData(3) + cb1, t1 := getCallback() + f2 := getData(4) + cb2, t2 := getCallback() + f3 := getData(8) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 13 + Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 11 + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3}, + {Start: 13, End: protocol.MaxByteCount}, + }) + checkQueue(map[protocol.ByteCount][]byte{ + 3: f3[:6], + 9: f2, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackCalled(t3) + }) + + It("case 35, for long frames", func() { + mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 6)) + f1 := getData(3 * mult) + cb1, t1 := getCallback() + f2 := getData(4 * mult) + cb2, t2 := getCallback() + f3 := getData(8 * mult) + cb3, t3 := getCallback() + Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 + Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 13 + Expect(s.Push(f3, 3*mult, cb3)).To(Succeed()) // 3 - 11 + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 3 * mult}, + {Start: 13 * mult, End: protocol.MaxByteCount}, + }) + checkQueue(map[protocol.ByteCount][]byte{ + 3 * mult: f3[:6*mult], + 9 * mult: f2, + }) + checkCallbackCalled(t1) + checkCallbackNotCalled(t2) + checkCallbackNotCalled(t3) + }) + + Context("receiving data after reads", func() { + It("ignores duplicate frames", func() { + Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) + offset, data, _ := s.Pop() + Expect(offset).To(BeZero()) + Expect(data).To(Equal([]byte("foobar"))) + // now receive the duplicate + Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) + Expect(s.queue).To(BeEmpty()) + checkGaps([]utils.ByteInterval{ + {Start: 6, End: protocol.MaxByteCount}, + }) + }) + + It("ignores parts of frames that have already been read", func() { + Expect(s.Push([]byte("foo"), 0, nil)).To(Succeed()) + offset, data, _ := s.Pop() + Expect(offset).To(BeZero()) + Expect(data).To(Equal([]byte("foo"))) + // now receive the duplicate + Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) + offset, data, _ = s.Pop() + Expect(offset).To(Equal(protocol.ByteCount(3))) + Expect(data).To(Equal([]byte("bar"))) + Expect(s.queue).To(BeEmpty()) + checkGaps([]utils.ByteInterval{ + {Start: 6, End: protocol.MaxByteCount}, + }) + }) + }) + + Context("DoS protection", func() { + It("errors when too many gaps are created", func() { + for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ { + Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), nil)).To(Succeed()) + } + Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps)) + err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, nil) + Expect(err).To(MatchError("too many gaps in received data")) + }) + }) + }) + + Context("stress testing", func() { + type frame struct { + offset protocol.ByteCount + data []byte + } + + for _, lf := range []bool{true, false} { + longFrames := lf + + const num = 1000 + + name := "short" + if longFrames { + name = "long" + } + + Context(fmt.Sprintf("using %s frames", name), func() { + var data []byte + var dataLen protocol.ByteCount + var callbacks []callbackTracker + + BeforeEach(func() { + seed := time.Now().UnixNano() + fmt.Fprintf(GinkgoWriter, "Seed: %d\n", seed) + rand.Seed(seed) + + callbacks = nil + dataLen = 25 + if longFrames { + dataLen = 2 * protocol.MinStreamFrameSize + } + + data = make([]byte, num*dataLen) + for i := 0; i < num; i++ { + for j := protocol.ByteCount(0); j < dataLen; j++ { + data[protocol.ByteCount(i)*dataLen+j] = uint8(i) + } + } + }) + + getRandomFrames := func() []frame { + frames := make([]frame, num) + for i := protocol.ByteCount(0); i < num; i++ { + b := make([]byte, dataLen) + Expect(copy(b, data[i*dataLen:])).To(BeEquivalentTo(dataLen)) + frames[i] = frame{ + offset: i * dataLen, + data: b, + } + } + rand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) + return frames + } + + getData := func() []byte { + var data []byte + for { + offset, b, cb := s.Pop() + if b == nil { + break + } + Expect(offset).To(BeEquivalentTo(len(data))) + data = append(data, b...) + if cb != nil { + cb() + } + } + return data + } + + // push pushes data to the frame sorter + // It creates a new callback and adds the + push := func(data []byte, offset protocol.ByteCount) { + cb, t := getCallback() + ExpectWithOffset(1, s.Push(data, offset, cb)).To(Succeed()) + callbacks = append(callbacks, t) + } + + checkCallbacks := func() { + ExpectWithOffset(1, callbacks).ToNot(BeEmpty()) + for _, t := range callbacks { + checkCallbackCalled(t) + } + } + + It("inserting frames in a random order", func() { + frames := getRandomFrames() + + for _, f := range frames { + push(f.data, f.offset) + } + checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) + + Expect(getData()).To(Equal(data)) + Expect(s.queue).To(BeEmpty()) + checkCallbacks() + }) + + It("inserting frames in a random order, with some duplicates", func() { + frames := getRandomFrames() + + for _, f := range frames { + push(f.data, f.offset) + if rand.Intn(10) < 5 { + df := frames[rand.Intn(len(frames))] + push(df.data, df.offset) + } + } + checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) + + Expect(getData()).To(Equal(data)) + Expect(s.queue).To(BeEmpty()) + checkCallbacks() + }) + + It("inserting frames in a random order, with randomly cut retransmissions", func() { + frames := getRandomFrames() + + for _, f := range frames { + push(f.data, f.offset) + if rand.Intn(10) < 5 { + length := protocol.ByteCount(1 + rand.Intn(int(4*dataLen))) + if length >= num*dataLen { + length = num*dataLen - 1 + } + b := make([]byte, length) + offset := protocol.ByteCount(rand.Intn(int(num*dataLen - length))) + Expect(copy(b, data[offset:offset+length])).To(BeEquivalentTo(length)) + push(b, offset) + } + } + checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) + + Expect(getData()).To(Equal(data)) + Expect(s.queue).To(BeEmpty()) + checkCallbacks() + }) + }) + } + }) +}) diff --git a/internal/quic-go/framer.go b/internal/quic-go/framer.go new file mode 100644 index 00000000..db989480 --- /dev/null +++ b/internal/quic-go/framer.go @@ -0,0 +1,171 @@ +package quic + +import ( + "errors" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type framer interface { + HasData() bool + + QueueControlFrame(wire.Frame) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + + AddActiveStream(protocol.StreamID) + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + + Handle0RTTRejection() error +} + +type framerI struct { + mutex sync.Mutex + + streamGetter streamGetter + version protocol.VersionNumber + + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + + controlFrameMutex sync.Mutex + controlFrames []wire.Frame +} + +var _ framer = &framerI{} + +func newFramer( + streamGetter streamGetter, + v protocol.VersionNumber, +) framer { + return &framerI{ + streamGetter: streamGetter, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, + } +} + +func (f *framerI) HasData() bool { + f.mutex.Lock() + hasData := len(f.streamQueue) > 0 + f.mutex.Unlock() + if hasData { + return true + } + f.controlFrameMutex.Lock() + hasData = len(f.controlFrames) > 0 + f.controlFrameMutex.Unlock() + return hasData +} + +func (f *framerI) QueueControlFrame(frame wire.Frame) { + f.controlFrameMutex.Lock() + f.controlFrames = append(f.controlFrames, frame) + f.controlFrameMutex.Unlock() +} + +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + var length protocol.ByteCount + f.controlFrameMutex.Lock() + for len(f.controlFrames) > 0 { + frame := f.controlFrames[len(f.controlFrames)-1] + frameLen := frame.Length(f.version) + if length+frameLen > maxLen { + break + } + frames = append(frames, ackhandler.Frame{Frame: frame}) + length += frameLen + f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] + } + f.controlFrameMutex.Unlock() + return frames, length +} + +func (f *framerI) AddActiveStream(id protocol.StreamID) { + f.mutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.mutex.Unlock() +} + +func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + var length protocol.ByteCount + var lastFrame *ackhandler.Frame + f.mutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if protocol.MinStreamFrameSize+length > maxLen { + break + } + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + // This should never return an error. Better check it anyway. + // The stream will only be in the streamQueue, if it enqueued itself there. + str, err := f.streamGetter.GetOrOpenSendStream(id) + // The stream can be nil if it completed after it said it had data. + if str == nil || err != nil { + delete(f.activeStreams, id) + continue + } + remainingLen := maxLen - length + // For the last STREAM frame, we'll remove the DataLen field later. + // Therefore, we can pretend to have more bytes available when popping + // the STREAM frame (which will always have the DataLen set). + remainingLen += quicvarint.Len(uint64(remainingLen)) + frame, hasMoreData := str.popStreamFrame(remainingLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) + } + // The frame can be nil + // * if the receiveStream was canceled after it said it had data + // * the remaining size doesn't allow us to add another STREAM frame + if frame == nil { + continue + } + frames = append(frames, *frame) + length += frame.Length(f.version) + lastFrame = frame + } + f.mutex.Unlock() + if lastFrame != nil { + lastFrameLen := lastFrame.Length(f.version) + // account for the smaller size of the last STREAM frame + lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false + length += lastFrame.Length(f.version) - lastFrameLen + } + return frames, length +} + +func (f *framerI) Handle0RTTRejection() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + f.controlFrameMutex.Lock() + f.streamQueue = f.streamQueue[:0] + for id := range f.activeStreams { + delete(f.activeStreams, id) + } + var j int + for i, frame := range f.controlFrames { + switch frame.(type) { + case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame: + return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT") + case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame: + continue + default: + f.controlFrames[j] = f.controlFrames[i] + j++ + } + } + f.controlFrames = f.controlFrames[:j] + f.controlFrameMutex.Unlock() + return nil +} diff --git a/internal/quic-go/framer_test.go b/internal/quic-go/framer_test.go new file mode 100644 index 00000000..201053f3 --- /dev/null +++ b/internal/quic-go/framer_test.go @@ -0,0 +1,385 @@ +package quic + +import ( + "bytes" + "math/rand" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Framer", func() { + const ( + id1 = protocol.StreamID(10) + id2 = protocol.StreamID(11) + ) + + var ( + framer framer + stream1, stream2 *MockSendStreamI + streamGetter *MockStreamGetter + version protocol.VersionNumber + ) + + BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) + stream1 = NewMockSendStreamI(mockCtrl) + stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() + stream2 = NewMockSendStreamI(mockCtrl) + stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() + framer = newFramer(streamGetter, version) + }) + + Context("handling control frames", func() { + It("adds control frames", func() { + mdf := &wire.MaxDataFrame{MaximumData: 0x42} + msf := &wire.MaxStreamsFrame{MaxStreamNum: 0x1337} + framer.QueueControlFrame(mdf) + framer.QueueControlFrame(msf) + frames, length := framer.AppendControlFrames(nil, 1000) + Expect(frames).To(HaveLen(2)) + fs := []wire.Frame{frames[0].Frame, frames[1].Frame} + Expect(fs).To(ContainElement(mdf)) + Expect(fs).To(ContainElement(msf)) + Expect(length).To(Equal(mdf.Length(version) + msf.Length(version))) + }) + + It("says if it has data", func() { + Expect(framer.HasData()).To(BeFalse()) + f := &wire.MaxDataFrame{MaximumData: 0x42} + framer.QueueControlFrame(f) + Expect(framer.HasData()).To(BeTrue()) + frames, _ := framer.AppendControlFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) + Expect(framer.HasData()).To(BeFalse()) + }) + + It("appends to the slice given", func() { + ping := &wire.PingFrame{} + mdf := &wire.MaxDataFrame{MaximumData: 0x42} + framer.QueueControlFrame(mdf) + frames, length := framer.AppendControlFrames([]ackhandler.Frame{{Frame: ping}}, 1000) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(ping)) + Expect(frames[1].Frame).To(Equal(mdf)) + Expect(length).To(Equal(mdf.Length(version))) + }) + + It("adds the right number of frames", func() { + maxSize := protocol.ByteCount(1000) + bf := &wire.DataBlockedFrame{MaximumData: 0x1337} + bfLen := bf.Length(version) + numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize + for i := 0; i < numFrames+1; i++ { + framer.QueueControlFrame(bf) + } + frames, length := framer.AppendControlFrames(nil, maxSize) + Expect(frames).To(HaveLen(numFrames)) + Expect(length).To(BeNumerically(">", maxSize-bfLen)) + frames, length = framer.AppendControlFrames(nil, maxSize) + Expect(frames).To(HaveLen(1)) + Expect(length).To(Equal(bfLen)) + }) + + It("drops *_BLOCKED frames when 0-RTT is rejected", func() { + ping := &wire.PingFrame{} + ncid := &wire.NewConnectionIDFrame{SequenceNumber: 10, ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}} + frames := []wire.Frame{ + &wire.DataBlockedFrame{MaximumData: 1337}, + &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1337}, + &wire.StreamsBlockedFrame{StreamLimit: 13}, + ping, + ncid, + } + rand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) + for _, f := range frames { + framer.QueueControlFrame(f) + } + Expect(framer.Handle0RTTRejection()).To(Succeed()) + fs, length := framer.AppendControlFrames(nil, protocol.MaxByteCount) + Expect(fs).To(HaveLen(2)) + Expect(length).To(Equal(ping.Length(version) + ncid.Length(version))) + }) + }) + + Context("popping STREAM frames", func() { + It("returns nil when popping an empty framer", func() { + Expect(framer.AppendStreamFrames(nil, 1000)).To(BeEmpty()) + }) + + It("returns STREAM frames", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + Offset: 42, + DataLenPresent: true, + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + fs, length := framer.AppendStreamFrames(nil, 1000) + Expect(fs).To(HaveLen(1)) + Expect(fs[0].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(length).To(Equal(f.Length(version))) + }) + + It("says if it has data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + Expect(framer.HasData()).To(BeFalse()) + framer.AddActiveStream(id1) + Expect(framer.HasData()).To(BeTrue()) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")} + f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) + frames, _ := framer.AppendStreamFrames(nil, protocol.MaxByteCount) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f1)) + Expect(framer.HasData()).To(BeTrue()) + frames, _ = framer.AppendStreamFrames(nil, protocol.MaxByteCount) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) + Expect(framer.HasData()).To(BeFalse()) + }) + + It("appends to a frame slice", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + DataLenPresent: true, + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + mdf := &wire.MaxDataFrame{MaximumData: 1337} + frames := []ackhandler.Frame{{Frame: mdf}} + fs, length := framer.AppendStreamFrames(frames, 1000) + Expect(fs).To(HaveLen(2)) + Expect(fs[0].Frame).To(Equal(mdf)) + Expect(fs[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(fs[1].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(length).To(Equal(f.Length(version))) + }) + + It("skips a stream that was reported active, but was completed shortly after", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + DataLenPresent: true, + } + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f)) + }) + + It("skips a stream that was reported active, but doesn't have any data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + DataLenPresent: true, + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f)) + }) + + It("pops from a stream multiple times, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) + framer.AddActiveStream(id1) // only add it once + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f1)) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) + // no further calls to popStreamFrame, after popStreamFrame said there's no more data + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(BeNil()) + }) + + It("re-queues a stream at the end, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f11}, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f12}, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) + framer.AddActiveStream(id1) // only add it once + framer.AddActiveStream(id2) + // first a frame from stream 1 + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f11)) + // then a frame from stream 2 + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) + // then another frame from stream 1 + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f12)) + }) + + It("only dequeues data from each stream once per packet", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + // both streams have more data, and will be re-queued + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, true) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + frames, length := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(f1)) + Expect(frames[1].Frame).To(Equal(f2)) + Expect(length).To(Equal(f1.Length(version) + f2.Length(version))) + }) + + It("returns multiple normal frames in the order they were reported active", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f1 := &wire.StreamFrame{Data: []byte("foobar")} + f2 := &wire.StreamFrame{Data: []byte("foobaz")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) + framer.AddActiveStream(id2) + framer.AddActiveStream(id1) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(f2)) + Expect(frames[1].Frame).To(Equal(f1)) + }) + + It("only asks a stream for data once, even if it was reported active multiple times", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) // only one call to this function + framer.AddActiveStream(id1) + framer.AddActiveStream(id1) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) + }) + + It("does not pop empty frames", func() { + fs, length := framer.AppendStreamFrames(nil, 500) + Expect(fs).To(BeEmpty()) + Expect(length).To(BeZero()) + }) + + It("pops maximum size STREAM frames", func() { + for i := protocol.MinStreamFrameSize; i < 2000; i++ { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + stream1.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { + f := &wire.StreamFrame{ + StreamID: id1, + DataLenPresent: true, + } + f.Data = make([]byte, f.MaxDataLen(size, version)) + Expect(f.Length(version)).To(Equal(size)) + return &ackhandler.Frame{Frame: f}, false + }) + framer.AddActiveStream(id1) + frames, _ := framer.AppendStreamFrames(nil, i) + Expect(frames).To(HaveLen(1)) + f := frames[0].Frame.(*wire.StreamFrame) + Expect(f.DataLenPresent).To(BeFalse()) + Expect(f.Length(version)).To(Equal(i)) + } + }) + + It("pops multiple STREAM frames", func() { + for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + stream1.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { + f := &wire.StreamFrame{ + StreamID: id2, + DataLenPresent: true, + } + f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, version)) + return &ackhandler.Frame{Frame: f}, false + }) + stream2.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { + f := &wire.StreamFrame{ + StreamID: id2, + DataLenPresent: true, + } + f.Data = make([]byte, f.MaxDataLen(size, version)) + Expect(f.Length(version)).To(Equal(size)) + return &ackhandler.Frame{Frame: f}, false + }) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + frames, _ := framer.AppendStreamFrames(nil, i) + Expect(frames).To(HaveLen(2)) + f1 := frames[0].Frame.(*wire.StreamFrame) + f2 := frames[1].Frame.(*wire.StreamFrame) + Expect(f1.DataLenPresent).To(BeTrue()) + Expect(f2.DataLenPresent).To(BeFalse()) + Expect(f1.Length(version) + f2.Length(version)).To(Equal(i)) + } + }) + + It("pops frames that when asked for the the minimum STREAM frame size", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + }) + + It("does not pop frames smaller than the minimum size", func() { + // don't expect a call to PopStreamFrame() + framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize-1) + }) + + It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size + f := &wire.StreamFrame{ + StreamID: id1, + Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), + DataLenPresent: true, + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) + framer.AddActiveStream(id1) + fs, length := framer.AppendStreamFrames(nil, 500) + Expect(fs).To(HaveLen(1)) + Expect(fs[0].Frame).To(Equal(f)) + Expect(length).To(Equal(f.Length(version))) + }) + + It("drops all STREAM frames when 0-RTT is rejected", func() { + framer.AddActiveStream(id1) + Expect(framer.Handle0RTTRejection()).To(Succeed()) + fs, length := framer.AppendStreamFrames(nil, protocol.MaxByteCount) + Expect(fs).To(BeEmpty()) + Expect(length).To(BeZero()) + }) + }) +}) diff --git a/internal/quic-go/handshake/aead.go b/internal/quic-go/handshake/aead.go new file mode 100644 index 00000000..e4e76cba --- /dev/null +++ b/internal/quic-go/handshake/aead.go @@ -0,0 +1,161 @@ +package handshake + +import ( + "crypto/cipher" + "encoding/binary" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) + iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) + return suite.AEAD(key, iv) +} + +type longHeaderSealer struct { + aead cipher.AEAD + headerProtector headerProtector + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderSealer = &longHeaderSealer{} + +func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { + return &longHeaderSealer{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return s.aead.Seal(dst, s.nonceBuf, src, ad) +} + +func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) +} + +func (s *longHeaderSealer) Overhead() int { + return s.aead.Overhead() +} + +type longHeaderOpener struct { + aead cipher.AEAD + headerProtector headerProtector + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderOpener = &longHeaderOpener{} + +func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { + return &longHeaderOpener{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) +} + +func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) + if err == nil { + o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) + } else { + err = ErrDecryptionFailed + } + return dec, err +} + +func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) +} + +type handshakeSealer struct { + LongHeaderSealer + + dropInitialKeys func() + dropped bool +} + +func newHandshakeSealer( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderSealer { + sealer := newLongHeaderSealer(aead, headerProtector) + // The client drops Initial keys when sending the first Handshake packet. + if perspective == protocol.PerspectiveServer { + return sealer + } + return &handshakeSealer{ + LongHeaderSealer: sealer, + dropInitialKeys: dropInitialKeys, + } +} + +func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + data := s.LongHeaderSealer.Seal(dst, src, pn, ad) + if !s.dropped { + s.dropInitialKeys() + s.dropped = true + } + return data +} + +type handshakeOpener struct { + LongHeaderOpener + + dropInitialKeys func() + dropped bool +} + +func newHandshakeOpener( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderOpener { + opener := newLongHeaderOpener(aead, headerProtector) + // The server drops Initial keys when first successfully processing a Handshake packet. + if perspective == protocol.PerspectiveClient { + return opener + } + return &handshakeOpener{ + LongHeaderOpener: opener, + dropInitialKeys: dropInitialKeys, + } +} + +func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) + if err == nil && !o.dropped { + o.dropInitialKeys() + o.dropped = true + } + return dec, err +} diff --git a/internal/quic-go/handshake/aead_test.go b/internal/quic-go/handshake/aead_test.go new file mode 100644 index 00000000..672d60f0 --- /dev/null +++ b/internal/quic-go/handshake/aead_test.go @@ -0,0 +1,204 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/tls" + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Long Header AEAD", func() { + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + + return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), + newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) + } + + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("decodes the packet number", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) + Expect(err).To(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + }) + + Context("header encryption", func() { + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) + + It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := bytes.Repeat([]byte{0xff}, 16) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + }) + + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener() + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) + }) + }) + } + }) + + Describe("Long Header AEAD", func() { + var ( + dropped chan struct{} // use a chan because closing it twice will panic + aead cipher.AEAD + hp headerProtector + ) + dropCb := func() { close(dropped) } + msg := []byte("Lorem ipsum dolor sit amet.") + ad := []byte("Donec in velit neque.") + + BeforeEach(func() { + dropped = make(chan struct{}) + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err = cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) + }) + + Context("for the server", func() { + It("drops keys when first successfully processing a Handshake packet", func() { + serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) + // first try to open an invalid message + _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) + Expect(err).To(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) + // then open a valid message + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) + _, err = serverOpener.Open(nil, enc, 10, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(dropped).To(BeClosed()) + // now open the same message again to make sure the callback is only called once + _, err = serverOpener.Open(nil, enc, 10, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't drop keys when sealing a Handshake packet", func() { + serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) + serverSealer.Seal(nil, msg, 1, ad) + Expect(dropped).ToNot(BeClosed()) + }) + }) + + Context("for the client", func() { + It("drops keys when first sealing a Handshake packet", func() { + clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) + // seal the first message + clientSealer.Seal(nil, msg, 1, ad) + Expect(dropped).To(BeClosed()) + // seal another message to make sure the callback is only called once + clientSealer.Seal(nil, msg, 2, ad) + }) + + It("doesn't drop keys when processing a Handshake packet", func() { + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) + clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) + _, err := clientOpener.Open(nil, enc, 42, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) + }) + }) + }) + } +}) diff --git a/internal/quic-go/handshake/crypto_setup.go b/internal/quic-go/handshake/crypto_setup.go new file mode 100644 index 00000000..1ef63cd0 --- /dev/null +++ b/internal/quic-go/handshake/crypto_setup.go @@ -0,0 +1,819 @@ +package handshake + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/qtls" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// TLS unexpected_message alert +const alertUnexpectedMessage uint8 = 10 + +type messageType uint8 + +// TLS handshake message types. +const ( + typeClientHello messageType = 1 + typeServerHello messageType = 2 + typeNewSessionTicket messageType = 4 + typeEncryptedExtensions messageType = 8 + typeCertificate messageType = 11 + typeCertificateRequest messageType = 13 + typeCertificateVerify messageType = 15 + typeFinished messageType = 20 +) + +func (m messageType) String() string { + switch m { + case typeClientHello: + return "ClientHello" + case typeServerHello: + return "ServerHello" + case typeNewSessionTicket: + return "NewSessionTicket" + case typeEncryptedExtensions: + return "EncryptedExtensions" + case typeCertificate: + return "Certificate" + case typeCertificateRequest: + return "CertificateRequest" + case typeCertificateVerify: + return "CertificateVerify" + case typeFinished: + return "Finished" + default: + return fmt.Sprintf("unknown message type: %d", m) + } +} + +const clientSessionStateRevision = 3 + +type conn struct { + localAddr, remoteAddr net.Addr + version protocol.VersionNumber +} + +var _ ConnWithVersion = &conn{} + +func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { + return &conn{ + localAddr: local, + remoteAddr: remote, + version: version, + } +} + +var _ net.Conn = &conn{} + +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } +func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } + +type cryptoSetup struct { + tlsConf *tls.Config + extraConf *qtls.ExtraConfig + conn *qtls.Conn + + version protocol.VersionNumber + + messageChan chan []byte + isReadingHandshakeMessage chan struct{} + readFirstHandshakeMessage bool + + ourParams *wire.TransportParameters + peerParams *wire.TransportParameters + paramsChan <-chan []byte + + runner handshakeRunner + + alertChan chan uint8 + // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns + handshakeDone chan struct{} + // is closed when Close() is called + closeChan chan struct{} + + zeroRTTParameters *wire.TransportParameters + clientHelloWritten bool + clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written + zeroRTTParametersChan chan<- *wire.TransportParameters + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + + perspective protocol.Perspective + + mutex sync.Mutex // protects all members below + + handshakeCompleteTime time.Time + + readEncLevel protocol.EncryptionLevel + writeEncLevel protocol.EncryptionLevel + + zeroRTTOpener LongHeaderOpener // only set for the server + zeroRTTSealer LongHeaderSealer // only set for the client + + initialStream io.Writer + initialOpener LongHeaderOpener + initialSealer LongHeaderSealer + + handshakeStream io.Writer + handshakeOpener LongHeaderOpener + handshakeSealer LongHeaderSealer + + aead *updatableAEAD + has1RTTSealer bool + has1RTTOpener bool +} + +var ( + _ qtls.RecordLayer = &cryptoSetup{} + _ CryptoSetup = &cryptoSetup{} +) + +// NewCryptoSetupClient creates a new crypto setup for the client +func NewCryptoSetupClient( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + cs, clientHelloWritten := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveClient, + version, + ) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs, clientHelloWritten +} + +// NewCryptoSetupServer creates a new crypto setup for the server +func NewCryptoSetupServer( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) CryptoSetup { + cs, _ := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveServer, + version, + ) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs +} + +func newCryptoSetup( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + perspective protocol.Perspective, + version protocol.VersionNumber, +) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) + if tracer != nil { + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } + extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) + zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) + cs := &cryptoSetup{ + tlsConf: tlsConf, + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + aead: newUpdatableAEAD(rttStats, tracer, logger, version), + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + runner: runner, + ourParams: tp, + paramsChan: extHandler.TransportParameters(), + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + alertChan: make(chan uint8), + clientHelloWrittenChan: make(chan struct{}), + zeroRTTParametersChan: zeroRTTParametersChan, + messageChan: make(chan []byte, 100), + isReadingHandshakeMessage: make(chan struct{}), + closeChan: make(chan struct{}), + version: version, + } + var maxEarlyData uint32 + if enable0RTT { + maxEarlyData = 0xffffffff + } + cs.extraConf = &qtls.ExtraConfig{ + GetExtensions: extHandler.GetExtensions, + ReceivedExtensions: extHandler.ReceivedExtensions, + AlternativeRecordLayer: cs, + EnforceNextProtoSelection: true, + MaxEarlyData: maxEarlyData, + Accept0RTT: cs.accept0RTT, + Rejected0RTT: cs.rejected0RTT, + Enable0RTT: enable0RTT, + GetAppDataForSessionState: cs.marshalDataForSessionState, + SetAppDataFromSessionState: cs.handleDataFromSessionState, + } + return cs, zeroRTTParametersChan +} + +func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { + initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) + h.initialSealer = initialSealer + h.initialOpener = initialOpener + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } +} + +func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { + return h.aead.SetLargestAcked(pn) +} + +func (h *cryptoSetup) RunHandshake() { + // Handle errors that might occur when HandleData() is called. + handshakeComplete := make(chan struct{}) + handshakeErrChan := make(chan error, 1) + go func() { + defer close(h.handshakeDone) + if err := h.conn.Handshake(); err != nil { + handshakeErrChan <- err + return + } + close(handshakeComplete) + }() + + if h.perspective == protocol.PerspectiveClient { + select { + case err := <-handshakeErrChan: + h.onError(0, err.Error()) + return + case <-h.clientHelloWrittenChan: + } + } + + select { + case <-handshakeComplete: // return when the handshake is done + h.mutex.Lock() + h.handshakeCompleteTime = time.Now() + h.mutex.Unlock() + h.runner.OnHandshakeComplete() + case <-h.closeChan: + // wait until the Handshake() go routine has returned + <-h.handshakeDone + case alert := <-h.alertChan: + handshakeErr := <-handshakeErrChan + h.onError(alert, handshakeErr.Error()) + } +} + +func (h *cryptoSetup) onError(alert uint8, message string) { + var err error + if alert == 0 { + err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} + } else { + err = qerr.NewCryptoError(alert, message) + } + h.runner.OnError(err) +} + +// Close closes the crypto setup. +// It aborts the handshake, if it is still running. +// It must only be called once. +func (h *cryptoSetup) Close() error { + close(h.closeChan) + // wait until qtls.Handshake() actually returned + <-h.handshakeDone + return nil +} + +// handleMessage handles a TLS handshake message. +// It is called by the crypto streams when a new message is available. +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { + msgType := messageType(data[0]) + h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) + if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { + h.onError(alertUnexpectedMessage, err.Error()) + return false + } + h.messageChan <- data + if encLevel == protocol.Encryption1RTT { + h.handlePostHandshakeMessage() + return false + } +readLoop: + for { + select { + case data := <-h.paramsChan: + if data == nil { + h.onError(0x6d, "missing quic_transport_parameters extension") + } else { + h.handleTransportParameters(data) + } + case <-h.isReadingHandshakeMessage: + break readLoop + case <-h.handshakeDone: + break readLoop + case <-h.closeChan: + break readLoop + } + } + // We're done with the Initial encryption level after processing a ClientHello / ServerHello, + // but only if a handshake opener and sealer was created. + // Otherwise, a HelloRetryRequest was performed. + // We're done with the Handshake encryption level after processing the Finished message. + return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || + msgType == typeFinished +} + +func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { + var expected protocol.EncryptionLevel + switch msgType { + case typeClientHello, + typeServerHello: + expected = protocol.EncryptionInitial + case typeEncryptedExtensions, + typeCertificate, + typeCertificateRequest, + typeCertificateVerify, + typeFinished: + expected = protocol.EncryptionHandshake + case typeNewSessionTicket: + expected = protocol.Encryption1RTT + default: + return fmt.Errorf("unexpected handshake message: %d", msgType) + } + if encLevel != expected { + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) + } + return nil +} + +func (h *cryptoSetup) handleTransportParameters(data []byte) { + var tp wire.TransportParameters + if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { + h.runner.OnError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) + } + h.peerParams = &tp + h.runner.OnReceivedParams(h.peerParams) +} + +// must be called after receiving the transport parameters +func (h *cryptoSetup) marshalDataForSessionState() []byte { + buf := &bytes.Buffer{} + quicvarint.Write(buf, clientSessionStateRevision) + quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) + h.peerParams.MarshalForSessionTicket(buf) + return buf.Bytes() +} + +func (h *cryptoSetup) handleDataFromSessionState(data []byte) { + tp, err := h.handleDataFromSessionStateImpl(data) + if err != nil { + h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) + return + } + h.zeroRTTParameters = tp +} + +func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { + r := bytes.NewReader(data) + ver, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ver != clientSessionStateRevision { + return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return nil, err + } + return &tp, nil +} + +// only valid for the server +func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { + var appData []byte + // Save transport parameters to the session ticket if we're allowing 0-RTT. + if h.extraConf.MaxEarlyData > 0 { + appData = (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() + } + return h.conn.GetSessionTicket(appData) +} + +// accept0RTT is called for the server when receiving the client's session ticket. +// It decides whether to accept 0-RTT. +func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { + var t sessionTicket + if err := t.Unmarshal(sessionTicketData); err != nil { + h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + return false + } + valid := h.ourParams.ValidFor0RTT(t.Parameters) + if valid { + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) + } else { + h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") + } + return valid +} + +// rejected0RTT is called for the client when the server rejects 0-RTT. +func (h *cryptoSetup) rejected0RTT() { + h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") + + h.mutex.Lock() + had0RTTKeys := h.zeroRTTSealer != nil + h.zeroRTTSealer = nil + h.mutex.Unlock() + + if had0RTTKeys { + h.runner.DropKeys(protocol.Encryption0RTT) + } +} + +func (h *cryptoSetup) handlePostHandshakeMessage() { + // make sure the handshake has already completed + <-h.handshakeDone + + done := make(chan struct{}) + defer close(done) + + // h.alertChan is an unbuffered channel. + // If an error occurs during conn.HandlePostHandshakeMessage, + // it will be sent on this channel. + // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. + alertChan := make(chan uint8, 1) + go func() { + <-h.isReadingHandshakeMessage + select { + case alert := <-h.alertChan: + alertChan <- alert + case <-done: + } + }() + + if err := h.conn.HandlePostHandshakeMessage(); err != nil { + select { + case <-h.closeChan: + case alert := <-alertChan: + h.onError(alert, err.Error()) + } + } +} + +// ReadHandshakeMessage is called by TLS. +// It blocks until a new handshake message is available. +func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { + if !h.readFirstHandshakeMessage { + h.readFirstHandshakeMessage = true + } else { + select { + case h.isReadingHandshakeMessage <- struct{}{}: + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } + } + select { + case msg := <-h.messageChan: + return msg, nil + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } +} + +func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveClient { + panic("Received 0-RTT read key for the client") + } + h.zeroRTTOpener = newLongHeaderOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) + } + return + case qtls.EncryptionHandshake: + h.readEncLevel = protocol.EncryptionHandshake + h.handshakeOpener = newHandshakeOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.readEncLevel = protocol.Encryption1RTT + h.aead.SetReadKey(suite, trafficSecret) + h.has1RTTOpener = true + h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + default: + panic("unexpected read encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) + } +} + +func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveServer { + panic("Received 0-RTT write key for the server") + } + h.zeroRTTSealer = newLongHeaderSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) + } + return + case qtls.EncryptionHandshake: + h.writeEncLevel = protocol.EncryptionHandshake + h.handshakeSealer = newHandshakeSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.writeEncLevel = protocol.Encryption1RTT + h.aead.SetWriteKey(suite, trafficSecret) + h.has1RTTSealer = true + h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.zeroRTTSealer != nil { + h.zeroRTTSealer = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + default: + panic("unexpected write encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) + } +} + +// WriteRecord is called when TLS writes data +func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + //nolint:exhaustive // LS records can only be written for Initial and Handshake. + switch h.writeEncLevel { + case protocol.EncryptionInitial: + // assume that the first WriteRecord call contains the ClientHello + n, err := h.initialStream.Write(p) + if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { + h.clientHelloWritten = true + close(h.clientHelloWrittenChan) + if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { + h.logger.Debugf("Doing 0-RTT.") + h.zeroRTTParametersChan <- h.zeroRTTParameters + } else { + h.logger.Debugf("Not doing 0-RTT.") + h.zeroRTTParametersChan <- nil + } + } + return n, err + case protocol.EncryptionHandshake: + return h.handshakeStream.Write(p) + default: + panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) + } +} + +func (h *cryptoSetup) SendAlert(alert uint8) { + select { + case h.alertChan <- alert: + case <-h.closeChan: + // no need to send an alert when we've already closed + } +} + +// used a callback in the handshakeSealer and handshakeOpener +func (h *cryptoSetup) dropInitialKeys() { + h.mutex.Lock() + h.initialOpener = nil + h.initialSealer = nil + h.mutex.Unlock() + h.runner.DropKeys(protocol.EncryptionInitial) + h.logger.Debugf("Dropping Initial keys.") +} + +func (h *cryptoSetup) SetHandshakeConfirmed() { + h.aead.SetHandshakeConfirmed() + // drop Handshake keys + var dropped bool + h.mutex.Lock() + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } + h.mutex.Unlock() + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } +} + +func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return h.initialSealer, nil +} + +func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTSealer == nil { + return nil, ErrKeysDropped + } + return h.zeroRTTSealer, nil +} + +func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeSealer == nil { + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return nil, ErrKeysNotYetAvailable + } + return h.handshakeSealer, nil +} + +func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if !h.has1RTTSealer { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialOpener == nil { + return nil, ErrKeysDropped + } + return h.initialOpener, nil +} + +func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.zeroRTTOpener, nil +} + +func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.handshakeOpener, nil +} + +func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { + h.zeroRTTOpener = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + + if !h.has1RTTOpener { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) ConnectionState() ConnectionState { + return qtls.GetConnectionState(h.conn) +} diff --git a/internal/quic-go/handshake/crypto_setup_test.go b/internal/quic-go/handshake/crypto_setup_test.go new file mode 100644 index 00000000..4adefcaa --- /dev/null +++ b/internal/quic-go/handshake/crypto_setup_test.go @@ -0,0 +1,864 @@ +package handshake + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" + + mocktls "github.com/imroc/req/v3/internal/quic-go/mocks/tls" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/testdata" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +type chunk struct { + data []byte + encLevel protocol.EncryptionLevel +} + +type stream struct { + encLevel protocol.EncryptionLevel + chunkChan chan<- chunk +} + +func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { + return &stream{ + chunkChan: chunkChan, + encLevel: encLevel, + } +} + +func (s *stream) Write(b []byte) (int, error) { + data := make([]byte, len(b)) + copy(data, b) + select { + case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: + default: + panic("chunkChan too small") + } + return len(b), nil +} + +var _ = Describe("Crypto Setup TLS", func() { + var clientConf, serverConf *tls.Config + + // unparam incorrectly complains that the first argument is never used. + //nolint:unparam + initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { + chunkChan := make(chan chunk, 100) + initialStream := newStream(chunkChan, protocol.EncryptionInitial) + handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) + return chunkChan, initialStream, handshakeStream + } + + BeforeEach(func() { + serverConf = testdata.GetTLSConfig() + serverConf.NextProtos = []string{"crypto-setup"} + clientConf = &tls.Config{ + ServerName: "localhost", + RootCAs: testdata.GetRootCA(), + NextProtos: []string{"crypto-setup"}, + } + }) + + It("returns Handshake() when an error occurs in qtls", func() { + sErrChan := make(chan error, 1) + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + _, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + testdata.GetTLSConfig(), + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "local error: tls: unexpected message", + }))) + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + handledMessage := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.HandleMessage(fakeCH, protocol.EncryptionInitial) + close(handledMessage) + }() + Eventually(handledMessage).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + + It("handles qtls errors occurring before during ClientHello generation", func() { + sErrChan := make(chan error, 1) + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + _, sInitialStream, sHandshakeStream := initStreams() + tlsConf := testdata.GetTLSConfig() + tlsConf.InsecureSkipVerify = true + tlsConf.NextProtos = []string{""} + cl, _ := NewCryptoSetupClient( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + runner, + tlsConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cl.RunHandshake() + close(done) + }() + + Eventually(done).Should(BeClosed()) + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: "tls: invalid NextProtos value", + }))) + }) + + It("errors when a message is received at the wrong encryption level", func() { + sErrChan := make(chan error, 1) + _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + testdata.GetTLSConfig(), + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", + }))) + + // make the go routine return + Expect(server.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("returns Handshake() when handling a message fails", func() { + sErrChan := make(chan error, 1) + _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + runner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + var err error + Expect(sErrChan).To(Receive(&err)) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + close(done) + }() + + fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) + server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level + Eventually(done).Should(BeClosed()) + }) + + It("returns Handshake() when it is closed", func() { + _, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + NewMockHandshakeRunner(mockCtrl), + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + server.RunHandshake() + close(done) + }() + Expect(server.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + Context("doing the handshake", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + + newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { + rttStats := &utils.RTTStats{} + rttStats.UpdateRTT(rtt, 0, time.Now()) + ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) + return rttStats + } + + handshake := func(client CryptoSetup, cChunkChan <-chan chunk, + server CryptoSetup, sChunkChan <-chan chunk, + ) { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + for { + select { + case c := <-cChunkChan: + msgType := messageType(c.data[0]) + finished := server.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeClientHello { + // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. + _, err := server.GetHandshakeOpener() + Expect(finished).To(Equal(err == nil)) + } else { + Expect(finished).To(BeFalse()) + } + case c := <-sChunkChan: + msgType := messageType(c.data[0]) + finished := client.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeServerHello { + Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) + } else { + Expect(finished).To(BeFalse()) + } + case <-done: // handshake complete + return + } + } + }() + + go func() { + defer GinkgoRecover() + defer close(done) + server.RunHandshake() + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + client.HandleMessage(ticket, protocol.Encryption1RTT) + } + }() + + client.RunHandshake() + Eventually(done).Should(BeClosed()) + } + + handshakeWithTLSConf := func( + clientConf, serverConf *tls.Config, + clientRTTStats, serverRTTStats *utils.RTTStats, + clientTransportParameters, serverTransportParameters *wire.TransportParameters, + enable0RTT bool, + ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { + var cHandshakeComplete bool + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cErrChan := make(chan error, 1) + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) + cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) + cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) + client, clientHelloWrittenChan := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + clientTransportParameters, + cRunner, + clientConf, + enable0RTT, + clientRTTStats, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + var sHandshakeComplete bool + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sErrChan := make(chan error, 1) + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) + sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) + if serverTransportParameters.StatelessResetToken == nil { + var token protocol.StatelessResetToken + serverTransportParameters.StatelessResetToken = &token + } + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + serverTransportParameters, + sRunner, + serverConf, + enable0RTT, + serverRTTStats, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + handshake(client, cChunkChan, server, sChunkChan) + var cErr, sErr error + select { + case sErr = <-sErrChan: + default: + Expect(sHandshakeComplete).To(BeTrue()) + } + select { + case cErr = <-cErrChan: + default: + Expect(cHandshakeComplete).To(BeTrue()) + } + return clientHelloWrittenChan, client, cErr, server, sErr + } + + It("handshakes", func() { + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("performs a HelloRetryRequst", func() { + serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("handshakes with client auth", func() { + clientConf.Certificates = []tls.Certificate{generateCert()} + serverConf.ClientAuth = tls.RequireAnyClientCert + _, _, clientErr, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("signals when it has written the ClientHello", func() { + runner := NewMockHandshakeRunner(mockCtrl) + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, chChan := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + runner, + &tls.Config{InsecureSkipVerify: true}, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RunHandshake() + close(done) + }() + var ch chunk + Eventually(cChunkChan).Should(Receive(&ch)) + Eventually(chChan).Should(Receive(BeNil())) + // make sure the whole ClientHello was written + Expect(len(ch.data)).To(BeNumerically(">=", 4)) + Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) + Expect(len(ch.data) - 4).To(Equal(length)) + + // make the go routine return + Expect(client.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second} + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + cTransportParameters, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + var token protocol.StatelessResetToken + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) + sRunner.EXPECT().OnHandshakeComplete() + sTransportParameters := &wire.TransportParameters{ + MaxIdleTimeout: 0x1337 * time.Second, + StatelessResetToken: &token, + } + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + sTransportParameters, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) + }) + + Context("with session tickets", func() { + It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + // inject an invalid session ticket + cRunner.EXPECT().OnError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.EncryptionHandshake) + }) + + It("errors when handling the NewSessionTicket fails", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _ := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{}, + cRunner, + clientConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, + ) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + var token protocol.StatelessResetToken + server := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + nil, + nil, + &wire.TransportParameters{StatelessResetToken: &token}, + sRunner, + serverConf, + false, + &utils.RTTStats{}, + nil, + utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + // inject an invalid session ticket + cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.Encryption1RTT) + }) + + It("uses session resumption", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + clientRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + }) + + It("doesn't use session resumption if the server disabled it", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + _, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + + serverConf.SessionTicketsDisabled = true + csc.EXPECT().Get(gomock.Any()).Return(state, true) + _, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{}, + false, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + }) + + It("uses 0-RTT", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + const initialMaxData protocol.ByteCount = 1337 + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, serverOrigRTTStats, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), nil) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + + clientRTTStats := &utils.RTTStats{} + serverRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, serverRTTStats, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) + + var tp *wire.TransportParameters + Expect(clientHelloWrittenChan).To(Receive(&tp)) + Expect(tp.InitialMaxData).To(Equal(initialMaxData)) + + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(server.ConnectionState().Used0RTT).To(BeTrue()) + Expect(client.ConnectionState().Used0RTT).To(BeTrue()) + }) + + It("rejects 0-RTT, when the transport parameters changed", func() { + csc := mocktls.NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) + const initialMaxData protocol.ByteCount = 1337 + clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + clientOrigRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + Expect(clientHelloWrittenChan).To(Receive(BeNil())) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), nil) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + + clientRTTStats := &utils.RTTStats{} + clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + clientConf, serverConf, + clientRTTStats, &utils.RTTStats{}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, + true, + ) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + + var tp *wire.TransportParameters + Expect(clientHelloWrittenChan).To(Receive(&tp)) + Expect(tp.InitialMaxData).To(Equal(initialMaxData)) + + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + Expect(server.ConnectionState().Used0RTT).To(BeFalse()) + Expect(client.ConnectionState().Used0RTT).To(BeFalse()) + }) + }) + }) +}) diff --git a/internal/quic-go/handshake/handshake_suite_test.go b/internal/quic-go/handshake/handshake_suite_test.go new file mode 100644 index 00000000..a464ba1c --- /dev/null +++ b/internal/quic-go/handshake/handshake_suite_test.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "crypto/tls" + "encoding/hex" + "strings" + "testing" + + "github.com/imroc/req/v3/internal/quic-go/qtls" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestHandshake(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Handshake Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) + +func splitHexString(s string) (slice []byte) { + for _, ss := range strings.Split(s, " ") { + if ss[0:2] == "0x" { + ss = ss[2:] + } + d, err := hex.DecodeString(ss) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + slice = append(slice, d...) + } + return +} + +var cipherSuites = []*qtls.CipherSuiteTLS13{ + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), + qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), +} diff --git a/internal/quic-go/handshake/header_protector.go b/internal/quic-go/handshake/header_protector.go new file mode 100644 index 00000000..c3e96f24 --- /dev/null +++ b/internal/quic-go/handshake/header_protector.go @@ -0,0 +1,136 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + + "golang.org/x/crypto/chacha20" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" +) + +type headerProtector interface { + EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) + DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) +} + +func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { + if v == protocol.Version2 { + return "quicv2 hp" + } + return "quic hp" +} + +func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { + hkdfLabel := hkdfHeaderProtectionLabel(v) + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + case tls.TLS_CHACHA20_POLY1305_SHA256: + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + default: + panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) + } +} + +type aesHeaderProtector struct { + mask []byte + block cipher.Block + isLongHeader bool +} + +var _ headerProtector = &aesHeaderProtector{} + +func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + block, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return &aesHeaderProtector{ + block: block, + mask: make([]byte, block.BlockSize()), + isLongHeader: isLongHeader, + } +} + +func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != len(p.mask) { + panic("invalid sample size") + } + p.block.Encrypt(p.mask, sample) + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} + +type chachaHeaderProtector struct { + mask [5]byte + + key [32]byte + isLongHeader bool +} + +var _ headerProtector = &chachaHeaderProtector{} + +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + + p := &chachaHeaderProtector{ + isLongHeader: isLongHeader, + } + copy(p.key[:], hpKey) + return p +} + +func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != 16 { + panic("invalid sample size") + } + for i := 0; i < 5; i++ { + p.mask[i] = 0 + } + cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) + if err != nil { + panic(err) + } + cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) + cipher.XORKeyStream(p.mask[:], p.mask[:]) + p.applyMask(firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/quic-go/handshake/hkdf.go b/internal/quic-go/handshake/hkdf.go new file mode 100644 index 00000000..c4fd86c5 --- /dev/null +++ b/internal/quic-go/handshake/hkdf.go @@ -0,0 +1,29 @@ +package handshake + +import ( + "crypto" + "encoding/binary" + + "golang.org/x/crypto/hkdf" +) + +// hkdfExpandLabel HKDF expands a label. +// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the +// hkdfExpandLabel in the standard library. +func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { + b := make([]byte, 3, 3+6+len(label)+1+len(context)) + binary.BigEndian.PutUint16(b, uint16(length)) + b[2] = uint8(6 + len(label)) + b = append(b, []byte("tls13 ")...) + b = append(b, []byte(label)...) + b = b[:3+6+len(label)+1] + b[3+6+len(label)] = uint8(len(context)) + b = append(b, context...) + + out := make([]byte, length) + n, err := hkdf.Expand(hash.New, secret, b).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/internal/quic-go/handshake/hkdf_test.go b/internal/quic-go/handshake/hkdf_test.go new file mode 100644 index 00000000..16154199 --- /dev/null +++ b/internal/quic-go/handshake/hkdf_test.go @@ -0,0 +1,17 @@ +package handshake + +import ( + "crypto" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Initial AEAD using AES-GCM", func() { + // Result generated by running in qtls: + // cipherSuiteTLS13ByID(TLS_AES_128_GCM_SHA256).expandLabel([]byte("secret"), []byte("context"), "label", 42) + It("gets the same results as qtls", func() { + expanded := hkdfExpandLabel(crypto.SHA256, []byte("secret"), []byte("context"), "label", 42) + Expect(expanded).To(Equal([]byte{0x78, 0x87, 0x6a, 0xb5, 0x84, 0xa2, 0x26, 0xb7, 0x8, 0x5a, 0x7b, 0x3a, 0x4c, 0xbb, 0x1e, 0xbc, 0x2f, 0x9b, 0x67, 0xd0, 0x6a, 0xa2, 0x24, 0xb4, 0x7d, 0x29, 0x3c, 0x7a, 0xce, 0xc7, 0xc3, 0x74, 0xcd, 0x59, 0x7a, 0xa8, 0x21, 0x5e, 0xe7, 0xca, 0x1, 0xda})) + }) +}) diff --git a/internal/quic-go/handshake/initial_aead.go b/internal/quic-go/handshake/initial_aead.go new file mode 100644 index 00000000..8a579b20 --- /dev/null +++ b/internal/quic-go/handshake/initial_aead.go @@ -0,0 +1,81 @@ +package handshake + +import ( + "crypto" + "crypto/tls" + + "golang.org/x/crypto/hkdf" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" +) + +var ( + quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} + quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} +) + +const ( + hkdfLabelKeyV1 = "quic key" + hkdfLabelKeyV2 = "quicv2 key" + hkdfLabelIVV1 = "quic iv" + hkdfLabelIVV2 = "quicv2 iv" +) + +func getSalt(v protocol.VersionNumber) []byte { + if v == protocol.Version2 { + return quicSaltV2 + } + if v == protocol.Version1 { + return quicSaltV1 + } + return quicSaltOld +} + +var initialSuite = &qtls.CipherSuiteTLS13{ + ID: tls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, +} + +// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { + clientSecret, serverSecret := computeSecrets(connID, v) + var mySecret, otherSecret []byte + if pers == protocol.PerspectiveClient { + mySecret = clientSecret + otherSecret = serverSecret + } else { + mySecret = serverSecret + otherSecret = clientSecret + } + myKey, myIV := computeInitialKeyAndIV(mySecret, v) + otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) + + encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) + decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) + + return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) +} + +func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { + initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) + serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) + return +} + +func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) + iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) + return +} diff --git a/internal/quic-go/handshake/initial_aead_test.go b/internal/quic-go/handshake/initial_aead_test.go new file mode 100644 index 00000000..6a97eb70 --- /dev/null +++ b/internal/quic-go/handshake/initial_aead_test.go @@ -0,0 +1,219 @@ +package handshake + +import ( + "fmt" + "math/rand" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Initial AEAD using AES-GCM", func() { + It("converts the string representation used in the draft into byte slices", func() { + Expect(splitHexString("0xdeadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + Expect(splitHexString("deadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) + + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + + DescribeTable("computes the client key and IV", + func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { + clientSecret, _ := computeSecrets(connID, v) + Expect(clientSecret).To(Equal(expectedClientSecret)) + key, iv := computeInitialKeyAndIV(clientSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft-29", + protocol.VersionDraft29, + splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"), + splitHexString("175257a31eb09dea9366d8bb79ad80ba"), + splitHexString("6b26114b9cba2b63a9e8dd4f"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), + splitHexString("1f369613dd76d5467730efcbe3b1a22d"), + splitHexString("fa044b2f42a3fd3b46fb255c"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("9fe72e1452e91f551b770005054034e4 7575d4a0fb4c27b7c6cb303a338423ae"), + splitHexString("95df2be2e8d549c82e996fc9339f4563"), + splitHexString("ea5e3c95f933db14b7020ad8"), + ), + ) + + DescribeTable("computes the server key and IV", + func(v protocol.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { + _, serverSecret := computeSecrets(connID, v) + Expect(serverSecret).To(Equal(expectedServerSecret)) + key, iv := computeInitialKeyAndIV(serverSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"), + splitHexString("149d0b1662ab871fbe63c49b5e655a5d"), + splitHexString("bab2b12a4c76016ace47856d"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), + splitHexString("cf3a5331653c364c88f0f379b6067e37"), + splitHexString("0ac1493ca1905853b0bba03e"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("3c9bf6a9c1c8c71819876967bd8b979e fd98ec665edf27f22c06e9845ba0ae2f"), + splitHexString("15d5b4d9a2b8916aa39b1bfe574d2aad"), + splitHexString("a85e7ac31cd275cbb095c626"), + ), + ) + + DescribeTable("encrypts the client's Initial", + func(v protocol.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) + data = append(data, make([]byte, 1162-len(data))...) // add PADDING + sealed := sealer.Seal(nil, data, 2, header) + sample := sealed[0:16] + Expect(sample).To(Equal(expectedSample)) + sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) + Expect(header[0]).To(Equal(expectedHdrFirstByte)) + Expect(header[len(header)-4:]).To(Equal(expectedHdr)) + packet := append(header, sealed...) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c3ff00001d088394c8f03e5157080000449e00000002"), + splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001"), + splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"), + byte(0xc5), + splitHexString("4a95245b"), + splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c300000001088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"), + byte(0xc0), + splitHexString("7b9aec34"), + splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d3709a50c4088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("23b8e610589c83c92d0e97eb7a6e5003"), + byte(0xdd), + splitHexString("4391d848"), + splitHexString("dd709a50c4088394c8f03e5157080000 449e4391d84823b8e610589c83c92d0e 97eb7a6e5003f57764c5c7f0095ba54b 90818f1bfeecc1c97c54fc731edbd2a2 44e3b1e639a9bc75ed545b98649343b2 53615ec6b3e4df0fd2e7fe9d691a09e6 a144b436d8a2c088a404262340dfd995 ec3865694e3026ecd8c6d2561a5a3667 2a1005018168c0f081c10e2bf14d550c 977e28bb9a759c57d0f7ffb1cdfb40bd 774dec589657542047dffefa56fc8089 a4d1ef379c81ba3df71a05ddc7928340 775910feb3ce4cbcfd8d253edd05f161 458f9dc44bea017c3117cca7065a315d eda9464e672ec80c3f79ac993437b441 ef74227ecc4dc9d597f66ab0ab8d214b 55840c70349d7616cbe38e5e1d052d07 f1fedb3dd3c4d8ce295724945e67ed2e efcd9fb52472387f318e3d9d233be7df c79d6bf6080dcbbb41feb180d7858849 7c3e439d38c334748d2b56fd19ab364d 057a9bd5a699ae145d7fdbc8f5777518 1b0a97c3bdedc91a555d6c9b8634e106 d8c9ca45a9d5450a7679edc545da9102 5bc93a7cf9a023a066ffadb9717ffaf3 414c3b646b5738b3cc4116502d18d79d 8227436306d9b2b3afc6c785ce3c817f eb703a42b9c83b59f0dcef1245d0b3e4 0299821ec19549ce489714fe2611e72c d882f4f70dce7d3671296fc045af5c9f 630d7b49a3eb821bbca60f1984dce664 91713bfe06001a56f51bb3abe92f7960 547c4d0a70f4a962b3f05dc25a34bbe8 30a7ea4736d3b0161723500d82beda9b e3327af2aa413821ff678b2a876ec4b0 0bb605ffcc3917ffdc279f187daa2fce 8cde121980bba8ec8f44ca562b0f1319 14c901cfbd847408b778e6738c7bb5b1 b3f97d01b0a24dcca40e3bed29411b1b a8f60843c4a241021b23132b9500509b 9a3516d4a9dd41d3bacbcd426b451393 521828afedcf20fa46ac24f44a8e2973 30b16705d5d5f798eff9e9134a065979 87a1db4617caa2d93837730829d4d89e 16413be4d8a8a38a7e6226623b64a820 178ec3a66954e10710e043ae73dd3fb2 715a0525a46343fb7590e5eac7ee55fc 810e0d8b4b8f7be82cd5a214575a1b99 629d47a9b281b61348c8627cab38e2a6 4db6626e97bb8f77bdcb0fee476aedd7 ba8f5441acaab00f4432edab3791047d 9091b2a753f035648431f6d12f7d6a68 1e64c861f4ac911a0f7d6ec0491a78c9 f192f96b3a5e7560a3f056bc1ca85983 67ad6acb6f2e034c7f37beeb9ed470c4 304af0107f0eb919be36a86f68f37fa6 1dae7aff14decd67ec3157a11488a14f ed0142828348f5f608b0fe03e1f3c0af 3acca0ce36852ed42e220ae9abf8f890 6f00f1b86bff8504c8f16c784fd52d25 e013ff4fda903e9e1eb453c1464b1196 6db9b28e8f26a3fc419e6a60a48d4c72 14ee9c6c6a12b68a32cac8f61580c64f 29cb6922408783c6d12e725b014fe485 cd17e484c5952bf99bc94941d4b1919d 04317b8aa1bd3754ecbaa10ec227de85 40695bf2fb8ee56f6dc526ef366625b9 1aa4970b6ffa5c8284b9b5ab852b905f 9d83f5669c0535bc377bcc05ad5e48e2 81ec0e1917ca3c6a471f8da0894bc82a c2a8965405d6eef3b5e293a88fda203f 09bdc72757b107ab14880eaa3ef7045b 580f4821ce6dd325b5a90655d8c5b55f 76fb846279a9b518c5e9b9a21165c509 3ed49baaacadf1f21873266c767f6769"), + ), + ) + + DescribeTable("encrypts the server's Initial", + func(v protocol.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) + sealed := sealer.Seal(nil, data, 1, header) + sample := sealed[2 : 2+16] + Expect(sample).To(Equal(expectedSample)) + sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) + Expect(header).To(Equal(expectedHdr)) + packet := append(header, sealed...) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c1ff00001d0008f067a5502a4262b50040740001"), + splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304"), + splitHexString("823a5d3a1207c86ee49132824f046524"), + splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"), + splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c1000000010008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("2cd0991cd25b0aac406a5816b6394100"), + splitHexString("cf000000010008f067a5502a4262b5004075c0d9"), + splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d1709a50c40008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("ebb7972fdce59d50e7e49ff2a7e8de76"), + splitHexString("d0709a50c40008f067a5502a4262b5004075103e"), + splitHexString("d0709a50c40008f067a5502a4262b500 4075103e63b4ebb7972fdce59d50e7e4 9ff2a7e8de76b0cd8c10100a1f13d549 dd6fe801588fb14d279bef8d7c53ef62 66a9a7a1a5f2fa026c236a5bf8df5aa0 f9d74773aeccfffe910b0f76814b5e33 f7b7f8ec278d23fd8c7a9e66856b8bbe 72558135bca27c54d63fcc902253461c fc089d4e6b9b19"), + ), + ) + + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + It("seals and opens", func() { + connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} + clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) + serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) + + clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) + m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) + Expect(err).ToNot(HaveOccurred()) + Expect(m).To(Equal([]byte("foobar"))) + serverMessage := serverSealer.Seal(nil, []byte("raboof"), 99, []byte("daa")) + m, err = clientOpener.Open(nil, serverMessage, 99, []byte("daa")) + Expect(err).ToNot(HaveOccurred()) + Expect(m).To(Equal([]byte("raboof"))) + }) + + It("doesn't work if initialized with different connection IDs", func() { + c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} + c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} + clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) + _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) + + clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) + _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("encrypts und decrypts the header", func() { + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) + serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) + + // the first byte and the last 4 bytes should be encrypted + header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + clientSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + serverOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + + serverSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + clientOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) + }) + } +}) diff --git a/internal/quic-go/handshake/interface.go b/internal/quic-go/handshake/interface.go new file mode 100644 index 00000000..43ed0236 --- /dev/null +++ b/internal/quic-go/handshake/interface.go @@ -0,0 +1,102 @@ +package handshake + +import ( + "errors" + "io" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +var ( + // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding opener has not yet been initialized + // This can happen when packets arrive out of order. + ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") + // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding keys have already been dropped. + ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") + // ErrDecryptionFailed is returned when the AEAD fails to open the packet. + ErrDecryptionFailed = errors.New("decryption failed") +) + +// ConnectionState contains information about the state of the connection. +type ConnectionState = qtls.ConnectionState + +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + +// LongHeaderOpener opens a long header packet +type LongHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + +// ShortHeaderOpener opens a short header packet +type ShortHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) +} + +// LongHeaderSealer seals a long header packet +type LongHeaderSealer interface { + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) + Overhead() int +} + +// ShortHeaderSealer seals a short header packet +type ShortHeaderSealer interface { + LongHeaderSealer + KeyPhase() protocol.KeyPhaseBit +} + +// A tlsExtensionHandler sends and received the QUIC TLS extension. +type tlsExtensionHandler interface { + GetExtensions(msgType uint8) []qtls.Extension + ReceivedExtensions(msgType uint8, exts []qtls.Extension) + TransportParameters() <-chan []byte +} + +type handshakeRunner interface { + OnReceivedParams(*wire.TransportParameters) + OnHandshakeComplete() + OnError(error) + DropKeys(protocol.EncryptionLevel) +} + +// CryptoSetup handles the handshake and protecting / unprotecting packets +type CryptoSetup interface { + RunHandshake() + io.Closer + ChangeConnectionID(protocol.ConnectionID) + GetSessionTicket() ([]byte, error) + + HandleMessage([]byte, protocol.EncryptionLevel) bool + SetLargest1RTTAcked(protocol.PacketNumber) error + SetHandshakeConfirmed() + ConnectionState() ConnectionState + + GetInitialOpener() (LongHeaderOpener, error) + GetHandshakeOpener() (LongHeaderOpener, error) + Get0RTTOpener() (LongHeaderOpener, error) + Get1RTTOpener() (ShortHeaderOpener, error) + + GetInitialSealer() (LongHeaderSealer, error) + GetHandshakeSealer() (LongHeaderSealer, error) + Get0RTTSealer() (LongHeaderSealer, error) + Get1RTTSealer() (ShortHeaderSealer, error) +} + +// ConnWithVersion is the connection used in the ClientHelloInfo. +// It can be used to determine the QUIC version in use. +type ConnWithVersion interface { + net.Conn + GetQUICVersion() protocol.VersionNumber +} diff --git a/internal/quic-go/handshake/mock_handshake_runner_test.go b/internal/quic-go/handshake/mock_handshake_runner_test.go new file mode 100644 index 00000000..eae6a898 --- /dev/null +++ b/internal/quic-go/handshake/mock_handshake_runner_test.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go + +// Package handshake is a generated GoMock package. +package handshake + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockHandshakeRunner is a mock of HandshakeRunner interface. +type MockHandshakeRunner struct { + ctrl *gomock.Controller + recorder *MockHandshakeRunnerMockRecorder +} + +// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner. +type MockHandshakeRunnerMockRecorder struct { + mock *MockHandshakeRunner +} + +// NewMockHandshakeRunner creates a new mock instance. +func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { + mock := &MockHandshakeRunner{ctrl: ctrl} + mock.recorder = &MockHandshakeRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { + return m.recorder +} + +// DropKeys mocks base method. +func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropKeys", arg0) +} + +// DropKeys indicates an expected call of DropKeys. +func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) +} + +// OnError mocks base method. +func (m *MockHandshakeRunner) OnError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnError", arg0) +} + +// OnError indicates an expected call of OnError. +func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) +} + +// OnHandshakeComplete mocks base method. +func (m *MockHandshakeRunner) OnHandshakeComplete() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnHandshakeComplete") +} + +// OnHandshakeComplete indicates an expected call of OnHandshakeComplete. +func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) +} + +// OnReceivedParams mocks base method. +func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnReceivedParams", arg0) +} + +// OnReceivedParams indicates an expected call of OnReceivedParams. +func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) +} diff --git a/internal/quic-go/handshake/mockgen.go b/internal/quic-go/handshake/mockgen.go new file mode 100644 index 00000000..b5534225 --- /dev/null +++ b/internal/quic-go/handshake/mockgen.go @@ -0,0 +1,3 @@ +package handshake + +//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/imroc/req/v3/internal/quic-go/handshake handshakeRunner" diff --git a/internal/quic-go/handshake/retry.go b/internal/quic-go/handshake/retry.go new file mode 100644 index 00000000..a9906086 --- /dev/null +++ b/internal/quic-go/handshake/retry.go @@ -0,0 +1,62 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +var ( + oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 + retryAEAD cipher.AEAD // used for QUIC draft-34 +) + +func init() { + oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) + retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) +} + +func initAEAD(key [16]byte) cipher.AEAD { + aes, err := aes.NewCipher(key[:]) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + return aead +} + +var ( + retryBuf bytes.Buffer + retryMutex sync.Mutex + oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} + retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} +) + +// GetRetryIntegrityTag calculates the integrity tag on a Retry packet +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { + retryMutex.Lock() + retryBuf.WriteByte(uint8(origDestConnID.Len())) + retryBuf.Write(origDestConnID.Bytes()) + retryBuf.Write(retry) + + var tag [16]byte + var sealed []byte + if version != protocol.Version1 { + sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) + } else { + sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) + } + if len(sealed) != 16 { + panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) + } + retryBuf.Reset() + retryMutex.Unlock() + return &tag +} diff --git a/internal/quic-go/handshake/retry_test.go b/internal/quic-go/handshake/retry_test.go new file mode 100644 index 00000000..fdb3ff75 --- /dev/null +++ b/internal/quic-go/handshake/retry_test.go @@ -0,0 +1,36 @@ +package handshake + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Retry Integrity Check", func() { + It("calculates retry integrity tags", func() { + fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + Expect(fooTag).ToNot(BeNil()) + Expect(barTag).ToNot(BeNil()) + Expect(*fooTag).ToNot(Equal(*barTag)) + }) + + It("includes the original connection ID in the tag calculation", func() { + t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) + t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) + Expect(*t1).ToNot(Equal(*t2)) + }) + + It("uses the test vector from the draft, for old draft versions", func() { + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") + Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) + }) + + It("uses the test vector from the draft, for version 1", func() { + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") + Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) + }) +}) diff --git a/internal/quic-go/handshake/session_ticket.go b/internal/quic-go/handshake/session_ticket.go new file mode 100644 index 00000000..afefe8a7 --- /dev/null +++ b/internal/quic-go/handshake/session_ticket.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "bytes" + "errors" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +const sessionTicketRevision = 2 + +type sessionTicket struct { + Parameters *wire.TransportParameters + RTT time.Duration // to be encoded in mus +} + +func (t *sessionTicket) Marshal() []byte { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + quicvarint.Write(b, uint64(t.RTT.Microseconds())) + t.Parameters.MarshalForSessionTicket(b) + return b.Bytes() +} + +func (t *sessionTicket) Unmarshal(b []byte) error { + r := bytes.NewReader(b) + rev, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read session ticket revision") + } + if rev != sessionTicketRevision { + return fmt.Errorf("unknown session ticket revision: %d", rev) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read RTT") + } + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + t.RTT = time.Duration(rtt) * time.Microsecond + return nil +} diff --git a/internal/quic-go/handshake/session_ticket_test.go b/internal/quic-go/handshake/session_ticket_test.go new file mode 100644 index 00000000..832def9d --- /dev/null +++ b/internal/quic-go/handshake/session_ticket_test.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "bytes" + "time" + + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Session Ticket", func() { + It("marshals and unmarshals a session ticket", func() { + ticket := &sessionTicket{ + Parameters: &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + }, + RTT: 1337 * time.Microsecond, + } + var t sessionTicket + Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) + Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) + Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) + Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + }) + + It("refuses to unmarshal if the ticket is too short for the revision", func() { + Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) + }) + + It("refuses to unmarshal if the revision doesn't match", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, 1337) + Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("unknown session ticket revision: 1337")) + }) + + It("refuses to unmarshal if the RTT cannot be read", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("failed to read RTT")) + }) + + It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + b.Write([]byte("foobar")) + err := (&sessionTicket{}).Unmarshal(b.Bytes()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) + }) +}) diff --git a/internal/quic-go/handshake/tls_extension_handler.go b/internal/quic-go/handshake/tls_extension_handler.go new file mode 100644 index 00000000..245f27c8 --- /dev/null +++ b/internal/quic-go/handshake/tls_extension_handler.go @@ -0,0 +1,68 @@ +package handshake + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" +) + +const ( + quicTLSExtensionTypeOldDrafts = 0xffa5 + quicTLSExtensionType = 0x39 +) + +type extensionHandler struct { + ourParams []byte + paramsChan chan []byte + + extensionType uint16 + + perspective protocol.Perspective +} + +var _ tlsExtensionHandler = &extensionHandler{} + +// newExtensionHandler creates a new extension handler +func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler { + et := uint16(quicTLSExtensionType) + if v != protocol.Version1 { + et = quicTLSExtensionTypeOldDrafts + } + return &extensionHandler{ + ourParams: params, + paramsChan: make(chan []byte), + perspective: pers, + extensionType: et, + } +} + +func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { + return nil + } + return []qtls.Extension{{ + Type: h.extensionType, + Data: h.ourParams, + }} +} + +func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { + return + } + + var data []byte + for _, ext := range exts { + if ext.Type == h.extensionType { + data = ext.Data + break + } + } + + h.paramsChan <- data +} + +func (h *extensionHandler) TransportParameters() <-chan []byte { + return h.paramsChan +} diff --git a/internal/quic-go/handshake/tls_extension_handler_test.go b/internal/quic-go/handshake/tls_extension_handler_test.go new file mode 100644 index 00000000..4fcd48c1 --- /dev/null +++ b/internal/quic-go/handshake/tls_extension_handler_test.go @@ -0,0 +1,210 @@ +package handshake + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qtls" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("TLS Extension Handler, for the server", func() { + var ( + handlerServer tlsExtensionHandler + handlerClient tlsExtensionHandler + version protocol.VersionNumber + ) + + BeforeEach(func() { + version = protocol.VersionDraft29 + }) + + JustBeforeEach(func() { + handlerServer = newExtensionHandler( + []byte("foobar"), + protocol.PerspectiveServer, + version, + ) + handlerClient = newExtensionHandler( + []byte("raboof"), + protocol.PerspectiveClient, + version, + ) + }) + + Context("for the server", func() { + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { + v := ver + + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the EncryptedExtensions message", func() { + exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("foobar"))) + }) + }) + } + + Context("receiving", func() { + var chExts []qtls.Extension + + JustBeforeEach(func() { + chExts = handlerClient.GetExtensions(uint8(typeClientHello)) + Expect(chExts).To(HaveLen(1)) + }) + + It("sends the extension on the channel", func() { + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) + Expect(data).To(Equal([]byte("raboof"))) + }) + + It("sends nil on the channel if the extension is missing", func() { + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions with different code points", func() { + go func() { + defer GinkgoRecover() + exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} + handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) + }() + + var data []byte + Eventually(handlerServer.TransportParameters()).Should(Receive()) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions that are not sent with the ClientHello", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) + close(done) + }() + + Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) + Eventually(done).Should(BeClosed()) + }) + }) + }) + + Context("for the client", func() { + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { + v := ver + + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the ClientHello message", func() { + exts := handlerClient.GetExtensions(uint8(typeClientHello)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("raboof"))) + }) + }) + } + + Context("receiving", func() { + var chExts []qtls.Extension + + JustBeforeEach(func() { + chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(chExts).To(HaveLen(1)) + }) + + It("sends the extension on the channel", func() { + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("sends nil on the channel if the extension is missing", func() { + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions with different code points", func() { + go func() { + defer GinkgoRecover() + exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} + handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) + }() + + var data []byte + Eventually(handlerClient.TransportParameters()).Should(Receive()) + Expect(data).To(BeEmpty()) + }) + + It("ignores extensions that are not sent with the EncryptedExtensions", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) + close(done) + }() + + Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) + Eventually(done).Should(BeClosed()) + }) + }) + }) +}) diff --git a/internal/quic-go/handshake/token_generator.go b/internal/quic-go/handshake/token_generator.go new file mode 100644 index 00000000..3dcfa090 --- /dev/null +++ b/internal/quic-go/handshake/token_generator.go @@ -0,0 +1,134 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "io" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +const ( + tokenPrefixIP byte = iota + tokenPrefixString +) + +// A Token is derived from the client address and can be used to verify the ownership of this address. +type Token struct { + IsRetryToken bool + RemoteAddr string + SentTime time.Time + // only set for retry tokens + OriginalDestConnectionID protocol.ConnectionID + RetrySrcConnectionID protocol.ConnectionID +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + IsRetryToken bool + RemoteAddr []byte + Timestamp int64 + OriginalDestConnectionID []byte + RetrySrcConnectionID []byte +} + +// A TokenGenerator generates tokens +type TokenGenerator struct { + tokenProtector tokenProtector +} + +// NewTokenGenerator initializes a new TookenGenerator +func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { + tokenProtector, err := newTokenProtector(rand) + if err != nil { + return nil, err + } + return &TokenGenerator{ + tokenProtector: tokenProtector, + }, nil +} + +// NewRetryToken generates a new token for a Retry for a given source address +func (g *TokenGenerator) NewRetryToken( + raddr net.Addr, + origDestConnID protocol.ConnectionID, + retrySrcConnID protocol.ConnectionID, +) ([]byte, error) { + data, err := asn1.Marshal(token{ + IsRetryToken: true, + RemoteAddr: encodeRemoteAddr(raddr), + OriginalDestConnectionID: origDestConnID, + RetrySrcConnectionID: retrySrcConnID, + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// NewToken generates a new token to be sent in a NEW_TOKEN frame +func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + RemoteAddr: encodeRemoteAddr(raddr), + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// DecodeToken decodes a token +func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { + // if the client didn't send any token, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.tokenProtector.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + token := &Token{ + IsRetryToken: t.IsRetryToken, + RemoteAddr: decodeRemoteAddr(t.RemoteAddr), + SentTime: time.Unix(0, t.Timestamp), + } + if t.IsRetryToken { + token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + } + return token, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the token +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{tokenPrefixIP}, udpAddr.IP...) + } + return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the token +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a token that we generated. + // Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == tokenPrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/internal/quic-go/handshake/token_generator_test.go b/internal/quic-go/handshake/token_generator_test.go new file mode 100644 index 00000000..a1a22ee1 --- /dev/null +++ b/internal/quic-go/handshake/token_generator_test.go @@ -0,0 +1,127 @@ +package handshake + +import ( + "crypto/rand" + "encoding/asn1" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Token Generator", func() { + var tokenGen *TokenGenerator + + BeforeEach(func() { + var err error + tokenGen, err = NewTokenGenerator(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + }) + + It("generates a token", func() { + ip := net.IPv4(127, 0, 0, 1) + token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(token).ToNot(BeEmpty()) + }) + + It("works with nil tokens", func() { + token, err := tokenGen.DecodeToken(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(token).To(BeNil()) + }) + + It("accepts a valid token", func() { + ip := net.IPv4(192, 168, 0, 1) + tokenEnc, err := tokenGen.NewRetryToken( + &net.UDPAddr{IP: ip, Port: 1337}, + nil, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal("192.168.0.1")) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) + Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) + }) + + It("saves the connection ID", func() { + tokenEnc, err := tokenGen.NewRetryToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + ) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + }) + + It("rejects invalid tokens", func() { + _, err := tokenGen.DecodeToken([]byte("invalid token")) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that cannot be decoded", func() { + token, err := tokenGen.tokenProtector.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(token) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that can be decoded, but have additional payload", func() { + t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) + Expect(err).ToNot(HaveOccurred()) + t = append(t, []byte("rest")...) + enc, err := tokenGen.tokenProtector.NewToken(t) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(enc) + Expect(err).To(MatchError("rest when unpacking token: 4")) + }) + + // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason + It("doesn't panic if a tokens has no data", func() { + t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) + Expect(err).ToNot(HaveOccurred()) + enc, err := tokenGen.tokenProtector.NewToken(t) + Expect(err).ToNot(HaveOccurred()) + _, err = tokenGen.DecodeToken(enc) + Expect(err).ToNot(HaveOccurred()) + }) + + It("works with an IPv6 addresses ", func() { + addresses := []string{ + "2001:db8::68", + "2001:0000:4136:e378:8000:63bf:3fff:fdd2", + "2001::1", + "ff01:0:0:0:0:0:0:2", + } + for _, addr := range addresses { + ip := net.ParseIP(addr) + Expect(ip).ToNot(BeNil()) + raddr := &net.UDPAddr{IP: ip, Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal(ip.String())) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + } + }) + + It("uses the string representation an address that is not a UDP address", func() { + raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + token, err := tokenGen.DecodeToken(tokenEnc) + Expect(err).ToNot(HaveOccurred()) + Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) + Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + }) +}) diff --git a/internal/quic-go/handshake/token_protector.go b/internal/quic-go/handshake/token_protector.go new file mode 100644 index 00000000..650f230b --- /dev/null +++ b/internal/quic-go/handshake/token_protector.go @@ -0,0 +1,89 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// TokenProtector is used to create and verify a token +type tokenProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const ( + tokenSecretSize = 32 + tokenNonceSize = 32 +) + +// tokenProtector is used to create and verify a token +type tokenProtectorImpl struct { + rand io.Reader + secret []byte +} + +// newTokenProtector creates a source for source address tokens +func newTokenProtector(rand io.Reader) (tokenProtector, error) { + secret := make([]byte, tokenSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &tokenProtectorImpl{ + rand: rand, + secret: secret, + }, nil +} + +// NewToken encodes data into a new token. +func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, tokenNonceSize) + if _, err := s.rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { + if len(p) < tokenNonceSize { + return nil, fmt.Errorf("token too short: %d", len(p)) + } + nonce := p[:tokenNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) +} + +func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/internal/quic-go/handshake/token_protector_test.go b/internal/quic-go/handshake/token_protector_test.go new file mode 100644 index 00000000..7171e865 --- /dev/null +++ b/internal/quic-go/handshake/token_protector_test.go @@ -0,0 +1,67 @@ +package handshake + +import ( + "crypto/rand" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type zeroReader struct{} + +func (r *zeroReader) Read(b []byte) (int, error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} + +var _ = Describe("Token Protector", func() { + var tp tokenProtector + + BeforeEach(func() { + var err error + tp, err = newTokenProtector(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + }) + + It("uses the random source", func() { + tp1, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + tp2, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + t1, err := tp1.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + t2, err := tp2.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t1).To(Equal(t2)) + tp3, err := newTokenProtector(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + t3, err := tp3.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t3).ToNot(Equal(t1)) + }) + + It("encodes and decodes tokens", func() { + token, err := tp.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(token).ToNot(ContainSubstring("foobar")) + decoded, err := tp.DecodeToken(token) + Expect(err).ToNot(HaveOccurred()) + Expect(decoded).To(Equal([]byte("foobar"))) + }) + + It("fails deconding invalid tokens", func() { + token, err := tp.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + token = token[1:] // remove the first byte + _, err = tp.DecodeToken(token) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("message authentication failed")) + }) + + It("errors when decoding too short tokens", func() { + _, err := tp.DecodeToken([]byte("foobar")) + Expect(err).To(MatchError("token too short: 6")) + }) +}) diff --git a/internal/quic-go/handshake/updatable_aead.go b/internal/quic-go/handshake/updatable_aead.go new file mode 100644 index 00000000..e22cea45 --- /dev/null +++ b/internal/quic-go/handshake/updatable_aead.go @@ -0,0 +1,323 @@ +package handshake + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/qtls" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +// It's a package-level variable to allow modifying it for testing purposes. +var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval + +type updatableAEAD struct { + suite *qtls.CipherSuiteTLS13 + + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber + handshakeConfirmed bool + + keyUpdateInterval uint64 + invalidPacketLimit uint64 + invalidPacketCount uint64 + + // Time when the keys should be dropped. Keys are dropped on the next call to Open(). + prevRcvAEADExpiry time.Time + prevRcvAEAD cipher.AEAD + + firstRcvdWithCurrentKey protocol.PacketNumber + firstSentWithCurrentKey protocol.PacketNumber + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + numRcvdWithCurrentKey uint64 + numSentWithCurrentKey uint64 + rcvAEAD cipher.AEAD + sendAEAD cipher.AEAD + // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). + aeadOverhead int + + nextRcvAEAD cipher.AEAD + nextSendAEAD cipher.AEAD + nextRcvTrafficSecret []byte + nextSendTrafficSecret []byte + + headerDecrypter headerProtector + headerEncrypter headerProtector + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + version protocol.VersionNumber + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var ( + _ ShortHeaderOpener = &updatableAEAD{} + _ ShortHeaderSealer = &updatableAEAD{} +) + +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { + return &updatableAEAD{ + firstPacketNumber: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, + firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, + firstSentWithCurrentKey: protocol.InvalidPacketNumber, + keyUpdateInterval: KeyUpdateInterval, + rttStats: rttStats, + tracer: tracer, + logger: logger, + version: version, + } +} + +func (a *updatableAEAD) rollKeys() { + if a.prevRcvAEAD != nil { + a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + a.prevRcvAEADExpiry = time.Time{} + } + + a.keyPhase++ + a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber + a.firstSentWithCurrentKey = protocol.InvalidPacketNumber + a.numRcvdWithCurrentKey = 0 + a.numSentWithCurrentKey = 0 + a.prevRcvAEAD = a.rcvAEAD + a.rcvAEAD = a.nextRcvAEAD + a.sendAEAD = a.nextSendAEAD + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) + a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) + a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) startKeyDropTimer(now time.Time) { + d := 3 * a.rttStats.PTO(true) + a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) + a.prevRcvAEADExpiry = now.Add(d) +} + +func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { + return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) +} + +// For the client, this function is called before SetWriteKey. +// For the server, this function is called after SetWriteKey. +func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.rcvAEAD, suite) + } + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) +} + +// For the client, this function is called after SetReadKey. +// For the server, this function is called before SetWriteKey. +func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.sendAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.sendAEAD, suite) + } + + a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { + a.nonceBuf = make([]byte, aead.NonceSize()) + a.aeadOverhead = aead.Overhead() + a.suite = suite + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + a.invalidPacketLimit = protocol.InvalidPacketLimitAES + case tls.TLS_CHACHA20_POLY1305_SHA256: + a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha + default: + panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) + } +} + +func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) +} + +func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + dec, err := a.open(dst, src, rcvTime, pn, kp, ad) + if err == ErrDecryptionFailed { + a.invalidPacketCount++ + if a.invalidPacketCount >= a.invalidPacketLimit { + return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} + } + } + if err == nil { + a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) + } + return dec, err +} + +func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { + a.prevRcvAEAD = nil + a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) + a.prevRcvAEADExpiry = time.Time{} + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + } + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + if kp != a.keyPhase.Bit() { + if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped + } + // we updated the key, but the peer hasn't updated yet + dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + err = ErrDecryptionFailed + } + return dec, err + } + // try opening the packet with the next key phase + dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return nil, ErrDecryptionFailed + } + // Opening succeeded. Check if the peer was allowed to update. + if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + return nil, &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + } + } + a.rollKeys() + a.logger.Debugf("Peer updated keys to %d", a.keyPhase) + // The peer initiated this key update. It's safe to drop the keys for the previous generation now. + // Start a timer to drop the previous key generation. + a.startKeyDropTimer(rcvTime) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, true) + } + a.firstRcvdWithCurrentKey = pn + return dec, err + } + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return dec, ErrDecryptionFailed + } + a.numRcvdWithCurrentKey++ + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + // We initiated the key updated, and now we received the first packet protected with the new key phase. + // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. + if a.keyPhase > 0 { + a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) + a.startKeyDropTimer(rcvTime) + } + a.firstRcvdWithCurrentKey = pn + } + return dec, err +} + +func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + a.firstSentWithCurrentKey = pn + } + if a.firstPacketNumber == protocol.InvalidPacketNumber { + a.firstPacketNumber = pn + } + a.numSentWithCurrentKey++ + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) +} + +func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { + if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), + } + } + a.largestAcked = pn + return nil +} + +func (a *updatableAEAD) SetHandshakeConfirmed() { + a.handshakeConfirmed = true +} + +func (a *updatableAEAD) updateAllowed() bool { + if !a.handshakeConfirmed { + return false + } + // the first key update is allowed as soon as the handshake is confirmed + return a.keyPhase == 0 || + // subsequent key updates as soon as a packet sent with that key phase has been acknowledged + (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey) +} + +func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { + if !a.updateAllowed() { + return false + } + if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) + return true + } + if a.numSentWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) + return true + } + return false +} + +func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { + if a.shouldInitiateKeyUpdate() { + a.rollKeys() + a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, false) + } + } + return a.keyPhase.Bit() +} + +func (a *updatableAEAD) Overhead() int { + return a.aeadOverhead +} + +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { + return a.firstPacketNumber +} diff --git a/internal/quic-go/handshake/updatable_aead_test.go b/internal/quic-go/handshake/updatable_aead_test.go new file mode 100644 index 00000000..35ec718f --- /dev/null +++ b/internal/quic-go/handshake/updatable_aead_test.go @@ -0,0 +1,528 @@ +package handshake + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "time" + + "github.com/golang/mock/gomock" + + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Updatable AEAD", func() { + DescribeTable("ChaCha test vector", + func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) { + secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") + aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) + chacha := cipherSuites[2] + Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) + aead.SetWriteKey(chacha, secret) + const pnOffset = 1 + header := splitHexString("4200bff4") + payloadOffset := len(header) + plaintext := splitHexString("01") + payload := aead.Seal(nil, plaintext, 654360564, header) + Expect(payload).To(Equal(expectedPayload)) + packet := append(header, payload...) + aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("QUIC v1", + protocol.Version1, + splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), + splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), + splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), + ), + ) + + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + var ( + client, server *updatableAEAD + serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + rttStats = utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) + }) + + Context("header protection", func() { + It("encrypts and decrypts the header", func() { + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + client.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + server.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) + }) + }) + + Context("message encryption", func() { + var msg, ad []byte + + BeforeEach(func() { + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") + }) + + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("saves the first packet number", func() { + client.Seal(nil, msg, 0x1337, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + client.Seal(nil, msg, 0x1338, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + }) + + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("decodes the packet number", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + + It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { + client.invalidPacketLimit = 10 + for i := 0; i < 9; i++ { + _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + } + _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) + }) + + Context("key updates", func() { + Context("receiving key updates", func() { + It("updates keys", func() { + now := time.Now() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys() + decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) + + It("updates the keys when receiving a packet with the next key phase", func() { + now := time.Now() + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x43, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + now := time.Now() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + rttStats.UpdateRTT(10*time.Millisecond, 0, now) + pto := rttStats.PTO(true) + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("allows the first key update immediately", func() { + // receive a packet at key phase one, before having sent or received any packets at key phase 0 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x1337, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { + client.rollKeys() + encrypted := client.Seal(nil, msg, 0x1337, ad) + encrypted = encrypted[:len(encrypted)-1] + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("errors when the peer updates keys too frequently", func() { + server.rollKeys() + client.rollKeys() + // receive the first packet at key phase one + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase two, before having sent any packets + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + })) + }) + }) + + Context("initiating key updates", func() { + const keyUpdateInterval = 20 + + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + server.SetHandshakeConfirmed() + }) + + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // Now that our keys are updated, send a packet using the new keys. + const nextPN = keyUpdateInterval + 1 + server.Seal(nil, msg, nextPN, ad) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", + })) + }) + + It("doesn't error before actually sending a packet in the new key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + // Now that our keys are updated, send a packet using the new keys. + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) + }) + + It("initiates a key update after opening the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, 1, ad) + Expect(server.SetLargestAcked(1)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(0)).To(Succeed()) + // Now we've initiated the first key update. + // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there + threePTO := 3 * rttStats.PTO(false) + dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // Now receive a packet with key phase 1. + // This should start the timer to drop the keys after 3 PTOs. + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) + t := now.Add(threePTO).Add(time.Second) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // Make sure the keys are still here. + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("doesn't drop the first key generation too early", func() { + now := time.Now() + data1 := client.Seal(nil, msg, 1, ad) + _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + Expect(server.SetLargestAcked(pn)).To(Succeed()) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // The server never received a packet at key phase 1. + // Make sure the key phase 0 is still there at a much later point. + data2 := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + const nextPN = keyUpdateInterval + 1 + // Send and receive an acknowledgement for a packet in key phase 1. + // We are now running a timer to drop the keys with 3 PTO. + server.Seal(nil, msg, nextPN, ad) + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) + now := time.Now() + _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(nextPN)) + // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. + // This mean that we need to drop the keys for key phase 0 immediately. + client.rollKeys() + dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), + ) + _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys early when we initiate another key update within the 3 PTO period", func() { + server.SetHandshakeConfirmed() + // send so many packets that we initiate the first key update + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // send so many packets that we initiate the next key update + for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + client.rollKeys() + b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) + now := time.Now() + _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), + ) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + // We haven't received an ACK for a packet sent in key phase 2 yet. + // Make sure we canceled the timer to drop the previous key phase. + b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) + _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + }) + }) + } + }) + } +}) diff --git a/internal/quic-go/interface.go b/internal/quic-go/interface.go new file mode 100644 index 00000000..76af5fcb --- /dev/null +++ b/internal/quic-go/interface.go @@ -0,0 +1,328 @@ +package quic + +import ( + "context" + "errors" + "io" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// The StreamID is the ID of a QUIC stream. +type StreamID = protocol.StreamID + +// A VersionNumber is a QUIC version number. +type VersionNumber = protocol.VersionNumber + +const ( + // VersionDraft29 is IETF QUIC draft-29 + VersionDraft29 = protocol.VersionDraft29 + // Version1 is RFC 9000 + Version1 = protocol.Version1 + Version2 = protocol.Version2 +) + +// A Token can be used to verify the ownership of the client address. +type Token struct { + // IsRetryToken encodes how the client received the token. There are two ways: + // * In a Retry packet sent when trying to establish a new connection. + // * In a NEW_TOKEN frame on a previous connection. + IsRetryToken bool + RemoteAddr string + SentTime time.Time +} + +// A ClientToken is a token received by the client. +// It can be used to skip address validation on future connection attempts. +type ClientToken struct { + data []byte +} + +type TokenStore interface { + // Pop searches for a ClientToken associated with the given key. + // Since tokens are not supposed to be reused, it must remove the token from the cache. + // It returns nil when no token is found. + Pop(key string) (token *ClientToken) + + // Put adds a token to the cache with the given key. It might get called + // multiple times in a connection. + Put(key string, token *ClientToken) +} + +// Err0RTTRejected is the returned from: +// * Open{Uni}Stream{Sync} +// * Accept{Uni}Stream +// * Stream.Read and Stream.Write +// when the server rejects a 0-RTT connection attempt. +var Err0RTTRejected = errors.New("0-RTT rejected") + +// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection. +// It is set on the Connection.Context() context, +// as well as on the context passed to logging.Tracer.NewConnectionTracer. +var ConnectionTracingKey = connTracingCtxKey{} + +type connTracingCtxKey struct{} + +// Stream is the interface implemented by QUIC streams +// In addition to the errors listed on the Connection, +// calls to stream functions can return a StreamError if the stream is canceled. +type Stream interface { + ReceiveStream + SendStream + // SetDeadline sets the read and write deadlines associated + // with the connection. It is equivalent to calling both + // SetReadDeadline and SetWriteDeadline. + SetDeadline(t time.Time) error +} + +// A ReceiveStream is a unidirectional Receive Stream. +type ReceiveStream interface { + // StreamID returns the stream ID. + StreamID() StreamID + // Read reads data from the stream. + // Read can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetDeadline and SetReadDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + io.Reader + // CancelRead aborts receiving on this stream. + // It will ask the peer to stop transmitting stream data. + // Read will unblock immediately, and future Read calls will fail. + // When called multiple times or after reading the io.EOF it is a no-op. + CancelRead(StreamErrorCode) + // SetReadDeadline sets the deadline for future Read calls and + // any currently-blocked Read call. + // A zero value for t means Read will not time out. + + SetReadDeadline(t time.Time) error +} + +// A SendStream is a unidirectional Send Stream. +type SendStream interface { + // StreamID returns the stream ID. + StreamID() StreamID + // Write writes data to the stream. + // Write can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + io.Writer + // Close closes the write-direction of the stream. + // Future calls to Write are not permitted after calling Close. + // It must not be called concurrently with Write. + // It must not be called after calling CancelWrite. + io.Closer + // CancelWrite aborts sending on this stream. + // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. + // Write will unblock immediately, and future calls to Write will fail. + // When called multiple times or after closing the stream it is a no-op. + CancelWrite(StreamErrorCode) + // The Context is canceled as soon as the write-side of the stream is closed. + // This happens when Close() or CancelWrite() is called, or when the peer + // cancels the read-side of their stream. + Context() context.Context + // SetWriteDeadline sets the deadline for future Write calls + // and any currently-blocked Write call. + // Even if write times out, it may return n > 0, indicating that + // some data was successfully written. + // A zero value for t means Write will not time out. + SetWriteDeadline(t time.Time) error +} + +// A Connection is a QUIC connection between two peers. +// Calls to the connection (and to streams) can return the following types of errors: +// * ApplicationError: for errors triggered by the application running on top of QUIC +// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer) +// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error) +// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error) +// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error) +// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers +type Connection interface { + // AcceptStream returns the next stream opened by the peer, blocking until one is available. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + AcceptStream(context.Context) (Stream, error) + // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + AcceptUniStream(context.Context) (ReceiveStream, error) + // OpenStream opens a new bidirectional QUIC stream. + // There is no signaling to the peer about new streams: + // The peer can only accept the stream after data has been sent on the stream. + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, err.Temporary() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenStream() (Stream, error) + // OpenStreamSync opens a new bidirectional QUIC stream. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenStreamSync(context.Context) (Stream, error) + // OpenUniStream opens a new outgoing unidirectional QUIC stream. + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, Temporary() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenUniStream() (SendStream, error) + // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenUniStreamSync(context.Context) (SendStream, error) + // LocalAddr returns the local address. + LocalAddr() net.Addr + // RemoteAddr returns the address of the peer. + RemoteAddr() net.Addr + // CloseWithError closes the connection with an error. + // The error string will be sent to the peer. + CloseWithError(ApplicationErrorCode, string) error + // The context is cancelled when the connection is closed. + Context() context.Context + // ConnectionState returns basic details about the QUIC connection. + // It blocks until the handshake completes. + // Warning: This API should not be considered stable and might change soon. + ConnectionState() ConnectionState + + // SendMessage sends a message as a datagram, as specified in RFC 9221. + SendMessage([]byte) error + // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. + ReceiveMessage() ([]byte, error) +} + +// An EarlyConnection is a connection that is handshaking. +// Data sent during the handshake is encrypted using the forward secure keys. +// When using client certificates, the client's identity is only verified +// after completion of the handshake. +type EarlyConnection interface { + Connection + + // HandshakeComplete blocks until the handshake completes (or fails). + // Data sent before completion of the handshake is encrypted with 1-RTT keys. + // Note that the client's identity hasn't been verified yet. + HandshakeComplete() context.Context + + NextConnection() Connection +} + +// Config contains all configuration data needed for a QUIC server or client. +type Config struct { + // The QUIC versions that can be negotiated. + // If not set, it uses all versions available. + Versions []VersionNumber + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If not set, the interpretation depends on where the Config is used: + // If used for dialing an address, a 0 byte connection ID will be used. + // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. + // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. + ConnectionIDLength int + // HandshakeIdleTimeout is the idle timeout before completion of the handshake. + // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. + // If this value is zero, the timeout is set to 5 seconds. + HandshakeIdleTimeout time.Duration + // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. + // The actual value for the idle timeout is the minimum of this value and the peer's. + // This value only applies after the handshake has completed. + // If the timeout is exceeded, the connection is closed. + // If this value is zero, the timeout is set to 30 seconds. + MaxIdleTimeout time.Duration + // AcceptToken determines if a Token is accepted. + // It is called with token = nil if the client didn't send a token. + // If not set, a default verification function is used: + // * it verifies that the address matches, and + // * if the token is a retry token, that it was issued within the last 5 seconds + // * else, that it was issued within the last 24 hours. + // This option is only valid for the server. + AcceptToken func(clientAddr net.Addr, token *Token) bool + // The TokenStore stores tokens received from the server. + // Tokens are used to skip address validation on future connection attempts. + // The key used to store tokens is the ServerName from the tls.Config, if set + // otherwise the token is associated with the server's IP address. + TokenStore TokenStore + // InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data. + // If the application is consuming data quickly enough, the flow control auto-tuning algorithm + // will increase the window up to MaxStreamReceiveWindow. + // If this value is zero, it will default to 512 KB. + InitialStreamReceiveWindow uint64 + // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. + // If this value is zero, it will default to 6 MB. + MaxStreamReceiveWindow uint64 + // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. + // If the application is consuming data quickly enough, the flow control auto-tuning algorithm + // will increase the window up to MaxConnectionReceiveWindow. + // If this value is zero, it will default to 512 KB. + InitialConnectionReceiveWindow uint64 + // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. + // If this value is zero, it will default to 15 MB. + MaxConnectionReceiveWindow uint64 + // AllowConnectionWindowIncrease is called every time the connection flow controller attempts + // to increase the connection flow control window. + // If set, the caller can prevent an increase of the window. Typically, it would do so to + // limit the memory usage. + // To avoid deadlocks, it is not valid to call other functions on the connection or on streams + // in this callback. + AllowConnectionWindowIncrease func(sess Connection, delta uint64) bool + // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. + // Values above 2^60 are invalid. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any bidirectional streams. + MaxIncomingStreams int64 + // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. + // Values above 2^60 are invalid. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any unidirectional streams. + MaxIncomingUniStreams int64 + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + StatelessResetKey []byte + // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. + // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most + // every half of MaxIdleTimeout, whichever is smaller). + KeepAlivePeriod time.Duration + // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). + // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. + // Note that if Path MTU discovery is causing issues on your system, please open a new issue + DisablePathMTUDiscovery bool + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for a client. + DisableVersionNegotiationPackets bool + // See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/. + // Datagrams will only be available when both peers enable datagram support. + EnableDatagrams bool + Tracer logging.Tracer +} + +// ConnectionState records basic details about a QUIC connection +type ConnectionState struct { + TLS handshake.ConnectionState + SupportsDatagrams bool +} + +// A Listener for incoming QUIC connections +type Listener interface { + // Close the server. All active connections will be closed. + Close() error + // Addr returns the local network addr that the server is listening on. + Addr() net.Addr + // Accept returns new connections. It should be called in a loop. + Accept(context.Context) (Connection, error) +} + +// An EarlyListener listens for incoming QUIC connections, +// and returns them before the handshake completes. +type EarlyListener interface { + // Close the server. All active connections will be closed. + Close() error + // Addr returns the local network addr that the server is listening on. + Addr() net.Addr + // Accept returns new early connections. It should be called in a loop. + Accept(context.Context) (EarlyConnection, error) +} diff --git a/internal/quic-go/logging/frame.go b/internal/quic-go/logging/frame.go new file mode 100644 index 00000000..8675e0f9 --- /dev/null +++ b/internal/quic-go/logging/frame.go @@ -0,0 +1,66 @@ +package logging + +import "github.com/imroc/req/v3/internal/quic-go/wire" + +// A Frame is a QUIC frame +type Frame interface{} + +// The AckRange is used within the AckFrame. +// It is a range of packet numbers that is being acknowledged. +type AckRange = wire.AckRange + +type ( + // An AckFrame is an ACK frame. + AckFrame = wire.AckFrame + // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. + ConnectionCloseFrame = wire.ConnectionCloseFrame + // A DataBlockedFrame is a DATA_BLOCKED frame. + DataBlockedFrame = wire.DataBlockedFrame + // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. + HandshakeDoneFrame = wire.HandshakeDoneFrame + // A MaxDataFrame is a MAX_DATA frame. + MaxDataFrame = wire.MaxDataFrame + // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. + MaxStreamDataFrame = wire.MaxStreamDataFrame + // A MaxStreamsFrame is a MAX_STREAMS_FRAME. + MaxStreamsFrame = wire.MaxStreamsFrame + // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. + NewConnectionIDFrame = wire.NewConnectionIDFrame + // A NewTokenFrame is a NEW_TOKEN frame. + NewTokenFrame = wire.NewTokenFrame + // A PathChallengeFrame is a PATH_CHALLENGE frame. + PathChallengeFrame = wire.PathChallengeFrame + // A PathResponseFrame is a PATH_RESPONSE frame. + PathResponseFrame = wire.PathResponseFrame + // A PingFrame is a PING frame. + PingFrame = wire.PingFrame + // A ResetStreamFrame is a RESET_STREAM frame. + ResetStreamFrame = wire.ResetStreamFrame + // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. + RetireConnectionIDFrame = wire.RetireConnectionIDFrame + // A StopSendingFrame is a STOP_SENDING frame. + StopSendingFrame = wire.StopSendingFrame + // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. + StreamsBlockedFrame = wire.StreamsBlockedFrame + // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. + StreamDataBlockedFrame = wire.StreamDataBlockedFrame +) + +// A CryptoFrame is a CRYPTO frame. +type CryptoFrame struct { + Offset ByteCount + Length ByteCount +} + +// A StreamFrame is a STREAM frame. +type StreamFrame struct { + StreamID StreamID + Offset ByteCount + Length ByteCount + Fin bool +} + +// A DatagramFrame is a DATAGRAM frame. +type DatagramFrame struct { + Length ByteCount +} diff --git a/internal/quic-go/logging/interface.go b/internal/quic-go/logging/interface.go new file mode 100644 index 00000000..f4e64840 --- /dev/null +++ b/internal/quic-go/logging/interface.go @@ -0,0 +1,134 @@ +// Package logging defines a logging interface for quic-go. +// This package should not be considered stable +package logging + +import ( + "context" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type ( + // A ByteCount is used to count bytes. + ByteCount = protocol.ByteCount + // A ConnectionID is a QUIC Connection ID. + ConnectionID = protocol.ConnectionID + // The EncryptionLevel is the encryption level of a packet. + EncryptionLevel = protocol.EncryptionLevel + // The KeyPhase is the key phase of the 1-RTT keys. + KeyPhase = protocol.KeyPhase + // The KeyPhaseBit is the value of the key phase bit of the 1-RTT packets. + KeyPhaseBit = protocol.KeyPhaseBit + // The PacketNumber is the packet number of a packet. + PacketNumber = protocol.PacketNumber + // The Perspective is the role of a QUIC endpoint (client or server). + Perspective = protocol.Perspective + // A StatelessResetToken is a stateless reset token. + StatelessResetToken = protocol.StatelessResetToken + // The StreamID is the stream ID. + StreamID = protocol.StreamID + // The StreamNum is the number of the stream. + StreamNum = protocol.StreamNum + // The StreamType is the type of the stream (unidirectional or bidirectional). + StreamType = protocol.StreamType + // The VersionNumber is the QUIC version. + VersionNumber = protocol.VersionNumber + + // The Header is the QUIC packet header, before removing header protection. + Header = wire.Header + // The ExtendedHeader is the QUIC packet header, after removing header protection. + ExtendedHeader = wire.ExtendedHeader + // The TransportParameters are QUIC transport parameters. + TransportParameters = wire.TransportParameters + // The PreferredAddress is the preferred address sent in the transport parameters. + PreferredAddress = wire.PreferredAddress + + // A TransportError is a transport-level error code. + TransportError = qerr.TransportErrorCode + // An ApplicationError is an application-defined error code. + ApplicationError = qerr.TransportErrorCode + + // The RTTStats contain statistics used by the congestion controller. + RTTStats = utils.RTTStats +) + +const ( + // KeyPhaseZero is key phase bit 0 + KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero + // KeyPhaseOne is key phase bit 1 + KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne +) + +const ( + // PerspectiveServer is used for a QUIC server + PerspectiveServer Perspective = protocol.PerspectiveServer + // PerspectiveClient is used for a QUIC client + PerspectiveClient Perspective = protocol.PerspectiveClient +) + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = protocol.EncryptionInitial + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT EncryptionLevel = protocol.Encryption1RTT + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT EncryptionLevel = protocol.Encryption0RTT +) + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni = protocol.StreamTypeUni + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi = protocol.StreamTypeBidi +) + +// A Tracer traces events. +type Tracer interface { + // TracerForConnection requests a new tracer for a connection. + // The ODCID is the original destination connection ID: + // The destination connection ID that the client used on the first Initial packet it sent on this connection. + // If nil is returned, tracing will be disabled for this connection. + TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer + + SentPacket(net.Addr, *Header, ByteCount, []Frame) + DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// A ConnectionTracer records events. +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection(error) + SentTransportParameters(*TransportParameters) + ReceivedTransportParameters(*TransportParameters) + RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT + SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket(*Header, []VersionNumber) + ReceivedRetry(*Header) + ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) + BufferedPacket(PacketType) + DroppedPacket(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket(EncryptionLevel, PacketNumber) + LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState(CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(EncryptionLevel, Perspective) + UpdatedKey(generation KeyPhase, remote bool) + DroppedEncryptionLevel(EncryptionLevel) + DroppedKey(generation KeyPhase) + SetLossTimer(TimerType, EncryptionLevel, time.Time) + LossTimerExpired(TimerType, EncryptionLevel) + LossTimerCanceled() + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/internal/quic-go/logging/logging_suite_test.go b/internal/quic-go/logging/logging_suite_test.go new file mode 100644 index 00000000..0a81943d --- /dev/null +++ b/internal/quic-go/logging/logging_suite_test.go @@ -0,0 +1,25 @@ +package logging + +import ( + "testing" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestLogging(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Logging Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/quic-go/logging/mock_connection_tracer_test.go b/internal/quic-go/logging/mock_connection_tracer_test.go new file mode 100644 index 00000000..620f181b --- /dev/null +++ b/internal/quic-go/logging/mock_connection_tracer_test.go @@ -0,0 +1,351 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: ConnectionTracer) + +// Package logging is a generated GoMock package. +package logging + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + utils "github.com/imroc/req/v3/internal/quic-go/utils" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedPacket mocks base method. +func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentPacket mocks base method. +func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/quic-go/logging/mock_tracer_test.go b/internal/quic-go/logging/mock_tracer_test.go new file mode 100644 index 00000000..6d49601a --- /dev/null +++ b/internal/quic-go/logging/mock_tracer_test.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: Tracer) + +// Package logging is a generated GoMock package. +package logging + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// TracerForConnection mocks base method. +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) + ret0, _ := ret[0].(ConnectionTracer) + return ret0 +} + +// TracerForConnection indicates an expected call of TracerForConnection. +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) +} diff --git a/internal/quic-go/logging/mockgen.go b/internal/quic-go/logging/mockgen.go new file mode 100644 index 00000000..09122480 --- /dev/null +++ b/internal/quic-go/logging/mockgen.go @@ -0,0 +1,4 @@ +package logging + +//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/quic-go/logging -destination mock_connection_tracer_test.go github.com/imroc/req/v3/internal/quic-go/logging ConnectionTracer" +//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/quic-go/logging -destination mock_tracer_test.go github.com/imroc/req/v3/internal/quic-go/logging Tracer" diff --git a/internal/quic-go/logging/multiplex.go b/internal/quic-go/logging/multiplex.go new file mode 100644 index 00000000..8280e8cd --- /dev/null +++ b/internal/quic-go/logging/multiplex.go @@ -0,0 +1,219 @@ +package logging + +import ( + "context" + "net" + "time" +) + +type tracerMultiplexer struct { + tracers []Tracer +} + +var _ Tracer = &tracerMultiplexer{} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...Tracer) Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &tracerMultiplexer{tracers} +} + +func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { + var connTracers []ConnectionTracer + for _, t := range m.tracers { + if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { + connTracers = append(connTracers, ct) + } + } + return NewMultiplexedConnectionTracer(connTracers...) +} + +func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(remote, hdr, size, frames) + } +} + +func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(remote, typ, size, reason) + } +} + +type connTracerMultiplexer struct { + tracers []ConnectionTracer +} + +var _ ConnectionTracer = &connTracerMultiplexer{} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &connTracerMultiplexer{tracers: tracers} +} + +func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range m.tracers { + t.StartedConnection(local, remote, srcConnID, destConnID) + } +} + +func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range m.tracers { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } +} + +func (m *connTracerMultiplexer) ClosedConnection(e error) { + for _, t := range m.tracers { + t.ClosedConnection(e) + } +} + +func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.SentTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.ReceivedTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.RestoredTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(hdr, size, ack, frames) + } +} + +func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { + for _, t := range m.tracers { + t.ReceivedVersionNegotiationPacket(hdr, versions) + } +} + +func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { + for _, t := range m.tracers { + t.ReceivedRetry(hdr) + } +} + +func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.ReceivedPacket(hdr, size, frames) + } +} + +func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { + for _, t := range m.tracers { + t.BufferedPacket(typ) + } +} + +func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(typ, size, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { + for _, t := range m.tracers { + t.UpdatedCongestionState(state) + } +} + +func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { + for _, t := range m.tracers { + t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) + } +} + +func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range m.tracers { + t.AcknowledgedPacket(encLevel, pn) + } +} + +func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range m.tracers { + t.LostPacket(encLevel, pn, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { + for _, t := range m.tracers { + t.UpdatedPTOCount(value) + } +} + +func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range m.tracers { + t.UpdatedKeyFromTLS(encLevel, perspective) + } +} + +func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { + for _, t := range m.tracers { + t.UpdatedKey(generation, remote) + } +} + +func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.DroppedEncryptionLevel(encLevel) + } +} + +func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { + for _, t := range m.tracers { + t.DroppedKey(generation) + } +} + +func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range m.tracers { + t.SetLossTimer(typ, encLevel, exp) + } +} + +func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.LossTimerExpired(typ, encLevel) + } +} + +func (m *connTracerMultiplexer) LossTimerCanceled() { + for _, t := range m.tracers { + t.LossTimerCanceled() + } +} + +func (m *connTracerMultiplexer) Debug(name, msg string) { + for _, t := range m.tracers { + t.Debug(name, msg) + } +} + +func (m *connTracerMultiplexer) Close() { + for _, t := range m.tracers { + t.Close() + } +} diff --git a/internal/quic-go/logging/multiplex_test.go b/internal/quic-go/logging/multiplex_test.go new file mode 100644 index 00000000..e9458d81 --- /dev/null +++ b/internal/quic-go/logging/multiplex_test.go @@ -0,0 +1,266 @@ +package logging + +import ( + "context" + "errors" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Tracing", func() { + Context("Tracer", func() { + It("returns a nil tracer if no tracers are passed in", func() { + Expect(NewMultiplexedTracer()).To(BeNil()) + }) + + It("returns the raw tracer if only one tracer is passed in", func() { + tr := NewMockTracer(mockCtrl) + tracer := NewMultiplexedTracer(tr) + Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) + }) + + Context("tracing events", func() { + var ( + tracer Tracer + tr1, tr2 *MockTracer + ) + + BeforeEach(func() { + tr1 = NewMockTracer(mockCtrl) + tr2 = NewMockTracer(mockCtrl) + tracer = NewMultiplexedTracer(tr1, tr2) + }) + + It("multiplexes the TracerForConnection call", func() { + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + }) + + It("uses multiple connection tracers", func() { + ctx := context.Background() + ctr1 := NewMockConnectionTracer(mockCtrl) + ctr2 := NewMockConnectionTracer(mockCtrl) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + ctr1.EXPECT().LossTimerCanceled() + ctr2.EXPECT().LossTimerCanceled() + tr.LossTimerCanceled() + }) + + It("handles tracers that return a nil ConnectionTracer", func() { + ctx := context.Background() + ctr1 := NewMockConnectionTracer(mockCtrl) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + ctr1.EXPECT().LossTimerCanceled() + tr.LossTimerCanceled() + }) + + It("returns nil when all tracers return a nil ConnectionTracer", func() { + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) + }) + + It("traces the PacketSent event", func() { + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + f := &MaxDataFrame{MaximumData: 1337} + tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) + tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) + tracer.SentPacket(remote, hdr, 1024, []Frame{f}) + }) + + It("traces the PacketDropped event", func() { + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) + tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) + tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) + }) + }) + }) + + Context("Connection Tracer", func() { + var ( + tracer ConnectionTracer + tr1 *MockConnectionTracer + tr2 *MockConnectionTracer + ) + + BeforeEach(func() { + tr1 = NewMockConnectionTracer(mockCtrl) + tr2 = NewMockConnectionTracer(mockCtrl) + tracer = NewMultiplexedConnectionTracer(tr1, tr2) + }) + + It("trace the ConnectionStarted event", func() { + local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + }) + + It("traces the ClosedConnection event", func() { + e := errors.New("test err") + tr1.EXPECT().ClosedConnection(e) + tr2.EXPECT().ClosedConnection(e) + tracer.ClosedConnection(e) + }) + + It("traces the SentTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().SentTransportParameters(tp) + tr2.EXPECT().SentTransportParameters(tp) + tracer.SentTransportParameters(tp) + }) + + It("traces the ReceivedTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().ReceivedTransportParameters(tp) + tr2.EXPECT().ReceivedTransportParameters(tp) + tracer.ReceivedTransportParameters(tp) + }) + + It("traces the RestoredTransportParameters event", func() { + tp := &wire.TransportParameters{InitialMaxData: 1337} + tr1.EXPECT().RestoredTransportParameters(tp) + tr2.EXPECT().RestoredTransportParameters(tp) + tracer.RestoredTransportParameters(tp) + }) + + It("traces the SentPacket event", func() { + hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} + ping := &PingFrame{} + tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) + tr2.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) + tracer.SentPacket(hdr, 1337, ack, []Frame{ping}) + }) + + It("traces the ReceivedVersionNegotiationPacket event", func() { + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + tr1.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + tr2.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + tracer.ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + }) + + It("traces the ReceivedRetry event", func() { + hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + tr1.EXPECT().ReceivedRetry(hdr) + tr2.EXPECT().ReceivedRetry(hdr) + tracer.ReceivedRetry(hdr) + }) + + It("traces the ReceivedPacket event", func() { + hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + ping := &PingFrame{} + tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) + tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) + tracer.ReceivedPacket(hdr, 1337, []Frame{ping}) + }) + + It("traces the BufferedPacket event", func() { + tr1.EXPECT().BufferedPacket(PacketTypeHandshake) + tr2.EXPECT().BufferedPacket(PacketTypeHandshake) + tracer.BufferedPacket(PacketTypeHandshake) + }) + + It("traces the DroppedPacket event", func() { + tr1.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) + tr2.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) + tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) + }) + + It("traces the UpdatedCongestionState event", func() { + tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tracer.UpdatedCongestionState(CongestionStateRecovery) + }) + + It("traces the UpdatedMetrics event", func() { + rttStats := &RTTStats{} + rttStats.UpdateRTT(time.Second, 0, time.Now()) + tr1.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) + tr2.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) + tracer.UpdatedMetrics(rttStats, 1337, 42, 13) + }) + + It("traces the AcknowledgedPacket event", func() { + tr1.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) + tr2.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) + tracer.AcknowledgedPacket(EncryptionHandshake, 42) + }) + + It("traces the LostPacket event", func() { + tr1.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) + tr2.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) + tracer.LostPacket(EncryptionHandshake, 42, PacketLossReorderingThreshold) + }) + + It("traces the UpdatedPTOCount event", func() { + tr1.EXPECT().UpdatedPTOCount(uint32(88)) + tr2.EXPECT().UpdatedPTOCount(uint32(88)) + tracer.UpdatedPTOCount(88) + }) + + It("traces the UpdatedKeyFromTLS event", func() { + tr1.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + tr2.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + tracer.UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) + }) + + It("traces the UpdatedKey event", func() { + tr1.EXPECT().UpdatedKey(KeyPhase(42), true) + tr2.EXPECT().UpdatedKey(KeyPhase(42), true) + tracer.UpdatedKey(KeyPhase(42), true) + }) + + It("traces the DroppedEncryptionLevel event", func() { + tr1.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) + tr2.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) + tracer.DroppedEncryptionLevel(EncryptionHandshake) + }) + + It("traces the DroppedKey event", func() { + tr1.EXPECT().DroppedKey(KeyPhase(123)) + tr2.EXPECT().DroppedKey(KeyPhase(123)) + tracer.DroppedKey(123) + }) + + It("traces the SetLossTimer event", func() { + now := time.Now() + tr1.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + tr2.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + tracer.SetLossTimer(TimerTypePTO, EncryptionHandshake, now) + }) + + It("traces the LossTimerExpired event", func() { + tr1.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) + tr2.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) + tracer.LossTimerExpired(TimerTypePTO, EncryptionHandshake) + }) + + It("traces the LossTimerCanceled event", func() { + tr1.EXPECT().LossTimerCanceled() + tr2.EXPECT().LossTimerCanceled() + tracer.LossTimerCanceled() + }) + + It("traces the Close event", func() { + tr1.EXPECT().Close() + tr2.EXPECT().Close() + tracer.Close() + }) + }) +}) diff --git a/internal/quic-go/logging/packet_header.go b/internal/quic-go/logging/packet_header.go new file mode 100644 index 00000000..9bb397dd --- /dev/null +++ b/internal/quic-go/logging/packet_header.go @@ -0,0 +1,27 @@ +package logging + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// PacketTypeFromHeader determines the packet type from a *wire.Header. +func PacketTypeFromHeader(hdr *Header) PacketType { + if !hdr.IsLongHeader { + return PacketType1RTT + } + if hdr.Version == 0 { + return PacketTypeVersionNegotiation + } + switch hdr.Type { + case protocol.PacketTypeInitial: + return PacketTypeInitial + case protocol.PacketTypeHandshake: + return PacketTypeHandshake + case protocol.PacketType0RTT: + return PacketType0RTT + case protocol.PacketTypeRetry: + return PacketTypeRetry + default: + return PacketTypeNotDetermined + } +} diff --git a/internal/quic-go/logging/packet_header_test.go b/internal/quic-go/logging/packet_header_test.go new file mode 100644 index 00000000..de8b3e68 --- /dev/null +++ b/internal/quic-go/logging/packet_header_test.go @@ -0,0 +1,60 @@ +package logging + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet Header", func() { + Context("determining the packet type from the header", func() { + It("recognizes Initial packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeInitial)) + }) + + It("recognizes Handshake packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeHandshake)) + }) + + It("recognizes Retry packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeRetry)) + }) + + It("recognizes 0-RTT packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + Version: protocol.VersionTLS, + })).To(Equal(PacketType0RTT)) + }) + + It("recognizes Version Negotiation packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{IsLongHeader: true})).To(Equal(PacketTypeVersionNegotiation)) + }) + + It("recognizes 1-RTT packets", func() { + Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketType1RTT)) + }) + + It("handles unrecognized packet types", func() { + Expect(PacketTypeFromHeader(&wire.Header{ + IsLongHeader: true, + Version: protocol.VersionTLS, + })).To(Equal(PacketTypeNotDetermined)) + }) + }) +}) diff --git a/internal/quic-go/logging/types.go b/internal/quic-go/logging/types.go new file mode 100644 index 00000000..ad800692 --- /dev/null +++ b/internal/quic-go/logging/types.go @@ -0,0 +1,94 @@ +package logging + +// PacketType is the packet type of a QUIC packet +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = iota + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT + // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet + PacketTypeVersionNegotiation + // PacketType1RTT is a 1-RTT packet + PacketType1RTT + // PacketTypeStatelessReset is a stateless reset + PacketTypeStatelessReset + // PacketTypeNotDetermined is the packet type when it could not be determined + PacketTypeNotDetermined +) + +type PacketLossReason uint8 + +const ( + // PacketLossReorderingThreshold: when a packet is deemed lost due to reordering threshold + PacketLossReorderingThreshold PacketLossReason = iota + // PacketLossTimeThreshold: when a packet is deemed lost due to time threshold + PacketLossTimeThreshold +) + +type PacketDropReason uint8 + +const ( + // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable + PacketDropKeyUnavailable PacketDropReason = iota + // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown + PacketDropUnknownConnectionID + // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed + PacketDropHeaderParseError + // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed + PacketDropPayloadDecryptError + // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation + PacketDropProtocolViolation + // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack + PacketDropDOSPrevention + // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported + PacketDropUnsupportedVersion + // PacketDropUnexpectedPacket is used when an unexpected packet is received + PacketDropUnexpectedPacket + // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received + PacketDropUnexpectedSourceConnectionID + // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received + PacketDropUnexpectedVersion + // PacketDropDuplicate is used when a duplicate packet is received + PacketDropDuplicate +) + +// TimerType is the type of the loss detection timer +type TimerType uint8 + +const ( + // TimerTypeACK is the timer type for the early retransmit timer + TimerTypeACK TimerType = iota + // TimerTypePTO is the timer type for the PTO retransmit timer + TimerTypePTO +) + +// TimeoutReason is the reason why a connection is closed +type TimeoutReason uint8 + +const ( + // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonHandshake TimeoutReason = iota + // TimeoutReasonIdle is used when the connection is closed due to an idle timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonIdle +) + +type CongestionState uint8 + +const ( + // CongestionStateSlowStart is the slow start phase of Reno / Cubic + CongestionStateSlowStart CongestionState = iota + // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic + CongestionStateCongestionAvoidance + // CongestionStateRecovery is the recovery phase of Reno / Cubic + CongestionStateRecovery + // CongestionStateApplicationLimited means that the congestion controller is application limited + CongestionStateApplicationLimited +) diff --git a/internal/quic-go/logutils/frame.go b/internal/quic-go/logutils/frame.go new file mode 100644 index 00000000..9076c8f0 --- /dev/null +++ b/internal/quic-go/logutils/frame.go @@ -0,0 +1,33 @@ +package logutils + +import ( + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// ConvertFrame converts a wire.Frame into a logging.Frame. +// This makes it possible for external packages to access the frames. +// Furthermore, it removes the data slices from CRYPTO and STREAM frames. +func ConvertFrame(frame wire.Frame) logging.Frame { + switch f := frame.(type) { + case *wire.CryptoFrame: + return &logging.CryptoFrame{ + Offset: f.Offset, + Length: protocol.ByteCount(len(f.Data)), + } + case *wire.StreamFrame: + return &logging.StreamFrame{ + StreamID: f.StreamID, + Offset: f.Offset, + Length: f.DataLen(), + Fin: f.Fin, + } + case *wire.DatagramFrame: + return &logging.DatagramFrame{ + Length: logging.ByteCount(len(f.Data)), + } + default: + return logging.Frame(frame) + } +} diff --git a/internal/quic-go/logutils/frame_test.go b/internal/quic-go/logutils/frame_test.go new file mode 100644 index 00000000..9a1acc13 --- /dev/null +++ b/internal/quic-go/logutils/frame_test.go @@ -0,0 +1,51 @@ +package logutils + +import ( + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CRYPTO frame", func() { + It("converts CRYPTO frames", func() { + f := ConvertFrame(&wire.CryptoFrame{ + Offset: 1234, + Data: []byte("foobar"), + }) + Expect(f).To(BeAssignableToTypeOf(&logging.CryptoFrame{})) + cf := f.(*logging.CryptoFrame) + Expect(cf.Offset).To(Equal(logging.ByteCount(1234))) + Expect(cf.Length).To(Equal(logging.ByteCount(6))) + }) + + It("converts STREAM frames", func() { + f := ConvertFrame(&wire.StreamFrame{ + StreamID: 42, + Offset: 1234, + Data: []byte("foo"), + Fin: true, + }) + Expect(f).To(BeAssignableToTypeOf(&logging.StreamFrame{})) + sf := f.(*logging.StreamFrame) + Expect(sf.StreamID).To(Equal(logging.StreamID(42))) + Expect(sf.Offset).To(Equal(logging.ByteCount(1234))) + Expect(sf.Length).To(Equal(logging.ByteCount(3))) + Expect(sf.Fin).To(BeTrue()) + }) + + It("converts DATAGRAM frames", func() { + f := ConvertFrame(&wire.DatagramFrame{Data: []byte("foobar")}) + Expect(f).To(BeAssignableToTypeOf(&logging.DatagramFrame{})) + df := f.(*logging.DatagramFrame) + Expect(df.Length).To(Equal(logging.ByteCount(6))) + }) + + It("converts other frames", func() { + f := ConvertFrame(&wire.MaxDataFrame{MaximumData: 1234}) + Expect(f).To(BeAssignableToTypeOf(&logging.MaxDataFrame{})) + Expect(f).ToNot(BeAssignableToTypeOf(&logging.MaxStreamDataFrame{})) + mdf := f.(*logging.MaxDataFrame) + Expect(mdf.MaximumData).To(Equal(logging.ByteCount(1234))) + }) +}) diff --git a/internal/quic-go/logutils/logutils_suite_test.go b/internal/quic-go/logutils/logutils_suite_test.go new file mode 100644 index 00000000..dc496b2d --- /dev/null +++ b/internal/quic-go/logutils/logutils_suite_test.go @@ -0,0 +1,13 @@ +package logutils + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestLogutils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Logutils Suite") +} diff --git a/internal/quic-go/mock_ack_frame_source_test.go b/internal/quic-go/mock_ack_frame_source_test.go new file mode 100644 index 00000000..4d498553 --- /dev/null +++ b/internal/quic-go/mock_ack_frame_source_test.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: packet_packer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockAckFrameSource is a mock of AckFrameSource interface. +type MockAckFrameSource struct { + ctrl *gomock.Controller + recorder *MockAckFrameSourceMockRecorder +} + +// MockAckFrameSourceMockRecorder is the mock recorder for MockAckFrameSource. +type MockAckFrameSourceMockRecorder struct { + mock *MockAckFrameSource +} + +// NewMockAckFrameSource creates a new mock instance. +func NewMockAckFrameSource(ctrl *gomock.Controller) *MockAckFrameSource { + mock := &MockAckFrameSource{ctrl: ctrl} + mock.recorder = &MockAckFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { + return m.recorder +} + +// GetAckFrame mocks base method. +func (m *MockAckFrameSource) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAckFrame", encLevel, onlyIfQueued) + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame. +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(encLevel, onlyIfQueued interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), encLevel, onlyIfQueued) +} diff --git a/internal/quic-go/mock_batch_conn_test.go b/internal/quic-go/mock_batch_conn_test.go new file mode 100644 index 00000000..74032900 --- /dev/null +++ b/internal/quic-go/mock_batch_conn_test.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sys_conn_oob.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ipv4 "golang.org/x/net/ipv4" +) + +// MockBatchConn is a mock of BatchConn interface. +type MockBatchConn struct { + ctrl *gomock.Controller + recorder *MockBatchConnMockRecorder +} + +// MockBatchConnMockRecorder is the mock recorder for MockBatchConn. +type MockBatchConnMockRecorder struct { + mock *MockBatchConn +} + +// NewMockBatchConn creates a new mock instance. +func NewMockBatchConn(ctrl *gomock.Controller) *MockBatchConn { + mock := &MockBatchConn{ctrl: ctrl} + mock.recorder = &MockBatchConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatchConn) EXPECT() *MockBatchConnMockRecorder { + return m.recorder +} + +// ReadBatch mocks base method. +func (m *MockBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadBatch", ms, flags) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadBatch indicates an expected call of ReadBatch. +func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) +} diff --git a/internal/quic-go/mock_conn_runner_test.go b/internal/quic-go/mock_conn_runner_test.go new file mode 100644 index 00000000..99d7dc8f --- /dev/null +++ b/internal/quic-go/mock_conn_runner_test.go @@ -0,0 +1,123 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: connection.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockConnRunner is a mock of ConnRunner interface. +type MockConnRunner struct { + ctrl *gomock.Controller + recorder *MockConnRunnerMockRecorder +} + +// MockConnRunnerMockRecorder is the mock recorder for MockConnRunner. +type MockConnRunnerMockRecorder struct { + mock *MockConnRunner +} + +// NewMockConnRunner creates a new mock instance. +func NewMockConnRunner(ctrl *gomock.Controller) *MockConnRunner { + mock := &MockConnRunner{ctrl: ctrl} + mock.recorder = &MockConnRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnRunner) EXPECT() *MockConnRunnerMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) +} + +// AddResetToken mocks base method. +func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken. +func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) +} + +// GetStatelessResetToken mocks base method. +func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].(protocol.StatelessResetToken) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. +func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) +} + +// Remove mocks base method. +func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Remove", arg0) +} + +// Remove indicates an expected call of Remove. +func (mr *MockConnRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) +} + +// RemoveResetToken mocks base method. +func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken. +func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) +} + +// ReplaceWithClosed mocks base method. +func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) +} + +// ReplaceWithClosed indicates an expected call of ReplaceWithClosed. +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) +} + +// Retire mocks base method. +func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Retire", arg0) +} + +// Retire indicates an expected call of Retire. +func (mr *MockConnRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) +} diff --git a/internal/quic-go/mock_crypto_data_handler_test.go b/internal/quic-go/mock_crypto_data_handler_test.go new file mode 100644 index 00000000..9c70ff2d --- /dev/null +++ b/internal/quic-go/mock_crypto_data_handler_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto_stream_manager.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockCryptoDataHandler is a mock of CryptoDataHandler interface. +type MockCryptoDataHandler struct { + ctrl *gomock.Controller + recorder *MockCryptoDataHandlerMockRecorder +} + +// MockCryptoDataHandlerMockRecorder is the mock recorder for MockCryptoDataHandler. +type MockCryptoDataHandlerMockRecorder struct { + mock *MockCryptoDataHandler +} + +// NewMockCryptoDataHandler creates a new mock instance. +func NewMockCryptoDataHandler(ctrl *gomock.Controller) *MockCryptoDataHandler { + mock := &MockCryptoDataHandler{ctrl: ctrl} + mock.recorder = &MockCryptoDataHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { + return m.recorder +} + +// HandleMessage mocks base method. +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleMessage indicates an expected call of HandleMessage. +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) +} diff --git a/internal/quic-go/mock_crypto_stream_test.go b/internal/quic-go/mock_crypto_stream_test.go new file mode 100644 index 00000000..2cdf22de --- /dev/null +++ b/internal/quic-go/mock_crypto_stream_test.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto_stream.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockCryptoStream is a mock of CryptoStream interface. +type MockCryptoStream struct { + ctrl *gomock.Controller + recorder *MockCryptoStreamMockRecorder +} + +// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream. +type MockCryptoStreamMockRecorder struct { + mock *MockCryptoStream +} + +// NewMockCryptoStream creates a new mock instance. +func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream { + mock := &MockCryptoStream{ctrl: ctrl} + mock.recorder = &MockCryptoStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { + return m.recorder +} + +// Finish mocks base method. +func (m *MockCryptoStream) Finish() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Finish") + ret0, _ := ret[0].(error) + return ret0 +} + +// Finish indicates an expected call of Finish. +func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) +} + +// GetCryptoData mocks base method. +func (m *MockCryptoStream) GetCryptoData() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoData") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// GetCryptoData indicates an expected call of GetCryptoData. +func (mr *MockCryptoStreamMockRecorder) GetCryptoData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoData", reflect.TypeOf((*MockCryptoStream)(nil).GetCryptoData)) +} + +// HandleCryptoFrame mocks base method. +func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleCryptoFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleCryptoFrame indicates an expected call of HandleCryptoFrame. +func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) +} + +// HasData mocks base method. +func (m *MockCryptoStream) HasData() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasData indicates an expected call of HasData. +func (mr *MockCryptoStreamMockRecorder) HasData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockCryptoStream)(nil).HasData)) +} + +// PopCryptoFrame mocks base method. +func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoFrame { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PopCryptoFrame", arg0) + ret0, _ := ret[0].(*wire.CryptoFrame) + return ret0 +} + +// PopCryptoFrame indicates an expected call of PopCryptoFrame. +func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) +} + +// Write mocks base method. +func (m *MockCryptoStream) Write(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockCryptoStreamMockRecorder) Write(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), p) +} diff --git a/internal/quic-go/mock_frame_source_test.go b/internal/quic-go/mock_frame_source_test.go new file mode 100644 index 00000000..efe0bccf --- /dev/null +++ b/internal/quic-go/mock_frame_source_test.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: packet_packer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockFrameSource is a mock of FrameSource interface. +type MockFrameSource struct { + ctrl *gomock.Controller + recorder *MockFrameSourceMockRecorder +} + +// MockFrameSourceMockRecorder is the mock recorder for MockFrameSource. +type MockFrameSourceMockRecorder struct { + mock *MockFrameSource +} + +// NewMockFrameSource creates a new mock instance. +func NewMockFrameSource(ctrl *gomock.Controller) *MockFrameSource { + mock := &MockFrameSource{ctrl: ctrl} + mock.recorder = &MockFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { + return m.recorder +} + +// AppendControlFrames mocks base method. +func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppendControlFrames", arg0, arg1) + ret0, _ := ret[0].([]ackhandler.Frame) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// AppendControlFrames indicates an expected call of AppendControlFrames. +func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1) +} + +// AppendStreamFrames mocks base method. +func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) + ret0, _ := ret[0].([]ackhandler.Frame) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// AppendStreamFrames indicates an expected call of AppendStreamFrames. +func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1) +} + +// HasData mocks base method. +func (m *MockFrameSource) HasData() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasData indicates an expected call of HasData. +func (mr *MockFrameSourceMockRecorder) HasData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) +} diff --git a/internal/quic-go/mock_mtu_discoverer_test.go b/internal/quic-go/mock_mtu_discoverer_test.go new file mode 100644 index 00000000..57993be1 --- /dev/null +++ b/internal/quic-go/mock_mtu_discoverer_test.go @@ -0,0 +1,66 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: mtu_discoverer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockMtuDiscoverer is a mock of MtuDiscoverer interface. +type MockMtuDiscoverer struct { + ctrl *gomock.Controller + recorder *MockMtuDiscovererMockRecorder +} + +// MockMtuDiscovererMockRecorder is the mock recorder for MockMtuDiscoverer. +type MockMtuDiscovererMockRecorder struct { + mock *MockMtuDiscoverer +} + +// NewMockMtuDiscoverer creates a new mock instance. +func NewMockMtuDiscoverer(ctrl *gomock.Controller) *MockMtuDiscoverer { + mock := &MockMtuDiscoverer{ctrl: ctrl} + mock.recorder = &MockMtuDiscovererMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMtuDiscoverer) EXPECT() *MockMtuDiscovererMockRecorder { + return m.recorder +} + +// GetPing mocks base method. +func (m *MockMtuDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPing") + ret0, _ := ret[0].(ackhandler.Frame) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// GetPing indicates an expected call of GetPing. +func (mr *MockMtuDiscovererMockRecorder) GetPing() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMtuDiscoverer)(nil).GetPing)) +} + +// ShouldSendProbe mocks base method. +func (m *MockMtuDiscoverer) ShouldSendProbe(now time.Time) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShouldSendProbe", now) + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShouldSendProbe indicates an expected call of ShouldSendProbe. +func (mr *MockMtuDiscovererMockRecorder) ShouldSendProbe(now interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMtuDiscoverer)(nil).ShouldSendProbe), now) +} diff --git a/internal/quic-go/mock_multiplexer_test.go b/internal/quic-go/mock_multiplexer_test.go new file mode 100644 index 00000000..2bce0112 --- /dev/null +++ b/internal/quic-go/mock_multiplexer_test.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: multiplexer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + logging "github.com/imroc/req/v3/internal/quic-go/logging" +) + +// MockMultiplexer is a mock of Multiplexer interface. +type MockMultiplexer struct { + ctrl *gomock.Controller + recorder *MockMultiplexerMockRecorder +} + +// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer. +type MockMultiplexerMockRecorder struct { + mock *MockMultiplexer +} + +// NewMockMultiplexer creates a new mock instance. +func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer { + mock := &MockMultiplexer{ctrl: ctrl} + mock.recorder = &MockMultiplexerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { + return m.recorder +} + +// AddConn mocks base method. +func (m *MockMultiplexer) AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddConn", c, connIDLen, statelessResetKey, tracer) + ret0, _ := ret[0].(packetHandlerManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddConn indicates an expected call of AddConn. +func (mr *MockMultiplexerMockRecorder) AddConn(c, connIDLen, statelessResetKey, tracer interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), c, connIDLen, statelessResetKey, tracer) +} + +// RemoveConn mocks base method. +func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveConn", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveConn indicates an expected call of RemoveConn. +func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0) +} diff --git a/internal/quic-go/mock_packer_test.go b/internal/quic-go/mock_packer_test.go new file mode 100644 index 00000000..ec1e5cca --- /dev/null +++ b/internal/quic-go/mock_packer_test.go @@ -0,0 +1,179 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: packet_packer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + qerr "github.com/imroc/req/v3/internal/quic-go/qerr" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockPacker is a mock of Packer interface. +type MockPacker struct { + ctrl *gomock.Controller + recorder *MockPackerMockRecorder +} + +// MockPackerMockRecorder is the mock recorder for MockPacker. +type MockPackerMockRecorder struct { + mock *MockPacker +} + +// NewMockPacker creates a new mock instance. +func NewMockPacker(ctrl *gomock.Controller) *MockPacker { + mock := &MockPacker{ctrl: ctrl} + mock.recorder = &MockPackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacker) EXPECT() *MockPackerMockRecorder { + return m.recorder +} + +// HandleTransportParameters mocks base method. +func (m *MockPacker) HandleTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "HandleTransportParameters", arg0) +} + +// HandleTransportParameters indicates an expected call of HandleTransportParameters. +func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleTransportParameters", reflect.TypeOf((*MockPacker)(nil).HandleTransportParameters), arg0) +} + +// MaybePackAckPacket mocks base method. +func (m *MockPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaybePackAckPacket", handshakeConfirmed) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaybePackAckPacket indicates an expected call of MaybePackAckPacket. +func (mr *MockPackerMockRecorder) MaybePackAckPacket(handshakeConfirmed interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), handshakeConfirmed) +} + +// MaybePackProbePacket mocks base method. +func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaybePackProbePacket indicates an expected call of MaybePackProbePacket. +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) +} + +// PackApplicationClose mocks base method. +func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError) (*coalescedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackApplicationClose", arg0) + ret0, _ := ret[0].(*coalescedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackApplicationClose indicates an expected call of PackApplicationClose. +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0) +} + +// PackCoalescedPacket mocks base method. +func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackCoalescedPacket") + ret0, _ := ret[0].(*coalescedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackCoalescedPacket indicates an expected call of PackCoalescedPacket. +func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket)) +} + +// PackConnectionClose mocks base method. +func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError) (*coalescedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackConnectionClose", arg0) + ret0, _ := ret[0].(*coalescedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackConnectionClose indicates an expected call of PackConnectionClose. +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0) +} + +// PackMTUProbePacket mocks base method. +func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackMTUProbePacket indicates an expected call of PackMTUProbePacket. +func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size) +} + +// PackPacket mocks base method. +func (m *MockPacker) PackPacket() (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackPacket indicates an expected call of PackPacket. +func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) +} + +// SetMaxPacketSize mocks base method. +func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxPacketSize", arg0) +} + +// SetMaxPacketSize indicates an expected call of SetMaxPacketSize. +func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) +} + +// SetToken mocks base method. +func (m *MockPacker) SetToken(arg0 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetToken", arg0) +} + +// SetToken indicates an expected call of SetToken. +func (mr *MockPackerMockRecorder) SetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) +} diff --git a/internal/quic-go/mock_packet_handler_manager_test.go b/internal/quic-go/mock_packet_handler_manager_test.go new file mode 100644 index 00000000..eb8539da --- /dev/null +++ b/internal/quic-go/mock_packet_handler_manager_test.go @@ -0,0 +1,175 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockPacketHandlerManager is a mock of PacketHandlerManager interface. +type MockPacketHandlerManager struct { + ctrl *gomock.Controller + recorder *MockPacketHandlerManagerMockRecorder +} + +// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager. +type MockPacketHandlerManagerMockRecorder struct { + mock *MockPacketHandlerManager +} + +// NewMockPacketHandlerManager creates a new mock instance. +func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager { + mock := &MockPacketHandlerManager{ctrl: ctrl} + mock.recorder = &MockPacketHandlerManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) +} + +// AddResetToken mocks base method. +func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken. +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) +} + +// AddWithConnID mocks base method. +func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AddWithConnID indicates an expected call of AddWithConnID. +func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) +} + +// CloseServer mocks base method. +func (m *MockPacketHandlerManager) CloseServer() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CloseServer") +} + +// CloseServer indicates an expected call of CloseServer. +func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) +} + +// Destroy mocks base method. +func (m *MockPacketHandlerManager) Destroy() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Destroy") + ret0, _ := ret[0].(error) + return ret0 +} + +// Destroy indicates an expected call of Destroy. +func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy)) +} + +// GetStatelessResetToken mocks base method. +func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].(protocol.StatelessResetToken) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) +} + +// Remove mocks base method. +func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Remove", arg0) +} + +// Remove indicates an expected call of Remove. +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) +} + +// RemoveResetToken mocks base method. +func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessResetToken) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken. +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) +} + +// ReplaceWithClosed mocks base method. +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) +} + +// ReplaceWithClosed indicates an expected call of ReplaceWithClosed. +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) +} + +// Retire mocks base method. +func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Retire", arg0) +} + +// Retire indicates an expected call of Retire. +func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) +} + +// SetServer mocks base method. +func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetServer", arg0) +} + +// SetServer indicates an expected call of SetServer. +func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0) +} diff --git a/internal/quic-go/mock_packet_handler_test.go b/internal/quic-go/mock_packet_handler_test.go new file mode 100644 index 00000000..82bb383e --- /dev/null +++ b/internal/quic-go/mock_packet_handler_test.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockPacketHandler is a mock of PacketHandler interface. +type MockPacketHandler struct { + ctrl *gomock.Controller + recorder *MockPacketHandlerMockRecorder +} + +// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler. +type MockPacketHandlerMockRecorder struct { + mock *MockPacketHandler +} + +// NewMockPacketHandler creates a new mock instance. +func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler { + mock := &MockPacketHandler{ctrl: ctrl} + mock.recorder = &MockPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { + return m.recorder +} + +// destroy mocks base method. +func (m *MockPacketHandler) destroy(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "destroy", arg0) +} + +// destroy indicates an expected call of destroy. +func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) +} + +// getPerspective mocks base method. +func (m *MockPacketHandler) getPerspective() protocol.Perspective { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective. +func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective)) +} + +// handlePacket mocks base method. +func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket. +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) +} + +// shutdown mocks base method. +func (m *MockPacketHandler) shutdown() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "shutdown") +} + +// shutdown indicates an expected call of shutdown. +func (mr *MockPacketHandlerMockRecorder) shutdown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockPacketHandler)(nil).shutdown)) +} diff --git a/internal/quic-go/mock_packetconn_test.go b/internal/quic-go/mock_packetconn_test.go new file mode 100644 index 00000000..d6731e4a --- /dev/null +++ b/internal/quic-go/mock_packetconn_test.go @@ -0,0 +1,137 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: PacketConn) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockPacketConn is a mock of PacketConn interface. +type MockPacketConn struct { + ctrl *gomock.Controller + recorder *MockPacketConnMockRecorder +} + +// MockPacketConnMockRecorder is the mock recorder for MockPacketConn. +type MockPacketConnMockRecorder struct { + mock *MockPacketConn +} + +// NewMockPacketConn creates a new mock instance. +func NewMockPacketConn(ctrl *gomock.Controller) *MockPacketConn { + mock := &MockPacketConn{ctrl: ctrl} + mock.recorder = &MockPacketConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketConn) EXPECT() *MockPacketConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockPacketConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockPacketConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockPacketConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockPacketConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) +} + +// ReadFrom mocks base method. +func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFrom", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(net.Addr) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadFrom indicates an expected call of ReadFrom. +func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) +} + +// SetDeadline mocks base method. +func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method. +func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method. +func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) +} + +// WriteTo mocks base method. +func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteTo", arg0, arg1) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteTo indicates an expected call of WriteTo. +func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) +} diff --git a/internal/quic-go/mock_quic_conn_test.go b/internal/quic-go/mock_quic_conn_test.go new file mode 100644 index 00000000..c79523ed --- /dev/null +++ b/internal/quic-go/mock_quic_conn_test.go @@ -0,0 +1,346 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server.go + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockQuicConn is a mock of QuicConn interface. +type MockQuicConn struct { + ctrl *gomock.Controller + recorder *MockQuicConnMockRecorder +} + +// MockQuicConnMockRecorder is the mock recorder for MockQuicConn. +type MockQuicConnMockRecorder struct { + mock *MockQuicConn +} + +// NewMockQuicConn creates a new mock instance. +func NewMockQuicConn(ctrl *gomock.Controller) *MockQuicConn { + mock := &MockQuicConn{ctrl: ctrl} + mock.recorder = &MockQuicConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQuicConn) EXPECT() *MockQuicConnMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method. +func (m *MockQuicConn) AcceptStream(arg0 context.Context) (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptStream", arg0) + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream. +func (mr *MockQuicConnMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptStream), arg0) +} + +// AcceptUniStream mocks base method. +func (m *MockQuicConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) + ret0, _ := ret[0].(ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream. +func (mr *MockQuicConnMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptUniStream), arg0) +} + +// CloseWithError mocks base method. +func (m *MockQuicConn) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWithError indicates an expected call of CloseWithError. +func (mr *MockQuicConnMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicConn)(nil).CloseWithError), arg0, arg1) +} + +// ConnectionState mocks base method. +func (m *MockQuicConn) ConnectionState() ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockQuicConnMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQuicConn)(nil).ConnectionState)) +} + +// Context mocks base method. +func (m *MockQuicConn) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockQuicConnMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicConn)(nil).Context)) +} + +// GetVersion mocks base method. +func (m *MockQuicConn) GetVersion() protocol.VersionNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVersion") + ret0, _ := ret[0].(protocol.VersionNumber) + return ret0 +} + +// GetVersion indicates an expected call of GetVersion. +func (mr *MockQuicConnMockRecorder) GetVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicConn)(nil).GetVersion)) +} + +// HandshakeComplete mocks base method. +func (m *MockQuicConn) HandshakeComplete() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandshakeComplete") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// HandshakeComplete indicates an expected call of HandshakeComplete. +func (mr *MockQuicConnMockRecorder) HandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicConn)(nil).HandshakeComplete)) +} + +// LocalAddr mocks base method. +func (m *MockQuicConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockQuicConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicConn)(nil).LocalAddr)) +} + +// NextConnection mocks base method. +func (m *MockQuicConn) NextConnection() Connection { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextConnection") + ret0, _ := ret[0].(Connection) + return ret0 +} + +// NextConnection indicates an expected call of NextConnection. +func (mr *MockQuicConnMockRecorder) NextConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQuicConn)(nil).NextConnection)) +} + +// OpenStream mocks base method. +func (m *MockQuicConn) OpenStream() (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream. +func (mr *MockQuicConnMockRecorder) OpenStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQuicConn)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method. +func (m *MockQuicConn) OpenStreamSync(arg0 context.Context) (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync. +func (mr *MockQuicConnMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenStreamSync), arg0) +} + +// OpenUniStream mocks base method. +func (m *MockQuicConn) OpenUniStream() (SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream. +func (mr *MockQuicConnMockRecorder) OpenUniStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method. +func (m *MockQuicConn) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. +func (mr *MockQuicConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStreamSync), arg0) +} + +// ReceiveMessage mocks base method. +func (m *MockQuicConn) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage. +func (mr *MockQuicConnMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicConn)(nil).ReceiveMessage)) +} + +// RemoteAddr mocks base method. +func (m *MockQuicConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockQuicConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicConn)(nil).RemoteAddr)) +} + +// SendMessage mocks base method. +func (m *MockQuicConn) SendMessage(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage. +func (mr *MockQuicConnMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicConn)(nil).SendMessage), arg0) +} + +// destroy mocks base method. +func (m *MockQuicConn) destroy(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "destroy", arg0) +} + +// destroy indicates an expected call of destroy. +func (mr *MockQuicConnMockRecorder) destroy(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicConn)(nil).destroy), arg0) +} + +// earlyConnReady mocks base method. +func (m *MockQuicConn) earlyConnReady() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "earlyConnReady") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// earlyConnReady indicates an expected call of earlyConnReady. +func (mr *MockQuicConnMockRecorder) earlyConnReady() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlyConnReady", reflect.TypeOf((*MockQuicConn)(nil).earlyConnReady)) +} + +// getPerspective mocks base method. +func (m *MockQuicConn) getPerspective() protocol.Perspective { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective. +func (mr *MockQuicConnMockRecorder) getPerspective() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicConn)(nil).getPerspective)) +} + +// handlePacket mocks base method. +func (m *MockQuicConn) handlePacket(arg0 *receivedPacket) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket. +func (mr *MockQuicConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQuicConn)(nil).handlePacket), arg0) +} + +// run mocks base method. +func (m *MockQuicConn) run() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "run") + ret0, _ := ret[0].(error) + return ret0 +} + +// run indicates an expected call of run. +func (mr *MockQuicConnMockRecorder) run() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQuicConn)(nil).run)) +} + +// shutdown mocks base method. +func (m *MockQuicConn) shutdown() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "shutdown") +} + +// shutdown indicates an expected call of shutdown. +func (mr *MockQuicConnMockRecorder) shutdown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockQuicConn)(nil).shutdown)) +} diff --git a/internal/quic-go/mock_receive_stream_internal_test.go b/internal/quic-go/mock_receive_stream_internal_test.go new file mode 100644 index 00000000..5389b85f --- /dev/null +++ b/internal/quic-go/mock_receive_stream_internal_test.go @@ -0,0 +1,146 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: receive_stream.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockReceiveStreamI is a mock of ReceiveStreamI interface. +type MockReceiveStreamI struct { + ctrl *gomock.Controller + recorder *MockReceiveStreamIMockRecorder +} + +// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI. +type MockReceiveStreamIMockRecorder struct { + mock *MockReceiveStreamI +} + +// NewMockReceiveStreamI creates a new mock instance. +func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI { + mock := &MockReceiveStreamI{ctrl: ctrl} + mock.recorder = &MockReceiveStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder { + return m.recorder +} + +// CancelRead mocks base method. +func (m *MockReceiveStreamI) CancelRead(arg0 StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelRead", arg0) +} + +// CancelRead indicates an expected call of CancelRead. +func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) +} + +// Read mocks base method. +func (m *MockReceiveStreamI) Read(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockReceiveStreamIMockRecorder) Read(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), p) +} + +// SetReadDeadline mocks base method. +func (m *MockReceiveStreamI) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), t) +} + +// StreamID mocks base method. +func (m *MockReceiveStreamI) StreamID() StreamID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID. +func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID)) +} + +// closeForShutdown mocks base method. +func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown. +func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method. +func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate. +func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) +} + +// handleResetStreamFrame mocks base method. +func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handleResetStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleResetStreamFrame indicates an expected call of handleResetStreamFrame. +func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) +} + +// handleStreamFrame mocks base method. +func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame. +func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) +} diff --git a/internal/quic-go/mock_sealing_manager_test.go b/internal/quic-go/mock_sealing_manager_test.go new file mode 100644 index 00000000..a046c897 --- /dev/null +++ b/internal/quic-go/mock_sealing_manager_test.go @@ -0,0 +1,95 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: packet_packer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + handshake "github.com/imroc/req/v3/internal/quic-go/handshake" +) + +// MockSealingManager is a mock of SealingManager interface. +type MockSealingManager struct { + ctrl *gomock.Controller + recorder *MockSealingManagerMockRecorder +} + +// MockSealingManagerMockRecorder is the mock recorder for MockSealingManager. +type MockSealingManagerMockRecorder struct { + mock *MockSealingManager +} + +// NewMockSealingManager creates a new mock instance. +func NewMockSealingManager(ctrl *gomock.Controller) *MockSealingManager { + mock := &MockSealingManager{ctrl: ctrl} + mock.recorder = &MockSealingManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder { + return m.recorder +} + +// Get0RTTSealer mocks base method. +func (m *MockSealingManager) Get0RTTSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTSealer indicates an expected call of Get0RTTSealer. +func (mr *MockSealingManagerMockRecorder) Get0RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get0RTTSealer)) +} + +// Get1RTTSealer mocks base method. +func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get1RTTSealer") + ret0, _ := ret[0].(handshake.ShortHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get1RTTSealer indicates an expected call of Get1RTTSealer. +func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get1RTTSealer)) +} + +// GetHandshakeSealer mocks base method. +func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. +func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockSealingManager)(nil).GetHandshakeSealer)) +} + +// GetInitialSealer mocks base method. +func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialSealer indicates an expected call of GetInitialSealer. +func (mr *MockSealingManagerMockRecorder) GetInitialSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockSealingManager)(nil).GetInitialSealer)) +} diff --git a/internal/quic-go/mock_send_conn_test.go b/internal/quic-go/mock_send_conn_test.go new file mode 100644 index 00000000..d66fec5f --- /dev/null +++ b/internal/quic-go/mock_send_conn_test.go @@ -0,0 +1,91 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: send_conn.go + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockSendConn is a mock of SendConn interface. +type MockSendConn struct { + ctrl *gomock.Controller + recorder *MockSendConnMockRecorder +} + +// MockSendConnMockRecorder is the mock recorder for MockSendConn. +type MockSendConnMockRecorder struct { + mock *MockSendConn +} + +// NewMockSendConn creates a new mock instance. +func NewMockSendConn(ctrl *gomock.Controller) *MockSendConn { + mock := &MockSendConn{ctrl: ctrl} + mock.recorder = &MockSendConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSendConn) EXPECT() *MockSendConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockSendConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockSendConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockSendConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockSendConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSendConn)(nil).LocalAddr)) +} + +// RemoteAddr mocks base method. +func (m *MockSendConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSendConn)(nil).RemoteAddr)) +} + +// Write mocks base method. +func (m *MockSendConn) Write(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write. +func (mr *MockSendConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0) +} diff --git a/internal/quic-go/mock_send_stream_internal_test.go b/internal/quic-go/mock_send_stream_internal_test.go new file mode 100644 index 00000000..7ce194aa --- /dev/null +++ b/internal/quic-go/mock_send_stream_internal_test.go @@ -0,0 +1,187 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: send_stream.go + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockSendStreamI is a mock of SendStreamI interface. +type MockSendStreamI struct { + ctrl *gomock.Controller + recorder *MockSendStreamIMockRecorder +} + +// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI. +type MockSendStreamIMockRecorder struct { + mock *MockSendStreamI +} + +// NewMockSendStreamI creates a new mock instance. +func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI { + mock := &MockSendStreamI{ctrl: ctrl} + mock.recorder = &MockSendStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder { + return m.recorder +} + +// CancelWrite mocks base method. +func (m *MockSendStreamI) CancelWrite(arg0 StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelWrite", arg0) +} + +// CancelWrite indicates an expected call of CancelWrite. +func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) +} + +// Close mocks base method. +func (m *MockSendStreamI) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close)) +} + +// Context mocks base method. +func (m *MockSendStreamI) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context)) +} + +// SetWriteDeadline mocks base method. +func (m *MockSendStreamI) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), t) +} + +// StreamID mocks base method. +func (m *MockSendStreamI) StreamID() StreamID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID. +func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID)) +} + +// Write mocks base method. +func (m *MockSendStreamI) Write(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockSendStreamIMockRecorder) Write(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), p) +} + +// closeForShutdown mocks base method. +func (m *MockSendStreamI) closeForShutdown(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown. +func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) +} + +// handleStopSendingFrame mocks base method. +func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "handleStopSendingFrame", arg0) +} + +// handleStopSendingFrame indicates an expected call of handleStopSendingFrame. +func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) +} + +// hasData mocks base method. +func (m *MockSendStreamI) hasData() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "hasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasData indicates an expected call of hasData. +func (mr *MockSendStreamIMockRecorder) hasData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData)) +} + +// popStreamFrame mocks base method. +func (m *MockSendStreamI) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "popStreamFrame", maxBytes) + ret0, _ := ret[0].(*ackhandler.Frame) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// popStreamFrame indicates an expected call of popStreamFrame. +func (mr *MockSendStreamIMockRecorder) popStreamFrame(maxBytes interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), maxBytes) +} + +// updateSendWindow mocks base method. +func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "updateSendWindow", arg0) +} + +// updateSendWindow indicates an expected call of updateSendWindow. +func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) +} diff --git a/internal/quic-go/mock_sender_test.go b/internal/quic-go/mock_sender_test.go new file mode 100644 index 00000000..bad5f149 --- /dev/null +++ b/internal/quic-go/mock_sender_test.go @@ -0,0 +1,100 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: send_queue.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockSender is a mock of Sender interface. +type MockSender struct { + ctrl *gomock.Controller + recorder *MockSenderMockRecorder +} + +// MockSenderMockRecorder is the mock recorder for MockSender. +type MockSenderMockRecorder struct { + mock *MockSender +} + +// NewMockSender creates a new mock instance. +func NewMockSender(ctrl *gomock.Controller) *MockSender { + mock := &MockSender{ctrl: ctrl} + mock.recorder = &MockSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSender) EXPECT() *MockSenderMockRecorder { + return m.recorder +} + +// Available mocks base method. +func (m *MockSender) Available() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Available") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// Available indicates an expected call of Available. +func (mr *MockSenderMockRecorder) Available() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockSender)(nil).Available)) +} + +// Close mocks base method. +func (m *MockSender) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockSenderMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSender)(nil).Close)) +} + +// Run mocks base method. +func (m *MockSender) Run() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run") + ret0, _ := ret[0].(error) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockSenderMockRecorder) Run() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSender)(nil).Run)) +} + +// Send mocks base method. +func (m *MockSender) Send(p *packetBuffer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Send", p) +} + +// Send indicates an expected call of Send. +func (mr *MockSenderMockRecorder) Send(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), p) +} + +// WouldBlock mocks base method. +func (m *MockSender) WouldBlock() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WouldBlock") + ret0, _ := ret[0].(bool) + return ret0 +} + +// WouldBlock indicates an expected call of WouldBlock. +func (mr *MockSenderMockRecorder) WouldBlock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WouldBlock", reflect.TypeOf((*MockSender)(nil).WouldBlock)) +} diff --git a/internal/quic-go/mock_stream_getter_test.go b/internal/quic-go/mock_stream_getter_test.go new file mode 100644 index 00000000..71df5186 --- /dev/null +++ b/internal/quic-go/mock_stream_getter_test.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: connection.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockStreamGetter is a mock of StreamGetter interface. +type MockStreamGetter struct { + ctrl *gomock.Controller + recorder *MockStreamGetterMockRecorder +} + +// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter. +type MockStreamGetterMockRecorder struct { + mock *MockStreamGetter +} + +// NewMockStreamGetter creates a new mock instance. +func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter { + mock := &MockStreamGetter{ctrl: ctrl} + mock.recorder = &MockStreamGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder { + return m.recorder +} + +// GetOrOpenReceiveStream mocks base method. +func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) + ret0, _ := ret[0].(receiveStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. +func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) +} + +// GetOrOpenSendStream mocks base method. +func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) + ret0, _ := ret[0].(sendStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. +func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) +} diff --git a/internal/quic-go/mock_stream_internal_test.go b/internal/quic-go/mock_stream_internal_test.go new file mode 100644 index 00000000..381eb2ae --- /dev/null +++ b/internal/quic-go/mock_stream_internal_test.go @@ -0,0 +1,284 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: stream.go + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockStreamI is a mock of StreamI interface. +type MockStreamI struct { + ctrl *gomock.Controller + recorder *MockStreamIMockRecorder +} + +// MockStreamIMockRecorder is the mock recorder for MockStreamI. +type MockStreamIMockRecorder struct { + mock *MockStreamI +} + +// NewMockStreamI creates a new mock instance. +func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI { + mock := &MockStreamI{ctrl: ctrl} + mock.recorder = &MockStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder { + return m.recorder +} + +// CancelRead mocks base method. +func (m *MockStreamI) CancelRead(arg0 StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelRead", arg0) +} + +// CancelRead indicates an expected call of CancelRead. +func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) +} + +// CancelWrite mocks base method. +func (m *MockStreamI) CancelWrite(arg0 StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelWrite", arg0) +} + +// CancelWrite indicates an expected call of CancelWrite. +func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) +} + +// Close mocks base method. +func (m *MockStreamI) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockStreamIMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) +} + +// Context mocks base method. +func (m *MockStreamI) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockStreamIMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) +} + +// Read mocks base method. +func (m *MockStreamI) Read(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockStreamIMockRecorder) Read(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), p) +} + +// SetDeadline mocks base method. +func (m *MockStreamI) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockStreamIMockRecorder) SetDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), t) +} + +// SetReadDeadline mocks base method. +func (m *MockStreamI) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockStreamIMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method. +func (m *MockStreamI) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockStreamIMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), t) +} + +// StreamID mocks base method. +func (m *MockStreamI) StreamID() StreamID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID. +func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) +} + +// Write mocks base method. +func (m *MockStreamI) Write(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockStreamIMockRecorder) Write(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), p) +} + +// closeForShutdown mocks base method. +func (m *MockStreamI) closeForShutdown(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown. +func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method. +func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate. +func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) +} + +// handleResetStreamFrame mocks base method. +func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handleResetStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleResetStreamFrame indicates an expected call of handleResetStreamFrame. +func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) +} + +// handleStopSendingFrame mocks base method. +func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "handleStopSendingFrame", arg0) +} + +// handleStopSendingFrame indicates an expected call of handleStopSendingFrame. +func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) +} + +// handleStreamFrame mocks base method. +func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame. +func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) +} + +// hasData mocks base method. +func (m *MockStreamI) hasData() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "hasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasData indicates an expected call of hasData. +func (mr *MockStreamIMockRecorder) hasData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockStreamI)(nil).hasData)) +} + +// popStreamFrame mocks base method. +func (m *MockStreamI) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "popStreamFrame", maxBytes) + ret0, _ := ret[0].(*ackhandler.Frame) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// popStreamFrame indicates an expected call of popStreamFrame. +func (mr *MockStreamIMockRecorder) popStreamFrame(maxBytes interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), maxBytes) +} + +// updateSendWindow mocks base method. +func (m *MockStreamI) updateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "updateSendWindow", arg0) +} + +// updateSendWindow indicates an expected call of updateSendWindow. +func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) +} diff --git a/internal/quic-go/mock_stream_manager_test.go b/internal/quic-go/mock_stream_manager_test.go new file mode 100644 index 00000000..34d7b72c --- /dev/null +++ b/internal/quic-go/mock_stream_manager_test.go @@ -0,0 +1,231 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: connection.go + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockStreamManager is a mock of StreamManager interface. +type MockStreamManager struct { + ctrl *gomock.Controller + recorder *MockStreamManagerMockRecorder +} + +// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager. +type MockStreamManagerMockRecorder struct { + mock *MockStreamManager +} + +// NewMockStreamManager creates a new mock instance. +func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager { + mock := &MockStreamManager{ctrl: ctrl} + mock.recorder = &MockStreamManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method. +func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptStream", arg0) + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream. +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) +} + +// AcceptUniStream mocks base method. +func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) + ret0, _ := ret[0].(ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream. +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) +} + +// CloseWithError mocks base method. +func (m *MockStreamManager) CloseWithError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CloseWithError", arg0) +} + +// CloseWithError indicates an expected call of CloseWithError. +func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) +} + +// DeleteStream mocks base method. +func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteStream", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteStream indicates an expected call of DeleteStream. +func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) +} + +// GetOrOpenReceiveStream mocks base method. +func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) + ret0, _ := ret[0].(receiveStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. +func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) +} + +// GetOrOpenSendStream mocks base method. +func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) + ret0, _ := ret[0].(sendStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. +func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) +} + +// HandleMaxStreamsFrame mocks base method. +func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0) +} + +// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) +} + +// OpenStream mocks base method. +func (m *MockStreamManager) OpenStream() (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream. +func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method. +func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync. +func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) +} + +// OpenUniStream mocks base method. +func (m *MockStreamManager) OpenUniStream() (SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream. +func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method. +func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. +func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) +} + +// ResetFor0RTT mocks base method. +func (m *MockStreamManager) ResetFor0RTT() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ResetFor0RTT") +} + +// ResetFor0RTT indicates an expected call of ResetFor0RTT. +func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) +} + +// UpdateLimits mocks base method. +func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateLimits", arg0) +} + +// UpdateLimits indicates an expected call of UpdateLimits. +func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) +} + +// UseResetMaps mocks base method. +func (m *MockStreamManager) UseResetMaps() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UseResetMaps") +} + +// UseResetMaps indicates an expected call of UseResetMaps. +func (mr *MockStreamManagerMockRecorder) UseResetMaps() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) +} diff --git a/internal/quic-go/mock_stream_sender_test.go b/internal/quic-go/mock_stream_sender_test.go new file mode 100644 index 00000000..3cd97f48 --- /dev/null +++ b/internal/quic-go/mock_stream_sender_test.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: stream.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockStreamSender is a mock of StreamSender interface. +type MockStreamSender struct { + ctrl *gomock.Controller + recorder *MockStreamSenderMockRecorder +} + +// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender. +type MockStreamSenderMockRecorder struct { + mock *MockStreamSender +} + +// NewMockStreamSender creates a new mock instance. +func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { + mock := &MockStreamSender{ctrl: ctrl} + mock.recorder = &MockStreamSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { + return m.recorder +} + +// onHasStreamData mocks base method. +func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "onHasStreamData", arg0) +} + +// onHasStreamData indicates an expected call of onHasStreamData. +func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) +} + +// onStreamCompleted mocks base method. +func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "onStreamCompleted", arg0) +} + +// onStreamCompleted indicates an expected call of onStreamCompleted. +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) +} + +// queueControlFrame mocks base method. +func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "queueControlFrame", arg0) +} + +// queueControlFrame indicates an expected call of queueControlFrame. +func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) +} diff --git a/internal/quic-go/mock_token_store_test.go b/internal/quic-go/mock_token_store_test.go new file mode 100644 index 00000000..a0f02b41 --- /dev/null +++ b/internal/quic-go/mock_token_store_test.go @@ -0,0 +1,60 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: TokenStore) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTokenStore is a mock of TokenStore interface. +type MockTokenStore struct { + ctrl *gomock.Controller + recorder *MockTokenStoreMockRecorder +} + +// MockTokenStoreMockRecorder is the mock recorder for MockTokenStore. +type MockTokenStoreMockRecorder struct { + mock *MockTokenStore +} + +// NewMockTokenStore creates a new mock instance. +func NewMockTokenStore(ctrl *gomock.Controller) *MockTokenStore { + mock := &MockTokenStore{ctrl: ctrl} + mock.recorder = &MockTokenStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenStore) EXPECT() *MockTokenStoreMockRecorder { + return m.recorder +} + +// Pop mocks base method. +func (m *MockTokenStore) Pop(arg0 string) *ClientToken { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Pop", arg0) + ret0, _ := ret[0].(*ClientToken) + return ret0 +} + +// Pop indicates an expected call of Pop. +func (mr *MockTokenStoreMockRecorder) Pop(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) +} + +// Put mocks base method. +func (m *MockTokenStore) Put(arg0 string, arg1 *ClientToken) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Put", arg0, arg1) +} + +// Put indicates an expected call of Put. +func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) +} diff --git a/internal/quic-go/mock_unknown_packet_handler_test.go b/internal/quic-go/mock_unknown_packet_handler_test.go new file mode 100644 index 00000000..d82acf1a --- /dev/null +++ b/internal/quic-go/mock_unknown_packet_handler_test.go @@ -0,0 +1,58 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. +type MockUnknownPacketHandler struct { + ctrl *gomock.Controller + recorder *MockUnknownPacketHandlerMockRecorder +} + +// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler. +type MockUnknownPacketHandlerMockRecorder struct { + mock *MockUnknownPacketHandler +} + +// NewMockUnknownPacketHandler creates a new mock instance. +func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler { + mock := &MockUnknownPacketHandler{ctrl: ctrl} + mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder { + return m.recorder +} + +// handlePacket mocks base method. +func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket. +func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) +} + +// setCloseError mocks base method. +func (m *MockUnknownPacketHandler) setCloseError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "setCloseError", arg0) +} + +// setCloseError indicates an expected call of setCloseError. +func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0) +} diff --git a/internal/quic-go/mock_unpacker_test.go b/internal/quic-go/mock_unpacker_test.go new file mode 100644 index 00000000..0dca2d03 --- /dev/null +++ b/internal/quic-go/mock_unpacker_test.go @@ -0,0 +1,51 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: connection.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockUnpacker is a mock of Unpacker interface. +type MockUnpacker struct { + ctrl *gomock.Controller + recorder *MockUnpackerMockRecorder +} + +// MockUnpackerMockRecorder is the mock recorder for MockUnpacker. +type MockUnpackerMockRecorder struct { + mock *MockUnpacker +} + +// NewMockUnpacker creates a new mock instance. +func NewMockUnpacker(ctrl *gomock.Controller) *MockUnpacker { + mock := &MockUnpacker{ctrl: ctrl} + mock.recorder = &MockUnpackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { + return m.recorder +} + +// Unpack mocks base method. +func (m *MockUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unpack", hdr, rcvTime, data) + ret0, _ := ret[0].(*unpackedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Unpack indicates an expected call of Unpack. +func (mr *MockUnpackerMockRecorder) Unpack(hdr, rcvTime, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), hdr, rcvTime, data) +} diff --git a/internal/quic-go/mockgen.go b/internal/quic-go/mockgen.go new file mode 100644 index 00000000..edecbbe3 --- /dev/null +++ b/internal/quic-go/mockgen.go @@ -0,0 +1,27 @@ +package quic + +//go:generate sh -c "./mockgen_private.sh quic mock_send_conn_test.go github.com/imroc/req/v3/internal/quic-go sendConn" +//go:generate sh -c "./mockgen_private.sh quic mock_sender_test.go github.com/imroc/req/v3/internal/quic-go sender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go streamI" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/imroc/req/v3/internal/quic-go cryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go receiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go sendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/imroc/req/v3/internal/quic-go streamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/imroc/req/v3/internal/quic-go streamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_data_handler_test.go github.com/imroc/req/v3/internal/quic-go cryptoDataHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/imroc/req/v3/internal/quic-go frameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/imroc/req/v3/internal/quic-go ackFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/imroc/req/v3/internal/quic-go streamManager" +//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/imroc/req/v3/internal/quic-go sealingManager" +//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/imroc/req/v3/internal/quic-go unpacker" +//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/imroc/req/v3/internal/quic-go packer" +//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/imroc/req/v3/internal/quic-go mtuDiscoverer" +//go:generate sh -c "./mockgen_private.sh quic mock_conn_runner_test.go github.com/imroc/req/v3/internal/quic-go connRunner" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_conn_test.go github.com/imroc/req/v3/internal/quic-go quicConn" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/imroc/req/v3/internal/quic-go packetHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/imroc/req/v3/internal/quic-go unknownPacketHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/imroc/req/v3/internal/quic-go packetHandlerManager" +//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/imroc/req/v3/internal/quic-go multiplexer" +//go:generate sh -c "./mockgen_private.sh quic mock_batch_conn_test.go github.com/imroc/req/v3/internal/quic-go batchConn" +//go:generate sh -c "mockgen -package quic -self_package github.com/imroc/req/v3/internal/quic-go -destination mock_token_store_test.go github.com/imroc/req/v3/internal/quic-go TokenStore" +//go:generate sh -c "mockgen -package quic -self_package github.com/imroc/req/v3/internal/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/internal/quic-go/mockgen_private.sh b/internal/quic-go/mockgen_private.sh new file mode 100755 index 00000000..92829d77 --- /dev/null +++ b/internal/quic-go/mockgen_private.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +DEST=$2 +PACKAGE=$3 +TMPFILE="mockgen_tmp.go" +# uppercase the name of the interface +ORIG_INTERFACE_NAME=$4 +INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" + +# Gather all files that contain interface definitions. +# These interfaces might be used as embedded interfaces, +# so we need to pass them to mockgen as aux_files. +AUX=() +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + if $(egrep -qe "type (.*) interface" $f); then + AUX+=("github.com/lucas-clemente/quic-go=$f") + fi +done + +# Find the file that defines the interface we're mocking. +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) + if [[ -n "$INTERFACE" ]]; then + SRC=$f + break + fi +done + +if [[ -z "$INTERFACE" ]]; then + echo "Interface $ORIG_INTERFACE_NAME not found." + exit 1 +fi + +AUX_FILES=$(IFS=, ; echo "${AUX[*]}") + +## create a public alias for the interface, so that mockgen can process it +echo -e "package $1\n" > $TMPFILE +echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE +mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES +sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" +rm "$TMPFILE" diff --git a/internal/quic-go/mocks/ackhandler/received_packet_handler.go b/internal/quic-go/mocks/ackhandler/received_packet_handler.go new file mode 100644 index 00000000..004fed2e --- /dev/null +++ b/internal/quic-go/mocks/ackhandler/received_packet_handler.go @@ -0,0 +1,105 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/ackhandler (interfaces: ReceivedPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. +type MockReceivedPacketHandler struct { + ctrl *gomock.Controller + recorder *MockReceivedPacketHandlerMockRecorder +} + +// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler. +type MockReceivedPacketHandlerMockRecorder struct { + mock *MockReceivedPacketHandler +} + +// NewMockReceivedPacketHandler creates a new mock instance. +func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { + mock := &MockReceivedPacketHandler{ctrl: ctrl} + mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { + return m.recorder +} + +// DropPackets mocks base method. +func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets. +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) +} + +// GetAckFrame mocks base method. +func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame. +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) +} + +// GetAlarmTimeout mocks base method. +func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout. +func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) +} + +// IsPotentiallyDuplicate mocks base method. +func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPotentiallyDuplicate", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. +func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) +} + +// ReceivedPacket mocks base method. +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN, arg2 protocol.EncryptionLevel, arg3 time.Time, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) +} diff --git a/internal/quic-go/mocks/ackhandler/sent_packet_handler.go b/internal/quic-go/mocks/ackhandler/sent_packet_handler.go new file mode 100644 index 00000000..ed16a6ed --- /dev/null +++ b/internal/quic-go/mocks/ackhandler/sent_packet_handler.go @@ -0,0 +1,240 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/ackhandler (interfaces: SentPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// MockSentPacketHandler is a mock of SentPacketHandler interface. +type MockSentPacketHandler struct { + ctrl *gomock.Controller + recorder *MockSentPacketHandlerMockRecorder +} + +// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler. +type MockSentPacketHandlerMockRecorder struct { + mock *MockSentPacketHandler +} + +// NewMockSentPacketHandler creates a new mock instance. +func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { + mock := &MockSentPacketHandler{ctrl: ctrl} + mock.recorder = &MockSentPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { + return m.recorder +} + +// DropPackets mocks base method. +func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets. +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) +} + +// GetLossDetectionTimeout mocks base method. +func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLossDetectionTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. +func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) +} + +// HasPacingBudget mocks base method. +func (m *MockSentPacketHandler) HasPacingBudget() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasPacingBudget") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasPacingBudget indicates an expected call of HasPacingBudget. +func (mr *MockSentPacketHandlerMockRecorder) HasPacingBudget() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSentPacketHandler)(nil).HasPacingBudget)) +} + +// OnLossDetectionTimeout mocks base method. +func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnLossDetectionTimeout") + ret0, _ := ret[0].(error) + return ret0 +} + +// OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. +func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) +} + +// PeekPacketNumber mocks base method. +func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PeekPacketNumber", arg0) + ret0, _ := ret[0].(protocol.PacketNumber) + ret1, _ := ret[1].(protocol.PacketNumberLen) + return ret0, ret1 +} + +// PeekPacketNumber indicates an expected call of PeekPacketNumber. +func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) +} + +// PopPacketNumber mocks base method. +func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PopPacketNumber", arg0) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// PopPacketNumber indicates an expected call of PopPacketNumber. +func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) +} + +// QueueProbePacket mocks base method. +func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueueProbePacket", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// QueueProbePacket indicates an expected call of QueueProbePacket. +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) +} + +// ReceivedAck mocks base method. +func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceivedAck indicates an expected call of ReceivedAck. +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) +} + +// ReceivedBytes mocks base method. +func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedBytes", arg0) +} + +// ReceivedBytes indicates an expected call of ReceivedBytes. +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) +} + +// ResetForRetry mocks base method. +func (m *MockSentPacketHandler) ResetForRetry() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResetForRetry") + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetForRetry indicates an expected call of ResetForRetry. +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) +} + +// SendMode mocks base method. +func (m *MockSentPacketHandler) SendMode() ackhandler.SendMode { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMode") + ret0, _ := ret[0].(ackhandler.SendMode) + return ret0 +} + +// SendMode indicates an expected call of SendMode. +func (mr *MockSentPacketHandlerMockRecorder) SendMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode)) +} + +// SentPacket mocks base method. +func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) +} + +// SetHandshakeConfirmed mocks base method. +func (m *MockSentPacketHandler) SetHandshakeConfirmed() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeConfirmed") +} + +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) +} + +// SetMaxDatagramSize mocks base method. +func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) +} + +// TimeUntilSend mocks base method. +func (m *MockSentPacketHandler) TimeUntilSend() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TimeUntilSend") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend. +func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) +} diff --git a/internal/quic-go/mocks/congestion.go b/internal/quic-go/mocks/congestion.go new file mode 100644 index 00000000..6c92f86a --- /dev/null +++ b/internal/quic-go/mocks/congestion.go @@ -0,0 +1,192 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/congestion (interfaces: SendAlgorithmWithDebugInfos) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockSendAlgorithmWithDebugInfos is a mock of SendAlgorithmWithDebugInfos interface. +type MockSendAlgorithmWithDebugInfos struct { + ctrl *gomock.Controller + recorder *MockSendAlgorithmWithDebugInfosMockRecorder +} + +// MockSendAlgorithmWithDebugInfosMockRecorder is the mock recorder for MockSendAlgorithmWithDebugInfos. +type MockSendAlgorithmWithDebugInfosMockRecorder struct { + mock *MockSendAlgorithmWithDebugInfos +} + +// NewMockSendAlgorithmWithDebugInfos creates a new mock instance. +func NewMockSendAlgorithmWithDebugInfos(ctrl *gomock.Controller) *MockSendAlgorithmWithDebugInfos { + mock := &MockSendAlgorithmWithDebugInfos{ctrl: ctrl} + mock.recorder = &MockSendAlgorithmWithDebugInfosMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSendAlgorithmWithDebugInfos) EXPECT() *MockSendAlgorithmWithDebugInfosMockRecorder { + return m.recorder +} + +// CanSend mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanSend", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanSend indicates an expected call of CanSend. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) +} + +// GetCongestionWindow mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCongestionWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetCongestionWindow indicates an expected call of GetCongestionWindow. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) +} + +// HasPacingBudget mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasPacingBudget") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasPacingBudget indicates an expected call of HasPacingBudget. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget)) +} + +// InRecovery mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InRecovery") + ret0, _ := ret[0].(bool) + return ret0 +} + +// InRecovery indicates an expected call of InRecovery. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) +} + +// InSlowStart mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InSlowStart") + ret0, _ := ret[0].(bool) + return ret0 +} + +// InSlowStart indicates an expected call of InSlowStart. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) +} + +// MaybeExitSlowStart mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "MaybeExitSlowStart") +} + +// MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) +} + +// OnPacketAcked mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) +} + +// OnPacketAcked indicates an expected call of OnPacketAcked. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) +} + +// OnPacketLost mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) +} + +// OnPacketLost indicates an expected call of OnPacketLost. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) +} + +// OnPacketSent mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) +} + +// OnPacketSent indicates an expected call of OnPacketSent. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) +} + +// OnRetransmissionTimeout mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnRetransmissionTimeout", arg0) +} + +// OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) +} + +// SetMaxDatagramSize mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) +} + +// TimeUntilSend mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TimeUntilSend", arg0) + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) +} diff --git a/internal/quic-go/mocks/connection_flow_controller.go b/internal/quic-go/mocks/connection_flow_controller.go new file mode 100644 index 00000000..ee8a14ea --- /dev/null +++ b/internal/quic-go/mocks/connection_flow_controller.go @@ -0,0 +1,128 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/flowcontrol (interfaces: ConnectionFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockConnectionFlowController is a mock of ConnectionFlowController interface. +type MockConnectionFlowController struct { + ctrl *gomock.Controller + recorder *MockConnectionFlowControllerMockRecorder +} + +// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController. +type MockConnectionFlowControllerMockRecorder struct { + mock *MockConnectionFlowController +} + +// NewMockConnectionFlowController creates a new mock instance. +func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { + mock := &MockConnectionFlowController{ctrl: ctrl} + mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { + return m.recorder +} + +// AddBytesRead mocks base method. +func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead. +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method. +func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent. +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method. +func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate. +func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) +} + +// IsNewlyBlocked mocks base method. +func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNewlyBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. +func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) +} + +// Reset mocks base method. +func (m *MockConnectionFlowController) Reset() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reset") + ret0, _ := ret[0].(error) + return ret0 +} + +// Reset indicates an expected call of Reset. +func (mr *MockConnectionFlowControllerMockRecorder) Reset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) +} + +// SendWindowSize mocks base method. +func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize. +func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) +} + +// UpdateSendWindow mocks base method. +func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow. +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/internal/quic-go/mocks/crypto_setup.go b/internal/quic-go/mocks/crypto_setup.go new file mode 100644 index 00000000..86e21aa8 --- /dev/null +++ b/internal/quic-go/mocks/crypto_setup.go @@ -0,0 +1,264 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: CryptoSetup) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + handshake "github.com/imroc/req/v3/internal/quic-go/handshake" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + qtls "github.com/imroc/req/v3/internal/quic-go/qtls" +) + +// MockCryptoSetup is a mock of CryptoSetup interface. +type MockCryptoSetup struct { + ctrl *gomock.Controller + recorder *MockCryptoSetupMockRecorder +} + +// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup. +type MockCryptoSetupMockRecorder struct { + mock *MockCryptoSetup +} + +// NewMockCryptoSetup creates a new mock instance. +func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup { + mock := &MockCryptoSetup{ctrl: ctrl} + mock.recorder = &MockCryptoSetupMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder { + return m.recorder +} + +// ChangeConnectionID mocks base method. +func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ChangeConnectionID", arg0) +} + +// ChangeConnectionID indicates an expected call of ChangeConnectionID. +func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) +} + +// Close mocks base method. +func (m *MockCryptoSetup) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) +} + +// ConnectionState mocks base method. +func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(qtls.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) +} + +// Get0RTTOpener mocks base method. +func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTOpener indicates an expected call of Get0RTTOpener. +func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) +} + +// Get0RTTSealer mocks base method. +func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTSealer indicates an expected call of Get0RTTSealer. +func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) +} + +// Get1RTTOpener mocks base method. +func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get1RTTOpener") + ret0, _ := ret[0].(handshake.ShortHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get1RTTOpener indicates an expected call of Get1RTTOpener. +func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) +} + +// Get1RTTSealer mocks base method. +func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get1RTTSealer") + ret0, _ := ret[0].(handshake.ShortHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get1RTTSealer indicates an expected call of Get1RTTSealer. +func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) +} + +// GetHandshakeOpener mocks base method. +func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeOpener indicates an expected call of GetHandshakeOpener. +func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) +} + +// GetHandshakeSealer mocks base method. +func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. +func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) +} + +// GetInitialOpener mocks base method. +func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialOpener indicates an expected call of GetInitialOpener. +func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) +} + +// GetInitialSealer mocks base method. +func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialSealer indicates an expected call of GetInitialSealer. +func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) +} + +// GetSessionTicket mocks base method. +func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionTicket") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSessionTicket indicates an expected call of GetSessionTicket. +func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) +} + +// HandleMessage mocks base method. +func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleMessage indicates an expected call of HandleMessage. +func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) +} + +// RunHandshake mocks base method. +func (m *MockCryptoSetup) RunHandshake() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RunHandshake") +} + +// RunHandshake indicates an expected call of RunHandshake. +func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) +} + +// SetHandshakeConfirmed mocks base method. +func (m *MockCryptoSetup) SetHandshakeConfirmed() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeConfirmed") +} + +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. +func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) +} + +// SetLargest1RTTAcked mocks base method. +func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. +func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) +} diff --git a/internal/quic-go/mocks/logging/connection_tracer.go b/internal/quic-go/mocks/logging/connection_tracer.go new file mode 100644 index 00000000..9fd58412 --- /dev/null +++ b/internal/quic-go/mocks/logging/connection_tracer.go @@ -0,0 +1,352 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: ConnectionTracer) + +// Package mocklogging is a generated GoMock package. +package mocklogging + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + utils "github.com/imroc/req/v3/internal/quic-go/utils" + wire "github.com/imroc/req/v3/internal/quic-go/wire" + logging "github.com/imroc/req/v3/internal/quic-go/logging" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedPacket mocks base method. +func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentPacket mocks base method. +func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/quic-go/mocks/logging/tracer.go b/internal/quic-go/mocks/logging/tracer.go new file mode 100644 index 00000000..b0b9700f --- /dev/null +++ b/internal/quic-go/mocks/logging/tracer.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: Tracer) + +// Package mocklogging is a generated GoMock package. +package mocklogging + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + wire "github.com/imroc/req/v3/internal/quic-go/wire" + logging "github.com/imroc/req/v3/internal/quic-go/logging" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// TracerForConnection mocks base method. +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) + ret0, _ := ret[0].(logging.ConnectionTracer) + return ret0 +} + +// TracerForConnection indicates an expected call of TracerForConnection. +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) +} diff --git a/internal/quic-go/mocks/long_header_opener.go b/internal/quic-go/mocks/long_header_opener.go new file mode 100644 index 00000000..158941ed --- /dev/null +++ b/internal/quic-go/mocks/long_header_opener.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: LongHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockLongHeaderOpener is a mock of LongHeaderOpener interface. +type MockLongHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockLongHeaderOpenerMockRecorder +} + +// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener. +type MockLongHeaderOpenerMockRecorder struct { + mock *MockLongHeaderOpener +} + +// NewMockLongHeaderOpener creates a new mock instance. +func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { + mock := &MockLongHeaderOpener{ctrl: ctrl} + mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { + return m.recorder +} + +// DecodePacketNumber mocks base method. +func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber. +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + +// DecryptHeader mocks base method. +func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader. +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method. +func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) +} diff --git a/internal/quic-go/mocks/mockgen.go b/internal/quic-go/mocks/mockgen.go new file mode 100644 index 00000000..7d470aa2 --- /dev/null +++ b/internal/quic-go/mocks/mockgen.go @@ -0,0 +1,20 @@ +package mocks + +//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/imroc/req/v3/internal/quic-go Stream" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/imroc/req/v3/internal/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && goimports -w quic/early_conn.go" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/imroc/req/v3/internal/quic-go EarlyListener" +//go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/imroc/req/v3/internal/quic-go/logging Tracer" +//go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/imroc/req/v3/internal/quic-go/logging ConnectionTracer" +//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/imroc/req/v3/internal/quic-go/handshake ShortHeaderSealer" +//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/imroc/req/v3/internal/quic-go/handshake ShortHeaderOpener" +//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/imroc/req/v3/internal/quic-go/handshake LongHeaderOpener" +//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/imroc/req/v3/internal/quic-go/handshake CryptoSetup && sed -E 's~github.com/marten-seemann/qtls[[:alnum:]_-]*~github.com/imroc/req/v3/internal/quic-go/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go" +//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/imroc/req/v3/internal/quic-go/flowcontrol StreamFlowController" +//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/imroc/req/v3/internal/quic-go/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/imroc/req/v3/internal/quic-go/flowcontrol ConnectionFlowController" +//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/imroc/req/v3/internal/quic-go/ackhandler SentPacketHandler" +//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/imroc/req/v3/internal/quic-go/ackhandler ReceivedPacketHandler" + +// The following command produces a warning message on OSX, however, it still generates the correct mock file. +// See https://github.com/golang/mock/issues/339 for details. +//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/quic-go/mocks/quic/early_conn.go similarity index 97% rename from internal/mocks/quic/early_conn.go rename to internal/quic-go/mocks/quic/early_conn.go index 9eecbc65..7bb07774 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/quic-go/mocks/quic/early_conn.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: EarlyConnection) +// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: EarlyConnection) // Package mockquic is a generated GoMock package. package mockquic @@ -10,7 +10,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - quic "github.com/lucas-clemente/quic-go" + quic "github.com/imroc/req/v3/internal/quic-go" + qerr "github.com/imroc/req/v3/internal/quic-go/qerr" ) // MockEarlyConnection is a mock of EarlyConnection interface. @@ -67,7 +68,7 @@ func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *go } // CloseWithError mocks base method. -func (m *MockEarlyConnection) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error { +func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/internal/quic-go/mocks/quic/early_listener.go b/internal/quic-go/mocks/quic/early_listener.go new file mode 100644 index 00000000..a57247a3 --- /dev/null +++ b/internal/quic-go/mocks/quic/early_listener.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: EarlyListener) + +// Package mockquic is a generated GoMock package. +package mockquic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + quic "github.com/imroc/req/v3/internal/quic-go" +) + +// MockEarlyListener is a mock of EarlyListener interface. +type MockEarlyListener struct { + ctrl *gomock.Controller + recorder *MockEarlyListenerMockRecorder +} + +// MockEarlyListenerMockRecorder is the mock recorder for MockEarlyListener. +type MockEarlyListenerMockRecorder struct { + mock *MockEarlyListener +} + +// NewMockEarlyListener creates a new mock instance. +func NewMockEarlyListener(ctrl *gomock.Controller) *MockEarlyListener { + mock := &MockEarlyListener{ctrl: ctrl} + mock.recorder = &MockEarlyListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept", arg0) + ret0, _ := ret[0].(quic.EarlyConnection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockEarlyListener)(nil).Accept), arg0) +} + +// Addr mocks base method. +func (m *MockEarlyListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockEarlyListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockEarlyListener)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockEarlyListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockEarlyListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEarlyListener)(nil).Close)) +} diff --git a/internal/mocks/quic/stream.go b/internal/quic-go/mocks/quic/stream.go similarity index 92% rename from internal/mocks/quic/stream.go rename to internal/quic-go/mocks/quic/stream.go index dbdb3429..97a1f042 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/quic-go/mocks/quic/stream.go @@ -1,16 +1,17 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: Stream) +// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: Stream) // Package mockquic is a generated GoMock package. package mockquic import ( context "context" - "github.com/lucas-clemente/quic-go" reflect "reflect" time "time" gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" + qerr "github.com/imroc/req/v3/internal/quic-go/qerr" ) // MockStream is a mock of Stream interface. @@ -37,7 +38,7 @@ func (m *MockStream) EXPECT() *MockStreamMockRecorder { } // CancelRead mocks base method. -func (m *MockStream) CancelRead(arg0 quic.StreamErrorCode) { +func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } @@ -49,7 +50,7 @@ func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { } // CancelWrite mocks base method. -func (m *MockStream) CancelWrite(arg0 quic.StreamErrorCode) { +func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } @@ -146,10 +147,10 @@ func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Cal } // StreamID mocks base method. -func (m *MockStream) StreamID() quic.StreamID { +func (m *MockStream) StreamID() protocol.StreamID { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(quic.StreamID) + ret0, _ := ret[0].(protocol.StreamID) return ret0 } diff --git a/internal/quic-go/mocks/short_header_opener.go b/internal/quic-go/mocks/short_header_opener.go new file mode 100644 index 00000000..1109cf0b --- /dev/null +++ b/internal/quic-go/mocks/short_header_opener.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: ShortHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockShortHeaderOpener is a mock of ShortHeaderOpener interface. +type MockShortHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockShortHeaderOpenerMockRecorder +} + +// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener. +type MockShortHeaderOpenerMockRecorder struct { + mock *MockShortHeaderOpener +} + +// NewMockShortHeaderOpener creates a new mock instance. +func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { + mock := &MockShortHeaderOpener{ctrl: ctrl} + mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { + return m.recorder +} + +// DecodePacketNumber mocks base method. +func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber. +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + +// DecryptHeader mocks base method. +func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader. +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method. +func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) +} diff --git a/internal/quic-go/mocks/short_header_sealer.go b/internal/quic-go/mocks/short_header_sealer.go new file mode 100644 index 00000000..72c6cbf1 --- /dev/null +++ b/internal/quic-go/mocks/short_header_sealer.go @@ -0,0 +1,89 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: ShortHeaderSealer) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockShortHeaderSealer is a mock of ShortHeaderSealer interface. +type MockShortHeaderSealer struct { + ctrl *gomock.Controller + recorder *MockShortHeaderSealerMockRecorder +} + +// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer. +type MockShortHeaderSealerMockRecorder struct { + mock *MockShortHeaderSealer +} + +// NewMockShortHeaderSealer creates a new mock instance. +func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { + mock := &MockShortHeaderSealer{ctrl: ctrl} + mock.recorder = &MockShortHeaderSealerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { + return m.recorder +} + +// EncryptHeader mocks base method. +func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) +} + +// EncryptHeader indicates an expected call of EncryptHeader. +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) +} + +// KeyPhase mocks base method. +func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyPhase") + ret0, _ := ret[0].(protocol.KeyPhaseBit) + return ret0 +} + +// KeyPhase indicates an expected call of KeyPhase. +func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) +} + +// Overhead mocks base method. +func (m *MockShortHeaderSealer) Overhead() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Overhead") + ret0, _ := ret[0].(int) + return ret0 +} + +// Overhead indicates an expected call of Overhead. +func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) +} + +// Seal mocks base method. +func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Seal indicates an expected call of Seal. +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) +} diff --git a/internal/quic-go/mocks/stream_flow_controller.go b/internal/quic-go/mocks/stream_flow_controller.go new file mode 100644 index 00000000..66d8c2ac --- /dev/null +++ b/internal/quic-go/mocks/stream_flow_controller.go @@ -0,0 +1,140 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/imroc/req/v3/internal/quic-go/flowcontrol (interfaces: StreamFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// MockStreamFlowController is a mock of StreamFlowController interface. +type MockStreamFlowController struct { + ctrl *gomock.Controller + recorder *MockStreamFlowControllerMockRecorder +} + +// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController. +type MockStreamFlowControllerMockRecorder struct { + mock *MockStreamFlowController +} + +// NewMockStreamFlowController creates a new mock instance. +func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { + mock := &MockStreamFlowController{ctrl: ctrl} + mock.recorder = &MockStreamFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { + return m.recorder +} + +// Abandon mocks base method. +func (m *MockStreamFlowController) Abandon() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Abandon") +} + +// Abandon indicates an expected call of Abandon. +func (mr *MockStreamFlowControllerMockRecorder) Abandon() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) +} + +// AddBytesRead mocks base method. +func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead. +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method. +func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent. +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method. +func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate. +func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) +} + +// IsNewlyBlocked mocks base method. +func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNewlyBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. +func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) +} + +// SendWindowSize mocks base method. +func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize. +func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) +} + +// UpdateHighestReceived mocks base method. +func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount, arg1 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateHighestReceived", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHighestReceived indicates an expected call of UpdateHighestReceived. +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) +} + +// UpdateSendWindow mocks base method. +func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow. +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/internal/quic-go/mocks/tls/client_session_cache.go b/internal/quic-go/mocks/tls/client_session_cache.go new file mode 100644 index 00000000..e3ae2c8e --- /dev/null +++ b/internal/quic-go/mocks/tls/client_session_cache.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto/tls (interfaces: ClientSessionCache) + +// Package mocktls is a generated GoMock package. +package mocktls + +import ( + tls "crypto/tls" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockClientSessionCache is a mock of ClientSessionCache interface. +type MockClientSessionCache struct { + ctrl *gomock.Controller + recorder *MockClientSessionCacheMockRecorder +} + +// MockClientSessionCacheMockRecorder is the mock recorder for MockClientSessionCache. +type MockClientSessionCacheMockRecorder struct { + mock *MockClientSessionCache +} + +// NewMockClientSessionCache creates a new mock instance. +func NewMockClientSessionCache(ctrl *gomock.Controller) *MockClientSessionCache { + mock := &MockClientSessionCache{ctrl: ctrl} + mock.recorder = &MockClientSessionCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientSessionCache) EXPECT() *MockClientSessionCacheMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(*tls.ClientSessionState) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) +} + +// Put mocks base method. +func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Put", arg0, arg1) +} + +// Put indicates an expected call of Put. +func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) +} diff --git a/internal/quic-go/mtu_discoverer.go b/internal/quic-go/mtu_discoverer.go new file mode 100644 index 00000000..d8259fc7 --- /dev/null +++ b/internal/quic-go/mtu_discoverer.go @@ -0,0 +1,74 @@ +package quic + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type mtuDiscoverer interface { + ShouldSendProbe(now time.Time) bool + GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) +} + +const ( + // At some point, we have to stop searching for a higher MTU. + // We're happy to send a packet that's 10 bytes smaller than the actual MTU. + maxMTUDiff = 20 + // send a probe packet every mtuProbeDelay RTTs + mtuProbeDelay = 5 +) + +type mtuFinder struct { + lastProbeTime time.Time + probeInFlight bool + mtuIncreased func(protocol.ByteCount) + + rttStats *utils.RTTStats + current protocol.ByteCount + max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) +} + +var _ mtuDiscoverer = &mtuFinder{} + +func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { + return &mtuFinder{ + current: start, + rttStats: rttStats, + lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately + mtuIncreased: mtuIncreased, + max: max, + } +} + +func (f *mtuFinder) done() bool { + return f.max-f.current <= maxMTUDiff+1 +} + +func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { + if f.probeInFlight || f.done() { + return false + } + return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) +} + +func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { + size := (f.max + f.current) / 2 + f.lastProbeTime = time.Now() + f.probeInFlight = true + return ackhandler.Frame{ + Frame: &wire.PingFrame{}, + OnLost: func(wire.Frame) { + f.probeInFlight = false + f.max = size + }, + OnAcked: func(wire.Frame) { + f.probeInFlight = false + f.current = size + f.mtuIncreased(size) + }, + }, size +} diff --git a/internal/quic-go/mtu_discoverer_test.go b/internal/quic-go/mtu_discoverer_test.go new file mode 100644 index 00000000..f6701827 --- /dev/null +++ b/internal/quic-go/mtu_discoverer_test.go @@ -0,0 +1,112 @@ +package quic + +import ( + "math/rand" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MTU Discoverer", func() { + const ( + rtt = 100 * time.Millisecond + startMTU protocol.ByteCount = 1000 + maxMTU protocol.ByteCount = 2000 + ) + + var ( + d mtuDiscoverer + rttStats *utils.RTTStats + now time.Time + discoveredMTU protocol.ByteCount + ) + + BeforeEach(func() { + rttStats = &utils.RTTStats{} + rttStats.SetInitialRTT(rtt) + Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) + d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s }) + now = time.Now() + _ = discoveredMTU + }) + + It("only allows a probe 5 RTTs after the handshake completes", func() { + Expect(d.ShouldSendProbe(now)).To(BeFalse()) + Expect(d.ShouldSendProbe(now.Add(rtt * 9 / 2))).To(BeFalse()) + Expect(d.ShouldSendProbe(now.Add(rtt * 5))).To(BeTrue()) + }) + + It("doesn't allow a probe if another probe is still in flight", func() { + ping, _ := d.GetPing() + Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) + ping.OnLost(ping.Frame) + Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) + }) + + It("tries a lower size when a probe is lost", func() { + ping, size := d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1500))) + ping.OnLost(ping.Frame) + _, size = d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1250))) + }) + + It("tries a higher size and calls the callback when a probe is acknowledged", func() { + ping, size := d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1500))) + ping.OnAcked(ping.Frame) + Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) + _, size = d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1750))) + }) + + It("stops discovery after getting close enough to the MTU", func() { + var sizes []protocol.ByteCount + t := now.Add(5 * rtt) + for d.ShouldSendProbe(t) { + ping, size := d.GetPing() + ping.OnAcked(ping.Frame) + sizes = append(sizes, size) + t = t.Add(5 * rtt) + } + Expect(sizes).To(Equal([]protocol.ByteCount{1500, 1750, 1875, 1937, 1968, 1984})) + Expect(d.ShouldSendProbe(t.Add(10 * rtt))).To(BeFalse()) + }) + + It("finds the MTU", func() { + const rep = 3000 + var maxDiff protocol.ByteCount + for i := 0; i < rep; i++ { + max := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 + currentMTU := startMTU + d := newMTUDiscoverer(rttStats, startMTU, max, func(s protocol.ByteCount) { currentMTU = s }) + now := time.Now() + realMTU := protocol.ByteCount(rand.Intn(int(max-startMTU))) + startMTU + t := now.Add(mtuProbeDelay * rtt) + var count int + for d.ShouldSendProbe(t) { + if count > 25 { + Fail("too many iterations") + } + count++ + + ping, size := d.GetPing() + if size <= realMTU { + ping.OnAcked(ping.Frame) + } else { + ping.OnLost(ping.Frame) + } + t = t.Add(mtuProbeDelay * rtt) + } + diff := realMTU - currentMTU + Expect(diff).To(BeNumerically(">=", 0)) + maxDiff = utils.MaxByteCount(maxDiff, diff) + } + Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff)) + }) +}) diff --git a/internal/quic-go/multiplexer.go b/internal/quic-go/multiplexer.go new file mode 100644 index 00000000..af943300 --- /dev/null +++ b/internal/quic-go/multiplexer.go @@ -0,0 +1,107 @@ +package quic + +import ( + "bytes" + "fmt" + "net" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +var ( + connMuxerOnce sync.Once + connMuxer multiplexer +) + +type indexableConn interface { + LocalAddr() net.Addr +} + +type multiplexer interface { + AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) + RemoveConn(indexableConn) error +} + +type connManager struct { + connIDLen int + statelessResetKey []byte + tracer logging.Tracer + manager packetHandlerManager +} + +// The connMultiplexer listens on multiple net.PacketConns and dispatches +// incoming packets to the connection handler. +type connMultiplexer struct { + mutex sync.Mutex + + conns map[string] /* LocalAddr().String() */ connManager + newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests + + logger utils.Logger +} + +var _ multiplexer = &connMultiplexer{} + +func getMultiplexer() multiplexer { + connMuxerOnce.Do(func() { + connMuxer = &connMultiplexer{ + conns: make(map[string]connManager), + logger: utils.DefaultLogger.WithPrefix("muxer"), + newPacketHandlerManager: newPacketHandlerMap, + } + }) + return connMuxer +} + +func (m *connMultiplexer) AddConn( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, + tracer logging.Tracer, +) (packetHandlerManager, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + addr := c.LocalAddr() + connIndex := addr.Network() + " " + addr.String() + p, ok := m.conns[connIndex] + if !ok { + manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) + if err != nil { + return nil, err + } + p = connManager{ + connIDLen: connIDLen, + statelessResetKey: statelessResetKey, + manager: manager, + tracer: tracer, + } + m.conns[connIndex] = p + } else { + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } + if tracer != p.tracer { + return nil, fmt.Errorf("cannot use different tracers on the same packet conn") + } + } + return p.manager, nil +} + +func (m *connMultiplexer) RemoveConn(c indexableConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + if _, ok := m.conns[connIndex]; !ok { + return fmt.Errorf("cannote remove connection, connection is unknown") + } + + delete(m.conns, connIndex) + return nil +} diff --git a/internal/quic-go/multiplexer_test.go b/internal/quic-go/multiplexer_test.go new file mode 100644 index 00000000..58c3fd84 --- /dev/null +++ b/internal/quic-go/multiplexer_test.go @@ -0,0 +1,70 @@ +package quic + +import ( + "net" + + "github.com/golang/mock/gomock" + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type testConn struct { + counter int + net.PacketConn +} + +var _ = Describe("Multiplexer", func() { + It("adds a new packet conn ", func() { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) + _, err := getMultiplexer().AddConn(conn, 8, nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("recognizes when the same connection is added twice", func() { + pconn := NewMockPacketConn(mockCtrl) + pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) + pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn := testConn{PacketConn: pconn} + tracer := mocklogging.NewMockTracer(mockCtrl) + _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) + Expect(err).ToNot(HaveOccurred()) + conn.counter++ + _, err = getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) + Expect(err).ToNot(HaveOccurred()) + Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) + }) + + It("errors when adding an existing conn with a different connection ID length", func() { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) + _, err := getMultiplexer().AddConn(conn, 5, nil, nil) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 6, nil, nil) + Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) + }) + + It("errors when adding an existing conn with a different stateless rest key", func() { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) + _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil) + Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) + }) + + It("errors when adding an existing conn with different tracers", func() { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) + _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) + Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) + }) +}) diff --git a/internal/quic-go/packet_handler_map.go b/internal/quic-go/packet_handler_map.go new file mode 100644 index 00000000..119b011d --- /dev/null +++ b/internal/quic-go/packet_handler_map.go @@ -0,0 +1,489 @@ +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "hash" + "io" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type zeroRTTQueue struct { + queue []*receivedPacket + retireTimer *time.Timer +} + +var _ packetHandler = &zeroRTTQueue{} + +func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { + if len(h.queue) < protocol.Max0RTTQueueLen { + h.queue = append(h.queue, p) + } +} +func (h *zeroRTTQueue) shutdown() {} +func (h *zeroRTTQueue) destroy(error) {} +func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } +func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { + for _, p := range h.queue { + sess.handlePacket(p) + } +} + +func (h *zeroRTTQueue) Clear() { + for _, p := range h.queue { + p.buffer.Release() + } +} + +// rawConn is a connection that allow reading of a receivedPacket. +type rawConn interface { + ReadPacket() (*receivedPacket, error) + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + LocalAddr() net.Addr + io.Closer +} + +type packetHandlerMapEntry struct { + packetHandler packetHandler + is0RTTQueue bool +} + +// The packetHandlerMap stores packetHandlers, identified by connection ID. +// It is used: +// * by the server to store connections +// * when multiplexing outgoing connections to store clients +type packetHandlerMap struct { + mutex sync.Mutex + + conn rawConn + connIDLen int + + handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler + server unknownPacketHandler + numZeroRTTEntries int + + listening chan struct{} // is closed when listen returns + closed bool + + deleteRetiredConnsAfter time.Duration + zeroRTTQueueDuration time.Duration + + statelessResetEnabled bool + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash + + tracer logging.Tracer + logger utils.Logger +} + +var _ packetHandlerManager = &packetHandlerMap{} + +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + size, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return fmt.Errorf("failed to increase receive buffer size: %w", err) + } + newSize, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} + +// only print warnings about the UDP receive buffer size once +var receiveBufferWarningOnce sync.Once + +func newPacketHandlerMap( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, + tracer logging.Tracer, + logger utils.Logger, +) (packetHandlerManager, error) { + if err := setReceiveBuffer(c, logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/imroc/req/v3/internal/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } + } + conn, err := wrapConn(c) + if err != nil { + return nil, err + } + m := &packetHandlerMap{ + conn: conn, + connIDLen: connIDLen, + listening: make(chan struct{}), + handlers: make(map[string]packetHandlerMapEntry), + resetTokens: make(map[protocol.StatelessResetToken]packetHandler), + deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, + zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, + statelessResetEnabled: len(statelessResetKey) > 0, + statelessResetHasher: hmac.New(sha256.New, statelessResetKey), + tracer: tracer, + logger: logger, + } + go m.listen() + + if logger.Debug() { + go m.logUsage() + } + return m, nil +} + +func (h *packetHandlerMap) logUsage() { + ticker := time.NewTicker(2 * time.Second) + var printedZero bool + for { + select { + case <-h.listening: + return + case <-ticker.C: + } + + h.mutex.Lock() + numHandlers := len(h.handlers) + numTokens := len(h.resetTokens) + h.mutex.Unlock() + // If the number tracked handlers and tokens is zero, only print it a single time. + hasZero := numHandlers == 0 && numTokens == 0 + if !hasZero || (hasZero && !printedZero) { + h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) + printedZero = false + if hasZero { + printedZero = true + } + } + } +} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[string(id)]; ok { + h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) + return false + } + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.logger.Debugf("Adding connection ID %s.", id) + return true +} + +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + + var q *zeroRTTQueue + if entry, ok := h.handlers[string(clientDestConnID)]; ok { + if !entry.is0RTTQueue { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false + } + q = entry.packetHandler.(*zeroRTTQueue) + q.retireTimer.Stop() + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + } + sess := fn() + if q != nil { + q.EnqueueAll(sess) + } + h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) + return true +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s.", id) +} + +func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { + h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s after it has been retired.", id) + }) +} + +func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { + h.mutex.Lock() + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.mutex.Unlock() + h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) + + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + handler.shutdown() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) + }) +} + +func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { + h.mutex.Lock() + h.resetTokens[token] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + +func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { + h.mutex.Lock() + h.server = s + h.mutex.Unlock() +} + +func (h *packetHandlerMap) CloseServer() { + h.mutex.Lock() + if h.server == nil { + h.mutex.Unlock() + return + } + h.server = nil + var wg sync.WaitGroup + for _, entry := range h.handlers { + if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { + wg.Add(1) + go func(handler packetHandler) { + // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + handler.shutdown() + wg.Done() + }(entry.packetHandler) + } + } + h.mutex.Unlock() + wg.Wait() +} + +// Destroy closes the underlying connection and waits until listen() has returned. +// It does not close active connections. +func (h *packetHandlerMap) Destroy() error { + if err := h.conn.Close(); err != nil { + return err + } + <-h.listening // wait until listening returns + return nil +} + +func (h *packetHandlerMap) close(e error) error { + h.mutex.Lock() + if h.closed { + h.mutex.Unlock() + return nil + } + + var wg sync.WaitGroup + for _, entry := range h.handlers { + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(entry.packetHandler) + } + + if h.server != nil { + h.server.setCloseError(e) + } + h.closed = true + h.mutex.Unlock() + wg.Wait() + return getMultiplexer().RemoveConn(h.conn) +} + +func (h *packetHandlerMap) listen() { + defer close(h.listening) + for { + p, err := h.conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/imroc/req/v3/internal/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + h.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + h.close(err) + return + } + h.handlePacket(p) + } +} + +func (h *packetHandlerMap) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, h.connIDLen) + if err != nil { + h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if h.tracer != nil { + h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + h.mutex.Lock() + defer h.mutex.Unlock() + + if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + + if entry, ok := h.handlers[string(connID)]; ok { + if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue + if wire.Is0RTTPacket(p.data) { + entry.packetHandler.handlePacket(p) + return + } + } else { // existing connection + entry.packetHandler.handlePacket(p) + return + } + } + if p.data[0]&0x80 == 0 { + go h.maybeSendStatelessReset(p, connID) + return + } + if h.server == nil { // no server set + h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + if wire.Is0RTTPacket(p.data) { + if h.numZeroRTTEntries >= protocol.Max0RTTQueues { + return + } + h.numZeroRTTEntries++ + queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} + h.handlers[string(connID)] = packetHandlerMapEntry{ + packetHandler: queue, + is0RTTQueue: true, + } + queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { + h.mutex.Lock() + defer h.mutex.Unlock() + // The entry might have been replaced by an actual connection. + // Only delete it if it's still a 0-RTT queue. + if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { + delete(h.handlers, string(connID)) + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + entry.packetHandler.(*zeroRTTQueue).Clear() + if h.logger.Debug() { + h.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + }) + queue.handlePacket(p) + return + } + h.server.handlePacket(p) +} + +func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if data[0]&0x80 != 0 { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + var token protocol.StatelessResetToken + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go sess.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} + +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { + var token protocol.StatelessResetToken + if !h.statelessResetEnabled { + // Return a random stateless reset token. + // This token will be sent in the server's transport parameters. + // By using a random token, an off-path attacker won't be able to disrupt the connection. + rand.Read(token[:]) + return token + } + h.statelessResetMutex.Lock() + h.statelessResetHasher.Write(connID.Bytes()) + copy(token[:], h.statelessResetHasher.Sum(nil)) + h.statelessResetHasher.Reset() + h.statelessResetMutex.Unlock() + return token +} + +func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + if !h.statelessResetEnabled { + return + } + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } + token := h.GetStatelessResetToken(connID) + h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + h.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} diff --git a/internal/quic-go/packet_handler_map_test.go b/internal/quic-go/packet_handler_map_test.go new file mode 100644 index 00000000..21c1fcbe --- /dev/null +++ b/internal/quic-go/packet_handler_map_test.go @@ -0,0 +1,495 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "errors" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet Handler Map", func() { + type packetToRead struct { + addr net.Addr + data []byte + err error + } + + var ( + handler *packetHandlerMap + conn *MockPacketConn + tracer *mocklogging.MockTracer + packetChan chan packetToRead + + connIDLen int + statelessResetKey []byte + ) + + getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: t, + DestConnectionID: connID, + Length: length, + Version: protocol.VersionTLS, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }).Write(buf, protocol.VersionWhatever)).To(Succeed()) + return buf.Bytes() + } + + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) + } + + BeforeEach(func() { + statelessResetKey = nil + connIDLen = 0 + tracer = mocklogging.NewMockTracer(mockCtrl) + packetChan = make(chan packetToRead, 10) + }) + + JustBeforeEach(func() { + conn = NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + p, ok := <-packetChan + if !ok { + return 0, nil, errors.New("closed") + } + return copy(b, p.data), p.addr, p.err + }).AnyTimes() + phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) + Expect(err).ToNot(HaveOccurred()) + handler = phm.(*packetHandlerMap) + }) + + It("closes", func() { + getMultiplexer() // make the sync.Once execute + // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer + mockMultiplexer := NewMockMultiplexer(mockCtrl) + origMultiplexer := connMuxer + connMuxer = mockMultiplexer + + defer func() { + connMuxer = origMultiplexer + }() + + testErr := errors.New("test error ") + conn1 := NewMockPacketHandler(mockCtrl) + conn1.EXPECT().destroy(testErr) + conn2 := NewMockPacketHandler(mockCtrl) + conn2.EXPECT().destroy(testErr) + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2) + mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) + handler.close(testErr) + close(packetChan) + Eventually(handler.listening).Should(BeClosed()) + }) + + Context("other operations", func() { + AfterEach(func() { + // delete connections and the server before closing + // They might be mock implementations, and we'd have to register the expected calls before otherwise. + handler.mutex.Lock() + for connID := range handler.handlers { + delete(handler.handlers, connID) + } + handler.server = nil + handler.mutex.Unlock() + conn.EXPECT().Close().MaxTimes(1) + close(packetChan) + handler.Destroy() + Eventually(handler.listening).Should(BeClosed()) + }) + + Context("handling packets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + + It("handles packets for different packet handlers on the same packet conn", func() { + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + packetHandler1 := NewMockPacketHandler(mockCtrl) + packetHandler2 := NewMockPacketHandler(mockCtrl) + handledPacket1 := make(chan struct{}) + handledPacket2 := make(chan struct{}) + packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) + close(handledPacket1) + }) + packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) + close(handledPacket2) + }) + handler.Add(connID1, packetHandler1) + handler.Add(connID2, packetHandler2) + packetChan <- packetToRead{data: getPacket(connID1)} + packetChan <- packetToRead{data: getPacket(connID2)} + + Eventually(handledPacket1).Should(BeClosed()) + Eventually(handledPacket2).Should(BeClosed()) + }) + + It("drops unparseable packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: []byte{0, 1, 2, 3}, + }) + }) + + It("deletes removed connections immediately", func() { + handler.deleteRetiredConnsAfter = time.Hour + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.Add(connID, NewMockPacketHandler(mockCtrl)) + handler.Remove(connID) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + // don't EXPECT any calls to handlePacket of the MockPacketHandler + }) + + It("deletes retired connection entries after a wait time", func() { + handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + conn := NewMockPacketHandler(mockCtrl) + handler.Add(connID, conn) + handler.Retire(connID) + time.Sleep(scaleDuration(30 * time.Millisecond)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + // don't EXPECT any calls to handlePacket of the MockPacketHandler + }) + + It("passes packets arriving late for closed connections to that connection", func() { + handler.deleteRetiredConnsAfter = time.Hour + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + packetHandler := NewMockPacketHandler(mockCtrl) + handled := make(chan struct{}) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + close(handled) + }) + handler.Add(connID, packetHandler) + handler.Retire(connID) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + Eventually(handled).Should(BeClosed()) + }) + + It("drops packets for unknown receivers", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + }) + + It("closes the packet handlers when reading from the conn fails", func() { + done := make(chan struct{}) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { + Expect(e).To(HaveOccurred()) + close(done) + }) + handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + packetChan <- packetToRead{err: errors.New("read failed")} + Eventually(done).Should(BeClosed()) + }) + + It("continues listening for temporary errors", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + err := deadlineError{} + Expect(err.Temporary()).To(BeTrue()) + packetChan <- packetToRead{err: err} + // don't EXPECT any calls to packetHandler.destroy + time.Sleep(50 * time.Millisecond) + }) + + It("says if a connection ID is already taken", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) + }) + + It("says if a connection ID is already taken, for AddWithConnID", func() { + clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + newConnID1 := protocol.ConnectionID{1, 2, 3, 4} + newConnID2 := protocol.ConnectionID{4, 3, 2, 1} + Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) + Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) + }) + }) + + Context("running a server", func() { + It("adds a server", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) + }) + handler.SetServer(server) + handler.handlePacket(&receivedPacket{data: p}) + }) + + It("closes all server connections", func() { + handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) + clientConn := NewMockPacketHandler(mockCtrl) + clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) + serverConn := NewMockPacketHandler(mockCtrl) + serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) + serverConn.EXPECT().shutdown() + + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn) + handler.CloseServer() + }) + + It("stops handling packets with unknown connection IDs after the server is closed", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + handler.CloseServer() + handler.handlePacket(&receivedPacket{data: p}) + }) + }) + + Context("0-RTT", func() { + JustBeforeEach(func() { + handler.zeroRTTQueueDuration = time.Hour + server := NewMockUnknownPacketHandler(mockCtrl) + // we don't expect any calls to server.handlePacket + handler.SetServer(server) + }) + + It("queues 0-RTT packets", func() { + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} + p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} + handler.handlePacket(p1) + handler.handlePacket(p2) + handler.handlePacket(p3) + conn := NewMockPacketHandler(mockCtrl) + done := make(chan struct{}) + gomock.InOrder( + conn.EXPECT().handlePacket(p1), + conn.EXPECT().handlePacket(p2), + conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), + ) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + Eventually(done).Should(BeClosed()) + }) + + It("directs 0-RTT packets to existing connections", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + conn.EXPECT().handlePacket(p1) + handler.handlePacket(p1) + }) + + It("limits the number of 0-RTT queues", func() { + for i := 0; i < protocol.Max0RTTQueues; i++ { + connID := make(protocol.ConnectionID, 8) + rand.Read(connID) + p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + handler.handlePacket(p) + } + // We're already storing the maximum number of queues. This packet will be dropped. + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} + handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) + // Don't EXPECT any handlePacket() calls. + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + time.Sleep(20 * time.Millisecond) + }) + + It("deletes queues if no connection is created for this connection ID", func() { + queueDuration := scaleDuration(10 * time.Millisecond) + handler.zeroRTTQueueDuration = queueDuration + + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p1 := &receivedPacket{ + data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), + buffer: getPacketBuffer(), + } + p2 := &receivedPacket{ + data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2), + buffer: getPacketBuffer(), + } + handler.handlePacket(p1) + handler.handlePacket(p2) + // wait a bit. The queue should now already be deleted. + time.Sleep(queueDuration * 3) + // Don't EXPECT any handlePacket() calls. + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + time.Sleep(20 * time.Millisecond) + }) + }) + + Context("stateless resets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + + Context("handling", func() { + It("handles stateless resets", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + destroyed := make(chan struct{}) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() + defer close(destroyed) + Expect(err).To(HaveOccurred()) + var resetErr *StatelessResetError + Expect(errors.As(err, &resetErr)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("received a stateless reset")) + Expect(resetErr.Token).To(Equal(token)) + }) + packetChan <- packetToRead{data: packet} + Eventually(destroyed).Should(BeClosed()) + }) + + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + destroyed := make(chan struct{}) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() + Expect(err).To(HaveOccurred()) + var resetErr *StatelessResetError + Expect(errors.As(err, &resetErr)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("received a stateless reset")) + Expect(resetErr.Token).To(Equal(token)) + close(destroyed) + }) + packetChan <- packetToRead{data: packet} + Eventually(destroyed).Should(BeClosed()) + }) + + It("removes reset tokens", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RemoveResetToken(token) + // don't EXPECT any call to packetHandler.destroy() + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + + handler.handlePacket(&receivedPacket{data: p}) + }) + + It("ignores packets too small to contain a stateless reset", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + done := make(chan struct{}) + // don't EXPECT any calls here, but register the closing of the done channel + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { + close(done) + }).AnyTimes() + packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} + Consistently(done).ShouldNot(BeClosed()) + }) + }) + + Context("generating", func() { + BeforeEach(func() { + key := make([]byte, 32) + rand.Read(key) + statelessResetKey = key + }) + + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { + defer close(done) + Expect(b[0] & 0x80).To(BeZero()) // short header packet + Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) + }) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send stateless resets for small packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + }) + + Context("if no key is configured", func() { + It("doesn't send stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + }) + }) + }) +}) diff --git a/internal/quic-go/packet_packer.go b/internal/quic-go/packet_packer.go new file mode 100644 index 00000000..de9ce11c --- /dev/null +++ b/internal/quic-go/packet_packer.go @@ -0,0 +1,894 @@ +package quic + +import ( + "bytes" + "errors" + "fmt" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type packer interface { + PackCoalescedPacket() (*coalescedPacket, error) + PackPacket() (*packedPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) + MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) + PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) + + SetMaxPacketSize(protocol.ByteCount) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) + + HandleTransportParameters(*wire.TransportParameters) + SetToken([]byte) +} + +type sealer interface { + handshake.LongHeaderSealer +} + +type payload struct { + frames []ackhandler.Frame + ack *wire.AckFrame + length protocol.ByteCount +} + +type packedPacket struct { + buffer *packetBuffer + *packetContents +} + +type packetContents struct { + header *wire.ExtendedHeader + ack *wire.AckFrame + frames []ackhandler.Frame + + length protocol.ByteCount + + isMTUProbePacket bool +} + +type coalescedPacket struct { + buffer *packetBuffer + packets []*packetContents +} + +func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { + if !p.header.IsLongHeader { + return protocol.Encryption1RTT + } + //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + case protocol.PacketType0RTT: + return protocol.Encryption0RTT + default: + panic("can't determine encryption level") + } +} + +func (p *packetContents) IsAckEliciting() bool { + return ackhandler.HasAckElicitingFrames(p.frames) +} + +func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { + largestAcked := protocol.InvalidPacketNumber + if p.ack != nil { + largestAcked = p.ack.LargestAcked() + } + encLevel := p.EncryptionLevel() + for i := range p.frames { + if p.frames[i].OnLost != nil { + continue + } + switch encLevel { + case protocol.EncryptionInitial: + p.frames[i].OnLost = q.AddInitial + case protocol.EncryptionHandshake: + p.frames[i].OnLost = q.AddHandshake + case protocol.Encryption0RTT, protocol.Encryption1RTT: + p.frames[i].OnLost = q.AddAppData + } + } + return &ackhandler.Packet{ + PacketNumber: p.header.PacketNumber, + LargestAcked: largestAcked, + Frames: p.frames, + Length: p.length, + EncryptionLevel: encLevel, + SendTime: now, + IsPathMTUProbePacket: p.isMTUProbePacket, + } +} + +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if utils.IsIPv4(udpAddr.IP) { + maxSize = protocol.InitialPacketSizeIPv4 + } else { + maxSize = protocol.InitialPacketSizeIPv6 + } + } + return maxSize +} + +type packetNumberManager interface { + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) + PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber +} + +type sealingManager interface { + GetInitialSealer() (handshake.LongHeaderSealer, error) + GetHandshakeSealer() (handshake.LongHeaderSealer, error) + Get0RTTSealer() (handshake.LongHeaderSealer, error) + Get1RTTSealer() (handshake.ShortHeaderSealer, error) +} + +type frameSource interface { + HasData() bool + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) +} + +type ackFrameSource interface { + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame +} + +type packetPacker struct { + srcConnID protocol.ConnectionID + getDestConnID func() protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + cryptoSetup sealingManager + + initialStream cryptoStream + handshakeStream cryptoStream + + token []byte + + pnManager packetNumberManager + framer frameSource + acks ackFrameSource + datagramQueue *datagramQueue + retransmissionQueue *retransmissionQueue + + maxPacketSize protocol.ByteCount + numNonAckElicitingAcks int +} + +var _ packer = &packetPacker{} + +func newPacketPacker( + srcConnID protocol.ConnectionID, + getDestConnID func() protocol.ConnectionID, + initialStream cryptoStream, + handshakeStream cryptoStream, + packetNumberManager packetNumberManager, + retransmissionQueue *retransmissionQueue, + remoteAddr net.Addr, // only used for determining the max packet size + cryptoSetup sealingManager, + framer frameSource, + acks ackFrameSource, + datagramQueue *datagramQueue, + perspective protocol.Perspective, + version protocol.VersionNumber, +) *packetPacker { + return &packetPacker{ + cryptoSetup: cryptoSetup, + getDestConnID: getDestConnID, + srcConnID: srcConnID, + initialStream: initialStream, + handshakeStream: handshakeStream, + retransmissionQueue: retransmissionQueue, + datagramQueue: datagramQueue, + perspective: perspective, + version: version, + framer: framer, + acks: acks, + pnManager: packetNumberManager, + maxPacketSize: getMaxPacketSize(remoteAddr), + } +} + +// PackConnectionClose packs a packet that closes the connection with a transport error. +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) { + var reason string + // don't send details of crypto errors + if !e.ErrorCode.IsCryptoError() { + reason = e.ErrorMessage + } + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason) +} + +// PackApplicationClose packs a packet that closes the connection with an application error. +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) { + return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage) +} + +func (p *packetPacker) packConnectionClose( + isApplicationError bool, + errorCode uint64, + frameType uint64, + reason string, +) (*coalescedPacket, error) { + var sealers [4]sealer + var hdrs [4]*wire.ExtendedHeader + var payloads [4]*payload + var size protocol.ByteCount + var numPackets uint8 + encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} + for i, encLevel := range encLevels { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { + continue + } + ccf := &wire.ConnectionCloseFrame{ + IsApplicationError: isApplicationError, + ErrorCode: errorCode, + FrameType: frameType, + ReasonPhrase: reason, + } + // don't send application errors in Initial or Handshake packets + if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { + ccf.IsApplicationError = false + ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) + ccf.ReasonPhrase = "" + } + payload := &payload{ + frames: []ackhandler.Frame{{Frame: ccf}}, + length: ccf.Length(p.version), + } + + var sealer sealer + var err error + var keyPhase protocol.KeyPhaseBit // only set for 1-RTT + switch encLevel { + case protocol.EncryptionInitial: + sealer, err = p.cryptoSetup.GetInitialSealer() + case protocol.EncryptionHandshake: + sealer, err = p.cryptoSetup.GetHandshakeSealer() + case protocol.Encryption0RTT: + sealer, err = p.cryptoSetup.Get0RTTSealer() + case protocol.Encryption1RTT: + var s handshake.ShortHeaderSealer + s, err = p.cryptoSetup.Get1RTTSealer() + if err == nil { + keyPhase = s.KeyPhase() + } + sealer = s + } + if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { + continue + } + if err != nil { + return nil, err + } + sealers[i] = sealer + var hdr *wire.ExtendedHeader + if encLevel == protocol.Encryption1RTT { + hdr = p.getShortHeader(keyPhase) + } else { + hdr = p.getLongHeader(encLevel) + } + hdrs[i] = hdr + payloads[i] = payload + size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) + numPackets++ + } + contents := make([]*packetContents, 0, numPackets) + buffer := getPacketBuffer() + for i, encLevel := range encLevels { + if sealers[i] == nil { + continue + } + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payloads[i].frames, size) + } + c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) + if err != nil { + return nil, err + } + contents = append(contents, c) + } + return &coalescedPacket{buffer: buffer, packets: contents}, nil +} + +// packetLength calculates the length of the serialized packet. +// It takes into account that packets that have a tiny payload need to be padded, +// such that len(payload) + packet number len >= 4 + AEAD overhead +func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(hdr.PacketNumberLen) + if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + return hdr.GetLength(p.version) + payload.length + paddingLen +} + +func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { + var encLevel protocol.EncryptionLevel + var ack *wire.AckFrame + if !handshakeConfirmed { + ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) + if ack != nil { + encLevel = protocol.EncryptionInitial + } else { + ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) + if ack != nil { + encLevel = protocol.EncryptionHandshake + } + } + } + if ack == nil { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) + if ack == nil { + return nil, nil + } + encLevel = protocol.Encryption1RTT + } + payload := &payload{ + ack: ack, + length: ack.Length(p.version), + } + + sealer, hdr, err := p.getSealerAndHeader(encLevel) + if err != nil { + return nil, err + } + return p.writeSinglePacket(hdr, payload, encLevel, sealer) +} + +// size is the expected size of the packet, if no padding was applied. +func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { + // For the server, only ack-eliciting Initial packets need to be padded. + if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { + return 0 + } + if size >= p.maxPacketSize { + return 0 + } + return p.maxPacketSize - size +} + +// PackCoalescedPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. +func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { + maxPacketSize := p.maxPacketSize + if p.perspective == protocol.PerspectiveClient { + maxPacketSize = protocol.MinInitialPacketSize + } + var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader + var initialPayload, handshakePayload, appDataPayload *payload + var numPackets int + // Try packing an Initial packet. + initialSealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + var size protocol.ByteCount + if initialSealer != nil { + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) + if initialPayload != nil { + size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) + numPackets++ + } + } + + // Add a Handshake packet. + var handshakeSealer sealer + if size < maxPacketSize-protocol.MinCoalescedPacketSize { + var err error + handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if handshakeSealer != nil { + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) + if handshakePayload != nil { + s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) + size += s + numPackets++ + } + } + } + + // Add a 0-RTT / 1-RTT packet. + var appDataSealer sealer + appDataEncLevel := protocol.Encryption1RTT + if size < maxPacketSize-protocol.MinCoalescedPacketSize { + var err error + appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size) + if err != nil { + return nil, err + } + if appDataHdr != nil { + if appDataHdr.IsLongHeader { + appDataEncLevel = protocol.Encryption0RTT + } + if appDataPayload != nil { + size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) + numPackets++ + } + } + } + + if numPackets == 0 { + return nil, nil + } + + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + packets: make([]*packetContents, 0, numPackets), + } + if initialPayload != nil { + padding := p.initialPaddingLen(initialPayload.frames, size) + cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + if handshakePayload != nil { + cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + if appDataPayload != nil { + cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + return packet, nil +} + +// PackPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) PackPacket() (*packedPacket, error) { + sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0) + if payload == nil { + return nil, nil + } + buffer := getPacketBuffer() + encLevel := protocol.Encryption1RTT + if hdr.IsLongHeader { + encLevel = protocol.Encryption0RTT + } + cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: cont, + }, nil +} + +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { + var s cryptoStream + var hasRetransmission bool + //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. + switch encLevel { + case protocol.EncryptionInitial: + s = p.initialStream + hasRetransmission = p.retransmissionQueue.HasInitialData() + case protocol.EncryptionHandshake: + s = p.handshakeStream + hasRetransmission = p.retransmissionQueue.HasHandshakeData() + } + + hasData := s.HasData() + var ack *wire.AckFrame + if encLevel == protocol.EncryptionInitial || currentSize == 0 { + ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) + } + if !hasData && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + + var payload payload + if ack != nil { + payload.ack = ack + payload.length = ack.Length(p.version) + maxPacketSize -= payload.length + } + hdr := p.getLongHeader(encLevel) + maxPacketSize -= hdr.GetLength(p.version) + if hasRetransmission { + for { + var f wire.Frame + //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s + switch encLevel { + case protocol.EncryptionInitial: + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) + case protocol.EncryptionHandshake: + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) + } + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + frameLen := f.Length(p.version) + payload.length += frameLen + maxPacketSize -= frameLen + } + } else if s.HasData() { + cf := s.PopCryptoFrame(maxPacketSize) + payload.frames = []ackhandler.Frame{{Frame: cf}} + payload.length += cf.Length(p.version) + } + return hdr, &payload +} + +func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) { + var sealer sealer + var encLevel protocol.EncryptionLevel + var hdr *wire.ExtendedHeader + oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() + if err == nil { + encLevel = protocol.Encryption1RTT + sealer = oneRTTSealer + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + } else { + // 1-RTT sealer not yet available + if p.perspective != protocol.PerspectiveClient { + return nil, nil, nil + } + sealer, err = p.cryptoSetup.Get0RTTSealer() + if sealer == nil || err != nil { + return nil, nil, nil + } + encLevel = protocol.Encryption0RTT + hdr = p.getLongHeader(protocol.Encryption0RTT) + } + + maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) + payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) + return sealer, hdr, payload +} + +func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { + payload := p.composeNextPacket(maxPayloadSize, ackAllowed) + + // check if we have anything to send + if len(payload.frames) == 0 { + if payload.ack == nil { + return nil + } + // the packet only contains an ACK + if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { + ping := &wire.PingFrame{} + // don't retransmit the PING frame when it is lost + payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}}) + payload.length += ping.Length(p.version) + p.numNonAckElicitingAcks = 0 + } else { + p.numNonAckElicitingAcks++ + } + } else { + p.numNonAckElicitingAcks = 0 + } + return payload +} + +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { + payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} + + var hasDatagram bool + if p.datagramQueue != nil { + if datagram := p.datagramQueue.Get(); datagram != nil { + payload.frames = append(payload.frames, ackhandler.Frame{ + Frame: datagram, + // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + OnLost: func(wire.Frame) {}, + }) + payload.length += datagram.Length(p.version) + hasDatagram = true + } + } + + var ack *wire.AckFrame + hasData := p.framer.HasData() + hasRetransmission := p.retransmissionQueue.HasAppData() + // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued + if !hasDatagram && ackAllowed { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) + if ack != nil { + payload.ack = ack + payload.length += ack.Length(p.version) + } + } + + if ack == nil && !hasData && !hasRetransmission { + return payload + } + + if hasRetransmission { + for { + remainingLen := maxFrameSize - payload.length + if remainingLen < protocol.MinStreamFrameSize { + break + } + f := p.retransmissionQueue.GetAppDataFrame(remainingLen) + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + payload.length += f.Length(p.version) + } + } + + if hasData { + var lengthAdded protocol.ByteCount + payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + + payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + } + return payload +} + +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { + var hdr *wire.ExtendedHeader + var payload *payload + var sealer sealer + //nolint:exhaustive // Probe packets are never sent for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) + case protocol.EncryptionHandshake: + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) + case protocol.Encryption1RTT: + oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + sealer = oneRTTSealer + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) + default: + panic("unknown encryption level") + } + if payload == nil { + return nil, nil + } + size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) + var padding protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + padding = p.initialPaddingLen(payload.frames, size) + } + buffer := getPacketBuffer() + cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: cont, + }, nil +} + +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { + payload := &payload{ + frames: []ackhandler.Frame{ping}, + length: ping.Length(p.version), + } + buffer := getPacketBuffer() + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) + contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) + if err != nil { + return nil, err + } + contents.isMTUProbePacket = true + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil +} + +func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { + switch encLevel { + case protocol.EncryptionInitial: + sealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionInitial) + return sealer, hdr, nil + case protocol.Encryption0RTT: + sealer, err := p.cryptoSetup.Get0RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.Encryption0RTT) + return sealer, hdr, nil + case protocol.EncryptionHandshake: + sealer, err := p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionHandshake) + return sealer, hdr, nil + case protocol.Encryption1RTT: + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + return sealer, hdr, nil + default: + return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) + } +} + +func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdr := &wire.ExtendedHeader{} + hdr.PacketNumber = pn + hdr.PacketNumberLen = pnLen + hdr.DestConnectionID = p.getDestConnID() + hdr.KeyPhase = kp + return hdr +} + +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) + hdr := &wire.ExtendedHeader{ + PacketNumber: pn, + PacketNumberLen: pnLen, + } + hdr.IsLongHeader = true + hdr.Version = p.version + hdr.SrcConnectionID = p.srcConnID + hdr.DestConnectionID = p.getDestConnID() + + //nolint:exhaustive // 1-RTT packets are not long header packets. + switch encLevel { + case protocol.EncryptionInitial: + hdr.Type = protocol.PacketTypeInitial + hdr.Token = p.token + case protocol.EncryptionHandshake: + hdr.Type = protocol.PacketTypeHandshake + case protocol.Encryption0RTT: + hdr.Type = protocol.PacketType0RTT + } + return hdr +} + +// writeSinglePacket packs a single packet. +func (p *packetPacker) writeSinglePacket( + hdr *wire.ExtendedHeader, + payload *payload, + encLevel protocol.EncryptionLevel, + sealer sealer, +) (*packedPacket, error) { + buffer := getPacketBuffer() + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) + } + contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil +} + +func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + paddingLen += padding + if header.IsLongHeader { + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen + } + + hdrOffset := buffer.Len() + buf := bytes.NewBuffer(buffer.Data) + if err := header.Write(buf, p.version); err != nil { + return nil, err + } + payloadOffset := buf.Len() + + if payload.ack != nil { + if err := payload.ack.Write(buf, p.version); err != nil { + return nil, err + } + } + if paddingLen > 0 { + buf.Write(make([]byte, paddingLen)) + } + for _, frame := range payload.frames { + if err := frame.Write(buf, p.version); err != nil { + return nil, err + } + } + + if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { + return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + } + if !isMTUProbePacket { + if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + } + } + + raw := buffer.Data + // encrypt the packet + raw = raw[:buf.Len()] + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) + raw = raw[0 : buf.Len()+sealer.Overhead()] + // apply header protection + pnOffset := payloadOffset - int(header.PacketNumberLen) + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) + buffer.Data = raw + + num := p.pnManager.PopPacketNumber(encLevel) + if num != header.PacketNumber { + return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + } + return &packetContents{ + header: header, + ack: payload.ack, + frames: payload.frames, + length: buffer.Len() - hdrOffset, + }, nil +} + +func (p *packetPacker) SetToken(token []byte) { + p.token = token +} + +// When a higher MTU is discovered, use it. +func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { + p.maxPacketSize = s +} + +// If the peer sets a max_packet_size that's smaller than the size we're currently using, +// we need to reduce the size of packets we send. +func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { + if params.MaxUDPPayloadSize != 0 { + p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxUDPPayloadSize) + } +} diff --git a/internal/quic-go/packet_packer_test.go b/internal/quic-go/packet_packer_test.go new file mode 100644 index 00000000..d069d0cc --- /dev/null +++ b/internal/quic-go/packet_packer_test.go @@ -0,0 +1,1556 @@ +package quic + +import ( + "bytes" + "fmt" + "math/rand" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/mocks" + mockackhandler "github.com/imroc/req/v3/internal/quic-go/mocks/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet packer", func() { + const maxPacketSize protocol.ByteCount = 1357 + const version = protocol.VersionTLS + + var ( + packer *packetPacker + retransmissionQueue *retransmissionQueue + datagramQueue *datagramQueue + framer *MockFrameSource + ackFramer *MockAckFrameSource + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream + sealingManager *MockSealingManager + pnManager *mockackhandler.MockSentPacketHandler + ) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + + parsePacket := func(data []byte) []*wire.ExtendedHeader { + var hdrs []*wire.ExtendedHeader + for len(data) > 0 { + hdr, payload, rest, err := wire.ParsePacket(data, connID.Len()) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(r, version) + Expect(err).ToNot(HaveOccurred()) + if extHdr.IsLongHeader { + ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(rest) + int(extHdr.PacketNumberLen))) + ExpectWithOffset(1, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) + } else { + ExpectWithOffset(1, len(payload)+int(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) + } + data = rest + hdrs = append(hdrs, extHdr) + } + return hdrs + } + + appendFrames := func(fs, frames []ackhandler.Frame) ([]ackhandler.Frame, protocol.ByteCount) { + var length protocol.ByteCount + for _, f := range frames { + length += f.Frame.Length(packer.version) + } + return append(fs, frames...), length + } + + expectAppendStreamFrames := func(frames ...ackhandler.Frame) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + return appendFrames(fs, frames) + }) + } + + expectAppendControlFrames := func(frames ...ackhandler.Frame) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + return appendFrames(fs, frames) + }) + } + + BeforeEach(func() { + rand.Seed(GinkgoRandomSeed()) + retransmissionQueue = newRetransmissionQueue(version) + mockSender := NewMockStreamSender(mockCtrl) + mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() + initialStream = NewMockCryptoStream(mockCtrl) + handshakeStream = NewMockCryptoStream(mockCtrl) + framer = NewMockFrameSource(mockCtrl) + ackFramer = NewMockAckFrameSource(mockCtrl) + sealingManager = NewMockSealingManager(mockCtrl) + pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) + datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) + + packer = newPacketPacker( + protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + func() protocol.ConnectionID { return connID }, + initialStream, + handshakeStream, + pnManager, + retransmissionQueue, + &net.TCPAddr{}, + sealingManager, + framer, + ackFramer, + datagramQueue, + protocol.PerspectiveServer, + version, + ) + packer.version = version + packer.maxPacketSize = maxPacketSize + }) + + Context("determining the maximum packet size", func() { + It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { + Expect(getMaxPacketSize(&net.TCPAddr{})).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + }) + + It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { + addr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.InitialPacketSizeIPv4)) + }) + + It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { + ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") + addr := &net.UDPAddr{IP: ip, Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.InitialPacketSizeIPv6)) + }) + }) + + Context("generating a packet header", func() { + It("uses the Long Header format", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen3) + h := packer.getLongHeader(protocol.EncryptionHandshake) + Expect(h.IsLongHeader).To(BeTrue()) + Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) + Expect(h.Version).To(Equal(packer.version)) + }) + + It("sets source and destination connection ID", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + packer.srcConnID = srcConnID + packer.getDestConnID = func() protocol.ConnectionID { return destConnID } + h := packer.getLongHeader(protocol.EncryptionHandshake) + Expect(h.SrcConnectionID).To(Equal(srcConnID)) + Expect(h.DestConnectionID).To(Equal(destConnID)) + }) + + It("gets a short header", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4) + h := packer.getShortHeader(protocol.KeyPhaseOne) + Expect(h.IsLongHeader).To(BeFalse()) + Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(h.KeyPhase).To(Equal(protocol.KeyPhaseOne)) + }) + }) + + Context("encrypting packets", func() { + It("encrypts a packet", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337)) + sealer := mocks.NewMockShortHeaderSealer(mockCtrl) + sealer.EXPECT().Overhead().Return(4).AnyTimes() + var hdrRaw []byte + gomock.InOrder( + sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne), + sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), gomock.Any()).DoAndReturn(func(_, src []byte, _ protocol.PacketNumber, aad []byte) []byte { + hdrRaw = append([]byte{}, aad...) + return append(src, []byte{0xde, 0xca, 0xfb, 0xad}...) + }), + sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(sample []byte, firstByte *byte, pnBytes []byte) { + Expect(firstByte).To(Equal(&hdrRaw[0])) + Expect(pnBytes).To(Equal(hdrRaw[len(hdrRaw)-2:])) + *firstByte ^= 0xff // invert the first byte + // invert the packet number bytes + for i := range pnBytes { + pnBytes[i] ^= 0xff + } + }), + ) + framer.EXPECT().HasData().Return(true) + sealingManager.EXPECT().GetInitialSealer().Return(nil, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendControlFrames() + f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + hdrRawEncrypted := append([]byte{}, hdrRaw...) + hdrRawEncrypted[0] ^= 0xff + hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff + hdrRawEncrypted[len(hdrRaw)-1] ^= 0xff + Expect(p.buffer.Data[0:len(hdrRaw)]).To(Equal(hdrRawEncrypted)) + Expect(p.buffer.Data[p.buffer.Len()-4:]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + }) + }) + + Context("packing packets", func() { + // getSealer gets a sealer that's expected to seal exactly one packet + getSealer := func() *mocks.MockShortHeaderSealer { + sealer := mocks.NewMockShortHeaderSealer(mockCtrl) + sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes() + sealer.EXPECT().Overhead().Return(7).AnyTimes() + sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { + return append(src, bytes.Repeat([]byte{'s'}, sealer.Overhead())...) + }).AnyTimes() + return sealer + } + + Context("packing ACK packets", func() { + It("doesn't pack a packet if there's no ACK to send", func() { + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) + p, err := packer.MaybePackAckPacket(false) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("packs Initial ACK-only packets, and pads them (for the client)", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) + p, err := packer.MaybePackAckPacket(false) + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.ack).To(Equal(ack)) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + parsePacket(p.buffer.Data) + }) + + It("packs Initial ACK-only packets, and doesn't pads them (for the server)", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) + p, err := packer.MaybePackAckPacket(false) + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.ack).To(Equal(ack)) + parsePacket(p.buffer.Data) + }) + + It("packs 1-RTT ACK-only packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) + p, err := packer.MaybePackAckPacket(true) + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.ack).To(Equal(ack)) + parsePacket(p.buffer.Data) + }) + }) + + Context("packing 0-RTT packets", func() { + BeforeEach(func() { + packer.perspective = protocol.PerspectiveClient + sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable).AnyTimes() + initialStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).AnyTimes() + handshakeStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).AnyTimes() + }) + + It("packs a 0-RTT packet", func() { + sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil).AnyTimes() + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) + cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{MaximumData: 0x1337}} + framer.EXPECT().HasData().Return(true) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + Expect(frames).To(BeEmpty()) + return append(frames, cf), cf.Length(packer.version) + }) + // TODO: check sizes + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + return frames, 0 + }) + p, err := packer.PackCoalescedPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{cf})) + }) + }) + + Context("packing CONNECTION_CLOSE", func() { + It("clears the reason phrase for crypto errors", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + quicErr := qerr.NewCryptoError(0x42, "crypto error") + quicErr.FrameType = 0x1234 + p, err := packer.PackConnectionClose(quicErr) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(0x100 + 0x42)) + Expect(ccf.FrameType).To(BeEquivalentTo(0x1234)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + }) + + It("packs a CONNECTION_CLOSE in 1-RTT", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + // expect no framer.PopStreamFrames + p, err := packer.PackConnectionClose(&qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: "test error", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.IsLongHeader).To(BeFalse()) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.CryptoBufferExceeded)) + Expect(ccf.ReasonPhrase).To(Equal("test error")) + }) + + It("packs a CONNECTION_CLOSE in all available encryption levels, and replaces application errors in Initial and Handshake", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1)) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(3)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + Expect(p.packets[1].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + Expect(p.packets[2].header.IsLongHeader).To(BeFalse()) + Expect(p.packets[2].header.PacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(p.packets[2].frames).To(HaveLen(1)) + Expect(p.packets[2].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.packets[2].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeTrue()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) + Expect(ccf.ReasonPhrase).To(Equal("test error")) + }) + + It("packs a CONNECTION_CLOSE in all available encryption levels, as a client", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(1)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(2)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(2)) + Expect(p.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + Expect(p.packets[1].header.IsLongHeader).To(BeFalse()) + Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeTrue()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) + Expect(ccf.ReasonPhrase).To(Equal("test error")) + }) + + It("packs a CONNECTION_CLOSE in all available encryption levels and pads, as a client", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(2)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(2)) + Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + Expect(p.packets[1].header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeTrue()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) + Expect(ccf.ReasonPhrase).To(Equal("test error")) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) + }) + }) + + Context("packing normal packets", func() { + It("returns nil when no packet is queued", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + // don't expect any calls to PopPacketNumber + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) + framer.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(p).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("packs single packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendControlFrames() + f := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + } + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + b := &bytes.Buffer{} + f.Write(b, packer.version) + Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.buffer.Data).To(ContainSubstring(b.String())) + }) + + It("stores the encryption level a packet was sealed with", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + }}) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + }) + + It("packs a single ACK", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} + framer.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + p, err := packer.PackPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.ack).To(Equal(ack)) + }) + + It("packs control frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + frames := []ackhandler.Frame{ + {Frame: &wire.ResetStreamFrame{}}, + {Frame: &wire.MaxDataFrame{}}, + } + expectAppendControlFrames(frames...) + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal(frames)) + Expect(p.buffer.Len()).ToNot(BeZero()) + }) + + It("packs DATAGRAM frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + f := &wire.DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + datagramQueue.AddAndWait(f) + }() + // make sure the DATAGRAM has actually been queued + time.Sleep(scaleDuration(20 * time.Millisecond)) + + framer.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + Expect(p.frames[0].Frame).To(Equal(f)) + Expect(p.buffer.Data).ToNot(BeEmpty()) + Eventually(done).Should(BeClosed()) + }) + + It("accounts for the space consumed by control frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + var maxSize protocol.ByteCount + gomock.InOrder( + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + maxSize = maxLen + return fs, 444 + }), + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(maxSize - 444)) + return fs, 0 + }), + ) + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + + It("pads if payload length + packet number length is smaller than 4, for Long Header packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealer := getSealer() + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + handshakeStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + packet, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.packets).To(HaveLen(1)) + // cut off the tag that the mock sealer added + // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.buffer.Data) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead())) + // the first bytes of the payload should be a 2 PADDING frames... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + secondPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(secondPayloadByte).To(Equal(byte(0))) + // ... followed by the PING + frameParser := wire.NewFrameParser(false, packer.version) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(r.Len()).To(Equal(sealer.Overhead())) + }) + + It("pads if payload length + packet number length is smaller than 4", func() { + f := &wire.StreamFrame{ + StreamID: 0x10, // small stream ID, such that only a single byte is consumed + Fin: true, + } + Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealer := getSealer() + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // cut off the tag that the mock sealer added + packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.buffer.Data) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) + // the first byte of the payload should be a PADDING frame... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + // ... followed by the STREAM frame + frameParser := wire.NewFrameParser(true, packer.version) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + sf := frame.(*wire.StreamFrame) + Expect(sf.StreamID).To(Equal(f.StreamID)) + Expect(sf.Fin).To(Equal(f.Fin)) + Expect(sf.Data).To(BeEmpty()) + Expect(r.Len()).To(BeZero()) + }) + + It("packs multiple small STREAM frames into single packet", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 1"), + DataLenPresent: true, + } + f2 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 2"), + DataLenPresent: true, + } + f3 := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("frame 3"), + DataLenPresent: true, + } + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[2].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + }) + + Context("making ACK packets ack-eliciting", func() { + sendMaxNumNonAckElicitingAcks := func() { + for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.ack).ToNot(BeNil()) + Expect(p.frames).To(BeEmpty()) + } + } + + It("adds a PING frame when it's supposed to send a ack-eliciting packet", func() { + sendMaxNumNonAckElicitingAcks() + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + var hasPing bool + for _, f := range p.frames { + if _, ok := f.Frame.(*wire.PingFrame); ok { + hasPing = true + Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost + } + } + Expect(hasPing).To(BeTrue()) + // make sure the next packet doesn't contain another PING + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.ack).ToNot(BeNil()) + Expect(p.frames).To(BeEmpty()) + }) + + It("waits until there's something to send before adding a PING frame", func() { + sendMaxNumNonAckElicitingAcks() + // nothing to send + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + expectAppendControlFrames() + expectAppendStreamFrames() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + // now add some frame to send + expectAppendControlFrames() + expectAppendStreamFrames() + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.ack).To(Equal(ack)) + var hasPing bool + for _, f := range p.frames { + if _, ok := f.Frame.(*wire.PingFrame); ok { + hasPing = true + Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost + } + } + Expect(hasPing).To(BeTrue()) + }) + + It("doesn't send a PING if it already sent another ack-eliciting frame", func() { + sendMaxNumNonAckElicitingAcks() + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + expectAppendStreamFrames() + expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) + }) + }) + + Context("handling transport parameters", func() { + It("lowers the maximum packet size", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) + framer.EXPECT().HasData().Return(true).Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) + var initialMaxPacketSize protocol.ByteCount + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // now reduce the maxPacketSize + packer.HandleTransportParameters(&wire.TransportParameters{ + MaxUDPPayloadSize: maxPacketSize - 10, + }) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't increase the max packet size", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) + framer.EXPECT().HasData().Return(true).Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) + var initialMaxPacketSize protocol.ByteCount + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // now try to increase the maxPacketSize + packer.HandleTransportParameters(&wire.TransportParameters{ + MaxUDPPayloadSize: maxPacketSize + 10, + }) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("max packet size", func() { + It("increases the max packet size", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) + framer.EXPECT().HasData().Return(true).Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) + var initialMaxPacketSize protocol.ByteCount + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // now reduce the maxPacketSize + const packetSizeIncrease = 50 + packer.SetMaxPacketSize(maxPacketSize + packetSizeIncrease) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize + packetSizeIncrease)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + + Context("packing crypto packets", func() { + It("sets the length", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + f := &wire.CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData().Return(true).AnyTimes() + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + parsePacket(p.buffer.Data) + }) + + It("packs an Initial packet and pads it", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(1)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + }) + + It("packs a maximum size Handshake packet", func() { + var f *wire.CryptoFrame + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + f = &wire.CryptoFrame{Offset: 0x1337} + f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version)-1)) + Expect(f.Length(packer.version)).To(Equal(size)) + return f + }) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + parsePacket(p.buffer.Data) + }) + + It("packs a coalesced packet with Initial / Handshake, and pads it", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + // don't EXPECT any calls for a Handshake ACK frame + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} + }) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + }) + + It("packs a coalesced packet with Initial / super short Handshake, and pads it", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + // don't EXPECT any calls for a Handshake ACK frame + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + handshakeStream.EXPECT().HasData() + packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + }) + + It("packs a coalesced packet with super short Initial / super short Handshake, and pads it", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any()) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData() + packer.retransmissionQueue.AddInitial(&wire.PingFrame{}) + packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + }) + + It("packs a coalesced packet with Initial / super short 1-RTT, and pads it", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + expectAppendControlFrames() + expectAppendStreamFrames() + framer.EXPECT().HasData().Return(true) + packer.retransmissionQueue.AddAppData(&wire.PingFrame{}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].IsLongHeader).To(BeFalse()) + }) + + It("packs a coalesced packet with Initial / 0-RTT, and pads it", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + // don't EXPECT any calls for a Handshake ACK frame + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + hdrs := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(2)) + Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) + }) + + It("packs a coalesced packet with Handshake / 1-RTT", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + // don't EXPECT any calls for a 1-RTT ACK frame + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} + }) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeNumerically("<", 100)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + hdr, _, rest, err = wire.ParsePacket(rest, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(rest).To(BeEmpty()) + }) + + It("doesn't add a coalesced packet if the remaining size is smaller than MaxCoalescedPacketSize", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + s := size - protocol.MinCoalescedPacketSize + f := &wire.CryptoFrame{Offset: 0x1337} + f.Data = bytes.Repeat([]byte{'f'}, int(s-f.Length(packer.version)-1)) + Expect(f.Length(packer.version)).To(Equal(s)) + return f + }) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(len(p.buffer.Data)).To(BeEquivalentTo(maxPacketSize - protocol.MinCoalescedPacketSize)) + parsePacket(p.buffer.Data) + }) + + It("pads if payload length + packet number length is smaller than 4, for Long Header packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealer := getSealer() + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + handshakeStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + packet, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.packets).To(HaveLen(1)) + // cut off the tag that the mock sealer added + // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.buffer.Data) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead())) + // the first bytes of the payload should be a 2 PADDING frames... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + secondPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(secondPayloadByte).To(Equal(byte(0))) + // ... followed by the PING + frameParser := wire.NewFrameParser(false, packer.version) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(r.Len()).To(Equal(sealer.Overhead())) + }) + + It("adds retransmissions", func() { + f := &wire.CryptoFrame{Data: []byte("Initial")} + retransmissionQueue.AddInitial(f) + retransmissionQueue.AddHandshake(&wire.CryptoFrame{Data: []byte("Handshake")}) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData() + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) + }) + + It("sends an Initial packet containing only an ACK", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) + initialStream.EXPECT().HasData().Times(2) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) + }) + + It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + initialStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("sends a Handshake packet containing only an ACK", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).Return(ack) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData().Times(2) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) + }) + + for _, pers := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { + perspective := pers + + It(fmt.Sprintf("pads Initial packets to the required minimum packet size, for the %s", perspective), func() { + token := []byte("initial token") + packer.SetToken(token) + f := &wire.CryptoFrame{Data: []byte("foobar")} + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) + packer.perspective = protocol.PerspectiveClient + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.Token).To(Equal(token)) + Expect(p.packets[0].frames).To(HaveLen(1)) + cf := p.packets[0].frames[0].Frame.(*wire.CryptoFrame) + Expect(cf.Data).To(Equal([]byte("foobar"))) + }) + } + + It("adds an ACK frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 42, Largest: 1337}}} + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false).Return(ack) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) + packer.version = protocol.VersionTLS + packer.perspective = protocol.PerspectiveClient + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + }) + }) + + Context("packing probe packets", func() { + for _, pers := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { + perspective := pers + + It(fmt.Sprintf("packs an Initial probe packet and pads it, for the %s", perspective), func() { + packer.perspective = perspective + f := &wire.CryptoFrame{Data: []byte("Initial")} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(packet.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(packet.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + parsePacket(packet.buffer.Data) + }) + + It(fmt.Sprintf("packs an Initial probe packet with 1 byte payload, for the %s", perspective), func() { + packer.perspective = perspective + retransmissionQueue.AddInitial(&wire.PingFrame{}) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) + initialStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(packet.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) + Expect(packet.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + parsePacket(packet.buffer.Data) + }) + } + + It("packs a Handshake probe packet", func() { + f := &wire.CryptoFrame{Data: []byte("Handshake")} + retransmissionQueue.AddHandshake(f) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + parsePacket(packet.buffer.Data) + }) + + It("packs a full size Handshake probe packet", func() { + f := &wire.CryptoFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddHandshake(f) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + parsePacket(packet.buffer.Data) + }) + + It("packs a 1-RTT probe packet", func() { + f := &wire.StreamFrame{Data: []byte("1-RTT")} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + framer.EXPECT().HasData().Return(true) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + }) + + It("packs a full size 1-RTT probe packet", func() { + f := &wire.StreamFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + framer.EXPECT().HasData().Return(true) + expectAppendControlFrames() + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + sf, split := f.MaybeSplitOffFrame(maxSize, packer.version) + Expect(split).To(BeTrue()) + return append(fs, ackhandler.Frame{Frame: sf}), sf.Length(packer.version) + }) + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + }) + + It("returns nil if there's no probe data to send", func() { + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) + framer.EXPECT().HasData() + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).To(BeNil()) + }) + + It("packs an MTU probe packet", func() { + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) + ping := ackhandler.Frame{Frame: &wire.PingFrame{}} + const probePacketSize = maxPacketSize + 42 + p, err := packer.PackMTUProbePacket(ping, probePacketSize) + Expect(err).ToNot(HaveOccurred()) + Expect(p.length).To(BeEquivalentTo(probePacketSize)) + Expect(p.header.IsLongHeader).To(BeFalse()) + Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(0x43))) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.buffer.Data).To(HaveLen(int(probePacketSize))) + Expect(p.packetContents.isMTUProbePacket).To(BeTrue()) + }) + }) + }) +}) + +var _ = Describe("Converting to AckHandler packets", func() { + It("convert a packet", func() { + packet := &packetContents{ + header: &wire.ExtendedHeader{Header: wire.Header{}}, + frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, + ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, + length: 42, + } + t := time.Now() + p := packet.ToAckHandlerPacket(t, nil) + Expect(p.Length).To(Equal(protocol.ByteCount(42))) + Expect(p.Frames).To(Equal(packet.frames)) + Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100))) + Expect(p.SendTime).To(Equal(t)) + }) + + It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { + packet := &packetContents{ + header: &wire.ExtendedHeader{Header: wire.Header{}}, + frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, + } + p := packet.ToAckHandlerPacket(time.Now(), nil) + Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) + }) + + It("marks MTU probe packets", func() { + packet := &packetContents{ + header: &wire.ExtendedHeader{Header: wire.Header{}}, + isMTUProbePacket: true, + } + Expect(packet.ToAckHandlerPacket(time.Now(), nil).IsPathMTUProbePacket).To(BeTrue()) + }) + + DescribeTable( + "doesn't overwrite the OnLost callback, if it is set", + func(hdr wire.Header) { + var pingLost bool + packet := &packetContents{ + header: &wire.ExtendedHeader{Header: hdr}, + frames: []ackhandler.Frame{ + {Frame: &wire.MaxDataFrame{}}, + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, + }, + } + p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames[0].OnLost).ToNot(BeNil()) + p.Frames[1].OnLost(nil) + Expect(pingLost).To(BeTrue()) + }, + Entry(protocol.EncryptionInitial.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeInitial}), + Entry(protocol.EncryptionHandshake.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}), + Entry(protocol.Encryption0RTT.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketType0RTT}), + Entry(protocol.Encryption1RTT.String(), wire.Header{}), + ) +}) diff --git a/internal/quic-go/packet_unpacker.go b/internal/quic-go/packet_unpacker.go new file mode 100644 index 00000000..829907ef --- /dev/null +++ b/internal/quic-go/packet_unpacker.go @@ -0,0 +1,196 @@ +package quic + +import ( + "bytes" + "fmt" + "time" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + +type headerParseError struct { + err error +} + +func (e *headerParseError) Unwrap() error { + return e.err +} + +func (e *headerParseError) Error() string { + return e.err.Error() +} + +type unpackedPacket struct { + packetNumber protocol.PacketNumber // the decoded packet number + hdr *wire.ExtendedHeader + encryptionLevel protocol.EncryptionLevel + data []byte +} + +// The packetUnpacker unpacks QUIC packets. +type packetUnpacker struct { + cs handshake.CryptoSetup + + version protocol.VersionNumber +} + +var _ unpacker = &packetUnpacker{} + +func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { + return &packetUnpacker{ + cs: cs, + version: version, + } +} + +// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. +// If any other error occurred when parsing the header, the error is of type headerParseError. +// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. +func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + var encLevel protocol.EncryptionLevel + var extHdr *wire.ExtendedHeader + var decrypted []byte + //nolint:exhaustive // Retry packets can't be unpacked. + switch hdr.Type { + case protocol.PacketTypeInitial: + encLevel = protocol.EncryptionInitial + opener, err := u.cs.GetInitialOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + case protocol.PacketTypeHandshake: + encLevel = protocol.EncryptionHandshake + opener, err := u.cs.GetHandshakeOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + case protocol.PacketType0RTT: + encLevel = protocol.Encryption0RTT + opener, err := u.cs.Get0RTTOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + default: + if hdr.IsLongHeader { + return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) + } + encLevel = protocol.Encryption1RTT + opener, err := u.cs.Get1RTTOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data) + if err != nil { + return nil, err + } + } + + return &unpackedPacket{ + hdr: extHdr, + packetNumber: extHdr.PacketNumber, + encryptionLevel: encLevel, + data: decrypted, + }, nil +} + +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { + extHdr, parseErr := u.unpackHeader(opener, hdr, data) + // If the reserved bits are set incorrectly, we still need to continue unpacking. + // This avoids a timing side-channel, which otherwise might allow an attacker + // to gain information about the header encryption. + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, nil, parseErr + } + extHdrLen := extHdr.ParsedLen() + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + if parseErr != nil { + return nil, nil, parseErr + } + return extHdr, decrypted, nil +} + +func (u *packetUnpacker) unpackShortHeaderPacket( + opener handshake.ShortHeaderOpener, + hdr *wire.Header, + rcvTime time.Time, + data []byte, +) (*wire.ExtendedHeader, []byte, error) { + extHdr, parseErr := u.unpackHeader(opener, hdr, data) + // If the reserved bits are set incorrectly, we still need to continue unpacking. + // This avoids a timing side-channel, which otherwise might allow an attacker + // to gain information about the header encryption. + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, nil, parseErr + } + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) + extHdrLen := extHdr.ParsedLen() + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + if parseErr != nil { + return nil, nil, parseErr + } + return extHdr, decrypted, nil +} + +// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. +func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { + extHdr, err := unpackHeader(hd, hdr, data, u.version) + if err != nil && err != wire.ErrInvalidReservedBits { + return nil, &headerParseError{err: err} + } + return extHdr, err +} + +func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { + r := bytes.NewReader(data) + + hdrLen := hdr.ParsedLen() + if protocol.ByteCount(len(data)) < hdrLen+4+16 { + //nolint:stylecheck + return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) + } + // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. + // 1. save a copy of the 4 bytes + origPNBytes := make([]byte, 4) + copy(origPNBytes, data[hdrLen:hdrLen+4]) + // 2. decrypt the header, assuming a 4 byte packet number + hd.DecryptHeader( + data[hdrLen+4:hdrLen+4+16], + &data[0], + data[hdrLen:hdrLen+4], + ) + // 3. parse the header (and learn the actual length of the packet number) + extHdr, parseErr := hdr.ParseExtended(r, version) + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, parseErr + } + // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier + if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { + copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) + } + return extHdr, parseErr +} diff --git a/internal/quic-go/packet_unpacker_test.go b/internal/quic-go/packet_unpacker_test.go new file mode 100644 index 00000000..a111faa4 --- /dev/null +++ b/internal/quic-go/packet_unpacker_test.go @@ -0,0 +1,292 @@ +package quic + +import ( + "bytes" + "errors" + "time" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet Unpacker", func() { + const version = protocol.VersionTLS + + var ( + unpacker *packetUnpacker + cs *mocks.MockCryptoSetup + connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ) + + getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { + buf := &bytes.Buffer{} + ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) + hdrLen := buf.Len() + if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { + buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))) + } + hdr, _, _, err := wire.ParsePacket(buf.Bytes(), connID.Len()) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return hdr, buf.Bytes()[:hdrLen] + } + + BeforeEach(func() { + cs = mocks.NewMockCryptoSetup(mockCtrl) + unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) + }) + + It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + _, err := unpacker.Unpack(hdr, time.Now(), data) + Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) + var headerErr *headerParseError + Expect(errors.As(err, &headerErr)).To(BeTrue()) + Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) + }) + + It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) + opener := mocks.NewMockShortHeaderOpener(mockCtrl) + cs.EXPECT().Get1RTTOpener().Return(opener, nil) + _, err := unpacker.Unpack(hdr, time.Now(), data) + Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) + Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) + }) + + It("opens Initial packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Length: 3 + 6, // packet number len + payload + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 2, + PacketNumberLen: 3, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + gomock.InOrder( + cs.EXPECT().GetInitialOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), + opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), + ) + packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(packet.data).To(Equal([]byte("decrypted"))) + }) + + It("opens 0-RTT packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + Length: 3 + 6, // packet number len + payload + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 20, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + gomock.InOrder( + cs.EXPECT().Get0RTTOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), + opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), + ) + packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) + Expect(packet.data).To(Equal([]byte("decrypted"))) + }) + + It("opens short header packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + KeyPhase: protocol.KeyPhaseOne, + PacketNumber: 99, + PacketNumberLen: protocol.PacketNumberLen4, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockShortHeaderOpener(mockCtrl) + now := time.Now() + gomock.InOrder( + cs.EXPECT().Get1RTTOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), + opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), + ) + packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(packet.data).To(Equal([]byte("decrypted"))) + }) + + It("returns the error when getting the sealer fails", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) + _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) + }) + + It("returns the error when unpacking fails", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 3, // packet number len + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 2, + PacketNumberLen: 3, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) + unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) + _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(unpackErr)) + }) + + It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + hdrRaw[0] |= 0xc + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) + _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) + }) + + It("defends against the timing side-channel when the reserved bits are wrong, for short header packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + hdrRaw[0] |= 0x18 + opener := mocks.NewMockShortHeaderOpener(mockCtrl) + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + cs.EXPECT().Get1RTTOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) + _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) + }) + + It("returns the decryption error, when unpacking a packet with wrong reserved bits fails", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + hdrRaw[0] |= 0x18 + opener := mocks.NewMockShortHeaderOpener(mockCtrl) + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + cs.EXPECT().Get1RTTOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) + _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) + Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) + }) + + It("decrypts the header", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 3, // packet number len + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header + firstHdrByte := hdrRaw[0] + hdrRaw[0] ^= 0xff // invert the first byte + hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number + hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number + Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + gomock.InOrder( + // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte + opener.EXPECT().DecryptHeader( + []byte{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + &hdrRaw[0], + append(hdrRaw[len(hdrRaw)-2:], []byte{1, 2}...)).Do(func(_ []byte, firstByte *byte, pnBytes []byte) { + *firstByte ^= 0xff // invert the first byte back + for i := range pnBytes { + pnBytes[i] ^= 0xff // invert the packet number bytes + } + }), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)), + opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil), + ) + data := hdrRaw + for i := 1; i <= 100; i++ { + data = append(data, uint8(i)) + } + packet, err := unpacker.Unpack(hdr, time.Now(), data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331))) + }) +}) diff --git a/internal/quic-go/protocol/connection_id.go b/internal/quic-go/protocol/connection_id.go new file mode 100644 index 00000000..3aec2cd3 --- /dev/null +++ b/internal/quic-go/protocol/connection_id.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +const maxConnectionIDLen = 20 + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 20 bytes. +func GenerateConnectionIDForInitial() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} diff --git a/internal/quic-go/protocol/connection_id_test.go b/internal/quic-go/protocol/connection_id_test.go new file mode 100644 index 00000000..345e656c --- /dev/null +++ b/internal/quic-go/protocol/connection_id_test.go @@ -0,0 +1,108 @@ +package protocol + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection ID generation", func() { + It("generates random connection IDs", func() { + c1, err := GenerateConnectionID(8) + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(BeZero()) + c2, err := GenerateConnectionID(8) + Expect(err).ToNot(HaveOccurred()) + Expect(c1).ToNot(Equal(c2)) + }) + + It("generates connection IDs with the requested length", func() { + c, err := GenerateConnectionID(5) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(Equal(5)) + }) + + It("generates random length destination connection IDs", func() { + var has8ByteConnID, has20ByteConnID bool + for i := 0; i < 1000; i++ { + c, err := GenerateConnectionIDForInitial() + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(BeNumerically(">=", 8)) + Expect(c.Len()).To(BeNumerically("<=", 20)) + if c.Len() == 8 { + has8ByteConnID = true + } + if c.Len() == 20 { + has20ByteConnID = true + } + } + Expect(has8ByteConnID).To(BeTrue()) + Expect(has20ByteConnID).To(BeTrue()) + }) + + It("says if connection IDs are equal", func() { + c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + Expect(c1.Equal(c1)).To(BeTrue()) + Expect(c2.Equal(c2)).To(BeTrue()) + Expect(c1.Equal(c2)).To(BeFalse()) + Expect(c2.Equal(c1)).To(BeFalse()) + }) + + It("reads the connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + c, err := ReadConnectionID(buf, 9) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) + }) + + It("returns io.EOF if there's not enough data to read", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + _, err := ReadConnectionID(buf, 5) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns nil for a 0 length connection ID", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + c, err := ReadConnectionID(buf, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(c).To(BeNil()) + }) + + It("returns the length", func() { + c := ConnectionID{1, 2, 3, 4, 5, 6, 7} + Expect(c.Len()).To(Equal(7)) + }) + + It("has 0 length for the default value", func() { + var c ConnectionID + Expect(c.Len()).To(BeZero()) + }) + + It("returns the bytes", func() { + c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) + }) + + It("returns a nil byte slice for the default value", func() { + var c ConnectionID + Expect(c.Bytes()).To(BeNil()) + }) + + It("has a string representation", func() { + c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + Expect(c.String()).To(Equal("deadbeef42")) + }) + + It("has a long string representation", func() { + c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} + Expect(c.String()).To(Equal("13370000decafbad")) + }) + + It("has a string representation for the default value", func() { + var c ConnectionID + Expect(c.String()).To(Equal("(empty)")) + }) +}) diff --git a/internal/quic-go/protocol/encryption_level.go b/internal/quic-go/protocol/encryption_level.go new file mode 100644 index 00000000..32d38ab1 --- /dev/null +++ b/internal/quic-go/protocol/encryption_level.go @@ -0,0 +1,30 @@ +package protocol + +// EncryptionLevel is the encryption level +// Default value is Unencrypted +type EncryptionLevel uint8 + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = 1 + iota + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT +) + +func (e EncryptionLevel) String() string { + switch e { + case EncryptionInitial: + return "Initial" + case EncryptionHandshake: + return "Handshake" + case Encryption0RTT: + return "0-RTT" + case Encryption1RTT: + return "1-RTT" + } + return "unknown" +} diff --git a/internal/quic-go/protocol/encryption_level_test.go b/internal/quic-go/protocol/encryption_level_test.go new file mode 100644 index 00000000..9b07b08b --- /dev/null +++ b/internal/quic-go/protocol/encryption_level_test.go @@ -0,0 +1,20 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Encryption Level", func() { + It("doesn't use 0 as a value", func() { + // 0 is used in some tests + Expect(EncryptionInitial * EncryptionHandshake * Encryption0RTT * Encryption1RTT).ToNot(BeZero()) + }) + + It("has the correct string representation", func() { + Expect(EncryptionInitial.String()).To(Equal("Initial")) + Expect(EncryptionHandshake.String()).To(Equal("Handshake")) + Expect(Encryption0RTT.String()).To(Equal("0-RTT")) + Expect(Encryption1RTT.String()).To(Equal("1-RTT")) + }) +}) diff --git a/internal/quic-go/protocol/key_phase.go b/internal/quic-go/protocol/key_phase.go new file mode 100644 index 00000000..edd740cf --- /dev/null +++ b/internal/quic-go/protocol/key_phase.go @@ -0,0 +1,36 @@ +package protocol + +// KeyPhase is the key phase +type KeyPhase uint64 + +// Bit determines the key phase bit +func (p KeyPhase) Bit() KeyPhaseBit { + if p%2 == 0 { + return KeyPhaseZero + } + return KeyPhaseOne +} + +// KeyPhaseBit is the key phase bit +type KeyPhaseBit uint8 + +const ( + // KeyPhaseUndefined is an undefined key phase + KeyPhaseUndefined KeyPhaseBit = iota + // KeyPhaseZero is key phase 0 + KeyPhaseZero + // KeyPhaseOne is key phase 1 + KeyPhaseOne +) + +func (p KeyPhaseBit) String() string { + //nolint:exhaustive + switch p { + case KeyPhaseZero: + return "0" + case KeyPhaseOne: + return "1" + default: + return "undefined" + } +} diff --git a/internal/quic-go/protocol/key_phase_test.go b/internal/quic-go/protocol/key_phase_test.go new file mode 100644 index 00000000..92f404a5 --- /dev/null +++ b/internal/quic-go/protocol/key_phase_test.go @@ -0,0 +1,27 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Key Phases", func() { + It("has undefined as its default value", func() { + var k KeyPhaseBit + Expect(k).To(Equal(KeyPhaseUndefined)) + }) + + It("has the correct string representation", func() { + Expect(KeyPhaseZero.String()).To(Equal("0")) + Expect(KeyPhaseOne.String()).To(Equal("1")) + }) + + It("converts the key phase to the key phase bit", func() { + Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne)) + }) +}) diff --git a/internal/quic-go/protocol/packet_number.go b/internal/quic-go/protocol/packet_number.go new file mode 100644 index 00000000..bd340161 --- /dev/null +++ b/internal/quic-go/protocol/packet_number.go @@ -0,0 +1,79 @@ +package protocol + +// A PacketNumber in QUIC +type PacketNumber int64 + +// InvalidPacketNumber is a packet number that is never sent. +// In QUIC, 0 is a valid packet number. +const InvalidPacketNumber PacketNumber = -1 + +// PacketNumberLen is the length of the packet number in bytes +type PacketNumberLen uint8 + +const ( + // PacketNumberLen1 is a packet number length of 1 byte + PacketNumberLen1 PacketNumberLen = 1 + // PacketNumberLen2 is a packet number length of 2 bytes + PacketNumberLen2 PacketNumberLen = 2 + // PacketNumberLen3 is a packet number length of 3 bytes + PacketNumberLen3 PacketNumberLen = 3 + // PacketNumberLen4 is a packet number length of 4 bytes + PacketNumberLen4 PacketNumberLen = 4 +) + +// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number +func DecodePacketNumber( + packetNumberLength PacketNumberLen, + lastPacketNumber PacketNumber, + wirePacketNumber PacketNumber, +) PacketNumber { + var epochDelta PacketNumber + switch packetNumberLength { + case PacketNumberLen1: + epochDelta = PacketNumber(1) << 8 + case PacketNumberLen2: + epochDelta = PacketNumber(1) << 16 + case PacketNumberLen3: + epochDelta = PacketNumber(1) << 24 + case PacketNumberLen4: + epochDelta = PacketNumber(1) << 32 + } + epoch := lastPacketNumber & ^(epochDelta - 1) + var prevEpochBegin PacketNumber + if epoch > epochDelta { + prevEpochBegin = epoch - epochDelta + } + nextEpochBegin := epoch + epochDelta + return closestTo( + lastPacketNumber+1, + epoch+wirePacketNumber, + closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), + ) +} + +func closestTo(target, a, b PacketNumber) PacketNumber { + if delta(target, a) < delta(target, b) { + return a + } + return b +} + +func delta(a, b PacketNumber) PacketNumber { + if a < b { + return b - a + } + return a - b +} + +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header +// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { + diff := uint64(packetNumber - leastUnacked) + if diff < (1 << (16 - 1)) { + return PacketNumberLen2 + } + if diff < (1 << (24 - 1)) { + return PacketNumberLen3 + } + return PacketNumberLen4 +} diff --git a/internal/quic-go/protocol/packet_number_test.go b/internal/quic-go/protocol/packet_number_test.go new file mode 100644 index 00000000..d3bfe1d5 --- /dev/null +++ b/internal/quic-go/protocol/packet_number_test.go @@ -0,0 +1,204 @@ +package protocol + +import ( + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// Tests taken and extended from chrome +var _ = Describe("packet number calculation", func() { + It("InvalidPacketNumber is smaller than all valid packet numbers", func() { + Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) + }) + + It("works with the example from the draft", func() { + Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) + }) + + It("works with the examples from the draft", func() { + Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) + Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) + }) + + getEpoch := func(len PacketNumberLen) uint64 { + if len > 4 { + Fail("invalid packet number len") + } + return uint64(1) << (len * 8) + } + + check := func(length PacketNumberLen, expected, last uint64) { + epoch := getEpoch(length) + epochMask := epoch - 1 + wirePacketNumber := expected & epochMask + ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) + } + + for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { + length := l + + Context(fmt.Sprintf("with %d bytes", length), func() { + epoch := getEpoch(length) + epochMask := epoch - 1 + + It("works near epoch start", func() { + // A few quick manual sanity check + check(length, 1, 0) + check(length, epoch+1, epochMask) + check(length, epoch, epochMask) + + // Cases where the last number was close to the start of the range. + for last := uint64(0); last < 10; last++ { + // Small numbers should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, j, last) + } + + // Large numbers should not wrap either (because we're near 0 already). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + It("works near epoch end", func() { + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := epoch - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, epoch+j, last) + } + + // Large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + // Next check where we're in a non-zero epoch to verify we handle + // reverse wrapping, too. + It("works near previous epoch", func() { + prevEpoch := 1 * epoch + curEpoch := 2 * epoch + // Cases where the last number was close to the start of the range + for i := uint64(0); i < 10; i++ { + last := curEpoch + i + // Small number should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, curEpoch+j, last) + } + + // But large numbers should reverse wrap. + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, prevEpoch+num, last) + } + } + }) + + It("works near next epoch", func() { + curEpoch := 2 * epoch + nextEpoch := 3 * epoch + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := nextEpoch - 1 - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, nextEpoch+j, last) + } + + // but large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, curEpoch+num, last) + } + } + }) + + Context("shortening a packet number for the header", func() { + Context("shortening", func() { + It("sends out low packet numbers as 2 byte", func() { + length := GetPacketNumberLengthForHeader(4, 2) + Expect(length).To(Equal(PacketNumberLen2)) + }) + + It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { + length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) + Expect(length).To(Equal(PacketNumberLen2)) + }) + + It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { + length := GetPacketNumberLengthForHeader(40000, 2) + Expect(length).To(Equal(PacketNumberLen3)) + }) + + It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { + length := GetPacketNumberLengthForHeader(40000000, 2) + Expect(length).To(Equal(PacketNumberLen4)) + }) + }) + + Context("self-consistency", func() { + It("works for small packet numbers", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(1) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("works for small packet numbers and increasing ACKed packets", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(i / 2) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 + wirePacketNumber := uint64(packetNumber) & epochMask + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("also works for larger packet numbers", func() { + var increment uint64 + for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(1) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 + wirePacketNumber := uint64(packetNumber) & epochMask + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + + increment = getEpoch(length) / 8 + } + }) + + It("works for packet numbers larger than 2^48", func() { + for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { + packetNumber := PacketNumber(i) + leastUnacked := PacketNumber(i - 1000) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) + Expect(decodedPacketNumber).To(Equal(packetNumber)) + } + }) + }) + }) + }) + } +}) diff --git a/internal/quic-go/protocol/params.go b/internal/quic-go/protocol/params.go new file mode 100644 index 00000000..83137113 --- /dev/null +++ b/internal/quic-go/protocol/params.go @@ -0,0 +1,193 @@ +package protocol + +import "time" + +// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. +const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB + +// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. +const InitialPacketSizeIPv4 = 1252 + +// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. +const InitialPacketSizeIPv6 = 1232 + +// MaxCongestionWindowPackets is the maximum congestion window in packet. +const MaxCongestionWindowPackets = 10000 + +// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. +const MaxUndecryptablePackets = 32 + +// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window +// This is the value that Chromium is using +const ConnectionFlowControlMultiplier = 1.5 + +// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data +const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb + +// DefaultInitialMaxData is the connection-level flow control window for receiving data +const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData + +// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data +const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB + +// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data +const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB + +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 + +// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open +const DefaultMaxIncomingStreams = 100 + +// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open +const DefaultMaxIncomingUniStreams = 100 + +// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. +const MaxServerUnprocessedPackets = 1024 + +// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. +const MaxConnUnprocessedPackets = 256 + +// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. +// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. +const SkipPacketInitialPeriod PacketNumber = 256 + +// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. +const SkipPacketMaxPeriod PacketNumber = 128 * 1024 + +// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. +// If the queue is full, new connection attempts will be rejected. +const MaxAcceptQueueSize = 32 + +// TokenValidity is the duration that a (non-retry) token is considered valid +const TokenValidity = 24 * time.Hour + +// RetryTokenValidity is the duration that a retry token is considered valid +const RetryTokenValidity = 10 * time.Second + +// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. +// When reached, it imposes a soft limit on sending new packets: +// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. +const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets + +// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. +// When reached, no more packets will be sent. +// This value *must* be larger than MaxOutstandingSentPackets. +const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 + +// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, +// but no ack-eliciting frames, that we send in a row +const MaxNonAckElicitingAcks = 19 + +// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames +// prevents DoS attacks against the streamFrameSorter +const MaxStreamFrameSorterGaps = 1000 + +// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame +// that we use the buffer for. This protects against a DoS where an attacker would send us +// very small STREAM frames to consume a lot of memory. +const MinStreamFrameBufferSize = 128 + +// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. +// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. +const MinCoalescedPacketSize = 128 + +// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. +// This limits the size of the ClientHello and Certificates that can be received. +const MaxCryptoStreamOffset = 16 * (1 << 10) + +// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout +const MinRemoteIdleTimeout = 5 * time.Second + +// DefaultIdleTimeout is the default idle timeout +const DefaultIdleTimeout = 30 * time.Second + +// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. +const DefaultHandshakeIdleTimeout = 5 * time.Second + +// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. +const DefaultHandshakeTimeout = 10 * time.Second + +// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. +// It should be shorter than the time that NATs clear their mapping. +const MaxKeepAliveInterval = 20 * time.Second + +// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. +// after this time all information about the old connection will be deleted +const RetiredConnectionIDDeleteTimeout = 5 * time.Second + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 + +// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames +// we send after the handshake completes. +const MaxPostHandshakeCryptoFrameSize = 1000 + +// MaxAckFrameSize is the maximum size for an ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + +// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). +// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. +const MaxDatagramFrameSize ByteCount = 1220 + +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) +const DatagramRcvQueueLen = 128 + +// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. +// It also serves as a limit for the packet history. +// If at any point we keep track of more ranges, old ranges are discarded. +const MaxNumAckRanges = 32 + +// MinPacingDelay is the minimum duration that is used for packet pacing +// If the packet packing frequency is higher, multiple packets might be sent at once. +// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. +const MinPacingDelay = time.Millisecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 + +// MaxActiveConnectionIDs is the number of connection IDs that we're storing. +const MaxActiveConnectionIDs = 4 + +// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. +const MaxIssuedConnectionIDs = 6 + +// PacketsPerConnectionID is the number of packets we send using one connection ID. +// If the peer provices us with enough new connection IDs, we switch to a new connection ID. +const PacketsPerConnectionID = 10000 + +// AckDelayExponent is the ack delay exponent used when sending ACKs. +const AckDelayExponent = 3 + +// Estimated timer granularity. +// The loss detection timer will not be set to a value smaller than granularity. +const TimerGranularity = time.Millisecond + +// MaxAckDelay is the maximum time by which we delay sending ACKs. +const MaxAckDelay = 25 * time.Millisecond + +// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. +// This is the value that should be advertised to the peer. +const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +const KeyUpdateInterval = 100 * 1000 + +// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. +const Max0RTTQueueingDuration = 100 * time.Millisecond + +// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. +const Max0RTTQueues = 32 + +// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. +// When a new connection is created, all buffered packets are passed to the connection immediately. +// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. +// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. +const Max0RTTQueueLen = 31 diff --git a/internal/quic-go/protocol/params_test.go b/internal/quic-go/protocol/params_test.go new file mode 100644 index 00000000..50a260d2 --- /dev/null +++ b/internal/quic-go/protocol/params_test.go @@ -0,0 +1,13 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Parameters", func() { + It("can queue more packets in the session than in the 0-RTT queue", func() { + Expect(MaxConnUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) + Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) + }) +}) diff --git a/internal/quic-go/protocol/perspective.go b/internal/quic-go/protocol/perspective.go new file mode 100644 index 00000000..43358fec --- /dev/null +++ b/internal/quic-go/protocol/perspective.go @@ -0,0 +1,26 @@ +package protocol + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) + +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "Server" + case PerspectiveClient: + return "Client" + default: + return "invalid perspective" + } +} diff --git a/internal/quic-go/protocol/perspective_test.go b/internal/quic-go/protocol/perspective_test.go new file mode 100644 index 00000000..0ae23d7c --- /dev/null +++ b/internal/quic-go/protocol/perspective_test.go @@ -0,0 +1,19 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Perspective", func() { + It("has a string representation", func() { + Expect(PerspectiveClient.String()).To(Equal("Client")) + Expect(PerspectiveServer.String()).To(Equal("Server")) + Expect(Perspective(0).String()).To(Equal("invalid perspective")) + }) + + It("returns the opposite", func() { + Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer)) + Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient)) + }) +}) diff --git a/internal/quic-go/protocol/protocol.go b/internal/quic-go/protocol/protocol.go new file mode 100644 index 00000000..8241e274 --- /dev/null +++ b/internal/quic-go/protocol/protocol.go @@ -0,0 +1,97 @@ +package protocol + +import ( + "fmt" + "time" +) + +// The PacketType is the Long Header Type +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = 1 + iota + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT +) + +func (t PacketType) String() string { + switch t { + case PacketTypeInitial: + return "Initial" + case PacketTypeRetry: + return "Retry" + case PacketTypeHandshake: + return "Handshake" + case PacketType0RTT: + return "0-RTT Protected" + default: + return fmt.Sprintf("unknown packet type: %d", t) + } +} + +type ECN uint8 + +const ( + ECNNon ECN = iota // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 +) + +// A ByteCount in QUIC +type ByteCount int64 + +// MaxByteCount is the maximum value of a ByteCount +const MaxByteCount = ByteCount(1<<62 - 1) + +// InvalidByteCount is an invalid byte count +const InvalidByteCount ByteCount = -1 + +// A StatelessResetToken is a stateless reset token. +type StatelessResetToken [16]byte + +// MaxPacketBufferSize maximum packet size of any QUIC packet, based on +// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, +// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. +// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. +const MaxPacketBufferSize ByteCount = 1452 + +// MinInitialPacketSize is the minimum size an Initial packet is required to have. +const MinInitialPacketSize = 1200 + +// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version +// needs to have in order to trigger a Version Negotiation packet. +const MinUnknownVersionPacketSize = MinInitialPacketSize + +// MinStatelessResetSize is the minimum size of a stateless reset packet that we send +const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ + +// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. +const MinConnectionIDLenInitial = 8 + +// DefaultAckDelayExponent is the default ack delay exponent +const DefaultAckDelayExponent = 3 + +// MaxAckDelayExponent is the maximum ack delay exponent +const MaxAckDelayExponent = 20 + +// DefaultMaxAckDelay is the default max_ack_delay +const DefaultMaxAckDelay = 25 * time.Millisecond + +// MaxMaxAckDelay is the maximum max_ack_delay +const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond + +// MaxConnIDLen is the maximum length of the connection ID +const MaxConnIDLen = 20 + +// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using +// AEAD_AES_128_GCM or AEAD_AES_265_GCM. +const InvalidPacketLimitAES = 1 << 52 + +// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. +const InvalidPacketLimitChaCha = 1 << 36 diff --git a/internal/quic-go/protocol/protocol_suite_test.go b/internal/quic-go/protocol/protocol_suite_test.go new file mode 100644 index 00000000..60da0157 --- /dev/null +++ b/internal/quic-go/protocol/protocol_suite_test.go @@ -0,0 +1,13 @@ +package protocol + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestProtocol(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Protocol Suite") +} diff --git a/internal/quic-go/protocol/protocol_test.go b/internal/quic-go/protocol/protocol_test.go new file mode 100644 index 00000000..117405e4 --- /dev/null +++ b/internal/quic-go/protocol/protocol_test.go @@ -0,0 +1,25 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Protocol", func() { + Context("Long Header Packet Types", func() { + It("has the correct string representation", func() { + Expect(PacketTypeInitial.String()).To(Equal("Initial")) + Expect(PacketTypeRetry.String()).To(Equal("Retry")) + Expect(PacketTypeHandshake.String()).To(Equal("Handshake")) + Expect(PacketType0RTT.String()).To(Equal("0-RTT Protected")) + Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) + }) + }) + + It("converts ECN bits from the IP header wire to the correct types", func() { + Expect(ECN(0)).To(Equal(ECNNon)) + Expect(ECN(0b00000010)).To(Equal(ECT0)) + Expect(ECN(0b00000001)).To(Equal(ECT1)) + Expect(ECN(0b00000011)).To(Equal(ECNCE)) + }) +}) diff --git a/internal/quic-go/protocol/stream.go b/internal/quic-go/protocol/stream.go new file mode 100644 index 00000000..ad7de864 --- /dev/null +++ b/internal/quic-go/protocol/stream.go @@ -0,0 +1,76 @@ +package protocol + +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} + +// A StreamID in QUIC +type StreamID int64 + +// InitiatedBy says if the stream was initiated by the client or by the server +func (s StreamID) InitiatedBy() Perspective { + if s%2 == 0 { + return PerspectiveClient + } + return PerspectiveServer +} + +// Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi +} + +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() StreamNum { + return StreamNum(s/4) + 1 +} diff --git a/internal/quic-go/protocol/stream_test.go b/internal/quic-go/protocol/stream_test.go new file mode 100644 index 00000000..4209f8a0 --- /dev/null +++ b/internal/quic-go/protocol/stream_test.go @@ -0,0 +1,70 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream ID", func() { + It("InvalidStreamID is smaller than all valid stream IDs", func() { + Expect(InvalidStreamID).To(BeNumerically("<", 0)) + }) + + It("says who initiated a stream", func() { + Expect(StreamID(4).InitiatedBy()).To(Equal(PerspectiveClient)) + Expect(StreamID(5).InitiatedBy()).To(Equal(PerspectiveServer)) + Expect(StreamID(6).InitiatedBy()).To(Equal(PerspectiveClient)) + Expect(StreamID(7).InitiatedBy()).To(Equal(PerspectiveServer)) + }) + + It("tells the directionality", func() { + Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(6).Type()).To(Equal(StreamTypeUni)) + Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) + }) + + It("tells the stream number", func() { + Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) + }) + + Context("converting stream nums to stream IDs", func() { + It("handles 0", func() { + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID)) + }) + + It("handles the first", func() { + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) + }) + + It("handles others", func() { + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396))) + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) + }) + + It("has the right value for MaxStreamCount", func() { + const maxStreamID = StreamID(1<<62 - 1) + for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { + for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { + Expect(MaxStreamCount.StreamID(dir, pers)).To(BeNumerically("<=", maxStreamID)) + Expect((MaxStreamCount + 1).StreamID(dir, pers)).To(BeNumerically(">", maxStreamID)) + } + } + }) + }) +}) diff --git a/internal/quic-go/protocol/version.go b/internal/quic-go/protocol/version.go new file mode 100644 index 00000000..dd54dbd3 --- /dev/null +++ b/internal/quic-go/protocol/version.go @@ -0,0 +1,114 @@ +package protocol + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "math" +) + +// VersionNumber is a version number as int +type VersionNumber uint32 + +// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions +const ( + gquicVersion0 = 0x51303030 + maxGquicVersion = 0x51303439 +) + +// The version numbers, making grepping easier +const ( + VersionTLS VersionNumber = 0x1 + VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter + VersionUnknown VersionNumber = math.MaxUint32 + VersionDraft29 VersionNumber = 0xff00001d + Version1 VersionNumber = 0x1 + Version2 VersionNumber = 0x709a50c4 +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29} + +// IsValidVersion says if the version is known to quic-go +func IsValidVersion(v VersionNumber) bool { + return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) +} + +func (vn VersionNumber) String() string { + // For releases, VersionTLS will be set to a draft version. + // A switch statement can't contain duplicate cases. + if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != Version1 { + return "TLS dev version (WIP)" + } + //nolint:exhaustive + switch vn { + case VersionWhatever: + return "whatever" + case VersionUnknown: + return "unknown" + case VersionDraft29: + return "draft-29" + case Version1: + return "v1" + case Version2: + return "v2" + default: + if vn.isGQUIC() { + return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%#x", uint32(vn)) + } +} + +func (vn VersionNumber) isGQUIC() bool { + return vn > gquicVersion0 && vn <= maxGquicVersion +} + +func (vn VersionNumber) toGQUICVersion() int { + return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) +} + +// IsSupportedVersion returns true if the server supports this version +func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { + for _, t := range supported { + if t == v { + return true + } + } + return false +} + +// ChooseSupportedVersion finds the best version in the overlap of ours and theirs +// ours is a slice of versions that we support, sorted by our preference (descending) +// theirs is a slice of versions offered by the peer. The order does not matter. +// The bool returned indicates if a matching version was found. +func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return ourVer, true + } + } + } + return 0, false +} + +// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() VersionNumber { + b := make([]byte, 4) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +} + +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position +func GetGreasedVersions(supported []VersionNumber) []VersionNumber { + b := make([]byte, 1) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + randPos := int(b[0]) % (len(supported) + 1) + greased := make([]VersionNumber, len(supported)+1) + copy(greased, supported[:randPos]) + greased[randPos] = generateReservedVersion() + copy(greased[randPos+1:], supported[randPos:]) + return greased +} diff --git a/internal/quic-go/protocol/version_test.go b/internal/quic-go/protocol/version_test.go new file mode 100644 index 00000000..33c6598b --- /dev/null +++ b/internal/quic-go/protocol/version_test.go @@ -0,0 +1,121 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Version", func() { + isReservedVersion := func(v VersionNumber) bool { + return v&0x0f0f0f0f == 0x0a0a0a0a + } + + It("says if a version is valid", func() { + Expect(IsValidVersion(VersionTLS)).To(BeTrue()) + Expect(IsValidVersion(VersionWhatever)).To(BeFalse()) + Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) + Expect(IsValidVersion(VersionDraft29)).To(BeTrue()) + Expect(IsValidVersion(Version1)).To(BeTrue()) + Expect(IsValidVersion(Version2)).To(BeTrue()) + Expect(IsValidVersion(1234)).To(BeFalse()) + }) + + It("versions don't have reserved version numbers", func() { + Expect(isReservedVersion(VersionTLS)).To(BeFalse()) + }) + + It("has the right string representation", func() { + Expect(VersionWhatever.String()).To(Equal("whatever")) + Expect(VersionUnknown.String()).To(Equal("unknown")) + Expect(VersionDraft29.String()).To(Equal("draft-29")) + Expect(Version1.String()).To(Equal("v1")) + Expect(Version2.String()).To(Equal("v2")) + // check with unsupported version numbers from the wiki + Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) + Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) + Expect(VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) + Expect(VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) + Expect(VersionNumber(0x01234567).String()).To(Equal("0x1234567")) + }) + + It("recognizes supported versions", func() { + Expect(IsSupportedVersion(SupportedVersions, 0)).To(BeFalse()) + Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[0])).To(BeTrue()) + Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue()) + }) + + Context("highest supported version", func() { + It("finds the supported version", func() { + supportedVersions := []VersionNumber{1, 2, 3} + other := []VersionNumber{6, 5, 4, 3} + ver, ok := ChooseSupportedVersion(supportedVersions, other) + Expect(ok).To(BeTrue()) + Expect(ver).To(Equal(VersionNumber(3))) + }) + + It("picks the preferred version", func() { + supportedVersions := []VersionNumber{2, 1, 3} + other := []VersionNumber{3, 6, 1, 8, 2, 10} + ver, ok := ChooseSupportedVersion(supportedVersions, other) + Expect(ok).To(BeTrue()) + Expect(ver).To(Equal(VersionNumber(2))) + }) + + It("says when no matching version was found", func() { + _, ok := ChooseSupportedVersion([]VersionNumber{1}, []VersionNumber{2}) + Expect(ok).To(BeFalse()) + }) + + It("handles empty inputs", func() { + _, ok := ChooseSupportedVersion([]VersionNumber{102, 101}, []VersionNumber{}) + Expect(ok).To(BeFalse()) + _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{1, 2}) + Expect(ok).To(BeFalse()) + _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{}) + Expect(ok).To(BeFalse()) + }) + }) + + Context("reserved versions", func() { + It("adds a greased version if passed an empty slice", func() { + greased := GetGreasedVersions([]VersionNumber{}) + Expect(greased).To(HaveLen(1)) + Expect(isReservedVersion(greased[0])).To(BeTrue()) + }) + + It("creates greased lists of version numbers", func() { + supported := []VersionNumber{10, 18, 29} + for _, v := range supported { + Expect(isReservedVersion(v)).To(BeFalse()) + } + var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int + // check that + // 1. the greased version sometimes appears first + // 2. the greased version sometimes appears in the middle + // 3. the greased version sometimes appears last + // 4. the supported versions are kept in order + for i := 0; i < 100; i++ { + greased := GetGreasedVersions(supported) + Expect(greased).To(HaveLen(4)) + var j int + for i, v := range greased { + if isReservedVersion(v) { + if i == 0 { + greasedVersionFirst++ + } + if i == len(greased)-1 { + greasedVersionLast++ + } + greasedVersionMiddle++ + continue + } + Expect(supported[j]).To(Equal(v)) + j++ + } + } + Expect(greasedVersionFirst).ToNot(BeZero()) + Expect(greasedVersionLast).ToNot(BeZero()) + Expect(greasedVersionMiddle).ToNot(BeZero()) + }) + }) +}) diff --git a/internal/quic-go/qerr/error_codes.go b/internal/quic-go/qerr/error_codes.go new file mode 100644 index 00000000..fe61d146 --- /dev/null +++ b/internal/quic-go/qerr/error_codes.go @@ -0,0 +1,88 @@ +package qerr + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/qtls" +) + +// TransportErrorCode is a QUIC transport error. +type TransportErrorCode uint64 + +// The error codes defined by QUIC +const ( + NoError TransportErrorCode = 0x0 + InternalError TransportErrorCode = 0x1 + ConnectionRefused TransportErrorCode = 0x2 + FlowControlError TransportErrorCode = 0x3 + StreamLimitError TransportErrorCode = 0x4 + StreamStateError TransportErrorCode = 0x5 + FinalSizeError TransportErrorCode = 0x6 + FrameEncodingError TransportErrorCode = 0x7 + TransportParameterError TransportErrorCode = 0x8 + ConnectionIDLimitError TransportErrorCode = 0x9 + ProtocolViolation TransportErrorCode = 0xa + InvalidToken TransportErrorCode = 0xb + ApplicationErrorErrorCode TransportErrorCode = 0xc + CryptoBufferExceeded TransportErrorCode = 0xd + KeyUpdateError TransportErrorCode = 0xe + AEADLimitReached TransportErrorCode = 0xf + NoViablePathError TransportErrorCode = 0x10 +) + +func (e TransportErrorCode) IsCryptoError() bool { + return e >= 0x100 && e < 0x200 +} + +// Message is a description of the error. +// It only returns a non-empty string for crypto errors. +func (e TransportErrorCode) Message() string { + if !e.IsCryptoError() { + return "" + } + return qtls.Alert(e - 0x100).Error() +} + +func (e TransportErrorCode) String() string { + switch e { + case NoError: + return "NO_ERROR" + case InternalError: + return "INTERNAL_ERROR" + case ConnectionRefused: + return "CONNECTION_REFUSED" + case FlowControlError: + return "FLOW_CONTROL_ERROR" + case StreamLimitError: + return "STREAM_LIMIT_ERROR" + case StreamStateError: + return "STREAM_STATE_ERROR" + case FinalSizeError: + return "FINAL_SIZE_ERROR" + case FrameEncodingError: + return "FRAME_ENCODING_ERROR" + case TransportParameterError: + return "TRANSPORT_PARAMETER_ERROR" + case ConnectionIDLimitError: + return "CONNECTION_ID_LIMIT_ERROR" + case ProtocolViolation: + return "PROTOCOL_VIOLATION" + case InvalidToken: + return "INVALID_TOKEN" + case ApplicationErrorErrorCode: + return "APPLICATION_ERROR" + case CryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case KeyUpdateError: + return "KEY_UPDATE_ERROR" + case AEADLimitReached: + return "AEAD_LIMIT_REACHED" + case NoViablePathError: + return "NO_VIABLE_PATH" + default: + if e.IsCryptoError() { + return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) + } +} diff --git a/internal/quic-go/qerr/errorcodes_test.go b/internal/quic-go/qerr/errorcodes_test.go new file mode 100644 index 00000000..cfc6cd85 --- /dev/null +++ b/internal/quic-go/qerr/errorcodes_test.go @@ -0,0 +1,52 @@ +package qerr + +import ( + "go/ast" + "go/parser" + "go/token" + "path" + "runtime" + "strconv" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("error codes", func() { + // If this test breaks, you should run `go generate ./...` + It("has a string representation for every error code", func() { + // We parse the error code file, extract all constants, and verify that + // each of them has a string version. Go FTW! + _, thisfile, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + filename := path.Join(path.Dir(thisfile), "error_codes.go") + fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) + Expect(err).NotTo(HaveOccurred()) + constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs + Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing + for _, c := range constSpecs { + valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + val, err := strconv.ParseInt(valString, 0, 64) + Expect(err).NotTo(HaveOccurred()) + Expect(TransportErrorCode(val).String()).ToNot(Equal("unknown error code")) + } + }) + + It("has a string representation for unknown error codes", func() { + Expect(TransportErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) + }) + + It("says if an error is a crypto error", func() { + for i := 0; i < 0x100; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } + for i := 0x100; i < 0x200; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeTrue()) + } + for i := 0x200; i < 0x300; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } + }) +}) diff --git a/internal/quic-go/qerr/errors.go b/internal/quic-go/qerr/errors.go new file mode 100644 index 00000000..c2be1040 --- /dev/null +++ b/internal/quic-go/qerr/errors.go @@ -0,0 +1,124 @@ +package qerr + +import ( + "fmt" + "net" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +var ( + ErrHandshakeTimeout = &HandshakeTimeoutError{} + ErrIdleTimeout = &IdleTimeoutError{} +) + +type TransportError struct { + Remote bool + FrameType uint64 + ErrorCode TransportErrorCode + ErrorMessage string +} + +var _ error = &TransportError{} + +// NewCryptoError create a new TransportError instance for a crypto error +func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { + return &TransportError{ + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), + ErrorMessage: errorMessage, + } +} + +func (e *TransportError) Error() string { + str := e.ErrorCode.String() + if e.FrameType != 0 { + str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) + } + msg := e.ErrorMessage + if len(msg) == 0 { + msg = e.ErrorCode.Message() + } + if len(msg) == 0 { + return str + } + return str + ": " + msg +} + +func (e *TransportError) Is(target error) bool { + return target == net.ErrClosed +} + +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint64 + +func (e *ApplicationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StreamErrorCode is an error code used to cancel streams. +type StreamErrorCode uint64 + +type ApplicationError struct { + Remote bool + ErrorCode ApplicationErrorCode + ErrorMessage string +} + +var _ error = &ApplicationError{} + +func (e *ApplicationError) Error() string { + if len(e.ErrorMessage) == 0 { + return fmt.Sprintf("Application error %#x", e.ErrorCode) + } + return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) +} + +type IdleTimeoutError struct{} + +var _ error = &IdleTimeoutError{} + +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +type HandshakeTimeoutError struct{} + +var _ error = &HandshakeTimeoutError{} + +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. +type VersionNegotiationError struct { + Ours []protocol.VersionNumber + Theirs []protocol.VersionNumber +} + +func (e *VersionNegotiationError) Error() string { + return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) +} + +func (e *VersionNegotiationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StatelessResetError occurs when we receive a stateless reset. +type StatelessResetError struct { + Token protocol.StatelessResetToken +} + +var _ net.Error = &StatelessResetError{} + +func (e *StatelessResetError) Error() string { + return fmt.Sprintf("received a stateless reset with token %x", e.Token) +} + +func (e *StatelessResetError) Is(target error) bool { + return target == net.ErrClosed +} + +func (e *StatelessResetError) Timeout() bool { return false } +func (e *StatelessResetError) Temporary() bool { return true } diff --git a/internal/quic-go/qerr/errors_suite_test.go b/internal/quic-go/qerr/errors_suite_test.go new file mode 100644 index 00000000..749cdedc --- /dev/null +++ b/internal/quic-go/qerr/errors_suite_test.go @@ -0,0 +1,13 @@ +package qerr + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestErrorcodes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Errors Suite") +} diff --git a/internal/quic-go/qerr/errors_test.go b/internal/quic-go/qerr/errors_test.go new file mode 100644 index 00000000..3b7f7cae --- /dev/null +++ b/internal/quic-go/qerr/errors_test.go @@ -0,0 +1,124 @@ +package qerr + +import ( + "errors" + "net" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("QUIC Errors", func() { + Context("Transport Errors", func() { + It("has a string representation", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) + }) + + It("has a string representation for empty error phrases", func() { + Expect((&TransportError{ErrorCode: FlowControlError}).Error()).To(Equal("FLOW_CONTROL_ERROR")) + }) + + It("includes the frame type, for errors without a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) + }) + + It("includes the frame type, for errors with a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) + }) + + Context("crypto errors", func() { + It("has a string representation for errors with a message", func() { + err := NewCryptoError(0x42, "foobar") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) + }) + + It("has a string representation for errors without a message", func() { + err := NewCryptoError(0x2a, "") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) + }) + }) + }) + + Context("Application Errors", func() { + It("has a string representation for errors with a message", func() { + Expect((&ApplicationError{ + ErrorCode: 0x42, + ErrorMessage: "foobar", + }).Error()).To(Equal("Application error 0x42: foobar")) + }) + + It("has a string representation for errors without a message", func() { + Expect((&ApplicationError{ + ErrorCode: 0x42, + }).Error()).To(Equal("Application error 0x42")) + }) + }) + + Context("timeout errors", func() { + It("handshake timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &HandshakeTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) + }) + + It("idle timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &IdleTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(err.Error()).To(Equal("timeout: no recent network activity")) + }) + }) + + Context("Version Negotiation errors", func() { + It("has a string representation", func() { + Expect((&VersionNegotiationError{ + Ours: []protocol.VersionNumber{2, 3}, + Theirs: []protocol.VersionNumber{4, 5, 6}, + }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) + }) + }) + + Context("Stateless Reset errors", func() { + token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + + It("has a string representation", func() { + Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f")) + }) + + It("is a net.Error", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &StatelessResetError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeFalse()) + }) + }) + + It("says that errors are net.ErrClosed errors", func() { + Expect(errors.Is(&TransportError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&ApplicationError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&IdleTimeoutError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&StatelessResetError{}, net.ErrClosed)).To(BeTrue()) + Expect(errors.Is(&VersionNegotiationError{}, net.ErrClosed)).To(BeTrue()) + }) +}) diff --git a/internal/quic-go/qlog/event.go b/internal/quic-go/qlog/event.go new file mode 100644 index 00000000..4d799090 --- /dev/null +++ b/internal/quic-go/qlog/event.go @@ -0,0 +1,529 @@ +package qlog + +import ( + "errors" + "fmt" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + + "github.com/francoispqt/gojay" +) + +func milliseconds(dur time.Duration) float64 { return float64(dur.Nanoseconds()) / 1e6 } + +type eventDetails interface { + Category() category + Name() string + gojay.MarshalerJSONObject +} + +type event struct { + RelativeTime time.Duration + eventDetails +} + +var _ gojay.MarshalerJSONObject = event{} + +func (e event) IsNil() bool { return false } +func (e event) MarshalJSONObject(enc *gojay.Encoder) { + enc.Float64Key("time", milliseconds(e.RelativeTime)) + enc.StringKey("name", e.Category().String()+":"+e.Name()) + enc.ObjectKey("data", e.eventDetails) +} + +type versions []versionNumber + +func (v versions) IsNil() bool { return false } +func (v versions) MarshalJSONArray(enc *gojay.Encoder) { + for _, e := range v { + enc.AddString(e.String()) + } +} + +type rawInfo struct { + Length logging.ByteCount // full packet length, including header and AEAD authentication tag + PayloadLength logging.ByteCount // length of the packet payload, excluding AEAD tag +} + +func (i rawInfo) IsNil() bool { return false } +func (i rawInfo) MarshalJSONObject(enc *gojay.Encoder) { + enc.Uint64Key("length", uint64(i.Length)) + enc.Uint64KeyOmitEmpty("payload_length", uint64(i.PayloadLength)) +} + +type eventConnectionStarted struct { + SrcAddr *net.UDPAddr + DestAddr *net.UDPAddr + + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID +} + +var _ eventDetails = &eventConnectionStarted{} + +func (e eventConnectionStarted) Category() category { return categoryTransport } +func (e eventConnectionStarted) Name() string { return "connection_started" } +func (e eventConnectionStarted) IsNil() bool { return false } + +func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) { + if utils.IsIPv4(e.SrcAddr.IP) { + enc.StringKey("ip_version", "ipv4") + } else { + enc.StringKey("ip_version", "ipv6") + } + enc.StringKey("src_ip", e.SrcAddr.IP.String()) + enc.IntKey("src_port", e.SrcAddr.Port) + enc.StringKey("dst_ip", e.DestAddr.IP.String()) + enc.IntKey("dst_port", e.DestAddr.Port) + enc.StringKey("src_cid", connectionID(e.SrcConnectionID).String()) + enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String()) +} + +type eventVersionNegotiated struct { + clientVersions, serverVersions []versionNumber + chosenVersion versionNumber +} + +func (e eventVersionNegotiated) Category() category { return categoryTransport } +func (e eventVersionNegotiated) Name() string { return "version_information" } +func (e eventVersionNegotiated) IsNil() bool { return false } + +func (e eventVersionNegotiated) MarshalJSONObject(enc *gojay.Encoder) { + if len(e.clientVersions) > 0 { + enc.ArrayKey("client_versions", versions(e.clientVersions)) + } + if len(e.serverVersions) > 0 { + enc.ArrayKey("server_versions", versions(e.serverVersions)) + } + enc.StringKey("chosen_version", e.chosenVersion.String()) +} + +type eventConnectionClosed struct { + e error +} + +func (e eventConnectionClosed) Category() category { return categoryTransport } +func (e eventConnectionClosed) Name() string { return "connection_closed" } +func (e eventConnectionClosed) IsNil() bool { return false } + +func (e eventConnectionClosed) MarshalJSONObject(enc *gojay.Encoder) { + var ( + statelessResetErr *quic.StatelessResetError + handshakeTimeoutErr *quic.HandshakeTimeoutError + idleTimeoutErr *quic.IdleTimeoutError + applicationErr *quic.ApplicationError + transportErr *quic.TransportError + versionNegotiationErr *quic.VersionNegotiationError + ) + switch { + case errors.As(e.e, &statelessResetErr): + enc.StringKey("owner", ownerRemote.String()) + enc.StringKey("trigger", "stateless_reset") + enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", statelessResetErr.Token)) + case errors.As(e.e, &handshakeTimeoutErr): + enc.StringKey("owner", ownerLocal.String()) + enc.StringKey("trigger", "handshake_timeout") + case errors.As(e.e, &idleTimeoutErr): + enc.StringKey("owner", ownerLocal.String()) + enc.StringKey("trigger", "idle_timeout") + case errors.As(e.e, &applicationErr): + owner := ownerLocal + if applicationErr.Remote { + owner = ownerRemote + } + enc.StringKey("owner", owner.String()) + enc.Uint64Key("application_code", uint64(applicationErr.ErrorCode)) + enc.StringKey("reason", applicationErr.ErrorMessage) + case errors.As(e.e, &transportErr): + owner := ownerLocal + if transportErr.Remote { + owner = ownerRemote + } + enc.StringKey("owner", owner.String()) + enc.StringKey("connection_code", transportError(transportErr.ErrorCode).String()) + enc.StringKey("reason", transportErr.ErrorMessage) + case errors.As(e.e, &versionNegotiationErr): + enc.StringKey("owner", ownerRemote.String()) + enc.StringKey("trigger", "version_negotiation") + } +} + +type eventPacketSent struct { + Header packetHeader + Length logging.ByteCount + PayloadLength logging.ByteCount + Frames frames + IsCoalesced bool + Trigger string +} + +var _ eventDetails = eventPacketSent{} + +func (e eventPacketSent) Category() category { return categoryTransport } +func (e eventPacketSent) Name() string { return "packet_sent" } +func (e eventPacketSent) IsNil() bool { return false } + +func (e eventPacketSent) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", e.Header) + enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) + enc.ArrayKeyOmitEmpty("frames", e.Frames) + enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + enc.StringKeyOmitEmpty("trigger", e.Trigger) +} + +type eventPacketReceived struct { + Header packetHeader + Length logging.ByteCount + PayloadLength logging.ByteCount + Frames frames + IsCoalesced bool + Trigger string +} + +var _ eventDetails = eventPacketReceived{} + +func (e eventPacketReceived) Category() category { return categoryTransport } +func (e eventPacketReceived) Name() string { return "packet_received" } +func (e eventPacketReceived) IsNil() bool { return false } + +func (e eventPacketReceived) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", e.Header) + enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) + enc.ArrayKeyOmitEmpty("frames", e.Frames) + enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + enc.StringKeyOmitEmpty("trigger", e.Trigger) +} + +type eventRetryReceived struct { + Header packetHeader +} + +func (e eventRetryReceived) Category() category { return categoryTransport } +func (e eventRetryReceived) Name() string { return "packet_received" } +func (e eventRetryReceived) IsNil() bool { return false } + +func (e eventRetryReceived) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", e.Header) +} + +type eventVersionNegotiationReceived struct { + Header packetHeader + SupportedVersions []versionNumber +} + +func (e eventVersionNegotiationReceived) Category() category { return categoryTransport } +func (e eventVersionNegotiationReceived) Name() string { return "packet_received" } +func (e eventVersionNegotiationReceived) IsNil() bool { return false } + +func (e eventVersionNegotiationReceived) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", e.Header) + enc.ArrayKey("supported_versions", versions(e.SupportedVersions)) +} + +type eventPacketBuffered struct { + PacketType logging.PacketType +} + +func (e eventPacketBuffered) Category() category { return categoryTransport } +func (e eventPacketBuffered) Name() string { return "packet_buffered" } +func (e eventPacketBuffered) IsNil() bool { return false } + +func (e eventPacketBuffered) MarshalJSONObject(enc *gojay.Encoder) { + //nolint:gosimple + enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) + enc.StringKey("trigger", "keys_unavailable") +} + +type eventPacketDropped struct { + PacketType logging.PacketType + PacketSize protocol.ByteCount + Trigger packetDropReason +} + +func (e eventPacketDropped) Category() category { return categoryTransport } +func (e eventPacketDropped) Name() string { return "packet_dropped" } +func (e eventPacketDropped) IsNil() bool { return false } + +func (e eventPacketDropped) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) + enc.ObjectKey("raw", rawInfo{Length: e.PacketSize}) + enc.StringKey("trigger", e.Trigger.String()) +} + +type metrics struct { + MinRTT time.Duration + SmoothedRTT time.Duration + LatestRTT time.Duration + RTTVariance time.Duration + + CongestionWindow protocol.ByteCount + BytesInFlight protocol.ByteCount + PacketsInFlight int +} + +type eventMetricsUpdated struct { + Last *metrics + Current *metrics +} + +func (e eventMetricsUpdated) Category() category { return categoryRecovery } +func (e eventMetricsUpdated) Name() string { return "metrics_updated" } +func (e eventMetricsUpdated) IsNil() bool { return false } + +func (e eventMetricsUpdated) MarshalJSONObject(enc *gojay.Encoder) { + if e.Last == nil || e.Last.MinRTT != e.Current.MinRTT { + enc.FloatKey("min_rtt", milliseconds(e.Current.MinRTT)) + } + if e.Last == nil || e.Last.SmoothedRTT != e.Current.SmoothedRTT { + enc.FloatKey("smoothed_rtt", milliseconds(e.Current.SmoothedRTT)) + } + if e.Last == nil || e.Last.LatestRTT != e.Current.LatestRTT { + enc.FloatKey("latest_rtt", milliseconds(e.Current.LatestRTT)) + } + if e.Last == nil || e.Last.RTTVariance != e.Current.RTTVariance { + enc.FloatKey("rtt_variance", milliseconds(e.Current.RTTVariance)) + } + + if e.Last == nil || e.Last.CongestionWindow != e.Current.CongestionWindow { + enc.Uint64Key("congestion_window", uint64(e.Current.CongestionWindow)) + } + if e.Last == nil || e.Last.BytesInFlight != e.Current.BytesInFlight { + enc.Uint64Key("bytes_in_flight", uint64(e.Current.BytesInFlight)) + } + if e.Last == nil || e.Last.PacketsInFlight != e.Current.PacketsInFlight { + enc.Uint64KeyOmitEmpty("packets_in_flight", uint64(e.Current.PacketsInFlight)) + } +} + +type eventUpdatedPTO struct { + Value uint32 +} + +func (e eventUpdatedPTO) Category() category { return categoryRecovery } +func (e eventUpdatedPTO) Name() string { return "metrics_updated" } +func (e eventUpdatedPTO) IsNil() bool { return false } + +func (e eventUpdatedPTO) MarshalJSONObject(enc *gojay.Encoder) { + enc.Uint32Key("pto_count", e.Value) +} + +type eventPacketLost struct { + PacketType logging.PacketType + PacketNumber protocol.PacketNumber + Trigger packetLossReason +} + +func (e eventPacketLost) Category() category { return categoryRecovery } +func (e eventPacketLost) Name() string { return "packet_lost" } +func (e eventPacketLost) IsNil() bool { return false } + +func (e eventPacketLost) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", packetHeaderWithTypeAndPacketNumber{ + PacketType: e.PacketType, + PacketNumber: e.PacketNumber, + }) + enc.StringKey("trigger", e.Trigger.String()) +} + +type eventKeyUpdated struct { + Trigger keyUpdateTrigger + KeyType keyType + Generation protocol.KeyPhase + // we don't log the keys here, so we don't need `old` and `new`. +} + +func (e eventKeyUpdated) Category() category { return categorySecurity } +func (e eventKeyUpdated) Name() string { return "key_updated" } +func (e eventKeyUpdated) IsNil() bool { return false } + +func (e eventKeyUpdated) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("trigger", e.Trigger.String()) + enc.StringKey("key_type", e.KeyType.String()) + if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { + enc.Uint64Key("generation", uint64(e.Generation)) + } +} + +type eventKeyRetired struct { + KeyType keyType + Generation protocol.KeyPhase +} + +func (e eventKeyRetired) Category() category { return categorySecurity } +func (e eventKeyRetired) Name() string { return "key_retired" } +func (e eventKeyRetired) IsNil() bool { return false } + +func (e eventKeyRetired) MarshalJSONObject(enc *gojay.Encoder) { + if e.KeyType != keyTypeClient1RTT && e.KeyType != keyTypeServer1RTT { + enc.StringKey("trigger", "tls") + } + enc.StringKey("key_type", e.KeyType.String()) + if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { + enc.Uint64Key("generation", uint64(e.Generation)) + } +} + +type eventTransportParameters struct { + Restore bool + Owner owner + SentBy protocol.Perspective + + OriginalDestinationConnectionID protocol.ConnectionID + InitialSourceConnectionID protocol.ConnectionID + RetrySourceConnectionID *protocol.ConnectionID + + StatelessResetToken *protocol.StatelessResetToken + DisableActiveMigration bool + MaxIdleTimeout time.Duration + MaxUDPPayloadSize protocol.ByteCount + AckDelayExponent uint8 + MaxAckDelay time.Duration + ActiveConnectionIDLimit uint64 + + InitialMaxData protocol.ByteCount + InitialMaxStreamDataBidiLocal protocol.ByteCount + InitialMaxStreamDataBidiRemote protocol.ByteCount + InitialMaxStreamDataUni protocol.ByteCount + InitialMaxStreamsBidi int64 + InitialMaxStreamsUni int64 + + PreferredAddress *preferredAddress + + MaxDatagramFrameSize protocol.ByteCount +} + +func (e eventTransportParameters) Category() category { return categoryTransport } +func (e eventTransportParameters) Name() string { + if e.Restore { + return "parameters_restored" + } + return "parameters_set" +} +func (e eventTransportParameters) IsNil() bool { return false } + +func (e eventTransportParameters) MarshalJSONObject(enc *gojay.Encoder) { + if !e.Restore { + enc.StringKey("owner", e.Owner.String()) + if e.SentBy == protocol.PerspectiveServer { + enc.StringKey("original_destination_connection_id", connectionID(e.OriginalDestinationConnectionID).String()) + if e.StatelessResetToken != nil { + enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", e.StatelessResetToken[:])) + } + if e.RetrySourceConnectionID != nil { + enc.StringKey("retry_source_connection_id", connectionID(*e.RetrySourceConnectionID).String()) + } + } + enc.StringKey("initial_source_connection_id", connectionID(e.InitialSourceConnectionID).String()) + } + enc.BoolKey("disable_active_migration", e.DisableActiveMigration) + enc.FloatKeyOmitEmpty("max_idle_timeout", milliseconds(e.MaxIdleTimeout)) + enc.Int64KeyNullEmpty("max_udp_payload_size", int64(e.MaxUDPPayloadSize)) + enc.Uint8KeyOmitEmpty("ack_delay_exponent", e.AckDelayExponent) + enc.FloatKeyOmitEmpty("max_ack_delay", milliseconds(e.MaxAckDelay)) + enc.Uint64KeyOmitEmpty("active_connection_id_limit", e.ActiveConnectionIDLimit) + + enc.Int64KeyOmitEmpty("initial_max_data", int64(e.InitialMaxData)) + enc.Int64KeyOmitEmpty("initial_max_stream_data_bidi_local", int64(e.InitialMaxStreamDataBidiLocal)) + enc.Int64KeyOmitEmpty("initial_max_stream_data_bidi_remote", int64(e.InitialMaxStreamDataBidiRemote)) + enc.Int64KeyOmitEmpty("initial_max_stream_data_uni", int64(e.InitialMaxStreamDataUni)) + enc.Int64KeyOmitEmpty("initial_max_streams_bidi", e.InitialMaxStreamsBidi) + enc.Int64KeyOmitEmpty("initial_max_streams_uni", e.InitialMaxStreamsUni) + + if e.PreferredAddress != nil { + enc.ObjectKey("preferred_address", e.PreferredAddress) + } + if e.MaxDatagramFrameSize != protocol.InvalidByteCount { + enc.Int64Key("max_datagram_frame_size", int64(e.MaxDatagramFrameSize)) + } +} + +type preferredAddress struct { + IPv4, IPv6 net.IP + PortV4, PortV6 uint16 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +var _ gojay.MarshalerJSONObject = &preferredAddress{} + +func (a preferredAddress) IsNil() bool { return false } +func (a preferredAddress) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("ip_v4", a.IPv4.String()) + enc.Uint16Key("port_v4", a.PortV4) + enc.StringKey("ip_v6", a.IPv6.String()) + enc.Uint16Key("port_v6", a.PortV6) + enc.StringKey("connection_id", connectionID(a.ConnectionID).String()) + enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", a.StatelessResetToken)) +} + +type eventLossTimerSet struct { + TimerType timerType + EncLevel protocol.EncryptionLevel + Delta time.Duration +} + +func (e eventLossTimerSet) Category() category { return categoryRecovery } +func (e eventLossTimerSet) Name() string { return "loss_timer_updated" } +func (e eventLossTimerSet) IsNil() bool { return false } + +func (e eventLossTimerSet) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("event_type", "set") + enc.StringKey("timer_type", e.TimerType.String()) + enc.StringKey("packet_number_space", encLevelToPacketNumberSpace(e.EncLevel)) + enc.Float64Key("delta", milliseconds(e.Delta)) +} + +type eventLossTimerExpired struct { + TimerType timerType + EncLevel protocol.EncryptionLevel +} + +func (e eventLossTimerExpired) Category() category { return categoryRecovery } +func (e eventLossTimerExpired) Name() string { return "loss_timer_updated" } +func (e eventLossTimerExpired) IsNil() bool { return false } + +func (e eventLossTimerExpired) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("event_type", "expired") + enc.StringKey("timer_type", e.TimerType.String()) + enc.StringKey("packet_number_space", encLevelToPacketNumberSpace(e.EncLevel)) +} + +type eventLossTimerCanceled struct{} + +func (e eventLossTimerCanceled) Category() category { return categoryRecovery } +func (e eventLossTimerCanceled) Name() string { return "loss_timer_updated" } +func (e eventLossTimerCanceled) IsNil() bool { return false } + +func (e eventLossTimerCanceled) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("event_type", "cancelled") +} + +type eventCongestionStateUpdated struct { + state congestionState +} + +func (e eventCongestionStateUpdated) Category() category { return categoryRecovery } +func (e eventCongestionStateUpdated) Name() string { return "congestion_state_updated" } +func (e eventCongestionStateUpdated) IsNil() bool { return false } + +func (e eventCongestionStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("new", e.state.String()) +} + +type eventGeneric struct { + name string + msg string +} + +func (e eventGeneric) Category() category { return categoryTransport } +func (e eventGeneric) Name() string { return e.name } +func (e eventGeneric) IsNil() bool { return false } + +func (e eventGeneric) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("details", e.msg) +} diff --git a/internal/quic-go/qlog/event_test.go b/internal/quic-go/qlog/event_test.go new file mode 100644 index 00000000..4248e69e --- /dev/null +++ b/internal/quic-go/qlog/event_test.go @@ -0,0 +1,43 @@ +package qlog + +import ( + "bytes" + "encoding/json" + "time" + + "github.com/francoispqt/gojay" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mevent struct{} + +var _ eventDetails = mevent{} + +func (mevent) Category() category { return categoryConnectivity } +func (mevent) Name() string { return "mevent" } +func (mevent) IsNil() bool { return false } +func (mevent) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("event", "details") } + +var _ = Describe("Events", func() { + It("marshals the fields before the event details", func() { + buf := &bytes.Buffer{} + enc := gojay.NewEncoder(buf) + Expect(enc.Encode(event{ + RelativeTime: 1337 * time.Microsecond, + eventDetails: mevent{}, + })).To(Succeed()) + + var decoded interface{} + Expect(json.Unmarshal(buf.Bytes(), &decoded)).To(Succeed()) + Expect(decoded).To(HaveLen(3)) + + Expect(decoded).To(HaveKeyWithValue("time", 1.337)) + Expect(decoded).To(HaveKeyWithValue("name", "connectivity:mevent")) + Expect(decoded).To(HaveKey("data")) + data := decoded.(map[string]interface{})["data"].(map[string]interface{}) + Expect(data).To(HaveLen(1)) + Expect(data).To(HaveKeyWithValue("event", "details")) + }) +}) diff --git a/internal/quic-go/qlog/frame.go b/internal/quic-go/qlog/frame.go new file mode 100644 index 00000000..c6e58253 --- /dev/null +++ b/internal/quic-go/qlog/frame.go @@ -0,0 +1,227 @@ +package qlog + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/francoispqt/gojay" +) + +type frame struct { + Frame logging.Frame +} + +var _ gojay.MarshalerJSONObject = frame{} + +var _ gojay.MarshalerJSONArray = frames{} + +func (f frame) MarshalJSONObject(enc *gojay.Encoder) { + switch frame := f.Frame.(type) { + case *logging.PingFrame: + marshalPingFrame(enc, frame) + case *logging.AckFrame: + marshalAckFrame(enc, frame) + case *logging.ResetStreamFrame: + marshalResetStreamFrame(enc, frame) + case *logging.StopSendingFrame: + marshalStopSendingFrame(enc, frame) + case *logging.CryptoFrame: + marshalCryptoFrame(enc, frame) + case *logging.NewTokenFrame: + marshalNewTokenFrame(enc, frame) + case *logging.StreamFrame: + marshalStreamFrame(enc, frame) + case *logging.MaxDataFrame: + marshalMaxDataFrame(enc, frame) + case *logging.MaxStreamDataFrame: + marshalMaxStreamDataFrame(enc, frame) + case *logging.MaxStreamsFrame: + marshalMaxStreamsFrame(enc, frame) + case *logging.DataBlockedFrame: + marshalDataBlockedFrame(enc, frame) + case *logging.StreamDataBlockedFrame: + marshalStreamDataBlockedFrame(enc, frame) + case *logging.StreamsBlockedFrame: + marshalStreamsBlockedFrame(enc, frame) + case *logging.NewConnectionIDFrame: + marshalNewConnectionIDFrame(enc, frame) + case *logging.RetireConnectionIDFrame: + marshalRetireConnectionIDFrame(enc, frame) + case *logging.PathChallengeFrame: + marshalPathChallengeFrame(enc, frame) + case *logging.PathResponseFrame: + marshalPathResponseFrame(enc, frame) + case *logging.ConnectionCloseFrame: + marshalConnectionCloseFrame(enc, frame) + case *logging.HandshakeDoneFrame: + marshalHandshakeDoneFrame(enc, frame) + case *logging.DatagramFrame: + marshalDatagramFrame(enc, frame) + default: + panic("unknown frame type") + } +} + +func (f frame) IsNil() bool { return false } + +type frames []frame + +func (fs frames) IsNil() bool { return fs == nil } +func (fs frames) MarshalJSONArray(enc *gojay.Encoder) { + for _, f := range fs { + enc.Object(f) + } +} + +func marshalPingFrame(enc *gojay.Encoder, _ *wire.PingFrame) { + enc.StringKey("frame_type", "ping") +} + +type ackRanges []wire.AckRange + +func (ars ackRanges) MarshalJSONArray(enc *gojay.Encoder) { + for _, r := range ars { + enc.Array(ackRange(r)) + } +} + +func (ars ackRanges) IsNil() bool { return false } + +type ackRange wire.AckRange + +func (ar ackRange) MarshalJSONArray(enc *gojay.Encoder) { + enc.AddInt64(int64(ar.Smallest)) + if ar.Smallest != ar.Largest { + enc.AddInt64(int64(ar.Largest)) + } +} + +func (ar ackRange) IsNil() bool { return false } + +func marshalAckFrame(enc *gojay.Encoder, f *logging.AckFrame) { + enc.StringKey("frame_type", "ack") + enc.FloatKeyOmitEmpty("ack_delay", milliseconds(f.DelayTime)) + enc.ArrayKey("acked_ranges", ackRanges(f.AckRanges)) + if hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0; hasECN { + enc.Uint64Key("ect0", f.ECT0) + enc.Uint64Key("ect1", f.ECT1) + enc.Uint64Key("ce", f.ECNCE) + } +} + +func marshalResetStreamFrame(enc *gojay.Encoder, f *logging.ResetStreamFrame) { + enc.StringKey("frame_type", "reset_stream") + enc.Int64Key("stream_id", int64(f.StreamID)) + enc.Int64Key("error_code", int64(f.ErrorCode)) + enc.Int64Key("final_size", int64(f.FinalSize)) +} + +func marshalStopSendingFrame(enc *gojay.Encoder, f *logging.StopSendingFrame) { + enc.StringKey("frame_type", "stop_sending") + enc.Int64Key("stream_id", int64(f.StreamID)) + enc.Int64Key("error_code", int64(f.ErrorCode)) +} + +func marshalCryptoFrame(enc *gojay.Encoder, f *logging.CryptoFrame) { + enc.StringKey("frame_type", "crypto") + enc.Int64Key("offset", int64(f.Offset)) + enc.Int64Key("length", int64(f.Length)) +} + +func marshalNewTokenFrame(enc *gojay.Encoder, f *logging.NewTokenFrame) { + enc.StringKey("frame_type", "new_token") + enc.ObjectKey("token", &token{Raw: f.Token}) +} + +func marshalStreamFrame(enc *gojay.Encoder, f *logging.StreamFrame) { + enc.StringKey("frame_type", "stream") + enc.Int64Key("stream_id", int64(f.StreamID)) + enc.Int64Key("offset", int64(f.Offset)) + enc.IntKey("length", int(f.Length)) + enc.BoolKeyOmitEmpty("fin", f.Fin) +} + +func marshalMaxDataFrame(enc *gojay.Encoder, f *logging.MaxDataFrame) { + enc.StringKey("frame_type", "max_data") + enc.Int64Key("maximum", int64(f.MaximumData)) +} + +func marshalMaxStreamDataFrame(enc *gojay.Encoder, f *logging.MaxStreamDataFrame) { + enc.StringKey("frame_type", "max_stream_data") + enc.Int64Key("stream_id", int64(f.StreamID)) + enc.Int64Key("maximum", int64(f.MaximumStreamData)) +} + +func marshalMaxStreamsFrame(enc *gojay.Encoder, f *logging.MaxStreamsFrame) { + enc.StringKey("frame_type", "max_streams") + enc.StringKey("stream_type", streamType(f.Type).String()) + enc.Int64Key("maximum", int64(f.MaxStreamNum)) +} + +func marshalDataBlockedFrame(enc *gojay.Encoder, f *logging.DataBlockedFrame) { + enc.StringKey("frame_type", "data_blocked") + enc.Int64Key("limit", int64(f.MaximumData)) +} + +func marshalStreamDataBlockedFrame(enc *gojay.Encoder, f *logging.StreamDataBlockedFrame) { + enc.StringKey("frame_type", "stream_data_blocked") + enc.Int64Key("stream_id", int64(f.StreamID)) + enc.Int64Key("limit", int64(f.MaximumStreamData)) +} + +func marshalStreamsBlockedFrame(enc *gojay.Encoder, f *logging.StreamsBlockedFrame) { + enc.StringKey("frame_type", "streams_blocked") + enc.StringKey("stream_type", streamType(f.Type).String()) + enc.Int64Key("limit", int64(f.StreamLimit)) +} + +func marshalNewConnectionIDFrame(enc *gojay.Encoder, f *logging.NewConnectionIDFrame) { + enc.StringKey("frame_type", "new_connection_id") + enc.Int64Key("sequence_number", int64(f.SequenceNumber)) + enc.Int64Key("retire_prior_to", int64(f.RetirePriorTo)) + enc.IntKey("length", f.ConnectionID.Len()) + enc.StringKey("connection_id", connectionID(f.ConnectionID).String()) + enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", f.StatelessResetToken)) +} + +func marshalRetireConnectionIDFrame(enc *gojay.Encoder, f *logging.RetireConnectionIDFrame) { + enc.StringKey("frame_type", "retire_connection_id") + enc.Int64Key("sequence_number", int64(f.SequenceNumber)) +} + +func marshalPathChallengeFrame(enc *gojay.Encoder, f *logging.PathChallengeFrame) { + enc.StringKey("frame_type", "path_challenge") + enc.StringKey("data", fmt.Sprintf("%x", f.Data[:])) +} + +func marshalPathResponseFrame(enc *gojay.Encoder, f *logging.PathResponseFrame) { + enc.StringKey("frame_type", "path_response") + enc.StringKey("data", fmt.Sprintf("%x", f.Data[:])) +} + +func marshalConnectionCloseFrame(enc *gojay.Encoder, f *logging.ConnectionCloseFrame) { + errorSpace := "transport" + if f.IsApplicationError { + errorSpace = "application" + } + enc.StringKey("frame_type", "connection_close") + enc.StringKey("error_space", errorSpace) + if errName := transportError(f.ErrorCode).String(); len(errName) > 0 { + enc.StringKey("error_code", errName) + } else { + enc.Uint64Key("error_code", f.ErrorCode) + } + enc.Uint64Key("raw_error_code", f.ErrorCode) + enc.StringKey("reason", f.ReasonPhrase) +} + +func marshalHandshakeDoneFrame(enc *gojay.Encoder, _ *logging.HandshakeDoneFrame) { + enc.StringKey("frame_type", "handshake_done") +} + +func marshalDatagramFrame(enc *gojay.Encoder, f *logging.DatagramFrame) { + enc.StringKey("frame_type", "datagram") + enc.Int64Key("length", int64(f.Length)) +} diff --git a/internal/quic-go/qlog/frame_test.go b/internal/quic-go/qlog/frame_test.go new file mode 100644 index 00000000..5e78d8ef --- /dev/null +++ b/internal/quic-go/qlog/frame_test.go @@ -0,0 +1,377 @@ +package qlog + +import ( + "bytes" + "encoding/json" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + + "github.com/francoispqt/gojay" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Frames", func() { + check := func(f logging.Frame, expected map[string]interface{}) { + buf := &bytes.Buffer{} + enc := gojay.NewEncoder(buf) + ExpectWithOffset(1, enc.Encode(frame{Frame: f})).To(Succeed()) + data := buf.Bytes() + ExpectWithOffset(1, json.Valid(data)).To(BeTrue()) + checkEncoding(data, expected) + } + + It("marshals PING frames", func() { + check( + &logging.PingFrame{}, + map[string]interface{}{ + "frame_type": "ping", + }, + ) + }) + + It("marshals ACK frames with a range acknowledging a single packet", func() { + check( + &logging.AckFrame{ + DelayTime: 86 * time.Millisecond, + AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, + }, + map[string]interface{}{ + "frame_type": "ack", + "ack_delay": 86, + "acked_ranges": [][]float64{{120}}, + }, + ) + }) + + It("marshals ACK frames without a delay", func() { + check( + &logging.AckFrame{ + AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, + }, + map[string]interface{}{ + "frame_type": "ack", + "acked_ranges": [][]float64{{120}}, + }, + ) + }) + + It("marshals ACK frames with ECN counts", func() { + check( + &logging.AckFrame{ + AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, + ECT0: 10, + ECT1: 100, + ECNCE: 1000, + }, + map[string]interface{}{ + "frame_type": "ack", + "acked_ranges": [][]float64{{120}}, + "ect0": 10, + "ect1": 100, + "ce": 1000, + }, + ) + }) + + It("marshals ACK frames with a range acknowledging ranges of packets", func() { + check( + &logging.AckFrame{ + DelayTime: 86 * time.Millisecond, + AckRanges: []logging.AckRange{ + {Smallest: 5, Largest: 50}, + {Smallest: 100, Largest: 120}, + }, + }, + map[string]interface{}{ + "frame_type": "ack", + "ack_delay": 86, + "acked_ranges": [][]float64{ + {5, 50}, + {100, 120}, + }, + }, + ) + }) + + It("marshals RESET_STREAM frames", func() { + check( + &logging.ResetStreamFrame{ + StreamID: 987, + FinalSize: 1234, + ErrorCode: 42, + }, + map[string]interface{}{ + "frame_type": "reset_stream", + "stream_id": 987, + "error_code": 42, + "final_size": 1234, + }, + ) + }) + + It("marshals STOP_SENDING frames", func() { + check( + &logging.StopSendingFrame{ + StreamID: 987, + ErrorCode: 42, + }, + map[string]interface{}{ + "frame_type": "stop_sending", + "stream_id": 987, + "error_code": 42, + }, + ) + }) + + It("marshals CRYPTO frames", func() { + check( + &logging.CryptoFrame{ + Offset: 1337, + Length: 6, + }, + map[string]interface{}{ + "frame_type": "crypto", + "offset": 1337, + "length": 6, + }, + ) + }) + + It("marshals NEW_TOKEN frames", func() { + check( + &logging.NewTokenFrame{ + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + }, + map[string]interface{}{ + "frame_type": "new_token", + "token": map[string]interface{}{"data": "deadbeef"}, + }, + ) + }) + + It("marshals STREAM frames with FIN", func() { + check( + &logging.StreamFrame{ + StreamID: 42, + Offset: 1337, + Fin: true, + Length: 9876, + }, + map[string]interface{}{ + "frame_type": "stream", + "stream_id": 42, + "offset": 1337, + "fin": true, + "length": 9876, + }, + ) + }) + + It("marshals STREAM frames without FIN", func() { + check( + &logging.StreamFrame{ + StreamID: 42, + Offset: 1337, + Length: 3, + }, + map[string]interface{}{ + "frame_type": "stream", + "stream_id": 42, + "offset": 1337, + "length": 3, + }, + ) + }) + + It("marshals MAX_DATA frames", func() { + check( + &logging.MaxDataFrame{ + MaximumData: 1337, + }, + map[string]interface{}{ + "frame_type": "max_data", + "maximum": 1337, + }, + ) + }) + + It("marshals MAX_STREAM_DATA frames", func() { + check( + &logging.MaxStreamDataFrame{ + StreamID: 1234, + MaximumStreamData: 1337, + }, + map[string]interface{}{ + "frame_type": "max_stream_data", + "stream_id": 1234, + "maximum": 1337, + }, + ) + }) + + It("marshals MAX_STREAMS frames", func() { + check( + &logging.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 42, + }, + map[string]interface{}{ + "frame_type": "max_streams", + "stream_type": "bidirectional", + "maximum": 42, + }, + ) + }) + + It("marshals DATA_BLOCKED frames", func() { + check( + &logging.DataBlockedFrame{ + MaximumData: 1337, + }, + map[string]interface{}{ + "frame_type": "data_blocked", + "limit": 1337, + }, + ) + }) + + It("marshals STREAM_DATA_BLOCKED frames", func() { + check( + &logging.StreamDataBlockedFrame{ + StreamID: 42, + MaximumStreamData: 1337, + }, + map[string]interface{}{ + "frame_type": "stream_data_blocked", + "stream_id": 42, + "limit": 1337, + }, + ) + }) + + It("marshals STREAMS_BLOCKED frames", func() { + check( + &logging.StreamsBlockedFrame{ + Type: protocol.StreamTypeUni, + StreamLimit: 123, + }, + map[string]interface{}{ + "frame_type": "streams_blocked", + "stream_type": "unidirectional", + "limit": 123, + }, + ) + }) + + It("marshals NEW_CONNECTION_ID frames", func() { + check( + &logging.NewConnectionIDFrame{ + SequenceNumber: 42, + RetirePriorTo: 24, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, + }, + map[string]interface{}{ + "frame_type": "new_connection_id", + "sequence_number": 42, + "retire_prior_to": 24, + "length": 4, + "connection_id": "deadbeef", + "stateless_reset_token": "000102030405060708090a0b0c0d0e0f", + }, + ) + }) + + It("marshals RETIRE_CONNECTION_ID frames", func() { + check( + &logging.RetireConnectionIDFrame{ + SequenceNumber: 1337, + }, + map[string]interface{}{ + "frame_type": "retire_connection_id", + "sequence_number": 1337, + }, + ) + }) + + It("marshals PATH_CHALLENGE frames", func() { + check( + &logging.PathChallengeFrame{ + Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}, + }, + map[string]interface{}{ + "frame_type": "path_challenge", + "data": "deadbeefcafec001", + }, + ) + }) + + It("marshals PATH_RESPONSE frames", func() { + check( + &logging.PathResponseFrame{ + Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}, + }, + map[string]interface{}{ + "frame_type": "path_response", + "data": "deadbeefcafec001", + }, + ) + }) + + It("marshals CONNECTION_CLOSE frames, for application error codes", func() { + check( + &logging.ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 1337, + ReasonPhrase: "lorem ipsum", + }, + map[string]interface{}{ + "frame_type": "connection_close", + "error_space": "application", + "error_code": 1337, + "raw_error_code": 1337, + "reason": "lorem ipsum", + }, + ) + }) + + It("marshals CONNECTION_CLOSE frames, for transport error codes", func() { + check( + &logging.ConnectionCloseFrame{ + ErrorCode: uint64(qerr.FlowControlError), + ReasonPhrase: "lorem ipsum", + }, + map[string]interface{}{ + "frame_type": "connection_close", + "error_space": "transport", + "error_code": "flow_control_error", + "raw_error_code": int(qerr.FlowControlError), + "reason": "lorem ipsum", + }, + ) + }) + + It("marshals HANDSHAKE_DONE frames", func() { + check( + &logging.HandshakeDoneFrame{}, + map[string]interface{}{ + "frame_type": "handshake_done", + }, + ) + }) + + It("marshals DATAGRAM frames", func() { + check( + &logging.DatagramFrame{Length: 1337}, + map[string]interface{}{ + "frame_type": "datagram", + "length": 1337, + }, + ) + }) +}) diff --git a/internal/quic-go/qlog/packet_header.go b/internal/quic-go/qlog/packet_header.go new file mode 100644 index 00000000..d029c8d2 --- /dev/null +++ b/internal/quic-go/qlog/packet_header.go @@ -0,0 +1,119 @@ +package qlog + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/francoispqt/gojay" +) + +func getPacketTypeFromEncryptionLevel(encLevel protocol.EncryptionLevel) logging.PacketType { + switch encLevel { + case protocol.EncryptionInitial: + return logging.PacketTypeInitial + case protocol.EncryptionHandshake: + return logging.PacketTypeHandshake + case protocol.Encryption0RTT: + return logging.PacketType0RTT + case protocol.Encryption1RTT: + return logging.PacketType1RTT + default: + panic("unknown encryption level") + } +} + +type token struct { + Raw []byte +} + +var _ gojay.MarshalerJSONObject = &token{} + +func (t token) IsNil() bool { return false } +func (t token) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("data", fmt.Sprintf("%x", t.Raw)) +} + +// PacketHeader is a QUIC packet header. +type packetHeader struct { + PacketType logging.PacketType + + KeyPhaseBit logging.KeyPhaseBit + PacketNumber logging.PacketNumber + + Version logging.VersionNumber + SrcConnectionID logging.ConnectionID + DestConnectionID logging.ConnectionID + + Token *token +} + +func transformHeader(hdr *wire.Header) *packetHeader { + h := &packetHeader{ + PacketType: logging.PacketTypeFromHeader(hdr), + SrcConnectionID: hdr.SrcConnectionID, + DestConnectionID: hdr.DestConnectionID, + Version: hdr.Version, + } + if len(hdr.Token) > 0 { + h.Token = &token{Raw: hdr.Token} + } + return h +} + +func transformExtendedHeader(hdr *wire.ExtendedHeader) *packetHeader { + h := transformHeader(&hdr.Header) + h.PacketNumber = hdr.PacketNumber + h.KeyPhaseBit = hdr.KeyPhase + return h +} + +func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("packet_type", packetType(h.PacketType).String()) + if h.PacketType != logging.PacketTypeRetry && h.PacketType != logging.PacketTypeVersionNegotiation { + enc.Int64Key("packet_number", int64(h.PacketNumber)) + } + if h.Version != 0 { + enc.StringKey("version", versionNumber(h.Version).String()) + } + if h.PacketType != logging.PacketType1RTT { + enc.IntKey("scil", h.SrcConnectionID.Len()) + if h.SrcConnectionID.Len() > 0 { + enc.StringKey("scid", connectionID(h.SrcConnectionID).String()) + } + } + enc.IntKey("dcil", h.DestConnectionID.Len()) + if h.DestConnectionID.Len() > 0 { + enc.StringKey("dcid", connectionID(h.DestConnectionID).String()) + } + if h.KeyPhaseBit == logging.KeyPhaseZero || h.KeyPhaseBit == logging.KeyPhaseOne { + enc.StringKey("key_phase_bit", h.KeyPhaseBit.String()) + } + if h.Token != nil { + enc.ObjectKey("token", h.Token) + } +} + +// a minimal header that only outputs the packet type +type packetHeaderWithType struct { + PacketType logging.PacketType +} + +func (h packetHeaderWithType) IsNil() bool { return false } +func (h packetHeaderWithType) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("packet_type", packetType(h.PacketType).String()) +} + +// a minimal header that only outputs the packet type +type packetHeaderWithTypeAndPacketNumber struct { + PacketType logging.PacketType + PacketNumber logging.PacketNumber +} + +func (h packetHeaderWithTypeAndPacketNumber) IsNil() bool { return false } +func (h packetHeaderWithTypeAndPacketNumber) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("packet_type", packetType(h.PacketType).String()) + enc.Int64Key("packet_number", int64(h.PacketNumber)) +} diff --git a/internal/quic-go/qlog/packet_header_test.go b/internal/quic-go/qlog/packet_header_test.go new file mode 100644 index 00000000..91dfeb39 --- /dev/null +++ b/internal/quic-go/qlog/packet_header_test.go @@ -0,0 +1,175 @@ +package qlog + +import ( + "bytes" + "encoding/json" + + "github.com/francoispqt/gojay" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packet Header", func() { + It("determines the packet type from the encryption level", func() { + Expect(getPacketTypeFromEncryptionLevel(protocol.EncryptionInitial)).To(BeEquivalentTo(logging.PacketTypeInitial)) + Expect(getPacketTypeFromEncryptionLevel(protocol.EncryptionHandshake)).To(BeEquivalentTo(logging.PacketTypeHandshake)) + Expect(getPacketTypeFromEncryptionLevel(protocol.Encryption0RTT)).To(BeEquivalentTo(logging.PacketType0RTT)) + Expect(getPacketTypeFromEncryptionLevel(protocol.Encryption1RTT)).To(BeEquivalentTo(logging.PacketType1RTT)) + }) + + Context("marshalling", func() { + check := func(hdr *wire.ExtendedHeader, expected map[string]interface{}) { + buf := &bytes.Buffer{} + enc := gojay.NewEncoder(buf) + ExpectWithOffset(1, enc.Encode(transformExtendedHeader(hdr))).To(Succeed()) + data := buf.Bytes() + ExpectWithOffset(1, json.Valid(data)).To(BeTrue()) + checkEncoding(data, expected) + } + + It("marshals a header for a 1-RTT packet", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 42, + KeyPhase: protocol.KeyPhaseZero, + }, + map[string]interface{}{ + "packet_type": "1RTT", + "packet_number": 42, + "dcil": 0, + "key_phase_bit": "0", + }, + ) + }) + + It("marshals a header with a payload length", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 42, + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Length: 123, + Version: protocol.VersionNumber(0xdecafbad), + }, + }, + map[string]interface{}{ + "packet_type": "initial", + "packet_number": 42, + "dcil": 0, + "scil": 0, + "version": "decafbad", + }, + ) + }) + + It("marshals an Initial with a token", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 4242, + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Length: 123, + Version: protocol.VersionNumber(0xdecafbad), + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + }, + }, + map[string]interface{}{ + "packet_type": "initial", + "packet_number": 4242, + "dcil": 0, + "scil": 0, + "version": "decafbad", + "token": map[string]interface{}{"data": "deadbeef"}, + }, + ) + }) + + It("marshals a Retry packet", func() { + check( + &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0x11, 0x22, 0x33, 0x44}, + Version: protocol.VersionNumber(0xdecafbad), + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + }, + }, + map[string]interface{}{ + "packet_type": "retry", + "dcil": 0, + "scil": 4, + "scid": "11223344", + "token": map[string]interface{}{"data": "deadbeef"}, + "version": "decafbad", + }, + ) + }) + + It("marshals a packet with packet number 0", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 0, + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: protocol.VersionNumber(0xdecafbad), + }, + }, + map[string]interface{}{ + "packet_type": "handshake", + "packet_number": 0, + "dcil": 0, + "scil": 0, + "version": "decafbad", + }, + ) + }) + + It("marshals a header with a source connection ID", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 42, + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: protocol.ConnectionID{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + Version: protocol.VersionNumber(0xdecafbad), + }, + }, + map[string]interface{}{ + "packet_type": "handshake", + "packet_number": 42, + "dcil": 0, + "scil": 16, + "scid": "00112233445566778899aabbccddeeff", + "version": "decafbad", + }, + ) + }) + + It("marshals a 1-RTT header with a destination connection ID", func() { + check( + &wire.ExtendedHeader{ + PacketNumber: 42, + Header: wire.Header{DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + KeyPhase: protocol.KeyPhaseOne, + }, + map[string]interface{}{ + "packet_type": "1RTT", + "packet_number": 42, + "dcil": 4, + "dcid": "deadbeef", + "key_phase_bit": "1", + }, + ) + }) + }) +}) diff --git a/internal/quic-go/qlog/qlog.go b/internal/quic-go/qlog/qlog.go new file mode 100644 index 00000000..feaa296e --- /dev/null +++ b/internal/quic-go/qlog/qlog.go @@ -0,0 +1,486 @@ +package qlog + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "runtime/debug" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/francoispqt/gojay" +) + +// Setting of this only works when quic-go is used as a library. +// When building a binary from this repository, the version can be set using the following go build flag: +// -ldflags="-X github.com/imroc/req/v3/internal/quic-go/qlog.quicGoVersion=foobar" +var quicGoVersion = "(devel)" + +func init() { + if quicGoVersion != "(devel)" { // variable set by ldflags + return + } + info, ok := debug.ReadBuildInfo() + if !ok { // no build info available. This happens when quic-go is not used as a library. + return + } + for _, d := range info.Deps { + if d.Path == "github.com/imroc/req/v3/internal/quic-go" { + quicGoVersion = d.Version + if d.Replace != nil { + if len(d.Replace.Version) > 0 { + quicGoVersion = d.Version + } else { + quicGoVersion += " (replaced)" + } + } + break + } + } +} + +const eventChanSize = 50 + +type tracer struct { + getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser +} + +var _ logging.Tracer = &tracer{} + +// NewTracer creates a new qlog tracer. +func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer { + return &tracer{getLogWriter: getLogWriter} +} + +func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { + if w := t.getLogWriter(p, odcid.Bytes()); w != nil { + return NewConnectionTracer(w, p, odcid) + } + return nil +} + +func (t *tracer) SentPacket(net.Addr, *logging.Header, protocol.ByteCount, []logging.Frame) {} +func (t *tracer) DroppedPacket(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { +} + +type connectionTracer struct { + mutex sync.Mutex + + w io.WriteCloser + odcid protocol.ConnectionID + perspective protocol.Perspective + referenceTime time.Time + + events chan event + encodeErr error + runStopped chan struct{} + + lastMetrics *metrics +} + +var _ logging.ConnectionTracer = &connectionTracer{} + +// NewConnectionTracer creates a new tracer to record a qlog for a connection. +func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { + t := &connectionTracer{ + w: w, + perspective: p, + odcid: odcid, + runStopped: make(chan struct{}), + events: make(chan event, eventChanSize), + referenceTime: time.Now(), + } + go t.run() + return t +} + +func (t *connectionTracer) run() { + defer close(t.runStopped) + buf := &bytes.Buffer{} + enc := gojay.NewEncoder(buf) + tl := &topLevel{ + trace: trace{ + VantagePoint: vantagePoint{Type: t.perspective}, + CommonFields: commonFields{ + ODCID: connectionID(t.odcid), + GroupID: connectionID(t.odcid), + ReferenceTime: t.referenceTime, + }, + }, + } + if err := enc.Encode(tl); err != nil { + panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) + } + if err := buf.WriteByte('\n'); err != nil { + panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) + } + if _, err := t.w.Write(buf.Bytes()); err != nil { + t.encodeErr = err + } + enc = gojay.NewEncoder(t.w) + for ev := range t.events { + if t.encodeErr != nil { // if encoding failed, just continue draining the event channel + continue + } + if err := enc.Encode(ev); err != nil { + t.encodeErr = err + continue + } + if _, err := t.w.Write([]byte{'\n'}); err != nil { + t.encodeErr = err + } + } +} + +func (t *connectionTracer) Close() { + if err := t.export(); err != nil { + log.Printf("exporting qlog failed: %s\n", err) + } +} + +// export writes a qlog. +func (t *connectionTracer) export() error { + close(t.events) + <-t.runStopped + if t.encodeErr != nil { + return t.encodeErr + } + return t.w.Close() +} + +func (t *connectionTracer) recordEvent(eventTime time.Time, details eventDetails) { + t.events <- event{ + RelativeTime: eventTime.Sub(t.referenceTime), + eventDetails: details, + } +} + +func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID protocol.ConnectionID) { + // ignore this event if we're not dealing with UDP addresses here + localAddr, ok := local.(*net.UDPAddr) + if !ok { + return + } + remoteAddr, ok := remote.(*net.UDPAddr) + if !ok { + return + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventConnectionStarted{ + SrcAddr: localAddr, + DestAddr: remoteAddr, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, client, server []logging.VersionNumber) { + var clientVersions, serverVersions []versionNumber + if len(client) > 0 { + clientVersions = make([]versionNumber, len(client)) + for i, v := range client { + clientVersions[i] = versionNumber(v) + } + } + if len(server) > 0 { + serverVersions = make([]versionNumber, len(server)) + for i, v := range server { + serverVersions[i] = versionNumber(v) + } + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventVersionNegotiated{ + clientVersions: clientVersions, + serverVersions: serverVersions, + chosenVersion: versionNumber(chosen), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) ClosedConnection(e error) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventConnectionClosed{e: e}) + t.mutex.Unlock() +} + +func (t *connectionTracer) SentTransportParameters(tp *wire.TransportParameters) { + t.recordTransportParameters(t.perspective, tp) +} + +func (t *connectionTracer) ReceivedTransportParameters(tp *wire.TransportParameters) { + t.recordTransportParameters(t.perspective.Opposite(), tp) +} + +func (t *connectionTracer) RestoredTransportParameters(tp *wire.TransportParameters) { + ev := t.toTransportParameters(tp) + ev.Restore = true + + t.mutex.Lock() + t.recordEvent(time.Now(), ev) + t.mutex.Unlock() +} + +func (t *connectionTracer) recordTransportParameters(sentBy protocol.Perspective, tp *wire.TransportParameters) { + ev := t.toTransportParameters(tp) + ev.Owner = ownerLocal + if sentBy != t.perspective { + ev.Owner = ownerRemote + } + ev.SentBy = sentBy + + t.mutex.Lock() + t.recordEvent(time.Now(), ev) + t.mutex.Unlock() +} + +func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) *eventTransportParameters { + var pa *preferredAddress + if tp.PreferredAddress != nil { + pa = &preferredAddress{ + IPv4: tp.PreferredAddress.IPv4, + PortV4: tp.PreferredAddress.IPv4Port, + IPv6: tp.PreferredAddress.IPv6, + PortV6: tp.PreferredAddress.IPv6Port, + ConnectionID: tp.PreferredAddress.ConnectionID, + StatelessResetToken: tp.PreferredAddress.StatelessResetToken, + } + } + return &eventTransportParameters{ + OriginalDestinationConnectionID: tp.OriginalDestinationConnectionID, + InitialSourceConnectionID: tp.InitialSourceConnectionID, + RetrySourceConnectionID: tp.RetrySourceConnectionID, + StatelessResetToken: tp.StatelessResetToken, + DisableActiveMigration: tp.DisableActiveMigration, + MaxIdleTimeout: tp.MaxIdleTimeout, + MaxUDPPayloadSize: tp.MaxUDPPayloadSize, + AckDelayExponent: tp.AckDelayExponent, + MaxAckDelay: tp.MaxAckDelay, + ActiveConnectionIDLimit: tp.ActiveConnectionIDLimit, + InitialMaxData: tp.InitialMaxData, + InitialMaxStreamDataBidiLocal: tp.InitialMaxStreamDataBidiLocal, + InitialMaxStreamDataBidiRemote: tp.InitialMaxStreamDataBidiRemote, + InitialMaxStreamDataUni: tp.InitialMaxStreamDataUni, + InitialMaxStreamsBidi: int64(tp.MaxBidiStreamNum), + InitialMaxStreamsUni: int64(tp.MaxUniStreamNum), + PreferredAddress: pa, + MaxDatagramFrameSize: tp.MaxDatagramFrameSize, + } +} + +func (t *connectionTracer) SentPacket(hdr *wire.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { + numFrames := len(frames) + if ack != nil { + numFrames++ + } + fs := make([]frame, 0, numFrames) + if ack != nil { + fs = append(fs, frame{Frame: ack}) + } + for _, f := range frames { + fs = append(fs, frame{Frame: f}) + } + header := *transformExtendedHeader(hdr) + t.mutex.Lock() + t.recordEvent(time.Now(), &eventPacketSent{ + Header: header, + Length: packetSize, + PayloadLength: hdr.Length, + Frames: fs, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) ReceivedPacket(hdr *wire.ExtendedHeader, packetSize logging.ByteCount, frames []logging.Frame) { + fs := make([]frame, len(frames)) + for i, f := range frames { + fs[i] = frame{Frame: f} + } + header := *transformExtendedHeader(hdr) + t.mutex.Lock() + t.recordEvent(time.Now(), &eventPacketReceived{ + Header: header, + Length: packetSize, + PayloadLength: hdr.Length, + Frames: fs, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) ReceivedRetry(hdr *wire.Header) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventRetryReceived{ + Header: *transformHeader(hdr), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) ReceivedVersionNegotiationPacket(hdr *wire.Header, versions []logging.VersionNumber) { + ver := make([]versionNumber, len(versions)) + for i, v := range versions { + ver[i] = versionNumber(v) + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ + Header: *transformHeader(hdr), + SupportedVersions: ver, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) BufferedPacket(pt logging.PacketType) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventPacketBuffered{PacketType: pt}) + t.mutex.Unlock() +} + +func (t *connectionTracer) DroppedPacket(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventPacketDropped{ + PacketType: pt, + PacketSize: size, + Trigger: packetDropReason(reason), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) UpdatedMetrics(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { + m := &metrics{ + MinRTT: rttStats.MinRTT(), + SmoothedRTT: rttStats.SmoothedRTT(), + LatestRTT: rttStats.LatestRTT(), + RTTVariance: rttStats.MeanDeviation(), + CongestionWindow: cwnd, + BytesInFlight: bytesInFlight, + PacketsInFlight: packetsInFlight, + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventMetricsUpdated{ + Last: t.lastMetrics, + Current: m, + }) + t.lastMetrics = m + t.mutex.Unlock() +} + +func (t *connectionTracer) AcknowledgedPacket(protocol.EncryptionLevel, protocol.PacketNumber) {} + +func (t *connectionTracer) LostPacket(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventPacketLost{ + PacketType: getPacketTypeFromEncryptionLevel(encLevel), + PacketNumber: pn, + Trigger: packetLossReason(lossReason), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) UpdatedCongestionState(state logging.CongestionState) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventCongestionStateUpdated{state: congestionState(state)}) + t.mutex.Unlock() +} + +func (t *connectionTracer) UpdatedPTOCount(value uint32) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventUpdatedPTO{Value: value}) + t.mutex.Unlock() +} + +func (t *connectionTracer) UpdatedKeyFromTLS(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventKeyUpdated{ + Trigger: keyUpdateTLS, + KeyType: encLevelToKeyType(encLevel, pers), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) UpdatedKey(generation protocol.KeyPhase, remote bool) { + trigger := keyUpdateLocal + if remote { + trigger = keyUpdateRemote + } + t.mutex.Lock() + now := time.Now() + t.recordEvent(now, &eventKeyUpdated{ + Trigger: trigger, + KeyType: keyTypeClient1RTT, + Generation: generation, + }) + t.recordEvent(now, &eventKeyUpdated{ + Trigger: trigger, + KeyType: keyTypeServer1RTT, + Generation: generation, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) DroppedEncryptionLevel(encLevel protocol.EncryptionLevel) { + t.mutex.Lock() + now := time.Now() + if encLevel == protocol.Encryption0RTT { + t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, t.perspective)}) + } else { + t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveServer)}) + t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveClient)}) + } + t.mutex.Unlock() +} + +func (t *connectionTracer) DroppedKey(generation protocol.KeyPhase) { + t.mutex.Lock() + now := time.Now() + t.recordEvent(now, &eventKeyRetired{ + KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer), + Generation: generation, + }) + t.recordEvent(now, &eventKeyRetired{ + KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient), + Generation: generation, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) SetLossTimer(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { + t.mutex.Lock() + now := time.Now() + t.recordEvent(now, &eventLossTimerSet{ + TimerType: timerType(tt), + EncLevel: encLevel, + Delta: timeout.Sub(now), + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) LossTimerExpired(tt logging.TimerType, encLevel protocol.EncryptionLevel) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventLossTimerExpired{ + TimerType: timerType(tt), + EncLevel: encLevel, + }) + t.mutex.Unlock() +} + +func (t *connectionTracer) LossTimerCanceled() { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventLossTimerCanceled{}) + t.mutex.Unlock() +} + +func (t *connectionTracer) Debug(name, msg string) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventGeneric{ + name: name, + msg: msg, + }) + t.mutex.Unlock() +} diff --git a/internal/quic-go/qlog/qlog_suite_test.go b/internal/quic-go/qlog/qlog_suite_test.go new file mode 100644 index 00000000..73f4917e --- /dev/null +++ b/internal/quic-go/qlog/qlog_suite_test.go @@ -0,0 +1,51 @@ +package qlog + +import ( + "encoding/json" + "os" + "strconv" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQlog(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "qlog Suite") +} + +//nolint:unparam +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} + +func checkEncoding(data []byte, expected map[string]interface{}) { + // unmarshal the data + m := make(map[string]interface{}) + ExpectWithOffset(1, json.Unmarshal(data, &m)).To(Succeed()) + ExpectWithOffset(1, m).To(HaveLen(len(expected))) + for key, value := range expected { + switch v := value.(type) { + case bool, string, map[string]interface{}: + ExpectWithOffset(1, m).To(HaveKeyWithValue(key, v)) + case int: + ExpectWithOffset(1, m).To(HaveKeyWithValue(key, float64(v))) + case [][]float64: // used in the ACK frame + ExpectWithOffset(1, m).To(HaveKey(key)) + for i, l := range v { + for j, s := range l { + ExpectWithOffset(1, m[key].([]interface{})[i].([]interface{})[j].(float64)).To(Equal(s)) + } + } + default: + Fail("unexpected type") + } + } +} diff --git a/internal/quic-go/qlog/qlog_test.go b/internal/quic-go/qlog/qlog_test.go new file mode 100644 index 00000000..f5849927 --- /dev/null +++ b/internal/quic-go/qlog/qlog_test.go @@ -0,0 +1,849 @@ +package qlog + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "log" + "net" + "os" + "time" + + "github.com/imroc/req/v3/internal/quic-go" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type nopWriteCloserImpl struct{ io.Writer } + +func (nopWriteCloserImpl) Close() error { return nil } + +func nopWriteCloser(w io.Writer) io.WriteCloser { + return &nopWriteCloserImpl{Writer: w} +} + +type limitedWriter struct { + io.WriteCloser + N int + written int +} + +func (w *limitedWriter) Write(p []byte) (int, error) { + if w.written+len(p) > w.N { + return 0, errors.New("writer full") + } + n, err := w.WriteCloser.Write(p) + w.written += n + return n, err +} + +type entry struct { + Time time.Time + Name string + Event map[string]interface{} +} + +var _ = Describe("Tracing", func() { + Context("tracer", func() { + It("returns nil when there's no io.WriteCloser", func() { + t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) + Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) + }) + }) + + It("stops writing when encountering an error", func() { + buf := &bytes.Buffer{} + t := NewConnectionTracer( + &limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}, + protocol.PerspectiveServer, + protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ) + for i := uint32(0); i < 1000; i++ { + t.UpdatedPTOCount(i) + } + + b := &bytes.Buffer{} + log.SetOutput(b) + defer log.SetOutput(os.Stdout) + t.Close() + Expect(b.String()).To(ContainSubstring("writer full")) + }) + + Context("connection tracer", func() { + var ( + tracer logging.ConnectionTracer + buf *bytes.Buffer + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) + tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + }) + + It("exports a trace that has the right metadata", func() { + tracer.Close() + + m := make(map[string]interface{}) + Expect(json.Unmarshal(buf.Bytes(), &m)).To(Succeed()) + Expect(m).To(HaveKeyWithValue("qlog_version", "draft-02")) + Expect(m).To(HaveKey("title")) + Expect(m).To(HaveKey("trace")) + trace := m["trace"].(map[string]interface{}) + Expect(trace).To(HaveKey(("common_fields"))) + commonFields := trace["common_fields"].(map[string]interface{}) + Expect(commonFields).To(HaveKeyWithValue("ODCID", "deadbeef")) + Expect(commonFields).To(HaveKeyWithValue("group_id", "deadbeef")) + Expect(commonFields).To(HaveKey("reference_time")) + referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) + Expect(referenceTime).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(commonFields).To(HaveKeyWithValue("time_format", "relative")) + Expect(trace).To(HaveKey("vantage_point")) + vantagePoint := trace["vantage_point"].(map[string]interface{}) + Expect(vantagePoint).To(HaveKeyWithValue("type", "server")) + }) + + Context("Events", func() { + exportAndParse := func() []entry { + tracer.Close() + + m := make(map[string]interface{}) + line, err := buf.ReadBytes('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(json.Unmarshal(line, &m)).To(Succeed()) + Expect(m).To(HaveKey("trace")) + var entries []entry + trace := m["trace"].(map[string]interface{}) + Expect(trace).To(HaveKey("common_fields")) + commonFields := trace["common_fields"].(map[string]interface{}) + Expect(commonFields).To(HaveKey("reference_time")) + referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) + Expect(trace).ToNot(HaveKey("events")) + + for buf.Len() > 0 { + line, err := buf.ReadBytes('\n') + Expect(err).ToNot(HaveOccurred()) + ev := make(map[string]interface{}) + Expect(json.Unmarshal(line, &ev)).To(Succeed()) + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKey("time")) + Expect(ev).To(HaveKey("name")) + Expect(ev).To(HaveKey("data")) + entries = append(entries, entry{ + Time: referenceTime.Add(time.Duration(ev["time"].(float64)*1e6) * time.Nanosecond), + Name: ev["name"].(string), + Event: ev["data"].(map[string]interface{}), + }) + } + return entries + } + + exportAndParseSingle := func() entry { + entries := exportAndParse() + Expect(entries).To(HaveLen(1)) + return entries[0] + } + + It("records connection starts", func() { + tracer.StartedConnection( + &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 42}, + &net.UDPAddr{IP: net.IPv4(192, 168, 12, 34), Port: 24}, + protocol.ConnectionID{1, 2, 3, 4}, + protocol.ConnectionID{5, 6, 7, 8}, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_started")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("ip_version", "ipv4")) + Expect(ev).To(HaveKeyWithValue("src_ip", "192.168.13.37")) + Expect(ev).To(HaveKeyWithValue("src_port", float64(42))) + Expect(ev).To(HaveKeyWithValue("dst_ip", "192.168.12.34")) + Expect(ev).To(HaveKeyWithValue("dst_port", float64(24))) + Expect(ev).To(HaveKeyWithValue("src_cid", "01020304")) + Expect(ev).To(HaveKeyWithValue("dst_cid", "05060708")) + }) + + It("records the version, if no version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, nil, nil) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + }) + + It("records the version, if version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, []logging.VersionNumber{1, 2, 3}, []logging.VersionNumber{4, 5, 6}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + Expect(ev).To(HaveKey("client_versions")) + Expect(ev["client_versions"].([]interface{})).To(Equal([]interface{}{"1", "2", "3"})) + Expect(ev).To(HaveKey("server_versions")) + Expect(ev["server_versions"].([]interface{})).To(Equal([]interface{}{"4", "5", "6"})) + }) + + It("records idle timeouts", func() { + tracer.ClosedConnection(&quic.IdleTimeoutError{}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("trigger", "idle_timeout")) + }) + + It("records handshake timeouts", func() { + tracer.ClosedConnection(&quic.HandshakeTimeoutError{}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("trigger", "handshake_timeout")) + }) + + It("records a received stateless reset packet", func() { + tracer.ClosedConnection(&quic.StatelessResetError{ + Token: protocol.StatelessResetToken{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).To(HaveKeyWithValue("trigger", "stateless_reset")) + Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "00112233445566778899aabbccddeeff")) + }) + + It("records connection closing due to version negotiation failure", func() { + tracer.ClosedConnection(&quic.VersionNegotiationError{}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).To(HaveKeyWithValue("trigger", "version_negotiation")) + }) + + It("records application errors", func() { + tracer.ClosedConnection(&quic.ApplicationError{ + Remote: true, + ErrorCode: 1337, + ErrorMessage: "foobar", + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).To(HaveKeyWithValue("application_code", float64(1337))) + Expect(ev).To(HaveKeyWithValue("reason", "foobar")) + }) + + It("records transport errors", func() { + tracer.ClosedConnection(&quic.TransportError{ + ErrorCode: qerr.AEADLimitReached, + ErrorMessage: "foobar", + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("connection_code", "aead_limit_reached")) + Expect(ev).To(HaveKeyWithValue("reason", "foobar")) + }) + + It("records sent transport parameters", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1000, + InitialMaxStreamDataBidiRemote: 2000, + InitialMaxStreamDataUni: 3000, + InitialMaxData: 4000, + MaxBidiStreamNum: 10, + MaxUniStreamNum: 20, + MaxAckDelay: 123 * time.Millisecond, + AckDelayExponent: 12, + DisableActiveMigration: true, + MaxUDPPayloadSize: 1234, + MaxIdleTimeout: 321 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: protocol.InvalidByteCount, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("original_destination_connection_id", "deadc0de")) + Expect(ev).To(HaveKeyWithValue("initial_source_connection_id", "deadbeef")) + Expect(ev).To(HaveKeyWithValue("retry_source_connection_id", "decafbad")) + Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "112233445566778899aabbccddeeff00")) + Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(321))) + Expect(ev).To(HaveKeyWithValue("max_udp_payload_size", float64(1234))) + Expect(ev).To(HaveKeyWithValue("ack_delay_exponent", float64(12))) + Expect(ev).To(HaveKeyWithValue("active_connection_id_limit", float64(7))) + Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(4000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(1000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(2000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(3000))) + Expect(ev).To(HaveKeyWithValue("initial_max_streams_bidi", float64(10))) + Expect(ev).To(HaveKeyWithValue("initial_max_streams_uni", float64(20))) + Expect(ev).ToNot(HaveKey("preferred_address")) + Expect(ev).ToNot(HaveKey("max_datagram_frame_size")) + }) + + It("records the server's transport parameters, without a stateless reset token", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + ActiveConnectionIDLimit: 7, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("stateless_reset_token")) + }) + + It("records transport parameters without retry_source_connection_id", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).ToNot(HaveKey("retry_source_connection_id")) + }) + + It("records transport parameters with a preferred address", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + PreferredAddress: &logging.PreferredAddress{ + IPv4: net.IPv4(12, 34, 56, 78), + IPv4Port: 123, + IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + IPv6Port: 456, + ConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + }, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKey("preferred_address")) + pa := ev["preferred_address"].(map[string]interface{}) + Expect(pa).To(HaveKeyWithValue("ip_v4", "12.34.56.78")) + Expect(pa).To(HaveKeyWithValue("port_v4", float64(123))) + Expect(pa).To(HaveKeyWithValue("ip_v6", "102:304:506:708:90a:b0c:d0e:f10")) + Expect(pa).To(HaveKeyWithValue("port_v6", float64(456))) + Expect(pa).To(HaveKeyWithValue("connection_id", "0807060504030201")) + Expect(pa).To(HaveKeyWithValue("stateless_reset_token", "0f0e0d0c0b0a09080706050403020100")) + }) + + It("records transport parameters that enable the datagram extension", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + MaxDatagramFrameSize: 1337, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("max_datagram_frame_size", float64(1337))) + }) + + It("records received transport parameters", func() { + tracer.ReceivedTransportParameters(&logging.TransportParameters{}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).ToNot(HaveKey("original_destination_connection_id")) + }) + + It("records restored transport parameters", func() { + tracer.RestoredTransportParameters(&logging.TransportParameters{ + InitialMaxStreamDataBidiLocal: 100, + InitialMaxStreamDataBidiRemote: 200, + InitialMaxStreamDataUni: 300, + InitialMaxData: 400, + MaxIdleTimeout: 123 * time.Millisecond, + }) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_restored")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("owner")) + Expect(ev).ToNot(HaveKey("original_destination_connection_id")) + Expect(ev).ToNot(HaveKey("stateless_reset_token")) + Expect(ev).ToNot(HaveKey("retry_source_connection_id")) + Expect(ev).ToNot(HaveKey("initial_source_connection_id")) + Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(123))) + Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(400))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(100))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(200))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(300))) + }) + + It("records a sent packet, without an ACK", func() { + tracer.SentPacket( + &logging.ExtendedHeader{ + Header: logging.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + Length: 1337, + Version: protocol.VersionTLS, + }, + PacketNumber: 1337, + }, + 987, + nil, + []logging.Frame{ + &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, + &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, + }, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_sent")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(987))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(1337))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) + Expect(ev).To(HaveKey("frames")) + frames := ev["frames"].([]interface{}) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data")) + Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "stream")) + }) + + It("records a sent packet, without an ACK", func() { + tracer.SentPacket( + &logging.ExtendedHeader{ + Header: logging.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}}, + PacketNumber: 1337, + }, + 123, + &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}}, + []logging.Frame{&logging.MaxDataFrame{MaximumData: 987}}, + ) + entry := exportAndParseSingle() + ev := entry.Event + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(123))) + Expect(raw).ToNot(HaveKey("payload_length")) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(ev).To(HaveKey("frames")) + frames := ev["frames"].([]interface{}) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "ack")) + Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_data")) + }) + + It("records a received packet", func() { + tracer.ReceivedPacket( + &logging.ExtendedHeader{ + Header: logging.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Length: 1234, + Version: protocol.VersionTLS, + }, + PacketNumber: 1337, + }, + 789, + []logging.Frame{ + &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, + &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, + }, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(789))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "initial")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) + Expect(hdr).To(HaveKey("token")) + token := hdr["token"].(map[string]interface{}) + Expect(token).To(HaveKeyWithValue("data", "deadbeef")) + Expect(ev).To(HaveKey("frames")) + Expect(ev["frames"].([]interface{})).To(HaveLen(2)) + }) + + It("records a received Retry packet", func() { + tracer.ReceivedRetry( + &logging.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Version: protocol.VersionTLS, + }, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("raw")) + Expect(ev).To(HaveKey("header")) + header := ev["header"].(map[string]interface{}) + Expect(header).To(HaveKeyWithValue("packet_type", "retry")) + Expect(header).ToNot(HaveKey("packet_number")) + Expect(header).To(HaveKey("version")) + Expect(header).To(HaveKey("dcid")) + Expect(header).To(HaveKey("scid")) + Expect(header).To(HaveKey("token")) + token := header["token"].(map[string]interface{}) + Expect(token).To(HaveKeyWithValue("data", "deadbeef")) + Expect(ev).ToNot(HaveKey("frames")) + }) + + It("records a received Version Negotiation packet", func() { + tracer.ReceivedVersionNegotiationPacket( + &logging.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + }, + []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("frames")) + Expect(ev).To(HaveKey("supported_versions")) + Expect(ev["supported_versions"].([]interface{})).To(Equal([]interface{}{"deadbeef", "decafbad"})) + header := ev["header"] + Expect(header).To(HaveKeyWithValue("packet_type", "version_negotiation")) + Expect(header).ToNot(HaveKey("packet_number")) + Expect(header).ToNot(HaveKey("version")) + Expect(header).To(HaveKey("dcid")) + Expect(header).To(HaveKey("scid")) + }) + + It("records buffered packets", func() { + tracer.BufferedPacket(logging.PacketTypeHandshake) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_buffered")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(1)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(ev).To(HaveKeyWithValue("trigger", "keys_unavailable")) + }) + + It("records dropped packets", func() { + tracer.DroppedPacket(logging.PacketTypeHandshake, 1337, logging.PacketDropPayloadDecryptError) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_dropped")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + Expect(ev["raw"].(map[string]interface{})).To(HaveKeyWithValue("length", float64(1337))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(1)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(ev).To(HaveKeyWithValue("trigger", "payload_decrypt_error")) + }) + + It("records metrics updates", func() { + now := time.Now() + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(15*time.Millisecond, 0, now) + rttStats.UpdateRTT(20*time.Millisecond, 0, now) + rttStats.UpdateRTT(25*time.Millisecond, 0, now) + Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) + Expect(rttStats.SmoothedRTT()).To(And( + BeNumerically(">", 15*time.Millisecond), + BeNumerically("<", 25*time.Millisecond), + )) + Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) + tracer.UpdatedMetrics( + rttStats, + 4321, + 1234, + 42, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:metrics_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("min_rtt", float64(15))) + Expect(ev).To(HaveKeyWithValue("latest_rtt", float64(25))) + Expect(ev).To(HaveKey("smoothed_rtt")) + Expect(time.Duration(ev["smoothed_rtt"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.SmoothedRTT(), time.Millisecond)) + Expect(ev).To(HaveKey("rtt_variance")) + Expect(time.Duration(ev["rtt_variance"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.MeanDeviation(), time.Millisecond)) + Expect(ev).To(HaveKeyWithValue("congestion_window", float64(4321))) + Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(1234))) + Expect(ev).To(HaveKeyWithValue("packets_in_flight", float64(42))) + }) + + It("only logs the diff between two metrics updates", func() { + now := time.Now() + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(15*time.Millisecond, 0, now) + rttStats.UpdateRTT(20*time.Millisecond, 0, now) + rttStats.UpdateRTT(25*time.Millisecond, 0, now) + Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) + + rttStats2 := utils.NewRTTStats() + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + Expect(rttStats2.MinRTT()).To(Equal(15 * time.Millisecond)) + + Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) + tracer.UpdatedMetrics( + rttStats, + 4321, + 1234, + 42, + ) + tracer.UpdatedMetrics( + rttStats2, + 4321, + 12345, // changed + 42, + ) + entries := exportAndParse() + Expect(entries).To(HaveLen(2)) + Expect(entries[0].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entries[0].Name).To(Equal("recovery:metrics_updated")) + Expect(entries[0].Event).To(HaveLen(7)) + Expect(entries[1].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entries[1].Name).To(Equal("recovery:metrics_updated")) + ev := entries[1].Event + Expect(ev).ToNot(HaveKey("min_rtt")) + Expect(ev).ToNot(HaveKey("congestion_window")) + Expect(ev).ToNot(HaveKey("packets_in_flight")) + Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(12345))) + Expect(ev).To(HaveKeyWithValue("smoothed_rtt", float64(15))) + }) + + It("records lost packets", func() { + tracer.LostPacket(protocol.EncryptionHandshake, 42, logging.PacketLossReorderingThreshold) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:packet_lost")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(2)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(42))) + Expect(ev).To(HaveKeyWithValue("trigger", "reordering_threshold")) + }) + + It("records congestion state updates", func() { + tracer.UpdatedCongestionState(logging.CongestionStateCongestionAvoidance) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:congestion_state_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("new", "congestion_avoidance")) + }) + + It("records PTO changes", func() { + tracer.UpdatedPTOCount(42) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:metrics_updated")) + Expect(entry.Event).To(HaveKeyWithValue("pto_count", float64(42))) + }) + + It("records TLS key updates", func() { + tracer.UpdatedKeyFromTLS(protocol.EncryptionHandshake, protocol.PerspectiveClient) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_type", "client_handshake_secret")) + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).ToNot(HaveKey("generation")) + Expect(ev).ToNot(HaveKey("old")) + Expect(ev).ToNot(HaveKey("new")) + }) + + It("records TLS key updates, for 1-RTT keys", func() { + tracer.UpdatedKeyFromTLS(protocol.Encryption1RTT, protocol.PerspectiveServer) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_type", "server_1rtt_secret")) + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKeyWithValue("generation", float64(0))) + Expect(ev).ToNot(HaveKey("old")) + Expect(ev).ToNot(HaveKey("new")) + }) + + It("records QUIC key updates", func() { + tracer.UpdatedKey(1337, true) + entries := exportAndParse() + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("generation", float64(1337))) + Expect(ev).To(HaveKeyWithValue("trigger", "remote_update")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_1rtt_secret")) + Expect(keyTypes).To(ContainElement("client_1rtt_secret")) + }) + + It("records dropped encryption levels", func() { + tracer.DroppedEncryptionLevel(protocol.EncryptionInitial) + entries := exportAndParse() + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_retired")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_initial_secret")) + Expect(keyTypes).To(ContainElement("client_initial_secret")) + }) + + It("records dropped 0-RTT keys", func() { + tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + entries := exportAndParse() + Expect(entries).To(HaveLen(1)) + entry := entries[0] + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_retired")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKeyWithValue("key_type", "server_0rtt_secret")) + }) + + It("records dropped keys", func() { + tracer.DroppedKey(42) + entries := exportAndParse() + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_retired")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("generation", float64(42))) + Expect(ev).ToNot(HaveKey("trigger")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_1rtt_secret")) + Expect(keyTypes).To(ContainElement("client_1rtt_secret")) + }) + + It("records when the timer is set", func() { + timeout := time.Now().Add(137 * time.Millisecond) + tracer.SetLossTimer(logging.TimerTypePTO, protocol.EncryptionHandshake, timeout) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(4)) + Expect(ev).To(HaveKeyWithValue("event_type", "set")) + Expect(ev).To(HaveKeyWithValue("timer_type", "pto")) + Expect(ev).To(HaveKeyWithValue("packet_number_space", "handshake")) + Expect(ev).To(HaveKey("delta")) + delta := time.Duration(ev["delta"].(float64)*1e6) * time.Nanosecond + Expect(entry.Time.Add(delta)).To(BeTemporally("~", timeout, scaleDuration(10*time.Microsecond))) + }) + + It("records when the loss timer expires", func() { + tracer.LossTimerExpired(logging.TimerTypeACK, protocol.Encryption1RTT) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("event_type", "expired")) + Expect(ev).To(HaveKeyWithValue("timer_type", "ack")) + Expect(ev).To(HaveKeyWithValue("packet_number_space", "application_data")) + }) + + It("records when the timer is canceled", func() { + tracer.LossTimerCanceled() + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("event_type", "cancelled")) + }) + + It("records a generic event", func() { + tracer.Debug("foo", "bar") + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:foo")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("details", "bar")) + }) + }) + }) +}) diff --git a/internal/quic-go/qlog/trace.go b/internal/quic-go/qlog/trace.go new file mode 100644 index 00000000..a3ae43b4 --- /dev/null +++ b/internal/quic-go/qlog/trace.go @@ -0,0 +1,66 @@ +package qlog + +import ( + "time" + + "github.com/francoispqt/gojay" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +type topLevel struct { + trace trace +} + +func (topLevel) IsNil() bool { return false } +func (l topLevel) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("qlog_format", "NDJSON") + enc.StringKey("qlog_version", "draft-02") + enc.StringKeyOmitEmpty("title", "quic-go qlog") + enc.StringKey("code_version", quicGoVersion) + enc.ObjectKey("trace", l.trace) +} + +type vantagePoint struct { + Name string + Type protocol.Perspective +} + +func (p vantagePoint) IsNil() bool { return false } +func (p vantagePoint) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKeyOmitEmpty("name", p.Name) + switch p.Type { + case protocol.PerspectiveClient: + enc.StringKey("type", "client") + case protocol.PerspectiveServer: + enc.StringKey("type", "server") + } +} + +type commonFields struct { + ODCID connectionID + GroupID connectionID + ProtocolType string + ReferenceTime time.Time +} + +func (f commonFields) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("ODCID", f.ODCID.String()) + enc.StringKey("group_id", f.ODCID.String()) + enc.StringKeyOmitEmpty("protocol_type", f.ProtocolType) + enc.Float64Key("reference_time", float64(f.ReferenceTime.UnixNano())/1e6) + enc.StringKey("time_format", "relative") +} + +func (f commonFields) IsNil() bool { return false } + +type trace struct { + VantagePoint vantagePoint + CommonFields commonFields +} + +func (trace) IsNil() bool { return false } +func (t trace) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("vantage_point", t.VantagePoint) + enc.ObjectKey("common_fields", t.CommonFields) +} diff --git a/internal/quic-go/qlog/types.go b/internal/quic-go/qlog/types.go new file mode 100644 index 00000000..30e7f3ee --- /dev/null +++ b/internal/quic-go/qlog/types.go @@ -0,0 +1,320 @@ +package qlog + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" +) + +type owner uint8 + +const ( + ownerLocal owner = iota + ownerRemote +) + +func (o owner) String() string { + switch o { + case ownerLocal: + return "local" + case ownerRemote: + return "remote" + default: + return "unknown owner" + } +} + +type streamType protocol.StreamType + +func (s streamType) String() string { + switch protocol.StreamType(s) { + case protocol.StreamTypeUni: + return "unidirectional" + case protocol.StreamTypeBidi: + return "bidirectional" + default: + return "unknown stream type" + } +} + +type connectionID protocol.ConnectionID + +func (c connectionID) String() string { + return fmt.Sprintf("%x", []byte(c)) +} + +// category is the qlog event category. +type category uint8 + +const ( + categoryConnectivity category = iota + categoryTransport + categorySecurity + categoryRecovery +) + +func (c category) String() string { + switch c { + case categoryConnectivity: + return "connectivity" + case categoryTransport: + return "transport" + case categorySecurity: + return "security" + case categoryRecovery: + return "recovery" + default: + return "unknown category" + } +} + +type versionNumber protocol.VersionNumber + +func (v versionNumber) String() string { + return fmt.Sprintf("%x", uint32(v)) +} + +func (packetHeader) IsNil() bool { return false } + +func encLevelToPacketNumberSpace(encLevel protocol.EncryptionLevel) string { + switch encLevel { + case protocol.EncryptionInitial: + return "initial" + case protocol.EncryptionHandshake: + return "handshake" + case protocol.Encryption0RTT, protocol.Encryption1RTT: + return "application_data" + default: + return "unknown encryption level" + } +} + +type keyType uint8 + +const ( + keyTypeServerInitial keyType = 1 + iota + keyTypeClientInitial + keyTypeServerHandshake + keyTypeClientHandshake + keyTypeServer0RTT + keyTypeClient0RTT + keyTypeServer1RTT + keyTypeClient1RTT +) + +func encLevelToKeyType(encLevel protocol.EncryptionLevel, pers protocol.Perspective) keyType { + if pers == protocol.PerspectiveServer { + switch encLevel { + case protocol.EncryptionInitial: + return keyTypeServerInitial + case protocol.EncryptionHandshake: + return keyTypeServerHandshake + case protocol.Encryption0RTT: + return keyTypeServer0RTT + case protocol.Encryption1RTT: + return keyTypeServer1RTT + default: + return 0 + } + } + switch encLevel { + case protocol.EncryptionInitial: + return keyTypeClientInitial + case protocol.EncryptionHandshake: + return keyTypeClientHandshake + case protocol.Encryption0RTT: + return keyTypeClient0RTT + case protocol.Encryption1RTT: + return keyTypeClient1RTT + default: + return 0 + } +} + +func (t keyType) String() string { + switch t { + case keyTypeServerInitial: + return "server_initial_secret" + case keyTypeClientInitial: + return "client_initial_secret" + case keyTypeServerHandshake: + return "server_handshake_secret" + case keyTypeClientHandshake: + return "client_handshake_secret" + case keyTypeServer0RTT: + return "server_0rtt_secret" + case keyTypeClient0RTT: + return "client_0rtt_secret" + case keyTypeServer1RTT: + return "server_1rtt_secret" + case keyTypeClient1RTT: + return "client_1rtt_secret" + default: + return "unknown key type" + } +} + +type keyUpdateTrigger uint8 + +const ( + keyUpdateTLS keyUpdateTrigger = iota + keyUpdateRemote + keyUpdateLocal +) + +func (t keyUpdateTrigger) String() string { + switch t { + case keyUpdateTLS: + return "tls" + case keyUpdateRemote: + return "remote_update" + case keyUpdateLocal: + return "local_update" + default: + return "unknown key update trigger" + } +} + +type transportError uint64 + +func (e transportError) String() string { + switch qerr.TransportErrorCode(e) { + case qerr.NoError: + return "no_error" + case qerr.InternalError: + return "internal_error" + case qerr.ConnectionRefused: + return "connection_refused" + case qerr.FlowControlError: + return "flow_control_error" + case qerr.StreamLimitError: + return "stream_limit_error" + case qerr.StreamStateError: + return "stream_state_error" + case qerr.FinalSizeError: + return "final_size_error" + case qerr.FrameEncodingError: + return "frame_encoding_error" + case qerr.TransportParameterError: + return "transport_parameter_error" + case qerr.ConnectionIDLimitError: + return "connection_id_limit_error" + case qerr.ProtocolViolation: + return "protocol_violation" + case qerr.InvalidToken: + return "invalid_token" + case qerr.ApplicationErrorErrorCode: + return "application_error" + case qerr.CryptoBufferExceeded: + return "crypto_buffer_exceeded" + case qerr.KeyUpdateError: + return "key_update_error" + case qerr.AEADLimitReached: + return "aead_limit_reached" + case qerr.NoViablePathError: + return "no_viable_path" + default: + return "" + } +} + +type packetType logging.PacketType + +func (t packetType) String() string { + switch logging.PacketType(t) { + case logging.PacketTypeInitial: + return "initial" + case logging.PacketTypeHandshake: + return "handshake" + case logging.PacketTypeRetry: + return "retry" + case logging.PacketType0RTT: + return "0RTT" + case logging.PacketTypeVersionNegotiation: + return "version_negotiation" + case logging.PacketTypeStatelessReset: + return "stateless_reset" + case logging.PacketType1RTT: + return "1RTT" + case logging.PacketTypeNotDetermined: + return "" + default: + return "unknown packet type" + } +} + +type packetLossReason logging.PacketLossReason + +func (r packetLossReason) String() string { + switch logging.PacketLossReason(r) { + case logging.PacketLossReorderingThreshold: + return "reordering_threshold" + case logging.PacketLossTimeThreshold: + return "time_threshold" + default: + return "unknown loss reason" + } +} + +type packetDropReason logging.PacketDropReason + +func (r packetDropReason) String() string { + switch logging.PacketDropReason(r) { + case logging.PacketDropKeyUnavailable: + return "key_unavailable" + case logging.PacketDropUnknownConnectionID: + return "unknown_connection_id" + case logging.PacketDropHeaderParseError: + return "header_parse_error" + case logging.PacketDropPayloadDecryptError: + return "payload_decrypt_error" + case logging.PacketDropProtocolViolation: + return "protocol_violation" + case logging.PacketDropDOSPrevention: + return "dos_prevention" + case logging.PacketDropUnsupportedVersion: + return "unsupported_version" + case logging.PacketDropUnexpectedPacket: + return "unexpected_packet" + case logging.PacketDropUnexpectedSourceConnectionID: + return "unexpected_source_connection_id" + case logging.PacketDropUnexpectedVersion: + return "unexpected_version" + case logging.PacketDropDuplicate: + return "duplicate" + default: + return "unknown packet drop reason" + } +} + +type timerType logging.TimerType + +func (t timerType) String() string { + switch logging.TimerType(t) { + case logging.TimerTypeACK: + return "ack" + case logging.TimerTypePTO: + return "pto" + default: + return "unknown timer type" + } +} + +type congestionState logging.CongestionState + +func (s congestionState) String() string { + switch logging.CongestionState(s) { + case logging.CongestionStateSlowStart: + return "slow_start" + case logging.CongestionStateCongestionAvoidance: + return "congestion_avoidance" + case logging.CongestionStateRecovery: + return "recovery" + case logging.CongestionStateApplicationLimited: + return "application_limited" + default: + return "unknown congestion state" + } +} diff --git a/internal/quic-go/qlog/types_test.go b/internal/quic-go/qlog/types_test.go new file mode 100644 index 00000000..cbdc8514 --- /dev/null +++ b/internal/quic-go/qlog/types_test.go @@ -0,0 +1,130 @@ +package qlog + +import ( + "go/ast" + "go/parser" + gotoken "go/token" + "path" + "runtime" + "strconv" + + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Types", func() { + It("has a string representation for the owner", func() { + Expect(ownerLocal.String()).To(Equal("local")) + Expect(ownerRemote.String()).To(Equal("remote")) + }) + + It("has a string representation for the category", func() { + Expect(categoryConnectivity.String()).To(Equal("connectivity")) + Expect(categoryTransport.String()).To(Equal("transport")) + Expect(categoryRecovery.String()).To(Equal("recovery")) + Expect(categorySecurity.String()).To(Equal("security")) + }) + + It("has a string representation for the packet type", func() { + Expect(packetType(logging.PacketTypeInitial).String()).To(Equal("initial")) + Expect(packetType(logging.PacketTypeHandshake).String()).To(Equal("handshake")) + Expect(packetType(logging.PacketType0RTT).String()).To(Equal("0RTT")) + Expect(packetType(logging.PacketType1RTT).String()).To(Equal("1RTT")) + Expect(packetType(logging.PacketTypeStatelessReset).String()).To(Equal("stateless_reset")) + Expect(packetType(logging.PacketTypeRetry).String()).To(Equal("retry")) + Expect(packetType(logging.PacketTypeVersionNegotiation).String()).To(Equal("version_negotiation")) + Expect(packetType(logging.PacketTypeNotDetermined).String()).To(BeEmpty()) + }) + + It("has a string representation for the packet drop reason", func() { + Expect(packetDropReason(logging.PacketDropKeyUnavailable).String()).To(Equal("key_unavailable")) + Expect(packetDropReason(logging.PacketDropUnknownConnectionID).String()).To(Equal("unknown_connection_id")) + Expect(packetDropReason(logging.PacketDropHeaderParseError).String()).To(Equal("header_parse_error")) + Expect(packetDropReason(logging.PacketDropPayloadDecryptError).String()).To(Equal("payload_decrypt_error")) + Expect(packetDropReason(logging.PacketDropProtocolViolation).String()).To(Equal("protocol_violation")) + Expect(packetDropReason(logging.PacketDropDOSPrevention).String()).To(Equal("dos_prevention")) + Expect(packetDropReason(logging.PacketDropUnsupportedVersion).String()).To(Equal("unsupported_version")) + Expect(packetDropReason(logging.PacketDropUnexpectedPacket).String()).To(Equal("unexpected_packet")) + Expect(packetDropReason(logging.PacketDropUnexpectedSourceConnectionID).String()).To(Equal("unexpected_source_connection_id")) + Expect(packetDropReason(logging.PacketDropUnexpectedVersion).String()).To(Equal("unexpected_version")) + }) + + It("has a string representation for the timer type", func() { + Expect(timerType(logging.TimerTypeACK).String()).To(Equal("ack")) + Expect(timerType(logging.TimerTypePTO).String()).To(Equal("pto")) + }) + + It("has a string representation for the key type", func() { + Expect(encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient).String()).To(Equal("client_initial_secret")) + Expect(encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer).String()).To(Equal("server_initial_secret")) + Expect(encLevelToKeyType(protocol.EncryptionHandshake, protocol.PerspectiveClient).String()).To(Equal("client_handshake_secret")) + Expect(encLevelToKeyType(protocol.EncryptionHandshake, protocol.PerspectiveServer).String()).To(Equal("server_handshake_secret")) + Expect(encLevelToKeyType(protocol.Encryption0RTT, protocol.PerspectiveClient).String()).To(Equal("client_0rtt_secret")) + Expect(encLevelToKeyType(protocol.Encryption0RTT, protocol.PerspectiveServer).String()).To(Equal("server_0rtt_secret")) + Expect(encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient).String()).To(Equal("client_1rtt_secret")) + Expect(encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer).String()).To(Equal("server_1rtt_secret")) + }) + + It("has a string representation for the key update trigger", func() { + Expect(keyUpdateTLS.String()).To(Equal("tls")) + Expect(keyUpdateRemote.String()).To(Equal("remote_update")) + Expect(keyUpdateLocal.String()).To(Equal("local_update")) + }) + + It("tells the packet number space from the encryption level", func() { + Expect(encLevelToPacketNumberSpace(protocol.EncryptionInitial)).To(Equal("initial")) + Expect(encLevelToPacketNumberSpace(protocol.EncryptionHandshake)).To(Equal("handshake")) + Expect(encLevelToPacketNumberSpace(protocol.Encryption0RTT)).To(Equal("application_data")) + Expect(encLevelToPacketNumberSpace(protocol.Encryption1RTT)).To(Equal("application_data")) + }) + + Context("transport errors", func() { + It("has a string representation for every error code", func() { + // We parse the error code file, extract all constants, and verify that + // each of them has a string version. Go FTW! + _, thisfile, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + filename := path.Join(path.Dir(thisfile), "../qerr/error_codes.go") + fileAst, err := parser.ParseFile(gotoken.NewFileSet(), filename, nil, 0) + Expect(err).NotTo(HaveOccurred()) + constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs + Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing + for _, c := range constSpecs { + valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + val, err := strconv.ParseInt(valString, 0, 64) + Expect(err).NotTo(HaveOccurred()) + Expect(transportError(val).String()).ToNot(BeEmpty()) + } + }) + + It("has a string representation for transport errors", func() { + Expect(transportError(qerr.NoError).String()).To(Equal("no_error")) + Expect(transportError(qerr.InternalError).String()).To(Equal("internal_error")) + Expect(transportError(qerr.ConnectionRefused).String()).To(Equal("connection_refused")) + Expect(transportError(qerr.FlowControlError).String()).To(Equal("flow_control_error")) + Expect(transportError(qerr.StreamLimitError).String()).To(Equal("stream_limit_error")) + Expect(transportError(qerr.StreamStateError).String()).To(Equal("stream_state_error")) + Expect(transportError(qerr.FrameEncodingError).String()).To(Equal("frame_encoding_error")) + Expect(transportError(qerr.ConnectionIDLimitError).String()).To(Equal("connection_id_limit_error")) + Expect(transportError(qerr.ProtocolViolation).String()).To(Equal("protocol_violation")) + Expect(transportError(qerr.InvalidToken).String()).To(Equal("invalid_token")) + Expect(transportError(qerr.ApplicationErrorErrorCode).String()).To(Equal("application_error")) + Expect(transportError(qerr.CryptoBufferExceeded).String()).To(Equal("crypto_buffer_exceeded")) + Expect(transportError(qerr.NoViablePathError).String()).To(Equal("no_viable_path")) + Expect(transportError(1337).String()).To(BeEmpty()) + }) + }) + + It("has a string representation for congestion state updates", func() { + Expect(congestionState(logging.CongestionStateSlowStart).String()).To(Equal("slow_start")) + Expect(congestionState(logging.CongestionStateCongestionAvoidance).String()).To(Equal("congestion_avoidance")) + Expect(congestionState(logging.CongestionStateApplicationLimited).String()).To(Equal("application_limited")) + Expect(congestionState(logging.CongestionStateRecovery).String()).To(Equal("recovery")) + }) +}) diff --git a/internal/quic-go/qtls/go116.go b/internal/quic-go/qtls/go116.go new file mode 100644 index 00000000..e3024624 --- /dev/null +++ b/internal/quic-go/qtls/go116.go @@ -0,0 +1,100 @@ +//go:build go1.16 && !go1.17 +// +build go1.16,!go1.17 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-16" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/quic-go/qtls/go117.go b/internal/quic-go/qtls/go117.go new file mode 100644 index 00000000..bc385f19 --- /dev/null +++ b/internal/quic-go/qtls/go117.go @@ -0,0 +1,100 @@ +//go:build go1.17 && !go1.18 +// +build go1.17,!go1.18 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-17" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/quic-go/qtls/go118.go b/internal/quic-go/qtls/go118.go new file mode 100644 index 00000000..5de030c7 --- /dev/null +++ b/internal/quic-go/qtls/go118.go @@ -0,0 +1,100 @@ +//go:build go1.18 && !go1.19 +// +build go1.18,!go1.19 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-18" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/quic-go/qtls/go119.go b/internal/quic-go/qtls/go119.go new file mode 100644 index 00000000..86dcaea3 --- /dev/null +++ b/internal/quic-go/qtls/go119.go @@ -0,0 +1,100 @@ +//go:build go1.19 +// +build go1.19 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-19" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains information about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-19.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/qtls/go_oldversion.go b/internal/quic-go/qtls/go_oldversion.go similarity index 80% rename from internal/qtls/go_oldversion.go rename to internal/quic-go/qtls/go_oldversion.go index 384d719c..c17b8589 100644 --- a/internal/qtls/go_oldversion.go +++ b/internal/quic-go/qtls/go_oldversion.go @@ -4,4 +4,4 @@ package qtls -var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/imroc/req/v3/internal/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/quic-go/qtls/qtls_suite_test.go b/internal/quic-go/qtls/qtls_suite_test.go new file mode 100644 index 00000000..24b143b2 --- /dev/null +++ b/internal/quic-go/qtls/qtls_suite_test.go @@ -0,0 +1,25 @@ +package qtls + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQTLS(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "qtls Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/quic-go/qtls/qtls_test.go b/internal/quic-go/qtls/qtls_test.go new file mode 100644 index 00000000..c64c5e9e --- /dev/null +++ b/internal/quic-go/qtls/qtls_test.go @@ -0,0 +1,17 @@ +package qtls + +import ( + "crypto/tls" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("qtls wrapper", func() { + It("gets cipher suites", func() { + for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} { + cs := CipherSuiteTLS13ByID(id) + Expect(cs.ID).To(Equal(id)) + } + }) +}) diff --git a/internal/quic-go/quic_suite_test.go b/internal/quic-go/quic_suite_test.go new file mode 100644 index 00000000..d8a0a2e0 --- /dev/null +++ b/internal/quic-go/quic_suite_test.go @@ -0,0 +1,34 @@ +package quic + +import ( + "io/ioutil" + "log" + "sync" + "testing" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQuicGo(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "QUIC Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + + // reset the sync.Once + connMuxerOnce = *new(sync.Once) +}) + +var _ = BeforeSuite(func() { + log.SetOutput(ioutil.Discard) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/quicvarint/io.go b/internal/quic-go/quicvarint/io.go similarity index 100% rename from internal/quicvarint/io.go rename to internal/quic-go/quicvarint/io.go diff --git a/internal/quicvarint/io_test.go b/internal/quic-go/quicvarint/io_test.go similarity index 100% rename from internal/quicvarint/io_test.go rename to internal/quic-go/quicvarint/io_test.go diff --git a/internal/quicvarint/quicvarint_suite_test.go b/internal/quic-go/quicvarint/quicvarint_suite_test.go similarity index 100% rename from internal/quicvarint/quicvarint_suite_test.go rename to internal/quic-go/quicvarint/quicvarint_suite_test.go diff --git a/internal/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go similarity index 98% rename from internal/quicvarint/varint.go rename to internal/quic-go/quicvarint/varint.go index 3bb242fd..ba7e8772 100644 --- a/internal/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/imroc/req/v3/internal/protocol" + "github.com/imroc/req/v3/internal/quic-go/protocol" ) // taken from the QUIC draft diff --git a/internal/quicvarint/varint_test.go b/internal/quic-go/quicvarint/varint_test.go similarity index 100% rename from internal/quicvarint/varint_test.go rename to internal/quic-go/quicvarint/varint_test.go diff --git a/internal/quic-go/receive_stream.go b/internal/quic-go/receive_stream.go new file mode 100644 index 00000000..6fea6afd --- /dev/null +++ b/internal/quic-go/receive_stream.go @@ -0,0 +1,331 @@ +package quic + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type receiveStreamI interface { + ReceiveStream + + handleStreamFrame(*wire.StreamFrame) error + handleResetStreamFrame(*wire.ResetStreamFrame) error + closeForShutdown(error) + getWindowUpdate() protocol.ByteCount +} + +type receiveStream struct { + mutex sync.Mutex + + streamID protocol.StreamID + + sender streamSender + + frameQueue *frameSorter + finalOffset protocol.ByteCount + + currentFrame []byte + currentFrameDone func() + currentFrameIsLast bool // is the currentFrame the last frame on this stream + readPosInFrame int + + closeForShutdownErr error + cancelReadErr error + resetRemotelyErr *StreamError + + closedForShutdown bool // set when CloseForShutdown() is called + finRead bool // set once we read a frame with a Fin + canceledRead bool // set when CancelRead() is called + resetRemotely bool // set when HandleResetStreamFrame() is called + + readChan chan struct{} + readOnce chan struct{} // cap: 1, to protect against concurrent use of Read + deadline time.Time + + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber +} + +var ( + _ ReceiveStream = &receiveStream{} + _ receiveStreamI = &receiveStream{} +) + +func newReceiveStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *receiveStream { + return &receiveStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newFrameSorter(), + readChan: make(chan struct{}, 1), + readOnce: make(chan struct{}, 1), + finalOffset: protocol.MaxByteCount, + version: version, + } +} + +func (s *receiveStream) StreamID() protocol.StreamID { + return s.streamID +} + +// Read implements io.Reader. It is not thread safe! +func (s *receiveStream) Read(p []byte) (int, error) { + // Concurrent use of Read is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.readOnce <- struct{}{} + defer func() { <-s.readOnce }() + + s.mutex.Lock() + completed, n, err := s.readImpl(p) + s.mutex.Unlock() + + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return n, err +} + +func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { + if s.finRead { + return false, 0, io.EOF + } + if s.canceledRead { + return false, 0, s.cancelReadErr + } + if s.resetRemotely { + return false, 0, s.resetRemotelyErr + } + if s.closedForShutdown { + return false, 0, s.closeForShutdownErr + } + + var bytesRead int + var deadlineTimer *utils.Timer + for bytesRead < len(p) { + if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { + s.dequeueNextFrame() + } + if s.currentFrame == nil && bytesRead > 0 { + return false, bytesRead, s.closeForShutdownErr + } + + for { + // Stop waiting on errors + if s.closedForShutdown { + return false, bytesRead, s.closeForShutdownErr + } + if s.canceledRead { + return false, bytesRead, s.cancelReadErr + } + if s.resetRemotely { + return false, bytesRead, s.resetRemotelyErr + } + + deadline := s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + return false, bytesRead, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + defer deadlineTimer.Stop() + } + deadlineTimer.Reset(deadline) + } + + if s.currentFrame != nil || s.currentFrameIsLast { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() + } + } + s.mutex.Lock() + if s.currentFrame == nil { + s.dequeueNextFrame() + } + } + + if bytesRead > len(p) { + return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + } + if s.readPosInFrame > len(s.currentFrame) { + return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + } + + m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) + s.readPosInFrame += m + bytesRead += m + + // when a RESET_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely { + s.flowController.AddBytesRead(protocol.ByteCount(m)) + } + + if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { + s.finRead = true + return true, bytesRead, io.EOF + } + } + return false, bytesRead, nil +} + +func (s *receiveStream) dequeueNextFrame() { + var offset protocol.ByteCount + // We're done with the last frame. Release the buffer. + if s.currentFrameDone != nil { + s.currentFrameDone() + } + offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() + s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset + s.readPosInFrame = 0 +} + +func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { + s.mutex.Lock() + completed := s.cancelReadImpl(errorCode) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { + if s.finRead || s.canceledRead || s.resetRemotely { + return false + } + s.canceledRead = true + s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.signalRead() + s.sender.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + // We're done with this stream if the final offset was already received. + return s.finalOffset != protocol.MaxByteCount +} + +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { + s.mutex.Lock() + completed, err := s.handleStreamFrameImpl(frame) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } + return err +} + +func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) { + maxOffset := frame.Offset + frame.DataLen() + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil { + return false, err + } + var newlyRcvdFinalOffset bool + if frame.Fin { + newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount + s.finalOffset = maxOffset + } + if s.canceledRead { + return newlyRcvdFinalOffset, nil + } + if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { + return false, err + } + s.signalRead() + return false, nil +} + +func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { + s.mutex.Lock() + completed, err := s.handleResetStreamFrameImpl(frame) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } + return err +} + +func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { + if s.closedForShutdown { + return false, nil + } + if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { + return false, err + } + newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount + s.finalOffset = frame.FinalSize + + // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) + if s.resetRemotely { + return false, nil + } + s.resetRemotely = true + s.resetRemotelyErr = &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, + } + s.signalRead() + return newlyRcvdFinalOffset, nil +} + +func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { + s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset}) +} + +func (s *receiveStream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + s.deadline = t + s.mutex.Unlock() + s.signalRead() + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET. +func (s *receiveStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalRead() +} + +func (s *receiveStream) getWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} + +// signalRead performs a non-blocking send on the readChan +func (s *receiveStream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} diff --git a/internal/quic-go/receive_stream_test.go b/internal/quic-go/receive_stream_test.go new file mode 100644 index 00000000..06b30ef9 --- /dev/null +++ b/internal/quic-go/receive_stream_test.go @@ -0,0 +1,696 @@ +package quic + +import ( + "errors" + "io" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("Receive Stream", func() { + const streamID protocol.StreamID = 1337 + + var ( + str *receiveStream + strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newReceiveStream(streamID, mockSender, mockFC, protocol.VersionWhatever) + + timeout := scaleDuration(250 * time.Millisecond) + strWithTimeout = gbytes.TimeoutReader(str, timeout) + }) + + It("gets stream id", func() { + Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + }) + + Context("reading", func() { + It("reads a single STREAM frame", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + frame := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("reads a single STREAM frame in multiple goes", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + frame := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 2) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b).To(Equal([]byte{0xDE, 0xAD})) + n, err = strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b).To(Equal([]byte{0xBE, 0xEF})) + }) + + It("reads all data available", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) + }) + + It("assembles multiple STREAM frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("waits until data is available", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + go func() { + defer GinkgoRecover() + frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} + time.Sleep(10 * time.Millisecond) + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + }() + b := make([]byte, 2) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + }) + + It("handles STREAM frames in wrong order", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + frame1 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("ignores duplicate STREAM frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0x13, 0x37}, + } + frame3 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame3) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("doesn't rejects a STREAM frames with an overlapping data range", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte("foob"), + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte("obar"), + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + Context("deadlines", func() { + It("the deadline error has the right net.Error properties", func() { + Expect(errDeadline.Timeout()).To(BeTrue()) + Expect(errDeadline).To(MatchError("deadline exceeded")) + }) + + It("returns an error when Read is called after the deadline", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() + f := &wire.StreamFrame{Data: []byte("foobar")} + err := str.handleStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetReadDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks when the deadline is changed to the past", func() { + str.SetReadDeadline(time.Now().Add(time.Hour)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Read(make([]byte, 6)) + Expect(err).To(MatchError(errDeadline)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.SetReadDeadline(time.Now().Add(-time.Hour)) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks after the deadline", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetReadDeadline(deadline) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetReadDeadline(deadline1) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + }() + runtime.Gosched() + b := make([]byte, 10) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + }) + + It("unblocks earlier, when a new deadline is set", func() { + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + }() + str.SetReadDeadline(deadline1) + runtime.Gosched() + b := make([]byte, 10) + _, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) + }) + + It("doesn't unblock if the deadline is removed", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetReadDeadline(deadline) + deadlineUnset := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetReadDeadline(time.Time{}) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline)) + close(deadlineUnset) + }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read(make([]byte, 1)) + Expect(err).To(MatchError("test done")) + close(done) + }() + runtime.Gosched() + Eventually(deadlineUnset).Should(BeClosed()) + Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) + // make the go routine return + str.closeForShutdown(errors.New("test done")) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("closing", func() { + Context("with FIN bit", func() { + It("returns EOFs", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + Fin: true, + }) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("handles out-of-order frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + frame1 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + Fin: true, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns EOFs with partial read", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + err := str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte{0xde, 0xad}, + Fin: true, + }) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte{0xde, 0xad})) + }) + + It("handles immediate FINs", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + err := str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Fin: true, + }) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + // Calling Read concurrently doesn't make any sense (and is forbidden), + // but we still want to make sure that we don't complete the stream more than once + // if the user misuses our API. + // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), + // which can be hard to debug. + // Note that even without the protection built into the receiveStream, this test + // is very timing-dependent, and would need to run a few hundred times to trigger the failure. + It("handles concurrent reads", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes() + var bytesRead protocol.ByteCount + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes() + + var numCompleted int32 + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { + atomic.AddInt32(&numCompleted, 1) + }).AnyTimes() + const num = 3 + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := str.Read(make([]byte, 8)) + Expect(err).To(MatchError(io.EOF)) + }() + } + str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + Fin: true, + }) + wg.Wait() + Expect(bytesRead).To(BeEquivalentTo(6)) + Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1)) + }) + }) + + It("closes when CloseRemote is called", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + str.CloseRemote(0) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 8) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) + + Context("closing for shutdown", func() { + testErr := errors.New("test error") + + It("immediately returns all reads", func() { + done := make(chan struct{}) + b := make([]byte, 4) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.closeForShutdown(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("errors for all following reads", func() { + str.closeForShutdown(testErr) + b := make([]byte, 1) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("stream cancelations", func() { + Context("canceling read", func() { + It("unblocks Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.CancelRead(1234) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + }) + + It("does nothing when CancelRead is called twice", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + str.CancelRead(1234) + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + }) + + It("queues a STOP_SENDING frame", func() { + mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, + }) + str.CancelRead(1234) + }) + + It("doesn't send a STOP_SENDING frame, if the FIN was already read", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + // no calls to mockSender.queueControlFrame + Expect(str.handleStreamFrame(&wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + Fin: true, + })).To(Succeed()) + mockSender.EXPECT().onStreamCompleted(streamID) + _, err := strWithTimeout.Read(make([]byte, 100)) + Expect(err).To(MatchError(io.EOF)) + str.CancelRead(1234) + }) + + It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() { + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), + mockFC.EXPECT().Abandon(), + ) + mockSender.EXPECT().onStreamCompleted(streamID) + Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{ + StreamID: streamID, + FinalSize: 42, + })).To(Succeed()) + str.CancelRead(1234) + }) + + It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true) + Expect(str.handleStreamFrame(&wire.StreamFrame{ + Offset: 1000, + Fin: true, + })).To(Succeed()) + mockFC.EXPECT().Abandon() + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelRead(1234) + }) + + It("completes the stream when receiving the Fin after the stream was canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), + mockFC.EXPECT().Abandon(), + ) + mockSender.EXPECT().onStreamCompleted(streamID) + Expect(str.handleStreamFrame(&wire.StreamFrame{ + Offset: 1000, + Fin: true, + })).To(Succeed()) + }) + + It("handles duplicate FinBits after the stream was canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), + mockFC.EXPECT().Abandon(), + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), + ) + mockSender.EXPECT().onStreamCompleted(streamID) + Expect(str.handleStreamFrame(&wire.StreamFrame{ + Offset: 1000, + Fin: true, + })).To(Succeed()) + Expect(str.handleStreamFrame(&wire.StreamFrame{ + Offset: 1000, + Fin: true, + })).To(Succeed()) + }) + }) + + Context("receiving RESET_STREAM frames", func() { + rst := &wire.ResetStreamFrame{ + StreamID: streamID, + FinalSize: 42, + ErrorCode: 1234, + } + + It("unblocks Read", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + mockSender.EXPECT().onStreamCompleted(streamID) + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), + mockFC.EXPECT().Abandon(), + ) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), + mockFC.EXPECT().Abandon(), + ) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) + }) + + It("errors when receiving a RESET_STREAM with an inconsistent offset", func() { + testErr := errors.New("already received a different final offset before") + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr) + err := str.handleResetStreamFrame(rst) + Expect(err).To(MatchError(testErr)) + }) + + It("ignores duplicate RESET_STREAM frames", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + }) + + It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) + Expect(str.handleStreamFrame(&wire.StreamFrame{ + StreamID: streamID, + Offset: rst.FinalSize, + Fin: true, + })).To(Succeed()) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + }) + + It("doesn't do anyting when it was closed for shutdown", func() { + str.closeForShutdown(nil) + err := str.handleResetStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + + Context("flow control", func() { + It("errors when a STREAM frame causes a flow control violation", func() { + testErr := errors.New("flow control violation") + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) + frame := wire.StreamFrame{ + Offset: 2, + Data: []byte("foobar"), + } + err := str.handleStreamFrame(&frame) + Expect(err).To(MatchError(testErr)) + }) + + It("gets a window update", func() { + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) + Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) + }) + }) +}) diff --git a/internal/quic-go/retransmission_queue.go b/internal/quic-go/retransmission_queue.go new file mode 100644 index 00000000..57d54e5f --- /dev/null +++ b/internal/quic-go/retransmission_queue.go @@ -0,0 +1,131 @@ +package quic + +import ( + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type retransmissionQueue struct { + initial []wire.Frame + initialCryptoData []*wire.CryptoFrame + + handshake []wire.Frame + handshakeCryptoData []*wire.CryptoFrame + + appData []wire.Frame + + version protocol.VersionNumber +} + +func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { + return &retransmissionQueue{version: ver} +} + +func (q *retransmissionQueue) AddInitial(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.initialCryptoData = append(q.initialCryptoData, cf) + return + } + q.initial = append(q.initial, f) +} + +func (q *retransmissionQueue) AddHandshake(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.handshakeCryptoData = append(q.handshakeCryptoData, cf) + return + } + q.handshake = append(q.handshake, f) +} + +func (q *retransmissionQueue) HasInitialData() bool { + return len(q.initialCryptoData) > 0 || len(q.initial) > 0 +} + +func (q *retransmissionQueue) HasHandshakeData() bool { + return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 +} + +func (q *retransmissionQueue) HasAppData() bool { + return len(q.appData) > 0 +} + +func (q *retransmissionQueue) AddAppData(f wire.Frame) { + if _, ok := f.(*wire.StreamFrame); ok { + panic("STREAM frames are handled with their respective streams.") + } + q.appData = append(q.appData, f) +} + +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.initialCryptoData) > 0 { + f := q.initialCryptoData[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + if newFrame == nil && !needsSplit { // the whole frame fits + q.initialCryptoData = q.initialCryptoData[1:] + return f + } + if newFrame != nil { // frame was split. Leave the original frame in the queue. + return newFrame + } + } + if len(q.initial) == 0 { + return nil + } + f := q.initial[0] + if f.Length(q.version) > maxLen { + return nil + } + q.initial = q.initial[1:] + return f +} + +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.handshakeCryptoData) > 0 { + f := q.handshakeCryptoData[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + if newFrame == nil && !needsSplit { // the whole frame fits + q.handshakeCryptoData = q.handshakeCryptoData[1:] + return f + } + if newFrame != nil { // frame was split. Leave the original frame in the queue. + return newFrame + } + } + if len(q.handshake) == 0 { + return nil + } + f := q.handshake[0] + if f.Length(q.version) > maxLen { + return nil + } + q.handshake = q.handshake[1:] + return f +} + +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.appData) == 0 { + return nil + } + f := q.appData[0] + if f.Length(q.version) > maxLen { + return nil + } + q.appData = q.appData[1:] + return f +} + +func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { + //nolint:exhaustive // Can only drop Initial and Handshake packet number space. + switch encLevel { + case protocol.EncryptionInitial: + q.initial = nil + q.initialCryptoData = nil + case protocol.EncryptionHandshake: + q.handshake = nil + q.handshakeCryptoData = nil + default: + panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) + } +} diff --git a/internal/quic-go/retransmission_queue_test.go b/internal/quic-go/retransmission_queue_test.go new file mode 100644 index 00000000..4780571f --- /dev/null +++ b/internal/quic-go/retransmission_queue_test.go @@ -0,0 +1,187 @@ +package quic + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Retransmission queue", func() { + const version = protocol.VersionTLS + + var q *retransmissionQueue + + BeforeEach(func() { + q = newRetransmissionQueue(version) + }) + + Context("Initial data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.HasInitialData()).To(BeFalse()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + q.AddInitial(f) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("queues and retrieves a CRYPTO frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddInitial(f) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("returns split CRYPTO frames", func() { + f := &wire.CryptoFrame{ + Offset: 100, + Data: []byte("foobar"), + } + q.AddInitial(f) + Expect(q.HasInitialData()).To(BeTrue()) + f1 := q.GetInitialFrame(f.Length(version) - 3) + Expect(f1).ToNot(BeNil()) + Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) + Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) + Expect(q.HasInitialData()).To(BeTrue()) + f2 := q.GetInitialFrame(protocol.MaxByteCount) + Expect(f2).ToNot(BeNil()) + Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) + Expect(f2.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(103))) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("returns other frames when a CRYPTO frame wouldn't fit", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddInitial(f) + q.AddInitial(&wire.PingFrame{}) + f1 := q.GetInitialFrame(2) // too small for a CRYPTO frame + Expect(f1).ToNot(BeNil()) + Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(q.HasInitialData()).To(BeTrue()) + f2 := q.GetInitialFrame(protocol.MaxByteCount) + Expect(f2).To(Equal(f)) + }) + + It("retrieves both a CRYPTO frame and a control frame", func() { + cf := &wire.MaxDataFrame{MaximumData: 0x42} + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddInitial(f) + q.AddInitial(cf) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f)) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("drops all Initial frames", func() { + q.AddInitial(&wire.CryptoFrame{Data: []byte("foobar")}) + q.AddInitial(&wire.MaxDataFrame{MaximumData: 0x42}) + q.DropPackets(protocol.EncryptionInitial) + Expect(q.HasInitialData()).To(BeFalse()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + }) + }) + + Context("Handshake data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.HasHandshakeData()).To(BeFalse()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + q.AddHandshake(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("queues and retrieves a CRYPTO frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddHandshake(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("returns split CRYPTO frames", func() { + f := &wire.CryptoFrame{ + Offset: 100, + Data: []byte("foobar"), + } + q.AddHandshake(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + f1 := q.GetHandshakeFrame(f.Length(version) - 3) + Expect(f1).ToNot(BeNil()) + Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) + Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) + Expect(q.HasHandshakeData()).To(BeTrue()) + f2 := q.GetHandshakeFrame(protocol.MaxByteCount) + Expect(f2).ToNot(BeNil()) + Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) + Expect(f2.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(103))) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("returns other frames when a CRYPTO frame wouldn't fit", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddHandshake(f) + q.AddHandshake(&wire.PingFrame{}) + f1 := q.GetHandshakeFrame(2) // too small for a CRYPTO frame + Expect(f1).ToNot(BeNil()) + Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(q.HasHandshakeData()).To(BeTrue()) + f2 := q.GetHandshakeFrame(protocol.MaxByteCount) + Expect(f2).To(Equal(f)) + }) + + It("retrieves both a CRYPTO frame and a control frame", func() { + cf := &wire.MaxDataFrame{MaximumData: 0x42} + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddHandshake(f) + q.AddHandshake(cf) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f)) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("drops all Handshake frames", func() { + q.AddHandshake(&wire.CryptoFrame{Data: []byte("foobar")}) + q.AddHandshake(&wire.MaxDataFrame{MaximumData: 0x42}) + q.DropPackets(protocol.EncryptionHandshake) + Expect(q.HasHandshakeData()).To(BeFalse()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + }) + }) + + Context("Application data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + Expect(q.HasAppData()).To(BeFalse()) + q.AddAppData(f) + Expect(q.HasAppData()).To(BeTrue()) + Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasAppData()).To(BeFalse()) + }) + }) +}) diff --git a/internal/quic-go/send_conn.go b/internal/quic-go/send_conn.go new file mode 100644 index 00000000..c53ebdfa --- /dev/null +++ b/internal/quic-go/send_conn.go @@ -0,0 +1,74 @@ +package quic + +import ( + "net" +) + +// A sendConn allows sending using a simple Write() on a non-connected packet conn. +type sendConn interface { + Write([]byte) error + Close() error + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +type sconn struct { + rawConn + + remoteAddr net.Addr + info *packetInfo + oob []byte +} + +var _ sendConn = &sconn{} + +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { + return &sconn{ + rawConn: c, + remoteAddr: remote, + info: info, + oob: info.OOB(), + } +} + +func (c *sconn) Write(p []byte) error { + _, err := c.WritePacket(p, c.remoteAddr, c.oob) + return err +} + +func (c *sconn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *sconn) LocalAddr() net.Addr { + addr := c.rawConn.LocalAddr() + if c.info != nil { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + addrCopy := *udpAddr + addrCopy.IP = c.info.addr + addr = &addrCopy + } + } + return addr +} + +type spconn struct { + net.PacketConn + + remoteAddr net.Addr +} + +var _ sendConn = &spconn{} + +func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { + return &spconn{PacketConn: c, remoteAddr: remote} +} + +func (c *spconn) Write(p []byte) error { + _, err := c.WriteTo(p, c.remoteAddr) + return err +} + +func (c *spconn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/internal/quic-go/send_conn_test.go b/internal/quic-go/send_conn_test.go new file mode 100644 index 00000000..15e8f3b4 --- /dev/null +++ b/internal/quic-go/send_conn_test.go @@ -0,0 +1,45 @@ +package quic + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Connection (for sending packets)", func() { + var ( + c sendConn + packetConn *MockPacketConn + addr net.Addr + ) + + BeforeEach(func() { + addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + packetConn = NewMockPacketConn(mockCtrl) + c = newSendPconn(packetConn, addr) + }) + + It("writes", func() { + packetConn.EXPECT().WriteTo([]byte("foobar"), addr) + Expect(c.Write([]byte("foobar"))).To(Succeed()) + }) + + It("gets the remote address", func() { + Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) + }) + + It("gets the local address", func() { + addr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 1), + Port: 1234, + } + packetConn.EXPECT().LocalAddr().Return(addr) + Expect(c.LocalAddr()).To(Equal(addr)) + }) + + It("closes", func() { + packetConn.EXPECT().Close() + Expect(c.Close()).To(Succeed()) + }) +}) diff --git a/internal/quic-go/send_queue.go b/internal/quic-go/send_queue.go new file mode 100644 index 00000000..1fc8c1bf --- /dev/null +++ b/internal/quic-go/send_queue.go @@ -0,0 +1,88 @@ +package quic + +type sender interface { + Send(p *packetBuffer) + Run() error + WouldBlock() bool + Available() <-chan struct{} + Close() +} + +type sendQueue struct { + queue chan *packetBuffer + closeCalled chan struct{} // runStopped when Close() is called + runStopped chan struct{} // runStopped when the run loop returns + available chan struct{} + conn sendConn +} + +var _ sender = &sendQueue{} + +const sendQueueCapacity = 8 + +func newSendQueue(conn sendConn) sender { + return &sendQueue{ + conn: conn, + runStopped: make(chan struct{}), + closeCalled: make(chan struct{}), + available: make(chan struct{}, 1), + queue: make(chan *packetBuffer, sendQueueCapacity), + } +} + +// Send sends out a packet. It's guaranteed to not block. +// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. +// Otherwise Send will panic. +func (h *sendQueue) Send(p *packetBuffer) { + select { + case h.queue <- p: + case <-h.runStopped: + default: + panic("sendQueue.Send would have blocked") + } +} + +func (h *sendQueue) WouldBlock() bool { + return len(h.queue) == sendQueueCapacity +} + +func (h *sendQueue) Available() <-chan struct{} { + return h.available +} + +func (h *sendQueue) Run() error { + defer close(h.runStopped) + var shouldClose bool + for { + if shouldClose && len(h.queue) == 0 { + return nil + } + select { + case <-h.closeCalled: + h.closeCalled = nil // prevent this case from being selected again + // make sure that all queued packets are actually sent out + shouldClose = true + case p := <-h.queue: + if err := h.conn.Write(p.Data); err != nil { + // This additional check enables: + // 1. Checking for "datagram too large" message from the kernel, as such, + // 2. Path MTU discovery,and + // 3. Eventual detection of loss PingFrame. + if !isMsgSizeErr(err) { + return err + } + } + p.Release() + select { + case h.available <- struct{}{}: + default: + } + } + } +} + +func (h *sendQueue) Close() { + close(h.closeCalled) + // wait until the run loop returned + <-h.runStopped +} diff --git a/internal/quic-go/send_queue_test.go b/internal/quic-go/send_queue_test.go new file mode 100644 index 00000000..dc3179c4 --- /dev/null +++ b/internal/quic-go/send_queue_test.go @@ -0,0 +1,126 @@ +package quic + +import ( + "errors" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Send Queue", func() { + var q sender + var c *MockSendConn + + BeforeEach(func() { + c = NewMockSendConn(mockCtrl) + q = newSendQueue(c) + }) + + getPacket := func(b []byte) *packetBuffer { + buf := getPacketBuffer() + buf.Data = buf.Data[:len(b)] + copy(buf.Data, b) + return buf + } + + It("sends a packet", func() { + p := getPacket([]byte("foobar")) + q.Send(p) + + written := make(chan struct{}) + c.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(written) }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Run() + close(done) + }() + + Eventually(written).Should(BeClosed()) + q.Close() + Eventually(done).Should(BeClosed()) + }) + + It("panics when Send() is called although there's no space in the queue", func() { + for i := 0; i < sendQueueCapacity; i++ { + Expect(q.WouldBlock()).To(BeFalse()) + q.Send(getPacket([]byte("foobar"))) + } + Expect(q.WouldBlock()).To(BeTrue()) + Expect(func() { q.Send(getPacket([]byte("raboof"))) }).To(Panic()) + }) + + It("signals when sending is possible again", func() { + Expect(q.WouldBlock()).To(BeFalse()) + q.Send(getPacket([]byte("foobar1"))) + Consistently(q.Available()).ShouldNot(Receive()) + + // now start sending out packets. This should free up queue space. + c.EXPECT().Write(gomock.Any()).MinTimes(1).MaxTimes(2) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Run() + close(done) + }() + + Eventually(q.Available()).Should(Receive()) + Expect(q.WouldBlock()).To(BeFalse()) + Expect(func() { q.Send(getPacket([]byte("foobar2"))) }).ToNot(Panic()) + + q.Close() + Eventually(done).Should(BeClosed()) + }) + + It("does not block pending send after the queue has stopped running", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Run() + close(done) + }() + + // the run loop exits if there is a write error + testErr := errors.New("test error") + c.EXPECT().Write(gomock.Any()).Return(testErr) + q.Send(getPacket([]byte("foobar"))) + Eventually(done).Should(BeClosed()) + + sent := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Send(getPacket([]byte("raboof"))) + q.Send(getPacket([]byte("quux"))) + close(sent) + }() + + Eventually(sent).Should(BeClosed()) + }) + + It("blocks Close() until the packet has been sent out", func() { + written := make(chan []byte) + c.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Run() + close(done) + }() + + q.Send(getPacket([]byte("foobar"))) + + closed := make(chan struct{}) + go func() { + defer GinkgoRecover() + q.Close() + close(closed) + }() + + Consistently(closed).ShouldNot(BeClosed()) + // now write the packet + Expect(written).To(Receive()) + Eventually(done).Should(BeClosed()) + Eventually(closed).Should(BeClosed()) + }) +}) diff --git a/internal/quic-go/send_stream.go b/internal/quic-go/send_stream.go new file mode 100644 index 00000000..f2e912e8 --- /dev/null +++ b/internal/quic-go/send_stream.go @@ -0,0 +1,496 @@ +package quic + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type sendStreamI interface { + SendStream + handleStopSendingFrame(*wire.StopSendingFrame) + hasData() bool + popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + closeForShutdown(error) + updateSendWindow(protocol.ByteCount) +} + +type sendStream struct { + mutex sync.Mutex + + numOutstandingFrames int64 + retransmissionQueue []*wire.StreamFrame + + ctx context.Context + ctxCancel context.CancelFunc + + streamID protocol.StreamID + sender streamSender + + writeOffset protocol.ByteCount + + cancelWriteErr error + closeForShutdownErr error + + closedForShutdown bool // set when CloseForShutdown() is called + finishedWriting bool // set once Close() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received + finSent bool // set when a STREAM_FRAME with FIN bit has been sent + completed bool // set when this stream has been reported to the streamSender as completed + + dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out + nextFrame *wire.StreamFrame + + writeChan chan struct{} + writeOnce chan struct{} + deadline time.Time + + flowController flowcontrol.StreamFlowController + + version protocol.VersionNumber +} + +var ( + _ SendStream = &sendStream{} + _ sendStreamI = &sendStream{} +) + +func newSendStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *sendStream { + s := &sendStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write + version: version, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s +} + +func (s *sendStream) StreamID() protocol.StreamID { + return s.streamID // same for receiveStream and sendStream +} + +func (s *sendStream) Write(p []byte) (int, error) { + // Concurrent use of Write is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.writeOnce <- struct{}{} + defer func() { <-s.writeOnce }() + + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finishedWriting { + return 0, fmt.Errorf("write on closed stream %d", s.streamID) + } + if s.canceledWrite { + return 0, s.cancelWriteErr + } + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } + if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { + return 0, errDeadline + } + if len(p) == 0 { + return 0, nil + } + + s.dataForWriting = p + + var ( + deadlineTimer *utils.Timer + bytesWritten int + notifiedSender bool + ) + for { + var copied bool + var deadline time.Time + // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), + // which can the be popped the next time we assemble a packet. + // This allows us to return Write() when all data but x bytes have been sent out. + // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, + // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). + if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { + if s.nextFrame == nil { + f := wire.GetStreamFrame() + f.Offset = s.writeOffset + f.StreamID = s.streamID + f.DataLenPresent = true + f.Data = f.Data[:len(s.dataForWriting)] + copy(f.Data, s.dataForWriting) + s.nextFrame = f + } else { + l := len(s.nextFrame.Data) + s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)] + copy(s.nextFrame.Data[l:], s.dataForWriting) + } + s.dataForWriting = nil + bytesWritten = len(p) + copied = true + } else { + bytesWritten = len(p) - len(s.dataForWriting) + deadline = s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + s.dataForWriting = nil + return bytesWritten, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + defer deadlineTimer.Stop() + } + deadlineTimer.Reset(deadline) + } + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break + } + } + + s.mutex.Unlock() + if !notifiedSender { + s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex + notifiedSender = true + } + if copied { + s.mutex.Lock() + break + } + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() + } + } + s.mutex.Lock() + } + + if bytesWritten == len(p) { + return bytesWritten, nil + } + if s.closeForShutdownErr != nil { + return bytesWritten, s.closeForShutdownErr + } else if s.cancelWriteErr != nil { + return bytesWritten, s.cancelWriteErr + } + return bytesWritten, nil +} + +func (s *sendStream) canBufferStreamFrame() bool { + var l protocol.ByteCount + if s.nextFrame != nil { + l = s.nextFrame.DataLen() + } + return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize +} + +// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) { + s.mutex.Lock() + f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes) + if f != nil { + s.numOutstandingFrames++ + } + s.mutex.Unlock() + + if f == nil { + return nil, hasMoreData + } + return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData +} + +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + if s.canceledWrite || s.closeForShutdownErr != nil { + return nil, false + } + + if len(s.retransmissionQueue) > 0 { + f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes) + if f != nil || hasMoreRetransmissions { + if f == nil { + return nil, true + } + // We always claim that we have more data to send. + // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. + return f, true + } + } + + if len(s.dataForWriting) == 0 && s.nextFrame == nil { + if s.finishedWriting && !s.finSent { + s.finSent = true + return &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + Fin: true, + }, false + } + return nil, false + } + + sendWindow := s.flowController.SendWindowSize() + if sendWindow == 0 { + if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ + StreamID: s.streamID, + MaximumStreamData: offset, + }) + return nil, false + } + return nil, true + } + + f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow) + if dataLen := f.DataLen(); dataLen > 0 { + s.writeOffset += f.DataLen() + s.flowController.AddBytesSent(f.DataLen()) + } + f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent + if f.Fin { + s.finSent = true + } + return f, hasMoreData +} + +func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) { + if s.nextFrame != nil { + nextFrame := s.nextFrame + s.nextFrame = nil + + maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version)) + if nextFrame.DataLen() > maxDataLen { + s.nextFrame = wire.GetStreamFrame() + s.nextFrame.StreamID = s.streamID + s.nextFrame.Offset = s.writeOffset + maxDataLen + s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] + s.nextFrame.DataLenPresent = true + copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) + nextFrame.Data = nextFrame.Data[:maxDataLen] + } else { + s.signalWrite() + } + return nextFrame, s.nextFrame != nil || s.dataForWriting != nil + } + + f := wire.GetStreamFrame() + f.Fin = false + f.StreamID = s.streamID + f.Offset = s.writeOffset + f.DataLenPresent = true + f.Data = f.Data[:0] + + hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow) + if len(f.Data) == 0 && !f.Fin { + f.PutBack() + return nil, hasMoreData + } + return f, hasMoreData +} + +func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool { + maxDataLen := f.MaxDataLen(maxBytes, s.version) + if maxDataLen == 0 { // a STREAM frame must have at least one byte of data + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting + } + s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow)) + + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting +} + +func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) { + f := s.retransmissionQueue[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version) + if needsSplit { + return newFrame, true + } + s.retransmissionQueue = s.retransmissionQueue[1:] + return f, len(s.retransmissionQueue) > 0 +} + +func (s *sendStream) hasData() bool { + s.mutex.Lock() + hasData := len(s.dataForWriting) > 0 + s.mutex.Unlock() + return hasData +} + +func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { + if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { + f.Data = f.Data[:len(s.dataForWriting)] + copy(f.Data, s.dataForWriting) + s.dataForWriting = nil + s.signalWrite() + return + } + f.Data = f.Data[:maxBytes] + copy(f.Data, s.dataForWriting) + s.dataForWriting = s.dataForWriting[maxBytes:] + if s.canBufferStreamFrame() { + s.signalWrite() + } +} + +func (s *sendStream) frameAcked(f wire.Frame) { + f.(*wire.StreamFrame).PutBack() + + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + newlyCompleted := s.isNewlyCompleted() + s.mutex.Unlock() + + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStream) isNewlyCompleted() bool { + completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 + if completed && !s.completed { + s.completed = true + return true + } + return false +} + +func (s *sendStream) queueRetransmission(f wire.Frame) { + sf := f.(*wire.StreamFrame) + sf.DataLenPresent = true + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.retransmissionQueue = append(s.retransmissionQueue, sf) + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) +} + +func (s *sendStream) Close() error { + s.mutex.Lock() + if s.closedForShutdown { + s.mutex.Unlock() + return nil + } + if s.canceledWrite { + s.mutex.Unlock() + return fmt.Errorf("close called for canceled stream %d", s.streamID) + } + s.ctxCancel() + s.finishedWriting = true + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex + return nil +} + +func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { + s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) +} + +// must be called after locking the mutex +func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.ctxCancel() + s.canceledWrite = true + s.cancelWriteErr = writeErr + s.numOutstandingFrames = 0 + s.retransmissionQueue = nil + newlyCompleted := s.isNewlyCompleted() + s.mutex.Unlock() + + s.signalWrite() + s.sender.queueControlFrame(&wire.ResetStreamFrame{ + StreamID: s.streamID, + FinalSize: s.writeOffset, + ErrorCode: errorCode, + }) + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { + s.mutex.Lock() + hasStreamData := s.dataForWriting != nil || s.nextFrame != nil + s.mutex.Unlock() + + s.flowController.UpdateSendWindow(limit) + if hasStreamData { + s.sender.onHasStreamData(s.streamID) + } +} + +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + s.cancelWriteImpl(frame.ErrorCode, &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, + }) +} + +func (s *sendStream) Context() context.Context { + return s.ctx +} + +func (s *sendStream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + s.deadline = t + s.mutex.Unlock() + s.signalWrite() + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *sendStream) closeForShutdown(err error) { + s.mutex.Lock() + s.ctxCancel() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalWrite() +} + +// signalWrite performs a non-blocking send on the writeChan +func (s *sendStream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} diff --git a/internal/quic-go/send_stream_test.go b/internal/quic-go/send_stream_test.go new file mode 100644 index 00000000..066cdb8c --- /dev/null +++ b/internal/quic-go/send_stream_test.go @@ -0,0 +1,1159 @@ +package quic + +import ( + "bytes" + "errors" + "io" + mrand "math/rand" + "runtime" + "time" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("Send Stream", func() { + const streamID protocol.StreamID = 1337 + + var ( + str *sendStream + strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever) + + timeout := scaleDuration(250 * time.Millisecond) + strWithTimeout = gbytes.TimeoutWriter(str, timeout) + }) + + expectedFrameHeaderLen := func(offset protocol.ByteCount) protocol.ByteCount { + return (&wire.StreamFrame{ + StreamID: streamID, + Offset: offset, + DataLenPresent: true, + }).Length(protocol.VersionWhatever) + } + + waitForWrite := func() { + EventuallyWithOffset(0, func() bool { + str.mutex.Lock() + hasData := str.dataForWriting != nil || str.nextFrame != nil + str.mutex.Unlock() + return hasData + }).Should(BeTrue()) + } + + getDataAtOffset := func(offset, length protocol.ByteCount) []byte { + b := make([]byte, length) + for i := protocol.ByteCount(0); i < length; i++ { + b[i] = uint8(offset + i) + } + return b + } + + getData := func(length protocol.ByteCount) []byte { + return getDataAtOffset(0, length) + } + + It("gets stream id", func() { + Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + }) + + Context("writing", func() { + It("writes and gets all data at once", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + frame, _ := str.popStreamFrame(protocol.MaxByteCount) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("foobar"))) + Expect(f.Fin).To(BeFalse()) + Expect(f.Offset).To(BeZero()) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) + Expect(str.dataForWriting).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("writes and gets data in two turns", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + close(done) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + frame, _ := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(BeZero()) + Expect(f.Fin).To(BeFalse()) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.DataLenPresent).To(BeTrue()) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.Fin).To(BeFalse()) + Expect(f.Offset).To(Equal(protocol.ByteCount(3))) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(str.popStreamFrame(1000)).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("bundles small writes", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + mockSender.EXPECT().onHasStreamData(streamID).Times(2) + n, err := strWithTimeout.Write([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + n, err = strWithTimeout.Write([]byte("bar")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + close(done) + }() + Eventually(done).Should(BeClosed()) // both Write calls returned without any data having been dequeued yet + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + frame, _ := str.popStreamFrame(protocol.MaxByteCount) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(BeZero()) + Expect(f.Fin).To(BeFalse()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("writes and gets data in multiple turns, for large writes", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(5) + var totalBytesSent protocol.ByteCount + mockFC.EXPECT().AddBytesSent(gomock.Any()).Do(func(l protocol.ByteCount) { totalBytesSent += l }).Times(5) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write(getData(5000)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(5000)) + close(done) + }() + waitForWrite() + for i := 0; i < 5; i++ { + frame, _ := str.popStreamFrame(1100) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(BeNumerically("~", 1100*i, 10*i)) + Expect(f.Fin).To(BeFalse()) + Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) + Expect(f.DataLenPresent).To(BeTrue()) + } + Expect(totalBytesSent).To(Equal(protocol.ByteCount(5000))) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks Write as soon as a STREAM frame can be buffered", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(protocol.MaxPacketBufferSize + 3)) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) + Expect(hasMoreData).To(BeTrue()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.DataLen()).To(Equal(protocol.ByteCount(2))) + Consistently(done).ShouldNot(BeClosed()) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) + frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1) + 1) + Expect(hasMoreData).To(BeTrue()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.DataLen()).To(Equal(protocol.ByteCount(1))) + Eventually(done).Should(BeClosed()) + }) + + It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() { + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := str.Write(getData(protocol.MaxPacketBufferSize)) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) + Expect(hasMoreData).To(BeTrue()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("fo"))) + Consistently(done).ShouldNot(BeClosed()) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) + frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2) + 4) + Expect(hasMoreData).To(BeTrue()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("obar"))) + Eventually(done).Should(BeClosed()) + }) + + It("popStreamFrame returns nil if no data is available", func() { + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("says if it has more data for writing", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(100)) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) + Expect(hasMoreData).To(BeTrue()) + frame, hasMoreData = str.popStreamFrame(protocol.MaxByteCount) + Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) + Expect(hasMoreData).To(BeFalse()) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) + Expect(frame).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("copies the slice while writing", func() { + frameHeaderSize := protocol.ByteCount(4) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + s := []byte("foo") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write(s) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + }() + waitForWrite() + frame, _ := str.popStreamFrame(frameHeaderSize + 1) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("f"))) + frame, _ = str.popStreamFrame(100) + Expect(frame).ToNot(BeNil()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("oo"))) + s[1] = 'e' + Expect(f.Data).To(Equal([]byte("oo"))) + Eventually(done).Should(BeClosed()) + }) + + It("returns when given a nil input", func() { + n, err := strWithTimeout.Write(nil) + Expect(n).To(BeZero()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns when given an empty slice", func() { + n, err := strWithTimeout.Write([]byte("")) + Expect(n).To(BeZero()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("cancels the context when Close is called", func() { + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Context().Done()).ToNot(BeClosed()) + Expect(str.Close()).To(Succeed()) + Expect(str.Context().Done()).To(BeClosed()) + }) + + Context("flow control blocking", func() { + It("queues a BLOCKED frame if the stream is flow control blocked", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) + mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(12)) + mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{ + StreamID: streamID, + MaximumStreamData: 12, + }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + // make the Write go routine return + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + + It("says that it doesn't have any more data, when it is flow control blocked", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + + // first pop a STREAM frame of the maximum size allowed by flow control + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + f, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) + Expect(f).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + + // try to pop again, this time noticing that we're blocked + mockFC.EXPECT().SendWindowSize() + // don't use offset 3 here, to make sure the BLOCKED frame contains the number returned by the flow controller + mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(10)) + mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{ + StreamID: streamID, + MaximumStreamData: 10, + }) + f, hasMoreData = str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + // make the Write go routine return + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("deadlines", func() { + It("returns an error when Write is called after the deadline", func() { + str.SetWriteDeadline(time.Now().Add(-time.Second)) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + n, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + }) + + It("unblocks when the deadline is changed to the past", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.SetWriteDeadline(time.Now().Add(time.Hour)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.SetWriteDeadline(time.Now().Add(-time.Hour)) + Eventually(done).Should(BeClosed()) + }) + + It("returns the number of bytes written, when the deadline expires", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + var n int + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(writeReturned) + mockSender.EXPECT().onHasStreamData(streamID) + var err error + n, err = strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) + Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) + }) + + It("doesn't pop any data after the deadline expired", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(writeReturned) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) + frame, hasMoreData = str.popStreamFrame(50) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetWriteDeadline(deadline1) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + close(done) + }() + runtime.Gosched() + n, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks earlier, when a new deadline is set", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + close(done) + }() + str.SetWriteDeadline(deadline1) + runtime.Gosched() + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't unblock if the deadline is removed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + deadlineUnset := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetWriteDeadline(time.Time{}) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline)) + close(deadlineUnset) + }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError("test done")) + close(done) + }() + runtime.Gosched() + Eventually(deadlineUnset).Should(BeClosed()) + Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) + // make the go routine return + str.closeForShutdown(errors.New("test done")) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("closing", func() { + It("doesn't allow writes after it has been closed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.Close() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("write on closed stream 1337")) + }) + + It("allows FIN", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.Close() + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(BeEmpty()) + Expect(f.Fin).To(BeTrue()) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't send a FIN when there's still data", func() { + const frameHeaderLen protocol.ByteCount = 4 + mockSender.EXPECT().onHasStreamData(streamID).Times(2) + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) + frame, _ := str.popStreamFrame(3 + frameHeaderLen) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.Fin).To(BeFalse()) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.Fin).To(BeTrue()) + }) + + It("doesn't send a FIN when there's still data, for long writes", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Close()).To(Succeed()) + }() + waitForWrite() + for i := 1; i <= 5; i++ { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + if i == 5 { + Eventually(done).Should(BeClosed()) + } + frame, _ := str.popStreamFrame(1100) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) + Expect(f.Fin).To(Equal(i == 5)) // the last frame should have the FIN bit set + } + }) + + It("doesn't allow FIN after it is closed for shutdown", func() { + str.closeForShutdown(errors.New("test")) + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + + Expect(str.Close()).To(Succeed()) + f, hasMoreData = str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't allow FIN twice", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.Close() + frame, _ := str.popStreamFrame(1000) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(BeEmpty()) + Expect(f.Fin).To(BeTrue()) + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + }) + + Context("closing for shutdown", func() { + testErr := errors.New("test") + + It("returns errors when the stream is cancelled", func() { + str.closeForShutdown(testErr) + n, err := strWithTimeout.Write([]byte("foo")) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + + It("doesn't get data for writing if an error occurred", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError(testErr)) + close(done) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + str.closeForShutdown(testErr) + frame, hasMoreData = str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + Eventually(done).Should(BeClosed()) + }) + + It("cancels the context", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.closeForShutdown(testErr) + Expect(str.Context().Done()).To(BeClosed()) + }) + }) + }) + + Context("handling MAX_STREAM_DATA frames", func() { + It("informs the flow controller", func() { + mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) + str.updateSendWindow(0x1337) + }) + + It("says when it has data for sending", func() { + mockFC.EXPECT().UpdateSendWindow(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + mockSender.EXPECT().onHasStreamData(streamID) + str.updateSendWindow(42) + // make sure the Write go routine returns + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("stream cancellations", func() { + Context("canceling writing", func() { + It("queues a RESET_STREAM frame", func() { + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ + StreamID: streamID, + FinalSize: 1234, + ErrorCode: 9876, + }), + mockSender.EXPECT().onStreamCompleted(streamID), + ) + str.writeOffset = 1234 + str.CancelWrite(9876) + }) + + // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). + // A single successful run of this test therefore doesn't mean a lot, + // for reliable results it has to be run many times. + It("returns a nil error when the whole slice has been sent out", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(1) + mockSender.EXPECT().onHasStreamData(streamID).MaxTimes(1) + mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) + mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1) + errChan := make(chan error) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write(getData(100)) + if n == 0 { + errChan <- nil + return + } + errChan <- err + }() + + runtime.Gosched() + go str.popStreamFrame(protocol.MaxByteCount) + go str.CancelWrite(1234) + Eventually(errChan).Should(Receive(Not(HaveOccurred()))) + }) + + It("unblocks Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + var n int + go func() { + defer GinkgoRecover() + var err error + n, err = strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + close(writeReturned) + }() + waitForWrite() + frame, _ := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelWrite(1234) + Eventually(writeReturned).Should(BeClosed()) + Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) + }) + + It("doesn't pop STREAM frames after being canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + strWithTimeout.Write(getData(100)) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelWrite(1234) + frame, hasMoreData = str.popStreamFrame(10) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + Eventually(writeReturned).Should(BeClosed()) + }) + + It("doesn't pop STREAM frames after being canceled, for large writes", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelWrite(1234) + frame, hasMoreData = str.popStreamFrame(10) + Expect(hasMoreData).To(BeFalse()) + Expect(frame).To(BeNil()) + Eventually(writeReturned).Should(BeClosed()) + }) + + It("ignores acknowledgements for STREAM frames after it was cancelled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + strWithTimeout.Write(getData(100)) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelWrite(1234) + frame.OnAcked(frame.Frame) + }) + + It("cancels the context", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + Expect(str.Context().Done()).ToNot(BeClosed()) + str.CancelWrite(1234) + Expect(str.Context().Done()).To(BeClosed()) + }) + + It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + str.CancelWrite(1234) + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + }) + + It("only cancels once", func() { + mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234}) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + str.CancelWrite(1234) + str.CancelWrite(4321) + }) + + It("queues a RESET_STREAM frame, even if the stream was already closed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) + }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + Expect(str.Close()).To(Succeed()) + // don't EXPECT any calls to queueControlFrame + str.CancelWrite(123) + }) + }) + + Context("receiving STOP_SENDING frames", func() { + It("queues a RESET_STREAM frames, and copies the error code from the STOP_SENDING frame", func() { + mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ + StreamID: streamID, + ErrorCode: 101, + }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 101, + }) + }) + + It("unblocks Write", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write(getData(5000)) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) + close(done) + }() + waitForWrite() + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) + }) + }) + }) + + Context("retransmissions", func() { + It("queues and retrieves frames", func() { + str.numOutstandingFrames = 1 + f := &wire.StreamFrame{ + Data: []byte("foobar"), + Offset: 0x42, + DataLenPresent: false, + } + mockSender.EXPECT().onHasStreamData(streamID) + str.queueRetransmission(f) + frame, _ := str.popStreamFrame(protocol.MaxByteCount) + Expect(frame).ToNot(BeNil()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) + Expect(f.Data).To(Equal([]byte("foobar"))) + Expect(f.DataLenPresent).To(BeTrue()) + }) + + It("splits a retransmission", func() { + str.numOutstandingFrames = 1 + sf := &wire.StreamFrame{ + Data: []byte("foobar"), + Offset: 0x42, + DataLenPresent: false, + } + mockSender.EXPECT().onHasStreamData(streamID) + str.queueRetransmission(sf) + frame, hasMoreData := str.popStreamFrame(sf.Length(str.version) - 3) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(hasMoreData).To(BeTrue()) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.DataLenPresent).To(BeTrue()) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) + Expect(frame).ToNot(BeNil()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x45))) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.DataLenPresent).To(BeTrue()) + }) + + It("returns nil if the size is too small", func() { + str.numOutstandingFrames = 1 + f := &wire.StreamFrame{ + Data: []byte("foobar"), + Offset: 0x42, + DataLenPresent: false, + } + mockSender.EXPECT().onHasStreamData(streamID) + str.queueRetransmission(f) + frame, hasMoreData := str.popStreamFrame(2) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).To(BeNil()) + }) + + It("queues lost STREAM frames", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + frame, _ := str.popStreamFrame(protocol.MaxByteCount) + Eventually(done).Should(BeClosed()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + + // now lose the frame + mockSender.EXPECT().onHasStreamData(streamID) + frame.OnLost(frame.Frame) + newFrame, _ := str.popStreamFrame(protocol.MaxByteCount) + Expect(newFrame).ToNot(BeNil()) + Expect(newFrame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + }) + + It("doesn't queue retransmissions for a stream that was canceled", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + f, _ := str.popStreamFrame(100) + Expect(f).ToNot(BeNil()) + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onStreamCompleted(streamID), + ) + str.CancelWrite(9876) + // don't EXPECT any calls to onHasStreamData + f.OnLost(f.Frame) + Expect(str.retransmissionQueue).To(BeEmpty()) + }) + }) + + Context("determining when a stream is completed", func() { + BeforeEach(func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() + }) + + It("says when a stream is completed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(make([]byte, 100)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + + // get a bunch of small frames (max. 20 bytes) + var frames []ackhandler.Frame + for { + frame, hasMoreData := str.popStreamFrame(20) + if frame == nil { + continue + } + frames = append(frames, *frame) + if !hasMoreData { + break + } + } + Eventually(done).Should(BeClosed()) + + // Acknowledge all frames. + // We don't expect the stream to be completed, since we still need to send the FIN. + for _, f := range frames { + f.OnAcked(f.Frame) + } + + // Now close the stream and acknowledge the FIN. + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Close()).To(Succeed()) + frame, _ := str.popStreamFrame(protocol.MaxByteCount) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + frame.OnAcked(frame.Frame) + }) + + It("says when a stream is completed, if Close() is called before popping the frame", func() { + mockSender.EXPECT().onHasStreamData(streamID).Times(2) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(make([]byte, 100)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + Eventually(done).Should(BeClosed()) + Expect(str.Close()).To(Succeed()) + + frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount) + Expect(hasMoreData).To(BeFalse()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeTrue()) + + mockSender.EXPECT().onStreamCompleted(streamID) + frame.OnAcked(frame.Frame) + }) + + It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(100)) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Close()).To(Succeed()) + close(done) + }() + waitForWrite() + + // get a bunch of small frames (max. 20 bytes) + var frames []ackhandler.Frame + for { + frame, _ := str.popStreamFrame(20) + if frame == nil { + continue + } + frames = append(frames, *frame) + if frame.Frame.(*wire.StreamFrame).Fin { + break + } + } + Eventually(done).Should(BeClosed()) + + // lose the first frame, acknowledge all others + for _, f := range frames[1:] { + f.OnAcked(f.Frame) + } + mockSender.EXPECT().onHasStreamData(streamID) + frames[0].OnLost(frames[0].Frame) + + // get the retransmission and acknowledge it + ret, _ := str.popStreamFrame(protocol.MaxByteCount) + Expect(ret).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + ret.OnAcked(ret.Frame) + }) + + // This test is kind of an integration test. + // It writes 4 MB of data, and pops STREAM frames that sometimes are and sometimes aren't limited by flow control. + // Half of these STREAM frames are then received and their content saved, while the other half is reported lost + // and has to be retransmitted. + It("retransmits data until everything has been acknowledged", func() { + const dataLen = 1 << 22 // 4 MB + mockSender.EXPECT().onHasStreamData(streamID).AnyTimes() + mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { + return protocol.ByteCount(mrand.Intn(500)) + 50 + }).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() + + data := make([]byte, dataLen) + _, err := mrand.Read(data) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + _, err := str.Write(data) + Expect(err).ToNot(HaveOccurred()) + str.Close() + }() + + var completed bool + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) + + received := make([]byte, dataLen) + for { + if completed { + break + } + f, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300) + 100)) + if f == nil { + continue + } + sf := f.Frame.(*wire.StreamFrame) + // 50%: acknowledge the frame and save the data + // 50%: lose the frame + if mrand.Intn(100) < 50 { + copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) + f.OnAcked(f.Frame) + } else { + f.OnLost(f.Frame) + } + } + Expect(received).To(Equal(data)) + }) + }) +}) diff --git a/internal/quic-go/server.go b/internal/quic-go/server.go new file mode 100644 index 00000000..07173357 --- /dev/null +++ b/internal/quic-go/server.go @@ -0,0 +1,670 @@ +package quic + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. +var ErrServerClosed = errors.New("quic: Server closed") + +// packetHandler handles packets +type packetHandler interface { + handlePacket(*receivedPacket) + shutdown() + destroy(error) + getPerspective() protocol.Perspective +} + +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + setCloseError(error) +} + +type packetHandlerManager interface { + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool + Destroy() error + connRunner + SetServer(unknownPacketHandler) + CloseServer() +} + +type quicConn interface { + EarlyConnection + earlyConnReady() <-chan struct{} + handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber + getPerspective() protocol.Perspective + run() error + destroy(error) + shutdown() +} + +// A Listener of QUIC +type baseServer struct { + mutex sync.Mutex + + acceptEarlyConns bool + + tlsConf *tls.Config + config *Config + + conn rawConn + // If the server is started with ListenAddr, we create a packet conn. + // If it is started with Listen, we take a packet conn as a parameter. + createdPacketConn bool + + tokenGenerator *handshake.TokenGenerator + + connHandler packetHandlerManager + + receivedPackets chan *receivedPacket + + // set as a member, so they can be set in the tests + newConn func( + sendConn, + connRunner, + protocol.ConnectionID, /* original dest connection ID */ + *protocol.ConnectionID, /* retry src connection ID */ + protocol.ConnectionID, /* client dest connection ID */ + protocol.ConnectionID, /* destination connection ID */ + protocol.ConnectionID, /* source connection ID */ + protocol.StatelessResetToken, + *Config, + *tls.Config, + *handshake.TokenGenerator, + bool, /* enable 0-RTT */ + logging.ConnectionTracer, + uint64, + utils.Logger, + protocol.VersionNumber, + ) quicConn + + serverError error + errorChan chan struct{} + closed bool + running chan struct{} // closed as soon as run() returns + + connQueue chan quicConn + connQueueLen int32 // to be used as an atomic + + logger utils.Logger +} + +var ( + _ Listener = &baseServer{} + _ unknownPacketHandler = &baseServer{} +) + +type earlyServer struct{ *baseServer } + +var _ EarlyListener = &earlyServer{} + +func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { + return s.baseServer.accept(ctx) +} + +// ListenAddr creates a QUIC server listening on a given address. +// The tls.Config must not be nil and must contain a certificate configuration. +// The quic.Config may be nil, in that case the default values will be used. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { + return listenAddr(addr, tlsConf, config, false) +} + +// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listenAddr(addr, tlsConf, config, true) + if err != nil { + return nil, err + } + return &earlyServer{s}, nil +} + +func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + serv, err := listen(conn, tlsConf, config, acceptEarly) + if err != nil { + return nil, err + } + serv.createdPacketConn = true + return serv, nil +} + +// Listen listens for QUIC connections on a given net.PacketConn. If the +// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. A single net.PacketConn only be used for a single call to Listen. +// The PacketConn can be used for simultaneous calls to Dial. QUIC connection +// IDs are used for demultiplexing the different connections. The tls.Config +// must not be nil and must contain a certificate configuration. The +// tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore, +// it must define an application control (using NextProtos). The quic.Config may +// be nil, in that case the default values will be used. +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { + return listen(conn, tlsConf, config, false) +} + +// ListenEarly works like Listen, but it returns connections before the handshake completes. +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listen(conn, tlsConf, config, true) + if err != nil { + return nil, err + } + return &earlyServer{s}, nil +} + +func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(config); err != nil { + return nil, err + } + config = populateServerConfig(config) + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + + connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + if err != nil { + return nil, err + } + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) + if err != nil { + return nil, err + } + c, err := wrapConn(conn) + if err != nil { + return nil, err + } + s := &baseServer{ + conn: c, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + newConn: newConnection, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + } + go s.run() + connHandler.SetServer(s) + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + return s, nil +} + +func (s *baseServer) run() { + defer close(s.running) + for { + select { + case <-s.errorChan: + return + default: + } + select { + case <-s.errorChan: + return + case p := <-s.receivedPackets: + if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { + p.buffer.Release() + } + } + } +} + +var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { + if token == nil { + return false + } + validity := protocol.TokenValidity + if token.IsRetryToken { + validity = protocol.RetryTokenValidity + } + if time.Now().After(token.SentTime.Add(validity)) { + return false + } + var sourceAddr string + if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { + sourceAddr = udpAddr.IP.String() + } else { + sourceAddr = clientAddr.String() + } + return sourceAddr == token.RemoteAddr +} + +// Accept returns connections that already completed the handshake. +// It is only valid if acceptEarlyConns is false. +func (s *baseServer) Accept(ctx context.Context) (Connection, error) { + return s.accept(ctx) +} + +func (s *baseServer) accept(ctx context.Context) (quicConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case conn := <-s.connQueue: + atomic.AddInt32(&s.connQueueLen, -1) + return conn, nil + case <-s.errorChan: + return nil, s.serverError + } +} + +// Close the server +func (s *baseServer) Close() error { + s.mutex.Lock() + if s.closed { + s.mutex.Unlock() + return nil + } + if s.serverError == nil { + s.serverError = ErrServerClosed + } + // If the server was started with ListenAddr, we created the packet conn. + // We need to close it in order to make the go routine reading from that conn return. + createdPacketConn := s.createdPacketConn + s.closed = true + close(s.errorChan) + s.mutex.Unlock() + + <-s.running + s.connHandler.CloseServer() + if createdPacketConn { + return s.connHandler.Destroy() + } + return nil +} + +func (s *baseServer) setCloseError(e error) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.closed { + return + } + s.closed = true + s.serverError = e + close(s.errorChan) +} + +// Addr returns the server's network address +func (s *baseServer) Addr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *baseServer) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { + if wire.IsVersionNegotiationPacket(p.data) { + s.logger.Debugf("Dropping Version Negotiation packet.") + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + // If we're creating a new connection, the packet will be passed to the connection. + // The header will then be parsed again. + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + if err != nil && err != wire.ErrUnsupportedVersion { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing packet: %s", err) + return false + } + // Short header packets should never end up here in the first place + if !hdr.IsLongHeader { + panic(fmt.Sprintf("misrouted packet: %#v", hdr)) + } + if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + if p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + if !s.config.DisableVersionNegotiationPackets { + go s.sendVersionNegotiationPacket(p, hdr) + } + return false + } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's little point in sending a Stateless Reset, since the client + // might not have received the token yet. + s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + + s.logger.Debugf("<- Received Initial packet.") + + if err := s.handleInitialImpl(p, hdr); err != nil { + s.logger.Errorf("Error occurred handling initial packet: %s", err) + } + // Don't put the packet buffer back. + // handleInitialImpl deals with the buffer. + return true +} + +func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { + if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { + p.buffer.Release() + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return errors.New("too short connection ID") + } + + var ( + token *Token + retrySrcConnID *protocol.ConnectionID + ) + origDestConnID := hdr.DestConnectionID + if len(hdr.Token) > 0 { + c, err := s.tokenGenerator.DecodeToken(hdr.Token) + if err == nil { + token = &Token{ + IsRetryToken: c.IsRetryToken, + RemoteAddr: c.RemoteAddr, + SentTime: c.SentTime, + } + if token.IsRetryToken { + origDestConnID = c.OriginalDestConnectionID + retrySrcConnID = &c.RetrySrcConnectionID + } + } + } + if !s.config.AcceptToken(p.remoteAddr, token) { + go func() { + defer p.buffer.Release() + if token != nil && token.IsRetryToken { + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + return + } + if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error sending Retry: %s", err) + } + }() + return nil + } + + if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { + s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() + return nil + } + + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + if err != nil { + return err + } + s.logger.Debugf("Changing connection ID to %s.", connID) + var conn quicConn + tracingID := nextConnTracingID() + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + var tracer logging.ConnectionTracer + if s.config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = s.config.Tracer.TracerForConnection( + context.WithValue(context.Background(), ConnectionTracingKey, tracingID), + protocol.PerspectiveServer, + connID, + ) + } + conn = s.newConn( + newSendConn(s.conn, p.remoteAddr, p.info), + s.connHandler, + origDestConnID, + retrySrcConnID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + s.connHandler.GetStatelessResetToken(connID), + s.config, + s.tlsConf, + s.tokenGenerator, + s.acceptEarlyConns, + tracer, + tracingID, + s.logger, + hdr.Version, + ) + conn.handlePacket(p) + return conn + }); !added { + return nil + } + go conn.run() + go s.handleNewConn(conn) + if conn == nil { + p.buffer.Release() + return nil + } + return nil +} + +func (s *baseServer) handleNewConn(conn quicConn) { + connCtx := conn.Context() + if s.acceptEarlyConns { + // wait until the early connection is ready (or the handshake fails) + select { + case <-conn.earlyConnReady(): + case <-connCtx.Done(): + return + } + } else { + // wait until the handshake is complete (or fails) + select { + case <-conn.HandshakeComplete().Done(): + case <-connCtx.Done(): + return + } + } + + atomic.AddInt32(&s.connQueueLen, 1) + select { + case s.connQueue <- conn: + // blocks until the connection is accepted + case <-connCtx.Done(): + atomic.AddInt32(&s.connQueueLen, -1) + // don't pass connections that were already closed to Accept() + } +} + +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + // Log the Initial packet now. + // If no Retry is sent, the packet will be logged by the connection. + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) + srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + if err != nil { + return err + } + token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) + if err != nil { + return err + } + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeRetry + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = srcConnID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.Token = token + if s.logger.Debug() { + s.logger.Debugf("Changing connection ID to %s.", srcConnID) + s.logger.Debugf("-> Sending Retry") + replyHdr.Log(s.logger) + } + + packetBuffer := getPacketBuffer() + defer packetBuffer.Release() + buf := bytes.NewBuffer(packetBuffer.Data) + if err := replyHdr.Write(buf, hdr.Version); err != nil { + return err + } + // append the Retry integrity tag + tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) + buf.Write(tag[:]) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) + } + _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. + sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + data := p.data[:hdr.ParsedLen()+hdr.Length] + extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) + if err != nil { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + } + // don't return the error here. Just drop the packet. + return nil + } + hdrLen := extHdr.ParsedLen() + if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { + // don't return the error here. Just drop the packet. + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + } + return nil + } + if s.logger.Debug() { + s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) + } + return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) +} + +func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) +} + +// sendError sends the error as a response to the packet received with header hdr +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { + packetBuffer := getPacketBuffer() + defer packetBuffer.Release() + buf := bytes.NewBuffer(packetBuffer.Data) + + ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} + + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeInitial + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = hdr.DestConnectionID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.PacketNumberLen = protocol.PacketNumberLen4 + replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) + if err := replyHdr.Write(buf, hdr.Version); err != nil { + return err + } + payloadOffset := buf.Len() + + if err := ccf.Write(buf, hdr.Version); err != nil { + return err + } + + raw := buf.Bytes() + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) + raw = raw[0 : buf.Len()+sealer.Overhead()] + + pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) + sealer.EncryptHeader( + raw[pnOffset+4:pnOffset+4+16], + &raw[0], + raw[pnOffset:payloadOffset], + ) + + replyHdr.Log(s.logger) + wire.LogFrame(s.logger, ccf, true) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) + } + _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { + s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) + data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket( + p.remoteAddr, + &wire.Header{ + IsLongHeader: true, + DestConnectionID: hdr.SrcConnectionID, + SrcConnectionID: hdr.DestConnectionID, + }, + protocol.ByteCount(len(data)), + nil, + ) + } + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + s.logger.Debugf("Error sending Version Negotiation: %s", err) + } +} diff --git a/internal/quic-go/server_test.go b/internal/quic-go/server_test.go new file mode 100644 index 00000000..e5e9c48e --- /dev/null +++ b/internal/quic-go/server_test.go @@ -0,0 +1,1237 @@ +package quic + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "errors" + "net" + "reflect" + "runtime/pprof" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/logging" + mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/testdata" + "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/quic-go/wire" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func areServersRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*baseServer).run") +} + +var _ = Describe("Server", func() { + var ( + conn *MockPacketConn + tlsConf *tls.Config + ) + + getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { + buffer := getPacketBuffer() + buf := bytes.NewBuffer(buffer.Data) + if hdr.IsLongHeader { + hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 + } + Expect((&wire.ExtendedHeader{ + Header: *hdr, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + n := buf.Len() + buf.Write(p) + data := buffer.Data[:buf.Len()] + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) + _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) + data = data[:len(data)+16] + sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) + return &receivedPacket{ + remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, + data: data, + buffer: buffer, + } + } + + getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: destConnID, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.buffer = getPacketBuffer() + p.remoteAddr = senderAddr + return p + } + + getInitialWithRandomDestConnID := func() *receivedPacket { + destConnID := make([]byte, 10) + _, err := rand.Read(destConnID) + Expect(err).ToNot(HaveOccurred()) + + return getInitial(destConnID) + } + + parseHeader := func(data []byte) *wire.Header { + hdr, _, _, err := wire.ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + return hdr + } + + BeforeEach(func() { + conn = NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) + tlsConf = testdata.GetTLSConfig() + tlsConf.NextProtos = []string{"proto1"} + }) + + AfterEach(func() { + Eventually(areServersRunning).Should(BeFalse()) + }) + + It("errors when no tls.Config is given", func() { + _, err := ListenAddr("localhost:0", nil, nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) + }) + + It("errors when the Config contains an invalid version", func() { + version := protocol.VersionNumber(0x1234) + _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) + Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + }) + + It("fills in default values if options are not set in the Config", func() { + ln, err := Listen(conn, tlsConf, &Config{}) + Expect(err).ToNot(HaveOccurred()) + server := ln.(*baseServer) + Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) + Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) + Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) + Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken))) + Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second)) + // stop the listener + Expect(ln.Close()).To(Succeed()) + }) + + It("setups with the right values", func() { + supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} + acceptToken := func(_ net.Addr, _ *Token) bool { return true } + config := Config{ + Versions: supportedVersions, + AcceptToken: acceptToken, + HandshakeIdleTimeout: 1337 * time.Hour, + MaxIdleTimeout: 42 * time.Minute, + KeepAlivePeriod: 5 * time.Second, + StatelessResetKey: []byte("foobar"), + } + ln, err := Listen(conn, tlsConf, &config) + Expect(err).ToNot(HaveOccurred()) + server := ln.(*baseServer) + Expect(server.connHandler).ToNot(BeNil()) + Expect(server.config.Versions).To(Equal(supportedVersions)) + Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) + Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) + Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken))) + Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) + Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) + // stop the listener + Expect(ln.Close()).To(Succeed()) + }) + + It("listens on a given address", func() { + addr := "127.0.0.1:13579" + ln, err := ListenAddr(addr, tlsConf, &Config{}) + Expect(err).ToNot(HaveOccurred()) + Expect(ln.Addr().String()).To(Equal(addr)) + // stop the listener + Expect(ln.Close()).To(Succeed()) + }) + + It("errors if given an invalid address", func() { + addr := "127.0.0.1" + _, err := ListenAddr(addr, tlsConf, &Config{}) + Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) + }) + + It("errors if given an invalid address", func() { + addr := "1.1.1.1:1111" + _, err := ListenAddr(addr, tlsConf, &Config{}) + Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) + }) + + Context("server accepting connections that completed the handshake", func() { + var ( + serv *baseServer + phm *MockPacketHandlerManager + tracer *mocklogging.MockTracer + ) + + BeforeEach(func() { + tracer = mocklogging.NewMockTracer(mockCtrl) + ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) + Expect(err).ToNot(HaveOccurred()) + serv = ln.(*baseServer) + phm = NewMockPacketHandlerManager(mockCtrl) + serv.connHandler = phm + }) + + AfterEach(func() { + phm.EXPECT().CloseServer().MaxTimes(1) + serv.Close() + }) + + Context("handling packets", func() { + It("drops Initial packets with a too short connection ID", func() { + p := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, nil) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + serv.handlePacket(p) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + + It("drops too small Initial", func() { + p := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize-100), + ) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + serv.handlePacket(p) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + + It("drops non-Initial packets", func() { + p := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: serv.config.Versions[0], + }, []byte("invalid")) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) + serv.handlePacket(p) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + + It("decodes the token from the Token field", func() { + raddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 13, 37), + Port: 1337, + } + done := make(chan struct{}) + serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { + Expect(addr).To(Equal(raddr)) + Expect(token).ToNot(BeNil()) + close(done) + return false + } + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("passes an empty token to the callback, if decoding fails", func() { + raddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 13, 37), + Port: 1337, + } + done := make(chan struct{}) + serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { + Expect(addr).To(Equal(raddr)) + Expect(token).To(BeNil()) + close(done) + return false + } + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: []byte("foobar"), + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("creates a connection when the token is accepted", func() { + serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } + retryToken, err := serv.tokenGenerator.NewRetryToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + ) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + Token: retryToken, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + run := make(chan struct{}) + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var newConnID protocol.ConnectionID + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + newConnID = c + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { + newConnID = c + return token + }) + fn() + return true + }) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( + _ sendConn, + _ connRunner, + origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + tokenP protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + enable0RTT bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(enable0RTT).To(BeFalse()) + Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(destConnID).To(Equal(hdr.SrcConnectionID)) + // make sure we're using a server-generated connection ID + Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) + Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) + Expect(srcConnID).To(Equal(newConnID)) + Expect(tokenP).To(Equal(token)) + conn.EXPECT().handlePacket(p) + conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.handlePacket(p) + // the Handshake packet is written by the connection. + // Make sure there are no Write calls on the packet conn. + time.Sleep(50 * time.Millisecond) + close(done) + }() + // make sure we're using a server-generated connection ID + Eventually(run).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + + It("sends a Version Negotiation Packet for unsupported versions", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinUnknownVersionPacketSize)) + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { + Expect(replyHdr.IsLongHeader).To(BeTrue()) + Expect(replyHdr.Version).To(BeZero()) + Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) + Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) + hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(srcConnID)) + Expect(hdr.SrcConnectionID).To(Equal(destConnID)) + Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send a Version Negotiation packets if sending them is disabled", func() { + serv.config.DisableVersionNegotiationPackets = true + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinUnknownVersionPacketSize)) + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0) + serv.handlePacket(packet) + Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) + }) + + It("ignores Version Negotiation packets", func() { + data := wire.ComposeVersionNegotiation( + protocol.ConnectionID{1, 2, 3, 4}, + protocol.ConnectionID{4, 3, 2, 1}, + []protocol.VersionNumber{1, 2, 3}, + ) + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }) + serv.handlePacket(&receivedPacket{ + remoteAddr: raddr, + data: data, + buffer: getPacketBuffer(), + }) + Eventually(done).Should(BeClosed()) + // make sure no other packet is sent + time.Sleep(scaleDuration(20 * time.Millisecond)) + }) + + It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + p := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) + Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + p.remoteAddr = raddr + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }) + serv.handlePacket(p) + Eventually(done).Should(BeClosed()) + // make sure no other packet is sent + time.Sleep(scaleDuration(20 * time.Millisecond)) + }) + + It("replies with a Retry packet, if a Token is required", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(replyHdr.Token).ToNot(BeEmpty()) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(replyHdr.Token).ToNot(BeEmpty()) + Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) + extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) + serv.handlePacket(packet) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + Eventually(done).Should(BeClosed()) + }) + + It("creates a connection, if no Token is required", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + run := make(chan struct{}) + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var newConnID protocol.ConnectionID + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + newConnID = c + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { + newConnID = c + return token + }) + fn() + return true + }) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( + _ sendConn, + _ connRunner, + origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + tokenP protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + enable0RTT bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(enable0RTT).To(BeFalse()) + Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(retrySrcConnID).To(BeNil()) + Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(destConnID).To(Equal(hdr.SrcConnectionID)) + // make sure we're using a server-generated connection ID + Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) + Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) + Expect(srcConnID).To(Equal(newConnID)) + Expect(tokenP).To(Equal(token)) + conn.EXPECT().handlePacket(p) + conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.handlePacket(p) + // the Handshake packet is written by the connection + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + close(done) + }() + // make sure we're using a server-generated connection ID + Eventually(run).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + + It("drops packets if the receive queue is full", func() { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).AnyTimes() + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() + + serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } + acceptConn := make(chan struct{}) + var counter uint32 // to be used as an atomic, so we query it in Eventually + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + <-acceptConn + atomic.AddUint32(&counter, 1) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) + conn.EXPECT().run().MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) + return conn + } + + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) + serv.handlePacket(p) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) + var wg sync.WaitGroup + for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) + }() + } + wg.Wait() + + close(acceptConn) + Eventually( + func() uint32 { return atomic.LoadUint32(&counter) }, + scaleDuration(100*time.Millisecond), + ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) + Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) + }) + + It("only creates a single connection for a duplicate Initial", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + var createdConn bool + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + createdConn = true + return conn + } + + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) + Expect(serv.handlePacketImpl(p)).To(BeTrue()) + Expect(createdConn).To(BeFalse()) + }) + + It("rejects new connection attempts if the accept queue is full", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + conn.EXPECT().Context().Return(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn.EXPECT().HandshakeComplete().Return(ctx) + return conn + } + + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).Times(protocol.MaxAcceptQueueSize) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) + + var wg sync.WaitGroup + wg.Add(protocol.MaxAcceptQueueSize) + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + go func() { + defer GinkgoRecover() + defer wg.Done() + serv.handlePacket(getInitialWithRandomDestConnID()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }() + } + wg.Wait() + p := getInitialWithRandomDestConnID() + hdr, _, _, err := wire.ParsePacket(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + return len(b), nil + }) + serv.handlePacket(p) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't accept new connections if they were closed in the mean time", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + ctx, cancel := context.WithCancel(context.Background()) + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn.EXPECT().handlePacket(p) + conn.EXPECT().run() + conn.EXPECT().Context().Return(ctx) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn.EXPECT().HandshakeComplete().Return(ctx) + close(connCreated) + return conn + } + + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) + + serv.handlePacket(p) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + Eventually(connCreated).Should(BeClosed()) + cancel() + time.Sleep(scaleDuration(200 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.Accept(context.Background()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + + // make the go routine return + phm.EXPECT().CloseServer() + conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + Expect(serv.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("accepting connections", func() { + It("returns Accept when an error occurs", func() { + testErr := errors.New("test err") + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(context.Background()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + + serv.setCloseError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("returns immediately, if an error occurred before", func() { + testErr := errors.New("test err") + serv.setCloseError(testErr) + for i := 0; i < 3; i++ { + _, err := serv.Accept(context.Background()) + Expect(err).To(MatchError(testErr)) + } + }) + + It("returns when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + + It("accepts new connections when the handshake completes", func() { + conn := NewMockQuicConn(mockCtrl) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(conn)) + close(done) + }() + + ctx, cancel := context.WithCancel(context.Background()) // handshake context + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().HandshakeComplete().Return(ctx) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + ) + Consistently(done).ShouldNot(BeClosed()) + cancel() // complete the handshake + Eventually(done).Should(BeClosed()) + }) + }) + }) + + Context("server accepting connections that haven't completed the handshake", func() { + var ( + serv *earlyServer + phm *MockPacketHandlerManager + ) + + BeforeEach(func() { + ln, err := ListenEarly(conn, tlsConf, nil) + Expect(err).ToNot(HaveOccurred()) + serv = ln.(*earlyServer) + phm = NewMockPacketHandlerManager(mockCtrl) + serv.connHandler = phm + }) + + AfterEach(func() { + phm.EXPECT().CloseServer().MaxTimes(1) + serv.Close() + }) + + It("accepts new connections when they become ready", func() { + conn := NewMockQuicConn(mockCtrl) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(conn)) + close(done) + }() + + ready := make(chan struct{}) + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + enable0RTT bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(enable0RTT).To(BeTrue()) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().earlyConnReady().Return(ready) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + ) + Consistently(done).ShouldNot(BeClosed()) + close(ready) + Eventually(done).Should(BeClosed()) + }) + + It("rejects new connection attempts if the accept queue is full", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + ready := make(chan struct{}) + close(ready) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + conn.EXPECT().earlyConnReady().Return(ready) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).Times(protocol.MaxAcceptQueueSize) + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + serv.handlePacket(getInitialWithRandomDestConnID()) + } + + Eventually(func() int32 { return atomic.LoadInt32(&serv.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + + p := getInitialWithRandomDestConnID() + hdr := parseHeader(p.data) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + return len(b), nil + }) + serv.handlePacket(p) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't accept new connections if they were closed in the mean time", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + ctx, cancel := context.WithCancel(context.Background()) + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn.EXPECT().handlePacket(p) + conn.EXPECT().run() + conn.EXPECT().earlyConnReady() + conn.EXPECT().Context().Return(ctx) + close(connCreated) + return conn + } + + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) + serv.handlePacket(p) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + Eventually(connCreated).Should(BeClosed()) + cancel() + time.Sleep(scaleDuration(200 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.Accept(context.Background()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + + // make the go routine return + phm.EXPECT().CloseServer() + conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + Expect(serv.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + }) +}) + +var _ = Describe("default source address verification", func() { + It("accepts a token", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} + token := &Token{ + IsRetryToken: true, + RemoteAddr: "192.168.0.1", + SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) + }) + + It("requests verification if no token is provided", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} + Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse()) + }) + + It("rejects a token if the address doesn't match", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} + token := &Token{ + IsRetryToken: true, + RemoteAddr: "127.0.0.1", + SentTime: time.Now(), + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) + }) + + It("accepts a token for a remote address is not a UDP address", func() { + remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + token := &Token{ + IsRetryToken: true, + RemoteAddr: "192.168.0.1:1337", + SentTime: time.Now(), + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) + }) + + It("rejects an invalid token for a remote address is not a UDP address", func() { + remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + token := &Token{ + IsRetryToken: true, + RemoteAddr: "192.168.0.1:7331", // mismatching port + SentTime: time.Now(), + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) + }) + + It("rejects an expired token", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} + token := &Token{ + IsRetryToken: true, + RemoteAddr: "192.168.0.1", + SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) + }) + + It("accepts a non-retry token", func() { + Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity)) + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} + token := &Token{ + IsRetryToken: false, + RemoteAddr: "192.168.0.1", + // if this was a retry token, it would have expired one second ago + SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), + } + Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) + }) +}) diff --git a/internal/quic-go/stream.go b/internal/quic-go/stream.go new file mode 100644 index 00000000..708304d1 --- /dev/null +++ b/internal/quic-go/stream.go @@ -0,0 +1,149 @@ +package quic + +import ( + "net" + "os" + "sync" + "time" + + "github.com/imroc/req/v3/internal/quic-go/ackhandler" + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type deadlineError struct{} + +func (deadlineError) Error() string { return "deadline exceeded" } +func (deadlineError) Temporary() bool { return true } +func (deadlineError) Timeout() bool { return true } +func (deadlineError) Unwrap() error { return os.ErrDeadlineExceeded } + +var errDeadline net.Error = &deadlineError{} + +// The streamSender is notified by the stream about various events. +type streamSender interface { + queueControlFrame(wire.Frame) + onHasStreamData(protocol.StreamID) + // must be called without holding the mutex that is acquired by closeForShutdown + onStreamCompleted(protocol.StreamID) +} + +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + +type streamI interface { + Stream + closeForShutdown(error) + // for receiving + handleStreamFrame(*wire.StreamFrame) error + handleResetStreamFrame(*wire.ResetStreamFrame) error + getWindowUpdate() protocol.ByteCount + // for sending + hasData() bool + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + updateSendWindow(protocol.ByteCount) +} + +var ( + _ receiveStreamI = (streamI)(nil) + _ sendStreamI = (streamI)(nil) +) + +// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface +// +// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. +type stream struct { + receiveStream + sendStream + + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool + + version protocol.VersionNumber +} + +var _ Stream = &stream{} + +// newStream creates a new Stream +func newStream(streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *stream { + s := &stream{sender: sender, version: version} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) + return s +} + +// need to define StreamID() here, since both receiveStream and readStream have a StreamID() +func (s *stream) StreamID() protocol.StreamID { + // the result is same for receiveStream and sendStream + return s.sendStream.StreamID() +} + +func (s *stream) Close() error { + return s.sendStream.Close() +} + +func (s *stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) // SetReadDeadline never errors + _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read and Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *stream) closeForShutdown(err error) { + s.sendStream.closeForShutdown(err) + s.receiveStream.closeForShutdown(err) +} + +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) + } +} diff --git a/internal/quic-go/stream_test.go b/internal/quic-go/stream_test.go new file mode 100644 index 00000000..51f1e131 --- /dev/null +++ b/internal/quic-go/stream_test.go @@ -0,0 +1,106 @@ +package quic + +import ( + "errors" + "io" + "os" + "strconv" + "time" + + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +// in the tests for the stream deadlines we set a deadline +// and wait to make an assertion when Read / Write was unblocked +// on the CIs, the timing is a lot less precise, so scale every duration by this factor +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} + +var _ = Describe("Stream", func() { + const streamID protocol.StreamID = 1337 + + var ( + str *stream + strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newStream(streamID, mockSender, mockFC, protocol.VersionWhatever) + + timeout := scaleDuration(250 * time.Millisecond) + strWithTimeout = struct { + io.Reader + io.Writer + }{ + gbytes.TimeoutReader(str, timeout), + gbytes.TimeoutWriter(str, timeout), + } + }) + + It("gets stream id", func() { + Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + }) + + Context("deadlines", func() { + It("sets a write deadline, when SetDeadline is called", func() { + str.SetDeadline(time.Now().Add(-time.Second)) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("sets a read deadline, when SetDeadline is called", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() + f := &wire.StreamFrame{Data: []byte("foobar")} + err := str.handleStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + }) + + Context("completing", func() { + It("is not completed when only the receive side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.receiveStream.sender.onStreamCompleted(streamID) + }) + + It("is not completed when only the send side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.sendStream.sender.onStreamCompleted(streamID) + }) + + It("is completed when both sides are completed", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + str.sendStream.sender.onStreamCompleted(streamID) + str.receiveStream.sender.onStreamCompleted(streamID) + }) + }) +}) + +var _ = Describe("Deadline Error", func() { + It("is a net.Error that wraps os.ErrDeadlineError", func() { + err := deadlineError{} + Expect(err.Timeout()).To(BeTrue()) + Expect(errors.Is(err, os.ErrDeadlineExceeded)).To(BeTrue()) + Expect(errors.Unwrap(err)).To(Equal(os.ErrDeadlineExceeded)) + }) +}) diff --git a/internal/quic-go/streams_map.go b/internal/quic-go/streams_map.go new file mode 100644 index 00000000..93465533 --- /dev/null +++ b/internal/quic-go/streams_map.go @@ -0,0 +1,317 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type streamError struct { + message string + nums []protocol.StreamNum +} + +func (e streamError) Error() string { + return e.message +} + +func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { + strError, ok := err.(streamError) + if !ok { + return err + } + ids := make([]interface{}, len(strError.nums)) + for i, num := range strError.nums { + ids[i] = num.StreamID(stype, pers) + } + return fmt.Errorf(strError.Error(), ids...) +} + +type streamOpenErr struct{ error } + +var _ net.Error = &streamOpenErr{} + +func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } +func (streamOpenErr) Timeout() bool { return false } + +// errTooManyOpenStreams is used internally by the outgoing streams maps. +var errTooManyOpenStreams = errors.New("too many open streams") + +type streamsMap struct { + perspective protocol.Perspective + version protocol.VersionNumber + + maxIncomingBidiStreams uint64 + maxIncomingUniStreams uint64 + + sender streamSender + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController + + mutex sync.Mutex + outgoingBidiStreams *outgoingBidiStreamsMap + outgoingUniStreams *outgoingUniStreamsMap + incomingBidiStreams *incomingBidiStreamsMap + incomingUniStreams *incomingUniStreamsMap + reset bool +} + +var _ streamManager = &streamsMap{} + +func newStreamsMap( + sender streamSender, + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, + maxIncomingBidiStreams uint64, + maxIncomingUniStreams uint64, + perspective protocol.Perspective, + version protocol.VersionNumber, +) streamManager { + m := &streamsMap{ + perspective: perspective, + newFlowController: newFlowController, + maxIncomingBidiStreams: maxIncomingBidiStreams, + maxIncomingUniStreams: maxIncomingUniStreams, + sender: sender, + version: version, + } + m.initMaps() + return m +} + +func (m *streamsMap) initMaps() { + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingBidiStreams = newIncomingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingBidiStreams, + m.sender.queueControlFrame, + ) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + func(num protocol.StreamNum) sendStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective) + return newSendStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingUniStreams = newIncomingUniStreamsMap( + func(num protocol.StreamNum) receiveStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) + return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingUniStreams, + m.sender.queueControlFrame, + ) +} + +func (m *streamsMap) OpenStream() (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenUniStream() (SendStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) +} + +func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.incomingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) +} + +func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.incomingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) +} + +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) + } + return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) + } + return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) + } + panic("") +} + +func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + str, err := m.getOrOpenReceiveStream(id) + if err != nil { + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } + } + return str, nil +} + +func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + } + str, err := m.incomingUniStreams.GetOrOpenStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + case protocol.StreamTypeBidi: + var str receiveStreamI + var err error + if id.InitiatedBy() == m.perspective { + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) + } + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + } + panic("") +} + +func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + str, err := m.getOrOpenSendStream(id) + if err != nil { + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } + } + return str, nil +} + +func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + str, err := m.outgoingUniStreams.GetStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + } + // an incoming unidirectional stream is a receive stream, not a send stream + return nil, fmt.Errorf("peer attempted to open send stream %d", id) + case protocol.StreamTypeBidi: + var str sendStreamI + var err error + if id.InitiatedBy() == m.perspective { + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) + } + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + } + panic("") +} + +func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { + switch f.Type { + case protocol.StreamTypeUni: + m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) + case protocol.StreamTypeBidi: + m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) + } +} + +func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { + m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) + m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) +} + +func (m *streamsMap) CloseWithError(err error) { + m.outgoingBidiStreams.CloseWithError(err) + m.outgoingUniStreams.CloseWithError(err) + m.incomingBidiStreams.CloseWithError(err) + m.incomingUniStreams.CloseWithError(err) +} + +// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are +// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. +// 2. reset to their initial state, such that we can immediately process new incoming stream data. +// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, +// until UseResetMaps() has been called. +func (m *streamsMap) ResetFor0RTT() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.reset = true + m.CloseWithError(Err0RTTRejected) + m.initMaps() +} + +func (m *streamsMap) UseResetMaps() { + m.mutex.Lock() + m.reset = false + m.mutex.Unlock() +} diff --git a/internal/quic-go/streams_map_generic_helper.go b/internal/quic-go/streams_map_generic_helper.go new file mode 100644 index 00000000..eda9f741 --- /dev/null +++ b/internal/quic-go/streams_map_generic_helper.go @@ -0,0 +1,18 @@ +package quic + +import ( + "github.com/cheekybits/genny/generic" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// In the auto-generated streams maps, we need to be able to close the streams. +// Therefore, extend the generic.Type with the stream close method. +// This definition must be in a file that Genny doesn't process. +type item interface { + generic.Type + updateSendWindow(protocol.ByteCount) + closeForShutdown(error) +} + +const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni diff --git a/internal/quic-go/streams_map_incoming_bidi.go b/internal/quic-go/streams_map_incoming_bidi.go new file mode 100644 index 00000000..6b80a359 --- /dev/null +++ b/internal/quic-go/streams_map_incoming_bidi.go @@ -0,0 +1,192 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type streamIEntry struct { + stream streamI + shouldDelete bool +} + +type incomingBidiStreamsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]streamIEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) streamI + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingBidiStreamsMap( + newStream func(protocol.StreamNum) streamI, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingBidiStreamsMap { + return &incomingBidiStreamsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]streamIEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry streamIEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s streamI + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = streamIEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/internal/quic-go/streams_map_incoming_generic.go b/internal/quic-go/streams_map_incoming_generic.go new file mode 100644 index 00000000..35a9e12d --- /dev/null +++ b/internal/quic-go/streams_map_incoming_generic.go @@ -0,0 +1,190 @@ +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type itemEntry struct { + stream item + shouldDelete bool +} + +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" +type incomingItemsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]itemEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) item + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingItemsMap( + newStream func(protocol.StreamNum) item, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingItemsMap { + return &incomingItemsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]itemEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry itemEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s item + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = itemEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: streamTypeGeneric, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/internal/quic-go/streams_map_incoming_generic_test.go b/internal/quic-go/streams_map_incoming_generic_test.go new file mode 100644 index 00000000..0983ba4f --- /dev/null +++ b/internal/quic-go/streams_map_incoming_generic_test.go @@ -0,0 +1,307 @@ +package quic + +import ( + "bytes" + "context" + "errors" + "math/rand" + "time" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockGenericStream struct { + num protocol.StreamNum + + closed bool + closeErr error + sendWindow protocol.ByteCount +} + +func (s *mockGenericStream) closeForShutdown(err error) { + s.closed = true + s.closeErr = err +} + +func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { + s.sendWindow = limit +} + +var _ = Describe("Streams Map (incoming)", func() { + var ( + m *incomingItemsMap + newItemCounter int + mockSender *MockStreamSender + maxNumStreams uint64 + ) + + // check that the frame can be serialized and deserialized + checkFrameSerialization := func(f wire.Frame) { + b := &bytes.Buffer{} + ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) + frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + Expect(f).To(Equal(frame)) + } + + BeforeEach(func() { maxNumStreams = 5 }) + + JustBeforeEach(func() { + newItemCounter = 0 + mockSender = NewMockStreamSender(mockCtrl) + m = newIncomingItemsMap( + func(num protocol.StreamNum) item { + newItemCounter++ + return &mockGenericStream{num: num} + }, + maxNumStreams, + mockSender.queueControlFrame, + ) + }) + + It("opens all streams up to the id on GetOrOpenStream", func() { + _, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(4)) + }) + + It("starts opening streams at the right position", func() { + // like the test above, but with 2 calls to GetOrOpenStream + _, err := m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(2)) + _, err = m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(5)) + }) + + It("accepts streams in the right order", func() { + _, err := m.GetOrOpenStream(2) // open streams 1 and 2 + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + str, err = m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + }) + + It("allows opening the maximum stream ID", func() { + str, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + }) + + It("errors when trying to get a stream ID higher than the maximum", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) + }) + + It("blocks AcceptStream until a new stream is available", func() { + strChan := make(chan item) + go func() { + defer GinkgoRecover() + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + strChan <- str + }() + Consistently(strChan).ShouldNot(Receive()) + str, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + var acceptedStr item + Eventually(strChan).Should(Receive(&acceptedStr)) + Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + }) + + It("unblocks AcceptStream when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + + It("unblocks AcceptStream when it is closed", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream(context.Background()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("errors AcceptStream immediately if it is closed", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.AcceptStream(context.Background()) + Expect(err).To(MatchError(testErr)) + }) + + It("closes all streams when CloseWithError is called", func() { + str1, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + str2, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test err") + m.CloseWithError(testErr) + Expect(str1.(*mockGenericStream).closed).To(BeTrue()) + Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) + Expect(str2.(*mockGenericStream).closed).To(BeTrue()) + Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) + }) + + It("deletes streams", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + _, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(m.DeleteStream(1)).To(Succeed()) + str, err = m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + It("waits until a stream is accepted before actually deleting it", func() { + _, err := m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(m.DeleteStream(2)).To(Succeed()) + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str, err = m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + }) + + It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { + str, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + Expect(m.DeleteStream(1)).To(Succeed()) + str, err = m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str, err = m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + }) + + It("errors when deleting a non-existing stream", func() { + err := m.DeleteStream(1337) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337")) + }) + + It("sends MAX_STREAMS frames when streams are deleted", func() { + // open a bunch of streams + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + // accept all streams + for i := 0; i < 5; i++ { + _, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + } + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(3)).To(Succeed()) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(4)).To(Succeed()) + }) + + Context("using high stream limits", func() { + BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) + + It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { + // open a bunch of streams + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + // accept all streams + for i := 0; i < 5; i++ { + _, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + } + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(4)).To(Succeed()) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(3)).To(Succeed()) + // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent + Expect(m.DeleteStream(2)).To(Succeed()) + Expect(m.DeleteStream(1)).To(Succeed()) + }) + }) + + Context("randomized tests", func() { + const num = 1000 + + BeforeEach(func() { maxNumStreams = num }) + + It("opens and accepts streams", func() { + rand.Seed(GinkgoRandomSeed()) + ids := make([]protocol.StreamNum, num) + for i := 0; i < num; i++ { + ids[i] = protocol.StreamNum(i + 1) + } + rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) + + const timeout = 5 * time.Second + done := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for i := 0; i < num; i++ { + _, err := m.AcceptStream(ctx) + Expect(err).ToNot(HaveOccurred()) + } + done <- struct{}{} + }() + + go func() { + defer GinkgoRecover() + for i := 0; i < num; i++ { + _, err := m.GetOrOpenStream(ids[i]) + Expect(err).ToNot(HaveOccurred()) + } + done <- struct{}{} + }() + + Eventually(done, timeout*3/2).Should(Receive()) + Eventually(done, timeout*3/2).Should(Receive()) + }) + }) +}) diff --git a/internal/quic-go/streams_map_incoming_uni.go b/internal/quic-go/streams_map_incoming_uni.go new file mode 100644 index 00000000..c567a562 --- /dev/null +++ b/internal/quic-go/streams_map_incoming_uni.go @@ -0,0 +1,192 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type receiveStreamIEntry struct { + stream receiveStreamI + shouldDelete bool +} + +type incomingUniStreamsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]receiveStreamIEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) receiveStreamI + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingUniStreamsMap( + newStream func(protocol.StreamNum) receiveStreamI, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingUniStreamsMap { + return &incomingUniStreamsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]receiveStreamIEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry receiveStreamIEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s receiveStreamI + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = receiveStreamIEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/internal/quic-go/streams_map_outgoing_bidi.go b/internal/quic-go/streams_map_outgoing_bidi.go new file mode 100644 index 00000000..f76eda16 --- /dev/null +++ b/internal/quic-go/streams_map_outgoing_bidi.go @@ -0,0 +1,226 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type outgoingBidiStreamsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]streamI + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) streamI + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingBidiStreamsMap( + newStream func(protocol.StreamNum) streamI, + queueControlFrame func(wire.Frame), +) *outgoingBidiStreamsMap { + return &outgoingBidiStreamsMap{ + streams: make(map[protocol.StreamNum]streamI), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingBidiStreamsMap) openStream() streamI { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingBidiStreamsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingBidiStreamsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/internal/quic-go/streams_map_outgoing_generic.go b/internal/quic-go/streams_map_outgoing_generic.go new file mode 100644 index 00000000..f5449ed2 --- /dev/null +++ b/internal/quic-go/streams_map_outgoing_generic.go @@ -0,0 +1,224 @@ +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" +//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" +type outgoingItemsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]item + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) item + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingItemsMap( + newStream func(protocol.StreamNum) item, + queueControlFrame func(wire.Frame), +) *outgoingItemsMap { + return &outgoingItemsMap{ + streams: make(map[protocol.StreamNum]item), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingItemsMap) OpenStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingItemsMap) openStream() item { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingItemsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: streamTypeGeneric, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingItemsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/internal/quic-go/streams_map_outgoing_generic_test.go b/internal/quic-go/streams_map_outgoing_generic_test.go new file mode 100644 index 00000000..dc2f13a1 --- /dev/null +++ b/internal/quic-go/streams_map_outgoing_generic_test.go @@ -0,0 +1,539 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sort" + "sync" + "time" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Streams Map (outgoing)", func() { + var ( + m *outgoingItemsMap + newItem func(num protocol.StreamNum) item + mockSender *MockStreamSender + ) + + // waitForEnqueued waits until there are n go routines waiting on OpenStreamSync() + waitForEnqueued := func(n int) { + Eventually(func() int { + m.mutex.Lock() + defer m.mutex.Unlock() + return len(m.openQueue) + }, 50*time.Millisecond, 100*time.Microsecond).Should(Equal(n)) + } + + BeforeEach(func() { + newItem = func(num protocol.StreamNum) item { + return &mockGenericStream{num: num} + } + mockSender = NewMockStreamSender(mockCtrl) + m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame) + }) + + Context("no stream ID limit", func() { + BeforeEach(func() { + m.SetMaxStream(0xffffffff) + }) + + It("opens streams", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + str, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + }) + + It("doesn't open streams after it has been closed", func() { + testErr := errors.New("close") + m.CloseWithError(testErr) + _, err := m.OpenStream() + Expect(err).To(MatchError(testErr)) + }) + + It("gets streams", func() { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + }) + + It("errors when trying to get a stream that has not yet been opened", func() { + _, err := m.GetStream(1) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1")) + }) + + It("deletes streams", func() { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(m.DeleteStream(1)).To(Succeed()) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + It("errors when deleting a non-existing stream", func() { + err := m.DeleteStream(1337) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1337")) + }) + + It("errors when deleting a stream twice", func() { + _, err := m.OpenStream() // opens firstNewStream + Expect(err).ToNot(HaveOccurred()) + Expect(m.DeleteStream(1)).To(Succeed()) + err = m.DeleteStream(1) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1")) + }) + + It("closes all streams when CloseWithError is called", func() { + str1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test err") + m.CloseWithError(testErr) + Expect(str1.(*mockGenericStream).closed).To(BeTrue()) + Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) + Expect(str2.(*mockGenericStream).closed).To(BeTrue()) + Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) + }) + + It("updates the send window", func() { + str1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + m.UpdateSendWindow(1337) + Expect(str1.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) + Expect(str2.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) + }) + }) + + Context("with stream ID limits", func() { + It("errors when no stream can be opened immediately", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + _, err := m.OpenStream() + expectTooManyStreamsError(err) + }) + + It("returns immediately when called with a canceled context", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := m.OpenStreamSync(ctx) + Expect(err).To(MatchError("context canceled")) + }) + + It("blocks until a stream can be opened synchronously", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + close(done) + }() + waitForEnqueued(1) + + m.SetMaxStream(1) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks when the context is canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + waitForEnqueued(1) + + cancel() + Eventually(done).Should(BeClosed()) + + // make sure that the next stream opened is stream 1 + m.SetMaxStream(1000) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + }) + + It("opens streams in the right order", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + close(done1) + }() + waitForEnqueued(1) + + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + close(done2) + }() + waitForEnqueued(2) + + m.SetMaxStream(1) + Eventually(done1).Should(BeClosed()) + Consistently(done2).ShouldNot(BeClosed()) + m.SetMaxStream(2) + Eventually(done2).Should(BeClosed()) + }) + + It("opens streams in the right order, when one of the contexts is canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + close(done1) + }() + waitForEnqueued(1) + + done2 := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(ctx) + Expect(err).To(MatchError(context.Canceled)) + close(done2) + }() + waitForEnqueued(2) + + done3 := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + close(done3) + }() + waitForEnqueued(3) + + cancel() + Eventually(done2).Should(BeClosed()) + m.SetMaxStream(1000) + Eventually(done1).Should(BeClosed()) + Eventually(done3).Should(BeClosed()) + }) + + It("unblocks multiple OpenStreamSync calls at the same time", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + done <- struct{}{} + }() + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + done <- struct{}{} + }() + waitForEnqueued(2) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).To(MatchError("test done")) + done <- struct{}{} + }() + waitForEnqueued(3) + + m.SetMaxStream(2) + Eventually(done).Should(Receive()) + Eventually(done).Should(Receive()) + Consistently(done).ShouldNot(Receive()) + + m.CloseWithError(errors.New("test done")) + Eventually(done).Should(Receive()) + }) + + It("returns an error for OpenStream while an OpenStreamSync call is blocking", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(2) + openedSync := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + close(openedSync) + }() + waitForEnqueued(1) + + start := make(chan struct{}) + openend := make(chan struct{}) + go func() { + defer GinkgoRecover() + var hasStarted bool + for { + str, err := m.OpenStream() + if err == nil { + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + close(openend) + return + } + expectTooManyStreamsError(err) + if !hasStarted { + close(start) + hasStarted = true + } + } + }() + + Eventually(start).Should(BeClosed()) + m.SetMaxStream(1) + Eventually(openedSync).Should(BeClosed()) + Consistently(openend).ShouldNot(BeClosed()) + m.SetMaxStream(2) + Eventually(openend).Should(BeClosed()) + }) + + It("stops opening synchronously when it is closed", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + testErr := errors.New("test error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't reduce the stream limit", func() { + m.SetMaxStream(2) + m.SetMaxStream(1) + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + }) + + It("queues a STREAMS_BLOCKED frame if no stream can be opened", func() { + m.SetMaxStream(6) + // open the 6 allowed streams + for i := 0; i < 6; i++ { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6)) + }) + _, err := m.OpenStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) + }) + + It("only sends one STREAMS_BLOCKED frame for one stream ID", func() { + m.SetMaxStream(1) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) + }) + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + // try to open a stream twice, but expect only one STREAMS_BLOCKED to be sent + _, err = m.OpenStream() + expectTooManyStreamsError(err) + _, err = m.OpenStream() + expectTooManyStreamsError(err) + }) + + It("queues a STREAMS_BLOCKED frame when there more streams waiting for OpenStreamSync than MAX_STREAMS allows", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0)) + }) + done := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + done <- struct{}{} + }() + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + done <- struct{}{} + }() + waitForEnqueued(2) + + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) + }) + m.SetMaxStream(1) + Eventually(done).Should(Receive()) + Consistently(done).ShouldNot(Receive()) + m.SetMaxStream(2) + Eventually(done).Should(Receive()) + }) + }) + + Context("randomized tests", func() { + It("opens streams", func() { + rand.Seed(GinkgoRandomSeed()) + const n = 100 + fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) + + var blockedAt []protocol.StreamNum + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) + }).AnyTimes() + done := make(map[int]chan struct{}) + for i := 1; i <= n; i++ { + c := make(chan struct{}) + done[i] = c + + go func(doneChan chan struct{}, id protocol.StreamNum) { + defer GinkgoRecover() + defer close(doneChan) + str, err := m.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(id)) + }(c, protocol.StreamNum(i)) + waitForEnqueued(i) + } + + var limit int + limits := []protocol.StreamNum{0} + for limit < n { + limit += rand.Intn(n/5) + 1 + if limit <= n { + limits = append(limits, protocol.StreamNum(limit)) + } + fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit) + m.SetMaxStream(protocol.StreamNum(limit)) + for i := 1; i <= n; i++ { + if i <= limit { + Eventually(done[i]).Should(BeClosed()) + } else { + Expect(done[i]).ToNot(BeClosed()) + } + } + str, err := m.OpenStream() + if limit <= n { + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) + } else { + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(n + 1))) + } + } + Expect(blockedAt).To(Equal(limits)) + }) + + It("opens streams, when some of them are getting canceled", func() { + rand.Seed(GinkgoRandomSeed()) + const n = 100 + fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) + + var blockedAt []protocol.StreamNum + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) + }).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + streamsToCancel := make(map[protocol.StreamNum]struct{}) // used as a set + for i := 0; i < 10; i++ { + id := protocol.StreamNum(rand.Intn(n) + 1) + fmt.Fprintf(GinkgoWriter, "Canceling stream %d.\n", id) + streamsToCancel[id] = struct{}{} + } + + streamWillBeCanceled := func(id protocol.StreamNum) bool { + _, ok := streamsToCancel[id] + return ok + } + + var streamIDs []int + var mutex sync.Mutex + done := make(map[int]chan struct{}) + for i := 1; i <= n; i++ { + c := make(chan struct{}) + done[i] = c + + go func(doneChan chan struct{}, id protocol.StreamNum) { + defer GinkgoRecover() + defer close(doneChan) + cont := context.Background() + if streamWillBeCanceled(id) { + cont = ctx + } + str, err := m.OpenStreamSync(cont) + if streamWillBeCanceled(id) { + Expect(err).To(MatchError(context.Canceled)) + return + } + Expect(err).ToNot(HaveOccurred()) + mutex.Lock() + streamIDs = append(streamIDs, int(str.(*mockGenericStream).num)) + mutex.Unlock() + }(c, protocol.StreamNum(i)) + waitForEnqueued(i) + } + + cancel() + for id := range streamsToCancel { + Eventually(done[int(id)]).Should(BeClosed()) + } + var limit int + numStreams := n - len(streamsToCancel) + var limits []protocol.StreamNum + for limit < numStreams { + limits = append(limits, protocol.StreamNum(limit)) + limit += rand.Intn(n/5) + 1 + fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit) + m.SetMaxStream(protocol.StreamNum(limit)) + l := limit + if l > numStreams { + l = numStreams + } + Eventually(func() int { + mutex.Lock() + defer mutex.Unlock() + return len(streamIDs) + }).Should(Equal(l)) + // check that all stream IDs were used + Expect(streamIDs).To(HaveLen(l)) + sort.Ints(streamIDs) + for i := 0; i < l; i++ { + Expect(streamIDs[i]).To(Equal(i + 1)) + } + } + Expect(blockedAt).To(Equal(limits)) + }) + }) +}) diff --git a/internal/quic-go/streams_map_outgoing_uni.go b/internal/quic-go/streams_map_outgoing_uni.go new file mode 100644 index 00000000..22261fb4 --- /dev/null +++ b/internal/quic-go/streams_map_outgoing_uni.go @@ -0,0 +1,226 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type outgoingUniStreamsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]sendStreamI + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) sendStreamI + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingUniStreamsMap( + newStream func(protocol.StreamNum) sendStreamI, + queueControlFrame func(wire.Frame), +) *outgoingUniStreamsMap { + return &outgoingUniStreamsMap{ + streams: make(map[protocol.StreamNum]sendStreamI), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingUniStreamsMap) openStream() sendStreamI { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: protocol.StreamTypeUni, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingUniStreamsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/internal/quic-go/streams_map_test.go b/internal/quic-go/streams_map_test.go new file mode 100644 index 00000000..c3e9fd4b --- /dev/null +++ b/internal/quic-go/streams_map_test.go @@ -0,0 +1,499 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/golang/mock/gomock" + + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func (e streamError) TestError() error { + nums := make([]interface{}, len(e.nums)) + for i, num := range e.nums { + nums[i] = num + } + return fmt.Errorf(e.message, nums...) +} + +type streamMapping struct { + firstIncomingBidiStream protocol.StreamID + firstIncomingUniStream protocol.StreamID + firstOutgoingBidiStream protocol.StreamID + firstOutgoingUniStream protocol.StreamID +} + +func expectTooManyStreamsError(err error) { + ExpectWithOffset(1, err).To(HaveOccurred()) + ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error())) + nerr, ok := err.(net.Error) + ExpectWithOffset(1, ok).To(BeTrue()) + ExpectWithOffset(1, nerr.Timeout()).To(BeFalse()) +} + +var _ = Describe("Streams Map", func() { + newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController { + return mocks.NewMockStreamFlowController(mockCtrl) + } + + serverStreamMapping := streamMapping{ + firstIncomingBidiStream: 0, + firstOutgoingBidiStream: 1, + firstIncomingUniStream: 2, + firstOutgoingUniStream: 3, + } + clientStreamMapping := streamMapping{ + firstIncomingBidiStream: 1, + firstOutgoingBidiStream: 0, + firstIncomingUniStream: 3, + firstOutgoingUniStream: 2, + } + + for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { + perspective := p + var ids streamMapping + if perspective == protocol.PerspectiveClient { + ids = clientStreamMapping + } else { + ids = serverStreamMapping + } + + Context(perspective.String(), func() { + var ( + m *streamsMap + mockSender *MockStreamSender + ) + + const ( + MaxBidiStreamNum = 111 + MaxUniStreamNum = 222 + ) + + allowUnlimitedStreams := func() { + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: protocol.MaxStreamCount, + MaxUniStreamNum: protocol.MaxStreamCount, + }) + } + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective, protocol.VersionWhatever).(*streamsMap) + }) + + Context("opening", func() { + It("opens bidirectional streams", func() { + allowUnlimitedStreams() + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&stream{})) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + str, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&stream{})) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4)) + }) + + It("opens unidirectional streams", func() { + allowUnlimitedStreams() + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&sendStream{})) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) + str, err = m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&sendStream{})) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4)) + }) + }) + + Context("accepting", func() { + It("accepts bidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&stream{})) + Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) + }) + + It("accepts unidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) + Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) + }) + }) + + Context("deleting", func() { + BeforeEach(func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + allowUnlimitedStreams() + }) + + It("deletes outgoing bidirectional streams", func() { + id := ids.firstOutgoingBidiStream + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + Expect(m.DeleteStream(id)).To(Succeed()) + dstr, err := m.GetOrOpenSendStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(dstr).To(BeNil()) + }) + + It("deletes incoming bidirectional streams", func() { + id := ids.firstIncomingBidiStream + str, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + Expect(m.DeleteStream(id)).To(Succeed()) + dstr, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(dstr).To(BeNil()) + }) + + It("accepts bidirectional streams after they have been deleted", func() { + id := ids.firstIncomingBidiStream + _, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(m.DeleteStream(id)).To(Succeed()) + str, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + Expect(str.StreamID()).To(Equal(id)) + }) + + It("deletes outgoing unidirectional streams", func() { + id := ids.firstOutgoingUniStream + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + Expect(m.DeleteStream(id)).To(Succeed()) + dstr, err := m.GetOrOpenSendStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(dstr).To(BeNil()) + }) + + It("deletes incoming unidirectional streams", func() { + id := ids.firstIncomingUniStream + str, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + Expect(m.DeleteStream(id)).To(Succeed()) + dstr, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(dstr).To(BeNil()) + }) + + It("accepts unirectional streams after they have been deleted", func() { + id := ids.firstIncomingUniStream + _, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(m.DeleteStream(id)).To(Succeed()) + str, err := m.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + Expect(str.StreamID()).To(Equal(id)) + }) + + It("errors when deleting unknown incoming unidirectional streams", func() { + id := ids.firstIncomingUniStream + 4 + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) + }) + + It("errors when deleting unknown outgoing unidirectional streams", func() { + id := ids.firstOutgoingUniStream + 4 + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) + }) + + It("errors when deleting unknown incoming bidirectional streams", func() { + id := ids.firstIncomingBidiStream + 4 + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) + }) + + It("errors when deleting unknown outgoing bidirectional streams", func() { + id := ids.firstOutgoingBidiStream + 4 + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) + }) + }) + + Context("getting streams", func() { + BeforeEach(func() { + allowUnlimitedStreams() + }) + + Context("send streams", func() { + It("gets an outgoing bidirectional stream", func() { + // need to open the stream ourselves first + // the peer is not allowed to create a stream initiated by us + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + }) + + It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { + id := ids.firstOutgoingBidiStream + 5*4 + _, err := m.GetOrOpenSendStream(id) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) + }) + + It("gets an outgoing unidirectional stream", func() { + // need to open the stream ourselves first + // the peer is not allowed to create a stream initiated by us + _, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) + }) + + It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { + id := ids.firstOutgoingUniStream + 5*4 + _, err := m.GetOrOpenSendStream(id) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) + }) + + It("gets an incoming bidirectional stream", func() { + id := ids.firstIncomingBidiStream + 4*7 + str, err := m.GetOrOpenSendStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + }) + + It("errors when trying to get an incoming unidirectional stream", func() { + id := ids.firstIncomingUniStream + _, err := m.GetOrOpenSendStream(id) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id), + })) + }) + }) + + Context("receive streams", func() { + It("gets an outgoing bidirectional stream", func() { + // need to open the stream ourselves first + // the peer is not allowed to create a stream initiated by us + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + }) + + It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { + id := ids.firstOutgoingBidiStream + 5*4 + _, err := m.GetOrOpenReceiveStream(id) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) + }) + + It("gets an incoming bidirectional stream", func() { + id := ids.firstIncomingBidiStream + 4*7 + str, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + }) + + It("gets an incoming unidirectional stream", func() { + id := ids.firstIncomingUniStream + 4*10 + str, err := m.GetOrOpenReceiveStream(id) + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(id)) + }) + + It("errors when trying to get an outgoing unidirectional stream", func() { + id := ids.firstOutgoingUniStream + _, err := m.GetOrOpenReceiveStream(id) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id), + })) + }) + }) + }) + + It("processes the parameter for outgoing streams", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + _, err := m.OpenStream() + expectTooManyStreamsError(err) + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 5, + MaxUniStreamNum: 8, + }) + + mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) + // test we can only 5 bidirectional streams + for i := 0; i < 5; i++ { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) + } + _, err = m.OpenStream() + expectTooManyStreamsError(err) + // test we can only 8 unidirectional streams + for i := 0; i < 8; i++ { + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) + } + _, err = m.OpenUniStream() + expectTooManyStreamsError(err) + }) + + if perspective == protocol.PerspectiveClient { + It("applies parameters to existing streams (needed for 0-RTT)", func() { + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 1000, + MaxUniStreamNum: 1000, + }) + flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController) + m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController { + fc := mocks.NewMockStreamFlowController(mockCtrl) + flowControllers[id] = fc + return fc + } + + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + unistr, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + + Expect(flowControllers).To(HaveKey(str.StreamID())) + flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) + Expect(flowControllers).To(HaveKey(unistr.StreamID())) + flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) + + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 1000, + InitialMaxStreamDataUni: 1234, + MaxUniStreamNum: 1000, + InitialMaxStreamDataBidiRemote: 4321, + }) + }) + } + + Context("handling MAX_STREAMS frames", func() { + BeforeEach(func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + }) + + It("processes IDs for outgoing bidirectional streams", func() { + _, err := m.OpenStream() + expectTooManyStreamsError(err) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 1, + }) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + _, err = m.OpenStream() + expectTooManyStreamsError(err) + }) + + It("processes IDs for outgoing unidirectional streams", func() { + _, err := m.OpenUniStream() + expectTooManyStreamsError(err) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: 1, + }) + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) + _, err = m.OpenUniStream() + expectTooManyStreamsError(err) + }) + }) + + Context("sending MAX_STREAMS frames", func() { + It("sends a MAX_STREAMS frame for bidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) + Expect(err).ToNot(HaveOccurred()) + _, err = m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: MaxBidiStreamNum + 1, + }) + Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) + }) + + It("sends a MAX_STREAMS frame for unidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) + Expect(err).ToNot(HaveOccurred()) + _, err = m.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: MaxUniStreamNum + 1, + }) + Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) + }) + }) + + It("closes", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) + _, err = m.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) + _, err = m.AcceptStream(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) + _, err = m.AcceptUniStream(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) + }) + + if perspective == protocol.PerspectiveClient { + It("resets for 0-RTT", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + m.ResetFor0RTT() + // make sure that calls to open / accept streams fail + _, err := m.OpenStream() + Expect(err).To(MatchError(Err0RTTRejected)) + _, err = m.AcceptStream(context.Background()) + Expect(err).To(MatchError(Err0RTTRejected)) + // make sure that we can still get new streams, as the server might be sending us data + str, err := m.GetOrOpenReceiveStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + + // now switch to using the new streams map + m.UseResetMaps() + _, err = m.OpenStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + }) + } + }) + } +}) diff --git a/internal/quic-go/sys_conn.go b/internal/quic-go/sys_conn.go new file mode 100644 index 00000000..315c26cc --- /dev/null +++ b/internal/quic-go/sys_conn.go @@ -0,0 +1,80 @@ +package quic + +import ( + "net" + "syscall" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. +// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will use it. +// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. +type OOBCapablePacketConn interface { + net.PacketConn + SyscallConn() (syscall.RawConn, error) + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +var _ OOBCapablePacketConn = &net.UDPConn{} + +func wrapConn(pc net.PacketConn) (rawConn, error) { + conn, ok := pc.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if ok { + rawConn, err := conn.SyscallConn() + if err != nil { + return nil, err + } + + if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { + // Only set DF on sockets that we expect to be able to handle that configuration. + err = setDF(rawConn) + if err != nil { + return nil, err + } + } + } + c, ok := pc.(OOBCapablePacketConn) + if !ok { + utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") + return &basicConn{PacketConn: pc}, nil + } + return newConn(c) +} + +// The basicConn is the most trivial implementation of a connection. +// It reads a single packet from the underlying net.PacketConn. +// It is used when +// * the net.PacketConn is not a OOBCapablePacketConn, and +// * when the OS doesn't support OOB. +type basicConn struct { + net.PacketConn +} + +var _ rawConn = &basicConn{} + +func (c *basicConn) ReadPacket() (*receivedPacket, error) { + buffer := getPacketBuffer() + // The packet size should not exceed protocol.MaxPacketBufferSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] + n, addr, err := c.PacketConn.ReadFrom(buffer.Data) + if err != nil { + return nil, err + } + return &receivedPacket{ + remoteAddr: addr, + rcvTime: time.Now(), + data: buffer.Data[:n], + buffer: buffer, + }, nil +} + +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, addr) +} diff --git a/internal/quic-go/sys_conn_df.go b/internal/quic-go/sys_conn_df.go new file mode 100644 index 00000000..ae9274d9 --- /dev/null +++ b/internal/quic-go/sys_conn_df.go @@ -0,0 +1,16 @@ +//go:build !linux && !windows +// +build !linux,!windows + +package quic + +import "syscall" + +func setDF(rawConn syscall.RawConn) error { + // no-op on unsupported platforms + return nil +} + +func isMsgSizeErr(err error) bool { + // to be implemented for more specific platforms + return false +} diff --git a/internal/quic-go/sys_conn_df_linux.go b/internal/quic-go/sys_conn_df_linux.go new file mode 100644 index 00000000..d3345658 --- /dev/null +++ b/internal/quic-go/sys_conn_df_linux.go @@ -0,0 +1,40 @@ +//go:build linux +// +build linux + +package quic + +import ( + "errors" + "syscall" + + "github.com/imroc/req/v3/internal/quic-go/utils" + "golang.org/x/sys/unix" +) + +func setDF(rawConn syscall.RawConn) error { + // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" + // and the datagram will not be fragmented + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://man7.org/linux/man-pages/man7/udp.7.html + return errors.Is(err, unix.EMSGSIZE) +} diff --git a/internal/quic-go/sys_conn_df_windows.go b/internal/quic-go/sys_conn_df_windows.go new file mode 100644 index 00000000..e5e2e5b3 --- /dev/null +++ b/internal/quic-go/sys_conn_df_windows.go @@ -0,0 +1,46 @@ +//go:build windows +// +build windows + +package quic + +import ( + "errors" + "syscall" + + "github.com/imroc/req/v3/internal/quic-go/utils" + "golang.org/x/sys/windows" +) + +const ( + // same for both IPv4 and IPv6 on Windows + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html + IP_DONTFRAGMENT = 14 + IPV6_DONTFRAG = 14 +) + +func setDF(rawConn syscall.RawConn) error { + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) + errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 + return errors.Is(err, windows.WSAEMSGSIZE) +} diff --git a/internal/quic-go/sys_conn_helper_darwin.go b/internal/quic-go/sys_conn_helper_darwin.go new file mode 100644 index 00000000..eabf489f --- /dev/null +++ b/internal/quic-go/sys_conn_helper_darwin.go @@ -0,0 +1,22 @@ +//go:build darwin +// +build darwin + +package quic + +import "golang.org/x/sys/unix" + +const msgTypeIPTOS = unix.IP_RECVTOS + +const ( + ipv4RECVPKTINFO = unix.IP_RECVPKTINFO + ipv6RECVPKTINFO = 0x3d +) + +const ( + msgTypeIPv4PKTINFO = unix.IP_PKTINFO + msgTypeIPv6PKTINFO = 0x2e +) + +// ReadBatch only returns a single packet on OSX, +// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. +const batchSize = 1 diff --git a/internal/quic-go/sys_conn_helper_freebsd.go b/internal/quic-go/sys_conn_helper_freebsd.go new file mode 100644 index 00000000..0b3e8434 --- /dev/null +++ b/internal/quic-go/sys_conn_helper_freebsd.go @@ -0,0 +1,22 @@ +//go:build freebsd +// +build freebsd + +package quic + +import "golang.org/x/sys/unix" + +const ( + msgTypeIPTOS = unix.IP_RECVTOS +) + +const ( + ipv4RECVPKTINFO = 0x7 + ipv6RECVPKTINFO = 0x24 +) + +const ( + msgTypeIPv4PKTINFO = 0x7 + msgTypeIPv6PKTINFO = 0x2e +) + +const batchSize = 8 diff --git a/internal/quic-go/sys_conn_helper_linux.go b/internal/quic-go/sys_conn_helper_linux.go new file mode 100644 index 00000000..51bec900 --- /dev/null +++ b/internal/quic-go/sys_conn_helper_linux.go @@ -0,0 +1,20 @@ +//go:build linux +// +build linux + +package quic + +import "golang.org/x/sys/unix" + +const msgTypeIPTOS = unix.IP_TOS + +const ( + ipv4RECVPKTINFO = unix.IP_PKTINFO + ipv6RECVPKTINFO = unix.IPV6_RECVPKTINFO +) + +const ( + msgTypeIPv4PKTINFO = unix.IP_PKTINFO + msgTypeIPv6PKTINFO = unix.IPV6_PKTINFO +) + +const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) diff --git a/internal/quic-go/sys_conn_no_oob.go b/internal/quic-go/sys_conn_no_oob.go new file mode 100644 index 00000000..e3b0d11f --- /dev/null +++ b/internal/quic-go/sys_conn_no_oob.go @@ -0,0 +1,16 @@ +//go:build !darwin && !linux && !freebsd && !windows +// +build !darwin,!linux,!freebsd,!windows + +package quic + +import "net" + +func newConn(c net.PacketConn) (rawConn, error) { + return &basicConn{PacketConn: c}, nil +} + +func inspectReadBuffer(interface{}) (int, error) { + return 0, nil +} + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/internal/quic-go/sys_conn_oob.go b/internal/quic-go/sys_conn_oob.go new file mode 100644 index 00000000..38fe9831 --- /dev/null +++ b/internal/quic-go/sys_conn_oob.go @@ -0,0 +1,257 @@ +//go:build darwin || linux || freebsd +// +build darwin linux freebsd + +package quic + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "syscall" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +const ( + ecnMask = 0x3 + oobBufferSize = 128 +) + +// Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version. +// They're both just aliases for x/net/internal/socket.Message. +// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages. +var _ ipv4.Message = ipv6.Message{} + +type batchConn interface { + ReadBatch(ms []ipv4.Message, flags int) (int, error) +} + +func inspectReadBuffer(c interface{}) (int, error) { + conn, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return 0, errors.New("doesn't have a SyscallConn") + } + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) + } + var size int + var serr error + if err := rawConn.Control(func(fd uintptr) { + size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + }); err != nil { + return 0, err + } + return size, serr +} + +type oobConn struct { + OOBCapablePacketConn + batchConn batchConn + + readPos uint8 + // Packets received from the kernel, but not yet returned by ReadPacket(). + messages []ipv4.Message + buffers [batchSize]*packetBuffer +} + +var _ rawConn = &oobConn{} + +func newConn(c OOBCapablePacketConn) (*oobConn, error) { + rawConn, err := c.SyscallConn() + if err != nil { + return nil, err + } + needsPacketInfo := false + if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() { + needsPacketInfo = true + } + // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. + // Try enabling receiving of ECN and packet info for both IP versions. + // We expect at least one of those syscalls to succeed. + var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + + if needsPacketInfo { + errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1) + errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1) + } + }); err != nil { + return nil, err + } + switch { + case errECNIPv4 == nil && errECNIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.") + case errECNIPv4 == nil && errECNIPv6 != nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.") + case errECNIPv4 != nil && errECNIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.") + case errECNIPv4 != nil && errECNIPv6 != nil: + return nil, errors.New("activating ECN failed for both IPv4 and IPv6") + } + if needsPacketInfo { + switch { + case errPIIPv4 == nil && errPIIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.") + case errPIIPv4 == nil && errPIIPv6 != nil: + utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.") + case errPIIPv4 != nil && errPIIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.") + case errPIIPv4 != nil && errPIIPv6 != nil: + return nil, errors.New("activating packet info failed for both IPv4 and IPv6") + } + } + + // Allows callers to pass in a connection that already satisfies batchConn interface + // to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor + // via SyscallConn(), and read it that way, which might not be what the caller wants. + var bc batchConn + if ibc, ok := c.(batchConn); ok { + bc = ibc + } else { + bc = ipv4.NewPacketConn(c) + } + + oobConn := &oobConn{ + OOBCapablePacketConn: c, + batchConn: bc, + messages: make([]ipv4.Message, batchSize), + readPos: batchSize, + } + for i := 0; i < batchSize; i++ { + oobConn.messages[i].OOB = make([]byte, oobBufferSize) + } + return oobConn, nil +} + +func (c *oobConn) ReadPacket() (*receivedPacket, error) { + if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. + c.messages = c.messages[:batchSize] + // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call + for i := uint8(0); i < c.readPos; i++ { + buffer := getPacketBuffer() + buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] + c.buffers[i] = buffer + c.messages[i].Buffers = [][]byte{c.buffers[i].Data} + } + c.readPos = 0 + + n, err := c.batchConn.ReadBatch(c.messages, 0) + if n == 0 || err != nil { + return nil, err + } + c.messages = c.messages[:n] + } + + msg := c.messages[c.readPos] + buffer := c.buffers[c.readPos] + c.readPos++ + ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN]) + if err != nil { + return nil, err + } + var ecn protocol.ECN + var destIP net.IP + var ifIndex uint32 + for _, ctrlMsg := range ctrlMsgs { + if ctrlMsg.Header.Level == unix.IPPROTO_IP { + switch ctrlMsg.Header.Type { + case msgTypeIPTOS: + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + case msgTypeIPv4PKTINFO: + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* Interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* Header Destination + // address */ + // }; + ip := make([]byte, 4) + if len(ctrlMsg.Data) == 12 { + ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data) + copy(ip, ctrlMsg.Data[8:12]) + } else if len(ctrlMsg.Data) == 4 { + // FreeBSD + copy(ip, ctrlMsg.Data) + } + destIP = net.IP(ip) + } + } + if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 { + switch ctrlMsg.Header.Type { + case unix.IPV6_TCLASS: + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + case msgTypeIPv6PKTINFO: + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + if len(ctrlMsg.Data) == 20 { + ip := make([]byte, 16) + copy(ip, ctrlMsg.Data[:16]) + destIP = net.IP(ip) + ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:]) + } + } + } + } + var info *packetInfo + if destIP != nil { + info = &packetInfo{ + addr: destIP, + ifIndex: ifIndex, + } + } + return &receivedPacket{ + remoteAddr: msg.Addr, + rcvTime: time.Now(), + data: msg.Buffers[0][:msg.N], + ecn: ecn, + info: info, + buffer: buffer, + }, nil +} + +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { + n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) + return n, err +} + +func (info *packetInfo) OOB() []byte { + if info == nil { + return nil + } + if ip4 := info.addr.To4(); ip4 != nil { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* Interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* Header Destination address */ + // }; + cm := ipv4.ControlMessage{ + Src: ip4, + IfIndex: int(info.ifIndex), + } + return cm.Marshal() + } else if len(info.addr) == 16 { + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + cm := ipv6.ControlMessage{ + Src: info.addr, + IfIndex: int(info.ifIndex), + } + return cm.Marshal() + } + return nil +} diff --git a/internal/quic-go/sys_conn_oob_test.go b/internal/quic-go/sys_conn_oob_test.go new file mode 100644 index 00000000..82afa1c8 --- /dev/null +++ b/internal/quic-go/sys_conn_oob_test.go @@ -0,0 +1,243 @@ +//go:build !windows +// +build !windows + +package quic + +import ( + "fmt" + "net" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/sys/unix" + + "github.com/golang/mock/gomock" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("OOB Conn Test", func() { + runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { + addr, err := net.ResolveUDPAddr(network, address) + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP(network, addr) + Expect(err).ToNot(HaveOccurred()) + oobConn, err := newConn(udpConn) + Expect(err).ToNot(HaveOccurred()) + + packetChan := make(chan *receivedPacket) + go func() { + defer GinkgoRecover() + for { + p, err := oobConn.ReadPacket() + if err != nil { + return + } + packetChan <- p + } + }() + + return udpConn, packetChan + } + + Context("ECN conn", func() { + sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { + conn, err := net.DialUDP(network, nil, addr) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + rawConn, err := conn.SyscallConn() + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, rawConn.Control(func(fd uintptr) { + setECN(fd) + })).To(Succeed()) + _, err = conn.Write([]byte("foobar")) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return conn.LocalAddr() + } + + It("reads ECN flags on IPv4", func() { + conn, packetChan := runServer("udp4", "localhost:0") + defer conn.Close() + + sentFrom := sendPacketWithECN( + "udp4", + conn.LocalAddr().(*net.UDPAddr), + func(fd uintptr) { + Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 2)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.ecn).To(Equal(protocol.ECT0)) + }) + + It("reads ECN flags on IPv6", func() { + conn, packetChan := runServer("udp6", "[::]:0") + defer conn.Close() + + sentFrom := sendPacketWithECN( + "udp6", + conn.LocalAddr().(*net.UDPAddr), + func(fd uintptr) { + Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 3)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.ecn).To(Equal(protocol.ECNCE)) + }) + + It("reads ECN flags on a connection that supports both IPv4 and IPv6", func() { + conn, packetChan := runServer("udp", "0.0.0.0:0") + defer conn.Close() + port := conn.LocalAddr().(*net.UDPAddr).Port + + // IPv4 + sendPacketWithECN( + "udp4", + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, + func(fd uintptr) { + Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 3)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) + Expect(p.ecn).To(Equal(protocol.ECNCE)) + + // IPv6 + sendPacketWithECN( + "udp6", + &net.UDPAddr{IP: net.IPv6loopback, Port: port}, + func(fd uintptr) { + Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 1)).To(Succeed()) + }, + ) + + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) + Expect(p.ecn).To(Equal(protocol.ECT1)) + }) + }) + + Context("Packet Info conn", func() { + sendPacket := func(network string, addr *net.UDPAddr) net.Addr { + conn, err := net.DialUDP(network, nil, addr) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + _, err = conn.Write([]byte("foobar")) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return conn.LocalAddr() + } + + It("reads packet info on IPv4", func() { + conn, packetChan := runServer("udp4", ":0") + defer conn.Close() + + addr := conn.LocalAddr().(*net.UDPAddr) + ip := net.ParseIP("127.0.0.1").To4() + addr.IP = ip + sentFrom := sendPacket("udp4", addr) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.info).To(Not(BeNil())) + Expect(p.info.addr.To4()).To(Equal(ip)) + }) + + It("reads packet info on IPv6", func() { + conn, packetChan := runServer("udp6", ":0") + defer conn.Close() + + addr := conn.LocalAddr().(*net.UDPAddr) + ip := net.ParseIP("::1") + addr.IP = ip + sentFrom := sendPacket("udp6", addr) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.info).To(Not(BeNil())) + Expect(p.info.addr).To(Equal(ip)) + }) + + It("reads packet info on a connection that supports both IPv4 and IPv6", func() { + conn, packetChan := runServer("udp", ":0") + defer conn.Close() + port := conn.LocalAddr().(*net.UDPAddr).Port + + // IPv4 + ip4 := net.ParseIP("127.0.0.1").To4() + sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port}) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) + Expect(p.info).To(Not(BeNil())) + Expect(p.info.addr.To4()).To(Equal(ip4)) + + // IPv6 + ip6 := net.ParseIP("::1") + sendPacket("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: port}) + + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) + Expect(p.info).To(Not(BeNil())) + Expect(p.info.addr).To(Equal(ip6)) + }) + }) + + Context("Batch Reading", func() { + var batchConn *MockBatchConn + + BeforeEach(func() { + batchConn = NewMockBatchConn(mockCtrl) + }) + + It("reads multiple messages in one batch", func() { + const numMsgRead = batchSize/2 + 1 + var counter int + batchConn.EXPECT().ReadBatch(gomock.Any(), gomock.Any()).DoAndReturn(func(ms []ipv4.Message, flags int) (int, error) { + Expect(ms).To(HaveLen(batchSize)) + for i := 0; i < numMsgRead; i++ { + Expect(ms[i].Buffers).To(HaveLen(1)) + Expect(ms[i].Buffers[0]).To(HaveLen(int(protocol.MaxPacketBufferSize))) + data := []byte(fmt.Sprintf("message %d", counter)) + counter++ + ms[i].Buffers[0] = data + ms[i].N = len(data) + } + return numMsgRead, nil + }).Times(2) + + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + oobConn, err := newConn(udpConn) + Expect(err).ToNot(HaveOccurred()) + oobConn.batchConn = batchConn + + for i := 0; i < batchSize+1; i++ { + p, err := oobConn.ReadPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(string(p.data)).To(Equal(fmt.Sprintf("message %d", i))) + } + }) + }) +}) diff --git a/internal/quic-go/sys_conn_test.go b/internal/quic-go/sys_conn_test.go new file mode 100644 index 00000000..15df7760 --- /dev/null +++ b/internal/quic-go/sys_conn_test.go @@ -0,0 +1,33 @@ +package quic + +import ( + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Basic Conn Test", func() { + It("reads a packet", func() { + c := NewMockPacketConn(mockCtrl) + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + c.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + data := []byte("foobar") + Expect(b).To(HaveLen(int(protocol.MaxPacketBufferSize))) + return copy(b, data), addr, nil + }) + + conn, err := wrapConn(c) + Expect(err).ToNot(HaveOccurred()) + p, err := conn.ReadPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(100*time.Millisecond))) + Expect(p.remoteAddr).To(Equal(addr)) + }) +}) diff --git a/internal/quic-go/sys_conn_windows.go b/internal/quic-go/sys_conn_windows.go new file mode 100644 index 00000000..f2cc22ab --- /dev/null +++ b/internal/quic-go/sys_conn_windows.go @@ -0,0 +1,40 @@ +//go:build windows +// +build windows + +package quic + +import ( + "errors" + "fmt" + "net" + "syscall" + + "golang.org/x/sys/windows" +) + +func newConn(c OOBCapablePacketConn) (rawConn, error) { + return &basicConn{PacketConn: c}, nil +} + +func inspectReadBuffer(c net.PacketConn) (int, error) { + conn, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return 0, errors.New("doesn't have a SyscallConn") + } + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) + } + var size int + var serr error + if err := rawConn.Control(func(fd uintptr) { + size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) + }); err != nil { + return 0, err + } + return size, serr +} + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/internal/quic-go/sys_conn_windows_test.go b/internal/quic-go/sys_conn_windows_test.go new file mode 100644 index 00000000..0dae55a9 --- /dev/null +++ b/internal/quic-go/sys_conn_windows_test.go @@ -0,0 +1,33 @@ +//go:build windows +// +build windows + +package quic + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Windows Conn Test", func() { + It("works on IPv4", func() { + addr, err := net.ResolveUDPAddr("udp4", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp4", addr) + Expect(err).ToNot(HaveOccurred()) + conn, err := newConn(udpConn) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.Close()).To(Succeed()) + }) + + It("works on IPv6", func() { + addr, err := net.ResolveUDPAddr("udp6", "[::1]:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp6", addr) + Expect(err).ToNot(HaveOccurred()) + conn, err := newConn(udpConn) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.Close()).To(Succeed()) + }) +}) diff --git a/internal/quic-go/testdata/ca.pem b/internal/quic-go/testdata/ca.pem new file mode 100644 index 00000000..67a5545e --- /dev/null +++ b/internal/quic-go/testdata/ca.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x +dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z +MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 +aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh +iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 +qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp +3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD +WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS +Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU +1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j +wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK +uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ +nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX +CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp +qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 +-----END CERTIFICATE----- diff --git a/internal/quic-go/testdata/cert.go b/internal/quic-go/testdata/cert.go new file mode 100644 index 00000000..f862b0cb --- /dev/null +++ b/internal/quic-go/testdata/cert.go @@ -0,0 +1,55 @@ +package testdata + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "path" + "runtime" +) + +var certPath string + +func init() { + _, filename, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + + certPath = path.Dir(filename) +} + +// GetCertificatePaths returns the paths to certificate and key +func GetCertificatePaths() (string, string) { + return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") +} + +// GetTLSConfig returns a tls config for quic.clemente.io +func GetTLSConfig() *tls.Config { + cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +// AddRootCA adds the root CA certificate to a cert pool +func AddRootCA(certPool *x509.CertPool) { + caCertPath := path.Join(certPath, "ca.pem") + caCertRaw, err := ioutil.ReadFile(caCertPath) + if err != nil { + panic(err) + } + if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { + panic("Could not add root ceritificate to pool.") + } +} + +// GetRootCA returns an x509.CertPool containing (only) the CA certificate +func GetRootCA() *x509.CertPool { + pool := x509.NewCertPool() + AddRootCA(pool) + return pool +} diff --git a/internal/quic-go/testdata/cert.pem b/internal/quic-go/testdata/cert.pem new file mode 100644 index 00000000..91d1aa9e --- /dev/null +++ b/internal/quic-go/testdata/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV +BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz +NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 ++w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu +WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ +meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw +Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM +i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz +shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf +636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj +fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh +hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R +vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U +A/TbaR0ih/qG +-----END CERTIFICATE----- diff --git a/internal/quic-go/testdata/cert_test.go b/internal/quic-go/testdata/cert_test.go new file mode 100644 index 00000000..0de1bd7b --- /dev/null +++ b/internal/quic-go/testdata/cert_test.go @@ -0,0 +1,31 @@ +package testdata + +import ( + "crypto/tls" + "io/ioutil" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("certificates", func() { + It("returns certificates", func() { + ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + conn, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + _, err = conn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + }() + + conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(conn) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("foobar")) + }) +}) diff --git a/internal/quic-go/testdata/generate_key.sh b/internal/quic-go/testdata/generate_key.sh new file mode 100755 index 00000000..7ecaa966 --- /dev/null +++ b/internal/quic-go/testdata/generate_key.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e + +echo "Generating CA key and certificate:" +openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ + -keyout ca.key -out ca.pem \ + -subj "/O=quic-go Certificate Authority/" + +echo "Generating CSR" +openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ + -subj "/O=quic-go/" + +echo "Sign certificate:" +openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ + -CA ca.pem -CAkey ca.key -CAcreateserial \ + -extfile <(printf "subjectAltName=DNS:localhost") + +# debug output the certificate +openssl x509 -noout -text -in cert.pem + +# we don't need the CA key, the serial number and the CSR any more +rm ca.key cert.csr ca.srl + diff --git a/internal/quic-go/testdata/priv.key b/internal/quic-go/testdata/priv.key new file mode 100644 index 00000000..56b8d894 --- /dev/null +++ b/internal/quic-go/testdata/priv.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ +23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK +pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq +t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa +bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd +Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls +dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw +dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd +IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 +hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 +AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 +tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ +Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r +9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u +O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 +tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 +Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H +9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec +7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW +uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 +vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq +unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 +lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T +d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ +yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 +jT0GzDymgLMGp8RPdBkpk+w= +-----END PRIVATE KEY----- diff --git a/internal/quic-go/testdata/testdata_suite_test.go b/internal/quic-go/testdata/testdata_suite_test.go new file mode 100644 index 00000000..4e9011cf --- /dev/null +++ b/internal/quic-go/testdata/testdata_suite_test.go @@ -0,0 +1,13 @@ +package testdata + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestTestdata(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Testdata Suite") +} diff --git a/internal/quic-go/testutils/testutils.go b/internal/quic-go/testutils/testutils.go new file mode 100644 index 00000000..cbd24c91 --- /dev/null +++ b/internal/quic-go/testutils/testutils.go @@ -0,0 +1,97 @@ +package testutils + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/handshake" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +// Utilities for simulating packet injection and man-in-the-middle (MITM) attacker tests. +// Do not use for non-testing purposes. + +// writePacket returns a new raw packet with the specified header and payload +func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { + buf := &bytes.Buffer{} + hdr.Write(buf, hdr.Version) + return append(buf.Bytes(), data...) +} + +// packRawPayload returns a new raw payload containing given frames +func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte { + buf := new(bytes.Buffer) + for _, cf := range frames { + cf.Write(buf, version) + } + return buf.Bytes() +} + +// ComposeInitialPacket returns an Initial packet encrypted under key +// (the original destination connection ID) containing specified frames +func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte { + sealer, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer, version) + + // compose payload + var payload []byte + if len(frames) == 0 { + payload = make([]byte, protocol.MinInitialPacketSize) + } else { + payload = packRawPayload(version, frames) + } + + // compose Initial header + payloadSize := len(payload) + pnLength := protocol.PacketNumberLen4 + length := payloadSize + int(pnLength) + sealer.Overhead() + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Length: protocol.ByteCount(length), + Version: version, + }, + PacketNumberLen: pnLength, + PacketNumber: 0x0, + } + + raw := writePacket(hdr, payload) + + // encrypt payload and header + payloadOffset := len(raw) - payloadSize + var encrypted []byte + encrypted = sealer.Seal(encrypted, payload, hdr.PacketNumber, raw[:payloadOffset]) + hdrBytes := raw[0:payloadOffset] + encrypted = append(hdrBytes, encrypted...) + pnOffset := payloadOffset - int(pnLength) // packet number offset + sealer.EncryptHeader( + encrypted[payloadOffset:payloadOffset+16], // first 16 bytes of payload (sample) + &encrypted[0], // first byte of header + encrypted[pnOffset:payloadOffset], // packet number bytes + ) + return encrypted +} + +// ComposeRetryPacket returns a new raw Retry Packet +func ComposeRetryPacket( + srcConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + token []byte, + version protocol.VersionNumber, +) []byte { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Token: token, + Version: version, + }, + } + data := writePacket(hdr, nil) + return append(data, handshake.GetRetryIntegrityTag(data, origDestConnID, version)[:]...) +} diff --git a/internal/quic-go/token_store.go b/internal/quic-go/token_store.go new file mode 100644 index 00000000..cbc07830 --- /dev/null +++ b/internal/quic-go/token_store.go @@ -0,0 +1,117 @@ +package quic + +import ( + "container/list" + "sync" + + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +type singleOriginTokenStore struct { + tokens []*ClientToken + len int + p int +} + +func newSingleOriginTokenStore(size int) *singleOriginTokenStore { + return &singleOriginTokenStore{tokens: make([]*ClientToken, size)} +} + +func (s *singleOriginTokenStore) Add(token *ClientToken) { + s.tokens[s.p] = token + s.p = s.index(s.p + 1) + s.len = utils.Min(s.len+1, len(s.tokens)) +} + +func (s *singleOriginTokenStore) Pop() *ClientToken { + s.p = s.index(s.p - 1) + token := s.tokens[s.p] + s.tokens[s.p] = nil + s.len = utils.Max(s.len-1, 0) + return token +} + +func (s *singleOriginTokenStore) Len() int { + return s.len +} + +func (s *singleOriginTokenStore) index(i int) int { + mod := len(s.tokens) + return (i + mod) % mod +} + +type lruTokenStoreEntry struct { + key string + cache *singleOriginTokenStore +} + +type lruTokenStore struct { + mutex sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int + singleOriginSize int +} + +var _ TokenStore = &lruTokenStore{} + +// NewLRUTokenStore creates a new LRU cache for tokens received by the client. +// maxOrigins specifies how many origins this cache is saving tokens for. +// tokensPerOrigin specifies the maximum number of tokens per origin. +func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { + return &lruTokenStore{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: maxOrigins, + singleOriginSize: tokensPerOrigin, + } +} + +func (s *lruTokenStore) Put(key string, token *ClientToken) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if el, ok := s.m[key]; ok { + entry := el.Value.(*lruTokenStoreEntry) + entry.cache.Add(token) + s.q.MoveToFront(el) + return + } + + if s.q.Len() < s.capacity { + entry := &lruTokenStoreEntry{ + key: key, + cache: newSingleOriginTokenStore(s.singleOriginSize), + } + entry.cache.Add(token) + s.m[key] = s.q.PushFront(entry) + return + } + + elem := s.q.Back() + entry := elem.Value.(*lruTokenStoreEntry) + delete(s.m, entry.key) + entry.key = key + entry.cache = newSingleOriginTokenStore(s.singleOriginSize) + entry.cache.Add(token) + s.q.MoveToFront(elem) + s.m[key] = elem +} + +func (s *lruTokenStore) Pop(key string) *ClientToken { + s.mutex.Lock() + defer s.mutex.Unlock() + + var token *ClientToken + if el, ok := s.m[key]; ok { + s.q.MoveToFront(el) + cache := el.Value.(*lruTokenStoreEntry).cache + token = cache.Pop() + if cache.Len() == 0 { + s.q.Remove(el) + delete(s.m, key) + } + } + return token +} diff --git a/internal/quic-go/token_store_test.go b/internal/quic-go/token_store_test.go new file mode 100644 index 00000000..01107821 --- /dev/null +++ b/internal/quic-go/token_store_test.go @@ -0,0 +1,108 @@ +package quic + +import ( + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Token Cache", func() { + var s TokenStore + + BeforeEach(func() { + s = NewLRUTokenStore(3, 4) + }) + + mockToken := func(num int) *ClientToken { + return &ClientToken{data: []byte(fmt.Sprintf("%d", num))} + } + + Context("for a single origin", func() { + const origin = "localhost" + + It("adds and gets tokens", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(Equal(mockToken(1))) + Expect(s.Pop(origin)).To(BeNil()) + }) + + It("overwrites old tokens", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + s.Put(origin, mockToken(3)) + s.Put(origin, mockToken(4)) + s.Put(origin, mockToken(5)) + Expect(s.Pop(origin)).To(Equal(mockToken(5))) + Expect(s.Pop(origin)).To(Equal(mockToken(4))) + Expect(s.Pop(origin)).To(Equal(mockToken(3))) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(BeNil()) + }) + + It("continues after getting a token", func() { + s.Put(origin, mockToken(1)) + s.Put(origin, mockToken(2)) + s.Put(origin, mockToken(3)) + Expect(s.Pop(origin)).To(Equal(mockToken(3))) + s.Put(origin, mockToken(4)) + s.Put(origin, mockToken(5)) + Expect(s.Pop(origin)).To(Equal(mockToken(5))) + Expect(s.Pop(origin)).To(Equal(mockToken(4))) + Expect(s.Pop(origin)).To(Equal(mockToken(2))) + Expect(s.Pop(origin)).To(Equal(mockToken(1))) + Expect(s.Pop(origin)).To(BeNil()) + }) + }) + + Context("for multiple origins", func() { + It("adds and gets tokens", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host1")).To(BeNil()) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host2")).To(BeNil()) + }) + + It("evicts old entries", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + s.Put("host4", mockToken(4)) + Expect(s.Pop("host1")).To(BeNil()) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + + It("moves old entries to the front, when new tokens are added", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + s.Put("host1", mockToken(11)) + // make sure one is evicted + s.Put("host4", mockToken(4)) + Expect(s.Pop("host2")).To(BeNil()) + Expect(s.Pop("host1")).To(Equal(mockToken(11))) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + + It("deletes hosts that are empty", func() { + s.Put("host1", mockToken(1)) + s.Put("host2", mockToken(2)) + s.Put("host3", mockToken(3)) + Expect(s.Pop("host2")).To(Equal(mockToken(2))) + Expect(s.Pop("host2")).To(BeNil()) + // host2 is now empty and should have been deleted, making space for host4 + s.Put("host4", mockToken(4)) + Expect(s.Pop("host1")).To(Equal(mockToken(1))) + Expect(s.Pop("host3")).To(Equal(mockToken(3))) + Expect(s.Pop("host4")).To(Equal(mockToken(4))) + }) + }) +}) diff --git a/internal/quic-go/tools.go b/internal/quic-go/tools.go new file mode 100644 index 00000000..ee68fafb --- /dev/null +++ b/internal/quic-go/tools.go @@ -0,0 +1,9 @@ +//go:build tools +// +build tools + +package quic + +import ( + _ "github.com/cheekybits/genny" + _ "github.com/onsi/ginkgo/ginkgo" +) diff --git a/internal/quic-go/utils/atomic_bool.go b/internal/quic-go/utils/atomic_bool.go new file mode 100644 index 00000000..cf464250 --- /dev/null +++ b/internal/quic-go/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/internal/quic-go/utils/atomic_bool_test.go b/internal/quic-go/utils/atomic_bool_test.go new file mode 100644 index 00000000..83a200c2 --- /dev/null +++ b/internal/quic-go/utils/atomic_bool_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Atomic Bool", func() { + var a *AtomicBool + + BeforeEach(func() { + a = &AtomicBool{} + }) + + It("has the right default value", func() { + Expect(a.Get()).To(BeFalse()) + }) + + It("sets the value to true", func() { + a.Set(true) + Expect(a.Get()).To(BeTrue()) + }) + + It("sets the value to false", func() { + a.Set(true) + a.Set(false) + Expect(a.Get()).To(BeFalse()) + }) +}) diff --git a/internal/quic-go/utils/buffered_write_closer.go b/internal/quic-go/utils/buffered_write_closer.go new file mode 100644 index 00000000..b5b9d6fc --- /dev/null +++ b/internal/quic-go/utils/buffered_write_closer.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bufio" + "io" +) + +type bufferedWriteCloser struct { + *bufio.Writer + io.Closer +} + +// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer +func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { + return &bufferedWriteCloser{ + Writer: writer, + Closer: closer, + } +} + +func (h bufferedWriteCloser) Close() error { + if err := h.Writer.Flush(); err != nil { + return err + } + return h.Closer.Close() +} diff --git a/internal/quic-go/utils/buffered_write_closer_test.go b/internal/quic-go/utils/buffered_write_closer_test.go new file mode 100644 index 00000000..9c93d615 --- /dev/null +++ b/internal/quic-go/utils/buffered_write_closer_test.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bufio" + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type nopCloser struct{} + +func (nopCloser) Close() error { return nil } + +var _ = Describe("buffered io.WriteCloser", func() { + It("flushes before closing", func() { + buf := &bytes.Buffer{} + + w := bufio.NewWriter(buf) + wc := NewBufferedWriteCloser(w, &nopCloser{}) + wc.Write([]byte("foobar")) + Expect(buf.Len()).To(BeZero()) + Expect(wc.Close()).To(Succeed()) + Expect(buf.String()).To(Equal("foobar")) + }) +}) diff --git a/internal/quic-go/utils/byteinterval_linkedlist.go b/internal/quic-go/utils/byteinterval_linkedlist.go new file mode 100644 index 00000000..096023ef --- /dev/null +++ b/internal/quic-go/utils/byteinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// ByteIntervalElement is an element of a linked list. +type ByteIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *ByteIntervalElement + + // The list to which this element belongs. + list *ByteIntervalList + + // The value stored with this element. + Value ByteInterval +} + +// Next returns the next list element or nil. +func (e *ByteIntervalElement) Next() *ByteIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *ByteIntervalElement) Prev() *ByteIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// ByteIntervalList is a linked list of ByteIntervals. +type ByteIntervalList struct { + root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *ByteIntervalList) Init() *ByteIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewByteIntervalList returns an initialized list. +func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *ByteIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *ByteIntervalList) Front() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *ByteIntervalList) Back() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *ByteIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { + return l.insert(&ByteIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/quic-go/utils/byteoder_big_endian_test.go b/internal/quic-go/utils/byteoder_big_endian_test.go new file mode 100644 index 00000000..5d0873a9 --- /dev/null +++ b/internal/quic-go/utils/byteoder_big_endian_test.go @@ -0,0 +1,107 @@ +package utils + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Big Endian encoding / decoding", func() { + Context("ReadUint16", func() { + It("reads a big endian", func() { + b := []byte{0x13, 0xEF} + val, err := BigEndian.ReadUint16(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint16(0x13EF))) + }) + + It("throws an error if less than 2 bytes are passed", func() { + b := []byte{0x13, 0xEF} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint16(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("ReadUint24", func() { + It("reads a big endian", func() { + b := []byte{0x13, 0xbe, 0xef} + val, err := BigEndian.ReadUint24(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x13beef))) + }) + + It("throws an error if less than 3 bytes are passed", func() { + b := []byte{0x13, 0xbe, 0xef} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint24(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("ReadUint32", func() { + It("reads a big endian", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF} + val, err := BigEndian.ReadUint32(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x1235ABFF))) + }) + + It("throws an error if less than 4 bytes are passed", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint32(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("WriteUint16", func() { + It("outputs 2 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint16(b, uint16(1)) + Expect(b.Len()).To(Equal(2)) + }) + + It("outputs a big endian", func() { + num := uint16(0xFF11) + b := &bytes.Buffer{} + BigEndian.WriteUint16(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xFF, 0x11})) + }) + }) + + Context("WriteUint24", func() { + It("outputs 3 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, uint32(1)) + Expect(b.Len()).To(Equal(3)) + }) + + It("outputs a big endian", func() { + num := uint32(0xff11aa) + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa})) + }) + }) + + Context("WriteUint32", func() { + It("outputs 4 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint32(b, uint32(1)) + Expect(b.Len()).To(Equal(4)) + }) + + It("outputs a big endian", func() { + num := uint32(0xEFAC3512) + b := &bytes.Buffer{} + BigEndian.WriteUint32(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) + }) + }) +}) diff --git a/internal/quic-go/utils/byteorder.go b/internal/quic-go/utils/byteorder.go new file mode 100644 index 00000000..d1f52842 --- /dev/null +++ b/internal/quic-go/utils/byteorder.go @@ -0,0 +1,17 @@ +package utils + +import ( + "bytes" + "io" +) + +// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + ReadUint32(io.ByteReader) (uint32, error) + ReadUint24(io.ByteReader) (uint32, error) + ReadUint16(io.ByteReader) (uint16, error) + + WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) + WriteUint16(*bytes.Buffer, uint16) +} diff --git a/internal/quic-go/utils/byteorder_big_endian.go b/internal/quic-go/utils/byteorder_big_endian.go new file mode 100644 index 00000000..d05542e1 --- /dev/null +++ b/internal/quic-go/utils/byteorder_big_endian.go @@ -0,0 +1,89 @@ +package utils + +import ( + "bytes" + "io" +) + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian ByteOrder = bigEndian{} + +type bigEndian struct{} + +var _ ByteOrder = &bigEndian{} + +// ReadUintN reads N bytes +func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { + var res uint64 + for i := uint8(0); i < length; i++ { + bt, err := b.ReadByte() + if err != nil { + return 0, err + } + res ^= uint64(bt) << ((length - 1 - i) * 8) + } + return res, nil +} + +// ReadUint32 reads a uint32 +func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +// ReadUint24 reads a uint24 +func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { + var b1, b2, b3 uint8 + var err error + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil +} + +// ReadUint16 reads a uint16 +func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +} + +// WriteUint32 writes a uint32 +func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint24 writes a uint24 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint16 writes a uint16 +func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { + b.Write([]byte{uint8(i >> 8), uint8(i)}) +} diff --git a/internal/utils/gen.go b/internal/quic-go/utils/gen.go similarity index 100% rename from internal/utils/gen.go rename to internal/quic-go/utils/gen.go diff --git a/internal/quic-go/utils/ip.go b/internal/quic-go/utils/ip.go new file mode 100644 index 00000000..7ac7ffec --- /dev/null +++ b/internal/quic-go/utils/ip.go @@ -0,0 +1,10 @@ +package utils + +import "net" + +func IsIPv4(ip net.IP) bool { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + return ip.To4() != nil +} diff --git a/internal/quic-go/utils/ip_test.go b/internal/quic-go/utils/ip_test.go new file mode 100644 index 00000000..b61cf529 --- /dev/null +++ b/internal/quic-go/utils/ip_test.go @@ -0,0 +1,17 @@ +package utils + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("IP", func() { + It("tells IPv4 and IPv6 addresses apart", func() { + Expect(IsIPv4(net.IPv4(127, 0, 0, 1))).To(BeTrue()) + Expect(IsIPv4(net.IPv4zero)).To(BeTrue()) + Expect(IsIPv4(net.IPv6zero)).To(BeFalse()) + Expect(IsIPv4(net.IPv6loopback)).To(BeFalse()) + }) +}) diff --git a/internal/quic-go/utils/linkedlist/README.md b/internal/quic-go/utils/linkedlist/README.md new file mode 100644 index 00000000..15b46dce --- /dev/null +++ b/internal/quic-go/utils/linkedlist/README.md @@ -0,0 +1,11 @@ +# Usage + +This is the Go standard library implementation of a linked list +(https://golang.org/src/container/list/list.go), modified such that genny +(https://github.com/cheekybits/genny) can be used to generate a typed linked +list. + +To generate, run +``` +genny -pkg $PACKAGE -in linkedlist.go -out $OUTFILE gen Item=$TYPE +``` diff --git a/internal/quic-go/utils/linkedlist/linkedlist.go b/internal/quic-go/utils/linkedlist/linkedlist.go new file mode 100644 index 00000000..74b815a8 --- /dev/null +++ b/internal/quic-go/utils/linkedlist/linkedlist.go @@ -0,0 +1,218 @@ +package linkedlist + +import "github.com/cheekybits/genny/generic" + +// Linked list implementation from the Go standard library. + +// Item is a generic type. +type Item generic.Type + +// ItemElement is an element of a linked list. +type ItemElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *ItemElement + + // The list to which this element belongs. + list *ItemList + + // The value stored with this element. + Value Item +} + +// Next returns the next list element or nil. +func (e *ItemElement) Next() *ItemElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *ItemElement) Prev() *ItemElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// ItemList is a linked list of Items. +type ItemList struct { + root ItemElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *ItemList) Init() *ItemList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewItemList returns an initialized list. +func NewItemList() *ItemList { return new(ItemList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *ItemList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *ItemList) Front() *ItemElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *ItemList) Back() *ItemElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *ItemList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *ItemList) insert(e, at *ItemElement) *ItemElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *ItemList) insertValue(v Item, at *ItemElement) *ItemElement { + return l.insert(&ItemElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *ItemList) remove(e *ItemElement) *ItemElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *ItemList) Remove(e *ItemElement) Item { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *ItemList) PushFront(v Item) *ItemElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *ItemList) PushBack(v Item) *ItemElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ItemList) InsertBefore(v Item, mark *ItemElement) *ItemElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ItemList) InsertAfter(v Item, mark *ItemElement) *ItemElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ItemList) MoveToFront(e *ItemElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ItemList) MoveToBack(e *ItemElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ItemList) MoveBefore(e, mark *ItemElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ItemList) MoveAfter(e, mark *ItemElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ItemList) PushBackList(other *ItemList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ItemList) PushFrontList(other *ItemList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/utils/log.go b/internal/quic-go/utils/log.go similarity index 97% rename from internal/utils/log.go rename to internal/quic-go/utils/log.go index e27f01b4..e04ee514 100644 --- a/internal/utils/log.go +++ b/internal/quic-go/utils/log.go @@ -125,7 +125,7 @@ func readLoggingEnv() LogLevel { case "error": return LogLevelError default: - fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/imroc/req/v3/internal/quic-go/wiki/Logging") return LogLevelNothing } } diff --git a/internal/utils/log_test.go b/internal/quic-go/utils/log_test.go similarity index 100% rename from internal/utils/log_test.go rename to internal/quic-go/utils/log_test.go diff --git a/internal/quic-go/utils/minmax.go b/internal/quic-go/utils/minmax.go new file mode 100644 index 00000000..1a3448d8 --- /dev/null +++ b/internal/quic-go/utils/minmax.go @@ -0,0 +1,170 @@ +package utils + +import ( + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +// Max returns the maximum of two Ints +func Max(a, b int) int { + if a < b { + return b + } + return a +} + +// MaxUint32 returns the maximum of two uint32 +func MaxUint32(a, b uint32) uint32 { + if a < b { + return b + } + return a +} + +// MaxUint64 returns the maximum of two uint64 +func MaxUint64(a, b uint64) uint64 { + if a < b { + return b + } + return a +} + +// MinUint64 returns the maximum of two uint64 +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +// Min returns the minimum of two Ints +func Min(a, b int) int { + if a < b { + return a + } + return b +} + +// MinUint32 returns the maximum of two uint32 +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +// MinInt64 returns the minimum of two int64 +func MinInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +// MaxInt64 returns the minimum of two int64 +func MaxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// MinByteCount returns the minimum of two ByteCounts +func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return a + } + return b +} + +// MaxByteCount returns the maximum of two ByteCounts +func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return b + } + return a +} + +// MaxDuration returns the max duration +func MaxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +// MinDuration returns the minimum duration +func MinDuration(a, b time.Duration) time.Duration { + if a > b { + return b + } + return a +} + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return MinDuration(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + +// MaxPacketNumber returns the max packet number +func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a > b { + return a + } + return b +} + +// MinPacketNumber returns the min packet number +func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a < b { + return a + } + return b +} diff --git a/internal/quic-go/utils/minmax_test.go b/internal/quic-go/utils/minmax_test.go new file mode 100644 index 00000000..5ee0caa5 --- /dev/null +++ b/internal/quic-go/utils/minmax_test.go @@ -0,0 +1,123 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Min / Max", func() { + Context("Max", func() { + It("returns the maximum", func() { + Expect(Max(5, 7)).To(Equal(7)) + Expect(Max(7, 5)).To(Equal(7)) + }) + + It("returns the maximum uint32", func() { + Expect(MaxUint32(5, 7)).To(Equal(uint32(7))) + Expect(MaxUint32(7, 5)).To(Equal(uint32(7))) + }) + + It("returns the maximum uint64", func() { + Expect(MaxUint64(5, 7)).To(Equal(uint64(7))) + Expect(MaxUint64(7, 5)).To(Equal(uint64(7))) + }) + + It("returns the minimum uint64", func() { + Expect(MinUint64(5, 7)).To(Equal(uint64(5))) + Expect(MinUint64(7, 5)).To(Equal(uint64(5))) + }) + + It("returns the maximum int64", func() { + Expect(MaxInt64(5, 7)).To(Equal(int64(7))) + Expect(MaxInt64(7, 5)).To(Equal(int64(7))) + }) + + It("returns the maximum ByteCount", func() { + Expect(MaxByteCount(7, 5)).To(Equal(protocol.ByteCount(7))) + Expect(MaxByteCount(5, 7)).To(Equal(protocol.ByteCount(7))) + }) + + It("returns the maximum duration", func() { + Expect(MaxDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Microsecond)) + Expect(MaxDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Microsecond)) + }) + + It("returns the minimum duration", func() { + Expect(MinDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Nanosecond)) + Expect(MinDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Nanosecond)) + }) + + It("returns packet number max", func() { + Expect(MaxPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(2))) + Expect(MaxPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(2))) + }) + + It("returns the maximum time", func() { + a := time.Now() + b := a.Add(time.Second) + Expect(MaxTime(a, b)).To(Equal(b)) + Expect(MaxTime(b, a)).To(Equal(b)) + }) + }) + + Context("Min", func() { + It("returns the minimum", func() { + Expect(Min(5, 7)).To(Equal(5)) + Expect(Min(7, 5)).To(Equal(5)) + }) + + It("returns the minimum uint32", func() { + Expect(MinUint32(7, 5)).To(Equal(uint32(5))) + Expect(MinUint32(5, 7)).To(Equal(uint32(5))) + }) + + It("returns the minimum int64", func() { + Expect(MinInt64(7, 5)).To(Equal(int64(5))) + Expect(MinInt64(5, 7)).To(Equal(int64(5))) + }) + + It("returns the minimum ByteCount", func() { + Expect(MinByteCount(7, 5)).To(Equal(protocol.ByteCount(5))) + Expect(MinByteCount(5, 7)).To(Equal(protocol.ByteCount(5))) + }) + + It("returns packet number min", func() { + Expect(MinPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(1))) + Expect(MinPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(1))) + }) + + It("returns the minimum duration", func() { + a := time.Now() + b := a.Add(time.Second) + Expect(MinTime(a, b)).To(Equal(a)) + Expect(MinTime(b, a)).To(Equal(a)) + }) + + It("returns the minium non-zero duration", func() { + var a time.Duration + b := time.Second + Expect(MinNonZeroDuration(0, 0)).To(BeZero()) + Expect(MinNonZeroDuration(a, b)).To(Equal(b)) + Expect(MinNonZeroDuration(b, a)).To(Equal(b)) + Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) + }) + + It("returns the minium non-zero time", func() { + a := time.Time{} + b := time.Now() + Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) + Expect(MinNonZeroTime(a, b)).To(Equal(b)) + Expect(MinNonZeroTime(b, a)).To(Equal(b)) + Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) + Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) + }) + }) + + It("returns the abs time", func() { + Expect(AbsDuration(time.Microsecond)).To(Equal(time.Microsecond)) + Expect(AbsDuration(-time.Microsecond)).To(Equal(time.Microsecond)) + }) +}) diff --git a/internal/quic-go/utils/new_connection_id.go b/internal/quic-go/utils/new_connection_id.go new file mode 100644 index 00000000..41b7484f --- /dev/null +++ b/internal/quic-go/utils/new_connection_id.go @@ -0,0 +1,12 @@ +package utils + +import ( + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// NewConnectionID is a new connection ID +type NewConnectionID struct { + SequenceNumber uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} diff --git a/internal/quic-go/utils/newconnectionid_linkedlist.go b/internal/quic-go/utils/newconnectionid_linkedlist.go new file mode 100644 index 00000000..d59562e5 --- /dev/null +++ b/internal/quic-go/utils/newconnectionid_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// NewConnectionIDElement is an element of a linked list. +type NewConnectionIDElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *NewConnectionIDElement + + // The list to which this element belongs. + list *NewConnectionIDList + + // The value stored with this element. + Value NewConnectionID +} + +// Next returns the next list element or nil. +func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// NewConnectionIDList is a linked list of NewConnectionIDs. +type NewConnectionIDList struct { + root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *NewConnectionIDList) Init() *NewConnectionIDList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewNewConnectionIDList returns an initialized list. +func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *NewConnectionIDList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Front() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Back() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *NewConnectionIDList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { + return l.insert(&NewConnectionIDElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/quic-go/utils/packet_interval.go b/internal/quic-go/utils/packet_interval.go new file mode 100644 index 00000000..5518141d --- /dev/null +++ b/internal/quic-go/utils/packet_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/imroc/req/v3/internal/quic-go/protocol" + +// PacketInterval is an interval from one PacketNumber to the other +type PacketInterval struct { + Start protocol.PacketNumber + End protocol.PacketNumber +} diff --git a/internal/quic-go/utils/packetinterval_linkedlist.go b/internal/quic-go/utils/packetinterval_linkedlist.go new file mode 100644 index 00000000..b461e85a --- /dev/null +++ b/internal/quic-go/utils/packetinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// PacketIntervalElement is an element of a linked list. +type PacketIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *PacketIntervalElement + + // The list to which this element belongs. + list *PacketIntervalList + + // The value stored with this element. + Value PacketInterval +} + +// Next returns the next list element or nil. +func (e *PacketIntervalElement) Next() *PacketIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *PacketIntervalElement) Prev() *PacketIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// PacketIntervalList is a linked list of PacketIntervals. +type PacketIntervalList struct { + root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *PacketIntervalList) Init() *PacketIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewPacketIntervalList returns an initialized list. +func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *PacketIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *PacketIntervalList) Front() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *PacketIntervalList) Back() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *PacketIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { + return l.insert(&PacketIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/internal/quic-go/utils/rand.go b/internal/quic-go/utils/rand.go new file mode 100644 index 00000000..30069144 --- /dev/null +++ b/internal/quic-go/utils/rand.go @@ -0,0 +1,29 @@ +package utils + +import ( + "crypto/rand" + "encoding/binary" +) + +// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. +type Rand struct { + buf [4]byte +} + +func (r *Rand) Int31() int32 { + rand.Read(r.buf[:]) + return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) +} + +// copied from the standard library math/rand implementation of Int63n +func (r *Rand) Int31n(n int32) int32 { + if n&(n-1) == 0 { // n is power of two, can mask + return r.Int31() & (n - 1) + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := r.Int31() + for v > max { + v = r.Int31() + } + return v % n +} diff --git a/internal/quic-go/utils/rand_test.go b/internal/quic-go/utils/rand_test.go new file mode 100644 index 00000000..f15a644e --- /dev/null +++ b/internal/quic-go/utils/rand_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Rand", func() { + It("generates random numbers", func() { + const ( + num = 1000 + max = 12345678 + ) + + var values [num]int32 + var r Rand + for i := 0; i < num; i++ { + v := r.Int31n(max) + Expect(v).To(And( + BeNumerically(">=", 0), + BeNumerically("<", max), + )) + values[i] = v + } + + var sum uint64 + for _, n := range values { + sum += uint64(n) + } + Expect(float64(sum) / num).To(BeNumerically("~", max/2, max/25)) + }) +}) diff --git a/internal/quic-go/utils/rtt_stats.go b/internal/quic-go/utils/rtt_stats.go new file mode 100644 index 00000000..75bfc6d3 --- /dev/null +++ b/internal/quic-go/utils/rtt_stats.go @@ -0,0 +1,127 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +const ( + rttAlpha = 0.125 + oneMinusAlpha = 1 - rttAlpha + rttBeta = 0.25 + oneMinusBeta = 1 - rttBeta + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond +) + +// RTTStats provides round-trip statistics +type RTTStats struct { + hasMeasurement bool + + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + meanDeviation time.Duration + + maxAckDelay time.Duration +} + +// NewRTTStats makes a properly initialized RTTStats object +func NewRTTStats() *RTTStats { + return &RTTStats{} +} + +// MinRTT Returns the minRTT for the entire connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } + +// LatestRTT returns the most recent rtt measurement. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } + +// SmoothedRTT returns the smoothed RTT for the connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } + +// MeanDeviation gets the mean deviation +func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } + +// MaxAckDelay gets the max_ack_delay advertised by the peer +func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } + +// PTO gets the probe timeout duration. +func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { + if r.SmoothedRTT() == 0 { + return 2 * defaultInitialRTT + } + pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + if includeMaxAckDelay { + pto += r.MaxAckDelay() + } + return pto +} + +// UpdateRTT updates the RTT based on a new sample. +func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { + if sendDelta == InfDuration || sendDelta <= 0 { + return + } + + // Update r.minRTT first. r.minRTT does not use an rttSample corrected for + // ackDelay but the raw observed sendDelta, since poor clock granularity at + // the client may cause a high ackDelay to result in underestimation of the + // r.minRTT. + if r.minRTT == 0 || r.minRTT > sendDelta { + r.minRTT = sendDelta + } + + // Correct for ackDelay if information received from the peer results in a + // an RTT sample at least as large as minRTT. Otherwise, only use the + // sendDelta. + sample := sendDelta + if sample-r.minRTT >= ackDelay { + sample -= ackDelay + } + r.latestRTT = sample + // First time call. + if !r.hasMeasurement { + r.hasMeasurement = true + r.smoothedRTT = sample + r.meanDeviation = sample / 2 + } else { + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond + } +} + +// SetMaxAckDelay sets the max_ack_delay +func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { + r.maxAckDelay = mad +} + +// SetInitialRTT sets the initial RTT. +// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. +func (r *RTTStats) SetInitialRTT(t time.Duration) { + if r.hasMeasurement { + panic("initial RTT set after first measurement") + } + r.smoothedRTT = t + r.latestRTT = t +} + +// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. +func (r *RTTStats) OnConnectionMigration() { + r.latestRTT = 0 + r.minRTT = 0 + r.smoothedRTT = 0 + r.meanDeviation = 0 +} + +// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt +// is larger. The mean deviation is increased to the most recent deviation if +// it's larger. +func (r *RTTStats) ExpireSmoothedMetrics() { + r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) + r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) +} diff --git a/internal/quic-go/utils/rtt_stats_test.go b/internal/quic-go/utils/rtt_stats_test.go new file mode 100644 index 00000000..a0de1b93 --- /dev/null +++ b/internal/quic-go/utils/rtt_stats_test.go @@ -0,0 +1,157 @@ +package utils + +import ( + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RTT stats", func() { + var rttStats *RTTStats + + BeforeEach(func() { + rttStats = NewRTTStats() + }) + + It("DefaultsBeforeUpdate", func() { + Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) + }) + + It("SmoothedRTT", func() { + // Verify that ack_delay is ignored in the first measurement. + rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) + // Verify that Smoothed RTT includes max ack delay if it's reasonable. + rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) + // Verify that large erroneous ack_delay does not change Smoothed RTT. + rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond))) + }) + + It("MinRTT", func() { + rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) + // Verify that ack_delay does not go into recording of MinRTT_. + rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond))) + }) + + It("MaxAckDelay", func() { + rttStats.SetMaxAckDelay(42 * time.Minute) + Expect(rttStats.MaxAckDelay()).To(Equal(42 * time.Minute)) + }) + + It("computes the PTO", func() { + maxAckDelay := 42 * time.Minute + rttStats.SetMaxAckDelay(maxAckDelay) + rtt := time.Second + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) + Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) + Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) + Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) + }) + + It("uses the granularity for computing the PTO for short RTTs", func() { + rtt := time.Microsecond + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) + }) + + It("ExpireSmoothedMetrics", func() { + initialRtt := (10 * time.Millisecond) + rttStats.UpdateRTT(initialRtt, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + + Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2)) + + // Update once with a 20ms RTT. + doubledRtt := initialRtt * (2) + rttStats.UpdateRTT(doubledRtt, 0, time.Time{}) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125))) + + // Expire the smoothed metrics, increasing smoothed rtt and mean deviation. + rttStats.ExpireSmoothedMetrics() + Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt)) + Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875))) + + // Now go back down to 5ms and expire the smoothed metrics, and ensure the + // mean deviation increases to 15ms. + halfRtt := initialRtt / 2 + rttStats.UpdateRTT(halfRtt, 0, time.Time{}) + Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT())) + Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation())) + }) + + It("UpdateRTTWithBadSendDeltas", func() { + // Make sure we ignore bad RTTs. + // base::test::MockLog log; + + initialRtt := (10 * time.Millisecond) + rttStats.UpdateRTT(initialRtt, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + + badSendDeltas := []time.Duration{ + 0, + InfDuration, + -1000 * time.Microsecond, + } + // log.StartCapturingLogs(); + + for _, badSendDelta := range badSendDeltas { + // SCOPED_TRACE(Message() << "bad_send_delta = " + // << bad_send_delta.ToMicroseconds()); + // EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring"))); + rttStats.UpdateRTT(badSendDelta, 0, time.Time{}) + Expect(rttStats.MinRTT()).To(Equal(initialRtt)) + Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) + } + }) + + It("ResetAfterConnectionMigrations", func() { + rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) + Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) + + // Reset rtt stats on connection migrations. + rttStats.OnConnectionMigration() + Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) + Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) + }) + + It("restores the RTT", func() { + rttStats.SetInitialRTT(10 * time.Second) + Expect(rttStats.LatestRTT()).To(Equal(10 * time.Second)) + Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second)) + Expect(rttStats.MeanDeviation()).To(BeZero()) + // update the RTT and make sure that the initial value is immediately forgotten + rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) + Expect(rttStats.LatestRTT()).To(Equal(200 * time.Millisecond)) + Expect(rttStats.SmoothedRTT()).To(Equal(200 * time.Millisecond)) + Expect(rttStats.MeanDeviation()).To(Equal(100 * time.Millisecond)) + }) +}) diff --git a/internal/quic-go/utils/streamframe_interval.go b/internal/quic-go/utils/streamframe_interval.go new file mode 100644 index 00000000..4efbd64c --- /dev/null +++ b/internal/quic-go/utils/streamframe_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/imroc/req/v3/internal/quic-go/protocol" + +// ByteInterval is an interval from one ByteCount to the other +type ByteInterval struct { + Start protocol.ByteCount + End protocol.ByteCount +} diff --git a/internal/quic-go/utils/timer.go b/internal/quic-go/utils/timer.go new file mode 100644 index 00000000..a4f5e67a --- /dev/null +++ b/internal/quic-go/utils/timer.go @@ -0,0 +1,53 @@ +package utils + +import ( + "math" + "time" +) + +// A Timer wrapper that behaves correctly when resetting +type Timer struct { + t *time.Timer + read bool + deadline time.Time +} + +// NewTimer creates a new timer that is not set +func NewTimer() *Timer { + return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))} +} + +// Chan returns the channel of the wrapped timer +func (t *Timer) Chan() <-chan time.Time { + return t.t.C +} + +// Reset the timer, no matter whether the value was read or not +func (t *Timer) Reset(deadline time.Time) { + if deadline.Equal(t.deadline) && !t.read { + // No need to reset the timer + return + } + + // We need to drain the timer if the value from its channel was not read yet. + // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU + if !t.t.Stop() && !t.read { + <-t.t.C + } + if !deadline.IsZero() { + t.t.Reset(time.Until(deadline)) + } + + t.read = false + t.deadline = deadline +} + +// SetRead should be called after the value from the chan was read +func (t *Timer) SetRead() { + t.read = true +} + +// Stop stops the timer +func (t *Timer) Stop() { + t.t.Stop() +} diff --git a/internal/quic-go/utils/timer_test.go b/internal/quic-go/utils/timer_test.go new file mode 100644 index 00000000..0cbb4a01 --- /dev/null +++ b/internal/quic-go/utils/timer_test.go @@ -0,0 +1,87 @@ +package utils + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Timer", func() { + const d = 10 * time.Millisecond + + It("doesn't fire a newly created timer", func() { + t := NewTimer() + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("works", func() { + t := NewTimer() + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("works multiple times with reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + t.SetRead() + } + }) + + It("works multiple times without reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + time.Sleep(d * 2) + } + Eventually(t.Chan()).Should(Receive()) + }) + + It("works when resetting without expiration", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(time.Hour)) + } + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("immediately fires the timer, if the deadlines has already passed", func() { + t := NewTimer() + t.Reset(time.Now().Add(-time.Second)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("doesn't set a timer if the deadline is the zero value", func() { + t := NewTimer() + t.Reset(time.Time{}) + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("fires the timer twice, if reset to the same deadline", func() { + deadline := time.Now().Add(-time.Millisecond) + t := NewTimer() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + t.SetRead() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + }) + + It("only fires the timer once, if it is reset to the same deadline, but not read in between", func() { + deadline := time.Now().Add(-time.Millisecond) + t := NewTimer() + t.Reset(deadline) + Eventually(t.Chan()).Should(Receive()) + Consistently(t.Chan()).ShouldNot(Receive()) + }) + + It("stops", func() { + t := NewTimer() + t.Reset(time.Now().Add(50 * time.Millisecond)) + t.Stop() + Consistently(t.Chan()).ShouldNot(Receive()) + }) +}) diff --git a/internal/utils/utils_suite_test.go b/internal/quic-go/utils/utils_suite_test.go similarity index 100% rename from internal/utils/utils_suite_test.go rename to internal/quic-go/utils/utils_suite_test.go diff --git a/internal/quic-go/window_update_queue.go b/internal/quic-go/window_update_queue.go new file mode 100644 index 00000000..67e4ac5f --- /dev/null +++ b/internal/quic-go/window_update_queue.go @@ -0,0 +1,71 @@ +package quic + +import ( + "sync" + + "github.com/imroc/req/v3/internal/quic-go/flowcontrol" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]struct{} // used as a set + queuedConn bool // connection-level window update + + streamGetter streamGetter + connFlowController flowcontrol.ConnectionFlowController + callback func(wire.Frame) +} + +func newWindowUpdateQueue( + streamGetter streamGetter, + connFC flowcontrol.ConnectionFlowController, + cb func(wire.Frame), +) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]struct{}), + streamGetter: streamGetter, + connFlowController: connFC, + callback: cb, + } +} + +func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { + q.mutex.Lock() + q.queue[id] = struct{}{} + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) AddConnection() { + q.mutex.Lock() + q.queuedConn = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + // queue a connection-level window update + if q.queuedConn { + q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()}) + q.queuedConn = false + } + // queue all stream-level window updates + for id := range q.queue { + delete(q.queue, id) + str, err := q.streamGetter.GetOrOpenReceiveStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset := str.getWindowUpdate() + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } + q.callback(&wire.MaxStreamDataFrame{ + StreamID: id, + MaximumStreamData: offset, + }) + } + q.mutex.Unlock() +} diff --git a/internal/quic-go/window_update_queue_test.go b/internal/quic-go/window_update_queue_test.go new file mode 100644 index 00000000..bacefb23 --- /dev/null +++ b/internal/quic-go/window_update_queue_test.go @@ -0,0 +1,112 @@ +package quic + +import ( + "github.com/imroc/req/v3/internal/quic-go/mocks" + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Window Update Queue", func() { + var ( + q *windowUpdateQueue + streamGetter *MockStreamGetter + connFC *mocks.MockConnectionFlowController + queuedFrames []wire.Frame + ) + + BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) + connFC = mocks.NewMockConnectionFlowController(mockCtrl) + queuedFrames = queuedFrames[:0] + q = newWindowUpdateQueue(streamGetter, connFC, func(f wire.Frame) { + queuedFrames = append(queuedFrames, f) + }) + }) + + It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { + stream1 := NewMockStreamI(mockCtrl) + stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10)) + stream3 := NewMockStreamI(mockCtrl) + stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil) + q.AddStream(3) + q.AddStream(1) + q.QueueAll() + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, MaximumStreamData: 10})) + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, MaximumStreamData: 30})) + }) + + It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) + q.AddStream(10) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + }) + + It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { + q.AddStream(12) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("removes closed streams from the queue", func() { + q.AddStream(12) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + // don't EXPECT any further calls to GetOrOpenReceiveStream + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() { + stream5 := NewMockStreamI(mockCtrl) + stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) + q.AddStream(5) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("removes streams for which the flow controller returns an offset of 0 from the queue", func() { + stream5 := NewMockStreamI(mockCtrl) + stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) + q.AddStream(5) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + // don't EXPECT any further calls to GetOrOpenReveiveStream and to getWindowUpdate + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("queues MAX_DATA frames", func() { + connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) + q.AddConnection() + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxDataFrame{MaximumData: 0x1337}, + })) + }) + + It("deduplicates", func() { + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) + q.AddStream(10) + q.AddStream(10) + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200}, + })) + }) +}) diff --git a/internal/quic-go/wire/ack_frame.go b/internal/quic-go/wire/ack_frame.go new file mode 100644 index 00000000..68dc6aa8 --- /dev/null +++ b/internal/quic-go/wire/ack_frame.go @@ -0,0 +1,251 @@ +package wire + +import ( + "bytes" + "errors" + "sort" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") + +// An AckFrame is an ACK frame +type AckFrame struct { + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + DelayTime time.Duration + + ECT0, ECT1, ECNCE uint64 +} + +// parseAckFrame reads an ACK frame +func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + ecn := typeByte&0x1 > 0 + + frame := &AckFrame{} + + la, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + largestAcked := protocol.PacketNumber(la) + delay, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + delayTime := time.Duration(delay*1< largestAcked { + return nil, errors.New("invalid first ACK range") + } + smallest := largestAcked - ackBlock + + // read all the other ACK ranges + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) + for i := uint64(0); i < numBlocks; i++ { + g, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + gap := protocol.PacketNumber(g) + if smallest < gap+2 { + return nil, errInvalidAckRanges + } + largest := smallest - gap - 2 + + ab, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ackBlock := protocol.PacketNumber(ab) + + if ackBlock > largest { + return nil, errInvalidAckRanges + } + smallest = largest - ackBlock + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) + } + + if !frame.validateAckRanges() { + return nil, errInvalidAckRanges + } + + // parse (and skip) the ECN section + if ecn { + for i := 0; i < 3; i++ { + if _, err := quicvarint.Read(r); err != nil { + return nil, err + } + } + } + + return frame, nil +} + +// Write writes an ACK frame. +func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + if hasECN { + b.WriteByte(0x3) + } else { + b.WriteByte(0x2) + } + quicvarint.Write(b, uint64(f.LargestAcked())) + quicvarint.Write(b, encodeAckDelay(f.DelayTime)) + + numRanges := f.numEncodableAckRanges() + quicvarint.Write(b, uint64(numRanges-1)) + + // write the first range + _, firstRange := f.encodeAckRange(0) + quicvarint.Write(b, firstRange) + + // write all the other range + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + quicvarint.Write(b, gap) + quicvarint.Write(b, len) + } + + if hasECN { + quicvarint.Write(b, f.ECT0) + quicvarint.Write(b, f.ECT1) + quicvarint.Write(b, f.ECNCE) + } + return nil +} + +// Length of a written frame +func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() + + length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + + length += quicvarint.Len(uint64(numRanges - 1)) + lowestInFirstRange := f.AckRanges[0].Smallest + length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) + + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += quicvarint.Len(gap) + length += quicvarint.Len(len) + } + if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { + length += quicvarint.Len(f.ECT0) + length += quicvarint.Len(f.ECT1) + length += quicvarint.Len(f.ECNCE) + } + return length +} + +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := quicvarint.Len(gap) + quicvarint.Len(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) +} + +// HasMissingRanges returns if this frame reports any missing packets +func (f *AckFrame) HasMissingRanges() bool { + return len(f.AckRanges) > 1 +} + +func (f *AckFrame) validateAckRanges() bool { + if len(f.AckRanges) == 0 { + return false + } + + // check the validity of every single ACK range + for _, ackRange := range f.AckRanges { + if ackRange.Smallest > ackRange.Largest { + return false + } + } + + // check the consistency for ACK with multiple NACK ranges + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + lastAckRange := f.AckRanges[i-1] + if lastAckRange.Smallest <= ackRange.Smallest { + return false + } + if lastAckRange.Smallest <= ackRange.Largest+1 { + return false + } + } + + return true +} + +// LargestAcked is the largest acked packet number +func (f *AckFrame) LargestAcked() protocol.PacketNumber { + return f.AckRanges[0].Largest +} + +// LowestAcked is the lowest acked packet number +func (f *AckFrame) LowestAcked() protocol.PacketNumber { + return f.AckRanges[len(f.AckRanges)-1].Smallest +} + +// AcksPacket determines if this ACK frame acks a certain packet number +func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { + if p < f.LowestAcked() || p > f.LargestAcked() { + return false + } + + i := sort.Search(len(f.AckRanges), func(i int) bool { + return p >= f.AckRanges[i].Smallest + }) + // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked + return p <= f.AckRanges[i].Largest +} + +func encodeAckDelay(delay time.Duration) uint64 { + return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) +} diff --git a/internal/quic-go/wire/ack_frame_test.go b/internal/quic-go/wire/ack_frame_test.go new file mode 100644 index 00000000..c57d99ff --- /dev/null +++ b/internal/quic-go/wire/ack_frame_test.go @@ -0,0 +1,454 @@ +package wire + +import ( + "bytes" + "io" + "math" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ACK Frame (for IETF QUIC)", func() { + Context("parsing", func() { + It("parses an ACK frame without any ranges", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(10)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("parses an ACK frame that only acks a single packet", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(55)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts an ACK frame that acks all packets from 0 to largest", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(20)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(20)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects an ACK frame that has a first ACK block which is larger than LargestAcked", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(20)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(21)...) // first ack block + b := bytes.NewReader(data) + _, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError("invalid first ACK range")) + }) + + It("parses an ACK frame that has a single block", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(frame.AckRanges).To(Equal([]AckRange{ + {Largest: 1000, Smallest: 900}, + {Largest: 800, Smallest: 750}, + })) + Expect(b.Len()).To(BeZero()) + }) + + It("parses an ACK frame that has a multiple blocks", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(2)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + data = append(data, encodeVarInt(0)...) // gap + data = append(data, encodeVarInt(0)...) // ack block + data = append(data, encodeVarInt(1)...) // gap + data = append(data, encodeVarInt(1)...) // ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(frame.AckRanges).To(Equal([]AckRange{ + {Largest: 100, Smallest: 100}, + {Largest: 98, Smallest: 98}, + {Largest: 95, Smallest: 94}, + })) + Expect(b.Len()).To(BeZero()) + }) + + It("uses the ack delay exponent", func() { + const delayTime = 1 << 10 * time.Millisecond + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: delayTime, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + for i := uint8(0); i < 8; i++ { + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) + } + }) + + It("gracefully handles overflows of the delay time", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(math.MaxUint64/5)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(0)...) // first ack block + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.DelayTime).To(BeNumerically(">", 0)) + // The maximum encodable duration is ~292 years. + Expect(frame.DelayTime.Hours()).To(BeNumerically("~", 292*365*24, 365*24)) + }) + + It("errors on EOF", func() { + data := []byte{0x2} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + + Context("ACK_ECN", func() { + It("parses", func() { + data := []byte{0x3} + data = append(data, encodeVarInt(100)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(0)...) // num blocks + data = append(data, encodeVarInt(10)...) // first ack block + data = append(data, encodeVarInt(0x42)...) // ECT(0) + data = append(data, encodeVarInt(0x12345)...) // ECT(1) + data = append(data, encodeVarInt(0x12345678)...) // ECN-CE + b := bytes.NewReader(data) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) + Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOF", func() { + data := []byte{0x3} + data = append(data, encodeVarInt(1000)...) // largest acked + data = append(data, encodeVarInt(0)...) // delay + data = append(data, encodeVarInt(1)...) // num blocks + data = append(data, encodeVarInt(100)...) // first ack block + data = append(data, encodeVarInt(98)...) // gap + data = append(data, encodeVarInt(50)...) // ack block + data = append(data, encodeVarInt(0x42)...) // ECT(0) + data = append(data, encodeVarInt(0x12345)...) // ECT(1) + data = append(data, encodeVarInt(0x12345678)...) // ECN-CE + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + }) + + Context("when writing", func() { + It("writes a simple frame", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x2} + expected = append(expected, encodeVarInt(1337)...) // largest acked + expected = append(expected, 0) // delay + expected = append(expected, encodeVarInt(0)...) // num ranges + expected = append(expected, encodeVarInt(1337-100)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes an ACK-ECN frame", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 10, Largest: 2000}}, + ECT0: 13, + ECT1: 37, + ECNCE: 12345, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + expected := []byte{0x3} + expected = append(expected, encodeVarInt(2000)...) // largest acked + expected = append(expected, 0) // delay + expected = append(expected, encodeVarInt(0)...) // num ranges + expected = append(expected, encodeVarInt(2000-10)...) + expected = append(expected, encodeVarInt(13)...) + expected = append(expected, encodeVarInt(37)...) + expected = append(expected, encodeVarInt(12345)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a frame that acks a single packet", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, + DelayTime: 18 * time.Millisecond, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(frame.DelayTime).To(Equal(f.DelayTime)) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame that acks many packets", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeFalse()) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame with a a single gap", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 400, Largest: 1000}, + {Smallest: 100, Largest: 200}, + }, + } + Expect(f.validateAckRanges()).To(BeTrue()) + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + }) + + It("writes a frame with multiple ranges", func() { + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 10, Largest: 10}, + {Smallest: 8, Largest: 8}, + {Smallest: 5, Largest: 6}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(f.validateAckRanges()).To(BeTrue()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + }) + + It("limits the maximum size of the ACK frame", func() { + buf := &bytes.Buffer{} + const numRanges = 1000 + ackRanges := make([]AckRange, numRanges) + for i := protocol.PacketNumber(1); i <= numRanges; i++ { + ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i} + } + f := &AckFrame{AckRanges: ackRanges} + Expect(f.validateAckRanges()).To(BeTrue()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize + Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) + Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges + }) + }) + + Context("ACK range validator", func() { + It("rejects ACKs without ranges", func() { + Expect((&AckFrame{}).validateAckRanges()).To(BeFalse()) + }) + + It("accepts an ACK without NACK Ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 7}}, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + + It("rejects ACK ranges with Smallest greater than Largest", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 8, Largest: 10}, + {Smallest: 4, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects ACK ranges in the wrong order", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 2, Largest: 2}, + {Smallest: 6, Largest: 7}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects with overlapping ACK ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 7}, + {Smallest: 2, Largest: 5}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects ACK ranges that are part of a larger ACK range", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 4, Largest: 7}, + {Smallest: 5, Largest: 6}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("rejects with directly adjacent ACK ranges", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 7}, + {Smallest: 2, Largest: 4}, + }, + } + Expect(ack.validateAckRanges()).To(BeFalse()) + }) + + It("accepts an ACK with one lost packet", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 10}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + + It("accepts an ACK with multiple lost packets", func() { + ack := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 15, Largest: 20}, + {Smallest: 10, Largest: 12}, + {Smallest: 1, Largest: 3}, + }, + } + Expect(ack.validateAckRanges()).To(BeTrue()) + }) + }) + + Context("check if ACK frame acks a certain packet", func() { + It("works with an ACK without any ranges", func() { + f := AckFrame{ + AckRanges: []AckRange{{Smallest: 5, Largest: 10}}, + } + Expect(f.AcksPacket(1)).To(BeFalse()) + Expect(f.AcksPacket(4)).To(BeFalse()) + Expect(f.AcksPacket(5)).To(BeTrue()) + Expect(f.AcksPacket(8)).To(BeTrue()) + Expect(f.AcksPacket(10)).To(BeTrue()) + Expect(f.AcksPacket(11)).To(BeFalse()) + Expect(f.AcksPacket(20)).To(BeFalse()) + }) + + It("works with an ACK with multiple ACK ranges", func() { + f := AckFrame{ + AckRanges: []AckRange{ + {Smallest: 15, Largest: 20}, + {Smallest: 5, Largest: 8}, + }, + } + Expect(f.AcksPacket(4)).To(BeFalse()) + Expect(f.AcksPacket(5)).To(BeTrue()) + Expect(f.AcksPacket(6)).To(BeTrue()) + Expect(f.AcksPacket(7)).To(BeTrue()) + Expect(f.AcksPacket(8)).To(BeTrue()) + Expect(f.AcksPacket(9)).To(BeFalse()) + Expect(f.AcksPacket(14)).To(BeFalse()) + Expect(f.AcksPacket(15)).To(BeTrue()) + Expect(f.AcksPacket(18)).To(BeTrue()) + Expect(f.AcksPacket(19)).To(BeTrue()) + Expect(f.AcksPacket(20)).To(BeTrue()) + Expect(f.AcksPacket(21)).To(BeFalse()) + }) + }) +}) diff --git a/internal/quic-go/wire/ack_range.go b/internal/quic-go/wire/ack_range.go new file mode 100644 index 00000000..e373835c --- /dev/null +++ b/internal/quic-go/wire/ack_range.go @@ -0,0 +1,14 @@ +package wire + +import "github.com/imroc/req/v3/internal/quic-go/protocol" + +// AckRange is an ACK range +type AckRange struct { + Smallest protocol.PacketNumber + Largest protocol.PacketNumber +} + +// Len returns the number of packets contained in this ACK range +func (r AckRange) Len() protocol.PacketNumber { + return r.Largest - r.Smallest + 1 +} diff --git a/internal/quic-go/wire/ack_range_test.go b/internal/quic-go/wire/ack_range_test.go new file mode 100644 index 00000000..84ef71b5 --- /dev/null +++ b/internal/quic-go/wire/ack_range_test.go @@ -0,0 +1,13 @@ +package wire + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ACK range", func() { + It("returns the length", func() { + Expect(AckRange{Smallest: 10, Largest: 10}.Len()).To(BeEquivalentTo(1)) + Expect(AckRange{Smallest: 10, Largest: 13}.Len()).To(BeEquivalentTo(4)) + }) +}) diff --git a/internal/quic-go/wire/connection_close_frame.go b/internal/quic-go/wire/connection_close_frame.go new file mode 100644 index 00000000..1fe48837 --- /dev/null +++ b/internal/quic-go/wire/connection_close_frame.go @@ -0,0 +1,83 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A ConnectionCloseFrame is a CONNECTION_CLOSE frame +type ConnectionCloseFrame struct { + IsApplicationError bool + ErrorCode uint64 + FrameType uint64 + ReasonPhrase string +} + +func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} + ec, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.ErrorCode = ec + // read the Frame Type, if this is not an application error + if !f.IsApplicationError { + ft, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.FrameType = ft + } + var reasonPhraseLen uint64 + reasonPhraseLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unnecessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the whole reason phrase would result in EOF when attempting to READ + if int(reasonPhraseLen) > r.Len() { + return nil, io.EOF + } + + reasonPhrase := make([]byte, reasonPhraseLen) + if _, err := io.ReadFull(r, reasonPhrase); err != nil { + // this should never happen, since we already checked the reasonPhraseLen earlier + return nil, err + } + f.ReasonPhrase = string(reasonPhrase) + return f, nil +} + +// Length of a written frame +func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) + if !f.IsApplicationError { + length += quicvarint.Len(f.FrameType) // for the frame type + } + return length +} + +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if f.IsApplicationError { + b.WriteByte(0x1d) + } else { + b.WriteByte(0x1c) + } + + quicvarint.Write(b, f.ErrorCode) + if !f.IsApplicationError { + quicvarint.Write(b, f.FrameType) + } + quicvarint.Write(b, uint64(len(f.ReasonPhrase))) + b.WriteString(f.ReasonPhrase) + return nil +} diff --git a/internal/quic-go/wire/connection_close_frame_test.go b/internal/quic-go/wire/connection_close_frame_test.go new file mode 100644 index 00000000..c507fc8a --- /dev/null +++ b/internal/quic-go/wire/connection_close_frame_test.go @@ -0,0 +1,153 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CONNECTION_CLOSE Frame", func() { + Context("when parsing", func() { + It("accepts sample frame containing a QUIC error code", func() { + reason := "No recent network activity." + data := []byte{0x1c} + data = append(data, encodeVarInt(0x19)...) + data = append(data, encodeVarInt(0x1337)...) // frame type + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, []byte(reason)...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.IsApplicationError).To(BeFalse()) + Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) + Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) + Expect(frame.ReasonPhrase).To(Equal(reason)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts sample frame containing an application error code", func() { + reason := "The application messed things up." + data := []byte{0x1d} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, reason...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.IsApplicationError).To(BeTrue()) + Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) + Expect(frame.ReasonPhrase).To(Equal(reason)) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects long reason phrases", func() { + data := []byte{0x1c} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(0x42)...) // frame type + data = append(data, encodeVarInt(0xffff)...) // reason phrase length + b := bytes.NewReader(data) + _, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + reason := "No recent network activity." + data := []byte{0x1c} + data = append(data, encodeVarInt(0x19)...) + data = append(data, encodeVarInt(0x1337)...) // frame type + data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length + data = append(data, []byte(reason)...) + _, err := parseConnectionCloseFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + + It("parses a frame without a reason phrase", func() { + data := []byte{0x1c} + data = append(data, encodeVarInt(0xcafe)...) + data = append(data, encodeVarInt(0x42)...) // frame type + data = append(data, encodeVarInt(0)...) + b := bytes.NewReader(data) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.ReasonPhrase).To(BeEmpty()) + Expect(b.Len()).To(BeZero()) + }) + }) + + Context("when writing", func() { + It("writes a frame without a reason phrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xbeef, + FrameType: 0x12345, + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1c} + expected = append(expected, encodeVarInt(0xbeef)...) + expected = append(expected, encodeVarInt(0x12345)...) // frame type + expected = append(expected, encodeVarInt(0)...) // reason phrase length + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with a reason phrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xdead, + ReasonPhrase: "foobar", + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1c} + expected = append(expected, encodeVarInt(0xdead)...) + expected = append(expected, encodeVarInt(0)...) // frame type + expected = append(expected, encodeVarInt(6)...) // reason phrase length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with an application error code", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 0xdead, + ReasonPhrase: "foobar", + } + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x1d} + expected = append(expected, encodeVarInt(0xdead)...) + expected = append(expected, encodeVarInt(6)...) // reason phrase length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has proper min length, for a frame containing a QUIC error code", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + ErrorCode: 0xcafe, + FrameType: 0xdeadbeef, + ReasonPhrase: "foobar", + } + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + }) + + It("has proper min length, for a frame containing an application error code", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 0xcafe, + ReasonPhrase: "foobar", + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + }) + }) +}) diff --git a/internal/quic-go/wire/crypto_frame.go b/internal/quic-go/wire/crypto_frame.go new file mode 100644 index 00000000..3e7e1808 --- /dev/null +++ b/internal/quic-go/wire/crypto_frame.go @@ -0,0 +1,102 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A CryptoFrame is a CRYPTO frame +type CryptoFrame struct { + Offset protocol.ByteCount + Data []byte +} + +func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &CryptoFrame{} + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + dataLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + return frame, nil +} + +func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x6) + quicvarint.Write(b, uint64(f.Offset)) + quicvarint.Write(b, uint64(len(f.Data))) + b.Write(f.Data) + return nil +} + +// Length of a written frame +func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1 + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { + if f.Length(version) <= maxSize { + return nil, false + } + + n := f.MaxDataLen(maxSize) + if n == 0 { + return nil, true + } + + newLen := protocol.ByteCount(len(f.Data)) - n + + new := &CryptoFrame{} + new.Offset = f.Offset + new.Data = make([]byte, newLen) + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} diff --git a/internal/quic-go/wire/crypto_frame_test.go b/internal/quic-go/wire/crypto_frame_test.go new file mode 100644 index 00000000..a52229e1 --- /dev/null +++ b/internal/quic-go/wire/crypto_frame_test.go @@ -0,0 +1,148 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CRYPTO frame", func() { + Context("when parsing", func() { + It("parses", func() { + data := []byte{0x6} + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseCryptoFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(r.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x6} + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // data length + data = append(data, []byte("foobar")...) + _, err := parseCryptoFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes a frame", func() { + f := &CryptoFrame{ + Offset: 0x123456, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x6} + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, encodeVarInt(6)...) // length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("always returns a data length such that the resulting frame has the right size", func() { + data := make([]byte, maxSize) + f := &CryptoFrame{ + Offset: 0xdeadbeef, + } + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < maxSize; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) + if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written + // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) + + Context("length", func() { + It("has the right length for a frame without offset and data length", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) + }) + }) + + Context("splitting", func() { + It("splits a frame", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + hdrLen := f.Length(protocol.Version1) - 6 + new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(new.Data).To(Equal([]byte("foo"))) + Expect(new.Offset).To(Equal(protocol.ByteCount(0x1337))) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x1337 + 3))) + }) + + It("doesn't split if there's enough space in the frame", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + f, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) + Expect(needsSplit).To(BeFalse()) + Expect(f).To(BeNil()) + }) + + It("doesn't split if the size is too small", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + } + length := f.Length(protocol.Version1) - 6 + for i := protocol.ByteCount(0); i <= length; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + f, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).ToNot(BeNil()) + }) + }) +}) diff --git a/internal/quic-go/wire/data_blocked_frame.go b/internal/quic-go/wire/data_blocked_frame.go new file mode 100644 index 00000000..ddce5a02 --- /dev/null +++ b/internal/quic-go/wire/data_blocked_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A DataBlockedFrame is a DATA_BLOCKED frame +type DataBlockedFrame struct { + MaximumData protocol.ByteCount +} + +func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &DataBlockedFrame{ + MaximumData: protocol.ByteCount(offset), + }, nil +} + +func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + typeByte := uint8(0x14) + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaximumData)) +} diff --git a/internal/quic-go/wire/data_blocked_frame_test.go b/internal/quic-go/wire/data_blocked_frame_test.go new file mode 100644 index 00000000..8f19310b --- /dev/null +++ b/internal/quic-go/wire/data_blocked_frame_test.go @@ -0,0 +1,54 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("DATA_BLOCKED frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x14} + data = append(data, encodeVarInt(0x12345678)...) + b := bytes.NewReader(data) + frame, err := parseDataBlockedFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x14} + data = append(data, encodeVarInt(0x12345678)...) + _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := DataBlockedFrame{MaximumData: 0xdeadbeef} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x14} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := DataBlockedFrame{MaximumData: 0x12345} + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x12345))) + }) + }) +}) diff --git a/internal/quic-go/wire/datagram_frame.go b/internal/quic-go/wire/datagram_frame.go new file mode 100644 index 00000000..1b3aeb96 --- /dev/null +++ b/internal/quic-go/wire/datagram_frame.go @@ -0,0 +1,85 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A DatagramFrame is a DATAGRAM frame +type DatagramFrame struct { + DataLenPresent bool + Data []byte +} + +func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &DatagramFrame{} + f.DataLenPresent = typeByte&0x1 > 0 + + var length uint64 + if f.DataLenPresent { + var err error + len, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if len > uint64(r.Len()) { + return nil, io.EOF + } + length = len + } else { + length = uint64(r.Len()) + } + f.Data = make([]byte, length) + if _, err := io.ReadFull(r, f.Data); err != nil { + return nil, err + } + return f, nil +} + +func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x30) + if f.DataLenPresent { + typeByte ^= 0x1 + } + b.WriteByte(typeByte) + if f.DataLenPresent { + quicvarint.Write(b, uint64(len(f.Data))) + } + b.Write(f.Data) + return nil +} + +// MaxDataLen returns the maximum data length +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := protocol.ByteCount(1) + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// Length of a written frame +func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + length := 1 + protocol.ByteCount(len(f.Data)) + if f.DataLenPresent { + length += quicvarint.Len(uint64(len(f.Data))) + } + return length +} diff --git a/internal/quic-go/wire/datagram_frame_test.go b/internal/quic-go/wire/datagram_frame_test.go new file mode 100644 index 00000000..6fc581ce --- /dev/null +++ b/internal/quic-go/wire/datagram_frame_test.go @@ -0,0 +1,154 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame", func() { + Context("when parsing", func() { + It("parses a frame containing a length", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + }) + + It("parses a frame without length", func() { + data := []byte{0x30} + data = append(data, []byte("Lorem ipsum dolor sit amet")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) + Expect(frame.DataLenPresent).To(BeFalse()) + Expect(r.Len()).To(BeZero()) + }) + + It("errors when the length is longer than the rest of the frame", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("fooba")...) + r := bytes.NewReader(data) + _, err := parseDatagramFrame(r, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(6)...) // length + data = append(data, []byte("foobar")...) + _, err := parseDatagramFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x30 ^ 0x1} + expected = append(expected, encodeVarInt(0x6)...) + expected = append(expected, []byte("foobar")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a frame without length", func() { + f := &DatagramFrame{Data: []byte("Lorem ipsum")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x30} + expected = append(expected, []byte("Lorem ipsum")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("length", func() { + It("has the right length for a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(6) + 6)) + }) + + It("has the right length for a frame without length", func() { + f := &DatagramFrame{Data: []byte("foobar")} + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(1 + 6))) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("returns a data length such that the resulting frame has the right size, if data length is not present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{} + b := &bytes.Buffer{} + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(Equal(i)) + } + }) + + It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{DataLenPresent: true} + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) +}) diff --git a/internal/quic-go/wire/extended_header.go b/internal/quic-go/wire/extended_header.go new file mode 100644 index 00000000..766ccbc1 --- /dev/null +++ b/internal/quic-go/wire/extended_header.go @@ -0,0 +1,249 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// ErrInvalidReservedBits is returned when the reserved bits are incorrect. +// When this error is returned, parsing continues, and an ExtendedHeader is returned. +// This is necessary because we need to decrypt the packet in that case, +// in order to avoid a timing side-channel. +var ErrInvalidReservedBits = errors.New("invalid reserved bits") + +// ExtendedHeader is the header of a QUIC packet. +type ExtendedHeader struct { + Header + + typeByte byte + + KeyPhase protocol.KeyPhaseBit + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + parsedLen protocol.ByteCount +} + +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { + startLen := b.Len() + // read the (now unencrypted) first byte + var err error + h.typeByte, err = b.ReadByte() + if err != nil { + return false, err + } + if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { + return false, err + } + var reservedBitsValid bool + if h.IsLongHeader { + reservedBitsValid, err = h.parseLongHeader(b, v) + } else { + reservedBitsValid, err = h.parseShortHeader(b, v) + } + if err != nil { + return false, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return reservedBitsValid, err +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0xc != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { + h.KeyPhase = protocol.KeyPhaseZero + if h.typeByte&0x4 > 0 { + h.KeyPhase = protocol.KeyPhaseOne + } + + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0x18 != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { + h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + n, err := b.ReadByte() + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen2: + n, err := utils.BigEndian.ReadUint16(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen3: + n, err := utils.BigEndian.ReadUint24(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen4: + n, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// Write writes the Header. +func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { + if h.DestConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + } + if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + } + if h.IsLongHeader { + return h.writeLongHeader(b, ver) + } + return h.writeShortHeader(b, ver) +} + +func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { + var packetType uint8 + if version == protocol.Version2 { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b01 + case protocol.PacketType0RTT: + packetType = 0b10 + case protocol.PacketTypeHandshake: + packetType = 0b11 + case protocol.PacketTypeRetry: + packetType = 0b00 + } + } else { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b00 + case protocol.PacketType0RTT: + packetType = 0b01 + case protocol.PacketTypeHandshake: + packetType = 0b10 + case protocol.PacketTypeRetry: + packetType = 0b11 + } + } + firstByte := 0xc0 | packetType<<4 + if h.Type != protocol.PacketTypeRetry { + // Retry packets don't have a packet number + firstByte |= uint8(h.PacketNumberLen - 1) + } + + b.WriteByte(firstByte) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + b.WriteByte(uint8(h.DestConnectionID.Len())) + b.Write(h.DestConnectionID.Bytes()) + b.WriteByte(uint8(h.SrcConnectionID.Len())) + b.Write(h.SrcConnectionID.Bytes()) + + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeRetry: + b.Write(h.Token) + return nil + case protocol.PacketTypeInitial: + quicvarint.Write(b, uint64(len(h.Token))) + b.Write(h.Token) + } + quicvarint.WriteWithLen(b, uint64(h.Length), 2) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := 0x40 | uint8(h.PacketNumberLen-1) + if h.KeyPhase == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } + + b.WriteByte(typeByte) + b.Write(h.DestConnectionID.Bytes()) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen3: + utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// GetLength determines the length of the Header. +func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { + if h.IsLongHeader { + length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ + if h.Type == protocol.PacketTypeInitial { + length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length + } + + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) + length += protocol.ByteCount(h.PacketNumberLen) + return length +} + +// Log logs the Header +func (h *ExtendedHeader) Log(logger utils.Logger) { + if h.IsLongHeader { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) + return + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} diff --git a/internal/quic-go/wire/extended_header_test.go b/internal/quic-go/wire/extended_header_test.go new file mode 100644 index 00000000..4ec4cde1 --- /dev/null +++ b/internal/quic-go/wire/extended_header_test.go @@ -0,0 +1,481 @@ +package wire + +import ( + "bytes" + "log" + "os" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Header", func() { + const versionIETFHeader = protocol.VersionTLS // a QUIC version that uses the IETF Header format + + Context("Writing", func() { + var buf *bytes.Buffer + + BeforeEach(func() { + buf = &bytes.Buffer{} + }) + + Context("Long Header", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + + It("writes", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, + Version: 0x1020304, + Length: protocol.InitialPacketSizeIPv4, + }, + PacketNumber: 0xdecaf, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{ + 0xc0 | 0x2<<4 | 0x2, + 0x1, 0x2, 0x3, 0x4, // version number + 0x6, // dest connection ID length + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID + 0x8, // src connection ID length + 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID + } + expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length + expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("refuses to write a header with a too long connection ID", func() { + err := (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + SrcConnectionID: srcConnID, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long + Version: 0x1020304, + Type: 0x5, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader) + Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) + }) + + It("writes a header with a 20 byte connection ID", func() { + err := (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + SrcConnectionID: srcConnID, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long + Version: 0x1020304, + Type: 0x5, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) + }) + + It("writes an Initial containing a token", func() { + token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: 0x1020304, + Type: protocol.PacketTypeInitial, + Token: token, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0) + expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) + Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) + }) + + It("uses a 2-byte encoding for the length on Initial packets", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: 0x1020304, + Type: protocol.PacketTypeInitial, + Length: 37, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, versionIETFHeader)).To(Succeed()) + b := &bytes.Buffer{} + quicvarint.WriteWithLen(b, 37, 2) + Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) + }) + + It("writes a Retry packet", func() { + token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") + Expect((&ExtendedHeader{Header: Header{ + IsLongHeader: true, + Version: protocol.Version1, + Type: protocol.PacketTypeRetry, + Token: token, + }}).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version1) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length + expected = append(expected, token...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("long header, version 2", func() { + It("writes an Initial", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeInitial, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b01) + }) + + It("writes a Retry packet", func() { + token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") + Expect((&ExtendedHeader{Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeRetry, + Token: token, + }}).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version2) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length + expected = append(expected, token...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a Handshake Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeHandshake, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b11) + }) + + It("writes a 0-RTT Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketType0RTT, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b10) + }) + }) + + Context("short header", func() { + It("writes a header with connection ID", func() { + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0x42, // packet number + })) + }) + + It("writes a header without connection ID", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40, + 0x42, // packet number + })) + }) + + It("writes a header with a 2 byte packet number", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen2, + PacketNumber: 0x765, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0x40 | 0x1} + expected = append(expected, []byte{0x7, 0x65}...) // packet number + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a header with a 4 byte packet number", func() { + Expect((&ExtendedHeader{ + PacketNumberLen: protocol.PacketNumberLen4, + PacketNumber: 0x12345678, + }).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0x40 | 0x3} + expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("errors when given an invalid packet number length", func() { + err := (&ExtendedHeader{ + PacketNumberLen: 5, + PacketNumber: 0xdecafbad, + }).Write(buf, versionIETFHeader) + Expect(err).To(MatchError("invalid packet number length: 5")) + }) + + It("writes the Key Phase Bit", func() { + Expect((&ExtendedHeader{ + KeyPhase: protocol.KeyPhaseOne, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, + }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()).To(Equal([]byte{ + 0x40 | 0x4, + 0x42, // packet number + })) + }) + }) + }) + + Context("getting the length", func() { + var buf *bytes.Buffer + + BeforeEach(func() { + buf = &bytes.Buffer{} + }) + + It("has the right length for the Long Header, for a short length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Length: 1, + }, + PacketNumberLen: protocol.PacketNumberLen1, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for the Long Header, for a long length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Length: 1500, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial that has a short length", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 15, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial not containing a Token", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 1500, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial containing a Token", func() { + h := &ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Type: protocol.PacketTypeInitial, + Length: 1500, + Token: []byte("foo"), + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for a Short Header containing a connection ID", func() { + h := &ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: protocol.PacketNumberLen1, + } + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(10)) + }) + + It("has the right length for a short header without a connection ID", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(2)) + }) + + It("has the right length for a short header with a 2 byte packet number", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(3)) + }) + + It("has the right length for a short header with a 5 byte packet number", func() { + h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} + Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4))) + Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Len()).To(Equal(5)) + }) + }) + + Context("Logging", func() { + var ( + buf *bytes.Buffer + logger utils.Logger + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) + }) + + AfterEach(func() { + log.SetOutput(os.Stdout) + }) + + It("logs Long Headers", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, + Type: protocol.PacketTypeHandshake, + Length: 54321, + Version: 0xfeed, + }, + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: deadbeefcafe1337, SrcConnectionID: decafbad13371337, PacketNumber: 1337, PacketNumberLen: 2, Length: 54321, Version: 0xfeed}")) + }) + + It("logs Initial Packets with a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeInitial, + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Length: 100, + Version: 0xfeed, + }, + PacketNumber: 42, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0xdeadbeef, PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) + }) + + It("logs Initial packets without a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeInitial, + Length: 100, + Version: 0xfeed, + }, + PacketNumber: 42, + PacketNumberLen: protocol.PacketNumberLen2, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: (empty), PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) + }) + + It("logs Retry packets with a Token", func() { + (&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Type: protocol.PacketTypeRetry, + Token: []byte{0x12, 0x34, 0x56}, + Version: 0xfeed, + }, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}")) + }) + + It("logs Short Headers containing a connection ID", func() { + (&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + }, + KeyPhase: protocol.KeyPhaseOne, + PacketNumber: 1337, + PacketNumberLen: 4, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) + }) + }) +}) diff --git a/internal/quic-go/wire/frame_parser.go b/internal/quic-go/wire/frame_parser.go new file mode 100644 index 00000000..b1a3659b --- /dev/null +++ b/internal/quic-go/wire/frame_parser.go @@ -0,0 +1,143 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" +) + +type frameParser struct { + ackDelayExponent uint8 + + supportsDatagrams bool + + version protocol.VersionNumber +} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { + return &frameParser{ + supportsDatagrams: supportsDatagrams, + version: v, + } +} + +// ParseNext parses the next frame. +// It skips PADDING frames. +func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { + for r.Len() != 0 { + typeByte, _ := r.ReadByte() + if typeByte == 0x0 { // PADDING frame + continue + } + r.UnreadByte() + + f, err := p.parseFrame(r, typeByte, encLevel) + if err != nil { + return nil, &qerr.TransportError{ + FrameType: uint64(typeByte), + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } + } + return f, nil + } + return nil, nil +} + +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { + var frame Frame + var err error + if typeByte&0xf8 == 0x8 { + frame, err = parseStreamFrame(r, p.version) + } else { + switch typeByte { + case 0x1: + frame, err = parsePingFrame(r, p.version) + case 0x2, 0x3: + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) + case 0x4: + frame, err = parseResetStreamFrame(r, p.version) + case 0x5: + frame, err = parseStopSendingFrame(r, p.version) + case 0x6: + frame, err = parseCryptoFrame(r, p.version) + case 0x7: + frame, err = parseNewTokenFrame(r, p.version) + case 0x10: + frame, err = parseMaxDataFrame(r, p.version) + case 0x11: + frame, err = parseMaxStreamDataFrame(r, p.version) + case 0x12, 0x13: + frame, err = parseMaxStreamsFrame(r, p.version) + case 0x14: + frame, err = parseDataBlockedFrame(r, p.version) + case 0x15: + frame, err = parseStreamDataBlockedFrame(r, p.version) + case 0x16, 0x17: + frame, err = parseStreamsBlockedFrame(r, p.version) + case 0x18: + frame, err = parseNewConnectionIDFrame(r, p.version) + case 0x19: + frame, err = parseRetireConnectionIDFrame(r, p.version) + case 0x1a: + frame, err = parsePathChallengeFrame(r, p.version) + case 0x1b: + frame, err = parsePathResponseFrame(r, p.version) + case 0x1c, 0x1d: + frame, err = parseConnectionCloseFrame(r, p.version) + case 0x1e: + frame, err = parseHandshakeDoneFrame(r, p.version) + case 0x30, 0x31: + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, p.version) + break + } + fallthrough + default: + err = errors.New("unknown frame type") + } + } + if err != nil { + return nil, err + } + if !p.isAllowedAtEncLevel(frame, encLevel) { + return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + } + return frame, nil +} + +func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: + return true + default: + return false + } + case protocol.Encryption0RTT: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + return false + default: + return true + } + case protocol.Encryption1RTT: + return true + default: + panic("unknown encryption level") + } +} + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/internal/quic-go/wire/frame_parser_test.go b/internal/quic-go/wire/frame_parser_test.go new file mode 100644 index 00000000..f46bafd8 --- /dev/null +++ b/internal/quic-go/wire/frame_parser_test.go @@ -0,0 +1,410 @@ +package wire + +import ( + "bytes" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Frame parsing", func() { + var ( + buf *bytes.Buffer + parser FrameParser + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + parser = NewFrameParser(true, protocol.Version1) + }) + + It("returns nil if there's nothing more to read", func() { + f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeNil()) + }) + + It("skips PADDING frames", func() { + buf.Write([]byte{0}) // PADDING frame + (&PingFrame{}).Write(buf, protocol.Version1) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&PingFrame{})) + }) + + It("handles PADDING at the end", func() { + r := bytes.NewReader([]byte{0, 0, 0}) + f, err := parser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeNil()) + Expect(r.Len()).To(BeZero()) + }) + + It("unpacks ACK frames", func() { + f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) + }) + + It("uses the custom ack delay exponent for 1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + // The ACK frame is always written using the protocol.AckDelayExponent. + // That's why we expect a different value when parsing. + Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) + }) + + It("uses the default ack delay exponent for non-1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) + }) + + It("unpacks RESET_STREAM frames", func() { + f := &ResetStreamFrame{ + StreamID: 0xdeadbeef, + FinalSize: 0xdecafbad1234, + ErrorCode: 0x1337, + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STOP_SENDING frames", func() { + f := &StopSendingFrame{StreamID: 0x42} + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks CRYPTO frames", func() { + f := &CryptoFrame{ + Offset: 0x1337, + Data: []byte("lorem ipsum"), + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks NEW_TOKEN frames", func() { + f := &NewTokenFrame{Token: []byte("foobar")} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAM frames", func() { + f := &StreamFrame{ + StreamID: 0x42, + Offset: 0x1337, + Fin: true, + Data: []byte("foobar"), + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_DATA frames", func() { + f := &MaxDataFrame{ + MaximumData: 0xcafe, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_STREAM_DATA frames", func() { + f := &MaxStreamDataFrame{ + StreamID: 0xdeadbeef, + MaximumStreamData: 0xdecafbad, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks MAX_STREAMS frames", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 0x1337, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks DATA_BLOCKED frames", func() { + f := &DataBlockedFrame{MaximumData: 0x1234} + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAM_DATA_BLOCKED frames", func() { + f := &StreamDataBlockedFrame{ + StreamID: 0xdeadbeef, + MaximumStreamData: 0xdead, + } + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks STREAMS_BLOCKED frames", func() { + f := &StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 0x1234567, + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks NEW_CONNECTION_ID frames", func() { + f := &NewConnectionIDFrame{ + SequenceNumber: 0x1337, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + } + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks RETIRE_CONNECTION_ID frames", func() { + f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks PATH_CHALLENGE frames", func() { + f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*PathChallengeFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("unpacks PATH_RESPONSE frames", func() { + f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).ToNot(BeNil()) + Expect(frame).To(BeAssignableToTypeOf(f)) + Expect(frame.(*PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("unpacks CONNECTION_CLOSE frames", func() { + f := &ConnectionCloseFrame{ + IsApplicationError: true, + ReasonPhrase: "foobar", + } + buf := &bytes.Buffer{} + err := f.Write(buf, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks HANDSHAKE_DONE frames", func() { + f := &HandshakeDoneFrame{} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("unpacks DATAGRAM frames", func() { + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when DATAGRAM frames are not supported", func() { + parser = NewFrameParser(false, protocol.Version1) + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x30, + ErrorMessage: "unknown frame type", + })) + }) + + It("errors on invalid type", func() { + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x42, + ErrorMessage: "unknown frame type", + })) + }) + + It("errors on invalid frames", func() { + f := &MaxStreamDataFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + b := &bytes.Buffer{} + f.Write(b, protocol.Version1) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) + Expect(err).To(HaveOccurred()) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + }) + + Context("encryption level check", func() { + frames := []Frame{ + &PingFrame{}, + &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 42}}}, + &ResetStreamFrame{}, + &StopSendingFrame{}, + &CryptoFrame{}, + &NewTokenFrame{Token: []byte("lorem ipsum")}, + &StreamFrame{Data: []byte("foobar")}, + &MaxDataFrame{}, + &MaxStreamDataFrame{}, + &MaxStreamsFrame{}, + &DataBlockedFrame{}, + &StreamDataBlockedFrame{}, + &StreamsBlockedFrame{}, + &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + &RetireConnectionIDFrame{}, + &PathChallengeFrame{}, + &PathResponseFrame{}, + &ConnectionCloseFrame{}, + &HandshakeDoneFrame{}, + &DatagramFrame{}, + } + + var framesSerialized [][]byte + + BeforeEach(func() { + framesSerialized = nil + for _, frame := range frames { + buf := &bytes.Buffer{} + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) + framesSerialized = append(framesSerialized, buf.Bytes()) + } + }) + + It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Initial packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionInitial) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Initial")) + } + } + }) + + It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Handshake packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Handshake")) + } + } + }) + + It("rejects all frames but ACK, CRYPTO, CONNECTION_CLOSE, NEW_TOKEN, PATH_RESPONSE and RETIRE_CONNECTION_ID in 0-RTT packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption0RTT) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level 0-RTT")) + default: + Expect(err).ToNot(HaveOccurred()) + } + } + }) + + It("accepts all frame types in 1-RTT packets", func() { + for _, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) +}) diff --git a/internal/quic-go/wire/handshake_done_frame.go b/internal/quic-go/wire/handshake_done_frame.go new file mode 100644 index 00000000..d940ddda --- /dev/null +++ b/internal/quic-go/wire/handshake_done_frame.go @@ -0,0 +1,28 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A HandshakeDoneFrame is a HANDSHAKE_DONE frame +type HandshakeDoneFrame struct{} + +// ParseHandshakeDoneFrame parses a HandshakeDone frame +func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &HandshakeDoneFrame{}, nil +} + +func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1e) + return nil +} + +// Length of a written frame +func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/internal/quic-go/wire/header.go b/internal/quic-go/wire/header.go new file mode 100644 index 00000000..8455e748 --- /dev/null +++ b/internal/quic-go/wire/header.go @@ -0,0 +1,274 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// ParseConnectionID parses the destination connection ID of a packet. +// It uses the data slice for the connection ID. +// That means that the connection ID must not be used after the packet buffer is released. +func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { + if len(data) == 0 { + return nil, io.EOF + } + isLongHeader := data[0]&0x80 > 0 + if !isLongHeader { + if len(data) < shortHeaderConnIDLen+1 { + return nil, io.EOF + } + return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + } + if len(data) < 6 { + return nil, io.EOF + } + destConnIDLen := int(data[5]) + if len(data) < 6+destConnIDLen { + return nil, io.EOF + } + return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil +} + +// IsVersionNegotiationPacket says if this is a version negotiation packet +func IsVersionNegotiationPacket(b []byte) bool { + if len(b) < 5 { + return false + } + return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 +} + +// Is0RTTPacket says if this is a 0-RTT packet. +// A packet sent with a version we don't understand can never be a 0-RTT packet. +func Is0RTTPacket(b []byte) bool { + if len(b) < 5 { + return false + } + if b[0]&0x80 == 0 { + return false + } + version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { + return false + } + if version == protocol.Version2 { + return b[0]>>4&0b11 == 0b10 + } + return b[0]>>4&0b11 == 0b01 +} + +var ErrUnsupportedVersion = errors.New("unsupported version") + +// The Header is the version independent part of the header +type Header struct { + IsLongHeader bool + typeByte byte + Type protocol.PacketType + + Version protocol.VersionNumber + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID + + Length protocol.ByteCount + + Token []byte + + parsedLen protocol.ByteCount // how many bytes were read while parsing this header +} + +// ParsePacket parses a packet. +// If the packet has a long header, the packet is cut according to the length field. +// If we understand the version, the packet is header up unto the packet number. +// Otherwise, only the invariant part of the header is parsed. +func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { + hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) + if err != nil { + if err == ErrUnsupportedVersion { + return hdr, nil, nil, ErrUnsupportedVersion + } + return nil, nil, nil, err + } + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { + return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + return hdr, data, rest, nil +} + +// ParseHeader parses the header. +// For short header packets: up to the packet number. +// For long header packets: +// * if we understand the version: up to the packet number +// * if not, only the invariant part of the header +func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + startLen := b.Len() + h, err := parseHeaderImpl(b, shortHeaderConnIDLen) + if err != nil { + return h, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return h, err +} + +func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + + h := &Header{ + typeByte: typeByte, + IsLongHeader: typeByte&0x80 > 0, + } + + if !h.IsLongHeader { + if h.typeByte&0x40 == 0 { + return nil, errors.New("not a QUIC packet") + } + if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { + return nil, err + } + return h, nil + } + return h, h.parseLongHeader(b) +} + +func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { + var err error + h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) + return err +} + +func (h *Header) parseLongHeader(b *bytes.Reader) error { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.Version = protocol.VersionNumber(v) + if h.Version != 0 && h.typeByte&0x40 == 0 { + return errors.New("not a QUIC packet") + } + destConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) + if err != nil { + return err + } + srcConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) + if err != nil { + return err + } + if h.Version == 0 { // version negotiation packet + return nil + } + // If we don't understand the version, we have no idea how to interpret the rest of the bytes + if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { + return ErrUnsupportedVersion + } + + if h.Version == protocol.Version2 { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeRetry + case 0b01: + h.Type = protocol.PacketTypeInitial + case 0b10: + h.Type = protocol.PacketType0RTT + case 0b11: + h.Type = protocol.PacketTypeHandshake + } + } else { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeInitial + case 0b01: + h.Type = protocol.PacketType0RTT + case 0b10: + h.Type = protocol.PacketTypeHandshake + case 0b11: + h.Type = protocol.PacketTypeRetry + } + } + + if h.Type == protocol.PacketTypeRetry { + tokenLen := b.Len() - 16 + if tokenLen <= 0 { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + _, err := b.Seek(16, io.SeekCurrent) + return err + } + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := quicvarint.Read(b) + if err != nil { + return err + } + if tokenLen > uint64(b.Len()) { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + } + + pl, err := quicvarint.Read(b) + if err != nil { + return err + } + h.Length = protocol.ByteCount(pl) + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *Header) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// ParseExtended parses the version dependent part of the header. +// The Reader has to be set such that it points to the first byte of the header. +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { + extHdr := h.toExtendedHeader() + reservedBitsValid, err := extHdr.parse(b, ver) + if err != nil { + return nil, err + } + if !reservedBitsValid { + return extHdr, ErrInvalidReservedBits + } + return extHdr, nil +} + +func (h *Header) toExtendedHeader() *ExtendedHeader { + return &ExtendedHeader{Header: *h} +} + +// PacketType is the type of the packet, for logging purposes +func (h *Header) PacketType() string { + if h.IsLongHeader { + return h.Type.String() + } + return "1-RTT" +} diff --git a/internal/quic-go/wire/header_test.go b/internal/quic-go/wire/header_test.go new file mode 100644 index 00000000..cdcc08b3 --- /dev/null +++ b/internal/quic-go/wire/header_test.go @@ -0,0 +1,583 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Header Parsing", func() { + Context("Parsing the Connection ID", func() { + It("parses the connection ID of a long header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + connID, err := ParseConnectionID(buf.Bytes(), 8) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("parses the connection ID of a short header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + buf.Write([]byte("foobar")) + connID, err := ParseConnectionID(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("errors on EOF, for short header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < len(data); i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + + It("errors on EOF, for long header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("identifying 0-RTT packets", func() { + It("recognizes 0-RTT packets, for QUIC v1", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b01<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) + + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) + Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header + Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) + }) + + It("recognizes 0-RTT packets, for QUIC v2", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b10<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) + + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) + Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header + Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) + }) + }) + + Context("Identifying Version Negotiation Packets", func() { + It("identifies version negotiation packets", func() { + Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) + Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) + }) + + It("returns false on EOF", func() { + vnp := []byte{0x80, 0, 0, 0, 0} + for i := range vnp { + Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) + } + }) + }) + + Context("Long Headers", func() { + It("parses a Long Header", func() { + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + data := []byte{0xc0 ^ 0x3} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x9) // dest conn id length + data = append(data, destConnID...) + data = append(data, 0x4) // src conn id length + data = append(data, srcConnID...) + data = append(data, encodeVarInt(6)...) // token length + data = append(data, []byte("foobar")...) // token + data = append(data, encodeVarInt(10)...) // length + hdrLen := len(data) + data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number + data = append(data, []byte("foobar")...) + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(Equal(data)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) + Expect(hdr.Version).To(Equal(protocol.Version1)) + Expect(rest).To(BeEmpty()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(b.Len()).To(Equal(6)) // foobar + Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4)) + }) + + It("errors if 0x40 is not set", func() { + data := []byte{ + 0x80 | 0x2<<4, + 0x11, // connection ID lengths + 0xde, 0xca, 0xfb, 0xad, // dest conn ID + 0xde, 0xad, 0xbe, 0xef, // src conn ID + } + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError("not a QUIC packet")) + }) + + It("stops parsing when encountering an unsupported version", func() { + data := []byte{ + 0xc0, + 0xde, 0xad, 0xbe, 0xef, + 0x8, // dest conn ID len + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID + 0x8, // src conn ID len + 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID + 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes + } + hdr, _, rest, err := ParsePacket(data, 0) + Expect(err).To(MatchError(ErrUnsupportedVersion)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) + Expect(rest).To(BeEmpty()) + }) + + It("parses a Long Header without a destination connection ID", func() { + data := []byte{0xc0 ^ 0x1<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x0) // dest conn ID len + data = append(data, 0x4) // src conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(hdr.DestConnectionID).To(BeEmpty()) + }) + + It("parses a Long Header without a source connection ID", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0xa) // dest conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID + data = append(data, 0x0) // src conn ID len + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.SrcConnectionID).To(BeEmpty()) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + }) + + It("parses a Long Header with a 2 byte packet number", func() { + data := []byte{0xc0 ^ 0x1} + data = appendVersion(data, protocol.Version1) // version number + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(0)...) // token length + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0x1, 0x23}...) + + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("parses a Retry packet, for QUIC v1", func() { + data := []byte{0xc0 | 0b11<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{6}...) // dest conn ID len + data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID + data = append(data, []byte{10}...) // src conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version1)) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("parses a Retry packet, for QUIC v2", func() { + data := []byte{0xc0 | 0b00<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version2) + data = append(data, []byte{6}...) // dest conn ID len + data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID + data = append(data, []byte{10}...) // src conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version2)) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("errors if the Retry packet is too short for the integrity tag", func() { + data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0, 0}...) // conn ID lens + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) + // this results in a token length of 0 + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors if the token length is too large", func() { + data := []byte{0xc0 ^ 0x1} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x0) // connection ID lengths + data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) + data = append(data, encodeVarInt(0x42)...) // length, 1 byte + data = append(data, []byte{0x12, 0x34}...) // packet number + + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors if the 5th or 6th bit are set", func() { + data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(2)...) // length + data = append(data, []byte{0x12, 0x34}...) // packet number + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(MatchError(ErrInvalidReservedBits)) + Expect(extHdr).ToNot(BeNil()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234))) + }) + + It("errors on EOF, when parsing the header", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 0x8) // dest conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID + data = append(data, 0x8) // src conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID + for i := 0; i < len(data); i++ { + _, _, _, err := ParsePacket(data[:i], 0) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, when parsing the extended header", func() { + data := []byte{0xc0 | 0x2<<4 | 0x3} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(0)...) // length + hdrLen := len(data) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + _, err = hdr.ParseExtended(b, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, for a Retry packet", func() { + data := []byte{0xc0 ^ 0x3<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, 0xa) // Orig Destination Connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + hdrLen := len(data) + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + _, err = hdr.ParseExtended(b, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + + Context("coalesced packets", func() { + It("cuts packets", func() { + buf := &bytes.Buffer{} + hdr := Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 2 + 6, + Version: protocol.Version1, + } + Expect((&ExtendedHeader{ + Header: hdr, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + hdrRaw := append([]byte{}, buf.Bytes()...) + buf.Write([]byte("foobar")) // payload of the first packet + buf.Write([]byte("raboof")) // second packet + parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(parsedHdr.Type).To(Equal(hdr.Type)) + Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) + Expect(rest).To(Equal([]byte("raboof"))) + }) + + It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 3, + Version: protocol.Version1, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) + }) + + It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 1000, + Version: protocol.Version1, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, protocol.Version1)).To(Succeed()) + buf.Write(make([]byte, 500-2 /* for packet number length */)) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + }) + }) + + Context("Short Headers", func() { + It("reads a Short Header with a 8 byte connection ID", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + data := append([]byte{0x40}, connID...) + data = append(data, 0x42) // packet number + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + + hdr, pdata, rest, err := ParsePacket(data, 8) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) + Expect(extHdr.DestConnectionID).To(Equal(connID)) + Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("errors if 0x40 is not set", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + data := append([]byte{0x0}, connID...) + _, _, _, err := ParsePacket(data, 8) + Expect(err).To(MatchError("not a QUIC packet")) + }) + + It("errors if the 4th or 5th bit are set", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) + data = append(data, 0x42) // packet number + hdr, _, _, err := ParsePacket(data, 5) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(MatchError(ErrInvalidReservedBits)) + Expect(extHdr).ToNot(BeNil()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + }) + + It("reads a Short Header with a 5 byte connection ID", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + data := append([]byte{0x40}, connID...) + data = append(data, 0x42) // packet number + hdr, pdata, rest, err := ParsePacket(data, 5) + Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(HaveLen(len(data))) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(hdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) + Expect(extHdr.DestConnectionID).To(Equal(connID)) + Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(rest).To(BeEmpty()) + }) + + It("reads the Key Phase Bit", func() { + data := []byte{ + 0x40 ^ 0x4, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID + } + data = append(data, 11) // packet number + hdr, _, _, err := ParsePacket(data, 6) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a header with a 2 byte packet number", func() { + data := []byte{ + 0x40 | 0x1, + 0xde, 0xad, 0xbe, 0xef, // connection ID + } + data = append(data, []byte{0x13, 0x37}...) // packet number + hdr, _, _, err := ParsePacket(data, 4) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.IsLongHeader).To(BeFalse()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a header with a 3 byte packet number", func() { + data := []byte{ + 0x40 | 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID + } + data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number + hdr, _, _, err := ParsePacket(data, 10) + Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.IsLongHeader).To(BeFalse()) + Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOF, when parsing the header", func() { + data := []byte{ + 0x40 ^ 0x2, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + } + for i := 0; i < len(data); i++ { + data = data[:i] + _, _, _, err := ParsePacket(data, 8) + Expect(err).To(Equal(io.EOF)) + } + }) + + It("errors on EOF, when parsing the extended header", func() { + data := []byte{ + 0x40 ^ 0x3, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID + } + hdrLen := len(data) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number + for i := hdrLen; i < len(data); i++ { + data = data[:i] + hdr, _, _, err := ParsePacket(data, 6) + Expect(err).ToNot(HaveOccurred()) + _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) + Expect(err).To(Equal(io.EOF)) + } + }) + }) + + It("tells its packet type for logging", func() { + Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake")) + Expect((&Header{}).PacketType()).To(Equal("1-RTT")) + }) +}) diff --git a/internal/quic-go/wire/interface.go b/internal/quic-go/wire/interface.go new file mode 100644 index 00000000..b5804af9 --- /dev/null +++ b/internal/quic-go/wire/interface.go @@ -0,0 +1,19 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A Frame in QUIC +type Frame interface { + Write(b *bytes.Buffer, version protocol.VersionNumber) error + Length(version protocol.VersionNumber) protocol.ByteCount +} + +// A FrameParser parses QUIC frames, one by one. +type FrameParser interface { + ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + SetAckDelayExponent(uint8) +} diff --git a/internal/quic-go/wire/log.go b/internal/quic-go/wire/log.go new file mode 100644 index 00000000..030549d2 --- /dev/null +++ b/internal/quic-go/wire/log.go @@ -0,0 +1,72 @@ +package wire + +import ( + "fmt" + "strings" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// LogFrame logs a frame, either sent or received +func LogFrame(logger utils.Logger, frame Frame, sent bool) { + if !logger.Debug() { + return + } + dir := "<-" + if sent { + dir = "->" + } + switch f := frame.(type) { + case *CryptoFrame: + dataLen := protocol.ByteCount(len(f.Data)) + logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) + case *StreamFrame: + logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *ResetStreamFrame: + logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) + case *AckFrame: + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + var ecn string + if hasECN { + ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) + } + if len(f.AckRanges) > 1 { + ackRanges := make([]string, len(f.AckRanges)) + for i, r := range f.AckRanges { + ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) + } + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) + } else { + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) + } + case *MaxDataFrame: + logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) + case *MaxStreamDataFrame: + logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *DataBlockedFrame: + logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) + case *StreamDataBlockedFrame: + logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *MaxStreamsFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) + } + case *StreamsBlockedFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) + } + case *NewConnectionIDFrame: + logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) + case *NewTokenFrame: + logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) + default: + logger.Debugf("\t%s %#v", dir, frame) + } +} diff --git a/internal/quic-go/wire/log_test.go b/internal/quic-go/wire/log_test.go new file mode 100644 index 00000000..38e7b645 --- /dev/null +++ b/internal/quic-go/wire/log_test.go @@ -0,0 +1,168 @@ +package wire + +import ( + "bytes" + "log" + "os" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Frame logging", func() { + var ( + buf *bytes.Buffer + logger utils.Logger + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) + }) + + AfterEach(func() { + log.SetOutput(os.Stdout) + }) + + It("doesn't log when debug is disabled", func() { + logger.SetLogLevel(utils.LogLevelInfo) + LogFrame(logger, &ResetStreamFrame{}, true) + Expect(buf.Len()).To(BeZero()) + }) + + It("logs sent frames", func() { + LogFrame(logger, &ResetStreamFrame{}, true) + Expect(buf.String()).To(ContainSubstring("\t-> &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) + }) + + It("logs received frames", func() { + LogFrame(logger, &ResetStreamFrame{}, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) + }) + + It("logs CRYPTO frames", func() { + frame := &CryptoFrame{ + Offset: 42, + Data: make([]byte, 123), + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.CryptoFrame{Offset: 42, Data length: 123, Offset + Data length: 165}\n")) + }) + + It("logs STREAM frames", func() { + frame := &StreamFrame{ + StreamID: 42, + Offset: 1337, + Data: bytes.Repeat([]byte{'f'}, 100), + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamFrame{StreamID: 42, Fin: false, Offset: 1337, Data length: 100, Offset + Data length: 1437}\n")) + }) + + It("logs ACK frames without missing packets", func() { + frame := &AckFrame{ + AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, + DelayTime: 1 * time.Millisecond, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms}\n")) + }) + + It("logs ACK frames with ECN", func() { + frame := &AckFrame{ + AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, + DelayTime: 1 * time.Millisecond, + ECT0: 5, + ECT1: 66, + ECNCE: 777, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms, ECT0: 5, ECT1: 66, CE: 777}\n")) + }) + + It("logs ACK frames with missing packets", func() { + frame := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5, Largest: 8}, + {Smallest: 2, Largest: 3}, + }, + DelayTime: 12 * time.Millisecond, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 8, LowestAcked: 2, AckRanges: {{Largest: 8, Smallest: 5}, {Largest: 3, Smallest: 2}}, DelayTime: 12ms}\n")) + }) + + It("logs MAX_STREAMS frames", func() { + frame := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: 42}\n")) + }) + + It("logs MAX_DATA frames", func() { + frame := &MaxDataFrame{ + MaximumData: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxDataFrame{MaximumData: 42}\n")) + }) + + It("logs MAX_STREAM_DATA frames", func() { + frame := &MaxStreamDataFrame{ + StreamID: 10, + MaximumStreamData: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 42}\n")) + }) + + It("logs DATA_BLOCKED frames", func() { + frame := &DataBlockedFrame{ + MaximumData: 1000, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.DataBlockedFrame{MaximumData: 1000}\n")) + }) + + It("logs STREAM_DATA_BLOCKED frames", func() { + frame := &StreamDataBlockedFrame{ + StreamID: 42, + MaximumStreamData: 1000, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1000}\n")) + }) + + It("logs STREAMS_BLOCKED frames", func() { + frame := &StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 42, + } + LogFrame(logger, frame, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: 42}\n")) + }) + + It("logs NEW_CONNECTION_ID frames", func() { + LogFrame(logger, &NewConnectionIDFrame{ + SequenceNumber: 42, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, + }, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) + }) + + It("logs NEW_TOKEN frames", func() { + LogFrame(logger, &NewTokenFrame{ + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + }, true) + Expect(buf.String()).To(ContainSubstring("\t-> &wire.NewTokenFrame{Token: 0xdeadbeef")) + }) +}) diff --git a/internal/quic-go/wire/max_data_frame.go b/internal/quic-go/wire/max_data_frame.go new file mode 100644 index 00000000..cfa54d7c --- /dev/null +++ b/internal/quic-go/wire/max_data_frame.go @@ -0,0 +1,40 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A MaxDataFrame carries flow control information for the connection +type MaxDataFrame struct { + MaximumData protocol.ByteCount +} + +// parseMaxDataFrame parses a MAX_DATA frame +func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &MaxDataFrame{} + byteOffset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.MaximumData = protocol.ByteCount(byteOffset) + return frame, nil +} + +// Write writes a MAX_STREAM_DATA frame +func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x10) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaximumData)) +} diff --git a/internal/quic-go/wire/max_data_frame_test.go b/internal/quic-go/wire/max_data_frame_test.go new file mode 100644 index 00000000..73f6a452 --- /dev/null +++ b/internal/quic-go/wire/max_data_frame_test.go @@ -0,0 +1,57 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_DATA frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x10} + data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset + b := bytes.NewReader(data) + frame, err := parseMaxDataFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x10} + data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset + _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &MaxDataFrame{ + MaximumData: 0xdeadbeef, + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) + }) + + It("writes a MAX_DATA frame", func() { + b := &bytes.Buffer{} + f := &MaxDataFrame{ + MaximumData: 0xdeadbeefcafe, + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/quic-go/wire/max_stream_data_frame.go b/internal/quic-go/wire/max_stream_data_frame.go new file mode 100644 index 00000000..5c6f37d0 --- /dev/null +++ b/internal/quic-go/wire/max_stream_data_frame.go @@ -0,0 +1,46 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A MaxStreamDataFrame is a MAX_STREAM_DATA frame +type MaxStreamDataFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x11) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) +} diff --git a/internal/quic-go/wire/max_stream_data_frame_test.go b/internal/quic-go/wire/max_stream_data_frame_test.go new file mode 100644 index 00000000..4d8e6fd8 --- /dev/null +++ b/internal/quic-go/wire/max_stream_data_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAM_DATA frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x11} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset + b := bytes.NewReader(data) + frame, err := parseMaxStreamDataFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x11} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset + _, err := parseMaxStreamDataFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &MaxStreamDataFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + Expect(f.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)))) + }) + + It("writes a sample frame", func() { + b := &bytes.Buffer{} + f := &MaxStreamDataFrame{ + StreamID: 0xdecafbad, + MaximumStreamData: 0xdeadbeefcafe42, + } + expected := []byte{0x11} + expected = append(expected, encodeVarInt(0xdecafbad)...) + expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/quic-go/wire/max_streams_frame.go b/internal/quic-go/wire/max_streams_frame.go new file mode 100644 index 00000000..0681fa24 --- /dev/null +++ b/internal/quic-go/wire/max_streams_frame.go @@ -0,0 +1,55 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A MaxStreamsFrame is a MAX_STREAMS frame +type MaxStreamsFrame struct { + Type protocol.StreamType + MaxStreamNum protocol.StreamNum +} + +func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &MaxStreamsFrame{} + switch typeByte { + case 0x12: + f.Type = protocol.StreamTypeBidi + case 0x13: + f.Type = protocol.StreamTypeUni + } + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.MaxStreamNum = protocol.StreamNum(streamID) + if f.MaxStreamNum > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + } + return f, nil +} + +func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x12) + case protocol.StreamTypeUni: + b.WriteByte(0x13) + } + quicvarint.Write(b, uint64(f.MaxStreamNum)) + return nil +} + +// Length of a written frame +func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) +} diff --git a/internal/quic-go/wire/max_streams_frame_test.go b/internal/quic-go/wire/max_streams_frame_test.go new file mode 100644 index 00000000..114b534d --- /dev/null +++ b/internal/quic-go/wire/max_streams_frame_test.go @@ -0,0 +1,107 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAMS frame", func() { + Context("parsing", func() { + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x12} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) + Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x13} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeUni)) + Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x1d} + data = append(data, encodeVarInt(0xdeadbeefcafe13)...) + _, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } + }) + + Context("writing", func() { + It("for a bidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: 0xdeadbeef, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x12} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("for a unidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: 0xdecafbad, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x13} + expected = append(expected, encodeVarInt(0xdecafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := MaxStreamsFrame{MaxStreamNum: 0x1337} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(0x1337))) + }) + }) +}) diff --git a/internal/quic-go/wire/new_connection_id_frame.go b/internal/quic-go/wire/new_connection_id_frame.go new file mode 100644 index 00000000..9eb1fcbc --- /dev/null +++ b/internal/quic-go/wire/new_connection_id_frame.go @@ -0,0 +1,80 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame +type NewConnectionIDFrame struct { + SequenceNumber uint64 + RetirePriorTo uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ret, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ret > seq { + //nolint:stylecheck + return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) + } + connIDLen, err := r.ReadByte() + if err != nil { + return nil, err + } + if connIDLen > protocol.MaxConnIDLen { + return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return nil, err + } + frame := &NewConnectionIDFrame{ + SequenceNumber: seq, + RetirePriorTo: ret, + ConnectionID: connID, + } + if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + + return frame, nil +} + +func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x18) + quicvarint.Write(b, f.SequenceNumber) + quicvarint.Write(b, f.RetirePriorTo) + connIDLen := f.ConnectionID.Len() + if connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + b.WriteByte(uint8(connIDLen)) + b.Write(f.ConnectionID.Bytes()) + b.Write(f.StatelessResetToken[:]) + return nil +} + +// Length of a written frame +func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 +} diff --git a/internal/quic-go/wire/new_connection_id_frame_test.go b/internal/quic-go/wire/new_connection_id_frame_test.go new file mode 100644 index 00000000..776b9670 --- /dev/null +++ b/internal/quic-go/wire/new_connection_id_frame_test.go @@ -0,0 +1,104 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_CONNECTION_ID frame", func() { + Context("when parsing", func() { + It("accepts a sample frame", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe)...) // retire prior to + data = append(data, 10) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + frame, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) + Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) + Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) + }) + + It("errors when the Retire Prior To value is larger than the Sequence Number", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(1000)...) // sequence number + data = append(data, encodeVarInt(1001)...) // retire prior to + data = append(data, 3) + data = append(data, []byte{1, 2, 3}...) + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) + }) + + It("errors when the connection ID has an invalid length", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe)...) // retire prior to + data = append(data, 21) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + b := bytes.NewReader(data) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) + Expect(err).To(MatchError("invalid connection ID length: 21")) + }) + + It("errors on EOFs", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + data = append(data, encodeVarInt(0xcafe1234)...) // retire prior to + data = append(data, 10) // connection ID length + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID + data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token + _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + frame := &NewConnectionIDFrame{ + SequenceNumber: 0x1337, + RetirePriorTo: 0x42, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + StatelessResetToken: token, + } + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x18} + expected = append(expected, encodeVarInt(0x1337)...) + expected = append(expected, encodeVarInt(0x42)...) + expected = append(expected, 6) + expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) + expected = append(expected, token[:]...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct length", func() { + token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + frame := &NewConnectionIDFrame{ + SequenceNumber: 0xdecafbad, + RetirePriorTo: 0xdeadbeefcafe, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + StatelessResetToken: token, + } + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + }) + }) +}) diff --git a/internal/quic-go/wire/new_token_frame.go b/internal/quic-go/wire/new_token_frame.go new file mode 100644 index 00000000..3a44eb21 --- /dev/null +++ b/internal/quic-go/wire/new_token_frame.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A NewTokenFrame is a NEW_TOKEN frame +type NewTokenFrame struct { + Token []byte +} + +func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + tokenLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if uint64(r.Len()) < tokenLen { + return nil, io.EOF + } + if tokenLen == 0 { + return nil, errors.New("token must not be empty") + } + token := make([]byte, int(tokenLen)) + if _, err := io.ReadFull(r, token); err != nil { + return nil, err + } + return &NewTokenFrame{Token: token}, nil +} + +func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x7) + quicvarint.Write(b, uint64(len(f.Token))) + b.Write(f.Token) + return nil +} + +// Length of a written frame +func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) +} diff --git a/internal/quic-go/wire/new_token_frame_test.go b/internal/quic-go/wire/new_token_frame_test.go new file mode 100644 index 00000000..3a3389c7 --- /dev/null +++ b/internal/quic-go/wire/new_token_frame_test.go @@ -0,0 +1,66 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_TOKEN frame", func() { + Context("parsing", func() { + It("accepts a sample frame", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(len(token)))...) + data = append(data, token...) + b := bytes.NewReader(data) + f, err := parseNewTokenFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(string(f.Token)).To(Equal(token)) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects empty tokens", func() { + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(0))...) + b := bytes.NewReader(data) + _, err := parseNewTokenFrame(b, protocol.VersionWhatever) + Expect(err).To(MatchError("token must not be empty")) + }) + + It("errors on EOFs", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" + data := []byte{0x7} + data = append(data, encodeVarInt(uint64(len(token)))...) + data = append(data, token...) + _, err := parseNewTokenFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseNewTokenFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("writes a sample frame", func() { + token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." + f := &NewTokenFrame{Token: []byte(token)} + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x7} + expected = append(expected, encodeVarInt(uint64(len(token)))...) + expected = append(expected, token...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := &NewTokenFrame{Token: []byte("foobar")} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(6) + 6)) + }) + }) +}) diff --git a/internal/quic-go/wire/path_challenge_frame.go b/internal/quic-go/wire/path_challenge_frame.go new file mode 100644 index 00000000..5d802249 --- /dev/null +++ b/internal/quic-go/wire/path_challenge_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A PathChallengeFrame is a PATH_CHALLENGE frame +type PathChallengeFrame struct { + Data [8]byte +} + +func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathChallengeFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1a) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/internal/quic-go/wire/path_challenge_frame_test.go b/internal/quic-go/wire/path_challenge_frame_test.go new file mode 100644 index 00000000..620e1b20 --- /dev/null +++ b/internal/quic-go/wire/path_challenge_frame_test.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PATH_CHALLENGE frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8}) + f, err := parsePathChallengeFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("errors on EOFs", func() { + data := []byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8} + _, err := parsePathChallengeFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + }) + + It("has the correct min length", func() { + frame := PathChallengeFrame{} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) + }) + }) +}) diff --git a/internal/quic-go/wire/path_response_frame.go b/internal/quic-go/wire/path_response_frame.go new file mode 100644 index 00000000..d7334ac1 --- /dev/null +++ b/internal/quic-go/wire/path_response_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A PathResponseFrame is a PATH_RESPONSE frame +type PathResponseFrame struct { + Data [8]byte +} + +func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathResponseFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1b) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/internal/quic-go/wire/path_response_frame_test.go b/internal/quic-go/wire/path_response_frame_test.go new file mode 100644 index 00000000..757a08f9 --- /dev/null +++ b/internal/quic-go/wire/path_response_frame_test.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PATH_RESPONSE frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8}) + f, err := parsePathResponseFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + }) + + It("errors on EOFs", func() { + data := []byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8} + _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + }) + + It("has the correct min length", func() { + frame := PathResponseFrame{} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) + }) + }) +}) diff --git a/internal/quic-go/wire/ping_frame.go b/internal/quic-go/wire/ping_frame.go new file mode 100644 index 00000000..d47d8ce9 --- /dev/null +++ b/internal/quic-go/wire/ping_frame.go @@ -0,0 +1,27 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +// A PingFrame is a PING frame +type PingFrame struct{} + +func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &PingFrame{}, nil +} + +func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x1) + return nil +} + +// Length of a written frame +func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/internal/quic-go/wire/ping_frame_test.go b/internal/quic-go/wire/ping_frame_test.go new file mode 100644 index 00000000..cb9b2259 --- /dev/null +++ b/internal/quic-go/wire/ping_frame_test.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PingFrame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1}) + _, err := parsePingFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + _, err := parsePingFrame(bytes.NewReader(nil), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := PingFrame{} + frame.Write(b, protocol.VersionWhatever) + Expect(b.Bytes()).To(Equal([]byte{0x1})) + }) + + It("has the correct min length", func() { + frame := PingFrame{} + Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1))) + }) + }) +}) diff --git a/internal/quic-go/wire/pool.go b/internal/quic-go/wire/pool.go new file mode 100644 index 00000000..2fb1f82c --- /dev/null +++ b/internal/quic-go/wire/pool.go @@ -0,0 +1,33 @@ +package wire + +import ( + "sync" + + "github.com/imroc/req/v3/internal/quic-go/protocol" +) + +var pool sync.Pool + +func init() { + pool.New = func() interface{} { + return &StreamFrame{ + Data: make([]byte, 0, protocol.MaxPacketBufferSize), + fromPool: true, + } + } +} + +func GetStreamFrame() *StreamFrame { + f := pool.Get().(*StreamFrame) + return f +} + +func putStreamFrame(f *StreamFrame) { + if !f.fromPool { + return + } + if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { + panic("wire.PutStreamFrame called with packet of wrong size!") + } + pool.Put(f) +} diff --git a/internal/quic-go/wire/pool_test.go b/internal/quic-go/wire/pool_test.go new file mode 100644 index 00000000..b55e493b --- /dev/null +++ b/internal/quic-go/wire/pool_test.go @@ -0,0 +1,24 @@ +package wire + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Pool", func() { + It("gets and puts STREAM frames", func() { + f := GetStreamFrame() + putStreamFrame(f) + }) + + It("panics when putting a STREAM frame with a wrong capacity", func() { + f := GetStreamFrame() + f.Data = []byte("foobar") + Expect(func() { putStreamFrame(f) }).To(Panic()) + }) + + It("accepts STREAM frames not from the buffer, but ignores them", func() { + f := &StreamFrame{Data: []byte("foobar")} + putStreamFrame(f) + }) +}) diff --git a/internal/quic-go/wire/reset_stream_frame.go b/internal/quic-go/wire/reset_stream_frame.go new file mode 100644 index 00000000..29910473 --- /dev/null +++ b/internal/quic-go/wire/reset_stream_frame.go @@ -0,0 +1,58 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A ResetStreamFrame is a RESET_STREAM frame in QUIC +type ResetStreamFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode + FinalSize protocol.ByteCount +} + +func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var byteOffset protocol.ByteCount + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + bo, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + + return &ResetStreamFrame{ + StreamID: streamID, + ErrorCode: qerr.StreamErrorCode(errorCode), + FinalSize: byteOffset, + }, nil +} + +func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x4) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + quicvarint.Write(b, uint64(f.FinalSize)) + return nil +} + +// Length of a written frame +func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) +} diff --git a/internal/quic-go/wire/reset_stream_frame_test.go b/internal/quic-go/wire/reset_stream_frame_test.go new file mode 100644 index 00000000..b60ba3f7 --- /dev/null +++ b/internal/quic-go/wire/reset_stream_frame_test.go @@ -0,0 +1,70 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RESET_STREAM frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x4} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + b := bytes.NewReader(data) + frame, err := parseResetStreamFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) + }) + + It("errors on EOFs", func() { + data := []byte{0x4} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + frame := ResetStreamFrame{ + StreamID: 0x1337, + FinalSize: 0x11223344decafbad, + ErrorCode: 0xcafe, + } + b := &bytes.Buffer{} + err := frame.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x4} + expected = append(expected, encodeVarInt(0x1337)...) + expected = append(expected, encodeVarInt(0xcafe)...) + expected = append(expected, encodeVarInt(0x11223344decafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + rst := ResetStreamFrame{ + StreamID: 0x1337, + FinalSize: 0x1234567, + ErrorCode: 0xde, + } + expectedLen := 1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + 2 + Expect(rst.Length(protocol.Version1)).To(Equal(expectedLen)) + }) + }) +}) diff --git a/internal/quic-go/wire/retire_connection_id_frame.go b/internal/quic-go/wire/retire_connection_id_frame.go new file mode 100644 index 00000000..a7e09aab --- /dev/null +++ b/internal/quic-go/wire/retire_connection_id_frame.go @@ -0,0 +1,36 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame +type RetireConnectionIDFrame struct { + SequenceNumber uint64 +} + +func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &RetireConnectionIDFrame{SequenceNumber: seq}, nil +} + +func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x19) + quicvarint.Write(b, f.SequenceNumber) + return nil +} + +// Length of a written frame +func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(f.SequenceNumber) +} diff --git a/internal/quic-go/wire/retire_connection_id_frame_test.go b/internal/quic-go/wire/retire_connection_id_frame_test.go new file mode 100644 index 00000000..2e531d34 --- /dev/null +++ b/internal/quic-go/wire/retire_connection_id_frame_test.go @@ -0,0 +1,53 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("NEW_CONNECTION_ID frame", func() { + Context("when parsing", func() { + It("accepts a sample frame", func() { + data := []byte{0x19} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + b := bytes.NewReader(data) + frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) + }) + + It("errors on EOFs", func() { + data := []byte{0x18} + data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + expected := []byte{0x19} + expected = append(expected, encodeVarInt(0x1337)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct length", func() { + frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} + b := &bytes.Buffer{} + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + }) + }) +}) diff --git a/internal/quic-go/wire/stop_sending_frame.go b/internal/quic-go/wire/stop_sending_frame.go new file mode 100644 index 00000000..5e40c4e3 --- /dev/null +++ b/internal/quic-go/wire/stop_sending_frame.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode +} + +// parseStopSendingFrame parses a STOP_SENDING frame +func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: qerr.StreamErrorCode(errorCode), + }, nil +} + +// Length of a written frame +func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x5) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + return nil +} diff --git a/internal/quic-go/wire/stop_sending_frame_test.go b/internal/quic-go/wire/stop_sending_frame_test.go new file mode 100644 index 00000000..e36e75ee --- /dev/null +++ b/internal/quic-go/wire/stop_sending_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STOP_SENDING frame", func() { + Context("when parsing", func() { + It("parses a sample frame", func() { + data := []byte{0x5} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, encodeVarInt(0x1337)...) // error code + b := bytes.NewReader(data) + frame, err := parseStopSendingFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x5} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, encodeVarInt(0x123456)...) // error code + _, err := parseStopSendingFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeefcafe, + ErrorCode: 0xdecafbad, + } + buf := &bytes.Buffer{} + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) + expected := []byte{0x5} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + expected = append(expected, encodeVarInt(0xdecafbad)...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeef, + ErrorCode: 0x1234567, + } + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) + }) + }) +}) diff --git a/internal/quic-go/wire/stream_data_blocked_frame.go b/internal/quic-go/wire/stream_data_blocked_frame.go new file mode 100644 index 00000000..447dc089 --- /dev/null +++ b/internal/quic-go/wire/stream_data_blocked_frame.go @@ -0,0 +1,46 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame +type StreamDataBlockedFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StreamDataBlockedFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x15) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) +} diff --git a/internal/quic-go/wire/stream_data_blocked_frame_test.go b/internal/quic-go/wire/stream_data_blocked_frame_test.go new file mode 100644 index 00000000..6d5f50db --- /dev/null +++ b/internal/quic-go/wire/stream_data_blocked_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM_DATA_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x15} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + b := bytes.NewReader(data) + frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x15} + data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xc0010ff)...) + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &StreamDataBlockedFrame{ + StreamID: 0x1337, + MaximumStreamData: 0xdeadbeef, + } + Expect(f.Length(0)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0xdeadbeef))) + }) + + It("writes a sample frame", func() { + b := &bytes.Buffer{} + f := &StreamDataBlockedFrame{ + StreamID: 0xdecafbad, + MaximumStreamData: 0x1337, + } + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x15} + expected = append(expected, encodeVarInt(uint64(f.StreamID))...) + expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) + Expect(b.Bytes()).To(Equal(expected)) + }) + }) +}) diff --git a/internal/quic-go/wire/stream_frame.go b/internal/quic-go/wire/stream_frame.go new file mode 100644 index 00000000..c4b9db48 --- /dev/null +++ b/internal/quic-go/wire/stream_frame.go @@ -0,0 +1,189 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A StreamFrame of QUIC +type StreamFrame struct { + StreamID protocol.StreamID + Offset protocol.ByteCount + Data []byte + Fin bool + DataLenPresent bool + + fromPool bool +} + +func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + hasOffset := typeByte&0x4 > 0 + fin := typeByte&0x1 > 0 + hasDataLen := typeByte&0x2 > 0 + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + var offset uint64 + if hasOffset { + offset, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } + + var dataLen uint64 + if hasDataLen { + var err error + dataLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } else { + // The rest of the packet is data + dataLen = uint64(r.Len()) + } + + var frame *StreamFrame + if dataLen < protocol.MinStreamFrameBufferSize { + frame = &StreamFrame{Data: make([]byte, dataLen)} + } else { + frame = GetStreamFrame() + // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, + // since those StreamFrames have a buffer length of the maximum packet size. + if dataLen > uint64(cap(frame.Data)) { + return nil, io.EOF + } + frame.Data = frame.Data[:dataLen] + } + + frame.StreamID = protocol.StreamID(streamID) + frame.Offset = protocol.ByteCount(offset) + frame.Fin = fin + frame.DataLenPresent = hasDataLen + + if dataLen != 0 { + if _, err := io.ReadFull(r, frame.Data); err != nil { + return nil, err + } + } + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, errors.New("stream data overflows maximum offset") + } + return frame, nil +} + +// Write writes a STREAM frame +func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if len(f.Data) == 0 && !f.Fin { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := byte(0x8) + if f.Fin { + typeByte ^= 0x1 + } + hasOffset := f.Offset != 0 + if f.DataLenPresent { + typeByte ^= 0x2 + } + if hasOffset { + typeByte ^= 0x4 + } + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.StreamID)) + if hasOffset { + quicvarint.Write(b, uint64(f.Offset)) + } + if f.DataLenPresent { + quicvarint.Write(b, uint64(f.DataLen())) + } + b.Write(f.Data) + return nil +} + +// Length returns the total length of the STREAM frame +func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + length += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + length += quicvarint.Len(uint64(f.DataLen())) + } + return length + f.DataLen() +} + +// DataLen gives the length of data in bytes +func (f *StreamFrame) DataLen() protocol.ByteCount { + return protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + headerLen += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { + if maxSize >= f.Length(version) { + return nil, false + } + + n := f.MaxDataLen(maxSize, version) + if n == 0 { + return nil, true + } + + new := GetStreamFrame() + new.StreamID = f.StreamID + new.Offset = f.Offset + new.Fin = false + new.DataLenPresent = f.DataLenPresent + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + new.fromPool, f.fromPool = f.fromPool, new.fromPool + + f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} + +func (f *StreamFrame) PutBack() { + putStreamFrame(f) +} diff --git a/internal/quic-go/wire/stream_frame_test.go b/internal/quic-go/wire/stream_frame_test.go new file mode 100644 index 00000000..a533a183 --- /dev/null +++ b/internal/quic-go/wire/stream_frame_test.go @@ -0,0 +1,443 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame", func() { + Context("when parsing", func() { + It("parses a frame with OFF bit", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(r.Len()).To(BeZero()) + }) + + It("respects the LEN when parsing the frame", func() { + data := []byte{0x8 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(4)...) // data length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foob"))) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.Offset).To(BeZero()) + Expect(r.Len()).To(Equal(2)) + }) + + It("parses a frame with FIN bit", func() { + data := []byte{0x8 ^ 0x1} + data = append(data, encodeVarInt(9)...) // stream ID + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.Fin).To(BeTrue()) + Expect(frame.Offset).To(BeZero()) + Expect(r.Len()).To(BeZero()) + }) + + It("allows empty frames", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x1337)...) // stream ID + data = append(data, encodeVarInt(0x12345)...) // offset + r := bytes.NewReader(data) + f, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) + Expect(f.Data).To(BeEmpty()) + Expect(f.Fin).To(BeFalse()) + }) + + It("rejects frames that overflow the maximum offset", func() { + data := []byte{0x8 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + _, err := parseStreamFrame(r, protocol.Version1) + Expect(err).To(MatchError("stream data overflows maximum offset")) + }) + + It("rejects frames that claim to be longer than the packet size", func() { + data := []byte{0x8 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length + data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) + r := bytes.NewReader(data) + _, err := parseStreamFrame(r, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x8 ^ 0x4 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // data length + data = append(data, []byte("foobar")...) + _, err := parseStreamFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("using the buffer", func() { + It("uses the buffer for long STREAM frames", func() { + data := []byte{0x8} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) + Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize)) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.fromPool).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + Expect(frame.PutBack).ToNot(Panic()) + }) + + It("doesn't use the buffer for short STREAM frames", func() { + data := []byte{0x8} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) + r := bytes.NewReader(data) + frame, err := parseStreamFrame(r, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) + Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1)) + Expect(frame.Fin).To(BeFalse()) + Expect(frame.fromPool).To(BeFalse()) + Expect(r.Len()).To(BeZero()) + Expect(frame.PutBack).ToNot(Panic()) + }) + }) + + Context("when writing", func() { + It("writes a frame without offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + Data: []byte("foobar"), + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with FIN bit", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + Fin: true, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4 ^ 0x1} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length and offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + Offset: 0x123456, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x8 ^ 0x4 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("refuses to write an empty frame without FIN", func() { + f := &StreamFrame{ + StreamID: 0x42, + Offset: 0x1337, + } + b := &bytes.Buffer{} + err := f.Write(b, protocol.Version1) + Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) + }) + }) + + Context("length", func() { + It("has the right length for a frame without offset and data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) + }) + + It("has the right length for a frame with offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x42, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) + }) + + It("has the right length for a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x1234567, + DataLenPresent: true, + Data: []byte("foobar"), + } + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("always returns a data length such that the resulting frame has the right size, if data length is not present", func() { + data := make([]byte, maxSize) + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0xdeadbeef, + } + b := &bytes.Buffer{} + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(Equal(i)) + } + }) + + It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { + data := make([]byte, maxSize) + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0xdeadbeef, + DataLenPresent: true, + } + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + err := f.Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) + + Context("splitting", func() { + It("doesn't split if the frame is short enough", func() { + f := &StreamFrame{ + StreamID: 0x1337, + DataLenPresent: true, + Offset: 0xdeadbeef, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) + Expect(needsSplit).To(BeFalse()) + Expect(frame).To(BeNil()) + Expect(f.DataLen()).To(BeEquivalentTo(100)) + frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame.DataLen()).To(BeEquivalentTo(99)) + f.PutBack() + }) + + It("keeps the data len", func() { + f := &StreamFrame{ + StreamID: 0x1337, + DataLenPresent: true, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(66, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(frame.DataLenPresent).To(BeTrue()) + }) + + It("adjusts the offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x100, + Data: []byte("foobar"), + } + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0x100))) + Expect(frame.Data).To(Equal([]byte("foo"))) + Expect(f.Offset).To(Equal(protocol.ByteCount(0x100 + 3))) + Expect(f.Data).To(Equal([]byte("bar"))) + }) + + It("preserves the FIN bit", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Fin: true, + Offset: 0xdeadbeef, + Data: make([]byte, 100), + } + frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Offset).To(BeNumerically("<", f.Offset)) + Expect(f.Fin).To(BeTrue()) + Expect(frame.Fin).To(BeFalse()) + }) + + It("produces frames of the correct length, without data len", func() { + const size = 1000 + f := &StreamFrame{ + StreamID: 0xdecafbad, + Offset: 0x1234, + Data: []byte{0}, + } + minFrameSize := f.Length(protocol.Version1) + for i := protocol.ByteCount(0); i < minFrameSize; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + for i := minFrameSize; i < size; i++ { + f.fromPool = false + f.Data = make([]byte, size) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f.Length(protocol.Version1)).To(Equal(i)) + } + }) + + It("produces frames of the correct length, with data len", func() { + const size = 1000 + f := &StreamFrame{ + StreamID: 0xdecafbad, + Offset: 0x1234, + DataLenPresent: true, + Data: []byte{0}, + } + minFrameSize := f.Length(protocol.Version1) + for i := protocol.ByteCount(0); i < minFrameSize; i++ { + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + Expect(f).To(BeNil()) + } + var frameOneByteTooSmallCounter int + for i := minFrameSize; i < size; i++ { + f.fromPool = false + f.Data = make([]byte, size) + newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) + Expect(needsSplit).To(BeTrue()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if newFrame.Length(protocol.Version1) == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(newFrame.Length(protocol.Version1)).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) +}) diff --git a/internal/quic-go/wire/streams_blocked_frame.go b/internal/quic-go/wire/streams_blocked_frame.go new file mode 100644 index 00000000..aab28c24 --- /dev/null +++ b/internal/quic-go/wire/streams_blocked_frame.go @@ -0,0 +1,55 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" +) + +// A StreamsBlockedFrame is a STREAMS_BLOCKED frame +type StreamsBlockedFrame struct { + Type protocol.StreamType + StreamLimit protocol.StreamNum +} + +func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &StreamsBlockedFrame{} + switch typeByte { + case 0x16: + f.Type = protocol.StreamTypeBidi + case 0x17: + f.Type = protocol.StreamTypeUni + } + streamLimit, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.StreamLimit = protocol.StreamNum(streamLimit) + if f.StreamLimit > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + } + return f, nil +} + +func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x16) + case protocol.StreamTypeUni: + b.WriteByte(0x17) + } + quicvarint.Write(b, uint64(f.StreamLimit)) + return nil +} + +// Length of a written frame +func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamLimit)) +} diff --git a/internal/quic-go/wire/streams_blocked_frame_test.go b/internal/quic-go/wire/streams_blocked_frame_test.go new file mode 100644 index 00000000..39b29dad --- /dev/null +++ b/internal/quic-go/wire/streams_blocked_frame_test.go @@ -0,0 +1,108 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAMS_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts a frame for bidirectional streams", func() { + expected := []byte{0x16} + expected = append(expected, encodeVarInt(0x1337)...) + b := bytes.NewReader(expected) + f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) + Expect(f.StreamLimit).To(BeEquivalentTo(0x1337)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts a frame for unidirectional streams", func() { + expected := []byte{0x17} + expected = append(expected, encodeVarInt(0x7331)...) + b := bytes.NewReader(expected) + f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeUni)) + Expect(f.StreamLimit).To(BeEquivalentTo(0x7331)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x16} + data = append(data, encodeVarInt(0x12345678)...) + _, err := parseStreamsBlockedFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + Expect(err).To(MatchError(io.EOF)) + } + }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } + }) + + Context("writing", func() { + It("writes a frame for bidirectional streams", func() { + b := &bytes.Buffer{} + f := StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: 0xdeadbeefcafe, + } + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x16} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame for unidirectional streams", func() { + b := &bytes.Buffer{} + f := StreamsBlockedFrame{ + Type: protocol.StreamTypeUni, + StreamLimit: 0xdeadbeefcafe, + } + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x17} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := StreamsBlockedFrame{StreamLimit: 0x123456} + Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + quicvarint.Len(0x123456))) + }) + }) +}) diff --git a/internal/quic-go/wire/transport_parameter_test.go b/internal/quic-go/wire/transport_parameter_test.go new file mode 100644 index 00000000..56d43fad --- /dev/null +++ b/internal/quic-go/wire/transport_parameter_test.go @@ -0,0 +1,612 @@ +package wire + +import ( + "bytes" + "fmt" + "math" + "math/rand" + "net" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Transport Parameters", func() { + getRandomValueUpTo := func(max int64) uint64 { + maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} + m := maxVals[int(rand.Int31n(4))] + if m > max { + m = max + } + return uint64(rand.Int63n(m)) + } + + getRandomValue := func() uint64 { + return getRandomValueUpTo(math.MaxInt64) + } + + BeforeEach(func() { + rand.Seed(GinkgoRandomSeed()) + }) + + addInitialSourceConnectionID := func(b *bytes.Buffer) { + quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + } + + It("has a string representation", func() { + p := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1234, + InitialMaxStreamDataBidiRemote: 2345, + InitialMaxStreamDataUni: 3456, + InitialMaxData: 4567, + MaxBidiStreamNum: 1337, + MaxUniStreamNum: 7331, + MaxIdleTimeout: 42 * time.Second, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + AckDelayExponent: 14, + MaxAckDelay: 37 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + ActiveConnectionIDLimit: 123, + MaxDatagramFrameSize: 876, + } + Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: decafbad, RetrySourceConnectionID: deadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}")) + }) + + It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() { + p := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1234, + InitialMaxStreamDataBidiRemote: 2345, + InitialMaxStreamDataUni: 3456, + InitialMaxData: 4567, + MaxBidiStreamNum: 1337, + MaxUniStreamNum: 7331, + MaxIdleTimeout: 42 * time.Second, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{}, + AckDelayExponent: 14, + MaxAckDelay: 37 * time.Second, + ActiveConnectionIDLimit: 89, + MaxDatagramFrameSize: protocol.InvalidByteCount, + } + Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}")) + }) + + It("marshals and unmarshals", func() { + var token protocol.StatelessResetToken + rand.Read(token[:]) + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxIdleTimeout: 0xcafe * time.Second, + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + DisableActiveMigration: true, + StatelessResetToken: &token, + OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + AckDelayExponent: 13, + MaxAckDelay: 42 * time.Millisecond, + ActiveConnectionIDLimit: getRandomValue(), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), + } + data := params.Marshal(protocol.PerspectiveServer) + + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) + Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) + Expect(p.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) + Expect(p.InitialMaxData).To(Equal(params.InitialMaxData)) + Expect(p.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) + Expect(p.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) + Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) + Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) + Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) + Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(p.AckDelayExponent).To(Equal(uint8(13))) + Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) + Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) + }) + + It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() { + data := (&TransportParameters{ + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.RetrySourceConnectionID).To(BeNil()) + }) + + It("marshals a zero-length retry_source_connection_id", func() { + data := (&TransportParameters{ + RetrySourceConnectionID: &protocol.ConnectionID{}, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.RetrySourceConnectionID).ToNot(BeNil()) + Expect(p.RetrySourceConnectionID.Len()).To(BeZero()) + }) + + It("errors when the stateless_reset_token has the wrong length", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 15) + b.Write(make([]byte, 15)) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for stateless_reset_token: 15 (expected 16)", + })) + }) + + It("errors when the max_packet_size is too small", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(maxUDPPayloadSizeParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(1199))) + quicvarint.Write(b, 1199) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_packet_size: 1199 (minimum 1200)", + })) + }) + + It("errors when disable_active_migration has content", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for disable_active_migration: 6 (expected empty)", + })) + }) + + It("errors when the server doesn't set the original_destination_connection_id", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 16) + b.Write(make([]byte, 16)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing original_destination_connection_id", + })) + }) + + It("errors when the initial_source_connection_id is missing", func() { + Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing initial_source_connection_id", + })) + }) + + It("errors when the max_ack_delay is too large", func() { + data := (&TransportParameters{ + MaxAckDelay: 1 << 14 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", + })) + }) + + It("doesn't send the max_ack_delay, if it has the default value", func() { + const num = 1000 + var defaultLen, dataLen int + // marshal 1000 times to average out the greasing transport parameter + maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond + for i := 0; i < num; i++ { + dataDefault := (&TransportParameters{ + MaxAckDelay: protocol.DefaultMaxAckDelay, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + defaultLen += len(dataDefault) + data := (&TransportParameters{ + MaxAckDelay: maxAckDelay, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + dataLen += len(data) + } + entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(uint64(maxAckDelay.Milliseconds())))) /*length */ + quicvarint.Len(uint64(maxAckDelay.Milliseconds())) /* value */ + Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) + }) + + It("errors when the ack_delay_exponenent is too large", func() { + data := (&TransportParameters{ + AckDelayExponent: 21, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for ack_delay_exponent: 21 (maximum 20)", + })) + }) + + It("doesn't send the ack_delay_exponent, if it has the default value", func() { + const num = 1000 + var defaultLen, dataLen int + // marshal 1000 times to average out the greasing transport parameter + for i := 0; i < num; i++ { + dataDefault := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + defaultLen += len(dataDefault) + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent + 1, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + dataLen += len(data) + } + entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(protocol.DefaultAckDelayExponent+1))) /* length */ + quicvarint.Len(protocol.DefaultAckDelayExponent+1) /* value */ + Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) + }) + + It("sets the default value for the ack_delay_exponent, when no value was sent", func() { + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent)) + }) + + It("errors when the varint value has the wrong length", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, 2) + val := uint64(0xdeadbeef) + Expect(quicvarint.Len(val)).ToNot(BeEquivalentTo(2)) + quicvarint.Write(b, val) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) + }) + + It("errors if initial_max_streams_bidi is too large", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamsBidiParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) + quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", + })) + }) + + It("errors if initial_max_streams_uni is too large", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(initialMaxStreamsUniParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) + quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", + })) + }) + + It("handles huge max_ack_delay values", func() { + b := &bytes.Buffer{} + val := uint64(math.MaxUint64) / 5 + quicvarint.Write(b, uint64(maxAckDelayParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(val))) + quicvarint.Write(b, val) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", + })) + }) + + It("skips unknown parameters", func() { + b := &bytes.Buffer{} + // write a known parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + // write an unknown parameter + quicvarint.Write(b, 0x42) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + // write a known parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x42))) + quicvarint.Write(b, 0x42) + addInitialSourceConnectionID(b) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) + Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) + Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) + }) + + It("rejects duplicate parameters", func() { + b := &bytes.Buffer{} + // write first parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + // write a second parameter + quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x42))) + quicvarint.Write(b, 0x42) + // write first parameter again + quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) + quicvarint.Write(b, 0x1337) + addInitialSourceConnectionID(b) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) + }) + + It("errors if there's not enough data to read", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, 0x42) + quicvarint.Write(b, 7) + b.Write([]byte("foobar")) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "remaining length (6) smaller than parameter length (7)", + })) + }) + + It("errors if the client sent a stateless_reset_token", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, uint64(quicvarint.Len(16))) + b.Write(make([]byte, 16)) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a stateless_reset_token", + })) + }) + + It("errors if the client sent the original_destination_connection_id", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent an original_destination_connection_id", + })) + }) + + Context("preferred address", func() { + var pa *PreferredAddress + + BeforeEach(func() { + pa = &PreferredAddress{ + IPv4: net.IPv4(127, 0, 0, 1), + IPv4Port: 42, + IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + IPv6Port: 13, + ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + } + }) + + It("marshals and unmarshals", func() { + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) + Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) + Expect(p.PreferredAddress.IPv4Port).To(Equal(pa.IPv4Port)) + Expect(p.PreferredAddress.IPv6.String()).To(Equal(pa.IPv6.String())) + Expect(p.PreferredAddress.IPv6Port).To(Equal(pa.IPv6Port)) + Expect(p.PreferredAddress.ConnectionID).To(Equal(pa.ConnectionID)) + Expect(p.PreferredAddress.StatelessResetToken).To(Equal(pa.StatelessResetToken)) + }) + + It("errors if the client sent a preferred_address", func() { + b := &bytes.Buffer{} + quicvarint.Write(b, uint64(preferredAddressParameterID)) + quicvarint.Write(b, 6) + b.Write([]byte("foobar")) + p := &TransportParameters{} + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a preferred_address", + })) + }) + + It("errors on zero-length connection IDs", func() { + pa.ConnectionID = protocol.ConnectionID{} + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 0", + })) + }) + + It("errors on too long connection IDs", func() { + pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} + Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &protocol.StatelessResetToken{}, + }).Marshal(protocol.PerspectiveServer) + p := &TransportParameters{} + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 21", + })) + }) + + It("errors on EOF", func() { + raw := []byte{ + 127, 0, 0, 1, // IPv4 + 0, 42, // IPv4 Port + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // IPv6 + 13, 37, // IPv6 Port, + 4, // conn ID len + 0xde, 0xad, 0xbe, 0xef, + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // stateless reset token + } + for i := 1; i < len(raw); i++ { + buf := &bytes.Buffer{} + quicvarint.Write(buf, uint64(preferredAddressParameterID)) + buf.Write(raw[:i]) + p := &TransportParameters{} + Expect(p.Unmarshal(buf.Bytes(), protocol.PerspectiveServer)).ToNot(Succeed()) + } + }) + }) + + Context("saving and retrieving from a session ticket", func() { + It("saves and retrieves the parameters", func() { + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + ActiveConnectionIDLimit: getRandomValue(), + } + Expect(params.ValidFor0RTT(params)).To(BeTrue()) + b := &bytes.Buffer{} + params.MarshalForSessionTicket(b) + var tp TransportParameters + Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) + Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) + Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) + Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) + Expect(tp.InitialMaxData).To(Equal(params.InitialMaxData)) + Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) + Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) + Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + }) + + It("rejects the parameters if it can't parse them", func() { + var p TransportParameters + Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed()) + }) + + It("rejects the parameters if the version changed", func() { + var p TransportParameters + buf := &bytes.Buffer{} + p.MarshalForSessionTicket(buf) + data := buf.Bytes() + b := &bytes.Buffer{} + quicvarint.Write(b, transportParameterMarshalingVersion+1) + b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) + Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) + }) + + Context("rejects the parameters if they changed", func() { + var p TransportParameters + saved := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + InitialMaxStreamDataUni: 3, + InitialMaxData: 4, + MaxBidiStreamNum: 5, + MaxUniStreamNum: 6, + ActiveConnectionIDLimit: 7, + } + + BeforeEach(func() { + p = *saved + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxData was reduced", func() { + p.InitialMaxData = saved.InitialMaxData - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the InitialMaxData was increased", func() { + p.InitialMaxData = saved.InitialMaxData + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("accepts the parameters if the MaxBidiStreamNum was increased", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxUniStreamNum changed", func() { + p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + + It("accepts the parameters if the MaxUniStreamNum was increased", func() { + p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the ActiveConnectionIDLimit changed", func() { + p.ActiveConnectionIDLimit = 0 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) + }) + }) +}) diff --git a/internal/quic-go/wire/transport_parameters.go b/internal/quic-go/wire/transport_parameters.go new file mode 100644 index 00000000..544e3506 --- /dev/null +++ b/internal/quic-go/wire/transport_parameters.go @@ -0,0 +1,476 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + "sort" + "time" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/qerr" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +const transportParameterMarshalingVersion = 1 + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} + +type transportParameterID uint64 + +const ( + originalDestinationConnectionIDParameterID transportParameterID = 0x0 + maxIdleTimeoutParameterID transportParameterID = 0x1 + statelessResetTokenParameterID transportParameterID = 0x2 + maxUDPPayloadSizeParameterID transportParameterID = 0x3 + initialMaxDataParameterID transportParameterID = 0x4 + initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 + initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 + initialMaxStreamDataUniParameterID transportParameterID = 0x7 + initialMaxStreamsBidiParameterID transportParameterID = 0x8 + initialMaxStreamsUniParameterID transportParameterID = 0x9 + ackDelayExponentParameterID transportParameterID = 0xa + maxAckDelayParameterID transportParameterID = 0xb + disableActiveMigrationParameterID transportParameterID = 0xc + preferredAddressParameterID transportParameterID = 0xd + activeConnectionIDLimitParameterID transportParameterID = 0xe + initialSourceConnectionIDParameterID transportParameterID = 0xf + retrySourceConnectionIDParameterID transportParameterID = 0x10 + // RFC 9221 + maxDatagramFrameSizeParameterID transportParameterID = 0x20 +) + +// PreferredAddress is the value encoding in the preferred_address transport parameter +type PreferredAddress struct { + IPv4 net.IP + IPv4Port uint16 + IPv6 net.IP + IPv6Port uint16 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + InitialMaxStreamDataBidiLocal protocol.ByteCount + InitialMaxStreamDataBidiRemote protocol.ByteCount + InitialMaxStreamDataUni protocol.ByteCount + InitialMaxData protocol.ByteCount + + MaxAckDelay time.Duration + AckDelayExponent uint8 + + DisableActiveMigration bool + + MaxUDPPayloadSize protocol.ByteCount + + MaxUniStreamNum protocol.StreamNum + MaxBidiStreamNum protocol.StreamNum + + MaxIdleTimeout time.Duration + + PreferredAddress *PreferredAddress + + OriginalDestinationConnectionID protocol.ConnectionID + InitialSourceConnectionID protocol.ConnectionID + RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters + + StatelessResetToken *protocol.StatelessResetToken + ActiveConnectionIDLimit uint64 + + MaxDatagramFrameSize protocol.ByteCount +} + +// Unmarshal the transport parameters +func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { + if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { + return &qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + } + } + return nil +} + +func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { + // needed to check that every parameter is only sent at most once + var parameterIDs []transportParameterID + + var ( + readOriginalDestinationConnectionID bool + readInitialSourceConnectionID bool + ) + + p.AckDelayExponent = protocol.DefaultAckDelayExponent + p.MaxAckDelay = protocol.DefaultMaxAckDelay + p.MaxDatagramFrameSize = protocol.InvalidByteCount + + for r.Len() > 0 { + paramIDInt, err := quicvarint.Read(r) + if err != nil { + return err + } + paramID := transportParameterID(paramIDInt) + paramLen, err := quicvarint.Read(r) + if err != nil { + return err + } + if uint64(r.Len()) < paramLen { + return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) + } + parameterIDs = append(parameterIDs, paramID) + switch paramID { + case maxIdleTimeoutParameterID, + maxUDPPayloadSizeParameterID, + initialMaxDataParameterID, + initialMaxStreamDataBidiLocalParameterID, + initialMaxStreamDataBidiRemoteParameterID, + initialMaxStreamDataUniParameterID, + initialMaxStreamsBidiParameterID, + initialMaxStreamsUniParameterID, + maxAckDelayParameterID, + activeConnectionIDLimitParameterID, + maxDatagramFrameSizeParameterID, + ackDelayExponentParameterID: + if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { + return err + } + case preferredAddressParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a preferred_address") + } + if err := p.readPreferredAddress(r, int(paramLen)); err != nil { + return err + } + case disableActiveMigrationParameterID: + if paramLen != 0 { + return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) + } + p.DisableActiveMigration = true + case statelessResetTokenParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a stateless_reset_token") + } + if paramLen != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) + } + var token protocol.StatelessResetToken + r.Read(token[:]) + p.StatelessResetToken = &token + case originalDestinationConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent an original_destination_connection_id") + } + p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readOriginalDestinationConnectionID = true + case initialSourceConnectionIDParameterID: + p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readInitialSourceConnectionID = true + case retrySourceConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a retry_source_connection_id") + } + connID, _ := protocol.ReadConnectionID(r, int(paramLen)) + p.RetrySourceConnectionID = &connID + default: + r.Seek(int64(paramLen), io.SeekCurrent) + } + } + + if !fromSessionTicket { + if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { + return errors.New("missing original_destination_connection_id") + } + if p.MaxUDPPayloadSize == 0 { + p.MaxUDPPayloadSize = protocol.MaxByteCount + } + if !readInitialSourceConnectionID { + return errors.New("missing initial_source_connection_id") + } + } + + // check that every transport parameter was sent at most once + sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) + for i := 0; i < len(parameterIDs)-1; i++ { + if parameterIDs[i] == parameterIDs[i+1] { + return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) + } + } + + return nil +} + +func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { + remainingLen := r.Len() + pa := &PreferredAddress{} + ipv4 := make([]byte, 4) + if _, err := io.ReadFull(r, ipv4); err != nil { + return err + } + pa.IPv4 = net.IP(ipv4) + port, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv4Port = port + ipv6 := make([]byte, 16) + if _, err := io.ReadFull(r, ipv6); err != nil { + return err + } + pa.IPv6 = net.IP(ipv6) + port, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv6Port = port + connIDLen, err := r.ReadByte() + if err != nil { + return err + } + if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return err + } + pa.ConnectionID = connID + if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { + return err + } + if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { + return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) + } + p.PreferredAddress = pa + return nil +} + +func (p *TransportParameters) readNumericTransportParameter( + r *bytes.Reader, + paramID transportParameterID, + expectedLen int, +) error { + remainingLen := r.Len() + val, err := quicvarint.Read(r) + if err != nil { + return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) + } + if remainingLen-r.Len() != expectedLen { + return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) + } + //nolint:exhaustive // This only covers the numeric transport parameters. + switch paramID { + case initialMaxStreamDataBidiLocalParameterID: + p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) + case initialMaxStreamDataBidiRemoteParameterID: + p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) + case initialMaxStreamDataUniParameterID: + p.InitialMaxStreamDataUni = protocol.ByteCount(val) + case initialMaxDataParameterID: + p.InitialMaxData = protocol.ByteCount(val) + case initialMaxStreamsBidiParameterID: + p.MaxBidiStreamNum = protocol.StreamNum(val) + if p.MaxBidiStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) + } + case initialMaxStreamsUniParameterID: + p.MaxUniStreamNum = protocol.StreamNum(val) + if p.MaxUniStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) + } + case maxIdleTimeoutParameterID: + p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + case maxUDPPayloadSizeParameterID: + if val < 1200 { + return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) + } + p.MaxUDPPayloadSize = protocol.ByteCount(val) + case ackDelayExponentParameterID: + if val > protocol.MaxAckDelayExponent { + return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) + } + p.AckDelayExponent = uint8(val) + case maxAckDelayParameterID: + if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { + return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) + } + p.MaxAckDelay = time.Duration(val) * time.Millisecond + case activeConnectionIDLimitParameterID: + p.ActiveConnectionIDLimit = val + case maxDatagramFrameSizeParameterID: + p.MaxDatagramFrameSize = protocol.ByteCount(val) + default: + return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) + } + return nil +} + +// Marshal the transport parameters +func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { + b := &bytes.Buffer{} + + // add a greased value + quicvarint.Write(b, uint64(27+31*rand.Intn(100))) + length := rand.Intn(16) + randomData := make([]byte, length) + rand.Read(randomData) + quicvarint.Write(b, uint64(length)) + b.Write(randomData) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // idle_timeout + p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) + // max_packet_size + p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) + // max_ack_delay + // Only send it if is different from the default value. + if p.MaxAckDelay != protocol.DefaultMaxAckDelay { + p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) + } + // ack_delay_exponent + // Only send it if is different from the default value. + if p.AckDelayExponent != protocol.DefaultAckDelayExponent { + p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) + } + // disable_active_migration + if p.DisableActiveMigration { + quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) + quicvarint.Write(b, 0) + } + if pers == protocol.PerspectiveServer { + // stateless_reset_token + if p.StatelessResetToken != nil { + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 16) + b.Write(p.StatelessResetToken[:]) + } + // original_destination_connection_id + quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) + b.Write(p.OriginalDestinationConnectionID.Bytes()) + // preferred_address + if p.PreferredAddress != nil { + quicvarint.Write(b, uint64(preferredAddressParameterID)) + quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + ipv4 := p.PreferredAddress.IPv4 + b.Write(ipv4[len(ipv4)-4:]) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) + b.Write(p.PreferredAddress.IPv6) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) + b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) + b.Write(p.PreferredAddress.ConnectionID.Bytes()) + b.Write(p.PreferredAddress.StatelessResetToken[:]) + } + } + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + // initial_source_connection_id + quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) + b.Write(p.InitialSourceConnectionID.Bytes()) + // retry_source_connection_id + if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { + quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) + b.Write(p.RetrySourceConnectionID.Bytes()) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } + return b.Bytes() +} + +func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { + quicvarint.Write(b, uint64(id)) + quicvarint.Write(b, uint64(quicvarint.Len(val))) + quicvarint.Write(b, val) +} + +// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. +// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. +// The client will remember the transport parameters used in the last session, +// and apply those to the 0-RTT data it sends. +// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT +// if the transport parameters changed. +// Since the session ticket is encrypted, the serialization format is defined by the server. +// For convenience, we use the same format that we also use for sending the transport parameters. +func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { + quicvarint.Write(b, transportParameterMarshalingVersion) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) +} + +// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. +func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { + version, err := quicvarint.Read(r) + if err != nil { + return err + } + if version != transportParameterMarshalingVersion { + return fmt.Errorf("unknown transport parameter marshaling version: %d", version) + } + return p.unmarshal(r, protocol.PerspectiveServer, true) +} + +// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. +func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && + p.InitialMaxData >= saved.InitialMaxData && + p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && + p.MaxUniStreamNum >= saved.MaxUniStreamNum && + p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit +} + +// String returns a string representation, intended for logging. +func (p *TransportParameters) String() string { + logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " + logParams := []interface{}{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} + if p.RetrySourceConnectionID != nil { + logString += "RetrySourceConnectionID: %s, " + logParams = append(logParams, p.RetrySourceConnectionID) + } + logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" + logParams = append(logParams, []interface{}{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) + if p.StatelessResetToken != nil { // the client never sends a stateless reset token + logString += ", StatelessResetToken: %#x" + logParams = append(logParams, *p.StatelessResetToken) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + logString += ", MaxDatagramFrameSize: %d" + logParams = append(logParams, p.MaxDatagramFrameSize) + } + logString += "}" + return fmt.Sprintf(logString, logParams...) +} diff --git a/internal/quic-go/wire/version_negotiation.go b/internal/quic-go/wire/version_negotiation.go new file mode 100644 index 00000000..ee1613e4 --- /dev/null +++ b/internal/quic-go/wire/version_negotiation.go @@ -0,0 +1,54 @@ +package wire + +import ( + "bytes" + "crypto/rand" + "errors" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/utils" +) + +// ParseVersionNegotiationPacket parses a Version Negotiation packet. +func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { + hdr, err := parseHeader(b, 0) + if err != nil { + return nil, nil, err + } + if b.Len() == 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has empty version list") + } + if b.Len()%4 != 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + } + versions := make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, nil, err + } + versions[i] = protocol.VersionNumber(v) + } + return hdr, versions, nil +} + +// ComposeVersionNegotiation composes a Version Negotiation +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { + greasedVersions := protocol.GetGreasedVersions(versions) + expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 + buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) + r := make([]byte, 1) + _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf.WriteByte(r[0] | 0x80) + utils.BigEndian.WriteUint32(buf, 0) // version 0 + buf.WriteByte(uint8(destConnID.Len())) + buf.Write(destConnID) + buf.WriteByte(uint8(srcConnID.Len())) + buf.Write(srcConnID) + for _, v := range greasedVersions { + utils.BigEndian.WriteUint32(buf, uint32(v)) + } + return buf.Bytes() +} diff --git a/internal/quic-go/wire/version_negotiation_test.go b/internal/quic-go/wire/version_negotiation_test.go new file mode 100644 index 00000000..9a312b82 --- /dev/null +++ b/internal/quic-go/wire/version_negotiation_test.go @@ -0,0 +1,83 @@ +package wire + +import ( + "bytes" + "encoding/binary" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Version Negotiation Packets", func() { + It("parses a Version Negotiation packet", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + versions := []protocol.VersionNumber{0x22334455, 0x33445566} + data := []byte{0x80, 0, 0, 0, 0} + data = append(data, uint8(len(destConnID))) + data = append(data, destConnID...) + data = append(data, uint8(len(srcConnID))) + data = append(data, srcConnID...) + for _, v := range versions { + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) + } + Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.Version).To(BeZero()) + Expect(supportedVersions).To(Equal(versions)) + }) + + It("errors if it contains versions of the wrong length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []protocol.VersionNumber{0x22334455, 0x33445566} + data := ComposeVersionNegotiation(connID, connID, versions) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) + }) + + It("errors if the version list is empty", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []protocol.VersionNumber{0x22334455} + data := ComposeVersionNegotiation(connID, connID, versions) + // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number + data = data[:len(data)-8] + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).To(MatchError("Version Negotiation packet has empty version list")) + }) + + It("adds a reserved version", func() { + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []protocol.VersionNumber{1001, 1003} + data := ComposeVersionNegotiation(destConnID, srcConnID, versions) + Expect(data[0] & 0x80).ToNot(BeZero()) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.Version).To(BeZero()) + // the supported versions should include one reserved version number + Expect(supportedVersions).To(HaveLen(len(versions) + 1)) + for _, v := range versions { + Expect(supportedVersions).To(ContainElement(v)) + } + var reservedVersion protocol.VersionNumber + versionLoop: + for _, ver := range supportedVersions { + for _, v := range versions { + if v == ver { + continue versionLoop + } + } + reservedVersion = ver + } + Expect(reservedVersion).ToNot(BeZero()) + Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number + }) +}) diff --git a/internal/quic-go/wire/wire_suite_test.go b/internal/quic-go/wire/wire_suite_test.go new file mode 100644 index 00000000..54e5c70b --- /dev/null +++ b/internal/quic-go/wire/wire_suite_test.go @@ -0,0 +1,31 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/imroc/req/v3/internal/quic-go/protocol" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestWire(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Wire Suite") +} + +func encodeVarInt(i uint64) []byte { + b := &bytes.Buffer{} + quicvarint.Write(b, i) + return b.Bytes() +} + +func appendVersion(data []byte, v protocol.VersionNumber) []byte { + offset := len(data) + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[offset:], uint32(v)) + return data +} From ef46f078c69f88ca2374380215f208b722e2040b Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 16 Jan 2023 19:15:21 +0800 Subject: [PATCH 649/843] Update go modules --- go.mod | 14 +++++----- go.sum | 81 ++++++++++++---------------------------------------------- 2 files changed, 24 insertions(+), 71 deletions(-) diff --git a/go.mod b/go.mod index 74e6a7da..bf0e8046 100644 --- a/go.mod +++ b/go.mod @@ -6,21 +6,23 @@ require ( github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 github.com/hashicorp/go-multierror v1.1.1 - github.com/marten-seemann/qpack v0.3.0 + github.com/marten-seemann/qpack v0.2.1 github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 github.com/marten-seemann/qtls-go1-18 v0.1.3 github.com/marten-seemann/qtls-go1-19 v0.1.1 github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.24.1 - golang.org/x/crypto v0.1.0 - golang.org/x/net v0.5.0 - golang.org/x/sys v0.4.0 - golang.org/x/text v0.6.0 + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 + golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a + golang.org/x/text v0.3.7 ) require ( github.com/cheekybits/genny v1.0.0 github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/onsi/gomega v1.13.0 + golang.org/x/tools v0.1.11 // indirect + google.golang.org/protobuf v1.28.0 // indirect ) diff --git a/go.sum b/go.sum index 505b6184..956ca20b 100644 --- a/go.sum +++ b/go.sum @@ -14,9 +14,6 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -33,8 +30,6 @@ github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmV github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= -github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= -github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -62,14 +57,10 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -81,7 +72,6 @@ github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -92,8 +82,8 @@ github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= -github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= @@ -113,27 +103,14 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= -github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= -github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= -github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= -github.com/onsi/ginkgo/v2 v2.5.0 h1:TRtrvv2vdQqzkwrQ1ke6vtXf7IK34RBUJafIy1wMwls= -github.com/onsi/ginkgo/v2 v2.5.0/go.mod h1:Luc4sArBICYCS8THh8v3i3i5CuSZO+RaQRaJoeNwomw= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= -github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= -github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= -github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= -github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= -github.com/onsi/gomega v1.24.1 h1:KORJXNNTzJXzu4ScJWssJfJMnJ+2QJqhoQSRwNlze9E= -github.com/onsi/gomega v1.24.1/go.mod h1:3AOiACssS3/MajrniINInwbfOOtfZvplPzuRSmvt1jM= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -178,7 +155,6 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= @@ -188,19 +164,16 @@ golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= -golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -213,18 +186,13 @@ golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= -golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -237,7 +205,6 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -247,9 +214,9 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -258,29 +225,17 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -292,10 +247,8 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= -golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= +golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -338,8 +291,6 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From 0b160ef6df9dfc55882f276cd95d4e3c10bb5352 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 16 Jan 2023 19:30:35 +0800 Subject: [PATCH 650/843] Update go modules --- go.mod | 5 ++--- go.sum | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index bf0e8046..c2240999 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/imroc/req/v3 go 1.16 require ( + github.com/cheekybits/genny v1.0.0 github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 github.com/hashicorp/go-multierror v1.1.1 @@ -12,6 +13,7 @@ require ( github.com/marten-seemann/qtls-go1-18 v0.1.3 github.com/marten-seemann/qtls-go1-19 v0.1.1 github.com/onsi/ginkgo v1.16.5 + github.com/onsi/gomega v1.13.0 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a @@ -19,10 +21,7 @@ require ( ) require ( - github.com/cheekybits/genny v1.0.0 github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/onsi/gomega v1.13.0 golang.org/x/tools v0.1.11 // indirect - google.golang.org/protobuf v1.28.0 // indirect ) diff --git a/go.sum b/go.sum index 956ca20b..6fda9b95 100644 --- a/go.sum +++ b/go.sum @@ -276,9 +276,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= From b43a82b2aba978e26db502f9385ec61e83e8f327 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 16 Jan 2023 20:59:30 +0800 Subject: [PATCH 651/843] Add sponsor to README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index c901753c..c2052337 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,18 @@ If you have questions, feel free to reach out to us in the following ways: * [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) | [Join](https://slack.req.cool/) * QQ Group (Chinese): 621411351 - +## Sponsor + +If you like req and it really helps you, feel free to reward me with a cup of coffee, and don't forget to mention your github id. + +Wechat: + +![](https://req.cool/images/wechat.jpg) + +Alipay: + +![](https://req.cool/images/alipay.jpg) + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. From 660c4cabfb960195461980c97ca2d93cddd626d2 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 16 Jan 2023 21:16:22 +0800 Subject: [PATCH 652/843] Add aadog to the sponsors list --- README.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c2052337..30abf108 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ If you have questions, feel free to reach out to us in the following ways: * [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) | [Join](https://slack.req.cool/) * QQ Group (Chinese): 621411351 - -## Sponsor +## Sponsors If you like req and it really helps you, feel free to reward me with a cup of coffee, and don't forget to mention your github id. @@ -143,6 +143,21 @@ Alipay: ![](https://req.cool/images/alipay.jpg) +Many thanks to the following sponsors: + + + + + +
+ + +
+ aadog 🥇 +
+
+ + ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. From 2ecc7865e9037b2593b2fefd3ef0ba5a0d1b4178 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 Jan 2023 14:53:00 +0800 Subject: [PATCH 653/843] Add M-Cosmosss to the sponsors list --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 30abf108..2928971d 100644 --- a/README.md +++ b/README.md @@ -146,18 +146,26 @@ Alipay: Many thanks to the following sponsors: + + +
+ + +
+ M-Cosmosss 🥇 +
+

- aadog 🥇 + aadog 🥈
- ## License `Req` released under MIT license, refer [LICENSE](LICENSE) file. From b4799505026b090bf23094796c20422fff12e362 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 Jan 2023 14:54:10 +0800 Subject: [PATCH 654/843] Update README --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 2928971d..3042c44f 100644 --- a/README.md +++ b/README.md @@ -154,8 +154,6 @@ Many thanks to the following sponsors: M-Cosmosss 🥇 - - From a00754ea83c037e2c9a846899e375a10588f8216 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 Jan 2023 15:02:07 +0800 Subject: [PATCH 655/843] Update README --- README.md | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 3042c44f..d33858d8 100644 --- a/README.md +++ b/README.md @@ -135,13 +135,20 @@ If you have questions, feel free to reach out to us in the following ways: If you like req and it really helps you, feel free to reward me with a cup of coffee, and don't forget to mention your github id. -Wechat: - -![](https://req.cool/images/wechat.jpg) - -Alipay: - -![](https://req.cool/images/alipay.jpg) + + + + + +
+ +
+ Wechat +
+ +
+ Alipay +
Many thanks to the following sponsors: From ce5480115f54c8ab30ece8d77e9e11f078453830 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 15:14:03 +0800 Subject: [PATCH 656/843] Refactor API style For Client: * Deprecate SetCommonError, add SetCommonErrorResult * Add SetCommonUnknownResultHandlerFunc * Add SetCommonResultStateChecker For Request: * Deprecate SetResult, add SetSuccessResult * Deprecate SetError, add SetErrorResult * Add SetUnknownResultHandlerFunc * Add SetResultStateChecker For Response: * Deprecate IsSuccess, add IsSuccessState * Deprecate IsError, add IsErrorState * Deprecate Result, add SuccessResult * Deprecate Error, add ErrorResult * Add ResultState --- client.go | 78 ++++++++++++++++++++++++++++++++++++++------------- middleware.go | 41 ++++++++++++++++++++------- request.go | 61 ++++++++++++++++++++++++++++++++++++++-- response.go | 78 +++++++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 213 insertions(+), 45 deletions(-) diff --git a/client.go b/client.go index b4b46235..72fc4202 100644 --- a/client.go +++ b/client.go @@ -48,25 +48,27 @@ type Client struct { DebugLog bool AllowGetMethodPayload bool - trace bool - disableAutoReadResponse bool - commonErrorType reflect.Type - retryOption *retryOption - jsonMarshal func(v interface{}) ([]byte, error) - jsonUnmarshal func(data []byte, v interface{}) error - xmlMarshal func(v interface{}) ([]byte, error) - xmlUnmarshal func(data []byte, v interface{}) error - outputDirectory string - scheme string - log Logger - t *Transport - dumpOptions *DumpOptions - httpClient *http.Client - beforeRequest []RequestMiddleware - udBeforeRequest []RequestMiddleware - afterResponse []ResponseMiddleware - wrappedRoundTrip RoundTripper - responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) + trace bool + disableAutoReadResponse bool + commonErrorType reflect.Type + retryOption *retryOption + jsonMarshal func(v interface{}) ([]byte, error) + jsonUnmarshal func(data []byte, v interface{}) error + xmlMarshal func(v interface{}) ([]byte, error) + xmlUnmarshal func(data []byte, v interface{}) error + outputDirectory string + scheme string + log Logger + t *Transport + dumpOptions *DumpOptions + httpClient *http.Client + beforeRequest []RequestMiddleware + udBeforeRequest []RequestMiddleware + afterResponse []ResponseMiddleware + wrappedRoundTrip RoundTripper + responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) + unknownResultHandlerFunc func(resp *Response) error + resultStateCheckFunc func(resp *Response) ResultState } // R create a new request. @@ -160,14 +162,41 @@ func (c *Client) SetResponseBodyTransformer(fn func(rawBody []byte, req *Request } // SetCommonError set the common result that response body will be unmarshalled to -// if it is an error response ( status `code >= 400`). +// if no error occurs but Response.ResultState returns ErrorState, by default it +// is HTTP status `code >= 400`, you can also use SetCommonResultStateChecker +// to customize the result state check logic. +// +// Deprecated: Use SetCommonErrorResult instead. func (c *Client) SetCommonError(err interface{}) *Client { + return c.SetCommonErrorResult(err) +} + +// SetCommonErrorResult set the common result that response body will be unmarshalled to +// if no error occurs but Response.ResultState returns ErrorState, by default it +// is HTTP status `code >= 400`, you can also use SetCommonResultStateChecker +// to customize the result state check logic. +func (c *Client) SetCommonErrorResult(err interface{}) *Client { if err != nil { c.commonErrorType = util.GetType(err) } return c } +// SetCommonUnknownResultHandlerFunc set the response handler which will be executed when no +// error occurs, but Response.ResultState returns UnknownState. +func (c *Client) SetCommonUnknownResultHandlerFunc(handler func(resp *Response) error) *Client { + c.unknownResultHandlerFunc = handler + return c +} + +// SetCommonResultStateChecker overrides the default result state checker with customized one, +// which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns +// ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. +func (c *Client) SetCommonResultStateChecker(checker func(resp *Response) ResultState) *Client { + c.resultStateCheckFunc = checker + return c +} + // SetCommonFormDataFromValues set the form data from url.Values for all requests // which request method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { @@ -1319,6 +1348,15 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { httpResponse, err = c.httpClient.Do(r.RawRequest) resp.Response = httpResponse + // setup resultStateCheckFunc + if r.resultStateCheckFunc == nil { + if c.resultStateCheckFunc != nil { + r.resultStateCheckFunc = c.resultStateCheckFunc + } else { + r.resultStateCheckFunc = defaultResultStateChecker + } + } + // auto-read response body if possible if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { _, err = resp.ToBytes() diff --git a/middleware.go b/middleware.go index b2a22dbb..a53ae46c 100644 --- a/middleware.go +++ b/middleware.go @@ -243,21 +243,34 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { return } +func defaultResultStateChecker(resp *Response) ResultState { + if code := resp.StatusCode; code > 199 && code < 300 { + return SuccessState + } else if code > 399 { + return ErrorState + } else { + return UnknownState + } +} + func parseResponseBody(c *Client, r *Response) (err error) { if r.Response == nil || r.StatusCode == http.StatusNoContent { return } - if r.Request.Result != nil && r.IsSuccess() { - err = unmarshalBody(c, r, r.Request.Result) - if err == nil { - r.result = r.Request.Result + req := r.Request + switch req.resultStateCheckFunc(r) { + case SuccessState: + if req.Result != nil { + err = unmarshalBody(c, r, r.Request.Result) + if err == nil { + r.result = r.Request.Result + } } - } - if r.IsError() { - if r.Request.Error != nil { - err = unmarshalBody(c, r, r.Request.Error) + case ErrorState: + if req.Error != nil { + err = unmarshalBody(c, r, req.Error) if err == nil { - r.error = r.Request.Error + r.error = req.Error } } else if c.commonErrorType != nil { e := reflect.New(c.commonErrorType).Interface() @@ -266,6 +279,14 @@ func parseResponseBody(c *Client, r *Response) (err error) { r.error = e } } + default: + handleUnknownResult := req.unknownResultHandlerFunc + if handleUnknownResult == nil { + handleUnknownResult = c.unknownResultHandlerFunc + } + if handleUnknownResult != nil { + handleUnknownResult(r) + } } return } @@ -325,7 +346,7 @@ func (r *callbackReader) Read(p []byte) (n int, err error) { } func handleDownload(c *Client, r *Response) (err error) { - if !r.Request.isSaveResponse { + if r.Response == nil || !r.Request.isSaveResponse { return nil } var body io.ReadCloser diff --git a/request.go b/request.go index 94dec749..5ccf68b6 100644 --- a/request.go +++ b/request.go @@ -66,6 +66,8 @@ type Request struct { trace *clientTrace dumpBuffer *bytes.Buffer responseReturnTime time.Time + unknownResultHandlerFunc func(resp *Response) error + resultStateCheckFunc func(resp *Response) ResultState } type GetContentFunc func() (io.ReadCloser, error) @@ -348,19 +350,72 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min } // SetResult set the result that response Body will be unmarshalled to if -// request is success (status `code >= 200 and <= 299`). +// Response.IsSuccess() == true (Response.Err == nil && status `code >= 200 and <= 299`). +// +// Deprecated: Use SetSuccessResult instead. func (r *Request) SetResult(result interface{}) *Request { + return r.SetSuccessResult(result) +} + +// SetSuccessResult set the result that response Body will be unmarshalled to if +// Response.IsSuccess() == true (Response.Err == nil && status `code >= 200 and <= 299`). +func (r *Request) SetSuccessResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r } -// SetError set the result that response Body will be unmarshalled to if -// request is error ( status `code >= 400`). +// SetError set the result that response body will be unmarshalled to if +// no error occurs and Response.ResultState() returns ErrorState, by default +// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateChecker +// or Client.SetCommonResultStateChecker to customize the result state check logic. +// +// Deprecated: Use SetErrorResult result. func (r *Request) SetError(error interface{}) *Request { + return r.SetErrorResult(error) +} + +// SetErrorResult set the result that response body will be unmarshalled to if +// no error occurs and Response.ResultState() returns ErrorState, by default +// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateChecker +// or Client.SetCommonResultStateChecker to customize the result state check logic. +func (r *Request) SetErrorResult(error interface{}) *Request { r.Error = util.GetPointer(error) return r } +// SetUnknownResultHandlerFunc set the response handler which will be executed when no +// error occurs, but Response.ResultState returns UnknownState. +func (r *Request) SetUnknownResultHandlerFunc(handler func(resp *Response) error) *Request { + r.unknownResultHandlerFunc = handler + return r +} + +// ResultState represents the state of the result. +type ResultState int + +const ( + // SuccessState indicates the response is in success state, + // and result will be unmarshalled if Request.SetSuccessResult + // is called. + SuccessState ResultState = iota + // ErrorState indicates the response is in error state, + // and result will be unmarshalled if Request.SetErrorResult + // or Client.SetCommonErrorResult is called. + ErrorState + // UnknownState indicates the response is in unknown state, + // and handler will be invoked if Request.SetUnknownResultHandlerFunc + // or Client.SetCommonUnknownResultHandlerFunc is called. + UnknownState +) + +// SetResultStateChecker overrides the default result state checker with customized one, +// which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns +// ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. +func (r *Request) SetResultStateChecker(checker func(resp *Response) ResultState) *Request { + r.resultStateCheckFunc = checker + return r +} + // SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) diff --git a/response.go b/response.go index 6cbcb657..73140c6c 100644 --- a/response.go +++ b/response.go @@ -25,20 +25,42 @@ type Response struct { result interface{} } -// IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. +// IsSuccess method returns true if no error occurs and HTTP status `code >= 200 and <= 299` +// by default, you can also use Request.SetResultStateChecker to customize the result +// state check logic. +// +// Deprecated: Use IsSuccessState instead. func (r *Response) IsSuccess() bool { + return r.IsSuccessState() +} + +// IsSuccessState method returns true if no error occurs and HTTP status `code >= 200 and <= 299` +// by default, you can also use Request.SetResultStateChecker to customize the result state +// check logic. +func (r *Response) IsSuccessState() bool { if r.Response == nil { return false } - return r.StatusCode > 199 && r.StatusCode < 300 + return r.ResultState() == SuccessState } -// IsError method returns true if HTTP status `code >= 400` otherwise false. +// IsError method returns true if no error occurs and HTTP status `code >= 400` +// by default, you can also use Request.SetResultStateChecker to customize the result +// state check logic. +// +// Deprecated: Use IsErrorState instead. func (r *Response) IsError() bool { + return r.IsErrorState() +} + +// IsErrorState method returns true if no error occurs and HTTP status `code >= 400` +// by default, you can also use Request.SetResultStateChecker to customize the result +// state check logic. +func (r *Response) IsErrorState() bool { if r.Response == nil { return false } - return r.StatusCode > 399 + return r.ResultState() == ErrorState } // GetContentType return the `Content-Type` header value. @@ -49,18 +71,50 @@ func (r *Response) GetContentType() string { return r.Header.Get(header.ContentType) } -// Result returns the automatically unmarshalled object if Request.SetResult is called, -// and Response.IsSuccess() == true. Otherwise return nil. +// ResultState returns the result state. +// By default, it returns SuccessState if HTTP status `code >= 400`, and returns +// ErrorState if HTTP status `code >= 400`, otherwise returns UnknownState. +// You can also use Request.SetResultStateChecker or Client.SetCommonResultStateChecker +// to customize the result state check logic. +func (r *Response) ResultState() ResultState { + if r.Response == nil { + return UnknownState + } + return r.Request.resultStateCheckFunc(r) +} + +// Result returns the automatically unmarshalled object if Request.SetSuccessResult +// is called and ResultState returns SuccessState. +// Otherwise, return nil. +// +// Deprecated: Use SuccessResult instead. func (r *Response) Result() interface{} { + return r.SuccessResult() +} + +// SuccessResult returns the automatically unmarshalled object if Request.SetSuccessResult +// is called and ResultState returns SuccessState. +// Otherwise, return nil. +func (r *Response) SuccessResult() interface{} { return r.result } -// Error returns the automatically unmarshalled object when Request.SetError is called, -// and Response.IsError() == true. Otherwise return nil. +// Error returns the automatically unmarshalled object when Request.SetErrorResult +// or Client.SetCommonErrorResult is called, and ResultState returns ErrorState. +// Otherwise, return nil. +// +// Deprecated: Use ErrorResult instead. func (r *Response) Error() interface{} { return r.error } +// ErrorResult returns the automatically unmarshalled object when Request.SetErrorResult +// or Client.SetCommonErrorResult is called, and ResultState returns ErrorState. +// Otherwise, return nil. +func (r *Response) ErrorResult() interface{} { + return r.error +} + // TraceInfo returns the TraceInfo from Request. func (r *Response) TraceInfo() TraceInfo { return r.Request.TraceInfo() @@ -86,7 +140,7 @@ func (r *Response) setReceivedAt() { } } -// UnmarshalJson unmarshals JSON response body into the specified object. +// UnmarshalJson unmarshalls JSON response body into the specified object. func (r *Response) UnmarshalJson(v interface{}) error { if r.Err != nil { return r.Err @@ -98,7 +152,7 @@ func (r *Response) UnmarshalJson(v interface{}) error { return r.Request.client.jsonUnmarshal(b, v) } -// UnmarshalXml unmarshals XML response body into the specified object. +// UnmarshalXml unmarshalls XML response body into the specified object. func (r *Response) UnmarshalXml(v interface{}) error { if r.Err != nil { return r.Err @@ -110,7 +164,7 @@ func (r *Response) UnmarshalXml(v interface{}) error { return r.Request.client.xmlUnmarshal(b, v) } -// Unmarshal unmarshals response body into the specified object according +// Unmarshal unmarshalls response body into the specified object according // to response `Content-Type`. func (r *Response) Unmarshal(v interface{}) error { if r.Err != nil { @@ -126,7 +180,7 @@ func (r *Response) Unmarshal(v interface{}) error { return r.UnmarshalJson(v) } -// Into unmarshals response body into the specified object according +// Into unmarshalls response body into the specified object according // to response `Content-Type`. func (r *Response) Into(v interface{}) error { return r.Unmarshal(v) From 15ce662b4ab4078643b82f9c76b2bd135fd8cac3 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 15:25:55 +0800 Subject: [PATCH 657/843] fix comments --- request.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index 5ccf68b6..b929a10f 100644 --- a/request.go +++ b/request.go @@ -350,7 +350,10 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min } // SetResult set the result that response Body will be unmarshalled to if -// Response.IsSuccess() == true (Response.Err == nil && status `code >= 200 and <= 299`). +// no error occurs and Response.ResultState() returns SuccessState, by default +// it requires HTTP status `code >= 200 && code <= 299`, you can also use +// Request.SetResultStateChecker or Client.SetCommonResultStateChecker to customize +// the result state check logic. // // Deprecated: Use SetSuccessResult instead. func (r *Request) SetResult(result interface{}) *Request { @@ -358,7 +361,10 @@ func (r *Request) SetResult(result interface{}) *Request { } // SetSuccessResult set the result that response Body will be unmarshalled to if -// Response.IsSuccess() == true (Response.Err == nil && status `code >= 200 and <= 299`). +// no error occurs and Response.ResultState() returns SuccessState, by default +// it requires HTTP status `code >= 200 && code <= 299`, you can also use +// Request.SetResultStateChecker or Client.SetCommonResultStateChecker to customize +// the result state check logic. func (r *Request) SetSuccessResult(result interface{}) *Request { r.Result = util.GetPointer(result) return r From 220aaf505e7b9ad351f94e4d90cb606cae939840 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 15:39:31 +0800 Subject: [PATCH 658/843] Complete global wrapper and rename Checker to CheckFunc --- client.go | 10 +++++----- client_wrapper.go | 24 ++++++++++++++++++++++-- request.go | 22 +++++++++++----------- request_wrapper.go | 36 ++++++++++++++++++++++++++++++++---- response.go | 10 +++++----- 5 files changed, 75 insertions(+), 27 deletions(-) diff --git a/client.go b/client.go index 72fc4202..d65a02c2 100644 --- a/client.go +++ b/client.go @@ -184,16 +184,16 @@ func (c *Client) SetCommonErrorResult(err interface{}) *Client { // SetCommonUnknownResultHandlerFunc set the response handler which will be executed when no // error occurs, but Response.ResultState returns UnknownState. -func (c *Client) SetCommonUnknownResultHandlerFunc(handler func(resp *Response) error) *Client { - c.unknownResultHandlerFunc = handler +func (c *Client) SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error) *Client { + c.unknownResultHandlerFunc = fn return c } -// SetCommonResultStateChecker overrides the default result state checker with customized one, +// SetCommonResultStateCheckFunc overrides the default result state checker with customized one, // which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns // ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. -func (c *Client) SetCommonResultStateChecker(checker func(resp *Response) ResultState) *Client { - c.resultStateCheckFunc = checker +func (c *Client) SetCommonResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { + c.resultStateCheckFunc = fn return c } diff --git a/client_wrapper.go b/client_wrapper.go index 3031bd5a..05e22b54 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -23,9 +23,29 @@ func WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { } // SetCommonError is a global wrapper methods which delegated -// to the default client's SetCommonError. +// to the default client's SetCommonErrorResult. +// +// Deprecated: Use SetCommonErrorResult instead. func SetCommonError(err interface{}) *Client { - return defaultClient.SetCommonError(err) + return defaultClient.SetCommonErrorResult(err) +} + +// SetCommonErrorResult is a global wrapper methods which delegated +// to the default client's SetCommonError. +func SetCommonErrorResult(err interface{}) *Client { + return defaultClient.SetCommonErrorResult(err) +} + +// SetCommonUnknownResultHandlerFunc is a global wrapper methods which delegated +// to the default client's SetCommonUnknownResultHandlerFunc. +func SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error) *Client { + return defaultClient.SetCommonUnknownResultHandlerFunc(fn) +} + +// SetCommonResultStateCheckFunc is a global wrapper methods which delegated +// to the default client's SetCommonResultStateCheckFunc. +func SetCommonResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { + return defaultClient.SetCommonResultStateCheckFunc(fn) } // SetCommonFormDataFromValues is a global wrapper methods which delegated diff --git a/request.go b/request.go index b929a10f..af655817 100644 --- a/request.go +++ b/request.go @@ -352,7 +352,7 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min // SetResult set the result that response Body will be unmarshalled to if // no error occurs and Response.ResultState() returns SuccessState, by default // it requires HTTP status `code >= 200 && code <= 299`, you can also use -// Request.SetResultStateChecker or Client.SetCommonResultStateChecker to customize +// Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc to customize // the result state check logic. // // Deprecated: Use SetSuccessResult instead. @@ -363,7 +363,7 @@ func (r *Request) SetResult(result interface{}) *Request { // SetSuccessResult set the result that response Body will be unmarshalled to if // no error occurs and Response.ResultState() returns SuccessState, by default // it requires HTTP status `code >= 200 && code <= 299`, you can also use -// Request.SetResultStateChecker or Client.SetCommonResultStateChecker to customize +// Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc to customize // the result state check logic. func (r *Request) SetSuccessResult(result interface{}) *Request { r.Result = util.GetPointer(result) @@ -372,8 +372,8 @@ func (r *Request) SetSuccessResult(result interface{}) *Request { // SetError set the result that response body will be unmarshalled to if // no error occurs and Response.ResultState() returns ErrorState, by default -// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateChecker -// or Client.SetCommonResultStateChecker to customize the result state check logic. +// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateCheckFunc +// or Client.SetCommonResultStateCheckFunc to customize the result state check logic. // // Deprecated: Use SetErrorResult result. func (r *Request) SetError(error interface{}) *Request { @@ -382,8 +382,8 @@ func (r *Request) SetError(error interface{}) *Request { // SetErrorResult set the result that response body will be unmarshalled to if // no error occurs and Response.ResultState() returns ErrorState, by default -// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateChecker -// or Client.SetCommonResultStateChecker to customize the result state check logic. +// it requires HTTP status `code >= 400`, you can also use Request.SetResultStateCheckFunc +// or Client.SetCommonResultStateCheckFunc to customize the result state check logic. func (r *Request) SetErrorResult(error interface{}) *Request { r.Error = util.GetPointer(error) return r @@ -391,8 +391,8 @@ func (r *Request) SetErrorResult(error interface{}) *Request { // SetUnknownResultHandlerFunc set the response handler which will be executed when no // error occurs, but Response.ResultState returns UnknownState. -func (r *Request) SetUnknownResultHandlerFunc(handler func(resp *Response) error) *Request { - r.unknownResultHandlerFunc = handler +func (r *Request) SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Request { + r.unknownResultHandlerFunc = fn return r } @@ -414,11 +414,11 @@ const ( UnknownState ) -// SetResultStateChecker overrides the default result state checker with customized one, +// SetResultStateCheckFunc overrides the default result state checker with customized one, // which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns // ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. -func (r *Request) SetResultStateChecker(checker func(resp *Response) ResultState) *Request { - r.resultStateCheckFunc = checker +func (r *Request) SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Request { + r.resultStateCheckFunc = fn return r } diff --git a/request_wrapper.go b/request_wrapper.go index 9fabc9de..47f154f7 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -75,15 +75,43 @@ func SetFileUpload(f ...FileUpload) *Request { } // SetResult is a global wrapper methods which delegated -// to the default client, create a request and SetResult for request. +// to the default client, create a request and SetSuccessResult for request. +// +// Deprecated: Use SetSuccessResult instead. func SetResult(result interface{}) *Request { - return defaultClient.R().SetResult(result) + return defaultClient.R().SetSuccessResult(result) +} + +// SetSuccessResult is a global wrapper methods which delegated +// to the default client, create a request and SetSuccessResult for request. +func SetSuccessResult(result interface{}) *Request { + return defaultClient.R().SetSuccessResult(result) } // SetError is a global wrapper methods which delegated -// to the default client, create a request and SetError for request. +// to the default client, create a request and SetErrorResult for request. +// +// Deprecated: Use SetErrorResult instead. func SetError(error interface{}) *Request { - return defaultClient.R().SetError(error) + return defaultClient.R().SetErrorResult(error) +} + +// SetErrorResult is a global wrapper methods which delegated +// to the default client, create a request and SetErrorResult for request. +func SetErrorResult(error interface{}) *Request { + return defaultClient.R().SetErrorResult(error) +} + +// SetResultStateCheckFunc is a global wrapper methods which delegated +// to the default client, create a request and SetResultStateCheckFunc for request. +func SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Request { + return defaultClient.R().SetResultStateCheckFunc(fn) +} + +// SetUnknownResultHandlerFunc is a global wrapper methods which delegated +// to the default client, create a request and SetUnknownResultHandlerFunc for request. +func SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Request { + return defaultClient.R().SetUnknownResultHandlerFunc(fn) } // SetBearerAuthToken is a global wrapper methods which delegated diff --git a/response.go b/response.go index 73140c6c..eaf29400 100644 --- a/response.go +++ b/response.go @@ -26,7 +26,7 @@ type Response struct { } // IsSuccess method returns true if no error occurs and HTTP status `code >= 200 and <= 299` -// by default, you can also use Request.SetResultStateChecker to customize the result +// by default, you can also use Request.SetResultStateCheckFunc to customize the result // state check logic. // // Deprecated: Use IsSuccessState instead. @@ -35,7 +35,7 @@ func (r *Response) IsSuccess() bool { } // IsSuccessState method returns true if no error occurs and HTTP status `code >= 200 and <= 299` -// by default, you can also use Request.SetResultStateChecker to customize the result state +// by default, you can also use Request.SetResultStateCheckFunc to customize the result state // check logic. func (r *Response) IsSuccessState() bool { if r.Response == nil { @@ -45,7 +45,7 @@ func (r *Response) IsSuccessState() bool { } // IsError method returns true if no error occurs and HTTP status `code >= 400` -// by default, you can also use Request.SetResultStateChecker to customize the result +// by default, you can also use Request.SetResultStateCheckFunc to customize the result // state check logic. // // Deprecated: Use IsErrorState instead. @@ -54,7 +54,7 @@ func (r *Response) IsError() bool { } // IsErrorState method returns true if no error occurs and HTTP status `code >= 400` -// by default, you can also use Request.SetResultStateChecker to customize the result +// by default, you can also use Request.SetResultStateCheckFunc to customize the result // state check logic. func (r *Response) IsErrorState() bool { if r.Response == nil { @@ -74,7 +74,7 @@ func (r *Response) GetContentType() string { // ResultState returns the result state. // By default, it returns SuccessState if HTTP status `code >= 400`, and returns // ErrorState if HTTP status `code >= 400`, otherwise returns UnknownState. -// You can also use Request.SetResultStateChecker or Client.SetCommonResultStateChecker +// You can also use Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc // to customize the result state check logic. func (r *Response) ResultState() ResultState { if r.Response == nil { From 09ccf1dd6a89f09d094229616019d458a1e03fe4 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 15:55:44 +0800 Subject: [PATCH 659/843] Improve comments for Client --- client.go | 126 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 70 insertions(+), 56 deletions(-) diff --git a/client.go b/client.go index d65a02c2..68974140 100644 --- a/client.go +++ b/client.go @@ -197,8 +197,8 @@ func (c *Client) SetCommonResultStateCheckFunc(fn func(resp *Response) ResultSta return c } -// SetCommonFormDataFromValues set the form data from url.Values for all requests -// which request method allows payload. +// SetCommonFormDataFromValues set the form data from url.Values for requests +// fired from the client which request method allows payload. func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { if c.FormData == nil { c.FormData = urlpkg.Values{} @@ -211,7 +211,7 @@ func (c *Client) SetCommonFormDataFromValues(data urlpkg.Values) *Client { return c } -// SetCommonFormData set the form data from map for all requests +// SetCommonFormData set the form data from map for requests fired from the client // which request method allows payload. func (c *Client) SetCommonFormData(data map[string]string) *Client { if c.FormData == nil { @@ -391,7 +391,7 @@ func (c *Client) DisableInsecureSkipVerify() *Client { } // SetCommonQueryParams set URL query parameters with a map -// for all requests. +// for requests fired from the client. func (c *Client) SetCommonQueryParams(params map[string]string) *Client { for k, v := range params { c.SetCommonQueryParam(k, v) @@ -400,7 +400,7 @@ func (c *Client) SetCommonQueryParams(params map[string]string) *Client { } // AddCommonQueryParam add a URL query parameter with a key-value -// pair for all requests. +// pair for requests fired from the client. func (c *Client) AddCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -409,7 +409,8 @@ func (c *Client) AddCommonQueryParam(key, value string) *Client { return c } -// AddCommonQueryParams add one or more values of specified URL query parameter for all requests. +// AddCommonQueryParams add one or more values of specified URL query parameter +// for requests fired from the client. func (c *Client) AddCommonQueryParams(key string, values ...string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -427,13 +428,13 @@ func (c *Client) pathParams() map[string]string { return c.PathParams } -// SetCommonPathParam set a path parameter for all requests. +// SetCommonPathParam set a path parameter for requests fired from the client. func (c *Client) SetCommonPathParam(key, value string) *Client { c.pathParams()[key] = value return c } -// SetCommonPathParams set path parameters for all requests. +// SetCommonPathParams set path parameters for requests fired from the client. func (c *Client) SetCommonPathParams(pathParams map[string]string) *Client { m := c.pathParams() for k, v := range pathParams { @@ -443,7 +444,7 @@ func (c *Client) SetCommonPathParams(pathParams map[string]string) *Client { } // SetCommonQueryParam set a URL query parameter with a key-value -// pair for all requests. +// pair for requests fired from the client. func (c *Client) SetCommonQueryParam(key, value string) *Client { if c.QueryParams == nil { c.QueryParams = make(urlpkg.Values) @@ -453,7 +454,7 @@ func (c *Client) SetCommonQueryParam(key, value string) *Client { } // SetCommonQueryString set URL query parameters with a raw query string -// for all requests. +// for requests fired from the client. func (c *Client) SetCommonQueryString(query string) *Client { params, err := urlpkg.ParseQuery(strings.TrimSpace(query)) if err != nil { @@ -471,7 +472,7 @@ func (c *Client) SetCommonQueryString(query string) *Client { return c } -// SetCommonCookies set HTTP cookies for all requests. +// SetCommonCookies set HTTP cookies for requests fired from the client. func (c *Client) SetCommonCookies(cookies ...*http.Cookie) *Client { c.Cookies = append(c.Cookies, cookies...) return c @@ -527,7 +528,7 @@ func (c *Client) SetLogger(log Logger) *Client { return c } -// SetTimeout set timeout for all requests. +// SetTimeout set timeout for requests fired from the client. func (c *Client) SetTimeout(d time.Duration) *Client { c.httpClient.Timeout = d return c @@ -540,7 +541,7 @@ func (c *Client) getDumpOptions() *DumpOptions { return c.dumpOptions } -// EnableDumpAll enable dump for all requests, including +// EnableDumpAll enable dump for requests fired from the client, including // all content for the request and response by default. func (c *Client) EnableDumpAll() *Client { if c.t.Dump != nil { // dump already started @@ -550,8 +551,8 @@ func (c *Client) EnableDumpAll() *Client { return c } -// EnableDumpAllToFile enable dump for all requests and output -// to the specified file. +// EnableDumpAllToFile enable dump for requests fired from the +// client and output to the specified file. func (c *Client) EnableDumpAllToFile(filename string) *Client { file, err := os.Create(filename) if err != nil { @@ -563,17 +564,17 @@ func (c *Client) EnableDumpAllToFile(filename string) *Client { return c } -// EnableDumpAllTo enable dump for all requests and output to -// the specified io.Writer. +// EnableDumpAllTo enable dump for requests fired from the +// client and output to the specified io.Writer. func (c *Client) EnableDumpAllTo(output io.Writer) *Client { c.getDumpOptions().Output = output c.EnableDumpAll() return c } -// EnableDumpAllAsync enable dump for all requests and output -// asynchronously, can be used for debugging in production -// environment without affecting performance. +// EnableDumpAllAsync enable dump for requests fired from the +// client and output asynchronously, can be used for debugging +// in production environment without affecting performance. func (c *Client) EnableDumpAllAsync() *Client { o := c.getDumpOptions() o.Async = true @@ -581,9 +582,9 @@ func (c *Client) EnableDumpAllAsync() *Client { return c } -// EnableDumpAllWithoutRequestBody enable dump for all requests without -// request body, can be used in the upload request to avoid dumping the -// unreadable binary content. +// EnableDumpAllWithoutRequestBody enable dump for requests fired +// from the client without request body, can be used in the upload +// request to avoid dumping the unreadable binary content. func (c *Client) EnableDumpAllWithoutRequestBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -591,9 +592,9 @@ func (c *Client) EnableDumpAllWithoutRequestBody() *Client { return c } -// EnableDumpAllWithoutResponseBody enable dump for all requests without -// response body, can be used in the download request to avoid dumping the -// unreadable binary content. +// EnableDumpAllWithoutResponseBody enable dump for requests fired +// from the client without response body, can be used in the download +// request to avoid dumping the unreadable binary content. func (c *Client) EnableDumpAllWithoutResponseBody() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -601,8 +602,9 @@ func (c *Client) EnableDumpAllWithoutResponseBody() *Client { return c } -// EnableDumpAllWithoutResponse enable dump for all requests without response, -// can be used if you only care about the request. +// EnableDumpAllWithoutResponse enable dump for requests fired from +// the client without response, can be used if you only care about +// the request. func (c *Client) EnableDumpAllWithoutResponse() *Client { o := c.getDumpOptions() o.ResponseBody = false @@ -611,8 +613,9 @@ func (c *Client) EnableDumpAllWithoutResponse() *Client { return c } -// EnableDumpAllWithoutRequest enables dump for all requests without request, -// can be used if you only care about the response. +// EnableDumpAllWithoutRequest enables dump for requests fired from +// the client without request, can be used if you only care about +// the response. func (c *Client) EnableDumpAllWithoutRequest() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -621,8 +624,9 @@ func (c *Client) EnableDumpAllWithoutRequest() *Client { return c } -// EnableDumpAllWithoutHeader enable dump for all requests without header, -// can be used if you only care about the body. +// EnableDumpAllWithoutHeader enable dump for requests fired from +// the client without header, can be used if you only care about +// the body. func (c *Client) EnableDumpAllWithoutHeader() *Client { o := c.getDumpOptions() o.RequestHeader = false @@ -631,8 +635,9 @@ func (c *Client) EnableDumpAllWithoutHeader() *Client { return c } -// EnableDumpAllWithoutBody enable dump for all requests without body, -// can be used if you only care about the header. +// EnableDumpAllWithoutBody enable dump for requests fired from +// the client without body, can be used if you only care about +// the header. func (c *Client) EnableDumpAllWithoutBody() *Client { o := c.getDumpOptions() o.RequestBody = false @@ -784,23 +789,26 @@ func (c *Client) EnableAutoDecode() *Client { return c } -// SetUserAgent set the "User-Agent" header for all requests. +// SetUserAgent set the "User-Agent" header for requests fired from +// the client. func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetCommonHeader(header.UserAgent, userAgent) } -// SetCommonBearerAuthToken set the bearer auth token for all requests. +// SetCommonBearerAuthToken set the bearer auth token for requests +// fired from the client. func (c *Client) SetCommonBearerAuthToken(token string) *Client { return c.SetCommonHeader("Authorization", "Bearer "+token) } -// SetCommonBasicAuth set the basic auth for all requests. +// SetCommonBasicAuth set the basic auth for requests fired from +// the client. func (c *Client) SetCommonBasicAuth(username, password string) *Client { c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password)) return c } -// SetCommonHeaders set headers for all requests. +// SetCommonHeaders set headers for requests fired from the client. func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { for k, v := range hdrs { c.SetCommonHeader(k, v) @@ -808,7 +816,7 @@ func (c *Client) SetCommonHeaders(hdrs map[string]string) *Client { return c } -// SetCommonHeader set a header for all requests. +// SetCommonHeader set a header for requests fired from the client. func (c *Client) SetCommonHeader(key, value string) *Client { if c.Headers == nil { c.Headers = make(http.Header) @@ -817,8 +825,9 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } -// SetCommonHeaderNonCanonical set a header for all requests which key is a -// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +// SetCommonHeaderNonCanonical set a header for requests fired from +// the client which key is a non-canonical key (keep case unchanged), +// only valid for HTTP/1.1. func (c *Client) SetCommonHeaderNonCanonical(key, value string) *Client { if c.Headers == nil { c.Headers = make(http.Header) @@ -827,8 +836,9 @@ func (c *Client) SetCommonHeaderNonCanonical(key, value string) *Client { return c } -// SetCommonHeadersNonCanonical set headers for all requests which key is a -// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +// SetCommonHeadersNonCanonical set headers for requests fired from the +// client which key is a non-canonical key (keep case unchanged), only +// valid for HTTP/1.1. func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { for k, v := range hdrs { c.SetCommonHeaderNonCanonical(k, v) @@ -836,20 +846,21 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { return c } -// SetCommonContentType set the `Content-Type` header for all requests. +// SetCommonContentType set the `Content-Type` header for requests fired +// from the client. func (c *Client) SetCommonContentType(ct string) *Client { c.SetCommonHeader(header.ContentType, ct) return c } -// DisableDumpAll disable dump for all requests. +// DisableDumpAll disable dump for requests fired from the client. func (c *Client) DisableDumpAll() *Client { c.t.DisableDump() return c } // SetCommonDumpOptions configures the underlying Transport's DumpOptions -// for all requests. +// for requests fired from the client. func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { if opt == nil { return c @@ -902,13 +913,14 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } -// DisableTraceAll disable trace for all requests. +// DisableTraceAll disable trace for requests fired from the client. func (c *Client) DisableTraceAll() *Client { c.trace = false return c } -// EnableTraceAll enable trace for all requests (http3 currently does not support trace). +// EnableTraceAll enable trace for requests fired from the client (http3 +// currently does not support trace). func (c *Client) EnableTraceAll() *Client { c.trace = true return c @@ -1045,14 +1057,15 @@ func (c *Client) getRetryOption() *retryOption { return c.retryOption } -// SetCommonRetryCount enables retry and set the maximum retry count for all requests. +// SetCommonRetryCount enables retry and set the maximum retry count for requests +// fired from the client. func (c *Client) SetCommonRetryCount(count int) *Client { c.getRetryOption().MaxRetries = count return c } -// SetCommonRetryInterval sets the custom GetRetryIntervalFunc for all requests, -// you can use this to implement your own backoff retry algorithm. +// SetCommonRetryInterval sets the custom GetRetryIntervalFunc for requests fired +// from the client, you can use this to implement your own backoff retry algorithm. // For example: // // req.SetCommonRetryInterval(func(resp *req.Response, attempt int) time.Duration { @@ -1064,7 +1077,8 @@ func (c *Client) SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFun return c } -// SetCommonRetryFixedInterval set retry to use a fixed interval for all requests. +// SetCommonRetryFixedInterval set retry to use a fixed interval for requests +// fired from the client. func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { c.getRetryOption().GetRetryInterval = func(resp *Response, attempt int) time.Duration { return interval @@ -1072,8 +1086,8 @@ func (c *Client) SetCommonRetryFixedInterval(interval time.Duration) *Client { return c } -// SetCommonRetryBackoffInterval set retry to use a capped exponential backoff with jitter -// for all requests. +// SetCommonRetryBackoffInterval set retry to use a capped exponential backoff +// with jitter for requests fired from the client. // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ func (c *Client) SetCommonRetryBackoffInterval(min, max time.Duration) *Client { c.getRetryOption().GetRetryInterval = backoffInterval(min, max) @@ -1087,8 +1101,8 @@ func (c *Client) SetCommonRetryHook(hook RetryHookFunc) *Client { return c } -// AddCommonRetryHook adds a retry hook for all requests, which will be -// executed before a retry. +// AddCommonRetryHook adds a retry hook for requests fired from the client, +// which will be executed before a retry. func (c *Client) AddCommonRetryHook(hook RetryHookFunc) *Client { ro := c.getRetryOption() ro.RetryHooks = append(ro.RetryHooks, hook) From da458d69c453f8758acf7305ccd5ea2c0ad1f870 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 16:05:18 +0800 Subject: [PATCH 660/843] Rename deprecated methods to new one in tests --- client_test.go | 10 +++++----- req_test.go | 8 ++++---- request_test.go | 36 ++++++++++++++++++------------------ 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/client_test.go b/client_test.go index 7a6c4e1b..712cd118 100644 --- a/client_test.go +++ b/client_test.go @@ -270,7 +270,7 @@ func TestSetCommonCookies(t *testing.T) { resp, err := tc().SetCommonCookies(&http.Cookie{ Name: "test", Value: "test", - }).R().SetResult(&headers).Get("/header") + }).R().SetSuccessResult(&headers).Get("/header") assertSuccess(t, resp, err) tests.AssertEqual(t, "test=test", headers.Get("Cookie")) } @@ -434,7 +434,7 @@ func TestSetCommonFormDataFromValues(t *testing.T) { expectedForm.Set("test", "test") resp, err := tc(). SetCommonFormDataFromValues(expectedForm). - R().SetResult(&gotForm). + R().SetSuccessResult(&gotForm). Post("/form") assertSuccess(t, resp, err) tests.AssertEqual(t, "test", gotForm.Get("test")) @@ -447,7 +447,7 @@ func TestSetCommonFormData(t *testing.T) { map[string]string{ "test": "test", }).R(). - SetResult(&form). + SetSuccessResult(&form). Post("/form") assertSuccess(t, resp, err) tests.AssertEqual(t, "test", form.Get("test")) @@ -595,14 +595,14 @@ func TestEnableDumpAllAsync(t *testing.T) { func TestSetResponseBodyTransformer(t *testing.T) { c := tc().SetResponseBodyTransformer(func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) { - if resp.IsSuccess() { + if resp.IsSuccessState() { result, err := url.QueryUnescape(string(rawBody)) return []byte(result), err } return rawBody, nil }) user := &UserInfo{} - resp, err := c.R().SetResult(user).Get("/urlencode") + resp, err := c.R().SetSuccessResult(user).Get("/urlencode") assertSuccess(t, resp, err) tests.AssertEqual(t, user.Username, "我是roc") tests.AssertEqual(t, user.Email, "roc@imroc.cc") diff --git a/req_test.go b/req_test.go index b019dc88..35f9f574 100644 --- a/req_test.go +++ b/req_test.go @@ -347,8 +347,8 @@ func assertSuccess(t *testing.T, resp *Response, err error) { tests.AssertNotNil(t, resp.Response.Body) tests.AssertEqual(t, http.StatusOK, resp.StatusCode) tests.AssertEqual(t, "200 OK", resp.Status) - if !resp.IsSuccess() { - t.Error("Response.IsSuccess should return true") + if !resp.IsSuccessState() { + t.Error("Response.IsSuccessState should return true") } } @@ -356,8 +356,8 @@ func assertIsError(t *testing.T, resp *Response, err error) { tests.AssertNoError(t, err) tests.AssertNotNil(t, resp) tests.AssertNotNil(t, resp.Body) - if !resp.IsError() { - t.Error("Response.IsError should return true") + if !resp.IsErrorState() { + t.Error("Response.IsErrorState should return true") } } diff --git a/request_test.go b/request_test.go index 70eeb963..65499775 100644 --- a/request_test.go +++ b/request_test.go @@ -353,7 +353,7 @@ func TestSetBodyMarshal(t *testing.T) { r := c.R() tc.Set(r) var e Echo - resp, err := r.SetResult(&e).Post("/echo") + resp, err := r.SetSuccessResult(&e).Post("/echo") assertSuccess(t, resp, err) tc.Assert([]byte(e.Body)) } @@ -369,22 +369,22 @@ func TestDoAPIStyle(t *testing.T) { tests.AssertEqual(t, "imroc", user.Username) } -func TestSetResult(t *testing.T) { +func TestSetSuccessResult(t *testing.T) { c := tc() var user *UserInfo url := "/search?username=imroc&type=json" - resp, err := c.R().SetResult(&user).Get(url) + resp, err := c.R().SetSuccessResult(&user).Get(url) assertSuccess(t, resp, err) tests.AssertEqual(t, "imroc", user.Username) user = &UserInfo{} - resp, err = c.R().SetResult(user).Get(url) + resp, err = c.R().SetSuccessResult(user).Get(url) assertSuccess(t, resp, err) tests.AssertEqual(t, "imroc", user.Username) user = nil - resp, err = c.R().SetResult(user).Get(url) + resp, err = c.R().SetSuccessResult(user).Get(url) assertSuccess(t, resp, err) tests.AssertEqual(t, "imroc", resp.Result().(*UserInfo).Username) } @@ -472,7 +472,7 @@ func TestSetBody(t *testing.T) { r := c.R() tc.SetBody(r) var e Echo - resp, err := r.SetResult(&e).Post("/echo") + resp, err := r.SetSuccessResult(&e).Post("/echo") assertSuccess(t, resp, err) tests.AssertEqual(t, tc.ContentType, e.Header.Get(header.ContentType)) tests.AssertEqual(t, body, e.Body) @@ -490,7 +490,7 @@ func TestCookie(t *testing.T) { Name: "cookie2", Value: "value2", }, - ).SetResult(&headers).Get("/header") + ).SetSuccessResult(&headers).Get("/header") assertSuccess(t, resp, err) tests.AssertEqual(t, "cookie1=value1; cookie2=value2", headers.Get("Cookie")) } @@ -499,7 +499,7 @@ func TestSetBasicAuth(t *testing.T) { headers := make(http.Header) resp, err := tc().R(). SetBasicAuth("imroc", "123456"). - SetResult(&headers). + SetSuccessResult(&headers). Get("/header") assertSuccess(t, resp, err) tests.AssertEqual(t, "Basic aW1yb2M6MTIzNDU2", headers.Get("Authorization")) @@ -510,7 +510,7 @@ func TestSetBearerAuthToken(t *testing.T) { headers := make(http.Header) resp, err := tc().R(). SetBearerAuthToken(token). - SetResult(&headers). + SetSuccessResult(&headers). Get("/header") assertSuccess(t, resp, err) tests.AssertEqual(t, "Bearer "+token, headers.Get("Authorization")) @@ -534,7 +534,7 @@ func testHeader(t *testing.T, c *Client) { SetHeaders(map[string]string{ "header2": "value2", "header3": "value3", - }).SetResult(&headers). + }).SetSuccessResult(&headers). Get("/header") assertSuccess(t, resp, err) tests.AssertEqual(t, "value1", headers.Get("header1")) @@ -661,7 +661,7 @@ func testSuccess(t *testing.T, c *Client) { var userInfo UserInfo resp, err := c.R(). SetQueryParam("username", "imroc"). - SetResult(&userInfo). + SetSuccessResult(&userInfo). Get("/search") assertSuccess(t, resp, err) tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) @@ -670,7 +670,7 @@ func testSuccess(t *testing.T, c *Client) { resp, err = c.R(). SetQueryParam("username", "imroc"). SetQueryParam("type", "xml"). // auto unmarshal to xml - SetResult(&userInfo).EnableDump(). + SetSuccessResult(&userInfo).EnableDump(). Get("/search") assertSuccess(t, resp, err) tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) @@ -684,7 +684,7 @@ func testError(t *testing.T, c *Client) { var errMsg ErrorMessage resp, err := c.R(). SetQueryParam("username", ""). - SetError(&errMsg). + SetErrorResult(&errMsg). Get("/search") assertIsError(t, resp, err) tests.AssertEqual(t, 10000, errMsg.ErrorCode) @@ -692,7 +692,7 @@ func testError(t *testing.T, c *Client) { errMsg = ErrorMessage{} resp, err = c.R(). SetQueryParam("username", "test"). - SetError(&errMsg). + SetErrorResult(&errMsg). Get("/search") assertIsError(t, resp, err) tests.AssertEqual(t, 10001, errMsg.ErrorCode) @@ -701,7 +701,7 @@ func testError(t *testing.T, c *Client) { resp, err = c.R(). SetQueryParam("username", "test"). SetQueryParam("type", "xml"). // auto unmarshal to xml - SetError(&errMsg). + SetErrorResult(&errMsg). Get("/search") assertIsError(t, resp, err) tests.AssertEqual(t, 10001, errMsg.ErrorCode) @@ -727,7 +727,7 @@ func testForm(t *testing.T, c *Client) { "username": "imroc", "type": "xml", }). - SetResult(&userInfo). + SetSuccessResult(&userInfo). Post("/search") assertSuccess(t, resp, err) tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) @@ -737,7 +737,7 @@ func testForm(t *testing.T, c *Client) { v.Add("type", "xml") resp, err = c.R(). SetFormDataFromValues(v). - SetResult(&userInfo). + SetSuccessResult(&userInfo). Post("/search") assertSuccess(t, resp, err) tests.AssertEqual(t, "roc@imroc.cc", userInfo.Email) @@ -872,7 +872,7 @@ func TestUploadMultipart(t *testing.T) { "param1": "value1", "param2": "value2", }). - SetResult(&m). + SetSuccessResult(&m). Post("/multipart") assertSuccess(t, resp, err) tests.AssertContains(t, resp.String(), "sample-image.png", true) From 34449599b8883627c92bf7c8d5004764f2e97824 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 16:08:46 +0800 Subject: [PATCH 661/843] Remove global wrapper tests --- client_wrapper_test.go | 143 ----------------------------------- request_wrapper_test.go | 163 ---------------------------------------- 2 files changed, 306 deletions(-) delete mode 100644 client_wrapper_test.go delete mode 100644 request_wrapper_test.go diff --git a/client_wrapper_test.go b/client_wrapper_test.go deleted file mode 100644 index 40a1ab48..00000000 --- a/client_wrapper_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package req - -import ( - "crypto/tls" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "net/http" - "net/url" - "os" - "path/filepath" - "testing" - "time" -) - -func TestGlobalWrapper(t *testing.T) { - EnableInsecureSkipVerify() - testGlobalWrapperSendMethods(t) - testGlobalWrapperMustSendMethods(t) - DisableInsecureSkipVerify() - - u, _ := url.Parse("http://dummy.proxy.local") - proxy := http.ProxyURL(u) - form := make(url.Values) - form.Add("test", "test") - - tests.AssertAllNotNil(t, - SetCommonError(nil), - SetCookieJar(nil), - SetDialTLS(nil), - SetDial(nil), - SetTLSHandshakeTimeout(time.Second), - EnableAllowGetMethodPayload(), - DisableAllowGetMethodPayload(), - SetJsonMarshal(nil), - SetJsonUnmarshal(nil), - SetXmlMarshal(nil), - SetXmlUnmarshal(nil), - EnableTraceAll(), - DisableTraceAll(), - OnAfterResponse(func(client *Client, response *Response) error { - return nil - }), - OnBeforeRequest(func(client *Client, request *Request) error { - return nil - }), - SetProxyURL("http://dummy.proxy.local"), - SetProxyURL("bad url"), - SetProxy(proxy), - SetCommonContentType(header.JsonContentType), - SetCommonHeader("my-header", "my-value"), - SetCommonHeaders(map[string]string{ - "header1": "value1", - "header2": "value2", - }), - SetCommonBasicAuth("imroc", "123456"), - SetCommonBearerAuthToken("123456"), - SetUserAgent("test"), - SetTimeout(1*time.Second), - SetLogger(createDefaultLogger()), - SetScheme("https"), - EnableDebugLog(), - DisableDebugLog(), - SetCommonCookies(&http.Cookie{Name: "test", Value: "test"}), - SetCommonQueryString("test1=test1"), - SetCommonPathParams(map[string]string{"test1": "test1"}), - SetCommonPathParam("test2", "test2"), - AddCommonQueryParam("test1", "test11"), - SetCommonQueryParam("test1", "test111"), - SetCommonQueryParams(map[string]string{"test1": "test1"}), - EnableInsecureSkipVerify(), - DisableInsecureSkipVerify(), - DisableCompression(), - EnableCompression(), - DisableKeepAlives(), - EnableKeepAlives(), - SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")), - SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))), - SetCerts(tls.Certificate{}, tls.Certificate{}), - SetCertFromFile( - tests.GetTestFilePath("sample-client.pem"), - tests.GetTestFilePath("sample-client-key.pem"), - ), - SetOutputDirectory(testDataPath), - SetBaseURL("http://dummy-req.local/test"), - SetCommonFormDataFromValues(form), - SetCommonFormData(map[string]string{"test2": "test2"}), - DisableAutoReadResponse(), - EnableAutoReadResponse(), - EnableDumpAll(), - EnableDumpAllAsync(), - EnableDumpAllWithoutBody(), - EnableDumpAllWithoutResponse(), - EnableDumpAllWithoutRequest(), - EnableDumpAllWithoutHeader(), - SetLogger(nil), - EnableDumpAllToFile(filepath.Join(testDataPath, "path-not-exists", "dump.out")), - EnableDumpAllToFile(tests.GetTestFilePath("tmpdump.out")), - SetCommonDumpOptions(&DumpOptions{ - RequestHeader: true, - }), - DisableDumpAll(), - SetRedirectPolicy(NoRedirectPolicy()), - EnableForceHTTP1(), - EnableForceHTTP2(), - EnableForceHTTP3(), - EnableHTTP3(), - DisableForceHttpVersion(), - SetAutoDecodeContentType("json"), - SetAutoDecodeContentTypeFunc(func(contentType string) bool { return true }), - SetAutoDecodeAllContentType(), - DisableAutoDecode(), - EnableAutoDecode(), - AddCommonRetryCondition(func(resp *Response, err error) bool { return true }), - SetCommonRetryCondition(func(resp *Response, err error) bool { return true }), - AddCommonRetryHook(func(resp *Response, err error) {}), - SetCommonRetryHook(func(resp *Response, err error) {}), - SetCommonRetryCount(2), - SetCommonRetryInterval(func(resp *Response, attempt int) time.Duration { - return 1 * time.Second - }), - SetCommonRetryBackoffInterval(1*time.Millisecond, 2*time.Second), - SetCommonRetryFixedInterval(1*time.Second), - SetUnixSocket("/var/run/custom.sock"), - ) - os.Remove(tests.GetTestFilePath("tmpdump.out")) - - config := GetTLSClientConfig() - tests.AssertEqual(t, config, DefaultClient().t.TLSClientConfig) - - r := R() - tests.AssertEqual(t, true, r != nil) - c := C() - - c.SetTimeout(10 * time.Second) - SetDefaultClient(c) - tests.AssertEqual(t, true, DefaultClient().httpClient.Timeout == 10*time.Second) - tests.AssertEqual(t, GetClient(), DefaultClient().httpClient) - - r = NewRequest() - tests.AssertEqual(t, true, r != nil) - c = NewClient() - tests.AssertEqual(t, true, c != nil) -} diff --git a/request_wrapper_test.go b/request_wrapper_test.go deleted file mode 100644 index 91a244a6..00000000 --- a/request_wrapper_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package req - -import ( - "bytes" - "context" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "net/http" - "testing" - "time" -) - -func init() { - SetLogger(nil) // disable log -} - -func TestGlobalWrapperForRequestSettings(t *testing.T) { - tests.AssertAllNotNil(t, - SetFiles(map[string]string{"test": "req.go"}), - SetFile("test", "req.go"), - SetFileReader("test", "test.txt", bytes.NewBufferString("test")), - SetFileBytes("test", "test.txt", []byte("test")), - SetFileUpload(FileUpload{}), - SetError(&ErrorMessage{}), - SetResult(&UserInfo{}), - SetOutput(new(bytes.Buffer)), - SetHeader("test", "test"), - SetHeaders(map[string]string{"test": "test"}), - SetCookies(&http.Cookie{ - Name: "test", - Value: "test", - }), - SetBasicAuth("imroc", "123456"), - SetBearerAuthToken("123456"), - SetQueryString("test=test"), - SetQueryString("ksjlfjk?"), - SetQueryParam("test", "test"), - AddQueryParam("test", "test"), - SetQueryParams(map[string]string{"test": "test"}), - SetPathParam("test", "test"), - SetPathParams(map[string]string{"test": "test"}), - SetFormData(map[string]string{"test": "test"}), - SetURL(""), - SetFormDataFromValues(nil), - SetContentType(header.JsonContentType), - AddRetryCondition(func(rep *Response, err error) bool { - return err != nil - }), - SetRetryCondition(func(rep *Response, err error) bool { - return err != nil - }), - AddRetryHook(func(resp *Response, err error) {}), - SetRetryHook(func(resp *Response, err error) {}), - SetRetryBackoffInterval(0, 0), - SetRetryFixedInterval(0), - SetRetryInterval(func(resp *Response, attempt int) time.Duration { - return 1 * time.Millisecond - }), - SetRetryCount(3), - SetBodyXmlMarshal(0), - SetBodyString("test"), - SetBodyBytes([]byte("test")), - SetBodyJsonBytes([]byte(`{"user":"roc"}`)), - SetBodyJsonString(`{"user":"roc"}`), - SetBodyXmlBytes([]byte("test")), - SetBodyXmlString("test"), - SetBody("test"), - SetBodyJsonMarshal(User{ - Name: "roc", - }), - EnableTrace(), - DisableTrace(), - SetContext(context.Background()), - SetUploadCallback(nil), - SetUploadCallbackWithInterval(nil, 0), - SetDownloadCallback(nil), - SetDownloadCallbackWithInterval(nil, 0), - ) -} - -func testGlobalWrapperMustSendMethods(t *testing.T) { - testCases := []struct { - SendReq func(string) *Response - ExpectMethod string - }{ - { - SendReq: MustGet, - ExpectMethod: "GET", - }, - { - SendReq: MustPost, - ExpectMethod: "POST", - }, - { - SendReq: MustPatch, - ExpectMethod: "PATCH", - }, - { - SendReq: MustPut, - ExpectMethod: "PUT", - }, - { - SendReq: MustDelete, - ExpectMethod: "DELETE", - }, - { - SendReq: MustOptions, - ExpectMethod: "OPTIONS", - }, - { - SendReq: MustHead, - ExpectMethod: "HEAD", - }, - } - url := getTestServerURL() + "/" - for _, tc := range testCases { - resp := tc.SendReq(url) - tests.AssertNotNil(t, resp.Response) - tests.AssertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) - } -} - -func testGlobalWrapperSendMethods(t *testing.T) { - testCases := []struct { - SendReq func(string) (*Response, error) - ExpectMethod string - }{ - { - SendReq: Get, - ExpectMethod: "GET", - }, - { - SendReq: Post, - ExpectMethod: "POST", - }, - { - SendReq: Patch, - ExpectMethod: "PATCH", - }, - { - SendReq: Put, - ExpectMethod: "PUT", - }, - { - SendReq: Delete, - ExpectMethod: "DELETE", - }, - { - SendReq: Options, - ExpectMethod: "OPTIONS", - }, - { - SendReq: Head, - ExpectMethod: "HEAD", - }, - } - url := getTestServerURL() + "/" - for _, tc := range testCases { - resp, err := tc.SendReq(url) - assertSuccess(t, resp, err) - tests.AssertEqual(t, tc.ExpectMethod, resp.Header.Get("Method")) - } -} From 9dd6a03f53d5628ac221602cd0e8f6b2f8d28203 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 18 Jan 2023 17:01:18 +0800 Subject: [PATCH 662/843] Update examples --- examples/find-popular-repo/go.mod | 5 +- examples/find-popular-repo/go.sum | 39 +++++------- examples/find-popular-repo/main.go | 61 +++++++++++-------- .../github/github.go | 18 +++--- middleware.go | 9 ++- request_test.go | 2 +- 6 files changed, 68 insertions(+), 66 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index b042eba8..cc76a320 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -12,12 +12,11 @@ require ( github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/lucas-clemente/quic-go v0.28.1 // indirect github.com/marten-seemann/qpack v0.2.1 // indirect github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.3 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index f5d21907..bad63e63 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -17,6 +17,7 @@ github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wX github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= @@ -36,6 +37,7 @@ github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200j github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -46,6 +48,7 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -62,7 +65,6 @@ github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE0 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= -github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -77,8 +79,6 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU= -github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= @@ -87,11 +87,10 @@ github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtU github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= -github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= -github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= -github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= -github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= +github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= +github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -105,14 +104,15 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= @@ -146,13 +146,14 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= @@ -190,12 +191,7 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220111093109-d55c255bac03 h1:0FB83qp0AzVJm+0wcIlauAjJ+tNdh7jLuacRYCIVv7s= -golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -210,7 +206,6 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -230,15 +225,12 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -257,13 +249,13 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= @@ -287,9 +279,9 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= @@ -298,6 +290,7 @@ gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/find-popular-repo/main.go b/examples/find-popular-repo/main.go index c9e43097..44708f3b 100644 --- a/examples/find-popular-repo/main.go +++ b/examples/find-popular-repo/main.go @@ -1,7 +1,6 @@ package main import ( - "errors" "fmt" "strconv" @@ -21,7 +20,23 @@ func main() { } func init() { - req.EnableDumpAllWithoutBody().EnableDebugLog().EnableTraceAll() + req.EnableDebugLog(). + EnableTraceAll(). + EnableDumpEachRequest(). + SetCommonErrorResult(&ErrorMessage{}). + OnAfterResponse(func(client *req.Client, resp *req.Response) error { + if resp.Err != nil { + return nil + } + if errMsg, ok := resp.ErrorResult().(*ErrorMessage); ok { + resp.Err = errMsg + return nil + } + if !resp.IsSuccessState() { + resp.Err = fmt.Errorf("bad status: %s\nraw content:\n%s", resp.Status, resp.Dump()) + } + return nil + }) } type Repo struct { @@ -32,14 +47,16 @@ type ErrorMessage struct { Message string `json:"message"` } -func findTheMostPopularRepo(username string) (repo string, star int, err error) { +func (msg *ErrorMessage) Error() string { + return fmt.Sprintf("API Error: %s", msg.Message) +} +func findTheMostPopularRepo(username string) (repo string, star int, err error) { var popularRepo Repo var resp *req.Response for page := 1; ; page++ { repos := []*Repo{} - errMsg := ErrorMessage{} resp, err = req.SetHeader("Accept", "application/vnd.github.v3+json"). SetQueryParams(map[string]string{ "type": "owner", @@ -49,8 +66,7 @@ func findTheMostPopularRepo(username string) (repo string, star int, err error) "direction": "desc", }). SetPathParam("username", username). - SetResult(&repos). - SetError(&errMsg). + SetSuccessResult(&repos). Get("https://api.github.com/users/{username}/repos") fmt.Println("TraceInfo:") @@ -62,29 +78,20 @@ func findTheMostPopularRepo(username string) (repo string, star int, err error) return } - if resp.IsSuccess() { // HTTP status `code >= 200 and <= 299` is considred as success - for _, repo := range repos { - if repo.Star >= popularRepo.Star { - popularRepo = *repo - } - } - if len(repo) == 100 { // Try Next page - continue - } - // All repos have been traversed, return the final result - repo = popularRepo.Name - star = popularRepo.Star - return - } else if resp.IsError() { // HTTP status `code >= 400` is considred as an error - // Extract the error message, wrap and return err - err = errors.New(errMsg.Message) + if !resp.IsSuccessState() { // HTTP status `code >= 200 and <= 299` is considered as success by default return } - - // Unkown http status code, record and return error, here we can use - // String() to get response body, cuz response body have already been read - // and no error returned, do not need to use ToString(). - err = fmt.Errorf("unknown error. status code %d; body: %s", resp.StatusCode, resp.String()) + for _, repo := range repos { + if repo.Star >= popularRepo.Star { + popularRepo = *repo + } + } + if len(repo) == 100 { // Try Next page + continue + } + // All repos have been traversed, return the final result + repo = popularRepo.Name + star = popularRepo.Star return } } diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index 894553f4..fb914898 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -59,23 +59,23 @@ func NewClient() *Client { // request middleware EnableDumpEachRequest(). // Unmarshal response body into an APIError struct when status >= 400. - SetCommonError(&APIError{}). + SetCommonErrorResult(&APIError{}). // Handle common exceptions in response middleware. OnAfterResponse(func(client *req.Client, resp *req.Response) error { if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error(SetResult or SetError was invoked before). if dump := resp.Dump(); dump != "" { // Append dump content to original underlying error to help troubleshoot. - resp.Err = fmt.Errorf("%s\nraw content:\n%s", resp.Err.Error(), resp.Dump()) + resp.Err = fmt.Errorf("error: %s\nraw content:\n%s", resp.Err.Error(), resp.Dump()) } return nil // Skip the following logic if there is an underlying error. } - if err, ok := resp.Error().(*APIError); ok { // Server returns an error message. - // Convert it to human-readable go error. + if err, ok := resp.ErrorResult().(*APIError); ok { // Server returns an error message. + // Convert it to human-readable go error which implements the error interface. resp.Err = err return nil } - // Corner case: neither an error response nor a success response, e.g. status code < 200 - // Just dump the raw content into error to help troubleshoot. - if !resp.IsSuccess() { + // Corner case: neither an error response nor a success response, e.g. status code < 200 or + // code >= 300 && code <= 399, just dump the raw content into error to help troubleshoot. + if !resp.IsSuccessState() { resp.Err = fmt.Errorf("bad response, raw content:\n%s", resp.Dump()) } return nil @@ -145,7 +145,7 @@ type UserProfile struct { func (c *Client) GetUserProfile(ctx context.Context, username string) (user *UserProfile, err error) { err = c.Get("/users/{username}"). SetPathParam("username", username). - SetResult(&user). + SetSuccessResult(&user). Do(withAPIName(ctx, "GetUserProfile")).Err return } @@ -167,7 +167,7 @@ func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (r "sort": "updated", "direction": "desc", }). - SetResult(&repos). + SetSuccessResult(&repos). Do(withAPIName(ctx, "ListUserRepo")).Err return } diff --git a/middleware.go b/middleware.go index a53ae46c..1e05be7e 100644 --- a/middleware.go +++ b/middleware.go @@ -254,19 +254,22 @@ func defaultResultStateChecker(resp *Response) ResultState { } func parseResponseBody(c *Client, r *Response) (err error) { - if r.Response == nil || r.StatusCode == http.StatusNoContent { + if r.Response == nil { return } req := r.Request switch req.resultStateCheckFunc(r) { case SuccessState: - if req.Result != nil { + if req.Result != nil && r.StatusCode != http.StatusNoContent { err = unmarshalBody(c, r, r.Request.Result) if err == nil { r.result = r.Request.Result } } case ErrorState: + if r.StatusCode == http.StatusNoContent { + return + } if req.Error != nil { err = unmarshalBody(c, r, req.Error) if err == nil { @@ -285,7 +288,7 @@ func parseResponseBody(c *Client, r *Response) (err error) { handleUnknownResult = c.unknownResultHandlerFunc } if handleUnknownResult != nil { - handleUnknownResult(r) + return handleUnknownResult(r) } } return diff --git a/request_test.go b/request_test.go index 65499775..c6800846 100644 --- a/request_test.go +++ b/request_test.go @@ -706,7 +706,7 @@ func testError(t *testing.T, c *Client) { assertIsError(t, resp, err) tests.AssertEqual(t, 10001, errMsg.ErrorCode) - c.SetCommonError(&errMsg) + c.SetCommonErrorResult(&errMsg) resp, err = c.R(). SetQueryParam("username", ""). Get("/search") From 3af71dfa05bdc2b63aec170131559eb51f341822 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 24 Jan 2023 20:40:28 +0800 Subject: [PATCH 663/843] Update examples --- examples/opentelemetry-jaeger-tracing/github/github.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/opentelemetry-jaeger-tracing/github/github.go b/examples/opentelemetry-jaeger-tracing/github/github.go index fb914898..b7898301 100644 --- a/examples/opentelemetry-jaeger-tracing/github/github.go +++ b/examples/opentelemetry-jaeger-tracing/github/github.go @@ -62,7 +62,7 @@ func NewClient() *Client { SetCommonErrorResult(&APIError{}). // Handle common exceptions in response middleware. OnAfterResponse(func(client *req.Client, resp *req.Response) error { - if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error(SetResult or SetError was invoked before). + if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error (SetSuccessResult or SetErrorResult was invoked before). if dump := resp.Dump(); dump != "" { // Append dump content to original underlying error to help troubleshoot. resp.Err = fmt.Errorf("error: %s\nraw content:\n%s", resp.Err.Error(), resp.Dump()) } @@ -167,8 +167,8 @@ func (c *Client) ListUserRepo(ctx context.Context, username string, page int) (r "sort": "updated", "direction": "desc", }). - SetSuccessResult(&repos). - Do(withAPIName(ctx, "ListUserRepo")).Err + Do(withAPIName(ctx, "ListUserRepo")). + Into(&repos) return } From f0e23255b6c068d17dc83d57065b2e60bf137aba Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 24 Jan 2023 20:59:42 +0800 Subject: [PATCH 664/843] SetCommonResultStateCheckFunc --> SetResultStateCheckFunc --- client.go | 31 ++++++++++++++++++++----------- client_wrapper.go | 6 +++--- middleware.go | 2 +- request.go | 35 ++++------------------------------- request_wrapper.go | 6 ------ response.go | 10 ++++++++-- 6 files changed, 36 insertions(+), 54 deletions(-) diff --git a/client.go b/client.go index 68974140..b98eb44b 100644 --- a/client.go +++ b/client.go @@ -189,10 +189,28 @@ func (c *Client) SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error return c } -// SetCommonResultStateCheckFunc overrides the default result state checker with customized one, +// ResultState represents the state of the result. +type ResultState int + +const ( + // SuccessState indicates the response is in success state, + // and result will be unmarshalled if Request.SetSuccessResult + // is called. + SuccessState ResultState = iota + // ErrorState indicates the response is in error state, + // and result will be unmarshalled if Request.SetErrorResult + // or Client.SetCommonErrorResult is called. + ErrorState + // UnknownState indicates the response is in unknown state, + // and handler will be invoked if Request.SetUnknownResultHandlerFunc + // or Client.SetCommonUnknownResultHandlerFunc is called. + UnknownState +) + +// SetResultStateCheckFunc overrides the default result state checker with customized one, // which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns // ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. -func (c *Client) SetCommonResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { +func (c *Client) SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { c.resultStateCheckFunc = fn return c } @@ -1362,15 +1380,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { httpResponse, err = c.httpClient.Do(r.RawRequest) resp.Response = httpResponse - // setup resultStateCheckFunc - if r.resultStateCheckFunc == nil { - if c.resultStateCheckFunc != nil { - r.resultStateCheckFunc = c.resultStateCheckFunc - } else { - r.resultStateCheckFunc = defaultResultStateChecker - } - } - // auto-read response body if possible if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { _, err = resp.ToBytes() diff --git a/client_wrapper.go b/client_wrapper.go index 05e22b54..aa366c9f 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -42,10 +42,10 @@ func SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error) *Client { return defaultClient.SetCommonUnknownResultHandlerFunc(fn) } -// SetCommonResultStateCheckFunc is a global wrapper methods which delegated +// SetResultStateCheckFunc is a global wrapper methods which delegated // to the default client's SetCommonResultStateCheckFunc. -func SetCommonResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { - return defaultClient.SetCommonResultStateCheckFunc(fn) +func SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { + return defaultClient.SetResultStateCheckFunc(fn) } // SetCommonFormDataFromValues is a global wrapper methods which delegated diff --git a/middleware.go b/middleware.go index 1e05be7e..9eed2af6 100644 --- a/middleware.go +++ b/middleware.go @@ -258,7 +258,7 @@ func parseResponseBody(c *Client, r *Response) (err error) { return } req := r.Request - switch req.resultStateCheckFunc(r) { + switch r.ResultState() { case SuccessState: if req.Result != nil && r.StatusCode != http.StatusNoContent { err = unmarshalBody(c, r, r.Request.Result) diff --git a/request.go b/request.go index af655817..6d859277 100644 --- a/request.go +++ b/request.go @@ -67,7 +67,6 @@ type Request struct { dumpBuffer *bytes.Buffer responseReturnTime time.Time unknownResultHandlerFunc func(resp *Response) error - resultStateCheckFunc func(resp *Response) ResultState } type GetContentFunc func() (io.ReadCloser, error) @@ -352,7 +351,7 @@ func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, min // SetResult set the result that response Body will be unmarshalled to if // no error occurs and Response.ResultState() returns SuccessState, by default // it requires HTTP status `code >= 200 && code <= 299`, you can also use -// Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc to customize +// Request.SetResultStateCheckFunc or Client.SetResultStateCheckFunc to customize // the result state check logic. // // Deprecated: Use SetSuccessResult instead. @@ -363,7 +362,7 @@ func (r *Request) SetResult(result interface{}) *Request { // SetSuccessResult set the result that response Body will be unmarshalled to if // no error occurs and Response.ResultState() returns SuccessState, by default // it requires HTTP status `code >= 200 && code <= 299`, you can also use -// Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc to customize +// Request.SetResultStateCheckFunc or Client.SetResultStateCheckFunc to customize // the result state check logic. func (r *Request) SetSuccessResult(result interface{}) *Request { r.Result = util.GetPointer(result) @@ -373,7 +372,7 @@ func (r *Request) SetSuccessResult(result interface{}) *Request { // SetError set the result that response body will be unmarshalled to if // no error occurs and Response.ResultState() returns ErrorState, by default // it requires HTTP status `code >= 400`, you can also use Request.SetResultStateCheckFunc -// or Client.SetCommonResultStateCheckFunc to customize the result state check logic. +// or Client.SetResultStateCheckFunc to customize the result state check logic. // // Deprecated: Use SetErrorResult result. func (r *Request) SetError(error interface{}) *Request { @@ -383,7 +382,7 @@ func (r *Request) SetError(error interface{}) *Request { // SetErrorResult set the result that response body will be unmarshalled to if // no error occurs and Response.ResultState() returns ErrorState, by default // it requires HTTP status `code >= 400`, you can also use Request.SetResultStateCheckFunc -// or Client.SetCommonResultStateCheckFunc to customize the result state check logic. +// or Client.SetResultStateCheckFunc to customize the result state check logic. func (r *Request) SetErrorResult(error interface{}) *Request { r.Error = util.GetPointer(error) return r @@ -396,32 +395,6 @@ func (r *Request) SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Re return r } -// ResultState represents the state of the result. -type ResultState int - -const ( - // SuccessState indicates the response is in success state, - // and result will be unmarshalled if Request.SetSuccessResult - // is called. - SuccessState ResultState = iota - // ErrorState indicates the response is in error state, - // and result will be unmarshalled if Request.SetErrorResult - // or Client.SetCommonErrorResult is called. - ErrorState - // UnknownState indicates the response is in unknown state, - // and handler will be invoked if Request.SetUnknownResultHandlerFunc - // or Client.SetCommonUnknownResultHandlerFunc is called. - UnknownState -) - -// SetResultStateCheckFunc overrides the default result state checker with customized one, -// which returns SuccessState when HTTP status `code >= 200 and <= 299`, and returns -// ErrorState when HTTP status `code >= 400`, otherwise returns UnknownState. -func (r *Request) SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Request { - r.resultStateCheckFunc = fn - return r -} - // SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) diff --git a/request_wrapper.go b/request_wrapper.go index 47f154f7..a0475686 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -102,12 +102,6 @@ func SetErrorResult(error interface{}) *Request { return defaultClient.R().SetErrorResult(error) } -// SetResultStateCheckFunc is a global wrapper methods which delegated -// to the default client, create a request and SetResultStateCheckFunc for request. -func SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Request { - return defaultClient.R().SetResultStateCheckFunc(fn) -} - // SetUnknownResultHandlerFunc is a global wrapper methods which delegated // to the default client, create a request and SetUnknownResultHandlerFunc for request. func SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Request { diff --git a/response.go b/response.go index eaf29400..6a3e1451 100644 --- a/response.go +++ b/response.go @@ -74,13 +74,19 @@ func (r *Response) GetContentType() string { // ResultState returns the result state. // By default, it returns SuccessState if HTTP status `code >= 400`, and returns // ErrorState if HTTP status `code >= 400`, otherwise returns UnknownState. -// You can also use Request.SetResultStateCheckFunc or Client.SetCommonResultStateCheckFunc +// You can also use Request.SetResultStateCheckFunc or Client.SetResultStateCheckFunc // to customize the result state check logic. func (r *Response) ResultState() ResultState { if r.Response == nil { return UnknownState } - return r.Request.resultStateCheckFunc(r) + var resultStateCheckFunc func(resp *Response) ResultState + if r.Request.client.resultStateCheckFunc != nil { + resultStateCheckFunc = r.Request.client.resultStateCheckFunc + } else { + resultStateCheckFunc = defaultResultStateChecker + } + return resultStateCheckFunc(r) } // Result returns the automatically unmarshalled object if Request.SetSuccessResult From f5594c56cf17ee7251d37fb23f71c6197f86327e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 24 Jan 2023 21:03:04 +0800 Subject: [PATCH 665/843] Remove SetUnknownResultHandlerFunc --- client.go | 48 +++++++++++++++++++--------------------------- client_wrapper.go | 6 ------ middleware.go | 8 -------- request.go | 8 -------- request_wrapper.go | 6 ------ 5 files changed, 20 insertions(+), 56 deletions(-) diff --git a/client.go b/client.go index b98eb44b..5386ed60 100644 --- a/client.go +++ b/client.go @@ -48,27 +48,26 @@ type Client struct { DebugLog bool AllowGetMethodPayload bool - trace bool - disableAutoReadResponse bool - commonErrorType reflect.Type - retryOption *retryOption - jsonMarshal func(v interface{}) ([]byte, error) - jsonUnmarshal func(data []byte, v interface{}) error - xmlMarshal func(v interface{}) ([]byte, error) - xmlUnmarshal func(data []byte, v interface{}) error - outputDirectory string - scheme string - log Logger - t *Transport - dumpOptions *DumpOptions - httpClient *http.Client - beforeRequest []RequestMiddleware - udBeforeRequest []RequestMiddleware - afterResponse []ResponseMiddleware - wrappedRoundTrip RoundTripper - responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) - unknownResultHandlerFunc func(resp *Response) error - resultStateCheckFunc func(resp *Response) ResultState + trace bool + disableAutoReadResponse bool + commonErrorType reflect.Type + retryOption *retryOption + jsonMarshal func(v interface{}) ([]byte, error) + jsonUnmarshal func(data []byte, v interface{}) error + xmlMarshal func(v interface{}) ([]byte, error) + xmlUnmarshal func(data []byte, v interface{}) error + outputDirectory string + scheme string + log Logger + t *Transport + dumpOptions *DumpOptions + httpClient *http.Client + beforeRequest []RequestMiddleware + udBeforeRequest []RequestMiddleware + afterResponse []ResponseMiddleware + wrappedRoundTrip RoundTripper + responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) + resultStateCheckFunc func(resp *Response) ResultState } // R create a new request. @@ -182,13 +181,6 @@ func (c *Client) SetCommonErrorResult(err interface{}) *Client { return c } -// SetCommonUnknownResultHandlerFunc set the response handler which will be executed when no -// error occurs, but Response.ResultState returns UnknownState. -func (c *Client) SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error) *Client { - c.unknownResultHandlerFunc = fn - return c -} - // ResultState represents the state of the result. type ResultState int diff --git a/client_wrapper.go b/client_wrapper.go index aa366c9f..96907b63 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -36,12 +36,6 @@ func SetCommonErrorResult(err interface{}) *Client { return defaultClient.SetCommonErrorResult(err) } -// SetCommonUnknownResultHandlerFunc is a global wrapper methods which delegated -// to the default client's SetCommonUnknownResultHandlerFunc. -func SetCommonUnknownResultHandlerFunc(fn func(resp *Response) error) *Client { - return defaultClient.SetCommonUnknownResultHandlerFunc(fn) -} - // SetResultStateCheckFunc is a global wrapper methods which delegated // to the default client's SetCommonResultStateCheckFunc. func SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { diff --git a/middleware.go b/middleware.go index 9eed2af6..b8865bdf 100644 --- a/middleware.go +++ b/middleware.go @@ -282,14 +282,6 @@ func parseResponseBody(c *Client, r *Response) (err error) { r.error = e } } - default: - handleUnknownResult := req.unknownResultHandlerFunc - if handleUnknownResult == nil { - handleUnknownResult = c.unknownResultHandlerFunc - } - if handleUnknownResult != nil { - return handleUnknownResult(r) - } } return } diff --git a/request.go b/request.go index 6d859277..4a48e9ac 100644 --- a/request.go +++ b/request.go @@ -66,7 +66,6 @@ type Request struct { trace *clientTrace dumpBuffer *bytes.Buffer responseReturnTime time.Time - unknownResultHandlerFunc func(resp *Response) error } type GetContentFunc func() (io.ReadCloser, error) @@ -388,13 +387,6 @@ func (r *Request) SetErrorResult(error interface{}) *Request { return r } -// SetUnknownResultHandlerFunc set the response handler which will be executed when no -// error occurs, but Response.ResultState returns UnknownState. -func (r *Request) SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Request { - r.unknownResultHandlerFunc = fn - return r -} - // SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { return r.SetHeader("Authorization", "Bearer "+token) diff --git a/request_wrapper.go b/request_wrapper.go index a0475686..daa1667a 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -102,12 +102,6 @@ func SetErrorResult(error interface{}) *Request { return defaultClient.R().SetErrorResult(error) } -// SetUnknownResultHandlerFunc is a global wrapper methods which delegated -// to the default client, create a request and SetUnknownResultHandlerFunc for request. -func SetUnknownResultHandlerFunc(fn func(resp *Response) error) *Request { - return defaultClient.R().SetUnknownResultHandlerFunc(fn) -} - // SetBearerAuthToken is a global wrapper methods which delegated // to the default client, create a request and SetBearerAuthToken for request. func SetBearerAuthToken(token string) *Request { From c15612a7a0b4d1d2a4309efcd1320bc3488741ab Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 24 Jan 2023 23:25:16 +0800 Subject: [PATCH 666/843] Add TestSetResultStateCheckFunc --- client_test.go | 25 +++++++++++++++++++++++++ req_test.go | 10 ++++++++++ 2 files changed, 35 insertions(+) diff --git a/client_test.go b/client_test.go index 712cd118..7cb911fd 100644 --- a/client_test.go +++ b/client_test.go @@ -607,3 +607,28 @@ func TestSetResponseBodyTransformer(t *testing.T) { tests.AssertEqual(t, user.Username, "我是roc") tests.AssertEqual(t, user.Email, "roc@imroc.cc") } + +func TestSetResultStateCheckFunc(t *testing.T) { + c := tc().SetResultStateCheckFunc(func(resp *Response) ResultState { + if resp.StatusCode == http.StatusOK { + return SuccessState + } else { + return ErrorState + } + }) + resp, err := c.R().Get("/status?code=200") + tests.AssertNoError(t, err) + tests.AssertEqual(t, SuccessState, resp.ResultState()) + + resp, err = c.R().Get("/status?code=201") + tests.AssertNoError(t, err) + tests.AssertEqual(t, ErrorState, resp.ResultState()) + + resp, err = c.R().Get("/status?code=301") + tests.AssertNoError(t, err) + tests.AssertEqual(t, ErrorState, resp.ResultState()) + + resp, err = c.R().Get("/status?code=404") + tests.AssertNoError(t, err) + tests.AssertEqual(t, ErrorState, resp.ResultState()) +} diff --git a/req_test.go b/req_test.go index 35f9f574..bbb61ee4 100644 --- a/req_test.go +++ b/req_test.go @@ -227,6 +227,16 @@ func handleGet(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": w.Write([]byte("TestGet: text response")) + case "/status": + r.ParseForm() + code := r.FormValue("code") + codeInt, err := strconv.Atoi(code) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + w.WriteHeader(codeInt) case "/urlencode": info := &UserInfo{ Username: "我是roc", From 11f1141c9a9ec341f3f36b5df4318b6d2fa500f4 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 24 Jan 2023 23:28:49 +0800 Subject: [PATCH 667/843] Update TestSetResultStateCheckFunc --- client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index 7cb911fd..ee87f14c 100644 --- a/client_test.go +++ b/client_test.go @@ -624,7 +624,7 @@ func TestSetResultStateCheckFunc(t *testing.T) { tests.AssertNoError(t, err) tests.AssertEqual(t, ErrorState, resp.ResultState()) - resp, err = c.R().Get("/status?code=301") + resp, err = c.R().Get("/status?code=399") tests.AssertNoError(t, err) tests.AssertEqual(t, ErrorState, resp.ResultState()) From 2febcc7d5be43ed5a38eb2f29254c949046de12f Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 28 Jan 2023 10:29:25 +0800 Subject: [PATCH 668/843] Update README --- README.md | 472 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 421 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index d33858d8..4ba1e787 100644 --- a/README.md +++ b/README.md @@ -49,64 +49,72 @@ import "github.com/imroc/req/v3" **Basic Usage** +```bash +# assume the following codes in main.go file +$ cat main.go +``` + ```go -// For testing, you can create and send a request with the global wrapper methods -// that use the default client behind the scenes to initiate the request (you can -// just treat package name `req` as a Client or Request, no need to create any client -// or Request explicitly). -req.DevMode() // Use Client.DevMode to see all details, try and surprise :) -req.Get("https://httpbin.org/get") // Use Request.Get to send a GET request. - -// In production, create a client explicitly and reuse it to send all requests -client := req.C(). // Use C() to create a client and set with chainable client settings. - SetUserAgent("my-custom-client"). - SetTimeout(5 * time.Second). - DevMode() -resp, err := client.R(). // Use R() to create a request and set with chainable request settings. - SetHeader("Accept", "application/vnd.github.v3+json"). - SetPathParam("username", "imroc"). - SetQueryParam("page", "1"). - SetResult(&result). // Unmarshal response into struct automatically if status code >= 200 and <= 299. - SetError(&errMsg). // Unmarshal response into struct automatically if status code >= 400. - EnableDump(). // Enable dump at request level to help troubleshoot, log content only when an unexpected exception occurs. - Get("https://api.github.com/users/{username}/repos") -if err != nil { - // Handle error. - // ... - return +package main + +import ( + "github.com/imroc/req/v3" +) + +func main() { + req.DevMode() // Treat the package name as a Client, enable development mode + req.MustGet("https://httpbin.org/uuid") // Treat the package name as a Request, send GET request. + + req.EnableForceHTTP1() // Force using HTTP/1.1 + req.MustGet("https://httpbin.org/uuid") } -if resp.IsSuccess() { - // Handle result. - // ... - return +``` + +```bash +$ go run main.go +2022/05/19 10:05:07.920113 DEBUG [req] HTTP/2 GET https://httpbin.org/uuid +:authority: httpbin.org +:method: GET +:path: /uuid +:scheme: https +user-agent: req/v3 (https://github.com/imroc/req/v3) +accept-encoding: gzip + +:status: 200 +date: Thu, 19 May 2022 02:05:08 GMT +content-type: application/json +content-length: 53 +server: gunicorn/19.9.0 +access-control-allow-origin: * +access-control-allow-credentials: true + +{ + "uuid": "bd519208-35d1-4483-ad9f-e1555ae108ba" } -if resp.IsError() { - // Handle errMsg. - // ... - return + +2022/05/19 10:05:09.340974 DEBUG [req] HTTP/1.1 GET https://httpbin.org/uuid +GET /uuid HTTP/1.1 +Host: httpbin.org +User-Agent: req/v3 (https://github.com/imroc/req/v3) +Accept-Encoding: gzip + +HTTP/1.1 200 OK +Date: Thu, 19 May 2022 02:05:09 GMT +Content-Type: application/json +Content-Length: 53 +Connection: keep-alive +Server: gunicorn/19.9.0 +Access-Control-Allow-Origin: * +Access-Control-Allow-Credentials: true + +{ + "uuid": "49b7f916-c6f3-49d4-a6d4-22ae93b71969" } -// Handle unexpected response (corner case). -err = fmt.Errorf("got unexpected response, raw dump:\n%s", resp.Dump()) -// ... ``` -You can also use another style if you want: +The sample code above is good for quick testing purposes, which use `DevMode()` to see request details, and send requests using global wrapper methods that use the default client behind the scenes to initiate the request. -```go -resp := client.Get("https://api.github.com/users/{username}/repos"). // Create a GET request with specified URL. - SetHeader("Accept", "application/vnd.github.v3+json"). - SetPathParam("username", "imroc"). - SetQueryParam("page", "1"). - SetResult(&result). - SetError(&errMsg). - EnableDump(). - Do() // Send request with Do. - -if resp.Err != nil { - // ... -} -// ... -``` +In production, it is recommended to explicitly create a client, and then use the same client to send all requests, please see other examples below. **Videos** @@ -119,6 +127,368 @@ The following is a series of video tutorials for req: Check more introduction, tutorials, examples, best practices and API references on the [official website](https://req.cool/). +##
Simple GET + +```go +package main + +import ( + "fmt" + "github.com/imroc/req/v3" + "log" +) + +func main() { + client := req.C() // Use C() to create a client. + resp, err := client.R(). // Use R() to create a request. + Get("https://httpbin.org/uuid") + if err != nil { + log.Fatal(err) + } + fmt.Println(resp) +} +``` + +```txt +{ + "uuid": "a4d4430d-0e5f-412f-88f5-722d84bc2a62" +} +``` + +## Advanced GET + +```go +package main + +import ( + "fmt" + "github.com/imroc/req/v3" + "log" + "time" +) + +type ErrorMessage struct { + Message string `json:"message"` +} + +type UserInfo struct { + Name string `json:"name"` + Blog string `json:"blog"` +} + +func main() { + client := req.C(). + SetUserAgent("my-custom-client"). // Chainable client settings. + SetTimeout(5 * time.Second) + + var userInfo UserInfo + var errMsg ErrorMessage + resp, err := client.R(). + SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings. + SetPathParam("username", "imroc"). // Replace path variable in url. + SetSuccessResult(&userInfo). // Unmarshal response body into userInfo automatically if status code is between 200 and 299. + SetErrorResult(&errMsg). // Unmarshal response body into errMsg automatically if status code >= 400. + EnableDump(). // Enable dump at request level, only print dump content if there is an error or some unknown situation occurs to help troubleshoot. + Get("https://api.github.com/users/{username}") + + if err != nil { // Error handling. + log.Println("error:", err) + log.Println("raw content:") + log.Println(resp.Dump()) // Record raw content when error occurs. + return + } + + if resp.IsErrorState() { // Status code >= 400. + fmt.Println(errMsg.Message) // Record error message returned. + return + } + + if resp.IsSuccessState() { // Status code is between 200 and 299. + fmt.Printf("%s (%s)\n", userInfo.Name, userInfo.Blog) + return + } + + // Unknown status code. + log.Println("unknown status", resp.Status) + log.Println("raw content:") + log.Println(resp.Dump()) // Record raw content when server returned unknown status code. +} +``` + +Normally it will output (SuccessState): + +```txt +roc (https://imroc.cc) +``` + +## More Advanced GET + +You can set up a unified logic for error handling on the client, so that each time you send a request you only need to focus on the success situation, reducing duplicate code. + +```go +package main + +import ( + "fmt" + "github.com/imroc/req/v3" + "log" + "time" +) + +type ErrorMessage struct { + Message string `json:"message"` +} + +func (msg *ErrorMessage) Error() string { + return fmt.Sprintf("API Error: %s", msg.Message) +} + +type UserInfo struct { + Name string `json:"name"` + Blog string `json:"blog"` +} + +var client = req.C(). + SetUserAgent("my-custom-client"). // Chainable client settings. + SetTimeout(5 * time.Second). + EnableDumpEachRequest(). + SetCommonErrorResult(&ErrorMessage{}). + OnAfterResponse(func(client *req.Client, resp *req.Response) error { + if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error. + return nil + } + if errMsg, ok := resp.ErrorResult().(*ErrorMessage); ok { + resp.Err = errMsg // Convert api error into go error + return nil + } + if !resp.IsSuccessState() { + // Neither a success response nor a error response, record details to help troubleshooting + resp.Err = fmt.Errorf("bad status: %s\nraw content:\n%s", resp.Status, resp.Dump()) + } + return nil + }) + +func main() { + var userInfo UserInfo + resp, err := client.R(). + SetHeader("Accept", "application/vnd.github.v3+json"). // Chainable request settings + SetPathParam("username", "imroc"). + SetSuccessResult(&userInfo). // Unmarshal response body into userInfo automatically if status code is between 200 and 299. + Get("https://api.github.com/users/{username}") + + if err != nil { // Error handling. + log.Println("error:", err) + return + } + + if resp.IsSuccessState() { // Status code is between 200 and 299. + fmt.Printf("%s (%s)\n", userInfo.Name, userInfo.Blog) + } +} +``` + +## Simple POST + +```go +package main + +import ( + "fmt" + "github.com/imroc/req/v3" + "log" +) + +type Repo struct { + Name string `json:"name"` + Url string `json:"url"` +} + +type Result struct { + Data string `json:"data"` +} + +func main() { + client := req.C().DevMode() + var result Result + + resp, err := client.R(). + SetBody(&Repo{Name: "req", Url: "https://github.com/imroc/req"}). + SetSuccessResult(&result). + Post("https://httpbin.org/post") + if err != nil { + log.Fatal(err) + } + + if !resp.IsSuccessState() { + fmt.Println("bad response status:", resp.Status) + return + } + fmt.Println("++++++++++++++++++++++++++++++++++++++++++++++++") + fmt.Println("data:", result.Data) + fmt.Println("++++++++++++++++++++++++++++++++++++++++++++++++") +} +``` + +```txt +2022/05/19 20:11:00.151171 DEBUG [req] HTTP/2 POST https://httpbin.org/post +:authority: httpbin.org +:method: POST +:path: /post +:scheme: https +user-agent: req/v3 (https://github.com/imroc/req/v3) +content-type: application/json; charset=utf-8 +content-length: 55 +accept-encoding: gzip + +{"name":"req","website":"https://github.com/imroc/req"} + +:status: 200 +date: Thu, 19 May 2022 12:11:00 GMT +content-type: application/json +content-length: 651 +server: gunicorn/19.9.0 +access-control-allow-origin: * +access-control-allow-credentials: true + +{ + "args": {}, + "data": "{\"name\":\"req\",\"website\":\"https://github.com/imroc/req\"}", + "files": {}, + "form": {}, + "headers": { + "Accept-Encoding": "gzip", + "Content-Length": "55", + "Content-Type": "application/json; charset=utf-8", + "Host": "httpbin.org", + "User-Agent": "req/v3 (https://github.com/imroc/req/v3)", + "X-Amzn-Trace-Id": "Root=1-628633d4-7559d633152b4307288ead2e" + }, + "json": { + "name": "req", + "website": "https://github.com/imroc/req" + }, + "origin": "103.7.29.30", + "url": "https://httpbin.org/post" +} + +++++++++++++++++++++++++++++++++++++++++++++++++ +data: {"name":"req","url":"https://github.com/imroc/req"} +++++++++++++++++++++++++++++++++++++++++++++++++ +``` + +## Do API Style + +If you like, you can also use a Do API style like the following to make requests: + +```go +package main + +import ( + "fmt" + "github.com/imroc/req/v3" +) + +type APIResponse struct { + Origin string `json:"origin"` + Url string `json:"url"` +} + +func main() { + var resp APIResponse + c := req.C().SetBaseURL("https://httpbin.org/post") + err := c.Post(). + SetBody("hello"). + Do(). + Into(&resp) + if err != nil { + panic(err) + } + fmt.Println("My IP is", resp.Origin) +} +``` + +```txt +My IP is 182.138.155.113 +``` + +* The order of chain calls is more intuitive: first call Client to create a request with a specified Method, then use chain calls to set the request, then use `Do()` to fire the request, return Response, and finally call `Response.Into` to unmarshal response body into specified object. +* `Response.Into` will return an error if an error occurs during sending the request or during unmarshalling. +* The url of some APIs is fixed, and different types of requests are implemented by passing different bodies. In this scenario, `Client.SetBaseURL` can be used to set a unified url, and there is no need to set the url for each request when initiating a request. Of course, you can also call `Request.SetURL` to set it if you need it. + +## Build SDK With Req + +Here is an example of building GitHub's SDK with req, using two styles (`GetUserProfile_Style1`, `GetUserProfile_Style2`). + +```go +import ( + "context" + "fmt" + "github.com/imroc/req/v3" +) + +type ErrorMessage struct { + Message string `json:"message"` +} + +// Error implements go error interface. +func (msg *ErrorMessage) Error() string { + return fmt.Sprintf("API Error: %s", msg.Message) +} + +type GithubClient struct { + *req.Client +} + +func NewGithubClient() *GithubClient { + return &GithubClient{ + Client: req.C(). + SetBaseURL("https://api.github.com"). + SetCommonErrorResult(&ErrorMessage{}). + EnableDumpEachRequest(). + OnAfterResponse(func(client *req.Client, resp *req.Response) error { + if resp.Err != nil { // There is an underlying error, e.g. network error or unmarshal error. + return nil + } + if errMsg, ok := resp.ErrorResult().(*ErrorMessage); ok { + resp.Err = errMsg // Convert api error into go error + return nil + } + if !resp.IsSuccessState() { + // Neither a success response nor a error response, record details to help troubleshooting + resp.Err = fmt.Errorf("bad status: %s\nraw content:\n%s", resp.Status, resp.Dump()) + } + return nil + }), + } +} + +type UserProfile struct { + Name string `json:"name"` + Blog string `json:"blog"` +} + +// GetUserProfile_Style1 returns the user profile for the specified user. +// Github API doc: https://docs.github.com/en/rest/users/users#get-a-user +func (c *GithubClient) GetUserProfile_Style1(ctx context.Context, username string) (user *UserProfile, err error) { + _, err = c.R(). + SetContext(ctx). + SetPathParam("username", username). + SetSuccessResult(&user). + Get("/users/{username}") + return +} + +// GetUserProfile_Style2 returns the user profile for the specified user. +// Github API doc: https://docs.github.com/en/rest/users/users#get-a-user +func (c *GithubClient) GetUserProfile_Style2(ctx context.Context, username string) (user *UserProfile, err error) { + err = c.Get("/users/{username}"). + SetPathParam("username", username). + Do(ctx). + Into(&user) + return +} +``` + ## Contributing If you have a bug report or feature request, you can [open an issue](https://github.com/imroc/req/issues/new), and [pull requests](https://github.com/imroc/req/pulls) are also welcome. From b064bf0e5cb63167ee2876e65501a00c9fa45c68 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 30 Jan 2023 16:16:07 +0800 Subject: [PATCH 669/843] Avoid err been override when response middleware is set --- client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 5386ed60..359e82c8 100644 --- a/client.go +++ b/client.go @@ -1382,8 +1382,9 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { } for _, f := range r.client.afterResponse { - if err = f(r.client, resp); err != nil { - resp.Err = err + if e := f(r.client, resp); e != nil { + err = e + resp.Err = e return } } From bc789938f75213100a0643ac0b5946f56861bd9e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 31 Jan 2023 10:51:13 +0800 Subject: [PATCH 670/843] Support infinity retry --- client.go | 1 + request.go | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 359e82c8..d80458c9 100644 --- a/client.go +++ b/client.go @@ -1069,6 +1069,7 @@ func (c *Client) getRetryOption() *retryOption { // SetCommonRetryCount enables retry and set the maximum retry count for requests // fired from the client. +// It will retry infinitely if count is negative. func (c *Client) SetCommonRetryCount(count int) *Client { c.getRetryOption().MaxRetries = count return c diff --git a/request.go b/request.go index 4a48e9ac..f1759a86 100644 --- a/request.go +++ b/request.go @@ -539,7 +539,7 @@ func (r *Request) Do(ctx ...context.Context) *Response { if r.error != nil { return r.newErrorResponse(r.error) } - if r.retryOption != nil && r.retryOption.MaxRetries > 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable Body + if r.retryOption != nil && r.retryOption.MaxRetries != 0 && r.unReplayableBody != nil { // retryable request should not have unreplayable Body return r.newErrorResponse(errRetryableWithUnReplayableBody) } resp, _ := r.do() @@ -578,7 +578,7 @@ func (r *Request) do() (resp *Response, err error) { resp, err = r.client.roundTrip(r) } - if r.retryOption == nil || r.RetryAttempt >= r.retryOption.MaxRetries { // absolutely cannot retry. + if r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries > 0) { // absolutely cannot retry. return } @@ -1017,6 +1017,7 @@ func (r *Request) getRetryOption() *retryOption { } // SetRetryCount enables retry and set the maximum retry count. +// It will retry infinitely if count is negative. func (r *Request) SetRetryCount(count int) *Request { r.getRetryOption().MaxRetries = count return r From a6c58eb054d7bb30812873942c8891370bef2ef5 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 31 Jan 2023 11:09:40 +0800 Subject: [PATCH 671/843] Support slice and array in SetBody --- request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/request.go b/request.go index f1759a86..cc145c2b 100644 --- a/request.go +++ b/request.go @@ -757,7 +757,7 @@ func (r *Request) SetBody(body interface{}) *Request { default: t := reflect.TypeOf(body) switch t.Kind() { - case reflect.Ptr, reflect.Struct, reflect.Map: + case reflect.Ptr, reflect.Struct, reflect.Map, reflect.Slice, reflect.Array: r.marshalBody = body default: r.SetBodyString(fmt.Sprint(body)) From 8adac86f930255d26f34140ff0576ef3a828fbd4 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 2 Feb 2023 11:48:44 +0800 Subject: [PATCH 672/843] Move github.com/marten-seemann/qpack to github.com/quic-go/qpack --- go.mod | 9 ++--- go.sum | 54 +++++++++++++++++++------- internal/http3/client.go | 2 +- internal/http3/client_test.go | 2 +- internal/http3/request.go | 2 +- internal/http3/request_test.go | 2 +- internal/http3/request_writer.go | 2 +- internal/http3/request_writer_test.go | 2 +- internal/http3/response_writer.go | 2 +- internal/http3/response_writer_test.go | 2 +- 10 files changed, 52 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index c2240999..65c0b17e 100644 --- a/go.mod +++ b/go.mod @@ -7,21 +7,20 @@ require ( github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 github.com/hashicorp/go-multierror v1.1.1 - github.com/marten-seemann/qpack v0.2.1 github.com/marten-seemann/qtls-go1-16 v0.1.5 github.com/marten-seemann/qtls-go1-17 v0.1.2 github.com/marten-seemann/qtls-go1-18 v0.1.3 github.com/marten-seemann/qtls-go1-19 v0.1.1 github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.13.0 + github.com/onsi/gomega v1.20.1 + github.com/quic-go/qpack v0.4.0 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a + golang.org/x/net v0.0.0-20220722155237-a158d28d115b + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f golang.org/x/text v0.3.7 ) require ( github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - golang.org/x/tools v0.1.11 // indirect ) diff --git a/go.sum b/go.sum index 6fda9b95..9e3014db 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,9 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -57,10 +60,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -72,6 +79,7 @@ github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -82,8 +90,6 @@ github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= -github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= @@ -103,14 +109,19 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= +github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= +github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= +github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= -github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= +github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= +github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= +github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -119,6 +130,8 @@ github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= +github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -155,6 +168,7 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= @@ -172,6 +186,7 @@ golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -186,13 +201,15 @@ golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -205,6 +222,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -214,9 +232,9 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -225,13 +243,17 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= @@ -247,8 +269,9 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= -golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -276,8 +299,9 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= @@ -290,6 +314,8 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/http3/client.go b/internal/http3/client.go index 2c4ce4c5..b7e96e7e 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -12,7 +12,7 @@ import ( "github.com/imroc/req/v3/internal/quic-go/qtls" "github.com/imroc/req/v3/internal/quic-go/quicvarint" "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "io" "net/http" "strconv" diff --git a/internal/http3/client_test.go b/internal/http3/client_test.go index 213a127f..14b1129f 100644 --- a/internal/http3/client_test.go +++ b/internal/http3/client_test.go @@ -7,7 +7,7 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "io" "io/ioutil" "net/http" diff --git a/internal/http3/request.go b/internal/http3/request.go index dceb96e4..a6fe02cd 100644 --- a/internal/http3/request.go +++ b/internal/http3/request.go @@ -3,7 +3,7 @@ package http3 import ( "crypto/tls" "errors" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "net/http" "net/url" "strconv" diff --git a/internal/http3/request_test.go b/internal/http3/request_test.go index 6841abed..059fcfe1 100644 --- a/internal/http3/request_test.go +++ b/internal/http3/request_test.go @@ -1,7 +1,7 @@ package http3 import ( - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "net/http" "net/url" diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index bf714210..f77d3a87 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "io" "net" "net/http" diff --git a/internal/http3/request_writer_test.go b/internal/http3/request_writer_test.go index d8f1eafe..bddedcc0 100644 --- a/internal/http3/request_writer_test.go +++ b/internal/http3/request_writer_test.go @@ -2,7 +2,7 @@ package http3 import ( "bytes" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "io" "net/http" diff --git a/internal/http3/response_writer.go b/internal/http3/response_writer.go index 2233c2b0..ab45e379 100644 --- a/internal/http3/response_writer.go +++ b/internal/http3/response_writer.go @@ -3,7 +3,7 @@ package http3 import ( "bufio" "bytes" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "net/http" "strconv" "strings" diff --git a/internal/http3/response_writer_test.go b/internal/http3/response_writer_test.go index 20a01d9f..b24f2c28 100644 --- a/internal/http3/response_writer_test.go +++ b/internal/http3/response_writer_test.go @@ -2,7 +2,7 @@ package http3 import ( "bytes" - "github.com/marten-seemann/qpack" + "github.com/quic-go/qpack" "io" "net/http" From 6d9b23e7e5a38ec7fa79e0da252f75003439eccc Mon Sep 17 00:00:00 2001 From: rockerchen Date: Mon, 6 Feb 2023 19:38:50 +0800 Subject: [PATCH 673/843] Refactor http3 implementation --- README.md | 2 +- go.mod | 31 +- go.sum | 287 +- internal/http3/body.go | 2 +- internal/http3/body_test.go | 54 - internal/http3/client.go | 45 +- internal/http3/client_test.go | 1022 ------ internal/http3/error_codes.go | 2 +- internal/http3/error_codes_test.go | 39 - internal/http3/frames.go | 5 +- internal/http3/frames_test.go | 245 -- internal/http3/http3_suite_test.go | 38 - internal/http3/http_stream.go | 2 +- internal/http3/http_stream_test.go | 150 - internal/http3/request_test.go | 197 -- internal/http3/request_writer.go | 9 +- internal/http3/request_writer_test.go | 111 - internal/http3/response_writer.go | 118 - internal/http3/response_writer_test.go | 149 - internal/http3/roundtrip.go | 3 +- internal/http3/roundtrip_test.go | 253 -- internal/http3/server.go | 15 +- internal/quic-go/ackhandler/ack_eliciting.go | 20 - .../quic-go/ackhandler/ack_eliciting_test.go | 34 - internal/quic-go/ackhandler/ackhandler.go | 21 - .../ackhandler/ackhandler_suite_test.go | 29 - internal/quic-go/ackhandler/frame.go | 9 - internal/quic-go/ackhandler/gen.go | 3 - internal/quic-go/ackhandler/interfaces.go | 68 - .../mock_sent_packet_tracker_test.go | 61 - internal/quic-go/ackhandler/mockgen.go | 3 - .../quic-go/ackhandler/packet_linkedlist.go | 217 -- .../ackhandler/packet_number_generator.go | 76 - .../packet_number_generator_test.go | 99 - .../ackhandler/received_packet_handler.go | 136 - .../received_packet_handler_test.go | 168 - .../ackhandler/received_packet_history.go | 142 - .../received_packet_history_test.go | 354 -- .../ackhandler/received_packet_tracker.go | 196 -- .../received_packet_tracker_test.go | 348 -- internal/quic-go/ackhandler/send_mode.go | 40 - internal/quic-go/ackhandler/send_mode_test.go | 18 - .../quic-go/ackhandler/sent_packet_handler.go | 838 ----- .../ackhandler/sent_packet_handler_test.go | 1386 -------- .../quic-go/ackhandler/sent_packet_history.go | 108 - .../ackhandler/sent_packet_history_test.go | 263 -- internal/quic-go/buffer_pool.go | 80 - internal/quic-go/buffer_pool_test.go | 55 - internal/quic-go/client.go | 339 -- internal/quic-go/client_test.go | 611 ---- internal/quic-go/closed_conn.go | 112 - internal/quic-go/closed_conn_test.go | 56 - internal/quic-go/config.go | 124 - internal/quic-go/config_test.go | 180 - internal/quic-go/congestion/bandwidth.go | 25 - internal/quic-go/congestion/bandwidth_test.go | 14 - internal/quic-go/congestion/clock.go | 18 - .../congestion/congestion_suite_test.go | 13 - internal/quic-go/congestion/cubic.go | 214 -- internal/quic-go/congestion/cubic_sender.go | 316 -- .../quic-go/congestion/cubic_sender_test.go | 526 --- internal/quic-go/congestion/cubic_test.go | 239 -- .../quic-go/congestion/hybrid_slow_start.go | 113 - .../congestion/hybrid_slow_start_test.go | 72 - internal/quic-go/congestion/interface.go | 28 - internal/quic-go/congestion/pacer.go | 77 - internal/quic-go/congestion/pacer_test.go | 131 - internal/quic-go/conn_id_generator.go | 140 - internal/quic-go/conn_id_generator_test.go | 187 - internal/quic-go/conn_id_manager.go | 207 -- internal/quic-go/conn_id_manager_test.go | 364 -- internal/quic-go/connection.go | 2006 ----------- internal/quic-go/connection_test.go | 3038 ----------------- internal/quic-go/crypto_stream.go | 115 - internal/quic-go/crypto_stream_manager.go | 61 - .../quic-go/crypto_stream_manager_test.go | 119 - internal/quic-go/crypto_stream_test.go | 187 - internal/quic-go/datagram_queue.go | 87 - internal/quic-go/datagram_queue_test.go | 98 - internal/quic-go/errors.go | 58 - .../flowcontrol/base_flow_controller.go | 125 - .../flowcontrol/base_flow_controller_test.go | 236 -- .../flowcontrol/connection_flow_controller.go | 112 - .../connection_flow_controller_test.go | 185 - .../flowcontrol/flowcontrol_suite_test.go | 24 - internal/quic-go/flowcontrol/interface.go | 42 - .../flowcontrol/stream_flow_controller.go | 149 - .../stream_flow_controller_test.go | 272 -- internal/quic-go/frame_sorter.go | 224 -- internal/quic-go/frame_sorter_test.go | 1527 --------- internal/quic-go/framer.go | 171 - internal/quic-go/framer_test.go | 385 --- internal/quic-go/handshake/aead.go | 161 - internal/quic-go/handshake/aead_test.go | 204 -- internal/quic-go/handshake/crypto_setup.go | 819 ----- .../quic-go/handshake/crypto_setup_test.go | 864 ----- .../quic-go/handshake/handshake_suite_test.go | 48 - .../quic-go/handshake/header_protector.go | 136 - internal/quic-go/handshake/hkdf.go | 29 - internal/quic-go/handshake/hkdf_test.go | 17 - internal/quic-go/handshake/initial_aead.go | 81 - .../quic-go/handshake/initial_aead_test.go | 219 -- internal/quic-go/handshake/interface.go | 102 - .../handshake/mock_handshake_runner_test.go | 84 - internal/quic-go/handshake/mockgen.go | 3 - internal/quic-go/handshake/retry.go | 62 - internal/quic-go/handshake/retry_test.go | 36 - internal/quic-go/handshake/session_ticket.go | 48 - .../quic-go/handshake/session_ticket_test.go | 54 - .../handshake/tls_extension_handler.go | 68 - .../handshake/tls_extension_handler_test.go | 210 -- internal/quic-go/handshake/token_generator.go | 134 - .../quic-go/handshake/token_generator_test.go | 127 - internal/quic-go/handshake/token_protector.go | 89 - .../quic-go/handshake/token_protector_test.go | 67 - internal/quic-go/handshake/updatable_aead.go | 323 -- .../quic-go/handshake/updatable_aead_test.go | 528 --- internal/quic-go/interface.go | 328 -- internal/quic-go/logging/frame.go | 66 - internal/quic-go/logging/interface.go | 134 - .../quic-go/logging/logging_suite_test.go | 25 - .../logging/mock_connection_tracer_test.go | 351 -- internal/quic-go/logging/mock_tracer_test.go | 76 - internal/quic-go/logging/mockgen.go | 4 - internal/quic-go/logging/multiplex.go | 219 -- internal/quic-go/logging/multiplex_test.go | 266 -- internal/quic-go/logging/packet_header.go | 27 - .../quic-go/logging/packet_header_test.go | 60 - internal/quic-go/logging/types.go | 94 - internal/quic-go/logutils/frame.go | 33 - internal/quic-go/logutils/frame_test.go | 51 - .../quic-go/logutils/logutils_suite_test.go | 13 - .../quic-go/mock_ack_frame_source_test.go | 50 - internal/quic-go/mock_batch_conn_test.go | 50 - internal/quic-go/mock_conn_runner_test.go | 123 - .../quic-go/mock_crypto_data_handler_test.go | 49 - internal/quic-go/mock_crypto_stream_test.go | 121 - internal/quic-go/mock_frame_source_test.go | 80 - internal/quic-go/mock_mtu_discoverer_test.go | 66 - internal/quic-go/mock_multiplexer_test.go | 65 - internal/quic-go/mock_packer_test.go | 179 - .../mock_packet_handler_manager_test.go | 175 - internal/quic-go/mock_packet_handler_test.go | 85 - internal/quic-go/mock_packetconn_test.go | 137 - internal/quic-go/mock_quic_conn_test.go | 346 -- .../mock_receive_stream_internal_test.go | 146 - internal/quic-go/mock_sealing_manager_test.go | 95 - internal/quic-go/mock_send_conn_test.go | 91 - .../quic-go/mock_send_stream_internal_test.go | 187 - internal/quic-go/mock_sender_test.go | 100 - internal/quic-go/mock_stream_getter_test.go | 65 - internal/quic-go/mock_stream_internal_test.go | 284 -- internal/quic-go/mock_stream_manager_test.go | 231 -- internal/quic-go/mock_stream_sender_test.go | 72 - internal/quic-go/mock_token_store_test.go | 60 - .../mock_unknown_packet_handler_test.go | 58 - internal/quic-go/mock_unpacker_test.go | 51 - internal/quic-go/mockgen.go | 27 - internal/quic-go/mockgen_private.sh | 49 - .../ackhandler/received_packet_handler.go | 105 - .../mocks/ackhandler/sent_packet_handler.go | 240 -- internal/quic-go/mocks/congestion.go | 192 -- .../mocks/connection_flow_controller.go | 128 - internal/quic-go/mocks/crypto_setup.go | 264 -- .../mocks/logging/connection_tracer.go | 352 -- internal/quic-go/mocks/logging/tracer.go | 77 - internal/quic-go/mocks/long_header_opener.go | 76 - internal/quic-go/mocks/mockgen.go | 20 - internal/quic-go/mocks/quic/early_conn.go | 255 -- internal/quic-go/mocks/quic/early_listener.go | 80 - internal/quic-go/mocks/quic/stream.go | 176 - internal/quic-go/mocks/short_header_opener.go | 77 - internal/quic-go/mocks/short_header_sealer.go | 89 - .../quic-go/mocks/stream_flow_controller.go | 140 - .../quic-go/mocks/tls/client_session_cache.go | 62 - internal/quic-go/mtu_discoverer.go | 74 - internal/quic-go/mtu_discoverer_test.go | 112 - internal/quic-go/multiplexer.go | 107 - internal/quic-go/multiplexer_test.go | 70 - internal/quic-go/packet_handler_map.go | 489 --- internal/quic-go/packet_handler_map_test.go | 495 --- internal/quic-go/packet_packer.go | 894 ----- internal/quic-go/packet_packer_test.go | 1556 --------- internal/quic-go/packet_unpacker.go | 196 -- internal/quic-go/packet_unpacker_test.go | 292 -- internal/quic-go/protocol/connection_id.go | 69 - .../quic-go/protocol/connection_id_test.go | 108 - internal/quic-go/protocol/encryption_level.go | 30 - .../quic-go/protocol/encryption_level_test.go | 20 - internal/quic-go/protocol/key_phase.go | 36 - internal/quic-go/protocol/key_phase_test.go | 27 - internal/quic-go/protocol/packet_number.go | 79 - .../quic-go/protocol/packet_number_test.go | 204 -- internal/quic-go/protocol/params.go | 193 -- internal/quic-go/protocol/params_test.go | 13 - internal/quic-go/protocol/perspective.go | 26 - internal/quic-go/protocol/perspective_test.go | 19 - internal/quic-go/protocol/protocol.go | 97 - .../quic-go/protocol/protocol_suite_test.go | 13 - internal/quic-go/protocol/protocol_test.go | 25 - internal/quic-go/protocol/stream.go | 76 - internal/quic-go/protocol/stream_test.go | 70 - internal/quic-go/protocol/version.go | 114 - internal/quic-go/protocol/version_test.go | 121 - internal/quic-go/qerr/error_codes.go | 88 - internal/quic-go/qerr/errorcodes_test.go | 52 - internal/quic-go/qerr/errors.go | 124 - internal/quic-go/qerr/errors_suite_test.go | 13 - internal/quic-go/qerr/errors_test.go | 124 - internal/quic-go/qlog/event.go | 529 --- internal/quic-go/qlog/event_test.go | 43 - internal/quic-go/qlog/frame.go | 227 -- internal/quic-go/qlog/frame_test.go | 377 -- internal/quic-go/qlog/packet_header.go | 119 - internal/quic-go/qlog/packet_header_test.go | 175 - internal/quic-go/qlog/qlog.go | 486 --- internal/quic-go/qlog/qlog_suite_test.go | 51 - internal/quic-go/qlog/qlog_test.go | 849 ----- internal/quic-go/qlog/trace.go | 66 - internal/quic-go/qlog/types.go | 320 -- internal/quic-go/qlog/types_test.go | 130 - internal/quic-go/qtls/go116.go | 100 - internal/quic-go/qtls/go117.go | 100 - internal/quic-go/qtls/go118.go | 100 - internal/quic-go/qtls/go119.go | 100 - internal/quic-go/qtls/go_oldversion.go | 7 - internal/quic-go/qtls/qtls_suite_test.go | 25 - internal/quic-go/qtls/qtls_test.go | 17 - internal/quic-go/quic_suite_test.go | 34 - internal/quic-go/quicvarint/io_test.go | 115 - .../quicvarint/quicvarint_suite_test.go | 13 - internal/quic-go/quicvarint/varint.go | 38 +- internal/quic-go/quicvarint/varint_test.go | 221 -- internal/quic-go/receive_stream.go | 331 -- internal/quic-go/receive_stream_test.go | 696 ---- internal/quic-go/retransmission_queue.go | 131 - internal/quic-go/retransmission_queue_test.go | 187 - internal/quic-go/send_conn.go | 74 - internal/quic-go/send_conn_test.go | 45 - internal/quic-go/send_queue.go | 88 - internal/quic-go/send_queue_test.go | 126 - internal/quic-go/send_stream.go | 496 --- internal/quic-go/send_stream_test.go | 1159 ------- internal/quic-go/server.go | 670 ---- internal/quic-go/server_test.go | 1237 ------- internal/quic-go/stream.go | 149 - internal/quic-go/stream_test.go | 106 - internal/quic-go/streams_map.go | 317 -- .../quic-go/streams_map_generic_helper.go | 18 - internal/quic-go/streams_map_incoming_bidi.go | 192 -- .../quic-go/streams_map_incoming_generic.go | 190 -- .../streams_map_incoming_generic_test.go | 307 -- internal/quic-go/streams_map_incoming_uni.go | 192 -- internal/quic-go/streams_map_outgoing_bidi.go | 226 -- .../quic-go/streams_map_outgoing_generic.go | 224 -- .../streams_map_outgoing_generic_test.go | 539 --- internal/quic-go/streams_map_outgoing_uni.go | 226 -- internal/quic-go/streams_map_test.go | 499 --- internal/quic-go/sys_conn.go | 80 - internal/quic-go/sys_conn_df.go | 16 - internal/quic-go/sys_conn_df_linux.go | 40 - internal/quic-go/sys_conn_df_windows.go | 46 - internal/quic-go/sys_conn_helper_darwin.go | 22 - internal/quic-go/sys_conn_helper_freebsd.go | 22 - internal/quic-go/sys_conn_helper_linux.go | 20 - internal/quic-go/sys_conn_no_oob.go | 16 - internal/quic-go/sys_conn_oob.go | 257 -- internal/quic-go/sys_conn_oob_test.go | 243 -- internal/quic-go/sys_conn_test.go | 33 - internal/quic-go/sys_conn_windows.go | 40 - internal/quic-go/sys_conn_windows_test.go | 33 - internal/quic-go/testdata/ca.pem | 17 - internal/quic-go/testdata/cert.go | 55 - internal/quic-go/testdata/cert.pem | 18 - internal/quic-go/testdata/cert_test.go | 31 - internal/quic-go/testdata/generate_key.sh | 24 - internal/quic-go/testdata/priv.key | 28 - .../quic-go/testdata/testdata_suite_test.go | 13 - internal/quic-go/testutils/testutils.go | 97 - internal/quic-go/token_store.go | 117 - internal/quic-go/token_store_test.go | 108 - internal/quic-go/tools.go | 9 - internal/quic-go/utils/atomic_bool.go | 22 - internal/quic-go/utils/atomic_bool_test.go | 29 - .../quic-go/utils/buffered_write_closer.go | 26 - .../utils/buffered_write_closer_test.go | 26 - .../quic-go/utils/byteinterval_linkedlist.go | 217 -- .../quic-go/utils/byteoder_big_endian_test.go | 107 - internal/quic-go/utils/byteorder.go | 17 - .../quic-go/utils/byteorder_big_endian.go | 89 - internal/quic-go/utils/gen.go | 5 - internal/quic-go/utils/ip.go | 10 - internal/quic-go/utils/ip_test.go | 17 - internal/quic-go/utils/linkedlist/README.md | 11 - .../quic-go/utils/linkedlist/linkedlist.go | 218 -- internal/quic-go/utils/log.go | 131 - internal/quic-go/utils/log_test.go | 144 - internal/quic-go/utils/minmax.go | 170 - internal/quic-go/utils/minmax_test.go | 123 - internal/quic-go/utils/new_connection_id.go | 12 - .../utils/newconnectionid_linkedlist.go | 217 -- internal/quic-go/utils/packet_interval.go | 9 - .../utils/packetinterval_linkedlist.go | 217 -- internal/quic-go/utils/rand.go | 29 - internal/quic-go/utils/rand_test.go | 32 - internal/quic-go/utils/rtt_stats.go | 127 - internal/quic-go/utils/rtt_stats_test.go | 157 - .../quic-go/utils/streamframe_interval.go | 9 - internal/quic-go/utils/timer.go | 53 - internal/quic-go/utils/timer_test.go | 87 - internal/quic-go/utils/utils_suite_test.go | 13 - internal/quic-go/window_update_queue.go | 71 - internal/quic-go/window_update_queue_test.go | 112 - internal/quic-go/wire/ack_frame.go | 251 -- internal/quic-go/wire/ack_frame_test.go | 454 --- internal/quic-go/wire/ack_range.go | 14 - internal/quic-go/wire/ack_range_test.go | 13 - .../quic-go/wire/connection_close_frame.go | 83 - .../wire/connection_close_frame_test.go | 153 - internal/quic-go/wire/crypto_frame.go | 102 - internal/quic-go/wire/crypto_frame_test.go | 148 - internal/quic-go/wire/data_blocked_frame.go | 38 - .../quic-go/wire/data_blocked_frame_test.go | 54 - internal/quic-go/wire/datagram_frame.go | 85 - internal/quic-go/wire/datagram_frame_test.go | 154 - internal/quic-go/wire/extended_header.go | 249 -- internal/quic-go/wire/extended_header_test.go | 481 --- internal/quic-go/wire/frame_parser.go | 143 - internal/quic-go/wire/frame_parser_test.go | 410 --- internal/quic-go/wire/handshake_done_frame.go | 28 - internal/quic-go/wire/header.go | 274 -- internal/quic-go/wire/header_test.go | 583 ---- internal/quic-go/wire/interface.go | 19 - internal/quic-go/wire/log.go | 72 - internal/quic-go/wire/log_test.go | 168 - internal/quic-go/wire/max_data_frame.go | 40 - internal/quic-go/wire/max_data_frame_test.go | 57 - .../quic-go/wire/max_stream_data_frame.go | 46 - .../wire/max_stream_data_frame_test.go | 63 - internal/quic-go/wire/max_streams_frame.go | 55 - .../quic-go/wire/max_streams_frame_test.go | 107 - .../quic-go/wire/new_connection_id_frame.go | 80 - .../wire/new_connection_id_frame_test.go | 104 - internal/quic-go/wire/new_token_frame.go | 48 - internal/quic-go/wire/new_token_frame_test.go | 66 - internal/quic-go/wire/path_challenge_frame.go | 38 - .../quic-go/wire/path_challenge_frame_test.go | 48 - internal/quic-go/wire/path_response_frame.go | 38 - .../quic-go/wire/path_response_frame_test.go | 47 - internal/quic-go/wire/ping_frame.go | 27 - internal/quic-go/wire/ping_frame_test.go | 39 - internal/quic-go/wire/pool.go | 33 - internal/quic-go/wire/pool_test.go | 24 - internal/quic-go/wire/reset_stream_frame.go | 58 - .../quic-go/wire/reset_stream_frame_test.go | 70 - .../wire/retire_connection_id_frame.go | 36 - .../wire/retire_connection_id_frame_test.go | 53 - internal/quic-go/wire/stop_sending_frame.go | 48 - .../quic-go/wire/stop_sending_frame_test.go | 63 - .../quic-go/wire/stream_data_blocked_frame.go | 46 - .../wire/stream_data_blocked_frame_test.go | 63 - internal/quic-go/wire/stream_frame.go | 189 - internal/quic-go/wire/stream_frame_test.go | 443 --- .../quic-go/wire/streams_blocked_frame.go | 55 - .../wire/streams_blocked_frame_test.go | 108 - .../quic-go/wire/transport_parameter_test.go | 612 ---- internal/quic-go/wire/transport_parameters.go | 476 --- internal/quic-go/wire/version_negotiation.go | 54 - .../quic-go/wire/version_negotiation_test.go | 83 - internal/quic-go/wire/wire_suite_test.go | 31 - transport.go | 2 +- 371 files changed, 84 insertions(+), 62440 deletions(-) delete mode 100644 internal/http3/body_test.go delete mode 100644 internal/http3/client_test.go delete mode 100644 internal/http3/error_codes_test.go delete mode 100644 internal/http3/frames_test.go delete mode 100644 internal/http3/http3_suite_test.go delete mode 100644 internal/http3/http_stream_test.go delete mode 100644 internal/http3/request_test.go delete mode 100644 internal/http3/request_writer_test.go delete mode 100644 internal/http3/response_writer.go delete mode 100644 internal/http3/response_writer_test.go delete mode 100644 internal/http3/roundtrip_test.go delete mode 100644 internal/quic-go/ackhandler/ack_eliciting.go delete mode 100644 internal/quic-go/ackhandler/ack_eliciting_test.go delete mode 100644 internal/quic-go/ackhandler/ackhandler.go delete mode 100644 internal/quic-go/ackhandler/ackhandler_suite_test.go delete mode 100644 internal/quic-go/ackhandler/frame.go delete mode 100644 internal/quic-go/ackhandler/gen.go delete mode 100644 internal/quic-go/ackhandler/interfaces.go delete mode 100644 internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go delete mode 100644 internal/quic-go/ackhandler/mockgen.go delete mode 100644 internal/quic-go/ackhandler/packet_linkedlist.go delete mode 100644 internal/quic-go/ackhandler/packet_number_generator.go delete mode 100644 internal/quic-go/ackhandler/packet_number_generator_test.go delete mode 100644 internal/quic-go/ackhandler/received_packet_handler.go delete mode 100644 internal/quic-go/ackhandler/received_packet_handler_test.go delete mode 100644 internal/quic-go/ackhandler/received_packet_history.go delete mode 100644 internal/quic-go/ackhandler/received_packet_history_test.go delete mode 100644 internal/quic-go/ackhandler/received_packet_tracker.go delete mode 100644 internal/quic-go/ackhandler/received_packet_tracker_test.go delete mode 100644 internal/quic-go/ackhandler/send_mode.go delete mode 100644 internal/quic-go/ackhandler/send_mode_test.go delete mode 100644 internal/quic-go/ackhandler/sent_packet_handler.go delete mode 100644 internal/quic-go/ackhandler/sent_packet_handler_test.go delete mode 100644 internal/quic-go/ackhandler/sent_packet_history.go delete mode 100644 internal/quic-go/ackhandler/sent_packet_history_test.go delete mode 100644 internal/quic-go/buffer_pool.go delete mode 100644 internal/quic-go/buffer_pool_test.go delete mode 100644 internal/quic-go/client.go delete mode 100644 internal/quic-go/client_test.go delete mode 100644 internal/quic-go/closed_conn.go delete mode 100644 internal/quic-go/closed_conn_test.go delete mode 100644 internal/quic-go/config.go delete mode 100644 internal/quic-go/config_test.go delete mode 100644 internal/quic-go/congestion/bandwidth.go delete mode 100644 internal/quic-go/congestion/bandwidth_test.go delete mode 100644 internal/quic-go/congestion/clock.go delete mode 100644 internal/quic-go/congestion/congestion_suite_test.go delete mode 100644 internal/quic-go/congestion/cubic.go delete mode 100644 internal/quic-go/congestion/cubic_sender.go delete mode 100644 internal/quic-go/congestion/cubic_sender_test.go delete mode 100644 internal/quic-go/congestion/cubic_test.go delete mode 100644 internal/quic-go/congestion/hybrid_slow_start.go delete mode 100644 internal/quic-go/congestion/hybrid_slow_start_test.go delete mode 100644 internal/quic-go/congestion/interface.go delete mode 100644 internal/quic-go/congestion/pacer.go delete mode 100644 internal/quic-go/congestion/pacer_test.go delete mode 100644 internal/quic-go/conn_id_generator.go delete mode 100644 internal/quic-go/conn_id_generator_test.go delete mode 100644 internal/quic-go/conn_id_manager.go delete mode 100644 internal/quic-go/conn_id_manager_test.go delete mode 100644 internal/quic-go/connection.go delete mode 100644 internal/quic-go/connection_test.go delete mode 100644 internal/quic-go/crypto_stream.go delete mode 100644 internal/quic-go/crypto_stream_manager.go delete mode 100644 internal/quic-go/crypto_stream_manager_test.go delete mode 100644 internal/quic-go/crypto_stream_test.go delete mode 100644 internal/quic-go/datagram_queue.go delete mode 100644 internal/quic-go/datagram_queue_test.go delete mode 100644 internal/quic-go/errors.go delete mode 100644 internal/quic-go/flowcontrol/base_flow_controller.go delete mode 100644 internal/quic-go/flowcontrol/base_flow_controller_test.go delete mode 100644 internal/quic-go/flowcontrol/connection_flow_controller.go delete mode 100644 internal/quic-go/flowcontrol/connection_flow_controller_test.go delete mode 100644 internal/quic-go/flowcontrol/flowcontrol_suite_test.go delete mode 100644 internal/quic-go/flowcontrol/interface.go delete mode 100644 internal/quic-go/flowcontrol/stream_flow_controller.go delete mode 100644 internal/quic-go/flowcontrol/stream_flow_controller_test.go delete mode 100644 internal/quic-go/frame_sorter.go delete mode 100644 internal/quic-go/frame_sorter_test.go delete mode 100644 internal/quic-go/framer.go delete mode 100644 internal/quic-go/framer_test.go delete mode 100644 internal/quic-go/handshake/aead.go delete mode 100644 internal/quic-go/handshake/aead_test.go delete mode 100644 internal/quic-go/handshake/crypto_setup.go delete mode 100644 internal/quic-go/handshake/crypto_setup_test.go delete mode 100644 internal/quic-go/handshake/handshake_suite_test.go delete mode 100644 internal/quic-go/handshake/header_protector.go delete mode 100644 internal/quic-go/handshake/hkdf.go delete mode 100644 internal/quic-go/handshake/hkdf_test.go delete mode 100644 internal/quic-go/handshake/initial_aead.go delete mode 100644 internal/quic-go/handshake/initial_aead_test.go delete mode 100644 internal/quic-go/handshake/interface.go delete mode 100644 internal/quic-go/handshake/mock_handshake_runner_test.go delete mode 100644 internal/quic-go/handshake/mockgen.go delete mode 100644 internal/quic-go/handshake/retry.go delete mode 100644 internal/quic-go/handshake/retry_test.go delete mode 100644 internal/quic-go/handshake/session_ticket.go delete mode 100644 internal/quic-go/handshake/session_ticket_test.go delete mode 100644 internal/quic-go/handshake/tls_extension_handler.go delete mode 100644 internal/quic-go/handshake/tls_extension_handler_test.go delete mode 100644 internal/quic-go/handshake/token_generator.go delete mode 100644 internal/quic-go/handshake/token_generator_test.go delete mode 100644 internal/quic-go/handshake/token_protector.go delete mode 100644 internal/quic-go/handshake/token_protector_test.go delete mode 100644 internal/quic-go/handshake/updatable_aead.go delete mode 100644 internal/quic-go/handshake/updatable_aead_test.go delete mode 100644 internal/quic-go/interface.go delete mode 100644 internal/quic-go/logging/frame.go delete mode 100644 internal/quic-go/logging/interface.go delete mode 100644 internal/quic-go/logging/logging_suite_test.go delete mode 100644 internal/quic-go/logging/mock_connection_tracer_test.go delete mode 100644 internal/quic-go/logging/mock_tracer_test.go delete mode 100644 internal/quic-go/logging/mockgen.go delete mode 100644 internal/quic-go/logging/multiplex.go delete mode 100644 internal/quic-go/logging/multiplex_test.go delete mode 100644 internal/quic-go/logging/packet_header.go delete mode 100644 internal/quic-go/logging/packet_header_test.go delete mode 100644 internal/quic-go/logging/types.go delete mode 100644 internal/quic-go/logutils/frame.go delete mode 100644 internal/quic-go/logutils/frame_test.go delete mode 100644 internal/quic-go/logutils/logutils_suite_test.go delete mode 100644 internal/quic-go/mock_ack_frame_source_test.go delete mode 100644 internal/quic-go/mock_batch_conn_test.go delete mode 100644 internal/quic-go/mock_conn_runner_test.go delete mode 100644 internal/quic-go/mock_crypto_data_handler_test.go delete mode 100644 internal/quic-go/mock_crypto_stream_test.go delete mode 100644 internal/quic-go/mock_frame_source_test.go delete mode 100644 internal/quic-go/mock_mtu_discoverer_test.go delete mode 100644 internal/quic-go/mock_multiplexer_test.go delete mode 100644 internal/quic-go/mock_packer_test.go delete mode 100644 internal/quic-go/mock_packet_handler_manager_test.go delete mode 100644 internal/quic-go/mock_packet_handler_test.go delete mode 100644 internal/quic-go/mock_packetconn_test.go delete mode 100644 internal/quic-go/mock_quic_conn_test.go delete mode 100644 internal/quic-go/mock_receive_stream_internal_test.go delete mode 100644 internal/quic-go/mock_sealing_manager_test.go delete mode 100644 internal/quic-go/mock_send_conn_test.go delete mode 100644 internal/quic-go/mock_send_stream_internal_test.go delete mode 100644 internal/quic-go/mock_sender_test.go delete mode 100644 internal/quic-go/mock_stream_getter_test.go delete mode 100644 internal/quic-go/mock_stream_internal_test.go delete mode 100644 internal/quic-go/mock_stream_manager_test.go delete mode 100644 internal/quic-go/mock_stream_sender_test.go delete mode 100644 internal/quic-go/mock_token_store_test.go delete mode 100644 internal/quic-go/mock_unknown_packet_handler_test.go delete mode 100644 internal/quic-go/mock_unpacker_test.go delete mode 100644 internal/quic-go/mockgen.go delete mode 100755 internal/quic-go/mockgen_private.sh delete mode 100644 internal/quic-go/mocks/ackhandler/received_packet_handler.go delete mode 100644 internal/quic-go/mocks/ackhandler/sent_packet_handler.go delete mode 100644 internal/quic-go/mocks/congestion.go delete mode 100644 internal/quic-go/mocks/connection_flow_controller.go delete mode 100644 internal/quic-go/mocks/crypto_setup.go delete mode 100644 internal/quic-go/mocks/logging/connection_tracer.go delete mode 100644 internal/quic-go/mocks/logging/tracer.go delete mode 100644 internal/quic-go/mocks/long_header_opener.go delete mode 100644 internal/quic-go/mocks/mockgen.go delete mode 100644 internal/quic-go/mocks/quic/early_conn.go delete mode 100644 internal/quic-go/mocks/quic/early_listener.go delete mode 100644 internal/quic-go/mocks/quic/stream.go delete mode 100644 internal/quic-go/mocks/short_header_opener.go delete mode 100644 internal/quic-go/mocks/short_header_sealer.go delete mode 100644 internal/quic-go/mocks/stream_flow_controller.go delete mode 100644 internal/quic-go/mocks/tls/client_session_cache.go delete mode 100644 internal/quic-go/mtu_discoverer.go delete mode 100644 internal/quic-go/mtu_discoverer_test.go delete mode 100644 internal/quic-go/multiplexer.go delete mode 100644 internal/quic-go/multiplexer_test.go delete mode 100644 internal/quic-go/packet_handler_map.go delete mode 100644 internal/quic-go/packet_handler_map_test.go delete mode 100644 internal/quic-go/packet_packer.go delete mode 100644 internal/quic-go/packet_packer_test.go delete mode 100644 internal/quic-go/packet_unpacker.go delete mode 100644 internal/quic-go/packet_unpacker_test.go delete mode 100644 internal/quic-go/protocol/connection_id.go delete mode 100644 internal/quic-go/protocol/connection_id_test.go delete mode 100644 internal/quic-go/protocol/encryption_level.go delete mode 100644 internal/quic-go/protocol/encryption_level_test.go delete mode 100644 internal/quic-go/protocol/key_phase.go delete mode 100644 internal/quic-go/protocol/key_phase_test.go delete mode 100644 internal/quic-go/protocol/packet_number.go delete mode 100644 internal/quic-go/protocol/packet_number_test.go delete mode 100644 internal/quic-go/protocol/params.go delete mode 100644 internal/quic-go/protocol/params_test.go delete mode 100644 internal/quic-go/protocol/perspective.go delete mode 100644 internal/quic-go/protocol/perspective_test.go delete mode 100644 internal/quic-go/protocol/protocol.go delete mode 100644 internal/quic-go/protocol/protocol_suite_test.go delete mode 100644 internal/quic-go/protocol/protocol_test.go delete mode 100644 internal/quic-go/protocol/stream.go delete mode 100644 internal/quic-go/protocol/stream_test.go delete mode 100644 internal/quic-go/protocol/version.go delete mode 100644 internal/quic-go/protocol/version_test.go delete mode 100644 internal/quic-go/qerr/error_codes.go delete mode 100644 internal/quic-go/qerr/errorcodes_test.go delete mode 100644 internal/quic-go/qerr/errors.go delete mode 100644 internal/quic-go/qerr/errors_suite_test.go delete mode 100644 internal/quic-go/qerr/errors_test.go delete mode 100644 internal/quic-go/qlog/event.go delete mode 100644 internal/quic-go/qlog/event_test.go delete mode 100644 internal/quic-go/qlog/frame.go delete mode 100644 internal/quic-go/qlog/frame_test.go delete mode 100644 internal/quic-go/qlog/packet_header.go delete mode 100644 internal/quic-go/qlog/packet_header_test.go delete mode 100644 internal/quic-go/qlog/qlog.go delete mode 100644 internal/quic-go/qlog/qlog_suite_test.go delete mode 100644 internal/quic-go/qlog/qlog_test.go delete mode 100644 internal/quic-go/qlog/trace.go delete mode 100644 internal/quic-go/qlog/types.go delete mode 100644 internal/quic-go/qlog/types_test.go delete mode 100644 internal/quic-go/qtls/go116.go delete mode 100644 internal/quic-go/qtls/go117.go delete mode 100644 internal/quic-go/qtls/go118.go delete mode 100644 internal/quic-go/qtls/go119.go delete mode 100644 internal/quic-go/qtls/go_oldversion.go delete mode 100644 internal/quic-go/qtls/qtls_suite_test.go delete mode 100644 internal/quic-go/qtls/qtls_test.go delete mode 100644 internal/quic-go/quic_suite_test.go delete mode 100644 internal/quic-go/quicvarint/io_test.go delete mode 100644 internal/quic-go/quicvarint/quicvarint_suite_test.go delete mode 100644 internal/quic-go/quicvarint/varint_test.go delete mode 100644 internal/quic-go/receive_stream.go delete mode 100644 internal/quic-go/receive_stream_test.go delete mode 100644 internal/quic-go/retransmission_queue.go delete mode 100644 internal/quic-go/retransmission_queue_test.go delete mode 100644 internal/quic-go/send_conn.go delete mode 100644 internal/quic-go/send_conn_test.go delete mode 100644 internal/quic-go/send_queue.go delete mode 100644 internal/quic-go/send_queue_test.go delete mode 100644 internal/quic-go/send_stream.go delete mode 100644 internal/quic-go/send_stream_test.go delete mode 100644 internal/quic-go/server.go delete mode 100644 internal/quic-go/server_test.go delete mode 100644 internal/quic-go/stream.go delete mode 100644 internal/quic-go/stream_test.go delete mode 100644 internal/quic-go/streams_map.go delete mode 100644 internal/quic-go/streams_map_generic_helper.go delete mode 100644 internal/quic-go/streams_map_incoming_bidi.go delete mode 100644 internal/quic-go/streams_map_incoming_generic.go delete mode 100644 internal/quic-go/streams_map_incoming_generic_test.go delete mode 100644 internal/quic-go/streams_map_incoming_uni.go delete mode 100644 internal/quic-go/streams_map_outgoing_bidi.go delete mode 100644 internal/quic-go/streams_map_outgoing_generic.go delete mode 100644 internal/quic-go/streams_map_outgoing_generic_test.go delete mode 100644 internal/quic-go/streams_map_outgoing_uni.go delete mode 100644 internal/quic-go/streams_map_test.go delete mode 100644 internal/quic-go/sys_conn.go delete mode 100644 internal/quic-go/sys_conn_df.go delete mode 100644 internal/quic-go/sys_conn_df_linux.go delete mode 100644 internal/quic-go/sys_conn_df_windows.go delete mode 100644 internal/quic-go/sys_conn_helper_darwin.go delete mode 100644 internal/quic-go/sys_conn_helper_freebsd.go delete mode 100644 internal/quic-go/sys_conn_helper_linux.go delete mode 100644 internal/quic-go/sys_conn_no_oob.go delete mode 100644 internal/quic-go/sys_conn_oob.go delete mode 100644 internal/quic-go/sys_conn_oob_test.go delete mode 100644 internal/quic-go/sys_conn_test.go delete mode 100644 internal/quic-go/sys_conn_windows.go delete mode 100644 internal/quic-go/sys_conn_windows_test.go delete mode 100644 internal/quic-go/testdata/ca.pem delete mode 100644 internal/quic-go/testdata/cert.go delete mode 100644 internal/quic-go/testdata/cert.pem delete mode 100644 internal/quic-go/testdata/cert_test.go delete mode 100755 internal/quic-go/testdata/generate_key.sh delete mode 100644 internal/quic-go/testdata/priv.key delete mode 100644 internal/quic-go/testdata/testdata_suite_test.go delete mode 100644 internal/quic-go/testutils/testutils.go delete mode 100644 internal/quic-go/token_store.go delete mode 100644 internal/quic-go/token_store_test.go delete mode 100644 internal/quic-go/tools.go delete mode 100644 internal/quic-go/utils/atomic_bool.go delete mode 100644 internal/quic-go/utils/atomic_bool_test.go delete mode 100644 internal/quic-go/utils/buffered_write_closer.go delete mode 100644 internal/quic-go/utils/buffered_write_closer_test.go delete mode 100644 internal/quic-go/utils/byteinterval_linkedlist.go delete mode 100644 internal/quic-go/utils/byteoder_big_endian_test.go delete mode 100644 internal/quic-go/utils/byteorder.go delete mode 100644 internal/quic-go/utils/byteorder_big_endian.go delete mode 100644 internal/quic-go/utils/gen.go delete mode 100644 internal/quic-go/utils/ip.go delete mode 100644 internal/quic-go/utils/ip_test.go delete mode 100644 internal/quic-go/utils/linkedlist/README.md delete mode 100644 internal/quic-go/utils/linkedlist/linkedlist.go delete mode 100644 internal/quic-go/utils/log.go delete mode 100644 internal/quic-go/utils/log_test.go delete mode 100644 internal/quic-go/utils/minmax.go delete mode 100644 internal/quic-go/utils/minmax_test.go delete mode 100644 internal/quic-go/utils/new_connection_id.go delete mode 100644 internal/quic-go/utils/newconnectionid_linkedlist.go delete mode 100644 internal/quic-go/utils/packet_interval.go delete mode 100644 internal/quic-go/utils/packetinterval_linkedlist.go delete mode 100644 internal/quic-go/utils/rand.go delete mode 100644 internal/quic-go/utils/rand_test.go delete mode 100644 internal/quic-go/utils/rtt_stats.go delete mode 100644 internal/quic-go/utils/rtt_stats_test.go delete mode 100644 internal/quic-go/utils/streamframe_interval.go delete mode 100644 internal/quic-go/utils/timer.go delete mode 100644 internal/quic-go/utils/timer_test.go delete mode 100644 internal/quic-go/utils/utils_suite_test.go delete mode 100644 internal/quic-go/window_update_queue.go delete mode 100644 internal/quic-go/window_update_queue_test.go delete mode 100644 internal/quic-go/wire/ack_frame.go delete mode 100644 internal/quic-go/wire/ack_frame_test.go delete mode 100644 internal/quic-go/wire/ack_range.go delete mode 100644 internal/quic-go/wire/ack_range_test.go delete mode 100644 internal/quic-go/wire/connection_close_frame.go delete mode 100644 internal/quic-go/wire/connection_close_frame_test.go delete mode 100644 internal/quic-go/wire/crypto_frame.go delete mode 100644 internal/quic-go/wire/crypto_frame_test.go delete mode 100644 internal/quic-go/wire/data_blocked_frame.go delete mode 100644 internal/quic-go/wire/data_blocked_frame_test.go delete mode 100644 internal/quic-go/wire/datagram_frame.go delete mode 100644 internal/quic-go/wire/datagram_frame_test.go delete mode 100644 internal/quic-go/wire/extended_header.go delete mode 100644 internal/quic-go/wire/extended_header_test.go delete mode 100644 internal/quic-go/wire/frame_parser.go delete mode 100644 internal/quic-go/wire/frame_parser_test.go delete mode 100644 internal/quic-go/wire/handshake_done_frame.go delete mode 100644 internal/quic-go/wire/header.go delete mode 100644 internal/quic-go/wire/header_test.go delete mode 100644 internal/quic-go/wire/interface.go delete mode 100644 internal/quic-go/wire/log.go delete mode 100644 internal/quic-go/wire/log_test.go delete mode 100644 internal/quic-go/wire/max_data_frame.go delete mode 100644 internal/quic-go/wire/max_data_frame_test.go delete mode 100644 internal/quic-go/wire/max_stream_data_frame.go delete mode 100644 internal/quic-go/wire/max_stream_data_frame_test.go delete mode 100644 internal/quic-go/wire/max_streams_frame.go delete mode 100644 internal/quic-go/wire/max_streams_frame_test.go delete mode 100644 internal/quic-go/wire/new_connection_id_frame.go delete mode 100644 internal/quic-go/wire/new_connection_id_frame_test.go delete mode 100644 internal/quic-go/wire/new_token_frame.go delete mode 100644 internal/quic-go/wire/new_token_frame_test.go delete mode 100644 internal/quic-go/wire/path_challenge_frame.go delete mode 100644 internal/quic-go/wire/path_challenge_frame_test.go delete mode 100644 internal/quic-go/wire/path_response_frame.go delete mode 100644 internal/quic-go/wire/path_response_frame_test.go delete mode 100644 internal/quic-go/wire/ping_frame.go delete mode 100644 internal/quic-go/wire/ping_frame_test.go delete mode 100644 internal/quic-go/wire/pool.go delete mode 100644 internal/quic-go/wire/pool_test.go delete mode 100644 internal/quic-go/wire/reset_stream_frame.go delete mode 100644 internal/quic-go/wire/reset_stream_frame_test.go delete mode 100644 internal/quic-go/wire/retire_connection_id_frame.go delete mode 100644 internal/quic-go/wire/retire_connection_id_frame_test.go delete mode 100644 internal/quic-go/wire/stop_sending_frame.go delete mode 100644 internal/quic-go/wire/stop_sending_frame_test.go delete mode 100644 internal/quic-go/wire/stream_data_blocked_frame.go delete mode 100644 internal/quic-go/wire/stream_data_blocked_frame_test.go delete mode 100644 internal/quic-go/wire/stream_frame.go delete mode 100644 internal/quic-go/wire/stream_frame_test.go delete mode 100644 internal/quic-go/wire/streams_blocked_frame.go delete mode 100644 internal/quic-go/wire/streams_blocked_frame_test.go delete mode 100644 internal/quic-go/wire/transport_parameter_test.go delete mode 100644 internal/quic-go/wire/transport_parameters.go delete mode 100644 internal/quic-go/wire/version_negotiation.go delete mode 100644 internal/quic-go/wire/version_negotiation_test.go delete mode 100644 internal/quic-go/wire/wire_suite_test.go diff --git a/README.md b/README.md index 4ba1e787..5f7d61f7 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Full documentation is available on the official website: https://req.cool. **Install** -You first need [Go](https://go.dev/) installed (version 1.16+ is required), then you can use the below Go command to install req: +You first need [Go](https://go.dev/) installed (version 1.18+ is required), then you can use the below Go command to install req: ``` sh go get github.com/imroc/req/v3 diff --git a/go.mod b/go.mod index 65c0b17e..f364a57a 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,27 @@ module github.com/imroc/req/v3 -go 1.16 +go 1.18 require ( - github.com/cheekybits/genny v1.0.0 - github.com/francoispqt/gojay v1.2.13 - github.com/golang/mock v1.6.0 github.com/hashicorp/go-multierror v1.1.1 - github.com/marten-seemann/qtls-go1-16 v0.1.5 - github.com/marten-seemann/qtls-go1-17 v0.1.2 - github.com/marten-seemann/qtls-go1-18 v0.1.3 - github.com/marten-seemann/qtls-go1-19 v0.1.1 - github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.20.1 github.com/quic-go/qpack v0.4.0 - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20220722155237-a158d28d115b - golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f - golang.org/x/text v0.3.7 + github.com/quic-go/quic-go v0.32.0 + golang.org/x/net v0.4.0 + golang.org/x/text v0.5.0 ) require ( - github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/onsi/ginkgo/v2 v2.2.0 // indirect + github.com/quic-go/qtls-go1-18 v0.2.0 // indirect + github.com/quic-go/qtls-go1-19 v0.2.0 // indirect + github.com/quic-go/qtls-go1-20 v0.1.0 // indirect + golang.org/x/crypto v0.4.0 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect + golang.org/x/mod v0.6.0 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/tools v0.2.0 // indirect ) diff --git a/go.sum b/go.sum index 9e3014db..a9b87a39 100644 --- a/go.sum +++ b/go.sum @@ -1,324 +1,81 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= -dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= -dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= -dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= -dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= -git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= -github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= -github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= -github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= -github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= -github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= -github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= -github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= -github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= -github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= -github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= -github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= -github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= -github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= -github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= -github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= -github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= -github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= -github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= -github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= -github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= -github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= -github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= -github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= -github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= -github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= -github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= -github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= -github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= -github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= -github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= -github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= -github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= -github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= -github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= -github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= -github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= -github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= -github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= -github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U= +github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= +github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= +github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= +github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= +github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= +github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= -github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= -golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= -golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= +golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= +golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= +golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= -google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= -sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/internal/http3/body.go b/internal/http3/body.go index 07b68ff0..50a64330 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -5,7 +5,7 @@ import ( "io" "net" - "github.com/imroc/req/v3/internal/quic-go" + "github.com/quic-go/quic-go" ) // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go deleted file mode 100644 index 886d391b..00000000 --- a/internal/http3/body_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package http3 - -import ( - "errors" - - "github.com/imroc/req/v3/internal/quic-go" - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Response Body", func() { - var reqDone chan struct{} - - BeforeEach(func() { reqDone = make(chan struct{}) }) - - It("closes the reqDone channel when Read errors", func() { - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")) - rb := newResponseBody(str, nil, reqDone) - _, err := rb.Read([]byte{0}) - Expect(err).To(MatchError("test error")) - Expect(reqDone).To(BeClosed()) - }) - - It("allows multiple calls to Read, when Read errors", func() { - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2) - rb := newResponseBody(str, nil, reqDone) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(reqDone).To(BeClosed()) - _, err = rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("closes responses", func() { - str := mockquic.NewMockStream(mockCtrl) - rb := newResponseBody(str, nil, reqDone) - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) - Expect(rb.Close()).To(Succeed()) - }) - - It("allows multiple calls to Close", func() { - str := mockquic.NewMockStream(mockCtrl) - rb := newResponseBody(str, nil, reqDone) - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) - Expect(rb.Close()).To(Succeed()) - Expect(reqDone).To(BeClosed()) - Expect(rb.Close()).To(Succeed()) - }) -}) diff --git a/internal/http3/client.go b/internal/http3/client.go index b7e96e7e..e9d2957e 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -7,14 +7,13 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/imroc/req/v3/internal/transport" "github.com/quic-go/qpack" + "github.com/quic-go/quic-go" "io" "net/http" + "reflect" "strconv" "sync" "time" @@ -28,10 +27,16 @@ const ( defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB ) +const ( + VersionDraft29 quic.VersionNumber = 0xff00001d + Version1 quic.VersionNumber = 0x1 + Version2 quic.VersionNumber = 0x6b3343cf +) + var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlivePeriod: 10 * time.Second, - Versions: []quic.VersionNumber{protocol.VersionTLS}, + Versions: []quic.VersionNumber{Version1}, } type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) @@ -65,10 +70,10 @@ type client struct { hostname string conn quic.EarlyConnection - logger utils.Logger + opt *transport.Options } -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (*client, error) { if conf == nil { conf = defaultQuicConfig.Clone() } else if len(conf.Versions) == 0 { @@ -82,7 +87,10 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams } conf.EnableDatagrams = opts.EnableDatagram - logger := utils.DefaultLogger.WithPrefix("h3 client") + var debugf func(format string, v ...interface{}) + if opt != nil && opt.Debugf != nil { + debugf = opt.Debugf + } if tlsConf == nil { tlsConf = &tls.Config{} @@ -95,12 +103,12 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con return &client{ hostname: authorityAddr("https", hostname), tlsConf: tlsConf, - requestWriter: newRequestWriter(logger), + requestWriter: newRequestWriter(debugf), decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), config: conf, opts: opts, dialer: dialer, - logger: logger, + opt: opt, }, nil } @@ -118,7 +126,7 @@ func (c *client) dial(ctx context.Context) error { // send the SETTINGs frame, using 0-RTT data, if possible go func() { if err := c.setupConn(); err != nil { - c.logger.Debugf("Setting up connection failed: %s", err) + c.opt.Debugf("setting up http3 connection failed: %s", err) c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") } }() @@ -148,7 +156,7 @@ func (c *client) handleBidirectionalStreams() { for { str, err := c.conn.AcceptStream(context.Background()) if err != nil { - c.logger.Debugf("accepting bidirectional stream failed: %s", err) + c.opt.Debugf("accepting bidirectional stream failed: %s", err) return } go func(str quic.Stream) { @@ -159,7 +167,7 @@ func (c *client) handleBidirectionalStreams() { return } if err != nil { - c.logger.Debugf("error handling stream: %s", err) + c.opt.Debugf("error handling stream: %s", err) } c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }(str) @@ -170,7 +178,7 @@ func (c *client) handleUnidirectionalStreams() { for { str, err := c.conn.AcceptUniStream(context.Background()) if err != nil { - c.logger.Debugf("accepting unidirectional stream failed: %s", err) + c.opt.Debugf("accepting unidirectional stream failed: %s", err) return } @@ -180,7 +188,7 @@ func (c *client) handleUnidirectionalStreams() { if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) { return } - c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) + c.opt.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) return } // We're only interested in the control stream here. @@ -391,7 +399,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, } } if err := c.sendRequestBody(hstr, req.Body, bodyDumps); err != nil { - c.logger.Errorf("Error writing request: %s", err) + c.opt.Debugf("error writing request: %s", err) } if !opt.DontCloseRequestStream { hstr.Close() @@ -436,7 +444,10 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, return nil, newConnError(errorGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS) + connState, ok := reflect.ValueOf(c.conn.ConnectionState().TLS).Field(0).Interface().(tls.ConnectionState) + if !ok { + panic(fmt.Sprintf("bad tls connection state type: %s", reflect.ValueOf(c.conn.ConnectionState().TLS).Field(0).Type().Name())) + } res := &http.Response{ Proto: "HTTP/3.0", ProtoMajor: 3, diff --git a/internal/http3/client_test.go b/internal/http3/client_test.go deleted file mode 100644 index 14b1129f..00000000 --- a/internal/http3/client_test.go +++ /dev/null @@ -1,1022 +0,0 @@ -package http3 - -import ( - "bytes" - "compress/gzip" - "context" - "crypto/tls" - "errors" - "fmt" - "github.com/quic-go/qpack" - "io" - "io/ioutil" - "net/http" - "time" - - "github.com/imroc/req/v3/internal/quic-go" - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Client", func() { - var ( - client *client - req *http.Request - origDialAddr = dialAddr - handshakeCtx context.Context // an already canceled context - ) - - BeforeEach(func() { - origDialAddr = dialAddr - hostname := "quic.clemente.io:1337" - var err error - client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(client.hostname).To(Equal(hostname)) - - req, err = http.NewRequest("GET", "https://localhost:1337", nil) - Expect(err).ToNot(HaveOccurred()) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - handshakeCtx = ctx - }) - - AfterEach(func() { - dialAddr = origDialAddr - }) - - It("rejects quic.Configs that allow multiple QUIC versions", func() { - qconf := &quic.Config{ - Versions: []quic.VersionNumber{protocol.VersionDraft29, protocol.Version1}, - } - _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) - Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) - }) - - It("uses the default QUIC and TLS config if none is give", func() { - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { - Expect(quicConf).To(Equal(defaultQuicConfig)) - Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) - Expect(quicConf.Versions).To(Equal([]quic.VersionNumber{protocol.Version1})) - dialAddrCalled = true - return nil, errors.New("test done") - } - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - }) - - It("adds the port to the hostname, if none is given", func() { - client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { - Expect(hostname).To(Equal("quic.clemente.io:443")) - dialAddrCalled = true - return nil, errors.New("test done") - } - req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) - Expect(err).ToNot(HaveOccurred()) - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - }) - - It("uses the TLS config and QUIC config", func() { - tlsConf := &tls.Config{ - ServerName: "foo.bar", - NextProtos: []string{"proto foo", "proto bar"}, - } - quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} - client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { - Expect(host).To(Equal("localhost:1337")) - Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) - Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) - Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) - dialAddrCalled = true - return nil, errors.New("test done") - } - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - // make sure the original tls.Config was not modified - Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) - }) - - It("uses the custom dialer, if provided", func() { - testErr := errors.New("test done") - tlsConf := &tls.Config{ServerName: "foo.bar"} - quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} - ctx, cancel := context.WithTimeout(context.Background(), time.Hour) - defer cancel() - var dialerCalled bool - dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { - Expect(ctxP).To(Equal(ctx)) - Expect(address).To(Equal("localhost:1337")) - Expect(tlsConfP.ServerName).To(Equal("foo.bar")) - Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) - dialerCalled = true - return nil, testErr - } - client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) - Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - Expect(dialerCalled).To(BeTrue()) - }) - - It("enables HTTP/3 Datagrams", func() { - testErr := errors.New("handshake error") - client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { - Expect(quicConf.EnableDatagrams).To(BeTrue()) - return nil, testErr - } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - - It("errors when dialing fails", func() { - testErr := errors.New("handshake error") - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return nil, testErr - } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - - It("closes correctly if connection was not created", func() { - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(client.Close()).To(Succeed()) - }) - - Context("validating the address", func() { - It("refuses to do requests for the wrong host", func() { - req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) - }) - - It("allows requests using a different scheme", func() { - testErr := errors.New("handshake error") - req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return nil, testErr - } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - }) - - Context("hijacking bidirectional streams", func() { - var ( - request *http.Request - conn *mockquic.MockEarlyConnection - settingsFrameWritten chan struct{} - ) - testDone := make(chan struct{}) - - BeforeEach(func() { - testDone = make(chan struct{}) - settingsFrameWritten = make(chan struct{}) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { - defer GinkgoRecover() - close(settingsFrameWritten) - }) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } - var err error - request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - testDone <- struct{}{} - Eventually(settingsFrameWritten).Should(BeClosed()) - }) - - It("hijacks a bidirectional stream of unknown frame type", func() { - frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return true, nil - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(request, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { - frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, nil - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - }) - - It("closes the connection when hijacker returned error", func() { - frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, errors.New("error in hijacker") - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x41) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) - }) - - It("handles errors that occur when reading the frame type", func() { - testErr := errors.New("test error") - unknownStr := mockquic.NewMockStream(mockCtrl) - done := make(chan struct{}) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { - defer close(done) - Expect(e).To(MatchError(testErr)) - Expect(ft).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - return false, nil - } - - unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() - conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - }) - - Context("hijacking unidirectional streams", func() { - var ( - req *http.Request - conn *mockquic.MockEarlyConnection - settingsFrameWritten chan struct{} - ) - testDone := make(chan struct{}) - - BeforeEach(func() { - testDone = make(chan struct{}) - settingsFrameWritten = make(chan struct{}) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { - defer GinkgoRecover() - close(settingsFrameWritten) - }) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } - var err error - req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - testDone <- struct{}{} - Eventually(settingsFrameWritten).Should(BeClosed()) - }) - - It("hijacks an unidirectional stream of unknown stream type", func() { - streamTypeChan := make(chan StreamType, 1) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return true - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x54) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return unknownStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("handles errors that occur when reading the stream type", func() { - testErr := errors.New("test error") - done := make(chan struct{}) - unknownStr := mockquic.NewMockStream(mockCtrl) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { - defer close(done) - Expect(st).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - Expect(err).To(MatchError(testErr)) - return true - } - - unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) - conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { - streamTypeChan := make(chan StreamType, 1) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return false - } - - buf := &bytes.Buffer{} - quicvarint.Write(buf, 0x54) - unknownStr := mockquic.NewMockStream(mockCtrl) - unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return unknownStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - }) - - Context("control stream handling", func() { - var ( - req *http.Request - conn *mockquic.MockEarlyConnection - settingsFrameWritten chan struct{} - ) - testDone := make(chan struct{}) - - BeforeEach(func() { - settingsFrameWritten = make(chan struct{}) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { - defer GinkgoRecover() - close(settingsFrameWritten) - }) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } - var err error - req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - testDone <- struct{}{} - Eventually(settingsFrameWritten).Should(BeClosed()) - }) - - It("parses the SETTINGS frame", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&settingsFrame{}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError - }) - - for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { - streamType := t - name := "encoder" - if streamType == streamTypeQPACKDecoderStream { - name = "decoder" - } - - It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamType) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return str, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead - }) - } - - It("resets streams Other than the control stream and the QPACK streams", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, 1337) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - done := make(chan struct{}) - str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { - close(done) - }) - - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return str, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - }) - - It("errors when the first frame on the control stream is not a SETTINGS frame", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&dataFrame{}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorMissingSettings)) - close(done) - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - }) - - It("errors when parsing the frame on the control stream fails", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - b := &bytes.Buffer{} - (&settingsFrame{}).Write(b) - buf.Write(b.Bytes()[:b.Len()-1]) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorFrameError)) - close(done) - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - }) - - It("errors when parsing the server opens a push stream", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypePushStream) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorIDError)) - close(done) - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - }) - - It("errors when the server advertises datagram support (and we enabled support for it)", func() { - client.opts.EnableDatagram = true - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) - (&settingsFrame{Datagram: true}).Write(buf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) - done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(errorSettingsError)) - Expect(reason).To(Equal("missing QUIC Datagram support")) - close(done) - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("done")) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("Doing requests", func() { - var ( - req *http.Request - str *mockquic.MockStream - conn *mockquic.MockEarlyConnection - settingsFrameWritten chan struct{} - ) - testDone := make(chan struct{}) - - getHeadersFrame := func(headers map[string]string) []byte { - buf := &bytes.Buffer{} - headerBuf := &bytes.Buffer{} - enc := qpack.NewEncoder(headerBuf) - for name, value := range headers { - Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed()) - } - Expect(enc.Close()).To(Succeed()) - (&headersFrame{Length: uint64(headerBuf.Len())}).Write(buf) - buf.Write(headerBuf.Bytes()) - return buf.Bytes() - } - - decodeHeader := func(str io.Reader) map[string]string { - fields := make(map[string]string) - decoder := qpack.NewDecoder(nil) - - frame, err := parseNextFrame(str, nil) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) - headersFrame := frame.(*headersFrame) - data := make([]byte, headersFrame.Length) - _, err = io.ReadFull(str, data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - hfs, err := decoder.DecodeFull(data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - for _, p := range hfs { - fields[p.Name] = p.Value - } - return fields - } - - getResponse := func(status int) []byte { - buf := &bytes.Buffer{} - rstr := mockquic.NewMockStream(mockCtrl) - rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) - rw.WriteHeader(status) - rw.Flush() - return buf.Bytes() - } - - BeforeEach(func() { - settingsFrameWritten = make(chan struct{}) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { - defer GinkgoRecover() - r := bytes.NewReader(b) - streamType, err := quicvarint.Read(r) - Expect(err).ToNot(HaveOccurred()) - Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) - close(settingsFrameWritten) - }) // SETTINGS frame - str = mockquic.NewMockStream(mockCtrl) - conn = mockquic.NewMockEarlyConnection(mockCtrl) - conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone - return nil, errors.New("test done") - }) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } - var err error - req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - testDone <- struct{}{} - Eventually(settingsFrameWritten).Should(BeClosed()) - }) - - It("errors if it can't open a stream", func() { - testErr := errors.New("stream open error") - conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - - It("performs a 0-RTT request", func() { - testErr := errors.New("stream open error") - req.Method = MethodGet0RTT - // don't EXPECT any calls to HandshakeComplete() - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - buf := &bytes.Buffer{} - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - str.EXPECT().Close() - str.EXPECT().CancelWrite(gomock.Any()) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { - return 0, testErr - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) - }) - - It("returns a response", func() { - rspBuf := bytes.NewBuffer(getResponse(418)) - gomock.InOrder( - conn.EXPECT().HandshakeComplete().Return(handshakeCtx), - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), - ) - str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) - str.EXPECT().Close() - str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.Proto).To(Equal("HTTP/3.0")) - Expect(rsp.ProtoMajor).To(Equal(3)) - Expect(rsp.StatusCode).To(Equal(418)) - }) - - Context("requests containing a Body", func() { - var strBuf *bytes.Buffer - - BeforeEach(func() { - strBuf = &bytes.Buffer{} - gomock.InOrder( - conn.EXPECT().HandshakeComplete().Return(handshakeCtx), - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), - ) - body := &mockBody{} - body.SetData([]byte("request body")) - var err error - req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) - Expect(err).ToNot(HaveOccurred()) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - }) - - It("sends a request", func() { - done := make(chan struct{}) - gomock.InOrder( - str.EXPECT().Close().Do(func() { close(done) }), - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors - ) - // the response body is sent asynchronously, while already reading the response - str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { - <-done - return 0, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - hfs := decodeHeader(strBuf) - Expect(hfs).To(HaveKeyWithValue(":method", "POST")) - Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) - }) - - It("returns the error that occurred when reading the body", func() { - req.Body.(*mockBody).readErr = errors.New("testErr") - done := make(chan struct{}) - gomock.InOrder( - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { - close(done) - }), - str.EXPECT().CancelWrite(gomock.Any()), - ) - - // the response body is sent asynchronously, while already reading the response - str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { - <-done - return 0, errors.New("test done") - }) - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - Eventually(closed).Should(BeClosed()) - }) - - It("sets the Content-Length", func() { - done := make(chan struct{}) - buf := &bytes.Buffer{} - buf.Write(getHeadersFrame(map[string]string{ - ":status": "200", - "Content-Length": "1337", - })) - (&dataFrame{Length: 0x6}).Write(buf) - buf.Write([]byte("foobar")) - str.EXPECT().Close().Do(func() { close(done) }) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors - // the response body is sent asynchronously, while already reading the response - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - req, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - Expect(req.ContentLength).To(BeEquivalentTo(1337)) - Eventually(done).Should(BeClosed()) - }) - - It("closes the connection when the first frame is not a HEADERS frame", func() { - buf := &bytes.Buffer{} - (&dataFrame{Length: 0x42}).Write(buf) - conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()) - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) - Eventually(closed).Should(BeClosed()) - }) - - It("cancels the stream when the HEADERS frame is too large", func() { - buf := &bytes.Buffer{} - (&headersFrame{Length: 1338}).Write(buf) - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)) - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) - Eventually(closed).Should(BeClosed()) - }) - }) - - Context("request cancellations", func() { - It("cancels a request while waiting for the handshake to complete", func() { - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - - errChan := make(chan error) - go func() { - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - errChan <- err - }() - Consistently(errChan).ShouldNot(Receive()) - cancel() - Eventually(errChan).Should(Receive(MatchError("context canceled"))) - }) - - It("cancels a request while the request is still in flight", func() { - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) - buf := &bytes.Buffer{} - str.EXPECT().Close().MaxTimes(1) - - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - - done := make(chan struct{}) - canceled := make(chan struct{}) - gomock.InOrder( - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), - ) - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { - cancel() - <-canceled - return 0, errors.New("test done") - }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - Eventually(done).Should(BeClosed()) - }) - - It("cancels a request after the response arrived", func() { - rspBuf := bytes.NewBuffer(getResponse(404)) - - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - buf := &bytes.Buffer{} - str.EXPECT().Close().MaxTimes(1) - - done := make(chan struct{}) - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - cancel() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't cancel a request if DontCloseRequestStream is set", func() { - rspBuf := bytes.NewBuffer(getResponse(404)) - - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - buf := &bytes.Buffer{} - str.EXPECT().Close().MaxTimes(1) - - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) - Expect(err).ToNot(HaveOccurred()) - cancel() - _, err = io.ReadAll(rsp.Body) - Expect(err).ToNot(HaveOccurred()) - }) - }) - - Context("gzip compression", func() { - BeforeEach(func() { - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - }) - - It("adds the gzip header to requests", func() { - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - buf := &bytes.Buffer{} - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - gomock.InOrder( - str.EXPECT().Close(), - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors - ) - str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - hfs := decodeHeader(buf) - Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) - }) - - It("doesn't add gzip if the header disable it", func() { - client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - buf := &bytes.Buffer{} - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - gomock.InOrder( - str.EXPECT().Close(), - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors - ) - str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - hfs := decodeHeader(buf) - Expect(hfs).ToNot(HaveKey("accept-encoding")) - }) - - It("decompresses the response", func() { - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - buf := &bytes.Buffer{} - rstr := mockquic.NewMockStream(mockCtrl) - rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) - rw.Header().Set("Content-Encoding", "gzip") - gz := gzip.NewWriter(rw) - gz.Write([]byte("gzipped response")) - gz.Close() - rw.Flush() - str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - str.EXPECT().Close() - - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(rsp.Body) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) - Expect(string(data)).To(Equal("gzipped response")) - Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) - Expect(rsp.Uncompressed).To(BeTrue()) - }) - - It("only decompresses the response if the response contains the right content-encoding header", func() { - conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - buf := &bytes.Buffer{} - rstr := mockquic.NewMockStream(mockCtrl) - rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) - rw.Write([]byte("not gzipped")) - rw.Flush() - str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - str.EXPECT().Close() - - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(rsp.Body) - Expect(err).ToNot(HaveOccurred()) - Expect(string(data)).To(Equal("not gzipped")) - Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) - }) - }) - }) -}) diff --git a/internal/http3/error_codes.go b/internal/http3/error_codes.go index 353a1a8e..5df9b5df 100644 --- a/internal/http3/error_codes.go +++ b/internal/http3/error_codes.go @@ -3,7 +3,7 @@ package http3 import ( "fmt" - "github.com/imroc/req/v3/internal/quic-go" + "github.com/quic-go/quic-go" ) type errorCode quic.ApplicationErrorCode diff --git a/internal/http3/error_codes_test.go b/internal/http3/error_codes_test.go deleted file mode 100644 index e4aae37e..00000000 --- a/internal/http3/error_codes_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package http3 - -import ( - "go/ast" - "go/parser" - "go/token" - "path" - "runtime" - "strconv" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("error codes", func() { - It("has a string representation for every error code", func() { - // We parse the error code file, extract all constants, and verify that - // each of them has a string version. Go FTW! - _, thisfile, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - filename := path.Join(path.Dir(thisfile), "error_codes.go") - fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) - Expect(err).NotTo(HaveOccurred()) - constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs - Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing - for _, c := range constSpecs { - valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value - val, err := strconv.ParseInt(valString, 0, 64) - Expect(err).NotTo(HaveOccurred()) - Expect(errorCode(val).String()).ToNot(Equal("unknown error code")) - } - }) - - It("has a string representation for unknown error codes", func() { - Expect(errorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) - }) -}) diff --git a/internal/http3/frames.go b/internal/http3/frames.go index b0d886d5..37eb0290 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" - "github.com/imroc/req/v3/internal/quic-go/protocol" "github.com/imroc/req/v3/internal/quic-go/quicvarint" ) @@ -145,14 +144,14 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { func (f *settingsFrame) Write(b *bytes.Buffer) { quicvarint.Write(b, 0x4) - var l protocol.ByteCount + var l uint64 for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } - quicvarint.Write(b, uint64(l)) + quicvarint.Write(b, l) if f.Datagram { quicvarint.Write(b, settingDatagram) quicvarint.Write(b, 1) diff --git a/internal/http3/frames_test.go b/internal/http3/frames_test.go deleted file mode 100644 index 1e0146cb..00000000 --- a/internal/http3/frames_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package http3 - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type errReader struct{ err error } - -func (e errReader) Read([]byte) (int, error) { return 0, e.err } - -var _ = Describe("Frames", func() { - appendVarInt := func(b []byte, val uint64) []byte { - buf := &bytes.Buffer{} - quicvarint.Write(buf, val) - return append(b, buf.Bytes()...) - } - - It("skips unknown frames", func() { - data := appendVarInt(nil, 0xdeadbeef) // type byte - data = appendVarInt(data, 0x42) - data = append(data, make([]byte, 0x42)...) - buf := bytes.NewBuffer(data) - (&dataFrame{Length: 0x1234}).Write(buf) - frame, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) - }) - - Context("DATA frames", func() { - It("parses", func() { - data := appendVarInt(nil, 0) // type byte - data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) - }) - - It("writes", func() { - buf := &bytes.Buffer{} - (&dataFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(Equal(uint64(0xdeadbeef))) - }) - }) - - Context("HEADERS frames", func() { - It("parses", func() { - data := appendVarInt(nil, 1) // type byte - data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) - Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) - }) - - It("writes", func() { - buf := &bytes.Buffer{} - (&headersFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) - Expect(frame.(*headersFrame).Length).To(Equal(uint64(0xdeadbeef))) - }) - }) - - Context("SETTINGS frames", func() { - It("parses", func() { - settings := appendVarInt(nil, 13) - settings = appendVarInt(settings, 37) - settings = appendVarInt(settings, 0xdead) - settings = appendVarInt(settings, 0xbeef) - data := appendVarInt(nil, 4) // type byte - data = appendVarInt(data, uint64(len(settings))) - data = append(data, settings...) - frame, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) - sf := frame.(*settingsFrame) - Expect(sf.Other).To(HaveKeyWithValue(uint64(13), uint64(37))) - Expect(sf.Other).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef))) - }) - - It("rejects duplicate settings", func() { - settings := appendVarInt(nil, 13) - settings = appendVarInt(settings, 37) - settings = appendVarInt(settings, 13) - settings = appendVarInt(settings, 38) - data := appendVarInt(nil, 4) // type byte - data = appendVarInt(data, uint64(len(settings))) - data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).To(MatchError("duplicate setting: 13")) - }) - - It("writes", func() { - sf := &settingsFrame{Other: map[uint64]uint64{ - 1: 2, - 99: 999, - 13: 37, - }} - buf := &bytes.Buffer{} - sf.Write(buf) - frame, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(sf)) - }) - - It("errors on EOF", func() { - sf := &settingsFrame{Other: map[uint64]uint64{ - 13: 37, - 0xdeadbeef: 0xdecafbad, - }} - buf := &bytes.Buffer{} - sf.Write(buf) - - data := buf.Bytes() - _, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).ToNot(HaveOccurred()) - - for i := range data { - b := make([]byte, i) - copy(b, data[:i]) - _, err := parseNextFrame(bytes.NewReader(b), nil) - Expect(err).To(MatchError(io.EOF)) - } - }) - - Context("H3_DATAGRAM", func() { - It("reads the H3_DATAGRAM value", func() { - settings := appendVarInt(nil, settingDatagram) - settings = appendVarInt(settings, 1) - data := appendVarInt(nil, 4) // type byte - data = appendVarInt(data, uint64(len(settings))) - data = append(data, settings...) - f, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) - sf := f.(*settingsFrame) - Expect(sf.Datagram).To(BeTrue()) - }) - - It("rejects duplicate H3_DATAGRAM entries", func() { - settings := appendVarInt(nil, settingDatagram) - settings = appendVarInt(settings, 1) - settings = appendVarInt(settings, settingDatagram) - settings = appendVarInt(settings, 1) - data := appendVarInt(nil, 4) // type byte - data = appendVarInt(data, uint64(len(settings))) - data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) - }) - - It("rejects invalid values for the H3_DATAGRAM entry", func() { - settings := appendVarInt(nil, settingDatagram) - settings = appendVarInt(settings, 1337) - data := appendVarInt(nil, 4) // type byte - data = appendVarInt(data, uint64(len(settings))) - data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) - }) - - It("writes the H3_DATAGRAM setting", func() { - sf := &settingsFrame{Datagram: true} - buf := &bytes.Buffer{} - sf.Write(buf) - frame, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(sf)) - }) - }) - }) - - Context("hijacking", func() { - It("reads a frame without hijacking the stream", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, 1337) - customFrameContents := []byte("foobar") - buf.Write(customFrameContents) - - var called bool - _, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - Expect(ft).To(BeEquivalentTo(1337)) - called = true - b := make([]byte, 3) - _, err = io.ReadFull(buf, b) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foo")) - return true, nil - }) - Expect(err).To(MatchError(errHijacked)) - Expect(called).To(BeTrue()) - }) - - It("passes on errors that occur when reading the frame type", func() { - testErr := errors.New("test error") - var called bool - _, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).To(MatchError(testErr)) - Expect(ft).To(BeZero()) - called = true - return true, nil - }) - Expect(err).To(MatchError(errHijacked)) - Expect(called).To(BeTrue()) - }) - - It("reads a frame without hijacking the stream", func() { - buf := &bytes.Buffer{} - quicvarint.Write(buf, 1337) - customFrameContents := []byte("custom frame") - quicvarint.Write(buf, uint64(len(customFrameContents))) - buf.Write(customFrameContents) - (&dataFrame{Length: 6}).Write(buf) - buf.WriteString("foobar") - - var called bool - frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - Expect(ft).To(BeEquivalentTo(1337)) - called = true - return false, nil - }) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(&dataFrame{Length: 6})) - Expect(called).To(BeTrue()) - }) - }) -}) diff --git a/internal/http3/http3_suite_test.go b/internal/http3/http3_suite_test.go deleted file mode 100644 index c94d932a..00000000 --- a/internal/http3/http3_suite_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package http3 - -import ( - "os" - "strconv" - "testing" - "time" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestHttp3(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "HTTP/3 Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) - -//nolint:unparam -func scaleDuration(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t -} diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index 3ff7a61b..9ce69b45 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" - "github.com/imroc/req/v3/internal/quic-go" + "github.com/quic-go/quic-go" ) // A Stream is a HTTP/3 stream. diff --git a/internal/http3/http_stream_test.go b/internal/http3/http_stream_test.go deleted file mode 100644 index f4ccaa6d..00000000 --- a/internal/http3/http_stream_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package http3 - -import ( - "bytes" - "io" - - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Stream", func() { - Context("reading", func() { - var ( - str Stream - qstr *mockquic.MockStream - buf *bytes.Buffer - errorCbCalled bool - ) - - errorCb := func() { errorCbCalled = true } - getDataFrame := func(data []byte) []byte { - b := &bytes.Buffer{} - (&dataFrame{Length: uint64(len(data))}).Write(b) - b.Write(data) - return b.Bytes() - } - - BeforeEach(func() { - buf = &bytes.Buffer{} - errorCbCalled = false - qstr = mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - str = newStream(qstr, errorCb) - }) - - It("reads DATA frames in a single run", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 6) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 3) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("foo"))) - n, err = str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("bar"))) - }) - - It("reads DATA frames into too large buffers", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 10) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b[:n]).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames into too large buffers, in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte("foob"))) - n, err = str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte("ar"))) - }) - - It("reads multiple DATA frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("foo"))) - n, err = str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("bar"))) - }) - - It("skips HEADERS frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - (&headersFrame{Length: 10}).Write(buf) - buf.Write(make([]byte, 10)) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := io.ReadFull(str, b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("errors when it can't parse the frame", func() { - buf.Write([]byte("invalid")) - _, err := str.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("errors on unexpected frames, and calls the error callback", func() { - (&settingsFrame{}).Write(buf) - _, err := str.Read([]byte{0}) - Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) - Expect(errorCbCalled).To(BeTrue()) - }) - }) - - Context("writing", func() { - It("writes data frames", func() { - buf := &bytes.Buffer{} - qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - str := newStream(qstr, nil) - str.Write([]byte("foo")) - str.Write([]byte("foobar")) - - f, err := parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(Equal(&dataFrame{Length: 3})) - b := make([]byte, 3) - _, err = io.ReadFull(buf, b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte("foo"))) - - f, err = parseNextFrame(buf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(Equal(&dataFrame{Length: 6})) - b = make([]byte, 6) - _, err = io.ReadFull(buf, b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte("foobar"))) - }) - }) -}) diff --git a/internal/http3/request_test.go b/internal/http3/request_test.go deleted file mode 100644 index 059fcfe1..00000000 --- a/internal/http3/request_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package http3 - -import ( - "github.com/quic-go/qpack" - "net/http" - "net/url" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Request", func() { - It("populates request", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: "GET"}, - {Name: "content-length", Value: "42"}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Method).To(Equal("GET")) - Expect(req.URL.Path).To(Equal("/foo")) - Expect(req.URL.Host).To(BeEmpty()) - Expect(req.Proto).To(Equal("HTTP/3.0")) - Expect(req.ProtoMajor).To(Equal(3)) - Expect(req.ProtoMinor).To(BeZero()) - Expect(req.ContentLength).To(Equal(int64(42))) - Expect(req.Header).To(BeEmpty()) - Expect(req.Body).To(BeNil()) - Expect(req.Host).To(Equal("quic.clemente.io")) - Expect(req.RequestURI).To(Equal("/foo")) - Expect(req.TLS).ToNot(BeNil()) - }) - - It("parses path with leading double slashes", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "//foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: "GET"}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Header).To(BeEmpty()) - Expect(req.Body).To(BeNil()) - Expect(req.URL.Path).To(Equal("//foo")) - Expect(req.URL.Host).To(BeEmpty()) - Expect(req.Host).To(Equal("quic.clemente.io")) - Expect(req.RequestURI).To(Equal("//foo")) - }) - - It("concatenates the cookie headers", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: "GET"}, - {Name: "cookie", Value: "cookie1=foobar1"}, - {Name: "cookie", Value: "cookie2=foobar2"}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Header).To(Equal(http.Header{ - "Cookie": []string{"cookie1=foobar1; cookie2=foobar2"}, - })) - }) - - It("handles Other headers", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: "GET"}, - {Name: "cache-control", Value: "max-age=0"}, - {Name: "duplicate-header", Value: "1"}, - {Name: "duplicate-header", Value: "2"}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Header).To(Equal(http.Header{ - "Cache-Control": []string{"max-age=0"}, - "Duplicate-Header": []string{"1", "2"}, - })) - }) - - It("errors with missing path", func() { - headers := []qpack.HeaderField{ - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: "GET"}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path, :authority and :method must not be empty")) - }) - - It("errors with missing method", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path, :authority and :method must not be empty")) - }) - - It("errors with missing authority", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":method", Value: "GET"}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path, :authority and :method must not be empty")) - }) - - Context("regular HTTP CONNECT", func() { - It("handles CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: http.MethodConnect}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Method).To(Equal(http.MethodConnect)) - Expect(req.RequestURI).To(Equal("quic.clemente.io")) - }) - - It("errors with missing authority in CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":method", Value: http.MethodConnect}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) - }) - - It("errors with extra path in CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: http.MethodConnect}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) - }) - }) - - Context("Extended CONNECT", func() { - It("handles Extended CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":protocol", Value: "webtransport"}, - {Name: ":scheme", Value: "ftp"}, - {Name: ":method", Value: http.MethodConnect}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":path", Value: "/foo?val=1337"}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Method).To(Equal(http.MethodConnect)) - Expect(req.Proto).To(Equal("webtransport")) - Expect(req.URL.String()).To(Equal("ftp://quic.clemente.io/foo?val=1337")) - Expect(req.URL.Query().Get("val")).To(Equal("1337")) - }) - - It("errors with missing scheme", func() { - headers := []qpack.HeaderField{ - {Name: ":protocol", Value: "webtransport"}, - {Name: ":method", Value: http.MethodConnect}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":path", Value: "/foo"}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError("extended CONNECT: :scheme, :path and :authority must not be empty")) - }) - }) - - Context("extracting the hostname from a request", func() { - var url *url.URL - - BeforeEach(func() { - var err error - url, err = url.Parse("https://quic.clemente.io:1337") - Expect(err).ToNot(HaveOccurred()) - }) - - It("uses req.URL.Host", func() { - req := &http.Request{URL: url} - Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337")) - }) - - It("uses req.URL.Host even if req.Host is available", func() { - req := &http.Request{ - Host: "www.example.org", - URL: url, - } - Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337")) - }) - - It("returns an empty hostname if nothing is set", func() { - Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty()) - }) - }) -}) diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index f77d3a87..66c7ead0 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -13,8 +13,7 @@ import ( "strings" "sync" - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/utils" + "github.com/quic-go/quic-go" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" @@ -27,16 +26,16 @@ type requestWriter struct { encoder *qpack.Encoder headerBuf *bytes.Buffer - logger utils.Logger + debugf func(format string, v ...interface{}) } -func newRequestWriter(logger utils.Logger) *requestWriter { +func newRequestWriter(debugf func(format string, v ...interface{})) *requestWriter { headerBuf := &bytes.Buffer{} encoder := qpack.NewEncoder(headerBuf) return &requestWriter{ encoder: encoder, headerBuf: headerBuf, - logger: logger, + debugf: debugf, } } diff --git a/internal/http3/request_writer_test.go b/internal/http3/request_writer_test.go deleted file mode 100644 index bddedcc0..00000000 --- a/internal/http3/request_writer_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package http3 - -import ( - "bytes" - "github.com/quic-go/qpack" - "io" - "net/http" - - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Request Writer", func() { - var ( - rw *requestWriter - str *mockquic.MockStream - strBuf *bytes.Buffer - ) - - decode := func(str io.Reader) map[string]string { - frame, err := parseNextFrame(str, nil) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) - headersFrame := frame.(*headersFrame) - data := make([]byte, headersFrame.Length) - _, err = io.ReadFull(str, data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - decoder := qpack.NewDecoder(nil) - hfs, err := decoder.DecodeFull(data) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - values := make(map[string]string) - for _, hf := range hfs { - values[hf.Name] = hf.Value - } - return values - } - - BeforeEach(func() { - rw = newRequestWriter(utils.DefaultLogger) - strBuf = &bytes.Buffer{} - str = mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - }) - - It("writes a GET request", func() { - req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) - Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) - Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar")) - Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) - Expect(headerFields).ToNot(HaveKey("accept-encoding")) - }) - - It("sends cookies", func() { - req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) - Expect(err).ToNot(HaveOccurred()) - cookie1 := &http.Cookie{ - Name: "Cookie #1", - Value: "Value #1", - } - cookie2 := &http.Cookie{ - Name: "Cookie #2", - Value: "Value #2", - } - req.AddCookie(cookie1) - req.AddCookie(cookie2) - Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) - }) - - It("adds the header for gzip support", func() { - req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, true, nil)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) - }) - - It("writes a CONNECT request", func() { - req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) - Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) - Expect(headerFields).ToNot(HaveKey(":path")) - Expect(headerFields).ToNot(HaveKey(":scheme")) - Expect(headerFields).ToNot(HaveKey(":protocol")) - }) - - It("writes an Extended CONNECT request", func() { - req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) - Expect(err).ToNot(HaveOccurred()) - req.Proto = "webtransport" - Expect(rw.WriteRequestHeader(str, req, false, nil)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) - Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) - Expect(headerFields).To(HaveKeyWithValue(":path", "/foobar")) - Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) - Expect(headerFields).To(HaveKeyWithValue(":protocol", "webtransport")) - }) -}) diff --git a/internal/http3/response_writer.go b/internal/http3/response_writer.go deleted file mode 100644 index ab45e379..00000000 --- a/internal/http3/response_writer.go +++ /dev/null @@ -1,118 +0,0 @@ -package http3 - -import ( - "bufio" - "bytes" - "github.com/quic-go/qpack" - "net/http" - "strconv" - "strings" - - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type responseWriter struct { - conn quic.Connection - bufferedStr *bufio.Writer - - header http.Header - status int // status code passed to WriteHeader - headerWritten bool - - logger utils.Logger -} - -var ( - _ http.ResponseWriter = &responseWriter{} - _ http.Flusher = &responseWriter{} - _ Hijacker = &responseWriter{} -) - -func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { - return &responseWriter{ - header: http.Header{}, - conn: conn, - bufferedStr: bufio.NewWriter(str), - logger: logger, - } -} - -func (w *responseWriter) Header() http.Header { - return w.header -} - -func (w *responseWriter) WriteHeader(status int) { - if w.headerWritten { - return - } - - if status < 100 || status >= 200 { - w.headerWritten = true - } - w.status = status - - var headers bytes.Buffer - enc := qpack.NewEncoder(&headers) - enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - - for k, v := range w.header { - for index := range v { - enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) - } - } - - buf := &bytes.Buffer{} - (&headersFrame{Length: uint64(headers.Len())}).Write(buf) - w.logger.Infof("Responding with %d", status) - if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { - w.logger.Errorf("could not write headers frame: %s", err.Error()) - } - if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { - w.logger.Errorf("could not write header frame payload: %s", err.Error()) - } - if !w.headerWritten { - w.Flush() - } -} - -func (w *responseWriter) Write(p []byte) (int, error) { - if !w.headerWritten { - w.WriteHeader(200) - } - if !bodyAllowedForStatus(w.status) { - return 0, http.ErrBodyNotAllowed - } - df := &dataFrame{Length: uint64(len(p))} - buf := &bytes.Buffer{} - df.Write(buf) - if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { - return 0, err - } - return w.bufferedStr.Write(p) -} - -func (w *responseWriter) Flush() { - if err := w.bufferedStr.Flush(); err != nil { - w.logger.Errorf("could not flush to stream: %s", err.Error()) - } -} - -func (w *responseWriter) StreamCreator() StreamCreator { - return w.conn -} - -// copied from http2/http2.go -// bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC 2616, section 4.4. -func bodyAllowedForStatus(status int) bool { - switch { - case status >= 100 && status <= 199: - return false - case status == 204: - return false - case status == 304: - return false - } - return true -} diff --git a/internal/http3/response_writer_test.go b/internal/http3/response_writer_test.go deleted file mode 100644 index b24f2c28..00000000 --- a/internal/http3/response_writer_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package http3 - -import ( - "bytes" - "github.com/quic-go/qpack" - "io" - "net/http" - - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Response Writer", func() { - var ( - rw *responseWriter - strBuf *bytes.Buffer - ) - - BeforeEach(func() { - strBuf = &bytes.Buffer{} - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - rw = newResponseWriter(str, nil, utils.DefaultLogger) - }) - - decodeHeader := func(str io.Reader) map[string][]string { - rw.Flush() - fields := make(map[string][]string) - decoder := qpack.NewDecoder(nil) - - frame, err := parseNextFrame(str, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) - headersFrame := frame.(*headersFrame) - data := make([]byte, headersFrame.Length) - _, err = io.ReadFull(str, data) - Expect(err).ToNot(HaveOccurred()) - hfs, err := decoder.DecodeFull(data) - Expect(err).ToNot(HaveOccurred()) - for _, p := range hfs { - fields[p.Name] = append(fields[p.Name], p.Value) - } - return fields - } - - getData := func(str io.Reader) []byte { - frame, err := parseNextFrame(str, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - df := frame.(*dataFrame) - data := make([]byte, df.Length) - _, err = io.ReadFull(str, data) - Expect(err).ToNot(HaveOccurred()) - return data - } - - It("writes status", func() { - rw.WriteHeader(http.StatusTeapot) - fields := decodeHeader(strBuf) - Expect(fields).To(HaveLen(1)) - Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) - }) - - It("writes headers", func() { - rw.Header().Add("content-length", "42") - rw.WriteHeader(http.StatusTeapot) - fields := decodeHeader(strBuf) - Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"})) - }) - - It("writes multiple headers with the same name", func() { - const cookie1 = "test1=1; Max-Age=7200; path=/" - const cookie2 = "test2=2; Max-Age=7200; path=/" - rw.Header().Add("set-cookie", cookie1) - rw.Header().Add("set-cookie", cookie2) - rw.WriteHeader(http.StatusTeapot) - fields := decodeHeader(strBuf) - Expect(fields).To(HaveKey("set-cookie")) - cookies := fields["set-cookie"] - Expect(cookies).To(ContainElement(cookie1)) - Expect(cookies).To(ContainElement(cookie2)) - }) - - It("writes data", func() { - n, err := rw.Write([]byte("foobar")) - Expect(n).To(Equal(6)) - Expect(err).ToNot(HaveOccurred()) - // Should have written 200 on the header stream - fields := decodeHeader(strBuf) - Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) - // And foobar on the data stream - Expect(getData(strBuf)).To(Equal([]byte("foobar"))) - }) - - It("writes data after WriteHeader is called", func() { - rw.WriteHeader(http.StatusTeapot) - n, err := rw.Write([]byte("foobar")) - Expect(n).To(Equal(6)) - Expect(err).ToNot(HaveOccurred()) - // Should have written 418 on the header stream - fields := decodeHeader(strBuf) - Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) - // And foobar on the data stream - Expect(getData(strBuf)).To(Equal([]byte("foobar"))) - }) - - It("does not WriteHeader() twice", func() { - rw.WriteHeader(200) - rw.WriteHeader(500) - fields := decodeHeader(strBuf) - Expect(fields).To(HaveLen(1)) - Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) - }) - - It("allows calling WriteHeader() several times when using the 103 status code", func() { - rw.Header().Add("Link", "; rel=preload; as=style") - rw.Header().Add("Link", "; rel=preload; as=script") - rw.WriteHeader(http.StatusEarlyHints) - - n, err := rw.Write([]byte("foobar")) - Expect(n).To(Equal(6)) - Expect(err).ToNot(HaveOccurred()) - - // Early Hints must have been received - fields := decodeHeader(strBuf) - Expect(fields).To(HaveLen(2)) - Expect(fields).To(HaveKeyWithValue(":status", []string{"103"})) - Expect(fields).To(HaveKeyWithValue("link", []string{"; rel=preload; as=style", "; rel=preload; as=script"})) - - // According to the spec, headers sent in the informational response must also be included in the final response - fields = decodeHeader(strBuf) - Expect(fields).To(HaveLen(2)) - Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) - Expect(fields).To(HaveKeyWithValue("link", []string{"; rel=preload; as=style", "; rel=preload; as=script"})) - - Expect(getData(strBuf)).To(Equal([]byte("foobar"))) - }) - - It("doesn't allow writes if the status code doesn't allow a body", func() { - rw.WriteHeader(304) - n, err := rw.Write([]byte("foobar")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(http.ErrBodyNotAllowed)) - }) -}) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 09b5ef1c..a6f4b080 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -12,7 +12,7 @@ import ( "sync" "time" - "github.com/imroc/req/v3/internal/quic-go" + "github.com/quic-go/quic-go" "golang.org/x/net/http/httpguts" ) @@ -182,6 +182,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo }, r.QuicConfig, r.Dial, + r.Options, ) if err != nil { return nil, err diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go deleted file mode 100644 index ed1aed90..00000000 --- a/internal/http3/roundtrip_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package http3 - -import ( - "bytes" - "context" - "crypto/tls" - "errors" - "github.com/imroc/req/v3/internal/transport" - "io" - "net/http" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go" - mockquic "github.com/imroc/req/v3/internal/quic-go/mocks/quic" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type mockClient struct { - closed bool -} - -func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) { - return &http.Response{Request: req}, nil -} - -func (m *mockClient) Close() error { - m.closed = true - return nil -} - -var _ roundTripCloser = &mockClient{} - -type mockBody struct { - reader bytes.Reader - readErr error - closeErr error - closed bool -} - -// make sure the mockBody can be used as a http.Request.Body -var _ io.ReadCloser = &mockBody{} - -func (m *mockBody) Read(p []byte) (int, error) { - if m.readErr != nil { - return 0, m.readErr - } - return m.reader.Read(p) -} - -func (m *mockBody) SetData(data []byte) { - m.reader = *bytes.NewReader(data) -} - -func (m *mockBody) Close() error { - m.closed = true - return m.closeErr -} - -var _ = Describe("RoundTripper", func() { - var ( - rt *RoundTripper - req1 *http.Request - conn *mockquic.MockEarlyConnection - handshakeCtx context.Context // an already canceled context - ) - - BeforeEach(func() { - rt = &RoundTripper{Options: &transport.Options{}} - var err error - req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) - Expect(err).ToNot(HaveOccurred()) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - handshakeCtx = ctx - }) - - Context("dialing hosts", func() { - origDialAddr := dialAddr - - BeforeEach(func() { - conn = mockquic.NewMockEarlyConnection(mockCtrl) - origDialAddr = dialAddr - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - // return an error when trying to open a stream - // we don't want to test all the dial logic here, just that dialing happens at all - return conn, nil - } - }) - - AfterEach(func() { - dialAddr = origDialAddr - }) - - It("creates new clients", func() { - closed := make(chan struct{}) - testErr := errors.New("test err") - req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-closed - return nil, errors.New("test done") - }).MaxTimes(1) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - Eventually(closed).Should(BeClosed()) - }) - - It("uses the quic.Config, if provided", func() { - config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} - var receivedConfig *quic.Config - dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { - receivedConfig = config - return nil, errors.New("handshake error") - } - rt.QuicConfig = config - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("handshake error")) - Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) - }) - - It("uses the custom dialer, if provided", func() { - var dialed bool - dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - dialed = true - return nil, errors.New("handshake error") - } - rt.Dial = dialer - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("handshake error")) - Expect(dialed).To(BeTrue()) - }) - - It("reuses existing clients", func() { - closed := make(chan struct{}) - testErr := errors.New("test err") - conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) - conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-closed - return nil, errors.New("test done") - }).MaxTimes(1) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) - req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTrip(req2) - Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - Eventually(closed).Should(BeClosed()) - }) - - It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { - req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) - Expect(err).To(MatchError(ErrNoCachedConn)) - }) - }) - - Context("validating request", func() { - It("rejects plain HTTP requests", func() { - req, err := http.NewRequest("GET", "http://www.example.org/", nil) - req.Body = &mockBody{} - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError("http3: unsupported protocol scheme: http")) - Expect(req.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("rejects requests without a URL", func() { - req1.URL = nil - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: nil Request.URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("rejects request without a URL Host", func() { - req1.URL.Host = "" - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: no Host in request URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("doesn't try to close the body if the request doesn't have one", func() { - req1.URL = nil - Expect(req1.Body).To(BeNil()) - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: nil Request.URL")) - }) - - It("rejects requests without a header", func() { - req1.Header = nil - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: nil Request.Header")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("rejects requests with invalid header name fields", func() { - req1.Header.Add("foobär", "value") - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) - }) - - It("rejects requests with invalid header name values", func() { - req1.Header.Add("foo", string([]byte{0x7})) - _, err := rt.RoundTrip(req1) - Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) - }) - - It("rejects requests with an invalid request method", func() { - req1.Method = "foobär" - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("http3: invalid method \"foobär\"")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - }) - - Context("closing", func() { - It("closes", func() { - rt.clients = make(map[string]roundTripCloser) - cl := &mockClient{} - rt.clients["foo.bar"] = cl - err := rt.Close() - Expect(err).ToNot(HaveOccurred()) - Expect(len(rt.clients)).To(BeZero()) - Expect(cl.closed).To(BeTrue()) - }) - - It("closes a RoundTripper that has never been used", func() { - Expect(len(rt.clients)).To(BeZero()) - err := rt.Close() - Expect(err).ToNot(HaveOccurred()) - Expect(len(rt.clients)).To(BeZero()) - }) - }) -}) diff --git a/internal/http3/server.go b/internal/http3/server.go index 0ce67b07..7449c3c7 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -1,14 +1,7 @@ package http3 import ( - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// allows mocking of quic.Listen and quic.ListenAddr -var ( - quicListen = quic.ListenEarly - quicListenAddr = quic.ListenAddrEarly + "github.com/quic-go/quic-go" ) const ( @@ -27,10 +20,10 @@ const ( ) func versionToALPN(v quic.VersionNumber) string { - if v == protocol.Version1 || v == protocol.Version2 { + switch v { + case Version1, Version2: return nextProtoH3 - } - if v == protocol.VersionTLS || v == protocol.VersionDraft29 { + case VersionDraft29: return nextProtoH3Draft29 } return "" diff --git a/internal/quic-go/ackhandler/ack_eliciting.go b/internal/quic-go/ackhandler/ack_eliciting.go deleted file mode 100644 index 76e8cc01..00000000 --- a/internal/quic-go/ackhandler/ack_eliciting.go +++ /dev/null @@ -1,20 +0,0 @@ -package ackhandler - -import "github.com/imroc/req/v3/internal/quic-go/wire" - -// IsFrameAckEliciting returns true if the frame is ack-eliciting. -func IsFrameAckEliciting(f wire.Frame) bool { - _, isAck := f.(*wire.AckFrame) - _, isConnectionClose := f.(*wire.ConnectionCloseFrame) - return !isAck && !isConnectionClose -} - -// HasAckElicitingFrames returns true if at least one frame is ack-eliciting. -func HasAckElicitingFrames(fs []Frame) bool { - for _, f := range fs { - if IsFrameAckEliciting(f.Frame) { - return true - } - } - return false -} diff --git a/internal/quic-go/ackhandler/ack_eliciting_test.go b/internal/quic-go/ackhandler/ack_eliciting_test.go deleted file mode 100644 index 625899f7..00000000 --- a/internal/quic-go/ackhandler/ack_eliciting_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package ackhandler - -import ( - "reflect" - - "github.com/imroc/req/v3/internal/quic-go/wire" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("ack-eliciting frames", func() { - for fl, el := range map[wire.Frame]bool{ - &wire.AckFrame{}: false, - &wire.ConnectionCloseFrame{}: false, - &wire.DataBlockedFrame{}: true, - &wire.PingFrame{}: true, - &wire.ResetStreamFrame{}: true, - &wire.StreamFrame{}: true, - &wire.MaxDataFrame{}: true, - &wire.MaxStreamDataFrame{}: true, - } { - f := fl - e := el - fName := reflect.ValueOf(f).Elem().Type().Name() - - It("works for "+fName, func() { - Expect(IsFrameAckEliciting(f)).To(Equal(e)) - }) - - It("HasAckElicitingFrames works for "+fName, func() { - Expect(HasAckElicitingFrames([]Frame{{Frame: f}})).To(Equal(e)) - }) - } -}) diff --git a/internal/quic-go/ackhandler/ackhandler.go b/internal/quic-go/ackhandler/ackhandler.go deleted file mode 100644 index c5ebb712..00000000 --- a/internal/quic-go/ackhandler/ackhandler.go +++ /dev/null @@ -1,21 +0,0 @@ -package ackhandler - -import ( - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler -func NewAckHandler( - initialPacketNumber protocol.PacketNumber, - initialMaxDatagramSize protocol.ByteCount, - rttStats *utils.RTTStats, - pers protocol.Perspective, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, -) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) - return sph, newReceivedPacketHandler(sph, rttStats, logger, version) -} diff --git a/internal/quic-go/ackhandler/ackhandler_suite_test.go b/internal/quic-go/ackhandler/ackhandler_suite_test.go deleted file mode 100644 index 17481188..00000000 --- a/internal/quic-go/ackhandler/ackhandler_suite_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package ackhandler - -import ( - "math/rand" - "testing" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestCrypto(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "AckHandler Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeSuite(func() { - rand.Seed(GinkgoRandomSeed()) -}) - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/quic-go/ackhandler/frame.go b/internal/quic-go/ackhandler/frame.go deleted file mode 100644 index 98866e91..00000000 --- a/internal/quic-go/ackhandler/frame.go +++ /dev/null @@ -1,9 +0,0 @@ -package ackhandler - -import "github.com/imroc/req/v3/internal/quic-go/wire" - -type Frame struct { - wire.Frame // nil if the frame has already been acknowledged in another packet - OnLost func(wire.Frame) - OnAcked func(wire.Frame) -} diff --git a/internal/quic-go/ackhandler/gen.go b/internal/quic-go/ackhandler/gen.go deleted file mode 100644 index 32235f81..00000000 --- a/internal/quic-go/ackhandler/gen.go +++ /dev/null @@ -1,3 +0,0 @@ -package ackhandler - -//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet diff --git a/internal/quic-go/ackhandler/interfaces.go b/internal/quic-go/ackhandler/interfaces.go deleted file mode 100644 index aa54bf53..00000000 --- a/internal/quic-go/ackhandler/interfaces.go +++ /dev/null @@ -1,68 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// A Packet is a packet -type Packet struct { - PacketNumber protocol.PacketNumber - Frames []Frame - LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK - Length protocol.ByteCount - EncryptionLevel protocol.EncryptionLevel - SendTime time.Time - - IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. - - includedInBytesInFlight bool - declaredLost bool - skippedPacket bool -} - -// SentPacketHandler handles ACKs received for outgoing packets -type SentPacketHandler interface { - // SentPacket may modify the packet - SentPacket(packet *Packet) - ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) - ReceivedBytes(protocol.ByteCount) - DropPackets(protocol.EncryptionLevel) - ResetForRetry() error - SetHandshakeConfirmed() - - // The SendMode determines if and what kind of packets can be sent. - SendMode() SendMode - // TimeUntilSend is the time when the next packet should be sent. - // It is used for pacing packets. - TimeUntilSend() time.Time - // HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment. - HasPacingBudget() bool - SetMaxDatagramSize(count protocol.ByteCount) - - // only to be called once the handshake is complete - QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ - - PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) - PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber - - GetLossDetectionTimeout() time.Time - OnLossDetectionTimeout() error -} - -type sentPacketTracker interface { - GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - ReceivedPacket(protocol.EncryptionLevel) -} - -// ReceivedPacketHandler handles ACKs needed to send for incoming packets -type ReceivedPacketHandler interface { - IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool - ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error - DropPackets(protocol.EncryptionLevel) - - GetAlarmTimeout() time.Time - GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame -} diff --git a/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go b/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go deleted file mode 100644 index 01eb17df..00000000 --- a/internal/quic-go/ackhandler/mock_sent_packet_tracker_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: interfaces.go - -// Package ackhandler is a generated GoMock package. -package ackhandler - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockSentPacketTracker is a mock of SentPacketTracker interface. -type MockSentPacketTracker struct { - ctrl *gomock.Controller - recorder *MockSentPacketTrackerMockRecorder -} - -// MockSentPacketTrackerMockRecorder is the mock recorder for MockSentPacketTracker. -type MockSentPacketTrackerMockRecorder struct { - mock *MockSentPacketTracker -} - -// NewMockSentPacketTracker creates a new mock instance. -func NewMockSentPacketTracker(ctrl *gomock.Controller) *MockSentPacketTracker { - mock := &MockSentPacketTracker{ctrl: ctrl} - mock.recorder = &MockSentPacketTrackerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSentPacketTracker) EXPECT() *MockSentPacketTrackerMockRecorder { - return m.recorder -} - -// GetLowestPacketNotConfirmedAcked mocks base method. -func (m *MockSentPacketTracker) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked. -func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) -} - -// ReceivedPacket mocks base method. -func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0) -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) -} diff --git a/internal/quic-go/ackhandler/mockgen.go b/internal/quic-go/ackhandler/mockgen.go deleted file mode 100644 index 8c5e33c0..00000000 --- a/internal/quic-go/ackhandler/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package ackhandler - -//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/imroc/req/v3/internal/quic-go/ackhandler sentPacketTracker" diff --git a/internal/quic-go/ackhandler/packet_linkedlist.go b/internal/quic-go/ackhandler/packet_linkedlist.go deleted file mode 100644 index bb74f4ef..00000000 --- a/internal/quic-go/ackhandler/packet_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package ackhandler - -// Linked list implementation from the Go standard library. - -// PacketElement is an element of a linked list. -type PacketElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *PacketElement - - // The list to which this element belongs. - list *PacketList - - // The value stored with this element. - Value Packet -} - -// Next returns the next list element or nil. -func (e *PacketElement) Next() *PacketElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *PacketElement) Prev() *PacketElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// PacketList is a linked list of Packets. -type PacketList struct { - root PacketElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *PacketList) Init() *PacketList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewPacketList returns an initialized list. -func NewPacketList() *PacketList { return new(PacketList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *PacketList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *PacketList) Front() *PacketElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *PacketList) Back() *PacketElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *PacketList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *PacketList) insert(e, at *PacketElement) *PacketElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { - return l.insert(&PacketElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *PacketList) remove(e *PacketElement) *PacketElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *PacketList) Remove(e *PacketElement) Packet { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *PacketList) PushFront(v Packet) *PacketElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *PacketList) PushBack(v Packet) *PacketElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketList) MoveToFront(e *PacketElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketList) MoveToBack(e *PacketElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketList) MoveBefore(e, mark *PacketElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketList) MoveAfter(e, mark *PacketElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketList) PushBackList(other *PacketList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketList) PushFrontList(other *PacketList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/quic-go/ackhandler/packet_number_generator.go b/internal/quic-go/ackhandler/packet_number_generator.go deleted file mode 100644 index 0d81f6d1..00000000 --- a/internal/quic-go/ackhandler/packet_number_generator.go +++ /dev/null @@ -1,76 +0,0 @@ -package ackhandler - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type packetNumberGenerator interface { - Peek() protocol.PacketNumber - Pop() protocol.PacketNumber -} - -type sequentialPacketNumberGenerator struct { - next protocol.PacketNumber -} - -var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} - -func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { - return &sequentialPacketNumberGenerator{next: initial} -} - -func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { - return p.next -} - -func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { - next := p.next - p.next++ - return next -} - -// The skippingPacketNumberGenerator generates the packet number for the next packet -// it randomly skips a packet number every averagePeriod packets (on average). -// It is guaranteed to never skip two consecutive packet numbers. -type skippingPacketNumberGenerator struct { - period protocol.PacketNumber - maxPeriod protocol.PacketNumber - - next protocol.PacketNumber - nextToSkip protocol.PacketNumber - - rng utils.Rand -} - -var _ packetNumberGenerator = &skippingPacketNumberGenerator{} - -func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { - g := &skippingPacketNumberGenerator{ - next: initial, - period: initialPeriod, - maxPeriod: maxPeriod, - } - g.generateNewSkip() - return g -} - -func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { - return p.next -} - -func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { - next := p.next - p.next++ // generate a new packet number for the next packet - if p.next == p.nextToSkip { - p.next++ - p.generateNewSkip() - } - return next -} - -func (p *skippingPacketNumberGenerator) generateNewSkip() { - // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) - p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) -} diff --git a/internal/quic-go/ackhandler/packet_number_generator_test.go b/internal/quic-go/ackhandler/packet_number_generator_test.go deleted file mode 100644 index db4d096d..00000000 --- a/internal/quic-go/ackhandler/packet_number_generator_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package ackhandler - -import ( - "fmt" - "math" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Sequential Packet Number Generator", func() { - It("generates sequential packet numbers", func() { - const initialPN protocol.PacketNumber = 123 - png := newSequentialPacketNumberGenerator(initialPN) - - for i := initialPN; i < initialPN+1000; i++ { - Expect(png.Peek()).To(Equal(i)) - Expect(png.Peek()).To(Equal(i)) - Expect(png.Pop()).To(Equal(i)) - } - }) -}) - -var _ = Describe("Skipping Packet Number Generator", func() { - const initialPN protocol.PacketNumber = 8 - const initialPeriod protocol.PacketNumber = 25 - const maxPeriod protocol.PacketNumber = 300 - - It("uses a maximum period that is sufficiently small such that using a 32-bit random number is ok", func() { - Expect(2 * protocol.SkipPacketMaxPeriod).To(BeNumerically("<", math.MaxInt32)) - }) - - It("can be initialized to return any first packet number", func() { - png := newSkippingPacketNumberGenerator(12345, initialPeriod, maxPeriod) - Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) - }) - - It("allows peeking", func() { - png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod).(*skippingPacketNumberGenerator) - png.nextToSkip = 1000 - Expect(png.Peek()).To(Equal(initialPN)) - Expect(png.Peek()).To(Equal(initialPN)) - Expect(png.Pop()).To(Equal(initialPN)) - Expect(png.Peek()).To(Equal(initialPN + 1)) - Expect(png.Peek()).To(Equal(initialPN + 1)) - }) - - It("skips a packet number", func() { - png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) - var last protocol.PacketNumber - var skipped bool - for i := 0; i < 1000; i++ { - num := png.Pop() - if num > last+1 { - skipped = true - break - } - last = num - } - Expect(skipped).To(BeTrue()) - }) - - It("generates a new packet number to skip", func() { - const rep = 2500 - periods := make([][]protocol.PacketNumber, rep) - expectedPeriods := []protocol.PacketNumber{25, 50, 100, 200, 300, 300, 300} - - for i := 0; i < rep; i++ { - png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) - last := initialPN - lastSkip := initialPN - for len(periods[i]) < len(expectedPeriods) { - next := png.Pop() - if next > last+1 { - skipped := next - 1 - Expect(skipped).To(BeNumerically(">", lastSkip+1)) - periods[i] = append(periods[i], skipped-lastSkip-1) - lastSkip = skipped - } - last = next - } - } - - for j := 0; j < len(expectedPeriods); j++ { - var average float64 - for i := 0; i < rep; i++ { - average += float64(periods[i][j]) / float64(len(periods)) - } - fmt.Fprintf(GinkgoWriter, "Period %d: %.2f (expected %d)\n", j, average, expectedPeriods[j]) - tolerance := protocol.PacketNumber(5) - if t := expectedPeriods[j] / 10; t > tolerance { - tolerance = t - } - Expect(average).To(BeNumerically("~", expectedPeriods[j]+1 /* we never skip two packet numbers at the same time */, tolerance)) - } - }) -}) diff --git a/internal/quic-go/ackhandler/received_packet_handler.go b/internal/quic-go/ackhandler/received_packet_handler.go deleted file mode 100644 index 39e45da4..00000000 --- a/internal/quic-go/ackhandler/received_packet_handler.go +++ /dev/null @@ -1,136 +0,0 @@ -package ackhandler - -import ( - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type receivedPacketHandler struct { - sentPackets sentPacketTracker - - initialPackets *receivedPacketTracker - handshakePackets *receivedPacketTracker - appDataPackets *receivedPacketTracker - - lowest1RTTPacket protocol.PacketNumber -} - -var _ ReceivedPacketHandler = &receivedPacketHandler{} - -func newReceivedPacketHandler( - sentPackets sentPacketTracker, - rttStats *utils.RTTStats, - logger utils.Logger, - version protocol.VersionNumber, -) ReceivedPacketHandler { - return &receivedPacketHandler{ - sentPackets: sentPackets, - initialPackets: newReceivedPacketTracker(rttStats, logger, version), - handshakePackets: newReceivedPacketTracker(rttStats, logger, version), - appDataPackets: newReceivedPacketTracker(rttStats, logger, version), - lowest1RTTPacket: protocol.InvalidPacketNumber, - } -} - -func (h *receivedPacketHandler) ReceivedPacket( - pn protocol.PacketNumber, - ecn protocol.ECN, - encLevel protocol.EncryptionLevel, - rcvTime time.Time, - shouldInstigateAck bool, -) error { - h.sentPackets.ReceivedPacket(encLevel) - switch encLevel { - case protocol.EncryptionInitial: - h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) - case protocol.EncryptionHandshake: - h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) - case protocol.Encryption0RTT: - if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { - return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) - } - h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) - case protocol.Encryption1RTT: - if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { - h.lowest1RTTPacket = pn - } - h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) - h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) - default: - panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) - } - return nil -} - -func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { - //nolint:exhaustive // 1-RTT packet number space is never dropped. - switch encLevel { - case protocol.EncryptionInitial: - h.initialPackets = nil - case protocol.EncryptionHandshake: - h.handshakePackets = nil - case protocol.Encryption0RTT: - // Nothing to do here. - // If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted. - default: - panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) - } -} - -func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - var initialAlarm, handshakeAlarm time.Time - if h.initialPackets != nil { - initialAlarm = h.initialPackets.GetAlarmTimeout() - } - if h.handshakePackets != nil { - handshakeAlarm = h.handshakePackets.GetAlarmTimeout() - } - oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() - return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) -} - -func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { - var ack *wire.AckFrame - //nolint:exhaustive // 0-RTT packets can't contain ACK frames. - switch encLevel { - case protocol.EncryptionInitial: - if h.initialPackets != nil { - ack = h.initialPackets.GetAckFrame(onlyIfQueued) - } - case protocol.EncryptionHandshake: - if h.handshakePackets != nil { - ack = h.handshakePackets.GetAckFrame(onlyIfQueued) - } - case protocol.Encryption1RTT: - // 0-RTT packets can't contain ACK frames - return h.appDataPackets.GetAckFrame(onlyIfQueued) - default: - return nil - } - // For Initial and Handshake ACKs, the delay time is ignored by the receiver. - // Set it to 0 in order to save bytes. - if ack != nil { - ack.DelayTime = 0 - } - return ack -} - -func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { - switch encLevel { - case protocol.EncryptionInitial: - if h.initialPackets != nil { - return h.initialPackets.IsPotentiallyDuplicate(pn) - } - case protocol.EncryptionHandshake: - if h.handshakePackets != nil { - return h.handshakePackets.IsPotentiallyDuplicate(pn) - } - case protocol.Encryption0RTT, protocol.Encryption1RTT: - return h.appDataPackets.IsPotentiallyDuplicate(pn) - } - panic("unexpected encryption level") -} diff --git a/internal/quic-go/ackhandler/received_packet_handler_test.go b/internal/quic-go/ackhandler/received_packet_handler_test.go deleted file mode 100644 index b852b068..00000000 --- a/internal/quic-go/ackhandler/received_packet_handler_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/golang/mock/gomock" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Received Packet Handler", func() { - var handler ReceivedPacketHandler - var sentPackets *MockSentPacketTracker - - BeforeEach(func() { - sentPackets = NewMockSentPacketTracker(mockCtrl) - handler = newReceivedPacketHandler( - sentPackets, - &utils.RTTStats{}, - utils.DefaultLogger, - protocol.VersionWhatever, - ) - }) - - It("generates ACKs for different packet number spaces", func() { - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sendTime := time.Now().Add(-time.Second) - sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) - sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) - Expect(handler.ReceivedPacket(2, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(1, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(5, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(3, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(4, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - initialAck := handler.GetAckFrame(protocol.EncryptionInitial, true) - Expect(initialAck).ToNot(BeNil()) - Expect(initialAck.AckRanges).To(HaveLen(1)) - Expect(initialAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) - Expect(initialAck.DelayTime).To(BeZero()) - Expect(initialAck.ECT0).To(BeEquivalentTo(2)) - Expect(initialAck.ECT1).To(BeZero()) - Expect(initialAck.ECNCE).To(BeZero()) - handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake, true) - Expect(handshakeAck).ToNot(BeNil()) - Expect(handshakeAck.AckRanges).To(HaveLen(1)) - Expect(handshakeAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) - Expect(handshakeAck.DelayTime).To(BeZero()) - Expect(handshakeAck.ECT0).To(BeZero()) - Expect(handshakeAck.ECT1).To(BeEquivalentTo(2)) - Expect(handshakeAck.ECNCE).To(BeZero()) - oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT, true) - Expect(oneRTTAck).ToNot(BeNil()) - Expect(oneRTTAck.AckRanges).To(HaveLen(1)) - Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) - Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) - Expect(oneRTTAck.ECT0).To(BeZero()) - Expect(oneRTTAck.ECT1).To(BeZero()) - Expect(oneRTTAck.ECNCE).To(BeEquivalentTo(2)) - }) - - It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT) - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) - sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack := handler.GetAckFrame(protocol.Encryption1RTT, true) - Expect(ack).ToNot(BeNil()) - Expect(ack.AckRanges).To(HaveLen(1)) - Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) - }) - - It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sendTime := time.Now() - Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(MatchError("received packet number 12 on a 0-RTT packet after receiving 11 on a 1-RTT packet")) - }) - - It("allows reordered 0-RTT packets", func() { - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sendTime := time.Now() - Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - }) - - It("drops Initial packets", func() { - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) - sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).ToNot(BeNil()) - handler.DropPackets(protocol.EncryptionInitial) - Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).To(BeNil()) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) - }) - - It("drops Handshake packets", func() { - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) - handler.DropPackets(protocol.EncryptionInitial) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).To(BeNil()) - Expect(handler.GetAckFrame(protocol.Encryption1RTT, true)).ToNot(BeNil()) - }) - - It("does nothing when dropping 0-RTT packets", func() { - handler.DropPackets(protocol.Encryption0RTT) - }) - - It("drops old ACK ranges", func() { - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() - sendTime := time.Now() - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) - Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack := handler.GetAckFrame(protocol.Encryption1RTT, true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked() - Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(2)) - Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack = handler.GetAckFrame(protocol.Encryption1RTT, true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(2))) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) - }) - - It("says if packets are duplicates", func() { - sendTime := time.Now() - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - // Initial - Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeTrue()) - // Handshake - Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeTrue()) - // 0-RTT - Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeTrue()) - // 1-RTT - Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption1RTT)).To(BeTrue()) - Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeFalse()) - Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeTrue()) - }) -}) diff --git a/internal/quic-go/ackhandler/received_packet_history.go b/internal/quic-go/ackhandler/received_packet_history.go deleted file mode 100644 index f3bcc25e..00000000 --- a/internal/quic-go/ackhandler/received_packet_history.go +++ /dev/null @@ -1,142 +0,0 @@ -package ackhandler - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// The receivedPacketHistory stores if a packet number has already been received. -// It generates ACK ranges which can be used to assemble an ACK frame. -// It does not store packet contents. -type receivedPacketHistory struct { - ranges *utils.PacketIntervalList - - deletedBelow protocol.PacketNumber -} - -func newReceivedPacketHistory() *receivedPacketHistory { - return &receivedPacketHistory{ - ranges: utils.NewPacketIntervalList(), - } -} - -// ReceivedPacket registers a packet with PacketNumber p and updates the ranges -func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { - // ignore delayed packets, if we already deleted the range - if p < h.deletedBelow { - return false - } - isNew := h.addToRanges(p) - h.maybeDeleteOldRanges() - return isNew -} - -func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { - if h.ranges.Len() == 0 { - h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) - return true - } - - for el := h.ranges.Back(); el != nil; el = el.Prev() { - // p already included in an existing range. Nothing to do here - if p >= el.Value.Start && p <= el.Value.End { - return false - } - - if el.Value.End == p-1 { // extend a range at the end - el.Value.End = p - return true - } - if el.Value.Start == p+1 { // extend a range at the beginning - el.Value.Start = p - - prev := el.Prev() - if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges - prev.Value.End = el.Value.End - h.ranges.Remove(el) - } - return true - } - - // create a new range at the end - if p > el.Value.End { - h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) - return true - } - } - - // create a new range at the beginning - h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) - return true -} - -// Delete old ranges, if we're tracking more than 500 of them. -// This is a DoS defense against a peer that sends us too many gaps. -func (h *receivedPacketHistory) maybeDeleteOldRanges() { - for h.ranges.Len() > protocol.MaxNumAckRanges { - h.ranges.Remove(h.ranges.Front()) - } -} - -// DeleteBelow deletes all entries below (but not including) p -func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { - if p < h.deletedBelow { - return - } - h.deletedBelow = p - - nextEl := h.ranges.Front() - for el := h.ranges.Front(); nextEl != nil; el = nextEl { - nextEl = el.Next() - - if el.Value.End < p { // delete a whole range - h.ranges.Remove(el) - } else if p > el.Value.Start && p <= el.Value.End { - el.Value.Start = p - return - } else { // no ranges affected. Nothing to do - return - } - } -} - -// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame -func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { - if h.ranges.Len() == 0 { - return nil - } - - ackRanges := make([]wire.AckRange, h.ranges.Len()) - i := 0 - for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} - i++ - } - return ackRanges -} - -func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { - ackRange := wire.AckRange{} - if h.ranges.Len() > 0 { - r := h.ranges.Back().Value - ackRange.Smallest = r.Start - ackRange.Largest = r.End - } - return ackRange -} - -func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool { - if p < h.deletedBelow { - return true - } - for el := h.ranges.Back(); el != nil; el = el.Prev() { - if p > el.Value.End { - return false - } - if p <= el.Value.End && p >= el.Value.Start { - return true - } - } - return false -} diff --git a/internal/quic-go/ackhandler/received_packet_history_test.go b/internal/quic-go/ackhandler/received_packet_history_test.go deleted file mode 100644 index 9994b489..00000000 --- a/internal/quic-go/ackhandler/received_packet_history_test.go +++ /dev/null @@ -1,354 +0,0 @@ -package ackhandler - -import ( - "fmt" - "math/rand" - "sort" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("receivedPacketHistory", func() { - var hist *receivedPacketHistory - - BeforeEach(func() { - hist = newReceivedPacketHistory() - }) - - Context("ranges", func() { - It("adds the first packet", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - }) - - It("doesn't care about duplicate packets", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(4)).To(BeFalse()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - }) - - It("adds a few consecutive packets", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) - }) - - It("doesn't care about a duplicate packet contained in an existing range", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeFalse()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) - }) - - It("extends a range at the front", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(3)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4})) - }) - - It("creates a new range when a packet is lost", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) - }) - - It("creates a new range in between two ranges", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ReceivedPacket(7)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(3)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) - }) - - It("creates a new range before an existing range for a belated packet", func() { - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) - }) - - It("extends a previous range at the end", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(7)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) - }) - - It("extends a range at the front", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(7)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7})) - }) - - It("closes a range", func() { - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) - }) - - It("closes a range in the middle", func() { - Expect(hist.ReceivedPacket(1)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(4)) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ranges.Len()).To(Equal(3)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1})) - Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) - }) - }) - - Context("deleting", func() { - It("does nothing when the history is empty", func() { - hist.DeleteBelow(5) - Expect(hist.ranges.Len()).To(BeZero()) - }) - - It("deletes a range", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - hist.DeleteBelow(6) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) - }) - - It("deletes multiple ranges", func() { - Expect(hist.ReceivedPacket(1)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - hist.DeleteBelow(8) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) - }) - - It("adjusts a range, if packets are delete from an existing range", func() { - Expect(hist.ReceivedPacket(3)).To(BeTrue()) - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(7)).To(BeTrue()) - hist.DeleteBelow(5) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7})) - }) - - It("adjusts a range, if only one packet remains in the range", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - hist.DeleteBelow(5) - Expect(hist.ranges.Len()).To(Equal(2)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5})) - Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) - }) - - It("keeps a one-packet range, if deleting up to the packet directly below", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - hist.DeleteBelow(4) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) - }) - - It("doesn't add delayed packets below deleted ranges", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - hist.DeleteBelow(5) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) - Expect(hist.ReceivedPacket(2)).To(BeFalse()) - Expect(hist.ranges.Len()).To(Equal(1)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) - }) - - It("doesn't create more than MaxNumAckRanges ranges", func() { - for i := protocol.PacketNumber(0); i < protocol.MaxNumAckRanges; i++ { - Expect(hist.ReceivedPacket(2 * i)).To(BeTrue()) - } - Expect(hist.ranges.Len()).To(Equal(protocol.MaxNumAckRanges)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 0, End: 0})) - hist.ReceivedPacket(2*protocol.MaxNumAckRanges + 1000) - // check that the oldest ACK range was deleted - Expect(hist.ranges.Len()).To(Equal(protocol.MaxNumAckRanges)) - Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 2, End: 2})) - }) - }) - - Context("ACK range export", func() { - It("returns nil if there are no ranges", func() { - Expect(hist.GetAckRanges()).To(BeNil()) - }) - - It("gets a single ACK range", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - ackRanges := hist.GetAckRanges() - Expect(ackRanges).To(HaveLen(1)) - Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) - }) - - It("gets multiple ACK ranges", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(1)).To(BeTrue()) - Expect(hist.ReceivedPacket(11)).To(BeTrue()) - Expect(hist.ReceivedPacket(10)).To(BeTrue()) - Expect(hist.ReceivedPacket(2)).To(BeTrue()) - ackRanges := hist.GetAckRanges() - Expect(ackRanges).To(HaveLen(3)) - Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11})) - Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6})) - Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) - }) - }) - - Context("Getting the highest ACK range", func() { - It("returns the zero value if there are no ranges", func() { - Expect(hist.GetHighestAckRange()).To(BeZero()) - }) - - It("gets a single ACK range", func() { - Expect(hist.ReceivedPacket(4)).To(BeTrue()) - Expect(hist.ReceivedPacket(5)).To(BeTrue()) - Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) - }) - - It("gets the highest of multiple ACK ranges", func() { - Expect(hist.ReceivedPacket(3)).To(BeTrue()) - Expect(hist.ReceivedPacket(6)).To(BeTrue()) - Expect(hist.ReceivedPacket(7)).To(BeTrue()) - Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7})) - }) - }) - - Context("duplicate detection", func() { - It("doesn't declare the first packet a duplicate", func() { - Expect(hist.IsPotentiallyDuplicate(5)).To(BeFalse()) - }) - - It("detects a duplicate in a range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) - Expect(hist.IsPotentiallyDuplicate(3)).To(BeFalse()) - Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(6)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(7)).To(BeFalse()) - }) - - It("detects a duplicate in multiple ranges", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(8) - hist.ReceivedPacket(9) - Expect(hist.IsPotentiallyDuplicate(3)).To(BeFalse()) - Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(6)).To(BeFalse()) - Expect(hist.IsPotentiallyDuplicate(7)).To(BeFalse()) - Expect(hist.IsPotentiallyDuplicate(8)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(9)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(10)).To(BeFalse()) - }) - - It("says a packet is a potentially duplicate if the ranges were already deleted", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(8) - hist.ReceivedPacket(9) - hist.ReceivedPacket(11) - hist.DeleteBelow(8) - Expect(hist.IsPotentiallyDuplicate(3)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(4)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(5)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(6)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(7)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(8)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(9)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(10)).To(BeFalse()) - Expect(hist.IsPotentiallyDuplicate(11)).To(BeTrue()) - Expect(hist.IsPotentiallyDuplicate(12)).To(BeFalse()) - }) - }) - - Context("randomized receiving", func() { - It("receiving packets in a random order, with gaps", func() { - packets := make(map[protocol.PacketNumber]int) - // Make sure we never end up with more than protocol.MaxNumAckRanges ACK ranges, even - // when we're receiving packets in a random order. - const num = 2 * protocol.MaxNumAckRanges - numLostPackets := rand.Intn(protocol.MaxNumAckRanges) - numRcvdPackets := num - numLostPackets - - for i := 0; i < num; i++ { - packets[protocol.PacketNumber(i)] = 0 - } - lostPackets := make([]protocol.PacketNumber, 0, numLostPackets) - for len(lostPackets) < numLostPackets { - p := protocol.PacketNumber(rand.Intn(num)) - if _, ok := packets[p]; ok { - lostPackets = append(lostPackets, p) - delete(packets, p) - } - } - sort.Slice(lostPackets, func(i, j int) bool { return lostPackets[i] < lostPackets[j] }) - fmt.Fprintf(GinkgoWriter, "Losing packets: %v\n", lostPackets) - - ordered := make([]protocol.PacketNumber, 0, numRcvdPackets) - for p := range packets { - ordered = append(ordered, p) - } - rand.Shuffle(len(ordered), func(i, j int) { ordered[i], ordered[j] = ordered[j], ordered[i] }) - - fmt.Fprintf(GinkgoWriter, "Receiving packets: %v\n", ordered) - for i, p := range ordered { - Expect(hist.ReceivedPacket(p)).To(BeTrue()) - // sometimes receive a duplicate - if i > 0 && rand.Int()%5 == 0 { - Expect(hist.ReceivedPacket(ordered[rand.Intn(i)])).To(BeFalse()) - } - } - var counter int - ackRanges := hist.GetAckRanges() - fmt.Fprintf(GinkgoWriter, "ACK ranges: %v\n", ackRanges) - Expect(len(ackRanges)).To(BeNumerically("<=", numLostPackets+1)) - for _, ackRange := range ackRanges { - for p := ackRange.Smallest; p <= ackRange.Largest; p++ { - counter++ - Expect(packets).To(HaveKey(p)) - } - } - Expect(counter).To(Equal(numRcvdPackets)) - }) - }) -}) diff --git a/internal/quic-go/ackhandler/received_packet_tracker.go b/internal/quic-go/ackhandler/received_packet_tracker.go deleted file mode 100644 index 31882311..00000000 --- a/internal/quic-go/ackhandler/received_packet_tracker.go +++ /dev/null @@ -1,196 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// number of ack-eliciting packets received before sending an ack. -const packetsBeforeAck = 2 - -type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber - largestObservedReceivedTime time.Time - ect0, ect1, ecnce uint64 - - packetHistory *receivedPacketHistory - - maxAckDelay time.Duration - rttStats *utils.RTTStats - - hasNewAck bool // true as soon as we received an ack-eliciting new packet - ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets - - ackElicitingPacketsReceivedSinceLastAck int - ackAlarm time.Time - lastAck *wire.AckFrame - - logger utils.Logger - - version protocol.VersionNumber -} - -func newReceivedPacketTracker( - rttStats *utils.RTTStats, - logger utils.Logger, - version protocol.VersionNumber, -) *receivedPacketTracker { - return &receivedPacketTracker{ - packetHistory: newReceivedPacketHistory(), - maxAckDelay: protocol.MaxAckDelay, - rttStats: rttStats, - logger: logger, - version: version, - } -} - -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) { - if packetNumber < h.ignoreBelow { - return - } - - isMissing := h.isMissing(packetNumber) - if packetNumber >= h.largestObserved { - h.largestObserved = packetNumber - h.largestObservedReceivedTime = rcvTime - } - - if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck { - h.hasNewAck = true - } - if shouldInstigateAck { - h.maybeQueueAck(packetNumber, rcvTime, isMissing) - } - switch ecn { - case protocol.ECNNon: - case protocol.ECT0: - h.ect0++ - case protocol.ECT1: - h.ect1++ - case protocol.ECNCE: - h.ecnce++ - } -} - -// IgnoreBelow sets a lower limit for acknowledging packets. -// Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { - if p <= h.ignoreBelow { - return - } - h.ignoreBelow = p - h.packetHistory.DeleteBelow(p) - if h.logger.Debug() { - h.logger.Debugf("\tIgnoring all packets below %d.", p) - } -} - -// isMissing says if a packet was reported missing in the last ACK. -func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { - if h.lastAck == nil || p < h.ignoreBelow { - return false - } - return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) -} - -func (h *receivedPacketTracker) hasNewMissingPackets() bool { - if h.lastAck == nil { - return false - } - highestRange := h.packetHistory.GetHighestAckRange() - return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 -} - -// maybeQueueAck queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { - // always acknowledge the first packet - if h.lastAck == nil { - if !h.ackQueued { - h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") - } - h.ackQueued = true - return - } - - if h.ackQueued { - return - } - - h.ackElicitingPacketsReceivedSinceLastAck++ - - // Send an ACK if this packet was reported missing in an ACK sent before. - // Ack decimation with reordering relies on the timer to send an ACK, but if - // missing packets we reported in the previous ack, send an ACK immediately. - if wasMissing { - if h.logger.Debug() { - h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) - } - h.ackQueued = true - } - - // send an ACK every 2 ack-eliciting packets - if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck { - if h.logger.Debug() { - h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) - } - h.ackQueued = true - } else if h.ackAlarm.IsZero() { - if h.logger.Debug() { - h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) - } - h.ackAlarm = rcvTime.Add(h.maxAckDelay) - } - - // Queue an ACK if there are new missing packets to report. - if h.hasNewMissingPackets() { - h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") - h.ackQueued = true - } - - if h.ackQueued { - // cancel the ack alarm - h.ackAlarm = time.Time{} - } -} - -func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { - if !h.hasNewAck { - return nil - } - now := time.Now() - if onlyIfQueued { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { - return nil - } - if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { - h.logger.Debugf("Sending ACK because the ACK timer expired.") - } - } - - ack := &wire.AckFrame{ - AckRanges: h.packetHistory.GetAckRanges(), - // Make sure that the DelayTime is always positive. - // This is not guaranteed on systems that don't have a monotonic clock. - DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)), - ECT0: h.ect0, - ECT1: h.ect1, - ECNCE: h.ecnce, - } - - h.lastAck = ack - h.ackAlarm = time.Time{} - h.ackQueued = false - h.hasNewAck = false - h.ackElicitingPacketsReceivedSinceLastAck = 0 - return ack -} - -func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } - -func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { - return h.packetHistory.IsPotentiallyDuplicate(pn) -} diff --git a/internal/quic-go/ackhandler/received_packet_tracker_test.go b/internal/quic-go/ackhandler/received_packet_tracker_test.go deleted file mode 100644 index 66b43cde..00000000 --- a/internal/quic-go/ackhandler/received_packet_tracker_test.go +++ /dev/null @@ -1,348 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Received Packet Tracker", func() { - var ( - tracker *receivedPacketTracker - rttStats *utils.RTTStats - ) - - BeforeEach(func() { - rttStats = &utils.RTTStats{} - tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger, protocol.VersionWhatever) - }) - - Context("accepting packets", func() { - It("saves the time when each packet arrived", func() { - tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true) - Expect(tracker.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) - }) - - It("updates the largestObserved and the largestObservedReceivedTime", func() { - now := time.Now() - tracker.largestObserved = 3 - tracker.largestObservedReceivedTime = now.Add(-1 * time.Second) - tracker.ReceivedPacket(5, protocol.ECNNon, now, true) - Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(now)) - }) - - It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { - now := time.Now() - timestamp := now.Add(-1 * time.Second) - tracker.largestObserved = 5 - tracker.largestObservedReceivedTime = timestamp - tracker.ReceivedPacket(4, protocol.ECNNon, now, true) - Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(timestamp)) - }) - }) - - Context("ACKs", func() { - Context("queueing ACKs", func() { - receiveAndAck10Packets := func() { - for i := 1; i <= 10; i++ { - tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Time{}, true) - } - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - Expect(tracker.ackQueued).To(BeFalse()) - } - - It("always queues an ACK for the first packet", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - Expect(tracker.ackQueued).To(BeTrue()) - Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) - }) - - It("works with packet number 0", func() { - tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) - Expect(tracker.ackQueued).To(BeTrue()) - Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) - }) - - It("sets ECN flags", func() { - tracker.ReceivedPacket(0, protocol.ECT0, time.Now(), true) - pn := protocol.PacketNumber(1) - for i := 0; i < 2; i++ { - tracker.ReceivedPacket(pn, protocol.ECT1, time.Now(), true) - pn++ - } - for i := 0; i < 3; i++ { - tracker.ReceivedPacket(pn, protocol.ECNCE, time.Now(), true) - pn++ - } - ack := tracker.GetAckFrame(false) - Expect(ack.ECT0).To(BeEquivalentTo(1)) - Expect(ack.ECT1).To(BeEquivalentTo(2)) - Expect(ack.ECNCE).To(BeEquivalentTo(3)) - }) - - It("queues an ACK for every second ack-eliciting packet", func() { - receiveAndAck10Packets() - p := protocol.PacketNumber(11) - for i := 0; i <= 20; i++ { - tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) - Expect(tracker.ackQueued).To(BeFalse()) - p++ - tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) - Expect(tracker.ackQueued).To(BeTrue()) - p++ - // dequeue the ACK frame - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - } - }) - - It("resets the counter when a non-queued ACK frame is generated", func() { - receiveAndAck10Packets() - rcvTime := time.Now() - tracker.ReceivedPacket(11, protocol.ECNNon, rcvTime, true) - Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) - tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - tracker.ReceivedPacket(13, protocol.ECNNon, rcvTime, true) - Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) - }) - - It("only sets the timer when receiving a ack-eliciting packets", func() { - receiveAndAck10Packets() - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) - Expect(tracker.ackQueued).To(BeFalse()) - Expect(tracker.GetAlarmTimeout()).To(BeZero()) - rcvTime := time.Now().Add(10 * time.Millisecond) - tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) - Expect(tracker.ackQueued).To(BeFalse()) - Expect(tracker.GetAlarmTimeout()).To(Equal(rcvTime.Add(protocol.MaxAckDelay))) - }) - - It("queues an ACK if it was reported missing before", func() { - receiveAndAck10Packets() - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) // ACK: 1-11 and 13, missing: 12 - Expect(ack).ToNot(BeNil()) - Expect(ack.HasMissingRanges()).To(BeTrue()) - Expect(tracker.ackQueued).To(BeFalse()) - tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) - Expect(tracker.ackQueued).To(BeTrue()) - }) - - It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() { - receiveAndAck10Packets() - // 11 is missing - tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) // ACK: 1-10, 12-13 - Expect(ack).ToNot(BeNil()) - // now receive 11 - tracker.IgnoreBelow(12) - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) - ack = tracker.GetAckFrame(true) - Expect(ack).To(BeNil()) - }) - - It("doesn't recognize in-order packets as out-of-order after raising the threshold", func() { - receiveAndAck10Packets() - Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) - Expect(tracker.ackQueued).To(BeFalse()) - tracker.IgnoreBelow(11) - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - }) - - It("recognizes out-of-order packets after raising the threshold", func() { - receiveAndAck10Packets() - Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) - Expect(tracker.ackQueued).To(BeFalse()) - tracker.IgnoreBelow(11) - tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 12, Largest: 12}})) - }) - - It("doesn't queue an ACK if for non-ack-eliciting packets arriving out-of-order", func() { - receiveAndAck10Packets() - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), false) // receive a non-ack-eliciting packet out-of-order - Expect(tracker.GetAckFrame(true)).To(BeNil()) - }) - - It("doesn't queue an ACK if packets arrive out-of-order, but haven't been acknowledged yet", func() { - receiveAndAck10Packets() - Expect(tracker.lastAck).ToNot(BeNil()) - tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), false) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - // 11 is received out-of-order, but this hasn't been reported in an ACK frame yet - tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - }) - }) - - Context("ACK generation", func() { - It("generates an ACK for an ack-eliciting packet, if no ACK is queued yet", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - // The first packet is always acknowledged. - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - }) - - It("doesn't generate ACK for a non-ack-eliciting packet, if no ACK is queued yet", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - // The first packet is always acknowledged. - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - - tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), false) - Expect(tracker.GetAckFrame(false)).To(BeNil()) - tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(false) - Expect(ack).ToNot(BeNil()) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) - }) - - Context("for queued ACKs", func() { - BeforeEach(func() { - tracker.ackQueued = true - }) - - It("generates a simple ACK frame", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) - - It("generates an ACK for packet number 0", func() { - tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) - - It("sets the delay time", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(2, protocol.ECNNon, time.Now().Add(-1337*time.Millisecond), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond)) - }) - - It("uses a 0 delay time if the delay would be negative", func() { - tracker.ReceivedPacket(0, protocol.ECNNon, time.Now().Add(time.Hour), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.DelayTime).To(BeZero()) - }) - - It("saves the last sent ACK", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(tracker.lastAck).To(Equal(ack)) - tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) - tracker.ackQueued = true - ack = tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(tracker.lastAck).To(Equal(ack)) - }) - - It("generates an ACK frame with missing packets", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.AckRanges).To(Equal([]wire.AckRange{ - {Smallest: 4, Largest: 4}, - {Smallest: 1, Largest: 1}, - })) - }) - - It("generates an ACK for packet number 0 and other packets", func() { - tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.AckRanges).To(Equal([]wire.AckRange{ - {Smallest: 3, Largest: 3}, - {Smallest: 0, Largest: 1}, - })) - }) - - It("doesn't add delayed packets to the packetHistory", func() { - tracker.IgnoreBelow(7) - tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) - tracker.ReceivedPacket(10, protocol.ECNNon, time.Now(), true) - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10))) - }) - - It("deletes packets from the packetHistory when a lower limit is set", func() { - for i := 1; i <= 12; i++ { - tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Now(), true) - } - tracker.IgnoreBelow(7) - // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame - ack := tracker.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) - - It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ackAlarm = time.Now().Add(-time.Minute) - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.ackElicitingPacketsReceivedSinceLastAck).To(BeZero()) - Expect(tracker.ackQueued).To(BeFalse()) - }) - - It("doesn't generate an ACK when none is queued and the timer is not set", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ackQueued = false - tracker.ackAlarm = time.Time{} - Expect(tracker.GetAckFrame(true)).To(BeNil()) - }) - - It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ackQueued = false - tracker.ackAlarm = time.Now().Add(time.Minute) - Expect(tracker.GetAckFrame(true)).To(BeNil()) - }) - - It("generates an ACK when the timer has expired", func() { - tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) - tracker.ackQueued = false - tracker.ackAlarm = time.Now().Add(-time.Minute) - Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - }) - }) - }) - }) -}) diff --git a/internal/quic-go/ackhandler/send_mode.go b/internal/quic-go/ackhandler/send_mode.go deleted file mode 100644 index 3d5fe560..00000000 --- a/internal/quic-go/ackhandler/send_mode.go +++ /dev/null @@ -1,40 +0,0 @@ -package ackhandler - -import "fmt" - -// The SendMode says what kind of packets can be sent. -type SendMode uint8 - -const ( - // SendNone means that no packets should be sent - SendNone SendMode = iota - // SendAck means an ACK-only packet should be sent - SendAck - // SendPTOInitial means that an Initial probe packet should be sent - SendPTOInitial - // SendPTOHandshake means that a Handshake probe packet should be sent - SendPTOHandshake - // SendPTOAppData means that an Application data probe packet should be sent - SendPTOAppData - // SendAny means that any packet should be sent - SendAny -) - -func (s SendMode) String() string { - switch s { - case SendNone: - return "none" - case SendAck: - return "ack" - case SendPTOInitial: - return "pto (Initial)" - case SendPTOHandshake: - return "pto (Handshake)" - case SendPTOAppData: - return "pto (Application Data)" - case SendAny: - return "any" - default: - return fmt.Sprintf("invalid send mode: %d", s) - } -} diff --git a/internal/quic-go/ackhandler/send_mode_test.go b/internal/quic-go/ackhandler/send_mode_test.go deleted file mode 100644 index 86515d74..00000000 --- a/internal/quic-go/ackhandler/send_mode_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package ackhandler - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Send Mode", func() { - It("has a string representation", func() { - Expect(SendNone.String()).To(Equal("none")) - Expect(SendAny.String()).To(Equal("any")) - Expect(SendAck.String()).To(Equal("ack")) - Expect(SendPTOInitial.String()).To(Equal("pto (Initial)")) - Expect(SendPTOHandshake.String()).To(Equal("pto (Handshake)")) - Expect(SendPTOAppData.String()).To(Equal("pto (Application Data)")) - Expect(SendMode(123).String()).To(Equal("invalid send mode: 123")) - }) -}) diff --git a/internal/quic-go/ackhandler/sent_packet_handler.go b/internal/quic-go/ackhandler/sent_packet_handler.go deleted file mode 100644 index 2a6b19b8..00000000 --- a/internal/quic-go/ackhandler/sent_packet_handler.go +++ /dev/null @@ -1,838 +0,0 @@ -package ackhandler - -import ( - "errors" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/congestion" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -const ( - // Maximum reordering in time space before time based loss detection considers a packet lost. - // Specified as an RTT multiplier. - timeThreshold = 9.0 / 8 - // Maximum reordering in packets before packet threshold loss detection considers a packet lost. - packetThreshold = 3 - // Before validating the client's address, the server won't send more than 3x bytes than it received. - amplificationFactor = 3 - // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. - minRTTAfterRetry = 5 * time.Millisecond -) - -type packetNumberSpace struct { - history *sentPacketHistory - pns packetNumberGenerator - - lossTime time.Time - lastAckElicitingPacketTime time.Time - - largestAcked protocol.PacketNumber - largestSent protocol.PacketNumber -} - -func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { - var pns packetNumberGenerator - if skipPNs { - pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) - } else { - pns = newSequentialPacketNumberGenerator(initialPN) - } - return &packetNumberSpace{ - history: newSentPacketHistory(rttStats), - pns: pns, - largestSent: protocol.InvalidPacketNumber, - largestAcked: protocol.InvalidPacketNumber, - } -} - -type sentPacketHandler struct { - initialPackets *packetNumberSpace - handshakePackets *packetNumberSpace - appDataPackets *packetNumberSpace - - // Do we know that the peer completed address validation yet? - // Always true for the server. - peerCompletedAddressValidation bool - bytesReceived protocol.ByteCount - bytesSent protocol.ByteCount - // Have we validated the peer's address yet? - // Always true for the client. - peerAddressValidated bool - - handshakeConfirmed bool - - // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived - // example: we send an ACK for packets 90-100 with packet number 20 - // once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101 - // Only applies to the application-data packet number space. - lowestNotConfirmedAcked protocol.PacketNumber - - ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets - - bytesInFlight protocol.ByteCount - - congestion congestion.SendAlgorithmWithDebugInfos - rttStats *utils.RTTStats - - // The number of times a PTO has been sent without receiving an ack. - ptoCount uint32 - ptoMode SendMode - // The number of PTO probe packets that should be sent. - // Only applies to the application-data packet number space. - numProbesToSend int - - // The alarm timeout - alarm time.Time - - perspective protocol.Perspective - - tracer logging.ConnectionTracer - logger utils.Logger -} - -var ( - _ SentPacketHandler = &sentPacketHandler{} - _ sentPacketTracker = &sentPacketHandler{} -) - -func newSentPacketHandler( - initialPN protocol.PacketNumber, - initialMaxDatagramSize protocol.ByteCount, - rttStats *utils.RTTStats, - pers protocol.Perspective, - tracer logging.ConnectionTracer, - logger utils.Logger, -) *sentPacketHandler { - congestion := congestion.NewCubicSender( - congestion.DefaultClock{}, - rttStats, - initialMaxDatagramSize, - true, // use Reno - tracer, - ) - - return &sentPacketHandler{ - peerCompletedAddressValidation: pers == protocol.PerspectiveServer, - peerAddressValidated: pers == protocol.PerspectiveClient, - initialPackets: newPacketNumberSpace(initialPN, false, rttStats), - handshakePackets: newPacketNumberSpace(0, false, rttStats), - appDataPackets: newPacketNumberSpace(0, true, rttStats), - rttStats: rttStats, - congestion: congestion, - perspective: pers, - tracer: tracer, - logger: logger, - } -} - -func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { - if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { - // This function is called when the crypto setup seals a Handshake packet. - // If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space - // before SentPacket() was called for that Initial packet. - return - } - h.dropPackets(encLevel) -} - -func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { - if p.includedInBytesInFlight { - if p.Length > h.bytesInFlight { - panic("negative bytes_in_flight") - } - h.bytesInFlight -= p.Length - p.includedInBytesInFlight = false - } -} - -func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { - // The server won't await address validation after the handshake is confirmed. - // This applies even if we didn't receive an ACK for a Handshake packet. - if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { - h.peerCompletedAddressValidation = true - } - // remove outstanding packets from bytes_in_flight - if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { - pnSpace := h.getPacketNumberSpace(encLevel) - pnSpace.history.Iterate(func(p *Packet) (bool, error) { - h.removeFromBytesInFlight(p) - return true, nil - }) - } - // drop the packet history - //nolint:exhaustive // Not every packet number space can be dropped. - switch encLevel { - case protocol.EncryptionInitial: - h.initialPackets = nil - case protocol.EncryptionHandshake: - h.handshakePackets = nil - case protocol.Encryption0RTT: - // This function is only called when 0-RTT is rejected, - // and not when the client drops 0-RTT keys when the handshake completes. - // When 0-RTT is rejected, all application data sent so far becomes invalid. - // Delete the packets from the history and remove them from bytes_in_flight. - h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - if p.EncryptionLevel != protocol.Encryption0RTT { - return false, nil - } - h.removeFromBytesInFlight(p) - h.appDataPackets.history.Remove(p.PacketNumber) - return true, nil - }) - default: - panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) - } - if h.tracer != nil && h.ptoCount != 0 { - h.tracer.UpdatedPTOCount(0) - } - h.ptoCount = 0 - h.numProbesToSend = 0 - h.ptoMode = SendNone - h.setLossDetectionTimer() -} - -func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { - wasAmplificationLimit := h.isAmplificationLimited() - h.bytesReceived += n - if wasAmplificationLimit && !h.isAmplificationLimited() { - h.setLossDetectionTimer() - } -} - -func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { - if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { - h.peerAddressValidated = true - h.setLossDetectionTimer() - } -} - -func (h *sentPacketHandler) packetsInFlight() int { - packetsInFlight := h.appDataPackets.history.Len() - if h.handshakePackets != nil { - packetsInFlight += h.handshakePackets.history.Len() - } - if h.initialPackets != nil { - packetsInFlight += h.initialPackets.history.Len() - } - return packetsInFlight -} - -func (h *sentPacketHandler) SentPacket(packet *Packet) { - h.bytesSent += packet.Length - // For the client, drop the Initial packet number space when the first Handshake packet is sent. - if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { - h.dropPackets(protocol.EncryptionInitial) - } - isAckEliciting := h.sentPacketImpl(packet) - h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet, isAckEliciting) - if h.tracer != nil && isAckEliciting { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) - } - if isAckEliciting || !h.peerCompletedAddressValidation { - h.setLossDetectionTimer() - } -} - -func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { - switch encLevel { - case protocol.EncryptionInitial: - return h.initialPackets - case protocol.EncryptionHandshake: - return h.handshakePackets - case protocol.Encryption0RTT, protocol.Encryption1RTT: - return h.appDataPackets - default: - panic("invalid packet number space") - } -} - -func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ { - pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) - - if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { - h.logger.Debugf("Skipping packet number %d", p) - } - } - - pnSpace.largestSent = packet.PacketNumber - isAckEliciting := len(packet.Frames) > 0 - - if isAckEliciting { - pnSpace.lastAckElicitingPacketTime = packet.SendTime - packet.includedInBytesInFlight = true - h.bytesInFlight += packet.Length - if h.numProbesToSend > 0 { - h.numProbesToSend-- - } - } - h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting) - - return isAckEliciting -} - -func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { - pnSpace := h.getPacketNumberSpace(encLevel) - - largestAcked := ack.LargestAcked() - if largestAcked > pnSpace.largestSent { - return false, &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received ACK for an unsent packet", - } - } - - pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) - - // Servers complete address validation when a protected packet is received. - if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && - (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { - h.peerCompletedAddressValidation = true - h.logger.Debugf("Peer doesn't await address validation any longer.") - // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. - h.setLossDetectionTimer() - } - - priorInFlight := h.bytesInFlight - ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel) - if err != nil || len(ackedPackets) == 0 { - return false, err - } - // update the RTT, if the largest acked is newly acknowledged - if len(ackedPackets) > 0 { - if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() { - // don't use the ack delay for Initial and Handshake packets - var ackDelay time.Duration - if encLevel == protocol.Encryption1RTT { - ackDelay = utils.MinDuration(ack.DelayTime, h.rttStats.MaxAckDelay()) - } - h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) - if h.logger.Debug() { - h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) - } - h.congestion.MaybeExitSlowStart() - } - } - if err := h.detectLostPackets(rcvTime, encLevel); err != nil { - return false, err - } - var acked1RTTPacket bool - for _, p := range ackedPackets { - if p.includedInBytesInFlight && !p.declaredLost { - h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) - } - if p.EncryptionLevel == protocol.Encryption1RTT { - acked1RTTPacket = true - } - h.removeFromBytesInFlight(p) - } - - // Reset the pto_count unless the client is unsure if the server has validated the client's address. - if h.peerCompletedAddressValidation { - if h.tracer != nil && h.ptoCount != 0 { - h.tracer.UpdatedPTOCount(0) - } - h.ptoCount = 0 - } - h.numProbesToSend = 0 - - if h.tracer != nil { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) - } - - pnSpace.history.DeleteOldPackets(rcvTime) - h.setLossDetectionTimer() - return acked1RTTPacket, nil -} - -func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { - return h.lowestNotConfirmedAcked -} - -// Packets are returned in ascending packet number order. -func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { - pnSpace := h.getPacketNumberSpace(encLevel) - h.ackedPackets = h.ackedPackets[:0] - ackRangeIndex := 0 - lowestAcked := ack.LowestAcked() - largestAcked := ack.LargestAcked() - err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { - // Ignore packets below the lowest acked - if p.PacketNumber < lowestAcked { - return true, nil - } - // Break after largest acked is reached - if p.PacketNumber > largestAcked { - return false, nil - } - - if ack.HasMissingRanges() { - ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] - - for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 { - ackRangeIndex++ - ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] - } - - if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range - return true, nil - } - if p.PacketNumber > ackRange.Largest { - return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) - } - } - if p.skippedPacket { - return false, &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), - } - } - h.ackedPackets = append(h.ackedPackets, p) - return true, nil - }) - if h.logger.Debug() && len(h.ackedPackets) > 0 { - pns := make([]protocol.PacketNumber, len(h.ackedPackets)) - for i, p := range h.ackedPackets { - pns[i] = p.PacketNumber - } - h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns) - } - - for _, p := range h.ackedPackets { - if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) - } - - for _, f := range p.Frames { - if f.OnAcked != nil { - f.OnAcked(f.Frame) - } - } - if err := pnSpace.history.Remove(p.PacketNumber); err != nil { - return nil, err - } - if h.tracer != nil { - h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) - } - } - - return h.ackedPackets, err -} - -func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { - var encLevel protocol.EncryptionLevel - var lossTime time.Time - - if h.initialPackets != nil { - lossTime = h.initialPackets.lossTime - encLevel = protocol.EncryptionInitial - } - if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) { - lossTime = h.handshakePackets.lossTime - encLevel = protocol.EncryptionHandshake - } - if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) { - lossTime = h.appDataPackets.lossTime - encLevel = protocol.Encryption1RTT - } - return lossTime, encLevel -} - -// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { - // We only send application data probe packets once the handshake is confirmed, - // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. - if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { - if h.peerCompletedAddressValidation { - return - } - t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) - if h.initialPackets != nil { - return t, protocol.EncryptionInitial, true - } - return t, protocol.EncryptionHandshake, true - } - - if h.initialPackets != nil { - encLevel = protocol.EncryptionInitial - if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { - pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) - } - } - if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { - t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) - if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { - pto = t - encLevel = protocol.EncryptionHandshake - } - } - if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { - t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) - if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { - pto = t - encLevel = protocol.Encryption1RTT - } - } - return pto, encLevel, true -} - -func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { - if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() { - return true - } - if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() { - return true - } - return false -} - -func (h *sentPacketHandler) hasOutstandingPackets() bool { - return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() -} - -func (h *sentPacketHandler) setLossDetectionTimer() { - oldAlarm := h.alarm // only needed in case tracing is enabled - lossTime, encLevel := h.getLossTimeAndSpace() - if !lossTime.IsZero() { - // Early retransmit timer or time loss detection. - h.alarm = lossTime - if h.tracer != nil && h.alarm != oldAlarm { - h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) - } - return - } - - // Cancel the alarm if amplification limited. - if h.isAmplificationLimited() { - h.alarm = time.Time{} - if !oldAlarm.IsZero() { - h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil { - h.tracer.LossTimerCanceled() - } - } - return - } - - // Cancel the alarm if no packets are outstanding - if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { - h.alarm = time.Time{} - if !oldAlarm.IsZero() { - h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil { - h.tracer.LossTimerCanceled() - } - } - return - } - - // PTO alarm - ptoTime, encLevel, ok := h.getPTOTimeAndSpace() - if !ok { - if !oldAlarm.IsZero() { - h.alarm = time.Time{} - h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil { - h.tracer.LossTimerCanceled() - } - } - return - } - h.alarm = ptoTime - if h.tracer != nil && h.alarm != oldAlarm { - h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) - } -} - -func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { - pnSpace := h.getPacketNumberSpace(encLevel) - pnSpace.lossTime = time.Time{} - - maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) - lossDelay := time.Duration(timeThreshold * maxRTT) - - // Minimum time of granularity before packets are deemed lost. - lossDelay = utils.MaxDuration(lossDelay, protocol.TimerGranularity) - - // Packets sent before this time are deemed lost. - lostSendTime := now.Add(-lossDelay) - - priorInFlight := h.bytesInFlight - return pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if p.PacketNumber > pnSpace.largestAcked { - return false, nil - } - if p.declaredLost || p.skippedPacket { - return true, nil - } - - var packetLost bool - if p.SendTime.Before(lostSendTime) { - packetLost = true - if h.logger.Debug() { - h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) - } - if h.tracer != nil { - h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) - } - } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { - packetLost = true - if h.logger.Debug() { - h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) - } - if h.tracer != nil { - h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) - } - } else if pnSpace.lossTime.IsZero() { - // Note: This conditional is only entered once per call - lossTime := p.SendTime.Add(lossDelay) - if h.logger.Debug() { - h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime) - } - pnSpace.lossTime = lossTime - } - if packetLost { - p.declaredLost = true - // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted - h.removeFromBytesInFlight(p) - h.queueFramesForRetransmission(p) - if !p.IsPathMTUProbePacket { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) - } - } - return true, nil - }) -} - -func (h *sentPacketHandler) OnLossDetectionTimeout() error { - defer h.setLossDetectionTimer() - earliestLossTime, encLevel := h.getLossTimeAndSpace() - if !earliestLossTime.IsZero() { - if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) - } - if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) - } - // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel) - } - - // PTO - // When all outstanding are acknowledged, the alarm is canceled in - // setLossDetectionTimer. This doesn't reset the timer in the session though. - // When OnAlarm is called, we therefore need to make sure that there are - // actually packets outstanding. - if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { - h.ptoCount++ - h.numProbesToSend++ - if h.initialPackets != nil { - h.ptoMode = SendPTOInitial - } else if h.handshakePackets != nil { - h.ptoMode = SendPTOHandshake - } else { - return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped") - } - return nil - } - - _, encLevel, ok := h.getPTOTimeAndSpace() - if !ok { - return nil - } - if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation { - return nil - } - h.ptoCount++ - if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) - } - if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) - } - h.numProbesToSend += 2 - //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. - switch encLevel { - case protocol.EncryptionInitial: - h.ptoMode = SendPTOInitial - case protocol.EncryptionHandshake: - h.ptoMode = SendPTOHandshake - case protocol.Encryption1RTT: - // skip a packet number in order to elicit an immediate ACK - _ = h.PopPacketNumber(protocol.Encryption1RTT) - h.ptoMode = SendPTOAppData - default: - return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) - } - return nil -} - -func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { - return h.alarm -} - -func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { - pnSpace := h.getPacketNumberSpace(encLevel) - - var lowestUnacked protocol.PacketNumber - if p := pnSpace.history.FirstOutstanding(); p != nil { - lowestUnacked = p.PacketNumber - } else { - lowestUnacked = pnSpace.largestAcked + 1 - } - - pn := pnSpace.pns.Peek() - return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) -} - -func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { - return h.getPacketNumberSpace(encLevel).pns.Pop() -} - -func (h *sentPacketHandler) SendMode() SendMode { - numTrackedPackets := h.appDataPackets.history.Len() - if h.initialPackets != nil { - numTrackedPackets += h.initialPackets.history.Len() - } - if h.handshakePackets != nil { - numTrackedPackets += h.handshakePackets.history.Len() - } - - if h.isAmplificationLimited() { - h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) - return SendNone - } - // Don't send any packets if we're keeping track of the maximum number of packets. - // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, - // we will stop sending out new data when reaching MaxOutstandingSentPackets, - // but still allow sending of retransmissions and ACKs. - if numTrackedPackets >= protocol.MaxTrackedSentPackets { - if h.logger.Debug() { - h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) - } - return SendNone - } - if h.numProbesToSend > 0 { - return h.ptoMode - } - // Only send ACKs if we're congestion limited. - if !h.congestion.CanSend(h.bytesInFlight) { - if h.logger.Debug() { - h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow()) - } - return SendAck - } - if numTrackedPackets >= protocol.MaxOutstandingSentPackets { - if h.logger.Debug() { - h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) - } - return SendAck - } - return SendAny -} - -func (h *sentPacketHandler) TimeUntilSend() time.Time { - return h.congestion.TimeUntilSend(h.bytesInFlight) -} - -func (h *sentPacketHandler) HasPacingBudget() bool { - return h.congestion.HasPacingBudget() -} - -func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { - h.congestion.SetMaxDatagramSize(s) -} - -func (h *sentPacketHandler) isAmplificationLimited() bool { - if h.peerAddressValidated { - return false - } - return h.bytesSent >= amplificationFactor*h.bytesReceived -} - -func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { - pnSpace := h.getPacketNumberSpace(encLevel) - p := pnSpace.history.FirstOutstanding() - if p == nil { - return false - } - h.queueFramesForRetransmission(p) - // TODO: don't declare the packet lost here. - // Keep track of acknowledged frames instead. - h.removeFromBytesInFlight(p) - p.declaredLost = true - return true -} - -func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { - if len(p.Frames) == 0 { - panic("no frames") - } - for _, f := range p.Frames { - f.OnLost(f.Frame) - } - p.Frames = nil -} - -func (h *sentPacketHandler) ResetForRetry() error { - h.bytesInFlight = 0 - var firstPacketSendTime time.Time - h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { - if firstPacketSendTime.IsZero() { - firstPacketSendTime = p.SendTime - } - if p.declaredLost || p.skippedPacket { - return true, nil - } - h.queueFramesForRetransmission(p) - return true, nil - }) - // All application data packets sent at this point are 0-RTT packets. - // In the case of a Retry, we can assume that the server dropped all of them. - h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - if !p.declaredLost && !p.skippedPacket { - h.queueFramesForRetransmission(p) - } - return true, nil - }) - - // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. - // Otherwise, we don't know which Initial the Retry was sent in response to. - if h.ptoCount == 0 { - // Don't set the RTT to a value lower than 5ms here. - now := time.Now() - h.rttStats.UpdateRTT(utils.MaxDuration(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) - if h.logger.Debug() { - h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) - } - if h.tracer != nil { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) - } - } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) - oldAlarm := h.alarm - h.alarm = time.Time{} - if h.tracer != nil { - h.tracer.UpdatedPTOCount(0) - if !oldAlarm.IsZero() { - h.tracer.LossTimerCanceled() - } - } - h.ptoCount = 0 - return nil -} - -func (h *sentPacketHandler) SetHandshakeConfirmed() { - h.handshakeConfirmed = true - // We don't send PTOs for application data packets before the handshake completes. - // Make sure the timer is armed now, if necessary. - h.setLossDetectionTimer() -} diff --git a/internal/quic-go/ackhandler/sent_packet_handler_test.go b/internal/quic-go/ackhandler/sent_packet_handler_test.go deleted file mode 100644 index e7a19250..00000000 --- a/internal/quic-go/ackhandler/sent_packet_handler_test.go +++ /dev/null @@ -1,1386 +0,0 @@ -package ackhandler - -import ( - "fmt" - "time" - - "github.com/golang/mock/gomock" - - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("SentPacketHandler", func() { - var ( - handler *sentPacketHandler - streamFrame wire.StreamFrame - lostPackets []protocol.PacketNumber - perspective protocol.Perspective - ) - - BeforeEach(func() { perspective = protocol.PerspectiveServer }) - - JustBeforeEach(func() { - lostPackets = nil - rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, perspective, nil, utils.DefaultLogger) - streamFrame = wire.StreamFrame{ - StreamID: 5, - Data: []byte{0x13, 0x37}, - } - }) - - getPacket := func(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) *Packet { - if el, ok := handler.getPacketNumberSpace(encLevel).history.packetMap[pn]; ok { - return &el.Value - } - return nil - } - - ackElicitingPacket := func(p *Packet) *Packet { - if p.EncryptionLevel == 0 { - p.EncryptionLevel = protocol.Encryption1RTT - } - if p.Length == 0 { - p.Length = 1 - } - if p.SendTime.IsZero() { - p.SendTime = time.Now() - } - if len(p.Frames) == 0 { - p.Frames = []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, - } - } - return p - } - - nonAckElicitingPacket := func(p *Packet) *Packet { - p = ackElicitingPacket(p) - p.Frames = nil - p.LargestAcked = 1 - return p - } - - initialPacket := func(p *Packet) *Packet { - p = ackElicitingPacket(p) - p.EncryptionLevel = protocol.EncryptionInitial - return p - } - - handshakePacket := func(p *Packet) *Packet { - p = ackElicitingPacket(p) - p.EncryptionLevel = protocol.EncryptionHandshake - return p - } - - handshakePacketNonAckEliciting := func(p *Packet) *Packet { - p = nonAckElicitingPacket(p) - p.EncryptionLevel = protocol.EncryptionHandshake - return p - } - - expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { - pnSpace := handler.getPacketNumberSpace(encLevel) - var length int - pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if !p.declaredLost && !p.skippedPacket { - length++ - } - return true, nil - }) - ExpectWithOffset(1, length).To(Equal(len(expected))) - for _, p := range expected { - ExpectWithOffset(2, pnSpace.history.packetMap).To(HaveKey(p)) - } - } - - updateRTT := func(rtt time.Duration) { - handler.rttStats.UpdateRTT(rtt, 0, time.Now()) - ExpectWithOffset(1, handler.rttStats.SmoothedRTT()).To(Equal(rtt)) - } - - Context("registering sent packets", func() { - It("accepts two consecutive packets", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.EncryptionHandshake})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, EncryptionLevel: protocol.EncryptionHandshake})) - Expect(handler.handshakePackets.largestSent).To(Equal(protocol.PacketNumber(2))) - expectInPacketHistory([]protocol.PacketNumber{1, 2}, protocol.EncryptionHandshake) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - }) - - It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption0RTT})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.appDataPackets.largestSent).To(Equal(protocol.PacketNumber(2))) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - }) - - It("accepts packet number 0", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 0, EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.appDataPackets.largestSent).To(BeZero()) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.appDataPackets.largestSent).To(Equal(protocol.PacketNumber(1))) - expectInPacketHistory([]protocol.PacketNumber{0, 1}, protocol.Encryption1RTT) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - }) - - It("stores the sent time", func() { - sendTime := time.Now().Add(-time.Minute) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) - Expect(handler.appDataPackets.lastAckElicitingPacketTime).To(Equal(sendTime)) - }) - - It("stores the sent time of Initial packets", func() { - sendTime := time.Now().Add(-time.Minute) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.initialPackets.lastAckElicitingPacketTime).To(Equal(sendTime)) - }) - }) - - Context("ACK processing", func() { - JustBeforeEach(func() { - for i := protocol.PacketNumber(0); i < 10; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) - } - // Increase RTT, because the tests would be flaky otherwise - updateRTT(time.Hour) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - }) - - Context("ACK processing", func() { - It("accepts ACKs sent in packet 0", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) - }) - - It("says if a 1-RTT packet was acknowledged", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100, EncryptionLevel: protocol.Encryption0RTT})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 101, EncryptionLevel: protocol.Encryption0RTT})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102, EncryptionLevel: protocol.Encryption1RTT})) - acked1RTT, err := handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 101}}}, - protocol.Encryption1RTT, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(acked1RTT).To(BeFalse()) - acked1RTT, err = handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 101, Largest: 102}}}, - protocol.Encryption1RTT, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(acked1RTT).To(BeTrue()) - }) - - It("accepts multiple ACKs sent in the same packet", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}} - _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) - // this wouldn't happen in practice - // for testing purposes, we pretend send a different ACK frame in a duplicated packet, to be able to verify that it actually doesn't get processed - _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(4))) - }) - - It("rejects ACKs that acknowledge a skipped packet number", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 102}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received an ACK for skipped packet number: 101 (1-RTT)", - })) - }) - - It("rejects ACKs with a too high LargestAcked packet number", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received ACK for an unsent packet", - })) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - }) - - It("ignores repeated ACKs", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - }) - }) - - Context("acks the right packets", func() { - expectInPacketHistoryOrLost := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { - pnSpace := handler.getPacketNumberSpace(encLevel) - var length int - pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if !p.declaredLost { - length++ - } - return true, nil - }) - ExpectWithOffset(1, length+len(lostPackets)).To(Equal(len(expected))) - expectedLoop: - for _, p := range expected { - if _, ok := pnSpace.history.packetMap[p]; ok { - continue - } - for _, lostP := range lostPackets { - if lostP == p { - continue expectedLoop - } - } - Fail(fmt.Sprintf("Packet %d not in packet history.", p)) - } - } - - It("adjusts the LargestAcked, and adjusts the bytes in flight", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) - expectInPacketHistoryOrLost([]protocol.PacketNumber{6, 7, 8, 9}, protocol.Encryption1RTT) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4))) - }) - - It("acks packet 0", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 0}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(getPacket(0, protocol.Encryption1RTT)).To(BeNil()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9}, protocol.Encryption1RTT) - }) - - It("calls the OnAcked callback", func() { - var acked bool - ping := &wire.PingFrame{} - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 13, - Frames: []Frame{{ - Frame: ping, OnAcked: func(f wire.Frame) { - Expect(f).To(Equal(ping)) - acked = true - }, - }}, - })) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(acked).To(BeTrue()) - }) - - It("handles an ACK frame with one missing packet range", func() { - ack := &wire.AckFrame{ // lose 4 and 5 - AckRanges: []wire.AckRange{ - {Smallest: 6, Largest: 9}, - {Smallest: 1, Largest: 3}, - }, - } - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 4, 5}, protocol.Encryption1RTT) - }) - - It("does not ack packets below the LowestAcked", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 8}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 1, 2, 9}, protocol.Encryption1RTT) - }) - - It("handles an ACK with multiple missing packet ranges", func() { - ack := &wire.AckFrame{ // packets 2, 4 and 5, and 8 were lost - AckRanges: []wire.AckRange{ - {Smallest: 9, Largest: 9}, - {Smallest: 6, Largest: 7}, - {Smallest: 3, Largest: 3}, - {Smallest: 1, Largest: 1}, - }, - } - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 2, 4, 5, 8}, protocol.Encryption1RTT) - }) - - It("processes an ACK frame that would be sent after a late arrival of a packet", func() { - ack1 := &wire.AckFrame{ // 5 lost - AckRanges: []wire.AckRange{ - {Smallest: 6, Largest: 6}, - {Smallest: 1, Largest: 4}, - }, - } - _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 5, 7, 8, 9}, protocol.Encryption1RTT) - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} // now ack 5 - _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) - }) - - It("processes an ACK that contains old ACK ranges", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} - _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) - ack2 := &wire.AckFrame{ - AckRanges: []wire.AckRange{ - {Smallest: 8, Largest: 8}, - {Smallest: 3, Largest: 3}, - {Smallest: 1, Largest: 1}, - }, - } - _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 9}, protocol.Encryption1RTT) - }) - }) - - Context("calculating RTT", func() { - It("computes the RTT", func() { - now := time.Now() - // First, fake the sent times of the first, second and last packet - getPacket(1, protocol.Encryption1RTT).SendTime = now.Add(-10 * time.Minute) - getPacket(2, protocol.Encryption1RTT).SendTime = now.Add(-5 * time.Minute) - getPacket(6, protocol.Encryption1RTT).SendTime = now.Add(-1 * time.Minute) - // Now, check that the proper times are used when calculating the deltas - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) - ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) - ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} - _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second)) - }) - - It("ignores the DelayTime for Initial and Handshake packets", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - handler.rttStats.SetMaxAckDelay(time.Hour) - // make sure the rttStats have a min RTT, so that the delay is used - handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) - getPacket(1, protocol.EncryptionInitial).SendTime = time.Now().Add(-10 * time.Minute) - ack := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: 5 * time.Minute, - } - _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) - }) - - It("uses the DelayTime in the ACK frame", func() { - handler.rttStats.SetMaxAckDelay(time.Hour) - // make sure the rttStats have a min RTT, so that the delay is used - handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) - getPacket(1, protocol.Encryption1RTT).SendTime = time.Now().Add(-10 * time.Minute) - ack := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: 5 * time.Minute, - } - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) - }) - - It("limits the DelayTime in the ACK frame to max_ack_delay", func() { - handler.rttStats.SetMaxAckDelay(time.Minute) - // make sure the rttStats have a min RTT, so that the delay is used - handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) - getPacket(1, protocol.Encryption1RTT).SendTime = time.Now().Add(-10 * time.Minute) - ack := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: 5 * time.Minute, - } - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 9*time.Minute, 1*time.Second)) - }) - }) - - Context("determining which ACKs we have received an ACK for", func() { - JustBeforeEach(func() { - morePackets := []*Packet{ - { - PacketNumber: 13, - LargestAcked: 100, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, - Length: 1, - EncryptionLevel: protocol.Encryption1RTT, - }, - { - PacketNumber: 14, - LargestAcked: 200, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, - Length: 1, - EncryptionLevel: protocol.Encryption1RTT, - }, - { - PacketNumber: 15, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, - Length: 1, - EncryptionLevel: protocol.Encryption1RTT, - }, - } - for _, packet := range morePackets { - handler.SentPacket(packet) - } - }) - - It("determines which ACK we have received an ACK for", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) - }) - - It("doesn't do anything when the acked packet didn't contain an ACK", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}} - _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) - _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) - }) - - It("doesn't decrease the value", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}} - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) - _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) - }) - }) - }) - - Context("congestion", func() { - var cong *mocks.MockSendAlgorithmWithDebugInfos - - JustBeforeEach(func() { - cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) - handler.congestion = cong - }) - - It("should call OnSent", func() { - cong.EXPECT().OnPacketSent( - gomock.Any(), - protocol.ByteCount(42), - protocol.PacketNumber(1), - protocol.ByteCount(42), - true, - ) - handler.SentPacket(&Packet{ - PacketNumber: 1, - Length: 42, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) {}}}, - EncryptionLevel: protocol.Encryption1RTT, - }) - }) - - It("should call MaybeExitSlowStart and OnPacketAcked", func() { - rcvTime := time.Now().Add(-5 * time.Second) - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) - gomock.InOrder( - cong.EXPECT().MaybeExitSlowStart(), // must be called before packets are acked - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3), rcvTime), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(3), rcvTime), - ) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, rcvTime) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't call OnPacketAcked when a retransmitted packet is acked", func() { - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - // lose packet 1 - gomock.InOrder( - cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), - ) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - // don't EXPECT any further calls to the congestion controller - ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() { - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - var mtuPacketDeclaredLost bool - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, - SendTime: time.Now().Add(-time.Hour), - IsPathMTUProbePacket: true, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, - })) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - // lose packet 1, but don't EXPECT any calls to OnPacketLost() - gomock.InOrder( - cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), - ) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(mtuPacketDeclaredLost).To(BeTrue()) - Expect(handler.bytesInFlight).To(BeZero()) - }) - - It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() { - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: time.Now().Add(-30 * time.Minute)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4, SendTime: time.Now()})) - // receive the first ACK - gomock.InOrder( - cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), - ) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute)) - Expect(err).ToNot(HaveOccurred()) - // receive the second ACK - gomock.InOrder( - cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), - ) - ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} - _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("passes the bytes in flight to the congestion controller", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true) - handler.SentPacket(&Packet{ - Length: 42, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true) - handler.SendMode() - }) - - It("allows sending of ACKs when congestion limited", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - cong.EXPECT().CanSend(gomock.Any()).Return(true) - Expect(handler.SendMode()).To(Equal(SendAny)) - cong.EXPECT().CanSend(gomock.Any()).Return(false) - Expect(handler.SendMode()).To(Equal(SendAck)) - }) - - It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - for i := protocol.PacketNumber(0); i < protocol.MaxOutstandingSentPackets; i++ { - Expect(handler.SendMode()).To(Equal(SendAny)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) - } - Expect(handler.SendMode()).To(Equal(SendAck)) - }) - - It("allows PTOs, even when congestion limited", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - // note that we don't EXPECT a call to GetCongestionWindow - // that means retransmissions are sent without considering the congestion window - handler.numProbesToSend = 1 - handler.ptoMode = SendPTOHandshake - Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) - }) - - It("says if it has pacing budget", func() { - cong.EXPECT().HasPacingBudget().Return(true) - Expect(handler.HasPacingBudget()).To(BeTrue()) - cong.EXPECT().HasPacingBudget().Return(false) - Expect(handler.HasPacingBudget()).To(BeFalse()) - }) - - It("returns the pacing delay", func() { - t := time.Now() - cong.EXPECT().TimeUntilSend(gomock.Any()).Return(t) - Expect(handler.TimeUntilSend()).To(Equal(t)) - }) - }) - - It("doesn't set an alarm if there are no outstanding packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - - It("does nothing on OnAlarm if there are no outstanding packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - Context("probe packets", func() { - It("queues a probe packet", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) - queued := handler.QueueProbePacket(protocol.Encryption1RTT) - Expect(queued).To(BeTrue()) - Expect(lostPackets).To(Equal([]protocol.PacketNumber{10})) - }) - - It("says when it can't queue a probe packet", func() { - queued := handler.QueueProbePacket(protocol.Encryption1RTT) - Expect(queued).To(BeFalse()) - }) - - It("implements exponential backoff", func() { - handler.peerAddressValidated = true - handler.SetHandshakeConfirmed() - sendTime := time.Now().Add(-time.Hour) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) - timeout := handler.GetLossDetectionTimeout().Sub(sendTime) - Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(timeout)) - handler.ptoCount = 1 - handler.setLossDetectionTimer() - Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(2 * timeout)) - handler.ptoCount = 2 - handler.setLossDetectionTimer() - Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) - }) - - It("reset the PTO count when receiving an ACK", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - now := time.Now() - handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) - Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.ptoCount).To(BeZero()) - }) - - It("resets the PTO mode and PTO count when a packet number space is dropped", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - - now := time.Now() - handler.rttStats.UpdateRTT(time.Second/2, 0, now) - Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second / 2)) - Expect(handler.rttStats.PTO(true)).To(And( - BeNumerically(">", time.Second), - BeNumerically("<", 2*time.Second), - )) - sendTimeHandshake := now.Add(-2 * time.Minute) - sendTimeAppData := now.Add(-time.Minute) - - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, - EncryptionLevel: protocol.EncryptionHandshake, - SendTime: sendTimeHandshake, - })) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 2, - SendTime: sendTimeAppData, - })) - - // PTO timer based on the Handshake packet - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) - Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) - handler.SetHandshakeConfirmed() - handler.DropPackets(protocol.EncryptionHandshake) - // PTO timer based on the 1-RTT packet - Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 - Expect(handler.SendMode()).ToNot(Equal(SendPTOHandshake)) - Expect(handler.ptoCount).To(BeZero()) - }) - - It("allows two 1-RTT PTOs", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - var lostPackets []protocol.PacketNumber - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, - SendTime: time.Now().Add(-time.Hour), - Frames: []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, - }, - })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) - }) - - It("skips a packet number for 1-RTT PTOs", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - var lostPackets []protocol.PacketNumber - pn := handler.PopPacketNumber(protocol.Encryption1RTT) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: pn, - SendTime: time.Now().Add(-time.Hour), - Frames: []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, - }, - })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - // The packet number generator might have introduced another skipped a packet number. - Expect(handler.PopPacketNumber(protocol.Encryption1RTT)).To(BeNumerically(">=", pn+2)) - }) - - It("only counts ack-eliciting packets as probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - for p := protocol.PacketNumber(3); p < 30; p++ { - handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: p})) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - } - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 30})) - Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) - }) - - It("gets two probe packets if PTO expires", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - - updateRTT(time.Hour) - Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) - - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) - - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO - Expect(handler.ptoCount).To(BeEquivalentTo(2)) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) - - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - It("gets two probe packets if PTO expires, for Handshake packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) - - updateRTT(time.Hour) - Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) - - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 4})) - - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - It("doesn't send 1-RTT probe packets before the handshake completes", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) - updateRTT(time.Hour) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - Expect(handler.SendMode()).To(Equal(SendAny)) - handler.SetHandshakeConfirmed() - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - }) - - It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - updateRTT(time.Second) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOAppData)) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - It("handles ACKs for the original packet", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) - updateRTT(time.Second) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - }) - - It("doesn't set the PTO timer for Path MTU probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - updateRTT(time.Second) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true})) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - }) - - Context("amplification limit, for the server", func() { - It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { - handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address - handler.ReceivedBytes(200) - handler.SentPacket(&Packet{ - PacketNumber: 1, - Length: 599, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - Expect(handler.SendMode()).To(Equal(SendAny)) - handler.SentPacket(&Packet{ - PacketNumber: 2, - Length: 1, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - Expect(handler.SendMode()).To(Equal(SendNone)) - }) - - It("cancels the loss detection timer when it is amplification limited, and resets it when becoming unblocked", func() { - handler.ReceivedBytes(300) - handler.SentPacket(&Packet{ - PacketNumber: 1, - Length: 900, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - // Amplification limited. We don't need to set a timer now. - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - // Unblock the server. Now we should fire up the timer. - handler.ReceivedBytes(1) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - }) - - It("resets the loss detection timer when the client's address is validated", func() { - handler.ReceivedBytes(300) - handler.SentPacket(&Packet{ - PacketNumber: 1, - Length: 900, - EncryptionLevel: protocol.EncryptionHandshake, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - // Amplification limited. We don't need to set a timer now. - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - handler.ReceivedPacket(protocol.EncryptionHandshake) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - }) - - It("cancels the loss detection alarm when all Handshake packets are acknowledged", func() { - t := time.Now().Add(-time.Second) - handler.ReceivedBytes(99999) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: t})) - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 3, SendTime: t})) - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 4, SendTime: t})) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 4}}}, protocol.EncryptionHandshake, time.Now()) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - }) - - Context("amplification limit, for the client", func() { - BeforeEach(func() { - perspective = protocol.PerspectiveClient - }) - - It("sends an Initial packet to unblock the server", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - _, err := handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - protocol.EncryptionInitial, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - // No packets are outstanding at this point. - // Make sure that a probe packet is sent. - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - - // send a single packet to unblock the server - handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendAny)) - - // Now receive an ACK for a Handshake packet. - // This tells the client that the server completed address validation. - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1})) - _, err = handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - protocol.EncryptionHandshake, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - // Make sure that no timer is set at this point. - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - - It("sends a Handshake packet to unblock the server, if Initial keys were already dropped", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - _, err := handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - protocol.EncryptionInitial, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - - handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1})) // also drops Initial packets - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) - - // Now receive an ACK for this packet, and send another one. - _, err = handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - protocol.EncryptionHandshake, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 2})) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - - It("doesn't send a packet to unblock the server after handshake confirmation, even if no Handshake ACK was received", func() { - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1})) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) - // confirm the handshake - handler.DropPackets(protocol.EncryptionHandshake) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - - It("correctly sets the timer after the Initial packet number space has been dropped", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-42 * time.Second)})) - _, err := handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - protocol.EncryptionInitial, - time.Now(), - ) - Expect(err).ToNot(HaveOccurred()) - handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1, SendTime: time.Now()})) - Expect(handler.initialPackets).To(BeNil()) - - pto := handler.rttStats.PTO(false) - Expect(pto).ToNot(BeZero()) - Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", time.Now().Add(pto), 10*time.Millisecond)) - }) - - It("doesn't reset the PTO count when receiving an ACK", func() { - now := time.Now() - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) - Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - }) - }) - - Context("Packet-based loss detection", func() { - It("declares packet below the packet loss threshold as lost", func() { - now := time.Now() - for i := protocol.PacketNumber(1); i <= 6; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) - } - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistory([]protocol.PacketNumber{4, 5}, protocol.Encryption1RTT) - Expect(lostPackets).To(Equal([]protocol.PacketNumber{1, 2, 3})) - }) - }) - - Context("Delay-based loss detection", func() { - It("immediately detects old packets as lost when receiving an ACK", func() { - now := time.Now() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Second)})) - Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) - - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) - Expect(err).ToNot(HaveOccurred()) - // no need to set an alarm, since packet 1 was already declared lost - Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) - Expect(handler.bytesInFlight).To(BeZero()) - }) - - It("sets the early retransmit alarm", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.handshakeConfirmed = true - now := time.Now() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: now})) - Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) - - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) - - // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. - Expect(handler.GetLossDetectionTimeout().Sub(getPacket(1, protocol.Encryption1RTT).SendTime)).To(Equal(time.Second * 9 / 8)) - Expect(handler.SendMode()).To(Equal(SendAny)) - - expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.Encryption1RTT) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - expectInPacketHistory([]protocol.PacketNumber{3}, protocol.Encryption1RTT) - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - It("sets the early retransmit alarm for crypto packets", func() { - handler.ReceivedBytes(1000) - now := time.Now() - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 3, SendTime: now})) - Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) - - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, now.Add(-time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) - - // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. - Expect(handler.GetLossDetectionTimeout().Sub(getPacket(1, protocol.EncryptionInitial).SendTime)).To(Equal(time.Second * 9 / 8)) - Expect(handler.SendMode()).To(Equal(SendAny)) - - expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.EncryptionInitial) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - expectInPacketHistory([]protocol.PacketNumber{3}, protocol.EncryptionInitial) - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - - It("sets the early retransmit alarm for Path MTU probe packets", func() { - var mtuPacketDeclaredLost bool - now := time.Now() - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, - SendTime: now.Add(-3 * time.Second), - IsPathMTUProbePacket: true, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, - })) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-3 * time.Second)})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(mtuPacketDeclaredLost).To(BeFalse()) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(mtuPacketDeclaredLost).To(BeTrue()) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - }) - }) - - Context("crypto packets", func() { - It("rejects an ACK that acks packets with a higher encryption level", func() { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 13, - EncryptionLevel: protocol.Encryption1RTT, - })) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - _, err := handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now()) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received ACK for an unsent packet", - })) - }) - - It("deletes Initial packets, as a server", func() { - for i := protocol.PacketNumber(0); i < 6; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.EncryptionInitial, - })) - } - for i := protocol.PacketNumber(0); i < 10; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.EncryptionHandshake, - })) - } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.DropPackets(protocol.EncryptionInitial) - Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - Expect(handler.initialPackets).To(BeNil()) - Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) - }) - - Context("deleting Initials", func() { - BeforeEach(func() { perspective = protocol.PerspectiveClient }) - - It("deletes Initials, as a client", func() { - for i := protocol.PacketNumber(0); i < 6; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.EncryptionInitial, - })) - } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - handler.DropPackets(protocol.EncryptionInitial) - // DropPackets should be ignored for clients and the Initial packet number space. - // It has to be possible to send another Initial packets after this function was called. - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.EncryptionInitial, - })) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) - // Sending a Handshake packet triggers dropping of Initials. - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, - EncryptionLevel: protocol.EncryptionHandshake, - })) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) - Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission - Expect(handler.initialPackets).To(BeNil()) - Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) - }) - }) - - It("deletes Handshake packets", func() { - for i := protocol.PacketNumber(0); i < 6; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.EncryptionHandshake, - })) - } - for i := protocol.PacketNumber(0); i < 10; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.Encryption1RTT, - })) - } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.DropPackets(protocol.EncryptionHandshake) - Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - Expect(handler.handshakePackets).To(BeNil()) - }) - - It("doesn't retransmit 0-RTT packets when 0-RTT keys are dropped", func() { - for i := protocol.PacketNumber(0); i < 6; i++ { - if i == 3 { - continue - } - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, - EncryptionLevel: protocol.Encryption0RTT, - })) - } - for i := protocol.PacketNumber(6); i < 12; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) - } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11))) - handler.DropPackets(protocol.Encryption0RTT) - Expect(lostPackets).To(BeEmpty()) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - }) - - It("cancels the PTO when dropping a packet number space", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - now := time.Now() - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) - handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) - Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) - Expect(handler.ptoCount).To(BeEquivalentTo(1)) - handler.DropPackets(protocol.EncryptionHandshake) - Expect(handler.ptoCount).To(BeZero()) - Expect(handler.SendMode()).To(Equal(SendAny)) - }) - }) - - Context("peeking and popping packet number", func() { - It("peeks and pops the initial packet number", func() { - pn, _ := handler.PeekPacketNumber(protocol.EncryptionInitial) - Expect(pn).To(Equal(protocol.PacketNumber(42))) - Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(Equal(protocol.PacketNumber(42))) - }) - - It("peeks and pops beyond the initial packet number", func() { - Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(Equal(protocol.PacketNumber(42))) - Expect(handler.PopPacketNumber(protocol.EncryptionInitial)).To(BeNumerically(">", 42)) - }) - - It("starts at 0 for handshake and application-data packet number space", func() { - pn, _ := handler.PeekPacketNumber(protocol.EncryptionHandshake) - Expect(pn).To(BeZero()) - Expect(handler.PopPacketNumber(protocol.EncryptionHandshake)).To(BeZero()) - pn, _ = handler.PeekPacketNumber(protocol.Encryption1RTT) - Expect(pn).To(BeZero()) - Expect(handler.PopPacketNumber(protocol.Encryption1RTT)).To(BeZero()) - }) - }) - - Context("for the client", func() { - BeforeEach(func() { - perspective = protocol.PerspectiveClient - }) - - It("considers the server's address validated right away", func() { - }) - - It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 42})) - Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.bytesInFlight).ToNot(BeZero()) - Expect(handler.SendMode()).To(Equal(SendAny)) - // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(lostPackets).To(Equal([]protocol.PacketNumber{42})) - Expect(handler.bytesInFlight).To(BeZero()) - Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - Expect(handler.SendMode()).To(Equal(SendAny)) - Expect(handler.ptoCount).To(BeZero()) - }) - - It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() { - var lostInitial, lost0RTT bool - handler.SentPacket(&Packet{ - PacketNumber: 13, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{ - {Frame: &wire.CryptoFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lostInitial = true }}, - }, - Length: 100, - }) - pn := handler.PopPacketNumber(protocol.Encryption0RTT) - handler.SentPacket(&Packet{ - PacketNumber: pn, - EncryptionLevel: protocol.Encryption0RTT, - Frames: []Frame{ - {Frame: &wire.StreamFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lost0RTT = true }}, - }, - Length: 999, - }) - Expect(handler.bytesInFlight).ToNot(BeZero()) - // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.bytesInFlight).To(BeZero()) - Expect(lostInitial).To(BeTrue()) - Expect(lost0RTT).To(BeTrue()) - - // make sure we keep increasing the packet number for 0-RTT packets - Expect(handler.PopPacketNumber(protocol.Encryption0RTT)).To(BeNumerically(">", pn)) - }) - - It("uses a Retry for an RTT estimate, if it was not retransmitted", func() { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 42, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Millisecond), - })) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 43, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Millisecond), - })) - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.rttStats.SmoothedRTT()).To(BeNumerically("~", 500*time.Millisecond, 100*time.Millisecond)) - }) - - It("uses a Retry for an RTT estimate, but doesn't set the RTT to a value lower than 5ms", func() { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 42, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Microsecond), - })) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 43, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Microsecond), - })) - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.rttStats.SmoothedRTT()).To(Equal(minRTTAfterRetry)) - }) - - It("doesn't use a Retry for an RTT estimate, if it was not retransmitted", func() { - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 42, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-800 * time.Millisecond), - })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 43, - EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-100 * time.Millisecond), - })) - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) - }) - }) -}) diff --git a/internal/quic-go/ackhandler/sent_packet_history.go b/internal/quic-go/ackhandler/sent_packet_history.go deleted file mode 100644 index f6acc2be..00000000 --- a/internal/quic-go/ackhandler/sent_packet_history.go +++ /dev/null @@ -1,108 +0,0 @@ -package ackhandler - -import ( - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type sentPacketHistory struct { - rttStats *utils.RTTStats - packetList *PacketList - packetMap map[protocol.PacketNumber]*PacketElement - highestSent protocol.PacketNumber -} - -func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { - return &sentPacketHistory{ - rttStats: rttStats, - packetList: NewPacketList(), - packetMap: make(map[protocol.PacketNumber]*PacketElement), - highestSent: protocol.InvalidPacketNumber, - } -} - -func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { - if p.PacketNumber <= h.highestSent { - panic("non-sequential packet number use") - } - // Skipped packet numbers. - for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { - el := h.packetList.PushBack(Packet{ - PacketNumber: pn, - EncryptionLevel: p.EncryptionLevel, - SendTime: p.SendTime, - skippedPacket: true, - }) - h.packetMap[pn] = el - } - h.highestSent = p.PacketNumber - - if isAckEliciting { - el := h.packetList.PushBack(*p) - h.packetMap[p.PacketNumber] = el - } -} - -// Iterate iterates through all packets. -func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { - cont := true - var next *PacketElement - for el := h.packetList.Front(); cont && el != nil; el = next { - var err error - next = el.Next() - cont, err = cb(&el.Value) - if err != nil { - return err - } - } - return nil -} - -// FirstOutStanding returns the first outstanding packet. -func (h *sentPacketHistory) FirstOutstanding() *Packet { - for el := h.packetList.Front(); el != nil; el = el.Next() { - p := &el.Value - if !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket { - return p - } - } - return nil -} - -func (h *sentPacketHistory) Len() int { - return len(h.packetMap) -} - -func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { - el, ok := h.packetMap[p] - if !ok { - return fmt.Errorf("packet %d not found in sent packet history", p) - } - h.packetList.Remove(el) - delete(h.packetMap, p) - return nil -} - -func (h *sentPacketHistory) HasOutstandingPackets() bool { - return h.FirstOutstanding() != nil -} - -func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { - maxAge := 3 * h.rttStats.PTO(false) - var nextEl *PacketElement - for el := h.packetList.Front(); el != nil; el = nextEl { - nextEl = el.Next() - p := el.Value - if p.SendTime.After(now.Add(-maxAge)) { - break - } - if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes - continue - } - delete(h.packetMap, p.PacketNumber) - h.packetList.Remove(el) - } -} diff --git a/internal/quic-go/ackhandler/sent_packet_history_test.go b/internal/quic-go/ackhandler/sent_packet_history_test.go deleted file mode 100644 index b3cf2f0e..00000000 --- a/internal/quic-go/ackhandler/sent_packet_history_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package ackhandler - -import ( - "errors" - "time" - - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("SentPacketHistory", func() { - var ( - hist *sentPacketHistory - rttStats *utils.RTTStats - ) - - expectInHistory := func(packetNumbers []protocol.PacketNumber) { - var mapLen int - for _, el := range hist.packetMap { - if !el.Value.skippedPacket { - mapLen++ - } - } - var listLen int - for el := hist.packetList.Front(); el != nil; el = el.Next() { - if !el.Value.skippedPacket { - listLen++ - } - } - ExpectWithOffset(1, mapLen).To(Equal(len(packetNumbers))) - ExpectWithOffset(1, listLen).To(Equal(len(packetNumbers))) - i := 0 - err := hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - return true, nil - } - pn := packetNumbers[i] - ExpectWithOffset(1, p.PacketNumber).To(Equal(pn)) - ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn)) - i++ - return true, nil - }) - Expect(err).ToNot(HaveOccurred()) - } - - BeforeEach(func() { - rttStats = utils.NewRTTStats() - hist = newSentPacketHistory(rttStats) - }) - - It("saves sent packets", func() { - hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 3}, true) - hist.SentPacket(&Packet{PacketNumber: 4}, true) - expectInHistory([]protocol.PacketNumber{1, 3, 4}) - }) - - It("doesn't save non-ack-eliciting packets", func() { - hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 3}, false) - hist.SentPacket(&Packet{PacketNumber: 4}, true) - expectInHistory([]protocol.PacketNumber{1, 4}) - for el := hist.packetList.Front(); el != nil; el = el.Next() { - Expect(el.Value.PacketNumber).ToNot(Equal(protocol.PacketNumber(3))) - } - }) - - It("gets the length", func() { - hist.SentPacket(&Packet{PacketNumber: 0}, true) - hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 2}, true) - Expect(hist.Len()).To(Equal(3)) - }) - - Context("getting the first outstanding packet", func() { - It("gets nil, if there are no packets", func() { - Expect(hist.FirstOutstanding()).To(BeNil()) - }) - - It("gets the first outstanding packet", func() { - hist.SentPacket(&Packet{PacketNumber: 2}, true) - hist.SentPacket(&Packet{PacketNumber: 3}, true) - front := hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) - }) - - It("doesn't regard path MTU packets as outstanding", func() { - hist.SentPacket(&Packet{PacketNumber: 2}, true) - hist.SentPacket(&Packet{PacketNumber: 4, IsPathMTUProbePacket: true}, true) - front := hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) - }) - }) - - It("removes packets", func() { - hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 4}, true) - hist.SentPacket(&Packet{PacketNumber: 8}, true) - err := hist.Remove(4) - Expect(err).ToNot(HaveOccurred()) - expectInHistory([]protocol.PacketNumber{1, 8}) - }) - - It("errors when trying to remove a non existing packet", func() { - hist.SentPacket(&Packet{PacketNumber: 1}, true) - err := hist.Remove(2) - Expect(err).To(MatchError("packet 2 not found in sent packet history")) - }) - - Context("iterating", func() { - BeforeEach(func() { - hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 4}, true) - hist.SentPacket(&Packet{PacketNumber: 8}, true) - }) - - It("iterates over all packets", func() { - var iterations []protocol.PacketNumber - Expect(hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - return true, nil - } - iterations = append(iterations, p.PacketNumber) - return true, nil - })).To(Succeed()) - Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) - }) - - It("also iterates over skipped packets", func() { - var packets, skippedPackets []protocol.PacketNumber - Expect(hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - skippedPackets = append(skippedPackets, p.PacketNumber) - } else { - packets = append(packets, p.PacketNumber) - } - return true, nil - })).To(Succeed()) - Expect(packets).To(Equal([]protocol.PacketNumber{1, 4, 8})) - Expect(skippedPackets).To(Equal([]protocol.PacketNumber{0, 2, 3, 5, 6, 7})) - }) - - It("stops iterating", func() { - var iterations []protocol.PacketNumber - Expect(hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - return true, nil - } - iterations = append(iterations, p.PacketNumber) - return p.PacketNumber != 4, nil - })).To(Succeed()) - Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) - }) - - It("returns the error", func() { - testErr := errors.New("test error") - var iterations []protocol.PacketNumber - Expect(hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - return true, nil - } - iterations = append(iterations, p.PacketNumber) - if p.PacketNumber == 4 { - return false, testErr - } - return true, nil - })).To(MatchError(testErr)) - Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) - }) - - It("allows deletions", func() { - var iterations []protocol.PacketNumber - Expect(hist.Iterate(func(p *Packet) (bool, error) { - if p.skippedPacket { - return true, nil - } - iterations = append(iterations, p.PacketNumber) - if p.PacketNumber == 4 { - Expect(hist.Remove(4)).To(Succeed()) - } - return true, nil - })).To(Succeed()) - expectInHistory([]protocol.PacketNumber{1, 8}) - Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) - }) - }) - - Context("outstanding packets", func() { - It("says if it has outstanding packets", func() { - Expect(hist.HasOutstandingPackets()).To(BeFalse()) - hist.SentPacket(&Packet{EncryptionLevel: protocol.Encryption1RTT}, true) - Expect(hist.HasOutstandingPackets()).To(BeTrue()) - }) - - It("accounts for deleted packets", func() { - hist.SentPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.Encryption1RTT, - }, true) - Expect(hist.HasOutstandingPackets()).To(BeTrue()) - Expect(hist.Remove(10)).To(Succeed()) - Expect(hist.HasOutstandingPackets()).To(BeFalse()) - }) - - It("counts the number of packets", func() { - hist.SentPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.Encryption1RTT, - }, true) - hist.SentPacket(&Packet{ - PacketNumber: 11, - EncryptionLevel: protocol.Encryption1RTT, - }, true) - Expect(hist.Remove(11)).To(Succeed()) - Expect(hist.HasOutstandingPackets()).To(BeTrue()) - Expect(hist.Remove(10)).To(Succeed()) - Expect(hist.HasOutstandingPackets()).To(BeFalse()) - }) - }) - - Context("deleting old packets", func() { - const pto = 3 * time.Second - - BeforeEach(func() { - rttStats.UpdateRTT(time.Second, 0, time.Time{}) - Expect(rttStats.PTO(false)).To(Equal(pto)) - }) - - It("deletes old packets after 3 PTOs", func() { - now := time.Now() - hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) - expectInHistory([]protocol.PacketNumber{10}) - hist.DeleteOldPackets(now.Add(-time.Nanosecond)) - expectInHistory([]protocol.PacketNumber{10}) - hist.DeleteOldPackets(now) - expectInHistory([]protocol.PacketNumber{}) - }) - - It("doesn't delete a packet if it hasn't been declared lost yet", func() { - now := time.Now() - hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) - hist.SentPacket(&Packet{PacketNumber: 11, SendTime: now.Add(-3 * pto), declaredLost: false}, true) - expectInHistory([]protocol.PacketNumber{10, 11}) - hist.DeleteOldPackets(now) - expectInHistory([]protocol.PacketNumber{11}) - }) - - It("deletes skipped packets", func() { - now := time.Now() - hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto)}, true) - expectInHistory([]protocol.PacketNumber{10}) - Expect(hist.Len()).To(Equal(11)) - hist.DeleteOldPackets(now) - expectInHistory([]protocol.PacketNumber{10}) // the packet was not declared lost - Expect(hist.Len()).To(Equal(1)) - }) - }) -}) diff --git a/internal/quic-go/buffer_pool.go b/internal/quic-go/buffer_pool.go deleted file mode 100644 index c8d50a43..00000000 --- a/internal/quic-go/buffer_pool.go +++ /dev/null @@ -1,80 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -type packetBuffer struct { - Data []byte - - // refCount counts how many packets Data is used in. - // It doesn't support concurrent use. - // It is > 1 when used for coalesced packet. - refCount int -} - -// Split increases the refCount. -// It must be called when a packet buffer is used for more than one packet, -// e.g. when splitting coalesced packets. -func (b *packetBuffer) Split() { - b.refCount++ -} - -// Decrement decrements the reference counter. -// It doesn't put the buffer back into the pool. -func (b *packetBuffer) Decrement() { - b.refCount-- - if b.refCount < 0 { - panic("negative packetBuffer refCount") - } -} - -// MaybeRelease puts the packet buffer back into the pool, -// if the reference counter already reached 0. -func (b *packetBuffer) MaybeRelease() { - // only put the packetBuffer back if it's not used any more - if b.refCount == 0 { - b.putBack() - } -} - -// Release puts back the packet buffer into the pool. -// It should be called when processing is definitely finished. -func (b *packetBuffer) Release() { - b.Decrement() - if b.refCount != 0 { - panic("packetBuffer refCount not zero") - } - b.putBack() -} - -// Len returns the length of Data -func (b *packetBuffer) Len() protocol.ByteCount { - return protocol.ByteCount(len(b.Data)) -} - -func (b *packetBuffer) putBack() { - if cap(b.Data) != int(protocol.MaxPacketBufferSize) { - panic("putPacketBuffer called with packet of wrong size!") - } - bufferPool.Put(b) -} - -var bufferPool sync.Pool - -func getPacketBuffer() *packetBuffer { - buf := bufferPool.Get().(*packetBuffer) - buf.refCount = 1 - buf.Data = buf.Data[:0] - return buf -} - -func init() { - bufferPool.New = func() interface{} { - return &packetBuffer{ - Data: make([]byte, 0, protocol.MaxPacketBufferSize), - } - } -} diff --git a/internal/quic-go/buffer_pool_test.go b/internal/quic-go/buffer_pool_test.go deleted file mode 100644 index 8e28ad02..00000000 --- a/internal/quic-go/buffer_pool_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package quic - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Buffer Pool", func() { - It("returns buffers of cap", func() { - buf := getPacketBuffer() - Expect(buf.Data).To(HaveCap(int(protocol.MaxPacketBufferSize))) - }) - - It("releases buffers", func() { - buf := getPacketBuffer() - buf.Release() - }) - - It("gets the length", func() { - buf := getPacketBuffer() - buf.Data = append(buf.Data, []byte("foobar")...) - Expect(buf.Len()).To(BeEquivalentTo(6)) - }) - - It("panics if wrong-sized buffers are passed", func() { - buf := getPacketBuffer() - buf.Data = make([]byte, 10) - Expect(func() { buf.Release() }).To(Panic()) - }) - - It("panics if it is released twice", func() { - buf := getPacketBuffer() - buf.Release() - Expect(func() { buf.Release() }).To(Panic()) - }) - - It("panics if it is decremented too many times", func() { - buf := getPacketBuffer() - buf.Decrement() - Expect(func() { buf.Decrement() }).To(Panic()) - }) - - It("waits until all parts have been released", func() { - buf := getPacketBuffer() - buf.Split() - buf.Split() - // now we have 3 parts - buf.Decrement() - buf.Decrement() - buf.Decrement() - Expect(func() { buf.Decrement() }).To(Panic()) - }) -}) diff --git a/internal/quic-go/client.go b/internal/quic-go/client.go deleted file mode 100644 index 3bcdbece..00000000 --- a/internal/quic-go/client.go +++ /dev/null @@ -1,339 +0,0 @@ -package quic - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net" - "strings" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type client struct { - sconn sendConn - // If the client is created with DialAddr, we create a packet conn. - // If it is started with Dial, we take a packet conn as a parameter. - createdPacketConn bool - - use0RTT bool - - packetHandlers packetHandlerManager - - tlsConf *tls.Config - config *Config - - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID - - initialPacketNumber protocol.PacketNumber - hasNegotiatedVersion bool - version protocol.VersionNumber - - handshakeChan chan struct{} - - conn quicConn - - tracer logging.ConnectionTracer - tracingID uint64 - logger utils.Logger -} - -var ( - // make it possible to mock connection ID generation in the tests - generateConnectionID = protocol.GenerateConnectionID - generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial -) - -// DialAddr establishes a new QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. -func DialAddr( - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return DialAddrContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. -func DialAddrEarly( - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. -// See DialAddrEarly for details -func DialAddrEarlyContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) - if err != nil { - return nil, err - } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") - return conn, nil -} - -// DialAddrContext establishes a new QUIC connection to a server using the provided context. -// See DialAddr for details. -func DialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - -func dialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, -) (quicConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) -} - -// Dial establishes a new QUIC connection to a server using a net.PacketConn. If -// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. The same PacketConn can be used for multiple calls to Dial and -// Listen, QUIC connection IDs are used for demultiplexing the different -// connections. The host parameter is used for SNI. The tls.Config must define -// an application protocol (using NextProtos). -func Dial( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) -} - -// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. -// The same PacketConn can be used for multiple calls to Dial and Listen, -// QUIC connection IDs are used for demultiplexing the different connections. -// The host parameter is used for SNI. -// The tls.Config must define an application protocol (using NextProtos). -func DialEarly( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) -} - -// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. -// See DialEarly for details. -func DialEarlyContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) -} - -// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. -// See Dial for details. -func DialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) -} - -func dialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, - createdPacketConn bool, -) (quicConn, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } - c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, ConnectionTracingKey, c.tracingID), - protocol.PerspectiveClient, - c.destConnID, - ) - } - if c.tracer != nil { - c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) - } - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - -func newClient( - pconn net.PacketConn, - remoteAddr net.Addr, - config *Config, - tlsConf *tls.Config, - host string, - use0RTT bool, - createdPacketConn bool, -) (*client, error) { - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - if tlsConf.ServerName == "" { - sni := host - if strings.IndexByte(sni, ':') != -1 { - var err error - sni, _, err = net.SplitHostPort(sni) - if err != nil { - return nil, err - } - } - - tlsConf.ServerName = sni - } - - // check that all versions are actually supported - if config != nil { - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - } - - srcConnID, err := generateConnectionID(config.ConnectionIDLength) - if err != nil { - return nil, err - } - destConnID, err := generateConnectionIDForInitial() - if err != nil { - return nil, err - } - c := &client{ - srcConnID: srcConnID, - destConnID: destConnID, - sconn: newSendPconn(pconn, remoteAddr), - createdPacketConn: createdPacketConn, - use0RTT: use0RTT, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), - } - return c, nil -} - -func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - - c.conn = newClientConnection( - c.sconn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - ) - c.packetHandlers.Add(c.srcConnID, c.conn) - - errorChan := make(chan error, 1) - go func() { - err := c.conn.run() // returns as soon as the connection is closed - - if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { - c.packetHandlers.Destroy() - } - errorChan <- err - }() - - // only set when we're using 0-RTT - // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. - var earlyConnChan <-chan struct{} - if c.use0RTT { - earlyConnChan = c.conn.earlyConnReady() - } - - select { - case <-ctx.Done(): - c.conn.shutdown() - return ctx.Err() - case err := <-errorChan: - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - } - return err - case <-earlyConnChan: - // ready to send 0-RTT data - return nil - case <-c.conn.HandshakeComplete().Done(): - // handshake successfully completed - return nil - } -} diff --git a/internal/quic-go/client_test.go b/internal/quic-go/client_test.go deleted file mode 100644 index 5dd3bf00..00000000 --- a/internal/quic-go/client_test.go +++ /dev/null @@ -1,611 +0,0 @@ -package quic - -import ( - "context" - "crypto/tls" - "errors" - "net" - "os" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Client", func() { - var ( - cl *client - packetConn *MockPacketConn - addr net.Addr - connID protocol.ConnectionID - mockMultiplexer *MockMultiplexer - origMultiplexer multiplexer - tlsConf *tls.Config - tracer *mocklogging.MockConnectionTracer - config *Config - - originalClientConnConstructor func( - conn sendConn, - runner connRunner, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - conf *Config, - tlsConf *tls.Config, - initialPacketNumber protocol.PacketNumber, - enable0RTT bool, - hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, - tracingID uint64, - logger utils.Logger, - v protocol.VersionNumber, - ) quicConn - ) - - BeforeEach(func() { - tlsConf = &tls.Config{NextProtos: []string{"proto1"}} - connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} - originalClientConnConstructor = newClientConnection - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - tr := mocklogging.NewMockTracer(mockCtrl) - tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) - config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} - Eventually(areConnsRunning).Should(BeFalse()) - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) - packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - cl = &client{ - srcConnID: connID, - destConnID: connID, - version: protocol.VersionTLS, - sconn: newSendPconn(packetConn, addr), - tracer: tracer, - logger: utils.DefaultLogger, - } - getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer = NewMockMultiplexer(mockCtrl) - origMultiplexer = connMuxer - connMuxer = mockMultiplexer - }) - - AfterEach(func() { - connMuxer = origMultiplexer - newClientConnection = originalClientConnConstructor - }) - - AfterEach(func() { - if s, ok := cl.conn.(*connection); ok { - s.shutdown() - } - Eventually(areConnsRunning).Should(BeFalse()) - }) - - Context("Dialing", func() { - var origGenerateConnectionID func(int) (protocol.ConnectionID, error) - var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) - - BeforeEach(func() { - origGenerateConnectionID = generateConnectionID - origGenerateConnectionIDForInitial = generateConnectionIDForInitial - generateConnectionID = func(int) (protocol.ConnectionID, error) { - return connID, nil - } - generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { - return connID, nil - } - }) - - AfterEach(func() { - generateConnectionID = origGenerateConnectionID - generateConnectionIDForInitial = origGenerateConnectionIDForInitial - }) - - It("resolves the address", func() { - if os.Getenv("APPVEYOR") == "True" { - Skip("This test is flaky on AppVeyor.") - } - - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - remoteAddrChan := make(chan string, 1) - newClientConnection = func( - sconn sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - remoteAddrChan <- sconn.RemoteAddr().String() - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) - Expect(err).ToNot(HaveOccurred()) - Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) - }) - - It("uses the tls.Config.ServerName as the hostname, if present", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - hostnameChan := make(chan string, 1) - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - tlsConf *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - hostnameChan <- tlsConf.ServerName - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - tlsConf.ServerName = "foobar" - _, err := DialAddr("localhost:17890", tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - Eventually(hostnameChan).Should(Receive(Equal("foobar"))) - }) - - It("allows passing host without port as server name", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - hostnameChan := make(chan string, 1) - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - tlsConf *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - hostnameChan <- tlsConf.ServerName - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - conn.EXPECT().run() - return conn - } - tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, gomock.Any(), gomock.Any()) - _, err := Dial( - packetConn, - addr, - "test.com", - tlsConf, - config, - ) - Expect(err).ToNot(HaveOccurred()) - Eventually(hostnameChan).Should(Receive(Equal("test.com"))) - }) - - It("returns after the handshake is complete", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - run := make(chan struct{}) - newClientConnection = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - enable0RTT bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - Expect(enable0RTT).To(BeFalse()) - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run().Do(func() { close(run) }) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) - return conn - } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := Dial( - packetConn, - addr, - "localhost:1337", - tlsConf, - config, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - Eventually(run).Should(BeClosed()) - }) - - It("returns early connections", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - readyChan := make(chan struct{}) - done := make(chan struct{}) - newClientConnection = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - enable0RTT bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - Expect(enable0RTT).To(BeTrue()) - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run().Do(func() { <-done }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - conn.EXPECT().earlyConnReady().Return(readyChan) - return conn - } - - go func() { - defer GinkgoRecover() - defer close(done) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := DialEarly( - packetConn, - addr, - "localhost:1337", - tlsConf, - config, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - }() - Consistently(done).ShouldNot(BeClosed()) - close(readyChan) - Eventually(done).Should(BeClosed()) - }) - - It("returns an error that occurs while waiting for the handshake to complete", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - testErr := errors.New("early handshake error") - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run().Return(testErr) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial( - packetConn, - addr, - "localhost:1337", - tlsConf, - config, - ) - Expect(err).To(MatchError(testErr)) - }) - - It("closes the connection when the context is canceled", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - connRunning := make(chan struct{}) - defer close(connRunning) - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run().Do(func() { - <-connRunning - }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - return conn - } - ctx, cancel := context.WithCancel(context.Background()) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := DialContext( - ctx, - packetConn, - addr, - "localhost:1337", - tlsConf, - config, - ) - Expect(err).To(MatchError(context.Canceled)) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - conn.EXPECT().shutdown() - cancel() - Eventually(dialed).Should(BeClosed()) - }) - - It("closes the connection when it was created by DialAddr", func() { - if os.Getenv("APPVEYOR") == "True" { - Skip("This test is flaky on AppVeyor.") - } - - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - - var sconn sendConn - run := make(chan struct{}) - connCreated := make(chan struct{}) - conn := NewMockQuicConn(mockCtrl) - newClientConnection = func( - connP sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - sconn = connP - close(connCreated) - return conn - } - conn.EXPECT().run().Do(func() { - <-run - }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := DialAddr("localhost:1337", tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - - Eventually(connCreated).Should(BeClosed()) - - // check that the connection is not closed - Expect(sconn.Write([]byte("foobar"))).To(Succeed()) - - manager.EXPECT().Destroy() - close(run) - time.Sleep(50 * time.Millisecond) - - Eventually(done).Should(BeClosed()) - }) - - Context("quic.Config", func() { - It("setups with the right values", func() { - tokenStore := NewLRUTokenStore(10, 4) - config := &Config{ - HandshakeIdleTimeout: 1337 * time.Minute, - MaxIdleTimeout: 42 * time.Hour, - MaxIncomingStreams: 1234, - MaxIncomingUniStreams: 4321, - ConnectionIDLength: 13, - StatelessResetKey: []byte("foobar"), - TokenStore: tokenStore, - EnableDatagrams: true, - } - c := populateClientConfig(config, false) - Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) - Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - Expect(c.ConnectionIDLength).To(Equal(13)) - Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) - Expect(c.TokenStore).To(Equal(tokenStore)) - Expect(c.EnableDatagrams).To(BeTrue()) - }) - - It("errors when the Config contains an invalid version", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - version := protocol.VersionNumber(0x1234) - _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) - }) - - It("disables bidirectional streams", func() { - config := &Config{ - MaxIncomingStreams: -1, - MaxIncomingUniStreams: 4321, - } - c := populateClientConfig(config, false) - Expect(c.MaxIncomingStreams).To(BeZero()) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - }) - - It("disables unidirectional streams", func() { - config := &Config{ - MaxIncomingStreams: 1234, - MaxIncomingUniStreams: -1, - } - c := populateClientConfig(config, false) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) - Expect(c.MaxIncomingUniStreams).To(BeZero()) - }) - - It("uses 0-byte connection IDs when dialing an address", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) - - It("fills in default values if options are not set in the Config", func() { - c := populateClientConfig(&Config{}, false) - Expect(c.Versions).To(Equal(protocol.SupportedVersions)) - Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - }) - }) - - It("creates new connections with the right parameters", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} - c := make(chan struct{}) - var cconn sendConn - var version protocol.VersionNumber - var conf *Config - newClientConnection = func( - connP sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - configP *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - versionP protocol.VersionNumber, - ) quicConn { - cconn = connP - version = versionP - conf = configP - close(c) - // TODO: check connection IDs? - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Eventually(c).Should(BeClosed()) - Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn)) - Expect(version).To(Equal(config.Versions[0])) - Expect(conf.Versions).To(Equal(config.Versions)) - }) - - It("creates a new connections after version negotiation", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()).Times(2) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - var counter int - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - configP *Config, - _ *tls.Config, - pn protocol.PacketNumber, - _ bool, - hasNegotiatedVersion bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - versionP protocol.VersionNumber, - ) quicConn { - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - if counter == 0 { - Expect(pn).To(BeZero()) - Expect(hasNegotiatedVersion).To(BeFalse()) - conn.EXPECT().run().Return(&errCloseForRecreating{ - nextPacketNumber: 109, - nextVersion: 789, - }) - } else { - Expect(pn).To(Equal(protocol.PacketNumber(109))) - Expect(hasNegotiatedVersion).To(BeTrue()) - conn.EXPECT().run() - } - counter++ - return conn - } - - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := DialAddr("localhost:7890", tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Expect(counter).To(Equal(2)) - }) - }) -}) diff --git a/internal/quic-go/closed_conn.go b/internal/quic-go/closed_conn.go deleted file mode 100644 index a97861b9..00000000 --- a/internal/quic-go/closed_conn.go +++ /dev/null @@ -1,112 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// A closedLocalConn is a connection that we closed locally. -// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, -// with an exponential backoff. -type closedLocalConn struct { - conn sendConn - connClosePacket []byte - - closeOnce sync.Once - closeChan chan struct{} // is closed when the connection is closed or destroyed - - receivedPackets chan *receivedPacket - counter uint64 // number of packets received - - perspective protocol.Perspective - - logger utils.Logger -} - -var _ packetHandler = &closedLocalConn{} - -// newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn( - conn sendConn, - connClosePacket []byte, - perspective protocol.Perspective, - logger utils.Logger, -) packetHandler { - s := &closedLocalConn{ - conn: conn, - connClosePacket: connClosePacket, - perspective: perspective, - logger: logger, - closeChan: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, 64), - } - go s.run() - return s -} - -func (s *closedLocalConn) run() { - for { - select { - case p := <-s.receivedPackets: - s.handlePacketImpl(p) - case <-s.closeChan: - return - } - } -} - -func (s *closedLocalConn) handlePacket(p *receivedPacket) { - select { - case s.receivedPackets <- p: - default: - } -} - -func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { - s.counter++ - // exponential backoff - // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving - for n := s.counter; n > 1; n = n / 2 { - if n%2 != 0 { - return - } - } - s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter) - if err := s.conn.Write(s.connClosePacket); err != nil { - s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) - } -} - -func (s *closedLocalConn) shutdown() { - s.destroy(nil) -} - -func (s *closedLocalConn) destroy(error) { - s.closeOnce.Do(func() { - close(s.closeChan) - }) -} - -func (s *closedLocalConn) getPerspective() protocol.Perspective { - return s.perspective -} - -// A closedRemoteConn is a connection that was closed remotely. -// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. -// We can just ignore those packets. -type closedRemoteConn struct { - perspective protocol.Perspective -} - -var _ packetHandler = &closedRemoteConn{} - -func newClosedRemoteConn(pers protocol.Perspective) packetHandler { - return &closedRemoteConn{perspective: pers} -} - -func (s *closedRemoteConn) handlePacket(*receivedPacket) {} -func (s *closedRemoteConn) shutdown() {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/internal/quic-go/closed_conn_test.go b/internal/quic-go/closed_conn_test.go deleted file mode 100644 index 330a9dad..00000000 --- a/internal/quic-go/closed_conn_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package quic - -import ( - "errors" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Closed local connection", func() { - var ( - conn packetHandler - mconn *MockSendConn - ) - - BeforeEach(func() { - mconn = NewMockSendConn(mockCtrl) - conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) - }) - - AfterEach(func() { - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) - - It("tells its perspective", func() { - Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) - // stop the connection - conn.shutdown() - }) - - It("repeats the packet containing the CONNECTION_CLOSE frame", func() { - written := make(chan []byte) - mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes() - for i := 1; i <= 20; i++ { - conn.handlePacket(&receivedPacket{}) - if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { - Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE - } else { - Consistently(written, 10*time.Millisecond).Should(HaveLen(0)) - } - } - // stop the connection - conn.shutdown() - }) - - It("destroys connections", func() { - Eventually(areClosedConnsRunning).Should(BeTrue()) - conn.destroy(errors.New("destroy")) - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) -}) diff --git a/internal/quic-go/config.go b/internal/quic-go/config.go deleted file mode 100644 index 8f444f5d..00000000 --- a/internal/quic-go/config.go +++ /dev/null @@ -1,124 +0,0 @@ -package quic - -import ( - "errors" - "time" - - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// Clone clones a Config -func (c *Config) Clone() *Config { - copy := *c - return © -} - -func (c *Config) handshakeTimeout() time.Duration { - return utils.MaxDuration(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) -} - -func validateConfig(config *Config) error { - if config == nil { - return nil - } - if config.MaxIncomingStreams > 1<<60 { - return errors.New("invalid value for Config.MaxIncomingStreams") - } - if config.MaxIncomingUniStreams > 1<<60 { - return errors.New("invalid value for Config.MaxIncomingUniStreams") - } - return nil -} - -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } - if config.AcceptToken == nil { - config.AcceptToken = defaultAcceptToken - } - return config -} - -// populateClientConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 && !createdPacketConn { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } - return config -} - -func populateConfig(config *Config) *Config { - if config == nil { - config = &Config{} - } - versions := config.Versions - if len(versions) == 0 { - versions = protocol.SupportedVersions - } - handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout - if config.HandshakeIdleTimeout != 0 { - handshakeIdleTimeout = config.HandshakeIdleTimeout - } - idleTimeout := protocol.DefaultIdleTimeout - if config.MaxIdleTimeout != 0 { - idleTimeout = config.MaxIdleTimeout - } - initialStreamReceiveWindow := config.InitialStreamReceiveWindow - if initialStreamReceiveWindow == 0 { - initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData - } - maxStreamReceiveWindow := config.MaxStreamReceiveWindow - if maxStreamReceiveWindow == 0 { - maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow - } - initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow - if initialConnectionReceiveWindow == 0 { - initialConnectionReceiveWindow = protocol.DefaultInitialMaxData - } - maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow - if maxConnectionReceiveWindow == 0 { - maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow - } - maxIncomingStreams := config.MaxIncomingStreams - if maxIncomingStreams == 0 { - maxIncomingStreams = protocol.DefaultMaxIncomingStreams - } else if maxIncomingStreams < 0 { - maxIncomingStreams = 0 - } - maxIncomingUniStreams := config.MaxIncomingUniStreams - if maxIncomingUniStreams == 0 { - maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams - } else if maxIncomingUniStreams < 0 { - maxIncomingUniStreams = 0 - } - - return &Config{ - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - AcceptToken: config.AcceptToken, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: config.ConnectionIDLength, - StatelessResetKey: config.StatelessResetKey, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, - Tracer: config.Tracer, - } -} diff --git a/internal/quic-go/config_test.go b/internal/quic-go/config_test.go deleted file mode 100644 index 1710bdab..00000000 --- a/internal/quic-go/config_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package quic - -import ( - "fmt" - "net" - "reflect" - "time" - - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Config", func() { - Context("validating", func() { - It("validates a nil config", func() { - Expect(validateConfig(nil)).To(Succeed()) - }) - - It("validates a config with normal values", func() { - Expect(validateConfig(populateServerConfig(&Config{}))).To(Succeed()) - }) - - It("errors on too large values for MaxIncomingStreams", func() { - Expect(validateConfig(&Config{MaxIncomingStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingStreams")) - }) - - It("errors on too large values for MaxIncomingUniStreams", func() { - Expect(validateConfig(&Config{MaxIncomingUniStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingUniStreams")) - }) - }) - - configWithNonZeroNonFunctionFields := func() *Config { - c := &Config{} - v := reflect.ValueOf(c).Elem() - - typ := v.Type() - for i := 0; i < typ.NumField(); i++ { - f := v.Field(i) - if !f.CanSet() { - // unexported field; not cloned. - continue - } - - switch fn := typ.Field(i).Name; fn { - case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease": - // Can't compare functions. - case "Versions": - f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) - case "ConnectionIDLength": - f.Set(reflect.ValueOf(8)) - case "HandshakeIdleTimeout": - f.Set(reflect.ValueOf(time.Second)) - case "MaxIdleTimeout": - f.Set(reflect.ValueOf(time.Hour)) - case "TokenStore": - f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) - case "InitialStreamReceiveWindow": - f.Set(reflect.ValueOf(uint64(1234))) - case "MaxStreamReceiveWindow": - f.Set(reflect.ValueOf(uint64(9))) - case "InitialConnectionReceiveWindow": - f.Set(reflect.ValueOf(uint64(4321))) - case "MaxConnectionReceiveWindow": - f.Set(reflect.ValueOf(uint64(10))) - case "MaxIncomingStreams": - f.Set(reflect.ValueOf(int64(11))) - case "MaxIncomingUniStreams": - f.Set(reflect.ValueOf(int64(12))) - case "StatelessResetKey": - f.Set(reflect.ValueOf([]byte{1, 2, 3, 4})) - case "KeepAlivePeriod": - f.Set(reflect.ValueOf(time.Second)) - case "EnableDatagrams": - f.Set(reflect.ValueOf(true)) - case "DisableVersionNegotiationPackets": - f.Set(reflect.ValueOf(true)) - case "DisablePathMTUDiscovery": - f.Set(reflect.ValueOf(true)) - case "Tracer": - f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) - default: - Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) - } - } - return c - } - - It("uses 10s handshake timeout for short handshake idle timeouts", func() { - c := &Config{HandshakeIdleTimeout: time.Second} - Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) - }) - - It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { - c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} - Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) - }) - - Context("cloning", func() { - It("clones function fields", func() { - var calledAcceptToken, calledAllowConnectionWindowIncrease bool - c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, - } - c2 := c1.Clone() - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) - c2.AllowConnectionWindowIncrease(nil, 1234) - Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) - }) - - It("clones non-function fields", func() { - c := configWithNonZeroNonFunctionFields() - Expect(c.Clone()).To(Equal(c)) - }) - - It("returns a copy", func() { - c1 := &Config{ - MaxIncomingStreams: 100, - AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, - } - c2 := c1.Clone() - c2.MaxIncomingStreams = 200 - c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - - Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) - }) - }) - - Context("populating", func() { - It("populates function fields", func() { - var calledAcceptToken bool - c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - } - c2 := populateConfig(c1) - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) - }) - - It("copies non-function fields", func() { - c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c)).To(Equal(c)) - }) - - It("populates empty fields with default values", func() { - c := populateConfig(&Config{}) - Expect(c.Versions).To(Equal(protocol.SupportedVersions)) - Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) - Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) - Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) - Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) - Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) - Expect(c.DisablePathMTUDiscovery).To(BeFalse()) - }) - - It("populates empty fields with default values, for the server", func() { - c := populateServerConfig(&Config{}) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - Expect(c.AcceptToken).ToNot(BeNil()) - }) - - It("sets a default connection ID length if we didn't create the conn, for the client", func() { - c := populateClientConfig(&Config{}, false) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - }) - - It("doesn't set a default connection ID length if we created the conn, for the client", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) - }) -}) diff --git a/internal/quic-go/congestion/bandwidth.go b/internal/quic-go/congestion/bandwidth.go deleted file mode 100644 index a6560980..00000000 --- a/internal/quic-go/congestion/bandwidth.go +++ /dev/null @@ -1,25 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// Bandwidth of a connection -type Bandwidth uint64 - -const infBandwidth Bandwidth = math.MaxUint64 - -const ( - // BitsPerSecond is 1 bit per second - BitsPerSecond Bandwidth = 1 - // BytesPerSecond is 1 byte per second - BytesPerSecond = 8 * BitsPerSecond -) - -// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta -func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth { - return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond -} diff --git a/internal/quic-go/congestion/bandwidth_test.go b/internal/quic-go/congestion/bandwidth_test.go deleted file mode 100644 index 03162747..00000000 --- a/internal/quic-go/congestion/bandwidth_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package congestion - -import ( - "time" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Bandwidth", func() { - It("converts from time delta", func() { - Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond)) - }) -}) diff --git a/internal/quic-go/congestion/clock.go b/internal/quic-go/congestion/clock.go deleted file mode 100644 index 405fae70..00000000 --- a/internal/quic-go/congestion/clock.go +++ /dev/null @@ -1,18 +0,0 @@ -package congestion - -import "time" - -// A Clock returns the current time -type Clock interface { - Now() time.Time -} - -// DefaultClock implements the Clock interface using the Go stdlib clock. -type DefaultClock struct{} - -var _ Clock = DefaultClock{} - -// Now gets the current time -func (DefaultClock) Now() time.Time { - return time.Now() -} diff --git a/internal/quic-go/congestion/congestion_suite_test.go b/internal/quic-go/congestion/congestion_suite_test.go deleted file mode 100644 index 6a0f7ed7..00000000 --- a/internal/quic-go/congestion/congestion_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package congestion - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestCongestion(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Congestion Suite") -} diff --git a/internal/quic-go/congestion/cubic.go b/internal/quic-go/congestion/cubic.go deleted file mode 100644 index acbb6bcc..00000000 --- a/internal/quic-go/congestion/cubic.go +++ /dev/null @@ -1,214 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// This cubic implementation is based on the one found in Chromiums's QUIC -// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. - -// Constants based on TCP defaults. -// The following constants are in 2^10 fractions of a second instead of ms to -// allow a 10 shift right to divide. - -// 1024*1024^3 (first 1024 is from 0.100^3) -// where 0.100 is 100 ms which is the scaling round trip time. -const ( - cubeScale = 40 - cubeCongestionWindowScale = 410 - cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize - // TODO: when re-enabling cubic, make sure to use the actual packet size here - maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) -) - -const defaultNumConnections = 1 - -// Default Cubic backoff factor -const beta float32 = 0.7 - -// Additional backoff factor when loss occurs in the concave part of the Cubic -// curve. This additional backoff factor is expected to give up bandwidth to -// new concurrent flows and speed up convergence. -const betaLastMax float32 = 0.85 - -// Cubic implements the cubic algorithm from TCP -type Cubic struct { - clock Clock - - // Number of connections to simulate. - numConnections int - - // Time when this cycle started, after last loss event. - epoch time.Time - - // Max congestion window used just before last loss event. - // Note: to improve fairness to other streams an additional back off is - // applied to this value if the new value is below our latest value. - lastMaxCongestionWindow protocol.ByteCount - - // Number of acked bytes since the cycle started (epoch). - ackedBytesCount protocol.ByteCount - - // TCP Reno equivalent congestion window in packets. - estimatedTCPcongestionWindow protocol.ByteCount - - // Origin point of cubic function. - originPointCongestionWindow protocol.ByteCount - - // Time to origin point of cubic function in 2^10 fractions of a second. - timeToOriginPoint uint32 - - // Last congestion window in packets computed by cubic function. - lastTargetCongestionWindow protocol.ByteCount -} - -// NewCubic returns a new Cubic instance -func NewCubic(clock Clock) *Cubic { - c := &Cubic{ - clock: clock, - numConnections: defaultNumConnections, - } - c.Reset() - return c -} - -// Reset is called after a timeout to reset the cubic state -func (c *Cubic) Reset() { - c.epoch = time.Time{} - c.lastMaxCongestionWindow = 0 - c.ackedBytesCount = 0 - c.estimatedTCPcongestionWindow = 0 - c.originPointCongestionWindow = 0 - c.timeToOriginPoint = 0 - c.lastTargetCongestionWindow = 0 -} - -func (c *Cubic) alpha() float32 { - // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that - // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. - // We derive the equivalent alpha for an N-connection emulation as: - b := c.beta() - return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) -} - -func (c *Cubic) beta() float32 { - // kNConnectionBeta is the backoff factor after loss for our N-connection - // emulation, which emulates the effective backoff of an ensemble of N - // TCP-Reno connections on a single loss event. The effective multiplier is - // computed as: - return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) -} - -func (c *Cubic) betaLastMax() float32 { - // betaLastMax is the additional backoff factor after loss for our - // N-connection emulation, which emulates the additional backoff of - // an ensemble of N TCP-Reno connections on a single loss event. The - // effective multiplier is computed as: - return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) -} - -// OnApplicationLimited is called on ack arrival when sender is unable to use -// the available congestion window. Resets Cubic state during quiescence. -func (c *Cubic) OnApplicationLimited() { - // When sender is not using the available congestion window, the window does - // not grow. But to be RTT-independent, Cubic assumes that the sender has been - // using the entire window during the time since the beginning of the current - // "epoch" (the end of the last loss recovery period). Since - // application-limited periods break this assumption, we reset the epoch when - // in such a period. This reset effectively freezes congestion window growth - // through application-limited periods and allows Cubic growth to continue - // when the entire window is being used. - c.epoch = time.Time{} -} - -// CongestionWindowAfterPacketLoss computes a new congestion window to use after -// a loss event. Returns the new congestion window in packets. The new -// congestion window is a multiplicative decrease of our current window. -func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { - if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { - // We never reached the old max, so assume we are competing with another - // flow. Use our extra back off factor to allow the other flow to go up. - c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) - } else { - c.lastMaxCongestionWindow = currentCongestionWindow - } - c.epoch = time.Time{} // Reset time. - return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) -} - -// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. -// Returns the new congestion window in packets. The new congestion window -// follows a cubic function that depends on the time passed since last -// packet loss. -func (c *Cubic) CongestionWindowAfterAck( - ackedBytes protocol.ByteCount, - currentCongestionWindow protocol.ByteCount, - delayMin time.Duration, - eventTime time.Time, -) protocol.ByteCount { - c.ackedBytesCount += ackedBytes - - if c.epoch.IsZero() { - // First ACK after a loss event. - c.epoch = eventTime // Start of epoch. - c.ackedBytesCount = ackedBytes // Reset count. - // Reset estimated_tcp_congestion_window_ to be in sync with cubic. - c.estimatedTCPcongestionWindow = currentCongestionWindow - if c.lastMaxCongestionWindow <= currentCongestionWindow { - c.timeToOriginPoint = 0 - c.originPointCongestionWindow = currentCongestionWindow - } else { - c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) - c.originPointCongestionWindow = c.lastMaxCongestionWindow - } - } - - // Change the time unit from microseconds to 2^10 fractions per second. Take - // the round trip time in account. This is done to allow us to use shift as a - // divide operator. - elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) - - // Right-shifts of negative, signed numbers have implementation-dependent - // behavior, so force the offset to be positive, as is done in the kernel. - offset := int64(c.timeToOriginPoint) - elapsedTime - if offset < 0 { - offset = -offset - } - - deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale - var targetCongestionWindow protocol.ByteCount - if elapsedTime > int64(c.timeToOriginPoint) { - targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow - } else { - targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow - } - // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) - - // Increase the window by approximately Alpha * 1 MSS of bytes every - // time we ack an estimated tcp window of bytes. For small - // congestion windows (less than 25), the formula below will - // increase slightly slower than linearly per estimated tcp window - // of bytes. - c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) - c.ackedBytesCount = 0 - - // We have a new cubic congestion window. - c.lastTargetCongestionWindow = targetCongestionWindow - - // Compute target congestion_window based on cubic target and estimated TCP - // congestion_window, use highest (fastest). - if targetCongestionWindow < c.estimatedTCPcongestionWindow { - targetCongestionWindow = c.estimatedTCPcongestionWindow - } - return targetCongestionWindow -} - -// SetNumConnections sets the number of emulated connections -func (c *Cubic) SetNumConnections(n int) { - c.numConnections = n -} diff --git a/internal/quic-go/congestion/cubic_sender.go b/internal/quic-go/congestion/cubic_sender.go deleted file mode 100644 index 12074d90..00000000 --- a/internal/quic-go/congestion/cubic_sender.go +++ /dev/null @@ -1,316 +0,0 @@ -package congestion - -import ( - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -const ( - // maxDatagramSize is the default maximum packet size used in the Linux TCP implementation. - // Used in QUIC for congestion window computations in bytes. - initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) - maxBurstPackets = 3 - renoBeta = 0.7 // Reno backoff factor. - minCongestionWindowPackets = 2 - initialCongestionWindow = 32 -) - -type cubicSender struct { - hybridSlowStart HybridSlowStart - rttStats *utils.RTTStats - cubic *Cubic - pacer *pacer - clock Clock - - reno bool - - // Track the largest packet that has been sent. - largestSentPacketNumber protocol.PacketNumber - - // Track the largest packet that has been acked. - largestAckedPacketNumber protocol.PacketNumber - - // Track the largest packet number outstanding when a CWND cutback occurs. - largestSentAtLastCutback protocol.PacketNumber - - // Whether the last loss event caused us to exit slowstart. - // Used for stats collection of slowstartPacketsLost - lastCutbackExitedSlowstart bool - - // Congestion window in bytes. - congestionWindow protocol.ByteCount - - // Slow start congestion window in bytes, aka ssthresh. - slowStartThreshold protocol.ByteCount - - // ACK counter for the Reno implementation. - numAckedPackets uint64 - - initialCongestionWindow protocol.ByteCount - initialMaxCongestionWindow protocol.ByteCount - - maxDatagramSize protocol.ByteCount - - lastState logging.CongestionState - tracer logging.ConnectionTracer -} - -var ( - _ SendAlgorithm = &cubicSender{} - _ SendAlgorithmWithDebugInfos = &cubicSender{} -) - -// NewCubicSender makes a new cubic sender -func NewCubicSender( - clock Clock, - rttStats *utils.RTTStats, - initialMaxDatagramSize protocol.ByteCount, - reno bool, - tracer logging.ConnectionTracer, -) *cubicSender { - return newCubicSender( - clock, - rttStats, - reno, - initialMaxDatagramSize, - initialCongestionWindow*initialMaxDatagramSize, - protocol.MaxCongestionWindowPackets*initialMaxDatagramSize, - tracer, - ) -} - -func newCubicSender( - clock Clock, - rttStats *utils.RTTStats, - reno bool, - initialMaxDatagramSize, - initialCongestionWindow, - initialMaxCongestionWindow protocol.ByteCount, - tracer logging.ConnectionTracer, -) *cubicSender { - c := &cubicSender{ - rttStats: rttStats, - largestSentPacketNumber: protocol.InvalidPacketNumber, - largestAckedPacketNumber: protocol.InvalidPacketNumber, - largestSentAtLastCutback: protocol.InvalidPacketNumber, - initialCongestionWindow: initialCongestionWindow, - initialMaxCongestionWindow: initialMaxCongestionWindow, - congestionWindow: initialCongestionWindow, - slowStartThreshold: protocol.MaxByteCount, - cubic: NewCubic(clock), - clock: clock, - reno: reno, - tracer: tracer, - maxDatagramSize: initialMaxDatagramSize, - } - c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { - c.lastState = logging.CongestionStateSlowStart - c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) - } - return c -} - -// TimeUntilSend returns when the next packet should be sent. -func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time { - return c.pacer.TimeUntilSend() -} - -func (c *cubicSender) HasPacingBudget() bool { - return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize -} - -func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { - return c.maxDatagramSize * protocol.MaxCongestionWindowPackets -} - -func (c *cubicSender) minCongestionWindow() protocol.ByteCount { - return c.maxDatagramSize * minCongestionWindowPackets -} - -func (c *cubicSender) OnPacketSent( - sentTime time.Time, - _ protocol.ByteCount, - packetNumber protocol.PacketNumber, - bytes protocol.ByteCount, - isRetransmittable bool, -) { - c.pacer.SentPacket(sentTime, bytes) - if !isRetransmittable { - return - } - c.largestSentPacketNumber = packetNumber - c.hybridSlowStart.OnPacketSent(packetNumber) -} - -func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool { - return bytesInFlight < c.GetCongestionWindow() -} - -func (c *cubicSender) InRecovery() bool { - return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback -} - -func (c *cubicSender) InSlowStart() bool { - return c.GetCongestionWindow() < c.slowStartThreshold -} - -func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { - return c.congestionWindow -} - -func (c *cubicSender) MaybeExitSlowStart() { - if c.InSlowStart() && - c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { - // exit slow start - c.slowStartThreshold = c.congestionWindow - c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) - } -} - -func (c *cubicSender) OnPacketAcked( - ackedPacketNumber protocol.PacketNumber, - ackedBytes protocol.ByteCount, - priorInFlight protocol.ByteCount, - eventTime time.Time, -) { - c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) - if c.InRecovery() { - return - } - c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) - if c.InSlowStart() { - c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) - } -} - -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { - // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets - // already sent should be treated as a single loss event, since it's expected. - if packetNumber <= c.largestSentAtLastCutback { - return - } - c.lastCutbackExitedSlowstart = c.InSlowStart() - c.maybeTraceStateChange(logging.CongestionStateRecovery) - - if c.reno { - c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta) - } else { - c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) - } - if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { - c.congestionWindow = minCwnd - } - c.slowStartThreshold = c.congestionWindow - c.largestSentAtLastCutback = c.largestSentPacketNumber - // reset packet count from congestion avoidance mode. We start - // counting again when we're out of recovery. - c.numAckedPackets = 0 -} - -// Called when we receive an ack. Normal TCP tracks how many packets one ack -// represents, but quic has a separate ack for each packet. -func (c *cubicSender) maybeIncreaseCwnd( - _ protocol.PacketNumber, - ackedBytes protocol.ByteCount, - priorInFlight protocol.ByteCount, - eventTime time.Time, -) { - // Do not increase the congestion window unless the sender is close to using - // the current window. - if !c.isCwndLimited(priorInFlight) { - c.cubic.OnApplicationLimited() - c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) - return - } - if c.congestionWindow >= c.maxCongestionWindow() { - return - } - if c.InSlowStart() { - // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow += c.maxDatagramSize - c.maybeTraceStateChange(logging.CongestionStateSlowStart) - return - } - // Congestion avoidance - c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) - if c.reno { - // Classic Reno congestion avoidance. - c.numAckedPackets++ - if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { - c.congestionWindow += c.maxDatagramSize - c.numAckedPackets = 0 - } - } else { - c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) - } -} - -func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool { - congestionWindow := c.GetCongestionWindow() - if bytesInFlight >= congestionWindow { - return true - } - availableBytes := congestionWindow - bytesInFlight - slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 - return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize -} - -// BandwidthEstimate returns the current bandwidth estimate -func (c *cubicSender) BandwidthEstimate() Bandwidth { - srtt := c.rttStats.SmoothedRTT() - if srtt == 0 { - // If we haven't measured an rtt, the bandwidth estimate is unknown. - return infBandwidth - } - return BandwidthFromDelta(c.GetCongestionWindow(), srtt) -} - -// OnRetransmissionTimeout is called on an retransmission timeout -func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { - c.largestSentAtLastCutback = protocol.InvalidPacketNumber - if !packetsRetransmitted { - return - } - c.hybridSlowStart.Restart() - c.cubic.Reset() - c.slowStartThreshold = c.congestionWindow / 2 - c.congestionWindow = c.minCongestionWindow() -} - -// OnConnectionMigration is called when the connection is migrated (?) -func (c *cubicSender) OnConnectionMigration() { - c.hybridSlowStart.Restart() - c.largestSentPacketNumber = protocol.InvalidPacketNumber - c.largestAckedPacketNumber = protocol.InvalidPacketNumber - c.largestSentAtLastCutback = protocol.InvalidPacketNumber - c.lastCutbackExitedSlowstart = false - c.cubic.Reset() - c.numAckedPackets = 0 - c.congestionWindow = c.initialCongestionWindow - c.slowStartThreshold = c.initialMaxCongestionWindow -} - -func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { - return - } - c.tracer.UpdatedCongestionState(new) - c.lastState = new -} - -func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) { - if s < c.maxDatagramSize { - panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) - } - cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() - c.maxDatagramSize = s - if cwndIsMinCwnd { - c.congestionWindow = c.minCongestionWindow() - } - c.pacer.SetMaxDatagramSize(s) -} diff --git a/internal/quic-go/congestion/cubic_sender_test.go b/internal/quic-go/congestion/cubic_sender_test.go deleted file mode 100644 index cddab314..00000000 --- a/internal/quic-go/congestion/cubic_sender_test.go +++ /dev/null @@ -1,526 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -const ( - initialCongestionWindowPackets = 10 - defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * maxDatagramSize -) - -type mockClock time.Time - -func (c *mockClock) Now() time.Time { - return time.Time(*c) -} - -func (c *mockClock) Advance(d time.Duration) { - *c = mockClock(time.Time(*c).Add(d)) -} - -const MaxCongestionWindow protocol.ByteCount = 200 * maxDatagramSize - -var _ = Describe("Cubic Sender", func() { - var ( - sender *cubicSender - clock mockClock - bytesInFlight protocol.ByteCount - packetNumber protocol.PacketNumber - ackedPacketNumber protocol.PacketNumber - rttStats *utils.RTTStats - ) - - BeforeEach(func() { - bytesInFlight = 0 - packetNumber = 1 - ackedPacketNumber = 0 - clock = mockClock{} - rttStats = utils.NewRTTStats() - sender = newCubicSender( - &clock, - rttStats, - true, /*reno*/ - protocol.InitialPacketSizeIPv4, - initialCongestionWindowPackets*maxDatagramSize, - MaxCongestionWindow, - nil, - ) - }) - - SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int { - var packetsSent int - for sender.CanSend(bytesInFlight) { - sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true) - packetNumber++ - packetsSent++ - bytesInFlight += packetLength - } - return packetsSent - } - - // Normal is that TCP acks every other segment. - AckNPackets := func(n int) { - rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now()) - sender.MaybeExitSlowStart() - for i := 0; i < n; i++ { - ackedPacketNumber++ - sender.OnPacketAcked(ackedPacketNumber, maxDatagramSize, bytesInFlight, clock.Now()) - } - bytesInFlight -= protocol.ByteCount(n) * maxDatagramSize - clock.Advance(time.Millisecond) - } - - LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { - for i := 0; i < n; i++ { - ackedPacketNumber++ - sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight) - } - bytesInFlight -= protocol.ByteCount(n) * packetLength - } - - // Does not increment acked_packet_number_. - LosePacket := func(number protocol.PacketNumber) { - sender.OnPacketLost(number, maxDatagramSize, bytesInFlight) - bytesInFlight -= maxDatagramSize - } - - SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(maxDatagramSize) } - LoseNPackets := func(n int) { LoseNPacketsLen(n, maxDatagramSize) } - - It("has the right values at startup", func() { - // At startup make sure we are at the default. - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - // Make sure we can send. - Expect(sender.TimeUntilSend(0)).To(BeZero()) - Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) - // And that window is un-affected. - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - - // Fill the send window with data, then verify that we can't send. - SendAvailableSendWindow() - Expect(sender.CanSend(bytesInFlight)).To(BeFalse()) - }) - - It("paces", func() { - rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now()) - clock.Advance(time.Hour) - // Fill the send window with data, then verify that we can't send. - SendAvailableSendWindow() - AckNPackets(1) - delay := sender.TimeUntilSend(bytesInFlight) - Expect(delay).ToNot(BeZero()) - Expect(delay).ToNot(Equal(utils.InfDuration)) - }) - - It("application limited slow start", func() { - // Send exactly 10 packets and ensure the CWND ends at 14 packets. - const numberOfAcks = 5 - // At startup make sure we can send. - Expect(sender.CanSend(0)).To(BeTrue()) - Expect(sender.TimeUntilSend(0)).To(BeZero()) - - SendAvailableSendWindow() - for i := 0; i < numberOfAcks; i++ { - AckNPackets(2) - } - bytesToSend := sender.GetCongestionWindow() - // It's expected 2 acks will arrive when the bytes_in_flight are greater than - // half the CWND. - Expect(bytesToSend).To(Equal(defaultWindowTCP + maxDatagramSize*2*2)) - }) - - It("exponential slow start", func() { - const numberOfAcks = 20 - // At startup make sure we can send. - Expect(sender.CanSend(0)).To(BeTrue()) - Expect(sender.TimeUntilSend(0)).To(BeZero()) - Expect(sender.BandwidthEstimate()).To(Equal(infBandwidth)) - // Make sure we can send. - Expect(sender.TimeUntilSend(0)).To(BeZero()) - - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - cwnd := sender.GetCongestionWindow() - Expect(cwnd).To(Equal(defaultWindowTCP + maxDatagramSize*2*numberOfAcks)) - Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT()))) - }) - - It("slow start packet loss", func() { - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Lose a packet to exit slow start. - LoseNPackets(1) - packetsInRecoveryWindow := expectedSendWindow / maxDatagramSize - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Recovery phase. We need to ack every packet in the recovery window before - // we exit recovery. - numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize - AckNPackets(int(packetsInRecoveryWindow)) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // We need to ack an entire window before we increase CWND by 1. - AckNPackets(int(numberOfPacketsInWindow) - 2) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Next ack should increase cwnd by 1. - AckNPackets(1) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Now RTO and ensure slow start gets reset. - Expect(sender.hybridSlowStart.Started()).To(BeTrue()) - sender.OnRetransmissionTimeout(true) - Expect(sender.hybridSlowStart.Started()).To(BeFalse()) - }) - - It("slow start packet loss PRR", func() { - // Test based on the first example in RFC6937. - // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. - const numberOfAcks = 5 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - sendWindowBeforeLoss := expectedSendWindow - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Testing TCP proportional rate reduction. - // We should send packets paced over the received acks for the remaining - // outstanding packets. The number of packets before we exit recovery is the - // original CWND minus the packet that has been lost and the one which - // triggered the loss. - remainingPacketsInRecovery := sendWindowBeforeLoss/maxDatagramSize - 2 - - for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ { - AckNPackets(1) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - // We need to ack another window before we increase CWND by 1. - numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize - for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ { - AckNPackets(1) - Expect(SendAvailableSendWindow()).To(Equal(1)) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - AckNPackets(1) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - }) - - It("slow start burst packet loss PRR", func() { - // Test based on the second example in RFC6937, though we also implement - // forward acknowledgements, so the first two incoming acks will trigger - // PRR immediately. - // Ack 20 packets in 10 acks to raise the CWND to 30. - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Lose one more than the congestion window reduction, so that after loss, - // bytes_in_flight is lesser than the congestion window. - sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow)) - numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/maxDatagramSize + 1 - LoseNPackets(int(numPacketsToLose)) - // Immediately after the loss, ensure at least one packet can be sent. - // Losses without subsequent acks can occur with timer based loss detection. - Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) - AckNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Only 2 packets should be allowed to be sent, per PRR-SSRB - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Ack the next packet, which triggers another loss. - LoseNPackets(1) - AckNPackets(1) - - // Send 2 packets to simulate PRR-SSRB. - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Ack the next packet, which triggers another loss. - LoseNPackets(1) - AckNPackets(1) - - // Send 2 packets to simulate PRR-SSRB. - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Exit recovery and return to sending at the new rate. - for i := 0; i < numberOfAcks; i++ { - AckNPackets(1) - Expect(SendAvailableSendWindow()).To(Equal(1)) - } - }) - - It("RTO congestion window", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) - - // Expect the window to decrease to the minimum once the RTO fires - // and slow start threshold to be set to 1/2 of the CWND. - sender.OnRetransmissionTimeout(true) - Expect(sender.GetCongestionWindow()).To(Equal(2 * maxDatagramSize)) - Expect(sender.slowStartThreshold).To(Equal(5 * maxDatagramSize)) - }) - - It("RTO congestion window no retransmission", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - - // Expect the window to remain unchanged if the RTO fires but no - // packets are retransmitted. - sender.OnRetransmissionTimeout(false) - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - }) - - It("tcp cubic reset epoch on quiescence", func() { - const maxCongestionWindow = 50 - const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize - sender = newCubicSender(&clock, rttStats, false, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindowBytes, nil) - - numSent := SendAvailableSendWindow() - - // Make sure we fall out of slow start. - savedCwnd := sender.GetCongestionWindow() - LoseNPackets(1) - Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) - - // Ack the rest of the outstanding packets to get out of recovery. - for i := 1; i < numSent; i++ { - AckNPackets(1) - } - Expect(bytesInFlight).To(BeZero()) - - // Send a new window of data and ack all; cubic growth should occur. - savedCwnd = sender.GetCongestionWindow() - numSent = SendAvailableSendWindow() - for i := 0; i < numSent; i++ { - AckNPackets(1) - } - Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow())) - Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) - Expect(bytesInFlight).To(BeZero()) - - // Quiescent time of 100 seconds - clock.Advance(100 * time.Second) - - // Send new window of data and ack one packet. Cubic epoch should have - // been reset; ensure cwnd increase is not dramatic. - savedCwnd = sender.GetCongestionWindow() - SendAvailableSendWindow() - AckNPackets(1) - Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), maxDatagramSize)) - Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) - }) - - It("multiple losses in one window", func() { - SendAvailableSendWindow() - initialWindow := sender.GetCongestionWindow() - LosePacket(ackedPacketNumber + 1) - postLossWindow := sender.GetCongestionWindow() - Expect(initialWindow).To(BeNumerically(">", postLossWindow)) - LosePacket(ackedPacketNumber + 3) - Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) - LosePacket(packetNumber - 1) - Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) - - // Lose a later packet and ensure the window decreases. - LosePacket(packetNumber) - Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow())) - }) - - It("1 connection congestion avoidance at end of recovery", func() { - // Ack 10 packets in 5 acks to raise the CWND to 20. - const numberOfAcks = 5 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // No congestion window growth should occur in recovery phase, i.e., until the - // currently outstanding 20 packets are acked. - for i := 0; i < 10; i++ { - // Send our full send window. - SendAvailableSendWindow() - Expect(sender.InRecovery()).To(BeTrue()) - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - Expect(sender.InRecovery()).To(BeFalse()) - - // Out of recovery now. Congestion window should not grow during RTT. - for i := protocol.ByteCount(0); i < expectedSendWindow/maxDatagramSize-2; i += 2 { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - // Next ack should cause congestion window to grow by 1MSS. - SendAvailableSendWindow() - AckNPackets(2) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - }) - - It("no PRR", func() { - SendAvailableSendWindow() - LoseNPackets(9) - AckNPackets(1) - - Expect(sender.GetCongestionWindow()).To(Equal(protocol.ByteCount(renoBeta * float32(defaultWindowTCP)))) - windowInPackets := renoBeta * float32(defaultWindowTCP) / float32(maxDatagramSize) - numSent := SendAvailableSendWindow() - Expect(numSent).To(BeEquivalentTo(windowInPackets)) - }) - - It("reset after connection migration", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) - - // Starts with slow start. - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Loses a packet to exit slow start. - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. Slow - // start threshold is also updated. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - Expect(sender.slowStartThreshold).To(Equal(expectedSendWindow)) - - // Resets cwnd and slow start threshold on connection migrations. - sender.OnConnectionMigration() - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(MaxCongestionWindow)) - Expect(sender.hybridSlowStart.Started()).To(BeFalse()) - }) - - It("slow starts up to the maximum congestion window", func() { - const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize - sender = newCubicSender(&clock, rttStats, true, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) - - for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { - sender.MaybeExitSlowStart() - sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) - } - Expect(sender.GetCongestionWindow()).To(Equal(initialMaxCongestionWindow)) - }) - - It("doesn't allow reductions of the maximum packet size", func() { - Expect(func() { sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }).To(Panic()) - }) - - It("slow starts up to maximum congestion window, if larger packets are sent", func() { - const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize - sender = newCubicSender(&clock, rttStats, true, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) - const packetSize = initialMaxDatagramSize + 100 - sender.SetMaxDatagramSize(packetSize) - for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { - sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) - } - const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize - Expect(sender.GetCongestionWindow()).To(And( - BeNumerically(">", maxCwnd), - BeNumerically("<=", maxCwnd+packetSize), - )) - }) - - It("limit cwnd increase in congestion avoidance", func() { - // Enable Cubic. - sender = newCubicSender(&clock, rttStats, false, protocol.InitialPacketSizeIPv4, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil) - numSent := SendAvailableSendWindow() - - // Make sure we fall out of slow start. - savedCwnd := sender.GetCongestionWindow() - LoseNPackets(1) - Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) - - // Ack the rest of the outstanding packets to get out of recovery. - for i := 1; i < numSent; i++ { - AckNPackets(1) - } - Expect(bytesInFlight).To(BeZero()) - - savedCwnd = sender.GetCongestionWindow() - SendAvailableSendWindow() - - // Ack packets until the CWND increases. - for sender.GetCongestionWindow() == savedCwnd { - AckNPackets(1) - SendAvailableSendWindow() - } - // Bytes in flight may be larger than the CWND if the CWND isn't an exact - // multiple of the packet sizes being sent. - Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow())) - savedCwnd = sender.GetCongestionWindow() - - // Advance time 2 seconds waiting for an ack. - clock.Advance(2 * time.Second) - - // Ack two packets. The CWND should increase by only one packet. - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + maxDatagramSize)) - }) -}) diff --git a/internal/quic-go/congestion/cubic_test.go b/internal/quic-go/congestion/cubic_test.go deleted file mode 100644 index e2fc5d33..00000000 --- a/internal/quic-go/congestion/cubic_test.go +++ /dev/null @@ -1,239 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -const ( - numConnections uint32 = 2 - nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections) - nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections) - nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta) - maxCubicTimeInterval = 30 * time.Millisecond -) - -var _ = Describe("Cubic", func() { - var ( - clock mockClock - cubic *Cubic - ) - - BeforeEach(func() { - clock = mockClock{} - cubic = NewCubic(&clock) - cubic.SetNumConnections(int(numConnections)) - }) - - renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount { - return currentCwnd + protocol.ByteCount(float32(maxDatagramSize)*nConnectionAlpha*float32(maxDatagramSize)/float32(currentCwnd)) - } - - cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount { - offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000 - deltaCongestionWindow := 410 * offset * offset * offset * maxDatagramSize >> 40 - return initialCwnd + deltaCongestionWindow - } - - It("works above origin (with tighter bounds)", func() { - // Convex growth. - const rttMin = 100 * time.Millisecond - const rttMinS = float32(rttMin/time.Millisecond) / 1000.0 - currentCwnd := 10 * maxDatagramSize - initialCwnd := currentCwnd - - clock.Advance(time.Millisecond) - initialTime := clock.Now() - expectedFirstCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, initialTime) - Expect(expectedFirstCwnd).To(Equal(currentCwnd)) - - // Normal TCP phase. - // The maximum number of expected reno RTTs can be calculated by - // finding the point where the cubic curve and the reno curve meet. - maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2) - for i := 0; i < maxRenoRtts; i++ { - // Alternatively, we expect it to increase by one, every time we - // receive current_cwnd/Alpha acks back. (This is another way of - // saying we expect cwnd to increase by approximately Alpha once - // we receive current_cwnd number ofacks back). - numAcksThisEpoch := int(float32(currentCwnd/maxDatagramSize) / nConnectionAlpha) - - initialCwndThisEpoch := currentCwnd - for n := 0; n < numAcksThisEpoch; n++ { - // Call once per ACK. - expectedNextCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(currentCwnd).To(Equal(expectedNextCwnd)) - } - // Our byte-wise Reno implementation is an estimate. We expect - // the cwnd to increase by approximately one MSS every - // cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as - // half a packet for smaller values of current_cwnd. - cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch - Expect(cwndChangeThisEpoch).To(BeNumerically("~", maxDatagramSize, maxDatagramSize/2)) - clock.Advance(100 * time.Millisecond) - } - - for i := 0; i < 54; i++ { - maxAcksThisEpoch := currentCwnd / maxDatagramSize - interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond - for n := 0; n < int(maxAcksThisEpoch); n++ { - clock.Advance(interval) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - // If we allow per-ack updates, every update is a small cubic update. - Expect(currentCwnd).To(Equal(expectedCwnd)) - } - } - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(currentCwnd).To(Equal(expectedCwnd)) - }) - - It("works above the origin with fine grained cubing", func() { - // Start the test with an artificially large cwnd to prevent Reno - // from over-taking cubic. - currentCwnd := 1000 * maxDatagramSize - initialCwnd := currentCwnd - rttMin := 100 * time.Millisecond - clock.Advance(time.Millisecond) - initialTime := clock.Now() - - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - clock.Advance(600 * time.Millisecond) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - - // We expect the algorithm to perform only non-zero, fine-grained cubic - // increases on every ack in this case. - for i := 0; i < 100; i++ { - clock.Advance(10 * time.Millisecond) - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - // Make sure we are performing cubic increases. - Expect(nextCwnd).To(Equal(expectedCwnd)) - // Make sure that these are non-zero, less-than-packet sized increases. - Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) - cwndDelta := nextCwnd - currentCwnd - Expect(maxDatagramSize / 10).To(BeNumerically(">", cwndDelta)) - currentCwnd = nextCwnd - } - }) - - It("handles per ack updates", func() { - // Start the test with a large cwnd and RTT, to force the first - // increase to be a cubic increase. - initialCwndPackets := 150 - currentCwnd := protocol.ByteCount(initialCwndPackets) * maxDatagramSize - rttMin := 350 * time.Millisecond - - // Initialize the epoch - clock.Advance(time.Millisecond) - // Keep track of the growth of the reno-equivalent cwnd. - rCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - initialCwnd := currentCwnd - - // Simulate the return of cwnd packets in less than - // MaxCubicInterval() time. - maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha) - interval := maxCubicTimeInterval / time.Duration(maxAcks+1) - - // In this scenario, the first increase is dictated by the cubic - // equation, but it is less than one byte, so the cwnd doesn't - // change. Normally, without per-ack increases, any cwnd plateau - // will cause the cwnd to be pinned for MaxCubicTimeInterval(). If - // we enable per-ack updates, the cwnd will continue to grow, - // regardless of the temporary plateau. - clock.Advance(interval) - rCwnd = renoCwnd(rCwnd) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd)) - for i := 1; i < maxAcks; i++ { - clock.Advance(interval) - nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - rCwnd = renoCwnd(rCwnd) - // The window shoud increase on every ack. - Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) - Expect(nextCwnd).To(Equal(rCwnd)) - currentCwnd = nextCwnd - } - - // After all the acks are returned from the epoch, we expect the - // cwnd to have increased by nearly one packet. (Not exactly one - // packet, because our byte-wise Reno algorithm is always a slight - // under-estimation). Without per-ack updates, the current_cwnd - // would otherwise be unchanged. - minimumExpectedIncrease := maxDatagramSize * 9 / 10 - Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease)) - }) - - It("handles loss events", func() { - rttMin := 100 * time.Millisecond - currentCwnd := 422 * maxDatagramSize - expectedCwnd := renoCwnd(currentCwnd) - // Initialize the state. - clock.Advance(time.Millisecond) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) - - // On the first loss, the last max congestion window is set to the - // congestion window before the loss. - preLossCwnd := currentCwnd - Expect(cubic.lastMaxCongestionWindow).To(BeZero()) - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd)) - currentCwnd = expectedCwnd - - // On the second loss, the current congestion window has not yet - // reached the last max congestion window. The last max congestion - // window will be reduced by an additional backoff factor to allow - // for competition. - preLossCwnd = currentCwnd - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - currentCwnd = expectedCwnd - Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow)) - expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax) - Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) - Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow)) - // Simulate an increase, and check that we are below the origin. - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd)) - - // On the final loss, simulate the condition where the congestion - // window had a chance to grow nearly to the last congestion window. - currentCwnd = cubic.lastMaxCongestionWindow - 1 - preLossCwnd = currentCwnd - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - expectedLastMax = preLossCwnd - Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) - }) - - It("works below origin", func() { - // Concave growth. - rttMin := 100 * time.Millisecond - currentCwnd := 422 * maxDatagramSize - expectedCwnd := renoCwnd(currentCwnd) - // Initialize the state. - clock.Advance(time.Millisecond) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) - - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - currentCwnd = expectedCwnd - // First update after loss to initialize the epoch. - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - // Cubic phase. - for i := 0; i < 40; i++ { - clock.Advance(100 * time.Millisecond) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - } - expectedCwnd = 553632 * maxDatagramSize / 1460 - Expect(currentCwnd).To(Equal(expectedCwnd)) - }) -}) diff --git a/internal/quic-go/congestion/hybrid_slow_start.go b/internal/quic-go/congestion/hybrid_slow_start.go deleted file mode 100644 index 035bc0da..00000000 --- a/internal/quic-go/congestion/hybrid_slow_start.go +++ /dev/null @@ -1,113 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// Note(pwestin): the magic clamping numbers come from the original code in -// tcp_cubic.c. -const hybridStartLowWindow = protocol.ByteCount(16) - -// Number of delay samples for detecting the increase of delay. -const hybridStartMinSamples = uint32(8) - -// Exit slow start if the min rtt has increased by more than 1/8th. -const hybridStartDelayFactorExp = 3 // 2^3 = 8 -// The original paper specifies 2 and 8ms, but those have changed over time. -const ( - hybridStartDelayMinThresholdUs = int64(4000) - hybridStartDelayMaxThresholdUs = int64(16000) -) - -// HybridSlowStart implements the TCP hybrid slow start algorithm -type HybridSlowStart struct { - endPacketNumber protocol.PacketNumber - lastSentPacketNumber protocol.PacketNumber - started bool - currentMinRTT time.Duration - rttSampleCount uint32 - hystartFound bool -} - -// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. -func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) { - s.endPacketNumber = lastSent - s.currentMinRTT = 0 - s.rttSampleCount = 0 - s.started = true -} - -// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. -func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool { - return s.endPacketNumber < ack -} - -// ShouldExitSlowStart should be called on every new ack frame, since a new -// RTT measurement can be made then. -// rtt: the RTT for this ack packet. -// minRTT: is the lowest delay (RTT) we have seen during the session. -// congestionWindow: the congestion window in packets. -func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool { - if !s.started { - // Time to start the hybrid slow start. - s.StartReceiveRound(s.lastSentPacketNumber) - } - if s.hystartFound { - return true - } - // Second detection parameter - delay increase detection. - // Compare the minimum delay (s.currentMinRTT) of the current - // burst of packets relative to the minimum delay during the session. - // Note: we only look at the first few(8) packets in each burst, since we - // only want to compare the lowest RTT of the burst relative to previous - // bursts. - s.rttSampleCount++ - if s.rttSampleCount <= hybridStartMinSamples { - if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { - s.currentMinRTT = latestRTT - } - } - // We only need to check this once per round. - if s.rttSampleCount == hybridStartMinSamples { - // Divide minRTT by 8 to get a rtt increase threshold for exiting. - minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) - // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond - - if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { - s.hystartFound = true - } - } - // Exit from slow start if the cwnd is greater than 16 and - // increasing delay is found. - return congestionWindow >= hybridStartLowWindow && s.hystartFound -} - -// OnPacketSent is called when a packet was sent -func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) { - s.lastSentPacketNumber = packetNumber -} - -// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end -// the round when the final packet of the burst is received and start it on -// the next incoming ack. -func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) { - if s.IsEndOfRound(ackedPacketNumber) { - s.started = false - } -} - -// Started returns true if started -func (s *HybridSlowStart) Started() bool { - return s.started -} - -// Restart the slow start phase -func (s *HybridSlowStart) Restart() { - s.started = false - s.hystartFound = false -} diff --git a/internal/quic-go/congestion/hybrid_slow_start_test.go b/internal/quic-go/congestion/hybrid_slow_start_test.go deleted file mode 100644 index 6de9ca8e..00000000 --- a/internal/quic-go/congestion/hybrid_slow_start_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Hybrid slow start", func() { - var slowStart HybridSlowStart - - BeforeEach(func() { - slowStart = HybridSlowStart{} - }) - - It("works in a simple case", func() { - packetNumber := protocol.PacketNumber(1) - endPacketNumber := protocol.PacketNumber(3) - slowStart.StartReceiveRound(endPacketNumber) - - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - - // Test duplicates. - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) - - // Test without a new registered end_packet_number; - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) - - endPacketNumber = 20 - slowStart.StartReceiveRound(endPacketNumber) - for packetNumber < endPacketNumber { - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - } - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) - }) - - It("works with delay", func() { - rtt := 60 * time.Millisecond - // We expect to detect the increase at +1/8 of the RTT; hence at a typical - // RTT of 60ms the detection will happen at 67.5 ms. - const hybridStartMinSamples = 8 // Number of acks required to trigger. - - endPacketNumber := protocol.PacketNumber(1) - endPacketNumber++ - slowStart.StartReceiveRound(endPacketNumber) - - // Will not trigger since our lowest RTT in our burst is the same as the long - // term RTT provided. - for n := 0; n < hybridStartMinSamples; n++ { - Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse()) - } - endPacketNumber++ - slowStart.StartReceiveRound(endPacketNumber) - for n := 1; n < hybridStartMinSamples; n++ { - Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse()) - } - // Expect to trigger since all packets in this burst was above the long term - // RTT provided. - Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue()) - }) -}) diff --git a/internal/quic-go/congestion/interface.go b/internal/quic-go/congestion/interface.go deleted file mode 100644 index f56ed395..00000000 --- a/internal/quic-go/congestion/interface.go +++ /dev/null @@ -1,28 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A SendAlgorithm performs congestion control -type SendAlgorithm interface { - TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time - HasPacingBudget() bool - OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) - CanSend(bytesInFlight protocol.ByteCount) bool - MaybeExitSlowStart() - OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) - OnRetransmissionTimeout(packetsRetransmitted bool) - SetMaxDatagramSize(protocol.ByteCount) -} - -// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos -type SendAlgorithmWithDebugInfos interface { - SendAlgorithm - InSlowStart() bool - InRecovery() bool - GetCongestionWindow() protocol.ByteCount -} diff --git a/internal/quic-go/congestion/pacer.go b/internal/quic-go/congestion/pacer.go deleted file mode 100644 index 0dd26607..00000000 --- a/internal/quic-go/congestion/pacer.go +++ /dev/null @@ -1,77 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -const maxBurstSizePackets = 10 - -// The pacer implements a token bucket pacing algorithm. -type pacer struct { - budgetAtLastSent protocol.ByteCount - maxDatagramSize protocol.ByteCount - lastSentTime time.Time - getAdjustedBandwidth func() uint64 // in bytes/s -} - -func newPacer(getBandwidth func() Bandwidth) *pacer { - p := &pacer{ - maxDatagramSize: initialMaxDatagramSize, - getAdjustedBandwidth: func() uint64 { - // Bandwidth is in bits/s. We need the value in bytes/s. - bw := uint64(getBandwidth() / BytesPerSecond) - // Use a slightly higher value than the actual measured bandwidth. - // RTT variations then won't result in under-utilization of the congestion window. - // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, - // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. - return bw * 5 / 4 - }, - } - p.budgetAtLastSent = p.maxBurstSize() - return p -} - -func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) { - budget := p.Budget(sendTime) - if size > budget { - p.budgetAtLastSent = 0 - } else { - p.budgetAtLastSent = budget - size - } - p.lastSentTime = sendTime -} - -func (p *pacer) Budget(now time.Time) protocol.ByteCount { - if p.lastSentTime.IsZero() { - return p.maxBurstSize() - } - budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - return utils.MinByteCount(p.maxBurstSize(), budget) -} - -func (p *pacer) maxBurstSize() protocol.ByteCount { - return utils.MaxByteCount( - protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, - maxBurstSizePackets*p.maxDatagramSize, - ) -} - -// TimeUntilSend returns when the next packet should be sent. -// It returns the zero value of time.Time if a packet can be sent immediately. -func (p *pacer) TimeUntilSend() time.Time { - if p.budgetAtLastSent >= p.maxDatagramSize { - return time.Time{} - } - return p.lastSentTime.Add(utils.MaxDuration( - protocol.MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, - )) -} - -func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { - p.maxDatagramSize = s -} diff --git a/internal/quic-go/congestion/pacer_test.go b/internal/quic-go/congestion/pacer_test.go deleted file mode 100644 index e840ff22..00000000 --- a/internal/quic-go/congestion/pacer_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Pacer", func() { - var p *pacer - - const packetsPerSecond = 50 - var bandwidth uint64 // in bytes/s - - BeforeEach(func() { - bandwidth = uint64(packetsPerSecond * initialMaxDatagramSize) // 50 full-size packets per second - // The pacer will multiply the bandwidth with 1.25 to achieve a slightly higher pacing speed. - // For the tests, cancel out this factor, so we can do the math using the exact bandwidth. - p = newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) - }) - - It("allows a burst at the beginning", func() { - t := time.Now() - Expect(p.TimeUntilSend()).To(BeZero()) - Expect(p.Budget(t)).To(BeEquivalentTo(maxBurstSizePackets * initialMaxDatagramSize)) - }) - - It("allows a big burst for high pacing rates", func() { - t := time.Now() - bandwidth = uint64(10000 * packetsPerSecond * initialMaxDatagramSize) - Expect(p.TimeUntilSend()).To(BeZero()) - Expect(p.Budget(t)).To(BeNumerically(">", maxBurstSizePackets*initialMaxDatagramSize)) - }) - - It("reduces the budget when sending packets", func() { - t := time.Now() - budget := p.Budget(t) - for budget > 0 { - Expect(p.TimeUntilSend()).To(BeZero()) - Expect(p.Budget(t)).To(Equal(budget)) - p.SentPacket(t, initialMaxDatagramSize) - budget -= initialMaxDatagramSize - } - Expect(p.Budget(t)).To(BeZero()) - Expect(p.TimeUntilSend()).ToNot(BeZero()) - }) - - sendBurst := func(t time.Time) { - for p.Budget(t) > 0 { - p.SentPacket(t, initialMaxDatagramSize) - } - } - - It("paces packets after a burst", func() { - t := time.Now() - sendBurst(t) - // send 100 exactly paced packets - for i := 0; i < 100; i++ { - t2 := p.TimeUntilSend() - Expect(t2.Sub(t)).To(BeNumerically("~", time.Second/packetsPerSecond, time.Nanosecond)) - Expect(p.Budget(t2)).To(BeEquivalentTo(initialMaxDatagramSize)) - p.SentPacket(t2, initialMaxDatagramSize) - t = t2 - } - }) - - It("accounts for non-full-size packets", func() { - t := time.Now() - sendBurst(t) - t2 := p.TimeUntilSend() - Expect(t2.Sub(t)).To(BeNumerically("~", time.Second/packetsPerSecond, time.Nanosecond)) - // send a half-full packet - Expect(p.Budget(t2)).To(BeEquivalentTo(initialMaxDatagramSize)) - size := initialMaxDatagramSize / 2 - p.SentPacket(t2, size) - Expect(p.Budget(t2)).To(Equal(initialMaxDatagramSize - size)) - Expect(p.TimeUntilSend()).To(BeTemporally("~", t2.Add(time.Second/packetsPerSecond/2), time.Nanosecond)) - }) - - It("accumulates budget, if no packets are sent", func() { - t := time.Now() - sendBurst(t) - t2 := p.TimeUntilSend() - Expect(t2).To(BeTemporally(">", t)) - // wait for 5 times the duration - Expect(p.Budget(t.Add(5 * t2.Sub(t)))).To(BeEquivalentTo(5 * initialMaxDatagramSize)) - }) - - It("accumulates budget, if no packets are sent, for larger packet sizes", func() { - t := time.Now() - sendBurst(t) - const packetSize = initialMaxDatagramSize + 200 - p.SetMaxDatagramSize(packetSize) - t2 := p.TimeUntilSend() - Expect(t2).To(BeTemporally(">", t)) - // wait for 5 times the duration - Expect(p.Budget(t.Add(5 * t2.Sub(t)))).To(BeEquivalentTo(5 * packetSize)) - }) - - It("never allows bursts larger than the maximum burst size", func() { - t := time.Now() - sendBurst(t) - Expect(p.Budget(t.Add(time.Hour))).To(BeEquivalentTo(maxBurstSizePackets * initialMaxDatagramSize)) - }) - - It("never allows bursts larger than the maximum burst size, for larger packets", func() { - t := time.Now() - const packetSize = initialMaxDatagramSize + 200 - p.SetMaxDatagramSize(packetSize) - sendBurst(t) - Expect(p.Budget(t.Add(time.Hour))).To(BeEquivalentTo(maxBurstSizePackets * packetSize)) - }) - - It("changes the bandwidth", func() { - t := time.Now() - sendBurst(t) - bandwidth = uint64(5 * initialMaxDatagramSize) // reduce the bandwidth to 5 packet per second - Expect(p.TimeUntilSend()).To(Equal(t.Add(time.Second / 5))) - }) - - It("doesn't pace faster than the minimum pacing duration", func() { - t := time.Now() - sendBurst(t) - bandwidth = uint64(1e6 * initialMaxDatagramSize) - Expect(p.TimeUntilSend()).To(Equal(t.Add(protocol.MinPacingDelay))) - Expect(p.Budget(t.Add(protocol.MinPacingDelay))).To(Equal(protocol.ByteCount(protocol.MinPacingDelay) * initialMaxDatagramSize * 1e6 / 1e9)) - }) -}) diff --git a/internal/quic-go/conn_id_generator.go b/internal/quic-go/conn_id_generator.go deleted file mode 100644 index 10f30ae9..00000000 --- a/internal/quic-go/conn_id_generator.go +++ /dev/null @@ -1,140 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type connIDGenerator struct { - connIDLen int - highestSeq uint64 - - activeSrcConnIDs map[uint64]protocol.ConnectionID - initialClientDestConnID protocol.ConnectionID - - addConnectionID func(protocol.ConnectionID) - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken - removeConnectionID func(protocol.ConnectionID) - retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func(protocol.ConnectionID, packetHandler) - queueControlFrame func(wire.Frame) - - version protocol.VersionNumber -} - -func newConnIDGenerator( - initialConnectionID protocol.ConnectionID, - initialClientDestConnID protocol.ConnectionID, // nil for the client - addConnectionID func(protocol.ConnectionID), - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, - removeConnectionID func(protocol.ConnectionID), - retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func(protocol.ConnectionID, packetHandler), - queueControlFrame func(wire.Frame), - version protocol.VersionNumber, -) *connIDGenerator { - m := &connIDGenerator{ - connIDLen: initialConnectionID.Len(), - activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), - addConnectionID: addConnectionID, - getStatelessResetToken: getStatelessResetToken, - removeConnectionID: removeConnectionID, - retireConnectionID: retireConnectionID, - replaceWithClosed: replaceWithClosed, - queueControlFrame: queueControlFrame, - version: version, - } - m.activeSrcConnIDs[0] = initialConnectionID - m.initialClientDestConnID = initialClientDestConnID - return m -} - -func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { - if m.connIDLen == 0 { - return nil - } - // The active_connection_id_limit transport parameter is the number of - // connection IDs the peer will store. This limit includes the connection ID - // used during the handshake, and the one sent in the preferred_address - // transport parameter. - // We currently don't send the preferred_address transport parameter, - // so we can issue (limit - 1) connection IDs. - for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ { - if err := m.issueNewConnID(); err != nil { - return err - } - } - return nil -} - -func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { - if seq > m.highestSeq { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), - } - } - connID, ok := m.activeSrcConnIDs[seq] - // We might already have deleted this connection ID, if this is a duplicate frame. - if !ok { - return nil - } - if connID.Equal(sentWithDestConnID) { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), - } - } - m.retireConnectionID(connID) - delete(m.activeSrcConnIDs, seq) - // Don't issue a replacement for the initial connection ID. - if seq == 0 { - return nil - } - return m.issueNewConnID() -} - -func (m *connIDGenerator) issueNewConnID() error { - connID, err := protocol.GenerateConnectionID(m.connIDLen) - if err != nil { - return err - } - m.activeSrcConnIDs[m.highestSeq+1] = connID - m.addConnectionID(connID) - m.queueControlFrame(&wire.NewConnectionIDFrame{ - SequenceNumber: m.highestSeq + 1, - ConnectionID: connID, - StatelessResetToken: m.getStatelessResetToken(connID), - }) - m.highestSeq++ - return nil -} - -func (m *connIDGenerator) SetHandshakeComplete() { - if m.initialClientDestConnID != nil { - m.retireConnectionID(m.initialClientDestConnID) - m.initialClientDestConnID = nil - } -} - -func (m *connIDGenerator) RemoveAll() { - if m.initialClientDestConnID != nil { - m.removeConnectionID(m.initialClientDestConnID) - } - for _, connID := range m.activeSrcConnIDs { - m.removeConnectionID(connID) - } -} - -func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { - if m.initialClientDestConnID != nil { - m.replaceWithClosed(m.initialClientDestConnID, handler) - } - for _, connID := range m.activeSrcConnIDs { - m.replaceWithClosed(connID, handler) - } -} diff --git a/internal/quic-go/conn_id_generator_test.go b/internal/quic-go/conn_id_generator_test.go deleted file mode 100644 index 543fce4b..00000000 --- a/internal/quic-go/conn_id_generator_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection ID Generator", func() { - var ( - addedConnIDs []protocol.ConnectionID - retiredConnIDs []protocol.ConnectionID - removedConnIDs []protocol.ConnectionID - replacedWithClosed map[string]packetHandler - queuedFrames []wire.Frame - g *connIDGenerator - ) - initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} - initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} - - connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken { - return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]} - } - - BeforeEach(func() { - addedConnIDs = nil - retiredConnIDs = nil - removedConnIDs = nil - queuedFrames = nil - replacedWithClosed = make(map[string]packetHandler) - g = newConnIDGenerator( - initialConnID, - initialClientDestConnID, - func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, - connIDToToken, - func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, - func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, - func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, - protocol.VersionDraft29, - ) - }) - - It("issues new connection IDs", func() { - Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) - Expect(retiredConnIDs).To(BeEmpty()) - Expect(addedConnIDs).To(HaveLen(3)) - for i := 0; i < len(addedConnIDs)-1; i++ { - Expect(addedConnIDs[i]).ToNot(Equal(addedConnIDs[i+1])) - } - Expect(queuedFrames).To(HaveLen(3)) - for i := 0; i < 3; i++ { - f := queuedFrames[i] - Expect(f).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) - nf := f.(*wire.NewConnectionIDFrame) - Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1)) - Expect(nf.ConnectionID.Len()).To(Equal(7)) - Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID))) - } - }) - - It("limits the number of connection IDs that it issues", func() { - Expect(g.SetMaxActiveConnIDs(9999999)).To(Succeed()) - Expect(retiredConnIDs).To(BeEmpty()) - Expect(addedConnIDs).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1)) - Expect(queuedFrames).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1)) - }) - - // SetMaxActiveConnIDs is called twice when we dialing a 0-RTT connection: - // once for the restored from the old connections, once when we receive the transport parameters - Context("dealing with 0-RTT", func() { - It("doesn't issue new connection IDs when SetMaxActiveConnIDs is called with the same value", func() { - Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(3)) - queuedFrames = nil - Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) - Expect(queuedFrames).To(BeEmpty()) - }) - - It("issues more connection IDs if the server allows a higher limit on the resumed connection", func() { - Expect(g.SetMaxActiveConnIDs(3)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(2)) - queuedFrames = nil - Expect(g.SetMaxActiveConnIDs(6)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(3)) - }) - - It("issues more connection IDs if the server allows a higher limit on the resumed connection, when connection IDs were retired in between", func() { - Expect(g.SetMaxActiveConnIDs(3)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(2)) - queuedFrames = nil - g.Retire(1, protocol.ConnectionID{}) - Expect(queuedFrames).To(HaveLen(1)) - queuedFrames = nil - Expect(g.SetMaxActiveConnIDs(6)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(3)) - }) - }) - - It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() { - Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "retired connection ID 1 (highest issued: 0)", - })) - }) - - It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() { - Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) - Expect(queuedFrames).ToNot(BeEmpty()) - Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) - f := queuedFrames[0].(*wire.NewConnectionIDFrame) - Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID), - })) - }) - - It("issues new connection IDs, when old ones are retired", func() { - Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) - queuedFrames = nil - Expect(retiredConnIDs).To(BeEmpty()) - Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed()) - Expect(queuedFrames).To(HaveLen(1)) - Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) - nf := queuedFrames[0].(*wire.NewConnectionIDFrame) - Expect(nf.SequenceNumber).To(BeEquivalentTo(5)) - Expect(nf.ConnectionID.Len()).To(Equal(7)) - }) - - It("retires the initial connection ID", func() { - Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed()) - Expect(removedConnIDs).To(BeEmpty()) - Expect(retiredConnIDs).To(HaveLen(1)) - Expect(retiredConnIDs[0]).To(Equal(initialConnID)) - Expect(addedConnIDs).To(BeEmpty()) - }) - - It("handles duplicate retirements", func() { - Expect(g.SetMaxActiveConnIDs(11)).To(Succeed()) - queuedFrames = nil - Expect(retiredConnIDs).To(BeEmpty()) - Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) - Expect(retiredConnIDs).To(HaveLen(1)) - Expect(queuedFrames).To(HaveLen(1)) - Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) - Expect(retiredConnIDs).To(HaveLen(1)) - Expect(queuedFrames).To(HaveLen(1)) - }) - - It("retires the client's initial destination connection ID when the handshake completes", func() { - g.SetHandshakeComplete() - Expect(retiredConnIDs).To(HaveLen(1)) - Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID)) - }) - - It("removes all connection IDs", func() { - Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(4)) - g.RemoveAll() - Expect(removedConnIDs).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones - Expect(removedConnIDs).To(ContainElement(initialConnID)) - Expect(removedConnIDs).To(ContainElement(initialClientDestConnID)) - for _, f := range queuedFrames { - nf := f.(*wire.NewConnectionIDFrame) - Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) - } - }) - - It("replaces with a closed connection for all connection IDs", func() { - Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) - Expect(queuedFrames).To(HaveLen(4)) - sess := NewMockPacketHandler(mockCtrl) - g.ReplaceWithClosed(sess) - Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones - Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess)) - Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess)) - for _, f := range queuedFrames { - nf := f.(*wire.NewConnectionIDFrame) - Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess)) - } - }) -}) diff --git a/internal/quic-go/conn_id_manager.go b/internal/quic-go/conn_id_manager.go deleted file mode 100644 index bb12de28..00000000 --- a/internal/quic-go/conn_id_manager.go +++ /dev/null @@ -1,207 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type connIDManager struct { - queue utils.NewConnectionIDList - - handshakeComplete bool - activeSequenceNumber uint64 - highestRetired uint64 - activeConnectionID protocol.ConnectionID - activeStatelessResetToken *protocol.StatelessResetToken - - // We change the connection ID after sending on average - // protocol.PacketsPerConnectionID packets. The actual value is randomized - // hide the packet loss rate from on-path observers. - rand utils.Rand - packetsSinceLastChange uint32 - packetsPerConnectionID uint32 - - addStatelessResetToken func(protocol.StatelessResetToken) - removeStatelessResetToken func(protocol.StatelessResetToken) - queueControlFrame func(wire.Frame) -} - -func newConnIDManager( - initialDestConnID protocol.ConnectionID, - addStatelessResetToken func(protocol.StatelessResetToken), - removeStatelessResetToken func(protocol.StatelessResetToken), - queueControlFrame func(wire.Frame), -) *connIDManager { - return &connIDManager{ - activeConnectionID: initialDestConnID, - addStatelessResetToken: addStatelessResetToken, - removeStatelessResetToken: removeStatelessResetToken, - queueControlFrame: queueControlFrame, - } -} - -func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { - return h.addConnectionID(1, connID, resetToken) -} - -func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { - if err := h.add(f); err != nil { - return err - } - if h.queue.Len() >= protocol.MaxActiveConnectionIDs { - return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} - } - return nil -} - -func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { - // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active - // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. - if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { - h.queueControlFrame(&wire.RetireConnectionIDFrame{ - SequenceNumber: f.SequenceNumber, - }) - return nil - } - - // Retire elements in the queue. - // Doesn't retire the active connection ID. - if f.RetirePriorTo > h.highestRetired { - var next *utils.NewConnectionIDElement - for el := h.queue.Front(); el != nil; el = next { - if el.Value.SequenceNumber >= f.RetirePriorTo { - break - } - next = el.Next() - h.queueControlFrame(&wire.RetireConnectionIDFrame{ - SequenceNumber: el.Value.SequenceNumber, - }) - h.queue.Remove(el) - } - h.highestRetired = f.RetirePriorTo - } - - if f.SequenceNumber == h.activeSequenceNumber { - return nil - } - - if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil { - return err - } - - // Retire the active connection ID, if necessary. - if h.activeSequenceNumber < f.RetirePriorTo { - // The queue is guaranteed to have at least one element at this point. - h.updateConnectionID() - } - return nil -} - -func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { - // insert a new element at the end - if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { - h.queue.PushBack(utils.NewConnectionID{ - SequenceNumber: seq, - ConnectionID: connID, - StatelessResetToken: resetToken, - }) - return nil - } - // insert a new element somewhere in the middle - for el := h.queue.Front(); el != nil; el = el.Next() { - if el.Value.SequenceNumber == seq { - if !el.Value.ConnectionID.Equal(connID) { - return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) - } - if el.Value.StatelessResetToken != resetToken { - return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) - } - break - } - if el.Value.SequenceNumber > seq { - h.queue.InsertBefore(utils.NewConnectionID{ - SequenceNumber: seq, - ConnectionID: connID, - StatelessResetToken: resetToken, - }, el) - break - } - } - return nil -} - -func (h *connIDManager) updateConnectionID() { - h.queueControlFrame(&wire.RetireConnectionIDFrame{ - SequenceNumber: h.activeSequenceNumber, - }) - h.highestRetired = utils.MaxUint64(h.highestRetired, h.activeSequenceNumber) - if h.activeStatelessResetToken != nil { - h.removeStatelessResetToken(*h.activeStatelessResetToken) - } - - front := h.queue.Remove(h.queue.Front()) - h.activeSequenceNumber = front.SequenceNumber - h.activeConnectionID = front.ConnectionID - h.activeStatelessResetToken = &front.StatelessResetToken - h.packetsSinceLastChange = 0 - h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID)) - h.addStatelessResetToken(*h.activeStatelessResetToken) -} - -func (h *connIDManager) Close() { - if h.activeStatelessResetToken != nil { - h.removeStatelessResetToken(*h.activeStatelessResetToken) - } -} - -// is called when the server performs a Retry -// and when the server changes the connection ID in the first Initial sent -func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { - if h.activeSequenceNumber != 0 { - panic("expected first connection ID to have sequence number 0") - } - h.activeConnectionID = newConnID -} - -// is called when the server provides a stateless reset token in the transport parameters -func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) { - if h.activeSequenceNumber != 0 { - panic("expected first connection ID to have sequence number 0") - } - h.activeStatelessResetToken = &token - h.addStatelessResetToken(token) -} - -func (h *connIDManager) SentPacket() { - h.packetsSinceLastChange++ -} - -func (h *connIDManager) shouldUpdateConnID() bool { - if !h.handshakeComplete { - return false - } - // initiate the first change as early as possible (after handshake completion) - if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { - return true - } - // For later changes, only change if - // 1. The queue of connection IDs is filled more than 50%. - // 2. We sent at least PacketsPerConnectionID packets - return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs && - h.packetsSinceLastChange >= h.packetsPerConnectionID -} - -func (h *connIDManager) Get() protocol.ConnectionID { - if h.shouldUpdateConnID() { - h.updateConnectionID() - } - return h.activeConnectionID -} - -func (h *connIDManager) SetHandshakeComplete() { - h.handshakeComplete = true -} diff --git a/internal/quic-go/conn_id_manager_test.go b/internal/quic-go/conn_id_manager_test.go deleted file mode 100644 index 5348a0d7..00000000 --- a/internal/quic-go/conn_id_manager_test.go +++ /dev/null @@ -1,364 +0,0 @@ -package quic - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection ID Manager", func() { - var ( - m *connIDManager - frameQueue []wire.Frame - tokenAdded *protocol.StatelessResetToken - removedTokens []protocol.StatelessResetToken - ) - initialConnID := protocol.ConnectionID{0, 0, 0, 0} - - BeforeEach(func() { - frameQueue = nil - tokenAdded = nil - removedTokens = nil - m = newConnIDManager( - initialConnID, - func(token protocol.StatelessResetToken) { tokenAdded = &token }, - func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, - func(f wire.Frame, - ) { - frameQueue = append(frameQueue, f) - }) - }) - - get := func() (protocol.ConnectionID, protocol.StatelessResetToken) { - if m.queue.Len() == 0 { - return nil, protocol.StatelessResetToken{} - } - val := m.queue.Remove(m.queue.Front()) - return val.ConnectionID, val.StatelessResetToken - } - - It("returns the initial connection ID", func() { - Expect(m.Get()).To(Equal(initialConnID)) - }) - - It("changes the initial connection ID", func() { - m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5}) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) - }) - - It("sets the token for the first connection ID", func() { - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - m.SetStatelessResetToken(token) - Expect(*m.activeStatelessResetToken).To(Equal(token)) - Expect(*tokenAdded).To(Equal(token)) - }) - - It("adds and gets connection IDs", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, - StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - })).To(Succeed()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, - })).To(Succeed()) - c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) - c2, rt2 := get() - Expect(c2).To(Equal(protocol.ConnectionID{2, 3, 4, 5})) - Expect(rt2).To(Equal(protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})) - c3, _ := get() - Expect(c3).To(BeNil()) - }) - - It("accepts duplicates", func() { - f1 := &wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, - } - f2 := &wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, - } - Expect(m.Add(f1)).To(Succeed()) - Expect(m.Add(f2)).To(Succeed()) - c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) - c2, _ := get() - Expect(c2).To(BeNil()) - }) - - It("ignores duplicates for the currently used connection ID", func() { - f := &wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, - } - m.SetHandshakeComplete() - Expect(m.Add(f)).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - c, _ := get() - Expect(c).To(BeNil()) - // Now send the same connection ID again. It should not be queued. - Expect(m.Add(f)).To(Succeed()) - c, _ = get() - Expect(c).To(BeNil()) - }) - - It("rejects duplicates with different connection IDs", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - })).To(Succeed()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, - })).To(MatchError("received conflicting connection IDs for sequence number 42")) - }) - - It("rejects duplicates with different connection IDs", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, - })).To(Succeed()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - })).To(MatchError("received conflicting stateless reset tokens for sequence number 42")) - }) - - It("retires connection IDs", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - })).To(Succeed()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 13, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, - })).To(Succeed()) - Expect(frameQueue).To(BeEmpty()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - RetirePriorTo: 14, - SequenceNumber: 17, - ConnectionID: protocol.ConnectionID{3, 4, 5, 6}, - })).To(Succeed()) - Expect(frameQueue).To(HaveLen(3)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10)) - Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13)) - Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) - }) - - It("ignores reordered connection IDs, if their sequence number was already retired", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - RetirePriorTo: 5, - })).To(Succeed()) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - frameQueue = nil - // If this NEW_CONNECTION_ID frame hadn't been reordered, we would have retired it before. - // Make sure it gets retired immediately now. - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - })).To(Succeed()) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(4)) - }) - - It("ignores reordered connection IDs, if their sequence number was already retired or less than active", func() { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - RetirePriorTo: 5, - })).To(Succeed()) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - frameQueue = nil - Expect(m.Get()).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 9, - ConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetirePriorTo: 5, - })).To(Succeed()) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(9)) - }) - - It("accepts retransmissions for the connection ID that is in use", func() { - connID := protocol.ConnectionID{1, 2, 3, 4} - - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: connID, - })).To(Succeed()) - m.SetHandshakeComplete() - Expect(frameQueue).To(BeEmpty()) - Expect(m.Get()).To(Equal(connID)) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0]).To(BeAssignableToTypeOf(&wire.RetireConnectionIDFrame{})) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - frameQueue = nil - - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: connID, - })).To(Succeed()) - Expect(frameQueue).To(BeEmpty()) - }) - - It("errors when the peer sends too connection IDs", func() { - for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, - StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, - })).To(Succeed()) - } - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(9999), - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - })).To(MatchError(&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})) - }) - - It("initiates the first connection ID update as soon as possible", func() { - Expect(m.Get()).To(Equal(initialConnID)) - m.SetHandshakeComplete() - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - }) - - It("waits until handshake completion before initiating a connection ID update", func() { - Expect(m.Get()).To(Equal(initialConnID)) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - })).To(Succeed()) - Expect(m.Get()).To(Equal(initialConnID)) - m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - }) - - It("initiates subsequent updates when enough packets are sent", func() { - var s uint8 - for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, - StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, - })).To(Succeed()) - } - - m.SetHandshakeComplete() - lastConnID := m.Get() - Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) - - var counter int - for i := 0; i < 50*protocol.PacketsPerConnectionID; i++ { - m.SentPacket() - - connID := m.Get() - if !connID.Equal(lastConnID) { - counter++ - lastConnID = connID - Expect(removedTokens).To(HaveLen(1)) - removedTokens = nil - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, - StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, - })).To(Succeed()) - s++ - } - } - Expect(counter).To(BeNumerically("~", 50, 10)) - }) - - It("retires delayed connection IDs that arrive after a higher connection ID was already retired", func() { - for s := uint8(10); s <= 10+protocol.MaxActiveConnectionIDs/2; s++ { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, - StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, - })).To(Succeed()) - } - m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) - for { - m.SentPacket() - if m.Get().Equal(protocol.ConnectionID{11, 11, 11, 11}) { - break - } - } - // The active conn ID is now {11, 11, 11, 11} - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) - // Add a delayed connection ID. It should just be ignored now. - frameQueue = nil - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(5), - ConnectionID: protocol.ConnectionID{5, 5, 5, 5}, - StatelessResetToken: protocol.StatelessResetToken{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, - })).To(Succeed()) - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) - Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(5)) - }) - - It("only initiates subsequent updates when enough if enough connection IDs are queued", func() { - for i := uint8(1); i <= protocol.MaxActiveConnectionIDs/2; i++ { - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, - StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, - })).To(Succeed()) - } - m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) - for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { - m.SentPacket() - } - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1337, - ConnectionID: protocol.ConnectionID{1, 3, 3, 7}, - })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{2, 2, 2, 2})) - Expect(removedTokens).To(HaveLen(1)) - Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) - }) - - It("removes the currently active stateless reset token when it is closed", func() { - m.Close() - Expect(removedTokens).To(BeEmpty()) - Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - })).To(Succeed()) - m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - m.Close() - Expect(removedTokens).To(HaveLen(1)) - Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})) - }) -}) diff --git a/internal/quic-go/connection.go b/internal/quic-go/connection.go deleted file mode 100644 index fd981709..00000000 --- a/internal/quic-go/connection.go +++ /dev/null @@ -1,2006 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "reflect" - "sync" - "sync/atomic" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/logutils" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type unpacker interface { - Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) -} - -type streamGetter interface { - GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) - GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) -} - -type streamManager interface { - GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) - GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) - OpenStream() (Stream, error) - OpenUniStream() (SendStream, error) - OpenStreamSync(context.Context) (Stream, error) - OpenUniStreamSync(context.Context) (SendStream, error) - AcceptStream(context.Context) (Stream, error) - AcceptUniStream(context.Context) (ReceiveStream, error) - DeleteStream(protocol.StreamID) error - UpdateLimits(*wire.TransportParameters) - HandleMaxStreamsFrame(*wire.MaxStreamsFrame) - CloseWithError(error) - ResetFor0RTT() - UseResetMaps() -} - -type cryptoStreamHandler interface { - RunHandshake() - ChangeConnectionID(protocol.ConnectionID) - SetLargest1RTTAcked(protocol.PacketNumber) error - SetHandshakeConfirmed() - GetSessionTicket() ([]byte, error) - io.Closer - ConnectionState() handshake.ConnectionState -} - -type packetInfo struct { - addr net.IP - ifIndex uint32 -} - -type receivedPacket struct { - buffer *packetBuffer - - remoteAddr net.Addr - rcvTime time.Time - data []byte - - ecn protocol.ECN - - info *packetInfo -} - -func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } - -func (p *receivedPacket) Clone() *receivedPacket { - return &receivedPacket{ - remoteAddr: p.remoteAddr, - rcvTime: p.rcvTime, - data: p.data, - buffer: p.buffer, - ecn: p.ecn, - info: p.info, - } -} - -type connRunner interface { - Add(protocol.ConnectionID, packetHandler) bool - GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken - Retire(protocol.ConnectionID) - Remove(protocol.ConnectionID) - ReplaceWithClosed(protocol.ConnectionID, packetHandler) - AddResetToken(protocol.StatelessResetToken, packetHandler) - RemoveResetToken(protocol.StatelessResetToken) -} - -type handshakeRunner struct { - onReceivedParams func(*wire.TransportParameters) - onError func(error) - dropKeys func(protocol.EncryptionLevel) - onHandshakeComplete func() -} - -func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) } -func (r *handshakeRunner) OnError(e error) { r.onError(e) } -func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } -func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } - -type closeError struct { - err error - remote bool - immediate bool -} - -type errCloseForRecreating struct { - nextPacketNumber protocol.PacketNumber - nextVersion protocol.VersionNumber -} - -func (e *errCloseForRecreating) Error() string { - return "closing connection in order to recreate it" -} - -var connTracingID uint64 // to be accessed atomically -func nextConnTracingID() uint64 { return atomic.AddUint64(&connTracingID, 1) } - -// A Connection is a QUIC connection -type connection struct { - // Destination connection ID used during the handshake. - // Used to check source connection ID on incoming packets. - handshakeDestConnID protocol.ConnectionID - // Set for the client. Destination connection ID used on the first Initial sent. - origDestConnID protocol.ConnectionID - retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed) - - srcConnIDLen int - - perspective protocol.Perspective - version protocol.VersionNumber - config *Config - - conn sendConn - sendQueue sender - - streamsMap streamManager - connIDManager *connIDManager - connIDGenerator *connIDGenerator - - rttStats *utils.RTTStats - - cryptoStreamManager *cryptoStreamManager - sentPacketHandler ackhandler.SentPacketHandler - receivedPacketHandler ackhandler.ReceivedPacketHandler - retransmissionQueue *retransmissionQueue - framer framer - windowUpdateQueue *windowUpdateQueue - connFlowController flowcontrol.ConnectionFlowController - tokenStoreKey string // only set for the client - tokenGenerator *handshake.TokenGenerator // only set for the server - - unpacker unpacker - frameParser wire.FrameParser - packer packer - mtuDiscoverer mtuDiscoverer // initialized when the handshake completes - - oneRTTStream cryptoStream // only set for the server - cryptoStreamHandler cryptoStreamHandler - - receivedPackets chan *receivedPacket - sendingScheduled chan struct{} - - closeOnce sync.Once - // closeChan is used to notify the run loop that it should terminate - closeChan chan closeError - - ctx context.Context - ctxCancel context.CancelFunc - handshakeCtx context.Context - handshakeCtxCancel context.CancelFunc - - undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level - undecryptablePacketsToProcess []*receivedPacket - - clientHelloWritten <-chan *wire.TransportParameters - earlyConnReadyChan chan struct{} - handshakeCompleteChan chan struct{} // is closed when the handshake completes - sentFirstPacket bool - handshakeComplete bool - handshakeConfirmed bool - - receivedRetry bool - versionNegotiated bool - receivedFirstPacket bool - - idleTimeout time.Duration - creationTime time.Time - // The idle timeout is set based on the max of the time we received the last packet... - lastPacketReceivedTime time.Time - // ... and the time we sent a new ack-eliciting packet after receiving a packet. - firstAckElicitingPacketAfterIdleSentTime time.Time - // pacingDeadline is the time when the next packet should be sent - pacingDeadline time.Time - - peerParams *wire.TransportParameters - - timer *utils.Timer - // keepAlivePingSent stores whether a keep alive PING is in flight. - // It is reset as soon as we receive a packet from the peer. - keepAlivePingSent bool - keepAliveInterval time.Duration - - datagramQueue *datagramQueue - - logID string - tracer logging.ConnectionTracer - logger utils.Logger -} - -var ( - _ Connection = &connection{} - _ EarlyConnection = &connection{} - _ streamSender = &connection{} - deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine -) - -var newConnection = func( - conn sendConn, - runner connRunner, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - statelessResetToken protocol.StatelessResetToken, - conf *Config, - tlsConf *tls.Config, - tokenGenerator *handshake.TokenGenerator, - enable0RTT bool, - tracer logging.ConnectionTracer, - tracingID uint64, - logger utils.Logger, - v protocol.VersionNumber, -) quicConn { - s := &connection{ - conn: conn, - config: conf, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(), - perspective: protocol.PerspectiveServer, - handshakeCompleteChan: make(chan struct{}), - tracer: tracer, - logger: logger, - version: v, - } - if origDestConnID != nil { - s.logID = origDestConnID.String() - } else { - s.logID = destConnID.String() - } - s.connIDManager = newConnIDManager( - destConnID, - func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, - runner.RemoveResetToken, - s.queueControlFrame, - ) - s.connIDGenerator = newConnIDGenerator( - srcConnID, - clientDestConnID, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, - s.queueControlFrame, - s.version, - ) - s.preSetup() - s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) - s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( - 0, - getMaxPacketSize(s.conn.RemoteAddr()), - s.rttStats, - s.perspective, - s.tracer, - s.logger, - s.version, - ) - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() - params := &wire.TransportParameters{ - InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - MaxIdleTimeout: s.config.MaxIdleTimeout, - MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - MaxAckDelay: protocol.MaxAckDelayInclGranularity, - AckDelayExponent: protocol.AckDelayExponent, - DisableActiveMigration: true, - StatelessResetToken: &statelessResetToken, - OriginalDestinationConnectionID: origDestConnID, - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, - RetrySourceConnectionID: retrySrcConnID, - } - if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - } - if s.tracer != nil { - s.tracer.SentTransportParameters(params) - } - cs := handshake.NewCryptoSetupServer( - initialStream, - handshakeStream, - clientDestConnID, - conn.LocalAddr(), - conn.RemoteAddr(), - params, - &handshakeRunner{ - onReceivedParams: s.handleTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, - onHandshakeComplete: func() { - runner.Retire(clientDestConnID) - close(s.handshakeCompleteChan) - }, - }, - tlsConf, - enable0RTT, - s.rttStats, - tracer, - logger, - s.version, - ) - s.cryptoStreamHandler = cs - s.packer = newPacketPacker( - srcConnID, - s.connIDManager.Get, - initialStream, - handshakeStream, - s.sentPacketHandler, - s.retransmissionQueue, - s.RemoteAddr(), - cs, - s.framer, - s.receivedPacketHandler, - s.datagramQueue, - s.perspective, - s.version, - ) - s.unpacker = newPacketUnpacker(cs, s.version) - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) - return s -} - -// declare this as a variable, such that we can it mock it in the tests -var newClientConnection = func( - conn sendConn, - runner connRunner, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - conf *Config, - tlsConf *tls.Config, - initialPacketNumber protocol.PacketNumber, - enable0RTT bool, - hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, - tracingID uint64, - logger utils.Logger, - v protocol.VersionNumber, -) quicConn { - s := &connection{ - conn: conn, - config: conf, - origDestConnID: destConnID, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - perspective: protocol.PerspectiveClient, - handshakeCompleteChan: make(chan struct{}), - logID: destConnID.String(), - logger: logger, - tracer: tracer, - versionNegotiated: hasNegotiatedVersion, - version: v, - } - s.connIDManager = newConnIDManager( - destConnID, - func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, - runner.RemoveResetToken, - s.queueControlFrame, - ) - s.connIDGenerator = newConnIDGenerator( - srcConnID, - nil, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, - s.queueControlFrame, - s.version, - ) - s.preSetup() - s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) - s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( - initialPacketNumber, - getMaxPacketSize(s.conn.RemoteAddr()), - s.rttStats, - s.perspective, - s.tracer, - s.logger, - s.version, - ) - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() - params := &wire.TransportParameters{ - InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - MaxIdleTimeout: s.config.MaxIdleTimeout, - MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - MaxAckDelay: protocol.MaxAckDelayInclGranularity, - AckDelayExponent: protocol.AckDelayExponent, - DisableActiveMigration: true, - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, - } - if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - } - if s.tracer != nil { - s.tracer.SentTransportParameters(params) - } - cs, clientHelloWritten := handshake.NewCryptoSetupClient( - initialStream, - handshakeStream, - destConnID, - conn.LocalAddr(), - conn.RemoteAddr(), - params, - &handshakeRunner{ - onReceivedParams: s.handleTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, - onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, - }, - tlsConf, - enable0RTT, - s.rttStats, - tracer, - logger, - s.version, - ) - s.clientHelloWritten = clientHelloWritten - s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) - s.unpacker = newPacketUnpacker(cs, s.version) - s.packer = newPacketPacker( - srcConnID, - s.connIDManager.Get, - initialStream, - handshakeStream, - s.sentPacketHandler, - s.retransmissionQueue, - s.RemoteAddr(), - cs, - s.framer, - s.receivedPacketHandler, - s.datagramQueue, - s.perspective, - s.version, - ) - if len(tlsConf.ServerName) > 0 { - s.tokenStoreKey = tlsConf.ServerName - } else { - s.tokenStoreKey = conn.RemoteAddr().String() - } - if s.config.TokenStore != nil { - if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { - s.packer.SetToken(token.data) - } - } - return s -} - -func (s *connection) preSetup() { - s.sendQueue = newSendQueue(s.conn) - s.retransmissionQueue = newRetransmissionQueue(s.version) - s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) - s.rttStats = &utils.RTTStats{} - s.connFlowController = flowcontrol.NewConnectionFlowController( - protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - protocol.ByteCount(s.config.MaxConnectionReceiveWindow), - s.onHasConnectionWindowUpdate, - func(size protocol.ByteCount) bool { - if s.config.AllowConnectionWindowIncrease == nil { - return true - } - return s.config.AllowConnectionWindowIncrease(s, uint64(size)) - }, - s.rttStats, - s.logger, - ) - s.earlyConnReadyChan = make(chan struct{}) - s.streamsMap = newStreamsMap( - s, - s.newFlowController, - uint64(s.config.MaxIncomingStreams), - uint64(s.config.MaxIncomingUniStreams), - s.perspective, - s.version, - ) - s.framer = newFramer(s.streamsMap, s.version) - s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) - s.closeChan = make(chan closeError, 1) - s.sendingScheduled = make(chan struct{}, 1) - s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) - - now := time.Now() - s.lastPacketReceivedTime = now - s.creationTime = now - - s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) - if s.config.EnableDatagrams { - s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) - } -} - -// run the connection main loop -func (s *connection) run() error { - defer s.ctxCancel() - - s.timer = utils.NewTimer() - - go s.cryptoStreamHandler.RunHandshake() - go func() { - if err := s.sendQueue.Run(); err != nil { - s.destroyImpl(err) - } - }() - - if s.perspective == protocol.PerspectiveClient { - select { - case zeroRTTParams := <-s.clientHelloWritten: - s.scheduleSending() - if zeroRTTParams != nil { - s.restoreTransportParameters(zeroRTTParams) - close(s.earlyConnReadyChan) - } - case closeErr := <-s.closeChan: - // put the close error back into the channel, so that the run loop can receive it - s.closeChan <- closeErr - } - } - - var ( - closeErr closeError - sendQueueAvailable <-chan struct{} - ) - -runLoop: - for { - // Close immediately if requested - select { - case closeErr = <-s.closeChan: - break runLoop - case <-s.handshakeCompleteChan: - s.handleHandshakeComplete() - default: - } - - s.maybeResetTimer() - - var processedUndecryptablePacket bool - if len(s.undecryptablePacketsToProcess) > 0 { - queue := s.undecryptablePacketsToProcess - s.undecryptablePacketsToProcess = nil - for _, p := range queue { - if processed := s.handlePacketImpl(p); processed { - processedUndecryptablePacket = true - } - // Don't set timers and send packets if the packet made us close the connection. - select { - case closeErr = <-s.closeChan: - break runLoop - default: - } - } - } - // If we processed any undecryptable packets, jump to the resetting of the timers directly. - if !processedUndecryptablePacket { - select { - case closeErr = <-s.closeChan: - break runLoop - case <-s.timer.Chan(): - s.timer.SetRead() - // We do all the interesting stuff after the switch statement, so - // nothing to see here. - case <-s.sendingScheduled: - // We do all the interesting stuff after the switch statement, so - // nothing to see here. - case <-sendQueueAvailable: - case firstPacket := <-s.receivedPackets: - wasProcessed := s.handlePacketImpl(firstPacket) - // Don't set timers and send packets if the packet made us close the connection. - select { - case closeErr = <-s.closeChan: - break runLoop - default: - } - if s.handshakeComplete { - // Now process all packets in the receivedPackets channel. - // Limit the number of packets to the length of the receivedPackets channel, - // so we eventually get a chance to send out an ACK when receiving a lot of packets. - numPackets := len(s.receivedPackets) - receiveLoop: - for i := 0; i < numPackets; i++ { - select { - case p := <-s.receivedPackets: - if processed := s.handlePacketImpl(p); processed { - wasProcessed = true - } - select { - case closeErr = <-s.closeChan: - break runLoop - default: - } - default: - break receiveLoop - } - } - } - // Only reset the timers if this packet was actually processed. - // This avoids modifying any state when handling undecryptable packets, - // which could be injected by an attacker. - if !wasProcessed { - continue - } - case <-s.handshakeCompleteChan: - s.handleHandshakeComplete() - } - } - - now := time.Now() - if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { - // This could cause packets to be retransmitted. - // Check it before trying to send packets. - if err := s.sentPacketHandler.OnLossDetectionTimeout(); err != nil { - s.closeLocal(err) - } - } - - if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) { - // send a PING frame since there is no activity in the connection - s.logger.Debugf("Sending a keep-alive PING to keep the connection alive.") - s.framer.QueueControlFrame(&wire.PingFrame{}) - s.keepAlivePingSent = true - } else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() { - s.destroyImpl(qerr.ErrHandshakeTimeout) - continue - } else { - idleTimeoutStartTime := s.idleTimeoutStartTime() - if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || - (s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { - s.destroyImpl(qerr.ErrIdleTimeout) - continue - } - } - - if s.sendQueue.WouldBlock() { - // The send queue is still busy sending out packets. - // Wait until there's space to enqueue new packets. - sendQueueAvailable = s.sendQueue.Available() - continue - } - if err := s.sendPackets(); err != nil { - s.closeLocal(err) - } - if s.sendQueue.WouldBlock() { - sendQueueAvailable = s.sendQueue.Available() - } else { - sendQueueAvailable = nil - } - } - - s.handleCloseError(&closeErr) - if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { - s.tracer.Close() - } - s.logger.Infof("Connection %s closed.", s.logID) - s.cryptoStreamHandler.Close() - s.sendQueue.Close() - s.timer.Stop() - return closeErr.err -} - -// blocks until the early connection can be used -func (s *connection) earlyConnReady() <-chan struct{} { - return s.earlyConnReadyChan -} - -func (s *connection) HandshakeComplete() context.Context { - return s.handshakeCtx -} - -func (s *connection) Context() context.Context { - return s.ctx -} - -func (s *connection) supportsDatagrams() bool { - return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount -} - -func (s *connection) ConnectionState() ConnectionState { - return ConnectionState{ - TLS: s.cryptoStreamHandler.ConnectionState(), - SupportsDatagrams: s.supportsDatagrams(), - } -} - -// Time when the next keep-alive packet should be sent. -// It returns a zero time if no keep-alive should be sent. -func (s *connection) nextKeepAliveTime() time.Time { - if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { - return time.Time{} - } - return s.lastPacketReceivedTime.Add(s.keepAliveInterval) -} - -func (s *connection) maybeResetTimer() { - var deadline time.Time - if !s.handshakeComplete { - deadline = utils.MinTime( - s.creationTime.Add(s.config.handshakeTimeout()), - s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout), - ) - } else { - if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { - deadline = keepAliveTime - } else { - deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) - } - } - - if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { - deadline = utils.MinTime(deadline, ackAlarm) - } - if lossTime := s.sentPacketHandler.GetLossDetectionTimeout(); !lossTime.IsZero() { - deadline = utils.MinTime(deadline, lossTime) - } - if !s.pacingDeadline.IsZero() { - deadline = utils.MinTime(deadline, s.pacingDeadline) - } - - s.timer.Reset(deadline) -} - -func (s *connection) idleTimeoutStartTime() time.Time { - return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime) -} - -func (s *connection) handleHandshakeComplete() { - s.handshakeComplete = true - s.handshakeCompleteChan = nil // prevent this case from ever being selected again - defer s.handshakeCtxCancel() - // Once the handshake completes, we have derived 1-RTT keys. - // There's no point in queueing undecryptable packets for later decryption any more. - s.undecryptablePackets = nil - - s.connIDManager.SetHandshakeComplete() - s.connIDGenerator.SetHandshakeComplete() - - if s.perspective == protocol.PerspectiveClient { - s.applyTransportParameters() - return - } - - s.handleHandshakeConfirmed() - - ticket, err := s.cryptoStreamHandler.GetSessionTicket() - if err != nil { - s.closeLocal(err) - } - if ticket != nil { - s.oneRTTStream.Write(ticket) - for s.oneRTTStream.HasData() { - s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) - } - } - token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) - if err != nil { - s.closeLocal(err) - } - s.queueControlFrame(&wire.NewTokenFrame{Token: token}) - s.queueControlFrame(&wire.HandshakeDoneFrame{}) -} - -func (s *connection) handleHandshakeConfirmed() { - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - s.cryptoStreamHandler.SetHandshakeConfirmed() - - if !s.config.DisablePathMTUDiscovery { - maxPacketSize := s.peerParams.MaxUDPPayloadSize - if maxPacketSize == 0 { - maxPacketSize = protocol.MaxByteCount - } - maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize) - s.mtuDiscoverer = newMTUDiscoverer( - s.rttStats, - getMaxPacketSize(s.conn.RemoteAddr()), - maxPacketSize, - func(size protocol.ByteCount) { - s.sentPacketHandler.SetMaxDatagramSize(size) - s.packer.SetMaxPacketSize(size) - }, - ) - } -} - -func (s *connection) handlePacketImpl(rp *receivedPacket) bool { - s.sentPacketHandler.ReceivedBytes(rp.Size()) - - if wire.IsVersionNegotiationPacket(rp.data) { - s.handleVersionNegotiationPacket(rp) - return false - } - - var counter uint8 - var lastConnID protocol.ConnectionID - var processed bool - data := rp.data - p := rp - for len(data) > 0 { - if counter > 0 { - p = p.Clone() - p.data = data - } - - hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) - if err != nil { - if s.tracer != nil { - dropReason := logging.PacketDropHeaderParseError - if err == wire.ErrUnsupportedVersion { - dropReason = logging.PacketDropUnsupportedVersion - } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) - } - s.logger.Debugf("error parsing packet: %s", err) - break - } - - if hdr.IsLongHeader && hdr.Version != s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) - } - s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) - break - } - - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) - } - s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) - break - } - lastConnID = hdr.DestConnectionID - - if counter > 0 { - p.buffer.Split() - } - counter++ - - // only log if this actually a coalesced packet - if s.logger.Debug() && (counter > 1 || len(rest) > 0) { - s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) - } - p.data = packetData - if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { - processed = true - } - data = rest - } - p.buffer.MaybeRelease() - return processed -} - -func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { - var wasQueued bool - - defer func() { - // Put back the packet buffer if the packet wasn't queued for later decryption. - if !wasQueued { - p.buffer.Decrement() - } - }() - - if hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(hdr, p.data) - } - - // The server can change the source connection ID with the first Handshake packet. - // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) - } - s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) - return false - } - // drop 0-RTT packets, if we are a client - if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) - } - return false - } - - packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) - if err != nil { - switch err { - case handshake.ErrKeysDropped: - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable) - } - s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size()) - case handshake.ErrKeysNotYetAvailable: - // Sealer for this encryption level not yet available. - // Try again later. - wasQueued = true - s.tryQueueingUndecryptablePacket(p, hdr) - case wire.ErrInvalidReservedBits: - s.closeLocal(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: err.Error(), - }) - case handshake.ErrDecryptionFailed: - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) - default: - var headerErr *headerParseError - if errors.As(err, &headerErr) { - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) - } else { - // This is an error returned by the AEAD (other than ErrDecryptionFailed). - // For example, a PROTOCOL_VIOLATION due to key updates. - s.closeLocal(err) - } - } - return false - } - - if s.logger.Debug() { - s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.packetNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) - packet.hdr.Log(s.logger) - } - - if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.packetNumber, packet.encryptionLevel) { - s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) - } - return false - } - - if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { - s.closeLocal(err) - return false - } - return true -} - -func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { - if s.perspective == protocol.PerspectiveServer { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) - } - s.logger.Debugf("Ignoring Retry.") - return false - } - if s.receivedFirstPacket { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) - } - s.logger.Debugf("Ignoring Retry, since we already received a packet.") - return false - } - destConnID := s.connIDManager.Get() - if hdr.SrcConnectionID.Equal(destConnID) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) - } - s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") - return false - } - // If a token is already set, this means that we already received a Retry from the server. - // Ignore this Retry packet. - if s.receivedRetry { - s.logger.Debugf("Ignoring Retry, since a Retry was already received.") - return false - } - - tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) - if !bytes.Equal(data[len(data)-16:], tag[:]) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) - } - s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") - return false - } - - if s.logger.Debug() { - s.logger.Debugf("<- Received Retry:") - (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) - } - if s.tracer != nil { - s.tracer.ReceivedRetry(hdr) - } - newDestConnID := hdr.SrcConnectionID - s.receivedRetry = true - if err := s.sentPacketHandler.ResetForRetry(); err != nil { - s.closeLocal(err) - return false - } - s.handshakeDestConnID = newDestConnID - s.retrySrcConnID = &newDestConnID - s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) - s.packer.SetToken(hdr.Token) - s.connIDManager.ChangeInitialConnID(newDestConnID) - s.scheduleSending() - return true -} - -func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { - if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets - s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) - } - return - } - - hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) - if err != nil { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) - return - } - - for _, v := range supportedVersions { - if v == s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) - } - // The Version Negotiation packet contains the version that we offered. - // This might be a packet sent by an attacker, or it was corrupted. - return - } - } - - s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) - if s.tracer != nil { - s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) - } - newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) - if !ok { - s.destroyImpl(&VersionNegotiationError{ - Ours: s.config.Versions, - Theirs: supportedVersions, - }) - s.logger.Infof("No compatible QUIC version found.") - return - } - if s.tracer != nil { - s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) - } - - s.logger.Infof("Switching to QUIC version %s.", newVersion) - nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) - s.destroyImpl(&errCloseForRecreating{ - nextPacketNumber: nextPN, - nextVersion: newVersion, - }) -} - -func (s *connection) handleUnpackedPacket( - packet *unpackedPacket, - ecn protocol.ECN, - rcvTime time.Time, - packetSize protocol.ByteCount, // only for logging -) error { - if len(packet.data) == 0 { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "empty packet", - } - } - - if !s.receivedFirstPacket { - s.receivedFirstPacket = true - if !s.versionNegotiated && s.tracer != nil { - var clientVersions, serverVersions []protocol.VersionNumber - switch s.perspective { - case protocol.PerspectiveClient: - clientVersions = s.config.Versions - case protocol.PerspectiveServer: - serverVersions = s.config.Versions - } - s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) - } - // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { - cid := packet.hdr.SrcConnectionID - s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) - s.handshakeDestConnID = cid - s.connIDManager.ChangeInitialConnID(cid) - } - // We create the connection as soon as we receive the first packet from the client. - // We do that before authenticating the packet. - // That means that if the source connection ID was corrupted, - // we might have create a connection with an incorrect source connection ID. - // Once we authenticate the first packet, we need to update it. - if s.perspective == protocol.PerspectiveServer { - if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { - s.handshakeDestConnID = packet.hdr.SrcConnectionID - s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) - } - if s.tracer != nil { - s.tracer.StartedConnection( - s.conn.LocalAddr(), - s.conn.RemoteAddr(), - packet.hdr.SrcConnectionID, - packet.hdr.DestConnectionID, - ) - } - } - } - - s.lastPacketReceivedTime = rcvTime - s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} - s.keepAlivePingSent = false - - // Only used for tracing. - // If we're not tracing, this slice will always remain empty. - var frames []wire.Frame - r := bytes.NewReader(packet.data) - var isAckEliciting bool - for { - frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) - if err != nil { - return err - } - if frame == nil { - break - } - if ackhandler.IsFrameAckEliciting(frame) { - isAckEliciting = true - } - // Only process frames now if we're not logging. - // If we're logging, we need to make sure that the packet_received event is logged first. - if s.tracer == nil { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err - } - } else { - frames = append(frames, frame) - } - } - - if s.tracer != nil { - fs := make([]logging.Frame, len(frames)) - for i, frame := range frames { - fs[i] = logutils.ConvertFrame(frame) - } - s.tracer.ReceivedPacket(packet.hdr, packetSize, fs) - for _, frame := range frames { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err - } - } - } - - return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) -} - -func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { - var err error - wire.LogFrame(s.logger, f, false) - switch frame := f.(type) { - case *wire.CryptoFrame: - err = s.handleCryptoFrame(frame, encLevel) - case *wire.StreamFrame: - err = s.handleStreamFrame(frame) - case *wire.AckFrame: - err = s.handleAckFrame(frame, encLevel) - case *wire.ConnectionCloseFrame: - s.handleConnectionCloseFrame(frame) - case *wire.ResetStreamFrame: - err = s.handleResetStreamFrame(frame) - case *wire.MaxDataFrame: - s.handleMaxDataFrame(frame) - case *wire.MaxStreamDataFrame: - err = s.handleMaxStreamDataFrame(frame) - case *wire.MaxStreamsFrame: - s.handleMaxStreamsFrame(frame) - case *wire.DataBlockedFrame: - case *wire.StreamDataBlockedFrame: - case *wire.StreamsBlockedFrame: - case *wire.StopSendingFrame: - err = s.handleStopSendingFrame(frame) - case *wire.PingFrame: - case *wire.PathChallengeFrame: - s.handlePathChallengeFrame(frame) - case *wire.PathResponseFrame: - // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs - err = errors.New("unexpected PATH_RESPONSE frame") - case *wire.NewTokenFrame: - err = s.handleNewTokenFrame(frame) - case *wire.NewConnectionIDFrame: - err = s.handleNewConnectionIDFrame(frame) - case *wire.RetireConnectionIDFrame: - err = s.handleRetireConnectionIDFrame(frame, destConnID) - case *wire.HandshakeDoneFrame: - err = s.handleHandshakeDoneFrame() - case *wire.DatagramFrame: - err = s.handleDatagramFrame(frame) - default: - err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) - } - return err -} - -// handlePacket is called by the server with a new packet -func (s *connection) handlePacket(p *receivedPacket) { - // Discard packets once the amount of queued packets is larger than - // the channel size, protocol.MaxConnUnprocessedPackets - select { - case s.receivedPackets <- p: - default: - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) - } - } -} - -func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { - if frame.IsApplicationError { - s.closeRemote(&qerr.ApplicationError{ - Remote: true, - ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode), - ErrorMessage: frame.ReasonPhrase, - }) - return - } - s.closeRemote(&qerr.TransportError{ - Remote: true, - ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), - FrameType: frame.FrameType, - ErrorMessage: frame.ReasonPhrase, - }) -} - -func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) - if err != nil { - return err - } - if encLevelChanged { - // Queue all packets for decryption that have been undecryptable so far. - s.undecryptablePacketsToProcess = s.undecryptablePackets - s.undecryptablePackets = nil - } - return nil -} - -func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { - str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // Stream is closed and already garbage collected - // ignore this StreamFrame - return nil - } - return str.handleStreamFrame(frame) -} - -func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { - s.connFlowController.UpdateSendWindow(frame.MaximumData) -} - -func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { - str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - str.updateSendWindow(frame.MaximumStreamData) - return nil -} - -func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { - s.streamsMap.HandleMaxStreamsFrame(frame) -} - -func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { - str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - return str.handleResetStreamFrame(frame) -} - -func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { - str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - str.handleStopSendingFrame(frame) - return nil -} - -func (s *connection) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { - s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) -} - -func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { - if s.perspective == protocol.PerspectiveServer { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received NEW_TOKEN frame from the client", - } - } - if s.config.TokenStore != nil { - s.config.TokenStore.Put(s.tokenStoreKey, &ClientToken{data: frame.Token}) - } - return nil -} - -func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { - return s.connIDManager.Add(f) -} - -func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { - return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) -} - -func (s *connection) handleHandshakeDoneFrame() error { - if s.perspective == protocol.PerspectiveServer { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received a HANDSHAKE_DONE frame", - } - } - if !s.handshakeConfirmed { - s.handleHandshakeConfirmed() - } - return nil -} - -func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) - if err != nil { - return err - } - if !acked1RTTPacket { - return nil - } - if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { - s.handleHandshakeConfirmed() - } - return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) -} - -func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { - if f.Length(s.version) > protocol.MaxDatagramFrameSize { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "DATAGRAM frame too large", - } - } - s.datagramQueue.HandleDatagramFrame(f) - return nil -} - -// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error -func (s *connection) closeLocal(e error) { - s.closeOnce.Do(func() { - if e == nil { - s.logger.Infof("Closing connection.") - } else { - s.logger.Errorf("Closing connection with error: %s", e) - } - s.closeChan <- closeError{err: e, immediate: false, remote: false} - }) -} - -// destroy closes the connection without sending the error on the wire -func (s *connection) destroy(e error) { - s.destroyImpl(e) - <-s.ctx.Done() -} - -func (s *connection) destroyImpl(e error) { - s.closeOnce.Do(func() { - if nerr, ok := e.(net.Error); ok && nerr.Timeout() { - s.logger.Errorf("Destroying connection: %s", e) - } else { - s.logger.Errorf("Destroying connection with error: %s", e) - } - s.closeChan <- closeError{err: e, immediate: true, remote: false} - }) -} - -func (s *connection) closeRemote(e error) { - s.closeOnce.Do(func() { - s.logger.Errorf("Peer closed connection with error: %s", e) - s.closeChan <- closeError{err: e, immediate: true, remote: true} - }) -} - -// Close the connection. It sends a NO_ERROR application error. -// It waits until the run loop has stopped before returning -func (s *connection) shutdown() { - s.closeLocal(nil) - <-s.ctx.Done() -} - -func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { - s.closeLocal(&qerr.ApplicationError{ - ErrorCode: code, - ErrorMessage: desc, - }) - <-s.ctx.Done() - return nil -} - -func (s *connection) handleCloseError(closeErr *closeError) { - e := closeErr.err - if e == nil { - e = &qerr.ApplicationError{} - } else { - defer func() { - closeErr.err = e - }() - } - - var ( - statelessResetErr *StatelessResetError - versionNegotiationErr *VersionNegotiationError - recreateErr *errCloseForRecreating - applicationErr *ApplicationError - transportErr *TransportError - ) - switch { - case errors.Is(e, qerr.ErrIdleTimeout), - errors.Is(e, qerr.ErrHandshakeTimeout), - errors.As(e, &statelessResetErr), - errors.As(e, &versionNegotiationErr), - errors.As(e, &recreateErr), - errors.As(e, &applicationErr), - errors.As(e, &transportErr): - default: - e = &qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: e.Error(), - } - } - - s.streamsMap.CloseWithError(e) - s.connIDManager.Close() - if s.datagramQueue != nil { - s.datagramQueue.CloseWithError(e) - } - - if s.tracer != nil && !errors.As(e, &recreateErr) { - s.tracer.ClosedConnection(e) - } - - // If this is a remote close we're done here - if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) - return - } - if closeErr.immediate { - s.connIDGenerator.RemoveAll() - return - } - // Don't send out any CONNECTION_CLOSE if this is an error that occurred - // before we even sent out the first packet. - if s.perspective == protocol.PerspectiveClient && !s.sentFirstPacket { - s.connIDGenerator.RemoveAll() - return - } - connClosePacket, err := s.sendConnectionClose(e) - if err != nil { - s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) - } - cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) - s.connIDGenerator.ReplaceWithClosed(cs) -} - -func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { - s.sentPacketHandler.DropPackets(encLevel) - s.receivedPacketHandler.DropPackets(encLevel) - if s.tracer != nil { - s.tracer.DroppedEncryptionLevel(encLevel) - } - if encLevel == protocol.Encryption0RTT { - s.streamsMap.ResetFor0RTT() - if err := s.connFlowController.Reset(); err != nil { - s.closeLocal(err) - } - if err := s.framer.Handle0RTTRejection(); err != nil { - s.closeLocal(err) - } - } -} - -// is called for the client, when restoring transport parameters saved for 0-RTT -func (s *connection) restoreTransportParameters(params *wire.TransportParameters) { - if s.logger.Debug() { - s.logger.Debugf("Restoring Transport Parameters: %s", params) - } - - s.peerParams = params - s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) - s.connFlowController.UpdateSendWindow(params.InitialMaxData) - s.streamsMap.UpdateLimits(params) -} - -func (s *connection) handleTransportParameters(params *wire.TransportParameters) { - if err := s.checkTransportParameters(params); err != nil { - s.closeLocal(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - }) - } - s.peerParams = params - // On the client side we have to wait for handshake completion. - // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. - if s.perspective == protocol.PerspectiveServer { - s.applyTransportParameters() - // On the server side, the early connection is ready as soon as we processed - // the client's transport parameters. - close(s.earlyConnReadyChan) - } -} - -func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { - if s.logger.Debug() { - s.logger.Debugf("Processed Transport Parameters: %s", params) - } - if s.tracer != nil { - s.tracer.ReceivedTransportParameters(params) - } - - // check the initial_source_connection_id - if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { - return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) - } - - if s.perspective == protocol.PerspectiveServer { - return nil - } - // check the original_destination_connection_id - if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { - return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) - } - if s.retrySrcConnID != nil { // a Retry was performed - if params.RetrySourceConnectionID == nil { - return errors.New("missing retry_source_connection_id") - } - if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { - return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) - } - } else if params.RetrySourceConnectionID != nil { - return errors.New("received retry_source_connection_id, although no Retry was performed") - } - return nil -} - -func (s *connection) applyTransportParameters() { - params := s.peerParams - // Our local idle timeout will always be > 0. - s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) - s.keepAliveInterval = utils.MinDuration(s.config.KeepAlivePeriod, utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) - s.streamsMap.UpdateLimits(params) - s.packer.HandleTransportParameters(params) - s.frameParser.SetAckDelayExponent(params.AckDelayExponent) - s.connFlowController.UpdateSendWindow(params.InitialMaxData) - s.rttStats.SetMaxAckDelay(params.MaxAckDelay) - s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) - if params.StatelessResetToken != nil { - s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) - } - // We don't support connection migration yet, so we don't have any use for the preferred_address. - if params.PreferredAddress != nil { - // Retire the connection ID. - s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken) - } -} - -func (s *connection) sendPackets() error { - s.pacingDeadline = time.Time{} - - var sentPacket bool // only used in for packets sent in send mode SendAny - for { - sendMode := s.sentPacketHandler.SendMode() - if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { - deadline := s.sentPacketHandler.TimeUntilSend() - if deadline.IsZero() { - deadline = deadlineSendImmediately - } - s.pacingDeadline = deadline - // Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). - // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) - // sends enough ACKs to allow its peer to utilize the bandwidth. - if sentPacket { - return nil - } - sendMode = ackhandler.SendAck - } - switch sendMode { - case ackhandler.SendNone: - return nil - case ackhandler.SendAck: - // If we already sent packets, and the send mode switches to SendAck, - // as we've just become congestion limited. - // There's no need to try to send an ACK at this moment. - if sentPacket { - return nil - } - // We can at most send a single ACK only packet. - // There will only be a new ACK after receiving new packets. - // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. - return s.maybeSendAckOnlyPacket() - case ackhandler.SendPTOInitial: - if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { - return err - } - case ackhandler.SendPTOHandshake: - if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { - return err - } - case ackhandler.SendPTOAppData: - if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { - return err - } - case ackhandler.SendAny: - sent, err := s.sendPacket() - if err != nil || !sent { - return err - } - sentPacket = true - default: - return fmt.Errorf("BUG: invalid send mode %d", sendMode) - } - // Prioritize receiving of packets over sending out more packets. - if len(s.receivedPackets) > 0 { - s.pacingDeadline = deadlineSendImmediately - return nil - } - if s.sendQueue.WouldBlock() { - return nil - } - } -} - -func (s *connection) maybeSendAckOnlyPacket() error { - packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) - if err != nil { - return err - } - if packet == nil { - return nil - } - s.sendPackedPacket(packet, time.Now()) - return nil -} - -func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { - // Queue probe packets until we actually send out a packet, - // or until there are no more packets to queue. - var packet *packedPacket - for { - if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { - break - } - var err error - packet, err = s.packer.MaybePackProbePacket(encLevel) - if err != nil { - return err - } - if packet != nil { - break - } - } - if packet == nil { - //nolint:exhaustive // Cannot send probe packets for 0-RTT. - switch encLevel { - case protocol.EncryptionInitial: - s.retransmissionQueue.AddInitial(&wire.PingFrame{}) - case protocol.EncryptionHandshake: - s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - case protocol.Encryption1RTT: - s.retransmissionQueue.AddAppData(&wire.PingFrame{}) - default: - panic("unexpected encryption level") - } - var err error - packet, err = s.packer.MaybePackProbePacket(encLevel) - if err != nil { - return err - } - } - if packet == nil || packet.packetContents == nil { - return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) - } - s.sendPackedPacket(packet, time.Now()) - return nil -} - -func (s *connection) sendPacket() (bool, error) { - if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { - s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) - } - s.windowUpdateQueue.QueueAll() - - now := time.Now() - if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket() - if err != nil || packet == nil { - return false, err - } - s.sentFirstPacket = true - s.logCoalescedPacket(packet) - for _, p := range packet.packets { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { - s.firstAckElicitingPacketAfterIdleSentTime = now - } - s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) - } - s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer) - return true, nil - } - if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { - packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) - if err != nil { - return false, err - } - s.sendPackedPacket(packet, now) - return true, nil - } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { - return false, err - } - s.sendPackedPacket(packet, now) - return true, nil -} - -func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { - s.firstAckElicitingPacketAfterIdleSentTime = now - } - s.logPacket(packet) - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue)) - s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer) -} - -func (s *connection) sendConnectionClose(e error) ([]byte, error) { - var packet *coalescedPacket - var err error - var transportErr *qerr.TransportError - var applicationErr *qerr.ApplicationError - if errors.As(e, &transportErr) { - packet, err = s.packer.PackConnectionClose(transportErr) - } else if errors.As(e, &applicationErr) { - packet, err = s.packer.PackApplicationClose(applicationErr) - } else { - packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), - }) - } - if err != nil { - return nil, err - } - s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data) -} - -func (s *connection) logPacketContents(p *packetContents) { - // tracing - if s.tracer != nil { - frames := make([]logging.Frame, 0, len(p.frames)) - for _, f := range p.frames { - frames = append(frames, logutils.ConvertFrame(f.Frame)) - } - s.tracer.SentPacket(p.header, p.length, p.ack, frames) - } - - // quic-go logging - if !s.logger.Debug() { - return - } - p.header.Log(s.logger) - if p.ack != nil { - wire.LogFrame(s.logger, p.ack, true) - } - for _, frame := range p.frames { - wire.LogFrame(s.logger, frame.Frame, true) - } -} - -func (s *connection) logCoalescedPacket(packet *coalescedPacket) { - if s.logger.Debug() { - if len(packet.packets) > 1 { - s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) - } else { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel()) - } - } - for _, p := range packet.packets { - s.logPacketContents(p) - } -} - -func (s *connection) logPacket(packet *packedPacket) { - if s.logger.Debug() { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) - } - s.logPacketContents(packet.packetContents) -} - -// AcceptStream returns the next stream openend by the peer -func (s *connection) AcceptStream(ctx context.Context) (Stream, error) { - return s.streamsMap.AcceptStream(ctx) -} - -func (s *connection) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { - return s.streamsMap.AcceptUniStream(ctx) -} - -// OpenStream opens a stream -func (s *connection) OpenStream() (Stream, error) { - return s.streamsMap.OpenStream() -} - -func (s *connection) OpenStreamSync(ctx context.Context) (Stream, error) { - return s.streamsMap.OpenStreamSync(ctx) -} - -func (s *connection) OpenUniStream() (SendStream, error) { - return s.streamsMap.OpenUniStream() -} - -func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) { - return s.streamsMap.OpenUniStreamSync(ctx) -} - -func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { - initialSendWindow := s.peerParams.InitialMaxStreamDataUni - if id.Type() == protocol.StreamTypeBidi { - if id.InitiatedBy() == s.perspective { - initialSendWindow = s.peerParams.InitialMaxStreamDataBidiRemote - } else { - initialSendWindow = s.peerParams.InitialMaxStreamDataBidiLocal - } - } - return flowcontrol.NewStreamFlowController( - id, - s.connFlowController, - protocol.ByteCount(s.config.InitialStreamReceiveWindow), - protocol.ByteCount(s.config.MaxStreamReceiveWindow), - initialSendWindow, - s.onHasStreamWindowUpdate, - s.rttStats, - s.logger, - ) -} - -// scheduleSending signals that we have data for sending -func (s *connection) scheduleSending() { - select { - case s.sendingScheduled <- struct{}{}: - default: - } -} - -func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { - if s.handshakeComplete { - panic("shouldn't queue undecryptable packets after handshake completion") - } - if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDOSPrevention) - } - s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) - return - } - s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) - if s.tracer != nil { - s.tracer.BufferedPacket(logging.PacketTypeFromHeader(hdr)) - } - s.undecryptablePackets = append(s.undecryptablePackets, p) -} - -func (s *connection) queueControlFrame(f wire.Frame) { - s.framer.QueueControlFrame(f) - s.scheduleSending() -} - -func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) { - s.windowUpdateQueue.AddStream(id) - s.scheduleSending() -} - -func (s *connection) onHasConnectionWindowUpdate() { - s.windowUpdateQueue.AddConnection() - s.scheduleSending() -} - -func (s *connection) onHasStreamData(id protocol.StreamID) { - s.framer.AddActiveStream(id) - s.scheduleSending() -} - -func (s *connection) onStreamCompleted(id protocol.StreamID) { - if err := s.streamsMap.DeleteStream(id); err != nil { - s.closeLocal(err) - } -} - -func (s *connection) SendMessage(p []byte) error { - f := &wire.DatagramFrame{DataLenPresent: true} - if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { - return errors.New("message too large") - } - f.Data = make([]byte, len(p)) - copy(f.Data, p) - return s.datagramQueue.AddAndWait(f) -} - -func (s *connection) ReceiveMessage() ([]byte, error) { - return s.datagramQueue.Receive() -} - -func (s *connection) LocalAddr() net.Addr { - return s.conn.LocalAddr() -} - -func (s *connection) RemoteAddr() net.Addr { - return s.conn.RemoteAddr() -} - -func (s *connection) getPerspective() protocol.Perspective { - return s.perspective -} - -func (s *connection) GetVersion() protocol.VersionNumber { - return s.version -} - -func (s *connection) NextConnection() Connection { - <-s.HandshakeComplete().Done() - s.streamsMap.UseResetMaps() - return s -} diff --git a/internal/quic-go/connection_test.go b/internal/quic-go/connection_test.go deleted file mode 100644 index 66c88156..00000000 --- a/internal/quic-go/connection_test.go +++ /dev/null @@ -1,3038 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "runtime/pprof" - "strings" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/mocks" - mockackhandler "github.com/imroc/req/v3/internal/quic-go/mocks/ackhandler" - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/testutils" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func areConnsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*connection).run") -} - -func areClosedConnsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run") -} - -var _ = Describe("Connection", func() { - var ( - conn *connection - connRunner *MockConnRunner - mconn *MockSendConn - streamManager *MockStreamManager - packer *MockPacker - cryptoSetup *mocks.MockCryptoSetup - tracer *mocklogging.MockConnectionTracer - ) - remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - - getPacket := func(pn protocol.PacketNumber) *packedPacket { - buffer := getPacketBuffer() - buffer.Data = append(buffer.Data, []byte("foobar")...) - return &packedPacket{ - buffer: buffer, - packetContents: &packetContents{ - header: &wire.ExtendedHeader{PacketNumber: pn}, - length: 6, // foobar - }, - } - } - - expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) - s.shutdown() - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) - } - - BeforeEach(func() { - Eventually(areConnsRunning).Should(BeFalse()) - - connRunner = NewMockConnRunner(mockCtrl) - mconn = NewMockSendConn(mockCtrl) - mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() - mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentTransportParameters(gomock.Any()) - tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() - tracer.EXPECT().UpdatedCongestionState(gomock.Any()) - conn = newConnection( - mconn, - connRunner, - nil, - nil, - clientDestConnID, - destConnID, - srcConnID, - protocol.StatelessResetToken{}, - populateServerConfig(&Config{DisablePathMTUDiscovery: true}), - nil, // tls.Config - tokenGenerator, - false, - tracer, - 1234, - utils.DefaultLogger, - protocol.VersionTLS, - ).(*connection) - streamManager = NewMockStreamManager(mockCtrl) - conn.streamsMap = streamManager - packer = NewMockPacker(mockCtrl) - conn.packer = packer - cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) - conn.cryptoStreamHandler = cryptoSetup - conn.handshakeComplete = true - conn.idleTimeout = time.Hour - }) - - AfterEach(func() { - Eventually(areConnsRunning).Should(BeFalse()) - }) - - Context("frame handling", func() { - Context("handling STREAM frames", func() { - It("passes STREAM frames to the stream", func() { - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, - } - str := NewMockReceiveStreamI(mockCtrl) - str.EXPECT().handleStreamFrame(f) - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - Expect(conn.handleStreamFrame(f)).To(Succeed()) - }) - - It("returns errors", func() { - testErr := errors.New("test err") - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, - } - str := NewMockReceiveStreamI(mockCtrl) - str.EXPECT().handleStreamFrame(f).Return(testErr) - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - Expect(conn.handleStreamFrame(f)).To(MatchError(testErr)) - }) - - It("ignores STREAM frames for closed streams", func() { - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(nil, nil) // for closed streams, the streamManager returns nil - Expect(conn.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - })).To(Succeed()) - }) - }) - - Context("handling ACK frames", func() { - It("informs the SentPacketHandler about ACKs", func() { - f := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().ReceivedAck(f, protocol.EncryptionHandshake, gomock.Any()) - conn.sentPacketHandler = sph - err := conn.handleAckFrame(f, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - }) - }) - - Context("handling RESET_STREAM frames", func() { - It("closes the streams for writing", func() { - f := &wire.ResetStreamFrame{ - StreamID: 555, - ErrorCode: 42, - FinalSize: 0x1337, - } - str := NewMockReceiveStreamI(mockCtrl) - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(555)).Return(str, nil) - str.EXPECT().handleResetStreamFrame(f) - err := conn.handleResetStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) - }) - - It("returns errors", func() { - f := &wire.ResetStreamFrame{ - StreamID: 7, - FinalSize: 0x1337, - } - testErr := errors.New("flow control violation") - str := NewMockReceiveStreamI(mockCtrl) - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(7)).Return(str, nil) - str.EXPECT().handleResetStreamFrame(f).Return(testErr) - err := conn.handleResetStreamFrame(f) - Expect(err).To(MatchError(testErr)) - }) - - It("ignores RESET_STREAM frames for closed streams", func() { - streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(nil, nil) - Expect(conn.handleFrame(&wire.ResetStreamFrame{ - StreamID: 3, - ErrorCode: 42, - }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - }) - }) - - Context("handling MAX_DATA and MAX_STREAM_DATA frames", func() { - var connFC *mocks.MockConnectionFlowController - - BeforeEach(func() { - connFC = mocks.NewMockConnectionFlowController(mockCtrl) - conn.connFlowController = connFC - }) - - It("updates the flow control window of a stream", func() { - f := &wire.MaxStreamDataFrame{ - StreamID: 12345, - MaximumStreamData: 0x1337, - } - str := NewMockSendStreamI(mockCtrl) - streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(12345)).Return(str, nil) - str.EXPECT().updateSendWindow(protocol.ByteCount(0x1337)) - Expect(conn.handleMaxStreamDataFrame(f)).To(Succeed()) - }) - - It("updates the flow control window of the connection", func() { - offset := protocol.ByteCount(0x800000) - connFC.EXPECT().UpdateSendWindow(offset) - conn.handleMaxDataFrame(&wire.MaxDataFrame{MaximumData: offset}) - }) - - It("ignores MAX_STREAM_DATA frames for a closed stream", func() { - streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(10)).Return(nil, nil) - Expect(conn.handleFrame(&wire.MaxStreamDataFrame{ - StreamID: 10, - MaximumStreamData: 1337, - }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - }) - }) - - Context("handling MAX_STREAM_ID frames", func() { - It("passes the frame to the streamsMap", func() { - f := &wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: 10, - } - streamManager.EXPECT().HandleMaxStreamsFrame(f) - conn.handleMaxStreamsFrame(f) - }) - }) - - Context("handling STOP_SENDING frames", func() { - It("passes the frame to the stream", func() { - f := &wire.StopSendingFrame{ - StreamID: 5, - ErrorCode: 10, - } - str := NewMockSendStreamI(mockCtrl) - streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(5)).Return(str, nil) - str.EXPECT().handleStopSendingFrame(f) - err := conn.handleStopSendingFrame(f) - Expect(err).ToNot(HaveOccurred()) - }) - - It("ignores STOP_SENDING frames for a closed stream", func() { - streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(3)).Return(nil, nil) - Expect(conn.handleFrame(&wire.StopSendingFrame{ - StreamID: 3, - ErrorCode: 1337, - }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - }) - }) - - It("handles NEW_CONNECTION_ID frames", func() { - Expect(conn.handleFrame(&wire.NewConnectionIDFrame{ - SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - }) - - It("handles PING frames", func() { - err := conn.handleFrame(&wire.PingFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("rejects PATH_RESPONSE frames", func() { - err := conn.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).To(MatchError("unexpected PATH_RESPONSE frame")) - }) - - It("handles PATH_CHALLENGE frames", func() { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - err := conn.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).ToNot(HaveOccurred()) - frames, _ := conn.framer.AppendControlFrames(nil, 1000) - Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PathResponseFrame{Data: data}}})) - }) - - It("rejects NEW_TOKEN frames", func() { - err := conn.handleNewTokenFrame(&wire.NewTokenFrame{}) - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) - }) - - It("handles BLOCKED frames", func() { - err := conn.handleFrame(&wire.DataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("handles STREAM_BLOCKED frames", func() { - err := conn.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("handles STREAMS_BLOCKED frames", func() { - err := conn.handleFrame(&wire.StreamsBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("handles CONNECTION_CLOSE frames, with a transport error code", func() { - expectedErr := &qerr.TransportError{ - Remote: true, - ErrorCode: qerr.StreamLimitError, - ErrorMessage: "foobar", - } - streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(expectedErr), - tracer.EXPECT().Close(), - ) - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(conn.run()).To(MatchError(expectedErr)) - }() - Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ - ErrorCode: uint64(qerr.StreamLimitError), - ReasonPhrase: "foobar", - }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("handles CONNECTION_CLOSE frames, with an application error code", func() { - testErr := &qerr.ApplicationError{ - Remote: true, - ErrorCode: 0x1337, - ErrorMessage: "foobar", - } - streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(testErr), - tracer.EXPECT().Close(), - ) - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(conn.run()).To(MatchError(testErr)) - }() - ccf := &wire.ConnectionCloseFrame{ - ErrorCode: 0x1337, - ReasonPhrase: "foobar", - IsApplicationError: true, - } - Expect(conn.handleFrame(ccf, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("errors on HANDSHAKE_DONE frames", func() { - Expect(conn.handleHandshakeDoneFrame()).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received a HANDSHAKE_DONE frame", - })) - }) - }) - - It("tells its versions", func() { - conn.version = 4242 - Expect(conn.GetVersion()).To(Equal(protocol.VersionNumber(4242))) - }) - - Context("closing", func() { - var ( - runErr chan error - expectedRunErr error - ) - - BeforeEach(func() { - runErr = make(chan error, 1) - expectedRunErr = nil - }) - - AfterEach(func() { - if expectedRunErr != nil { - Eventually(runErr).Should(Receive(MatchError(expectedRunErr))) - } else { - Eventually(runErr).Should(Receive()) - } - }) - - runConn := func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - runErr <- conn.run() - }() - Eventually(areConnsRunning).Should(BeTrue()) - } - - It("shuts down without error", func() { - conn.handshakeComplete = true - runConn() - streamManager.EXPECT().CloseWithError(&qerr.ApplicationError{}) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - buffer := getPacketBuffer() - buffer.Data = append(buffer.Data, []byte("connection close")...) - packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { - Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) - Expect(e.ErrorMessage).To(BeEmpty()) - return &coalescedPacket{buffer: buffer}, nil - }) - mconn.EXPECT().Write([]byte("connection close")) - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - var appErr *ApplicationError - Expect(errors.As(e, &appErr)).To(BeTrue()) - Expect(appErr.Remote).To(BeFalse()) - Expect(appErr.ErrorCode).To(BeZero()) - }), - tracer.EXPECT().Close(), - ) - conn.shutdown() - Eventually(areConnsRunning).Should(BeFalse()) - Expect(conn.Context().Done()).To(BeClosed()) - }) - - It("only closes once", func() { - runConn() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - conn.shutdown() - Eventually(areConnsRunning).Should(BeFalse()) - Expect(conn.Context().Done()).To(BeClosed()) - }) - - It("closes with an error", func() { - runConn() - expectedErr := &qerr.ApplicationError{ - ErrorCode: 0x1337, - ErrorMessage: "test error", - } - streamManager.EXPECT().CloseWithError(expectedErr) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) - gomock.InOrder( - tracer.EXPECT().ClosedConnection(expectedErr), - tracer.EXPECT().Close(), - ) - conn.CloseWithError(0x1337, "test error") - Eventually(areConnsRunning).Should(BeFalse()) - Expect(conn.Context().Done()).To(BeClosed()) - }) - - It("includes the frame type in transport-level close frames", func() { - runConn() - expectedErr := &qerr.TransportError{ - ErrorCode: 0x1337, - FrameType: 0x42, - ErrorMessage: "test error", - } - streamManager.EXPECT().CloseWithError(expectedErr) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) - gomock.InOrder( - tracer.EXPECT().ClosedConnection(expectedErr), - tracer.EXPECT().Close(), - ) - conn.closeLocal(expectedErr) - Eventually(areConnsRunning).Should(BeFalse()) - Expect(conn.Context().Done()).To(BeClosed()) - }) - - It("destroys the connection", func() { - runConn() - testErr := errors.New("close") - streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cryptoSetup.EXPECT().Close() - // don't EXPECT any calls to mconn.Write() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - var transportErr *TransportError - Expect(errors.As(e, &transportErr)).To(BeTrue()) - Expect(transportErr.Remote).To(BeFalse()) - Expect(transportErr.ErrorCode).To(Equal(qerr.InternalError)) - }), - tracer.EXPECT().Close(), - ) - conn.destroy(testErr) - Eventually(areConnsRunning).Should(BeFalse()) - expectedRunErr = &qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: testErr.Error(), - } - }) - - It("cancels the context when the run loop exists", func() { - runConn() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - returned := make(chan struct{}) - go func() { - defer GinkgoRecover() - ctx := conn.Context() - <-ctx.Done() - Expect(ctx.Err()).To(MatchError(context.Canceled)) - close(returned) - }() - Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(returned).Should(BeClosed()) - }) - - It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { - unpacker := NewMockUnpacker(mockCtrl) - conn.handshakeConfirmed = true - conn.unpacker = unpacker - runConn() - cryptoSetup.EXPECT().Close() - streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() - buf := &bytes.Buffer{} - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen2, - } - Expect(hdr.Write(buf, conn.version)).To(Succeed()) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { - buf := &bytes.Buffer{} - Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, conn.version)).To(Succeed()) - return &unpackedPacket{ - hdr: hdr, - data: buf.Bytes(), - encryptionLevel: protocol.Encryption1RTT, - }, nil - }) - gomock.InOrder( - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()), - tracer.EXPECT().ClosedConnection(gomock.Any()), - tracer.EXPECT().Close(), - ) - // don't EXPECT any calls to packer.PackPacket() - conn.handlePacket(&receivedPacket{ - rcvTime: time.Now(), - remoteAddr: &net.UDPAddr{}, - buffer: getPacketBuffer(), - data: buf.Bytes(), - }) - // Consistently(pack).ShouldNot(Receive()) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("closes when the sendQueue encounters an error", func() { - conn.handshakeConfirmed = true - sconn := NewMockSendConn(mockCtrl) - sconn.EXPECT().Write(gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() - conn.sendQueue = newSendQueue(sconn) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - // only expect a single SentPacket() call - sph.EXPECT().SentPacket(gomock.Any()) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cryptoSetup.EXPECT().Close() - conn.sentPacketHandler = sph - p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - runConn() - conn.queueControlFrame(&wire.PingFrame{}) - conn.scheduleSending() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("closes due to a stateless reset", func() { - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - runConn() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - var srErr *StatelessResetError - Expect(errors.As(e, &srErr)).To(BeTrue()) - Expect(srErr.Token).To(Equal(token)) - }), - tracer.EXPECT().Close(), - ) - streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cryptoSetup.EXPECT().Close() - conn.destroy(&StatelessResetError{Token: token}) - }) - }) - - Context("receiving packets", func() { - var unpacker *MockUnpacker - - BeforeEach(func() { - unpacker = NewMockUnpacker(mockCtrl) - conn.unpacker = unpacker - }) - - getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, conn.version)).To(Succeed()) - return &receivedPacket{ - data: append(buf.Bytes(), data...), - buffer: getPacketBuffer(), - rcvTime: time.Now(), - } - } - - It("drops Retry packets", func() { - p := getPacket(&wire.ExtendedHeader{Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - DestConnectionID: destConnID, - SrcConnectionID: srcConnID, - Version: conn.version, - Token: []byte("foobar"), - }}, make([]byte, 16) /* Retry integrity tag */) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("drops Version Negotiation packets", func() { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) - Expect(conn.handlePacketImpl(&receivedPacket{ - data: b, - buffer: getPacketBuffer(), - })).To(BeFalse()) - }) - - It("drops packets for which header decryption fails", func() { - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }, nil) - p.data[0] ^= 0x40 // unset the QUIC bit - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("drops packets for which the version is unsupported", func() { - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: conn.version + 1, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnsupportedVersion) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("drops packets with an unsupported version", func() { - origSupportedVersions := make([]protocol.VersionNumber, len(protocol.SupportedVersions)) - copy(origSupportedVersions, protocol.SupportedVersions) - defer func() { - protocol.SupportedVersions = origSupportedVersions - }() - - protocol.SupportedVersions = append(protocol.SupportedVersions, conn.version+1) - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: destConnID, - SrcConnectionID: srcConnID, - Version: conn.version + 1, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedVersion) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x37, - PacketNumberLen: protocol.PacketNumberLen1, - } - packet := getPacket(hdr, nil) - packet.ecn = protocol.ECNCE - rcvTime := time.Now().Add(-10 * time.Second) - unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ - packetNumber: 0x1337, - encryptionLevel: protocol.EncryptionInitial, - hdr: hdr, - data: []byte{0}, // one PADDING frame - }, nil) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - gomock.InOrder( - rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false), - ) - conn.receivedPacketHandler = rph - packet.rcvTime = rcvTime - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{}) - Expect(conn.handlePacketImpl(packet)).To(BeTrue()) - }) - - It("informs the ReceivedPacketHandler about ack-eliciting packets", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x37, - PacketNumberLen: protocol.PacketNumberLen1, - } - rcvTime := time.Now().Add(-10 * time.Second) - buf := &bytes.Buffer{} - Expect((&wire.PingFrame{}).Write(buf, conn.version)).To(Succeed()) - packet := getPacket(hdr, nil) - packet.ecn = protocol.ECT1 - unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ - packetNumber: 0x1337, - encryptionLevel: protocol.Encryption1RTT, - hdr: hdr, - data: buf.Bytes(), - }, nil) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - gomock.InOrder( - rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true), - ) - conn.receivedPacketHandler = rph - packet.rcvTime = rcvTime - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) - Expect(conn.handlePacketImpl(packet)).To(BeTrue()) - }) - - It("drops duplicate packets", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x37, - PacketNumberLen: protocol.PacketNumberLen1, - } - packet := getPacket(hdr, nil) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - packetNumber: 0x1337, - encryptionLevel: protocol.Encryption1RTT, - hdr: hdr, - data: []byte("foobar"), - }, nil) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) - conn.receivedPacketHandler = rph - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) - Expect(conn.handlePacketImpl(packet)).To(BeFalse()) - }) - - It("drops a packet when unpacking fails", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - expectReplaceWithClosed() - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: srcConnID, - Version: conn.version, - Length: 2 + 6, - }, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropPayloadDecryptError) - conn.handlePacket(p) - Consistently(conn.Context().Done()).ShouldNot(BeClosed()) - // make the go routine return - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - conn.closeLocal(errors.New("close")) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("processes multiple received packets before sending one", func() { - conn.creationTime = time.Now() - var pn protocol.PacketNumber - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { - pn++ - return &unpackedPacket{ - data: []byte{0}, // PADDING frame - encryptionLevel: protocol.Encryption1RTT, - packetNumber: pn, - hdr: &wire.ExtendedHeader{Header: *hdr}, - }, nil - }).Times(3) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) - packer.EXPECT().PackCoalescedPacket() // only expect a single call - - for i := 0; i < 3; i++ { - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar"))) - } - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - Consistently(conn.Context().Done()).ShouldNot(BeClosed()) - - // make the go routine return - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - conn.closeLocal(errors.New("close")) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("doesn't processes multiple received packets before sending one before handshake completion", func() { - conn.handshakeComplete = false - conn.creationTime = time.Now() - var pn protocol.PacketNumber - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { - pn++ - return &unpackedPacket{ - data: []byte{0}, // PADDING frame - encryptionLevel: protocol.Encryption1RTT, - packetNumber: pn, - hdr: &wire.ExtendedHeader{Header: *hdr}, - }, nil - }).Times(3) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) - packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call - - for i := 0; i < 3; i++ { - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar"))) - } - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - Consistently(conn.Context().Done()).ShouldNot(BeClosed()) - - // make the go routine return - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - conn.closeLocal(errors.New("close")) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := conn.run() - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) - close(done) - }() - expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any()) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.handlePacket(packet) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("ignores packets when unpacking the header fails", func() { - testErr := &headerParseError{errors.New("test error")} - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - runErr := make(chan error) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - runErr <- conn.run() - }() - expectReplaceWithClosed() - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil)) - Consistently(runErr).ShouldNot(Receive()) - // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("closes the connection when unpacking fails because of an error other than a decryption error", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := conn.run() - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ConnectionIDLimitError)) - close(done) - }() - expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any()) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.handlePacket(packet) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("rejects packets with empty payload", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - hdr: &wire.ExtendedHeader{}, - data: []byte{}, // no payload - encryptionLevel: protocol.Encryption1RTT, - }, nil) - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(conn.run()).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "empty packet", - })) - close(done) - }() - expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil)) - Eventually(done).Should(BeClosed()) - }) - - It("ignores packets with a different source connection ID", func() { - hdr1 := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: destConnID, - SrcConnectionID: srcConnID, - Length: 1, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 1, - } - hdr2 := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - Length: 1, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 2, - } - Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) - // Send one packet, which might change the connection ID. - // only EXPECT one call to the unpacker - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: hdr1, - data: []byte{0}, // one PADDING frame - }, nil) - p1 := getPacket(hdr1, nil) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) - Expect(conn.handlePacketImpl(p1)).To(BeTrue()) - // The next packet has to be ignored, since the source connection ID doesn't match. - p2 := getPacket(hdr2, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) - Expect(conn.handlePacketImpl(p2)).To(BeFalse()) - }) - - It("queues undecryptable packets", func() { - conn.handshakeComplete = false - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: destConnID, - SrcConnectionID: srcConnID, - Length: 1, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 1, - } - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) - packet := getPacket(hdr, nil) - tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake) - Expect(conn.handlePacketImpl(packet)).To(BeFalse()) - Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) - }) - - Context("updating the remote address", func() { - It("doesn't support connection migration", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{}, - data: []byte{0}, // one PADDING frame - }, nil) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) - packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) - Expect(conn.handlePacketImpl(packet)).To(BeTrue()) - }) - }) - - Context("coalesced packets", func() { - BeforeEach(func() { - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - }) - getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: connID, - SrcConnectionID: destConnID, - Version: protocol.VersionTLS, - Length: length, - }, - PacketNumberLen: protocol.PacketNumberLen3, - } - hdrLen := hdr.GetLength(conn.version) - b := make([]byte, 1) - rand.Read(b) - packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) - return int(hdrLen), packet - } - - It("cuts packets to the right length", func() { - hdrLen, packet := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(hdrLen + 456 - 3)) - return &unpackedPacket{ - encryptionLevel: protocol.EncryptionHandshake, - data: []byte{0}, - hdr: &wire.ExtendedHeader{}, - }, nil - }) - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) - Expect(conn.handlePacketImpl(packet)).To(BeTrue()) - }) - - It("handles coalesced packets", func() { - hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) - return &unpackedPacket{ - encryptionLevel: protocol.EncryptionHandshake, - data: []byte{0}, - packetNumber: 1, - hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, - }, nil - }) - hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) - return &unpackedPacket{ - encryptionLevel: protocol.EncryptionHandshake, - data: []byte{0}, - packetNumber: 2, - hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, - }, nil - }) - gomock.InOrder( - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), - ) - packet1.data = append(packet1.data, packet2.data...) - Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) - }) - - It("works with undecryptable packets", func() { - conn.handshakeComplete = false - hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) - gomock.InOrder( - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) - return &unpackedPacket{ - encryptionLevel: protocol.EncryptionHandshake, - data: []byte{0}, - hdr: &wire.ExtendedHeader{}, - }, nil - }), - ) - gomock.InOrder( - tracer.EXPECT().BufferedPacket(gomock.Any()), - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), - ) - packet1.data = append(packet1.data, packet2.data...) - Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) - - Expect(conn.undecryptablePackets).To(HaveLen(1)) - Expect(conn.undecryptablePackets[0].data).To(HaveLen(hdrLen1 + 456 - 3)) - }) - - It("ignores coalesced packet parts if the destination connection IDs don't match", func() { - wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - Expect(srcConnID).ToNot(Equal(wrongConnID)) - hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) - return &unpackedPacket{ - encryptionLevel: protocol.EncryptionHandshake, - data: []byte{0}, - hdr: &wire.ExtendedHeader{}, - }, nil - }) - _, packet2 := getPacketWithLength(wrongConnID, 123) - // don't EXPECT any more calls to unpacker.Unpack() - gomock.InOrder( - tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), - tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), - ) - packet1.data = append(packet1.data, packet2.data...) - Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) - }) - }) - }) - - Context("sending packets", func() { - var ( - connDone chan struct{} - sender *MockSender - ) - - BeforeEach(func() { - sender = NewMockSender(mockCtrl) - sender.EXPECT().Run() - sender.EXPECT().WouldBlock().AnyTimes() - conn.sendQueue = sender - connDone = make(chan struct{}) - }) - - AfterEach(func() { - streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - sender.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - Eventually(connDone).Should(BeClosed()) - }) - - runConn := func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - close(connDone) - }() - } - - It("sends packets", func() { - conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - conn.sentPacketHandler = sph - runConn() - p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - sent := make(chan struct{}) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []logging.Frame{}) - conn.scheduleSending() - Eventually(sent).Should(BeClosed()) - }) - - It("doesn't send packets if there's nothing to send", func() { - conn.handshakeConfirmed = true - runConn() - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() - }) - - It("sends ACK only packets", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - done := make(chan struct{}) - packer.EXPECT().MaybePackAckPacket(false).Do(func(bool) { close(done) }) - conn.sentPacketHandler = sph - runConn() - conn.scheduleSending() - Eventually(done).Should(BeClosed()) - }) - - It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { - conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - conn.sentPacketHandler = sph - fc := mocks.NewMockConnectionFlowController(mockCtrl) - fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) - fc.EXPECT().IsNewlyBlocked() - p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - conn.connFlowController = fc - runConn() - sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentPacket(p.header, p.length, nil, []logging.Frame{}) - conn.scheduleSending() - Eventually(sent).Should(BeClosed()) - frames, _ := conn.framer.AppendControlFrames(nil, 1000) - Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &logging.DataBlockedFrame{MaximumData: 1337}}})) - }) - - It("doesn't send when the SentPacketHandler doesn't allow it", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendNone).AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - conn.sentPacketHandler = sph - runConn() - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) - }) - - for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { - encLevel := enc - - Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { - var sendMode ackhandler.SendMode - var getFrame func(protocol.ByteCount) wire.Frame - - BeforeEach(func() { - //nolint:exhaustive - switch encLevel { - case protocol.EncryptionInitial: - sendMode = ackhandler.SendPTOInitial - getFrame = conn.retransmissionQueue.GetInitialFrame - case protocol.EncryptionHandshake: - sendMode = ackhandler.SendPTOHandshake - getFrame = conn.retransmissionQueue.GetHandshakeFrame - case protocol.Encryption1RTT: - sendMode = ackhandler.SendPTOAppData - getFrame = conn.retransmissionQueue.GetAppDataFrame - } - }) - - It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().Return(sendMode) - sph.EXPECT().SendMode().Return(ackhandler.SendNone) - sph.EXPECT().QueueProbePacket(encLevel) - p := getPacket(123) - packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - conn.sentPacketHandler = sph - runConn() - sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - conn.scheduleSending() - Eventually(sent).Should(BeClosed()) - }) - - It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().Return(sendMode) - sph.EXPECT().SendMode().Return(ackhandler.SendNone) - sph.EXPECT().QueueProbePacket(encLevel).Return(false) - p := getPacket(123) - packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - conn.sentPacketHandler = sph - runConn() - sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - conn.scheduleSending() - Eventually(sent).Should(BeClosed()) - // We're using a mock packet packer in this test. - // We therefore need to test separately that the PING was actually queued. - Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) - }) - }) - } - }) - - Context("packet pacing", func() { - var ( - sph *mockackhandler.MockSentPacketHandler - sender *MockSender - ) - - BeforeEach(func() { - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - conn.handshakeConfirmed = true - conn.handshakeComplete = true - conn.sentPacketHandler = sph - sender = NewMockSender(mockCtrl) - sender.EXPECT().Run() - conn.sendQueue = sender - streamManager.EXPECT().CloseWithError(gomock.Any()) - }) - - AfterEach(func() { - // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - sender.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(2) - sph.EXPECT().HasPacingBudget().Return(true).Times(2) - sph.EXPECT().HasPacingBudget() - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) - packer.EXPECT().PackPacket().Return(getPacket(10), nil) - packer.EXPECT().PackPacket().Return(getPacket(11), nil) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).Times(2) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent - }) - - It("sends multiple packets, when the pacer allows immediate sending", func() { - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) - packer.EXPECT().PackPacket().Return(getPacket(10), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent - }) - - It("allows an ACK to be sent when pacing limited", func() { - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget() - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().MaybePackAckPacket(gomock.Any()).Return(getPacket(10), nil) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent - }) - - // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck - // we shouldn't send the ACK in the same run - It("doesn't send an ACK right after becoming congestion limited", func() { - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true) - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - packer.EXPECT().PackPacket().Return(getPacket(100), nil) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent - }) - - It("paces packets", func() { - pacingDelay := scaleDuration(100 * time.Millisecond) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - gomock.InOrder( - sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket().Return(getPacket(100), nil), - sph.EXPECT().SentPacket(gomock.Any()), - sph.EXPECT().HasPacingBudget(), - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), - sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket().Return(getPacket(101), nil), - sph.EXPECT().SentPacket(gomock.Any()), - sph.EXPECT().HasPacingBudget(), - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), - ) - written := make(chan struct{}, 2) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(2) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - Eventually(written).Should(HaveLen(1)) - Consistently(written, pacingDelay/2).Should(HaveLen(1)) - Eventually(written, 2*pacingDelay).Should(HaveLen(2)) - }) - - It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(3) - sph.EXPECT().HasPacingBudget().Return(true).Times(3) - sph.EXPECT().HasPacingBudget() - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4) - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackPacket().Return(getPacket(1002), nil) - written := make(chan struct{}, 3) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - Eventually(written).Should(HaveLen(3)) - }) - - It("doesn't try to send if the send queue is full", func() { - available := make(chan struct{}, 1) - sender.EXPECT().WouldBlock().Return(true) - sender.EXPECT().Available().Return(available) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(scaleDuration(50 * time.Millisecond)) - - written := make(chan struct{}) - sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) - available <- struct{}{} - Eventually(written).Should(BeClosed()) - }) - - It("stops sending when there are new packets to receive", func() { - sender.EXPECT().WouldBlock().AnyTimes() - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - - written := make(chan struct{}) - sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) { - sph.EXPECT().ReceivedBytes(gomock.Any()) - conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) - }) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) - - conn.scheduleSending() - time.Sleep(scaleDuration(50 * time.Millisecond)) - - Eventually(written).Should(BeClosed()) - }) - - It("stops sending when the send queue is full", func() { - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - written := make(chan struct{}, 1) - sender.EXPECT().WouldBlock() - sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - available := make(chan struct{}, 1) - sender.EXPECT().Available().Return(available) - conn.scheduleSending() - Eventually(written).Should(Receive()) - time.Sleep(scaleDuration(50 * time.Millisecond)) - - // now make room in the send queue - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) - available <- struct{}{} - Eventually(written).Should(Receive()) - - // The send queue is not full any more. Sending on the available channel should have no effect. - available <- struct{}{} - time.Sleep(scaleDuration(50 * time.Millisecond)) - }) - - It("doesn't set a pacing timer when there is no data to send", func() { - sph.EXPECT().HasPacingBudget().Return(true) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket() - // don't EXPECT any calls to mconn.Write() - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() // no packet will get sent - time.Sleep(50 * time.Millisecond) - }) - - It("sends a Path MTU probe packet", func() { - mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl) - conn.mtuDiscoverer = mtuDiscoverer - conn.config.DisablePathMTUDiscovery = false - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().SendMode().Return(ackhandler.SendNone) - written := make(chan struct{}, 1) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) - mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) - ping := ackhandler.Frame{Frame: &wire.PingFrame{}} - mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) - packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - Eventually(written).Should(Receive()) - }) - }) - - Context("scheduling sending", func() { - var sender *MockSender - - BeforeEach(func() { - sender = NewMockSender(mockCtrl) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Run() - conn.sendQueue = sender - conn.handshakeConfirmed = true - }) - - AfterEach(func() { - // make the go routine return - expectReplaceWithClosed() - streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - sender.EXPECT().Close() - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("sends when scheduleSending is called", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - conn.sentPacketHandler = sph - packer.EXPECT().PackPacket().Return(getPacket(1), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - // don't EXPECT any calls to mconn.Write() - time.Sleep(50 * time.Millisecond) - // only EXPECT calls after scheduleSending is called - written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - conn.scheduleSending() - Eventually(written).Should(BeClosed()) - }) - - It("sets the timer to the ack timer", func() { - packer.EXPECT().PackPacket().Return(getPacket(1234), nil) - packer.EXPECT().PackPacket().Return(nil, nil) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1234))) - }) - conn.sentPacketHandler = sph - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) - // make the run loop wait - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) - conn.receivedPacketHandler = rph - - written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - Eventually(written).Should(BeClosed()) - }) - }) - - It("sends coalesced packets before the handshake is confirmed", func() { - conn.handshakeComplete = false - conn.handshakeConfirmed = false - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - buffer := getPacketBuffer() - buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ - buffer: buffer, - packets: []*packetContents{ - { - header: &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - }, - PacketNumber: 13, - }, - length: 123, - }, - { - header: &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - }, - PacketNumber: 37, - }, - length: 1234, - }, - }, - }, nil) - packer.EXPECT().PackCoalescedPacket().AnyTimes() - - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() - gomock.InOrder( - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionInitial)) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(13))) - Expect(p.Length).To(BeEquivalentTo(123)) - }), - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionHandshake)) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(37))) - Expect(p.Length).To(BeEquivalentTo(1234)) - }), - ) - gomock.InOrder( - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { - Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) - }), - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { - Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - }), - ) - - sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(sent) }) - - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - - conn.scheduleSending() - Eventually(sent).Should(BeClosed()) - - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("cancels the HandshakeComplete context when the handshake completes", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() - finishHandshake := make(chan struct{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().AnyTimes() - sph.EXPECT().SetHandshakeConfirmed() - connRunner.EXPECT().Retire(clientDestConnID) - go func() { - defer GinkgoRecover() - <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().GetSessionTicket() - close(conn.handshakeCompleteChan) - conn.run() - }() - handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - close(finishHandshake) - Eventually(handshakeCtx.Done()).Should(BeClosed()) - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("sends a connection ticket when the handshake completes", func() { - const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 - packer.EXPECT().PackCoalescedPacket().AnyTimes() - finishHandshake := make(chan struct{}) - connRunner.EXPECT().Retire(clientDestConnID) - go func() { - defer GinkgoRecover() - <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) - close(conn.handshakeCompleteChan) - conn.run() - }() - - handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - close(finishHandshake) - var frames []ackhandler.Frame - Eventually(func() []ackhandler.Frame { - frames, _ = conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) - return frames - }).ShouldNot(BeEmpty()) - var count int - var s int - for _, f := range frames { - if cf, ok := f.Frame.(*wire.CryptoFrame); ok { - count++ - s += len(cf.Data) - Expect(f.Length(conn.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) - } - } - Expect(size).To(BeEquivalentTo(s)) - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() - conn.run() - }() - handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any()) - conn.closeLocal(errors.New("handshake error")) - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SetHandshakeConfirmed() - sph.EXPECT().SentPacket(gomock.Any()) - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph - done := make(chan struct{}) - connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { - frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(frames).ToNot(BeEmpty()) - Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) - defer close(done) - return &packedPacket{ - packetContents: &packetContents{ - header: &wire.ExtendedHeader{}, - }, - buffer: getPacketBuffer(), - }, nil - }) - packer.EXPECT().PackPacket().AnyTimes() - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().Write(gomock.Any()) - close(conn.handshakeCompleteChan) - conn.run() - }() - Eventually(done).Should(BeClosed()) - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("doesn't return a run error when closing", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(conn.run()).To(Succeed()) - close(done) - }() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(done).Should(BeClosed()) - }) - - It("passes errors to the connection runner", func() { - testErr := errors.New("handshake error") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := conn.run() - Expect(err).To(MatchError(&qerr.ApplicationError{ - ErrorCode: 0x1337, - ErrorMessage: testErr.Error(), - })) - close(done) - }() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - Context("transport parameters", func() { - It("processes transport parameters received from the client", func() { - params := &wire.TransportParameters{ - MaxIdleTimeout: 90 * time.Second, - InitialMaxStreamDataBidiLocal: 0x5000, - InitialMaxData: 0x5000, - ActiveConnectionIDLimit: 3, - // marshaling always sets it to this value - MaxUDPPayloadSize: protocol.MaxPacketBufferSize, - InitialSourceConnectionID: destConnID, - } - streamManager.EXPECT().UpdateLimits(params) - packer.EXPECT().HandleTransportParameters(params) - packer.EXPECT().PackCoalescedPacket().MaxTimes(3) - Expect(conn.earlyConnReady()).ToNot(BeClosed()) - connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) - connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Expect(conn.earlyConnReady()).To(BeClosed()) - }) - }) - - Context("keep-alives", func() { - setRemoteIdleTimeout := func(t time.Duration) { - streamManager.EXPECT().UpdateLimits(gomock.Any()) - packer.EXPECT().HandleTransportParameters(gomock.Any()) - tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) - conn.handleTransportParameters(&wire.TransportParameters{ - MaxIdleTimeout: t, - InitialSourceConnectionID: destConnID, - }) - } - - runConn := func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - } - - BeforeEach(func() { - conn.config.MaxIdleTimeout = 30 * time.Second - conn.config.KeepAlivePeriod = 15 * time.Second - conn.receivedPacketHandler.ReceivedPacket(0, protocol.ECNNon, protocol.EncryptionHandshake, time.Now(), true) - }) - - AfterEach(func() { - // make the go routine return - expectReplaceWithClosed() - streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("sends a PING as a keep-alive after half the idle timeout", func() { - setRemoteIdleTimeout(5 * time.Second) - conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) - sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { - close(sent) - return nil, nil - }) - runConn() - Eventually(sent).Should(BeClosed()) - }) - - It("sends a PING after a maximum of protocol.MaxKeepAliveInterval", func() { - conn.config.MaxIdleTimeout = time.Hour - setRemoteIdleTimeout(time.Hour) - conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) - sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { - close(sent) - return nil, nil - }) - runConn() - Eventually(sent).Should(BeClosed()) - }) - - It("doesn't send a PING packet if keep-alive is disabled", func() { - setRemoteIdleTimeout(5 * time.Second) - conn.config.KeepAlivePeriod = 0 - conn.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) - runConn() - // don't EXPECT() any calls to mconn.Write() - time.Sleep(50 * time.Millisecond) - }) - - It("doesn't send a PING if the handshake isn't completed yet", func() { - conn.config.HandshakeIdleTimeout = time.Hour - conn.handshakeComplete = false - // Needs to be shorter than our idle timeout. - // Otherwise we'll try to send a CONNECTION_CLOSE. - conn.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) - runConn() - // don't EXPECT() any calls to mconn.Write() - time.Sleep(50 * time.Millisecond) - }) - }) - - Context("timeouts", func() { - BeforeEach(func() { - streamManager.EXPECT().CloseWithError(gomock.Any()) - }) - - It("times out due to no network activity", func() { - connRunner.EXPECT().Remove(gomock.Any()).Times(2) - conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) - done := make(chan struct{}) - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(e).To(MatchError(&qerr.IdleTimeoutError{})) - }), - tracer.EXPECT().Close(), - ) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := conn.run() - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err).To(MatchError(qerr.ErrIdleTimeout)) - close(done) - }() - Eventually(done).Should(BeClosed()) - }) - - It("times out due to non-completed handshake", func() { - conn.handshakeComplete = false - conn.creationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) - connRunner.EXPECT().Remove(gomock.Any()).Times(2) - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(e).To(MatchError(&HandshakeTimeoutError{})) - }), - tracer.EXPECT().Close(), - ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := conn.run() - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err).To(MatchError(qerr.ErrHandshakeTimeout)) - close(done) - }() - Eventually(done).Should(BeClosed()) - }) - - It("does not use the idle timeout before the handshake complete", func() { - conn.handshakeComplete = false - conn.config.HandshakeIdleTimeout = 9999 * time.Second - conn.config.MaxIdleTimeout = 9999 * time.Second - conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) - packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { - Expect(e.ErrorCode).To(BeZero()) - return &coalescedPacket{buffer: getPacketBuffer()}, nil - }) - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - idleTimeout := &IdleTimeoutError{} - handshakeTimeout := &HandshakeTimeoutError{} - Expect(errors.As(e, &idleTimeout)).To(BeFalse()) - Expect(errors.As(e, &handshakeTimeout)).To(BeFalse()) - }), - tracer.EXPECT().Close(), - ) - // the handshake timeout is irrelevant here, since it depends on the time the connection was created, - // and not on the last network activity - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - Consistently(conn.Context().Done()).ShouldNot(BeClosed()) - // make the go routine return - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("closes the connection due to the idle timeout before handshake", func() { - conn.config.HandshakeIdleTimeout = 0 - packer.EXPECT().PackCoalescedPacket().AnyTimes() - connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(e).To(MatchError(&IdleTimeoutError{})) - }), - tracer.EXPECT().Close(), - ) - done := make(chan struct{}) - conn.handshakeComplete = false - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) - err := conn.run() - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err).To(MatchError(qerr.ErrIdleTimeout)) - close(done) - }() - Eventually(done).Should(BeClosed()) - }) - - It("closes the connection due to the idle timeout after handshake", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() - gomock.InOrder( - connRunner.EXPECT().Retire(clientDestConnID), - connRunner.EXPECT().Remove(gomock.Any()), - ) - cryptoSetup.EXPECT().Close() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(e).To(MatchError(&IdleTimeoutError{})) - }), - tracer.EXPECT().Close(), - ) - conn.idleTimeout = 0 - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) - cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) - close(conn.handshakeCompleteChan) - err := conn.run() - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err).To(MatchError(qerr.ErrIdleTimeout)) - close(done) - }() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't time out when it just sent a packet", func() { - conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) - conn.firstAckElicitingPacketAfterIdleSentTime = time.Now().Add(-time.Second) - conn.idleTimeout = 30 * time.Second - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - Consistently(conn.Context().Done()).ShouldNot(BeClosed()) - // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - }) - - It("stores up to MaxConnUnprocessedPackets packets", func() { - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.ByteCount, logging.PacketDropReason) { - close(done) - }) - // Nothing here should block - for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ { - conn.handlePacket(&receivedPacket{data: []byte("foobar")}) - } - Eventually(done).Should(BeClosed()) - }) - - Context("getting streams", func() { - It("opens streams", func() { - mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().OpenStream().Return(mstr, nil) - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - - It("opens streams synchronously", func() { - mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil) - str, err := conn.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - - It("opens unidirectional streams", func() { - mstr := NewMockSendStreamI(mockCtrl) - streamManager.EXPECT().OpenUniStream().Return(mstr, nil) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - - It("opens unidirectional streams synchronously", func() { - mstr := NewMockSendStreamI(mockCtrl) - streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil) - str, err := conn.OpenUniStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - - It("accepts streams", func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil) - str, err := conn.AcceptStream(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - - It("accepts unidirectional streams", func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - mstr := NewMockReceiveStreamI(mockCtrl) - streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil) - str, err := conn.AcceptUniStream(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - }) - - It("returns the local address", func() { - Expect(conn.LocalAddr()).To(Equal(localAddr)) - }) - - It("returns the remote address", func() { - Expect(conn.RemoteAddr()).To(Equal(remoteAddr)) - }) -}) - -var _ = Describe("Client Connection", func() { - var ( - conn *connection - connRunner *MockConnRunner - packer *MockPacker - mconn *MockSendConn - cryptoSetup *mocks.MockCryptoSetup - tracer *mocklogging.MockConnectionTracer - tlsConf *tls.Config - quicConf *Config - ) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - - getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(hdr.Write(buf, conn.version)).To(Succeed()) - return &receivedPacket{ - data: append(buf.Bytes(), data...), - buffer: getPacketBuffer(), - } - } - - expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - s.shutdown() - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) - } - - BeforeEach(func() { - quicConf = populateClientConfig(&Config{}, true) - tlsConf = nil - }) - - JustBeforeEach(func() { - Eventually(areConnsRunning).Should(BeFalse()) - - mconn = NewMockSendConn(mockCtrl) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() - mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - if tlsConf == nil { - tlsConf = &tls.Config{} - } - connRunner = NewMockConnRunner(mockCtrl) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentTransportParameters(gomock.Any()) - tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() - tracer.EXPECT().UpdatedCongestionState(gomock.Any()) - conn = newClientConnection( - mconn, - connRunner, - destConnID, - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - quicConf, - tlsConf, - 42, // initial packet number - false, - false, - tracer, - 1234, - utils.DefaultLogger, - protocol.VersionTLS, - ).(*connection) - packer = NewMockPacker(mockCtrl) - conn.packer = packer - cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) - conn.cryptoStreamHandler = cryptoSetup - conn.sentFirstPacket = true - }) - - It("changes the connection ID when receiving the first packet from the server", func() { - unpacker := NewMockUnpacker(mockCtrl) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { - return &unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{Header: *hdr}, - data: []byte{0}, // one PADDING frame - }, nil - }) - conn.unpacker = unpacker - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - SrcConnectionID: newConnID, - DestConnectionID: srcConnID, - Length: 2 + 6, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar")) - tracer.EXPECT().ReceivedPacket(gomock.Any(), p.Size(), []logging.Frame{}) - Expect(conn.handlePacketImpl(p)).To(BeTrue()) - // make sure the go routine returns - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - }) - - It("continues accepting Long Header packets after using a new connection ID", func() { - unpacker := NewMockUnpacker(mockCtrl) - conn.unpacker = unpacker - connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) - conn.connIDManager.SetHandshakeComplete() - conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ - SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, - }) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) - // now receive a packet with the original source connection ID - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { - return &unpackedPacket{ - hdr: &wire.ExtendedHeader{Header: *hdr}, - data: []byte{0}, - encryptionLevel: protocol.EncryptionHandshake, - }, nil - }) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: srcConnID, - SrcConnectionID: destConnID, - } - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) - }) - - It("handles HANDSHAKE_DONE frames", func() { - conn.peerParams = &wire.TransportParameters{} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - sph.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().SetHandshakeConfirmed() - Expect(conn.handleHandshakeDoneFrame()).To(Succeed()) - }) - - It("interprets an ACK for 1-RTT packets as confirmation of the handshake", func() { - conn.peerParams = &wire.TransportParameters{} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}} - sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil) - sph.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) - cryptoSetup.EXPECT().SetHandshakeConfirmed() - Expect(conn.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) - }) - - It("doesn't send a CONNECTION_CLOSE when no packet was sent", func() { - conn.sentFirstPacket = false - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - running := make(chan struct{}) - cryptoSetup.EXPECT().RunHandshake().Do(func() { - close(running) - conn.closeLocal(errors.New("early error")) - }) - cryptoSetup.EXPECT().Close() - connRunner.EXPECT().Remove(gomock.Any()) - go func() { - defer GinkgoRecover() - conn.run() - }() - Eventually(running).Should(BeClosed()) - Eventually(areConnsRunning).Should(BeFalse()) - }) - - Context("handling tokens", func() { - var mockTokenStore *MockTokenStore - - BeforeEach(func() { - mockTokenStore = NewMockTokenStore(mockCtrl) - tlsConf = &tls.Config{ServerName: "server"} - quicConf.TokenStore = mockTokenStore - mockTokenStore.EXPECT().Pop(gomock.Any()) - quicConf.TokenStore = mockTokenStore - }) - - It("handles NEW_TOKEN frames", func() { - mockTokenStore.EXPECT().Put("server", &ClientToken{data: []byte("foobar")}) - Expect(conn.handleNewTokenFrame(&wire.NewTokenFrame{Token: []byte("foobar")})).To(Succeed()) - }) - }) - - Context("handling Version Negotiation", func() { - getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) - return &receivedPacket{ - data: b, - buffer: getPacketBuffer(), - } - } - - It("closes and returns the right error", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - sph.EXPECT().ReceivedBytes(gomock.Any()) - sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) - conn.config.Versions = []protocol.VersionNumber{1234, 4321} - errChan := make(chan error, 1) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- conn.run() - }() - connRunner.EXPECT().Remove(srcConnID) - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()).Do(func(hdr *wire.Header, versions []logging.VersionNumber) { - Expect(hdr.Version).To(BeZero()) - Expect(versions).To(And( - ContainElement(protocol.VersionNumber(4321)), - ContainElement(protocol.VersionNumber(1337)), - )) - }) - cryptoSetup.EXPECT().Close() - Expect(conn.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) - var err error - Eventually(errChan).Should(Receive(&err)) - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{})) - recreateErr := err.(*errCloseForRecreating) - Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321))) - Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) - }) - - It("it closes when no matching version is found", func() { - errChan := make(chan error, 1) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- conn.run() - }() - connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) - gomock.InOrder( - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()), - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - var vnErr *VersionNegotiationError - Expect(errors.As(e, &vnErr)).To(BeTrue()) - Expect(vnErr.Theirs).To(ContainElement(logging.VersionNumber(12345678))) - }), - tracer.EXPECT().Close(), - ) - cryptoSetup.EXPECT().Close() - Expect(conn.handlePacketImpl(getVNP(12345678))).To(BeFalse()) - var err error - Eventually(errChan).Should(Receive(&err)) - Expect(err).To(HaveOccurred()) - Expect(err).ToNot(BeAssignableToTypeOf(errCloseForRecreating{})) - Expect(err.Error()).To(ContainSubstring("no compatible QUIC version found")) - }) - - It("ignores Version Negotiation packets that offer the current version", func() { - p := getVNP(conn.version) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("ignores unparseable Version Negotiation packets", func() { - p := getVNP(conn.version) - p.data = p.data[:len(p.data)-2] - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - }) - - Context("handling Retry", func() { - origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - - var retryHdr *wire.ExtendedHeader - - JustBeforeEach(func() { - retryHdr = &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Token: []byte("foobar"), - Version: conn.version, - }, - } - }) - - getRetryTag := func(hdr *wire.ExtendedHeader) []byte { - buf := &bytes.Buffer{} - hdr.Write(buf, conn.version) - return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:] - } - - It("handles Retry packets", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - sph.EXPECT().ResetForRetry() - sph.EXPECT().ReceivedBytes(gomock.Any()) - cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) - packer.EXPECT().SetToken([]byte("foobar")) - tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { - Expect(hdr.DestConnectionID).To(Equal(retryHdr.DestConnectionID)) - Expect(hdr.SrcConnectionID).To(Equal(retryHdr.SrcConnectionID)) - Expect(hdr.Token).To(Equal(retryHdr.Token)) - }) - Expect(conn.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) - }) - - It("ignores Retry packets after receiving a regular packet", func() { - conn.receivedFirstPacket = true - p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("ignores Retry packets if the server didn't change the connection ID", func() { - retryHdr.SrcConnectionID = destConnID - p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - It("ignores Retry packets with the a wrong Integrity tag", func() { - tag := getRetryTag(retryHdr) - tag[0]++ - p := getPacket(retryHdr, tag) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropPayloadDecryptError) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - }) - - Context("transport parameters", func() { - var ( - closed bool - errChan chan error - ) - - JustBeforeEach(func() { - errChan = make(chan error, 1) - closed = false - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- conn.run() - close(errChan) - }() - }) - - expectClose := func(applicationClose bool) { - if !closed { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) - s.shutdown() - }) - if applicationClose { - packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) - } else { - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) - } - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - gomock.InOrder( - tracer.EXPECT().ClosedConnection(gomock.Any()), - tracer.EXPECT().Close(), - ) - } - closed = true - } - - AfterEach(func() { - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) - Eventually(errChan).Should(BeClosed()) - }) - - It("uses the preferred_address connection ID", func() { - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: destConnID, - PreferredAddress: &wire.PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - }, - } - packer.EXPECT().HandleTransportParameters(gomock.Any()) - packer.EXPECT().PackCoalescedPacket().MaxTimes(1) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - conn.handleHandshakeComplete() - // make sure the connection ID is not retired - cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(cf).To(BeEmpty()) - connRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, conn) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) - // shut down - connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) - expectClose(true) - }) - - It("uses the minimum of the peers' idle timeouts", func() { - conn.config.MaxIdleTimeout = 19 * time.Second - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: destConnID, - MaxIdleTimeout: 18 * time.Second, - } - packer.EXPECT().HandleTransportParameters(gomock.Any()) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - conn.handleHandshakeComplete() - Expect(conn.idleTimeout).To(Equal(18 * time.Second)) - expectClose(true) - }) - - It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { - conn.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad", - }))) - }) - - It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: destConnID, - StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "missing retry_source_connection_id", - }))) - }) - - It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de", - }))) - }) - - It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "received retry_source_connection_id, although no Retry was performed", - }))) - }) - - It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { - conn.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - InitialSourceConnectionID: conn.handshakeDestConnID, - StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad", - }))) - }) - }) - - Context("handling potentially injected packets", func() { - var unpacker *MockUnpacker - - getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, conn.version)).To(Succeed()) - return &receivedPacket{ - data: append(buf.Bytes(), data...), - buffer: getPacketBuffer(), - } - } - - // Convert an already packed raw packet into a receivedPacket - wrapPacket := func(packet []byte) *receivedPacket { - return &receivedPacket{ - data: packet, - buffer: getPacketBuffer(), - } - } - - // Illustrates that attacker may inject an Initial packet with a different - // source connection ID, causing endpoint to ignore a subsequent real Initial packets. - It("ignores Initial packets with a different source connection ID", func() { - // Modified from test "ignores packets with a different source connection ID" - unpacker = NewMockUnpacker(mockCtrl) - conn.unpacker = unpacker - - hdr1 := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: destConnID, - SrcConnectionID: srcConnID, - Length: 1, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 1, - } - hdr2 := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - Length: 1, - Version: conn.version, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 2, - } - Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID)) - // Send one packet, which might change the connection ID. - // only EXPECT one call to the unpacker - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.EncryptionInitial, - hdr: hdr1, - data: []byte{0}, // one PADDING frame - }, nil) - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) - // The next packet has to be ignored, since the source connection ID doesn't match. - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) - }) - - It("ignores 0-RTT packets", func() { - p := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - DestConnectionID: srcConnID, - Length: 2 + 6, - Version: conn.version, - }, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, p.Size(), gomock.Any()) - Expect(conn.handlePacketImpl(p)).To(BeFalse()) - }) - - // Illustrates that an injected Initial with an ACK frame for an unsent packet causes - // the connection to immediately break down - It("fails on Initial-level ACK for unsent packet", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) - }) - - // Illustrates that an injected Initial with a CONNECTION_CLOSE frame causes - // the connection to immediately break down - It("fails on Initial-level CONNECTION_CLOSE frame", func() { - connCloseFrame := &wire.ConnectionCloseFrame{ - IsApplicationError: true, - ReasonPhrase: "mitm attacker", - } - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) - tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) - }) - - // Illustrates that attacker who injects a Retry packet and changes the connection ID - // can cause subsequent real Initial packets to be ignored - It("ignores Initial packets which use original source id, after accepting a Retry", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - conn.sentPacketHandler = sph - sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) - sph.EXPECT().ResetForRetry() - newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) - packer.EXPECT().SetToken([]byte("foobar")) - - tracer.EXPECT().ReceivedRetry(gomock.Any()) - conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) - initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) - }) - }) -}) diff --git a/internal/quic-go/crypto_stream.go b/internal/quic-go/crypto_stream.go deleted file mode 100644 index 0763b165..00000000 --- a/internal/quic-go/crypto_stream.go +++ /dev/null @@ -1,115 +0,0 @@ -package quic - -import ( - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type cryptoStream interface { - // for receiving data - HandleCryptoFrame(*wire.CryptoFrame) error - GetCryptoData() []byte - Finish() error - // for sending data - io.Writer - HasData() bool - PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame -} - -type cryptoStreamImpl struct { - queue *frameSorter - msgBuf []byte - - highestOffset protocol.ByteCount - finished bool - - writeOffset protocol.ByteCount - writeBuf []byte -} - -func newCryptoStream() cryptoStream { - return &cryptoStreamImpl{queue: newFrameSorter()} -} - -func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { - highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) - if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { - return &qerr.TransportError{ - ErrorCode: qerr.CryptoBufferExceeded, - ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset), - } - } - if s.finished { - if highestOffset > s.highestOffset { - // reject crypto data received after this stream was already finished - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received crypto data after change of encryption level", - } - } - // ignore data with a smaller offset than the highest received - // could e.g. be a retransmission - return nil - } - s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) - if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { - return err - } - for { - _, data, _ := s.queue.Pop() - if data == nil { - return nil - } - s.msgBuf = append(s.msgBuf, data...) - } -} - -// GetCryptoData retrieves data that was received in CRYPTO frames -func (s *cryptoStreamImpl) GetCryptoData() []byte { - if len(s.msgBuf) < 4 { - return nil - } - msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) - if len(s.msgBuf) < msgLen { - return nil - } - msg := make([]byte, msgLen) - copy(msg, s.msgBuf[:msgLen]) - s.msgBuf = s.msgBuf[msgLen:] - return msg -} - -func (s *cryptoStreamImpl) Finish() error { - if s.queue.HasMoreData() { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "encryption level changed, but crypto stream has more data to read", - } - } - s.finished = true - return nil -} - -// Writes writes data that should be sent out in CRYPTO frames -func (s *cryptoStreamImpl) Write(p []byte) (int, error) { - s.writeBuf = append(s.writeBuf, p...) - return len(p), nil -} - -func (s *cryptoStreamImpl) HasData() bool { - return len(s.writeBuf) > 0 -} - -func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { - f := &wire.CryptoFrame{Offset: s.writeOffset} - n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) - f.Data = s.writeBuf[:n] - s.writeBuf = s.writeBuf[n:] - s.writeOffset += n - return f -} diff --git a/internal/quic-go/crypto_stream_manager.go b/internal/quic-go/crypto_stream_manager.go deleted file mode 100644 index 83a70ae5..00000000 --- a/internal/quic-go/crypto_stream_manager.go +++ /dev/null @@ -1,61 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type cryptoDataHandler interface { - HandleMessage([]byte, protocol.EncryptionLevel) bool -} - -type cryptoStreamManager struct { - cryptoHandler cryptoDataHandler - - initialStream cryptoStream - handshakeStream cryptoStream - oneRTTStream cryptoStream -} - -func newCryptoStreamManager( - cryptoHandler cryptoDataHandler, - initialStream cryptoStream, - handshakeStream cryptoStream, - oneRTTStream cryptoStream, -) *cryptoStreamManager { - return &cryptoStreamManager{ - cryptoHandler: cryptoHandler, - initialStream: initialStream, - handshakeStream: handshakeStream, - oneRTTStream: oneRTTStream, - } -} - -func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { - var str cryptoStream - //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. - switch encLevel { - case protocol.EncryptionInitial: - str = m.initialStream - case protocol.EncryptionHandshake: - str = m.handshakeStream - case protocol.Encryption1RTT: - str = m.oneRTTStream - default: - return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) - } - if err := str.HandleCryptoFrame(frame); err != nil { - return false, err - } - for { - data := str.GetCryptoData() - if data == nil { - return false, nil - } - if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { - return true, str.Finish() - } - } -} diff --git a/internal/quic-go/crypto_stream_manager_test.go b/internal/quic-go/crypto_stream_manager_test.go deleted file mode 100644 index d5d7ed85..00000000 --- a/internal/quic-go/crypto_stream_manager_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package quic - -import ( - "errors" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Crypto Stream Manager", func() { - var ( - csm *cryptoStreamManager - cs *MockCryptoDataHandler - - initialStream *MockCryptoStream - handshakeStream *MockCryptoStream - oneRTTStream *MockCryptoStream - ) - - BeforeEach(func() { - initialStream = NewMockCryptoStream(mockCtrl) - handshakeStream = NewMockCryptoStream(mockCtrl) - oneRTTStream = NewMockCryptoStream(mockCtrl) - cs = NewMockCryptoDataHandler(mockCtrl) - csm = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) - }) - - It("passes messages to the initial stream", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - initialStream.EXPECT().HandleCryptoFrame(cf) - initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - initialStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("passes messages to the handshake stream", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - handshakeStream.EXPECT().HandleCryptoFrame(cf) - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - handshakeStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("passes messages to the 1-RTT stream", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - oneRTTStream.EXPECT().HandleCryptoFrame(cf) - oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - oneRTTStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("doesn't call the message handler, if there's no message", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - handshakeStream.EXPECT().HandleCryptoFrame(cf) - handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle - // don't EXPECT any calls to HandleMessage() - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("processes all messages", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - handshakeStream.EXPECT().HandleCryptoFrame(cf) - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo")) - handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar")) - handshakeStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) - cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish(), - ) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeTrue()) - }) - - It("returns errors that occur when finishing a stream", func() { - testErr := errors.New("test error") - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish().Return(testErr), - ) - _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).To(MatchError(err)) - }) - - It("errors for unknown encryption levels", func() { - _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level")) - }) -}) diff --git a/internal/quic-go/crypto_stream_test.go b/internal/quic-go/crypto_stream_test.go deleted file mode 100644 index 7c8301b7..00000000 --- a/internal/quic-go/crypto_stream_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package quic - -import ( - "crypto/rand" - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func createHandshakeMessage(len int) []byte { - msg := make([]byte, 4+len) - rand.Read(msg[:1]) // random message type - msg[1] = uint8(len >> 16) - msg[2] = uint8(len >> 8) - msg[3] = uint8(len) - rand.Read(msg[4:]) - return msg -} - -var _ = Describe("Crypto Stream", func() { - var str cryptoStream - - BeforeEach(func() { - str = newCryptoStream() - }) - - Context("handling incoming data", func() { - It("handles in-order CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) - Expect(str.GetCryptoData()).To(BeNil()) - }) - - It("handles multiple messages in one CRYPTO frame", func() { - msg1 := createHandshakeMessage(6) - msg2 := createHandshakeMessage(10) - msg := append(append([]byte{}, msg1...), msg2...) - err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg1)) - Expect(str.GetCryptoData()).To(Equal(msg2)) - Expect(str.GetCryptoData()).To(BeNil()) - }) - - It("errors if the frame exceeds the maximum offset", func() { - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: protocol.MaxCryptoStreamOffset - 5, - Data: []byte("foobar"), - })).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.CryptoBufferExceeded, - ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset), - })) - }) - - It("handles messages split over multiple CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg[:4], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(BeNil()) - err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 4, - Data: msg[4:], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) - Expect(str.GetCryptoData()).To(BeNil()) - }) - - It("handles out-of-order CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 4, - Data: msg[4:], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(BeNil()) - err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg[:4], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) - Expect(str.GetCryptoData()).To(BeNil()) - }) - - Context("finishing", func() { - It("errors if there's still data to read after finishing", func() { - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: createHandshakeMessage(5), - Offset: 10, - })).To(Succeed()) - Expect(str.Finish()).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "encryption level changed, but crypto stream has more data to read", - })) - }) - - It("works with reordered data", func() { - f1 := &wire.CryptoFrame{ - Data: []byte("foo"), - } - f2 := &wire.CryptoFrame{ - Offset: 3, - Data: []byte("bar"), - } - Expect(str.HandleCryptoFrame(f2)).To(Succeed()) - Expect(str.HandleCryptoFrame(f1)).To(Succeed()) - Expect(str.Finish()).To(Succeed()) - Expect(str.HandleCryptoFrame(f2)).To(Succeed()) - }) - - It("rejects new crypto data after finishing", func() { - Expect(str.Finish()).To(Succeed()) - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: createHandshakeMessage(5), - })).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received crypto data after change of encryption level", - })) - }) - - It("ignores crypto data below the maximum offset received before finishing", func() { - msg := createHandshakeMessage(15) - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg, - })).To(Succeed()) - Expect(str.GetCryptoData()).To(Equal(msg)) - Expect(str.Finish()).To(Succeed()) - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: protocol.ByteCount(len(msg) - 6), - Data: []byte("foobar"), - })).To(Succeed()) - }) - }) - }) - - Context("writing data", func() { - It("says if it has data", func() { - Expect(str.HasData()).To(BeFalse()) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.HasData()).To(BeTrue()) - }) - - It("pops crypto frames", func() { - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - f := str.PopCryptoFrame(1000) - Expect(f).ToNot(BeNil()) - Expect(f.Offset).To(BeZero()) - Expect(f.Data).To(Equal([]byte("foobar"))) - }) - - It("coalesces multiple writes", func() { - _, err := str.Write([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("bar")) - Expect(err).ToNot(HaveOccurred()) - f := str.PopCryptoFrame(1000) - Expect(f).ToNot(BeNil()) - Expect(f.Offset).To(BeZero()) - Expect(f.Data).To(Equal([]byte("foobar"))) - }) - - It("respects the maximum size", func() { - frameHeaderLen := (&wire.CryptoFrame{}).Length(protocol.VersionWhatever) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - f := str.PopCryptoFrame(frameHeaderLen + 3) - Expect(f).ToNot(BeNil()) - Expect(f.Offset).To(BeZero()) - Expect(f.Data).To(Equal([]byte("foo"))) - f = str.PopCryptoFrame(frameHeaderLen + 3) - Expect(f).ToNot(BeNil()) - Expect(f.Offset).To(Equal(protocol.ByteCount(3))) - Expect(f.Data).To(Equal([]byte("bar"))) - }) - }) -}) diff --git a/internal/quic-go/datagram_queue.go b/internal/quic-go/datagram_queue.go deleted file mode 100644 index 561b7a8e..00000000 --- a/internal/quic-go/datagram_queue.go +++ /dev/null @@ -1,87 +0,0 @@ -package quic - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - rcvQueue chan []byte - - closeErr error - closed chan struct{} - - hasData func() - - dequeued chan struct{} - - logger utils.Logger -} - -func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { - return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), - dequeued: make(chan struct{}), - closed: make(chan struct{}), - logger: logger, - } -} - -// AddAndWait queues a new DATAGRAM frame for sending. -// It blocks until the frame has been dequeued. -func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { - select { - case h.sendQueue <- f: - h.hasData() - case <-h.closed: - return h.closeErr - } - - select { - case <-h.dequeued: - return nil - case <-h.closed: - return h.closeErr - } -} - -// Get dequeues a DATAGRAM frame for sending. -func (h *datagramQueue) Get() *wire.DatagramFrame { - select { - case f := <-h.sendQueue: - h.dequeued <- struct{}{} - return f - default: - return nil - } -} - -// HandleDatagramFrame handles a received DATAGRAM frame. -func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { - data := make([]byte, len(f.Data)) - copy(data, f.Data) - select { - case h.rcvQueue <- data: - default: - h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) - } -} - -// Receive gets a received DATAGRAM frame. -func (h *datagramQueue) Receive() ([]byte, error) { - select { - case data := <-h.rcvQueue: - return data, nil - case <-h.closed: - return nil, h.closeErr - } -} - -func (h *datagramQueue) CloseWithError(e error) { - h.closeErr = e - close(h.closed) -} diff --git a/internal/quic-go/datagram_queue_test.go b/internal/quic-go/datagram_queue_test.go deleted file mode 100644 index 29351ce6..00000000 --- a/internal/quic-go/datagram_queue_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package quic - -import ( - "errors" - - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Datagram Queue", func() { - var queue *datagramQueue - var queued chan struct{} - - BeforeEach(func() { - queued = make(chan struct{}, 100) - queue = newDatagramQueue(func() { - queued <- struct{}{} - }, utils.DefaultLogger) - }) - - Context("sending", func() { - It("returns nil when there's no datagram to send", func() { - Expect(queue.Get()).To(BeNil()) - }) - - It("queues a datagram", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed()) - }() - - Eventually(queued).Should(HaveLen(1)) - Consistently(done).ShouldNot(BeClosed()) - f := queue.Get() - Expect(f).ToNot(BeNil()) - Expect(f.Data).To(Equal([]byte("foobar"))) - Eventually(done).Should(BeClosed()) - Expect(queue.Get()).To(BeNil()) - }) - - It("closes", func() { - errChan := make(chan error, 1) - go func() { - defer GinkgoRecover() - errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) - }() - - Consistently(errChan).ShouldNot(Receive()) - queue.CloseWithError(errors.New("test error")) - Eventually(errChan).Should(Receive(MatchError("test error"))) - }) - }) - - Context("receiving", func() { - It("receives DATAGRAM frames", func() { - queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) - queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) - data, err := queue.Receive() - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foo"))) - data, err = queue.Receive() - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("bar"))) - }) - - It("blocks until a frame is received", func() { - c := make(chan []byte, 1) - go func() { - defer GinkgoRecover() - data, err := queue.Receive() - Expect(err).ToNot(HaveOccurred()) - c <- data - }() - - Consistently(c).ShouldNot(Receive()) - queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")}) - Eventually(c).Should(Receive(Equal([]byte("foobar")))) - }) - - It("closes", func() { - errChan := make(chan error, 1) - go func() { - defer GinkgoRecover() - _, err := queue.Receive() - errChan <- err - }() - - Consistently(errChan).ShouldNot(Receive()) - queue.CloseWithError(errors.New("test error")) - Eventually(errChan).Should(Receive(MatchError("test error"))) - }) - }) -}) diff --git a/internal/quic-go/errors.go b/internal/quic-go/errors.go deleted file mode 100644 index 5f9050ac..00000000 --- a/internal/quic-go/errors.go +++ /dev/null @@ -1,58 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/qerr" -) - -type ( - TransportError = qerr.TransportError - ApplicationError = qerr.ApplicationError - VersionNegotiationError = qerr.VersionNegotiationError - StatelessResetError = qerr.StatelessResetError - IdleTimeoutError = qerr.IdleTimeoutError - HandshakeTimeoutError = qerr.HandshakeTimeoutError -) - -type ( - TransportErrorCode = qerr.TransportErrorCode - ApplicationErrorCode = qerr.ApplicationErrorCode - StreamErrorCode = qerr.StreamErrorCode -) - -const ( - NoError = qerr.NoError - InternalError = qerr.InternalError - ConnectionRefused = qerr.ConnectionRefused - FlowControlError = qerr.FlowControlError - StreamLimitError = qerr.StreamLimitError - StreamStateError = qerr.StreamStateError - FinalSizeError = qerr.FinalSizeError - FrameEncodingError = qerr.FrameEncodingError - TransportParameterError = qerr.TransportParameterError - ConnectionIDLimitError = qerr.ConnectionIDLimitError - ProtocolViolation = qerr.ProtocolViolation - InvalidToken = qerr.InvalidToken - ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode - CryptoBufferExceeded = qerr.CryptoBufferExceeded - KeyUpdateError = qerr.KeyUpdateError - AEADLimitReached = qerr.AEADLimitReached - NoViablePathError = qerr.NoViablePathError -) - -// A StreamError is used for Stream.CancelRead and Stream.CancelWrite. -// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing. -type StreamError struct { - StreamID StreamID - ErrorCode StreamErrorCode -} - -func (e *StreamError) Is(target error) bool { - _, ok := target.(*StreamError) - return ok -} - -func (e *StreamError) Error() string { - return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) -} diff --git a/internal/quic-go/flowcontrol/base_flow_controller.go b/internal/quic-go/flowcontrol/base_flow_controller.go deleted file mode 100644 index 4c7bcb70..00000000 --- a/internal/quic-go/flowcontrol/base_flow_controller.go +++ /dev/null @@ -1,125 +0,0 @@ -package flowcontrol - -import ( - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type baseFlowController struct { - // for sending data - bytesSent protocol.ByteCount - sendWindow protocol.ByteCount - lastBlockedAt protocol.ByteCount - - // for receiving data - //nolint:structcheck // The mutex is used both by the stream and the connection flow controller - mutex sync.Mutex - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveWindow protocol.ByteCount - receiveWindowSize protocol.ByteCount - maxReceiveWindowSize protocol.ByteCount - - allowWindowIncrease func(size protocol.ByteCount) bool - - epochStartTime time.Time - epochStartOffset protocol.ByteCount - rttStats *utils.RTTStats - - logger utils.Logger -} - -// IsNewlyBlocked says if it is newly blocked by flow control. -// For every offset, it only returns true once. -// If it is blocked, the offset is returned. -func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { - return false, 0 - } - c.lastBlockedAt = c.sendWindow - return true, c.sendWindow -} - -func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { - c.bytesSent += n -} - -// UpdateSendWindow is be called after receiving a MAX_{STREAM_}DATA frame. -func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { - if offset > c.sendWindow { - c.sendWindow = offset - } -} - -func (c *baseFlowController) sendWindowSize() protocol.ByteCount { - // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters - if c.bytesSent > c.sendWindow { - return 0 - } - return c.sendWindow - c.bytesSent -} - -// needs to be called with locked mutex -func (c *baseFlowController) addBytesRead(n protocol.ByteCount) { - // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window size already works for the first WindowUpdate - if c.bytesRead == 0 { - c.startNewAutoTuningEpoch(time.Now()) - } - c.bytesRead += n -} - -func (c *baseFlowController) hasWindowUpdate() bool { - bytesRemaining := c.receiveWindow - c.bytesRead - // update the window when more than the threshold was consumed - return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold)) -} - -// getWindowUpdate updates the receive window, if necessary -// it returns the new offset -func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { - if !c.hasWindowUpdate() { - return 0 - } - - c.maybeAdjustWindowSize() - c.receiveWindow = c.bytesRead + c.receiveWindowSize - return c.receiveWindow -} - -// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. -// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. -func (c *baseFlowController) maybeAdjustWindowSize() { - bytesReadInEpoch := c.bytesRead - c.epochStartOffset - // don't do anything if less than half the window has been consumed - if bytesReadInEpoch <= c.receiveWindowSize/2 { - return - } - rtt := c.rttStats.SmoothedRTT() - if rtt == 0 { - return - } - - fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) - now := time.Now() - if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { - // window is consumed too fast, try to increase the window size - newSize := utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) - if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { - c.receiveWindowSize = newSize - } - } - c.startNewAutoTuningEpoch(now) -} - -func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) { - c.epochStartTime = now - c.epochStartOffset = c.bytesRead -} - -func (c *baseFlowController) checkFlowControlViolation() bool { - return c.highestReceived > c.receiveWindow -} diff --git a/internal/quic-go/flowcontrol/base_flow_controller_test.go b/internal/quic-go/flowcontrol/base_flow_controller_test.go deleted file mode 100644 index e5a9f578..00000000 --- a/internal/quic-go/flowcontrol/base_flow_controller_test.go +++ /dev/null @@ -1,236 +0,0 @@ -package flowcontrol - -import ( - "os" - "strconv" - "time" - - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -// on the CIs, the timing is a lot less precise, so scale every duration by this factor -// -//nolint:unparam -func scaleDuration(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t -} - -var _ = Describe("Base Flow controller", func() { - var controller *baseFlowController - - BeforeEach(func() { - controller = &baseFlowController{} - controller.rttStats = &utils.RTTStats{} - }) - - Context("send flow control", func() { - It("adds bytes sent", func() { - controller.bytesSent = 5 - controller.AddBytesSent(6) - Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6))) - }) - - It("gets the size of the remaining flow control window", func() { - controller.bytesSent = 5 - controller.sendWindow = 12 - Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(12 - 5))) - }) - - It("updates the size of the flow control window", func() { - controller.AddBytesSent(5) - controller.UpdateSendWindow(15) - Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15))) - Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) - }) - - It("says that the window size is 0 if we sent more than we were allowed to", func() { - controller.AddBytesSent(15) - controller.UpdateSendWindow(10) - Expect(controller.sendWindowSize()).To(BeZero()) - }) - - It("does not decrease the flow control window", func() { - controller.UpdateSendWindow(20) - Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) - controller.UpdateSendWindow(10) - Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) - }) - - It("says when it's blocked", func() { - controller.UpdateSendWindow(100) - Expect(controller.IsNewlyBlocked()).To(BeFalse()) - controller.AddBytesSent(100) - blocked, offset := controller.IsNewlyBlocked() - Expect(blocked).To(BeTrue()) - Expect(offset).To(Equal(protocol.ByteCount(100))) - }) - - It("doesn't say that it's newly blocked multiple times for the same offset", func() { - controller.UpdateSendWindow(100) - controller.AddBytesSent(100) - newlyBlocked, offset := controller.IsNewlyBlocked() - Expect(newlyBlocked).To(BeTrue()) - Expect(offset).To(Equal(protocol.ByteCount(100))) - newlyBlocked, _ = controller.IsNewlyBlocked() - Expect(newlyBlocked).To(BeFalse()) - controller.UpdateSendWindow(150) - controller.AddBytesSent(150) - newlyBlocked, _ = controller.IsNewlyBlocked() - Expect(newlyBlocked).To(BeTrue()) - }) - }) - - Context("receive flow control", func() { - var ( - receiveWindow protocol.ByteCount = 10000 - receiveWindowSize protocol.ByteCount = 1000 - ) - - BeforeEach(func() { - controller.bytesRead = receiveWindow - receiveWindowSize - controller.receiveWindow = receiveWindow - controller.receiveWindowSize = receiveWindowSize - }) - - It("adds bytes read", func() { - controller.bytesRead = 5 - controller.addBytesRead(6) - Expect(controller.bytesRead).To(Equal(protocol.ByteCount(5 + 6))) - }) - - It("triggers a window update when necessary", func() { - bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold - bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) - readPosition := receiveWindow - bytesRemaining - controller.bytesRead = readPosition - offset := controller.getWindowUpdate() - Expect(offset).To(Equal(readPosition + receiveWindowSize)) - Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize)) - }) - - It("doesn't trigger a window update when not necessary", func() { - bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold - bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) - readPosition := receiveWindow - bytesRemaining - controller.bytesRead = readPosition - offset := controller.getWindowUpdate() - Expect(offset).To(BeZero()) - }) - - Context("receive window size auto-tuning", func() { - var oldWindowSize protocol.ByteCount - - BeforeEach(func() { - oldWindowSize = controller.receiveWindowSize - controller.maxReceiveWindowSize = 5000 - }) - - // update the congestion such that it returns a given value for the smoothed RTT - setRtt := func(t time.Duration) { - controller.rttStats.UpdateRTT(t, 0, time.Now()) - Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked - } - - It("doesn't increase the window size for a new stream", func() { - controller.maybeAdjustWindowSize() - Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) - }) - - It("doesn't increase the window size when no RTT estimate is available", func() { - setRtt(0) - controller.startNewAutoTuningEpoch(time.Now()) - controller.addBytesRead(400) - offset := controller.getWindowUpdate() - Expect(offset).ToNot(BeZero()) // make sure a window update is sent - Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) - }) - - It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() { - bytesRead := controller.bytesRead - rtt := scaleDuration(50 * time.Millisecond) - setRtt(rtt) - // consume more than 2/3 of the window... - dataRead := receiveWindowSize*2/3 + 1 - // ... in 4*2/3 of the RTT - controller.epochStartOffset = controller.bytesRead - controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) - controller.addBytesRead(dataRead) - offset := controller.getWindowUpdate() - Expect(offset).ToNot(BeZero()) - // check that the window size was increased - newWindowSize := controller.receiveWindowSize - Expect(newWindowSize).To(Equal(2 * oldWindowSize)) - // check that the new window size was used to increase the offset - Expect(offset).To(Equal(bytesRead + dataRead + newWindowSize)) - }) - - It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() { - bytesRead := controller.bytesRead - rtt := scaleDuration(20 * time.Millisecond) - setRtt(rtt) - // consume more than 2/3 of the window... - dataRead := receiveWindowSize*1/3 + 1 - // ... in 4*2/3 of the RTT - controller.epochStartOffset = controller.bytesRead - controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3) - controller.addBytesRead(dataRead) - offset := controller.getWindowUpdate() - Expect(offset).ToNot(BeZero()) - // check that the window size was not increased - newWindowSize := controller.receiveWindowSize - Expect(newWindowSize).To(Equal(oldWindowSize)) - // check that the new window size was used to increase the offset - Expect(offset).To(Equal(bytesRead + dataRead + newWindowSize)) - }) - - It("doesn't increase the window size if read too slowly", func() { - bytesRead := controller.bytesRead - rtt := scaleDuration(20 * time.Millisecond) - setRtt(rtt) - // consume less than 2/3 of the window... - dataRead := receiveWindowSize*2/3 - 1 - // ... in 4*2/3 of the RTT - controller.epochStartOffset = controller.bytesRead - controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) - controller.addBytesRead(dataRead) - offset := controller.getWindowUpdate() - Expect(offset).ToNot(BeZero()) - // check that the window size was not increased - Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) - // check that the new window size was used to increase the offset - Expect(offset).To(Equal(bytesRead + dataRead + oldWindowSize)) - }) - - It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() { - resetEpoch := func() { - // make sure the next call to maybeAdjustWindowSize will increase the window - controller.epochStartTime = time.Now().Add(-time.Millisecond) - controller.epochStartOffset = controller.bytesRead - controller.addBytesRead(controller.receiveWindowSize/2 + 1) - } - setRtt(scaleDuration(20 * time.Millisecond)) - resetEpoch() - controller.maybeAdjustWindowSize() - Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000 - // because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time - resetEpoch() - controller.maybeAdjustWindowSize() - Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000 - resetEpoch() - controller.maybeAdjustWindowSize() - Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 - controller.maybeAdjustWindowSize() - Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 - }) - }) - }) -}) diff --git a/internal/quic-go/flowcontrol/connection_flow_controller.go b/internal/quic-go/flowcontrol/connection_flow_controller.go deleted file mode 100644 index 3e40e0d5..00000000 --- a/internal/quic-go/flowcontrol/connection_flow_controller.go +++ /dev/null @@ -1,112 +0,0 @@ -package flowcontrol - -import ( - "errors" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type connectionFlowController struct { - baseFlowController - - queueWindowUpdate func() -} - -var _ ConnectionFlowController = &connectionFlowController{} - -// NewConnectionFlowController gets a new flow controller for the connection -// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0. -func NewConnectionFlowController( - receiveWindow protocol.ByteCount, - maxReceiveWindow protocol.ByteCount, - queueWindowUpdate func(), - allowWindowIncrease func(size protocol.ByteCount) bool, - rttStats *utils.RTTStats, - logger utils.Logger, -) ConnectionFlowController { - return &connectionFlowController{ - baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowSize: receiveWindow, - maxReceiveWindowSize: maxReceiveWindow, - allowWindowIncrease: allowWindowIncrease, - logger: logger, - }, - queueWindowUpdate: queueWindowUpdate, - } -} - -func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { - return c.baseFlowController.sendWindowSize() -} - -// IncrementHighestReceived adds an increment to the highestReceived value -func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.highestReceived += increment - if c.checkFlowControlViolation() { - return &qerr.TransportError{ - ErrorCode: qerr.FlowControlError, - ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow), - } - } - return nil -} - -func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) { - c.mutex.Lock() - c.baseFlowController.addBytesRead(n) - shouldQueueWindowUpdate := c.hasWindowUpdate() - c.mutex.Unlock() - if shouldQueueWindowUpdate { - c.queueWindowUpdate() - } -} - -func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { - c.mutex.Lock() - oldWindowSize := c.receiveWindowSize - offset := c.baseFlowController.getWindowUpdate() - if oldWindowSize < c.receiveWindowSize { - c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) - } - c.mutex.Unlock() - return offset -} - -// EnsureMinimumWindowSize sets a minimum window size -// it should make sure that the connection-level window is increased when a stream-level window grows -func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { - c.mutex.Lock() - if inc > c.receiveWindowSize { - c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) - newSize := utils.MinByteCount(inc, c.maxReceiveWindowSize) - if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { - c.receiveWindowSize = newSize - } - c.startNewAutoTuningEpoch(time.Now()) - } - c.mutex.Unlock() -} - -// Reset rests the flow controller. This happens when 0-RTT is rejected. -// All stream data is invalidated, it's if we had never opened a stream and never sent any data. -// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet. -func (c *connectionFlowController) Reset() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() { - return errors.New("flow controller reset after reading data") - } - c.bytesSent = 0 - c.lastBlockedAt = 0 - return nil -} diff --git a/internal/quic-go/flowcontrol/connection_flow_controller_test.go b/internal/quic-go/flowcontrol/connection_flow_controller_test.go deleted file mode 100644 index 32ad79c5..00000000 --- a/internal/quic-go/flowcontrol/connection_flow_controller_test.go +++ /dev/null @@ -1,185 +0,0 @@ -package flowcontrol - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection Flow controller", func() { - var ( - controller *connectionFlowController - queuedWindowUpdate bool - ) - - // update the congestion such that it returns a given value for the smoothed RTT - setRtt := func(t time.Duration) { - controller.rttStats.UpdateRTT(t, 0, time.Now()) - Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked - } - - BeforeEach(func() { - queuedWindowUpdate = false - controller = &connectionFlowController{} - controller.rttStats = &utils.RTTStats{} - controller.logger = utils.DefaultLogger - controller.queueWindowUpdate = func() { queuedWindowUpdate = true } - controller.allowWindowIncrease = func(protocol.ByteCount) bool { return true } - }) - - Context("Constructor", func() { - rttStats := &utils.RTTStats{} - - It("sets the send and receive windows", func() { - receiveWindow := protocol.ByteCount(2000) - maxReceiveWindow := protocol.ByteCount(3000) - - fc := NewConnectionFlowController( - receiveWindow, - maxReceiveWindow, - nil, - func(protocol.ByteCount) bool { return true }, - rttStats, - utils.DefaultLogger).(*connectionFlowController) - Expect(fc.receiveWindow).To(Equal(receiveWindow)) - Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) - }) - }) - - Context("receive flow control", func() { - It("increases the highestReceived by a given window size", func() { - controller.highestReceived = 1337 - controller.IncrementHighestReceived(123) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123))) - }) - - Context("getting window updates", func() { - BeforeEach(func() { - controller.receiveWindow = 100 - controller.receiveWindowSize = 60 - controller.maxReceiveWindowSize = 1000 - controller.bytesRead = 100 - 60 - }) - - It("queues window updates", func() { - controller.AddBytesRead(1) - Expect(queuedWindowUpdate).To(BeFalse()) - controller.AddBytesRead(29) - Expect(queuedWindowUpdate).To(BeTrue()) - Expect(controller.GetWindowUpdate()).ToNot(BeZero()) - queuedWindowUpdate = false - controller.AddBytesRead(1) - Expect(queuedWindowUpdate).To(BeFalse()) - }) - - It("gets a window update", func() { - windowSize := controller.receiveWindowSize - oldOffset := controller.bytesRead - dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning - controller.AddBytesRead(dataRead) - offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(oldOffset + dataRead + 60)) - }) - - It("auto-tunes the window", func() { - var allowed protocol.ByteCount - controller.allowWindowIncrease = func(size protocol.ByteCount) bool { - allowed = size - return true - } - oldOffset := controller.bytesRead - oldWindowSize := controller.receiveWindowSize - rtt := scaleDuration(20 * time.Millisecond) - setRtt(rtt) - controller.epochStartTime = time.Now().Add(-time.Millisecond) - controller.epochStartOffset = oldOffset - dataRead := oldWindowSize/2 + 1 - controller.AddBytesRead(dataRead) - offset := controller.GetWindowUpdate() - newWindowSize := controller.receiveWindowSize - Expect(newWindowSize).To(Equal(2 * oldWindowSize)) - Expect(offset).To(Equal(oldOffset + dataRead + newWindowSize)) - Expect(allowed).To(Equal(oldWindowSize)) - }) - - It("doesn't auto-tune the window if it's not allowed", func() { - controller.allowWindowIncrease = func(protocol.ByteCount) bool { return false } - oldOffset := controller.bytesRead - oldWindowSize := controller.receiveWindowSize - rtt := scaleDuration(20 * time.Millisecond) - setRtt(rtt) - controller.epochStartTime = time.Now().Add(-time.Millisecond) - controller.epochStartOffset = oldOffset - dataRead := oldWindowSize/2 + 1 - controller.AddBytesRead(dataRead) - offset := controller.GetWindowUpdate() - newWindowSize := controller.receiveWindowSize - Expect(newWindowSize).To(Equal(oldWindowSize)) - Expect(offset).To(Equal(oldOffset + dataRead + newWindowSize)) - }) - }) - }) - - Context("setting the minimum window size", func() { - var ( - oldWindowSize protocol.ByteCount - receiveWindow protocol.ByteCount = 10000 - receiveWindowSize protocol.ByteCount = 1000 - ) - - BeforeEach(func() { - controller.receiveWindow = receiveWindow - controller.receiveWindowSize = receiveWindowSize - oldWindowSize = controller.receiveWindowSize - controller.maxReceiveWindowSize = 3000 - }) - - It("sets the minimum window window size", func() { - controller.EnsureMinimumWindowSize(1800) - Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800))) - }) - - It("doesn't reduce the window window size", func() { - controller.EnsureMinimumWindowSize(1) - Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) - }) - - It("doesn't increase the window size beyond the maxReceiveWindowSize", func() { - max := controller.maxReceiveWindowSize - controller.EnsureMinimumWindowSize(2 * max) - Expect(controller.receiveWindowSize).To(Equal(max)) - }) - - It("starts a new epoch after the window size was increased", func() { - controller.EnsureMinimumWindowSize(1912) - Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - }) - - Context("resetting", func() { - It("resets", func() { - const initialWindow protocol.ByteCount = 1337 - controller.UpdateSendWindow(initialWindow) - controller.AddBytesSent(1000) - Expect(controller.SendWindowSize()).To(Equal(initialWindow - 1000)) - Expect(controller.Reset()).To(Succeed()) - Expect(controller.SendWindowSize()).To(Equal(initialWindow)) - }) - - It("says if is blocked after resetting", func() { - const initialWindow protocol.ByteCount = 1337 - controller.UpdateSendWindow(initialWindow) - controller.AddBytesSent(initialWindow) - blocked, _ := controller.IsNewlyBlocked() - Expect(blocked).To(BeTrue()) - Expect(controller.Reset()).To(Succeed()) - controller.AddBytesSent(initialWindow) - blocked, blockedAt := controller.IsNewlyBlocked() - Expect(blocked).To(BeTrue()) - Expect(blockedAt).To(Equal(initialWindow)) - }) - }) -}) diff --git a/internal/quic-go/flowcontrol/flowcontrol_suite_test.go b/internal/quic-go/flowcontrol/flowcontrol_suite_test.go deleted file mode 100644 index 91102815..00000000 --- a/internal/quic-go/flowcontrol/flowcontrol_suite_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package flowcontrol - -import ( - "testing" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestFlowControl(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "FlowControl Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/quic-go/flowcontrol/interface.go b/internal/quic-go/flowcontrol/interface.go deleted file mode 100644 index 58a0e88f..00000000 --- a/internal/quic-go/flowcontrol/interface.go +++ /dev/null @@ -1,42 +0,0 @@ -package flowcontrol - -import "github.com/imroc/req/v3/internal/quic-go/protocol" - -type flowController interface { - // for sending - SendWindowSize() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) - AddBytesSent(protocol.ByteCount) - // for receiving - AddBytesRead(protocol.ByteCount) - GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary - IsNewlyBlocked() (bool, protocol.ByteCount) -} - -// A StreamFlowController is a flow controller for a QUIC stream. -type StreamFlowController interface { - flowController - // for receiving - // UpdateHighestReceived should be called when a new highest offset is received - // final has to be to true if this is the final offset of the stream, - // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame - UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // Abandon should be called when reading from the stream is aborted early, - // and there won't be any further calls to AddBytesRead. - Abandon() -} - -// The ConnectionFlowController is the flow controller for the connection. -type ConnectionFlowController interface { - flowController - Reset() error -} - -type connectionFlowControllerI interface { - ConnectionFlowController - // The following two methods are not supposed to be called from outside this packet, but are needed internally - // for sending - EnsureMinimumWindowSize(protocol.ByteCount) - // for receiving - IncrementHighestReceived(protocol.ByteCount) error -} diff --git a/internal/quic-go/flowcontrol/stream_flow_controller.go b/internal/quic-go/flowcontrol/stream_flow_controller.go deleted file mode 100644 index 7ee95862..00000000 --- a/internal/quic-go/flowcontrol/stream_flow_controller.go +++ /dev/null @@ -1,149 +0,0 @@ -package flowcontrol - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type streamFlowController struct { - baseFlowController - - streamID protocol.StreamID - - queueWindowUpdate func() - - connection connectionFlowControllerI - - receivedFinalOffset bool -} - -var _ StreamFlowController = &streamFlowController{} - -// NewStreamFlowController gets a new flow controller for a stream -func NewStreamFlowController( - streamID protocol.StreamID, - cfc ConnectionFlowController, - receiveWindow protocol.ByteCount, - maxReceiveWindow protocol.ByteCount, - initialSendWindow protocol.ByteCount, - queueWindowUpdate func(protocol.StreamID), - rttStats *utils.RTTStats, - logger utils.Logger, -) StreamFlowController { - return &streamFlowController{ - streamID: streamID, - connection: cfc.(connectionFlowControllerI), - queueWindowUpdate: func() { queueWindowUpdate(streamID) }, - baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowSize: receiveWindow, - maxReceiveWindowSize: maxReceiveWindow, - sendWindow: initialSendWindow, - logger: logger, - }, - } -} - -// UpdateHighestReceived updates the highestReceived value, if the offset is higher. -func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error { - // If the final offset for this stream is already known, check for consistency. - if c.receivedFinalOffset { - // If we receive another final offset, check that it's the same. - if final && offset != c.highestReceived { - return &qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset), - } - } - // Check that the offset is below the final offset. - if offset > c.highestReceived { - return &qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived), - } - } - } - - if final { - c.receivedFinalOffset = true - } - if offset == c.highestReceived { - return nil - } - // A higher offset was received before. - // This can happen due to reordering. - if offset <= c.highestReceived { - if final { - return &qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived), - } - } - return nil - } - - increment := offset - c.highestReceived - c.highestReceived = offset - if c.checkFlowControlViolation() { - return &qerr.TransportError{ - ErrorCode: qerr.FlowControlError, - ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), - } - } - return c.connection.IncrementHighestReceived(increment) -} - -func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { - c.mutex.Lock() - c.baseFlowController.addBytesRead(n) - shouldQueueWindowUpdate := c.shouldQueueWindowUpdate() - c.mutex.Unlock() - if shouldQueueWindowUpdate { - c.queueWindowUpdate() - } - c.connection.AddBytesRead(n) -} - -func (c *streamFlowController) Abandon() { - c.mutex.Lock() - unread := c.highestReceived - c.bytesRead - c.mutex.Unlock() - if unread > 0 { - c.connection.AddBytesRead(unread) - } -} - -func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { - c.baseFlowController.AddBytesSent(n) - c.connection.AddBytesSent(n) -} - -func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) -} - -func (c *streamFlowController) shouldQueueWindowUpdate() bool { - return !c.receivedFinalOffset && c.hasWindowUpdate() -} - -func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { - // If we already received the final offset for this stream, the peer won't need any additional flow control credit. - if c.receivedFinalOffset { - return 0 - } - - // Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler - c.mutex.Lock() - oldWindowSize := c.receiveWindowSize - offset := c.baseFlowController.getWindowUpdate() - if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size - c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) - c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) - } - c.mutex.Unlock() - return offset -} diff --git a/internal/quic-go/flowcontrol/stream_flow_controller_test.go b/internal/quic-go/flowcontrol/stream_flow_controller_test.go deleted file mode 100644 index 61084795..00000000 --- a/internal/quic-go/flowcontrol/stream_flow_controller_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package flowcontrol - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Stream Flow controller", func() { - var ( - controller *streamFlowController - queuedWindowUpdate bool - ) - - BeforeEach(func() { - queuedWindowUpdate = false - rttStats := &utils.RTTStats{} - controller = &streamFlowController{ - streamID: 10, - connection: NewConnectionFlowController( - 1000, - 1000, - func() {}, - func(protocol.ByteCount) bool { return true }, - rttStats, - utils.DefaultLogger, - ).(*connectionFlowController), - } - controller.maxReceiveWindowSize = 10000 - controller.rttStats = rttStats - controller.logger = utils.DefaultLogger - controller.queueWindowUpdate = func() { queuedWindowUpdate = true } - }) - - Context("Constructor", func() { - rttStats := &utils.RTTStats{} - const receiveWindow protocol.ByteCount = 2000 - const maxReceiveWindow protocol.ByteCount = 3000 - const sendWindow protocol.ByteCount = 4000 - - It("sets the send and receive windows", func() { - cc := NewConnectionFlowController(0, 0, nil, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger) - fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController) - Expect(fc.streamID).To(Equal(protocol.StreamID(5))) - Expect(fc.receiveWindow).To(Equal(receiveWindow)) - Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) - Expect(fc.sendWindow).To(Equal(sendWindow)) - }) - - It("queues window updates with the correct stream ID", func() { - var queued bool - queueWindowUpdate := func(id protocol.StreamID) { - Expect(id).To(Equal(protocol.StreamID(5))) - queued = true - } - - cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger) - fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) - fc.AddBytesRead(receiveWindow) - Expect(queued).To(BeTrue()) - }) - }) - - Context("receiving data", func() { - Context("registering received offsets", func() { - var receiveWindow protocol.ByteCount = 10000 - var receiveWindowSize protocol.ByteCount = 600 - - BeforeEach(func() { - controller.receiveWindow = receiveWindow - controller.receiveWindowSize = receiveWindowSize - }) - - It("updates the highestReceived", func() { - controller.highestReceived = 1337 - Expect(controller.UpdateHighestReceived(1338, false)).To(Succeed()) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338))) - }) - - It("informs the connection flow controller about received data", func() { - controller.highestReceived = 10 - controller.connection.(*connectionFlowController).highestReceived = 100 - Expect(controller.UpdateHighestReceived(20, false)).To(Succeed()) - Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10))) - }) - - It("does not decrease the highestReceived", func() { - controller.highestReceived = 1337 - Expect(controller.UpdateHighestReceived(1000, false)).To(Succeed()) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337))) - }) - - It("does nothing when setting the same byte offset", func() { - controller.highestReceived = 1337 - Expect(controller.UpdateHighestReceived(1337, false)).To(Succeed()) - }) - - It("does not give a flow control violation when using the window completely", func() { - controller.connection.(*connectionFlowController).receiveWindow = receiveWindow - Expect(controller.UpdateHighestReceived(receiveWindow, false)).To(Succeed()) - }) - - It("detects a flow control violation", func() { - Expect(controller.UpdateHighestReceived(receiveWindow+1, false)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FlowControlError, - ErrorMessage: "received 10001 bytes on stream 10, allowed 10000 bytes", - })) - }) - - It("accepts a final offset higher than the highest received", func() { - Expect(controller.UpdateHighestReceived(100, false)).To(Succeed()) - Expect(controller.UpdateHighestReceived(101, true)).To(Succeed()) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(101))) - }) - - It("errors when receiving a final offset smaller than the highest offset received so far", func() { - controller.UpdateHighestReceived(100, false) - Expect(controller.UpdateHighestReceived(50, true)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: "received final offset 50 for stream 10, but already received offset 100 before", - })) - }) - - It("accepts delayed data after receiving a final offset", func() { - Expect(controller.UpdateHighestReceived(300, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(250, false)).To(Succeed()) - }) - - It("errors when receiving a higher offset after receiving a final offset", func() { - Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(250, false)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: "received offset 250 for stream 10, but final offset was already received at 200", - })) - }) - - It("accepts duplicate final offsets", func() { - Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(200))) - }) - - It("errors when receiving inconsistent final offsets", func() { - Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(201, true)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FinalSizeError, - ErrorMessage: "received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)", - })) - }) - - It("tells the connection flow controller when a stream is abandoned", func() { - controller.AddBytesRead(5) - Expect(controller.UpdateHighestReceived(100, true)).To(Succeed()) - controller.Abandon() - Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(100))) - }) - }) - - It("saves when data is read", func() { - controller.AddBytesRead(200) - Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200))) - Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200))) - }) - - Context("generating window updates", func() { - var oldWindowSize protocol.ByteCount - - // update the congestion such that it returns a given value for the smoothed RTT - setRtt := func(t time.Duration) { - controller.rttStats.UpdateRTT(t, 0, time.Now()) - Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked - } - - BeforeEach(func() { - controller.receiveWindow = 100 - controller.receiveWindowSize = 60 - controller.bytesRead = 100 - 60 - controller.connection.(*connectionFlowController).receiveWindow = 100 - controller.connection.(*connectionFlowController).receiveWindowSize = 120 - oldWindowSize = controller.receiveWindowSize - }) - - It("queues window updates", func() { - controller.AddBytesRead(1) - Expect(queuedWindowUpdate).To(BeFalse()) - controller.AddBytesRead(29) - Expect(queuedWindowUpdate).To(BeTrue()) - Expect(controller.GetWindowUpdate()).ToNot(BeZero()) - queuedWindowUpdate = false - controller.AddBytesRead(1) - Expect(queuedWindowUpdate).To(BeFalse()) - }) - - It("tells the connection flow controller when the window was auto-tuned", func() { - var allowed protocol.ByteCount - controller.connection.(*connectionFlowController).allowWindowIncrease = func(size protocol.ByteCount) bool { - allowed = size - return true - } - oldOffset := controller.bytesRead - setRtt(scaleDuration(20 * time.Millisecond)) - controller.epochStartOffset = oldOffset - controller.epochStartTime = time.Now().Add(-time.Millisecond) - controller.AddBytesRead(55) - offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(oldOffset + 55 + 2*oldWindowSize)) - Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) - Expect(allowed).To(Equal(oldWindowSize)) - Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))) - }) - - It("doesn't increase the connection flow control window if it's not allowed", func() { - oldOffset := controller.bytesRead - oldConnectionSize := controller.connection.(*connectionFlowController).receiveWindowSize - controller.connection.(*connectionFlowController).allowWindowIncrease = func(protocol.ByteCount) bool { return false } - setRtt(scaleDuration(20 * time.Millisecond)) - controller.epochStartOffset = oldOffset - controller.epochStartTime = time.Now().Add(-time.Millisecond) - controller.AddBytesRead(55) - offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(oldOffset + 55 + 2*oldWindowSize)) - Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) - Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(oldConnectionSize)) - }) - - It("sends a connection-level window update when a large stream is abandoned", func() { - Expect(controller.UpdateHighestReceived(90, true)).To(Succeed()) - Expect(controller.connection.GetWindowUpdate()).To(BeZero()) - controller.Abandon() - Expect(controller.connection.GetWindowUpdate()).ToNot(BeZero()) - }) - - It("doesn't increase the window after a final offset was already received", func() { - Expect(controller.UpdateHighestReceived(90, true)).To(Succeed()) - controller.AddBytesRead(30) - Expect(queuedWindowUpdate).To(BeFalse()) - offset := controller.GetWindowUpdate() - Expect(offset).To(BeZero()) - }) - }) - }) - - Context("sending data", func() { - It("gets the size of the send window", func() { - controller.connection.UpdateSendWindow(1000) - controller.UpdateSendWindow(15) - controller.AddBytesSent(5) - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10))) - }) - - It("makes sure that it doesn't overflow the connection-level window", func() { - controller.connection.UpdateSendWindow(12) - controller.UpdateSendWindow(20) - controller.AddBytesSent(10) - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(2))) - }) - - It("doesn't say that it's blocked, if only the connection is blocked", func() { - controller.connection.UpdateSendWindow(50) - controller.UpdateSendWindow(100) - controller.AddBytesSent(50) - blocked, _ := controller.connection.IsNewlyBlocked() - Expect(blocked).To(BeTrue()) - Expect(controller.IsNewlyBlocked()).To(BeFalse()) - }) - }) -}) diff --git a/internal/quic-go/frame_sorter.go b/internal/quic-go/frame_sorter.go deleted file mode 100644 index aa16e38c..00000000 --- a/internal/quic-go/frame_sorter.go +++ /dev/null @@ -1,224 +0,0 @@ -package quic - -import ( - "errors" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type frameSorterEntry struct { - Data []byte - DoneCb func() -} - -type frameSorter struct { - queue map[protocol.ByteCount]frameSorterEntry - readPos protocol.ByteCount - gaps *utils.ByteIntervalList -} - -var errDuplicateStreamData = errors.New("duplicate stream data") - -func newFrameSorter() *frameSorter { - s := frameSorter{ - gaps: utils.NewByteIntervalList(), - queue: make(map[protocol.ByteCount]frameSorterEntry), - } - s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) - return &s -} - -func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error { - err := s.push(data, offset, doneCb) - if err == errDuplicateStreamData { - if doneCb != nil { - doneCb() - } - return nil - } - return err -} - -func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error { - if len(data) == 0 { - return errDuplicateStreamData - } - - start := offset - end := offset + protocol.ByteCount(len(data)) - - if end <= s.gaps.Front().Value.Start { - return errDuplicateStreamData - } - - startGap, startsInGap := s.findStartGap(start) - endGap, endsInGap := s.findEndGap(startGap, end) - - startGapEqualsEndGap := startGap == endGap - - if (startGapEqualsEndGap && end <= startGap.Value.Start) || - (!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) { - return errDuplicateStreamData - } - - startGapNext := startGap.Next() - startGapEnd := startGap.Value.End // save it, in case startGap is modified - endGapStart := endGap.Value.Start // save it, in case endGap is modified - endGapEnd := endGap.Value.End // save it, in case endGap is modified - var adjustedStartGapEnd bool - var wasCut bool - - pos := start - var hasReplacedAtLeastOne bool - for { - oldEntry, ok := s.queue[pos] - if !ok { - break - } - oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) - if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) { - // The existing frame is shorter than the new frame. Replace it. - delete(s.queue, pos) - pos += oldEntryLen - hasReplacedAtLeastOne = true - if oldEntry.DoneCb != nil { - oldEntry.DoneCb() - } - } else { - if !hasReplacedAtLeastOne { - return errDuplicateStreamData - } - // The existing frame is longer than the new frame. - // Cut the new frame such that the end aligns with the start of the existing frame. - data = data[:pos-start] - end = pos - wasCut = true - break - } - } - - if !startsInGap && !hasReplacedAtLeastOne { - // cut the frame, such that it starts at the start of the gap - data = data[startGap.Value.Start-start:] - start = startGap.Value.Start - wasCut = true - } - if start <= startGap.Value.Start { - if end >= startGap.Value.End { - // The frame covers the whole startGap. Delete the gap. - s.gaps.Remove(startGap) - } else { - startGap.Value.Start = end - } - } else if !hasReplacedAtLeastOne { - startGap.Value.End = start - adjustedStartGapEnd = true - } - - if !startGapEqualsEndGap { - s.deleteConsecutive(startGapEnd) - var nextGap *utils.ByteIntervalElement - for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { - nextGap = gap.Next() - s.deleteConsecutive(gap.Value.End) - s.gaps.Remove(gap) - } - } - - if !endsInGap && start != endGapEnd && end > endGapEnd { - // cut the frame, such that it ends at the end of the gap - data = data[:endGapEnd-start] - end = endGapEnd - wasCut = true - } - if end == endGapEnd { - if !startGapEqualsEndGap { - // The frame covers the whole endGap. Delete the gap. - s.gaps.Remove(endGap) - } - } else { - if startGapEqualsEndGap && adjustedStartGapEnd { - // The frame split the existing gap into two. - s.gaps.InsertAfter(utils.ByteInterval{Start: end, End: startGapEnd}, startGap) - } else if !startGapEqualsEndGap { - endGap.Value.Start = end - } - } - - if wasCut && len(data) < protocol.MinStreamFrameBufferSize { - newData := make([]byte, len(data)) - copy(newData, data) - data = newData - if doneCb != nil { - doneCb() - doneCb = nil - } - } - - if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { - return errors.New("too many gaps in received data") - } - - s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb} - return nil -} - -func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - if offset >= gap.Value.Start && offset <= gap.Value.End { - return gap, true - } - if offset < gap.Value.Start { - return gap, false - } - } - panic("no gap found") -} - -func (s *frameSorter) findEndGap(startGap *utils.ByteIntervalElement, offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { - for gap := startGap; gap != nil; gap = gap.Next() { - if offset >= gap.Value.Start && offset < gap.Value.End { - return gap, true - } - if offset < gap.Value.Start { - return gap.Prev(), false - } - } - panic("no gap found") -} - -// deleteConsecutive deletes consecutive frames from the queue, starting at pos -func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) { - for { - oldEntry, ok := s.queue[pos] - if !ok { - break - } - oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) - delete(s.queue, pos) - if oldEntry.DoneCb != nil { - oldEntry.DoneCb() - } - pos += oldEntryLen - } -} - -func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { - entry, ok := s.queue[s.readPos] - if !ok { - return s.readPos, nil, nil - } - delete(s.queue, s.readPos) - offset := s.readPos - s.readPos += protocol.ByteCount(len(entry.Data)) - if s.gaps.Front().Value.End <= s.readPos { - panic("frame sorter BUG: read position higher than a gap") - } - return offset, entry.Data, entry.DoneCb -} - -// HasMoreData says if there is any more data queued at *any* offset. -func (s *frameSorter) HasMoreData() bool { - return len(s.queue) > 0 -} diff --git a/internal/quic-go/frame_sorter_test.go b/internal/quic-go/frame_sorter_test.go deleted file mode 100644 index 52614111..00000000 --- a/internal/quic-go/frame_sorter_test.go +++ /dev/null @@ -1,1527 +0,0 @@ -package quic - -import ( - "bytes" - "fmt" - "math" - "math/rand" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("frame sorter", func() { - var s *frameSorter - - checkGaps := func(expectedGaps []utils.ByteInterval) { - if s.gaps.Len() != len(expectedGaps) { - fmt.Println("Gaps:") - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - fmt.Printf("\t%d - %d\n", gap.Value.Start, gap.Value.End) - } - ExpectWithOffset(1, s.gaps.Len()).To(Equal(len(expectedGaps))) - } - var i int - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - ExpectWithOffset(1, gap.Value).To(Equal(expectedGaps[i])) - i++ - } - } - - type callbackTracker struct { - called *bool - cb func() - } - - getCallback := func() (func(), callbackTracker) { - var called bool - cb := func() { - if called { - panic("double free") - } - called = true - } - return cb, callbackTracker{ - cb: cb, - called: &called, - } - } - - checkCallbackCalled := func(t callbackTracker) { - ExpectWithOffset(1, *t.called).To(BeTrue()) - } - - checkCallbackNotCalled := func(t callbackTracker) { - ExpectWithOffset(1, *t.called).To(BeFalse()) - t.cb() - ExpectWithOffset(1, *t.called).To(BeTrue()) - } - - BeforeEach(func() { - s = newFrameSorter() - }) - - It("returns nil when empty", func() { - _, data, doneCb := s.Pop() - Expect(data).To(BeNil()) - Expect(doneCb).To(BeNil()) - }) - - It("inserts and pops a single frame", func() { - cb, t := getCallback() - Expect(s.Push([]byte("foobar"), 0, cb)).To(Succeed()) - offset, data, doneCb := s.Pop() - Expect(offset).To(BeZero()) - Expect(data).To(Equal([]byte("foobar"))) - Expect(doneCb).ToNot(BeNil()) - checkCallbackNotCalled(t) - offset, data, doneCb = s.Pop() - Expect(offset).To(Equal(protocol.ByteCount(6))) - Expect(data).To(BeNil()) - Expect(doneCb).To(BeNil()) - }) - - It("inserts and pops two consecutive frame", func() { - cb1, t1 := getCallback() - cb2, t2 := getCallback() - Expect(s.Push([]byte("bar"), 3, cb2)).To(Succeed()) - Expect(s.Push([]byte("foo"), 0, cb1)).To(Succeed()) - offset, data, doneCb := s.Pop() - Expect(offset).To(BeZero()) - Expect(data).To(Equal([]byte("foo"))) - Expect(doneCb).ToNot(BeNil()) - doneCb() - checkCallbackCalled(t1) - offset, data, doneCb = s.Pop() - Expect(offset).To(Equal(protocol.ByteCount(3))) - Expect(data).To(Equal([]byte("bar"))) - Expect(doneCb).ToNot(BeNil()) - doneCb() - checkCallbackCalled(t2) - offset, data, doneCb = s.Pop() - Expect(offset).To(Equal(protocol.ByteCount(6))) - Expect(data).To(BeNil()) - Expect(doneCb).To(BeNil()) - }) - - It("ignores empty frames", func() { - Expect(s.Push(nil, 0, nil)).To(Succeed()) - _, data, doneCb := s.Pop() - Expect(data).To(BeNil()) - Expect(doneCb).To(BeNil()) - }) - - It("says if has more data", func() { - Expect(s.HasMoreData()).To(BeFalse()) - Expect(s.Push([]byte("foo"), 0, nil)).To(Succeed()) - Expect(s.HasMoreData()).To(BeTrue()) - _, data, _ := s.Pop() - Expect(data).To(Equal([]byte("foo"))) - Expect(s.HasMoreData()).To(BeFalse()) - }) - - Context("Gap handling", func() { - var dataCounter uint8 - - BeforeEach(func() { - dataCounter = 0 - }) - - checkQueue := func(m map[protocol.ByteCount][]byte) { - ExpectWithOffset(1, s.queue).To(HaveLen(len(m))) - for offset, data := range m { - ExpectWithOffset(1, s.queue).To(HaveKey(offset)) - ExpectWithOffset(1, s.queue[offset].Data).To(Equal(data)) - } - } - - getData := func(l protocol.ByteCount) []byte { - dataCounter++ - return bytes.Repeat([]byte{dataCounter}, int(l)) - } - - // ---xxx-------------- - // ++++++ - // => - // ---xxx++++++-------- - It("case 1", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 11, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ---xxx----------------- - // +++++++ - // => - // ---xxx---+++++++-------- - It("case 2", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 -15 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 10: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 6, End: 10}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ---xxx----xxxxxx------- - // ++++ - // => - // ---xxx++++xxxxx-------- - It("case 3", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f2, 6, cb3)).To(Succeed()) // 6 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f2, - 10: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ----xxxx------- - // ++++ - // => - // ----xxxx++----- - It("case 4", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 - Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 7: f2[2:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - It("case 4, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) - f1 := getData(4 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 7 - Expect(s.Push(f2, 5*mult, cb2)).To(Succeed()) // 5 - 9 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 7 * mult: f2[2*mult:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 9 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - }) - - // xxxx------- - // ++++ - // => - // xxxx+++----- - It("case 5", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 4 - Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 7 - checkQueue(map[protocol.ByteCount][]byte{ - 0: f1, - 4: f2[1:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 7, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - It("case 5, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) - f1 := getData(4 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 4 - Expect(s.Push(f2, 3*mult, cb2)).To(Succeed()) // 3 - 7 - checkQueue(map[protocol.ByteCount][]byte{ - 0: f1, - 4 * mult: f2[mult:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 7 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ----xxxx------- - // ++++ - // => - // --++xxxx------- - It("case 6", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 - Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 7 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f2[:2], - 5: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - It("case 6, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) - f1 := getData(4 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5*mult, cb1)).To(Succeed()) // 5 - 9 - Expect(s.Push(f2, 3*mult, cb2)).To(Succeed()) // 3 - 7 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f2[:2*mult], - 5 * mult: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 9 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ---xxx----xxxxxx------- - // ++ - // => - // ---xxx++--xxxxx-------- - It("case 7", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(2) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f2, 6, cb3)).To(Succeed()) // 6 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f2, - 10: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 8, End: 10}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx---------xxxxxx-- - // ++ - // => - // ---xxx---++----xxxxx-- - It("case 8", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(2) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f3, 15, cb2)).To(Succeed()) // 15 - 20 - Expect(s.Push(f2, 10, cb3)).To(Succeed()) // 10 - 12 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 10: f2, - 15: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 6, End: 10}, - {Start: 12, End: 15}, - {Start: 20, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx----xxxxxx------- - // ++ - // => - // ---xxx--++xxxxx-------- - It("case 9", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(2) - cb2, t2 := getCallback() - cb3, t3 := getCallback() - f3 := getData(5) - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f3, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f2, 8, cb3)).To(Succeed()) // 8 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 8: f2, - 10: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 6, End: 8}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx----=====------- - // +++++++ - // => - // ---xxx++++=====-------- - It("case 10", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - f3 := getData(6) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f3[1:5], - 10: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 10, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 4)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(5 * mult) - cb2, t2 := getCallback() - f3 := getData(6 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 10*mult, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 6 * mult: f3[mult : 5*mult], - 10 * mult: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 15 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxxx----=====------- - // ++++++ - // => - // ---xxx++++=====-------- - It("case 11", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 - Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 7: f3[2:], - 10: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - // ---xxxx----=====------- - // ++++++ - // => - // ---xxx++++=====-------- - It("case 11, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) - f1 := getData(4 * mult) - cb1, t1 := getCallback() - f2 := getData(5 * mult) - cb2, t2 := getCallback() - f3 := getData(5 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 7 - Expect(s.Push(f2, 10*mult, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 7 * mult: f3[2*mult:], - 10 * mult: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 15 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ----xxxx------- - // +++++++ - // => - // ----+++++++----- - It("case 12", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(7) - cb2, t2 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 7 - Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ----xxx===------- - // +++++++ - // => - // ----+++++++----- - It("case 13", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(7) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ----xxx====------- - // +++++ - // => - // ----+++====----- - It("case 14", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3[:3], - 6: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 14, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - f3 := getData(5 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 3*mult, cb3)).To(Succeed()) // 3 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f3[:3*mult], - 6 * mult: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 10 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ----xxx===------- - // ++++++ - // => - // ----++++++----- - It("case 15", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(6) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 9 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxxx------- - // ++++ - // => - // ---xxxx----- - It("case 16", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 - Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - // ----xxx===------- - // +++ - // => - // ----xxx===----- - It("case 17", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(3) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 6 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - // ---xxxx------- - // ++ - // => - // ---xxxx----- - It("case 18", func() { - f1 := getData(4) - cb1, t1 := getCallback() - f2 := getData(2) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 9 - Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 7 - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 9, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - // ---xxxxx------ - // ++ - // => - // ---xxxxx---- - It("case 19", func() { - f1 := getData(5) - cb1, t1 := getCallback() - f2 := getData(2) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - }) - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - // xxxxx------ - // ++ - // => - // xxxxx------ - It("case 20", func() { - f1 := getData(10) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - Expect(s.Push(f1, 0, cb1)).To(Succeed()) // 0 - 10 - Expect(s.Push(f2, 5, cb2)).To(Succeed()) // 5 - 9 - checkQueue(map[protocol.ByteCount][]byte{ - 0: f1, - }) - checkGaps([]utils.ByteInterval{ - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - // ---xxxxx--- - // +++ - // => - // ---xxxxx--- - It("case 21", func() { - f1 := getData(5) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 - Expect(s.Push(f2, 7, cb2)).To(Succeed()) // 7 - 10 - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - }) - - // ----xxx------ - // +++++ - // => - // --+++++---- - It("case 22", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 - Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 8, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ----xxx===------ - // ++++++++ - // => - // --++++++++---- - It("case 23", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(8) - cb3, t3 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 - Expect(s.Push(f2, 8, cb2)).To(Succeed()) // 8 - 11 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 11, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // --xxx---===--- - // ++++++ - // => - // --xxx++++++---- - It("case 24", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(6) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 6, cb3)).To(Succeed()) // 6 - 12 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 12, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // --xxx---===---### - // +++++++++ - // => - // --xxx+++++++++### - It("case 25", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(3) - cb3, t3 := getCallback() - f4 := getData(9) - cb4, t4 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 15, cb3)).To(Succeed()) // 15 - 18 - Expect(s.Push(f4, 6, cb4)).To(Succeed()) // 6 - 15 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f4, - 15: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 18, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - checkCallbackNotCalled(t4) - }) - - // ----xxx------ - // +++++++ - // => - // --+++++++--- - It("case 26", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(10) - cb2, t2 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 8 - Expect(s.Push(f2, 3, cb2)).To(Succeed()) // 3 - 13 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 13, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - }) - - // ---xxx====--- - // ++++ - // => - // --+xxx====--- - It("case 27", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - f3 := getData(4) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 2, cb3)).To(Succeed()) // 2 - 6 - checkQueue(map[protocol.ByteCount][]byte{ - 2: f3[:1], - 3: f1, - 6: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 27, for long frames", func() { - const mult = protocol.MinStreamFrameSize - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - f3 := getData(4 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 2*mult, cb3)).To(Succeed()) // 2 - 6 - checkQueue(map[protocol.ByteCount][]byte{ - 2 * mult: f3[:mult], - 3 * mult: f1, - 6 * mult: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2 * mult}, - {Start: 10 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx====--- - // ++++++ - // => - // --+xxx====--- - It("case 28", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - f3 := getData(6) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 2, cb3)).To(Succeed()) // 2 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 2: f3[:1], - 3: f1, - 6: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2}, - {Start: 10, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 28, for long frames", func() { - const mult = protocol.MinStreamFrameSize - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - f3 := getData(6 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 10 - Expect(s.Push(f3, 2*mult, cb3)).To(Succeed()) // 2 - 8 - checkQueue(map[protocol.ByteCount][]byte{ - 2 * mult: f3[:mult], - 3 * mult: f1, - 6 * mult: f2, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2 * mult}, - {Start: 10 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx===----- - // +++++ - // => - // ---xxx+++++--- - It("case 29", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(5) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 6, cb3)).To(Succeed()) // 6 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 11, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx===---- - // ++++++ - // => - // ---xxx===++-- - It("case 30", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(6) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f2, - 9: f3[4:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 11, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 30, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(3 * mult) - cb2, t2 := getCallback() - f3 := getData(6 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 6*mult, cb2)).To(Succeed()) // 6 - 9 - Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 11 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 6 * mult: f2, - 9 * mult: f3[4*mult:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 11 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx---===----- - // ++++++++++ - // => - // ---xxx++++++++--- - It("case 31", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(10) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 5, cb3)).To(Succeed()) // 5 - 15 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f3[1:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 15, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 31, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 9)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(3 * mult) - cb2, t2 := getCallback() - f3 := getData(10 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 5*mult, cb3)).To(Succeed()) // 5 - 15 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 6 * mult: f3[mult:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 15 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx---===----- - // +++++++++ - // => - // ---+++++++++--- - It("case 32", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(9) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 12 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 12, End: protocol.MaxByteCount}, - }) - checkCallbackCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - }) - - // ---xxx---===###----- - // ++++++++++++ - // => - // ---xxx++++++++++--- - It("case 33", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(3) - cb2, t2 := getCallback() - f3 := getData(3) - cb3, t3 := getCallback() - f4 := getData(12) - cb4, t4 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 9, cb3)).To(Succeed()) // 12 - 15 - Expect(s.Push(f4, 5, cb4)).To(Succeed()) // 5 - 17 - checkQueue(map[protocol.ByteCount][]byte{ - 3: f1, - 6: f4[1:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 17, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackCalled(t3) - checkCallbackCalled(t4) - }) - - It("case 33, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 11)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(3 * mult) - cb2, t2 := getCallback() - f3 := getData(3 * mult) - cb3, t3 := getCallback() - f4 := getData(12 * mult) - cb4, t4 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 12 - Expect(s.Push(f3, 9*mult, cb3)).To(Succeed()) // 12 - 15 - Expect(s.Push(f4, 5*mult, cb4)).To(Succeed()) // 5 - 17 - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f1, - 6 * mult: f4[mult:], - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 17 * mult, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackCalled(t3) - checkCallbackNotCalled(t4) - }) - - // ---xxx===---### - // ++++++ - // => - // ---xxx++++++### - It("case 34", func() { - f1 := getData(5) - cb1, t1 := getCallback() - f2 := getData(5) - cb2, t2 := getCallback() - f3 := getData(10) - cb3, t3 := getCallback() - f4 := getData(5) - cb4, t4 := getCallback() - Expect(s.Push(f1, 5, cb1)).To(Succeed()) // 5 - 10 - Expect(s.Push(f2, 10, cb2)).To(Succeed()) // 10 - 15 - Expect(s.Push(f4, 20, cb3)).To(Succeed()) // 20 - 25 - Expect(s.Push(f3, 10, cb4)).To(Succeed()) // 10 - 20 - checkQueue(map[protocol.ByteCount][]byte{ - 5: f1, - 10: f3, - 20: f4, - }) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 25, End: protocol.MaxByteCount}, - }) - checkCallbackNotCalled(t1) - checkCallbackCalled(t2) - checkCallbackNotCalled(t3) - checkCallbackNotCalled(t4) - }) - - // ---xxx---####--- - // ++++++++ - // => - // ---++++++####--- - It("case 35", func() { - f1 := getData(3) - cb1, t1 := getCallback() - f2 := getData(4) - cb2, t2 := getCallback() - f3 := getData(8) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9, cb2)).To(Succeed()) // 9 - 13 - Expect(s.Push(f3, 3, cb3)).To(Succeed()) // 3 - 11 - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3}, - {Start: 13, End: protocol.MaxByteCount}, - }) - checkQueue(map[protocol.ByteCount][]byte{ - 3: f3[:6], - 9: f2, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackCalled(t3) - }) - - It("case 35, for long frames", func() { - mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 6)) - f1 := getData(3 * mult) - cb1, t1 := getCallback() - f2 := getData(4 * mult) - cb2, t2 := getCallback() - f3 := getData(8 * mult) - cb3, t3 := getCallback() - Expect(s.Push(f1, 3*mult, cb1)).To(Succeed()) // 3 - 6 - Expect(s.Push(f2, 9*mult, cb2)).To(Succeed()) // 9 - 13 - Expect(s.Push(f3, 3*mult, cb3)).To(Succeed()) // 3 - 11 - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 3 * mult}, - {Start: 13 * mult, End: protocol.MaxByteCount}, - }) - checkQueue(map[protocol.ByteCount][]byte{ - 3 * mult: f3[:6*mult], - 9 * mult: f2, - }) - checkCallbackCalled(t1) - checkCallbackNotCalled(t2) - checkCallbackNotCalled(t3) - }) - - Context("receiving data after reads", func() { - It("ignores duplicate frames", func() { - Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) - offset, data, _ := s.Pop() - Expect(offset).To(BeZero()) - Expect(data).To(Equal([]byte("foobar"))) - // now receive the duplicate - Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) - Expect(s.queue).To(BeEmpty()) - checkGaps([]utils.ByteInterval{ - {Start: 6, End: protocol.MaxByteCount}, - }) - }) - - It("ignores parts of frames that have already been read", func() { - Expect(s.Push([]byte("foo"), 0, nil)).To(Succeed()) - offset, data, _ := s.Pop() - Expect(offset).To(BeZero()) - Expect(data).To(Equal([]byte("foo"))) - // now receive the duplicate - Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) - offset, data, _ = s.Pop() - Expect(offset).To(Equal(protocol.ByteCount(3))) - Expect(data).To(Equal([]byte("bar"))) - Expect(s.queue).To(BeEmpty()) - checkGaps([]utils.ByteInterval{ - {Start: 6, End: protocol.MaxByteCount}, - }) - }) - }) - - Context("DoS protection", func() { - It("errors when too many gaps are created", func() { - for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ { - Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), nil)).To(Succeed()) - } - Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps)) - err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, nil) - Expect(err).To(MatchError("too many gaps in received data")) - }) - }) - }) - - Context("stress testing", func() { - type frame struct { - offset protocol.ByteCount - data []byte - } - - for _, lf := range []bool{true, false} { - longFrames := lf - - const num = 1000 - - name := "short" - if longFrames { - name = "long" - } - - Context(fmt.Sprintf("using %s frames", name), func() { - var data []byte - var dataLen protocol.ByteCount - var callbacks []callbackTracker - - BeforeEach(func() { - seed := time.Now().UnixNano() - fmt.Fprintf(GinkgoWriter, "Seed: %d\n", seed) - rand.Seed(seed) - - callbacks = nil - dataLen = 25 - if longFrames { - dataLen = 2 * protocol.MinStreamFrameSize - } - - data = make([]byte, num*dataLen) - for i := 0; i < num; i++ { - for j := protocol.ByteCount(0); j < dataLen; j++ { - data[protocol.ByteCount(i)*dataLen+j] = uint8(i) - } - } - }) - - getRandomFrames := func() []frame { - frames := make([]frame, num) - for i := protocol.ByteCount(0); i < num; i++ { - b := make([]byte, dataLen) - Expect(copy(b, data[i*dataLen:])).To(BeEquivalentTo(dataLen)) - frames[i] = frame{ - offset: i * dataLen, - data: b, - } - } - rand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) - return frames - } - - getData := func() []byte { - var data []byte - for { - offset, b, cb := s.Pop() - if b == nil { - break - } - Expect(offset).To(BeEquivalentTo(len(data))) - data = append(data, b...) - if cb != nil { - cb() - } - } - return data - } - - // push pushes data to the frame sorter - // It creates a new callback and adds the - push := func(data []byte, offset protocol.ByteCount) { - cb, t := getCallback() - ExpectWithOffset(1, s.Push(data, offset, cb)).To(Succeed()) - callbacks = append(callbacks, t) - } - - checkCallbacks := func() { - ExpectWithOffset(1, callbacks).ToNot(BeEmpty()) - for _, t := range callbacks { - checkCallbackCalled(t) - } - } - - It("inserting frames in a random order", func() { - frames := getRandomFrames() - - for _, f := range frames { - push(f.data, f.offset) - } - checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) - - Expect(getData()).To(Equal(data)) - Expect(s.queue).To(BeEmpty()) - checkCallbacks() - }) - - It("inserting frames in a random order, with some duplicates", func() { - frames := getRandomFrames() - - for _, f := range frames { - push(f.data, f.offset) - if rand.Intn(10) < 5 { - df := frames[rand.Intn(len(frames))] - push(df.data, df.offset) - } - } - checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) - - Expect(getData()).To(Equal(data)) - Expect(s.queue).To(BeEmpty()) - checkCallbacks() - }) - - It("inserting frames in a random order, with randomly cut retransmissions", func() { - frames := getRandomFrames() - - for _, f := range frames { - push(f.data, f.offset) - if rand.Intn(10) < 5 { - length := protocol.ByteCount(1 + rand.Intn(int(4*dataLen))) - if length >= num*dataLen { - length = num*dataLen - 1 - } - b := make([]byte, length) - offset := protocol.ByteCount(rand.Intn(int(num*dataLen - length))) - Expect(copy(b, data[offset:offset+length])).To(BeEquivalentTo(length)) - push(b, offset) - } - } - checkGaps([]utils.ByteInterval{{Start: num * dataLen, End: protocol.MaxByteCount}}) - - Expect(getData()).To(Equal(data)) - Expect(s.queue).To(BeEmpty()) - checkCallbacks() - }) - }) - } - }) -}) diff --git a/internal/quic-go/framer.go b/internal/quic-go/framer.go deleted file mode 100644 index db989480..00000000 --- a/internal/quic-go/framer.go +++ /dev/null @@ -1,171 +0,0 @@ -package quic - -import ( - "errors" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type framer interface { - HasData() bool - - QueueControlFrame(wire.Frame) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) - - AddActiveStream(protocol.StreamID) - AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) - - Handle0RTTRejection() error -} - -type framerI struct { - mutex sync.Mutex - - streamGetter streamGetter - version protocol.VersionNumber - - activeStreams map[protocol.StreamID]struct{} - streamQueue []protocol.StreamID - - controlFrameMutex sync.Mutex - controlFrames []wire.Frame -} - -var _ framer = &framerI{} - -func newFramer( - streamGetter streamGetter, - v protocol.VersionNumber, -) framer { - return &framerI{ - streamGetter: streamGetter, - activeStreams: make(map[protocol.StreamID]struct{}), - version: v, - } -} - -func (f *framerI) HasData() bool { - f.mutex.Lock() - hasData := len(f.streamQueue) > 0 - f.mutex.Unlock() - if hasData { - return true - } - f.controlFrameMutex.Lock() - hasData = len(f.controlFrames) > 0 - f.controlFrameMutex.Unlock() - return hasData -} - -func (f *framerI) QueueControlFrame(frame wire.Frame) { - f.controlFrameMutex.Lock() - f.controlFrames = append(f.controlFrames, frame) - f.controlFrameMutex.Unlock() -} - -func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount - f.controlFrameMutex.Lock() - for len(f.controlFrames) > 0 { - frame := f.controlFrames[len(f.controlFrames)-1] - frameLen := frame.Length(f.version) - if length+frameLen > maxLen { - break - } - frames = append(frames, ackhandler.Frame{Frame: frame}) - length += frameLen - f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] - } - f.controlFrameMutex.Unlock() - return frames, length -} - -func (f *framerI) AddActiveStream(id protocol.StreamID) { - f.mutex.Lock() - if _, ok := f.activeStreams[id]; !ok { - f.streamQueue = append(f.streamQueue, id) - f.activeStreams[id] = struct{}{} - } - f.mutex.Unlock() -} - -func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount - var lastFrame *ackhandler.Frame - f.mutex.Lock() - // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet - numActiveStreams := len(f.streamQueue) - for i := 0; i < numActiveStreams; i++ { - if protocol.MinStreamFrameSize+length > maxLen { - break - } - id := f.streamQueue[0] - f.streamQueue = f.streamQueue[1:] - // This should never return an error. Better check it anyway. - // The stream will only be in the streamQueue, if it enqueued itself there. - str, err := f.streamGetter.GetOrOpenSendStream(id) - // The stream can be nil if it completed after it said it had data. - if str == nil || err != nil { - delete(f.activeStreams, id) - continue - } - remainingLen := maxLen - length - // For the last STREAM frame, we'll remove the DataLen field later. - // Therefore, we can pretend to have more bytes available when popping - // the STREAM frame (which will always have the DataLen set). - remainingLen += quicvarint.Len(uint64(remainingLen)) - frame, hasMoreData := str.popStreamFrame(remainingLen) - if hasMoreData { // put the stream back in the queue (at the end) - f.streamQueue = append(f.streamQueue, id) - } else { // no more data to send. Stream is not active any more - delete(f.activeStreams, id) - } - // The frame can be nil - // * if the receiveStream was canceled after it said it had data - // * the remaining size doesn't allow us to add another STREAM frame - if frame == nil { - continue - } - frames = append(frames, *frame) - length += frame.Length(f.version) - lastFrame = frame - } - f.mutex.Unlock() - if lastFrame != nil { - lastFrameLen := lastFrame.Length(f.version) - // account for the smaller size of the last STREAM frame - lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false - length += lastFrame.Length(f.version) - lastFrameLen - } - return frames, length -} - -func (f *framerI) Handle0RTTRejection() error { - f.mutex.Lock() - defer f.mutex.Unlock() - - f.controlFrameMutex.Lock() - f.streamQueue = f.streamQueue[:0] - for id := range f.activeStreams { - delete(f.activeStreams, id) - } - var j int - for i, frame := range f.controlFrames { - switch frame.(type) { - case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame: - return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT") - case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame: - continue - default: - f.controlFrames[j] = f.controlFrames[i] - j++ - } - } - f.controlFrames = f.controlFrames[:j] - f.controlFrameMutex.Unlock() - return nil -} diff --git a/internal/quic-go/framer_test.go b/internal/quic-go/framer_test.go deleted file mode 100644 index 201053f3..00000000 --- a/internal/quic-go/framer_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package quic - -import ( - "bytes" - "math/rand" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Framer", func() { - const ( - id1 = protocol.StreamID(10) - id2 = protocol.StreamID(11) - ) - - var ( - framer framer - stream1, stream2 *MockSendStreamI - streamGetter *MockStreamGetter - version protocol.VersionNumber - ) - - BeforeEach(func() { - streamGetter = NewMockStreamGetter(mockCtrl) - stream1 = NewMockSendStreamI(mockCtrl) - stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() - stream2 = NewMockSendStreamI(mockCtrl) - stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - framer = newFramer(streamGetter, version) - }) - - Context("handling control frames", func() { - It("adds control frames", func() { - mdf := &wire.MaxDataFrame{MaximumData: 0x42} - msf := &wire.MaxStreamsFrame{MaxStreamNum: 0x1337} - framer.QueueControlFrame(mdf) - framer.QueueControlFrame(msf) - frames, length := framer.AppendControlFrames(nil, 1000) - Expect(frames).To(HaveLen(2)) - fs := []wire.Frame{frames[0].Frame, frames[1].Frame} - Expect(fs).To(ContainElement(mdf)) - Expect(fs).To(ContainElement(msf)) - Expect(length).To(Equal(mdf.Length(version) + msf.Length(version))) - }) - - It("says if it has data", func() { - Expect(framer.HasData()).To(BeFalse()) - f := &wire.MaxDataFrame{MaximumData: 0x42} - framer.QueueControlFrame(f) - Expect(framer.HasData()).To(BeTrue()) - frames, _ := framer.AppendControlFrames(nil, 1000) - Expect(frames).To(HaveLen(1)) - Expect(framer.HasData()).To(BeFalse()) - }) - - It("appends to the slice given", func() { - ping := &wire.PingFrame{} - mdf := &wire.MaxDataFrame{MaximumData: 0x42} - framer.QueueControlFrame(mdf) - frames, length := framer.AppendControlFrames([]ackhandler.Frame{{Frame: ping}}, 1000) - Expect(frames).To(HaveLen(2)) - Expect(frames[0].Frame).To(Equal(ping)) - Expect(frames[1].Frame).To(Equal(mdf)) - Expect(length).To(Equal(mdf.Length(version))) - }) - - It("adds the right number of frames", func() { - maxSize := protocol.ByteCount(1000) - bf := &wire.DataBlockedFrame{MaximumData: 0x1337} - bfLen := bf.Length(version) - numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize - for i := 0; i < numFrames+1; i++ { - framer.QueueControlFrame(bf) - } - frames, length := framer.AppendControlFrames(nil, maxSize) - Expect(frames).To(HaveLen(numFrames)) - Expect(length).To(BeNumerically(">", maxSize-bfLen)) - frames, length = framer.AppendControlFrames(nil, maxSize) - Expect(frames).To(HaveLen(1)) - Expect(length).To(Equal(bfLen)) - }) - - It("drops *_BLOCKED frames when 0-RTT is rejected", func() { - ping := &wire.PingFrame{} - ncid := &wire.NewConnectionIDFrame{SequenceNumber: 10, ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}} - frames := []wire.Frame{ - &wire.DataBlockedFrame{MaximumData: 1337}, - &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1337}, - &wire.StreamsBlockedFrame{StreamLimit: 13}, - ping, - ncid, - } - rand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) - for _, f := range frames { - framer.QueueControlFrame(f) - } - Expect(framer.Handle0RTTRejection()).To(Succeed()) - fs, length := framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(fs).To(HaveLen(2)) - Expect(length).To(Equal(ping.Length(version) + ncid.Length(version))) - }) - }) - - Context("popping STREAM frames", func() { - It("returns nil when popping an empty framer", func() { - Expect(framer.AppendStreamFrames(nil, 1000)).To(BeEmpty()) - }) - - It("returns STREAM frames", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - f := &wire.StreamFrame{ - StreamID: id1, - Data: []byte("foobar"), - Offset: 42, - DataLenPresent: true, - } - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - fs, length := framer.AppendStreamFrames(nil, 1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(length).To(Equal(f.Length(version))) - }) - - It("says if it has data", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) - Expect(framer.HasData()).To(BeFalse()) - framer.AddActiveStream(id1) - Expect(framer.HasData()).To(BeTrue()) - f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")} - f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) - frames, _ := framer.AppendStreamFrames(nil, protocol.MaxByteCount) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f1)) - Expect(framer.HasData()).To(BeTrue()) - frames, _ = framer.AppendStreamFrames(nil, protocol.MaxByteCount) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f2)) - Expect(framer.HasData()).To(BeFalse()) - }) - - It("appends to a frame slice", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - f := &wire.StreamFrame{ - StreamID: id1, - Data: []byte("foobar"), - DataLenPresent: true, - } - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - mdf := &wire.MaxDataFrame{MaximumData: 1337} - frames := []ackhandler.Frame{{Frame: mdf}} - fs, length := framer.AppendStreamFrames(frames, 1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0].Frame).To(Equal(mdf)) - Expect(fs[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - Expect(fs[1].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(length).To(Equal(f.Length(version))) - }) - - It("skips a stream that was reported active, but was completed shortly after", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, nil) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - f := &wire.StreamFrame{ - StreamID: id2, - Data: []byte("foobar"), - DataLenPresent: true, - } - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - framer.AddActiveStream(id2) - frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f)) - }) - - It("skips a stream that was reported active, but doesn't have any data", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - f := &wire.StreamFrame{ - StreamID: id2, - Data: []byte("foobar"), - DataLenPresent: true, - } - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - framer.AddActiveStream(id2) - frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f)) - }) - - It("pops from a stream multiple times, if it has enough data", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) - f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} - f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) - framer.AddActiveStream(id1) // only add it once - frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f1)) - frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f2)) - // no further calls to popStreamFrame, after popStreamFrame said there's no more data - frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(BeNil()) - }) - - It("re-queues a stream at the end, if it has enough data", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} - f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} - f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f11}, true) - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f12}, false) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) - framer.AddActiveStream(id1) // only add it once - framer.AddActiveStream(id2) - // first a frame from stream 1 - frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f11)) - // then a frame from stream 2 - frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f2)) - // then another frame from stream 1 - frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].Frame).To(Equal(f12)) - }) - - It("only dequeues data from each stream once per packet", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} - f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} - // both streams have more data, and will be re-queued - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, true) - framer.AddActiveStream(id1) - framer.AddActiveStream(id2) - frames, length := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(HaveLen(2)) - Expect(frames[0].Frame).To(Equal(f1)) - Expect(frames[1].Frame).To(Equal(f2)) - Expect(length).To(Equal(f1.Length(version) + f2.Length(version))) - }) - - It("returns multiple normal frames in the order they were reported active", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - f1 := &wire.StreamFrame{Data: []byte("foobar")} - f2 := &wire.StreamFrame{Data: []byte("foobaz")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, false) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) - framer.AddActiveStream(id2) - framer.AddActiveStream(id1) - frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(HaveLen(2)) - Expect(frames[0].Frame).To(Equal(f2)) - Expect(frames[1].Frame).To(Equal(f1)) - }) - - It("only asks a stream for data once, even if it was reported active multiple times", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - f := &wire.StreamFrame{Data: []byte("foobar")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) // only one call to this function - framer.AddActiveStream(id1) - framer.AddActiveStream(id1) - frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(HaveLen(1)) - }) - - It("does not pop empty frames", func() { - fs, length := framer.AppendStreamFrames(nil, 500) - Expect(fs).To(BeEmpty()) - Expect(length).To(BeZero()) - }) - - It("pops maximum size STREAM frames", func() { - for i := protocol.MinStreamFrameSize; i < 2000; i++ { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - stream1.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { - f := &wire.StreamFrame{ - StreamID: id1, - DataLenPresent: true, - } - f.Data = make([]byte, f.MaxDataLen(size, version)) - Expect(f.Length(version)).To(Equal(size)) - return &ackhandler.Frame{Frame: f}, false - }) - framer.AddActiveStream(id1) - frames, _ := framer.AppendStreamFrames(nil, i) - Expect(frames).To(HaveLen(1)) - f := frames[0].Frame.(*wire.StreamFrame) - Expect(f.DataLenPresent).To(BeFalse()) - Expect(f.Length(version)).To(Equal(i)) - } - }) - - It("pops multiple STREAM frames", func() { - for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - stream1.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { - f := &wire.StreamFrame{ - StreamID: id2, - DataLenPresent: true, - } - f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, version)) - return &ackhandler.Frame{Frame: f}, false - }) - stream2.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*ackhandler.Frame, bool) { - f := &wire.StreamFrame{ - StreamID: id2, - DataLenPresent: true, - } - f.Data = make([]byte, f.MaxDataLen(size, version)) - Expect(f.Length(version)).To(Equal(size)) - return &ackhandler.Frame{Frame: f}, false - }) - framer.AddActiveStream(id1) - framer.AddActiveStream(id2) - frames, _ := framer.AppendStreamFrames(nil, i) - Expect(frames).To(HaveLen(2)) - f1 := frames[0].Frame.(*wire.StreamFrame) - f2 := frames[1].Frame.(*wire.StreamFrame) - Expect(f1.DataLenPresent).To(BeTrue()) - Expect(f2.DataLenPresent).To(BeFalse()) - Expect(f1.Length(version) + f2.Length(version)).To(Equal(i)) - } - }) - - It("pops frames that when asked for the the minimum STREAM frame size", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - f := &wire.StreamFrame{Data: []byte("foobar")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - }) - - It("does not pop frames smaller than the minimum size", func() { - // don't expect a call to PopStreamFrame() - framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize-1) - }) - - It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { - streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size - f := &wire.StreamFrame{ - StreamID: id1, - Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), - DataLenPresent: true, - } - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f}, false) - framer.AddActiveStream(id1) - fs, length := framer.AppendStreamFrames(nil, 500) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Frame).To(Equal(f)) - Expect(length).To(Equal(f.Length(version))) - }) - - It("drops all STREAM frames when 0-RTT is rejected", func() { - framer.AddActiveStream(id1) - Expect(framer.Handle0RTTRejection()).To(Succeed()) - fs, length := framer.AppendStreamFrames(nil, protocol.MaxByteCount) - Expect(fs).To(BeEmpty()) - Expect(length).To(BeZero()) - }) - }) -}) diff --git a/internal/quic-go/handshake/aead.go b/internal/quic-go/handshake/aead.go deleted file mode 100644 index e4e76cba..00000000 --- a/internal/quic-go/handshake/aead.go +++ /dev/null @@ -1,161 +0,0 @@ -package handshake - -import ( - "crypto/cipher" - "encoding/binary" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { - keyLabel := hkdfLabelKeyV1 - ivLabel := hkdfLabelIVV1 - if v == protocol.Version2 { - keyLabel = hkdfLabelKeyV2 - ivLabel = hkdfLabelIVV2 - } - key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) - iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) - return suite.AEAD(key, iv) -} - -type longHeaderSealer struct { - aead cipher.AEAD - headerProtector headerProtector - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var _ LongHeaderSealer = &longHeaderSealer{} - -func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { - return &longHeaderSealer{ - aead: aead, - headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), - } -} - -func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return s.aead.Seal(dst, s.nonceBuf, src, ad) -} - -func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) -} - -func (s *longHeaderSealer) Overhead() int { - return s.aead.Overhead() -} - -type longHeaderOpener struct { - aead cipher.AEAD - headerProtector headerProtector - highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var _ LongHeaderOpener = &longHeaderOpener{} - -func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { - return &longHeaderOpener{ - aead: aead, - headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), - } -} - -func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { - return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) -} - -func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) - if err == nil { - o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) - } else { - err = ErrDecryptionFailed - } - return dec, err -} - -func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) -} - -type handshakeSealer struct { - LongHeaderSealer - - dropInitialKeys func() - dropped bool -} - -func newHandshakeSealer( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderSealer { - sealer := newLongHeaderSealer(aead, headerProtector) - // The client drops Initial keys when sending the first Handshake packet. - if perspective == protocol.PerspectiveServer { - return sealer - } - return &handshakeSealer{ - LongHeaderSealer: sealer, - dropInitialKeys: dropInitialKeys, - } -} - -func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - data := s.LongHeaderSealer.Seal(dst, src, pn, ad) - if !s.dropped { - s.dropInitialKeys() - s.dropped = true - } - return data -} - -type handshakeOpener struct { - LongHeaderOpener - - dropInitialKeys func() - dropped bool -} - -func newHandshakeOpener( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderOpener { - opener := newLongHeaderOpener(aead, headerProtector) - // The server drops Initial keys when first successfully processing a Handshake packet. - if perspective == protocol.PerspectiveClient { - return opener - } - return &handshakeOpener{ - LongHeaderOpener: opener, - dropInitialKeys: dropInitialKeys, - } -} - -func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) - if err == nil && !o.dropped { - o.dropInitialKeys() - o.dropped = true - } - return dec, err -} diff --git a/internal/quic-go/handshake/aead_test.go b/internal/quic-go/handshake/aead_test.go deleted file mode 100644 index 672d60f0..00000000 --- a/internal/quic-go/handshake/aead_test.go +++ /dev/null @@ -1,204 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/tls" - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Long Header AEAD", func() { - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - for i := range cipherSuites { - cs := cipherSuites[i] - - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - - return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), - newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) - } - - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") - - It("encrypts and decrypts a message", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) - }) - - It("ignores packets it can't decrypt for packet number derivation", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) - Expect(err).To(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) - }) - }) - - Context("header encryption", func() { - It("encrypts and encrypts the header", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) - }) - - It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := bytes.Repeat([]byte{0xff}, 16) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - }) - - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener() - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - sealer.EncryptHeader(sample, &header[0], header[9:13]) - rand.Read(sample) // use a different sample - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - }) - }) - }) - } - }) - - Describe("Long Header AEAD", func() { - var ( - dropped chan struct{} // use a chan because closing it twice will panic - aead cipher.AEAD - hp headerProtector - ) - dropCb := func() { close(dropped) } - msg := []byte("Lorem ipsum dolor sit amet.") - ad := []byte("Donec in velit neque.") - - BeforeEach(func() { - dropped = make(chan struct{}) - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err = cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) - }) - - Context("for the server", func() { - It("drops keys when first successfully processing a Handshake packet", func() { - serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) - // first try to open an invalid message - _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) - Expect(err).To(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - // then open a valid message - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).To(BeClosed()) - // now open the same message again to make sure the callback is only called once - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't drop keys when sealing a Handshake packet", func() { - serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) - serverSealer.Seal(nil, msg, 1, ad) - Expect(dropped).ToNot(BeClosed()) - }) - }) - - Context("for the client", func() { - It("drops keys when first sealing a Handshake packet", func() { - clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) - // seal the first message - clientSealer.Seal(nil, msg, 1, ad) - Expect(dropped).To(BeClosed()) - // seal another message to make sure the callback is only called once - clientSealer.Seal(nil, msg, 2, ad) - }) - - It("doesn't drop keys when processing a Handshake packet", func() { - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) - clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) - _, err := clientOpener.Open(nil, enc, 42, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - }) - }) - }) - } -}) diff --git a/internal/quic-go/handshake/crypto_setup.go b/internal/quic-go/handshake/crypto_setup.go deleted file mode 100644 index 1ef63cd0..00000000 --- a/internal/quic-go/handshake/crypto_setup.go +++ /dev/null @@ -1,819 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/qtls" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// TLS unexpected_message alert -const alertUnexpectedMessage uint8 = 10 - -type messageType uint8 - -// TLS handshake message types. -const ( - typeClientHello messageType = 1 - typeServerHello messageType = 2 - typeNewSessionTicket messageType = 4 - typeEncryptedExtensions messageType = 8 - typeCertificate messageType = 11 - typeCertificateRequest messageType = 13 - typeCertificateVerify messageType = 15 - typeFinished messageType = 20 -) - -func (m messageType) String() string { - switch m { - case typeClientHello: - return "ClientHello" - case typeServerHello: - return "ServerHello" - case typeNewSessionTicket: - return "NewSessionTicket" - case typeEncryptedExtensions: - return "EncryptedExtensions" - case typeCertificate: - return "Certificate" - case typeCertificateRequest: - return "CertificateRequest" - case typeCertificateVerify: - return "CertificateVerify" - case typeFinished: - return "Finished" - default: - return fmt.Sprintf("unknown message type: %d", m) - } -} - -const clientSessionStateRevision = 3 - -type conn struct { - localAddr, remoteAddr net.Addr - version protocol.VersionNumber -} - -var _ ConnWithVersion = &conn{} - -func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { - return &conn{ - localAddr: local, - remoteAddr: remote, - version: version, - } -} - -var _ net.Conn = &conn{} - -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } -func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } - -type cryptoSetup struct { - tlsConf *tls.Config - extraConf *qtls.ExtraConfig - conn *qtls.Conn - - version protocol.VersionNumber - - messageChan chan []byte - isReadingHandshakeMessage chan struct{} - readFirstHandshakeMessage bool - - ourParams *wire.TransportParameters - peerParams *wire.TransportParameters - paramsChan <-chan []byte - - runner handshakeRunner - - alertChan chan uint8 - // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns - handshakeDone chan struct{} - // is closed when Close() is called - closeChan chan struct{} - - zeroRTTParameters *wire.TransportParameters - clientHelloWritten bool - clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written - zeroRTTParametersChan chan<- *wire.TransportParameters - - rttStats *utils.RTTStats - - tracer logging.ConnectionTracer - logger utils.Logger - - perspective protocol.Perspective - - mutex sync.Mutex // protects all members below - - handshakeCompleteTime time.Time - - readEncLevel protocol.EncryptionLevel - writeEncLevel protocol.EncryptionLevel - - zeroRTTOpener LongHeaderOpener // only set for the server - zeroRTTSealer LongHeaderSealer // only set for the client - - initialStream io.Writer - initialOpener LongHeaderOpener - initialSealer LongHeaderSealer - - handshakeStream io.Writer - handshakeOpener LongHeaderOpener - handshakeSealer LongHeaderSealer - - aead *updatableAEAD - has1RTTSealer bool - has1RTTOpener bool -} - -var ( - _ qtls.RecordLayer = &cryptoSetup{} - _ CryptoSetup = &cryptoSetup{} -) - -// NewCryptoSetupClient creates a new crypto setup for the client -func NewCryptoSetupClient( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, -) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { - cs, clientHelloWritten := newCryptoSetup( - initialStream, - handshakeStream, - connID, - tp, - runner, - tlsConf, - enable0RTT, - rttStats, - tracer, - logger, - protocol.PerspectiveClient, - version, - ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) - return cs, clientHelloWritten -} - -// NewCryptoSetupServer creates a new crypto setup for the server -func NewCryptoSetupServer( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, -) CryptoSetup { - cs, _ := newCryptoSetup( - initialStream, - handshakeStream, - connID, - tp, - runner, - tlsConf, - enable0RTT, - rttStats, - tracer, - logger, - protocol.PerspectiveServer, - version, - ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) - return cs -} - -func newCryptoSetup( - initialStream io.Writer, - handshakeStream io.Writer, - connID protocol.ConnectionID, - tp *wire.TransportParameters, - runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - perspective protocol.Perspective, - version protocol.VersionNumber, -) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { - initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { - tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) - tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) - } - extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) - zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) - cs := &cryptoSetup{ - tlsConf: tlsConf, - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - aead: newUpdatableAEAD(rttStats, tracer, logger, version), - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - runner: runner, - ourParams: tp, - paramsChan: extHandler.TransportParameters(), - rttStats: rttStats, - tracer: tracer, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - alertChan: make(chan uint8), - clientHelloWrittenChan: make(chan struct{}), - zeroRTTParametersChan: zeroRTTParametersChan, - messageChan: make(chan []byte, 100), - isReadingHandshakeMessage: make(chan struct{}), - closeChan: make(chan struct{}), - version: version, - } - var maxEarlyData uint32 - if enable0RTT { - maxEarlyData = 0xffffffff - } - cs.extraConf = &qtls.ExtraConfig{ - GetExtensions: extHandler.GetExtensions, - ReceivedExtensions: extHandler.ReceivedExtensions, - AlternativeRecordLayer: cs, - EnforceNextProtoSelection: true, - MaxEarlyData: maxEarlyData, - Accept0RTT: cs.accept0RTT, - Rejected0RTT: cs.rejected0RTT, - Enable0RTT: enable0RTT, - GetAppDataForSessionState: cs.marshalDataForSessionState, - SetAppDataFromSessionState: cs.handleDataFromSessionState, - } - return cs, zeroRTTParametersChan -} - -func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { - initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) - h.initialSealer = initialSealer - h.initialOpener = initialOpener - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) - h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) - } -} - -func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { - return h.aead.SetLargestAcked(pn) -} - -func (h *cryptoSetup) RunHandshake() { - // Handle errors that might occur when HandleData() is called. - handshakeComplete := make(chan struct{}) - handshakeErrChan := make(chan error, 1) - go func() { - defer close(h.handshakeDone) - if err := h.conn.Handshake(); err != nil { - handshakeErrChan <- err - return - } - close(handshakeComplete) - }() - - if h.perspective == protocol.PerspectiveClient { - select { - case err := <-handshakeErrChan: - h.onError(0, err.Error()) - return - case <-h.clientHelloWrittenChan: - } - } - - select { - case <-handshakeComplete: // return when the handshake is done - h.mutex.Lock() - h.handshakeCompleteTime = time.Now() - h.mutex.Unlock() - h.runner.OnHandshakeComplete() - case <-h.closeChan: - // wait until the Handshake() go routine has returned - <-h.handshakeDone - case alert := <-h.alertChan: - handshakeErr := <-handshakeErrChan - h.onError(alert, handshakeErr.Error()) - } -} - -func (h *cryptoSetup) onError(alert uint8, message string) { - var err error - if alert == 0 { - err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} - } else { - err = qerr.NewCryptoError(alert, message) - } - h.runner.OnError(err) -} - -// Close closes the crypto setup. -// It aborts the handshake, if it is still running. -// It must only be called once. -func (h *cryptoSetup) Close() error { - close(h.closeChan) - // wait until qtls.Handshake() actually returned - <-h.handshakeDone - return nil -} - -// handleMessage handles a TLS handshake message. -// It is called by the crypto streams when a new message is available. -// It returns if it is done with messages on the same encryption level. -func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { - msgType := messageType(data[0]) - h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) - if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - h.onError(alertUnexpectedMessage, err.Error()) - return false - } - h.messageChan <- data - if encLevel == protocol.Encryption1RTT { - h.handlePostHandshakeMessage() - return false - } -readLoop: - for { - select { - case data := <-h.paramsChan: - if data == nil { - h.onError(0x6d, "missing quic_transport_parameters extension") - } else { - h.handleTransportParameters(data) - } - case <-h.isReadingHandshakeMessage: - break readLoop - case <-h.handshakeDone: - break readLoop - case <-h.closeChan: - break readLoop - } - } - // We're done with the Initial encryption level after processing a ClientHello / ServerHello, - // but only if a handshake opener and sealer was created. - // Otherwise, a HelloRetryRequest was performed. - // We're done with the Handshake encryption level after processing the Finished message. - return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || - msgType == typeFinished -} - -func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { - var expected protocol.EncryptionLevel - switch msgType { - case typeClientHello, - typeServerHello: - expected = protocol.EncryptionInitial - case typeEncryptedExtensions, - typeCertificate, - typeCertificateRequest, - typeCertificateVerify, - typeFinished: - expected = protocol.EncryptionHandshake - case typeNewSessionTicket: - expected = protocol.Encryption1RTT - default: - return fmt.Errorf("unexpected handshake message: %d", msgType) - } - if encLevel != expected { - return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) - } - return nil -} - -func (h *cryptoSetup) handleTransportParameters(data []byte) { - var tp wire.TransportParameters - if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { - h.runner.OnError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - }) - } - h.peerParams = &tp - h.runner.OnReceivedParams(h.peerParams) -} - -// must be called after receiving the transport parameters -func (h *cryptoSetup) marshalDataForSessionState() []byte { - buf := &bytes.Buffer{} - quicvarint.Write(buf, clientSessionStateRevision) - quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) - h.peerParams.MarshalForSessionTicket(buf) - return buf.Bytes() -} - -func (h *cryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) - if err != nil { - h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) - return - } - h.zeroRTTParameters = tp -} - -func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { - r := bytes.NewReader(data) - ver, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err - } - return &tp, nil -} - -// only valid for the server -func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - var appData []byte - // Save transport parameters to the session ticket if we're allowing 0-RTT. - if h.extraConf.MaxEarlyData > 0 { - appData = (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() - } - return h.conn.GetSessionTicket(appData) -} - -// accept0RTT is called for the server when receiving the client's session ticket. -// It decides whether to accept 0-RTT. -func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { - var t sessionTicket - if err := t.Unmarshal(sessionTicketData); err != nil { - h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) - return false - } - valid := h.ourParams.ValidFor0RTT(t.Parameters) - if valid { - h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) - } else { - h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") - } - return valid -} - -// rejected0RTT is called for the client when the server rejects 0-RTT. -func (h *cryptoSetup) rejected0RTT() { - h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - - h.mutex.Lock() - had0RTTKeys := h.zeroRTTSealer != nil - h.zeroRTTSealer = nil - h.mutex.Unlock() - - if had0RTTKeys { - h.runner.DropKeys(protocol.Encryption0RTT) - } -} - -func (h *cryptoSetup) handlePostHandshakeMessage() { - // make sure the handshake has already completed - <-h.handshakeDone - - done := make(chan struct{}) - defer close(done) - - // h.alertChan is an unbuffered channel. - // If an error occurs during conn.HandlePostHandshakeMessage, - // it will be sent on this channel. - // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. - alertChan := make(chan uint8, 1) - go func() { - <-h.isReadingHandshakeMessage - select { - case alert := <-h.alertChan: - alertChan <- alert - case <-done: - } - }() - - if err := h.conn.HandlePostHandshakeMessage(); err != nil { - select { - case <-h.closeChan: - case alert := <-alertChan: - h.onError(alert, err.Error()) - } - } -} - -// ReadHandshakeMessage is called by TLS. -// It blocks until a new handshake message is available. -func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { - if !h.readFirstHandshakeMessage { - h.readFirstHandshakeMessage = true - } else { - select { - case h.isReadingHandshakeMessage <- struct{}{}: - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } - } - select { - case msg := <-h.messageChan: - return msg, nil - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } -} - -func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: - if h.perspective == protocol.PerspectiveClient { - panic("Received 0-RTT read key for the client") - } - h.zeroRTTOpener = newLongHeaderOpener( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - ) - h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) - } - return - case qtls.EncryptionHandshake: - h.readEncLevel = protocol.EncryptionHandshake - h.handshakeOpener = newHandshakeOpener( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, - ) - h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - case qtls.EncryptionApplication: - h.readEncLevel = protocol.Encryption1RTT - h.aead.SetReadKey(suite, trafficSecret) - h.has1RTTOpener = true - h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) - default: - panic("unexpected read encryption level") - } - h.mutex.Unlock() - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) - } -} - -func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: - if h.perspective == protocol.PerspectiveServer { - panic("Received 0-RTT write key for the server") - } - h.zeroRTTSealer = newLongHeaderSealer( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - ) - h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) - } - return - case qtls.EncryptionHandshake: - h.writeEncLevel = protocol.EncryptionHandshake - h.handshakeSealer = newHandshakeSealer( - createAEAD(suite, trafficSecret, h.version), - newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, - ) - h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - case qtls.EncryptionApplication: - h.writeEncLevel = protocol.Encryption1RTT - h.aead.SetWriteKey(suite, trafficSecret) - h.has1RTTSealer = true - h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) - if h.zeroRTTSealer != nil { - h.zeroRTTSealer = nil - h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { - h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) - } - } - default: - panic("unexpected write encryption level") - } - h.mutex.Unlock() - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) - } -} - -// WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - //nolint:exhaustive // LS records can only be written for Initial and Handshake. - switch h.writeEncLevel { - case protocol.EncryptionInitial: - // assume that the first WriteRecord call contains the ClientHello - n, err := h.initialStream.Write(p) - if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { - h.clientHelloWritten = true - close(h.clientHelloWrittenChan) - if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { - h.logger.Debugf("Doing 0-RTT.") - h.zeroRTTParametersChan <- h.zeroRTTParameters - } else { - h.logger.Debugf("Not doing 0-RTT.") - h.zeroRTTParametersChan <- nil - } - } - return n, err - case protocol.EncryptionHandshake: - return h.handshakeStream.Write(p) - default: - panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) - } -} - -func (h *cryptoSetup) SendAlert(alert uint8) { - select { - case h.alertChan <- alert: - case <-h.closeChan: - // no need to send an alert when we've already closed - } -} - -// used a callback in the handshakeSealer and handshakeOpener -func (h *cryptoSetup) dropInitialKeys() { - h.mutex.Lock() - h.initialOpener = nil - h.initialSealer = nil - h.mutex.Unlock() - h.runner.DropKeys(protocol.EncryptionInitial) - h.logger.Debugf("Dropping Initial keys.") -} - -func (h *cryptoSetup) SetHandshakeConfirmed() { - h.aead.SetHandshakeConfirmed() - // drop Handshake keys - var dropped bool - h.mutex.Lock() - if h.handshakeOpener != nil { - h.handshakeOpener = nil - h.handshakeSealer = nil - dropped = true - } - h.mutex.Unlock() - if dropped { - h.runner.DropKeys(protocol.EncryptionHandshake) - h.logger.Debugf("Dropping Handshake keys.") - } -} - -func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.initialSealer == nil { - return nil, ErrKeysDropped - } - return h.initialSealer, nil -} - -func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTSealer == nil { - return nil, ErrKeysDropped - } - return h.zeroRTTSealer, nil -} - -func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.handshakeSealer == nil { - if h.initialSealer == nil { - return nil, ErrKeysDropped - } - return nil, ErrKeysNotYetAvailable - } - return h.handshakeSealer, nil -} - -func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if !h.has1RTTSealer { - return nil, ErrKeysNotYetAvailable - } - return h.aead, nil -} - -func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.initialOpener == nil { - return nil, ErrKeysDropped - } - return h.initialOpener, nil -} - -func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTOpener == nil { - if h.initialOpener != nil { - return nil, ErrKeysNotYetAvailable - } - // if the initial opener is also not available, the keys were already dropped - return nil, ErrKeysDropped - } - return h.zeroRTTOpener, nil -} - -func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.handshakeOpener == nil { - if h.initialOpener != nil { - return nil, ErrKeysNotYetAvailable - } - // if the initial opener is also not available, the keys were already dropped - return nil, ErrKeysDropped - } - return h.handshakeOpener, nil -} - -func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { - h.zeroRTTOpener = nil - h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { - h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) - } - } - - if !h.has1RTTOpener { - return nil, ErrKeysNotYetAvailable - } - return h.aead, nil -} - -func (h *cryptoSetup) ConnectionState() ConnectionState { - return qtls.GetConnectionState(h.conn) -} diff --git a/internal/quic-go/handshake/crypto_setup_test.go b/internal/quic-go/handshake/crypto_setup_test.go deleted file mode 100644 index 4adefcaa..00000000 --- a/internal/quic-go/handshake/crypto_setup_test.go +++ /dev/null @@ -1,864 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "math/big" - "time" - - mocktls "github.com/imroc/req/v3/internal/quic-go/mocks/tls" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/testdata" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -type chunk struct { - data []byte - encLevel protocol.EncryptionLevel -} - -type stream struct { - encLevel protocol.EncryptionLevel - chunkChan chan<- chunk -} - -func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { - return &stream{ - chunkChan: chunkChan, - encLevel: encLevel, - } -} - -func (s *stream) Write(b []byte) (int, error) { - data := make([]byte, len(b)) - copy(data, b) - select { - case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: - default: - panic("chunkChan too small") - } - return len(b), nil -} - -var _ = Describe("Crypto Setup TLS", func() { - var clientConf, serverConf *tls.Config - - // unparam incorrectly complains that the first argument is never used. - //nolint:unparam - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { - chunkChan := make(chan chunk, 100) - initialStream := newStream(chunkChan, protocol.EncryptionInitial) - handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) - return chunkChan, initialStream, handshakeStream - } - - BeforeEach(func() { - serverConf = testdata.GetTLSConfig() - serverConf.NextProtos = []string{"crypto-setup"} - clientConf = &tls.Config{ - ServerName: "localhost", - RootCAs: testdata.GetRootCA(), - NextProtos: []string{"crypto-setup"}, - } - }) - - It("returns Handshake() when an error occurs in qtls", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "local error: tls: unexpected message", - }))) - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - handledMessage := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.HandleMessage(fakeCH, protocol.EncryptionInitial) - close(handledMessage) - }() - Eventually(handledMessage).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - - It("handles qtls errors occurring before during ClientHello generation", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - tlsConf := testdata.GetTLSConfig() - tlsConf.InsecureSkipVerify = true - tlsConf.NextProtos = []string{""} - cl, _ := NewCryptoSetupClient( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - runner, - tlsConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cl.RunHandshake() - close(done) - }() - - Eventually(done).Should(BeClosed()) - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: "tls: invalid NextProtos value", - }))) - }) - - It("errors when a message is received at the wrong encryption level", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", - }))) - - // make the go routine return - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when handling a message fails", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - close(done) - }() - - fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - NewMockHandshakeRunner(mockCtrl), - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - Context("doing the handshake", func() { - generateCert := func() tls.Certificate { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{}, - SignatureAlgorithm: x509.SHA256WithRSA, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), // valid for an hour - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - PrivateKey: priv, - Certificate: [][]byte{certDER}, - } - } - - newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { - rttStats := &utils.RTTStats{} - rttStats.UpdateRTT(rtt, 0, time.Now()) - ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) - return rttStats - } - - handshake := func(client CryptoSetup, cChunkChan <-chan chunk, - server CryptoSetup, sChunkChan <-chan chunk, - ) { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - msgType := messageType(c.data[0]) - finished := server.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeClientHello { - // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. - _, err := server.GetHandshakeOpener() - Expect(finished).To(Equal(err == nil)) - } else { - Expect(finished).To(BeFalse()) - } - case c := <-sChunkChan: - msgType := messageType(c.data[0]) - finished := client.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeServerHello { - Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) - } else { - Expect(finished).To(BeFalse()) - } - case <-done: // handshake complete - return - } - } - }() - - go func() { - defer GinkgoRecover() - defer close(done) - server.RunHandshake() - ticket, err := server.GetSessionTicket() - Expect(err).ToNot(HaveOccurred()) - if ticket != nil { - client.HandleMessage(ticket, protocol.Encryption1RTT) - } - }() - - client.RunHandshake() - Eventually(done).Should(BeClosed()) - } - - handshakeWithTLSConf := func( - clientConf, serverConf *tls.Config, - clientRTTStats, serverRTTStats *utils.RTTStats, - clientTransportParameters, serverTransportParameters *wire.TransportParameters, - enable0RTT bool, - ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { - var cHandshakeComplete bool - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cErrChan := make(chan error, 1) - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) - cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) - cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) - client, clientHelloWrittenChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - clientTransportParameters, - cRunner, - clientConf, - enable0RTT, - clientRTTStats, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - var sHandshakeComplete bool - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sErrChan := make(chan error, 1) - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) - sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) - if serverTransportParameters.StatelessResetToken == nil { - var token protocol.StatelessResetToken - serverTransportParameters.StatelessResetToken = &token - } - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - serverTransportParameters, - sRunner, - serverConf, - enable0RTT, - serverRTTStats, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - handshake(client, cChunkChan, server, sChunkChan) - var cErr, sErr error - select { - case sErr = <-sErrChan: - default: - Expect(sHandshakeComplete).To(BeTrue()) - } - select { - case cErr = <-cErrChan: - default: - Expect(cHandshakeComplete).To(BeTrue()) - } - return clientHelloWrittenChan, client, cErr, server, sErr - } - - It("handshakes", func() { - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("performs a HelloRetryRequst", func() { - serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("handshakes with client auth", func() { - clientConf.Certificates = []tls.Certificate{generateCert()} - serverConf.ClientAuth = tls.RequireAnyClientCert - _, _, clientErr, _, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("signals when it has written the ClientHello", func() { - runner := NewMockHandshakeRunner(mockCtrl) - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, chChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - runner, - &tls.Config{InsecureSkipVerify: true}, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() - var ch chunk - Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(Receive(BeNil())) - // make sure the whole ClientHello was written - Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) - length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) - Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - Expect(client.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second} - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - cTransportParameters, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) - sRunner.EXPECT().OnHandshakeComplete() - sTransportParameters := &wire.TransportParameters{ - MaxIdleTimeout: 0x1337 * time.Second, - StatelessResetToken: &token, - } - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - sTransportParameters, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) - Expect(sTransportParametersRcvd).ToNot(BeNil()) - Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) - }) - - Context("with session tickets", func() { - It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - - // inject an invalid session ticket - cRunner.EXPECT().OnError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", - }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.EncryptionHandshake) - }) - - It("errors when handling the NewSessionTicket fails", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{}, - cRunner, - clientConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.VersionTLS, - ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.VersionTLS, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - - // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) - }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.Encryption1RTT) - }) - - It("uses session resumption", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - }) - - It("doesn't use session resumption if the server disabled it", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - _, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - - serverConf.SessionTicketsDisabled = true - csc.EXPECT().Get(gomock.Any()).Return(state, true) - _, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - &utils.RTTStats{}, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{}, - false, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - }) - - It("uses 0-RTT", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, serverOrigRTTStats, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - - clientRTTStats := &utils.RTTStats{} - serverRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, serverRTTStats, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) - - var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) - Expect(tp.InitialMaxData).To(Equal(initialMaxData)) - - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(server.ConnectionState().Used0RTT).To(BeTrue()) - Expect(client.ConnectionState().Used0RTT).To(BeTrue()) - }) - - It("rejects 0-RTT, when the transport parameters changed", func() { - csc := mocktls.NewMockClientSessionCache(mockCtrl) - var state *tls.ClientSessionState - receivedSessionTicket := make(chan struct{}) - csc.EXPECT().Get(gomock.Any()) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { - state = css - close(receivedSessionTicket) - }) - clientConf.ClientSessionCache = csc - const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. - clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( - clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) - Expect(server.ConnectionState().DidResume).To(BeFalse()) - Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) - - csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) - csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - - clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( - clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, - true, - ) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - - var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) - Expect(tp.InitialMaxData).To(Equal(initialMaxData)) - - Expect(server.ConnectionState().DidResume).To(BeTrue()) - Expect(client.ConnectionState().DidResume).To(BeTrue()) - Expect(server.ConnectionState().Used0RTT).To(BeFalse()) - Expect(client.ConnectionState().Used0RTT).To(BeFalse()) - }) - }) - }) -}) diff --git a/internal/quic-go/handshake/handshake_suite_test.go b/internal/quic-go/handshake/handshake_suite_test.go deleted file mode 100644 index a464ba1c..00000000 --- a/internal/quic-go/handshake/handshake_suite_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package handshake - -import ( - "crypto/tls" - "encoding/hex" - "strings" - "testing" - - "github.com/imroc/req/v3/internal/quic-go/qtls" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestHandshake(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Handshake Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) - -func splitHexString(s string) (slice []byte) { - for _, ss := range strings.Split(s, " ") { - if ss[0:2] == "0x" { - ss = ss[2:] - } - d, err := hex.DecodeString(ss) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - slice = append(slice, d...) - } - return -} - -var cipherSuites = []*qtls.CipherSuiteTLS13{ - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), - qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), -} diff --git a/internal/quic-go/handshake/header_protector.go b/internal/quic-go/handshake/header_protector.go deleted file mode 100644 index c3e96f24..00000000 --- a/internal/quic-go/handshake/header_protector.go +++ /dev/null @@ -1,136 +0,0 @@ -package handshake - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/tls" - "encoding/binary" - "fmt" - - "golang.org/x/crypto/chacha20" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" -) - -type headerProtector interface { - EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) - DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) -} - -func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { - if v == protocol.Version2 { - return "quicv2 hp" - } - return "quic hp" -} - -func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { - hkdfLabel := hkdfHeaderProtectionLabel(v) - switch suite.ID { - case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: - return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) - case tls.TLS_CHACHA20_POLY1305_SHA256: - return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) - default: - panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) - } -} - -type aesHeaderProtector struct { - mask []byte - block cipher.Block - isLongHeader bool -} - -var _ headerProtector = &aesHeaderProtector{} - -func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) - block, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - return &aesHeaderProtector{ - block: block, - mask: make([]byte, block.BlockSize()), - isLongHeader: isLongHeader, - } -} - -func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { - if len(sample) != len(p.mask) { - panic("invalid sample size") - } - p.block.Encrypt(p.mask, sample) - if p.isLongHeader { - *firstByte ^= p.mask[0] & 0xf - } else { - *firstByte ^= p.mask[0] & 0x1f - } - for i := range hdrBytes { - hdrBytes[i] ^= p.mask[i+1] - } -} - -type chachaHeaderProtector struct { - mask [5]byte - - key [32]byte - isLongHeader bool -} - -var _ headerProtector = &chachaHeaderProtector{} - -func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) - - p := &chachaHeaderProtector{ - isLongHeader: isLongHeader, - } - copy(p.key[:], hpKey) - return p -} - -func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - p.apply(sample, firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { - if len(sample) != 16 { - panic("invalid sample size") - } - for i := 0; i < 5; i++ { - p.mask[i] = 0 - } - cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) - if err != nil { - panic(err) - } - cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) - cipher.XORKeyStream(p.mask[:], p.mask[:]) - p.applyMask(firstByte, hdrBytes) -} - -func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { - if p.isLongHeader { - *firstByte ^= p.mask[0] & 0xf - } else { - *firstByte ^= p.mask[0] & 0x1f - } - for i := range hdrBytes { - hdrBytes[i] ^= p.mask[i+1] - } -} diff --git a/internal/quic-go/handshake/hkdf.go b/internal/quic-go/handshake/hkdf.go deleted file mode 100644 index c4fd86c5..00000000 --- a/internal/quic-go/handshake/hkdf.go +++ /dev/null @@ -1,29 +0,0 @@ -package handshake - -import ( - "crypto" - "encoding/binary" - - "golang.org/x/crypto/hkdf" -) - -// hkdfExpandLabel HKDF expands a label. -// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the -// hkdfExpandLabel in the standard library. -func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { - b := make([]byte, 3, 3+6+len(label)+1+len(context)) - binary.BigEndian.PutUint16(b, uint16(length)) - b[2] = uint8(6 + len(label)) - b = append(b, []byte("tls13 ")...) - b = append(b, []byte(label)...) - b = b[:3+6+len(label)+1] - b[3+6+len(label)] = uint8(len(context)) - b = append(b, context...) - - out := make([]byte, length) - n, err := hkdf.Expand(hash.New, secret, b).Read(out) - if err != nil || n != length { - panic("quic: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} diff --git a/internal/quic-go/handshake/hkdf_test.go b/internal/quic-go/handshake/hkdf_test.go deleted file mode 100644 index 16154199..00000000 --- a/internal/quic-go/handshake/hkdf_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package handshake - -import ( - "crypto" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Initial AEAD using AES-GCM", func() { - // Result generated by running in qtls: - // cipherSuiteTLS13ByID(TLS_AES_128_GCM_SHA256).expandLabel([]byte("secret"), []byte("context"), "label", 42) - It("gets the same results as qtls", func() { - expanded := hkdfExpandLabel(crypto.SHA256, []byte("secret"), []byte("context"), "label", 42) - Expect(expanded).To(Equal([]byte{0x78, 0x87, 0x6a, 0xb5, 0x84, 0xa2, 0x26, 0xb7, 0x8, 0x5a, 0x7b, 0x3a, 0x4c, 0xbb, 0x1e, 0xbc, 0x2f, 0x9b, 0x67, 0xd0, 0x6a, 0xa2, 0x24, 0xb4, 0x7d, 0x29, 0x3c, 0x7a, 0xce, 0xc7, 0xc3, 0x74, 0xcd, 0x59, 0x7a, 0xa8, 0x21, 0x5e, 0xe7, 0xca, 0x1, 0xda})) - }) -}) diff --git a/internal/quic-go/handshake/initial_aead.go b/internal/quic-go/handshake/initial_aead.go deleted file mode 100644 index 8a579b20..00000000 --- a/internal/quic-go/handshake/initial_aead.go +++ /dev/null @@ -1,81 +0,0 @@ -package handshake - -import ( - "crypto" - "crypto/tls" - - "golang.org/x/crypto/hkdf" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" -) - -var ( - quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} - quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} - quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} -) - -const ( - hkdfLabelKeyV1 = "quic key" - hkdfLabelKeyV2 = "quicv2 key" - hkdfLabelIVV1 = "quic iv" - hkdfLabelIVV2 = "quicv2 iv" -) - -func getSalt(v protocol.VersionNumber) []byte { - if v == protocol.Version2 { - return quicSaltV2 - } - if v == protocol.Version1 { - return quicSaltV1 - } - return quicSaltOld -} - -var initialSuite = &qtls.CipherSuiteTLS13{ - ID: tls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, -} - -// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { - clientSecret, serverSecret := computeSecrets(connID, v) - var mySecret, otherSecret []byte - if pers == protocol.PerspectiveClient { - mySecret = clientSecret - otherSecret = serverSecret - } else { - mySecret = serverSecret - otherSecret = clientSecret - } - myKey, myIV := computeInitialKeyAndIV(mySecret, v) - otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) - - encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - - return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), - newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) -} - -func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) - clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) - serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) - return -} - -func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { - keyLabel := hkdfLabelKeyV1 - ivLabel := hkdfLabelIVV1 - if v == protocol.Version2 { - keyLabel = hkdfLabelKeyV2 - ivLabel = hkdfLabelIVV2 - } - key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) - iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) - return -} diff --git a/internal/quic-go/handshake/initial_aead_test.go b/internal/quic-go/handshake/initial_aead_test.go deleted file mode 100644 index 6a97eb70..00000000 --- a/internal/quic-go/handshake/initial_aead_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package handshake - -import ( - "fmt" - "math/rand" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" -) - -var _ = Describe("Initial AEAD using AES-GCM", func() { - It("converts the string representation used in the draft into byte slices", func() { - Expect(splitHexString("0xdeadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - Expect(splitHexString("deadbeef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - }) - - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - - DescribeTable("computes the client key and IV", - func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { - clientSecret, _ := computeSecrets(connID, v) - Expect(clientSecret).To(Equal(expectedClientSecret)) - key, iv := computeInitialKeyAndIV(clientSecret, v) - Expect(key).To(Equal(expectedKey)) - Expect(iv).To(Equal(expectedIV)) - }, - Entry("draft-29", - protocol.VersionDraft29, - splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"), - splitHexString("175257a31eb09dea9366d8bb79ad80ba"), - splitHexString("6b26114b9cba2b63a9e8dd4f"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), - splitHexString("1f369613dd76d5467730efcbe3b1a22d"), - splitHexString("fa044b2f42a3fd3b46fb255c"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("9fe72e1452e91f551b770005054034e4 7575d4a0fb4c27b7c6cb303a338423ae"), - splitHexString("95df2be2e8d549c82e996fc9339f4563"), - splitHexString("ea5e3c95f933db14b7020ad8"), - ), - ) - - DescribeTable("computes the server key and IV", - func(v protocol.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { - _, serverSecret := computeSecrets(connID, v) - Expect(serverSecret).To(Equal(expectedServerSecret)) - key, iv := computeInitialKeyAndIV(serverSecret, v) - Expect(key).To(Equal(expectedKey)) - Expect(iv).To(Equal(expectedIV)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"), - splitHexString("149d0b1662ab871fbe63c49b5e655a5d"), - splitHexString("bab2b12a4c76016ace47856d"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), - splitHexString("cf3a5331653c364c88f0f379b6067e37"), - splitHexString("0ac1493ca1905853b0bba03e"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("3c9bf6a9c1c8c71819876967bd8b979e fd98ec665edf27f22c06e9845ba0ae2f"), - splitHexString("15d5b4d9a2b8916aa39b1bfe574d2aad"), - splitHexString("a85e7ac31cd275cbb095c626"), - ), - ) - - DescribeTable("encrypts the client's Initial", - func(v protocol.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) - data = append(data, make([]byte, 1162-len(data))...) // add PADDING - sealed := sealer.Seal(nil, data, 2, header) - sample := sealed[0:16] - Expect(sample).To(Equal(expectedSample)) - sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) - Expect(header[0]).To(Equal(expectedHdrFirstByte)) - Expect(header[len(header)-4:]).To(Equal(expectedHdr)) - packet := append(header, sealed...) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("c3ff00001d088394c8f03e5157080000449e00000002"), - splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001"), - splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"), - byte(0xc5), - splitHexString("4a95245b"), - splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c300000001088394c8f03e5157080000449e00000002"), - splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), - splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"), - byte(0xc0), - splitHexString("7b9aec34"), - splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("d3709a50c4088394c8f03e5157080000449e00000002"), - splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), - splitHexString("23b8e610589c83c92d0e97eb7a6e5003"), - byte(0xdd), - splitHexString("4391d848"), - splitHexString("dd709a50c4088394c8f03e5157080000 449e4391d84823b8e610589c83c92d0e 97eb7a6e5003f57764c5c7f0095ba54b 90818f1bfeecc1c97c54fc731edbd2a2 44e3b1e639a9bc75ed545b98649343b2 53615ec6b3e4df0fd2e7fe9d691a09e6 a144b436d8a2c088a404262340dfd995 ec3865694e3026ecd8c6d2561a5a3667 2a1005018168c0f081c10e2bf14d550c 977e28bb9a759c57d0f7ffb1cdfb40bd 774dec589657542047dffefa56fc8089 a4d1ef379c81ba3df71a05ddc7928340 775910feb3ce4cbcfd8d253edd05f161 458f9dc44bea017c3117cca7065a315d eda9464e672ec80c3f79ac993437b441 ef74227ecc4dc9d597f66ab0ab8d214b 55840c70349d7616cbe38e5e1d052d07 f1fedb3dd3c4d8ce295724945e67ed2e efcd9fb52472387f318e3d9d233be7df c79d6bf6080dcbbb41feb180d7858849 7c3e439d38c334748d2b56fd19ab364d 057a9bd5a699ae145d7fdbc8f5777518 1b0a97c3bdedc91a555d6c9b8634e106 d8c9ca45a9d5450a7679edc545da9102 5bc93a7cf9a023a066ffadb9717ffaf3 414c3b646b5738b3cc4116502d18d79d 8227436306d9b2b3afc6c785ce3c817f eb703a42b9c83b59f0dcef1245d0b3e4 0299821ec19549ce489714fe2611e72c d882f4f70dce7d3671296fc045af5c9f 630d7b49a3eb821bbca60f1984dce664 91713bfe06001a56f51bb3abe92f7960 547c4d0a70f4a962b3f05dc25a34bbe8 30a7ea4736d3b0161723500d82beda9b e3327af2aa413821ff678b2a876ec4b0 0bb605ffcc3917ffdc279f187daa2fce 8cde121980bba8ec8f44ca562b0f1319 14c901cfbd847408b778e6738c7bb5b1 b3f97d01b0a24dcca40e3bed29411b1b a8f60843c4a241021b23132b9500509b 9a3516d4a9dd41d3bacbcd426b451393 521828afedcf20fa46ac24f44a8e2973 30b16705d5d5f798eff9e9134a065979 87a1db4617caa2d93837730829d4d89e 16413be4d8a8a38a7e6226623b64a820 178ec3a66954e10710e043ae73dd3fb2 715a0525a46343fb7590e5eac7ee55fc 810e0d8b4b8f7be82cd5a214575a1b99 629d47a9b281b61348c8627cab38e2a6 4db6626e97bb8f77bdcb0fee476aedd7 ba8f5441acaab00f4432edab3791047d 9091b2a753f035648431f6d12f7d6a68 1e64c861f4ac911a0f7d6ec0491a78c9 f192f96b3a5e7560a3f056bc1ca85983 67ad6acb6f2e034c7f37beeb9ed470c4 304af0107f0eb919be36a86f68f37fa6 1dae7aff14decd67ec3157a11488a14f ed0142828348f5f608b0fe03e1f3c0af 3acca0ce36852ed42e220ae9abf8f890 6f00f1b86bff8504c8f16c784fd52d25 e013ff4fda903e9e1eb453c1464b1196 6db9b28e8f26a3fc419e6a60a48d4c72 14ee9c6c6a12b68a32cac8f61580c64f 29cb6922408783c6d12e725b014fe485 cd17e484c5952bf99bc94941d4b1919d 04317b8aa1bd3754ecbaa10ec227de85 40695bf2fb8ee56f6dc526ef366625b9 1aa4970b6ffa5c8284b9b5ab852b905f 9d83f5669c0535bc377bcc05ad5e48e2 81ec0e1917ca3c6a471f8da0894bc82a c2a8965405d6eef3b5e293a88fda203f 09bdc72757b107ab14880eaa3ef7045b 580f4821ce6dd325b5a90655d8c5b55f 76fb846279a9b518c5e9b9a21165c509 3ed49baaacadf1f21873266c767f6769"), - ), - ) - - DescribeTable("encrypts the server's Initial", - func(v protocol.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) - sealed := sealer.Seal(nil, data, 1, header) - sample := sealed[2 : 2+16] - Expect(sample).To(Equal(expectedSample)) - sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) - Expect(header).To(Equal(expectedHdr)) - packet := append(header, sealed...) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("draft 29", - protocol.VersionDraft29, - splitHexString("c1ff00001d0008f067a5502a4262b50040740001"), - splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304"), - splitHexString("823a5d3a1207c86ee49132824f046524"), - splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"), - splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"), - ), - Entry("QUIC v1", - protocol.Version1, - splitHexString("c1000000010008f067a5502a4262b50040750001"), - splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), - splitHexString("2cd0991cd25b0aac406a5816b6394100"), - splitHexString("cf000000010008f067a5502a4262b5004075c0d9"), - splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("d1709a50c40008f067a5502a4262b50040750001"), - splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), - splitHexString("ebb7972fdce59d50e7e49ff2a7e8de76"), - splitHexString("d0709a50c40008f067a5502a4262b5004075103e"), - splitHexString("d0709a50c40008f067a5502a4262b500 4075103e63b4ebb7972fdce59d50e7e4 9ff2a7e8de76b0cd8c10100a1f13d549 dd6fe801588fb14d279bef8d7c53ef62 66a9a7a1a5f2fa026c236a5bf8df5aa0 f9d74773aeccfffe910b0f76814b5e33 f7b7f8ec278d23fd8c7a9e66856b8bbe 72558135bca27c54d63fcc902253461c fc089d4e6b9b19"), - ), - ) - - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - It("seals and opens", func() { - connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} - clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) - serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) - - clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) - m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(m).To(Equal([]byte("foobar"))) - serverMessage := serverSealer.Seal(nil, []byte("raboof"), 99, []byte("daa")) - m, err = clientOpener.Open(nil, serverMessage, 99, []byte("daa")) - Expect(err).ToNot(HaveOccurred()) - Expect(m).To(Equal([]byte("raboof"))) - }) - - It("doesn't work if initialized with different connection IDs", func() { - c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} - c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} - clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) - _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) - - clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) - _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("encrypts und decrypts the header", func() { - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) - serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) - - // the first byte and the last 4 bytes should be encrypted - header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - clientSealer.EncryptHeader(sample, &header[0], header[6:10]) - // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified - Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - serverOpener.DecryptHeader(sample, &header[0], header[6:10]) - Expect(header[0]).To(Equal(byte(0x5e))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - - serverSealer.EncryptHeader(sample, &header[0], header[6:10]) - // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified - Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - clientOpener.DecryptHeader(sample, &header[0], header[6:10]) - Expect(header[0]).To(Equal(byte(0x5e))) - Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) - Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - }) - }) - } -}) diff --git a/internal/quic-go/handshake/interface.go b/internal/quic-go/handshake/interface.go deleted file mode 100644 index 43ed0236..00000000 --- a/internal/quic-go/handshake/interface.go +++ /dev/null @@ -1,102 +0,0 @@ -package handshake - -import ( - "errors" - "io" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -var ( - // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, - // but the corresponding opener has not yet been initialized - // This can happen when packets arrive out of order. - ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") - // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, - // but the corresponding keys have already been dropped. - ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") - // ErrDecryptionFailed is returned when the AEAD fails to open the packet. - ErrDecryptionFailed = errors.New("decryption failed") -) - -// ConnectionState contains information about the state of the connection. -type ConnectionState = qtls.ConnectionState - -type headerDecryptor interface { - DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) -} - -// LongHeaderOpener opens a long header packet -type LongHeaderOpener interface { - headerDecryptor - DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber - Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) -} - -// ShortHeaderOpener opens a short header packet -type ShortHeaderOpener interface { - headerDecryptor - DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber - Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) -} - -// LongHeaderSealer seals a long header packet -type LongHeaderSealer interface { - Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte - EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) - Overhead() int -} - -// ShortHeaderSealer seals a short header packet -type ShortHeaderSealer interface { - LongHeaderSealer - KeyPhase() protocol.KeyPhaseBit -} - -// A tlsExtensionHandler sends and received the QUIC TLS extension. -type tlsExtensionHandler interface { - GetExtensions(msgType uint8) []qtls.Extension - ReceivedExtensions(msgType uint8, exts []qtls.Extension) - TransportParameters() <-chan []byte -} - -type handshakeRunner interface { - OnReceivedParams(*wire.TransportParameters) - OnHandshakeComplete() - OnError(error) - DropKeys(protocol.EncryptionLevel) -} - -// CryptoSetup handles the handshake and protecting / unprotecting packets -type CryptoSetup interface { - RunHandshake() - io.Closer - ChangeConnectionID(protocol.ConnectionID) - GetSessionTicket() ([]byte, error) - - HandleMessage([]byte, protocol.EncryptionLevel) bool - SetLargest1RTTAcked(protocol.PacketNumber) error - SetHandshakeConfirmed() - ConnectionState() ConnectionState - - GetInitialOpener() (LongHeaderOpener, error) - GetHandshakeOpener() (LongHeaderOpener, error) - Get0RTTOpener() (LongHeaderOpener, error) - Get1RTTOpener() (ShortHeaderOpener, error) - - GetInitialSealer() (LongHeaderSealer, error) - GetHandshakeSealer() (LongHeaderSealer, error) - Get0RTTSealer() (LongHeaderSealer, error) - Get1RTTSealer() (ShortHeaderSealer, error) -} - -// ConnWithVersion is the connection used in the ClientHelloInfo. -// It can be used to determine the QUIC version in use. -type ConnWithVersion interface { - net.Conn - GetQUICVersion() protocol.VersionNumber -} diff --git a/internal/quic-go/handshake/mock_handshake_runner_test.go b/internal/quic-go/handshake/mock_handshake_runner_test.go deleted file mode 100644 index eae6a898..00000000 --- a/internal/quic-go/handshake/mock_handshake_runner_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: interface.go - -// Package handshake is a generated GoMock package. -package handshake - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockHandshakeRunner is a mock of HandshakeRunner interface. -type MockHandshakeRunner struct { - ctrl *gomock.Controller - recorder *MockHandshakeRunnerMockRecorder -} - -// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner. -type MockHandshakeRunnerMockRecorder struct { - mock *MockHandshakeRunner -} - -// NewMockHandshakeRunner creates a new mock instance. -func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { - mock := &MockHandshakeRunner{ctrl: ctrl} - mock.recorder = &MockHandshakeRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { - return m.recorder -} - -// DropKeys mocks base method. -func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropKeys", arg0) -} - -// DropKeys indicates an expected call of DropKeys. -func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) -} - -// OnError mocks base method. -func (m *MockHandshakeRunner) OnError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnError", arg0) -} - -// OnError indicates an expected call of OnError. -func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) -} - -// OnHandshakeComplete mocks base method. -func (m *MockHandshakeRunner) OnHandshakeComplete() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnHandshakeComplete") -} - -// OnHandshakeComplete indicates an expected call of OnHandshakeComplete. -func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) -} - -// OnReceivedParams mocks base method. -func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnReceivedParams", arg0) -} - -// OnReceivedParams indicates an expected call of OnReceivedParams. -func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) -} diff --git a/internal/quic-go/handshake/mockgen.go b/internal/quic-go/handshake/mockgen.go deleted file mode 100644 index b5534225..00000000 --- a/internal/quic-go/handshake/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package handshake - -//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/imroc/req/v3/internal/quic-go/handshake handshakeRunner" diff --git a/internal/quic-go/handshake/retry.go b/internal/quic-go/handshake/retry.go deleted file mode 100644 index a9906086..00000000 --- a/internal/quic-go/handshake/retry.go +++ /dev/null @@ -1,62 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "fmt" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -var ( - oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 - retryAEAD cipher.AEAD // used for QUIC draft-34 -) - -func init() { - oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) - retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) -} - -func initAEAD(key [16]byte) cipher.AEAD { - aes, err := aes.NewCipher(key[:]) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - return aead -} - -var ( - retryBuf bytes.Buffer - retryMutex sync.Mutex - oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} - retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} -) - -// GetRetryIntegrityTag calculates the integrity tag on a Retry packet -func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { - retryMutex.Lock() - retryBuf.WriteByte(uint8(origDestConnID.Len())) - retryBuf.Write(origDestConnID.Bytes()) - retryBuf.Write(retry) - - var tag [16]byte - var sealed []byte - if version != protocol.Version1 { - sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) - } else { - sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) - } - if len(sealed) != 16 { - panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) - } - retryBuf.Reset() - retryMutex.Unlock() - return &tag -} diff --git a/internal/quic-go/handshake/retry_test.go b/internal/quic-go/handshake/retry_test.go deleted file mode 100644 index fdb3ff75..00000000 --- a/internal/quic-go/handshake/retry_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package handshake - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Retry Integrity Check", func() { - It("calculates retry integrity tags", func() { - fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - Expect(fooTag).ToNot(BeNil()) - Expect(barTag).ToNot(BeNil()) - Expect(*fooTag).ToNot(Equal(*barTag)) - }) - - It("includes the original connection ID in the tag calculation", func() { - t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) - t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) - Expect(*t1).ToNot(Equal(*t2)) - }) - - It("uses the test vector from the draft, for old draft versions", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") - Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) - }) - - It("uses the test vector from the draft, for version 1", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") - Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) - }) -}) diff --git a/internal/quic-go/handshake/session_ticket.go b/internal/quic-go/handshake/session_ticket.go deleted file mode 100644 index afefe8a7..00000000 --- a/internal/quic-go/handshake/session_ticket.go +++ /dev/null @@ -1,48 +0,0 @@ -package handshake - -import ( - "bytes" - "errors" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -const sessionTicketRevision = 2 - -type sessionTicket struct { - Parameters *wire.TransportParameters - RTT time.Duration // to be encoded in mus -} - -func (t *sessionTicket) Marshal() []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - quicvarint.Write(b, uint64(t.RTT.Microseconds())) - t.Parameters.MarshalForSessionTicket(b) - return b.Bytes() -} - -func (t *sessionTicket) Unmarshal(b []byte) error { - r := bytes.NewReader(b) - rev, err := quicvarint.Read(r) - if err != nil { - return errors.New("failed to read session ticket revision") - } - if rev != sessionTicketRevision { - return fmt.Errorf("unknown session ticket revision: %d", rev) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return errors.New("failed to read RTT") - } - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) - } - t.Parameters = &tp - t.RTT = time.Duration(rtt) * time.Microsecond - return nil -} diff --git a/internal/quic-go/handshake/session_ticket_test.go b/internal/quic-go/handshake/session_ticket_test.go deleted file mode 100644 index 832def9d..00000000 --- a/internal/quic-go/handshake/session_ticket_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package handshake - -import ( - "bytes" - "time" - - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Session Ticket", func() { - It("marshals and unmarshals a session ticket", func() { - ticket := &sessionTicket{ - Parameters: &wire.TransportParameters{ - InitialMaxStreamDataBidiLocal: 1, - InitialMaxStreamDataBidiRemote: 2, - }, - RTT: 1337 * time.Microsecond, - } - var t sessionTicket - Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) - Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) - Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) - Expect(t.RTT).To(Equal(1337 * time.Microsecond)) - }) - - It("refuses to unmarshal if the ticket is too short for the revision", func() { - Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) - }) - - It("refuses to unmarshal if the revision doesn't match", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, 1337) - Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("unknown session ticket revision: 1337")) - }) - - It("refuses to unmarshal if the RTT cannot be read", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("failed to read RTT")) - }) - - It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - b.Write([]byte("foobar")) - err := (&sessionTicket{}).Unmarshal(b.Bytes()) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) - }) -}) diff --git a/internal/quic-go/handshake/tls_extension_handler.go b/internal/quic-go/handshake/tls_extension_handler.go deleted file mode 100644 index 245f27c8..00000000 --- a/internal/quic-go/handshake/tls_extension_handler.go +++ /dev/null @@ -1,68 +0,0 @@ -package handshake - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" -) - -const ( - quicTLSExtensionTypeOldDrafts = 0xffa5 - quicTLSExtensionType = 0x39 -) - -type extensionHandler struct { - ourParams []byte - paramsChan chan []byte - - extensionType uint16 - - perspective protocol.Perspective -} - -var _ tlsExtensionHandler = &extensionHandler{} - -// newExtensionHandler creates a new extension handler -func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler { - et := uint16(quicTLSExtensionType) - if v != protocol.Version1 { - et = quicTLSExtensionTypeOldDrafts - } - return &extensionHandler{ - ourParams: params, - paramsChan: make(chan []byte), - perspective: pers, - extensionType: et, - } -} - -func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { - return nil - } - return []qtls.Extension{{ - Type: h.extensionType, - Data: h.ourParams, - }} -} - -func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { - return - } - - var data []byte - for _, ext := range exts { - if ext.Type == h.extensionType { - data = ext.Data - break - } - } - - h.paramsChan <- data -} - -func (h *extensionHandler) TransportParameters() <-chan []byte { - return h.paramsChan -} diff --git a/internal/quic-go/handshake/tls_extension_handler_test.go b/internal/quic-go/handshake/tls_extension_handler_test.go deleted file mode 100644 index 4fcd48c1..00000000 --- a/internal/quic-go/handshake/tls_extension_handler_test.go +++ /dev/null @@ -1,210 +0,0 @@ -package handshake - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qtls" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("TLS Extension Handler, for the server", func() { - var ( - handlerServer tlsExtensionHandler - handlerClient tlsExtensionHandler - version protocol.VersionNumber - ) - - BeforeEach(func() { - version = protocol.VersionDraft29 - }) - - JustBeforeEach(func() { - handlerServer = newExtensionHandler( - []byte("foobar"), - protocol.PerspectiveServer, - version, - ) - handlerClient = newExtensionHandler( - []byte("raboof"), - protocol.PerspectiveClient, - version, - ) - }) - - Context("for the server", func() { - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { - v := ver - - Context(fmt.Sprintf("sending, for version %s", v), func() { - var extensionType uint16 - - BeforeEach(func() { - version = v - if v == protocol.VersionDraft29 { - extensionType = quicTLSExtensionTypeOldDrafts - } else { - extensionType = quicTLSExtensionType - } - }) - - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the EncryptedExtensions message", func() { - exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) - Expect(exts[0].Data).To(Equal([]byte("foobar"))) - }) - }) - } - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("raboof"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the ClientHello", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("for the client", func() { - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { - v := ver - - Context(fmt.Sprintf("sending, for version %s", v), func() { - var extensionType uint16 - - BeforeEach(func() { - version = v - if v == protocol.VersionDraft29 { - extensionType = quicTLSExtensionTypeOldDrafts - } else { - extensionType = quicTLSExtensionType - } - }) - - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the ClientHello message", func() { - exts := handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) - Expect(exts[0].Data).To(Equal([]byte("raboof"))) - }) - }) - } - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("foobar"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the EncryptedExtensions", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) -}) diff --git a/internal/quic-go/handshake/token_generator.go b/internal/quic-go/handshake/token_generator.go deleted file mode 100644 index 3dcfa090..00000000 --- a/internal/quic-go/handshake/token_generator.go +++ /dev/null @@ -1,134 +0,0 @@ -package handshake - -import ( - "encoding/asn1" - "fmt" - "io" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -const ( - tokenPrefixIP byte = iota - tokenPrefixString -) - -// A Token is derived from the client address and can be used to verify the ownership of this address. -type Token struct { - IsRetryToken bool - RemoteAddr string - SentTime time.Time - // only set for retry tokens - OriginalDestConnectionID protocol.ConnectionID - RetrySrcConnectionID protocol.ConnectionID -} - -// token is the struct that is used for ASN1 serialization and deserialization -type token struct { - IsRetryToken bool - RemoteAddr []byte - Timestamp int64 - OriginalDestConnectionID []byte - RetrySrcConnectionID []byte -} - -// A TokenGenerator generates tokens -type TokenGenerator struct { - tokenProtector tokenProtector -} - -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil -} - -// NewRetryToken generates a new token for a Retry for a given source address -func (g *TokenGenerator) NewRetryToken( - raddr net.Addr, - origDestConnID protocol.ConnectionID, - retrySrcConnID protocol.ConnectionID, -) ([]byte, error) { - data, err := asn1.Marshal(token{ - IsRetryToken: true, - RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origDestConnID, - RetrySrcConnectionID: retrySrcConnID, - Timestamp: time.Now().UnixNano(), - }) - if err != nil { - return nil, err - } - return g.tokenProtector.NewToken(data) -} - -// NewToken generates a new token to be sent in a NEW_TOKEN frame -func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) { - data, err := asn1.Marshal(token{ - RemoteAddr: encodeRemoteAddr(raddr), - Timestamp: time.Now().UnixNano(), - }) - if err != nil { - return nil, err - } - return g.tokenProtector.NewToken(data) -} - -// DecodeToken decodes a token -func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { - // if the client didn't send any token, DecodeToken will be called with a nil-slice - if len(encrypted) == 0 { - return nil, nil - } - - data, err := g.tokenProtector.DecodeToken(encrypted) - if err != nil { - return nil, err - } - t := &token{} - rest, err := asn1.Unmarshal(data, t) - if err != nil { - return nil, err - } - if len(rest) != 0 { - return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) - } - token := &Token{ - IsRetryToken: t.IsRetryToken, - RemoteAddr: decodeRemoteAddr(t.RemoteAddr), - SentTime: time.Unix(0, t.Timestamp), - } - if t.IsRetryToken { - token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) - token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) - } - return token, nil -} - -// encodeRemoteAddr encodes a remote address such that it can be saved in the token -func encodeRemoteAddr(remoteAddr net.Addr) []byte { - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return append([]byte{tokenPrefixIP}, udpAddr.IP...) - } - return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) -} - -// decodeRemoteAddr decodes the remote address saved in the token -func decodeRemoteAddr(data []byte) string { - // data will never be empty for a token that we generated. - // Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == tokenPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/internal/quic-go/handshake/token_generator_test.go b/internal/quic-go/handshake/token_generator_test.go deleted file mode 100644 index a1a22ee1..00000000 --- a/internal/quic-go/handshake/token_generator_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package handshake - -import ( - "crypto/rand" - "encoding/asn1" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Token Generator", func() { - var tokenGen *TokenGenerator - - BeforeEach(func() { - var err error - tokenGen, err = NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - }) - - It("generates a token", func() { - ip := net.IPv4(127, 0, 0, 1) - token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(token).ToNot(BeEmpty()) - }) - - It("works with nil tokens", func() { - token, err := tokenGen.DecodeToken(nil) - Expect(err).ToNot(HaveOccurred()) - Expect(token).To(BeNil()) - }) - - It("accepts a valid token", func() { - ip := net.IPv4(192, 168, 0, 1) - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{IP: ip, Port: 1337}, - nil, - nil, - ) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.0.1")) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) - Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) - }) - - It("saves the connection ID", func() { - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{}, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - ) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - }) - - It("rejects invalid tokens", func() { - _, err := tokenGen.DecodeToken([]byte("invalid token")) - Expect(err).To(HaveOccurred()) - }) - - It("rejects tokens that cannot be decoded", func() { - token, err := tokenGen.tokenProtector.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(token) - Expect(err).To(HaveOccurred()) - }) - - It("rejects tokens that can be decoded, but have additional payload", func() { - t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) - Expect(err).ToNot(HaveOccurred()) - t = append(t, []byte("rest")...) - enc, err := tokenGen.tokenProtector.NewToken(t) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(enc) - Expect(err).To(MatchError("rest when unpacking token: 4")) - }) - - // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason - It("doesn't panic if a tokens has no data", func() { - t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) - Expect(err).ToNot(HaveOccurred()) - enc, err := tokenGen.tokenProtector.NewToken(t) - Expect(err).ToNot(HaveOccurred()) - _, err = tokenGen.DecodeToken(enc) - Expect(err).ToNot(HaveOccurred()) - }) - - It("works with an IPv6 addresses ", func() { - addresses := []string{ - "2001:db8::68", - "2001:0000:4136:e378:8000:63bf:3fff:fdd2", - "2001::1", - "ff01:0:0:0:0:0:0:2", - } - for _, addr := range addresses { - ip := net.ParseIP(addr) - Expect(ip).ToNot(BeNil()) - raddr := &net.UDPAddr{IP: ip, Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal(ip.String())) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - } - }) - - It("uses the string representation an address that is not a UDP address", func() { - raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - token, err := tokenGen.DecodeToken(tokenEnc) - Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) - Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) -}) diff --git a/internal/quic-go/handshake/token_protector.go b/internal/quic-go/handshake/token_protector.go deleted file mode 100644 index 650f230b..00000000 --- a/internal/quic-go/handshake/token_protector.go +++ /dev/null @@ -1,89 +0,0 @@ -package handshake - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/sha256" - "fmt" - "io" - - "golang.org/x/crypto/hkdf" -) - -// TokenProtector is used to create and verify a token -type tokenProtector interface { - // NewToken creates a new token - NewToken([]byte) ([]byte, error) - // DecodeToken decodes a token - DecodeToken([]byte) ([]byte, error) -} - -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) - -// tokenProtector is used to create and verify a token -type tokenProtectorImpl struct { - rand io.Reader - secret []byte -} - -// newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil -} - -// NewToken encodes data into a new token. -func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { - return nil, err - } - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil -} - -// DecodeToken decodes a token. -func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { - if len(p) < tokenNonceSize { - return nil, fmt.Errorf("token too short: %d", len(p)) - } - nonce := p[:tokenNonceSize] - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) -} - -func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) - key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 - if _, err := io.ReadFull(h, key); err != nil { - return nil, nil, err - } - aeadNonce := make([]byte, 12) - if _, err := io.ReadFull(h, aeadNonce); err != nil { - return nil, nil, err - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - aead, err := cipher.NewGCM(c) - if err != nil { - return nil, nil, err - } - return aead, aeadNonce, nil -} diff --git a/internal/quic-go/handshake/token_protector_test.go b/internal/quic-go/handshake/token_protector_test.go deleted file mode 100644 index 7171e865..00000000 --- a/internal/quic-go/handshake/token_protector_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package handshake - -import ( - "crypto/rand" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type zeroReader struct{} - -func (r *zeroReader) Read(b []byte) (int, error) { - for i := range b { - b[i] = 0 - } - return len(b), nil -} - -var _ = Describe("Token Protector", func() { - var tp tokenProtector - - BeforeEach(func() { - var err error - tp, err = newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - }) - - It("uses the random source", func() { - tp1, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - tp2, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - t1, err := tp1.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - t2, err := tp2.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t1).To(Equal(t2)) - tp3, err := newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - t3, err := tp3.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t3).ToNot(Equal(t1)) - }) - - It("encodes and decodes tokens", func() { - token, err := tp.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(token).ToNot(ContainSubstring("foobar")) - decoded, err := tp.DecodeToken(token) - Expect(err).ToNot(HaveOccurred()) - Expect(decoded).To(Equal([]byte("foobar"))) - }) - - It("fails deconding invalid tokens", func() { - token, err := tp.NewToken([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - token = token[1:] // remove the first byte - _, err = tp.DecodeToken(token) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("message authentication failed")) - }) - - It("errors when decoding too short tokens", func() { - _, err := tp.DecodeToken([]byte("foobar")) - Expect(err).To(MatchError("token too short: 6")) - }) -}) diff --git a/internal/quic-go/handshake/updatable_aead.go b/internal/quic-go/handshake/updatable_aead.go deleted file mode 100644 index e22cea45..00000000 --- a/internal/quic-go/handshake/updatable_aead.go +++ /dev/null @@ -1,323 +0,0 @@ -package handshake - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "encoding/binary" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/qtls" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. -// It's a package-level variable to allow modifying it for testing purposes. -var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval - -type updatableAEAD struct { - suite *qtls.CipherSuiteTLS13 - - keyPhase protocol.KeyPhase - largestAcked protocol.PacketNumber - firstPacketNumber protocol.PacketNumber - handshakeConfirmed bool - - keyUpdateInterval uint64 - invalidPacketLimit uint64 - invalidPacketCount uint64 - - // Time when the keys should be dropped. Keys are dropped on the next call to Open(). - prevRcvAEADExpiry time.Time - prevRcvAEAD cipher.AEAD - - firstRcvdWithCurrentKey protocol.PacketNumber - firstSentWithCurrentKey protocol.PacketNumber - highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - numRcvdWithCurrentKey uint64 - numSentWithCurrentKey uint64 - rcvAEAD cipher.AEAD - sendAEAD cipher.AEAD - // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). - aeadOverhead int - - nextRcvAEAD cipher.AEAD - nextSendAEAD cipher.AEAD - nextRcvTrafficSecret []byte - nextSendTrafficSecret []byte - - headerDecrypter headerProtector - headerEncrypter headerProtector - - rttStats *utils.RTTStats - - tracer logging.ConnectionTracer - logger utils.Logger - version protocol.VersionNumber - - // use a single slice to avoid allocations - nonceBuf []byte -} - -var ( - _ ShortHeaderOpener = &updatableAEAD{} - _ ShortHeaderSealer = &updatableAEAD{} -) - -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { - return &updatableAEAD{ - firstPacketNumber: protocol.InvalidPacketNumber, - largestAcked: protocol.InvalidPacketNumber, - firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, - firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, - rttStats: rttStats, - tracer: tracer, - logger: logger, - version: version, - } -} - -func (a *updatableAEAD) rollKeys() { - if a.prevRcvAEAD != nil { - a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { - a.tracer.DroppedKey(a.keyPhase - 1) - } - a.prevRcvAEADExpiry = time.Time{} - } - - a.keyPhase++ - a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber - a.firstSentWithCurrentKey = protocol.InvalidPacketNumber - a.numRcvdWithCurrentKey = 0 - a.numSentWithCurrentKey = 0 - a.prevRcvAEAD = a.rcvAEAD - a.rcvAEAD = a.nextRcvAEAD - a.sendAEAD = a.nextSendAEAD - - a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) - a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) - a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) - a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) -} - -func (a *updatableAEAD) startKeyDropTimer(now time.Time) { - d := 3 * a.rttStats.PTO(true) - a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) - a.prevRcvAEADExpiry = now.Add(d) -} - -func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { - return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) -} - -// For the client, this function is called before SetWriteKey. -// For the server, this function is called after SetWriteKey. -func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) - a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) - if a.suite == nil { - a.setAEADParameters(a.rcvAEAD, suite) - } - - a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) -} - -// For the client, this function is called after SetReadKey. -// For the server, this function is called before SetWriteKey. -func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.sendAEAD = createAEAD(suite, trafficSecret, a.version) - a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) - if a.suite == nil { - a.setAEADParameters(a.sendAEAD, suite) - } - - a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) -} - -func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { - a.nonceBuf = make([]byte, aead.NonceSize()) - a.aeadOverhead = aead.Overhead() - a.suite = suite - switch suite.ID { - case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: - a.invalidPacketLimit = protocol.InvalidPacketLimitAES - case tls.TLS_CHACHA20_POLY1305_SHA256: - a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha - default: - panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) - } -} - -func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { - return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) -} - -func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { - dec, err := a.open(dst, src, rcvTime, pn, kp, ad) - if err == ErrDecryptionFailed { - a.invalidPacketCount++ - if a.invalidPacketCount >= a.invalidPacketLimit { - return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} - } - } - if err == nil { - a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) - } - return dec, err -} - -func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { - if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { - a.prevRcvAEAD = nil - a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) - a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { - a.tracer.DroppedKey(a.keyPhase - 1) - } - } - binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) - if kp != a.keyPhase.Bit() { - if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { - if a.prevRcvAEAD == nil { - return nil, ErrKeysDropped - } - // we updated the key, but the peer hasn't updated yet - dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - err = ErrDecryptionFailed - } - return dec, err - } - // try opening the packet with the next key phase - dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - return nil, ErrDecryptionFailed - } - // Opening succeeded. Check if the peer was allowed to update. - if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { - return nil, &qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "keys updated too quickly", - } - } - a.rollKeys() - a.logger.Debugf("Peer updated keys to %d", a.keyPhase) - // The peer initiated this key update. It's safe to drop the keys for the previous generation now. - // Start a timer to drop the previous key generation. - a.startKeyDropTimer(rcvTime) - if a.tracer != nil { - a.tracer.UpdatedKey(a.keyPhase, true) - } - a.firstRcvdWithCurrentKey = pn - return dec, err - } - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - return dec, ErrDecryptionFailed - } - a.numRcvdWithCurrentKey++ - if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { - // We initiated the key updated, and now we received the first packet protected with the new key phase. - // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. - if a.keyPhase > 0 { - a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) - a.startKeyDropTimer(rcvTime) - } - a.firstRcvdWithCurrentKey = pn - } - return dec, err -} - -func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { - a.firstSentWithCurrentKey = pn - } - if a.firstPacketNumber == protocol.InvalidPacketNumber { - a.firstPacketNumber = pn - } - a.numSentWithCurrentKey++ - binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) -} - -func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { - if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { - return &qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), - } - } - a.largestAcked = pn - return nil -} - -func (a *updatableAEAD) SetHandshakeConfirmed() { - a.handshakeConfirmed = true -} - -func (a *updatableAEAD) updateAllowed() bool { - if !a.handshakeConfirmed { - return false - } - // the first key update is allowed as soon as the handshake is confirmed - return a.keyPhase == 0 || - // subsequent key updates as soon as a packet sent with that key phase has been acknowledged - (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - a.largestAcked != protocol.InvalidPacketNumber && - a.largestAcked >= a.firstSentWithCurrentKey) -} - -func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { - if !a.updateAllowed() { - return false - } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) - return true - } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) - return true - } - return false -} - -func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { - if a.shouldInitiateKeyUpdate() { - a.rollKeys() - a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { - a.tracer.UpdatedKey(a.keyPhase, false) - } - } - return a.keyPhase.Bit() -} - -func (a *updatableAEAD) Overhead() int { - return a.aeadOverhead -} - -func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) -} - -func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { - a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) -} - -func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { - return a.firstPacketNumber -} diff --git a/internal/quic-go/handshake/updatable_aead_test.go b/internal/quic-go/handshake/updatable_aead_test.go deleted file mode 100644 index 35ec718f..00000000 --- a/internal/quic-go/handshake/updatable_aead_test.go +++ /dev/null @@ -1,528 +0,0 @@ -package handshake - -import ( - "crypto/rand" - "crypto/tls" - "fmt" - "time" - - "github.com/golang/mock/gomock" - - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" -) - -var _ = Describe("Updatable AEAD", func() { - DescribeTable("ChaCha test vector", - func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) { - secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") - aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) - chacha := cipherSuites[2] - Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) - aead.SetWriteKey(chacha, secret) - const pnOffset = 1 - header := splitHexString("4200bff4") - payloadOffset := len(header) - plaintext := splitHexString("01") - payload := aead.Seal(nil, plaintext, 654360564, header) - Expect(payload).To(Equal(expectedPayload)) - packet := append(header, payload...) - aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) - Expect(packet).To(Equal(expectedPacket)) - }, - Entry("QUIC v1", - protocol.Version1, - splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), - splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), - ), - Entry("QUIC v2", - protocol.Version2, - splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), - splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), - ), - ) - - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { - v := ver - - Context(fmt.Sprintf("using version %s", v), func() { - for i := range cipherSuites { - cs := cipherSuites[i] - - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - var ( - client, server *updatableAEAD - serverTracer *mocklogging.MockConnectionTracer - rttStats *utils.RTTStats - ) - - BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) - trafficSecret1 := make([]byte, 16) - trafficSecret2 := make([]byte, 16) - rand.Read(trafficSecret1) - rand.Read(trafficSecret2) - - rttStats = utils.NewRTTStats() - client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) - client.SetReadKey(cs, trafficSecret2) - client.SetWriteKey(cs, trafficSecret1) - server.SetReadKey(cs, trafficSecret1) - server.SetWriteKey(cs, trafficSecret2) - }) - - Context("header protection", func() { - It("encrypts and decrypts the header", func() { - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - client.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - server.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - }) - - Context("message encryption", func() { - var msg, ad []byte - - BeforeEach(func() { - msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad = []byte("Donec in velit neque.") - }) - - It("encrypts and decrypts a message", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("saves the first packet number", func() { - client.Seal(nil, msg, 0x1337, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - client.Seal(nil, msg, 0x1338, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - }) - - It("fails to open a message if the associated data is not the same", func() { - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) - }) - - It("ignores packets it can't decrypt for packet number derivation", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) - }) - - It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { - client.invalidPacketLimit = 10 - for i := 0; i < 9; i++ { - _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - } - _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) - }) - - Context("key updates", func() { - Context("receiving key updates", func() { - It("updates keys", func() { - now := time.Now() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys() - decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - }) - - It("updates the keys when receiving a packet with the next key phase", func() { - now := time.Now() - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x43, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("opens a reordered packet with the old keys after an update", func() { - now := time.Now() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO(true) - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("allows the first key update immediately", func() { - // receive a packet at key phase one, before having sent or received any packets at key phase 0 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x1337, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { - client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - encrypted = encrypted[:len(encrypted)-1] - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("errors when the peer updates keys too frequently", func() { - server.rollKeys() - client.rollKeys() - // receive the first packet at key phase one - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase two, before having sent any packets - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "keys updated too quickly", - })) - }) - }) - - Context("initiating key updates", func() { - const keyUpdateInterval = 20 - - BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval - server.SetHandshakeConfirmed() - }) - - It("initiates a key update after sealing the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // receive an ACK for a packet sent in key phase 0 - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // Now that our keys are updated, send a packet using the new keys. - const nextPN = keyUpdateInterval + 1 - server.Seal(nil, msg, nextPN, ad) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", - })) - }) - - It("doesn't error before actually sending a packet in the new key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - // Now that our keys are updated, send a packet using the new keys. - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) - }) - - It("initiates a key update after opening the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, 1, ad) - Expect(server.SetLargestAcked(1)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(0)).To(Succeed()) - // Now we've initiated the first key update. - // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there - threePTO := 3 * rttStats.PTO(false) - dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // Now receive a packet with key phase 1. - // This should start the timer to drop the keys after 3 PTOs. - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) - t := now.Add(threePTO).Add(time.Second) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // Make sure the keys are still here. - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("doesn't drop the first key generation too early", func() { - now := time.Now() - data1 := client.Seal(nil, msg, 1, ad) - _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - Expect(server.SetLargestAcked(pn)).To(Succeed()) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // The server never received a packet at key phase 1. - // Make sure the key phase 0 is still there at a much later point. - data2 := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - const nextPN = keyUpdateInterval + 1 - // Send and receive an acknowledgement for a packet in key phase 1. - // We are now running a timer to drop the keys with 3 PTO. - server.Seal(nil, msg, nextPN, ad) - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) - now := time.Now() - _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(nextPN)) - // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. - // This mean that we need to drop the keys for key phase 0 immediately. - client.rollKeys() - dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), - ) - _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - }) - - It("drops keys early when we initiate another key update within the 3 PTO period", func() { - server.SetHandshakeConfirmed() - // send so many packets that we initiate the first key update - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // send so many packets that we initiate the next key update - for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - client.rollKeys() - b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) - now := time.Now() - _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), - ) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - // We haven't received an ACK for a packet sent in key phase 2 yet. - // Make sure we canceled the timer to drop the previous key phase. - b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) - _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - }) - }) - }) - }) - }) - } - }) - } -}) diff --git a/internal/quic-go/interface.go b/internal/quic-go/interface.go deleted file mode 100644 index 76af5fcb..00000000 --- a/internal/quic-go/interface.go +++ /dev/null @@ -1,328 +0,0 @@ -package quic - -import ( - "context" - "errors" - "io" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// The StreamID is the ID of a QUIC stream. -type StreamID = protocol.StreamID - -// A VersionNumber is a QUIC version number. -type VersionNumber = protocol.VersionNumber - -const ( - // VersionDraft29 is IETF QUIC draft-29 - VersionDraft29 = protocol.VersionDraft29 - // Version1 is RFC 9000 - Version1 = protocol.Version1 - Version2 = protocol.Version2 -) - -// A Token can be used to verify the ownership of the client address. -type Token struct { - // IsRetryToken encodes how the client received the token. There are two ways: - // * In a Retry packet sent when trying to establish a new connection. - // * In a NEW_TOKEN frame on a previous connection. - IsRetryToken bool - RemoteAddr string - SentTime time.Time -} - -// A ClientToken is a token received by the client. -// It can be used to skip address validation on future connection attempts. -type ClientToken struct { - data []byte -} - -type TokenStore interface { - // Pop searches for a ClientToken associated with the given key. - // Since tokens are not supposed to be reused, it must remove the token from the cache. - // It returns nil when no token is found. - Pop(key string) (token *ClientToken) - - // Put adds a token to the cache with the given key. It might get called - // multiple times in a connection. - Put(key string, token *ClientToken) -} - -// Err0RTTRejected is the returned from: -// * Open{Uni}Stream{Sync} -// * Accept{Uni}Stream -// * Stream.Read and Stream.Write -// when the server rejects a 0-RTT connection attempt. -var Err0RTTRejected = errors.New("0-RTT rejected") - -// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection. -// It is set on the Connection.Context() context, -// as well as on the context passed to logging.Tracer.NewConnectionTracer. -var ConnectionTracingKey = connTracingCtxKey{} - -type connTracingCtxKey struct{} - -// Stream is the interface implemented by QUIC streams -// In addition to the errors listed on the Connection, -// calls to stream functions can return a StreamError if the stream is canceled. -type Stream interface { - ReceiveStream - SendStream - // SetDeadline sets the read and write deadlines associated - // with the connection. It is equivalent to calling both - // SetReadDeadline and SetWriteDeadline. - SetDeadline(t time.Time) error -} - -// A ReceiveStream is a unidirectional Receive Stream. -type ReceiveStream interface { - // StreamID returns the stream ID. - StreamID() StreamID - // Read reads data from the stream. - // Read can be made to time out and return a net.Error with Timeout() == true - // after a fixed time limit; see SetDeadline and SetReadDeadline. - // If the stream was canceled by the peer, the error implements the StreamError - // interface, and Canceled() == true. - // If the connection was closed due to a timeout, the error satisfies - // the net.Error interface, and Timeout() will be true. - io.Reader - // CancelRead aborts receiving on this stream. - // It will ask the peer to stop transmitting stream data. - // Read will unblock immediately, and future Read calls will fail. - // When called multiple times or after reading the io.EOF it is a no-op. - CancelRead(StreamErrorCode) - // SetReadDeadline sets the deadline for future Read calls and - // any currently-blocked Read call. - // A zero value for t means Read will not time out. - - SetReadDeadline(t time.Time) error -} - -// A SendStream is a unidirectional Send Stream. -type SendStream interface { - // StreamID returns the stream ID. - StreamID() StreamID - // Write writes data to the stream. - // Write can be made to time out and return a net.Error with Timeout() == true - // after a fixed time limit; see SetDeadline and SetWriteDeadline. - // If the stream was canceled by the peer, the error implements the StreamError - // interface, and Canceled() == true. - // If the connection was closed due to a timeout, the error satisfies - // the net.Error interface, and Timeout() will be true. - io.Writer - // Close closes the write-direction of the stream. - // Future calls to Write are not permitted after calling Close. - // It must not be called concurrently with Write. - // It must not be called after calling CancelWrite. - io.Closer - // CancelWrite aborts sending on this stream. - // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. - // Write will unblock immediately, and future calls to Write will fail. - // When called multiple times or after closing the stream it is a no-op. - CancelWrite(StreamErrorCode) - // The Context is canceled as soon as the write-side of the stream is closed. - // This happens when Close() or CancelWrite() is called, or when the peer - // cancels the read-side of their stream. - Context() context.Context - // SetWriteDeadline sets the deadline for future Write calls - // and any currently-blocked Write call. - // Even if write times out, it may return n > 0, indicating that - // some data was successfully written. - // A zero value for t means Write will not time out. - SetWriteDeadline(t time.Time) error -} - -// A Connection is a QUIC connection between two peers. -// Calls to the connection (and to streams) can return the following types of errors: -// * ApplicationError: for errors triggered by the application running on top of QUIC -// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer) -// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error) -// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error) -// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error) -// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers -type Connection interface { - // AcceptStream returns the next stream opened by the peer, blocking until one is available. - // If the connection was closed due to a timeout, the error satisfies - // the net.Error interface, and Timeout() will be true. - AcceptStream(context.Context) (Stream, error) - // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. - // If the connection was closed due to a timeout, the error satisfies - // the net.Error interface, and Timeout() will be true. - AcceptUniStream(context.Context) (ReceiveStream, error) - // OpenStream opens a new bidirectional QUIC stream. - // There is no signaling to the peer about new streams: - // The peer can only accept the stream after data has been sent on the stream. - // If the error is non-nil, it satisfies the net.Error interface. - // When reaching the peer's stream limit, err.Temporary() will be true. - // If the connection was closed due to a timeout, Timeout() will be true. - OpenStream() (Stream, error) - // OpenStreamSync opens a new bidirectional QUIC stream. - // It blocks until a new stream can be opened. - // If the error is non-nil, it satisfies the net.Error interface. - // If the connection was closed due to a timeout, Timeout() will be true. - OpenStreamSync(context.Context) (Stream, error) - // OpenUniStream opens a new outgoing unidirectional QUIC stream. - // If the error is non-nil, it satisfies the net.Error interface. - // When reaching the peer's stream limit, Temporary() will be true. - // If the connection was closed due to a timeout, Timeout() will be true. - OpenUniStream() (SendStream, error) - // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. - // It blocks until a new stream can be opened. - // If the error is non-nil, it satisfies the net.Error interface. - // If the connection was closed due to a timeout, Timeout() will be true. - OpenUniStreamSync(context.Context) (SendStream, error) - // LocalAddr returns the local address. - LocalAddr() net.Addr - // RemoteAddr returns the address of the peer. - RemoteAddr() net.Addr - // CloseWithError closes the connection with an error. - // The error string will be sent to the peer. - CloseWithError(ApplicationErrorCode, string) error - // The context is cancelled when the connection is closed. - Context() context.Context - // ConnectionState returns basic details about the QUIC connection. - // It blocks until the handshake completes. - // Warning: This API should not be considered stable and might change soon. - ConnectionState() ConnectionState - - // SendMessage sends a message as a datagram, as specified in RFC 9221. - SendMessage([]byte) error - // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. - ReceiveMessage() ([]byte, error) -} - -// An EarlyConnection is a connection that is handshaking. -// Data sent during the handshake is encrypted using the forward secure keys. -// When using client certificates, the client's identity is only verified -// after completion of the handshake. -type EarlyConnection interface { - Connection - - // HandshakeComplete blocks until the handshake completes (or fails). - // Data sent before completion of the handshake is encrypted with 1-RTT keys. - // Note that the client's identity hasn't been verified yet. - HandshakeComplete() context.Context - - NextConnection() Connection -} - -// Config contains all configuration data needed for a QUIC server or client. -type Config struct { - // The QUIC versions that can be negotiated. - // If not set, it uses all versions available. - Versions []VersionNumber - // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. - // If not set, the interpretation depends on where the Config is used: - // If used for dialing an address, a 0 byte connection ID will be used. - // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. - // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. - ConnectionIDLength int - // HandshakeIdleTimeout is the idle timeout before completion of the handshake. - // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. - // If this value is zero, the timeout is set to 5 seconds. - HandshakeIdleTimeout time.Duration - // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. - // The actual value for the idle timeout is the minimum of this value and the peer's. - // This value only applies after the handshake has completed. - // If the timeout is exceeded, the connection is closed. - // If this value is zero, the timeout is set to 30 seconds. - MaxIdleTimeout time.Duration - // AcceptToken determines if a Token is accepted. - // It is called with token = nil if the client didn't send a token. - // If not set, a default verification function is used: - // * it verifies that the address matches, and - // * if the token is a retry token, that it was issued within the last 5 seconds - // * else, that it was issued within the last 24 hours. - // This option is only valid for the server. - AcceptToken func(clientAddr net.Addr, token *Token) bool - // The TokenStore stores tokens received from the server. - // Tokens are used to skip address validation on future connection attempts. - // The key used to store tokens is the ServerName from the tls.Config, if set - // otherwise the token is associated with the server's IP address. - TokenStore TokenStore - // InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data. - // If the application is consuming data quickly enough, the flow control auto-tuning algorithm - // will increase the window up to MaxStreamReceiveWindow. - // If this value is zero, it will default to 512 KB. - InitialStreamReceiveWindow uint64 - // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. - // If this value is zero, it will default to 6 MB. - MaxStreamReceiveWindow uint64 - // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. - // If the application is consuming data quickly enough, the flow control auto-tuning algorithm - // will increase the window up to MaxConnectionReceiveWindow. - // If this value is zero, it will default to 512 KB. - InitialConnectionReceiveWindow uint64 - // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. - // If this value is zero, it will default to 15 MB. - MaxConnectionReceiveWindow uint64 - // AllowConnectionWindowIncrease is called every time the connection flow controller attempts - // to increase the connection flow control window. - // If set, the caller can prevent an increase of the window. Typically, it would do so to - // limit the memory usage. - // To avoid deadlocks, it is not valid to call other functions on the connection or on streams - // in this callback. - AllowConnectionWindowIncrease func(sess Connection, delta uint64) bool - // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. - // Values above 2^60 are invalid. - // If not set, it will default to 100. - // If set to a negative value, it doesn't allow any bidirectional streams. - MaxIncomingStreams int64 - // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. - // Values above 2^60 are invalid. - // If not set, it will default to 100. - // If set to a negative value, it doesn't allow any unidirectional streams. - MaxIncomingUniStreams int64 - // The StatelessResetKey is used to generate stateless reset tokens. - // If no key is configured, sending of stateless resets is disabled. - StatelessResetKey []byte - // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. - // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most - // every half of MaxIdleTimeout, whichever is smaller). - KeepAlivePeriod time.Duration - // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). - // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. - // Note that if Path MTU discovery is causing issues on your system, please open a new issue - DisablePathMTUDiscovery bool - // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. - // This can be useful if version information is exchanged out-of-band. - // It has no effect for a client. - DisableVersionNegotiationPackets bool - // See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/. - // Datagrams will only be available when both peers enable datagram support. - EnableDatagrams bool - Tracer logging.Tracer -} - -// ConnectionState records basic details about a QUIC connection -type ConnectionState struct { - TLS handshake.ConnectionState - SupportsDatagrams bool -} - -// A Listener for incoming QUIC connections -type Listener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new connections. It should be called in a loop. - Accept(context.Context) (Connection, error) -} - -// An EarlyListener listens for incoming QUIC connections, -// and returns them before the handshake completes. -type EarlyListener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new early connections. It should be called in a loop. - Accept(context.Context) (EarlyConnection, error) -} diff --git a/internal/quic-go/logging/frame.go b/internal/quic-go/logging/frame.go deleted file mode 100644 index 8675e0f9..00000000 --- a/internal/quic-go/logging/frame.go +++ /dev/null @@ -1,66 +0,0 @@ -package logging - -import "github.com/imroc/req/v3/internal/quic-go/wire" - -// A Frame is a QUIC frame -type Frame interface{} - -// The AckRange is used within the AckFrame. -// It is a range of packet numbers that is being acknowledged. -type AckRange = wire.AckRange - -type ( - // An AckFrame is an ACK frame. - AckFrame = wire.AckFrame - // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. - ConnectionCloseFrame = wire.ConnectionCloseFrame - // A DataBlockedFrame is a DATA_BLOCKED frame. - DataBlockedFrame = wire.DataBlockedFrame - // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. - HandshakeDoneFrame = wire.HandshakeDoneFrame - // A MaxDataFrame is a MAX_DATA frame. - MaxDataFrame = wire.MaxDataFrame - // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. - MaxStreamDataFrame = wire.MaxStreamDataFrame - // A MaxStreamsFrame is a MAX_STREAMS_FRAME. - MaxStreamsFrame = wire.MaxStreamsFrame - // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. - NewConnectionIDFrame = wire.NewConnectionIDFrame - // A NewTokenFrame is a NEW_TOKEN frame. - NewTokenFrame = wire.NewTokenFrame - // A PathChallengeFrame is a PATH_CHALLENGE frame. - PathChallengeFrame = wire.PathChallengeFrame - // A PathResponseFrame is a PATH_RESPONSE frame. - PathResponseFrame = wire.PathResponseFrame - // A PingFrame is a PING frame. - PingFrame = wire.PingFrame - // A ResetStreamFrame is a RESET_STREAM frame. - ResetStreamFrame = wire.ResetStreamFrame - // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. - RetireConnectionIDFrame = wire.RetireConnectionIDFrame - // A StopSendingFrame is a STOP_SENDING frame. - StopSendingFrame = wire.StopSendingFrame - // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. - StreamsBlockedFrame = wire.StreamsBlockedFrame - // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. - StreamDataBlockedFrame = wire.StreamDataBlockedFrame -) - -// A CryptoFrame is a CRYPTO frame. -type CryptoFrame struct { - Offset ByteCount - Length ByteCount -} - -// A StreamFrame is a STREAM frame. -type StreamFrame struct { - StreamID StreamID - Offset ByteCount - Length ByteCount - Fin bool -} - -// A DatagramFrame is a DATAGRAM frame. -type DatagramFrame struct { - Length ByteCount -} diff --git a/internal/quic-go/logging/interface.go b/internal/quic-go/logging/interface.go deleted file mode 100644 index f4e64840..00000000 --- a/internal/quic-go/logging/interface.go +++ /dev/null @@ -1,134 +0,0 @@ -// Package logging defines a logging interface for quic-go. -// This package should not be considered stable -package logging - -import ( - "context" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type ( - // A ByteCount is used to count bytes. - ByteCount = protocol.ByteCount - // A ConnectionID is a QUIC Connection ID. - ConnectionID = protocol.ConnectionID - // The EncryptionLevel is the encryption level of a packet. - EncryptionLevel = protocol.EncryptionLevel - // The KeyPhase is the key phase of the 1-RTT keys. - KeyPhase = protocol.KeyPhase - // The KeyPhaseBit is the value of the key phase bit of the 1-RTT packets. - KeyPhaseBit = protocol.KeyPhaseBit - // The PacketNumber is the packet number of a packet. - PacketNumber = protocol.PacketNumber - // The Perspective is the role of a QUIC endpoint (client or server). - Perspective = protocol.Perspective - // A StatelessResetToken is a stateless reset token. - StatelessResetToken = protocol.StatelessResetToken - // The StreamID is the stream ID. - StreamID = protocol.StreamID - // The StreamNum is the number of the stream. - StreamNum = protocol.StreamNum - // The StreamType is the type of the stream (unidirectional or bidirectional). - StreamType = protocol.StreamType - // The VersionNumber is the QUIC version. - VersionNumber = protocol.VersionNumber - - // The Header is the QUIC packet header, before removing header protection. - Header = wire.Header - // The ExtendedHeader is the QUIC packet header, after removing header protection. - ExtendedHeader = wire.ExtendedHeader - // The TransportParameters are QUIC transport parameters. - TransportParameters = wire.TransportParameters - // The PreferredAddress is the preferred address sent in the transport parameters. - PreferredAddress = wire.PreferredAddress - - // A TransportError is a transport-level error code. - TransportError = qerr.TransportErrorCode - // An ApplicationError is an application-defined error code. - ApplicationError = qerr.TransportErrorCode - - // The RTTStats contain statistics used by the congestion controller. - RTTStats = utils.RTTStats -) - -const ( - // KeyPhaseZero is key phase bit 0 - KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero - // KeyPhaseOne is key phase bit 1 - KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne -) - -const ( - // PerspectiveServer is used for a QUIC server - PerspectiveServer Perspective = protocol.PerspectiveServer - // PerspectiveClient is used for a QUIC client - PerspectiveClient Perspective = protocol.PerspectiveClient -) - -const ( - // EncryptionInitial is the Initial encryption level - EncryptionInitial EncryptionLevel = protocol.EncryptionInitial - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake - // Encryption1RTT is the 1-RTT encryption level - Encryption1RTT EncryptionLevel = protocol.Encryption1RTT - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT EncryptionLevel = protocol.Encryption0RTT -) - -const ( - // StreamTypeUni is a unidirectional stream - StreamTypeUni = protocol.StreamTypeUni - // StreamTypeBidi is a bidirectional stream - StreamTypeBidi = protocol.StreamTypeBidi -) - -// A Tracer traces events. -type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - - SentPacket(net.Addr, *Header, ByteCount, []Frame) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(*Header, []VersionNumber) - ReceivedRetry(*Header) - ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - BufferedPacket(PacketType) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/internal/quic-go/logging/logging_suite_test.go b/internal/quic-go/logging/logging_suite_test.go deleted file mode 100644 index 0a81943d..00000000 --- a/internal/quic-go/logging/logging_suite_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package logging - -import ( - "testing" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestLogging(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Logging Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/quic-go/logging/mock_connection_tracer_test.go b/internal/quic-go/logging/mock_connection_tracer_test.go deleted file mode 100644 index 620f181b..00000000 --- a/internal/quic-go/logging/mock_connection_tracer_test.go +++ /dev/null @@ -1,351 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: ConnectionTracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - net "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - utils "github.com/imroc/req/v3/internal/quic-go/utils" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedPacket mocks base method. -func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentPacket mocks base method. -func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) -} diff --git a/internal/quic-go/logging/mock_tracer_test.go b/internal/quic-go/logging/mock_tracer_test.go deleted file mode 100644 index 6d49601a..00000000 --- a/internal/quic-go/logging/mock_tracer_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: Tracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/internal/quic-go/logging/mockgen.go b/internal/quic-go/logging/mockgen.go deleted file mode 100644 index 09122480..00000000 --- a/internal/quic-go/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/quic-go/logging -destination mock_connection_tracer_test.go github.com/imroc/req/v3/internal/quic-go/logging ConnectionTracer" -//go:generate sh -c "mockgen -package logging -self_package github.com/imroc/req/v3/internal/quic-go/logging -destination mock_tracer_test.go github.com/imroc/req/v3/internal/quic-go/logging Tracer" diff --git a/internal/quic-go/logging/multiplex.go b/internal/quic-go/logging/multiplex.go deleted file mode 100644 index 8280e8cd..00000000 --- a/internal/quic-go/logging/multiplex.go +++ /dev/null @@ -1,219 +0,0 @@ -package logging - -import ( - "context" - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer - for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } - } - return NewMultiplexedConnectionTracer(connTracers...) -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(hdr, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { - for _, t := range m.tracers { - t.BufferedPacket(typ) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/internal/quic-go/logging/multiplex_test.go b/internal/quic-go/logging/multiplex_test.go deleted file mode 100644 index e9458d81..00000000 --- a/internal/quic-go/logging/multiplex_test.go +++ /dev/null @@ -1,266 +0,0 @@ -package logging - -import ( - "context" - "errors" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Tracing", func() { - Context("Tracer", func() { - It("returns a nil tracer if no tracers are passed in", func() { - Expect(NewMultiplexedTracer()).To(BeNil()) - }) - - It("returns the raw tracer if only one tracer is passed in", func() { - tr := NewMockTracer(mockCtrl) - tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) - }) - - Context("tracing events", func() { - var ( - tracer Tracer - tr1, tr2 *MockTracer - ) - - BeforeEach(func() { - tr1 = NewMockTracer(mockCtrl) - tr2 = NewMockTracer(mockCtrl) - tracer = NewMultiplexedTracer(tr1, tr2) - }) - - It("multiplexes the TracerForConnection call", func() { - ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - }) - - It("uses multiple connection tracers", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - ctr2 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - ctr1.EXPECT().LossTimerCanceled() - ctr2.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("handles tracers that return a nil ConnectionTracer", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - ctr1.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("returns nil when all tracers return a nil ConnectionTracer", func() { - ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) - }) - - It("traces the PacketSent event", func() { - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - f := &MaxDataFrame{MaximumData: 1337} - tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tracer.SentPacket(remote, hdr, 1024, []Frame{f}) - }) - - It("traces the PacketDropped event", func() { - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) - }) - }) - }) - - Context("Connection Tracer", func() { - var ( - tracer ConnectionTracer - tr1 *MockConnectionTracer - tr2 *MockConnectionTracer - ) - - BeforeEach(func() { - tr1 = NewMockConnectionTracer(mockCtrl) - tr2 = NewMockConnectionTracer(mockCtrl) - tracer = NewMultiplexedConnectionTracer(tr1, tr2) - }) - - It("trace the ConnectionStarted event", func() { - local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - }) - - It("traces the ClosedConnection event", func() { - e := errors.New("test err") - tr1.EXPECT().ClosedConnection(e) - tr2.EXPECT().ClosedConnection(e) - tracer.ClosedConnection(e) - }) - - It("traces the SentTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().SentTransportParameters(tp) - tr2.EXPECT().SentTransportParameters(tp) - tracer.SentTransportParameters(tp) - }) - - It("traces the ReceivedTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().ReceivedTransportParameters(tp) - tr2.EXPECT().ReceivedTransportParameters(tp) - tracer.ReceivedTransportParameters(tp) - }) - - It("traces the RestoredTransportParameters event", func() { - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().RestoredTransportParameters(tp) - tr2.EXPECT().RestoredTransportParameters(tp) - tracer.RestoredTransportParameters(tp) - }) - - It("traces the SentPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} - ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} - ping := &PingFrame{} - tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentPacket(hdr, 1337, ack, []Frame{ping}) - }) - - It("traces the ReceivedVersionNegotiationPacket event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - tr1.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tr2.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tracer.ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - }) - - It("traces the ReceivedRetry event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - tr1.EXPECT().ReceivedRetry(hdr) - tr2.EXPECT().ReceivedRetry(hdr) - tracer.ReceivedRetry(hdr) - }) - - It("traces the ReceivedPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} - ping := &PingFrame{} - tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedPacket(hdr, 1337, []Frame{ping}) - }) - - It("traces the BufferedPacket event", func() { - tr1.EXPECT().BufferedPacket(PacketTypeHandshake) - tr2.EXPECT().BufferedPacket(PacketTypeHandshake) - tracer.BufferedPacket(PacketTypeHandshake) - }) - - It("traces the DroppedPacket event", func() { - tr1.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tr2.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) - }) - - It("traces the UpdatedCongestionState event", func() { - tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) - tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) - tracer.UpdatedCongestionState(CongestionStateRecovery) - }) - - It("traces the UpdatedMetrics event", func() { - rttStats := &RTTStats{} - rttStats.UpdateRTT(time.Second, 0, time.Now()) - tr1.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) - tr2.EXPECT().UpdatedMetrics(rttStats, ByteCount(1337), ByteCount(42), 13) - tracer.UpdatedMetrics(rttStats, 1337, 42, 13) - }) - - It("traces the AcknowledgedPacket event", func() { - tr1.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) - tr2.EXPECT().AcknowledgedPacket(EncryptionHandshake, PacketNumber(42)) - tracer.AcknowledgedPacket(EncryptionHandshake, 42) - }) - - It("traces the LostPacket event", func() { - tr1.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) - tr2.EXPECT().LostPacket(EncryptionHandshake, PacketNumber(42), PacketLossReorderingThreshold) - tracer.LostPacket(EncryptionHandshake, 42, PacketLossReorderingThreshold) - }) - - It("traces the UpdatedPTOCount event", func() { - tr1.EXPECT().UpdatedPTOCount(uint32(88)) - tr2.EXPECT().UpdatedPTOCount(uint32(88)) - tracer.UpdatedPTOCount(88) - }) - - It("traces the UpdatedKeyFromTLS event", func() { - tr1.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - tr2.EXPECT().UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - tracer.UpdatedKeyFromTLS(EncryptionHandshake, PerspectiveClient) - }) - - It("traces the UpdatedKey event", func() { - tr1.EXPECT().UpdatedKey(KeyPhase(42), true) - tr2.EXPECT().UpdatedKey(KeyPhase(42), true) - tracer.UpdatedKey(KeyPhase(42), true) - }) - - It("traces the DroppedEncryptionLevel event", func() { - tr1.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) - tr2.EXPECT().DroppedEncryptionLevel(EncryptionHandshake) - tracer.DroppedEncryptionLevel(EncryptionHandshake) - }) - - It("traces the DroppedKey event", func() { - tr1.EXPECT().DroppedKey(KeyPhase(123)) - tr2.EXPECT().DroppedKey(KeyPhase(123)) - tracer.DroppedKey(123) - }) - - It("traces the SetLossTimer event", func() { - now := time.Now() - tr1.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - tr2.EXPECT().SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - tracer.SetLossTimer(TimerTypePTO, EncryptionHandshake, now) - }) - - It("traces the LossTimerExpired event", func() { - tr1.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) - tr2.EXPECT().LossTimerExpired(TimerTypePTO, EncryptionHandshake) - tracer.LossTimerExpired(TimerTypePTO, EncryptionHandshake) - }) - - It("traces the LossTimerCanceled event", func() { - tr1.EXPECT().LossTimerCanceled() - tr2.EXPECT().LossTimerCanceled() - tracer.LossTimerCanceled() - }) - - It("traces the Close event", func() { - tr1.EXPECT().Close() - tr2.EXPECT().Close() - tracer.Close() - }) - }) -}) diff --git a/internal/quic-go/logging/packet_header.go b/internal/quic-go/logging/packet_header.go deleted file mode 100644 index 9bb397dd..00000000 --- a/internal/quic-go/logging/packet_header.go +++ /dev/null @@ -1,27 +0,0 @@ -package logging - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// PacketTypeFromHeader determines the packet type from a *wire.Header. -func PacketTypeFromHeader(hdr *Header) PacketType { - if !hdr.IsLongHeader { - return PacketType1RTT - } - if hdr.Version == 0 { - return PacketTypeVersionNegotiation - } - switch hdr.Type { - case protocol.PacketTypeInitial: - return PacketTypeInitial - case protocol.PacketTypeHandshake: - return PacketTypeHandshake - case protocol.PacketType0RTT: - return PacketType0RTT - case protocol.PacketTypeRetry: - return PacketTypeRetry - default: - return PacketTypeNotDetermined - } -} diff --git a/internal/quic-go/logging/packet_header_test.go b/internal/quic-go/logging/packet_header_test.go deleted file mode 100644 index de8b3e68..00000000 --- a/internal/quic-go/logging/packet_header_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package logging - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet Header", func() { - Context("determining the packet type from the header", func() { - It("recognizes Initial packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeInitial)) - }) - - It("recognizes Handshake packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeHandshake)) - }) - - It("recognizes Retry packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeRetry)) - }) - - It("recognizes 0-RTT packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - Version: protocol.VersionTLS, - })).To(Equal(PacketType0RTT)) - }) - - It("recognizes Version Negotiation packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{IsLongHeader: true})).To(Equal(PacketTypeVersionNegotiation)) - }) - - It("recognizes 1-RTT packets", func() { - Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketType1RTT)) - }) - - It("handles unrecognized packet types", func() { - Expect(PacketTypeFromHeader(&wire.Header{ - IsLongHeader: true, - Version: protocol.VersionTLS, - })).To(Equal(PacketTypeNotDetermined)) - }) - }) -}) diff --git a/internal/quic-go/logging/types.go b/internal/quic-go/logging/types.go deleted file mode 100644 index ad800692..00000000 --- a/internal/quic-go/logging/types.go +++ /dev/null @@ -1,94 +0,0 @@ -package logging - -// PacketType is the packet type of a QUIC packet -type PacketType uint8 - -const ( - // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = iota - // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake - // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry - // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT - // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet - PacketTypeVersionNegotiation - // PacketType1RTT is a 1-RTT packet - PacketType1RTT - // PacketTypeStatelessReset is a stateless reset - PacketTypeStatelessReset - // PacketTypeNotDetermined is the packet type when it could not be determined - PacketTypeNotDetermined -) - -type PacketLossReason uint8 - -const ( - // PacketLossReorderingThreshold: when a packet is deemed lost due to reordering threshold - PacketLossReorderingThreshold PacketLossReason = iota - // PacketLossTimeThreshold: when a packet is deemed lost due to time threshold - PacketLossTimeThreshold -) - -type PacketDropReason uint8 - -const ( - // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable - PacketDropKeyUnavailable PacketDropReason = iota - // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown - PacketDropUnknownConnectionID - // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed - PacketDropHeaderParseError - // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed - PacketDropPayloadDecryptError - // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation - PacketDropProtocolViolation - // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack - PacketDropDOSPrevention - // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported - PacketDropUnsupportedVersion - // PacketDropUnexpectedPacket is used when an unexpected packet is received - PacketDropUnexpectedPacket - // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received - PacketDropUnexpectedSourceConnectionID - // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received - PacketDropUnexpectedVersion - // PacketDropDuplicate is used when a duplicate packet is received - PacketDropDuplicate -) - -// TimerType is the type of the loss detection timer -type TimerType uint8 - -const ( - // TimerTypeACK is the timer type for the early retransmit timer - TimerTypeACK TimerType = iota - // TimerTypePTO is the timer type for the PTO retransmit timer - TimerTypePTO -) - -// TimeoutReason is the reason why a connection is closed -type TimeoutReason uint8 - -const ( - // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout - // This reason is not defined in the qlog draft, but very useful for debugging. - TimeoutReasonHandshake TimeoutReason = iota - // TimeoutReasonIdle is used when the connection is closed due to an idle timeout - // This reason is not defined in the qlog draft, but very useful for debugging. - TimeoutReasonIdle -) - -type CongestionState uint8 - -const ( - // CongestionStateSlowStart is the slow start phase of Reno / Cubic - CongestionStateSlowStart CongestionState = iota - // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic - CongestionStateCongestionAvoidance - // CongestionStateRecovery is the recovery phase of Reno / Cubic - CongestionStateRecovery - // CongestionStateApplicationLimited means that the congestion controller is application limited - CongestionStateApplicationLimited -) diff --git a/internal/quic-go/logutils/frame.go b/internal/quic-go/logutils/frame.go deleted file mode 100644 index 9076c8f0..00000000 --- a/internal/quic-go/logutils/frame.go +++ /dev/null @@ -1,33 +0,0 @@ -package logutils - -import ( - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// ConvertFrame converts a wire.Frame into a logging.Frame. -// This makes it possible for external packages to access the frames. -// Furthermore, it removes the data slices from CRYPTO and STREAM frames. -func ConvertFrame(frame wire.Frame) logging.Frame { - switch f := frame.(type) { - case *wire.CryptoFrame: - return &logging.CryptoFrame{ - Offset: f.Offset, - Length: protocol.ByteCount(len(f.Data)), - } - case *wire.StreamFrame: - return &logging.StreamFrame{ - StreamID: f.StreamID, - Offset: f.Offset, - Length: f.DataLen(), - Fin: f.Fin, - } - case *wire.DatagramFrame: - return &logging.DatagramFrame{ - Length: logging.ByteCount(len(f.Data)), - } - default: - return logging.Frame(frame) - } -} diff --git a/internal/quic-go/logutils/frame_test.go b/internal/quic-go/logutils/frame_test.go deleted file mode 100644 index 9a1acc13..00000000 --- a/internal/quic-go/logutils/frame_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package logutils - -import ( - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/wire" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("CRYPTO frame", func() { - It("converts CRYPTO frames", func() { - f := ConvertFrame(&wire.CryptoFrame{ - Offset: 1234, - Data: []byte("foobar"), - }) - Expect(f).To(BeAssignableToTypeOf(&logging.CryptoFrame{})) - cf := f.(*logging.CryptoFrame) - Expect(cf.Offset).To(Equal(logging.ByteCount(1234))) - Expect(cf.Length).To(Equal(logging.ByteCount(6))) - }) - - It("converts STREAM frames", func() { - f := ConvertFrame(&wire.StreamFrame{ - StreamID: 42, - Offset: 1234, - Data: []byte("foo"), - Fin: true, - }) - Expect(f).To(BeAssignableToTypeOf(&logging.StreamFrame{})) - sf := f.(*logging.StreamFrame) - Expect(sf.StreamID).To(Equal(logging.StreamID(42))) - Expect(sf.Offset).To(Equal(logging.ByteCount(1234))) - Expect(sf.Length).To(Equal(logging.ByteCount(3))) - Expect(sf.Fin).To(BeTrue()) - }) - - It("converts DATAGRAM frames", func() { - f := ConvertFrame(&wire.DatagramFrame{Data: []byte("foobar")}) - Expect(f).To(BeAssignableToTypeOf(&logging.DatagramFrame{})) - df := f.(*logging.DatagramFrame) - Expect(df.Length).To(Equal(logging.ByteCount(6))) - }) - - It("converts other frames", func() { - f := ConvertFrame(&wire.MaxDataFrame{MaximumData: 1234}) - Expect(f).To(BeAssignableToTypeOf(&logging.MaxDataFrame{})) - Expect(f).ToNot(BeAssignableToTypeOf(&logging.MaxStreamDataFrame{})) - mdf := f.(*logging.MaxDataFrame) - Expect(mdf.MaximumData).To(Equal(logging.ByteCount(1234))) - }) -}) diff --git a/internal/quic-go/logutils/logutils_suite_test.go b/internal/quic-go/logutils/logutils_suite_test.go deleted file mode 100644 index dc496b2d..00000000 --- a/internal/quic-go/logutils/logutils_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package logutils - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestLogutils(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Logutils Suite") -} diff --git a/internal/quic-go/mock_ack_frame_source_test.go b/internal/quic-go/mock_ack_frame_source_test.go deleted file mode 100644 index 4d498553..00000000 --- a/internal/quic-go/mock_ack_frame_source_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: packet_packer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockAckFrameSource is a mock of AckFrameSource interface. -type MockAckFrameSource struct { - ctrl *gomock.Controller - recorder *MockAckFrameSourceMockRecorder -} - -// MockAckFrameSourceMockRecorder is the mock recorder for MockAckFrameSource. -type MockAckFrameSourceMockRecorder struct { - mock *MockAckFrameSource -} - -// NewMockAckFrameSource creates a new mock instance. -func NewMockAckFrameSource(ctrl *gomock.Controller) *MockAckFrameSource { - mock := &MockAckFrameSource{ctrl: ctrl} - mock.recorder = &MockAckFrameSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { - return m.recorder -} - -// GetAckFrame mocks base method. -func (m *MockAckFrameSource) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAckFrame", encLevel, onlyIfQueued) - ret0, _ := ret[0].(*wire.AckFrame) - return ret0 -} - -// GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(encLevel, onlyIfQueued interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), encLevel, onlyIfQueued) -} diff --git a/internal/quic-go/mock_batch_conn_test.go b/internal/quic-go/mock_batch_conn_test.go deleted file mode 100644 index 74032900..00000000 --- a/internal/quic-go/mock_batch_conn_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: sys_conn_oob.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - ipv4 "golang.org/x/net/ipv4" -) - -// MockBatchConn is a mock of BatchConn interface. -type MockBatchConn struct { - ctrl *gomock.Controller - recorder *MockBatchConnMockRecorder -} - -// MockBatchConnMockRecorder is the mock recorder for MockBatchConn. -type MockBatchConnMockRecorder struct { - mock *MockBatchConn -} - -// NewMockBatchConn creates a new mock instance. -func NewMockBatchConn(ctrl *gomock.Controller) *MockBatchConn { - mock := &MockBatchConn{ctrl: ctrl} - mock.recorder = &MockBatchConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockBatchConn) EXPECT() *MockBatchConnMockRecorder { - return m.recorder -} - -// ReadBatch mocks base method. -func (m *MockBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadBatch", ms, flags) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadBatch indicates an expected call of ReadBatch. -func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) -} diff --git a/internal/quic-go/mock_conn_runner_test.go b/internal/quic-go/mock_conn_runner_test.go deleted file mode 100644 index 99d7dc8f..00000000 --- a/internal/quic-go/mock_conn_runner_test.go +++ /dev/null @@ -1,123 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: connection.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockConnRunner is a mock of ConnRunner interface. -type MockConnRunner struct { - ctrl *gomock.Controller - recorder *MockConnRunnerMockRecorder -} - -// MockConnRunnerMockRecorder is the mock recorder for MockConnRunner. -type MockConnRunnerMockRecorder struct { - mock *MockConnRunner -} - -// NewMockConnRunner creates a new mock instance. -func NewMockConnRunner(ctrl *gomock.Controller) *MockConnRunner { - mock := &MockConnRunner{ctrl: ctrl} - mock.recorder = &MockConnRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnRunner) EXPECT() *MockConnRunnerMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Add indicates an expected call of Add. -func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) -} - -// AddResetToken mocks base method. -func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddResetToken", arg0, arg1) -} - -// AddResetToken indicates an expected call of AddResetToken. -func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) -} - -// GetStatelessResetToken mocks base method. -func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].(protocol.StatelessResetToken) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) -} - -// Remove mocks base method. -func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove. -func (mr *MockConnRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) -} - -// RemoveResetToken mocks base method. -func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RemoveResetToken", arg0) -} - -// RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) -} - -// ReplaceWithClosed mocks base method. -func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) -} - -// ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) -} - -// Retire mocks base method. -func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Retire", arg0) -} - -// Retire indicates an expected call of Retire. -func (mr *MockConnRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) -} diff --git a/internal/quic-go/mock_crypto_data_handler_test.go b/internal/quic-go/mock_crypto_data_handler_test.go deleted file mode 100644 index 9c70ff2d..00000000 --- a/internal/quic-go/mock_crypto_data_handler_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: crypto_stream_manager.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockCryptoDataHandler is a mock of CryptoDataHandler interface. -type MockCryptoDataHandler struct { - ctrl *gomock.Controller - recorder *MockCryptoDataHandlerMockRecorder -} - -// MockCryptoDataHandlerMockRecorder is the mock recorder for MockCryptoDataHandler. -type MockCryptoDataHandlerMockRecorder struct { - mock *MockCryptoDataHandler -} - -// NewMockCryptoDataHandler creates a new mock instance. -func NewMockCryptoDataHandler(ctrl *gomock.Controller) *MockCryptoDataHandler { - mock := &MockCryptoDataHandler{ctrl: ctrl} - mock.recorder = &MockCryptoDataHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { - return m.recorder -} - -// HandleMessage mocks base method. -func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) -} diff --git a/internal/quic-go/mock_crypto_stream_test.go b/internal/quic-go/mock_crypto_stream_test.go deleted file mode 100644 index 2cdf22de..00000000 --- a/internal/quic-go/mock_crypto_stream_test.go +++ /dev/null @@ -1,121 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: crypto_stream.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockCryptoStream is a mock of CryptoStream interface. -type MockCryptoStream struct { - ctrl *gomock.Controller - recorder *MockCryptoStreamMockRecorder -} - -// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream. -type MockCryptoStreamMockRecorder struct { - mock *MockCryptoStream -} - -// NewMockCryptoStream creates a new mock instance. -func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream { - mock := &MockCryptoStream{ctrl: ctrl} - mock.recorder = &MockCryptoStreamMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { - return m.recorder -} - -// Finish mocks base method. -func (m *MockCryptoStream) Finish() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Finish") - ret0, _ := ret[0].(error) - return ret0 -} - -// Finish indicates an expected call of Finish. -func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) -} - -// GetCryptoData mocks base method. -func (m *MockCryptoStream) GetCryptoData() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCryptoData") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// GetCryptoData indicates an expected call of GetCryptoData. -func (mr *MockCryptoStreamMockRecorder) GetCryptoData() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoData", reflect.TypeOf((*MockCryptoStream)(nil).GetCryptoData)) -} - -// HandleCryptoFrame mocks base method. -func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleCryptoFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// HandleCryptoFrame indicates an expected call of HandleCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) -} - -// HasData mocks base method. -func (m *MockCryptoStream) HasData() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasData indicates an expected call of HasData. -func (mr *MockCryptoStreamMockRecorder) HasData() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockCryptoStream)(nil).HasData)) -} - -// PopCryptoFrame mocks base method. -func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoFrame { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PopCryptoFrame", arg0) - ret0, _ := ret[0].(*wire.CryptoFrame) - return ret0 -} - -// PopCryptoFrame indicates an expected call of PopCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) -} - -// Write mocks base method. -func (m *MockCryptoStream) Write(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockCryptoStreamMockRecorder) Write(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), p) -} diff --git a/internal/quic-go/mock_frame_source_test.go b/internal/quic-go/mock_frame_source_test.go deleted file mode 100644 index efe0bccf..00000000 --- a/internal/quic-go/mock_frame_source_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: packet_packer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockFrameSource is a mock of FrameSource interface. -type MockFrameSource struct { - ctrl *gomock.Controller - recorder *MockFrameSourceMockRecorder -} - -// MockFrameSourceMockRecorder is the mock recorder for MockFrameSource. -type MockFrameSourceMockRecorder struct { - mock *MockFrameSource -} - -// NewMockFrameSource creates a new mock instance. -func NewMockFrameSource(ctrl *gomock.Controller) *MockFrameSource { - mock := &MockFrameSource{ctrl: ctrl} - mock.recorder = &MockFrameSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { - return m.recorder -} - -// AppendControlFrames mocks base method. -func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendControlFrames", arg0, arg1) - ret0, _ := ret[0].([]ackhandler.Frame) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// AppendControlFrames indicates an expected call of AppendControlFrames. -func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1) -} - -// AppendStreamFrames mocks base method. -func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) - ret0, _ := ret[0].([]ackhandler.Frame) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// AppendStreamFrames indicates an expected call of AppendStreamFrames. -func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1) -} - -// HasData mocks base method. -func (m *MockFrameSource) HasData() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasData indicates an expected call of HasData. -func (mr *MockFrameSourceMockRecorder) HasData() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) -} diff --git a/internal/quic-go/mock_mtu_discoverer_test.go b/internal/quic-go/mock_mtu_discoverer_test.go deleted file mode 100644 index 57993be1..00000000 --- a/internal/quic-go/mock_mtu_discoverer_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: mtu_discoverer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockMtuDiscoverer is a mock of MtuDiscoverer interface. -type MockMtuDiscoverer struct { - ctrl *gomock.Controller - recorder *MockMtuDiscovererMockRecorder -} - -// MockMtuDiscovererMockRecorder is the mock recorder for MockMtuDiscoverer. -type MockMtuDiscovererMockRecorder struct { - mock *MockMtuDiscoverer -} - -// NewMockMtuDiscoverer creates a new mock instance. -func NewMockMtuDiscoverer(ctrl *gomock.Controller) *MockMtuDiscoverer { - mock := &MockMtuDiscoverer{ctrl: ctrl} - mock.recorder = &MockMtuDiscovererMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMtuDiscoverer) EXPECT() *MockMtuDiscovererMockRecorder { - return m.recorder -} - -// GetPing mocks base method. -func (m *MockMtuDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPing") - ret0, _ := ret[0].(ackhandler.Frame) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// GetPing indicates an expected call of GetPing. -func (mr *MockMtuDiscovererMockRecorder) GetPing() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMtuDiscoverer)(nil).GetPing)) -} - -// ShouldSendProbe mocks base method. -func (m *MockMtuDiscoverer) ShouldSendProbe(now time.Time) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShouldSendProbe", now) - ret0, _ := ret[0].(bool) - return ret0 -} - -// ShouldSendProbe indicates an expected call of ShouldSendProbe. -func (mr *MockMtuDiscovererMockRecorder) ShouldSendProbe(now interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMtuDiscoverer)(nil).ShouldSendProbe), now) -} diff --git a/internal/quic-go/mock_multiplexer_test.go b/internal/quic-go/mock_multiplexer_test.go deleted file mode 100644 index 2bce0112..00000000 --- a/internal/quic-go/mock_multiplexer_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: multiplexer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - logging "github.com/imroc/req/v3/internal/quic-go/logging" -) - -// MockMultiplexer is a mock of Multiplexer interface. -type MockMultiplexer struct { - ctrl *gomock.Controller - recorder *MockMultiplexerMockRecorder -} - -// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer. -type MockMultiplexerMockRecorder struct { - mock *MockMultiplexer -} - -// NewMockMultiplexer creates a new mock instance. -func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer { - mock := &MockMultiplexer{ctrl: ctrl} - mock.recorder = &MockMultiplexerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { - return m.recorder -} - -// AddConn mocks base method. -func (m *MockMultiplexer) AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", c, connIDLen, statelessResetKey, tracer) - ret0, _ := ret[0].(packetHandlerManager) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AddConn indicates an expected call of AddConn. -func (mr *MockMultiplexerMockRecorder) AddConn(c, connIDLen, statelessResetKey, tracer interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), c, connIDLen, statelessResetKey, tracer) -} - -// RemoveConn mocks base method. -func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveConn", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveConn indicates an expected call of RemoveConn. -func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0) -} diff --git a/internal/quic-go/mock_packer_test.go b/internal/quic-go/mock_packer_test.go deleted file mode 100644 index ec1e5cca..00000000 --- a/internal/quic-go/mock_packer_test.go +++ /dev/null @@ -1,179 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: packet_packer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - qerr "github.com/imroc/req/v3/internal/quic-go/qerr" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockPacker is a mock of Packer interface. -type MockPacker struct { - ctrl *gomock.Controller - recorder *MockPackerMockRecorder -} - -// MockPackerMockRecorder is the mock recorder for MockPacker. -type MockPackerMockRecorder struct { - mock *MockPacker -} - -// NewMockPacker creates a new mock instance. -func NewMockPacker(ctrl *gomock.Controller) *MockPacker { - mock := &MockPacker{ctrl: ctrl} - mock.recorder = &MockPackerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacker) EXPECT() *MockPackerMockRecorder { - return m.recorder -} - -// HandleTransportParameters mocks base method. -func (m *MockPacker) HandleTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "HandleTransportParameters", arg0) -} - -// HandleTransportParameters indicates an expected call of HandleTransportParameters. -func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleTransportParameters", reflect.TypeOf((*MockPacker)(nil).HandleTransportParameters), arg0) -} - -// MaybePackAckPacket mocks base method. -func (m *MockPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackAckPacket", handshakeConfirmed) - ret0, _ := ret[0].(*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MaybePackAckPacket indicates an expected call of MaybePackAckPacket. -func (mr *MockPackerMockRecorder) MaybePackAckPacket(handshakeConfirmed interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), handshakeConfirmed) -} - -// MaybePackProbePacket mocks base method. -func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0) - ret0, _ := ret[0].(*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) -} - -// PackApplicationClose mocks base method. -func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError) (*coalescedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackApplicationClose", arg0) - ret0, _ := ret[0].(*coalescedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackApplicationClose indicates an expected call of PackApplicationClose. -func (mr *MockPackerMockRecorder) PackApplicationClose(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0) -} - -// PackCoalescedPacket mocks base method. -func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackCoalescedPacket") - ret0, _ := ret[0].(*coalescedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket)) -} - -// PackConnectionClose mocks base method. -func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError) (*coalescedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackConnectionClose", arg0) - ret0, _ := ret[0].(*coalescedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackConnectionClose indicates an expected call of PackConnectionClose. -func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0) -} - -// PackMTUProbePacket mocks base method. -func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size) - ret0, _ := ret[0].(*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackMTUProbePacket indicates an expected call of PackMTUProbePacket. -func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size) -} - -// PackPacket mocks base method. -func (m *MockPacker) PackPacket() (*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackPacket") - ret0, _ := ret[0].(*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackPacket indicates an expected call of PackPacket. -func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) -} - -// SetMaxPacketSize mocks base method. -func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxPacketSize", arg0) -} - -// SetMaxPacketSize indicates an expected call of SetMaxPacketSize. -func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) -} - -// SetToken mocks base method. -func (m *MockPacker) SetToken(arg0 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetToken", arg0) -} - -// SetToken indicates an expected call of SetToken. -func (mr *MockPackerMockRecorder) SetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) -} diff --git a/internal/quic-go/mock_packet_handler_manager_test.go b/internal/quic-go/mock_packet_handler_manager_test.go deleted file mode 100644 index eb8539da..00000000 --- a/internal/quic-go/mock_packet_handler_manager_test.go +++ /dev/null @@ -1,175 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: server.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockPacketHandlerManager is a mock of PacketHandlerManager interface. -type MockPacketHandlerManager struct { - ctrl *gomock.Controller - recorder *MockPacketHandlerManagerMockRecorder -} - -// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager. -type MockPacketHandlerManagerMockRecorder struct { - mock *MockPacketHandlerManager -} - -// NewMockPacketHandlerManager creates a new mock instance. -func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager { - mock := &MockPacketHandlerManager{ctrl: ctrl} - mock.recorder = &MockPacketHandlerManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Add indicates an expected call of Add. -func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) -} - -// AddResetToken mocks base method. -func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddResetToken", arg0, arg1) -} - -// AddResetToken indicates an expected call of AddResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) -} - -// AddWithConnID mocks base method. -func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) - ret0, _ := ret[0].(bool) - return ret0 -} - -// AddWithConnID indicates an expected call of AddWithConnID. -func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) -} - -// CloseServer mocks base method. -func (m *MockPacketHandlerManager) CloseServer() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseServer") -} - -// CloseServer indicates an expected call of CloseServer. -func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) -} - -// Destroy mocks base method. -func (m *MockPacketHandlerManager) Destroy() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Destroy") - ret0, _ := ret[0].(error) - return ret0 -} - -// Destroy indicates an expected call of Destroy. -func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy)) -} - -// GetStatelessResetToken mocks base method. -func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].(protocol.StatelessResetToken) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) -} - -// Remove mocks base method. -func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove. -func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) -} - -// RemoveResetToken mocks base method. -func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessResetToken) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RemoveResetToken", arg0) -} - -// RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) -} - -// ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) -} - -// ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) -} - -// Retire mocks base method. -func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Retire", arg0) -} - -// Retire indicates an expected call of Retire. -func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) -} - -// SetServer mocks base method. -func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetServer", arg0) -} - -// SetServer indicates an expected call of SetServer. -func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0) -} diff --git a/internal/quic-go/mock_packet_handler_test.go b/internal/quic-go/mock_packet_handler_test.go deleted file mode 100644 index 82bb383e..00000000 --- a/internal/quic-go/mock_packet_handler_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: server.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockPacketHandler is a mock of PacketHandler interface. -type MockPacketHandler struct { - ctrl *gomock.Controller - recorder *MockPacketHandlerMockRecorder -} - -// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler. -type MockPacketHandlerMockRecorder struct { - mock *MockPacketHandler -} - -// NewMockPacketHandler creates a new mock instance. -func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler { - mock := &MockPacketHandler{ctrl: ctrl} - mock.recorder = &MockPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { - return m.recorder -} - -// destroy mocks base method. -func (m *MockPacketHandler) destroy(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "destroy", arg0) -} - -// destroy indicates an expected call of destroy. -func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) -} - -// getPerspective mocks base method. -func (m *MockPacketHandler) getPerspective() protocol.Perspective { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 -} - -// getPerspective indicates an expected call of getPerspective. -func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective)) -} - -// handlePacket mocks base method. -func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handlePacket", arg0) -} - -// handlePacket indicates an expected call of handlePacket. -func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) -} - -// shutdown mocks base method. -func (m *MockPacketHandler) shutdown() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "shutdown") -} - -// shutdown indicates an expected call of shutdown. -func (mr *MockPacketHandlerMockRecorder) shutdown() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockPacketHandler)(nil).shutdown)) -} diff --git a/internal/quic-go/mock_packetconn_test.go b/internal/quic-go/mock_packetconn_test.go deleted file mode 100644 index d6731e4a..00000000 --- a/internal/quic-go/mock_packetconn_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: net (interfaces: PacketConn) - -// Package quic is a generated GoMock package. -package quic - -import ( - net "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" -) - -// MockPacketConn is a mock of PacketConn interface. -type MockPacketConn struct { - ctrl *gomock.Controller - recorder *MockPacketConnMockRecorder -} - -// MockPacketConnMockRecorder is the mock recorder for MockPacketConn. -type MockPacketConnMockRecorder struct { - mock *MockPacketConn -} - -// NewMockPacketConn creates a new mock instance. -func NewMockPacketConn(ctrl *gomock.Controller) *MockPacketConn { - mock := &MockPacketConn{ctrl: ctrl} - mock.recorder = &MockPacketConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketConn) EXPECT() *MockPacketConnMockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockPacketConn) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockPacketConnMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) -} - -// LocalAddr mocks base method. -func (m *MockPacketConn) LocalAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr. -func (mr *MockPacketConnMockRecorder) LocalAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) -} - -// ReadFrom mocks base method. -func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadFrom", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(net.Addr) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// ReadFrom indicates an expected call of ReadFrom. -func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) -} - -// SetDeadline mocks base method. -func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetDeadline indicates an expected call of SetDeadline. -func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) -} - -// SetReadDeadline mocks base method. -func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetReadDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) -} - -// SetWriteDeadline mocks base method. -func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) -} - -// WriteTo mocks base method. -func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WriteTo", arg0, arg1) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// WriteTo indicates an expected call of WriteTo. -func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) -} diff --git a/internal/quic-go/mock_quic_conn_test.go b/internal/quic-go/mock_quic_conn_test.go deleted file mode 100644 index c79523ed..00000000 --- a/internal/quic-go/mock_quic_conn_test.go +++ /dev/null @@ -1,346 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: server.go - -// Package quic is a generated GoMock package. -package quic - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockQuicConn is a mock of QuicConn interface. -type MockQuicConn struct { - ctrl *gomock.Controller - recorder *MockQuicConnMockRecorder -} - -// MockQuicConnMockRecorder is the mock recorder for MockQuicConn. -type MockQuicConnMockRecorder struct { - mock *MockQuicConn -} - -// NewMockQuicConn creates a new mock instance. -func NewMockQuicConn(ctrl *gomock.Controller) *MockQuicConn { - mock := &MockQuicConn{ctrl: ctrl} - mock.recorder = &MockQuicConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockQuicConn) EXPECT() *MockQuicConnMockRecorder { - return m.recorder -} - -// AcceptStream mocks base method. -func (m *MockQuicConn) AcceptStream(arg0 context.Context) (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream", arg0) - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptStream indicates an expected call of AcceptStream. -func (mr *MockQuicConnMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptStream), arg0) -} - -// AcceptUniStream mocks base method. -func (m *MockQuicConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream", arg0) - ret0, _ := ret[0].(ReceiveStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockQuicConnMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptUniStream), arg0) -} - -// CloseWithError mocks base method. -func (m *MockQuicConn) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseWithError indicates an expected call of CloseWithError. -func (mr *MockQuicConnMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicConn)(nil).CloseWithError), arg0, arg1) -} - -// ConnectionState mocks base method. -func (m *MockQuicConn) ConnectionState() ConnectionState { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState. -func (mr *MockQuicConnMockRecorder) ConnectionState() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQuicConn)(nil).ConnectionState)) -} - -// Context mocks base method. -func (m *MockQuicConn) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockQuicConnMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicConn)(nil).Context)) -} - -// GetVersion mocks base method. -func (m *MockQuicConn) GetVersion() protocol.VersionNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVersion") - ret0, _ := ret[0].(protocol.VersionNumber) - return ret0 -} - -// GetVersion indicates an expected call of GetVersion. -func (mr *MockQuicConnMockRecorder) GetVersion() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicConn)(nil).GetVersion)) -} - -// HandshakeComplete mocks base method. -func (m *MockQuicConn) HandshakeComplete() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandshakeComplete") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockQuicConnMockRecorder) HandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicConn)(nil).HandshakeComplete)) -} - -// LocalAddr mocks base method. -func (m *MockQuicConn) LocalAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr. -func (mr *MockQuicConnMockRecorder) LocalAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicConn)(nil).LocalAddr)) -} - -// NextConnection mocks base method. -func (m *MockQuicConn) NextConnection() Connection { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextConnection") - ret0, _ := ret[0].(Connection) - return ret0 -} - -// NextConnection indicates an expected call of NextConnection. -func (mr *MockQuicConnMockRecorder) NextConnection() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQuicConn)(nil).NextConnection)) -} - -// OpenStream mocks base method. -func (m *MockQuicConn) OpenStream() (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream") - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockQuicConnMockRecorder) OpenStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQuicConn)(nil).OpenStream)) -} - -// OpenStreamSync mocks base method. -func (m *MockQuicConn) OpenStreamSync(arg0 context.Context) (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync", arg0) - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockQuicConnMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenStreamSync), arg0) -} - -// OpenUniStream mocks base method. -func (m *MockQuicConn) OpenUniStream() (SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStream") - ret0, _ := ret[0].(SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockQuicConnMockRecorder) OpenUniStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStream)) -} - -// OpenUniStreamSync mocks base method. -func (m *MockQuicConn) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) - ret0, _ := ret[0].(SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockQuicConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStreamSync), arg0) -} - -// ReceiveMessage mocks base method. -func (m *MockQuicConn) ReceiveMessage() ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQuicConnMockRecorder) ReceiveMessage() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicConn)(nil).ReceiveMessage)) -} - -// RemoteAddr mocks base method. -func (m *MockQuicConn) RemoteAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoteAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockQuicConnMockRecorder) RemoteAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicConn)(nil).RemoteAddr)) -} - -// SendMessage mocks base method. -func (m *MockQuicConn) SendMessage(arg0 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMessage", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendMessage indicates an expected call of SendMessage. -func (mr *MockQuicConnMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicConn)(nil).SendMessage), arg0) -} - -// destroy mocks base method. -func (m *MockQuicConn) destroy(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "destroy", arg0) -} - -// destroy indicates an expected call of destroy. -func (mr *MockQuicConnMockRecorder) destroy(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicConn)(nil).destroy), arg0) -} - -// earlyConnReady mocks base method. -func (m *MockQuicConn) earlyConnReady() <-chan struct{} { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "earlyConnReady") - ret0, _ := ret[0].(<-chan struct{}) - return ret0 -} - -// earlyConnReady indicates an expected call of earlyConnReady. -func (mr *MockQuicConnMockRecorder) earlyConnReady() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlyConnReady", reflect.TypeOf((*MockQuicConn)(nil).earlyConnReady)) -} - -// getPerspective mocks base method. -func (m *MockQuicConn) getPerspective() protocol.Perspective { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 -} - -// getPerspective indicates an expected call of getPerspective. -func (mr *MockQuicConnMockRecorder) getPerspective() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicConn)(nil).getPerspective)) -} - -// handlePacket mocks base method. -func (m *MockQuicConn) handlePacket(arg0 *receivedPacket) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handlePacket", arg0) -} - -// handlePacket indicates an expected call of handlePacket. -func (mr *MockQuicConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQuicConn)(nil).handlePacket), arg0) -} - -// run mocks base method. -func (m *MockQuicConn) run() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "run") - ret0, _ := ret[0].(error) - return ret0 -} - -// run indicates an expected call of run. -func (mr *MockQuicConnMockRecorder) run() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQuicConn)(nil).run)) -} - -// shutdown mocks base method. -func (m *MockQuicConn) shutdown() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "shutdown") -} - -// shutdown indicates an expected call of shutdown. -func (mr *MockQuicConnMockRecorder) shutdown() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockQuicConn)(nil).shutdown)) -} diff --git a/internal/quic-go/mock_receive_stream_internal_test.go b/internal/quic-go/mock_receive_stream_internal_test.go deleted file mode 100644 index 5389b85f..00000000 --- a/internal/quic-go/mock_receive_stream_internal_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: receive_stream.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockReceiveStreamI is a mock of ReceiveStreamI interface. -type MockReceiveStreamI struct { - ctrl *gomock.Controller - recorder *MockReceiveStreamIMockRecorder -} - -// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI. -type MockReceiveStreamIMockRecorder struct { - mock *MockReceiveStreamI -} - -// NewMockReceiveStreamI creates a new mock instance. -func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI { - mock := &MockReceiveStreamI{ctrl: ctrl} - mock.recorder = &MockReceiveStreamIMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder { - return m.recorder -} - -// CancelRead mocks base method. -func (m *MockReceiveStreamI) CancelRead(arg0 StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelRead", arg0) -} - -// CancelRead indicates an expected call of CancelRead. -func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) -} - -// Read mocks base method. -func (m *MockReceiveStreamI) Read(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockReceiveStreamIMockRecorder) Read(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), p) -} - -// SetReadDeadline mocks base method. -func (m *MockReceiveStreamI) SetReadDeadline(t time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetReadDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), t) -} - -// StreamID mocks base method. -func (m *MockReceiveStreamI) StreamID() StreamID { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(StreamID) - return ret0 -} - -// StreamID indicates an expected call of StreamID. -func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID)) -} - -// closeForShutdown mocks base method. -func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "closeForShutdown", arg0) -} - -// closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) -} - -// getWindowUpdate mocks base method. -func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) -} - -// handleResetStreamFrame mocks base method. -func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handleResetStreamFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) -} - -// handleStreamFrame mocks base method. -func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handleStreamFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) -} diff --git a/internal/quic-go/mock_sealing_manager_test.go b/internal/quic-go/mock_sealing_manager_test.go deleted file mode 100644 index a046c897..00000000 --- a/internal/quic-go/mock_sealing_manager_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: packet_packer.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - handshake "github.com/imroc/req/v3/internal/quic-go/handshake" -) - -// MockSealingManager is a mock of SealingManager interface. -type MockSealingManager struct { - ctrl *gomock.Controller - recorder *MockSealingManagerMockRecorder -} - -// MockSealingManagerMockRecorder is the mock recorder for MockSealingManager. -type MockSealingManagerMockRecorder struct { - mock *MockSealingManager -} - -// NewMockSealingManager creates a new mock instance. -func NewMockSealingManager(ctrl *gomock.Controller) *MockSealingManager { - mock := &MockSealingManager{ctrl: ctrl} - mock.recorder = &MockSealingManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder { - return m.recorder -} - -// Get0RTTSealer mocks base method. -func (m *MockSealingManager) Get0RTTSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get0RTTSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get0RTTSealer indicates an expected call of Get0RTTSealer. -func (mr *MockSealingManagerMockRecorder) Get0RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get0RTTSealer)) -} - -// Get1RTTSealer mocks base method. -func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get1RTTSealer") - ret0, _ := ret[0].(handshake.ShortHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get1RTTSealer indicates an expected call of Get1RTTSealer. -func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get1RTTSealer)) -} - -// GetHandshakeSealer mocks base method. -func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHandshakeSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. -func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockSealingManager)(nil).GetHandshakeSealer)) -} - -// GetInitialSealer mocks base method. -func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInitialSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInitialSealer indicates an expected call of GetInitialSealer. -func (mr *MockSealingManagerMockRecorder) GetInitialSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockSealingManager)(nil).GetInitialSealer)) -} diff --git a/internal/quic-go/mock_send_conn_test.go b/internal/quic-go/mock_send_conn_test.go deleted file mode 100644 index d66fec5f..00000000 --- a/internal/quic-go/mock_send_conn_test.go +++ /dev/null @@ -1,91 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: send_conn.go - -// Package quic is a generated GoMock package. -package quic - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockSendConn is a mock of SendConn interface. -type MockSendConn struct { - ctrl *gomock.Controller - recorder *MockSendConnMockRecorder -} - -// MockSendConnMockRecorder is the mock recorder for MockSendConn. -type MockSendConnMockRecorder struct { - mock *MockSendConn -} - -// NewMockSendConn creates a new mock instance. -func NewMockSendConn(ctrl *gomock.Controller) *MockSendConn { - mock := &MockSendConn{ctrl: ctrl} - mock.recorder = &MockSendConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSendConn) EXPECT() *MockSendConnMockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockSendConn) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockSendConnMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendConn)(nil).Close)) -} - -// LocalAddr mocks base method. -func (m *MockSendConn) LocalAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr. -func (mr *MockSendConnMockRecorder) LocalAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSendConn)(nil).LocalAddr)) -} - -// RemoteAddr mocks base method. -func (m *MockSendConn) RemoteAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoteAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSendConn)(nil).RemoteAddr)) -} - -// Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0) -} diff --git a/internal/quic-go/mock_send_stream_internal_test.go b/internal/quic-go/mock_send_stream_internal_test.go deleted file mode 100644 index 7ce194aa..00000000 --- a/internal/quic-go/mock_send_stream_internal_test.go +++ /dev/null @@ -1,187 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: send_stream.go - -// Package quic is a generated GoMock package. -package quic - -import ( - context "context" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockSendStreamI is a mock of SendStreamI interface. -type MockSendStreamI struct { - ctrl *gomock.Controller - recorder *MockSendStreamIMockRecorder -} - -// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI. -type MockSendStreamIMockRecorder struct { - mock *MockSendStreamI -} - -// NewMockSendStreamI creates a new mock instance. -func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI { - mock := &MockSendStreamI{ctrl: ctrl} - mock.recorder = &MockSendStreamIMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder { - return m.recorder -} - -// CancelWrite mocks base method. -func (m *MockSendStreamI) CancelWrite(arg0 StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelWrite", arg0) -} - -// CancelWrite indicates an expected call of CancelWrite. -func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) -} - -// Close mocks base method. -func (m *MockSendStreamI) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close)) -} - -// Context mocks base method. -func (m *MockSendStreamI) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context)) -} - -// SetWriteDeadline mocks base method. -func (m *MockSendStreamI) SetWriteDeadline(t time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWriteDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), t) -} - -// StreamID mocks base method. -func (m *MockSendStreamI) StreamID() StreamID { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(StreamID) - return ret0 -} - -// StreamID indicates an expected call of StreamID. -func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID)) -} - -// Write mocks base method. -func (m *MockSendStreamI) Write(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockSendStreamIMockRecorder) Write(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), p) -} - -// closeForShutdown mocks base method. -func (m *MockSendStreamI) closeForShutdown(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "closeForShutdown", arg0) -} - -// closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) -} - -// handleStopSendingFrame mocks base method. -func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handleStopSendingFrame", arg0) -} - -// handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) -} - -// hasData mocks base method. -func (m *MockSendStreamI) hasData() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "hasData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// hasData indicates an expected call of hasData. -func (mr *MockSendStreamIMockRecorder) hasData() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData)) -} - -// popStreamFrame mocks base method. -func (m *MockSendStreamI) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "popStreamFrame", maxBytes) - ret0, _ := ret[0].(*ackhandler.Frame) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockSendStreamIMockRecorder) popStreamFrame(maxBytes interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), maxBytes) -} - -// updateSendWindow mocks base method. -func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "updateSendWindow", arg0) -} - -// updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) -} diff --git a/internal/quic-go/mock_sender_test.go b/internal/quic-go/mock_sender_test.go deleted file mode 100644 index bad5f149..00000000 --- a/internal/quic-go/mock_sender_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: send_queue.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockSender is a mock of Sender interface. -type MockSender struct { - ctrl *gomock.Controller - recorder *MockSenderMockRecorder -} - -// MockSenderMockRecorder is the mock recorder for MockSender. -type MockSenderMockRecorder struct { - mock *MockSender -} - -// NewMockSender creates a new mock instance. -func NewMockSender(ctrl *gomock.Controller) *MockSender { - mock := &MockSender{ctrl: ctrl} - mock.recorder = &MockSenderMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSender) EXPECT() *MockSenderMockRecorder { - return m.recorder -} - -// Available mocks base method. -func (m *MockSender) Available() <-chan struct{} { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Available") - ret0, _ := ret[0].(<-chan struct{}) - return ret0 -} - -// Available indicates an expected call of Available. -func (mr *MockSenderMockRecorder) Available() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockSender)(nil).Available)) -} - -// Close mocks base method. -func (m *MockSender) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockSenderMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSender)(nil).Close)) -} - -// Run mocks base method. -func (m *MockSender) Run() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Run") - ret0, _ := ret[0].(error) - return ret0 -} - -// Run indicates an expected call of Run. -func (mr *MockSenderMockRecorder) Run() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSender)(nil).Run)) -} - -// Send mocks base method. -func (m *MockSender) Send(p *packetBuffer) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Send", p) -} - -// Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), p) -} - -// WouldBlock mocks base method. -func (m *MockSender) WouldBlock() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WouldBlock") - ret0, _ := ret[0].(bool) - return ret0 -} - -// WouldBlock indicates an expected call of WouldBlock. -func (mr *MockSenderMockRecorder) WouldBlock() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WouldBlock", reflect.TypeOf((*MockSender)(nil).WouldBlock)) -} diff --git a/internal/quic-go/mock_stream_getter_test.go b/internal/quic-go/mock_stream_getter_test.go deleted file mode 100644 index 71df5186..00000000 --- a/internal/quic-go/mock_stream_getter_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: connection.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockStreamGetter is a mock of StreamGetter interface. -type MockStreamGetter struct { - ctrl *gomock.Controller - recorder *MockStreamGetterMockRecorder -} - -// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter. -type MockStreamGetterMockRecorder struct { - mock *MockStreamGetter -} - -// NewMockStreamGetter creates a new mock instance. -func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter { - mock := &MockStreamGetter{ctrl: ctrl} - mock.recorder = &MockStreamGetterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder { - return m.recorder -} - -// GetOrOpenReceiveStream mocks base method. -func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) - ret0, _ := ret[0].(receiveStreamI) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) -} - -// GetOrOpenSendStream mocks base method. -func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) - ret0, _ := ret[0].(sendStreamI) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) -} diff --git a/internal/quic-go/mock_stream_internal_test.go b/internal/quic-go/mock_stream_internal_test.go deleted file mode 100644 index 381eb2ae..00000000 --- a/internal/quic-go/mock_stream_internal_test.go +++ /dev/null @@ -1,284 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: stream.go - -// Package quic is a generated GoMock package. -package quic - -import ( - context "context" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockStreamI is a mock of StreamI interface. -type MockStreamI struct { - ctrl *gomock.Controller - recorder *MockStreamIMockRecorder -} - -// MockStreamIMockRecorder is the mock recorder for MockStreamI. -type MockStreamIMockRecorder struct { - mock *MockStreamI -} - -// NewMockStreamI creates a new mock instance. -func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI { - mock := &MockStreamI{ctrl: ctrl} - mock.recorder = &MockStreamIMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder { - return m.recorder -} - -// CancelRead mocks base method. -func (m *MockStreamI) CancelRead(arg0 StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelRead", arg0) -} - -// CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) -} - -// CancelWrite mocks base method. -func (m *MockStreamI) CancelWrite(arg0 StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelWrite", arg0) -} - -// CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) -} - -// Close mocks base method. -func (m *MockStreamI) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockStreamIMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) -} - -// Context mocks base method. -func (m *MockStreamI) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockStreamIMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) -} - -// Read mocks base method. -func (m *MockStreamI) Read(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockStreamIMockRecorder) Read(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), p) -} - -// SetDeadline mocks base method. -func (m *MockStreamI) SetDeadline(t time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamIMockRecorder) SetDeadline(t interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), t) -} - -// SetReadDeadline mocks base method. -func (m *MockStreamI) SetReadDeadline(t time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetReadDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamIMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), t) -} - -// SetWriteDeadline mocks base method. -func (m *MockStreamI) SetWriteDeadline(t time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWriteDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamIMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), t) -} - -// StreamID mocks base method. -func (m *MockStreamI) StreamID() StreamID { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(StreamID) - return ret0 -} - -// StreamID indicates an expected call of StreamID. -func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) -} - -// Write mocks base method. -func (m *MockStreamI) Write(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockStreamIMockRecorder) Write(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), p) -} - -// closeForShutdown mocks base method. -func (m *MockStreamI) closeForShutdown(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "closeForShutdown", arg0) -} - -// closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) -} - -// getWindowUpdate mocks base method. -func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) -} - -// handleResetStreamFrame mocks base method. -func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handleResetStreamFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) -} - -// handleStopSendingFrame mocks base method. -func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handleStopSendingFrame", arg0) -} - -// handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) -} - -// handleStreamFrame mocks base method. -func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handleStreamFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) -} - -// hasData mocks base method. -func (m *MockStreamI) hasData() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "hasData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// hasData indicates an expected call of hasData. -func (mr *MockStreamIMockRecorder) hasData() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockStreamI)(nil).hasData)) -} - -// popStreamFrame mocks base method. -func (m *MockStreamI) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "popStreamFrame", maxBytes) - ret0, _ := ret[0].(*ackhandler.Frame) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockStreamIMockRecorder) popStreamFrame(maxBytes interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), maxBytes) -} - -// updateSendWindow mocks base method. -func (m *MockStreamI) updateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "updateSendWindow", arg0) -} - -// updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) -} diff --git a/internal/quic-go/mock_stream_manager_test.go b/internal/quic-go/mock_stream_manager_test.go deleted file mode 100644 index 34d7b72c..00000000 --- a/internal/quic-go/mock_stream_manager_test.go +++ /dev/null @@ -1,231 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: connection.go - -// Package quic is a generated GoMock package. -package quic - -import ( - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockStreamManager is a mock of StreamManager interface. -type MockStreamManager struct { - ctrl *gomock.Controller - recorder *MockStreamManagerMockRecorder -} - -// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager. -type MockStreamManagerMockRecorder struct { - mock *MockStreamManager -} - -// NewMockStreamManager creates a new mock instance. -func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager { - mock := &MockStreamManager{ctrl: ctrl} - mock.recorder = &MockStreamManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { - return m.recorder -} - -// AcceptStream mocks base method. -func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream", arg0) - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptStream indicates an expected call of AcceptStream. -func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) -} - -// AcceptUniStream mocks base method. -func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream", arg0) - ret0, _ := ret[0].(ReceiveStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) -} - -// CloseWithError mocks base method. -func (m *MockStreamManager) CloseWithError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseWithError", arg0) -} - -// CloseWithError indicates an expected call of CloseWithError. -func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) -} - -// DeleteStream mocks base method. -func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteStream", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteStream indicates an expected call of DeleteStream. -func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) -} - -// GetOrOpenReceiveStream mocks base method. -func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) - ret0, _ := ret[0].(receiveStreamI) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) -} - -// GetOrOpenSendStream mocks base method. -func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) - ret0, _ := ret[0].(sendStreamI) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) -} - -// HandleMaxStreamsFrame mocks base method. -func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0) -} - -// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) -} - -// OpenStream mocks base method. -func (m *MockStreamManager) OpenStream() (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream") - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) -} - -// OpenStreamSync mocks base method. -func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync", arg0) - ret0, _ := ret[0].(Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) -} - -// OpenUniStream mocks base method. -func (m *MockStreamManager) OpenUniStream() (SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStream") - ret0, _ := ret[0].(SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream)) -} - -// OpenUniStreamSync mocks base method. -func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) - ret0, _ := ret[0].(SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) -} - -// ResetFor0RTT mocks base method. -func (m *MockStreamManager) ResetFor0RTT() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ResetFor0RTT") -} - -// ResetFor0RTT indicates an expected call of ResetFor0RTT. -func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) -} - -// UpdateLimits mocks base method. -func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateLimits", arg0) -} - -// UpdateLimits indicates an expected call of UpdateLimits. -func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) -} - -// UseResetMaps mocks base method. -func (m *MockStreamManager) UseResetMaps() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UseResetMaps") -} - -// UseResetMaps indicates an expected call of UseResetMaps. -func (mr *MockStreamManagerMockRecorder) UseResetMaps() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) -} diff --git a/internal/quic-go/mock_stream_sender_test.go b/internal/quic-go/mock_stream_sender_test.go deleted file mode 100644 index 3cd97f48..00000000 --- a/internal/quic-go/mock_stream_sender_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: stream.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockStreamSender is a mock of StreamSender interface. -type MockStreamSender struct { - ctrl *gomock.Controller - recorder *MockStreamSenderMockRecorder -} - -// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender. -type MockStreamSenderMockRecorder struct { - mock *MockStreamSender -} - -// NewMockStreamSender creates a new mock instance. -func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { - mock := &MockStreamSender{ctrl: ctrl} - mock.recorder = &MockStreamSenderMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { - return m.recorder -} - -// onHasStreamData mocks base method. -func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "onHasStreamData", arg0) -} - -// onHasStreamData indicates an expected call of onHasStreamData. -func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) -} - -// onStreamCompleted mocks base method. -func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "onStreamCompleted", arg0) -} - -// onStreamCompleted indicates an expected call of onStreamCompleted. -func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) -} - -// queueControlFrame mocks base method. -func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "queueControlFrame", arg0) -} - -// queueControlFrame indicates an expected call of queueControlFrame. -func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) -} diff --git a/internal/quic-go/mock_token_store_test.go b/internal/quic-go/mock_token_store_test.go deleted file mode 100644 index a0f02b41..00000000 --- a/internal/quic-go/mock_token_store_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: TokenStore) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockTokenStore is a mock of TokenStore interface. -type MockTokenStore struct { - ctrl *gomock.Controller - recorder *MockTokenStoreMockRecorder -} - -// MockTokenStoreMockRecorder is the mock recorder for MockTokenStore. -type MockTokenStoreMockRecorder struct { - mock *MockTokenStore -} - -// NewMockTokenStore creates a new mock instance. -func NewMockTokenStore(ctrl *gomock.Controller) *MockTokenStore { - mock := &MockTokenStore{ctrl: ctrl} - mock.recorder = &MockTokenStoreMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTokenStore) EXPECT() *MockTokenStoreMockRecorder { - return m.recorder -} - -// Pop mocks base method. -func (m *MockTokenStore) Pop(arg0 string) *ClientToken { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Pop", arg0) - ret0, _ := ret[0].(*ClientToken) - return ret0 -} - -// Pop indicates an expected call of Pop. -func (mr *MockTokenStoreMockRecorder) Pop(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) -} - -// Put mocks base method. -func (m *MockTokenStore) Put(arg0 string, arg1 *ClientToken) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Put", arg0, arg1) -} - -// Put indicates an expected call of Put. -func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) -} diff --git a/internal/quic-go/mock_unknown_packet_handler_test.go b/internal/quic-go/mock_unknown_packet_handler_test.go deleted file mode 100644 index d82acf1a..00000000 --- a/internal/quic-go/mock_unknown_packet_handler_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: server.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. -type MockUnknownPacketHandler struct { - ctrl *gomock.Controller - recorder *MockUnknownPacketHandlerMockRecorder -} - -// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler. -type MockUnknownPacketHandlerMockRecorder struct { - mock *MockUnknownPacketHandler -} - -// NewMockUnknownPacketHandler creates a new mock instance. -func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler { - mock := &MockUnknownPacketHandler{ctrl: ctrl} - mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder { - return m.recorder -} - -// handlePacket mocks base method. -func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handlePacket", arg0) -} - -// handlePacket indicates an expected call of handlePacket. -func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) -} - -// setCloseError mocks base method. -func (m *MockUnknownPacketHandler) setCloseError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setCloseError", arg0) -} - -// setCloseError indicates an expected call of setCloseError. -func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0) -} diff --git a/internal/quic-go/mock_unpacker_test.go b/internal/quic-go/mock_unpacker_test.go deleted file mode 100644 index 0dca2d03..00000000 --- a/internal/quic-go/mock_unpacker_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: connection.go - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockUnpacker is a mock of Unpacker interface. -type MockUnpacker struct { - ctrl *gomock.Controller - recorder *MockUnpackerMockRecorder -} - -// MockUnpackerMockRecorder is the mock recorder for MockUnpacker. -type MockUnpackerMockRecorder struct { - mock *MockUnpacker -} - -// NewMockUnpacker creates a new mock instance. -func NewMockUnpacker(ctrl *gomock.Controller) *MockUnpacker { - mock := &MockUnpacker{ctrl: ctrl} - mock.recorder = &MockUnpackerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { - return m.recorder -} - -// Unpack mocks base method. -func (m *MockUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Unpack", hdr, rcvTime, data) - ret0, _ := ret[0].(*unpackedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Unpack indicates an expected call of Unpack. -func (mr *MockUnpackerMockRecorder) Unpack(hdr, rcvTime, data interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), hdr, rcvTime, data) -} diff --git a/internal/quic-go/mockgen.go b/internal/quic-go/mockgen.go deleted file mode 100644 index edecbbe3..00000000 --- a/internal/quic-go/mockgen.go +++ /dev/null @@ -1,27 +0,0 @@ -package quic - -//go:generate sh -c "./mockgen_private.sh quic mock_send_conn_test.go github.com/imroc/req/v3/internal/quic-go sendConn" -//go:generate sh -c "./mockgen_private.sh quic mock_sender_test.go github.com/imroc/req/v3/internal/quic-go sender" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go streamI" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/imroc/req/v3/internal/quic-go cryptoStream" -//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go receiveStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/imroc/req/v3/internal/quic-go sendStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/imroc/req/v3/internal/quic-go streamSender" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/imroc/req/v3/internal/quic-go streamGetter" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_data_handler_test.go github.com/imroc/req/v3/internal/quic-go cryptoDataHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/imroc/req/v3/internal/quic-go frameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/imroc/req/v3/internal/quic-go ackFrameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/imroc/req/v3/internal/quic-go streamManager" -//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/imroc/req/v3/internal/quic-go sealingManager" -//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/imroc/req/v3/internal/quic-go unpacker" -//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/imroc/req/v3/internal/quic-go packer" -//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/imroc/req/v3/internal/quic-go mtuDiscoverer" -//go:generate sh -c "./mockgen_private.sh quic mock_conn_runner_test.go github.com/imroc/req/v3/internal/quic-go connRunner" -//go:generate sh -c "./mockgen_private.sh quic mock_quic_conn_test.go github.com/imroc/req/v3/internal/quic-go quicConn" -//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/imroc/req/v3/internal/quic-go packetHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/imroc/req/v3/internal/quic-go unknownPacketHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/imroc/req/v3/internal/quic-go packetHandlerManager" -//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/imroc/req/v3/internal/quic-go multiplexer" -//go:generate sh -c "./mockgen_private.sh quic mock_batch_conn_test.go github.com/imroc/req/v3/internal/quic-go batchConn" -//go:generate sh -c "mockgen -package quic -self_package github.com/imroc/req/v3/internal/quic-go -destination mock_token_store_test.go github.com/imroc/req/v3/internal/quic-go TokenStore" -//go:generate sh -c "mockgen -package quic -self_package github.com/imroc/req/v3/internal/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/internal/quic-go/mockgen_private.sh b/internal/quic-go/mockgen_private.sh deleted file mode 100755 index 92829d77..00000000 --- a/internal/quic-go/mockgen_private.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -DEST=$2 -PACKAGE=$3 -TMPFILE="mockgen_tmp.go" -# uppercase the name of the interface -ORIG_INTERFACE_NAME=$4 -INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" - -# Gather all files that contain interface definitions. -# These interfaces might be used as embedded interfaces, -# so we need to pass them to mockgen as aux_files. -AUX=() -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - if $(egrep -qe "type (.*) interface" $f); then - AUX+=("github.com/lucas-clemente/quic-go=$f") - fi -done - -# Find the file that defines the interface we're mocking. -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) - if [[ -n "$INTERFACE" ]]; then - SRC=$f - break - fi -done - -if [[ -z "$INTERFACE" ]]; then - echo "Interface $ORIG_INTERFACE_NAME not found." - exit 1 -fi - -AUX_FILES=$(IFS=, ; echo "${AUX[*]}") - -## create a public alias for the interface, so that mockgen can process it -echo -e "package $1\n" > $TMPFILE -echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE -mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES -sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" -rm "$TMPFILE" diff --git a/internal/quic-go/mocks/ackhandler/received_packet_handler.go b/internal/quic-go/mocks/ackhandler/received_packet_handler.go deleted file mode 100644 index 004fed2e..00000000 --- a/internal/quic-go/mocks/ackhandler/received_packet_handler.go +++ /dev/null @@ -1,105 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/ackhandler (interfaces: ReceivedPacketHandler) - -// Package mockackhandler is a generated GoMock package. -package mockackhandler - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. -type MockReceivedPacketHandler struct { - ctrl *gomock.Controller - recorder *MockReceivedPacketHandlerMockRecorder -} - -// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler. -type MockReceivedPacketHandlerMockRecorder struct { - mock *MockReceivedPacketHandler -} - -// NewMockReceivedPacketHandler creates a new mock instance. -func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { - mock := &MockReceivedPacketHandler{ctrl: ctrl} - mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { - return m.recorder -} - -// DropPackets mocks base method. -func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropPackets", arg0) -} - -// DropPackets indicates an expected call of DropPackets. -func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) -} - -// GetAckFrame mocks base method. -func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) - ret0, _ := ret[0].(*wire.AckFrame) - return ret0 -} - -// GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) -} - -// GetAlarmTimeout mocks base method. -func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAlarmTimeout") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// GetAlarmTimeout indicates an expected call of GetAlarmTimeout. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) -} - -// IsPotentiallyDuplicate mocks base method. -func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsPotentiallyDuplicate", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. -func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) -} - -// ReceivedPacket mocks base method. -func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN, arg2 protocol.EncryptionLevel, arg3 time.Time, arg4 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) -} diff --git a/internal/quic-go/mocks/ackhandler/sent_packet_handler.go b/internal/quic-go/mocks/ackhandler/sent_packet_handler.go deleted file mode 100644 index ed16a6ed..00000000 --- a/internal/quic-go/mocks/ackhandler/sent_packet_handler.go +++ /dev/null @@ -1,240 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/ackhandler (interfaces: SentPacketHandler) - -// Package mockackhandler is a generated GoMock package. -package mockackhandler - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - ackhandler "github.com/imroc/req/v3/internal/quic-go/ackhandler" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// MockSentPacketHandler is a mock of SentPacketHandler interface. -type MockSentPacketHandler struct { - ctrl *gomock.Controller - recorder *MockSentPacketHandlerMockRecorder -} - -// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler. -type MockSentPacketHandlerMockRecorder struct { - mock *MockSentPacketHandler -} - -// NewMockSentPacketHandler creates a new mock instance. -func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { - mock := &MockSentPacketHandler{ctrl: ctrl} - mock.recorder = &MockSentPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { - return m.recorder -} - -// DropPackets mocks base method. -func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropPackets", arg0) -} - -// DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) -} - -// GetLossDetectionTimeout mocks base method. -func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLossDetectionTimeout") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) -} - -// HasPacingBudget mocks base method. -func (m *MockSentPacketHandler) HasPacingBudget() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasPacingBudget") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSentPacketHandlerMockRecorder) HasPacingBudget() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSentPacketHandler)(nil).HasPacingBudget)) -} - -// OnLossDetectionTimeout mocks base method. -func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnLossDetectionTimeout") - ret0, _ := ret[0].(error) - return ret0 -} - -// OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) -} - -// PeekPacketNumber mocks base method. -func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PeekPacketNumber", arg0) - ret0, _ := ret[0].(protocol.PacketNumber) - ret1, _ := ret[1].(protocol.PacketNumberLen) - return ret0, ret1 -} - -// PeekPacketNumber indicates an expected call of PeekPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) -} - -// PopPacketNumber mocks base method. -func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PopPacketNumber", arg0) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// PopPacketNumber indicates an expected call of PopPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) -} - -// QueueProbePacket mocks base method. -func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueueProbePacket", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// QueueProbePacket indicates an expected call of QueueProbePacket. -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) -} - -// ReceivedAck mocks base method. -func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReceivedAck indicates an expected call of ReceivedAck. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) -} - -// ReceivedBytes mocks base method. -func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedBytes", arg0) -} - -// ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) -} - -// ResetForRetry mocks base method. -func (m *MockSentPacketHandler) ResetForRetry() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetForRetry") - ret0, _ := ret[0].(error) - return ret0 -} - -// ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) -} - -// SendMode mocks base method. -func (m *MockSentPacketHandler) SendMode() ackhandler.SendMode { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMode") - ret0, _ := ret[0].(ackhandler.SendMode) - return ret0 -} - -// SendMode indicates an expected call of SendMode. -func (mr *MockSentPacketHandlerMockRecorder) SendMode() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode)) -} - -// SentPacket mocks base method. -func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) -} - -// SetHandshakeConfirmed mocks base method. -func (m *MockSentPacketHandler) SetHandshakeConfirmed() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeConfirmed") -} - -// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) -} - -// SetMaxDatagramSize mocks base method. -func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxDatagramSize", arg0) -} - -// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) -} - -// TimeUntilSend mocks base method. -func (m *MockSentPacketHandler) TimeUntilSend() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TimeUntilSend") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) -} diff --git a/internal/quic-go/mocks/congestion.go b/internal/quic-go/mocks/congestion.go deleted file mode 100644 index 6c92f86a..00000000 --- a/internal/quic-go/mocks/congestion.go +++ /dev/null @@ -1,192 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/congestion (interfaces: SendAlgorithmWithDebugInfos) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockSendAlgorithmWithDebugInfos is a mock of SendAlgorithmWithDebugInfos interface. -type MockSendAlgorithmWithDebugInfos struct { - ctrl *gomock.Controller - recorder *MockSendAlgorithmWithDebugInfosMockRecorder -} - -// MockSendAlgorithmWithDebugInfosMockRecorder is the mock recorder for MockSendAlgorithmWithDebugInfos. -type MockSendAlgorithmWithDebugInfosMockRecorder struct { - mock *MockSendAlgorithmWithDebugInfos -} - -// NewMockSendAlgorithmWithDebugInfos creates a new mock instance. -func NewMockSendAlgorithmWithDebugInfos(ctrl *gomock.Controller) *MockSendAlgorithmWithDebugInfos { - mock := &MockSendAlgorithmWithDebugInfos{ctrl: ctrl} - mock.recorder = &MockSendAlgorithmWithDebugInfosMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSendAlgorithmWithDebugInfos) EXPECT() *MockSendAlgorithmWithDebugInfosMockRecorder { - return m.recorder -} - -// CanSend mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSend", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// CanSend indicates an expected call of CanSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) -} - -// GetCongestionWindow mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCongestionWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetCongestionWindow indicates an expected call of GetCongestionWindow. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) -} - -// HasPacingBudget mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasPacingBudget") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget)) -} - -// InRecovery mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InRecovery") - ret0, _ := ret[0].(bool) - return ret0 -} - -// InRecovery indicates an expected call of InRecovery. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) -} - -// InSlowStart mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InSlowStart") - ret0, _ := ret[0].(bool) - return ret0 -} - -// InSlowStart indicates an expected call of InSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) -} - -// MaybeExitSlowStart mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "MaybeExitSlowStart") -} - -// MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) -} - -// OnPacketAcked mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) -} - -// OnPacketAcked indicates an expected call of OnPacketAcked. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) -} - -// OnPacketLost mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) -} - -// OnPacketLost indicates an expected call of OnPacketLost. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) -} - -// OnPacketSent mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) -} - -// OnPacketSent indicates an expected call of OnPacketSent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) -} - -// OnRetransmissionTimeout mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnRetransmissionTimeout", arg0) -} - -// OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) -} - -// SetMaxDatagramSize mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxDatagramSize", arg0) -} - -// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) -} - -// TimeUntilSend mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TimeUntilSend", arg0) - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) -} diff --git a/internal/quic-go/mocks/connection_flow_controller.go b/internal/quic-go/mocks/connection_flow_controller.go deleted file mode 100644 index ee8a14ea..00000000 --- a/internal/quic-go/mocks/connection_flow_controller.go +++ /dev/null @@ -1,128 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/flowcontrol (interfaces: ConnectionFlowController) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockConnectionFlowController is a mock of ConnectionFlowController interface. -type MockConnectionFlowController struct { - ctrl *gomock.Controller - recorder *MockConnectionFlowControllerMockRecorder -} - -// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController. -type MockConnectionFlowControllerMockRecorder struct { - mock *MockConnectionFlowController -} - -// NewMockConnectionFlowController creates a new mock instance. -func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { - mock := &MockConnectionFlowController{ctrl: ctrl} - mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { - return m.recorder -} - -// AddBytesRead mocks base method. -func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesRead", arg0) -} - -// AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) -} - -// AddBytesSent mocks base method. -func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesSent", arg0) -} - -// AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) -} - -// GetWindowUpdate mocks base method. -func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) -} - -// IsNewlyBlocked mocks base method. -func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsNewlyBlocked") - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) -} - -// Reset mocks base method. -func (m *MockConnectionFlowController) Reset() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Reset") - ret0, _ := ret[0].(error) - return ret0 -} - -// Reset indicates an expected call of Reset. -func (mr *MockConnectionFlowControllerMockRecorder) Reset() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) -} - -// SendWindowSize mocks base method. -func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) -} - -// UpdateSendWindow mocks base method. -func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) -} - -// UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) -} diff --git a/internal/quic-go/mocks/crypto_setup.go b/internal/quic-go/mocks/crypto_setup.go deleted file mode 100644 index 86e21aa8..00000000 --- a/internal/quic-go/mocks/crypto_setup.go +++ /dev/null @@ -1,264 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: CryptoSetup) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - handshake "github.com/imroc/req/v3/internal/quic-go/handshake" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - qtls "github.com/imroc/req/v3/internal/quic-go/qtls" -) - -// MockCryptoSetup is a mock of CryptoSetup interface. -type MockCryptoSetup struct { - ctrl *gomock.Controller - recorder *MockCryptoSetupMockRecorder -} - -// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup. -type MockCryptoSetupMockRecorder struct { - mock *MockCryptoSetup -} - -// NewMockCryptoSetup creates a new mock instance. -func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup { - mock := &MockCryptoSetup{ctrl: ctrl} - mock.recorder = &MockCryptoSetupMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder { - return m.recorder -} - -// ChangeConnectionID mocks base method. -func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ChangeConnectionID", arg0) -} - -// ChangeConnectionID indicates an expected call of ChangeConnectionID. -func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) -} - -// Close mocks base method. -func (m *MockCryptoSetup) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) -} - -// ConnectionState mocks base method. -func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(qtls.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState. -func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) -} - -// Get0RTTOpener mocks base method. -func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get0RTTOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get0RTTOpener indicates an expected call of Get0RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) -} - -// Get0RTTSealer mocks base method. -func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get0RTTSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get0RTTSealer indicates an expected call of Get0RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) -} - -// Get1RTTOpener mocks base method. -func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get1RTTOpener") - ret0, _ := ret[0].(handshake.ShortHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get1RTTOpener indicates an expected call of Get1RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) -} - -// Get1RTTSealer mocks base method. -func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get1RTTSealer") - ret0, _ := ret[0].(handshake.ShortHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get1RTTSealer indicates an expected call of Get1RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) -} - -// GetHandshakeOpener mocks base method. -func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHandshakeOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHandshakeOpener indicates an expected call of GetHandshakeOpener. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) -} - -// GetHandshakeSealer mocks base method. -func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHandshakeSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHandshakeSealer indicates an expected call of GetHandshakeSealer. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) -} - -// GetInitialOpener mocks base method. -func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInitialOpener") - ret0, _ := ret[0].(handshake.LongHeaderOpener) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInitialOpener indicates an expected call of GetInitialOpener. -func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) -} - -// GetInitialSealer mocks base method. -func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInitialSealer") - ret0, _ := ret[0].(handshake.LongHeaderSealer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInitialSealer indicates an expected call of GetInitialSealer. -func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) -} - -// GetSessionTicket mocks base method. -func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionTicket") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionTicket indicates an expected call of GetSessionTicket. -func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) -} - -// HandleMessage mocks base method. -func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) -} - -// RunHandshake mocks base method. -func (m *MockCryptoSetup) RunHandshake() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RunHandshake") -} - -// RunHandshake indicates an expected call of RunHandshake. -func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) -} - -// SetHandshakeConfirmed mocks base method. -func (m *MockCryptoSetup) SetHandshakeConfirmed() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeConfirmed") -} - -// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) -} - -// SetLargest1RTTAcked mocks base method. -func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. -func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) -} diff --git a/internal/quic-go/mocks/logging/connection_tracer.go b/internal/quic-go/mocks/logging/connection_tracer.go deleted file mode 100644 index 9fd58412..00000000 --- a/internal/quic-go/mocks/logging/connection_tracer.go +++ /dev/null @@ -1,352 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: ConnectionTracer) - -// Package mocklogging is a generated GoMock package. -package mocklogging - -import ( - net "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - utils "github.com/imroc/req/v3/internal/quic-go/utils" - wire "github.com/imroc/req/v3/internal/quic-go/wire" - logging "github.com/imroc/req/v3/internal/quic-go/logging" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedPacket mocks base method. -func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) -} - -// ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedPacket), arg0, arg1, arg2) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentPacket mocks base method. -func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) -} diff --git a/internal/quic-go/mocks/logging/tracer.go b/internal/quic-go/mocks/logging/tracer.go deleted file mode 100644 index b0b9700f..00000000 --- a/internal/quic-go/mocks/logging/tracer.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/logging (interfaces: Tracer) - -// Package mocklogging is a generated GoMock package. -package mocklogging - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - wire "github.com/imroc/req/v3/internal/quic-go/wire" - logging "github.com/imroc/req/v3/internal/quic-go/logging" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(logging.ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/internal/quic-go/mocks/long_header_opener.go b/internal/quic-go/mocks/long_header_opener.go deleted file mode 100644 index 158941ed..00000000 --- a/internal/quic-go/mocks/long_header_opener.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: LongHeaderOpener) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockLongHeaderOpener is a mock of LongHeaderOpener interface. -type MockLongHeaderOpener struct { - ctrl *gomock.Controller - recorder *MockLongHeaderOpenerMockRecorder -} - -// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener. -type MockLongHeaderOpenerMockRecorder struct { - mock *MockLongHeaderOpener -} - -// NewMockLongHeaderOpener creates a new mock instance. -func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { - mock := &MockLongHeaderOpener{ctrl: ctrl} - mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { - return m.recorder -} - -// DecodePacketNumber mocks base method. -func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) -} - -// DecryptHeader mocks base method. -func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) -} - -// DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) -} - -// Open mocks base method. -func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open indicates an expected call of Open. -func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) -} diff --git a/internal/quic-go/mocks/mockgen.go b/internal/quic-go/mocks/mockgen.go deleted file mode 100644 index 7d470aa2..00000000 --- a/internal/quic-go/mocks/mockgen.go +++ /dev/null @@ -1,20 +0,0 @@ -package mocks - -//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/imroc/req/v3/internal/quic-go Stream" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/imroc/req/v3/internal/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && goimports -w quic/early_conn.go" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/imroc/req/v3/internal/quic-go EarlyListener" -//go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/imroc/req/v3/internal/quic-go/logging Tracer" -//go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/imroc/req/v3/internal/quic-go/logging ConnectionTracer" -//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/imroc/req/v3/internal/quic-go/handshake ShortHeaderSealer" -//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/imroc/req/v3/internal/quic-go/handshake ShortHeaderOpener" -//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/imroc/req/v3/internal/quic-go/handshake LongHeaderOpener" -//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/imroc/req/v3/internal/quic-go/handshake CryptoSetup && sed -E 's~github.com/marten-seemann/qtls[[:alnum:]_-]*~github.com/imroc/req/v3/internal/quic-go/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go" -//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/imroc/req/v3/internal/quic-go/flowcontrol StreamFlowController" -//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/imroc/req/v3/internal/quic-go/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/imroc/req/v3/internal/quic-go/flowcontrol ConnectionFlowController" -//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/imroc/req/v3/internal/quic-go/ackhandler SentPacketHandler" -//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/imroc/req/v3/internal/quic-go/ackhandler ReceivedPacketHandler" - -// The following command produces a warning message on OSX, however, it still generates the correct mock file. -// See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/quic-go/mocks/quic/early_conn.go b/internal/quic-go/mocks/quic/early_conn.go deleted file mode 100644 index 7bb07774..00000000 --- a/internal/quic-go/mocks/quic/early_conn.go +++ /dev/null @@ -1,255 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: EarlyConnection) - -// Package mockquic is a generated GoMock package. -package mockquic - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - quic "github.com/imroc/req/v3/internal/quic-go" - qerr "github.com/imroc/req/v3/internal/quic-go/qerr" -) - -// MockEarlyConnection is a mock of EarlyConnection interface. -type MockEarlyConnection struct { - ctrl *gomock.Controller - recorder *MockEarlyConnectionMockRecorder -} - -// MockEarlyConnectionMockRecorder is the mock recorder for MockEarlyConnection. -type MockEarlyConnectionMockRecorder struct { - mock *MockEarlyConnection -} - -// NewMockEarlyConnection creates a new mock instance. -func NewMockEarlyConnection(ctrl *gomock.Controller) *MockEarlyConnection { - mock := &MockEarlyConnection{ctrl: ctrl} - mock.recorder = &MockEarlyConnectionMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEarlyConnection) EXPECT() *MockEarlyConnectionMockRecorder { - return m.recorder -} - -// AcceptStream mocks base method. -func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream", arg0) - ret0, _ := ret[0].(quic.Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptStream indicates an expected call of AcceptStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) -} - -// AcceptUniStream mocks base method. -func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream", arg0) - ret0, _ := ret[0].(quic.ReceiveStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) -} - -// CloseWithError mocks base method. -func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseWithError indicates an expected call of CloseWithError. -func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) -} - -// ConnectionState mocks base method. -func (m *MockEarlyConnection) ConnectionState() quic.ConnectionState { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(quic.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState. -func (mr *MockEarlyConnectionMockRecorder) ConnectionState() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlyConnection)(nil).ConnectionState)) -} - -// Context mocks base method. -func (m *MockEarlyConnection) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlyConnection)(nil).Context)) -} - -// HandshakeComplete mocks base method. -func (m *MockEarlyConnection) HandshakeComplete() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandshakeComplete") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockEarlyConnectionMockRecorder) HandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlyConnection)(nil).HandshakeComplete)) -} - -// LocalAddr mocks base method. -func (m *MockEarlyConnection) LocalAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr. -func (mr *MockEarlyConnectionMockRecorder) LocalAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlyConnection)(nil).LocalAddr)) -} - -// NextConnection mocks base method. -func (m *MockEarlyConnection) NextConnection() quic.Connection { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextConnection") - ret0, _ := ret[0].(quic.Connection) - return ret0 -} - -// NextConnection indicates an expected call of NextConnection. -func (mr *MockEarlyConnectionMockRecorder) NextConnection() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) -} - -// OpenStream mocks base method. -func (m *MockEarlyConnection) OpenStream() (quic.Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream") - ret0, _ := ret[0].(quic.Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockEarlyConnectionMockRecorder) OpenStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStream)) -} - -// OpenStreamSync mocks base method. -func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync", arg0) - ret0, _ := ret[0].(quic.Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) -} - -// OpenUniStream mocks base method. -func (m *MockEarlyConnection) OpenUniStream() (quic.SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStream") - ret0, _ := ret[0].(quic.SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStream)) -} - -// OpenUniStreamSync mocks base method. -func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) - ret0, _ := ret[0].(quic.SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) -} - -// ReceiveMessage mocks base method. -func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage)) -} - -// RemoteAddr mocks base method. -func (m *MockEarlyConnection) RemoteAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoteAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockEarlyConnectionMockRecorder) RemoteAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlyConnection)(nil).RemoteAddr)) -} - -// SendMessage mocks base method. -func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMessage", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendMessage indicates an expected call of SendMessage. -func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) -} diff --git a/internal/quic-go/mocks/quic/early_listener.go b/internal/quic-go/mocks/quic/early_listener.go deleted file mode 100644 index a57247a3..00000000 --- a/internal/quic-go/mocks/quic/early_listener.go +++ /dev/null @@ -1,80 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: EarlyListener) - -// Package mockquic is a generated GoMock package. -package mockquic - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - quic "github.com/imroc/req/v3/internal/quic-go" -) - -// MockEarlyListener is a mock of EarlyListener interface. -type MockEarlyListener struct { - ctrl *gomock.Controller - recorder *MockEarlyListenerMockRecorder -} - -// MockEarlyListenerMockRecorder is the mock recorder for MockEarlyListener. -type MockEarlyListenerMockRecorder struct { - mock *MockEarlyListener -} - -// NewMockEarlyListener creates a new mock instance. -func NewMockEarlyListener(ctrl *gomock.Controller) *MockEarlyListener { - mock := &MockEarlyListener{ctrl: ctrl} - mock.recorder = &MockEarlyListenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { - return m.recorder -} - -// Accept mocks base method. -func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Accept", arg0) - ret0, _ := ret[0].(quic.EarlyConnection) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Accept indicates an expected call of Accept. -func (mr *MockEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockEarlyListener)(nil).Accept), arg0) -} - -// Addr mocks base method. -func (m *MockEarlyListener) Addr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Addr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// Addr indicates an expected call of Addr. -func (mr *MockEarlyListenerMockRecorder) Addr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockEarlyListener)(nil).Addr)) -} - -// Close mocks base method. -func (m *MockEarlyListener) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockEarlyListenerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEarlyListener)(nil).Close)) -} diff --git a/internal/quic-go/mocks/quic/stream.go b/internal/quic-go/mocks/quic/stream.go deleted file mode 100644 index 97a1f042..00000000 --- a/internal/quic-go/mocks/quic/stream.go +++ /dev/null @@ -1,176 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go (interfaces: Stream) - -// Package mockquic is a generated GoMock package. -package mockquic - -import ( - context "context" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" - qerr "github.com/imroc/req/v3/internal/quic-go/qerr" -) - -// MockStream is a mock of Stream interface. -type MockStream struct { - ctrl *gomock.Controller - recorder *MockStreamMockRecorder -} - -// MockStreamMockRecorder is the mock recorder for MockStream. -type MockStreamMockRecorder struct { - mock *MockStream -} - -// NewMockStream creates a new mock instance. -func NewMockStream(ctrl *gomock.Controller) *MockStream { - mock := &MockStream{ctrl: ctrl} - mock.recorder = &MockStreamMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStream) EXPECT() *MockStreamMockRecorder { - return m.recorder -} - -// CancelRead mocks base method. -func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelRead", arg0) -} - -// CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) -} - -// CancelWrite mocks base method. -func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CancelWrite", arg0) -} - -// CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) -} - -// Close mocks base method. -func (m *MockStream) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockStreamMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close)) -} - -// Context mocks base method. -func (m *MockStream) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockStreamMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context)) -} - -// Read mocks base method. -func (m *MockStream) Read(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) -} - -// SetDeadline mocks base method. -func (m *MockStream) SetDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) -} - -// SetReadDeadline mocks base method. -func (m *MockStream) SetReadDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetReadDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) -} - -// SetWriteDeadline mocks base method. -func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) -} - -// StreamID mocks base method. -func (m *MockStream) StreamID() protocol.StreamID { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(protocol.StreamID) - return ret0 -} - -// StreamID indicates an expected call of StreamID. -func (mr *MockStreamMockRecorder) StreamID() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID)) -} - -// Write mocks base method. -func (m *MockStream) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) -} diff --git a/internal/quic-go/mocks/short_header_opener.go b/internal/quic-go/mocks/short_header_opener.go deleted file mode 100644 index 1109cf0b..00000000 --- a/internal/quic-go/mocks/short_header_opener.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: ShortHeaderOpener) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockShortHeaderOpener is a mock of ShortHeaderOpener interface. -type MockShortHeaderOpener struct { - ctrl *gomock.Controller - recorder *MockShortHeaderOpenerMockRecorder -} - -// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener. -type MockShortHeaderOpenerMockRecorder struct { - mock *MockShortHeaderOpener -} - -// NewMockShortHeaderOpener creates a new mock instance. -func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { - mock := &MockShortHeaderOpener{ctrl: ctrl} - mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { - return m.recorder -} - -// DecodePacketNumber mocks base method. -func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) -} - -// DecryptHeader mocks base method. -func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) -} - -// DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) -} - -// Open mocks base method. -func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open indicates an expected call of Open. -func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) -} diff --git a/internal/quic-go/mocks/short_header_sealer.go b/internal/quic-go/mocks/short_header_sealer.go deleted file mode 100644 index 72c6cbf1..00000000 --- a/internal/quic-go/mocks/short_header_sealer.go +++ /dev/null @@ -1,89 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/handshake (interfaces: ShortHeaderSealer) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockShortHeaderSealer is a mock of ShortHeaderSealer interface. -type MockShortHeaderSealer struct { - ctrl *gomock.Controller - recorder *MockShortHeaderSealerMockRecorder -} - -// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer. -type MockShortHeaderSealerMockRecorder struct { - mock *MockShortHeaderSealer -} - -// NewMockShortHeaderSealer creates a new mock instance. -func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { - mock := &MockShortHeaderSealer{ctrl: ctrl} - mock.recorder = &MockShortHeaderSealerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { - return m.recorder -} - -// EncryptHeader mocks base method. -func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) -} - -// EncryptHeader indicates an expected call of EncryptHeader. -func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) -} - -// KeyPhase mocks base method. -func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "KeyPhase") - ret0, _ := ret[0].(protocol.KeyPhaseBit) - return ret0 -} - -// KeyPhase indicates an expected call of KeyPhase. -func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) -} - -// Overhead mocks base method. -func (m *MockShortHeaderSealer) Overhead() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Overhead") - ret0, _ := ret[0].(int) - return ret0 -} - -// Overhead indicates an expected call of Overhead. -func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) -} - -// Seal mocks base method. -func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Seal indicates an expected call of Seal. -func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) -} diff --git a/internal/quic-go/mocks/stream_flow_controller.go b/internal/quic-go/mocks/stream_flow_controller.go deleted file mode 100644 index 66d8c2ac..00000000 --- a/internal/quic-go/mocks/stream_flow_controller.go +++ /dev/null @@ -1,140 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/imroc/req/v3/internal/quic-go/flowcontrol (interfaces: StreamFlowController) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// MockStreamFlowController is a mock of StreamFlowController interface. -type MockStreamFlowController struct { - ctrl *gomock.Controller - recorder *MockStreamFlowControllerMockRecorder -} - -// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController. -type MockStreamFlowControllerMockRecorder struct { - mock *MockStreamFlowController -} - -// NewMockStreamFlowController creates a new mock instance. -func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { - mock := &MockStreamFlowController{ctrl: ctrl} - mock.recorder = &MockStreamFlowControllerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { - return m.recorder -} - -// Abandon mocks base method. -func (m *MockStreamFlowController) Abandon() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Abandon") -} - -// Abandon indicates an expected call of Abandon. -func (mr *MockStreamFlowControllerMockRecorder) Abandon() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) -} - -// AddBytesRead mocks base method. -func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesRead", arg0) -} - -// AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) -} - -// AddBytesSent mocks base method. -func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesSent", arg0) -} - -// AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) -} - -// GetWindowUpdate mocks base method. -func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) -} - -// IsNewlyBlocked mocks base method. -func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsNewlyBlocked") - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(protocol.ByteCount) - return ret0, ret1 -} - -// IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) -} - -// SendWindowSize mocks base method. -func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) -} - -// UpdateHighestReceived mocks base method. -func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount, arg1 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateHighestReceived", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateHighestReceived indicates an expected call of UpdateHighestReceived. -func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) -} - -// UpdateSendWindow mocks base method. -func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) -} - -// UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) -} diff --git a/internal/quic-go/mocks/tls/client_session_cache.go b/internal/quic-go/mocks/tls/client_session_cache.go deleted file mode 100644 index e3ae2c8e..00000000 --- a/internal/quic-go/mocks/tls/client_session_cache.go +++ /dev/null @@ -1,62 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: crypto/tls (interfaces: ClientSessionCache) - -// Package mocktls is a generated GoMock package. -package mocktls - -import ( - tls "crypto/tls" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockClientSessionCache is a mock of ClientSessionCache interface. -type MockClientSessionCache struct { - ctrl *gomock.Controller - recorder *MockClientSessionCacheMockRecorder -} - -// MockClientSessionCacheMockRecorder is the mock recorder for MockClientSessionCache. -type MockClientSessionCacheMockRecorder struct { - mock *MockClientSessionCache -} - -// NewMockClientSessionCache creates a new mock instance. -func NewMockClientSessionCache(ctrl *gomock.Controller) *MockClientSessionCache { - mock := &MockClientSessionCache{ctrl: ctrl} - mock.recorder = &MockClientSessionCacheMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClientSessionCache) EXPECT() *MockClientSessionCacheMockRecorder { - return m.recorder -} - -// Get mocks base method. -func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(*tls.ClientSessionState) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) -} - -// Put mocks base method. -func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Put", arg0, arg1) -} - -// Put indicates an expected call of Put. -func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) -} diff --git a/internal/quic-go/mtu_discoverer.go b/internal/quic-go/mtu_discoverer.go deleted file mode 100644 index d8259fc7..00000000 --- a/internal/quic-go/mtu_discoverer.go +++ /dev/null @@ -1,74 +0,0 @@ -package quic - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type mtuDiscoverer interface { - ShouldSendProbe(now time.Time) bool - GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) -} - -const ( - // At some point, we have to stop searching for a higher MTU. - // We're happy to send a packet that's 10 bytes smaller than the actual MTU. - maxMTUDiff = 20 - // send a probe packet every mtuProbeDelay RTTs - mtuProbeDelay = 5 -) - -type mtuFinder struct { - lastProbeTime time.Time - probeInFlight bool - mtuIncreased func(protocol.ByteCount) - - rttStats *utils.RTTStats - current protocol.ByteCount - max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) -} - -var _ mtuDiscoverer = &mtuFinder{} - -func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { - return &mtuFinder{ - current: start, - rttStats: rttStats, - lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately - mtuIncreased: mtuIncreased, - max: max, - } -} - -func (f *mtuFinder) done() bool { - return f.max-f.current <= maxMTUDiff+1 -} - -func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { - if f.probeInFlight || f.done() { - return false - } - return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) -} - -func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { - size := (f.max + f.current) / 2 - f.lastProbeTime = time.Now() - f.probeInFlight = true - return ackhandler.Frame{ - Frame: &wire.PingFrame{}, - OnLost: func(wire.Frame) { - f.probeInFlight = false - f.max = size - }, - OnAcked: func(wire.Frame) { - f.probeInFlight = false - f.current = size - f.mtuIncreased(size) - }, - }, size -} diff --git a/internal/quic-go/mtu_discoverer_test.go b/internal/quic-go/mtu_discoverer_test.go deleted file mode 100644 index f6701827..00000000 --- a/internal/quic-go/mtu_discoverer_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package quic - -import ( - "math/rand" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MTU Discoverer", func() { - const ( - rtt = 100 * time.Millisecond - startMTU protocol.ByteCount = 1000 - maxMTU protocol.ByteCount = 2000 - ) - - var ( - d mtuDiscoverer - rttStats *utils.RTTStats - now time.Time - discoveredMTU protocol.ByteCount - ) - - BeforeEach(func() { - rttStats = &utils.RTTStats{} - rttStats.SetInitialRTT(rtt) - Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) - d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s }) - now = time.Now() - _ = discoveredMTU - }) - - It("only allows a probe 5 RTTs after the handshake completes", func() { - Expect(d.ShouldSendProbe(now)).To(BeFalse()) - Expect(d.ShouldSendProbe(now.Add(rtt * 9 / 2))).To(BeFalse()) - Expect(d.ShouldSendProbe(now.Add(rtt * 5))).To(BeTrue()) - }) - - It("doesn't allow a probe if another probe is still in flight", func() { - ping, _ := d.GetPing() - Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) - ping.OnLost(ping.Frame) - Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) - }) - - It("tries a lower size when a probe is lost", func() { - ping, size := d.GetPing() - Expect(size).To(Equal(protocol.ByteCount(1500))) - ping.OnLost(ping.Frame) - _, size = d.GetPing() - Expect(size).To(Equal(protocol.ByteCount(1250))) - }) - - It("tries a higher size and calls the callback when a probe is acknowledged", func() { - ping, size := d.GetPing() - Expect(size).To(Equal(protocol.ByteCount(1500))) - ping.OnAcked(ping.Frame) - Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) - _, size = d.GetPing() - Expect(size).To(Equal(protocol.ByteCount(1750))) - }) - - It("stops discovery after getting close enough to the MTU", func() { - var sizes []protocol.ByteCount - t := now.Add(5 * rtt) - for d.ShouldSendProbe(t) { - ping, size := d.GetPing() - ping.OnAcked(ping.Frame) - sizes = append(sizes, size) - t = t.Add(5 * rtt) - } - Expect(sizes).To(Equal([]protocol.ByteCount{1500, 1750, 1875, 1937, 1968, 1984})) - Expect(d.ShouldSendProbe(t.Add(10 * rtt))).To(BeFalse()) - }) - - It("finds the MTU", func() { - const rep = 3000 - var maxDiff protocol.ByteCount - for i := 0; i < rep; i++ { - max := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 - currentMTU := startMTU - d := newMTUDiscoverer(rttStats, startMTU, max, func(s protocol.ByteCount) { currentMTU = s }) - now := time.Now() - realMTU := protocol.ByteCount(rand.Intn(int(max-startMTU))) + startMTU - t := now.Add(mtuProbeDelay * rtt) - var count int - for d.ShouldSendProbe(t) { - if count > 25 { - Fail("too many iterations") - } - count++ - - ping, size := d.GetPing() - if size <= realMTU { - ping.OnAcked(ping.Frame) - } else { - ping.OnLost(ping.Frame) - } - t = t.Add(mtuProbeDelay * rtt) - } - diff := realMTU - currentMTU - Expect(diff).To(BeNumerically(">=", 0)) - maxDiff = utils.MaxByteCount(maxDiff, diff) - } - Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff)) - }) -}) diff --git a/internal/quic-go/multiplexer.go b/internal/quic-go/multiplexer.go deleted file mode 100644 index af943300..00000000 --- a/internal/quic-go/multiplexer.go +++ /dev/null @@ -1,107 +0,0 @@ -package quic - -import ( - "bytes" - "fmt" - "net" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -var ( - connMuxerOnce sync.Once - connMuxer multiplexer -) - -type indexableConn interface { - LocalAddr() net.Addr -} - -type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) - RemoveConn(indexableConn) error -} - -type connManager struct { - connIDLen int - statelessResetKey []byte - tracer logging.Tracer - manager packetHandlerManager -} - -// The connMultiplexer listens on multiple net.PacketConns and dispatches -// incoming packets to the connection handler. -type connMultiplexer struct { - mutex sync.Mutex - - conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests - - logger utils.Logger -} - -var _ multiplexer = &connMultiplexer{} - -func getMultiplexer() multiplexer { - connMuxerOnce.Do(func() { - connMuxer = &connMultiplexer{ - conns: make(map[string]connManager), - logger: utils.DefaultLogger.WithPrefix("muxer"), - newPacketHandlerManager: newPacketHandlerMap, - } - }) - return connMuxer -} - -func (m *connMultiplexer) AddConn( - c net.PacketConn, - connIDLen int, - statelessResetKey []byte, - tracer logging.Tracer, -) (packetHandlerManager, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - addr := c.LocalAddr() - connIndex := addr.Network() + " " + addr.String() - p, ok := m.conns[connIndex] - if !ok { - manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) - if err != nil { - return nil, err - } - p = connManager{ - connIDLen: connIDLen, - statelessResetKey: statelessResetKey, - manager: manager, - tracer: tracer, - } - m.conns[connIndex] = p - } else { - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") - } - if tracer != p.tracer { - return nil, fmt.Errorf("cannot use different tracers on the same packet conn") - } - } - return p.manager, nil -} - -func (m *connMultiplexer) RemoveConn(c indexableConn) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() - if _, ok := m.conns[connIndex]; !ok { - return fmt.Errorf("cannote remove connection, connection is unknown") - } - - delete(m.conns, connIndex) - return nil -} diff --git a/internal/quic-go/multiplexer_test.go b/internal/quic-go/multiplexer_test.go deleted file mode 100644 index 58c3fd84..00000000 --- a/internal/quic-go/multiplexer_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package quic - -import ( - "net" - - "github.com/golang/mock/gomock" - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type testConn struct { - counter int - net.PacketConn -} - -var _ = Describe("Multiplexer", func() { - It("adds a new packet conn ", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) - _, err := getMultiplexer().AddConn(conn, 8, nil, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - It("recognizes when the same connection is added twice", func() { - pconn := NewMockPacketConn(mockCtrl) - pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) - pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn := testConn{PacketConn: pconn} - tracer := mocklogging.NewMockTracer(mockCtrl) - _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) - Expect(err).ToNot(HaveOccurred()) - conn.counter++ - _, err = getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) - Expect(err).ToNot(HaveOccurred()) - Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) - }) - - It("errors when adding an existing conn with a different connection ID length", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 5, nil, nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6, nil, nil) - Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) - }) - - It("errors when adding an existing conn with a different stateless rest key", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil) - Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) - }) - - It("errors when adding an existing conn with different tracers", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) - }) -}) diff --git a/internal/quic-go/packet_handler_map.go b/internal/quic-go/packet_handler_map.go deleted file mode 100644 index 119b011d..00000000 --- a/internal/quic-go/packet_handler_map.go +++ /dev/null @@ -1,489 +0,0 @@ -package quic - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "hash" - "io" - "log" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type zeroRTTQueue struct { - queue []*receivedPacket - retireTimer *time.Timer -} - -var _ packetHandler = &zeroRTTQueue{} - -func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { - if len(h.queue) < protocol.Max0RTTQueueLen { - h.queue = append(h.queue, p) - } -} -func (h *zeroRTTQueue) shutdown() {} -func (h *zeroRTTQueue) destroy(error) {} -func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } -func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { - for _, p := range h.queue { - sess.handlePacket(p) - } -} - -func (h *zeroRTTQueue) Clear() { - for _, p := range h.queue { - p.buffer.Release() - } -} - -// rawConn is a connection that allow reading of a receivedPacket. -type rawConn interface { - ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) - LocalAddr() net.Addr - io.Closer -} - -type packetHandlerMapEntry struct { - packetHandler packetHandler - is0RTTQueue bool -} - -// The packetHandlerMap stores packetHandlers, identified by connection ID. -// It is used: -// * by the server to store connections -// * when multiplexing outgoing connections to store clients -type packetHandlerMap struct { - mutex sync.Mutex - - conn rawConn - connIDLen int - - handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - numZeroRTTEntries int - - listening chan struct{} // is closed when listen returns - closed bool - - deleteRetiredConnsAfter time.Duration - zeroRTTQueueDuration time.Duration - - statelessResetEnabled bool - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash - - tracer logging.Tracer - logger utils.Logger -} - -var _ packetHandlerManager = &packetHandlerMap{} - -func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { - conn, ok := c.(interface{ SetReadBuffer(int) error }) - if !ok { - return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") - } - size, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if size >= protocol.DesiredReceiveBufferSize { - logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) - return nil - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return fmt.Errorf("failed to increase receive buffer size: %w", err) - } - newSize, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if newSize == size { - return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - if newSize < protocol.DesiredReceiveBufferSize { - return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) - return nil -} - -// only print warnings about the UDP receive buffer size once -var receiveBufferWarningOnce sync.Once - -func newPacketHandlerMap( - c net.PacketConn, - connIDLen int, - statelessResetKey []byte, - tracer logging.Tracer, - logger utils.Logger, -) (packetHandlerManager, error) { - if err := setReceiveBuffer(c, logger); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - receiveBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/imroc/req/v3/internal/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) - } - } - conn, err := wrapConn(c) - if err != nil { - return nil, err - } - m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), - handlers: make(map[string]packetHandlerMapEntry), - resetTokens: make(map[protocol.StatelessResetToken]packetHandler), - deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, - statelessResetEnabled: len(statelessResetKey) > 0, - statelessResetHasher: hmac.New(sha256.New, statelessResetKey), - tracer: tracer, - logger: logger, - } - go m.listen() - - if logger.Debug() { - go m.logUsage() - } - return m, nil -} - -func (h *packetHandlerMap) logUsage() { - ticker := time.NewTicker(2 * time.Second) - var printedZero bool - for { - select { - case <-h.listening: - return - case <-ticker.C: - } - - h.mutex.Lock() - numHandlers := len(h.handlers) - numTokens := len(h.resetTokens) - h.mutex.Unlock() - // If the number tracked handlers and tokens is zero, only print it a single time. - hasZero := numHandlers == 0 && numTokens == 0 - if !hasZero || (hasZero && !printedZero) { - h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) - printedZero = false - if hasZero { - printedZero = true - } - } - } -} - -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { - h.mutex.Lock() - defer h.mutex.Unlock() - - if _, ok := h.handlers[string(id)]; ok { - h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) - return false - } - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} - h.logger.Debugf("Adding connection ID %s.", id) - return true -} - -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { - h.mutex.Lock() - defer h.mutex.Unlock() - - var q *zeroRTTQueue - if entry, ok := h.handlers[string(clientDestConnID)]; ok { - if !entry.is0RTTQueue { - h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) - return false - } - q = entry.packetHandler.(*zeroRTTQueue) - q.retireTimer.Stop() - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - } - sess := fn() - if q != nil { - q.EnqueueAll(sess) - } - h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) - return true -} - -func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { - h.mutex.Lock() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s.", id) -} - -func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { - h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) - time.AfterFunc(h.deleteRetiredConnsAfter, func() { - h.mutex.Lock() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s after it has been retired.", id) - }) -} - -func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { - h.mutex.Lock() - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} - h.mutex.Unlock() - h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) - - time.AfterFunc(h.deleteRetiredConnsAfter, func() { - h.mutex.Lock() - handler.shutdown() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) - }) -} - -func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { - h.mutex.Lock() - h.resetTokens[token] = handler - h.mutex.Unlock() -} - -func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { - h.mutex.Lock() - delete(h.resetTokens, token) - h.mutex.Unlock() -} - -func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { - h.mutex.Lock() - h.server = s - h.mutex.Unlock() -} - -func (h *packetHandlerMap) CloseServer() { - h.mutex.Lock() - if h.server == nil { - h.mutex.Unlock() - return - } - h.server = nil - var wg sync.WaitGroup - for _, entry := range h.handlers { - if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { - wg.Add(1) - go func(handler packetHandler) { - // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - handler.shutdown() - wg.Done() - }(entry.packetHandler) - } - } - h.mutex.Unlock() - wg.Wait() -} - -// Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active connections. -func (h *packetHandlerMap) Destroy() error { - if err := h.conn.Close(); err != nil { - return err - } - <-h.listening // wait until listening returns - return nil -} - -func (h *packetHandlerMap) close(e error) error { - h.mutex.Lock() - if h.closed { - h.mutex.Unlock() - return nil - } - - var wg sync.WaitGroup - for _, entry := range h.handlers { - wg.Add(1) - go func(handler packetHandler) { - handler.destroy(e) - wg.Done() - }(entry.packetHandler) - } - - if h.server != nil { - h.server.setCloseError(e) - } - h.closed = true - h.mutex.Unlock() - wg.Wait() - return getMultiplexer().RemoveConn(h.conn) -} - -func (h *packetHandlerMap) listen() { - defer close(h.listening) - for { - p, err := h.conn.ReadPacket() - //nolint:staticcheck // SA1019 ignore this! - // TODO: This code is used to ignore wsa errors on Windows. - // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. - // See https://github.com/imroc/req/v3/internal/quic-go/issues/1737 for details. - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - h.logger.Debugf("Temporary error reading from conn: %w", err) - continue - } - if err != nil { - h.close(err) - return - } - h.handlePacket(p) - } -} - -func (h *packetHandlerMap) handlePacket(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, h.connIDLen) - if err != nil { - h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - p.buffer.MaybeRelease() - return - } - - h.mutex.Lock() - defer h.mutex.Unlock() - - if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } - - if entry, ok := h.handlers[string(connID)]; ok { - if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue - if wire.Is0RTTPacket(p.data) { - entry.packetHandler.handlePacket(p) - return - } - } else { // existing connection - entry.packetHandler.handlePacket(p) - return - } - } - if p.data[0]&0x80 == 0 { - go h.maybeSendStatelessReset(p, connID) - return - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) - return - } - if wire.Is0RTTPacket(p.data) { - if h.numZeroRTTEntries >= protocol.Max0RTTQueues { - return - } - h.numZeroRTTEntries++ - queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[string(connID)] = packetHandlerMapEntry{ - packetHandler: queue, - is0RTTQueue: true, - } - queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { - h.mutex.Lock() - defer h.mutex.Unlock() - // The entry might have been replaced by an actual connection. - // Only delete it if it's still a 0-RTT queue. - if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { - delete(h.handlers, string(connID)) - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - entry.packetHandler.(*zeroRTTQueue).Clear() - if h.logger.Debug() { - h.logger.Debugf("Removing 0-RTT queue for %s.", connID) - } - } - }) - queue.handlePacket(p) - return - } - h.server.handlePacket(p) -} - -func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { - // stateless resets are always short header packets - if data[0]&0x80 != 0 { - return false - } - if len(data) < 17 /* type byte + 16 bytes for the reset token */ { - return false - } - - var token protocol.StatelessResetToken - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go sess.destroy(&StatelessResetError{Token: token}) - return true - } - return false -} - -func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { - var token protocol.StatelessResetToken - if !h.statelessResetEnabled { - // Return a random stateless reset token. - // This token will be sent in the server's transport parameters. - // By using a random token, an off-path attacker won't be able to disrupt the connection. - rand.Read(token[:]) - return token - } - h.statelessResetMutex.Lock() - h.statelessResetHasher.Write(connID.Bytes()) - copy(token[:], h.statelessResetHasher.Sum(nil)) - h.statelessResetHasher.Reset() - h.statelessResetMutex.Unlock() - return token -} - -func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { - defer p.buffer.Release() - if !h.statelessResetEnabled { - return - } - // Don't send a stateless reset in response to very small packets. - // This includes packets that could be stateless resets. - if len(p.data) <= protocol.MinStatelessResetSize { - return - } - token := h.GetStatelessResetToken(connID) - h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) - data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) - rand.Read(data) - data[0] = (data[0] & 0x7f) | 0x40 - data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - h.logger.Debugf("Error sending Stateless Reset: %s", err) - } -} diff --git a/internal/quic-go/packet_handler_map_test.go b/internal/quic-go/packet_handler_map_test.go deleted file mode 100644 index 21c1fcbe..00000000 --- a/internal/quic-go/packet_handler_map_test.go +++ /dev/null @@ -1,495 +0,0 @@ -package quic - -import ( - "bytes" - "crypto/rand" - "errors" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet Handler Map", func() { - type packetToRead struct { - addr net.Addr - data []byte - err error - } - - var ( - handler *packetHandlerMap - conn *MockPacketConn - tracer *mocklogging.MockTracer - packetChan chan packetToRead - - connIDLen int - statelessResetKey []byte - ) - - getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { - buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: t, - DestConnectionID: connID, - Length: length, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }).Write(buf, protocol.VersionWhatever)).To(Succeed()) - return buf.Bytes() - } - - getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) - } - - BeforeEach(func() { - statelessResetKey = nil - connIDLen = 0 - tracer = mocklogging.NewMockTracer(mockCtrl) - packetChan = make(chan packetToRead, 10) - }) - - JustBeforeEach(func() { - conn = NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { - p, ok := <-packetChan - if !ok { - return 0, nil, errors.New("closed") - } - return copy(b, p.data), p.addr, p.err - }).AnyTimes() - phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) - Expect(err).ToNot(HaveOccurred()) - handler = phm.(*packetHandlerMap) - }) - - It("closes", func() { - getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer := NewMockMultiplexer(mockCtrl) - origMultiplexer := connMuxer - connMuxer = mockMultiplexer - - defer func() { - connMuxer = origMultiplexer - }() - - testErr := errors.New("test error ") - conn1 := NewMockPacketHandler(mockCtrl) - conn1.EXPECT().destroy(testErr) - conn2 := NewMockPacketHandler(mockCtrl) - conn2.EXPECT().destroy(testErr) - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2) - mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) - handler.close(testErr) - close(packetChan) - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("other operations", func() { - AfterEach(func() { - // delete connections and the server before closing - // They might be mock implementations, and we'd have to register the expected calls before otherwise. - handler.mutex.Lock() - for connID := range handler.handlers { - delete(handler.handlers, connID) - } - handler.server = nil - handler.mutex.Unlock() - conn.EXPECT().Close().MaxTimes(1) - close(packetChan) - handler.Destroy() - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("handling packets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - packetHandler1 := NewMockPacketHandler(mockCtrl) - packetHandler2 := NewMockPacketHandler(mockCtrl) - handledPacket1 := make(chan struct{}) - handledPacket2 := make(chan struct{}) - packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID1)) - close(handledPacket1) - }) - packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID2)) - close(handledPacket2) - }) - handler.Add(connID1, packetHandler1) - handler.Add(connID2, packetHandler2) - packetChan <- packetToRead{data: getPacket(connID1)} - packetChan <- packetToRead{data: getPacket(connID2)} - - Eventually(handledPacket1).Should(BeClosed()) - Eventually(handledPacket2).Should(BeClosed()) - }) - - It("drops unparseable packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: []byte{0, 1, 2, 3}, - }) - }) - - It("deletes removed connections immediately", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.Add(connID, NewMockPacketHandler(mockCtrl)) - handler.Remove(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("deletes retired connection entries after a wait time", func() { - handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - conn := NewMockPacketHandler(mockCtrl) - handler.Add(connID, conn) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("passes packets arriving late for closed connections to that connection", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - handled := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - close(handled) - }) - handler.Add(connID, packetHandler) - handler.Retire(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - Eventually(handled).Should(BeClosed()) - }) - - It("drops packets for unknown receivers", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - }) - - It("closes the packet handlers when reading from the conn fails", func() { - done := make(chan struct{}) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { - Expect(e).To(HaveOccurred()) - close(done) - }) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) - packetChan <- packetToRead{err: errors.New("read failed")} - Eventually(done).Should(BeClosed()) - }) - - It("continues listening for temporary errors", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) - err := deadlineError{} - Expect(err.Temporary()).To(BeTrue()) - packetChan <- packetToRead{err: err} - // don't EXPECT any calls to packetHandler.destroy - time.Sleep(50 * time.Millisecond) - }) - - It("says if a connection ID is already taken", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) - }) - - It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - newConnID1 := protocol.ConnectionID{1, 2, 3, 4} - newConnID2 := protocol.ConnectionID{4, 3, 2, 1} - Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) - Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) - }) - }) - - Context("running a server", func() { - It("adds a server", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - }) - handler.SetServer(server) - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("closes all server connections", func() { - handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) - clientConn := NewMockPacketHandler(mockCtrl) - clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - serverConn := NewMockPacketHandler(mockCtrl) - serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - serverConn.EXPECT().shutdown() - - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn) - handler.CloseServer() - }) - - It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - handler.CloseServer() - handler.handlePacket(&receivedPacket{data: p}) - }) - }) - - Context("0-RTT", func() { - JustBeforeEach(func() { - handler.zeroRTTQueueDuration = time.Hour - server := NewMockUnknownPacketHandler(mockCtrl) - // we don't expect any calls to server.handlePacket - handler.SetServer(server) - }) - - It("queues 0-RTT packets", func() { - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} - p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} - handler.handlePacket(p1) - handler.handlePacket(p2) - handler.handlePacket(p3) - conn := NewMockPacketHandler(mockCtrl) - done := make(chan struct{}) - gomock.InOrder( - conn.EXPECT().handlePacket(p1), - conn.EXPECT().handlePacket(p2), - conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), - ) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) - Eventually(done).Should(BeClosed()) - }) - - It("directs 0-RTT packets to existing connections", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) - p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - conn.EXPECT().handlePacket(p1) - handler.handlePacket(p1) - }) - - It("limits the number of 0-RTT queues", func() { - for i := 0; i < protocol.Max0RTTQueues; i++ { - connID := make(protocol.ConnectionID, 8) - rand.Read(connID) - p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - handler.handlePacket(p) - } - // We're already storing the maximum number of queues. This packet will be dropped. - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} - handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) - // Don't EXPECT any handlePacket() calls. - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) - time.Sleep(20 * time.Millisecond) - }) - - It("deletes queues if no connection is created for this connection ID", func() { - queueDuration := scaleDuration(10 * time.Millisecond) - handler.zeroRTTQueueDuration = queueDuration - - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p1 := &receivedPacket{ - data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), - buffer: getPacketBuffer(), - } - p2 := &receivedPacket{ - data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2), - buffer: getPacketBuffer(), - } - handler.handlePacket(p1) - handler.handlePacket(p2) - // wait a bit. The queue should now already be deleted. - time.Sleep(queueDuration * 3) - // Don't EXPECT any handlePacket() calls. - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) - time.Sleep(20 * time.Millisecond) - }) - }) - - Context("stateless resets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - Context("handling", func() { - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - defer close(destroyed) - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - close(destroyed) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("removes reset tokens", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RemoveResetToken(token) - // don't EXPECT any call to packetHandler.destroy() - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("ignores packets too small to contain a stateless reset", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - done := make(chan struct{}) - // don't EXPECT any calls here, but register the closing of the done channel - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { - close(done) - }).AnyTimes() - packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} - Consistently(done).ShouldNot(BeClosed()) - }) - }) - - Context("generating", func() { - BeforeEach(func() { - key := make([]byte, 32) - rand.Read(key) - statelessResetKey = key - }) - - It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { - defer close(done) - Expect(b[0] & 0x80).To(BeZero()) // short header packet - Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) - }) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - - Context("if no key is configured", func() { - It("doesn't send stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - }) - }) -}) diff --git a/internal/quic-go/packet_packer.go b/internal/quic-go/packet_packer.go deleted file mode 100644 index de9ce11c..00000000 --- a/internal/quic-go/packet_packer.go +++ /dev/null @@ -1,894 +0,0 @@ -package quic - -import ( - "bytes" - "errors" - "fmt" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type packer interface { - PackCoalescedPacket() (*coalescedPacket, error) - PackPacket() (*packedPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) - MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) - PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) - - SetMaxPacketSize(protocol.ByteCount) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) - - HandleTransportParameters(*wire.TransportParameters) - SetToken([]byte) -} - -type sealer interface { - handshake.LongHeaderSealer -} - -type payload struct { - frames []ackhandler.Frame - ack *wire.AckFrame - length protocol.ByteCount -} - -type packedPacket struct { - buffer *packetBuffer - *packetContents -} - -type packetContents struct { - header *wire.ExtendedHeader - ack *wire.AckFrame - frames []ackhandler.Frame - - length protocol.ByteCount - - isMTUProbePacket bool -} - -type coalescedPacket struct { - buffer *packetBuffer - packets []*packetContents -} - -func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { - if !p.header.IsLongHeader { - return protocol.Encryption1RTT - } - //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). - switch p.header.Type { - case protocol.PacketTypeInitial: - return protocol.EncryptionInitial - case protocol.PacketTypeHandshake: - return protocol.EncryptionHandshake - case protocol.PacketType0RTT: - return protocol.Encryption0RTT - default: - panic("can't determine encryption level") - } -} - -func (p *packetContents) IsAckEliciting() bool { - return ackhandler.HasAckElicitingFrames(p.frames) -} - -func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { - largestAcked := protocol.InvalidPacketNumber - if p.ack != nil { - largestAcked = p.ack.LargestAcked() - } - encLevel := p.EncryptionLevel() - for i := range p.frames { - if p.frames[i].OnLost != nil { - continue - } - switch encLevel { - case protocol.EncryptionInitial: - p.frames[i].OnLost = q.AddInitial - case protocol.EncryptionHandshake: - p.frames[i].OnLost = q.AddHandshake - case protocol.Encryption0RTT, protocol.Encryption1RTT: - p.frames[i].OnLost = q.AddAppData - } - } - return &ackhandler.Packet{ - PacketNumber: p.header.PacketNumber, - LargestAcked: largestAcked, - Frames: p.frames, - Length: p.length, - EncryptionLevel: encLevel, - SendTime: now, - IsPathMTUProbePacket: p.isMTUProbePacket, - } -} - -func getMaxPacketSize(addr net.Addr) protocol.ByteCount { - maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if utils.IsIPv4(udpAddr.IP) { - maxSize = protocol.InitialPacketSizeIPv4 - } else { - maxSize = protocol.InitialPacketSizeIPv6 - } - } - return maxSize -} - -type packetNumberManager interface { - PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) - PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber -} - -type sealingManager interface { - GetInitialSealer() (handshake.LongHeaderSealer, error) - GetHandshakeSealer() (handshake.LongHeaderSealer, error) - Get0RTTSealer() (handshake.LongHeaderSealer, error) - Get1RTTSealer() (handshake.ShortHeaderSealer, error) -} - -type frameSource interface { - HasData() bool - AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) -} - -type ackFrameSource interface { - GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame -} - -type packetPacker struct { - srcConnID protocol.ConnectionID - getDestConnID func() protocol.ConnectionID - - perspective protocol.Perspective - version protocol.VersionNumber - cryptoSetup sealingManager - - initialStream cryptoStream - handshakeStream cryptoStream - - token []byte - - pnManager packetNumberManager - framer frameSource - acks ackFrameSource - datagramQueue *datagramQueue - retransmissionQueue *retransmissionQueue - - maxPacketSize protocol.ByteCount - numNonAckElicitingAcks int -} - -var _ packer = &packetPacker{} - -func newPacketPacker( - srcConnID protocol.ConnectionID, - getDestConnID func() protocol.ConnectionID, - initialStream cryptoStream, - handshakeStream cryptoStream, - packetNumberManager packetNumberManager, - retransmissionQueue *retransmissionQueue, - remoteAddr net.Addr, // only used for determining the max packet size - cryptoSetup sealingManager, - framer frameSource, - acks ackFrameSource, - datagramQueue *datagramQueue, - perspective protocol.Perspective, - version protocol.VersionNumber, -) *packetPacker { - return &packetPacker{ - cryptoSetup: cryptoSetup, - getDestConnID: getDestConnID, - srcConnID: srcConnID, - initialStream: initialStream, - handshakeStream: handshakeStream, - retransmissionQueue: retransmissionQueue, - datagramQueue: datagramQueue, - perspective: perspective, - version: version, - framer: framer, - acks: acks, - pnManager: packetNumberManager, - maxPacketSize: getMaxPacketSize(remoteAddr), - } -} - -// PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) { - var reason string - // don't send details of crypto errors - if !e.ErrorCode.IsCryptoError() { - reason = e.ErrorMessage - } - return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason) -} - -// PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) { - return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage) -} - -func (p *packetPacker) packConnectionClose( - isApplicationError bool, - errorCode uint64, - frameType uint64, - reason string, -) (*coalescedPacket, error) { - var sealers [4]sealer - var hdrs [4]*wire.ExtendedHeader - var payloads [4]*payload - var size protocol.ByteCount - var numPackets uint8 - encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} - for i, encLevel := range encLevels { - if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { - continue - } - ccf := &wire.ConnectionCloseFrame{ - IsApplicationError: isApplicationError, - ErrorCode: errorCode, - FrameType: frameType, - ReasonPhrase: reason, - } - // don't send application errors in Initial or Handshake packets - if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { - ccf.IsApplicationError = false - ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) - ccf.ReasonPhrase = "" - } - payload := &payload{ - frames: []ackhandler.Frame{{Frame: ccf}}, - length: ccf.Length(p.version), - } - - var sealer sealer - var err error - var keyPhase protocol.KeyPhaseBit // only set for 1-RTT - switch encLevel { - case protocol.EncryptionInitial: - sealer, err = p.cryptoSetup.GetInitialSealer() - case protocol.EncryptionHandshake: - sealer, err = p.cryptoSetup.GetHandshakeSealer() - case protocol.Encryption0RTT: - sealer, err = p.cryptoSetup.Get0RTTSealer() - case protocol.Encryption1RTT: - var s handshake.ShortHeaderSealer - s, err = p.cryptoSetup.Get1RTTSealer() - if err == nil { - keyPhase = s.KeyPhase() - } - sealer = s - } - if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { - continue - } - if err != nil { - return nil, err - } - sealers[i] = sealer - var hdr *wire.ExtendedHeader - if encLevel == protocol.Encryption1RTT { - hdr = p.getShortHeader(keyPhase) - } else { - hdr = p.getLongHeader(encLevel) - } - hdrs[i] = hdr - payloads[i] = payload - size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - numPackets++ - } - contents := make([]*packetContents, 0, numPackets) - buffer := getPacketBuffer() - for i, encLevel := range encLevels { - if sealers[i] == nil { - continue - } - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size) - } - c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) - if err != nil { - return nil, err - } - contents = append(contents, c) - } - return &coalescedPacket{buffer: buffer, packets: contents}, nil -} - -// packetLength calculates the length of the serialized packet. -// It takes into account that packets that have a tiny payload need to be padded, -// such that len(payload) + packet number len >= 4 + AEAD overhead -func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { - var paddingLen protocol.ByteCount - pnLen := protocol.ByteCount(hdr.PacketNumberLen) - if payload.length < 4-pnLen { - paddingLen = 4 - pnLen - payload.length - } - return hdr.GetLength(p.version) + payload.length + paddingLen -} - -func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { - var encLevel protocol.EncryptionLevel - var ack *wire.AckFrame - if !handshakeConfirmed { - ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) - if ack != nil { - encLevel = protocol.EncryptionInitial - } else { - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) - if ack != nil { - encLevel = protocol.EncryptionHandshake - } - } - } - if ack == nil { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) - if ack == nil { - return nil, nil - } - encLevel = protocol.Encryption1RTT - } - payload := &payload{ - ack: ack, - length: ack.Length(p.version), - } - - sealer, hdr, err := p.getSealerAndHeader(encLevel) - if err != nil { - return nil, err - } - return p.writeSinglePacket(hdr, payload, encLevel, sealer) -} - -// size is the expected size of the packet, if no padding was applied. -func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { - // For the server, only ack-eliciting Initial packets need to be padded. - if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { - return 0 - } - if size >= p.maxPacketSize { - return 0 - } - return p.maxPacketSize - size -} - -// PackCoalescedPacket packs a new packet. -// It packs an Initial / Handshake if there is data to send in these packet number spaces. -// It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { - maxPacketSize := p.maxPacketSize - if p.perspective == protocol.PerspectiveClient { - maxPacketSize = protocol.MinInitialPacketSize - } - var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader - var initialPayload, handshakePayload, appDataPayload *payload - var numPackets int - // Try packing an Initial packet. - initialSealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil && err != handshake.ErrKeysDropped { - return nil, err - } - var size protocol.ByteCount - if initialSealer != nil { - initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) - if initialPayload != nil { - size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) - numPackets++ - } - } - - // Add a Handshake packet. - var handshakeSealer sealer - if size < maxPacketSize-protocol.MinCoalescedPacketSize { - var err error - handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() - if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { - return nil, err - } - if handshakeSealer != nil { - handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) - if handshakePayload != nil { - s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) - size += s - numPackets++ - } - } - } - - // Add a 0-RTT / 1-RTT packet. - var appDataSealer sealer - appDataEncLevel := protocol.Encryption1RTT - if size < maxPacketSize-protocol.MinCoalescedPacketSize { - var err error - appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size) - if err != nil { - return nil, err - } - if appDataHdr != nil { - if appDataHdr.IsLongHeader { - appDataEncLevel = protocol.Encryption0RTT - } - if appDataPayload != nil { - size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) - numPackets++ - } - } - } - - if numPackets == 0 { - return nil, nil - } - - buffer := getPacketBuffer() - packet := &coalescedPacket{ - buffer: buffer, - packets: make([]*packetContents, 0, numPackets), - } - if initialPayload != nil { - padding := p.initialPaddingLen(initialPayload.frames, size) - cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - if handshakePayload != nil { - cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - if appDataPayload != nil { - cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - return packet, nil -} - -// PackPacket packs a packet in the application data packet number space. -// It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket() (*packedPacket, error) { - sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0) - if payload == nil { - return nil, nil - } - buffer := getPacketBuffer() - encLevel := protocol.Encryption1RTT - if hdr.IsLongHeader { - encLevel = protocol.Encryption0RTT - } - cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: cont, - }, nil -} - -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { - var s cryptoStream - var hasRetransmission bool - //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. - switch encLevel { - case protocol.EncryptionInitial: - s = p.initialStream - hasRetransmission = p.retransmissionQueue.HasInitialData() - case protocol.EncryptionHandshake: - s = p.handshakeStream - hasRetransmission = p.retransmissionQueue.HasHandshakeData() - } - - hasData := s.HasData() - var ack *wire.AckFrame - if encLevel == protocol.EncryptionInitial || currentSize == 0 { - ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) - } - if !hasData && !hasRetransmission && ack == nil { - // nothing to send - return nil, nil - } - - var payload payload - if ack != nil { - payload.ack = ack - payload.length = ack.Length(p.version) - maxPacketSize -= payload.length - } - hdr := p.getLongHeader(encLevel) - maxPacketSize -= hdr.GetLength(p.version) - if hasRetransmission { - for { - var f wire.Frame - //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s - switch encLevel { - case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) - case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) - } - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - frameLen := f.Length(p.version) - payload.length += frameLen - maxPacketSize -= frameLen - } - } else if s.HasData() { - cf := s.PopCryptoFrame(maxPacketSize) - payload.frames = []ackhandler.Frame{{Frame: cf}} - payload.length += cf.Length(p.version) - } - return hdr, &payload -} - -func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) { - var sealer sealer - var encLevel protocol.EncryptionLevel - var hdr *wire.ExtendedHeader - oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() - if err == nil { - encLevel = protocol.Encryption1RTT - sealer = oneRTTSealer - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - } else { - // 1-RTT sealer not yet available - if p.perspective != protocol.PerspectiveClient { - return nil, nil, nil - } - sealer, err = p.cryptoSetup.Get0RTTSealer() - if sealer == nil || err != nil { - return nil, nil, nil - } - encLevel = protocol.Encryption0RTT - hdr = p.getLongHeader(protocol.Encryption0RTT) - } - - maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) - return sealer, hdr, payload -} - -func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { - payload := p.composeNextPacket(maxPayloadSize, ackAllowed) - - // check if we have anything to send - if len(payload.frames) == 0 { - if payload.ack == nil { - return nil - } - // the packet only contains an ACK - if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { - ping := &wire.PingFrame{} - // don't retransmit the PING frame when it is lost - payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}}) - payload.length += ping.Length(p.version) - p.numNonAckElicitingAcks = 0 - } else { - p.numNonAckElicitingAcks++ - } - } else { - p.numNonAckElicitingAcks = 0 - } - return payload -} - -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { - payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} - - var hasDatagram bool - if p.datagramQueue != nil { - if datagram := p.datagramQueue.Get(); datagram != nil { - payload.frames = append(payload.frames, ackhandler.Frame{ - Frame: datagram, - // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. - OnLost: func(wire.Frame) {}, - }) - payload.length += datagram.Length(p.version) - hasDatagram = true - } - } - - var ack *wire.AckFrame - hasData := p.framer.HasData() - hasRetransmission := p.retransmissionQueue.HasAppData() - // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued - if !hasDatagram && ackAllowed { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) - if ack != nil { - payload.ack = ack - payload.length += ack.Length(p.version) - } - } - - if ack == nil && !hasData && !hasRetransmission { - return payload - } - - if hasRetransmission { - for { - remainingLen := maxFrameSize - payload.length - if remainingLen < protocol.MinStreamFrameSize { - break - } - f := p.retransmissionQueue.GetAppDataFrame(remainingLen) - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - payload.length += f.Length(p.version) - } - } - - if hasData { - var lengthAdded protocol.ByteCount - payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded - - payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded - } - return payload -} - -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { - var hdr *wire.ExtendedHeader - var payload *payload - var sealer sealer - //nolint:exhaustive // Probe packets are never sent for 0-RTT. - switch encLevel { - case protocol.EncryptionInitial: - var err error - sealer, err = p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, err - } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) - case protocol.EncryptionHandshake: - var err error - sealer, err = p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, err - } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) - case protocol.Encryption1RTT: - oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - sealer = oneRTTSealer - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) - default: - panic("unknown encryption level") - } - if payload == nil { - return nil, nil - } - size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - var padding protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - padding = p.initialPaddingLen(payload.frames, size) - } - buffer := getPacketBuffer() - cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: cont, - }, nil -} - -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { - payload := &payload{ - frames: []ackhandler.Frame{ping}, - length: ping.Length(p.version), - } - buffer := getPacketBuffer() - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) - contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) - if err != nil { - return nil, err - } - contents.isMTUProbePacket = true - return &packedPacket{ - buffer: buffer, - packetContents: contents, - }, nil -} - -func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { - switch encLevel { - case protocol.EncryptionInitial: - sealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionInitial) - return sealer, hdr, nil - case protocol.Encryption0RTT: - sealer, err := p.cryptoSetup.Get0RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.Encryption0RTT) - return sealer, hdr, nil - case protocol.EncryptionHandshake: - sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionHandshake) - return sealer, hdr, nil - case protocol.Encryption1RTT: - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - return sealer, hdr, nil - default: - return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) - } -} - -func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) - hdr := &wire.ExtendedHeader{} - hdr.PacketNumber = pn - hdr.PacketNumberLen = pnLen - hdr.DestConnectionID = p.getDestConnID() - hdr.KeyPhase = kp - return hdr -} - -func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) - hdr := &wire.ExtendedHeader{ - PacketNumber: pn, - PacketNumberLen: pnLen, - } - hdr.IsLongHeader = true - hdr.Version = p.version - hdr.SrcConnectionID = p.srcConnID - hdr.DestConnectionID = p.getDestConnID() - - //nolint:exhaustive // 1-RTT packets are not long header packets. - switch encLevel { - case protocol.EncryptionInitial: - hdr.Type = protocol.PacketTypeInitial - hdr.Token = p.token - case protocol.EncryptionHandshake: - hdr.Type = protocol.PacketTypeHandshake - case protocol.Encryption0RTT: - hdr.Type = protocol.PacketType0RTT - } - return hdr -} - -// writeSinglePacket packs a single packet. -func (p *packetPacker) writeSinglePacket( - hdr *wire.ExtendedHeader, - payload *payload, - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packedPacket, error) { - buffer := getPacketBuffer() - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) - } - contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: contents, - }, nil -} - -func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { - var paddingLen protocol.ByteCount - pnLen := protocol.ByteCount(header.PacketNumberLen) - if payload.length < 4-pnLen { - paddingLen = 4 - pnLen - payload.length - } - paddingLen += padding - if header.IsLongHeader { - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen - } - - hdrOffset := buffer.Len() - buf := bytes.NewBuffer(buffer.Data) - if err := header.Write(buf, p.version); err != nil { - return nil, err - } - payloadOffset := buf.Len() - - if payload.ack != nil { - if err := payload.ack.Write(buf, p.version); err != nil { - return nil, err - } - } - if paddingLen > 0 { - buf.Write(make([]byte, paddingLen)) - } - for _, frame := range payload.frames { - if err := frame.Write(buf, p.version); err != nil { - return nil, err - } - } - - if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) - } - if !isMTUProbePacket { - if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) - } - } - - raw := buffer.Data - // encrypt the packet - raw = raw[:buf.Len()] - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] - // apply header protection - pnOffset := payloadOffset - int(header.PacketNumberLen) - sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) - buffer.Data = raw - - num := p.pnManager.PopPacketNumber(encLevel) - if num != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - return &packetContents{ - header: header, - ack: payload.ack, - frames: payload.frames, - length: buffer.Len() - hdrOffset, - }, nil -} - -func (p *packetPacker) SetToken(token []byte) { - p.token = token -} - -// When a higher MTU is discovered, use it. -func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { - p.maxPacketSize = s -} - -// If the peer sets a max_packet_size that's smaller than the size we're currently using, -// we need to reduce the size of packets we send. -func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { - if params.MaxUDPPayloadSize != 0 { - p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxUDPPayloadSize) - } -} diff --git a/internal/quic-go/packet_packer_test.go b/internal/quic-go/packet_packer_test.go deleted file mode 100644 index d069d0cc..00000000 --- a/internal/quic-go/packet_packer_test.go +++ /dev/null @@ -1,1556 +0,0 @@ -package quic - -import ( - "bytes" - "fmt" - "math/rand" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/mocks" - mockackhandler "github.com/imroc/req/v3/internal/quic-go/mocks/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet packer", func() { - const maxPacketSize protocol.ByteCount = 1357 - const version = protocol.VersionTLS - - var ( - packer *packetPacker - retransmissionQueue *retransmissionQueue - datagramQueue *datagramQueue - framer *MockFrameSource - ackFramer *MockAckFrameSource - initialStream *MockCryptoStream - handshakeStream *MockCryptoStream - sealingManager *MockSealingManager - pnManager *mockackhandler.MockSentPacketHandler - ) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - - parsePacket := func(data []byte) []*wire.ExtendedHeader { - var hdrs []*wire.ExtendedHeader - for len(data) > 0 { - hdr, payload, rest, err := wire.ParsePacket(data, connID.Len()) - Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(r, version) - Expect(err).ToNot(HaveOccurred()) - if extHdr.IsLongHeader { - ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(rest) + int(extHdr.PacketNumberLen))) - ExpectWithOffset(1, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) - } else { - ExpectWithOffset(1, len(payload)+int(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) - } - data = rest - hdrs = append(hdrs, extHdr) - } - return hdrs - } - - appendFrames := func(fs, frames []ackhandler.Frame) ([]ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount - for _, f := range frames { - length += f.Frame.Length(packer.version) - } - return append(fs, frames...), length - } - - expectAppendStreamFrames := func(frames ...ackhandler.Frame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - return appendFrames(fs, frames) - }) - } - - expectAppendControlFrames := func(frames ...ackhandler.Frame) { - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - return appendFrames(fs, frames) - }) - } - - BeforeEach(func() { - rand.Seed(GinkgoRandomSeed()) - retransmissionQueue = newRetransmissionQueue(version) - mockSender := NewMockStreamSender(mockCtrl) - mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() - initialStream = NewMockCryptoStream(mockCtrl) - handshakeStream = NewMockCryptoStream(mockCtrl) - framer = NewMockFrameSource(mockCtrl) - ackFramer = NewMockAckFrameSource(mockCtrl) - sealingManager = NewMockSealingManager(mockCtrl) - pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) - datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) - - packer = newPacketPacker( - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - func() protocol.ConnectionID { return connID }, - initialStream, - handshakeStream, - pnManager, - retransmissionQueue, - &net.TCPAddr{}, - sealingManager, - framer, - ackFramer, - datagramQueue, - protocol.PerspectiveServer, - version, - ) - packer.version = version - packer.maxPacketSize = maxPacketSize - }) - - Context("determining the maximum packet size", func() { - It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { - Expect(getMaxPacketSize(&net.TCPAddr{})).To(BeEquivalentTo(protocol.MinInitialPacketSize)) - }) - - It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { - addr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} - Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.InitialPacketSizeIPv4)) - }) - - It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { - ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") - addr := &net.UDPAddr{IP: ip, Port: 1337} - Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.InitialPacketSizeIPv6)) - }) - }) - - Context("generating a packet header", func() { - It("uses the Long Header format", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen3) - h := packer.getLongHeader(protocol.EncryptionHandshake) - Expect(h.IsLongHeader).To(BeTrue()) - Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) - Expect(h.Version).To(Equal(packer.version)) - }) - - It("sets source and destination connection ID", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - packer.srcConnID = srcConnID - packer.getDestConnID = func() protocol.ConnectionID { return destConnID } - h := packer.getLongHeader(protocol.EncryptionHandshake) - Expect(h.SrcConnectionID).To(Equal(srcConnID)) - Expect(h.DestConnectionID).To(Equal(destConnID)) - }) - - It("gets a short header", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4) - h := packer.getShortHeader(protocol.KeyPhaseOne) - Expect(h.IsLongHeader).To(BeFalse()) - Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) - Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(h.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - }) - }) - - Context("encrypting packets", func() { - It("encrypts a packet", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337)) - sealer := mocks.NewMockShortHeaderSealer(mockCtrl) - sealer.EXPECT().Overhead().Return(4).AnyTimes() - var hdrRaw []byte - gomock.InOrder( - sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne), - sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), gomock.Any()).DoAndReturn(func(_, src []byte, _ protocol.PacketNumber, aad []byte) []byte { - hdrRaw = append([]byte{}, aad...) - return append(src, []byte{0xde, 0xca, 0xfb, 0xad}...) - }), - sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(sample []byte, firstByte *byte, pnBytes []byte) { - Expect(firstByte).To(Equal(&hdrRaw[0])) - Expect(pnBytes).To(Equal(hdrRaw[len(hdrRaw)-2:])) - *firstByte ^= 0xff // invert the first byte - // invert the packet number bytes - for i := range pnBytes { - pnBytes[i] ^= 0xff - } - }), - ) - framer.EXPECT().HasData().Return(true) - sealingManager.EXPECT().GetInitialSealer().Return(nil, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendControlFrames() - f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} - expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - hdrRawEncrypted := append([]byte{}, hdrRaw...) - hdrRawEncrypted[0] ^= 0xff - hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff - hdrRawEncrypted[len(hdrRaw)-1] ^= 0xff - Expect(p.buffer.Data[0:len(hdrRaw)]).To(Equal(hdrRawEncrypted)) - Expect(p.buffer.Data[p.buffer.Len()-4:]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - }) - }) - - Context("packing packets", func() { - // getSealer gets a sealer that's expected to seal exactly one packet - getSealer := func() *mocks.MockShortHeaderSealer { - sealer := mocks.NewMockShortHeaderSealer(mockCtrl) - sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes() - sealer.EXPECT().Overhead().Return(7).AnyTimes() - sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { - return append(src, bytes.Repeat([]byte{'s'}, sealer.Overhead())...) - }).AnyTimes() - return sealer - } - - Context("packing ACK packets", func() { - It("doesn't pack a packet if there's no ACK to send", func() { - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) - p, err := packer.MaybePackAckPacket(false) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("packs Initial ACK-only packets, and pads them (for the client)", func() { - packer.perspective = protocol.PerspectiveClient - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.MaybePackAckPacket(false) - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.ack).To(Equal(ack)) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - parsePacket(p.buffer.Data) - }) - - It("packs Initial ACK-only packets, and doesn't pads them (for the server)", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.MaybePackAckPacket(false) - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.ack).To(Equal(ack)) - parsePacket(p.buffer.Data) - }) - - It("packs 1-RTT ACK-only packets", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - p, err := packer.MaybePackAckPacket(true) - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.ack).To(Equal(ack)) - parsePacket(p.buffer.Data) - }) - }) - - Context("packing 0-RTT packets", func() { - BeforeEach(func() { - packer.perspective = protocol.PerspectiveClient - sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable).AnyTimes() - initialStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).AnyTimes() - handshakeStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).AnyTimes() - }) - - It("packs a 0-RTT packet", func() { - sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil).AnyTimes() - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) - cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{MaximumData: 0x1337}} - framer.EXPECT().HasData().Return(true) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(frames).To(BeEmpty()) - return append(frames, cf), cf.Length(packer.version) - }) - // TODO: check sizes - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - return frames, 0 - }) - p, err := packer.PackCoalescedPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketType0RTT)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{cf})) - }) - }) - - Context("packing CONNECTION_CLOSE", func() { - It("clears the reason phrase for crypto errors", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - quicErr := qerr.NewCryptoError(0x42, "crypto error") - quicErr.FrameType = 0x1234 - p, err := packer.PackConnectionClose(quicErr) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(0x100 + 0x42)) - Expect(ccf.FrameType).To(BeEquivalentTo(0x1234)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - }) - - It("packs a CONNECTION_CLOSE in 1-RTT", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - // expect no framer.PopStreamFrames - p, err := packer.PackConnectionClose(&qerr.TransportError{ - ErrorCode: qerr.CryptoBufferExceeded, - ErrorMessage: "test error", - }) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.CryptoBufferExceeded)) - Expect(ccf.ReasonPhrase).To(Equal("test error")) - }) - - It("packs a CONNECTION_CLOSE in all available encryption levels, and replaces application errors in Initial and Handshake", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1)) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackApplicationClose(&qerr.ApplicationError{ - ErrorCode: 0x1337, - ErrorMessage: "test error", - }) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(3)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[2].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[2].header.PacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(p.packets[2].frames).To(HaveLen(1)) - Expect(p.packets[2].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[2].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeTrue()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) - Expect(ccf.ReasonPhrase).To(Equal("test error")) - }) - - It("packs a CONNECTION_CLOSE in all available encryption levels, as a client", func() { - packer.perspective = protocol.PerspectiveClient - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(1)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(2)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackApplicationClose(&qerr.ApplicationError{ - ErrorCode: 0x1337, - ErrorMessage: "test error", - }) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(2)) - Expect(p.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeTrue()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) - Expect(ccf.ReasonPhrase).To(Equal("test error")) - }) - - It("packs a CONNECTION_CLOSE in all available encryption levels and pads, as a client", func() { - packer.perspective = protocol.PerspectiveClient - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(2)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackApplicationClose(&qerr.ApplicationError{ - ErrorCode: 0x1337, - ErrorMessage: "test error", - }) - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(2)) - Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.Type).To(Equal(protocol.PacketType0RTT)) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeTrue()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) - Expect(ccf.ReasonPhrase).To(Equal("test error")) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) - }) - }) - - Context("packing normal packets", func() { - It("returns nil when no packet is queued", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - // don't expect any calls to PopPacketNumber - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) - framer.EXPECT().HasData() - p, err := packer.PackPacket() - Expect(p).To(BeNil()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("packs single packets", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendControlFrames() - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, - } - expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - b := &bytes.Buffer{} - f.Write(b, packer.version) - Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.buffer.Data).To(ContainSubstring(b.String())) - }) - - It("stores the encryption level a packet was sealed with", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }}) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - }) - - It("packs a single ACK", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - framer.EXPECT().HasData() - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.ack).To(Equal(ack)) - }) - - It("packs control frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - frames := []ackhandler.Frame{ - {Frame: &wire.ResetStreamFrame{}}, - {Frame: &wire.MaxDataFrame{}}, - } - expectAppendControlFrames(frames...) - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal(frames)) - Expect(p.buffer.Len()).ToNot(BeZero()) - }) - - It("packs DATAGRAM frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - f := &wire.DatagramFrame{ - DataLenPresent: true, - Data: []byte("foobar"), - } - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - datagramQueue.AddAndWait(f) - }() - // make sure the DATAGRAM has actually been queued - time.Sleep(scaleDuration(20 * time.Millisecond)) - - framer.EXPECT().HasData() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].Frame).To(Equal(f)) - Expect(p.buffer.Data).ToNot(BeEmpty()) - Eventually(done).Should(BeClosed()) - }) - - It("accounts for the space consumed by control frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - var maxSize protocol.ByteCount - gomock.InOrder( - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - maxSize = maxLen - return fs, 444 - }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(maxSize - 444)) - return fs, 0 - }), - ) - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) - - It("pads if payload length + packet number length is smaller than 4, for Long Header packets", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealer := getSealer() - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - handshakeStream.EXPECT().HasData() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.packets).To(HaveLen(1)) - // cut off the tag that the mock sealer added - // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) - Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(packet.buffer.Data) - extHdr, err := hdr.ParseExtended(r, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead())) - // the first bytes of the payload should be a 2 PADDING frames... - firstPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(firstPayloadByte).To(Equal(byte(0))) - secondPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(secondPayloadByte).To(Equal(byte(0))) - // ... followed by the PING - frameParser := wire.NewFrameParser(false, packer.version) - frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(r.Len()).To(Equal(sealer.Overhead())) - }) - - It("pads if payload length + packet number length is smaller than 4", func() { - f := &wire.StreamFrame{ - StreamID: 0x10, // small stream ID, such that only a single byte is consumed - Fin: true, - } - Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealer := getSealer() - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // cut off the tag that the mock sealer added - packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) - Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(packet.buffer.Data) - extHdr, err := hdr.ParseExtended(r, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) - // the first byte of the payload should be a PADDING frame... - firstPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(firstPayloadByte).To(Equal(byte(0))) - // ... followed by the STREAM frame - frameParser := wire.NewFrameParser(true, packer.version) - frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - sf := frame.(*wire.StreamFrame) - Expect(sf.StreamID).To(Equal(f.StreamID)) - Expect(sf.Fin).To(Equal(f.Fin)) - Expect(sf.Data).To(BeEmpty()) - Expect(r.Len()).To(BeZero()) - }) - - It("packs multiple small STREAM frames into single packet", func() { - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 1"), - DataLenPresent: true, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 2"), - DataLenPresent: true, - } - f3 := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("frame 3"), - DataLenPresent: true, - } - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[2].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) - }) - - Context("making ACK packets ack-eliciting", func() { - sendMaxNumNonAckElicitingAcks := func() { - for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.ack).ToNot(BeNil()) - Expect(p.frames).To(BeEmpty()) - } - } - - It("adds a PING frame when it's supposed to send a ack-eliciting packet", func() { - sendMaxNumNonAckElicitingAcks() - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - var hasPing bool - for _, f := range p.frames { - if _, ok := f.Frame.(*wire.PingFrame); ok { - hasPing = true - Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost - } - } - Expect(hasPing).To(BeTrue()) - // make sure the next packet doesn't contain another PING - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.ack).ToNot(BeNil()) - Expect(p.frames).To(BeEmpty()) - }) - - It("waits until there's something to send before adding a PING frame", func() { - sendMaxNumNonAckElicitingAcks() - // nothing to send - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - expectAppendControlFrames() - expectAppendStreamFrames() - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - // now add some frame to send - expectAppendControlFrames() - expectAppendStreamFrames() - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.ack).To(Equal(ack)) - var hasPing bool - for _, f := range p.frames { - if _, ok := f.Frame.(*wire.PingFrame); ok { - hasPing = true - Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost - } - } - Expect(hasPing).To(BeTrue()) - }) - - It("doesn't send a PING if it already sent another ack-eliciting frame", func() { - sendMaxNumNonAckElicitingAcks() - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - expectAppendStreamFrames() - expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) - }) - }) - - Context("handling transport parameters", func() { - It("lowers the maximum packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // now reduce the maxPacketSize - packer.HandleTransportParameters(&wire.TransportParameters{ - MaxUDPPayloadSize: maxPacketSize - 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) - return nil, 0 - }) - expectAppendStreamFrames() - _, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't increase the max packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // now try to increase the maxPacketSize - packer.HandleTransportParameters(&wire.TransportParameters{ - MaxUDPPayloadSize: maxPacketSize + 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize)) - return nil, 0 - }) - expectAppendStreamFrames() - _, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) - }) - - Context("max packet size", func() { - It("increases the max packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // now reduce the maxPacketSize - const packetSizeIncrease = 50 - packer.SetMaxPacketSize(maxPacketSize + packetSizeIncrease) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize + packetSizeIncrease)) - return nil, 0 - }) - expectAppendStreamFrames() - _, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) - }) - }) - - Context("packing crypto packets", func() { - It("sets the length", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - f := &wire.CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - handshakeStream.EXPECT().HasData().Return(true).AnyTimes() - handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - parsePacket(p.buffer.Data) - }) - - It("packs an Initial packet and pads it", func() { - packer.perspective = protocol.PerspectiveClient - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} - }) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(1)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - }) - - It("packs a maximum size Handshake packet", func() { - var f *wire.CryptoFrame - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData().Return(true).Times(2) - handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - f = &wire.CryptoFrame{Offset: 0x1337} - f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version)-1)) - Expect(f.Length(packer.version)).To(Equal(size)) - return f - }) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - parsePacket(p.buffer.Data) - }) - - It("packs a coalesced packet with Initial / Handshake, and pads it", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - // don't EXPECT any calls for a Handshake ACK frame - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} - }) - handshakeStream.EXPECT().HasData().Return(true).Times(2) - handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} - }) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) - }) - - It("packs a coalesced packet with Initial / super short Handshake, and pads it", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - // don't EXPECT any calls for a Handshake ACK frame - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} - }) - handshakeStream.EXPECT().HasData() - packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) - }) - - It("packs a coalesced packet with super short Initial / super short Handshake, and pads it", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any()) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() - packer.retransmissionQueue.AddInitial(&wire.PingFrame{}) - packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) - }) - - It("packs a coalesced packet with Initial / super short 1-RTT, and pads it", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} - }) - expectAppendControlFrames() - expectAppendStreamFrames() - framer.EXPECT().HasData().Return(true) - packer.retransmissionQueue.AddAppData(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].IsLongHeader).To(BeFalse()) - }) - - It("packs a coalesced packet with Initial / 0-RTT, and pads it", func() { - packer.perspective = protocol.PerspectiveClient - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - // don't EXPECT any calls for a Handshake ACK frame - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} - }) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) - Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) - }) - - It("packs a coalesced packet with Handshake / 1-RTT", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - framer.EXPECT().HasData().Return(true) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - // don't EXPECT any calls for a 1-RTT ACK frame - handshakeStream.EXPECT().HasData().Return(true).Times(2) - handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} - }) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeNumerically("<", 100)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - hdr, _, rest, err = wire.ParsePacket(rest, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(rest).To(BeEmpty()) - }) - - It("doesn't add a coalesced packet if the remaining size is smaller than MaxCoalescedPacketSize", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - handshakeStream.EXPECT().HasData().Return(true).Times(2) - handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { - s := size - protocol.MinCoalescedPacketSize - f := &wire.CryptoFrame{Offset: 0x1337} - f.Data = bytes.Repeat([]byte{'f'}, int(s-f.Length(packer.version)-1)) - Expect(f.Length(packer.version)).To(Equal(s)) - return f - }) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(len(p.buffer.Data)).To(BeEquivalentTo(maxPacketSize - protocol.MinCoalescedPacketSize)) - parsePacket(p.buffer.Data) - }) - - It("pads if payload length + packet number length is smaller than 4, for Long Header packets", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealer := getSealer() - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - handshakeStream.EXPECT().HasData() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.packets).To(HaveLen(1)) - // cut off the tag that the mock sealer added - // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) - Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(packet.buffer.Data) - extHdr, err := hdr.ParseExtended(r, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead())) - // the first bytes of the payload should be a 2 PADDING frames... - firstPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(firstPayloadByte).To(Equal(byte(0))) - secondPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(secondPayloadByte).To(Equal(byte(0))) - // ... followed by the PING - frameParser := wire.NewFrameParser(false, packer.version) - frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(r.Len()).To(Equal(sealer.Overhead())) - }) - - It("adds retransmissions", func() { - f := &wire.CryptoFrame{Data: []byte("Initial")} - retransmissionQueue.AddInitial(f) - retransmissionQueue.AddHandshake(&wire.CryptoFrame{Data: []byte("Handshake")}) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData() - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) - }) - - It("sends an Initial packet containing only an ACK", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - initialStream.EXPECT().HasData().Times(2) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) - }) - - It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - initialStream.EXPECT().HasData() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("sends a Handshake packet containing only an ACK", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).Return(ack) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData().Times(2) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) - }) - - for _, pers := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { - perspective := pers - - It(fmt.Sprintf("pads Initial packets to the required minimum packet size, for the %s", perspective), func() { - token := []byte("initial token") - packer.SetToken(token) - f := &wire.CryptoFrame{Data: []byte("foobar")} - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) - packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Token).To(Equal(token)) - Expect(p.packets[0].frames).To(HaveLen(1)) - cf := p.packets[0].frames[0].Frame.(*wire.CryptoFrame) - Expect(cf.Data).To(Equal([]byte("foobar"))) - }) - } - - It("adds an ACK frame", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 42, Largest: 1337}}} - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false).Return(ack) - initialStream.EXPECT().HasData().Return(true).Times(2) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) - packer.version = protocol.VersionTLS - packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - }) - }) - - Context("packing probe packets", func() { - for _, pers := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { - perspective := pers - - It(fmt.Sprintf("packs an Initial probe packet and pads it, for the %s", perspective), func() { - packer.perspective = perspective - f := &wire.CryptoFrame{Data: []byte("Initial")} - retransmissionQueue.AddInitial(f) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData() - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - - packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(packet.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(packet.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(Equal(f)) - parsePacket(packet.buffer.Data) - }) - - It(fmt.Sprintf("packs an Initial probe packet with 1 byte payload, for the %s", perspective), func() { - packer.perspective = perspective - retransmissionQueue.AddInitial(&wire.PingFrame{}) - sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) - initialStream.EXPECT().HasData() - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - - packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(packet.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) - Expect(packet.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - parsePacket(packet.buffer.Data) - }) - } - - It("packs a Handshake probe packet", func() { - f := &wire.CryptoFrame{Data: []byte("Handshake")} - retransmissionQueue.AddHandshake(f) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - handshakeStream.EXPECT().HasData() - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - - packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(Equal(f)) - parsePacket(packet.buffer.Data) - }) - - It("packs a full size Handshake probe packet", func() { - f := &wire.CryptoFrame{Data: make([]byte, 2000)} - retransmissionQueue.AddHandshake(f) - sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - handshakeStream.EXPECT().HasData() - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - - packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - Expect(packet.length).To(Equal(maxPacketSize)) - parsePacket(packet.buffer.Data) - }) - - It("packs a 1-RTT probe packet", func() { - f := &wire.StreamFrame{Data: []byte("1-RTT")} - retransmissionQueue.AddInitial(f) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - framer.EXPECT().HasData().Return(true) - expectAppendControlFrames() - expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - - packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(Equal(f)) - }) - - It("packs a full size 1-RTT probe packet", func() { - f := &wire.StreamFrame{Data: make([]byte, 2000)} - retransmissionQueue.AddInitial(f) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - framer.EXPECT().HasData().Return(true) - expectAppendControlFrames() - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { - sf, split := f.MaybeSplitOffFrame(maxSize, packer.version) - Expect(split).To(BeTrue()) - return append(fs, ackhandler.Frame{Frame: sf}), sf.Length(packer.version) - }) - - packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(packet.length).To(Equal(maxPacketSize)) - }) - - It("returns nil if there's no probe data to send", func() { - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) - framer.EXPECT().HasData() - - packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(packet).To(BeNil()) - }) - - It("packs an MTU probe packet", func() { - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) - ping := ackhandler.Frame{Frame: &wire.PingFrame{}} - const probePacketSize = maxPacketSize + 42 - p, err := packer.PackMTUProbePacket(ping, probePacketSize) - Expect(err).ToNot(HaveOccurred()) - Expect(p.length).To(BeEquivalentTo(probePacketSize)) - Expect(p.header.IsLongHeader).To(BeFalse()) - Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(0x43))) - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.buffer.Data).To(HaveLen(int(probePacketSize))) - Expect(p.packetContents.isMTUProbePacket).To(BeTrue()) - }) - }) - }) -}) - -var _ = Describe("Converting to AckHandler packets", func() { - It("convert a packet", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, - frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, - ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, - length: 42, - } - t := time.Now() - p := packet.ToAckHandlerPacket(t, nil) - Expect(p.Length).To(Equal(protocol.ByteCount(42))) - Expect(p.Frames).To(Equal(packet.frames)) - Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100))) - Expect(p.SendTime).To(Equal(t)) - }) - - It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, - frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, - } - p := packet.ToAckHandlerPacket(time.Now(), nil) - Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) - }) - - It("marks MTU probe packets", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, - isMTUProbePacket: true, - } - Expect(packet.ToAckHandlerPacket(time.Now(), nil).IsPathMTUProbePacket).To(BeTrue()) - }) - - DescribeTable( - "doesn't overwrite the OnLost callback, if it is set", - func(hdr wire.Header) { - var pingLost bool - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: hdr}, - frames: []ackhandler.Frame{ - {Frame: &wire.MaxDataFrame{}}, - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, - }, - } - p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0].OnLost).ToNot(BeNil()) - p.Frames[1].OnLost(nil) - Expect(pingLost).To(BeTrue()) - }, - Entry(protocol.EncryptionInitial.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeInitial}), - Entry(protocol.EncryptionHandshake.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}), - Entry(protocol.Encryption0RTT.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketType0RTT}), - Entry(protocol.Encryption1RTT.String(), wire.Header{}), - ) -}) diff --git a/internal/quic-go/packet_unpacker.go b/internal/quic-go/packet_unpacker.go deleted file mode 100644 index 829907ef..00000000 --- a/internal/quic-go/packet_unpacker.go +++ /dev/null @@ -1,196 +0,0 @@ -package quic - -import ( - "bytes" - "fmt" - "time" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type headerDecryptor interface { - DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) -} - -type headerParseError struct { - err error -} - -func (e *headerParseError) Unwrap() error { - return e.err -} - -func (e *headerParseError) Error() string { - return e.err.Error() -} - -type unpackedPacket struct { - packetNumber protocol.PacketNumber // the decoded packet number - hdr *wire.ExtendedHeader - encryptionLevel protocol.EncryptionLevel - data []byte -} - -// The packetUnpacker unpacks QUIC packets. -type packetUnpacker struct { - cs handshake.CryptoSetup - - version protocol.VersionNumber -} - -var _ unpacker = &packetUnpacker{} - -func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { - return &packetUnpacker{ - cs: cs, - version: version, - } -} - -// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. -// If any other error occurred when parsing the header, the error is of type headerParseError. -// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. -func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { - var encLevel protocol.EncryptionLevel - var extHdr *wire.ExtendedHeader - var decrypted []byte - //nolint:exhaustive // Retry packets can't be unpacked. - switch hdr.Type { - case protocol.PacketTypeInitial: - encLevel = protocol.EncryptionInitial - opener, err := u.cs.GetInitialOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) - if err != nil { - return nil, err - } - case protocol.PacketTypeHandshake: - encLevel = protocol.EncryptionHandshake - opener, err := u.cs.GetHandshakeOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) - if err != nil { - return nil, err - } - case protocol.PacketType0RTT: - encLevel = protocol.Encryption0RTT - opener, err := u.cs.Get0RTTOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) - if err != nil { - return nil, err - } - default: - if hdr.IsLongHeader { - return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) - } - encLevel = protocol.Encryption1RTT - opener, err := u.cs.Get1RTTOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data) - if err != nil { - return nil, err - } - } - - return &unpackedPacket{ - hdr: extHdr, - packetNumber: extHdr.PacketNumber, - encryptionLevel: encLevel, - data: decrypted, - }, nil -} - -func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) - // If the reserved bits are set incorrectly, we still need to continue unpacking. - // This avoids a timing side-channel, which otherwise might allow an attacker - // to gain information about the header encryption. - if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, parseErr - } - extHdrLen := extHdr.ParsedLen() - extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) - decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) - if err != nil { - return nil, nil, err - } - if parseErr != nil { - return nil, nil, parseErr - } - return extHdr, decrypted, nil -} - -func (u *packetUnpacker) unpackShortHeaderPacket( - opener handshake.ShortHeaderOpener, - hdr *wire.Header, - rcvTime time.Time, - data []byte, -) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) - // If the reserved bits are set incorrectly, we still need to continue unpacking. - // This avoids a timing side-channel, which otherwise might allow an attacker - // to gain information about the header encryption. - if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, parseErr - } - extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) - extHdrLen := extHdr.ParsedLen() - decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) - if err != nil { - return nil, nil, err - } - if parseErr != nil { - return nil, nil, parseErr - } - return extHdr, decrypted, nil -} - -// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. -func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { - extHdr, err := unpackHeader(hd, hdr, data, u.version) - if err != nil && err != wire.ErrInvalidReservedBits { - return nil, &headerParseError{err: err} - } - return extHdr, err -} - -func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { - r := bytes.NewReader(data) - - hdrLen := hdr.ParsedLen() - if protocol.ByteCount(len(data)) < hdrLen+4+16 { - //nolint:stylecheck - return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) - } - // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. - // 1. save a copy of the 4 bytes - origPNBytes := make([]byte, 4) - copy(origPNBytes, data[hdrLen:hdrLen+4]) - // 2. decrypt the header, assuming a 4 byte packet number - hd.DecryptHeader( - data[hdrLen+4:hdrLen+4+16], - &data[0], - data[hdrLen:hdrLen+4], - ) - // 3. parse the header (and learn the actual length of the packet number) - extHdr, parseErr := hdr.ParseExtended(r, version) - if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, parseErr - } - // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier - if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { - copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) - } - return extHdr, parseErr -} diff --git a/internal/quic-go/packet_unpacker_test.go b/internal/quic-go/packet_unpacker_test.go deleted file mode 100644 index a111faa4..00000000 --- a/internal/quic-go/packet_unpacker_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package quic - -import ( - "bytes" - "errors" - "time" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet Unpacker", func() { - const version = protocol.VersionTLS - - var ( - unpacker *packetUnpacker - cs *mocks.MockCryptoSetup - connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ) - - getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { - buf := &bytes.Buffer{} - ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) - hdrLen := buf.Len() - if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { - buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))) - } - hdr, _, _, err := wire.ParsePacket(buf.Bytes(), connID.Len()) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - return hdr, buf.Bytes()[:hdrLen] - } - - BeforeEach(func() { - cs = mocks.NewMockCryptoSetup(mockCtrl) - unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) - }) - - It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 1337, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - cs.EXPECT().GetHandshakeOpener().Return(opener, nil) - _, err := unpacker.Unpack(hdr, time.Now(), data) - Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) - var headerErr *headerParseError - Expect(errors.As(err, &headerErr)).To(BeTrue()) - Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) - }) - - It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 1337, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) - opener := mocks.NewMockShortHeaderOpener(mockCtrl) - cs.EXPECT().Get1RTTOpener().Return(opener, nil) - _, err := unpacker.Unpack(hdr, time.Now(), data) - Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) - Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) - }) - - It("opens Initial packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Length: 3 + 6, // packet number len + payload - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 2, - PacketNumberLen: 3, - } - hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - gomock.InOrder( - cs.EXPECT().GetInitialOpener().Return(opener, nil), - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), - opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), - opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), - ) - packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) - Expect(packet.data).To(Equal([]byte("decrypted"))) - }) - - It("opens 0-RTT packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - Length: 3 + 6, // packet number len + payload - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 20, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - gomock.InOrder( - cs.EXPECT().Get0RTTOpener().Return(opener, nil), - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), - opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), - opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), - ) - packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) - Expect(packet.data).To(Equal([]byte("decrypted"))) - }) - - It("opens short header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 99, - PacketNumberLen: protocol.PacketNumberLen4, - } - hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockShortHeaderOpener(mockCtrl) - now := time.Now() - gomock.InOrder( - cs.EXPECT().Get1RTTOpener().Return(opener, nil), - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), - opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), - opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), - ) - packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT)) - Expect(packet.data).To(Equal([]byte("decrypted"))) - }) - - It("returns the error when getting the sealer fails", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) - }) - - It("returns the error when unpacking fails", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 3, // packet number len - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 2, - PacketNumberLen: 3, - } - hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - cs.EXPECT().GetHandshakeOpener().Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) - unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(unpackErr)) - }) - - It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - hdrRaw[0] |= 0xc - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - cs.EXPECT().GetHandshakeOpener().Return(opener, nil) - opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) - }) - - It("defends against the timing side-channel when the reserved bits are wrong, for short header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - hdrRaw[0] |= 0x18 - opener := mocks.NewMockShortHeaderOpener(mockCtrl) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - cs.EXPECT().Get1RTTOpener().Return(opener, nil) - opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) - }) - - It("returns the decryption error, when unpacking a packet with wrong reserved bits fails", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - hdrRaw[0] |= 0x18 - opener := mocks.NewMockShortHeaderOpener(mockCtrl) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - cs.EXPECT().Get1RTTOpener().Return(opener, nil) - opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) - _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) - }) - - It("decrypts the header", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 3, // packet number len - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - hdr, hdrRaw := getHeader(extHdr) - origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header - firstHdrByte := hdrRaw[0] - hdrRaw[0] ^= 0xff // invert the first byte - hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number - hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number - Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) - opener := mocks.NewMockLongHeaderOpener(mockCtrl) - cs.EXPECT().GetHandshakeOpener().Return(opener, nil) - gomock.InOrder( - // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte - opener.EXPECT().DecryptHeader( - []byte{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, - &hdrRaw[0], - append(hdrRaw[len(hdrRaw)-2:], []byte{1, 2}...)).Do(func(_ []byte, firstByte *byte, pnBytes []byte) { - *firstByte ^= 0xff // invert the first byte back - for i := range pnBytes { - pnBytes[i] ^= 0xff // invert the packet number bytes - } - }), - opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)), - opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil), - ) - data := hdrRaw - for i := 1; i <= 100; i++ { - data = append(data, uint8(i)) - } - packet, err := unpacker.Unpack(hdr, time.Now(), data) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331))) - }) -}) diff --git a/internal/quic-go/protocol/connection_id.go b/internal/quic-go/protocol/connection_id.go deleted file mode 100644 index 3aec2cd3..00000000 --- a/internal/quic-go/protocol/connection_id.go +++ /dev/null @@ -1,69 +0,0 @@ -package protocol - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" -) - -// A ConnectionID in QUIC -type ConnectionID []byte - -const maxConnectionIDLen = 20 - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID(len int) (ConnectionID, error) { - b := make([]byte, len) - if _, err := rand.Read(b); err != nil { - return nil, err - } - return ConnectionID(b), nil -} - -// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. -// It uses a length randomly chosen between 8 and 20 bytes. -func GenerateConnectionIDForInitial() (ConnectionID, error) { - r := make([]byte, 1) - if _, err := rand.Read(r); err != nil { - return nil, err - } - len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) - return GenerateConnectionID(len) -} - -// ReadConnectionID reads a connection ID of length len from the given io.Reader. -// It returns io.EOF if there are not enough bytes to read. -func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { - if len == 0 { - return nil, nil - } - c := make(ConnectionID, len) - _, err := io.ReadFull(r, c) - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return c, err -} - -// Equal says if two connection IDs are equal -func (c ConnectionID) Equal(other ConnectionID) bool { - return bytes.Equal(c, other) -} - -// Len returns the length of the connection ID in bytes -func (c ConnectionID) Len() int { - return len(c) -} - -// Bytes returns the byte representation -func (c ConnectionID) Bytes() []byte { - return []byte(c) -} - -func (c ConnectionID) String() string { - if c.Len() == 0 { - return "(empty)" - } - return fmt.Sprintf("%x", c.Bytes()) -} diff --git a/internal/quic-go/protocol/connection_id_test.go b/internal/quic-go/protocol/connection_id_test.go deleted file mode 100644 index 345e656c..00000000 --- a/internal/quic-go/protocol/connection_id_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package protocol - -import ( - "bytes" - "io" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection ID generation", func() { - It("generates random connection IDs", func() { - c1, err := GenerateConnectionID(8) - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(BeZero()) - c2, err := GenerateConnectionID(8) - Expect(err).ToNot(HaveOccurred()) - Expect(c1).ToNot(Equal(c2)) - }) - - It("generates connection IDs with the requested length", func() { - c, err := GenerateConnectionID(5) - Expect(err).ToNot(HaveOccurred()) - Expect(c.Len()).To(Equal(5)) - }) - - It("generates random length destination connection IDs", func() { - var has8ByteConnID, has20ByteConnID bool - for i := 0; i < 1000; i++ { - c, err := GenerateConnectionIDForInitial() - Expect(err).ToNot(HaveOccurred()) - Expect(c.Len()).To(BeNumerically(">=", 8)) - Expect(c.Len()).To(BeNumerically("<=", 20)) - if c.Len() == 8 { - has8ByteConnID = true - } - if c.Len() == 20 { - has20ByteConnID = true - } - } - Expect(has8ByteConnID).To(BeTrue()) - Expect(has20ByteConnID).To(BeTrue()) - }) - - It("says if connection IDs are equal", func() { - c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - Expect(c1.Equal(c1)).To(BeTrue()) - Expect(c2.Equal(c2)).To(BeTrue()) - Expect(c1.Equal(c2)).To(BeFalse()) - Expect(c2.Equal(c1)).To(BeFalse()) - }) - - It("reads the connection ID", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) - c, err := ReadConnectionID(buf, 9) - Expect(err).ToNot(HaveOccurred()) - Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) - }) - - It("returns io.EOF if there's not enough data to read", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) - _, err := ReadConnectionID(buf, 5) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns nil for a 0 length connection ID", func() { - buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) - c, err := ReadConnectionID(buf, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(c).To(BeNil()) - }) - - It("returns the length", func() { - c := ConnectionID{1, 2, 3, 4, 5, 6, 7} - Expect(c.Len()).To(Equal(7)) - }) - - It("has 0 length for the default value", func() { - var c ConnectionID - Expect(c.Len()).To(BeZero()) - }) - - It("returns the bytes", func() { - c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) - Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) - }) - - It("returns a nil byte slice for the default value", func() { - var c ConnectionID - Expect(c.Bytes()).To(BeNil()) - }) - - It("has a string representation", func() { - c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) - Expect(c.String()).To(Equal("deadbeef42")) - }) - - It("has a long string representation", func() { - c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} - Expect(c.String()).To(Equal("13370000decafbad")) - }) - - It("has a string representation for the default value", func() { - var c ConnectionID - Expect(c.String()).To(Equal("(empty)")) - }) -}) diff --git a/internal/quic-go/protocol/encryption_level.go b/internal/quic-go/protocol/encryption_level.go deleted file mode 100644 index 32d38ab1..00000000 --- a/internal/quic-go/protocol/encryption_level.go +++ /dev/null @@ -1,30 +0,0 @@ -package protocol - -// EncryptionLevel is the encryption level -// Default value is Unencrypted -type EncryptionLevel uint8 - -const ( - // EncryptionInitial is the Initial encryption level - EncryptionInitial EncryptionLevel = 1 + iota - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT - // Encryption1RTT is the 1-RTT encryption level - Encryption1RTT -) - -func (e EncryptionLevel) String() string { - switch e { - case EncryptionInitial: - return "Initial" - case EncryptionHandshake: - return "Handshake" - case Encryption0RTT: - return "0-RTT" - case Encryption1RTT: - return "1-RTT" - } - return "unknown" -} diff --git a/internal/quic-go/protocol/encryption_level_test.go b/internal/quic-go/protocol/encryption_level_test.go deleted file mode 100644 index 9b07b08b..00000000 --- a/internal/quic-go/protocol/encryption_level_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Encryption Level", func() { - It("doesn't use 0 as a value", func() { - // 0 is used in some tests - Expect(EncryptionInitial * EncryptionHandshake * Encryption0RTT * Encryption1RTT).ToNot(BeZero()) - }) - - It("has the correct string representation", func() { - Expect(EncryptionInitial.String()).To(Equal("Initial")) - Expect(EncryptionHandshake.String()).To(Equal("Handshake")) - Expect(Encryption0RTT.String()).To(Equal("0-RTT")) - Expect(Encryption1RTT.String()).To(Equal("1-RTT")) - }) -}) diff --git a/internal/quic-go/protocol/key_phase.go b/internal/quic-go/protocol/key_phase.go deleted file mode 100644 index edd740cf..00000000 --- a/internal/quic-go/protocol/key_phase.go +++ /dev/null @@ -1,36 +0,0 @@ -package protocol - -// KeyPhase is the key phase -type KeyPhase uint64 - -// Bit determines the key phase bit -func (p KeyPhase) Bit() KeyPhaseBit { - if p%2 == 0 { - return KeyPhaseZero - } - return KeyPhaseOne -} - -// KeyPhaseBit is the key phase bit -type KeyPhaseBit uint8 - -const ( - // KeyPhaseUndefined is an undefined key phase - KeyPhaseUndefined KeyPhaseBit = iota - // KeyPhaseZero is key phase 0 - KeyPhaseZero - // KeyPhaseOne is key phase 1 - KeyPhaseOne -) - -func (p KeyPhaseBit) String() string { - //nolint:exhaustive - switch p { - case KeyPhaseZero: - return "0" - case KeyPhaseOne: - return "1" - default: - return "undefined" - } -} diff --git a/internal/quic-go/protocol/key_phase_test.go b/internal/quic-go/protocol/key_phase_test.go deleted file mode 100644 index 92f404a5..00000000 --- a/internal/quic-go/protocol/key_phase_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Key Phases", func() { - It("has undefined as its default value", func() { - var k KeyPhaseBit - Expect(k).To(Equal(KeyPhaseUndefined)) - }) - - It("has the correct string representation", func() { - Expect(KeyPhaseZero.String()).To(Equal("0")) - Expect(KeyPhaseOne.String()).To(Equal("1")) - }) - - It("converts the key phase to the key phase bit", func() { - Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero)) - Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne)) - Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne)) - Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne)) - }) -}) diff --git a/internal/quic-go/protocol/packet_number.go b/internal/quic-go/protocol/packet_number.go deleted file mode 100644 index bd340161..00000000 --- a/internal/quic-go/protocol/packet_number.go +++ /dev/null @@ -1,79 +0,0 @@ -package protocol - -// A PacketNumber in QUIC -type PacketNumber int64 - -// InvalidPacketNumber is a packet number that is never sent. -// In QUIC, 0 is a valid packet number. -const InvalidPacketNumber PacketNumber = -1 - -// PacketNumberLen is the length of the packet number in bytes -type PacketNumberLen uint8 - -const ( - // PacketNumberLen1 is a packet number length of 1 byte - PacketNumberLen1 PacketNumberLen = 1 - // PacketNumberLen2 is a packet number length of 2 bytes - PacketNumberLen2 PacketNumberLen = 2 - // PacketNumberLen3 is a packet number length of 3 bytes - PacketNumberLen3 PacketNumberLen = 3 - // PacketNumberLen4 is a packet number length of 4 bytes - PacketNumberLen4 PacketNumberLen = 4 -) - -// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number -func DecodePacketNumber( - packetNumberLength PacketNumberLen, - lastPacketNumber PacketNumber, - wirePacketNumber PacketNumber, -) PacketNumber { - var epochDelta PacketNumber - switch packetNumberLength { - case PacketNumberLen1: - epochDelta = PacketNumber(1) << 8 - case PacketNumberLen2: - epochDelta = PacketNumber(1) << 16 - case PacketNumberLen3: - epochDelta = PacketNumber(1) << 24 - case PacketNumberLen4: - epochDelta = PacketNumber(1) << 32 - } - epoch := lastPacketNumber & ^(epochDelta - 1) - var prevEpochBegin PacketNumber - if epoch > epochDelta { - prevEpochBegin = epoch - epochDelta - } - nextEpochBegin := epoch + epochDelta - return closestTo( - lastPacketNumber+1, - epoch+wirePacketNumber, - closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), - ) -} - -func closestTo(target, a, b PacketNumber) PacketNumber { - if delta(target, a) < delta(target, b) { - return a - } - return b -} - -func delta(a, b PacketNumber) PacketNumber { - if a < b { - return b - a - } - return a - b -} - -// GetPacketNumberLengthForHeader gets the length of the packet number for the public header -// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { - diff := uint64(packetNumber - leastUnacked) - if diff < (1 << (16 - 1)) { - return PacketNumberLen2 - } - if diff < (1 << (24 - 1)) { - return PacketNumberLen3 - } - return PacketNumberLen4 -} diff --git a/internal/quic-go/protocol/packet_number_test.go b/internal/quic-go/protocol/packet_number_test.go deleted file mode 100644 index d3bfe1d5..00000000 --- a/internal/quic-go/protocol/packet_number_test.go +++ /dev/null @@ -1,204 +0,0 @@ -package protocol - -import ( - "fmt" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -// Tests taken and extended from chrome -var _ = Describe("packet number calculation", func() { - It("InvalidPacketNumber is smaller than all valid packet numbers", func() { - Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) - }) - - It("works with the example from the draft", func() { - Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) - }) - - It("works with the examples from the draft", func() { - Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) - Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) - }) - - getEpoch := func(len PacketNumberLen) uint64 { - if len > 4 { - Fail("invalid packet number len") - } - return uint64(1) << (len * 8) - } - - check := func(length PacketNumberLen, expected, last uint64) { - epoch := getEpoch(length) - epochMask := epoch - 1 - wirePacketNumber := expected & epochMask - ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) - } - - for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { - length := l - - Context(fmt.Sprintf("with %d bytes", length), func() { - epoch := getEpoch(length) - epochMask := epoch - 1 - - It("works near epoch start", func() { - // A few quick manual sanity check - check(length, 1, 0) - check(length, epoch+1, epochMask) - check(length, epoch, epochMask) - - // Cases where the last number was close to the start of the range. - for last := uint64(0); last < 10; last++ { - // Small numbers should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, j, last) - } - - // Large numbers should not wrap either (because we're near 0 already). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - It("works near epoch end", func() { - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := epoch - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, epoch+j, last) - } - - // Large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - // Next check where we're in a non-zero epoch to verify we handle - // reverse wrapping, too. - It("works near previous epoch", func() { - prevEpoch := 1 * epoch - curEpoch := 2 * epoch - // Cases where the last number was close to the start of the range - for i := uint64(0); i < 10; i++ { - last := curEpoch + i - // Small number should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, curEpoch+j, last) - } - - // But large numbers should reverse wrap. - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, prevEpoch+num, last) - } - } - }) - - It("works near next epoch", func() { - curEpoch := 2 * epoch - nextEpoch := 3 * epoch - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := nextEpoch - 1 - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, nextEpoch+j, last) - } - - // but large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, curEpoch+num, last) - } - } - }) - - Context("shortening a packet number for the header", func() { - Context("shortening", func() { - It("sends out low packet numbers as 2 byte", func() { - length := GetPacketNumberLengthForHeader(4, 2) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000, 2) - Expect(length).To(Equal(PacketNumberLen3)) - }) - - It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000000, 2) - Expect(length).To(Equal(PacketNumberLen4)) - }) - }) - - Context("self-consistency", func() { - It("works for small packet numbers", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("works for small packet numbers and increasing ACKed packets", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i / 2) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("also works for larger packet numbers", func() { - var increment uint64 - for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - - increment = getEpoch(length) / 8 - } - }) - - It("works for packet numbers larger than 2^48", func() { - for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - }) - }) - }) - } -}) diff --git a/internal/quic-go/protocol/params.go b/internal/quic-go/protocol/params.go deleted file mode 100644 index 83137113..00000000 --- a/internal/quic-go/protocol/params.go +++ /dev/null @@ -1,193 +0,0 @@ -package protocol - -import "time" - -// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. -const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB - -// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. -const InitialPacketSizeIPv4 = 1252 - -// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. -const InitialPacketSizeIPv6 = 1232 - -// MaxCongestionWindowPackets is the maximum congestion window in packet. -const MaxCongestionWindowPackets = 10000 - -// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. -const MaxUndecryptablePackets = 32 - -// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window -// This is the value that Chromium is using -const ConnectionFlowControlMultiplier = 1.5 - -// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data -const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb - -// DefaultInitialMaxData is the connection-level flow control window for receiving data -const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData - -// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data -const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB - -// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data -const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB - -// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client -const WindowUpdateThreshold = 0.25 - -// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open -const DefaultMaxIncomingStreams = 100 - -// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open -const DefaultMaxIncomingUniStreams = 100 - -// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. -const MaxServerUnprocessedPackets = 1024 - -// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. -const MaxConnUnprocessedPackets = 256 - -// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. -// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. -const SkipPacketInitialPeriod PacketNumber = 256 - -// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. -const SkipPacketMaxPeriod PacketNumber = 128 * 1024 - -// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. -// If the queue is full, new connection attempts will be rejected. -const MaxAcceptQueueSize = 32 - -// TokenValidity is the duration that a (non-retry) token is considered valid -const TokenValidity = 24 * time.Hour - -// RetryTokenValidity is the duration that a retry token is considered valid -const RetryTokenValidity = 10 * time.Second - -// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. -// When reached, it imposes a soft limit on sending new packets: -// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. -const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets - -// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. -// When reached, no more packets will be sent. -// This value *must* be larger than MaxOutstandingSentPackets. -const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 - -// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, -// but no ack-eliciting frames, that we send in a row -const MaxNonAckElicitingAcks = 19 - -// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames -// prevents DoS attacks against the streamFrameSorter -const MaxStreamFrameSorterGaps = 1000 - -// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame -// that we use the buffer for. This protects against a DoS where an attacker would send us -// very small STREAM frames to consume a lot of memory. -const MinStreamFrameBufferSize = 128 - -// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. -// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. -const MinCoalescedPacketSize = 128 - -// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. -// This limits the size of the ClientHello and Certificates that can be received. -const MaxCryptoStreamOffset = 16 * (1 << 10) - -// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout -const MinRemoteIdleTimeout = 5 * time.Second - -// DefaultIdleTimeout is the default idle timeout -const DefaultIdleTimeout = 30 * time.Second - -// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. -const DefaultHandshakeIdleTimeout = 5 * time.Second - -// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. -const DefaultHandshakeTimeout = 10 * time.Second - -// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. -// It should be shorter than the time that NATs clear their mapping. -const MaxKeepAliveInterval = 20 * time.Second - -// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. -// after this time all information about the old connection will be deleted -const RetiredConnectionIDDeleteTimeout = 5 * time.Second - -// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. -// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: -// 1. it reduces the framing overhead -// 2. it reduces the head-of-line blocking, when a packet is lost -const MinStreamFrameSize ByteCount = 128 - -// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames -// we send after the handshake completes. -const MaxPostHandshakeCryptoFrameSize = 1000 - -// MaxAckFrameSize is the maximum size for an ACK frame that we write -// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. -// The MaxAckFrameSize should be large enough to encode many ACK range, -// but must ensure that a maximum size ACK frame fits into one packet. -const MaxAckFrameSize ByteCount = 1000 - -// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). -// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. -const MaxDatagramFrameSize ByteCount = 1220 - -// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) -const DatagramRcvQueueLen = 128 - -// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. -// It also serves as a limit for the packet history. -// If at any point we keep track of more ranges, old ranges are discarded. -const MaxNumAckRanges = 32 - -// MinPacingDelay is the minimum duration that is used for packet pacing -// If the packet packing frequency is higher, multiple packets might be sent at once. -// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. -const MinPacingDelay = time.Millisecond - -// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections -// if no other value is configured. -const DefaultConnectionIDLength = 4 - -// MaxActiveConnectionIDs is the number of connection IDs that we're storing. -const MaxActiveConnectionIDs = 4 - -// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. -const MaxIssuedConnectionIDs = 6 - -// PacketsPerConnectionID is the number of packets we send using one connection ID. -// If the peer provices us with enough new connection IDs, we switch to a new connection ID. -const PacketsPerConnectionID = 10000 - -// AckDelayExponent is the ack delay exponent used when sending ACKs. -const AckDelayExponent = 3 - -// Estimated timer granularity. -// The loss detection timer will not be set to a value smaller than granularity. -const TimerGranularity = time.Millisecond - -// MaxAckDelay is the maximum time by which we delay sending ACKs. -const MaxAckDelay = 25 * time.Millisecond - -// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. -// This is the value that should be advertised to the peer. -const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity - -// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. -const KeyUpdateInterval = 100 * 1000 - -// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. -const Max0RTTQueueingDuration = 100 * time.Millisecond - -// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. -const Max0RTTQueues = 32 - -// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. -// When a new connection is created, all buffered packets are passed to the connection immediately. -// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. -// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. -const Max0RTTQueueLen = 31 diff --git a/internal/quic-go/protocol/params_test.go b/internal/quic-go/protocol/params_test.go deleted file mode 100644 index 50a260d2..00000000 --- a/internal/quic-go/protocol/params_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Parameters", func() { - It("can queue more packets in the session than in the 0-RTT queue", func() { - Expect(MaxConnUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) - Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) - }) -}) diff --git a/internal/quic-go/protocol/perspective.go b/internal/quic-go/protocol/perspective.go deleted file mode 100644 index 43358fec..00000000 --- a/internal/quic-go/protocol/perspective.go +++ /dev/null @@ -1,26 +0,0 @@ -package protocol - -// Perspective determines if we're acting as a server or a client -type Perspective int - -// the perspectives -const ( - PerspectiveServer Perspective = 1 - PerspectiveClient Perspective = 2 -) - -// Opposite returns the perspective of the peer -func (p Perspective) Opposite() Perspective { - return 3 - p -} - -func (p Perspective) String() string { - switch p { - case PerspectiveServer: - return "Server" - case PerspectiveClient: - return "Client" - default: - return "invalid perspective" - } -} diff --git a/internal/quic-go/protocol/perspective_test.go b/internal/quic-go/protocol/perspective_test.go deleted file mode 100644 index 0ae23d7c..00000000 --- a/internal/quic-go/protocol/perspective_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Perspective", func() { - It("has a string representation", func() { - Expect(PerspectiveClient.String()).To(Equal("Client")) - Expect(PerspectiveServer.String()).To(Equal("Server")) - Expect(Perspective(0).String()).To(Equal("invalid perspective")) - }) - - It("returns the opposite", func() { - Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer)) - Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient)) - }) -}) diff --git a/internal/quic-go/protocol/protocol.go b/internal/quic-go/protocol/protocol.go deleted file mode 100644 index 8241e274..00000000 --- a/internal/quic-go/protocol/protocol.go +++ /dev/null @@ -1,97 +0,0 @@ -package protocol - -import ( - "fmt" - "time" -) - -// The PacketType is the Long Header Type -type PacketType uint8 - -const ( - // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = 1 + iota - // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry - // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake - // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT -) - -func (t PacketType) String() string { - switch t { - case PacketTypeInitial: - return "Initial" - case PacketTypeRetry: - return "Retry" - case PacketTypeHandshake: - return "Handshake" - case PacketType0RTT: - return "0-RTT Protected" - default: - return fmt.Sprintf("unknown packet type: %d", t) - } -} - -type ECN uint8 - -const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 -) - -// A ByteCount in QUIC -type ByteCount int64 - -// MaxByteCount is the maximum value of a ByteCount -const MaxByteCount = ByteCount(1<<62 - 1) - -// InvalidByteCount is an invalid byte count -const InvalidByteCount ByteCount = -1 - -// A StatelessResetToken is a stateless reset token. -type StatelessResetToken [16]byte - -// MaxPacketBufferSize maximum packet size of any QUIC packet, based on -// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, -// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. -// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. -const MaxPacketBufferSize ByteCount = 1452 - -// MinInitialPacketSize is the minimum size an Initial packet is required to have. -const MinInitialPacketSize = 1200 - -// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version -// needs to have in order to trigger a Version Negotiation packet. -const MinUnknownVersionPacketSize = MinInitialPacketSize - -// MinStatelessResetSize is the minimum size of a stateless reset packet that we send -const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ - -// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. -const MinConnectionIDLenInitial = 8 - -// DefaultAckDelayExponent is the default ack delay exponent -const DefaultAckDelayExponent = 3 - -// MaxAckDelayExponent is the maximum ack delay exponent -const MaxAckDelayExponent = 20 - -// DefaultMaxAckDelay is the default max_ack_delay -const DefaultMaxAckDelay = 25 * time.Millisecond - -// MaxMaxAckDelay is the maximum max_ack_delay -const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond - -// MaxConnIDLen is the maximum length of the connection ID -const MaxConnIDLen = 20 - -// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using -// AEAD_AES_128_GCM or AEAD_AES_265_GCM. -const InvalidPacketLimitAES = 1 << 52 - -// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. -const InvalidPacketLimitChaCha = 1 << 36 diff --git a/internal/quic-go/protocol/protocol_suite_test.go b/internal/quic-go/protocol/protocol_suite_test.go deleted file mode 100644 index 60da0157..00000000 --- a/internal/quic-go/protocol/protocol_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package protocol - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestProtocol(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Protocol Suite") -} diff --git a/internal/quic-go/protocol/protocol_test.go b/internal/quic-go/protocol/protocol_test.go deleted file mode 100644 index 117405e4..00000000 --- a/internal/quic-go/protocol/protocol_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Protocol", func() { - Context("Long Header Packet Types", func() { - It("has the correct string representation", func() { - Expect(PacketTypeInitial.String()).To(Equal("Initial")) - Expect(PacketTypeRetry.String()).To(Equal("Retry")) - Expect(PacketTypeHandshake.String()).To(Equal("Handshake")) - Expect(PacketType0RTT.String()).To(Equal("0-RTT Protected")) - Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) - }) - }) - - It("converts ECN bits from the IP header wire to the correct types", func() { - Expect(ECN(0)).To(Equal(ECNNon)) - Expect(ECN(0b00000010)).To(Equal(ECT0)) - Expect(ECN(0b00000001)).To(Equal(ECT1)) - Expect(ECN(0b00000011)).To(Equal(ECNCE)) - }) -}) diff --git a/internal/quic-go/protocol/stream.go b/internal/quic-go/protocol/stream.go deleted file mode 100644 index ad7de864..00000000 --- a/internal/quic-go/protocol/stream.go +++ /dev/null @@ -1,76 +0,0 @@ -package protocol - -// StreamType encodes if this is a unidirectional or bidirectional stream -type StreamType uint8 - -const ( - // StreamTypeUni is a unidirectional stream - StreamTypeUni StreamType = iota - // StreamTypeBidi is a bidirectional stream - StreamTypeBidi -) - -// InvalidPacketNumber is a stream ID that is invalid. -// The first valid stream ID in QUIC is 0. -const InvalidStreamID StreamID = -1 - -// StreamNum is the stream number -type StreamNum int64 - -const ( - // InvalidStreamNum is an invalid stream number. - InvalidStreamNum = -1 - // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames - // and as the stream count in the transport parameters - MaxStreamCount StreamNum = 1 << 60 -) - -// StreamID calculates the stream ID. -func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { - if s == 0 { - return InvalidStreamID - } - var first StreamID - switch stype { - case StreamTypeBidi: - switch pers { - case PerspectiveClient: - first = 0 - case PerspectiveServer: - first = 1 - } - case StreamTypeUni: - switch pers { - case PerspectiveClient: - first = 2 - case PerspectiveServer: - first = 3 - } - } - return first + 4*StreamID(s-1) -} - -// A StreamID in QUIC -type StreamID int64 - -// InitiatedBy says if the stream was initiated by the client or by the server -func (s StreamID) InitiatedBy() Perspective { - if s%2 == 0 { - return PerspectiveClient - } - return PerspectiveServer -} - -// Type says if this is a unidirectional or bidirectional stream -func (s StreamID) Type() StreamType { - if s%4 >= 2 { - return StreamTypeUni - } - return StreamTypeBidi -} - -// StreamNum returns how many streams in total are below this -// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) -func (s StreamID) StreamNum() StreamNum { - return StreamNum(s/4) + 1 -} diff --git a/internal/quic-go/protocol/stream_test.go b/internal/quic-go/protocol/stream_test.go deleted file mode 100644 index 4209f8a0..00000000 --- a/internal/quic-go/protocol/stream_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Stream ID", func() { - It("InvalidStreamID is smaller than all valid stream IDs", func() { - Expect(InvalidStreamID).To(BeNumerically("<", 0)) - }) - - It("says who initiated a stream", func() { - Expect(StreamID(4).InitiatedBy()).To(Equal(PerspectiveClient)) - Expect(StreamID(5).InitiatedBy()).To(Equal(PerspectiveServer)) - Expect(StreamID(6).InitiatedBy()).To(Equal(PerspectiveClient)) - Expect(StreamID(7).InitiatedBy()).To(Equal(PerspectiveServer)) - }) - - It("tells the directionality", func() { - Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi)) - Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi)) - Expect(StreamID(6).Type()).To(Equal(StreamTypeUni)) - Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) - }) - - It("tells the stream number", func() { - Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1)) - Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3)) - Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) - }) - - Context("converting stream nums to stream IDs", func() { - It("handles 0", func() { - Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID)) - }) - - It("handles the first", func() { - Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) - Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) - Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) - Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) - }) - - It("handles others", func() { - Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396))) - Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397))) - Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) - Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) - }) - - It("has the right value for MaxStreamCount", func() { - const maxStreamID = StreamID(1<<62 - 1) - for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { - for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { - Expect(MaxStreamCount.StreamID(dir, pers)).To(BeNumerically("<=", maxStreamID)) - Expect((MaxStreamCount + 1).StreamID(dir, pers)).To(BeNumerically(">", maxStreamID)) - } - } - }) - }) -}) diff --git a/internal/quic-go/protocol/version.go b/internal/quic-go/protocol/version.go deleted file mode 100644 index dd54dbd3..00000000 --- a/internal/quic-go/protocol/version.go +++ /dev/null @@ -1,114 +0,0 @@ -package protocol - -import ( - "crypto/rand" - "encoding/binary" - "fmt" - "math" -) - -// VersionNumber is a version number as int -type VersionNumber uint32 - -// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions -const ( - gquicVersion0 = 0x51303030 - maxGquicVersion = 0x51303439 -) - -// The version numbers, making grepping easier -const ( - VersionTLS VersionNumber = 0x1 - VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter - VersionUnknown VersionNumber = math.MaxUint32 - VersionDraft29 VersionNumber = 0xff00001d - Version1 VersionNumber = 0x1 - Version2 VersionNumber = 0x709a50c4 -) - -// SupportedVersions lists the versions that the server supports -// must be in sorted descending order -var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29} - -// IsValidVersion says if the version is known to quic-go -func IsValidVersion(v VersionNumber) bool { - return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) -} - -func (vn VersionNumber) String() string { - // For releases, VersionTLS will be set to a draft version. - // A switch statement can't contain duplicate cases. - if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != Version1 { - return "TLS dev version (WIP)" - } - //nolint:exhaustive - switch vn { - case VersionWhatever: - return "whatever" - case VersionUnknown: - return "unknown" - case VersionDraft29: - return "draft-29" - case Version1: - return "v1" - case Version2: - return "v2" - default: - if vn.isGQUIC() { - return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) - } - return fmt.Sprintf("%#x", uint32(vn)) - } -} - -func (vn VersionNumber) isGQUIC() bool { - return vn > gquicVersion0 && vn <= maxGquicVersion -} - -func (vn VersionNumber) toGQUICVersion() int { - return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) -} - -// IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { - for _, t := range supported { - if t == v { - return true - } - } - return false -} - -// ChooseSupportedVersion finds the best version in the overlap of ours and theirs -// ours is a slice of versions that we support, sorted by our preference (descending) -// theirs is a slice of versions offered by the peer. The order does not matter. -// The bool returned indicates if a matching version was found. -func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { - for _, ourVer := range ours { - for _, theirVer := range theirs { - if ourVer == theirVer { - return ourVer, true - } - } - } - return 0, false -} - -// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) -func generateReservedVersion() VersionNumber { - b := make([]byte, 4) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) -} - -// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position -func GetGreasedVersions(supported []VersionNumber) []VersionNumber { - b := make([]byte, 1) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - randPos := int(b[0]) % (len(supported) + 1) - greased := make([]VersionNumber, len(supported)+1) - copy(greased, supported[:randPos]) - greased[randPos] = generateReservedVersion() - copy(greased[randPos+1:], supported[randPos:]) - return greased -} diff --git a/internal/quic-go/protocol/version_test.go b/internal/quic-go/protocol/version_test.go deleted file mode 100644 index 33c6598b..00000000 --- a/internal/quic-go/protocol/version_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package protocol - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Version", func() { - isReservedVersion := func(v VersionNumber) bool { - return v&0x0f0f0f0f == 0x0a0a0a0a - } - - It("says if a version is valid", func() { - Expect(IsValidVersion(VersionTLS)).To(BeTrue()) - Expect(IsValidVersion(VersionWhatever)).To(BeFalse()) - Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) - Expect(IsValidVersion(VersionDraft29)).To(BeTrue()) - Expect(IsValidVersion(Version1)).To(BeTrue()) - Expect(IsValidVersion(Version2)).To(BeTrue()) - Expect(IsValidVersion(1234)).To(BeFalse()) - }) - - It("versions don't have reserved version numbers", func() { - Expect(isReservedVersion(VersionTLS)).To(BeFalse()) - }) - - It("has the right string representation", func() { - Expect(VersionWhatever.String()).To(Equal("whatever")) - Expect(VersionUnknown.String()).To(Equal("unknown")) - Expect(VersionDraft29.String()).To(Equal("draft-29")) - Expect(Version1.String()).To(Equal("v1")) - Expect(Version2.String()).To(Equal("v2")) - // check with unsupported version numbers from the wiki - Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) - Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) - Expect(VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) - Expect(VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) - Expect(VersionNumber(0x01234567).String()).To(Equal("0x1234567")) - }) - - It("recognizes supported versions", func() { - Expect(IsSupportedVersion(SupportedVersions, 0)).To(BeFalse()) - Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[0])).To(BeTrue()) - Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue()) - }) - - Context("highest supported version", func() { - It("finds the supported version", func() { - supportedVersions := []VersionNumber{1, 2, 3} - other := []VersionNumber{6, 5, 4, 3} - ver, ok := ChooseSupportedVersion(supportedVersions, other) - Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(3))) - }) - - It("picks the preferred version", func() { - supportedVersions := []VersionNumber{2, 1, 3} - other := []VersionNumber{3, 6, 1, 8, 2, 10} - ver, ok := ChooseSupportedVersion(supportedVersions, other) - Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(2))) - }) - - It("says when no matching version was found", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{1}, []VersionNumber{2}) - Expect(ok).To(BeFalse()) - }) - - It("handles empty inputs", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{102, 101}, []VersionNumber{}) - Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{1, 2}) - Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{}) - Expect(ok).To(BeFalse()) - }) - }) - - Context("reserved versions", func() { - It("adds a greased version if passed an empty slice", func() { - greased := GetGreasedVersions([]VersionNumber{}) - Expect(greased).To(HaveLen(1)) - Expect(isReservedVersion(greased[0])).To(BeTrue()) - }) - - It("creates greased lists of version numbers", func() { - supported := []VersionNumber{10, 18, 29} - for _, v := range supported { - Expect(isReservedVersion(v)).To(BeFalse()) - } - var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int - // check that - // 1. the greased version sometimes appears first - // 2. the greased version sometimes appears in the middle - // 3. the greased version sometimes appears last - // 4. the supported versions are kept in order - for i := 0; i < 100; i++ { - greased := GetGreasedVersions(supported) - Expect(greased).To(HaveLen(4)) - var j int - for i, v := range greased { - if isReservedVersion(v) { - if i == 0 { - greasedVersionFirst++ - } - if i == len(greased)-1 { - greasedVersionLast++ - } - greasedVersionMiddle++ - continue - } - Expect(supported[j]).To(Equal(v)) - j++ - } - } - Expect(greasedVersionFirst).ToNot(BeZero()) - Expect(greasedVersionLast).ToNot(BeZero()) - Expect(greasedVersionMiddle).ToNot(BeZero()) - }) - }) -}) diff --git a/internal/quic-go/qerr/error_codes.go b/internal/quic-go/qerr/error_codes.go deleted file mode 100644 index fe61d146..00000000 --- a/internal/quic-go/qerr/error_codes.go +++ /dev/null @@ -1,88 +0,0 @@ -package qerr - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/qtls" -) - -// TransportErrorCode is a QUIC transport error. -type TransportErrorCode uint64 - -// The error codes defined by QUIC -const ( - NoError TransportErrorCode = 0x0 - InternalError TransportErrorCode = 0x1 - ConnectionRefused TransportErrorCode = 0x2 - FlowControlError TransportErrorCode = 0x3 - StreamLimitError TransportErrorCode = 0x4 - StreamStateError TransportErrorCode = 0x5 - FinalSizeError TransportErrorCode = 0x6 - FrameEncodingError TransportErrorCode = 0x7 - TransportParameterError TransportErrorCode = 0x8 - ConnectionIDLimitError TransportErrorCode = 0x9 - ProtocolViolation TransportErrorCode = 0xa - InvalidToken TransportErrorCode = 0xb - ApplicationErrorErrorCode TransportErrorCode = 0xc - CryptoBufferExceeded TransportErrorCode = 0xd - KeyUpdateError TransportErrorCode = 0xe - AEADLimitReached TransportErrorCode = 0xf - NoViablePathError TransportErrorCode = 0x10 -) - -func (e TransportErrorCode) IsCryptoError() bool { - return e >= 0x100 && e < 0x200 -} - -// Message is a description of the error. -// It only returns a non-empty string for crypto errors. -func (e TransportErrorCode) Message() string { - if !e.IsCryptoError() { - return "" - } - return qtls.Alert(e - 0x100).Error() -} - -func (e TransportErrorCode) String() string { - switch e { - case NoError: - return "NO_ERROR" - case InternalError: - return "INTERNAL_ERROR" - case ConnectionRefused: - return "CONNECTION_REFUSED" - case FlowControlError: - return "FLOW_CONTROL_ERROR" - case StreamLimitError: - return "STREAM_LIMIT_ERROR" - case StreamStateError: - return "STREAM_STATE_ERROR" - case FinalSizeError: - return "FINAL_SIZE_ERROR" - case FrameEncodingError: - return "FRAME_ENCODING_ERROR" - case TransportParameterError: - return "TRANSPORT_PARAMETER_ERROR" - case ConnectionIDLimitError: - return "CONNECTION_ID_LIMIT_ERROR" - case ProtocolViolation: - return "PROTOCOL_VIOLATION" - case InvalidToken: - return "INVALID_TOKEN" - case ApplicationErrorErrorCode: - return "APPLICATION_ERROR" - case CryptoBufferExceeded: - return "CRYPTO_BUFFER_EXCEEDED" - case KeyUpdateError: - return "KEY_UPDATE_ERROR" - case AEADLimitReached: - return "AEAD_LIMIT_REACHED" - case NoViablePathError: - return "NO_VIABLE_PATH" - default: - if e.IsCryptoError() { - return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) - } - return fmt.Sprintf("unknown error code: %#x", uint16(e)) - } -} diff --git a/internal/quic-go/qerr/errorcodes_test.go b/internal/quic-go/qerr/errorcodes_test.go deleted file mode 100644 index cfc6cd85..00000000 --- a/internal/quic-go/qerr/errorcodes_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package qerr - -import ( - "go/ast" - "go/parser" - "go/token" - "path" - "runtime" - "strconv" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("error codes", func() { - // If this test breaks, you should run `go generate ./...` - It("has a string representation for every error code", func() { - // We parse the error code file, extract all constants, and verify that - // each of them has a string version. Go FTW! - _, thisfile, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - filename := path.Join(path.Dir(thisfile), "error_codes.go") - fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) - Expect(err).NotTo(HaveOccurred()) - constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs - Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing - for _, c := range constSpecs { - valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value - val, err := strconv.ParseInt(valString, 0, 64) - Expect(err).NotTo(HaveOccurred()) - Expect(TransportErrorCode(val).String()).ToNot(Equal("unknown error code")) - } - }) - - It("has a string representation for unknown error codes", func() { - Expect(TransportErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) - }) - - It("says if an error is a crypto error", func() { - for i := 0; i < 0x100; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) - } - for i := 0x100; i < 0x200; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeTrue()) - } - for i := 0x200; i < 0x300; i++ { - Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) - } - }) -}) diff --git a/internal/quic-go/qerr/errors.go b/internal/quic-go/qerr/errors.go deleted file mode 100644 index c2be1040..00000000 --- a/internal/quic-go/qerr/errors.go +++ /dev/null @@ -1,124 +0,0 @@ -package qerr - -import ( - "fmt" - "net" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -var ( - ErrHandshakeTimeout = &HandshakeTimeoutError{} - ErrIdleTimeout = &IdleTimeoutError{} -) - -type TransportError struct { - Remote bool - FrameType uint64 - ErrorCode TransportErrorCode - ErrorMessage string -} - -var _ error = &TransportError{} - -// NewCryptoError create a new TransportError instance for a crypto error -func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { - return &TransportError{ - ErrorCode: 0x100 + TransportErrorCode(tlsAlert), - ErrorMessage: errorMessage, - } -} - -func (e *TransportError) Error() string { - str := e.ErrorCode.String() - if e.FrameType != 0 { - str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) - } - msg := e.ErrorMessage - if len(msg) == 0 { - msg = e.ErrorCode.Message() - } - if len(msg) == 0 { - return str - } - return str + ": " + msg -} - -func (e *TransportError) Is(target error) bool { - return target == net.ErrClosed -} - -// An ApplicationErrorCode is an application-defined error code. -type ApplicationErrorCode uint64 - -func (e *ApplicationError) Is(target error) bool { - return target == net.ErrClosed -} - -// A StreamErrorCode is an error code used to cancel streams. -type StreamErrorCode uint64 - -type ApplicationError struct { - Remote bool - ErrorCode ApplicationErrorCode - ErrorMessage string -} - -var _ error = &ApplicationError{} - -func (e *ApplicationError) Error() string { - if len(e.ErrorMessage) == 0 { - return fmt.Sprintf("Application error %#x", e.ErrorCode) - } - return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) -} - -type IdleTimeoutError struct{} - -var _ error = &IdleTimeoutError{} - -func (e *IdleTimeoutError) Timeout() bool { return true } -func (e *IdleTimeoutError) Temporary() bool { return false } -func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } -func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } - -type HandshakeTimeoutError struct{} - -var _ error = &HandshakeTimeoutError{} - -func (e *HandshakeTimeoutError) Timeout() bool { return true } -func (e *HandshakeTimeoutError) Temporary() bool { return false } -func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } -func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } - -// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. -type VersionNegotiationError struct { - Ours []protocol.VersionNumber - Theirs []protocol.VersionNumber -} - -func (e *VersionNegotiationError) Error() string { - return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) -} - -func (e *VersionNegotiationError) Is(target error) bool { - return target == net.ErrClosed -} - -// A StatelessResetError occurs when we receive a stateless reset. -type StatelessResetError struct { - Token protocol.StatelessResetToken -} - -var _ net.Error = &StatelessResetError{} - -func (e *StatelessResetError) Error() string { - return fmt.Sprintf("received a stateless reset with token %x", e.Token) -} - -func (e *StatelessResetError) Is(target error) bool { - return target == net.ErrClosed -} - -func (e *StatelessResetError) Timeout() bool { return false } -func (e *StatelessResetError) Temporary() bool { return true } diff --git a/internal/quic-go/qerr/errors_suite_test.go b/internal/quic-go/qerr/errors_suite_test.go deleted file mode 100644 index 749cdedc..00000000 --- a/internal/quic-go/qerr/errors_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package qerr - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestErrorcodes(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Errors Suite") -} diff --git a/internal/quic-go/qerr/errors_test.go b/internal/quic-go/qerr/errors_test.go deleted file mode 100644 index 3b7f7cae..00000000 --- a/internal/quic-go/qerr/errors_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package qerr - -import ( - "errors" - "net" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("QUIC Errors", func() { - Context("Transport Errors", func() { - It("has a string representation", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - ErrorMessage: "foobar", - }).Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) - }) - - It("has a string representation for empty error phrases", func() { - Expect((&TransportError{ErrorCode: FlowControlError}).Error()).To(Equal("FLOW_CONTROL_ERROR")) - }) - - It("includes the frame type, for errors without a message", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - FrameType: 0x1337, - }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) - }) - - It("includes the frame type, for errors with a message", func() { - Expect((&TransportError{ - ErrorCode: FlowControlError, - FrameType: 0x1337, - ErrorMessage: "foobar", - }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) - }) - - Context("crypto errors", func() { - It("has a string representation for errors with a message", func() { - err := NewCryptoError(0x42, "foobar") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) - }) - - It("has a string representation for errors without a message", func() { - err := NewCryptoError(0x2a, "") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) - }) - }) - }) - - Context("Application Errors", func() { - It("has a string representation for errors with a message", func() { - Expect((&ApplicationError{ - ErrorCode: 0x42, - ErrorMessage: "foobar", - }).Error()).To(Equal("Application error 0x42: foobar")) - }) - - It("has a string representation for errors without a message", func() { - Expect((&ApplicationError{ - ErrorCode: 0x42, - }).Error()).To(Equal("Application error 0x42")) - }) - }) - - Context("timeout errors", func() { - It("handshake timeouts", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &HandshakeTimeoutError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) - }) - - It("idle timeouts", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &IdleTimeoutError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(Equal("timeout: no recent network activity")) - }) - }) - - Context("Version Negotiation errors", func() { - It("has a string representation", func() { - Expect((&VersionNegotiationError{ - Ours: []protocol.VersionNumber{2, 3}, - Theirs: []protocol.VersionNumber{4, 5, 6}, - }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) - }) - }) - - Context("Stateless Reset errors", func() { - token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - - It("has a string representation", func() { - Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f")) - }) - - It("is a net.Error", func() { - //nolint:gosimple // we need to assign to an interface here - var err error - err = &StatelessResetError{} - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeFalse()) - }) - }) - - It("says that errors are net.ErrClosed errors", func() { - Expect(errors.Is(&TransportError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&ApplicationError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&IdleTimeoutError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&StatelessResetError{}, net.ErrClosed)).To(BeTrue()) - Expect(errors.Is(&VersionNegotiationError{}, net.ErrClosed)).To(BeTrue()) - }) -}) diff --git a/internal/quic-go/qlog/event.go b/internal/quic-go/qlog/event.go deleted file mode 100644 index 4d799090..00000000 --- a/internal/quic-go/qlog/event.go +++ /dev/null @@ -1,529 +0,0 @@ -package qlog - -import ( - "errors" - "fmt" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - - "github.com/francoispqt/gojay" -) - -func milliseconds(dur time.Duration) float64 { return float64(dur.Nanoseconds()) / 1e6 } - -type eventDetails interface { - Category() category - Name() string - gojay.MarshalerJSONObject -} - -type event struct { - RelativeTime time.Duration - eventDetails -} - -var _ gojay.MarshalerJSONObject = event{} - -func (e event) IsNil() bool { return false } -func (e event) MarshalJSONObject(enc *gojay.Encoder) { - enc.Float64Key("time", milliseconds(e.RelativeTime)) - enc.StringKey("name", e.Category().String()+":"+e.Name()) - enc.ObjectKey("data", e.eventDetails) -} - -type versions []versionNumber - -func (v versions) IsNil() bool { return false } -func (v versions) MarshalJSONArray(enc *gojay.Encoder) { - for _, e := range v { - enc.AddString(e.String()) - } -} - -type rawInfo struct { - Length logging.ByteCount // full packet length, including header and AEAD authentication tag - PayloadLength logging.ByteCount // length of the packet payload, excluding AEAD tag -} - -func (i rawInfo) IsNil() bool { return false } -func (i rawInfo) MarshalJSONObject(enc *gojay.Encoder) { - enc.Uint64Key("length", uint64(i.Length)) - enc.Uint64KeyOmitEmpty("payload_length", uint64(i.PayloadLength)) -} - -type eventConnectionStarted struct { - SrcAddr *net.UDPAddr - DestAddr *net.UDPAddr - - SrcConnectionID protocol.ConnectionID - DestConnectionID protocol.ConnectionID -} - -var _ eventDetails = &eventConnectionStarted{} - -func (e eventConnectionStarted) Category() category { return categoryTransport } -func (e eventConnectionStarted) Name() string { return "connection_started" } -func (e eventConnectionStarted) IsNil() bool { return false } - -func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) { - if utils.IsIPv4(e.SrcAddr.IP) { - enc.StringKey("ip_version", "ipv4") - } else { - enc.StringKey("ip_version", "ipv6") - } - enc.StringKey("src_ip", e.SrcAddr.IP.String()) - enc.IntKey("src_port", e.SrcAddr.Port) - enc.StringKey("dst_ip", e.DestAddr.IP.String()) - enc.IntKey("dst_port", e.DestAddr.Port) - enc.StringKey("src_cid", connectionID(e.SrcConnectionID).String()) - enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String()) -} - -type eventVersionNegotiated struct { - clientVersions, serverVersions []versionNumber - chosenVersion versionNumber -} - -func (e eventVersionNegotiated) Category() category { return categoryTransport } -func (e eventVersionNegotiated) Name() string { return "version_information" } -func (e eventVersionNegotiated) IsNil() bool { return false } - -func (e eventVersionNegotiated) MarshalJSONObject(enc *gojay.Encoder) { - if len(e.clientVersions) > 0 { - enc.ArrayKey("client_versions", versions(e.clientVersions)) - } - if len(e.serverVersions) > 0 { - enc.ArrayKey("server_versions", versions(e.serverVersions)) - } - enc.StringKey("chosen_version", e.chosenVersion.String()) -} - -type eventConnectionClosed struct { - e error -} - -func (e eventConnectionClosed) Category() category { return categoryTransport } -func (e eventConnectionClosed) Name() string { return "connection_closed" } -func (e eventConnectionClosed) IsNil() bool { return false } - -func (e eventConnectionClosed) MarshalJSONObject(enc *gojay.Encoder) { - var ( - statelessResetErr *quic.StatelessResetError - handshakeTimeoutErr *quic.HandshakeTimeoutError - idleTimeoutErr *quic.IdleTimeoutError - applicationErr *quic.ApplicationError - transportErr *quic.TransportError - versionNegotiationErr *quic.VersionNegotiationError - ) - switch { - case errors.As(e.e, &statelessResetErr): - enc.StringKey("owner", ownerRemote.String()) - enc.StringKey("trigger", "stateless_reset") - enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", statelessResetErr.Token)) - case errors.As(e.e, &handshakeTimeoutErr): - enc.StringKey("owner", ownerLocal.String()) - enc.StringKey("trigger", "handshake_timeout") - case errors.As(e.e, &idleTimeoutErr): - enc.StringKey("owner", ownerLocal.String()) - enc.StringKey("trigger", "idle_timeout") - case errors.As(e.e, &applicationErr): - owner := ownerLocal - if applicationErr.Remote { - owner = ownerRemote - } - enc.StringKey("owner", owner.String()) - enc.Uint64Key("application_code", uint64(applicationErr.ErrorCode)) - enc.StringKey("reason", applicationErr.ErrorMessage) - case errors.As(e.e, &transportErr): - owner := ownerLocal - if transportErr.Remote { - owner = ownerRemote - } - enc.StringKey("owner", owner.String()) - enc.StringKey("connection_code", transportError(transportErr.ErrorCode).String()) - enc.StringKey("reason", transportErr.ErrorMessage) - case errors.As(e.e, &versionNegotiationErr): - enc.StringKey("owner", ownerRemote.String()) - enc.StringKey("trigger", "version_negotiation") - } -} - -type eventPacketSent struct { - Header packetHeader - Length logging.ByteCount - PayloadLength logging.ByteCount - Frames frames - IsCoalesced bool - Trigger string -} - -var _ eventDetails = eventPacketSent{} - -func (e eventPacketSent) Category() category { return categoryTransport } -func (e eventPacketSent) Name() string { return "packet_sent" } -func (e eventPacketSent) IsNil() bool { return false } - -func (e eventPacketSent) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", e.Header) - enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) - enc.ArrayKeyOmitEmpty("frames", e.Frames) - enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) - enc.StringKeyOmitEmpty("trigger", e.Trigger) -} - -type eventPacketReceived struct { - Header packetHeader - Length logging.ByteCount - PayloadLength logging.ByteCount - Frames frames - IsCoalesced bool - Trigger string -} - -var _ eventDetails = eventPacketReceived{} - -func (e eventPacketReceived) Category() category { return categoryTransport } -func (e eventPacketReceived) Name() string { return "packet_received" } -func (e eventPacketReceived) IsNil() bool { return false } - -func (e eventPacketReceived) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", e.Header) - enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) - enc.ArrayKeyOmitEmpty("frames", e.Frames) - enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) - enc.StringKeyOmitEmpty("trigger", e.Trigger) -} - -type eventRetryReceived struct { - Header packetHeader -} - -func (e eventRetryReceived) Category() category { return categoryTransport } -func (e eventRetryReceived) Name() string { return "packet_received" } -func (e eventRetryReceived) IsNil() bool { return false } - -func (e eventRetryReceived) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", e.Header) -} - -type eventVersionNegotiationReceived struct { - Header packetHeader - SupportedVersions []versionNumber -} - -func (e eventVersionNegotiationReceived) Category() category { return categoryTransport } -func (e eventVersionNegotiationReceived) Name() string { return "packet_received" } -func (e eventVersionNegotiationReceived) IsNil() bool { return false } - -func (e eventVersionNegotiationReceived) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", e.Header) - enc.ArrayKey("supported_versions", versions(e.SupportedVersions)) -} - -type eventPacketBuffered struct { - PacketType logging.PacketType -} - -func (e eventPacketBuffered) Category() category { return categoryTransport } -func (e eventPacketBuffered) Name() string { return "packet_buffered" } -func (e eventPacketBuffered) IsNil() bool { return false } - -func (e eventPacketBuffered) MarshalJSONObject(enc *gojay.Encoder) { - //nolint:gosimple - enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) - enc.StringKey("trigger", "keys_unavailable") -} - -type eventPacketDropped struct { - PacketType logging.PacketType - PacketSize protocol.ByteCount - Trigger packetDropReason -} - -func (e eventPacketDropped) Category() category { return categoryTransport } -func (e eventPacketDropped) Name() string { return "packet_dropped" } -func (e eventPacketDropped) IsNil() bool { return false } - -func (e eventPacketDropped) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) - enc.ObjectKey("raw", rawInfo{Length: e.PacketSize}) - enc.StringKey("trigger", e.Trigger.String()) -} - -type metrics struct { - MinRTT time.Duration - SmoothedRTT time.Duration - LatestRTT time.Duration - RTTVariance time.Duration - - CongestionWindow protocol.ByteCount - BytesInFlight protocol.ByteCount - PacketsInFlight int -} - -type eventMetricsUpdated struct { - Last *metrics - Current *metrics -} - -func (e eventMetricsUpdated) Category() category { return categoryRecovery } -func (e eventMetricsUpdated) Name() string { return "metrics_updated" } -func (e eventMetricsUpdated) IsNil() bool { return false } - -func (e eventMetricsUpdated) MarshalJSONObject(enc *gojay.Encoder) { - if e.Last == nil || e.Last.MinRTT != e.Current.MinRTT { - enc.FloatKey("min_rtt", milliseconds(e.Current.MinRTT)) - } - if e.Last == nil || e.Last.SmoothedRTT != e.Current.SmoothedRTT { - enc.FloatKey("smoothed_rtt", milliseconds(e.Current.SmoothedRTT)) - } - if e.Last == nil || e.Last.LatestRTT != e.Current.LatestRTT { - enc.FloatKey("latest_rtt", milliseconds(e.Current.LatestRTT)) - } - if e.Last == nil || e.Last.RTTVariance != e.Current.RTTVariance { - enc.FloatKey("rtt_variance", milliseconds(e.Current.RTTVariance)) - } - - if e.Last == nil || e.Last.CongestionWindow != e.Current.CongestionWindow { - enc.Uint64Key("congestion_window", uint64(e.Current.CongestionWindow)) - } - if e.Last == nil || e.Last.BytesInFlight != e.Current.BytesInFlight { - enc.Uint64Key("bytes_in_flight", uint64(e.Current.BytesInFlight)) - } - if e.Last == nil || e.Last.PacketsInFlight != e.Current.PacketsInFlight { - enc.Uint64KeyOmitEmpty("packets_in_flight", uint64(e.Current.PacketsInFlight)) - } -} - -type eventUpdatedPTO struct { - Value uint32 -} - -func (e eventUpdatedPTO) Category() category { return categoryRecovery } -func (e eventUpdatedPTO) Name() string { return "metrics_updated" } -func (e eventUpdatedPTO) IsNil() bool { return false } - -func (e eventUpdatedPTO) MarshalJSONObject(enc *gojay.Encoder) { - enc.Uint32Key("pto_count", e.Value) -} - -type eventPacketLost struct { - PacketType logging.PacketType - PacketNumber protocol.PacketNumber - Trigger packetLossReason -} - -func (e eventPacketLost) Category() category { return categoryRecovery } -func (e eventPacketLost) Name() string { return "packet_lost" } -func (e eventPacketLost) IsNil() bool { return false } - -func (e eventPacketLost) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", packetHeaderWithTypeAndPacketNumber{ - PacketType: e.PacketType, - PacketNumber: e.PacketNumber, - }) - enc.StringKey("trigger", e.Trigger.String()) -} - -type eventKeyUpdated struct { - Trigger keyUpdateTrigger - KeyType keyType - Generation protocol.KeyPhase - // we don't log the keys here, so we don't need `old` and `new`. -} - -func (e eventKeyUpdated) Category() category { return categorySecurity } -func (e eventKeyUpdated) Name() string { return "key_updated" } -func (e eventKeyUpdated) IsNil() bool { return false } - -func (e eventKeyUpdated) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("trigger", e.Trigger.String()) - enc.StringKey("key_type", e.KeyType.String()) - if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { - enc.Uint64Key("generation", uint64(e.Generation)) - } -} - -type eventKeyRetired struct { - KeyType keyType - Generation protocol.KeyPhase -} - -func (e eventKeyRetired) Category() category { return categorySecurity } -func (e eventKeyRetired) Name() string { return "key_retired" } -func (e eventKeyRetired) IsNil() bool { return false } - -func (e eventKeyRetired) MarshalJSONObject(enc *gojay.Encoder) { - if e.KeyType != keyTypeClient1RTT && e.KeyType != keyTypeServer1RTT { - enc.StringKey("trigger", "tls") - } - enc.StringKey("key_type", e.KeyType.String()) - if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { - enc.Uint64Key("generation", uint64(e.Generation)) - } -} - -type eventTransportParameters struct { - Restore bool - Owner owner - SentBy protocol.Perspective - - OriginalDestinationConnectionID protocol.ConnectionID - InitialSourceConnectionID protocol.ConnectionID - RetrySourceConnectionID *protocol.ConnectionID - - StatelessResetToken *protocol.StatelessResetToken - DisableActiveMigration bool - MaxIdleTimeout time.Duration - MaxUDPPayloadSize protocol.ByteCount - AckDelayExponent uint8 - MaxAckDelay time.Duration - ActiveConnectionIDLimit uint64 - - InitialMaxData protocol.ByteCount - InitialMaxStreamDataBidiLocal protocol.ByteCount - InitialMaxStreamDataBidiRemote protocol.ByteCount - InitialMaxStreamDataUni protocol.ByteCount - InitialMaxStreamsBidi int64 - InitialMaxStreamsUni int64 - - PreferredAddress *preferredAddress - - MaxDatagramFrameSize protocol.ByteCount -} - -func (e eventTransportParameters) Category() category { return categoryTransport } -func (e eventTransportParameters) Name() string { - if e.Restore { - return "parameters_restored" - } - return "parameters_set" -} -func (e eventTransportParameters) IsNil() bool { return false } - -func (e eventTransportParameters) MarshalJSONObject(enc *gojay.Encoder) { - if !e.Restore { - enc.StringKey("owner", e.Owner.String()) - if e.SentBy == protocol.PerspectiveServer { - enc.StringKey("original_destination_connection_id", connectionID(e.OriginalDestinationConnectionID).String()) - if e.StatelessResetToken != nil { - enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", e.StatelessResetToken[:])) - } - if e.RetrySourceConnectionID != nil { - enc.StringKey("retry_source_connection_id", connectionID(*e.RetrySourceConnectionID).String()) - } - } - enc.StringKey("initial_source_connection_id", connectionID(e.InitialSourceConnectionID).String()) - } - enc.BoolKey("disable_active_migration", e.DisableActiveMigration) - enc.FloatKeyOmitEmpty("max_idle_timeout", milliseconds(e.MaxIdleTimeout)) - enc.Int64KeyNullEmpty("max_udp_payload_size", int64(e.MaxUDPPayloadSize)) - enc.Uint8KeyOmitEmpty("ack_delay_exponent", e.AckDelayExponent) - enc.FloatKeyOmitEmpty("max_ack_delay", milliseconds(e.MaxAckDelay)) - enc.Uint64KeyOmitEmpty("active_connection_id_limit", e.ActiveConnectionIDLimit) - - enc.Int64KeyOmitEmpty("initial_max_data", int64(e.InitialMaxData)) - enc.Int64KeyOmitEmpty("initial_max_stream_data_bidi_local", int64(e.InitialMaxStreamDataBidiLocal)) - enc.Int64KeyOmitEmpty("initial_max_stream_data_bidi_remote", int64(e.InitialMaxStreamDataBidiRemote)) - enc.Int64KeyOmitEmpty("initial_max_stream_data_uni", int64(e.InitialMaxStreamDataUni)) - enc.Int64KeyOmitEmpty("initial_max_streams_bidi", e.InitialMaxStreamsBidi) - enc.Int64KeyOmitEmpty("initial_max_streams_uni", e.InitialMaxStreamsUni) - - if e.PreferredAddress != nil { - enc.ObjectKey("preferred_address", e.PreferredAddress) - } - if e.MaxDatagramFrameSize != protocol.InvalidByteCount { - enc.Int64Key("max_datagram_frame_size", int64(e.MaxDatagramFrameSize)) - } -} - -type preferredAddress struct { - IPv4, IPv6 net.IP - PortV4, PortV6 uint16 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} - -var _ gojay.MarshalerJSONObject = &preferredAddress{} - -func (a preferredAddress) IsNil() bool { return false } -func (a preferredAddress) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("ip_v4", a.IPv4.String()) - enc.Uint16Key("port_v4", a.PortV4) - enc.StringKey("ip_v6", a.IPv6.String()) - enc.Uint16Key("port_v6", a.PortV6) - enc.StringKey("connection_id", connectionID(a.ConnectionID).String()) - enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", a.StatelessResetToken)) -} - -type eventLossTimerSet struct { - TimerType timerType - EncLevel protocol.EncryptionLevel - Delta time.Duration -} - -func (e eventLossTimerSet) Category() category { return categoryRecovery } -func (e eventLossTimerSet) Name() string { return "loss_timer_updated" } -func (e eventLossTimerSet) IsNil() bool { return false } - -func (e eventLossTimerSet) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("event_type", "set") - enc.StringKey("timer_type", e.TimerType.String()) - enc.StringKey("packet_number_space", encLevelToPacketNumberSpace(e.EncLevel)) - enc.Float64Key("delta", milliseconds(e.Delta)) -} - -type eventLossTimerExpired struct { - TimerType timerType - EncLevel protocol.EncryptionLevel -} - -func (e eventLossTimerExpired) Category() category { return categoryRecovery } -func (e eventLossTimerExpired) Name() string { return "loss_timer_updated" } -func (e eventLossTimerExpired) IsNil() bool { return false } - -func (e eventLossTimerExpired) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("event_type", "expired") - enc.StringKey("timer_type", e.TimerType.String()) - enc.StringKey("packet_number_space", encLevelToPacketNumberSpace(e.EncLevel)) -} - -type eventLossTimerCanceled struct{} - -func (e eventLossTimerCanceled) Category() category { return categoryRecovery } -func (e eventLossTimerCanceled) Name() string { return "loss_timer_updated" } -func (e eventLossTimerCanceled) IsNil() bool { return false } - -func (e eventLossTimerCanceled) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("event_type", "cancelled") -} - -type eventCongestionStateUpdated struct { - state congestionState -} - -func (e eventCongestionStateUpdated) Category() category { return categoryRecovery } -func (e eventCongestionStateUpdated) Name() string { return "congestion_state_updated" } -func (e eventCongestionStateUpdated) IsNil() bool { return false } - -func (e eventCongestionStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("new", e.state.String()) -} - -type eventGeneric struct { - name string - msg string -} - -func (e eventGeneric) Category() category { return categoryTransport } -func (e eventGeneric) Name() string { return e.name } -func (e eventGeneric) IsNil() bool { return false } - -func (e eventGeneric) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("details", e.msg) -} diff --git a/internal/quic-go/qlog/event_test.go b/internal/quic-go/qlog/event_test.go deleted file mode 100644 index 4248e69e..00000000 --- a/internal/quic-go/qlog/event_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package qlog - -import ( - "bytes" - "encoding/json" - "time" - - "github.com/francoispqt/gojay" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type mevent struct{} - -var _ eventDetails = mevent{} - -func (mevent) Category() category { return categoryConnectivity } -func (mevent) Name() string { return "mevent" } -func (mevent) IsNil() bool { return false } -func (mevent) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("event", "details") } - -var _ = Describe("Events", func() { - It("marshals the fields before the event details", func() { - buf := &bytes.Buffer{} - enc := gojay.NewEncoder(buf) - Expect(enc.Encode(event{ - RelativeTime: 1337 * time.Microsecond, - eventDetails: mevent{}, - })).To(Succeed()) - - var decoded interface{} - Expect(json.Unmarshal(buf.Bytes(), &decoded)).To(Succeed()) - Expect(decoded).To(HaveLen(3)) - - Expect(decoded).To(HaveKeyWithValue("time", 1.337)) - Expect(decoded).To(HaveKeyWithValue("name", "connectivity:mevent")) - Expect(decoded).To(HaveKey("data")) - data := decoded.(map[string]interface{})["data"].(map[string]interface{}) - Expect(data).To(HaveLen(1)) - Expect(data).To(HaveKeyWithValue("event", "details")) - }) -}) diff --git a/internal/quic-go/qlog/frame.go b/internal/quic-go/qlog/frame.go deleted file mode 100644 index c6e58253..00000000 --- a/internal/quic-go/qlog/frame.go +++ /dev/null @@ -1,227 +0,0 @@ -package qlog - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/francoispqt/gojay" -) - -type frame struct { - Frame logging.Frame -} - -var _ gojay.MarshalerJSONObject = frame{} - -var _ gojay.MarshalerJSONArray = frames{} - -func (f frame) MarshalJSONObject(enc *gojay.Encoder) { - switch frame := f.Frame.(type) { - case *logging.PingFrame: - marshalPingFrame(enc, frame) - case *logging.AckFrame: - marshalAckFrame(enc, frame) - case *logging.ResetStreamFrame: - marshalResetStreamFrame(enc, frame) - case *logging.StopSendingFrame: - marshalStopSendingFrame(enc, frame) - case *logging.CryptoFrame: - marshalCryptoFrame(enc, frame) - case *logging.NewTokenFrame: - marshalNewTokenFrame(enc, frame) - case *logging.StreamFrame: - marshalStreamFrame(enc, frame) - case *logging.MaxDataFrame: - marshalMaxDataFrame(enc, frame) - case *logging.MaxStreamDataFrame: - marshalMaxStreamDataFrame(enc, frame) - case *logging.MaxStreamsFrame: - marshalMaxStreamsFrame(enc, frame) - case *logging.DataBlockedFrame: - marshalDataBlockedFrame(enc, frame) - case *logging.StreamDataBlockedFrame: - marshalStreamDataBlockedFrame(enc, frame) - case *logging.StreamsBlockedFrame: - marshalStreamsBlockedFrame(enc, frame) - case *logging.NewConnectionIDFrame: - marshalNewConnectionIDFrame(enc, frame) - case *logging.RetireConnectionIDFrame: - marshalRetireConnectionIDFrame(enc, frame) - case *logging.PathChallengeFrame: - marshalPathChallengeFrame(enc, frame) - case *logging.PathResponseFrame: - marshalPathResponseFrame(enc, frame) - case *logging.ConnectionCloseFrame: - marshalConnectionCloseFrame(enc, frame) - case *logging.HandshakeDoneFrame: - marshalHandshakeDoneFrame(enc, frame) - case *logging.DatagramFrame: - marshalDatagramFrame(enc, frame) - default: - panic("unknown frame type") - } -} - -func (f frame) IsNil() bool { return false } - -type frames []frame - -func (fs frames) IsNil() bool { return fs == nil } -func (fs frames) MarshalJSONArray(enc *gojay.Encoder) { - for _, f := range fs { - enc.Object(f) - } -} - -func marshalPingFrame(enc *gojay.Encoder, _ *wire.PingFrame) { - enc.StringKey("frame_type", "ping") -} - -type ackRanges []wire.AckRange - -func (ars ackRanges) MarshalJSONArray(enc *gojay.Encoder) { - for _, r := range ars { - enc.Array(ackRange(r)) - } -} - -func (ars ackRanges) IsNil() bool { return false } - -type ackRange wire.AckRange - -func (ar ackRange) MarshalJSONArray(enc *gojay.Encoder) { - enc.AddInt64(int64(ar.Smallest)) - if ar.Smallest != ar.Largest { - enc.AddInt64(int64(ar.Largest)) - } -} - -func (ar ackRange) IsNil() bool { return false } - -func marshalAckFrame(enc *gojay.Encoder, f *logging.AckFrame) { - enc.StringKey("frame_type", "ack") - enc.FloatKeyOmitEmpty("ack_delay", milliseconds(f.DelayTime)) - enc.ArrayKey("acked_ranges", ackRanges(f.AckRanges)) - if hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0; hasECN { - enc.Uint64Key("ect0", f.ECT0) - enc.Uint64Key("ect1", f.ECT1) - enc.Uint64Key("ce", f.ECNCE) - } -} - -func marshalResetStreamFrame(enc *gojay.Encoder, f *logging.ResetStreamFrame) { - enc.StringKey("frame_type", "reset_stream") - enc.Int64Key("stream_id", int64(f.StreamID)) - enc.Int64Key("error_code", int64(f.ErrorCode)) - enc.Int64Key("final_size", int64(f.FinalSize)) -} - -func marshalStopSendingFrame(enc *gojay.Encoder, f *logging.StopSendingFrame) { - enc.StringKey("frame_type", "stop_sending") - enc.Int64Key("stream_id", int64(f.StreamID)) - enc.Int64Key("error_code", int64(f.ErrorCode)) -} - -func marshalCryptoFrame(enc *gojay.Encoder, f *logging.CryptoFrame) { - enc.StringKey("frame_type", "crypto") - enc.Int64Key("offset", int64(f.Offset)) - enc.Int64Key("length", int64(f.Length)) -} - -func marshalNewTokenFrame(enc *gojay.Encoder, f *logging.NewTokenFrame) { - enc.StringKey("frame_type", "new_token") - enc.ObjectKey("token", &token{Raw: f.Token}) -} - -func marshalStreamFrame(enc *gojay.Encoder, f *logging.StreamFrame) { - enc.StringKey("frame_type", "stream") - enc.Int64Key("stream_id", int64(f.StreamID)) - enc.Int64Key("offset", int64(f.Offset)) - enc.IntKey("length", int(f.Length)) - enc.BoolKeyOmitEmpty("fin", f.Fin) -} - -func marshalMaxDataFrame(enc *gojay.Encoder, f *logging.MaxDataFrame) { - enc.StringKey("frame_type", "max_data") - enc.Int64Key("maximum", int64(f.MaximumData)) -} - -func marshalMaxStreamDataFrame(enc *gojay.Encoder, f *logging.MaxStreamDataFrame) { - enc.StringKey("frame_type", "max_stream_data") - enc.Int64Key("stream_id", int64(f.StreamID)) - enc.Int64Key("maximum", int64(f.MaximumStreamData)) -} - -func marshalMaxStreamsFrame(enc *gojay.Encoder, f *logging.MaxStreamsFrame) { - enc.StringKey("frame_type", "max_streams") - enc.StringKey("stream_type", streamType(f.Type).String()) - enc.Int64Key("maximum", int64(f.MaxStreamNum)) -} - -func marshalDataBlockedFrame(enc *gojay.Encoder, f *logging.DataBlockedFrame) { - enc.StringKey("frame_type", "data_blocked") - enc.Int64Key("limit", int64(f.MaximumData)) -} - -func marshalStreamDataBlockedFrame(enc *gojay.Encoder, f *logging.StreamDataBlockedFrame) { - enc.StringKey("frame_type", "stream_data_blocked") - enc.Int64Key("stream_id", int64(f.StreamID)) - enc.Int64Key("limit", int64(f.MaximumStreamData)) -} - -func marshalStreamsBlockedFrame(enc *gojay.Encoder, f *logging.StreamsBlockedFrame) { - enc.StringKey("frame_type", "streams_blocked") - enc.StringKey("stream_type", streamType(f.Type).String()) - enc.Int64Key("limit", int64(f.StreamLimit)) -} - -func marshalNewConnectionIDFrame(enc *gojay.Encoder, f *logging.NewConnectionIDFrame) { - enc.StringKey("frame_type", "new_connection_id") - enc.Int64Key("sequence_number", int64(f.SequenceNumber)) - enc.Int64Key("retire_prior_to", int64(f.RetirePriorTo)) - enc.IntKey("length", f.ConnectionID.Len()) - enc.StringKey("connection_id", connectionID(f.ConnectionID).String()) - enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", f.StatelessResetToken)) -} - -func marshalRetireConnectionIDFrame(enc *gojay.Encoder, f *logging.RetireConnectionIDFrame) { - enc.StringKey("frame_type", "retire_connection_id") - enc.Int64Key("sequence_number", int64(f.SequenceNumber)) -} - -func marshalPathChallengeFrame(enc *gojay.Encoder, f *logging.PathChallengeFrame) { - enc.StringKey("frame_type", "path_challenge") - enc.StringKey("data", fmt.Sprintf("%x", f.Data[:])) -} - -func marshalPathResponseFrame(enc *gojay.Encoder, f *logging.PathResponseFrame) { - enc.StringKey("frame_type", "path_response") - enc.StringKey("data", fmt.Sprintf("%x", f.Data[:])) -} - -func marshalConnectionCloseFrame(enc *gojay.Encoder, f *logging.ConnectionCloseFrame) { - errorSpace := "transport" - if f.IsApplicationError { - errorSpace = "application" - } - enc.StringKey("frame_type", "connection_close") - enc.StringKey("error_space", errorSpace) - if errName := transportError(f.ErrorCode).String(); len(errName) > 0 { - enc.StringKey("error_code", errName) - } else { - enc.Uint64Key("error_code", f.ErrorCode) - } - enc.Uint64Key("raw_error_code", f.ErrorCode) - enc.StringKey("reason", f.ReasonPhrase) -} - -func marshalHandshakeDoneFrame(enc *gojay.Encoder, _ *logging.HandshakeDoneFrame) { - enc.StringKey("frame_type", "handshake_done") -} - -func marshalDatagramFrame(enc *gojay.Encoder, f *logging.DatagramFrame) { - enc.StringKey("frame_type", "datagram") - enc.Int64Key("length", int64(f.Length)) -} diff --git a/internal/quic-go/qlog/frame_test.go b/internal/quic-go/qlog/frame_test.go deleted file mode 100644 index 5e78d8ef..00000000 --- a/internal/quic-go/qlog/frame_test.go +++ /dev/null @@ -1,377 +0,0 @@ -package qlog - -import ( - "bytes" - "encoding/json" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - - "github.com/francoispqt/gojay" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Frames", func() { - check := func(f logging.Frame, expected map[string]interface{}) { - buf := &bytes.Buffer{} - enc := gojay.NewEncoder(buf) - ExpectWithOffset(1, enc.Encode(frame{Frame: f})).To(Succeed()) - data := buf.Bytes() - ExpectWithOffset(1, json.Valid(data)).To(BeTrue()) - checkEncoding(data, expected) - } - - It("marshals PING frames", func() { - check( - &logging.PingFrame{}, - map[string]interface{}{ - "frame_type": "ping", - }, - ) - }) - - It("marshals ACK frames with a range acknowledging a single packet", func() { - check( - &logging.AckFrame{ - DelayTime: 86 * time.Millisecond, - AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, - }, - map[string]interface{}{ - "frame_type": "ack", - "ack_delay": 86, - "acked_ranges": [][]float64{{120}}, - }, - ) - }) - - It("marshals ACK frames without a delay", func() { - check( - &logging.AckFrame{ - AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, - }, - map[string]interface{}{ - "frame_type": "ack", - "acked_ranges": [][]float64{{120}}, - }, - ) - }) - - It("marshals ACK frames with ECN counts", func() { - check( - &logging.AckFrame{ - AckRanges: []logging.AckRange{{Smallest: 120, Largest: 120}}, - ECT0: 10, - ECT1: 100, - ECNCE: 1000, - }, - map[string]interface{}{ - "frame_type": "ack", - "acked_ranges": [][]float64{{120}}, - "ect0": 10, - "ect1": 100, - "ce": 1000, - }, - ) - }) - - It("marshals ACK frames with a range acknowledging ranges of packets", func() { - check( - &logging.AckFrame{ - DelayTime: 86 * time.Millisecond, - AckRanges: []logging.AckRange{ - {Smallest: 5, Largest: 50}, - {Smallest: 100, Largest: 120}, - }, - }, - map[string]interface{}{ - "frame_type": "ack", - "ack_delay": 86, - "acked_ranges": [][]float64{ - {5, 50}, - {100, 120}, - }, - }, - ) - }) - - It("marshals RESET_STREAM frames", func() { - check( - &logging.ResetStreamFrame{ - StreamID: 987, - FinalSize: 1234, - ErrorCode: 42, - }, - map[string]interface{}{ - "frame_type": "reset_stream", - "stream_id": 987, - "error_code": 42, - "final_size": 1234, - }, - ) - }) - - It("marshals STOP_SENDING frames", func() { - check( - &logging.StopSendingFrame{ - StreamID: 987, - ErrorCode: 42, - }, - map[string]interface{}{ - "frame_type": "stop_sending", - "stream_id": 987, - "error_code": 42, - }, - ) - }) - - It("marshals CRYPTO frames", func() { - check( - &logging.CryptoFrame{ - Offset: 1337, - Length: 6, - }, - map[string]interface{}{ - "frame_type": "crypto", - "offset": 1337, - "length": 6, - }, - ) - }) - - It("marshals NEW_TOKEN frames", func() { - check( - &logging.NewTokenFrame{ - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - }, - map[string]interface{}{ - "frame_type": "new_token", - "token": map[string]interface{}{"data": "deadbeef"}, - }, - ) - }) - - It("marshals STREAM frames with FIN", func() { - check( - &logging.StreamFrame{ - StreamID: 42, - Offset: 1337, - Fin: true, - Length: 9876, - }, - map[string]interface{}{ - "frame_type": "stream", - "stream_id": 42, - "offset": 1337, - "fin": true, - "length": 9876, - }, - ) - }) - - It("marshals STREAM frames without FIN", func() { - check( - &logging.StreamFrame{ - StreamID: 42, - Offset: 1337, - Length: 3, - }, - map[string]interface{}{ - "frame_type": "stream", - "stream_id": 42, - "offset": 1337, - "length": 3, - }, - ) - }) - - It("marshals MAX_DATA frames", func() { - check( - &logging.MaxDataFrame{ - MaximumData: 1337, - }, - map[string]interface{}{ - "frame_type": "max_data", - "maximum": 1337, - }, - ) - }) - - It("marshals MAX_STREAM_DATA frames", func() { - check( - &logging.MaxStreamDataFrame{ - StreamID: 1234, - MaximumStreamData: 1337, - }, - map[string]interface{}{ - "frame_type": "max_stream_data", - "stream_id": 1234, - "maximum": 1337, - }, - ) - }) - - It("marshals MAX_STREAMS frames", func() { - check( - &logging.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 42, - }, - map[string]interface{}{ - "frame_type": "max_streams", - "stream_type": "bidirectional", - "maximum": 42, - }, - ) - }) - - It("marshals DATA_BLOCKED frames", func() { - check( - &logging.DataBlockedFrame{ - MaximumData: 1337, - }, - map[string]interface{}{ - "frame_type": "data_blocked", - "limit": 1337, - }, - ) - }) - - It("marshals STREAM_DATA_BLOCKED frames", func() { - check( - &logging.StreamDataBlockedFrame{ - StreamID: 42, - MaximumStreamData: 1337, - }, - map[string]interface{}{ - "frame_type": "stream_data_blocked", - "stream_id": 42, - "limit": 1337, - }, - ) - }) - - It("marshals STREAMS_BLOCKED frames", func() { - check( - &logging.StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, - StreamLimit: 123, - }, - map[string]interface{}{ - "frame_type": "streams_blocked", - "stream_type": "unidirectional", - "limit": 123, - }, - ) - }) - - It("marshals NEW_CONNECTION_ID frames", func() { - check( - &logging.NewConnectionIDFrame{ - SequenceNumber: 42, - RetirePriorTo: 24, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, - }, - map[string]interface{}{ - "frame_type": "new_connection_id", - "sequence_number": 42, - "retire_prior_to": 24, - "length": 4, - "connection_id": "deadbeef", - "stateless_reset_token": "000102030405060708090a0b0c0d0e0f", - }, - ) - }) - - It("marshals RETIRE_CONNECTION_ID frames", func() { - check( - &logging.RetireConnectionIDFrame{ - SequenceNumber: 1337, - }, - map[string]interface{}{ - "frame_type": "retire_connection_id", - "sequence_number": 1337, - }, - ) - }) - - It("marshals PATH_CHALLENGE frames", func() { - check( - &logging.PathChallengeFrame{ - Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}, - }, - map[string]interface{}{ - "frame_type": "path_challenge", - "data": "deadbeefcafec001", - }, - ) - }) - - It("marshals PATH_RESPONSE frames", func() { - check( - &logging.PathResponseFrame{ - Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}, - }, - map[string]interface{}{ - "frame_type": "path_response", - "data": "deadbeefcafec001", - }, - ) - }) - - It("marshals CONNECTION_CLOSE frames, for application error codes", func() { - check( - &logging.ConnectionCloseFrame{ - IsApplicationError: true, - ErrorCode: 1337, - ReasonPhrase: "lorem ipsum", - }, - map[string]interface{}{ - "frame_type": "connection_close", - "error_space": "application", - "error_code": 1337, - "raw_error_code": 1337, - "reason": "lorem ipsum", - }, - ) - }) - - It("marshals CONNECTION_CLOSE frames, for transport error codes", func() { - check( - &logging.ConnectionCloseFrame{ - ErrorCode: uint64(qerr.FlowControlError), - ReasonPhrase: "lorem ipsum", - }, - map[string]interface{}{ - "frame_type": "connection_close", - "error_space": "transport", - "error_code": "flow_control_error", - "raw_error_code": int(qerr.FlowControlError), - "reason": "lorem ipsum", - }, - ) - }) - - It("marshals HANDSHAKE_DONE frames", func() { - check( - &logging.HandshakeDoneFrame{}, - map[string]interface{}{ - "frame_type": "handshake_done", - }, - ) - }) - - It("marshals DATAGRAM frames", func() { - check( - &logging.DatagramFrame{Length: 1337}, - map[string]interface{}{ - "frame_type": "datagram", - "length": 1337, - }, - ) - }) -}) diff --git a/internal/quic-go/qlog/packet_header.go b/internal/quic-go/qlog/packet_header.go deleted file mode 100644 index d029c8d2..00000000 --- a/internal/quic-go/qlog/packet_header.go +++ /dev/null @@ -1,119 +0,0 @@ -package qlog - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/francoispqt/gojay" -) - -func getPacketTypeFromEncryptionLevel(encLevel protocol.EncryptionLevel) logging.PacketType { - switch encLevel { - case protocol.EncryptionInitial: - return logging.PacketTypeInitial - case protocol.EncryptionHandshake: - return logging.PacketTypeHandshake - case protocol.Encryption0RTT: - return logging.PacketType0RTT - case protocol.Encryption1RTT: - return logging.PacketType1RTT - default: - panic("unknown encryption level") - } -} - -type token struct { - Raw []byte -} - -var _ gojay.MarshalerJSONObject = &token{} - -func (t token) IsNil() bool { return false } -func (t token) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("data", fmt.Sprintf("%x", t.Raw)) -} - -// PacketHeader is a QUIC packet header. -type packetHeader struct { - PacketType logging.PacketType - - KeyPhaseBit logging.KeyPhaseBit - PacketNumber logging.PacketNumber - - Version logging.VersionNumber - SrcConnectionID logging.ConnectionID - DestConnectionID logging.ConnectionID - - Token *token -} - -func transformHeader(hdr *wire.Header) *packetHeader { - h := &packetHeader{ - PacketType: logging.PacketTypeFromHeader(hdr), - SrcConnectionID: hdr.SrcConnectionID, - DestConnectionID: hdr.DestConnectionID, - Version: hdr.Version, - } - if len(hdr.Token) > 0 { - h.Token = &token{Raw: hdr.Token} - } - return h -} - -func transformExtendedHeader(hdr *wire.ExtendedHeader) *packetHeader { - h := transformHeader(&hdr.Header) - h.PacketNumber = hdr.PacketNumber - h.KeyPhaseBit = hdr.KeyPhase - return h -} - -func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("packet_type", packetType(h.PacketType).String()) - if h.PacketType != logging.PacketTypeRetry && h.PacketType != logging.PacketTypeVersionNegotiation { - enc.Int64Key("packet_number", int64(h.PacketNumber)) - } - if h.Version != 0 { - enc.StringKey("version", versionNumber(h.Version).String()) - } - if h.PacketType != logging.PacketType1RTT { - enc.IntKey("scil", h.SrcConnectionID.Len()) - if h.SrcConnectionID.Len() > 0 { - enc.StringKey("scid", connectionID(h.SrcConnectionID).String()) - } - } - enc.IntKey("dcil", h.DestConnectionID.Len()) - if h.DestConnectionID.Len() > 0 { - enc.StringKey("dcid", connectionID(h.DestConnectionID).String()) - } - if h.KeyPhaseBit == logging.KeyPhaseZero || h.KeyPhaseBit == logging.KeyPhaseOne { - enc.StringKey("key_phase_bit", h.KeyPhaseBit.String()) - } - if h.Token != nil { - enc.ObjectKey("token", h.Token) - } -} - -// a minimal header that only outputs the packet type -type packetHeaderWithType struct { - PacketType logging.PacketType -} - -func (h packetHeaderWithType) IsNil() bool { return false } -func (h packetHeaderWithType) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("packet_type", packetType(h.PacketType).String()) -} - -// a minimal header that only outputs the packet type -type packetHeaderWithTypeAndPacketNumber struct { - PacketType logging.PacketType - PacketNumber logging.PacketNumber -} - -func (h packetHeaderWithTypeAndPacketNumber) IsNil() bool { return false } -func (h packetHeaderWithTypeAndPacketNumber) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("packet_type", packetType(h.PacketType).String()) - enc.Int64Key("packet_number", int64(h.PacketNumber)) -} diff --git a/internal/quic-go/qlog/packet_header_test.go b/internal/quic-go/qlog/packet_header_test.go deleted file mode 100644 index 91dfeb39..00000000 --- a/internal/quic-go/qlog/packet_header_test.go +++ /dev/null @@ -1,175 +0,0 @@ -package qlog - -import ( - "bytes" - "encoding/json" - - "github.com/francoispqt/gojay" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Packet Header", func() { - It("determines the packet type from the encryption level", func() { - Expect(getPacketTypeFromEncryptionLevel(protocol.EncryptionInitial)).To(BeEquivalentTo(logging.PacketTypeInitial)) - Expect(getPacketTypeFromEncryptionLevel(protocol.EncryptionHandshake)).To(BeEquivalentTo(logging.PacketTypeHandshake)) - Expect(getPacketTypeFromEncryptionLevel(protocol.Encryption0RTT)).To(BeEquivalentTo(logging.PacketType0RTT)) - Expect(getPacketTypeFromEncryptionLevel(protocol.Encryption1RTT)).To(BeEquivalentTo(logging.PacketType1RTT)) - }) - - Context("marshalling", func() { - check := func(hdr *wire.ExtendedHeader, expected map[string]interface{}) { - buf := &bytes.Buffer{} - enc := gojay.NewEncoder(buf) - ExpectWithOffset(1, enc.Encode(transformExtendedHeader(hdr))).To(Succeed()) - data := buf.Bytes() - ExpectWithOffset(1, json.Valid(data)).To(BeTrue()) - checkEncoding(data, expected) - } - - It("marshals a header for a 1-RTT packet", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 42, - KeyPhase: protocol.KeyPhaseZero, - }, - map[string]interface{}{ - "packet_type": "1RTT", - "packet_number": 42, - "dcil": 0, - "key_phase_bit": "0", - }, - ) - }) - - It("marshals a header with a payload length", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 42, - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Length: 123, - Version: protocol.VersionNumber(0xdecafbad), - }, - }, - map[string]interface{}{ - "packet_type": "initial", - "packet_number": 42, - "dcil": 0, - "scil": 0, - "version": "decafbad", - }, - ) - }) - - It("marshals an Initial with a token", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 4242, - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Length: 123, - Version: protocol.VersionNumber(0xdecafbad), - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - }, - }, - map[string]interface{}{ - "packet_type": "initial", - "packet_number": 4242, - "dcil": 0, - "scil": 0, - "version": "decafbad", - "token": map[string]interface{}{"data": "deadbeef"}, - }, - ) - }) - - It("marshals a Retry packet", func() { - check( - &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0x11, 0x22, 0x33, 0x44}, - Version: protocol.VersionNumber(0xdecafbad), - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - }, - }, - map[string]interface{}{ - "packet_type": "retry", - "dcil": 0, - "scil": 4, - "scid": "11223344", - "token": map[string]interface{}{"data": "deadbeef"}, - "version": "decafbad", - }, - ) - }) - - It("marshals a packet with packet number 0", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 0, - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: protocol.VersionNumber(0xdecafbad), - }, - }, - map[string]interface{}{ - "packet_type": "handshake", - "packet_number": 0, - "dcil": 0, - "scil": 0, - "version": "decafbad", - }, - ) - }) - - It("marshals a header with a source connection ID", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 42, - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - SrcConnectionID: protocol.ConnectionID{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, - Version: protocol.VersionNumber(0xdecafbad), - }, - }, - map[string]interface{}{ - "packet_type": "handshake", - "packet_number": 42, - "dcil": 0, - "scil": 16, - "scid": "00112233445566778899aabbccddeeff", - "version": "decafbad", - }, - ) - }) - - It("marshals a 1-RTT header with a destination connection ID", func() { - check( - &wire.ExtendedHeader{ - PacketNumber: 42, - Header: wire.Header{DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, - KeyPhase: protocol.KeyPhaseOne, - }, - map[string]interface{}{ - "packet_type": "1RTT", - "packet_number": 42, - "dcil": 4, - "dcid": "deadbeef", - "key_phase_bit": "1", - }, - ) - }) - }) -}) diff --git a/internal/quic-go/qlog/qlog.go b/internal/quic-go/qlog/qlog.go deleted file mode 100644 index feaa296e..00000000 --- a/internal/quic-go/qlog/qlog.go +++ /dev/null @@ -1,486 +0,0 @@ -package qlog - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "net" - "runtime/debug" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/francoispqt/gojay" -) - -// Setting of this only works when quic-go is used as a library. -// When building a binary from this repository, the version can be set using the following go build flag: -// -ldflags="-X github.com/imroc/req/v3/internal/quic-go/qlog.quicGoVersion=foobar" -var quicGoVersion = "(devel)" - -func init() { - if quicGoVersion != "(devel)" { // variable set by ldflags - return - } - info, ok := debug.ReadBuildInfo() - if !ok { // no build info available. This happens when quic-go is not used as a library. - return - } - for _, d := range info.Deps { - if d.Path == "github.com/imroc/req/v3/internal/quic-go" { - quicGoVersion = d.Version - if d.Replace != nil { - if len(d.Replace.Version) > 0 { - quicGoVersion = d.Version - } else { - quicGoVersion += " (replaced)" - } - } - break - } - } -} - -const eventChanSize = 50 - -type tracer struct { - getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser -} - -var _ logging.Tracer = &tracer{} - -// NewTracer creates a new qlog tracer. -func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer { - return &tracer{getLogWriter: getLogWriter} -} - -func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - if w := t.getLogWriter(p, odcid.Bytes()); w != nil { - return NewConnectionTracer(w, p, odcid) - } - return nil -} - -func (t *tracer) SentPacket(net.Addr, *logging.Header, protocol.ByteCount, []logging.Frame) {} -func (t *tracer) DroppedPacket(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { -} - -type connectionTracer struct { - mutex sync.Mutex - - w io.WriteCloser - odcid protocol.ConnectionID - perspective protocol.Perspective - referenceTime time.Time - - events chan event - encodeErr error - runStopped chan struct{} - - lastMetrics *metrics -} - -var _ logging.ConnectionTracer = &connectionTracer{} - -// NewConnectionTracer creates a new tracer to record a qlog for a connection. -func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - t := &connectionTracer{ - w: w, - perspective: p, - odcid: odcid, - runStopped: make(chan struct{}), - events: make(chan event, eventChanSize), - referenceTime: time.Now(), - } - go t.run() - return t -} - -func (t *connectionTracer) run() { - defer close(t.runStopped) - buf := &bytes.Buffer{} - enc := gojay.NewEncoder(buf) - tl := &topLevel{ - trace: trace{ - VantagePoint: vantagePoint{Type: t.perspective}, - CommonFields: commonFields{ - ODCID: connectionID(t.odcid), - GroupID: connectionID(t.odcid), - ReferenceTime: t.referenceTime, - }, - }, - } - if err := enc.Encode(tl); err != nil { - panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) - } - if err := buf.WriteByte('\n'); err != nil { - panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) - } - if _, err := t.w.Write(buf.Bytes()); err != nil { - t.encodeErr = err - } - enc = gojay.NewEncoder(t.w) - for ev := range t.events { - if t.encodeErr != nil { // if encoding failed, just continue draining the event channel - continue - } - if err := enc.Encode(ev); err != nil { - t.encodeErr = err - continue - } - if _, err := t.w.Write([]byte{'\n'}); err != nil { - t.encodeErr = err - } - } -} - -func (t *connectionTracer) Close() { - if err := t.export(); err != nil { - log.Printf("exporting qlog failed: %s\n", err) - } -} - -// export writes a qlog. -func (t *connectionTracer) export() error { - close(t.events) - <-t.runStopped - if t.encodeErr != nil { - return t.encodeErr - } - return t.w.Close() -} - -func (t *connectionTracer) recordEvent(eventTime time.Time, details eventDetails) { - t.events <- event{ - RelativeTime: eventTime.Sub(t.referenceTime), - eventDetails: details, - } -} - -func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID protocol.ConnectionID) { - // ignore this event if we're not dealing with UDP addresses here - localAddr, ok := local.(*net.UDPAddr) - if !ok { - return - } - remoteAddr, ok := remote.(*net.UDPAddr) - if !ok { - return - } - t.mutex.Lock() - t.recordEvent(time.Now(), &eventConnectionStarted{ - SrcAddr: localAddr, - DestAddr: remoteAddr, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, client, server []logging.VersionNumber) { - var clientVersions, serverVersions []versionNumber - if len(client) > 0 { - clientVersions = make([]versionNumber, len(client)) - for i, v := range client { - clientVersions[i] = versionNumber(v) - } - } - if len(server) > 0 { - serverVersions = make([]versionNumber, len(server)) - for i, v := range server { - serverVersions[i] = versionNumber(v) - } - } - t.mutex.Lock() - t.recordEvent(time.Now(), &eventVersionNegotiated{ - clientVersions: clientVersions, - serverVersions: serverVersions, - chosenVersion: versionNumber(chosen), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) ClosedConnection(e error) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventConnectionClosed{e: e}) - t.mutex.Unlock() -} - -func (t *connectionTracer) SentTransportParameters(tp *wire.TransportParameters) { - t.recordTransportParameters(t.perspective, tp) -} - -func (t *connectionTracer) ReceivedTransportParameters(tp *wire.TransportParameters) { - t.recordTransportParameters(t.perspective.Opposite(), tp) -} - -func (t *connectionTracer) RestoredTransportParameters(tp *wire.TransportParameters) { - ev := t.toTransportParameters(tp) - ev.Restore = true - - t.mutex.Lock() - t.recordEvent(time.Now(), ev) - t.mutex.Unlock() -} - -func (t *connectionTracer) recordTransportParameters(sentBy protocol.Perspective, tp *wire.TransportParameters) { - ev := t.toTransportParameters(tp) - ev.Owner = ownerLocal - if sentBy != t.perspective { - ev.Owner = ownerRemote - } - ev.SentBy = sentBy - - t.mutex.Lock() - t.recordEvent(time.Now(), ev) - t.mutex.Unlock() -} - -func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) *eventTransportParameters { - var pa *preferredAddress - if tp.PreferredAddress != nil { - pa = &preferredAddress{ - IPv4: tp.PreferredAddress.IPv4, - PortV4: tp.PreferredAddress.IPv4Port, - IPv6: tp.PreferredAddress.IPv6, - PortV6: tp.PreferredAddress.IPv6Port, - ConnectionID: tp.PreferredAddress.ConnectionID, - StatelessResetToken: tp.PreferredAddress.StatelessResetToken, - } - } - return &eventTransportParameters{ - OriginalDestinationConnectionID: tp.OriginalDestinationConnectionID, - InitialSourceConnectionID: tp.InitialSourceConnectionID, - RetrySourceConnectionID: tp.RetrySourceConnectionID, - StatelessResetToken: tp.StatelessResetToken, - DisableActiveMigration: tp.DisableActiveMigration, - MaxIdleTimeout: tp.MaxIdleTimeout, - MaxUDPPayloadSize: tp.MaxUDPPayloadSize, - AckDelayExponent: tp.AckDelayExponent, - MaxAckDelay: tp.MaxAckDelay, - ActiveConnectionIDLimit: tp.ActiveConnectionIDLimit, - InitialMaxData: tp.InitialMaxData, - InitialMaxStreamDataBidiLocal: tp.InitialMaxStreamDataBidiLocal, - InitialMaxStreamDataBidiRemote: tp.InitialMaxStreamDataBidiRemote, - InitialMaxStreamDataUni: tp.InitialMaxStreamDataUni, - InitialMaxStreamsBidi: int64(tp.MaxBidiStreamNum), - InitialMaxStreamsUni: int64(tp.MaxUniStreamNum), - PreferredAddress: pa, - MaxDatagramFrameSize: tp.MaxDatagramFrameSize, - } -} - -func (t *connectionTracer) SentPacket(hdr *wire.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { - numFrames := len(frames) - if ack != nil { - numFrames++ - } - fs := make([]frame, 0, numFrames) - if ack != nil { - fs = append(fs, frame{Frame: ack}) - } - for _, f := range frames { - fs = append(fs, frame{Frame: f}) - } - header := *transformExtendedHeader(hdr) - t.mutex.Lock() - t.recordEvent(time.Now(), &eventPacketSent{ - Header: header, - Length: packetSize, - PayloadLength: hdr.Length, - Frames: fs, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) ReceivedPacket(hdr *wire.ExtendedHeader, packetSize logging.ByteCount, frames []logging.Frame) { - fs := make([]frame, len(frames)) - for i, f := range frames { - fs[i] = frame{Frame: f} - } - header := *transformExtendedHeader(hdr) - t.mutex.Lock() - t.recordEvent(time.Now(), &eventPacketReceived{ - Header: header, - Length: packetSize, - PayloadLength: hdr.Length, - Frames: fs, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) ReceivedRetry(hdr *wire.Header) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventRetryReceived{ - Header: *transformHeader(hdr), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) ReceivedVersionNegotiationPacket(hdr *wire.Header, versions []logging.VersionNumber) { - ver := make([]versionNumber, len(versions)) - for i, v := range versions { - ver[i] = versionNumber(v) - } - t.mutex.Lock() - t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ - Header: *transformHeader(hdr), - SupportedVersions: ver, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) BufferedPacket(pt logging.PacketType) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventPacketBuffered{PacketType: pt}) - t.mutex.Unlock() -} - -func (t *connectionTracer) DroppedPacket(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventPacketDropped{ - PacketType: pt, - PacketSize: size, - Trigger: packetDropReason(reason), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) UpdatedMetrics(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { - m := &metrics{ - MinRTT: rttStats.MinRTT(), - SmoothedRTT: rttStats.SmoothedRTT(), - LatestRTT: rttStats.LatestRTT(), - RTTVariance: rttStats.MeanDeviation(), - CongestionWindow: cwnd, - BytesInFlight: bytesInFlight, - PacketsInFlight: packetsInFlight, - } - t.mutex.Lock() - t.recordEvent(time.Now(), &eventMetricsUpdated{ - Last: t.lastMetrics, - Current: m, - }) - t.lastMetrics = m - t.mutex.Unlock() -} - -func (t *connectionTracer) AcknowledgedPacket(protocol.EncryptionLevel, protocol.PacketNumber) {} - -func (t *connectionTracer) LostPacket(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventPacketLost{ - PacketType: getPacketTypeFromEncryptionLevel(encLevel), - PacketNumber: pn, - Trigger: packetLossReason(lossReason), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) UpdatedCongestionState(state logging.CongestionState) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventCongestionStateUpdated{state: congestionState(state)}) - t.mutex.Unlock() -} - -func (t *connectionTracer) UpdatedPTOCount(value uint32) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventUpdatedPTO{Value: value}) - t.mutex.Unlock() -} - -func (t *connectionTracer) UpdatedKeyFromTLS(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventKeyUpdated{ - Trigger: keyUpdateTLS, - KeyType: encLevelToKeyType(encLevel, pers), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) UpdatedKey(generation protocol.KeyPhase, remote bool) { - trigger := keyUpdateLocal - if remote { - trigger = keyUpdateRemote - } - t.mutex.Lock() - now := time.Now() - t.recordEvent(now, &eventKeyUpdated{ - Trigger: trigger, - KeyType: keyTypeClient1RTT, - Generation: generation, - }) - t.recordEvent(now, &eventKeyUpdated{ - Trigger: trigger, - KeyType: keyTypeServer1RTT, - Generation: generation, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) DroppedEncryptionLevel(encLevel protocol.EncryptionLevel) { - t.mutex.Lock() - now := time.Now() - if encLevel == protocol.Encryption0RTT { - t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, t.perspective)}) - } else { - t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveServer)}) - t.recordEvent(now, &eventKeyRetired{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveClient)}) - } - t.mutex.Unlock() -} - -func (t *connectionTracer) DroppedKey(generation protocol.KeyPhase) { - t.mutex.Lock() - now := time.Now() - t.recordEvent(now, &eventKeyRetired{ - KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer), - Generation: generation, - }) - t.recordEvent(now, &eventKeyRetired{ - KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient), - Generation: generation, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) SetLossTimer(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { - t.mutex.Lock() - now := time.Now() - t.recordEvent(now, &eventLossTimerSet{ - TimerType: timerType(tt), - EncLevel: encLevel, - Delta: timeout.Sub(now), - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) LossTimerExpired(tt logging.TimerType, encLevel protocol.EncryptionLevel) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventLossTimerExpired{ - TimerType: timerType(tt), - EncLevel: encLevel, - }) - t.mutex.Unlock() -} - -func (t *connectionTracer) LossTimerCanceled() { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventLossTimerCanceled{}) - t.mutex.Unlock() -} - -func (t *connectionTracer) Debug(name, msg string) { - t.mutex.Lock() - t.recordEvent(time.Now(), &eventGeneric{ - name: name, - msg: msg, - }) - t.mutex.Unlock() -} diff --git a/internal/quic-go/qlog/qlog_suite_test.go b/internal/quic-go/qlog/qlog_suite_test.go deleted file mode 100644 index 73f4917e..00000000 --- a/internal/quic-go/qlog/qlog_suite_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package qlog - -import ( - "encoding/json" - "os" - "strconv" - "testing" - "time" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestQlog(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "qlog Suite") -} - -//nolint:unparam -func scaleDuration(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t -} - -func checkEncoding(data []byte, expected map[string]interface{}) { - // unmarshal the data - m := make(map[string]interface{}) - ExpectWithOffset(1, json.Unmarshal(data, &m)).To(Succeed()) - ExpectWithOffset(1, m).To(HaveLen(len(expected))) - for key, value := range expected { - switch v := value.(type) { - case bool, string, map[string]interface{}: - ExpectWithOffset(1, m).To(HaveKeyWithValue(key, v)) - case int: - ExpectWithOffset(1, m).To(HaveKeyWithValue(key, float64(v))) - case [][]float64: // used in the ACK frame - ExpectWithOffset(1, m).To(HaveKey(key)) - for i, l := range v { - for j, s := range l { - ExpectWithOffset(1, m[key].([]interface{})[i].([]interface{})[j].(float64)).To(Equal(s)) - } - } - default: - Fail("unexpected type") - } - } -} diff --git a/internal/quic-go/qlog/qlog_test.go b/internal/quic-go/qlog/qlog_test.go deleted file mode 100644 index f5849927..00000000 --- a/internal/quic-go/qlog/qlog_test.go +++ /dev/null @@ -1,849 +0,0 @@ -package qlog - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "log" - "net" - "os" - "time" - - "github.com/imroc/req/v3/internal/quic-go" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type nopWriteCloserImpl struct{ io.Writer } - -func (nopWriteCloserImpl) Close() error { return nil } - -func nopWriteCloser(w io.Writer) io.WriteCloser { - return &nopWriteCloserImpl{Writer: w} -} - -type limitedWriter struct { - io.WriteCloser - N int - written int -} - -func (w *limitedWriter) Write(p []byte) (int, error) { - if w.written+len(p) > w.N { - return 0, errors.New("writer full") - } - n, err := w.WriteCloser.Write(p) - w.written += n - return n, err -} - -type entry struct { - Time time.Time - Name string - Event map[string]interface{} -} - -var _ = Describe("Tracing", func() { - Context("tracer", func() { - It("returns nil when there's no io.WriteCloser", func() { - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) - }) - }) - - It("stops writing when encountering an error", func() { - buf := &bytes.Buffer{} - t := NewConnectionTracer( - &limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}, - protocol.PerspectiveServer, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - ) - for i := uint32(0); i < 1000; i++ { - t.UpdatedPTOCount(i) - } - - b := &bytes.Buffer{} - log.SetOutput(b) - defer log.SetOutput(os.Stdout) - t.Close() - Expect(b.String()).To(ContainSubstring("writer full")) - }) - - Context("connection tracer", func() { - var ( - tracer logging.ConnectionTracer - buf *bytes.Buffer - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) - }) - - It("exports a trace that has the right metadata", func() { - tracer.Close() - - m := make(map[string]interface{}) - Expect(json.Unmarshal(buf.Bytes(), &m)).To(Succeed()) - Expect(m).To(HaveKeyWithValue("qlog_version", "draft-02")) - Expect(m).To(HaveKey("title")) - Expect(m).To(HaveKey("trace")) - trace := m["trace"].(map[string]interface{}) - Expect(trace).To(HaveKey(("common_fields"))) - commonFields := trace["common_fields"].(map[string]interface{}) - Expect(commonFields).To(HaveKeyWithValue("ODCID", "deadbeef")) - Expect(commonFields).To(HaveKeyWithValue("group_id", "deadbeef")) - Expect(commonFields).To(HaveKey("reference_time")) - referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) - Expect(referenceTime).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(commonFields).To(HaveKeyWithValue("time_format", "relative")) - Expect(trace).To(HaveKey("vantage_point")) - vantagePoint := trace["vantage_point"].(map[string]interface{}) - Expect(vantagePoint).To(HaveKeyWithValue("type", "server")) - }) - - Context("Events", func() { - exportAndParse := func() []entry { - tracer.Close() - - m := make(map[string]interface{}) - line, err := buf.ReadBytes('\n') - Expect(err).ToNot(HaveOccurred()) - Expect(json.Unmarshal(line, &m)).To(Succeed()) - Expect(m).To(HaveKey("trace")) - var entries []entry - trace := m["trace"].(map[string]interface{}) - Expect(trace).To(HaveKey("common_fields")) - commonFields := trace["common_fields"].(map[string]interface{}) - Expect(commonFields).To(HaveKey("reference_time")) - referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) - Expect(trace).ToNot(HaveKey("events")) - - for buf.Len() > 0 { - line, err := buf.ReadBytes('\n') - Expect(err).ToNot(HaveOccurred()) - ev := make(map[string]interface{}) - Expect(json.Unmarshal(line, &ev)).To(Succeed()) - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKey("time")) - Expect(ev).To(HaveKey("name")) - Expect(ev).To(HaveKey("data")) - entries = append(entries, entry{ - Time: referenceTime.Add(time.Duration(ev["time"].(float64)*1e6) * time.Nanosecond), - Name: ev["name"].(string), - Event: ev["data"].(map[string]interface{}), - }) - } - return entries - } - - exportAndParseSingle := func() entry { - entries := exportAndParse() - Expect(entries).To(HaveLen(1)) - return entries[0] - } - - It("records connection starts", func() { - tracer.StartedConnection( - &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 42}, - &net.UDPAddr{IP: net.IPv4(192, 168, 12, 34), Port: 24}, - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{5, 6, 7, 8}, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_started")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("ip_version", "ipv4")) - Expect(ev).To(HaveKeyWithValue("src_ip", "192.168.13.37")) - Expect(ev).To(HaveKeyWithValue("src_port", float64(42))) - Expect(ev).To(HaveKeyWithValue("dst_ip", "192.168.12.34")) - Expect(ev).To(HaveKeyWithValue("dst_port", float64(24))) - Expect(ev).To(HaveKeyWithValue("src_cid", "01020304")) - Expect(ev).To(HaveKeyWithValue("dst_cid", "05060708")) - }) - - It("records the version, if no version negotiation happened", func() { - tracer.NegotiatedVersion(0x1337, nil, nil) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:version_information")) - ev := entry.Event - Expect(ev).To(HaveLen(1)) - Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) - }) - - It("records the version, if version negotiation happened", func() { - tracer.NegotiatedVersion(0x1337, []logging.VersionNumber{1, 2, 3}, []logging.VersionNumber{4, 5, 6}) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:version_information")) - ev := entry.Event - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) - Expect(ev).To(HaveKey("client_versions")) - Expect(ev["client_versions"].([]interface{})).To(Equal([]interface{}{"1", "2", "3"})) - Expect(ev).To(HaveKey("server_versions")) - Expect(ev["server_versions"].([]interface{})).To(Equal([]interface{}{"4", "5", "6"})) - }) - - It("records idle timeouts", func() { - tracer.ClosedConnection(&quic.IdleTimeoutError{}) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(2)) - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).To(HaveKeyWithValue("trigger", "idle_timeout")) - }) - - It("records handshake timeouts", func() { - tracer.ClosedConnection(&quic.HandshakeTimeoutError{}) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(2)) - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).To(HaveKeyWithValue("trigger", "handshake_timeout")) - }) - - It("records a received stateless reset packet", func() { - tracer.ClosedConnection(&quic.StatelessResetError{ - Token: protocol.StatelessResetToken{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKeyWithValue("owner", "remote")) - Expect(ev).To(HaveKeyWithValue("trigger", "stateless_reset")) - Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "00112233445566778899aabbccddeeff")) - }) - - It("records connection closing due to version negotiation failure", func() { - tracer.ClosedConnection(&quic.VersionNegotiationError{}) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(2)) - Expect(ev).To(HaveKeyWithValue("owner", "remote")) - Expect(ev).To(HaveKeyWithValue("trigger", "version_negotiation")) - }) - - It("records application errors", func() { - tracer.ClosedConnection(&quic.ApplicationError{ - Remote: true, - ErrorCode: 1337, - ErrorMessage: "foobar", - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKeyWithValue("owner", "remote")) - Expect(ev).To(HaveKeyWithValue("application_code", float64(1337))) - Expect(ev).To(HaveKeyWithValue("reason", "foobar")) - }) - - It("records transport errors", func() { - tracer.ClosedConnection(&quic.TransportError{ - ErrorCode: qerr.AEADLimitReached, - ErrorMessage: "foobar", - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:connection_closed")) - ev := entry.Event - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).To(HaveKeyWithValue("connection_code", "aead_limit_reached")) - Expect(ev).To(HaveKeyWithValue("reason", "foobar")) - }) - - It("records sent transport parameters", func() { - tracer.SentTransportParameters(&logging.TransportParameters{ - InitialMaxStreamDataBidiLocal: 1000, - InitialMaxStreamDataBidiRemote: 2000, - InitialMaxStreamDataUni: 3000, - InitialMaxData: 4000, - MaxBidiStreamNum: 10, - MaxUniStreamNum: 20, - MaxAckDelay: 123 * time.Millisecond, - AckDelayExponent: 12, - DisableActiveMigration: true, - MaxUDPPayloadSize: 1234, - MaxIdleTimeout: 321 * time.Millisecond, - StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - ActiveConnectionIDLimit: 7, - MaxDatagramFrameSize: protocol.InvalidByteCount, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).To(HaveKeyWithValue("original_destination_connection_id", "deadc0de")) - Expect(ev).To(HaveKeyWithValue("initial_source_connection_id", "deadbeef")) - Expect(ev).To(HaveKeyWithValue("retry_source_connection_id", "decafbad")) - Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "112233445566778899aabbccddeeff00")) - Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(321))) - Expect(ev).To(HaveKeyWithValue("max_udp_payload_size", float64(1234))) - Expect(ev).To(HaveKeyWithValue("ack_delay_exponent", float64(12))) - Expect(ev).To(HaveKeyWithValue("active_connection_id_limit", float64(7))) - Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(4000))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(1000))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(2000))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(3000))) - Expect(ev).To(HaveKeyWithValue("initial_max_streams_bidi", float64(10))) - Expect(ev).To(HaveKeyWithValue("initial_max_streams_uni", float64(20))) - Expect(ev).ToNot(HaveKey("preferred_address")) - Expect(ev).ToNot(HaveKey("max_datagram_frame_size")) - }) - - It("records the server's transport parameters, without a stateless reset token", func() { - tracer.SentTransportParameters(&logging.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - ActiveConnectionIDLimit: 7, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).ToNot(HaveKey("stateless_reset_token")) - }) - - It("records transport parameters without retry_source_connection_id", func() { - tracer.SentTransportParameters(&logging.TransportParameters{ - StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).ToNot(HaveKey("retry_source_connection_id")) - }) - - It("records transport parameters with a preferred address", func() { - tracer.SentTransportParameters(&logging.TransportParameters{ - PreferredAddress: &logging.PreferredAddress{ - IPv4: net.IPv4(12, 34, 56, 78), - IPv4Port: 123, - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - IPv6Port: 456, - ConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, - StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - }, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("owner", "local")) - Expect(ev).To(HaveKey("preferred_address")) - pa := ev["preferred_address"].(map[string]interface{}) - Expect(pa).To(HaveKeyWithValue("ip_v4", "12.34.56.78")) - Expect(pa).To(HaveKeyWithValue("port_v4", float64(123))) - Expect(pa).To(HaveKeyWithValue("ip_v6", "102:304:506:708:90a:b0c:d0e:f10")) - Expect(pa).To(HaveKeyWithValue("port_v6", float64(456))) - Expect(pa).To(HaveKeyWithValue("connection_id", "0807060504030201")) - Expect(pa).To(HaveKeyWithValue("stateless_reset_token", "0f0e0d0c0b0a09080706050403020100")) - }) - - It("records transport parameters that enable the datagram extension", func() { - tracer.SentTransportParameters(&logging.TransportParameters{ - MaxDatagramFrameSize: 1337, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("max_datagram_frame_size", float64(1337))) - }) - - It("records received transport parameters", func() { - tracer.ReceivedTransportParameters(&logging.TransportParameters{}) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_set")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("owner", "remote")) - Expect(ev).ToNot(HaveKey("original_destination_connection_id")) - }) - - It("records restored transport parameters", func() { - tracer.RestoredTransportParameters(&logging.TransportParameters{ - InitialMaxStreamDataBidiLocal: 100, - InitialMaxStreamDataBidiRemote: 200, - InitialMaxStreamDataUni: 300, - InitialMaxData: 400, - MaxIdleTimeout: 123 * time.Millisecond, - }) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:parameters_restored")) - ev := entry.Event - Expect(ev).ToNot(HaveKey("owner")) - Expect(ev).ToNot(HaveKey("original_destination_connection_id")) - Expect(ev).ToNot(HaveKey("stateless_reset_token")) - Expect(ev).ToNot(HaveKey("retry_source_connection_id")) - Expect(ev).ToNot(HaveKey("initial_source_connection_id")) - Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(123))) - Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(400))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(100))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(200))) - Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(300))) - }) - - It("records a sent packet, without an ACK", func() { - tracer.SentPacket( - &logging.ExtendedHeader{ - Header: logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - Length: 1337, - Version: protocol.VersionTLS, - }, - PacketNumber: 1337, - }, - 987, - nil, - []logging.Frame{ - &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, - &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, - }, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_sent")) - ev := entry.Event - Expect(ev).To(HaveKey("raw")) - raw := ev["raw"].(map[string]interface{}) - Expect(raw).To(HaveKeyWithValue("length", float64(987))) - Expect(raw).To(HaveKeyWithValue("payload_length", float64(1337))) - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) - Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) - Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) - Expect(ev).To(HaveKey("frames")) - frames := ev["frames"].([]interface{}) - Expect(frames).To(HaveLen(2)) - Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data")) - Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "stream")) - }) - - It("records a sent packet, without an ACK", func() { - tracer.SentPacket( - &logging.ExtendedHeader{ - Header: logging.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}}, - PacketNumber: 1337, - }, - 123, - &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}}, - []logging.Frame{&logging.MaxDataFrame{MaximumData: 987}}, - ) - entry := exportAndParseSingle() - ev := entry.Event - raw := ev["raw"].(map[string]interface{}) - Expect(raw).To(HaveKeyWithValue("length", float64(123))) - Expect(raw).ToNot(HaveKey("payload_length")) - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) - Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) - Expect(ev).To(HaveKey("frames")) - frames := ev["frames"].([]interface{}) - Expect(frames).To(HaveLen(2)) - Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "ack")) - Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_data")) - }) - - It("records a received packet", func() { - tracer.ReceivedPacket( - &logging.ExtendedHeader{ - Header: logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - Length: 1234, - Version: protocol.VersionTLS, - }, - PacketNumber: 1337, - }, - 789, - []logging.Frame{ - &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, - &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, - }, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_received")) - ev := entry.Event - Expect(ev).To(HaveKey("raw")) - raw := ev["raw"].(map[string]interface{}) - Expect(raw).To(HaveKeyWithValue("length", float64(789))) - Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234))) - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveKeyWithValue("packet_type", "initial")) - Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) - Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) - Expect(hdr).To(HaveKey("token")) - token := hdr["token"].(map[string]interface{}) - Expect(token).To(HaveKeyWithValue("data", "deadbeef")) - Expect(ev).To(HaveKey("frames")) - Expect(ev["frames"].([]interface{})).To(HaveLen(2)) - }) - - It("records a received Retry packet", func() { - tracer.ReceivedRetry( - &logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - Version: protocol.VersionTLS, - }, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_received")) - ev := entry.Event - Expect(ev).ToNot(HaveKey("raw")) - Expect(ev).To(HaveKey("header")) - header := ev["header"].(map[string]interface{}) - Expect(header).To(HaveKeyWithValue("packet_type", "retry")) - Expect(header).ToNot(HaveKey("packet_number")) - Expect(header).To(HaveKey("version")) - Expect(header).To(HaveKey("dcid")) - Expect(header).To(HaveKey("scid")) - Expect(header).To(HaveKey("token")) - token := header["token"].(map[string]interface{}) - Expect(token).To(HaveKeyWithValue("data", "deadbeef")) - Expect(ev).ToNot(HaveKey("frames")) - }) - - It("records a received Version Negotiation packet", func() { - tracer.ReceivedVersionNegotiationPacket( - &logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - }, - []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_received")) - ev := entry.Event - Expect(ev).To(HaveKey("header")) - Expect(ev).ToNot(HaveKey("frames")) - Expect(ev).To(HaveKey("supported_versions")) - Expect(ev["supported_versions"].([]interface{})).To(Equal([]interface{}{"deadbeef", "decafbad"})) - header := ev["header"] - Expect(header).To(HaveKeyWithValue("packet_type", "version_negotiation")) - Expect(header).ToNot(HaveKey("packet_number")) - Expect(header).ToNot(HaveKey("version")) - Expect(header).To(HaveKey("dcid")) - Expect(header).To(HaveKey("scid")) - }) - - It("records buffered packets", func() { - tracer.BufferedPacket(logging.PacketTypeHandshake) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_buffered")) - ev := entry.Event - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveLen(1)) - Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) - Expect(ev).To(HaveKeyWithValue("trigger", "keys_unavailable")) - }) - - It("records dropped packets", func() { - tracer.DroppedPacket(logging.PacketTypeHandshake, 1337, logging.PacketDropPayloadDecryptError) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:packet_dropped")) - ev := entry.Event - Expect(ev).To(HaveKey("raw")) - Expect(ev["raw"].(map[string]interface{})).To(HaveKeyWithValue("length", float64(1337))) - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveLen(1)) - Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) - Expect(ev).To(HaveKeyWithValue("trigger", "payload_decrypt_error")) - }) - - It("records metrics updates", func() { - now := time.Now() - rttStats := utils.NewRTTStats() - rttStats.UpdateRTT(15*time.Millisecond, 0, now) - rttStats.UpdateRTT(20*time.Millisecond, 0, now) - rttStats.UpdateRTT(25*time.Millisecond, 0, now) - Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) - Expect(rttStats.SmoothedRTT()).To(And( - BeNumerically(">", 15*time.Millisecond), - BeNumerically("<", 25*time.Millisecond), - )) - Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) - tracer.UpdatedMetrics( - rttStats, - 4321, - 1234, - 42, - ) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:metrics_updated")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("min_rtt", float64(15))) - Expect(ev).To(HaveKeyWithValue("latest_rtt", float64(25))) - Expect(ev).To(HaveKey("smoothed_rtt")) - Expect(time.Duration(ev["smoothed_rtt"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.SmoothedRTT(), time.Millisecond)) - Expect(ev).To(HaveKey("rtt_variance")) - Expect(time.Duration(ev["rtt_variance"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.MeanDeviation(), time.Millisecond)) - Expect(ev).To(HaveKeyWithValue("congestion_window", float64(4321))) - Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(1234))) - Expect(ev).To(HaveKeyWithValue("packets_in_flight", float64(42))) - }) - - It("only logs the diff between two metrics updates", func() { - now := time.Now() - rttStats := utils.NewRTTStats() - rttStats.UpdateRTT(15*time.Millisecond, 0, now) - rttStats.UpdateRTT(20*time.Millisecond, 0, now) - rttStats.UpdateRTT(25*time.Millisecond, 0, now) - Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) - - rttStats2 := utils.NewRTTStats() - rttStats2.UpdateRTT(15*time.Millisecond, 0, now) - rttStats2.UpdateRTT(15*time.Millisecond, 0, now) - rttStats2.UpdateRTT(15*time.Millisecond, 0, now) - Expect(rttStats2.MinRTT()).To(Equal(15 * time.Millisecond)) - - Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) - tracer.UpdatedMetrics( - rttStats, - 4321, - 1234, - 42, - ) - tracer.UpdatedMetrics( - rttStats2, - 4321, - 12345, // changed - 42, - ) - entries := exportAndParse() - Expect(entries).To(HaveLen(2)) - Expect(entries[0].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entries[0].Name).To(Equal("recovery:metrics_updated")) - Expect(entries[0].Event).To(HaveLen(7)) - Expect(entries[1].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entries[1].Name).To(Equal("recovery:metrics_updated")) - ev := entries[1].Event - Expect(ev).ToNot(HaveKey("min_rtt")) - Expect(ev).ToNot(HaveKey("congestion_window")) - Expect(ev).ToNot(HaveKey("packets_in_flight")) - Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(12345))) - Expect(ev).To(HaveKeyWithValue("smoothed_rtt", float64(15))) - }) - - It("records lost packets", func() { - tracer.LostPacket(protocol.EncryptionHandshake, 42, logging.PacketLossReorderingThreshold) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:packet_lost")) - ev := entry.Event - Expect(ev).To(HaveKey("header")) - hdr := ev["header"].(map[string]interface{}) - Expect(hdr).To(HaveLen(2)) - Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) - Expect(hdr).To(HaveKeyWithValue("packet_number", float64(42))) - Expect(ev).To(HaveKeyWithValue("trigger", "reordering_threshold")) - }) - - It("records congestion state updates", func() { - tracer.UpdatedCongestionState(logging.CongestionStateCongestionAvoidance) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:congestion_state_updated")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("new", "congestion_avoidance")) - }) - - It("records PTO changes", func() { - tracer.UpdatedPTOCount(42) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:metrics_updated")) - Expect(entry.Event).To(HaveKeyWithValue("pto_count", float64(42))) - }) - - It("records TLS key updates", func() { - tracer.UpdatedKeyFromTLS(protocol.EncryptionHandshake, protocol.PerspectiveClient) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_updated")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("key_type", "client_handshake_secret")) - Expect(ev).To(HaveKeyWithValue("trigger", "tls")) - Expect(ev).ToNot(HaveKey("generation")) - Expect(ev).ToNot(HaveKey("old")) - Expect(ev).ToNot(HaveKey("new")) - }) - - It("records TLS key updates, for 1-RTT keys", func() { - tracer.UpdatedKeyFromTLS(protocol.Encryption1RTT, protocol.PerspectiveServer) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_updated")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("key_type", "server_1rtt_secret")) - Expect(ev).To(HaveKeyWithValue("trigger", "tls")) - Expect(ev).To(HaveKeyWithValue("generation", float64(0))) - Expect(ev).ToNot(HaveKey("old")) - Expect(ev).ToNot(HaveKey("new")) - }) - - It("records QUIC key updates", func() { - tracer.UpdatedKey(1337, true) - entries := exportAndParse() - Expect(entries).To(HaveLen(2)) - var keyTypes []string - for _, entry := range entries { - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_updated")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("generation", float64(1337))) - Expect(ev).To(HaveKeyWithValue("trigger", "remote_update")) - Expect(ev).To(HaveKey("key_type")) - keyTypes = append(keyTypes, ev["key_type"].(string)) - } - Expect(keyTypes).To(ContainElement("server_1rtt_secret")) - Expect(keyTypes).To(ContainElement("client_1rtt_secret")) - }) - - It("records dropped encryption levels", func() { - tracer.DroppedEncryptionLevel(protocol.EncryptionInitial) - entries := exportAndParse() - Expect(entries).To(HaveLen(2)) - var keyTypes []string - for _, entry := range entries { - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_retired")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("trigger", "tls")) - Expect(ev).To(HaveKey("key_type")) - keyTypes = append(keyTypes, ev["key_type"].(string)) - } - Expect(keyTypes).To(ContainElement("server_initial_secret")) - Expect(keyTypes).To(ContainElement("client_initial_secret")) - }) - - It("records dropped 0-RTT keys", func() { - tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) - entries := exportAndParse() - Expect(entries).To(HaveLen(1)) - entry := entries[0] - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_retired")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("trigger", "tls")) - Expect(ev).To(HaveKeyWithValue("key_type", "server_0rtt_secret")) - }) - - It("records dropped keys", func() { - tracer.DroppedKey(42) - entries := exportAndParse() - Expect(entries).To(HaveLen(2)) - var keyTypes []string - for _, entry := range entries { - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("security:key_retired")) - ev := entry.Event - Expect(ev).To(HaveKeyWithValue("generation", float64(42))) - Expect(ev).ToNot(HaveKey("trigger")) - Expect(ev).To(HaveKey("key_type")) - keyTypes = append(keyTypes, ev["key_type"].(string)) - } - Expect(keyTypes).To(ContainElement("server_1rtt_secret")) - Expect(keyTypes).To(ContainElement("client_1rtt_secret")) - }) - - It("records when the timer is set", func() { - timeout := time.Now().Add(137 * time.Millisecond) - tracer.SetLossTimer(logging.TimerTypePTO, protocol.EncryptionHandshake, timeout) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) - ev := entry.Event - Expect(ev).To(HaveLen(4)) - Expect(ev).To(HaveKeyWithValue("event_type", "set")) - Expect(ev).To(HaveKeyWithValue("timer_type", "pto")) - Expect(ev).To(HaveKeyWithValue("packet_number_space", "handshake")) - Expect(ev).To(HaveKey("delta")) - delta := time.Duration(ev["delta"].(float64)*1e6) * time.Nanosecond - Expect(entry.Time.Add(delta)).To(BeTemporally("~", timeout, scaleDuration(10*time.Microsecond))) - }) - - It("records when the loss timer expires", func() { - tracer.LossTimerExpired(logging.TimerTypeACK, protocol.Encryption1RTT) - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) - ev := entry.Event - Expect(ev).To(HaveLen(3)) - Expect(ev).To(HaveKeyWithValue("event_type", "expired")) - Expect(ev).To(HaveKeyWithValue("timer_type", "ack")) - Expect(ev).To(HaveKeyWithValue("packet_number_space", "application_data")) - }) - - It("records when the timer is canceled", func() { - tracer.LossTimerCanceled() - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) - ev := entry.Event - Expect(ev).To(HaveLen(1)) - Expect(ev).To(HaveKeyWithValue("event_type", "cancelled")) - }) - - It("records a generic event", func() { - tracer.Debug("foo", "bar") - entry := exportAndParseSingle() - Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) - Expect(entry.Name).To(Equal("transport:foo")) - ev := entry.Event - Expect(ev).To(HaveLen(1)) - Expect(ev).To(HaveKeyWithValue("details", "bar")) - }) - }) - }) -}) diff --git a/internal/quic-go/qlog/trace.go b/internal/quic-go/qlog/trace.go deleted file mode 100644 index a3ae43b4..00000000 --- a/internal/quic-go/qlog/trace.go +++ /dev/null @@ -1,66 +0,0 @@ -package qlog - -import ( - "time" - - "github.com/francoispqt/gojay" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -type topLevel struct { - trace trace -} - -func (topLevel) IsNil() bool { return false } -func (l topLevel) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("qlog_format", "NDJSON") - enc.StringKey("qlog_version", "draft-02") - enc.StringKeyOmitEmpty("title", "quic-go qlog") - enc.StringKey("code_version", quicGoVersion) - enc.ObjectKey("trace", l.trace) -} - -type vantagePoint struct { - Name string - Type protocol.Perspective -} - -func (p vantagePoint) IsNil() bool { return false } -func (p vantagePoint) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKeyOmitEmpty("name", p.Name) - switch p.Type { - case protocol.PerspectiveClient: - enc.StringKey("type", "client") - case protocol.PerspectiveServer: - enc.StringKey("type", "server") - } -} - -type commonFields struct { - ODCID connectionID - GroupID connectionID - ProtocolType string - ReferenceTime time.Time -} - -func (f commonFields) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("ODCID", f.ODCID.String()) - enc.StringKey("group_id", f.ODCID.String()) - enc.StringKeyOmitEmpty("protocol_type", f.ProtocolType) - enc.Float64Key("reference_time", float64(f.ReferenceTime.UnixNano())/1e6) - enc.StringKey("time_format", "relative") -} - -func (f commonFields) IsNil() bool { return false } - -type trace struct { - VantagePoint vantagePoint - CommonFields commonFields -} - -func (trace) IsNil() bool { return false } -func (t trace) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("vantage_point", t.VantagePoint) - enc.ObjectKey("common_fields", t.CommonFields) -} diff --git a/internal/quic-go/qlog/types.go b/internal/quic-go/qlog/types.go deleted file mode 100644 index 30e7f3ee..00000000 --- a/internal/quic-go/qlog/types.go +++ /dev/null @@ -1,320 +0,0 @@ -package qlog - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" -) - -type owner uint8 - -const ( - ownerLocal owner = iota - ownerRemote -) - -func (o owner) String() string { - switch o { - case ownerLocal: - return "local" - case ownerRemote: - return "remote" - default: - return "unknown owner" - } -} - -type streamType protocol.StreamType - -func (s streamType) String() string { - switch protocol.StreamType(s) { - case protocol.StreamTypeUni: - return "unidirectional" - case protocol.StreamTypeBidi: - return "bidirectional" - default: - return "unknown stream type" - } -} - -type connectionID protocol.ConnectionID - -func (c connectionID) String() string { - return fmt.Sprintf("%x", []byte(c)) -} - -// category is the qlog event category. -type category uint8 - -const ( - categoryConnectivity category = iota - categoryTransport - categorySecurity - categoryRecovery -) - -func (c category) String() string { - switch c { - case categoryConnectivity: - return "connectivity" - case categoryTransport: - return "transport" - case categorySecurity: - return "security" - case categoryRecovery: - return "recovery" - default: - return "unknown category" - } -} - -type versionNumber protocol.VersionNumber - -func (v versionNumber) String() string { - return fmt.Sprintf("%x", uint32(v)) -} - -func (packetHeader) IsNil() bool { return false } - -func encLevelToPacketNumberSpace(encLevel protocol.EncryptionLevel) string { - switch encLevel { - case protocol.EncryptionInitial: - return "initial" - case protocol.EncryptionHandshake: - return "handshake" - case protocol.Encryption0RTT, protocol.Encryption1RTT: - return "application_data" - default: - return "unknown encryption level" - } -} - -type keyType uint8 - -const ( - keyTypeServerInitial keyType = 1 + iota - keyTypeClientInitial - keyTypeServerHandshake - keyTypeClientHandshake - keyTypeServer0RTT - keyTypeClient0RTT - keyTypeServer1RTT - keyTypeClient1RTT -) - -func encLevelToKeyType(encLevel protocol.EncryptionLevel, pers protocol.Perspective) keyType { - if pers == protocol.PerspectiveServer { - switch encLevel { - case protocol.EncryptionInitial: - return keyTypeServerInitial - case protocol.EncryptionHandshake: - return keyTypeServerHandshake - case protocol.Encryption0RTT: - return keyTypeServer0RTT - case protocol.Encryption1RTT: - return keyTypeServer1RTT - default: - return 0 - } - } - switch encLevel { - case protocol.EncryptionInitial: - return keyTypeClientInitial - case protocol.EncryptionHandshake: - return keyTypeClientHandshake - case protocol.Encryption0RTT: - return keyTypeClient0RTT - case protocol.Encryption1RTT: - return keyTypeClient1RTT - default: - return 0 - } -} - -func (t keyType) String() string { - switch t { - case keyTypeServerInitial: - return "server_initial_secret" - case keyTypeClientInitial: - return "client_initial_secret" - case keyTypeServerHandshake: - return "server_handshake_secret" - case keyTypeClientHandshake: - return "client_handshake_secret" - case keyTypeServer0RTT: - return "server_0rtt_secret" - case keyTypeClient0RTT: - return "client_0rtt_secret" - case keyTypeServer1RTT: - return "server_1rtt_secret" - case keyTypeClient1RTT: - return "client_1rtt_secret" - default: - return "unknown key type" - } -} - -type keyUpdateTrigger uint8 - -const ( - keyUpdateTLS keyUpdateTrigger = iota - keyUpdateRemote - keyUpdateLocal -) - -func (t keyUpdateTrigger) String() string { - switch t { - case keyUpdateTLS: - return "tls" - case keyUpdateRemote: - return "remote_update" - case keyUpdateLocal: - return "local_update" - default: - return "unknown key update trigger" - } -} - -type transportError uint64 - -func (e transportError) String() string { - switch qerr.TransportErrorCode(e) { - case qerr.NoError: - return "no_error" - case qerr.InternalError: - return "internal_error" - case qerr.ConnectionRefused: - return "connection_refused" - case qerr.FlowControlError: - return "flow_control_error" - case qerr.StreamLimitError: - return "stream_limit_error" - case qerr.StreamStateError: - return "stream_state_error" - case qerr.FinalSizeError: - return "final_size_error" - case qerr.FrameEncodingError: - return "frame_encoding_error" - case qerr.TransportParameterError: - return "transport_parameter_error" - case qerr.ConnectionIDLimitError: - return "connection_id_limit_error" - case qerr.ProtocolViolation: - return "protocol_violation" - case qerr.InvalidToken: - return "invalid_token" - case qerr.ApplicationErrorErrorCode: - return "application_error" - case qerr.CryptoBufferExceeded: - return "crypto_buffer_exceeded" - case qerr.KeyUpdateError: - return "key_update_error" - case qerr.AEADLimitReached: - return "aead_limit_reached" - case qerr.NoViablePathError: - return "no_viable_path" - default: - return "" - } -} - -type packetType logging.PacketType - -func (t packetType) String() string { - switch logging.PacketType(t) { - case logging.PacketTypeInitial: - return "initial" - case logging.PacketTypeHandshake: - return "handshake" - case logging.PacketTypeRetry: - return "retry" - case logging.PacketType0RTT: - return "0RTT" - case logging.PacketTypeVersionNegotiation: - return "version_negotiation" - case logging.PacketTypeStatelessReset: - return "stateless_reset" - case logging.PacketType1RTT: - return "1RTT" - case logging.PacketTypeNotDetermined: - return "" - default: - return "unknown packet type" - } -} - -type packetLossReason logging.PacketLossReason - -func (r packetLossReason) String() string { - switch logging.PacketLossReason(r) { - case logging.PacketLossReorderingThreshold: - return "reordering_threshold" - case logging.PacketLossTimeThreshold: - return "time_threshold" - default: - return "unknown loss reason" - } -} - -type packetDropReason logging.PacketDropReason - -func (r packetDropReason) String() string { - switch logging.PacketDropReason(r) { - case logging.PacketDropKeyUnavailable: - return "key_unavailable" - case logging.PacketDropUnknownConnectionID: - return "unknown_connection_id" - case logging.PacketDropHeaderParseError: - return "header_parse_error" - case logging.PacketDropPayloadDecryptError: - return "payload_decrypt_error" - case logging.PacketDropProtocolViolation: - return "protocol_violation" - case logging.PacketDropDOSPrevention: - return "dos_prevention" - case logging.PacketDropUnsupportedVersion: - return "unsupported_version" - case logging.PacketDropUnexpectedPacket: - return "unexpected_packet" - case logging.PacketDropUnexpectedSourceConnectionID: - return "unexpected_source_connection_id" - case logging.PacketDropUnexpectedVersion: - return "unexpected_version" - case logging.PacketDropDuplicate: - return "duplicate" - default: - return "unknown packet drop reason" - } -} - -type timerType logging.TimerType - -func (t timerType) String() string { - switch logging.TimerType(t) { - case logging.TimerTypeACK: - return "ack" - case logging.TimerTypePTO: - return "pto" - default: - return "unknown timer type" - } -} - -type congestionState logging.CongestionState - -func (s congestionState) String() string { - switch logging.CongestionState(s) { - case logging.CongestionStateSlowStart: - return "slow_start" - case logging.CongestionStateCongestionAvoidance: - return "congestion_avoidance" - case logging.CongestionStateRecovery: - return "recovery" - case logging.CongestionStateApplicationLimited: - return "application_limited" - default: - return "unknown congestion state" - } -} diff --git a/internal/quic-go/qlog/types_test.go b/internal/quic-go/qlog/types_test.go deleted file mode 100644 index cbdc8514..00000000 --- a/internal/quic-go/qlog/types_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package qlog - -import ( - "go/ast" - "go/parser" - gotoken "go/token" - "path" - "runtime" - "strconv" - - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Types", func() { - It("has a string representation for the owner", func() { - Expect(ownerLocal.String()).To(Equal("local")) - Expect(ownerRemote.String()).To(Equal("remote")) - }) - - It("has a string representation for the category", func() { - Expect(categoryConnectivity.String()).To(Equal("connectivity")) - Expect(categoryTransport.String()).To(Equal("transport")) - Expect(categoryRecovery.String()).To(Equal("recovery")) - Expect(categorySecurity.String()).To(Equal("security")) - }) - - It("has a string representation for the packet type", func() { - Expect(packetType(logging.PacketTypeInitial).String()).To(Equal("initial")) - Expect(packetType(logging.PacketTypeHandshake).String()).To(Equal("handshake")) - Expect(packetType(logging.PacketType0RTT).String()).To(Equal("0RTT")) - Expect(packetType(logging.PacketType1RTT).String()).To(Equal("1RTT")) - Expect(packetType(logging.PacketTypeStatelessReset).String()).To(Equal("stateless_reset")) - Expect(packetType(logging.PacketTypeRetry).String()).To(Equal("retry")) - Expect(packetType(logging.PacketTypeVersionNegotiation).String()).To(Equal("version_negotiation")) - Expect(packetType(logging.PacketTypeNotDetermined).String()).To(BeEmpty()) - }) - - It("has a string representation for the packet drop reason", func() { - Expect(packetDropReason(logging.PacketDropKeyUnavailable).String()).To(Equal("key_unavailable")) - Expect(packetDropReason(logging.PacketDropUnknownConnectionID).String()).To(Equal("unknown_connection_id")) - Expect(packetDropReason(logging.PacketDropHeaderParseError).String()).To(Equal("header_parse_error")) - Expect(packetDropReason(logging.PacketDropPayloadDecryptError).String()).To(Equal("payload_decrypt_error")) - Expect(packetDropReason(logging.PacketDropProtocolViolation).String()).To(Equal("protocol_violation")) - Expect(packetDropReason(logging.PacketDropDOSPrevention).String()).To(Equal("dos_prevention")) - Expect(packetDropReason(logging.PacketDropUnsupportedVersion).String()).To(Equal("unsupported_version")) - Expect(packetDropReason(logging.PacketDropUnexpectedPacket).String()).To(Equal("unexpected_packet")) - Expect(packetDropReason(logging.PacketDropUnexpectedSourceConnectionID).String()).To(Equal("unexpected_source_connection_id")) - Expect(packetDropReason(logging.PacketDropUnexpectedVersion).String()).To(Equal("unexpected_version")) - }) - - It("has a string representation for the timer type", func() { - Expect(timerType(logging.TimerTypeACK).String()).To(Equal("ack")) - Expect(timerType(logging.TimerTypePTO).String()).To(Equal("pto")) - }) - - It("has a string representation for the key type", func() { - Expect(encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient).String()).To(Equal("client_initial_secret")) - Expect(encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer).String()).To(Equal("server_initial_secret")) - Expect(encLevelToKeyType(protocol.EncryptionHandshake, protocol.PerspectiveClient).String()).To(Equal("client_handshake_secret")) - Expect(encLevelToKeyType(protocol.EncryptionHandshake, protocol.PerspectiveServer).String()).To(Equal("server_handshake_secret")) - Expect(encLevelToKeyType(protocol.Encryption0RTT, protocol.PerspectiveClient).String()).To(Equal("client_0rtt_secret")) - Expect(encLevelToKeyType(protocol.Encryption0RTT, protocol.PerspectiveServer).String()).To(Equal("server_0rtt_secret")) - Expect(encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient).String()).To(Equal("client_1rtt_secret")) - Expect(encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer).String()).To(Equal("server_1rtt_secret")) - }) - - It("has a string representation for the key update trigger", func() { - Expect(keyUpdateTLS.String()).To(Equal("tls")) - Expect(keyUpdateRemote.String()).To(Equal("remote_update")) - Expect(keyUpdateLocal.String()).To(Equal("local_update")) - }) - - It("tells the packet number space from the encryption level", func() { - Expect(encLevelToPacketNumberSpace(protocol.EncryptionInitial)).To(Equal("initial")) - Expect(encLevelToPacketNumberSpace(protocol.EncryptionHandshake)).To(Equal("handshake")) - Expect(encLevelToPacketNumberSpace(protocol.Encryption0RTT)).To(Equal("application_data")) - Expect(encLevelToPacketNumberSpace(protocol.Encryption1RTT)).To(Equal("application_data")) - }) - - Context("transport errors", func() { - It("has a string representation for every error code", func() { - // We parse the error code file, extract all constants, and verify that - // each of them has a string version. Go FTW! - _, thisfile, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - filename := path.Join(path.Dir(thisfile), "../qerr/error_codes.go") - fileAst, err := parser.ParseFile(gotoken.NewFileSet(), filename, nil, 0) - Expect(err).NotTo(HaveOccurred()) - constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs - Expect(len(constSpecs)).To(BeNumerically(">", 4)) // at time of writing - for _, c := range constSpecs { - valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value - val, err := strconv.ParseInt(valString, 0, 64) - Expect(err).NotTo(HaveOccurred()) - Expect(transportError(val).String()).ToNot(BeEmpty()) - } - }) - - It("has a string representation for transport errors", func() { - Expect(transportError(qerr.NoError).String()).To(Equal("no_error")) - Expect(transportError(qerr.InternalError).String()).To(Equal("internal_error")) - Expect(transportError(qerr.ConnectionRefused).String()).To(Equal("connection_refused")) - Expect(transportError(qerr.FlowControlError).String()).To(Equal("flow_control_error")) - Expect(transportError(qerr.StreamLimitError).String()).To(Equal("stream_limit_error")) - Expect(transportError(qerr.StreamStateError).String()).To(Equal("stream_state_error")) - Expect(transportError(qerr.FrameEncodingError).String()).To(Equal("frame_encoding_error")) - Expect(transportError(qerr.ConnectionIDLimitError).String()).To(Equal("connection_id_limit_error")) - Expect(transportError(qerr.ProtocolViolation).String()).To(Equal("protocol_violation")) - Expect(transportError(qerr.InvalidToken).String()).To(Equal("invalid_token")) - Expect(transportError(qerr.ApplicationErrorErrorCode).String()).To(Equal("application_error")) - Expect(transportError(qerr.CryptoBufferExceeded).String()).To(Equal("crypto_buffer_exceeded")) - Expect(transportError(qerr.NoViablePathError).String()).To(Equal("no_viable_path")) - Expect(transportError(1337).String()).To(BeEmpty()) - }) - }) - - It("has a string representation for congestion state updates", func() { - Expect(congestionState(logging.CongestionStateSlowStart).String()).To(Equal("slow_start")) - Expect(congestionState(logging.CongestionStateCongestionAvoidance).String()).To(Equal("congestion_avoidance")) - Expect(congestionState(logging.CongestionStateApplicationLimited).String()).To(Equal("application_limited")) - Expect(congestionState(logging.CongestionStateRecovery).String()).To(Equal("recovery")) - }) -}) diff --git a/internal/quic-go/qtls/go116.go b/internal/quic-go/qtls/go116.go deleted file mode 100644 index e3024624..00000000 --- a/internal/quic-go/qtls/go116.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.16 && !go1.17 -// +build go1.16,!go1.17 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-16" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/quic-go/qtls/go117.go b/internal/quic-go/qtls/go117.go deleted file mode 100644 index bc385f19..00000000 --- a/internal/quic-go/qtls/go117.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.17 && !go1.18 -// +build go1.17,!go1.18 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-17" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/quic-go/qtls/go118.go b/internal/quic-go/qtls/go118.go deleted file mode 100644 index 5de030c7..00000000 --- a/internal/quic-go/qtls/go118.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.18 && !go1.19 -// +build go1.18,!go1.19 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-18" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/quic-go/qtls/go119.go b/internal/quic-go/qtls/go119.go deleted file mode 100644 index 86dcaea3..00000000 --- a/internal/quic-go/qtls/go119.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.19 -// +build go1.19 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-19" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains information about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-19.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/internal/quic-go/qtls/go_oldversion.go b/internal/quic-go/qtls/go_oldversion.go deleted file mode 100644 index c17b8589..00000000 --- a/internal/quic-go/qtls/go_oldversion.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build (go1.9 || go1.10 || go1.11 || go1.12 || go1.13 || go1.14 || go1.15) && !go1.16 -// +build go1.9 go1.10 go1.11 go1.12 go1.13 go1.14 go1.15 -// +build !go1.16 - -package qtls - -var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/imroc/req/v3/internal/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/quic-go/qtls/qtls_suite_test.go b/internal/quic-go/qtls/qtls_suite_test.go deleted file mode 100644 index 24b143b2..00000000 --- a/internal/quic-go/qtls/qtls_suite_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package qtls - -import ( - "testing" - - gomock "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestQTLS(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "qtls Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/quic-go/qtls/qtls_test.go b/internal/quic-go/qtls/qtls_test.go deleted file mode 100644 index c64c5e9e..00000000 --- a/internal/quic-go/qtls/qtls_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package qtls - -import ( - "crypto/tls" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("qtls wrapper", func() { - It("gets cipher suites", func() { - for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} { - cs := CipherSuiteTLS13ByID(id) - Expect(cs.ID).To(Equal(id)) - } - }) -}) diff --git a/internal/quic-go/quic_suite_test.go b/internal/quic-go/quic_suite_test.go deleted file mode 100644 index d8a0a2e0..00000000 --- a/internal/quic-go/quic_suite_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package quic - -import ( - "io/ioutil" - "log" - "sync" - "testing" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestQuicGo(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "QUIC Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) - - // reset the sync.Once - connMuxerOnce = *new(sync.Once) -}) - -var _ = BeforeSuite(func() { - log.SetOutput(ioutil.Discard) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/quic-go/quicvarint/io_test.go b/internal/quic-go/quicvarint/io_test.go deleted file mode 100644 index 054ab864..00000000 --- a/internal/quic-go/quicvarint/io_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package quicvarint - -import ( - "bytes" - "io" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type nopReader struct{} - -func (r *nopReader) Read(_ []byte) (int, error) { - return 0, io.ErrUnexpectedEOF -} - -var _ io.Reader = &nopReader{} - -type nopWriter struct{} - -func (r *nopWriter) Write(_ []byte) (int, error) { - return 0, io.ErrShortBuffer -} - -// eofReader is a reader that returns data and the io.EOF at the same time in the last Read call -type eofReader struct { - Data []byte - pos int -} - -func (r *eofReader) Read(b []byte) (int, error) { - n := copy(b, r.Data[r.pos:]) - r.pos += n - if r.pos >= len(r.Data) { - return n, io.EOF - } - return n, nil -} - -var _ io.Writer = &nopWriter{} - -var _ = Describe("Varint I/O", func() { - Context("Reader", func() { - Context("NewReader", func() { - It("passes through a Reader unchanged", func() { - b := bytes.NewReader([]byte{0}) - r := NewReader(b) - Expect(r).To(Equal(b)) - }) - - It("wraps an io.Reader", func() { - n := &nopReader{} - r := NewReader(n) - Expect(r).ToNot(Equal(n)) - }) - }) - - It("returns an error when reading from an underlying io.Reader fails", func() { - r := NewReader(&nopReader{}) - val, err := r.ReadByte() - Expect(err).To(Equal(io.ErrUnexpectedEOF)) - Expect(val).To(Equal(byte(0))) - }) - - Context("EOF handling", func() { - It("eofReader works correctly", func() { - r := &eofReader{Data: []byte("foobar")} - b := make([]byte, 3) - n, err := r.Read(b) - Expect(n).To(Equal(3)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foo")) - n, err = r.Read(b) - Expect(n).To(Equal(3)) - Expect(err).To(MatchError(io.EOF)) - Expect(string(b)).To(Equal("bar")) - n, err = r.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(BeZero()) - }) - - It("correctly handles io.EOF", func() { - buf := &bytes.Buffer{} - Write(buf, 1337) - - r := NewReader(&eofReader{Data: buf.Bytes()}) - n, err := Read(r) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(BeEquivalentTo(1337)) - }) - }) - }) - - Context("Writer", func() { - Context("NewWriter", func() { - It("passes through a Writer unchanged", func() { - b := &bytes.Buffer{} - w := NewWriter(b) - Expect(w).To(Equal(b)) - }) - - It("wraps an io.Writer", func() { - n := &nopWriter{} - w := NewWriter(n) - Expect(w).ToNot(Equal(n)) - }) - }) - - It("returns an error when writing to an underlying io.Writer fails", func() { - w := NewWriter(&nopWriter{}) - err := w.WriteByte(0) - Expect(err).To(Equal(io.ErrShortBuffer)) - }) - }) -}) diff --git a/internal/quic-go/quicvarint/quicvarint_suite_test.go b/internal/quic-go/quicvarint/quicvarint_suite_test.go deleted file mode 100644 index b7b17de7..00000000 --- a/internal/quic-go/quicvarint/quicvarint_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package quicvarint_test - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestQuicVarint(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "QUIC Varint Suite") -} diff --git a/internal/quic-go/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go index ba7e8772..e4040841 100644 --- a/internal/quic-go/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -3,18 +3,10 @@ package quicvarint import ( "fmt" "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" ) // taken from the QUIC draft const ( - // Min is the minimum value allowed for a QUIC varint. - Min = 0 - - // Max is the maximum allowed value for a QUIC varint (2^62-1). - Max = maxVarInt8 - maxVarInt1 = 63 maxVarInt2 = 16383 maxVarInt4 = 1073741823 @@ -88,36 +80,8 @@ func Write(w Writer, i uint64) { } } -// WriteWithLen writes i in the QUIC varint format with the desired length to w. -func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { - if length != 1 && length != 2 && length != 4 && length != 8 { - panic("invalid varint length") - } - l := Len(i) - if l == length { - Write(w, i) - return - } - if l > length { - panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) - } - if length == 2 { - w.WriteByte(0b01000000) - } else if length == 4 { - w.WriteByte(0b10000000) - } else if length == 8 { - w.WriteByte(0b11000000) - } - for j := protocol.ByteCount(1); j < length-l; j++ { - w.WriteByte(0) - } - for j := protocol.ByteCount(0); j < l; j++ { - w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) - } -} - // Len determines the number of bytes that will be needed to write the number i. -func Len(i uint64) protocol.ByteCount { +func Len(i uint64) uint64 { if i <= maxVarInt1 { return 1 } diff --git a/internal/quic-go/quicvarint/varint_test.go b/internal/quic-go/quicvarint/varint_test.go deleted file mode 100644 index acf4a31c..00000000 --- a/internal/quic-go/quicvarint/varint_test.go +++ /dev/null @@ -1,221 +0,0 @@ -package quicvarint - -import ( - "bytes" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Varint encoding / decoding", func() { - Context("limits", func() { - Specify("Min == 0", func() { - Expect(Min).To(Equal(0)) - }) - - Specify("Max == 2^62-1", func() { - Expect(uint64(Max)).To(Equal(uint64(1<<62 - 1))) - }) - }) - - Context("decoding", func() { - It("reads a 1 byte number", func() { - b := bytes.NewReader([]byte{0b00011001}) - val, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint64(25))) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a number that is encoded too long", func() { - b := bytes.NewReader([]byte{0b01000000, 0x25}) - val, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint64(37))) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a 2 byte number", func() { - b := bytes.NewReader([]byte{0b01111011, 0xbd}) - val, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint64(15293))) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a 4 byte number", func() { - b := bytes.NewReader([]byte{0b10011101, 0x7f, 0x3e, 0x7d}) - val, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint64(494878333))) - Expect(b.Len()).To(BeZero()) - }) - - It("reads an 8 byte number", func() { - b := bytes.NewReader([]byte{0b11000010, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}) - val, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint64(151288809941952652))) - Expect(b.Len()).To(BeZero()) - }) - }) - - Context("encoding", func() { - Context("with minimal length", func() { - It("writes a 1 byte number", func() { - b := &bytes.Buffer{} - Write(b, 37) - Expect(b.Bytes()).To(Equal([]byte{0x25})) - }) - - It("writes the maximum 1 byte number in 1 byte", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt1) - Expect(b.Bytes()).To(Equal([]byte{0b00111111})) - }) - - It("writes the minimum 2 byte number in 2 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt1+1) - Expect(b.Bytes()).To(Equal([]byte{0x40, maxVarInt1 + 1})) - }) - - It("writes a 2 byte number", func() { - b := &bytes.Buffer{} - Write(b, 15293) - Expect(b.Bytes()).To(Equal([]byte{0b01000000 ^ 0x3b, 0xbd})) - }) - - It("writes the maximum 2 byte number in 2 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt2) - Expect(b.Bytes()).To(Equal([]byte{0b01111111, 0xff})) - }) - - It("writes the minimum 4 byte number in 4 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt2+1) - Expect(b.Len()).To(Equal(4)) - num, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(num).To(Equal(uint64(maxVarInt2 + 1))) - }) - - It("writes a 4 byte number", func() { - b := &bytes.Buffer{} - Write(b, 494878333) - Expect(b.Bytes()).To(Equal([]byte{0b10000000 ^ 0x1d, 0x7f, 0x3e, 0x7d})) - }) - - It("writes the maximum 4 byte number in 4 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt4) - Expect(b.Bytes()).To(Equal([]byte{0b10111111, 0xff, 0xff, 0xff})) - }) - - It("writes the minimum 8 byte number in 8 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt4+1) - Expect(b.Len()).To(Equal(8)) - num, err := Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(num).To(Equal(uint64(maxVarInt4 + 1))) - }) - - It("writes an 8 byte number", func() { - b := &bytes.Buffer{} - Write(b, 151288809941952652) - Expect(b.Bytes()).To(Equal([]byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c})) - }) - - It("writes the maximum 8 byte number in 8 bytes", func() { - b := &bytes.Buffer{} - Write(b, maxVarInt8) - Expect(b.Bytes()).To(Equal([]byte{0xff /* 11111111 */, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) - }) - - It("panics when given a too large number (> 62 bit)", func() { - Expect(func() { Write(&bytes.Buffer{}, maxVarInt8+1) }).Should(Panic()) - }) - }) - - Context("with fixed length", func() { - It("panics when given an invalid length", func() { - Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic()) - }) - - It("panics when given a too short length", func() { - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic()) - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic()) - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic()) - }) - - It("writes a 1-byte number in minimal encoding", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 1) - Expect(b.Bytes()).To(Equal([]byte{0x25})) - }) - - It("writes a 1-byte number in 2 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 2) - Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) - }) - - It("writes a 1-byte number in 4 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 4) - Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) - }) - - It("writes a 1-byte number in 8 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 8) - Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) - }) - - It("writes a 2-byte number in 4 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 15293, 4) - Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd})) - Expect(Read(b)).To(BeEquivalentTo(15293)) - }) - - It("write a 4-byte number in 8 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 494878333, 8) - Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d})) - Expect(Read(b)).To(BeEquivalentTo(494878333)) - }) - }) - }) - - Context("determining the length needed for encoding", func() { - It("for numbers that need 1 byte", func() { - Expect(Len(0)).To(BeEquivalentTo(1)) - Expect(Len(maxVarInt1)).To(BeEquivalentTo(1)) - }) - - It("for numbers that need 2 bytes", func() { - Expect(Len(maxVarInt1 + 1)).To(BeEquivalentTo(2)) - Expect(Len(maxVarInt2)).To(BeEquivalentTo(2)) - }) - - It("for numbers that need 4 bytes", func() { - Expect(Len(maxVarInt2 + 1)).To(BeEquivalentTo(4)) - Expect(Len(maxVarInt4)).To(BeEquivalentTo(4)) - }) - - It("for numbers that need 8 bytes", func() { - Expect(Len(maxVarInt4 + 1)).To(BeEquivalentTo(8)) - Expect(Len(maxVarInt8)).To(BeEquivalentTo(8)) - }) - - It("panics when given a too large number (> 62 bit)", func() { - Expect(func() { Len(maxVarInt8 + 1) }).Should(Panic()) - }) - }) -}) diff --git a/internal/quic-go/receive_stream.go b/internal/quic-go/receive_stream.go deleted file mode 100644 index 6fea6afd..00000000 --- a/internal/quic-go/receive_stream.go +++ /dev/null @@ -1,331 +0,0 @@ -package quic - -import ( - "fmt" - "io" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type receiveStreamI interface { - ReceiveStream - - handleStreamFrame(*wire.StreamFrame) error - handleResetStreamFrame(*wire.ResetStreamFrame) error - closeForShutdown(error) - getWindowUpdate() protocol.ByteCount -} - -type receiveStream struct { - mutex sync.Mutex - - streamID protocol.StreamID - - sender streamSender - - frameQueue *frameSorter - finalOffset protocol.ByteCount - - currentFrame []byte - currentFrameDone func() - currentFrameIsLast bool // is the currentFrame the last frame on this stream - readPosInFrame int - - closeForShutdownErr error - cancelReadErr error - resetRemotelyErr *StreamError - - closedForShutdown bool // set when CloseForShutdown() is called - finRead bool // set once we read a frame with a Fin - canceledRead bool // set when CancelRead() is called - resetRemotely bool // set when HandleResetStreamFrame() is called - - readChan chan struct{} - readOnce chan struct{} // cap: 1, to protect against concurrent use of Read - deadline time.Time - - flowController flowcontrol.StreamFlowController - version protocol.VersionNumber -} - -var ( - _ ReceiveStream = &receiveStream{} - _ receiveStreamI = &receiveStream{} -) - -func newReceiveStream( - streamID protocol.StreamID, - sender streamSender, - flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, -) *receiveStream { - return &receiveStream{ - streamID: streamID, - sender: sender, - flowController: flowController, - frameQueue: newFrameSorter(), - readChan: make(chan struct{}, 1), - readOnce: make(chan struct{}, 1), - finalOffset: protocol.MaxByteCount, - version: version, - } -} - -func (s *receiveStream) StreamID() protocol.StreamID { - return s.streamID -} - -// Read implements io.Reader. It is not thread safe! -func (s *receiveStream) Read(p []byte) (int, error) { - // Concurrent use of Read is not permitted (and doesn't make any sense), - // but sometimes people do it anyway. - // Make sure that we only execute one call at any given time to avoid hard to debug failures. - s.readOnce <- struct{}{} - defer func() { <-s.readOnce }() - - s.mutex.Lock() - completed, n, err := s.readImpl(p) - s.mutex.Unlock() - - if completed { - s.sender.onStreamCompleted(s.streamID) - } - return n, err -} - -func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { - if s.finRead { - return false, 0, io.EOF - } - if s.canceledRead { - return false, 0, s.cancelReadErr - } - if s.resetRemotely { - return false, 0, s.resetRemotelyErr - } - if s.closedForShutdown { - return false, 0, s.closeForShutdownErr - } - - var bytesRead int - var deadlineTimer *utils.Timer - for bytesRead < len(p) { - if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { - s.dequeueNextFrame() - } - if s.currentFrame == nil && bytesRead > 0 { - return false, bytesRead, s.closeForShutdownErr - } - - for { - // Stop waiting on errors - if s.closedForShutdown { - return false, bytesRead, s.closeForShutdownErr - } - if s.canceledRead { - return false, bytesRead, s.cancelReadErr - } - if s.resetRemotely { - return false, bytesRead, s.resetRemotelyErr - } - - deadline := s.deadline - if !deadline.IsZero() { - if !time.Now().Before(deadline) { - return false, bytesRead, errDeadline - } - if deadlineTimer == nil { - deadlineTimer = utils.NewTimer() - defer deadlineTimer.Stop() - } - deadlineTimer.Reset(deadline) - } - - if s.currentFrame != nil || s.currentFrameIsLast { - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.readChan - } else { - select { - case <-s.readChan: - case <-deadlineTimer.Chan(): - deadlineTimer.SetRead() - } - } - s.mutex.Lock() - if s.currentFrame == nil { - s.dequeueNextFrame() - } - } - - if bytesRead > len(p) { - return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) - } - if s.readPosInFrame > len(s.currentFrame) { - return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) - } - - m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) - s.readPosInFrame += m - bytesRead += m - - // when a RESET_STREAM was received, the was already informed about the final byteOffset for this stream - if !s.resetRemotely { - s.flowController.AddBytesRead(protocol.ByteCount(m)) - } - - if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { - s.finRead = true - return true, bytesRead, io.EOF - } - } - return false, bytesRead, nil -} - -func (s *receiveStream) dequeueNextFrame() { - var offset protocol.ByteCount - // We're done with the last frame. Release the buffer. - if s.currentFrameDone != nil { - s.currentFrameDone() - } - offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() - s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset - s.readPosInFrame = 0 -} - -func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { - s.mutex.Lock() - completed := s.cancelReadImpl(errorCode) - s.mutex.Unlock() - - if completed { - s.flowController.Abandon() - s.sender.onStreamCompleted(s.streamID) - } -} - -func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { - if s.finRead || s.canceledRead || s.resetRemotely { - return false - } - s.canceledRead = true - s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) - s.signalRead() - s.sender.queueControlFrame(&wire.StopSendingFrame{ - StreamID: s.streamID, - ErrorCode: errorCode, - }) - // We're done with this stream if the final offset was already received. - return s.finalOffset != protocol.MaxByteCount -} - -func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { - s.mutex.Lock() - completed, err := s.handleStreamFrameImpl(frame) - s.mutex.Unlock() - - if completed { - s.flowController.Abandon() - s.sender.onStreamCompleted(s.streamID) - } - return err -} - -func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) { - maxOffset := frame.Offset + frame.DataLen() - if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil { - return false, err - } - var newlyRcvdFinalOffset bool - if frame.Fin { - newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount - s.finalOffset = maxOffset - } - if s.canceledRead { - return newlyRcvdFinalOffset, nil - } - if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { - return false, err - } - s.signalRead() - return false, nil -} - -func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { - s.mutex.Lock() - completed, err := s.handleResetStreamFrameImpl(frame) - s.mutex.Unlock() - - if completed { - s.flowController.Abandon() - s.sender.onStreamCompleted(s.streamID) - } - return err -} - -func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { - if s.closedForShutdown { - return false, nil - } - if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { - return false, err - } - newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount - s.finalOffset = frame.FinalSize - - // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) - if s.resetRemotely { - return false, nil - } - s.resetRemotely = true - s.resetRemotelyErr = &StreamError{ - StreamID: s.streamID, - ErrorCode: frame.ErrorCode, - } - s.signalRead() - return newlyRcvdFinalOffset, nil -} - -func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { - s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset}) -} - -func (s *receiveStream) SetReadDeadline(t time.Time) error { - s.mutex.Lock() - s.deadline = t - s.mutex.Unlock() - s.signalRead() - return nil -} - -// CloseForShutdown closes a stream abruptly. -// It makes Read unblock (and return the error) immediately. -// The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET. -func (s *receiveStream) closeForShutdown(err error) { - s.mutex.Lock() - s.closedForShutdown = true - s.closeForShutdownErr = err - s.mutex.Unlock() - s.signalRead() -} - -func (s *receiveStream) getWindowUpdate() protocol.ByteCount { - return s.flowController.GetWindowUpdate() -} - -// signalRead performs a non-blocking send on the readChan -func (s *receiveStream) signalRead() { - select { - case s.readChan <- struct{}{}: - default: - } -} diff --git a/internal/quic-go/receive_stream_test.go b/internal/quic-go/receive_stream_test.go deleted file mode 100644 index 06b30ef9..00000000 --- a/internal/quic-go/receive_stream_test.go +++ /dev/null @@ -1,696 +0,0 @@ -package quic - -import ( - "errors" - "io" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" -) - -var _ = Describe("Receive Stream", func() { - const streamID protocol.StreamID = 1337 - - var ( - str *receiveStream - strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender - ) - - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newReceiveStream(streamID, mockSender, mockFC, protocol.VersionWhatever) - - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = gbytes.TimeoutReader(str, timeout) - }) - - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) - }) - - Context("reading", func() { - It("reads a single STREAM frame", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("reads a single STREAM frame in multiple goes", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xDE, 0xAD})) - n, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xBE, 0xEF})) - }) - - It("reads all data available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) - }) - - It("assembles multiple STREAM frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("waits until data is available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - go func() { - defer GinkgoRecover() - frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} - time.Sleep(10 * time.Millisecond) - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - }() - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - }) - - It("handles STREAM frames in wrong order", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("ignores duplicate STREAM frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0x13, 0x37}, - } - frame3 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame3) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("doesn't rejects a STREAM frames with an overlapping data range", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte("foob"), - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte("obar"), - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - Context("deadlines", func() { - It("the deadline error has the right net.Error properties", func() { - Expect(errDeadline.Timeout()).To(BeTrue()) - Expect(errDeadline).To(MatchError("deadline exceeded")) - }) - - It("returns an error when Read is called after the deadline", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() - f := &wire.StreamFrame{Data: []byte("foobar")} - err := str.handleStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) - str.SetReadDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("unblocks when the deadline is changed to the past", func() { - str.SetReadDeadline(time.Now().Add(time.Hour)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Read(make([]byte, 6)) - Expect(err).To(MatchError(errDeadline)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.SetReadDeadline(time.Now().Add(-time.Hour)) - Eventually(done).Should(BeClosed()) - }) - - It("unblocks after the deadline", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetReadDeadline(deadline) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }) - - It("doesn't unblock if the deadline is changed before the first one expires", func() { - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetReadDeadline(deadline1) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - }() - runtime.Gosched() - b := make([]byte, 10) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - }) - - It("unblocks earlier, when a new deadline is set", func() { - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - }() - str.SetReadDeadline(deadline1) - runtime.Gosched() - b := make([]byte, 10) - _, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) - }) - - It("doesn't unblock if the deadline is removed", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetReadDeadline(deadline) - deadlineUnset := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetReadDeadline(time.Time{}) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline)) - close(deadlineUnset) - }() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read(make([]byte, 1)) - Expect(err).To(MatchError("test done")) - close(done) - }() - runtime.Gosched() - Eventually(deadlineUnset).Should(BeClosed()) - Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) - // make the go routine return - str.closeForShutdown(errors.New("test done")) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("closing", func() { - Context("with FIN bit", func() { - It("returns EOFs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - Fin: true, - }) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("handles out-of-order frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - Fin: true, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.handleStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns EOFs with partial read", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - err := str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte{0xde, 0xad}, - Fin: true, - }) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte{0xde, 0xad})) - }) - - It("handles immediate FINs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - err := str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Fin: true, - }) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - // Calling Read concurrently doesn't make any sense (and is forbidden), - // but we still want to make sure that we don't complete the stream more than once - // if the user misuses our API. - // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), - // which can be hard to debug. - // Note that even without the protection built into the receiveStream, this test - // is very timing-dependent, and would need to run a few hundred times to trigger the failure. - It("handles concurrent reads", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes() - var bytesRead protocol.ByteCount - mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes() - - var numCompleted int32 - mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { - atomic.AddInt32(&numCompleted, 1) - }).AnyTimes() - const num = 3 - var wg sync.WaitGroup - wg.Add(num) - for i := 0; i < num; i++ { - go func() { - defer wg.Done() - defer GinkgoRecover() - _, err := str.Read(make([]byte, 8)) - Expect(err).To(MatchError(io.EOF)) - }() - } - str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - Fin: true, - }) - wg.Wait() - Expect(bytesRead).To(BeEquivalentTo(6)) - Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1)) - }) - }) - - It("closes when CloseRemote is called", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - str.CloseRemote(0) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 8) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - }) - - Context("closing for shutdown", func() { - testErr := errors.New("test error") - - It("immediately returns all reads", func() { - done := make(chan struct{}) - b := make([]byte, 4) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.closeForShutdown(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("errors for all following reads", func() { - str.closeForShutdown(testErr) - b := make([]byte, 1) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - }) - }) - - Context("stream cancelations", func() { - Context("canceling read", func() { - It("unblocks Read", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.CancelRead(1234) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further calls to Read", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) - }) - - It("does nothing when CancelRead is called twice", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - str.CancelRead(1234) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) - }) - - It("queues a STOP_SENDING frame", func() { - mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 1234, - }) - str.CancelRead(1234) - }) - - It("doesn't send a STOP_SENDING frame, if the FIN was already read", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - // no calls to mockSender.queueControlFrame - Expect(str.handleStreamFrame(&wire.StreamFrame{ - StreamID: streamID, - Data: []byte("foobar"), - Fin: true, - })).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - _, err := strWithTimeout.Read(make([]byte, 100)) - Expect(err).To(MatchError(io.EOF)) - str.CancelRead(1234) - }) - - It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() { - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), - mockFC.EXPECT().Abandon(), - ) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 42, - })).To(Succeed()) - str.CancelRead(1234) - }) - - It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - })).To(Succeed()) - mockFC.EXPECT().Abandon() - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelRead(1234) - }) - - It("completes the stream when receiving the Fin after the stream was canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), - mockFC.EXPECT().Abandon(), - ) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - })).To(Succeed()) - }) - - It("handles duplicate FinBits after the stream was canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), - mockFC.EXPECT().Abandon(), - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), - ) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - })).To(Succeed()) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - })).To(Succeed()) - }) - }) - - Context("receiving RESET_STREAM frames", func() { - rst := &wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 42, - ErrorCode: 1234, - } - - It("unblocks Read", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - })) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - mockSender.EXPECT().onStreamCompleted(streamID) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), - mockFC.EXPECT().Abandon(), - ) - Expect(str.handleResetStreamFrame(rst)).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further calls to Read", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), - mockFC.EXPECT().Abandon(), - ) - Expect(str.handleResetStreamFrame(rst)).To(Succeed()) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - })) - }) - - It("errors when receiving a RESET_STREAM with an inconsistent offset", func() { - testErr := errors.New("already received a different final offset before") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr) - err := str.handleResetStreamFrame(rst) - Expect(err).To(MatchError(testErr)) - }) - - It("ignores duplicate RESET_STREAM frames", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - mockFC.EXPECT().Abandon() - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) - Expect(str.handleResetStreamFrame(rst)).To(Succeed()) - Expect(str.handleResetStreamFrame(rst)).To(Succeed()) - }) - - It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - mockSender.EXPECT().onStreamCompleted(streamID) - mockFC.EXPECT().Abandon() - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - StreamID: streamID, - Offset: rst.FinalSize, - Fin: true, - })).To(Succeed()) - Expect(str.handleResetStreamFrame(rst)).To(Succeed()) - }) - - It("doesn't do anyting when it was closed for shutdown", func() { - str.closeForShutdown(nil) - err := str.handleResetStreamFrame(rst) - Expect(err).ToNot(HaveOccurred()) - }) - }) - }) - - Context("flow control", func() { - It("errors when a STREAM frame causes a flow control violation", func() { - testErr := errors.New("flow control violation") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) - frame := wire.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.handleStreamFrame(&frame) - Expect(err).To(MatchError(testErr)) - }) - - It("gets a window update", func() { - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) - Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) - }) - }) -}) diff --git a/internal/quic-go/retransmission_queue.go b/internal/quic-go/retransmission_queue.go deleted file mode 100644 index 57d54e5f..00000000 --- a/internal/quic-go/retransmission_queue.go +++ /dev/null @@ -1,131 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type retransmissionQueue struct { - initial []wire.Frame - initialCryptoData []*wire.CryptoFrame - - handshake []wire.Frame - handshakeCryptoData []*wire.CryptoFrame - - appData []wire.Frame - - version protocol.VersionNumber -} - -func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { - return &retransmissionQueue{version: ver} -} - -func (q *retransmissionQueue) AddInitial(f wire.Frame) { - if cf, ok := f.(*wire.CryptoFrame); ok { - q.initialCryptoData = append(q.initialCryptoData, cf) - return - } - q.initial = append(q.initial, f) -} - -func (q *retransmissionQueue) AddHandshake(f wire.Frame) { - if cf, ok := f.(*wire.CryptoFrame); ok { - q.handshakeCryptoData = append(q.handshakeCryptoData, cf) - return - } - q.handshake = append(q.handshake, f) -} - -func (q *retransmissionQueue) HasInitialData() bool { - return len(q.initialCryptoData) > 0 || len(q.initial) > 0 -} - -func (q *retransmissionQueue) HasHandshakeData() bool { - return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 -} - -func (q *retransmissionQueue) HasAppData() bool { - return len(q.appData) > 0 -} - -func (q *retransmissionQueue) AddAppData(f wire.Frame) { - if _, ok := f.(*wire.StreamFrame); ok { - panic("STREAM frames are handled with their respective streams.") - } - q.appData = append(q.appData, f) -} - -func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { - if len(q.initialCryptoData) > 0 { - f := q.initialCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) - if newFrame == nil && !needsSplit { // the whole frame fits - q.initialCryptoData = q.initialCryptoData[1:] - return f - } - if newFrame != nil { // frame was split. Leave the original frame in the queue. - return newFrame - } - } - if len(q.initial) == 0 { - return nil - } - f := q.initial[0] - if f.Length(q.version) > maxLen { - return nil - } - q.initial = q.initial[1:] - return f -} - -func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { - if len(q.handshakeCryptoData) > 0 { - f := q.handshakeCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) - if newFrame == nil && !needsSplit { // the whole frame fits - q.handshakeCryptoData = q.handshakeCryptoData[1:] - return f - } - if newFrame != nil { // frame was split. Leave the original frame in the queue. - return newFrame - } - } - if len(q.handshake) == 0 { - return nil - } - f := q.handshake[0] - if f.Length(q.version) > maxLen { - return nil - } - q.handshake = q.handshake[1:] - return f -} - -func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { - if len(q.appData) == 0 { - return nil - } - f := q.appData[0] - if f.Length(q.version) > maxLen { - return nil - } - q.appData = q.appData[1:] - return f -} - -func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { - //nolint:exhaustive // Can only drop Initial and Handshake packet number space. - switch encLevel { - case protocol.EncryptionInitial: - q.initial = nil - q.initialCryptoData = nil - case protocol.EncryptionHandshake: - q.handshake = nil - q.handshakeCryptoData = nil - default: - panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) - } -} diff --git a/internal/quic-go/retransmission_queue_test.go b/internal/quic-go/retransmission_queue_test.go deleted file mode 100644 index 4780571f..00000000 --- a/internal/quic-go/retransmission_queue_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package quic - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Retransmission queue", func() { - const version = protocol.VersionTLS - - var q *retransmissionQueue - - BeforeEach(func() { - q = newRetransmissionQueue(version) - }) - - Context("Initial data", func() { - It("doesn't dequeue anything when it's empty", func() { - Expect(q.HasInitialData()).To(BeFalse()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) - }) - - It("queues and retrieves a control frame", func() { - f := &wire.MaxDataFrame{MaximumData: 0x42} - q.AddInitial(f) - Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) - Expect(q.HasInitialData()).To(BeFalse()) - }) - - It("queues and retrieves a CRYPTO frame", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddInitial(f) - Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) - Expect(q.HasInitialData()).To(BeFalse()) - }) - - It("returns split CRYPTO frames", func() { - f := &wire.CryptoFrame{ - Offset: 100, - Data: []byte("foobar"), - } - q.AddInitial(f) - Expect(q.HasInitialData()).To(BeTrue()) - f1 := q.GetInitialFrame(f.Length(version) - 3) - Expect(f1).ToNot(BeNil()) - Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) - Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) - Expect(q.HasInitialData()).To(BeTrue()) - f2 := q.GetInitialFrame(protocol.MaxByteCount) - Expect(f2).ToNot(BeNil()) - Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) - Expect(f2.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(103))) - Expect(q.HasInitialData()).To(BeFalse()) - }) - - It("returns other frames when a CRYPTO frame wouldn't fit", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddInitial(f) - q.AddInitial(&wire.PingFrame{}) - f1 := q.GetInitialFrame(2) // too small for a CRYPTO frame - Expect(f1).ToNot(BeNil()) - Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(q.HasInitialData()).To(BeTrue()) - f2 := q.GetInitialFrame(protocol.MaxByteCount) - Expect(f2).To(Equal(f)) - }) - - It("retrieves both a CRYPTO frame and a control frame", func() { - cf := &wire.MaxDataFrame{MaximumData: 0x42} - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddInitial(f) - q.AddInitial(cf) - Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f)) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf)) - Expect(q.HasInitialData()).To(BeFalse()) - }) - - It("drops all Initial frames", func() { - q.AddInitial(&wire.CryptoFrame{Data: []byte("foobar")}) - q.AddInitial(&wire.MaxDataFrame{MaximumData: 0x42}) - q.DropPackets(protocol.EncryptionInitial) - Expect(q.HasInitialData()).To(BeFalse()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) - }) - }) - - Context("Handshake data", func() { - It("doesn't dequeue anything when it's empty", func() { - Expect(q.HasHandshakeData()).To(BeFalse()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) - }) - - It("queues and retrieves a control frame", func() { - f := &wire.MaxDataFrame{MaximumData: 0x42} - q.AddHandshake(f) - Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) - Expect(q.HasHandshakeData()).To(BeFalse()) - }) - - It("queues and retrieves a CRYPTO frame", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddHandshake(f) - Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) - Expect(q.HasHandshakeData()).To(BeFalse()) - }) - - It("returns split CRYPTO frames", func() { - f := &wire.CryptoFrame{ - Offset: 100, - Data: []byte("foobar"), - } - q.AddHandshake(f) - Expect(q.HasHandshakeData()).To(BeTrue()) - f1 := q.GetHandshakeFrame(f.Length(version) - 3) - Expect(f1).ToNot(BeNil()) - Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) - Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) - Expect(q.HasHandshakeData()).To(BeTrue()) - f2 := q.GetHandshakeFrame(protocol.MaxByteCount) - Expect(f2).ToNot(BeNil()) - Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) - Expect(f2.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(103))) - Expect(q.HasHandshakeData()).To(BeFalse()) - }) - - It("returns other frames when a CRYPTO frame wouldn't fit", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddHandshake(f) - q.AddHandshake(&wire.PingFrame{}) - f1 := q.GetHandshakeFrame(2) // too small for a CRYPTO frame - Expect(f1).ToNot(BeNil()) - Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(q.HasHandshakeData()).To(BeTrue()) - f2 := q.GetHandshakeFrame(protocol.MaxByteCount) - Expect(f2).To(Equal(f)) - }) - - It("retrieves both a CRYPTO frame and a control frame", func() { - cf := &wire.MaxDataFrame{MaximumData: 0x42} - f := &wire.CryptoFrame{Data: []byte("foobar")} - q.AddHandshake(f) - q.AddHandshake(cf) - Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f)) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf)) - Expect(q.HasHandshakeData()).To(BeFalse()) - }) - - It("drops all Handshake frames", func() { - q.AddHandshake(&wire.CryptoFrame{Data: []byte("foobar")}) - q.AddHandshake(&wire.MaxDataFrame{MaximumData: 0x42}) - q.DropPackets(protocol.EncryptionHandshake) - Expect(q.HasHandshakeData()).To(BeFalse()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) - }) - }) - - Context("Application data", func() { - It("doesn't dequeue anything when it's empty", func() { - Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) - }) - - It("queues and retrieves a control frame", func() { - f := &wire.MaxDataFrame{MaximumData: 0x42} - Expect(q.HasAppData()).To(BeFalse()) - q.AddAppData(f) - Expect(q.HasAppData()).To(BeTrue()) - Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) - Expect(q.HasAppData()).To(BeFalse()) - }) - }) -}) diff --git a/internal/quic-go/send_conn.go b/internal/quic-go/send_conn.go deleted file mode 100644 index c53ebdfa..00000000 --- a/internal/quic-go/send_conn.go +++ /dev/null @@ -1,74 +0,0 @@ -package quic - -import ( - "net" -) - -// A sendConn allows sending using a simple Write() on a non-connected packet conn. -type sendConn interface { - Write([]byte) error - Close() error - LocalAddr() net.Addr - RemoteAddr() net.Addr -} - -type sconn struct { - rawConn - - remoteAddr net.Addr - info *packetInfo - oob []byte -} - -var _ sendConn = &sconn{} - -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { - return &sconn{ - rawConn: c, - remoteAddr: remote, - info: info, - oob: info.OOB(), - } -} - -func (c *sconn) Write(p []byte) error { - _, err := c.WritePacket(p, c.remoteAddr, c.oob) - return err -} - -func (c *sconn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *sconn) LocalAddr() net.Addr { - addr := c.rawConn.LocalAddr() - if c.info != nil { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - addrCopy := *udpAddr - addrCopy.IP = c.info.addr - addr = &addrCopy - } - } - return addr -} - -type spconn struct { - net.PacketConn - - remoteAddr net.Addr -} - -var _ sendConn = &spconn{} - -func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { - return &spconn{PacketConn: c, remoteAddr: remote} -} - -func (c *spconn) Write(p []byte) error { - _, err := c.WriteTo(p, c.remoteAddr) - return err -} - -func (c *spconn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/internal/quic-go/send_conn_test.go b/internal/quic-go/send_conn_test.go deleted file mode 100644 index 15e8f3b4..00000000 --- a/internal/quic-go/send_conn_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package quic - -import ( - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Connection (for sending packets)", func() { - var ( - c sendConn - packetConn *MockPacketConn - addr net.Addr - ) - - BeforeEach(func() { - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) - c = newSendPconn(packetConn, addr) - }) - - It("writes", func() { - packetConn.EXPECT().WriteTo([]byte("foobar"), addr) - Expect(c.Write([]byte("foobar"))).To(Succeed()) - }) - - It("gets the remote address", func() { - Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) - }) - - It("gets the local address", func() { - addr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 0, 1), - Port: 1234, - } - packetConn.EXPECT().LocalAddr().Return(addr) - Expect(c.LocalAddr()).To(Equal(addr)) - }) - - It("closes", func() { - packetConn.EXPECT().Close() - Expect(c.Close()).To(Succeed()) - }) -}) diff --git a/internal/quic-go/send_queue.go b/internal/quic-go/send_queue.go deleted file mode 100644 index 1fc8c1bf..00000000 --- a/internal/quic-go/send_queue.go +++ /dev/null @@ -1,88 +0,0 @@ -package quic - -type sender interface { - Send(p *packetBuffer) - Run() error - WouldBlock() bool - Available() <-chan struct{} - Close() -} - -type sendQueue struct { - queue chan *packetBuffer - closeCalled chan struct{} // runStopped when Close() is called - runStopped chan struct{} // runStopped when the run loop returns - available chan struct{} - conn sendConn -} - -var _ sender = &sendQueue{} - -const sendQueueCapacity = 8 - -func newSendQueue(conn sendConn) sender { - return &sendQueue{ - conn: conn, - runStopped: make(chan struct{}), - closeCalled: make(chan struct{}), - available: make(chan struct{}, 1), - queue: make(chan *packetBuffer, sendQueueCapacity), - } -} - -// Send sends out a packet. It's guaranteed to not block. -// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. -// Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer) { - select { - case h.queue <- p: - case <-h.runStopped: - default: - panic("sendQueue.Send would have blocked") - } -} - -func (h *sendQueue) WouldBlock() bool { - return len(h.queue) == sendQueueCapacity -} - -func (h *sendQueue) Available() <-chan struct{} { - return h.available -} - -func (h *sendQueue) Run() error { - defer close(h.runStopped) - var shouldClose bool - for { - if shouldClose && len(h.queue) == 0 { - return nil - } - select { - case <-h.closeCalled: - h.closeCalled = nil // prevent this case from being selected again - // make sure that all queued packets are actually sent out - shouldClose = true - case p := <-h.queue: - if err := h.conn.Write(p.Data); err != nil { - // This additional check enables: - // 1. Checking for "datagram too large" message from the kernel, as such, - // 2. Path MTU discovery,and - // 3. Eventual detection of loss PingFrame. - if !isMsgSizeErr(err) { - return err - } - } - p.Release() - select { - case h.available <- struct{}{}: - default: - } - } - } -} - -func (h *sendQueue) Close() { - close(h.closeCalled) - // wait until the run loop returned - <-h.runStopped -} diff --git a/internal/quic-go/send_queue_test.go b/internal/quic-go/send_queue_test.go deleted file mode 100644 index dc3179c4..00000000 --- a/internal/quic-go/send_queue_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package quic - -import ( - "errors" - - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Send Queue", func() { - var q sender - var c *MockSendConn - - BeforeEach(func() { - c = NewMockSendConn(mockCtrl) - q = newSendQueue(c) - }) - - getPacket := func(b []byte) *packetBuffer { - buf := getPacketBuffer() - buf.Data = buf.Data[:len(b)] - copy(buf.Data, b) - return buf - } - - It("sends a packet", func() { - p := getPacket([]byte("foobar")) - q.Send(p) - - written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(written) }) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Run() - close(done) - }() - - Eventually(written).Should(BeClosed()) - q.Close() - Eventually(done).Should(BeClosed()) - }) - - It("panics when Send() is called although there's no space in the queue", func() { - for i := 0; i < sendQueueCapacity; i++ { - Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar"))) - } - Expect(q.WouldBlock()).To(BeTrue()) - Expect(func() { q.Send(getPacket([]byte("raboof"))) }).To(Panic()) - }) - - It("signals when sending is possible again", func() { - Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar1"))) - Consistently(q.Available()).ShouldNot(Receive()) - - // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any()).MinTimes(1).MaxTimes(2) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Run() - close(done) - }() - - Eventually(q.Available()).Should(Receive()) - Expect(q.WouldBlock()).To(BeFalse()) - Expect(func() { q.Send(getPacket([]byte("foobar2"))) }).ToNot(Panic()) - - q.Close() - Eventually(done).Should(BeClosed()) - }) - - It("does not block pending send after the queue has stopped running", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Run() - close(done) - }() - - // the run loop exits if there is a write error - testErr := errors.New("test error") - c.EXPECT().Write(gomock.Any()).Return(testErr) - q.Send(getPacket([]byte("foobar"))) - Eventually(done).Should(BeClosed()) - - sent := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Send(getPacket([]byte("raboof"))) - q.Send(getPacket([]byte("quux"))) - close(sent) - }() - - Eventually(sent).Should(BeClosed()) - }) - - It("blocks Close() until the packet has been sent out", func() { - written := make(chan []byte) - c.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Run() - close(done) - }() - - q.Send(getPacket([]byte("foobar"))) - - closed := make(chan struct{}) - go func() { - defer GinkgoRecover() - q.Close() - close(closed) - }() - - Consistently(closed).ShouldNot(BeClosed()) - // now write the packet - Expect(written).To(Receive()) - Eventually(done).Should(BeClosed()) - Eventually(closed).Should(BeClosed()) - }) -}) diff --git a/internal/quic-go/send_stream.go b/internal/quic-go/send_stream.go deleted file mode 100644 index f2e912e8..00000000 --- a/internal/quic-go/send_stream.go +++ /dev/null @@ -1,496 +0,0 @@ -package quic - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type sendStreamI interface { - SendStream - handleStopSendingFrame(*wire.StopSendingFrame) - hasData() bool - popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) - closeForShutdown(error) - updateSendWindow(protocol.ByteCount) -} - -type sendStream struct { - mutex sync.Mutex - - numOutstandingFrames int64 - retransmissionQueue []*wire.StreamFrame - - ctx context.Context - ctxCancel context.CancelFunc - - streamID protocol.StreamID - sender streamSender - - writeOffset protocol.ByteCount - - cancelWriteErr error - closeForShutdownErr error - - closedForShutdown bool // set when CloseForShutdown() is called - finishedWriting bool // set once Close() is called - canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received - finSent bool // set when a STREAM_FRAME with FIN bit has been sent - completed bool // set when this stream has been reported to the streamSender as completed - - dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out - nextFrame *wire.StreamFrame - - writeChan chan struct{} - writeOnce chan struct{} - deadline time.Time - - flowController flowcontrol.StreamFlowController - - version protocol.VersionNumber -} - -var ( - _ SendStream = &sendStream{} - _ sendStreamI = &sendStream{} -) - -func newSendStream( - streamID protocol.StreamID, - sender streamSender, - flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, -) *sendStream { - s := &sendStream{ - streamID: streamID, - sender: sender, - flowController: flowController, - writeChan: make(chan struct{}, 1), - writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write - version: version, - } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) - return s -} - -func (s *sendStream) StreamID() protocol.StreamID { - return s.streamID // same for receiveStream and sendStream -} - -func (s *sendStream) Write(p []byte) (int, error) { - // Concurrent use of Write is not permitted (and doesn't make any sense), - // but sometimes people do it anyway. - // Make sure that we only execute one call at any given time to avoid hard to debug failures. - s.writeOnce <- struct{}{} - defer func() { <-s.writeOnce }() - - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.finishedWriting { - return 0, fmt.Errorf("write on closed stream %d", s.streamID) - } - if s.canceledWrite { - return 0, s.cancelWriteErr - } - if s.closeForShutdownErr != nil { - return 0, s.closeForShutdownErr - } - if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { - return 0, errDeadline - } - if len(p) == 0 { - return 0, nil - } - - s.dataForWriting = p - - var ( - deadlineTimer *utils.Timer - bytesWritten int - notifiedSender bool - ) - for { - var copied bool - var deadline time.Time - // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), - // which can the be popped the next time we assemble a packet. - // This allows us to return Write() when all data but x bytes have been sent out. - // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, - // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). - if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { - if s.nextFrame == nil { - f := wire.GetStreamFrame() - f.Offset = s.writeOffset - f.StreamID = s.streamID - f.DataLenPresent = true - f.Data = f.Data[:len(s.dataForWriting)] - copy(f.Data, s.dataForWriting) - s.nextFrame = f - } else { - l := len(s.nextFrame.Data) - s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)] - copy(s.nextFrame.Data[l:], s.dataForWriting) - } - s.dataForWriting = nil - bytesWritten = len(p) - copied = true - } else { - bytesWritten = len(p) - len(s.dataForWriting) - deadline = s.deadline - if !deadline.IsZero() { - if !time.Now().Before(deadline) { - s.dataForWriting = nil - return bytesWritten, errDeadline - } - if deadlineTimer == nil { - deadlineTimer = utils.NewTimer() - defer deadlineTimer.Stop() - } - deadlineTimer.Reset(deadline) - } - if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { - break - } - } - - s.mutex.Unlock() - if !notifiedSender { - s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex - notifiedSender = true - } - if copied { - s.mutex.Lock() - break - } - if deadline.IsZero() { - <-s.writeChan - } else { - select { - case <-s.writeChan: - case <-deadlineTimer.Chan(): - deadlineTimer.SetRead() - } - } - s.mutex.Lock() - } - - if bytesWritten == len(p) { - return bytesWritten, nil - } - if s.closeForShutdownErr != nil { - return bytesWritten, s.closeForShutdownErr - } else if s.cancelWriteErr != nil { - return bytesWritten, s.cancelWriteErr - } - return bytesWritten, nil -} - -func (s *sendStream) canBufferStreamFrame() bool { - var l protocol.ByteCount - if s.nextFrame != nil { - l = s.nextFrame.DataLen() - } - return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize -} - -// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream -// maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) { - s.mutex.Lock() - f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes) - if f != nil { - s.numOutstandingFrames++ - } - s.mutex.Unlock() - - if f == nil { - return nil, hasMoreData - } - return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData -} - -func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { - if s.canceledWrite || s.closeForShutdownErr != nil { - return nil, false - } - - if len(s.retransmissionQueue) > 0 { - f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes) - if f != nil || hasMoreRetransmissions { - if f == nil { - return nil, true - } - // We always claim that we have more data to send. - // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. - return f, true - } - } - - if len(s.dataForWriting) == 0 && s.nextFrame == nil { - if s.finishedWriting && !s.finSent { - s.finSent = true - return &wire.StreamFrame{ - StreamID: s.streamID, - Offset: s.writeOffset, - DataLenPresent: true, - Fin: true, - }, false - } - return nil, false - } - - sendWindow := s.flowController.SendWindowSize() - if sendWindow == 0 { - if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { - s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: s.streamID, - MaximumStreamData: offset, - }) - return nil, false - } - return nil, true - } - - f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow) - if dataLen := f.DataLen(); dataLen > 0 { - s.writeOffset += f.DataLen() - s.flowController.AddBytesSent(f.DataLen()) - } - f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent - if f.Fin { - s.finSent = true - } - return f, hasMoreData -} - -func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) { - if s.nextFrame != nil { - nextFrame := s.nextFrame - s.nextFrame = nil - - maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version)) - if nextFrame.DataLen() > maxDataLen { - s.nextFrame = wire.GetStreamFrame() - s.nextFrame.StreamID = s.streamID - s.nextFrame.Offset = s.writeOffset + maxDataLen - s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] - s.nextFrame.DataLenPresent = true - copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) - nextFrame.Data = nextFrame.Data[:maxDataLen] - } else { - s.signalWrite() - } - return nextFrame, s.nextFrame != nil || s.dataForWriting != nil - } - - f := wire.GetStreamFrame() - f.Fin = false - f.StreamID = s.streamID - f.Offset = s.writeOffset - f.DataLenPresent = true - f.Data = f.Data[:0] - - hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow) - if len(f.Data) == 0 && !f.Fin { - f.PutBack() - return nil, hasMoreData - } - return f, hasMoreData -} - -func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool { - maxDataLen := f.MaxDataLen(maxBytes, s.version) - if maxDataLen == 0 { // a STREAM frame must have at least one byte of data - return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting - } - s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow)) - - return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting -} - -func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) { - f := s.retransmissionQueue[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version) - if needsSplit { - return newFrame, true - } - s.retransmissionQueue = s.retransmissionQueue[1:] - return f, len(s.retransmissionQueue) > 0 -} - -func (s *sendStream) hasData() bool { - s.mutex.Lock() - hasData := len(s.dataForWriting) > 0 - s.mutex.Unlock() - return hasData -} - -func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { - if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { - f.Data = f.Data[:len(s.dataForWriting)] - copy(f.Data, s.dataForWriting) - s.dataForWriting = nil - s.signalWrite() - return - } - f.Data = f.Data[:maxBytes] - copy(f.Data, s.dataForWriting) - s.dataForWriting = s.dataForWriting[maxBytes:] - if s.canBufferStreamFrame() { - s.signalWrite() - } -} - -func (s *sendStream) frameAcked(f wire.Frame) { - f.(*wire.StreamFrame).PutBack() - - s.mutex.Lock() - if s.canceledWrite { - s.mutex.Unlock() - return - } - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - newlyCompleted := s.isNewlyCompleted() - s.mutex.Unlock() - - if newlyCompleted { - s.sender.onStreamCompleted(s.streamID) - } -} - -func (s *sendStream) isNewlyCompleted() bool { - completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 - if completed && !s.completed { - s.completed = true - return true - } - return false -} - -func (s *sendStream) queueRetransmission(f wire.Frame) { - sf := f.(*wire.StreamFrame) - sf.DataLenPresent = true - s.mutex.Lock() - if s.canceledWrite { - s.mutex.Unlock() - return - } - s.retransmissionQueue = append(s.retransmissionQueue, sf) - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - s.mutex.Unlock() - - s.sender.onHasStreamData(s.streamID) -} - -func (s *sendStream) Close() error { - s.mutex.Lock() - if s.closedForShutdown { - s.mutex.Unlock() - return nil - } - if s.canceledWrite { - s.mutex.Unlock() - return fmt.Errorf("close called for canceled stream %d", s.streamID) - } - s.ctxCancel() - s.finishedWriting = true - s.mutex.Unlock() - - s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex - return nil -} - -func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) -} - -// must be called after locking the mutex -func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { - s.mutex.Lock() - if s.canceledWrite { - s.mutex.Unlock() - return - } - s.ctxCancel() - s.canceledWrite = true - s.cancelWriteErr = writeErr - s.numOutstandingFrames = 0 - s.retransmissionQueue = nil - newlyCompleted := s.isNewlyCompleted() - s.mutex.Unlock() - - s.signalWrite() - s.sender.queueControlFrame(&wire.ResetStreamFrame{ - StreamID: s.streamID, - FinalSize: s.writeOffset, - ErrorCode: errorCode, - }) - if newlyCompleted { - s.sender.onStreamCompleted(s.streamID) - } -} - -func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { - s.mutex.Lock() - hasStreamData := s.dataForWriting != nil || s.nextFrame != nil - s.mutex.Unlock() - - s.flowController.UpdateSendWindow(limit) - if hasStreamData { - s.sender.onHasStreamData(s.streamID) - } -} - -func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWriteImpl(frame.ErrorCode, &StreamError{ - StreamID: s.streamID, - ErrorCode: frame.ErrorCode, - }) -} - -func (s *sendStream) Context() context.Context { - return s.ctx -} - -func (s *sendStream) SetWriteDeadline(t time.Time) error { - s.mutex.Lock() - s.deadline = t - s.mutex.Unlock() - s.signalWrite() - return nil -} - -// CloseForShutdown closes a stream abruptly. -// It makes Write unblock (and return the error) immediately. -// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. -func (s *sendStream) closeForShutdown(err error) { - s.mutex.Lock() - s.ctxCancel() - s.closedForShutdown = true - s.closeForShutdownErr = err - s.mutex.Unlock() - s.signalWrite() -} - -// signalWrite performs a non-blocking send on the writeChan -func (s *sendStream) signalWrite() { - select { - case s.writeChan <- struct{}{}: - default: - } -} diff --git a/internal/quic-go/send_stream_test.go b/internal/quic-go/send_stream_test.go deleted file mode 100644 index 066cdb8c..00000000 --- a/internal/quic-go/send_stream_test.go +++ /dev/null @@ -1,1159 +0,0 @@ -package quic - -import ( - "bytes" - "errors" - "io" - mrand "math/rand" - "runtime" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" -) - -var _ = Describe("Send Stream", func() { - const streamID protocol.StreamID = 1337 - - var ( - str *sendStream - strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender - ) - - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever) - - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = gbytes.TimeoutWriter(str, timeout) - }) - - expectedFrameHeaderLen := func(offset protocol.ByteCount) protocol.ByteCount { - return (&wire.StreamFrame{ - StreamID: streamID, - Offset: offset, - DataLenPresent: true, - }).Length(protocol.VersionWhatever) - } - - waitForWrite := func() { - EventuallyWithOffset(0, func() bool { - str.mutex.Lock() - hasData := str.dataForWriting != nil || str.nextFrame != nil - str.mutex.Unlock() - return hasData - }).Should(BeTrue()) - } - - getDataAtOffset := func(offset, length protocol.ByteCount) []byte { - b := make([]byte, length) - for i := protocol.ByteCount(0); i < length; i++ { - b[i] = uint8(offset + i) - } - return b - } - - getData := func(length protocol.ByteCount) []byte { - return getDataAtOffset(0, length) - } - - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) - }) - - Context("writing", func() { - It("writes and gets all data at once", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, _ := str.popStreamFrame(protocol.MaxByteCount) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("foobar"))) - Expect(f.Fin).To(BeFalse()) - Expect(f.Offset).To(BeZero()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) - Expect(str.dataForWriting).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) - - It("writes and gets data in two turns", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - close(done) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) - frame, _ := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Offset).To(BeZero()) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.DataLenPresent).To(BeTrue()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Fin).To(BeFalse()) - Expect(f.Offset).To(Equal(protocol.ByteCount(3))) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(str.popStreamFrame(1000)).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) - - It("bundles small writes", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID).Times(2) - n, err := strWithTimeout.Write([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - n, err = strWithTimeout.Write([]byte("bar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - close(done) - }() - Eventually(done).Should(BeClosed()) // both Write calls returned without any data having been dequeued yet - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, _ := str.popStreamFrame(protocol.MaxByteCount) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Offset).To(BeZero()) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal([]byte("foobar"))) - }) - - It("writes and gets data in multiple turns, for large writes", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(5) - var totalBytesSent protocol.ByteCount - mockFC.EXPECT().AddBytesSent(gomock.Any()).Do(func(l protocol.ByteCount) { totalBytesSent += l }).Times(5) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID) - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(5000)) - close(done) - }() - waitForWrite() - for i := 0; i < 5; i++ { - frame, _ := str.popStreamFrame(1100) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Offset).To(BeNumerically("~", 1100*i, 10*i)) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) - Expect(f.DataLenPresent).To(BeTrue()) - } - Expect(totalBytesSent).To(Equal(protocol.ByteCount(5000))) - Eventually(done).Should(BeClosed()) - }) - - It("unblocks Write as soon as a STREAM frame can be buffered", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := strWithTimeout.Write(getData(protocol.MaxPacketBufferSize + 3)) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) - Expect(hasMoreData).To(BeTrue()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.DataLen()).To(Equal(protocol.ByteCount(2))) - Consistently(done).ShouldNot(BeClosed()) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) - frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1) + 1) - Expect(hasMoreData).To(BeTrue()) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.DataLen()).To(Equal(protocol.ByteCount(1))) - Eventually(done).Should(BeClosed()) - }) - - It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() { - mockSender.EXPECT().onHasStreamData(streamID) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := str.Write(getData(protocol.MaxPacketBufferSize)) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) - Expect(hasMoreData).To(BeTrue()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("fo"))) - Consistently(done).ShouldNot(BeClosed()) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) - frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2) + 4) - Expect(hasMoreData).To(BeTrue()) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("obar"))) - Eventually(done).Should(BeClosed()) - }) - - It("popStreamFrame returns nil if no data is available", func() { - frame, hasMoreData := str.popStreamFrame(1000) - Expect(frame).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("says if it has more data for writing", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(100)) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, hasMoreData := str.popStreamFrame(50) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) - Expect(hasMoreData).To(BeTrue()) - frame, hasMoreData = str.popStreamFrame(protocol.MaxByteCount) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount) - Expect(frame).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) - - It("copies the slice while writing", func() { - frameHeaderSize := protocol.ByteCount(4) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - s := []byte("foo") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - n, err := strWithTimeout.Write(s) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - }() - waitForWrite() - frame, _ := str.popStreamFrame(frameHeaderSize + 1) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("f"))) - frame, _ = str.popStreamFrame(100) - Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("oo"))) - s[1] = 'e' - Expect(f.Data).To(Equal([]byte("oo"))) - Eventually(done).Should(BeClosed()) - }) - - It("returns when given a nil input", func() { - n, err := strWithTimeout.Write(nil) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("returns when given an empty slice", func() { - n, err := strWithTimeout.Write([]byte("")) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("cancels the context when Close is called", func() { - mockSender.EXPECT().onHasStreamData(streamID) - Expect(str.Context().Done()).ToNot(BeClosed()) - Expect(str.Close()).To(Succeed()) - Expect(str.Context().Done()).To(BeClosed()) - }) - - Context("flow control blocking", func() { - It("queues a BLOCKED frame if the stream is flow control blocked", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) - mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(12)) - mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: streamID, - MaximumStreamData: 12, - }) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - f, hasMoreData := str.popStreamFrame(1000) - Expect(f).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - // make the Write go routine return - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - - It("says that it doesn't have any more data, when it is flow control blocked", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - - // first pop a STREAM frame of the maximum size allowed by flow control - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) - f, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) - Expect(f).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - - // try to pop again, this time noticing that we're blocked - mockFC.EXPECT().SendWindowSize() - // don't use offset 3 here, to make sure the BLOCKED frame contains the number returned by the flow controller - mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(10)) - mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: streamID, - MaximumStreamData: 10, - }) - f, hasMoreData = str.popStreamFrame(1000) - Expect(f).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - // make the Write go routine return - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("deadlines", func() { - It("returns an error when Write is called after the deadline", func() { - str.SetWriteDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("unblocks after the deadline", func() { - mockSender.EXPECT().onHasStreamData(streamID) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }) - - It("unblocks when the deadline is changed to the past", func() { - mockSender.EXPECT().onHasStreamData(streamID) - str.SetWriteDeadline(time.Now().Add(time.Hour)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.SetWriteDeadline(time.Now().Add(-time.Hour)) - Eventually(done).Should(BeClosed()) - }) - - It("returns the number of bytes written, when the deadline expires", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - var n int - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(writeReturned) - mockSender.EXPECT().onHasStreamData(streamID) - var err error - n, err = strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) - }) - - It("doesn't pop any data after the deadline expired", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(writeReturned) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - frame, hasMoreData = str.popStreamFrame(50) - Expect(frame).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't unblock if the deadline is changed before the first one expires", func() { - mockSender.EXPECT().onHasStreamData(streamID) - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetWriteDeadline(deadline1) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - close(done) - }() - runtime.Gosched() - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - Eventually(done).Should(BeClosed()) - }) - - It("unblocks earlier, when a new deadline is set", func() { - mockSender.EXPECT().onHasStreamData(streamID) - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - close(done) - }() - str.SetWriteDeadline(deadline1) - runtime.Gosched() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't unblock if the deadline is removed", func() { - mockSender.EXPECT().onHasStreamData(streamID) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - deadlineUnset := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetWriteDeadline(time.Time{}) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline)) - close(deadlineUnset) - }() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("test done")) - close(done) - }() - runtime.Gosched() - Eventually(deadlineUnset).Should(BeClosed()) - Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) - // make the go routine return - str.closeForShutdown(errors.New("test done")) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("closing", func() { - It("doesn't allow writes after it has been closed", func() { - mockSender.EXPECT().onHasStreamData(streamID) - str.Close() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError("write on closed stream 1337")) - }) - - It("allows FIN", func() { - mockSender.EXPECT().onHasStreamData(streamID) - str.Close() - frame, hasMoreData := str.popStreamFrame(1000) - Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeTrue()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't send a FIN when there's still data", func() { - const frameHeaderLen protocol.ByteCount = 4 - mockSender.EXPECT().onHasStreamData(streamID).Times(2) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, _ := str.popStreamFrame(3 + frameHeaderLen) - Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.Fin).To(BeFalse()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Fin).To(BeTrue()) - }) - - It("doesn't send a FIN when there's still data, for long writes", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID) - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onHasStreamData(streamID) - Expect(str.Close()).To(Succeed()) - }() - waitForWrite() - for i := 1; i <= 5; i++ { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - if i == 5 { - Eventually(done).Should(BeClosed()) - } - frame, _ := str.popStreamFrame(1100) - Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) - Expect(f.Fin).To(Equal(i == 5)) // the last frame should have the FIN bit set - } - }) - - It("doesn't allow FIN after it is closed for shutdown", func() { - str.closeForShutdown(errors.New("test")) - f, hasMoreData := str.popStreamFrame(1000) - Expect(f).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - - Expect(str.Close()).To(Succeed()) - f, hasMoreData = str.popStreamFrame(1000) - Expect(f).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't allow FIN twice", func() { - mockSender.EXPECT().onHasStreamData(streamID) - str.Close() - frame, _ := str.popStreamFrame(1000) - Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeTrue()) - frame, hasMoreData := str.popStreamFrame(1000) - Expect(frame).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - }) - }) - - Context("closing for shutdown", func() { - testErr := errors.New("test") - - It("returns errors when the stream is cancelled", func() { - str.closeForShutdown(testErr) - n, err := strWithTimeout.Write([]byte("foo")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("doesn't get data for writing if an error occurred", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(testErr)) - close(done) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - str.closeForShutdown(testErr) - frame, hasMoreData = str.popStreamFrame(1000) - Expect(frame).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - Eventually(done).Should(BeClosed()) - }) - - It("cancels the context", func() { - Expect(str.Context().Done()).ToNot(BeClosed()) - str.closeForShutdown(testErr) - Expect(str.Context().Done()).To(BeClosed()) - }) - }) - }) - - Context("handling MAX_STREAM_DATA frames", func() { - It("informs the flow controller", func() { - mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) - str.updateSendWindow(0x1337) - }) - - It("says when it has data for sending", func() { - mockFC.EXPECT().UpdateSendWindow(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - mockSender.EXPECT().onHasStreamData(streamID) - str.updateSendWindow(42) - // make sure the Write go routine returns - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("stream cancellations", func() { - Context("canceling writing", func() { - It("queues a RESET_STREAM frame", func() { - gomock.InOrder( - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 1234, - ErrorCode: 9876, - }), - mockSender.EXPECT().onStreamCompleted(streamID), - ) - str.writeOffset = 1234 - str.CancelWrite(9876) - }) - - // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). - // A single successful run of this test therefore doesn't mean a lot, - // for reliable results it has to be run many times. - It("returns a nil error when the whole slice has been sent out", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(1) - mockSender.EXPECT().onHasStreamData(streamID).MaxTimes(1) - mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) - mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1) - errChan := make(chan error) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write(getData(100)) - if n == 0 { - errChan <- nil - return - } - errChan <- err - }() - - runtime.Gosched() - go str.popStreamFrame(protocol.MaxByteCount) - go str.CancelWrite(1234) - Eventually(errChan).Should(Receive(Not(HaveOccurred()))) - }) - - It("unblocks Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - var n int - go func() { - defer GinkgoRecover() - var err error - n, err = strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) - close(writeReturned) - }() - waitForWrite() - frame, _ := str.popStreamFrame(50) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelWrite(1234) - Eventually(writeReturned).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) - }) - - It("doesn't pop STREAM frames after being canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - strWithTimeout.Write(getData(100)) - close(writeReturned) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelWrite(1234) - frame, hasMoreData = str.popStreamFrame(10) - Expect(frame).To(BeNil()) - Expect(hasMoreData).To(BeFalse()) - Eventually(writeReturned).Should(BeClosed()) - }) - - It("doesn't pop STREAM frames after being canceled, for large writes", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) - close(writeReturned) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelWrite(1234) - frame, hasMoreData = str.popStreamFrame(10) - Expect(hasMoreData).To(BeFalse()) - Expect(frame).To(BeNil()) - Eventually(writeReturned).Should(BeClosed()) - }) - - It("ignores acknowledgements for STREAM frames after it was cancelled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - strWithTimeout.Write(getData(100)) - close(writeReturned) - }() - waitForWrite() - frame, hasMoreData := str.popStreamFrame(50) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelWrite(1234) - frame.OnAcked(frame.Frame) - }) - - It("cancels the context", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - Expect(str.Context().Done()).ToNot(BeClosed()) - str.CancelWrite(1234) - Expect(str.Context().Done()).To(BeClosed()) - }) - - It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - str.CancelWrite(1234) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) - }) - - It("only cancels once", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234}) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - str.CancelWrite(1234) - str.CancelWrite(4321) - }) - - It("queues a RESET_STREAM frame, even if the stream was already closed", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) - }) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - Expect(str.Close()).To(Succeed()) - // don't EXPECT any calls to queueControlFrame - str.CancelWrite(123) - }) - }) - - Context("receiving STOP_SENDING frames", func() { - It("queues a RESET_STREAM frames, and copies the error code from the STOP_SENDING frame", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - ErrorCode: 101, - }) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 101, - }) - }) - - It("unblocks Write", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write(getData(5000)) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - })) - close(done) - }() - waitForWrite() - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - _, err := str.Write([]byte("foobar")) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - })) - }) - }) - }) - - Context("retransmissions", func() { - It("queues and retrieves frames", func() { - str.numOutstandingFrames = 1 - f := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(f) - frame, _ := str.popStreamFrame(protocol.MaxByteCount) - Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) - Expect(f.Data).To(Equal([]byte("foobar"))) - Expect(f.DataLenPresent).To(BeTrue()) - }) - - It("splits a retransmission", func() { - str.numOutstandingFrames = 1 - sf := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(sf) - frame, hasMoreData := str.popStreamFrame(sf.Length(str.version) - 3) - Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) - Expect(hasMoreData).To(BeTrue()) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.DataLenPresent).To(BeTrue()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount) - Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x45))) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.DataLenPresent).To(BeTrue()) - }) - - It("returns nil if the size is too small", func() { - str.numOutstandingFrames = 1 - f := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(f) - frame, hasMoreData := str.popStreamFrame(2) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).To(BeNil()) - }) - - It("queues lost STREAM frames", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - frame, _ := str.popStreamFrame(protocol.MaxByteCount) - Eventually(done).Should(BeClosed()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - - // now lose the frame - mockSender.EXPECT().onHasStreamData(streamID) - frame.OnLost(frame.Frame) - newFrame, _ := str.popStreamFrame(protocol.MaxByteCount) - Expect(newFrame).ToNot(BeNil()) - Expect(newFrame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - }) - - It("doesn't queue retransmissions for a stream that was canceled", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - f, _ := str.popStreamFrame(100) - Expect(f).ToNot(BeNil()) - gomock.InOrder( - mockSender.EXPECT().queueControlFrame(gomock.Any()), - mockSender.EXPECT().onStreamCompleted(streamID), - ) - str.CancelWrite(9876) - // don't EXPECT any calls to onHasStreamData - f.OnLost(f.Frame) - Expect(str.retransmissionQueue).To(BeEmpty()) - }) - }) - - Context("determining when a stream is completed", func() { - BeforeEach(func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() - }) - - It("says when a stream is completed", func() { - mockSender.EXPECT().onHasStreamData(streamID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 100)) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - - // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.Frame - for { - frame, hasMoreData := str.popStreamFrame(20) - if frame == nil { - continue - } - frames = append(frames, *frame) - if !hasMoreData { - break - } - } - Eventually(done).Should(BeClosed()) - - // Acknowledge all frames. - // We don't expect the stream to be completed, since we still need to send the FIN. - for _, f := range frames { - f.OnAcked(f.Frame) - } - - // Now close the stream and acknowledge the FIN. - mockSender.EXPECT().onHasStreamData(streamID) - Expect(str.Close()).To(Succeed()) - frame, _ := str.popStreamFrame(protocol.MaxByteCount) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - frame.OnAcked(frame.Frame) - }) - - It("says when a stream is completed, if Close() is called before popping the frame", func() { - mockSender.EXPECT().onHasStreamData(streamID).Times(2) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 100)) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - Eventually(done).Should(BeClosed()) - Expect(str.Close()).To(Succeed()) - - frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount) - Expect(hasMoreData).To(BeFalse()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeTrue()) - - mockSender.EXPECT().onStreamCompleted(streamID) - frame.OnAcked(frame.Frame) - }) - - It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { - mockSender.EXPECT().onHasStreamData(streamID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(100)) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onHasStreamData(streamID) - Expect(str.Close()).To(Succeed()) - close(done) - }() - waitForWrite() - - // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.Frame - for { - frame, _ := str.popStreamFrame(20) - if frame == nil { - continue - } - frames = append(frames, *frame) - if frame.Frame.(*wire.StreamFrame).Fin { - break - } - } - Eventually(done).Should(BeClosed()) - - // lose the first frame, acknowledge all others - for _, f := range frames[1:] { - f.OnAcked(f.Frame) - } - mockSender.EXPECT().onHasStreamData(streamID) - frames[0].OnLost(frames[0].Frame) - - // get the retransmission and acknowledge it - ret, _ := str.popStreamFrame(protocol.MaxByteCount) - Expect(ret).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - ret.OnAcked(ret.Frame) - }) - - // This test is kind of an integration test. - // It writes 4 MB of data, and pops STREAM frames that sometimes are and sometimes aren't limited by flow control. - // Half of these STREAM frames are then received and their content saved, while the other half is reported lost - // and has to be retransmitted. - It("retransmits data until everything has been acknowledged", func() { - const dataLen = 1 << 22 // 4 MB - mockSender.EXPECT().onHasStreamData(streamID).AnyTimes() - mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { - return protocol.ByteCount(mrand.Intn(500)) + 50 - }).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() - - data := make([]byte, dataLen) - _, err := mrand.Read(data) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - _, err := str.Write(data) - Expect(err).ToNot(HaveOccurred()) - str.Close() - }() - - var completed bool - mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) - - received := make([]byte, dataLen) - for { - if completed { - break - } - f, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300) + 100)) - if f == nil { - continue - } - sf := f.Frame.(*wire.StreamFrame) - // 50%: acknowledge the frame and save the data - // 50%: lose the frame - if mrand.Intn(100) < 50 { - copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) - f.OnAcked(f.Frame) - } else { - f.OnLost(f.Frame) - } - } - Expect(received).To(Equal(data)) - }) - }) -}) diff --git a/internal/quic-go/server.go b/internal/quic-go/server.go deleted file mode 100644 index 07173357..00000000 --- a/internal/quic-go/server.go +++ /dev/null @@ -1,670 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "errors" - "fmt" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. -var ErrServerClosed = errors.New("quic: Server closed") - -// packetHandler handles packets -type packetHandler interface { - handlePacket(*receivedPacket) - shutdown() - destroy(error) - getPerspective() protocol.Perspective -} - -type unknownPacketHandler interface { - handlePacket(*receivedPacket) - setCloseError(error) -} - -type packetHandlerManager interface { - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool - Destroy() error - connRunner - SetServer(unknownPacketHandler) - CloseServer() -} - -type quicConn interface { - EarlyConnection - earlyConnReady() <-chan struct{} - handlePacket(*receivedPacket) - GetVersion() protocol.VersionNumber - getPerspective() protocol.Perspective - run() error - destroy(error) - shutdown() -} - -// A Listener of QUIC -type baseServer struct { - mutex sync.Mutex - - acceptEarlyConns bool - - tlsConf *tls.Config - config *Config - - conn rawConn - // If the server is started with ListenAddr, we create a packet conn. - // If it is started with Listen, we take a packet conn as a parameter. - createdPacketConn bool - - tokenGenerator *handshake.TokenGenerator - - connHandler packetHandlerManager - - receivedPackets chan *receivedPacket - - // set as a member, so they can be set in the tests - newConn func( - sendConn, - connRunner, - protocol.ConnectionID, /* original dest connection ID */ - *protocol.ConnectionID, /* retry src connection ID */ - protocol.ConnectionID, /* client dest connection ID */ - protocol.ConnectionID, /* destination connection ID */ - protocol.ConnectionID, /* source connection ID */ - protocol.StatelessResetToken, - *Config, - *tls.Config, - *handshake.TokenGenerator, - bool, /* enable 0-RTT */ - logging.ConnectionTracer, - uint64, - utils.Logger, - protocol.VersionNumber, - ) quicConn - - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns - - connQueue chan quicConn - connQueueLen int32 // to be used as an atomic - - logger utils.Logger -} - -var ( - _ Listener = &baseServer{} - _ unknownPacketHandler = &baseServer{} -) - -type earlyServer struct{ *baseServer } - -var _ EarlyListener = &earlyServer{} - -func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { - return s.baseServer.accept(ctx) -} - -// ListenAddr creates a QUIC server listening on a given address. -// The tls.Config must not be nil and must contain a certificate configuration. -// The quic.Config may be nil, in that case the default values will be used. -func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { - return listenAddr(addr, tlsConf, config, false) -} - -// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. -func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listenAddr(addr, tlsConf, config, true) - if err != nil { - return nil, err - } - return &earlyServer{s}, nil -} - -func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - serv, err := listen(conn, tlsConf, config, acceptEarly) - if err != nil { - return nil, err - } - serv.createdPacketConn = true - return serv, nil -} - -// Listen listens for QUIC connections on a given net.PacketConn. If the -// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. A single net.PacketConn only be used for a single call to Listen. -// The PacketConn can be used for simultaneous calls to Dial. QUIC connection -// IDs are used for demultiplexing the different connections. The tls.Config -// must not be nil and must contain a certificate configuration. The -// tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore, -// it must define an application control (using NextProtos). The quic.Config may -// be nil, in that case the default values will be used. -func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { - return listen(conn, tlsConf, config, false) -} - -// ListenEarly works like Listen, but it returns connections before the handshake completes. -func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listen(conn, tlsConf, config, true) - if err != nil { - return nil, err - } - return &earlyServer{s}, nil -} - -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateServerConfig(config) - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } - c, err := wrapConn(conn) - if err != nil { - return nil, err - } - s := &baseServer{ - conn: c, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newConn: newConnection, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - } - go s.run() - connHandler.SetServer(s) - s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil -} - -func (s *baseServer) run() { - defer close(s.running) - for { - select { - case <-s.errorChan: - return - default: - } - select { - case <-s.errorChan: - return - case p := <-s.receivedPackets: - if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { - p.buffer.Release() - } - } - } -} - -var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { - if token == nil { - return false - } - validity := protocol.TokenValidity - if token.IsRetryToken { - validity = protocol.RetryTokenValidity - } - if time.Now().After(token.SentTime.Add(validity)) { - return false - } - var sourceAddr string - if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { - sourceAddr = udpAddr.IP.String() - } else { - sourceAddr = clientAddr.String() - } - return sourceAddr == token.RemoteAddr -} - -// Accept returns connections that already completed the handshake. -// It is only valid if acceptEarlyConns is false. -func (s *baseServer) Accept(ctx context.Context) (Connection, error) { - return s.accept(ctx) -} - -func (s *baseServer) accept(ctx context.Context) (quicConn, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-s.connQueue: - atomic.AddInt32(&s.connQueueLen, -1) - return conn, nil - case <-s.errorChan: - return nil, s.serverError - } -} - -// Close the server -func (s *baseServer) Close() error { - s.mutex.Lock() - if s.closed { - s.mutex.Unlock() - return nil - } - if s.serverError == nil { - s.serverError = ErrServerClosed - } - // If the server was started with ListenAddr, we created the packet conn. - // We need to close it in order to make the go routine reading from that conn return. - createdPacketConn := s.createdPacketConn - s.closed = true - close(s.errorChan) - s.mutex.Unlock() - - <-s.running - s.connHandler.CloseServer() - if createdPacketConn { - return s.connHandler.Destroy() - } - return nil -} - -func (s *baseServer) setCloseError(e error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.closed { - return - } - s.closed = true - s.serverError = e - close(s.errorChan) -} - -// Addr returns the server's network address -func (s *baseServer) Addr() net.Addr { - return s.conn.LocalAddr() -} - -func (s *baseServer) handlePacket(p *receivedPacket) { - select { - case s.receivedPackets <- p: - default: - s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) - } - } -} - -func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { - if wire.IsVersionNegotiationPacket(p.data) { - s.logger.Debugf("Dropping Version Negotiation packet.") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - // If we're creating a new connection, the packet will be passed to the connection. - // The header will then be parsed again. - hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) - if err != nil && err != wire.ErrUnsupportedVersion { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Error parsing packet: %s", err) - return false - } - // Short header packets should never end up here in the first place - if !hdr.IsLongHeader { - panic(fmt.Sprintf("misrouted packet: %#v", hdr)) - } - if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { - s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - if p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p, hdr) - } - return false - } - if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { - // Drop long header packets. - // There's little point in sending a Stateless Reset, since the client - // might not have received the token yet. - s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - - s.logger.Debugf("<- Received Initial packet.") - - if err := s.handleInitialImpl(p, hdr); err != nil { - s.logger.Errorf("Error occurred handling initial packet: %s", err) - } - // Don't put the packet buffer back. - // handleInitialImpl deals with the buffer. - return true -} - -func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { - if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - p.buffer.Release() - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - } - return errors.New("too short connection ID") - } - - var ( - token *Token - retrySrcConnID *protocol.ConnectionID - ) - origDestConnID := hdr.DestConnectionID - if len(hdr.Token) > 0 { - c, err := s.tokenGenerator.DecodeToken(hdr.Token) - if err == nil { - token = &Token{ - IsRetryToken: c.IsRetryToken, - RemoteAddr: c.RemoteAddr, - SentTime: c.SentTime, - } - if token.IsRetryToken { - origDestConnID = c.OriginalDestConnectionID - retrySrcConnID = &c.RetrySrcConnectionID - } - } - } - if !s.config.AcceptToken(p.remoteAddr, token) { - go func() { - defer p.buffer.Release() - if token != nil && token.IsRetryToken { - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - return - } - if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error sending Retry: %s", err) - } - }() - return nil - } - - if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { - s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) - go func() { - defer p.buffer.Release() - if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error rejecting connection: %s", err) - } - }() - return nil - } - - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) - if err != nil { - return err - } - s.logger.Debugf("Changing connection ID to %s.", connID) - var conn quicConn - tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { - var tracer logging.ConnectionTracer - if s.config.Tracer != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := hdr.DestConnectionID - if origDestConnID.Len() > 0 { - connID = origDestConnID - } - tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), ConnectionTracingKey, tracingID), - protocol.PerspectiveServer, - connID, - ) - } - conn = s.newConn( - newSendConn(s.conn, p.remoteAddr, p.info), - s.connHandler, - origDestConnID, - retrySrcConnID, - hdr.DestConnectionID, - hdr.SrcConnectionID, - connID, - s.connHandler.GetStatelessResetToken(connID), - s.config, - s.tlsConf, - s.tokenGenerator, - s.acceptEarlyConns, - tracer, - tracingID, - s.logger, - hdr.Version, - ) - conn.handlePacket(p) - return conn - }); !added { - return nil - } - go conn.run() - go s.handleNewConn(conn) - if conn == nil { - p.buffer.Release() - return nil - } - return nil -} - -func (s *baseServer) handleNewConn(conn quicConn) { - connCtx := conn.Context() - if s.acceptEarlyConns { - // wait until the early connection is ready (or the handshake fails) - select { - case <-conn.earlyConnReady(): - case <-connCtx.Done(): - return - } - } else { - // wait until the handshake is complete (or fails) - select { - case <-conn.HandshakeComplete().Done(): - case <-connCtx.Done(): - return - } - } - - atomic.AddInt32(&s.connQueueLen, 1) - select { - case s.connQueue <- conn: - // blocks until the connection is accepted - case <-connCtx.Done(): - atomic.AddInt32(&s.connQueueLen, -1) - // don't pass connections that were already closed to Accept() - } -} - -func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { - // Log the Initial packet now. - // If no Retry is sent, the packet will be logged by the connection. - (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) - if err != nil { - return err - } - token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) - if err != nil { - return err - } - replyHdr := &wire.ExtendedHeader{} - replyHdr.IsLongHeader = true - replyHdr.Type = protocol.PacketTypeRetry - replyHdr.Version = hdr.Version - replyHdr.SrcConnectionID = srcConnID - replyHdr.DestConnectionID = hdr.SrcConnectionID - replyHdr.Token = token - if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", srcConnID) - s.logger.Debugf("-> Sending Retry") - replyHdr.Log(s.logger) - } - - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) - if err := replyHdr.Write(buf, hdr.Version); err != nil { - return err - } - // append the Retry integrity tag - tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) - buf.Write(tag[:]) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) - } - _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) - return err -} - -func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { - // Only send INVALID_TOKEN if we can unprotect the packet. - // This makes sure that we won't send it for packets that were corrupted. - sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) - data := p.data[:hdr.ParsedLen()+hdr.Length] - extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) - if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) - } - // don't return the error here. Just drop the packet. - return nil - } - hdrLen := extHdr.ParsedLen() - if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - // don't return the error here. Just drop the packet. - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) - } - return nil - } - if s.logger.Debug() { - s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) - } - return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) -} - -func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { - sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) - return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) -} - -// sendError sends the error as a response to the packet received with header hdr -func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) - - ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} - - replyHdr := &wire.ExtendedHeader{} - replyHdr.IsLongHeader = true - replyHdr.Type = protocol.PacketTypeInitial - replyHdr.Version = hdr.Version - replyHdr.SrcConnectionID = hdr.DestConnectionID - replyHdr.DestConnectionID = hdr.SrcConnectionID - replyHdr.PacketNumberLen = protocol.PacketNumberLen4 - replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) - if err := replyHdr.Write(buf, hdr.Version); err != nil { - return err - } - payloadOffset := buf.Len() - - if err := ccf.Write(buf, hdr.Version); err != nil { - return err - } - - raw := buf.Bytes() - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] - - pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) - sealer.EncryptHeader( - raw[pnOffset+4:pnOffset+4+16], - &raw[0], - raw[pnOffset:payloadOffset], - ) - - replyHdr.Log(s.logger) - wire.LogFrame(s.logger, ccf, true) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) - } - _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) - return err -} - -func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { - s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket( - p.remoteAddr, - &wire.Header{ - IsLongHeader: true, - DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, - }, - protocol.ByteCount(len(data)), - nil, - ) - } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - s.logger.Debugf("Error sending Version Negotiation: %s", err) - } -} diff --git a/internal/quic-go/server_test.go b/internal/quic-go/server_test.go deleted file mode 100644 index e5e9c48e..00000000 --- a/internal/quic-go/server_test.go +++ /dev/null @@ -1,1237 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "errors" - "net" - "reflect" - "runtime/pprof" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/logging" - mocklogging "github.com/imroc/req/v3/internal/quic-go/mocks/logging" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/testdata" - "github.com/imroc/req/v3/internal/quic-go/utils" - "github.com/imroc/req/v3/internal/quic-go/wire" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func areServersRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*baseServer).run") -} - -var _ = Describe("Server", func() { - var ( - conn *MockPacketConn - tlsConf *tls.Config - ) - - getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { - buffer := getPacketBuffer() - buf := bytes.NewBuffer(buffer.Data) - if hdr.IsLongHeader { - hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 - } - Expect((&wire.ExtendedHeader{ - Header: *hdr, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.VersionTLS)).To(Succeed()) - n := buf.Len() - buf.Write(p) - data := buffer.Data[:buf.Len()] - sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) - _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) - data = data[:len(data)+16] - sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) - return &receivedPacket{ - remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, - data: data, - buffer: buffer, - } - } - - getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: destConnID, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.buffer = getPacketBuffer() - p.remoteAddr = senderAddr - return p - } - - getInitialWithRandomDestConnID := func() *receivedPacket { - destConnID := make([]byte, 10) - _, err := rand.Read(destConnID) - Expect(err).ToNot(HaveOccurred()) - - return getInitial(destConnID) - } - - parseHeader := func(data []byte) *wire.Header { - hdr, _, _, err := wire.ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - return hdr - } - - BeforeEach(func() { - conn = NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) - tlsConf = testdata.GetTLSConfig() - tlsConf.NextProtos = []string{"proto1"} - }) - - AfterEach(func() { - Eventually(areServersRunning).Should(BeFalse()) - }) - - It("errors when no tls.Config is given", func() { - _, err := ListenAddr("localhost:0", nil, nil) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) - }) - - It("errors when the Config contains an invalid version", func() { - version := protocol.VersionNumber(0x1234) - _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) - }) - - It("fills in default values if options are not set in the Config", func() { - ln, err := Listen(conn, tlsConf, &Config{}) - Expect(err).ToNot(HaveOccurred()) - server := ln.(*baseServer) - Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) - Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken))) - Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second)) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("setups with the right values", func() { - supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} - acceptToken := func(_ net.Addr, _ *Token) bool { return true } - config := Config{ - Versions: supportedVersions, - AcceptToken: acceptToken, - HandshakeIdleTimeout: 1337 * time.Hour, - MaxIdleTimeout: 42 * time.Minute, - KeepAlivePeriod: 5 * time.Second, - StatelessResetKey: []byte("foobar"), - } - ln, err := Listen(conn, tlsConf, &config) - Expect(err).ToNot(HaveOccurred()) - server := ln.(*baseServer) - Expect(server.connHandler).ToNot(BeNil()) - Expect(server.config.Versions).To(Equal(supportedVersions)) - Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) - Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken))) - Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) - Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("listens on a given address", func() { - addr := "127.0.0.1:13579" - ln, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).ToNot(HaveOccurred()) - Expect(ln.Addr().String()).To(Equal(addr)) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("errors if given an invalid address", func() { - addr := "127.0.0.1" - _, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) - }) - - It("errors if given an invalid address", func() { - addr := "1.1.1.1:1111" - _, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) - }) - - Context("server accepting connections that completed the handshake", func() { - var ( - serv *baseServer - phm *MockPacketHandlerManager - tracer *mocklogging.MockTracer - ) - - BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) - Expect(err).ToNot(HaveOccurred()) - serv = ln.(*baseServer) - phm = NewMockPacketHandlerManager(mockCtrl) - serv.connHandler = phm - }) - - AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() - }) - - Context("handling packets", func() { - It("drops Initial packets with a too short connection ID", func() { - p := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, nil) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("drops too small Initial", func() { - p := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize-100), - ) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("drops non-Initial packets", func() { - p := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: serv.config.Versions[0], - }, []byte("invalid")) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("decodes the token from the Token field", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).ToNot(BeNil()) - close(done) - return false - } - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("passes an empty token to the callback, if decoding fails", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).To(BeNil()) - close(done) - return false - } - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("creates a connection when the token is accepted", func() { - serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } - retryToken, err := serv.tokenGenerator.NewRetryToken( - &net.UDPAddr{}, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - ) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - Token: retryToken, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - run := make(chan struct{}) - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { - newConnID = c - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { - newConnID = c - return token - }) - fn() - return true - }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) - conn := NewMockQuicConn(mockCtrl) - serv.newConn = func( - _ sendConn, - _ connRunner, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - tokenP protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - enable0RTT bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - Expect(enable0RTT).To(BeFalse()) - Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(destConnID).To(Equal(hdr.SrcConnectionID)) - // make sure we're using a server-generated connection ID - Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) - Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - Expect(srcConnID).To(Equal(newConnID)) - Expect(tokenP).To(Equal(token)) - conn.EXPECT().handlePacket(p) - conn.EXPECT().run().Do(func() { close(run) }) - conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.handlePacket(p) - // the Handshake packet is written by the connection. - // Make sure there are no Write calls on the packet conn. - time.Sleep(50 * time.Millisecond) - close(done) - }() - // make sure we're using a server-generated connection ID - Eventually(run).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - - It("sends a Version Negotiation Packet for unsupported versions", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { - Expect(replyHdr.IsLongHeader).To(BeTrue()) - Expect(replyHdr.Version).To(BeZero()) - Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) - Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(srcConnID)) - Expect(hdr.SrcConnectionID).To(Equal(destConnID)) - Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send a Version Negotiation packets if sending them is disabled", func() { - serv.config.DisableVersionNegotiationPackets = true - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0) - serv.handlePacket(packet) - Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) - }) - - It("ignores Version Negotiation packets", func() { - data := wire.ComposeVersionNegotiation( - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{4, 3, 2, 1}, - []protocol.VersionNumber{1, 2, 3}, - ) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(done) - }) - serv.handlePacket(&receivedPacket{ - remoteAddr: raddr, - data: data, - buffer: getPacketBuffer(), - }) - Eventually(done).Should(BeClosed()) - // make sure no other packet is sent - time.Sleep(scaleDuration(20 * time.Millisecond)) - }) - - It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - p := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) - Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - p.remoteAddr = raddr - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(done) - }) - serv.handlePacket(p) - Eventually(done).Should(BeClosed()) - // make sure no other packet is sent - time.Sleep(scaleDuration(20 * time.Millisecond)) - }) - - It("replies with a Retry packet, if a Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(frames).To(HaveLen(1)) - Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := frames[0].(*logging.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) - serv.handlePacket(packet) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(done).Should(BeClosed()) - }) - - It("creates a connection, if no Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - run := make(chan struct{}) - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { - newConnID = c - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { - newConnID = c - return token - }) - fn() - return true - }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - - conn := NewMockQuicConn(mockCtrl) - serv.newConn = func( - _ sendConn, - _ connRunner, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - tokenP protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - enable0RTT bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - Expect(enable0RTT).To(BeFalse()) - Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(retrySrcConnID).To(BeNil()) - Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(destConnID).To(Equal(hdr.SrcConnectionID)) - // make sure we're using a server-generated connection ID - Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) - Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - Expect(srcConnID).To(Equal(newConnID)) - Expect(tokenP).To(Equal(token)) - conn.EXPECT().handlePacket(p) - conn.EXPECT().run().Do(func() { close(run) }) - conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(context.Background()) - return conn - } - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.handlePacket(p) - // the Handshake packet is written by the connection - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - close(done) - }() - // make sure we're using a server-generated connection ID - Eventually(run).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - - It("drops packets if the receive queue is full", func() { - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }).AnyTimes() - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() - - serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } - acceptConn := make(chan struct{}) - var counter uint32 // to be used as an atomic, so we query it in Eventually - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - <-acceptConn - atomic.AddUint32(&counter, 1) - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) - conn.EXPECT().run().MaxTimes(1) - conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) - conn.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) - return conn - } - - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) - serv.handlePacket(p) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) - var wg sync.WaitGroup - for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) - }() - } - wg.Wait() - - close(acceptConn) - Eventually( - func() uint32 { return atomic.LoadUint32(&counter) }, - scaleDuration(100*time.Millisecond), - ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) - Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) - }) - - It("only creates a single connection for a duplicate Initial", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - var createdConn bool - conn := NewMockQuicConn(mockCtrl) - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - createdConn = true - return conn - } - - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) - Expect(serv.handlePacketImpl(p)).To(BeTrue()) - Expect(createdConn).To(BeFalse()) - }) - - It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run() - conn.EXPECT().Context().Return(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) - return conn - } - - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }).Times(protocol.MaxAcceptQueueSize) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) - - var wg sync.WaitGroup - wg.Add(protocol.MaxAcceptQueueSize) - for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - go func() { - defer GinkgoRecover() - defer wg.Done() - serv.handlePacket(getInitialWithRandomDestConnID()) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }() - } - wg.Wait() - p := getInitialWithRandomDestConnID() - hdr, _, _, err := wire.ParsePacket(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - rejectHdr := parseHeader(b) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - return len(b), nil - }) - serv.handlePacket(p) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - ctx, cancel := context.WithCancel(context.Background()) - connCreated := make(chan struct{}) - conn := NewMockQuicConn(mockCtrl) - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - conn.EXPECT().handlePacket(p) - conn.EXPECT().run() - conn.EXPECT().Context().Return(ctx) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) - close(connCreated) - return conn - } - - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) - - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(connCreated).Should(BeClosed()) - cancel() - time.Sleep(scaleDuration(200 * time.Millisecond)) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.Accept(context.Background()) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - - // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID - Expect(serv.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("accepting connections", func() { - It("returns Accept when an error occurs", func() { - testErr := errors.New("test err") - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - - serv.setCloseError(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("returns immediately, if an error occurred before", func() { - testErr := errors.New("test err") - serv.setCloseError(testErr) - for i := 0; i < 3; i++ { - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) - } - }) - - It("returns when the context is canceled", func() { - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(ctx) - Expect(err).To(MatchError("context canceled")) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - cancel() - Eventually(done).Should(BeClosed()) - }) - - It("accepts new connections when the handshake completes", func() { - conn := NewMockQuicConn(mockCtrl) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(conn)) - close(done) - }() - - ctx, cancel := context.WithCancel(context.Background()) // handshake context - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().HandshakeComplete().Return(ctx) - conn.EXPECT().run().Do(func() {}) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) - serv.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, - ) - Consistently(done).ShouldNot(BeClosed()) - cancel() // complete the handshake - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("server accepting connections that haven't completed the handshake", func() { - var ( - serv *earlyServer - phm *MockPacketHandlerManager - ) - - BeforeEach(func() { - ln, err := ListenEarly(conn, tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - serv = ln.(*earlyServer) - phm = NewMockPacketHandlerManager(mockCtrl) - serv.connHandler = phm - }) - - AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() - }) - - It("accepts new connections when they become ready", func() { - conn := NewMockQuicConn(mockCtrl) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(conn)) - close(done) - }() - - ready := make(chan struct{}) - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - enable0RTT bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - Expect(enable0RTT).To(BeTrue()) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run().Do(func() {}) - conn.EXPECT().earlyConnReady().Return(ready) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }) - serv.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, - ) - Consistently(done).ShouldNot(BeClosed()) - close(ready) - Eventually(done).Should(BeClosed()) - }) - - It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - ready := make(chan struct{}) - close(ready) - conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run() - conn.EXPECT().earlyConnReady().Return(ready) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }).Times(protocol.MaxAcceptQueueSize) - for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - serv.handlePacket(getInitialWithRandomDestConnID()) - } - - Eventually(func() int32 { return atomic.LoadInt32(&serv.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - - p := getInitialWithRandomDestConnID() - hdr := parseHeader(p.data) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - rejectHdr := parseHeader(b) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - return len(b), nil - }) - serv.handlePacket(p) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - ctx, cancel := context.WithCancel(context.Background()) - connCreated := make(chan struct{}) - conn := NewMockQuicConn(mockCtrl) - serv.newConn = func( - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - conn.EXPECT().handlePacket(p) - conn.EXPECT().run() - conn.EXPECT().earlyConnReady() - conn.EXPECT().Context().Return(ctx) - close(connCreated) - return conn - } - - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true - }) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(connCreated).Should(BeClosed()) - cancel() - time.Sleep(scaleDuration(200 * time.Millisecond)) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.Accept(context.Background()) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - - // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID - Expect(serv.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - }) -}) - -var _ = Describe("default source address verification", func() { - It("accepts a token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("requests verification if no token is provided", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse()) - }) - - It("rejects a token if the address doesn't match", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "127.0.0.1", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:1337", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("rejects an invalid token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:7331", // mismatching port - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("rejects an expired token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a non-retry token", func() { - Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity)) - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: false, - RemoteAddr: "192.168.0.1", - // if this was a retry token, it would have expired one second ago - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) -}) diff --git a/internal/quic-go/stream.go b/internal/quic-go/stream.go deleted file mode 100644 index 708304d1..00000000 --- a/internal/quic-go/stream.go +++ /dev/null @@ -1,149 +0,0 @@ -package quic - -import ( - "net" - "os" - "sync" - "time" - - "github.com/imroc/req/v3/internal/quic-go/ackhandler" - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type deadlineError struct{} - -func (deadlineError) Error() string { return "deadline exceeded" } -func (deadlineError) Temporary() bool { return true } -func (deadlineError) Timeout() bool { return true } -func (deadlineError) Unwrap() error { return os.ErrDeadlineExceeded } - -var errDeadline net.Error = &deadlineError{} - -// The streamSender is notified by the stream about various events. -type streamSender interface { - queueControlFrame(wire.Frame) - onHasStreamData(protocol.StreamID) - // must be called without holding the mutex that is acquired by closeForShutdown - onStreamCompleted(protocol.StreamID) -} - -// Each of the both stream halves gets its own uniStreamSender. -// This is necessary in order to keep track when both halves have been completed. -type uniStreamSender struct { - streamSender - onStreamCompletedImpl func() -} - -func (s *uniStreamSender) queueControlFrame(f wire.Frame) { - s.streamSender.queueControlFrame(f) -} - -func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { - s.streamSender.onHasStreamData(id) -} - -func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { - s.onStreamCompletedImpl() -} - -var _ streamSender = &uniStreamSender{} - -type streamI interface { - Stream - closeForShutdown(error) - // for receiving - handleStreamFrame(*wire.StreamFrame) error - handleResetStreamFrame(*wire.ResetStreamFrame) error - getWindowUpdate() protocol.ByteCount - // for sending - hasData() bool - handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) - updateSendWindow(protocol.ByteCount) -} - -var ( - _ receiveStreamI = (streamI)(nil) - _ sendStreamI = (streamI)(nil) -) - -// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface -// -// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. -type stream struct { - receiveStream - sendStream - - completedMutex sync.Mutex - sender streamSender - receiveStreamCompleted bool - sendStreamCompleted bool - - version protocol.VersionNumber -} - -var _ Stream = &stream{} - -// newStream creates a new Stream -func newStream(streamID protocol.StreamID, - sender streamSender, - flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, -) *stream { - s := &stream{sender: sender, version: version} - senderForSendStream := &uniStreamSender{ - streamSender: sender, - onStreamCompletedImpl: func() { - s.completedMutex.Lock() - s.sendStreamCompleted = true - s.checkIfCompleted() - s.completedMutex.Unlock() - }, - } - s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) - senderForReceiveStream := &uniStreamSender{ - streamSender: sender, - onStreamCompletedImpl: func() { - s.completedMutex.Lock() - s.receiveStreamCompleted = true - s.checkIfCompleted() - s.completedMutex.Unlock() - }, - } - s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) - return s -} - -// need to define StreamID() here, since both receiveStream and readStream have a StreamID() -func (s *stream) StreamID() protocol.StreamID { - // the result is same for receiveStream and sendStream - return s.sendStream.StreamID() -} - -func (s *stream) Close() error { - return s.sendStream.Close() -} - -func (s *stream) SetDeadline(t time.Time) error { - _ = s.SetReadDeadline(t) // SetReadDeadline never errors - _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors - return nil -} - -// CloseForShutdown closes a stream abruptly. -// It makes Read and Write unblock (and return the error) immediately. -// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. -func (s *stream) closeForShutdown(err error) { - s.sendStream.closeForShutdown(err) - s.receiveStream.closeForShutdown(err) -} - -// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. -// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. -func (s *stream) checkIfCompleted() { - if s.sendStreamCompleted && s.receiveStreamCompleted { - s.sender.onStreamCompleted(s.StreamID()) - } -} diff --git a/internal/quic-go/stream_test.go b/internal/quic-go/stream_test.go deleted file mode 100644 index 51f1e131..00000000 --- a/internal/quic-go/stream_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package quic - -import ( - "errors" - "io" - "os" - "strconv" - "time" - - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" -) - -// in the tests for the stream deadlines we set a deadline -// and wait to make an assertion when Read / Write was unblocked -// on the CIs, the timing is a lot less precise, so scale every duration by this factor -func scaleDuration(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t -} - -var _ = Describe("Stream", func() { - const streamID protocol.StreamID = 1337 - - var ( - str *stream - strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender - ) - - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, mockSender, mockFC, protocol.VersionWhatever) - - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = struct { - io.Reader - io.Writer - }{ - gbytes.TimeoutReader(str, timeout), - gbytes.TimeoutWriter(str, timeout), - } - }) - - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) - }) - - Context("deadlines", func() { - It("sets a write deadline, when SetDeadline is called", func() { - str.SetDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("sets a read deadline, when SetDeadline is called", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() - f := &wire.StreamFrame{Data: []byte("foobar")} - err := str.handleStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) - str.SetDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - }) - - Context("completing", func() { - It("is not completed when only the receive side is completed", func() { - // don't EXPECT a call to mockSender.onStreamCompleted() - str.receiveStream.sender.onStreamCompleted(streamID) - }) - - It("is not completed when only the send side is completed", func() { - // don't EXPECT a call to mockSender.onStreamCompleted() - str.sendStream.sender.onStreamCompleted(streamID) - }) - - It("is completed when both sides are completed", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - str.sendStream.sender.onStreamCompleted(streamID) - str.receiveStream.sender.onStreamCompleted(streamID) - }) - }) -}) - -var _ = Describe("Deadline Error", func() { - It("is a net.Error that wraps os.ErrDeadlineError", func() { - err := deadlineError{} - Expect(err.Timeout()).To(BeTrue()) - Expect(errors.Is(err, os.ErrDeadlineExceeded)).To(BeTrue()) - Expect(errors.Unwrap(err)).To(Equal(os.ErrDeadlineExceeded)) - }) -}) diff --git a/internal/quic-go/streams_map.go b/internal/quic-go/streams_map.go deleted file mode 100644 index 93465533..00000000 --- a/internal/quic-go/streams_map.go +++ /dev/null @@ -1,317 +0,0 @@ -package quic - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type streamError struct { - message string - nums []protocol.StreamNum -} - -func (e streamError) Error() string { - return e.message -} - -func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { - strError, ok := err.(streamError) - if !ok { - return err - } - ids := make([]interface{}, len(strError.nums)) - for i, num := range strError.nums { - ids[i] = num.StreamID(stype, pers) - } - return fmt.Errorf(strError.Error(), ids...) -} - -type streamOpenErr struct{ error } - -var _ net.Error = &streamOpenErr{} - -func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } -func (streamOpenErr) Timeout() bool { return false } - -// errTooManyOpenStreams is used internally by the outgoing streams maps. -var errTooManyOpenStreams = errors.New("too many open streams") - -type streamsMap struct { - perspective protocol.Perspective - version protocol.VersionNumber - - maxIncomingBidiStreams uint64 - maxIncomingUniStreams uint64 - - sender streamSender - newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController - - mutex sync.Mutex - outgoingBidiStreams *outgoingBidiStreamsMap - outgoingUniStreams *outgoingUniStreamsMap - incomingBidiStreams *incomingBidiStreamsMap - incomingUniStreams *incomingUniStreamsMap - reset bool -} - -var _ streamManager = &streamsMap{} - -func newStreamsMap( - sender streamSender, - newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, - maxIncomingBidiStreams uint64, - maxIncomingUniStreams uint64, - perspective protocol.Perspective, - version protocol.VersionNumber, -) streamManager { - m := &streamsMap{ - perspective: perspective, - newFlowController: newFlowController, - maxIncomingBidiStreams: maxIncomingBidiStreams, - maxIncomingUniStreams: maxIncomingUniStreams, - sender: sender, - version: version, - } - m.initMaps() - return m -} - -func (m *streamsMap) initMaps() { - m.outgoingBidiStreams = newOutgoingBidiStreamsMap( - func(num protocol.StreamNum) streamI { - id := num.StreamID(protocol.StreamTypeBidi, m.perspective) - return newStream(id, m.sender, m.newFlowController(id), m.version) - }, - m.sender.queueControlFrame, - ) - m.incomingBidiStreams = newIncomingBidiStreamsMap( - func(num protocol.StreamNum) streamI { - id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) - return newStream(id, m.sender, m.newFlowController(id), m.version) - }, - m.maxIncomingBidiStreams, - m.sender.queueControlFrame, - ) - m.outgoingUniStreams = newOutgoingUniStreamsMap( - func(num protocol.StreamNum) sendStreamI { - id := num.StreamID(protocol.StreamTypeUni, m.perspective) - return newSendStream(id, m.sender, m.newFlowController(id), m.version) - }, - m.sender.queueControlFrame, - ) - m.incomingUniStreams = newIncomingUniStreamsMap( - func(num protocol.StreamNum) receiveStreamI { - id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) - return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) - }, - m.maxIncomingUniStreams, - m.sender.queueControlFrame, - ) -} - -func (m *streamsMap) OpenStream() (Stream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.outgoingBidiStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.OpenStream() - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) -} - -func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.outgoingBidiStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.OpenStreamSync(ctx) - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) -} - -func (m *streamsMap) OpenUniStream() (SendStream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.outgoingUniStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.OpenStream() - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) -} - -func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.outgoingUniStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.OpenStreamSync(ctx) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) -} - -func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.incomingBidiStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.AcceptStream(ctx) - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) -} - -func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { - m.mutex.Lock() - reset := m.reset - mm := m.incomingUniStreams - m.mutex.Unlock() - if reset { - return nil, Err0RTTRejected - } - str, err := mm.AcceptStream(ctx) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) -} - -func (m *streamsMap) DeleteStream(id protocol.StreamID) error { - num := id.StreamNum() - switch id.Type() { - case protocol.StreamTypeUni: - if id.InitiatedBy() == m.perspective { - return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) - } - return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) - case protocol.StreamTypeBidi: - if id.InitiatedBy() == m.perspective { - return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) - } - return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) - } - panic("") -} - -func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { - str, err := m.getOrOpenReceiveStream(id) - if err != nil { - return nil, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: err.Error(), - } - } - return str, nil -} - -func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { - num := id.StreamNum() - switch id.Type() { - case protocol.StreamTypeUni: - if id.InitiatedBy() == m.perspective { - // an outgoing unidirectional stream is a send stream, not a receive stream - return nil, fmt.Errorf("peer attempted to open receive stream %d", id) - } - str, err := m.incomingUniStreams.GetOrOpenStream(num) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) - case protocol.StreamTypeBidi: - var str receiveStreamI - var err error - if id.InitiatedBy() == m.perspective { - str, err = m.outgoingBidiStreams.GetStream(num) - } else { - str, err = m.incomingBidiStreams.GetOrOpenStream(num) - } - return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) - } - panic("") -} - -func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { - str, err := m.getOrOpenSendStream(id) - if err != nil { - return nil, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: err.Error(), - } - } - return str, nil -} - -func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { - num := id.StreamNum() - switch id.Type() { - case protocol.StreamTypeUni: - if id.InitiatedBy() == m.perspective { - str, err := m.outgoingUniStreams.GetStream(num) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) - } - // an incoming unidirectional stream is a receive stream, not a send stream - return nil, fmt.Errorf("peer attempted to open send stream %d", id) - case protocol.StreamTypeBidi: - var str sendStreamI - var err error - if id.InitiatedBy() == m.perspective { - str, err = m.outgoingBidiStreams.GetStream(num) - } else { - str, err = m.incomingBidiStreams.GetOrOpenStream(num) - } - return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) - } - panic("") -} - -func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { - switch f.Type { - case protocol.StreamTypeUni: - m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) - case protocol.StreamTypeBidi: - m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) - } -} - -func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { - m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) - m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) - m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) - m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) -} - -func (m *streamsMap) CloseWithError(err error) { - m.outgoingBidiStreams.CloseWithError(err) - m.outgoingUniStreams.CloseWithError(err) - m.incomingBidiStreams.CloseWithError(err) - m.incomingUniStreams.CloseWithError(err) -} - -// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are -// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. -// 2. reset to their initial state, such that we can immediately process new incoming stream data. -// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, -// until UseResetMaps() has been called. -func (m *streamsMap) ResetFor0RTT() { - m.mutex.Lock() - defer m.mutex.Unlock() - m.reset = true - m.CloseWithError(Err0RTTRejected) - m.initMaps() -} - -func (m *streamsMap) UseResetMaps() { - m.mutex.Lock() - m.reset = false - m.mutex.Unlock() -} diff --git a/internal/quic-go/streams_map_generic_helper.go b/internal/quic-go/streams_map_generic_helper.go deleted file mode 100644 index eda9f741..00000000 --- a/internal/quic-go/streams_map_generic_helper.go +++ /dev/null @@ -1,18 +0,0 @@ -package quic - -import ( - "github.com/cheekybits/genny/generic" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// In the auto-generated streams maps, we need to be able to close the streams. -// Therefore, extend the generic.Type with the stream close method. -// This definition must be in a file that Genny doesn't process. -type item interface { - generic.Type - updateSendWindow(protocol.ByteCount) - closeForShutdown(error) -} - -const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni diff --git a/internal/quic-go/streams_map_incoming_bidi.go b/internal/quic-go/streams_map_incoming_bidi.go deleted file mode 100644 index 6b80a359..00000000 --- a/internal/quic-go/streams_map_incoming_bidi.go +++ /dev/null @@ -1,192 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// When a stream is deleted before it was accepted, we can't delete it from the map immediately. -// We need to wait until the application accepts it, and delete it then. -type streamIEntry struct { - stream streamI - shouldDelete bool -} - -type incomingBidiStreamsMap struct { - mutex sync.RWMutex - newStreamChan chan struct{} - - streams map[protocol.StreamNum]streamIEntry - - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams - - newStream func(protocol.StreamNum) streamI - queueMaxStreamID func(*wire.MaxStreamsFrame) - - closeErr error -} - -func newIncomingBidiStreamsMap( - newStream func(protocol.StreamNum) streamI, - maxStreams uint64, - queueControlFrame func(wire.Frame), -) *incomingBidiStreamsMap { - return &incomingBidiStreamsMap{ - newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]streamIEntry), - maxStream: protocol.StreamNum(maxStreams), - maxNumStreams: maxStreams, - newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, - queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - } -} - -func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { - // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist - select { - case <-m.newStreamChan: - default: - } - - m.mutex.Lock() - - var num protocol.StreamNum - var entry streamIEntry - for { - num = m.nextStreamToAccept - if m.closeErr != nil { - m.mutex.Unlock() - return nil, m.closeErr - } - var ok bool - entry, ok = m.streams[num] - if ok { - break - } - m.mutex.Unlock() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.newStreamChan: - } - m.mutex.Lock() - } - m.nextStreamToAccept++ - // If this stream was completed before being accepted, we can delete it now. - if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { - m.mutex.Unlock() - return nil, err - } - } - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) { - m.mutex.RLock() - if num > m.maxStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } - } - // if the num is smaller than the highest we accepted - // * this stream exists in the map, and we can return it, or - // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { - var s streamI - // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { - s = entry.stream - } - m.mutex.RUnlock() - return s, nil - } - m.mutex.RUnlock() - - m.mutex.Lock() - // no need to check the two error conditions from above again - // * maxStream can only increase, so if the id was valid before, it definitely is valid now - // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = streamIEntry{stream: m.newStream(newNum)} - select { - case m.newStreamChan <- struct{}{}: - default: - } - } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - return m.deleteStream(num) -} - -func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } - } - - // Don't delete this stream yet, if it was not yet accepted. - // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] - if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } - } - entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign - return nil - } - - delete(m.streams, num) - // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream - if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { - m.maxStream = maxStream - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: m.maxStream, - }) - } - } - return nil -} - -func (m *incomingBidiStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, entry := range m.streams { - entry.stream.closeForShutdown(err) - } - m.mutex.Unlock() - close(m.newStreamChan) -} diff --git a/internal/quic-go/streams_map_incoming_generic.go b/internal/quic-go/streams_map_incoming_generic.go deleted file mode 100644 index 35a9e12d..00000000 --- a/internal/quic-go/streams_map_incoming_generic.go +++ /dev/null @@ -1,190 +0,0 @@ -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// When a stream is deleted before it was accepted, we can't delete it from the map immediately. -// We need to wait until the application accepts it, and delete it then. -type itemEntry struct { - stream item - shouldDelete bool -} - -//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" -//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" -type incomingItemsMap struct { - mutex sync.RWMutex - newStreamChan chan struct{} - - streams map[protocol.StreamNum]itemEntry - - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams - - newStream func(protocol.StreamNum) item - queueMaxStreamID func(*wire.MaxStreamsFrame) - - closeErr error -} - -func newIncomingItemsMap( - newStream func(protocol.StreamNum) item, - maxStreams uint64, - queueControlFrame func(wire.Frame), -) *incomingItemsMap { - return &incomingItemsMap{ - newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]itemEntry), - maxStream: protocol.StreamNum(maxStreams), - maxNumStreams: maxStreams, - newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, - queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - } -} - -func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { - // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist - select { - case <-m.newStreamChan: - default: - } - - m.mutex.Lock() - - var num protocol.StreamNum - var entry itemEntry - for { - num = m.nextStreamToAccept - if m.closeErr != nil { - m.mutex.Unlock() - return nil, m.closeErr - } - var ok bool - entry, ok = m.streams[num] - if ok { - break - } - m.mutex.Unlock() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.newStreamChan: - } - m.mutex.Lock() - } - m.nextStreamToAccept++ - // If this stream was completed before being accepted, we can delete it now. - if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { - m.mutex.Unlock() - return nil, err - } - } - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) { - m.mutex.RLock() - if num > m.maxStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } - } - // if the num is smaller than the highest we accepted - // * this stream exists in the map, and we can return it, or - // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { - var s item - // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { - s = entry.stream - } - m.mutex.RUnlock() - return s, nil - } - m.mutex.RUnlock() - - m.mutex.Lock() - // no need to check the two error conditions from above again - // * maxStream can only increase, so if the id was valid before, it definitely is valid now - // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = itemEntry{stream: m.newStream(newNum)} - select { - case m.newStreamChan <- struct{}{}: - default: - } - } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - return m.deleteStream(num) -} - -func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } - } - - // Don't delete this stream yet, if it was not yet accepted. - // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] - if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } - } - entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign - return nil - } - - delete(m.streams, num) - // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream - if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { - m.maxStream = maxStream - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: streamTypeGeneric, - MaxStreamNum: m.maxStream, - }) - } - } - return nil -} - -func (m *incomingItemsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, entry := range m.streams { - entry.stream.closeForShutdown(err) - } - m.mutex.Unlock() - close(m.newStreamChan) -} diff --git a/internal/quic-go/streams_map_incoming_generic_test.go b/internal/quic-go/streams_map_incoming_generic_test.go deleted file mode 100644 index 0983ba4f..00000000 --- a/internal/quic-go/streams_map_incoming_generic_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "errors" - "math/rand" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type mockGenericStream struct { - num protocol.StreamNum - - closed bool - closeErr error - sendWindow protocol.ByteCount -} - -func (s *mockGenericStream) closeForShutdown(err error) { - s.closed = true - s.closeErr = err -} - -func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { - s.sendWindow = limit -} - -var _ = Describe("Streams Map (incoming)", func() { - var ( - m *incomingItemsMap - newItemCounter int - mockSender *MockStreamSender - maxNumStreams uint64 - ) - - // check that the frame can be serialized and deserialized - checkFrameSerialization := func(f wire.Frame) { - b := &bytes.Buffer{} - ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) - frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - Expect(f).To(Equal(frame)) - } - - BeforeEach(func() { maxNumStreams = 5 }) - - JustBeforeEach(func() { - newItemCounter = 0 - mockSender = NewMockStreamSender(mockCtrl) - m = newIncomingItemsMap( - func(num protocol.StreamNum) item { - newItemCounter++ - return &mockGenericStream{num: num} - }, - maxNumStreams, - mockSender.queueControlFrame, - ) - }) - - It("opens all streams up to the id on GetOrOpenStream", func() { - _, err := m.GetOrOpenStream(4) - Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(4)) - }) - - It("starts opening streams at the right position", func() { - // like the test above, but with 2 calls to GetOrOpenStream - _, err := m.GetOrOpenStream(2) - Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(2)) - _, err = m.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(5)) - }) - - It("accepts streams in the right order", func() { - _, err := m.GetOrOpenStream(2) // open streams 1 and 2 - Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - str, err = m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - }) - - It("allows opening the maximum stream ID", func() { - str, err := m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - }) - - It("errors when trying to get a stream ID higher than the maximum", func() { - _, err := m.GetOrOpenStream(6) - Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) - }) - - It("blocks AcceptStream until a new stream is available", func() { - strChan := make(chan item) - go func() { - defer GinkgoRecover() - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - strChan <- str - }() - Consistently(strChan).ShouldNot(Receive()) - str, err := m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - var acceptedStr item - Eventually(strChan).Should(Receive(&acceptedStr)) - Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - }) - - It("unblocks AcceptStream when the context is canceled", func() { - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.AcceptStream(ctx) - Expect(err).To(MatchError("context canceled")) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - cancel() - Eventually(done).Should(BeClosed()) - }) - - It("unblocks AcceptStream when it is closed", func() { - testErr := errors.New("test error") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.AcceptStream(context.Background()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - m.CloseWithError(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("errors AcceptStream immediately if it is closed", func() { - testErr := errors.New("test error") - m.CloseWithError(testErr) - _, err := m.AcceptStream(context.Background()) - Expect(err).To(MatchError(testErr)) - }) - - It("closes all streams when CloseWithError is called", func() { - str1, err := m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - str2, err := m.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - testErr := errors.New("test err") - m.CloseWithError(testErr) - Expect(str1.(*mockGenericStream).closed).To(BeTrue()) - Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) - Expect(str2.(*mockGenericStream).closed).To(BeTrue()) - Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) - }) - - It("deletes streams", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - _, err := m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - Expect(m.DeleteStream(1)).To(Succeed()) - str, err = m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - }) - - It("waits until a stream is accepted before actually deleting it", func() { - _, err := m.GetOrOpenStream(2) - Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(2)).To(Succeed()) - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - }) - - It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { - str, err := m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - Expect(m.DeleteStream(1)).To(Succeed()) - str, err = m.GetOrOpenStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) - - It("errors when deleting a non-existing stream", func() { - err := m.DeleteStream(1337) - Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337")) - }) - - It("sends MAX_STREAMS frames when streams are deleted", func() { - // open a bunch of streams - _, err := m.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - // accept all streams - for i := 0; i < 5; i++ { - _, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - } - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) - checkFrameSerialization(f) - }) - Expect(m.DeleteStream(3)).To(Succeed()) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) - checkFrameSerialization(f) - }) - Expect(m.DeleteStream(4)).To(Succeed()) - }) - - Context("using high stream limits", func() { - BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) - - It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { - // open a bunch of streams - _, err := m.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - // accept all streams - for i := 0; i < 5; i++ { - _, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - } - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) - checkFrameSerialization(f) - }) - Expect(m.DeleteStream(4)).To(Succeed()) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) - checkFrameSerialization(f) - }) - Expect(m.DeleteStream(3)).To(Succeed()) - // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent - Expect(m.DeleteStream(2)).To(Succeed()) - Expect(m.DeleteStream(1)).To(Succeed()) - }) - }) - - Context("randomized tests", func() { - const num = 1000 - - BeforeEach(func() { maxNumStreams = num }) - - It("opens and accepts streams", func() { - rand.Seed(GinkgoRandomSeed()) - ids := make([]protocol.StreamNum, num) - for i := 0; i < num; i++ { - ids[i] = protocol.StreamNum(i + 1) - } - rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) - - const timeout = 5 * time.Second - done := make(chan struct{}, 2) - go func() { - defer GinkgoRecover() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - for i := 0; i < num; i++ { - _, err := m.AcceptStream(ctx) - Expect(err).ToNot(HaveOccurred()) - } - done <- struct{}{} - }() - - go func() { - defer GinkgoRecover() - for i := 0; i < num; i++ { - _, err := m.GetOrOpenStream(ids[i]) - Expect(err).ToNot(HaveOccurred()) - } - done <- struct{}{} - }() - - Eventually(done, timeout*3/2).Should(Receive()) - Eventually(done, timeout*3/2).Should(Receive()) - }) - }) -}) diff --git a/internal/quic-go/streams_map_incoming_uni.go b/internal/quic-go/streams_map_incoming_uni.go deleted file mode 100644 index c567a562..00000000 --- a/internal/quic-go/streams_map_incoming_uni.go +++ /dev/null @@ -1,192 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// When a stream is deleted before it was accepted, we can't delete it from the map immediately. -// We need to wait until the application accepts it, and delete it then. -type receiveStreamIEntry struct { - stream receiveStreamI - shouldDelete bool -} - -type incomingUniStreamsMap struct { - mutex sync.RWMutex - newStreamChan chan struct{} - - streams map[protocol.StreamNum]receiveStreamIEntry - - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams - - newStream func(protocol.StreamNum) receiveStreamI - queueMaxStreamID func(*wire.MaxStreamsFrame) - - closeErr error -} - -func newIncomingUniStreamsMap( - newStream func(protocol.StreamNum) receiveStreamI, - maxStreams uint64, - queueControlFrame func(wire.Frame), -) *incomingUniStreamsMap { - return &incomingUniStreamsMap{ - newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]receiveStreamIEntry), - maxStream: protocol.StreamNum(maxStreams), - maxNumStreams: maxStreams, - newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, - queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - } -} - -func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { - // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist - select { - case <-m.newStreamChan: - default: - } - - m.mutex.Lock() - - var num protocol.StreamNum - var entry receiveStreamIEntry - for { - num = m.nextStreamToAccept - if m.closeErr != nil { - m.mutex.Unlock() - return nil, m.closeErr - } - var ok bool - entry, ok = m.streams[num] - if ok { - break - } - m.mutex.Unlock() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.newStreamChan: - } - m.mutex.Lock() - } - m.nextStreamToAccept++ - // If this stream was completed before being accepted, we can delete it now. - if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { - m.mutex.Unlock() - return nil, err - } - } - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) { - m.mutex.RLock() - if num > m.maxStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } - } - // if the num is smaller than the highest we accepted - // * this stream exists in the map, and we can return it, or - // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { - var s receiveStreamI - // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { - s = entry.stream - } - m.mutex.RUnlock() - return s, nil - } - m.mutex.RUnlock() - - m.mutex.Lock() - // no need to check the two error conditions from above again - // * maxStream can only increase, so if the id was valid before, it definitely is valid now - // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = receiveStreamIEntry{stream: m.newStream(newNum)} - select { - case m.newStreamChan <- struct{}{}: - default: - } - } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - return m.deleteStream(num) -} - -func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } - } - - // Don't delete this stream yet, if it was not yet accepted. - // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] - if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } - } - entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign - return nil - } - - delete(m.streams, num) - // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream - if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { - m.maxStream = maxStream - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: m.maxStream, - }) - } - } - return nil -} - -func (m *incomingUniStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, entry := range m.streams { - entry.stream.closeForShutdown(err) - } - m.mutex.Unlock() - close(m.newStreamChan) -} diff --git a/internal/quic-go/streams_map_outgoing_bidi.go b/internal/quic-go/streams_map_outgoing_bidi.go deleted file mode 100644 index f76eda16..00000000 --- a/internal/quic-go/streams_map_outgoing_bidi.go +++ /dev/null @@ -1,226 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type outgoingBidiStreamsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]streamI - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) streamI - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingBidiStreamsMap( - newStream func(protocol.StreamNum) streamI, - queueControlFrame func(wire.Frame), -) *outgoingBidiStreamsMap { - return &outgoingBidiStreamsMap{ - streams: make(map[protocol.StreamNum]streamI), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingBidiStreamsMap) openStream() streamI { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingBidiStreamsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingBidiStreamsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingBidiStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/internal/quic-go/streams_map_outgoing_generic.go b/internal/quic-go/streams_map_outgoing_generic.go deleted file mode 100644 index f5449ed2..00000000 --- a/internal/quic-go/streams_map_outgoing_generic.go +++ /dev/null @@ -1,224 +0,0 @@ -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" -//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" -type outgoingItemsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]item - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) item - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingItemsMap( - newStream func(protocol.StreamNum) item, - queueControlFrame func(wire.Frame), -) *outgoingItemsMap { - return &outgoingItemsMap{ - streams: make(map[protocol.StreamNum]item), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingItemsMap) OpenStream() (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingItemsMap) openStream() item { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingItemsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: streamTypeGeneric, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingItemsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingItemsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/internal/quic-go/streams_map_outgoing_generic_test.go b/internal/quic-go/streams_map_outgoing_generic_test.go deleted file mode 100644 index dc2f13a1..00000000 --- a/internal/quic-go/streams_map_outgoing_generic_test.go +++ /dev/null @@ -1,539 +0,0 @@ -package quic - -import ( - "context" - "errors" - "fmt" - "math/rand" - "sort" - "sync" - "time" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Streams Map (outgoing)", func() { - var ( - m *outgoingItemsMap - newItem func(num protocol.StreamNum) item - mockSender *MockStreamSender - ) - - // waitForEnqueued waits until there are n go routines waiting on OpenStreamSync() - waitForEnqueued := func(n int) { - Eventually(func() int { - m.mutex.Lock() - defer m.mutex.Unlock() - return len(m.openQueue) - }, 50*time.Millisecond, 100*time.Microsecond).Should(Equal(n)) - } - - BeforeEach(func() { - newItem = func(num protocol.StreamNum) item { - return &mockGenericStream{num: num} - } - mockSender = NewMockStreamSender(mockCtrl) - m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame) - }) - - Context("no stream ID limit", func() { - BeforeEach(func() { - m.SetMaxStream(0xffffffff) - }) - - It("opens streams", func() { - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - str, err = m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - }) - - It("doesn't open streams after it has been closed", func() { - testErr := errors.New("close") - m.CloseWithError(testErr) - _, err := m.OpenStream() - Expect(err).To(MatchError(testErr)) - }) - - It("gets streams", func() { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - }) - - It("errors when trying to get a stream that has not yet been opened", func() { - _, err := m.GetStream(1) - Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1")) - }) - - It("deletes streams", func() { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(1)).To(Succeed()) - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(1) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - }) - - It("errors when deleting a non-existing stream", func() { - err := m.DeleteStream(1337) - Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1337")) - }) - - It("errors when deleting a stream twice", func() { - _, err := m.OpenStream() // opens firstNewStream - Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(1)).To(Succeed()) - err = m.DeleteStream(1) - Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1")) - }) - - It("closes all streams when CloseWithError is called", func() { - str1, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str2, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - testErr := errors.New("test err") - m.CloseWithError(testErr) - Expect(str1.(*mockGenericStream).closed).To(BeTrue()) - Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) - Expect(str2.(*mockGenericStream).closed).To(BeTrue()) - Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) - }) - - It("updates the send window", func() { - str1, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str2, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - m.UpdateSendWindow(1337) - Expect(str1.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) - Expect(str2.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) - }) - }) - - Context("with stream ID limits", func() { - It("errors when no stream can be opened immediately", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - _, err := m.OpenStream() - expectTooManyStreamsError(err) - }) - - It("returns immediately when called with a canceled context", func() { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err := m.OpenStreamSync(ctx) - Expect(err).To(MatchError("context canceled")) - }) - - It("blocks until a stream can be opened synchronously", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - close(done) - }() - waitForEnqueued(1) - - m.SetMaxStream(1) - Eventually(done).Should(BeClosed()) - }) - - It("unblocks when the context is canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(ctx) - Expect(err).To(MatchError("context canceled")) - close(done) - }() - waitForEnqueued(1) - - cancel() - Eventually(done).Should(BeClosed()) - - // make sure that the next stream opened is stream 1 - m.SetMaxStream(1000) - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - }) - - It("opens streams in the right order", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - close(done1) - }() - waitForEnqueued(1) - - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - close(done2) - }() - waitForEnqueued(2) - - m.SetMaxStream(1) - Eventually(done1).Should(BeClosed()) - Consistently(done2).ShouldNot(BeClosed()) - m.SetMaxStream(2) - Eventually(done2).Should(BeClosed()) - }) - - It("opens streams in the right order, when one of the contexts is canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - close(done1) - }() - waitForEnqueued(1) - - done2 := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(ctx) - Expect(err).To(MatchError(context.Canceled)) - close(done2) - }() - waitForEnqueued(2) - - done3 := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - close(done3) - }() - waitForEnqueued(3) - - cancel() - Eventually(done2).Should(BeClosed()) - m.SetMaxStream(1000) - Eventually(done1).Should(BeClosed()) - Eventually(done3).Should(BeClosed()) - }) - - It("unblocks multiple OpenStreamSync calls at the same time", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - done <- struct{}{} - }() - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - done <- struct{}{} - }() - waitForEnqueued(2) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).To(MatchError("test done")) - done <- struct{}{} - }() - waitForEnqueued(3) - - m.SetMaxStream(2) - Eventually(done).Should(Receive()) - Eventually(done).Should(Receive()) - Consistently(done).ShouldNot(Receive()) - - m.CloseWithError(errors.New("test done")) - Eventually(done).Should(Receive()) - }) - - It("returns an error for OpenStream while an OpenStreamSync call is blocking", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(2) - openedSync := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - close(openedSync) - }() - waitForEnqueued(1) - - start := make(chan struct{}) - openend := make(chan struct{}) - go func() { - defer GinkgoRecover() - var hasStarted bool - for { - str, err := m.OpenStream() - if err == nil { - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - close(openend) - return - } - expectTooManyStreamsError(err) - if !hasStarted { - close(start) - hasStarted = true - } - } - }() - - Eventually(start).Should(BeClosed()) - m.SetMaxStream(1) - Eventually(openedSync).Should(BeClosed()) - Consistently(openend).ShouldNot(BeClosed()) - m.SetMaxStream(2) - Eventually(openend).Should(BeClosed()) - }) - - It("stops opening synchronously when it is closed", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - testErr := errors.New("test error") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - m.CloseWithError(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't reduce the stream limit", func() { - m.SetMaxStream(2) - m.SetMaxStream(1) - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) - }) - - It("queues a STREAMS_BLOCKED frame if no stream can be opened", func() { - m.SetMaxStream(6) - // open the 6 allowed streams - for i := 0; i < 6; i++ { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6)) - }) - _, err := m.OpenStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) - }) - - It("only sends one STREAMS_BLOCKED frame for one stream ID", func() { - m.SetMaxStream(1) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) - }) - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - // try to open a stream twice, but expect only one STREAMS_BLOCKED to be sent - _, err = m.OpenStream() - expectTooManyStreamsError(err) - _, err = m.OpenStream() - expectTooManyStreamsError(err) - }) - - It("queues a STREAMS_BLOCKED frame when there more streams waiting for OpenStreamSync than MAX_STREAMS allows", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0)) - }) - done := make(chan struct{}, 2) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - done <- struct{}{} - }() - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - done <- struct{}{} - }() - waitForEnqueued(2) - - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) - }) - m.SetMaxStream(1) - Eventually(done).Should(Receive()) - Consistently(done).ShouldNot(Receive()) - m.SetMaxStream(2) - Eventually(done).Should(Receive()) - }) - }) - - Context("randomized tests", func() { - It("opens streams", func() { - rand.Seed(GinkgoRandomSeed()) - const n = 100 - fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) - - var blockedAt []protocol.StreamNum - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) - }).AnyTimes() - done := make(map[int]chan struct{}) - for i := 1; i <= n; i++ { - c := make(chan struct{}) - done[i] = c - - go func(doneChan chan struct{}, id protocol.StreamNum) { - defer GinkgoRecover() - defer close(doneChan) - str, err := m.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(id)) - }(c, protocol.StreamNum(i)) - waitForEnqueued(i) - } - - var limit int - limits := []protocol.StreamNum{0} - for limit < n { - limit += rand.Intn(n/5) + 1 - if limit <= n { - limits = append(limits, protocol.StreamNum(limit)) - } - fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit) - m.SetMaxStream(protocol.StreamNum(limit)) - for i := 1; i <= n; i++ { - if i <= limit { - Eventually(done[i]).Should(BeClosed()) - } else { - Expect(done[i]).ToNot(BeClosed()) - } - } - str, err := m.OpenStream() - if limit <= n { - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) - } else { - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(n + 1))) - } - } - Expect(blockedAt).To(Equal(limits)) - }) - - It("opens streams, when some of them are getting canceled", func() { - rand.Seed(GinkgoRandomSeed()) - const n = 100 - fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) - - var blockedAt []protocol.StreamNum - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) - }).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - streamsToCancel := make(map[protocol.StreamNum]struct{}) // used as a set - for i := 0; i < 10; i++ { - id := protocol.StreamNum(rand.Intn(n) + 1) - fmt.Fprintf(GinkgoWriter, "Canceling stream %d.\n", id) - streamsToCancel[id] = struct{}{} - } - - streamWillBeCanceled := func(id protocol.StreamNum) bool { - _, ok := streamsToCancel[id] - return ok - } - - var streamIDs []int - var mutex sync.Mutex - done := make(map[int]chan struct{}) - for i := 1; i <= n; i++ { - c := make(chan struct{}) - done[i] = c - - go func(doneChan chan struct{}, id protocol.StreamNum) { - defer GinkgoRecover() - defer close(doneChan) - cont := context.Background() - if streamWillBeCanceled(id) { - cont = ctx - } - str, err := m.OpenStreamSync(cont) - if streamWillBeCanceled(id) { - Expect(err).To(MatchError(context.Canceled)) - return - } - Expect(err).ToNot(HaveOccurred()) - mutex.Lock() - streamIDs = append(streamIDs, int(str.(*mockGenericStream).num)) - mutex.Unlock() - }(c, protocol.StreamNum(i)) - waitForEnqueued(i) - } - - cancel() - for id := range streamsToCancel { - Eventually(done[int(id)]).Should(BeClosed()) - } - var limit int - numStreams := n - len(streamsToCancel) - var limits []protocol.StreamNum - for limit < numStreams { - limits = append(limits, protocol.StreamNum(limit)) - limit += rand.Intn(n/5) + 1 - fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit) - m.SetMaxStream(protocol.StreamNum(limit)) - l := limit - if l > numStreams { - l = numStreams - } - Eventually(func() int { - mutex.Lock() - defer mutex.Unlock() - return len(streamIDs) - }).Should(Equal(l)) - // check that all stream IDs were used - Expect(streamIDs).To(HaveLen(l)) - sort.Ints(streamIDs) - for i := 0; i < l; i++ { - Expect(streamIDs[i]).To(Equal(i + 1)) - } - } - Expect(blockedAt).To(Equal(limits)) - }) - }) -}) diff --git a/internal/quic-go/streams_map_outgoing_uni.go b/internal/quic-go/streams_map_outgoing_uni.go deleted file mode 100644 index 22261fb4..00000000 --- a/internal/quic-go/streams_map_outgoing_uni.go +++ /dev/null @@ -1,226 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type outgoingUniStreamsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]sendStreamI - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) sendStreamI - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingUniStreamsMap( - newStream func(protocol.StreamNum) sendStreamI, - queueControlFrame func(wire.Frame), -) *outgoingUniStreamsMap { - return &outgoingUniStreamsMap{ - streams: make(map[protocol.StreamNum]sendStreamI), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingUniStreamsMap) openStream() sendStreamI { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingUniStreamsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingUniStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/internal/quic-go/streams_map_test.go b/internal/quic-go/streams_map_test.go deleted file mode 100644 index c3e9fd4b..00000000 --- a/internal/quic-go/streams_map_test.go +++ /dev/null @@ -1,499 +0,0 @@ -package quic - -import ( - "context" - "errors" - "fmt" - "net" - - "github.com/golang/mock/gomock" - - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func (e streamError) TestError() error { - nums := make([]interface{}, len(e.nums)) - for i, num := range e.nums { - nums[i] = num - } - return fmt.Errorf(e.message, nums...) -} - -type streamMapping struct { - firstIncomingBidiStream protocol.StreamID - firstIncomingUniStream protocol.StreamID - firstOutgoingBidiStream protocol.StreamID - firstOutgoingUniStream protocol.StreamID -} - -func expectTooManyStreamsError(err error) { - ExpectWithOffset(1, err).To(HaveOccurred()) - ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error())) - nerr, ok := err.(net.Error) - ExpectWithOffset(1, ok).To(BeTrue()) - ExpectWithOffset(1, nerr.Timeout()).To(BeFalse()) -} - -var _ = Describe("Streams Map", func() { - newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController { - return mocks.NewMockStreamFlowController(mockCtrl) - } - - serverStreamMapping := streamMapping{ - firstIncomingBidiStream: 0, - firstOutgoingBidiStream: 1, - firstIncomingUniStream: 2, - firstOutgoingUniStream: 3, - } - clientStreamMapping := streamMapping{ - firstIncomingBidiStream: 1, - firstOutgoingBidiStream: 0, - firstIncomingUniStream: 3, - firstOutgoingUniStream: 2, - } - - for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { - perspective := p - var ids streamMapping - if perspective == protocol.PerspectiveClient { - ids = clientStreamMapping - } else { - ids = serverStreamMapping - } - - Context(perspective.String(), func() { - var ( - m *streamsMap - mockSender *MockStreamSender - ) - - const ( - MaxBidiStreamNum = 111 - MaxUniStreamNum = 222 - ) - - allowUnlimitedStreams := func() { - m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: protocol.MaxStreamCount, - MaxUniStreamNum: protocol.MaxStreamCount, - }) - } - - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective, protocol.VersionWhatever).(*streamsMap) - }) - - Context("opening", func() { - It("opens bidirectional streams", func() { - allowUnlimitedStreams() - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&stream{})) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) - str, err = m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&stream{})) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4)) - }) - - It("opens unidirectional streams", func() { - allowUnlimitedStreams() - str, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&sendStream{})) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) - str, err = m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&sendStream{})) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4)) - }) - }) - - Context("accepting", func() { - It("accepts bidirectional streams", func() { - _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) - Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&stream{})) - Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) - }) - - It("accepts unidirectional streams", func() { - _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) - Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) - Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) - }) - }) - - Context("deleting", func() { - BeforeEach(func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - allowUnlimitedStreams() - }) - - It("deletes outgoing bidirectional streams", func() { - id := ids.firstOutgoingBidiStream - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - Expect(m.DeleteStream(id)).To(Succeed()) - dstr, err := m.GetOrOpenSendStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(dstr).To(BeNil()) - }) - - It("deletes incoming bidirectional streams", func() { - id := ids.firstIncomingBidiStream - str, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - Expect(m.DeleteStream(id)).To(Succeed()) - dstr, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(dstr).To(BeNil()) - }) - - It("accepts bidirectional streams after they have been deleted", func() { - id := ids.firstIncomingBidiStream - _, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - Expect(str.StreamID()).To(Equal(id)) - }) - - It("deletes outgoing unidirectional streams", func() { - id := ids.firstOutgoingUniStream - str, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - Expect(m.DeleteStream(id)).To(Succeed()) - dstr, err := m.GetOrOpenSendStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(dstr).To(BeNil()) - }) - - It("deletes incoming unidirectional streams", func() { - id := ids.firstIncomingUniStream - str, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - Expect(m.DeleteStream(id)).To(Succeed()) - dstr, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(dstr).To(BeNil()) - }) - - It("accepts unirectional streams after they have been deleted", func() { - id := ids.firstIncomingUniStream - _, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - Expect(str.StreamID()).To(Equal(id)) - }) - - It("errors when deleting unknown incoming unidirectional streams", func() { - id := ids.firstIncomingUniStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) - }) - - It("errors when deleting unknown outgoing unidirectional streams", func() { - id := ids.firstOutgoingUniStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) - }) - - It("errors when deleting unknown incoming bidirectional streams", func() { - id := ids.firstIncomingBidiStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) - }) - - It("errors when deleting unknown outgoing bidirectional streams", func() { - id := ids.firstOutgoingBidiStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) - }) - }) - - Context("getting streams", func() { - BeforeEach(func() { - allowUnlimitedStreams() - }) - - Context("send streams", func() { - It("gets an outgoing bidirectional stream", func() { - // need to open the stream ourselves first - // the peer is not allowed to create a stream initiated by us - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) - }) - - It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { - id := ids.firstOutgoingBidiStream + 5*4 - _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), - })) - }) - - It("gets an outgoing unidirectional stream", func() { - // need to open the stream ourselves first - // the peer is not allowed to create a stream initiated by us - _, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) - }) - - It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { - id := ids.firstOutgoingUniStream + 5*4 - _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), - })) - }) - - It("gets an incoming bidirectional stream", func() { - id := ids.firstIncomingBidiStream + 4*7 - str, err := m.GetOrOpenSendStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - }) - - It("errors when trying to get an incoming unidirectional stream", func() { - id := ids.firstIncomingUniStream - _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id), - })) - }) - }) - - Context("receive streams", func() { - It("gets an outgoing bidirectional stream", func() { - // need to open the stream ourselves first - // the peer is not allowed to create a stream initiated by us - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) - }) - - It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { - id := ids.firstOutgoingBidiStream + 5*4 - _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), - })) - }) - - It("gets an incoming bidirectional stream", func() { - id := ids.firstIncomingBidiStream + 4*7 - str, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - }) - - It("gets an incoming unidirectional stream", func() { - id := ids.firstIncomingUniStream + 4*10 - str, err := m.GetOrOpenReceiveStream(id) - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(id)) - }) - - It("errors when trying to get an outgoing unidirectional stream", func() { - id := ids.firstOutgoingUniStream - _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id), - })) - }) - }) - }) - - It("processes the parameter for outgoing streams", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - _, err := m.OpenStream() - expectTooManyStreamsError(err) - m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: 5, - MaxUniStreamNum: 8, - }) - - mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) - // test we can only 5 bidirectional streams - for i := 0; i < 5; i++ { - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) - } - _, err = m.OpenStream() - expectTooManyStreamsError(err) - // test we can only 8 unidirectional streams - for i := 0; i < 8; i++ { - str, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) - } - _, err = m.OpenUniStream() - expectTooManyStreamsError(err) - }) - - if perspective == protocol.PerspectiveClient { - It("applies parameters to existing streams (needed for 0-RTT)", func() { - m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: 1000, - MaxUniStreamNum: 1000, - }) - flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController) - m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController { - fc := mocks.NewMockStreamFlowController(mockCtrl) - flowControllers[id] = fc - return fc - } - - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - unistr, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - - Expect(flowControllers).To(HaveKey(str.StreamID())) - flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) - Expect(flowControllers).To(HaveKey(unistr.StreamID())) - flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) - - m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: 1000, - InitialMaxStreamDataUni: 1234, - MaxUniStreamNum: 1000, - InitialMaxStreamDataBidiRemote: 4321, - }) - }) - } - - Context("handling MAX_STREAMS frames", func() { - BeforeEach(func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - }) - - It("processes IDs for outgoing bidirectional streams", func() { - _, err := m.OpenStream() - expectTooManyStreamsError(err) - m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 1, - }) - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) - _, err = m.OpenStream() - expectTooManyStreamsError(err) - }) - - It("processes IDs for outgoing unidirectional streams", func() { - _, err := m.OpenUniStream() - expectTooManyStreamsError(err) - m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: 1, - }) - str, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) - _, err = m.OpenUniStream() - expectTooManyStreamsError(err) - }) - }) - - Context("sending MAX_STREAMS frames", func() { - It("sends a MAX_STREAMS frame for bidirectional streams", func() { - _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) - Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: MaxBidiStreamNum + 1, - }) - Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) - }) - - It("sends a MAX_STREAMS frame for unidirectional streams", func() { - _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) - Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: MaxUniStreamNum + 1, - }) - Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) - }) - }) - - It("closes", func() { - testErr := errors.New("test error") - m.CloseWithError(testErr) - _, err := m.OpenStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptStream(context.Background()) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptUniStream(context.Background()) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) - }) - - if perspective == protocol.PerspectiveClient { - It("resets for 0-RTT", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - m.ResetFor0RTT() - // make sure that calls to open / accept streams fail - _, err := m.OpenStream() - Expect(err).To(MatchError(Err0RTTRejected)) - _, err = m.AcceptStream(context.Background()) - Expect(err).To(MatchError(Err0RTTRejected)) - // make sure that we can still get new streams, as the server might be sending us data - str, err := m.GetOrOpenReceiveStream(3) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - - // now switch to using the new streams map - m.UseResetMaps() - _, err = m.OpenStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - }) - } - }) - } -}) diff --git a/internal/quic-go/sys_conn.go b/internal/quic-go/sys_conn.go deleted file mode 100644 index 315c26cc..00000000 --- a/internal/quic-go/sys_conn.go +++ /dev/null @@ -1,80 +0,0 @@ -package quic - -import ( - "net" - "syscall" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. -// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will use it. -// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. -type OOBCapablePacketConn interface { - net.PacketConn - SyscallConn() (syscall.RawConn, error) - ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) - WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) -} - -var _ OOBCapablePacketConn = &net.UDPConn{} - -func wrapConn(pc net.PacketConn) (rawConn, error) { - conn, ok := pc.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if ok { - rawConn, err := conn.SyscallConn() - if err != nil { - return nil, err - } - - if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { - // Only set DF on sockets that we expect to be able to handle that configuration. - err = setDF(rawConn) - if err != nil { - return nil, err - } - } - } - c, ok := pc.(OOBCapablePacketConn) - if !ok { - utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") - return &basicConn{PacketConn: pc}, nil - } - return newConn(c) -} - -// The basicConn is the most trivial implementation of a connection. -// It reads a single packet from the underlying net.PacketConn. -// It is used when -// * the net.PacketConn is not a OOBCapablePacketConn, and -// * when the OS doesn't support OOB. -type basicConn struct { - net.PacketConn -} - -var _ rawConn = &basicConn{} - -func (c *basicConn) ReadPacket() (*receivedPacket, error) { - buffer := getPacketBuffer() - // The packet size should not exceed protocol.MaxPacketBufferSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] - n, addr, err := c.PacketConn.ReadFrom(buffer.Data) - if err != nil { - return nil, err - } - return &receivedPacket{ - remoteAddr: addr, - rcvTime: time.Now(), - data: buffer.Data[:n], - buffer: buffer, - }, nil -} - -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { - return c.PacketConn.WriteTo(b, addr) -} diff --git a/internal/quic-go/sys_conn_df.go b/internal/quic-go/sys_conn_df.go deleted file mode 100644 index ae9274d9..00000000 --- a/internal/quic-go/sys_conn_df.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build !linux && !windows -// +build !linux,!windows - -package quic - -import "syscall" - -func setDF(rawConn syscall.RawConn) error { - // no-op on unsupported platforms - return nil -} - -func isMsgSizeErr(err error) bool { - // to be implemented for more specific platforms - return false -} diff --git a/internal/quic-go/sys_conn_df_linux.go b/internal/quic-go/sys_conn_df_linux.go deleted file mode 100644 index d3345658..00000000 --- a/internal/quic-go/sys_conn_df_linux.go +++ /dev/null @@ -1,40 +0,0 @@ -//go:build linux -// +build linux - -package quic - -import ( - "errors" - "syscall" - - "github.com/imroc/req/v3/internal/quic-go/utils" - "golang.org/x/sys/unix" -) - -func setDF(rawConn syscall.RawConn) error { - // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" - // and the datagram will not be fragmented - var errDFIPv4, errDFIPv6 error - if err := rawConn.Control(func(fd uintptr) { - errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) - errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) - }); err != nil { - return err - } - switch { - case errDFIPv4 == nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") - case errDFIPv4 == nil && errDFIPv6 != nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4.") - case errDFIPv4 != nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv6.") - case errDFIPv4 != nil && errDFIPv6 != nil: - return errors.New("setting DF failed for both IPv4 and IPv6") - } - return nil -} - -func isMsgSizeErr(err error) bool { - // https://man7.org/linux/man-pages/man7/udp.7.html - return errors.Is(err, unix.EMSGSIZE) -} diff --git a/internal/quic-go/sys_conn_df_windows.go b/internal/quic-go/sys_conn_df_windows.go deleted file mode 100644 index e5e2e5b3..00000000 --- a/internal/quic-go/sys_conn_df_windows.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build windows -// +build windows - -package quic - -import ( - "errors" - "syscall" - - "github.com/imroc/req/v3/internal/quic-go/utils" - "golang.org/x/sys/windows" -) - -const ( - // same for both IPv4 and IPv6 on Windows - // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html - // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html - IP_DONTFRAGMENT = 14 - IPV6_DONTFRAG = 14 -) - -func setDF(rawConn syscall.RawConn) error { - var errDFIPv4, errDFIPv6 error - if err := rawConn.Control(func(fd uintptr) { - errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) - errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) - }); err != nil { - return err - } - switch { - case errDFIPv4 == nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") - case errDFIPv4 == nil && errDFIPv6 != nil: - utils.DefaultLogger.Debugf("Setting DF for IPv4.") - case errDFIPv4 != nil && errDFIPv6 == nil: - utils.DefaultLogger.Debugf("Setting DF for IPv6.") - case errDFIPv4 != nil && errDFIPv6 != nil: - return errors.New("setting DF failed for both IPv4 and IPv6") - } - return nil -} - -func isMsgSizeErr(err error) bool { - // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 - return errors.Is(err, windows.WSAEMSGSIZE) -} diff --git a/internal/quic-go/sys_conn_helper_darwin.go b/internal/quic-go/sys_conn_helper_darwin.go deleted file mode 100644 index eabf489f..00000000 --- a/internal/quic-go/sys_conn_helper_darwin.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build darwin -// +build darwin - -package quic - -import "golang.org/x/sys/unix" - -const msgTypeIPTOS = unix.IP_RECVTOS - -const ( - ipv4RECVPKTINFO = unix.IP_RECVPKTINFO - ipv6RECVPKTINFO = 0x3d -) - -const ( - msgTypeIPv4PKTINFO = unix.IP_PKTINFO - msgTypeIPv6PKTINFO = 0x2e -) - -// ReadBatch only returns a single packet on OSX, -// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. -const batchSize = 1 diff --git a/internal/quic-go/sys_conn_helper_freebsd.go b/internal/quic-go/sys_conn_helper_freebsd.go deleted file mode 100644 index 0b3e8434..00000000 --- a/internal/quic-go/sys_conn_helper_freebsd.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build freebsd -// +build freebsd - -package quic - -import "golang.org/x/sys/unix" - -const ( - msgTypeIPTOS = unix.IP_RECVTOS -) - -const ( - ipv4RECVPKTINFO = 0x7 - ipv6RECVPKTINFO = 0x24 -) - -const ( - msgTypeIPv4PKTINFO = 0x7 - msgTypeIPv6PKTINFO = 0x2e -) - -const batchSize = 8 diff --git a/internal/quic-go/sys_conn_helper_linux.go b/internal/quic-go/sys_conn_helper_linux.go deleted file mode 100644 index 51bec900..00000000 --- a/internal/quic-go/sys_conn_helper_linux.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build linux -// +build linux - -package quic - -import "golang.org/x/sys/unix" - -const msgTypeIPTOS = unix.IP_TOS - -const ( - ipv4RECVPKTINFO = unix.IP_PKTINFO - ipv6RECVPKTINFO = unix.IPV6_RECVPKTINFO -) - -const ( - msgTypeIPv4PKTINFO = unix.IP_PKTINFO - msgTypeIPv6PKTINFO = unix.IPV6_PKTINFO -) - -const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) diff --git a/internal/quic-go/sys_conn_no_oob.go b/internal/quic-go/sys_conn_no_oob.go deleted file mode 100644 index e3b0d11f..00000000 --- a/internal/quic-go/sys_conn_no_oob.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build !darwin && !linux && !freebsd && !windows -// +build !darwin,!linux,!freebsd,!windows - -package quic - -import "net" - -func newConn(c net.PacketConn) (rawConn, error) { - return &basicConn{PacketConn: c}, nil -} - -func inspectReadBuffer(interface{}) (int, error) { - return 0, nil -} - -func (i *packetInfo) OOB() []byte { return nil } diff --git a/internal/quic-go/sys_conn_oob.go b/internal/quic-go/sys_conn_oob.go deleted file mode 100644 index 38fe9831..00000000 --- a/internal/quic-go/sys_conn_oob.go +++ /dev/null @@ -1,257 +0,0 @@ -//go:build darwin || linux || freebsd -// +build darwin linux freebsd - -package quic - -import ( - "encoding/binary" - "errors" - "fmt" - "net" - "syscall" - "time" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -const ( - ecnMask = 0x3 - oobBufferSize = 128 -) - -// Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version. -// They're both just aliases for x/net/internal/socket.Message. -// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages. -var _ ipv4.Message = ipv6.Message{} - -type batchConn interface { - ReadBatch(ms []ipv4.Message, flags int) (int, error) -} - -func inspectReadBuffer(c interface{}) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } - var size int - var serr error - if err := rawConn.Control(func(fd uintptr) { - size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) - }); err != nil { - return 0, err - } - return size, serr -} - -type oobConn struct { - OOBCapablePacketConn - batchConn batchConn - - readPos uint8 - // Packets received from the kernel, but not yet returned by ReadPacket(). - messages []ipv4.Message - buffers [batchSize]*packetBuffer -} - -var _ rawConn = &oobConn{} - -func newConn(c OOBCapablePacketConn) (*oobConn, error) { - rawConn, err := c.SyscallConn() - if err != nil { - return nil, err - } - needsPacketInfo := false - if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() { - needsPacketInfo = true - } - // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. - // Try enabling receiving of ECN and packet info for both IP versions. - // We expect at least one of those syscalls to succeed. - var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error - if err := rawConn.Control(func(fd uintptr) { - errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) - errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) - - if needsPacketInfo { - errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1) - errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1) - } - }); err != nil { - return nil, err - } - switch { - case errECNIPv4 == nil && errECNIPv6 == nil: - utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.") - case errECNIPv4 == nil && errECNIPv6 != nil: - utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.") - case errECNIPv4 != nil && errECNIPv6 == nil: - utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.") - case errECNIPv4 != nil && errECNIPv6 != nil: - return nil, errors.New("activating ECN failed for both IPv4 and IPv6") - } - if needsPacketInfo { - switch { - case errPIIPv4 == nil && errPIIPv6 == nil: - utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.") - case errPIIPv4 == nil && errPIIPv6 != nil: - utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.") - case errPIIPv4 != nil && errPIIPv6 == nil: - utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.") - case errPIIPv4 != nil && errPIIPv6 != nil: - return nil, errors.New("activating packet info failed for both IPv4 and IPv6") - } - } - - // Allows callers to pass in a connection that already satisfies batchConn interface - // to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor - // via SyscallConn(), and read it that way, which might not be what the caller wants. - var bc batchConn - if ibc, ok := c.(batchConn); ok { - bc = ibc - } else { - bc = ipv4.NewPacketConn(c) - } - - oobConn := &oobConn{ - OOBCapablePacketConn: c, - batchConn: bc, - messages: make([]ipv4.Message, batchSize), - readPos: batchSize, - } - for i := 0; i < batchSize; i++ { - oobConn.messages[i].OOB = make([]byte, oobBufferSize) - } - return oobConn, nil -} - -func (c *oobConn) ReadPacket() (*receivedPacket, error) { - if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. - c.messages = c.messages[:batchSize] - // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call - for i := uint8(0); i < c.readPos; i++ { - buffer := getPacketBuffer() - buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] - c.buffers[i] = buffer - c.messages[i].Buffers = [][]byte{c.buffers[i].Data} - } - c.readPos = 0 - - n, err := c.batchConn.ReadBatch(c.messages, 0) - if n == 0 || err != nil { - return nil, err - } - c.messages = c.messages[:n] - } - - msg := c.messages[c.readPos] - buffer := c.buffers[c.readPos] - c.readPos++ - ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN]) - if err != nil { - return nil, err - } - var ecn protocol.ECN - var destIP net.IP - var ifIndex uint32 - for _, ctrlMsg := range ctrlMsgs { - if ctrlMsg.Header.Level == unix.IPPROTO_IP { - switch ctrlMsg.Header.Type { - case msgTypeIPTOS: - ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) - case msgTypeIPv4PKTINFO: - // struct in_pktinfo { - // unsigned int ipi_ifindex; /* Interface index */ - // struct in_addr ipi_spec_dst; /* Local address */ - // struct in_addr ipi_addr; /* Header Destination - // address */ - // }; - ip := make([]byte, 4) - if len(ctrlMsg.Data) == 12 { - ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data) - copy(ip, ctrlMsg.Data[8:12]) - } else if len(ctrlMsg.Data) == 4 { - // FreeBSD - copy(ip, ctrlMsg.Data) - } - destIP = net.IP(ip) - } - } - if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 { - switch ctrlMsg.Header.Type { - case unix.IPV6_TCLASS: - ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) - case msgTypeIPv6PKTINFO: - // struct in6_pktinfo { - // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ - // unsigned int ipi6_ifindex; /* send/recv interface index */ - // }; - if len(ctrlMsg.Data) == 20 { - ip := make([]byte, 16) - copy(ip, ctrlMsg.Data[:16]) - destIP = net.IP(ip) - ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:]) - } - } - } - } - var info *packetInfo - if destIP != nil { - info = &packetInfo{ - addr: destIP, - ifIndex: ifIndex, - } - } - return &receivedPacket{ - remoteAddr: msg.Addr, - rcvTime: time.Now(), - data: msg.Buffers[0][:msg.N], - ecn: ecn, - info: info, - buffer: buffer, - }, nil -} - -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { - n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) - return n, err -} - -func (info *packetInfo) OOB() []byte { - if info == nil { - return nil - } - if ip4 := info.addr.To4(); ip4 != nil { - // struct in_pktinfo { - // unsigned int ipi_ifindex; /* Interface index */ - // struct in_addr ipi_spec_dst; /* Local address */ - // struct in_addr ipi_addr; /* Header Destination address */ - // }; - cm := ipv4.ControlMessage{ - Src: ip4, - IfIndex: int(info.ifIndex), - } - return cm.Marshal() - } else if len(info.addr) == 16 { - // struct in6_pktinfo { - // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ - // unsigned int ipi6_ifindex; /* send/recv interface index */ - // }; - cm := ipv6.ControlMessage{ - Src: info.addr, - IfIndex: int(info.ifIndex), - } - return cm.Marshal() - } - return nil -} diff --git a/internal/quic-go/sys_conn_oob_test.go b/internal/quic-go/sys_conn_oob_test.go deleted file mode 100644 index 82afa1c8..00000000 --- a/internal/quic-go/sys_conn_oob_test.go +++ /dev/null @@ -1,243 +0,0 @@ -//go:build !windows -// +build !windows - -package quic - -import ( - "fmt" - "net" - "time" - - "golang.org/x/net/ipv4" - "golang.org/x/sys/unix" - - "github.com/golang/mock/gomock" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("OOB Conn Test", func() { - runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { - addr, err := net.ResolveUDPAddr(network, address) - Expect(err).ToNot(HaveOccurred()) - udpConn, err := net.ListenUDP(network, addr) - Expect(err).ToNot(HaveOccurred()) - oobConn, err := newConn(udpConn) - Expect(err).ToNot(HaveOccurred()) - - packetChan := make(chan *receivedPacket) - go func() { - defer GinkgoRecover() - for { - p, err := oobConn.ReadPacket() - if err != nil { - return - } - packetChan <- p - } - }() - - return udpConn, packetChan - } - - Context("ECN conn", func() { - sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { - conn, err := net.DialUDP(network, nil, addr) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - rawConn, err := conn.SyscallConn() - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, rawConn.Control(func(fd uintptr) { - setECN(fd) - })).To(Succeed()) - _, err = conn.Write([]byte("foobar")) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - return conn.LocalAddr() - } - - It("reads ECN flags on IPv4", func() { - conn, packetChan := runServer("udp4", "localhost:0") - defer conn.Close() - - sentFrom := sendPacketWithECN( - "udp4", - conn.LocalAddr().(*net.UDPAddr), - func(fd uintptr) { - Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 2)).To(Succeed()) - }, - ) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) - Expect(p.data).To(Equal([]byte("foobar"))) - Expect(p.remoteAddr).To(Equal(sentFrom)) - Expect(p.ecn).To(Equal(protocol.ECT0)) - }) - - It("reads ECN flags on IPv6", func() { - conn, packetChan := runServer("udp6", "[::]:0") - defer conn.Close() - - sentFrom := sendPacketWithECN( - "udp6", - conn.LocalAddr().(*net.UDPAddr), - func(fd uintptr) { - Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 3)).To(Succeed()) - }, - ) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) - Expect(p.data).To(Equal([]byte("foobar"))) - Expect(p.remoteAddr).To(Equal(sentFrom)) - Expect(p.ecn).To(Equal(protocol.ECNCE)) - }) - - It("reads ECN flags on a connection that supports both IPv4 and IPv6", func() { - conn, packetChan := runServer("udp", "0.0.0.0:0") - defer conn.Close() - port := conn.LocalAddr().(*net.UDPAddr).Port - - // IPv4 - sendPacketWithECN( - "udp4", - &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, - func(fd uintptr) { - Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 3)).To(Succeed()) - }, - ) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) - Expect(p.ecn).To(Equal(protocol.ECNCE)) - - // IPv6 - sendPacketWithECN( - "udp6", - &net.UDPAddr{IP: net.IPv6loopback, Port: port}, - func(fd uintptr) { - Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 1)).To(Succeed()) - }, - ) - - Eventually(packetChan).Should(Receive(&p)) - Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) - Expect(p.ecn).To(Equal(protocol.ECT1)) - }) - }) - - Context("Packet Info conn", func() { - sendPacket := func(network string, addr *net.UDPAddr) net.Addr { - conn, err := net.DialUDP(network, nil, addr) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - _, err = conn.Write([]byte("foobar")) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - return conn.LocalAddr() - } - - It("reads packet info on IPv4", func() { - conn, packetChan := runServer("udp4", ":0") - defer conn.Close() - - addr := conn.LocalAddr().(*net.UDPAddr) - ip := net.ParseIP("127.0.0.1").To4() - addr.IP = ip - sentFrom := sendPacket("udp4", addr) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) - Expect(p.data).To(Equal([]byte("foobar"))) - Expect(p.remoteAddr).To(Equal(sentFrom)) - Expect(p.info).To(Not(BeNil())) - Expect(p.info.addr.To4()).To(Equal(ip)) - }) - - It("reads packet info on IPv6", func() { - conn, packetChan := runServer("udp6", ":0") - defer conn.Close() - - addr := conn.LocalAddr().(*net.UDPAddr) - ip := net.ParseIP("::1") - addr.IP = ip - sentFrom := sendPacket("udp6", addr) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) - Expect(p.data).To(Equal([]byte("foobar"))) - Expect(p.remoteAddr).To(Equal(sentFrom)) - Expect(p.info).To(Not(BeNil())) - Expect(p.info.addr).To(Equal(ip)) - }) - - It("reads packet info on a connection that supports both IPv4 and IPv6", func() { - conn, packetChan := runServer("udp", ":0") - defer conn.Close() - port := conn.LocalAddr().(*net.UDPAddr).Port - - // IPv4 - ip4 := net.ParseIP("127.0.0.1").To4() - sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port}) - - var p *receivedPacket - Eventually(packetChan).Should(Receive(&p)) - Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) - Expect(p.info).To(Not(BeNil())) - Expect(p.info.addr.To4()).To(Equal(ip4)) - - // IPv6 - ip6 := net.ParseIP("::1") - sendPacket("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: port}) - - Eventually(packetChan).Should(Receive(&p)) - Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) - Expect(p.info).To(Not(BeNil())) - Expect(p.info.addr).To(Equal(ip6)) - }) - }) - - Context("Batch Reading", func() { - var batchConn *MockBatchConn - - BeforeEach(func() { - batchConn = NewMockBatchConn(mockCtrl) - }) - - It("reads multiple messages in one batch", func() { - const numMsgRead = batchSize/2 + 1 - var counter int - batchConn.EXPECT().ReadBatch(gomock.Any(), gomock.Any()).DoAndReturn(func(ms []ipv4.Message, flags int) (int, error) { - Expect(ms).To(HaveLen(batchSize)) - for i := 0; i < numMsgRead; i++ { - Expect(ms[i].Buffers).To(HaveLen(1)) - Expect(ms[i].Buffers[0]).To(HaveLen(int(protocol.MaxPacketBufferSize))) - data := []byte(fmt.Sprintf("message %d", counter)) - counter++ - ms[i].Buffers[0] = data - ms[i].N = len(data) - } - return numMsgRead, nil - }).Times(2) - - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - udpConn, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - oobConn, err := newConn(udpConn) - Expect(err).ToNot(HaveOccurred()) - oobConn.batchConn = batchConn - - for i := 0; i < batchSize+1; i++ { - p, err := oobConn.ReadPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(string(p.data)).To(Equal(fmt.Sprintf("message %d", i))) - } - }) - }) -}) diff --git a/internal/quic-go/sys_conn_test.go b/internal/quic-go/sys_conn_test.go deleted file mode 100644 index 15df7760..00000000 --- a/internal/quic-go/sys_conn_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package quic - -import ( - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - "github.com/golang/mock/gomock" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Basic Conn Test", func() { - It("reads a packet", func() { - c := NewMockPacketConn(mockCtrl) - addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - c.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { - data := []byte("foobar") - Expect(b).To(HaveLen(int(protocol.MaxPacketBufferSize))) - return copy(b, data), addr, nil - }) - - conn, err := wrapConn(c) - Expect(err).ToNot(HaveOccurred()) - p, err := conn.ReadPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.data).To(Equal([]byte("foobar"))) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(100*time.Millisecond))) - Expect(p.remoteAddr).To(Equal(addr)) - }) -}) diff --git a/internal/quic-go/sys_conn_windows.go b/internal/quic-go/sys_conn_windows.go deleted file mode 100644 index f2cc22ab..00000000 --- a/internal/quic-go/sys_conn_windows.go +++ /dev/null @@ -1,40 +0,0 @@ -//go:build windows -// +build windows - -package quic - -import ( - "errors" - "fmt" - "net" - "syscall" - - "golang.org/x/sys/windows" -) - -func newConn(c OOBCapablePacketConn) (rawConn, error) { - return &basicConn{PacketConn: c}, nil -} - -func inspectReadBuffer(c net.PacketConn) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } - var size int - var serr error - if err := rawConn.Control(func(fd uintptr) { - size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) - }); err != nil { - return 0, err - } - return size, serr -} - -func (i *packetInfo) OOB() []byte { return nil } diff --git a/internal/quic-go/sys_conn_windows_test.go b/internal/quic-go/sys_conn_windows_test.go deleted file mode 100644 index 0dae55a9..00000000 --- a/internal/quic-go/sys_conn_windows_test.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build windows -// +build windows - -package quic - -import ( - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Windows Conn Test", func() { - It("works on IPv4", func() { - addr, err := net.ResolveUDPAddr("udp4", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - udpConn, err := net.ListenUDP("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - conn, err := newConn(udpConn) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.Close()).To(Succeed()) - }) - - It("works on IPv6", func() { - addr, err := net.ResolveUDPAddr("udp6", "[::1]:0") - Expect(err).ToNot(HaveOccurred()) - udpConn, err := net.ListenUDP("udp6", addr) - Expect(err).ToNot(HaveOccurred()) - conn, err := newConn(udpConn) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.Close()).To(Succeed()) - }) -}) diff --git a/internal/quic-go/testdata/ca.pem b/internal/quic-go/testdata/ca.pem deleted file mode 100644 index 67a5545e..00000000 --- a/internal/quic-go/testdata/ca.pem +++ /dev/null @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x -dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z -MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 -aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh -iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 -qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp -3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD -WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS -Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU -1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j -wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK -uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ -nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX -CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp -qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 ------END CERTIFICATE----- diff --git a/internal/quic-go/testdata/cert.go b/internal/quic-go/testdata/cert.go deleted file mode 100644 index f862b0cb..00000000 --- a/internal/quic-go/testdata/cert.go +++ /dev/null @@ -1,55 +0,0 @@ -package testdata - -import ( - "crypto/tls" - "crypto/x509" - "io/ioutil" - "path" - "runtime" -) - -var certPath string - -func init() { - _, filename, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - - certPath = path.Dir(filename) -} - -// GetCertificatePaths returns the paths to certificate and key -func GetCertificatePaths() (string, string) { - return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") -} - -// GetTLSConfig returns a tls config for quic.clemente.io -func GetTLSConfig() *tls.Config { - cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) - if err != nil { - panic(err) - } - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - } -} - -// AddRootCA adds the root CA certificate to a cert pool -func AddRootCA(certPool *x509.CertPool) { - caCertPath := path.Join(certPath, "ca.pem") - caCertRaw, err := ioutil.ReadFile(caCertPath) - if err != nil { - panic(err) - } - if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { - panic("Could not add root ceritificate to pool.") - } -} - -// GetRootCA returns an x509.CertPool containing (only) the CA certificate -func GetRootCA() *x509.CertPool { - pool := x509.NewCertPool() - AddRootCA(pool) - return pool -} diff --git a/internal/quic-go/testdata/cert.pem b/internal/quic-go/testdata/cert.pem deleted file mode 100644 index 91d1aa9e..00000000 --- a/internal/quic-go/testdata/cert.pem +++ /dev/null @@ -1,18 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV -BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz -NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 -+w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu -WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ -meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw -Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM -i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw -FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz -shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf -636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj -fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh -hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R -vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U -A/TbaR0ih/qG ------END CERTIFICATE----- diff --git a/internal/quic-go/testdata/cert_test.go b/internal/quic-go/testdata/cert_test.go deleted file mode 100644 index 0de1bd7b..00000000 --- a/internal/quic-go/testdata/cert_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package testdata - -import ( - "crypto/tls" - "io/ioutil" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("certificates", func() { - It("returns certificates", func() { - ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - - go func() { - defer GinkgoRecover() - conn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - }() - - conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(conn) - Expect(err).ToNot(HaveOccurred()) - Expect(string(data)).To(Equal("foobar")) - }) -}) diff --git a/internal/quic-go/testdata/generate_key.sh b/internal/quic-go/testdata/generate_key.sh deleted file mode 100755 index 7ecaa966..00000000 --- a/internal/quic-go/testdata/generate_key.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash - -set -e - -echo "Generating CA key and certificate:" -openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ - -keyout ca.key -out ca.pem \ - -subj "/O=quic-go Certificate Authority/" - -echo "Generating CSR" -openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ - -subj "/O=quic-go/" - -echo "Sign certificate:" -openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ - -CA ca.pem -CAkey ca.key -CAcreateserial \ - -extfile <(printf "subjectAltName=DNS:localhost") - -# debug output the certificate -openssl x509 -noout -text -in cert.pem - -# we don't need the CA key, the serial number and the CSR any more -rm ca.key cert.csr ca.srl - diff --git a/internal/quic-go/testdata/priv.key b/internal/quic-go/testdata/priv.key deleted file mode 100644 index 56b8d894..00000000 --- a/internal/quic-go/testdata/priv.key +++ /dev/null @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ -23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK -pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq -t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa -bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd -Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls -dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw -dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd -IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 -hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 -AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 -tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ -Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r -9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u -O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 -tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 -Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H -9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec -7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW -uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 -vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq -unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 -lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T -d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ -yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 -jT0GzDymgLMGp8RPdBkpk+w= ------END PRIVATE KEY----- diff --git a/internal/quic-go/testdata/testdata_suite_test.go b/internal/quic-go/testdata/testdata_suite_test.go deleted file mode 100644 index 4e9011cf..00000000 --- a/internal/quic-go/testdata/testdata_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package testdata - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestTestdata(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Testdata Suite") -} diff --git a/internal/quic-go/testutils/testutils.go b/internal/quic-go/testutils/testutils.go deleted file mode 100644 index cbd24c91..00000000 --- a/internal/quic-go/testutils/testutils.go +++ /dev/null @@ -1,97 +0,0 @@ -package testutils - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/handshake" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -// Utilities for simulating packet injection and man-in-the-middle (MITM) attacker tests. -// Do not use for non-testing purposes. - -// writePacket returns a new raw packet with the specified header and payload -func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { - buf := &bytes.Buffer{} - hdr.Write(buf, hdr.Version) - return append(buf.Bytes(), data...) -} - -// packRawPayload returns a new raw payload containing given frames -func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte { - buf := new(bytes.Buffer) - for _, cf := range frames { - cf.Write(buf, version) - } - return buf.Bytes() -} - -// ComposeInitialPacket returns an Initial packet encrypted under key -// (the original destination connection ID) containing specified frames -func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte { - sealer, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer, version) - - // compose payload - var payload []byte - if len(frames) == 0 { - payload = make([]byte, protocol.MinInitialPacketSize) - } else { - payload = packRawPayload(version, frames) - } - - // compose Initial header - payloadSize := len(payload) - pnLength := protocol.PacketNumberLen4 - length := payloadSize + int(pnLength) + sealer.Overhead() - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Length: protocol.ByteCount(length), - Version: version, - }, - PacketNumberLen: pnLength, - PacketNumber: 0x0, - } - - raw := writePacket(hdr, payload) - - // encrypt payload and header - payloadOffset := len(raw) - payloadSize - var encrypted []byte - encrypted = sealer.Seal(encrypted, payload, hdr.PacketNumber, raw[:payloadOffset]) - hdrBytes := raw[0:payloadOffset] - encrypted = append(hdrBytes, encrypted...) - pnOffset := payloadOffset - int(pnLength) // packet number offset - sealer.EncryptHeader( - encrypted[payloadOffset:payloadOffset+16], // first 16 bytes of payload (sample) - &encrypted[0], // first byte of header - encrypted[pnOffset:payloadOffset], // packet number bytes - ) - return encrypted -} - -// ComposeRetryPacket returns a new raw Retry Packet -func ComposeRetryPacket( - srcConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - origDestConnID protocol.ConnectionID, - token []byte, - version protocol.VersionNumber, -) []byte { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Token: token, - Version: version, - }, - } - data := writePacket(hdr, nil) - return append(data, handshake.GetRetryIntegrityTag(data, origDestConnID, version)[:]...) -} diff --git a/internal/quic-go/token_store.go b/internal/quic-go/token_store.go deleted file mode 100644 index cbc07830..00000000 --- a/internal/quic-go/token_store.go +++ /dev/null @@ -1,117 +0,0 @@ -package quic - -import ( - "container/list" - "sync" - - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -type singleOriginTokenStore struct { - tokens []*ClientToken - len int - p int -} - -func newSingleOriginTokenStore(size int) *singleOriginTokenStore { - return &singleOriginTokenStore{tokens: make([]*ClientToken, size)} -} - -func (s *singleOriginTokenStore) Add(token *ClientToken) { - s.tokens[s.p] = token - s.p = s.index(s.p + 1) - s.len = utils.Min(s.len+1, len(s.tokens)) -} - -func (s *singleOriginTokenStore) Pop() *ClientToken { - s.p = s.index(s.p - 1) - token := s.tokens[s.p] - s.tokens[s.p] = nil - s.len = utils.Max(s.len-1, 0) - return token -} - -func (s *singleOriginTokenStore) Len() int { - return s.len -} - -func (s *singleOriginTokenStore) index(i int) int { - mod := len(s.tokens) - return (i + mod) % mod -} - -type lruTokenStoreEntry struct { - key string - cache *singleOriginTokenStore -} - -type lruTokenStore struct { - mutex sync.Mutex - - m map[string]*list.Element - q *list.List - capacity int - singleOriginSize int -} - -var _ TokenStore = &lruTokenStore{} - -// NewLRUTokenStore creates a new LRU cache for tokens received by the client. -// maxOrigins specifies how many origins this cache is saving tokens for. -// tokensPerOrigin specifies the maximum number of tokens per origin. -func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { - return &lruTokenStore{ - m: make(map[string]*list.Element), - q: list.New(), - capacity: maxOrigins, - singleOriginSize: tokensPerOrigin, - } -} - -func (s *lruTokenStore) Put(key string, token *ClientToken) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if el, ok := s.m[key]; ok { - entry := el.Value.(*lruTokenStoreEntry) - entry.cache.Add(token) - s.q.MoveToFront(el) - return - } - - if s.q.Len() < s.capacity { - entry := &lruTokenStoreEntry{ - key: key, - cache: newSingleOriginTokenStore(s.singleOriginSize), - } - entry.cache.Add(token) - s.m[key] = s.q.PushFront(entry) - return - } - - elem := s.q.Back() - entry := elem.Value.(*lruTokenStoreEntry) - delete(s.m, entry.key) - entry.key = key - entry.cache = newSingleOriginTokenStore(s.singleOriginSize) - entry.cache.Add(token) - s.q.MoveToFront(elem) - s.m[key] = elem -} - -func (s *lruTokenStore) Pop(key string) *ClientToken { - s.mutex.Lock() - defer s.mutex.Unlock() - - var token *ClientToken - if el, ok := s.m[key]; ok { - s.q.MoveToFront(el) - cache := el.Value.(*lruTokenStoreEntry).cache - token = cache.Pop() - if cache.Len() == 0 { - s.q.Remove(el) - delete(s.m, key) - } - } - return token -} diff --git a/internal/quic-go/token_store_test.go b/internal/quic-go/token_store_test.go deleted file mode 100644 index 01107821..00000000 --- a/internal/quic-go/token_store_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package quic - -import ( - "fmt" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Token Cache", func() { - var s TokenStore - - BeforeEach(func() { - s = NewLRUTokenStore(3, 4) - }) - - mockToken := func(num int) *ClientToken { - return &ClientToken{data: []byte(fmt.Sprintf("%d", num))} - } - - Context("for a single origin", func() { - const origin = "localhost" - - It("adds and gets tokens", func() { - s.Put(origin, mockToken(1)) - s.Put(origin, mockToken(2)) - Expect(s.Pop(origin)).To(Equal(mockToken(2))) - Expect(s.Pop(origin)).To(Equal(mockToken(1))) - Expect(s.Pop(origin)).To(BeNil()) - }) - - It("overwrites old tokens", func() { - s.Put(origin, mockToken(1)) - s.Put(origin, mockToken(2)) - s.Put(origin, mockToken(3)) - s.Put(origin, mockToken(4)) - s.Put(origin, mockToken(5)) - Expect(s.Pop(origin)).To(Equal(mockToken(5))) - Expect(s.Pop(origin)).To(Equal(mockToken(4))) - Expect(s.Pop(origin)).To(Equal(mockToken(3))) - Expect(s.Pop(origin)).To(Equal(mockToken(2))) - Expect(s.Pop(origin)).To(BeNil()) - }) - - It("continues after getting a token", func() { - s.Put(origin, mockToken(1)) - s.Put(origin, mockToken(2)) - s.Put(origin, mockToken(3)) - Expect(s.Pop(origin)).To(Equal(mockToken(3))) - s.Put(origin, mockToken(4)) - s.Put(origin, mockToken(5)) - Expect(s.Pop(origin)).To(Equal(mockToken(5))) - Expect(s.Pop(origin)).To(Equal(mockToken(4))) - Expect(s.Pop(origin)).To(Equal(mockToken(2))) - Expect(s.Pop(origin)).To(Equal(mockToken(1))) - Expect(s.Pop(origin)).To(BeNil()) - }) - }) - - Context("for multiple origins", func() { - It("adds and gets tokens", func() { - s.Put("host1", mockToken(1)) - s.Put("host2", mockToken(2)) - Expect(s.Pop("host1")).To(Equal(mockToken(1))) - Expect(s.Pop("host1")).To(BeNil()) - Expect(s.Pop("host2")).To(Equal(mockToken(2))) - Expect(s.Pop("host2")).To(BeNil()) - }) - - It("evicts old entries", func() { - s.Put("host1", mockToken(1)) - s.Put("host2", mockToken(2)) - s.Put("host3", mockToken(3)) - s.Put("host4", mockToken(4)) - Expect(s.Pop("host1")).To(BeNil()) - Expect(s.Pop("host2")).To(Equal(mockToken(2))) - Expect(s.Pop("host3")).To(Equal(mockToken(3))) - Expect(s.Pop("host4")).To(Equal(mockToken(4))) - }) - - It("moves old entries to the front, when new tokens are added", func() { - s.Put("host1", mockToken(1)) - s.Put("host2", mockToken(2)) - s.Put("host3", mockToken(3)) - s.Put("host1", mockToken(11)) - // make sure one is evicted - s.Put("host4", mockToken(4)) - Expect(s.Pop("host2")).To(BeNil()) - Expect(s.Pop("host1")).To(Equal(mockToken(11))) - Expect(s.Pop("host1")).To(Equal(mockToken(1))) - Expect(s.Pop("host3")).To(Equal(mockToken(3))) - Expect(s.Pop("host4")).To(Equal(mockToken(4))) - }) - - It("deletes hosts that are empty", func() { - s.Put("host1", mockToken(1)) - s.Put("host2", mockToken(2)) - s.Put("host3", mockToken(3)) - Expect(s.Pop("host2")).To(Equal(mockToken(2))) - Expect(s.Pop("host2")).To(BeNil()) - // host2 is now empty and should have been deleted, making space for host4 - s.Put("host4", mockToken(4)) - Expect(s.Pop("host1")).To(Equal(mockToken(1))) - Expect(s.Pop("host3")).To(Equal(mockToken(3))) - Expect(s.Pop("host4")).To(Equal(mockToken(4))) - }) - }) -}) diff --git a/internal/quic-go/tools.go b/internal/quic-go/tools.go deleted file mode 100644 index ee68fafb..00000000 --- a/internal/quic-go/tools.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build tools -// +build tools - -package quic - -import ( - _ "github.com/cheekybits/genny" - _ "github.com/onsi/ginkgo/ginkgo" -) diff --git a/internal/quic-go/utils/atomic_bool.go b/internal/quic-go/utils/atomic_bool.go deleted file mode 100644 index cf464250..00000000 --- a/internal/quic-go/utils/atomic_bool.go +++ /dev/null @@ -1,22 +0,0 @@ -package utils - -import "sync/atomic" - -// An AtomicBool is an atomic bool -type AtomicBool struct { - v int32 -} - -// Set sets the value -func (a *AtomicBool) Set(value bool) { - var n int32 - if value { - n = 1 - } - atomic.StoreInt32(&a.v, n) -} - -// Get gets the value -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.v) != 0 -} diff --git a/internal/quic-go/utils/atomic_bool_test.go b/internal/quic-go/utils/atomic_bool_test.go deleted file mode 100644 index 83a200c2..00000000 --- a/internal/quic-go/utils/atomic_bool_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package utils - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Atomic Bool", func() { - var a *AtomicBool - - BeforeEach(func() { - a = &AtomicBool{} - }) - - It("has the right default value", func() { - Expect(a.Get()).To(BeFalse()) - }) - - It("sets the value to true", func() { - a.Set(true) - Expect(a.Get()).To(BeTrue()) - }) - - It("sets the value to false", func() { - a.Set(true) - a.Set(false) - Expect(a.Get()).To(BeFalse()) - }) -}) diff --git a/internal/quic-go/utils/buffered_write_closer.go b/internal/quic-go/utils/buffered_write_closer.go deleted file mode 100644 index b5b9d6fc..00000000 --- a/internal/quic-go/utils/buffered_write_closer.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "bufio" - "io" -) - -type bufferedWriteCloser struct { - *bufio.Writer - io.Closer -} - -// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer -func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { - return &bufferedWriteCloser{ - Writer: writer, - Closer: closer, - } -} - -func (h bufferedWriteCloser) Close() error { - if err := h.Writer.Flush(); err != nil { - return err - } - return h.Closer.Close() -} diff --git a/internal/quic-go/utils/buffered_write_closer_test.go b/internal/quic-go/utils/buffered_write_closer_test.go deleted file mode 100644 index 9c93d615..00000000 --- a/internal/quic-go/utils/buffered_write_closer_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "bufio" - "bytes" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type nopCloser struct{} - -func (nopCloser) Close() error { return nil } - -var _ = Describe("buffered io.WriteCloser", func() { - It("flushes before closing", func() { - buf := &bytes.Buffer{} - - w := bufio.NewWriter(buf) - wc := NewBufferedWriteCloser(w, &nopCloser{}) - wc.Write([]byte("foobar")) - Expect(buf.Len()).To(BeZero()) - Expect(wc.Close()).To(Succeed()) - Expect(buf.String()).To(Equal("foobar")) - }) -}) diff --git a/internal/quic-go/utils/byteinterval_linkedlist.go b/internal/quic-go/utils/byteinterval_linkedlist.go deleted file mode 100644 index 096023ef..00000000 --- a/internal/quic-go/utils/byteinterval_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// ByteIntervalElement is an element of a linked list. -type ByteIntervalElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *ByteIntervalElement - - // The list to which this element belongs. - list *ByteIntervalList - - // The value stored with this element. - Value ByteInterval -} - -// Next returns the next list element or nil. -func (e *ByteIntervalElement) Next() *ByteIntervalElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *ByteIntervalElement) Prev() *ByteIntervalElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// ByteIntervalList is a linked list of ByteIntervals. -type ByteIntervalList struct { - root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *ByteIntervalList) Init() *ByteIntervalList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewByteIntervalList returns an initialized list. -func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *ByteIntervalList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *ByteIntervalList) Front() *ByteIntervalElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *ByteIntervalList) Back() *ByteIntervalElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *ByteIntervalList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { - return l.insert(&ByteIntervalElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/quic-go/utils/byteoder_big_endian_test.go b/internal/quic-go/utils/byteoder_big_endian_test.go deleted file mode 100644 index 5d0873a9..00000000 --- a/internal/quic-go/utils/byteoder_big_endian_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package utils - -import ( - "bytes" - "io" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Big Endian encoding / decoding", func() { - Context("ReadUint16", func() { - It("reads a big endian", func() { - b := []byte{0x13, 0xEF} - val, err := BigEndian.ReadUint16(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint16(0x13EF))) - }) - - It("throws an error if less than 2 bytes are passed", func() { - b := []byte{0x13, 0xEF} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint16(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("ReadUint24", func() { - It("reads a big endian", func() { - b := []byte{0x13, 0xbe, 0xef} - val, err := BigEndian.ReadUint24(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint32(0x13beef))) - }) - - It("throws an error if less than 3 bytes are passed", func() { - b := []byte{0x13, 0xbe, 0xef} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint24(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("ReadUint32", func() { - It("reads a big endian", func() { - b := []byte{0x12, 0x35, 0xAB, 0xFF} - val, err := BigEndian.ReadUint32(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint32(0x1235ABFF))) - }) - - It("throws an error if less than 4 bytes are passed", func() { - b := []byte{0x12, 0x35, 0xAB, 0xFF} - for i := 0; i < len(b); i++ { - _, err := BigEndian.ReadUint32(bytes.NewReader(b[:i])) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("WriteUint16", func() { - It("outputs 2 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint16(b, uint16(1)) - Expect(b.Len()).To(Equal(2)) - }) - - It("outputs a big endian", func() { - num := uint16(0xFF11) - b := &bytes.Buffer{} - BigEndian.WriteUint16(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xFF, 0x11})) - }) - }) - - Context("WriteUint24", func() { - It("outputs 3 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint24(b, uint32(1)) - Expect(b.Len()).To(Equal(3)) - }) - - It("outputs a big endian", func() { - num := uint32(0xff11aa) - b := &bytes.Buffer{} - BigEndian.WriteUint24(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa})) - }) - }) - - Context("WriteUint32", func() { - It("outputs 4 bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUint32(b, uint32(1)) - Expect(b.Len()).To(Equal(4)) - }) - - It("outputs a big endian", func() { - num := uint32(0xEFAC3512) - b := &bytes.Buffer{} - BigEndian.WriteUint32(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) - }) - }) -}) diff --git a/internal/quic-go/utils/byteorder.go b/internal/quic-go/utils/byteorder.go deleted file mode 100644 index d1f52842..00000000 --- a/internal/quic-go/utils/byteorder.go +++ /dev/null @@ -1,17 +0,0 @@ -package utils - -import ( - "bytes" - "io" -) - -// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. -type ByteOrder interface { - ReadUint32(io.ByteReader) (uint32, error) - ReadUint24(io.ByteReader) (uint32, error) - ReadUint16(io.ByteReader) (uint16, error) - - WriteUint32(*bytes.Buffer, uint32) - WriteUint24(*bytes.Buffer, uint32) - WriteUint16(*bytes.Buffer, uint16) -} diff --git a/internal/quic-go/utils/byteorder_big_endian.go b/internal/quic-go/utils/byteorder_big_endian.go deleted file mode 100644 index d05542e1..00000000 --- a/internal/quic-go/utils/byteorder_big_endian.go +++ /dev/null @@ -1,89 +0,0 @@ -package utils - -import ( - "bytes" - "io" -) - -// BigEndian is the big-endian implementation of ByteOrder. -var BigEndian ByteOrder = bigEndian{} - -type bigEndian struct{} - -var _ ByteOrder = &bigEndian{} - -// ReadUintN reads N bytes -func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { - var res uint64 - for i := uint8(0); i < length; i++ { - bt, err := b.ReadByte() - if err != nil { - return 0, err - } - res ^= uint64(bt) << ((length - 1 - i) * 8) - } - return res, nil -} - -// ReadUint32 reads a uint32 -func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { - var b1, b2, b3, b4 uint8 - var err error - if b4, err = b.ReadByte(); err != nil { - return 0, err - } - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil -} - -// ReadUint24 reads a uint24 -func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { - var b1, b2, b3 uint8 - var err error - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil -} - -// ReadUint16 reads a uint16 -func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { - var b1, b2 uint8 - var err error - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint16(b1) + uint16(b2)<<8, nil -} - -// WriteUint32 writes a uint32 -func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { - b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) -} - -// WriteUint24 writes a uint24 -func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { - b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) -} - -// WriteUint16 writes a uint16 -func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { - b.Write([]byte{uint8(i >> 8), uint8(i)}) -} diff --git a/internal/quic-go/utils/gen.go b/internal/quic-go/utils/gen.go deleted file mode 100644 index 8a63e958..00000000 --- a/internal/quic-go/utils/gen.go +++ /dev/null @@ -1,5 +0,0 @@ -package utils - -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out newconnectionid_linkedlist.go gen Item=NewConnectionID diff --git a/internal/quic-go/utils/ip.go b/internal/quic-go/utils/ip.go deleted file mode 100644 index 7ac7ffec..00000000 --- a/internal/quic-go/utils/ip.go +++ /dev/null @@ -1,10 +0,0 @@ -package utils - -import "net" - -func IsIPv4(ip net.IP) bool { - // If ip is not an IPv4 address, To4 returns nil. - // Note that there might be some corner cases, where this is not correct. - // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. - return ip.To4() != nil -} diff --git a/internal/quic-go/utils/ip_test.go b/internal/quic-go/utils/ip_test.go deleted file mode 100644 index b61cf529..00000000 --- a/internal/quic-go/utils/ip_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package utils - -import ( - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("IP", func() { - It("tells IPv4 and IPv6 addresses apart", func() { - Expect(IsIPv4(net.IPv4(127, 0, 0, 1))).To(BeTrue()) - Expect(IsIPv4(net.IPv4zero)).To(BeTrue()) - Expect(IsIPv4(net.IPv6zero)).To(BeFalse()) - Expect(IsIPv4(net.IPv6loopback)).To(BeFalse()) - }) -}) diff --git a/internal/quic-go/utils/linkedlist/README.md b/internal/quic-go/utils/linkedlist/README.md deleted file mode 100644 index 15b46dce..00000000 --- a/internal/quic-go/utils/linkedlist/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Usage - -This is the Go standard library implementation of a linked list -(https://golang.org/src/container/list/list.go), modified such that genny -(https://github.com/cheekybits/genny) can be used to generate a typed linked -list. - -To generate, run -``` -genny -pkg $PACKAGE -in linkedlist.go -out $OUTFILE gen Item=$TYPE -``` diff --git a/internal/quic-go/utils/linkedlist/linkedlist.go b/internal/quic-go/utils/linkedlist/linkedlist.go deleted file mode 100644 index 74b815a8..00000000 --- a/internal/quic-go/utils/linkedlist/linkedlist.go +++ /dev/null @@ -1,218 +0,0 @@ -package linkedlist - -import "github.com/cheekybits/genny/generic" - -// Linked list implementation from the Go standard library. - -// Item is a generic type. -type Item generic.Type - -// ItemElement is an element of a linked list. -type ItemElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *ItemElement - - // The list to which this element belongs. - list *ItemList - - // The value stored with this element. - Value Item -} - -// Next returns the next list element or nil. -func (e *ItemElement) Next() *ItemElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *ItemElement) Prev() *ItemElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// ItemList is a linked list of Items. -type ItemList struct { - root ItemElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *ItemList) Init() *ItemList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewItemList returns an initialized list. -func NewItemList() *ItemList { return new(ItemList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *ItemList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *ItemList) Front() *ItemElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *ItemList) Back() *ItemElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *ItemList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *ItemList) insert(e, at *ItemElement) *ItemElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *ItemList) insertValue(v Item, at *ItemElement) *ItemElement { - return l.insert(&ItemElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *ItemList) remove(e *ItemElement) *ItemElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *ItemList) Remove(e *ItemElement) Item { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *ItemList) PushFront(v Item) *ItemElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *ItemList) PushBack(v Item) *ItemElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ItemList) InsertBefore(v Item, mark *ItemElement) *ItemElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *ItemList) InsertAfter(v Item, mark *ItemElement) *ItemElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ItemList) MoveToFront(e *ItemElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *ItemList) MoveToBack(e *ItemElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ItemList) MoveBefore(e, mark *ItemElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *ItemList) MoveAfter(e, mark *ItemElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ItemList) PushBackList(other *ItemList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *ItemList) PushFrontList(other *ItemList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/quic-go/utils/log.go b/internal/quic-go/utils/log.go deleted file mode 100644 index e04ee514..00000000 --- a/internal/quic-go/utils/log.go +++ /dev/null @@ -1,131 +0,0 @@ -package utils - -import ( - "fmt" - "log" - "os" - "strings" - "time" -) - -// LogLevel of quic-go -type LogLevel uint8 - -const ( - // LogLevelNothing disables - LogLevelNothing LogLevel = iota - // LogLevelError enables err logs - LogLevelError - // LogLevelInfo enables info logs (e.g. packets) - LogLevelInfo - // LogLevelDebug enables debug logs (e.g. packet contents) - LogLevelDebug -) - -const logEnv = "QUIC_GO_LOG_LEVEL" - -// A Logger logs. -type Logger interface { - SetLogLevel(LogLevel) - SetLogTimeFormat(format string) - WithPrefix(prefix string) Logger - Debug() bool - - Errorf(format string, args ...interface{}) - Infof(format string, args ...interface{}) - Debugf(format string, args ...interface{}) -} - -// DefaultLogger is used by quic-go for logging. -var DefaultLogger Logger - -type defaultLogger struct { - prefix string - - logLevel LogLevel - timeFormat string -} - -var _ Logger = &defaultLogger{} - -// SetLogLevel sets the log level -func (l *defaultLogger) SetLogLevel(level LogLevel) { - l.logLevel = level -} - -// SetLogTimeFormat sets the format of the timestamp -// an empty string disables the logging of timestamps -func (l *defaultLogger) SetLogTimeFormat(format string) { - log.SetFlags(0) // disable timestamp logging done by the log package - l.timeFormat = format -} - -// Debugf logs something -func (l *defaultLogger) Debugf(format string, args ...interface{}) { - if l.logLevel == LogLevelDebug { - l.logMessage(format, args...) - } -} - -// Infof logs something -func (l *defaultLogger) Infof(format string, args ...interface{}) { - if l.logLevel >= LogLevelInfo { - l.logMessage(format, args...) - } -} - -// Errorf logs something -func (l *defaultLogger) Errorf(format string, args ...interface{}) { - if l.logLevel >= LogLevelError { - l.logMessage(format, args...) - } -} - -func (l *defaultLogger) logMessage(format string, args ...interface{}) { - var pre string - - if len(l.timeFormat) > 0 { - pre = time.Now().Format(l.timeFormat) + " " - } - if len(l.prefix) > 0 { - pre += l.prefix + " " - } - log.Printf(pre+format, args...) -} - -func (l *defaultLogger) WithPrefix(prefix string) Logger { - if len(l.prefix) > 0 { - prefix = l.prefix + " " + prefix - } - return &defaultLogger{ - logLevel: l.logLevel, - timeFormat: l.timeFormat, - prefix: prefix, - } -} - -// Debug returns true if the log level is LogLevelDebug -func (l *defaultLogger) Debug() bool { - return l.logLevel == LogLevelDebug -} - -func init() { - DefaultLogger = &defaultLogger{} - DefaultLogger.SetLogLevel(readLoggingEnv()) -} - -func readLoggingEnv() LogLevel { - switch strings.ToLower(os.Getenv(logEnv)) { - case "": - return LogLevelNothing - case "debug": - return LogLevelDebug - case "info": - return LogLevelInfo - case "error": - return LogLevelError - default: - fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/imroc/req/v3/internal/quic-go/wiki/Logging") - return LogLevelNothing - } -} diff --git a/internal/quic-go/utils/log_test.go b/internal/quic-go/utils/log_test.go deleted file mode 100644 index 36edc1cc..00000000 --- a/internal/quic-go/utils/log_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package utils - -import ( - "bytes" - "log" - "os" - "time" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Log", func() { - var b *bytes.Buffer - - BeforeEach(func() { - b = &bytes.Buffer{} - log.SetOutput(b) - }) - - AfterEach(func() { - log.SetOutput(os.Stdout) - DefaultLogger.SetLogLevel(LogLevelNothing) - }) - - It("the log level has the correct numeric value", func() { - Expect(LogLevelNothing).To(BeEquivalentTo(0)) - Expect(LogLevelError).To(BeEquivalentTo(1)) - Expect(LogLevelInfo).To(BeEquivalentTo(2)) - Expect(LogLevelDebug).To(BeEquivalentTo(3)) - }) - - It("log level nothing", func() { - DefaultLogger.SetLogLevel(LogLevelNothing) - DefaultLogger.Debugf("debug") - DefaultLogger.Infof("info") - DefaultLogger.Errorf("err") - Expect(b.String()).To(BeEmpty()) - }) - - It("log level err", func() { - DefaultLogger.SetLogLevel(LogLevelError) - DefaultLogger.Debugf("debug") - DefaultLogger.Infof("info") - DefaultLogger.Errorf("err") - Expect(b.String()).To(ContainSubstring("err\n")) - Expect(b.String()).ToNot(ContainSubstring("info")) - Expect(b.String()).ToNot(ContainSubstring("debug")) - }) - - It("log level info", func() { - DefaultLogger.SetLogLevel(LogLevelInfo) - DefaultLogger.Debugf("debug") - DefaultLogger.Infof("info") - DefaultLogger.Errorf("err") - Expect(b.String()).To(ContainSubstring("err\n")) - Expect(b.String()).To(ContainSubstring("info\n")) - Expect(b.String()).ToNot(ContainSubstring("debug")) - }) - - It("log level debug", func() { - DefaultLogger.SetLogLevel(LogLevelDebug) - DefaultLogger.Debugf("debug") - DefaultLogger.Infof("info") - DefaultLogger.Errorf("err") - Expect(b.String()).To(ContainSubstring("err\n")) - Expect(b.String()).To(ContainSubstring("info\n")) - Expect(b.String()).To(ContainSubstring("debug\n")) - }) - - It("doesn't add a timestamp if the time format is empty", func() { - DefaultLogger.SetLogLevel(LogLevelDebug) - DefaultLogger.SetLogTimeFormat("") - DefaultLogger.Debugf("debug") - Expect(b.String()).To(Equal("debug\n")) - }) - - It("adds a timestamp", func() { - format := "Jan 2, 2006" - DefaultLogger.SetLogTimeFormat(format) - DefaultLogger.SetLogLevel(LogLevelInfo) - DefaultLogger.Infof("info") - t, err := time.Parse(format, b.String()[:b.Len()-6]) - Expect(err).ToNot(HaveOccurred()) - Expect(t).To(BeTemporally("~", time.Now(), 25*time.Hour)) - }) - - It("says whether debug is enabled", func() { - Expect(DefaultLogger.Debug()).To(BeFalse()) - DefaultLogger.SetLogLevel(LogLevelDebug) - Expect(DefaultLogger.Debug()).To(BeTrue()) - }) - - It("adds a prefix", func() { - DefaultLogger.SetLogLevel(LogLevelDebug) - prefixLogger := DefaultLogger.WithPrefix("prefix") - prefixLogger.Debugf("debug") - Expect(b.String()).To(ContainSubstring("prefix")) - Expect(b.String()).To(ContainSubstring("debug")) - }) - - It("adds multiple prefixes", func() { - DefaultLogger.SetLogLevel(LogLevelDebug) - prefixLogger := DefaultLogger.WithPrefix("prefix1") - prefixPrefixLogger := prefixLogger.WithPrefix("prefix2") - prefixPrefixLogger.Debugf("debug") - Expect(b.String()).To(ContainSubstring("prefix")) - Expect(b.String()).To(ContainSubstring("debug")) - }) - - Context("reading from env", func() { - BeforeEach(func() { - Expect(DefaultLogger.(*defaultLogger).logLevel).To(Equal(LogLevelNothing)) - }) - - It("reads DEBUG", func() { - os.Setenv(logEnv, "DEBUG") - Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) - }) - - It("reads debug", func() { - os.Setenv(logEnv, "debug") - Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) - }) - - It("reads INFO", func() { - os.Setenv(logEnv, "INFO") - readLoggingEnv() - Expect(readLoggingEnv()).To(Equal(LogLevelInfo)) - }) - - It("reads ERROR", func() { - os.Setenv(logEnv, "ERROR") - Expect(readLoggingEnv()).To(Equal(LogLevelError)) - }) - - It("does not error reading invalid log levels from env", func() { - os.Setenv(logEnv, "") - Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) - os.Setenv(logEnv, "asdf") - Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) - }) - }) -}) diff --git a/internal/quic-go/utils/minmax.go b/internal/quic-go/utils/minmax.go deleted file mode 100644 index 1a3448d8..00000000 --- a/internal/quic-go/utils/minmax.go +++ /dev/null @@ -1,170 +0,0 @@ -package utils - -import ( - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// InfDuration is a duration of infinite length -const InfDuration = time.Duration(math.MaxInt64) - -// Max returns the maximum of two Ints -func Max(a, b int) int { - if a < b { - return b - } - return a -} - -// MaxUint32 returns the maximum of two uint32 -func MaxUint32(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - -// MaxUint64 returns the maximum of two uint64 -func MaxUint64(a, b uint64) uint64 { - if a < b { - return b - } - return a -} - -// MinUint64 returns the maximum of two uint64 -func MinUint64(a, b uint64) uint64 { - if a < b { - return a - } - return b -} - -// Min returns the minimum of two Ints -func Min(a, b int) int { - if a < b { - return a - } - return b -} - -// MinUint32 returns the maximum of two uint32 -func MinUint32(a, b uint32) uint32 { - if a < b { - return a - } - return b -} - -// MinInt64 returns the minimum of two int64 -func MinInt64(a, b int64) int64 { - if a < b { - return a - } - return b -} - -// MaxInt64 returns the minimum of two int64 -func MaxInt64(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// MinByteCount returns the minimum of two ByteCounts -func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return a - } - return b -} - -// MaxByteCount returns the maximum of two ByteCounts -func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return b - } - return a -} - -// MaxDuration returns the max duration -func MaxDuration(a, b time.Duration) time.Duration { - if a > b { - return a - } - return b -} - -// MinDuration returns the minimum duration -func MinDuration(a, b time.Duration) time.Duration { - if a > b { - return b - } - return a -} - -// MinNonZeroDuration return the minimum duration that's not zero. -func MinNonZeroDuration(a, b time.Duration) time.Duration { - if a == 0 { - return b - } - if b == 0 { - return a - } - return MinDuration(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d -} - -// MinTime returns the earlier time -func MinTime(a, b time.Time) time.Time { - if a.After(b) { - return b - } - return a -} - -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - -// MaxTime returns the later time -func MaxTime(a, b time.Time) time.Time { - if a.After(b) { - return a - } - return b -} - -// MaxPacketNumber returns the max packet number -func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a > b { - return a - } - return b -} - -// MinPacketNumber returns the min packet number -func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a < b { - return a - } - return b -} diff --git a/internal/quic-go/utils/minmax_test.go b/internal/quic-go/utils/minmax_test.go deleted file mode 100644 index 5ee0caa5..00000000 --- a/internal/quic-go/utils/minmax_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Min / Max", func() { - Context("Max", func() { - It("returns the maximum", func() { - Expect(Max(5, 7)).To(Equal(7)) - Expect(Max(7, 5)).To(Equal(7)) - }) - - It("returns the maximum uint32", func() { - Expect(MaxUint32(5, 7)).To(Equal(uint32(7))) - Expect(MaxUint32(7, 5)).To(Equal(uint32(7))) - }) - - It("returns the maximum uint64", func() { - Expect(MaxUint64(5, 7)).To(Equal(uint64(7))) - Expect(MaxUint64(7, 5)).To(Equal(uint64(7))) - }) - - It("returns the minimum uint64", func() { - Expect(MinUint64(5, 7)).To(Equal(uint64(5))) - Expect(MinUint64(7, 5)).To(Equal(uint64(5))) - }) - - It("returns the maximum int64", func() { - Expect(MaxInt64(5, 7)).To(Equal(int64(7))) - Expect(MaxInt64(7, 5)).To(Equal(int64(7))) - }) - - It("returns the maximum ByteCount", func() { - Expect(MaxByteCount(7, 5)).To(Equal(protocol.ByteCount(7))) - Expect(MaxByteCount(5, 7)).To(Equal(protocol.ByteCount(7))) - }) - - It("returns the maximum duration", func() { - Expect(MaxDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Microsecond)) - Expect(MaxDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Microsecond)) - }) - - It("returns the minimum duration", func() { - Expect(MinDuration(time.Microsecond, time.Nanosecond)).To(Equal(time.Nanosecond)) - Expect(MinDuration(time.Nanosecond, time.Microsecond)).To(Equal(time.Nanosecond)) - }) - - It("returns packet number max", func() { - Expect(MaxPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(2))) - Expect(MaxPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(2))) - }) - - It("returns the maximum time", func() { - a := time.Now() - b := a.Add(time.Second) - Expect(MaxTime(a, b)).To(Equal(b)) - Expect(MaxTime(b, a)).To(Equal(b)) - }) - }) - - Context("Min", func() { - It("returns the minimum", func() { - Expect(Min(5, 7)).To(Equal(5)) - Expect(Min(7, 5)).To(Equal(5)) - }) - - It("returns the minimum uint32", func() { - Expect(MinUint32(7, 5)).To(Equal(uint32(5))) - Expect(MinUint32(5, 7)).To(Equal(uint32(5))) - }) - - It("returns the minimum int64", func() { - Expect(MinInt64(7, 5)).To(Equal(int64(5))) - Expect(MinInt64(5, 7)).To(Equal(int64(5))) - }) - - It("returns the minimum ByteCount", func() { - Expect(MinByteCount(7, 5)).To(Equal(protocol.ByteCount(5))) - Expect(MinByteCount(5, 7)).To(Equal(protocol.ByteCount(5))) - }) - - It("returns packet number min", func() { - Expect(MinPacketNumber(1, 2)).To(Equal(protocol.PacketNumber(1))) - Expect(MinPacketNumber(2, 1)).To(Equal(protocol.PacketNumber(1))) - }) - - It("returns the minimum duration", func() { - a := time.Now() - b := a.Add(time.Second) - Expect(MinTime(a, b)).To(Equal(a)) - Expect(MinTime(b, a)).To(Equal(a)) - }) - - It("returns the minium non-zero duration", func() { - var a time.Duration - b := time.Second - Expect(MinNonZeroDuration(0, 0)).To(BeZero()) - Expect(MinNonZeroDuration(a, b)).To(Equal(b)) - Expect(MinNonZeroDuration(b, a)).To(Equal(b)) - Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) - }) - - It("returns the minium non-zero time", func() { - a := time.Time{} - b := time.Now() - Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) - Expect(MinNonZeroTime(a, b)).To(Equal(b)) - Expect(MinNonZeroTime(b, a)).To(Equal(b)) - Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) - Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) - }) - }) - - It("returns the abs time", func() { - Expect(AbsDuration(time.Microsecond)).To(Equal(time.Microsecond)) - Expect(AbsDuration(-time.Microsecond)).To(Equal(time.Microsecond)) - }) -}) diff --git a/internal/quic-go/utils/new_connection_id.go b/internal/quic-go/utils/new_connection_id.go deleted file mode 100644 index 41b7484f..00000000 --- a/internal/quic-go/utils/new_connection_id.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import ( - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// NewConnectionID is a new connection ID -type NewConnectionID struct { - SequenceNumber uint64 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} diff --git a/internal/quic-go/utils/newconnectionid_linkedlist.go b/internal/quic-go/utils/newconnectionid_linkedlist.go deleted file mode 100644 index d59562e5..00000000 --- a/internal/quic-go/utils/newconnectionid_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// NewConnectionIDElement is an element of a linked list. -type NewConnectionIDElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *NewConnectionIDElement - - // The list to which this element belongs. - list *NewConnectionIDList - - // The value stored with this element. - Value NewConnectionID -} - -// Next returns the next list element or nil. -func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// NewConnectionIDList is a linked list of NewConnectionIDs. -type NewConnectionIDList struct { - root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *NewConnectionIDList) Init() *NewConnectionIDList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewNewConnectionIDList returns an initialized list. -func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *NewConnectionIDList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Front() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Back() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *NewConnectionIDList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { - return l.insert(&NewConnectionIDElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/quic-go/utils/packet_interval.go b/internal/quic-go/utils/packet_interval.go deleted file mode 100644 index 5518141d..00000000 --- a/internal/quic-go/utils/packet_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/imroc/req/v3/internal/quic-go/protocol" - -// PacketInterval is an interval from one PacketNumber to the other -type PacketInterval struct { - Start protocol.PacketNumber - End protocol.PacketNumber -} diff --git a/internal/quic-go/utils/packetinterval_linkedlist.go b/internal/quic-go/utils/packetinterval_linkedlist.go deleted file mode 100644 index b461e85a..00000000 --- a/internal/quic-go/utils/packetinterval_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// PacketIntervalElement is an element of a linked list. -type PacketIntervalElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *PacketIntervalElement - - // The list to which this element belongs. - list *PacketIntervalList - - // The value stored with this element. - Value PacketInterval -} - -// Next returns the next list element or nil. -func (e *PacketIntervalElement) Next() *PacketIntervalElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *PacketIntervalElement) Prev() *PacketIntervalElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// PacketIntervalList is a linked list of PacketIntervals. -type PacketIntervalList struct { - root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *PacketIntervalList) Init() *PacketIntervalList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewPacketIntervalList returns an initialized list. -func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *PacketIntervalList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *PacketIntervalList) Front() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *PacketIntervalList) Back() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *PacketIntervalList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { - return l.insert(&PacketIntervalElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/internal/quic-go/utils/rand.go b/internal/quic-go/utils/rand.go deleted file mode 100644 index 30069144..00000000 --- a/internal/quic-go/utils/rand.go +++ /dev/null @@ -1,29 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/binary" -) - -// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. -type Rand struct { - buf [4]byte -} - -func (r *Rand) Int31() int32 { - rand.Read(r.buf[:]) - return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) -} - -// copied from the standard library math/rand implementation of Int63n -func (r *Rand) Int31n(n int32) int32 { - if n&(n-1) == 0 { // n is power of two, can mask - return r.Int31() & (n - 1) - } - max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) - v := r.Int31() - for v > max { - v = r.Int31() - } - return v % n -} diff --git a/internal/quic-go/utils/rand_test.go b/internal/quic-go/utils/rand_test.go deleted file mode 100644 index f15a644e..00000000 --- a/internal/quic-go/utils/rand_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package utils - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Rand", func() { - It("generates random numbers", func() { - const ( - num = 1000 - max = 12345678 - ) - - var values [num]int32 - var r Rand - for i := 0; i < num; i++ { - v := r.Int31n(max) - Expect(v).To(And( - BeNumerically(">=", 0), - BeNumerically("<", max), - )) - values[i] = v - } - - var sum uint64 - for _, n := range values { - sum += uint64(n) - } - Expect(float64(sum) / num).To(BeNumerically("~", max/2, max/25)) - }) -}) diff --git a/internal/quic-go/utils/rtt_stats.go b/internal/quic-go/utils/rtt_stats.go deleted file mode 100644 index 75bfc6d3..00000000 --- a/internal/quic-go/utils/rtt_stats.go +++ /dev/null @@ -1,127 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -const ( - rttAlpha = 0.125 - oneMinusAlpha = 1 - rttAlpha - rttBeta = 0.25 - oneMinusBeta = 1 - rttBeta - // The default RTT used before an RTT sample is taken. - defaultInitialRTT = 100 * time.Millisecond -) - -// RTTStats provides round-trip statistics -type RTTStats struct { - hasMeasurement bool - - minRTT time.Duration - latestRTT time.Duration - smoothedRTT time.Duration - meanDeviation time.Duration - - maxAckDelay time.Duration -} - -// NewRTTStats makes a properly initialized RTTStats object -func NewRTTStats() *RTTStats { - return &RTTStats{} -} - -// MinRTT Returns the minRTT for the entire connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } - -// LatestRTT returns the most recent rtt measurement. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } - -// SmoothedRTT returns the smoothed RTT for the connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } - -// MeanDeviation gets the mean deviation -func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } - -// MaxAckDelay gets the max_ack_delay advertised by the peer -func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } - -// PTO gets the probe timeout duration. -func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { - if r.SmoothedRTT() == 0 { - return 2 * defaultInitialRTT - } - pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) - if includeMaxAckDelay { - pto += r.MaxAckDelay() - } - return pto -} - -// UpdateRTT updates the RTT based on a new sample. -func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { - if sendDelta == InfDuration || sendDelta <= 0 { - return - } - - // Update r.minRTT first. r.minRTT does not use an rttSample corrected for - // ackDelay but the raw observed sendDelta, since poor clock granularity at - // the client may cause a high ackDelay to result in underestimation of the - // r.minRTT. - if r.minRTT == 0 || r.minRTT > sendDelta { - r.minRTT = sendDelta - } - - // Correct for ackDelay if information received from the peer results in a - // an RTT sample at least as large as minRTT. Otherwise, only use the - // sendDelta. - sample := sendDelta - if sample-r.minRTT >= ackDelay { - sample -= ackDelay - } - r.latestRTT = sample - // First time call. - if !r.hasMeasurement { - r.hasMeasurement = true - r.smoothedRTT = sample - r.meanDeviation = sample / 2 - } else { - r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond - r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond - } -} - -// SetMaxAckDelay sets the max_ack_delay -func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { - r.maxAckDelay = mad -} - -// SetInitialRTT sets the initial RTT. -// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. -func (r *RTTStats) SetInitialRTT(t time.Duration) { - if r.hasMeasurement { - panic("initial RTT set after first measurement") - } - r.smoothedRTT = t - r.latestRTT = t -} - -// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. -func (r *RTTStats) OnConnectionMigration() { - r.latestRTT = 0 - r.minRTT = 0 - r.smoothedRTT = 0 - r.meanDeviation = 0 -} - -// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt -// is larger. The mean deviation is increased to the most recent deviation if -// it's larger. -func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) -} diff --git a/internal/quic-go/utils/rtt_stats_test.go b/internal/quic-go/utils/rtt_stats_test.go deleted file mode 100644 index a0de1b93..00000000 --- a/internal/quic-go/utils/rtt_stats_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package utils - -import ( - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("RTT stats", func() { - var rttStats *RTTStats - - BeforeEach(func() { - rttStats = NewRTTStats() - }) - - It("DefaultsBeforeUpdate", func() { - Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) - }) - - It("SmoothedRTT", func() { - // Verify that ack_delay is ignored in the first measurement. - rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) - // Verify that Smoothed RTT includes max ack delay if it's reasonable. - rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond))) - // Verify that large erroneous ack_delay does not change Smoothed RTT. - rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond))) - }) - - It("MinRTT", func() { - rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond))) - // Verify that ack_delay does not go into recording of MinRTT_. - rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond))) - }) - - It("MaxAckDelay", func() { - rttStats.SetMaxAckDelay(42 * time.Minute) - Expect(rttStats.MaxAckDelay()).To(Equal(42 * time.Minute)) - }) - - It("computes the PTO", func() { - maxAckDelay := 42 * time.Minute - rttStats.SetMaxAckDelay(maxAckDelay) - rtt := time.Second - rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) - Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) - Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) - Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) - }) - - It("uses the granularity for computing the PTO for short RTTs", func() { - rtt := time.Microsecond - rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) - }) - - It("ExpireSmoothedMetrics", func() { - initialRtt := (10 * time.Millisecond) - rttStats.UpdateRTT(initialRtt, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - - Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2)) - - // Update once with a 20ms RTT. - doubledRtt := initialRtt * (2) - rttStats.UpdateRTT(doubledRtt, 0, time.Time{}) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125))) - - // Expire the smoothed metrics, increasing smoothed rtt and mean deviation. - rttStats.ExpireSmoothedMetrics() - Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt)) - Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875))) - - // Now go back down to 5ms and expire the smoothed metrics, and ensure the - // mean deviation increases to 15ms. - halfRtt := initialRtt / 2 - rttStats.UpdateRTT(halfRtt, 0, time.Time{}) - Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT())) - Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation())) - }) - - It("UpdateRTTWithBadSendDeltas", func() { - // Make sure we ignore bad RTTs. - // base::test::MockLog log; - - initialRtt := (10 * time.Millisecond) - rttStats.UpdateRTT(initialRtt, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - - badSendDeltas := []time.Duration{ - 0, - InfDuration, - -1000 * time.Microsecond, - } - // log.StartCapturingLogs(); - - for _, badSendDelta := range badSendDeltas { - // SCOPED_TRACE(Message() << "bad_send_delta = " - // << bad_send_delta.ToMicroseconds()); - // EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring"))); - rttStats.UpdateRTT(badSendDelta, 0, time.Time{}) - Expect(rttStats.MinRTT()).To(Equal(initialRtt)) - Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt)) - } - }) - - It("ResetAfterConnectionMigrations", func() { - rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond))) - Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond))) - - // Reset rtt stats on connection migrations. - rttStats.OnConnectionMigration() - Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0))) - Expect(rttStats.MinRTT()).To(Equal(time.Duration(0))) - }) - - It("restores the RTT", func() { - rttStats.SetInitialRTT(10 * time.Second) - Expect(rttStats.LatestRTT()).To(Equal(10 * time.Second)) - Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second)) - Expect(rttStats.MeanDeviation()).To(BeZero()) - // update the RTT and make sure that the initial value is immediately forgotten - rttStats.UpdateRTT(200*time.Millisecond, 0, time.Time{}) - Expect(rttStats.LatestRTT()).To(Equal(200 * time.Millisecond)) - Expect(rttStats.SmoothedRTT()).To(Equal(200 * time.Millisecond)) - Expect(rttStats.MeanDeviation()).To(Equal(100 * time.Millisecond)) - }) -}) diff --git a/internal/quic-go/utils/streamframe_interval.go b/internal/quic-go/utils/streamframe_interval.go deleted file mode 100644 index 4efbd64c..00000000 --- a/internal/quic-go/utils/streamframe_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/imroc/req/v3/internal/quic-go/protocol" - -// ByteInterval is an interval from one ByteCount to the other -type ByteInterval struct { - Start protocol.ByteCount - End protocol.ByteCount -} diff --git a/internal/quic-go/utils/timer.go b/internal/quic-go/utils/timer.go deleted file mode 100644 index a4f5e67a..00000000 --- a/internal/quic-go/utils/timer.go +++ /dev/null @@ -1,53 +0,0 @@ -package utils - -import ( - "math" - "time" -) - -// A Timer wrapper that behaves correctly when resetting -type Timer struct { - t *time.Timer - read bool - deadline time.Time -} - -// NewTimer creates a new timer that is not set -func NewTimer() *Timer { - return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))} -} - -// Chan returns the channel of the wrapped timer -func (t *Timer) Chan() <-chan time.Time { - return t.t.C -} - -// Reset the timer, no matter whether the value was read or not -func (t *Timer) Reset(deadline time.Time) { - if deadline.Equal(t.deadline) && !t.read { - // No need to reset the timer - return - } - - // We need to drain the timer if the value from its channel was not read yet. - // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU - if !t.t.Stop() && !t.read { - <-t.t.C - } - if !deadline.IsZero() { - t.t.Reset(time.Until(deadline)) - } - - t.read = false - t.deadline = deadline -} - -// SetRead should be called after the value from the chan was read -func (t *Timer) SetRead() { - t.read = true -} - -// Stop stops the timer -func (t *Timer) Stop() { - t.t.Stop() -} diff --git a/internal/quic-go/utils/timer_test.go b/internal/quic-go/utils/timer_test.go deleted file mode 100644 index 0cbb4a01..00000000 --- a/internal/quic-go/utils/timer_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package utils - -import ( - "time" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Timer", func() { - const d = 10 * time.Millisecond - - It("doesn't fire a newly created timer", func() { - t := NewTimer() - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("works", func() { - t := NewTimer() - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("works multiple times with reading", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - t.SetRead() - } - }) - - It("works multiple times without reading", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(d)) - time.Sleep(d * 2) - } - Eventually(t.Chan()).Should(Receive()) - }) - - It("works when resetting without expiration", func() { - t := NewTimer() - for i := 0; i < 10; i++ { - t.Reset(time.Now().Add(time.Hour)) - } - t.Reset(time.Now().Add(d)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("immediately fires the timer, if the deadlines has already passed", func() { - t := NewTimer() - t.Reset(time.Now().Add(-time.Second)) - Eventually(t.Chan()).Should(Receive()) - }) - - It("doesn't set a timer if the deadline is the zero value", func() { - t := NewTimer() - t.Reset(time.Time{}) - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("fires the timer twice, if reset to the same deadline", func() { - deadline := time.Now().Add(-time.Millisecond) - t := NewTimer() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - t.SetRead() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - }) - - It("only fires the timer once, if it is reset to the same deadline, but not read in between", func() { - deadline := time.Now().Add(-time.Millisecond) - t := NewTimer() - t.Reset(deadline) - Eventually(t.Chan()).Should(Receive()) - Consistently(t.Chan()).ShouldNot(Receive()) - }) - - It("stops", func() { - t := NewTimer() - t.Reset(time.Now().Add(50 * time.Millisecond)) - t.Stop() - Consistently(t.Chan()).ShouldNot(Receive()) - }) -}) diff --git a/internal/quic-go/utils/utils_suite_test.go b/internal/quic-go/utils/utils_suite_test.go deleted file mode 100644 index 9ecb8c05..00000000 --- a/internal/quic-go/utils/utils_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package utils - -import ( - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestCrypto(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Utils Suite") -} diff --git a/internal/quic-go/window_update_queue.go b/internal/quic-go/window_update_queue.go deleted file mode 100644 index 67e4ac5f..00000000 --- a/internal/quic-go/window_update_queue.go +++ /dev/null @@ -1,71 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/imroc/req/v3/internal/quic-go/flowcontrol" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" -) - -type windowUpdateQueue struct { - mutex sync.Mutex - - queue map[protocol.StreamID]struct{} // used as a set - queuedConn bool // connection-level window update - - streamGetter streamGetter - connFlowController flowcontrol.ConnectionFlowController - callback func(wire.Frame) -} - -func newWindowUpdateQueue( - streamGetter streamGetter, - connFC flowcontrol.ConnectionFlowController, - cb func(wire.Frame), -) *windowUpdateQueue { - return &windowUpdateQueue{ - queue: make(map[protocol.StreamID]struct{}), - streamGetter: streamGetter, - connFlowController: connFC, - callback: cb, - } -} - -func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { - q.mutex.Lock() - q.queue[id] = struct{}{} - q.mutex.Unlock() -} - -func (q *windowUpdateQueue) AddConnection() { - q.mutex.Lock() - q.queuedConn = true - q.mutex.Unlock() -} - -func (q *windowUpdateQueue) QueueAll() { - q.mutex.Lock() - // queue a connection-level window update - if q.queuedConn { - q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()}) - q.queuedConn = false - } - // queue all stream-level window updates - for id := range q.queue { - delete(q.queue, id) - str, err := q.streamGetter.GetOrOpenReceiveStream(id) - if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update - continue - } - offset := str.getWindowUpdate() - if offset == 0 { // can happen if we received a final offset, right after queueing the window update - continue - } - q.callback(&wire.MaxStreamDataFrame{ - StreamID: id, - MaximumStreamData: offset, - }) - } - q.mutex.Unlock() -} diff --git a/internal/quic-go/window_update_queue_test.go b/internal/quic-go/window_update_queue_test.go deleted file mode 100644 index bacefb23..00000000 --- a/internal/quic-go/window_update_queue_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package quic - -import ( - "github.com/imroc/req/v3/internal/quic-go/mocks" - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Window Update Queue", func() { - var ( - q *windowUpdateQueue - streamGetter *MockStreamGetter - connFC *mocks.MockConnectionFlowController - queuedFrames []wire.Frame - ) - - BeforeEach(func() { - streamGetter = NewMockStreamGetter(mockCtrl) - connFC = mocks.NewMockConnectionFlowController(mockCtrl) - queuedFrames = queuedFrames[:0] - q = newWindowUpdateQueue(streamGetter, connFC, func(f wire.Frame) { - queuedFrames = append(queuedFrames, f) - }) - }) - - It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { - stream1 := NewMockStreamI(mockCtrl) - stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10)) - stream3 := NewMockStreamI(mockCtrl) - stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil) - q.AddStream(3) - q.AddStream(1) - q.QueueAll() - Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, MaximumStreamData: 10})) - Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, MaximumStreamData: 30})) - }) - - It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { - stream10 := NewMockStreamI(mockCtrl) - stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) - q.AddStream(10) - q.QueueAll() - Expect(queuedFrames).To(HaveLen(1)) - q.QueueAll() - Expect(queuedFrames).To(HaveLen(1)) - }) - - It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { - q.AddStream(12) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("removes closed streams from the queue", func() { - q.AddStream(12) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - // don't EXPECT any further calls to GetOrOpenReceiveStream - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() { - stream5 := NewMockStreamI(mockCtrl) - stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) - q.AddStream(5) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("removes streams for which the flow controller returns an offset of 0 from the queue", func() { - stream5 := NewMockStreamI(mockCtrl) - stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) - q.AddStream(5) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - // don't EXPECT any further calls to GetOrOpenReveiveStream and to getWindowUpdate - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("queues MAX_DATA frames", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) - q.AddConnection() - q.QueueAll() - Expect(queuedFrames).To(Equal([]wire.Frame{ - &wire.MaxDataFrame{MaximumData: 0x1337}, - })) - }) - - It("deduplicates", func() { - stream10 := NewMockStreamI(mockCtrl) - stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) - streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) - q.AddStream(10) - q.AddStream(10) - q.QueueAll() - Expect(queuedFrames).To(Equal([]wire.Frame{ - &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200}, - })) - }) -}) diff --git a/internal/quic-go/wire/ack_frame.go b/internal/quic-go/wire/ack_frame.go deleted file mode 100644 index 68dc6aa8..00000000 --- a/internal/quic-go/wire/ack_frame.go +++ /dev/null @@ -1,251 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "sort" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") - -// An AckFrame is an ACK frame -type AckFrame struct { - AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last - DelayTime time.Duration - - ECT0, ECT1, ECNCE uint64 -} - -// parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - ecn := typeByte&0x1 > 0 - - frame := &AckFrame{} - - la, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - largestAcked := protocol.PacketNumber(la) - delay, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - delayTime := time.Duration(delay*1< largestAcked { - return nil, errors.New("invalid first ACK range") - } - smallest := largestAcked - ackBlock - - // read all the other ACK ranges - frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) - for i := uint64(0); i < numBlocks; i++ { - g, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - gap := protocol.PacketNumber(g) - if smallest < gap+2 { - return nil, errInvalidAckRanges - } - largest := smallest - gap - 2 - - ab, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - ackBlock := protocol.PacketNumber(ab) - - if ackBlock > largest { - return nil, errInvalidAckRanges - } - smallest = largest - ackBlock - frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) - } - - if !frame.validateAckRanges() { - return nil, errInvalidAckRanges - } - - // parse (and skip) the ECN section - if ecn { - for i := 0; i < 3; i++ { - if _, err := quicvarint.Read(r); err != nil { - return nil, err - } - } - } - - return frame, nil -} - -// Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 - if hasECN { - b.WriteByte(0x3) - } else { - b.WriteByte(0x2) - } - quicvarint.Write(b, uint64(f.LargestAcked())) - quicvarint.Write(b, encodeAckDelay(f.DelayTime)) - - numRanges := f.numEncodableAckRanges() - quicvarint.Write(b, uint64(numRanges-1)) - - // write the first range - _, firstRange := f.encodeAckRange(0) - quicvarint.Write(b, firstRange) - - // write all the other range - for i := 1; i < numRanges; i++ { - gap, len := f.encodeAckRange(i) - quicvarint.Write(b, gap) - quicvarint.Write(b, len) - } - - if hasECN { - quicvarint.Write(b, f.ECT0) - quicvarint.Write(b, f.ECT1) - quicvarint.Write(b, f.ECNCE) - } - return nil -} - -// Length of a written frame -func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - largestAcked := f.AckRanges[0].Largest - numRanges := f.numEncodableAckRanges() - - length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) - - length += quicvarint.Len(uint64(numRanges - 1)) - lowestInFirstRange := f.AckRanges[0].Smallest - length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) - - for i := 1; i < numRanges; i++ { - gap, len := f.encodeAckRange(i) - length += quicvarint.Len(gap) - length += quicvarint.Len(len) - } - if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { - length += quicvarint.Len(f.ECT0) - length += quicvarint.Len(f.ECT1) - length += quicvarint.Len(f.ECNCE) - } - return length -} - -// gets the number of ACK ranges that can be encoded -// such that the resulting frame is smaller than the maximum ACK frame size -func (f *AckFrame) numEncodableAckRanges() int { - length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) - length += 2 // assume that the number of ranges will consume 2 bytes - for i := 1; i < len(f.AckRanges); i++ { - gap, len := f.encodeAckRange(i) - rangeLen := quicvarint.Len(gap) + quicvarint.Len(len) - if length+rangeLen > protocol.MaxAckFrameSize { - // Writing range i would exceed the MaxAckFrameSize. - // So encode one range less than that. - return i - 1 - } - length += rangeLen - } - return len(f.AckRanges) -} - -func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { - if i == 0 { - return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) - } - return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), - uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) -} - -// HasMissingRanges returns if this frame reports any missing packets -func (f *AckFrame) HasMissingRanges() bool { - return len(f.AckRanges) > 1 -} - -func (f *AckFrame) validateAckRanges() bool { - if len(f.AckRanges) == 0 { - return false - } - - // check the validity of every single ACK range - for _, ackRange := range f.AckRanges { - if ackRange.Smallest > ackRange.Largest { - return false - } - } - - // check the consistency for ACK with multiple NACK ranges - for i, ackRange := range f.AckRanges { - if i == 0 { - continue - } - lastAckRange := f.AckRanges[i-1] - if lastAckRange.Smallest <= ackRange.Smallest { - return false - } - if lastAckRange.Smallest <= ackRange.Largest+1 { - return false - } - } - - return true -} - -// LargestAcked is the largest acked packet number -func (f *AckFrame) LargestAcked() protocol.PacketNumber { - return f.AckRanges[0].Largest -} - -// LowestAcked is the lowest acked packet number -func (f *AckFrame) LowestAcked() protocol.PacketNumber { - return f.AckRanges[len(f.AckRanges)-1].Smallest -} - -// AcksPacket determines if this ACK frame acks a certain packet number -func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { - if p < f.LowestAcked() || p > f.LargestAcked() { - return false - } - - i := sort.Search(len(f.AckRanges), func(i int) bool { - return p >= f.AckRanges[i].Smallest - }) - // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked - return p <= f.AckRanges[i].Largest -} - -func encodeAckDelay(delay time.Duration) uint64 { - return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) -} diff --git a/internal/quic-go/wire/ack_frame_test.go b/internal/quic-go/wire/ack_frame_test.go deleted file mode 100644 index c57d99ff..00000000 --- a/internal/quic-go/wire/ack_frame_test.go +++ /dev/null @@ -1,454 +0,0 @@ -package wire - -import ( - "bytes" - "io" - "math" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("ACK Frame (for IETF QUIC)", func() { - Context("parsing", func() { - It("parses an ACK frame without any ranges", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(10)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("parses an ACK frame that only acks a single packet", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(55)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts an ACK frame that acks all packets from 0 to largest", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(20)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(20)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects an ACK frame that has a first ACK block which is larger than LargestAcked", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(20)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(21)...) // first ack block - b := bytes.NewReader(data) - _, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError("invalid first ACK range")) - }) - - It("parses an ACK frame that has a single block", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(frame.AckRanges).To(Equal([]AckRange{ - {Largest: 1000, Smallest: 900}, - {Largest: 800, Smallest: 750}, - })) - Expect(b.Len()).To(BeZero()) - }) - - It("parses an ACK frame that has a multiple blocks", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(2)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - data = append(data, encodeVarInt(0)...) // gap - data = append(data, encodeVarInt(0)...) // ack block - data = append(data, encodeVarInt(1)...) // gap - data = append(data, encodeVarInt(1)...) // ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(frame.AckRanges).To(Equal([]AckRange{ - {Largest: 100, Smallest: 100}, - {Largest: 98, Smallest: 98}, - {Largest: 95, Smallest: 94}, - })) - Expect(b.Len()).To(BeZero()) - }) - - It("uses the ack delay exponent", func() { - const delayTime = 1 << 10 * time.Millisecond - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: delayTime, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - for i := uint8(0); i < 8; i++ { - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) - } - }) - - It("gracefully handles overflows of the delay time", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(math.MaxUint64/5)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.DelayTime).To(BeNumerically(">", 0)) - // The maximum encodable duration is ~292 years. - Expect(frame.DelayTime.Hours()).To(BeNumerically("~", 292*365*24, 365*24)) - }) - - It("errors on EOF", func() { - data := []byte{0x2} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - - Context("ACK_ECN", func() { - It("parses", func() { - data := []byte{0x3} - data = append(data, encodeVarInt(100)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(0)...) // num blocks - data = append(data, encodeVarInt(10)...) // first ack block - data = append(data, encodeVarInt(0x42)...) // ECT(0) - data = append(data, encodeVarInt(0x12345)...) // ECT(1) - data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) - Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOF", func() { - data := []byte{0x3} - data = append(data, encodeVarInt(1000)...) // largest acked - data = append(data, encodeVarInt(0)...) // delay - data = append(data, encodeVarInt(1)...) // num blocks - data = append(data, encodeVarInt(100)...) // first ack block - data = append(data, encodeVarInt(98)...) // gap - data = append(data, encodeVarInt(50)...) // ack block - data = append(data, encodeVarInt(0x42)...) // ECT(0) - data = append(data, encodeVarInt(0x12345)...) // ECT(1) - data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - }) - - Context("when writing", func() { - It("writes a simple frame", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x2} - expected = append(expected, encodeVarInt(1337)...) // largest acked - expected = append(expected, 0) // delay - expected = append(expected, encodeVarInt(0)...) // num ranges - expected = append(expected, encodeVarInt(1337-100)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes an ACK-ECN frame", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 10, Largest: 2000}}, - ECT0: 13, - ECT1: 37, - ECNCE: 12345, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - expected := []byte{0x3} - expected = append(expected, encodeVarInt(2000)...) // largest acked - expected = append(expected, 0) // delay - expected = append(expected, encodeVarInt(0)...) // num ranges - expected = append(expected, encodeVarInt(2000-10)...) - expected = append(expected, encodeVarInt(13)...) - expected = append(expected, encodeVarInt(37)...) - expected = append(expected, encodeVarInt(12345)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a frame that acks a single packet", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, - DelayTime: 18 * time.Millisecond, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(frame.DelayTime).To(Equal(f.DelayTime)) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame that acks many packets", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame with a a single gap", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 400, Largest: 1000}, - {Smallest: 100, Largest: 200}, - }, - } - Expect(f.validateAckRanges()).To(BeTrue()) - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - }) - - It("writes a frame with multiple ranges", func() { - buf := &bytes.Buffer{} - f := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 10, Largest: 10}, - {Smallest: 8, Largest: 8}, - {Smallest: 5, Largest: 6}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - }) - - It("limits the maximum size of the ACK frame", func() { - buf := &bytes.Buffer{} - const numRanges = 1000 - ackRanges := make([]AckRange, numRanges) - for i := protocol.PacketNumber(1); i <= numRanges; i++ { - ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i} - } - f := &AckFrame{AckRanges: ackRanges} - Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize - Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) - Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) - Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges - }) - }) - - Context("ACK range validator", func() { - It("rejects ACKs without ranges", func() { - Expect((&AckFrame{}).validateAckRanges()).To(BeFalse()) - }) - - It("accepts an ACK without NACK Ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 7}}, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - - It("rejects ACK ranges with Smallest greater than Largest", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 8, Largest: 10}, - {Smallest: 4, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects ACK ranges in the wrong order", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 2, Largest: 2}, - {Smallest: 6, Largest: 7}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects with overlapping ACK ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 7}, - {Smallest: 2, Largest: 5}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects ACK ranges that are part of a larger ACK range", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 4, Largest: 7}, - {Smallest: 5, Largest: 6}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("rejects with directly adjacent ACK ranges", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 7}, - {Smallest: 2, Largest: 4}, - }, - } - Expect(ack.validateAckRanges()).To(BeFalse()) - }) - - It("accepts an ACK with one lost packet", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 10}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - - It("accepts an ACK with multiple lost packets", func() { - ack := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 15, Largest: 20}, - {Smallest: 10, Largest: 12}, - {Smallest: 1, Largest: 3}, - }, - } - Expect(ack.validateAckRanges()).To(BeTrue()) - }) - }) - - Context("check if ACK frame acks a certain packet", func() { - It("works with an ACK without any ranges", func() { - f := AckFrame{ - AckRanges: []AckRange{{Smallest: 5, Largest: 10}}, - } - Expect(f.AcksPacket(1)).To(BeFalse()) - Expect(f.AcksPacket(4)).To(BeFalse()) - Expect(f.AcksPacket(5)).To(BeTrue()) - Expect(f.AcksPacket(8)).To(BeTrue()) - Expect(f.AcksPacket(10)).To(BeTrue()) - Expect(f.AcksPacket(11)).To(BeFalse()) - Expect(f.AcksPacket(20)).To(BeFalse()) - }) - - It("works with an ACK with multiple ACK ranges", func() { - f := AckFrame{ - AckRanges: []AckRange{ - {Smallest: 15, Largest: 20}, - {Smallest: 5, Largest: 8}, - }, - } - Expect(f.AcksPacket(4)).To(BeFalse()) - Expect(f.AcksPacket(5)).To(BeTrue()) - Expect(f.AcksPacket(6)).To(BeTrue()) - Expect(f.AcksPacket(7)).To(BeTrue()) - Expect(f.AcksPacket(8)).To(BeTrue()) - Expect(f.AcksPacket(9)).To(BeFalse()) - Expect(f.AcksPacket(14)).To(BeFalse()) - Expect(f.AcksPacket(15)).To(BeTrue()) - Expect(f.AcksPacket(18)).To(BeTrue()) - Expect(f.AcksPacket(19)).To(BeTrue()) - Expect(f.AcksPacket(20)).To(BeTrue()) - Expect(f.AcksPacket(21)).To(BeFalse()) - }) - }) -}) diff --git a/internal/quic-go/wire/ack_range.go b/internal/quic-go/wire/ack_range.go deleted file mode 100644 index e373835c..00000000 --- a/internal/quic-go/wire/ack_range.go +++ /dev/null @@ -1,14 +0,0 @@ -package wire - -import "github.com/imroc/req/v3/internal/quic-go/protocol" - -// AckRange is an ACK range -type AckRange struct { - Smallest protocol.PacketNumber - Largest protocol.PacketNumber -} - -// Len returns the number of packets contained in this ACK range -func (r AckRange) Len() protocol.PacketNumber { - return r.Largest - r.Smallest + 1 -} diff --git a/internal/quic-go/wire/ack_range_test.go b/internal/quic-go/wire/ack_range_test.go deleted file mode 100644 index 84ef71b5..00000000 --- a/internal/quic-go/wire/ack_range_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package wire - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("ACK range", func() { - It("returns the length", func() { - Expect(AckRange{Smallest: 10, Largest: 10}.Len()).To(BeEquivalentTo(1)) - Expect(AckRange{Smallest: 10, Largest: 13}.Len()).To(BeEquivalentTo(4)) - }) -}) diff --git a/internal/quic-go/wire/connection_close_frame.go b/internal/quic-go/wire/connection_close_frame.go deleted file mode 100644 index 1fe48837..00000000 --- a/internal/quic-go/wire/connection_close_frame.go +++ /dev/null @@ -1,83 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A ConnectionCloseFrame is a CONNECTION_CLOSE frame -type ConnectionCloseFrame struct { - IsApplicationError bool - ErrorCode uint64 - FrameType uint64 - ReasonPhrase string -} - -func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} - ec, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.ErrorCode = ec - // read the Frame Type, if this is not an application error - if !f.IsApplicationError { - ft, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.FrameType = ft - } - var reasonPhraseLen uint64 - reasonPhraseLen, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - // shortcut to prevent the unnecessary allocation of dataLen bytes - // if the dataLen is larger than the remaining length of the packet - // reading the whole reason phrase would result in EOF when attempting to READ - if int(reasonPhraseLen) > r.Len() { - return nil, io.EOF - } - - reasonPhrase := make([]byte, reasonPhraseLen) - if _, err := io.ReadFull(r, reasonPhrase); err != nil { - // this should never happen, since we already checked the reasonPhraseLen earlier - return nil, err - } - f.ReasonPhrase = string(reasonPhrase) - return f, nil -} - -// Length of a written frame -func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { - length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) - if !f.IsApplicationError { - length += quicvarint.Len(f.FrameType) // for the frame type - } - return length -} - -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if f.IsApplicationError { - b.WriteByte(0x1d) - } else { - b.WriteByte(0x1c) - } - - quicvarint.Write(b, f.ErrorCode) - if !f.IsApplicationError { - quicvarint.Write(b, f.FrameType) - } - quicvarint.Write(b, uint64(len(f.ReasonPhrase))) - b.WriteString(f.ReasonPhrase) - return nil -} diff --git a/internal/quic-go/wire/connection_close_frame_test.go b/internal/quic-go/wire/connection_close_frame_test.go deleted file mode 100644 index c507fc8a..00000000 --- a/internal/quic-go/wire/connection_close_frame_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("CONNECTION_CLOSE Frame", func() { - Context("when parsing", func() { - It("accepts sample frame containing a QUIC error code", func() { - reason := "No recent network activity." - data := []byte{0x1c} - data = append(data, encodeVarInt(0x19)...) - data = append(data, encodeVarInt(0x1337)...) // frame type - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, []byte(reason)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.IsApplicationError).To(BeFalse()) - Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) - Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) - Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts sample frame containing an application error code", func() { - reason := "The application messed things up." - data := []byte{0x1d} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, reason...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.IsApplicationError).To(BeTrue()) - Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) - Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects long reason phrases", func() { - data := []byte{0x1c} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(0x42)...) // frame type - data = append(data, encodeVarInt(0xffff)...) // reason phrase length - b := bytes.NewReader(data) - _, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors on EOFs", func() { - reason := "No recent network activity." - data := []byte{0x1c} - data = append(data, encodeVarInt(0x19)...) - data = append(data, encodeVarInt(0x1337)...) // frame type - data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length - data = append(data, []byte(reason)...) - _, err := parseConnectionCloseFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - - It("parses a frame without a reason phrase", func() { - data := []byte{0x1c} - data = append(data, encodeVarInt(0xcafe)...) - data = append(data, encodeVarInt(0x42)...) // frame type - data = append(data, encodeVarInt(0)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.ReasonPhrase).To(BeEmpty()) - Expect(b.Len()).To(BeZero()) - }) - }) - - Context("when writing", func() { - It("writes a frame without a reason phrase", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - ErrorCode: 0xbeef, - FrameType: 0x12345, - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1c} - expected = append(expected, encodeVarInt(0xbeef)...) - expected = append(expected, encodeVarInt(0x12345)...) // frame type - expected = append(expected, encodeVarInt(0)...) // reason phrase length - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with a reason phrase", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - ErrorCode: 0xdead, - ReasonPhrase: "foobar", - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1c} - expected = append(expected, encodeVarInt(0xdead)...) - expected = append(expected, encodeVarInt(0)...) // frame type - expected = append(expected, encodeVarInt(6)...) // reason phrase length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with an application error code", func() { - b := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - IsApplicationError: true, - ErrorCode: 0xdead, - ReasonPhrase: "foobar", - } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x1d} - expected = append(expected, encodeVarInt(0xdead)...) - expected = append(expected, encodeVarInt(6)...) // reason phrase length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has proper min length, for a frame containing a QUIC error code", func() { - b := &bytes.Buffer{} - f := &ConnectionCloseFrame{ - ErrorCode: 0xcafe, - FrameType: 0xdeadbeef, - ReasonPhrase: "foobar", - } - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) - }) - - It("has proper min length, for a frame containing an application error code", func() { - b := &bytes.Buffer{} - f := &ConnectionCloseFrame{ - IsApplicationError: true, - ErrorCode: 0xcafe, - ReasonPhrase: "foobar", - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) - }) - }) -}) diff --git a/internal/quic-go/wire/crypto_frame.go b/internal/quic-go/wire/crypto_frame.go deleted file mode 100644 index 3e7e1808..00000000 --- a/internal/quic-go/wire/crypto_frame.go +++ /dev/null @@ -1,102 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A CryptoFrame is a CRYPTO frame -type CryptoFrame struct { - Offset protocol.ByteCount - Data []byte -} - -func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - frame := &CryptoFrame{} - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - frame.Offset = protocol.ByteCount(offset) - dataLen, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if dataLen > uint64(r.Len()) { - return nil, io.EOF - } - if dataLen != 0 { - frame.Data = make([]byte, dataLen) - if _, err := io.ReadFull(r, frame.Data); err != nil { - // this should never happen, since we already checked the dataLen earlier - return nil, err - } - } - return frame, nil -} - -func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x6) - quicvarint.Write(b, uint64(f.Offset)) - quicvarint.Write(b, uint64(len(f.Data))) - b.Write(f.Data) - return nil -} - -// Length of a written frame -func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) -} - -// MaxDataLen returns the maximum data length -func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1 - if headerLen > maxSize { - return 0 - } - maxDataLen := maxSize - headerLen - if quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. -// It returns if the frame was actually split. -// The frame might not be split if: -// * the size is large enough to fit the whole frame -// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { - if f.Length(version) <= maxSize { - return nil, false - } - - n := f.MaxDataLen(maxSize) - if n == 0 { - return nil, true - } - - newLen := protocol.ByteCount(len(f.Data)) - n - - new := &CryptoFrame{} - new.Offset = f.Offset - new.Data = make([]byte, newLen) - - // swap the data slices - new.Data, f.Data = f.Data, new.Data - - copy(f.Data, new.Data[n:]) - new.Data = new.Data[:n] - f.Offset += n - - return new, true -} diff --git a/internal/quic-go/wire/crypto_frame_test.go b/internal/quic-go/wire/crypto_frame_test.go deleted file mode 100644 index a52229e1..00000000 --- a/internal/quic-go/wire/crypto_frame_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("CRYPTO frame", func() { - Context("when parsing", func() { - It("parses", func() { - data := []byte{0x6} - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseCryptoFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(r.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x6} - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // data length - data = append(data, []byte("foobar")...) - _, err := parseCryptoFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes a frame", func() { - f := &CryptoFrame{ - Offset: 0x123456, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x6} - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, encodeVarInt(6)...) // length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("always returns a data length such that the resulting frame has the right size", func() { - data := make([]byte, maxSize) - f := &CryptoFrame{ - Offset: 0xdeadbeef, - } - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < maxSize; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) - if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written - // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) - - Context("length", func() { - It("has the right length for a frame without offset and data length", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) - }) - }) - - Context("splitting", func() { - It("splits a frame", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - hdrLen := f.Length(protocol.Version1) - 6 - new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(new.Data).To(Equal([]byte("foo"))) - Expect(new.Offset).To(Equal(protocol.ByteCount(0x1337))) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x1337 + 3))) - }) - - It("doesn't split if there's enough space in the frame", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - f, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) - Expect(needsSplit).To(BeFalse()) - Expect(f).To(BeNil()) - }) - - It("doesn't split if the size is too small", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("foobar"), - } - length := f.Length(protocol.Version1) - 6 - for i := protocol.ByteCount(0); i <= length; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - f, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).ToNot(BeNil()) - }) - }) -}) diff --git a/internal/quic-go/wire/data_blocked_frame.go b/internal/quic-go/wire/data_blocked_frame.go deleted file mode 100644 index ddce5a02..00000000 --- a/internal/quic-go/wire/data_blocked_frame.go +++ /dev/null @@ -1,38 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A DataBlockedFrame is a DATA_BLOCKED frame -type DataBlockedFrame struct { - MaximumData protocol.ByteCount -} - -func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - return &DataBlockedFrame{ - MaximumData: protocol.ByteCount(offset), - }, nil -} - -func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x14) - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil -} - -// Length of a written frame -func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.MaximumData)) -} diff --git a/internal/quic-go/wire/data_blocked_frame_test.go b/internal/quic-go/wire/data_blocked_frame_test.go deleted file mode 100644 index 8f19310b..00000000 --- a/internal/quic-go/wire/data_blocked_frame_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("DATA_BLOCKED frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x14} - data = append(data, encodeVarInt(0x12345678)...) - b := bytes.NewReader(data) - frame, err := parseDataBlockedFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x14} - data = append(data, encodeVarInt(0x12345678)...) - _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - for i := range data { - _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := DataBlockedFrame{MaximumData: 0xdeadbeef} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x14} - expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := DataBlockedFrame{MaximumData: 0x12345} - Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x12345))) - }) - }) -}) diff --git a/internal/quic-go/wire/datagram_frame.go b/internal/quic-go/wire/datagram_frame.go deleted file mode 100644 index 1b3aeb96..00000000 --- a/internal/quic-go/wire/datagram_frame.go +++ /dev/null @@ -1,85 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A DatagramFrame is a DATAGRAM frame -type DatagramFrame struct { - DataLenPresent bool - Data []byte -} - -func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &DatagramFrame{} - f.DataLenPresent = typeByte&0x1 > 0 - - var length uint64 - if f.DataLenPresent { - var err error - len, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if len > uint64(r.Len()) { - return nil, io.EOF - } - length = len - } else { - length = uint64(r.Len()) - } - f.Data = make([]byte, length) - if _, err := io.ReadFull(r, f.Data); err != nil { - return nil, err - } - return f, nil -} - -func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - typeByte := uint8(0x30) - if f.DataLenPresent { - typeByte ^= 0x1 - } - b.WriteByte(typeByte) - if f.DataLenPresent { - quicvarint.Write(b, uint64(len(f.Data))) - } - b.Write(f.Data) - return nil -} - -// MaxDataLen returns the maximum data length -func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { - headerLen := protocol.ByteCount(1) - if f.DataLenPresent { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen++ - } - if headerLen > maxSize { - return 0 - } - maxDataLen := maxSize - headerLen - if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// Length of a written frame -func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - length := 1 + protocol.ByteCount(len(f.Data)) - if f.DataLenPresent { - length += quicvarint.Len(uint64(len(f.Data))) - } - return length -} diff --git a/internal/quic-go/wire/datagram_frame_test.go b/internal/quic-go/wire/datagram_frame_test.go deleted file mode 100644 index 6fc581ce..00000000 --- a/internal/quic-go/wire/datagram_frame_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM frame", func() { - Context("when parsing", func() { - It("parses a frame containing a length", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(0x6)...) // length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.DataLenPresent).To(BeTrue()) - Expect(r.Len()).To(BeZero()) - }) - - It("parses a frame without length", func() { - data := []byte{0x30} - data = append(data, []byte("Lorem ipsum dolor sit amet")...) - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) - Expect(frame.DataLenPresent).To(BeFalse()) - Expect(r.Len()).To(BeZero()) - }) - - It("errors when the length is longer than the rest of the frame", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(0x6)...) // length - data = append(data, []byte("fooba")...) - r := bytes.NewReader(data) - _, err := parseDatagramFrame(r, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors on EOFs", func() { - data := []byte{0x30 ^ 0x1} - data = append(data, encodeVarInt(6)...) // length - data = append(data, []byte("foobar")...) - _, err := parseDatagramFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a frame with length", func() { - f := &DatagramFrame{ - DataLenPresent: true, - Data: []byte("foobar"), - } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x30 ^ 0x1} - expected = append(expected, encodeVarInt(0x6)...) - expected = append(expected, []byte("foobar")...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a frame without length", func() { - f := &DatagramFrame{Data: []byte("Lorem ipsum")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x30} - expected = append(expected, []byte("Lorem ipsum")...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - }) - - Context("length", func() { - It("has the right length for a frame with length", func() { - f := &DatagramFrame{ - DataLenPresent: true, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(6) + 6)) - }) - - It("has the right length for a frame without length", func() { - f := &DatagramFrame{Data: []byte("foobar")} - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(1 + 6))) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("returns a data length such that the resulting frame has the right size, if data length is not present", func() { - data := make([]byte, maxSize) - f := &DatagramFrame{} - b := &bytes.Buffer{} - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(Equal(i)) - } - }) - - It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { - data := make([]byte, maxSize) - f := &DatagramFrame{DataLenPresent: true} - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) -}) diff --git a/internal/quic-go/wire/extended_header.go b/internal/quic-go/wire/extended_header.go deleted file mode 100644 index 766ccbc1..00000000 --- a/internal/quic-go/wire/extended_header.go +++ /dev/null @@ -1,249 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// ErrInvalidReservedBits is returned when the reserved bits are incorrect. -// When this error is returned, parsing continues, and an ExtendedHeader is returned. -// This is necessary because we need to decrypt the packet in that case, -// in order to avoid a timing side-channel. -var ErrInvalidReservedBits = errors.New("invalid reserved bits") - -// ExtendedHeader is the header of a QUIC packet. -type ExtendedHeader struct { - Header - - typeByte byte - - KeyPhase protocol.KeyPhaseBit - - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - - parsedLen protocol.ByteCount -} - -func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { - startLen := b.Len() - // read the (now unencrypted) first byte - var err error - h.typeByte, err = b.ReadByte() - if err != nil { - return false, err - } - if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { - return false, err - } - var reservedBitsValid bool - if h.IsLongHeader { - reservedBitsValid, err = h.parseLongHeader(b, v) - } else { - reservedBitsValid, err = h.parseShortHeader(b, v) - } - if err != nil { - return false, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return reservedBitsValid, err -} - -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0xc != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { - h.KeyPhase = protocol.KeyPhaseZero - if h.typeByte&0x4 > 0 { - h.KeyPhase = protocol.KeyPhaseOne - } - - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0x18 != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { - h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - n, err := b.ReadByte() - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen2: - n, err := utils.BigEndian.ReadUint16(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen3: - n, err := utils.BigEndian.ReadUint24(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen4: - n, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// Write writes the Header. -func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { - if h.DestConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) - } - if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) - } - if h.IsLongHeader { - return h.writeLongHeader(b, ver) - } - return h.writeShortHeader(b, ver) -} - -func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { - var packetType uint8 - if version == protocol.Version2 { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b01 - case protocol.PacketType0RTT: - packetType = 0b10 - case protocol.PacketTypeHandshake: - packetType = 0b11 - case protocol.PacketTypeRetry: - packetType = 0b00 - } - } else { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b00 - case protocol.PacketType0RTT: - packetType = 0b01 - case protocol.PacketTypeHandshake: - packetType = 0b10 - case protocol.PacketTypeRetry: - packetType = 0b11 - } - } - firstByte := 0xc0 | packetType<<4 - if h.Type != protocol.PacketTypeRetry { - // Retry packets don't have a packet number - firstByte |= uint8(h.PacketNumberLen - 1) - } - - b.WriteByte(firstByte) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - b.WriteByte(uint8(h.DestConnectionID.Len())) - b.Write(h.DestConnectionID.Bytes()) - b.WriteByte(uint8(h.SrcConnectionID.Len())) - b.Write(h.SrcConnectionID.Bytes()) - - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeRetry: - b.Write(h.Token) - return nil - case protocol.PacketTypeInitial: - quicvarint.Write(b, uint64(len(h.Token))) - b.Write(h.Token) - } - quicvarint.WriteWithLen(b, uint64(h.Length), 2) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { - typeByte := 0x40 | uint8(h.PacketNumberLen-1) - if h.KeyPhase == protocol.KeyPhaseOne { - typeByte |= byte(1 << 2) - } - - b.WriteByte(typeByte) - b.Write(h.DestConnectionID.Bytes()) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen3: - utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// ParsedLen returns the number of bytes that were consumed when parsing the header -func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { - return h.parsedLen -} - -// GetLength determines the length of the Header. -func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { - if h.IsLongHeader { - length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ - if h.Type == protocol.PacketTypeInitial { - length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) - } - return length - } - - length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) - length += protocol.ByteCount(h.PacketNumberLen) - return length -} - -// Log logs the Header -func (h *ExtendedHeader) Log(logger utils.Logger) { - if h.IsLongHeader { - var token string - if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { - if len(h.Token) == 0 { - token = "Token: (empty), " - } else { - token = fmt.Sprintf("Token: %#x, ", h.Token) - } - if h.Type == protocol.PacketTypeRetry { - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) - return - } - } - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) - } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) - } -} diff --git a/internal/quic-go/wire/extended_header_test.go b/internal/quic-go/wire/extended_header_test.go deleted file mode 100644 index 4ec4cde1..00000000 --- a/internal/quic-go/wire/extended_header_test.go +++ /dev/null @@ -1,481 +0,0 @@ -package wire - -import ( - "bytes" - "log" - "os" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Header", func() { - const versionIETFHeader = protocol.VersionTLS // a QUIC version that uses the IETF Header format - - Context("Writing", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - - Context("Long Header", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - - It("writes", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, - Version: 0x1020304, - Length: protocol.InitialPacketSizeIPv4, - }, - PacketNumber: 0xdecaf, - PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{ - 0xc0 | 0x2<<4 | 0x2, - 0x1, 0x2, 0x3, 0x4, // version number - 0x6, // dest connection ID length - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID - 0x8, // src connection ID length - 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID - } - expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length - expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("refuses to write a header with a too long connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) - }) - - It("writes a header with a 20 byte connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).ToNot(HaveOccurred()) - Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) - }) - - It("writes an Initial containing a token", func() { - token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: 0x1020304, - Type: protocol.PacketTypeInitial, - Token: token, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0) - expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) - Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) - }) - - It("uses a 2-byte encoding for the length on Initial packets", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: 0x1020304, - Type: protocol.PacketTypeInitial, - Length: 37, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - b := &bytes.Buffer{} - quicvarint.WriteWithLen(b, 37, 2) - Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) - }) - - It("writes a Retry packet", func() { - token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ - IsLongHeader: true, - Version: protocol.Version1, - Type: protocol.PacketTypeRetry, - Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0xc0 | 0b11<<4} - expected = appendVersion(expected, protocol.Version1) - expected = append(expected, 0x0) // dest connection ID length - expected = append(expected, 0x0) // src connection ID length - expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - }) - - Context("long header, version 2", func() { - It("writes an Initial", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeInitial, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b01) - }) - - It("writes a Retry packet", func() { - token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeRetry, - Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0xc0 | 0b11<<4} - expected = appendVersion(expected, protocol.Version2) - expected = append(expected, 0x0) // dest connection ID length - expected = append(expected, 0x0) // src connection ID length - expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a Handshake Packet", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketTypeHandshake, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b11) - }) - - It("writes a 0-RTT Packet", func() { - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Version: protocol.Version2, - Type: protocol.PacketType0RTT, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b10) - }) - }) - - Context("short header", func() { - It("writes a header with connection ID", func() { - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - 0x42, // packet number - })) - }) - - It("writes a header without connection ID", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0x42, // packet number - })) - }) - - It("writes a header with a 2 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen2, - PacketNumber: 0x765, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x1} - expected = append(expected, []byte{0x7, 0x65}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a header with a 4 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen4, - PacketNumber: 0x12345678, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x3} - expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("errors when given an invalid packet number length", func() { - err := (&ExtendedHeader{ - PacketNumberLen: 5, - PacketNumber: 0xdecafbad, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid packet number length: 5")) - }) - - It("writes the Key Phase Bit", func() { - Expect((&ExtendedHeader{ - KeyPhase: protocol.KeyPhaseOne, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40 | 0x4, - 0x42, // packet number - })) - }) - }) - }) - - Context("getting the length", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - - It("has the right length for the Long Header, for a short length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Length: 1, - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for the Long Header, for a long length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Length: 1500, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial that has a short length", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 15, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial not containing a Token", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 1500, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for an Initial containing a Token", func() { - h := &ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Type: protocol.PacketTypeInitial, - Length: 1500, - Token: []byte("foo"), - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) - }) - - It("has the right length for a Short Header containing a connection ID", func() { - h := &ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(10)) - }) - - It("has the right length for a short header without a connection ID", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(2)) - }) - - It("has the right length for a short header with a 2 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(3)) - }) - - It("has the right length for a short header with a 5 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(5)) - }) - }) - - Context("Logging", func() { - var ( - buf *bytes.Buffer - logger utils.Logger - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - logger = utils.DefaultLogger - logger.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(buf) - }) - - AfterEach(func() { - log.SetOutput(os.Stdout) - }) - - It("logs Long Headers", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, - Type: protocol.PacketTypeHandshake, - Length: 54321, - Version: 0xfeed, - }, - PacketNumber: 1337, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: deadbeefcafe1337, SrcConnectionID: decafbad13371337, PacketNumber: 1337, PacketNumberLen: 2, Length: 54321, Version: 0xfeed}")) - }) - - It("logs Initial Packets with a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeInitial, - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - Length: 100, - Version: 0xfeed, - }, - PacketNumber: 42, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0xdeadbeef, PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) - }) - - It("logs Initial packets without a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeInitial, - Length: 100, - Version: 0xfeed, - }, - PacketNumber: 42, - PacketNumberLen: protocol.PacketNumberLen2, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: (empty), PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}")) - }) - - It("logs Retry packets with a Token", func() { - (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - Type: protocol.PacketTypeRetry, - Token: []byte{0x12, 0x34, 0x56}, - Version: 0xfeed, - }, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}")) - }) - - It("logs Short Headers containing a connection ID", func() { - (&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - }, - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 1337, - PacketNumberLen: 4, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) - }) - }) -}) diff --git a/internal/quic-go/wire/frame_parser.go b/internal/quic-go/wire/frame_parser.go deleted file mode 100644 index b1a3659b..00000000 --- a/internal/quic-go/wire/frame_parser.go +++ /dev/null @@ -1,143 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "reflect" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" -) - -type frameParser struct { - ackDelayExponent uint8 - - supportsDatagrams bool - - version protocol.VersionNumber -} - -// NewFrameParser creates a new frame parser. -func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { - return &frameParser{ - supportsDatagrams: supportsDatagrams, - version: v, - } -} - -// ParseNext parses the next frame. -// It skips PADDING frames. -func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { - for r.Len() != 0 { - typeByte, _ := r.ReadByte() - if typeByte == 0x0 { // PADDING frame - continue - } - r.UnreadByte() - - f, err := p.parseFrame(r, typeByte, encLevel) - if err != nil { - return nil, &qerr.TransportError{ - FrameType: uint64(typeByte), - ErrorCode: qerr.FrameEncodingError, - ErrorMessage: err.Error(), - } - } - return f, nil - } - return nil, nil -} - -func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { - var frame Frame - var err error - if typeByte&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, p.version) - } else { - switch typeByte { - case 0x1: - frame, err = parsePingFrame(r, p.version) - case 0x2, 0x3: - ackDelayExponent := p.ackDelayExponent - if encLevel != protocol.Encryption1RTT { - ackDelayExponent = protocol.DefaultAckDelayExponent - } - frame, err = parseAckFrame(r, ackDelayExponent, p.version) - case 0x4: - frame, err = parseResetStreamFrame(r, p.version) - case 0x5: - frame, err = parseStopSendingFrame(r, p.version) - case 0x6: - frame, err = parseCryptoFrame(r, p.version) - case 0x7: - frame, err = parseNewTokenFrame(r, p.version) - case 0x10: - frame, err = parseMaxDataFrame(r, p.version) - case 0x11: - frame, err = parseMaxStreamDataFrame(r, p.version) - case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, p.version) - case 0x14: - frame, err = parseDataBlockedFrame(r, p.version) - case 0x15: - frame, err = parseStreamDataBlockedFrame(r, p.version) - case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, p.version) - case 0x18: - frame, err = parseNewConnectionIDFrame(r, p.version) - case 0x19: - frame, err = parseRetireConnectionIDFrame(r, p.version) - case 0x1a: - frame, err = parsePathChallengeFrame(r, p.version) - case 0x1b: - frame, err = parsePathResponseFrame(r, p.version) - case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, p.version) - case 0x1e: - frame, err = parseHandshakeDoneFrame(r, p.version) - case 0x30, 0x31: - if p.supportsDatagrams { - frame, err = parseDatagramFrame(r, p.version) - break - } - fallthrough - default: - err = errors.New("unknown frame type") - } - } - if err != nil { - return nil, err - } - if !p.isAllowedAtEncLevel(frame, encLevel) { - return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) - } - return frame, nil -} - -func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { - switch encLevel { - case protocol.EncryptionInitial, protocol.EncryptionHandshake: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: - return true - default: - return false - } - case protocol.Encryption0RTT: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - return false - default: - return true - } - case protocol.Encryption1RTT: - return true - default: - panic("unknown encryption level") - } -} - -func (p *frameParser) SetAckDelayExponent(exp uint8) { - p.ackDelayExponent = exp -} diff --git a/internal/quic-go/wire/frame_parser_test.go b/internal/quic-go/wire/frame_parser_test.go deleted file mode 100644 index f46bafd8..00000000 --- a/internal/quic-go/wire/frame_parser_test.go +++ /dev/null @@ -1,410 +0,0 @@ -package wire - -import ( - "bytes" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Frame parsing", func() { - var ( - buf *bytes.Buffer - parser FrameParser - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - parser = NewFrameParser(true, protocol.Version1) - }) - - It("returns nil if there's nothing more to read", func() { - f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeNil()) - }) - - It("skips PADDING frames", func() { - buf.Write([]byte{0}) // PADDING frame - (&PingFrame{}).Write(buf, protocol.Version1) - f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(Equal(&PingFrame{})) - }) - - It("handles PADDING at the end", func() { - r := bytes.NewReader([]byte{0, 0, 0}) - f, err := parser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeNil()) - Expect(r.Len()).To(BeZero()) - }) - - It("unpacks ACK frames", func() { - f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) - }) - - It("uses the custom ack delay exponent for 1RTT packets", func() { - parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: time.Second, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - // The ACK frame is always written using the protocol.AckDelayExponent. - // That's why we expect a different value when parsing. - Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) - }) - - It("uses the default ack delay exponent for non-1RTT packets", func() { - parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) - f := &AckFrame{ - AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, - DelayTime: time.Second, - } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) - }) - - It("unpacks RESET_STREAM frames", func() { - f := &ResetStreamFrame{ - StreamID: 0xdeadbeef, - FinalSize: 0xdecafbad1234, - ErrorCode: 0x1337, - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STOP_SENDING frames", func() { - f := &StopSendingFrame{StreamID: 0x42} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks CRYPTO frames", func() { - f := &CryptoFrame{ - Offset: 0x1337, - Data: []byte("lorem ipsum"), - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks NEW_TOKEN frames", func() { - f := &NewTokenFrame{Token: []byte("foobar")} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAM frames", func() { - f := &StreamFrame{ - StreamID: 0x42, - Offset: 0x1337, - Fin: true, - Data: []byte("foobar"), - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_DATA frames", func() { - f := &MaxDataFrame{ - MaximumData: 0xcafe, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_STREAM_DATA frames", func() { - f := &MaxStreamDataFrame{ - StreamID: 0xdeadbeef, - MaximumStreamData: 0xdecafbad, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks MAX_STREAMS frames", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 0x1337, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks DATA_BLOCKED frames", func() { - f := &DataBlockedFrame{MaximumData: 0x1234} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAM_DATA_BLOCKED frames", func() { - f := &StreamDataBlockedFrame{ - StreamID: 0xdeadbeef, - MaximumStreamData: 0xdead, - } - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks STREAMS_BLOCKED frames", func() { - f := &StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 0x1234567, - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks NEW_CONNECTION_ID frames", func() { - f := &NewConnectionIDFrame{ - SequenceNumber: 0x1337, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks RETIRE_CONNECTION_ID frames", func() { - f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks PATH_CHALLENGE frames", func() { - f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*PathChallengeFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("unpacks PATH_RESPONSE frames", func() { - f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).ToNot(BeNil()) - Expect(frame).To(BeAssignableToTypeOf(f)) - Expect(frame.(*PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("unpacks CONNECTION_CLOSE frames", func() { - f := &ConnectionCloseFrame{ - IsApplicationError: true, - ReasonPhrase: "foobar", - } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks HANDSHAKE_DONE frames", func() { - f := &HandshakeDoneFrame{} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("unpacks DATAGRAM frames", func() { - f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when DATAGRAM frames are not supported", func() { - parser = NewFrameParser(false, protocol.Version1) - f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FrameEncodingError, - FrameType: 0x30, - ErrorMessage: "unknown frame type", - })) - }) - - It("errors on invalid type", func() { - _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.FrameEncodingError, - FrameType: 0x42, - ErrorMessage: "unknown frame type", - })) - }) - - It("errors on invalid frames", func() { - f := &MaxStreamDataFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - b := &bytes.Buffer{} - f.Write(b, protocol.Version1) - _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) - Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - }) - - Context("encryption level check", func() { - frames := []Frame{ - &PingFrame{}, - &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 42}}}, - &ResetStreamFrame{}, - &StopSendingFrame{}, - &CryptoFrame{}, - &NewTokenFrame{Token: []byte("lorem ipsum")}, - &StreamFrame{Data: []byte("foobar")}, - &MaxDataFrame{}, - &MaxStreamDataFrame{}, - &MaxStreamsFrame{}, - &DataBlockedFrame{}, - &StreamDataBlockedFrame{}, - &StreamsBlockedFrame{}, - &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, - &RetireConnectionIDFrame{}, - &PathChallengeFrame{}, - &PathResponseFrame{}, - &ConnectionCloseFrame{}, - &HandshakeDoneFrame{}, - &DatagramFrame{}, - } - - var framesSerialized [][]byte - - BeforeEach(func() { - framesSerialized = nil - for _, frame := range frames { - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) - framesSerialized = append(framesSerialized, buf.Bytes()) - } - }) - - It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Initial packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionInitial) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: - Expect(err).ToNot(HaveOccurred()) - default: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Initial")) - } - } - }) - - It("rejects all frames but ACK, CRYPTO, PING and CONNECTION_CLOSE in Handshake packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: - Expect(err).ToNot(HaveOccurred()) - default: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Handshake")) - } - } - }) - - It("rejects all frames but ACK, CRYPTO, CONNECTION_CLOSE, NEW_TOKEN, PATH_RESPONSE and RETIRE_CONNECTION_ID in 0-RTT packets", func() { - for i, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption0RTT) - switch frames[i].(type) { - case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) - Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level 0-RTT")) - default: - Expect(err).ToNot(HaveOccurred()) - } - } - }) - - It("accepts all frame types in 1-RTT packets", func() { - for _, b := range framesSerialized { - _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - } - }) - }) -}) diff --git a/internal/quic-go/wire/handshake_done_frame.go b/internal/quic-go/wire/handshake_done_frame.go deleted file mode 100644 index d940ddda..00000000 --- a/internal/quic-go/wire/handshake_done_frame.go +++ /dev/null @@ -1,28 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A HandshakeDoneFrame is a HANDSHAKE_DONE frame -type HandshakeDoneFrame struct{} - -// ParseHandshakeDoneFrame parses a HandshakeDone frame -func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &HandshakeDoneFrame{}, nil -} - -func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1e) - return nil -} - -// Length of a written frame -func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/internal/quic-go/wire/header.go b/internal/quic-go/wire/header.go deleted file mode 100644 index 8455e748..00000000 --- a/internal/quic-go/wire/header.go +++ /dev/null @@ -1,274 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// ParseConnectionID parses the destination connection ID of a packet. -// It uses the data slice for the connection ID. -// That means that the connection ID must not be used after the packet buffer is released. -func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { - if len(data) == 0 { - return nil, io.EOF - } - isLongHeader := data[0]&0x80 > 0 - if !isLongHeader { - if len(data) < shortHeaderConnIDLen+1 { - return nil, io.EOF - } - return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil - } - if len(data) < 6 { - return nil, io.EOF - } - destConnIDLen := int(data[5]) - if len(data) < 6+destConnIDLen { - return nil, io.EOF - } - return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil -} - -// IsVersionNegotiationPacket says if this is a version negotiation packet -func IsVersionNegotiationPacket(b []byte) bool { - if len(b) < 5 { - return false - } - return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 -} - -// Is0RTTPacket says if this is a 0-RTT packet. -// A packet sent with a version we don't understand can never be a 0-RTT packet. -func Is0RTTPacket(b []byte) bool { - if len(b) < 5 { - return false - } - if b[0]&0x80 == 0 { - return false - } - version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) - if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { - return false - } - if version == protocol.Version2 { - return b[0]>>4&0b11 == 0b10 - } - return b[0]>>4&0b11 == 0b01 -} - -var ErrUnsupportedVersion = errors.New("unsupported version") - -// The Header is the version independent part of the header -type Header struct { - IsLongHeader bool - typeByte byte - Type protocol.PacketType - - Version protocol.VersionNumber - SrcConnectionID protocol.ConnectionID - DestConnectionID protocol.ConnectionID - - Length protocol.ByteCount - - Token []byte - - parsedLen protocol.ByteCount // how many bytes were read while parsing this header -} - -// ParsePacket parses a packet. -// If the packet has a long header, the packet is cut according to the length field. -// If we understand the version, the packet is header up unto the packet number. -// Otherwise, only the invariant part of the header is parsed. -func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { - hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) - if err != nil { - if err == ErrUnsupportedVersion { - return hdr, nil, nil, ErrUnsupportedVersion - } - return nil, nil, nil, err - } - var rest []byte - if hdr.IsLongHeader { - if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { - return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - packetLen := int(hdr.ParsedLen() + hdr.Length) - rest = data[packetLen:] - data = data[:packetLen] - } - return hdr, data, rest, nil -} - -// ParseHeader parses the header. -// For short header packets: up to the packet number. -// For long header packets: -// * if we understand the version: up to the packet number -// * if not, only the invariant part of the header -func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { - startLen := b.Len() - h, err := parseHeaderImpl(b, shortHeaderConnIDLen) - if err != nil { - return h, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return h, err -} - -func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { - typeByte, err := b.ReadByte() - if err != nil { - return nil, err - } - - h := &Header{ - typeByte: typeByte, - IsLongHeader: typeByte&0x80 > 0, - } - - if !h.IsLongHeader { - if h.typeByte&0x40 == 0 { - return nil, errors.New("not a QUIC packet") - } - if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { - return nil, err - } - return h, nil - } - return h, h.parseLongHeader(b) -} - -func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { - var err error - h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) - return err -} - -func (h *Header) parseLongHeader(b *bytes.Reader) error { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.Version = protocol.VersionNumber(v) - if h.Version != 0 && h.typeByte&0x40 == 0 { - return errors.New("not a QUIC packet") - } - destConnIDLen, err := b.ReadByte() - if err != nil { - return err - } - h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) - if err != nil { - return err - } - srcConnIDLen, err := b.ReadByte() - if err != nil { - return err - } - h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) - if err != nil { - return err - } - if h.Version == 0 { // version negotiation packet - return nil - } - // If we don't understand the version, we have no idea how to interpret the rest of the bytes - if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { - return ErrUnsupportedVersion - } - - if h.Version == protocol.Version2 { - switch h.typeByte >> 4 & 0b11 { - case 0b00: - h.Type = protocol.PacketTypeRetry - case 0b01: - h.Type = protocol.PacketTypeInitial - case 0b10: - h.Type = protocol.PacketType0RTT - case 0b11: - h.Type = protocol.PacketTypeHandshake - } - } else { - switch h.typeByte >> 4 & 0b11 { - case 0b00: - h.Type = protocol.PacketTypeInitial - case 0b01: - h.Type = protocol.PacketType0RTT - case 0b10: - h.Type = protocol.PacketTypeHandshake - case 0b11: - h.Type = protocol.PacketTypeRetry - } - } - - if h.Type == protocol.PacketTypeRetry { - tokenLen := b.Len() - 16 - if tokenLen <= 0 { - return io.EOF - } - h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } - _, err := b.Seek(16, io.SeekCurrent) - return err - } - - if h.Type == protocol.PacketTypeInitial { - tokenLen, err := quicvarint.Read(b) - if err != nil { - return err - } - if tokenLen > uint64(b.Len()) { - return io.EOF - } - h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } - } - - pl, err := quicvarint.Read(b) - if err != nil { - return err - } - h.Length = protocol.ByteCount(pl) - return nil -} - -// ParsedLen returns the number of bytes that were consumed when parsing the header -func (h *Header) ParsedLen() protocol.ByteCount { - return h.parsedLen -} - -// ParseExtended parses the version dependent part of the header. -// The Reader has to be set such that it points to the first byte of the header. -func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { - extHdr := h.toExtendedHeader() - reservedBitsValid, err := extHdr.parse(b, ver) - if err != nil { - return nil, err - } - if !reservedBitsValid { - return extHdr, ErrInvalidReservedBits - } - return extHdr, nil -} - -func (h *Header) toExtendedHeader() *ExtendedHeader { - return &ExtendedHeader{Header: *h} -} - -// PacketType is the type of the packet, for logging purposes -func (h *Header) PacketType() string { - if h.IsLongHeader { - return h.Type.String() - } - return "1-RTT" -} diff --git a/internal/quic-go/wire/header_test.go b/internal/quic-go/wire/header_test.go deleted file mode 100644 index cdcc08b3..00000000 --- a/internal/quic-go/wire/header_test.go +++ /dev/null @@ -1,583 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Header Parsing", func() { - Context("Parsing the Connection ID", func() { - It("parses the connection ID of a long header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, - Version: protocol.Version1, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - connID, err := ParseConnectionID(buf.Bytes(), 8) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - }) - - It("parses the connection ID of a short header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write([]byte("foobar")) - connID, err := ParseConnectionID(buf.Bytes(), 4) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - }) - - It("errors on EOF, for short header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < len(data); i++ { - b := make([]byte, i) - copy(b, data[:i]) - _, err := ParseConnectionID(b, 8) - Expect(err).To(MatchError(io.EOF)) - } - }) - - It("errors on EOF, for long header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, - Version: protocol.Version1, - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { - b := make([]byte, i) - copy(b, data[:i]) - _, err := ParseConnectionID(b, 8) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("identifying 0-RTT packets", func() { - It("recognizes 0-RTT packets, for QUIC v1", func() { - zeroRTTHeader := make([]byte, 5) - zeroRTTHeader[0] = 0x80 | 0b01<<4 - binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) - - Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) - Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header - Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) - }) - - It("recognizes 0-RTT packets, for QUIC v2", func() { - zeroRTTHeader := make([]byte, 5) - zeroRTTHeader[0] = 0x80 | 0b10<<4 - binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) - - Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) - Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version - Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header - Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) - }) - }) - - Context("Identifying Version Negotiation Packets", func() { - It("identifies version negotiation packets", func() { - Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) - Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) - Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) - }) - - It("returns false on EOF", func() { - vnp := []byte{0x80, 0, 0, 0, 0} - for i := range vnp { - Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) - } - }) - }) - - Context("Long Headers", func() { - It("parses a Long Header", func() { - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - data := []byte{0xc0 ^ 0x3} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x9) // dest conn id length - data = append(data, destConnID...) - data = append(data, 0x4) // src conn id length - data = append(data, srcConnID...) - data = append(data, encodeVarInt(6)...) // token length - data = append(data, []byte("foobar")...) // token - data = append(data, encodeVarInt(10)...) // length - hdrLen := len(data) - data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number - data = append(data, []byte("foobar")...) - Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) - - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(pdata).To(Equal(data)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) - Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(rest).To(BeEmpty()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(b.Len()).To(Equal(6)) // foobar - Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) - Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4)) - }) - - It("errors if 0x40 is not set", func() { - data := []byte{ - 0x80 | 0x2<<4, - 0x11, // connection ID lengths - 0xde, 0xca, 0xfb, 0xad, // dest conn ID - 0xde, 0xad, 0xbe, 0xef, // src conn ID - } - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError("not a QUIC packet")) - }) - - It("stops parsing when encountering an unsupported version", func() { - data := []byte{ - 0xc0, - 0xde, 0xad, 0xbe, 0xef, - 0x8, // dest conn ID len - 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID - 0x8, // src conn ID len - 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID - 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes - } - hdr, _, rest, err := ParsePacket(data, 0) - Expect(err).To(MatchError(ErrUnsupportedVersion)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) - Expect(rest).To(BeEmpty()) - }) - - It("parses a Long Header without a destination connection ID", func() { - data := []byte{0xc0 ^ 0x1<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // dest conn ID len - data = append(data, 0x4) // src conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(hdr.DestConnectionID).To(BeEmpty()) - }) - - It("parses a Long Header without a source connection ID", func() { - data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0xa) // dest conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID - data = append(data, 0x0) // src conn ID len - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.SrcConnectionID).To(BeEmpty()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - }) - - It("parses a Long Header with a 2 byte packet number", func() { - data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, protocol.Version1) // version number - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(0)...) // token length - data = append(data, encodeVarInt(0)...) // length - data = append(data, []byte{0x1, 0x23}...) - - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("parses a Retry packet, for QUIC v1", func() { - data := []byte{0xc0 | 0b11<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{6}...) // dest conn ID len - data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID - data = append(data, []byte{10}...) // src conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("parses a Retry packet, for QUIC v2", func() { - data := []byte{0xc0 | 0b00<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version2) - data = append(data, []byte{6}...) // dest conn ID len - data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID - data = append(data, []byte{10}...) // src conn ID len - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) - hdr, pdata, rest, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(hdr.Version).To(Equal(protocol.Version2)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("errors if the Retry packet is too short for the integrity tag", func() { - data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0, 0}...) // conn ID lens - data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) - // this results in a token length of 0 - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors if the token length is too large", func() { - data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // connection ID lengths - data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) - data = append(data, encodeVarInt(0x42)...) // length, 1 byte - data = append(data, []byte{0x12, 0x34}...) // packet number - - _, _, _, err := ParsePacket(data, 0) - Expect(err).To(MatchError(io.EOF)) - }) - - It("errors if the 5th or 6th bit are set", func() { - data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(2)...) // length - data = append(data, []byte{0x12, 0x34}...) // packet number - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(extHdr).ToNot(BeNil()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234))) - }) - - It("errors on EOF, when parsing the header", func() { - data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, 0x8) // dest conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID - data = append(data, 0x8) // src conn ID len - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID - for i := 0; i < len(data); i++ { - _, _, _, err := ParsePacket(data[:i], 0) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, when parsing the extended header", func() { - data := []byte{0xc0 | 0x2<<4 | 0x3} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, encodeVarInt(0)...) // length - hdrLen := len(data) - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, for a Retry packet", func() { - data := []byte{0xc0 ^ 0x3<<4} - data = appendVersion(data, protocol.Version1) - data = append(data, []byte{0x0, 0x0}...) // connection ID lengths - data = append(data, 0xa) // Orig Destination Connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - hdrLen := len(data) - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - - Context("coalesced packets", func() { - It("cuts packets", func() { - buf := &bytes.Buffer{} - hdr := Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 2 + 6, - Version: protocol.Version1, - } - Expect((&ExtendedHeader{ - Header: hdr, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - hdrRaw := append([]byte{}, buf.Bytes()...) - buf.Write([]byte("foobar")) // payload of the first packet - buf.Write([]byte("raboof")) // second packet - parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4) - Expect(err).ToNot(HaveOccurred()) - Expect(parsedHdr.Type).To(Equal(hdr.Type)) - Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) - Expect(rest).To(Equal([]byte("raboof"))) - }) - - It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 3, - Version: protocol.Version1, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - _, _, _, err := ParsePacket(buf.Bytes(), 4) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) - }) - - It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Length: 1000, - Version: protocol.Version1, - }, - PacketNumber: 0x1337, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write(make([]byte, 500-2 /* for packet number length */)) - _, _, _, err := ParsePacket(buf.Bytes(), 4) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - }) - }) - - Context("Short Headers", func() { - It("reads a Short Header with a 8 byte connection ID", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x40}, connID...) - data = append(data, 0x42) // packet number - Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) - - hdr, pdata, rest, err := ParsePacket(data, 8) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) - Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("errors if 0x40 is not set", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x0}, connID...) - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(MatchError("not a QUIC packet")) - }) - - It("errors if the 4th or 5th bit are set", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) - data = append(data, 0x42) // packet number - hdr, _, _, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(extHdr).ToNot(BeNil()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - }) - - It("reads a Short Header with a 5 byte connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40}, connID...) - data = append(data, 0x42) // packet number - hdr, pdata, rest, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(pdata).To(HaveLen(len(data))) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) - Expect(rest).To(BeEmpty()) - }) - - It("reads the Key Phase Bit", func() { - data := []byte{ - 0x40 ^ 0x4, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - data = append(data, 11) // packet number - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 2 byte packet number", func() { - data := []byte{ - 0x40 | 0x1, - 0xde, 0xad, 0xbe, 0xef, // connection ID - } - data = append(data, []byte{0x13, 0x37}...) // packet number - hdr, _, _, err := ParsePacket(data, 4) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 3 byte packet number", func() { - data := []byte{ - 0x40 | 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID - } - data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number - hdr, _, _, err := ParsePacket(data, 10) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOF, when parsing the header", func() { - data := []byte{ - 0x40 ^ 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - } - for i := 0; i < len(data); i++ { - data = data[:i] - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, when parsing the extended header", func() { - data := []byte{ - 0x40 ^ 0x3, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - hdrLen := len(data) - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - }) - - It("tells its packet type for logging", func() { - Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake")) - Expect((&Header{}).PacketType()).To(Equal("1-RTT")) - }) -}) diff --git a/internal/quic-go/wire/interface.go b/internal/quic-go/wire/interface.go deleted file mode 100644 index b5804af9..00000000 --- a/internal/quic-go/wire/interface.go +++ /dev/null @@ -1,19 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A Frame in QUIC -type Frame interface { - Write(b *bytes.Buffer, version protocol.VersionNumber) error - Length(version protocol.VersionNumber) protocol.ByteCount -} - -// A FrameParser parses QUIC frames, one by one. -type FrameParser interface { - ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) - SetAckDelayExponent(uint8) -} diff --git a/internal/quic-go/wire/log.go b/internal/quic-go/wire/log.go deleted file mode 100644 index 030549d2..00000000 --- a/internal/quic-go/wire/log.go +++ /dev/null @@ -1,72 +0,0 @@ -package wire - -import ( - "fmt" - "strings" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// LogFrame logs a frame, either sent or received -func LogFrame(logger utils.Logger, frame Frame, sent bool) { - if !logger.Debug() { - return - } - dir := "<-" - if sent { - dir = "->" - } - switch f := frame.(type) { - case *CryptoFrame: - dataLen := protocol.ByteCount(len(f.Data)) - logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) - case *StreamFrame: - logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) - case *ResetStreamFrame: - logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) - case *AckFrame: - hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 - var ecn string - if hasECN { - ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) - } - if len(f.AckRanges) > 1 { - ackRanges := make([]string, len(f.AckRanges)) - for i, r := range f.AckRanges { - ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) - } - logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) - } else { - logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) - } - case *MaxDataFrame: - logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) - case *MaxStreamDataFrame: - logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) - case *DataBlockedFrame: - logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) - case *StreamDataBlockedFrame: - logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) - case *MaxStreamsFrame: - switch f.Type { - case protocol.StreamTypeUni: - logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) - case protocol.StreamTypeBidi: - logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) - } - case *StreamsBlockedFrame: - switch f.Type { - case protocol.StreamTypeUni: - logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) - case protocol.StreamTypeBidi: - logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) - } - case *NewConnectionIDFrame: - logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) - case *NewTokenFrame: - logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) - default: - logger.Debugf("\t%s %#v", dir, frame) - } -} diff --git a/internal/quic-go/wire/log_test.go b/internal/quic-go/wire/log_test.go deleted file mode 100644 index 38e7b645..00000000 --- a/internal/quic-go/wire/log_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package wire - -import ( - "bytes" - "log" - "os" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Frame logging", func() { - var ( - buf *bytes.Buffer - logger utils.Logger - ) - - BeforeEach(func() { - buf = &bytes.Buffer{} - logger = utils.DefaultLogger - logger.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(buf) - }) - - AfterEach(func() { - log.SetOutput(os.Stdout) - }) - - It("doesn't log when debug is disabled", func() { - logger.SetLogLevel(utils.LogLevelInfo) - LogFrame(logger, &ResetStreamFrame{}, true) - Expect(buf.Len()).To(BeZero()) - }) - - It("logs sent frames", func() { - LogFrame(logger, &ResetStreamFrame{}, true) - Expect(buf.String()).To(ContainSubstring("\t-> &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) - }) - - It("logs received frames", func() { - LogFrame(logger, &ResetStreamFrame{}, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n")) - }) - - It("logs CRYPTO frames", func() { - frame := &CryptoFrame{ - Offset: 42, - Data: make([]byte, 123), - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.CryptoFrame{Offset: 42, Data length: 123, Offset + Data length: 165}\n")) - }) - - It("logs STREAM frames", func() { - frame := &StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'f'}, 100), - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamFrame{StreamID: 42, Fin: false, Offset: 1337, Data length: 100, Offset + Data length: 1437}\n")) - }) - - It("logs ACK frames without missing packets", func() { - frame := &AckFrame{ - AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, - DelayTime: 1 * time.Millisecond, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms}\n")) - }) - - It("logs ACK frames with ECN", func() { - frame := &AckFrame{ - AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, - DelayTime: 1 * time.Millisecond, - ECT0: 5, - ECT1: 66, - ECNCE: 777, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms, ECT0: 5, ECT1: 66, CE: 777}\n")) - }) - - It("logs ACK frames with missing packets", func() { - frame := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5, Largest: 8}, - {Smallest: 2, Largest: 3}, - }, - DelayTime: 12 * time.Millisecond, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 8, LowestAcked: 2, AckRanges: {{Largest: 8, Smallest: 5}, {Largest: 3, Smallest: 2}}, DelayTime: 12ms}\n")) - }) - - It("logs MAX_STREAMS frames", func() { - frame := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: 42}\n")) - }) - - It("logs MAX_DATA frames", func() { - frame := &MaxDataFrame{ - MaximumData: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxDataFrame{MaximumData: 42}\n")) - }) - - It("logs MAX_STREAM_DATA frames", func() { - frame := &MaxStreamDataFrame{ - StreamID: 10, - MaximumStreamData: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 42}\n")) - }) - - It("logs DATA_BLOCKED frames", func() { - frame := &DataBlockedFrame{ - MaximumData: 1000, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.DataBlockedFrame{MaximumData: 1000}\n")) - }) - - It("logs STREAM_DATA_BLOCKED frames", func() { - frame := &StreamDataBlockedFrame{ - StreamID: 42, - MaximumStreamData: 1000, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1000}\n")) - }) - - It("logs STREAMS_BLOCKED frames", func() { - frame := &StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 42, - } - LogFrame(logger, frame, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: 42}\n")) - }) - - It("logs NEW_CONNECTION_ID frames", func() { - LogFrame(logger, &NewConnectionIDFrame{ - SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, - }, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) - }) - - It("logs NEW_TOKEN frames", func() { - LogFrame(logger, &NewTokenFrame{ - Token: []byte{0xde, 0xad, 0xbe, 0xef}, - }, true) - Expect(buf.String()).To(ContainSubstring("\t-> &wire.NewTokenFrame{Token: 0xdeadbeef")) - }) -}) diff --git a/internal/quic-go/wire/max_data_frame.go b/internal/quic-go/wire/max_data_frame.go deleted file mode 100644 index cfa54d7c..00000000 --- a/internal/quic-go/wire/max_data_frame.go +++ /dev/null @@ -1,40 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A MaxDataFrame carries flow control information for the connection -type MaxDataFrame struct { - MaximumData protocol.ByteCount -} - -// parseMaxDataFrame parses a MAX_DATA frame -func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - frame := &MaxDataFrame{} - byteOffset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - frame.MaximumData = protocol.ByteCount(byteOffset) - return frame, nil -} - -// Write writes a MAX_STREAM_DATA frame -func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x10) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil -} - -// Length of a written frame -func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.MaximumData)) -} diff --git a/internal/quic-go/wire/max_data_frame_test.go b/internal/quic-go/wire/max_data_frame_test.go deleted file mode 100644 index 73f6a452..00000000 --- a/internal/quic-go/wire/max_data_frame_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_DATA frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x10} - data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset - b := bytes.NewReader(data) - frame, err := parseMaxDataFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x10} - data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset - _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &MaxDataFrame{ - MaximumData: 0xdeadbeef, - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) - }) - - It("writes a MAX_DATA frame", func() { - b := &bytes.Buffer{} - f := &MaxDataFrame{ - MaximumData: 0xdeadbeefcafe, - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x10} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/quic-go/wire/max_stream_data_frame.go b/internal/quic-go/wire/max_stream_data_frame.go deleted file mode 100644 index 5c6f37d0..00000000 --- a/internal/quic-go/wire/max_stream_data_frame.go +++ /dev/null @@ -1,46 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A MaxStreamDataFrame is a MAX_STREAM_DATA frame -type MaxStreamDataFrame struct { - StreamID protocol.StreamID - MaximumStreamData protocol.ByteCount -} - -func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &MaxStreamDataFrame{ - StreamID: protocol.StreamID(sid), - MaximumStreamData: protocol.ByteCount(offset), - }, nil -} - -func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x11) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil -} - -// Length of a written frame -func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) -} diff --git a/internal/quic-go/wire/max_stream_data_frame_test.go b/internal/quic-go/wire/max_stream_data_frame_test.go deleted file mode 100644 index 4d8e6fd8..00000000 --- a/internal/quic-go/wire/max_stream_data_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_STREAM_DATA frame", func() { - Context("parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x11} - data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID - data = append(data, encodeVarInt(0x12345678)...) // Offset - b := bytes.NewReader(data) - frame, err := parseMaxStreamDataFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x11} - data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID - data = append(data, encodeVarInt(0x12345678)...) // Offset - _, err := parseMaxStreamDataFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &MaxStreamDataFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - Expect(f.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)))) - }) - - It("writes a sample frame", func() { - b := &bytes.Buffer{} - f := &MaxStreamDataFrame{ - StreamID: 0xdecafbad, - MaximumStreamData: 0xdeadbeefcafe42, - } - expected := []byte{0x11} - expected = append(expected, encodeVarInt(0xdecafbad)...) - expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/quic-go/wire/max_streams_frame.go b/internal/quic-go/wire/max_streams_frame.go deleted file mode 100644 index 0681fa24..00000000 --- a/internal/quic-go/wire/max_streams_frame.go +++ /dev/null @@ -1,55 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A MaxStreamsFrame is a MAX_STREAMS frame -type MaxStreamsFrame struct { - Type protocol.StreamType - MaxStreamNum protocol.StreamNum -} - -func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &MaxStreamsFrame{} - switch typeByte { - case 0x12: - f.Type = protocol.StreamTypeBidi - case 0x13: - f.Type = protocol.StreamTypeUni - } - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.MaxStreamNum = protocol.StreamNum(streamID) - if f.MaxStreamNum > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) - } - return f, nil -} - -func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - switch f.Type { - case protocol.StreamTypeBidi: - b.WriteByte(0x12) - case protocol.StreamTypeUni: - b.WriteByte(0x13) - } - quicvarint.Write(b, uint64(f.MaxStreamNum)) - return nil -} - -// Length of a written frame -func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) -} diff --git a/internal/quic-go/wire/max_streams_frame_test.go b/internal/quic-go/wire/max_streams_frame_test.go deleted file mode 100644 index 114b534d..00000000 --- a/internal/quic-go/wire/max_streams_frame_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_STREAMS frame", func() { - Context("parsing", func() { - It("accepts a frame for a bidirectional stream", func() { - data := []byte{0x12} - data = append(data, encodeVarInt(0xdecaf)...) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) - Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts a frame for a bidirectional stream", func() { - data := []byte{0x13} - data = append(data, encodeVarInt(0xdecaf)...) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeUni)) - Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x1d} - data = append(data, encodeVarInt(0xdeadbeefcafe13)...) - _, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } - }) - - for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { - streamType := t - - It("accepts a frame containing the maximum stream count", func() { - f := &MaxStreamsFrame{ - Type: streamType, - MaxStreamNum: protocol.MaxStreamCount, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when receiving a too large stream count", func() { - f := &MaxStreamsFrame{ - Type: streamType, - MaxStreamNum: protocol.MaxStreamCount + 1, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) - }) - } - }) - - Context("writing", func() { - It("for a bidirectional stream", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: 0xdeadbeef, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x12} - expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("for a unidirectional stream", func() { - f := &MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: 0xdecafbad, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x13} - expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := MaxStreamsFrame{MaxStreamNum: 0x1337} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(0x1337))) - }) - }) -}) diff --git a/internal/quic-go/wire/new_connection_id_frame.go b/internal/quic-go/wire/new_connection_id_frame.go deleted file mode 100644 index 9eb1fcbc..00000000 --- a/internal/quic-go/wire/new_connection_id_frame.go +++ /dev/null @@ -1,80 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame -type NewConnectionIDFrame struct { - SequenceNumber uint64 - RetirePriorTo uint64 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} - -func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - seq, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - ret, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if ret > seq { - //nolint:stylecheck - return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) - } - connIDLen, err := r.ReadByte() - if err != nil { - return nil, err - } - if connIDLen > protocol.MaxConnIDLen { - return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return nil, err - } - frame := &NewConnectionIDFrame{ - SequenceNumber: seq, - RetirePriorTo: ret, - ConnectionID: connID, - } - if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - - return frame, nil -} - -func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x18) - quicvarint.Write(b, f.SequenceNumber) - quicvarint.Write(b, f.RetirePriorTo) - connIDLen := f.ConnectionID.Len() - if connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - b.WriteByte(uint8(connIDLen)) - b.Write(f.ConnectionID.Bytes()) - b.Write(f.StatelessResetToken[:]) - return nil -} - -// Length of a written frame -func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 -} diff --git a/internal/quic-go/wire/new_connection_id_frame_test.go b/internal/quic-go/wire/new_connection_id_frame_test.go deleted file mode 100644 index 776b9670..00000000 --- a/internal/quic-go/wire/new_connection_id_frame_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_CONNECTION_ID frame", func() { - Context("when parsing", func() { - It("accepts a sample frame", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe)...) // retire prior to - data = append(data, 10) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - frame, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) - Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) - Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) - }) - - It("errors when the Retire Prior To value is larger than the Sequence Number", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(1000)...) // sequence number - data = append(data, encodeVarInt(1001)...) // retire prior to - data = append(data, 3) - data = append(data, []byte{1, 2, 3}...) - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) - }) - - It("errors when the connection ID has an invalid length", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe)...) // retire prior to - data = append(data, 21) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("invalid connection ID length: 21")) - }) - - It("errors on EOFs", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - data = append(data, encodeVarInt(0xcafe1234)...) // retire prior to - data = append(data, 10) // connection ID length - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID - data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - frame := &NewConnectionIDFrame{ - SequenceNumber: 0x1337, - RetirePriorTo: 0x42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, - StatelessResetToken: token, - } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x18} - expected = append(expected, encodeVarInt(0x1337)...) - expected = append(expected, encodeVarInt(0x42)...) - expected = append(expected, 6) - expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) - expected = append(expected, token[:]...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct length", func() { - token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - frame := &NewConnectionIDFrame{ - SequenceNumber: 0xdecafbad, - RetirePriorTo: 0xdeadbeefcafe, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - StatelessResetToken: token, - } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) - }) - }) -}) diff --git a/internal/quic-go/wire/new_token_frame.go b/internal/quic-go/wire/new_token_frame.go deleted file mode 100644 index 3a44eb21..00000000 --- a/internal/quic-go/wire/new_token_frame.go +++ /dev/null @@ -1,48 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A NewTokenFrame is a NEW_TOKEN frame -type NewTokenFrame struct { - Token []byte -} - -func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - tokenLen, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if uint64(r.Len()) < tokenLen { - return nil, io.EOF - } - if tokenLen == 0 { - return nil, errors.New("token must not be empty") - } - token := make([]byte, int(tokenLen)) - if _, err := io.ReadFull(r, token); err != nil { - return nil, err - } - return &NewTokenFrame{Token: token}, nil -} - -func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x7) - quicvarint.Write(b, uint64(len(f.Token))) - b.Write(f.Token) - return nil -} - -// Length of a written frame -func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) -} diff --git a/internal/quic-go/wire/new_token_frame_test.go b/internal/quic-go/wire/new_token_frame_test.go deleted file mode 100644 index 3a3389c7..00000000 --- a/internal/quic-go/wire/new_token_frame_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_TOKEN frame", func() { - Context("parsing", func() { - It("accepts a sample frame", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(len(token)))...) - data = append(data, token...) - b := bytes.NewReader(data) - f, err := parseNewTokenFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(string(f.Token)).To(Equal(token)) - Expect(b.Len()).To(BeZero()) - }) - - It("rejects empty tokens", func() { - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(0))...) - b := bytes.NewReader(data) - _, err := parseNewTokenFrame(b, protocol.VersionWhatever) - Expect(err).To(MatchError("token must not be empty")) - }) - - It("errors on EOFs", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" - data := []byte{0x7} - data = append(data, encodeVarInt(uint64(len(token)))...) - data = append(data, token...) - _, err := parseNewTokenFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseNewTokenFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("writes a sample frame", func() { - token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." - f := &NewTokenFrame{Token: []byte(token)} - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x7} - expected = append(expected, encodeVarInt(uint64(len(token)))...) - expected = append(expected, token...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := &NewTokenFrame{Token: []byte("foobar")} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(6) + 6)) - }) - }) -}) diff --git a/internal/quic-go/wire/path_challenge_frame.go b/internal/quic-go/wire/path_challenge_frame.go deleted file mode 100644 index 5d802249..00000000 --- a/internal/quic-go/wire/path_challenge_frame.go +++ /dev/null @@ -1,38 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A PathChallengeFrame is a PATH_CHALLENGE frame -type PathChallengeFrame struct { - Data [8]byte -} - -func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &PathChallengeFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - return frame, nil -} - -func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1a) - b.Write(f.Data[:]) - return nil -} - -// Length of a written frame -func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 + 8 -} diff --git a/internal/quic-go/wire/path_challenge_frame_test.go b/internal/quic-go/wire/path_challenge_frame_test.go deleted file mode 100644 index 620e1b20..00000000 --- a/internal/quic-go/wire/path_challenge_frame_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PATH_CHALLENGE frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathChallengeFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("errors on EOFs", func() { - data := []byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathChallengeFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) - }) - - It("has the correct min length", func() { - frame := PathChallengeFrame{} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) - }) - }) -}) diff --git a/internal/quic-go/wire/path_response_frame.go b/internal/quic-go/wire/path_response_frame.go deleted file mode 100644 index d7334ac1..00000000 --- a/internal/quic-go/wire/path_response_frame.go +++ /dev/null @@ -1,38 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A PathResponseFrame is a PATH_RESPONSE frame -type PathResponseFrame struct { - Data [8]byte -} - -func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &PathResponseFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err - } - return frame, nil -} - -func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1b) - b.Write(f.Data[:]) - return nil -} - -// Length of a written frame -func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 + 8 -} diff --git a/internal/quic-go/wire/path_response_frame_test.go b/internal/quic-go/wire/path_response_frame_test.go deleted file mode 100644 index 757a08f9..00000000 --- a/internal/quic-go/wire/path_response_frame_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PATH_RESPONSE frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathResponseFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) - }) - - It("errors on EOFs", func() { - data := []byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) - }) - - It("has the correct min length", func() { - frame := PathResponseFrame{} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) - }) - }) -}) diff --git a/internal/quic-go/wire/ping_frame.go b/internal/quic-go/wire/ping_frame.go deleted file mode 100644 index d47d8ce9..00000000 --- a/internal/quic-go/wire/ping_frame.go +++ /dev/null @@ -1,27 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -// A PingFrame is a PING frame -type PingFrame struct{} - -func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &PingFrame{}, nil -} - -func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x1) - return nil -} - -// Length of a written frame -func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/internal/quic-go/wire/ping_frame_test.go b/internal/quic-go/wire/ping_frame_test.go deleted file mode 100644 index cb9b2259..00000000 --- a/internal/quic-go/wire/ping_frame_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("PingFrame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x1}) - _, err := parsePingFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - _, err := parsePingFrame(bytes.NewReader(nil), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := PingFrame{} - frame.Write(b, protocol.VersionWhatever) - Expect(b.Bytes()).To(Equal([]byte{0x1})) - }) - - It("has the correct min length", func() { - frame := PingFrame{} - Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1))) - }) - }) -}) diff --git a/internal/quic-go/wire/pool.go b/internal/quic-go/wire/pool.go deleted file mode 100644 index 2fb1f82c..00000000 --- a/internal/quic-go/wire/pool.go +++ /dev/null @@ -1,33 +0,0 @@ -package wire - -import ( - "sync" - - "github.com/imroc/req/v3/internal/quic-go/protocol" -) - -var pool sync.Pool - -func init() { - pool.New = func() interface{} { - return &StreamFrame{ - Data: make([]byte, 0, protocol.MaxPacketBufferSize), - fromPool: true, - } - } -} - -func GetStreamFrame() *StreamFrame { - f := pool.Get().(*StreamFrame) - return f -} - -func putStreamFrame(f *StreamFrame) { - if !f.fromPool { - return - } - if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { - panic("wire.PutStreamFrame called with packet of wrong size!") - } - pool.Put(f) -} diff --git a/internal/quic-go/wire/pool_test.go b/internal/quic-go/wire/pool_test.go deleted file mode 100644 index b55e493b..00000000 --- a/internal/quic-go/wire/pool_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package wire - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Pool", func() { - It("gets and puts STREAM frames", func() { - f := GetStreamFrame() - putStreamFrame(f) - }) - - It("panics when putting a STREAM frame with a wrong capacity", func() { - f := GetStreamFrame() - f.Data = []byte("foobar") - Expect(func() { putStreamFrame(f) }).To(Panic()) - }) - - It("accepts STREAM frames not from the buffer, but ignores them", func() { - f := &StreamFrame{Data: []byte("foobar")} - putStreamFrame(f) - }) -}) diff --git a/internal/quic-go/wire/reset_stream_frame.go b/internal/quic-go/wire/reset_stream_frame.go deleted file mode 100644 index 29910473..00000000 --- a/internal/quic-go/wire/reset_stream_frame.go +++ /dev/null @@ -1,58 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A ResetStreamFrame is a RESET_STREAM frame in QUIC -type ResetStreamFrame struct { - StreamID protocol.StreamID - ErrorCode qerr.StreamErrorCode - FinalSize protocol.ByteCount -} - -func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { - if _, err := r.ReadByte(); err != nil { // read the TypeByte - return nil, err - } - - var streamID protocol.StreamID - var byteOffset protocol.ByteCount - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - streamID = protocol.StreamID(sid) - errorCode, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - bo, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - byteOffset = protocol.ByteCount(bo) - - return &ResetStreamFrame{ - StreamID: streamID, - ErrorCode: qerr.StreamErrorCode(errorCode), - FinalSize: byteOffset, - }, nil -} - -func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x4) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - quicvarint.Write(b, uint64(f.FinalSize)) - return nil -} - -// Length of a written frame -func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) -} diff --git a/internal/quic-go/wire/reset_stream_frame_test.go b/internal/quic-go/wire/reset_stream_frame_test.go deleted file mode 100644 index b60ba3f7..00000000 --- a/internal/quic-go/wire/reset_stream_frame_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("RESET_STREAM frame", func() { - Context("when parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x4} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - data = append(data, encodeVarInt(0x987654321)...) // byte offset - b := bytes.NewReader(data) - frame, err := parseResetStreamFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) - Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) - }) - - It("errors on EOFs", func() { - data := []byte{0x4} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - data = append(data, encodeVarInt(0x987654321)...) // byte offset - _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - frame := ResetStreamFrame{ - StreamID: 0x1337, - FinalSize: 0x11223344decafbad, - ErrorCode: 0xcafe, - } - b := &bytes.Buffer{} - err := frame.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x4} - expected = append(expected, encodeVarInt(0x1337)...) - expected = append(expected, encodeVarInt(0xcafe)...) - expected = append(expected, encodeVarInt(0x11223344decafbad)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - rst := ResetStreamFrame{ - StreamID: 0x1337, - FinalSize: 0x1234567, - ErrorCode: 0xde, - } - expectedLen := 1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + 2 - Expect(rst.Length(protocol.Version1)).To(Equal(expectedLen)) - }) - }) -}) diff --git a/internal/quic-go/wire/retire_connection_id_frame.go b/internal/quic-go/wire/retire_connection_id_frame.go deleted file mode 100644 index a7e09aab..00000000 --- a/internal/quic-go/wire/retire_connection_id_frame.go +++ /dev/null @@ -1,36 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame -type RetireConnectionIDFrame struct { - SequenceNumber uint64 -} - -func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - seq, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - return &RetireConnectionIDFrame{SequenceNumber: seq}, nil -} - -func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x19) - quicvarint.Write(b, f.SequenceNumber) - return nil -} - -// Length of a written frame -func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(f.SequenceNumber) -} diff --git a/internal/quic-go/wire/retire_connection_id_frame_test.go b/internal/quic-go/wire/retire_connection_id_frame_test.go deleted file mode 100644 index 2e531d34..00000000 --- a/internal/quic-go/wire/retire_connection_id_frame_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("NEW_CONNECTION_ID frame", func() { - Context("when parsing", func() { - It("accepts a sample frame", func() { - data := []byte{0x19} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - b := bytes.NewReader(data) - frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) - }) - - It("errors on EOFs", func() { - data := []byte{0x18} - data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - }) - - Context("when writing", func() { - It("writes a sample frame", func() { - frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - expected := []byte{0x19} - expected = append(expected, encodeVarInt(0x1337)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct length", func() { - frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) - }) - }) -}) diff --git a/internal/quic-go/wire/stop_sending_frame.go b/internal/quic-go/wire/stop_sending_frame.go deleted file mode 100644 index 5e40c4e3..00000000 --- a/internal/quic-go/wire/stop_sending_frame.go +++ /dev/null @@ -1,48 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A StopSendingFrame is a STOP_SENDING frame -type StopSendingFrame struct { - StreamID protocol.StreamID - ErrorCode qerr.StreamErrorCode -} - -// parseStopSendingFrame parses a STOP_SENDING frame -func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - errorCode, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &StopSendingFrame{ - StreamID: protocol.StreamID(streamID), - ErrorCode: qerr.StreamErrorCode(errorCode), - }, nil -} - -// Length of a written frame -func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) -} - -func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x5) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - return nil -} diff --git a/internal/quic-go/wire/stop_sending_frame_test.go b/internal/quic-go/wire/stop_sending_frame_test.go deleted file mode 100644 index e36e75ee..00000000 --- a/internal/quic-go/wire/stop_sending_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STOP_SENDING frame", func() { - Context("when parsing", func() { - It("parses a sample frame", func() { - data := []byte{0x5} - data = append(data, encodeVarInt(0xdecafbad)...) // stream ID - data = append(data, encodeVarInt(0x1337)...) // error code - b := bytes.NewReader(data) - frame, err := parseStopSendingFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) - Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x5} - data = append(data, encodeVarInt(0xdecafbad)...) // stream ID - data = append(data, encodeVarInt(0x123456)...) // error code - _, err := parseStopSendingFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("when writing", func() { - It("writes", func() { - frame := &StopSendingFrame{ - StreamID: 0xdeadbeefcafe, - ErrorCode: 0xdecafbad, - } - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) - expected := []byte{0x5} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := &StopSendingFrame{ - StreamID: 0xdeadbeef, - ErrorCode: 0x1234567, - } - Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) - }) - }) -}) diff --git a/internal/quic-go/wire/stream_data_blocked_frame.go b/internal/quic-go/wire/stream_data_blocked_frame.go deleted file mode 100644 index 447dc089..00000000 --- a/internal/quic-go/wire/stream_data_blocked_frame.go +++ /dev/null @@ -1,46 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame -type StreamDataBlockedFrame struct { - StreamID protocol.StreamID - MaximumStreamData protocol.ByteCount -} - -func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - offset, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - - return &StreamDataBlockedFrame{ - StreamID: protocol.StreamID(sid), - MaximumStreamData: protocol.ByteCount(offset), - }, nil -} - -func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x15) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil -} - -// Length of a written frame -func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) -} diff --git a/internal/quic-go/wire/stream_data_blocked_frame_test.go b/internal/quic-go/wire/stream_data_blocked_frame_test.go deleted file mode 100644 index 6d5f50db..00000000 --- a/internal/quic-go/wire/stream_data_blocked_frame_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM_DATA_BLOCKED frame", func() { - Context("parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x15} - data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - b := bytes.NewReader(data) - frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x15} - data = append(data, encodeVarInt(0xdeadbeef)...) - data = append(data, encodeVarInt(0xc0010ff)...) - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("has proper min length", func() { - f := &StreamDataBlockedFrame{ - StreamID: 0x1337, - MaximumStreamData: 0xdeadbeef, - } - Expect(f.Length(0)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0xdeadbeef))) - }) - - It("writes a sample frame", func() { - b := &bytes.Buffer{} - f := &StreamDataBlockedFrame{ - StreamID: 0xdecafbad, - MaximumStreamData: 0x1337, - } - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x15} - expected = append(expected, encodeVarInt(uint64(f.StreamID))...) - expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) - Expect(b.Bytes()).To(Equal(expected)) - }) - }) -}) diff --git a/internal/quic-go/wire/stream_frame.go b/internal/quic-go/wire/stream_frame.go deleted file mode 100644 index c4b9db48..00000000 --- a/internal/quic-go/wire/stream_frame.go +++ /dev/null @@ -1,189 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A StreamFrame of QUIC -type StreamFrame struct { - StreamID protocol.StreamID - Offset protocol.ByteCount - Data []byte - Fin bool - DataLenPresent bool - - fromPool bool -} - -func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - hasOffset := typeByte&0x4 > 0 - fin := typeByte&0x1 > 0 - hasDataLen := typeByte&0x2 > 0 - - streamID, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - var offset uint64 - if hasOffset { - offset, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - } - - var dataLen uint64 - if hasDataLen { - var err error - dataLen, err = quicvarint.Read(r) - if err != nil { - return nil, err - } - } else { - // The rest of the packet is data - dataLen = uint64(r.Len()) - } - - var frame *StreamFrame - if dataLen < protocol.MinStreamFrameBufferSize { - frame = &StreamFrame{Data: make([]byte, dataLen)} - } else { - frame = GetStreamFrame() - // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, - // since those StreamFrames have a buffer length of the maximum packet size. - if dataLen > uint64(cap(frame.Data)) { - return nil, io.EOF - } - frame.Data = frame.Data[:dataLen] - } - - frame.StreamID = protocol.StreamID(streamID) - frame.Offset = protocol.ByteCount(offset) - frame.Fin = fin - frame.DataLenPresent = hasDataLen - - if dataLen != 0 { - if _, err := io.ReadFull(r, frame.Data); err != nil { - return nil, err - } - } - if frame.Offset+frame.DataLen() > protocol.MaxByteCount { - return nil, errors.New("stream data overflows maximum offset") - } - return frame, nil -} - -// Write writes a STREAM frame -func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if len(f.Data) == 0 && !f.Fin { - return errors.New("StreamFrame: attempting to write empty frame without FIN") - } - - typeByte := byte(0x8) - if f.Fin { - typeByte ^= 0x1 - } - hasOffset := f.Offset != 0 - if f.DataLenPresent { - typeByte ^= 0x2 - } - if hasOffset { - typeByte ^= 0x4 - } - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.StreamID)) - if hasOffset { - quicvarint.Write(b, uint64(f.Offset)) - } - if f.DataLenPresent { - quicvarint.Write(b, uint64(f.DataLen())) - } - b.Write(f.Data) - return nil -} - -// Length returns the total length of the STREAM frame -func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - length := 1 + quicvarint.Len(uint64(f.StreamID)) - if f.Offset != 0 { - length += quicvarint.Len(uint64(f.Offset)) - } - if f.DataLenPresent { - length += quicvarint.Len(uint64(f.DataLen())) - } - return length + f.DataLen() -} - -// DataLen gives the length of data in bytes -func (f *StreamFrame) DataLen() protocol.ByteCount { - return protocol.ByteCount(len(f.Data)) -} - -// MaxDataLen returns the maximum data length -// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). -func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { - headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) - if f.Offset != 0 { - headerLen += quicvarint.Len(uint64(f.Offset)) - } - if f.DataLenPresent { - // pretend that the data size will be 1 bytes - // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards - headerLen++ - } - if headerLen > maxSize { - return 0 - } - maxDataLen := maxSize - headerLen - if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { - maxDataLen-- - } - return maxDataLen -} - -// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. -// It returns if the frame was actually split. -// The frame might not be split if: -// * the size is large enough to fit the whole frame -// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { - if maxSize >= f.Length(version) { - return nil, false - } - - n := f.MaxDataLen(maxSize, version) - if n == 0 { - return nil, true - } - - new := GetStreamFrame() - new.StreamID = f.StreamID - new.Offset = f.Offset - new.Fin = false - new.DataLenPresent = f.DataLenPresent - - // swap the data slices - new.Data, f.Data = f.Data, new.Data - new.fromPool, f.fromPool = f.fromPool, new.fromPool - - f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] - copy(f.Data, new.Data[n:]) - new.Data = new.Data[:n] - f.Offset += n - - return new, true -} - -func (f *StreamFrame) PutBack() { - putStreamFrame(f) -} diff --git a/internal/quic-go/wire/stream_frame_test.go b/internal/quic-go/wire/stream_frame_test.go deleted file mode 100644 index a533a183..00000000 --- a/internal/quic-go/wire/stream_frame_test.go +++ /dev/null @@ -1,443 +0,0 @@ -package wire - -import ( - "bytes" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAM frame", func() { - Context("when parsing", func() { - It("parses a frame with OFF bit", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(r.Len()).To(BeZero()) - }) - - It("respects the LEN when parsing the frame", func() { - data := []byte{0x8 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(4)...) // data length - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal([]byte("foob"))) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(Equal(2)) - }) - - It("parses a frame with FIN bit", func() { - data := []byte{0x8 ^ 0x1} - data = append(data, encodeVarInt(9)...) // stream ID - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(frame.Fin).To(BeTrue()) - Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(BeZero()) - }) - - It("allows empty frames", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x1337)...) // stream ID - data = append(data, encodeVarInt(0x12345)...) // offset - r := bytes.NewReader(data) - f, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeFalse()) - }) - - It("rejects frames that overflow the maximum offset", func() { - data := []byte{0x8 ^ 0x4} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset - data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, protocol.Version1) - Expect(err).To(MatchError("stream data overflows maximum offset")) - }) - - It("rejects frames that claim to be longer than the packet size", func() { - data := []byte{0x8 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length - data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, protocol.Version1) - Expect(err).To(Equal(io.EOF)) - }) - - It("errors on EOFs", func() { - data := []byte{0x8 ^ 0x4 ^ 0x2} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, encodeVarInt(0xdecafbad)...) // offset - data = append(data, encodeVarInt(6)...) // data length - data = append(data, []byte("foobar")...) - _, err := parseStreamFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("using the buffer", func() { - It("uses the buffer for long STREAM frames", func() { - data := []byte{0x8} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) - Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize)) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.fromPool).To(BeTrue()) - Expect(r.Len()).To(BeZero()) - Expect(frame.PutBack).ToNot(Panic()) - }) - - It("doesn't use the buffer for short STREAM frames", func() { - data := []byte{0x8} - data = append(data, encodeVarInt(0x12345)...) // stream ID - data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) - Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) - Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1)) - Expect(frame.Fin).To(BeFalse()) - Expect(frame.fromPool).To(BeFalse()) - Expect(r.Len()).To(BeZero()) - Expect(frame.PutBack).ToNot(Panic()) - }) - }) - - Context("when writing", func() { - It("writes a frame without offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x123456, - Data: []byte("foobar"), - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with FIN bit", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x123456, - Fin: true, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4 ^ 0x1} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - DataLenPresent: true, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x2} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(6)...) // data length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame with data length and offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - DataLenPresent: true, - Offset: 0x123456, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - expected := []byte{0x8 ^ 0x4 ^ 0x2} - expected = append(expected, encodeVarInt(0x1337)...) // stream ID - expected = append(expected, encodeVarInt(0x123456)...) // offset - expected = append(expected, encodeVarInt(6)...) // data length - expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("refuses to write an empty frame without FIN", func() { - f := &StreamFrame{ - StreamID: 0x42, - Offset: 0x1337, - } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) - Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) - }) - }) - - Context("length", func() { - It("has the right length for a frame without offset and data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) - }) - - It("has the right length for a frame with offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x42, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) - }) - - It("has the right length for a frame with data length", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x1234567, - DataLenPresent: true, - Data: []byte("foobar"), - } - Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) - }) - }) - - Context("max data length", func() { - const maxSize = 3000 - - It("always returns a data length such that the resulting frame has the right size, if data length is not present", func() { - data := make([]byte, maxSize) - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0xdeadbeef, - } - b := &bytes.Buffer{} - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(Equal(i)) - } - }) - - It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { - data := make([]byte, maxSize) - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0xdeadbeef, - DataLenPresent: true, - } - b := &bytes.Buffer{} - var frameOneByteTooSmallCounter int - for i := 1; i < 3000; i++ { - b.Reset() - f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) - if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written - // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size - f.Data = []byte{0} - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) - continue - } - f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(b.Len()).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) - - Context("splitting", func() { - It("doesn't split if the frame is short enough", func() { - f := &StreamFrame{ - StreamID: 0x1337, - DataLenPresent: true, - Offset: 0xdeadbeef, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) - Expect(needsSplit).To(BeFalse()) - Expect(frame).To(BeNil()) - Expect(f.DataLen()).To(BeEquivalentTo(100)) - frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame.DataLen()).To(BeEquivalentTo(99)) - f.PutBack() - }) - - It("keeps the data len", func() { - f := &StreamFrame{ - StreamID: 0x1337, - DataLenPresent: true, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(66, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(frame.DataLenPresent).To(BeTrue()) - }) - - It("adjusts the offset", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Offset: 0x100, - Data: []byte("foobar"), - } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0x100))) - Expect(frame.Data).To(Equal([]byte("foo"))) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x100 + 3))) - Expect(f.Data).To(Equal([]byte("bar"))) - }) - - It("preserves the FIN bit", func() { - f := &StreamFrame{ - StreamID: 0x1337, - Fin: true, - Offset: 0xdeadbeef, - Data: make([]byte, 100), - } - frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Offset).To(BeNumerically("<", f.Offset)) - Expect(f.Fin).To(BeTrue()) - Expect(frame.Fin).To(BeFalse()) - }) - - It("produces frames of the correct length, without data len", func() { - const size = 1000 - f := &StreamFrame{ - StreamID: 0xdecafbad, - Offset: 0x1234, - Data: []byte{0}, - } - minFrameSize := f.Length(protocol.Version1) - for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - for i := minFrameSize; i < size; i++ { - f.fromPool = false - f.Data = make([]byte, size) - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f.Length(protocol.Version1)).To(Equal(i)) - } - }) - - It("produces frames of the correct length, with data len", func() { - const size = 1000 - f := &StreamFrame{ - StreamID: 0xdecafbad, - Offset: 0x1234, - DataLenPresent: true, - Data: []byte{0}, - } - minFrameSize := f.Length(protocol.Version1) - for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - Expect(f).To(BeNil()) - } - var frameOneByteTooSmallCounter int - for i := minFrameSize; i < size; i++ { - f.fromPool = false - f.Data = make([]byte, size) - newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) - Expect(needsSplit).To(BeTrue()) - // There's *one* pathological case, where a data length of x can be encoded into 1 byte - // but a data lengths of x+1 needs 2 bytes - // In that case, it's impossible to create a STREAM frame of the desired size - if newFrame.Length(protocol.Version1) == i-1 { - frameOneByteTooSmallCounter++ - continue - } - Expect(newFrame.Length(protocol.Version1)).To(Equal(i)) - } - Expect(frameOneByteTooSmallCounter).To(Equal(1)) - }) - }) -}) diff --git a/internal/quic-go/wire/streams_blocked_frame.go b/internal/quic-go/wire/streams_blocked_frame.go deleted file mode 100644 index aab28c24..00000000 --- a/internal/quic-go/wire/streams_blocked_frame.go +++ /dev/null @@ -1,55 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" -) - -// A StreamsBlockedFrame is a STREAMS_BLOCKED frame -type StreamsBlockedFrame struct { - Type protocol.StreamType - StreamLimit protocol.StreamNum -} - -func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &StreamsBlockedFrame{} - switch typeByte { - case 0x16: - f.Type = protocol.StreamTypeBidi - case 0x17: - f.Type = protocol.StreamTypeUni - } - streamLimit, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - f.StreamLimit = protocol.StreamNum(streamLimit) - if f.StreamLimit > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) - } - return f, nil -} - -func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - switch f.Type { - case protocol.StreamTypeBidi: - b.WriteByte(0x16) - case protocol.StreamTypeUni: - b.WriteByte(0x17) - } - quicvarint.Write(b, uint64(f.StreamLimit)) - return nil -} - -// Length of a written frame -func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 + quicvarint.Len(uint64(f.StreamLimit)) -} diff --git a/internal/quic-go/wire/streams_blocked_frame_test.go b/internal/quic-go/wire/streams_blocked_frame_test.go deleted file mode 100644 index 39b29dad..00000000 --- a/internal/quic-go/wire/streams_blocked_frame_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "io" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("STREAMS_BLOCKED frame", func() { - Context("parsing", func() { - It("accepts a frame for bidirectional streams", func() { - expected := []byte{0x16} - expected = append(expected, encodeVarInt(0x1337)...) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) - Expect(f.StreamLimit).To(BeEquivalentTo(0x1337)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts a frame for unidirectional streams", func() { - expected := []byte{0x17} - expected = append(expected, encodeVarInt(0x7331)...) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.Type).To(Equal(protocol.StreamTypeUni)) - Expect(f.StreamLimit).To(BeEquivalentTo(0x7331)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x16} - data = append(data, encodeVarInt(0x12345678)...) - _, err := parseStreamsBlockedFrame(bytes.NewReader(data), protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - for i := range data { - _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) - Expect(err).To(MatchError(io.EOF)) - } - }) - - for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { - streamType := t - - It("accepts a frame containing the maximum stream count", func() { - f := &StreamsBlockedFrame{ - Type: streamType, - StreamLimit: protocol.MaxStreamCount, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - }) - - It("errors when receiving a too large stream count", func() { - f := &StreamsBlockedFrame{ - Type: streamType, - StreamLimit: protocol.MaxStreamCount + 1, - } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) - }) - } - }) - - Context("writing", func() { - It("writes a frame for bidirectional streams", func() { - b := &bytes.Buffer{} - f := StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: 0xdeadbeefcafe, - } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x16} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("writes a frame for unidirectional streams", func() { - b := &bytes.Buffer{} - f := StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, - StreamLimit: 0xdeadbeefcafe, - } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - expected := []byte{0x17} - expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := StreamsBlockedFrame{StreamLimit: 0x123456} - Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + quicvarint.Len(0x123456))) - }) - }) -}) diff --git a/internal/quic-go/wire/transport_parameter_test.go b/internal/quic-go/wire/transport_parameter_test.go deleted file mode 100644 index 56d43fad..00000000 --- a/internal/quic-go/wire/transport_parameter_test.go +++ /dev/null @@ -1,612 +0,0 @@ -package wire - -import ( - "bytes" - "fmt" - "math" - "math/rand" - "net" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Transport Parameters", func() { - getRandomValueUpTo := func(max int64) uint64 { - maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - m := maxVals[int(rand.Int31n(4))] - if m > max { - m = max - } - return uint64(rand.Int63n(m)) - } - - getRandomValue := func() uint64 { - return getRandomValueUpTo(math.MaxInt64) - } - - BeforeEach(func() { - rand.Seed(GinkgoRandomSeed()) - }) - - addInitialSourceConnectionID := func(b *bytes.Buffer) { - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - } - - It("has a string representation", func() { - p := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1234, - InitialMaxStreamDataBidiRemote: 2345, - InitialMaxStreamDataUni: 3456, - InitialMaxData: 4567, - MaxBidiStreamNum: 1337, - MaxUniStreamNum: 7331, - MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - AckDelayExponent: 14, - MaxAckDelay: 37 * time.Millisecond, - StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - ActiveConnectionIDLimit: 123, - MaxDatagramFrameSize: 876, - } - Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: decafbad, RetrySourceConnectionID: deadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}")) - }) - - It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() { - p := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1234, - InitialMaxStreamDataBidiRemote: 2345, - InitialMaxStreamDataUni: 3456, - InitialMaxData: 4567, - MaxBidiStreamNum: 1337, - MaxUniStreamNum: 7331, - MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{}, - AckDelayExponent: 14, - MaxAckDelay: 37 * time.Second, - ActiveConnectionIDLimit: 89, - MaxDatagramFrameSize: protocol.InvalidByteCount, - } - Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}")) - }) - - It("marshals and unmarshals", func() { - var token protocol.StatelessResetToken - rand.Read(token[:]) - params := &TransportParameters{ - InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), - InitialMaxData: protocol.ByteCount(getRandomValue()), - MaxIdleTimeout: 0xcafe * time.Second, - MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - DisableActiveMigration: true, - StatelessResetToken: &token, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - AckDelayExponent: 13, - MaxAckDelay: 42 * time.Millisecond, - ActiveConnectionIDLimit: getRandomValue(), - MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), - } - data := params.Marshal(protocol.PerspectiveServer) - - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) - Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) - Expect(p.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) - Expect(p.InitialMaxData).To(Equal(params.InitialMaxData)) - Expect(p.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) - Expect(p.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) - Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) - Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) - Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) - Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - Expect(p.AckDelayExponent).To(Equal(uint8(13))) - Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) - Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) - Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) - }) - - It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() { - data := (&TransportParameters{ - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.RetrySourceConnectionID).To(BeNil()) - }) - - It("marshals a zero-length retry_source_connection_id", func() { - data := (&TransportParameters{ - RetrySourceConnectionID: &protocol.ConnectionID{}, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.RetrySourceConnectionID).ToNot(BeNil()) - Expect(p.RetrySourceConnectionID.Len()).To(BeZero()) - }) - - It("errors when the stateless_reset_token has the wrong length", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 15) - b.Write(make([]byte, 15)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "wrong length for stateless_reset_token: 15 (expected 16)", - })) - }) - - It("errors when the max_packet_size is too small", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(maxUDPPayloadSizeParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(1199))) - quicvarint.Write(b, 1199) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_packet_size: 1199 (minimum 1200)", - })) - }) - - It("errors when disable_active_migration has content", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "wrong length for disable_active_migration: 6 (expected empty)", - })) - }) - - It("errors when the server doesn't set the original_destination_connection_id", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(make([]byte, 16)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "missing original_destination_connection_id", - })) - }) - - It("errors when the initial_source_connection_id is missing", func() { - Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "missing initial_source_connection_id", - })) - }) - - It("errors when the max_ack_delay is too large", func() { - data := (&TransportParameters{ - MaxAckDelay: 1 << 14 * time.Millisecond, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", - })) - }) - - It("doesn't send the max_ack_delay, if it has the default value", func() { - const num = 1000 - var defaultLen, dataLen int - // marshal 1000 times to average out the greasing transport parameter - maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond - for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{ - MaxAckDelay: protocol.DefaultMaxAckDelay, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - defaultLen += len(dataDefault) - data := (&TransportParameters{ - MaxAckDelay: maxAckDelay, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - dataLen += len(data) - } - entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(uint64(maxAckDelay.Milliseconds())))) /*length */ + quicvarint.Len(uint64(maxAckDelay.Milliseconds())) /* value */ - Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) - }) - - It("errors when the ack_delay_exponenent is too large", func() { - data := (&TransportParameters{ - AckDelayExponent: 21, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for ack_delay_exponent: 21 (maximum 20)", - })) - }) - - It("doesn't send the ack_delay_exponent, if it has the default value", func() { - const num = 1000 - var defaultLen, dataLen int - // marshal 1000 times to average out the greasing transport parameter - for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - defaultLen += len(dataDefault) - data := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent + 1, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - dataLen += len(data) - } - entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) /* parameter id */ + quicvarint.Len(uint64(quicvarint.Len(protocol.DefaultAckDelayExponent+1))) /* length */ + quicvarint.Len(protocol.DefaultAckDelayExponent+1) /* value */ - Expect(float32(dataLen) / num).To(BeNumerically("~", float32(defaultLen)/num+float32(entryLen), 1)) - }) - - It("sets the default value for the ack_delay_exponent, when no value was sent", func() { - data := (&TransportParameters{ - AckDelayExponent: protocol.DefaultAckDelayExponent, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent)) - }) - - It("errors when the varint value has the wrong length", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, 2) - val := uint64(0xdeadbeef) - Expect(quicvarint.Len(val)).ToNot(BeEquivalentTo(2)) - quicvarint.Write(b, val) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), - })) - }) - - It("errors if initial_max_streams_bidi is too large", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamsBidiParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) - quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", - })) - }) - - It("errors if initial_max_streams_uni is too large", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(initialMaxStreamsUniParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) - quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", - })) - }) - - It("handles huge max_ack_delay values", func() { - b := &bytes.Buffer{} - val := uint64(math.MaxUint64) / 5 - quicvarint.Write(b, uint64(maxAckDelayParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", - })) - }) - - It("skips unknown parameters", func() { - b := &bytes.Buffer{} - // write a known parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - // write an unknown parameter - quicvarint.Write(b, 0x42) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - // write a known parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x42))) - quicvarint.Write(b, 0x42) - addInitialSourceConnectionID(b) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) - Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) - Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) - }) - - It("rejects duplicate parameters", func() { - b := &bytes.Buffer{} - // write first parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - // write a second parameter - quicvarint.Write(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x42))) - quicvarint.Write(b, 0x42) - // write first parameter again - quicvarint.Write(b, uint64(initialMaxStreamDataBidiLocalParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) - quicvarint.Write(b, 0x1337) - addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), - })) - }) - - It("errors if there's not enough data to read", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, 0x42) - quicvarint.Write(b, 7) - b.Write([]byte("foobar")) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "remaining length (6) smaller than parameter length (7)", - })) - }) - - It("errors if the client sent a stateless_reset_token", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, uint64(quicvarint.Len(16))) - b.Write(make([]byte, 16)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent a stateless_reset_token", - })) - }) - - It("errors if the client sent the original_destination_connection_id", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent an original_destination_connection_id", - })) - }) - - Context("preferred address", func() { - var pa *PreferredAddress - - BeforeEach(func() { - pa = &PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv4Port: 42, - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - IPv6Port: 13, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, - } - }) - - It("marshals and unmarshals", func() { - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) - Expect(p.PreferredAddress.IPv4Port).To(Equal(pa.IPv4Port)) - Expect(p.PreferredAddress.IPv6.String()).To(Equal(pa.IPv6.String())) - Expect(p.PreferredAddress.IPv6Port).To(Equal(pa.IPv6Port)) - Expect(p.PreferredAddress.ConnectionID).To(Equal(pa.ConnectionID)) - Expect(p.PreferredAddress.StatelessResetToken).To(Equal(pa.StatelessResetToken)) - }) - - It("errors if the client sent a preferred_address", func() { - b := &bytes.Buffer{} - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 6) - b.Write([]byte("foobar")) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "client sent a preferred_address", - })) - }) - - It("errors on zero-length connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{} - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 0", - })) - }) - - It("errors on too long connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} - Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 21", - })) - }) - - It("errors on EOF", func() { - raw := []byte{ - 127, 0, 0, 1, // IPv4 - 0, 42, // IPv4 Port - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // IPv6 - 13, 37, // IPv6 Port, - 4, // conn ID len - 0xde, 0xad, 0xbe, 0xef, - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // stateless reset token - } - for i := 1; i < len(raw); i++ { - buf := &bytes.Buffer{} - quicvarint.Write(buf, uint64(preferredAddressParameterID)) - buf.Write(raw[:i]) - p := &TransportParameters{} - Expect(p.Unmarshal(buf.Bytes(), protocol.PerspectiveServer)).ToNot(Succeed()) - } - }) - }) - - Context("saving and retrieving from a session ticket", func() { - It("saves and retrieves the parameters", func() { - params := &TransportParameters{ - InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), - InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), - InitialMaxData: protocol.ByteCount(getRandomValue()), - MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), - ActiveConnectionIDLimit: getRandomValue(), - } - Expect(params.ValidFor0RTT(params)).To(BeTrue()) - b := &bytes.Buffer{} - params.MarshalForSessionTicket(b) - var tp TransportParameters - Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) - Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) - Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) - Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) - Expect(tp.InitialMaxData).To(Equal(params.InitialMaxData)) - Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) - Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) - Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) - }) - - It("rejects the parameters if it can't parse them", func() { - var p TransportParameters - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed()) - }) - - It("rejects the parameters if the version changed", func() { - var p TransportParameters - buf := &bytes.Buffer{} - p.MarshalForSessionTicket(buf) - data := buf.Bytes() - b := &bytes.Buffer{} - quicvarint.Write(b, transportParameterMarshalingVersion+1) - b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) - }) - - Context("rejects the parameters if they changed", func() { - var p TransportParameters - saved := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1, - InitialMaxStreamDataBidiRemote: 2, - InitialMaxStreamDataUni: 3, - InitialMaxData: 4, - MaxBidiStreamNum: 5, - MaxUniStreamNum: 6, - ActiveConnectionIDLimit: 7, - } - - BeforeEach(func() { - p = *saved - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { - p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { - p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { - p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { - p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { - p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { - p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the InitialMaxData was reduced", func() { - p.InitialMaxData = saved.InitialMaxData - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("doesn't reject the parameters if the InitialMaxData was increased", func() { - p.InitialMaxData = saved.InitialMaxData + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { - p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("accepts the parameters if the MaxBidiStreamNum was increased", func() { - p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the MaxUniStreamNum changed", func() { - p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - - It("accepts the parameters if the MaxUniStreamNum was increased", func() { - p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 - Expect(p.ValidFor0RTT(saved)).To(BeTrue()) - }) - - It("rejects the parameters if the ActiveConnectionIDLimit changed", func() { - p.ActiveConnectionIDLimit = 0 - Expect(p.ValidFor0RTT(saved)).To(BeFalse()) - }) - }) - }) -}) diff --git a/internal/quic-go/wire/transport_parameters.go b/internal/quic-go/wire/transport_parameters.go deleted file mode 100644 index 544e3506..00000000 --- a/internal/quic-go/wire/transport_parameters.go +++ /dev/null @@ -1,476 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/rand" - "net" - "sort" - "time" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/qerr" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -const transportParameterMarshalingVersion = 1 - -func init() { - rand.Seed(time.Now().UTC().UnixNano()) -} - -type transportParameterID uint64 - -const ( - originalDestinationConnectionIDParameterID transportParameterID = 0x0 - maxIdleTimeoutParameterID transportParameterID = 0x1 - statelessResetTokenParameterID transportParameterID = 0x2 - maxUDPPayloadSizeParameterID transportParameterID = 0x3 - initialMaxDataParameterID transportParameterID = 0x4 - initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 - initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 - initialMaxStreamDataUniParameterID transportParameterID = 0x7 - initialMaxStreamsBidiParameterID transportParameterID = 0x8 - initialMaxStreamsUniParameterID transportParameterID = 0x9 - ackDelayExponentParameterID transportParameterID = 0xa - maxAckDelayParameterID transportParameterID = 0xb - disableActiveMigrationParameterID transportParameterID = 0xc - preferredAddressParameterID transportParameterID = 0xd - activeConnectionIDLimitParameterID transportParameterID = 0xe - initialSourceConnectionIDParameterID transportParameterID = 0xf - retrySourceConnectionIDParameterID transportParameterID = 0x10 - // RFC 9221 - maxDatagramFrameSizeParameterID transportParameterID = 0x20 -) - -// PreferredAddress is the value encoding in the preferred_address transport parameter -type PreferredAddress struct { - IPv4 net.IP - IPv4Port uint16 - IPv6 net.IP - IPv6Port uint16 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} - -// TransportParameters are parameters sent to the peer during the handshake -type TransportParameters struct { - InitialMaxStreamDataBidiLocal protocol.ByteCount - InitialMaxStreamDataBidiRemote protocol.ByteCount - InitialMaxStreamDataUni protocol.ByteCount - InitialMaxData protocol.ByteCount - - MaxAckDelay time.Duration - AckDelayExponent uint8 - - DisableActiveMigration bool - - MaxUDPPayloadSize protocol.ByteCount - - MaxUniStreamNum protocol.StreamNum - MaxBidiStreamNum protocol.StreamNum - - MaxIdleTimeout time.Duration - - PreferredAddress *PreferredAddress - - OriginalDestinationConnectionID protocol.ConnectionID - InitialSourceConnectionID protocol.ConnectionID - RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters - - StatelessResetToken *protocol.StatelessResetToken - ActiveConnectionIDLimit uint64 - - MaxDatagramFrameSize protocol.ByteCount -} - -// Unmarshal the transport parameters -func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { - if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { - return &qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - } - } - return nil -} - -func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { - // needed to check that every parameter is only sent at most once - var parameterIDs []transportParameterID - - var ( - readOriginalDestinationConnectionID bool - readInitialSourceConnectionID bool - ) - - p.AckDelayExponent = protocol.DefaultAckDelayExponent - p.MaxAckDelay = protocol.DefaultMaxAckDelay - p.MaxDatagramFrameSize = protocol.InvalidByteCount - - for r.Len() > 0 { - paramIDInt, err := quicvarint.Read(r) - if err != nil { - return err - } - paramID := transportParameterID(paramIDInt) - paramLen, err := quicvarint.Read(r) - if err != nil { - return err - } - if uint64(r.Len()) < paramLen { - return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) - } - parameterIDs = append(parameterIDs, paramID) - switch paramID { - case maxIdleTimeoutParameterID, - maxUDPPayloadSizeParameterID, - initialMaxDataParameterID, - initialMaxStreamDataBidiLocalParameterID, - initialMaxStreamDataBidiRemoteParameterID, - initialMaxStreamDataUniParameterID, - initialMaxStreamsBidiParameterID, - initialMaxStreamsUniParameterID, - maxAckDelayParameterID, - activeConnectionIDLimitParameterID, - maxDatagramFrameSizeParameterID, - ackDelayExponentParameterID: - if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { - return err - } - case preferredAddressParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a preferred_address") - } - if err := p.readPreferredAddress(r, int(paramLen)); err != nil { - return err - } - case disableActiveMigrationParameterID: - if paramLen != 0 { - return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) - } - p.DisableActiveMigration = true - case statelessResetTokenParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a stateless_reset_token") - } - if paramLen != 16 { - return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) - } - var token protocol.StatelessResetToken - r.Read(token[:]) - p.StatelessResetToken = &token - case originalDestinationConnectionIDParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent an original_destination_connection_id") - } - p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) - readOriginalDestinationConnectionID = true - case initialSourceConnectionIDParameterID: - p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) - readInitialSourceConnectionID = true - case retrySourceConnectionIDParameterID: - if sentBy == protocol.PerspectiveClient { - return errors.New("client sent a retry_source_connection_id") - } - connID, _ := protocol.ReadConnectionID(r, int(paramLen)) - p.RetrySourceConnectionID = &connID - default: - r.Seek(int64(paramLen), io.SeekCurrent) - } - } - - if !fromSessionTicket { - if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { - return errors.New("missing original_destination_connection_id") - } - if p.MaxUDPPayloadSize == 0 { - p.MaxUDPPayloadSize = protocol.MaxByteCount - } - if !readInitialSourceConnectionID { - return errors.New("missing initial_source_connection_id") - } - } - - // check that every transport parameter was sent at most once - sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) - for i := 0; i < len(parameterIDs)-1; i++ { - if parameterIDs[i] == parameterIDs[i+1] { - return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) - } - } - - return nil -} - -func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { - remainingLen := r.Len() - pa := &PreferredAddress{} - ipv4 := make([]byte, 4) - if _, err := io.ReadFull(r, ipv4); err != nil { - return err - } - pa.IPv4 = net.IP(ipv4) - port, err := utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv4Port = port - ipv6 := make([]byte, 16) - if _, err := io.ReadFull(r, ipv6); err != nil { - return err - } - pa.IPv6 = net.IP(ipv6) - port, err = utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv6Port = port - connIDLen, err := r.ReadByte() - if err != nil { - return err - } - if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) - } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return err - } - pa.ConnectionID = connID - if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { - return err - } - if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { - return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) - } - p.PreferredAddress = pa - return nil -} - -func (p *TransportParameters) readNumericTransportParameter( - r *bytes.Reader, - paramID transportParameterID, - expectedLen int, -) error { - remainingLen := r.Len() - val, err := quicvarint.Read(r) - if err != nil { - return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) - } - if remainingLen-r.Len() != expectedLen { - return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) - } - //nolint:exhaustive // This only covers the numeric transport parameters. - switch paramID { - case initialMaxStreamDataBidiLocalParameterID: - p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) - case initialMaxStreamDataBidiRemoteParameterID: - p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) - case initialMaxStreamDataUniParameterID: - p.InitialMaxStreamDataUni = protocol.ByteCount(val) - case initialMaxDataParameterID: - p.InitialMaxData = protocol.ByteCount(val) - case initialMaxStreamsBidiParameterID: - p.MaxBidiStreamNum = protocol.StreamNum(val) - if p.MaxBidiStreamNum > protocol.MaxStreamCount { - return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) - } - case initialMaxStreamsUniParameterID: - p.MaxUniStreamNum = protocol.StreamNum(val) - if p.MaxUniStreamNum > protocol.MaxStreamCount { - return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) - } - case maxIdleTimeoutParameterID: - p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) - case maxUDPPayloadSizeParameterID: - if val < 1200 { - return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) - } - p.MaxUDPPayloadSize = protocol.ByteCount(val) - case ackDelayExponentParameterID: - if val > protocol.MaxAckDelayExponent { - return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) - } - p.AckDelayExponent = uint8(val) - case maxAckDelayParameterID: - if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { - return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) - } - p.MaxAckDelay = time.Duration(val) * time.Millisecond - case activeConnectionIDLimitParameterID: - p.ActiveConnectionIDLimit = val - case maxDatagramFrameSizeParameterID: - p.MaxDatagramFrameSize = protocol.ByteCount(val) - default: - return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) - } - return nil -} - -// Marshal the transport parameters -func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { - b := &bytes.Buffer{} - - // add a greased value - quicvarint.Write(b, uint64(27+31*rand.Intn(100))) - length := rand.Intn(16) - randomData := make([]byte, length) - rand.Read(randomData) - quicvarint.Write(b, uint64(length)) - b.Write(randomData) - - // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) - // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) - // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) - // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) - // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) - // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) - // idle_timeout - p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) - // max_packet_size - p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) - // max_ack_delay - // Only send it if is different from the default value. - if p.MaxAckDelay != protocol.DefaultMaxAckDelay { - p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) - } - // ack_delay_exponent - // Only send it if is different from the default value. - if p.AckDelayExponent != protocol.DefaultAckDelayExponent { - p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) - } - // disable_active_migration - if p.DisableActiveMigration { - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 0) - } - if pers == protocol.PerspectiveServer { - // stateless_reset_token - if p.StatelessResetToken != nil { - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(p.StatelessResetToken[:]) - } - // original_destination_connection_id - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) - b.Write(p.OriginalDestinationConnectionID.Bytes()) - // preferred_address - if p.PreferredAddress != nil { - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) - ipv4 := p.PreferredAddress.IPv4 - b.Write(ipv4[len(ipv4)-4:]) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) - b.Write(p.PreferredAddress.IPv6) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) - b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) - b.Write(p.PreferredAddress.ConnectionID.Bytes()) - b.Write(p.PreferredAddress.StatelessResetToken[:]) - } - } - // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) - // initial_source_connection_id - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) - b.Write(p.InitialSourceConnectionID.Bytes()) - // retry_source_connection_id - if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { - quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) - b.Write(p.RetrySourceConnectionID.Bytes()) - } - if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) - } - return b.Bytes() -} - -func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { - quicvarint.Write(b, uint64(id)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) -} - -// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. -// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. -// The client will remember the transport parameters used in the last session, -// and apply those to the 0-RTT data it sends. -// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT -// if the transport parameters changed. -// Since the session ticket is encrypted, the serialization format is defined by the server. -// For convenience, we use the same format that we also use for sending the transport parameters. -func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { - quicvarint.Write(b, transportParameterMarshalingVersion) - - // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) - // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) - // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) - // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) - // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) - // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) - // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) -} - -// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. -func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { - version, err := quicvarint.Read(r) - if err != nil { - return err - } - if version != transportParameterMarshalingVersion { - return fmt.Errorf("unknown transport parameter marshaling version: %d", version) - } - return p.unmarshal(r, protocol.PerspectiveServer, true) -} - -// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. -func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { - return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && - p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && - p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && - p.InitialMaxData >= saved.InitialMaxData && - p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && - p.MaxUniStreamNum >= saved.MaxUniStreamNum && - p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit -} - -// String returns a string representation, intended for logging. -func (p *TransportParameters) String() string { - logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " - logParams := []interface{}{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} - if p.RetrySourceConnectionID != nil { - logString += "RetrySourceConnectionID: %s, " - logParams = append(logParams, p.RetrySourceConnectionID) - } - logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" - logParams = append(logParams, []interface{}{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) - if p.StatelessResetToken != nil { // the client never sends a stateless reset token - logString += ", StatelessResetToken: %#x" - logParams = append(logParams, *p.StatelessResetToken) - } - if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - logString += ", MaxDatagramFrameSize: %d" - logParams = append(logParams, p.MaxDatagramFrameSize) - } - logString += "}" - return fmt.Sprintf(logString, logParams...) -} diff --git a/internal/quic-go/wire/version_negotiation.go b/internal/quic-go/wire/version_negotiation.go deleted file mode 100644 index ee1613e4..00000000 --- a/internal/quic-go/wire/version_negotiation.go +++ /dev/null @@ -1,54 +0,0 @@ -package wire - -import ( - "bytes" - "crypto/rand" - "errors" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/utils" -) - -// ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { - hdr, err := parseHeader(b, 0) - if err != nil { - return nil, nil, err - } - if b.Len() == 0 { - //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has empty version list") - } - if b.Len()%4 != 0 { - //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") - } - versions := make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, nil, err - } - versions[i] = protocol.VersionNumber(v) - } - return hdr, versions, nil -} - -// ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { - greasedVersions := protocol.GetGreasedVersions(versions) - expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 - buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) - r := make([]byte, 1) - _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. - buf.WriteByte(r[0] | 0x80) - utils.BigEndian.WriteUint32(buf, 0) // version 0 - buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID) - buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID) - for _, v := range greasedVersions { - utils.BigEndian.WriteUint32(buf, uint32(v)) - } - return buf.Bytes() -} diff --git a/internal/quic-go/wire/version_negotiation_test.go b/internal/quic-go/wire/version_negotiation_test.go deleted file mode 100644 index 9a312b82..00000000 --- a/internal/quic-go/wire/version_negotiation_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Version Negotiation Packets", func() { - It("parses a Version Negotiation packet", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data := []byte{0x80, 0, 0, 0, 0} - data = append(data, uint8(len(destConnID))) - data = append(data, destConnID...) - data = append(data, uint8(len(srcConnID))) - data = append(data, srcConnID...) - for _, v := range versions { - data = append(data, []byte{0, 0, 0, 0}...) - binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) - } - Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(BeZero()) - Expect(supportedVersions).To(Equal(versions)) - }) - - It("errors if it contains versions of the wrong length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data := ComposeVersionNegotiation(connID, connID, versions) - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) - Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) - }) - - It("errors if the version list is empty", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455} - data := ComposeVersionNegotiation(connID, connID, versions) - // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number - data = data[:len(data)-8] - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).To(MatchError("Version Negotiation packet has empty version list")) - }) - - It("adds a reserved version", func() { - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{1001, 1003} - data := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(data[0] & 0x80).ToNot(BeZero()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Version).To(BeZero()) - // the supported versions should include one reserved version number - Expect(supportedVersions).To(HaveLen(len(versions) + 1)) - for _, v := range versions { - Expect(supportedVersions).To(ContainElement(v)) - } - var reservedVersion protocol.VersionNumber - versionLoop: - for _, ver := range supportedVersions { - for _, v := range versions { - if v == ver { - continue versionLoop - } - } - reservedVersion = ver - } - Expect(reservedVersion).ToNot(BeZero()) - Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number - }) -}) diff --git a/internal/quic-go/wire/wire_suite_test.go b/internal/quic-go/wire/wire_suite_test.go deleted file mode 100644 index 54e5c70b..00000000 --- a/internal/quic-go/wire/wire_suite_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package wire - -import ( - "bytes" - "encoding/binary" - "testing" - - "github.com/imroc/req/v3/internal/quic-go/protocol" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestWire(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Wire Suite") -} - -func encodeVarInt(i uint64) []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, i) - return b.Bytes() -} - -func appendVersion(data []byte, v protocol.VersionNumber) []byte { - offset := len(data) - data = append(data, []byte{0, 0, 0, 0}...) - binary.BigEndian.PutUint32(data[offset:], uint32(v)) - return data -} diff --git a/transport.go b/transport.go index 2476dbf6..0c006b26 100644 --- a/transport.go +++ b/transport.go @@ -479,7 +479,7 @@ func (t *Transport) EnableHTTP3() { } return } - if !(minorVersion >= 16 && minorVersion <= 19) { + if !(minorVersion >= 18 && minorVersion <= 20) { if t.Debugf != nil { t.Debugf("%s is not support http3", v) } From 436172e144935393d7fcd1d288d5ba6b2a72bf4b Mon Sep 17 00:00:00 2001 From: rockerchen Date: Mon, 6 Feb 2023 19:41:41 +0800 Subject: [PATCH 674/843] Add go1.20 to ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a479e10f..be91c2a8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.19.x', '1.18.x', '1.17.x' ] + go: [ '1.20.x', '1.19.x', '1.18.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: From 05894c665360b137b913c0d58115ba3099e26554 Mon Sep 17 00:00:00 2001 From: rockerchen Date: Mon, 6 Feb 2023 19:47:06 +0800 Subject: [PATCH 675/843] Remove deprecated ioutil functions --- client.go | 5 +- client_test.go | 4 +- examples/uploadcallback/uploadserver/main.go | 3 +- internal/charsets/charsets_test.go | 4 +- internal/http2/pipe_test.go | 5 +- internal/http2/server_test.go | 5 +- internal/http2/transport_test.go | 63 ++++++++++---------- internal/http3/frames.go | 6 +- internal/testdata/cert.go | 3 +- internal/testdata/cert_test.go | 4 +- middleware.go | 5 +- req_test.go | 21 ++++--- request.go | 7 +-- request_test.go | 13 ++-- response.go | 4 +- retry_test.go | 4 +- roundtrip_js.go | 3 +- transfer.go | 13 ++-- 18 files changed, 78 insertions(+), 94 deletions(-) diff --git a/client.go b/client.go index d80458c9..e3d0a253 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,6 @@ import ( "github.com/imroc/req/v3/internal/util" "golang.org/x/net/publicsuffix" "io" - "io/ioutil" "net" "net/http" "net/http/cookiejar" @@ -284,7 +283,7 @@ func (c *Client) SetRootCertFromString(pemContent string) *Client { // SetRootCertsFromFile set root certificates from files. func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { for _, pemFile := range pemFiles { - rootPemData, err := ioutil.ReadFile(pemFile) + rootPemData, err := os.ReadFile(pemFile) if err != nil { c.log.Errorf("failed to read root cert file: %v", err) return c @@ -1377,7 +1376,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { _, err = resp.ToBytes() // restore body for re-reads - resp.Body = ioutil.NopCloser(bytes.NewReader(resp.body)) + resp.Body = io.NopCloser(bytes.NewReader(resp.body)) } else if err != nil { resp.Err = err } diff --git a/client_test.go b/client_test.go index ee87f14c..75ec45a2 100644 --- a/client_test.go +++ b/client_test.go @@ -7,7 +7,7 @@ import ( "errors" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" - "io/ioutil" + "io" "net" "net/http" "net/url" @@ -484,7 +484,7 @@ func testDisableAutoReadResponse(t *testing.T, c *Client) { resp, err = c.R().Get("/") assertSuccess(t, resp, err) - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) tests.AssertNoError(t, err) } diff --git a/examples/uploadcallback/uploadserver/main.go b/examples/uploadcallback/uploadserver/main.go index f1149207..62def71d 100644 --- a/examples/uploadcallback/uploadserver/main.go +++ b/examples/uploadcallback/uploadserver/main.go @@ -3,7 +3,6 @@ package main import ( "github.com/gin-gonic/gin" "io" - "io/ioutil" "net/http" ) @@ -11,7 +10,7 @@ func main() { router := gin.Default() router.POST("/upload", func(c *gin.Context) { body := c.Request.Body - io.Copy(ioutil.Discard, body) + io.Copy(io.Discard, body) c.String(http.StatusOK, "ok") }) router.Run(":8888") diff --git a/internal/charsets/charsets_test.go b/internal/charsets/charsets_test.go index 239e371f..28cd698a 100644 --- a/internal/charsets/charsets_test.go +++ b/internal/charsets/charsets_test.go @@ -6,7 +6,7 @@ package charsets import ( "github.com/imroc/req/v3/internal/tests" - "io/ioutil" + "os" "runtime" "testing" ) @@ -30,7 +30,7 @@ func TestSniff(t *testing.T) { } for _, tc := range sniffTestCases { - content, err := ioutil.ReadFile(tests.GetTestFilePath(tc.filename)) + content, err := os.ReadFile(tests.GetTestFilePath(tc.filename)) if err != nil { t.Errorf("%s: error reading file: %v", tc.filename, err) continue diff --git a/internal/http2/pipe_test.go b/internal/http2/pipe_test.go index 83d2dfd2..c21c007e 100644 --- a/internal/http2/pipe_test.go +++ b/internal/http2/pipe_test.go @@ -8,7 +8,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "testing" ) @@ -85,7 +84,7 @@ func TestPipeCloseWithError(t *testing.T) { io.WriteString(p, body) a := errors.New("test error") p.CloseWithError(a) - all, err := ioutil.ReadAll(p) + all, err := io.ReadAll(p) if string(all) != body { t.Errorf("read bytes = %q; want %q", all, body) } @@ -112,7 +111,7 @@ func TestPipeBreakWithError(t *testing.T) { io.WriteString(p, "foo") a := errors.New("test err") p.BreakWithError(a) - all, err := ioutil.ReadAll(p) + all, err := io.ReadAll(p) if string(all) != "" { t.Errorf("read bytes = %q; want empty string", all) } diff --git a/internal/http2/server_test.go b/internal/http2/server_test.go index c1ddee69..77d0cf2f 100644 --- a/internal/http2/server_test.go +++ b/internal/http2/server_test.go @@ -10,7 +10,6 @@ import ( "fmt" "github.com/imroc/req/v3/internal/ascii" "io" - "io/ioutil" "log" "math" "net" @@ -4707,7 +4706,7 @@ func stderrv() io.Writer { return os.Stderr } - return ioutil.Discard + return io.Discard } type safeBuffer struct { @@ -4947,7 +4946,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config if quiet { - ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ErrorLog = log.New(io.Discard, "", 0) } else { ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) } diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index c75086b2..0cad0c63 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -19,7 +19,6 @@ import ( "github.com/imroc/req/v3/internal/transport" "io" "io/fs" - "io/ioutil" "log" "math/rand" "net" @@ -131,7 +130,7 @@ func TestTransportH2c(t *testing.T) { if res.ProtoMajor != 2 { t.Fatal("proto not h2c") } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -190,7 +189,7 @@ func TestTransport(t *testing.T) { if res.TLS == nil { t.Errorf("%d: Response.TLS = nil; want non-nil", i) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Errorf("%d: Body read: %v", i, err) } else if string(slurp) != body { @@ -224,7 +223,7 @@ func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Req t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("Body read: %v", err) } @@ -449,7 +448,7 @@ func TestTransportAbortClosesPipes(t *testing.T) { } defer res.Body.Close() st.closeConn() - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if err == nil { errCh <- errors.New("expected error from res.Body.Read") return @@ -585,7 +584,7 @@ func TestTransportBody(t *testing.T) { gotc := make(chan reqInfo, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { gotc <- reqInfo{err: err} } else { @@ -941,7 +940,7 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { } io.Copy(body, io.LimitReader(tests.NeverEnding('A'), bodySize/2)) body.CloseWithError(io.EOF) - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("Slurp: %v", err) } @@ -1061,7 +1060,7 @@ func TestTransportFullDuplex(t *testing.T) { c := &http.Client{Transport: tr} pr, pw := io.Pipe() - req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) + req, err := http.NewRequest("PUT", st.ts.URL, io.NopCloser(pr)) if err != nil { t.Fatal(err) } @@ -1238,7 +1237,7 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("Slurp: %v", err) } @@ -1424,7 +1423,7 @@ func TestTransportReceiveUndeclaredTrailer(t *testing.T) { if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) } @@ -1524,7 +1523,7 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) se, ok := err.(StreamError) if !ok || se.Cause != wantErr { return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) @@ -1712,11 +1711,11 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { // Consume body & force client to send // trailers before writing response. - // ioutil.ReadAll returns non-nil err for + // io.ReadAll returns non-nil err for // requests that attempt to send greater than // maxHeaderListSize bytes of trailers, since // those requests generate a stream reset. - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) r.Body.Close() }, func(ts *httptest.Server) { @@ -2075,7 +2074,7 @@ func TestTransportDisableKeepAlives(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := ioutil.ReadAll(res.Body); err != nil { + if _, err := io.ReadAll(res.Body); err != nil { t.Fatal(err) } defer res.Body.Close() @@ -2139,7 +2138,7 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { t.Error(err) return } - if _, err := ioutil.ReadAll(res.Body); err != nil { + if _, err := io.ReadAll(res.Body); err != nil { t.Error(err) return } @@ -2506,7 +2505,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { // the first Read call's gzip.NewReader returning an error. func TestGzipReader_DoubleReadCrash(t *testing.T) { gz := &GzipReader{ - Body: ioutil.NopCloser(strings.NewReader("0123456789")), + Body: io.NopCloser(strings.NewReader("0123456789")), } var buf [1]byte n, err1 := gz.Read(buf[:]) @@ -2631,7 +2630,7 @@ func TestTransportReadHeadResponse(t *testing.T) { if res.ContentLength != 123 { return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("ReadAll: %v", err) } @@ -2674,7 +2673,7 @@ func TestTransportReadHeadResponse(t *testing.T) { func TestTransportReadHeadResponseWithBody(t *testing.T) { // This test use not valid response format. // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" @@ -2690,7 +2689,7 @@ func TestTransportReadHeadResponseWithBody(t *testing.T) { if res.ContentLength != int64(len(response)) { return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("ReadAll: %v", err) } @@ -2766,7 +2765,7 @@ func TestTransportHandlerBodyClose(t *testing.T) { if err != nil { t.Fatal(err) } - n, err := io.Copy(ioutil.Discard, res.Body) + n, err := io.Copy(io.Discard, res.Body) res.Body.Close() if n != bodySize || err != nil { t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) @@ -2866,7 +2865,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { if err != nil { return fmt.Errorf("unexpected client RoundTrip error: %v", err) } - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) res.Body.Close() } want := GoAwayError{ @@ -3420,7 +3419,7 @@ func TestTransportRequestPathPseudo(t *testing.T) { // before we've determined that the ClientConn is usable. func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { const body = "foo" - req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) + req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) cc := &ClientConn{ closed: true, reqHeaderMu: make(chan struct{}, 1), @@ -3429,7 +3428,7 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { if err != errClientConnUnusable { t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) } - slurp, err := ioutil.ReadAll(req.Body) + slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("ReadAll = %v", err) } @@ -3491,7 +3490,7 @@ func TestTransportCancelDataResponseRace(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { + if _, err = io.Copy(io.Discard, res.Body); err == nil { t.Fatal("unexpected success") } @@ -3499,7 +3498,7 @@ func TestTransportCancelDataResponseRace(t *testing.T) { if err != nil { t.Fatal(err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -3525,7 +3524,7 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { + if _, err = io.Copy(io.Discard, resp.Body); err != nil { t.Fatalf("error reading response body: %v", err) } if err := resp.Body.Close(); err != nil { @@ -3646,7 +3645,7 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { return nil } @@ -3960,7 +3959,7 @@ func TestTransportRetryHasLimit(t *testing.T) { func TestTransportResponseDataBeforeHeaders(t *testing.T) { // This test use not valid response format. // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) ct := newClientTester(t) @@ -4166,7 +4165,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) return } - ioutil.ReadAll(resp.Body) + io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode != 204 { errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) @@ -4647,7 +4646,7 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { case closeAtHeaders, closeAtBody: if closeMode == closeAtBody { go close(sendBody) - if _, err := io.Copy(ioutil.Discard, res.Body); err == nil { + if _, err := io.Copy(io.Discard, res.Body); err == nil { t.Error("expected a Copy error, got nil") } } @@ -4698,7 +4697,7 @@ func TestClientConnShutdownCancel(t *testing.T) { func TestTransportUsesGetBodyWhenPresent(t *testing.T) { calls := 0 someBody := func() io.ReadCloser { - return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))} + return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))} } req := &http.Request{ Body: someBody(), @@ -5149,7 +5148,7 @@ func TestTransportFrameBufferReuse(t *testing.T) { if got, want := r.Header.Get("Big"), filler; got != want { t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want) } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("error reading request body: %v", err) } diff --git a/internal/http3/frames.go b/internal/http3/frames.go index 37eb0290..0ea126dd 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -4,10 +4,8 @@ import ( "bytes" "errors" "fmt" - "io" - "io/ioutil" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "io" ) // FrameType is the frame type of a HTTP/3 frame @@ -64,7 +62,7 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f case 0xd: // MAX_PUSH_ID } // skip over unknown frames - if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { + if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil { return nil, err } } diff --git a/internal/testdata/cert.go b/internal/testdata/cert.go index f862b0cb..bc1a8f44 100644 --- a/internal/testdata/cert.go +++ b/internal/testdata/cert.go @@ -3,7 +3,6 @@ package testdata import ( "crypto/tls" "crypto/x509" - "io/ioutil" "path" "runtime" ) @@ -38,7 +37,7 @@ func GetTLSConfig() *tls.Config { // AddRootCA adds the root CA certificate to a cert pool func AddRootCA(certPool *x509.CertPool) { caCertPath := path.Join(certPath, "ca.pem") - caCertRaw, err := ioutil.ReadFile(caCertPath) + caCertRaw, err := os.ReadFile(caCertPath) if err != nil { panic(err) } diff --git a/internal/testdata/cert_test.go b/internal/testdata/cert_test.go index 0de1bd7b..e21fb61d 100644 --- a/internal/testdata/cert_test.go +++ b/internal/testdata/cert_test.go @@ -2,8 +2,6 @@ package testdata import ( "crypto/tls" - "io/ioutil" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -24,7 +22,7 @@ var _ = Describe("certificates", func() { conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(conn) + data, err := io.ReadAll(conn) Expect(err).ToNot(HaveOccurred()) Expect(string(data)).To(Equal("foobar")) }) diff --git a/middleware.go b/middleware.go index b8865bdf..be48388b 100644 --- a/middleware.go +++ b/middleware.go @@ -5,7 +5,6 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" "io" - "io/ioutil" "mime/multipart" "net/http" "net/textproto" @@ -143,7 +142,7 @@ func handleMultiPart(c *Client, r *Request) (err error) { w := multipart.NewWriter(buf) writeMultiPart(r, w) r.GetBody = func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil + return io.NopCloser(bytes.NewReader(buf.Bytes())), nil } r.Body = buf.Bytes() r.SetContentType(w.FormDataContentType()) @@ -347,7 +346,7 @@ func handleDownload(c *Client, r *Response) (err error) { var body io.ReadCloser if r.body != nil { // already read - body = ioutil.NopCloser(bytes.NewReader(r.body)) + body = io.NopCloser(bytes.NewReader(r.body)) } else { body = r.Body } diff --git a/req_test.go b/req_test.go index bbb61ee4..e89f524f 100644 --- a/req_test.go +++ b/req_test.go @@ -10,7 +10,6 @@ import ( "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/transform" "io" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -68,7 +67,7 @@ func getTestServerURL() string { } func getTestFileContent(t *testing.T, filename string) []byte { - b, err := ioutil.ReadFile(tests.GetTestFilePath(filename)) + b, err := os.ReadFile(tests.GetTestFilePath(filename)) tests.AssertNoError(t, err) return b } @@ -113,15 +112,15 @@ type Echo struct { func handlePost(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) w.Write([]byte("TestPost: text response")) case "/raw-upload": - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) case "/file-text": r.ParseMultipartForm(10e6) files := r.MultipartForm.File["file"] file, _ := files[0].Open() - b, _ := ioutil.ReadAll(file) + b, _ := io.ReadAll(file) r.ParseForm() if a := r.FormValue("attempt"); a != "" && a != "2" { w.WriteHeader(http.StatusInternalServerError) @@ -143,14 +142,14 @@ func handlePost(w http.ResponseWriter, r *http.Request) { case "/search": handleSearch(w, r) case "/redirect": - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) w.Header().Set(header.Location, "/") w.WriteHeader(http.StatusMovedPermanently) case "/content-type": - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) w.Write([]byte(r.Header.Get(header.ContentType))) case "/echo": - b, _ := ioutil.ReadAll(r.Body) + b, _ := io.ReadAll(r.Body) e := Echo{ Header: r.Header, Body: string(b), @@ -216,7 +215,7 @@ func handleSearch(w http.ResponseWriter, r *http.Request) { func toGbk(s string) []byte { reader := transform.NewReader(strings.NewReader(s), simplifiedchinese.GBK.NewEncoder()) - d, e := ioutil.ReadAll(reader) + d, e := io.ReadAll(reader) if e != nil { panic(e) } @@ -286,13 +285,13 @@ func handleGet(w http.ResponseWriter, r *http.Request) { case "/pragma": w.Header().Add("Pragma", "no-cache") case "/payload": - b, _ := ioutil.ReadAll(r.Body) + b, _ := io.ReadAll(r.Body) w.Write(b) case "/gbk": w.Header().Set(header.ContentType, "text/plain; charset=gbk") w.Write(toGbk("我是roc")) case "/gbk-no-charset": - b, err := ioutil.ReadFile(tests.GetTestFilePath("sample-gbk.html")) + b, err := os.ReadFile(tests.GetTestFilePath("sample-gbk.html")) if err != nil { panic(err) } diff --git a/request.go b/request.go index cc145c2b..c7181354 100644 --- a/request.go +++ b/request.go @@ -10,7 +10,6 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" "io" - "io/ioutil" "net/http" urlpkg "net/url" "os" @@ -232,7 +231,7 @@ func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *R if rc, ok := reader.(io.ReadCloser); ok { return rc, nil } - return ioutil.NopCloser(reader), nil + return io.NopCloser(reader), nil }, }) return r @@ -742,7 +741,7 @@ func (r *Request) SetBody(body interface{}) *Request { return r.unReplayableBody, nil } case io.Reader: - r.unReplayableBody = ioutil.NopCloser(b) + r.unReplayableBody = io.NopCloser(b) r.GetBody = func() (io.ReadCloser, error) { return r.unReplayableBody, nil } @@ -770,7 +769,7 @@ func (r *Request) SetBody(body interface{}) *Request { func (r *Request) SetBodyBytes(body []byte) *Request { r.Body = body r.GetBody = func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewReader(body)), nil + return io.NopCloser(bytes.NewReader(body)), nil } return r } diff --git a/request_test.go b/request_test.go index c6800846..ac30f85b 100644 --- a/request_test.go +++ b/request_test.go @@ -8,7 +8,6 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -392,7 +391,7 @@ func TestSetSuccessResult(t *testing.T) { func TestSetBody(t *testing.T) { body := "hello" fn := func() (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewBufferString(body)), nil + return io.NopCloser(bytes.NewBufferString(body)), nil } c := tc() testCases := []struct { @@ -411,7 +410,7 @@ func TestSetBody(t *testing.T) { }, { SetBody: func(r *Request) { // SetBody with io.ReadCloser - r.SetBody(ioutil.NopCloser(bytes.NewBufferString(body))) + r.SetBody(io.NopCloser(bytes.NewBufferString(body))) }, }, { @@ -903,7 +902,7 @@ func TestSetFileReader(t *testing.T) { buff = bytes.NewBufferString("test") resp = uploadTextFile(t, func(r *Request) { - r.SetFileReader("file", "file.txt", ioutil.NopCloser(buff)) + r.SetFileReader("file", "file.txt", io.NopCloser(buff)) }) tests.AssertEqual(t, "test", resp.String()) } @@ -990,7 +989,7 @@ func TestUploadCallback(t *testing.T) { func TestDownloadCallback(t *testing.T) { n := 0 resp, err := tc().R(). - SetOutput(ioutil.Discard). + SetOutput(io.Discard). SetDownloadCallback(func(info DownloadInfo) { n++ }).Get("/download") @@ -1009,7 +1008,7 @@ func TestRequestDisableAutoReadResponse(t *testing.T) { resp, err = c.R().DisableAutoReadResponse().Get("/") assertSuccess(t, resp, err) - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) tests.AssertNoError(t, err) }) } @@ -1020,7 +1019,7 @@ func TestRestoreResponseBody(t *testing.T) { assertSuccess(t, resp, err) tests.AssertNoError(t, err) tests.AssertEqual(t, true, len(resp.Bytes()) > 0) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) tests.AssertEqual(t, true, len(body) > 0) } diff --git a/response.go b/response.go index 6a3e1451..8f19d410 100644 --- a/response.go +++ b/response.go @@ -3,7 +3,7 @@ package req import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" - "io/ioutil" + "io" "net/http" "strings" "time" @@ -234,7 +234,7 @@ func (r *Response) ToBytes() (body []byte, err error) { } r.body = body }() - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) r.setReceivedAt() if err == nil && r.Request.client.responseBodyTransformer != nil { body, err = r.Request.client.responseBodyTransformer(body, r.Request, r) diff --git a/retry_test.go b/retry_test.go index 2bffbf13..bbf41909 100644 --- a/retry_test.go +++ b/retry_test.go @@ -3,7 +3,7 @@ package req import ( "bytes" "github.com/imroc/req/v3/internal/tests" - "io/ioutil" + "io" "math" "net/http" "testing" @@ -123,7 +123,7 @@ func TestRetryWithUnreplayableBody(t *testing.T) { _, err = tc().R(). SetRetryCount(1). - SetBody(ioutil.NopCloser(bytes.NewBufferString("test"))). + SetBody(io.NopCloser(bytes.NewBufferString("test"))). Post("/") tests.AssertEqual(t, errRetryableWithUnReplayableBody, err) } diff --git a/roundtrip_js.go b/roundtrip_js.go index 9d8f3e4a..771a31d1 100644 --- a/roundtrip_js.go +++ b/roundtrip_js.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "strconv" "syscall/js" @@ -102,7 +101,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API // and browser support. - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err != nil { req.Body.Close() // RoundTrip must always close the body, including on errors. return nil, err diff --git a/transfer.go b/transfer.go index 95001f66..d2ec69db 100644 --- a/transfer.go +++ b/transfer.go @@ -13,7 +13,6 @@ import ( "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/dump" "io" - "io/ioutil" "net/http" "net/http/httptrace" "net/textproto" @@ -129,7 +128,7 @@ func newTransferWriter(r *http.Request) (t *transferWriter, err error) { // servers. See Issue 18257, as one example. // // The only reason we'd send such a request is if the user set the Body to a -// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't +// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't // set ContentLength, or NewRequest set it to -1 (unknown), so then we assume // there's bytes to send. // @@ -357,7 +356,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error return err } var nextra int64 - nextra, err = t.doBodyCopy(ioutil.Discard, body) + nextra, err = t.doBodyCopy(io.Discard, body) ncopy += nextra } if err != nil { @@ -918,7 +917,7 @@ func (b *body) Close() error { var n int64 // Consume the body, or, which will also lead to us reading // the trailer headers after the body, if present. - n, err = io.CopyN(ioutil.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) if err == io.EOF { err = nil } @@ -929,7 +928,7 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, bodyLocked{b}) + _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true return err @@ -1004,8 +1003,8 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil)) -var nopCloserWriterToType = reflect.TypeOf(ioutil.NopCloser(struct { +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) +var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { io.Reader io.WriterTo }{})) From f0c0c12a39eb75a725e59d4bc4599e781a133410 Mon Sep 17 00:00:00 2001 From: rockerchen Date: Mon, 6 Feb 2023 19:58:06 +0800 Subject: [PATCH 676/843] Fix TestTransportEventTraceTLSVerify --- transport_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport_test.go b/transport_test.go index 102bbc5a..741a6e9a 100644 --- a/transport_test.go +++ b/transport_test.go @@ -4523,7 +4523,7 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { wantOnce("TLSHandshakeStart") wantOnce("TLSHandshakeDone") - wantOnce("err = x509: certificate is valid for example.com") + wantOnce("x509: certificate is valid for example.com") if t.Failed() { t.Errorf("Output:\n%s", got) From 2ffcfc56710dc6db5ba6edc3ac48d90a8e17b6b9 Mon Sep 17 00:00:00 2001 From: jfilipczyk Date: Fri, 17 Feb 2023 14:36:03 +0100 Subject: [PATCH 677/843] fix: do not retry when RetryCount eq 0 --- request.go | 11 ++++++----- retry_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/request.go b/request.go index c7181354..8fdf5daa 100644 --- a/request.go +++ b/request.go @@ -5,10 +5,6 @@ import ( "context" "errors" "fmt" - "github.com/hashicorp/go-multierror" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/util" "io" "net/http" urlpkg "net/url" @@ -17,6 +13,11 @@ import ( "reflect" "strings" "time" + + "github.com/hashicorp/go-multierror" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" ) // Request struct is used to compose and fire individual request from @@ -577,7 +578,7 @@ func (r *Request) do() (resp *Response, err error) { resp, err = r.client.roundTrip(r) } - if r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries > 0) { // absolutely cannot retry. + if r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { // absolutely cannot retry. return } diff --git a/retry_test.go b/retry_test.go index bbf41909..1f936a31 100644 --- a/retry_test.go +++ b/retry_test.go @@ -2,12 +2,13 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/tests" "io" "math" "net/http" "testing" "time" + + "github.com/imroc/req/v3/internal/tests" ) func TestRetryBackOff(t *testing.T) { @@ -170,3 +171,28 @@ func TestRetryFalse(t *testing.T) { tests.AssertIsNil(t, resp.Response) tests.AssertEqual(t, 0, resp.Request.RetryAttempt) } + +func TestRetryTurnedOffByRetryCountEqZero(t *testing.T) { + resp, err := tc().R(). + SetRetryCount(0). + SetRetryCondition(func(resp *Response, err error) bool { + t.Fatal("retry condition should not be executed") + return true + }). + Get("https://non-exists-host.com.cn") + tests.AssertNotNil(t, err) + tests.AssertIsNil(t, resp.Response) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) + + resp, err = tc(). + SetCommonRetryCount(0). + SetCommonRetryCondition(func(resp *Response, err error) bool { + t.Fatal("retry condition should not be executed") + return true + }). + R(). + Get("https://non-exists-host.com.cn") + tests.AssertNotNil(t, err) + tests.AssertIsNil(t, resp.Response) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) +} From 4d93b334db76b168c8a09007143c44d5c9452bc3 Mon Sep 17 00:00:00 2001 From: jfilipczyk Date: Fri, 17 Feb 2023 14:54:41 +0100 Subject: [PATCH 678/843] chore: change test name --- retry_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retry_test.go b/retry_test.go index 1f936a31..20817438 100644 --- a/retry_test.go +++ b/retry_test.go @@ -172,7 +172,7 @@ func TestRetryFalse(t *testing.T) { tests.AssertEqual(t, 0, resp.Request.RetryAttempt) } -func TestRetryTurnedOffByRetryCountEqZero(t *testing.T) { +func TestRetryTurnedOffWhenRetryCountEqZero(t *testing.T) { resp, err := tc().R(). SetRetryCount(0). SetRetryCondition(func(resp *Response, err error) bool { From be24aaad78d304f7af8cc5a123b33a19f73cb35e Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 22 Feb 2023 11:41:41 +0800 Subject: [PATCH 679/843] Fix negative resp.TotalTime() (#214) --- response.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/response.go b/response.go index 8f19d410..bd2f9d7c 100644 --- a/response.go +++ b/response.go @@ -131,7 +131,10 @@ func (r *Response) TotalTime() time.Duration { if r.Request.trace != nil { return r.Request.TraceInfo().TotalTime } - return r.receivedAt.Sub(r.Request.StartTime) + if !r.receivedAt.IsZero() { + return r.receivedAt.Sub(r.Request.StartTime) + } + return r.Request.responseReturnTime.Sub(r.Request.StartTime) } // ReceivedAt returns the timestamp that response we received. From 868660c714a47a07e64b6232006e6a13e100603a Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 24 Feb 2023 17:32:48 +0800 Subject: [PATCH 680/843] Fix: avoid resp.Err been overridden --- client.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index e3d0a253..f3720f2b 100644 --- a/client.go +++ b/client.go @@ -1297,6 +1297,9 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { // RoundTrip implements RoundTripper func (c *Client) roundTrip(r *Request) (resp *Response, err error) { resp = &Response{Request: r} + defer func() { + err = resp.Err + }() // setup trace if r.trace == nil && r.client.trace { @@ -1322,8 +1325,8 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { var reqBody io.ReadCloser if r.GetBody != nil { - reqBody, err = r.GetBody() - if err != nil { + reqBody, resp.Err = r.GetBody() + if resp.Err != nil { return } } @@ -1369,21 +1372,18 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { r.StartTime = time.Now() var httpResponse *http.Response - httpResponse, err = c.httpClient.Do(r.RawRequest) + httpResponse, resp.Err = c.httpClient.Do(r.RawRequest) resp.Response = httpResponse // auto-read response body if possible - if err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { - _, err = resp.ToBytes() + if resp.Err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { + resp.ToBytes() // restore body for re-reads resp.Body = io.NopCloser(bytes.NewReader(resp.body)) - } else if err != nil { - resp.Err = err } for _, f := range r.client.afterResponse { if e := f(r.client, resp); e != nil { - err = e resp.Err = e return } From 9e2d4b189089195babdaf2665d58baa8cd2abf27 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 24 Feb 2023 17:54:34 +0800 Subject: [PATCH 681/843] Ensure response middleware executed when internal middleware returns error --- client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client.go b/client.go index f3720f2b..a6ef117e 100644 --- a/client.go +++ b/client.go @@ -1385,7 +1385,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { for _, f := range r.client.afterResponse { if e := f(r.client, resp); e != nil { resp.Err = e - return } } return From de7450afd66bed0d947c236223edbb84f9f1f60e Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 7 Mar 2023 19:25:25 +0800 Subject: [PATCH 682/843] expose more http2 settings to client and transport --- client.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ transport.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/client.go b/client.go index a6ef117e..62056984 100644 --- a/client.go +++ b/client.go @@ -1158,6 +1158,63 @@ func (c *Client) EnableHTTP3() *Client { return c } +// SetHTTP2MaxHeaderListSize set the http2 MaxHeaderListSize, +// which is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to +// send in the initial settings frame. It is how many bytes +// of response headers are allowed. Unlike the http2 spec, zero here +// means to use a default limit (currently 10MB). If you actually +// want to advertise an unlimited value to the peer, Transport +// interprets the highest possible value here (0xffffffff or 1<<32-1) +// to mean no limit. +func (c *Client) SetHTTP2MaxHeaderListSize(max uint32) *Client { + c.t.SetHTTP2MaxHeaderListSize(max) + return c +} + +// SetHTTP2StrictMaxConcurrentStreams set the http2 +// StrictMaxConcurrentStreams, which controls whether the +// server's SETTINGS_MAX_CONCURRENT_STREAMS should be respected +// globally. If false, new TCP connections are created to the +// server as needed to keep each under the per-connection +// SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the +// server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as +// a global limit and callers of RoundTrip block when needed, +// waiting for their turn. +func (c *Client) SetHTTP2StrictMaxConcurrentStreams(strict bool) *Client { + c.t.SetHTTP2StrictMaxConcurrentStreams(strict) + return c +} + +// SetHTTP2ReadIdleTimeout set the http2 ReadIdleTimeout, +// which is the timeout after which a health check using ping +// frame will be carried out if no frame is received on the connection. +// Note that a ping response will is considered a received frame, so if +// there is no other traffic on the connection, the health check will +// be performed every ReadIdleTimeout interval. +// If zero, no health check is performed. +func (c *Client) SetHTTP2ReadIdleTimeout(timeout time.Duration) *Client { + c.t.SetHTTP2ReadIdleTimeout(timeout) + return c +} + +// SetHTTP2PingTimeout set the http2 PingTimeout, which is the timeout +// after which the connection will be closed if a response to Ping is +// not received. +// Defaults to 15s +func (c *Client) SetHTTP2PingTimeout(timeout time.Duration) *Client { + c.t.SetHTTP2PingTimeout(timeout) + return c +} + +// SetHTTP2WriteByteTimeout set the http2 WriteByteTimeout, which is the +// timeout after which the connection will be closed no data can be written +// to it. The timeout begins when data is available to write, and is +// extended whenever any bytes are written. +func (c *Client) SetHTTP2WriteByteTimeout(timeout time.Duration) *Client { + c.t.SetHTTP2WriteByteTimeout(timeout) + return c +} + // NewClient is the alias of C func NewClient() *Client { return C() diff --git a/transport.go b/transport.go index 0c006b26..0cd3501d 100644 --- a/transport.go +++ b/transport.go @@ -345,6 +345,63 @@ func (t *Transport) SetMaxResponseHeaderBytes(max int64) *Transport { return t } +// SetHTTP2MaxHeaderListSize set the http2 MaxHeaderListSize, +// which is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to +// send in the initial settings frame. It is how many bytes +// of response headers are allowed. Unlike the http2 spec, zero here +// means to use a default limit (currently 10MB). If you actually +// want to advertise an unlimited value to the peer, Transport +// interprets the highest possible value here (0xffffffff or 1<<32-1) +// to mean no limit. +func (t *Transport) SetHTTP2MaxHeaderListSize(max uint32) *Transport { + t.t2.MaxHeaderListSize = max + return t +} + +// SetHTTP2StrictMaxConcurrentStreams set the http2 +// StrictMaxConcurrentStreams, which controls whether the +// server's SETTINGS_MAX_CONCURRENT_STREAMS should be respected +// globally. If false, new TCP connections are created to the +// server as needed to keep each under the per-connection +// SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the +// server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as +// a global limit and callers of RoundTrip block when needed, +// waiting for their turn. +func (t *Transport) SetHTTP2StrictMaxConcurrentStreams(strict bool) *Transport { + t.t2.StrictMaxConcurrentStreams = strict + return t +} + +// SetHTTP2ReadIdleTimeout set the http2 ReadIdleTimeout, +// which is the timeout after which a health check using ping +// frame will be carried out if no frame is received on the connection. +// Note that a ping response will is considered a received frame, so if +// there is no other traffic on the connection, the health check will +// be performed every ReadIdleTimeout interval. +// If zero, no health check is performed. +func (t *Transport) SetHTTP2ReadIdleTimeout(timeout time.Duration) *Transport { + t.t2.ReadIdleTimeout = timeout + return t +} + +// SetHTTP2PingTimeout set the http2 PingTimeout, which is the timeout +// after which the connection will be closed if a response to Ping is +// not received. +// Defaults to 15s +func (t *Transport) SetHTTP2PingTimeout(timeout time.Duration) *Transport { + t.t2.PingTimeout = timeout + return t +} + +// SetHTTP2WriteByteTimeout set the http2 WriteByteTimeout, which is the +// timeout after which the connection will be closed no data can be written +// to it. The timeout begins when data is available to write, and is +// extended whenever any bytes are written. +func (t *Transport) SetHTTP2WriteByteTimeout(timeout time.Duration) *Transport { + t.t2.WriteByteTimeout = timeout + return t +} + // SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to // use with tls.Client. // If nil, the default configuration is used. From 0f702b47bdf9a71a65bc642c88bc5e6f0ce4c3c8 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 9 Mar 2023 17:56:25 +0800 Subject: [PATCH 683/843] Only auto-read response if code > 199 --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 62056984..b5dbb370 100644 --- a/client.go +++ b/client.go @@ -1433,7 +1433,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { resp.Response = httpResponse // auto-read response body if possible - if resp.Err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse { + if resp.Err == nil && !c.disableAutoReadResponse && !r.isSaveResponse && !r.disableAutoReadResponse && resp.StatusCode > 199 { resp.ToBytes() // restore body for re-reads resp.Body = io.NopCloser(bytes.NewReader(resp.body)) From 060235787e2e0dd012fbdc23a3a589b7b0339743 Mon Sep 17 00:00:00 2001 From: kingluo Date: Mon, 27 Mar 2023 00:28:45 +0800 Subject: [PATCH 684/843] fix: use host and port from alt-svc header --- transport.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transport.go b/transport.go index 0cd3501d..56e5cb42 100644 --- a/transport.go +++ b/transport.go @@ -604,14 +604,16 @@ func (t *Transport) handleAltSvc(req *http.Request, value string) { Entries: entries, } t.pendingAltSvcs[addr] = pas - go t.handlePendingAltSvc(netutil.AuthorityAddr(req.URL.Scheme, req.URL.Host), pas) + go t.handlePendingAltSvc(req.URL, pas) } } -func (t *Transport) handlePendingAltSvc(hostname string, pas *pendingAltSvc) { +func (t *Transport) handlePendingAltSvc(u *url.URL, pas *pendingAltSvc) { for i := pas.CurrentIndex; i < len(pas.Entries); i++ { switch pas.Entries[i].Protocol { case "h3": // only support h3 in alt-svc for now + u2 := altsvcutil.ConvertURL(pas.Entries[i], u) + hostname := u2.Host err := t.t3.AddConn(hostname) if err != nil { if t.Debugf != nil { @@ -805,12 +807,14 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error pas.Mu.Lock() if pas.Transport != nil { pas.LastTime = time.Now() - resp, err = pas.Transport.RoundTrip(req) + r := req.Clone(req.Context()) + r.URL = altsvcutil.ConvertURL(pas.Entries[pas.CurrentIndex], req.URL) + resp, err = pas.Transport.RoundTrip(r) if err != nil { pas.Transport = nil if pas.CurrentIndex+1 < len(pas.Entries) { pas.CurrentIndex++ - go t.handlePendingAltSvc(addr, pas) + go t.handlePendingAltSvc(req.URL, pas) } } else { t.altSvcJar.SetAltSvc(addr, pas.Entries[pas.CurrentIndex]) From 9f87b5b87f03663e0b2b7aabb3bc38ef2ad7388d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Apr 2023 10:05:17 +0800 Subject: [PATCH 685/843] Support latest quic-go version, requires go1.19 --- go.mod | 28 +++---- go.sum | 28 +++++++ internal/http3/body.go | 3 + internal/http3/client.go | 114 +++++++++++++++----------- internal/http3/frames.go | 32 ++++---- internal/http3/http_stream.go | 15 ++-- internal/http3/request.go | 5 +- internal/http3/request_writer.go | 14 ++-- internal/http3/roundtrip.go | 88 +++++++++++++++++--- internal/quic-go/quicvarint/varint.go | 57 ++++++++++++- 10 files changed, 280 insertions(+), 104 deletions(-) diff --git a/go.mod b/go.mod index f364a57a..f08387ce 100644 --- a/go.mod +++ b/go.mod @@ -1,27 +1,27 @@ module github.com/imroc/req/v3 -go 1.18 +go 1.19 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.32.0 - golang.org/x/net v0.4.0 - golang.org/x/text v0.5.0 + github.com/quic-go/quic-go v0.34.0 + golang.org/x/net v0.9.0 + golang.org/x/text v0.9.0 ) require ( - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/google/pprof v0.0.0-20230426061923-93006964c1fc // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/onsi/ginkgo/v2 v2.2.0 // indirect + github.com/onsi/ginkgo/v2 v2.9.2 // indirect github.com/quic-go/qtls-go1-18 v0.2.0 // indirect - github.com/quic-go/qtls-go1-19 v0.2.0 // indirect - github.com/quic-go/qtls-go1-20 v0.1.0 // indirect - golang.org/x/crypto v0.4.0 // indirect - golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect - golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.3.0 // indirect - golang.org/x/tools v0.2.0 // indirect + github.com/quic-go/qtls-go1-19 v0.3.2 // indirect + github.com/quic-go/qtls-go1-20 v0.2.2 // indirect + golang.org/x/crypto v0.8.0 // indirect + golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect + golang.org/x/mod v0.10.0 // indirect + golang.org/x/sys v0.7.0 // indirect + golang.org/x/tools v0.8.0 // indirect ) diff --git a/go.sum b/go.sum index a9b87a39..d9f7ae80 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20230426061923-93006964c1fc h1:AGDHt781oIcL4EFk7cPnvBUYTwU8BEU6GDTO3ZMn1sE= +github.com/google/pprof v0.0.0-20230426061923-93006964c1fc/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -20,6 +24,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9 github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= +github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU= +github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts= github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -29,28 +35,43 @@ github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwq github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= +github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= +github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= +github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= +github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= +github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= +golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= +golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -61,16 +82,22 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= +golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= +golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -78,4 +105,5 @@ google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscL gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/http3/body.go b/internal/http3/body.go index 50a64330..15985a1c 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -18,12 +18,15 @@ type HTTPStreamer interface { } type StreamCreator interface { + // Context returns a context that is cancelled when the underlying connection is closed. + Context() context.Context OpenStream() (quic.Stream, error) OpenStreamSync(context.Context) (quic.Stream, error) OpenUniStream() (quic.SendStream, error) OpenUniStreamSync(context.Context) (quic.SendStream, error) LocalAddr() net.Addr RemoteAddr() net.Addr + ConnectionState() quic.ConnectionState } var _ StreamCreator = quic.Connection(nil) diff --git a/internal/http3/client.go b/internal/http3/client.go index e9d2957e..8b180569 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -1,22 +1,23 @@ package http3 import ( - "bytes" "context" "crypto/tls" "errors" "fmt" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" - "github.com/imroc/req/v3/internal/transport" - "github.com/quic-go/qpack" - "github.com/quic-go/quic-go" "io" "net/http" "reflect" "strconv" "sync" + "sync/atomic" "time" + + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/imroc/req/v3/internal/transport" + "github.com/quic-go/qpack" + "github.com/quic-go/quic-go" ) // MethodGet0RTT allows a GET request to be sent using 0-RTT. @@ -68,12 +69,14 @@ type client struct { decoder *qpack.Decoder hostname string - conn quic.EarlyConnection + conn atomic.Pointer[quic.EarlyConnection] opt *transport.Options } -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (*client, error) { +var _ roundTripCloser = &client{} + +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() } else if len(conf.Versions) == 0 { @@ -114,54 +117,56 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con func (c *client) dial(ctx context.Context) error { var err error + var conn quic.EarlyConnection if c.dialer != nil { - c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) + conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) } else { - c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) + conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) } if err != nil { return err } + c.conn.Store(&conn) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(); err != nil { + if err := c.setupConn(conn); err != nil { c.opt.Debugf("setting up http3 connection failed: %s", err) - c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") } }() if c.opts.StreamHijacker != nil { - go c.handleBidirectionalStreams() + go c.handleBidirectionalStreams(conn) } - go c.handleUnidirectionalStreams() + go c.handleUnidirectionalStreams(conn) return nil } -func (c *client) setupConn() error { +func (c *client) setupConn(conn quic.EarlyConnection) error { // open the control stream - str, err := c.conn.OpenUniStream() + str, err := conn.OpenUniStream() if err != nil { return err } - buf := &bytes.Buffer{} - quicvarint.Write(buf, streamTypeControlStream) + b := make([]byte, 0, 64) + b = quicvarint.Append(b, streamTypeControlStream) // send the SETTINGS frame - (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf) - _, err = str.Write(buf.Bytes()) + b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) + _, err = str.Write(b) return err } -func (c *client) handleBidirectionalStreams() { +func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { for { - str, err := c.conn.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) if err != nil { c.opt.Debugf("accepting bidirectional stream failed: %s", err) return } go func(str quic.Stream) { _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { - return c.opts.StreamHijacker(ft, c.conn, str, e) + return c.opts.StreamHijacker(ft, conn, str, e) }) if err == errHijacked { return @@ -169,14 +174,14 @@ func (c *client) handleBidirectionalStreams() { if err != nil { c.opt.Debugf("error handling stream: %s", err) } - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }(str) } } -func (c *client) handleUnidirectionalStreams() { +func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { for { - str, err := c.conn.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) if err != nil { c.opt.Debugf("accepting unidirectional stream failed: %s", err) return @@ -185,7 +190,7 @@ func (c *client) handleUnidirectionalStreams() { go func(str quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { return } c.opt.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) @@ -200,10 +205,10 @@ func (c *client) handleUnidirectionalStreams() { return case streamTypePushStream: // We never increased the Push ID, so we don't expect any push streams. - c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") return default: - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { return } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) @@ -211,12 +216,12 @@ func (c *client) handleUnidirectionalStreams() { } f, err := parseNextFrame(str, nil) if err != nil { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") return } if !sf.Datagram { @@ -225,18 +230,19 @@ func (c *client) handleUnidirectionalStreams() { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { + conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") } }(str) } } func (c *client) Close() error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return nil } - return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") + return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "") } func (c *client) maxHeaderBytes() uint64 { @@ -255,24 +261,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon c.dialOnce.Do(func() { c.handshakeErr = c.dial(req.Context()) }) - if c.handshakeErr != nil { return nil, c.handshakeErr } + // At this point, c.conn is guaranteed to be set. + conn := *c.conn.Load() + // Immediately send out this request, if this is a 0-RTT request. if req.Method == MethodGet0RTT { req.Method = http.MethodGet } else { // wait for the handshake to complete select { - case <-c.conn.HandshakeComplete().Done(): + case <-conn.HandshakeComplete(): case <-req.Context().Done(): return nil, req.Context().Err() } } - str, err := c.conn.OpenStreamSync(req.Context()) + str, err := conn.OpenStreamSync(req.Context()) if err != nil { return nil, err } @@ -296,7 +304,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon if opt.DontCloseRequestStream { doneChan = nil } - rsp, rerr := c.doRequest(req, str, opt, doneChan) + rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) if rerr.err != nil { // if any error occurred close(reqDone) <-done @@ -308,8 +316,9 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon if rerr.err != nil { reason = rerr.err.Error() } - c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } + return nil, rerr.err } if opt.DontCloseRequestStream { close(reqDone) @@ -368,7 +377,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D return nil } -func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { +func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true @@ -388,7 +397,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, str.Close() } - hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) if req.Body != nil { // send the request body asynchronously go func() { @@ -444,9 +453,9 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, return nil, newConnError(errorGeneralProtocolError, err) } - connState, ok := reflect.ValueOf(c.conn.ConnectionState().TLS).Field(0).Interface().(tls.ConnectionState) + connState, ok := reflect.ValueOf(conn.ConnectionState().TLS).Field(0).Interface().(tls.ConnectionState) if !ok { - panic(fmt.Sprintf("bad tls connection state type: %s", reflect.ValueOf(c.conn.ConnectionState().TLS).Field(0).Type().Name())) + panic(fmt.Sprintf("bad tls connection state type: %s", reflect.ValueOf(conn.ConnectionState().TLS).Field(0).Type().Name())) } res := &http.Response{ Proto: "HTTP/3.0", @@ -468,12 +477,12 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(hstr, c.conn, reqDone) + respBody := newResponseBody(hstr, conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. _, hasTransferEncoding := res.Header["Transfer-Encoding"] isInformational := res.StatusCode >= 100 && res.StatusCode < 200 - isNoContent := res.StatusCode == 204 + isNoContent := res.StatusCode == http.StatusNoContent isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { res.ContentLength = -1 @@ -496,3 +505,16 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, return res, requestError{} } + +func (c *client) HandshakeComplete() bool { + conn := c.conn.Load() + if conn == nil { + return false + } + select { + case <-(*conn).HandshakeComplete(): + return true + default: + return false + } +} diff --git a/internal/http3/frames.go b/internal/http3/frames.go index 0ea126dd..5e31d72a 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -4,8 +4,9 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/v3/internal/quic-go/quicvarint" "io" + + "github.com/imroc/req/v3/internal/quic-go/quicvarint" ) // FrameType is the frame type of a HTTP/3 frame @@ -72,18 +73,18 @@ type dataFrame struct { Length uint64 } -func (f *dataFrame) Write(b *bytes.Buffer) { - quicvarint.Write(b, 0x0) - quicvarint.Write(b, f.Length) +func (f *dataFrame) Append(b []byte) []byte { + b = quicvarint.Append(b, 0x0) + return quicvarint.Append(b, f.Length) } type headersFrame struct { Length uint64 } -func (f *headersFrame) Write(b *bytes.Buffer) { - quicvarint.Write(b, 0x1) - quicvarint.Write(b, f.Length) +func (f *headersFrame) Append(b []byte) []byte { + b = quicvarint.Append(b, 0x1) + return quicvarint.Append(b, f.Length) } const settingDatagram = 0xffd277 @@ -140,22 +141,23 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { return frame, nil } -func (f *settingsFrame) Write(b *bytes.Buffer) { - quicvarint.Write(b, 0x4) - var l uint64 +func (f *settingsFrame) Append(b []byte) []byte { + b = quicvarint.Append(b, 0x4) + var l int64 for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } - quicvarint.Write(b, l) + b = quicvarint.Append(b, uint64(l)) if f.Datagram { - quicvarint.Write(b, settingDatagram) - quicvarint.Write(b, 1) + b = quicvarint.Append(b, settingDatagram) + b = quicvarint.Append(b, 1) } for id, val := range f.Other { - quicvarint.Write(b, id) - quicvarint.Write(b, val) + b = quicvarint.Append(b, id) + b = quicvarint.Append(b, val) } + return b } diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index 9ce69b45..2799e2b3 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -1,7 +1,6 @@ package http3 import ( - "bytes" "fmt" "github.com/quic-go/quic-go" @@ -16,6 +15,8 @@ type Stream quic.Stream type stream struct { quic.Stream + buf []byte + onFrameError func() bytesRemainingInFrame uint64 } @@ -23,7 +24,11 @@ type stream struct { var _ Stream = &stream{} func newStream(str quic.Stream, onFrameError func()) *stream { - return &stream{Stream: str, onFrameError: onFrameError} + return &stream{ + Stream: str, + onFrameError: onFrameError, + buf: make([]byte, 0, 16), + } } func (s *stream) Read(b []byte) (int, error) { @@ -62,9 +67,9 @@ func (s *stream) Read(b []byte) (int, error) { } func (s *stream) Write(b []byte) (int, error) { - buf := &bytes.Buffer{} - (&dataFrame{Length: uint64(len(b))}).Write(buf) - if _, err := s.Stream.Write(buf.Bytes()); err != nil { + s.buf = s.buf[:0] + s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf) + if _, err := s.Stream.Write(s.buf); err != nil { return 0, err } return s.Stream.Write(b) diff --git a/internal/http3/request.go b/internal/http3/request.go index a6fe02cd..9af25a57 100644 --- a/internal/http3/request.go +++ b/internal/http3/request.go @@ -1,13 +1,13 @@ package http3 import ( - "crypto/tls" "errors" - "github.com/quic-go/qpack" "net/http" "net/url" "strconv" "strings" + + "github.com/quic-go/qpack" ) func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { @@ -100,7 +100,6 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { ContentLength: contentLength, Host: authority, RequestURI: requestURI, - TLS: &tls.ConnectionState{}, }, nil } diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 66c7ead0..1fdcf934 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -3,9 +3,6 @@ package http3 import ( "bytes" "fmt" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" - "github.com/quic-go/qpack" "io" "net" "net/http" @@ -13,6 +10,10 @@ import ( "strings" "sync" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + "github.com/quic-go/qpack" + "github.com/quic-go/quic-go" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" @@ -59,10 +60,9 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } - buf := &bytes.Buffer{} - hf := headersFrame{Length: uint64(w.headerBuf.Len())} - hf.Write(buf) - if _, err := wr.Write(buf.Bytes()); err != nil { + b := make([]byte, 0, 128) + b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b) + if _, err := wr.Write(b); err != nil { return err } _, err := wr.Write(w.headerBuf.Bytes()) diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index a6f4b080..ea16105a 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -5,20 +5,26 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/imroc/req/v3/internal/transport" "io" + "net" "net/http" "strings" "sync" "time" + "github.com/imroc/req/v3/internal/transport" + "github.com/quic-go/quic-go" "golang.org/x/net/http/httpguts" ) +// declare this as a variable, such that we can it mock it in the tests +var quicDialer = quic.DialEarlyContext + type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) + HandshakeComplete() bool io.Closer } @@ -56,10 +62,13 @@ type RoundTripper struct { // Dial specifies an optional dial function for creating QUIC // connections for requests. - // If Dial is nil, quic.DialAddrEarlyContext will be used. + // If Dial is nil, a UDPConn will be created at the first request + // and will be reused for subsequent connections to other servers. Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - clients map[string]roundTripCloser + newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) // so we can mock it in tests + clients map[string]roundTripCloser + udpConn *net.UDPConn } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -67,6 +76,7 @@ type RoundTripOpt struct { // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. // If set, context cancellations have no effect after the response headers are received. DontCloseRequestStream bool } @@ -85,6 +95,10 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. closeRequestBody(req) return nil, errors.New("http3: nil Request.URL") } + if req.URL.Scheme != "https" { + closeRequestBody(req) + return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) + } if req.URL.Host == "" { closeRequestBody(req) return nil, errors.New("http3: no Host in request URL") @@ -116,7 +130,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } hostname := authorityAddr("https", hostnameFromRequest(req)) - cl, err := r.getClient(hostname, opt.OnlyCachedConn) + cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) if err != ErrNoCachedConn { if debugf := r.Debugf; debugf != nil { debugf("HTTP/3 %s %s", req.Method, req.URL.String()) @@ -125,7 +139,16 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } - return cl.RoundTripOpt(req, opt) + rsp, err := cl.RoundTripOpt(req, opt) + if err != nil { + r.removeClient(hostname) + if isReused { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return r.RoundTripOpt(req, opt) + } + } + } + return rsp, err } // RoundTrip does a round trip. @@ -140,7 +163,7 @@ func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Respons // AddConn add a http3 connection, dial new conn if not exists. func (r *RoundTripper) AddConn(addr string) error { - c, err := r.getClient(addr, false) + c, _, err := r.getClient(addr, false) if err != nil { return err } @@ -155,7 +178,7 @@ func (r *RoundTripper) AddConn(addr string) error { return client.handshakeErr } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() @@ -166,10 +189,24 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo client, ok := r.clients[hostname] if !ok { if onlyCached { - return nil, ErrNoCachedConn + return nil, false, ErrNoCachedConn } var err error - client, err = newClient( + newCl := newClient + if r.newClient != nil { + newCl = r.newClient + } + dial := r.Dial + if dial == nil { + if r.udpConn == nil { + r.udpConn, err = net.ListenUDP("udp", nil) + if err != nil { + return nil, false, err + } + } + dial = r.makeDialer() + } + client, err = newCl( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -181,18 +218,28 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo dump: r.Dump, }, r.QuicConfig, - r.Dial, + dial, r.Options, ) if err != nil { - return nil, err + return nil, false, err } r.clients[hostname] = client } - return client, nil + return client, isReused, nil +} + +func (r *RoundTripper) removeClient(hostname string) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.clients == nil { + return + } + delete(r.clients, hostname) } -// Close closes the QUIC connections that this RoundTripper has used +// Close closes the QUIC connections that this RoundTripper has used. +// It also closes the underlying UDPConn if it is not nil. func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() @@ -202,6 +249,10 @@ func (r *RoundTripper) Close() error { } } r.clients = nil + if r.udpConn != nil { + r.udpConn.Close() + r.udpConn = nil + } return nil } @@ -232,3 +283,14 @@ func validMethod(method string) bool { func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } + +// makeDialer makes a QUIC dialer using r.udpConn. +func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg) + } +} diff --git a/internal/quic-go/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go index e4040841..fdb00353 100644 --- a/internal/quic-go/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -7,6 +7,12 @@ import ( // taken from the QUIC draft const ( + // Min is the minimum value allowed for a QUIC varint. + Min = 0 + + // Max is the maximum allowed value for a QUIC varint (2^62-1). + Max = maxVarInt8 + maxVarInt1 = 63 maxVarInt2 = 16383 maxVarInt4 = 1073741823 @@ -63,6 +69,7 @@ func Read(r io.ByteReader) (uint64, error) { } // Write writes i in the QUIC varint format to w. +// Deprecated: use Append instead. func Write(w Writer, i uint64) { if i <= maxVarInt1 { w.WriteByte(uint8(i)) @@ -80,8 +87,56 @@ func Write(w Writer, i uint64) { } } +// Append appends i in the QUIC varint format. +func Append(b []byte, i uint64) []byte { + if i <= maxVarInt1 { + return append(b, uint8(i)) + } + if i <= maxVarInt2 { + return append(b, []byte{uint8(i>>8) | 0x40, uint8(i)}...) + } + if i <= maxVarInt4 { + return append(b, []byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}...) + } + if i <= maxVarInt8 { + return append(b, []byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }...) + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} + +// AppendWithLen append i in the QUIC varint format with the desired length. +func AppendWithLen(b []byte, i uint64, length int64) []byte { + if length != 1 && length != 2 && length != 4 && length != 8 { + panic("invalid varint length") + } + l := Len(i) + if l == length { + return Append(b, i) + } + if l > length { + panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) + } + if length == 2 { + b = append(b, 0b01000000) + } else if length == 4 { + b = append(b, 0b10000000) + } else if length == 8 { + b = append(b, 0b11000000) + } + for j := int64(1); j < length-l; j++ { + b = append(b, 0) + } + for j := int64(0); j < l; j++ { + b = append(b, uint8(i>>(8*(l-1-j)))) + } + return b +} + // Len determines the number of bytes that will be needed to write the number i. -func Len(i uint64) uint64 { +func Len(i uint64) int64 { if i <= maxVarInt1 { return 1 } From 943e35dcfff3b4f0154767c3a08f892c248f6bda Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Apr 2023 10:06:36 +0800 Subject: [PATCH 686/843] Drop go1.18 ci workflow --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be91c2a8..5883b872 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.20.x', '1.19.x', '1.18.x' ] + go: [ '1.20.x', '1.19.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: From d25aa32e2ba148b81f78315405ae4f9b30df0e60 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 29 Apr 2023 10:09:33 +0800 Subject: [PATCH 687/843] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5f7d61f7..7dc4c680 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Full documentation is available on the official website: https://req.cool. **Install** -You first need [Go](https://go.dev/) installed (version 1.18+ is required), then you can use the below Go command to install req: +You first need [Go](https://go.dev/) installed (version 1.19+ is required), then you can use the below Go command to install req: ``` sh go get github.com/imroc/req/v3 From 83f97bdc9ca6e73c5a780e96f78e50b0e8729e6f Mon Sep 17 00:00:00 2001 From: M-Cosmosss Date: Wed, 17 May 2023 23:31:55 +0800 Subject: [PATCH 688/843] feat: add client func GetCookiesFromJar --- client.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index b5dbb370..c4eeba03 100644 --- a/client.go +++ b/client.go @@ -8,9 +8,6 @@ import ( "encoding/json" "encoding/xml" "errors" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/util" - "golang.org/x/net/publicsuffix" "io" "net" "net/http" @@ -20,6 +17,11 @@ import ( "reflect" "strings" "time" + + "golang.org/x/net/publicsuffix" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" ) // DefaultClient returns the global default Client. @@ -942,6 +944,18 @@ func (c *Client) SetCookieJar(jar http.CookieJar) *Client { return c } +// GetCookiesFromJar get cookies from the underlying `http.Client`'s `CookieJar`. +func (c *Client) GetCookiesFromJar(url string) ([]*http.Cookie, error) { + if c.httpClient.Jar == nil { + return nil, errors.New("cookie jar is not enabled") + } + u, err := urlpkg.Parse(url) + if err != nil { + return nil, err + } + return c.httpClient.Jar.Cookies(u), nil +} + // ClearCookies clears all cookies if cookie is enabled. func (c *Client) ClearCookies() *Client { if c.httpClient.Jar != nil { From a10478eec5540c1570c01b1e27144842a783f750 Mon Sep 17 00:00:00 2001 From: M-Cosmosss Date: Thu, 18 May 2023 12:51:32 +0800 Subject: [PATCH 689/843] rename func Signed-off-by: M-Cosmosss --- client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index c4eeba03..dcd7566c 100644 --- a/client.go +++ b/client.go @@ -944,8 +944,8 @@ func (c *Client) SetCookieJar(jar http.CookieJar) *Client { return c } -// GetCookiesFromJar get cookies from the underlying `http.Client`'s `CookieJar`. -func (c *Client) GetCookiesFromJar(url string) ([]*http.Cookie, error) { +// GetCookies get cookies from the underlying `http.Client`'s `CookieJar`. +func (c *Client) GetCookies(url string) ([]*http.Cookie, error) { if c.httpClient.Jar == nil { return nil, errors.New("cookie jar is not enabled") } From 2eb6b13865bd246ff0f8e3bcafbc829047cb9742 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 30 May 2023 20:52:21 +0800 Subject: [PATCH 690/843] support quic-go v0.35.0 --- go.mod | 16 ++++++------ go.sum | 16 ++++++++++++ internal/http3/client.go | 13 +++++++--- internal/http3/roundtrip.go | 35 +++++++++++++++++++++------ internal/quic-go/quicvarint/varint.go | 19 --------------- 5 files changed, 61 insertions(+), 38 deletions(-) diff --git a/go.mod b/go.mod index f08387ce..a148c6e4 100644 --- a/go.mod +++ b/go.mod @@ -5,23 +5,23 @@ go 1.19 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.34.0 - golang.org/x/net v0.9.0 + github.com/quic-go/quic-go v0.35.0 + golang.org/x/net v0.10.0 golang.org/x/text v0.9.0 ) require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230426061923-93006964c1fc // indirect + github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/onsi/ginkgo/v2 v2.9.2 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/quic-go/qtls-go1-18 v0.2.0 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect - golang.org/x/crypto v0.8.0 // indirect - golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/sys v0.7.0 // indirect - golang.org/x/tools v0.8.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/tools v0.9.1 // indirect ) diff --git a/go.sum b/go.sum index d9f7ae80..7835462c 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20230426061923-93006964c1fc h1:AGDHt781oIcL4EFk7cPnvBUYTwU8BEU6GDTO3ZMn1sE= github.com/google/pprof v0.0.0-20230426061923-93006964c1fc/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= +github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= +github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -26,6 +28,8 @@ github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU= github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -45,6 +49,8 @@ github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7t github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= +github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -56,10 +62,14 @@ golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= @@ -72,6 +82,8 @@ golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -84,6 +96,8 @@ golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -98,6 +112,8 @@ golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/http3/client.go b/internal/http3/client.go index 8b180569..43249483 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "reflect" "strconv" @@ -37,12 +38,11 @@ const ( var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlivePeriod: 10 * time.Second, - Versions: []quic.VersionNumber{Version1}, } type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) -var dialAddr = quic.DialAddrEarlyContext +var dialAddr dialFunc = quic.DialAddrEarly type roundTripperOpts struct { DisableCompression bool @@ -79,9 +79,10 @@ var _ roundTripCloser = &client{} func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() - } else if len(conf.Versions) == 0 { + } + if len(conf.Versions) == 0 { conf = conf.Clone() - conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + conf.Versions = []quic.VersionNumber{Version1} } if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") @@ -100,6 +101,10 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con } else { tlsConf = tlsConf.Clone() } + if tlsConf.ServerName == "" { + host, _, _ := net.SplitHostPort(hostname) + tlsConf.ServerName = host + } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index ea16105a..ed26e49a 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -10,6 +10,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/imroc/req/v3/internal/transport" @@ -20,7 +21,7 @@ import ( ) // declare this as a variable, such that we can it mock it in the tests -var quicDialer = quic.DialEarlyContext +var quicDialer = quic.DialEarly type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) @@ -28,6 +29,11 @@ type roundTripCloser interface { io.Closer } +type roundTripCloserWithCount struct { + roundTripCloser + useCount atomic.Int64 +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { *transport.Options @@ -67,7 +73,7 @@ type RoundTripper struct { Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) // so we can mock it in tests - clients map[string]roundTripCloser + clients map[string]*roundTripCloserWithCount udpConn *net.UDPConn } @@ -139,6 +145,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } + defer cl.useCount.Add(-1) rsp, err := cl.RoundTripOpt(req, opt) if err != nil { r.removeClient(hostname) @@ -163,11 +170,12 @@ func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Respons // AddConn add a http3 connection, dial new conn if not exists. func (r *RoundTripper) AddConn(addr string) error { + addr = authorityAddr("https", addr) c, _, err := r.getClient(addr, false) if err != nil { return err } - client, ok := c.(*client) + client, ok := c.roundTripCloser.(*client) if !ok { return errors.New("bad client type") } @@ -178,12 +186,12 @@ func (r *RoundTripper) AddConn(addr string) error { return client.handshakeErr } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]roundTripCloser) + r.clients = make(map[string]*roundTripCloserWithCount) } client, ok := r.clients[hostname] @@ -206,7 +214,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri } dial = r.makeDialer() } - client, err = newCl( + c, err := newCl( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -224,8 +232,10 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri if err != nil { return nil, false, err } + client = &roundTripCloserWithCount{roundTripCloser: c} r.clients[hostname] = client } + client.useCount.Add(1) return client, isReused, nil } @@ -291,6 +301,17 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf if err != nil { return nil, err } - return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg) + return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) + } +} + +func (r *RoundTripper) CloseIdleConnections() { + r.mutex.Lock() + defer r.mutex.Unlock() + for hostname, client := range r.clients { + if client.useCount.Load() == 0 { + client.Close() + delete(r.clients, hostname) + } } } diff --git a/internal/quic-go/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go index fdb00353..60d17e3d 100644 --- a/internal/quic-go/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -68,25 +68,6 @@ func Read(r io.ByteReader) (uint64, error) { return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } -// Write writes i in the QUIC varint format to w. -// Deprecated: use Append instead. -func Write(w Writer, i uint64) { - if i <= maxVarInt1 { - w.WriteByte(uint8(i)) - } else if i <= maxVarInt2 { - w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) - } else if i <= maxVarInt4 { - w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) - } else if i <= maxVarInt8 { - w.Write([]byte{ - uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), - uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), - }) - } else { - panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) - } -} - // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { From bc28d016cf52a03e41ccaefc4e6453805aaddfe2 Mon Sep 17 00:00:00 2001 From: Ronaldinho Date: Thu, 8 Jun 2023 12:50:30 +0800 Subject: [PATCH 691/843] fix: cookies to be added multiple times in retrying --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index dcd7566c..07561a99 100644 --- a/client.go +++ b/client.go @@ -1403,7 +1403,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { } req := &http.Request{ Method: r.Method, - Header: r.Headers, + Header: r.Headers.Clone(), URL: r.URL, Host: host, Proto: "HTTP/1.1", From 2ed9f81c09d0eb1925df24cdff599dc6686ff90a Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 14 Jun 2023 10:31:43 +0800 Subject: [PATCH 692/843] go mod tidy --- go.mod | 1 - go.sum | 61 +++++----------------------------------------------------- 2 files changed, 5 insertions(+), 57 deletions(-) diff --git a/go.mod b/go.mod index a148c6e4..4d57d101 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect - github.com/quic-go/qtls-go1-18 v0.2.0 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect golang.org/x/crypto v0.9.0 // indirect diff --git a/go.sum b/go.sum index 7835462c..9a77f6b9 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,13 @@ -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20230426061923-93006964c1fc h1:AGDHt781oIcL4EFk7cPnvBUYTwU8BEU6GDTO3ZMn1sE= -github.com/google/pprof v0.0.0-20230426061923-93006964c1fc/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -23,95 +15,54 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= -github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU= -github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U= -github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= -github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= -github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= -github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= -github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= -github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= -github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= -golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= -golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= -golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= -golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= -golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= -golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= -golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= -golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= -golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -119,7 +70,5 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 344c22c4aca98df9956f8c9b2ba0101986950019 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 14 Jun 2023 11:12:33 +0800 Subject: [PATCH 693/843] integrate utls to support tls fingerprinting resistance --- client.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +++ go.sum | 8 ++++++ 3 files changed, 97 insertions(+) diff --git a/client.go b/client.go index 07561a99..fc49d77d 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" + utls "github.com/refraction-networking/utls" ) // DefaultClient returns the global default Client. @@ -1007,6 +1008,90 @@ func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net return c } +// SetTLSFingerprintChrome uses tls fingerprint of Chrome browser. +func (c *Client) SetTLSFingerprintChrome() *Client { + return c.SetTLSFingerprint(utls.HelloChrome_Auto) +} + +// SetTLSFingerprintFirefox uses tls fingerprint of Firefox browser. +func (c *Client) SetTLSFingerprintFirefox() *Client { + return c.SetTLSFingerprint(utls.HelloFirefox_Auto) +} + +// SetTLSFingerprintEdge uses tls fingerprint of Edge browser. +func (c *Client) SetTLSFingerprintEdge() *Client { + return c.SetTLSFingerprint(utls.HelloEdge_Auto) +} + +// SetTLSFingerprintQQ uses tls fingerprint of QQ browser. +func (c *Client) SetTLSFingerprintQQ() *Client { + return c.SetTLSFingerprint(utls.HelloQQ_Auto) +} + +// SetTLSFingerprintSafari uses tls fingerprint of Safari browser. +func (c *Client) SetTLSFingerprintSafari() *Client { + return c.SetTLSFingerprint(utls.HelloSafari_Auto) +} + +// SetTLSFingerprint360 uses tls fingerprint of 360 browser. +func (c *Client) SetTLSFingerprint360() *Client { + return c.SetTLSFingerprint(utls.Hello360_Auto) +} + +// SetTLSFingerprintIOS uses tls fingerprint of IOS. +func (c *Client) SetTLSFingerprintIOS() *Client { + return c.SetTLSFingerprint(utls.HelloIOS_Auto) +} + +// SetTLSFingerprintAndroid uses tls fingerprint of Android. +func (c *Client) SetTLSFingerprintAndroid() *Client { + return c.SetTLSFingerprint(utls.HelloAndroid_11_OkHttp) +} + +// uTLSConn is wrapper of UConn which implements the net.Conn interface. +type uTLSConn struct { + *utls.UConn +} + +func (conn *uTLSConn) ConnectionState() tls.ConnectionState { + cs := conn.Conn.ConnectionState() + return tls.ConnectionState{ + Version: cs.Version, + HandshakeComplete: cs.HandshakeComplete, + DidResume: cs.DidResume, + CipherSuite: cs.CipherSuite, + NegotiatedProtocol: cs.NegotiatedProtocol, + NegotiatedProtocolIsMutual: cs.NegotiatedProtocolIsMutual, + ServerName: cs.ServerName, + PeerCertificates: cs.PeerCertificates, + VerifiedChains: cs.VerifiedChains, + SignedCertificateTimestamps: cs.SignedCertificateTimestamps, + OCSPResponse: cs.OCSPResponse, + TLSUnique: cs.TLSUnique, + } +} + +// SetTLSFingerprint set the tls fingerprint for tls handshake, will use utls +// (https://github.com/refraction-networking/utls) to perform the tls handshake, +// which uses the specified clientHelloID to simulate the tls fingerprint. +func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { + c.SetDialTLS(func(ctx context.Context, network, addr string) (net.Conn, error) { + plainConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + utlsConfig := &utls.Config{ServerName: hostname, NextProtos: c.GetTLSClientConfig().NextProtos} + conn := utls.UClient(plainConn, utlsConfig, clientHelloID) + return &uTLSConn{conn}, nil + }) + return c +} + // SetTLSHandshakeTimeout set the TLS handshake timeout. func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { c.t.SetTLSHandshakeTimeout(timeout) diff --git a/go.mod b/go.mod index 4d57d101..06047ad9 100644 --- a/go.mod +++ b/go.mod @@ -11,13 +11,17 @@ require ( ) require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/gaukas/godicttls v0.0.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/klauspost/compress v1.15.15 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect + github.com/refraction-networking/utls v1.3.2 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/mod v0.10.0 // indirect diff --git a/go.sum b/go.sum index 9a77f6b9..bf58ebf9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= +github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= @@ -15,6 +19,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= @@ -28,6 +34,8 @@ github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8G github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= +github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= From 7e40b68c28d668a4fab6550c42e24b310d119488 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 14 Jun 2023 19:16:03 +0800 Subject: [PATCH 694/843] do some chores --- transport.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transport.go b/transport.go index 56e5cb42..a2425623 100644 --- a/transport.go +++ b/transport.go @@ -1814,7 +1814,7 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) if timer != nil { timer.Stop() } @@ -1847,10 +1847,6 @@ func newHttp2NotSupportedError(negotiatedProtocol string) error { return errors.New(errMsg) } -type erringRoundTripper interface { - RoundTripErr() error -} - func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, From 4a6fedcb78f4abc9c7fc5401ac6f75d890e81329 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 11:50:26 +0800 Subject: [PATCH 695/843] let tls fingerprinting works event a proxy is set --- client.go | 42 ++++++++++++++++---- internal/transport/option.go | 4 ++ transport.go | 74 +++++++++++++++++++++++++++++++++--- 3 files changed, 107 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index fc49d77d..0fe4dc89 100644 --- a/client.go +++ b/client.go @@ -1074,21 +1074,47 @@ func (conn *uTLSConn) ConnectionState() tls.ConnectionState { // SetTLSFingerprint set the tls fingerprint for tls handshake, will use utls // (https://github.com/refraction-networking/utls) to perform the tls handshake, // which uses the specified clientHelloID to simulate the tls fingerprint. +// Note this is valid for HTTP1 and HTTP2, not HTTP3. func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { - c.SetDialTLS(func(ctx context.Context, network, addr string) (net.Conn, error) { - plainConn, err := net.Dial(network, addr) - if err != nil { - return nil, err - } + fn := func(ctx context.Context, addr string, plainConn net.Conn) (conn net.Conn, tlsState *tls.ConnectionState, err error) { colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { colonPos = len(addr) } hostname := addr[:colonPos] utlsConfig := &utls.Config{ServerName: hostname, NextProtos: c.GetTLSClientConfig().NextProtos} - conn := utls.UClient(plainConn, utlsConfig, clientHelloID) - return &uTLSConn{conn}, nil - }) + uconn := &uTLSConn{utls.UClient(plainConn, utlsConfig, clientHelloID)} + err = uconn.HandshakeContext(ctx) + if err != nil { + return + } + cs := uconn.Conn.ConnectionState() + conn = uconn + tlsState = &tls.ConnectionState{ + Version: cs.Version, + HandshakeComplete: cs.HandshakeComplete, + DidResume: cs.DidResume, + CipherSuite: cs.CipherSuite, + NegotiatedProtocol: cs.NegotiatedProtocol, + NegotiatedProtocolIsMutual: cs.NegotiatedProtocolIsMutual, + ServerName: cs.ServerName, + PeerCertificates: cs.PeerCertificates, + VerifiedChains: cs.VerifiedChains, + SignedCertificateTimestamps: cs.SignedCertificateTimestamps, + OCSPResponse: cs.OCSPResponse, + TLSUnique: cs.TLSUnique, + } + return + } + c.SetTLSHandshake(fn) + return c +} + +// SetTLSHandshake set the custom tls handshake function, only valid for HTTP1 and HTTP2, not HTTP3, +// it specifies an optional dial function for tls handshake, it works even if a proxy is set, can be +// used to customize the tls fingerprint. +func (c *Client) SetTLSHandshake(fn func(ctx context.Context, addr string, plainConn net.Conn) (conn net.Conn, tlsState *tls.ConnectionState, err error)) *Client { + c.t.SetTLSHandshake(fn) return c } diff --git a/internal/transport/option.go b/internal/transport/option.go index 714aea34..fc033e15 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -43,6 +43,10 @@ type Options struct { // past the TLS handshake. DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // TLSHandshakeContext specifies an optional dial function for tls handshake, + // it works even if a proxy is set, can be used to customize the tls fingerprint. + TLSHandshakeContext func(ctx context.Context, addr string, plainConn net.Conn) (conn net.Conn, tlsState *tls.ConnectionState, err error) + // TLSClientConfig specifies the TLS configuration to use with // tls.Client. // If nil, the default configuration is used. diff --git a/transport.go b/transport.go index a2425623..d5ea2b5d 100644 --- a/transport.go +++ b/transport.go @@ -444,7 +444,8 @@ func (t *Transport) SetDial(fn func(ctx context.Context, network, addr string) ( } // SetDialTLS set the custom DialTLSContext function, only valid for HTTP1 and HTTP2, which specifies -// an optional dial function for creating TLS connections for non-proxied HTTPS requests. +// an optional dial function for creating TLS connections for non-proxied HTTPS requests (proxy will +// not work if set). // // If it is nil, DialContext and TLSClientConfig are used. // @@ -455,6 +456,14 @@ func (t *Transport) SetDialTLS(fn func(ctx context.Context, network, addr string return t } +// SetTLSHandshake set the custom tls handshake function, only valid for HTTP1 and HTTP2, not HTTP3, +// it specifies an optional dial function for tls handshake, it works even if a proxy is set, can be +// used to customize the tls fingerprint. +func (t *Transport) SetTLSHandshake(fn func(ctx context.Context, addr string, plainConn net.Conn) (conn net.Conn, tlsState *tls.ConnectionState, err error)) *Transport { + t.TLSHandshakeContext = fn + return t +} + type pendingAltSvc struct { CurrentIndex int Entries []*altsvc.AltSvc @@ -1793,6 +1802,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { + // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pc.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1847,6 +1857,42 @@ func newHttp2NotSupportedError(negotiatedProtocol string) error { return errors.New(errMsg) } +func (t *Transport) customTlsHandshake(ctx context.Context, trace *httptrace.ClientTrace, addr string, pconn *persistConn) error { + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + conn, tlsState, err := t.TLSHandshakeContext(ctx, addr, pconn.conn) + if err != nil { + if timer != nil { + timer.Stop() + } + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + } else { + pconn.conn = conn + pconn.tlsState = tlsState + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(*tlsState, nil) + } + } + errc <- err + }() + if err := <-errc; err != nil { + pconn.conn.Close() + return err + } + return nil +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -1904,12 +1950,23 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { return nil, wrapErr(err) } - if err = pconn.addTLS(ctx, firstTLSHost, trace, cm.proxyURL != nil); err != nil { - return nil, wrapErr(err) + if t.TLSHandshakeContext != nil && cm.proxyURL == nil { + err = t.customTlsHandshake(ctx, trace, firstTLSHost, pconn) + if err != nil { + return nil, err + } + } else { + if err = pconn.addTLS(ctx, firstTLSHost, trace, cm.proxyURL != nil); err != nil { + return nil, wrapErr(err) + } } } } + if t.Debugf != nil && cm.proxyURL != nil { + t.Debugf("connect %s via proxy %s", cm.targetAddr, cm.proxyURL.String()) + } + // Proxy setup. switch { case cm.proxyURL == nil: @@ -2018,8 +2075,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.proxyURL != nil && cm.targetScheme == "https" { - if err := pconn.addTLS(ctx, cm.tlsHost(), trace, false); err != nil { - return nil, err + if t.TLSHandshakeContext != nil { + err := t.customTlsHandshake(ctx, trace, cm.tlsHost(), pconn) + if err != nil { + return nil, err + } + } else { + if err := pconn.addTLS(ctx, cm.tlsHost(), trace, false); err != nil { + return nil, err + } } } From 6b480f8d447a61463f3197d73150bbcb74b2d230 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 12:11:43 +0800 Subject: [PATCH 696/843] use HandshakeContext if custom tls dialer if not handshaked --- pkg/tls/conn.go | 13 +++++++++++++ transport.go | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pkg/tls/conn.go b/pkg/tls/conn.go index 8c6ef1ad..428cd930 100644 --- a/pkg/tls/conn.go +++ b/pkg/tls/conn.go @@ -1,6 +1,7 @@ package tls import ( + "context" "crypto/tls" "net" ) @@ -23,4 +24,16 @@ type Conn interface { // For control over canceling or setting a timeout on a handshake, use // HandshakeContext or the Dialer's DialContext method instead. Handshake() error + + // HandshakeContext runs the client or server handshake + // protocol if it has not yet been run. + // + // The provided Context must be non-nil. If the context is canceled before + // the handshake is complete, the handshake is interrupted and an error is returned. + // Once the handshake has completed, cancellation of the context will not affect the + // connection. + // + // Most uses of this package need not call HandshakeContext explicitly: the + // first Read or Write will call it automatically. + HandshakeContext(ctx context.Context) error } diff --git a/transport.go b/transport.go index d5ea2b5d..9cb52421 100644 --- a/transport.go +++ b/transport.go @@ -1923,7 +1923,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - if err := tc.Handshake(); err != nil { + if err := tc.HandshakeContext(ctx); err != nil { go pconn.conn.Close() if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) From 9db035691e5de06a9ac23c9a3fe4b88a474c74b4 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 13:45:39 +0800 Subject: [PATCH 697/843] add SetTLSFingerprintRandomized --- client.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client.go b/client.go index 0fe4dc89..29bfeedc 100644 --- a/client.go +++ b/client.go @@ -1048,6 +1048,11 @@ func (c *Client) SetTLSFingerprintAndroid() *Client { return c.SetTLSFingerprint(utls.HelloAndroid_11_OkHttp) } +// SetTLSFingerprintRandomized uses randomized tls fingerprint. +func (c *Client) SetTLSFingerprintRandomized() *Client { + return c.SetTLSFingerprint(utls.HelloRandomized) +} + // uTLSConn is wrapper of UConn which implements the net.Conn interface. type uTLSConn struct { *utls.UConn From 1b5cf921807c742c69fbca12be6d03ee487e2c71 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 13:58:24 +0800 Subject: [PATCH 698/843] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 7dc4c680..1d5b6c87 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Full documentation is available on the official website: https://req.cool. * **Smart by Default**: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. * **Support Multiple HTTP Versions**: Support `HTTP/1.1`, `HTTP/2`, and `HTTP/3`, and can automatically detect the server side and select the optimal HTTP version for requests, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). +* **TLS Fingerprinting**: Support tls fingerprinting resistance, so that we can access websites that prohibit crawler programs by identifying TLS handshake fingerprints (See [TLS Fingerprinting](https://req.cool/docs/tutorial/tls-fingerprinting/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). * **Exportable**: `req.Transport` is exportable. Compared with `http.Transport`, it also supports HTTP3, dump content, middleware, etc. It can directly replace the Transport of `http.Client` in existing projects, and obtain more powerful functions with minimal code change. * **Extensible**: Support Middleware for Request, Response, Client and Transport (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware-for-request-and-response/)) and [Client and Transport Middleware](https://req.cool/docs/tutorial/middleware-for-client-and-transport/)). From dee4e87023bc77387a26935938654de2e95646f9 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 17:34:26 +0800 Subject: [PATCH 699/843] merge upstream http2: 2023-06-13(f5464d) --- internal/http2/flow.go | 2 +- internal/http2/frame.go | 11 +- internal/http2/pipe.go | 6 +- internal/http2/pipe_test.go | 11 +- internal/http2/server_test.go | 114 ++++++++- internal/http2/transport.go | 68 +++++- internal/http2/transport_test.go | 399 +++++++++++++++++-------------- transport_test.go | 12 - 8 files changed, 408 insertions(+), 215 deletions(-) diff --git a/internal/http2/flow.go b/internal/http2/flow.go index 750ac52f..b7dbd186 100644 --- a/internal/http2/flow.go +++ b/internal/http2/flow.go @@ -18,7 +18,7 @@ type inflow struct { unsent int32 } -// set sets the initial window. +// init sets the initial window. func (f *inflow) init(n int32) { f.avail = n } diff --git a/internal/http2/frame.go b/internal/http2/frame.go index bf829fe9..34284fe5 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -715,6 +715,15 @@ func (h2f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. func (h2f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if err := h2f.startWriteDataPadded(streamID, endStream, data, pad); err != nil { + return err + } + return h2f.endWrite() +} + +// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer. +// The caller should call endWrite to flush the frame to the underlying writer. +func (h2f *Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } @@ -744,7 +753,7 @@ func (h2f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad [] } h2f.wbuf = append(h2f.wbuf, data...) h2f.wbuf = append(h2f.wbuf, pad...) - return h2f.endWrite() + return nil } // A SettingsFrame conveys configuration parameters that affect how diff --git a/internal/http2/pipe.go b/internal/http2/pipe.go index c15b8a77..684d984f 100644 --- a/internal/http2/pipe.go +++ b/internal/http2/pipe.go @@ -88,13 +88,9 @@ func (p *pipe) Write(d []byte) (n int, err error) { p.c.L = &p.mu } defer p.c.Signal() - if p.err != nil { + if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } - if p.breakErr != nil { - p.unread += len(d) - return len(d), nil // discard when there is no reader - } return p.b.Write(d) } diff --git a/internal/http2/pipe_test.go b/internal/http2/pipe_test.go index c21c007e..391ab3bf 100644 --- a/internal/http2/pipe_test.go +++ b/internal/http2/pipe_test.go @@ -8,6 +8,7 @@ import ( "bytes" "errors" "io" + "io/ioutil" "testing" ) @@ -111,7 +112,7 @@ func TestPipeBreakWithError(t *testing.T) { io.WriteString(p, "foo") a := errors.New("test err") p.BreakWithError(a) - all, err := io.ReadAll(p) + all, err := ioutil.ReadAll(p) if string(all) != "" { t.Errorf("read bytes = %q; want empty string", all) } @@ -124,14 +125,14 @@ func TestPipeBreakWithError(t *testing.T) { if p.Len() != 3 { t.Errorf("pipe should have 3 unread bytes") } - // Write should succeed silently. - if n, err := p.Write([]byte("abc")); err != nil || n != 3 { - t.Errorf("Write(abc) after break\ngot %v, %v\nwant 0, nil", n, err) + // Write should fail. + if n, err := p.Write([]byte("abc")); err != errClosedPipeWrite || n != 0 { + t.Errorf("Write(abc) after break\ngot %v, %v\nwant 0, errClosedPipeWrite", n, err) } if p.b != nil { t.Errorf("buffer should be nil after Write") } - if p.Len() != 6 { + if p.Len() != 3 { t.Errorf("pipe should have 6 unread bytes") } // Read should fail. diff --git a/internal/http2/server_test.go b/internal/http2/server_test.go index 77d0cf2f..be803d76 100644 --- a/internal/http2/server_test.go +++ b/internal/http2/server_test.go @@ -790,7 +790,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { if s.NewWriteScheduler != nil { sc.writeSched = s.NewWriteScheduler() } else { - sc.writeSched = NewRandomWriteScheduler() + sc.writeSched = newRoundRobinWriteScheduler() } // These start at the RFC-specified defaults. If there is a higher @@ -4119,7 +4119,8 @@ func (wr *FrameWriteRequest) replyToWriter(err error) { // writeQueue is used by implementations of WriteScheduler. type writeQueue struct { - s []FrameWriteRequest + s []FrameWriteRequest + prev, next *writeQueue } func (q *writeQueue) empty() bool { return len(q.s) == 0 } @@ -5415,3 +5416,112 @@ func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { }) <-donec } + +type roundRobinWriteScheduler struct { + // control contains control frames (SETTINGS, PING, etc.). + control writeQueue + + // streams maps stream ID to a queue. + streams map[uint32]*writeQueue + + // stream queues are stored in a circular linked list. + // head is the next stream to write, or nil if there are no streams open. + head *writeQueue + + // pool of empty queues for reuse. + queuePool writeQueuePool +} + +// newRoundRobinWriteScheduler constructs a new write scheduler. +// The round robin scheduler priorizes control frames +// like SETTINGS and PING over DATA frames. +// When there are no control frames to send, it performs a round-robin +// selection from the ready streams. +func newRoundRobinWriteScheduler() WriteScheduler { + ws := &roundRobinWriteScheduler{ + streams: make(map[uint32]*writeQueue), + } + return ws +} + +func (ws *roundRobinWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { + if ws.streams[streamID] != nil { + panic(fmt.Errorf("stream %d already opened", streamID)) + } + q := ws.queuePool.get() + ws.streams[streamID] = q + if ws.head == nil { + ws.head = q + q.next = q + q.prev = q + } else { + // Queues are stored in a ring. + // Insert the new stream before ws.head, putting it at the end of the list. + q.prev = ws.head.prev + q.next = ws.head + q.prev.next = q + q.next.prev = q + } +} + +func (ws *roundRobinWriteScheduler) CloseStream(streamID uint32) { + q := ws.streams[streamID] + if q == nil { + return + } + if q.next == q { + // This was the only open stream. + ws.head = nil + } else { + q.prev.next = q.next + q.next.prev = q.prev + if ws.head == q { + ws.head = q.next + } + } + delete(ws.streams, streamID) + ws.queuePool.put(q) +} + +func (ws *roundRobinWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) {} + +func (ws *roundRobinWriteScheduler) Push(wr FrameWriteRequest) { + if wr.isControl() { + ws.control.push(wr) + return + } + q := ws.streams[wr.StreamID()] + if q == nil { + // This is a closed stream. + // wr should not be a HEADERS or DATA frame. + // We push the request onto the control queue. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + ws.control.push(wr) + return + } + q.push(wr) +} + +func (ws *roundRobinWriteScheduler) Pop() (FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.control.empty() { + return ws.control.shift(), true + } + if ws.head == nil { + return FrameWriteRequest{}, false + } + q := ws.head + for { + if wr, ok := q.consume(math.MaxInt32); ok { + ws.head = q.next + return wr, true + } + q = q.next + if q == ws.head { + break + } + } + return FrameWriteRequest{}, false +} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 37f251ee..7f7f34db 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -288,6 +288,7 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error } func (cs *clientStream) abortStream(err error) { + fmt.Println("abortStream") cs.cc.mu.Lock() defer cs.cc.mu.Unlock() cs.abortStreamLocked(err) @@ -459,10 +460,11 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res traceGotConn(req, cc, reused) res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { + roundTripErr := err if req, err = shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { - t.vlogf("RoundTrip retrying after failure: %v", err) + t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue } backoff := float64(uint(1) << (uint(retry) - 1)) @@ -471,7 +473,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res timer := backoffNewTimer(d) select { case <-timer.C: - t.vlogf("RoundTrip retrying after failure: %v", err) + t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): timer.Stop() @@ -1111,6 +1113,45 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return res, nil } + cancelRequest := func(cs *clientStream, err error) error { + cs.cc.mu.Lock() + fmt.Println("cancelRequest---") + cs.abortStreamLocked(err) + bodyClosed := cs.reqBodyClosed + if cs.ID != 0 { + // This request may have failed because of a problem with the connection, + // or for some unrelated reason. (For example, the user might have canceled + // the request without waiting for a response.) Mark the connection as + // not reusable, since trying to reuse a dead connection is worse than + // unnecessarily creating a new one. + // + // If cs.ID is 0, then the request was never allocated a stream ID and + // whatever went wrong was unrelated to the connection. We might have + // timed out waiting for a stream slot when StrictMaxConcurrentStreams + // is set, for example, in which case retrying on a different connection + // will not help. + cs.cc.doNotReuse = true + } + cs.cc.mu.Unlock() + // Wait for the request body to be closed. + // + // If nothing closed the body before now, abortStreamLocked + // will have started a goroutine to close it. + // + // Closing the body before returning avoids a race condition + // with net/http checking its readTrackingBody to see if the + // body was read from or closed. See golang/go#60041. + // + // The body is closed in a separate goroutine without the + // connection mutex held, but dropping the mutex before waiting + // will keep us from holding it indefinitely if the body + // close is slow for some reason. + if bodyClosed != nil { + <-bodyClosed + } + return err + } + for { select { case <-cs.respHeaderRecv: @@ -1124,16 +1165,16 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // golang.org/issue/49645 return handleResponseHeaders() default: + fmt.Println("just abort") waitDone() - return nil, cs.abortErr + return nil, cancelRequest(cs, cs.abortErr) } case <-ctx.Done(): - err := ctx.Err() - cs.abortStream(err) - return nil, err + fmt.Println("ctx done") + return nil, cancelRequest(cs, ctx.Err()) case <-cs.reqCancel: - cs.abortStream(common.ErrRequestCanceled) - return nil, common.ErrRequestCanceled + fmt.Println("reqCancel") + return nil, cancelRequest(cs, common.ErrRequestCanceled) } } } @@ -1398,6 +1439,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) { } } if err != nil { + fmt.Println("cleanupWriteRequest") cs.abortStream(err) // possibly redundant, but harmless if cs.sentHeaders { if se, ok := err.(StreamError); ok { @@ -1429,7 +1471,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) { close(cs.donec) } -// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { for { @@ -1749,7 +1791,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of + // followed by the query production, see Sections 3.3 and 3.4 of // [RFC3986]). f(":authority", host) m := req.Method @@ -2475,6 +2517,9 @@ func (b transportResponseBody) Close() error { cs := b.cs cc := cs.cc + cs.bufPipe.BreakWithError(errClosedResponseBody) + cs.abortStream(errClosedResponseBody) + unread := cs.bufPipe.Len() if unread > 0 { cc.mu.Lock() @@ -2493,9 +2538,6 @@ func (b transportResponseBody) Close() error { cc.wmu.Unlock() } - cs.bufPipe.BreakWithError(errClosedResponseBody) - cs.abortStream(errClosedResponseBody) - select { case <-cs.donec: case <-cs.ctx.Done(): diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index 0cad0c63..d587c1ae 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -748,7 +748,6 @@ func newClientTester(t *testing.T) *clientTester { cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) - } sc, err := ln.Accept() if err != nil { @@ -761,6 +760,18 @@ func newClientTester(t *testing.T) *clientTester { return ct } +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err == nil { + return ln + } + ln, err = net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Fatal(err) + } + return ln +} + func (ct *clientTester) greet(settings ...Setting) { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) @@ -1730,6 +1741,17 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { defer tr.CloseIdleConnections() checkRoundTrip := func(req *http.Request, wantErr error, desc string) { + // Make an arbitrary request to ensure we get the server's + // settings frame and initialize peerMaxHeaderListSize. + req0, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatalf("newRequest: NewRequest: %v", err) + } + res0, err := tr.RoundTrip(req0) + if err != nil { + t.Errorf("%v: Initial RoundTrip err = %v", desc, err) + } + res0.Body.Close() res, err := tr.RoundTrip(req) if err != wantErr { if res != nil { @@ -1790,13 +1812,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { return req } - // Make an arbitrary request to ensure we get the server's - // settings frame and initialize peerMaxHeaderListSize. + // Validate peerMaxHeaderListSize. req := newRequest() checkRoundTrip(req, nil, "Initial request") - - // Get the ClientConn associated with the request and validate - // peerMaxHeaderListSize. addr := authorityAddr(req.URL.Scheme, req.URL.Host) cc, err := tr.connPool().GetClientConn(req, addr, true) if err != nil { @@ -2729,58 +2747,6 @@ func TestTransportReadHeadResponseWithBody(t *testing.T) { ct.run() } -type neverEnding byte - -func (b neverEnding) Read(p []byte) (int, error) { - for i := range p { - p[i] = byte(b) - } - return len(p), nil -} - -// golang.org/issue/15425: test that a handler closing the request -// body doesn't terminate the stream to the peer. (It just stops -// readability from the handler's side, and eventually the client -// runs out of flow control tokens) -func TestTransportHandlerBodyClose(t *testing.T) { - const bodySize = 10 << 20 - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - r.Body.Close() - io.Copy(w, io.LimitReader(tests.NeverEnding('A'), bodySize)) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - g0 := runtime.NumGoroutine() - - const numReq = 10 - for i := 0; i < numReq; i++ { - req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(tests.NeverEnding('A'), bodySize)}) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - n, err := io.Copy(io.Discard, res.Body) - res.Body.Close() - if n != bodySize || err != nil { - t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) - } - } - tr.CloseIdleConnections() - - if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { - gd := runtime.NumGoroutine() - g0 - return gd < numReq/2 - }) { - t.Errorf("appeared to leak goroutines") - } -} - // https://golang.org/issue/15930 func TestTransportFlowControl(t *testing.T) { const bufLen = 64 << 10 @@ -3705,35 +3671,33 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D ct.run() } -func TestTransportRetryAfterGOAWAY(t *testing.T) { - var dialer struct { - sync.Mutex - count int - } - ct1 := make(chan *clientTester) - ct2 := make(chan *clientTester) - - ln := tests.NewLocalListener(t) +func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { + ln := newLocalListener(t) defer ln.Close() + var ( + mu sync.Mutex + count int + conns []net.Conn + ) + var wg sync.WaitGroup tr := &Transport{ Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, } tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialer.Lock() - defer dialer.Unlock() - dialer.count++ - if dialer.count == 3 { - return nil, errors.New("unexpected number of dials") - } + mu.Lock() + defer mu.Unlock() + count++ cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { return nil, fmt.Errorf("dial error: %v", err) } + conns = append(conns, cc) sc, err := ln.Accept() if err != nil { return nil, fmt.Errorf("accept error: %v", err) } + conns = append(conns, sc) ct := &clientTester{ t: t, tr: tr, @@ -3741,19 +3705,25 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { sc: sc, fr: NewFramer(sc, sc), } - switch dialer.count { - case 1: - ct1 <- ct - case 2: - ct2 <- ct - } + wg.Add(1) + go func(count int) { + defer wg.Done() + server(count, ct) + }(count) return cc, nil } - errs := make(chan error, 3) + client(tr) + tr.CloseIdleConnections() + ln.Close() + for _, c := range conns { + c.Close() + } + wg.Wait() +} - // Client. - go func() { +func TestTransportRetryAfterGOAWAY(t *testing.T) { + client := func(tr *Transport) { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := tr.RoundTrip(req) if res != nil { @@ -3763,102 +3733,76 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { } } if err != nil { - err = fmt.Errorf("RoundTrip: %v", err) + t.Errorf("RoundTrip: %v", err) } - errs <- err - }() - - connToClose := make(chan io.Closer, 2) - - // Server for the first request. - go func() { - ct := <-ct1 - - connToClose <- ct.cc - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - errs <- nil - }() + } - // Server for the second request. - go func() { - ct := <-ct2 + server := func(count int, ct *clientTester) { + switch count { + case 1: + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + t.Errorf("server1 failed reading HEADERS: %v", err) + return + } + t.Logf("server1 got %v", hf) + if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { + t.Errorf("server1 failed writing GOAWAY: %v", err) + return + } + case 2: + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + t.Errorf("server2 failed reading HEADERS: %v", err) + return + } + t.Logf("server2 got %v", hf) - connToClose <- ct.cc - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + err = ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + if err != nil { + t.Errorf("server2 failed writing response HEADERS: %v", err) + } + default: + t.Errorf("unexpected number of dials") return } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) - } else { - errs <- nil - } - }() - - for k := 0; k < 3; k++ { - err := <-errs - if err != nil { - t.Error(err) - } } - close(connToClose) - for c := range connToClose { - c.Close() - } + testClientMultipleDials(t, client, server) } func TestTransportRetryAfterRefusedStream(t *testing.T) { clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } + client := func(tr *Transport) { defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) + resp, err := tr.RoundTrip(req) if err != nil { - return fmt.Errorf("RoundTrip: %v", err) + t.Errorf("RoundTrip: %v", err) + return } resp.Body.Close() if resp.StatusCode != 204 { - return fmt.Errorf("Status = %v; want 204", resp.StatusCode) + t.Errorf("Status = %v; want 204", resp.StatusCode) + return } - return nil } - ct.server = func() error { + + server := func(count int, ct *clientTester) { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) - nreq := 0 - for { f, err := ct.fr.ReadFrame() if err != nil { @@ -3867,19 +3811,19 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { // If the client's done, it // will have reported any // errors on its side. - return nil default: - return err + t.Error(err) } + return } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + t.Errorf("headers should have END_HEADERS be ended: %v", f) + return } - nreq++ - if nreq == 1 { + if count == 1 { ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) } else { enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) @@ -3891,11 +3835,13 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { }) } default: - return fmt.Errorf("Unexpected client frame %v", f) + t.Errorf("Unexpected client frame %v", f) + return } } } - ct.run() + + testClientMultipleDials(t, client, server) } func TestTransportRetryHasLimit(t *testing.T) { @@ -4072,6 +4018,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { greet := make(chan struct{}) // server sends initial SETTINGS frame gotRequest := make(chan struct{}) // server received a request clientDone := make(chan struct{}) + cancelClientRequest := make(chan struct{}) // Collect errors from goroutines. var wg sync.WaitGroup @@ -4150,9 +4097,8 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) if k == maxConcurrent { // This request will be canceled. - cancel := make(chan struct{}) - req.Cancel = cancel - close(cancel) + req.Cancel = cancelClientRequest + close(cancelClientRequest) _, err := ct.tr.RoundTrip(req) close(clientRequestCancelled) if err == nil { @@ -5526,14 +5472,21 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { } func TestClientConnReservations(t *testing.T) { - cc := &ClientConn{ - reqHeaderMu: make(chan struct{}, 1), - streams: make(map[uint32]*clientStream), - maxConcurrentStreams: initialMaxConcurrentStreams, - nextStreamID: 1, - t: &Transport{Options: &transport.Options{}}, + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, func(s *Server) { + s.MaxConcurrentStreams = initialMaxConcurrentStreams + }) + defer st.Close() + + tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} + defer tr.CloseIdleConnections() + + cc, err := tr.newClientConn(st.cc, false) + if err != nil { + t.Fatal(err) } - cc.cond = sync.NewCond(&cc.mu) + + req, _ := http.NewRequest("GET", st.ts.URL, nil) n := 0 for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n++ @@ -5541,8 +5494,8 @@ func TestClientConnReservations(t *testing.T) { if n != initialMaxConcurrentStreams { t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) } - if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) { - t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err) + if _, err := cc.RoundTrip(req); err != nil { + t.Fatalf("RoundTrip error = %v", err) } n2 := 0 for n2 <= 5 && cc.ReserveNewRequest() { @@ -5554,7 +5507,7 @@ func TestClientConnReservations(t *testing.T) { // Use up all the reservations for i := 0; i < n; i++ { - cc.RoundTrip(new(http.Request)) + cc.RoundTrip(req) } n2 = 0 @@ -5959,3 +5912,97 @@ func TestTransportSlowClose(t *testing.T) { } res.Body.Close() } + +type blockReadConn struct { + net.Conn + blockc chan struct{} +} + +func (c *blockReadConn) Read(b []byte) (n int, err error) { + <-c.blockc + return c.Conn.Read(b) +} + +func TestTransportReuseAfterError(t *testing.T) { + serverReqc := make(chan struct{}, 3) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + serverReqc <- struct{}{} + }, optOnlyServer) + defer st.Close() + + var ( + unblockOnce sync.Once + blockc = make(chan struct{}) + connCountMu sync.Mutex + connCount int + ) + tr := &Transport{ + Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + // The first connection dialed will block on reads until blockc is closed. + connCountMu.Lock() + defer connCountMu.Unlock() + connCount++ + conn, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + if connCount == 1 { + return &blockReadConn{ + Conn: conn, + blockc: blockc, + }, nil + } + return conn, nil + }, + } + defer tr.CloseIdleConnections() + defer unblockOnce.Do(func() { + // Ensure that reads on blockc are unblocked if we return early. + close(blockc) + }) + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + + // Request 1 is made on conn 1. + // Reading the response will block. + // Wait until the server receives the request, and continue. + req1c := make(chan struct{}) + go func() { + defer close(req1c) + res1, err := tr.RoundTrip(req.Clone(context.Background())) + if err != nil { + t.Errorf("request 1: %v", err) + } else { + res1.Body.Close() + } + }() + <-serverReqc + + // Request 2 is also made on conn 1. + // Reading the response will block. + // The request is canceled once the server receives it. + // Conn 1 should now be flagged as unfit for reuse. + req2Ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-serverReqc + cancel() + }() + _, err := tr.RoundTrip(req.Clone(req2Ctx)) + if err == nil { + t.Errorf("request 2 unexpectedly succeeded (want cancel)") + } + + // Request 3 is made on a new conn, and succeeds. + res3, err := tr.RoundTrip(req.Clone(context.Background())) + if err != nil { + t.Fatalf("request 3: %v", err) + } + res3.Body.Close() + + // Unblock conn 1, and verify that request 1 completes. + unblockOnce.Do(func() { + close(blockc) + }) + <-req1c +} diff --git a/transport_test.go b/transport_test.go index 741a6e9a..8a9e4cf6 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2780,18 +2780,6 @@ func TestTransportCloseResponseBody(t *testing.T) { } } -type fooProto struct{} - -func (fooProto) RoundTrip(req *http.Request) (*http.Response, error) { - res := &http.Response{ - Status: "200 OK", - StatusCode: 200, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), - } - return res, nil -} - func TestTransportNoHost(t *testing.T) { defer afterTest(t) tr := T() From f417a9b32b87217fd9b8a47c144fb36fa2020884 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 19:26:01 +0800 Subject: [PATCH 700/843] merge upstream net/http: 2023-06-15(c54632) --- internal/transport/option.go | 7 +- roundtrip.go | 3 +- roundtrip_js.go | 142 +++++++++++++++--- transfer.go | 16 +- transport.go | 48 ++++-- transport_default_other.go | 1 - ...default_js.go => transport_default_wasm.go | 3 +- 7 files changed, 176 insertions(+), 44 deletions(-) rename transport_default_js.go => transport_default_wasm.go (89%) diff --git a/internal/transport/option.go b/internal/transport/option.go index fc033e15..6f9f97f6 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -23,6 +23,11 @@ type Options struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) + // OnProxyConnectResponse is called when the Transport gets an HTTP response from + // a proxy for a CONNECT request. It's called before the check for a 200 OK response. + // If it returns an error, the request fails with that error. + OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *http.Request, connectRes *http.Response) error + // DialContext specifies the dial function for creating unencrypted TCP connections. // If DialContext is nil, then the transport dials using package net. // @@ -53,7 +58,7 @@ type Options struct { // If non-nil, HTTP/2 support may not be enabled by default. TLSClientConfig *tls.Config - // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // TLSHandshakeTimeout specifies the maximum amount of time to // wait for a TLS handshake. Zero means no timeout. TLSHandshakeTimeout time.Duration diff --git a/roundtrip.go b/roundtrip.go index 0a63edbb..52b49409 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js || !wasm -// +build !js !wasm +//go:build !js package req diff --git a/roundtrip_js.go b/roundtrip_js.go index 771a31d1..af51f13e 100644 --- a/roundtrip_js.go +++ b/roundtrip_js.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package req @@ -46,6 +45,44 @@ const jsFetchRedirect = "js.fetch:redirect" // the browser globals. var jsFetchMissing = js.Global().Get("fetch").IsUndefined() +// jsFetchDisabled will be true if the "process" global is present. +// We use this as an indicator that we're running in Node.js. We +// want to disable the Fetch API in Node.js because it breaks +// our wasm tests. See https://go.dev/issue/57613 for more information. +var jsFetchDisabled = !js.Global().Get("process").IsUndefined() + +// Determine whether the JS runtime supports streaming request bodies. +// Courtesy: https://developer.chrome.com/articles/fetch-streaming-requests/#feature-detection +func supportsPostRequestStreams() bool { + requestOpt := js.Global().Get("Object").New() + requestBody := js.Global().Get("ReadableStream").New() + + requestOpt.Set("method", "POST") + requestOpt.Set("body", requestBody) + + // There is quite a dance required to define a getter if you do not have the { get property() { ... } } + // syntax available. However, it is possible: + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Functions/get#defining_a_getter_on_existing_objects_using_defineproperty + duplexCalled := false + duplexGetterObj := js.Global().Get("Object").New() + duplexGetterFunc := js.FuncOf(func(this js.Value, args []js.Value) any { + duplexCalled = true + return "half" + }) + defer duplexGetterFunc.Release() + duplexGetterObj.Set("get", duplexGetterFunc) + js.Global().Get("Object").Call("defineProperty", requestOpt, "duplex", duplexGetterObj) + + // Slight difference here between the aforementioned example: Non-browser-based runtimes + // do not have a non-empty API Base URL (https://html.spec.whatwg.org/multipage/webappapis.html#api-base-url) + // so we have to supply a valid URL here. + requestObject := js.Global().Get("Request").New("https://www.example.org", requestOpt) + + hasContentTypeHeader := requestObject.Get("headers").Call("has", "Content-Type").Bool() + + return duplexCalled && !hasContentTypeHeader +} + // RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // The Transport has a documented contract that states that if the DialContext or @@ -54,7 +91,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // though they are deprecated. Therefore, if any of these are set, we should obey // the contract and dial using the regular round-trip instead. Otherwise, we'll try // to fall back on the Fetch API, unless it's not available. - if t.DialContext != nil || t.DialTLSContext != nil || jsFetchMissing { + if t.DialContext != nil || t.DialTLSContext != nil || jsFetchMissing || jsFetchDisabled { return t.roundTrip(req) } @@ -94,23 +131,60 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } opt.Set("headers", headers) + var readableStreamStart, readableStreamPull, readableStreamCancel js.Func if req.Body != nil { - // TODO(johanbrandhorst): Stream request body when possible. - // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue. - // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue. - // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. - // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API - // and browser support. - body, err := io.ReadAll(req.Body) - if err != nil { - req.Body.Close() // RoundTrip must always close the body, including on errors. - return nil, err - } - req.Body.Close() - if len(body) != 0 { - buf := uint8Array.New(len(body)) - js.CopyBytesToJS(buf, body) - opt.Set("body", buf) + if !supportsPostRequestStreams() { + body, err := io.ReadAll(req.Body) + if err != nil { + req.Body.Close() // RoundTrip must always close the body, including on errors. + return nil, err + } + if len(body) != 0 { + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) + } + } else { + readableStreamCtorArg := js.Global().Get("Object").New() + readableStreamCtorArg.Set("type", "bytes") + readableStreamCtorArg.Set("autoAllocateChunkSize", t.writeBufferSize()) + + readableStreamPull = js.FuncOf(func(this js.Value, args []js.Value) any { + controller := args[0] + byobRequest := controller.Get("byobRequest") + if byobRequest.IsNull() { + controller.Call("close") + } + + byobRequestView := byobRequest.Get("view") + + bodyBuf := make([]byte, byobRequestView.Get("byteLength").Int()) + readBytes, readErr := io.ReadFull(req.Body, bodyBuf) + if readBytes > 0 { + buf := uint8Array.New(byobRequestView.Get("buffer")) + js.CopyBytesToJS(buf, bodyBuf) + byobRequest.Call("respond", readBytes) + } + + if readErr == io.EOF || readErr == io.ErrUnexpectedEOF { + controller.Call("close") + } else if readErr != nil { + readErrCauseObject := js.Global().Get("Object").New() + readErrCauseObject.Set("cause", readErr.Error()) + readErr := js.Global().Get("Error").New("io.ReadFull failed while streaming POST body", readErrCauseObject) + controller.Call("error", readErr) + } + // Note: This a return from the pull callback of the controller and *not* RoundTrip(). + return nil + }) + readableStreamCtorArg.Set("pull", readableStreamPull) + + opt.Set("body", js.Global().Get("ReadableStream").New(readableStreamCtorArg)) + // There is a requirement from the WHATWG fetch standard that the duplex property of + // the object given as the options argument to the fetch call be set to 'half' + // when the body property of the same options object is a ReadableStream: + // https://fetch.spec.whatwg.org/#dom-requestinit-duplex + opt.Set("duplex", "half") } } @@ -120,9 +194,14 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { errCh = make(chan error, 1) success, failure js.Func ) - success = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() + readableStreamCancel.Release() + readableStreamPull.Release() + readableStreamStart.Release() + + req.Body.Close() result := args[0] header := http.Header{} @@ -184,10 +263,31 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return nil }) - failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) + readableStreamCancel.Release() + readableStreamPull.Release() + readableStreamStart.Release() + + req.Body.Close() + + err := args[0] + // The error is a JS Error type + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Error + // We can use the toString() method to get a string representation of the error. + errMsg := err.Call("toString").String() + // Errors can optionally contain a cause. + if cause := err.Get("cause"); !cause.IsUndefined() { + // The exact type of the cause is not defined, + // but if it's another error, we can call toString() on it too. + if !cause.Get("toString").IsUndefined() { + errMsg += ": " + cause.Call("toString").String() + } else if cause.Type() == js.TypeString { + errMsg += ": " + cause.String() + } + } + errCh <- fmt.Errorf("net/http: fetch() failed: %s", errMsg) return nil }) diff --git a/transfer.go b/transfer.go index d2ec69db..58e64793 100644 --- a/transfer.go +++ b/transfer.go @@ -405,7 +405,7 @@ func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err return } -// unwrapBodyReader unwraps the body's inner reader if it's a +// unwrapBody unwraps the body's inner reader if it's a // nopCloser. This is to ensure that body writes sourced from local // files (*os.File types) are properly optimized. // @@ -520,7 +520,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case t.Chunked: - if noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + if isResponse && (noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode)) { t.Body = NoBody } else { t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} @@ -555,7 +555,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { return nil } -// Checks whether chunked is part of the encodings stack +// Checks whether chunked is part of the encodings stack. func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Checks whether the encoding is explicitly "identity". @@ -590,7 +590,7 @@ func (t *transferReader) parseTransferEncoding() error { if len(raw) != 1 { return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} } - if !ascii.EqualFold(textproto.TrimString(raw[0]), "chunked") { + if !ascii.EqualFold(raw[0], "chunked") { return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} } @@ -638,7 +638,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He } // Logic based on response type or status - if noResponseBodyExpected(requestMethod) { + if isResponse && noResponseBodyExpected(requestMethod) { return 0, nil } if status/100 == 1 { @@ -673,7 +673,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He // Determine whether to hang up after sending a request and body, or // receiving a response and body -// 'header' is the request headers +// 'header' is the request headers. func shouldClose(major, minor int, header http.Header, removeCloseHeader bool) bool { if major < 1 { return true @@ -692,7 +692,7 @@ func shouldClose(major, minor int, header http.Header, removeCloseHeader bool) b return hasClose } -// Parse the trailer header +// Parse the trailer header. func fixTrailer(header http.Header, chunked bool) (http.Header, error) { vv, ok := header["Trailer"] if !ok { @@ -1010,7 +1010,7 @@ var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { }{})) // unwrapNopCloser return the underlying reader and true if r is a NopCloser -// else it return false +// else it return false. func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { switch reflect.TypeOf(r) { case nopCloserType, nopCloserWriterToType: diff --git a/transport.go b/transport.go index 9cb52421..baca70f2 100644 --- a/transport.go +++ b/transport.go @@ -855,7 +855,8 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error for _, v := range vv { if !httpguts.ValidHeaderFieldValue(v) { closeBody(req) - return nil, fmt.Errorf("net/http: invalid header field value %q for key %v", v, k) + // Don't include the value in the error, because it may be sensitive. + return nil, fmt.Errorf("net/http: invalid header field value for %q", k) } } } @@ -962,6 +963,12 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error if e, ok := err.(transportReadFromServerError); ok { err = e.err } + if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose { + // Issue 49621: Close the request body if pconn.roundTrip + // didn't do so already. This can happen if the pconn + // write loop exits without reading the write request. + closeBody(req) + } return nil, err } testHookRoundTripRetried() @@ -1465,7 +1472,11 @@ var zeroDialer net.Dialer func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { if t.DialContext != nil { - return t.DialContext(ctx, network, addr) + c, err := t.DialContext(ctx, network, addr) + if c == nil && err == nil { + err = errors.New("net/http: Transport.DialContext hook returned (nil, nil)") + } + return c, err } return zeroDialer.DialContext(ctx, network, addr) } @@ -2064,6 +2075,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers conn.Close() return nil, err } + + if t.OnProxyConnectResponse != nil { + err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) + if err != nil { + return nil, err + } + } + if resp.StatusCode != 200 { _, text, ok := util.CutString(resp.Status, " ") conn.Close() @@ -2471,7 +2490,7 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } - return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) } return err } @@ -2677,7 +2696,7 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { // common case. pc.closeLocked(errServerClosedIdle) } else { - pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %v", peekErr)) + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", peekErr)) } } @@ -2812,6 +2831,10 @@ type nothingWrittenError struct { error } +func (nwe nothingWrittenError) Unwrap() error { + return nwe.error +} + func (pc *persistConn) writeLoop() { defer close(pc.writeLoopDone) for { @@ -3028,7 +3051,10 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo // maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip // will wait to see the Request's Body.Write result after getting a // response from the server. See comments in (*persistConn).wroteRequest. -const maxWriteWaitBeforeConnReuse = 50 * time.Millisecond +// +// In tests, we set this to a large value to avoid flakiness from inconsistent +// recycling of connections. +var maxWriteWaitBeforeConnReuse = 50 * time.Millisecond // wroteRequest is a check before recycling a connection that the previous write // (from writeLoop above) happened and was successful. @@ -3224,7 +3250,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er req.logf("writeErrCh resv: %T/%#v", err, err) } if err != nil { - pc.close(fmt.Errorf("write error: %v", err)) + pc.close(fmt.Errorf("write error: %w", err)) return nil, pc.mapRoundTripError(req, startBytesWritten, err) } if d := pc.t.ResponseHeaderTimeout; d > 0 { @@ -3326,17 +3352,21 @@ var portMap = map[string]string{ "socks5": "1080", } -// canonicalAddr returns url.Host but always with a ":port" suffix -func canonicalAddr(url *url.URL) string { +func idnaASCIIFromURL(url *url.URL) string { addr := url.Hostname() if v, err := idnaASCII(addr); err == nil { addr = v } + return addr +} + +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { port := url.Port() if port == "" { port = portMap[url.Scheme] } - return net.JoinHostPort(addr, port) + return net.JoinHostPort(idnaASCIIFromURL(url), port) } // bodyEOFSignal is used by the HTTP/1 transport when reading response diff --git a/transport_default_other.go b/transport_default_other.go index a18e66b5..7d4e9103 100644 --- a/transport_default_other.go +++ b/transport_default_other.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js || !wasm -// +build !js !wasm package req diff --git a/transport_default_js.go b/transport_default_wasm.go similarity index 89% rename from transport_default_js.go rename to transport_default_wasm.go index 7cd8e335..731acdf2 100644 --- a/transport_default_js.go +++ b/transport_default_wasm.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build js && wasm -// +build js,wasm +//go:build (js && wasm) || wasip1 package req From 91c613c721f4c39a23ff100a42b7240af595f195 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 15 Jun 2023 19:26:44 +0800 Subject: [PATCH 701/843] go mod tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 06047ad9..2bb2bcb7 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 github.com/quic-go/quic-go v0.35.0 + github.com/refraction-networking/utls v1.3.2 golang.org/x/net v0.10.0 golang.org/x/text v0.9.0 ) @@ -21,7 +22,6 @@ require ( github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect - github.com/refraction-networking/utls v1.3.2 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/mod v0.10.0 // indirect From f4f07846da53721db1dde281826c231a23a6a183 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 10:22:05 +0800 Subject: [PATCH 702/843] remove debug print --- internal/http2/transport.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 7f7f34db..837599e0 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -288,7 +288,6 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error } func (cs *clientStream) abortStream(err error) { - fmt.Println("abortStream") cs.cc.mu.Lock() defer cs.cc.mu.Unlock() cs.abortStreamLocked(err) @@ -1115,7 +1114,6 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cancelRequest := func(cs *clientStream, err error) error { cs.cc.mu.Lock() - fmt.Println("cancelRequest---") cs.abortStreamLocked(err) bodyClosed := cs.reqBodyClosed if cs.ID != 0 { @@ -1165,15 +1163,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // golang.org/issue/49645 return handleResponseHeaders() default: - fmt.Println("just abort") waitDone() return nil, cancelRequest(cs, cs.abortErr) } case <-ctx.Done(): - fmt.Println("ctx done") return nil, cancelRequest(cs, ctx.Err()) case <-cs.reqCancel: - fmt.Println("reqCancel") return nil, cancelRequest(cs, common.ErrRequestCanceled) } } @@ -1439,7 +1434,6 @@ func (cs *clientStream) cleanupWriteRequest(err error) { } } if err != nil { - fmt.Println("cleanupWriteRequest") cs.abortStream(err) // possibly redundant, but harmless if cs.sentHeaders { if se, ok := err.(StreamError); ok { From e2b23b9c2f2b79c413e520ca1697c6223b486aba Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 10:29:03 +0800 Subject: [PATCH 703/843] merge go115.go into transport.go in http2 --- internal/http2/go115.go | 28 ---------------------------- internal/http2/transport.go | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 28 deletions(-) delete mode 100644 internal/http2/go115.go diff --git a/internal/http2/go115.go b/internal/http2/go115.go deleted file mode 100644 index a3a4dfc5..00000000 --- a/internal/http2/go115.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.15 -// +build go1.15 - -package http2 - -import ( - "context" - "crypto/tls" - reqtls "github.com/imroc/req/v3/pkg/tls" -) - -// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS -// connection. -func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { - dialer := &tls.Dialer{ - Config: cfg, - } - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - tlsCn := conn.(reqtls.Conn) - return tlsCn, nil -} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 837599e0..21bf6099 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -21,6 +21,7 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/transport" + reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" "log" @@ -577,6 +578,20 @@ func (t *Transport) newTLSConfig(host string) *tls.Config { return cfg } +// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS +// connection. +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, + } + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil +} + func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS From 5aa91a6389a637b28b7384c9e57700cd6de9efd8 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 10:41:16 +0800 Subject: [PATCH 704/843] go get -u github.com/quic-go/quic-go --- go.mod | 18 +++++++++--------- go.sum | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 2bb2bcb7..b1f7e59d 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.19 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.35.0 + github.com/quic-go/quic-go v0.35.1 github.com/refraction-networking/utls v1.3.2 - golang.org/x/net v0.10.0 - golang.org/x/text v0.9.0 + golang.org/x/net v0.11.0 + golang.org/x/text v0.10.0 ) require ( @@ -16,15 +16,15 @@ require ( github.com/gaukas/godicttls v0.0.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect + github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/klauspost/compress v1.15.15 // indirect - github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/onsi/ginkgo/v2 v2.10.0 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect - golang.org/x/crypto v0.9.0 // indirect + golang.org/x/crypto v0.10.0 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect - golang.org/x/mod v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/tools v0.9.1 // indirect + golang.org/x/mod v0.11.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/tools v0.10.0 // indirect ) diff --git a/go.sum b/go.sum index bf58ebf9..8b4b60f3 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -23,6 +25,8 @@ github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7y github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= +github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -34,6 +38,8 @@ github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8G github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= +github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -44,16 +50,22 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= +golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -63,16 +75,22 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= +golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From dd9c751666b4de36c34dbc4e23447676d3602a4e Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 10:56:07 +0800 Subject: [PATCH 705/843] support quic-go v0.35.1 --- internal/http3/client.go | 8 ++++++-- internal/http3/roundtrip.go | 23 +++++++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index 43249483..6c04253c 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -102,8 +102,12 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con tlsConf = tlsConf.Clone() } if tlsConf.ServerName == "" { - host, _, _ := net.SplitHostPort(hostname) - tlsConf.ServerName = host + sni, _, err := net.SplitHostPort(hostname) + if err != nil { + // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. + sni = hostname + } + tlsConf.ServerName = sni } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index ed26e49a..3171e09d 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -20,9 +20,6 @@ import ( "golang.org/x/net/http/httpguts" ) -// declare this as a variable, such that we can it mock it in the tests -var quicDialer = quic.DialEarly - type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool @@ -74,7 +71,7 @@ type RoundTripper struct { newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) // so we can mock it in tests clients map[string]*roundTripCloserWithCount - udpConn *net.UDPConn + transport *quic.Transport } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -206,11 +203,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr } dial := r.Dial if dial == nil { - if r.udpConn == nil { - r.udpConn, err = net.ListenUDP("udp", nil) + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, false, err } + r.transport = &quic.Transport{Conn: udpConn} } dial = r.makeDialer() } @@ -259,9 +257,14 @@ func (r *RoundTripper) Close() error { } } r.clients = nil - if r.udpConn != nil { - r.udpConn.Close() - r.udpConn = nil + if r.transport != nil { + if err := r.transport.Close(); err != nil { + return err + } + if err := r.transport.Conn.Close(); err != nil { + return err + } + r.transport = nil } return nil } @@ -301,7 +304,7 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf if err != nil { return nil, err } - return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) } } From a6ae22b5d9cd306ffba0d27637d53ecb9128b7e4 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 11:14:48 +0800 Subject: [PATCH 706/843] do not detect alt-svc if it's already http3 --- roundtrip.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roundtrip.go b/roundtrip.go index 52b49409..69953742 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -26,7 +26,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if err != nil { return } - if t.altSvcJar != nil { + if resp.ProtoMajor != 3 && t.altSvcJar != nil { if v := resp.Header.Get("alt-svc"); v != "" { t.handleAltSvc(req, v) } From 394268529f92f715dd0eae094ff4b7362e1a7162 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 12:00:27 +0800 Subject: [PATCH 707/843] add comments to SetContext --- request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/request.go b/request.go index 8fdf5daa..8c738fc3 100644 --- a/request.go +++ b/request.go @@ -846,6 +846,9 @@ func (r *Request) Context() context.Context { // to interrupt the request execution if ctx.Done() channel is closed. // See https://blog.golang.org/context article and the "context" package // documentation. +// +// Attention: make sure call SetContext before EnableDumpXXX if you want to +// dump at the request level. func (r *Request) SetContext(ctx context.Context) *Request { if ctx != nil { r.ctx = ctx From 314aef5dcd6c48434472dd0b7f40ce8c898519cb Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 14:57:12 +0800 Subject: [PATCH 708/843] fix transport middleware not work after clone(#233) --- transport.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/transport.go b/transport.go index baca70f2..e18ca80f 100644 --- a/transport.go +++ b/transport.go @@ -134,6 +134,7 @@ type Transport struct { // Only valid when DisableAutoDecode is true. autoDecodeContentType func(contentType string) bool wrappedRoundTrip http.RoundTripper + httpRoundTripWrappers []HttpRoundTripWrapper } // NewTransport is an alias of T @@ -194,11 +195,15 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } if t.wrappedRoundTrip == nil { + t.httpRoundTripWrappers = wrappers fn := func(req *http.Request) (*http.Response, error) { return t.roundTrip(req) } t.wrappedRoundTrip = HttpRoundTripFunc(fn) + } else { + t.httpRoundTripWrappers = append(t.httpRoundTripWrappers, wrappers...) } + for _, w := range wrappers { t.wrappedRoundTrip = w(t.wrappedRoundTrip) } @@ -724,8 +729,17 @@ func (t *Transport) Clone() *Transport { Options: t.Options.Clone(), disableAutoDecode: t.disableAutoDecode, autoDecodeContentType: t.autoDecodeContentType, - wrappedRoundTrip: t.wrappedRoundTrip, forceHttpVersion: t.forceHttpVersion, + httpRoundTripWrappers: t.httpRoundTripWrappers, + } + if len(tt.httpRoundTripWrappers) > 0 { // clone transport middleware + fn := func(req *http.Request) (*http.Response, error) { + return tt.roundTrip(req) + } + tt.wrappedRoundTrip = HttpRoundTripFunc(fn) + for _, w := range tt.httpRoundTripWrappers { + tt.wrappedRoundTrip = w(tt.wrappedRoundTrip) + } } if t.t2 != nil { tt.t2 = &http2.Transport{Options: &tt.Options} From 1b84d71fde26b11dc47ec494882465c8ccdd7b29 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 15:06:26 +0800 Subject: [PATCH 709/843] fix client middleware not work after clone --- client.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/client.go b/client.go index 29bfeedc..8a004c1d 100644 --- a/client.go +++ b/client.go @@ -68,6 +68,7 @@ type Client struct { udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware wrappedRoundTrip RoundTripper + roundTripWrappers []RoundTripWrapper responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) resultStateCheckFunc func(resp *Response) ResultState } @@ -1363,6 +1364,14 @@ func (c *Client) Clone() *Client { client.Transport = cc.t cc.httpClient = &client + // clone client middleware + if len(cc.roundTripWrappers) > 0 { + cc.wrappedRoundTrip = roundTripImpl{&cc} + for _, w := range cc.roundTripWrappers { + cc.wrappedRoundTrip = w(cc.wrappedRoundTrip) + } + } + // clone other fields that may need to be cloned cc.Headers = cloneHeaders(c.Headers) cc.Cookies = cloneCookies(c.Cookies) @@ -1473,7 +1482,10 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { return c } if c.wrappedRoundTrip == nil { + c.roundTripWrappers = wrappers c.wrappedRoundTrip = roundTripImpl{c} + } else { + c.roundTripWrappers = append(c.roundTripWrappers, wrappers...) } for _, w := range wrappers { c.wrappedRoundTrip = w(c.wrappedRoundTrip) From d56ac371e841336ffa349dc5cc34c68ad1855ed1 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 15:30:04 +0800 Subject: [PATCH 710/843] Allow PATCH multipart request --- middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index be48388b..74ceb09a 100644 --- a/middleware.go +++ b/middleware.go @@ -195,7 +195,7 @@ func parseRequestBody(c *Client, r *Request) (err error) { return } // handle multipart - if r.isMultiPart && (r.Method != http.MethodPatch) { + if r.isMultiPart { return handleMultiPart(c, r) } From a50ca79d2993417cecdc76b116f2b494a6a36065 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 15:44:05 +0800 Subject: [PATCH 711/843] Add nil check to SetSuccessResult and SetErrorResult --- request.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/request.go b/request.go index 8c738fc3..4ef20701 100644 --- a/request.go +++ b/request.go @@ -364,6 +364,9 @@ func (r *Request) SetResult(result interface{}) *Request { // Request.SetResultStateCheckFunc or Client.SetResultStateCheckFunc to customize // the result state check logic. func (r *Request) SetSuccessResult(result interface{}) *Request { + if result == nil { + return r + } r.Result = util.GetPointer(result) return r } @@ -374,16 +377,19 @@ func (r *Request) SetSuccessResult(result interface{}) *Request { // or Client.SetResultStateCheckFunc to customize the result state check logic. // // Deprecated: Use SetErrorResult result. -func (r *Request) SetError(error interface{}) *Request { - return r.SetErrorResult(error) +func (r *Request) SetError(err interface{}) *Request { + return r.SetErrorResult(err) } // SetErrorResult set the result that response body will be unmarshalled to if // no error occurs and Response.ResultState() returns ErrorState, by default // it requires HTTP status `code >= 400`, you can also use Request.SetResultStateCheckFunc // or Client.SetResultStateCheckFunc to customize the result state check logic. -func (r *Request) SetErrorResult(error interface{}) *Request { - r.Error = util.GetPointer(error) +func (r *Request) SetErrorResult(err interface{}) *Request { + if err == nil { + return r + } + r.Error = util.GetPointer(err) return r } From 767a6b95f8d270abd672ceefc0624075fbcf3a93 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 16 Jun 2023 16:02:48 +0800 Subject: [PATCH 712/843] reset file reader when retry a multipart file upload --- middleware.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/middleware.go b/middleware.go index 74ceb09a..fb76d872 100644 --- a/middleware.go +++ b/middleware.go @@ -63,6 +63,14 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e return err } defer content.Close() + if r.RetryAttempt > 0 { // reset file reader when retry a multipart file upload + if rs, ok := content.(io.ReadSeeker); ok { + _, err = rs.Seek(0, io.SeekStart) + if err != nil { + return err + } + } + } // Auto detect actual multipart content type cbuf := make([]byte, 512) seeEOF := false From ef6c1ad6955e09fdcc62b37dc69e9978c2469272 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 17 Jun 2023 10:17:31 +0800 Subject: [PATCH 713/843] Support http digest authentication. * Add SetDigestAuth for Request. * Add SetCommonDigestAuth for Client. * Add OnAfterResponse for Request --- client.go | 63 ++++++--- digest.go | 277 ++++++++++++++++++++++++++++++++++++++ internal/header/header.go | 2 + middleware.go | 5 +- request.go | 49 ++++--- retry_test.go | 6 +- 6 files changed, 363 insertions(+), 39 deletions(-) create mode 100644 digest.go diff --git a/client.go b/client.go index 8a004c1d..96793361 100644 --- a/client.go +++ b/client.go @@ -767,57 +767,69 @@ func (c *Client) EnableAutoReadResponse() *Client { return c } -// SetAutoDecodeContentType set the content types that will be auto-detected and decode -// to utf-8 (e.g. "json", "xml", "html", "text"). +// SetAutoDecodeContentType set the content types that will be auto-detected and decode to utf-8 +// (e.g. "json", "xml", "html", "text"). func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { c.t.SetAutoDecodeContentType(contentTypes...) return c } -// SetAutoDecodeContentTypeFunc set the function that determines whether the -// specified `Content-Type` should be auto-detected and decode to utf-8. +// SetAutoDecodeContentTypeFunc set the function that determines whether the specified `Content-Type` should be auto-detected and decode to utf-8. func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { c.t.SetAutoDecodeContentTypeFunc(fn) return c } -// SetAutoDecodeAllContentType enable try auto-detect charset and decode all -// content type to utf-8. +// SetAutoDecodeAllContentType enable try auto-detect charset and decode all content type to utf-8. func (c *Client) SetAutoDecodeAllContentType() *Client { c.t.SetAutoDecodeAllContentType() return c } -// DisableAutoDecode disable auto-detect charset and decode to utf-8 -// (enabled by default). +// DisableAutoDecode disable auto-detect charset and decode to utf-8 (enabled by default). func (c *Client) DisableAutoDecode() *Client { c.t.DisableAutoDecode() return c } -// EnableAutoDecode enable auto-detect charset and decode to utf-8 -// (enabled by default). +// EnableAutoDecode enable auto-detect charset and decode to utf-8 (enabled by default). func (c *Client) EnableAutoDecode() *Client { c.t.EnableAutoDecode() return c } -// SetUserAgent set the "User-Agent" header for requests fired from -// the client. +// SetUserAgent set the "User-Agent" header for requests fired from the client. func (c *Client) SetUserAgent(userAgent string) *Client { return c.SetCommonHeader(header.UserAgent, userAgent) } -// SetCommonBearerAuthToken set the bearer auth token for requests -// fired from the client. +// SetCommonBearerAuthToken set the bearer auth token for requests fired from the client. func (c *Client) SetCommonBearerAuthToken(token string) *Client { - return c.SetCommonHeader("Authorization", "Bearer "+token) + return c.SetCommonHeader(header.Authorization, "Bearer "+token) } // SetCommonBasicAuth set the basic auth for requests fired from // the client. func (c *Client) SetCommonBasicAuth(username, password string) *Client { - c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password)) + c.SetCommonHeader(header.Authorization, util.BasicAuthHeaderValue(username, password)) + return c +} + +// SetCommonDigestAuth sets the Digest Access auth scheme for requests fired from the client. If a server responds with +// 401 and sends a Digest challenge in the WWW-Authenticate Header, requests will be resent with the appropriate +// Authorization Header. +// +// For Example: To set the Digest scheme with user "roc" and password "123456" +// +// client.SetCommonDigestAuth("roc", "123456") +// +// Information about Digest Access Authentication can be found in RFC7616: +// +// https://datatracker.ietf.org/doc/html/rfc7616 +// +// See `Request.SetDigestAuth` +func (c *Client) SetCommonDigestAuth(username, password string) *Client { + c.OnAfterResponse(handleDigestAuthFunc(username, password)) return c } @@ -1500,6 +1512,21 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { err = resp.Err }() + if r.Headers == nil { + r.Headers = make(http.Header) + } + + for _, f := range r.client.udBeforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + for _, f := range r.client.beforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + // setup trace if r.trace == nil && r.client.trace { r.trace = &clientTrace{} @@ -1581,8 +1608,8 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { resp.Body = io.NopCloser(bytes.NewReader(resp.body)) } - for _, f := range r.client.afterResponse { - if e := f(r.client, resp); e != nil { + for _, f := range c.afterResponse { + if e := f(c, resp); e != nil { resp.Err = e } } diff --git a/digest.go b/digest.go new file mode 100644 index 00000000..11d0b91b --- /dev/null +++ b/digest.go @@ -0,0 +1,277 @@ +package req + +import ( + "crypto/md5" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "github.com/imroc/req/v3/internal/header" + "hash" + "io" + "net/http" + "strings" +) + +var ( + errDigestBadChallenge = errors.New("digest: challenge is bad") + errDigestCharset = errors.New("digest: unsupported charset") + errDigestAlgNotSupported = errors.New("digest: algorithm is not supported") + errDigestQopNotSupported = errors.New("digest: no supported qop in list") + errDigestNoQop = errors.New("digest: qop must be specified") +) + +var hashFuncs = map[string]func() hash.Hash{ + "": md5.New, + "MD5": md5.New, + "MD5-sess": md5.New, + "SHA-256": sha256.New, + "SHA-256-sess": sha256.New, + "SHA-512-256": sha512.New, + "SHA-512-256-sess": sha512.New, +} + +// create response middleware for http digest authentication. +func handleDigestAuthFunc(username, password string) ResponseMiddleware { + return func(client *Client, resp *Response) error { + if resp.Err != nil || resp.StatusCode != http.StatusUnauthorized { + return nil + } + auth, err := createDigestAuth(resp.Response, username, password) + if err != nil { + return err + } + r := resp.Request + req := *r.RawRequest + if req.Body != nil { + err = parseRequestBody(client, r) // re-setup body + if err != nil { + return err + } + if r.GetBody != nil { + body, err := r.GetBody() + if err != nil { + return err + } + req.Body = body + req.GetBody = r.GetBody + } + } + if req.Header == nil { + req.Header = make(http.Header) + } + req.Header.Set(header.Authorization, auth) + resp.Response, err = client.GetTransport().RoundTrip(&req) + return err + } +} + +func createDigestAuth(resp *http.Response, username, password string) (auth string, err error) { + chal := resp.Header.Get(header.WwwAuthenticate) + if chal == "" { + return "", errDigestBadChallenge + } + + c, err := parseChallenge(chal) + if err != nil { + return "", err + } + + // Form credentials based on the challenge + cr := newCredentials(resp.Request.URL.RequestURI(), resp.Request.Method, username, password, c) + auth, err = cr.authorize() + return +} + +func newCredentials(digestURI, method, username, password string, c *challenge) *credentials { + return &credentials{ + username: username, + userhash: c.userhash, + realm: c.realm, + nonce: c.nonce, + digestURI: digestURI, + algorithm: c.algorithm, + sessionAlg: strings.HasSuffix(c.algorithm, "-sess"), + opaque: c.opaque, + messageQop: c.qop, + nc: 0, + method: method, + password: password, + } +} + +type challenge struct { + realm string + domain string + nonce string + opaque string + stale string + algorithm string + qop string + userhash string +} + +func parseChallenge(input string) (*challenge, error) { + const ws = " \n\r\t" + const qs = `"` + s := strings.Trim(input, ws) + if !strings.HasPrefix(s, "Digest ") { + return nil, errDigestBadChallenge + } + s = strings.Trim(s[7:], ws) + sl := strings.Split(s, ", ") + c := &challenge{} + var r []string + for i := range sl { + r = strings.SplitN(sl[i], "=", 2) + if len(r) != 2 { + return nil, errDigestBadChallenge + } + switch r[0] { + case "realm": + c.realm = strings.Trim(r[1], qs) + case "domain": + c.domain = strings.Trim(r[1], qs) + case "nonce": + c.nonce = strings.Trim(r[1], qs) + case "opaque": + c.opaque = strings.Trim(r[1], qs) + case "stale": + c.stale = r[1] + case "algorithm": + c.algorithm = r[1] + case "qop": + c.qop = strings.Trim(r[1], qs) + case "charset": + if strings.ToUpper(strings.Trim(r[1], qs)) != "UTF-8" { + return nil, errDigestCharset + } + case "userhash": + c.userhash = strings.Trim(r[1], qs) + default: + return nil, errDigestBadChallenge + } + } + return c, nil +} + +type credentials struct { + username string + userhash string + realm string + nonce string + digestURI string + algorithm string + sessionAlg bool + cNonce string + opaque string + messageQop string + nc int + method string + password string +} + +func (c *credentials) authorize() (string, error) { + if _, ok := hashFuncs[c.algorithm]; !ok { + return "", errDigestAlgNotSupported + } + + if err := c.validateQop(); err != nil { + return "", err + } + + resp, err := c.resp() + if err != nil { + return "", err + } + + sl := make([]string, 0, 10) + if c.userhash == "true" { + // RFC 7616 3.4.4 + c.username = c.h(fmt.Sprintf("%s:%s", c.username, c.realm)) + sl = append(sl, fmt.Sprintf(`userhash=%s`, c.userhash)) + } + sl = append(sl, fmt.Sprintf(`username="%s"`, c.username)) + sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm)) + sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce)) + sl = append(sl, fmt.Sprintf(`uri="%s"`, c.digestURI)) + sl = append(sl, fmt.Sprintf(`response="%s"`, resp)) + sl = append(sl, fmt.Sprintf(`algorithm=%s`, c.algorithm)) + if c.opaque != "" { + sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque)) + } + if c.messageQop != "" { + sl = append(sl, fmt.Sprintf("qop=%s", c.messageQop)) + sl = append(sl, fmt.Sprintf("nc=%08x", c.nc)) + sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cNonce)) + } + + return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil +} + +func (c *credentials) validateQop() error { + // Currently only supporting auth quality of protection. TODO: add auth-int support + if c.messageQop == "" { + return errDigestNoQop + } + possibleQops := strings.Split(c.messageQop, ", ") + var authSupport bool + for _, qop := range possibleQops { + if qop == "auth" { + authSupport = true + break + } + } + if !authSupport { + return errDigestQopNotSupported + } + + c.messageQop = "auth" + + return nil +} + +func (c *credentials) h(data string) string { + hfCtor := hashFuncs[c.algorithm] + hf := hfCtor() + _, _ = hf.Write([]byte(data)) // Hash.Write never returns an error + return fmt.Sprintf("%x", hf.Sum(nil)) +} + +func (c *credentials) resp() (string, error) { + c.nc++ + + b := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return "", err + } + c.cNonce = fmt.Sprintf("%x", b)[:32] + + ha1 := c.ha1() + ha2 := c.ha2() + + return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", + c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil +} + +func (c *credentials) kd(secret, data string) string { + return c.h(fmt.Sprintf("%s:%s", secret, data)) +} + +// RFC 7616 3.4.2 +func (c *credentials) ha1() string { + ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password)) + if c.sessionAlg { + return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce)) + } + + return ret +} + +// RFC 7616 3.4.3 +func (c *credentials) ha2() string { + // currently no auth-int support + return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI)) +} diff --git a/internal/header/header.go b/internal/header/header.go index dcf4c56d..b6064558 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -9,4 +9,6 @@ const ( JsonContentType = "application/json; charset=utf-8" XmlContentType = "text/xml; charset=utf-8" FormContentType = "application/x-www-form-urlencoded" + WwwAuthenticate = "WWW-Authenticate" + Authorization = "Authorization" ) diff --git a/middleware.go b/middleware.go index fb76d872..50032e98 100644 --- a/middleware.go +++ b/middleware.go @@ -218,7 +218,10 @@ func parseRequestBody(c *Client, r *Request) (err error) { // handle marshal body if r.marshalBody != nil { - handleMarshalBody(c, r) + err = handleMarshalBody(c, r) + if err != nil { + return + } } if r.Body == nil { diff --git a/request.go b/request.go index 4ef20701..05ef2fc5 100644 --- a/request.go +++ b/request.go @@ -66,6 +66,7 @@ type Request struct { trace *clientTrace dumpBuffer *bytes.Buffer responseReturnTime time.Time + afterResponse []ResponseMiddleware } type GetContentFunc func() (io.ReadCloser, error) @@ -395,12 +396,35 @@ func (r *Request) SetErrorResult(err interface{}) *Request { // SetBearerAuthToken set bearer auth token for the request. func (r *Request) SetBearerAuthToken(token string) *Request { - return r.SetHeader("Authorization", "Bearer "+token) + return r.SetHeader(header.Authorization, "Bearer "+token) } // SetBasicAuth set basic auth for the request. func (r *Request) SetBasicAuth(username, password string) *Request { - return r.SetHeader("Authorization", util.BasicAuthHeaderValue(username, password)) + return r.SetHeader(header.Authorization, util.BasicAuthHeaderValue(username, password)) +} + +// SetDigestAuth sets the Digest Access auth scheme for the HTTP request. If a server responds with 401 and sends a +// Digest challenge in the WWW-Authenticate Header, the request will be resent with the appropriate Authorization Header. +// +// For Example: To set the Digest scheme with username "roc" and password "123456" +// +// client.R().SetDigestAuth("roc", "123456") +// +// Information about Digest Access Authentication can be found in RFC7616: +// +// https://datatracker.ietf.org/doc/html/rfc7616 +// +// This method overrides the username and password set by method `Client.SetCommonDigestAuth`. +func (r *Request) SetDigestAuth(username, password string) *Request { + r.OnAfterResponse(handleDigestAuthFunc(username, password)) + return r +} + +// OnAfterResponse add a response middleware which hooks after response received. +func (r *Request) OnAfterResponse(m ResponseMiddleware) *Request { + r.afterResponse = append(r.afterResponse, m) + return r } // SetHeaders set headers from a map for the request. @@ -563,27 +587,18 @@ func (r *Request) do() (resp *Response, err error) { }() for { - for _, f := range r.client.udBeforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - for _, f := range r.client.beforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - - if r.Headers == nil { - r.Headers = make(http.Header) - } - if r.client.wrappedRoundTrip != nil { resp, err = r.client.wrappedRoundTrip.RoundTrip(r) } else { resp, err = r.client.roundTrip(r) } + for _, f := range r.afterResponse { + if err = f(r.client, resp); err != nil { + return + } + } + if r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { // absolutely cannot retry. return } diff --git a/retry_test.go b/retry_test.go index 20817438..5814b5fe 100644 --- a/retry_test.go +++ b/retry_test.go @@ -162,7 +162,7 @@ func TestRetryWithModify(t *testing.T) { } func TestRetryFalse(t *testing.T) { - resp, err := tc().R(). + resp, err := tc().SetTimeout(2 * time.Second).R(). SetRetryCount(1). SetRetryCondition(func(resp *Response, err error) bool { return false @@ -173,7 +173,7 @@ func TestRetryFalse(t *testing.T) { } func TestRetryTurnedOffWhenRetryCountEqZero(t *testing.T) { - resp, err := tc().R(). + resp, err := tc().SetTimeout(2 * time.Second).R(). SetRetryCount(0). SetRetryCondition(func(resp *Response, err error) bool { t.Fatal("retry condition should not be executed") @@ -184,7 +184,7 @@ func TestRetryTurnedOffWhenRetryCountEqZero(t *testing.T) { tests.AssertIsNil(t, resp.Response) tests.AssertEqual(t, 0, resp.Request.RetryAttempt) - resp, err = tc(). + resp, err = tc().SetTimeout(2 * time.Second). SetCommonRetryCount(0). SetCommonRetryCondition(func(resp *Response, err error) bool { t.Fatal("retry condition should not be executed") From 80a5a2b29aaf57e8ab2d57e0d2d0deb12d3ffbed Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 17 Jun 2023 10:44:33 +0800 Subject: [PATCH 714/843] add global wrappers --- client_wrapper.go | 283 ++++++++++++++++++++++++++++----------------- request_wrapper.go | 6 + 2 files changed, 184 insertions(+), 105 deletions(-) diff --git a/client_wrapper.go b/client_wrapper.go index 96907b63..d6540450 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -3,6 +3,7 @@ package req import ( "context" "crypto/tls" + utls "github.com/refraction-networking/utls" "io" "net" "net/http" @@ -11,19 +12,19 @@ import ( ) // WrapRoundTrip is a global wrapper methods which delegated -// to the default client's WrapRoundTrip. +// to the default client's Client.WrapRoundTrip. func WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { return defaultClient.WrapRoundTrip(wrappers...) } // WrapRoundTripFunc is a global wrapper methods which delegated -// to the default client's WrapRoundTripFunc. +// to the default client's Client.WrapRoundTripFunc. func WrapRoundTripFunc(funcs ...RoundTripWrapperFunc) *Client { return defaultClient.WrapRoundTripFunc(funcs...) } // SetCommonError is a global wrapper methods which delegated -// to the default client's SetCommonErrorResult. +// to the default client's Client.SetCommonErrorResult. // // Deprecated: Use SetCommonErrorResult instead. func SetCommonError(err interface{}) *Client { @@ -31,613 +32,685 @@ func SetCommonError(err interface{}) *Client { } // SetCommonErrorResult is a global wrapper methods which delegated -// to the default client's SetCommonError. +// to the default client's Client.SetCommonError. func SetCommonErrorResult(err interface{}) *Client { return defaultClient.SetCommonErrorResult(err) } // SetResultStateCheckFunc is a global wrapper methods which delegated -// to the default client's SetCommonResultStateCheckFunc. +// to the default client's Client.SetCommonResultStateCheckFunc. func SetResultStateCheckFunc(fn func(resp *Response) ResultState) *Client { return defaultClient.SetResultStateCheckFunc(fn) } // SetCommonFormDataFromValues is a global wrapper methods which delegated -// to the default client's SetCommonFormDataFromValues. +// to the default client's Client.SetCommonFormDataFromValues. func SetCommonFormDataFromValues(data url.Values) *Client { return defaultClient.SetCommonFormDataFromValues(data) } // SetCommonFormData is a global wrapper methods which delegated -// to the default client's SetCommonFormData. +// to the default client's Client.SetCommonFormData. func SetCommonFormData(data map[string]string) *Client { return defaultClient.SetCommonFormData(data) } // SetBaseURL is a global wrapper methods which delegated -// to the default client's SetBaseURL. +// to the default client's Client.SetBaseURL. func SetBaseURL(u string) *Client { return defaultClient.SetBaseURL(u) } // SetOutputDirectory is a global wrapper methods which delegated -// to the default client's SetOutputDirectory. +// to the default client's Client.SetOutputDirectory. func SetOutputDirectory(dir string) *Client { return defaultClient.SetOutputDirectory(dir) } // SetCertFromFile is a global wrapper methods which delegated -// to the default client's SetCertFromFile. +// to the default client's Client.SetCertFromFile. func SetCertFromFile(certFile, keyFile string) *Client { return defaultClient.SetCertFromFile(certFile, keyFile) } // SetCerts is a global wrapper methods which delegated -// to the default client's SetCerts. +// to the default client's Client.SetCerts. func SetCerts(certs ...tls.Certificate) *Client { return defaultClient.SetCerts(certs...) } // SetRootCertFromString is a global wrapper methods which delegated -// to the default client's SetRootCertFromString. +// to the default client's Client.SetRootCertFromString. func SetRootCertFromString(pemContent string) *Client { return defaultClient.SetRootCertFromString(pemContent) } // SetRootCertsFromFile is a global wrapper methods which delegated -// to the default client's SetRootCertsFromFile. +// to the default client's Client.SetRootCertsFromFile. func SetRootCertsFromFile(pemFiles ...string) *Client { return defaultClient.SetRootCertsFromFile(pemFiles...) } // GetTLSClientConfig is a global wrapper methods which delegated -// to the default client's GetTLSClientConfig. +// to the default client's Client.GetTLSClientConfig. func GetTLSClientConfig() *tls.Config { return defaultClient.GetTLSClientConfig() } // SetRedirectPolicy is a global wrapper methods which delegated -// to the default client's SetRedirectPolicy. +// to the default client's Client.SetRedirectPolicy. func SetRedirectPolicy(policies ...RedirectPolicy) *Client { return defaultClient.SetRedirectPolicy(policies...) } // DisableKeepAlives is a global wrapper methods which delegated -// to the default client's DisableKeepAlives. +// to the default client's Client.DisableKeepAlives. func DisableKeepAlives() *Client { return defaultClient.DisableKeepAlives() } // EnableKeepAlives is a global wrapper methods which delegated -// to the default client's EnableKeepAlives. +// to the default client's Client.EnableKeepAlives. func EnableKeepAlives() *Client { return defaultClient.EnableKeepAlives() } // DisableCompression is a global wrapper methods which delegated -// to the default client's DisableCompression. +// to the default client's Client.DisableCompression. func DisableCompression() *Client { return defaultClient.DisableCompression() } // EnableCompression is a global wrapper methods which delegated -// to the default client's EnableCompression. +// to the default client's Client.EnableCompression. func EnableCompression() *Client { return defaultClient.EnableCompression() } // SetTLSClientConfig is a global wrapper methods which delegated -// to the default client's SetTLSClientConfig. +// to the default client's Client.SetTLSClientConfig. func SetTLSClientConfig(conf *tls.Config) *Client { return defaultClient.SetTLSClientConfig(conf) } // EnableInsecureSkipVerify is a global wrapper methods which delegated -// to the default client's EnableInsecureSkipVerify. +// to the default client's Client.EnableInsecureSkipVerify. func EnableInsecureSkipVerify() *Client { return defaultClient.EnableInsecureSkipVerify() } // DisableInsecureSkipVerify is a global wrapper methods which delegated -// to the default client's DisableInsecureSkipVerify. +// to the default client's Client.DisableInsecureSkipVerify. func DisableInsecureSkipVerify() *Client { return defaultClient.DisableInsecureSkipVerify() } // SetCommonQueryParams is a global wrapper methods which delegated -// to the default client's SetCommonQueryParams. +// to the default client's Client.SetCommonQueryParams. func SetCommonQueryParams(params map[string]string) *Client { return defaultClient.SetCommonQueryParams(params) } // AddCommonQueryParam is a global wrapper methods which delegated -// to the default client's AddCommonQueryParam. +// to the default client's Client.AddCommonQueryParam. func AddCommonQueryParam(key, value string) *Client { return defaultClient.AddCommonQueryParam(key, value) } // AddCommonQueryParams is a global wrapper methods which delegated -// to the default client's AddCommonQueryParams. +// to the default client's Client.AddCommonQueryParams. func AddCommonQueryParams(key string, values ...string) *Client { return defaultClient.AddCommonQueryParams(key, values...) } // SetCommonPathParam is a global wrapper methods which delegated -// to the default client's SetCommonPathParam. +// to the default client's Client.SetCommonPathParam. func SetCommonPathParam(key, value string) *Client { return defaultClient.SetCommonPathParam(key, value) } // SetCommonPathParams is a global wrapper methods which delegated -// to the default client's SetCommonPathParams. +// to the default client's Client.SetCommonPathParams. func SetCommonPathParams(pathParams map[string]string) *Client { return defaultClient.SetCommonPathParams(pathParams) } // SetCommonQueryParam is a global wrapper methods which delegated -// to the default client's SetCommonQueryParam. +// to the default client's Client.SetCommonQueryParam. func SetCommonQueryParam(key, value string) *Client { return defaultClient.SetCommonQueryParam(key, value) } // SetCommonQueryString is a global wrapper methods which delegated -// to the default client's SetCommonQueryString. +// to the default client's Client.SetCommonQueryString. func SetCommonQueryString(query string) *Client { return defaultClient.SetCommonQueryString(query) } // SetCommonCookies is a global wrapper methods which delegated -// to the default client's SetCommonCookies. +// to the default client's Client.SetCommonCookies. func SetCommonCookies(cookies ...*http.Cookie) *Client { return defaultClient.SetCommonCookies(cookies...) } // DisableDebugLog is a global wrapper methods which delegated -// to the default client's DisableDebugLog. +// to the default client's Client.DisableDebugLog. func DisableDebugLog() *Client { return defaultClient.DisableDebugLog() } // EnableDebugLog is a global wrapper methods which delegated -// to the default client's EnableDebugLog. +// to the default client's Client.EnableDebugLog. func EnableDebugLog() *Client { return defaultClient.EnableDebugLog() } // DevMode is a global wrapper methods which delegated -// to the default client's DevMode. +// to the default client's Client.DevMode. func DevMode() *Client { return defaultClient.DevMode() } // SetScheme is a global wrapper methods which delegated -// to the default client's SetScheme. +// to the default client's Client.SetScheme. func SetScheme(scheme string) *Client { return defaultClient.SetScheme(scheme) } // SetLogger is a global wrapper methods which delegated -// to the default client's SetLogger. +// to the default client's Client.SetLogger. func SetLogger(log Logger) *Client { return defaultClient.SetLogger(log) } // SetTimeout is a global wrapper methods which delegated -// to the default client's SetTimeout. +// to the default client's Client.SetTimeout. func SetTimeout(d time.Duration) *Client { return defaultClient.SetTimeout(d) } // EnableDumpAll is a global wrapper methods which delegated -// to the default client's EnableDumpAll. +// to the default client's Client.EnableDumpAll. func EnableDumpAll() *Client { return defaultClient.EnableDumpAll() } // EnableDumpAllToFile is a global wrapper methods which delegated -// to the default client's EnableDumpAllToFile. +// to the default client's Client.EnableDumpAllToFile. func EnableDumpAllToFile(filename string) *Client { return defaultClient.EnableDumpAllToFile(filename) } // EnableDumpAllTo is a global wrapper methods which delegated -// to the default client's EnableDumpAllTo. +// to the default client's Client.EnableDumpAllTo. func EnableDumpAllTo(output io.Writer) *Client { return defaultClient.EnableDumpAllTo(output) } // EnableDumpAllAsync is a global wrapper methods which delegated -// to the default client's EnableDumpAllAsync. +// to the default client's Client.EnableDumpAllAsync. func EnableDumpAllAsync() *Client { return defaultClient.EnableDumpAllAsync() } // EnableDumpAllWithoutRequestBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutRequestBody. +// to the default client's Client.EnableDumpAllWithoutRequestBody. func EnableDumpAllWithoutRequestBody() *Client { return defaultClient.EnableDumpAllWithoutRequestBody() } // EnableDumpAllWithoutResponseBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutResponseBody. +// to the default client's Client.EnableDumpAllWithoutResponseBody. func EnableDumpAllWithoutResponseBody() *Client { return defaultClient.EnableDumpAllWithoutResponseBody() } // EnableDumpAllWithoutResponse is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutResponse. +// to the default client's Client.EnableDumpAllWithoutResponse. func EnableDumpAllWithoutResponse() *Client { return defaultClient.EnableDumpAllWithoutResponse() } // EnableDumpAllWithoutRequest is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutRequest. +// to the default client's Client.EnableDumpAllWithoutRequest. func EnableDumpAllWithoutRequest() *Client { return defaultClient.EnableDumpAllWithoutRequest() } // EnableDumpAllWithoutHeader is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutHeader. +// to the default client's Client.EnableDumpAllWithoutHeader. func EnableDumpAllWithoutHeader() *Client { return defaultClient.EnableDumpAllWithoutHeader() } // EnableDumpAllWithoutBody is a global wrapper methods which delegated -// to the default client's EnableDumpAllWithoutBody. +// to the default client's Client.EnableDumpAllWithoutBody. func EnableDumpAllWithoutBody() *Client { return defaultClient.EnableDumpAllWithoutBody() } // EnableDumpEachRequest is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequest. +// to the default client's Client.EnableDumpEachRequest. func EnableDumpEachRequest() *Client { return defaultClient.EnableDumpEachRequest() } // EnableDumpEachRequestWithoutBody is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutBody. +// to the default client's Client.EnableDumpEachRequestWithoutBody. func EnableDumpEachRequestWithoutBody() *Client { return defaultClient.EnableDumpEachRequestWithoutBody() } // EnableDumpEachRequestWithoutHeader is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutHeader. +// to the default client's Client.EnableDumpEachRequestWithoutHeader. func EnableDumpEachRequestWithoutHeader() *Client { return defaultClient.EnableDumpEachRequestWithoutHeader() } // EnableDumpEachRequestWithoutResponse is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutResponse. +// to the default client's Client.EnableDumpEachRequestWithoutResponse. func EnableDumpEachRequestWithoutResponse() *Client { return defaultClient.EnableDumpEachRequestWithoutResponse() } // EnableDumpEachRequestWithoutRequest is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutRequest. +// to the default client's Client.EnableDumpEachRequestWithoutRequest. func EnableDumpEachRequestWithoutRequest() *Client { return defaultClient.EnableDumpEachRequestWithoutRequest() } // EnableDumpEachRequestWithoutResponseBody is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutResponseBody. +// to the default client's Client.EnableDumpEachRequestWithoutResponseBody. func EnableDumpEachRequestWithoutResponseBody() *Client { return defaultClient.EnableDumpEachRequestWithoutResponseBody() } // EnableDumpEachRequestWithoutRequestBody is a global wrapper methods which delegated -// to the default client's EnableDumpEachRequestWithoutRequestBody. +// to the default client's Client.EnableDumpEachRequestWithoutRequestBody. func EnableDumpEachRequestWithoutRequestBody() *Client { return defaultClient.EnableDumpEachRequestWithoutRequestBody() } // DisableAutoReadResponse is a global wrapper methods which delegated -// to the default client's DisableAutoReadResponse. +// to the default client's Client.DisableAutoReadResponse. func DisableAutoReadResponse() *Client { return defaultClient.DisableAutoReadResponse() } // EnableAutoReadResponse is a global wrapper methods which delegated -// to the default client's EnableAutoReadResponse. +// to the default client's Client.EnableAutoReadResponse. func EnableAutoReadResponse() *Client { return defaultClient.EnableAutoReadResponse() } // SetAutoDecodeContentType is a global wrapper methods which delegated -// to the default client's SetAutoDecodeContentType. +// to the default client's Client.SetAutoDecodeContentType. func SetAutoDecodeContentType(contentTypes ...string) *Client { return defaultClient.SetAutoDecodeContentType(contentTypes...) } // SetAutoDecodeContentTypeFunc is a global wrapper methods which delegated -// to the default client's SetAutoDecodeAllTypeFunc. +// to the default client's Client.SetAutoDecodeAllTypeFunc. func SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { return defaultClient.SetAutoDecodeContentTypeFunc(fn) } // SetAutoDecodeAllContentType is a global wrapper methods which delegated -// to the default client's SetAutoDecodeAllContentType. +// to the default client's Client.SetAutoDecodeAllContentType. func SetAutoDecodeAllContentType() *Client { return defaultClient.SetAutoDecodeAllContentType() } // DisableAutoDecode is a global wrapper methods which delegated -// to the default client's DisableAutoDecode. +// to the default client's Client.DisableAutoDecode. func DisableAutoDecode() *Client { return defaultClient.DisableAutoDecode() } // EnableAutoDecode is a global wrapper methods which delegated -// to the default client's EnableAutoDecode. +// to the default client's Client.EnableAutoDecode. func EnableAutoDecode() *Client { return defaultClient.EnableAutoDecode() } // SetUserAgent is a global wrapper methods which delegated -// to the default client's SetUserAgent. +// to the default client's Client.SetUserAgent. func SetUserAgent(userAgent string) *Client { return defaultClient.SetUserAgent(userAgent) } // SetCommonBearerAuthToken is a global wrapper methods which delegated -// to the default client's SetCommonBearerAuthToken. +// to the default client's Client.SetCommonBearerAuthToken. func SetCommonBearerAuthToken(token string) *Client { return defaultClient.SetCommonBearerAuthToken(token) } // SetCommonBasicAuth is a global wrapper methods which delegated -// to the default client's SetCommonBasicAuth. +// to the default client's Client.SetCommonBasicAuth. func SetCommonBasicAuth(username, password string) *Client { return defaultClient.SetCommonBasicAuth(username, password) } +// SetCommonDigestAuth is a global wrapper methods which delegated +// to the default client's Client.SetCommonDigestAuth. +func SetCommonDigestAuth(username, password string) *Client { + return defaultClient.SetCommonDigestAuth(username, password) +} + // SetCommonHeaders is a global wrapper methods which delegated -// to the default client's SetCommonHeaders. +// to the default client's Client.SetCommonHeaders. func SetCommonHeaders(hdrs map[string]string) *Client { return defaultClient.SetCommonHeaders(hdrs) } // SetCommonHeader is a global wrapper methods which delegated -// to the default client's SetCommonHeader. +// to the default client's Client.SetCommonHeader. func SetCommonHeader(key, value string) *Client { return defaultClient.SetCommonHeader(key, value) } // SetCommonContentType is a global wrapper methods which delegated -// to the default client's SetCommonContentType. +// to the default client's Client.SetCommonContentType. func SetCommonContentType(ct string) *Client { return defaultClient.SetCommonContentType(ct) } // DisableDumpAll is a global wrapper methods which delegated -// to the default client's DisableDumpAll. +// to the default client's Client.DisableDumpAll. func DisableDumpAll() *Client { return defaultClient.DisableDumpAll() } // SetCommonDumpOptions is a global wrapper methods which delegated -// to the default client's SetCommonDumpOptions. +// to the default client's Client.SetCommonDumpOptions. func SetCommonDumpOptions(opt *DumpOptions) *Client { return defaultClient.SetCommonDumpOptions(opt) } // SetProxy is a global wrapper methods which delegated -// to the default client's SetProxy. +// to the default client's Client.SetProxy. func SetProxy(proxy func(*http.Request) (*url.URL, error)) *Client { return defaultClient.SetProxy(proxy) } // OnBeforeRequest is a global wrapper methods which delegated -// to the default client's OnBeforeRequest. +// to the default client's Client.OnBeforeRequest. func OnBeforeRequest(m RequestMiddleware) *Client { return defaultClient.OnBeforeRequest(m) } // OnAfterResponse is a global wrapper methods which delegated -// to the default client's OnAfterResponse. +// to the default client's Client.OnAfterResponse. func OnAfterResponse(m ResponseMiddleware) *Client { return defaultClient.OnAfterResponse(m) } // SetProxyURL is a global wrapper methods which delegated -// to the default client's SetProxyURL. +// to the default client's Client.SetProxyURL. func SetProxyURL(proxyUrl string) *Client { return defaultClient.SetProxyURL(proxyUrl) } // DisableTraceAll is a global wrapper methods which delegated -// to the default client's DisableTraceAll. +// to the default client's Client.DisableTraceAll. func DisableTraceAll() *Client { return defaultClient.DisableTraceAll() } // EnableTraceAll is a global wrapper methods which delegated -// to the default client's EnableTraceAll. +// to the default client's Client.EnableTraceAll. func EnableTraceAll() *Client { return defaultClient.EnableTraceAll() } // SetCookieJar is a global wrapper methods which delegated -// to the default client's SetCookieJar. +// to the default client's Client.SetCookieJar. func SetCookieJar(jar http.CookieJar) *Client { return defaultClient.SetCookieJar(jar) } +// GetCookies is a global wrapper methods which delegated +// to the default client's Client.GetCookies. +func GetCookies(url string) ([]*http.Cookie, error) { + return defaultClient.GetCookies(url) +} + // ClearCookies is a global wrapper methods which delegated -// to the default client's ClearCookies. +// to the default client's Client.ClearCookies. func ClearCookies() *Client { return defaultClient.ClearCookies() } // SetJsonMarshal is a global wrapper methods which delegated -// to the default client's SetJsonMarshal. +// to the default client's Client.SetJsonMarshal. func SetJsonMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetJsonMarshal(fn) } // SetJsonUnmarshal is a global wrapper methods which delegated -// to the default client's SetJsonUnmarshal. +// to the default client's Client.SetJsonUnmarshal. func SetJsonUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetJsonUnmarshal(fn) } // SetXmlMarshal is a global wrapper methods which delegated -// to the default client's SetXmlMarshal. +// to the default client's Client.SetXmlMarshal. func SetXmlMarshal(fn func(v interface{}) ([]byte, error)) *Client { return defaultClient.SetXmlMarshal(fn) } // SetXmlUnmarshal is a global wrapper methods which delegated -// to the default client's SetXmlUnmarshal. +// to the default client's Client.SetXmlUnmarshal. func SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Client { return defaultClient.SetXmlUnmarshal(fn) } // SetDialTLS is a global wrapper methods which delegated -// to the default client's SetDialTLS. +// to the default client's Client.SetDialTLS. func SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { return defaultClient.SetDialTLS(fn) } // SetDial is a global wrapper methods which delegated -// to the default client's SetDial. +// to the default client's Client.SetDial. func SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { return defaultClient.SetDial(fn) } // SetTLSHandshakeTimeout is a global wrapper methods which delegated -// to the default client's SetTLSHandshakeTimeout. +// to the default client's Client.SetTLSHandshakeTimeout. func SetTLSHandshakeTimeout(timeout time.Duration) *Client { return defaultClient.SetTLSHandshakeTimeout(timeout) } // EnableForceHTTP1 is a global wrapper methods which delegated -// to the default client's EnableForceHTTP1. +// to the default client's Client.EnableForceHTTP1. func EnableForceHTTP1() *Client { return defaultClient.EnableForceHTTP1() } // EnableForceHTTP2 is a global wrapper methods which delegated -// to the default client's EnableForceHTTP2. +// to the default client's Client.EnableForceHTTP2. func EnableForceHTTP2() *Client { return defaultClient.EnableForceHTTP2() } // EnableForceHTTP3 is a global wrapper methods which delegated -// to the default client's EnableForceHTTP3. +// to the default client's Client.EnableForceHTTP3. func EnableForceHTTP3() *Client { return defaultClient.EnableForceHTTP3() } // EnableHTTP3 is a global wrapper methods which delegated -// to the default client's EnableHTTP3. +// to the default client's Client.EnableHTTP3. func EnableHTTP3() *Client { return defaultClient.EnableHTTP3() } // DisableForceHttpVersion is a global wrapper methods which delegated -// to the default client's DisableForceHttpVersion. +// to the default client's Client.DisableForceHttpVersion. func DisableForceHttpVersion() *Client { return defaultClient.DisableForceHttpVersion() } // EnableH2C is a global wrapper methods which delegated -// to the default client's EnableH2C. +// to the default client's Client.EnableH2C. func EnableH2C() *Client { return defaultClient.EnableH2C() } // DisableH2C is a global wrapper methods which delegated -// to the default client's DisableH2C. +// to the default client's Client.DisableH2C. func DisableH2C() *Client { return defaultClient.DisableH2C() } // DisableAllowGetMethodPayload is a global wrapper methods which delegated -// to the default client's DisableAllowGetMethodPayload. +// to the default client's Client.DisableAllowGetMethodPayload. func DisableAllowGetMethodPayload() *Client { return defaultClient.DisableAllowGetMethodPayload() } // EnableAllowGetMethodPayload is a global wrapper methods which delegated -// to the default client's EnableAllowGetMethodPayload. +// to the default client's Client.EnableAllowGetMethodPayload. func EnableAllowGetMethodPayload() *Client { return defaultClient.EnableAllowGetMethodPayload() } // SetCommonRetryCount is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryCount for request. +// to the default client's Client.SetCommonRetryCount. func SetCommonRetryCount(count int) *Client { return defaultClient.SetCommonRetryCount(count) } // SetCommonRetryInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryInterval for request. +// to the default client's Client.SetCommonRetryInterval. func SetCommonRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Client { return defaultClient.SetCommonRetryInterval(getRetryIntervalFunc) } // SetCommonRetryFixedInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryFixedInterval for request. +// to the default client's Client.SetCommonRetryFixedInterval. func SetCommonRetryFixedInterval(interval time.Duration) *Client { return defaultClient.SetCommonRetryFixedInterval(interval) } // SetCommonRetryBackoffInterval is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryBackoffInterval for request. +// to the default client's Client.SetCommonRetryBackoffInterval. func SetCommonRetryBackoffInterval(min, max time.Duration) *Client { return defaultClient.SetCommonRetryBackoffInterval(min, max) } // SetCommonRetryHook is a global wrapper methods which delegated -// to the default client, create a request and SetRetryHook for request. +// to the default client's Client.SetCommonRetryHook. func SetCommonRetryHook(hook RetryHookFunc) *Client { return defaultClient.SetCommonRetryHook(hook) } // AddCommonRetryHook is a global wrapper methods which delegated -// to the default client, create a request and AddCommonRetryHook for request. +// to the default client's Client.AddCommonRetryHook. func AddCommonRetryHook(hook RetryHookFunc) *Client { return defaultClient.AddCommonRetryHook(hook) } // SetCommonRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and SetCommonRetryCondition for request. +// to the default client's Client.SetCommonRetryCondition. func SetCommonRetryCondition(condition RetryConditionFunc) *Client { return defaultClient.SetCommonRetryCondition(condition) } // AddCommonRetryCondition is a global wrapper methods which delegated -// to the default client, create a request and AddCommonRetryCondition for request. +// to the default client's Client.AddCommonRetryCondition. func AddCommonRetryCondition(condition RetryConditionFunc) *Client { return defaultClient.AddCommonRetryCondition(condition) } // SetResponseBodyTransformer is a global wrapper methods which delegated -// to the default client, create a request and SetResponseBodyTransformer for request. +// to the default client's Client.SetResponseBodyTransformer. func SetResponseBodyTransformer(fn func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error)) *Client { return defaultClient.SetResponseBodyTransformer(fn) } // SetUnixSocket is a global wrapper methods which delegated -// to the default client, create a request and SetUnixSocket for request. +// to the default client's Client.SetUnixSocket. func SetUnixSocket(file string) *Client { return defaultClient.SetUnixSocket(file) } +// SetTLSFingerprint is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprint. +func SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { + return defaultClient.SetTLSFingerprint(clientHelloID) +} + +// SetTLSFingerprintRandomized is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintRandomized. +func SetTLSFingerprintRandomized() *Client { + return defaultClient.SetTLSFingerprintRandomized() +} + +// SetTLSFingerprintChrome is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintChrome. +func SetTLSFingerprintChrome() *Client { + return defaultClient.SetTLSFingerprintChrome() +} + +// SetTLSFingerprintAndroid is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintAndroid. +func SetTLSFingerprintAndroid() *Client { + return defaultClient.SetTLSFingerprintAndroid() +} + +// SetTLSFingerprint360 is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprint360. +func SetTLSFingerprint360() *Client { + return defaultClient.SetTLSFingerprint360() +} + +// SetTLSFingerprintEdge is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintEdge. +func SetTLSFingerprintEdge() *Client { + return defaultClient.SetTLSFingerprintEdge() +} + +// SetTLSFingerprintFirefox is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintFirefox. +func SetTLSFingerprintFirefox() *Client { + return defaultClient.SetTLSFingerprintFirefox() +} + +// SetTLSFingerprintQQ is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintQQ. +func SetTLSFingerprintQQ() *Client { + return defaultClient.SetTLSFingerprintQQ() +} + +// SetTLSFingerprintIOS is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintIOS. +func SetTLSFingerprintIOS() *Client { + return defaultClient.SetTLSFingerprintIOS() +} + +// SetTLSFingerprintSafari is a global wrapper methods which delegated +// to the default client's Client.SetTLSFingerprintSafari. +func SetTLSFingerprintSafari() *Client { + return defaultClient.SetTLSFingerprintSafari() +} + // GetClient is a global wrapper methods which delegated -// to the default client's GetClient. +// to the default client's Client.GetClient. func GetClient() *http.Client { return defaultClient.GetClient() } // NewRequest is a global wrapper methods which delegated -// to the default client's NewRequest. +// to the default client's Client.NewRequest. func NewRequest() *Request { return defaultClient.R() } // R is a global wrapper methods which delegated -// to the default client's R(). +// to the default client's Client.R(). func R() *Request { return defaultClient.R() } diff --git a/request_wrapper.go b/request_wrapper.go index daa1667a..050eaad0 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -114,6 +114,12 @@ func SetBasicAuth(username, password string) *Request { return defaultClient.R().SetBasicAuth(username, password) } +// SetDigestAuth is a global wrapper methods which delegated +// to the default client, create a request and SetDigestAuth for request. +func SetDigestAuth(username, password string) *Request { + return defaultClient.R().SetDigestAuth(username, password) +} + // SetHeaders is a global wrapper methods which delegated // to the default client, create a request and SetHeaders for request. func SetHeaders(hdrs map[string]string) *Request { From f6d7e2aef5ce82ea4745c88f6b13286d6d9ab1a9 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 17 Jun 2023 11:39:50 +0800 Subject: [PATCH 715/843] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1d5b6c87..bc2a3531 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Full documentation is available on the official website: https://req.cool. * **Support Multiple HTTP Versions**: Support `HTTP/1.1`, `HTTP/2`, and `HTTP/3`, and can automatically detect the server side and select the optimal HTTP version for requests, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). * **TLS Fingerprinting**: Support tls fingerprinting resistance, so that we can access websites that prohibit crawler programs by identifying TLS handshake fingerprints (See [TLS Fingerprinting](https://req.cool/docs/tutorial/tls-fingerprinting/)). +* **Multiple Authentication Methods**: You can use HTTP Basic Auth, Bearer Auth Token and Digest Auth out of box (see [Authentication](https://req.cool/docs/tutorial/authentication/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). * **Exportable**: `req.Transport` is exportable. Compared with `http.Transport`, it also supports HTTP3, dump content, middleware, etc. It can directly replace the Transport of `http.Client` in existing projects, and obtain more powerful functions with minimal code change. * **Extensible**: Support Middleware for Request, Response, Client and Transport (See [Request and Response Middleware](https://req.cool/docs/tutorial/middleware-for-request-and-response/)) and [Client and Transport Middleware](https://req.cool/docs/tutorial/middleware-for-client-and-transport/)). From 45ace4b352928971318ca82c7282aa8366e2b36e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 22 Jun 2023 23:42:17 +0800 Subject: [PATCH 716/843] ensure err in client.roundTrip --- client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 96793361..f95fe32f 100644 --- a/client.go +++ b/client.go @@ -1509,7 +1509,11 @@ func (c *Client) WrapRoundTrip(wrappers ...RoundTripWrapper) *Client { func (c *Client) roundTrip(r *Request) (resp *Response, err error) { resp = &Response{Request: r} defer func() { - err = resp.Err + if err != nil { + resp.Err = err + } else { + err = resp.Err + } }() if r.Headers == nil { From d9fef28c7800369a1f2c10a7d6d87cf11360a031 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 28 Jun 2023 10:26:34 +0800 Subject: [PATCH 717/843] make sure beforeRequests executed before client middleware(fix #248) --- client.go | 15 --------------- request.go | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index f95fe32f..b6316b86 100644 --- a/client.go +++ b/client.go @@ -1516,21 +1516,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { } }() - if r.Headers == nil { - r.Headers = make(http.Header) - } - - for _, f := range r.client.udBeforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - for _, f := range r.client.beforeRequest { - if err = f(r.client, r); err != nil { - return - } - } - // setup trace if r.trace == nil && r.client.trace { r.trace = &clientTrace{} diff --git a/request.go b/request.go index 05ef2fc5..63df395d 100644 --- a/request.go +++ b/request.go @@ -587,6 +587,20 @@ func (r *Request) do() (resp *Response, err error) { }() for { + if r.Headers == nil { + r.Headers = make(http.Header) + } + for _, f := range r.client.udBeforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + for _, f := range r.client.beforeRequest { + if err = f(r.client, r); err != nil { + return + } + } + if r.client.wrappedRoundTrip != nil { resp, err = r.client.wrappedRoundTrip.RoundTrip(r) } else { From 4d556bbafbc12c3b7c838f954bc8330c2829c0d4 Mon Sep 17 00:00:00 2001 From: guoguangwu Date: Wed, 12 Jul 2023 10:04:01 +0800 Subject: [PATCH 718/843] chore: remove refs to deprecated io/ioutil --- internal/http2/pipe_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/http2/pipe_test.go b/internal/http2/pipe_test.go index 391ab3bf..326b94de 100644 --- a/internal/http2/pipe_test.go +++ b/internal/http2/pipe_test.go @@ -8,7 +8,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "testing" ) @@ -112,7 +111,7 @@ func TestPipeBreakWithError(t *testing.T) { io.WriteString(p, "foo") a := errors.New("test err") p.BreakWithError(a) - all, err := ioutil.ReadAll(p) + all, err := io.ReadAll(p) if string(all) != "" { t.Errorf("read bytes = %q; want empty string", all) } From 9355c9b19cb6fe24aa3e6c1f909c8ce158eb4edb Mon Sep 17 00:00:00 2001 From: guoguangwu Date: Wed, 12 Jul 2023 10:18:46 +0800 Subject: [PATCH 719/843] chore: slice replace loop --- middleware.go | 4 +--- retry.go | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/middleware.go b/middleware.go index 50032e98..94c6ad62 100644 --- a/middleware.go +++ b/middleware.go @@ -413,9 +413,7 @@ func parseRequestCookie(c *Client, r *Request) error { if len(c.Cookies) == 0 { return nil } - for _, ck := range c.Cookies { - r.Cookies = append(r.Cookies, ck) - } + r.Cookies = append(r.Cookies, c.Cookies...) return nil } diff --git a/retry.go b/retry.go index b79e0323..fa67c843 100644 --- a/retry.go +++ b/retry.go @@ -53,11 +53,7 @@ func (ro *retryOption) Clone() *retryOption { MaxRetries: ro.MaxRetries, GetRetryInterval: ro.GetRetryInterval, } - for _, c := range ro.RetryConditions { - o.RetryConditions = append(o.RetryConditions, c) - } - for _, h := range ro.RetryHooks { - o.RetryHooks = append(o.RetryHooks, h) - } + o.RetryConditions = append(o.RetryConditions, ro.RetryConditions...) + o.RetryHooks = append(o.RetryHooks, ro.RetryHooks...) return o } From 88d2db669d92864de26610ec5cc4e40f3d7dc552 Mon Sep 17 00:00:00 2001 From: guoguangwu Date: Wed, 12 Jul 2023 11:16:32 +0800 Subject: [PATCH 720/843] chore: use bytes.Equal instead bytes.Compare --- internal/http2/transport_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go index d587c1ae..5acecd3d 100644 --- a/internal/http2/transport_test.go +++ b/internal/http2/transport_test.go @@ -4757,7 +4757,7 @@ func testTransportBodyReadError(t *testing.T, body []byte) { // If the client's done, it // will have reported any // errors on its side. - if bytes.Compare(receivedBody, body) != 0 { + if !bytes.Equal(receivedBody, body) { return fmt.Errorf("body: %q; expected %q", receivedBody, body) } if resetCount != 1 { From 2ad0520cd970532c41e170e1c33216f949629fc6 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jul 2023 17:32:50 +0800 Subject: [PATCH 721/843] support quic-go v0.37.0 --- examples/find-popular-repo/go.mod | 14 +- examples/find-popular-repo/go.sum | 9 ++ examples/opentelemetry-jaeger-tracing/go.mod | 14 +- examples/opentelemetry-jaeger-tracing/go.sum | 9 ++ examples/upload/uploadclient/go.mod | 14 +- examples/upload/uploadclient/go.sum | 9 ++ examples/uploadcallback/uploadclient/go.mod | 14 +- examples/uploadcallback/uploadclient/go.sum | 9 ++ go.mod | 33 +++-- go.sum | 30 +++++ internal/http3/client.go | 24 +--- internal/http3/headers.go | 129 +++++++++++++++++++ internal/http3/request.go | 111 ---------------- internal/http3/request_writer.go | 4 + 14 files changed, 248 insertions(+), 175 deletions(-) create mode 100644 internal/http3/headers.go delete mode 100644 internal/http3/request.go diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index cc76a320..be0ddc0c 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -9,7 +9,7 @@ require github.com/imroc/req/v3 v3.0.0 require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/marten-seemann/qpack v0.2.1 // indirect @@ -19,11 +19,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect - golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.11.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index bad63e63..39bbba8a 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -31,6 +31,7 @@ github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aev github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= @@ -148,6 +149,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= @@ -166,6 +168,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -174,6 +177,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -194,6 +198,7 @@ golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -230,6 +235,7 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -238,6 +244,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -252,6 +259,7 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -292,6 +300,7 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/opentelemetry-jaeger-tracing/go.mod b/examples/opentelemetry-jaeger-tracing/go.mod index 888ea567..bc0c021e 100644 --- a/examples/opentelemetry-jaeger-tracing/go.mod +++ b/examples/opentelemetry-jaeger-tracing/go.mod @@ -17,7 +17,7 @@ require ( github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lucas-clemente/quic-go v0.28.1 // indirect @@ -28,11 +28,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect - golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.11.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/opentelemetry-jaeger-tracing/go.sum b/examples/opentelemetry-jaeger-tracing/go.sum index 676e2148..6c63d5db 100644 --- a/examples/opentelemetry-jaeger-tracing/go.sum +++ b/examples/opentelemetry-jaeger-tracing/go.sum @@ -35,6 +35,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= @@ -151,6 +152,7 @@ github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= @@ -177,6 +179,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -185,6 +188,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -208,6 +212,7 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -248,6 +253,7 @@ golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsi golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -257,6 +263,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -270,6 +277,7 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -310,6 +318,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index cfc07949..eb2cabe0 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -9,7 +9,7 @@ require github.com/imroc/req/v3 v3.0.0 require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lucas-clemente/quic-go v0.28.1 // indirect @@ -20,11 +20,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect - golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.11.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index f5d21907..c694e19d 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -30,6 +30,7 @@ github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aev github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= @@ -147,6 +148,7 @@ github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= @@ -165,6 +167,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -173,6 +176,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -198,6 +202,7 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -237,6 +242,7 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -246,6 +252,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -259,6 +266,7 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -299,6 +307,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod index cfc07949..eb2cabe0 100644 --- a/examples/uploadcallback/uploadclient/go.mod +++ b/examples/uploadcallback/uploadclient/go.mod @@ -9,7 +9,7 @@ require github.com/imroc/req/v3 v3.0.0 require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lucas-clemente/quic-go v0.28.1 // indirect @@ -20,11 +20,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220809012201-f428fae20770 // indirect - golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.11.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum index f5d21907..c694e19d 100644 --- a/examples/uploadcallback/uploadclient/go.sum +++ b/examples/uploadcallback/uploadclient/go.sum @@ -30,6 +30,7 @@ github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aev github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= @@ -147,6 +148,7 @@ github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= @@ -165,6 +167,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -173,6 +176,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -198,6 +202,7 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -237,6 +242,7 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -246,6 +252,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -259,6 +266,7 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -299,6 +307,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/go.mod b/go.mod index b1f7e59d..9471aad8 100644 --- a/go.mod +++ b/go.mod @@ -1,30 +1,29 @@ module github.com/imroc/req/v3 -go 1.19 +go 1.20 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.35.1 - github.com/refraction-networking/utls v1.3.2 - golang.org/x/net v0.11.0 - golang.org/x/text v0.10.0 + github.com/quic-go/quic-go v0.37.0 + github.com/refraction-networking/utls v1.3.3 + golang.org/x/net v0.12.0 + golang.org/x/text v0.11.0 ) require ( - github.com/andybalholm/brotli v1.0.4 // indirect - github.com/gaukas/godicttls v0.0.3 // indirect + github.com/andybalholm/brotli v1.0.5 // indirect + github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect + github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/klauspost/compress v1.15.15 // indirect - github.com/onsi/ginkgo/v2 v2.10.0 // indirect - github.com/quic-go/qtls-go1-19 v0.3.2 // indirect - github.com/quic-go/qtls-go1-20 v0.2.2 // indirect - golang.org/x/crypto v0.10.0 // indirect - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect - golang.org/x/mod v0.11.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/tools v0.10.0 // indirect + github.com/klauspost/compress v1.16.7 // indirect + github.com/onsi/ginkgo/v2 v2.11.0 // indirect + github.com/quic-go/qtls-go1-20 v0.3.0 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/tools v0.11.0 // indirect ) diff --git a/go.sum b/go.sum index 8b4b60f3..09460786 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,14 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= +github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= +github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= @@ -16,6 +20,8 @@ github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9 github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 h1:n6vlPhxsA+BW/XsS5+uqi7GyzaLa5MH7qlSLBZtRdiA= +github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -23,10 +29,14 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= +github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= +github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -36,12 +46,18 @@ github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc8 github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= +github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.37.0 h1:wf/Ym2yeWi98oQn4ahiBSqdnaXVxNQGj2oBQFgiVChc= +github.com/quic-go/quic-go v0.37.0/go.mod h1:XtCUOCALTTWbPyd0IxFfHf6h0sEMubRFvEYHl3QxKw8= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= +github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= +github.com/refraction-networking/utls v1.3.3/go.mod h1:DlecWW1LMlMJu+9qpzzQqdHDT/C2LAe03EdpLUz/RL8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -52,13 +68,19 @@ golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= @@ -66,6 +88,8 @@ golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -77,6 +101,8 @@ golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -84,6 +110,8 @@ golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -91,6 +119,8 @@ golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= +golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/http3/client.go b/internal/http3/client.go index 6c04253c..f336cc1a 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -466,26 +466,12 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui if !ok { panic(fmt.Sprintf("bad tls connection state type: %s", reflect.ValueOf(conn.ConnectionState().TLS).Field(0).Type().Name())) } - res := &http.Response{ - Proto: "HTTP/3.0", - ProtoMajor: 3, - Header: http.Header{}, - TLS: &connState, - Request: req, - } - for _, hf := range hfs { - switch hf.Name { - case ":status": - status, err := strconv.Atoi(hf.Value) - if err != nil { - return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header")) - } - res.StatusCode = status - res.Status = hf.Value + " " + http.StatusText(status) - default: - res.Header.Add(hf.Name, hf.Value) - } + res, err := responseFromHeaders(hfs) + if err != nil { + return nil, newStreamError(errorMessageError, err) } + res.Request = req + res.TLS = &connState respBody := newResponseBody(hstr, conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. diff --git a/internal/http3/headers.go b/internal/http3/headers.go new file mode 100644 index 00000000..2eb5ca29 --- /dev/null +++ b/internal/http3/headers.go @@ -0,0 +1,129 @@ +package http3 + +import ( + "errors" + "fmt" + "github.com/quic-go/qpack" + "golang.org/x/net/http/httpguts" + "net/http" + "strconv" + "strings" +) + +type Header struct { + // Pseudo header fields defined in RFC 9114 + Path string + Method string + Authority string + Scheme string + Status string + // for Extended connect + Protocol string + // parsed and deduplicated + ContentLength int64 + // all non-pseudo headers + Headers http.Header +} + +func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { + hdr := Header{Headers: make(http.Header, len(headers))} + var readFirstRegularHeader, readContentLength bool + var contentLengthStr string + for _, h := range headers { + // field names need to be lowercase, see section 4.2 of RFC 9114 + if strings.ToLower(h.Name) != h.Name { + return Header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) + } + if !httpguts.ValidHeaderFieldValue(h.Value) { + return Header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) + } + if h.IsPseudo() { + if readFirstRegularHeader { + // all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114 + return Header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) + } + var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses + switch h.Name { + case ":path": + hdr.Path = h.Value + case ":method": + hdr.Method = h.Value + case ":authority": + hdr.Authority = h.Value + case ":protocol": + hdr.Protocol = h.Value + case ":scheme": + hdr.Scheme = h.Value + case ":status": + hdr.Status = h.Value + isResponsePseudoHeader = true + default: + return Header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) + } + if isRequest && isResponsePseudoHeader { + return Header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) + } + if !isRequest && !isResponsePseudoHeader { + return Header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) + } + } else { + if !httpguts.ValidHeaderFieldName(h.Name) { + return Header{}, fmt.Errorf("invalid header field name: %q", h.Name) + } + readFirstRegularHeader = true + switch h.Name { + case "content-length": + // Ignore duplicate Content-Length headers. + // Fail if the duplicates differ. + if !readContentLength { + readContentLength = true + contentLengthStr = h.Value + } else if contentLengthStr != h.Value { + return Header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) + } + default: + hdr.Headers.Add(h.Name, h.Value) + } + } + } + if len(contentLengthStr) > 0 { + // use ParseUint instead of ParseInt, so that parsing fails on negative values + cl, err := strconv.ParseUint(contentLengthStr, 10, 63) + if err != nil { + return Header{}, fmt.Errorf("invalid content length: %w", err) + } + hdr.Headers.Set("Content-Length", contentLengthStr) + hdr.ContentLength = int64(cl) + } + return hdr, nil +} + +func hostnameFromRequest(req *http.Request) string { + if req.URL != nil { + return req.URL.Host + } + return "" +} + +func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) { + hdr, err := parseHeaders(headerFields, false) + if err != nil { + return nil, err + } + if hdr.Status == "" { + return nil, errors.New("missing status field") + } + rsp := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: hdr.Headers, + ContentLength: hdr.ContentLength, + } + status, err := strconv.Atoi(hdr.Status) + if err != nil { + return nil, fmt.Errorf("invalid status code: %w", err) + } + rsp.StatusCode = status + rsp.Status = hdr.Status + " " + http.StatusText(status) + return rsp, nil +} diff --git a/internal/http3/request.go b/internal/http3/request.go deleted file mode 100644 index 9af25a57..00000000 --- a/internal/http3/request.go +++ /dev/null @@ -1,111 +0,0 @@ -package http3 - -import ( - "errors" - "net/http" - "net/url" - "strconv" - "strings" - - "github.com/quic-go/qpack" -) - -func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { - var path, authority, method, protocol, scheme, contentLengthStr string - - httpHeaders := http.Header{} - for _, h := range headers { - switch h.Name { - case ":path": - path = h.Value - case ":method": - method = h.Value - case ":authority": - authority = h.Value - case ":protocol": - protocol = h.Value - case ":scheme": - scheme = h.Value - case "content-length": - contentLengthStr = h.Value - default: - if !h.IsPseudo() { - httpHeaders.Add(h.Name, h.Value) - } - } - } - - // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 - if len(httpHeaders["Cookie"]) > 0 { - httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) - } - - isConnect := method == http.MethodConnect - // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 - isExtendedConnected := isConnect && protocol != "" - if isExtendedConnected { - if scheme == "" || path == "" || authority == "" { - return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") - } - } else if isConnect { - if path != "" || authority == "" { // normal CONNECT - return nil, errors.New(":path must be empty and :authority must not be empty") - } - } else if len(path) == 0 || len(authority) == 0 || len(method) == 0 { - return nil, errors.New(":path, :authority and :method must not be empty") - } - - var u *url.URL - var requestURI string - var err error - - if isConnect { - u = &url.URL{} - if isExtendedConnected { - u, err = url.ParseRequestURI(path) - if err != nil { - return nil, err - } - } else { - u.Path = path - } - u.Scheme = scheme - u.Host = authority - requestURI = authority - } else { - protocol = "HTTP/3.0" - u, err = url.ParseRequestURI(path) - if err != nil { - return nil, err - } - requestURI = path - } - - var contentLength int64 - if len(contentLengthStr) > 0 { - contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) - if err != nil { - return nil, err - } - } - - return &http.Request{ - Method: method, - URL: u, - Proto: protocol, - ProtoMajor: 3, - ProtoMinor: 0, - Header: httpHeaders, - Body: nil, - ContentLength: contentLength, - Host: authority, - RequestURI: requestURI, - }, nil -} - -func hostnameFromRequest(req *http.Request) string { - if req.URL != nil { - return req.URL.Host - } - return "" -} diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 1fdcf934..84330c6b 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "errors" "fmt" "io" "net" @@ -82,6 +83,9 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra if err != nil { return err } + if !httpguts.ValidHostHeader(host) { + return errors.New("http3: invalid Host header") + } // http.NewRequest sets this field to HTTP/1.1 isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" From f7e03fa154ad65ee517da5452739d976025f976e Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jul 2023 17:36:08 +0800 Subject: [PATCH 722/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bc2a3531..29157a28 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Full documentation is available on the official website: https://req.cool. **Install** -You first need [Go](https://go.dev/) installed (version 1.19+ is required), then you can use the below Go command to install req: +You first need [Go](https://go.dev/) installed (version 1.20+ is required), then you can use the below Go command to install req: ``` sh go get github.com/imroc/req/v3 From e2cff66a8948d982ba7b1970d524cf48ce8e6405 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jul 2023 17:36:59 +0800 Subject: [PATCH 723/843] update CI: remove go1.19 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5883b872..43e855c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.20.x', '1.19.x' ] + go: [ '1.20.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: From 7c4631441e76e6c98a390e14e9f11a70234aa9db Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 28 Jul 2023 17:52:23 +0800 Subject: [PATCH 724/843] fix http3 tls ConnectionState --- internal/http3/client.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index f336cc1a..4ba9e9bb 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "reflect" "strconv" "sync" "sync/atomic" @@ -462,15 +461,12 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui return nil, newConnError(errorGeneralProtocolError, err) } - connState, ok := reflect.ValueOf(conn.ConnectionState().TLS).Field(0).Interface().(tls.ConnectionState) - if !ok { - panic(fmt.Sprintf("bad tls connection state type: %s", reflect.ValueOf(conn.ConnectionState().TLS).Field(0).Type().Name())) - } res, err := responseFromHeaders(hfs) if err != nil { return nil, newStreamError(errorMessageError, err) } res.Request = req + connState := conn.ConnectionState().TLS res.TLS = &connState respBody := newResponseBody(hstr, conn, reqDone) From bc158ce7f151c3d966cbfb6fbe4cb6d758f05e56 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 11:59:03 +0800 Subject: [PATCH 725/843] Support SetHeaderOrder for HTTP/1.1 Request --- header.go | 70 +++++++++++++++++++-------------------- http_request.go | 12 ++++--- internal/header/header.go | 2 ++ internal/header/sort.go | 37 +++++++++++++++++++++ request.go | 21 ++++++++++++ transfer.go | 30 +++++------------ transport.go | 52 +++++++++++++++++++++++------ 7 files changed, 152 insertions(+), 72 deletions(-) create mode 100644 internal/header/sort.go diff --git a/header.go b/header.go index eb991c49..d31da8d7 100644 --- a/header.go +++ b/header.go @@ -1,10 +1,10 @@ package req import ( + "github.com/imroc/req/v3/internal/header" "golang.org/x/net/http/httpguts" "io" "net/http" - "net/http/httptrace" "net/textproto" "sort" "strings" @@ -22,21 +22,16 @@ func (w stringWriter) WriteString(s string) (n int, err error) { return w.w.Write([]byte(s)) } -type keyValues struct { - key string - values []string -} - // A headerSorter implements sort.Interface by sorting a []keyValues // by key. It's used as a pointer, so it can fit in a sort.Interface // interface value without allocation. type headerSorter struct { - kvs []keyValues + kvs []header.KeyValues } func (s *headerSorter) Len() int { return len(s.kvs) } func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].Key < s.kvs[j].Key } var headerSorterPool = sync.Pool{ New: func() interface{} { return new(headerSorter) }, @@ -60,15 +55,15 @@ func headerHas(h http.Header, key string) bool { // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. -func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { +func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []header.KeyValues, hs *headerSorter) { hs = headerSorterPool.Get().(*headerSorter) if cap(hs.kvs) < len(h) { - hs.kvs = make([]keyValues, 0, len(h)) + hs.kvs = make([]header.KeyValues, 0, len(h)) } kvs = hs.kvs[:0] for k, vv := range h { if !exclude[k] { - kvs = append(kvs, keyValues{k, vv}) + kvs = append(kvs, header.KeyValues{k, vv}) } } hs.kvs = kvs @@ -76,43 +71,48 @@ func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyVal return kvs, hs } -func headerWrite(h http.Header, w io.Writer, trace *httptrace.ClientTrace) error { - return headerWriteSubset(h, w, nil, trace) +func headerWrite(h http.Header, writeHeader func(key string, values ...string) error, sort bool) error { + return headerWriteSubset(h, nil, writeHeader, sort) } -func headerWriteSubset(h http.Header, w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error { - ws, ok := w.(io.StringWriter) - if !ok { - ws = stringWriter{w} +func headerWriteSubset(h http.Header, exclude map[string]bool, writeHeader func(key string, values ...string) error, sort bool) error { + var kvs []header.KeyValues + var hs *headerSorter + if sort { + kvs = make([]header.KeyValues, 0, len(h)) + for k, v := range h { + if !exclude[k] { + kvs = append(kvs, header.KeyValues{k, v}) + } + } + } else { + kvs, hs = headerSortedKeyValues(h, exclude) } - kvs, sorter := headerSortedKeyValues(h, exclude) - var formattedVals []string for _, kv := range kvs { - if !httpguts.ValidHeaderFieldName(kv.key) { + if !httpguts.ValidHeaderFieldName(kv.Key) { // This could be an error. In the common case of // writing response headers, however, we have no good // way to provide the error back to the server // handler, so just drop invalid headers instead. continue } - for _, v := range kv.values { - v = headerNewlineToSpace.Replace(v) - v = textproto.TrimString(v) - for _, s := range []string{kv.key, ": ", v, "\r\n"} { - if _, err := ws.WriteString(s); err != nil { - headerSorterPool.Put(sorter) - return err - } - } - if trace != nil && trace.WroteHeaderField != nil { - formattedVals = append(formattedVals, v) + for i, v := range kv.Values { + vv := headerNewlineToSpace.Replace(v) + vv = textproto.TrimString(v) + if vv != v { + kv.Values[i] = vv } } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField(kv.key, formattedVals) - formattedVals = nil + err := writeHeader(kv.Key, kv.Values...) + if err != nil { + if hs != nil { + headerSorterPool.Put(hs) + } + return err } } - headerSorterPool.Put(sorter) + if hs != nil { + headerSorterPool.Put(hs) + } return nil } diff --git a/http_request.go b/http_request.go index 2d471b4d..c737a6b3 100644 --- a/http_request.go +++ b/http_request.go @@ -3,6 +3,7 @@ package req import ( "errors" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/header" "golang.org/x/net/http/httpguts" "net/http" "strings" @@ -87,11 +88,12 @@ func closeRequestBody(r *http.Request) error { // Headers that Request.Write handles itself and should be skipped. var reqWriteExcludeHeader = map[string]bool{ - "Host": true, // not in Header map anyway - "User-Agent": true, - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, + header.HeaderOderKey: true, } // requestMethodUsuallyLacksBody reports whether the given request diff --git a/internal/header/header.go b/internal/header/header.go index b6064558..0aead102 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -11,4 +11,6 @@ const ( FormContentType = "application/x-www-form-urlencoded" WwwAuthenticate = "WWW-Authenticate" Authorization = "Authorization" + HeaderOderKey = "__Header_Order__" + PseudoHeaderOderKey = "__Pseudo_Header_Order__" ) diff --git a/internal/header/sort.go b/internal/header/sort.go new file mode 100644 index 00000000..8c768c36 --- /dev/null +++ b/internal/header/sort.go @@ -0,0 +1,37 @@ +package header + +import "sort" + +type KeyValues struct { + Key string + Values []string +} + +type sorter struct { + order map[string]int + kvs []KeyValues +} + +func (s *sorter) Len() int { return len(s.kvs) } +func (s *sorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *sorter) Less(i, j int) bool { + if index, ok := s.order[s.kvs[i].Key]; ok { + i = index + } + if index, ok := s.order[s.kvs[j].Key]; ok { + j = index + } + return i < j +} + +func SortKeyValues(kvs []KeyValues, orderedKeys []string) { + order := make(map[string]int) + for i, key := range orderedKeys { + order[key] = i + } + s := &sorter{ + order: order, + kvs: kvs, + } + sort.Sort(s) +} diff --git a/request.go b/request.go index 63df395d..e57d82bc 100644 --- a/request.go +++ b/request.go @@ -463,6 +463,27 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { return r } +const ( + HeaderOderKey = "__Header_Order__" + PseudoHeaderOderKey = "__Pseudo_Header_Order__" +) + +func (r *Request) SetHeaderOrder(keys ...string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers[HeaderOderKey] = append(r.Headers[HeaderOderKey], keys...) + return r +} + +func (r *Request) SetPseudoHeaderOrder(keys ...string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers[PseudoHeaderOderKey] = append(r.Headers[PseudoHeaderOderKey], keys...) + return r +} + // SetOutputFile set the file that response Body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true diff --git a/transfer.go b/transfer.go index 58e64793..c7c623f8 100644 --- a/transfer.go +++ b/transfer.go @@ -14,7 +14,6 @@ import ( "github.com/imroc/req/v3/internal/dump" "io" "net/http" - "net/http/httptrace" "net/textproto" "reflect" "sort" @@ -245,36 +244,27 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } -func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error { +func (t *transferWriter) writeHeader(writeHeader func(key string, values ...string) error) error { if t.Close && !hasToken(headerGet(t.Header, "Connection"), "close") { - if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + err := writeHeader("Connection", "close") + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Connection", []string{"close"}) - } } // Write Content-Length and/or Transfer-Encoding whose values are a // function of the sanitized field triple (Body, ContentLength, // TransferEncoding) if t.shouldSendContentLength() { - if _, err := io.WriteString(w, "Content-Length: "); err != nil { - return err - } - if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + err := writeHeader("Content-Length", strconv.FormatInt(t.ContentLength, 10)) + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)}) - } } else if chunked(t.TransferEncoding) { - if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + err := writeHeader("Transfer-Encoding", "chunked") + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"}) - } } // Write Trailer header @@ -292,12 +282,10 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) sort.Strings(keys) // TODO: could do better allocation-wise here, but trailers are rare, // so being lazy for now. - if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + err := writeHeader("Trailer", strings.Join(keys, ",")) + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Trailer", keys) - } } } diff --git a/transport.go b/transport.go index e18ca80f..f3b4acac 100644 --- a/transport.go +++ b/transport.go @@ -2966,14 +2966,40 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo return err } + _writeHeader := func(key string, values ...string) error { + for _, value := range values { + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, value) + if err != nil { + return err + } + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(key, values) + } + return nil + } + + var writeHeader func(key string, values ...string) error + var kvs []header.KeyValues + sort := false + + if r.Header != nil && len(r.Header[header.HeaderOderKey]) > 0 { + writeHeader = func(key string, values ...string) error { + kvs = append(kvs, header.KeyValues{ + Key: key, + Values: values, + }) + return nil + } + sort = true + } else { + writeHeader = _writeHeader + } // Header lines - _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + err = writeHeader("Host", host) if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Host", []string{host}) - } // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. @@ -2982,13 +3008,10 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo userAgent = r.Header.Get("User-Agent") } if userAgent != "" { - _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + err = writeHeader("User-Agent", userAgent) if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("User-Agent", []string{userAgent}) - } } // Process Body,ContentLength,Close,Trailer @@ -2996,23 +3019,30 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo if err != nil { return err } - err = tw.writeHeader(w, trace) + err = tw.writeHeader(writeHeader) if err != nil { return err } - err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace) + err = headerWriteSubset(r.Header, reqWriteExcludeHeader, writeHeader, sort) if err != nil { return err } if extraHeaders != nil { - err = headerWrite(extraHeaders, w, trace) + err = headerWrite(extraHeaders, writeHeader, sort) if err != nil { return err } } + if sort { // sort and write headers + header.SortKeyValues(kvs, r.Header[header.HeaderOderKey]) + for _, kv := range kvs { + _writeHeader(kv.Key, kv.Values...) + } + } + _, err = io.WriteString(w, "\r\n") if err != nil { return err From 22e6fc978118b8a830efac5dde802aa766deb6b4 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 16:20:46 +0800 Subject: [PATCH 726/843] Support SetHeaderOrder and SetPseudoHeaderOrder for HTTP/2 Request --- internal/header/header.go | 4 +- internal/header/sort.go | 11 ++-- internal/http2/transport.go | 117 ++++++++++++++++++++++++++++-------- request.go | 4 +- 4 files changed, 102 insertions(+), 34 deletions(-) diff --git a/internal/header/header.go b/internal/header/header.go index 0aead102..33e779bc 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -11,6 +11,6 @@ const ( FormContentType = "application/x-www-form-urlencoded" WwwAuthenticate = "WWW-Authenticate" Authorization = "Authorization" - HeaderOderKey = "__Header_Order__" - PseudoHeaderOderKey = "__Pseudo_Header_Order__" + HeaderOderKey = "__header_order__" + PseudoHeaderOderKey = "__pseudo_header_order__" ) diff --git a/internal/header/sort.go b/internal/header/sort.go index 8c768c36..2c61fd2e 100644 --- a/internal/header/sort.go +++ b/internal/header/sort.go @@ -1,6 +1,9 @@ package header -import "sort" +import ( + "net/textproto" + "sort" +) type KeyValues struct { Key string @@ -15,10 +18,10 @@ type sorter struct { func (s *sorter) Len() int { return len(s.kvs) } func (s *sorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } func (s *sorter) Less(i, j int) bool { - if index, ok := s.order[s.kvs[i].Key]; ok { + if index, ok := s.order[textproto.CanonicalMIMEHeaderKey(s.kvs[i].Key)]; ok { i = index } - if index, ok := s.order[s.kvs[j].Key]; ok { + if index, ok := s.order[textproto.CanonicalMIMEHeaderKey(s.kvs[j].Key)]; ok { j = index } return i < j @@ -27,7 +30,7 @@ func (s *sorter) Less(i, j int) bool { func SortKeyValues(kvs []KeyValues, orderedKeys []string) { order := make(map[string]int) for i, key := range orderedKeys { - order[key] = i + order[textproto.CanonicalMIMEHeaderKey(key)] = i } s := &sorter{ order: order, diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 21bf6099..1208cbeb 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -1748,6 +1748,25 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } +var reqWriteExcludeHeader = map[string]bool{ + // Host is :authority, already sent. + // Content-Length is automatic. + "host": true, + "content-length": true, + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + "connection": true, + "proxy-connection": true, + "transfer-encoding": true, + "upgrade": true, + "keep-alive": true, + // Ignore header order keys which is only used internally. + header.HeaderOderKey: true, + header.PseudoHeaderOderKey: true, +} + var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -1797,40 +1816,75 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } enumerateHeaders := func(f func(name, value string)) { + var writeHeader func(name string, value ...string) + var kvs []header.KeyValues + sort := false + + if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + sort = true + } else { + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } + // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character // followed by the query production, see Sections 3.3 and 3.4 of // [RFC3986]). - f(":authority", host) + writeHeader(":authority", host) m := req.Method if m == "" { m = http.MethodGet } - f(":method", m) + writeHeader(":method", m) if req.Method != "CONNECT" { - f(":path", path) - f(":scheme", req.URL.Scheme) + writeHeader(":path", path) + writeHeader(":scheme", req.URL.Scheme) + } + if sort { + header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } + + if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + sort = true + kvs = nil + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + } else { + sort = false + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } + if trailers != "" { - f("trailer", trailers) + writeHeader("trailer", trailers) } var didUA bool for k, vv := range req.Header { - if ascii.EqualFold(k, "host") || ascii.EqualFold(k, "content-length") { - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - } else if ascii.EqualFold(k, "connection") || - ascii.EqualFold(k, "proxy-connection") || - ascii.EqualFold(k, "transfer-encoding") || - ascii.EqualFold(k, "upgrade") || - ascii.EqualFold(k, "keep-alive") { - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. + if reqWriteExcludeHeader[strings.ToLower(k)] { continue } else if ascii.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -1846,6 +1900,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail continue } } else if ascii.EqualFold(k, "cookie") { + var vals []string // Per 8.1.2.5 To allow for better compression efficiency, the // Cookie header field MAY be split into separate header fields, // each with one or more cookie-pairs. @@ -1855,7 +1910,8 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail if p < 0 { break } - f("cookie", v[:p]) + vals = append(vals, v[:p]) + //writeHeader("cookie", v[:p]) p++ // strip space after semicolon if any. for p+1 <= len(v) && v[p] == ' ' { @@ -1864,24 +1920,33 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail v = v[p:] } if len(v) > 0 { - f("cookie", v) + vals = append(vals, v) + //writeHeader("cookie", v) } } + writeHeader("cookie", vals...) continue } - for _, v := range vv { - f(k, v) - } + writeHeader(k, vv...) } if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + writeHeader("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { - f("accept-encoding", "gzip") + writeHeader("accept-encoding", "gzip") } if !didUA { - f("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", header.DefaultUserAgent) + } + + if sort { + header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } } diff --git a/request.go b/request.go index e57d82bc..7e705bb6 100644 --- a/request.go +++ b/request.go @@ -464,8 +464,8 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { } const ( - HeaderOderKey = "__Header_Order__" - PseudoHeaderOderKey = "__Pseudo_Header_Order__" + HeaderOderKey = "__header_order__" + PseudoHeaderOderKey = "__pseudo_header_order__" ) func (r *Request) SetHeaderOrder(keys ...string) *Request { From f0b2f73706f5d8550774fe5543d3b2528f97bf4e Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 16:52:55 +0800 Subject: [PATCH 727/843] Support SetHeaderOrder and SetPseudoHeaderOrder for HTTP/3 Request --- internal/http3/request_writer.go | 117 ++++++++++++++++++++++++------- internal/http3/trace.go | 7 ++ 2 files changed, 98 insertions(+), 26 deletions(-) create mode 100644 internal/http3/trace.go diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 84330c6b..1f6ce445 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "net/http/httptrace" "strconv" "strings" "sync" @@ -70,6 +71,25 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } +var reqWriteExcludeHeader = map[string]bool{ + // Host is :authority, already sent. + // Content-Length is automatic. + "host": true, + "content-length": true, + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + "connection": true, + "proxy-connection": true, + "transfer-encoding": true, + "upgrade": true, + "keep-alive": true, + // Ignore header order keys which is only used internally. + header.HeaderOderKey: true, + header.PseudoHeaderOderKey: true, +} + // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -121,37 +141,73 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } enumerateHeaders := func(f func(name, value string)) { + var writeHeader func(name string, value ...string) + var kvs []header.KeyValues + sort := false + if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + sort = true + } else { + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character // followed by the query production (see Sections 3.3 and 3.4 of // [RFC3986]). - f(":authority", host) - f(":method", req.Method) + writeHeader(":authority", host) + writeHeader(":method", req.Method) if req.Method != http.MethodConnect || isExtendedConnect { - f(":path", path) - f(":scheme", req.URL.Scheme) + writeHeader(":path", path) + writeHeader(":scheme", req.URL.Scheme) } if isExtendedConnect { - f(":protocol", req.Proto) + writeHeader(":protocol", req.Proto) + } + + if sort { + header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } + } + + if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + sort = true + kvs = nil + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + } else { + sort = false + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } } + if trailers != "" { - f("trailer", trailers) + writeHeader("trailer", trailers) } var didUA bool for k, vv := range req.Header { - if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || - strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || - strings.EqualFold(k, "keep-alive") { - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. + if reqWriteExcludeHeader[strings.ToLower(k)] { continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -170,17 +226,26 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } for _, v := range vv { - f(k, v) + writeHeader(k, v) } } if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + writeHeader("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { - f("accept-encoding", "gzip") + writeHeader("accept-encoding", "gzip") } if !didUA { - f("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", header.DefaultUserAgent) + } + + if sort { + header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } } @@ -199,8 +264,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // return errRequestHeaderListSize // } - // trace := httptrace.ContextClientTrace(req.Context()) - // traceHeaders := traceHasWroteHeaderField(trace) + trace := httptrace.ContextClientTrace(req.Context()) + traceHeaders := traceHasWroteHeaderField(trace) // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { @@ -209,9 +274,9 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra dump.DumpRequestHeader([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) - // if traceHeaders { - // traceWroteHeaderField(trace, name, value) - // } + if traceHeaders { + trace.WroteHeaderField(name, []string{value}) + } }) for _, dump := range dumps { diff --git a/internal/http3/trace.go b/internal/http3/trace.go new file mode 100644 index 00000000..710072de --- /dev/null +++ b/internal/http3/trace.go @@ -0,0 +1,7 @@ +package http3 + +import "net/http/httptrace" + +func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +} From 68815fd4175a10b43e1debcf56118a7c4bf86069 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 17:03:23 +0800 Subject: [PATCH 728/843] Refactor: share excluded headers between http2 and http3 --- internal/header/header.go | 28 ++++++++++++++++++++++++++++ internal/http2/transport.go | 21 +-------------------- internal/http3/request_writer.go | 21 +-------------------- 3 files changed, 30 insertions(+), 40 deletions(-) diff --git a/internal/header/header.go b/internal/header/header.go index 33e779bc..4098febe 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -1,5 +1,7 @@ package header +import "strings" + const ( DefaultUserAgent = "req/v3 (https://github.com/imroc/req)" UserAgent = "User-Agent" @@ -14,3 +16,29 @@ const ( HeaderOderKey = "__header_order__" PseudoHeaderOderKey = "__pseudo_header_order__" ) + +var reqWriteExcludeHeader = map[string]bool{ + // Host is :authority, already sent. + // Content-Length is automatic. + "host": true, + "content-length": true, + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + "connection": true, + "proxy-connection": true, + "transfer-encoding": true, + "upgrade": true, + "keep-alive": true, + // Ignore header order keys which is only used internally. + HeaderOderKey: true, + PseudoHeaderOderKey: true, +} + +func IsExcluded(key string) bool { + if reqWriteExcludeHeader[strings.ToLower(key)] { + return true + } + return false +} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 1208cbeb..75b100e3 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -1748,25 +1748,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } -var reqWriteExcludeHeader = map[string]bool{ - // Host is :authority, already sent. - // Content-Length is automatic. - "host": true, - "content-length": true, - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. - "connection": true, - "proxy-connection": true, - "transfer-encoding": true, - "upgrade": true, - "keep-alive": true, - // Ignore header order keys which is only used internally. - header.HeaderOderKey: true, - header.PseudoHeaderOderKey: true, -} - var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -1884,7 +1865,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail var didUA bool for k, vv := range req.Header { - if reqWriteExcludeHeader[strings.ToLower(k)] { + if header.IsExcluded(k) { continue } else if ascii.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 1f6ce445..443b34c2 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -71,25 +71,6 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } -var reqWriteExcludeHeader = map[string]bool{ - // Host is :authority, already sent. - // Content-Length is automatic. - "host": true, - "content-length": true, - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. - "connection": true, - "proxy-connection": true, - "transfer-encoding": true, - "upgrade": true, - "keep-alive": true, - // Ignore header order keys which is only used internally. - header.HeaderOderKey: true, - header.PseudoHeaderOderKey: true, -} - // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -207,7 +188,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra var didUA bool for k, vv := range req.Header { - if reqWriteExcludeHeader[strings.ToLower(k)] { + if header.IsExcluded(k) { continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one From ff63e3bdea36b174331f04373f35d5b9c861b2c7 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 17:16:45 +0800 Subject: [PATCH 729/843] implement SetHeaderOrder and SetPseudoHeaderOder at client-level --- client.go | 10 ++++++++++ transport.go | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/client.go b/client.go index b6316b86..b05411ee 100644 --- a/client.go +++ b/client.go @@ -871,6 +871,16 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { return c } +func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { + c.t.SetHeaderOrder(keys...) + return c +} + +func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { + c.t.SetPseudoHeaderOder(keys...) + return c +} + // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { diff --git a/transport.go b/transport.go index f3b4acac..6bef585c 100644 --- a/transport.go +++ b/transport.go @@ -210,6 +210,30 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } +func (t *Transport) SetHeaderOrder(keys ...string) { + t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + req.Header[HeaderOderKey] = keys + return rt.RoundTrip(req) + } + }) +} + +func (t *Transport) SetPseudoHeaderOder(keys ...string) { + t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + req.Header[PseudoHeaderOderKey] = keys + return rt.RoundTrip(req) + } + }) +} + // DisableAutoDecode disable auto-detect charset and decode to utf-8 // (enabled by default). func (t *Transport) DisableAutoDecode() *Transport { From 8ca2a65629418f0e78db70e8ac90b9884286c856 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 17:37:54 +0800 Subject: [PATCH 730/843] add comments --- client.go | 21 +++++++++++++++++++++ request.go | 33 ++++++++++++++++++++++++++++----- transport.go | 19 +++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index b05411ee..9cd56774 100644 --- a/client.go +++ b/client.go @@ -871,11 +871,32 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { return c } +// SetCommonHeaderOrder set the order of the http header requests fired from the +// client (case-insensitive). +// For example: +// +// client.R().SetCommonHeaderOrder( +// "custom-header", +// "cookie", +// "user-agent", +// "accept-encoding", +// ).Get(url func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { c.t.SetHeaderOrder(keys...) return c } +// SetCommonPseudoHeaderOder set the order of the pseudo http header requests fired +// from the client (case-insensitive). +// Note this is only valid for http2 and http3. +// For example: +// +// client.SetCommonPseudoHeaderOder( +// ":scheme", +// ":authority", +// ":path", +// ":method", +// ) func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { c.t.SetPseudoHeaderOder(keys...) return c diff --git a/request.go b/request.go index 7e705bb6..d0a295a5 100644 --- a/request.go +++ b/request.go @@ -464,10 +464,23 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { } const ( - HeaderOderKey = "__header_order__" + // HeaderOderKey is the key of header order, which specifies the order + // of the http header. + HeaderOderKey = "__header_order__" + // PseudoHeaderOderKey is the key of pseudo header order, which specifies + // the order of the http2 and http3 pseudo header. PseudoHeaderOderKey = "__pseudo_header_order__" ) +// SetHeaderOrder set the order of the http header (case-insensitive). +// For example: +// +// client.R().SetHeaderOrder( +// "custom-header", +// "cookie", +// "user-agent", +// "accept-encoding", +// ).Get(url) func (r *Request) SetHeaderOrder(keys ...string) *Request { if r.Headers == nil { r.Headers = make(http.Header) @@ -476,6 +489,16 @@ func (r *Request) SetHeaderOrder(keys ...string) *Request { return r } +// SetPseudoHeaderOrder set the order of the pseudo http header (case-insensitive). +// Note this is only valid for http2 and http3. +// For example: +// +// client.R().SetPseudoHeaderOrder( +// ":scheme", +// ":authority", +// ":path", +// ":method", +// ).Get(url) func (r *Request) SetPseudoHeaderOrder(keys ...string) *Request { if r.Headers == nil { r.Headers = make(http.Header) @@ -1086,10 +1109,10 @@ func (r *Request) SetRetryCount(count int) *Request { // implement your own backoff retry algorithm. // For example: // -// req.SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { -// sleep := 0.01 * math.Exp2(float64(attempt)) -// return time.Duration(math.Min(2, sleep)) * time.Second -// }) +// req.SetRetryInterval(func(resp *req.Response, attempt int) time.Duration { +// sleep := 0.01 * math.Exp2(float64(attempt)) +// return time.Duration(math.Min(2, sleep)) * time.Second +// }) func (r *Request) SetRetryInterval(getRetryIntervalFunc GetRetryIntervalFunc) *Request { r.getRetryOption().GetRetryInterval = getRetryIntervalFunc return r diff --git a/transport.go b/transport.go index 6bef585c..f1ee7698 100644 --- a/transport.go +++ b/transport.go @@ -210,6 +210,15 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } +// SetHeaderOrder set the order of the http header (case-insensitive). +// For example: +// +// t.SetHeaderOrder( +// "custom-header", +// "cookie", +// "user-agent", +// "accept-encoding", +// ) func (t *Transport) SetHeaderOrder(keys ...string) { t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { return func(req *http.Request) (resp *http.Response, err error) { @@ -222,6 +231,16 @@ func (t *Transport) SetHeaderOrder(keys ...string) { }) } +// SetPseudoHeaderOder set the order of the pseudo http header (case-insensitive). +// Note this is only valid for http2 and http3. +// For example: +// +// t.SetPseudoHeaderOrder( +// ":scheme", +// ":authority", +// ":path", +// ":method", +// ) func (t *Transport) SetPseudoHeaderOder(keys ...string) { t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { return func(req *http.Request) (resp *http.Response, err error) { From 63fafa06f2a15f826a267a3ca4548001e96018cf Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 17:41:24 +0800 Subject: [PATCH 731/843] add global wrappers --- client_wrapper.go | 12 ++++++++++++ request_wrapper.go | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/client_wrapper.go b/client_wrapper.go index d6540450..f460fdaf 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -409,6 +409,18 @@ func SetCommonHeader(key, value string) *Client { return defaultClient.SetCommonHeader(key, value) } +// SetCommonHeaderOrder is a global wrapper methods which delegated +// to the default client's Client.SetCommonHeaderOrder. +func SetCommonHeaderOrder(keys ...string) *Client { + return defaultClient.SetCommonHeaderOrder(keys...) +} + +// SetCommonPseudoHeaderOder is a global wrapper methods which delegated +// to the default client's Client.SetCommonPseudoHeaderOder. +func SetCommonPseudoHeaderOder(keys ...string) *Client { + return defaultClient.SetCommonPseudoHeaderOder(keys...) +} + // SetCommonContentType is a global wrapper methods which delegated // to the default client's Client.SetCommonContentType. func SetCommonContentType(ct string) *Client { diff --git a/request_wrapper.go b/request_wrapper.go index 050eaad0..fec48b40 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -132,6 +132,18 @@ func SetHeader(key, value string) *Request { return defaultClient.R().SetHeader(key, value) } +// SetHeaderOrder is a global wrapper methods which delegated +// to the default client, create a request and SetHeaderOrder for request. +func SetHeaderOrder(keys ...string) *Request { + return defaultClient.R().SetHeaderOrder(keys...) +} + +// SetPseudoHeaderOrder is a global wrapper methods which delegated +// to the default client, create a request and SetPseudoHeaderOrder for request. +func SetPseudoHeaderOrder(keys ...string) *Request { + return defaultClient.R().SetPseudoHeaderOrder(keys...) +} + // SetOutputFile is a global wrapper methods which delegated // to the default client, create a request and SetOutputFile for request. func SetOutputFile(file string) *Request { From 853b883f210c77a4397bfbdb037ecc29be72894f Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 17:57:56 +0800 Subject: [PATCH 732/843] update comments --- request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index d0a295a5..596a609e 100644 --- a/request.go +++ b/request.go @@ -480,7 +480,7 @@ const ( // "cookie", // "user-agent", // "accept-encoding", -// ).Get(url) +// ) func (r *Request) SetHeaderOrder(keys ...string) *Request { if r.Headers == nil { r.Headers = make(http.Header) @@ -498,7 +498,7 @@ func (r *Request) SetHeaderOrder(keys ...string) *Request { // ":authority", // ":path", // ":method", -// ).Get(url) +// ) func (r *Request) SetPseudoHeaderOrder(keys ...string) *Request { if r.Headers == nil { r.Headers = make(http.Header) From 6d1b9ec0e6b3e174139f7938109e44236a054d4c Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 15:25:11 +0800 Subject: [PATCH 733/843] export http2 setting --- internal/http2/frame.go | 9 +++++---- internal/http2/http2.go | 38 +++++-------------------------------- internal/http2/transport.go | 15 ++++++++------- pkg/http2/setting.go | 34 +++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 44 deletions(-) create mode 100644 pkg/http2/setting.go diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 34284fe5..5f8afc21 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/pkg/http2" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "io" @@ -794,7 +795,7 @@ func parseSettingsFrame(_ *frameCache, fh FrameHeader, countError func(string), return nil, ConnectionError(ErrCodeFrameSize) } f := &SettingsFrame{FrameHeader: fh, p: p} - if v, ok := f.Value(SettingInitialWindowSize); ok && v > (1<<31)-1 { + if v, ok := f.Value(http2.SettingInitialWindowSize); ok && v > (1<<31)-1 { countError("frame_settings_window_size_too_big") // Values above the maximum flow control window size of 2^31 - 1 MUST // be treated as a connection error (Section 5.4.1) of type @@ -808,7 +809,7 @@ func (f *SettingsFrame) IsAck() bool { return f.FrameHeader.Flags.Has(FlagSettingsAck) } -func (f *SettingsFrame) Value(id SettingID) (v uint32, ok bool) { +func (f *SettingsFrame) Value(id http2.SettingID) (v uint32, ok bool) { f.checkValid() for i := 0; i < f.NumSettings(); i++ { if s := f.Setting(i); s.ID == id { @@ -823,7 +824,7 @@ func (f *SettingsFrame) Value(id SettingID) (v uint32, ok bool) { func (f *SettingsFrame) Setting(i int) Setting { buf := f.p return Setting{ - ID: SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), + ID: http2.SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), } } @@ -850,7 +851,7 @@ func (f *SettingsFrame) HasDuplicates() bool { } return false } - seen := map[SettingID]bool{} + seen := map[http2.SettingID]bool{} for i := 0; i < num; i++ { id := f.Setting(i).ID if seen[id] { diff --git a/internal/http2/http2.go b/internal/http2/http2.go index 0e1fe2b6..15932b5a 100644 --- a/internal/http2/http2.go +++ b/internal/http2/http2.go @@ -8,6 +8,7 @@ import ( "bufio" "crypto/tls" "fmt" + "github.com/imroc/req/v3/pkg/http2" "golang.org/x/net/http/httpguts" "net/http" "os" @@ -59,7 +60,7 @@ var ( type Setting struct { // ID is which setting is being set. // See https://httpwg.org/specs/rfc7540.html#SettingValues - ID SettingID + ID http2.SettingID // Val is the value. Val uint32 @@ -73,15 +74,15 @@ func (s Setting) String() string { func (s Setting) Valid() error { // Limits and error codes from 6.5.2 Defined SETTINGS Parameters switch s.ID { - case SettingEnablePush: + case http2.SettingEnablePush: if s.Val != 1 && s.Val != 0 { return ConnectionError(ErrCodeProtocol) } - case SettingInitialWindowSize: + case http2.SettingInitialWindowSize: if s.Val > 1<<31-1 { return ConnectionError(ErrCodeFlowControl) } - case SettingMaxFrameSize: + case http2.SettingMaxFrameSize: if s.Val < 16384 || s.Val > 1<<24-1 { return ConnectionError(ErrCodeProtocol) } @@ -89,35 +90,6 @@ func (s Setting) Valid() error { return nil } -// A SettingID is an HTTP/2 setting as defined in -// https://httpwg.org/specs/rfc7540.html#iana-settings -type SettingID uint16 - -const ( - SettingHeaderTableSize SettingID = 0x1 - SettingEnablePush SettingID = 0x2 - SettingMaxConcurrentStreams SettingID = 0x3 - SettingInitialWindowSize SettingID = 0x4 - SettingMaxFrameSize SettingID = 0x5 - SettingMaxHeaderListSize SettingID = 0x6 -) - -var settingName = map[SettingID]string{ - SettingHeaderTableSize: "HEADER_TABLE_SIZE", - SettingEnablePush: "ENABLE_PUSH", - SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", - SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", - SettingMaxFrameSize: "MAX_FRAME_SIZE", - SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", -} - -func (s SettingID) String() string { - if v, ok := settingName[s]; ok { - return v - } - return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) -} - // validWireHeaderFieldName reports whether v is a valid header field // name (key). See httpguts.ValidHeaderName for the base rules. // diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 75b100e3..9186085b 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -21,6 +21,7 @@ import ( "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/transport" + "github.com/imroc/req/v3/pkg/http2" reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" @@ -678,11 +679,11 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro } initialSettings := []Setting{ - {ID: SettingEnablePush, Val: 0}, - {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow}, + {ID: http2.SettingEnablePush, Val: 0}, + {ID: http2.SettingInitialWindowSize, Val: transportDefaultStreamFlow}, } if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) + initialSettings = append(initialSettings, Setting{ID: http2.SettingMaxHeaderListSize, Val: max}) } cc.bw.Write(clientPreface) @@ -2809,14 +2810,14 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { var seenMaxConcurrentStreams bool err := f.ForeachSetting(func(s Setting) error { switch s.ID { - case SettingMaxFrameSize: + case http2.SettingMaxFrameSize: cc.maxFrameSize = s.Val - case SettingMaxConcurrentStreams: + case http2.SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val seenMaxConcurrentStreams = true - case SettingMaxHeaderListSize: + case http2.SettingMaxHeaderListSize: cc.peerMaxHeaderListSize = uint64(s.Val) - case SettingInitialWindowSize: + case http2.SettingInitialWindowSize: // Values above the maximum flow-control // window size of 2^31-1 MUST be treated as a // connection error (Section 5.4.1) of type diff --git a/pkg/http2/setting.go b/pkg/http2/setting.go new file mode 100644 index 00000000..6d0e0aa1 --- /dev/null +++ b/pkg/http2/setting.go @@ -0,0 +1,34 @@ +package http2 + +import ( + "fmt" +) + +// A SettingID is an HTTP/2 setting as defined in +// https://httpwg.org/specs/rfc7540.html#iana-settings +type SettingID uint16 + +const ( + SettingHeaderTableSize SettingID = 0x1 + SettingEnablePush SettingID = 0x2 + SettingMaxConcurrentStreams SettingID = 0x3 + SettingInitialWindowSize SettingID = 0x4 + SettingMaxFrameSize SettingID = 0x5 + SettingMaxHeaderListSize SettingID = 0x6 +) + +var settingName = map[SettingID]string{ + SettingHeaderTableSize: "HEADER_TABLE_SIZE", + SettingEnablePush: "ENABLE_PUSH", + SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + SettingMaxFrameSize: "MAX_FRAME_SIZE", + SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s SettingID) String() string { + if v, ok := settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} From 9de13ca9c9aa232221cf22775bc44900220e6150 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 15:28:56 +0800 Subject: [PATCH 734/843] remove http2 tests --- internal/http2/databuffer_test.go | 155 - internal/http2/errors_test.go | 24 - internal/http2/flow_test.go | 139 - internal/http2/frame_test.go | 1479 ------ internal/http2/gotrack_test.go | 56 - internal/http2/http2_test.go | 119 - internal/http2/pipe_test.go | 141 - internal/http2/server_test.go | 5527 ---------------------- internal/http2/transport_go117_test.go | 176 - internal/http2/transport_test.go | 6008 ------------------------ 10 files changed, 13824 deletions(-) delete mode 100644 internal/http2/databuffer_test.go delete mode 100644 internal/http2/errors_test.go delete mode 100644 internal/http2/flow_test.go delete mode 100644 internal/http2/frame_test.go delete mode 100644 internal/http2/gotrack_test.go delete mode 100644 internal/http2/http2_test.go delete mode 100644 internal/http2/pipe_test.go delete mode 100644 internal/http2/server_test.go delete mode 100644 internal/http2/transport_go117_test.go delete mode 100644 internal/http2/transport_test.go diff --git a/internal/http2/databuffer_test.go b/internal/http2/databuffer_test.go deleted file mode 100644 index 32cd5f38..00000000 --- a/internal/http2/databuffer_test.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "bytes" - "fmt" - "reflect" - "testing" -) - -func fmtDataChunk(chunk []byte) string { - out := "" - var last byte - var count int - for _, c := range chunk { - if c != last { - if count > 0 { - out += fmt.Sprintf(" x %d ", count) - count = 0 - } - out += string([]byte{c}) - last = c - } - count++ - } - if count > 0 { - out += fmt.Sprintf(" x %d", count) - } - return out -} - -func fmtDataChunks(chunks [][]byte) string { - var out string - for _, chunk := range chunks { - out += fmt.Sprintf("{%q}", fmtDataChunk(chunk)) - } - return out -} - -func testDataBuffer(t *testing.T, wantBytes []byte, setup func(t *testing.T) *dataBuffer) { - // Run setup, then read the remaining bytes from the dataBuffer and check - // that they match wantBytes. We use different read sizes to check corner - // cases in Read. - for _, readSize := range []int{1, 2, 1 * 1024, 32 * 1024} { - t.Run(fmt.Sprintf("ReadSize=%d", readSize), func(t *testing.T) { - b := setup(t) - buf := make([]byte, readSize) - var gotRead bytes.Buffer - for { - n, err := b.Read(buf) - gotRead.Write(buf[:n]) - if err == errReadEmpty { - break - } - if err != nil { - t.Fatalf("error after %v bytes: %v", gotRead.Len(), err) - } - } - if got, want := gotRead.Bytes(), wantBytes; !bytes.Equal(got, want) { - t.Errorf("FinalRead=%q, want %q", fmtDataChunk(got), fmtDataChunk(want)) - } - }) - } -} - -func TestDataBufferAllocation(t *testing.T) { - writes := [][]byte{ - bytes.Repeat([]byte("a"), 1*1024-1), - []byte("a"), - bytes.Repeat([]byte("b"), 4*1024-1), - []byte("b"), - bytes.Repeat([]byte("c"), 8*1024-1), - []byte("c"), - bytes.Repeat([]byte("d"), 16*1024-1), - []byte("d"), - bytes.Repeat([]byte("e"), 32*1024), - } - var wantRead bytes.Buffer - for _, p := range writes { - wantRead.Write(p) - } - - testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *dataBuffer { - b := &dataBuffer{} - for _, p := range writes { - if n, err := b.Write(p); n != len(p) || err != nil { - t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) - } - } - want := [][]byte{ - bytes.Repeat([]byte("a"), 1*1024), - bytes.Repeat([]byte("b"), 4*1024), - bytes.Repeat([]byte("c"), 8*1024), - bytes.Repeat([]byte("d"), 16*1024), - bytes.Repeat([]byte("e"), 16*1024), - bytes.Repeat([]byte("e"), 16*1024), - } - if !reflect.DeepEqual(b.chunks, want) { - t.Errorf("dataBuffer.chunks\ngot: %s\nwant: %s", fmtDataChunks(b.chunks), fmtDataChunks(want)) - } - return b - }) -} - -func TestDataBufferAllocationWithExpected(t *testing.T) { - writes := [][]byte{ - bytes.Repeat([]byte("a"), 1*1024), // allocates 16KB - bytes.Repeat([]byte("b"), 14*1024), - bytes.Repeat([]byte("c"), 15*1024), // allocates 16KB more - bytes.Repeat([]byte("d"), 2*1024), - bytes.Repeat([]byte("e"), 1*1024), // overflows 32KB expectation, allocates just 1KB - } - var wantRead bytes.Buffer - for _, p := range writes { - wantRead.Write(p) - } - - testDataBuffer(t, wantRead.Bytes(), func(t *testing.T) *dataBuffer { - b := &dataBuffer{expected: 32 * 1024} - for _, p := range writes { - if n, err := b.Write(p); n != len(p) || err != nil { - t.Fatalf("Write(%q x %d)=%v,%v want %v,nil", p[:1], len(p), n, err, len(p)) - } - } - want := [][]byte{ - append(bytes.Repeat([]byte("a"), 1*1024), append(bytes.Repeat([]byte("b"), 14*1024), bytes.Repeat([]byte("c"), 1*1024)...)...), - append(bytes.Repeat([]byte("c"), 14*1024), bytes.Repeat([]byte("d"), 2*1024)...), - bytes.Repeat([]byte("e"), 1*1024), - } - if !reflect.DeepEqual(b.chunks, want) { - t.Errorf("dataBuffer.chunks\ngot: %s\nwant: %s", fmtDataChunks(b.chunks), fmtDataChunks(want)) - } - return b - }) -} - -func TestDataBufferWriteAfterPartialRead(t *testing.T) { - testDataBuffer(t, []byte("cdxyz"), func(t *testing.T) *dataBuffer { - b := &dataBuffer{} - if n, err := b.Write([]byte("abcd")); n != 4 || err != nil { - t.Fatalf("Write(\"abcd\")=%v,%v want 4,nil", n, err) - } - p := make([]byte, 2) - if n, err := b.Read(p); n != 2 || err != nil || !bytes.Equal(p, []byte("ab")) { - t.Fatalf("Read()=%q,%v,%v want \"ab\",2,nil", p, n, err) - } - if n, err := b.Write([]byte("xyz")); n != 3 || err != nil { - t.Fatalf("Write(\"xyz\")=%v,%v want 3,nil", n, err) - } - return b - }) -} diff --git a/internal/http2/errors_test.go b/internal/http2/errors_test.go deleted file mode 100644 index da5c58c3..00000000 --- a/internal/http2/errors_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import "testing" - -func TestErrCodeString(t *testing.T) { - tests := []struct { - err ErrCode - want string - }{ - {ErrCodeProtocol, "PROTOCOL_ERROR"}, - {0xd, "HTTP_1_1_REQUIRED"}, - {0xf, "unknown error code 0xf"}, - } - for i, tt := range tests { - got := tt.err.String() - if got != tt.want { - t.Errorf("%d. Error = %q; want %q", i, got, tt.want) - } - } -} diff --git a/internal/http2/flow_test.go b/internal/http2/flow_test.go deleted file mode 100644 index cae4f38c..00000000 --- a/internal/http2/flow_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import "testing" - -func TestInFlowTake(t *testing.T) { - var f inflow - f.init(100) - if !f.take(40) { - t.Fatalf("f.take(40) from 100: got false, want true") - } - if !f.take(40) { - t.Fatalf("f.take(40) from 60: got false, want true") - } - if f.take(40) { - t.Fatalf("f.take(40) from 20: got true, want false") - } - if !f.take(20) { - t.Fatalf("f.take(20) from 20: got false, want true") - } -} - -func TestInflowAddSmall(t *testing.T) { - var f inflow - f.init(0) - // Adding even a small amount when there is no flow causes an immediate send. - if got, want := f.add(1), int32(1); got != want { - t.Fatalf("f.add(1) to 1 = %v, want %v", got, want) - } -} - -func TestInflowAdd(t *testing.T) { - var f inflow - f.init(10 * inflowMinRefresh) - if got, want := f.add(inflowMinRefresh-1), int32(0); got != want { - t.Fatalf("f.add(minRefresh - 1) = %v, want %v", got, want) - } - if got, want := f.add(1), int32(inflowMinRefresh); got != want { - t.Fatalf("f.add(minRefresh) = %v, want %v", got, want) - } -} - -func TestTakeInflows(t *testing.T) { - var a, b inflow - a.init(10) - b.init(20) - if !takeInflows(&a, &b, 5) { - t.Fatalf("takeInflows(a, b, 5) from 10, 20: got false, want true") - } - if takeInflows(&a, &b, 6) { - t.Fatalf("takeInflows(a, b, 6) from 5, 15: got true, want false") - } - if !takeInflows(&a, &b, 5) { - t.Fatalf("takeInflows(a, b, 5) from 5, 15: got false, want true") - } -} - -func TestOutFlow(t *testing.T) { - var st outflow - var conn outflow - st.add(3) - conn.add(2) - - if got, want := st.available(), int32(3); got != want { - t.Errorf("available = %d; want %d", got, want) - } - st.setConnFlow(&conn) - if got, want := st.available(), int32(2); got != want { - t.Errorf("after parent setup, available = %d; want %d", got, want) - } - - st.take(2) - if got, want := conn.available(), int32(0); got != want { - t.Errorf("after taking 2, conn = %d; want %d", got, want) - } - if got, want := st.available(), int32(0); got != want { - t.Errorf("after taking 2, stream = %d; want %d", got, want) - } -} - -func TestOutFlowAdd(t *testing.T) { - var f outflow - if !f.add(1) { - t.Fatal("failed to add 1") - } - if !f.add(-1) { - t.Fatal("failed to add -1") - } - if got, want := f.available(), int32(0); got != want { - t.Fatalf("size = %d; want %d", got, want) - } - if !f.add(1<<31 - 1) { - t.Fatal("failed to add 2^31-1") - } - if got, want := f.available(), int32(1<<31-1); got != want { - t.Fatalf("size = %d; want %d", got, want) - } - if f.add(1) { - t.Fatal("adding 1 to max shouldn't be allowed") - } -} - -func TestOutFlowAddOverflow(t *testing.T) { - var f outflow - if !f.add(0) { - t.Fatal("failed to add 0") - } - if !f.add(-1) { - t.Fatal("failed to add -1") - } - if !f.add(0) { - t.Fatal("failed to add 0") - } - if !f.add(1) { - t.Fatal("failed to add 1") - } - if !f.add(1) { - t.Fatal("failed to add 1") - } - if !f.add(0) { - t.Fatal("failed to add 0") - } - if !f.add(-3) { - t.Fatal("failed to add -3") - } - if got, want := f.available(), int32(-2); got != want { - t.Fatalf("size = %d; want %d", got, want) - } - if !f.add(1<<31 - 1) { - t.Fatal("failed to add 2^31-1") - } - if got, want := f.available(), int32(1+-3+(1<<31-1)); got != want { - t.Fatalf("size = %d; want %d", got, want) - } - -} diff --git a/internal/http2/frame_test.go b/internal/http2/frame_test.go deleted file mode 100644 index ed5ec9c7..00000000 --- a/internal/http2/frame_test.go +++ /dev/null @@ -1,1479 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "bytes" - "fmt" - "github.com/imroc/req/v3/internal/tests" - "io" - "reflect" - "strings" - "testing" - "unsafe" - - "golang.org/x/net/http2/hpack" -) - -func testFramer() (*Framer, *bytes.Buffer) { - buf := new(bytes.Buffer) - return NewFramer(buf, buf), buf -} - -func TestFrameSizes(t *testing.T) { - // Catch people rearranging the FrameHeader fields. - if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want { - t.Errorf("FrameHeader size = %d; want %d", got, want) - } -} - -func TestFrameTypeString(t *testing.T) { - tests := []struct { - ft FrameType - want string - }{ - {FrameData, "DATA"}, - {FramePing, "PING"}, - {FrameGoAway, "GOAWAY"}, - {0xf, "UNKNOWN_FRAME_TYPE_15"}, - } - - for i, tt := range tests { - got := tt.ft.String() - if got != tt.want { - t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want) - } - } -} - -func TestWriteRST(t *testing.T) { - fr, buf := testFramer() - var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 - var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4 - fr.WriteRSTStream(streamID, ErrCode(errCode)) - const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - want := &RSTStreamFrame{ - FrameHeader: FrameHeader{ - valid: true, - Type: 0x3, - Flags: 0x0, - Length: 0x4, - StreamID: 0x1020304, - }, - ErrCode: 0x7060504, - } - if !reflect.DeepEqual(f, want) { - t.Errorf("parsed back %#v; want %#v", f, want) - } -} - -func TestWriteData(t *testing.T) { - fr, buf := testFramer() - var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 - data := []byte("ABC") - fr.WriteData(streamID, true, data) - const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - df, ok := f.(*DataFrame) - if !ok { - t.Fatalf("got %T; want *DataFrame", f) - } - if !bytes.Equal(df.Data(), data) { - t.Errorf("got %q; want %q", df.Data(), data) - } - if f.Header().Flags&1 == 0 { - t.Errorf("didn't see END_STREAM flag") - } -} - -func TestWriteDataPadded(t *testing.T) { - tests := [...]struct { - streamID uint32 - endStream bool - data []byte - pad []byte - wantHeader FrameHeader - }{ - // Unpadded: - 0: { - streamID: 1, - endStream: true, - data: []byte("foo"), - pad: nil, - wantHeader: FrameHeader{ - Type: FrameData, - Flags: FlagDataEndStream, - Length: 3, - StreamID: 1, - }, - }, - - // Padded bit set, but no padding: - 1: { - streamID: 1, - endStream: true, - data: []byte("foo"), - pad: []byte{}, - wantHeader: FrameHeader{ - Type: FrameData, - Flags: FlagDataEndStream | FlagDataPadded, - Length: 4, - StreamID: 1, - }, - }, - - // Padded bit set, with padding: - 2: { - streamID: 1, - endStream: false, - data: []byte("foo"), - pad: []byte{0, 0, 0}, - wantHeader: FrameHeader{ - Type: FrameData, - Flags: FlagDataPadded, - Length: 7, - StreamID: 1, - }, - }, - } - for i, tt := range tests { - fr, _ := testFramer() - fr.WriteDataPadded(tt.streamID, tt.endStream, tt.data, tt.pad) - f, err := fr.ReadFrame() - if err != nil { - t.Errorf("%d. ReadFrame: %v", i, err) - continue - } - got := f.Header() - tt.wantHeader.valid = true - if !got.Equal(tt.wantHeader) { - t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader) - continue - } - df := f.(*DataFrame) - if !bytes.Equal(df.Data(), tt.data) { - t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data) - } - } -} - -func (fh FrameHeader) Equal(b FrameHeader) bool { - return fh.valid == b.valid && - fh.Type == b.Type && - fh.Flags == b.Flags && - fh.Length == b.Length && - fh.StreamID == b.StreamID -} - -func TestWriteHeaders(t *testing.T) { - tests := []struct { - name string - p HeadersFrameParam - wantEnc string - wantFrame *HeadersFrame - }{ - { - "basic", - HeadersFrameParam{ - StreamID: 42, - BlockFragment: []byte("abc"), - Priority: PriorityParam{}, - }, - "\x00\x00\x03\x01\x00\x00\x00\x00*abc", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Length: uint32(len("abc")), - }, - Priority: PriorityParam{}, - headerFragBuf: []byte("abc"), - }, - }, - { - "basic + end flags", - HeadersFrameParam{ - StreamID: 42, - BlockFragment: []byte("abc"), - EndStream: true, - EndHeaders: true, - Priority: PriorityParam{}, - }, - "\x00\x00\x03\x01\x05\x00\x00\x00*abc", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Flags: FlagHeadersEndStream | FlagHeadersEndHeaders, - Length: uint32(len("abc")), - }, - Priority: PriorityParam{}, - headerFragBuf: []byte("abc"), - }, - }, - { - "with padding", - HeadersFrameParam{ - StreamID: 42, - BlockFragment: []byte("abc"), - EndStream: true, - EndHeaders: true, - PadLength: 5, - Priority: PriorityParam{}, - }, - "\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded, - Length: uint32(1 + len("abc") + 5), // pad length + contents + padding - }, - Priority: PriorityParam{}, - headerFragBuf: []byte("abc"), - }, - }, - { - "with priority", - HeadersFrameParam{ - StreamID: 42, - BlockFragment: []byte("abc"), - EndStream: true, - EndHeaders: true, - PadLength: 2, - Priority: PriorityParam{ - StreamDep: 15, - Exclusive: true, - Weight: 127, - }, - }, - "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, - Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding - }, - Priority: PriorityParam{ - StreamDep: 15, - Exclusive: true, - Weight: 127, - }, - headerFragBuf: []byte("abc"), - }, - }, - { - "with priority stream dep zero", // golang.org/issue/15444 - HeadersFrameParam{ - StreamID: 42, - BlockFragment: []byte("abc"), - EndStream: true, - EndHeaders: true, - PadLength: 2, - Priority: PriorityParam{ - StreamDep: 0, - Exclusive: true, - Weight: 127, - }, - }, - "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, - Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding - }, - Priority: PriorityParam{ - StreamDep: 0, - Exclusive: true, - Weight: 127, - }, - headerFragBuf: []byte("abc"), - }, - }, - { - "zero length", - HeadersFrameParam{ - StreamID: 42, - Priority: PriorityParam{}, - }, - "\x00\x00\x00\x01\x00\x00\x00\x00*", - &HeadersFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: 42, - Type: FrameHeaders, - Length: 0, - }, - Priority: PriorityParam{}, - }, - }, - } - for _, tt := range tests { - fr, buf := testFramer() - if err := fr.WriteHeaders(tt.p); err != nil { - t.Errorf("test %q: %v", tt.name, err) - continue - } - if buf.String() != tt.wantEnc { - t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) - continue - } - if !reflect.DeepEqual(f, tt.wantFrame) { - t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) - } - } -} - -func TestWriteInvalidStreamDep(t *testing.T) { - fr, _ := testFramer() - err := fr.WriteHeaders(HeadersFrameParam{ - StreamID: 42, - Priority: PriorityParam{ - StreamDep: 1 << 31, - }, - }) - if err != errDepStreamID { - t.Errorf("header error = %v; want %q", err, errDepStreamID) - } - - err = fr.WritePriority(2, PriorityParam{StreamDep: 1 << 31}) - if err != errDepStreamID { - t.Errorf("priority error = %v; want %q", err, errDepStreamID) - } -} - -func TestWriteContinuation(t *testing.T) { - const streamID = 42 - tests := []struct { - name string - end bool - frag []byte - - wantFrame *ContinuationFrame - }{ - { - "not end", - false, - []byte("abc"), - &ContinuationFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: streamID, - Type: FrameContinuation, - Length: uint32(len("abc")), - }, - headerFragBuf: []byte("abc"), - }, - }, - { - "end", - true, - []byte("def"), - &ContinuationFrame{ - FrameHeader: FrameHeader{ - valid: true, - StreamID: streamID, - Type: FrameContinuation, - Flags: FlagContinuationEndHeaders, - Length: uint32(len("def")), - }, - headerFragBuf: []byte("def"), - }, - }, - } - for _, tt := range tests { - fr, _ := testFramer() - if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil { - t.Errorf("test %q: %v", tt.name, err) - continue - } - fr.AllowIllegalReads = true - f, err := fr.ReadFrame() - if err != nil { - t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) - continue - } - if !reflect.DeepEqual(f, tt.wantFrame) { - t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) - } - } -} - -func TestWritePriority(t *testing.T) { - const streamID = 42 - tests := []struct { - name string - priority PriorityParam - wantFrame *PriorityFrame - }{ - { - "not exclusive", - PriorityParam{ - StreamDep: 2, - Exclusive: false, - Weight: 127, - }, - &PriorityFrame{ - FrameHeader{ - valid: true, - StreamID: streamID, - Type: FramePriority, - Length: 5, - }, - PriorityParam{ - StreamDep: 2, - Exclusive: false, - Weight: 127, - }, - }, - }, - - { - "exclusive", - PriorityParam{ - StreamDep: 3, - Exclusive: true, - Weight: 77, - }, - &PriorityFrame{ - FrameHeader{ - valid: true, - StreamID: streamID, - Type: FramePriority, - Length: 5, - }, - PriorityParam{ - StreamDep: 3, - Exclusive: true, - Weight: 77, - }, - }, - }, - } - for _, tt := range tests { - fr, _ := testFramer() - if err := fr.WritePriority(streamID, tt.priority); err != nil { - t.Errorf("test %q: %v", tt.name, err) - continue - } - f, err := fr.ReadFrame() - if err != nil { - t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) - continue - } - if !reflect.DeepEqual(f, tt.wantFrame) { - t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) - } - } -} - -func TestWriteSettings(t *testing.T) { - fr, buf := testFramer() - settings := []Setting{{1, 2}, {3, 4}} - fr.WriteSettings(settings...) - const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - sf, ok := f.(*SettingsFrame) - if !ok { - t.Fatalf("Got a %T; want a SettingsFrame", f) - } - var got []Setting - sf.ForeachSetting(func(s Setting) error { - got = append(got, s) - valBack, ok := sf.Value(s.ID) - if !ok || valBack != s.Val { - t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val) - } - return nil - }) - if !reflect.DeepEqual(settings, got) { - t.Errorf("Read settings %+v != written settings %+v", got, settings) - } -} - -func TestWriteSettingsAck(t *testing.T) { - fr, buf := testFramer() - fr.WriteSettingsAck() - const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } -} - -func TestWriteWindowUpdate(t *testing.T) { - fr, buf := testFramer() - const streamID = 1<<24 + 2<<16 + 3<<8 + 4 - const incr = 7<<24 + 6<<16 + 5<<8 + 4 - if err := fr.WriteWindowUpdate(streamID, incr); err != nil { - t.Fatal(err) - } - const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - want := &WindowUpdateFrame{ - FrameHeader: FrameHeader{ - valid: true, - Type: 0x8, - Flags: 0x0, - Length: 0x4, - StreamID: 0x1020304, - }, - Increment: 0x7060504, - } - if !reflect.DeepEqual(f, want) { - t.Errorf("parsed back %#v; want %#v", f, want) - } -} - -func TestWritePing(t *testing.T) { testWritePing(t, false) } -func TestWritePingAck(t *testing.T) { testWritePing(t, true) } - -func testWritePing(t *testing.T, ack bool) { - fr, buf := testFramer() - if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { - t.Fatal(err) - } - var wantFlags Flags - if ack { - wantFlags = FlagPingAck - } - var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - want := &PingFrame{ - FrameHeader: FrameHeader{ - valid: true, - Type: 0x6, - Flags: wantFlags, - Length: 0x8, - StreamID: 0, - }, - Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, - } - if !reflect.DeepEqual(f, want) { - t.Errorf("parsed back %#v; want %#v", f, want) - } -} - -func TestReadFrameHeader(t *testing.T) { - tests := []struct { - in string - want FrameHeader - }{ - {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}}, - {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{ - Length: 66051, Type: 4, Flags: 5, StreamID: 101124105, - }}, - // Ignore high bit: - {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{ - Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, - {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{ - Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, - } - for i, tt := range tests { - got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in)) - if err != nil { - t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err) - continue - } - tt.want.valid = true - if !got.Equal(tt.want) { - t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want) - } - } -} - -func TestReadWriteFrameHeader(t *testing.T) { - tests := []struct { - len uint32 - typ FrameType - flags Flags - streamID uint32 - }{ - {len: 0, typ: 255, flags: 1, streamID: 0}, - {len: 0, typ: 255, flags: 1, streamID: 1}, - {len: 0, typ: 255, flags: 1, streamID: 255}, - {len: 0, typ: 255, flags: 1, streamID: 256}, - {len: 0, typ: 255, flags: 1, streamID: 65535}, - {len: 0, typ: 255, flags: 1, streamID: 65536}, - - {len: 0, typ: 1, flags: 255, streamID: 1}, - {len: 255, typ: 1, flags: 255, streamID: 1}, - {len: 256, typ: 1, flags: 255, streamID: 1}, - {len: 65535, typ: 1, flags: 255, streamID: 1}, - {len: 65536, typ: 1, flags: 255, streamID: 1}, - {len: 16777215, typ: 1, flags: 255, streamID: 1}, - } - for _, tt := range tests { - fr, buf := testFramer() - fr.startWrite(tt.typ, tt.flags, tt.streamID) - fr.writeBytes(make([]byte, tt.len)) - fr.endWrite() - fh, err := ReadFrameHeader(buf) - if err != nil { - t.Errorf("ReadFrameHeader(%+v) = %v", tt, err) - continue - } - if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID { - t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh) - } - } - -} - -func TestWriteTooLargeFrame(t *testing.T) { - fr, _ := testFramer() - fr.startWrite(0, 1, 1) - fr.writeBytes(make([]byte, 1<<24)) - err := fr.endWrite() - if err != errFrameTooLarge { - t.Errorf("endWrite = %v; want errFrameTooLarge", err) - } -} - -func TestWriteGoAway(t *testing.T) { - const debug = "foo" - fr, buf := testFramer() - if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil { - t.Fatal(err) - } - const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - want := &GoAwayFrame{ - FrameHeader: FrameHeader{ - valid: true, - Type: 0x7, - Flags: 0, - Length: uint32(4 + 4 + len(debug)), - StreamID: 0, - }, - LastStreamID: 0x01020304, - ErrCode: 0x05060708, - debugData: []byte(debug), - } - if !reflect.DeepEqual(f, want) { - t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) - } - if got := string(f.(*GoAwayFrame).DebugData()); got != debug { - t.Errorf("debug data = %q; want %q", got, debug) - } -} - -func TestWritePushPromise(t *testing.T) { - pp := PushPromiseParam{ - StreamID: 42, - PromiseID: 42, - BlockFragment: []byte("abc"), - } - fr, buf := testFramer() - if err := fr.WritePushPromise(pp); err != nil { - t.Fatal(err) - } - const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc" - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - _, ok := f.(*PushPromiseFrame) - if !ok { - t.Fatalf("got %T; want *PushPromiseFrame", f) - } - want := &PushPromiseFrame{ - FrameHeader: FrameHeader{ - valid: true, - Type: 0x5, - Flags: 0x0, - Length: 0x7, - StreamID: 42, - }, - PromiseID: 42, - headerFragBuf: []byte("abc"), - } - if !reflect.DeepEqual(f, want) { - t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) - } -} - -// test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled. -func TestReadFrameOrder(t *testing.T) { - head := func(f *Framer, id uint32, end bool) { - f.WriteHeaders(HeadersFrameParam{ - StreamID: id, - BlockFragment: []byte("foo"), // unused, but non-empty - EndHeaders: end, - }) - } - cont := func(f *Framer, id uint32, end bool) { - f.WriteContinuation(id, end, []byte("foo")) - } - - tests := [...]struct { - name string - w func(*Framer) - atLeast int - wantErr string - }{ - 0: { - w: func(f *Framer) { - head(f, 1, true) - }, - }, - 1: { - w: func(f *Framer) { - head(f, 1, true) - head(f, 2, true) - }, - }, - 2: { - wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1", - w: func(f *Framer) { - head(f, 1, false) - head(f, 2, true) - }, - }, - 3: { - wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1", - w: func(f *Framer) { - head(f, 1, false) - }, - }, - 4: { - w: func(f *Framer) { - head(f, 1, false) - cont(f, 1, true) - head(f, 2, true) - }, - }, - 5: { - wantErr: "got CONTINUATION for stream 2; expected stream 1", - w: func(f *Framer) { - head(f, 1, false) - cont(f, 2, true) - head(f, 2, true) - }, - }, - 6: { - wantErr: "unexpected CONTINUATION for stream 1", - w: func(f *Framer) { - cont(f, 1, true) - }, - }, - 7: { - wantErr: "unexpected CONTINUATION for stream 1", - w: func(f *Framer) { - cont(f, 1, false) - }, - }, - 8: { - wantErr: "HEADERS frame with stream ID 0", - w: func(f *Framer) { - head(f, 0, true) - }, - }, - 9: { - wantErr: "CONTINUATION frame with stream ID 0", - w: func(f *Framer) { - cont(f, 0, true) - }, - }, - 10: { - wantErr: "unexpected CONTINUATION for stream 1", - atLeast: 5, - w: func(f *Framer) { - head(f, 1, false) - cont(f, 1, false) - cont(f, 1, false) - cont(f, 1, false) - cont(f, 1, true) - cont(f, 1, false) - }, - }, - } - for i, tt := range tests { - buf := new(bytes.Buffer) - f := NewFramer(buf, buf) - f.AllowIllegalWrites = true - tt.w(f) - f.WriteData(1, true, nil) // to test transition away from last step - - var err error - n := 0 - var log bytes.Buffer - for { - var got Frame - got, err = f.ReadFrame() - fmt.Fprintf(&log, " read %v, %v\n", got, err) - if err != nil { - break - } - n++ - } - if err == io.EOF { - err = nil - } - ok := tt.wantErr == "" - if ok && err != nil { - t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes()) - continue - } - if !ok && err != ConnectionError(ErrCodeProtocol) { - t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes()) - continue - } - if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) { - t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes()) - } - if n < tt.atLeast { - t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes()) - } - } -} - -type hpackEncoder struct { - enc *hpack.Encoder - buf bytes.Buffer -} - -func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { - if len(headers)%2 == 1 { - panic("odd number of kv args") - } - he.buf.Reset() - if he.enc == nil { - he.enc = hpack.NewEncoder(&he.buf) - } - for len(headers) > 0 { - k, v := headers[0], headers[1] - err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v}) - if err != nil { - t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) - } - headers = headers[2:] - } - return he.buf.Bytes() -} - -func TestMetaFrameHeader(t *testing.T) { - write := func(f *Framer, frags ...[]byte) { - for i, frag := range frags { - end := (i == len(frags)-1) - if i == 0 { - f.WriteHeaders(HeadersFrameParam{ - StreamID: 1, - BlockFragment: frag, - EndHeaders: end, - }) - } else { - f.WriteContinuation(1, end, frag) - } - } - } - - want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame { - mh := &MetaHeadersFrame{ - HeadersFrame: &HeadersFrame{ - FrameHeader: FrameHeader{ - Type: FrameHeaders, - Flags: flags, - Length: length, - StreamID: 1, - }, - }, - Fields: []hpack.HeaderField(nil), - } - for len(pairs) > 0 { - mh.Fields = append(mh.Fields, hpack.HeaderField{ - Name: pairs[0], - Value: pairs[1], - }) - pairs = pairs[2:] - } - return mh - } - truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame { - mh.Truncated = true - return mh - } - - const noFlags Flags = 0 - - oneKBString := strings.Repeat("a", 1<<10) - - tests := [...]struct { - name string - w func(*Framer) - want interface{} // *MetaHeaderFrame or error - wantErrReason string - maxHeaderListSize uint32 - }{ - 0: { - name: "single_headers", - w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/") - write(f, all) - }, - want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"), - }, - 1: { - name: "with_continuation", - w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") - write(f, all[:1], all[1:]) - }, - want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"), - }, - 2: { - name: "with_two_continuation", - w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") - write(f, all[:2], all[2:4], all[4:]) - }, - want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"), - }, - 3: { - name: "big_string_okay", - w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) - write(f, all[:2], all[2:]) - }, - want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString), - }, - 4: { - name: "big_string_error", - w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) - write(f, all[:2], all[2:]) - }, - maxHeaderListSize: (1 << 10) / 2, - want: ConnectionError(ErrCodeCompression), - }, - 5: { - name: "max_header_list_truncated", - w: func(f *Framer) { - var he hpackEncoder - var pairs = []string{":method", "GET", ":path", "/"} - for i := 0; i < 100; i++ { - pairs = append(pairs, "foo", "bar") - } - all := he.encodeHeaderRaw(t, pairs...) - write(f, all[:2], all[2:]) - }, - maxHeaderListSize: (1 << 10) / 2, - want: truncated(want(noFlags, 2, - ":method", "GET", - ":path", "/", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", - "foo", "bar", // 11 - )), - }, - 6: { - name: "pseudo_order", - w: func(f *Framer) { - write(f, encodeHeaderRaw(t, - ":method", "GET", - "foo", "bar", - ":path", "/", // bogus - )) - }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: "pseudo header field after regular", - }, - 7: { - name: "pseudo_unknown", - w: func(f *Framer) { - write(f, encodeHeaderRaw(t, - ":unknown", "foo", // bogus - "foo", "bar", - )) - }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: "invalid pseudo-header \":unknown\"", - }, - 8: { - name: "pseudo_mix_request_response", - w: func(f *Framer) { - write(f, encodeHeaderRaw(t, - ":method", "GET", - ":status", "100", - )) - }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: "mix of request and response pseudo headers", - }, - 9: { - name: "pseudo_dup", - w: func(f *Framer) { - write(f, encodeHeaderRaw(t, - ":method", "GET", - ":method", "POST", - )) - }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: "duplicate pseudo-header \":method\"", - }, - 10: { - name: "trailer_okay_no_pseudo", - w: func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) }, - want: want(FlagHeadersEndHeaders, 8, "foo", "bar"), - }, - 11: { - name: "invalid_field_name", - w: func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: "invalid header field name \"CapitalBad\"", - }, - 12: { - name: "invalid_field_value", - w: func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, - want: streamError(1, ErrCodeProtocol), - wantErrReason: `invalid header field value for "key"`, - }, - } - for i, tt := range tests { - buf := new(bytes.Buffer) - f := NewFramer(buf, buf) - f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) - f.MaxHeaderListSize = tt.maxHeaderListSize - tt.w(f) - - name := tt.name - if name == "" { - name = fmt.Sprintf("test index %d", i) - } - - var got interface{} - var err error - got, err = f.ReadFrame() - if err != nil { - got = err - - // Ignore the StreamError.Cause field, if it matches the wantErrReason. - // The test table above predates the Cause field. - if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason { - se.Cause = nil - got = se - } - } - if !reflect.DeepEqual(got, tt.want) { - if mhg, ok := got.(*MetaHeadersFrame); ok { - if mhw, ok := tt.want.(*MetaHeadersFrame); ok { - hg := mhg.HeadersFrame - hw := mhw.HeadersFrame - if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) { - t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw) - } - } - } - str := func(v interface{}) string { - if _, ok := v.(error); ok { - return fmt.Sprintf("error %v", v) - } else { - return fmt.Sprintf("value %#v", v) - } - } - t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want)) - } - if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) { - t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason) - } - } -} - -func TestSetReuseFrames(t *testing.T) { - fr, buf := testFramer() - fr.SetReuseFrames() - - // Check that DataFrames are reused. Note that - // SetReuseFrames only currently implements reuse of DataFrames. - firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t) - - for i := 0; i < 10; i++ { - df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) - if df != firstDf { - t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) - } - } - - for i := 0; i < 10; i++ { - df := readAndVerifyDataFrame("", 0, fr, buf, t) - if df != firstDf { - t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) - } - } - - for i := 0; i < 10; i++ { - df := readAndVerifyDataFrame("HHH", 3, fr, buf, t) - if df != firstDf { - t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) - } - } -} - -func TestSetReuseFramesMoreThanOnce(t *testing.T) { - fr, buf := testFramer() - fr.SetReuseFrames() - - firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t) - fr.SetReuseFrames() - - for i := 0; i < 10; i++ { - df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) - // SetReuseFrames should be idempotent - fr.SetReuseFrames() - if df != firstDf { - t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf) - } - } -} - -func TestNoSetReuseFrames(t *testing.T) { - fr, buf := testFramer() - const numNewDataFrames = 10 - dfSoFar := make([]interface{}, numNewDataFrames) - - // Check that DataFrames are not reused if SetReuseFrames wasn't called. - // SetReuseFrames only currently implements reuse of DataFrames. - for i := 0; i < numNewDataFrames; i++ { - df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t) - for _, item := range dfSoFar { - if df == item { - t.Errorf("Expected Framer to return new DataFrames since SetNoReuseFrames not set.") - } - } - dfSoFar[i] = df - } -} - -func readAndVerifyDataFrame(data string, length byte, fr *Framer, buf *bytes.Buffer, t *testing.T) *DataFrame { - var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 - fr.WriteData(streamID, true, []byte(data)) - wantEnc := "\x00\x00" + string(length) + "\x00\x01\x01\x02\x03\x04" + data - if buf.String() != wantEnc { - t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) - } - f, err := fr.ReadFrame() - if err != nil { - t.Fatal(err) - } - df, ok := f.(*DataFrame) - if !ok { - t.Fatalf("got %T; want *DataFrame", f) - } - if !bytes.Equal(df.Data(), []byte(data)) { - t.Errorf("got %q; want %q", df.Data(), []byte(data)) - } - if f.Header().Flags&1 == 0 { - t.Errorf("didn't see END_STREAM flag") - } - return df -} - -func encodeHeaderRaw(t *testing.T, pairs ...string) []byte { - var he hpackEncoder - return he.encodeHeaderRaw(t, pairs...) -} - -func TestSettingsDuplicates(t *testing.T) { - tests := []struct { - settings []Setting - want bool - }{ - {nil, false}, - {[]Setting{{ID: 1}}, false}, - {[]Setting{{ID: 1}, {ID: 2}}, false}, - {[]Setting{{ID: 1}, {ID: 2}}, false}, - {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, - {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false}, - {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}, false}, - - {[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 2}}, true}, - {[]Setting{{ID: 4}, {ID: 2}, {ID: 3}, {ID: 4}}, true}, - - {[]Setting{ - {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, - {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, - {ID: 9}, {ID: 10}, {ID: 11}, {ID: 12}, - }, false}, - - {[]Setting{ - {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, - {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, - {ID: 9}, {ID: 10}, {ID: 11}, {ID: 11}, - }, true}, - } - for i, tt := range tests { - fr, _ := testFramer() - fr.WriteSettings(tt.settings...) - f, err := fr.ReadFrame() - if err != nil { - t.Fatalf("%d. ReadFrame: %v", i, err) - } - sf := f.(*SettingsFrame) - got := sf.HasDuplicates() - if got != tt.want { - t.Errorf("%d. HasDuplicates = %v; want %v", i, got, tt.want) - } - } - -} - -func TestParseSettingsFrame(t *testing.T) { - fh := FrameHeader{} - fh.Flags = FlagSettingsAck - fh.Length = 1 - countErr := func(s string) {} - _, err := parseSettingsFrame(nil, fh, countErr, nil) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - - fh = FrameHeader{StreamID: 1} - _, err = parseSettingsFrame(nil, fh, countErr, nil) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh = FrameHeader{} - _, err = parseSettingsFrame(nil, fh, countErr, []byte("roc")) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - - fh = FrameHeader{valid: true} - _, err = parseSettingsFrame(nil, fh, countErr, []byte("rocroc")) - tests.AssertNoError(t, err) -} - -func TestParsePushPromise(t *testing.T) { - fh := FrameHeader{} - countError := func(string) {} - _, err := parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh.StreamID = 1 - fh.Flags = FlagPushPromisePadded - _, err = parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "EOF") - - fh.Flags = 0 - _, err = parsePushPromise(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "EOF") - - _, err = parsePushPromise(nil, fh, countError, []byte("ksjfksjksjflskk")) - tests.AssertNoError(t, err) -} - -func TestSummarizeFrame(t *testing.T) { - fh := FrameHeader{valid: true} - var f Frame - f = &SettingsFrame{FrameHeader: fh, p: []byte{0x09, 0x01, 0x80, 0x20, 0x00, 0x11}} - s := summarizeFrame(f) - tests.AssertContains(t, s, "len=0", true) - - f = &DataFrame{FrameHeader: fh} - s = summarizeFrame(f) - tests.AssertContains(t, s, `data=""`, true) - - f = &WindowUpdateFrame{FrameHeader: fh} - s = summarizeFrame(f) - tests.AssertContains(t, s, "conn", true) - - f = &PingFrame{FrameHeader: fh} - s = summarizeFrame(f) - tests.AssertContains(t, s, "ping", true) - - f = &GoAwayFrame{FrameHeader: fh} - s = summarizeFrame(f) - tests.AssertContains(t, s, "laststreamid", true) - - f = &RSTStreamFrame{FrameHeader: fh} - s = summarizeFrame(f) - tests.AssertContains(t, s, "no_error", true) -} - -func TestParseDataFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - _, err := parseDataFrame(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "DATA frame with stream ID 0") - - fh.StreamID = 1 - fh.Flags = FlagDataPadded - fc := &frameCache{} - payload := []byte{0x09, 0x00, 0x00, 0x98, 0x11, 0x12} - _, err = parseDataFrame(fc, fh, countError, payload) - tests.AssertErrorContains(t, err, "pad size larger than data payload") - - payload = []byte{0x02, 0x00, 0x00, 0x98, 0x11, 0x12} - _, err = parseDataFrame(fc, fh, countError, payload) - tests.AssertNoError(t, err) -} - -func TestParseWindowUpdateFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - _, err := parseWindowUpdateFrame(nil, fh, countError, nil) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - - p := []byte{0x00, 0x00, 0x00, 0x00} - _, err = parseWindowUpdateFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh.StreamID = 255 - p[0] = 0x01 - p[3] = 0x01 - _, err = parseWindowUpdateFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) -} - -func TestParseUnknownFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - p := []byte("test") - f, err := parseUnknownFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) - uf, ok := f.(*UnknownFrame) - if !ok { - t.Fatalf("not UnknownFrame type: %#+v", f) - } - tests.AssertEqual(t, p, uf.Payload()) -} - -func TestParseRSTStreamFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - p := []byte("test.") - _, err := parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - - p = []byte("test") - _, err = parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh.StreamID = 1 - _, err = parseRSTStreamFrame(nil, fh, countError, p) - tests.AssertNoError(t, err) -} - -func TestParsePingFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - payload := []byte("") - _, err := parsePingFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") - - payload = []byte("testtest") - fh.StreamID = 1 - _, err = parsePingFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh.StreamID = 0 - _, err = parsePingFrame(nil, fh, countError, payload) - tests.AssertNoError(t, err) -} - -func TestParseGoAwayFrame(t *testing.T) { - fh := FrameHeader{valid: true} - countError := func(string) {} - payload := []byte("") - - fh.StreamID = 1 - _, err := parseGoAwayFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "PROTOCOL_ERROR") - - fh.StreamID = 0 - _, err = parseGoAwayFrame(nil, fh, countError, payload) - tests.AssertErrorContains(t, err, "FRAME_SIZE_ERROR") -} - -func TestPushPromiseFrame(t *testing.T) { - fh := FrameHeader{valid: true} - buf := []byte("test") - f := &PushPromiseFrame{FrameHeader: fh, headerFragBuf: buf} - tests.AssertEqual(t, buf, f.HeaderBlockFragment()) - tests.AssertEqual(t, false, f.HeadersEnded()) -} - -func TestH2Framer(t *testing.T) { - f := &Framer{} - f.debugWriteLoggerf = func(s string, i ...interface{}) {} - f.logWrite() - tests.AssertNotNil(t, f.debugFramer) - tests.AssertIsNil(t, f.ErrorDetail()) - - f.w = new(bytes.Buffer) - err := f.WriteRawFrame(FrameData, FlagDataEndStream, 1, nil) - tests.AssertNoError(t, err) - - param := PushPromiseParam{} - err = f.WritePushPromise(param) - tests.AssertErrorContains(t, err, "invalid stream ID") - - param.StreamID = 1 - param.EndHeaders = true - param.PadLength = 2 - f.AllowIllegalWrites = true - err = f.WritePushPromise(param) - tests.AssertNoError(t, err) -} diff --git a/internal/http2/gotrack_test.go b/internal/http2/gotrack_test.go deleted file mode 100644 index 55d2d3a1..00000000 --- a/internal/http2/gotrack_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "fmt" - "github.com/imroc/req/v3/internal/tests" - "strings" - "testing" -) - -func TestGoroutineLock(t *testing.T) { - oldDebug := DebugGoroutines - DebugGoroutines = true - defer func() { DebugGoroutines = oldDebug }() - - g := newGoroutineLock() - g.check() - - sawPanic := make(chan interface{}) - go func() { - defer func() { sawPanic <- recover() }() - g.check() // should panic - }() - e := <-sawPanic - if e == nil { - t.Fatal("did not see panic from check in other goroutine") - } - if !strings.Contains(fmt.Sprint(e), "wrong goroutine") { - t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e) - } -} - -func TestParseUintBytes(t *testing.T) { - s := []byte{} - _, err := parseUintBytes(s, 0, 0) - tests.AssertErrorContains(t, err, "invalid syntax") - - s = []byte("0x") - _, err = parseUintBytes(s, 0, 0) - tests.AssertErrorContains(t, err, "invalid syntax") - - s = []byte("0x01") - _, err = parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) - - s = []byte("0xa1") - _, err = parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) - - s = []byte("0xA1") - _, err = parseUintBytes(s, 0, 0) - tests.AssertNoError(t, err) -} diff --git a/internal/http2/http2_test.go b/internal/http2/http2_test.go deleted file mode 100644 index c905758d..00000000 --- a/internal/http2/http2_test.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "flag" - "fmt" - "github.com/imroc/req/v3/internal/tests" - "net/http" - "testing" - "time" -) - -func init() { - inTests = true - DebugGoroutines = true - flag.BoolVar(&VerboseLogs, "verboseh2", VerboseLogs, "Verbose HTTP/2 debug logging") -} - -func TestSettingString(t *testing.T) { - tests := []struct { - s Setting - want string - }{ - {Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"}, - {Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"}, - } - for i, tt := range tests { - got := fmt.Sprint(tt.s) - if got != tt.want { - t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want) - } - } -} - -func cleanDate(res *http.Response) { - if d := res.Header["Date"]; len(d) == 1 { - d[0] = "XXX" - } -} - -func TestSorterPoolAllocs(t *testing.T) { - ss := []string{"a", "b", "c"} - h := http.Header{ - "a": nil, - "b": nil, - "c": nil, - } - sorter := new(sorter) - - if allocs := testing.AllocsPerRun(100, func() { - sorter.SortStrings(ss) - }); allocs >= 1 { - t.Logf("SortStrings allocs = %v; want <1", allocs) - } - - if allocs := testing.AllocsPerRun(5, func() { - if len(sorter.Keys(h)) != 3 { - t.Fatal("wrong result") - } - }); allocs > 0 { - t.Logf("Keys allocs = %v; want <1", allocs) - } -} - -// waitCondition reports whether fn eventually returned true, -// checking immediately and then every checkEvery amount, -// until waitFor has elapsed, at which point it returns false. -func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { - deadline := time.Now().Add(waitFor) - for time.Now().Before(deadline) { - if fn() { - return true - } - time.Sleep(checkEvery) - } - return false -} - -func TestSettingValid(t *testing.T) { - cases := []struct { - id SettingID - val uint32 - }{ - { - id: SettingEnablePush, - val: 2, - }, - { - id: SettingInitialWindowSize, - val: 1 << 31, - }, - { - id: SettingMaxFrameSize, - val: 0, - }, - } - for _, c := range cases { - s := &Setting{ID: c.id, Val: c.val} - tests.AssertEqual(t, true, s.Valid() != nil) - } - s := &Setting{ID: SettingMaxHeaderListSize} - tests.AssertEqual(t, true, s.Valid() == nil) -} - -func TestBodyAllowedForStatus(t *testing.T) { - tests.AssertEqual(t, false, bodyAllowedForStatus(101)) - tests.AssertEqual(t, false, bodyAllowedForStatus(204)) - tests.AssertEqual(t, false, bodyAllowedForStatus(304)) - tests.AssertEqual(t, true, bodyAllowedForStatus(900)) -} - -func TestHttpError(t *testing.T) { - e := &httpError{msg: "test"} - tests.AssertEqual(t, "test", e.Error()) - tests.AssertEqual(t, true, e.Temporary()) -} diff --git a/internal/http2/pipe_test.go b/internal/http2/pipe_test.go deleted file mode 100644 index 326b94de..00000000 --- a/internal/http2/pipe_test.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "bytes" - "errors" - "io" - "testing" -) - -func TestPipeClose(t *testing.T) { - var p pipe - p.b = new(bytes.Buffer) - a := errors.New("a") - b := errors.New("b") - p.CloseWithError(a) - p.CloseWithError(b) - _, err := p.Read(make([]byte, 1)) - if err != a { - t.Errorf("err = %v want %v", err, a) - } -} - -func TestPipeDoneChan(t *testing.T) { - var p pipe - done := p.Done() - select { - case <-done: - t.Fatal("done too soon") - default: - } - p.CloseWithError(io.EOF) - select { - case <-done: - default: - t.Fatal("should be done") - } -} - -func TestPipeDoneChan_ErrFirst(t *testing.T) { - var p pipe - p.CloseWithError(io.EOF) - done := p.Done() - select { - case <-done: - default: - t.Fatal("should be done") - } -} - -func TestPipeDoneChan_Break(t *testing.T) { - var p pipe - done := p.Done() - select { - case <-done: - t.Fatal("done too soon") - default: - } - p.BreakWithError(io.EOF) - select { - case <-done: - default: - t.Fatal("should be done") - } -} - -func TestPipeDoneChan_Break_ErrFirst(t *testing.T) { - var p pipe - p.BreakWithError(io.EOF) - done := p.Done() - select { - case <-done: - default: - t.Fatal("should be done") - } -} - -func TestPipeCloseWithError(t *testing.T) { - p := &pipe{b: new(bytes.Buffer)} - const body = "foo" - io.WriteString(p, body) - a := errors.New("test error") - p.CloseWithError(a) - all, err := io.ReadAll(p) - if string(all) != body { - t.Errorf("read bytes = %q; want %q", all, body) - } - if err != a { - t.Logf("read error = %v, %v", err, a) - } - if p.Len() != 0 { - t.Errorf("pipe should have 0 unread bytes") - } - // Read and Write should fail. - if n, err := p.Write([]byte("abc")); err != errClosedPipeWrite || n != 0 { - t.Errorf("Write(abc) after close\ngot %v, %v\nwant 0, %v", n, err, errClosedPipeWrite) - } - if n, err := p.Read(make([]byte, 1)); err == nil || n != 0 { - t.Errorf("Read() after close\ngot %v, nil\nwant 0, %v", n, errClosedPipeWrite) - } - if p.Len() != 0 { - t.Errorf("pipe should have 0 unread bytes") - } -} - -func TestPipeBreakWithError(t *testing.T) { - p := &pipe{b: new(bytes.Buffer)} - io.WriteString(p, "foo") - a := errors.New("test err") - p.BreakWithError(a) - all, err := io.ReadAll(p) - if string(all) != "" { - t.Errorf("read bytes = %q; want empty string", all) - } - if err != a { - t.Logf("read error = %v, %v", err, a) - } - if p.b != nil { - t.Errorf("buffer should be nil after BreakWithError") - } - if p.Len() != 3 { - t.Errorf("pipe should have 3 unread bytes") - } - // Write should fail. - if n, err := p.Write([]byte("abc")); err != errClosedPipeWrite || n != 0 { - t.Errorf("Write(abc) after break\ngot %v, %v\nwant 0, errClosedPipeWrite", n, err) - } - if p.b != nil { - t.Errorf("buffer should be nil after Write") - } - if p.Len() != 3 { - t.Errorf("pipe should have 6 unread bytes") - } - // Read should fail. - if n, err := p.Read(make([]byte, 1)); err == nil || n != 0 { - t.Errorf("Read() after close\ngot %v, nil\nwant 0, not nil", n) - } -} diff --git a/internal/http2/server_test.go b/internal/http2/server_test.go deleted file mode 100644 index be803d76..00000000 --- a/internal/http2/server_test.go +++ /dev/null @@ -1,5527 +0,0 @@ -package http2 - -import ( - "bufio" - "bytes" - "context" - "crypto/tls" - "errors" - "flag" - "fmt" - "github.com/imroc/req/v3/internal/ascii" - "io" - "log" - "math" - "net" - "net/http" - "net/http/httptest" - "net/textproto" - "net/url" - "os" - "reflect" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" - - "golang.org/x/net/http/httpguts" - "golang.org/x/net/http2/hpack" -) - -// A list of the possible cipher suite ids. Taken from -// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt - -const ( - cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 - cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 - cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 - cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 - cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 - cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 - cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 - cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 - cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 - cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A - cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B - cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C - cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D - cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E - cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F - cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 - cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 - cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 - cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 - cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 - cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 - cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 - cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 - cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 - cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 - cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A - cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B - cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E - cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F - cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 - cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 - cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 - cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 - cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 - cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 - cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 - cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 - cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 - cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 - cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A - cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B - cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C - cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D - cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E - cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F - cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 - cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 - cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 - cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 - cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 - cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 - cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 - cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 - cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 - cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A - cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B - cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C - cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D - cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E - cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F - cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 - cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 - cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 - cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 - cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 - cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 - cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 - cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 - cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A - cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B - cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C - cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D - cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 - cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 - cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 - cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 - cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A - cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B - cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C - cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D - cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E - cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F - cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 - cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 - cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 - cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 - cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 - cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 - cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 - cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 - cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 - cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 - cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A - cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B - cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C - cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D - cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 - cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 - cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 - cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 - cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 - cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 - cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 - cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 - cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC - cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD - cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE - cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF - cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 - cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 - cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 - cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 - cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 - cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 - cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 - cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 - cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 - cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 - cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC - cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD - cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE - cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF - cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 - cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 - cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 - cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 - cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF - cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 - cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 - cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 - cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 - cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 - cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 - cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 - cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 - cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 - cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A - cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B - cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C - cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D - cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E - cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F - cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 - cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 - cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 - cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 - cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 - cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 - cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 - cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 - cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 - cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 - cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A - cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B - cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C - cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D - cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E - cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F - cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 - cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 - cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 - cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 - cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 - cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 - cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 - cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 - cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 - cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 - cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A - cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D - cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E - cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F - cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 - cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 - cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 - cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 - cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 - cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 - cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 - cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B - cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C - cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D - cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E - cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F - cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 - cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 - cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 - cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 - cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 - cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 - cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 - cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 - cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 - cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 - cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A - cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B - cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C - cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D - cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E - cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F - cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 - cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 - cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 - cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 - cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 - cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 - cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A - cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B - cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E - cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F - cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 - cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 - cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 - cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 - cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 - cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 - cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 - cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 - cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A - cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B - cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E - cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F - cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 - cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 - cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 - cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 - cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 - cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 - cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A - cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 - cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 - cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D - cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E - cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F - cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 - cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 - cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 - cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 - cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 - cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 - cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 - cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 - cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A - cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B - cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C - cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D - cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 - cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 - cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 - cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 - cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 - cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 -) - -// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. -// References: -// https://tools.ietf.org/html/rfc7540#appendix-A -// Reject cipher suites from Appendix A. -// "This list includes those cipher suites that do not -// offer an ephemeral key exchange and those that are -// based on the TLS null, stream or block cipher type" -func isBadCipher(cipher uint16) bool { - switch cipher { - case cipher_TLS_NULL_WITH_NULL_NULL, - cipher_TLS_RSA_WITH_NULL_MD5, - cipher_TLS_RSA_WITH_NULL_SHA, - cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, - cipher_TLS_RSA_WITH_RC4_128_MD5, - cipher_TLS_RSA_WITH_RC4_128_SHA, - cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, - cipher_TLS_RSA_WITH_IDEA_CBC_SHA, - cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_RSA_WITH_DES_CBC_SHA, - cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, - cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, - cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, - cipher_TLS_DH_anon_WITH_RC4_128_MD5, - cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, - cipher_TLS_DH_anon_WITH_DES_CBC_SHA, - cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_KRB5_WITH_DES_CBC_SHA, - cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_KRB5_WITH_RC4_128_SHA, - cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, - cipher_TLS_KRB5_WITH_DES_CBC_MD5, - cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, - cipher_TLS_KRB5_WITH_RC4_128_MD5, - cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, - cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, - cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, - cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, - cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, - cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, - cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, - cipher_TLS_PSK_WITH_NULL_SHA, - cipher_TLS_DHE_PSK_WITH_NULL_SHA, - cipher_TLS_RSA_PSK_WITH_NULL_SHA, - cipher_TLS_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, - cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, - cipher_TLS_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, - cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, - cipher_TLS_RSA_WITH_NULL_SHA256, - cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, - cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, - cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, - cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, - cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, - cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, - cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, - cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, - cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, - cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, - cipher_TLS_PSK_WITH_RC4_128_SHA, - cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_PSK_WITH_AES_128_CBC_SHA, - cipher_TLS_PSK_WITH_AES_256_CBC_SHA, - cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, - cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, - cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, - cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, - cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, - cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, - cipher_TLS_RSA_WITH_SEED_CBC_SHA, - cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, - cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, - cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, - cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, - cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, - cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, - cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, - cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, - cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, - cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, - cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, - cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, - cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, - cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, - cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, - cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, - cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, - cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, - cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, - cipher_TLS_PSK_WITH_NULL_SHA256, - cipher_TLS_PSK_WITH_NULL_SHA384, - cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, - cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, - cipher_TLS_DHE_PSK_WITH_NULL_SHA256, - cipher_TLS_DHE_PSK_WITH_NULL_SHA384, - cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, - cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, - cipher_TLS_RSA_PSK_WITH_NULL_SHA256, - cipher_TLS_RSA_PSK_WITH_NULL_SHA384, - cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, - cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, - cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, - cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, - cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDH_RSA_WITH_NULL_SHA, - cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, - cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, - cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, - cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDH_anon_WITH_NULL_SHA, - cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, - cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, - cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, - cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, - cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, - cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, - cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, - cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, - cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, - cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, - cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, - cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, - cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, - cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, - cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, - cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, - cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, - cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, - cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, - cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, - cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, - cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, - cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, - cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, - cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, - cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, - cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, - cipher_TLS_RSA_WITH_AES_128_CCM, - cipher_TLS_RSA_WITH_AES_256_CCM, - cipher_TLS_RSA_WITH_AES_128_CCM_8, - cipher_TLS_RSA_WITH_AES_256_CCM_8, - cipher_TLS_PSK_WITH_AES_128_CCM, - cipher_TLS_PSK_WITH_AES_256_CCM, - cipher_TLS_PSK_WITH_AES_128_CCM_8, - cipher_TLS_PSK_WITH_AES_256_CCM_8: - return true - default: - return false - } -} - -const ( - prefaceTimeout = 10 * time.Second - firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway - handlerChunkWriteSize = 4 << 10 - defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? - maxQueuedControlFrames = 10000 -) - -var ( - errClientDisconnected = errors.New("client disconnected") - errClosedBody = errors.New("body closed by handler") - errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - errStreamClosed = errors.New("http2: stream closed") -) - -var responseWriterStatePool = sync.Pool{ - New: func() interface{} { - rws := &responseWriterState{} - rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize) - return rws - }, -} - -// Test hooks. -var ( - testHookOnConn func() - testHookGetServerConn func(*serverConn) - testHookOnPanicMu *sync.Mutex // nil except in tests - testHookOnPanic func(sc *serverConn, panicVal interface{}) (rePanic bool) -) - -// Server is an HTTP/2 server. -type Server struct { - // MaxHandlers limits the number of http.Handler ServeHTTP goroutines - // which may run at a time over all connections. - // Negative or zero no limit. - // TODO: implement - MaxHandlers int - - // MaxConcurrentStreams optionally specifies the number of - // concurrent streams that each client may have open at a - // time. This is unrelated to the number of http.Handler goroutines - // which may be active globally, which is MaxHandlers. - // If zero, MaxConcurrentStreams defaults to at least 100, per - // the HTTP/2 spec's recommendations. - MaxConcurrentStreams uint32 - - // MaxReadFrameSize optionally specifies the largest frame - // this server is willing to read. A valid value is between - // 16k and 16M, inclusive. If zero or otherwise invalid, a - // default value is used. - MaxReadFrameSize uint32 - - // PermitProhibitedCipherSuites, if true, permits the use of - // cipher suites prohibited by the HTTP/2 spec. - PermitProhibitedCipherSuites bool - - // IdleTimeout specifies how long until idle clients should be - // closed with a GOAWAY frame. PING frames are not considered - // activity for the purposes of IdleTimeout. - IdleTimeout time.Duration - - // MaxUploadBufferPerConnection is the size of the initial flow - // control window for each connections. The HTTP/2 spec does not - // allow this to be smaller than 65535 or larger than 2^32-1. - // If the value is outside this range, a default value will be - // used instead. - MaxUploadBufferPerConnection int32 - - // MaxUploadBufferPerStream is the size of the initial flow control - // window for each stream. The HTTP/2 spec does not allow this to - // be larger than 2^32-1. If the value is zero or larger than the - // maximum, a default value will be used instead. - MaxUploadBufferPerStream int32 - - // NewWriteScheduler constructs a write scheduler for a connection. - // If nil, a default scheduler is chosen. - NewWriteScheduler func() WriteScheduler - - // CountError, if non-nil, is called on HTTP/2 server errors. - // It's intended to increment a metric for monitoring, such - // as an expvar or Prometheus metric. - // The errType consists of only ASCII word characters. - CountError func(errType string) - - // Internal state. This is a pointer (rather than embedded directly) - // so that we don't embed a Mutex in this struct, which will make the - // struct non-copyable, which might break some callers. - state *serverInternalState -} - -func (s *Server) initialConnRecvWindowSize() int32 { - if s.MaxUploadBufferPerConnection > initialWindowSize { - return s.MaxUploadBufferPerConnection - } - return 1 << 20 -} - -func (s *Server) initialStreamRecvWindowSize() int32 { - if s.MaxUploadBufferPerStream > 0 { - return s.MaxUploadBufferPerStream - } - return 1 << 20 -} - -func (s *Server) maxReadFrameSize() uint32 { - if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize { - return v - } - return defaultMaxReadFrameSize -} - -func (s *Server) maxConcurrentStreams() uint32 { - if v := s.MaxConcurrentStreams; v > 0 { - return v - } - return defaultMaxStreams -} - -// maxQueuedControlFrames is the maximum number of control frames like -// SETTINGS, PING and RST_STREAM that will be queued for writing before -// the connection is closed to prevent memory exhaustion attacks. -func (s *Server) maxQueuedControlFrames() int { - // TODO: if anybody asks, add a Server field, and remember to define the - // behavior of negative values. - return maxQueuedControlFrames -} - -// ServeConn serves HTTP/2 requests on the provided connection and -// blocks until the connection is no longer readable. -// -// ServeConn starts speaking HTTP/2 assuming that c has not had any -// reads or writes. It writes its initial settings frame and expects -// to be able to read the preface and settings frame from the -// client. If c has a ConnectionState method like a *tls.Conn, the -// ConnectionState is used to verify the TLS ciphersuite and to set -// the Request.TLS field in Handlers. -// -// ServeConn does not support h2c by itself. Any h2c support must be -// implemented in terms of providing a suitably-behaving net.Conn. -// -// The opts parameter is optional. If nil, default values are used. -func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { - baseCtx, cancel := serverConnBaseContext(c, opts) - defer cancel() - - sc := &serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*stream), - readFrameCh: make(chan readFrameResult), - wantWriteFrameCh: make(chan FrameWriteRequest, 8), - serveMsgCh: make(chan interface{}, 8), - wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync - bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way - doneServing: make(chan struct{}), - clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" - advMaxStreams: s.maxConcurrentStreams(), - initialStreamSendWindowSize: initialWindowSize, - maxFrameSize: initialMaxFrameSize, - headerTableSize: initialHeaderTableSize, - serveG: newGoroutineLock(), - pushEnabled: true, - } - - s.state.registerConn(sc) - defer s.state.unregisterConn(sc) - - // The net/http package sets the write deadline from the - // http.Server.WriteTimeout during the TLS handshake, but then - // passes the connection off to us with the deadline already set. - // Write deadlines are set per stream in serverConn.newStream. - // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { - sc.conn.SetWriteDeadline(time.Time{}) - } - - if s.NewWriteScheduler != nil { - sc.writeSched = s.NewWriteScheduler() - } else { - sc.writeSched = newRoundRobinWriteScheduler() - } - - // These start at the RFC-specified defaults. If there is a higher - // configured value for inflow, that will be updated when we send a - // WINDOW_UPDATE shortly after sending SETTINGS. - sc.flow.add(initialWindowSize) - sc.inflow.add(initialWindowSize) - sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - - fr := NewFramer(sc.bw, c) - fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) - fr.MaxHeaderListSize = sc.maxHeaderListSize() - fr.SetMaxReadFrameSize(s.maxReadFrameSize()) - sc.framer = fr - - if tc, ok := c.(connectionStater); ok { - sc.tlsState = new(tls.ConnectionState) - *sc.tlsState = tc.ConnectionState() - // 9.2 Use of TLS Features - // An implementation of HTTP/2 over TLS MUST use TLS - // 1.2 or higher with the restrictions on feature set - // and cipher suite described in this section. Due to - // implementation limitations, it might not be - // possible to fail TLS negotiation. An endpoint MUST - // immediately terminate an HTTP/2 connection that - // does not meet the TLS requirements described in - // this section with a connection error (Section - // 5.4.1) of type INADEQUATE_SECURITY. - if sc.tlsState.Version < tls.VersionTLS12 { - sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low") - return - } - - if sc.tlsState.ServerName == "" { - // Client must use SNI, but we don't enforce that anymore, - // since it was causing problems when connecting to bare IP - // addresses during development. - // - // TODO: optionally enforce? Or enforce at the time we receive - // a new request, and verify the ServerName matches the :authority? - // But that precludes proxy situations, perhaps. - // - // So for now, do nothing here again. - } - - if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { - // "Endpoints MAY choose to generate a connection error - // (Section 5.4.1) of type INADEQUATE_SECURITY if one of - // the prohibited cipher suites are negotiated." - // - // We choose that. In my opinion, the spec is weak - // here. It also says both parties must support at least - // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no - // excuses here. If we really must, we could allow an - // "AllowInsecureWeakCiphers" option on the server later. - // Let's see how it plays out first. - sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) - return - } - } - - if hook := testHookGetServerConn; hook != nil { - hook(sc) - } - sc.serve() -} - -type serverInternalState struct { - mu sync.Mutex - activeConns map[*serverConn]struct{} -} - -func (s *serverInternalState) registerConn(sc *serverConn) { - if s == nil { - return // if the Server was used without calling ConfigureServer - } - s.mu.Lock() - s.activeConns[sc] = struct{}{} - s.mu.Unlock() -} - -func (s *serverInternalState) unregisterConn(sc *serverConn) { - if s == nil { - return // if the Server was used without calling ConfigureServer - } - s.mu.Lock() - delete(s.activeConns, sc) - s.mu.Unlock() -} - -func (s *serverInternalState) startGracefulShutdown() { - if s == nil { - return // if the Server was used without calling ConfigureServer - } - s.mu.Lock() - for sc := range s.activeConns { - sc.startGracefulShutdown() - } - s.mu.Unlock() -} - -// ServeConnOpts are options for the Server.ServeConn method. -type ServeConnOpts struct { - // Context is the base context to use. - // If nil, context.Background is used. - Context context.Context - - // BaseConfig optionally sets the base configuration - // for values. If nil, defaults are used. - BaseConfig *http.Server - - // Handler specifies which handler to use for processing - // requests. If nil, BaseConfig.Handler is used. If BaseConfig - // or BaseConfig.Handler is nil, http.DefaultServeMux is used. - Handler http.Handler -} - -func (o *ServeConnOpts) context() context.Context { - if o != nil && o.Context != nil { - return o.Context - } - return context.Background() -} - -func (o *ServeConnOpts) baseConfig() *http.Server { - if o != nil && o.BaseConfig != nil { - return o.BaseConfig - } - return new(http.Server) -} - -func (o *ServeConnOpts) handler() http.Handler { - if o != nil { - if o.Handler != nil { - return o.Handler - } - if o.BaseConfig != nil && o.BaseConfig.Handler != nil { - return o.BaseConfig.Handler - } - } - return http.DefaultServeMux -} - -func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { - ctx, cancel = context.WithCancel(opts.context()) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) - if hs := opts.baseConfig(); hs != nil { - ctx = context.WithValue(ctx, http.ServerContextKey, hs) - } - return -} - -// bufferedWriter is a buffered writer that writes to w. -// Its buffered writer is lazily allocated as needed, to minimize -// idle memory usage with many connections. -type bufferedWriter struct { - _ incomparable - w io.Writer // immutable - bw *bufio.Writer // non-nil when data is buffered -} - -func newBufferedWriter(w io.Writer) *bufferedWriter { - return &bufferedWriter{w: w} -} - -func (w *bufferedWriter) Available() int { - if w.bw == nil { - return bufWriterPoolBufferSize - } - return w.bw.Available() -} - -func (w *bufferedWriter) Write(p []byte) (n int, err error) { - if w.bw == nil { - bw := bufWriterPool.Get().(*bufio.Writer) - bw.Reset(w.w) - w.bw = bw - } - return w.bw.Write(p) -} - -func (w *bufferedWriter) Flush() error { - bw := w.bw - if bw == nil { - return nil - } - err := bw.Flush() - bw.Reset(nil) - bufWriterPool.Put(bw) - w.bw = nil - return err -} - -func (sc *serverConn) rejectConn(err ErrCode, debug string) { - sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) - // ignoring errors. hanging up anyway. - sc.framer.WriteGoAway(0, err, []byte(debug)) - sc.bw.Flush() - sc.conn.Close() -} - -type serverConn struct { - // Immutable: - srv *Server - hs *http.Server - conn net.Conn - bw *bufferedWriter // writing to conn - handler http.Handler - baseCtx context.Context - framer *Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan FrameWriteRequest // from handlers -> serve - wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan bodyReadMsg // from handlers -> serve - serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow outflow // conn-wide (not stream-specific) outbound flow control - inflow inflow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http - remoteAddrStr string - writeSched WriteScheduler - - // Everything following is owned by the serve loop; use serveG.check(): - serveG goroutineLock // used to verify funcs are on serve() - pushEnabled bool - sawFirstSettings bool // got the initial SETTINGS frame after the preface - needToSendSettingsAck bool - unackedSettings int // how many SETTINGS have we sent without ACKs? - queuedControlFrames int // control frames in the writeSched queue - clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) - advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curClientStreams uint32 // number of open streams initiated by the client - curPushedStreams uint32 // number of open streams initiated by server push - maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests - maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes - streams map[uint32]*stream - initialStreamSendWindowSize int32 - maxFrameSize int32 - headerTableSize uint32 - peerMaxHeaderListSize uint32 // zero means unknown (default) - canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started writing a frame (on serve goroutine or separate) - writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh - needsFrameFlush bool // last frame write wasn't a flush - inGoAway bool // we've started to or sent GOAWAY - inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop - needToSendGoAway bool // we need to schedule a GOAWAY frame write - goAwayCode ErrCode - shutdownTimer *time.Timer // nil until used - idleTimer *time.Timer // nil if unused - - // Owned by the writeFrameAsync goroutine: - headerWriteBuf bytes.Buffer - hpackEncoder *hpack.Encoder - - // Used by startGracefulShutdown. - shutdownOnce sync.Once -} - -func (sc *serverConn) maxHeaderListSize() uint32 { - n := sc.hs.MaxHeaderBytes - if n <= 0 { - n = http.DefaultMaxHeaderBytes - } - // http2's count is in a slightly different unit and includes 32 bytes per pair. - // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. - const perFieldOverhead = 32 // per http2 spec - const typicalHeaders = 10 // conservative - return uint32(n + typicalHeaders*perFieldOverhead) -} - -func (sc *serverConn) curOpenStreams() uint32 { - sc.serveG.check() - return sc.curClientStreams + sc.curPushedStreams -} - -// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). -type closeWaiter chan struct{} - -// Init makes a closeWaiter usable. -// It exists because so a closeWaiter value can be placed inside a -// larger struct and have the Mutex and Cond's memory in the same -// allocation. -func (cw *closeWaiter) Init() { - *cw = make(chan struct{}) -} - -// Close marks the closeWaiter as closed and unblocks any waiters. -func (cw closeWaiter) Close() { - close(cw) -} - -// Wait waits for the closeWaiter to become closed. -func (cw closeWaiter) Wait() { - <-cw -} - -// stream represents a stream. This is the minimal metadata needed by -// the serve goroutine. Most of the actual stream state is owned by -// the http.Handler's goroutine in the responseWriter. Because the -// responseWriter's responseWriterState is recycled at the end of a -// handler, this struct intentionally has no pointer to the -// *responseWriter{,State} itself, as the Handler ending nils out the -// responseWriter's state field. -type stream struct { - // immutable: - sc *serverConn - id uint32 - body *pipe // non-nil if expecting DATA frames - cw closeWaiter // closed wait stream transitions to closed state - ctx context.Context - cancelCtx func() - - // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow outflow // limits writing from Handler to client - inflow inflow // what the client is allowed to POST/etc to us - state streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - writeDeadline *time.Timer // nil if unused - - trailer http.Header // accumulated trailers - reqTrailer http.Header // handler's Request.Trailer -} - -func (sc *serverConn) Framer() *Framer { return sc.framer } - -func (sc *serverConn) CloseConn() error { return sc.conn.Close() } - -func (sc *serverConn) Flush() error { return sc.bw.Flush() } - -func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { - return sc.hpackEncoder, &sc.headerWriteBuf -} - -const ( - // SETTINGS_MAX_FRAME_SIZE default - // http://http2.github.io/http2-spec/#rfc.section.6.5.2 - initialMaxFrameSize = 16384 - - defaultMaxReadFrameSize = 1 << 20 -) - -type streamState int - -// HTTP/2 stream states. -// -// See http://tools.ietf.org/html/rfc7540#section-5.1. -// -// For simplicity, the server code merges "reserved (local)" into -// "half-closed (remote)". This is one less state transition to track. -// The only downside is that we send PUSH_PROMISEs slightly less -// liberally than allowable. More discussion here: -// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html -// -// "reserved (remote)" is omitted since the client code does not -// support server push. -const ( - stateIdle streamState = iota - stateOpen - stateHalfClosedLocal - stateHalfClosedRemote - stateClosed -) - -var stateName = [...]string{ - stateIdle: "Idle", - stateOpen: "Open", - stateHalfClosedLocal: "HalfClosedLocal", - stateHalfClosedRemote: "HalfClosedRemote", - stateClosed: "Closed", -} - -func (st streamState) String() string { - return stateName[st] -} - -func (sc *serverConn) state(streamID uint32) (streamState, *stream) { - sc.serveG.check() - // http://tools.ietf.org/html/rfc7540#section-5.1 - if st, ok := sc.streams[streamID]; ok { - return st.state, st - } - // "The first use of a new stream identifier implicitly closes all - // streams in the "idle" state that might have been initiated by - // that peer with a lower-valued stream identifier. For example, if - // a client sends a HEADERS frame on stream 7 without ever sending a - // frame on stream 5, then stream 5 transitions to the "closed" - // state when the first frame for stream 7 is sent or received." - if streamID%2 == 1 { - if streamID <= sc.maxClientStreamID { - return stateClosed, nil - } - } else { - if streamID <= sc.maxPushPromiseID { - return stateClosed, nil - } - } - return stateIdle, nil -} - -// setConnState calls the net/http ConnState hook for this connection, if configured. -// Note that the net/http package does StateNew and StateClosed for us. -// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. -func (sc *serverConn) setConnState(state http.ConnState) { - if sc.hs.ConnState != nil { - sc.hs.ConnState(sc.conn, state) - } -} - -func (sc *serverConn) vlogf(format string, args ...interface{}) { - if VerboseLogs { - sc.logf(format, args...) - } -} - -func (sc *serverConn) logf(format string, args ...interface{}) { - if lg := sc.hs.ErrorLog; lg != nil { - lg.Printf(format, args...) - } else { - log.Printf(format, args...) - } -} - -// errno returns v's underlying uintptr, else 0. -// -// TODO: remove this helper function once http2 can use build -// tags. See comment in isClosedConnError. -func errno(v error) uintptr { - if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { - return uintptr(rv.Uint()) - } - return 0 -} - -// isClosedConnError reports whether err is an error from use of a closed -// network connection. -func isClosedConnError(err error) bool { - if err == nil { - return false - } - - // TODO: remove this string search and be more like the Windows - // case below. That might involve modifying the standard library - // to return better error types. - str := err.Error() - if strings.Contains(str, "use of closed network connection") { - return true - } - - // TODO(bradfitz): x/tools/cmd/bundle doesn't really support - // build tags, so I can't make an _windows.go file with - // Windows-specific stuff. Fix that and move this, once we - // have a way to bundle this into std's net/http somehow. - if runtime.GOOS == "windows" { - if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { - if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { - const WSAECONNABORTED = 10053 - const WSAECONNRESET = 10054 - if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { - return true - } - } - } - } - return false -} - -func (sc *serverConn) condlogf(err error, format string, args ...interface{}) { - if err == nil { - return - } - if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout { - // Boring, expected errors. - sc.vlogf(format, args...) - } else { - sc.logf(format, args...) - } -} - -func (sc *serverConn) canonicalHeader(v string) string { - sc.serveG.check() - buildCommonHeaderMapsOnce() - cv, ok := commonCanonHeader[v] - if ok { - return cv - } - cv, ok = sc.canonHeader[v] - if ok { - return cv - } - if sc.canonHeader == nil { - sc.canonHeader = make(map[string]string) - } - cv = http.CanonicalHeaderKey(v) - // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of - // entries in the canonHeader cache. This should be larger than the number - // of unique, uncommon header keys likely to be sent by the peer, while not - // so high as to permit unreasonable memory usage if the peer sends an unbounded - // number of unique header keys. - const maxCachedCanonicalHeaders = 32 - if len(sc.canonHeader) < maxCachedCanonicalHeaders { - sc.canonHeader[v] = cv - } - return cv -} - -type readFrameResult struct { - f Frame // valid until readMore is called - err error - - // readMore should be called once the consumer no longer needs or - // retains f. After readMore, f is invalid and more frames can be - // read. - readMore func() -} - -// A gate lets two goroutines coordinate their activities. -type gate chan struct{} - -func (g gate) Done() { g <- struct{}{} } - -func (g gate) Wait() { <-g } - -// readFrames is the loop that reads incoming frames. -// It takes care to only read one frame at a time, blocking until the -// consumer is done with the frame. -// It's run on its own goroutine. -func (sc *serverConn) readFrames() { - gate := make(gate) - gateDone := gate.Done - for { - f, err := sc.framer.ReadFrame() - select { - case sc.readFrameCh <- readFrameResult{f, err, gateDone}: - case <-sc.doneServing: - return - } - select { - case <-gate: - case <-sc.doneServing: - return - } - if terminalReadFrameError(err) { - return - } - } -} - -// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. -type frameWriteResult struct { - _ incomparable - wr FrameWriteRequest // what was written (or attempted) - err error // result of the writeFrame call -} - -// writeFrameAsync runs in its own goroutine and writes a single frame -// and then reports when it's done. -// At most one goroutine can be running writeFrameAsync at a time per -// serverConn. -func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest) { - err := wr.write.writeFrame(sc) - sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err} -} - -func (sc *serverConn) closeAllStreamsOnConnClose() { - sc.serveG.check() - for _, st := range sc.streams { - sc.closeStream(st, errClientDisconnected) - } -} - -func (sc *serverConn) stopShutdownTimer() { - sc.serveG.check() - if t := sc.shutdownTimer; t != nil { - t.Stop() - } -} - -func (sc *serverConn) notePanic() { - // Note: this is for serverConn.serve panicking, not http.Handler code. - if testHookOnPanicMu != nil { - testHookOnPanicMu.Lock() - defer testHookOnPanicMu.Unlock() - } - if testHookOnPanic != nil { - if e := recover(); e != nil { - if testHookOnPanic(sc, e) { - panic(e) - } - } - } -} - -func (sc *serverConn) serve() { - sc.serveG.check() - defer sc.notePanic() - defer sc.conn.Close() - defer sc.closeAllStreamsOnConnClose() - defer sc.stopShutdownTimer() - defer close(sc.doneServing) // unblocks handlers trying to send - - if VerboseLogs { - sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) - } - - sc.writeFrame(FrameWriteRequest{ - write: writeSettings{ - {SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, - {SettingMaxConcurrentStreams, sc.advMaxStreams}, - {SettingMaxHeaderListSize, sc.maxHeaderListSize()}, - {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, - }, - }) - sc.unackedSettings++ - - // Each connection starts with initialWindowSize inflow tokens. - // If a higher value is configured, we add more tokens. - if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { - sc.sendWindowUpdate(nil, int(diff)) - } - - if err := sc.readPreface(); err != nil { - sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) - return - } - // Now that we've got the preface, get us out of the - // "StateNew" state. We can't go directly to idle, though. - // Active means we read some data and anticipate a request. We'll - // do another Active when we get a HEADERS frame. - sc.setConnState(http.StateActive) - sc.setConnState(http.StateIdle) - - if sc.srv.IdleTimeout != 0 { - sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) - defer sc.idleTimer.Stop() - } - - go sc.readFrames() // closed by defer sc.conn.Close above - - settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) - defer settingsTimer.Stop() - - loopNum := 0 - for { - loopNum++ - select { - case wr := <-sc.wantWriteFrameCh: - if se, ok := wr.write.(StreamError); ok { - sc.resetStream(se) - break - } - sc.writeFrame(wr) - case res := <-sc.wroteFrameCh: - sc.wroteFrame(res) - case res := <-sc.readFrameCh: - // Process any written frames before reading new frames from the client since a - // written frame could have triggered a new stream to be started. - if sc.writingFrameAsync { - select { - case wroteRes := <-sc.wroteFrameCh: - sc.wroteFrame(wroteRes) - default: - } - } - if !sc.processFrameFromReader(res) { - return - } - res.readMore() - if settingsTimer != nil { - settingsTimer.Stop() - settingsTimer = nil - } - case m := <-sc.bodyReadCh: - sc.noteBodyRead(m.st, m.n) - case msg := <-sc.serveMsgCh: - switch v := msg.(type) { - case func(int): - v(loopNum) // for testing - case *serverMessage: - switch v { - case settingsTimerMsg: - sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) - return - case idleTimerMsg: - sc.vlogf("connection is idle") - sc.goAway(ErrCodeNo) - case shutdownTimerMsg: - sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) - return - case gracefulShutdownMsg: - sc.startGracefulShutdownInternal() - default: - panic("unknown timer") - } - case *startPushRequest: - sc.startPush(v) - default: - panic(fmt.Sprintf("unexpected type %T", v)) - } - } - - // If the peer is causing us to generate a lot of control frames, - // but not reading them from us, assume they are trying to make us - // run out of memory. - if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { - sc.vlogf("http2: too many control frames in send queue, closing connection") - return - } - - // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY - // with no error code (graceful shutdown), don't start the timer until - // all open streams have been completed. - sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame - gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0 - if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) { - sc.shutDownIn(goAwayTimeout) - } - } -} - -func (sc *serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { - select { - case <-sc.doneServing: - case <-sharedCh: - close(privateCh) - } -} - -type serverMessage int - -// Message values sent to serveMsgCh. -var ( - settingsTimerMsg = new(serverMessage) - idleTimerMsg = new(serverMessage) - shutdownTimerMsg = new(serverMessage) - gracefulShutdownMsg = new(serverMessage) -) - -func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } - -func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) } - -func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) } - -func (sc *serverConn) sendServeMsg(msg interface{}) { - sc.serveG.checkNotOn() // NOT - select { - case sc.serveMsgCh <- msg: - case <-sc.doneServing: - } -} - -var errPrefaceTimeout = errors.New("timeout waiting for client preface") - -// readPreface reads the ClientPreface greeting from the peer or -// returns errPrefaceTimeout on timeout, or an error if the greeting -// is invalid. -func (sc *serverConn) readPreface() error { - errc := make(chan error, 1) - go func() { - // Read the client preface - buf := make([]byte, len(ClientPreface)) - if _, err := io.ReadFull(sc.conn, buf); err != nil { - errc <- err - } else if !bytes.Equal(buf, clientPreface) { - errc <- fmt.Errorf("bogus greeting %q", buf) - } else { - errc <- nil - } - }() - timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? - defer timer.Stop() - select { - case <-timer.C: - return errPrefaceTimeout - case err := <-errc: - if err == nil { - if VerboseLogs { - sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) - } - } - return err - } -} - -var errChanPool = sync.Pool{ - New: func() interface{} { return make(chan error, 1) }, -} - -var writeDataPool = sync.Pool{ - New: func() interface{} { return new(writeData) }, -} - -// writeDataFromHandler writes DATA response frames from a handler on -// the given stream. -func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error { - ch := errChanPool.Get().(chan error) - writeArg := writeDataPool.Get().(*writeData) - *writeArg = writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(FrameWriteRequest{ - write: writeArg, - stream: stream, - done: ch, - }) - if err != nil { - return err - } - var frameWriteDone bool // the frame write is done (successfully or not) - select { - case err = <-ch: - frameWriteDone = true - case <-sc.doneServing: - return errClientDisconnected - case <-stream.cw: - // If both ch and stream.cw were ready (as might - // happen on the final Write after an http.Handler - // ends), prefer the write result. Otherwise this - // might just be us successfully closing the stream. - // The writeFrameAsync and serve goroutines guarantee - // that the ch send will happen before the stream.cw - // close. - select { - case err = <-ch: - frameWriteDone = true - default: - return errStreamClosed - } - } - errChanPool.Put(ch) - if frameWriteDone { - writeDataPool.Put(writeArg) - } - return err -} - -// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts -// if the connection has gone away. -// -// This must not be run from the serve goroutine itself, else it might -// deadlock writing to sc.wantWriteFrameCh (which is only mildly -// buffered and is read by serve itself). If you're on the serve -// goroutine, call writeFrame instead. -func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error { - sc.serveG.checkNotOn() // NOT - select { - case sc.wantWriteFrameCh <- wr: - return nil - case <-sc.doneServing: - // Serve loop is gone. - // Client has closed their connection to the server. - return errClientDisconnected - } -} - -// writeFrame schedules a frame to write and sends it if there's nothing -// already being written. -// -// There is no pushback here (the serve goroutine never blocks). It's -// the http.Handlers that block, waiting for their previous frames to -// make it onto the wire -// -// If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *serverConn) writeFrame(wr FrameWriteRequest) { - sc.serveG.check() - - // If true, wr will not be written and wr.done will not be signaled. - var ignoreWrite bool - - // We are not allowed to write frames on closed streams. RFC 7540 Section - // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on - // a closed stream." Our server never sends PRIORITY, so that exception - // does not apply. - // - // The serverConn might close an open stream while the stream's handler - // is still running. For example, the server might close a stream when it - // receives bad data from the client. If this happens, the handler might - // attempt to write a frame after the stream has been closed (since the - // handler hasn't yet been notified of the close). In this case, we simply - // ignore the frame. The handler will notice that the stream is closed when - // it waits for the frame to be written. - // - // As an exception to this rule, we allow sending RST_STREAM after close. - // This allows us to immediately reject new streams without tracking any - // state for those streams (except for the queued RST_STREAM frame). This - // may result in duplicate RST_STREAMs in some cases, but the client should - // ignore those. - if wr.StreamID() != 0 { - _, isReset := wr.write.(StreamError) - if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset { - ignoreWrite = true - } - } - - // Don't send a 100-continue response if we've already sent headers. - // See golang.org/issue/14030. - switch wr.write.(type) { - case *writeResHeaders: - wr.stream.wroteHeaders = true - case write100ContinueHeadersFrame: - if wr.stream.wroteHeaders { - // We do not need to notify wr.done because this frame is - // never written with wr.done != nil. - if wr.done != nil { - panic("wr.done != nil for write100ContinueHeadersFrame") - } - ignoreWrite = true - } - } - - if !ignoreWrite { - if wr.isControl() { - sc.queuedControlFrames++ - // For extra safety, detect wraparounds, which should not happen, - // and pull the plug. - if sc.queuedControlFrames < 0 { - sc.conn.Close() - } - } - sc.writeSched.Push(wr) - } - sc.scheduleFrameWrite() -} - -// startFrameWrite starts a goroutine to write wr (in a separate -// goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wr. -func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) { - sc.serveG.check() - if sc.writingFrame { - panic("internal error: can only be writing one frame at a time") - } - - st := wr.stream - if st != nil { - switch st.state { - case stateHalfClosedLocal: - switch wr.write.(type) { - case StreamError, handlerPanicRST, writeWindowUpdate: - // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE - // in this state. (We never send PRIORITY from the server, so that is not checked.) - default: - panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) - } - case stateClosed: - panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) - } - } - if wpp, ok := wr.write.(*writePushPromise); ok { - var err error - wpp.promisedID, err = wpp.allocatePromisedID() - if err != nil { - sc.writingFrameAsync = false - wr.replyToWriter(err) - return - } - } - - sc.writingFrame = true - sc.needsFrameFlush = true - if wr.write.staysWithinBuffer(sc.bw.Available()) { - sc.writingFrameAsync = false - err := wr.write.writeFrame(sc) - sc.wroteFrame(frameWriteResult{wr: wr, err: err}) - } else { - sc.writingFrameAsync = true - go sc.writeFrameAsync(wr) - } -} - -// errHandlerPanicked is the error given to any callers blocked in a read from -// Request.Body when the main goroutine panics. Since most handlers read in the -// main ServeHTTP goroutine, this will show up rarely. -var errHandlerPanicked = errors.New("http2: handler panicked") - -// wroteFrame is called on the serve goroutine with the result of -// whatever happened on writeFrameAsync. -func (sc *serverConn) wroteFrame(res frameWriteResult) { - sc.serveG.check() - if !sc.writingFrame { - panic("internal error: expected to be already writing a frame") - } - sc.writingFrame = false - sc.writingFrameAsync = false - - wr := res.wr - - if writeEndsStream(wr.write) { - st := wr.stream - if st == nil { - panic("internal error: expecting non-nil stream") - } - switch st.state { - case stateOpen: - // Here we would go to stateHalfClosedLocal in - // theory, but since our handler is done and - // the net/http package provides no mechanism - // for closing a ResponseWriter while still - // reading data (see possible TODO at top of - // this file), we go into closed state here - // anyway, after telling the peer we're - // hanging up on them. We'll transition to - // stateClosed after the RST_STREAM frame is - // written. - st.state = stateHalfClosedLocal - // Section 8.1: a server MAY request that the client abort - // transmission of a request without error by sending a - // RST_STREAM with an error code of NO_ERROR after sending - // a complete response. - sc.resetStream(streamError(st.id, ErrCodeNo)) - case stateHalfClosedRemote: - sc.closeStream(st, errHandlerComplete) - } - } else { - switch v := wr.write.(type) { - case StreamError: - // st may be unknown if the RST_STREAM was generated to reject bad input. - if st, ok := sc.streams[v.StreamID]; ok { - sc.closeStream(st, v) - } - case handlerPanicRST: - sc.closeStream(wr.stream, errHandlerPanicked) - } - } - - // Reply (if requested) to unblock the ServeHTTP goroutine. - wr.replyToWriter(res.err) - - sc.scheduleFrameWrite() -} - -// scheduleFrameWrite tickles the frame writing scheduler. -// -// If a frame is already being written, nothing happens. This will be called again -// when the frame is done being written. -// -// If a frame isn't being written and we need to send one, the best frame -// to send is selected by writeSched. -// -// If a frame isn't being written and there's nothing else to send, we -// flush the write buffer. -func (sc *serverConn) scheduleFrameWrite() { - sc.serveG.check() - if sc.writingFrame || sc.inFrameScheduleLoop { - return - } - sc.inFrameScheduleLoop = true - for !sc.writingFrameAsync { - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(FrameWriteRequest{ - write: &writeGoAway{ - maxStreamID: sc.maxClientStreamID, - code: sc.goAwayCode, - }, - }) - continue - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}}) - continue - } - if !sc.inGoAway || sc.goAwayCode == ErrCodeNo { - if wr, ok := sc.writeSched.Pop(); ok { - if wr.isControl() { - sc.queuedControlFrames-- - } - sc.startFrameWrite(wr) - continue - } - } - if sc.needsFrameFlush { - sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}}) - sc.needsFrameFlush = false // after startFrameWrite, since it sets this true - continue - } - break - } - sc.inFrameScheduleLoop = false -} - -// startGracefulShutdown gracefully shuts down a connection. This -// sends GOAWAY with ErrCodeNo to tell the client we're gracefully -// shutting down. The connection isn't closed until all current -// streams are done. -// -// startGracefulShutdown returns immediately; it does not wait until -// the connection has shut down. -func (sc *serverConn) startGracefulShutdown() { - sc.serveG.checkNotOn() // NOT - sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) }) -} - -// After sending GOAWAY with an error code (non-graceful shutdown), the -// connection will close after goAwayTimeout. -// -// If we close the connection immediately after sending GOAWAY, there may -// be unsent data in our kernel receive buffer, which will cause the kernel -// to send a TCP RST on close() instead of a FIN. This RST will abort the -// connection immediately, whether or not the client had received the GOAWAY. -// -// Ideally we should delay for at least 1 RTT + epsilon so the client has -// a chance to read the GOAWAY and stop sending messages. Measuring RTT -// is hard, so we approximate with 1 second. See golang.org/issue/18701. -// -// This is a var so it can be shorter in tests, where all requests uses the -// loopback interface making the expected RTT very small. -// -// TODO: configurable? -var goAwayTimeout = 1 * time.Second - -func (sc *serverConn) startGracefulShutdownInternal() { - sc.goAway(ErrCodeNo) -} - -func (sc *serverConn) goAway(code ErrCode) { - sc.serveG.check() - if sc.inGoAway { - return - } - sc.inGoAway = true - sc.needToSendGoAway = true - sc.goAwayCode = code - sc.scheduleFrameWrite() -} - -func (sc *serverConn) shutDownIn(d time.Duration) { - sc.serveG.check() - sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) -} - -func (sc *serverConn) resetStream(se StreamError) { - sc.serveG.check() - sc.writeFrame(FrameWriteRequest{write: se}) - if st, ok := sc.streams[se.StreamID]; ok { - st.resetQueued = true - } -} - -// 6.9.1 The Flow Control Window -// "If a sender receives a WINDOW_UPDATE that causes a flow control -// window to exceed this maximum it MUST terminate either the stream -// or the connection, as appropriate. For streams, [...]; for the -// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." -type goAwayFlowError struct{} - -func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" } - -// processFrameFromReader processes the serve loop's read from readFrameCh from the -// frame-reading goroutine. -// processFrameFromReader returns whether the connection should be kept open. -func (sc *serverConn) processFrameFromReader(res readFrameResult) bool { - sc.serveG.check() - err := res.err - if err != nil { - if err == errFrameTooLarge { - sc.goAway(ErrCodeFrameSize) - return true // goAway will close the loop - } - clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) - if clientGone { - // TODO: could we also get into this state if - // the peer does a half close - // (e.g. CloseWrite) because they're done - // sending frames but they're still wanting - // our open replies? Investigate. - // TODO: add CloseWrite to crypto/tls.Conn first - // so we have a way to test this? I suppose - // just for testing we could have a non-TLS mode. - return false - } - } else { - f := res.f - if VerboseLogs { - sc.vlogf("http2: server read frame %v", summarizeFrame(f)) - } - err = sc.processFrame(f) - if err == nil { - return true - } - } - - switch ev := err.(type) { - case StreamError: - sc.resetStream(ev) - return true - case goAwayFlowError: - sc.goAway(ErrCodeFlowControl) - return true - case ConnectionError: - sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) - sc.goAway(ErrCode(ev)) - return true // goAway will handle shutdown - default: - if res.err != nil { - sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) - } else { - sc.logf("http2: server closing client connection: %v", err) - } - return false - } -} - -func (sc *serverConn) processFrame(f Frame) error { - sc.serveG.check() - - // First frame received must be SETTINGS. - if !sc.sawFirstSettings { - if _, ok := f.(*SettingsFrame); !ok { - return sc.countError("first_settings", ConnectionError(ErrCodeProtocol)) - } - sc.sawFirstSettings = true - } - - switch f := f.(type) { - case *SettingsFrame: - return sc.processSettings(f) - case *MetaHeadersFrame: - return sc.processHeaders(f) - case *WindowUpdateFrame: - return sc.processWindowUpdate(f) - case *PingFrame: - return sc.processPing(f) - case *DataFrame: - return sc.processData(f) - case *RSTStreamFrame: - return sc.processResetStream(f) - case *PriorityFrame: - return sc.processPriority(f) - case *GoAwayFrame: - return sc.processGoAway(f) - case *PushPromiseFrame: - // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE - // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("push_promise", ConnectionError(ErrCodeProtocol)) - default: - sc.vlogf("http2: server ignoring frame: %v", f.Header()) - return nil - } -} - -func (sc *serverConn) processPing(f *PingFrame) error { - sc.serveG.check() - if f.IsAck() { - // 6.7 PING: " An endpoint MUST NOT respond to PING frames - // containing this flag." - return nil - } - if f.StreamID != 0 { - // "PING frames are not associated with any individual - // stream. If a PING frame is received with a stream - // identifier field value other than 0x0, the recipient MUST - // respond with a connection error (Section 5.4.1) of type - // PROTOCOL_ERROR." - return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol)) - } - if sc.inGoAway && sc.goAwayCode != ErrCodeNo { - return nil - } - sc.writeFrame(FrameWriteRequest{write: writePingAck{f}}) - return nil -} - -func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { - sc.serveG.check() - switch { - case f.StreamID != 0: // stream-level flow control - state, st := sc.state(f.StreamID) - if state == stateIdle { - // Section 5.1: "Receiving any frame other than HEADERS - // or PRIORITY on a stream in this state MUST be - // treated as a connection error (Section 5.4.1) of - // type PROTOCOL_ERROR." - return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol)) - } - if st == nil { - // "WINDOW_UPDATE can be sent by a peer that has sent a - // frame bearing the END_STREAM flag. This means that a - // receiver could receive a WINDOW_UPDATE frame on a "half - // closed (remote)" or "closed" stream. A receiver MUST - // NOT treat this as an error, see Section 5.1." - return nil - } - if !st.flow.add(int32(f.Increment)) { - return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl)) - } - default: // connection-level flow control - if !sc.flow.add(int32(f.Increment)) { - return goAwayFlowError{} - } - } - sc.scheduleFrameWrite() - return nil -} - -func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { - sc.serveG.check() - - state, st := sc.state(f.StreamID) - if state == stateIdle { - // 6.4 "RST_STREAM frames MUST NOT be sent for a - // stream in the "idle" state. If a RST_STREAM frame - // identifying an idle stream is received, the - // recipient MUST treat this as a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. - return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol)) - } - if st != nil { - st.cancelCtx() - sc.closeStream(st, streamError(f.StreamID, f.ErrCode)) - } - return nil -} - -func (sc *serverConn) closeStream(st *stream, err error) { - sc.serveG.check() - if st.state == stateIdle || st.state == stateClosed { - panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) - } - st.state = stateClosed - if st.writeDeadline != nil { - st.writeDeadline.Stop() - } - if st.isPushed() { - sc.curPushedStreams-- - } else { - sc.curClientStreams-- - } - delete(sc.streams, st.id) - if len(sc.streams) == 0 { - sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { - sc.idleTimer.Reset(sc.srv.IdleTimeout) - } - if h1ServerKeepAlivesDisabled(sc.hs) { - sc.startGracefulShutdownInternal() - } - } - if p := st.body; p != nil { - // Return any buffered unread bytes worth of conn-level flow control. - // See golang.org/issue/16481 - sc.sendWindowUpdate(nil, p.Len()) - - p.CloseWithError(err) - } - st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc - sc.writeSched.CloseStream(st.id) -} - -func (sc *serverConn) processSettings(f *SettingsFrame) error { - sc.serveG.check() - if f.IsAck() { - sc.unackedSettings-- - if sc.unackedSettings < 0 { - // Why is the peer ACKing settings we never sent? - // The spec doesn't mention this case, but - // hang up on them anyway. - return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol)) - } - return nil - } - if f.NumSettings() > 100 || f.HasDuplicates() { - // This isn't actually in the spec, but hang up on - // suspiciously large settings frames or those with - // duplicate entries. - return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol)) - } - if err := f.ForeachSetting(sc.processSetting); err != nil { - return err - } - // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be - // acknowledged individually, even if multiple are received before the ACK. - sc.needToSendSettingsAck = true - sc.scheduleFrameWrite() - return nil -} - -func (sc *serverConn) processSetting(s Setting) error { - sc.serveG.check() - if err := s.Valid(); err != nil { - return err - } - if VerboseLogs { - sc.vlogf("http2: server processing setting %v", s) - } - switch s.ID { - case SettingHeaderTableSize: - sc.headerTableSize = s.Val - sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) - case SettingEnablePush: - sc.pushEnabled = s.Val != 0 - case SettingMaxConcurrentStreams: - sc.clientMaxStreams = s.Val - case SettingInitialWindowSize: - return sc.processSettingInitialWindowSize(s.Val) - case SettingMaxFrameSize: - sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 - case SettingMaxHeaderListSize: - sc.peerMaxHeaderListSize = s.Val - default: - // Unknown setting: "An endpoint that receives a SETTINGS - // frame with any unknown or unsupported identifier MUST - // ignore that setting." - if VerboseLogs { - sc.vlogf("http2: server ignoring unknown setting %v", s) - } - } - return nil -} - -func (sc *serverConn) processSettingInitialWindowSize(val uint32) error { - sc.serveG.check() - // Note: val already validated to be within range by - // processSetting's Valid call. - - // "A SETTINGS frame can alter the initial flow control window - // size for all current streams. When the value of - // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST - // adjust the size of all stream flow control windows that it - // maintains by the difference between the new value and the - // old value." - old := sc.initialStreamSendWindowSize - sc.initialStreamSendWindowSize = int32(val) - growth := int32(val) - old // may be negative - for _, st := range sc.streams { - if !st.flow.add(growth) { - // 6.9.2 Initial Flow Control Window Size - // "An endpoint MUST treat a change to - // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow - // control window to exceed the maximum size as a - // connection error (Section 5.4.1) of type - // FLOW_CONTROL_ERROR." - return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl)) - } - } - return nil -} - -func (sc *serverConn) processData(f *DataFrame) error { - sc.serveG.check() - id := f.Header().StreamID - if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || id > sc.maxClientStreamID) { - // Discard all DATA frames if the GOAWAY is due to an - // error, or: - // - // Section 6.8: After sending a GOAWAY frame, the sender - // can discard frames for streams initiated by the - // receiver with identifiers higher than the identified - // last stream. - return nil - } - - data := f.Data() - state, st := sc.state(id) - if id == 0 || state == stateIdle { - // Section 6.1: "DATA frames MUST be associated with a - // stream. If a DATA frame is received whose stream - // identifier field is 0x0, the recipient MUST respond - // with a connection error (Section 5.4.1) of type - // PROTOCOL_ERROR." - // - // Section 5.1: "Receiving any frame other than HEADERS - // or PRIORITY on a stream in this state MUST be - // treated as a connection error (Section 5.4.1) of - // type PROTOCOL_ERROR." - return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol)) - } - - // "If a DATA frame is received whose stream is not in "open" - // or "half closed (local)" state, the recipient MUST respond - // with a stream error (Section 5.4.2) of type STREAM_CLOSED." - if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued { - // This includes sending a RST_STREAM if the stream is - // in stateHalfClosedLocal (which currently means that - // the http.Handler returned, so it's done reading & - // done writing). Try to stop the client from sending - // more DATA. - - // But still enforce their connection-level flow control, - // and return any flow control bytes since we're not going - // to consume them. - if !sc.inflow.take(f.Length) { - return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) - } - sc.sendWindowUpdate(nil, int(f.Length)) // conn-level - - if st != nil && st.resetQueued { - // Already have a stream error in flight. Don't send another. - return nil - } - return sc.countError("closed", streamError(id, ErrCodeStreamClosed)) - } - if st.body == nil { - panic("internal error: should have a body in this state") - } - - // Sender sending more than they'd declared? - if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { - if !sc.inflow.take(f.Length) { - return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) - } - sc.sendWindowUpdate(nil, int(f.Length)) // conn-level - - st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) - // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the - // value of a content-length header field does not equal the sum of the - // DATA frame payload lengths that form the body. - return sc.countError("send_too_much", streamError(id, ErrCodeProtocol)) - } - if f.Length > 0 { - // Check whether the client has flow control quota. - if !takeInflows(&sc.inflow, &st.inflow, f.Length) { - return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl)) - } - - if len(data) > 0 { - wrote, err := st.body.Write(data) - if err != nil { - sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed)) - } - if wrote != len(data) { - panic("internal error: bad Writer") - } - st.bodyBytes += int64(len(data)) - } - - // Return any padded flow control now, since we won't - // refund it later on body reads. - // Call sendWindowUpdate even if there is no padding, - // to return buffered flow control credit if the sent - // window has shrunk. - pad := int32(f.Length) - int32(len(data)) - sc.sendWindowUpdate32(nil, pad) - sc.sendWindowUpdate32(st, pad) - } - if f.StreamEnded() { - st.endStream() - } - return nil -} - -func (sc *serverConn) processGoAway(f *GoAwayFrame) error { - sc.serveG.check() - if f.ErrCode != ErrCodeNo { - sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) - } else { - sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) - } - sc.startGracefulShutdownInternal() - // http://tools.ietf.org/html/rfc7540#section-6.8 - // We should not create any new streams, which means we should disable push. - sc.pushEnabled = false - return nil -} - -// isPushed reports whether the stream is server-initiated. -func (st *stream) isPushed() bool { - return st.id%2 == 0 -} - -// endStream closes a Request.Body's pipe. It is called when a DATA -// frame says a request body is over (or after trailers). -func (st *stream) endStream() { - sc := st.sc - sc.serveG.check() - - if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { - st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", - st.declBodyBytes, st.bodyBytes)) - } else { - st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) - st.body.CloseWithError(io.EOF) - } - st.state = stateHalfClosedRemote -} - -// copyTrailersToHandlerRequest is run in the Handler's goroutine in -// its Request.Body.Read just before it gets io.EOF. -func (st *stream) copyTrailersToHandlerRequest() { - for k, vv := range st.trailer { - if _, ok := st.reqTrailer[k]; ok { - // Only copy it over it was pre-declared. - st.reqTrailer[k] = vv - } - } -} - -// onWriteTimeout is run on its own goroutine (from time.AfterFunc) -// when the stream's WriteTimeout has fired. -func (st *stream) onWriteTimeout() { - st.sc.writeFrameFromHandler(FrameWriteRequest{write: streamError(st.id, ErrCodeInternal)}) -} - -func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { - sc.serveG.check() - id := f.StreamID - if sc.inGoAway { - // Ignore. - return nil - } - // http://tools.ietf.org/html/rfc7540#section-5.1.1 - // Streams initiated by a client MUST use odd-numbered stream - // identifiers. [...] An endpoint that receives an unexpected - // stream identifier MUST respond with a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. - if id%2 != 1 { - return sc.countError("headers_even", ConnectionError(ErrCodeProtocol)) - } - // A HEADERS frame can be used to create a new stream or - // send a trailer for an open one. If we already have a stream - // open, let it process its own HEADERS frame (trailers at this - // point, if it's valid). - if st := sc.streams[f.StreamID]; st != nil { - if st.resetQueued { - // We're sending RST_STREAM to close the stream, so don't bother - // processing this frame. - return nil - } - // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than - // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in - // this state, it MUST respond with a stream error (Section 5.4.2) of - // type STREAM_CLOSED. - if st.state == stateHalfClosedRemote { - return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed)) - } - return st.processTrailerHeaders(f) - } - - // [...] The identifier of a newly established stream MUST be - // numerically greater than all streams that the initiating - // endpoint has opened or reserved. [...] An endpoint that - // receives an unexpected stream identifier MUST respond with - // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - if id <= sc.maxClientStreamID { - return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol)) - } - sc.maxClientStreamID = id - - if sc.idleTimer != nil { - sc.idleTimer.Stop() - } - - // http://tools.ietf.org/html/rfc7540#section-5.1.2 - // [...] Endpoints MUST NOT exceed the limit set by their peer. An - // endpoint that receives a HEADERS frame that causes their - // advertised concurrent stream limit to be exceeded MUST treat - // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR - // or REFUSED_STREAM. - if sc.curClientStreams+1 > sc.advMaxStreams { - if sc.unackedSettings == 0 { - // They should know better. - return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol)) - } - // Assume it's a network race, where they just haven't - // received our last SETTINGS update. But actually - // this can't happen yet, because we don't yet provide - // a way for users to adjust server parameters at - // runtime. - return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream)) - } - - initialState := stateOpen - if f.StreamEnded() { - initialState = stateHalfClosedRemote - } - st := sc.newStream(id, 0, initialState) - - if f.HasPriority() { - if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { - return err - } - sc.writeSched.AdjustStream(st.id, f.Priority) - } - - rw, req, err := sc.newWriterAndRequest(st, f) - if err != nil { - return err - } - st.reqTrailer = req.Trailer - if st.reqTrailer != nil { - st.trailer = make(http.Header) - } - st.body = req.Body.(*requestBody).pipe // may be nil - st.declBodyBytes = req.ContentLength - - handler := sc.handler.ServeHTTP - if f.Truncated { - // Their header list was too long. Send a 431 error. - handler = handleHeaderListTooLong - } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil { - handler = new400Handler(err) - } - - // The net/http package sets the read deadline from the - // http.Server.ReadTimeout during the TLS handshake, but then - // passes the connection off to us with the deadline already - // set. Disarm it here after the request headers are read, - // similar to how the http1 server works. Here it's - // technically more like the http1 Server's ReadHeaderTimeout - // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { - sc.conn.SetReadDeadline(time.Time{}) - } - - go sc.runHandler(rw, req, handler) - return nil -} - -func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { - sc := st.sc - sc.serveG.check() - if st.gotTrailerHeader { - return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol)) - } - st.gotTrailerHeader = true - if !f.StreamEnded() { - return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol)) - } - - if len(f.PseudoFields()) > 0 { - return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol)) - } - if st.trailer != nil { - for _, hf := range f.RegularFields() { - key := sc.canonicalHeader(hf.Name) - if !httpguts.ValidTrailerHeader(key) { - // TODO: send more details to the peer somehow. But http2 has - // no way to send debug data at a stream level. Discuss with - // HTTP folk. - return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol)) - } - st.trailer[key] = append(st.trailer[key], hf.Value) - } - } - st.endStream() - return nil -} - -func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error { - if streamID == p.StreamDep { - // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat - // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." - // Section 5.3.3 says that a stream can depend on one of its dependencies, - // so it's only self-dependencies that are forbidden. - return sc.countError("priority", streamError(streamID, ErrCodeProtocol)) - } - return nil -} - -func (sc *serverConn) processPriority(f *PriorityFrame) error { - if sc.inGoAway { - return nil - } - if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil { - return err - } - sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam) - return nil -} - -func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream { - sc.serveG.check() - if id == 0 { - panic("internal error: cannot create stream with id 0") - } - - ctx, cancelCtx := context.WithCancel(sc.baseCtx) - st := &stream{ - sc: sc, - id: id, - state: state, - ctx: ctx, - cancelCtx: cancelCtx, - } - st.cw.Init() - st.flow.conn = &sc.flow // link to conn-level counter - st.flow.add(sc.initialStreamSendWindowSize) - st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { - st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) - } - - sc.streams[id] = st - sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID}) - if st.isPushed() { - sc.curPushedStreams++ - } else { - sc.curClientStreams++ - } - if sc.curOpenStreams() == 1 { - sc.setConnState(http.StateActive) - } - - return st -} - -func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { - sc.serveG.check() - - rp := requestParam{ - method: f.PseudoValue("method"), - scheme: f.PseudoValue("scheme"), - authority: f.PseudoValue("authority"), - path: f.PseudoValue("path"), - } - - isConnect := rp.method == "CONNECT" - if isConnect { - if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) - } - } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { - // See 8.1.2.6 Malformed Requests and Responses: - // - // Malformed requests or responses that are detected - // MUST be treated as a stream error (Section 5.4.2) - // of type PROTOCOL_ERROR." - // - // 8.1.2.3 Request Pseudo-Header Fields - // "All HTTP/2 requests MUST include exactly one valid - // value for the :method, :scheme, and :path - // pseudo-header fields" - return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) - } - - bodyOpen := !f.StreamEnded() - if rp.method == "HEAD" && bodyOpen { - // HEAD requests can't have bodies - return nil, nil, sc.countError("head_body", streamError(f.StreamID, ErrCodeProtocol)) - } - - rp.header = make(http.Header) - for _, hf := range f.RegularFields() { - rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) - } - if rp.authority == "" { - rp.authority = rp.header.Get("Host") - } - - rw, req, err := sc.newWriterAndRequestNoBody(st, rp) - if err != nil { - return nil, nil, err - } - if bodyOpen { - if vv, ok := rp.header["Content-Length"]; ok { - if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { - req.ContentLength = int64(cl) - } else { - req.ContentLength = 0 - } - } else { - req.ContentLength = -1 - } - req.Body.(*requestBody).pipe = &pipe{ - b: &dataBuffer{expected: req.ContentLength}, - } - } - return rw, req, nil -} - -type requestParam struct { - method string - scheme, authority, path string - header http.Header -} - -func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) { - sc.serveG.check() - - var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { - tlsState = sc.tlsState - } - - needsContinue := rp.header.Get("Expect") == "100-continue" - if needsContinue { - rp.header.Del("Expect") - } - // Merge Cookie headers into one "; "-delimited value. - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) - } - - // Setup Trailers - var trailer http.Header - for _, v := range rp.header["Trailer"] { - for _, key := range strings.Split(v, ",") { - key = http.CanonicalHeaderKey(textproto.TrimString(key)) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - // Bogus. (copy of http1 rules) - // Ignore. - default: - if trailer == nil { - trailer = make(http.Header) - } - trailer[key] = nil - } - } - } - delete(rp.header, "Trailer") - - var u *url.URL - var requestURI string - if rp.method == "CONNECT" { - u = &url.URL{Host: rp.authority} - requestURI = rp.authority // mimic HTTP/1 server behavior - } else { - var err error - u, err = url.ParseRequestURI(rp.path) - if err != nil { - return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol)) - } - requestURI = rp.path - } - - body := &requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } - req := &http.Request{ - Method: rp.method, - URL: u, - RemoteAddr: sc.remoteAddrStr, - Header: rp.header, - RequestURI: requestURI, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - TLS: tlsState, - Host: rp.authority, - Body: body, - Trailer: trailer, - } - req = req.WithContext(st.ctx) - - rws := responseWriterStatePool.Get().(*responseWriterState) - bwSave := rws.bw - *rws = responseWriterState{} // zero all the fields - rws.conn = sc - rws.bw = bwSave - rws.bw.Reset(chunkWriter{rws}) - rws.stream = st - rws.req = req - rws.body = body - - rw := &responseWriter{rws: rws} - return rw, req, nil -} - -// Run on its own goroutine. -func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { - didPanic := true - defer func() { - rw.rws.stream.cancelCtx() - if didPanic { - e := recover() - sc.writeFrameFromHandler(FrameWriteRequest{ - write: handlerPanicRST{rw.rws.stream.id}, - stream: rw.rws.stream, - }) - // Same as net/http: - if e != nil && e != http.ErrAbortHandler { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) - } - return - } - rw.handlerDone() - }() - handler(rw, req) - didPanic = false -} - -func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { - // 10.5.1 Limits on Header Block Size: - // .. "A server that receives a larger header block than it is - // willing to handle can send an HTTP 431 (Request Header Fields Too - // Large) status code" - const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ - w.WriteHeader(statusRequestHeaderFieldsTooLarge) - io.WriteString(w, "

HTTP Error 431

Request Header Field(s) Too Large

") -} - -// called from handler goroutines. -// h may be nil. -func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error { - sc.serveG.checkNotOn() // NOT on - var errc chan error - if headerData.h != nil { - // If there's a header map (which we don't own), so we have to block on - // waiting for this frame to be written, so an http.Flush mid-handler - // writes out the correct value of keys, before a handler later potentially - // mutates it. - errc = errChanPool.Get().(chan error) - } - if err := sc.writeFrameFromHandler(FrameWriteRequest{ - write: headerData, - stream: st, - done: errc, - }); err != nil { - return err - } - if errc != nil { - select { - case err := <-errc: - errChanPool.Put(errc) - return err - case <-sc.doneServing: - return errClientDisconnected - case <-st.cw: - return errStreamClosed - } - } - return nil -} - -// called from handler goroutines. -func (sc *serverConn) write100ContinueHeaders(st *stream) { - sc.writeFrameFromHandler(FrameWriteRequest{ - write: write100ContinueHeadersFrame{st.id}, - stream: st, - }) -} - -// A bodyReadMsg tells the server loop that the http.Handler read n -// bytes of the DATA from the client on the given stream. -type bodyReadMsg struct { - st *stream - n int -} - -// called from handler goroutines. -// Notes that the handler for the given stream ID read n bytes of its body -// and schedules flow control tokens to be sent. -func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) { - sc.serveG.checkNotOn() // NOT on - if n > 0 { - select { - case sc.bodyReadCh <- bodyReadMsg{st, n}: - case <-sc.doneServing: - } - } -} - -func (sc *serverConn) noteBodyRead(st *stream, n int) { - sc.serveG.check() - sc.sendWindowUpdate(nil, n) // conn-level - if st.state != stateHalfClosedRemote && st.state != stateClosed { - // Don't send this WINDOW_UPDATE if the stream is closed - // remotely. - sc.sendWindowUpdate(st, n) - } -} - -// st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { - sc.sendWindowUpdate(st, int(n)) -} - -// st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate(st *stream, n int) { - sc.serveG.check() - var streamID uint32 - var send int32 - if st == nil { - send = sc.inflow.add(n) - } else { - streamID = st.id - send = st.inflow.add(n) - } - if send == 0 { - return - } - sc.writeFrame(FrameWriteRequest{ - write: writeWindowUpdate{streamID: streamID, n: uint32(send)}, - stream: st, - }) -} - -// requestBody is the Handler's Request.Body type. -// Read and Close may be called concurrently. -type requestBody struct { - _ incomparable - stream *stream - conn *serverConn - closed bool // for use by Close only - sawEOF bool // for use by Read only - pipe *pipe // non-nil if we have a HTTP entity message body - needsContinue bool // need to send a 100-continue -} - -func (b *requestBody) Close() error { - if b.pipe != nil && !b.closed { - b.pipe.BreakWithError(errClosedBody) - } - b.closed = true - return nil -} - -func (b *requestBody) Read(p []byte) (n int, err error) { - if b.needsContinue { - b.needsContinue = false - b.conn.write100ContinueHeaders(b.stream) - } - if b.pipe == nil || b.sawEOF { - return 0, io.EOF - } - n, err = b.pipe.Read(p) - if err == io.EOF { - b.sawEOF = true - } - if b.conn == nil && inTests { - return - } - b.conn.noteBodyReadFromHandler(b.stream, n, err) - return -} - -// responseWriter is the http.ResponseWriter implementation. It's -// intentionally small (1 pointer wide) to minimize garbage. The -// responseWriterState pointer inside is zeroed at the end of a -// request (in handlerDone) and calls on the responseWriter thereafter -// simply crash (caller's mistake), but the much larger responseWriterState -// and buffers are reused between multiple requests. -type responseWriter struct { - rws *responseWriterState -} - -// from pkg io -type stringWriter interface { - WriteString(s string) (n int, err error) -} - -// Optional http.ResponseWriter interfaces implemented. -var ( - _ http.CloseNotifier = (*responseWriter)(nil) - _ http.Flusher = (*responseWriter)(nil) - _ stringWriter = (*responseWriter)(nil) -) - -type responseWriterState struct { - // immutable within a request: - stream *stream - req *http.Request - body *requestBody // to close at end of request, if DATA frames didn't - conn *serverConn - - // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc - bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} - - // mutated by http.Handler goroutine: - handlerHeader http.Header // nil until called - snapHeader http.Header // snapshot of handlerHeader at WriteHeader time - trailers []string // set in writeChunk - status int // status code passed to WriteHeader - wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. - sentHeader bool // have we sent the header frame? - handlerDone bool // handler has finished - dirty bool // a Write failed; don't reuse this responseWriterState - - sentContentLen int64 // non-zero if handler set a Content-Length header - wroteBytes int64 - - closeNotifierMu sync.Mutex // guards closeNotifierCh - closeNotifierCh chan bool // nil until first used -} - -type chunkWriter struct{ rws *responseWriterState } - -func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } - -func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } - -func (rws *responseWriterState) hasNonemptyTrailers() bool { - for _, trailer := range rws.trailers { - if _, ok := rws.handlerHeader[trailer]; ok { - return true - } - } - return false -} - -// declareTrailer is called for each Trailer header when the -// response header is written. It notes that a header will need to be -// written in the trailers at the end of the response. -func (rws *responseWriterState) declareTrailer(k string) { - k = http.CanonicalHeaderKey(k) - if !httpguts.ValidTrailerHeader(k) { - // Forbidden by RFC 7230, section 4.1.2. - rws.conn.logf("ignoring invalid trailer %q", k) - return - } - if !strSliceContains(rws.trailers, k) { - rws.trailers = append(rws.trailers, k) - } -} - -// writeChunk writes chunks from the bufio.Writer. But because -// bufio.Writer may bypass its chunking, sometimes p may be -// arbitrarily large. -// -// writeChunk is also responsible (on the first chunk) for sending the -// HEADER response. -func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { - if !rws.wroteHeader { - rws.writeHeader(200) - } - - isHeadResp := rws.req.Method == "HEAD" - if !rws.sentHeader { - rws.sentHeader = true - var ctype, clen string - if clen = rws.snapHeader.Get("Content-Length"); clen != "" { - rws.snapHeader.Del("Content-Length") - if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { - rws.sentContentLen = int64(cl) - } else { - clen = "" - } - } - if clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { - clen = strconv.Itoa(len(p)) - } - _, hasContentType := rws.snapHeader["Content-Type"] - // If the Content-Encoding is non-blank, we shouldn't - // sniff the body. See Issue golang.org/issue/31753. - ce := rws.snapHeader.Get("Content-Encoding") - hasCE := len(ce) > 0 - if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 { - ctype = http.DetectContentType(p) - } - var date string - if _, ok := rws.snapHeader["Date"]; !ok { - // TODO(bradfitz): be faster here, like net/http? measure. - date = time.Now().UTC().Format(http.TimeFormat) - } - - for _, v := range rws.snapHeader["Trailer"] { - foreachHeaderElement(v, rws.declareTrailer) - } - - // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), - // but respect "Connection" == "close" to mean sending a GOAWAY and tearing - // down the TCP connection when idle, like we do for HTTP/1. - // TODO: remove more Connection-specific header fields here, in addition - // to "Connection". - if _, ok := rws.snapHeader["Connection"]; ok { - v := rws.snapHeader.Get("Connection") - delete(rws.snapHeader, "Connection") - if v == "close" { - rws.conn.startGracefulShutdown() - } - } - - endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp - err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{ - streamID: rws.stream.id, - httpResCode: rws.status, - h: rws.snapHeader, - endStream: endStream, - contentType: ctype, - contentLength: clen, - date: date, - }) - if err != nil { - rws.dirty = true - return 0, err - } - if endStream { - return 0, nil - } - } - if isHeadResp { - return len(p), nil - } - if len(p) == 0 && !rws.handlerDone { - return 0, nil - } - - if rws.handlerDone { - rws.promoteUndeclaredTrailers() - } - - // only send trailers if they have actually been defined by the - // server handler. - hasNonemptyTrailers := rws.hasNonemptyTrailers() - endStream := rws.handlerDone && !hasNonemptyTrailers - if len(p) > 0 || endStream { - // only send a 0 byte DATA frame if we're ending the stream. - if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { - rws.dirty = true - return 0, err - } - } - - if rws.handlerDone && hasNonemptyTrailers { - err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{ - streamID: rws.stream.id, - h: rws.handlerHeader, - trailers: rws.trailers, - endStream: true, - }) - if err != nil { - rws.dirty = true - } - return len(p), err - } - return len(p), nil -} - -// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys -// that, if present, signals that the map entry is actually for -// the response trailers, and not the response headers. The prefix -// is stripped after the ServeHTTP call finishes and the values are -// sent in the trailers. -// -// This mechanism is intended only for trailers that are not known -// prior to the headers being written. If the set of trailers is fixed -// or known before the header is written, the normal Go trailers mechanism -// is preferred: -// -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers -const TrailerPrefix = "Trailer:" - -// promoteUndeclaredTrailers permits http.Handlers to set trailers -// after the header has already been flushed. Because the Go -// ResponseWriter interface has no way to set Trailers (only the -// Header), and because we didn't want to expand the ResponseWriter -// interface, and because nobody used trailers, and because RFC 7230 -// says you SHOULD (but not must) predeclare any trailers in the -// header, the official ResponseWriter rules said trailers in Go must -// be predeclared, and then we reuse the same ResponseWriter.Header() -// map to mean both Headers and Trailers. When it's time to write the -// Trailers, we pick out the fields of Headers that were declared as -// trailers. That worked for a while, until we found the first major -// user of Trailers in the wild: gRPC (using them only over http2), -// and gRPC libraries permit setting trailers mid-stream without -// predeclaring them. So: change of plans. We still permit the old -// way, but we also permit this hack: if a Header() key begins with -// "Trailer:", the suffix of that key is a Trailer. Because ':' is an -// invalid token byte anyway, there is no ambiguity. (And it's already -// filtered out) It's mildly hacky, but not terrible. -// -// This method runs after the Handler is done and promotes any Header -// fields to be trailers. -func (rws *responseWriterState) promoteUndeclaredTrailers() { - for k, vv := range rws.handlerHeader { - if !strings.HasPrefix(k, TrailerPrefix) { - continue - } - trailerKey := strings.TrimPrefix(k, TrailerPrefix) - rws.declareTrailer(trailerKey) - rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv - } - - if len(rws.trailers) > 1 { - sorter := sorterPool.Get().(*sorter) - sorter.SortStrings(rws.trailers) - sorterPool.Put(sorter) - } -} - -func (w *responseWriter) Flush() { - rws := w.rws - if rws == nil { - panic("Header called after Handler finished") - } - if rws.bw.Buffered() > 0 { - if err := rws.bw.Flush(); err != nil { - // Ignore the error. The frame writer already knows. - return - } - } else { - // The bufio.Writer won't call chunkWriter.Write - // (writeChunk with zero bytes, so we have to do it - // ourselves to force the HTTP response header and/or - // final DATA frame (with END_STREAM) to be sent. - rws.writeChunk(nil) - } -} - -func (w *responseWriter) CloseNotify() <-chan bool { - rws := w.rws - if rws == nil { - panic("CloseNotify called after Handler finished") - } - rws.closeNotifierMu.Lock() - ch := rws.closeNotifierCh - if ch == nil { - ch = make(chan bool, 1) - rws.closeNotifierCh = ch - cw := rws.stream.cw - go func() { - cw.Wait() // wait for close - ch <- true - }() - } - rws.closeNotifierMu.Unlock() - return ch -} - -func (w *responseWriter) Header() http.Header { - rws := w.rws - if rws == nil { - panic("Header called after Handler finished") - } - if rws.handlerHeader == nil { - rws.handlerHeader = make(http.Header) - } - return rws.handlerHeader -} - -// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. -func checkWriteHeaderCode(code int) { - // Issue 22880: require valid WriteHeader status codes. - // For now we only enforce that it's three digits. - // In the future we might block things over 599 (600 and above aren't defined - // at http://httpwg.org/specs/rfc7231.html#status.codes) - // and we might block under 200 (once we have more mature 1xx support). - // But for now any three digits. - // - // We used to send "HTTP/1.1 000 0" on the wire in responses but there's - // no equivalent bogus thing we can realistically send in HTTP/2, - // so we'll consistently panic instead and help people find their bugs - // early. (We can't return an error from WriteHeader even if we wanted to.) - if code < 100 || code > 999 { - panic(fmt.Sprintf("invalid WriteHeader code %v", code)) - } -} - -func (w *responseWriter) WriteHeader(code int) { - rws := w.rws - if rws == nil { - panic("WriteHeader called after Handler finished") - } - rws.writeHeader(code) -} - -func (rws *responseWriterState) writeHeader(code int) { - if !rws.wroteHeader { - checkWriteHeaderCode(code) - rws.wroteHeader = true - rws.status = code - if len(rws.handlerHeader) > 0 { - rws.snapHeader = cloneHeader(rws.handlerHeader) - } - } -} - -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } - return h2 -} - -// The Life Of A Write is like this: -// -// * Handler calls w.Write or w.WriteString -> -// * -> rws.bw (*bufio.Writer) -> -// * (Handler might call Flush) -// * -> chunkWriter{rws} -// * -> responseWriterState.writeChunk(p []byte) -// * -> responseWriterState.writeChunk (most of the magic; see comment there) -func (w *responseWriter) Write(p []byte) (n int, err error) { - return w.write(len(p), p, "") -} - -func (w *responseWriter) WriteString(s string) (n int, err error) { - return w.write(len(s), nil, s) -} - -// either dataB or dataS is non-zero. -func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { - rws := w.rws - if rws == nil { - panic("Write called after Handler finished") - } - if !rws.wroteHeader { - w.WriteHeader(200) - } - if !bodyAllowedForStatus(rws.status) { - return 0, http.ErrBodyNotAllowed - } - rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set - if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { - // TODO: send a RST_STREAM - return 0, errors.New("http2: handler wrote more than declared Content-Length") - } - - if dataB != nil { - return rws.bw.Write(dataB) - } - return rws.bw.WriteString(dataS) -} - -func (w *responseWriter) handlerDone() { - rws := w.rws - dirty := rws.dirty - rws.handlerDone = true - w.Flush() - w.rws = nil - if !dirty { - // Only recycle the pool if all prior Write calls to - // the serverConn goroutine completed successfully. If - // they returned earlier due to resets from the peer - // there might still be write goroutines outstanding - // from the serverConn referencing the rws memory. See - // issue 20704. - responseWriterStatePool.Put(rws) - } -} - -// Push errors. -var ( - errRecursivePush = errors.New("http2: recursive push not allowed") - errPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") -) - -var _ http.Pusher = (*responseWriter)(nil) - -func (w *responseWriter) Push(target string, opts *http.PushOptions) error { - st := w.rws.stream - sc := st.sc - sc.serveG.checkNotOn() - - // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." - // http://tools.ietf.org/html/rfc7540#section-6.6 - if st.isPushed() { - return errRecursivePush - } - - if opts == nil { - opts = new(http.PushOptions) - } - - // Default options. - if opts.Method == "" { - opts.Method = "GET" - } - if opts.Header == nil { - opts.Header = http.Header{} - } - wantScheme := "http" - if w.rws.req.TLS != nil { - wantScheme = "https" - } - - // Validate the request. - u, err := url.Parse(target) - if err != nil { - return err - } - if u.Scheme == "" { - if !strings.HasPrefix(target, "/") { - return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) - } - u.Scheme = wantScheme - u.Host = w.rws.req.Host - } else { - if u.Scheme != wantScheme { - return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) - } - if u.Host == "" { - return errors.New("URL must have a host") - } - } - for k := range opts.Header { - if strings.HasPrefix(k, ":") { - return fmt.Errorf("promised request headers cannot include pseudo header %q", k) - } - // These headers are meaningful only if the request has a body, - // but PUSH_PROMISE requests cannot have a body. - // http://tools.ietf.org/html/rfc7540#section-8.2 - // Also disallow Host, since the promised URL must be absolute. - if ascii.EqualFold(k, "content-length") || - ascii.EqualFold(k, "content-encoding") || - ascii.EqualFold(k, "trailer") || - ascii.EqualFold(k, "te") || - ascii.EqualFold(k, "expect") || - ascii.EqualFold(k, "host") { - return fmt.Errorf("promised request headers cannot include %q", k) - } - } - if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil { - return err - } - - // The RFC effectively limits promised requests to GET and HEAD: - // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" - // http://tools.ietf.org/html/rfc7540#section-8.2 - if opts.Method != "GET" && opts.Method != "HEAD" { - return fmt.Errorf("method %q must be GET or HEAD", opts.Method) - } - - msg := &startPushRequest{ - parent: st, - method: opts.Method, - url: u, - header: cloneHeader(opts.Header), - done: errChanPool.Get().(chan error), - } - - select { - case <-sc.doneServing: - return errClientDisconnected - case <-st.cw: - return errStreamClosed - case sc.serveMsgCh <- msg: - } - - select { - case <-sc.doneServing: - return errClientDisconnected - case <-st.cw: - return errStreamClosed - case err := <-msg.done: - errChanPool.Put(msg.done) - return err - } -} - -type startPushRequest struct { - parent *stream - method string - url *url.URL - header http.Header - done chan error -} - -func (sc *serverConn) startPush(msg *startPushRequest) { - sc.serveG.check() - - // http://tools.ietf.org/html/rfc7540#section-6.6. - // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that - // is in either the "open" or "half-closed (remote)" state. - if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote { - // responseWriter.Push checks that the stream is peer-initiated. - msg.done <- errStreamClosed - return - } - - // http://tools.ietf.org/html/rfc7540#section-6.6. - if !sc.pushEnabled { - msg.done <- http.ErrNotSupported - return - } - - // PUSH_PROMISE frames must be sent in increasing order by stream ID, so - // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE - // is written. Once the ID is allocated, we start the request handler. - allocatePromisedID := func() (uint32, error) { - sc.serveG.check() - - // Check this again, just in case. Technically, we might have received - // an updated SETTINGS by the time we got around to writing this frame. - if !sc.pushEnabled { - return 0, http.ErrNotSupported - } - // http://tools.ietf.org/html/rfc7540#section-6.5.2. - if sc.curPushedStreams+1 > sc.clientMaxStreams { - return 0, errPushLimitReached - } - - // http://tools.ietf.org/html/rfc7540#section-5.1.1. - // Streams initiated by the server MUST use even-numbered identifiers. - // A server that is unable to establish a new stream identifier can send a GOAWAY - // frame so that the client is forced to open a new connection for new streams. - if sc.maxPushPromiseID+2 >= 1<<31 { - sc.startGracefulShutdownInternal() - return 0, errPushLimitReached - } - sc.maxPushPromiseID += 2 - promisedID := sc.maxPushPromiseID - - // http://tools.ietf.org/html/rfc7540#section-8.2. - // Strictly speaking, the new stream should start in "reserved (local)", then - // transition to "half closed (remote)" after sending the initial HEADERS, but - // we start in "half closed (remote)" for simplicity. - // See further comments at the definition of stateHalfClosedRemote. - promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) - rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ - method: msg.method, - scheme: msg.url.Scheme, - authority: msg.url.Host, - path: msg.url.RequestURI(), - header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE - }) - if err != nil { - // Should not happen, since we've already validated msg.url. - panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) - } - - go sc.runHandler(rw, req, sc.handler.ServeHTTP) - return promisedID, nil - } - - sc.writeFrame(FrameWriteRequest{ - write: &writePushPromise{ - streamID: msg.parent.id, - method: msg.method, - url: msg.url, - h: msg.header, - allocatePromisedID: allocatePromisedID, - }, - stream: msg.parent, - done: msg.done, - }) -} - -// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 -var connHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Connection", - "Transfer-Encoding", - "Upgrade", -} - -// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, -// per RFC 7540 Section 8.1.2.2. -// The returned error is reported to users. -func checkValidHTTP2RequestHeaders(h http.Header) error { - for _, k := range connHeaders { - if _, ok := h[k]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", k) - } - } - te := h["Te"] - if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { - return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) - } - return nil -} - -func new400Handler(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - http.Error(w, err.Error(), http.StatusBadRequest) - } -} - -// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives -// disabled. See comments on h1ServerShutdownChan above for why -// the code is written this way. -func h1ServerKeepAlivesDisabled(hs *http.Server) bool { - var x interface{} = hs - type I interface { - doKeepAlives() bool - } - if hs, ok := x.(I); ok { - return !hs.doKeepAlives() - } - return false -} - -func (sc *serverConn) countError(name string, err error) error { - if sc == nil || sc.srv == nil { - return err - } - f := sc.srv.CountError - if f == nil { - return err - } - var typ string - var code ErrCode - switch e := err.(type) { - case ConnectionError: - typ = "conn" - code = ErrCode(e) - case StreamError: - typ = "stream" - code = ErrCode(e.Code) - default: - return err - } - codeStr := errCodeName[code] - if codeStr == "" { - codeStr = strconv.Itoa(int(code)) - } - f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) - return err -} - -// writeFramer is implemented by any type that is used to write frames. -type writeFramer interface { - writeFrame(writeContext) error - - // staysWithinBuffer reports whether this writer promises that - // it will only write less than or equal to size bytes, and it - // won't Flush the write context. - staysWithinBuffer(size int) bool -} - -// writeContext is the interface needed by the various frame writer -// types below. All the writeFrame methods below are scheduled via the -// frame writing scheduler (see writeScheduler in writesched.go). -// -// This interface is implemented by *serverConn. -// -// TODO: decide whether to a) use this in the client code (which didn't -// end up using this yet, because it has a simpler design, not -// currently implementing priorities), or b) delete this and -// make the server code a bit more concrete. -type writeContext interface { - Framer() *Framer - Flush() error - CloseConn() error - // HeaderEncoder returns an HPACK encoder that writes to the - // returned buffer. - HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) -} - -// writeEndsStream reports whether w writes a frame that will transition -// the stream to a half-closed local state. This returns false for RST_STREAM, -// which closes the entire stream (not just the local half). -func writeEndsStream(w writeFramer) bool { - switch v := w.(type) { - case *writeData: - return v.endStream - case *writeResHeaders: - return v.endStream - case nil: - // This can only happen if the caller reuses w after it's - // been intentionally nil'ed out to prevent use. Keep this - // here to catch future refactoring breaking it. - panic("writeEndsStream called on nil writeFramer") - } - return false -} - -type flushFrameWriter struct{} - -func (flushFrameWriter) writeFrame(ctx writeContext) error { - return ctx.Flush() -} - -func (flushFrameWriter) staysWithinBuffer(max int) bool { return false } - -type writeSettings []Setting - -func (s writeSettings) staysWithinBuffer(max int) bool { - const settingSize = 6 // uint16 + uint32 - return frameHeaderLen+settingSize*len(s) <= max - -} - -func (s writeSettings) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteSettings([]Setting(s)...) -} - -type writeGoAway struct { - maxStreamID uint32 - code ErrCode -} - -func (p *writeGoAway) writeFrame(ctx writeContext) error { - err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) - ctx.Flush() // ignore error: we're hanging up on them anyway - return err -} - -func (*writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes - -type writeData struct { - streamID uint32 - p []byte - endStream bool -} - -func (w *writeData) String() string { - return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) -} - -func (w *writeData) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) -} - -func (w *writeData) staysWithinBuffer(max int) bool { - return frameHeaderLen+len(w.p) <= max -} - -// handlerPanicRST is the message sent from handler goroutines when -// the handler panics. -type handlerPanicRST struct { - StreamID uint32 -} - -func (hp handlerPanicRST) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteRSTStream(hp.StreamID, ErrCodeInternal) -} - -func (hp handlerPanicRST) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } - -func (se StreamError) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) -} - -func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } - -type writePingAck struct{ pf *PingFrame } - -func (w writePingAck) writeFrame(ctx writeContext) error { - return ctx.Framer().WritePing(true, w.pf.Data) -} - -func (w writePingAck) staysWithinBuffer(max int) bool { - return frameHeaderLen+len(w.pf.Data) <= max -} - -type writeSettingsAck struct{} - -func (writeSettingsAck) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteSettingsAck() -} - -func (writeSettingsAck) staysWithinBuffer(max int) bool { return frameHeaderLen <= max } - -// splitHeaderBlock splits headerBlock into fragments so that each fragment fits -// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true -// for the first/last fragment, respectively. -func splitHeaderBlock(ctx writeContext, headerBlock []byte, fn func(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error) error { - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 - - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { - return err - } - first = false - } - return nil -} - -// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames -// for HTTP response headers or trailers from a server handler. -type writeResHeaders struct { - streamID uint32 - httpResCode int // 0 means no ":status" line - h http.Header // may be nil - trailers []string // if non-nil, which keys of h to write. nil means all. - endStream bool - - date string - contentType string - contentLength string -} - -func encKV(enc *hpack.Encoder, k, v string) { - if VerboseLogs { - log.Printf("http2: server encoding header %q = %q", k, v) - } - enc.WriteField(hpack.HeaderField{Name: k, Value: v}) -} - -func (w *writeResHeaders) staysWithinBuffer(max int) bool { - // TODO: this is a common one. It'd be nice to return true - // here and get into the fast path if we could be clever and - // calculate the size fast enough, or at least a conservative - // upper bound that usually fires. (Maybe if w.h and - // w.trailers are nil, so we don't need to enumerate it.) - // Otherwise I'm afraid that just calculating the length to - // answer this question would be slower than the ~2µs benefit. - return false -} - -func (w *writeResHeaders) writeFrame(ctx writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() - - if w.httpResCode != 0 { - encKV(enc, ":status", httpCodeString(w.httpResCode)) - } - - encodeHeaders(enc, w.h, w.trailers) - - if w.contentType != "" { - encKV(enc, "content-type", w.contentType) - } - if w.contentLength != "" { - encKV(enc, "content-length", w.contentLength) - } - if w.date != "" { - encKV(enc, "date", w.date) - } - - headerBlock := buf.Bytes() - if len(headerBlock) == 0 && w.trailers == nil { - panic("unexpected empty hpack") - } - - return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) -} - -func (w *writeResHeaders) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { - if firstFrag { - return ctx.Framer().WriteHeaders(HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: lastFrag, - }) - } - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) -} - -// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. -type writePushPromise struct { - streamID uint32 // pusher stream - method string // for :method - url *url.URL // for :scheme, :authority, :path - h http.Header - - // Creates an ID for a pushed stream. This runs on serveG just before - // the frame is written. The returned ID is copied to promisedID. - allocatePromisedID func() (uint32, error) - promisedID uint32 -} - -func (w *writePushPromise) staysWithinBuffer(max int) bool { - // TODO: see writeResHeaders.staysWithinBuffer - return false -} - -func (w *writePushPromise) writeFrame(ctx writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() - - encKV(enc, ":method", w.method) - encKV(enc, ":scheme", w.url.Scheme) - encKV(enc, ":authority", w.url.Host) - encKV(enc, ":path", w.url.RequestURI()) - encodeHeaders(enc, w.h, nil) - - headerBlock := buf.Bytes() - if len(headerBlock) == 0 { - panic("unexpected empty hpack") - } - - return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) -} - -func (w *writePushPromise) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { - if firstFrag { - return ctx.Framer().WritePushPromise(PushPromiseParam{ - StreamID: w.streamID, - PromiseID: w.promisedID, - BlockFragment: frag, - EndHeaders: lastFrag, - }) - } - return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) -} - -type write100ContinueHeadersFrame struct { - streamID uint32 -} - -func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error { - enc, buf := ctx.HeaderEncoder() - buf.Reset() - encKV(enc, ":status", "100") - return ctx.Framer().WriteHeaders(HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: buf.Bytes(), - EndStream: false, - EndHeaders: true, - }) -} - -func (w write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { - // Sloppy but conservative: - return 9+2*(len(":status")+len("100")) <= max -} - -type writeWindowUpdate struct { - streamID uint32 // or 0 for conn-level - n uint32 -} - -func (wu writeWindowUpdate) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } - -func (wu writeWindowUpdate) writeFrame(ctx writeContext) error { - return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) -} - -// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) -// is encoded only if k is in keys. -func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { - if keys == nil { - sorter := sorterPool.Get().(*sorter) - // Using defer here, since the returned keys from the - // sorter.Keys method is only valid until the sorter - // is returned: - defer sorterPool.Put(sorter) - keys = sorter.Keys(h) - } - for _, k := range keys { - vv := h[k] - k, ascii := lowerHeader(k) - if !ascii { - // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header - // field names have to be ASCII characters (just as in HTTP/1.x). - continue - } - if !validWireHeaderFieldName(k) { - // Skip it as backup paranoia. Per - // golang.org/issue/14048, these should - // already be rejected at a higher level. - continue - } - isTE := k == "transfer-encoding" - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // TODO: return an error? golang.org/issue/14048 - // For now just omit it. - continue - } - // TODO: more of "8.1.2.2 Connection-Specific Header Fields" - if isTE && v != "trailers" { - continue - } - encKV(enc, k, v) - } - } -} - -// WriteScheduler is the interface implemented by HTTP/2 write schedulers. -// Methods are never called concurrently. -type WriteScheduler interface { - // OpenStream opens a new stream in the write scheduler. - // It is illegal to call this with streamID=0 or with a streamID that is - // already open -- the call may panic. - OpenStream(streamID uint32, options OpenStreamOptions) - - // CloseStream closes a stream in the write scheduler. Any frames queued on - // this stream should be discarded. It is illegal to call this on a stream - // that is not open -- the call may panic. - CloseStream(streamID uint32) - - // AdjustStream adjusts the priority of the given stream. This may be called - // on a stream that has not yet been opened or has been closed. Note that - // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: - // https://tools.ietf.org/html/rfc7540#section-5.1 - AdjustStream(streamID uint32, priority PriorityParam) - - // Push queues a frame in the scheduler. In most cases, this will not be - // called with wr.StreamID()!=0 unless that stream is currently open. The one - // exception is RST_STREAM frames, which may be sent on idle or closed streams. - Push(wr FrameWriteRequest) - - // Pop dequeues the next frame to write. Returns false if no frames can - // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd, except RST_STREAM frames. No frames should be - // discarded except by CloseStream. - Pop() (wr FrameWriteRequest, ok bool) -} - -// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. -type OpenStreamOptions struct { - // PusherID is zero if the stream was initiated by the client. Otherwise, - // PusherID names the stream that pushed the newly opened stream. - PusherID uint32 -} - -// FrameWriteRequest is a request to write a frame. -type FrameWriteRequest struct { - // write is the interface value that does the writing, once the - // WriteScheduler has selected this frame to write. The write - // functions are all defined in write.go. - write writeFramer - - // stream is the stream on which this frame will be written. - // nil for non-stream frames like PING and SETTINGS. - // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. - stream *stream - - // done, if non-nil, must be a buffered channel with space for - // 1 message and is sent the return value from write (or an - // earlier error) when the frame has been written. - done chan error -} - -// StreamID returns the id of the stream this frame will be written to. -// 0 is used for non-stream frames such as PING and SETTINGS. -func (wr FrameWriteRequest) StreamID() uint32 { - if wr.stream == nil { - if se, ok := wr.write.(StreamError); ok { - // (*serverConn).resetStream doesn't set - // stream because it doesn't necessarily have - // one. So special case this type of write - // message. - return se.StreamID - } - return 0 - } - return wr.stream.id -} - -// isControl reports whether wr is a control frame for MaxQueuedControlFrames -// purposes. That includes non-stream frames and RST_STREAM frames. -func (wr FrameWriteRequest) isControl() bool { - return wr.stream == nil -} - -// DataSize returns the number of flow control bytes that must be consumed -// to write this entire frame. This is 0 for non-DATA frames. -func (wr FrameWriteRequest) DataSize() int { - if wd, ok := wr.write.(*writeData); ok { - return len(wd.p) - } - return 0 -} - -// Consume consumes min(n, available) bytes from this frame, where available -// is the number of flow control bytes available on the stream. Consume returns -// 0, 1, or 2 frames, where the integer return value gives the number of frames -// returned. -// -// If flow control prevents consuming any bytes, this returns (_, _, 0). If -// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this -// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and -// 'rest' contains the remaining bytes. The consumed bytes are deducted from the -// underlying stream's flow control budget. -func (wr FrameWriteRequest) Consume(n int32) (FrameWriteRequest, FrameWriteRequest, int) { - var empty FrameWriteRequest - - // Non-DATA frames are always consumed whole. - wd, ok := wr.write.(*writeData) - if !ok || len(wd.p) == 0 { - return wr, empty, 1 - } - - // Might need to split after applying limits. - allowed := wr.stream.flow.available() - if n < allowed { - allowed = n - } - if wr.stream.sc.maxFrameSize < allowed { - allowed = wr.stream.sc.maxFrameSize - } - if allowed <= 0 { - return empty, empty, 0 - } - if len(wd.p) > int(allowed) { - wr.stream.flow.take(allowed) - consumed := FrameWriteRequest{ - stream: wr.stream, - write: &writeData{ - streamID: wd.streamID, - p: wd.p[:allowed], - // Even if the original had endStream set, there - // are bytes remaining because len(wd.p) > allowed, - // so we know endStream is false. - endStream: false, - }, - // Our caller is blocking on the final DATA frame, not - // this intermediate frame, so no need to wait. - done: nil, - } - rest := FrameWriteRequest{ - stream: wr.stream, - write: &writeData{ - streamID: wd.streamID, - p: wd.p[allowed:], - endStream: wd.endStream, - }, - done: wr.done, - } - return consumed, rest, 2 - } - - // The frame is consumed whole. - // NB: This cast cannot overflow because allowed is <= math.MaxInt32. - wr.stream.flow.take(int32(len(wd.p))) - return wr, empty, 1 -} - -// String is for debugging only. -func (wr FrameWriteRequest) String() string { - var des string - if s, ok := wr.write.(fmt.Stringer); ok { - des = s.String() - } else { - des = fmt.Sprintf("%T", wr.write) - } - return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) -} - -// replyToWriter sends err to wr.done and panics if the send must block -// This does nothing if wr.done is nil. -func (wr *FrameWriteRequest) replyToWriter(err error) { - if wr.done == nil { - return - } - select { - case wr.done <- err: - default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) - } - wr.write = nil // prevent use (assume it's tainted after wr.done send) -} - -// writeQueue is used by implementations of WriteScheduler. -type writeQueue struct { - s []FrameWriteRequest - prev, next *writeQueue -} - -func (q *writeQueue) empty() bool { return len(q.s) == 0 } - -func (q *writeQueue) push(wr FrameWriteRequest) { - q.s = append(q.s, wr) -} - -func (q *writeQueue) shift() FrameWriteRequest { - if len(q.s) == 0 { - panic("invalid use of queue") - } - wr := q.s[0] - // TODO: less copy-happy queue. - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = FrameWriteRequest{} - q.s = q.s[:len(q.s)-1] - return wr -} - -// consume consumes up to n bytes from q.s[0]. If the frame is -// entirely consumed, it is removed from the queue. If the frame -// is partially consumed, the frame is kept with the consumed -// bytes removed. Returns true iff any bytes were consumed. -func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) { - if len(q.s) == 0 { - return FrameWriteRequest{}, false - } - consumed, rest, numresult := q.s[0].Consume(n) - switch numresult { - case 0: - return FrameWriteRequest{}, false - case 1: - q.shift() - case 2: - q.s[0] = rest - } - return consumed, true -} - -type writeQueuePool []*writeQueue - -// put inserts an unused writeQueue into the pool. - -// put inserts an unused writeQueue into the pool. -func (p *writeQueuePool) put(q *writeQueue) { - for i := range q.s { - q.s[i] = FrameWriteRequest{} - } - q.s = q.s[:0] - *p = append(*p, q) -} - -// get returns an empty writeQueue. -func (p *writeQueuePool) get() *writeQueue { - ln := len(*p) - if ln == 0 { - return new(writeQueue) - } - x := ln - 1 - q := (*p)[x] - (*p)[x] = nil - *p = (*p)[:x] - return q -} - -// RFC 7540, Section 5.3.5: the default weight is 16. -const priorityDefaultWeight = 15 // 16 = 15 + 1 - -// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. -type PriorityWriteSchedulerConfig struct { - // MaxClosedNodesInTree controls the maximum number of closed streams to - // retain in the priority tree. Setting this to zero saves a small amount - // of memory at the cost of performance. - // - // See RFC 7540, Section 5.3.4: - // "It is possible for a stream to become closed while prioritization - // information ... is in transit. ... This potentially creates suboptimal - // prioritization, since the stream could be given a priority that is - // different from what is intended. To avoid these problems, an endpoint - // SHOULD retain stream prioritization state for a period after streams - // become closed. The longer state is retained, the lower the chance that - // streams are assigned incorrect or default priority values." - MaxClosedNodesInTree int - - // MaxIdleNodesInTree controls the maximum number of idle streams to - // retain in the priority tree. Setting this to zero saves a small amount - // of memory at the cost of performance. - // - // See RFC 7540, Section 5.3.4: - // Similarly, streams that are in the "idle" state can be assigned - // priority or become a parent of other streams. This allows for the - // creation of a grouping node in the dependency tree, which enables - // more flexible expressions of priority. Idle streams begin with a - // default priority (Section 5.3.5). - MaxIdleNodesInTree int - - // ThrottleOutOfOrderWrites enables write throttling to help ensure that - // data is delivered in priority order. This works around a race where - // stream B depends on stream A and both streams are about to call Write - // to queue DATA frames. If B wins the race, a naive scheduler would eagerly - // write as much data from B as possible, but this is suboptimal because A - // is a higher-priority stream. With throttling enabled, we write a small - // amount of data from B to minimize the amount of bandwidth that B can - // steal from A. - ThrottleOutOfOrderWrites bool -} - -// NewPriorityWriteScheduler constructs a WriteScheduler that schedules -// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. -// If cfg is nil, default options are used. -func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler { - if cfg == nil { - // For justification of these defaults, see: - // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY - cfg = &PriorityWriteSchedulerConfig{ - MaxClosedNodesInTree: 10, - MaxIdleNodesInTree: 10, - ThrottleOutOfOrderWrites: false, - } - } - - ws := &priorityWriteScheduler{ - nodes: make(map[uint32]*priorityNode), - maxClosedNodesInTree: cfg.MaxClosedNodesInTree, - maxIdleNodesInTree: cfg.MaxIdleNodesInTree, - enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, - } - ws.nodes[0] = &ws.root - if cfg.ThrottleOutOfOrderWrites { - ws.writeThrottleLimit = 1024 - } else { - ws.writeThrottleLimit = math.MaxInt32 - } - return ws -} - -type priorityNodeState int - -const ( - priorityNodeOpen priorityNodeState = iota - priorityNodeClosed - priorityNodeIdle -) - -// priorityNode is a node in an HTTP/2 priority tree. -// Each node is associated with a single stream ID. -// See RFC 7540, Section 5.3. -type priorityNode struct { - q writeQueue // queue of pending frames to write - id uint32 // id of the stream, or 0 for the root of the tree - weight uint8 // the actual weight is weight+1, so the value is in [1,256] - state priorityNodeState // open | closed | idle - bytes int64 // number of bytes written by this node, or 0 if closed - subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree - - // These links form the priority tree. - parent *priorityNode - kids *priorityNode // start of the kids list - prev, next *priorityNode // doubly-linked list of siblings -} - -func (n *priorityNode) setParent(parent *priorityNode) { - if n == parent { - panic("setParent to self") - } - if n.parent == parent { - return - } - // Unlink from current parent. - if parent := n.parent; parent != nil { - if n.prev == nil { - parent.kids = n.next - } else { - n.prev.next = n.next - } - if n.next != nil { - n.next.prev = n.prev - } - } - // Link to new parent. - // If parent=nil, remove n from the tree. - // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). - n.parent = parent - if parent == nil { - n.next = nil - n.prev = nil - } else { - n.next = parent.kids - n.prev = nil - if n.next != nil { - n.next.prev = n - } - parent.kids = n - } -} - -func (n *priorityNode) addBytes(b int64) { - n.bytes += b - for ; n != nil; n = n.parent { - n.subtreeBytes += b - } -} - -// walkReadyInOrder iterates over the tree in priority order, calling f for each node -// with a non-empty write queue. When f returns true, this function returns true and the -// walk halts. tmp is used as scratch space for sorting. -// -// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true -// if any ancestor p of n is still open (ignoring the root node). -func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool { - if !n.q.empty() && f(n, openParent) { - return true - } - if n.kids == nil { - return false - } - - // Don't consider the root "open" when updating openParent since - // we can't send data frames on the root stream (only control frames). - if n.id != 0 { - openParent = openParent || (n.state == priorityNodeOpen) - } - - // Common case: only one kid or all kids have the same weight. - // Some clients don't use weights; other clients (like web browsers) - // use mostly-linear priority trees. - w := n.kids.weight - needSort := false - for k := n.kids.next; k != nil; k = k.next { - if k.weight != w { - needSort = true - break - } - } - if !needSort { - for k := n.kids; k != nil; k = k.next { - if k.walkReadyInOrder(openParent, tmp, f) { - return true - } - } - return false - } - - // Uncommon case: sort the child nodes. We remove the kids from the parent, - // then re-insert after sorting so we can reuse tmp for future sort calls. - *tmp = (*tmp)[:0] - for n.kids != nil { - *tmp = append(*tmp, n.kids) - n.kids.setParent(nil) - } - sort.Sort(sortPriorityNodeSiblings(*tmp)) - for i := len(*tmp) - 1; i >= 0; i-- { - (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids - } - for k := n.kids; k != nil; k = k.next { - if k.walkReadyInOrder(openParent, tmp, f) { - return true - } - } - return false -} - -type sortPriorityNodeSiblings []*priorityNode - -func (z sortPriorityNodeSiblings) Len() int { return len(z) } - -func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } - -func (z sortPriorityNodeSiblings) Less(i, k int) bool { - // Prefer the subtree that has sent fewer bytes relative to its weight. - // See sections 5.3.2 and 5.3.4. - wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) - wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) - if bi == 0 && bk == 0 { - return wi >= wk - } - if bk == 0 { - return false - } - return bi/bk <= wi/wk -} - -type priorityWriteScheduler struct { - // root is the root of the priority tree, where root.id = 0. - // The root queues control frames that are not associated with any stream. - root priorityNode - - // nodes maps stream ids to priority tree nodes. - nodes map[uint32]*priorityNode - - // maxID is the maximum stream id in nodes. - maxID uint32 - - // lists of nodes that have been closed or are idle, but are kept in - // the tree for improved prioritization. When the lengths exceed either - // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. - closedNodes, idleNodes []*priorityNode - - // From the config. - maxClosedNodesInTree int - maxIdleNodesInTree int - writeThrottleLimit int32 - enableWriteThrottle bool - - // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. - tmp []*priorityNode - - // pool of empty queues for reuse. - queuePool writeQueuePool -} - -func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { - // The stream may be currently idle but cannot be opened or closed. - if curr := ws.nodes[streamID]; curr != nil { - if curr.state != priorityNodeIdle { - panic(fmt.Sprintf("stream %d already opened", streamID)) - } - curr.state = priorityNodeOpen - return - } - - // RFC 7540, Section 5.3.5: - // "All streams are initially assigned a non-exclusive dependency on stream 0x0. - // Pushed streams initially depend on their associated stream. In both cases, - // streams are assigned a default weight of 16." - parent := ws.nodes[options.PusherID] - if parent == nil { - parent = &ws.root - } - n := &priorityNode{ - q: *ws.queuePool.get(), - id: streamID, - weight: priorityDefaultWeight, - state: priorityNodeOpen, - } - n.setParent(parent) - ws.nodes[streamID] = n - if streamID > ws.maxID { - ws.maxID = streamID - } -} - -func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { - if streamID == 0 { - panic("violation of WriteScheduler interface: cannot close stream 0") - } - if ws.nodes[streamID] == nil { - panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) - } - if ws.nodes[streamID].state != priorityNodeOpen { - panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) - } - - n := ws.nodes[streamID] - n.state = priorityNodeClosed - n.addBytes(-n.bytes) - - q := n.q - ws.queuePool.put(&q) - n.q.s = nil - if ws.maxClosedNodesInTree > 0 { - ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) - } else { - ws.removeNode(n) - } -} - -func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { - if streamID == 0 { - panic("adjustPriority on root") - } - - // If streamID does not exist, there are two cases: - // - A closed stream that has been removed (this will have ID <= maxID) - // - An idle stream that is being used for "grouping" (this will have ID > maxID) - n := ws.nodes[streamID] - if n == nil { - if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { - return - } - ws.maxID = streamID - n = &priorityNode{ - q: *ws.queuePool.get(), - id: streamID, - weight: priorityDefaultWeight, - state: priorityNodeIdle, - } - n.setParent(&ws.root) - ws.nodes[streamID] = n - ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) - } - - // Section 5.3.1: A dependency on a stream that is not currently in the tree - // results in that stream being given a default priority (Section 5.3.5). - parent := ws.nodes[priority.StreamDep] - if parent == nil { - n.setParent(&ws.root) - n.weight = priorityDefaultWeight - return - } - - // Ignore if the client tries to make a node its own parent. - if n == parent { - return - } - - // Section 5.3.3: - // "If a stream is made dependent on one of its own dependencies, the - // formerly dependent stream is first moved to be dependent on the - // reprioritized stream's previous parent. The moved dependency retains - // its weight." - // - // That is: if parent depends on n, move parent to depend on n.parent. - for x := parent.parent; x != nil; x = x.parent { - if x == n { - parent.setParent(n.parent) - break - } - } - - // Section 5.3.3: The exclusive flag causes the stream to become the sole - // dependency of its parent stream, causing other dependencies to become - // dependent on the exclusive stream. - if priority.Exclusive { - k := parent.kids - for k != nil { - next := k.next - if k != n { - k.setParent(n) - } - k = next - } - } - - n.setParent(parent) - n.weight = priority.Weight -} - -func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { - var n *priorityNode - if id := wr.StreamID(); id == 0 { - n = &ws.root - } else { - n = ws.nodes[id] - if n == nil { - // id is an idle or closed stream. wr should not be a HEADERS or - // DATA frame. However, wr can be a RST_STREAM. In this case, we - // push wr onto the root, rather than creating a new priorityNode, - // since RST_STREAM is tiny and the stream's priority is unknown - // anyway. See issue #17919. - if wr.DataSize() > 0 { - panic("add DATA on non-open stream") - } - n = &ws.root - } - } - n.q.push(wr) -} - -func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { - ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool { - limit := int32(math.MaxInt32) - if openParent { - limit = ws.writeThrottleLimit - } - wr, ok = n.q.consume(limit) - if !ok { - return false - } - n.addBytes(int64(wr.DataSize())) - // If B depends on A and B continuously has data available but A - // does not, gradually increase the throttling limit to allow B to - // steal more and more bandwidth from A. - if openParent { - ws.writeThrottleLimit += 1024 - if ws.writeThrottleLimit < 0 { - ws.writeThrottleLimit = math.MaxInt32 - } - } else if ws.enableWriteThrottle { - ws.writeThrottleLimit = 1024 - } - return true - }) - return wr, ok -} - -func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) { - if maxSize == 0 { - return - } - if len(*list) == maxSize { - // Remove the oldest node, then shift left. - ws.removeNode((*list)[0]) - x := (*list)[1:] - copy(*list, x) - *list = (*list)[:len(x)] - } - *list = append(*list, n) -} - -func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { - for k := n.kids; k != nil; k = k.next { - k.setParent(n.parent) - } - n.setParent(nil) - delete(ws.nodes, n.id) -} - -// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 -// priorities. Control frames like SETTINGS and PING are written before DATA -// frames, but if no control frames are queued and multiple streams have queued -// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. -func NewRandomWriteScheduler() WriteScheduler { - return &randomWriteScheduler{sq: make(map[uint32]*writeQueue)} -} - -type randomWriteScheduler struct { - // zero are frames not associated with a specific stream. - zero writeQueue - - // sq contains the stream-specific queues, keyed by stream ID. - // When a stream is idle, closed, or emptied, it's deleted - // from the map. - sq map[uint32]*writeQueue - - // pool of empty queues for reuse. - queuePool writeQueuePool -} - -func (ws *randomWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { - // no-op: idle streams are not tracked -} - -func (ws *randomWriteScheduler) CloseStream(streamID uint32) { - q, ok := ws.sq[streamID] - if !ok { - return - } - delete(ws.sq, streamID) - ws.queuePool.put(q) -} - -func (ws *randomWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { - // no-op: priorities are ignored -} - -func (ws *randomWriteScheduler) Push(wr FrameWriteRequest) { - if wr.isControl() { - ws.zero.push(wr) - return - } - id := wr.StreamID() - q, ok := ws.sq[id] - if !ok { - q = ws.queuePool.get() - ws.sq[id] = q - } - q.push(wr) -} - -func (ws *randomWriteScheduler) Pop() (FrameWriteRequest, bool) { - // Control and RST_STREAM frames first. - if !ws.zero.empty() { - return ws.zero.shift(), true - } - // Iterate over all non-idle streams until finding one that can be consumed. - for streamID, q := range ws.sq { - if wr, ok := q.consume(math.MaxInt32); ok { - if q.empty() { - delete(ws.sq, streamID) - ws.queuePool.put(q) - } - return wr, true - } - } - return FrameWriteRequest{}, false -} - -var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") - -func stderrv() io.Writer { - if *stderrVerbose { - return os.Stderr - } - - return io.Discard -} - -type safeBuffer struct { - b bytes.Buffer - m sync.Mutex -} - -func (sb *safeBuffer) Write(d []byte) (int, error) { - sb.m.Lock() - defer sb.m.Unlock() - return sb.b.Write(d) -} - -func (sb *safeBuffer) Bytes() []byte { - sb.m.Lock() - defer sb.m.Unlock() - return sb.b.Bytes() -} - -func (sb *safeBuffer) Len() int { - sb.m.Lock() - defer sb.m.Unlock() - return sb.b.Len() -} - -type serverTester struct { - cc net.Conn // client conn - t testing.TB - ts *httptest.Server - fr *Framer - serverLogBuf safeBuffer // logger for httptest.Server - logFilter []string // substrings to filter out - scMu sync.Mutex // guards sc - sc *serverConn - hpackDec *hpack.Decoder - decodedHeaders [][2]string - - // If debug!=2, then we capture Frame debug logs that will be written - // to t.Log after a test fails. The read and write logs use separate locks - // and buffers so we don't accidentally introduce synchronization between - // the read and write goroutines, which may hide data races. - frameReadLogMu sync.Mutex - frameReadLogBuf bytes.Buffer - frameWriteLogMu sync.Mutex - frameWriteLogBuf bytes.Buffer - - // writing headers: - headerBuf bytes.Buffer - hpackEnc *hpack.Encoder -} - -func (st *serverTester) onHeaderField(f hpack.HeaderField) { - if f.Name == "date" { - return - } - st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value}) -} - -func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) { - st.decodedHeaders = nil - if _, err := st.hpackDec.Write(headerBlock); err != nil { - st.t.Fatalf("hpack decoding error: %v", err) - } - if err := st.hpackDec.Close(); err != nil { - st.t.Fatalf("hpack decoding error: %v", err) - } - return st.decodedHeaders -} - -func init() { - testHookOnPanicMu = new(sync.Mutex) - goAwayTimeout = 25 * time.Millisecond -} - -func resetHooks() { - testHookOnPanicMu.Lock() - testHookOnPanic = nil - testHookOnPanicMu.Unlock() -} - -// ConfigureServer adds HTTP/2 support to a net/http Server. -// -// The configuration conf may be nil. -// -// ConfigureServer must be called before s begins serving. -func ConfigureServer(s *http.Server, conf *Server) error { - if s == nil { - panic("nil *http.Server") - } - if conf == nil { - conf = new(Server) - } - conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})} - if h1, h2 := s, conf; h2.IdleTimeout == 0 { - if h1.IdleTimeout != 0 { - h2.IdleTimeout = h1.IdleTimeout - } else { - h2.IdleTimeout = h1.ReadTimeout - } - } - s.RegisterOnShutdown(conf.state.startGracefulShutdown) - - if s.TLSConfig == nil { - s.TLSConfig = new(tls.Config) - } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 { - // If they already provided a TLS 1.0–1.2 CipherSuite list, return an - // error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or - // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. - haveRequired := false - for _, cs := range s.TLSConfig.CipherSuites { - switch cs { - case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - // Alternative MTI cipher to not discourage ECDSA-only servers. - // See http://golang.org/cl/30721 for further information. - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: - haveRequired = true - } - } - if !haveRequired { - return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)") - } - } - - // Note: not setting MinVersion to tls.VersionTLS12, - // as we don't want to interfere with HTTP/1.1 traffic - // on the user's server. We enforce TLS 1.2 later once - // we accept a connection. Ideally this should be done - // during next-proto selection, but using TLS <1.2 with - // HTTP/2 is still the client's bug. - - s.TLSConfig.PreferServerCipherSuites = true - - if !strSliceContains(s.TLSConfig.NextProtos, NextProtoTLS) { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) - } - if !strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") - } - - if s.TLSNextProto == nil { - s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} - } - protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { - if testHookOnConn != nil { - testHookOnConn() - } - // The TLSNextProto interface predates contexts, so - // the net/http package passes down its per-connection - // base context via an exported but unadvertised - // method on the Handler. This is for internal - // net/http<=>http2 use only. - var ctx context.Context - type baseContexter interface { - BaseContext() context.Context - } - if bc, ok := h.(baseContexter); ok { - ctx = bc.BaseContext() - } - conf.ServeConn(c, &ServeConnOpts{ - Context: ctx, - Handler: h, - BaseConfig: hs, - }) - } - s.TLSNextProto[NextProtoTLS] = protoHandler - return nil -} - -type twriter struct { - t testing.TB - st *serverTester // optional -} - -func (w twriter) Write(p []byte) (n int, err error) { - if w.st != nil { - ps := string(p) - for _, phrase := range w.st.logFilter { - if strings.Contains(ps, phrase) { - return len(p), nil // no logging - } - } - } - w.t.Logf("%s", p) - return len(p), nil -} - -type serverTesterOpt string - -var optOnlyServer = serverTesterOpt("only_server") -var optQuiet = serverTesterOpt("quiet_logging") -var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames") - -func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { - resetHooks() - - ts := httptest.NewUnstartedServer(handler) - - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{NextProtoTLS}, - } - - var onlyServer, quiet, framerReuseFrames bool - h2server := new(Server) - for _, opt := range opts { - switch v := opt.(type) { - case func(*tls.Config): - v(tlsConfig) - case func(*httptest.Server): - v(ts) - case func(*Server): - v(h2server) - case serverTesterOpt: - switch v { - case optOnlyServer: - onlyServer = true - case optQuiet: - quiet = true - case optFramerReuseFrames: - framerReuseFrames = true - } - case func(net.Conn, http.ConnState): - ts.Config.ConnState = v - default: - t.Fatalf("unknown newServerTester option type %T", v) - } - } - - ConfigureServer(ts.Config, h2server) - - st := &serverTester{ - t: t, - ts: ts, - } - st.hpackEnc = hpack.NewEncoder(&st.headerBuf) - st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) - - ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config - if quiet { - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - } else { - ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) - } - ts.StartTLS() - - if VerboseLogs { - t.Logf("Running test server at: %s", ts.URL) - } - testHookGetServerConn = func(v *serverConn) { - st.scMu.Lock() - defer st.scMu.Unlock() - st.sc = v - } - log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st})) - if !onlyServer { - cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) - if err != nil { - t.Fatal(err) - } - st.cc = cc - st.fr = NewFramer(cc, cc) - if framerReuseFrames { - st.fr.SetReuseFrames() - } - if !logFrameReads && !logFrameWrites { - st.fr.debugReadLoggerf = func(m string, v ...interface{}) { - m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" - st.frameReadLogMu.Lock() - fmt.Fprintf(&st.frameReadLogBuf, m, v...) - st.frameReadLogMu.Unlock() - } - st.fr.debugWriteLoggerf = func(m string, v ...interface{}) { - m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" - st.frameWriteLogMu.Lock() - fmt.Fprintf(&st.frameWriteLogBuf, m, v...) - st.frameWriteLogMu.Unlock() - } - st.fr.logReads = true - st.fr.logWrites = true - } - } - return st -} - -func (st *serverTester) closeConn() { - st.scMu.Lock() - defer st.scMu.Unlock() - st.sc.conn.Close() -} - -func (st *serverTester) addLogFilter(phrase string) { - st.logFilter = append(st.logFilter, phrase) -} - -func (st *serverTester) stream(id uint32) *stream { - ch := make(chan *stream, 1) - st.sc.serveMsgCh <- func(int) { - ch <- st.sc.streams[id] - } - return <-ch -} - -func (st *serverTester) streamState(id uint32) streamState { - ch := make(chan streamState, 1) - st.sc.serveMsgCh <- func(int) { - state, _ := st.sc.state(id) - ch <- state - } - return <-ch -} - -// loopNum reports how many times this conn's select loop has gone around. -func (st *serverTester) loopNum() int { - lastc := make(chan int, 1) - st.sc.serveMsgCh <- func(loopNum int) { - lastc <- loopNum - } - return <-lastc -} - -// awaitIdle heuristically awaits for the server conn's select loop to be idle. -// The heuristic is that the server connection's serve loop must schedule -// 50 times in a row without any channel sends or receives occurring. -func (st *serverTester) awaitIdle() { - remain := 50 - last := st.loopNum() - for remain > 0 { - n := st.loopNum() - if n == last+1 { - remain-- - } else { - remain = 50 - } - last = n - } -} - -func (st *serverTester) Close() { - if st.t.Failed() { - st.frameReadLogMu.Lock() - if st.frameReadLogBuf.Len() > 0 { - st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String()) - } - st.frameReadLogMu.Unlock() - - st.frameWriteLogMu.Lock() - if st.frameWriteLogBuf.Len() > 0 { - st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String()) - } - st.frameWriteLogMu.Unlock() - - // If we failed already (and are likely in a Fatal, - // unwindowing), force close the connection, so the - // httptest.Server doesn't wait forever for the conn - // to close. - if st.cc != nil { - st.cc.Close() - } - } - st.ts.Close() - if st.cc != nil { - st.cc.Close() - } - log.SetOutput(os.Stderr) -} - -// greet initiates the client's HTTP/2 connection into a state where -// frames may be sent. -func (st *serverTester) greet() { - st.greetAndCheckSettings(func(Setting) error { return nil }) -} - -func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) { - st.writePreface() - st.writeInitialSettings() - st.wantSettings().ForeachSetting(checkSetting) - st.writeSettingsAck() - - // The initial WINDOW_UPDATE and SETTINGS ACK can come in any order. - var gotSettingsAck bool - var gotWindowUpdate bool - - for i := 0; i < 2; i++ { - f, err := st.readFrame() - if err != nil { - st.t.Fatal(err) - } - switch f := f.(type) { - case *SettingsFrame: - if !f.Header().Flags.Has(FlagSettingsAck) { - st.t.Fatal("Settings Frame didn't have ACK set") - } - gotSettingsAck = true - - case *WindowUpdateFrame: - if f.FrameHeader.StreamID != 0 { - st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID) - } - incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize) - if f.Increment != incr { - st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr) - } - gotWindowUpdate = true - - default: - st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f) - } - } - - if !gotSettingsAck { - st.t.Fatalf("Didn't get a settings ACK") - } - if !gotWindowUpdate { - st.t.Fatalf("Didn't get a window update") - } -} - -func (st *serverTester) writePreface() { - n, err := st.cc.Write(clientPreface) - if err != nil { - st.t.Fatalf("Error writing client preface: %v", err) - } - if n != len(clientPreface) { - st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface)) - } -} - -func (st *serverTester) writeInitialSettings() { - if err := st.fr.WriteSettings(); err != nil { - st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) - } -} - -func (st *serverTester) writeSettingsAck() { - if err := st.fr.WriteSettingsAck(); err != nil { - st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) - } -} - -func (st *serverTester) writeHeaders(p HeadersFrameParam) { - if err := st.fr.WriteHeaders(p); err != nil { - st.t.Fatalf("Error writing HEADERS: %v", err) - } -} - -func (st *serverTester) writePriority(id uint32, p PriorityParam) { - if err := st.fr.WritePriority(id, p); err != nil { - st.t.Fatalf("Error writing PRIORITY: %v", err) - } -} - -func (st *serverTester) encodeHeaderField(k, v string) { - err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) - if err != nil { - st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) - } -} - -// encodeHeaderRaw is the magic-free version of encodeHeader. -// It takes 0 or more (k, v) pairs and encodes them. -func (st *serverTester) encodeHeaderRaw(headers ...string) []byte { - if len(headers)%2 == 1 { - panic("odd number of kv args") - } - st.headerBuf.Reset() - for len(headers) > 0 { - k, v := headers[0], headers[1] - st.encodeHeaderField(k, v) - headers = headers[2:] - } - return st.headerBuf.Bytes() -} - -// encodeHeader encodes headers and returns their HPACK bytes. headers -// must contain an even number of key/value pairs. There may be -// multiple pairs for keys (e.g. "cookie"). The :method, :path, and -// :scheme headers default to GET, / and https. The :authority header -// defaults to st.ts.Listener.Addr(). -func (st *serverTester) encodeHeader(headers ...string) []byte { - if len(headers)%2 == 1 { - panic("odd number of kv args") - } - - st.headerBuf.Reset() - defaultAuthority := st.ts.Listener.Addr().String() - - if len(headers) == 0 { - // Fast path, mostly for benchmarks, so test code doesn't pollute - // profiles when we're looking to improve server allocations. - st.encodeHeaderField(":method", "GET") - st.encodeHeaderField(":scheme", "https") - st.encodeHeaderField(":authority", defaultAuthority) - st.encodeHeaderField(":path", "/") - return st.headerBuf.Bytes() - } - - if len(headers) == 2 && headers[0] == ":method" { - // Another fast path for benchmarks. - st.encodeHeaderField(":method", headers[1]) - st.encodeHeaderField(":scheme", "https") - st.encodeHeaderField(":authority", defaultAuthority) - st.encodeHeaderField(":path", "/") - return st.headerBuf.Bytes() - } - - pseudoCount := map[string]int{} - keys := []string{":method", ":scheme", ":authority", ":path"} - vals := map[string][]string{ - ":method": {"GET"}, - ":scheme": {"https"}, - ":authority": {defaultAuthority}, - ":path": {"/"}, - } - for len(headers) > 0 { - k, v := headers[0], headers[1] - headers = headers[2:] - if _, ok := vals[k]; !ok { - keys = append(keys, k) - } - if strings.HasPrefix(k, ":") { - pseudoCount[k]++ - if pseudoCount[k] == 1 { - vals[k] = []string{v} - } else { - // Allows testing of invalid headers w/ dup pseudo fields. - vals[k] = append(vals[k], v) - } - } else { - vals[k] = append(vals[k], v) - } - } - for _, k := range keys { - for _, v := range vals[k] { - st.encodeHeaderField(k, v) - } - } - return st.headerBuf.Bytes() -} - -// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set. -func (st *serverTester) bodylessReq1(headers ...string) { - st.writeHeaders(HeadersFrameParam{ - StreamID: 1, // clients send odd numbers - BlockFragment: st.encodeHeader(headers...), - EndStream: true, - EndHeaders: true, - }) -} - -func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { - if err := st.fr.WriteData(streamID, endStream, data); err != nil { - st.t.Fatalf("Error writing DATA: %v", err) - } -} - -func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { - if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { - st.t.Fatalf("Error writing DATA: %v", err) - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection. -func (st *serverTester) writeReadPing() { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := st.fr.WritePing(false, data); err != nil { - st.t.Fatalf("Error writing PING: %v", err) - } - p := st.wantPing() - if p.Flags&FlagPingAck == 0 { - st.t.Fatalf("got a PING, want a PING ACK") - } - if p.Data != data { - st.t.Fatalf("got PING data = %x, want %x", p.Data, data) - } -} - -func (st *serverTester) readFrame() (Frame, error) { - return st.fr.ReadFrame() -} - -func (st *serverTester) wantHeaders() *HeadersFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a HEADERS frame: %v", err) - } - hf, ok := f.(*HeadersFrame) - if !ok { - st.t.Fatalf("got a %T; want *HeadersFrame", f) - } - return hf -} - -func (st *serverTester) wantContinuation() *ContinuationFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err) - } - cf, ok := f.(*ContinuationFrame) - if !ok { - st.t.Fatalf("got a %T; want *ContinuationFrame", f) - } - return cf -} - -func (st *serverTester) wantData() *DataFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a DATA frame: %v", err) - } - df, ok := f.(*DataFrame) - if !ok { - st.t.Fatalf("got a %T; want *DataFrame", f) - } - return df -} - -func (st *serverTester) wantSettings() *SettingsFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) - } - sf, ok := f.(*SettingsFrame) - if !ok { - st.t.Fatalf("got a %T; want *SettingsFrame", f) - } - return sf -} - -func (st *serverTester) wantPing() *PingFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a PING frame: %v", err) - } - pf, ok := f.(*PingFrame) - if !ok { - st.t.Fatalf("got a %T; want *PingFrame", f) - } - return pf -} - -func (st *serverTester) wantGoAway() *GoAwayFrame { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err) - } - gf, ok := f.(*GoAwayFrame) - if !ok { - st.t.Fatalf("got a %T; want *GoAwayFrame", f) - } - return gf -} - -func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting an RSTStream frame: %v", err) - } - rs, ok := f.(*RSTStreamFrame) - if !ok { - st.t.Fatalf("got a %T; want *RSTStreamFrame", f) - } - if rs.FrameHeader.StreamID != streamID { - st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID) - } - if rs.ErrCode != errCode { - st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) - } -} - -func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { - f, err := st.readFrame() - if err != nil { - st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err) - } - wu, ok := f.(*WindowUpdateFrame) - if !ok { - st.t.Fatalf("got a %T; want *WindowUpdateFrame", f) - } - if wu.FrameHeader.StreamID != streamID { - st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID) - } - if wu.Increment != incr { - st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) - } -} - -func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { - var initial int32 - if streamID == 0 { - initial = st.sc.srv.initialConnRecvWindowSize() - } else { - initial = st.sc.srv.initialStreamRecvWindowSize() - } - donec := make(chan struct{}) - st.sc.sendServeMsg(func(sc *serverConn) { - defer close(donec) - var avail int32 - if streamID == 0 { - avail = sc.inflow.avail + sc.inflow.unsent - } else { - } - if got, want := initial-avail, consumed; got != want { - st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want) - } - }) - <-donec -} - -type roundRobinWriteScheduler struct { - // control contains control frames (SETTINGS, PING, etc.). - control writeQueue - - // streams maps stream ID to a queue. - streams map[uint32]*writeQueue - - // stream queues are stored in a circular linked list. - // head is the next stream to write, or nil if there are no streams open. - head *writeQueue - - // pool of empty queues for reuse. - queuePool writeQueuePool -} - -// newRoundRobinWriteScheduler constructs a new write scheduler. -// The round robin scheduler priorizes control frames -// like SETTINGS and PING over DATA frames. -// When there are no control frames to send, it performs a round-robin -// selection from the ready streams. -func newRoundRobinWriteScheduler() WriteScheduler { - ws := &roundRobinWriteScheduler{ - streams: make(map[uint32]*writeQueue), - } - return ws -} - -func (ws *roundRobinWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { - if ws.streams[streamID] != nil { - panic(fmt.Errorf("stream %d already opened", streamID)) - } - q := ws.queuePool.get() - ws.streams[streamID] = q - if ws.head == nil { - ws.head = q - q.next = q - q.prev = q - } else { - // Queues are stored in a ring. - // Insert the new stream before ws.head, putting it at the end of the list. - q.prev = ws.head.prev - q.next = ws.head - q.prev.next = q - q.next.prev = q - } -} - -func (ws *roundRobinWriteScheduler) CloseStream(streamID uint32) { - q := ws.streams[streamID] - if q == nil { - return - } - if q.next == q { - // This was the only open stream. - ws.head = nil - } else { - q.prev.next = q.next - q.next.prev = q.prev - if ws.head == q { - ws.head = q.next - } - } - delete(ws.streams, streamID) - ws.queuePool.put(q) -} - -func (ws *roundRobinWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) {} - -func (ws *roundRobinWriteScheduler) Push(wr FrameWriteRequest) { - if wr.isControl() { - ws.control.push(wr) - return - } - q := ws.streams[wr.StreamID()] - if q == nil { - // This is a closed stream. - // wr should not be a HEADERS or DATA frame. - // We push the request onto the control queue. - if wr.DataSize() > 0 { - panic("add DATA on non-open stream") - } - ws.control.push(wr) - return - } - q.push(wr) -} - -func (ws *roundRobinWriteScheduler) Pop() (FrameWriteRequest, bool) { - // Control and RST_STREAM frames first. - if !ws.control.empty() { - return ws.control.shift(), true - } - if ws.head == nil { - return FrameWriteRequest{}, false - } - q := ws.head - for { - if wr, ok := q.consume(math.MaxInt32); ok { - ws.head = q.next - return wr, true - } - q = q.next - if q == ws.head { - break - } - } - return FrameWriteRequest{}, false -} diff --git a/internal/http2/transport_go117_test.go b/internal/http2/transport_go117_test.go deleted file mode 100644 index e46e5fcf..00000000 --- a/internal/http2/transport_go117_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.17 -// +build go1.17 - -package http2 - -import ( - "context" - "crypto/tls" - "errors" - "github.com/imroc/req/v3/internal/transport" - "net/http" - "net/http/httptest" - - "testing" -) - -func TestTransportDialTLSContexth2(t *testing.T) { - blockCh := make(chan struct{}) - serverTLSConfigFunc := func(ts *httptest.Server) { - ts.Config.TLSConfig = &tls.Config{ - // Triggers the server to request the clients certificate - // during TLS handshake. - ClientAuth: tls.RequestClientCert, - } - } - ts := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) {}, - optOnlyServer, - serverTLSConfigFunc, - ) - defer ts.Close() - opt := &transport.Options{ - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Tests that the context provided to `req` is - // passed into this function. - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() - }, - InsecureSkipVerify: true, - }, - } - tr := &Transport{ - Options: opt, - } - defer tr.CloseIdleConnections() - req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - req = req.WithContext(ctx) - errCh := make(chan error) - go func() { - defer close(errCh) - res, err := tr.RoundTrip(req) - if err != nil { - errCh <- err - return - } - res.Body.Close() - }() - // Wait for GetClientCertificate handler to be called - <-blockCh - // Cancel the context - cancel() - // Expect the cancellation error here - err = <-errCh - if err == nil { - t.Fatal("cancelling context during client certificate fetch did not error as expected") - return - } - if !errors.Is(err, context.Canceled) { - t.Fatalf("unexpected error returned after cancellation: %v", err) - } -} - -// TestDialRaceResumesDial tests that, given two concurrent requests -// to the same address, when the first Dial is interrupted because -// the first request's context is cancelled, the second request -// resumes the dial automatically. -func TestDialRaceResumesDial(t *testing.T) { - blockCh := make(chan struct{}) - serverTLSConfigFunc := func(ts *httptest.Server) { - ts.Config.TLSConfig = &tls.Config{ - // Triggers the server to request the clients certificate - // during TLS handshake. - ClientAuth: tls.RequestClientCert, - } - } - ts := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) {}, - optOnlyServer, - serverTLSConfigFunc, - ) - defer ts.Close() - opt := &transport.Options{ - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - select { - case <-blockCh: - // If we already errored, return without error. - return &tls.Certificate{}, nil - default: - } - close(blockCh) - <-cri.Context().Done() - return nil, cri.Context().Err() - }, - InsecureSkipVerify: true, - }, - } - tr := &Transport{ - Options: opt, - } - defer tr.CloseIdleConnections() - req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - // Create two requests with independent cancellation. - ctx1, cancel1 := context.WithCancel(context.Background()) - defer cancel1() - req1 := req.WithContext(ctx1) - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - req2 := req.WithContext(ctx2) - errCh := make(chan error) - go func() { - res, err := tr.RoundTrip(req1) - if err != nil { - errCh <- err - return - } - res.Body.Close() - }() - successCh := make(chan struct{}) - go func() { - // Don't start request until first request - // has initiated the handshake. - <-blockCh - res, err := tr.RoundTrip(req2) - if err != nil { - errCh <- err - return - } - res.Body.Close() - // Close successCh to indicate that the second request - // made it to the server successfully. - close(successCh) - }() - // Wait for GetClientCertificate handler to be called - <-blockCh - // Cancel the context first - cancel1() - // Expect the cancellation error here - err = <-errCh - if err == nil { - t.Fatal("cancelling context during client certificate fetch did not error as expected") - return - } - if !errors.Is(err, context.Canceled) { - t.Fatalf("unexpected error returned after cancellation: %v", err) - } - select { - case err := <-errCh: - t.Fatalf("unexpected second error: %v", err) - case <-successCh: - } -} diff --git a/internal/http2/transport_test.go b/internal/http2/transport_test.go deleted file mode 100644 index 5acecd3d..00000000 --- a/internal/http2/transport_test.go +++ /dev/null @@ -1,6008 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "bufio" - "bytes" - "compress/gzip" - "context" - "crypto/tls" - "encoding/hex" - "errors" - "flag" - "fmt" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "github.com/imroc/req/v3/internal/transport" - "io" - "io/fs" - "log" - "math/rand" - "net" - "net/http" - "net/http/httptest" - "net/http/httptrace" - "net/textproto" - "net/url" - "os" - "reflect" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "golang.org/x/net/http2/hpack" -) - -var ( - extNet = flag.Bool("extnet", false, "do external network tests") - transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") - insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove? -) - -var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} - -var canceledCtx context.Context - -func init() { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - canceledCtx = ctx -} - -func TestTransportExternal(t *testing.T) { - if !*extNet { - t.Skip("skipping external network test") - } - req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) - opt := &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - } - rt := &Transport{Options: opt} - res, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("%v", err) - } - res.Write(os.Stdout) -} - -type fakeTLSConn struct { - net.Conn -} - -func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { - return tls.ConnectionState{ - Version: tls.VersionTLS12, - CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - } -} - -func startH2cServer(t *testing.T) net.Listener { - h2Server := &Server{} - l := tests.NewLocalListener(t) - go func() { - conn, err := l.Accept() - if err != nil { - t.Error(err) - return - } - h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) - })}) - }() - return l -} - -func TestTransportH2c(t *testing.T) { - l := startH2cServer(t) - defer l.Close() - req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil) - if err != nil { - t.Fatal(err) - } - var gotConnCnt int32 - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - tr := &Transport{ - Options: &transport.Options{}, - AllowHTTP: true, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - return net.Dial(network, addr) - }, - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if res.ProtoMajor != 2 { - t.Fatal("proto not h2c") - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if got, want := string(body), "Hello, /foobar, http: true"; got != want { - t.Fatalf("response got %v, want %v", got, want) - } - if got, want := gotConnCnt, int32(1); got != want { - t.Errorf("Too many got connections: %d", gotConnCnt) - } -} - -func TestTransport(t *testing.T) { - const body = "sup" - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, body) - }, optOnlyServer) - defer st.Close() - - opt := &transport.Options{TLSClientConfig: tlsConfigInsecure} - tr := &Transport{Options: opt} - defer tr.CloseIdleConnections() - - u, err := url.Parse(st.ts.URL) - if err != nil { - t.Fatal(err) - } - for i, m := range []string{"GET", ""} { - req := &http.Request{ - Method: m, - URL: u, - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("%d: %s", i, err) - } - - t.Logf("%d: Got res: %+v", i, res) - if g, w := res.StatusCode, 200; g != w { - t.Errorf("%d: StatusCode = %v; want %v", i, g, w) - } - if g, w := res.Status, "200 OK"; g != w { - t.Errorf("%d: Status = %q; want %q", i, g, w) - } - wantHeader := http.Header{ - "Content-Length": []string{"3"}, - "Content-Type": []string{"text/plain; charset=utf-8"}, - "Date": []string{"XXX"}, // see cleanDate - } - cleanDate(res) - if !reflect.DeepEqual(res.Header, wantHeader) { - t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader) - } - if res.Request != req { - t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req) - } - if res.TLS == nil { - t.Errorf("%d: Response.TLS = nil; want non-nil", i) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("%d: Body read: %v", i, err) - } else if string(slurp) != body { - t.Errorf("%d: Body = %q; want %q", i, slurp, body) - } - res.Body.Close() - } -} - -func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Request)) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - }, optOnlyServer, func(c net.Conn, st http.ConnState) { - t.Logf("conn %v is now state %v", c.RemoteAddr(), st) - }) - defer st.Close() - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - } - defer tr.CloseIdleConnections() - get := func() string { - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - modReq(req) - var res *http.Response - - res, err = tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("Body read: %v", err) - } - addr := strings.TrimSpace(string(slurp)) - if addr == "" { - t.Fatalf("didn't get an addr in response") - } - return addr - } - first := get() - second := get() - if got := first == second; got != wantSame { - t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame) - } -} - -func TestTransportReusesConns(t *testing.T) { - for _, test := range []struct { - name string - modReq func(*http.Request) - wantSame bool - }{{ - name: "ReuseConn", - modReq: func(*http.Request) {}, - wantSame: true, - }, { - name: "RequestClose", - modReq: func(r *http.Request) { r.Close = true }, - wantSame: false, - }, { - name: "ConnClose", - modReq: func(r *http.Request) { r.Header.Set("Connection", "close") }, - wantSame: false, - }} { - t.Run(test.name, func(t *testing.T) { - t.Run("Transport", func(t *testing.T) { - const useClient = false - testTransportReusesConns(t, test.wantSame, test.modReq) - }) - t.Run("Client", func(t *testing.T) { - const useClient = true - testTransportReusesConns(t, test.wantSame, test.modReq) - }) - }) - } -} - -func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) { - testTransportGetGotConnHooks(t) -} - -func testTransportGetGotConnHooks(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - }, func(s *httptest.Server) { - s.EnableHTTP2 = true - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - } - - var ( - getConns int32 - gotConns int32 - ) - for i := 0; i < 2; i++ { - trace := &httptrace.ClientTrace{ - GetConn: func(hostport string) { - atomic.AddInt32(&getConns, 1) - }, - GotConn: func(connInfo httptrace.GotConnInfo) { - got := atomic.AddInt32(&gotConns, 1) - wantReused, wantWasIdle := false, false - if got > 1 { - wantReused, wantWasIdle = true, true - } - if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle { - t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle) - } - }, - } - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - var res *http.Response - res, err = tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if get := atomic.LoadInt32(&getConns); get != int32(i+1) { - t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1) - } - if got := atomic.LoadInt32(&gotConns); got != int32(i+1) { - t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1) - } - } -} - -type testNetConn struct { - net.Conn - closed bool - onClose func() -} - -func (c *testNetConn) Close() error { - if !c.closed { - // We can call Close multiple times on the same net.Conn. - c.onClose() - } - c.closed = true - return c.Conn.Close() -} - -// Tests that the Transport only keeps one pending dial open per destination address. -// https://golang.org/issue/13397 -func TestTransportGroupsPendingDials(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, optOnlyServer) - defer st.Close() - var ( - mu sync.Mutex - dialCount int - closeCount int - ) - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - dialCount++ - mu.Unlock() - c, err := tls.Dial(network, addr, cfg) - return &testNetConn{ - Conn: c, - onClose: func() { - mu.Lock() - closeCount++ - mu.Unlock() - }, - }, err - }, - } - defer tr.CloseIdleConnections() - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Error(err) - return - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Error(err) - return - } - res.Body.Close() - }() - } - wg.Wait() - tr.CloseIdleConnections() - if dialCount != 1 { - t.Errorf("saw %d dials; want 1", dialCount) - } - if closeCount != 1 { - t.Errorf("saw %d closes; want 1", closeCount) - } -} - -func retry(tries int, delay time.Duration, fn func() error) error { - var err error - for i := 0; i < tries; i++ { - err = fn() - if err == nil { - return nil - } - time.Sleep(delay) - } - return err -} - -func TestTransportAbortClosesPipes(t *testing.T) { - shutdown := make(chan struct{}) - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - w.(http.Flusher).Flush() - <-shutdown - }, - optOnlyServer, - ) - defer st.Close() - defer close(shutdown) // we must shutdown before st.Close() to avoid hanging - - errCh := make(chan error) - go func() { - defer close(errCh) - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - } - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - errCh <- err - return - } - res, err := tr.RoundTrip(req) - if err != nil { - errCh <- err - return - } - defer res.Body.Close() - st.closeConn() - _, err = io.ReadAll(res.Body) - if err == nil { - errCh <- errors.New("expected error from res.Body.Read") - return - } - }() - - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - // deadlock? that's a bug. - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } -} - -// TODO: merge this with TestTransportBody to make TestTransportRequest? This -// could be a table-driven test with extra goodies. -func TestTransportPath(t *testing.T) { - gotc := make(chan *url.URL, 1) - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - gotc <- r.URL - }, - optOnlyServer, - ) - defer st.Close() - - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - } - defer tr.CloseIdleConnections() - const ( - path = "/testpath" - query = "q=1" - ) - surl := st.ts.URL + path + "?" + query - req, err := http.NewRequest("POST", surl, nil) - if err != nil { - t.Fatal(err) - } - c := &http.Client{Transport: tr} - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - got := <-gotc - if got.Path != path { - t.Errorf("Read Path = %q; want %q", got.Path, path) - } - if got.RawQuery != query { - t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query) - } -} - -func randString(n int) string { - rnd := rand.New(rand.NewSource(int64(n))) - b := make([]byte, n) - for i := range b { - b[i] = byte(rnd.Intn(256)) - } - return string(b) -} - -type panicReader struct{} - -func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } -func (panicReader) Close() error { panic("unexpected Close") } - -func TestActualContentLength(t *testing.T) { - tests := []struct { - req *http.Request - want int64 - }{ - // Verify we don't read from Body: - 0: { - req: &http.Request{Body: panicReader{}}, - want: -1, - }, - // nil Body means 0, regardless of ContentLength: - 1: { - req: &http.Request{Body: nil, ContentLength: 5}, - want: 0, - }, - // ContentLength is used if set. - 2: { - req: &http.Request{Body: panicReader{}, ContentLength: 5}, - want: 5, - }, - // http.NoBody means 0, not -1. - 3: { - req: &http.Request{Body: http.NoBody}, - want: 0, - }, - } - for i, tt := range tests { - got := actualContentLength(tt.req) - if got != tt.want { - t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) - } - } -} - -func TestTransportBody(t *testing.T) { - bodyTests := []struct { - body string - noContentLen bool - }{ - {body: "some message"}, - {body: "some message", noContentLen: true}, - {body: strings.Repeat("a", 1<<20), noContentLen: true}, - {body: strings.Repeat("a", 1<<20)}, - {body: randString(16<<10 - 1)}, - {body: randString(16 << 10)}, - {body: randString(16<<10 + 1)}, - {body: randString(512<<10 - 1)}, - {body: randString(512 << 10)}, - {body: randString(512<<10 + 1)}, - {body: randString(1<<20 - 1)}, - {body: randString(1 << 20)}, - {body: randString(1<<20 + 2)}, - } - - type reqInfo struct { - req *http.Request - slurp []byte - err error - } - gotc := make(chan reqInfo, 1) - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - slurp, err := io.ReadAll(r.Body) - if err != nil { - gotc <- reqInfo{err: err} - } else { - gotc <- reqInfo{req: r, slurp: slurp} - } - }, - optOnlyServer, - ) - defer st.Close() - - for i, tt := range bodyTests { - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - } - defer tr.CloseIdleConnections() - - var body io.Reader = strings.NewReader(tt.body) - if tt.noContentLen { - body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods - } - req, err := http.NewRequest("POST", st.ts.URL, body) - if err != nil { - t.Fatalf("#%d: %v", i, err) - } - c := &http.Client{Transport: tr} - res, err := c.Do(req) - if err != nil { - t.Fatalf("#%d: %v", i, err) - } - defer res.Body.Close() - ri := <-gotc - if ri.err != nil { - t.Errorf("#%d: read error: %v", i, ri.err) - continue - } - if got := string(ri.slurp); got != tt.body { - t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body)) - } - wantLen := int64(len(tt.body)) - if tt.noContentLen && tt.body != "" { - wantLen = -1 - } - if ri.req.ContentLength != wantLen { - t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen) - } - } -} - -func shortString(v string) string { - const maxLen = 100 - if len(v) <= maxLen { - return v - } - return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) -} - -func TestTransportDialTLSh2(t *testing.T) { - var mu sync.Mutex // guards following - var gotReq, didDial bool - - ts := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - gotReq = true - mu.Unlock() - }, - optOnlyServer, - ) - defer ts.Close() - tr := &Transport{ - Options: &transport.Options{}, - DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - didDial = true - mu.Unlock() - cfg.InsecureSkipVerify = true - c, err := tls.Dial(netw, addr, cfg) - if err != nil { - return nil, err - } - return c, c.Handshake() - }, - } - defer tr.CloseIdleConnections() - client := &http.Client{Transport: tr} - res, err := client.Get(ts.ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - mu.Lock() - if !gotReq { - t.Error("didn't get request") - } - if !didDial { - t.Error("didn't use dial hook") - } -} - -type capitalizeReader struct { - r io.Reader -} - -func (cr capitalizeReader) Read(p []byte) (n int, err error) { - n, err = cr.r.Read(p) - for i, b := range p[:n] { - if b >= 'a' && b <= 'z' { - p[i] = b - ('a' - 'A') - } - } - return -} - -type flushWriter struct { - w io.Writer -} - -func (fw flushWriter) Write(p []byte) (n int, err error) { - n, err = fw.w.Write(p) - if f, ok := fw.w.(http.Flusher); ok { - f.Flush() - } - return -} - -type clientTester struct { - t *testing.T - tr *Transport - sc, cc net.Conn // server and client conn - fr *Framer // server's framer - client func() error - server func() error -} - -func newClientTester(t *testing.T) *clientTester { - var dialOnce struct { - sync.Mutex - dialed bool - } - ct := &clientTester{ - t: t, - } - ct.tr = &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialOnce.Lock() - defer dialOnce.Unlock() - if dialOnce.dialed { - return nil, errors.New("only one dial allowed in test mode") - } - dialOnce.dialed = true - return ct.cc, nil - }, - } - - ln := tests.NewLocalListener(t) - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - sc, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - ln.Close() - ct.cc = cc - ct.sc = sc - ct.fr = NewFramer(sc, sc) - return ct -} - -func newLocalListener(t *testing.T) net.Listener { - ln, err := net.Listen("tcp4", "127.0.0.1:0") - if err == nil { - return ln - } - ln, err = net.Listen("tcp6", "[::1]:0") - if err != nil { - t.Fatal(err) - } - return ln -} - -func (ct *clientTester) greet(settings ...Setting) { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - ct.t.Fatalf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - ct.t.Fatalf("Reading client settings frame: %v", err) - } - if sf, ok := f.(*SettingsFrame); !ok { - ct.t.Fatalf("Wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - if err := ct.fr.WriteSettings(settings...); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } -} - -func (ct *clientTester) readNonSettingsFrame() (Frame, error) { - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return nil, err - } - if _, ok := f.(*SettingsFrame); ok { - continue - } - return f, nil - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection, -// aside from SETTINGS frames. -func (ct *clientTester) writeReadPing() error { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := ct.fr.WritePing(false, data); err != nil { - return fmt.Errorf("Error writing PING: %v", err) - } - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - p, ok := f.(*PingFrame) - if !ok { - return fmt.Errorf("got a %v, want a PING ACK", f) - } - if p.Flags&FlagPingAck == 0 { - return fmt.Errorf("got a PING, want a PING ACK") - } - if p.Data != data { - return fmt.Errorf("got PING data = %x, want %x", p.Data, data) - } - return nil -} - -func (ct *clientTester) inflowWindow(streamID uint32) int32 { - pool := ct.tr.connPoolOrDef.(*clientConnPool) - pool.mu.Lock() - defer pool.mu.Unlock() - if n := len(pool.keys); n != 1 { - ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) - return -1 - } - for cc := range pool.keys { - cc.mu.Lock() - defer cc.mu.Unlock() - if streamID == 0 { - return cc.inflow.avail + cc.inflow.unsent - } - cs := cc.streams[streamID] - if cs == nil { - ct.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent - } - return -1 -} - -func (ct *clientTester) cleanup() { - ct.tr.CloseIdleConnections() - - // close both connections, ignore the error if its already closed - ct.sc.Close() - ct.cc.Close() -} - -func (ct *clientTester) run() { - var errOnce sync.Once - var wg sync.WaitGroup - - run := func(which string, fn func() error) { - defer wg.Done() - if err := fn(); err != nil { - errOnce.Do(func() { - ct.t.Errorf("%s: %v", which, err) - ct.cleanup() - }) - } - } - - wg.Add(2) - go run("client", ct.client) - go run("server", ct.server) - wg.Wait() - - errOnce.Do(ct.cleanup) // clean up if no error -} - -func (ct *clientTester) readFrame() (Frame, error) { - return ct.fr.ReadFrame() -} - -func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { - for { - f, err := ct.readFrame() - if err != nil { - return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - hf, ok := f.(*HeadersFrame) - if !ok { - return nil, fmt.Errorf("Got %T; want HeadersFrame", f) - } - return hf, nil - } -} - -type countingReader struct { - n *int64 -} - -func (r countingReader) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(i) - } - atomic.AddInt64(r.n, int64(len(p))) - return len(p), err -} - -func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } -func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } - -func testTransportReqBodyAfterResponse(t *testing.T, status int) { - const bodySize = 10 << 20 - clientDone := make(chan struct{}) - ct := newClientTester(t) - recvLen := make(chan int64, 1) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - body := &pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(tests.NeverEnding('A'), bodySize/2)) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - if res.StatusCode != status { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) - } - io.Copy(body, io.LimitReader(tests.NeverEnding('A'), bodySize/2)) - body.CloseWithError(io.EOF) - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("unexpected body: %q", slurp) - } - res.Body.Close() - if status == 200 { - if got := <-recvLen; got != bodySize { - return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) - } - } else { - if got := <-recvLen; got == 0 || got >= bodySize { - return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) - } - } - return nil - } - ct.server = func() error { - ct.greet() - defer close(recvLen) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var dataRecv int64 - var closed bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - // println(fmt.Sprintf("server got frame: %v", f)) - ended := false - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - if f.StreamEnded() { - return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) - } - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if dataRecv == 0 { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - dataRecv += int64(dataLen) - - if !closed && ((status != 200 && dataRecv > 0) || - (status == 200 && f.StreamEnded())) { - closed = true - if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { - return err - } - } - - if f.StreamEnded() { - ended = true - } - case *RSTStreamFrame: - if status == 200 { - return fmt.Errorf("Unexpected client frame %v", f) - } - ended = true - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - if ended { - select { - case recvLen <- dataRecv: - default: - } - } - } - } - ct.run() -} - -// See golang.org/issue/13444 -func TestTransportFullDuplex(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) // redundant but for clarity - w.(http.Flusher).Flush() - io.Copy(flushWriter{w}, capitalizeReader{r.Body}) - fmt.Fprintf(w, "bye.\n") - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - } - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - pr, pw := io.Pipe() - req, err := http.NewRequest("PUT", st.ts.URL, io.NopCloser(pr)) - if err != nil { - t.Fatal(err) - } - req.ContentLength = -1 - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) - } - bs := bufio.NewScanner(res.Body) - want := func(v string) { - if !bs.Scan() { - t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) - } - } - write := func(v string) { - _, err := io.WriteString(pw, v) - if err != nil { - t.Fatalf("pipe write: %v", err) - } - } - write("foo\n") - want("FOO") - write("bar\n") - want("BAR") - pw.Close() - want("bye.") - if err := bs.Err(); err != nil { - t.Fatal(err) - } -} - -func TestTransportConnectRequest(t *testing.T) { - gotc := make(chan *http.Request, 1) - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - gotc <- r - }, optOnlyServer) - defer st.Close() - - u, err := url.Parse(st.ts.URL) - if err != nil { - t.Fatal(err) - } - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - tests := []struct { - req *http.Request - want string - }{ - { - req: &http.Request{ - Method: "CONNECT", - Header: http.Header{}, - URL: u, - }, - want: u.Host, - }, - { - req: &http.Request{ - Method: "CONNECT", - Header: http.Header{}, - URL: u, - Host: "example.com:123", - }, - want: "example.com:123", - }, - } - - for i, tt := range tests { - res, err := c.Do(tt.req) - if err != nil { - t.Errorf("%d. RoundTrip = %v", i, err) - continue - } - res.Body.Close() - req := <-gotc - if req.Method != "CONNECT" { - t.Errorf("method = %q; want CONNECT", req.Method) - } - if req.Host != tt.want { - t.Errorf("Host = %q; want %q", req.Host, tt.want) - } - if req.URL.Host != tt.want { - t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) - } - } -} - -type headerType int - -const ( - noHeader headerType = iota // omitted - oneHeader - splitHeader // broken into continuation on purpose -) - -const ( - f0 = noHeader - f1 = oneHeader - f2 = splitHeader - d0 = false - d1 = true -) - -// Test all 36 combinations of response frame orders: -// -// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } -// -// Generated by http://play.golang.org/p/SScqYKJYXd -func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) } -func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) } -func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) } -func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) } -func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) } -func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) } -func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) } -func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) } -func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) } -func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) } -func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) } -func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) } -func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) } -func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) } -func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) } -func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) } -func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) } -func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) } -func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) } -func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) } -func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) } -func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) } -func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) } -func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) } -func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) } -func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) } -func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) } -func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) } -func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) } -func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) } -func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) } -func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) } -func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) } -func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) } -func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) } -func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } - -func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { - const reqBody = "some request body" - const resBody = "some response body" - - if resHeader == noHeader { - // TODO: test 100-continue followed by immediate - // server stream reset, without headers in the middle? - panic("invalid combination") - } - - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) - if expect100Continue != noHeader { - req.Header.Set("Expect", "100-continue") - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - wantBody := resBody - if !withData { - wantBody = "" - } - if string(slurp) != wantBody { - return fmt.Errorf("body = %q; want %q", slurp, wantBody) - } - if trailers == noHeader { - if len(res.Trailer) > 0 { - t.Errorf("Trailer = %v; want none", res.Trailer) - } - } else { - want := http.Header{"Some-Trailer": {"some-value"}} - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Trailer = %v; want %v", res.Trailer, want) - } - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - endStream := false - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *DataFrame: - if !f.StreamEnded() { - // No need to send flow control tokens. The test request body is tiny. - continue - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) - enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) - if trailers != noHeader { - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) - } - endStream = withData == false && trailers == noHeader - send(resHeader) - } - if withData { - endStream = trailers == noHeader - ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) - } - if trailers != noHeader { - endStream = true - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) - send(trailers) - } - if endStream { - return nil - } - case *HeadersFrame: - if expect100Continue != noHeader { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) - send(expect100Continue) - } - } - } - } - ct.run() -} - -// Issue 26189, Issue 17739: ignore unknown 1xx responses -func TestTransportUnknown1xx(t *testing.T) { - var buf bytes.Buffer - defer func() { got1xxFuncForTests = nil }() - got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { - fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) - return nil - } - - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 204 { - return fmt.Errorf("status code = %v; want 204", res.StatusCode) - } - want := `code=110 header=map[Foo-Bar:[110]] -code=111 header=map[Foo-Bar:[111]] -code=112 header=map[Foo-Bar:[112]] -code=113 header=map[Foo-Bar:[113]] -code=114 header=map[Foo-Bar:[114]] -` - if got := buf.String(); got != want { - t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - for i := 110; i <= 114; i++ { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) - enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - return nil - } - } - } - ct.run() - -} - -func TestTransportReceiveUndeclaredTrailer(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - if _, ok := res.Trailer["Some-Trailer"]; !ok { - return fmt.Errorf("expected Some-Trailer") - } - return nil - } - ct.server = func() error { - ct.greet() - - var n int - var hf *HeadersFrame - for hf == nil && n < 10 { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - hf, _ = f.(*HeadersFrame) - n++ - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - // send headers without Trailer header - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - - // send trailers - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - ct.run() -} - -func TestTransportInvalidTrailerPseudo1(t *testing.T) { - testTransportInvalidTrailerPseudo(t, oneHeader) -} -func TestTransportInvalidTrailerPseudo2(t *testing.T) { - testTransportInvalidTrailerPseudo(t, splitHeader) -} -func testTransportInvalidTrailerPseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - }) -} - -func TestTransportInvalidTrailerCapital1(t *testing.T) { - testTransportInvalidTrailerCapital(t, oneHeader) -} -func TestTransportInvalidTrailerCapital2(t *testing.T) { - testTransportInvalidTrailerCapital(t, splitHeader) -} -func testTransportInvalidTrailerCapital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) - }) -} -func TestTransportInvalidTrailerEmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) - }) -} -func TestTransportInvalidTrailerBinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) - }) -} - -func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := io.ReadAll(res.Body) - se, ok := err.(StreamError) - if !ok || se.Cause != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var endStream bool - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) - endStream = false - send(oneHeader) - } - // Trailers: - { - endStream = true - buf.Reset() - writeTrailer(enc) - send(trailers) - } - return nil - } - } - } - ct.run() -} - -// headerListSize returns the HTTP2 header list size of h. -// -// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE -// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock -func headerListSize(h http.Header) (size uint32) { - for k, vv := range h { - for _, v := range vv { - hf := hpack.HeaderField{Name: k, Value: v} - size += hf.Size() - } - } - return size -} - -// padHeaders adds data to an http.Header until headerListSize(h) == -// limit. Due to the way header list sizes are calculated, padHeaders -// cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will -// call t.Fatal if asked to do so. PadHeaders first reserves enough -// space for an empty "Pad-Headers" key, then adds as many copies of -// filler as possible. Any remaining bytes necessary to push the -// header list size up to limit are added to h["Pad-Headers"]. -func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { - if limit > 0xffffffff { - t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) - } - hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} - minPadding := uint64(hf.Size()) - size := uint64(headerListSize(h)) - - minlimit := size + minPadding - if limit < minlimit { - t.Fatalf("padHeaders: limit %v < %v", limit, minlimit) - } - - // Use a fixed-width format for name so that fieldSize - // remains constant. - nameFmt := "Pad-Headers-%06d" - hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler} - fieldSize := uint64(hf.Size()) - - // Add as many complete filler values as possible, leaving - // room for at least one empty "Pad-Headers" key. - limit = limit - minPadding - for i := 0; size+fieldSize < limit; i++ { - name := fmt.Sprintf(nameFmt, i) - h.Add(name, filler) - size += fieldSize - } - - // Add enough bytes to reach limit. - remain := limit - size - lastValue := strings.Repeat("*", int(remain)) - h.Add("Pad-Headers", lastValue) -} - -func TestPadHeaders(t *testing.T) { - check := func(h http.Header, limit uint32, fillerLen int) { - if h == nil { - h = make(http.Header) - } - filler := strings.Repeat("f", fillerLen) - padHeaders(t, h, uint64(limit), filler) - gotSize := headerListSize(h) - if gotSize != limit { - t.Errorf("Got size = %v; want %v", gotSize, limit) - } - } - // Try all possible combinations for small fillerLen and limit. - hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} - minLimit := hf.Size() - for limit := minLimit; limit <= 128; limit++ { - for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ { - check(nil, limit, fillerLen) - } - } - - // Try a few tests with larger limits, plus cumulative - // tests. Since these tests are cumulative, tests[i+1].limit - // must be >= tests[i].limit + minLimit. See the comment on - // padHeaders for more info on why the limit arg has this - // restriction. - tests := []struct { - fillerLen int - limit uint32 - }{ - { - fillerLen: 64, - limit: 1024, - }, - { - fillerLen: 1024, - limit: 1286, - }, - { - fillerLen: 256, - limit: 2048, - }, - { - fillerLen: 1024, - limit: 10 * 1024, - }, - { - fillerLen: 1023, - limit: 11 * 1024, - }, - } - h := make(http.Header) - for _, tc := range tests { - check(nil, tc.limit, tc.fillerLen) - check(h, tc.limit, tc.fillerLen) - } -} - -func TestTransportChecksRequestHeaderListSize(t *testing.T) { - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - // Consume body & force client to send - // trailers before writing response. - // io.ReadAll returns non-nil err for - // requests that attempt to send greater than - // maxHeaderListSize bytes of trailers, since - // those requests generate a stream reset. - io.ReadAll(r.Body) - r.Body.Close() - }, - func(ts *httptest.Server) { - ts.Config.MaxHeaderBytes = 16 << 10 - }, - optOnlyServer, - optQuiet, - ) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - checkRoundTrip := func(req *http.Request, wantErr error, desc string) { - // Make an arbitrary request to ensure we get the server's - // settings frame and initialize peerMaxHeaderListSize. - req0, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatalf("newRequest: NewRequest: %v", err) - } - res0, err := tr.RoundTrip(req0) - if err != nil { - t.Errorf("%v: Initial RoundTrip err = %v", desc, err) - } - res0.Body.Close() - res, err := tr.RoundTrip(req) - if err != wantErr { - if res != nil { - res.Body.Close() - } - t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) - return - } - if err == nil { - if res == nil { - t.Errorf("%v: response nil; want non-nil.", desc) - return - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) - } - return - } - if res != nil { - t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) - } - } - headerListSizeForRequest := func(req *http.Request) (size uint64) { - contentLen := actualContentLength(req) - trailers, err := commaSeparatedTrailers(req) - if err != nil { - t.Fatalf("headerListSizeForRequest: %v", err) - } - cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} - cc.henc = hpack.NewEncoder(&cc.hbuf) - cc.mu.Lock() - hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen, nil) - cc.mu.Unlock() - if err != nil { - t.Fatalf("headerListSizeForRequest: %v", err) - } - hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) { - size += uint64(hf.Size()) - }) - if len(hdrs) > 0 { - if _, err := hpackDec.Write(hdrs); err != nil { - t.Fatalf("headerListSizeForRequest: %v", err) - } - } - return size - } - // Create a new Request for each test, rather than reusing the - // same Request, to avoid a race when modifying req.Headers. - // See https://github.com/golang/go/issues/21316 - newRequest := func() *http.Request { - // Body must be non-nil to enable writing trailers. - body := strings.NewReader("hello") - req, err := http.NewRequest("POST", st.ts.URL, body) - if err != nil { - t.Fatalf("newRequest: NewRequest: %v", err) - } - return req - } - - // Validate peerMaxHeaderListSize. - req := newRequest() - checkRoundTrip(req, nil, "Initial request") - addr := authorityAddr(req.URL.Scheme, req.URL.Host) - cc, err := tr.connPool().GetClientConn(req, addr, true) - if err != nil { - t.Fatalf("GetClientConn: %v", err) - } - cc.mu.Lock() - peerSize := cc.peerMaxHeaderListSize - cc.mu.Unlock() - st.scMu.Lock() - wantSize := uint64(st.sc.maxHeaderListSize()) - st.scMu.Unlock() - if peerSize != wantSize { - t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) - } - - // Sanity check peerSize. (*serverConn) maxHeaderListSize adds - // 320 bytes of padding. - wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320 - if peerSize != wantHeaderBytes { - t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) - } - - // Pad headers & trailers, but stay under peerSize. - req = newRequest() - req.Header = make(http.Header) - req.Trailer = make(http.Header) - filler := strings.Repeat("*", 1024) - padHeaders(t, req.Trailer, peerSize, filler) - // cc.encodeHeaders adds some default headers to the request, - // so we need to leave room for those. - defaultBytes := headerListSizeForRequest(req) - padHeaders(t, req.Header, peerSize-defaultBytes, filler) - checkRoundTrip(req, nil, "Headers & Trailers under limit") - - // Add enough header bytes to push us over peerSize. - req = newRequest() - req.Header = make(http.Header) - padHeaders(t, req.Header, peerSize, filler) - checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit") - - // Push trailers over the limit. - req = newRequest() - req.Trailer = make(http.Header) - padHeaders(t, req.Trailer, peerSize+1, filler) - checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit") - - // Send headers with a single large value. - req = newRequest() - filler = strings.Repeat("*", int(peerSize)) - req.Header = make(http.Header) - req.Header.Set("Big", filler) - checkRoundTrip(req, errRequestHeaderListSize, "Single large header") - - // Send trailers with a single large value. - req = newRequest() - req.Trailer = make(http.Header) - req.Trailer.Set("Big", filler) - checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer") -} - -func TestTransportChecksResponseHeaderListSize(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if e, ok := err.(StreamError); ok { - err = e.Cause - } - if err != errResponseHeaderListSize { - size := int64(0) - if res != nil { - res.Body.Close() - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 - } - } - } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - large := strings.Repeat("a", 1<<10) - for i := 0; i < 5042; i++ { - enc.WriteField(hpack.HeaderField{Name: large, Value: large}) - } - if size, want := buf.Len(), 6329; size != want { - // Note: this number might change if - // our hpack implementation - // changes. That's fine. This is - // just a sanity check that our - // response can fit in a single - // header block fragment frame. - return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } - } - ct.run() -} - -func TestTransportCookieHeaderSplit(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - req.Header.Add("Cookie", "a=b;c=d; e=f;") - req.Header.Add("Cookie", "e=f;g=h; ") - req.Header.Add("Cookie", "i=j") - _, err := ct.tr.RoundTrip(req) - return err - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - dec := hpack.NewDecoder(initialHeaderTableSize, nil) - hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) - if err != nil { - return err - } - got := []string{} - want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} - for _, hf := range hfs { - if hf.Name == "cookie" { - got = append(got, hf.Value) - } - } - if !reflect.DeepEqual(got, want) { - t.Errorf("Cookies = %#v, want %#v", got, want) - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } - } - ct.run() -} - -// Test that the Transport returns a typed error from Response.Body.Read calls -// when the server sends an error. (here we use a panic, since that should generate -// a stream error, but others like cancel should be similar) -func TestTransportBodyReadErrorType(t *testing.T) { - doPanic := make(chan bool, 1) - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - w.(http.Flusher).Flush() // force headers out - <-doPanic - panic("boom") - }, - optOnlyServer, - optQuiet, - ) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - res, err := c.Get(st.ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - doPanic <- true - buf := make([]byte, 100) - n, err := res.Body.Read(buf) - got, ok := err.(StreamError) - want := StreamError{StreamID: 0x1, Code: 0x2} - if !ok || got.StreamID != want.StreamID || got.Code != want.Code { - t.Errorf("Read = %v, %#v; want error %#v", n, err, want) - } -} - -// golang.org/issue/13924 -// This used to fail after many iterations, especially with -race: -// go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race -func TestTransportDoubleCloseOnWriteError(t *testing.T) { - var ( - mu sync.Mutex - conn net.Conn // to close if set - ) - - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - defer mu.Unlock() - if conn != nil { - conn.Close() - } - }, - optOnlyServer, - ) - defer st.Close() - - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - tc, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - mu.Lock() - defer mu.Unlock() - conn = tc - return tc, nil - }, - } - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - c.Get(st.ts.URL) -} - -// Test that the http1 Transport.DisableKeepAlives option is respected -// and connections are closed as soon as idle. -// See golang.org/issue/14008 -func TestTransportDisableKeepAlives(t *testing.T) { - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "hi") - }, - optOnlyServer, - ) - defer st.Close() - - connClosed := make(chan struct{}) // closed on tls.Conn.Close - tr := &Transport{ - Options: &transport.Options{ - DisableKeepAlives: true, - TLSClientConfig: tlsConfigInsecure, - }, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - tc, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil - }, - } - c := &http.Client{Transport: tr} - res, err := c.Get(st.ts.URL) - if err != nil { - t.Fatal(err) - } - if _, err := io.ReadAll(res.Body); err != nil { - t.Fatal(err) - } - defer res.Body.Close() - - select { - case <-connClosed: - case <-time.After(1 * time.Second): - t.Errorf("timeout") - } - -} - -// Test concurrent requests with Transport.DisableKeepAlives. We can share connections, -// but when things are totally idle, it still needs to close. -func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { - const D = 25 * time.Millisecond - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) { - time.Sleep(D) - io.WriteString(w, "hi") - }, - optOnlyServer, - ) - defer st.Close() - - var dials int32 - var conns sync.WaitGroup - tr := &Transport{ - Options: &transport.Options{ - DisableKeepAlives: true, - TLSClientConfig: tlsConfigInsecure, - }, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - tc, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - atomic.AddInt32(&dials, 1) - conns.Add(1) - return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil - }, - } - c := &http.Client{Transport: tr} - var reqs sync.WaitGroup - const N = 20 - for i := 0; i < N; i++ { - reqs.Add(1) - if i == N-1 { - // For the final request, try to make all the - // others close. This isn't verified in the - // count, other than the Log statement, since - // it's so timing dependent. This test is - // really to make sure we don't interrupt a - // valid request. - time.Sleep(D * 2) - } - go func() { - defer reqs.Done() - res, err := c.Get(st.ts.URL) - if err != nil { - t.Error(err) - return - } - if _, err := io.ReadAll(res.Body); err != nil { - t.Error(err) - return - } - res.Body.Close() - }() - } - reqs.Wait() - conns.Wait() - t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N) -} - -type noteCloseConn struct { - net.Conn - onceClose sync.Once - closefn func() -} - -func (c *noteCloseConn) Close() error { - c.onceClose.Do(c.closefn) - return c.Conn.Close() -} - -func isTimeout(err error) bool { - switch err := err.(type) { - case nil: - return false - case *url.Error: - return isTimeout(err.Err) - case net.Error: - return err.Timeout() - } - return false -} - -// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. -func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { - testTransportResponseHeaderTimeout(t, false) -} -func TestTransportResponseHeaderTimeout_Body(t *testing.T) { - testTransportResponseHeaderTimeout(t, true) -} - -func testTransportResponseHeaderTimeout(t *testing.T, body bool) { - ct := newClientTester(t) - ct.tr.Options.ResponseHeaderTimeout = 5 * time.Millisecond - ct.client = func() error { - c := &http.Client{Transport: ct.tr} - var err error - var n int64 - const bodySize = 4 << 20 - if body { - _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) - } else { - _, err = c.Get("https://dummy.tld/") - } - if !isTimeout(err) { - t.Errorf("client expected timeout error; got %#v", err) - } - if body && n != bodySize { - t.Errorf("only read %d bytes of body; want %d", n, bodySize) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - switch f := f.(type) { - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - case *RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { - return nil - } - } - } - } - ct.run() -} - -func TestTransportDisableCompression(t *testing.T) { - const body = "sup" - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - want := http.Header{ - "User-Agent": []string{header.DefaultUserAgent}, - } - if !reflect.DeepEqual(r.Header, want) { - t.Errorf("request headers = %v; want %v", r.Header, want) - } - }, optOnlyServer) - defer st.Close() - - tr := &Transport{ - Options: &transport.Options{ - DisableCompression: true, - TLSClientConfig: tlsConfigInsecure, - }, - } - defer tr.CloseIdleConnections() - - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() -} - -// RFC 7540 section 8.1.2.2 -func TestTransportRejectsConnHeaders(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - var got []string - for k := range r.Header { - got = append(got, k) - } - sort.Strings(got) - w.Header().Set("Got-Header", strings.Join(got, ",")) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - tests := []struct { - key string - value []string - want string - }{ - { - key: "Upgrade", - value: []string{"anything"}, - want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]", - }, - { - key: "Connection", - value: []string{"foo"}, - want: "ERROR: http2: invalid Connection request header: [\"foo\"]", - }, - { - key: "Connection", - value: []string{"close"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Connection", - value: []string{"CLoSe"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Connection", - value: []string{"close", "something-else"}, - want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]", - }, - { - key: "Connection", - value: []string{"keep-alive"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Connection", - value: []string{"Keep-ALIVE"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Proxy-Connection", // just deleted and ignored - value: []string{"keep-alive"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Transfer-Encoding", - value: []string{""}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Transfer-Encoding", - value: []string{"foo"}, - want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]", - }, - { - key: "Transfer-Encoding", - value: []string{"chunked"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Transfer-Encoding", - value: []string{"chunKed"}, // Kelvin sign - want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]", - }, - { - key: "Transfer-Encoding", - value: []string{"chunked", "other"}, - want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]", - }, - { - key: "Content-Length", - value: []string{"123"}, - want: "Accept-Encoding,User-Agent", - }, - { - key: "Keep-Alive", - value: []string{"doop"}, - want: "Accept-Encoding,User-Agent", - }, - } - - for _, tt := range tests { - req, _ := http.NewRequest("GET", st.ts.URL, nil) - req.Header[tt.key] = tt.value - res, err := tr.RoundTrip(req) - var got string - if err != nil { - got = fmt.Sprintf("ERROR: %v", err) - } else { - got = res.Header.Get("Got-Header") - res.Body.Close() - } - if got != tt.want { - t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want) - } - } -} - -// Reject content-length headers containing a sign. -// See https://golang.org/issue/39017 -func TestTransportRejectsContentLengthWithSign(t *testing.T) { - testCases := []struct { - name string - cl []string - wantCL string - }{ - { - name: "proper content-length", - cl: []string{"3"}, - wantCL: "3", - }, - { - name: "ignore cl with plus sign", - cl: []string{"+3"}, - wantCL: "", - }, - { - name: "ignore cl with minus sign", - cl: []string{"-3"}, - wantCL: "", - }, - { - name: "max int64, for safe uint64->int64 conversion", - cl: []string{"9223372036854775807"}, - wantCL: "9223372036854775807", - }, - { - name: "overflows int64, so ignored", - cl: []string{"9223372036854775808"}, - wantCL: "", - }, - } - - for _, tt := range testCases { - tt := tt - t.Run(tt.name, func(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", tt.cl[0]) - }, optOnlyServer) - defer st.Close() - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - req, _ := http.NewRequest("HEAD", st.ts.URL, nil) - res, err := tr.RoundTrip(req) - - var got string - if err != nil { - got = fmt.Sprintf("ERROR: %v", err) - } else { - got = res.Header.Get("Content-Length") - res.Body.Close() - } - - if got != tt.wantCL { - t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL) - } - }) - } -} - -// golang.org/issue/14048 -func TestTransportFailsOnInvalidHeaders(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - var got []string - for k := range r.Header { - got = append(got, k) - } - sort.Strings(got) - w.Header().Set("Got-Header", strings.Join(got, ",")) - }, optOnlyServer) - defer st.Close() - - testCases := [...]struct { - h http.Header - wantErr string - }{ - 0: { - h: http.Header{"with space": {"foo"}}, - wantErr: `invalid HTTP header name "with space"`, - }, - 1: { - h: http.Header{"name": {"Брэд"}}, - wantErr: "", // okay - }, - 2: { - h: http.Header{"имя": {"Brad"}}, - wantErr: `invalid HTTP header name "имя"`, - }, - 3: { - h: http.Header{"foo": {"foo\x01bar"}}, - wantErr: `invalid HTTP header value for header "foo"`, - }, - } - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - for i, tt := range testCases { - req, _ := http.NewRequest("GET", st.ts.URL, nil) - req.Header = tt.h - res, err := tr.RoundTrip(req) - var bad bool - if tt.wantErr == "" { - if err != nil { - bad = true - t.Errorf("case %d: error = %v; want no error", i, err) - } - } else { - if !strings.Contains(fmt.Sprint(err), tt.wantErr) { - bad = true - t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr) - } - } - if err == nil { - if bad { - t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header")) - } - res.Body.Close() - } - } -} - -// Tests that GzipReader doesn't crash on a second Read call following -// the first Read call's gzip.NewReader returning an error. -func TestGzipReader_DoubleReadCrash(t *testing.T) { - gz := &GzipReader{ - Body: io.NopCloser(strings.NewReader("0123456789")), - } - var buf [1]byte - n, err1 := gz.Read(buf[:]) - if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { - t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) - } - n, err2 := gz.Read(buf[:]) - if n != 0 || err2 != err1 { - t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) - } -} - -func TestGzipReader_ReadAfterClose(t *testing.T) { - body := bytes.Buffer{} - w := gzip.NewWriter(&body) - w.Write([]byte("012345679")) - w.Close() - gz := &GzipReader{ - Body: io.NopCloser(&body), - } - var buf [1]byte - n, err := gz.Read(buf[:]) - if n != 1 || err != nil { - t.Fatalf("first Read = %v, %v; want 1, nil", n, err) - } - if err := gz.Close(); err != nil { - t.Fatalf("gz Close error: %v", err) - } - n, err = gz.Read(buf[:]) - if n != 0 || err != fs.ErrClosed { - t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err) - } -} - -func TestTransportNewTLSConfig(t *testing.T) { - testCases := [...]struct { - conf *tls.Config - host string - want *tls.Config - }{ - // Normal case. - 0: { - conf: nil, - host: "foo.com", - want: &tls.Config{ - ServerName: "foo.com", - NextProtos: []string{NextProtoTLS}, - }, - }, - - // User-provided name (bar.com) takes precedence: - 1: { - conf: &tls.Config{ - ServerName: "bar.com", - }, - host: "foo.com", - want: &tls.Config{ - ServerName: "bar.com", - NextProtos: []string{NextProtoTLS}, - }, - }, - - // NextProto is prepended: - 2: { - conf: &tls.Config{ - NextProtos: []string{"foo", "bar"}, - }, - host: "example.com", - want: &tls.Config{ - ServerName: "example.com", - NextProtos: []string{NextProtoTLS, "foo", "bar"}, - }, - }, - - // NextProto is not duplicated: - 3: { - conf: &tls.Config{ - NextProtos: []string{"foo", "bar", NextProtoTLS}, - }, - host: "example.com", - want: &tls.Config{ - ServerName: "example.com", - NextProtos: []string{"foo", "bar", NextProtoTLS}, - }, - }, - } - for i, tt := range testCases { - // Ignore the session ticket keys part, which ends up populating - // unexported fields in the Config: - if tt.conf != nil { - tt.conf.SessionTicketsDisabled = true - } - - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tt.conf, - }, - } - got := tr.newTLSConfig(tt.host) - - got.SessionTicketsDisabled = false - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("%d. got %#v; want %#v", i, got, tt.want) - } - } -} - -// The Google GFE responds to HEAD requests with a HEADERS frame -// without END_STREAM, followed by a 0-length DATA frame with -// END_STREAM. Make sure we don't get confused by that. (We did.) -func TestTransportReadHeadResponse(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != 123 { - return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, // as the GFE does - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, nil) - - <-clientDone - return nil - } - } - ct.run() -} - -func TestTransportReadHeadResponseWithBody(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(io.Discard) - defer log.SetOutput(os.Stderr) - - response := "redirecting to /elsewhere" - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != int64(len(response)) { - return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, []byte(response)) - - <-clientDone - return nil - } - } - ct.run() -} - -// https://golang.org/issue/15930 -func TestTransportFlowControl(t *testing.T) { - const bufLen = 64 << 10 - var total int64 = 100 << 20 // 100MB - if testing.Short() { - total = 10 << 20 - } - - var wrote int64 // updated atomically - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - b := make([]byte, bufLen) - for wrote < total { - n, err := w.Write(b) - atomic.AddInt64(&wrote, int64(n)) - if err != nil { - t.Errorf("ResponseWriter.Write error: %v", err) - break - } - w.(http.Flusher).Flush() - } - }, optOnlyServer) - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal("NewRequest error:", err) - } - resp, err := tr.RoundTrip(req) - if err != nil { - t.Fatal("RoundTrip error:", err) - } - defer resp.Body.Close() - - var read int64 - b := make([]byte, bufLen) - for { - n, err := resp.Body.Read(b) - if err == io.EOF { - break - } - if err != nil { - t.Fatal("Read error:", err) - } - read += int64(n) - - const max = transportDefaultStreamFlow - if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { - t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) - } - - // Let the server get ahead of the client. - time.Sleep(1 * time.Millisecond) - } -} - -// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make -// the Transport remember it and return it back to users (via -// RoundTrip or request body reads) if needed (e.g. if the server -// proceeds to close the TCP connection before the client gets its -// response) -func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { - testTransportUsesGoAwayDebugError(t, false) -} - -func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { - testTransportUsesGoAwayDebugError(t, true) -} - -func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - - const goAwayErrCode = ErrCodeHTTP11Required // arbitrary - const goAwayDebugData = "some debug data" - - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if failMidBody { - if err != nil { - return fmt.Errorf("unexpected client RoundTrip error: %v", err) - } - _, err = io.Copy(io.Discard, res.Body) - res.Body.Close() - } - want := GoAwayError{ - LastStreamID: 5, - ErrCode: goAwayErrCode, - DebugData: goAwayDebugData, - } - if !reflect.DeepEqual(err, want) { - t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - if failMidBody { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - // Write two GOAWAY frames, to test that the Transport takes - // the interesting parts of both. - ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) - ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.sc.(*net.TCPConn).Close() - } - <-clientDone - return nil - } - } - ct.run() -} - -func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { - ct := newClientTester(t) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - - if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { - return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) - } - res.Body.Close() // leaving 4999 bytes unread - - return nil - } - ct.server = func() error { - ct.greet() - - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialInflow := ct.inflowWindow(0) - - // Two cases: - // - Send one DATA frame with 5000 bytes. - // - Send two DATA frames with 1 and 4999 bytes each. - // - // In both cases, the client should consume one byte of data, - // refund that byte, then refund the following 4999 bytes. - // - // In the second case, the server waits for the client to reset the - // stream before sending the second DATA frame. This tests the case - // where the client receives a DATA frame after it has reset the stream. - if oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - } else { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - } - - wantRST := true - wantWUF := true - if !oneDataFrame { - wantWUF = false // flow control update is small, and will not be sent - } - for wantRST || wantWUF { - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *RSTStreamFrame: - if !wantRST { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) - } - wantRST = false - case *WindowUpdateFrame: - if !wantWUF { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.Increment != 5000 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) - } - wantWUF = false - default: - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - } - if !oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - wuf, ok := f.(*WindowUpdateFrame) - if !ok || wuf.Increment != 5000 { - return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) - } - } - if err := ct.writeReadPing(); err != nil { - return err - } - if got, want := ct.inflowWindow(0), initialInflow; got != want { - return fmt.Errorf("connection flow tokens = %v, want %v", got, want) - } - return nil - } - ct.run() -} - -// See golang.org/issue/16481 -func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, true) -} - -// See golang.org/issue/20469 -func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, false) -} - -// Issue 16612: adjust flow control on open streams when transport -// receives SETTINGS with INITIAL_WINDOW_SIZE from server. -func TestTransportAdjustsFlowControl(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - - const bodySize = 1 << 20 - - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(tests.NeverEnding('A'), bodySize)}) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - res.Body.Close() - return nil - } - ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - - var gotBytes int64 - var sentSettings bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - return nil - default: - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - } - switch f := f.(type) { - case *DataFrame: - gotBytes += int64(len(f.Data())) - // After we've got half the client's - // initial flow control window's worth - // of request body data, give it just - // enough flow control to finish. - if gotBytes >= initialWindowSize/2 && !sentSettings { - sentSettings = true - - ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) - ct.fr.WriteWindowUpdate(0, bodySize) - ct.fr.WriteSettingsAck() - } - - if f.StreamEnded() { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - } - } - } - ct.run() -} - -// See golang.org/issue/16556 -func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool, 1) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - defer res.Body.Close() - <-unblockClient - return nil - } - ct.server = func() error { - ct.greet() - - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } - - initialConnWindow := ct.inflowWindow(0) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialStreamWindow := ct.inflowWindow(hf.StreamID) - pad := make([]byte, 5) - ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - if err := ct.writeReadPing(); err != nil { - return err - } - // Padding flow control should have been returned. - if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { - t.Errorf("conn inflow window = %v, want %v", got, want) - } - if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { - t.Errorf("stream inflow window = %v, want %v", got, want) - } - unblockClient <- true - return nil - } - ct.run() -} - -// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a -// StreamError as a result of the response HEADERS -func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { - ct := newClientTester(t) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - return errors.New("unexpected successful GET") - } - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} - if !reflect.DeepEqual(want, err) { - t.Errorf("RoundTrip error = %#v; want %#v", err, want) - } - return nil - } - ct.server = func() error { - ct.greet() - - hf, err := ct.firstHeaders() - if err != nil { - return err - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - - for { - fr, err := ct.readFrame() - if err != nil { - return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) - } - if _, ok := fr.(*SettingsFrame); ok { - continue - } - if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) - } - break - } - - return nil - } - ct.run() -} - -// byteAndEOFReader returns is in an io.Reader which reads one byte -// (the underlying byte) and io.EOF at once in its Read call. -type byteAndEOFReader byte - -func (b byteAndEOFReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - panic("unexpected useless call") - } - p[0] = byte(b) - return 1, io.EOF -} - -// Issue 16788: the Transport had a regression where it started -// sending a spurious DATA frame with a duplicate END_STREAM bit after -// the request body writer goroutine had already read an EOF from the -// Request.Body and included the END_STREAM on a data-carrying DATA -// frame. -// -// Notably, to trigger this, the requests need to use a Request.Body -// which returns (non-0, io.EOF) and also needs to set the ContentLength -// explicitly. -func TestTransportBodyDoubleEndStream(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - // Nothing. - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - for i := 0; i < 2; i++ { - req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a')) - req.ContentLength = 1 - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("failure on req %d: %v", i+1, err) - } - defer res.Body.Close() - } -} - -// golang.org/issue/16847, golang.org/issue/19103 -func TestTransportRequestPathPseudo(t *testing.T) { - type result struct { - path string - err string - } - tests := []struct { - req *http.Request - want result - }{ - 0: { - req: &http.Request{ - Method: "GET", - URL: &url.URL{ - Host: "foo.com", - Path: "/foo", - }, - }, - want: result{path: "/foo"}, - }, - // In Go 1.7, we accepted paths of "//foo". - // In Go 1.8, we rejected it (issue 16847). - // In Go 1.9, we accepted it again (issue 19103). - 1: { - req: &http.Request{ - Method: "GET", - URL: &url.URL{ - Host: "foo.com", - Path: "//foo", - }, - }, - want: result{path: "//foo"}, - }, - - // Opaque with //$Matching_Hostname/path - 2: { - req: &http.Request{ - Method: "GET", - URL: &url.URL{ - Scheme: "https", - Opaque: "//foo.com/path", - Host: "foo.com", - Path: "/ignored", - }, - }, - want: result{path: "/path"}, - }, - - // Opaque with some other Request.Host instead: - 3: { - req: &http.Request{ - Method: "GET", - Host: "bar.com", - URL: &url.URL{ - Scheme: "https", - Opaque: "//bar.com/path", - Host: "foo.com", - Path: "/ignored", - }, - }, - want: result{path: "/path"}, - }, - - // Opaque without the leading "//": - 4: { - req: &http.Request{ - Method: "GET", - URL: &url.URL{ - Opaque: "/path", - Host: "foo.com", - Path: "/ignored", - }, - }, - want: result{path: "/path"}, - }, - - // Opaque we can't handle: - 5: { - req: &http.Request{ - Method: "GET", - URL: &url.URL{ - Scheme: "https", - Opaque: "//unknown_host/path", - Host: "foo.com", - Path: "/ignored", - }, - }, - want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`}, - }, - - // A CONNECT request: - 6: { - req: &http.Request{ - Method: "CONNECT", - URL: &url.URL{ - Host: "foo.com", - }, - }, - want: result{}, - }, - } - for i, tt := range tests { - cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} - cc.henc = hpack.NewEncoder(&cc.hbuf) - cc.mu.Lock() - hdrs, err := cc.encodeHeaders(tt.req, false, "", -1, nil) - cc.mu.Unlock() - var got result - hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { - if f.Name == ":path" { - got.path = f.Value - } - }) - if err != nil { - got.err = err.Error() - } else if len(hdrs) > 0 { - if _, err := hpackDec.Write(hdrs); err != nil { - t.Errorf("%d. bogus hpack: %v", i, err) - continue - } - } - if got != tt.want { - t.Errorf("%d. got %+v; want %+v", i, got, tt.want) - } - - } - -} - -// golang.org/issue/17071 -- don't sniff the first byte of the request body -// before we've determined that the ClientConn is usable. -func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { - const body = "foo" - req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) - cc := &ClientConn{ - closed: true, - reqHeaderMu: make(chan struct{}, 1), - } - _, err := cc.RoundTrip(req) - if err != errClientConnUnusable { - t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) - } - slurp, err := io.ReadAll(req.Body) - if err != nil { - t.Errorf("ReadAll = %v", err) - } - if string(slurp) != body { - t.Errorf("Body = %q; want %q", slurp, body) - } -} - -func TestClientConnPing(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) - defer st.Close() - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) - if err != nil { - t.Fatal(err) - } - if err = cc.Ping(context.Background()); err != nil { - t.Fatal(err) - } -} - -// Issue 16974: if the server sent a DATA frame after the user -// canceled the Transport's Request, the Transport previously wrote to a -// closed pipe, got an error, and ended up closing the whole TCP -// connection. -func TestTransportCancelDataResponseRace(t *testing.T) { - cancel := make(chan struct{}) - clientGotResponse := make(chan bool, 1) - - const msg = "Hello." - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/hello") { - time.Sleep(50 * time.Millisecond) - io.WriteString(w, msg) - return - } - for i := 0; i < 50; i++ { - io.WriteString(w, "Some data.") - w.(http.Flusher).Flush() - if i == 2 { - <-clientGotResponse - close(cancel) - } - time.Sleep(10 * time.Millisecond) - } - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - c := &http.Client{Transport: tr} - req, _ := http.NewRequest("GET", st.ts.URL, nil) - req.Cancel = cancel - res, err := c.Do(req) - clientGotResponse <- true - if err != nil { - t.Fatal(err) - } - if _, err = io.Copy(io.Discard, res.Body); err == nil { - t.Fatal("unexpected success") - } - - res, err = c.Get(st.ts.URL + "/hello") - if err != nil { - t.Fatal(err) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if string(slurp) != msg { - t.Errorf("Got = %q; want %q", slurp, msg) - } -} - -// Issue 21316: It should be safe to reuse an http.Request after the -// request has completed. -func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - io.WriteString(w, "body") - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - req, _ := http.NewRequest("GET", st.ts.URL, nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if _, err = io.Copy(io.Discard, resp.Body); err != nil { - t.Fatalf("error reading response body: %v", err) - } - if err := resp.Body.Close(); err != nil { - t.Fatalf("error closing response body: %v", err) - } - - // This access of req.Header should not race with code in the transport. - req.Header = http.Header{} -} - -func TestTransportCloseAfterLostPing(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.tr.PingTimeout = 1 * time.Second - ct.tr.ReadIdleTimeout = 1 * time.Second - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - if err == nil || !strings.Contains(err.Error(), "client connection lost") { - return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - <-clientDone - return nil - } - ct.run() -} - -func TestTransportPingWriteBlocks(t *testing.T) { - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) {}, - optOnlyServer, - ) - defer st.Close() - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - s, c := net.Pipe() // unbuffered, unlike a TCP conn - go func() { - // Read initial handshake frames. - // Without this, we block indefinitely in newClientConn, - // and never get to the point of sending a PING. - var buf [1024]byte - s.Read(buf[:]) - }() - return c, nil - }, - PingTimeout: 1 * time.Millisecond, - ReadIdleTimeout: 1 * time.Millisecond, - } - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - _, err := c.Get(st.ts.URL) - if err == nil { - t.Fatalf("Get = nil, want error") - } -} - -func TestTransportPingWhenReading(t *testing.T) { - testCases := []struct { - name string - readIdleTimeout time.Duration - deadline time.Duration - expectedPingCount int - }{ - { - name: "two pings", - readIdleTimeout: 100 * time.Millisecond, - deadline: time.Second, - expectedPingCount: 2, - }, - { - name: "zero ping", - readIdleTimeout: time.Second, - deadline: 200 * time.Millisecond, - expectedPingCount: 0, - }, - { - name: "0 readIdleTimeout means no ping", - readIdleTimeout: 0 * time.Millisecond, - deadline: 500 * time.Millisecond, - expectedPingCount: 0, - }, - } - - for _, tc := range testCases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) - }) - } -} - -func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { - var pingCount int - ct := newClientTester(t) - ct.tr.ReadIdleTimeout = readIdleTimeout - - ctx, cancel := context.WithTimeout(context.Background(), deadline) - defer cancel() - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) - } - _, err = io.ReadAll(res.Body) - if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil - } - - cancel() - return err - } - - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var streamID uint32 - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-ctx.Done(): - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - streamID = f.StreamID - case *PingFrame: - pingCount++ - if pingCount == expectedPingCount { - if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { - return err - } - } - if err := ct.fr.WritePing(true, f.Data); err != nil { - return err - } - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - ct.run() -} - -func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { - ln := newLocalListener(t) - defer ln.Close() - - var ( - mu sync.Mutex - count int - conns []net.Conn - ) - var wg sync.WaitGroup - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - } - tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - defer mu.Unlock() - count++ - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - return nil, fmt.Errorf("dial error: %v", err) - } - conns = append(conns, cc) - sc, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("accept error: %v", err) - } - conns = append(conns, sc) - ct := &clientTester{ - t: t, - tr: tr, - cc: cc, - sc: sc, - fr: NewFramer(sc, sc), - } - wg.Add(1) - go func(count int) { - defer wg.Done() - server(count, ct) - }(count) - return cc, nil - } - - client(tr) - tr.CloseIdleConnections() - ln.Close() - for _, c := range conns { - c.Close() - } - wg.Wait() -} - -func TestTransportRetryAfterGOAWAY(t *testing.T) { - client := func(tr *Transport) { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := tr.RoundTrip(req) - if res != nil { - res.Body.Close() - if got := res.Header.Get("Foo"); got != "bar" { - err = fmt.Errorf("foo header = %q; want bar", got) - } - } - if err != nil { - t.Errorf("RoundTrip: %v", err) - } - } - - server := func(count int, ct *clientTester) { - switch count { - case 1: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - t.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - case 2: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server2 failed reading HEADERS: %v", err) - return - } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - t.Errorf("server2 failed writing response HEADERS: %v", err) - } - default: - t.Errorf("unexpected number of dials") - return - } - } - - testClientMultipleDials(t, client, server) -} - -func TestTransportRetryAfterRefusedStream(t *testing.T) { - clientDone := make(chan struct{}) - client := func(tr *Transport) { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - return - } - resp.Body.Close() - if resp.StatusCode != 204 { - t.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - - server := func(count int, ct *clientTester) { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - default: - t.Error(err) - } - return - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - t.Errorf("headers should have END_HEADERS be ended: %v", f) - return - } - if count == 1 { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - } else { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - default: - t.Errorf("Unexpected client frame %v", f) - return - } - } - } - - testClientMultipleDials(t, client, server) -} - -func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - ct.run() -} - -func TestTransportResponseDataBeforeHeaders(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(io.Discard) - defer log.SetOutput(os.Stderr) - - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req := httptest.NewRequest("GET", "https://dummy.tld/", nil) - // First request is normal to ensure the check is per stream and not per connection. - _, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip expected no error, got: %v", err) - } - // Second request returns a DATA frame with no HEADERS. - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { - return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: - case *HeadersFrame: - switch f.StreamID { - case 1: - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - case 3: - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - } - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - ct.run() -} - -func TestTransportRequestsLowServerLimit(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, optOnlyServer, func(s *Server) { - s.MaxConcurrentStreams = 1 - }) - defer st.Close() - - var ( - connCountMu sync.Mutex - connCount int - ) - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - connCountMu.Lock() - defer connCountMu.Unlock() - connCount++ - return tls.Dial(network, addr, cfg) - }, - } - defer tr.CloseIdleConnections() - - const reqCount = 3 - for i := 0; i < reqCount; i++ { - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - } - - if connCount != 1 { - t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) - } -} - -// tests Transport.StrictMaxConcurrentStreams -func TestTransportRequestsStallAtServerLimit(t *testing.T) { - const maxConcurrent = 2 - - greet := make(chan struct{}) // server sends initial SETTINGS frame - gotRequest := make(chan struct{}) // server received a request - clientDone := make(chan struct{}) - cancelClientRequest := make(chan struct{}) - - // Collect errors from goroutines. - var wg sync.WaitGroup - errs := make(chan error, 100) - defer func() { - wg.Wait() - close(errs) - for err := range errs { - t.Error(err) - } - }() - - // We will send maxConcurrent+2 requests. This checker goroutine waits for the - // following stages: - // 1. The first maxConcurrent requests are received by the server. - // 2. The client will cancel the next request - // 3. The server is unblocked so it can service the first maxConcurrent requests - // 4. The client will send the final request - wg.Add(1) - unblockClient := make(chan struct{}) - clientRequestCancelled := make(chan struct{}) - unblockServer := make(chan struct{}) - go func() { - defer wg.Done() - // Stage 1. - for k := 0; k < maxConcurrent; k++ { - <-gotRequest - } - // Stage 2. - close(unblockClient) - <-clientRequestCancelled - // Stage 3: give some time for the final RoundTrip call to be scheduled and - // verify that the final request is not sent. - time.Sleep(50 * time.Millisecond) - select { - case <-gotRequest: - errs <- errors.New("last request did not stall") - close(unblockServer) - return - default: - } - close(unblockServer) - // Stage 4. - <-gotRequest - }() - - ct := newClientTester(t) - ct.tr.StrictMaxConcurrentStreams = true - ct.client = func() error { - var wg sync.WaitGroup - defer func() { - wg.Wait() - close(clientDone) - ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.cc.(*net.TCPConn).Close() - } - }() - for k := 0; k < maxConcurrent+2; k++ { - wg.Add(1) - go func(k int) { - defer wg.Done() - // Don't send the second request until after receiving SETTINGS from the server - // to avoid a race where we use the default SettingMaxConcurrentStreams, which - // is much larger than maxConcurrent. We have to send the first request before - // waiting because the first request triggers the dial and greet. - if k > 0 { - <-greet - } - // Block until maxConcurrent requests are sent before sending any more. - if k >= maxConcurrent { - <-unblockClient - } - body := newStaticCloseChecker("") - req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) - if k == maxConcurrent { - // This request will be canceled. - req.Cancel = cancelClientRequest - close(cancelClientRequest) - _, err := ct.tr.RoundTrip(req) - close(clientRequestCancelled) - if err == nil { - errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) - return - } - } else { - resp, err := ct.tr.RoundTrip(req) - if err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - return - } - io.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 204 { - errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - if err := body.isClosed(); err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - } - }(k) - } - return nil - } - - ct.server = func() error { - var wg sync.WaitGroup - defer wg.Wait() - - ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) - - // Server write loop. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - writeResp := make(chan uint32, maxConcurrent+1) - - wg.Add(1) - go func() { - defer wg.Done() - <-unblockServer - for id := range writeResp { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: id, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - }() - - // Server read loop. - var nreq int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it will have reported any errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame: - case *SettingsFrame: - // Wait for the client SETTINGS ack until ending the greet. - close(greet) - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - gotRequest <- struct{}{} - nreq++ - writeResp <- f.StreamID - if nreq == maxConcurrent+1 { - close(writeResp) - } - case *DataFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - - ct.run() -} - -func TestAuthorityAddr(t *testing.T) { - tests := []struct { - scheme, authority string - want string - }{ - {"http", "foo.com", "foo.com:80"}, - {"https", "foo.com", "foo.com:443"}, - {"https", "foo.com:1234", "foo.com:1234"}, - {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, - {"https", "1.2.3.4", "1.2.3.4:443"}, - {"https", "[::1]:1234", "[::1]:1234"}, - {"https", "[::1]", "[::1]:443"}, - } - for _, tt := range tests { - got := authorityAddr(tt.scheme, tt.authority) - if got != tt.want { - t.Errorf("http2authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) - } - } -} - -// Issue 20448: stop allocating for DATA frames' payload after -// Response.Body.Close is called. -func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { - megabyteZero := make([]byte, 1<<20) - - writeErr := make(chan error, 1) - - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.(http.Flusher).Flush() - var sum int64 - for i := 0; i < 100; i++ { - n, err := w.Write(megabyteZero) - sum += int64(n) - if err != nil { - writeErr <- err - return - } - } - t.Logf("wrote all %d bytes", sum) - writeErr <- nil - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - res, err := c.Get(st.ts.URL) - if err != nil { - t.Fatal(err) - } - var buf [1]byte - if _, err := res.Body.Read(buf[:]); err != nil { - t.Error(err) - } - if err := res.Body.Close(); err != nil { - t.Error(err) - } - - trb, ok := res.Body.(transportResponseBody) - if !ok { - t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) - } - if trb.cs.bufPipe.b != nil { - t.Errorf("response body pipe is still open") - } - - gotErr := <-writeErr - if gotErr == nil { - t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") - } else if gotErr != errStreamClosed { - t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) - } -} - -// Issue 18891: make sure Request.Body == NoBody means no DATA frame -// is ever sent, even if empty. -func TestTransportNoBodyMeansNoDATA(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - <-unblockClient - return nil - } - ct.server = func() error { - defer close(unblockClient) - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f := f.(type) { - default: - return fmt.Errorf("Got %T; want HeadersFrame", f) - case *WindowUpdateFrame, *SettingsFrame: - continue - case *HeadersFrame: - if !f.StreamEnded() { - return fmt.Errorf("got headers frame without END_STREAM") - } - return nil - } - } - } - ct.run() -} - -func disableGoroutineTracking() (restore func()) { - old := DebugGoroutines - DebugGoroutines = false - return func() { DebugGoroutines = old } -} - -func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { - defer disableGoroutineTracking()() - b.ReportAllocs() - st := newServerTester(b, - func(w http.ResponseWriter, r *http.Request) { - for i := 0; i < nResHeader; i++ { - name := fmt.Sprint("A-", i) - w.Header().Set(name, "*") - } - }, - optOnlyServer, - optQuiet, - ) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - b.Fatal(err) - } - - for i := 0; i < nReqHeaders; i++ { - name := fmt.Sprint("A-", i) - req.Header.Set(name, "*") - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - res, err := tr.RoundTrip(req) - if err != nil { - if res != nil { - res.Body.Close() - } - b.Fatalf("RoundTrip err = %v; want nil", err) - } - res.Body.Close() - if res.StatusCode != http.StatusOK { - b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) - } - } -} - -type infiniteReader struct{} - -func (r infiniteReader) Read(b []byte) (int, error) { - return len(b), nil -} - -// Issue 20521: it is not an error to receive a response and end stream -// from the server without the body being consumed. -func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - // The request body needs to be big enough to trigger flow control. - req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{}) - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) - } -} - -// Verify transport doesn't crash when receiving bogus response lacking a :status header. -// Issue 22880. -func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - const substr = "malformed response from server: missing status pseudo header" - if !strings.Contains(fmt.Sprint(err), substr) { - return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, // we'll send some DATA to try to crash the transport - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - return nil - } - } - } - ct.run() -} - -func BenchmarkClientRequestHeaders(b *testing.B) { - b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) - b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) }) - b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) }) - b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) }) -} - -func BenchmarkClientResponseHeaders(b *testing.B) { - b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) - b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) }) - b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) }) - b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) }) -} - -func activeStreams(cc *ClientConn) int { - count := 0 - cc.mu.Lock() - defer cc.mu.Unlock() - for _, cs := range cc.streams { - select { - case <-cs.abort: - default: - count++ - } - } - return count -} - -type closeMode int - -const ( - closeAtHeaders closeMode = iota - closeAtBody - shutdown - shutdownCancel -) - -// See golang.org/issue/17292 -func testClientConnClose(t *testing.T, closeMode closeMode) { - clientDone := make(chan struct{}) - defer close(clientDone) - handlerDone := make(chan struct{}) - closeDone := make(chan struct{}) - beforeHeader := func() {} - bodyWrite := func(w http.ResponseWriter) {} - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - defer close(handlerDone) - beforeHeader() - w.WriteHeader(http.StatusOK) - w.(http.Flusher).Flush() - bodyWrite(w) - select { - case <-w.(http.CloseNotifier).CloseNotify(): - // client closed connection before completion - if closeMode == shutdown || closeMode == shutdownCancel { - t.Error("expected request to complete") - } - case <-clientDone: - if closeMode == closeAtHeaders || closeMode == closeAtBody { - t.Error("expected connection closed by client") - } - } - }, optOnlyServer) - defer st.Close() - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) - req, err := http.NewRequest("GET", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - if closeMode == closeAtHeaders { - beforeHeader = func() { - if err := cc.Close(); err != nil { - t.Error(err) - } - close(closeDone) - } - } - var sendBody chan struct{} - if closeMode == closeAtBody { - sendBody = make(chan struct{}) - bodyWrite = func(w http.ResponseWriter) { - <-sendBody - b := make([]byte, 32) - w.Write(b) - w.(http.Flusher).Flush() - if err := cc.Close(); err != nil { - t.Errorf("unexpected ClientConn close error: %v", err) - } - close(closeDone) - w.Write(b) - w.(http.Flusher).Flush() - } - } - res, err := cc.RoundTrip(req) - if res != nil { - defer res.Body.Close() - } - if closeMode == closeAtHeaders { - got := fmt.Sprint(err) - want := "http2: client connection force closed via ClientConn.Close" - if got != want { - t.Fatalf("RoundTrip error = %v, want %v", got, want) - } - } else { - if err != nil { - t.Fatalf("RoundTrip: %v", err) - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - } - switch closeMode { - case shutdownCancel: - if err = cc.Shutdown(canceledCtx); err != context.Canceled { - t.Errorf("got %v, want %v", err, context.Canceled) - } - if cc.closing == false { - t.Error("expected closing to be true") - } - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if v, want := len(cc.streams), 1; v != want { - t.Errorf("expected %d active streams, got %d", want, v) - } - clientDone <- struct{}{} - <-handlerDone - case shutdown: - wait := make(chan struct{}) - shutdownEnterWaitStateHook = func() { - close(wait) - shutdownEnterWaitStateHook = func() {} - } - defer func() { shutdownEnterWaitStateHook = func() {} }() - shutdown := make(chan struct{}, 1) - go func() { - if err = cc.Shutdown(context.Background()); err != nil { - t.Error(err) - } - close(shutdown) - }() - // Let the shutdown to enter wait state - <-wait - cc.mu.Lock() - if cc.closing == false { - t.Error("expected closing to be true") - } - cc.mu.Unlock() - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - // Let the active request finish - clientDone <- struct{}{} - // Wait for the shutdown to end - select { - case <-shutdown: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") - } - case closeAtHeaders, closeAtBody: - if closeMode == closeAtBody { - go close(sendBody) - if _, err := io.Copy(io.Discard, res.Body); err == nil { - t.Error("expected a Copy error, got nil") - } - } - <-closeDone - if got, want := activeStreams(cc), 0; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - // wait for server to get the connection close notice - select { - case <-handlerDone: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") - } - } -} - -// The client closes the connection just after the server got the client's HEADERS -// frame, but before the server sends its HEADERS response back. The expected -// result is an error on RoundTrip explaining the client closed the connection. -func TestClientConnCloseAtHeaders(t *testing.T) { - testClientConnClose(t, closeAtHeaders) -} - -// The client closes the connection between two server's response DATA frames. -// The expected behavior is a response body io read error on the client. -func TestClientConnCloseAtBody(t *testing.T) { - testClientConnClose(t, closeAtBody) -} - -// The client sends a GOAWAY frame before the server finished processing a request. -// We expect the connection not to close until the request is completed. -func TestClientConnShutdown(t *testing.T) { - testClientConnClose(t, shutdown) -} - -// The client sends a GOAWAY frame before the server finishes processing a request, -// but cancels the passed context before the request is completed. The expected -// behavior is the client closing the connection after the context is canceled. -func TestClientConnShutdownCancel(t *testing.T) { - testClientConnClose(t, shutdownCancel) -} - -// Issue 25009: use Request.GetBody if present, even if it seems like -// we might not need it. Apparently something else can still read from -// the original request body. Data race? In any case, rewinding -// unconditionally on retry is a nicer model anyway and should -// simplify code in the future (after the Go 1.11 freeze) -func TestTransportUsesGetBodyWhenPresent(t *testing.T) { - calls := 0 - someBody := func() io.ReadCloser { - return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))} - } - req := &http.Request{ - Body: someBody(), - GetBody: func() (io.ReadCloser, error) { - calls++ - return someBody(), nil - }, - } - - req2, err := shouldRetryRequest(req, errClientConnUnusable) - if err != nil { - t.Fatal(err) - } - if calls != 1 { - t.Errorf("Calls = %d; want 1", calls) - } - if req2 == req { - t.Error("req2 changed") - } - if req2 == nil { - t.Fatal("req2 is nil") - } - if req2.Body == nil { - t.Fatal("req2.Body is nil") - } - if req2.GetBody == nil { - t.Fatal("req2.GetBody is nil") - } - if req2.Body == req.Body { - t.Error("req2.Body unchanged") - } -} - -type errReader struct { - body []byte - err error -} - -func (r *errReader) Read(p []byte) (int, error) { - if len(r.body) > 0 { - n := copy(p, r.body) - r.body = r.body[n:] - return n, nil - } - return 0, r.err -} - -func testTransportBodyReadError(t *testing.T, body []byte) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - // So far we've only seen this be flaky on Windows and Plan 9, - // perhaps due to TCP behavior on shutdowns while - // unread data is in flight. This test should be - // fixed, but a skip is better than annoying people - // for now. - t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) - } - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*clientConnPool) - if !ok { - return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) - } - cp.mu.Lock() - defer cp.mu.Unlock() - conns, ok := cp.conns["dummy.tld:443"] - if !ok { - return fmt.Errorf("missing connection") - } - if len(conns) != 1 { - return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) - } - if activeStreams(conns[0]) != 0 { - return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) - } - return nil - } - bodyReadError := errors.New("body read error") - body := &errReader{body, bodyReadError} - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != bodyReadError { - return fmt.Errorf("err = %v; want %v", err, bodyReadError) - } - if err = checkNoStreams(); err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() - var receivedBody []byte - var resetCount int - for { - f, err := ct.fr.ReadFrame() - t.Logf("server: ReadFrame = %v, %v", f, err) - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - if !bytes.Equal(receivedBody, body) { - return fmt.Errorf("body: %q; expected %q", receivedBody, body) - } - if resetCount != 1 { - return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) - } - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - receivedBody = append(receivedBody, f.Data()...) - case *RSTStreamFrame: - resetCount++ - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - ct.run() -} - -func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } -func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) } - -// Issue 32254: verify that the client sends END_STREAM flag eagerly with the last -// (or in this test-case the only one) request body data frame, and does not send -// extra zero-len data frames. -func TestTransportBodyEagerEndStream(t *testing.T) { - const reqBody = "some request body" - const resBody = "some response body" - - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - body := strings.NewReader(reqBody) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - return fmt.Errorf("data frame without END_STREAM %v", f) - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte(resBody)) - return nil - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } - } - ct.run() -} - -type chunkReader struct { - chunks [][]byte -} - -func (r *chunkReader) Read(p []byte) (int, error) { - if len(r.chunks) > 0 { - n := copy(p, r.chunks[0]) - r.chunks = r.chunks[1:] - return n, nil - } - panic("shouldn't read this many times") -} - -// Issue 32254: if the request body is larger than the specified -// content length, the client should refuse to send the extra part -// and abort the stream. -// -// In _len3 case, the first Read() matches the expected content length -// but the second read returns more data. -// -// In _len2 case, the first Read() exceeds the expected content length. -func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) { - body := &chunkReader{[][]byte{ - []byte("123"), - []byte("456"), - }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) -} - -func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) { - body := &chunkReader{[][]byte{ - []byte("123"), - }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) -} - -func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - r.Body.Read(make([]byte, 6)) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - req, _ := http.NewRequest("POST", st.ts.URL, body) - req.ContentLength = contentLen - _, err := tr.RoundTrip(req) - if err != errReqBodyTooLong { - t.Fatalf("expected %v, got %v", errReqBodyTooLong, err) - } -} - -func TestClientConnTooIdle(t *testing.T) { - tests := []struct { - cc func() *ClientConn - want bool - }{ - { - func() *ClientConn { - return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} - }, - true, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} - }, - false, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} - }, - false, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} - }, - false, - }, - } - for i, tt := range tests { - got := tt.cc().tooIdleLocked() - if got != tt.want { - t.Errorf("%d. got %v; want %v", i, got, tt.want) - } - } -} - -type fakeConnErr struct { - net.Conn - writeErr error - closed bool -} - -func (fce *fakeConnErr) Write(b []byte) (n int, err error) { - return 0, fce.writeErr -} - -func (fce *fakeConnErr) Close() error { - fce.closed = true - return nil -} - -// issue 39337: close the connection on a failed write -func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { - tr := &Transport{Options: &transport.Options{}} - writeErr := errors.New("write error") - fakeConn := &fakeConnErr{writeErr: writeErr} - _, err := tr.NewClientConn(fakeConn) - if err != writeErr { - t.Fatalf("expected %v, got %v", writeErr, err) - } - if !fakeConn.closed { - t.Error("expected closed conn") - } -} - -func TestTransportRoundtripCloseOnWriteError(t *testing.T) { - req, err := http.NewRequest("GET", "https://dummy.tld/", nil) - if err != nil { - t.Fatal(err) - } - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) - if err != nil { - t.Fatal(err) - } - - writeErr := errors.New("write error") - cc.wmu.Lock() - cc.werr = writeErr - cc.wmu.Unlock() - - _, err = cc.RoundTrip(req) - if err != writeErr { - t.Fatalf("expected %v, got %v", writeErr, err) - } - - cc.mu.Lock() - closed := cc.closed - cc.mu.Unlock() - if !closed { - t.Fatal("expected closed") - } -} - -type closeChecker struct { - io.ReadCloser - closed chan struct{} -} - -func newCloseChecker(r io.ReadCloser) *closeChecker { - return &closeChecker{r, make(chan struct{})} -} - -func newStaticCloseChecker(body string) *closeChecker { - return newCloseChecker(io.NopCloser(strings.NewReader("body"))) -} - -func (rc *closeChecker) Read(b []byte) (n int, err error) { - select { - default: - case <-rc.closed: - // TODO(dneil): Consider restructuring the request write to avoid reading - // from the request body after closing it, and check for read-after-close here. - // Currently, abortRequestBodyWrite races with writeRequestBody. - return 0, errors.New("read after Body.Close") - } - return rc.ReadCloser.Read(b) -} - -func (rc *closeChecker) Close() error { - close(rc.closed) - return rc.ReadCloser.Close() -} - -func (rc *closeChecker) isClosed() error { - // The RoundTrip contract says that it will close the request body, - // but that it may do so in a separate goroutine. Wait a reasonable - // amount of time before concluding that the body isn't being closed. - timeout := time.Duration(10 * time.Second) - select { - case <-rc.closed: - case <-time.After(timeout): - return fmt.Errorf("body not closed after %v", timeout) - } - return nil -} - -// A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written. -type blockingWriteConn struct { - net.Conn - writeOnce sync.Once - writec chan struct{} // closed after the write limit is reached - unblockc chan struct{} // closed to unblock writes - count, limit int -} - -func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn { - return &blockingWriteConn{ - Conn: conn, - limit: limit, - writec: make(chan struct{}), - unblockc: make(chan struct{}), - } -} - -// wait waits until the conn blocks writing the limit+1st byte. -func (c *blockingWriteConn) wait() { - <-c.writec -} - -// unblock unblocks writes to the conn. -func (c *blockingWriteConn) unblock() { - close(c.unblockc) -} - -func (c *blockingWriteConn) Write(b []byte) (n int, err error) { - if c.count+len(b) > c.limit { - c.writeOnce.Do(func() { - close(c.writec) - }) - <-c.unblockc - } - n, err = c.Conn.Write(b) - c.count += n - return n, err -} - -// Write several requests to a ClientConn at the same time, looking for race conditions. -// See golang.org/issue/48340 -func TestTransportFrameBufferReuse(t *testing.T) { - filler := hex.EncodeToString([]byte(randString(2048))) - - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - if got, want := r.Header.Get("Big"), filler; got != want { - t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want) - } - b, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("error reading request body: %v", err) - } - if got, want := string(b), filler; got != want { - t.Errorf("request body = %q, want %q", got, want) - } - if got, want := r.Trailer.Get("Big"), filler; got != want { - t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want) - } - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - var wg sync.WaitGroup - defer wg.Wait() - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Big", filler) - req.Trailer = make(http.Header) - req.Trailer.Set("Big", filler) - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - }() - } - -} - -// Ensure that a request blocking while being written to the underlying net.Conn doesn't -// block access to the ClientConn pool. Test requests blocking while writing headers, the body, -// and trailers. -// See golang.org/issue/32388 -func TestTransportBlockingRequestWrite(t *testing.T) { - filler := hex.EncodeToString([]byte(randString(2048))) - for _, test := range []struct { - name string - req func(url string) (*http.Request, error) - }{{ - name: "headers", - req: func(url string) (*http.Request, error) { - req, err := http.NewRequest("POST", url, nil) - if err != nil { - return nil, err - } - req.Header.Set("Big", filler) - return req, err - }, - }, { - name: "body", - req: func(url string) (*http.Request, error) { - req, err := http.NewRequest("POST", url, strings.NewReader(filler)) - if err != nil { - return nil, err - } - return req, err - }, - }, { - name: "trailer", - req: func(url string) (*http.Request, error) { - req, err := http.NewRequest("POST", url, strings.NewReader("body")) - if err != nil { - return nil, err - } - req.Trailer = make(http.Header) - req.Trailer.Set("Big", filler) - return req, err - }, - }} { - test := test - t.Run(test.name, func(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - if v := r.Header.Get("Big"); v != "" && v != filler { - t.Errorf("request header mismatch") - } - if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler { - t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler) - } - if v := r.Trailer.Get("Big"); v != "" && v != filler { - t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler) - } - }, optOnlyServer, func(s *Server) { - s.MaxConcurrentStreams = 1 - }) - defer st.Close() - - // This Transport creates connections that block on writes after 1024 bytes. - connc := make(chan *blockingWriteConn, 1) - connCount := 0 - tr := &Transport{ - Options: &transport.Options{ - TLSClientConfig: tlsConfigInsecure, - }, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - connCount++ - c, err := tls.Dial(network, addr, cfg) - wc := newBlockingWriteConn(c, 1024) - select { - case connc <- wc: - default: - } - return wc, err - }, - } - defer tr.CloseIdleConnections() - - // Request 1: A small request to ensure we read the server MaxConcurrentStreams. - { - req, err := http.NewRequest("POST", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - } - - // Request 2: A large request that blocks while being written. - reqc := make(chan struct{}) - go func() { - defer close(reqc) - req, err := test.req(st.ts.URL) - if err != nil { - t.Error(err) - return - } - res, _ := tr.RoundTrip(req) - if res != nil && res.Body != nil { - res.Body.Close() - } - }() - conn := <-connc - conn.wait() // wait for the request to block - - // Request 3: A small request that is sent on a new connection, since request 2 - // is hogging the only available stream on the previous connection. - { - req, err := http.NewRequest("POST", st.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - } - - // Request 2 should still be blocking at this point. - select { - case <-reqc: - t.Errorf("request 2 unexpectedly completed") - default: - } - - conn.unblock() - <-reqc - - if connCount != 2 { - t.Errorf("created %v connections, want 1", connCount) - } - }) - } -} - -func TestTransportCloseRequestBody(t *testing.T) { - var statusCode int - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(statusCode) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) - if err != nil { - t.Fatal(err) - } - - for _, status := range []int{200, 401} { - t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) { - statusCode = status - pr, pw := io.Pipe() - body := newCloseChecker(pr) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - t.Fatal(err) - } - res, err := cc.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - pw.Close() - if err := body.isClosed(); err != nil { - t.Fatal(err) - } - }) - } -} - -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { - cc, err := p.ClientConnPool.GetClientConn(req, addr, dialOnMiss) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - ClientConnPool: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool - - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } - } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil - } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) - - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - return nil - } - ct.run() -} - -func TestClientConnReservations(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, func(s *Server) { - s.MaxConcurrentStreams = initialMaxConcurrentStreams - }) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - cc, err := tr.newClientConn(st.cc, false) - if err != nil { - t.Fatal(err) - } - - req, _ := http.NewRequest("GET", st.ts.URL, nil) - n := 0 - for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { - n++ - } - if n != initialMaxConcurrentStreams { - t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) - } - if _, err := cc.RoundTrip(req); err != nil { - t.Fatalf("RoundTrip error = %v", err) - } - n2 := 0 - for n2 <= 5 && cc.ReserveNewRequest() { - n2++ - } - if n2 != 1 { - t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2) - } - - // Use up all the reservations - for i := 0; i < n; i++ { - cc.RoundTrip(req) - } - - n2 = 0 - for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { - n2++ - } - if n2 != n { - t.Errorf("after reset, reservations = %v; want %v", n2, n) - } -} - -func TestTransportTimeoutServerHangs(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - - req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) - if err != nil { - return err - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - req.Header.Add("Big", strings.Repeat("a", 1<<20)) - _, err = ct.tr.RoundTrip(req) - if err == nil { - return errors.New("error should not be nil") - } - if ne, ok := err.(net.Error); !ok || !ne.Timeout() { - return fmt.Errorf("error should be a net error timeout: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - select { - case <-time.After(5 * time.Second): - case <-clientDone: - } - return nil - } - ct.run() -} - -func TestTransportContentLengthWithoutBody(t *testing.T) { - contentLength := "" - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", contentLength) - }, optOnlyServer) - defer st.Close() - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - for _, test := range []struct { - name string - contentLength string - wantBody string - wantErr error - wantContentLength int64 - }{ - { - name: "non-zero content length", - contentLength: "42", - wantErr: io.ErrUnexpectedEOF, - wantContentLength: 42, - }, - { - name: "zero content length", - contentLength: "0", - wantErr: nil, - wantContentLength: 0, - }, - } { - t.Run(test.name, func(t *testing.T) { - contentLength = test.contentLength - - req, _ := http.NewRequest("GET", st.ts.URL, nil) - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - body, err := io.ReadAll(res.Body) - - if err != test.wantErr { - t.Errorf("Expected error %v, got: %v", test.wantErr, err) - } - if len(body) > 0 { - t.Errorf("Expected empty body, got: %v", body) - } - if res.ContentLength != test.wantContentLength { - t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength) - } - }) - } -} - -func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - w.(http.Flusher).Flush() - io.Copy(io.Discard, r.Body) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - pr, pw := net.Pipe() - req, err := http.NewRequest("GET", st.ts.URL, pr) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - // Closing the Response's Body interrupts the blocked body read. - res.Body.Close() - pw.Close() -} - -func TestTransport300ResponseBody(t *testing.T) { - reqc := make(chan struct{}) - body := []byte("response body") - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(300) - w.(http.Flusher).Flush() - <-reqc - w.Write(body) - }, optOnlyServer) - defer st.Close() - - tr := &Transport{Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}} - defer tr.CloseIdleConnections() - - pr, pw := net.Pipe() - req, err := http.NewRequest("GET", st.ts.URL, pr) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - close(reqc) - got, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("error reading response body: %v", err) - } - if !bytes.Equal(got, body) { - t.Errorf("got response body %q, want %q", string(got), string(body)) - } - res.Body.Close() - pw.Close() -} - -func TestTransportWriteByteTimeout(t *testing.T) { - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) {}, - optOnlyServer, - ) - defer st.Close() - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - _, c := net.Pipe() - return c, nil - }, - WriteByteTimeout: 1 * time.Millisecond, - } - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - _, err := c.Get(st.ts.URL) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err) - } -} - -type slowWriteConn struct { - net.Conn - hasWriteDeadline bool -} - -func (c *slowWriteConn) SetWriteDeadline(t time.Time) error { - c.hasWriteDeadline = !t.IsZero() - return nil -} - -func (c *slowWriteConn) Write(b []byte) (n int, err error) { - if c.hasWriteDeadline && len(b) > 1 { - n, err = c.Conn.Write(b[:1]) - if err != nil { - return n, err - } - return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded) - } - return c.Conn.Write(b) -} - -func TestTransportSlowWrites(t *testing.T) { - st := newServerTester(t, - func(w http.ResponseWriter, r *http.Request) {}, - optOnlyServer, - ) - defer st.Close() - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - cfg.InsecureSkipVerify = true - c, err := tls.Dial(network, addr, cfg) - return &slowWriteConn{Conn: c}, err - }, - WriteByteTimeout: 1 * time.Millisecond, - } - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - const bodySize = 1 << 20 - resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(tests.NeverEnding('A'), bodySize)) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - -func TestCountReadFrameError(t *testing.T) { - cc := &ClientConn{} - errMsg := "" - countError := func(errType string) { - errMsg = errType - } - cc.t = &Transport{CountError: countError} - - var err error - cc.countReadFrameError(err) - tests.AssertEqual(t, "", errMsg) - - err = ConnectionError(ErrCodeInternal) - cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_conn_error", true) - - err = io.EOF - cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_eof", true) - - err = io.ErrUnexpectedEOF - cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_unexpected_eof", true) - - err = errFrameTooLarge - cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_too_large", true) - - err = errors.New("other") - cc.countReadFrameError(err) - tests.AssertContains(t, errMsg, "read_frame_other", true) -} - -func TestProcessHeaders(t *testing.T) { - rl := &clientConnReadLoop{} - cc := &ClientConn{streams: map[uint32]*clientStream{}} - cc.streams[1] = &clientStream{cc: cc, abort: make(chan struct{})} - rl.cc = cc - f := &MetaHeadersFrame{HeadersFrame: &HeadersFrame{ - FrameHeader: FrameHeader{StreamID: 1}, - }} - err := rl.processHeaders(f) - tests.AssertNoError(t, err) - - f.StreamID = 0 - err = rl.processHeaders(f) - tests.AssertNoError(t, err) -} - -func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 0) -} -func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 1) -} - -type closeOnceConn struct { - net.Conn - closed uint32 -} - -var errClosed = errors.New("Close of closed connection") - -func (c *closeOnceConn) Close() error { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return c.Conn.Close() - } - return errClosed -} - -// testTransportClosesConnAfterGoAway verifies that the transport -// closes a connection after reading a GOAWAY from it. -// -// lastStream is the last stream ID in the GOAWAY frame. -// When 0, the transport (unsuccessfully) retries the request (stream 1); -// when 1, the transport reads the response after receiving the GOAWAY. -func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { - ct := newClientTester(t) - ct.cc = &closeOnceConn{Conn: ct.cc} - - var wg sync.WaitGroup - wg.Add(1) - ct.client = func() error { - defer wg.Done() - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - } - if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { - t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) - } - if err = ct.cc.Close(); err != errClosed { - return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) - } - return nil - } - - ct.server = func() error { - defer wg.Wait() - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - return fmt.Errorf("server failed reading HEADERS: %v", err) - } - if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { - return fmt.Errorf("server failed writing GOAWAY: %v", err) - } - if lastStream > 0 { - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - return nil - } - - ct.run() -} - -type slowCloser struct { - closing chan struct{} - closed chan struct{} -} - -func (r *slowCloser) Read([]byte) (int, error) { - return 0, io.EOF -} - -func (r *slowCloser) Close() error { - close(r.closing) - <-r.closed - return nil -} - -func TestTransportSlowClose(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, optOnlyServer) - defer st.Close() - - client := st.ts.Client() - body := &slowCloser{ - closing: make(chan struct{}), - closed: make(chan struct{}), - } - - reqc := make(chan struct{}) - go func() { - defer close(reqc) - res, err := client.Post(st.ts.URL, "text/plain", body) - if err != nil { - t.Error(err) - } - res.Body.Close() - }() - defer func() { - close(body.closed) - <-reqc // wait for POST request to finish - }() - - <-body.closing // wait for POST request to call body.Close - // This GET request should not be blocked by the in-progress POST. - res, err := client.Get(st.ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() -} - -type blockReadConn struct { - net.Conn - blockc chan struct{} -} - -func (c *blockReadConn) Read(b []byte) (n int, err error) { - <-c.blockc - return c.Conn.Read(b) -} - -func TestTransportReuseAfterError(t *testing.T) { - serverReqc := make(chan struct{}, 3) - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - serverReqc <- struct{}{} - }, optOnlyServer) - defer st.Close() - - var ( - unblockOnce sync.Once - blockc = make(chan struct{}) - connCountMu sync.Mutex - connCount int - ) - tr := &Transport{ - Options: &transport.Options{TLSClientConfig: tlsConfigInsecure}, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - // The first connection dialed will block on reads until blockc is closed. - connCountMu.Lock() - defer connCountMu.Unlock() - connCount++ - conn, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - if connCount == 1 { - return &blockReadConn{ - Conn: conn, - blockc: blockc, - }, nil - } - return conn, nil - }, - } - defer tr.CloseIdleConnections() - defer unblockOnce.Do(func() { - // Ensure that reads on blockc are unblocked if we return early. - close(blockc) - }) - - req, _ := http.NewRequest("GET", st.ts.URL, nil) - - // Request 1 is made on conn 1. - // Reading the response will block. - // Wait until the server receives the request, and continue. - req1c := make(chan struct{}) - go func() { - defer close(req1c) - res1, err := tr.RoundTrip(req.Clone(context.Background())) - if err != nil { - t.Errorf("request 1: %v", err) - } else { - res1.Body.Close() - } - }() - <-serverReqc - - // Request 2 is also made on conn 1. - // Reading the response will block. - // The request is canceled once the server receives it. - // Conn 1 should now be flagged as unfit for reuse. - req2Ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-serverReqc - cancel() - }() - _, err := tr.RoundTrip(req.Clone(req2Ctx)) - if err == nil { - t.Errorf("request 2 unexpectedly succeeded (want cancel)") - } - - // Request 3 is made on a new conn, and succeeds. - res3, err := tr.RoundTrip(req.Clone(context.Background())) - if err != nil { - t.Fatalf("request 3: %v", err) - } - res3.Body.Close() - - // Unblock conn 1, and verify that request 1 completes. - unblockOnce.Do(func() { - close(blockc) - }) - <-req1c -} From 1c7cbd881bb174fc72f66f5c2afed370fe6ff30f Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 15:42:16 +0800 Subject: [PATCH 735/843] export http2 setting --- internal/http2/frame.go | 10 +++++----- internal/http2/http2.go | 36 ------------------------------------ internal/http2/transport.go | 6 +++--- pkg/http2/setting.go | 14 ++++++++++++++ 4 files changed, 22 insertions(+), 44 deletions(-) diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 5f8afc21..3b0922c4 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -821,9 +821,9 @@ func (f *SettingsFrame) Value(id http2.SettingID) (v uint32, ok bool) { // Setting returns the setting from the frame at the given 0-based index. // The index must be >= 0 and less than f.NumSettings(). -func (f *SettingsFrame) Setting(i int) Setting { +func (f *SettingsFrame) Setting(i int) http2.Setting { buf := f.p - return Setting{ + return http2.Setting{ ID: http2.SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), } @@ -864,7 +864,7 @@ func (f *SettingsFrame) HasDuplicates() bool { // ForeachSetting runs fn for each setting. // It stops and returns the first error. -func (f *SettingsFrame) ForeachSetting(fn func(Setting) error) error { +func (f *SettingsFrame) ForeachSetting(fn func(http2.Setting) error) error { f.checkValid() for i := 0; i < f.NumSettings(); i++ { if err := fn(f.Setting(i)); err != nil { @@ -879,7 +879,7 @@ func (f *SettingsFrame) ForeachSetting(fn func(Setting) error) error { // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *Framer) WriteSettings(settings ...Setting) error { +func (h2f *Framer) WriteSettings(settings ...http2.Setting) error { h2f.startWrite(FrameSettings, 0, 0) for _, s := range settings { h2f.writeUint16(uint16(s.ID)) @@ -1685,7 +1685,7 @@ func summarizeFrame(f Frame) string { switch f := f.(type) { case *SettingsFrame: n := 0 - f.ForeachSetting(func(s Setting) error { + f.ForeachSetting(func(s http2.Setting) error { n++ if n == 1 { buf.WriteString(", settings:") diff --git a/internal/http2/http2.go b/internal/http2/http2.go index 15932b5a..3e3e9350 100644 --- a/internal/http2/http2.go +++ b/internal/http2/http2.go @@ -7,8 +7,6 @@ package http2 import ( "bufio" "crypto/tls" - "fmt" - "github.com/imroc/req/v3/pkg/http2" "golang.org/x/net/http/httpguts" "net/http" "os" @@ -56,40 +54,6 @@ var ( clientPreface = []byte(ClientPreface) ) -// Setting is a setting parameter: which setting it is, and its value. -type Setting struct { - // ID is which setting is being set. - // See https://httpwg.org/specs/rfc7540.html#SettingValues - ID http2.SettingID - - // Val is the value. - Val uint32 -} - -func (s Setting) String() string { - return fmt.Sprintf("[%v = %d]", s.ID, s.Val) -} - -// Valid reports whether the setting is valid. -func (s Setting) Valid() error { - // Limits and error codes from 6.5.2 Defined SETTINGS Parameters - switch s.ID { - case http2.SettingEnablePush: - if s.Val != 1 && s.Val != 0 { - return ConnectionError(ErrCodeProtocol) - } - case http2.SettingInitialWindowSize: - if s.Val > 1<<31-1 { - return ConnectionError(ErrCodeFlowControl) - } - case http2.SettingMaxFrameSize: - if s.Val < 16384 || s.Val > 1<<24-1 { - return ConnectionError(ErrCodeProtocol) - } - } - return nil -} - // validWireHeaderFieldName reports whether v is a valid header field // name (key). See httpguts.ValidHeaderName for the base rules. // diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 9186085b..46f1477a 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -678,12 +678,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.tlsState = &state } - initialSettings := []Setting{ + initialSettings := []http2.Setting{ {ID: http2.SettingEnablePush, Val: 0}, {ID: http2.SettingInitialWindowSize, Val: transportDefaultStreamFlow}, } if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, Setting{ID: http2.SettingMaxHeaderListSize, Val: max}) + initialSettings = append(initialSettings, http2.Setting{ID: http2.SettingMaxHeaderListSize, Val: max}) } cc.bw.Write(clientPreface) @@ -2808,7 +2808,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { } var seenMaxConcurrentStreams bool - err := f.ForeachSetting(func(s Setting) error { + err := f.ForeachSetting(func(s http2.Setting) error { switch s.ID { case http2.SettingMaxFrameSize: cc.maxFrameSize = s.Val diff --git a/pkg/http2/setting.go b/pkg/http2/setting.go index 6d0e0aa1..f351238d 100644 --- a/pkg/http2/setting.go +++ b/pkg/http2/setting.go @@ -32,3 +32,17 @@ func (s SettingID) String() string { } return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) } + +// Setting is a setting parameter: which setting it is, and its value. +type Setting struct { + // ID is which setting is being set. + // See https://httpwg.org/specs/rfc7540.html#SettingValues + ID SettingID + + // Val is the value. + Val uint32 +} + +func (s Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} From aadac704abb83ca9e491a1ec1c36f878d1323971 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 19:25:55 +0800 Subject: [PATCH 736/843] Support more http fingerprint customization: * Add SetHTTP2SettingsFrame to Transport and Client. * Add SetProfile to Client. * Add ClientProfile_Chrome. --- client.go | 13 ++++++++++++ internal/http2/transport.go | 19 +++++++++++------ profile.go | 41 +++++++++++++++++++++++++++++++++++++ transport.go | 29 ++++++++++++++++---------- 4 files changed, 85 insertions(+), 17 deletions(-) create mode 100644 profile.go diff --git a/client.go b/client.go index 9cd56774..d6d614f8 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "encoding/json" "encoding/xml" "errors" + "github.com/imroc/req/v3/pkg/http2" "io" "net" "net/http" @@ -902,6 +903,18 @@ func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { return c } +// SetHTTP2SettingsFrame set the ordered http2 settings frame. +func (c *Client) SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { + c.t.SetHTTP2SettingsFrame(settings...) + return c +} + +// SetProfile set the http client profile. +func (c *Client) SetProfile(p ClientProfile) *Client { + p(c) + return c +} + // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 46f1477a..889c1a4c 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -132,6 +132,8 @@ type Transport struct { // The errType consists of only ASCII word characters. CountError func(errType string) + Settings []http2.Setting + connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool } @@ -678,12 +680,17 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.tlsState = &state } - initialSettings := []http2.Setting{ - {ID: http2.SettingEnablePush, Val: 0}, - {ID: http2.SettingInitialWindowSize, Val: transportDefaultStreamFlow}, - } - if max := t.maxHeaderListSize(); max != 0 { - initialSettings = append(initialSettings, http2.Setting{ID: http2.SettingMaxHeaderListSize, Val: max}) + var initialSettings []http2.Setting + if len(t.Settings) > 0 { + initialSettings = t.Settings + } else { + initialSettings = []http2.Setting{ + {ID: http2.SettingEnablePush, Val: 0}, + {ID: http2.SettingInitialWindowSize, Val: transportDefaultStreamFlow}, + } + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2.Setting{ID: http2.SettingMaxHeaderListSize, Val: max}) + } } cc.bw.Write(clientPreface) diff --git a/profile.go b/profile.go new file mode 100644 index 00000000..014f6aa9 --- /dev/null +++ b/profile.go @@ -0,0 +1,41 @@ +package req + +import "github.com/imroc/req/v3/pkg/http2" + +type ClientProfile func(c *Client) + +var http2SettingsChrome = []http2.Setting{ + { + ID: http2.SettingHeaderTableSize, + Val: 65536, + }, + { + ID: http2.SettingEnablePush, + Val: 0, + }, + { + ID: http2.SettingMaxConcurrentStreams, + Val: 1000, + }, + { + ID: http2.SettingInitialWindowSize, + Val: 6291456, + }, + { + ID: http2.SettingMaxHeaderListSize, + Val: 262144, + }, +} + +var chromePseudoHeaderOrder = []string{ + ":method", + ":authority", + ":scheme", + ":path", +} + +var ClientProfile_Chrome ClientProfile = func(c *Client) { + c.SetTLSFingerprintChrome(). + SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). + SetHTTP2SettingsFrame(http2SettingsChrome...) +} diff --git a/transport.go b/transport.go index f1ee7698..1e49c236 100644 --- a/transport.go +++ b/transport.go @@ -22,13 +22,14 @@ import ( "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/http2" + h2internal "github.com/imroc/req/v3/internal/http2" "github.com/imroc/req/v3/internal/http3" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/socks" "github.com/imroc/req/v3/internal/transport" "github.com/imroc/req/v3/internal/util" "github.com/imroc/req/v3/pkg/altsvc" + "github.com/imroc/req/v3/pkg/http2" reqtls "github.com/imroc/req/v3/pkg/tls" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" @@ -122,7 +123,7 @@ type Transport struct { transport.Options - t2 *http2.Transport // non-nil if http2 wired up + t2 *h2internal.Transport // non-nil if http2 wired up t3 *http3.RoundTripper // disableAutoDecode, if true, prevents auto detect response @@ -154,7 +155,7 @@ func T() *Transport { TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}, }, } - t.t2 = &http2.Transport{Options: &t.Options} + t.t2 = &h2internal.Transport{Options: &t.Options} return t } @@ -450,6 +451,12 @@ func (t *Transport) SetHTTP2WriteByteTimeout(timeout time.Duration) *Transport { return t } +// SetHTTP2SettingsFrame set the ordered http2 settings frame. +func (t *Transport) SetHTTP2SettingsFrame(settings ...http2.Setting) *Transport { + t.t2.Settings = settings + return t +} + // SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to // use with tls.Client. // If nil, the default configuration is used. @@ -692,7 +699,7 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu switch b := res.Body.(type) { case *gzipReader: b.body.body = wrap(b.body.body) - case *http2.GzipReader: + case *h2internal.GzipReader: b.Body = wrap(b.Body) case *http3.GzipReader: b.Body = wrap(b.Body) @@ -785,7 +792,7 @@ func (t *Transport) Clone() *Transport { } } if t.t2 != nil { - tt.t2 = &http2.Transport{Options: &tt.Options} + tt.t2 = &h2internal.Transport{Options: &tt.Options} } if t.t3 != nil { tt.EnableHTTP3() @@ -934,7 +941,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error if scheme == "https" && t.forceHttpVersion != h1 { resp, err := t.t2.RoundTripOnlyCachedConn(req) - if err != http2.ErrNoCachedConn { + if err != h2internal.ErrNoCachedConn { return resp, err } req, err = rewindBody(req) @@ -1007,7 +1014,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } // Failed. Clean up and determine whether to retry. - if http2.IsNoCachedConnError(err) { + if h2internal.IsNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) } @@ -1096,7 +1103,7 @@ func rewindBody(req *http.Request) (rewound *http.Request, err error) { // HTTP request on a new connection. The non-nil input error is the // error from roundTrip. func (pc *persistConn) shouldRetryRequest(req *http.Request, err error) bool { - if http2.IsNoCachedConnError(err) { + if h2internal.IsNoCachedConnError(err) { // Issue 16582: if the user started a bunch of // requests at once, they can all pick the same conn // and violate the server's max concurrent streams. @@ -1911,7 +1918,7 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace } pc.tlsState = &cs pc.conn = tlsConn - if !forProxy && pc.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != http2.NextProtoTLS { + if !forProxy && pc.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != h2internal.NextProtoTLS { return newHttp2NotSupportedError(cs.NegotiatedProtocol) } return nil @@ -2003,7 +2010,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers trace.TLSHandshakeDone(cs, nil) } pconn.tlsState = &cs - if cm.proxyURL == nil && pconn.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != http2.NextProtoTLS { + if cm.proxyURL == nil && pconn.t.forceHttpVersion == h2 && cs.NegotiatedProtocol != h2internal.NextProtoTLS { return nil, newHttp2NotSupportedError(cs.NegotiatedProtocol) } } @@ -2164,7 +2171,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if s := pconn.tlsState; t.forceHttpVersion != h1 && s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { - if s.NegotiatedProtocol == http2.NextProtoTLS { + if s.NegotiatedProtocol == h2internal.NextProtoTLS { if used, err := t.t2.AddConn(pconn.conn, cm.targetAddr); err != nil { go pconn.conn.Close() return nil, err From fa8a09fe70a97baf3ffca9492551e5e0a0b6c9ab Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 19:48:21 +0800 Subject: [PATCH 737/843] optimize Transport.Clone() --- transport.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/transport.go b/transport.go index 1e49c236..71d97f1b 100644 --- a/transport.go +++ b/transport.go @@ -792,7 +792,15 @@ func (t *Transport) Clone() *Transport { } } if t.t2 != nil { - tt.t2 = &h2internal.Transport{Options: &tt.Options} + tt.t2 = &h2internal.Transport{ + Options: &tt.Options, + Settings: t.t2.Settings, + MaxHeaderListSize: t.t2.MaxHeaderListSize, + StrictMaxConcurrentStreams: t.t2.StrictMaxConcurrentStreams, + ReadIdleTimeout: t.t2.ReadIdleTimeout, + PingTimeout: t.t2.PingTimeout, + WriteByteTimeout: t.t2.WriteByteTimeout, + } } if t.t3 != nil { tt.EnableHTTP3() From b7d9543ca44824786fa1f08a3724d4bba47d926d Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 1 Aug 2023 20:16:52 +0800 Subject: [PATCH 738/843] Add SetHeaders and SetHeadersNonCanonical to Transport --- transport.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/transport.go b/transport.go index 71d97f1b..c215cb96 100644 --- a/transport.go +++ b/transport.go @@ -211,6 +211,54 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } +// SetHeaders set headers for all requests, if the same header exists at +// the request-level, the request-level header takes precedence. +func (t *Transport) SetHeaders(hdrs http.Header) *Transport { + if len(hdrs) == 0 { + return t + } + t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + for k, v := range hdrs { + if len(v) == 0 { + continue + } + kk := textproto.CanonicalMIMEHeaderKey(k) + if len(req.Header[kk]) == 0 { + req.Header[kk] = v + } + } + return rt.RoundTrip(req) + } + }) + return t +} + +// SetHeadersNonCanonical set non-canonical headers for all requests, if the +// same header exists at the request-level, the request-level header takes precedence. +func (t *Transport) SetHeadersNonCanonical(hdrs http.Header) *Transport { + if len(hdrs) == 0 { + return t + } + t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + for k, v := range hdrs { + if len(v) > 0 && len(req.Header[k]) == 0 { + req.Header[k] = v + } + } + return rt.RoundTrip(req) + } + }) + return t +} + // SetHeaderOrder set the order of the http header (case-insensitive). // For example: // From cdf3a341e47ac46b581be58570a571b554808f69 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 10:28:13 +0800 Subject: [PATCH 739/843] Make Transport as a embed struct in Client --- client.go | 94 +++++++++++++++++++++++------------------------ client_test.go | 36 +++++++++--------- transport_test.go | 6 +-- 3 files changed, 68 insertions(+), 68 deletions(-) diff --git a/client.go b/client.go index d6d614f8..e1a41352 100644 --- a/client.go +++ b/client.go @@ -50,6 +50,7 @@ type Client struct { FormData urlpkg.Values DebugLog bool AllowGetMethodPayload bool + *Transport trace bool disableAutoReadResponse bool @@ -62,7 +63,6 @@ type Client struct { outputDirectory string scheme string log Logger - t *Transport dumpOptions *DumpOptions httpClient *http.Client beforeRequest []RequestMiddleware @@ -154,7 +154,7 @@ func (c *Client) Options(url ...string) *Request { // GetTransport return the underlying transport. func (c *Client) GetTransport() *Transport { - return c.t + return c.Transport } // SetResponseBodyTransformer set the response body transformer, which can modify the @@ -300,12 +300,12 @@ func (c *Client) SetRootCertsFromFile(pemFiles ...string) *Client { // GetTLSClientConfig return the underlying tls.Config. func (c *Client) GetTLSClientConfig() *tls.Config { - if c.t.TLSClientConfig == nil { - c.t.TLSClientConfig = &tls.Config{ + if c.TLSClientConfig == nil { + c.TLSClientConfig = &tls.Config{ NextProtos: []string{"h2", "http/1.1"}, } } - return c.t.TLSClientConfig + return c.TLSClientConfig } func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error { @@ -350,13 +350,13 @@ func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { // // This is unrelated to the similarly named TCP keep-alives. func (c *Client) DisableKeepAlives() *Client { - c.t.DisableKeepAlives = true + c.Transport.DisableKeepAlives = true return c } // EnableKeepAlives enables HTTP keep-alives (enabled by default). func (c *Client) EnableKeepAlives() *Client { - c.t.DisableKeepAlives = false + c.Transport.DisableKeepAlives = false return c } @@ -369,13 +369,13 @@ func (c *Client) EnableKeepAlives() *Client { // However, if the user explicitly requested gzip it is not // automatically uncompressed. func (c *Client) DisableCompression() *Client { - c.t.DisableCompression = true + c.Transport.DisableCompression = true return c } // EnableCompression enables the compression (enabled by default). func (c *Client) EnableCompression() *Client { - c.t.DisableCompression = false + c.Transport.DisableCompression = false return c } @@ -386,7 +386,7 @@ func (c *Client) EnableCompression() *Client { // overwriting some important configurations, such as not setting NextProtos // will not use http2 by default. func (c *Client) SetTLSClientConfig(conf *tls.Config) *Client { - c.t.TLSClientConfig = conf + c.TLSClientConfig = conf return c } @@ -558,10 +558,10 @@ func (c *Client) getDumpOptions() *DumpOptions { // EnableDumpAll enable dump for requests fired from the client, including // all content for the request and response by default. func (c *Client) EnableDumpAll() *Client { - if c.t.Dump != nil { // dump already started + if c.Dump != nil { // dump already started return c } - c.t.EnableDump(c.getDumpOptions()) + c.EnableDump(c.getDumpOptions()) return c } @@ -771,31 +771,31 @@ func (c *Client) EnableAutoReadResponse() *Client { // SetAutoDecodeContentType set the content types that will be auto-detected and decode to utf-8 // (e.g. "json", "xml", "html", "text"). func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client { - c.t.SetAutoDecodeContentType(contentTypes...) + c.Transport.SetAutoDecodeContentType(contentTypes...) return c } // SetAutoDecodeContentTypeFunc set the function that determines whether the specified `Content-Type` should be auto-detected and decode to utf-8. func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client { - c.t.SetAutoDecodeContentTypeFunc(fn) + c.Transport.SetAutoDecodeContentTypeFunc(fn) return c } // SetAutoDecodeAllContentType enable try auto-detect charset and decode all content type to utf-8. func (c *Client) SetAutoDecodeAllContentType() *Client { - c.t.SetAutoDecodeAllContentType() + c.Transport.SetAutoDecodeAllContentType() return c } // DisableAutoDecode disable auto-detect charset and decode to utf-8 (enabled by default). func (c *Client) DisableAutoDecode() *Client { - c.t.DisableAutoDecode() + c.Transport.DisableAutoDecode() return c } // EnableAutoDecode enable auto-detect charset and decode to utf-8 (enabled by default). func (c *Client) EnableAutoDecode() *Client { - c.t.EnableAutoDecode() + c.Transport.EnableAutoDecode() return c } @@ -883,7 +883,7 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { // "accept-encoding", // ).Get(url func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { - c.t.SetHeaderOrder(keys...) + c.SetHeaderOrder(keys...) return c } @@ -899,13 +899,13 @@ func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { // ":method", // ) func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { - c.t.SetPseudoHeaderOder(keys...) + c.SetPseudoHeaderOder(keys...) return c } // SetHTTP2SettingsFrame set the ordered http2 settings frame. func (c *Client) SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { - c.t.SetHTTP2SettingsFrame(settings...) + c.SetHTTP2SettingsFrame(settings...) return c } @@ -924,7 +924,7 @@ func (c *Client) SetCommonContentType(ct string) *Client { // DisableDumpAll disable dump for requests fired from the client. func (c *Client) DisableDumpAll() *Client { - c.t.DisableDump() + c.DisableDump() return c } @@ -942,15 +942,15 @@ func (c *Client) SetCommonDumpOptions(opt *DumpOptions) *Client { } } c.dumpOptions = opt - if c.t.Dump != nil { - c.t.Dump.SetOptions(dumpOptions{opt}) + if c.Dump != nil { + c.Dump.SetOptions(dumpOptions{opt}) } return c } // SetProxy set the proxy function. func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Client { - c.t.SetProxy(proxy) + c.Transport.SetProxy(proxy) return c } @@ -978,7 +978,7 @@ func (c *Client) SetProxyURL(proxyUrl string) *Client { return c } proxy := http.ProxyURL(u) - c.t.SetProxy(proxy) + c.SetProxy(proxy) return c } @@ -1055,13 +1055,13 @@ func (c *Client) SetXmlUnmarshal(fn func(data []byte, v interface{}) error) *Cli // Make sure the returned `conn` implements pkg/tls.Conn if you want your // customized `conn` supports HTTP2. func (c *Client) SetDialTLS(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - c.t.SetDialTLS(fn) + c.Transport.SetDialTLS(fn) return c } // SetDial set the customized `DialContext` function to Transport. func (c *Client) SetDial(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { - c.t.SetDial(fn) + c.Transport.SetDial(fn) return c } @@ -1168,7 +1168,7 @@ func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { } return } - c.SetTLSHandshake(fn) + c.Transport.SetTLSHandshake(fn) return c } @@ -1176,52 +1176,52 @@ func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { // it specifies an optional dial function for tls handshake, it works even if a proxy is set, can be // used to customize the tls fingerprint. func (c *Client) SetTLSHandshake(fn func(ctx context.Context, addr string, plainConn net.Conn) (conn net.Conn, tlsState *tls.ConnectionState, err error)) *Client { - c.t.SetTLSHandshake(fn) + c.Transport.SetTLSHandshake(fn) return c } // SetTLSHandshakeTimeout set the TLS handshake timeout. func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { - c.t.SetTLSHandshakeTimeout(timeout) + c.Transport.SetTLSHandshakeTimeout(timeout) return c } // EnableForceHTTP1 enable force using HTTP1 (disabled by default). func (c *Client) EnableForceHTTP1() *Client { - c.t.EnableForceHTTP1() + c.Transport.EnableForceHTTP1() return c } // EnableForceHTTP2 enable force using HTTP2 for https requests // (disabled by default). func (c *Client) EnableForceHTTP2() *Client { - c.t.EnableForceHTTP2() + c.Transport.EnableForceHTTP2() return c } // EnableForceHTTP3 enable force using HTTP3 for https requests // (disabled by default). func (c *Client) EnableForceHTTP3() *Client { - c.t.EnableForceHTTP3() + c.Transport.EnableForceHTTP3() return c } // DisableForceHttpVersion disable force using specified http // version (disabled by default). func (c *Client) DisableForceHttpVersion() *Client { - c.t.DisableForceHttpVersion() + c.Transport.DisableForceHttpVersion() return c } // EnableH2C enables HTTP/2 over TCP without TLS. func (c *Client) EnableH2C() *Client { - c.t.EnableH2C() + c.Transport.EnableH2C() return c } // DisableH2C disables HTTP/2 over TCP without TLS. func (c *Client) DisableH2C() *Client { - c.t.DisableH2C() + c.Transport.DisableH2C() return c } @@ -1335,13 +1335,13 @@ func (c *Client) SetUnixSocket(file string) *Client { // DisableHTTP3 disables the http3 protocol. func (c *Client) DisableHTTP3() *Client { - c.t.DisableHTTP3() + c.Transport.DisableHTTP3() return c } // EnableHTTP3 enables the http3 protocol. func (c *Client) EnableHTTP3() *Client { - c.t.EnableHTTP3() + c.Transport.EnableHTTP3() return c } @@ -1354,7 +1354,7 @@ func (c *Client) EnableHTTP3() *Client { // interprets the highest possible value here (0xffffffff or 1<<32-1) // to mean no limit. func (c *Client) SetHTTP2MaxHeaderListSize(max uint32) *Client { - c.t.SetHTTP2MaxHeaderListSize(max) + c.Transport.SetHTTP2MaxHeaderListSize(max) return c } @@ -1368,7 +1368,7 @@ func (c *Client) SetHTTP2MaxHeaderListSize(max uint32) *Client { // a global limit and callers of RoundTrip block when needed, // waiting for their turn. func (c *Client) SetHTTP2StrictMaxConcurrentStreams(strict bool) *Client { - c.t.SetHTTP2StrictMaxConcurrentStreams(strict) + c.Transport.SetHTTP2StrictMaxConcurrentStreams(strict) return c } @@ -1380,7 +1380,7 @@ func (c *Client) SetHTTP2StrictMaxConcurrentStreams(strict bool) *Client { // be performed every ReadIdleTimeout interval. // If zero, no health check is performed. func (c *Client) SetHTTP2ReadIdleTimeout(timeout time.Duration) *Client { - c.t.SetHTTP2ReadIdleTimeout(timeout) + c.Transport.SetHTTP2ReadIdleTimeout(timeout) return c } @@ -1389,7 +1389,7 @@ func (c *Client) SetHTTP2ReadIdleTimeout(timeout time.Duration) *Client { // not received. // Defaults to 15s func (c *Client) SetHTTP2PingTimeout(timeout time.Duration) *Client { - c.t.SetHTTP2PingTimeout(timeout) + c.Transport.SetHTTP2PingTimeout(timeout) return c } @@ -1398,7 +1398,7 @@ func (c *Client) SetHTTP2PingTimeout(timeout time.Duration) *Client { // to it. The timeout begins when data is available to write, and is // extended whenever any bytes are written. func (c *Client) SetHTTP2WriteByteTimeout(timeout time.Duration) *Client { - c.t.SetHTTP2WriteByteTimeout(timeout) + c.Transport.SetHTTP2WriteByteTimeout(timeout) return c } @@ -1412,12 +1412,12 @@ func (c *Client) Clone() *Client { cc := *c // clone Transport - cc.t = c.t.Clone() + cc.Transport = c.Transport.Clone() cc.initTransport() // clone http.Client client := *c.httpClient - client.Transport = cc.t + client.Transport = cc.Transport cc.httpClient = &client // clone client middleware @@ -1468,7 +1468,7 @@ func C() *Client { afterResponse: afterResponse, log: createDefaultLogger(), httpClient: httpClient, - t: t, + Transport: t, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, @@ -1481,7 +1481,7 @@ func C() *Client { } func (c *Client) initTransport() { - c.t.Debugf = func(format string, v ...interface{}) { + c.Debugf = func(format string, v ...interface{}) { if c.DebugLog { c.log.Debugf(format, v...) } diff --git a/client_test.go b/client_test.go index 75ec45a2..2204f8c2 100644 --- a/client_test.go +++ b/client_test.go @@ -63,7 +63,7 @@ func TestAllowGetMethodPayload(t *testing.T) { func TestSetTLSHandshakeTimeout(t *testing.T) { timeout := 2 * time.Second c := tc().SetTLSHandshakeTimeout(timeout) - tests.AssertEqual(t, timeout, c.t.TLSHandshakeTimeout) + tests.AssertEqual(t, timeout, c.TLSHandshakeTimeout) } func TestSetDial(t *testing.T) { @@ -72,7 +72,7 @@ func TestSetDial(t *testing.T) { return nil, testErr } c := tc().SetDial(testDial) - _, err := c.t.DialContext(nil, "", "") + _, err := c.DialContext(nil, "", "") tests.AssertEqual(t, testErr, err) } @@ -82,7 +82,7 @@ func TestSetDialTLS(t *testing.T) { return nil, testErr } c := tc().SetDialTLS(testDialTLS) - _, err := c.t.DialTLSContext(nil, "", "") + _, err := c.DialTLSContext(nil, "", "") tests.AssertEqual(t, testErr, err) } @@ -147,7 +147,7 @@ func TestOnBeforeRequest(t *testing.T) { func TestSetProxyURL(t *testing.T) { c := tc().SetProxyURL("http://dummy.proxy.local") - u, err := c.t.Proxy(nil) + u, err := c.Proxy(nil) tests.AssertNoError(t, err) tests.AssertEqual(t, "http://dummy.proxy.local", u.String()) } @@ -156,7 +156,7 @@ func TestSetProxy(t *testing.T) { u, _ := url.Parse("http://dummy.proxy.local") proxy := http.ProxyURL(u) c := tc().SetProxy(proxy) - uu, err := c.t.Proxy(nil) + uu, err := c.Proxy(nil) tests.AssertNoError(t, err) tests.AssertEqual(t, u.String(), uu.String()) } @@ -316,32 +316,32 @@ func TestSetCommonQueryParams(t *testing.T) { func TestInsecureSkipVerify(t *testing.T) { c := tc().EnableInsecureSkipVerify() - tests.AssertEqual(t, true, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, true, c.TLSClientConfig.InsecureSkipVerify) c.DisableInsecureSkipVerify() - tests.AssertEqual(t, false, c.t.TLSClientConfig.InsecureSkipVerify) + tests.AssertEqual(t, false, c.TLSClientConfig.InsecureSkipVerify) } func TestSetTLSClientConfig(t *testing.T) { config := &tls.Config{InsecureSkipVerify: true} c := tc().SetTLSClientConfig(config) - tests.AssertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, config, c.TLSClientConfig) } func TestCompression(t *testing.T) { c := tc().DisableCompression() - tests.AssertEqual(t, true, c.t.DisableCompression) + tests.AssertEqual(t, true, c.Transport.DisableCompression) c.EnableCompression() - tests.AssertEqual(t, false, c.t.DisableCompression) + tests.AssertEqual(t, false, c.Transport.DisableCompression) } func TestKeepAlives(t *testing.T) { c := tc().DisableKeepAlives() - tests.AssertEqual(t, true, c.t.DisableKeepAlives) + tests.AssertEqual(t, true, c.Transport.DisableKeepAlives) c.EnableKeepAlives() - tests.AssertEqual(t, false, c.t.DisableKeepAlives) + tests.AssertEqual(t, false, c.Transport.DisableKeepAlives) } func TestRedirect(t *testing.T) { @@ -383,23 +383,23 @@ func TestRedirect(t *testing.T) { func TestGetTLSClientConfig(t *testing.T) { c := tc() config := c.GetTLSClientConfig() - tests.AssertEqual(t, true, c.t.TLSClientConfig != nil) - tests.AssertEqual(t, config, c.t.TLSClientConfig) + tests.AssertEqual(t, true, c.TLSClientConfig != nil) + tests.AssertEqual(t, config, c.TLSClientConfig) } func TestSetRootCertFromFile(t *testing.T) { c := tc().SetRootCertsFromFile(tests.GetTestFilePath("sample-root.pem")) - tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.TLSClientConfig.RootCAs != nil) } func TestSetRootCertFromString(t *testing.T) { c := tc().SetRootCertFromString(string(getTestFileContent(t, "sample-root.pem"))) - tests.AssertEqual(t, true, c.t.TLSClientConfig.RootCAs != nil) + tests.AssertEqual(t, true, c.TLSClientConfig.RootCAs != nil) } func TestSetCerts(t *testing.T) { c := tc().SetCerts(tls.Certificate{}, tls.Certificate{}) - tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 2) + tests.AssertEqual(t, true, len(c.TLSClientConfig.Certificates) == 2) } func TestSetCertFromFile(t *testing.T) { @@ -407,7 +407,7 @@ func TestSetCertFromFile(t *testing.T) { tests.GetTestFilePath("sample-client.pem"), tests.GetTestFilePath("sample-client-key.pem"), ) - tests.AssertEqual(t, true, len(c.t.TLSClientConfig.Certificates) == 1) + tests.AssertEqual(t, true, len(c.TLSClientConfig.Certificates) == 1) } func TestSetOutputDirectory(t *testing.T) { diff --git a/transport_test.go b/transport_test.go index 8a9e4cf6..21ff1aaa 100644 --- a/transport_test.go +++ b/transport_test.go @@ -1213,7 +1213,7 @@ func TestRoundTripGzip(t *testing.T) { } })) defer ts.Close() - tr := tc().t + tr := tc().GetTransport() for i, test := range roundTripTests { // Test basic request (no accept-encoding) @@ -3902,7 +3902,7 @@ func TestTransportResponseCancelRace(t *testing.T) { w.Write(b[:]) })) defer ts.Close() - tr := tc().t + tr := tc().GetTransport() req, err := http.NewRequest("GET", ts.URL, nil) if err != nil { @@ -3970,7 +3970,7 @@ func TestTransportDialCancelRace(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer ts.Close() - tr := tc().t + tr := tc().GetTransport() req, err := http.NewRequest("GET", ts.URL, nil) if err != nil { From 5fbd4c8759e835b31327555e9875bee0748f692d Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 11:33:41 +0800 Subject: [PATCH 740/843] move Client.Headers into Transport.Headers --- client.go | 2 - middleware.go | 7 --- transport.go | 121 ++++++++++++++++++++++++++++------------------ transport_test.go | 3 ++ 4 files changed, 76 insertions(+), 57 deletions(-) diff --git a/client.go b/client.go index e1a41352..bea21cad 100644 --- a/client.go +++ b/client.go @@ -45,7 +45,6 @@ type Client struct { BaseURL string PathParams map[string]string QueryParams urlpkg.Values - Headers http.Header Cookies []*http.Cookie FormData urlpkg.Values DebugLog bool @@ -1429,7 +1428,6 @@ func (c *Client) Clone() *Client { } // clone other fields that may need to be cloned - cc.Headers = cloneHeaders(c.Headers) cc.Cookies = cloneCookies(c.Cookies) cc.PathParams = cloneMap(c.PathParams) cc.QueryParams = cloneUrlValues(c.QueryParams) diff --git a/middleware.go b/middleware.go index 94c6ad62..c90d360a 100644 --- a/middleware.go +++ b/middleware.go @@ -399,13 +399,6 @@ func parseRequestHeader(c *Client, r *Request) error { if r.Headers == nil { r.Headers = make(http.Header) } - for k, vs := range c.Headers { - for _, v := range vs { - if len(r.Headers[k]) == 0 { - r.Headers[k] = append(r.Headers[k], v) - } - } - } return nil } diff --git a/transport.go b/transport.go index c215cb96..867129da 100644 --- a/transport.go +++ b/transport.go @@ -102,6 +102,8 @@ const defaultMaxIdleConnsPerHost = 2 // request is treated as idempotent but the header is not sent on the // wire. type Transport struct { + Headers http.Header + idleMu sync.Mutex closeIdle bool // user has requested to close all idle conns idleConn map[connectMethodKey][]*persistConn // most recently used at end @@ -824,6 +826,7 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { tt := &Transport{ + Headers: cloneHeaders(t.Headers), Options: t.Options.Clone(), disableAutoDecode: t.disableAutoDecode, autoDecodeContentType: t.autoDecodeContentType, @@ -918,6 +921,69 @@ func (t *Transport) roundTripAltSvc(req *http.Request, as *altsvc.AltSvc) (resp return } +func (t *Transport) ensureHeader(req *http.Request, isHTTP bool) error { + if req.Header == nil { + closeBody(req) + return errors.New("http: nil Request.Header") + } + for k, vs := range t.Headers { + if len(req.Header[k]) == 0 { + req.Header[k] = vs + } + } + if !isHTTP { + return nil + } + // TODO: is h2c should also check this? + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + closeBody(req) + return fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + closeBody(req) + // Don't include the value in the error, because it may be sensitive. + return fmt.Errorf("net/http: invalid header field value for %q", k) + } + } + } + return nil +} + +func (t *Transport) checkAltSvc(req *http.Request) (resp *http.Response, err error) { + if t.altSvcJar == nil { + return + } + addr := netutil.AuthorityKey(req.URL) + pas, ok := t.pendingAltSvcs[addr] + if ok && pas.Transport != nil { + pas.Mu.Lock() + if pas.Transport != nil { + pas.LastTime = time.Now() + r := req.Clone(req.Context()) + r.URL = altsvcutil.ConvertURL(pas.Entries[pas.CurrentIndex], req.URL) + resp, err = pas.Transport.RoundTrip(r) + if err != nil { + pas.Transport = nil + if pas.CurrentIndex+1 < len(pas.Entries) { + pas.CurrentIndex++ + go t.handlePendingAltSvc(req.URL, pas) + } + } else { + t.altSvcJar.SetAltSvc(addr, pas.Entries[pas.CurrentIndex]) + delete(t.pendingAltSvcs, addr) + } + } + pas.Mu.Unlock() + return + } + if as := t.altSvcJar.GetAltSvc(addr); as != nil { + return t.roundTripAltSvc(req, as) + } + return +} + // roundTrip implements a http.RoundTripper over HTTP. func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error) { ctx := req.Context() @@ -928,58 +994,17 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error return nil, errors.New("http: nil Request.URL") } - if t.altSvcJar != nil { - addr := netutil.AuthorityKey(req.URL) - pas, ok := t.pendingAltSvcs[addr] - if ok { - if pas.Transport != nil { - pas.Mu.Lock() - if pas.Transport != nil { - pas.LastTime = time.Now() - r := req.Clone(req.Context()) - r.URL = altsvcutil.ConvertURL(pas.Entries[pas.CurrentIndex], req.URL) - resp, err = pas.Transport.RoundTrip(r) - if err != nil { - pas.Transport = nil - if pas.CurrentIndex+1 < len(pas.Entries) { - pas.CurrentIndex++ - go t.handlePendingAltSvc(req.URL, pas) - } - } else { - t.altSvcJar.SetAltSvc(addr, pas.Entries[pas.CurrentIndex]) - delete(t.pendingAltSvcs, addr) - } - } - pas.Mu.Unlock() - return - } - } - as := t.altSvcJar.GetAltSvc(addr) - if as != nil { - return t.roundTripAltSvc(req, as) - } + resp, err = t.checkAltSvc(req) + if err != nil || resp != nil { + return } - if req.Header == nil { - closeBody(req) - return nil, errors.New("http: nil Request.Header") - } scheme := req.URL.Scheme isHTTP := scheme == "http" || scheme == "https" - if isHTTP { - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - closeBody(req) - return nil, fmt.Errorf("net/http: invalid header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - closeBody(req) - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("net/http: invalid header field value for %q", k) - } - } - } + + err = t.ensureHeader(req, isHTTP) + if err != nil { + return } if t.forceHttpVersion != "" { diff --git a/transport_test.go b/transport_test.go index 21ff1aaa..5c524e5f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5424,6 +5424,9 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { func TestTransportClone(t *testing.T) { tr := &Transport{ + Headers: http.Header{ + "test-key": []string{"test-value"}, + }, forceHttpVersion: h1, Options: transport.Options{ Proxy: func(*http.Request) (*url.URL, error) { panic("") }, From b04b895783a1b3d116bcb3e3757be0db1a4e5d3b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 11:38:41 +0800 Subject: [PATCH 741/843] fix Client.SetHTTP2SettingsFrame --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index bea21cad..19879446 100644 --- a/client.go +++ b/client.go @@ -904,7 +904,7 @@ func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { // SetHTTP2SettingsFrame set the ordered http2 settings frame. func (c *Client) SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { - c.SetHTTP2SettingsFrame(settings...) + c.Transport.SetHTTP2SettingsFrame(settings...) return c } From 98b9c97259e3a080ba11e2f0c80f33ad1b67c488 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 14:05:44 +0800 Subject: [PATCH 742/843] remove parseRequestHeader --- client.go | 1 - middleware.go | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/client.go b/client.go index 19879446..91267fe0 100644 --- a/client.go +++ b/client.go @@ -1451,7 +1451,6 @@ func C() *Client { Timeout: 2 * time.Minute, } beforeRequest := []RequestMiddleware{ - parseRequestHeader, parseRequestURL, parseRequestBody, parseRequestCookie, diff --git a/middleware.go b/middleware.go index c90d360a..22d1d0dd 100644 --- a/middleware.go +++ b/middleware.go @@ -392,16 +392,6 @@ func handleDownload(c *Client, r *Response) (err error) { return } -func parseRequestHeader(c *Client, r *Request) error { - if c.Headers == nil { - return nil - } - if r.Headers == nil { - r.Headers = make(http.Header) - } - return nil -} - func parseRequestCookie(c *Client, r *Request) error { if len(c.Cookies) == 0 { return nil From e60e1ff1014ca13812bafb6abd167d49da8e0748 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 14:20:20 +0800 Subject: [PATCH 743/843] move Client.Cookies into Transport.Cookies --- client.go | 3 --- middleware.go | 8 -------- transport.go | 9 +++++++-- transport_test.go | 6 ++++++ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 91267fe0..c7bb284a 100644 --- a/client.go +++ b/client.go @@ -45,7 +45,6 @@ type Client struct { BaseURL string PathParams map[string]string QueryParams urlpkg.Values - Cookies []*http.Cookie FormData urlpkg.Values DebugLog bool AllowGetMethodPayload bool @@ -1428,7 +1427,6 @@ func (c *Client) Clone() *Client { } // clone other fields that may need to be cloned - cc.Cookies = cloneCookies(c.Cookies) cc.PathParams = cloneMap(c.PathParams) cc.QueryParams = cloneUrlValues(c.QueryParams) cc.FormData = cloneUrlValues(c.FormData) @@ -1453,7 +1451,6 @@ func C() *Client { beforeRequest := []RequestMiddleware{ parseRequestURL, parseRequestBody, - parseRequestCookie, } afterResponse := []ResponseMiddleware{ parseResponseBody, diff --git a/middleware.go b/middleware.go index 22d1d0dd..40a9993b 100644 --- a/middleware.go +++ b/middleware.go @@ -392,14 +392,6 @@ func handleDownload(c *Client, r *Response) (err error) { return } -func parseRequestCookie(c *Client, r *Request) error { - if len(c.Cookies) == 0 { - return nil - } - r.Cookies = append(r.Cookies, c.Cookies...) - return nil -} - // generate URL func parseRequestURL(c *Client, r *Request) error { tempURL := r.RawURL diff --git a/transport.go b/transport.go index 867129da..ef2c963c 100644 --- a/transport.go +++ b/transport.go @@ -103,6 +103,7 @@ const defaultMaxIdleConnsPerHost = 2 // wire. type Transport struct { Headers http.Header + Cookies []*http.Cookie idleMu sync.Mutex closeIdle bool // user has requested to close all idle conns @@ -827,6 +828,7 @@ func (t *Transport) readBufferSize() int { func (t *Transport) Clone() *Transport { tt := &Transport{ Headers: cloneHeaders(t.Headers), + Cookies: cloneCookies(t.Cookies), Options: t.Options.Clone(), disableAutoDecode: t.disableAutoDecode, autoDecodeContentType: t.autoDecodeContentType, @@ -921,7 +923,7 @@ func (t *Transport) roundTripAltSvc(req *http.Request, as *altsvc.AltSvc) (resp return } -func (t *Transport) ensureHeader(req *http.Request, isHTTP bool) error { +func (t *Transport) ensureHeaderAndCookie(req *http.Request, isHTTP bool) error { if req.Header == nil { closeBody(req) return errors.New("http: nil Request.Header") @@ -931,6 +933,9 @@ func (t *Transport) ensureHeader(req *http.Request, isHTTP bool) error { req.Header[k] = vs } } + for _, c := range t.Cookies { + req.AddCookie(c) + } if !isHTTP { return nil } @@ -1002,7 +1007,7 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error scheme := req.URL.Scheme isHTTP := scheme == "http" || scheme == "https" - err = t.ensureHeader(req, isHTTP) + err = t.ensureHeaderAndCookie(req, isHTTP) if err != nil { return } diff --git a/transport_test.go b/transport_test.go index 5c524e5f..1858540b 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5427,6 +5427,12 @@ func TestTransportClone(t *testing.T) { Headers: http.Header{ "test-key": []string{"test-value"}, }, + Cookies: []*http.Cookie{ + { + Name: "test", + Value: "test", + }, + }, forceHttpVersion: h1, Options: transport.Options{ Proxy: func(*http.Request) (*url.URL, error) { panic("") }, From 6532cf6531c433a76574fee03001cd124a6c81e4 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 14:28:13 +0800 Subject: [PATCH 744/843] remove SetHeadersNonCanonical and SetHeaders of Transport(use Client to set) --- transport.go | 48 ------------------------------------------------ 1 file changed, 48 deletions(-) diff --git a/transport.go b/transport.go index ef2c963c..79f901cf 100644 --- a/transport.go +++ b/transport.go @@ -214,54 +214,6 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } -// SetHeaders set headers for all requests, if the same header exists at -// the request-level, the request-level header takes precedence. -func (t *Transport) SetHeaders(hdrs http.Header) *Transport { - if len(hdrs) == 0 { - return t - } - t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { - return func(req *http.Request) (resp *http.Response, err error) { - if req.Header == nil { - req.Header = make(http.Header) - } - for k, v := range hdrs { - if len(v) == 0 { - continue - } - kk := textproto.CanonicalMIMEHeaderKey(k) - if len(req.Header[kk]) == 0 { - req.Header[kk] = v - } - } - return rt.RoundTrip(req) - } - }) - return t -} - -// SetHeadersNonCanonical set non-canonical headers for all requests, if the -// same header exists at the request-level, the request-level header takes precedence. -func (t *Transport) SetHeadersNonCanonical(hdrs http.Header) *Transport { - if len(hdrs) == 0 { - return t - } - t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { - return func(req *http.Request) (resp *http.Response, err error) { - if req.Header == nil { - req.Header = make(http.Header) - } - for k, v := range hdrs { - if len(v) > 0 && len(req.Header[k]) == 0 { - req.Header[k] = v - } - } - return rt.RoundTrip(req) - } - }) - return t -} - // SetHeaderOrder set the order of the http header (case-insensitive). // For example: // From 23923940ebbbbc33ae1d54b549f4eb02c3081fc1 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 16:13:05 +0800 Subject: [PATCH 745/843] rename SetProfile to SetClientProfile --- client.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index c7bb284a..8772ca58 100644 --- a/client.go +++ b/client.go @@ -907,12 +907,17 @@ func (c *Client) SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { return c } -// SetProfile set the http client profile. -func (c *Client) SetProfile(p ClientProfile) *Client { +// SetClientProfile set the http client profile. +func (c *Client) SetClientProfile(p ClientProfile) *Client { p(c) return c } +// UseChromeProfile set the http client profile to chrome. +func (c *Client) UseChromeProfile() *Client { + return c.SetClientProfile(ClientProfile_Chrome) +} + // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { From 47d0c9090c54d820f2ecfb6d9882c82f9ed66fba Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 17:26:04 +0800 Subject: [PATCH 746/843] optimize ClientProfile_Chrome --- profile.go | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/profile.go b/profile.go index 014f6aa9..8028d588 100644 --- a/profile.go +++ b/profile.go @@ -1,6 +1,9 @@ package req -import "github.com/imroc/req/v3/pkg/http2" +import ( + "github.com/imroc/req/v3/pkg/http2" + utls "github.com/refraction-networking/utls" +) type ClientProfile func(c *Client) @@ -27,6 +30,24 @@ var http2SettingsChrome = []http2.Setting{ }, } +var chromeHeaderOrder = []string{ + "host", + "pragma", + "cache-control", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "upgrade-insecure-requests", + "user-agent", + "accept", + "sec-fetch-site", + "sec-fetch-mode", + "sec-fetch-user", + "sec-fetch-dest", + "accept-encoding", + "accept-language", +} + var chromePseudoHeaderOrder = []string{ ":method", ":authority", @@ -34,8 +55,27 @@ var chromePseudoHeaderOrder = []string{ ":path", } +var chromeHeaders = map[string]string{ + "pragma": "no-cache", + "cache-control": "no-cache", + "sec-ch-ua": `"Chromium";v="106", "Google Chrome";v="106", "Not;A=Brand";v="99"`, + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": "macOS", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36", + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "sec-fetch-site": "none", + "sec-fetch-mode": "navigate", + "sec-fetch-user": "?1", + "sec-fetch-dest": "document", + "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", +} + var ClientProfile_Chrome ClientProfile = func(c *Client) { - c.SetTLSFingerprintChrome(). + c.SetTLSFingerprint(utls.HelloChrome_106_Shuffle). + SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36"). SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). + SetCommonHeaders(chromeHeaders). + SetCommonHeaderOrder(chromeHeaderOrder...). SetHTTP2SettingsFrame(http2SettingsChrome...) } From 445e518dba001a25fab7053efbe76c67b232e402 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 17:38:23 +0800 Subject: [PATCH 747/843] optimize ClientProfile_Chrome --- profile.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/profile.go b/profile.go index 8028d588..5ba4de44 100644 --- a/profile.go +++ b/profile.go @@ -46,6 +46,7 @@ var chromeHeaderOrder = []string{ "sec-fetch-dest", "accept-encoding", "accept-language", + "cookie", } var chromePseudoHeaderOrder = []string{ @@ -60,7 +61,7 @@ var chromeHeaders = map[string]string{ "cache-control": "no-cache", "sec-ch-ua": `"Chromium";v="106", "Google Chrome";v="106", "Not;A=Brand";v="99"`, "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": "macOS", + "sec-ch-ua-platform": `"macOS"`, "upgrade-insecure-requests": "1", "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", From c8b8f64ca22a53cf72846366575eeb308d905270 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 2 Aug 2023 17:40:48 +0800 Subject: [PATCH 748/843] remove duplicate UserAgent settting --- profile.go | 1 - 1 file changed, 1 deletion(-) diff --git a/profile.go b/profile.go index 5ba4de44..32399d61 100644 --- a/profile.go +++ b/profile.go @@ -74,7 +74,6 @@ var chromeHeaders = map[string]string{ var ClientProfile_Chrome ClientProfile = func(c *Client) { c.SetTLSFingerprint(utls.HelloChrome_106_Shuffle). - SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36"). SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). SetCommonHeaders(chromeHeaders). SetCommonHeaderOrder(chromeHeaderOrder...). From a7d4bf22a6ce6c577acea2a99f176fb7dd0f52b1 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 3 Aug 2023 14:01:23 +0800 Subject: [PATCH 749/843] Add SetHTTP2ConnectionFlow to Client and Transport --- client.go | 7 +++++++ internal/http2/transport.go | 10 ++++++++-- profile.go | 3 ++- transport.go | 7 +++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 8772ca58..23adebc3 100644 --- a/client.go +++ b/client.go @@ -907,6 +907,13 @@ func (c *Client) SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { return c } +// SetHTTP2ConnectionFlow set the default http2 connection flow, which is the increment +// value of initial WINDOW_UPDATE frame. +func (c *Client) SetHTTP2ConnectionFlow(flow uint32) *Client { + c.Transport.SetHTTP2ConnectionFlow(flow) + return c +} + // SetClientProfile set the http client profile. func (c *Client) SetClientProfile(p ClientProfile) *Client { p(c) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 889c1a4c..0b33ede3 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -134,6 +134,8 @@ type Transport struct { Settings []http2.Setting + ConnectionFlow uint32 + connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool } @@ -695,8 +697,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) - cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) - cc.inflow.init(transportDefaultConnFlow + initialWindowSize) + connFlow := cc.t.ConnectionFlow + if connFlow < 1 { + connFlow = transportDefaultConnFlow + } + cc.fr.WriteWindowUpdate(0, connFlow) + cc.inflow.init(int32(connFlow) + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() diff --git a/profile.go b/profile.go index 32399d61..c14df10b 100644 --- a/profile.go +++ b/profile.go @@ -77,5 +77,6 @@ var ClientProfile_Chrome ClientProfile = func(c *Client) { SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). SetCommonHeaders(chromeHeaders). SetCommonHeaderOrder(chromeHeaderOrder...). - SetHTTP2SettingsFrame(http2SettingsChrome...) + SetHTTP2SettingsFrame(http2SettingsChrome...). + SetHTTP2ConnectionFlow(15663105) } diff --git a/transport.go b/transport.go index 79f901cf..102c8f96 100644 --- a/transport.go +++ b/transport.go @@ -460,6 +460,13 @@ func (t *Transport) SetHTTP2SettingsFrame(settings ...http2.Setting) *Transport return t } +// SetHTTP2ConnectionFlow set the default http2 connection flow, which is the increment +// value of initial WINDOW_UPDATE frame. +func (t *Transport) SetHTTP2ConnectionFlow(flow uint32) *Transport { + t.t2.ConnectionFlow = flow + return t +} + // SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to // use with tls.Client. // If nil, the default configuration is used. From af6757289bc504a198ad09f847d05010f6cd1c5e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 3 Aug 2023 14:12:31 +0800 Subject: [PATCH 750/843] add client wrappers --- client_wrapper.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/client_wrapper.go b/client_wrapper.go index f460fdaf..964b7b5d 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -3,6 +3,7 @@ package req import ( "context" "crypto/tls" + "github.com/imroc/req/v3/pkg/http2" utls "github.com/refraction-networking/utls" "io" "net" @@ -421,6 +422,60 @@ func SetCommonPseudoHeaderOder(keys ...string) *Client { return defaultClient.SetCommonPseudoHeaderOder(keys...) } +// SetHTTP2SettingsFrame is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2SettingsFrame. +func SetHTTP2SettingsFrame(settings ...http2.Setting) *Client { + return defaultClient.SetHTTP2SettingsFrame(settings...) +} + +// SetHTTP2ConnectionFlow is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2ConnectionFlow. +func SetHTTP2ConnectionFlow(flow uint32) *Client { + return defaultClient.SetHTTP2ConnectionFlow(flow) +} + +// SetHTTP2MaxHeaderListSize is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2MaxHeaderListSize. +func SetHTTP2MaxHeaderListSize(max uint32) *Client { + return defaultClient.SetHTTP2MaxHeaderListSize(max) +} + +// SetHTTP2StrictMaxConcurrentStreams is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2StrictMaxConcurrentStreams. +func SetHTTP2StrictMaxConcurrentStreams(strict bool) *Client { + return defaultClient.SetHTTP2StrictMaxConcurrentStreams(strict) +} + +// SetHTTP2ReadIdleTimeout is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2ReadIdleTimeout. +func SetHTTP2ReadIdleTimeout(timeout time.Duration) *Client { + return defaultClient.SetHTTP2ReadIdleTimeout(timeout) +} + +// SetHTTP2PingTimeout is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2PingTimeout. +func SetHTTP2PingTimeout(timeout time.Duration) *Client { + return defaultClient.SetHTTP2PingTimeout(timeout) +} + +// SetHTTP2WriteByteTimeout is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2WriteByteTimeout. +func SetHTTP2WriteByteTimeout(timeout time.Duration) *Client { + return defaultClient.SetHTTP2WriteByteTimeout(timeout) +} + +// SetClientProfile is a global wrapper methods which delegated +// to the default client's Client.SetClientProfile. +func SetClientProfile(p ClientProfile) *Client { + return defaultClient.SetClientProfile(p) +} + +// UseChromeProfile is a global wrapper methods which delegated +// to the default client's Client.UseChromeProfile. +func UseChromeProfile() *Client { + return defaultClient.UseChromeProfile() +} + // SetCommonContentType is a global wrapper methods which delegated // to the default client's Client.SetCommonContentType. func SetCommonContentType(ct string) *Client { From 874654578f74072f0745a0d23934ea27dbe79543 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 3 Aug 2023 16:12:02 +0800 Subject: [PATCH 751/843] Refactor SetClientProfile to ImpersonateXXX * Remove ClientProfile, SetClientProfile, UseChromeProfile. * Add ImpersonateChrome. --- client.go | 11 ----------- profile.go => client_impersonate.go | 10 +++++----- client_wrapper.go | 14 ++++---------- 3 files changed, 9 insertions(+), 26 deletions(-) rename profile.go => client_impersonate.go (86%) diff --git a/client.go b/client.go index 23adebc3..36c64efc 100644 --- a/client.go +++ b/client.go @@ -914,17 +914,6 @@ func (c *Client) SetHTTP2ConnectionFlow(flow uint32) *Client { return c } -// SetClientProfile set the http client profile. -func (c *Client) SetClientProfile(p ClientProfile) *Client { - p(c) - return c -} - -// UseChromeProfile set the http client profile to chrome. -func (c *Client) UseChromeProfile() *Client { - return c.SetClientProfile(ClientProfile_Chrome) -} - // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { diff --git a/profile.go b/client_impersonate.go similarity index 86% rename from profile.go rename to client_impersonate.go index c14df10b..de0fe33f 100644 --- a/profile.go +++ b/client_impersonate.go @@ -5,8 +5,6 @@ import ( utls "github.com/refraction-networking/utls" ) -type ClientProfile func(c *Client) - var http2SettingsChrome = []http2.Setting{ { ID: http2.SettingHeaderTableSize, @@ -59,11 +57,11 @@ var chromePseudoHeaderOrder = []string{ var chromeHeaders = map[string]string{ "pragma": "no-cache", "cache-control": "no-cache", - "sec-ch-ua": `"Chromium";v="106", "Google Chrome";v="106", "Not;A=Brand";v="99"`, + "sec-ch-ua": `"Not_A Brand";v="99", "Google Chrome";v="109", "Chromium";v="109"`, "sec-ch-ua-mobile": "?0", "sec-ch-ua-platform": `"macOS"`, "upgrade-insecure-requests": "1", - "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", "sec-fetch-site": "none", "sec-fetch-mode": "navigate", @@ -72,11 +70,13 @@ var chromeHeaders = map[string]string{ "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", } -var ClientProfile_Chrome ClientProfile = func(c *Client) { +// ImpersonateChrome impersonates Chrome browser (version 109). +func (c *Client) ImpersonateChrome() *Client { c.SetTLSFingerprint(utls.HelloChrome_106_Shuffle). SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). SetCommonHeaders(chromeHeaders). SetCommonHeaderOrder(chromeHeaderOrder...). SetHTTP2SettingsFrame(http2SettingsChrome...). SetHTTP2ConnectionFlow(15663105) + return c } diff --git a/client_wrapper.go b/client_wrapper.go index 964b7b5d..362e84e8 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -464,16 +464,10 @@ func SetHTTP2WriteByteTimeout(timeout time.Duration) *Client { return defaultClient.SetHTTP2WriteByteTimeout(timeout) } -// SetClientProfile is a global wrapper methods which delegated -// to the default client's Client.SetClientProfile. -func SetClientProfile(p ClientProfile) *Client { - return defaultClient.SetClientProfile(p) -} - -// UseChromeProfile is a global wrapper methods which delegated -// to the default client's Client.UseChromeProfile. -func UseChromeProfile() *Client { - return defaultClient.UseChromeProfile() +// ImpersonateChrome is a global wrapper methods which delegated +// to the default client's Client.ImpersonateChrome. +func ImpersonateChrome() *Client { + return defaultClient.ImpersonateChrome() } // SetCommonContentType is a global wrapper methods which delegated From eece8360d84fdba5ed3e8f5c22e9f4d89ec2b035 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 3 Aug 2023 16:17:01 +0800 Subject: [PATCH 752/843] add referer to chromeHeaderOrder --- client_impersonate.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client_impersonate.go b/client_impersonate.go index de0fe33f..b2a525e2 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -42,6 +42,7 @@ var chromeHeaderOrder = []string{ "sec-fetch-mode", "sec-fetch-user", "sec-fetch-dest", + "referer", "accept-encoding", "accept-language", "cookie", From 89d1ceabf8271e20be21db49a33797c3b6e7cb8e Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 3 Aug 2023 17:35:25 +0800 Subject: [PATCH 753/843] Support http2 header priority and some refactor. --- client.go | 8 +++++++- client_impersonate.go | 2 +- client_wrapper.go | 2 +- http2/priority.go | 22 ++++++++++++++++++++++ {pkg/http2 => http2}/setting.go | 0 internal/http2/frame.go | 33 ++++++--------------------------- internal/http2/transport.go | 4 +++- transport.go | 8 +++++++- 8 files changed, 47 insertions(+), 32 deletions(-) create mode 100644 http2/priority.go rename {pkg/http2 => http2}/setting.go (100%) diff --git a/client.go b/client.go index 36c64efc..83a4301b 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "encoding/json" "encoding/xml" "errors" - "github.com/imroc/req/v3/pkg/http2" + "github.com/imroc/req/v3/http2" "io" "net" "net/http" @@ -914,6 +914,12 @@ func (c *Client) SetHTTP2ConnectionFlow(flow uint32) *Client { return c } +// SetHTTP2HeaderPriority set the header priority param. +func (c *Client) SetHTTP2HeaderPriority(priority http2.PriorityParam) *Client { + c.SetHTTP2HeaderPriority(priority) + return c +} + // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { diff --git a/client_impersonate.go b/client_impersonate.go index b2a525e2..1450d809 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -1,7 +1,7 @@ package req import ( - "github.com/imroc/req/v3/pkg/http2" + "github.com/imroc/req/v3/http2" utls "github.com/refraction-networking/utls" ) diff --git a/client_wrapper.go b/client_wrapper.go index 362e84e8..9d458611 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -3,7 +3,7 @@ package req import ( "context" "crypto/tls" - "github.com/imroc/req/v3/pkg/http2" + "github.com/imroc/req/v3/http2" utls "github.com/refraction-networking/utls" "io" "net" diff --git a/http2/priority.go b/http2/priority.go new file mode 100644 index 00000000..9d28327b --- /dev/null +++ b/http2/priority.go @@ -0,0 +1,22 @@ +package http2 + +// PriorityParam are the stream prioritzation parameters. +type PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p PriorityParam) IsZero() bool { + return p == PriorityParam{} +} diff --git a/pkg/http2/setting.go b/http2/setting.go similarity index 100% rename from pkg/http2/setting.go rename to http2/setting.go diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 3b0922c4..2bb88522 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -10,8 +10,8 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/pkg/http2" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "io" @@ -1049,7 +1049,7 @@ type HeadersFrame struct { FrameHeader // Priority is set if FlagHeadersPriority is set in the FrameHeader. - Priority PriorityParam + Priority http2.PriorityParam headerFragBuf []byte // not owned } @@ -1137,7 +1137,7 @@ type HeadersFrameParam struct { // Priority, if non-zero, includes stream priority information // in the HEADER frame. - Priority PriorityParam + Priority http2.PriorityParam } // WriteHeaders writes a single HEADERS frame. @@ -1189,28 +1189,7 @@ func (h2f *Framer) WriteHeaders(p HeadersFrameParam) error { // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3 type PriorityFrame struct { FrameHeader - PriorityParam -} - -// PriorityParam are the stream prioritzation parameters. -type PriorityParam struct { - // StreamDep is a 31-bit stream identifier for the - // stream that this stream depends on. Zero means no - // dependency. - StreamDep uint32 - - // Exclusive is whether the dependency is exclusive. - Exclusive bool - - // Weight is the stream's zero-indexed weight. It should be - // set together with StreamDep, or neither should be set. Per - // the spec, "Add one to the value to obtain a weight between - // 1 and 256." - Weight uint8 -} - -func (p PriorityParam) IsZero() bool { - return p == PriorityParam{} + http2.PriorityParam } func parsePriorityFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) { @@ -1226,7 +1205,7 @@ func parsePriorityFrame(_ *frameCache, fh FrameHeader, countError func(string), streamID := v & 0x7fffffff // mask off high bit return &PriorityFrame{ FrameHeader: fh, - PriorityParam: PriorityParam{ + PriorityParam: http2.PriorityParam{ Weight: payload[4], StreamDep: streamID, Exclusive: streamID != v, // was high bit set? @@ -1238,7 +1217,7 @@ func parsePriorityFrame(_ *frameCache, fh FrameHeader, countError func(string), // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. -func (h2f *Framer) WritePriority(streamID uint32, p PriorityParam) error { +func (h2f *Framer) WritePriority(streamID uint32, p http2.PriorityParam) error { if !validStreamID(streamID) && !h2f.AllowIllegalWrites { return errStreamID } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 0b33ede3..7be9df36 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -15,13 +15,13 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" "github.com/imroc/req/v3/internal/transport" - "github.com/imroc/req/v3/pkg/http2" reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" @@ -135,6 +135,7 @@ type Transport struct { Settings []http2.Setting ConnectionFlow uint32 + HeaderPriority http2.PriorityParam connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool @@ -1533,6 +1534,7 @@ func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize BlockFragment: chunk, EndStream: endStream, EndHeaders: endHeaders, + Priority: cc.t.HeaderPriority, }) first = false } else { diff --git a/transport.go b/transport.go index 102c8f96..75ce96d8 100644 --- a/transport.go +++ b/transport.go @@ -17,6 +17,7 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/altsvcutil" "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" @@ -29,7 +30,6 @@ import ( "github.com/imroc/req/v3/internal/transport" "github.com/imroc/req/v3/internal/util" "github.com/imroc/req/v3/pkg/altsvc" - "github.com/imroc/req/v3/pkg/http2" reqtls "github.com/imroc/req/v3/pkg/tls" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" @@ -467,6 +467,12 @@ func (t *Transport) SetHTTP2ConnectionFlow(flow uint32) *Transport { return t } +// SetHTTP2HeaderPriority set the header priority param. +func (t *Transport) SetHTTP2HeaderPriority(priority http2.PriorityParam) *Transport { + t.t2.HeaderPriority = priority + return t +} + // SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to // use with tls.Client. // If nil, the default configuration is used. From d74b2ce6b2ba7338914df9b64bd78debc34b4272 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 11:04:18 +0800 Subject: [PATCH 754/843] fix Client.SetHTTP2HeaderPriority --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 83a4301b..a683008f 100644 --- a/client.go +++ b/client.go @@ -916,7 +916,7 @@ func (c *Client) SetHTTP2ConnectionFlow(flow uint32) *Client { // SetHTTP2HeaderPriority set the header priority param. func (c *Client) SetHTTP2HeaderPriority(priority http2.PriorityParam) *Client { - c.SetHTTP2HeaderPriority(priority) + c.Transport.SetHTTP2HeaderPriority(priority) return c } From b960634509820b20555b38e1d270af9ae9345251 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 11:04:46 +0800 Subject: [PATCH 755/843] add SetHTTP2HeaderPriority to ImpersonateChrome --- client_impersonate.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client_impersonate.go b/client_impersonate.go index 1450d809..ff6f3cf1 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -78,6 +78,11 @@ func (c *Client) ImpersonateChrome() *Client { SetCommonHeaders(chromeHeaders). SetCommonHeaderOrder(chromeHeaderOrder...). SetHTTP2SettingsFrame(http2SettingsChrome...). - SetHTTP2ConnectionFlow(15663105) + SetHTTP2ConnectionFlow(15663105). + SetHTTP2HeaderPriority(http2.PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 255, + }) return c } From 6c5a082457bd5f74b40acdefec027012a6bfbafe Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 11:08:27 +0800 Subject: [PATCH 756/843] move chromeXXX vars into var block --- client_impersonate.go | 139 +++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 68 deletions(-) diff --git a/client_impersonate.go b/client_impersonate.go index ff6f3cf1..e20b4e12 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -5,71 +5,78 @@ import ( utls "github.com/refraction-networking/utls" ) -var http2SettingsChrome = []http2.Setting{ - { - ID: http2.SettingHeaderTableSize, - Val: 65536, - }, - { - ID: http2.SettingEnablePush, - Val: 0, - }, - { - ID: http2.SettingMaxConcurrentStreams, - Val: 1000, - }, - { - ID: http2.SettingInitialWindowSize, - Val: 6291456, - }, - { - ID: http2.SettingMaxHeaderListSize, - Val: 262144, - }, -} +var ( + chromeHttp2Settings = []http2.Setting{ + { + ID: http2.SettingHeaderTableSize, + Val: 65536, + }, + { + ID: http2.SettingEnablePush, + Val: 0, + }, + { + ID: http2.SettingMaxConcurrentStreams, + Val: 1000, + }, + { + ID: http2.SettingInitialWindowSize, + Val: 6291456, + }, + { + ID: http2.SettingMaxHeaderListSize, + Val: 262144, + }, + } -var chromeHeaderOrder = []string{ - "host", - "pragma", - "cache-control", - "sec-ch-ua", - "sec-ch-ua-mobile", - "sec-ch-ua-platform", - "upgrade-insecure-requests", - "user-agent", - "accept", - "sec-fetch-site", - "sec-fetch-mode", - "sec-fetch-user", - "sec-fetch-dest", - "referer", - "accept-encoding", - "accept-language", - "cookie", -} + chromeHeaderOrder = []string{ + "host", + "pragma", + "cache-control", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "upgrade-insecure-requests", + "user-agent", + "accept", + "sec-fetch-site", + "sec-fetch-mode", + "sec-fetch-user", + "sec-fetch-dest", + "referer", + "accept-encoding", + "accept-language", + "cookie", + } -var chromePseudoHeaderOrder = []string{ - ":method", - ":authority", - ":scheme", - ":path", -} + chromePseudoHeaderOrder = []string{ + ":method", + ":authority", + ":scheme", + ":path", + } -var chromeHeaders = map[string]string{ - "pragma": "no-cache", - "cache-control": "no-cache", - "sec-ch-ua": `"Not_A Brand";v="99", "Google Chrome";v="109", "Chromium";v="109"`, - "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": `"macOS"`, - "upgrade-insecure-requests": "1", - "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", - "sec-fetch-site": "none", - "sec-fetch-mode": "navigate", - "sec-fetch-user": "?1", - "sec-fetch-dest": "document", - "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", -} + chromeHeaders = map[string]string{ + "pragma": "no-cache", + "cache-control": "no-cache", + "sec-ch-ua": `"Not_A Brand";v="99", "Google Chrome";v="109", "Chromium";v="109"`, + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": `"macOS"`, + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "sec-fetch-site": "none", + "sec-fetch-mode": "navigate", + "sec-fetch-user": "?1", + "sec-fetch-dest": "document", + "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", + } + chromeHeaderPriority = http2.PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 255, + } +) // ImpersonateChrome impersonates Chrome browser (version 109). func (c *Client) ImpersonateChrome() *Client { @@ -77,12 +84,8 @@ func (c *Client) ImpersonateChrome() *Client { SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). SetCommonHeaders(chromeHeaders). SetCommonHeaderOrder(chromeHeaderOrder...). - SetHTTP2SettingsFrame(http2SettingsChrome...). + SetHTTP2SettingsFrame(chromeHttp2Settings...). SetHTTP2ConnectionFlow(15663105). - SetHTTP2HeaderPriority(http2.PriorityParam{ - StreamDep: 0, - Exclusive: true, - Weight: 255, - }) + SetHTTP2HeaderPriority(chromeHeaderPriority) return c } From fec12cbfa726021bb5e72cf3f87b320a0d791ccd Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 11:58:59 +0800 Subject: [PATCH 757/843] Add SetHTTP2PriorityFrames support to Client and Transport --- client.go | 6 ++++++ client_wrapper.go | 12 ++++++++++++ http2/priority.go | 6 ++++++ internal/http2/transport.go | 7 +++++++ transport.go | 6 ++++++ 5 files changed, 37 insertions(+) diff --git a/client.go b/client.go index a683008f..143f20c3 100644 --- a/client.go +++ b/client.go @@ -920,6 +920,12 @@ func (c *Client) SetHTTP2HeaderPriority(priority http2.PriorityParam) *Client { return c } +// SetHTTP2PriorityFrames set the ordered http2 priority frames. +func (c *Client) SetHTTP2PriorityFrames(frames ...http2.PriorityFrame) *Client { + c.Transport.SetHTTP2PriorityFrames(frames...) + return c +} + // SetCommonContentType set the `Content-Type` header for requests fired // from the client. func (c *Client) SetCommonContentType(ct string) *Client { diff --git a/client_wrapper.go b/client_wrapper.go index 9d458611..2c1136af 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -434,6 +434,18 @@ func SetHTTP2ConnectionFlow(flow uint32) *Client { return defaultClient.SetHTTP2ConnectionFlow(flow) } +// SetHTTP2HeaderPriority is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2HeaderPriority. +func SetHTTP2HeaderPriority(priority http2.PriorityParam) *Client { + return defaultClient.SetHTTP2HeaderPriority(priority) +} + +// SetHTTP2PriorityFrames is a global wrapper methods which delegated +// to the default client's Client.SetHTTP2PriorityFrames. +func SetHTTP2PriorityFrames(frames ...http2.PriorityFrame) *Client { + return defaultClient.SetHTTP2PriorityFrames(frames...) +} + // SetHTTP2MaxHeaderListSize is a global wrapper methods which delegated // to the default client's Client.SetHTTP2MaxHeaderListSize. func SetHTTP2MaxHeaderListSize(max uint32) *Client { diff --git a/http2/priority.go b/http2/priority.go index 9d28327b..f63846dd 100644 --- a/http2/priority.go +++ b/http2/priority.go @@ -20,3 +20,9 @@ type PriorityParam struct { func (p PriorityParam) IsZero() bool { return p == PriorityParam{} } + +// PriorityFrame represents a http priority frame. +type PriorityFrame struct { + StreamID uint32 + PriorityParam PriorityParam +} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 7be9df36..4a80a4b9 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -136,6 +136,7 @@ type Transport struct { ConnectionFlow uint32 HeaderPriority http2.PriorityParam + PriorityFrames []http2.PriorityFrame connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool @@ -703,6 +704,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro connFlow = transportDefaultConnFlow } cc.fr.WriteWindowUpdate(0, connFlow) + + for _, p := range t.PriorityFrames { + cc.fr.WritePriority(p.StreamID, p.PriorityParam) + cc.nextStreamID = p.StreamID + 2 + } + cc.inflow.init(int32(connFlow) + initialWindowSize) cc.bw.Flush() if cc.werr != nil { diff --git a/transport.go b/transport.go index 75ce96d8..5b7d1e25 100644 --- a/transport.go +++ b/transport.go @@ -473,6 +473,12 @@ func (t *Transport) SetHTTP2HeaderPriority(priority http2.PriorityParam) *Transp return t } +// SetHTTP2PriorityFrames set the ordered http2 priority frames. +func (t *Transport) SetHTTP2PriorityFrames(frames ...http2.PriorityFrame) *Transport { + t.t2.PriorityFrames = frames + return t +} + // SetTLSClientConfig set the custom TLSClientConfig, which specifies the TLS configuration to // use with tls.Client. // If nil, the default configuration is used. From 1913cbbaf54bb8d1a0bf1a7665aeb7f78f5eb5dc Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 15:00:37 +0800 Subject: [PATCH 758/843] add ImpersonateFirefox --- client_impersonate.go | 140 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 129 insertions(+), 11 deletions(-) diff --git a/client_impersonate.go b/client_impersonate.go index e20b4e12..06fb3f47 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -29,6 +29,13 @@ var ( }, } + chromePseudoHeaderOrder = []string{ + ":method", + ":authority", + ":scheme", + ":path", + } + chromeHeaderOrder = []string{ "host", "pragma", @@ -49,13 +56,6 @@ var ( "cookie", } - chromePseudoHeaderOrder = []string{ - ":method", - ":authority", - ":scheme", - ":path", - } - chromeHeaders = map[string]string{ "pragma": "no-cache", "cache-control": "no-cache", @@ -80,12 +80,130 @@ var ( // ImpersonateChrome impersonates Chrome browser (version 109). func (c *Client) ImpersonateChrome() *Client { - c.SetTLSFingerprint(utls.HelloChrome_106_Shuffle). - SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). - SetCommonHeaders(chromeHeaders). - SetCommonHeaderOrder(chromeHeaderOrder...). + c. + SetTLSFingerprint(utls.HelloChrome_106_Shuffle). // Chrome 106~109 shares the same tls fingerprint. SetHTTP2SettingsFrame(chromeHttp2Settings...). SetHTTP2ConnectionFlow(15663105). + SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). + SetCommonHeaderOrder(chromeHeaderOrder...). + SetCommonHeaders(chromeHeaders). SetHTTP2HeaderPriority(chromeHeaderPriority) return c } + +var ( + firefoxHttp2Settings = []http2.Setting{ + { + ID: http2.SettingHeaderTableSize, + Val: 65536, + }, + { + ID: http2.SettingInitialWindowSize, + Val: 131072, + }, + { + ID: http2.SettingMaxFrameSize, + Val: 16384, + }, + } + firefoxPriorityFrames = []http2.PriorityFrame{ + { + StreamID: 3, + PriorityParam: http2.PriorityParam{ + StreamDep: 0, + Exclusive: false, + Weight: 200, + }, + }, + { + StreamID: 5, + PriorityParam: http2.PriorityParam{ + StreamDep: 0, + Exclusive: false, + Weight: 100, + }, + }, + { + StreamID: 7, + PriorityParam: http2.PriorityParam{ + StreamDep: 0, + Exclusive: false, + Weight: 0, + }, + }, + { + StreamID: 9, + PriorityParam: http2.PriorityParam{ + StreamDep: 7, + Exclusive: false, + Weight: 0, + }, + }, + { + StreamID: 11, + PriorityParam: http2.PriorityParam{ + StreamDep: 3, + Exclusive: false, + Weight: 0, + }, + }, + { + StreamID: 13, + PriorityParam: http2.PriorityParam{ + StreamDep: 0, + Exclusive: false, + Weight: 240, + }, + }, + } + firefoxPseudoHeaderOrder = []string{ + ":method", + ":path", + ":authority", + ":scheme", + } + firefoxHeaderOrder = []string{ + "user-agent", + "accept", + "accept-language", + "accept-encoding", + "referer", + "cookie", + "upgrade-insecure-requests", + "sec-fetch-dest", + "sec-fetch-mode", + "sec-fetch-site", + "sec-fetch-user", + "te", + } + firefoxHeaders = map[string]string{ + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:105.0) Gecko/20100101 Firefox/105.0", + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + "accept-language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2", + "upgrade-insecure-requests": "1", + "sec-fetch-dest": "document", + "sec-fetch-mode": "navigate", + "sec-fetch-site": "same-origin", + "sec-fetch-user": "?1", + //"te": "trailers", + } + firefoxHeaderPriority = http2.PriorityParam{ + StreamDep: 13, + Exclusive: false, + Weight: 41, + } +) + +// ImpersonateFirefox impersonates Firefox browser (version 105). +func (c *Client) ImpersonateFirefox() *Client { + c. + SetTLSFingerprint(utls.HelloFirefox_105). + SetHTTP2SettingsFrame(firefoxHttp2Settings...). + SetHTTP2ConnectionFlow(12517377). + SetHTTP2PriorityFrames(firefoxPriorityFrames...). + SetCommonPseudoHeaderOder(firefoxPseudoHeaderOrder...). + SetCommonHeaderOrder(firefoxHeaderOrder...). + SetCommonHeaders(firefoxHeaders). + SetHTTP2HeaderPriority(firefoxHeaderPriority) + return c +} From 6bfbd552a4b915b4662012348fca92c061afd1e5 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 15:45:49 +0800 Subject: [PATCH 759/843] execute retry condition in reverse order --- request.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/request.go b/request.go index 596a609e..3c7ca623 100644 --- a/request.go +++ b/request.go @@ -662,11 +662,13 @@ func (r *Request) do() (resp *Response, err error) { } // check retry whether is needed. - needRetry := err != nil // default behaviour: retry if error occurs - for _, condition := range r.retryOption.RetryConditions { // override default behaviour if custom RetryConditions has been set. - needRetry = condition(resp, err) - if needRetry { - break + needRetry := err != nil // default behaviour: retry if error occurs + if l := len(r.retryOption.RetryConditions); l > 0 { // override default behaviour if custom RetryConditions has been set. + for i := l - 1; i >= 0; i-- { + needRetry = r.retryOption.RetryConditions[i](resp, err) + if needRetry { + break + } } } if !needRetry { // no retry is needed. From 9059ed0692897ee90ea9d4cdfc92724c1527c3d3 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 15:48:37 +0800 Subject: [PATCH 760/843] execute retry hook in reverse order --- request.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index 3c7ca623..4a40df75 100644 --- a/request.go +++ b/request.go @@ -677,8 +677,10 @@ func (r *Request) do() (resp *Response, err error) { // need retry, attempt to retry r.RetryAttempt++ - for _, hook := range r.retryOption.RetryHooks { // run retry hooks - hook(resp, err) + if l := len(r.retryOption.RetryHooks); l > 0 { + for i := l - 1; i >= 0; i-- { // run retry hooks in reverse order + r.retryOption.RetryHooks[i](resp, err) + } } time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) From 490ffb527653ca6ddcd545941d4ef8aff03efff4 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 4 Aug 2023 22:07:49 +0800 Subject: [PATCH 761/843] improve Client.Clone --- req.go | 23 +++++------------------ transport.go | 9 ++++++--- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/req.go b/req.go index 81fadac8..05fc2112 100644 --- a/req.go +++ b/req.go @@ -84,26 +84,13 @@ type DownloadInfo struct { // response body download. type DownloadCallback func(info DownloadInfo) -func cloneCookies(cookies []*http.Cookie) []*http.Cookie { - if len(cookies) == 0 { +func cloneSlice[T any](s []T) []T { + if len(s) == 0 { return nil } - c := make([]*http.Cookie, len(cookies)) - copy(c, cookies) - return c -} - -func cloneHeaders(hdrs http.Header) http.Header { - if hdrs == nil { - return nil - } - h := make(http.Header) - for k, vs := range hdrs { - for _, v := range vs { - h.Add(k, v) - } - } - return h + ss := make([]T, len(s)) + copy(ss, s) + return ss } // TODO: change to generics function when generics are commonly used. diff --git a/transport.go b/transport.go index 5b7d1e25..2511b056 100644 --- a/transport.go +++ b/transport.go @@ -798,8 +798,8 @@ func (t *Transport) readBufferSize() int { // Clone returns a deep copy of t's exported fields. func (t *Transport) Clone() *Transport { tt := &Transport{ - Headers: cloneHeaders(t.Headers), - Cookies: cloneCookies(t.Cookies), + Headers: t.Headers.Clone(), + Cookies: cloneSlice(t.Cookies), Options: t.Options.Clone(), disableAutoDecode: t.disableAutoDecode, autoDecodeContentType: t.autoDecodeContentType, @@ -818,12 +818,15 @@ func (t *Transport) Clone() *Transport { if t.t2 != nil { tt.t2 = &h2internal.Transport{ Options: &tt.Options, - Settings: t.t2.Settings, MaxHeaderListSize: t.t2.MaxHeaderListSize, StrictMaxConcurrentStreams: t.t2.StrictMaxConcurrentStreams, ReadIdleTimeout: t.t2.ReadIdleTimeout, PingTimeout: t.t2.PingTimeout, WriteByteTimeout: t.t2.WriteByteTimeout, + ConnectionFlow: t.t2.ConnectionFlow, + Settings: cloneSlice(t.t2.Settings), + HeaderPriority: t.t2.HeaderPriority, + PriorityFrames: cloneSlice(t.t2.PriorityFrames), } } if t.t3 != nil { From ec7dc7a57301afc86f5840a61e7a4d3ee6db7d30 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Aug 2023 09:56:49 +0800 Subject: [PATCH 762/843] refactor: replace cloneXXX with generics function cloneSlice --- client.go | 6 +++--- req.go | 19 ------------------- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index 143f20c3..6a36e44a 100644 --- a/client.go +++ b/client.go @@ -1443,9 +1443,9 @@ func (c *Client) Clone() *Client { cc.PathParams = cloneMap(c.PathParams) cc.QueryParams = cloneUrlValues(c.QueryParams) cc.FormData = cloneUrlValues(c.FormData) - cc.beforeRequest = cloneRequestMiddleware(c.beforeRequest) - cc.udBeforeRequest = cloneRequestMiddleware(c.udBeforeRequest) - cc.afterResponse = cloneResponseMiddleware(c.afterResponse) + cc.beforeRequest = cloneSlice(c.beforeRequest) + cc.udBeforeRequest = cloneSlice(c.udBeforeRequest) + cc.afterResponse = cloneSlice(c.afterResponse) cc.dumpOptions = c.dumpOptions.Clone() cc.retryOption = c.retryOption.Clone() return &cc diff --git a/req.go b/req.go index 05fc2112..84a052d7 100644 --- a/req.go +++ b/req.go @@ -93,25 +93,6 @@ func cloneSlice[T any](s []T) []T { return ss } -// TODO: change to generics function when generics are commonly used. -func cloneRequestMiddleware(m []RequestMiddleware) []RequestMiddleware { - if len(m) == 0 { - return nil - } - mm := make([]RequestMiddleware, len(m)) - copy(mm, m) - return mm -} - -func cloneResponseMiddleware(m []ResponseMiddleware) []ResponseMiddleware { - if len(m) == 0 { - return nil - } - mm := make([]ResponseMiddleware, len(m)) - copy(mm, m) - return mm -} - func cloneUrlValues(v url.Values) url.Values { if v == nil { return nil From e6a9abfbb489cff5a9b9e40c4a853aba3a7a7f6d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Aug 2023 11:00:31 +0800 Subject: [PATCH 763/843] Add ImpersonateSafari --- client_impersonate.go | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/client_impersonate.go b/client_impersonate.go index 06fb3f47..a4fd76fe 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -207,3 +207,63 @@ func (c *Client) ImpersonateFirefox() *Client { SetHTTP2HeaderPriority(firefoxHeaderPriority) return c } + +var ( + safariHttp2Settings = []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: 4194304, + }, + { + ID: http2.SettingMaxConcurrentStreams, + Val: 100, + }, + } + + safariPseudoHeaderOrder = []string{ + ":method", + ":scheme", + ":path", + ":authority", + } + + safariHeaderOrder = []string{ + "accept", + "sec-fetch-site", + "cookie", + "sec-fetch-dest", + "accept-language", + "sec-fetch-mode", + "user-agent", + "referer", + "accept-encoding", + } + + safariHeaders = map[string]string{ + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "sec-fetch-site": "same-origin", + "sec-fetch-dest": "document", + "accept-language": "zh-CN,zh-Hans;q=0.9", + "sec-fetch-mode": "navigate", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.6 Safari/605.1.15", + } + + safariHeaderPriority = http2.PriorityParam{ + StreamDep: 0, + Exclusive: false, + Weight: 254, + } +) + +// ImpersonateSafari impersonates Safari browser (version 16). +func (c *Client) ImpersonateSafari() *Client { + c. + SetTLSFingerprint(utls.HelloSafari_16_0). + SetHTTP2SettingsFrame(safariHttp2Settings...). + SetHTTP2ConnectionFlow(10485760). + SetCommonPseudoHeaderOder(safariPseudoHeaderOrder...). + SetCommonHeaderOrder(safariHeaderOrder...). + SetCommonHeaders(safariHeaders). + SetHTTP2HeaderPriority(safariHeaderPriority) + return c +} From bee1599f0f3ee1a3901d175a1cf40e98444784f2 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Aug 2023 11:43:27 +0800 Subject: [PATCH 764/843] unexpose unnecessary methods of Transport --- client.go | 20 ++++++++++++++++++-- transport.go | 43 ------------------------------------------- 2 files changed, 18 insertions(+), 45 deletions(-) diff --git a/client.go b/client.go index 6a36e44a..4bf54009 100644 --- a/client.go +++ b/client.go @@ -881,7 +881,15 @@ func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { // "accept-encoding", // ).Get(url func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { - c.SetHeaderOrder(keys...) + c.Transport.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + req.Header[HeaderOderKey] = keys + return rt.RoundTrip(req) + } + }) return c } @@ -897,7 +905,15 @@ func (c *Client) SetCommonHeaderOrder(keys ...string) *Client { // ":method", // ) func (c *Client) SetCommonPseudoHeaderOder(keys ...string) *Client { - c.SetPseudoHeaderOder(keys...) + c.Transport.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { + return func(req *http.Request) (resp *http.Response, err error) { + if req.Header == nil { + req.Header = make(http.Header) + } + req.Header[PseudoHeaderOderKey] = keys + return rt.RoundTrip(req) + } + }) return c } diff --git a/transport.go b/transport.go index 2511b056..80106d8e 100644 --- a/transport.go +++ b/transport.go @@ -214,49 +214,6 @@ func (t *Transport) WrapRoundTrip(wrappers ...HttpRoundTripWrapper) *Transport { return t } -// SetHeaderOrder set the order of the http header (case-insensitive). -// For example: -// -// t.SetHeaderOrder( -// "custom-header", -// "cookie", -// "user-agent", -// "accept-encoding", -// ) -func (t *Transport) SetHeaderOrder(keys ...string) { - t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { - return func(req *http.Request) (resp *http.Response, err error) { - if req.Header == nil { - req.Header = make(http.Header) - } - req.Header[HeaderOderKey] = keys - return rt.RoundTrip(req) - } - }) -} - -// SetPseudoHeaderOder set the order of the pseudo http header (case-insensitive). -// Note this is only valid for http2 and http3. -// For example: -// -// t.SetPseudoHeaderOrder( -// ":scheme", -// ":authority", -// ":path", -// ":method", -// ) -func (t *Transport) SetPseudoHeaderOder(keys ...string) { - t.WrapRoundTripFunc(func(rt http.RoundTripper) HttpRoundTripFunc { - return func(req *http.Request) (resp *http.Response, err error) { - if req.Header == nil { - req.Header = make(http.Header) - } - req.Header[PseudoHeaderOderKey] = keys - return rt.RoundTrip(req) - } - }) -} - // DisableAutoDecode disable auto-detect charset and decode to utf-8 // (enabled by default). func (t *Transport) DisableAutoDecode() *Transport { From f76a2f9cdca7e959dffb95ea76142ac327a3634d Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Aug 2023 16:14:53 +0800 Subject: [PATCH 765/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 29157a28..689a4649 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Full documentation is available on the official website: https://req.cool. * **Smart by Default**: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. * **Support Multiple HTTP Versions**: Support `HTTP/1.1`, `HTTP/2`, and `HTTP/3`, and can automatically detect the server side and select the optimal HTTP version for requests, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). -* **TLS Fingerprinting**: Support tls fingerprinting resistance, so that we can access websites that prohibit crawler programs by identifying TLS handshake fingerprints (See [TLS Fingerprinting](https://req.cool/docs/tutorial/tls-fingerprinting/)). +* **HTTP Fingerprinting**: Support http fingerprint impersonation, so that we can access websites that prohibit crawler programs by identifying http fingerprints (See [TLS Fingerprinting](https://req.cool/docs/tutorial/http-fingerprint/)). * **Multiple Authentication Methods**: You can use HTTP Basic Auth, Bearer Auth Token and Digest Auth out of box (see [Authentication](https://req.cool/docs/tutorial/authentication/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). * **Exportable**: `req.Transport` is exportable. Compared with `http.Transport`, it also supports HTTP3, dump content, middleware, etc. It can directly replace the Transport of `http.Client` in existing projects, and obtain more powerful functions with minimal code change. From e012355bd91b156a66043e4b2d2fe4866cb14169 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 5 Aug 2023 16:15:31 +0800 Subject: [PATCH 766/843] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 689a4649..822a6651 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Full documentation is available on the official website: https://req.cool. * **Smart by Default**: Detect and decode to utf-8 automatically if possible to avoid garbled characters (See [Auto Decode](https://req.cool/docs/tutorial/auto-decode/)), marshal request body and unmarshal response body automatically according to the Content-Type. * **Support Multiple HTTP Versions**: Support `HTTP/1.1`, `HTTP/2`, and `HTTP/3`, and can automatically detect the server side and select the optimal HTTP version for requests, you can also force the protocol if you want (See [Force HTTP version](https://req.cool/docs/tutorial/force-http-version/)). * **Support Retry**: Support automatic request retry and is fully customizable (See [Retry](https://req.cool/docs/tutorial/retry/)). -* **HTTP Fingerprinting**: Support http fingerprint impersonation, so that we can access websites that prohibit crawler programs by identifying http fingerprints (See [TLS Fingerprinting](https://req.cool/docs/tutorial/http-fingerprint/)). +* **HTTP Fingerprinting**: Support http fingerprint impersonation, so that we can access websites that prohibit crawler programs by identifying http fingerprints (See [HTTP Fingerprint](https://req.cool/docs/tutorial/http-fingerprint/)). * **Multiple Authentication Methods**: You can use HTTP Basic Auth, Bearer Auth Token and Digest Auth out of box (see [Authentication](https://req.cool/docs/tutorial/authentication/)). * **Easy Download and Upload**: You can download and upload files with simple request settings, and even set a callback to show real-time progress (See [Download](https://req.cool/docs/tutorial/download/) and [Upload](https://req.cool/docs/tutorial/upload/)). * **Exportable**: `req.Transport` is exportable. Compared with `http.Transport`, it also supports HTTP3, dump content, middleware, etc. It can directly replace the Transport of `http.Client` in existing projects, and obtain more powerful functions with minimal code change. From 2ecb93d8801e49b5cf59cf5d6fb8463984456919 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 8 Aug 2023 15:30:22 +0800 Subject: [PATCH 767/843] Improve cookie jar. * Add SetCookeJarFactory. * Use memoryCookieJarFactory to create cookie jar by default when create Client. * Add some comments. --- client.go | 46 ++++++++++++++++++++++++++++++++++++++-------- client_test.go | 16 ++++++++++++++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 4bf54009..c140d1b5 100644 --- a/client.go +++ b/client.go @@ -50,6 +50,7 @@ type Client struct { AllowGetMethodPayload bool *Transport + cookiejarFactory func() *cookiejar.Jar trace bool disableAutoReadResponse bool commonErrorType reflect.Type @@ -1022,9 +1023,13 @@ func (c *Client) EnableTraceAll() *Client { return c } -// SetCookieJar set the `CookeJar` to the underlying `http.Client`, set to nil if you -// want to disable cookie. +// SetCookieJar set the cookie jar to the underlying `http.Client`, set to nil if you +// want to disable cookies. +// Note: If you use Client.Clone to clone a new Client, the new client will share the same +// cookie jar as the old Client after cloning. Use SetCookieJarFactory instead if you want +// to create a new CookieJar automatically when cloning a client. func (c *Client) SetCookieJar(jar http.CookieJar) *Client { + c.cookiejarFactory = nil c.httpClient.Jar = jar return c } @@ -1042,11 +1047,10 @@ func (c *Client) GetCookies(url string) ([]*http.Cookie, error) { } // ClearCookies clears all cookies if cookie is enabled. +// Note: It has no effect if you called SetCookieJar instead of +// SetCookieJarFactory. func (c *Client) ClearCookies() *Client { - if c.httpClient.Jar != nil { - jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) - c.httpClient.Jar = jar - } + c.initCookieJar() return c } @@ -1446,6 +1450,7 @@ func (c *Client) Clone() *Client { client := *c.httpClient client.Transport = cc.Transport cc.httpClient = &client + cc.initCookieJar() // clone client middleware if len(cc.roundTripWrappers) > 0 { @@ -1467,14 +1472,17 @@ func (c *Client) Clone() *Client { return &cc } +func memoryCookieJarFactory() *cookiejar.Jar { + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + return jar +} + // C create a new client. func C() *Client { t := T() - jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) httpClient := &http.Client{ Transport: t, - Jar: jar, Timeout: 2 * time.Minute, } beforeRequest := []RequestMiddleware{ @@ -1496,13 +1504,35 @@ func C() *Client { jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, xmlUnmarshal: xml.Unmarshal, + cookiejarFactory: memoryCookieJarFactory, } httpClient.CheckRedirect = c.defaultCheckRedirect + c.initCookieJar() c.initTransport() return c } +// SetCookieJarFactory set the functional factory of cookie jar, which creates +// cookie jar that store cookies for underlying `http.Client`. After client clone, +// the cookie jar of the new client will also be regenerated using this factory +// function. +func (c *Client) SetCookieJarFactory(factory func() *cookiejar.Jar) *Client { + c.cookiejarFactory = factory + c.initCookieJar() + return c +} + +func (c *Client) initCookieJar() { + if c.cookiejarFactory == nil { + return + } + jar := c.cookiejarFactory() + if jar != nil { + c.httpClient.Jar = jar + } +} + func (c *Client) initTransport() { c.Debugf = func(format string, v ...interface{}) { if c.DebugLog { diff --git a/client_test.go b/client_test.go index 2204f8c2..98df9f55 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,11 @@ import ( "errors" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/tests" + "golang.org/x/net/publicsuffix" "io" "net" "net/http" + "net/http/cookiejar" "net/url" "os" "strings" @@ -632,3 +634,17 @@ func TestSetResultStateCheckFunc(t *testing.T) { tests.AssertNoError(t, err) tests.AssertEqual(t, ErrorState, resp.ResultState()) } +func TestCloneCookieJar(t *testing.T) { + c1 := C() + c2 := c1.Clone() + tests.AssertEqual(t, true, c1.httpClient.Jar != c2.httpClient.Jar) + + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + c1.SetCookieJar(jar) + c2 = c1.Clone() + tests.AssertEqual(t, true, c1.httpClient.Jar == c2.httpClient.Jar) + + c2.SetCookieJar(nil) + tests.AssertEqual(t, true, c2.cookiejarFactory == nil) + tests.AssertEqual(t, true, c2.httpClient.Jar == nil) +} From b1f61a8862d7717227062ccac3fca083c445dc43 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Aug 2023 09:42:53 +0800 Subject: [PATCH 768/843] Update dependencies and support go1.21 --- .github/workflows/ci.yml | 2 +- examples/find-popular-repo/go.mod | 10 +++++----- examples/find-popular-repo/go.sum | 5 +++++ examples/opentelemetry-jaeger-tracing/go.mod | 10 +++++----- examples/opentelemetry-jaeger-tracing/go.sum | 5 +++++ examples/upload/uploadclient/go.mod | 10 +++++----- examples/upload/uploadclient/go.sum | 5 +++++ examples/uploadcallback/uploadclient/go.mod | 10 +++++----- examples/uploadcallback/uploadclient/go.sum | 5 +++++ go.mod | 20 ++++++++++---------- go.sum | 20 ++++++++++++++++++++ 11 files changed, 71 insertions(+), 31 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43e855c7..9d236dc9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.20.x' ] + go: [ '1.20.x', '1.21.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index be0ddc0c..6e895b2b 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -19,11 +19,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.11.0 // indirect + golang.org/x/crypto v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect - golang.org/x/tools v0.11.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/text v0.12.0 // indirect + golang.org/x/tools v0.12.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 39bbba8a..a96224fb 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -169,6 +169,7 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -199,6 +200,7 @@ golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -236,6 +238,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -245,6 +248,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -260,6 +264,7 @@ golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4 golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/opentelemetry-jaeger-tracing/go.mod b/examples/opentelemetry-jaeger-tracing/go.mod index bc0c021e..1fc84c1d 100644 --- a/examples/opentelemetry-jaeger-tracing/go.mod +++ b/examples/opentelemetry-jaeger-tracing/go.mod @@ -28,11 +28,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.11.0 // indirect + golang.org/x/crypto v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect - golang.org/x/tools v0.11.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/text v0.12.0 // indirect + golang.org/x/tools v0.12.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/opentelemetry-jaeger-tracing/go.sum b/examples/opentelemetry-jaeger-tracing/go.sum index 6c63d5db..b81a8a94 100644 --- a/examples/opentelemetry-jaeger-tracing/go.sum +++ b/examples/opentelemetry-jaeger-tracing/go.sum @@ -180,6 +180,7 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -213,6 +214,7 @@ golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -254,6 +256,7 @@ golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -264,6 +267,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -278,6 +282,7 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index eb2cabe0..ca2d105c 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -20,11 +20,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.11.0 // indirect + golang.org/x/crypto v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect - golang.org/x/tools v0.11.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/text v0.12.0 // indirect + golang.org/x/tools v0.12.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index c694e19d..24e9d3b8 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -168,6 +168,7 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -203,6 +204,7 @@ golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -243,6 +245,7 @@ golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -253,6 +256,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -267,6 +271,7 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod index eb2cabe0..ca2d105c 100644 --- a/examples/uploadcallback/uploadclient/go.mod +++ b/examples/uploadcallback/uploadclient/go.mod @@ -20,11 +20,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.11.0 // indirect + golang.org/x/crypto v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect - golang.org/x/tools v0.11.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/text v0.12.0 // indirect + golang.org/x/tools v0.12.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum index c694e19d..24e9d3b8 100644 --- a/examples/uploadcallback/uploadclient/go.sum +++ b/examples/uploadcallback/uploadclient/go.sum @@ -168,6 +168,7 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -203,6 +204,7 @@ golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNGnz3kepwuXqFKYDdDMs= golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -243,6 +245,7 @@ golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -253,6 +256,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -267,6 +271,7 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/go.mod b/go.mod index 9471aad8..a4847adc 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.20 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.37.0 - github.com/refraction-networking/utls v1.3.3 - golang.org/x/net v0.12.0 - golang.org/x/text v0.11.0 + github.com/quic-go/quic-go v0.37.4 + github.com/refraction-networking/utls v1.4.1 + golang.org/x/net v0.14.0 + golang.org/x/text v0.12.0 ) require ( @@ -16,14 +16,14 @@ require ( github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 // indirect + github.com/google/pprof v0.0.0-20230808223545-4887780b67fb // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/klauspost/compress v1.16.7 // indirect github.com/onsi/ginkgo/v2 v2.11.0 // indirect - github.com/quic-go/qtls-go1-20 v0.3.0 // indirect - golang.org/x/crypto v0.11.0 // indirect - golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect + github.com/quic-go/qtls-go1-20 v0.3.2 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/tools v0.11.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/tools v0.12.0 // indirect ) diff --git a/go.sum b/go.sum index 09460786..ca4152b4 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 h1:n6vlPhxsA+BW/XsS5+uqi7GyzaLa5MH7qlSLBZtRdiA= github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/pprof v0.0.0-20230808223545-4887780b67fb h1:oqpb3Cwpc7EOml5PVGMYbSGmwNui2R7i8IW83gs4W0c= +github.com/google/pprof v0.0.0-20230808223545-4887780b67fb/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -48,16 +50,22 @@ github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8G github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI= +github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.37.0 h1:wf/Ym2yeWi98oQn4ahiBSqdnaXVxNQGj2oBQFgiVChc= github.com/quic-go/quic-go v0.37.0/go.mod h1:XtCUOCALTTWbPyd0IxFfHf6h0sEMubRFvEYHl3QxKw8= +github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH4= +github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= github.com/refraction-networking/utls v1.3.3/go.mod h1:DlecWW1LMlMJu+9qpzzQqdHDT/C2LAe03EdpLUz/RL8= +github.com/refraction-networking/utls v1.4.1 h1:5VXwhNzrnWrvbJW8IVpptJKrErZGqoRbn7wqu2jqMrU= +github.com/refraction-networking/utls v1.4.1/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -70,10 +78,14 @@ golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 h1:EDuYyU/MkFXllv9QF9819VlI9a4tzGuCbhG0ExK9o1U= +golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -90,6 +102,8 @@ golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -103,6 +117,8 @@ golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -112,6 +128,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -121,6 +139,8 @@ golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 924664aa251d294140df8b4fcd848e0e02a7a024 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Aug 2023 14:13:35 +0800 Subject: [PATCH 769/843] update dependencies to fix #263 --- go.mod | 4 ++-- go.sum | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index a4847adc..cc5eb39d 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 github.com/quic-go/quic-go v0.37.4 - github.com/refraction-networking/utls v1.4.1 + github.com/refraction-networking/utls v1.4.2 golang.org/x/net v0.14.0 golang.org/x/text v0.12.0 ) @@ -22,7 +22,7 @@ require ( github.com/onsi/ginkgo/v2 v2.11.0 // indirect github.com/quic-go/qtls-go1-20 v0.3.2 // indirect golang.org/x/crypto v0.12.0 // indirect - golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 // indirect + golang.org/x/exp v0.0.0-20230810033253-352e893a4cad // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/tools v0.12.0 // indirect diff --git a/go.sum b/go.sum index ca4152b4..e503098e 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCi github.com/refraction-networking/utls v1.3.3/go.mod h1:DlecWW1LMlMJu+9qpzzQqdHDT/C2LAe03EdpLUz/RL8= github.com/refraction-networking/utls v1.4.1 h1:5VXwhNzrnWrvbJW8IVpptJKrErZGqoRbn7wqu2jqMrU= github.com/refraction-networking/utls v1.4.1/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= +github.com/refraction-networking/utls v1.4.2 h1:7N+928mSM1pEyAJb8x2Y1FbEwTIftGwn2IFykosSzwc= +github.com/refraction-networking/utls v1.4.2/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -86,6 +88,8 @@ golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiAp golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 h1:EDuYyU/MkFXllv9QF9819VlI9a4tzGuCbhG0ExK9o1U= golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggkk1bJDsINcuWmJN1iJU= +golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= From ad9fd1e86aad47e1a13bc704e2fd1ed1ca6104a8 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 11 Aug 2023 23:27:34 +0800 Subject: [PATCH 770/843] fix SetCommonContentType is not respected when SetBody is called (#265) --- middleware.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/middleware.go b/middleware.go index 40a9993b..99268271 100644 --- a/middleware.go +++ b/middleware.go @@ -228,9 +228,14 @@ func parseRequestBody(c *Client, r *Request) (err error) { return } // body is in-memory []byte, so we can guess content type - if r.getHeader(header.ContentType) == "" { - r.SetContentType(http.DetectContentType(r.Body)) + + if c.Headers != nil && c.Headers.Get(header.ContentType) != "" { // ignore if content type set at client-level + return + } + if r.getHeader(header.ContentType) != "" { // ignore if content-type set at request-level + return } + r.SetContentType(http.DetectContentType(r.Body)) return } From abc4962114fb7f3609c2af3aeffb7c3ffed5e98a Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 12 Aug 2023 13:47:00 +0800 Subject: [PATCH 771/843] fix transport middleware cannot access common header and cookies --- client.go | 2 ++ middleware.go | 23 +++++++++++++++++++++ transport.go | 57 +++++++++++++++++++-------------------------------- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/client.go b/client.go index c140d1b5..4ec1cdc8 100644 --- a/client.go +++ b/client.go @@ -1486,6 +1486,8 @@ func C() *Client { Timeout: 2 * time.Minute, } beforeRequest := []RequestMiddleware{ + parseRequestHeader, + parseRequestCookie, parseRequestURL, parseRequestBody, } diff --git a/middleware.go b/middleware.go index 99268271..c554e89d 100644 --- a/middleware.go +++ b/middleware.go @@ -472,3 +472,26 @@ func parseRequestURL(c *Client, r *Request) error { r.URL = reqURL return nil } + +func parseRequestHeader(c *Client, r *Request) error { + if c.Headers == nil { + return nil + } + if r.Headers == nil { + r.Headers = make(http.Header) + } + for k, vs := range c.Headers { + if len(r.Headers[k]) == 0 { + r.Headers[k] = vs + } + } + return nil +} + +func parseRequestCookie(c *Client, r *Request) error { + if len(c.Cookies) == 0 { + return nil + } + r.Cookies = append(r.Cookies, c.Cookies...) + return nil +} diff --git a/transport.go b/transport.go index 80106d8e..7ff158c0 100644 --- a/transport.go +++ b/transport.go @@ -854,39 +854,6 @@ func (t *Transport) roundTripAltSvc(req *http.Request, as *altsvc.AltSvc) (resp return } -func (t *Transport) ensureHeaderAndCookie(req *http.Request, isHTTP bool) error { - if req.Header == nil { - closeBody(req) - return errors.New("http: nil Request.Header") - } - for k, vs := range t.Headers { - if len(req.Header[k]) == 0 { - req.Header[k] = vs - } - } - for _, c := range t.Cookies { - req.AddCookie(c) - } - if !isHTTP { - return nil - } - // TODO: is h2c should also check this? - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - closeBody(req) - return fmt.Errorf("net/http: invalid header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - closeBody(req) - // Don't include the value in the error, because it may be sensitive. - return fmt.Errorf("net/http: invalid header field value for %q", k) - } - } - } - return nil -} - func (t *Transport) checkAltSvc(req *http.Request) (resp *http.Response, err error) { if t.altSvcJar == nil { return @@ -938,9 +905,27 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error scheme := req.URL.Scheme isHTTP := scheme == "http" || scheme == "https" - err = t.ensureHeaderAndCookie(req, isHTTP) - if err != nil { - return + if isHTTP { + // TODO: is h2c should also check this? + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + closeBody(req) + err = fmt.Errorf("net/http: invalid header field name %q", k) + return + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + closeBody(req) + // Don't include the value in the error, because it may be sensitive. + err = fmt.Errorf("net/http: invalid header field value for %q", k) + return + } + } + } + } + + if req.Header == nil { + req.Header = make(http.Header) } if t.forceHttpVersion != "" { From e3b07ec876d77913e305a3d25f9877816c862f02 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 12 Aug 2023 16:14:05 +0800 Subject: [PATCH 772/843] fix InsecureSkipVerify (#268) --- client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 4ec1cdc8..5b9b033e 100644 --- a/client.go +++ b/client.go @@ -1175,7 +1175,11 @@ func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { colonPos = len(addr) } hostname := addr[:colonPos] - utlsConfig := &utls.Config{ServerName: hostname, NextProtos: c.GetTLSClientConfig().NextProtos} + utlsConfig := &utls.Config{ + ServerName: hostname, + NextProtos: c.GetTLSClientConfig().NextProtos, + InsecureSkipVerify: c.GetTLSClientConfig().InsecureSkipVerify, + } uconn := &uTLSConn{utls.UClient(plainConn, utlsConfig, clientHelloID)} err = uconn.HandshakeContext(ctx) if err != nil { From 13d91d27ee100c7d0f21fc3afdc760a255aad407 Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 12 Aug 2023 17:09:56 +0800 Subject: [PATCH 773/843] upgrade utls v1.4.3 --- go.mod | 6 +++--- go.sum | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cc5eb39d..95ecf0e9 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 github.com/quic-go/quic-go v0.37.4 - github.com/refraction-networking/utls v1.4.2 + github.com/refraction-networking/utls v1.4.3 golang.org/x/net v0.14.0 golang.org/x/text v0.12.0 ) @@ -16,13 +16,13 @@ require ( github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230808223545-4887780b67fb // indirect + github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/klauspost/compress v1.16.7 // indirect github.com/onsi/ginkgo/v2 v2.11.0 // indirect github.com/quic-go/qtls-go1-20 v0.3.2 // indirect golang.org/x/crypto v0.12.0 // indirect - golang.org/x/exp v0.0.0-20230810033253-352e893a4cad // indirect + golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/tools v0.12.0 // indirect diff --git a/go.sum b/go.sum index e503098e..9e738ffc 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 h1:n6vlPhxsA+BW/XsS5+ github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230808223545-4887780b67fb h1:oqpb3Cwpc7EOml5PVGMYbSGmwNui2R7i8IW83gs4W0c= github.com/google/pprof v0.0.0-20230808223545-4887780b67fb/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 h1:0h35ESZ02+hN/MFZb7XZOXg+Rl9+Rk8fBIf5YLws9gA= +github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -68,6 +70,8 @@ github.com/refraction-networking/utls v1.4.1 h1:5VXwhNzrnWrvbJW8IVpptJKrErZGqoRb github.com/refraction-networking/utls v1.4.1/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= github.com/refraction-networking/utls v1.4.2 h1:7N+928mSM1pEyAJb8x2Y1FbEwTIftGwn2IFykosSzwc= github.com/refraction-networking/utls v1.4.2/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= +github.com/refraction-networking/utls v1.4.3 h1:BdWS3BSzCwWCFfMIXP3mjLAyQkdmog7diaD/OqFbAzM= +github.com/refraction-networking/utls v1.4.3/go.mod h1:4u9V/awOSBrRw6+federGmVJQfPtemEqLBXkML1b0bo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -90,6 +94,8 @@ golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 h1:EDuYyU/MkFXllv9QF9819VlI9 golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggkk1bJDsINcuWmJN1iJU= golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= +golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= From 66c121d321a2c2bc2fa9ac493bfb66d38296031f Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 18 Aug 2023 09:57:32 +0800 Subject: [PATCH 774/843] Default qop to "auth" in HTTPDigestAuth(#269) --- digest.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/digest.go b/digest.go index 11d0b91b..4cf4f73f 100644 --- a/digest.go +++ b/digest.go @@ -7,11 +7,12 @@ import ( "crypto/sha512" "errors" "fmt" - "github.com/imroc/req/v3/internal/header" "hash" "io" "net/http" "strings" + + "github.com/imroc/req/v3/internal/header" ) var ( @@ -19,7 +20,6 @@ var ( errDigestCharset = errors.New("digest: unsupported charset") errDigestAlgNotSupported = errors.New("digest: algorithm is not supported") errDigestQopNotSupported = errors.New("digest: no supported qop in list") - errDigestNoQop = errors.New("digest: qop must be specified") ) var hashFuncs = map[string]func() hash.Hash{ @@ -213,7 +213,7 @@ func (c *credentials) authorize() (string, error) { func (c *credentials) validateQop() error { // Currently only supporting auth quality of protection. TODO: add auth-int support if c.messageQop == "" { - return errDigestNoQop + c.messageQop = "auth" } possibleQops := strings.Split(c.messageQop, ", ") var authSupport bool @@ -227,8 +227,6 @@ func (c *credentials) validateQop() error { return errDigestQopNotSupported } - c.messageQop = "auth" - return nil } From 08cabbcf9cbdabde0756ea15a56a5c47a96c6ecf Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 20 Aug 2023 16:14:40 +0800 Subject: [PATCH 775/843] clear common cookies in ClearCookies --- client.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 5b9b033e..a5affaa4 100644 --- a/client.go +++ b/client.go @@ -1046,11 +1046,13 @@ func (c *Client) GetCookies(url string) ([]*http.Cookie, error) { return c.httpClient.Jar.Cookies(u), nil } -// ClearCookies clears all cookies if cookie is enabled. -// Note: It has no effect if you called SetCookieJar instead of -// SetCookieJarFactory. +// ClearCookies clears all cookies if cookie is enabled, including +// cookies from cookie jar and cookies set by SetCommonCookies. +// Note: The cookie jar will not be cleared if you called SetCookieJar +// instead of SetCookieJarFactory. func (c *Client) ClearCookies() *Client { c.initCookieJar() + c.Cookies = nil return c } From ea515f94f21175633db8092fd584c505ec5afa8a Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 21 Aug 2023 10:50:12 +0800 Subject: [PATCH 776/843] add comments --- client.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index a5affaa4..ae4035a0 100644 --- a/client.go +++ b/client.go @@ -1224,20 +1224,27 @@ func (c *Client) SetTLSHandshakeTimeout(timeout time.Duration) *Client { } // EnableForceHTTP1 enable force using HTTP1 (disabled by default). +// +// Attention: This method should not be called when ImpersonateXXX, SetTLSFingerPrint or +// SetTLSHandshake and other methods that will customize the tls handshake are called. func (c *Client) EnableForceHTTP1() *Client { c.Transport.EnableForceHTTP1() return c } -// EnableForceHTTP2 enable force using HTTP2 for https requests -// (disabled by default). +// EnableForceHTTP2 enable force using HTTP2 for https requests (disabled by default). +// +// Attention: This method should not be called when ImpersonateXXX, SetTLSFingerPrint or +// SetTLSHandshake and other methods that will customize the tls handshake are called. func (c *Client) EnableForceHTTP2() *Client { c.Transport.EnableForceHTTP2() return c } -// EnableForceHTTP3 enable force using HTTP3 for https requests -// (disabled by default). +// EnableForceHTTP3 enable force using HTTP3 for https requests (disabled by default). +// +// Attention: This method should not be called when ImpersonateXXX, SetTLSFingerPrint or +// SetTLSHandshake and other methods that will customize the tls handshake are called. func (c *Client) EnableForceHTTP3() *Client { c.Transport.EnableForceHTTP3() return c From bc78220c094180c4619fd36423b72ab6becb014c Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 21 Aug 2023 14:10:46 +0800 Subject: [PATCH 777/843] Allow splitting digest parameters without spaces --- digest.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/digest.go b/digest.go index 4cf4f73f..9b341407 100644 --- a/digest.go +++ b/digest.go @@ -120,11 +120,11 @@ func parseChallenge(input string) (*challenge, error) { return nil, errDigestBadChallenge } s = strings.Trim(s[7:], ws) - sl := strings.Split(s, ", ") + sl := strings.Split(s, ",") c := &challenge{} var r []string for i := range sl { - r = strings.SplitN(sl[i], "=", 2) + r = strings.SplitN(strings.TrimSpace(sl[i]), "=", 2) if len(r) != 2 { return nil, errDigestBadChallenge } From 23dae155150bd7bc17edebe8fb20dc1d9d49fed5 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 23 Aug 2023 11:01:41 +0800 Subject: [PATCH 778/843] Support http digest calculation that does not follow the rfc specification --- digest.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/digest.go b/digest.go index 9b341407..ccb97398 100644 --- a/digest.go +++ b/digest.go @@ -213,7 +213,7 @@ func (c *credentials) authorize() (string, error) { func (c *credentials) validateQop() error { // Currently only supporting auth quality of protection. TODO: add auth-int support if c.messageQop == "" { - c.messageQop = "auth" + return nil } possibleQops := strings.Split(c.messageQop, ", ") var authSupport bool @@ -250,6 +250,9 @@ func (c *credentials) resp() (string, error) { ha1 := c.ha1() ha2 := c.ha2() + if len(c.messageQop) == 0 { + return c.h(fmt.Sprintf("%s:%s:%s", ha1, c.nonce, ha2)), nil + } return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil } From 93b933a1fabf7c64a7b226e0efb6ea9ec2888949 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 28 Aug 2023 15:54:08 +0800 Subject: [PATCH 779/843] Do not try charset conversion if Accept-Encoding specified --- transport.go | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/transport.go b/transport.go index 7ff158c0..407628ae 100644 --- a/transport.go +++ b/transport.go @@ -17,6 +17,20 @@ import ( "crypto/tls" "errors" "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "runtime" + "strconv" + "strings" + "sync" + "time" + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/altsvcutil" "github.com/imroc/req/v3/internal/ascii" @@ -33,19 +47,6 @@ import ( reqtls "github.com/imroc/req/v3/pkg/tls" htmlcharset "golang.org/x/net/html/charset" "golang.org/x/text/encoding/ianaindex" - "io" - "log" - "mime" - "net" - "net/http" - "net/http/httptrace" - "net/textproto" - "net/url" - "runtime" - "strconv" - "strings" - "sync" - "time" "golang.org/x/net/http/httpguts" ) @@ -688,7 +689,7 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu } func (t *Transport) autoDecodeResponseBody(res *http.Response) { - if t.disableAutoDecode { + if t.disableAutoDecode || res.Header.Get("Accept-Encoding") != "" { return } contentType := res.Header.Get("Content-Type") From 011c0c2f2fdd6d79d86dfb729efe3ff2b151a1fb Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 28 Aug 2023 16:37:59 +0800 Subject: [PATCH 780/843] Fix http3 in go1.21(#274) --- transport.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transport.go b/transport.go index 407628ae..2fdcff3c 100644 --- a/transport.go +++ b/transport.go @@ -542,7 +542,9 @@ func (t *Transport) DisableH2C() *Transport { // (disabled by default). func (t *Transport) EnableForceHTTP3() *Transport { t.EnableHTTP3() - t.forceHttpVersion = h3 + if t.t3 != nil { + t.forceHttpVersion = h3 + } return t } @@ -580,7 +582,7 @@ func (t *Transport) EnableHTTP3() { } return } - if !(minorVersion >= 18 && minorVersion <= 20) { + if !(minorVersion >= 20 && minorVersion <= 21) { if t.Debugf != nil { t.Debugf("%s is not support http3", v) } From 16f0680d983b4d4df4a25e096fed92d77f5ab247 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 28 Aug 2023 19:25:12 +0800 Subject: [PATCH 781/843] Fix COMPRESSION_ERROR in ImpersonateXXX (#275) --- internal/http2/transport.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 4a80a4b9..88bb9714 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -653,7 +653,23 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro } cc.cond = sync.NewCond(&cc.mu) - cc.flow.add(int32(initialWindowSize)) + + var windowSize int32 = initialWindowSize + var headerTableSize uint32 = initialHeaderTableSize + for _, setting := range t.Settings { + switch setting.ID { + case http2.SettingMaxFrameSize: + cc.maxFrameSize = setting.Val + case http2.SettingMaxHeaderListSize: + t.MaxHeaderListSize = setting.Val + case http2.SettingHeaderTableSize: + headerTableSize = setting.Val + case http2.SettingInitialWindowSize: + windowSize = int32(setting.Val) + } + } + + cc.flow.add(windowSize) // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. @@ -668,7 +684,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro if t.CountError != nil { cc.fr.countError = t.CountError } - cc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + cc.fr.ReadMetaHeaders = hpack.NewDecoder(headerTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on @@ -710,7 +726,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.nextStreamID = p.StreamID + 2 } - cc.inflow.init(int32(connFlow) + initialWindowSize) + cc.inflow.init(int32(connFlow) + windowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() From 7b2415be97bb2f2c706f6e2c98c2e5bf95b21138 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 31 Aug 2023 15:51:02 +0800 Subject: [PATCH 782/843] Fix FLOW_CONTROL_ERROR in ImpersonateXXX (#275) --- internal/http2/transport.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 88bb9714..115eab6f 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -654,7 +654,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.cond = sync.NewCond(&cc.mu) - var windowSize int32 = initialWindowSize var headerTableSize uint32 = initialHeaderTableSize for _, setting := range t.Settings { switch setting.ID { @@ -664,12 +663,10 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro t.MaxHeaderListSize = setting.Val case http2.SettingHeaderTableSize: headerTableSize = setting.Val - case http2.SettingInitialWindowSize: - windowSize = int32(setting.Val) } } - cc.flow.add(windowSize) + cc.flow.add(initialWindowSize) // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. @@ -726,7 +723,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.nextStreamID = p.StreamID + 2 } - cc.inflow.init(int32(connFlow) + windowSize) + cc.inflow.init(int32(connFlow) + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() From 6fe6ed1ffb61f9e18e2dde00e965adc8517f3d66 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 5 Sep 2023 11:13:50 +0800 Subject: [PATCH 783/843] Update dependencies --- go.mod | 19 ++++++++++--------- go.sum | 20 ++++++++++++++++++++ internal/http3/error_codes.go | 2 +- internal/http3/frames.go | 2 +- internal/http3/roundtrip.go | 3 ++- 5 files changed, 34 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 95ecf0e9..69f81ba5 100644 --- a/go.mod +++ b/go.mod @@ -5,25 +5,26 @@ go 1.20 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.37.4 - github.com/refraction-networking/utls v1.4.3 + github.com/quic-go/quic-go v0.38.1 + github.com/refraction-networking/utls v1.5.3 golang.org/x/net v0.14.0 - golang.org/x/text v0.12.0 + golang.org/x/text v0.13.0 ) require ( github.com/andybalholm/brotli v1.0.5 // indirect + github.com/cloudflare/circl v1.3.3 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 // indirect + github.com/google/pprof v0.0.0-20230901174712-0191c66da455 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/klauspost/compress v1.16.7 // indirect - github.com/onsi/ginkgo/v2 v2.11.0 // indirect - github.com/quic-go/qtls-go1-20 v0.3.2 // indirect + github.com/onsi/ginkgo/v2 v2.12.0 // indirect + github.com/quic-go/qtls-go1-20 v0.3.3 // indirect golang.org/x/crypto v0.12.0 // indirect - golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb // indirect + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect ) diff --git a/go.sum b/go.sum index 9e738ffc..cf811594 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= +github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -26,6 +28,8 @@ github.com/google/pprof v0.0.0-20230808223545-4887780b67fb h1:oqpb3Cwpc7EOml5PVG github.com/google/pprof v0.0.0-20230808223545-4887780b67fb/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 h1:0h35ESZ02+hN/MFZb7XZOXg+Rl9+Rk8fBIf5YLws9gA= github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/pprof v0.0.0-20230901174712-0191c66da455 h1:YhRUmI1ttDC4sxKY2V62BTI8hCXnyZBV9h38eAanInE= +github.com/google/pprof v0.0.0-20230901174712-0191c66da455/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -41,6 +45,8 @@ github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= +github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= +github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -54,6 +60,8 @@ github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQ github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI= github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= +github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= @@ -62,6 +70,8 @@ github.com/quic-go/quic-go v0.37.0 h1:wf/Ym2yeWi98oQn4ahiBSqdnaXVxNQGj2oBQFgiVCh github.com/quic-go/quic-go v0.37.0/go.mod h1:XtCUOCALTTWbPyd0IxFfHf6h0sEMubRFvEYHl3QxKw8= github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH4= github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= +github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= +github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= @@ -72,6 +82,8 @@ github.com/refraction-networking/utls v1.4.2 h1:7N+928mSM1pEyAJb8x2Y1FbEwTIftGwn github.com/refraction-networking/utls v1.4.2/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= github.com/refraction-networking/utls v1.4.3 h1:BdWS3BSzCwWCFfMIXP3mjLAyQkdmog7diaD/OqFbAzM= github.com/refraction-networking/utls v1.4.3/go.mod h1:4u9V/awOSBrRw6+federGmVJQfPtemEqLBXkML1b0bo= +github.com/refraction-networking/utls v1.5.3 h1:Ds5Ocg1+MC1ahNx5iBEcHe0jHeLaA/fLey61EENm7ro= +github.com/refraction-networking/utls v1.5.3/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -96,6 +108,8 @@ golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggk golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -129,6 +143,8 @@ golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -140,6 +156,8 @@ golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -151,6 +169,8 @@ golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/http3/error_codes.go b/internal/http3/error_codes.go index 5df9b5df..148447ea 100644 --- a/internal/http3/error_codes.go +++ b/internal/http3/error_codes.go @@ -26,7 +26,7 @@ const ( errorMessageError errorCode = 0x10e errorConnectError errorCode = 0x10f errorVersionFallback errorCode = 0x110 - errorDatagramError errorCode = 0x4a1268 + errorDatagramError errorCode = 0x33 ) func (e errorCode) String() string { diff --git a/internal/http3/frames.go b/internal/http3/frames.go index 5e31d72a..a3cd88ad 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -87,7 +87,7 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0xffd277 +const settingDatagram = 0x33 type settingsFrame struct { Datagram bool diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 3171e09d..3e99acb8 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -42,7 +42,8 @@ type RoundTripper struct { // Enable support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + // + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // Additional HTTP/3 settings. From ebe4e19b04b482e7020b69fb9bd823392845bf01 Mon Sep 17 00:00:00 2001 From: nange Date: Thu, 7 Sep 2023 16:59:40 +0800 Subject: [PATCH 784/843] Fix RootCAs setting when using utls --- client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index ae4035a0..1a14c6d5 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,6 @@ import ( "encoding/json" "encoding/xml" "errors" - "github.com/imroc/req/v3/http2" "io" "net" "net/http" @@ -19,11 +18,11 @@ import ( "strings" "time" - "golang.org/x/net/publicsuffix" - + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" utls "github.com/refraction-networking/utls" + "golang.org/x/net/publicsuffix" ) // DefaultClient returns the global default Client. @@ -1179,6 +1178,7 @@ func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { hostname := addr[:colonPos] utlsConfig := &utls.Config{ ServerName: hostname, + RootCAs: c.GetTLSClientConfig().RootCAs, NextProtos: c.GetTLSClientConfig().NextProtos, InsecureSkipVerify: c.GetTLSClientConfig().InsecureSkipVerify, } From 21fe799e249cca1559e170c1ece38a99f10e7237 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 15 Sep 2023 19:57:35 +0800 Subject: [PATCH 785/843] feat: Add OnError to Client Support the error hook, which will be executed if any error will be returned --- client.go | 15 +++++++++++++-- request.go | 14 +++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 1a14c6d5..f4ed8b0b 100644 --- a/client.go +++ b/client.go @@ -18,11 +18,12 @@ import ( "strings" "time" + utls "github.com/refraction-networking/utls" + "golang.org/x/net/publicsuffix" + "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" - utls "github.com/refraction-networking/utls" - "golang.org/x/net/publicsuffix" ) // DefaultClient returns the global default Client. @@ -70,8 +71,11 @@ type Client struct { roundTripWrappers []RoundTripWrapper responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) resultStateCheckFunc func(resp *Response) ResultState + onError ErrorHook } +type ErrorHook func(client *Client, req *Request, resp *Response, err error) + // R create a new request. func (c *Client) R() *Request { return &Request{ @@ -981,6 +985,13 @@ func (c *Client) SetProxy(proxy func(*http.Request) (*urlpkg.URL, error)) *Clien return c } +// OnError set the error hook which will be executed if any error returned, +// even if the occurs before request is sent (e.g. invalid URL). +func (c *Client) OnError(hook ErrorHook) *Client { + c.onError = hook + return c +} + // OnBeforeRequest add a request middleware which hooks before request sent. func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { c.udBeforeRequest = append(c.udBeforeRequest, m) diff --git a/request.go b/request.go index 4a40df75..2e0fceb6 100644 --- a/request.go +++ b/request.go @@ -15,6 +15,7 @@ import ( "time" "github.com/hashicorp/go-multierror" + "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/util" @@ -285,9 +286,11 @@ func (r *Request) SetFile(paramName, filePath string) *Request { }) } -var errMissingParamName = errors.New("missing param name in multipart file upload") -var errMissingFileName = errors.New("missing filename in multipart file upload") -var errMissingFileContent = errors.New("missing file content in multipart file upload") +var ( + errMissingParamName = errors.New("missing param name in multipart file upload") + errMissingFileName = errors.New("missing filename in multipart file upload") + errMissingFileContent = errors.New("missing file content in multipart file upload") +) // SetFileUpload set the fully custimized multipart file upload options. func (r *Request) SetFileUpload(uploads ...FileUpload) *Request { @@ -695,8 +698,6 @@ func (r *Request) do() (resp *Response, err error) { resp.result = nil resp.error = nil } - - return } // Send fires http request with specified method and url, returns the @@ -705,6 +706,9 @@ func (r *Request) Send(method, url string) (*Response, error) { r.Method = method r.RawURL = url resp := r.Do() + if resp.Err != nil && r.client.onError != nil { + r.client.onError(r.client, r, resp, resp.Err) + } return resp, resp.Err } From 21697a2503417bf78c204499fa6509f01b51073a Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 26 Sep 2023 10:13:57 +0800 Subject: [PATCH 786/843] Ignore PseudoHeaderOderKey while write http/1.1 header --- header.go | 6 ++++-- http_request.go | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/header.go b/header.go index d31da8d7..3c94f7db 100644 --- a/header.go +++ b/header.go @@ -1,14 +1,16 @@ package req import ( - "github.com/imroc/req/v3/internal/header" - "golang.org/x/net/http/httpguts" "io" "net/http" "net/textproto" "sort" "strings" "sync" + + "golang.org/x/net/http/httpguts" + + "github.com/imroc/req/v3/internal/header" ) var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") diff --git a/http_request.go b/http_request.go index c737a6b3..5816de44 100644 --- a/http_request.go +++ b/http_request.go @@ -2,11 +2,13 @@ package req import ( "errors" - "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/header" - "golang.org/x/net/http/httpguts" "net/http" "strings" + + "golang.org/x/net/http/httpguts" + + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/header" ) // Given a string of the form "host", "host:port", or "[ipv6::address]:port", @@ -88,12 +90,13 @@ func closeRequestBody(r *http.Request) error { // Headers that Request.Write handles itself and should be skipped. var reqWriteExcludeHeader = map[string]bool{ - "Host": true, // not in Header map anyway - "User-Agent": true, - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, - header.HeaderOderKey: true, + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, + header.HeaderOderKey: true, + header.PseudoHeaderOderKey: true, } // requestMethodUsuallyLacksBody reports whether the given request From bc2bf8622c2c06437aa1153509ac78058978ceaf Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 6 Oct 2023 18:32:09 +0800 Subject: [PATCH 787/843] fix SetTLSFingerprintXXX does not take effect in subsequent requests(#290) --- internal/http2/transport.go | 94 +++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 115eab6f..0007cdee 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -15,14 +15,6 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/imroc/req/v3/http2" - "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/common" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/netutil" - "github.com/imroc/req/v3/internal/transport" - reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" "log" @@ -43,6 +35,15 @@ import ( "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" + + "github.com/imroc/req/v3/http2" + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/netutil" + "github.com/imroc/req/v3/internal/transport" + reqtls "github.com/imroc/req/v3/pkg/tls" ) const ( @@ -157,7 +158,6 @@ func (t *Transport) pingTimeout() time.Duration { return 15 * time.Second } return t.PingTimeout - } func (t *Transport) connPool() ClientConnPool { @@ -585,18 +585,72 @@ func (t *Transport) newTLSConfig(host string) *tls.Config { return cfg } +var zeroDialer net.Dialer + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { - dialer := &tls.Dialer{ - Config: cfg, - } - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err + if t.TLSHandshakeContext != nil { + conn, err := zeroDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + var firstTLSHost string + if firstTLSHost, _, err = net.SplitHostPort(addr); err != nil { + return nil, err + } + trace := httptrace.ContextClientTrace(ctx) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + tlsCn, tlsState, err := t.TLSHandshakeContext(ctx, firstTLSHost, conn) + if err != nil { + if timer != nil { + timer.Stop() + } + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + } else { + conn = tlsCn + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(*tlsState, nil) + } + } + errc <- err + }() + if err := <-errc; err != nil { + conn.Close() + return nil, err + } else { + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil + } + } else { + dialer := &tls.Dialer{ + Config: cfg, + } + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil } - tlsCn := conn.(reqtls.Conn) - return tlsCn, nil } func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { @@ -1771,7 +1825,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if a := cs.flow.available(); a > 0 { take := a if int(take) > maxBytes { - take = int32(maxBytes) // can't truncate int; take is int32 } if take > int32(cc.maxFrameSize) { @@ -1928,7 +1981,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail break } vals = append(vals, v[:p]) - //writeHeader("cookie", v[:p]) + // writeHeader("cookie", v[:p]) p++ // strip space after semicolon if any. for p+1 <= len(v) && v[p] == ' ' { @@ -1938,7 +1991,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } if len(v) > 0 { vals = append(vals, v) - //writeHeader("cookie", v) + // writeHeader("cookie", v) } } writeHeader("cookie", vals...) @@ -2641,6 +2694,7 @@ func (b transportResponseBody) Close() error { } return nil } + func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc := rl.cc cs := rl.streamByID(f.StreamID) From 5323efe519c828b9b97dc52f13386e007e5972b5 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 20 Nov 2023 13:02:21 +0800 Subject: [PATCH 788/843] Fix retry in SetFileBytes (#300) --- request.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/request.go b/request.go index 2e0fceb6..7a5942a2 100644 --- a/request.go +++ b/request.go @@ -242,7 +242,14 @@ func (r *Request) SetFileReader(paramName, filename string, reader io.Reader) *R // SetFileBytes set up a multipart form with given []byte to upload. func (r *Request) SetFileBytes(paramName, filename string, content []byte) *Request { - return r.SetFileReader(paramName, filename, bytes.NewReader(content)) + r.SetFileUpload(FileUpload{ + ParamName: paramName, + FileName: filename, + GetFileContent: func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(content)), nil + }, + }) + return r } // SetFiles set up a multipart form from a map to upload, which From 43359d7be41bed7c24c20767c8def61fce26d92c Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 5 Jan 2024 11:56:44 +0800 Subject: [PATCH 789/843] update go modules, support quic-go v0.40.1 --- go.mod | 31 +++++++------ go.sum | 32 +++++++++++++ internal/http3/body.go | 4 +- internal/http3/client.go | 43 +++++++++--------- internal/http3/error.go | 58 +++++++++++++++++++++++ internal/http3/error_codes.go | 86 +++++++++++++++++++---------------- internal/http3/server.go | 8 ++-- 7 files changed, 181 insertions(+), 81 deletions(-) create mode 100644 internal/http3/error.go diff --git a/go.mod b/go.mod index 69f81ba5..6f0b6ce4 100644 --- a/go.mod +++ b/go.mod @@ -5,26 +5,27 @@ go 1.20 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.38.1 - github.com/refraction-networking/utls v1.5.3 - golang.org/x/net v0.14.0 - golang.org/x/text v0.13.0 + github.com/quic-go/quic-go v0.40.1 + github.com/refraction-networking/utls v1.6.0 + golang.org/x/net v0.19.0 + golang.org/x/text v0.14.0 ) require ( - github.com/andybalholm/brotli v1.0.5 // indirect - github.com/cloudflare/circl v1.3.3 // indirect + github.com/andybalholm/brotli v1.0.6 // indirect + github.com/cloudflare/circl v1.3.7 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20230901174712-0191c66da455 // indirect + github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/klauspost/compress v1.16.7 // indirect - github.com/onsi/ginkgo/v2 v2.12.0 // indirect - github.com/quic-go/qtls-go1-20 v0.3.3 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/sys v0.12.0 // indirect - golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect + github.com/klauspost/compress v1.17.4 // indirect + github.com/onsi/ginkgo/v2 v2.13.2 // indirect + github.com/quic-go/qtls-go1-20 v0.4.1 // indirect + go.uber.org/mock v0.4.0 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc // indirect + golang.org/x/mod v0.14.0 // indirect + golang.org/x/sys v0.16.0 // indirect + golang.org/x/tools v0.16.1 // indirect ) diff --git a/go.sum b/go.sum index cf811594..b71b469e 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,12 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= +github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= +github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= +github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -30,6 +34,8 @@ github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 h1:0h35ESZ02+hN/MFZb7 github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230901174712-0191c66da455 h1:YhRUmI1ttDC4sxKY2V62BTI8hCXnyZBV9h38eAanInE= github.com/google/pprof v0.0.0-20230901174712-0191c66da455/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42 h1:dHLYa5D8/Ta0aLR2XcPsrkpAgGeFs6thhMcQK0oQ0n8= +github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -39,6 +45,8 @@ github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7y github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= @@ -47,6 +55,8 @@ github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= +github.com/onsi/ginkgo/v2 v2.13.2 h1:Bi2gGVkfn6gQcjNjZJVO8Gf0FHzMPf2phUei9tejVMs= +github.com/onsi/ginkgo/v2 v2.13.2/go.mod h1:XStQ8QcGwLyF4HdfcZB8SFOS/MWCgDuXMSBe6zrvLgM= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -62,6 +72,8 @@ github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQ github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= +github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= @@ -72,6 +84,8 @@ github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= +github.com/quic-go/quic-go v0.40.1 h1:X3AGzUNFs0jVuO3esAGnTfvdgvL4fq655WaOi1snv1Q= +github.com/quic-go/quic-go v0.40.1/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= @@ -84,10 +98,14 @@ github.com/refraction-networking/utls v1.4.3 h1:BdWS3BSzCwWCFfMIXP3mjLAyQkdmog7d github.com/refraction-networking/utls v1.4.3/go.mod h1:4u9V/awOSBrRw6+federGmVJQfPtemEqLBXkML1b0bo= github.com/refraction-networking/utls v1.5.3 h1:Ds5Ocg1+MC1ahNx5iBEcHe0jHeLaA/fLey61EENm7ro= github.com/refraction-networking/utls v1.5.3/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= +github.com/refraction-networking/utls v1.6.0 h1:X5vQMqVx7dY7ehxxqkFER/W6DSjy8TMqSItXm8hRDYQ= +github.com/refraction-networking/utls v1.6.0/go.mod h1:kHJ6R9DFFA0WsRgBM35iiDku4O7AqPR6y79iuzW7b10= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= @@ -98,6 +116,8 @@ golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= @@ -110,6 +130,8 @@ golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8Cj golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= +golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc h1:ao2WRsKSzW6KuUY9IWPwWahcHCgR0s52IfwutMfEbdM= +golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -117,6 +139,8 @@ golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= @@ -128,6 +152,8 @@ golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -145,6 +171,8 @@ golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -158,6 +186,8 @@ golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -171,6 +201,8 @@ golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/http3/body.go b/internal/http3/body.go index 15985a1c..63ff4366 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -67,7 +67,7 @@ func (r *body) Read(b []byte) (int, error) { } func (r *body) Close() error { - r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } @@ -126,7 +126,7 @@ func (r *body) StreamID() quic.StreamID { func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. - r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } diff --git a/internal/http3/client.go b/internal/http3/client.go index 4ba9e9bb..673c85df 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -140,7 +140,7 @@ func (c *client) dial(ctx context.Context) error { go func() { if err := c.setupConn(conn); err != nil { c.opt.Debugf("setting up http3 connection failed: %s", err) - conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() @@ -182,7 +182,7 @@ func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { if err != nil { c.opt.Debugf("error handling stream: %s", err) } - conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }(str) } } @@ -213,23 +213,23 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { return case streamTypePushStream: // We never increased the Push ID, so we don't expect any push streams. - conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return default: if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { return } - str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } f, err := parseNextFrame(str, nil) if err != nil { - conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } if !sf.Datagram { @@ -239,7 +239,7 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { - conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") } }(str) } @@ -250,7 +250,7 @@ func (c *client) Close() error { if conn == nil { return nil } - return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "") + return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } func (c *client) maxHeaderBytes() uint64 { @@ -302,8 +302,8 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon defer close(done) select { case <-req.Context().Done(): - str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) - str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) + str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) case <-reqDone: } }() @@ -326,13 +326,14 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } - return nil, rerr.err + return nil, maybeReplaceError(rerr.err) + } if opt.DontCloseRequestStream { close(reqDone) <-done } - return rsp, rerr.err + return rsp, maybeReplaceError(rerr.err) } func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { @@ -378,7 +379,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.D } break } - str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) return rerr } } @@ -398,14 +399,14 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui } } if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, headerDumps); err != nil { - return nil, newStreamError(errorInternalError, err) + return nil, newStreamError(ErrCodeInternalError, err) } if req.Body == nil && !opt.DontCloseRequestStream { str.Close() } - hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) if req.Body != nil { // send the request body asynchronously go func() { @@ -426,18 +427,18 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui frame, err := parseNextFrame(str, nil) if err != nil { - return nil, newStreamError(errorFrameError, err) + return nil, newStreamError(ErrCodeFrameError, err) } hf, ok := frame.(*headersFrame) if !ok { - return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) + return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) } if hf.Length > c.maxHeaderBytes() { - return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) + return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) } headerBlock := make([]byte, hf.Length) if _, err := io.ReadFull(str, headerBlock); err != nil { - return nil, newStreamError(errorRequestIncomplete, err) + return nil, newStreamError(ErrCodeRequestIncomplete, err) } var respHeaderDumps []*dump.Dumper for _, dump := range dumps { @@ -458,12 +459,12 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui } if err != nil { // TODO: use the right error code - return nil, newConnError(errorGeneralProtocolError, err) + return nil, newConnError(ErrCodeGeneralProtocolError, err) } res, err := responseFromHeaders(hfs) if err != nil { - return nil, newStreamError(errorMessageError, err) + return nil, newStreamError(ErrCodeMessageError, err) } res.Request = req connState := conn.ConnectionState().TLS diff --git a/internal/http3/error.go b/internal/http3/error.go new file mode 100644 index 00000000..b96ebeec --- /dev/null +++ b/internal/http3/error.go @@ -0,0 +1,58 @@ +package http3 + +import ( + "errors" + "fmt" + + "github.com/quic-go/quic-go" +) + +// Error is returned from the round tripper (for HTTP clients) +// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs. +// See section 8 of RFC 9114. +type Error struct { + Remote bool + ErrorCode ErrCode + ErrorMessage string +} + +var _ error = &Error{} + +func (e *Error) Error() string { + s := e.ErrorCode.string() + if s == "" { + s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode)) + } + // Usually errors are remote. Only make it explicit for local errors. + if !e.Remote { + s += " (local)" + } + if e.ErrorMessage != "" { + s += ": " + e.ErrorMessage + } + return s +} + +func maybeReplaceError(err error) error { + if err == nil { + return nil + } + + var ( + e Error + strErr *quic.StreamError + appErr *quic.ApplicationError + ) + switch { + default: + return err + case errors.As(err, &strErr): + e.Remote = strErr.Remote + e.ErrorCode = ErrCode(strErr.ErrorCode) + case errors.As(err, &appErr): + e.Remote = appErr.Remote + e.ErrorCode = ErrCode(appErr.ErrorCode) + e.ErrorMessage = appErr.ErrorMessage + } + return &e +} diff --git a/internal/http3/error_codes.go b/internal/http3/error_codes.go index 148447ea..ae646586 100644 --- a/internal/http3/error_codes.go +++ b/internal/http3/error_codes.go @@ -6,68 +6,76 @@ import ( "github.com/quic-go/quic-go" ) -type errorCode quic.ApplicationErrorCode +type ErrCode quic.ApplicationErrorCode const ( - errorNoError errorCode = 0x100 - errorGeneralProtocolError errorCode = 0x101 - errorInternalError errorCode = 0x102 - errorStreamCreationError errorCode = 0x103 - errorClosedCriticalStream errorCode = 0x104 - errorFrameUnexpected errorCode = 0x105 - errorFrameError errorCode = 0x106 - errorExcessiveLoad errorCode = 0x107 - errorIDError errorCode = 0x108 - errorSettingsError errorCode = 0x109 - errorMissingSettings errorCode = 0x10a - errorRequestRejected errorCode = 0x10b - errorRequestCanceled errorCode = 0x10c - errorRequestIncomplete errorCode = 0x10d - errorMessageError errorCode = 0x10e - errorConnectError errorCode = 0x10f - errorVersionFallback errorCode = 0x110 - errorDatagramError errorCode = 0x33 + ErrCodeNoError ErrCode = 0x100 + ErrCodeGeneralProtocolError ErrCode = 0x101 + ErrCodeInternalError ErrCode = 0x102 + ErrCodeStreamCreationError ErrCode = 0x103 + ErrCodeClosedCriticalStream ErrCode = 0x104 + ErrCodeFrameUnexpected ErrCode = 0x105 + ErrCodeFrameError ErrCode = 0x106 + ErrCodeExcessiveLoad ErrCode = 0x107 + ErrCodeIDError ErrCode = 0x108 + ErrCodeSettingsError ErrCode = 0x109 + ErrCodeMissingSettings ErrCode = 0x10a + ErrCodeRequestRejected ErrCode = 0x10b + ErrCodeRequestCanceled ErrCode = 0x10c + ErrCodeRequestIncomplete ErrCode = 0x10d + ErrCodeMessageError ErrCode = 0x10e + ErrCodeConnectError ErrCode = 0x10f + ErrCodeVersionFallback ErrCode = 0x110 + ErrCodeDatagramError ErrCode = 0x33 ) -func (e errorCode) String() string { +func (e ErrCode) String() string { + s := e.string() + if s != "" { + return s + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) +} + +func (e ErrCode) string() string { switch e { - case errorNoError: + case ErrCodeNoError: return "H3_NO_ERROR" - case errorGeneralProtocolError: + case ErrCodeGeneralProtocolError: return "H3_GENERAL_PROTOCOL_ERROR" - case errorInternalError: + case ErrCodeInternalError: return "H3_INTERNAL_ERROR" - case errorStreamCreationError: + case ErrCodeStreamCreationError: return "H3_STREAM_CREATION_ERROR" - case errorClosedCriticalStream: + case ErrCodeClosedCriticalStream: return "H3_CLOSED_CRITICAL_STREAM" - case errorFrameUnexpected: + case ErrCodeFrameUnexpected: return "H3_FRAME_UNEXPECTED" - case errorFrameError: + case ErrCodeFrameError: return "H3_FRAME_ERROR" - case errorExcessiveLoad: + case ErrCodeExcessiveLoad: return "H3_EXCESSIVE_LOAD" - case errorIDError: + case ErrCodeIDError: return "H3_ID_ERROR" - case errorSettingsError: + case ErrCodeSettingsError: return "H3_SETTINGS_ERROR" - case errorMissingSettings: + case ErrCodeMissingSettings: return "H3_MISSING_SETTINGS" - case errorRequestRejected: + case ErrCodeRequestRejected: return "H3_REQUEST_REJECTED" - case errorRequestCanceled: + case ErrCodeRequestCanceled: return "H3_REQUEST_CANCELLED" - case errorRequestIncomplete: + case ErrCodeRequestIncomplete: return "H3_INCOMPLETE_REQUEST" - case errorMessageError: + case ErrCodeMessageError: return "H3_MESSAGE_ERROR" - case errorConnectError: + case ErrCodeConnectError: return "H3_CONNECT_ERROR" - case errorVersionFallback: + case ErrCodeVersionFallback: return "H3_VERSION_FALLBACK" - case errorDatagramError: + case ErrCodeDatagramError: return "H3_DATAGRAM_ERROR" default: - return fmt.Sprintf("unknown error code: %#x", uint16(e)) + return "" } } diff --git a/internal/http3/server.go b/internal/http3/server.go index 7449c3c7..2b9d8658 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -31,14 +31,14 @@ func versionToALPN(v quic.VersionNumber) string { type requestError struct { err error - streamErr errorCode - connErr errorCode + streamErr ErrCode + connErr ErrCode } -func newStreamError(code errorCode, err error) requestError { +func newStreamError(code ErrCode, err error) requestError { return requestError{err: err, streamErr: code} } -func newConnError(code errorCode, err error) requestError { +func newConnError(code ErrCode, err error) requestError { return requestError{err: err, connErr: code} } From b2189eef86f5ee25e05efbb5700414bd2540eaf9 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 15:19:00 +0800 Subject: [PATCH 790/843] update go modules --- go.mod | 29 ++++---- go.sum | 217 +++++++++------------------------------------------------ 2 files changed, 46 insertions(+), 200 deletions(-) diff --git a/go.mod b/go.mod index 6f0b6ce4..7b617180 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,28 @@ module github.com/imroc/req/v3 -go 1.20 +go 1.21 require ( github.com/hashicorp/go-multierror v1.1.1 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.40.1 - github.com/refraction-networking/utls v1.6.0 - golang.org/x/net v0.19.0 + github.com/quic-go/quic-go v0.41.0 + github.com/refraction-networking/utls v1.6.3 + golang.org/x/net v0.22.0 golang.org/x/text v0.14.0 ) require ( - github.com/andybalholm/brotli v1.0.6 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect - github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect - github.com/golang/mock v1.6.0 // indirect - github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42 // indirect + github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.4 // indirect - github.com/onsi/ginkgo/v2 v2.13.2 // indirect - github.com/quic-go/qtls-go1-20 v0.4.1 // indirect + github.com/klauspost/compress v1.17.7 // indirect + github.com/onsi/ginkgo/v2 v2.16.0 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/sys v0.16.0 // indirect - golang.org/x/tools v0.16.1 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/tools v0.19.0 // indirect ) diff --git a/go.sum b/go.sum index b71b469e..7eb022ef 100644 --- a/go.sum +++ b/go.sum @@ -1,212 +1,61 @@ -github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= -github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= -github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= -github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= -github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= -github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= -github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= -github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= -github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= -github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= -github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8 h1:n6vlPhxsA+BW/XsS5+uqi7GyzaLa5MH7qlSLBZtRdiA= -github.com/google/pprof v0.0.0-20230705174524-200ffdc848b8/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= -github.com/google/pprof v0.0.0-20230808223545-4887780b67fb h1:oqpb3Cwpc7EOml5PVGMYbSGmwNui2R7i8IW83gs4W0c= -github.com/google/pprof v0.0.0-20230808223545-4887780b67fb/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= -github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17 h1:0h35ESZ02+hN/MFZb7XZOXg+Rl9+Rk8fBIf5YLws9gA= -github.com/google/pprof v0.0.0-20230811205829-9131a7e9cc17/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= -github.com/google/pprof v0.0.0-20230901174712-0191c66da455 h1:YhRUmI1ttDC4sxKY2V62BTI8hCXnyZBV9h38eAanInE= -github.com/google/pprof v0.0.0-20230901174712-0191c66da455/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= -github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42 h1:dHLYa5D8/Ta0aLR2XcPsrkpAgGeFs6thhMcQK0oQ0n8= -github.com/google/pprof v0.0.0-20231229205709-960ae82b1e42/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q= +github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= -github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= -github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= -github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= -github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= -github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= -github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= -github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= -github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= -github.com/onsi/ginkgo/v2 v2.13.2 h1:Bi2gGVkfn6gQcjNjZJVO8Gf0FHzMPf2phUei9tejVMs= -github.com/onsi/ginkgo/v2 v2.13.2/go.mod h1:XStQ8QcGwLyF4HdfcZB8SFOS/MWCgDuXMSBe6zrvLgM= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= +github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/onsi/ginkgo/v2 v2.16.0 h1:7q1w9frJDzninhXxjZd+Y/x54XNjG/UlRLIYPZafsPM= +github.com/onsi/ginkgo/v2 v2.16.0/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= +github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= +github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= -github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= -github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI= -github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= -github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= -github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/quic-go/quic-go v0.35.0 h1:JXIf219xJK+4qGeY52rlnrVqeB2AXUAwfLU9JSoWXwg= -github.com/quic-go/quic-go v0.35.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= -github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= -github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= -github.com/quic-go/quic-go v0.37.0 h1:wf/Ym2yeWi98oQn4ahiBSqdnaXVxNQGj2oBQFgiVChc= -github.com/quic-go/quic-go v0.37.0/go.mod h1:XtCUOCALTTWbPyd0IxFfHf6h0sEMubRFvEYHl3QxKw8= -github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH4= -github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= -github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= -github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= -github.com/quic-go/quic-go v0.40.1 h1:X3AGzUNFs0jVuO3esAGnTfvdgvL4fq655WaOi1snv1Q= -github.com/quic-go/quic-go v0.40.1/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c= -github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= -github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= -github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= -github.com/refraction-networking/utls v1.3.3/go.mod h1:DlecWW1LMlMJu+9qpzzQqdHDT/C2LAe03EdpLUz/RL8= -github.com/refraction-networking/utls v1.4.1 h1:5VXwhNzrnWrvbJW8IVpptJKrErZGqoRbn7wqu2jqMrU= -github.com/refraction-networking/utls v1.4.1/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= -github.com/refraction-networking/utls v1.4.2 h1:7N+928mSM1pEyAJb8x2Y1FbEwTIftGwn2IFykosSzwc= -github.com/refraction-networking/utls v1.4.2/go.mod h1:JkUIj+Pc8eyFB0z+A4RJRZmoT43ajjFZWVMXuZQ8BEQ= -github.com/refraction-networking/utls v1.4.3 h1:BdWS3BSzCwWCFfMIXP3mjLAyQkdmog7diaD/OqFbAzM= -github.com/refraction-networking/utls v1.4.3/go.mod h1:4u9V/awOSBrRw6+federGmVJQfPtemEqLBXkML1b0bo= -github.com/refraction-networking/utls v1.5.3 h1:Ds5Ocg1+MC1ahNx5iBEcHe0jHeLaA/fLey61EENm7ro= -github.com/refraction-networking/utls v1.5.3/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= -github.com/refraction-networking/utls v1.6.0 h1:X5vQMqVx7dY7ehxxqkFER/W6DSjy8TMqSItXm8hRDYQ= -github.com/refraction-networking/utls v1.6.0/go.mod h1:kHJ6R9DFFA0WsRgBM35iiDku4O7AqPR6y79iuzW7b10= +github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= +github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= +github.com/refraction-networking/utls v1.6.3 h1:MFOfRN35sSx6K5AZNIoESsBuBxS2LCgRilRIdHb6fDc= +github.com/refraction-networking/utls v1.6.3/go.mod h1:yil9+7qSl+gBwJqztoQseO6Pr3h62pQoY1lXiNR/FPs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= -golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= -golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= -golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= -golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 h1:EDuYyU/MkFXllv9QF9819VlI9a4tzGuCbhG0ExK9o1U= -golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggkk1bJDsINcuWmJN1iJU= -golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= -golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= -golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc h1:ao2WRsKSzW6KuUY9IWPwWahcHCgR0s52IfwutMfEbdM= -golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= -golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= -golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= -golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= -golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= -golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= -golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= -golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= -golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= -golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= -golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= -golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= -golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= -golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 9c88606080c6c8185470e11d2acbc4e2e069c8d5 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 15:34:44 +0800 Subject: [PATCH 791/843] upgrade quic-go to v0.41.0 --- internal/http3/client.go | 11 +++++++- internal/http3/http_stream.go | 47 +++++++++++++++++++++++++++++++++++ internal/http3/roundtrip.go | 5 ++-- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index 673c85df..59ffb413 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -260,8 +260,17 @@ func (c *client) maxHeaderBytes() uint64 { return uint64(c.opts.MaxHeaderBytes) } -// RoundTripOpt executes a request and returns a response func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + rsp, err := c.roundTripOpt(req, opt) + if err != nil && req.Context().Err() != nil { + // if the context was canceled, return the context cancellation error + err = req.Context().Err() + } + return rsp, err +} + +// RoundTripOpt executes a request and returns a response +func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) } diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index 2799e2b3..bfaf4214 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -1,6 +1,7 @@ package http3 import ( + "errors" "fmt" "github.com/quic-go/quic-go" @@ -66,6 +67,10 @@ func (s *stream) Read(b []byte) (int, error) { return n, err } +func (s *stream) hasMoreData() bool { + return s.bytesRemainingInFrame > 0 +} + func (s *stream) Write(b []byte) (int, error) { s.buf = s.buf[:0] s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf) @@ -74,3 +79,45 @@ func (s *stream) Write(b []byte) (int, error) { } return s.Stream.Write(b) } + +var errTooMuchData = errors.New("peer sent too much data") + +type lengthLimitedStream struct { + *stream + contentLength int64 + read int64 + resetStream bool +} + +var _ Stream = &lengthLimitedStream{} + +func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream { + return &lengthLimitedStream{ + stream: str, + contentLength: contentLength, + } +} + +func (s *lengthLimitedStream) checkContentLengthViolation() error { + if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() { + if !s.resetStream { + s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + s.resetStream = true + } + return errTooMuchData + } + return nil +} + +func (s *lengthLimitedStream) Read(b []byte) (int, error) { + if err := s.checkContentLengthViolation(); err != nil { + return 0, err + } + n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)]) + s.read += int64(n) + if err := s.checkContentLengthViolation(); err != nil { + return n, err + } + return n, err +} diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 3e99acb8..89624094 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -42,8 +42,8 @@ type RoundTripper struct { // Enable support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // - // See https://datatracker.ietf.org/doc/html/rfc9297. + // + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // Additional HTTP/3 settings. @@ -223,6 +223,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr StreamHijacker: r.StreamHijacker, UniStreamHijacker: r.UniStreamHijacker, dump: r.Dump, + AdditionalSettings: r.AdditionalSettings, }, r.QuicConfig, dial, From f3df9fd4c84ceb8b75cd27b565d4e56924e8b8bd Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 16:29:43 +0800 Subject: [PATCH 792/843] merge upstream http2: 2024-03-04(7ee34a) --- internal/http2/databuffer.go | 59 +++++++++++++------------ internal/http2/frame.go | 24 +++++------ internal/http2/transport.go | 83 +++++++++++++++++++++--------------- 3 files changed, 92 insertions(+), 74 deletions(-) diff --git a/internal/http2/databuffer.go b/internal/http2/databuffer.go index a3067f8d..e6f55cbd 100644 --- a/internal/http2/databuffer.go +++ b/internal/http2/databuffer.go @@ -20,41 +20,44 @@ import ( // TODO: Benchmark to determine if the pools are necessary. The GC may have // improved enough that we can instead allocate chunks like this: // make([]byte, max(16<<10, expectedBytesRemaining)) -var ( - dataChunkSizeClasses = []int{ - 1 << 10, - 2 << 10, - 4 << 10, - 8 << 10, - 16 << 10, - } - dataChunkPools = [...]sync.Pool{ - {New: func() interface{} { return make([]byte, 1<<10) }}, - {New: func() interface{} { return make([]byte, 2<<10) }}, - {New: func() interface{} { return make([]byte, 4<<10) }}, - {New: func() interface{} { return make([]byte, 8<<10) }}, - {New: func() interface{} { return make([]byte, 16<<10) }}, - } -) +var dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return new([1 << 10]byte) }}, + {New: func() interface{} { return new([2 << 10]byte) }}, + {New: func() interface{} { return new([4 << 10]byte) }}, + {New: func() interface{} { return new([8 << 10]byte) }}, + {New: func() interface{} { return new([16 << 10]byte) }}, +} func getDataBufferChunk(size int64) []byte { - i := 0 - for ; i < len(dataChunkSizeClasses)-1; i++ { - if size <= int64(dataChunkSizeClasses[i]) { - break - } + switch { + case size <= 1<<10: + return dataChunkPools[0].Get().(*[1 << 10]byte)[:] + case size <= 2<<10: + return dataChunkPools[1].Get().(*[2 << 10]byte)[:] + case size <= 4<<10: + return dataChunkPools[2].Get().(*[4 << 10]byte)[:] + case size <= 8<<10: + return dataChunkPools[3].Get().(*[8 << 10]byte)[:] + default: + return dataChunkPools[4].Get().(*[16 << 10]byte)[:] } - return dataChunkPools[i].Get().([]byte) } func putDataBufferChunk(p []byte) { - for i, n := range dataChunkSizeClasses { - if len(p) == n { - dataChunkPools[i].Put(p) - return - } + switch len(p) { + case 1 << 10: + dataChunkPools[0].Put((*[1 << 10]byte)(p)) + case 2 << 10: + dataChunkPools[1].Put((*[2 << 10]byte)(p)) + case 4 << 10: + dataChunkPools[2].Put((*[4 << 10]byte)(p)) + case 8 << 10: + dataChunkPools[3].Put((*[8 << 10]byte)(p)) + case 16 << 10: + dataChunkPools[4].Put((*[16 << 10]byte)(p)) + default: + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) } - panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) } // dataBuffer is an io.ReadWriter backed by a list of data chunks. diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 2bb88522..81b3c082 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -10,15 +10,16 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/imroc/req/v3/http2" - "github.com/imroc/req/v3/internal/dump" - "golang.org/x/net/http/httpguts" - "golang.org/x/net/http2/hpack" "io" "log" "net/http" "strings" "sync" + + "github.com/imroc/req/v3/http2" + "github.com/imroc/req/v3/internal/dump" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" ) const frameHeaderLen = 9 @@ -1542,14 +1543,13 @@ func (mh *MetaHeadersFrame) checkPseudos() error { return nil } -func (h2f *Framer) maxHeaderStringLen() int { - v := h2f.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) +func (fr *Framer) maxHeaderStringLen() int { + v := int(fr.maxHeaderListSize()) + if v < 0 { + // If maxHeaderListSize overflows an int, use no limit (0). + return 0 } - // They had a crazy big number for MaxHeaderBytes anyway, - // so give them unlimited header lengths: - return 0 + return v } // readMetaFrame returns 0 or more CONTINUATION frames from fr and @@ -1562,7 +1562,7 @@ func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaH mh := &MetaHeadersFrame{ HeadersFrame: hf, } - var remainSize = h2f.maxHeaderListSize() + remainSize := h2f.maxHeaderListSize() var sawRegular bool var invalid error // pseudo header field errors diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 0007cdee..fd3eaf76 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -19,6 +19,7 @@ import ( "io/fs" "log" "math" + "math/bits" mathrand "math/rand" "net" "net/http" @@ -177,8 +178,7 @@ func (t *Transport) initConnPool() { // HTTP/2 server. type ClientConn struct { t *Transport - tconn net.Conn // usually TLSConn, except specialized impls - tconnClosed bool + tconn net.Conn // usually TLSConn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request @@ -410,11 +410,14 @@ func (t *Transport) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, func authorityAddr(scheme string, authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port + host = authority + port = "" + } + if port == "" { // authority's port was empty port = "443" if scheme == "http" { port = "80" } - host = authority } if a, err := idna.ToASCII(host); err == nil { host = a @@ -956,19 +959,10 @@ func (cc *ClientConn) closeConn() { cc.tconn.Close() } -// netConnWrapper is the interface to get underlying connection, which is -// introduced in go1.18 for *tls.Conn. -type netConnWrapper interface { - // NetConn returns the underlying connection that is wrapped by c. - // Note that writing to or reading from this connection directly will corrupt the - // TLS session. - NetConn() net.Conn -} - // A tls.Conn.Close can hang for a long time if the peer is unresponsive. // Try to shut it down more aggressively. func (cc *ClientConn) forceCloseConn() { - tc, ok := cc.tconn.(netConnWrapper) + tc, ok := cc.tconn.(*tls.Conn) if !ok { return } @@ -1218,22 +1212,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cancelRequest := func(cs *clientStream, err error) error { cs.cc.mu.Lock() - cs.abortStreamLocked(err) bodyClosed := cs.reqBodyClosed - if cs.ID != 0 { - // This request may have failed because of a problem with the connection, - // or for some unrelated reason. (For example, the user might have canceled - // the request without waiting for a response.) Mark the connection as - // not reusable, since trying to reuse a dead connection is worse than - // unnecessarily creating a new one. - // - // If cs.ID is 0, then the request was never allocated a stream ID and - // whatever went wrong was unrelated to the connection. We might have - // timed out waiting for a stream slot when StrictMaxConcurrentStreams - // is set, for example, in which case retrying on a different connection - // will not help. - cs.cc.doNotReuse = true - } cs.cc.mu.Unlock() // Wait for the request body to be closed. // @@ -1268,11 +1247,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return handleResponseHeaders() default: waitDone() - return nil, cancelRequest(cs, cs.abortErr) + return nil, cs.abortErr } case <-ctx.Done(): - return nil, cancelRequest(cs, ctx.Err()) + err := ctx.Err() + cs.abortStream(err) + return nil, cancelRequest(cs, err) case <-cs.reqCancel: + cs.abortStream(common.ErrRequestCanceled) return nil, cancelRequest(cs, common.ErrRequestCanceled) } } @@ -1654,7 +1636,27 @@ func (cs *clientStream) frameScratchBufferLen(maxFrameSize int) int { return int(n) // doesn't truncate; max is 512K } -var bufPool sync.Pool // of *[]byte +// Seven bufPools manage different frame sizes. This helps to avoid scenarios where long-running +// streaming requests using small frame sizes occupy large buffers initially allocated for prior +// requests needing big buffers. The size ranges are as follows: +// {0 KB, 16 KB], {16 KB, 32 KB], {32 KB, 64 KB], {64 KB, 128 KB], {128 KB, 256 KB], +// {256 KB, 512 KB], {512 KB, infinity} +// In practice, the maximum scratch buffer size should not exceed 512 KB due to +// frameScratchBufferLen(maxFrameSize), thus the "infinity pool" should never be used. +// It exists mainly as a safety measure, for potential future increases in max buffer size. +var bufPools [7]sync.Pool // of *[]byte +func bufPoolIndex(size int) int { + if size <= 16384 { + return 0 + } + size -= 1 + bits := bits.Len(uint(size)) + index := bits - 14 + if index >= len(bufPools) { + return len(bufPools) - 1 + } + return index +} func (cs *clientStream) writeRequestBody(req *http.Request, dumps []*dump.Dumper) (err error) { cc := cs.cc @@ -1672,12 +1674,13 @@ func (cs *clientStream) writeRequestBody(req *http.Request, dumps []*dump.Dumper // Scratch buffer for reading into & writing from. scratchLen := cs.frameScratchBufferLen(maxFrameSize) var buf []byte - if bp, ok := bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { - defer bufPool.Put(bp) + index := bufPoolIndex(scratchLen) + if bp, ok := bufPools[index].Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer bufPools[index].Put(bp) buf = *bp } else { buf = make([]byte, scratchLen) - defer bufPool.Put(&buf) + defer bufPools[index].Put(&buf) } writeData := cc.fr.WriteData @@ -1854,6 +1857,9 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail if err != nil { return nil, err } + if !httpguts.ValidHostHeader(host) { + return nil, errors.New("http2: invalid Host header") + } var path string if req.Method != "CONNECT" { @@ -2964,6 +2970,15 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { fl = &cs.flow } if !fl.add(int32(f.Increment)) { + // For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR + if cs != nil { + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeFlowControl, + }) + return nil + } + return ConnectionError(ErrCodeFlowControl) } cc.cond.Broadcast() From f8dca3e3f96ad64afff43b504eb9910ceabcfe97 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 16:48:24 +0800 Subject: [PATCH 793/843] update github actions --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d236dc9..665084ff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.20.x', '1.21.x' ] + go: [ '1.21.x', '1.22.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: @@ -30,4 +30,4 @@ jobs: with: go-version: ${{ matrix.go }} - name: Test - run: go test ./... -coverprofile=coverage.txt \ No newline at end of file + run: go test ./... -coverprofile=coverage.txt From c0a83bd7ea57d0d0674c82597d1b83cce2d1862e Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 16:58:14 +0800 Subject: [PATCH 794/843] remove tests from std net/http --- http_test.go | 64 - response_test.go | 73 - transfer_test.go | 364 --- transport_internal_test.go | 188 -- transport_test.go | 5996 ------------------------------------ 5 files changed, 6685 deletions(-) delete mode 100644 http_test.go delete mode 100644 response_test.go delete mode 100644 transfer_test.go delete mode 100644 transport_internal_test.go delete mode 100644 transport_test.go diff --git a/http_test.go b/http_test.go deleted file mode 100644 index 7dca3a45..00000000 --- a/http_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package req - -import ( - "reflect" - "testing" -) - -func TestForeachHeaderElement(t *testing.T) { - tests := []struct { - in string - want []string - }{ - {"Foo", []string{"Foo"}}, - {" Foo", []string{"Foo"}}, - {"Foo ", []string{"Foo"}}, - {" Foo ", []string{"Foo"}}, - - {"foo", []string{"foo"}}, - {"anY-cAsE", []string{"anY-cAsE"}}, - - {"", nil}, - {",,,, , ,, ,,, ,", nil}, - - {" Foo,Bar, Baz,lower,,Quux ", []string{"Foo", "Bar", "Baz", "lower", "Quux"}}, - } - for _, tt := range tests { - var got []string - foreachHeaderElement(tt.in, func(v string) { - got = append(got, v) - }) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("foreachHeaderElement(%q) = %q; want %q", tt.in, got, tt.want) - } - } -} - -func TestCleanHost(t *testing.T) { - tests := []struct { - in, want string - }{ - {"www.google.com", "www.google.com"}, - {"www.google.com foo", "www.google.com"}, - {"www.google.com/foo", "www.google.com"}, - {" first character is a space", ""}, - {"[1::6]:8080", "[1::6]:8080"}, - - // Punycode: - {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, - {"bücher.de", "xn--bcher-kva.de"}, - {"bücher.de:8080", "xn--bcher-kva.de:8080"}, - // Verify we convert to lowercase before punycode: - {"BÜCHER.de", "xn--bcher-kva.de"}, - {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, - // Verify we normalize to NFC before punycode: - {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed - {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input - } - for _, tt := range tests { - got := cleanHost(tt.in) - if tt.want != got { - t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) - } - } -} diff --git a/response_test.go b/response_test.go deleted file mode 100644 index 7ec82376..00000000 --- a/response_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package req - -import ( - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "net/http" - "testing" -) - -type User struct { - Name string `json:"name" xml:"name"` -} - -type Message struct { - Message string `json:"message"` -} - -func TestUnmarshalJson(t *testing.T) { - var user User - resp, err := tc().R().Get("/json") - assertSuccess(t, resp, err) - err = resp.UnmarshalJson(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestUnmarshalXml(t *testing.T) { - var user User - resp, err := tc().R().Get("/xml") - assertSuccess(t, resp, err) - err = resp.UnmarshalXml(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestUnmarshal(t *testing.T) { - var user User - resp, err := tc().R().Get("/xml") - assertSuccess(t, resp, err) - err = resp.Unmarshal(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestResponseResult(t *testing.T) { - resp, _ := tc().R().SetResult(&User{}).Get("/json") - user, ok := resp.Result().(*User) - if !ok { - t.Fatal("Response.Result() should return *User") - } - tests.AssertEqual(t, "roc", user.Name) - - tests.AssertEqual(t, true, resp.TotalTime() > 0) - tests.AssertEqual(t, false, resp.ReceivedAt().IsZero()) -} - -func TestResponseError(t *testing.T) { - resp, _ := tc().R().SetError(&Message{}).Get("/json?error=yes") - msg, ok := resp.Error().(*Message) - if !ok { - t.Fatal("Response.Error() should return *Message") - } - tests.AssertEqual(t, "not allowed", msg.Message) -} - -func TestResponseWrap(t *testing.T) { - resp, err := tc().R().Get("/json") - assertSuccess(t, resp, err) - tests.AssertEqual(t, true, resp.GetStatusCode() == http.StatusOK) - tests.AssertEqual(t, true, resp.GetStatus() == "200 OK") - tests.AssertEqual(t, true, resp.GetHeader(header.ContentType) == header.JsonContentType) - tests.AssertEqual(t, true, len(resp.GetHeaderValues(header.ContentType)) == 1) -} diff --git a/transfer_test.go b/transfer_test.go deleted file mode 100644 index 0721aeed..00000000 --- a/transfer_test.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package req - -import ( - "bufio" - "bytes" - "crypto/rand" - "fmt" - "io" - "net/http" - "os" - "reflect" - "strings" - "testing" -) - -func TestBodyReadBadTrailer(t *testing.T) { - b := &body{ - src: strings.NewReader("foobar"), - hdr: true, // force reading the trailer - r: bufio.NewReader(strings.NewReader("")), - } - buf := make([]byte, 7) - n, err := b.Read(buf[:3]) - got := string(buf[:n]) - if got != "foo" || err != nil { - t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) - } - - n, err = b.Read(buf[:]) - got = string(buf[:n]) - if got != "bar" || err != nil { - t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) - } - - n, err = b.Read(buf[:]) - got = string(buf[:n]) - if err == nil { - t.Errorf("final Read was successful (%q), expected error from trailer read", got) - } -} - -func TestFinalChunkedBodyReadEOF(t *testing.T) { - res, err := http.ReadResponse(bufio.NewReader(strings.NewReader( - "HTTP/1.1 200 OK\r\n"+ - "Transfer-Encoding: chunked\r\n"+ - "\r\n"+ - "0a\r\n"+ - "Body here\n\r\n"+ - "09\r\n"+ - "continued\r\n"+ - "0\r\n"+ - "\r\n")), nil) - if err != nil { - t.Fatal(err) - } - want := "Body here\ncontinued" - buf := make([]byte, len(want)) - n, err := res.Body.Read(buf) - if n != len(want) || err != io.EOF { - t.Logf("body = %#v", res.Body) - t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) - } - if string(buf) != want { - t.Errorf("buf = %q; want %q", buf, want) - } -} - -func TestDetectInMemoryReaders(t *testing.T) { - pr, _ := io.Pipe() - tests := []struct { - r io.Reader - want bool - }{ - {pr, false}, - - {bytes.NewReader(nil), true}, - {bytes.NewBuffer(nil), true}, - {strings.NewReader(""), true}, - - {io.NopCloser(pr), false}, - - {io.NopCloser(bytes.NewReader(nil)), true}, - {io.NopCloser(bytes.NewBuffer(nil)), true}, - {io.NopCloser(strings.NewReader("")), true}, - } - for i, tt := range tests { - got := isKnownInMemoryReader(tt.r) - if got != tt.want { - t.Errorf("%d: got = %v; want %v", i, got, tt.want) - } - } -} - -type mockTransferWriter struct { - CalledReader io.Reader - WriteCalled bool -} - -var _ io.ReaderFrom = (*mockTransferWriter)(nil) - -func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { - w.CalledReader = r - return io.Copy(io.Discard, r) -} - -func (w *mockTransferWriter) Write(p []byte) (int, error) { - w.WriteCalled = true - return io.Discard.Write(p) -} - -func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { - fileType := reflect.TypeOf(&os.File{}) - bufferType := reflect.TypeOf(&bytes.Buffer{}) - - nBytes := int64(1 << 10) - newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := os.CreateTemp("", "net-http-newfilefunc") - if err != nil { - return nil, nil, err - } - - // Write some bytes to the file to enable reading. - if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { - return nil, nil, fmt.Errorf("failed to write data to file: %v", err) - } - if _, err := f.Seek(0, 0); err != nil { - return nil, nil, fmt.Errorf("failed to seek to front: %v", err) - } - - done = func() { - f.Close() - os.Remove(f.Name()) - } - - return f, done, nil - } - - newBufferFunc := func() (io.Reader, func(), error) { - return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil - } - - cases := []struct { - name string - bodyFunc func() (io.Reader, func(), error) - method string - contentLength int64 - transferEncoding []string - limitedReader bool - expectedReader reflect.Type - expectedWrite bool - }{ - { - name: "file, non-chunked, size set", - bodyFunc: newFileFunc, - method: "PUT", - contentLength: nBytes, - limitedReader: true, - expectedReader: fileType, - }, - { - name: "file, non-chunked, size set, nopCloser wrapped", - method: "PUT", - bodyFunc: func() (io.Reader, func(), error) { - r, cleanup, err := newFileFunc() - return io.NopCloser(r), cleanup, err - }, - contentLength: nBytes, - limitedReader: true, - expectedReader: fileType, - }, - { - name: "file, non-chunked, negative size", - method: "PUT", - bodyFunc: newFileFunc, - contentLength: -1, - expectedReader: fileType, - }, - { - name: "file, non-chunked, CONNECT, negative size", - method: "CONNECT", - bodyFunc: newFileFunc, - contentLength: -1, - expectedReader: fileType, - }, - { - name: "file, chunked", - method: "PUT", - bodyFunc: newFileFunc, - transferEncoding: []string{"chunked"}, - expectedWrite: true, - }, - { - name: "buffer, non-chunked, size set", - bodyFunc: newBufferFunc, - method: "PUT", - contentLength: nBytes, - limitedReader: true, - expectedReader: bufferType, - }, - { - name: "buffer, non-chunked, size set, nopCloser wrapped", - method: "PUT", - bodyFunc: func() (io.Reader, func(), error) { - r, cleanup, err := newBufferFunc() - return io.NopCloser(r), cleanup, err - }, - contentLength: nBytes, - limitedReader: true, - expectedReader: bufferType, - }, - { - name: "buffer, non-chunked, negative size", - method: "PUT", - bodyFunc: newBufferFunc, - contentLength: -1, - expectedWrite: true, - }, - { - name: "buffer, non-chunked, CONNECT, negative size", - method: "CONNECT", - bodyFunc: newBufferFunc, - contentLength: -1, - expectedWrite: true, - }, - { - name: "buffer, chunked", - method: "PUT", - bodyFunc: newBufferFunc, - transferEncoding: []string{"chunked"}, - expectedWrite: true, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - body, cleanup, err := tc.bodyFunc() - if err != nil { - t.Fatal(err) - } - defer cleanup() - - mw := &mockTransferWriter{} - tw := &transferWriter{ - Body: body, - ContentLength: tc.contentLength, - TransferEncoding: tc.transferEncoding, - } - - if err := tw.writeBody(mw, nil); err != nil { - t.Fatal(err) - } - - if tc.expectedReader != nil { - if mw.CalledReader == nil { - t.Fatal("did not call ReadFrom") - } - - var actualReader reflect.Type - lr, ok := mw.CalledReader.(*io.LimitedReader) - if ok && tc.limitedReader { - actualReader = reflect.TypeOf(lr.R) - } else { - actualReader = reflect.TypeOf(mw.CalledReader) - } - - if tc.expectedReader != actualReader { - t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader) - } - } - - if tc.expectedWrite && !mw.WriteCalled { - t.Fatal("did not invoke Write") - } - }) - } -} - -func TestParseTransferEncoding(t *testing.T) { - tests := []struct { - hdr http.Header - wantErr error - }{ - { - hdr: http.Header{"Transfer-Encoding": {"fugazi"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}}, - wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {""}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked, identity"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked", "identity"}}, - wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"\x0bchunked"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked"}}, - wantErr: nil, - }, - } - - for i, tt := range tests { - tr := &transferReader{ - Header: tt.hdr, - ProtoMajor: 1, - ProtoMinor: 1, - } - gotErr := tr.parseTransferEncoding() - if !reflect.DeepEqual(gotErr, tt.wantErr) { - t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr) - } - } -} - -// issue 39017 - disallow Content-Length values such as "+3" -func TestParseContentLength(t *testing.T) { - tests := []struct { - cl string - wantErr error - }{ - { - cl: "3", - wantErr: nil, - }, - { - cl: "+3", - wantErr: badStringError("bad Content-Length", "+3"), - }, - { - cl: "-3", - wantErr: badStringError("bad Content-Length", "-3"), - }, - { - // max int64, for safe conversion before returning - cl: "9223372036854775807", - wantErr: nil, - }, - { - cl: "9223372036854775808", - wantErr: badStringError("bad Content-Length", "9223372036854775808"), - }, - } - - for _, tt := range tests { - if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) { - t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr) - } - } -} diff --git a/transport_internal_test.go b/transport_internal_test.go deleted file mode 100644 index 91bea4cd..00000000 --- a/transport_internal_test.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// White-box tests for transport.go (in package http instead of http_test). - -package req - -import ( - "context" - "errors" - "github.com/imroc/req/v3/internal/http2" - "github.com/imroc/req/v3/internal/tests" - "net" - "net/http" - "strings" - "testing" -) - -func withT(r *http.Request, t *testing.T) *http.Request { - return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) -} - -// Issue 15446: incorrect wrapping of errors when server closes an idle connection. -func TestTransportPersistConnReadLoopEOF(t *testing.T) { - ln := tests.NewLocalListener(t) - defer ln.Close() - - connc := make(chan net.Conn, 1) - go func() { - defer close(connc) - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - connc <- c - }() - - tr := new(Transport) - req, _ := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) - req = withT(req, t) - treq := &transportRequest{Request: req} - cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} - pc, err := tr.getConn(treq, cm) - if err != nil { - t.Fatal(err) - } - defer pc.close(errors.New("test over")) - - conn := <-connc - if conn == nil { - // Already called t.Error in the accept goroutine. - return - } - conn.Close() // simulate the server hanging up on the client - - _, err = pc.roundTrip(treq) - if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err) - } - - <-pc.closech - err = pc.closed - if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) - } -} - -func isNothingWrittenError(err error) bool { - _, ok := err.(nothingWrittenError) - return ok -} - -func isTransportReadFromServerError(err error) bool { - _, ok := err.(transportReadFromServerError) - return ok -} - -func dummyRequest(method string) *http.Request { - req, err := http.NewRequest(method, "http://fake.tld/", nil) - if err != nil { - panic(err) - } - return req -} -func dummyRequestWithBody(method string) *http.Request { - req, err := http.NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) - if err != nil { - panic(err) - } - return req -} - -func dummyRequestWithBodyNoGetBody(method string) *http.Request { - req := dummyRequestWithBody(method) - req.GetBody = nil - return req -} - -// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn. -type issue22091Error struct{} - -func (issue22091Error) IsHTTP2NoCachedConnError() {} -func (issue22091Error) Error() string { return "issue22091Error" } - -func TestTransportShouldRetryRequest(t *testing.T) { - tests := []struct { - pc *persistConn - req *http.Request - - err error - want bool - }{ - 0: { - pc: &persistConn{reused: false}, - req: dummyRequest("POST"), - err: nothingWrittenError{}, - want: false, - }, - 1: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: nothingWrittenError{}, - want: true, - }, - 2: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: http2.ErrNoCachedConn, - want: true, - }, - 3: { - pc: nil, - req: nil, - err: issue22091Error{}, // like an external http2ErrNoCachedConn - want: true, - }, - 4: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: errMissingHost, - want: false, - }, - 5: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: transportReadFromServerError{}, - want: false, - }, - 6: { - pc: &persistConn{reused: true}, - req: dummyRequest("GET"), - err: transportReadFromServerError{}, - want: true, - }, - 7: { - pc: &persistConn{reused: true}, - req: dummyRequest("GET"), - err: errServerClosedIdle, - want: true, - }, - 8: { - pc: &persistConn{reused: true}, - req: dummyRequestWithBody("POST"), - err: nothingWrittenError{}, - want: true, - }, - 9: { - pc: &persistConn{reused: true}, - req: dummyRequestWithBodyNoGetBody("POST"), - err: nothingWrittenError{}, - want: false, - }, - } - for i, tt := range tests { - got := tt.pc.shouldRetryRequest(tt.req, tt.err) - if got != tt.want { - t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) - } - } -} - -type roundTripFunc func(r *http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { - return f(r) -} diff --git a/transport_test.go b/transport_test.go deleted file mode 100644 index 1858540b..00000000 --- a/transport_test.go +++ /dev/null @@ -1,5996 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Tests for transport.go. -// -// More tests are in clientserver_test.go (for things testing both client & server for both -// HTTP/1 and HTTP/2). This - -package req - -import ( - "bufio" - "bytes" - "compress/gzip" - "context" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "encoding/binary" - "errors" - "fmt" - "github.com/imroc/req/v3/internal/common" - "github.com/imroc/req/v3/internal/tests" - "github.com/imroc/req/v3/internal/transport" - "go/token" - "golang.org/x/net/http/httpproxy" - nethttp2 "golang.org/x/net/http2" - "io" - "log" - mrand "math/rand" - "net" - "net/http" - "net/http/httptest" - "net/http/httptrace" - "net/http/httputil" - "net/textproto" - "net/url" - "os" - "reflect" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "testing" - "testing/iotest" - "time" - - "golang.org/x/net/http/httpguts" -) - -func (t *Transport) NumPendingRequestsForTesting() int { - t.reqMu.Lock() - defer t.reqMu.Unlock() - return len(t.reqCanceler) -} - -func (t *Transport) IdleConnKeysForTesting() (keys []string) { - keys = make([]string, 0) - t.idleMu.Lock() - defer t.idleMu.Unlock() - for key := range t.idleConn { - keys = append(keys, key.String()) - } - sort.Strings(keys) - return -} - -func (t *Transport) IdleConnKeyCountForTesting() int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return len(t.idleConn) -} - -func (t *Transport) IdleConnStrsForTesting() []string { - var ret []string - t.idleMu.Lock() - defer t.idleMu.Unlock() - for _, conns := range t.idleConn { - for _, pc := range conns { - ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String()) - } - } - sort.Strings(ret) - return ret -} - -func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - key := connectMethodKey{"", scheme, addr, false} - cacheKey := key.String() - for k, conns := range t.idleConn { - if k.String() == cacheKey { - return len(conns) - } - } - return 0 -} - -func (t *Transport) IdleConnWaitMapSizeForTesting() int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return len(t.idleConnWait) -} - -func (t *Transport) IsIdleForTesting() bool { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return t.closeIdle -} - -func (t *Transport) QueueForIdleConnForTesting() { - t.queueForIdleConn(nil) -} - -// PutIdleTestConn reports whether it was able to insert a fresh -// persistConn for scheme, addr into the idle connection pool. -func (t *Transport) PutIdleTestConn(scheme, addr string) bool { - c, _ := net.Pipe() - key := connectMethodKey{"", scheme, addr, false} - - if t.MaxConnsPerHost > 0 { - // Transport is tracking conns-per-host. - // Increment connection count to account - // for new persistConn created below. - t.connsPerHostMu.Lock() - if t.connsPerHost == nil { - t.connsPerHost = make(map[connectMethodKey]int) - } - t.connsPerHost[key]++ - t.connsPerHostMu.Unlock() - } - - return t.tryPutIdleConn(&persistConn{ - t: t, - conn: c, // dummy - closech: make(chan struct{}), // so it can be closed - cacheKey: key, - }) == nil -} - -// PutIdleTestConnH2 reports whether it was able to insert a fresh -// HTTP/2 persistConn for scheme, addr into the idle connection pool. -func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt http.RoundTripper) bool { - key := connectMethodKey{"", scheme, addr, false} - - if t.MaxConnsPerHost > 0 { - // Transport is tracking conns-per-host. - // Increment connection count to account - // for new persistConn created below. - t.connsPerHostMu.Lock() - if t.connsPerHost == nil { - t.connsPerHost = make(map[connectMethodKey]int) - } - t.connsPerHost[key]++ - t.connsPerHostMu.Unlock() - } - - return t.tryPutIdleConn(&persistConn{ - t: t, - alt: alt, - cacheKey: key, - }) == nil -} - -// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close -// and then verify that the final 2 responses get errors back. - -// hostPortHandler writes back the client's "host:port". -var hostPortHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.FormValue("close") == "true" { - w.Header().Set("Connection", "close") - } - w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) - w.Write([]byte(r.RemoteAddr)) -}) - -// testCloseConn is a net.Conn tracked by a testConnSet. -type testCloseConn struct { - net.Conn - set *testConnSet -} - -func (c *testCloseConn) Close() error { - c.set.remove(c) - return c.Conn.Close() -} - -// testConnSet tracks a set of TCP connections and whether they've -// been closed. -type testConnSet struct { - t *testing.T - mu sync.Mutex // guards closed and list - closed map[net.Conn]bool - list []net.Conn // in order created -} - -func (tcs *testConnSet) insert(c net.Conn) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - tcs.closed[c] = false - tcs.list = append(tcs.list, c) -} - -func (tcs *testConnSet) remove(c net.Conn) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - tcs.closed[c] = true -} - -// some tests use this to manage raw tcp connections for later inspection -func makeTestDial(t *testing.T) (*testConnSet, func(ctx context.Context, n, addr string) (net.Conn, error)) { - connSet := &testConnSet{ - t: t, - closed: make(map[net.Conn]bool), - } - dial := func(_ context.Context, n, addr string) (net.Conn, error) { - c, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - tc := &testCloseConn{c, connSet} - connSet.insert(tc) - return tc, nil - } - return connSet, dial -} - -func (tcs *testConnSet) check(t *testing.T) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - for i := 4; i >= 0; i-- { - for i, c := range tcs.list { - if tcs.closed[c] { - continue - } - if i != 0 { - tcs.mu.Unlock() - time.Sleep(50 * time.Millisecond) - tcs.mu.Lock() - continue - } - t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) - } - } -} - -func TestReuseRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("{}")) - })) - defer ts.Close() - - c := tc().httpClient - req, _ := http.NewRequest("GET", ts.URL, nil) - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - err = res.Body.Close() - if err != nil { - t.Fatal(err) - } - - res, err = c.Do(req) - if err != nil { - t.Fatal(err) - } - err = res.Body.Close() - if err != nil { - t.Fatal(err) - } -} - -// Two subsequent requests and verify their response is the same. -// The response from the server is our own IP:port -func TestTransportKeepAlives(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - for _, disableKeepAlive := range []bool{false, true} { - c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive - fetch := func(n int) string { - res, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - - bodiesDiffer := body1 != body2 - if bodiesDiffer != disableKeepAlive { - t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", - disableKeepAlive, bodiesDiffer, body1, body2) - } - } -} - -func interestingGoroutines() (gs []string) { - buf := make([]byte, 2<<20) - buf = buf[:runtime.Stack(buf, true)] - for _, g := range strings.Split(string(buf), "\n\n") { - sl := strings.SplitN(g, "\n", 2) - if len(sl) != 2 { - continue - } - stack := strings.TrimSpace(sl[1]) - if stack == "" || - strings.Contains(stack, "testing.(*M).before.func1") || - strings.Contains(stack, "os/signal.signal_recv") || - strings.Contains(stack, "created by net.startServer") || - strings.Contains(stack, "created by testing.RunTests") || - strings.Contains(stack, "closeWriteAndWait") || - strings.Contains(stack, "testing.Main(") || - // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) - strings.Contains(stack, "runtime.goexit") || - strings.Contains(stack, "created by runtime.gc") || - strings.Contains(stack, "net/http_test.interestingGoroutines") || - strings.Contains(stack, "runtime.MHeap_Scavenger") { - continue - } - gs = append(gs, stack) - } - sort.Strings(gs) - return -} - -func afterTest(t testing.TB) { - http.DefaultTransport.(*http.Transport).CloseIdleConnections() - if testing.Short() { - return - } - // var bad string - // badSubstring := map[string]string{ - // ").readLoop(": "a Transport", - // ").writeLoop(": "a Transport", - // "created by net/http/httptest.(*Server).Start": "an httptest.Server", - // "timeoutHandler": "a TimeoutHandler", - // "net.(*netFD).connect(": "a timing out dial", - // ").noteClientGone(": "a closenotifier sender", - // } - // var stacks string - // for i := 0; i < 10; i++ { - // bad = "" - // stacks = strings.Join(interestingGoroutines(), "\n\n") - // for substr, what := range badSubstring { - // if strings.Contains(stacks, substr) { - // bad = what - // } - // } - // if bad == "" { - // return - // } - // // Bad stuff found, but goroutines might just still be - // // shutting down, so give it some time. - // time.Sleep(250 * time.Millisecond) - // } - // t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) -} - -func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - connSet, testDial := makeTestDial(t) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = testDial - - for _, connectionClose := range []bool{false, true} { - fetch := func(n int) string { - req := new(http.Request) - var err error - req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) - if err != nil { - t.Fatalf("URL parse error: %v", err) - } - req.Method = "GET" - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - - res, err := c.Do(req) - if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) - } - defer res.Body.Close() - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - bodiesDiffer := body1 != body2 - if bodiesDiffer != connectionClose { - t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", - connectionClose, bodiesDiffer, body1, body2) - } - - tr.CloseIdleConnections() - } - - connSet.check(t) -} - -// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse -// an underlying TCP connection after making an http.Request with Request.Close set. -// -// It tests the behavior by making an HTTP request to a server which -// describes the source source connection it got (remote port number + -// address of its net.Conn) -func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - connSet, testDial := makeTestDial(t) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = testDial - for _, reqClose := range []bool{false, true} { - fetch := func(n int) string { - req := new(http.Request) - var err error - req.URL, err = url.Parse(ts.URL) - if err != nil { - t.Fatalf("URL parse error: %v", err) - } - req.Method = "GET" - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = reqClose - - res, err := c.Do(req) - if err != nil { - t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) - } - if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { - t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", - reqClose, got, !reqClose) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - - got := 1 - if body1 != body2 { - got++ - } - want := 1 - if reqClose { - want = 2 - } - if got != want { - t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", - reqClose, got, want, body1, body2) - } - - tr.CloseIdleConnections() - } - - connSet.check(t) -} - -// if the Transport's DisableKeepAlives is set, all requests should -// send Connection: close. -// HTTP/1-only (Connection: close doesn't exist in h2) -func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).DisableKeepAlives = true - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if res.Header.Get("X-Saw-Close") != "true" { - t.Errorf("handler didn't see Connection: close ") - } -} - -// Test that Transport only sends one "Connection: close", regardless of -// how "close" was indicated. -func TestTransportRespectRequestWantsClose(t *testing.T) { - tests := []struct { - disableKeepAlives bool - close bool - }{ - {disableKeepAlives: false, close: false}, - {disableKeepAlives: false, close: true}, - {disableKeepAlives: true, close: false}, - {disableKeepAlives: true, close: true}, - } - - for _, testCase := range tests { - t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", testCase.disableKeepAlives, testCase.close), - func(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).DisableKeepAlives = testCase.disableKeepAlives - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - count := 0 - trace := &httptrace.ClientTrace{ - WroteHeaderField: func(key string, field []string) { - if key != "Connection" { - return - } - if httpguts.HeaderValuesContainsToken(field, "close") { - count += 1 - } - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - req.Close = testCase.close - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if want := testCase.disableKeepAlives || testCase.close; count > 1 || (count == 1) != want { - t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) - } - }) - } - -} - -func TestTransportIdleCacheKeys(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) - } - - resp, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - } - io.ReadAll(resp.Body) - - keys := tr.IdleConnKeysForTesting() - if e, g := 1, len(keys); e != g { - t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) - } - - if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { - t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) - } - - tr.CloseIdleConnections() - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) - } -} - -// Tests that the HTTP transport re-uses connections when a client -// reads to the end of a response Body without closing it. -func TestTransportReadToEndReusesConn(t *testing.T) { - defer afterTest(t) - const msg = "foobar" - - var addrSeen map[string]int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addrSeen[r.RemoteAddr]++ - if r.URL.Path == "/chunked/" { - w.WriteHeader(200) - w.(http.Flusher).Flush() - } else { - w.Header().Set("Content-Length", strconv.Itoa(len(msg))) - w.WriteHeader(200) - } - w.Write([]byte(msg)) - })) - defer ts.Close() - - buf := make([]byte, len(msg)) - - for pi, path := range []string{"/content-length/", "/chunked/"} { - wantLen := []int{len(msg), -1}[pi] - addrSeen = make(map[string]int) - for i := 0; i < 3; i++ { - res, err := http.Get(ts.URL + path) - if err != nil { - t.Errorf("Get %s: %v", path, err) - continue - } - // We want to close this body eventually (before the - // defer afterTest at top runs), but not before the - // len(addrSeen) check at the bottom of this test, - // since Closing this early in the loop would risk - // making connections be re-used for the wrong reason. - defer res.Body.Close() - - if res.ContentLength != int64(wantLen) { - t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) - } - n, err := res.Body.Read(buf) - if n != len(msg) || err != io.EOF { - t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) - } - } - if len(addrSeen) != 1 { - t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) - } - } -} - -func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer afterTest(t) - stop := make(chan struct{}) // stop marks the exit of main Test goroutine - defer close(stop) - - resch := make(chan string) - gotReq := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotReq <- true - var msg string - select { - case <-stop: - return - case msg = <-resch: - } - _, err := w.Write([]byte(msg)) - if err != nil { - t.Errorf("Write: %v", err) - return - } - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - maxIdleConnsPerHost := 2 - tr.MaxIdleConnsPerHost = maxIdleConnsPerHost - - // Start 3 outstanding requests and wait for the server to get them. - // Their responses will hang until we write to resch, though. - donech := make(chan bool) - doReq := func() { - defer func() { - select { - case <-stop: - return - case donech <- t.Failed(): - } - }() - resp, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - return - } - if _, err := io.ReadAll(resp.Body); err != nil { - t.Errorf("ReadAll: %v", err) - return - } - } - go doReq() - <-gotReq - go doReq() - <-gotReq - go doReq() - <-gotReq - - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) - } - - resch <- "res1" - <-donech - keys := tr.IdleConnKeysForTesting() - if e, g := 1, len(keys); e != g { - t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) - } - addr := ts.Listener.Addr().String() - cacheKey := "|http|" + addr - if keys[0] != cacheKey { - t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) - } - if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { - t.Errorf("after first response, expected %d idle conns; got %d", e, g) - } - - resch <- "res2" - <-donech - if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { - t.Errorf("after second response, idle conns = %d; want %d", g, w) - } - - resch <- "res3" - <-donech - if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { - t.Errorf("after third response, idle conns = %d; want %d", g, w) - } -} - -func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - dialStarted := make(chan struct{}) - stallDial := make(chan struct{}) - tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { - dialStarted <- struct{}{} - <-stallDial - return net.Dial(network, addr) - } - - tr.DisableKeepAlives = true - tr.MaxConnsPerHost = 1 - - preDial := make(chan struct{}) - reqComplete := make(chan struct{}) - doReq := func(reqId string) { - req, _ := http.NewRequest("GET", ts.URL, nil) - trace := &httptrace.ClientTrace{ - GetConn: func(hostPort string) { - preDial <- struct{}{} - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("unexpected error for request %s: %v", reqId, err) - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Errorf("unexpected error for request %s: %v", reqId, err) - } - reqComplete <- struct{}{} - } - // get req1 to dial-in-progress - go doReq("req1") - <-preDial - <-dialStarted - - // get req2 to waiting on conns per host to go down below max - go doReq("req2") - <-preDial - select { - case <-dialStarted: - t.Error("req2 dial started while req1 dial in progress") - return - default: - } - - // let req1 complete - stallDial <- struct{}{} - <-reqComplete - - // let req2 complete - <-dialStarted - stallDial <- struct{}{} - <-reqComplete -} - -func TestTransportMaxConnsPerHost(t *testing.T) { - defer afterTest(t) - - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - }) - - testMaxConns := func(scheme string, ts *httptest.Server) { - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - - mu := sync.Mutex{} - var conns []net.Conn - var dialCnt, gotConnCnt, tlsHandshakeCnt int32 - tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { - atomic.AddInt32(&dialCnt, 1) - c, err := net.Dial(network, addr) - mu.Lock() - defer mu.Unlock() - conns = append(conns, c) - return c, err - } - - doReq := func() { - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - TLSHandshakeStart: func() { - atomic.AddInt32(&tlsHandshakeCnt, 1) - }, - } - req, _ := http.NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - resp, err := c.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body failed: %v", err) - } - } - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - - expected := int32(tr.MaxConnsPerHost) - if dialCnt != expected { - t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) - } - if gotConnCnt != expected { - t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - - if t.Failed() { - t.FailNow() - } - - mu.Lock() - for _, c := range conns { - c.Close() - } - conns = nil - mu.Unlock() - tr.CloseIdleConnections() - - doReq() - expected++ - if dialCnt != expected { - t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) - } - if gotConnCnt != expected { - t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - } - - testMaxConns("http", httptest.NewServer(h)) - testMaxConns("https", httptest.NewTLSServer(h)) - - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - testMaxConns("http2", ts) -} - -func TestTransportRemovesDeadIdleConnections(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - doReq := func(name string) string { - // Do a POST instead of a GET to prevent the Transport's - // idempotent request retry logic from kicking in... - res, err := c.Post(ts.URL, "", nil) - if err != nil { - t.Fatalf("%s: %v", name, err) - } - if res.StatusCode != 200 { - t.Fatalf("%s: %v", name, res.Status) - } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("%s: %v", name, err) - } - return string(slurp) - } - - first := doReq("first") - keys1 := tr.IdleConnKeysForTesting() - - ts.CloseClientConnections() - - var keys2 []string - if !tests.WaitCondition(3*time.Second, 50*time.Millisecond, func() bool { - keys2 = tr.IdleConnKeysForTesting() - return len(keys2) == 0 - }) { - t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) - } - - second := doReq("second") - if first == second { - t.Errorf("expected a different connection between requests. got %q both times", first) - } -} - -// ExportCloseTransportConnsAbruptly closes all idle connections from -// tr in an abrupt way, just reaching into the underlying Conns and -// closing them, without telling the Transport or its persistConns -// that it's doing so. This is to simulate the server closing connections -// on the Transport. -func ExportCloseTransportConnsAbruptly(tr *Transport) { - tr.idleMu.Lock() - for _, pcs := range tr.idleConn { - for _, pc := range pcs { - pc.conn.Close() - } - } - tr.idleMu.Unlock() -} - -// Test that the Transport notices when a server hangs up on its -// unexpectedly (a keep-alive connection is closed). -func TestTransportServerClosingUnexpectedly(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - c := tc().httpClient - - fetch := func(n, retries int) string { - condFatalf := func(format string, arg ...interface{}) { - if retries <= 0 { - t.Fatalf(format, arg...) - } - t.Logf("retrying shortly after expected error: "+format, arg...) - time.Sleep(time.Second / time.Duration(retries)) - } - for retries >= 0 { - retries-- - res, err := c.Get(ts.URL) - if err != nil { - condFatalf("error in req #%d, GET: %v", n, err) - continue - } - body, err := io.ReadAll(res.Body) - if err != nil { - condFatalf("error in req #%d, ReadAll: %v", n, err) - continue - } - res.Body.Close() - return string(body) - } - panic("unreachable") - } - - body1 := fetch(1, 0) - body2 := fetch(2, 0) - - // Close all the idle connections in a way that's similar to - // the server hanging up on us. We don't use - // httptest.Server.CloseClientConnections because it's - // best-effort and stops blocking after 5 seconds. On a loaded - // machine running many tests concurrently it's possible for - // that method to be async and cause the body3 fetch below to - // run on an old connection. This function is synchronous. - ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) - - body3 := fetch(3, 5) - - if body1 != body2 { - t.Errorf("expected body1 and body2 to be equal") - } - if body2 == body3 { - t.Errorf("expected body2 and body3 to be different") - } -} - -// Test for https://golang.org/issue/2616 (appropriate issue number) -// This fails pretty reliably with GOMAXPROCS=100 or something high. -func TestStressSurpriseServerCloses(t *testing.T) { - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in short mode") - } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "5") - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("Hello")) - w.(http.Flusher).Flush() - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Flush() - conn.Close() - })) - defer ts.Close() - c := tc().httpClient - - // Do a bunch of traffic from different goroutines. Send to activityc - // after each request completes, regardless of whether it failed. - // If these are too high, OS X exhausts its ephemeral ports - // and hangs waiting for them to transition TCP states. That's - // not what we want to test. TODO(bradfitz): use an io.Pipe - // dialer for this test instead? - const ( - numClients = 20 - reqsPerClient = 25 - ) - activityc := make(chan bool) - for i := 0; i < numClients; i++ { - go func() { - for i := 0; i < reqsPerClient; i++ { - res, err := c.Get(ts.URL) - if err == nil { - // We expect errors since the server is - // hanging up on us after telling us to - // send more requests, so we don't - // actually care what the error is. - // But we want to close the body in cases - // where we won the race. - res.Body.Close() - } - if !<-activityc { // Receives false when close(activityc) is executed - return - } - } - }() - } - - // Make sure all the request come back, one way or another. - for i := 0; i < numClients*reqsPerClient; i++ { - select { - case activityc <- true: - case <-time.After(5 * time.Second): - close(activityc) - t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") - } - } -} - -// TestTransportHeadResponses verifies that we deal with Content-Lengths -// with no bodies properly -func TestTransportHeadResponses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "HEAD" { - panic("expected HEAD; got " + r.Method) - } - w.Header().Set("Content-Length", "123") - w.WriteHeader(200) - })) - defer ts.Close() - c := tc().httpClient - - for i := 0; i < 2; i++ { - res, err := c.Head(ts.URL) - if err != nil { - t.Errorf("error on loop %d: %v", i, err) - continue - } - if e, g := "123", res.Header.Get("Content-Length"); e != g { - t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) - } - if e, g := int64(123), res.ContentLength; e != g { - t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) - } - if all, err := io.ReadAll(res.Body); err != nil { - t.Errorf("loop %d: Body ReadAll: %v", i, err) - } else if len(all) != 0 { - t.Errorf("Bogus body %q", all) - } - } -} - -// All test hooks must be non-nil so they can be called directly, -// but the tests use nil to mean hook disabled. -func unnilTestHook(f *func()) { - if *f == nil { - *f = nop - } -} - -func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() - unnilTestHook(&f) - testHookReadLoopBeforeNextRead = f -} - -// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding -// on responses to HEAD requests. -func TestTransportHeadChunkedResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "HEAD" { - panic("expected HEAD; got " + r.Method) - } - w.Header().Set("Transfer-Encoding", "chunked") // client should ignore - w.Header().Set("x-client-ipport", r.RemoteAddr) - w.WriteHeader(200) - })) - defer ts.Close() - c := tc().httpClient - - // Ensure that we wait for the readLoop to complete before - // calling Head again - didRead := make(chan bool) - SetReadLoopBeforeNextReadHook(func() { didRead <- true }) - defer SetReadLoopBeforeNextReadHook(nil) - - res1, err := c.Head(ts.URL) - <-didRead - - if err != nil { - t.Fatalf("request 1 error: %v", err) - } - - res2, err := c.Head(ts.URL) - <-didRead - - if err != nil { - t.Fatalf("request 2 error: %v", err) - } - if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { - t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) - } -} - -var roundTripTests = []struct { - accept string - expectAccept string - compressed bool -}{ - // Requests with no accept-encoding header use transparent compression - {"", "gzip", false}, - // Requests with other accept-encoding should pass through unmodified - {"foo", "foo", false}, - // Requests with accept-encoding == gzip should be passed through - {"gzip", "gzip", true}, -} - -// Test that the modification made to the Request by the http.RoundTripper is cleaned up -func TestRoundTripGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) - const responseBody = "test response body" - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - accept := req.Header.Get("Accept-Encoding") - if expect := req.FormValue("expect_accept"); accept != expect { - t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", - req.FormValue("testnum"), accept, expect) - } - if accept == "gzip" { - rw.Header().Set("Content-Encoding", "gzip") - gz := gzip.NewWriter(rw) - gz.Write([]byte(responseBody)) - gz.Close() - } else { - rw.Header().Set("Content-Encoding", accept) - rw.Write([]byte(responseBody)) - } - })) - defer ts.Close() - tr := tc().GetTransport() - - for i, test := range roundTripTests { - // Test basic request (no accept-encoding) - req, _ := http.NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) - if test.accept != "" { - req.Header.Set("Accept-Encoding", test.accept) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("%d. RoundTrip: %v", i, err) - continue - } - var body []byte - if test.compressed { - var r *gzip.Reader - r, err = gzip.NewReader(res.Body) - if err != nil { - t.Errorf("%d. gzip NewReader: %v", i, err) - continue - } - body, err = io.ReadAll(r) - res.Body.Close() - } else { - body, err = io.ReadAll(res.Body) - } - if err != nil { - t.Errorf("%d. Error: %q", i, err) - continue - } - if g, e := string(body), responseBody; g != e { - t.Errorf("%d. body = %q; want %q", i, g, e) - } - if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { - t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) - } - if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { - t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) - } - } - -} - -func TestTransportGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) - const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - const nRandBytes = 1024 * 1024 - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method == "HEAD" { - if g := req.Header.Get("Accept-Encoding"); g != "" { - t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) - } - return - } - if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { - t.Errorf("Accept-Encoding = %q, want %q", g, e) - } - rw.Header().Set("Content-Encoding", "gzip") - - var w io.Writer = rw - var buf bytes.Buffer - if req.FormValue("chunked") == "0" { - w = &buf - defer io.Copy(rw, &buf) - defer func() { - rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) - }() - } - gz := gzip.NewWriter(w) - gz.Write([]byte(testString)) - if req.FormValue("body") == "large" { - io.CopyN(gz, rand.Reader, nRandBytes) - } - gz.Close() - })) - defer ts.Close() - c := tc().httpClient - - for _, chunked := range []string{"1", "0"} { - // First fetch something large, but only read some of it. - res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) - if err != nil { - t.Fatalf("large get: %v", err) - } - buf := make([]byte, len(testString)) - n, err := io.ReadFull(res.Body, buf) - if err != nil { - t.Fatalf("partial read of large response: size=%d, %v", n, err) - } - if e, g := testString, string(buf); e != g { - t.Errorf("partial read got %q, expected %q", g, e) - } - res.Body.Close() - // Read on the body, even though it's closed - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) - } - - // Then something small. - res, err = c.Get(ts.URL + "/?chunked=" + chunked) - if err != nil { - t.Fatal(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if g, e := string(body), testString; g != e { - t.Fatalf("body = %q; want %q", g, e) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) - } - - // Read on the body after it's been fully read: - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) - } - res.Body.Close() - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected Read error after Close; got %d, %v", n, err) - } - } - - // And a HEAD request too, because they're always weird. - res, err := c.Head(ts.URL) - if err != nil { - t.Fatalf("Head: %v", err) - } - if res.StatusCode != 200 { - t.Errorf("Head status=%d; want=200", res.StatusCode) - } -} - -// setParallel marks t as a parallel test if we're in short mode -// (all.bash), but as a serial test otherwise. Using t.Parallel isn't -// compatible with the afterTest func in non-short mode. -func setParallel(t *testing.T) { - if testing.Short() { - t.Parallel() - } -} - -// If a request has Expect:100-continue header, the request blocks sending body until the first response. -// Premature consumption of the request body should not be occurred. -func TestTransportExpect100Continue(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - switch req.URL.Path { - case "/100": - // This endpoint implicitly responds 100 Continue and reads body. - if _, err := io.Copy(io.Discard, req.Body); err != nil { - t.Error("Failed to read Body", err) - } - rw.WriteHeader(http.StatusOK) - case "/200": - // Go 1.5 adds Connection: close header if the client expect - // continue but not entire request body is consumed. - rw.WriteHeader(http.StatusOK) - case "/500": - rw.WriteHeader(http.StatusInternalServerError) - case "/keepalive": - // This hijacked endpoint responds error without Connection:close. - _, bufrw, err := rw.(http.Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") - bufrw.WriteString("Content-Length: 0\r\n\r\n") - bufrw.Flush() - case "/timeout": - // This endpoint tries to read body without 100 (Continue) response. - // After ExpectContinueTimeout, the reading will be started. - conn, bufrw, err := rw.(http.Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { - t.Error("Failed to read Body", err) - } - bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") - bufrw.Flush() - conn.Close() - } - - })) - defer ts.Close() - - tests := []struct { - path string - body []byte - sent int - status int - }{ - {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. - {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. - {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. - {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. - {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. - } - - c := tc().httpClient - for i, v := range tests { - tr := T() - tr.ExpectContinueTimeout = 2 * time.Second - defer tr.CloseIdleConnections() - c.Transport = tr - body := bytes.NewReader(v.body) - req, err := http.NewRequest("PUT", ts.URL+v.path, body) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Expect", "100-continue") - req.ContentLength = int64(len(v.body)) - - resp, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - - sent := len(v.body) - body.Len() - if v.status != resp.StatusCode { - t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) - } - if v.sent != sent { - t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) - } - } -} - -func TestSOCKS5Proxy(t *testing.T) { - defer afterTest(t) - ch := make(chan string, 1) - l := tests.NewLocalListener(t) - defer l.Close() - defer close(ch) - proxy := func(t *testing.T) { - s, err := l.Accept() - if err != nil { - t.Errorf("socks5 proxy Accept(): %v", err) - return - } - defer s.Close() - var buf [22]byte - if _, err := io.ReadFull(s, buf[:3]); err != nil { - t.Errorf("socks5 proxy initial read: %v", err) - return - } - if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { - t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) - return - } - if _, err := s.Write([]byte{5, 0}); err != nil { - t.Errorf("socks5 proxy initial write: %v", err) - return - } - if _, err := io.ReadFull(s, buf[:4]); err != nil { - t.Errorf("socks5 proxy second read: %v", err) - return - } - if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { - t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) - return - } - var ipLen int - switch buf[3] { - case 1: - ipLen = net.IPv4len - case 4: - ipLen = net.IPv6len - default: - t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) - return - } - if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { - t.Errorf("socks5 proxy address read: %v", err) - return - } - ip := net.IP(buf[4 : ipLen+4]) - port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) - copy(buf[:3], []byte{5, 0, 0}) - if _, err := s.Write(buf[:ipLen+6]); err != nil { - t.Errorf("socks5 proxy connect write: %v", err) - return - } - ch <- fmt.Sprintf("proxy for %s:%d", ip, port) - - // Implement proxying. - targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) - targetConn, err := net.Dial("tcp", targetHost) - if err != nil { - t.Errorf("net.Dial failed") - return - } - go io.Copy(targetConn, s) - io.Copy(s, targetConn) // Wait for the client to close the socket. - targetConn.Close() - } - - pu, err := url.Parse("socks5://" + l.Addr().String()) - if err != nil { - t.Fatal(err) - } - - sentinelHeader := "X-Sentinel" - sentinelValue := "12345" - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(sentinelHeader, sentinelValue) - }) - for _, useTLS := range []bool{false, true} { - t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { - var ts *httptest.Server - if useTLS { - ts = httptest.NewTLSServer(h) - } else { - ts = httptest.NewServer(h) - } - go proxy(t) - c := tc().httpClient - c.Transport.(*Transport).Proxy = http.ProxyURL(pu) - r, err := c.Head(ts.URL) - if err != nil { - t.Fatal(err) - } - if r.Header.Get(sentinelHeader) != sentinelValue { - t.Errorf("Failed to retrieve sentinel value") - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to socks5 proxy") - } - ts.Close() - tsu, err := url.Parse(ts.URL) - if err != nil { - t.Fatal(err) - } - want := "proxy for " + tsu.Host - if got != want { - t.Errorf("got %q, want %q", got, want) - } - }) - } -} - -func TestTransportProxy(t *testing.T) { - defer afterTest(t) - testCases := []struct{ httpsSite, httpsProxy bool }{ - {false, false}, - {false, true}, - {true, false}, - {true, true}, - } - for _, testCase := range testCases { - httpsSite := testCase.httpsSite - httpsProxy := testCase.httpsProxy - t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { - siteCh := make(chan *http.Request, 1) - h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siteCh <- r - }) - proxyCh := make(chan *http.Request, 1) - h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxyCh <- r - // Implement an entire CONNECT proxy - if r.Method == "CONNECT" { - hijacker, ok := w.(http.Hijacker) - if !ok { - t.Errorf("hijack not allowed") - return - } - clientConn, _, err := hijacker.Hijack() - if err != nil { - t.Errorf("hijacking failed") - return - } - res := &http.Response{ - StatusCode: http.StatusOK, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - - targetConn, err := net.Dial("tcp", r.URL.Host) - if err != nil { - t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) - return - } - - if err := res.Write(clientConn); err != nil { - t.Errorf("Writing 200 OK failed: %v", err) - return - } - - go io.Copy(targetConn, clientConn) - go func() { - io.Copy(clientConn, targetConn) - targetConn.Close() - }() - } - }) - var ts *httptest.Server - if httpsSite { - ts = httptest.NewTLSServer(h1) - } else { - ts = httptest.NewServer(h1) - } - var proxy *httptest.Server - if httpsProxy { - proxy = httptest.NewTLSServer(h2) - } else { - proxy = httptest.NewServer(h2) - } - - pu, err := url.Parse(proxy.URL) - if err != nil { - t.Fatal(err) - } - - // If neither server is HTTPS or both are, then c may be derived from either. - // If only one server is HTTPS, c must be derived from that server in order - // to ensure that it is configured to use the fake root CA from testcert.go. - c := tc().httpClient - - c.Transport.(*Transport).Proxy = http.ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got *http.Request - select { - case got = <-proxyCh: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to http proxy") - } - c.Transport.(*Transport).CloseIdleConnections() - ts.Close() - proxy.Close() - if httpsSite { - // First message should be a CONNECT, asking for a socket to the real server, - if got.Method != "CONNECT" { - t.Errorf("Wrong method for secure proxying: %q", got.Method) - } - gotHost := got.URL.Host - pu, err := url.Parse(ts.URL) - if err != nil { - t.Fatal("Invalid site URL") - } - if wantHost := pu.Host; gotHost != wantHost { - t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) - } - - // The next message on the channel should be from the site's server. - next := <-siteCh - if next.Method != "HEAD" { - t.Errorf("Wrong method at destination: %s", next.Method) - } - if nextURL := next.URL.String(); nextURL != "/" { - t.Errorf("Wrong URL at destination: %s", nextURL) - } - } else { - if got.Method != "HEAD" { - t.Errorf("Wrong method for destination: %q", got.Method) - } - gotURL := got.URL.String() - wantURL := ts.URL + "/" - if gotURL != wantURL { - t.Errorf("Got URL %q, want %q", gotURL, wantURL) - } - } - }) - } -} - -// Issue 28012: verify that the Transport closes its TCP connection to http proxies -// when they're slow to reply to HTTPS CONNECT responses. -func TestTransportProxyHTTPSConnectLeak(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln := tests.NewLocalListener(t) - defer ln.Close() - listenerDone := make(chan struct{}) - go func() { - defer close(listenerDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Accept: %v", err) - return - } - defer c.Close() - // Read the CONNECT request - br := bufio.NewReader(c) - cr, err := http.ReadRequest(br) - if err != nil { - t.Errorf("proxy server failed to read CONNECT request") - return - } - if cr.Method != "CONNECT" { - t.Errorf("unexpected method %q", cr.Method) - return - } - - // Now hang and never write a response; instead, cancel the request and wait - // for the client to close. - // (Prior to Issue 28012 being fixed, we never closed.) - cancel() - var buf [1]byte - _, err = br.Read(buf[:]) - if err != io.EOF { - t.Errorf("proxy server Read err = %v; want EOF", err) - } - return - }() - - tr := T().SetProxy(func(*http.Request) (*url.URL, error) { - return url.Parse("http://" + ln.Addr().String()) - }) - c := &http.Client{ - Transport: tr, - } - req, err := http.NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Errorf("unexpected Get success") - } - - // Wait unconditionally for the listener goroutine to exit: this should never - // hang, so if it does we want a full goroutine dump — and that's exactly what - // the testing package will give us when the test run times out. - <-listenerDone -} - -// Issue 16997: test transport dial preserves typed errors -func TestTransportDialPreservesNetOpProxyError(t *testing.T) { - defer afterTest(t) - - var errDial = errors.New("some dial error") - - tr := T().SetProxy(func(*http.Request) (*url.URL, error) { - return url.Parse("http://proxy.fake.tld/") - }).SetDial(func(context.Context, string, string) (net.Conn, error) { - return nil, errDial - }) - defer tr.CloseIdleConnections() - - c := &http.Client{Transport: tr} - req, _ := http.NewRequest("GET", "http://fake.tld", nil) - res, err := c.Do(req) - if err == nil { - res.Body.Close() - t.Fatal("wanted a non-nil error") - } - - uerr, ok := err.(*url.Error) - if !ok { - t.Fatalf("got %T, want *url.Error", err) - } - oe, ok := uerr.Err.(*net.OpError) - if !ok { - t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) - } - want := &net.OpError{ - Op: "proxyconnect", - Net: "tcp", - Err: errDial, // original error, unwrapped. - } - if !reflect.DeepEqual(oe, want) { - t.Errorf("Got error %#v; want %#v", oe, want) - } -} - -// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. -// -// (A bug caused dialConn to instead write the per-request Proxy-Authorization -// header through to the shared Header instance, introducing a data race.) -func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - - proxy := httptest.NewTLSServer(http.NotFoundHandler()) - defer proxy.Close() - c := tc().httpClient - - tr := c.Transport.(*Transport) - tr.Proxy = func(*http.Request) (*url.URL, error) { - u, _ := url.Parse(proxy.URL) - u.User = url.UserPassword("aladdin", "opensesame") - return u, nil - } - h := tr.ProxyConnectHeader - if h == nil { - h = make(http.Header) - } - tr.ProxyConnectHeader = h.Clone() - - req, err := http.NewRequest("GET", "https://golang.fake.tld/", nil) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Errorf("unexpected Get success") - } - - if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { - t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) - } -} - -// TestTransportGzipRecursive sends a gzip quine and checks that the -// client gets the same value back. This is more cute than anything, -// but checks that we don't recurse forever, and checks that -// Content-Encoding is removed. -func TestTransportGzipRecursive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", "gzip") - w.Write(rgz) - })) - defer ts.Close() - - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(body, rgz) { - t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", - body, rgz) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) - } -} - -// golang.org/issue/7750: request fails when server replies with -// a short gzip body -func TestTransportGzipShort(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", "gzip") - w.Write([]byte{0x1f, 0x8b}) - })) - defer ts.Close() - - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - _, err = io.ReadAll(res.Body) - if err == nil { - t.Fatal("Expect an error from reading a body.") - } - if err != io.ErrUnexpectedEOF { - t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) - } -} - -// Wait until number of goroutines is no greater than nmax, or time out. -func waitNumGoroutine(nmax int) int { - nfinal := runtime.NumGoroutine() - for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { - time.Sleep(50 * time.Millisecond) - runtime.GC() - nfinal = runtime.NumGoroutine() - } - return nfinal -} - -// tests that persistent goroutine connections shut down when no longer desired. -func TestTransportPersistConnLeak(t *testing.T) { - // Not parallel: counts goroutines - defer afterTest(t) - - const numReq = 25 - gotReqCh := make(chan bool, numReq) - unblockCh := make(chan bool, numReq) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotReqCh <- true - <-unblockCh - w.Header().Set("Content-Length", "0") - w.WriteHeader(204) - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - n0 := runtime.NumGoroutine() - - didReqCh := make(chan bool, numReq) - failed := make(chan bool, numReq) - for i := 0; i < numReq; i++ { - go func() { - res, err := c.Get(ts.URL) - didReqCh <- true - if err != nil { - t.Logf("client fetch error: %v", err) - failed <- true - return - } - res.Body.Close() - }() - } - - // Wait for all goroutines to be stuck in the Handler. - for i := 0; i < numReq; i++ { - select { - case <-gotReqCh: - // ok - case <-failed: - // Not great but not what we are testing: - // sometimes an overloaded system will fail to make all the connections. - } - } - - nhigh := runtime.NumGoroutine() - - // Tell all handlers to unblock and reply. - close(unblockCh) - - // Wait for all HTTP clients to be done. - for i := 0; i < numReq; i++ { - <-didReqCh - } - - tr.CloseIdleConnections() - nfinal := waitNumGoroutine(n0 + 5) - - growth := nfinal - n0 - - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. - // Previously we were leaking one per numReq. - if int(growth) > 5 { - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) - t.Error("too many new goroutines") - } -} - -// golang.org/issue/4531: Transport leaks goroutines when -// request.ContentLength is explicitly short -func TestTransportPersistConnLeakShortBody(t *testing.T) { - // Not parallel: measures goroutines. - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - n0 := runtime.NumGoroutine() - body := []byte("Hello") - for i := 0; i < 20; i++ { - req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - req.ContentLength = int64(len(body) - 2) // explicitly short - _, err = c.Do(req) - if err == nil { - t.Fatal("Expect an error from writing too long of a body.") - } - } - nhigh := runtime.NumGoroutine() - tr.CloseIdleConnections() - nfinal := waitNumGoroutine(n0 + 5) - - growth := nfinal - n0 - - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. - // Previously we were leaking one per numReq. - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) - if int(growth) > 5 { - t.Error("too many new goroutines") - } -} - -// A countedConn is a net.Conn that decrements an atomic counter when finalized. -type countedConn struct { - net.Conn -} - -// A countingDialer dials connections and counts the number that remain reachable. -type countingDialer struct { - dialer net.Dialer - mu sync.Mutex - total, live int64 -} - -func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - - counted := new(countedConn) - counted.Conn = conn - - d.mu.Lock() - defer d.mu.Unlock() - d.total++ - d.live++ - - runtime.SetFinalizer(counted, d.decrement) - return counted, nil -} - -func (d *countingDialer) decrement(*countedConn) { - d.mu.Lock() - defer d.mu.Unlock() - d.live-- -} - -func (d *countingDialer) Read() (total, live int64) { - d.mu.Lock() - defer d.mu.Unlock() - return d.total, d.live -} - -func TestTransportPersistConnLeakNeverIdle(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Close every connection so that it cannot be kept alive. - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack failed unexpectedly: %v", err) - return - } - conn.Close() - })) - defer ts.Close() - - var d countingDialer - c := tc().httpClient - c.Transport.(*Transport).DialContext = d.DialContext - - body := []byte("Hello") - for i := 0; ; i++ { - total, live := d.Read() - if live < total { - break - } - if i >= 1<<12 { - t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) - } - - req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Fatal("expected broken connection") - } - - runtime.GC() - } -} - -type countedContext struct { - context.Context -} - -type contextCounter struct { - mu sync.Mutex - live int64 -} - -func (cc *contextCounter) Track(ctx context.Context) context.Context { - counted := new(countedContext) - counted.Context = ctx - cc.mu.Lock() - defer cc.mu.Unlock() - cc.live++ - runtime.SetFinalizer(counted, cc.decrement) - return counted -} - -func (cc *contextCounter) decrement(*countedContext) { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.live-- -} - -func (cc *contextCounter) Read() (live int64) { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.live -} - -// This used to crash; https://golang.org/issue/3266 -func TestTransportIdleConnCrash(t *testing.T) { - defer afterTest(t) - var tr *Transport - - unblockCh := make(chan bool, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockCh - tr.CloseIdleConnections() - })) - defer ts.Close() - c := tc().httpClient - tr = c.Transport.(*Transport) - - didreq := make(chan bool) - go func() { - res, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - } else { - res.Body.Close() // returns idle conn - } - didreq <- true - }() - unblockCh <- true - <-didreq -} - -// Test that the transport doesn't close the TCP connection early, -// before the response body has been read. This was a regression -// which sadly lacked a triggering test. The large response body made -// the old race easier to trigger. -func TestIssue3644(t *testing.T) { - defer afterTest(t) - const numFoos = 5000 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Connection", "close") - for i := 0; i < numFoos; i++ { - w.Write([]byte("foo ")) - } - })) - defer ts.Close() - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - bs, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if len(bs) != numFoos*len("foo ") { - t.Errorf("unexpected response length") - } -} - -// Test that a client receives a server's reply, even if the server doesn't read -// the entire request body. -func TestIssue3595(t *testing.T) { - setParallel(t) - defer afterTest(t) - const deniedMsg = "sorry, denied." - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, deniedMsg, http.StatusUnauthorized) - })) - defer ts.Close() - c := tc().httpClient - res, err := c.Post(ts.URL, "application/octet-stream", tests.NeverEnding('a')) - if err != nil { - t.Errorf("Post: %v", err) - return - } - got, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("Body ReadAll: %v", err) - } - if !strings.Contains(string(got), deniedMsg) { - t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) - } -} - -// From https://golang.org/issue/4454 , -// "client fails to handle requests with no body and chunked encoding" -func TestChunkedNoContent(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer ts.Close() - - c := tc().httpClient - for _, closeBody := range []bool{true, false} { - const n = 4 - for i := 1; i <= n; i++ { - res, err := c.Get(ts.URL) - if err != nil { - t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) - } else { - if closeBody { - res.Body.Close() - } - } - } - } -} - -// SetPendingDialHooks sets the hooks that run before and after handling -// pending dials. -func SetPendingDialHooks(before, after func()) { - unnilTestHook(&before) - unnilTestHook(&after) - testHookPrePendingDial, testHookPostPendingDial = before, after -} - -func TestTransportConcurrency(t *testing.T) { - // Not parallel: uses global test hooks. - defer afterTest(t) - maxProcs, numReqs := 16, 500 - if testing.Short() { - maxProcs, numReqs = 4, 50 - } - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%v", r.FormValue("echo")) - })) - defer ts.Close() - - var wg sync.WaitGroup - wg.Add(numReqs) - - // Due to the Transport's "socket late binding" (see - // idleConnCh in transport.go), the numReqs HTTP requests - // below can finish with a dial still outstanding. To keep - // the leak checker happy, keep track of pending dials and - // wait for them to finish (and be closed or returned to the - // idle pool) before we close idle connections. - SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) - defer SetPendingDialHooks(nil, nil) - - c := tc().httpClient - reqs := make(chan string) - defer close(reqs) - - for i := 0; i < maxProcs*2; i++ { - go func() { - for req := range reqs { - res, err := c.Get(ts.URL + "/?echo=" + req) - if err != nil { - if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { - // https://go.dev/issue/52168: this test was observed to fail with - // ECONNRESET errors in Dial on various netbsd builders. - t.Logf("error on req %s: %v", req, err) - t.Logf("(see https://go.dev/issue/52168)") - } else { - t.Errorf("error on req %s: %v", req, err) - } - wg.Done() - continue - } - all, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("read error on req %s: %v", req, err) - } else if string(all) != req { - t.Errorf("body of req %s = %q; want %q", req, all, req) - } - res.Body.Close() - wg.Done() - } - }() - } - for i := 0; i < numReqs; i++ { - reqs <- fmt.Sprintf("request-%d", i) - } - wg.Wait() -} - -// loggingConn is used for debugging. -type loggingConn struct { - name string - net.Conn -} - -var ( - uniqNameMu sync.Mutex - uniqNameNext = make(map[string]int) -) - -func newLoggingConn(baseName string, c net.Conn) net.Conn { - uniqNameMu.Lock() - defer uniqNameMu.Unlock() - uniqNameNext[baseName]++ - return &loggingConn{ - name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), - Conn: c, - } -} - -func (c *loggingConn) Write(p []byte) (n int, err error) { - log.Printf("%s.Write(%d) = ....", c.name, len(p)) - n, err = c.Conn.Write(p) - log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) - return -} - -func (c *loggingConn) Read(p []byte) (n int, err error) { - log.Printf("%s.Read(%d) = ....", c.name, len(p)) - n, err = c.Conn.Read(p) - log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) - return -} - -func (c *loggingConn) Close() (err error) { - log.Printf("%s.Close() = ...", c.name) - err = c.Conn.Close() - log.Printf("%s.Close() = %v", c.name, err) - return -} - -func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - const debug = false - mux := http.NewServeMux() - mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, tests.NeverEnding('a')) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - timeout := 100 * time.Millisecond - - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = newLoggingConn("client", conn) - } - return conn, nil - } - - getFailed := false - nRuns := 5 - if testing.Short() { - nRuns = 1 - } - for i := 0; i < nRuns; i++ { - if debug { - println("run", i+1, "of", nRuns) - } - sres, err := c.Get(ts.URL + "/get") - if err != nil { - if !getFailed { - // Make the timeout longer, once. - getFailed = true - t.Logf("increasing timeout") - i-- - timeout *= 10 - continue - } - t.Errorf("Error issuing GET: %v", err) - break - } - _, err = io.Copy(io.Discard, sres.Body) - if err == nil { - t.Errorf("Unexpected successful copy") - break - } - } - if debug { - println("tests complete; waiting for handlers to finish") - } -} - -func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - const debug = false - mux := http.NewServeMux() - mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, tests.NeverEnding('a')) - }) - mux.HandleFunc("/put", func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - io.Copy(io.Discard, r.Body) - }) - ts := httptest.NewServer(mux) - timeout := 100 * time.Millisecond - - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = newLoggingConn("client", conn) - } - return conn, nil - } - - getFailed := false - nRuns := 5 - if testing.Short() { - nRuns = 1 - } - for i := 0; i < nRuns; i++ { - if debug { - println("run", i+1, "of", nRuns) - } - sres, err := c.Get(ts.URL + "/get") - if err != nil { - if !getFailed { - // Make the timeout longer, once. - getFailed = true - t.Logf("increasing timeout") - i-- - timeout *= 10 - continue - } - t.Errorf("Error issuing GET: %v", err) - break - } - req, _ := http.NewRequest("PUT", ts.URL+"/put", sres.Body) - _, err = c.Do(req) - if err == nil { - sres.Body.Close() - t.Errorf("Unexpected successful PUT") - break - } - sres.Body.Close() - } - if debug { - println("tests complete; waiting for handlers to finish") - } - ts.Close() -} - -func reqWithT(r *http.Request, t *testing.T) *http.Request { - return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) -} - -func TestTransportResponseHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping timeout test in -short mode") - } - inHandler := make(chan bool, 1) - mux := http.NewServeMux() - mux.HandleFunc("/fast", func(w http.ResponseWriter, r *http.Request) { - inHandler <- true - }) - mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { - inHandler <- true - time.Sleep(2 * time.Second) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond - - tests := []struct { - path string - want int - wantErr string - }{ - {path: "/fast", want: 200}, - {path: "/slow", wantErr: "timeout awaiting response headers"}, - {path: "/fast", want: 200}, - } - for i, tt := range tests { - req, _ := http.NewRequest("GET", ts.URL+tt.path, nil) - req = reqWithT(req, t) - res, err := c.Do(req) - select { - case <-inHandler: - case <-time.After(5 * time.Second): - t.Errorf("never entered handler for test index %d, %s", i, tt.path) - continue - } - if err != nil { - uerr, ok := err.(*url.Error) - if !ok { - t.Errorf("error is not an url.Error; got: %#v", err) - continue - } - nerr, ok := uerr.Err.(net.Error) - if !ok { - t.Errorf("error does not satisfy net.Error interface; got: %#v", err) - continue - } - if !nerr.Timeout() { - t.Errorf("want timeout error; got: %q", nerr) - continue - } - if strings.Contains(err.Error(), tt.wantErr) { - continue - } - t.Errorf("%d. unexpected error: %v", i, err) - continue - } - if tt.wantErr != "" { - t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) - continue - } - if res.StatusCode != tt.want { - t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) - } - } -} - -func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in -short mode") - } - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - tr := c.Transport.(*Transport) - - donec := make(chan bool) - req, _ := http.NewRequest("GET", ts.URL, body) - go func() { - defer close(donec) - c.Do(req) - }() - start := time.Now() - timeout := 10 * time.Second - for time.Since(start) < timeout { - time.Sleep(100 * time.Millisecond) - tr.CancelRequest(req) - select { - case <-donec: - return - default: - } - } - t.Errorf("Do of canceled request has not returned after %v", timeout) -} - -func TestCancelRequestWithChannel(t *testing.T) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in -short mode") - } - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello") - w.(http.Flusher).Flush() // send headers and some body - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - tr := c.Transport.(*Transport) - - req, _ := http.NewRequest("GET", ts.URL, nil) - ch := make(chan struct{}) - req.Cancel = ch - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - go func() { - time.Sleep(1 * time.Second) - close(ch) - }() - t0 := time.Now() - body, err := io.ReadAll(res.Body) - d := time.Since(t0) - - if err != common.ErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } - if string(body) != "Hello" { - t.Errorf("Body = %q; want Hello", body) - } - if d < 500*time.Millisecond { - t.Errorf("expected ~1 second delay; got %v", d) - } - // Verify no outstanding requests after readLoop/writeLoop - // goroutines shut down. - for tries := 5; tries > 0; tries-- { - n := tr.NumPendingRequestsForTesting() - if n == 0 { - break - } - time.Sleep(100 * time.Millisecond) - if tries == 1 { - t.Errorf("pending requests = %d; want 0", n) - } - } -} - -func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, false) -} -func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, true) -} -func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { - setParallel(t) - defer afterTest(t) - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - - req, _ := http.NewRequest("GET", ts.URL, nil) - if withCtx { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - req = req.WithContext(ctx) - } else { - ch := make(chan struct{}) - req.Cancel = ch - close(ch) - } - - _, err := c.Do(req) - if ue, ok := err.(*url.Error); ok { - err = ue.Err - } - if withCtx { - if err != context.Canceled { - t.Errorf("Do error = %v; want %v", err, context.Canceled) - } - } else { - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancellation", err) - } - } -} - -// Issue 11020. The returned error message should be errRequestCanceled -func TestTransportCancelBeforeResponseHeaders(t *testing.T) { - defer afterTest(t) - - serverConnCh := make(chan net.Conn, 1) - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - cc, sc := net.Pipe() - serverConnCh <- sc - return cc, nil - }) - defer tr.CloseIdleConnections() - errc := make(chan error, 1) - req, _ := http.NewRequest("GET", "http://example.com/", nil) - go func() { - _, err := tr.RoundTrip(req) - errc <- err - }() - - sc := <-serverConnCh - verb := make([]byte, 3) - if _, err := io.ReadFull(sc, verb); err != nil { - t.Errorf("Error reading HTTP verb from server: %v", err) - } - if string(verb) != "GET" { - t.Errorf("server received %q; want GET", verb) - } - defer sc.Close() - - tr.CancelRequest(req) - - err := <-errc - if err == nil { - t.Fatalf("unexpected success from RoundTrip") - } - if err != common.ErrRequestCanceled { - t.Errorf("RoundTrip error = %v; want errRequestCanceled", err) - } -} - -// golang.org/issue/3672 -- Client can't close HTTP stream -// Calling Close on a Response.Body used to just read until EOF. -// Now it actually closes the TCP connection. -func TestTransportCloseResponseBody(t *testing.T) { - defer afterTest(t) - writeErr := make(chan error, 1) - msg := []byte("young\n") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - _, err := w.Write(msg) - if err != nil { - writeErr <- err - return - } - w.(http.Flusher).Flush() - } - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - req, _ := http.NewRequest("GET", ts.URL, nil) - defer tr.CancelRequest(req) - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - - const repeats = 3 - buf := make([]byte, len(msg)*repeats) - want := bytes.Repeat(msg, repeats) - - _, err = io.ReadFull(res.Body, buf) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf, want) { - t.Fatalf("read %q; want %q", buf, want) - } - didClose := make(chan error, 1) - go func() { - didClose <- res.Body.Close() - }() - select { - case err := <-didClose: - if err != nil { - t.Errorf("Close = %v", err) - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for close") - } - select { - case err := <-writeErr: - if err == nil { - t.Errorf("expected non-nil write error") - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for write error") - } -} - -func TestTransportNoHost(t *testing.T) { - defer afterTest(t) - tr := T() - _, err := tr.RoundTrip(&http.Request{ - Header: make(http.Header), - URL: &url.URL{ - Scheme: "http", - }, - }) - want := "http: no Host in request URL" - if got := fmt.Sprint(err); got != want { - t.Errorf("error = %v; want %q", err, want) - } -} - -// Issue 13311 -func TestTransportEmptyMethod(t *testing.T) { - req, _ := http.NewRequest("GET", "http://foo.com/", nil) - req.Method = "" // docs say "For client requests an empty string means GET" - got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport - if err != nil { - t.Fatal(err) - } - if !strings.Contains(string(got), "GET ") { - t.Fatalf("expected substring 'GET '; got: %s", got) - } -} - -func TestTransportSocketLateBinding(t *testing.T) { - setParallel(t) - defer afterTest(t) - - mux := http.NewServeMux() - fooGate := make(chan bool, 1) - mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("foo-ipport", r.RemoteAddr) - w.(http.Flusher).Flush() - <-fooGate - }) - mux.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("bar-ipport", r.RemoteAddr) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - - dialGate := make(chan bool, 1) - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - if <-dialGate { - return net.Dial(n, addr) - } - return nil, errors.New("manually closed") - } - - dialGate <- true // only allow one dial - fooRes, err := c.Get(ts.URL + "/foo") - if err != nil { - t.Fatal(err) - } - fooAddr := fooRes.Header.Get("foo-ipport") - if fooAddr == "" { - t.Fatal("No addr on /foo request") - } - time.AfterFunc(200*time.Millisecond, func() { - // let the foo response finish so we can use its - // connection for /bar - fooGate <- true - io.Copy(io.Discard, fooRes.Body) - fooRes.Body.Close() - }) - - barRes, err := c.Get(ts.URL + "/bar") - if err != nil { - t.Fatal(err) - } - barAddr := barRes.Header.Get("bar-ipport") - if barAddr != fooAddr { - t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) - } - barRes.Body.Close() - dialGate <- false -} - -type dummyAddr string -type oneConnListener struct { - conn net.Conn -} - -func (l *oneConnListener) Accept() (c net.Conn, err error) { - c = l.conn - if c == nil { - err = io.EOF - return - } - err = nil - l.conn = nil - return -} - -func (l *oneConnListener) Close() error { - return nil -} - -func (l *oneConnListener) Addr() net.Addr { - return dummyAddr("test-address") -} - -func (a dummyAddr) Network() string { - return string(a) -} - -func (a dummyAddr) String() string { - return string(a) -} - -type noopConn struct{} - -func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } -func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } -func (noopConn) SetDeadline(t time.Time) error { return nil } -func (noopConn) SetReadDeadline(t time.Time) error { return nil } -func (noopConn) SetWriteDeadline(t time.Time) error { return nil } - -type rwTestConn struct { - io.Reader - io.Writer - noopConn - - closeFunc func() error // called if non-nil - closec chan bool // else, if non-nil, send value to it on close -} - -func (c *rwTestConn) Close() error { - if c.closeFunc != nil { - return c.closeFunc() - } - select { - case c.closec <- true: - default: - } - return nil -} - -// Issue 2184 -func TestTransportReading100Continue(t *testing.T) { - defer afterTest(t) - - const numReqs = 5 - reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } - reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } - - send100Response := func(w *io.PipeWriter, r *io.PipeReader) { - defer w.Close() - defer r.Close() - br := bufio.NewReader(r) - n := 0 - for { - n++ - req, err := http.ReadRequest(br) - if err == io.EOF { - return - } - if err != nil { - t.Error(err) - return - } - slurp, err := io.ReadAll(req.Body) - if err != nil { - t.Errorf("Server request body slurp: %v", err) - return - } - id := req.Header.Get("Request-Id") - resCode := req.Header.Get("X-Want-Response-Code") - if resCode == "" { - resCode = "100 Continue" - if string(slurp) != reqBody(n) { - t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) - } - } - body := fmt.Sprintf("Response number %d", n) - v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s -Date: Thu, 28 Feb 2013 17:55:41 GMT - -HTTP/1.1 200 OK -Content-Type: text/html -Echo-Request-Id: %s -Content-Length: %d - -%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) - w.Write(v) - if id == reqID(numReqs) { - return - } - } - - } - - tr := T().SetDial(func(_ context.Context, n, addr string) (net.Conn, error) { - sr, sw := io.Pipe() // server read/write - cr, cw := io.Pipe() // client read/write - conn := &rwTestConn{ - Reader: cr, - Writer: sw, - closeFunc: func() error { - sw.Close() - cw.Close() - return nil - }, - } - go send100Response(cw, sr) - return conn, nil - }) - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - testResponse := func(req *http.Request, name string, wantCode int) { - t.Helper() - res, err := c.Do(req) - if err != nil { - t.Fatalf("%s: Do: %v", name, err) - } - if res.StatusCode != wantCode { - t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) - } - if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { - t.Errorf("%s: response id %q != request id %q", name, idBack, id) - } - _, err = io.ReadAll(res.Body) - if err != nil { - t.Fatalf("%s: Slurp error: %v", name, err) - } - } - - // Few 100 responses, making sure we're not off-by-one. - for i := 1; i <= numReqs; i++ { - req, _ := http.NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) - req.Header.Set("Request-Id", reqID(i)) - testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) - } -} - -type clientServerTest struct { - t *testing.T - h2 bool - h http.Handler - ts *httptest.Server - tr *Transport - c *http.Client -} - -func (t *clientServerTest) close() { - t.tr.CloseIdleConnections() - t.ts.Close() -} - -func (t *clientServerTest) getURL(u string) string { - res, err := t.c.Get(u) - if err != nil { - t.t.Fatal(err) - } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.t.Fatal(err) - } - return string(slurp) -} - -func (t *clientServerTest) scheme() string { - if t.h2 { - return "https" - } - return "http" -} - -const ( - h1Mode = false - h2Mode = true -) - -var quietLog = log.New(io.Discard, "", 0) - -var optQuietLog = func(ts *httptest.Server) { - ts.Config.ErrorLog = quietLog -} - -func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interface{}) *clientServerTest { - cst := &clientServerTest{ - t: t, - h2: h2, - h: h, - tr: T(), - } - cst.c = &http.Client{Transport: cst.tr} - cst.ts = httptest.NewUnstartedServer(h) - - for _, opt := range opts { - switch opt := opt.(type) { - case func(*Transport): - opt(cst.tr) - case func(*httptest.Server): - opt(cst.ts) - default: - t.Fatalf("unhandled option type %T", opt) - } - } - - if !h2 { - cst.ts.Start() - return cst - } - nethttp2.ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig.InsecureSkipVerify = true - return cst -} - -// Issue 17739: the HTTP client must ignore any unknown 1xx -// informational responses before the actual response. -func TestTransportIgnore1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) - buf.Flush() - conn.Close() - })) - defer cst.close() - cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway - - var got bytes.Buffer - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) - return nil - }, - })) - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - - res.Write(&got) - want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" - if got.String() != want { - t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) - } -} - -func TestTransportLimits1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - for i := 0; i < 10; i++ { - buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) - } - buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway - - res, err := cst.c.Get(cst.ts.URL) - if res != nil { - defer res.Body.Close() - } - got := fmt.Sprint(err) - wantSub := "too many 1xx informational responses" - if !strings.Contains(got, wantSub) { - t.Errorf("Get error = %v; want substring %q", err, wantSub) - } -} - -// Issue 26161: the HTTP client must treat 101 responses -// as the final response. -func TestTransportTreat101Terminal(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) - buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if res.StatusCode != http.StatusSwitchingProtocols { - t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) - } -} - -type proxyFromEnvTest struct { - req string // URL to fetch; blank means "http://example.com" - - env string // HTTP_PROXY - httpsenv string // HTTPS_PROXY - noenv string // NO_PROXY - reqmeth string // REQUEST_METHOD - - want string - wanterr error -} - -func (t proxyFromEnvTest) String() string { - var buf bytes.Buffer - space := func() { - if buf.Len() > 0 { - buf.WriteByte(' ') - } - } - if t.env != "" { - fmt.Fprintf(&buf, "http_proxy=%q", t.env) - } - if t.httpsenv != "" { - space() - fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) - } - if t.noenv != "" { - space() - fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) - } - if t.reqmeth != "" { - space() - fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) - } - req := "http://example.com" - if t.req != "" { - req = t.req - } - space() - fmt.Fprintf(&buf, "req=%q", req) - return strings.TrimSpace(buf.String()) -} - -var proxyFromEnvTests = []proxyFromEnvTest{ - {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, - {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, - {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, - {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, - {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, - {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, - {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, - - // Don't use secure for http - {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, - // Use secure for https. - {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, - {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, - - // Issue 16405: don't use HTTP_PROXY in a CGI environment, - // where HTTP_PROXY can be attacker-controlled. - {env: "http://10.1.2.3:8080", reqmeth: "POST", - want: "", - wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, - - {want: ""}, - - {noenv: "example.com", req: "http://example.com/", env: "proxy", want: ""}, - {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, - {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, - {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: ""}, - {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, -} - -func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *http.Request) (*url.URL, error)) { - t.Helper() - reqURL := tt.req - if reqURL == "" { - reqURL = "http://example.com" - } - req, _ := http.NewRequest("GET", reqURL, nil) - url, err := proxyForRequest(req) - if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { - t.Errorf("%v: got error = %q, want %q", tt, g, e) - return - } - if got := fmt.Sprintf("%s", url); got != tt.want { - t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) - } -} - -func ResetProxyEnv() { - for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} { - os.Unsetenv(v) - } -} - -func TestProxyFromEnvironment(t *testing.T) { - ResetProxyEnv() - defer ResetProxyEnv() - for _, tt := range proxyFromEnvTests { - testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { - os.Setenv("HTTP_PROXY", tt.env) - os.Setenv("HTTPS_PROXY", tt.httpsenv) - os.Setenv("NO_PROXY", tt.noenv) - os.Setenv("REQUEST_METHOD", tt.reqmeth) - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }) - } -} - -func TestProxyFromEnvironmentLowerCase(t *testing.T) { - ResetProxyEnv() - defer ResetProxyEnv() - for _, tt := range proxyFromEnvTests { - testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { - os.Setenv("http_proxy", tt.env) - os.Setenv("https_proxy", tt.httpsenv) - os.Setenv("no_proxy", tt.noenv) - os.Setenv("REQUEST_METHOD", tt.reqmeth) - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }) - } -} - -func TestIdleConnChannelLeak(t *testing.T) { - // Not parallel: uses global test hooks. - var mu sync.Mutex - var n int - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - n++ - mu.Unlock() - })) - defer ts.Close() - - const nReqs = 5 - didRead := make(chan bool, nReqs) - SetReadLoopBeforeNextReadHook(func() { didRead <- true }) - defer SetReadLoopBeforeNextReadHook(nil) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - return net.Dial(netw, ts.Listener.Addr().String()) - } - - // First, without keep-alives. - for _, disableKeep := range []bool{true, false} { - tr.DisableKeepAlives = disableKeep - for i := 0; i < nReqs; i++ { - _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) - if err != nil { - t.Fatal(err) - } - // Note: no res.Body.Close is needed here, since the - // response Content-Length is zero. Perhaps the test - // should be more explicit and use a HEAD, but tests - // elsewhere guarantee that zero byte responses generate - // a "Content-Length: 0" instead of chunking. - } - - // At this point, each of the 5 Transport.readLoop goroutines - // are scheduling noting that there are no response bodies (see - // earlier comment), and are then calling putIdleConn, which - // decrements this count. Usually that happens quickly, which is - // why this test has seemed to work for ages. But it's still - // racey: we have wait for them to finish first. See Issue 10427 - for i := 0; i < nReqs; i++ { - <-didRead - } - - if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { - t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) - } - } -} - -// Verify the status quo: that the Client.Post function coerces its -// body into a ReadCloser if it's a Closer, and that the Transport -// then closes it. -func TestTransportClosesRequestBody(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(io.Discard, r.Body) - })) - defer ts.Close() - - c := tc().httpClient - - closes := 0 - - res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if closes != 1 { - t.Errorf("closes = %d; want 1", closes) - } -} - -func TestTransportTLSHandshakeTimeout(t *testing.T) { - defer afterTest(t) - if testing.Short() { - t.Skip("skipping in short mode") - } - ln := tests.NewLocalListener(t) - defer ln.Close() - testdonec := make(chan struct{}) - defer close(testdonec) - - go func() { - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - <-testdonec - c.Close() - }() - - getdonec := make(chan struct{}) - go func() { - defer close(getdonec) - tr := T().SetDial(func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("tcp", ln.Addr().String()) - }).SetTLSHandshakeTimeout(250 * time.Millisecond) - cl := &http.Client{Transport: tr} - _, err := cl.Get("https://dummy.tld/") - if err == nil { - t.Error("expected error") - return - } - ue, ok := err.(*url.Error) - if !ok { - t.Errorf("expected url.Error; got %#v", err) - return - } - ne, ok := ue.Err.(net.Error) - if !ok { - t.Errorf("expected net.Error; got %#v", err) - return - } - if !ne.Timeout() { - t.Errorf("expected timeout error; got %v", err) - } - if !strings.Contains(err.Error(), "handshake timeout") { - t.Errorf("expected 'handshake timeout' in error; got %v", err) - } - }() - select { - case <-getdonec: - case <-time.After(5 * time.Second): - t.Error("test timeout; TLS handshake hung?") - } -} - -// Trying to repro golang.org/issue/3514 -func TestTLSServerClosesConnection(t *testing.T) { - defer afterTest(t) - - closedc := make(chan bool, 1) - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/keep-alive-then-die") { - conn, _, _ := w.(http.Hijacker).Hijack() - conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) - conn.Close() - closedc <- true - return - } - fmt.Fprintf(w, "hello") - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - var nSuccess = 0 - var errs []error - const trials = 20 - for i := 0; i < trials; i++ { - tr.CloseIdleConnections() - res, err := c.Get(ts.URL + "/keep-alive-then-die") - if err != nil { - t.Fatal(err) - } - <-closedc - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if string(slurp) != "foo" { - t.Errorf("Got %q, want foo", slurp) - } - - // Now try again and see if we successfully - // pick a new connection. - res, err = c.Get(ts.URL + "/") - if err != nil { - errs = append(errs, err) - continue - } - slurp, err = io.ReadAll(res.Body) - if err != nil { - errs = append(errs, err) - continue - } - nSuccess++ - } - if nSuccess > 0 { - t.Logf("successes = %d of %d", nSuccess, trials) - } else { - t.Errorf("All runs failed:") - } - for _, err := range errs { - t.Logf(" err: %v", err) - } -} - -// byteFromChanReader is an io.Reader that reads a single byte at a -// time from the channel. When the channel is closed, the reader -// returns io.EOF. -type byteFromChanReader chan byte - -func (c byteFromChanReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return - } - b, ok := <-c - if !ok { - return 0, io.EOF - } - p[0] = b - return 1, nil -} - -// Verifies that the Transport doesn't reuse a connection in the case -// where the server replies before the request has been fully -// written. We still honor that reply (see TestIssue3595), but don't -// send future requests on the connection because it's then in a -// questionable state. -// golang.org/issue/7569 -func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - var sconn struct { - sync.Mutex - c net.Conn - } - var getOkay bool - closeConn := func() { - sconn.Lock() - defer sconn.Unlock() - if sconn.c != nil { - sconn.c.Close() - sconn.c = nil - if !getOkay { - t.Logf("Closed server connection") - } - } - } - defer closeConn() - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - io.WriteString(w, "bar") - return - } - conn, _, _ := w.(http.Hijacker).Hijack() - sconn.Lock() - sconn.c = conn - sconn.Unlock() - conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive - go io.Copy(io.Discard, conn) - })) - defer ts.Close() - c := tc().httpClient - - const bodySize = 256 << 10 - finalBit := make(byteFromChanReader, 1) - req, _ := http.NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(tests.NeverEnding('x'), bodySize-1), finalBit)) - req.ContentLength = bodySize - res, err := c.Do(req) - if err := wantBody(res, err, "foo"); err != nil { - t.Errorf("POST response: %v", err) - } - donec := make(chan bool) - go func() { - defer close(donec) - res, err = c.Get(ts.URL) - if err := wantBody(res, err, "bar"); err != nil { - t.Errorf("GET response: %v", err) - return - } - getOkay = true // suppress test noise - }() - time.AfterFunc(5*time.Second, closeConn) - select { - case <-donec: - finalBit <- 'x' // unblock the writeloop of the first Post - close(finalBit) - case <-time.After(7 * time.Second): - t.Fatal("timeout waiting for GET request to finish") - } -} - -// Tests that we don't leak Transport persistConn.readLoop goroutines -// when a server hangs up immediately after saying it would keep-alive. -func TestTransportIssue10457(t *testing.T) { - defer afterTest(t) // used to fail in goroutine leak check - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Send a response with no body, keep-alive - // (implicit), and then lie and immediately close the - // connection. This forces the Transport's readLoop to - // immediately Peek an io.EOF and get to the point - // that used to hang. - conn, _, _ := w.(http.Hijacker).Hijack() - conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive - conn.Close() - })) - defer ts.Close() - c := tc().httpClient - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("Get: %v", err) - } - defer res.Body.Close() - - // Just a sanity check that we at least get the response. The real - // test here is that the "defer afterTest" above doesn't find any - // leaked goroutines. - if got, want := res.Header.Get("Foo"), "Bar"; got != want { - t.Errorf("Foo header = %q; want %q", got, want) - } -} - -type closerFunc func() error - -func (f closerFunc) Close() error { return f() } - -type writerFuncConn struct { - net.Conn - write func(p []byte) (n int, err error) -} - -func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } - -func hookSetter(dst *func()) func(func()) { - return func(fn func()) { - unnilTestHook(&fn) - *dst = fn - } -} - -var ( - SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) - SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) -) - -// Issue 6981 -func TestTransportClosesBodyOnError(t *testing.T) { - setParallel(t) - defer afterTest(t) - readBody := make(chan error, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := io.ReadAll(r.Body) - readBody <- err - })) - defer ts.Close() - c := tc().httpClient - fakeErr := errors.New("fake error") - didClose := make(chan bool, 1) - req, _ := http.NewRequest("POST", ts.URL, struct { - io.Reader - io.Closer - }{ - io.MultiReader(io.LimitReader(tests.NeverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), - closerFunc(func() error { - select { - case didClose <- true: - default: - } - return nil - }), - }) - res, err := c.Do(req) - if res != nil { - defer res.Body.Close() - } - if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { - t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) - } - select { - case err := <-readBody: - if err == nil { - t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") - } - case <-time.After(5 * time.Second): - t.Error("timeout waiting for server handler to complete") - } - select { - case <-didClose: - default: - t.Errorf("didn't see Body.Close") - } -} - -func TestTransportDialTLS(t *testing.T) { - setParallel(t) - defer afterTest(t) - var mu sync.Mutex // guards following - var gotReq, didDial bool - - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - gotReq = true - mu.Unlock() - })) - defer ts.Close() - c := tc().httpClient - c.Transport.(*Transport).DialTLSContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - mu.Lock() - didDial = true - mu.Unlock() - c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) - if err != nil { - return nil, err - } - return c, c.Handshake() - } - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - mu.Lock() - if !gotReq { - t.Error("didn't get request") - } - if !didDial { - t.Error("didn't use dial hook") - } -} - -// Test for issue 8755 -// Ensure that if a proxy returns an error, it is exposed by RoundTrip -func TestRoundTripReturnsProxyError(t *testing.T) { - badProxy := func(*http.Request) (*url.URL, error) { - return nil, errors.New("errorMessage") - } - - tr := T().SetProxy(badProxy) - - req, _ := http.NewRequest("GET", "http://example.com", nil) - - _, err := tr.RoundTrip(req) - - if err == nil { - t.Error("Expected proxy error to be returned by RoundTrip") - } -} - -// tests that putting an idle conn after a call to CloseIdleConns does return it -func TestTransportCloseIdleConnsThenReturn(t *testing.T) { - tr := T() - wantIdle := func(when string, n int) bool { - got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn - if got == n { - return true - } - t.Errorf("%s: idle conns = %d; want %d", when, got, n) - return false - } - wantIdle("start", 0) - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("put failed") - } - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("second put failed") - } - wantIdle("after put", 2) - tr.CloseIdleConnections() - if !tr.IsIdleForTesting() { - t.Error("should be idle after CloseIdleConnections") - } - wantIdle("after close idle", 0) - if tr.PutIdleTestConn("http", "example.com") { - t.Fatal("put didn't fail") - } - wantIdle("after second put", 0) - - tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode - if tr.IsIdleForTesting() { - t.Error("shouldn't be idle after QueueForIdleConnForTesting") - } - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("after re-activation") - } - wantIdle("after final put", 1) -} - -// Test for issue 34282 -// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn -func TestTransportTraceGotConnH2IdleConns(t *testing.T) { - tr := T() - wantIdle := func(when string, n int) bool { - got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 - if got == n { - return true - } - t.Errorf("%s: idle conns = %d; want %d", when, got, n) - return false - } - wantIdle("start", 0) - alt := funcRoundTripper(func() {}) - if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { - t.Fatal("put failed") - } - wantIdle("after put", 1) - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - GotConn: func(httptrace.GotConnInfo) { - // tr.getConn should leave it for the HTTP/2 alt to call GotConn. - t.Error("GotConn called") - }, - }) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) - _, err := tr.RoundTrip(req) - if err != errFakeRoundTrip { - t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) - } - wantIdle("after round trip", 1) -} - -func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - - trFunc := func(tr *Transport) { - tr.MaxConnsPerHost = 1 - tr.MaxIdleConnsPerHost = 1 - tr.IdleConnTimeout = 10 * time.Millisecond - } - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), trFunc) - defer cst.close() - - if _, err := cst.c.Get(cst.ts.URL); err != nil { - t.Fatalf("got error: %s", err) - } - - time.Sleep(100 * time.Millisecond) - got := make(chan error) - go func() { - if _, err := cst.c.Get(cst.ts.URL); err != nil { - got <- err - } - close(got) - }() - - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - select { - case err := <-got: - if err != nil { - t.Fatalf("got error: %s", err) - } - case <-timeout.C: - t.Fatal("request never completed") - } -} - -// This tests that a client requesting a content range won't also -// implicitly ask for gzip support. If they want that, they need to do it -// on their own. -// golang.org/issue/8923 -func TestTransportRangeAndGzip(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - reqc <- r - })) - defer ts.Close() - c := tc().httpClient - - req, _ := http.NewRequest("GET", ts.URL, nil) - req.Header.Set("Range", "bytes=7-11") - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - - select { - case r := <-reqc: - if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - t.Error("Transport advertised gzip support in the Accept header") - } - if r.Header.Get("Range") == "" { - t.Error("no Range in request") - } - case <-time.After(10 * time.Second): - t.Fatal("timeout") - } - res.Body.Close() -} - -// Test for issue 10474 -func TestTransportResponseCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // important that this response has a body. - var b [1024]byte - w.Write(b[:]) - })) - defer ts.Close() - tr := tc().GetTransport() - - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - // If we do an early close, Transport just throws the connection away and - // doesn't reuse it. In order to trigger the bug, it has to reuse the connection - // so read the body - if _, err := io.Copy(io.Discard, res.Body); err != nil { - t.Fatal(err) - } - - req2, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - tr.CancelRequest(req) - res, err = tr.RoundTrip(req2) - if err != nil { - t.Fatal(err) - } - res.Body.Close() -} - -// Test for issue 19248: Content-Encoding's value is case insensitive. -func TestTransportContentEncodingCaseInsensitive(t *testing.T) { - setParallel(t) - defer afterTest(t) - for _, ce := range []string{"gzip", "GZIP"} { - ce := ce - t.Run(ce, func(t *testing.T) { - const encodedString = "Hello Gopher" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", ce) - gz := gzip.NewWriter(w) - gz.Write([]byte(encodedString)) - gz.Close() - })) - defer ts.Close() - - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - - body, err := io.ReadAll(res.Body) - res.Body.Close() - if err != nil { - t.Fatal(err) - } - - if string(body) != encodedString { - t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) - } - }) - } -} - -func TestTransportDialCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer ts.Close() - tr := tc().GetTransport() - - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - SetEnterRoundTripHook(func() { - tr.CancelRequest(req) - }) - defer SetEnterRoundTripHook(nil) - res, err := tr.RoundTrip(req) - if err != common.ErrRequestCanceled { - t.Errorf("expected canceled request error; got %v", err) - if err == nil { - res.Body.Close() - } - } -} - -// logWritesConn is a net.Conn that logs each Write call to writes -// and then proxies to w. -// It proxies Read calls to a reader it receives from rch. -type logWritesConn struct { - net.Conn // nil. crash on use. - - w io.Writer - - rch <-chan io.Reader - r io.Reader // nil until received by rch - - mu sync.Mutex - writes []string -} - -func (c *logWritesConn) Write(p []byte) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - c.writes = append(c.writes, string(p)) - return c.w.Write(p) -} - -func (c *logWritesConn) Read(p []byte) (n int, err error) { - if c.r == nil { - c.r = <-c.rch - } - return c.r.Read(p) -} - -func (c *logWritesConn) Close() error { return nil } - -// Issue 6574 -func TestTransportFlushesBodyChunks(t *testing.T) { - defer afterTest(t) - resBody := make(chan io.Reader, 1) - connr, connw := io.Pipe() // connection pipe pair - lw := &logWritesConn{ - rch: resBody, - w: connw, - } - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - return lw, nil - }) - bodyr, bodyw := io.Pipe() // body pipe pair - go func() { - defer bodyw.Close() - for i := 0; i < 3; i++ { - fmt.Fprintf(bodyw, "num%d\n", i) - } - }() - resc := make(chan *http.Response) - go func() { - req, _ := http.NewRequest("POST", "http://localhost:8080", bodyr) - req.Header.Set("User-Agent", "x") // known value for test - res, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - close(resc) - return - } - resc <- res - - }() - // Fully consume the request before checking the Write log vs. want. - req, err := http.ReadRequest(bufio.NewReader(connr)) - if err != nil { - t.Fatal(err) - } - io.Copy(io.Discard, req.Body) - - // Unblock the transport's roundTrip goroutine. - resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") - res, ok := <-resc - if !ok { - return - } - defer res.Body.Close() - - want := []string{ - "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", - "5\r\nnum0\n\r\n", - "5\r\nnum1\n\r\n", - "5\r\nnum2\n\r\n", - "0\r\n\r\n", - } - if !reflect.DeepEqual(lw.writes, want) { - t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) - } -} - -// Issue 22088: flush Transport request headers if we're not sure the body won't block on read. -func TestTransportFlushesRequestHeader(t *testing.T) { - defer afterTest(t) - gotReq := make(chan struct{}) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(gotReq) - })) - defer cst.close() - - pr, pw := io.Pipe() - req, err := http.NewRequest("POST", cst.ts.URL, pr) - if err != nil { - t.Fatal(err) - } - gotRes := make(chan struct{}) - go func() { - defer close(gotRes) - res, err := cst.tr.RoundTrip(req) - if err != nil { - t.Error(err) - return - } - res.Body.Close() - }() - - select { - case <-gotReq: - pw.Close() - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for handler to get request") - } - <-gotRes -} - -// Issue 11745. -func TestTransportPrefersResponseOverWriteError(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - defer afterTest(t) - const contentLengthLimit = 1024 * 1024 // 1MB - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.ContentLength >= contentLengthLimit { - w.WriteHeader(http.StatusBadRequest) - r.Body.Close() - return - } - w.WriteHeader(http.StatusOK) - })) - defer ts.Close() - c := tc().httpClient - - fail := 0 - count := 100 - bigBody := strings.Repeat("a", contentLengthLimit*2) - for i := 0; i < count; i++ { - req, err := http.NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) - if err != nil { - t.Fatal(err) - } - resp, err := c.Do(req) - if err != nil { - fail++ - t.Logf("%d = %#v", i, err) - if ue, ok := err.(*url.Error); ok { - t.Logf("urlErr = %#v", ue.Err) - if ne, ok := ue.Err.(*net.OpError); ok { - t.Logf("netOpError = %#v", ne.Err) - } - } - } else { - resp.Body.Close() - if resp.StatusCode != 400 { - t.Errorf("Expected status code 400, got %v", resp.Status) - } - } - } - if fail > 0 { - t.Errorf("Failed %v out of %v\n", fail, count) - } -} - -// Issue 13633: there was a race where we returned bodyless responses -// to callers before recycling the persistent connection, which meant -// a client doing two subsequent requests could end up on different -// connections. It's somewhat harmless but enough tests assume it's -// not true in order to test other things that it's worth fixing. -// Plus it's nice to be consistent and not have timing-dependent -// behavior. -func TestTransportReuseConnEmptyResponseBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Addr", r.RemoteAddr) - // Empty response body. - })) - defer cst.close() - n := 100 - if testing.Short() { - n = 10 - } - var firstAddr string - for i := 0; i < n; i++ { - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - log.Fatal(err) - } - addr := res.Header.Get("X-Addr") - if i == 0 { - firstAddr = addr - } else if addr != firstAddr { - t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) - } - res.Body.Close() - } -} - -func TestTransportReuseConnectionGzipChunked(t *testing.T) { - testTransportReuseConnectionGzip(t, true) -} - -func TestTransportReuseConnectionGzipContentLength(t *testing.T) { - testTransportReuseConnectionGzip(t, false) -} - -// Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnectionGzip(t *testing.T, chunked bool) { - setParallel(t) - defer afterTest(t) - addr := make(chan string, 2) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addr <- r.RemoteAddr - w.Header().Set("Content-Encoding", "gzip") - if chunked { - w.(http.Flusher).Flush() - } - w.Write(rgz) // arbitrary gzip response - })) - defer ts.Close() - c := tc().httpClient - - for i := 0; i < 2; i++ { - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - buf := make([]byte, len(rgz)) - if n, err := io.ReadFull(res.Body, buf); err != nil { - t.Errorf("%d. ReadFull = %v, %v", i, n, err) - } - // Note: no res.Body.Close call. It should work without it, - // since the flate.Reader's internal buffering will hit EOF - // and that should be sufficient. - } - a1, a2 := <-addr, <-addr - if a1 != a2 { - t.Fatalf("didn't reuse connection") - } -} - -func TestTransportResponseHeaderLength(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/long" { - w.Header().Set("Long", strings.Repeat("a", 1<<20)) - } - })) - defer ts.Close() - c := tc().httpClient - c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 - - if res, err := c.Get(ts.URL); err != nil { - t.Fatal(err) - } else { - res.Body.Close() - } - - res, err := c.Get(ts.URL + "/long") - if err == nil { - defer res.Body.Close() - var n int64 - for k, vv := range res.Header { - for _, v := range vv { - n += int64(len(k)) + int64(len(v)) - } - } - t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) - } - if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { - t.Errorf("got error: %v; want %q", err, want) - } -} - -type lookupIPAltResolverKey struct{} - -func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { - defer afterTest(t) - const resBody = "some body" - gotWroteReqEvent := make(chan struct{}, 500) - cst := newClientServerTest(t, h2, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - // Do nothing for the second request. - return - } - if _, err := io.ReadAll(r.Body); err != nil { - t.Error(err) - } - if !noHooks { - select { - case <-gotWroteReqEvent: - case <-time.After(5 * time.Second): - t.Error("timeout waiting for WroteRequest event") - } - } - io.WriteString(w, resBody) - })) - defer cst.close() - - cst.tr.ExpectContinueTimeout = 1 * time.Second - - var mu sync.Mutex // guards buf - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - addrStr := cst.ts.Listener.Addr().String() - ip, port, err := net.SplitHostPort(addrStr) - if err != nil { - t.Fatal(err) - } - - // Install a fake DNS server. - ctx := context.WithValue(context.Background(), lookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { - if host != "dns-is-faked.golang" { - t.Errorf("unexpected DNS host lookup for %q/%q", network, host) - return nil, nil - } - return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil - }) - - body := "some body" - req, _ := http.NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) - req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} - trace := &httptrace.ClientTrace{ - GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, - GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, - GotFirstResponseByte: func() { logf("first response byte") }, - PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, - DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, - DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, - ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, - ConnectDone: func(network, addr string, err error) { - if err != nil { - t.Errorf("ConnectDone: %v", err) - } - logf("ConnectDone: connected to %s %s = %v", network, addr, err) - }, - WroteHeaderField: func(key string, value []string) { - logf("WroteHeaderField: %s: %v", key, value) - }, - WroteHeaders: func() { - logf("WroteHeaders") - }, - Wait100Continue: func() { logf("Wait100Continue") }, - Got100Continue: func() { logf("Got100Continue") }, - WroteRequest: func(e httptrace.WroteRequestInfo) { - logf("WroteRequest: %+v", e) - gotWroteReqEvent <- struct{}{} - }, - } - if h2 { - trace.TLSHandshakeStart = func() { logf("tls handshake start") } - trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { - logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) - } - } - if noHooks { - // zero out all func pointers, trying to get some path to crash - *trace = httptrace.ClientTrace{} - } - req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) - - req.Header.Set("Expect", "100-continue") - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - logf("got roundtrip.response") - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - logf("consumed body") - if string(slurp) != resBody || res.StatusCode != 200 { - t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) - } - res.Body.Close() - - if noHooks { - // Done at this point. Just testing a full HTTP - // requests can happen with a trace pointing to a zero - // ClientTrace, full of nil func pointers. - return - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantOnce := func(sub string) { - if strings.Count(got, sub) != 1 { - t.Errorf("expected substring %q exactly once in output.", sub) - } - } - wantOnceOrMore := func(sub string) { - if strings.Count(got, sub) == 0 { - t.Errorf("expected substring %q at least once in output.", sub) - } - } - wantOnce("Getting conn for dns-is-faked.golang:" + port) - wantOnce("DNS start: {Host:dns-is-faked.golang}") - wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err: Coalesced:false}") - wantOnce("got conn: {") - wantOnceOrMore("Connecting to tcp " + addrStr) - wantOnceOrMore("connected to tcp " + addrStr + " = ") - wantOnce("Reused:false WasIdle:false IdleTime:0s") - wantOnce("first response byte") - if h2 { - wantOnce("tls handshake start") - wantOnce("tls handshake done") - } else { - wantOnce("PutIdleConn = ") - wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") - // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the - // WroteHeaderField hook is not yet implemented in h2.) - wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) - wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) - wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") - wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") - } - wantOnce("WroteHeaders") - wantOnce("Wait100Continue") - wantOnce("Got100Continue") - wantOnce("WroteRequest: {Err:}") - if strings.Contains(got, " to udp ") { - t.Errorf("should not see UDP (DNS) connections") - } - if t.Failed() { - t.Errorf("Output:\n%s", got) - } - - // And do a second request: - req, _ = http.NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) - req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) - res, err = cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != 200 { - t.Fatal(res.Status) - } - res.Body.Close() - - mu.Lock() - got = buf.String() - mu.Unlock() - - sub := "Getting conn for dns-is-faked.golang:" - if gotn, want := strings.Count(got, sub), 2; gotn != want { - t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) - } - -} - -func TestTransportEventTraceTLSVerify(t *testing.T) { - var mu sync.Mutex - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Error("Unexpected request") - })) - defer ts.Close() - ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { - logf("%s", p) - return len(p), nil - }), "", 0) - - certpool := x509.NewCertPool() - certpool.AddCert(ts.Certificate()) - - tr := T().SetTLSClientConfig(&tls.Config{ - ServerName: "dns-is-faked.golang", - RootCAs: certpool, - }) - c := &http.Client{Transport: tr} - - trace := &httptrace.ClientTrace{ - TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, - TLSHandshakeDone: func(s tls.ConnectionState, err error) { - logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) - }, - } - - req, _ := http.NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) - _, err := c.Do(req) - if err == nil { - t.Error("Expected request to fail TLS verification") - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantOnce := func(sub string) { - if strings.Count(got, sub) != 1 { - t.Errorf("expected substring %q exactly once in output.", sub) - } - } - - wantOnce("TLSHandshakeStart") - wantOnce("TLSHandshakeDone") - wantOnce("x509: certificate is valid for example.com") - - if t.Failed() { - t.Errorf("Output:\n%s", got) - } -} - -var ( - isDNSHijackedOnce sync.Once - isDNSHijacked bool -) - -func skipIfDNSHijacked(t *testing.T) { - // Skip this test if the user is using a shady/ISP - // DNS server hijacking queries. - // See issues 16732, 16716. - isDNSHijackedOnce.Do(func() { - addrs, _ := net.LookupHost("dns-should-not-resolve.golang") - isDNSHijacked = len(addrs) != 0 - }) - if isDNSHijacked { - t.Skip("skipping; test requires non-hijacking DNS server") - } -} - -func TestTransportEventTraceRealDNS(t *testing.T) { - skipIfDNSHijacked(t) - defer afterTest(t) - tr := T() - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - var mu sync.Mutex // guards buf - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - req, _ := http.NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) - trace := &httptrace.ClientTrace{ - DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, - DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, - ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, - ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, - } - req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) - - resp, err := c.Do(req) - if err == nil { - resp.Body.Close() - t.Fatal("expected error during DNS lookup") - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantSub := func(sub string) { - if !strings.Contains(got, sub) { - t.Errorf("expected substring %q in output.", sub) - } - } - wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") - wantSub("DNSDone: {Addrs:[] Err:") - if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { - t.Errorf("should not see Connect events") - } - if t.Failed() { - t.Errorf("Output:\n%s", got) - } -} - -// Issue 14353: port can only contain digits. -func TestTransportRejectsAlphaPort(t *testing.T) { - res, err := http.Get("http://dummy.tld:123foo/bar") - if err == nil { - res.Body.Close() - t.Fatal("unexpected success") - } - ue, ok := err.(*url.Error) - if !ok { - t.Fatalf("got %#v; want *url.Error", err) - } - got := ue.Err.Error() - want := `invalid port ":123foo" after host` - if got != want { - t.Errorf("got error %q; want %q", got, want) - } -} - -// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 -// connections. The http2 test is done in TestTransportEventTrace_h2 -func TestTLSHandshakeTrace(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer ts.Close() - - var mu sync.Mutex - var start, done bool - trace := &httptrace.ClientTrace{ - TLSHandshakeStart: func() { - mu.Lock() - defer mu.Unlock() - start = true - }, - TLSHandshakeDone: func(s tls.ConnectionState, err error) { - mu.Lock() - defer mu.Unlock() - done = true - if err != nil { - t.Fatal("Expected error to be nil but was:", err) - } - }, - } - - c := tc().httpClient - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal("Unable to construct test request:", err) - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - r, err := c.Do(req) - if err != nil { - t.Fatal("Unexpected error making request:", err) - } - r.Body.Close() - mu.Lock() - defer mu.Unlock() - if !start { - t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") - } - if !done { - t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") - } -} - -// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an -// HTTP/2 connection was established but its caller no longer -// wanted it. (Assuming the connection cache was enabled, which it is -// by default) -// -// This test reproduced the crash by setting the IdleConnTimeout low -// (to make the test reasonable) and then making a request which is -// canceled by the DialTLS hook, which then also waits to return the -// real connection until after the RoundTrip saw the error. Then we -// know the successful tls.Dial from DialTLS will need to go into the -// idle pool. Then we give it a of time to explode. -func TestIdleConnH2Crash(t *testing.T) { - setParallel(t) - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // nothing - })) - defer cst.close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sawDoErr := make(chan bool, 1) - testDone := make(chan struct{}) - defer close(testDone) - - cst.tr.IdleConnTimeout = 5 * time.Millisecond - cst.tr.DialTLSContext = func(_ context.Context, network, addr string) (net.Conn, error) { - c, err := tls.Dial(network, addr, &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{"h2"}, - }) - if err != nil { - t.Error(err) - return nil, err - } - if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { - t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") - c.Close() - return nil, errors.New("bogus") - } - - cancel() - - failTimer := time.NewTimer(5 * time.Second) - defer failTimer.Stop() - select { - case <-sawDoErr: - case <-testDone: - case <-failTimer.C: - t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") - } - return c, nil - } - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req = req.WithContext(ctx) - res, err := cst.c.Do(req) - if err == nil { - res.Body.Close() - t.Fatal("unexpected success") - } - sawDoErr <- true - - // Wait for the explosion. - time.Sleep(cst.tr.IdleConnTimeout * 10) -} - -type funcConn struct { - net.Conn - read func([]byte) (int, error) - write func([]byte) (int, error) -} - -func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } -func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } -func (c funcConn) Close() error { return nil } - -// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek -// back to the caller. -func TestTransportReturnsPeekError(t *testing.T) { - errValue := errors.New("specific error value") - - wrote := make(chan struct{}) - var wroteOnce sync.Once - - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - c := funcConn{ - read: func([]byte) (int, error) { - <-wrote - return 0, errValue - }, - write: func(p []byte) (int, error) { - wroteOnce.Do(func() { close(wrote) }) - return len(p), nil - }, - } - return c, nil - }) - - _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) - if err != errValue { - t.Errorf("error = %#v; want %v", err, errValue) - } -} - -// Issue 13290: send User-Agent in proxy CONNECT -func TestTransportProxyConnectHeader(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - t.Errorf("method = %q; want CONNECT", r.Method) - } - reqc <- r - c, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack: %v", err) - return - } - c.Close() - })) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { - return url.Parse(ts.URL) - } - c.Transport.(*Transport).ProxyConnectHeader = http.Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - } - - res, err := c.Get("https://dummy.tld/") // https to force a CONNECT - if err == nil { - res.Body.Close() - t.Errorf("unexpected success") - } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } - } -} - -func TestTransportProxyGetConnectHeader(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - t.Errorf("method = %q; want CONNECT", r.Method) - } - reqc <- r - c, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack: %v", err) - return - } - c.Close() - })) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { - return url.Parse(ts.URL) - } - // These should be ignored: - c.Transport.(*Transport).ProxyConnectHeader = http.Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - } - c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { - return http.Header{ - "User-Agent": {"foo2"}, - "Other": {"bar2"}, - }, nil - } - - res, err := c.Get("https://dummy.tld/") // https to force a CONNECT - if err == nil { - res.Body.Close() - t.Errorf("unexpected success") - } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar2"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } - } -} - -var errFakeRoundTrip = errors.New("fake roundtrip") - -type funcRoundTripper func() - -func (fn funcRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { - fn() - return nil, errFakeRoundTrip -} - -func wantBody(res *http.Response, err error, want string) error { - if err != nil { - return err - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("error reading body: %v", err) - } - if string(slurp) != want { - return fmt.Errorf("body = %q; want %q", slurp, want) - } - if err := res.Body.Close(); err != nil { - return fmt.Errorf("body Close = %v", err) - } - return nil -} - -type countCloseReader struct { - n *int - io.Reader -} - -func (cr countCloseReader) Close() error { - (*cr.n)++ - return nil -} - -// rgz is a gzip quine that uncompresses to itself. -var rgz = []byte{ - 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, - 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, - 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, - 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, - 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, - 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, - 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, - 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, - 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, - 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, - 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, - 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, - 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, - 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, - 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, - 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, - 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, - 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, - 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, - 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, - 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, - 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, - 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, - 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, - 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, - 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, - 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, - 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, - 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, - 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, - 0x00, 0x00, -} - -// Ensure that a missing status doesn't make the server panic -// See Issue https://golang.org/issues/21701 -func TestMissingStatusNoPanic(t *testing.T) { - t.Parallel() - - const want = "unknown status code" - - ln := tests.NewLocalListener(t) - addr := ln.Addr().String() - done := make(chan bool) - fullAddrURL := fmt.Sprintf("http://%s", addr) - raw := "HTTP/1.1 400\r\n" + - "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + - "Content-Type: text/html; charset=utf-8\r\n" + - "Content-Length: 10\r\n" + - "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + - "Vary: Accept-Encoding\r\n\r\n" + - "Aloha Olaa" - - go func() { - defer close(done) - - conn, _ := ln.Accept() - if conn != nil { - io.WriteString(conn, raw) - io.ReadAll(conn) - conn.Close() - } - }() - - proxyURL, err := url.Parse(fullAddrURL) - if err != nil { - t.Fatalf("proxyURL: %v", err) - } - - tr := T().SetProxy(http.ProxyURL(proxyURL)) - - req, _ := http.NewRequest("GET", "https://golang.org/", nil) - res, err, panicked := doFetchCheckPanic(tr, req) - if panicked { - t.Error("panicked, expecting an error") - } - if res != nil && res.Body != nil { - io.Copy(io.Discard, res.Body) - res.Body.Close() - } - - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("got=%v want=%q", err, want) - } - - ln.Close() - <-done -} - -func doFetchCheckPanic(tr *Transport, req *http.Request) (res *http.Response, err error, panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - res, err = tr.RoundTrip(req) - return -} - -// Issue 22330: do not allow the response body to be read when the status code -// forbids a response body. -func TestNoBodyOnChunked304Response(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - - // Our test server above is sending back bogus data after the - // response (the "0\r\n\r\n" part), which causes the Transport - // code to log spam. Disable keep-alives so we never even try - // to reuse the connection. - cst.tr.DisableKeepAlives = true - - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - - if res.Body != NoBody { - t.Errorf("Unexpected body on 304 response") - } -} - -type funcWriter func([]byte) (int, error) - -func (f funcWriter) Write(p []byte) (int, error) { return f(p) } - -type doneContext struct { - context.Context - err error -} - -func (doneContext) Done() <-chan struct{} { - c := make(chan struct{}) - close(c) - return c -} - -func (d doneContext) Err() error { return d.err } - -// Issue 25852: Transport should check whether Context is done early. -func TestTransportCheckContextDoneEarly(t *testing.T) { - tr := T() - req, _ := http.NewRequest("GET", "http://fake.example/", nil) - wantErr := errors.New("some error") - req = req.WithContext(doneContext{context.Background(), wantErr}) - _, err := tr.RoundTrip(req) - if err != wantErr { - t.Errorf("error = %v; want %v", err, wantErr) - } -} - -// Issue 23399: verify that if a client request times out, the Transport's -// conn is closed so that it's not reused. -// -// This is the test variant that times out before the server replies with -// any response headers. -func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) - inHandler := make(chan net.Conn, 1) - handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - inHandler <- conn - n, err := conn.Read([]byte{0}) - if n != 0 || err != io.EOF { - t.Errorf("unexpected Read result: %v, %v", n, err) - } - handlerReadReturned <- true - })) - defer cst.close() - - const timeout = 50 * time.Millisecond - cst.c.Timeout = timeout - - _, err := cst.c.Get(cst.ts.URL) - if err == nil { - t.Fatal("unexpected Get succeess") - } - - select { - case c := <-inHandler: - select { - case <-handlerReadReturned: - // Success. - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler - } - case <-time.After(timeout * 10): - // If we didn't get into the Handler in 50ms, that probably means - // the builder was just slow and the Get failed in that time - // but never made it to the server. That's fine. We'll usually - // test the part above on faster machines. - t.Skip("skipping test on slow builder") - } -} - -// Issue 23399: verify that if a client request times out, the Transport's -// conn is closed so that it's not reused. -// -// This is the test variant that has the server send response headers -// first, and time out during the write of the response body. -func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) - inHandler := make(chan net.Conn, 1) - handlerResult := make(chan error, 1) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "100") - w.(http.Flusher).Flush() - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - conn.Write([]byte("foo")) - inHandler <- conn - n, err := conn.Read([]byte{0}) - // The error should be io.EOF or "read tcp - // 127.0.0.1:35827->127.0.0.1:40290: read: connection - // reset by peer" depending on timing. Really we just - // care that it returns at all. But if it returns with - // data, that's weird. - if n != 0 || err == nil { - handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) - return - } - handlerResult <- nil - })) - defer cst.close() - - // Set Timeout to something very long but non-zero to exercise - // the codepaths that check for it. But rather than wait for it to fire - // (which would make the test slow), we send on the req.Cancel channel instead, - // which happens to exercise the same code paths. - cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - cancel := make(chan struct{}) - req.Cancel = cancel - - res, err := cst.c.Do(req) - if err != nil { - select { - case <-inHandler: - t.Fatalf("Get error: %v", err) - default: - // Failed before entering handler. Ignore result. - t.Skip("skipping test on slow builder") - } - } - - close(cancel) - got, err := io.ReadAll(res.Body) - if err == nil { - t.Fatalf("unexpected success; read %q, nil", got) - } - - select { - case c := <-inHandler: - select { - case err := <-handlerResult: - if err != nil { - t.Errorf("handler: %v", err) - } - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler - } - case <-time.After(5 * time.Second): - t.Fatal("timeout") - } -} - -func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { - setParallel(t) - defer afterTest(t) - done := make(chan struct{}) - defer close(done) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - defer conn.Close() - io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") - bs := bufio.NewScanner(conn) - bs.Scan() - fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) - <-done - })) - defer cst.close() - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req.Header.Set("Upgrade", "foo") - req.Header.Set("Connection", "upgrade") - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != 101 { - t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) - } - rwc, ok := res.Body.(io.ReadWriteCloser) - if !ok { - t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) - } - defer rwc.Close() - bs := bufio.NewScanner(rwc) - if !bs.Scan() { - t.Fatalf("expected readable input") - } - if got, want := bs.Text(), "Some buffered data"; got != want { - t.Errorf("read %q; want %q", got, want) - } - io.WriteString(rwc, "echo\n") - if !bs.Scan() { - t.Fatalf("expected another line") - } - if got, want := bs.Text(), "ECHO"; got != want { - t.Errorf("read %q; want %q", got, want) - } -} - -func TestTransportRequestReplayable(t *testing.T) { - someBody := io.NopCloser(strings.NewReader("")) - tests := []struct { - name string - req *http.Request - want bool - }{ - { - name: "GET", - req: &http.Request{Method: "GET"}, - want: true, - }, - { - name: "GET_http.NoBody", - req: &http.Request{Method: "GET", Body: NoBody}, - want: true, - }, - { - name: "GET_body", - req: &http.Request{Method: "GET", Body: someBody}, - want: false, - }, - { - name: "POST", - req: &http.Request{Method: "POST"}, - want: false, - }, - { - name: "POST_idempotency-key", - req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}}, - want: true, - }, - { - name: "POST_x-idempotency-key", - req: &http.Request{Method: "POST", Header: http.Header{"X-Idempotency-Key": {"x"}}}, - want: true, - }, - { - name: "POST_body", - req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}, Body: someBody}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isReplayable(tt.req) - if got != tt.want { - t.Errorf("replyable = %v; want %v", got, tt.want) - } - }) - } -} - -// testMockTCPConn is a mock TCP connection used to test that -// ReadFrom is called when sending the request body. -type testMockTCPConn struct { - *net.TCPConn - - ReadFromCalled bool -} - -func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { - c.ReadFromCalled = true - return c.TCPConn.ReadFrom(r) -} - -func TestTransportRequestWriteRoundTrip(t *testing.T) { - nBytes := int64(1 << 10) - newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := os.CreateTemp("", "net-http-newfilefunc") - if err != nil { - return nil, nil, err - } - - // Write some bytes to the file to enable reading. - if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { - return nil, nil, fmt.Errorf("failed to write data to file: %v", err) - } - if _, err := f.Seek(0, 0); err != nil { - return nil, nil, fmt.Errorf("failed to seek to front: %v", err) - } - - done = func() { - f.Close() - os.Remove(f.Name()) - } - - return f, done, nil - } - - newBufferFunc := func() (io.Reader, func(), error) { - return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil - } - - cases := []struct { - name string - readerFunc func() (io.Reader, func(), error) - contentLength int64 - expectedReadFrom bool - }{ - { - name: "file, length", - readerFunc: newFileFunc, - contentLength: nBytes, - expectedReadFrom: true, - }, - { - name: "file, no length", - readerFunc: newFileFunc, - }, - { - name: "file, negative length", - readerFunc: newFileFunc, - contentLength: -1, - }, - { - name: "buffer", - contentLength: nBytes, - readerFunc: newBufferFunc, - }, - { - name: "buffer, no length", - readerFunc: newBufferFunc, - }, - { - name: "buffer, length -1", - contentLength: -1, - readerFunc: newBufferFunc, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - r, cleanup, err := tc.readerFunc() - if err != nil { - t.Fatal(err) - } - defer cleanup() - - tConn := &testMockTCPConn{} - trFunc := func(tr *Transport) { - tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - var d net.Dialer - conn, err := d.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) - } - - tConn.TCPConn = tcpConn - return tConn, nil - } - } - - cst := newClientServerTest( - t, - h1Mode, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(io.Discard, r.Body) - r.Body.Close() - w.WriteHeader(200) - }), - trFunc, - ) - defer cst.close() - - req, err := http.NewRequest("PUT", cst.ts.URL, r) - if err != nil { - t.Fatal(err) - } - req.ContentLength = tc.contentLength - req.Header.Set("Content-Type", "application/octet-stream") - resp, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - t.Fatalf("status code = %d; want 200", resp.StatusCode) - } - - if !tConn.ReadFromCalled && tc.expectedReadFrom { - t.Fatalf("did not call ReadFrom") - } - - if tConn.ReadFromCalled && !tc.expectedReadFrom { - t.Fatalf("ReadFrom was unexpectedly invoked") - } - }) - } -} - -func TestTransportClone(t *testing.T) { - tr := &Transport{ - Headers: http.Header{ - "test-key": []string{"test-value"}, - }, - Cookies: []*http.Cookie{ - { - Name: "test", - Value: "test", - }, - }, - forceHttpVersion: h1, - Options: transport.Options{ - Proxy: func(*http.Request) (*url.URL, error) { panic("") }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - TLSClientConfig: new(tls.Config), - TLSHandshakeTimeout: time.Second, - DisableKeepAlives: true, - DisableCompression: true, - MaxIdleConns: 1, - MaxIdleConnsPerHost: 1, - MaxConnsPerHost: 1, - IdleConnTimeout: time.Second, - ResponseHeaderTimeout: time.Second, - ExpectContinueTimeout: time.Second, - ProxyConnectHeader: http.Header{}, - GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, - MaxResponseHeaderBytes: 1, - ReadBufferSize: 1, - WriteBufferSize: 1, - Debugf: func(format string, v ...interface{}) {}, - }, - } - tr2 := tr.Clone() - rv := reflect.ValueOf(tr2).Elem() - rt := rv.Type() - for i := 0; i < rt.NumField(); i++ { - sf := rt.Field(i) - if !token.IsExported(sf.Name) { - continue - } - if rv.Field(i).IsZero() { - t.Errorf("cloned field t2.%s is zero", sf.Name) - } - } - - // But test that a nil TLSNextProto is kept nil: - tr = new(Transport) - tr2 = tr.Clone() -} - -func TestIs408(t *testing.T) { - tests := []struct { - in string - want bool - }{ - {"HTTP/1.0 408", true}, - {"HTTP/1.1 408", true}, - {"HTTP/1.8 408", true}, - {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. - {"HTTP/1.1 408 ", true}, - {"HTTP/1.1 40", false}, - {"http/1.0 408", false}, - {"HTTP/1-1 408", false}, - } - for _, tt := range tests { - if got := is408Message([]byte(tt.in)); got != tt.want { - t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) - } - } -} - -func TestTransportIgnores408(t *testing.T) { - // Not parallel. Relies on mutating the log package's global Output. - defer log.SetOutput(log.Writer()) - - var logout bytes.Buffer - log.SetOutput(&logout) - - defer afterTest(t) - const target = "backend:443" - - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nc, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - defer nc.Close() - nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) - nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail - })) - defer cst.close() - req, err := http.NewRequest("GET", cst.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if err != nil { - t.Fatal(err) - } - if string(slurp) != "ok" { - t.Fatalf("got %q; want ok", slurp) - } - - t0 := time.Now() - for i := 0; i < 50; i++ { - time.Sleep(time.Duration(i) * 5 * time.Millisecond) - if cst.tr.IdleConnKeyCountForTesting() == 0 { - if got := logout.String(); got != "" { - t.Fatalf("expected no log output; got: %s", got) - } - return - } - } - t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) -} - -func TestInvalidHeaderResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 200 OK\r\n" + - "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + - "Content-Type: text/html; charset=utf-8\r\n" + - "Content-Length: 0\r\n" + - "Foo : bar\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if v := res.Header.Get("Foo"); v != "" { - t.Errorf(`unexpected "Foo" header: %q`, v) - } - if v := res.Header.Get("Foo "); v != "bar" { - t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") - } -} - -type bodyCloser bool - -func (bc *bodyCloser) Close() error { - *bc = true - return nil -} -func (bc *bodyCloser) Read(b []byte) (n int, err error) { - return 0, io.EOF -} - -// Issue 35015: ensure that Transport closes the body on any error -// with an invalid request, as promised by Client.Do docs. -func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { - cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Errorf("Should not have been invoked") - })) - defer cst.Close() - - u, _ := url.Parse(cst.URL) - - tests := []struct { - name string - req *http.Request - wantErr string - }{ - { - name: "invalid method", - req: &http.Request{ - Method: " ", - URL: u, - }, - wantErr: "invalid method", - }, - { - name: "nil URL", - req: &http.Request{ - Method: "GET", - }, - wantErr: "nil Request.URL", - }, - { - name: "invalid header key", - req: &http.Request{ - Method: "GET", - Header: http.Header{"💡": {"emoji"}}, - URL: u, - }, - wantErr: "invalid header field name", - }, - { - name: "invalid header value", - req: &http.Request{ - Method: "POST", - Header: http.Header{"key": {"\x19"}}, - URL: u, - }, - wantErr: "invalid header field value", - }, - { - name: "non HTTP(s) scheme", - req: &http.Request{ - Method: "POST", - URL: &url.URL{Scheme: "faux"}, - }, - wantErr: "unsupported protocol scheme", - }, - { - name: "no Host in URL", - req: &http.Request{ - Method: "POST", - URL: &url.URL{Scheme: "http"}, - }, - wantErr: "no Host", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var bc bodyCloser - req := tt.req - req.Body = &bc - _, err := DefaultClient().httpClient.Do(tt.req) - if err == nil { - t.Fatal("Expected an error") - } - if !bc { - t.Fatal("Expected body to have been closed") - } - if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { - t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w) - } - }) - } -} - -// breakableConn is a net.Conn wrapper with a Write method -// that will fail when its brokenState is true. -type breakableConn struct { - net.Conn - *brokenState -} - -type brokenState struct { - sync.Mutex - broken bool -} - -func (w *breakableConn) Write(b []byte) (n int, err error) { - w.Lock() - defer w.Unlock() - if w.broken { - return 0, errors.New("some write error") - } - return w.Conn.Write(b) -} - -// Issue 34978: don't cache a broken HTTP/2 connection -func TestDontCacheBrokenHTTP2Conn(t *testing.T) { - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), optQuietLog) - defer cst.close() - - var brokenState brokenState - - const numReqs = 5 - var numDials, gotConns uint32 // atomic - - cst.tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - atomic.AddUint32(&numDials, 1) - c, err := net.Dial(netw, addr) - if err != nil { - t.Errorf("unexpected Dial error: %v", err) - return nil, err - } - return &breakableConn{c, &brokenState}, err - } - - for i := 1; i <= numReqs; i++ { - brokenState.Lock() - brokenState.broken = false - brokenState.Unlock() - - // doBreak controls whether we break the TCP connection after the TLS - // handshake (before the HTTP/2 handshake). We test a few failures - // in a row followed by a final success. - doBreak := i != numReqs - - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) - atomic.AddUint32(&gotConns, 1) - }, - TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { - brokenState.Lock() - defer brokenState.Unlock() - if doBreak { - brokenState.broken = true - } - }, - }) - req, err := http.NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - _, err = cst.c.Do(req) - if doBreak != (err != nil) { - t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) - } - } - if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { - t.Errorf("GotConn calls = %v; want %v", got, want) - } - if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { - t.Errorf("Dials = %v; want %v", got, want) - } -} - -// Issue 34941 -// When the client has too many concurrent requests on a single connection, -// http.http2noCachedConnError is reported on multiple requests. There should -// only be one decrement regardless of the number of failures. -func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { - defer afterTest(t) - - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - }) - - ts := httptest.NewUnstartedServer(h) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - - errCh := make(chan error, 300) - doReq := func() { - resp, err := c.Get(ts.URL) - if err != nil { - errCh <- fmt.Errorf("request failed: %v", err) - return - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - errCh <- fmt.Errorf("read body failed: %v", err) - } - } - - var wg sync.WaitGroup - for i := 0; i < 300; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - close(errCh) - - for err := range errCh { - t.Errorf("error occurred: %v", err) - } -} - -// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers -// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. -func TestTransportRejectsSignInContentLength(t *testing.T) { - cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "+3") - w.Write([]byte("abc")) - })) - defer cst.Close() - - c := cst.Client() - res, err := c.Get(cst.URL) - if err == nil || res != nil { - t.Fatal("Expected a non-nil error and a nil http.Response") - } - if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { - t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) - } -} - -// dumpConn is a net.Conn which writes to Writer and reads from Reader -type dumpConn struct { - io.Writer - io.Reader -} - -func (c *dumpConn) Close() error { return nil } -func (c *dumpConn) LocalAddr() net.Addr { return nil } -func (c *dumpConn) RemoteAddr() net.Addr { return nil } -func (c *dumpConn) SetDeadline(t time.Time) error { return nil } -func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } -func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } - -// delegateReader is a reader that delegates to another reader, -// once it arrives on a channel. -type delegateReader struct { - c chan io.Reader - r io.Reader // nil until received from c -} - -func (r *delegateReader) Read(p []byte) (int, error) { - if r.r == nil { - var ok bool - if r.r, ok = <-r.c; !ok { - return 0, errors.New("delegate closed") - } - } - return r.r.Read(p) -} - -func testTransportRace(req *http.Request) { - save := req.Body - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - dr := &delegateReader{c: make(chan io.Reader)} - - t := T().SetDial(func(_ context.Context, net, addr string) (net.Conn, error) { - return &dumpConn{pw, dr}, nil - }) - defer t.CloseIdleConnections() - - quitReadCh := make(chan struct{}) - // Wait for the request before replying with a dummy response: - go func() { - defer close(quitReadCh) - - req, err := http.ReadRequest(bufio.NewReader(pr)) - if err == nil { - // Ensure all the body is read; otherwise - // we'll get a partial dump. - io.Copy(io.Discard, req.Body) - req.Body.Close() - } - select { - case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): - case quitReadCh <- struct{}{}: - // Ensure delegate is closed so Read doesn't block forever. - close(dr.c) - } - }() - - t.RoundTrip(req) - - // Ensure the reader returns before we reset req.Body to prevent - // a data race on req.Body. - pw.Close() - <-quitReadCh - - req.Body = save -} - -// Issue 37669 -// Test that a cancellation doesn't result in a data race due to the writeLoop -// goroutine being left running, if the caller mutates the processed Request -// upon completion. -func TestErrorWriteLoopRace(t *testing.T) { - if testing.Short() { - return - } - t.Parallel() - for i := 0; i < 1000; i++ { - delay := time.Duration(mrand.Intn(5)) * time.Millisecond - ctx, cancel := context.WithTimeout(context.Background(), delay) - defer cancel() - - r := bytes.NewBuffer(make([]byte, 10000)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r) - if err != nil { - t.Fatal(err) - } - - testTransportRace(req) - } -} - -// Issue 41600 -// Test that a new request which uses the connection of an active request -// cannot cause it to be canceled as well. -func TestCancelRequestWhenSharingConnection(t *testing.T) { - reqc := make(chan chan struct{}, 2) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ch := make(chan struct{}, 1) - reqc <- ch - <-ch - w.Header().Add("Content-Length", "0") - })) - defer ts.Close() - - client := tc().httpClient - transport := client.Transport.(*Transport) - transport.MaxIdleConns = 1 - transport.MaxConnsPerHost = 1 - - var wg sync.WaitGroup - - wg.Add(1) - putidlec := make(chan chan struct{}) - go func() { - defer wg.Done() - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - PutIdleConn: func(error) { - // Signal that the idle conn has been returned to the pool, - // and wait for the order to proceed. - ch := make(chan struct{}) - putidlec <- ch - <-ch - }, - }) - req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL, nil) - res, err := client.Do(req) - if err == nil { - res.Body.Close() - } - if err != nil { - t.Errorf("request 1: got err %v, want nil", err) - } - }() - - // Wait for the first request to receive a response and return the - // connection to the idle pool. - r1c := <-reqc - close(r1c) - idlec := <-putidlec - - wg.Add(1) - cancelctx, cancel := context.WithCancel(context.Background()) - go func() { - defer wg.Done() - req, _ := http.NewRequestWithContext(cancelctx, "GET", ts.URL, nil) - res, err := client.Do(req) - if err == nil { - res.Body.Close() - } - if !errors.Is(err, context.Canceled) { - t.Errorf("request 2: got err %v, want Canceled", err) - } - }() - - // Wait for the second request to arrive at the server, and then cancel - // the request context. - r2c := <-reqc - cancel() - - // Give the cancelation a moment to take effect, and then unblock the first request. - time.Sleep(1 * time.Millisecond) - close(idlec) - - close(r2c) - wg.Wait() -} From d0373d27487a0fac2a5005e258747bcccb5e44e8 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 18:01:11 +0800 Subject: [PATCH 795/843] merge upstream net/http: 2024-01-30(e8b5bc) --- internal/bisect/bisect.go | 794 +++++++++++++++++++++++++++++++ internal/godebug/godebug.go | 249 +++++++++- internal/godebug/godebug_test.go | 34 -- internal/godebugs/table.go | 78 +++ request_test.go | 8 +- roundtrip_js.go | 142 ++---- server.go | 18 + textproto_reader.go | 165 +++++-- transfer.go | 112 +++-- transport.go | 40 +- 10 files changed, 1404 insertions(+), 236 deletions(-) create mode 100644 internal/bisect/bisect.go delete mode 100644 internal/godebug/godebug_test.go create mode 100644 internal/godebugs/table.go create mode 100644 server.go diff --git a/internal/bisect/bisect.go b/internal/bisect/bisect.go new file mode 100644 index 00000000..3e5a6849 --- /dev/null +++ b/internal/bisect/bisect.go @@ -0,0 +1,794 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package bisect can be used by compilers and other programs +// to serve as a target for the bisect debugging tool. +// See [golang.org/x/tools/cmd/bisect] for details about using the tool. +// +// To be a bisect target, allowing bisect to help determine which of a set of independent +// changes provokes a failure, a program needs to: +// +// 1. Define a way to accept a change pattern on its command line or in its environment. +// The most common mechanism is a command-line flag. +// The pattern can be passed to [New] to create a [Matcher], the compiled form of a pattern. +// +// 2. Assign each change a unique ID. One possibility is to use a sequence number, +// but the most common mechanism is to hash some kind of identifying information +// like the file and line number where the change might be applied. +// [Hash] hashes its arguments to compute an ID. +// +// 3. Enable each change that the pattern says should be enabled. +// The [Matcher.ShouldEnable] method answers this question for a given change ID. +// +// 4. Print a report identifying each change that the pattern says should be printed. +// The [Matcher.ShouldPrint] method answers this question for a given change ID. +// The report consists of one more lines on standard error or standard output +// that contain a “match marker”. [Marker] returns the match marker for a given ID. +// When bisect reports a change as causing the failure, it identifies the change +// by printing the report lines with the match marker removed. +// +// # Example Usage +// +// A program starts by defining how it receives the pattern. In this example, we will assume a flag. +// The next step is to compile the pattern: +// +// m, err := bisect.New(patternFlag) +// if err != nil { +// log.Fatal(err) +// } +// +// Then, each time a potential change is considered, the program computes +// a change ID by hashing identifying information (source file and line, in this case) +// and then calls m.ShouldPrint and m.ShouldEnable to decide whether to +// print and enable the change, respectively. The two can return different values +// depending on whether bisect is trying to find a minimal set of changes to +// disable or to enable to provoke the failure. +// +// It is usually helpful to write a helper function that accepts the identifying information +// and then takes care of hashing, printing, and reporting whether the identified change +// should be enabled. For example, a helper for changes identified by a file and line number +// would be: +// +// func ShouldEnable(file string, line int) { +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// return m.ShouldEnable(h) +// } +// +// Finally, note that New returns a nil Matcher when there is no pattern, +// meaning that the target is not running under bisect at all, +// so all changes should be enabled and none should be printed. +// In that common case, the computation of the hash can be avoided entirely +// by checking for m == nil first: +// +// func ShouldEnable(file string, line int) bool { +// if m == nil { +// return true +// } +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// return m.ShouldEnable(h) +// } +// +// When the identifying information is expensive to format, this code can call +// [Matcher.MarkerOnly] to find out whether short report lines containing only the +// marker are permitted for a given run. (Bisect permits such lines when it is +// still exploring the space of possible changes and will not be showing the +// output to the user.) If so, the client can choose to print only the marker: +// +// func ShouldEnable(file string, line int) bool { +// if m == nil { +// return true +// } +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// if m.MarkerOnly() { +// bisect.PrintMarker(os.Stderr, h) +// } else { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// } +// return m.ShouldEnable(h) +// } +// +// This specific helper – deciding whether to enable a change identified by +// file and line number and printing about the change when necessary – is +// provided by the [Matcher.FileLine] method. +// +// Another common usage is deciding whether to make a change in a function +// based on the caller's stack, to identify the specific calling contexts that the +// change breaks. The [Matcher.Stack] method takes care of obtaining the stack, +// printing it when necessary, and reporting whether to enable the change +// based on that stack. +// +// # Pattern Syntax +// +// Patterns are generated by the bisect tool and interpreted by [New]. +// Users should not have to understand the patterns except when +// debugging a target's bisect support or debugging the bisect tool itself. +// +// The pattern syntax selecting a change is a sequence of bit strings +// separated by + and - operators. Each bit string denotes the set of +// changes with IDs ending in those bits, + is set addition, - is set subtraction, +// and the expression is evaluated in the usual left-to-right order. +// The special binary number “y” denotes the set of all changes, +// standing in for the empty bit string. +// In the expression, all the + operators must appear before all the - operators. +// A leading + adds to an empty set. A leading - subtracts from the set of all +// possible suffixes. +// +// For example: +// +// - “01+10” and “+01+10” both denote the set of changes +// with IDs ending with the bits 01 or 10. +// +// - “01+10-1001” denotes the set of changes with IDs +// ending with the bits 01 or 10, but excluding those ending in 1001. +// +// - “-01-1000” and “y-01-1000 both denote the set of all changes +// with IDs not ending in 01 nor 1000. +// +// - “0+1-01+001” is not a valid pattern, because all the + operators do not +// appear before all the - operators. +// +// In the syntaxes described so far, the pattern specifies the changes to +// enable and report. If a pattern is prefixed by a “!”, the meaning +// changes: the pattern specifies the changes to DISABLE and report. This +// mode of operation is needed when a program passes with all changes +// enabled but fails with no changes enabled. In this case, bisect +// searches for minimal sets of changes to disable. +// Put another way, the leading “!” inverts the result from [Matcher.ShouldEnable] +// but does not invert the result from [Matcher.ShouldPrint]. +// +// As a convenience for manual debugging, “n” is an alias for “!y”, +// meaning to disable and report all changes. +// +// Finally, a leading “v” in the pattern indicates that the reports will be shown +// to the user of bisect to describe the changes involved in a failure. +// At the API level, the leading “v” causes [Matcher.Visible] to return true. +// See the next section for details. +// +// # Match Reports +// +// The target program must enable only those changed matched +// by the pattern, and it must print a match report for each such change. +// A match report consists of one or more lines of text that will be +// printed by the bisect tool to describe a change implicated in causing +// a failure. Each line in the report for a given change must contain a +// match marker with that change ID, as returned by [Marker]. +// The markers are elided when displaying the lines to the user. +// +// A match marker has the form “[bisect-match 0x1234]” where +// 0x1234 is the change ID in hexadecimal. +// An alternate form is “[bisect-match 010101]”, giving the change ID in binary. +// +// When [Matcher.Visible] returns false, the match reports are only +// being processed by bisect to learn the set of enabled changes, +// not shown to the user, meaning that each report can be a match +// marker on a line by itself, eliding the usual textual description. +// When the textual description is expensive to compute, +// checking [Matcher.Visible] can help the avoid that expense +// in most runs. +package bisect + +import ( + "runtime" + "sync" + "sync/atomic" + "unsafe" +) + +// New creates and returns a new Matcher implementing the given pattern. +// The pattern syntax is defined in the package doc comment. +// +// In addition to the pattern syntax syntax, New("") returns nil, nil. +// The nil *Matcher is valid for use: it returns true from ShouldEnable +// and false from ShouldPrint for all changes. Callers can avoid calling +// [Hash], [Matcher.ShouldEnable], and [Matcher.ShouldPrint] entirely +// when they recognize the nil Matcher. +func New(pattern string) (*Matcher, error) { + if pattern == "" { + return nil, nil + } + + m := new(Matcher) + + p := pattern + // Special case for leading 'q' so that 'qn' quietly disables, e.g. fmahash=qn to disable fma + // Any instance of 'v' disables 'q'. + if len(p) > 0 && p[0] == 'q' { + m.quiet = true + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + // Allow multiple v, so that “bisect cmd vPATTERN” can force verbose all the time. + for len(p) > 0 && p[0] == 'v' { + m.verbose = true + m.quiet = false + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + + // Allow multiple !, each negating the last, so that “bisect cmd !PATTERN” works + // even when bisect chooses to add its own !. + m.enable = true + for len(p) > 0 && p[0] == '!' { + m.enable = !m.enable + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + + if p == "n" { + // n is an alias for !y. + m.enable = !m.enable + p = "y" + } + + // Parse actual pattern syntax. + result := true + bits := uint64(0) + start := 0 + wid := 1 // 1-bit (binary); sometimes 4-bit (hex) + for i := 0; i <= len(p); i++ { + // Imagine a trailing - at the end of the pattern to flush final suffix + c := byte('-') + if i < len(p) { + c = p[i] + } + if i == start && wid == 1 && c == 'x' { // leading x for hex + start = i + 1 + wid = 4 + continue + } + switch c { + default: + return nil, &parseError{"invalid pattern syntax: " + pattern} + case '2', '3', '4', '5', '6', '7', '8', '9': + if wid != 4 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + fallthrough + case '0', '1': + bits <<= wid + bits |= uint64(c - '0') + case 'a', 'b', 'c', 'd', 'e', 'f', 'A', 'B', 'C', 'D', 'E', 'F': + if wid != 4 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + bits <<= 4 + bits |= uint64(c&^0x20 - 'A' + 10) + case 'y': + if i+1 < len(p) && (p[i+1] == '0' || p[i+1] == '1') { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + bits = 0 + case '+', '-': + if c == '+' && result == false { + // Have already seen a -. Should be - from here on. + return nil, &parseError{"invalid pattern syntax (+ after -): " + pattern} + } + if i > 0 { + n := (i - start) * wid + if n > 64 { + return nil, &parseError{"pattern bits too long: " + pattern} + } + if n <= 0 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + if p[start] == 'y' { + n = 0 + } + mask := uint64(1)<= 0; i-- { + c := &m.list[i] + if id&c.mask == c.bits { + return c.result + } + } + return false +} + +// FileLine reports whether the change identified by file and line should be enabled. +// If the change should be printed, FileLine prints a one-line report to w. +func (m *Matcher) FileLine(w Writer, file string, line int) bool { + if m == nil { + return true + } + return m.fileLine(w, file, line) +} + +// fileLine does the real work for FileLine. +// This lets FileLine's body handle m == nil and potentially be inlined. +func (m *Matcher) fileLine(w Writer, file string, line int) bool { + h := Hash(file, line) + if m.ShouldPrint(h) { + if m.MarkerOnly() { + PrintMarker(w, h) + } else { + printFileLine(w, h, file, line) + } + } + return m.ShouldEnable(h) +} + +// printFileLine prints a non-marker-only report for file:line to w. +func printFileLine(w Writer, h uint64, file string, line int) error { + const markerLen = 40 // overestimate + b := make([]byte, 0, markerLen+len(file)+24) + b = AppendMarker(b, h) + b = appendFileLine(b, file, line) + b = append(b, '\n') + _, err := w.Write(b) + return err +} + +// appendFileLine appends file:line to dst, returning the extended slice. +func appendFileLine(dst []byte, file string, line int) []byte { + dst = append(dst, file...) + dst = append(dst, ':') + u := uint(line) + if line < 0 { + dst = append(dst, '-') + u = -u + } + var buf [24]byte + i := len(buf) + for i == len(buf) || u > 0 { + i-- + buf[i] = '0' + byte(u%10) + u /= 10 + } + dst = append(dst, buf[i:]...) + return dst +} + +// MatchStack assigns the current call stack a change ID. +// If the stack should be printed, MatchStack prints it. +// Then MatchStack reports whether a change at the current call stack should be enabled. +func (m *Matcher) Stack(w Writer) bool { + if m == nil { + return true + } + return m.stack(w) +} + +// stack does the real work for Stack. +// This lets stack's body handle m == nil and potentially be inlined. +func (m *Matcher) stack(w Writer) bool { + const maxStack = 16 + var stk [maxStack]uintptr + n := runtime.Callers(2, stk[:]) + // caller #2 is not for printing; need it to normalize PCs if ASLR. + if n <= 1 { + return false + } + + base := stk[0] + // normalize PCs + for i := range stk[:n] { + stk[i] -= base + } + + h := Hash(stk[:n]) + if m.ShouldPrint(h) { + var d *dedup + for { + d = m.dedup.Load() + if d != nil { + break + } + d = new(dedup) + if m.dedup.CompareAndSwap(nil, d) { + break + } + } + + if m.MarkerOnly() { + if !d.seenLossy(h) { + PrintMarker(w, h) + } + } else { + if !d.seen(h) { + // Restore PCs in stack for printing + for i := range stk[:n] { + stk[i] += base + } + printStack(w, h, stk[1:n]) + } + } + } + return m.ShouldEnable(h) +} + +// Writer is the same interface as io.Writer. +// It is duplicated here to avoid importing io. +type Writer interface { + Write([]byte) (int, error) +} + +// PrintMarker prints to w a one-line report containing only the marker for h. +// It is appropriate to use when [Matcher.ShouldPrint] and [Matcher.MarkerOnly] both return true. +func PrintMarker(w Writer, h uint64) error { + var buf [50]byte + b := AppendMarker(buf[:0], h) + b = append(b, '\n') + _, err := w.Write(b) + return err +} + +// printStack prints to w a multi-line report containing a formatting of the call stack stk, +// with each line preceded by the marker for h. +func printStack(w Writer, h uint64, stk []uintptr) error { + buf := make([]byte, 0, 2048) + + var prefixBuf [100]byte + prefix := AppendMarker(prefixBuf[:0], h) + + frames := runtime.CallersFrames(stk) + for { + f, more := frames.Next() + buf = append(buf, prefix...) + buf = append(buf, f.Func.Name()...) + buf = append(buf, "()\n"...) + buf = append(buf, prefix...) + buf = append(buf, '\t') + buf = appendFileLine(buf, f.File, f.Line) + buf = append(buf, '\n') + if !more { + break + } + } + buf = append(buf, prefix...) + buf = append(buf, '\n') + _, err := w.Write(buf) + return err +} + +// Marker returns the match marker text to use on any line reporting details +// about a match of the given ID. +// It always returns the hexadecimal format. +func Marker(id uint64) string { + return string(AppendMarker(nil, id)) +} + +// AppendMarker is like [Marker] but appends the marker to dst. +func AppendMarker(dst []byte, id uint64) []byte { + const prefix = "[bisect-match 0x" + var buf [len(prefix) + 16 + 1]byte + copy(buf[:], prefix) + for i := 0; i < 16; i++ { + buf[len(prefix)+i] = "0123456789abcdef"[id>>60] + id <<= 4 + } + buf[len(prefix)+16] = ']' + return append(dst, buf[:]...) +} + +// CutMarker finds the first match marker in line and removes it, +// returning the shortened line (with the marker removed), +// the ID from the match marker, +// and whether a marker was found at all. +// If there is no marker, CutMarker returns line, 0, false. +func CutMarker(line string) (short string, id uint64, ok bool) { + // Find first instance of prefix. + prefix := "[bisect-match " + i := 0 + for ; ; i++ { + if i >= len(line)-len(prefix) { + return line, 0, false + } + if line[i] == '[' && line[i:i+len(prefix)] == prefix { + break + } + } + + // Scan to ]. + j := i + len(prefix) + for j < len(line) && line[j] != ']' { + j++ + } + if j >= len(line) { + return line, 0, false + } + + // Parse id. + idstr := line[i+len(prefix) : j] + if len(idstr) >= 3 && idstr[:2] == "0x" { + // parse hex + if len(idstr) > 2+16 { // max 0x + 16 digits + return line, 0, false + } + for i := 2; i < len(idstr); i++ { + id <<= 4 + switch c := idstr[i]; { + case '0' <= c && c <= '9': + id |= uint64(c - '0') + case 'a' <= c && c <= 'f': + id |= uint64(c - 'a' + 10) + case 'A' <= c && c <= 'F': + id |= uint64(c - 'A' + 10) + } + } + } else { + if idstr == "" || len(idstr) > 64 { // min 1 digit, max 64 digits + return line, 0, false + } + // parse binary + for i := 0; i < len(idstr); i++ { + id <<= 1 + switch c := idstr[i]; c { + default: + return line, 0, false + case '0', '1': + id |= uint64(c - '0') + } + } + } + + // Construct shortened line. + // Remove at most one space from around the marker, + // so that "foo [marker] bar" shortens to "foo bar". + j++ // skip ] + if i > 0 && line[i-1] == ' ' { + i-- + } else if j < len(line) && line[j] == ' ' { + j++ + } + short = line[:i] + line[j:] + return short, id, true +} + +// Hash computes a hash of the data arguments, +// each of which must be of type string, byte, int, uint, int32, uint32, int64, uint64, uintptr, or a slice of one of those types. +func Hash(data ...any) uint64 { + h := offset64 + for _, v := range data { + switch v := v.(type) { + default: + // Note: Not printing the type, because reflect.ValueOf(v) + // would make the interfaces prepared by the caller escape + // and therefore allocate. This way, Hash(file, line) runs + // without any allocation. It should be clear from the + // source code calling Hash what the bad argument was. + panic("bisect.Hash: unexpected argument type") + case string: + h = fnvString(h, v) + case byte: + h = fnv(h, v) + case int: + h = fnvUint64(h, uint64(v)) + case uint: + h = fnvUint64(h, uint64(v)) + case int32: + h = fnvUint32(h, uint32(v)) + case uint32: + h = fnvUint32(h, v) + case int64: + h = fnvUint64(h, uint64(v)) + case uint64: + h = fnvUint64(h, v) + case uintptr: + h = fnvUint64(h, uint64(v)) + case []string: + for _, x := range v { + h = fnvString(h, x) + } + case []byte: + for _, x := range v { + h = fnv(h, x) + } + case []int: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []uint: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []int32: + for _, x := range v { + h = fnvUint32(h, uint32(x)) + } + case []uint32: + for _, x := range v { + h = fnvUint32(h, x) + } + case []int64: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []uint64: + for _, x := range v { + h = fnvUint64(h, x) + } + case []uintptr: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + } + } + return h +} + +// Trivial error implementation, here to avoid importing errors. + +// parseError is a trivial error implementation, +// defined here to avoid importing errors. +type parseError struct{ text string } + +func (e *parseError) Error() string { return e.text } + +// FNV-1a implementation. See Go's hash/fnv/fnv.go. +// Copied here for simplicity (can handle integers more directly) +// and to avoid importing hash/fnv. + +const ( + offset64 uint64 = 14695981039346656037 + prime64 uint64 = 1099511628211 +) + +func fnv(h uint64, x byte) uint64 { + h ^= uint64(x) + h *= prime64 + return h +} + +func fnvString(h uint64, x string) uint64 { + for i := 0; i < len(x); i++ { + h ^= uint64(x[i]) + h *= prime64 + } + return h +} + +func fnvUint64(h uint64, x uint64) uint64 { + for i := 0; i < 8; i++ { + h ^= x & 0xFF + x >>= 8 + h *= prime64 + } + return h +} + +func fnvUint32(h uint64, x uint32) uint64 { + for i := 0; i < 4; i++ { + h ^= uint64(x & 0xFF) + x >>= 8 + h *= prime64 + } + return h +} + +// A dedup is a deduplicator for call stacks, so that we only print +// a report for new call stacks, not for call stacks we've already +// reported. +// +// It has two modes: an approximate but lock-free mode that +// may still emit some duplicates, and a precise mode that uses +// a lock and never emits duplicates. +type dedup struct { + // 128-entry 4-way, lossy cache for seenLossy + recent [128][4]uint64 + + // complete history for seen + mu sync.Mutex + m map[uint64]bool +} + +// seen records that h has now been seen and reports whether it was seen before. +// When seen returns false, the caller is expected to print a report for h. +func (d *dedup) seen(h uint64) bool { + d.mu.Lock() + if d.m == nil { + d.m = make(map[uint64]bool) + } + seen := d.m[h] + d.m[h] = true + d.mu.Unlock() + return seen +} + +// seenLossy is a variant of seen that avoids a lock by using a cache of recently seen hashes. +// Each cache entry is N-way set-associative: h can appear in any of the slots. +// If h does not appear in any of them, then it is inserted into a random slot, +// overwriting whatever was there before. +func (d *dedup) seenLossy(h uint64) bool { + cache := &d.recent[uint(h)%uint(len(d.recent))] + for i := 0; i < len(cache); i++ { + if atomic.LoadUint64(&cache[i]) == h { + return true + } + } + + // Compute index in set to evict as hash of current set. + ch := offset64 + for _, x := range cache { + ch = fnvUint64(ch, x) + } + atomic.StoreUint64(&cache[uint(ch)%uint(len(cache))], h) + return false +} diff --git a/internal/godebug/godebug.go b/internal/godebug/godebug.go index ac434e5f..f5a0f53a 100644 --- a/internal/godebug/godebug.go +++ b/internal/godebug/godebug.go @@ -2,33 +2,244 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package godebug parses the GODEBUG environment variable. +// Package godebug makes the settings in the $GODEBUG environment variable +// available to other packages. These settings are often used for compatibility +// tweaks, when we need to change a default behavior but want to let users +// opt back in to the original. For example GODEBUG=http2server=0 disables +// HTTP/2 support in the net/http server. +// +// In typical usage, code should declare a Setting as a global +// and then call Value each time the current setting value is needed: +// +// var http2server = godebug.New("http2server") +// +// func ServeConn(c net.Conn) { +// if http2server.Value() == "0" { +// disallow HTTP/2 +// ... +// } +// ... +// } +// +// Each time a non-default setting causes a change in program behavior, +// code should call [Setting.IncNonDefault] to increment a counter that can +// be reported by [runtime/metrics.Read]. +// Note that counters used with IncNonDefault must be added to +// various tables in other packages. See the [Setting.IncNonDefault] +// documentation for details. package godebug -import "os" +// Note: Be careful about new imports here. Any package +// that internal/godebug imports cannot itself import internal/godebug, +// meaning it cannot introduce a GODEBUG setting of its own. +// We keep imports to the absolute bare minimum. +import ( + "sync" + "sync/atomic" + _ "unsafe" // go:linkname -// Get returns the value for the provided GODEBUG key. -func Get(key string) string { - return get(os.Getenv("GODEBUG"), key) + "github.com/imroc/req/v3/internal/bisect" + "github.com/imroc/req/v3/internal/godebugs" +) + +// A Setting is a single setting in the $GODEBUG environment variable. +type Setting struct { + name string + once sync.Once + *setting +} + +type setting struct { + value atomic.Pointer[value] + nonDefaultOnce sync.Once + nonDefault atomic.Uint64 + info *godebugs.Info +} + +type value struct { + text string + bisect *bisect.Matcher +} + +// New returns a new Setting for the $GODEBUG setting with the given name. +// +// GODEBUGs meant for use by end users must be listed in ../godebugs/table.go, +// which is used for generating and checking various documentation. +// If the name is not listed in that table, New will succeed but calling Value +// on the returned Setting will panic. +// To disable that panic for access to an undocumented setting, +// prefix the name with a #, as in godebug.New("#gofsystrace"). +// The # is a signal to New but not part of the key used in $GODEBUG. +func New(name string) *Setting { + return &Setting{name: name} +} + +// Name returns the name of the setting. +func (s *Setting) Name() string { + if s.name != "" && s.name[0] == '#' { + return s.name[1:] + } + return s.name } -// get returns the value part of key=value in s (a GODEBUG value). -func get(s, key string) string { - for i := 0; i < len(s)-len(key)-1; i++ { - if i > 0 && s[i-1] != ',' { - continue +// Undocumented reports whether this is an undocumented setting. +func (s *Setting) Undocumented() bool { + return s.name != "" && s.name[0] == '#' +} + +// String returns a printable form for the setting: name=value. +func (s *Setting) String() string { + return s.Name() + "=" + s.Value() +} + +// IncNonDefault increments the non-default behavior counter +// associated with the given setting. +// This counter is exposed in the runtime/metrics value +// /godebug/non-default-behavior/:events. +// +// Note that Value must be called at least once before IncNonDefault. +func (s *Setting) IncNonDefault() { + s.nonDefaultOnce.Do(s.register) + s.nonDefault.Add(1) +} + +func (s *Setting) register() { + if s.info == nil || s.info.Opaque { + panic("godebug: unexpected IncNonDefault of " + s.name) + } +} + +// cache is a cache of all the GODEBUG settings, +// a locked map[string]*atomic.Pointer[string]. +// +// All Settings with the same name share a single +// *atomic.Pointer[string], so that when GODEBUG +// changes only that single atomic string pointer +// needs to be updated. +// +// A name appears in the values map either if it is the +// name of a Setting for which Value has been called +// at least once, or if the name has ever appeared in +// a name=value pair in the $GODEBUG environment variable. +// Once entered into the map, the name is never removed. +var cache sync.Map // name string -> value *atomic.Pointer[string] + +var empty value + +// Value returns the current value for the GODEBUG setting s. +// +// Value maintains an internal cache that is synchronized +// with changes to the $GODEBUG environment variable, +// making Value efficient to call as frequently as needed. +// Clients should therefore typically not attempt their own +// caching of Value's result. +func (s *Setting) Value() string { + s.once.Do(func() { + s.setting = lookup(s.Name()) + if s.info == nil && !s.Undocumented() { + panic("godebug: Value of name not listed in godebugs.All: " + s.name) } - afterKey := s[i+len(key):] - if afterKey[0] != '=' || s[i:i+len(key)] != key { - continue + }) + v := *s.value.Load() + if v.bisect != nil && !v.bisect.Stack(&stderr) { + return "" + } + return v.text +} + +// lookup returns the unique *setting value for the given name. +func lookup(name string) *setting { + if v, ok := cache.Load(name); ok { + return v.(*setting) + } + s := new(setting) + s.info = godebugs.Lookup(name) + s.value.Store(&empty) + if v, loaded := cache.LoadOrStore(name, s); loaded { + // Lost race: someone else created it. Use theirs. + return v.(*setting) + } + + return s +} + +func newIncNonDefault(name string) func() { + s := New(name) + s.Value() + return s.IncNonDefault +} + +var updateMu sync.Mutex + +// update records an updated GODEBUG setting. +// def is the default GODEBUG setting for the running binary, +// and env is the current value of the $GODEBUG environment variable. +func update(def, env string) { + updateMu.Lock() + defer updateMu.Unlock() + + // Update all the cached values, creating new ones as needed. + // We parse the environment variable first, so that any settings it has + // are already locked in place (did[name] = true) before we consider + // the defaults. + did := make(map[string]bool) + parse(did, env) + parse(did, def) + + // Clear any cached values that are no longer present. + cache.Range(func(name, s any) bool { + if !did[name.(string)] { + s.(*setting).value.Store(&empty) } - val := afterKey[1:] - for i, b := range val { - if b == ',' { - return val[:i] + return true + }) +} + +// parse parses the GODEBUG setting string s, +// which has the form k=v,k2=v2,k3=v3. +// Later settings override earlier ones. +// Parse only updates settings k=v for which did[k] = false. +// It also sets did[k] = true for settings that it updates. +// Each value v can also have the form v#pattern, +// in which case the GODEBUG is only enabled for call stacks +// matching pattern, for use with golang.org/x/tools/cmd/bisect. +func parse(did map[string]bool, s string) { + // Scan the string backward so that later settings are used + // and earlier settings are ignored. + // Note that a forward scan would cause cached values + // to temporarily use the ignored value before being + // updated to the "correct" one. + end := len(s) + eq := -1 + for i := end - 1; i >= -1; i-- { + if i == -1 || s[i] == ',' { + if eq >= 0 { + name, arg := s[i+1:eq], s[eq+1:end] + if !did[name] { + did[name] = true + v := &value{text: arg} + for j := 0; j < len(arg); j++ { + if arg[j] == '#' { + v.text = arg[:j] + v.bisect, _ = bisect.New(arg[j+1:]) + break + } + } + lookup(name).value.Store(v) + } } + eq = -1 + end = i + } else if s[i] == '=' { + eq = i } - return val } - return "" +} + +type runtimeStderr struct{} + +var stderr runtimeStderr + +func (*runtimeStderr) Write(b []byte) (int, error) { + return len(b), nil } diff --git a/internal/godebug/godebug_test.go b/internal/godebug/godebug_test.go deleted file mode 100644 index 41b9117b..00000000 --- a/internal/godebug/godebug_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package godebug - -import "testing" - -func TestGet(t *testing.T) { - tests := []struct { - godebug string - key string - want string - }{ - {"", "", ""}, - {"", "foo", ""}, - {"foo=bar", "foo", "bar"}, - {"foo=bar,after=x", "foo", "bar"}, - {"before=x,foo=bar,after=x", "foo", "bar"}, - {"before=x,foo=bar", "foo", "bar"}, - {",,,foo=bar,,,", "foo", "bar"}, - {"foodecoy=wrong,foo=bar", "foo", "bar"}, - {"foo=", "foo", ""}, - {"foo", "foo", ""}, - {",foo", "foo", ""}, - {"foo=bar,baz", "loooooooong", ""}, - } - for _, tt := range tests { - got := get(tt.godebug, tt.key) - if got != tt.want { - t.Errorf("get(%q, %q) = %q; want %q", tt.godebug, tt.key, got, tt.want) - } - } -} diff --git a/internal/godebugs/table.go b/internal/godebugs/table.go new file mode 100644 index 00000000..d5ac707a --- /dev/null +++ b/internal/godebugs/table.go @@ -0,0 +1,78 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package godebugs provides a table of known GODEBUG settings, +// for use by a variety of other packages, including internal/godebug, +// runtime, runtime/metrics, and cmd/go/internal/load. +package godebugs + +// An Info describes a single known GODEBUG setting. +type Info struct { + Name string // name of the setting ("panicnil") + Package string // package that uses the setting ("runtime") + Changed int // minor version when default changed, if any; 21 means Go 1.21 + Old string // value that restores behavior prior to Changed + Opaque bool // setting does not export information to runtime/metrics using [internal/godebug.Setting.IncNonDefault] +} + +// All is the table of known settings, sorted by Name. +// +// Note: After adding entries to this table, run 'go generate runtime/metrics' +// to update the runtime/metrics doc comment. +// (Otherwise the runtime/metrics test will fail.) +// +// Note: After adding entries to this table, update the list in doc/godebug.md as well. +// (Otherwise the test in this package will fail.) +var All = []Info{ + {Name: "execerrdot", Package: "os/exec"}, + {Name: "gocachehash", Package: "cmd/go"}, + {Name: "gocachetest", Package: "cmd/go"}, + {Name: "gocacheverify", Package: "cmd/go"}, + {Name: "gotypesalias", Package: "go/types"}, + {Name: "http2client", Package: "net/http"}, + {Name: "http2debug", Package: "net/http", Opaque: true}, + {Name: "http2server", Package: "net/http"}, + {Name: "httplaxcontentlength", Package: "net/http", Changed: 22, Old: "1"}, + {Name: "httpmuxgo121", Package: "net/http", Changed: 22, Old: "1"}, + {Name: "installgoroot", Package: "go/build"}, + {Name: "jstmpllitinterp", Package: "html/template"}, + //{Name: "multipartfiles", Package: "mime/multipart"}, + {Name: "multipartmaxheaders", Package: "mime/multipart"}, + {Name: "multipartmaxparts", Package: "mime/multipart"}, + {Name: "multipathtcp", Package: "net"}, + {Name: "netdns", Package: "net", Opaque: true}, + {Name: "panicnil", Package: "runtime", Changed: 21, Old: "1"}, + {Name: "randautoseed", Package: "math/rand"}, + {Name: "tarinsecurepath", Package: "archive/tar"}, + {Name: "tls10server", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "tlsmaxrsasize", Package: "crypto/tls"}, + {Name: "tlsrsakex", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "tlsunsafeekm", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "winreadlinkvolume", Package: "os", Changed: 22, Old: "0"}, + {Name: "winsymlink", Package: "os", Changed: 22, Old: "0"}, + {Name: "x509sha1", Package: "crypto/x509"}, + {Name: "x509usefallbackroots", Package: "crypto/x509"}, + {Name: "x509usepolicies", Package: "crypto/x509"}, + {Name: "zipinsecurepath", Package: "archive/zip"}, +} + +// Lookup returns the Info with the given name. +func Lookup(name string) *Info { + // binary search, avoiding import of sort. + lo := 0 + hi := len(All) + for lo < hi { + m := int(uint(lo+hi) >> 1) + mid := All[m].Name + if name == mid { + return &All[m] + } + if name < mid { + hi = m + } else { + lo = m + 1 + } + } + return nil +} diff --git a/request_test.go b/request_test.go index ac30f85b..a326e1f2 100644 --- a/request_test.go +++ b/request_test.go @@ -5,8 +5,6 @@ import ( "encoding/json" "encoding/xml" "fmt" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" "io" "net/http" "net/url" @@ -15,6 +13,9 @@ import ( "strings" "testing" "time" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/tests" ) func TestMustSendMethods(t *testing.T) { @@ -635,7 +636,6 @@ func testQueryParam(t *testing.T, c *Client) { Get("/query-parameter") assertSuccess(t, resp, err) tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5&key6=value6&key6=value66", resp.String()) - } func TestPathParam(t *testing.T) { @@ -963,7 +963,7 @@ func (r *SlowReader) Read(p []byte) (int, error) { func TestUploadCallback(t *testing.T) { r := tc().R() - file := "transport_test.go" + file := "transport.go" fileInfo, err := os.Stat(file) if err != nil { t.Fatal(err) diff --git a/roundtrip_js.go b/roundtrip_js.go index af51f13e..9c6b6c4a 100644 --- a/roundtrip_js.go +++ b/roundtrip_js.go @@ -12,7 +12,10 @@ import ( "io" "net/http" "strconv" + "strings" "syscall/js" + + "github.com/imroc/req/v3/internal/ascii" ) var uint8Array = js.Global().Get("Uint8Array") @@ -45,45 +48,17 @@ const jsFetchRedirect = "js.fetch:redirect" // the browser globals. var jsFetchMissing = js.Global().Get("fetch").IsUndefined() -// jsFetchDisabled will be true if the "process" global is present. -// We use this as an indicator that we're running in Node.js. We -// want to disable the Fetch API in Node.js because it breaks -// our wasm tests. See https://go.dev/issue/57613 for more information. -var jsFetchDisabled = !js.Global().Get("process").IsUndefined() - -// Determine whether the JS runtime supports streaming request bodies. -// Courtesy: https://developer.chrome.com/articles/fetch-streaming-requests/#feature-detection -func supportsPostRequestStreams() bool { - requestOpt := js.Global().Get("Object").New() - requestBody := js.Global().Get("ReadableStream").New() - - requestOpt.Set("method", "POST") - requestOpt.Set("body", requestBody) - - // There is quite a dance required to define a getter if you do not have the { get property() { ... } } - // syntax available. However, it is possible: - // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Functions/get#defining_a_getter_on_existing_objects_using_defineproperty - duplexCalled := false - duplexGetterObj := js.Global().Get("Object").New() - duplexGetterFunc := js.FuncOf(func(this js.Value, args []js.Value) any { - duplexCalled = true - return "half" - }) - defer duplexGetterFunc.Release() - duplexGetterObj.Set("get", duplexGetterFunc) - js.Global().Get("Object").Call("defineProperty", requestOpt, "duplex", duplexGetterObj) - - // Slight difference here between the aforementioned example: Non-browser-based runtimes - // do not have a non-empty API Base URL (https://html.spec.whatwg.org/multipage/webappapis.html#api-base-url) - // so we have to supply a valid URL here. - requestObject := js.Global().Get("Request").New("https://www.example.org", requestOpt) - - hasContentTypeHeader := requestObject.Get("headers").Call("has", "Content-Type").Bool() - - return duplexCalled && !hasContentTypeHeader -} +// jsFetchDisabled controls whether the use of Fetch API is disabled. +// It's set to true when we detect we're running in Node.js, so that +// RoundTrip ends up talking over the same fake network the HTTP servers +// currently use in various tests and examples. See go.dev/issue/57613. +// +// TODO(go.dev/issue/60810): See if it's viable to test the Fetch API +// code path. +var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject && + strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node") -// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. +// RoundTrip implements the [RoundTripper] interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // The Transport has a documented contract that states that if the DialContext or // DialTLSContext functions are set, they will be used to set up the connections. @@ -131,60 +106,25 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } opt.Set("headers", headers) - var readableStreamStart, readableStreamPull, readableStreamCancel js.Func if req.Body != nil { - if !supportsPostRequestStreams() { - body, err := io.ReadAll(req.Body) - if err != nil { - req.Body.Close() // RoundTrip must always close the body, including on errors. - return nil, err - } - if len(body) != 0 { - buf := uint8Array.New(len(body)) - js.CopyBytesToJS(buf, body) - opt.Set("body", buf) - } - } else { - readableStreamCtorArg := js.Global().Get("Object").New() - readableStreamCtorArg.Set("type", "bytes") - readableStreamCtorArg.Set("autoAllocateChunkSize", t.writeBufferSize()) - - readableStreamPull = js.FuncOf(func(this js.Value, args []js.Value) any { - controller := args[0] - byobRequest := controller.Get("byobRequest") - if byobRequest.IsNull() { - controller.Call("close") - } - - byobRequestView := byobRequest.Get("view") - - bodyBuf := make([]byte, byobRequestView.Get("byteLength").Int()) - readBytes, readErr := io.ReadFull(req.Body, bodyBuf) - if readBytes > 0 { - buf := uint8Array.New(byobRequestView.Get("buffer")) - js.CopyBytesToJS(buf, bodyBuf) - byobRequest.Call("respond", readBytes) - } - - if readErr == io.EOF || readErr == io.ErrUnexpectedEOF { - controller.Call("close") - } else if readErr != nil { - readErrCauseObject := js.Global().Get("Object").New() - readErrCauseObject.Set("cause", readErr.Error()) - readErr := js.Global().Get("Error").New("io.ReadFull failed while streaming POST body", readErrCauseObject) - controller.Call("error", readErr) - } - // Note: This a return from the pull callback of the controller and *not* RoundTrip(). - return nil - }) - readableStreamCtorArg.Set("pull", readableStreamPull) - - opt.Set("body", js.Global().Get("ReadableStream").New(readableStreamCtorArg)) - // There is a requirement from the WHATWG fetch standard that the duplex property of - // the object given as the options argument to the fetch call be set to 'half' - // when the body property of the same options object is a ReadableStream: - // https://fetch.spec.whatwg.org/#dom-requestinit-duplex - opt.Set("duplex", "half") + // TODO(johanbrandhorst): Stream request body when possible. + // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue. + // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. + // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API + // and browser support. + // NOTE(haruyama480): Ensure HTTP/1 fallback exists. + // See https://go.dev/issue/61889 for discussion. + body, err := io.ReadAll(req.Body) + if err != nil { + req.Body.Close() // RoundTrip must always close the body, including on errors. + return nil, err + } + req.Body.Close() + if len(body) != 0 { + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) } } @@ -197,11 +137,6 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { success = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - readableStreamCancel.Release() - readableStreamPull.Release() - readableStreamStart.Release() - - req.Body.Close() result := args[0] header := http.Header{} @@ -252,11 +187,21 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } code := result.Get("status").Int() + + uncompressed := false + if ascii.EqualFold(header.Get("Content-Encoding"), "gzip") { + // The fetch api will decode the gzip, but Content-Encoding not be deleted. + header.Del("Content-Encoding") + header.Del("Content-Length") + contentLength = -1 + uncompressed = true + } respCh <- &http.Response{ Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), StatusCode: code, Header: header, ContentLength: contentLength, + Uncompressed: uncompressed, Body: body, Request: req, } @@ -266,11 +211,6 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - readableStreamCancel.Release() - readableStreamPull.Release() - readableStreamStart.Release() - - req.Body.Close() err := args[0] // The error is a JS Error type diff --git a/server.go b/server.go new file mode 100644 index 00000000..8cd25f11 --- /dev/null +++ b/server.go @@ -0,0 +1,18 @@ +package req + +import "sync" + +const copyBufPoolSize = 32 * 1024 + +var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }} + +func getCopyBuf() []byte { + return copyBufPool.Get().(*[copyBufPoolSize]byte)[:] +} + +func putCopyBuf(b []byte) { + if len(b) != copyBufPoolSize { + panic("trying to put back buffer of the wrong size in the copyBufPool") + } + copyBufPool.Put((*[copyBufPoolSize]byte)(b)) +} diff --git a/textproto_reader.go b/textproto_reader.go index 46103c07..1c09872a 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -7,12 +7,13 @@ package req import ( "bufio" "bytes" + "errors" "fmt" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/util" + "math" "net/textproto" - "strings" "sync" + + "github.com/imroc/req/v3/internal/dump" ) func isASCIILetter(b byte) bool { @@ -20,6 +21,10 @@ func isASCIILetter(b byte) bool { return 'a' <= b && b <= 'z' } +// TODO: This should be a distinguishable error (ErrMessageTooLarge) +// to allow mime/multipart to detect it. +var errMessageTooLarge = errors.New("message too large") + // A textprotoReader implements convenience methods for reading requests // or responses from a text protocol network connection. type textprotoReader struct { @@ -67,11 +72,14 @@ func newTextprotoReader(r *bufio.Reader, ds dump.Dumpers) *textprotoReader { // ReadLine reads a single line from r, // eliding the final \n or \r\n from the returned string. func (r *textprotoReader) ReadLine() (string, error) { - line, err := r.readLineSlice() + line, err := r.readLineSlice(-1) return string(line), err } -func (r *textprotoReader) readLineSlice() ([]byte, error) { +// readLineSlice reads a single line from r, +// up to lim bytes long (or unlimited if lim is less than 0), +// eliding the final \r or \r\n from the returned string. +func (r *textprotoReader) readLineSlice(lim int64) ([]byte, error) { var line []byte for { @@ -79,6 +87,9 @@ func (r *textprotoReader) readLineSlice() ([]byte, error) { if err != nil { return nil, err } + if lim >= 0 && int64(len(line))+int64(len(l)) > lim { + return nil, errMessageTooLarge + } // Avoid the copy if the first call produced a full line. if line == nil && !more { return l, nil @@ -109,13 +120,14 @@ func trim(s []byte) []byte { // returning a byte slice with all lines. The validateFirstLine function // is run on the first read line, and if it returns an error then this // error is returned from readContinuedLineSlice. -func (r *textprotoReader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { +// It reads up to lim bytes of data (or unlimited if lim is less than 0). +func (r *textprotoReader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) { if validateFirstLine == nil { return nil, fmt.Errorf("missing validateFirstLine func") } // Read the first line. - line, err := r.readLineSlice() + line, err := r.readLineSlice(lim) if err != nil { return nil, err } @@ -143,13 +155,21 @@ func (r *textprotoReader) readContinuedLineSlice(validateFirstLine func([]byte) // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) + if lim < 0 { + lim = math.MaxInt64 + } + lim -= int64(len(r.buf)) + // Read continuation lines. for r.skipSpace() > 0 { - line, err := r.readLineSlice() + r.buf = append(r.buf, ' ') + if int64(len(r.buf)) >= lim { + return nil, errMessageTooLarge + } + line, err := r.readLineSlice(lim - int64(len(r.buf))) if err != nil { break } - r.buf = append(r.buf, ' ') r.buf = append(r.buf, trim(line)...) } return r.buf, nil @@ -186,7 +206,7 @@ var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. -// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// The returned map m maps [CanonicalMIMEHeaderKey](key) to a // sequence of values in the same order encountered in the input. // // For example, consider this input: @@ -203,20 +223,36 @@ var colon = []byte(":") // "Long-Key": {"Even Longer Value"}, // } func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { + return r.readMIMEHeader(math.MaxInt64, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func (r *textprotoReader) readMIMEHeader(maxMemory, maxHeaders int64) (textproto.MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. var strs []string - hint := r.upcomingHeaderNewlines() + hint := r.upcomingHeaderKeys() if hint > 0 { + if hint > 1000 { + hint = 1000 // set a cap to avoid overallocation + } strs = make([]string, hint) } m := make(textproto.MIMEHeader, hint) + // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry. + // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large + // MIMEHeaders average about 200 bytes per entry. + maxMemory -= 400 + const mapEntryOverhead = 200 + // The first line cannot start with a leading space. if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { - line, err := r.readLineSlice() + const errorLimit = 80 // arbitrary limit on how much of the line we'll quote + line, err := r.readLineSlice(errorLimit) if err != nil { return m, err } @@ -224,29 +260,43 @@ func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { } for { - kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon) if len(kv) == 0 { return m, err } // Key ends at first colon. - k, v, ok := util.CutBytes(kv, colon) + k, v, ok := bytes.Cut(kv, colon) + if !ok { + return m, protocolError("malformed MIME header line: " + string(kv)) + } + key, ok := canonicalMIMEHeaderKey(k) if !ok { return m, protocolError("malformed MIME header line: " + string(kv)) } - key := canonicalMIMEHeaderKey(k) + for _, c := range v { + if !validHeaderValueByte(c) { + return m, protocolError("malformed MIME header line: " + string(kv)) + } + } - // As per RFC 7230 field-name is a token, tokens consist of one or more chars. - // We could return a protocolError here, but better to be liberal in what we - // accept, so if we get an empty key, skip it. - if key == "" { - continue + maxHeaders-- + if maxHeaders < 0 { + return nil, errMessageTooLarge } // Skip initial spaces in value. - value := strings.TrimLeft(string(v), " \t") + value := string(bytes.TrimLeft(v, " \t")) vv := m[key] + if vv == nil { + maxMemory -= int64(len(key)) + maxMemory -= mapEntryOverhead + } + maxMemory -= int64(len(value)) + if maxMemory < 0 { + return m, errMessageTooLarge + } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. // Most headers aren't multi-valued. @@ -277,9 +327,9 @@ func mustHaveFieldNameColon(line []byte) error { var nl = []byte("\n") -// upcomingHeaderNewlines returns an approximation of the number of newlines +// upcomingHeaderKeys returns an approximation of the number of keys // that will be in this header. If it gets confused, it returns 0. -func (r *textprotoReader) upcomingHeaderNewlines() (n int) { +func (r *textprotoReader) upcomingHeaderKeys() (n int) { // Try to determine the 'hint' size. r.R.Peek(1) // force a buffer load if empty s := r.R.Buffered() @@ -287,7 +337,20 @@ func (r *textprotoReader) upcomingHeaderNewlines() (n int) { return } peek, _ := r.R.Peek(s) - return bytes.Count(peek, nl) + for len(peek) > 0 && n < 1000 { + var line []byte + line, peek, _ = bytes.Cut(peek, nl) + if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { + // Blank line separating headers from the body. + break + } + if line[0] == ' ' || line[0] == '\t' { + // Folded continuation of the previous line. + continue + } + n++ + } + return n } const toLower = 'a' - 'A' @@ -310,14 +373,33 @@ func validHeaderFieldByte(b byte) bool { // // For invalid inputs (if a contains spaces or non-token bytes), a // is unchanged and a string copy is returned. -func canonicalMIMEHeaderKey(a []byte) string { +// +// ok is true if the header key contains only valid characters and spaces. +// ReadMIMEHeader accepts header keys containing spaces, but does not +// canonicalize them. +func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) { + if len(a) == 0 { + return "", false + } + // See if a looks like a header key. If not, return it unchanged. + noCanon := false for _, c := range a { if validHeaderFieldByte(c) { continue } // Don't canonicalize. - return string(a) + if c == ' ' { + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See https://go.dev/issue/34540. + noCanon = true + continue + } + return string(a), false + } + if noCanon { + return string(a), true } upper := true @@ -334,13 +416,40 @@ func canonicalMIMEHeaderKey(a []byte) string { a[i] = c upper = c == '-' // for next time } + commonHeaderOnce.Do(initCommonHeader) // The compiler recognizes m[string(byteSlice)] as a special // case, so a copy of a's bytes into a new string does not // happen in this map lookup: if v := commonHeader[string(a)]; v != "" { - return v + return v, true } - return string(a) + return string(a), true +} + +// validHeaderValueByte reports whether c is a valid byte in a header +// field value. RFC 7230 says: +// +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// +// RFC 5234 says: +// +// HTAB = %x09 +// SP = %x20 +// VCHAR = %x21-7E +func validHeaderValueByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<>64)) == 0 } // commonHeader interns common header strings. diff --git a/transfer.go b/transfer.go index c7c623f8..92a5d305 100644 --- a/transfer.go +++ b/transfer.go @@ -9,9 +9,6 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/v3/internal" - "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/dump" "io" "net/http" "net/textproto" @@ -22,6 +19,11 @@ import ( "sync" "time" + "github.com/imroc/req/v3/internal" + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/godebug" + "golang.org/x/net/http/httpguts" ) @@ -317,7 +319,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error // OS-level optimizations in the event that the body is an // *os.File. if t.Body != nil { - var body = t.unwrapBody() + body := t.unwrapBody() if chunked(t.TransferEncoding) { if bw, ok := rw.(*bufio.Writer); ok { rw = &internal.FlushAfterChunkWriter{Writer: bw} @@ -386,7 +388,9 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error // // This function is only intended for use in writeBody. func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { - n, err = io.Copy(dst, src) + buf := getCopyBuf() + defer putCopyBuf(buf) + n, err = io.CopyBuffer(dst, src, buf) if err != nil && err != io.EOF { t.bodyReadError = err } @@ -478,7 +482,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { return err } if isResponse && t.RequestMethod == "HEAD" { - if n, err := parseContentLength(headerGet(t.Header, "Content-Length")); err != nil { + if n, err := parseContentLength(t.Header["Content-Length"]); err != nil { return err } else { t.ContentLength = n @@ -582,19 +586,6 @@ func (t *transferReader) parseTransferEncoding() error { return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} } - // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field - // in any message that contains a Transfer-Encoding header field." - // - // but also: "If a message is received with both a Transfer-Encoding and a - // Content-Length header field, the Transfer-Encoding overrides the - // Content-Length. Such a message might indicate an attempt to perform - // request smuggling (Section 9.5) or response splitting (Section 9.4) and - // ought to be handled as an error. A sender MUST remove the received - // Content-Length field prior to forwarding such a message downstream." - // - // Reportedly, these appear in the wild. - delete(t.Header, "Content-Length") - t.Chunked = true return nil } @@ -602,7 +593,8 @@ func (t *transferReader) parseTransferEncoding() error { // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. -func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (int64, error) { +func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (n int64, err error) { + isRequest := !isResponse contentLens := header["Content-Length"] // Hardening against HTTP request smuggling @@ -625,6 +617,14 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He contentLens = header["Content-Length"] } + // Reject requests with invalid Content-Length headers. + if len(contentLens) > 0 { + n, err = parseContentLength(contentLens) + if err != nil { + return -1, err + } + } + // Logic based on response type or status if isResponse && noResponseBodyExpected(requestMethod) { return 0, nil @@ -637,25 +637,43 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He return 0, nil } + // According to RFC 9112, "If a message is received with both a + // Transfer-Encoding and a Content-Length header field, the Transfer-Encoding + // overrides the Content-Length. Such a message might indicate an attempt to + // perform request smuggling (Section 11.2) or response splitting (Section 11.1) + // and ought to be handled as an error. An intermediary that chooses to forward + // the message MUST first remove the received Content-Length field and process + // the Transfer-Encoding (as described below) prior to forwarding the message downstream." + // + // Chunked-encoding requests with either valid Content-Length + // headers or no Content-Length headers are accepted after removing + // the Content-Length field from header. + // // Logic based on Transfer-Encoding if chunked { + header.Del("Content-Length") return -1, nil } // Logic based on Content-Length - var cl string - if len(contentLens) == 1 { - cl = textproto.TrimString(contentLens[0]) - } - if cl != "" { - n, err := parseContentLength(cl) - if err != nil { - return -1, err - } + if len(contentLens) > 0 { return n, nil } + header.Del("Content-Length") + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -955,19 +973,31 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) { return bl.b.readLocked(p) } -// parseContentLength trims whitespace from s and returns -1 if no value -// is set, or the value if it's >= 0. -func parseContentLength(cl string) (int64, error) { - cl = textproto.TrimString(cl) - if cl == "" { +var laxContentLength = godebug.New("httplaxcontentlength") + +// parseContentLength checks that the header is valid and then trims +// whitespace. It returns -1 if no value is set otherwise the value +// if it's >= 0. +func parseContentLength(clHeaders []string) (int64, error) { + if len(clHeaders) == 0 { return -1, nil } + cl := textproto.TrimString(clHeaders[0]) + + // The Content-Length must be a valid numeric value. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + if cl == "" { + if laxContentLength.Value() == "1" { + laxContentLength.IncNonDefault() + return -1, nil + } + return 0, badStringError("invalid empty Content-Length", cl) + } n, err := strconv.ParseUint(cl, 10, 63) if err != nil { return 0, badStringError("bad Content-Length", cl) } return int64(n), nil - } // finishAsyncByteRead finishes reading the 1-byte sniff @@ -991,11 +1021,13 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) -var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { - io.Reader - io.WriterTo -}{})) +var ( + nopCloserType = reflect.TypeOf(io.NopCloser(nil)) + nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { + io.Reader + io.WriterTo + }{})) +) // unwrapNopCloser return the underlying reader and true if r is a NopCloser // else it return false. diff --git a/transport.go b/transport.go index 2fdcff3c..e79c005f 100644 --- a/transport.go +++ b/transport.go @@ -1583,6 +1583,13 @@ func (w *wantConn) waiting() bool { } } +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.Lock() + defer w.mu.Unlock() + return w.ctx +} + // tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { w.mu.Lock() @@ -1592,6 +1599,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { return false } + w.ctx = nil w.pc = pc w.err = err if w.pc == nil && w.err == nil { @@ -1609,6 +1617,7 @@ func (w *wantConn) cancel(t *Transport, err error) { close(w.ready) // catch misbehavior in future delivery } pc := w.pc + w.ctx = nil w.pc = nil w.err = err w.mu.Unlock() @@ -1814,6 +1823,11 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() + ctx := w.getCtxForDial() + if ctx == nil { + t.decConnsPerHost(w.key) + return + } pc, err := t.dialConn(w.ctx, w.cm) delivered := w.tryDeliver(pc, err) @@ -1882,7 +1896,6 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { - // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pc.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1912,6 +1925,11 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace }() if err := <-errc; err != nil { plainConn.Close() + if err == (tlsHandshakeTimeoutError{}) { + // Now that we have closed the connection, + // wait for the call to HandshakeContext to return. + <-errc + } if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) } @@ -2148,6 +2166,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if t.OnProxyConnectResponse != nil { err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) if err != nil { + conn.Close() return nil, err } } @@ -2690,7 +2709,6 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- false <-eofc // will be closed by deferred call at the end of the function return nil - }, fn: func(err error) error { isEOF := err == io.EOF @@ -2737,7 +2755,7 @@ func (pc *persistConn) readLoop() { } case <-rc.req.Cancel: alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.cancelKey, common.ErrRequestCanceled) case <-rc.req.Context().Done(): alive = false pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) @@ -3227,16 +3245,18 @@ type writeRequest struct { continueCh <-chan struct{} } -type httpError struct { - err string - timeout bool +// httpTimeoutError represents a timeout. +// It implements net.Error and wraps context.DeadlineExceeded. +type timeoutError struct { + err string } -func (e *httpError) Error() string { return e.err } -func (e *httpError) Timeout() bool { return e.timeout } -func (e *httpError) Temporary() bool { return true } +func (e *timeoutError) Error() string { return e.err } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } +func (e *timeoutError) Is(err error) bool { return err == context.DeadlineExceeded } -var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} +var errTimeout error = &timeoutError{"net/http: timeout awaiting response headers"} var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? From 7e8e5f43b18850689751eb29d709e737b8636fa9 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 12 Mar 2024 11:41:46 +0800 Subject: [PATCH 796/843] fix DisableForceMultipart (#333) --- request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/request.go b/request.go index 7a5942a2..79f69f49 100644 --- a/request.go +++ b/request.go @@ -1102,7 +1102,7 @@ func (r *Request) EnableForceMultipart() *Request { // DisableForceMultipart disables force using multipart to upload form data. func (r *Request) DisableForceMultipart() *Request { - r.isMultiPart = true + r.isMultiPart = false return r } From 5d1cb2b78f229ec0ccdbfd818a31b855c253863b Mon Sep 17 00:00:00 2001 From: Mark Niehe Date: Wed, 3 Apr 2024 19:29:35 -0700 Subject: [PATCH 797/843] Check if the context was cancelled before retrying --- client_test.go | 22 +++++++++++++++++++--- request.go | 6 +++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index 98df9f55..28acdfe5 100644 --- a/client_test.go +++ b/client_test.go @@ -5,9 +5,6 @@ import ( "context" "crypto/tls" "errors" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "golang.org/x/net/publicsuffix" "io" "net" "net/http" @@ -17,8 +14,27 @@ import ( "strings" "testing" "time" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/tests" + "golang.org/x/net/publicsuffix" ) +func TestPassContext(t *testing.T) { + cancelledCtx, done := context.WithCancel(context.Background()) + done() + + client := tc(). + SetCommonRetryCount(2). + SetCommonRetryBackoffInterval(1*time.Second, 5*time.Second) + + res, err := client.R().SetContext(cancelledCtx).Get("/") + + tests.AssertEqual(t, 0, res.Request.RetryAttempt) + tests.AssertNotNil(t, err) + tests.AssertErrorContains(t, err, "context canceled") +} + func TestWrapRoundTrip(t *testing.T) { i, j, a, b := 0, 0, 0, 0 c := tc().WrapRoundTripFunc(func(rt RoundTripper) RoundTripFunc { diff --git a/request.go b/request.go index 79f69f49..88b368cb 100644 --- a/request.go +++ b/request.go @@ -661,13 +661,17 @@ func (r *Request) do() (resp *Response, err error) { resp, err = r.client.roundTrip(r) } + // Determine if the error is from a cancelled context. Store it here so it doesn't get lost + // when processing the AfterRespon middleware. + contextCanceled := errors.Is(err, context.Canceled) + for _, f := range r.afterResponse { if err = f(r.client, resp); err != nil { return } } - if r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { // absolutely cannot retry. + if contextCanceled || r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { // absolutely cannot retry. return } From 9953cc77c189da55cb8206a969dfe77d810ae8dc Mon Sep 17 00:00:00 2001 From: Mark Niehe Date: Wed, 3 Apr 2024 19:30:45 -0700 Subject: [PATCH 798/843] Rename test function --- client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index 28acdfe5..ccf62f57 100644 --- a/client_test.go +++ b/client_test.go @@ -20,7 +20,7 @@ import ( "golang.org/x/net/publicsuffix" ) -func TestPassContext(t *testing.T) { +func TestRetryCancelledContext(t *testing.T) { cancelledCtx, done := context.WithCancel(context.Background()) done() From 2b95fc59e652b76f7cd5f3e12ef9537e43409834 Mon Sep 17 00:00:00 2001 From: Ronaldinho Date: Thu, 4 Apr 2024 11:49:00 +0800 Subject: [PATCH 799/843] allow RetryCondition and RetryHook wrap the original error if resp.Err has been set, wont overwrite it --- request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/request.go b/request.go index 79f69f49..ff7e1ef6 100644 --- a/request.go +++ b/request.go @@ -635,7 +635,7 @@ func (r *Request) do() (resp *Response, err error) { if resp == nil { resp = &Response{Request: r} } - if err != nil { + if err != nil && resp.Err == nil { resp.Err = err } }() From 18c160944dbd0fd1855807bae9f9095edcc5cf2c Mon Sep 17 00:00:00 2001 From: ferhat elmas Date: Fri, 5 Apr 2024 21:09:12 +0200 Subject: [PATCH 800/843] fix: use correct context for dial --- transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport.go b/transport.go index e79c005f..620e2923 100644 --- a/transport.go +++ b/transport.go @@ -1829,7 +1829,7 @@ func (t *Transport) dialConnFor(w *wantConn) { return } - pc, err := t.dialConn(w.ctx, w.cm) + pc, err := t.dialConn(ctx, w.cm) delivered := w.tryDeliver(pc, err) if err == nil && (!delivered || pc.alt != nil) { // pconn was not passed to w, From 34bd5228604d79fb7c697702422d8ef622ceaeb1 Mon Sep 17 00:00:00 2001 From: roc Date: Sun, 7 Apr 2024 11:16:35 +0800 Subject: [PATCH 801/843] fix typo --- request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index 9669adb7..4bbb98a0 100644 --- a/request.go +++ b/request.go @@ -661,8 +661,8 @@ func (r *Request) do() (resp *Response, err error) { resp, err = r.client.roundTrip(r) } - // Determine if the error is from a cancelled context. Store it here so it doesn't get lost - // when processing the AfterRespon middleware. + // Determine if the error is from a canceled context. + // Store it here so it doesn't get lost when processing the AfterResponse middleware. contextCanceled := errors.Is(err, context.Canceled) for _, f := range r.afterResponse { From 14d332b19b39b6f66fb8ff10677f64a374d1f363 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Tue, 9 Apr 2024 22:20:11 -0700 Subject: [PATCH 802/843] Allow logger creation from an existing standard logger --- logger.go | 4 ++++ logger_test.go | 14 +++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/logger.go b/logger.go index af461064..c0020713 100644 --- a/logger.go +++ b/logger.go @@ -19,6 +19,10 @@ func NewLogger(output io.Writer, prefix string, flag int) Logger { return &logger{l: log.New(output, prefix, flag)} } +func NewStandardLogger(l *log.Logger) Logger { + return &logger{l: l} +} + func createDefaultLogger() Logger { return NewLogger(os.Stdout, "", log.Ldate|log.Lmicroseconds) } diff --git a/logger_test.go b/logger_test.go index c15dd6f4..64b9bdcd 100644 --- a/logger_test.go +++ b/logger_test.go @@ -2,9 +2,10 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/tests" "log" "testing" + + "github.com/imroc/req/v3/internal/tests" ) func TestLogger(t *testing.T) { @@ -17,3 +18,14 @@ func TestLogger(t *testing.T) { c.R().SetOutput(nil) tests.AssertContains(t, buf.String(), "warn", true) } + +func TestStandardLogger(t *testing.T) { + buf := new(bytes.Buffer) + l := NewStandardLogger(log.New(buf, "", log.Ldate|log.Lmicroseconds)) + c := tc().SetLogger(l) + c.SetProxyURL(":=\\<>ksfj&*&sf") + tests.AssertContains(t, buf.String(), "error", true) + buf.Reset() + c.R().SetOutput(nil) + tests.AssertContains(t, buf.String(), "warn", true) +} From 1c6644396a4210d16a7971b55cf1c736cb6b9391 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Fri, 12 Apr 2024 07:58:47 -0700 Subject: [PATCH 803/843] Rename for clarity --- logger.go | 2 +- logger_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/logger.go b/logger.go index c0020713..6b8f897a 100644 --- a/logger.go +++ b/logger.go @@ -19,7 +19,7 @@ func NewLogger(output io.Writer, prefix string, flag int) Logger { return &logger{l: log.New(output, prefix, flag)} } -func NewStandardLogger(l *log.Logger) Logger { +func NewFromStandardLogger(l *log.Logger) Logger { return &logger{l: l} } diff --git a/logger_test.go b/logger_test.go index 64b9bdcd..ac49ee9f 100644 --- a/logger_test.go +++ b/logger_test.go @@ -19,9 +19,9 @@ func TestLogger(t *testing.T) { tests.AssertContains(t, buf.String(), "warn", true) } -func TestStandardLogger(t *testing.T) { +func TestFromStandardLogger(t *testing.T) { buf := new(bytes.Buffer) - l := NewStandardLogger(log.New(buf, "", log.Ldate|log.Lmicroseconds)) + l := NewFromStandardLogger(log.New(buf, "", log.Ldate|log.Lmicroseconds)) c := tc().SetLogger(l) c.SetProxyURL(":=\\<>ksfj&*&sf") tests.AssertContains(t, buf.String(), "error", true) From f3578479d68aebe90865c01b871a29deb3e7cb50 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Sat, 13 Apr 2024 08:46:24 -0700 Subject: [PATCH 804/843] more descriptive function name --- logger.go | 2 +- logger_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/logger.go b/logger.go index 6b8f897a..749ec6b0 100644 --- a/logger.go +++ b/logger.go @@ -19,7 +19,7 @@ func NewLogger(output io.Writer, prefix string, flag int) Logger { return &logger{l: log.New(output, prefix, flag)} } -func NewFromStandardLogger(l *log.Logger) Logger { +func NewLoggerFromStandardLogger(l *log.Logger) Logger { return &logger{l: l} } diff --git a/logger_test.go b/logger_test.go index ac49ee9f..8a456da8 100644 --- a/logger_test.go +++ b/logger_test.go @@ -19,9 +19,9 @@ func TestLogger(t *testing.T) { tests.AssertContains(t, buf.String(), "warn", true) } -func TestFromStandardLogger(t *testing.T) { +func TestLoggerFromStandardLogger(t *testing.T) { buf := new(bytes.Buffer) - l := NewFromStandardLogger(log.New(buf, "", log.Ldate|log.Lmicroseconds)) + l := NewLoggerFromStandardLogger(log.New(buf, "", log.Ldate|log.Lmicroseconds)) c := tc().SetLogger(l) c.SetProxyURL(":=\\<>ksfj&*&sf") tests.AssertContains(t, buf.String(), "error", true) From 757eb95873eea26dbccf6b1ccad0eac1cf600df2 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 26 Apr 2024 23:18:44 +0800 Subject: [PATCH 805/843] Create SECURITY.md --- SECURITY.md | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..4a5665da --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,9 @@ +# Security Policy + +## Supported Versions + +req version >= `v3.43.x` + +## Reporting a Vulnerability + +Email: roc@imroc.cc From 2b406026f4aa980175ecb08bc540cf0e685467c1 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 8 May 2024 14:16:57 +0800 Subject: [PATCH 806/843] prevent successful requests from invalid host --- http.go | 40 ++++------------------------------------ transport.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/http.go b/http.go index ff99699f..61a63676 100644 --- a/http.go +++ b/http.go @@ -3,14 +3,14 @@ package req import ( "encoding/base64" "fmt" - "github.com/imroc/req/v3/internal/ascii" - "golang.org/x/net/http/httpguts" - "golang.org/x/net/idna" "io" - "net" "net/http" "net/textproto" "strings" + + "github.com/imroc/req/v3/internal/ascii" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/idna" ) // maxInt64 is the effective "infinite" value for the Server and @@ -165,38 +165,6 @@ func idnaASCII(v string) (string, error) { return idna.Lookup.ToASCII(v) } -// cleanHost cleans up the host sent in request's Host header. -// -// It both strips anything after '/' or ' ', and puts the value -// into Punycode form, if necessary. -// -// Ideally we'd clean the Host header according to the spec: -// https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]") -// https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host) -// https://tools.ietf.org/html/rfc3986#section-3.2.2 (definition of host) -// But practically, what we are trying to avoid is the situation in -// issue 11206, where a malformed Host header used in the proxy context -// would create a bad request. So it is enough to just truncate at the -// first offending character. -func cleanHost(in string) string { - if i := strings.IndexAny(in, " /"); i != -1 { - in = in[:i] - } - host, port, err := net.SplitHostPort(in) - if err != nil { // input was just a host - a, err := idnaASCII(in) - if err != nil { - return in // garbage in, garbage out - } - return a - } - a, err := idnaASCII(host) - if err != nil { - return in // garbage in, garbage out - } - return net.JoinHostPort(a, port) -} - // removeZone removes IPv6 zone identifier from host. // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" func removeZone(host string) string { diff --git a/transport.go b/transport.go index 620e2923..a1f847b0 100644 --- a/transport.go +++ b/transport.go @@ -2986,12 +2986,41 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo // is not given, use the host from the request URL. // // Clean the host, in case it arrives with unexpected stuff in it. - host := cleanHost(r.Host) + host := r.Host if host == "" { if r.URL == nil { return errMissingHost } - host = cleanHost(r.URL.Host) + host = r.URL.Host + } + host, err = httpguts.PunycodeHostPort(host) + if err != nil { + return err + } + + // Validate that the Host header is a valid header in general, + // but don't validate the host itself. This is sufficient to avoid + // header or request smuggling via the Host field. + // The server can (and will, if it's a net/http server) reject + // the request if it doesn't consider the host valid. + if !httpguts.ValidHostHeader(host) { + // Historically, we would truncate the Host header after '/' or ' '. + // Some users have relied on this truncation to convert a network + // address such as Unix domain socket path into a valid, ignored + // Host header (see https://go.dev/issue/61431). + // + // We don't preserve the truncation, because sending an altered + // header field opens a smuggling vector. Instead, zero out the + // Host header entirely if it isn't valid. (An empty Host is valid; + // see RFC 9112 Section 3.2.) + // + // Return an error if we're sending to a proxy, since the proxy + // probably can't do anything useful with an empty Host header. + if !usingProxy { + host = "" + } else { + return errors.New("http: invalid Host header") + } } // According to RFC 6874, an HTTP client, proxy, or other From 47d303390169a842d085d113d1b84bd2c78dbb77 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 16 May 2024 19:02:00 +0800 Subject: [PATCH 807/843] fix: avoid repeated append cookies when retry (#353) --- middleware.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/middleware.go b/middleware.go index c554e89d..59cd46f5 100644 --- a/middleware.go +++ b/middleware.go @@ -2,8 +2,6 @@ package req import ( "bytes" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/util" "io" "mime/multipart" "net/http" @@ -14,6 +12,9 @@ import ( "reflect" "strings" "time" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" ) type ( @@ -489,7 +490,7 @@ func parseRequestHeader(c *Client, r *Request) error { } func parseRequestCookie(c *Client, r *Request) error { - if len(c.Cookies) == 0 { + if len(c.Cookies) == 0 || r.RetryAttempt > 0 { return nil } r.Cookies = append(r.Cookies, c.Cookies...) From 382b9b174fe3db8a4cbbce2c8dd34e80fdaae30d Mon Sep 17 00:00:00 2001 From: ferhat elmas Date: Tue, 21 May 2024 12:09:25 +0200 Subject: [PATCH 808/843] fix: header cleanup replaced value should be trimmed, otherwise escape characters won't be handled correctly. --- header.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/header.go b/header.go index 3c94f7db..c78b9a71 100644 --- a/header.go +++ b/header.go @@ -100,7 +100,7 @@ func headerWriteSubset(h http.Header, exclude map[string]bool, writeHeader func( } for i, v := range kv.Values { vv := headerNewlineToSpace.Replace(v) - vv = textproto.TrimString(v) + vv = textproto.TrimString(vv) if vv != v { kv.Values[i] = vv } From e3649726ea8214bf2201736ac8e15d08a62de270 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 28 May 2024 14:33:05 +0800 Subject: [PATCH 809/843] add SetContextData and GetContextData(#358) --- request.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/request.go b/request.go index 4bbb98a0..7696d881 100644 --- a/request.go +++ b/request.go @@ -954,6 +954,18 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } +// SetContextData sets the key-value pair data for current Request, so you +// can access some extra context info for current Request in hook or middleware. +func (r *Request) SetContextData(key, val any) *Request { + r.ctx = context.WithValue(r.Context(), key, val) + return r +} + +// GetContextData returns the context data of specified key, which set by SetContextData. +func (r *Request) GetContextData(key any) any { + return r.Context().Value(key) +} + // DisableAutoReadResponse disable read response body automatically (enabled by default). func (r *Request) DisableAutoReadResponse() *Request { r.disableAutoReadResponse = true From 682cde21c366e4f7fa9429ab1897a5829a5feef3 Mon Sep 17 00:00:00 2001 From: roc Date: Fri, 31 May 2024 13:49:54 +0800 Subject: [PATCH 810/843] fix digest auth parse (#359) --- digest.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/digest.go b/digest.go index ccb97398..442b5953 100644 --- a/digest.go +++ b/digest.go @@ -138,9 +138,9 @@ func parseChallenge(input string) (*challenge, error) { case "opaque": c.opaque = strings.Trim(r[1], qs) case "stale": - c.stale = r[1] + c.stale = strings.Trim(r[1], qs) case "algorithm": - c.algorithm = r[1] + c.algorithm = strings.Trim(r[1], qs) case "qop": c.qop = strings.Trim(r[1], qs) case "charset": From 6559503ea1156d03d9f5485d26edd7e1b6105555 Mon Sep 17 00:00:00 2001 From: shrimp Date: Sat, 8 Jun 2024 16:52:14 +0400 Subject: [PATCH 811/843] Added auto decompression feature This commit updates several dependencies in different go modules and introduces the automatic decompression feature. --- client.go | 12 ++++++ examples/find-popular-repo/go.mod | 16 ++++---- examples/find-popular-repo/go.sum | 6 +++ examples/opentelemetry-jaeger-tracing/go.mod | 16 ++++---- examples/opentelemetry-jaeger-tracing/go.sum | 6 +++ examples/upload/uploadclient/go.mod | 16 ++++---- examples/upload/uploadclient/go.sum | 6 +++ examples/uploadcallback/uploadclient/go.mod | 16 ++++---- examples/uploadcallback/uploadclient/go.sum | 6 +++ go.mod | 2 +- go.sum | 2 + internal/http3/brotli_reader.go | 30 +++++++++++++++ internal/http3/client.go | 29 ++++++++++++++ internal/http3/deflate_reader.go | 40 ++++++++++++++++++++ internal/http3/zstd_reader.go | 37 ++++++++++++++++++ internal/transport/option.go | 6 +++ 16 files changed, 217 insertions(+), 29 deletions(-) create mode 100644 internal/http3/brotli_reader.go create mode 100644 internal/http3/deflate_reader.go create mode 100644 internal/http3/zstd_reader.go diff --git a/client.go b/client.go index f4ed8b0b..2631e6ce 100644 --- a/client.go +++ b/client.go @@ -381,6 +381,18 @@ func (c *Client) EnableCompression() *Client { return c } +// EnableAutoDecompress enables the automatic decompression (disabled by default). +func (c *Client) EnableAutoDecompress() *Client { + c.Transport.AutoDecompression = true + return c +} + +// DisableAutoDecompress disables the automatic decompression (disabled by default). +func (c *Client) DisableAutoDecompress() *Client { + c.Transport.AutoDecompression = false + return c +} + // SetTLSClientConfig set the TLS client config. Be careful! Usually // you don't need this, you can directly set the tls configuration with // methods like EnableInsecureSkipVerify, SetCerts etc. Or you can call diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 6e895b2b..430089f7 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -1,6 +1,8 @@ module find-popular-repo -go 1.18 +go 1.21 + +toolchain go1.22.3 replace github.com/imroc/req/v3 => ../../ @@ -19,11 +21,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index a96224fb..382380ce 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -170,6 +170,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -179,6 +180,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -201,6 +203,7 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNG golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -239,6 +242,7 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7 golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -249,6 +253,7 @@ golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -265,6 +270,7 @@ golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/opentelemetry-jaeger-tracing/go.mod b/examples/opentelemetry-jaeger-tracing/go.mod index 1fc84c1d..09afcab2 100644 --- a/examples/opentelemetry-jaeger-tracing/go.mod +++ b/examples/opentelemetry-jaeger-tracing/go.mod @@ -1,6 +1,8 @@ module opentelemetry-jaeger-tracing -go 1.18 +go 1.21 + +toolchain go1.22.3 replace github.com/imroc/req/v3 => ../../ @@ -28,11 +30,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/opentelemetry-jaeger-tracing/go.sum b/examples/opentelemetry-jaeger-tracing/go.sum index b81a8a94..6ca9b455 100644 --- a/examples/opentelemetry-jaeger-tracing/go.sum +++ b/examples/opentelemetry-jaeger-tracing/go.sum @@ -181,6 +181,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -190,6 +191,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -215,6 +217,7 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNG golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -257,6 +260,7 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7 golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -268,6 +272,7 @@ golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -283,6 +288,7 @@ golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index ca2d105c..8913edf1 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -1,6 +1,8 @@ module uploadclient -go 1.18 +go 1.21 + +toolchain go1.22.3 replace github.com/imroc/req/v3 => ../../../ @@ -20,11 +22,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 24e9d3b8..3d8d9499 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -169,6 +169,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -178,6 +179,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -205,6 +207,7 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNG golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -246,6 +249,7 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7 golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -257,6 +261,7 @@ golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -272,6 +277,7 @@ golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod index ca2d105c..8913edf1 100644 --- a/examples/uploadcallback/uploadclient/go.mod +++ b/examples/uploadcallback/uploadclient/go.mod @@ -1,6 +1,8 @@ module uploadclient -go 1.18 +go 1.21 + +toolchain go1.22.3 replace github.com/imroc/req/v3 => ../../../ @@ -20,11 +22,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum index 24e9d3b8..3d8d9499 100644 --- a/examples/uploadcallback/uploadclient/go.sum +++ b/examples/uploadcallback/uploadclient/go.sum @@ -169,6 +169,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -178,6 +179,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -205,6 +207,7 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770 h1:dIi4qVdvjZEjiMDv7vhokAZNG golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -246,6 +249,7 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7 golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -257,6 +261,7 @@ golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -272,6 +277,7 @@ golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/go.mod b/go.mod index 7b617180..85849bbc 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.8 // indirect github.com/onsi/ginkgo/v2 v2.16.0 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.21.0 // indirect diff --git a/go.sum b/go.sum index 7eb022ef..c9e55489 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/onsi/ginkgo/v2 v2.16.0 h1:7q1w9frJDzninhXxjZd+Y/x54XNjG/UlRLIYPZafsPM= github.com/onsi/ginkgo/v2 v2.16.0/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= diff --git a/internal/http3/brotli_reader.go b/internal/http3/brotli_reader.go new file mode 100644 index 00000000..7ee5c13f --- /dev/null +++ b/internal/http3/brotli_reader.go @@ -0,0 +1,30 @@ +package http3 + +import ( + "github.com/andybalholm/brotli" + "io" +) + +type BrotliReader struct { + Body io.ReadCloser // underlying Response.Body + br io.Reader // lazily-initialized brotli reader + berr error // sticky error +} + +func newBrotliReader(body io.ReadCloser) io.ReadCloser { + return &BrotliReader{Body: body} +} + +func (br *BrotliReader) Read(p []byte) (n int, err error) { + if br.berr != nil { + return 0, br.berr + } + if br.br == nil { + br.br = brotli.NewReader(br.Body) + } + return br.br.Read(p) +} + +func (br *BrotliReader) Close() error { + return br.Body.Close() +} diff --git a/internal/http3/client.go b/internal/http3/client.go index 59ffb413..106d138b 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -500,6 +500,35 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui res.ContentLength = -1 res.Body = newGzipReader(respBody) res.Uncompressed = true + } else if c.opt.AutoDecompression { + switch res.Header.Get("Content-Encoding") { + case "gzip": + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newGzipReader(respBody) + res.Uncompressed = true + case "deflate": + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newDeflateReader(respBody) + res.Uncompressed = true + case "br": + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newBrotliReader(respBody) + res.Uncompressed = true + case "zstd": + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newZstdReader(respBody) + res.Uncompressed = true + default: + res.Uncompressed = false + } } else { res.Body = respBody } diff --git a/internal/http3/deflate_reader.go b/internal/http3/deflate_reader.go new file mode 100644 index 00000000..1e0a3871 --- /dev/null +++ b/internal/http3/deflate_reader.go @@ -0,0 +1,40 @@ +package http3 + +import ( + "compress/flate" + "io" +) + +type DeflateReader struct { + Body io.ReadCloser // underlying Response.Body + dr io.ReadCloser // lazily-initialized deflate reader + derr error // sticky error +} + +func newDeflateReader(body io.ReadCloser) io.ReadCloser { + return &DeflateReader{Body: body} +} + +func (df *DeflateReader) Read(p []byte) (n int, err error) { + if df.derr != nil { + return 0, df.derr + } + if df.dr == nil { + df.dr = flate.NewReader(df.Body) + if df.dr == nil { + df.derr = io.ErrUnexpectedEOF + return 0, df.derr + } + } + return df.dr.Read(p) +} + +func (df *DeflateReader) Close() error { + if df.dr != nil { + err := df.dr.Close() + if err != nil { + return err + } + } + return df.Body.Close() +} diff --git a/internal/http3/zstd_reader.go b/internal/http3/zstd_reader.go new file mode 100644 index 00000000..b2116c33 --- /dev/null +++ b/internal/http3/zstd_reader.go @@ -0,0 +1,37 @@ +package http3 + +import ( + "github.com/klauspost/compress/zstd" + "io" +) + +type ZstdReader struct { + Body io.ReadCloser // underlying Response.Body + zr *zstd.Decoder // lazily-initialized zstd reader + zerr error // sticky error +} + +func newZstdReader(body io.ReadCloser) io.ReadCloser { + return &ZstdReader{Body: body} +} + +func (zr *ZstdReader) Read(p []byte) (n int, err error) { + if zr.zerr != nil { + return 0, zr.zerr + } + if zr.zr == nil { + zr.zr, err = zstd.NewReader(zr.Body) + if err != nil { + zr.zerr = err + return 0, err + } + } + return zr.zr.Read(p) +} + +func (zr *ZstdReader) Close() error { + if zr.zr != nil { + zr.zr.Close() + } + return zr.Body.Close() +} diff --git a/internal/transport/option.go b/internal/transport/option.go index 6f9f97f6..e9403f39 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -79,6 +79,12 @@ type Options struct { // uncompressed. DisableCompression bool + // AutoDecompression, if true, enables automatic decompression of + // compressed responses. It is equivalent to setting the Accept-Encoding + // header to "gzip, deflate, br, zstd" and the Transport will handle the + // decompression of the response transparently, returning the uncompressed. + AutoDecompression bool + // EnableH2C, if true, enables http2 over plain http without tls. EnableH2C bool From 2069ef940dcb1c493d80750609b589765be3c465 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 11 Jun 2024 14:44:12 +0800 Subject: [PATCH 812/843] refactor auto decompression --- internal/{http3 => compress}/brotli_reader.go | 15 +++++-- .../{http3 => compress}/deflate_reader.go | 21 ++++----- internal/{http3 => compress}/gzip_reader.go | 24 +++++++--- internal/compress/reader.go | 23 ++++++++++ internal/{http3 => compress}/zstd_reader.go | 15 +++++-- internal/http2/transport.go | 45 +++++-------------- internal/http3/client.go | 29 +++--------- transport.go | 16 +++++-- 8 files changed, 103 insertions(+), 85 deletions(-) rename internal/{http3 => compress}/brotli_reader.go (68%) rename internal/{http3 => compress}/deflate_reader.go (66%) rename internal/{http3 => compress}/gzip_reader.go (65%) create mode 100644 internal/compress/reader.go rename internal/{http3 => compress}/zstd_reader.go (72%) diff --git a/internal/http3/brotli_reader.go b/internal/compress/brotli_reader.go similarity index 68% rename from internal/http3/brotli_reader.go rename to internal/compress/brotli_reader.go index 7ee5c13f..6eaeffb5 100644 --- a/internal/http3/brotli_reader.go +++ b/internal/compress/brotli_reader.go @@ -1,8 +1,9 @@ -package http3 +package compress import ( - "github.com/andybalholm/brotli" "io" + + "github.com/andybalholm/brotli" ) type BrotliReader struct { @@ -11,7 +12,7 @@ type BrotliReader struct { berr error // sticky error } -func newBrotliReader(body io.ReadCloser) io.ReadCloser { +func NewBrotliReader(body io.ReadCloser) *BrotliReader { return &BrotliReader{Body: body} } @@ -28,3 +29,11 @@ func (br *BrotliReader) Read(p []byte) (n int, err error) { func (br *BrotliReader) Close() error { return br.Body.Close() } + +func (br *BrotliReader) GetUnderlyingBody() io.ReadCloser { + return br.Body +} + +func (br *BrotliReader) SetUnderlyingBody(body io.ReadCloser) { + br.Body = body +} diff --git a/internal/http3/deflate_reader.go b/internal/compress/deflate_reader.go similarity index 66% rename from internal/http3/deflate_reader.go rename to internal/compress/deflate_reader.go index 1e0a3871..44cf4943 100644 --- a/internal/http3/deflate_reader.go +++ b/internal/compress/deflate_reader.go @@ -1,4 +1,4 @@ -package http3 +package compress import ( "compress/flate" @@ -11,7 +11,7 @@ type DeflateReader struct { derr error // sticky error } -func newDeflateReader(body io.ReadCloser) io.ReadCloser { +func NewDeflateReader(body io.ReadCloser) *DeflateReader { return &DeflateReader{Body: body} } @@ -21,20 +21,21 @@ func (df *DeflateReader) Read(p []byte) (n int, err error) { } if df.dr == nil { df.dr = flate.NewReader(df.Body) - if df.dr == nil { - df.derr = io.ErrUnexpectedEOF - return 0, df.derr - } } return df.dr.Read(p) } func (df *DeflateReader) Close() error { if df.dr != nil { - err := df.dr.Close() - if err != nil { - return err - } + return df.dr.Close() } return df.Body.Close() } + +func (df *DeflateReader) GetUnderlyingBody() io.ReadCloser { + return df.Body +} + +func (df *DeflateReader) SetUnderlyingBody(body io.ReadCloser) { + df.Body = body +} diff --git a/internal/http3/gzip_reader.go b/internal/compress/gzip_reader.go similarity index 65% rename from internal/http3/gzip_reader.go rename to internal/compress/gzip_reader.go index 9050623e..615d7339 100644 --- a/internal/http3/gzip_reader.go +++ b/internal/compress/gzip_reader.go @@ -1,14 +1,12 @@ -package http3 +package compress -// copied from net/transport.go - -// GzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read import ( "compress/gzip" "io" + "io/fs" ) +// GzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read type GzipReader struct { Body io.ReadCloser // underlying Response.Body @@ -16,7 +14,7 @@ type GzipReader struct { zerr error // sticky error } -func newGzipReader(body io.ReadCloser) io.ReadCloser { +func NewGzipReader(body io.ReadCloser) *GzipReader { return &GzipReader{Body: body} } @@ -35,5 +33,17 @@ func (gz *GzipReader) Read(p []byte) (n int, err error) { } func (gz *GzipReader) Close() error { - return gz.Body.Close() + if err := gz.Body.Close(); err != nil { + return err + } + gz.zerr = fs.ErrClosed + return nil +} + +func (gz *GzipReader) GetUnderlyingBody() io.ReadCloser { + return gz.Body +} + +func (gz *GzipReader) SetUnderlyingBody(body io.ReadCloser) { + gz.Body = body } diff --git a/internal/compress/reader.go b/internal/compress/reader.go new file mode 100644 index 00000000..0a65a3df --- /dev/null +++ b/internal/compress/reader.go @@ -0,0 +1,23 @@ +package compress + +import "io" + +type CompressReader interface { + io.ReadCloser + GetUnderlyingBody() io.ReadCloser + SetUnderlyingBody(body io.ReadCloser) +} + +func NewCompressReader(body io.ReadCloser, contentEncoding string) CompressReader { + switch contentEncoding { + case "gzip": + return NewGzipReader(body) + case "deflate": + return NewDeflateReader(body) + case "br": + return NewBrotliReader(body) + case "zstd": + return NewZstdReader(body) + } + return nil +} diff --git a/internal/http3/zstd_reader.go b/internal/compress/zstd_reader.go similarity index 72% rename from internal/http3/zstd_reader.go rename to internal/compress/zstd_reader.go index b2116c33..5e0b68d9 100644 --- a/internal/http3/zstd_reader.go +++ b/internal/compress/zstd_reader.go @@ -1,8 +1,9 @@ -package http3 +package compress import ( - "github.com/klauspost/compress/zstd" "io" + + "github.com/klauspost/compress/zstd" ) type ZstdReader struct { @@ -11,7 +12,7 @@ type ZstdReader struct { zerr error // sticky error } -func newZstdReader(body io.ReadCloser) io.ReadCloser { +func NewZstdReader(body io.ReadCloser) *ZstdReader { return &ZstdReader{Body: body} } @@ -35,3 +36,11 @@ func (zr *ZstdReader) Close() error { } return zr.Body.Close() } + +func (zr *ZstdReader) GetUnderlyingBody() io.ReadCloser { + return zr.Body +} + +func (zr *ZstdReader) SetUnderlyingBody(body io.ReadCloser) { + zr.Body = body +} diff --git a/internal/http2/transport.go b/internal/http2/transport.go index fd3eaf76..483fc9c9 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -9,14 +9,12 @@ package http2 import ( "bufio" "bytes" - "compress/gzip" "context" "crypto/rand" "crypto/tls" "errors" "fmt" "io" - "io/fs" "log" "math" "math/bits" @@ -40,6 +38,7 @@ import ( "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" "github.com/imroc/req/v3/internal/netutil" @@ -2568,8 +2567,17 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 - res.Body = &GzipReader{Body: res.Body} + res.Body = compress.NewGzipReader(res.Body) res.Uncompressed = true + } else if cs.cc.t.AutoDecompression { + contentEncoding := res.Header.Get("Content-Encoding") + if contentEncoding != "" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Uncompressed = true + res.Body = compress.NewCompressReader(res.Body, contentEncoding) + } } return res, nil @@ -3145,37 +3153,6 @@ func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err } -// GzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -type GzipReader struct { - _ incomparable - Body io.ReadCloser // underlying Response.Body - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // sticky error -} - -func (gz *GzipReader) Read(p []byte) (n int, err error) { - if gz.zerr != nil { - return 0, gz.zerr - } - if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.Body) - if err != nil { - gz.zerr = err - return 0, err - } - } - return gz.zr.Read(p) -} - -func (gz *GzipReader) Close() error { - if err := gz.Body.Close(); err != nil { - return err - } - gz.zerr = fs.ErrClosed - return nil -} - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func isConnectionCloseRequest(req *http.Request) bool { diff --git a/internal/http3/client.go b/internal/http3/client.go index 106d138b..92f16c6a 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/quic-go/quicvarint" "github.com/imroc/req/v3/internal/transport" @@ -498,36 +499,16 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 - res.Body = newGzipReader(respBody) + res.Body = compress.NewGzipReader(respBody) res.Uncompressed = true } else if c.opt.AutoDecompression { - switch res.Header.Get("Content-Encoding") { - case "gzip": + contentEncoding := res.Header.Get("Content-Encoding") + if contentEncoding != "" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 - res.Body = newGzipReader(respBody) res.Uncompressed = true - case "deflate": - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = newDeflateReader(respBody) - res.Uncompressed = true - case "br": - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = newBrotliReader(respBody) - res.Uncompressed = true - case "zstd": - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = newZstdReader(respBody) - res.Uncompressed = true - default: - res.Uncompressed = false + res.Body = compress.NewCompressReader(respBody, contentEncoding) } } else { res.Body = respBody diff --git a/transport.go b/transport.go index a1f847b0..b2bd31fd 100644 --- a/transport.go +++ b/transport.go @@ -35,6 +35,7 @@ import ( "github.com/imroc/req/v3/internal/altsvcutil" "github.com/imroc/req/v3/internal/ascii" "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/header" h2internal "github.com/imroc/req/v3/internal/http2" @@ -681,10 +682,8 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu switch b := res.Body.(type) { case *gzipReader: b.body.body = wrap(b.body.body) - case *h2internal.GzipReader: - b.Body = wrap(b.Body) - case *http3.GzipReader: - b.Body = wrap(b.Body) + case compress.CompressReader: + b.SetUnderlyingBody(wrap(b.GetUnderlyingBody())) default: res.Body = wrap(res.Body) } @@ -2731,6 +2730,15 @@ func (pc *persistConn) readLoop() { resp.Header.Del("Content-Length") resp.ContentLength = -1 resp.Uncompressed = true + } else if pc.t.AutoDecompression { + contentEncoding := resp.Header.Get("Content-Encoding") + if contentEncoding != "" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true + resp.Body = compress.NewCompressReader(resp.Body, contentEncoding) + } } select { From 88a68fe748eb0a7ca09e8f4e71d54d97482c103a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 13 Jun 2024 11:23:25 +0800 Subject: [PATCH 813/843] remove QQ group in README --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 822a6651..16c2d0c9 100644 --- a/README.md +++ b/README.md @@ -501,7 +501,6 @@ If you have questions, feel free to reach out to us in the following ways: * [Github Discussion](https://github.com/imroc/req/discussions) * [Slack](https://imroc-req.slack.com/archives/C03UFPGSNC8) | [Join](https://slack.req.cool/) -* QQ Group (Chinese): 621411351 - ## Sponsors From 9479f861707635f81d57e859bdfb99b25116e96e Mon Sep 17 00:00:00 2001 From: Qiming Xu <33349132+xqm32@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:43:41 +0800 Subject: [PATCH 814/843] Fix typo "hava" to "have" --- response.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/response.go b/response.go index bd2f9d7c..65b9f521 100644 --- a/response.go +++ b/response.go @@ -195,7 +195,7 @@ func (r *Response) Into(v interface{}) error { return r.Unmarshal(v) } -// Bytes return the response body as []bytes that hava already been read, could be +// Bytes return the response body as []bytes that have already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. // 2. `Client.DisableAutoReadResponse` and `Request.DisableAutoReadResponse` is not @@ -204,7 +204,7 @@ func (r *Response) Bytes() []byte { return r.body } -// String returns the response body as string that hava already been read, could be +// String returns the response body as string that have already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. // 2. `Client.DisableAutoReadResponse` and `Request.DisableAutoReadResponse` is not From b46a0977df4a59c829aa38768b157afc644ba595 Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:10:23 +0800 Subject: [PATCH 815/843] =?UTF-8?q?=F0=9F=A6=84=20refactor:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete unnecessary return --- internal/altsvcutil/altsvcutil.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/altsvcutil/altsvcutil.go b/internal/altsvcutil/altsvcutil.go index 8aa160a0..84978d7e 100644 --- a/internal/altsvcutil/altsvcutil.go +++ b/internal/altsvcutil/altsvcutil.go @@ -79,7 +79,6 @@ func (p *altAvcParser) Parse() (as []*altsvc.AltSvc, err error) { } } } - return } func (p *altAvcParser) parseKv() (key, value string, haveNextField bool, err error) { From 9454280a852c5875f13609e77ebc6ed4d68cde66 Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:27:42 +0800 Subject: [PATCH 816/843] =?UTF-8?q?=F0=9F=A6=84=20refactor:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete redundant judgment --- middleware.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/middleware.go b/middleware.go index 59cd46f5..22d5504d 100644 --- a/middleware.go +++ b/middleware.go @@ -490,9 +490,9 @@ func parseRequestHeader(c *Client, r *Request) error { } func parseRequestCookie(c *Client, r *Request) error { - if len(c.Cookies) == 0 || r.RetryAttempt > 0 { - return nil + if len(c.Cookies) > 0 || r.RetryAttempt <= 0 { + r.Cookies = append(r.Cookies, c.Cookies...) } - r.Cookies = append(r.Cookies, c.Cookies...) + return nil } From 373b0b233d06b834c4834e85d1a107ec7535477c Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:29:36 +0800 Subject: [PATCH 817/843] =?UTF-8?q?=F0=9F=A6=84=20refactor:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete redundant judgment --- client_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index ccf62f57..b87745cb 100644 --- a/client_test.go +++ b/client_test.go @@ -242,10 +242,7 @@ func TestAutoDecode(t *testing.T) { assertSuccess(t, resp, err) tests.AssertEqual(t, "我是roc", resp.String()) resp, err = c.SetAutoDecodeContentTypeFunc(func(contentType string) bool { - if strings.Contains(contentType, "text") { - return true - } - return false + return strings.Contains(contentType, "text") }).R().Get("/gbk") assertSuccess(t, resp, err) tests.AssertEqual(t, "我是roc", resp.String()) From 545f631981f33bbe8172f8d0f401360f7a14faf2 Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:30:52 +0800 Subject: [PATCH 818/843] =?UTF-8?q?=F0=9F=A6=84=20refactor:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete unnecessary return --- client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client.go b/client.go index 2631e6ce..dc7016d1 100644 --- a/client.go +++ b/client.go @@ -278,7 +278,6 @@ func (c *Client) appendRootCertData(data []byte) { config.RootCAs = x509.NewCertPool() } config.RootCAs.AppendCertsFromPEM(data) - return } // SetRootCertFromString set root certificates from string. From 4b43a2beb5c479d348efb077463ee0a411c538f7 Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:44:46 +0800 Subject: [PATCH 819/843] =?UTF-8?q?=F0=9F=A6=84=20refactor:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete unnecessary return --- middleware.go | 1 - 1 file changed, 1 deletion(-) diff --git a/middleware.go b/middleware.go index 22d5504d..c310cd3d 100644 --- a/middleware.go +++ b/middleware.go @@ -256,7 +256,6 @@ func unmarshalBody(c *Client, r *Response, v interface{}) (err error) { } return c.jsonUnmarshal(body, v) } - return } func defaultResultStateChecker(resp *Response) ResultState { From 004636e1ffa5d28021a57bf60bb958635b2fcdaa Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:47:46 +0800 Subject: [PATCH 820/843] =?UTF-8?q?=F0=9F=8E=88=20perf:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace map[string]bool with map[string]struct{} to reduce memory footprint --- redirect.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redirect.go b/redirect.go index 364a55eb..ed395845 100644 --- a/redirect.go +++ b/redirect.go @@ -55,9 +55,9 @@ func SameHostRedirectPolicy() RedirectPolicy { // AllowedHostRedirectPolicy allows redirect only if the redirected host // match one of the host that specified. func AllowedHostRedirectPolicy(hosts ...string) RedirectPolicy { - m := make(map[string]bool) + m := make(map[string]struct{}) for _, h := range hosts { - m[strings.ToLower(getHostname(h))] = true + m[strings.ToLower(getHostname(h))] = struct{}{} } return func(req *http.Request, via []*http.Request) error { From abca3e38b6aef91db1cd0b891c70cf91c80cb2fb Mon Sep 17 00:00:00 2001 From: xiaok29 <1526783667@qq.com> Date: Thu, 5 Sep 2024 16:49:09 +0800 Subject: [PATCH 821/843] =?UTF-8?q?=F0=9F=8E=88=20perf:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace map[string]bool with map[string]struct{} to reduce memory footprint --- redirect.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redirect.go b/redirect.go index ed395845..f1cc4331 100644 --- a/redirect.go +++ b/redirect.go @@ -72,9 +72,9 @@ func AllowedHostRedirectPolicy(hosts ...string) RedirectPolicy { // AllowedDomainRedirectPolicy allows redirect only if the redirected domain // match one of the domain that specified. func AllowedDomainRedirectPolicy(hosts ...string) RedirectPolicy { - domains := make(map[string]bool) + domains := make(map[string]struct{}) for _, h := range hosts { - domains[strings.ToLower(getDomain(h))] = true + domains[strings.ToLower(getDomain(h))] = struct{}{} } return func(req *http.Request, via []*http.Request) error { From 79007b2b97d25d3b4a6e4966eaa94beb471426f4 Mon Sep 17 00:00:00 2001 From: Igor Agapie Date: Mon, 9 Sep 2024 12:07:42 +0200 Subject: [PATCH 822/843] fix: issue imroc/req#384 --- go.mod | 37 +++++++++-------- go.sum | 88 +++++++++++++++++++--------------------- internal/http3/client.go | 13 +++--- internal/http3/server.go | 2 +- 4 files changed, 70 insertions(+), 70 deletions(-) diff --git a/go.mod b/go.mod index 85849bbc..ed1d8947 100644 --- a/go.mod +++ b/go.mod @@ -1,28 +1,31 @@ module github.com/imroc/req/v3 -go 1.21 +go 1.22.0 + +toolchain go1.23.1 require ( + github.com/andybalholm/brotli v1.1.0 github.com/hashicorp/go-multierror v1.1.1 - github.com/quic-go/qpack v0.4.0 - github.com/quic-go/quic-go v0.41.0 - github.com/refraction-networking/utls v1.6.3 - golang.org/x/net v0.22.0 - golang.org/x/text v0.14.0 + github.com/klauspost/compress v1.17.9 + github.com/quic-go/qpack v0.5.1 + github.com/quic-go/quic-go v0.47.0 + github.com/refraction-networking/utls v1.6.7 + golang.org/x/net v0.29.0 + golang.org/x/text v0.18.0 ) require ( - github.com/andybalholm/brotli v1.1.0 // indirect - github.com/cloudflare/circl v1.3.7 // indirect - github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect - github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect + github.com/cloudflare/circl v1.4.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.8 // indirect - github.com/onsi/ginkgo/v2 v2.16.0 // indirect + github.com/onsi/ginkgo/v2 v2.20.2 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect + golang.org/x/mod v0.21.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/tools v0.24.0 // indirect ) diff --git a/go.sum b/go.sum index c9e55489..4147995e 100644 --- a/go.sum +++ b/go.sum @@ -1,63 +1,59 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= -github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= -github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY= +github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q= -github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9 h1:q5g0N9eal4bmJwXHC5z0QCKs8qhS35hFfq0BAYsIwZI= +github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= -github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/onsi/ginkgo/v2 v2.16.0 h1:7q1w9frJDzninhXxjZd+Y/x54XNjG/UlRLIYPZafsPM= -github.com/onsi/ginkgo/v2 v2.16.0/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= -github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= -github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/onsi/ginkgo/v2 v2.20.2 h1:7NVCeyIWROIAheY21RLS+3j2bb52W0W82tkberYytp4= +github.com/onsi/ginkgo/v2 v2.20.2/go.mod h1:K9gyxPIlb+aIvnZ8bd9Ak+YP18w3APlR+5coaZoE2ag= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= -github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= -github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= -github.com/refraction-networking/utls v1.6.3 h1:MFOfRN35sSx6K5AZNIoESsBuBxS2LCgRilRIdHb6fDc= -github.com/refraction-networking/utls v1.6.3/go.mod h1:yil9+7qSl+gBwJqztoQseO6Pr3h62pQoY1lXiNR/FPs= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= +github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= +github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B2MR1K67ULZM= +github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= -golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk= +golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/http3/client.go b/internal/http3/client.go index 92f16c6a..8c85cf63 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -13,12 +13,13 @@ import ( "sync/atomic" "time" + "github.com/quic-go/qpack" + "github.com/quic-go/quic-go" + "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/quic-go/quicvarint" "github.com/imroc/req/v3/internal/transport" - "github.com/quic-go/qpack" - "github.com/quic-go/quic-go" ) // MethodGet0RTT allows a GET request to be sent using 0-RTT. @@ -30,9 +31,9 @@ const ( ) const ( - VersionDraft29 quic.VersionNumber = 0xff00001d - Version1 quic.VersionNumber = 0x1 - Version2 quic.VersionNumber = 0x6b3343cf + VersionDraft29 quic.Version = 0xff00001d + Version1 quic.Version = 0x1 + Version2 quic.Version = 0x6b3343cf ) var defaultQuicConfig = &quic.Config{ @@ -82,7 +83,7 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con } if len(conf.Versions) == 0 { conf = conf.Clone() - conf.Versions = []quic.VersionNumber{Version1} + conf.Versions = []quic.Version{Version1} } if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") diff --git a/internal/http3/server.go b/internal/http3/server.go index 2b9d8658..9f94e7b5 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -19,7 +19,7 @@ const ( streamTypeQPACKDecoderStream = 3 ) -func versionToALPN(v quic.VersionNumber) string { +func versionToALPN(v quic.Version) string { switch v { case Version1, Version2: return nextProtoH3 From fcc942e9bc45a5d0d6b5077e25bfe19c8646c2f1 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 10 Sep 2024 09:40:08 +0800 Subject: [PATCH 823/843] update go version --- .github/workflows/ci.yml | 2 +- README.md | 2 +- go.mod | 2 -- transport.go | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 665084ff..de1e1701 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: test: strategy: matrix: - go: [ '1.21.x', '1.22.x' ] + go: [ '1.22.x', '1.23.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/README.md b/README.md index 16c2d0c9..c8191fd2 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Full documentation is available on the official website: https://req.cool. **Install** -You first need [Go](https://go.dev/) installed (version 1.20+ is required), then you can use the below Go command to install req: +You first need [Go](https://go.dev/) installed (version 1.22+ is required), then you can use the below Go command to install req: ``` sh go get github.com/imroc/req/v3 diff --git a/go.mod b/go.mod index ed1d8947..ef6f380a 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/imroc/req/v3 go 1.22.0 -toolchain go1.23.1 - require ( github.com/andybalholm/brotli v1.1.0 github.com/hashicorp/go-multierror v1.1.1 diff --git a/transport.go b/transport.go index b2bd31fd..69c8932f 100644 --- a/transport.go +++ b/transport.go @@ -583,7 +583,7 @@ func (t *Transport) EnableHTTP3() { } return } - if !(minorVersion >= 20 && minorVersion <= 21) { + if minorVersion < 22 || minorVersion > 23 { if t.Debugf != nil { t.Debugf("%s is not support http3", v) } From c0fb6d6660899e854f1a5c129e066c80c3713c69 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 10 Sep 2024 19:42:01 +0800 Subject: [PATCH 824/843] align upstream quic-go to v0.47.0 --- internal/dump/dump.go | 1 - internal/http3/body.go | 135 +++--- internal/http3/client.go | 591 ++++++++---------------- internal/http3/conn.go | 328 +++++++++++++ internal/http3/datagram.go | 98 ++++ internal/http3/frames.go | 59 ++- internal/http3/headers.go | 184 ++++++-- internal/http3/http_stream.go | 260 +++++++++-- internal/http3/protocol.go | 119 +++++ internal/http3/request_writer.go | 36 +- internal/http3/roundtrip.go | 335 +++++++++----- internal/http3/server.go | 34 +- internal/http3/state_tracking_stream.go | 116 +++++ internal/quic-go/quicvarint/varint.go | 41 +- transport.go | 2 +- 15 files changed, 1614 insertions(+), 725 deletions(-) create mode 100644 internal/http3/conn.go create mode 100644 internal/http3/datagram.go create mode 100644 internal/http3/protocol.go create mode 100644 internal/http3/state_tracking_stream.go diff --git a/internal/dump/dump.go b/internal/dump/dump.go index 231c4beb..0f71da7e 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -138,7 +138,6 @@ func NewDumper(opt Options) *Dumper { func (d *Dumper) SetOptions(opt Options) { d.Options = opt - return } func (d *Dumper) Clone() *Dumper { diff --git a/internal/http3/body.go b/internal/http3/body.go index 63ff4366..fa023ce4 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -2,68 +2,68 @@ package http3 import ( "context" + "errors" "io" - "net" "github.com/quic-go/quic-go" ) -// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: -// * for the server: the http.Request.Body -// * for the client: the http.Response.Body -// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. -// When a stream is taken over, it's the caller's responsibility to close the stream. -type HTTPStreamer interface { - HTTPStream() Stream -} - -type StreamCreator interface { - // Context returns a context that is cancelled when the underlying connection is closed. - Context() context.Context - OpenStream() (quic.Stream, error) - OpenStreamSync(context.Context) (quic.Stream, error) - OpenUniStream() (quic.SendStream, error) - OpenUniStreamSync(context.Context) (quic.SendStream, error) - LocalAddr() net.Addr - RemoteAddr() net.Addr - ConnectionState() quic.ConnectionState -} - -var _ StreamCreator = quic.Connection(nil) - // A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body. // It is used by WebTransport to create WebTransport streams after a session has been established. type Hijacker interface { - StreamCreator() StreamCreator + Connection() Connection } -// The body of a http.Request or http.Response. +var errTooMuchData = errors.New("peer sent too much data") + +// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response). type body struct { - str quic.Stream + str *stream - wasHijacked bool // set when HTTPStream is called + remainingContentLength int64 + violatedContentLength bool + hasContentLength bool } -var ( - _ io.ReadCloser = &body{} - _ HTTPStreamer = &body{} -) - -func newRequestBody(str Stream) *body { - return &body{str: str} +func newBody(str *stream, contentLength int64) *body { + b := &body{str: str} + if contentLength >= 0 { + b.hasContentLength = true + b.remainingContentLength = contentLength + } + return b } -func (r *body) HTTPStream() Stream { - r.wasHijacked = true - return r.str -} +func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } -func (r *body) wasStreamHijacked() bool { - return r.wasHijacked +func (r *body) checkContentLengthViolation() error { + if !r.hasContentLength { + return nil + } + if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() { + if !r.violatedContentLength { + r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + r.violatedContentLength = true + } + return errTooMuchData + } + return nil } func (r *body) Read(b []byte) (int, error) { - return r.str.Read(b) + if err := r.checkContentLengthViolation(); err != nil { + return 0, err + } + if r.hasContentLength { + b = b[:min(int64(len(b)), r.remainingContentLength)] + } + n, err := r.str.Read(b) + r.remainingContentLength -= int64(n) + if err := r.checkContentLengthViolation(); err != nil { + return n, err + } + return n, maybeReplaceError(err) } func (r *body) Close() error { @@ -71,9 +71,26 @@ func (r *body) Close() error { return nil } -type hijackableBody struct { +type requestBody struct { body - conn quic.Connection // only needed to implement Hijacker + connCtx context.Context + rcvdSettings <-chan struct{} + getSettings func() *Settings +} + +var _ io.ReadCloser = &requestBody{} + +func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody { + return &requestBody{ + body: *newBody(str, contentLength), + connCtx: connCtx, + rcvdSettings: rcvdSettings, + getSettings: getSettings, + } +} + +type hijackableBody struct { + body body // only set for the http.Response // The channel is closed when the user is done with this response: @@ -82,31 +99,21 @@ type hijackableBody struct { reqDoneClosed bool } -var ( - _ Hijacker = &hijackableBody{} - _ HTTPStreamer = &hijackableBody{} -) +var _ io.ReadCloser = &hijackableBody{} -func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { +func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody { return &hijackableBody{ - body: body{ - str: str, - }, + body: *newBody(str, contentLength), reqDone: done, - conn: conn, } } -func (r *hijackableBody) StreamCreator() StreamCreator { - return r.conn -} - func (r *hijackableBody) Read(b []byte) (int, error) { - n, err := r.str.Read(b) + n, err := r.body.Read(b) if err != nil { r.requestDone() } - return n, err + return n, maybeReplaceError(err) } func (r *hijackableBody) requestDone() { @@ -119,17 +126,9 @@ func (r *hijackableBody) requestDone() { r.reqDoneClosed = true } -func (r *body) StreamID() quic.StreamID { - return r.str.StreamID() -} - func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. - r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) + r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } - -func (r *hijackableBody) HTTPStream() Stream { - return r.str -} diff --git a/internal/http3/client.go b/internal/http3/client.go index 8c85cf63..664f0f76 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -2,38 +2,34 @@ package http3 import ( "context" - "crypto/tls" "errors" - "fmt" "io" - "net" + "log/slog" "net/http" - "strconv" + "net/http/httptrace" + "net/textproto" "sync" - "sync/atomic" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" - "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/quic-go/quicvarint" "github.com/imroc/req/v3/internal/transport" ) -// MethodGet0RTT allows a GET request to be sent using 0-RTT. -// Note that 0-RTT data doesn't provide replay protection. -const MethodGet0RTT = "GET_0RTT" - const ( - defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB + // MethodGet0RTT allows a GET request to be sent using 0-RTT. + // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. + MethodGet0RTT = "GET_0RTT" + // MethodHead0RTT allows a HEAD request to be sent using 0-RTT. + // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. + MethodHead0RTT = "HEAD_0RTT" ) const ( - VersionDraft29 quic.Version = 0xff00001d - Version1 quic.Version = 0x1 - Version2 quic.Version = 0x6b3343cf + defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB ) var defaultQuicConfig = &quic.Config{ @@ -41,119 +37,70 @@ var defaultQuicConfig = &quic.Config{ KeepAlivePeriod: 10 * time.Second, } -type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) +// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server. +type SingleDestinationRoundTripper struct { + *transport.Options -var dialAddr dialFunc = quic.DialAddrEarly + Connection quic.Connection -type roundTripperOpts struct { - DisableCompression bool - EnableDatagram bool - MaxHeaderBytes int64 + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. + EnableDatagrams bool + + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) - UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) - dump *dump.Dumper -} + StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) + UniStreamHijacker func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) -// client is a HTTP3 client doing requests -type client struct { - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 - dialOnce sync.Once - dialer dialFunc - handshakeErr error + Logger *slog.Logger + initOnce sync.Once + hconn *connection requestWriter *requestWriter - - decoder *qpack.Decoder - - hostname string - conn atomic.Pointer[quic.EarlyConnection] - - opt *transport.Options + decoder *qpack.Decoder } -var _ roundTripCloser = &client{} +var _ http.RoundTripper = &SingleDestinationRoundTripper{} -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) { - if conf == nil { - conf = defaultQuicConfig.Clone() - } - if len(conf.Versions) == 0 { - conf = conf.Clone() - conf.Versions = []quic.Version{Version1} - } - if len(conf.Versions) != 1 { - return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") - } - if conf.MaxIncomingStreams == 0 { - conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams - } - conf.EnableDatagrams = opts.EnableDatagram - var debugf func(format string, v ...interface{}) - if opt != nil && opt.Debugf != nil { - debugf = opt.Debugf - } - - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - if tlsConf.ServerName == "" { - sni, _, err := net.SplitHostPort(hostname) - if err != nil { - // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. - sni = hostname - } - tlsConf.ServerName = sni - } - // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} - - return &client{ - hostname: authorityAddr("https", hostname), - tlsConf: tlsConf, - requestWriter: newRequestWriter(debugf), - decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: conf, - opts: opts, - dialer: dialer, - opt: opt, - }, nil +func (c *SingleDestinationRoundTripper) Start() Connection { + c.initOnce.Do(func() { c.init() }) + return c.hconn } -func (c *client) dial(ctx context.Context) error { - var err error - var conn quic.EarlyConnection - if c.dialer != nil { - conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) - } else { - conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) - } - if err != nil { - return err - } - c.conn.Store(&conn) - +func (c *SingleDestinationRoundTripper) init() { + c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) + c.requestWriter = newRequestWriter() + c.hconn = newConnection( + c.Connection.Context(), + c.Connection, + c.EnableDatagrams, + PerspectiveClient, + c.Logger, + 0, + c.Options, + ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(conn); err != nil { - c.opt.Debugf("setting up http3 connection failed: %s", err) - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") + if err := c.setupConn(c.hconn); err != nil { + if c.Logger != nil { + c.Logger.Debug("Setting up connection failed", "error", err) + } + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() - - if c.opts.StreamHijacker != nil { - go c.handleBidirectionalStreams(conn) + if c.StreamHijacker != nil { + go c.handleBidirectionalStreams() } - go c.handleUnidirectionalStreams(conn) - return nil + go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker) } -func (c *client) setupConn(conn quic.EarlyConnection) error { +func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error { // open the control stream str, err := conn.OpenUniStream() if err != nil { @@ -162,108 +109,54 @@ func (c *client) setupConn(conn quic.EarlyConnection) error { b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) // send the SETTINGS frame - b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) + b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b) _, err = str.Write(b) return err } -func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { +func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { for { - str, err := conn.AcceptStream(context.Background()) + str, err := c.hconn.AcceptStream(context.Background()) if err != nil { - c.opt.Debugf("accepting bidirectional stream failed: %s", err) - return - } - go func(str quic.Stream) { - _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { - return c.opts.StreamHijacker(ft, conn, str, e) - }) - if err == errHijacked { - return + if c.Logger != nil { + c.Logger.Debug("accepting bidirectional stream failed", "error", err) } - if err != nil { - c.opt.Debugf("error handling stream: %s", err) - } - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") - }(str) - } -} - -func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { - for { - str, err := conn.AcceptUniStream(context.Background()) - if err != nil { - c.opt.Debugf("accepting unidirectional stream failed: %s", err) return } - - go func(str quic.ReceiveStream) { - streamType, err := quicvarint.Read(quicvarint.NewReader(str)) - if err != nil { - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { - return - } - c.opt.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) - return - } - // We're only interested in the control stream here. - switch streamType { - case streamTypeControlStream: - case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: - // Our QPACK implementation doesn't use the dynamic table yet. - // TODO: check that only one stream of each type is opened. - return - case streamTypePushStream: - // We never increased the Push ID, so we don't expect any push streams. - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") - return - default: - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { - return - } - str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) + fp := &frameParser{ + r: str, + conn: c.hconn, + unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) { + id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + return c.StreamHijacker(ft, id, str, e) + }, + } + go func() { + if _, err := fp.ParseNext(); err == errHijacked { return } - f, err := parseNextFrame(str, nil) if err != nil { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") - return - } - sf, ok := f.(*settingsFrame) - if !ok { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") - return - } - if !sf.Datagram { - return - } - // If datagram support was enabled on our side as well as on the server side, - // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. - // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + if c.Logger != nil { + c.Logger.Debug("error handling stream", "error", err) + } } - }(str) - } -} - -func (c *client) Close() error { - conn := c.conn.Load() - if conn == nil { - return nil + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + }() } - return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } -func (c *client) maxHeaderBytes() uint64 { - if c.opts.MaxHeaderBytes <= 0 { +func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 { + if c.MaxResponseHeaderBytes <= 0 { return defaultMaxResponseHeaderBytes } - return uint64(c.opts.MaxHeaderBytes) + return uint64(c.MaxResponseHeaderBytes) } -func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - rsp, err := c.roundTripOpt(req, opt) +// RoundTrip executes a request and returns a response +func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + c.initOnce.Do(func() { c.init() }) + + rsp, err := c.roundTrip(req) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error err = req.Context().Err() @@ -271,35 +164,48 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, err } -// RoundTripOpt executes a request and returns a response -func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { - return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) - } - - c.dialOnce.Do(func() { - c.handshakeErr = c.dial(req.Context()) - }) - if c.handshakeErr != nil { - return nil, c.handshakeErr - } - - // At this point, c.conn is guaranteed to be set. - conn := *c.conn.Load() - +func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. - if req.Method == MethodGet0RTT { + switch req.Method { + case MethodGet0RTT: + // don't modify the original request + reqCopy := *req + req = &reqCopy req.Method = http.MethodGet - } else { + case MethodHead0RTT: + // don't modify the original request + reqCopy := *req + req = &reqCopy + req.Method = http.MethodHead + default: // wait for the handshake to complete + earlyConn, ok := c.Connection.(quic.EarlyConnection) + if ok { + select { + case <-earlyConn.HandshakeComplete(): + case <-req.Context().Done(): + return nil, req.Context().Err() + } + } + } + + // It is only possible to send an Extended CONNECT request once the SETTINGS were received. + // See section 3 of RFC 8441. + if isExtendedConnectRequest(req) { + connCtx := c.Connection.Context() + // wait for the server's SETTINGS frame to arrive select { - case <-conn.HandshakeComplete(): - case <-req.Context().Done(): - return nil, req.Context().Err() + case <-c.hconn.ReceivedSettings(): + case <-connCtx.Done(): + return nil, context.Cause(connCtx) + } + if !c.hconn.Settings().EnableExtendedConnect { + return nil, errors.New("http3: server didn't enable Extended CONNECT") } } - str, err := conn.OpenStreamSync(req.Context()) + reqDone := make(chan struct{}) + str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes()) if err != nil { return nil, err } @@ -307,7 +213,6 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon // Request Cancellation: // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. - reqDone := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) @@ -319,214 +224,110 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } }() - doneChan := reqDone - if opt.DontCloseRequestStream { - doneChan = nil - } - rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) - if rerr.err != nil { // if any error occurred + rsp, err := c.doRequest(req, str) + if err != nil { // if any error occurred close(reqDone) <-done - if rerr.streamErr != 0 { // if it was a stream error - str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) - } - if rerr.connErr != 0 { // if it was a connection error - var reason string - if rerr.err != nil { - reason = rerr.err.Error() - } - conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) - } - return nil, maybeReplaceError(rerr.err) - + return nil, maybeReplaceError(err) } - if opt.DontCloseRequestStream { - close(reqDone) - <-done + return rsp, maybeReplaceError(err) +} + +func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) { + c.initOnce.Do(func() { c.init() }) + + return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes()) +} + +// cancelingReader reads from the io.Reader. +// It cancels writing on the stream if any error other than io.EOF occurs. +type cancelingReader struct { + r io.Reader + str Stream +} + +func (r *cancelingReader) Read(b []byte) (int, error) { + n, err := r.r.Read(b) + if err != nil && err != io.EOF { + r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) } - return rsp, maybeReplaceError(rerr.err) + return n, err } -func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { +func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { defer body.Close() - b := make([]byte, bodyCopyBufferSize) - writeData := func(data []byte) error { - if _, err := str.Write(data); err != nil { - return err - } - return nil - } + buf := make([]byte, bodyCopyBufferSize) + sr := &cancelingReader{str: str, r: body} + var w io.Writer = str if len(dumps) > 0 { - writeData = func(data []byte) error { - for _, dump := range dumps { - dump.DumpRequestBody(data) - } - if _, err := str.Write(data); err != nil { - return err - } - return nil + for _, d := range dumps { + w = io.MultiWriter(w, d.RequestBodyOutput()) } } - for { - n, rerr := body.Read(b) - if n == 0 { - if rerr == nil { - continue - } - if rerr == io.EOF { - for _, dump := range dumps { - dump.DumpDefault([]byte("\r\n\r\n")) - } - break - } - } - if err := writeData(b[:n]); err != nil { - return err - } - if rerr != nil { - if rerr == io.EOF { - for _, dump := range dumps { - dump.DumpDefault([]byte("\r\n\r\n")) - } - break - } - str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) - return rerr + writeTail := func() { + for _, d := range dumps { + d.Output().Write([]byte("\r\n\r\n")) } } - return nil + written, err := io.CopyBuffer(w, sr, buf) + if len(dumps) > 0 && err == nil && written > 0 { + writeTail() + } + + return err } -func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { - var requestGzip bool - if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { - requestGzip = true - } - dumps := dump.GetDumpers(req.Context(), c.opts.dump) - var headerDumps []*dump.Dumper - for _, dump := range dumps { - if dump.RequestHeader() { - headerDumps = append(headerDumps, dump) - } - } - if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, headerDumps); err != nil { - return nil, newStreamError(ErrCodeInternalError, err) +func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) { + if err := str.SendRequestHeader(req); err != nil { + return nil, err } - - if req.Body == nil && !opt.DontCloseRequestStream { + if req.Body == nil { str.Close() - } - - hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) - if req.Body != nil { + } else { // send the request body asynchronously go func() { - var bodyDumps []*dump.Dumper - for _, dump := range dumps { - if dump.RequestBody() { - bodyDumps = append(bodyDumps, dump) + dumps := dump.GetDumpers(req.Context(), c.Dump) + if err := c.sendRequestBody(str, req.Body, dumps); err != nil { + if c.Logger != nil { + c.Logger.Debug("error writing request", "error", err) } } - if err := c.sendRequestBody(hstr, req.Body, bodyDumps); err != nil { - c.opt.Debugf("error writing request: %s", err) - } - if !opt.DontCloseRequestStream { - hstr.Close() - } + str.Close() }() } - frame, err := parseNextFrame(str, nil) - if err != nil { - return nil, newStreamError(ErrCodeFrameError, err) - } - hf, ok := frame.(*headersFrame) - if !ok { - return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) - } - if hf.Length > c.maxHeaderBytes() { - return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) - } - headerBlock := make([]byte, hf.Length) - if _, err := io.ReadFull(str, headerBlock); err != nil { - return nil, newStreamError(ErrCodeRequestIncomplete, err) - } - var respHeaderDumps []*dump.Dumper - for _, dump := range dumps { - if dump.ResponseHeader() { - respHeaderDumps = append(respHeaderDumps, dump) + // copy from net/http: support 1xx responses + trace := httptrace.ContextClientTrace(req.Context()) + num1xx := 0 // number of informational 1xx headers received + const max1xxResponses = 5 // arbitrary bound on number of informational responses + + var res *http.Response + for { + var err error + res, err = str.ReadResponse() + if err != nil { + return nil, err } - } - hfs, err := c.decoder.DecodeFull(headerBlock) - if len(respHeaderDumps) > 0 { - for _, hf := range hfs { - for _, dump := range respHeaderDumps { - dump.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + resCode := res.StatusCode + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see https://github.com/golang/go/issues/26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + if is1xxNonTerminal { + num1xx++ + if num1xx > max1xxResponses { + return nil, errors.New("http: too many 1xx informational responses") } - } - for _, dump := range respHeaderDumps { - dump.DumpResponseHeader([]byte("\r\n")) - } - } - if err != nil { - // TODO: use the right error code - return nil, newConnError(ErrCodeGeneralProtocolError, err) - } - - res, err := responseFromHeaders(hfs) - if err != nil { - return nil, newStreamError(ErrCodeMessageError, err) - } - res.Request = req - connState := conn.ConnectionState().TLS - res.TLS = &connState - respBody := newResponseBody(hstr, conn, reqDone) - - // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. - _, hasTransferEncoding := res.Header["Transfer-Encoding"] - isInformational := res.StatusCode >= 100 && res.StatusCode < 200 - isNoContent := res.StatusCode == http.StatusNoContent - isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 - if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { - res.ContentLength = -1 - if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { - if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { - res.ContentLength = clen64 + if trace != nil && trace.Got1xxResponse != nil { + if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil { + return nil, err + } } + continue } + break } - - if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = compress.NewGzipReader(respBody) - res.Uncompressed = true - } else if c.opt.AutoDecompression { - contentEncoding := res.Header.Get("Content-Encoding") - if contentEncoding != "" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Uncompressed = true - res.Body = compress.NewCompressReader(respBody, contentEncoding) - } - } else { - res.Body = respBody - } - - return res, requestError{} -} - -func (c *client) HandshakeComplete() bool { - conn := c.conn.Load() - if conn == nil { - return false - } - select { - case <-(*conn).HandshakeComplete(): - return true - default: - return false - } + connState := c.hconn.ConnectionState().TLS + res.TLS = &connState + res.Request = req + return res, nil } diff --git a/internal/http3/conn.go b/internal/http3/conn.go new file mode 100644 index 00000000..60f0e259 --- /dev/null +++ b/internal/http3/conn.go @@ -0,0 +1,328 @@ +package http3 + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/imroc/req/v3/internal/transport" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/quicvarint" + + "github.com/quic-go/qpack" +) + +// Connection is an HTTP/3 connection. +// It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream, +// SendDatagram and ReceiveDatagram. +type Connection interface { + OpenStream() (quic.Stream, error) + OpenStreamSync(context.Context) (quic.Stream, error) + OpenUniStream() (quic.SendStream, error) + OpenUniStreamSync(context.Context) (quic.SendStream, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr + CloseWithError(quic.ApplicationErrorCode, string) error + Context() context.Context + ConnectionState() quic.ConnectionState + + // ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received. + ReceivedSettings() <-chan struct{} + // Settings returns the settings received on this connection. + Settings() *Settings +} + +type connection struct { + quic.Connection + *transport.Options + ctx context.Context + + perspective Perspective + logger *slog.Logger + + enableDatagrams bool + + decoder *qpack.Decoder + + streamMx sync.Mutex + streams map[quic.StreamID]*datagrammer + + settings *Settings + receivedSettings chan struct{} + + idleTimeout time.Duration + idleTimer *time.Timer +} + +func newConnection( + ctx context.Context, + quicConn quic.Connection, + enableDatagrams bool, + perspective Perspective, + logger *slog.Logger, + idleTimeout time.Duration, + options *transport.Options, +) *connection { + c := &connection{ + ctx: ctx, + Connection: quicConn, + Options: options, + perspective: perspective, + logger: logger, + idleTimeout: idleTimeout, + enableDatagrams: enableDatagrams, + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + receivedSettings: make(chan struct{}), + streams: make(map[quic.StreamID]*datagrammer), + } + if idleTimeout > 0 { + c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) + } + return c +} + +func (c *connection) onIdleTimer() { + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout") +} + +func (c *connection) clearStream(id quic.StreamID) { + c.streamMx.Lock() + defer c.streamMx.Unlock() + + delete(c.streams, id) + if c.idleTimeout > 0 && len(c.streams) == 0 { + c.idleTimer.Reset(c.idleTimeout) + } +} + +func (c *connection) openRequestStream( + ctx context.Context, + requestWriter *requestWriter, + reqDone chan<- struct{}, + disableCompression bool, + maxHeaderBytes uint64, +) (*requestStream, error) { + str, err := c.Connection.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + c.streamMx.Lock() + c.streams[str.StreamID()] = datagrams + c.streamMx.Unlock() + qstr := newStateTrackingStream(str, c, datagrams) + rsp := &http.Response{} + hstr := newStream(qstr, c, datagrams, func(r io.Reader, l uint64) error { + hdr, err := c.decodeTrailers(r, l, maxHeaderBytes) + if err != nil { + return err + } + rsp.Trailer = hdr + return nil + }) + return newRequestStream(ctx, c.Options, hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes, rsp), nil +} + +func (c *connection) decodeTrailers(r io.Reader, l, maxHeaderBytes uint64) (http.Header, error) { + if l > maxHeaderBytes { + return nil, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", l, maxHeaderBytes) + } + + b := make([]byte, l) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + fields, err := c.decoder.DecodeFull(b) + if err != nil { + return nil, err + } + return parseTrailers(fields) +} + +func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) { + str, err := c.AcceptStream(ctx) + if err != nil { + return nil, nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + if c.perspective == PerspectiveServer { + strID := str.StreamID() + c.streamMx.Lock() + c.streams[strID] = datagrams + if c.idleTimeout > 0 { + if len(c.streams) == 1 { + c.idleTimer.Stop() + } + } + c.streamMx.Unlock() + str = newStateTrackingStream(str, c, datagrams) + } + return str, datagrams, nil +} + +func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) error { + if c.idleTimer != nil { + c.idleTimer.Stop() + } + return c.Connection.CloseWithError(code, msg) +} + +func (c *connection) HandleUnidirectionalStreams(hijack func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { + var ( + rcvdControlStr atomic.Bool + rcvdQPACKEncoderStr atomic.Bool + rcvdQPACKDecoderStr atomic.Bool + ) + + for { + str, err := c.Connection.AcceptUniStream(context.Background()) + if err != nil { + if c.logger != nil { + c.logger.Debug("accepting unidirectional stream failed", "error", err) + } + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + if hijack != nil && hijack(ServerStreamType(streamType), id, str, err) { + return + } + if c.logger != nil { + c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) + } + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream: + if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypeQPACKDecoderStream: + if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypePushStream: + switch c.perspective { + case PerspectiveClient: + // we never increased the Push ID, so we don't expect any push streams + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") + case PerspectiveServer: + // only the server can push + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") + } + return + default: + if hijack != nil { + if hijack( + ServerStreamType(streamType), + c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID), + str, + nil, + ) { + return + } + } + str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) + return + } + // Only a single control stream is allowed. + if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return + } + fp := &frameParser{conn: c.Connection, r: str} + f, err := fp.ParseNext() + if err != nil { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") + return + } + c.settings = &Settings{ + EnableDatagrams: sf.Datagram, + EnableExtendedConnect: sf.ExtendedConnect, + Other: sf.Other, + } + close(c.receivedSettings) + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the server side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + return + } + go func() { + if err := c.receiveDatagrams(); err != nil { + if c.logger != nil { + c.logger.Debug("receiving datagrams failed", "error", err) + } + } + }() + }(str) + } +} + +func (c *connection) sendDatagram(streamID quic.StreamID, b []byte) error { + // TODO: this creates a lot of garbage and an additional copy + data := make([]byte, 0, len(b)+8) + data = quicvarint.Append(data, uint64(streamID/4)) + data = append(data, b...) + return c.Connection.SendDatagram(data) +} + +func (c *connection) receiveDatagrams() error { + for { + b, err := c.Connection.ReceiveDatagram(context.Background()) + if err != nil { + return err + } + quarterStreamID, n, err := quicvarint.Parse(b) + if err != nil { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("could not read quarter stream id: %w", err) + } + if quarterStreamID > maxQuarterStreamID { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("invalid quarter stream id: %w", err) + } + streamID := quic.StreamID(4 * quarterStreamID) + c.streamMx.Lock() + dg, ok := c.streams[streamID] + if !ok { + c.streamMx.Unlock() + return nil + } + c.streamMx.Unlock() + dg.enqueue(b[n:]) + } +} + +// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. +func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings } + +// Settings returns the settings received on this connection. +// It is only valid to call this function after the channel returned by ReceivedSettings was closed. +func (c *connection) Settings() *Settings { return c.settings } + +func (c *connection) Context() context.Context { return c.ctx } diff --git a/internal/http3/datagram.go b/internal/http3/datagram.go new file mode 100644 index 00000000..6d570e6b --- /dev/null +++ b/internal/http3/datagram.go @@ -0,0 +1,98 @@ +package http3 + +import ( + "context" + "sync" +) + +const maxQuarterStreamID = 1<<60 - 1 + +const streamDatagramQueueLen = 32 + +type datagrammer struct { + sendDatagram func([]byte) error + + hasData chan struct{} + queue [][]byte // TODO: use a ring buffer + + mx sync.Mutex + sendErr error + receiveErr error +} + +func newDatagrammer(sendDatagram func([]byte) error) *datagrammer { + return &datagrammer{ + sendDatagram: sendDatagram, + hasData: make(chan struct{}, 1), + } +} + +func (d *datagrammer) SetReceiveError(err error) { + d.mx.Lock() + defer d.mx.Unlock() + + d.receiveErr = err + d.signalHasData() +} + +func (d *datagrammer) SetSendError(err error) { + d.mx.Lock() + defer d.mx.Unlock() + + d.sendErr = err +} + +func (d *datagrammer) Send(b []byte) error { + d.mx.Lock() + sendErr := d.sendErr + d.mx.Unlock() + if sendErr != nil { + return sendErr + } + + return d.sendDatagram(b) +} + +func (d *datagrammer) signalHasData() { + select { + case d.hasData <- struct{}{}: + default: + } +} + +func (d *datagrammer) enqueue(data []byte) { + d.mx.Lock() + defer d.mx.Unlock() + + if d.receiveErr != nil { + return + } + if len(d.queue) >= streamDatagramQueueLen { + return + } + d.queue = append(d.queue, data) + d.signalHasData() +} + +func (d *datagrammer) Receive(ctx context.Context) ([]byte, error) { +start: + d.mx.Lock() + if len(d.queue) >= 1 { + data := d.queue[0] + d.queue = d.queue[1:] + d.mx.Unlock() + return data, nil + } + if receiveErr := d.receiveErr; receiveErr != nil { + d.mx.Unlock() + return nil, receiveErr + } + d.mx.Unlock() + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-d.hasData: + } + goto start +} diff --git a/internal/http3/frames.go b/internal/http3/frames.go index a3cd88ad..b2d59a52 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -7,6 +7,7 @@ import ( "io" "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/quic-go/quic-go" ) // FrameType is the frame type of a HTTP/3 frame @@ -18,13 +19,19 @@ type frame interface{} var errHijacked = errors.New("hijacked") -func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { - qr := quicvarint.NewReader(r) +type frameParser struct { + r io.Reader + conn quic.Connection + unknownFrameHandler unknownFrameHandlerFunc +} + +func (p *frameParser) ParseNext() (frame, error) { + qr := quicvarint.NewReader(p.r) for { t, err := quicvarint.Read(qr) if err != nil { - if unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(0, err) + if p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(0, err) if err != nil { return nil, err } @@ -35,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f return nil, err } // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec - if t > 0xd && unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(FrameType(t), nil) + if t > 0xd && p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(FrameType(t), nil) if err != nil { return nil, err } @@ -56,11 +63,14 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f case 0x1: return &headersFrame{Length: l}, nil case 0x4: - return parseSettingsFrame(r, l) + return parseSettingsFrame(p.r, l) case 0x3: // CANCEL_PUSH case 0x5: // PUSH_PROMISE case 0x7: // GOAWAY case 0xd: // MAX_PUSH_ID + case 0x2, 0x6, 0x8, 0x9: + p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") + return nil, fmt.Errorf("http3: reserved frame type: %d", t) } // skip over unknown frames if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil { @@ -87,11 +97,18 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0x33 +const ( + // Extended CONNECT, RFC 9220 + settingExtendedConnect = 0x8 + // HTTP Datagrams, RFC 9297 + settingDatagram = 0x33 +) type settingsFrame struct { - Datagram bool - Other map[uint64]uint64 // all settings that we don't explicitly recognize + Datagram bool // HTTP Datagrams, RFC 9297 + ExtendedConnect bool // Extended CONNECT, RFC 9220 + + Other map[uint64]uint64 // all settings that we don't explicitly recognize } func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { @@ -107,7 +124,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } frame := &settingsFrame{} b := bytes.NewReader(buf) - var readDatagram bool + var readDatagram, readExtendedConnect bool for b.Len() > 0 { id, err := quicvarint.Read(b) if err != nil { // should not happen. We allocated the whole frame already. @@ -119,13 +136,22 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } switch id { + case settingExtendedConnect: + if readExtendedConnect { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + readExtendedConnect = true + if val != 0 && val != 1 { + return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val) + } + frame.ExtendedConnect = val == 1 case settingDatagram: if readDatagram { return nil, fmt.Errorf("duplicate setting: %d", id) } readDatagram = true if val != 0 && val != 1 { - return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val) + return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val) } frame.Datagram = val == 1 default: @@ -143,18 +169,25 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { func (f *settingsFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x4) - var l int64 + var l int for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } + if f.ExtendedConnect { + l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1) + } b = quicvarint.Append(b, uint64(l)) if f.Datagram { b = quicvarint.Append(b, settingDatagram) b = quicvarint.Append(b, 1) } + if f.ExtendedConnect { + b = quicvarint.Append(b, settingExtendedConnect) + b = quicvarint.Append(b, 1) + } for id, val := range f.Other { b = quicvarint.Append(b, id) b = quicvarint.Append(b, val) diff --git a/internal/http3/headers.go b/internal/http3/headers.go index 2eb5ca29..cbd79ecd 100644 --- a/internal/http3/headers.go +++ b/internal/http3/headers.go @@ -3,14 +3,17 @@ package http3 import ( "errors" "fmt" - "github.com/quic-go/qpack" - "golang.org/x/net/http/httpguts" "net/http" + "net/textproto" + "net/url" "strconv" "strings" + + "github.com/quic-go/qpack" + "golang.org/x/net/http/httpguts" ) -type Header struct { +type header struct { // Pseudo header fields defined in RFC 9114 Path string Method string @@ -19,28 +22,37 @@ type Header struct { Status string // for Extended connect Protocol string - // parsed and deduplicated + // parsed and deduplicated. -1 if no Content-Length header is sent ContentLength int64 // all non-pseudo headers Headers http.Header } -func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { - hdr := Header{Headers: make(http.Header, len(headers))} +// connection-specific header fields must not be sent on HTTP/3 +var invalidHeaderFields = [...]string{ + "connection", + "keep-alive", + "proxy-connection", + "transfer-encoding", + "upgrade", +} + +func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) { + hdr := header{Headers: make(http.Header, len(headers))} var readFirstRegularHeader, readContentLength bool var contentLengthStr string for _, h := range headers { // field names need to be lowercase, see section 4.2 of RFC 9114 if strings.ToLower(h.Name) != h.Name { - return Header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) + return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) } if !httpguts.ValidHeaderFieldValue(h.Value) { - return Header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) + return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) } if h.IsPseudo() { if readFirstRegularHeader { // all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114 - return Header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) + return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) } var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses switch h.Name { @@ -58,17 +70,25 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { hdr.Status = h.Value isResponsePseudoHeader = true default: - return Header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) + return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) } if isRequest && isResponsePseudoHeader { - return Header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) + return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) } if !isRequest && !isResponsePseudoHeader { - return Header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) + return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) } } else { if !httpguts.ValidHeaderFieldName(h.Name) { - return Header{}, fmt.Errorf("invalid header field name: %q", h.Name) + return header{}, fmt.Errorf("invalid header field name: %q", h.Name) + } + for _, invalidField := range invalidHeaderFields { + if h.Name == invalidField { + return header{}, fmt.Errorf("invalid header field name: %q", h.Name) + } + } + if h.Name == "te" && h.Value != "trailers" { + return header{}, fmt.Errorf("invalid TE header field value: %q", h.Value) } readFirstRegularHeader = true switch h.Name { @@ -79,18 +99,19 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { readContentLength = true contentLengthStr = h.Value } else if contentLengthStr != h.Value { - return Header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) + return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) } default: hdr.Headers.Add(h.Name, h.Value) } } } + hdr.ContentLength = -1 if len(contentLengthStr) > 0 { // use ParseUint instead of ParseInt, so that parsing fails on negative values cl, err := strconv.ParseUint(contentLengthStr, 10, 63) if err != nil { - return Header{}, fmt.Errorf("invalid content length: %w", err) + return header{}, fmt.Errorf("invalid content length: %w", err) } hdr.Headers.Set("Content-Length", contentLengthStr) hdr.ContentLength = int64(cl) @@ -98,32 +119,141 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { return hdr, nil } -func hostnameFromRequest(req *http.Request) string { - if req.URL != nil { - return req.URL.Host +func parseTrailers(headers []qpack.HeaderField) (http.Header, error) { + h := make(http.Header, len(headers)) + for _, field := range headers { + if field.IsPseudo() { + return nil, fmt.Errorf("http3: received pseudo header in trailer: %s", field.Name) + } + h.Add(field.Name, field.Value) } - return "" + return h, nil } -func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) { - hdr, err := parseHeaders(headerFields, false) +func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) { + hdr, err := parseHeaders(headerFields, true) if err != nil { return nil, err } - if hdr.Status == "" { - return nil, errors.New("missing status field") + // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 + if len(hdr.Headers["Cookie"]) > 0 { + hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; ")) + } + + isConnect := hdr.Method == http.MethodConnect + // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 + isExtendedConnected := isConnect && hdr.Protocol != "" + if isExtendedConnected { + if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" { + return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") + } + } else if isConnect { + if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT + return nil, errors.New(":path must be empty and :authority must not be empty") + } + } else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 { + return nil, errors.New(":path, :authority and :method must not be empty") } - rsp := &http.Response{ - Proto: "HTTP/3.0", + + if !isExtendedConnected && len(hdr.Protocol) > 0 { + return nil, errors.New(":protocol must be empty") + } + + var u *url.URL + var requestURI string + + protocol := "HTTP/3.0" + + if isConnect { + u = &url.URL{} + if isExtendedConnected { + u, err = url.ParseRequestURI(hdr.Path) + if err != nil { + return nil, err + } + protocol = hdr.Protocol + } else { + u.Path = hdr.Path + } + u.Scheme = hdr.Scheme + u.Host = hdr.Authority + requestURI = hdr.Authority + } else { + u, err = url.ParseRequestURI(hdr.Path) + if err != nil { + return nil, fmt.Errorf("invalid content length: %w", err) + } + requestURI = hdr.Path + } + + return &http.Request{ + Method: hdr.Method, + URL: u, + Proto: protocol, ProtoMajor: 3, + ProtoMinor: 0, Header: hdr.Headers, + Body: nil, ContentLength: hdr.ContentLength, + Host: hdr.Authority, + RequestURI: requestURI, + }, nil +} + +func hostnameFromURL(url *url.URL) string { + if url != nil { + return url.Host } + return "" +} + +// updateResponseFromHeaders sets up http.Response as an HTTP/3 response, +// using the decoded qpack header filed. +// It is only called for the HTTP header (and not the HTTP trailer). +// It takes an http.Response as an argument to allow the caller to set the trailer later on. +func updateResponseFromHeaders(rsp *http.Response, headerFields []qpack.HeaderField) error { + hdr, err := parseHeaders(headerFields, false) + if err != nil { + return err + } + if hdr.Status == "" { + return errors.New("missing status field") + } + rsp.Proto = "HTTP/3.0" + rsp.ProtoMajor = 3 + rsp.Header = hdr.Headers + processTrailers(rsp) + rsp.ContentLength = hdr.ContentLength + status, err := strconv.Atoi(hdr.Status) if err != nil { - return nil, fmt.Errorf("invalid status code: %w", err) + return fmt.Errorf("invalid status code: %w", err) } rsp.StatusCode = status rsp.Status = hdr.Status + " " + http.StatusText(status) - return rsp, nil + return nil +} + +// processTrailers initializes the rsp.Trailer map, and adds keys for every announced header value. +// The Trailer header is removed from the http.Response.Header map. +// It handles both duplicate as well as comma-separated values for the Trailer header. +// For example: +// +// Trailer: Trailer1, Trailer2 +// Trailer: Trailer3 +// +// Will result in a http.Response.Trailer map containing the keys "Trailer1", "Trailer2", "Trailer3". +func processTrailers(rsp *http.Response) { + rawTrailers, ok := rsp.Header["Trailer"] + if !ok { + return + } + + rsp.Trailer = make(http.Header) + for _, rawVal := range rawTrailers { + for _, val := range strings.Split(rawVal, ",") { + rsp.Trailer[http.CanonicalHeaderKey(textproto.TrimString(val))] = nil + } + } + delete(rsp.Header, "Trailer") } diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index bfaf4214..7c969090 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -1,54 +1,103 @@ package http3 import ( + "context" "errors" "fmt" + "io" + "net/http" + + "github.com/imroc/req/v3/internal/compress" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/transport" "github.com/quic-go/quic-go" + + "github.com/quic-go/qpack" ) -// A Stream is a HTTP/3 stream. +// A Stream is an HTTP/3 request stream. // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. -type Stream quic.Stream +type Stream interface { + quic.Stream + + SendDatagram([]byte) error + ReceiveDatagram(context.Context) ([]byte, error) +} + +// A RequestStream is an HTTP/3 request stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type RequestStream interface { + Stream + + // SendRequestHeader sends the HTTP request. + // It is invalid to call it more than once. + // It is invalid to call it after Write has been called. + SendRequestHeader(req *http.Request) error + + // ReadResponse reads the HTTP response from the stream. + // It is invalid to call it more than once. + // It doesn't set Response.Request and Response.TLS. + // It is invalid to call it after Read has been called. + ReadResponse() (*http.Response, error) +} -// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly -// from the QUIC stream, it writes to and reads from the HTTP stream. type stream struct { quic.Stream + conn *connection - buf []byte + buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers - onFrameError func() bytesRemainingInFrame uint64 + + datagrams *datagrammer + + parseTrailer func(io.Reader, uint64) error + parsedTrailer bool } var _ Stream = &stream{} -func newStream(str quic.Stream, onFrameError func()) *stream { +func newStream(str quic.Stream, conn *connection, datagrams *datagrammer, parseTrailer func(io.Reader, uint64) error) *stream { return &stream{ Stream: str, - onFrameError: onFrameError, - buf: make([]byte, 0, 16), + conn: conn, + buf: make([]byte, 16), + datagrams: datagrams, + parseTrailer: parseTrailer, } } func (s *stream) Read(b []byte) (int, error) { + fp := &frameParser{ + r: s.Stream, + conn: s.conn, + } if s.bytesRemainingInFrame == 0 { parseLoop: for { - frame, err := parseNextFrame(s.Stream, nil) + frame, err := fp.ParseNext() if err != nil { return 0, err } switch f := frame.(type) { - case *headersFrame: - // skip HEADERS frames - continue case *dataFrame: + if s.parsedTrailer { + return 0, errors.New("DATA frame received after trailers") + } s.bytesRemainingInFrame = f.Length break parseLoop + case *headersFrame: + if s.conn.perspective == PerspectiveServer { + continue + } + if s.parsedTrailer { + return 0, errors.New("additional HEADERS frame received after trailers") + } + s.parsedTrailer = true + return 0, s.parseTrailer(s.Stream, f.Length) default: - s.onFrameError() + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") // parseNextFrame skips over unknown frame types // Therefore, this condition is only entered when we parsed another known frame type. return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) @@ -80,44 +129,175 @@ func (s *stream) Write(b []byte) (int, error) { return s.Stream.Write(b) } -var errTooMuchData = errors.New("peer sent too much data") +func (s *stream) writeUnframed(b []byte) (int, error) { + return s.Stream.Write(b) +} -type lengthLimitedStream struct { +func (s *stream) StreamID() quic.StreamID { + return s.Stream.StreamID() +} + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type requestStream struct { + ctx context.Context *stream - contentLength int64 - read int64 - resetStream bool + *transport.Options + + responseBody io.ReadCloser // set by ReadResponse + + decoder *qpack.Decoder + requestWriter *requestWriter + maxHeaderBytes uint64 + reqDone chan<- struct{} + disableCompression bool + response *http.Response + + sentRequest bool + requestedGzip bool + isConnect bool } -var _ Stream = &lengthLimitedStream{} +var _ RequestStream = &requestStream{} -func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream { - return &lengthLimitedStream{ - stream: str, - contentLength: contentLength, +func newRequestStream( + ctx context.Context, + options *transport.Options, + str *stream, + requestWriter *requestWriter, + reqDone chan<- struct{}, + decoder *qpack.Decoder, + disableCompression bool, + maxHeaderBytes uint64, + rsp *http.Response, +) *requestStream { + return &requestStream{ + ctx: ctx, + Options: options, + stream: str, + requestWriter: requestWriter, + reqDone: reqDone, + decoder: decoder, + disableCompression: disableCompression, + maxHeaderBytes: maxHeaderBytes, + response: rsp, } } -func (s *lengthLimitedStream) checkContentLengthViolation() error { - if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() { - if !s.resetStream { - s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) - s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) - s.resetStream = true +func (s *requestStream) Read(b []byte) (int, error) { + if s.responseBody == nil { + return 0, errors.New("http3: invalid use of RequestStream.Read: need to call ReadResponse first") + } + return s.responseBody.Read(b) +} + +func (s *requestStream) SendRequestHeader(req *http.Request) error { + if s.sentRequest { + return errors.New("http3: invalid duplicate use of SendRequestHeader") + } + if !s.DisableCompression && !s.disableCompression && req.Method != http.MethodHead && + req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { + s.requestedGzip = true + } + dumps := dump.GetDumpers(req.Context(), s.Dump) + var headerDumps []*dump.Dumper + for _, dump := range dumps { + if dump.RequestHeader() { + headerDumps = append(headerDumps, dump) } - return errTooMuchData } - return nil + + s.isConnect = req.Method == http.MethodConnect + s.sentRequest = true + return s.requestWriter.WriteRequestHeader(s.Stream, req, s.requestedGzip, headerDumps) } -func (s *lengthLimitedStream) Read(b []byte) (int, error) { - if err := s.checkContentLengthViolation(); err != nil { - return 0, err +func (s *requestStream) ReadResponse() (*http.Response, error) { + fp := &frameParser{ + r: s.Stream, + conn: s.conn, } - n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)]) - s.read += int64(n) - if err := s.checkContentLengthViolation(); err != nil { - return n, err + frame, err := fp.ParseNext() + if err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) + return nil, fmt.Errorf("http3: parsing frame failed: %w", err) } - return n, err + hf, ok := frame.(*headersFrame) + if !ok { + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame") + return nil, errors.New("http3: expected first frame to be a HEADERS frame") + } + if hf.Length > s.maxHeaderBytes { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) + return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(s.Stream, headerBlock); err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) + return nil, fmt.Errorf("http3: failed to read response headers: %w", err) + } + hfs, err := s.decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "") + return nil, fmt.Errorf("http3: failed to decode response headers: %w", err) + } + ds := dump.GetResponseHeaderDumpers(s.ctx, s.Dump) + if ds.ShouldDump() { + for _, h := range hfs { + ds.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", h.Name, h.Value))) + } + ds.DumpResponseHeader([]byte("\r\n")) + } + res := s.response + if err := updateResponseFromHeaders(res, hfs); err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + return nil, fmt.Errorf("http3: invalid response: %w", err) + } + + // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). + // See section 4.1.2 of RFC 9114. + respBody := newResponseBody(s.stream, res.ContentLength, s.reqDone) + + // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. + isInformational := res.StatusCode >= 100 && res.StatusCode < 200 + isNoContent := res.StatusCode == http.StatusNoContent + isSuccessfulConnect := s.isConnect && res.StatusCode >= 200 && res.StatusCode < 300 + if (isInformational || isNoContent || isSuccessfulConnect) && res.ContentLength == -1 { + res.ContentLength = 0 + } + if s.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + s.responseBody = compress.NewGzipReader(respBody) + res.Uncompressed = true + } else if s.AutoDecompression { + contentEncoding := res.Header.Get("Content-Encoding") + if contentEncoding != "" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Uncompressed = true + res.Body = compress.NewCompressReader(respBody, contentEncoding) + } + } else { + s.responseBody = respBody + } + res.Body = s.responseBody + return res, nil +} + +func (s *stream) SendDatagram(b []byte) error { + // TODO: reject if datagrams are not negotiated (yet) + return s.datagrams.Send(b) +} + +func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) { + // TODO: reject if datagrams are not negotiated (yet) + return s.datagrams.Receive(ctx) } diff --git a/internal/http3/protocol.go b/internal/http3/protocol.go new file mode 100644 index 00000000..d5ba5bb6 --- /dev/null +++ b/internal/http3/protocol.go @@ -0,0 +1,119 @@ +package http3 + +import ( + "math" + + "github.com/quic-go/quic-go" +) + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) + +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "server" + case PerspectiveClient: + return "client" + default: + return "invalid perspective" + } +} + +// The version numbers, making grepping easier +const ( + VersionUnknown quic.Version = math.MaxUint32 + versionDraft29 quic.Version = 0xff00001d // draft-29 used to be a widely deployed version + Version1 quic.Version = 0x1 + Version2 quic.Version = 0x6b3343cf +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []quic.Version{Version1, Version2} + +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + +// A StreamID in QUIC +type StreamID int64 + +// InitiatedBy says if the stream was initiated by the client or by the server +func (s StreamID) InitiatedBy() Perspective { + if s%2 == 0 { + return PerspectiveClient + } + return PerspectiveServer +} + +// Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi +} + +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() StreamNum { + return StreamNum(s/4) + 1 +} + +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 443b34c2..2af1ce3f 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -13,7 +13,7 @@ import ( "sync" "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" + reqheader "github.com/imroc/req/v3/internal/header" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" @@ -28,17 +28,14 @@ type requestWriter struct { mutex sync.Mutex encoder *qpack.Encoder headerBuf *bytes.Buffer - - debugf func(format string, v ...interface{}) } -func newRequestWriter(debugf func(format string, v ...interface{})) *requestWriter { +func newRequestWriter() *requestWriter { headerBuf := &bytes.Buffer{} encoder := qpack.NewEncoder(headerBuf) return &requestWriter{ encoder: encoder, headerBuf: headerBuf, - debugf: debugf, } } @@ -71,6 +68,10 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } +func isExtendedConnectRequest(req *http.Request) bool { + return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" +} + // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -89,7 +90,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } // http.NewRequest sets this field to HTTP/1.1 - isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + isExtendedConnect := isExtendedConnectRequest(req) var path string if req.Method != http.MethodConnect || isExtendedConnect { @@ -123,11 +124,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra enumerateHeaders := func(f func(name, value string)) { var writeHeader func(name string, value ...string) - var kvs []header.KeyValues + var kvs []reqheader.KeyValues sort := false - if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + if req.Header != nil && len(req.Header[reqheader.PseudoHeaderOderKey]) > 0 { writeHeader = func(name string, value ...string) { - kvs = append(kvs, header.KeyValues{ + kvs = append(kvs, reqheader.KeyValues{ Key: name, Values: value, }) @@ -156,7 +157,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } if sort { - header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + reqheader.SortKeyValues(kvs, req.Header[reqheader.PseudoHeaderOderKey]) for _, kv := range kvs { for _, v := range kv.Values { f(kv.Key, v) @@ -164,11 +165,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } } - if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + if req.Header != nil && len(req.Header[reqheader.HeaderOderKey]) > 0 { sort = true kvs = nil writeHeader = func(name string, value ...string) { - kvs = append(kvs, header.KeyValues{ + kvs = append(kvs, reqheader.KeyValues{ Key: name, Values: value, }) @@ -188,7 +189,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra var didUA bool for k, vv := range req.Header { - if header.IsExcluded(k) { + if reqheader.IsExcluded(k) { continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -217,11 +218,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra writeHeader("accept-encoding", "gzip") } if !didUA { - writeHeader("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", reqheader.DefaultUserAgent) } if sort { - header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + reqheader.SortKeyValues(kvs, req.Header[reqheader.HeaderOderKey]) for _, kv := range kvs { for _, v := range kv.Values { f(kv.Key, v) @@ -269,13 +270,10 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. -func authorityAddr(scheme string, authority string) (addr string) { +func authorityAddr(authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port port = "443" - if scheme == "http" { - port = "80" - } host = authority } if a, err := idna.ToASCII(host); err == nil { diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 89624094..afef7807 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -11,7 +11,6 @@ import ( "strings" "sync" "sync/atomic" - "time" "github.com/imroc/req/v3/internal/transport" @@ -20,71 +19,88 @@ import ( "golang.org/x/net/http/httpguts" ) -type roundTripCloser interface { - RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) - HandshakeComplete() bool - io.Closer +// Settings are HTTP/3 settings that apply to the underlying connection. +type Settings struct { + // Support for HTTP/3 datagrams (RFC 9297) + EnableDatagrams bool + // Extended CONNECT, RFC 9220 + EnableExtendedConnect bool + // Other settings, defined by the application + Other map[uint64]uint64 } -type roundTripCloserWithCount struct { - roundTripCloser +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool +} + +type singleRoundTripper interface { + OpenRequestStream(context.Context) (RequestStream, error) + RoundTrip(*http.Request) (*http.Response, error) +} + +type roundTripperWithCount struct { + cancel context.CancelFunc + dialing chan struct{} // closed as soon as quic.Dial(Early) returned + dialErr error + conn quic.EarlyConnection + rt singleRoundTripper + useCount atomic.Int64 } +func (r *roundTripperWithCount) Close() error { + r.cancel() + <-r.dialing + if r.conn != nil { + return r.conn.CloseWithError(0, "") + } + return nil +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { *transport.Options mutex sync.Mutex - // QuicConfig is the quic.Config used for dialing new connections. + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // QUICConfig is the quic.Config used for dialing new connections. // If nil, reasonable default values will be used. - QuicConfig *quic.Config + QUICConfig *quic.Config + + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, a UDPConn will be created at the first request + // and will be reused for subsequent connections to other servers. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - // Enable support for HTTP/3 datagrams. - // If set to true, QuicConfig.EnableDatagram will be set. - // - // See https://datatracker.ietf.org/doc/html/rfc9297. + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. EnableDatagrams bool // Additional HTTP/3 settings. - // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. - // It is called right after parsing the frame type. - // If parsing the frame type fails, the error is passed to the callback. - // In that case, the frame type will not be set. - // Callers can either ignore the frame and return control of the stream back to HTTP/3 - // (by returning hijacked false). - // Alternatively, callers can take over the QUIC stream (by returning hijacked true). - StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 - // When set, this callback is called for unknown unidirectional stream of unknown stream type. - // If parsing the stream type fails, the error is passed to the callback. - // In that case, the stream type will not be set. - UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + initOnce sync.Once + initErr error - // Dial specifies an optional dial function for creating QUIC - // connections for requests. - // If Dial is nil, a UDPConn will be created at the first request - // and will be reused for subsequent connections to other servers. - Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + newClient func(quic.EarlyConnection) singleRoundTripper - newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) // so we can mock it in tests - clients map[string]*roundTripCloserWithCount + clients map[string]*roundTripperWithCount transport *quic.Transport } -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. - // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. - OnlyCachedConn bool - // DontCloseRequestStream controls whether the request stream is closed after sending the request. - // If set, context cancellations have no effect after the response headers are received. - DontCloseRequestStream bool -} - var ( _ http.RoundTripper = &RoundTripper{} _ io.Closer = &RoundTripper{} @@ -95,6 +111,11 @@ var ErrNoCachedConn = errors.New("http3: no cached connection was available") // RoundTripOpt is like RoundTrip, but takes options. func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + r.initOnce.Do(func() { r.initErr = r.init() }) + if r.initErr != nil { + return nil, r.initErr + } + if req.URL == nil { closeRequestBody(req) return nil, errors.New("http3: nil Request.URL") @@ -111,21 +132,15 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. closeRequestBody(req) return nil, errors.New("http3: nil Request.Header") } - - if req.URL.Scheme == "https" { - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("http3: invalid http header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) - } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) } } - } else { - closeRequestBody(req) - return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) } if req.Method != "" && !validMethod(req.Method) { @@ -133,8 +148,8 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. return nil, fmt.Errorf("http3: invalid method %q", req.Method) } - hostname := authorityAddr("https", hostnameFromRequest(req)) - cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) + hostname := authorityAddr(hostnameFromURL(req.URL)) + cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != ErrNoCachedConn { if debugf := r.Debugf; debugf != nil { debugf("HTTP/3 %s %s", req.Method, req.URL.String()) @@ -143,10 +158,27 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } + + select { + case <-cl.dialing: + case <-req.Context().Done(): + return nil, context.Cause(req.Context()) + } + + if cl.dialErr != nil { + r.removeClient(hostname) + return nil, cl.dialErr + } defer cl.useCount.Add(-1) - rsp, err := cl.RoundTripOpt(req, opt) + rsp, err := cl.rt.RoundTrip(req) if err != nil { - r.removeClient(hostname) + // non-nil errors on roundtrip are likely due to a problem with the connection + // so we remove the client from the cache so that subsequent trips reconnect + // context cancelation is excluded as is does not signify a connection error + if !errors.Is(err, context.Canceled) { + r.removeClient(hostname) + } + if isReused { if nerr, ok := err.(net.Error); ok && nerr.Timeout() { return r.RoundTripOpt(req, opt) @@ -161,82 +193,142 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } +func (r *RoundTripper) init() error { + if r.newClient == nil { + r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { + return &SingleDestinationRoundTripper{ + Options: r.Options, + Connection: conn, + EnableDatagrams: r.EnableDatagrams, + AdditionalSettings: r.AdditionalSettings, + MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, + } + } + } + if r.QUICConfig == nil { + r.QUICConfig = defaultQuicConfig.Clone() + r.QUICConfig.EnableDatagrams = r.EnableDatagrams + } + if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams { + return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") + } + if len(r.QUICConfig.Versions) == 0 { + r.QUICConfig = r.QUICConfig.Clone() + r.QUICConfig.Versions = []quic.Version{SupportedVersions[0]} + } + if len(r.QUICConfig.Versions) != 1 { + return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if r.QUICConfig.MaxIncomingStreams == 0 { + r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + return nil +} + // RoundTripOnlyCachedConn round trip only cached conn. func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) } // AddConn add a http3 connection, dial new conn if not exists. -func (r *RoundTripper) AddConn(addr string) error { - addr = authorityAddr("https", addr) - c, _, err := r.getClient(addr, false) - if err != nil { - return err - } - client, ok := c.roundTripCloser.(*client) - if !ok { - return errors.New("bad client type") +func (r *RoundTripper) AddConn(ctx context.Context, addr string) error { + addr = authorityAddr(addr) + cl, _, err := r.getClient(ctx, addr, false) + if err == nil { + cl.useCount.Add(-1) } - client.dialOnce.Do(func() { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - client.handshakeErr = client.dial(ctx) - }) - return client.handshakeErr + return err } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { +func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]*roundTripCloserWithCount) + r.clients = make(map[string]*roundTripperWithCount) } - client, ok := r.clients[hostname] + cl, ok := r.clients[hostname] if !ok { if onlyCached { return nil, false, ErrNoCachedConn } - var err error - newCl := newClient - if r.newClient != nil { - newCl = r.newClient + ctx, cancel := context.WithCancel(ctx) + cl = &roundTripperWithCount{ + dialing: make(chan struct{}), + cancel: cancel, } - dial := r.Dial - if dial == nil { - if r.transport == nil { - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, false, err - } - r.transport = &quic.Transport{Conn: udpConn} + go func() { + defer close(cl.dialing) + defer cancel() + conn, rt, err := r.dial(ctx, hostname) + if err != nil { + cl.dialErr = err + return } - dial = r.makeDialer() + cl.conn = conn + cl.rt = rt + }() + r.clients[hostname] = cl + } + select { + case <-cl.dialing: + if cl.dialErr != nil { + delete(r.clients, hostname) + return nil, false, cl.dialErr + } + select { + case <-cl.conn.HandshakeComplete(): + isReused = true + default: } - c, err := newCl( - hostname, - r.TLSClientConfig, - &roundTripperOpts{ - EnableDatagram: r.EnableDatagrams, - DisableCompression: r.DisableCompression, - MaxHeaderBytes: r.MaxResponseHeaderBytes, - StreamHijacker: r.StreamHijacker, - UniStreamHijacker: r.UniStreamHijacker, - dump: r.Dump, - AdditionalSettings: r.AdditionalSettings, - }, - r.QuicConfig, - dial, - r.Options, - ) + default: + } + cl.useCount.Add(1) + return cl, isReused, nil +} + +func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { + var tlsConf *tls.Config + if r.TLSClientConfig == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = r.TLSClientConfig.Clone() + } + if tlsConf.ServerName == "" { + sni, _, err := net.SplitHostPort(hostname) if err != nil { - return nil, false, err + // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. + sni = hostname } - client = &roundTripCloserWithCount{roundTripCloser: c} - r.clients[hostname] = client + tlsConf.ServerName = sni } - client.useCount.Add(1) - return client, isReused, nil + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])} + + dial := r.Dial + if dial == nil { + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + r.transport = &quic.Transport{Conn: udpConn} + } + dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } + } + + conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig) + if err != nil { + return nil, nil, err + } + return conn, r.newClient(conn), nil } func (r *RoundTripper) removeClient(hostname string) { @@ -253,8 +345,8 @@ func (r *RoundTripper) removeClient(hostname string) { func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() - for _, client := range r.clients { - if err := client.Close(); err != nil { + for _, cl := range r.clients { + if err := cl.Close(); err != nil { return err } } @@ -299,23 +391,12 @@ func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } -// makeDialer makes a QUIC dialer using r.udpConn. -func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) - } -} - func (r *RoundTripper) CloseIdleConnections() { r.mutex.Lock() defer r.mutex.Unlock() - for hostname, client := range r.clients { - if client.useCount.Load() == 0 { - client.Close() + for hostname, cl := range r.clients { + if cl.useCount.Load() == 0 { + cl.Close() delete(r.clients, hostname) } } diff --git a/internal/http3/server.go b/internal/http3/server.go index 9f94e7b5..4c36f0f0 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -1,16 +1,12 @@ package http3 -import ( - "github.com/quic-go/quic-go" -) +import "github.com/quic-go/quic-go" -const ( - nextProtoH3Draft29 = "h3-29" - nextProtoH3 = "h3" -) +// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. +const NextProtoH3 = "h3" // StreamType is the stream type of a unidirectional stream. -type StreamType uint64 +type ServerStreamType uint64 const ( streamTypeControlStream = 0 @@ -20,25 +16,11 @@ const ( ) func versionToALPN(v quic.Version) string { + //nolint:exhaustive // These are all the versions we care about. switch v { case Version1, Version2: - return nextProtoH3 - case VersionDraft29: - return nextProtoH3Draft29 + return NextProtoH3 + default: + return "" } - return "" -} - -type requestError struct { - err error - streamErr ErrCode - connErr ErrCode -} - -func newStreamError(code ErrCode, err error) requestError { - return requestError{err: err, streamErr: code} -} - -func newConnError(code ErrCode, err error) requestError { - return requestError{err: err, connErr: code} } diff --git a/internal/http3/state_tracking_stream.go b/internal/http3/state_tracking_stream.go new file mode 100644 index 00000000..9cf17f5e --- /dev/null +++ b/internal/http3/state_tracking_stream.go @@ -0,0 +1,116 @@ +package http3 + +import ( + "context" + "errors" + "os" + "sync" + + "github.com/quic-go/quic-go" +) + +var _ quic.Stream = &stateTrackingStream{} + +// stateTrackingStream is an implementation of quic.Stream that delegates +// to an underlying stream +// it takes care of proxying send and receive errors onto an implementation of +// the errorSetter interface (intended to be occupied by a datagrammer) +// it is also responsible for clearing the stream based on its ID from its +// parent connection, this is done through the streamClearer interface when +// both the send and receive sides are closed +type stateTrackingStream struct { + quic.Stream + + mx sync.Mutex + sendErr error + recvErr error + + clearer streamClearer + setter errorSetter +} + +type streamClearer interface { + clearStream(quic.StreamID) +} + +type errorSetter interface { + SetSendError(error) + SetReceiveError(error) +} + +func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { + t := &stateTrackingStream{ + Stream: s, + clearer: clearer, + setter: setter, + } + + context.AfterFunc(s.Context(), func() { + t.closeSend(context.Cause(s.Context())) + }) + + return t +} + +func (s *stateTrackingStream) closeSend(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + // clear the stream the first time both the send + // and receive are finished + if s.sendErr == nil { + if s.recvErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetSendError(e) + s.sendErr = e + } +} + +func (s *stateTrackingStream) closeReceive(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + // clear the stream the first time both the send + // and receive are finished + if s.recvErr == nil { + if s.sendErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetReceiveError(e) + s.recvErr = e + } +} + +func (s *stateTrackingStream) Close() error { + s.closeSend(errors.New("write on closed stream")) + return s.Stream.Close() +} + +func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { + s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelWrite(e) +} + +func (s *stateTrackingStream) Write(b []byte) (int, error) { + n, err := s.Stream.Write(b) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + s.closeSend(err) + } + return n, err +} + +func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { + s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelRead(e) +} + +func (s *stateTrackingStream) Read(b []byte) (int, error) { + n, err := s.Stream.Read(b) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + s.closeReceive(err) + } + return n, err +} diff --git a/internal/quic-go/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go index 60d17e3d..9a22e334 100644 --- a/internal/quic-go/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -26,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) { return 0, err } // the first two bits of the first byte encode the length - len := 1 << ((firstByte & 0xc0) >> 6) + l := 1 << ((firstByte & 0xc0) >> 6) b1 := firstByte & (0xff - 0xc0) - if len == 1 { + if l == 1 { return uint64(b1), nil } b2, err := r.ReadByte() if err != nil { return 0, err } - if len == 2 { + if l == 2 { return uint64(b2) + uint64(b1)<<8, nil } b3, err := r.ReadByte() @@ -46,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) { if err != nil { return 0, err } - if len == 4 { + if l == 4 { return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil } b5, err := r.ReadByte() @@ -68,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) { return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } +// Parse reads a number in the QUIC varint format. +// It returns the number of bytes consumed. +func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) { + if len(b) == 0 { + return 0, 0, io.EOF + } + firstByte := b[0] + // the first two bits of the first byte encode the length + l := 1 << ((firstByte & 0xc0) >> 6) + if len(b) < l { + return 0, 0, io.ErrUnexpectedEOF + } + b0 := firstByte & (0xff - 0xc0) + if l == 1 { + return uint64(b0), 1, nil + } + if l == 2 { + return uint64(b[1]) + uint64(b0)<<8, 2, nil + } + if l == 4 { + return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil + } + return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil +} + // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { @@ -89,7 +114,7 @@ func Append(b []byte, i uint64) []byte { } // AppendWithLen append i in the QUIC varint format with the desired length. -func AppendWithLen(b []byte, i uint64, length int64) []byte { +func AppendWithLen(b []byte, i uint64, length int) []byte { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } @@ -107,17 +132,17 @@ func AppendWithLen(b []byte, i uint64, length int64) []byte { } else if length == 8 { b = append(b, 0b11000000) } - for j := int64(1); j < length-l; j++ { + for j := 1; j < length-l; j++ { b = append(b, 0) } - for j := int64(0); j < l; j++ { + for j := 0; j < l; j++ { b = append(b, uint8(i>>(8*(l-1-j)))) } return b } // Len determines the number of bytes that will be needed to write the number i. -func Len(i uint64) int64 { +func Len(i uint64) int { if i <= maxVarInt1 { return 1 } diff --git a/transport.go b/transport.go index 69c8932f..4dd2bceb 100644 --- a/transport.go +++ b/transport.go @@ -661,7 +661,7 @@ func (t *Transport) handlePendingAltSvc(u *url.URL, pas *pendingAltSvc) { case "h3": // only support h3 in alt-svc for now u2 := altsvcutil.ConvertURL(pas.Entries[i], u) hostname := u2.Host - err := t.t3.AddConn(hostname) + err := t.t3.AddConn(context.Background(), hostname) if err != nil { if t.Debugf != nil { t.Debugf("failed to get http3 connection: %s", err.Error()) From 16b3f61d8066b4c62d109a73abcebc30ca8c391d Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 11:00:18 +0800 Subject: [PATCH 825/843] update go modules --- go.mod | 6 +++--- go.sum | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ef6f380a..87ce144f 100644 --- a/go.mod +++ b/go.mod @@ -16,14 +16,14 @@ require ( require ( github.com/cloudflare/circl v1.4.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9 // indirect + github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.27.0 // indirect - golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect - golang.org/x/tools v0.24.0 // indirect + golang.org/x/tools v0.25.0 // indirect ) diff --git a/go.sum b/go.sum index 4147995e..5976620d 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9 h1:q5g0N9eal4bmJwXHC5z0QCKs8qhS35hFfq0BAYsIwZI= github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 h1:c5FlPPgxOn7kJz3VoPLkQYQXGBS3EklQ4Zfi57uOuqQ= +github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -39,6 +42,8 @@ golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk= golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= @@ -53,6 +58,8 @@ golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= +golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= +golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From dcdcd1af6ed2b0cccfc8d078c2da8b082c44d993 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 14:22:17 +0800 Subject: [PATCH 826/843] http3: use MaxResponseHeaderBytes option from transport --- internal/http3/client.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index 664f0f76..9f2e2b87 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -53,11 +53,6 @@ type SingleDestinationRoundTripper struct { StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) UniStreamHijacker func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) - // MaxResponseHeaderBytes specifies a limit on how many response bytes are - // allowed in the server's response header. - // Zero means to use a default limit. - MaxResponseHeaderBytes int64 - Logger *slog.Logger initOnce sync.Once From 551b96d2162ebac225e219033517073708a387fa Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 14:30:16 +0800 Subject: [PATCH 827/843] http3: replace slog with req's logger --- internal/http3/client.go | 20 ++++++++------------ internal/http3/conn.go | 16 ++++++---------- internal/http3/roundtrip.go | 14 ++++---------- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/internal/http3/client.go b/internal/http3/client.go index 9f2e2b87..f9341488 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "io" - "log/slog" "net/http" "net/http/httptrace" "net/textproto" @@ -53,8 +52,6 @@ type SingleDestinationRoundTripper struct { StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) UniStreamHijacker func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) - Logger *slog.Logger - initOnce sync.Once hconn *connection requestWriter *requestWriter @@ -76,15 +73,14 @@ func (c *SingleDestinationRoundTripper) init() { c.Connection, c.EnableDatagrams, PerspectiveClient, - c.Logger, 0, c.Options, ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { if err := c.setupConn(c.hconn); err != nil { - if c.Logger != nil { - c.Logger.Debug("Setting up connection failed", "error", err) + if c.Debugf != nil { + c.Debugf("Setting up connection failed: %s", err.Error()) } c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } @@ -113,8 +109,8 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { for { str, err := c.hconn.AcceptStream(context.Background()) if err != nil { - if c.Logger != nil { - c.Logger.Debug("accepting bidirectional stream failed", "error", err) + if c.Debugf != nil { + c.Debugf("accepting bidirectional stream failed: %s", err.Error()) } return } @@ -131,8 +127,8 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { return } if err != nil { - if c.Logger != nil { - c.Logger.Debug("error handling stream", "error", err) + if c.Debugf != nil { + c.Debugf("error handling stream: %s", err.Error()) } } c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") @@ -283,8 +279,8 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques go func() { dumps := dump.GetDumpers(req.Context(), c.Dump) if err := c.sendRequestBody(str, req.Body, dumps); err != nil { - if c.Logger != nil { - c.Logger.Debug("error writing request", "error", err) + if c.Debugf != nil { + c.Debugf("error writing request: %s", err.Error()) } } str.Close() diff --git a/internal/http3/conn.go b/internal/http3/conn.go index 60f0e259..fa5302f9 100644 --- a/internal/http3/conn.go +++ b/internal/http3/conn.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "log/slog" "net" "net/http" "sync" @@ -44,7 +43,6 @@ type connection struct { ctx context.Context perspective Perspective - logger *slog.Logger enableDatagrams bool @@ -65,7 +63,6 @@ func newConnection( quicConn quic.Connection, enableDatagrams bool, perspective Perspective, - logger *slog.Logger, idleTimeout time.Duration, options *transport.Options, ) *connection { @@ -74,7 +71,6 @@ func newConnection( Connection: quicConn, Options: options, perspective: perspective, - logger: logger, idleTimeout: idleTimeout, enableDatagrams: enableDatagrams, decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), @@ -183,8 +179,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(ServerStreamType, q for { str, err := c.Connection.AcceptUniStream(context.Background()) if err != nil { - if c.logger != nil { - c.logger.Debug("accepting unidirectional stream failed", "error", err) + if c.Debugf != nil { + c.Debugf("accepting unidirectional stream failed: %s", err.Error()) } return } @@ -196,8 +192,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(ServerStreamType, q if hijack != nil && hijack(ServerStreamType(streamType), id, str, err) { return } - if c.logger != nil { - c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) + if c.Debugf != nil { + c.Debugf("reading stream type on stream failed (id %v): %s", str.StreamID(), err.Error()) } return } @@ -274,8 +270,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(ServerStreamType, q } go func() { if err := c.receiveDatagrams(); err != nil { - if c.logger != nil { - c.logger.Debug("receiving datagrams failed", "error", err) + if c.Debugf != nil { + c.Debugf("receiving datagrams failed: %s", err.Error()) } } }() diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index afef7807..26157b2e 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -87,11 +87,6 @@ type RoundTripper struct { // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - // MaxResponseHeaderBytes specifies a limit on how many response bytes are - // allowed in the server's response header. - // Zero means to use a default limit. - MaxResponseHeaderBytes int64 - initOnce sync.Once initErr error @@ -197,11 +192,10 @@ func (r *RoundTripper) init() error { if r.newClient == nil { r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { return &SingleDestinationRoundTripper{ - Options: r.Options, - Connection: conn, - EnableDatagrams: r.EnableDatagrams, - AdditionalSettings: r.AdditionalSettings, - MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, + Options: r.Options, + Connection: conn, + EnableDatagrams: r.EnableDatagrams, + AdditionalSettings: r.AdditionalSettings, } } } From 5c12b054fabe769b33fe1f7410eb98face4a2a6b Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 16:26:37 +0800 Subject: [PATCH 828/843] merge upstream net/http: 2024-09-10(ad6ee2) --- internal/transport/option.go | 10 +- transfer.go | 10 +- transport.go | 586 ++++++++++++++++++++--------------- 3 files changed, 349 insertions(+), 257 deletions(-) diff --git a/internal/transport/option.go b/internal/transport/option.go index e9403f39..78c88c19 100644 --- a/internal/transport/option.go +++ b/internal/transport/option.go @@ -3,11 +3,12 @@ package transport import ( "context" "crypto/tls" - "github.com/imroc/req/v3/internal/dump" "net" "net/http" "net/url" "time" + + "github.com/imroc/req/v3/internal/dump" ) // Options is transport's options. @@ -17,8 +18,13 @@ type Options struct { // request is aborted with the provided error. // // The proxy type is determined by the URL scheme. "http", - // "https", and "socks5" are supported. If the scheme is empty, + // "https", "socks5", and "socks5h" are supported. If the scheme is empty, // "http" is assumed. + // "socks5" is treated the same as "socks5h". + // + // If the proxy URL contains a userinfo subcomponent, + // the proxy request will pass the username and password + // in a Proxy-Authorization header. // // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) diff --git a/transfer.go b/transfer.go index 92a5d305..e9cd5a56 100644 --- a/transfer.go +++ b/transfer.go @@ -13,7 +13,7 @@ import ( "net/http" "net/textproto" "reflect" - "sort" + "slices" "strconv" "strings" "sync" @@ -281,7 +281,7 @@ func (t *transferWriter) writeHeader(writeHeader func(key string, values ...stri keys = append(keys, k) } if len(keys) > 0 { - sort.Strings(keys) + slices.Sort(keys) // TODO: could do better allocation-wise here, but trailers are rare, // so being lazy for now. err := writeHeader("Trailer", strings.Join(keys, ",")) @@ -973,7 +973,7 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) { return bl.b.readLocked(p) } -var laxContentLength = godebug.New("httplaxcontentlength") +var httplaxContentLength = godebug.New("httplaxcontentlength") // parseContentLength checks that the header is valid and then trims // whitespace. It returns -1 if no value is set otherwise the value @@ -987,8 +987,8 @@ func parseContentLength(clHeaders []string) (int64, error) { // The Content-Length must be a valid numeric value. // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 if cl == "" { - if laxContentLength.Value() == "1" { - laxContentLength.IncNonDefault() + if httplaxContentLength.Value() == "1" { + httplaxContentLength.IncNonDefault() return -1, nil } return 0, badStringError("invalid empty Content-Length", cl) diff --git a/transport.go b/transport.go index 4dd2bceb..f4aa9e3b 100644 --- a/transport.go +++ b/transport.go @@ -30,6 +30,7 @@ import ( "strings" "sync" "time" + _ "unsafe" "github.com/imroc/req/v3/http2" "github.com/imroc/req/v3/internal/altsvcutil" @@ -114,11 +115,13 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[cancelKey]func(error) + reqCanceler map[*http.Request]context.CancelCauseFunc connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + dialsInProgress wantConnQueue + altSvcJar altsvc.Jar pendingAltSvcs map[string]*pendingAltSvc pendingAltSvcsMu sync.Mutex @@ -733,13 +736,6 @@ func (t *Transport) autoDecodeResponseBody(res *http.Response) { res.Body = newAutoDecodeReadCloser(res.Body, t) } -// A cancelKey is the key of the reqCanceler map. -// We wrap the *http.Request in this type since we want to use the original request, -// not any transient one created by roundTrip. -type cancelKey struct { - req *http.Request -} - func (t *Transport) writeBufferSize() int { if t.WriteBufferSize > 0 { return t.WriteBufferSize @@ -813,14 +809,16 @@ func (t *Transport) hasCustomTLSDialer() bool { return t.DialTLSContext != nil } -// transportRequest is a wrapper around a *http.Request that adds +// transportRequest is a wrapper around a *Request that adds // optional extra headers to write and stores any error to return // from roundTrip. type transportRequest struct { *http.Request // original request, not to be mutated extra http.Header // extra headers to write, or nil trace *httptrace.ClientTrace // optional - cancelKey cancelKey + + ctx context.Context // canceled when we are done with the request + cancel context.CancelCauseFunc mu sync.Mutex // guards err err error // first setError value for mapRoundTripError to consider @@ -889,6 +887,22 @@ func (t *Transport) checkAltSvc(req *http.Request) (resp *http.Response, err err return } +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("field value for %q", k) + } + } + } + return "" +} + // roundTrip implements a http.RoundTripper over HTTP. func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error) { ctx := req.Context() @@ -908,21 +922,16 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error isHTTP := scheme == "http" || scheme == "https" if isHTTP { - // TODO: is h2c should also check this? - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - closeBody(req) - err = fmt.Errorf("net/http: invalid header field name %q", k) - return - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - closeBody(req) - // Don't include the value in the error, because it may be sensitive. - err = fmt.Errorf("net/http: invalid header field value for %q", k) - return - } - } + // Validate the outgoing headers. + if err := validateHeaders(req.Header); err != "" { + closeBody(req) + return nil, fmt.Errorf("net/http: invalid header %s", err) + } + + // Validate the outgoing trailers too. + if err := validateHeaders(req.Trailer); err != "" { + closeBody(req) + return nil, fmt.Errorf("net/http: invalid trailer %s", err) } } @@ -940,7 +949,6 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } origReq := req - cancelKey := cancelKey{origReq} req = setupRewindBody(req) if scheme == "https" && t.forceHttpVersion != h1 { @@ -977,16 +985,44 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error return nil, errors.New("http: no Host in request URL") } + // Transport request context. + // + // If RoundTrip returns an error, it cancels this context before returning. + // + // If RoundTrip returns no error: + // - For an HTTP/1 request, persistConn.readLoop cancels this context + // after reading the request body. + // - For an HTTP/2 request, RoundTrip cancels this context after the HTTP/2 + // RoundTripper returns. + ctx, cancel := context.WithCancelCause(req.Context()) + + // Convert Request.Cancel into context cancelation. + if origReq.Cancel != nil { + go awaitLegacyCancel(ctx, cancel, origReq) + } + + // Convert Transport.CancelRequest into context cancelation. + // + // This is lamentably expensive. CancelRequest has been deprecated for a long time + // and doesn't work on HTTP/2 requests. Perhaps we should drop support for it entirely. + cancel = t.prepareTransportCancel(origReq, cancel) + + defer func() { + if err != nil { + cancel(err) + } + }() + for { select { case <-ctx.Done(): closeBody(req) - return nil, ctx.Err() + return nil, context.Cause(ctx) default: } // treq gets modified by roundTrip, so we need to recreate for each retry. - treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} + treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel} cm, err := t.connectMethodForRequest(treq) if err != nil { closeBody(req) @@ -999,7 +1035,6 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error // to send it requests. pconn, err := t.getConn(treq, cm) if err != nil { - t.setReqCanceler(cancelKey, nil) closeBody(req) return nil, err } @@ -1007,12 +1042,20 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error var resp *http.Response if t.forceHttpVersion != h1 && pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) } if err == nil { + if pconn.alt != nil { + // HTTP/2 requests are not cancelable with CancelRequest, + // so we have no further need for the request context. + // + // On the HTTP/1 path, roundTrip takes responsibility for + // canceling the context after the response body is read. + cancel(errRequestDone) + } + resp.Request = origReq return resp, nil } @@ -1049,6 +1092,14 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error } } +func awaitLegacyCancel(ctx context.Context, cancel context.CancelCauseFunc, req *http.Request) { + select { + case <-req.Cancel: + cancel(common.ErrRequestCanceled) + case <-ctx.Done(): + } +} + var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") type readTrackingBody struct { @@ -1169,35 +1220,55 @@ func (t *Transport) CloseIdleConnections() { pconn.close(errCloseIdleConns) } } + t.connsPerHostMu.Lock() + t.dialsInProgress.all(func(w *wantConn) { + if w.cancelCtx != nil && !w.waiting() { + w.cancelCtx() + } + }) + t.connsPerHostMu.Unlock() + if t2 := t.t2; t2 != nil { t2.CloseIdleConnections() } } +// prepareTransportCancel sets up state to convert Transport.CancelRequest into context cancelation. +func (t *Transport) prepareTransportCancel(req *http.Request, origCancel context.CancelCauseFunc) context.CancelCauseFunc { + // Historically, RoundTrip has not modified the Request in any way. + // We could avoid the need to keep a map of all in-flight requests by adding + // a field to the Request containing its cancel func, and setting that field + // while the request is in-flight. Callers aren't supposed to reuse a Request + // until after the response body is closed, so this wouldn't violate any + // concurrency guarantees. + cancel := func(err error) { + origCancel(err) + t.reqMu.Lock() + delete(t.reqCanceler, req) + t.reqMu.Unlock() + } + t.reqMu.Lock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[*http.Request]context.CancelCauseFunc) + } + t.reqCanceler[req] = cancel + t.reqMu.Unlock() + return cancel +} + // CancelRequest cancels an in-flight request by closing its connection. -// CancelRequest should only be called after RoundTrip has returned. +// CancelRequest should only be called after [Transport.RoundTrip] has returned. // -// Deprecated: Use Request.WithContext to create a request with a +// Deprecated: Use [Request.WithContext] to create a request with a // cancelable context instead. CancelRequest cannot cancel HTTP/2 -// requests. +// requests. This may become a no-op in a future release of Go. func (t *Transport) CancelRequest(req *http.Request) { - t.cancelRequest(cancelKey{req}, common.ErrRequestCanceled) -} - -// Cancel an in-flight request, recording the error value. -// Returns whether the request was canceled. -func (t *Transport) cancelRequest(key cancelKey, err error) bool { - // This function must not return until the cancel func has completed. - // See: https://golang.org/issue/34658 t.reqMu.Lock() - defer t.reqMu.Unlock() - cancel := t.reqCanceler[key] - delete(t.reqCanceler, key) + cancel := t.reqCanceler[req] + t.reqMu.Unlock() if cancel != nil { - cancel(err) + cancel(common.ErrRequestCanceled) } - - return cancel != nil } // resetProxyConfig is used by tests. @@ -1313,7 +1384,7 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { // Loop over the waiting list until we find a w that isn't done already, and hand it pconn. for q.len() > 0 { w := q.popFront() - if w.tryDeliver(pconn, nil) { + if w.tryDeliver(pconn, nil, time.Time{}) { done = true break } @@ -1325,7 +1396,7 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { // list unconditionally, for any future clients too. for q.len() > 0 { w := q.popFront() - w.tryDeliver(pconn, nil) + w.tryDeliver(pconn, nil, time.Time{}) } } if q.len() == 0 { @@ -1429,7 +1500,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { list = list[:len(list)-1] continue } - delivered = w.tryDeliver(pconn, nil) + delivered = w.tryDeliver(pconn, nil, pconn.idleAt) if delivered { if pconn.alt != nil { // HTTP/2: multiple clients can share pconn. @@ -1458,7 +1529,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { t.idleConnWait = make(map[connectMethodKey]wantConnQueue) } q := t.idleConnWait[w.key] - q.cleanFront() + q.cleanFrontNotWaiting() q.pushBack(w) t.idleConnWait[w.key] = q return false @@ -1504,38 +1575,6 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { return removed } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } -} - -// replaceReqCanceler replaces an existing cancel function. If there is no cancel function -// for the request, we don't set the function and return false. -// Since CancelRequest will clear the canceler, we can use the return value to detect if -// the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { - t.reqMu.Lock() - defer t.reqMu.Unlock() - _, ok := t.reqCanceler[key] - if !ok { - return false - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } - return true -} - var zeroDialer net.Dialer func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { @@ -1556,10 +1595,8 @@ func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, e // These three options are racing against each other and use // wantConn to coordinate and agree about the winning outcome. type wantConn struct { - cm connectMethod - key connectMethodKey // cm.key() - ctx context.Context // context for dial - ready chan struct{} // closed when pc, err pair is delivered + cm connectMethod + key connectMethodKey // cm.key() // hooks for testing to know when dials are done // beforeDial is called in the getConn goroutine when the dial is queued. @@ -1567,44 +1604,52 @@ type wantConn struct { beforeDial func() afterDial func() - mu sync.Mutex // protects pc, err, close(ready) - pc *persistConn - err error + mu sync.Mutex // protects ctx, done and sending of the result + ctx context.Context // context for dial, cleared after delivered or canceled + cancelCtx context.CancelFunc + done bool // true after delivered or canceled + result chan connOrError // channel to deliver connection or error +} + +type connOrError struct { + pc *persistConn + err error + idleAt time.Time } // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { - select { - case <-w.ready: - return false - default: - return true - } + w.mu.Lock() + defer w.mu.Unlock() + + return !w.done } // getCtxForDial returns context for dial or nil if connection was delivered or canceled. func (w *wantConn) getCtxForDial() context.Context { w.mu.Lock() defer w.mu.Unlock() + return w.ctx } // tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. -func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { +func (w *wantConn) tryDeliver(pc *persistConn, err error, idleAt time.Time) bool { w.mu.Lock() defer w.mu.Unlock() - if w.pc != nil || w.err != nil { + if w.done { return false } - - w.ctx = nil - w.pc = pc - w.err = err - if w.pc == nil && w.err == nil { + if (pc == nil) == (err == nil) { panic("net/http: internal error: misuse of tryDeliver") } - close(w.ready) + w.ctx = nil + w.done = true + + w.result <- connOrError{pc: pc, err: err, idleAt: idleAt} + close(w.result) + return true } @@ -1612,13 +1657,16 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { // If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. func (w *wantConn) cancel(t *Transport, err error) { w.mu.Lock() - if w.pc == nil && w.err == nil { - close(w.ready) // catch misbehavior in future delivery + var pc *persistConn + if w.done { + if r, ok := <-w.result; ok { + pc = r.pc + } + } else { + close(w.result) } - pc := w.pc w.ctx = nil - w.pc = nil - w.err = err + w.done = true w.mu.Unlock() if pc != nil { @@ -1679,9 +1727,9 @@ func (q *wantConnQueue) peekFront() *wantConn { return nil } -// cleanFront pops any wantConns that are no longer waiting from the head of the +// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the // queue, reporting whether any were popped. -func (q *wantConnQueue) cleanFront() (cleaned bool) { +func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) { for { w := q.peekFront() if w == nil || w.waiting() { @@ -1692,6 +1740,28 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } +// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue. +func (q *wantConnQueue) cleanFrontCanceled() { + for { + w := q.peekFront() + if w == nil || w.cancelCtx != nil { + return + } + q.popFront() + } +} + +// all iterates over all wantConns in the queue. +// The caller must not modify the queue while iterating. +func (q *wantConnQueue) all(f func(*wantConn)) { + for _, w := range q.head[q.headPos:] { + f(w) + } + for _, w := range q.tail { + f(w) + } +} + func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { conn, err = t.DialTLSContext(ctx, network, addr) @@ -1713,11 +1783,19 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi trace.GetConn(cm.addr()) } + // Detach from the request context's cancellation signal. + // The dial should proceed even if the request is canceled, + // because a future request may be able to make use of the connection. + // + // We retain the request context's values. + dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx)) + w := &wantConn{ cm: cm, key: cm.key(), - ctx: ctx, - ready: make(chan struct{}, 1), + ctx: dialCtx, + cancelCtx: dialCancel, + result: make(chan connOrError, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, } @@ -1728,44 +1806,33 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi }() // Queue for idle connection. - if delivered := t.queueForIdleConn(w); delivered { - pc := w.pc - // Trace only for HTTP/1. - // HTTP/2 calls trace.GotConn itself. - if pc.alt == nil && trace != nil && trace.GotConn != nil { - trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) - } - // set request canceler to some non-nil function so we - // can detect whether it was cleared between now and when - // we enter roundTrip - t.setReqCanceler(treq.cancelKey, func(error) {}) - return pc, nil + if delivered := t.queueForIdleConn(w); !delivered { + t.queueForDial(w) } - cancelc := make(chan error, 1) - t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) - - // Queue for permission to dial. - t.queueForDial(w) - // Wait for completion or cancellation. select { - case <-w.ready: + case r := <-w.result: // Trace success but only for HTTP/1. // HTTP/2 calls trace.GotConn itself. - if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { - trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + if r.pc != nil && r.pc.alt == nil && trace != nil && trace.GotConn != nil { + info := httptrace.GotConnInfo{ + Conn: r.pc.conn, + Reused: r.pc.isReused(), + } + if !r.idleAt.IsZero() { + info.WasIdle = true + info.IdleTime = time.Since(r.idleAt) + } + trace.GotConn(info) } - if w.err != nil { + if r.err != nil { // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the + // what caused r.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { - case <-req.Cancel: - return nil, errRequestCanceledConn - case <-req.Context().Done(): - return nil, req.Context().Err() - case err := <-cancelc: + case <-treq.ctx.Done(): + err := context.Cause(treq.ctx) if err == common.ErrRequestCanceled { err = errRequestCanceledConn } @@ -1774,12 +1841,9 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // return below } } - return w.pc, w.err - case <-req.Cancel: - return nil, errRequestCanceledConn - case <-req.Context().Done(): - return nil, req.Context().Err() - case err := <-cancelc: + return r.pc, r.err + case <-treq.ctx.Done(): + err := context.Cause(treq.ctx) if err == common.ErrRequestCanceled { err = errRequestCanceledConn } @@ -1791,20 +1855,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() - if t.MaxConnsPerHost <= 0 { - go t.dialConnFor(w) - return - } t.connsPerHostMu.Lock() defer t.connsPerHostMu.Unlock() + if t.MaxConnsPerHost <= 0 { + t.startDialConnForLocked(w) + return + } + if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { if t.connsPerHost == nil { t.connsPerHost = make(map[connectMethodKey]int) } t.connsPerHost[w.key] = n + 1 - go t.dialConnFor(w) + t.startDialConnForLocked(w) return } @@ -1812,11 +1877,24 @@ func (t *Transport) queueForDial(w *wantConn) { t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) } q := t.connsPerHostWait[w.key] - q.cleanFront() + q.cleanFrontNotWaiting() q.pushBack(w) t.connsPerHostWait[w.key] = q } +// startDialConnFor calls dialConn in a new goroutine. +// t.connsPerHostMu must be held. +func (t *Transport) startDialConnForLocked(w *wantConn) { + t.dialsInProgress.cleanFrontCanceled() + t.dialsInProgress.pushBack(w) + go func() { + t.dialConnFor(w) + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + w.cancelCtx = nil + }() +} + // dialConnFor dials on behalf of w and delivers the result to w. // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. @@ -1829,7 +1907,7 @@ func (t *Transport) dialConnFor(w *wantConn) { } pc, err := t.dialConn(ctx, w.cm) - delivered := w.tryDeliver(pc, err) + delivered := w.tryDeliver(pc, err, time.Time{}) if err == nil && (!delivered || pc.alt != nil) { // pconn was not passed to w, // or it is HTTP/2 and can be shared. @@ -1866,7 +1944,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { for q.len() > 0 { w := q.popFront() if w.waiting() { - go t.dialConnFor(w) + t.startDialConnForLocked(w) done = true break } @@ -1990,6 +2068,8 @@ func (t *Transport) customTlsHandshake(ctx context.Context, trace *httptrace.Cli return nil } +var testHookProxyConnectTimeout = context.WithTimeout + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -2068,7 +2148,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers switch { case cm.proxyURL == nil: // Do nothing. Not using a proxy. - case cm.proxyURL.Scheme == "socks5": + case cm.proxyURL.Scheme == "socks5" || cm.proxyURL.Scheme == "socks5h": conn := pconn.conn d := socks.NewDialer("tcp", conn.RemoteAddr().String()) if u := cm.proxyURL.User; u != nil { @@ -2120,17 +2200,11 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers Header: hdr, } - // If there's no done channel (no deadline or cancellation - // from the caller possible), at least set some (long) - // timeout here. This will make sure we don't block forever - // and leak a goroutine if the connection stops replying - // after the TCP connect. - connectCtx := ctx - if ctx.Done() == nil { - newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() - connectCtx = newCtx - } + // Set a (long) timeout here to make sure we don't block forever + // and leak a goroutine if the connection stops replying after + // the TCP connect. + connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute) + defer cancel() didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails var ( @@ -2490,18 +2564,6 @@ func (pc *persistConn) isReused() bool { return r } -func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) { - pc.mu.Lock() - defer pc.mu.Unlock() - t.Reused = pc.reused - t.Conn = pc.conn - t.WasIdle = true - if !idleAt.IsZero() { - t.IdleTime = time.Since(idleAt) - } - return -} - func (pc *persistConn) cancelRequest(err error) { pc.mu.Lock() defer pc.mu.Unlock() @@ -2594,7 +2656,8 @@ func (pc *persistConn) readLoop() { pc.t.removeIdleConn(pc) }() - tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + tryPutIdleConn := func(treq *transportRequest) bool { + trace := treq.trace if err := pc.t.tryPutIdleConn(pc); err != nil { closeErr = err if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { @@ -2633,7 +2696,7 @@ func (pc *persistConn) readLoop() { pc.mu.Unlock() rc := <-pc.reqch - trace := httptrace.ContextClientTrace(rc.req.Context()) + trace := rc.treq.trace var resp *http.Response if err == nil { @@ -2662,9 +2725,9 @@ func (pc *persistConn) readLoop() { pc.mu.Unlock() bodyWritable := bodyIsWritable(resp) - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + hasBody := rc.treq.Request.Method != "HEAD" && resp.ContentLength != 0 - if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { + if resp.Close || rc.treq.Request.Close || resp.StatusCode <= 199 || bodyWritable { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. @@ -2672,8 +2735,6 @@ func (pc *persistConn) readLoop() { } if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - // Put the idle conn back into the pool before we send the response // so if they process it quickly and make another request, they'll // get this same conn. But we use the unbuffered channel 'rc' @@ -2682,7 +2743,7 @@ func (pc *persistConn) readLoop() { alive = alive && !pc.sawEOF && pc.wroteRequest() && - replaced && tryPutIdleConn(trace) + tryPutIdleConn(rc.treq) if bodyWritable { closeErr = errCallerOwnsConn @@ -2694,6 +2755,8 @@ func (pc *persistConn) readLoop() { return } + rc.treq.cancel(errRequestDone) + // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) @@ -2748,29 +2811,26 @@ func (pc *persistConn) readLoop() { } // Before looping back to the top of this function and peeking on - // the bufio.textprotoReader, wait for the caller goroutine to finish + // the bufio.Reader, wait for the caller goroutine to finish // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: - replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool alive = alive && bodyEOF && !pc.sawEOF && pc.wroteRequest() && - replaced && tryPutIdleConn(trace) + tryPutIdleConn(rc.treq) if bodyEOF { eofc <- struct{}{} } - case <-rc.req.Cancel: + case <-rc.treq.ctx.Done(): alive = false - pc.t.cancelRequest(rc.cancelKey, common.ErrRequestCanceled) - case <-rc.req.Context().Done(): - alive = false - pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + pc.cancelRequest(context.Cause(rc.treq.ctx)) case <-pc.closech: alive = false } + rc.treq.cancel(errRequestDone) testHookReadLoopBeforeNextRead() } } @@ -2822,22 +2882,17 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr continueCh := rc.continueCh for { - resp, err = pc._readResponse(rc.req) + resp, err = pc._readResponse(rc.treq.Request) if err != nil { return } resCode := resp.StatusCode - if continueCh != nil { - if resCode == 100 { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() - } - continueCh <- struct{}{} - continueCh = nil - } else if resCode >= 200 { - close(continueCh) - continueCh = nil + if continueCh != nil && resCode == http.StatusContinue { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() } + continueCh <- struct{}{} + continueCh = nil } is1xx := 100 <= resCode && resCode <= 199 // treat 101 as a terminal status, see issue 26161 @@ -2861,6 +2916,25 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr if isProtocolSwitch(resp) { resp.Body = newReadWriteCloserBody(pc.br, pc.conn) } + if continueCh != nil { + // We send an "Expect: 100-continue" header, but the server + // responded with a terminal status and no 100 Continue. + // + // If we're going to keep using the connection, we need to send the request body. + // Tell writeLoop to skip sending the body if we're going to close the connection, + // or to send it otherwise. + // + // The case where we receive a 101 Switching Protocols response is a bit + // ambiguous, since we don't know what protocol we're switching to. + // Conceivably, it's one that doesn't need us to send the body. + // Given that we'll send the body if ExpectContinueTimeout expires, + // be consistent and always send it if we aren't closing the connection. + if resp.Close || rc.treq.Request.Close { + close(continueCh) // don't send the body; the connection will close + } else { + continueCh <- struct{}{} // send the body + } + } resp.TLS = pc.tlsState return @@ -3249,10 +3323,9 @@ type responseAndError struct { } type requestAndChan struct { - _ incomparable - req *http.Request - cancelKey cancelKey - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + treq *transportRequest + ch chan responseAndError // unbuffered; always send in select on callerGone // whether the Transport (as opposed to the user client code) // added the Accept-Encoding gzip header. If the Transport @@ -3297,6 +3370,10 @@ var errTimeout error = &timeoutError{"net/http: timeout awaiting response header var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? +// errRequestDone is used to cancel the round trip Context after a request is successfully done. +// It should not be seen by the user. +var errRequestDone = errors.New("net/http: request completed") + func nop() {} // testHooks. Always non-nil. @@ -3316,10 +3393,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er pc.t.Debugf("HTTP/1.1 %s %s", req.Method, req.URL.String()) } testHookEnterRoundTrip() - if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - pc.t.putOrCloseIdleConn(pc) - return nil, common.ErrRequestCanceled - } pc.mu.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -3368,12 +3441,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er gone := make(chan struct{}) defer close(gone) - defer func() { - if err != nil { - pc.t.setReqCanceler(req.cancelKey, nil) - } - }() - const debugRoundTrip = false // Write the request concurrently with waiting for a response, @@ -3385,25 +3452,34 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er resc := make(chan responseAndError) pc.reqch <- requestAndChan{ - req: req.Request, - cancelKey: req.cancelKey, + treq: req, ch: resc, addedGzip: requestedGzip, continueCh: continueCh, callerGone: gone, } + handleResponse := func(re responseAndError) (*http.Response, error) { + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if debugRoundTrip { + req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil + } var respHeaderTimer <-chan time.Time - cancelChan := req.Request.Cancel - ctxDoneChan := req.Context().Done() + ctxDoneChan := req.ctx.Done() pcClosed := pc.closech - canceled := false for { testHookWaitResLoop() select { case err := <-writeErrCh: if debugRoundTrip { - req.logf("writeErrCh resv: %T/%#v", err, err) + req.logf("writeErrCh recv: %T/%#v", err, err) } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) @@ -3418,13 +3494,18 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er respHeaderTimer = timer.C } case <-pcClosed: - pcClosed = nil - if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { - if debugRoundTrip { - req.logf("closech recv: %T %#v", pc.closed, pc.closed) - } - return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) + select { + case re := <-resc: + // The pconn closing raced with the response to the request, + // probably after the server wrote a response and immediately + // closed the connection. Use the response. + return handleResponse(re) + default: + } + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) case <-respHeaderTimer: if debugRoundTrip { req.logf("timeout waiting for response headers.") @@ -3432,23 +3513,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er pc.close(errTimeout) return nil, errTimeout case re := <-resc: - if (re.res == nil) == (re.err == nil) { - panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) - } - if debugRoundTrip { - req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) - } - if re.err != nil { - return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) - } - return re.res, nil - case <-cancelChan: - canceled = pc.t.cancelRequest(req.cancelKey, common.ErrRequestCanceled) - cancelChan = nil + return handleResponse(re) case <-ctxDoneChan: - canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) - cancelChan = nil - ctxDoneChan = nil + select { + case re := <-resc: + // readLoop is responsible for canceling req.ctx after + // it reads the response body. Check for a response racing + // the context close, and use the response if available. + return handleResponse(re) + default: + } + pc.cancelRequest(context.Cause(req.ctx)) } } } @@ -3503,9 +3578,10 @@ func (pc *persistConn) closeLocked(err error) { } var portMap = map[string]string{ - "http": "80", - "https": "443", - "socks5": "1080", + "http": "80", + "https": "443", + "socks5": "1080", + "socks5h": "1080", } func idnaASCIIFromURL(url *url.URL) string { @@ -3646,6 +3722,16 @@ func (fakeLocker) Unlock() {} // cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if // cfg is nil. This is safe to call even if cfg is in active use by a TLS // client or server. +// +// cloneTLSConfig should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneTLSConfig func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} From fb51d5bc7856ca8adff512cc1a328c8cb8d69e75 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 16:31:06 +0800 Subject: [PATCH 829/843] use latest tls fingerprint in ImpersonateXXX --- client_impersonate.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client_impersonate.go b/client_impersonate.go index a4fd76fe..581ca08f 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -81,7 +81,7 @@ var ( // ImpersonateChrome impersonates Chrome browser (version 109). func (c *Client) ImpersonateChrome() *Client { c. - SetTLSFingerprint(utls.HelloChrome_106_Shuffle). // Chrome 106~109 shares the same tls fingerprint. + SetTLSFingerprint(utls.HelloChrome_Auto). // Chrome 106~109 shares the same tls fingerprint. SetHTTP2SettingsFrame(chromeHttp2Settings...). SetHTTP2ConnectionFlow(15663105). SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). @@ -197,7 +197,7 @@ var ( // ImpersonateFirefox impersonates Firefox browser (version 105). func (c *Client) ImpersonateFirefox() *Client { c. - SetTLSFingerprint(utls.HelloFirefox_105). + SetTLSFingerprint(utls.HelloFirefox_Auto). SetHTTP2SettingsFrame(firefoxHttp2Settings...). SetHTTP2ConnectionFlow(12517377). SetHTTP2PriorityFrames(firefoxPriorityFrames...). @@ -258,7 +258,7 @@ var ( // ImpersonateSafari impersonates Safari browser (version 16). func (c *Client) ImpersonateSafari() *Client { c. - SetTLSFingerprint(utls.HelloSafari_16_0). + SetTLSFingerprint(utls.HelloSafari_Auto). SetHTTP2SettingsFrame(safariHttp2Settings...). SetHTTP2ConnectionFlow(10485760). SetCommonPseudoHeaderOder(safariPseudoHeaderOrder...). From db181cc471b795444f43e905d993274b541752c0 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 11 Sep 2024 17:13:12 +0800 Subject: [PATCH 830/843] merge upstream http2: 2024-09-06(3c333c) --- internal/http2/frame.go | 40 +++++++- internal/http2/http2.go | 7 +- internal/http2/pipe.go | 11 ++- internal/http2/timer.go | 20 ++++ internal/http2/transport.go | 179 ++++++++++++++++++++++-------------- 5 files changed, 180 insertions(+), 77 deletions(-) create mode 100644 internal/http2/timer.go diff --git a/internal/http2/frame.go b/internal/http2/frame.go index 81b3c082..077e181e 100644 --- a/internal/http2/frame.go +++ b/internal/http2/frame.go @@ -520,6 +520,9 @@ func (h2f *Framer) currentRequest(id uint32) *http.Request { // returned error is errFrameTooLarge. Other errors may be of type // ConnectionError, StreamError, or anything else from the underlying // reader. +// +// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID +// indicates the stream responsible for the error. func (h2f *Framer) ReadFrame() (Frame, error) { h2f.errDetail = nil if h2f.lastFrame != nil { @@ -1555,7 +1558,7 @@ func (fr *Framer) maxHeaderStringLen() int { // readMetaFrame returns 0 or more CONTINUATION frames from fr and // merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. -func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaHeadersFrame, error) { +func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (Frame, error) { if h2f.AllowIllegalReads { return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") } @@ -1598,6 +1601,7 @@ func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaH if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1621,8 +1625,38 @@ func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaH var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return mh, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return mh, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { - return nil, ConnectionError(ErrCodeCompression) + return mh, ConnectionError(ErrCodeCompression) } if hc.HeadersEnded() { @@ -1639,7 +1673,7 @@ func (h2f *Framer) readMetaFrame(hf *HeadersFrame, dumps []*dump.Dumper) (*MetaH mh.HeadersFrame.invalidate() if err := hdec.Close(); err != nil { - return nil, ConnectionError(ErrCodeCompression) + return mh, ConnectionError(ErrCodeCompression) } if invalid != nil { h2f.errDetail = invalid diff --git a/internal/http2/http2.go b/internal/http2/http2.go index 3e3e9350..4d38ac69 100644 --- a/internal/http2/http2.go +++ b/internal/http2/http2.go @@ -7,13 +7,14 @@ package http2 import ( "bufio" "crypto/tls" - "golang.org/x/net/http/httpguts" "net/http" "os" "sort" "strconv" "strings" "sync" + + "golang.org/x/net/http/httpguts" ) var ( @@ -50,9 +51,7 @@ const ( initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size ) -var ( - clientPreface = []byte(ClientPreface) -) +var clientPreface = []byte(ClientPreface) // validWireHeaderFieldName reports whether v is a valid header field // name (key). See httpguts.ValidHeaderName for the base rules. diff --git a/internal/http2/pipe.go b/internal/http2/pipe.go index 684d984f..3b9f06b9 100644 --- a/internal/http2/pipe.go +++ b/internal/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } diff --git a/internal/http2/timer.go b/internal/http2/timer.go new file mode 100644 index 00000000..0b1c17b8 --- /dev/null +++ b/internal/http2/timer.go @@ -0,0 +1,20 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import "time" + +// A timer is a time.Timer, as an interface which can be replaced in tests. +type timer = interface { + C() <-chan time.Time + Reset(d time.Duration) bool + Stop() bool +} + +// timeTimer adapts a time.Timer to the timer interface. +type timeTimer struct { + *time.Timer +} + +func (t timeTimer) C() <-chan time.Time { return t.Timer.C } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 483fc9c9..fea6685c 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -109,6 +109,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -143,6 +149,20 @@ type Transport struct { connPoolOrDef ClientConnPool // non-nil version of ConnPool } +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (t *Transport) newTimer(d time.Duration) timer { + return timeTimer{time.NewTimer(d)} +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (t *Transport) afterFunc(d time.Duration, f func()) timer { + return timeTimer{time.AfterFunc(d, f)} +} + +func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, d) +} + func (t *Transport) maxHeaderListSize() uint32 { if t.MaxHeaderListSize == 0 { return 10 << 20 @@ -188,7 +208,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -433,15 +453,6 @@ func (t *Transport) AddConn(conn net.Conn, addr string) (used bool, err error) { return } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -479,13 +490,13 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + tm := t.newTimer(d) select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -700,10 +711,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } - if d := t.IdleConnTimeout; d != 0 { - cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) - } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) } @@ -744,10 +751,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // henc in response to SETTINGS frames? cc.henc = hpack.NewEncoder(&cc.hbuf) - if t.AllowHTTP { - cc.nextStreamID = 3 - } - if cs, ok := c.(connectionStater); ok { state := cs.ConnectionState() cc.tlsState = &state @@ -786,6 +789,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } + // Start the idle timer after the connection is fully initialized. + if d := t.IdleConnTimeout; d != 0 { + cc.idleTimeout = d + cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout) + } + go cc.readLoop() return cc, nil } @@ -794,7 +803,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -830,7 +839,20 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { } last := f.LastStreamID for streamID, cs := range cc.streams { - if streamID > last { + if streamID <= last { + // The server's GOAWAY indicates that it received this stream. + // It will either finish processing it, or close the connection + // without doing so. Either way, leave the stream alone for now. + continue + } + if streamID == 1 && cc.goAway.ErrCode != ErrCodeNo { + // Don't retry the first stream on a connection if we get a non-NO error. + // If the server is sending an error on a new connection, + // retrying the request on a new one probably isn't going to work. + cs.abortStreamLocked(fmt.Errorf("http2: Transport received GOAWAY from server ErrCode:%v", cc.goAway.ErrCode)) + } else { + // Aborting the stream with errClentConnGotGoAway indicates that + // the request should be retried on a new connection. cs.abortStreamLocked(errClientConnGotGoAway) } } @@ -1151,6 +1173,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { if cc.t != nil && cc.t.Debugf != nil { cc.t.Debugf("HTTP/2 %s %s", req.Method, req.URL.String()) } @@ -1169,7 +1195,27 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + if !cc.t.DisableCompression && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + !cs.isHead { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + cs.requestedGzip = true + } + go cs.doRequest(req, streamf) waitDone := func() error { select { @@ -1262,8 +1308,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // doRequest runs for the duration of the request lifetime. // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). -func (cs *clientStream) doRequest(req *http.Request) { - err := cs.writeRequest(req) +func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) { + err := cs.writeRequest(req, streamf) cs.cleanupWriteRequest(err) } @@ -1274,7 +1320,7 @@ func (cs *clientStream) doRequest(req *http.Request) { // // It returns non-nil if the request ends otherwise. // If the returned error is StreamError, the error Code may be used in resetting the stream. -func (cs *clientStream) writeRequest(req *http.Request) (err error) { +func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStream)) (err error) { cc := cs.cc ctx := cs.ctx @@ -1312,24 +1358,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() - // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - if !cc.t.DisableCompression && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - !cs.isHead { - // Request gzip only, not deflate. Deflate is ambiguous and - // not as universally supported anyway. - // See: https://zlib.net/zlib_faq.html#faq39 - // - // Note that we don't request this for HEAD requests, - // due to a bug in nginx: - // http://trac.nginx.org/nginx/ticket/358 - // https://golang.org/issue/5522 - // - // We don't request gzip if the request is for a range, since - // auto-decoding a portion of a gzipped document will just fail - // anyway. See https://golang.org/issue/8923 - cs.requestedGzip = true + if streamf != nil { + streamf(cs) } continueTimeout := cc.t.ExpectContinueTimeout @@ -1406,9 +1436,9 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.t.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, @@ -1839,6 +1869,22 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } + } + return "" +} + var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -1875,19 +1921,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { @@ -2307,10 +2348,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2753,7 +2793,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, @@ -3031,24 +3071,25 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) + var pingError error + errc := make(chan struct{}) go func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } }() select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: From 8d3a13408832a108b8d17f13f4bbd4ef2d0e11bc Mon Sep 17 00:00:00 2001 From: roc Date: Sat, 21 Sep 2024 14:23:36 +0800 Subject: [PATCH 831/843] copy more tls config to utls (fix #387) --- client.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index dc7016d1..360a5df0 100644 --- a/client.go +++ b/client.go @@ -1198,11 +1198,21 @@ func (c *Client) SetTLSFingerprint(clientHelloID utls.ClientHelloID) *Client { colonPos = len(addr) } hostname := addr[:colonPos] + tlsConfig := c.GetTLSClientConfig() utlsConfig := &utls.Config{ - ServerName: hostname, - RootCAs: c.GetTLSClientConfig().RootCAs, - NextProtos: c.GetTLSClientConfig().NextProtos, - InsecureSkipVerify: c.GetTLSClientConfig().InsecureSkipVerify, + ServerName: hostname, + Rand: tlsConfig.Rand, + Time: tlsConfig.Time, + RootCAs: tlsConfig.RootCAs, + NextProtos: tlsConfig.NextProtos, + ClientCAs: tlsConfig.ClientCAs, + InsecureSkipVerify: tlsConfig.InsecureSkipVerify, + CipherSuites: tlsConfig.CipherSuites, + SessionTicketsDisabled: tlsConfig.SessionTicketsDisabled, + MinVersion: tlsConfig.MinVersion, + MaxVersion: tlsConfig.MaxVersion, + DynamicRecordSizingDisabled: tlsConfig.DynamicRecordSizingDisabled, + KeyLogWriter: tlsConfig.KeyLogWriter, } uconn := &uTLSConn{utls.UClient(plainConn, utlsConfig, clientHelloID)} err = uconn.HandshakeContext(ctx) From cfff2507b184d921cca82594b680a6a7028e8263 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Oct 2024 14:13:01 +0800 Subject: [PATCH 832/843] Support ordered form data (#391) --- middleware.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++----- request.go | 35 +++++++++++++++++++++-------------- 2 files changed, 67 insertions(+), 19 deletions(-) diff --git a/middleware.go b/middleware.go index c310cd3d..438be661 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "errors" "io" "mime/multipart" "net/http" @@ -124,9 +125,22 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e func writeMultiPart(r *Request, w *multipart.Writer) { defer w.Close() // close multipart to write tailer boundary - for k, vs := range r.FormData { - for _, v := range vs { - w.WriteField(k, v) + if len(r.FormData) > 0 { + for k, vs := range r.FormData { + for _, v := range vs { + w.WriteField(k, v) + } + } + } else if len(r.OrderedFormData) > 0 { + if len(r.OrderedFormData)%2 != 0 { + r.error = errBadOrderedFormData + return + } + maxIndex := len(r.OrderedFormData) - 2 + for i := 0; i <= maxIndex; i += 2 { + key := r.OrderedFormData[i] + value := r.OrderedFormData[i+1] + w.WriteField(key, value) } } for _, file := range r.uploadFiles { @@ -134,7 +148,7 @@ func writeMultiPart(r *Request, w *multipart.Writer) { } } -func handleMultiPart(c *Client, r *Request) (err error) { +func handleMultiPart(r *Request) (err error) { if r.forceChunkedEncoding { pr, pw := io.Pipe() r.GetBody = func() (io.ReadCloser, error) { @@ -164,6 +178,29 @@ func handleFormData(r *Request) { r.SetBodyBytes([]byte(r.FormData.Encode())) } +var errBadOrderedFormData = errors.New("bad ordered form data, the number of key-value pairs should be an even number") + +func handleOrderedFormData(r *Request) { + r.SetContentType(header.FormContentType) + if len(r.OrderedFormData)%2 != 0 { + r.error = errBadOrderedFormData + return + } + maxIndex := len(r.OrderedFormData) - 2 + var buf strings.Builder + for i := 0; i <= maxIndex; i += 2 { + key := r.OrderedFormData[i] + value := r.OrderedFormData[i+1] + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(url.QueryEscape(key)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(value)) + } + r.SetBodyString(buf.String()) +} + func handleMarshalBody(c *Client, r *Request) error { ct := "" if r.Headers != nil { @@ -205,16 +242,20 @@ func parseRequestBody(c *Client, r *Request) (err error) { } // handle multipart if r.isMultiPart { - return handleMultiPart(c, r) + return handleMultiPart(r) } // handle form data if len(c.FormData) > 0 { r.SetFormDataFromValues(c.FormData) } + if len(r.FormData) > 0 { handleFormData(r) return + } else if len(r.OrderedFormData) > 0 { + handleOrderedFormData(r) + return } // handle marshal body diff --git a/request.go b/request.go index 7696d881..c8bea419 100644 --- a/request.go +++ b/request.go @@ -25,20 +25,21 @@ import ( // req client. Request provides lots of chainable settings which can // override client level settings. type Request struct { - PathParams map[string]string - QueryParams urlpkg.Values - FormData urlpkg.Values - Headers http.Header - Cookies []*http.Cookie - Result interface{} - Error interface{} - RawRequest *http.Request - StartTime time.Time - RetryAttempt int - RawURL string // read only - Method string - Body []byte - GetBody GetContentFunc + PathParams map[string]string + QueryParams urlpkg.Values + FormData urlpkg.Values + OrderedFormData []string + Headers http.Header + Cookies []*http.Cookie + Result interface{} + Error interface{} + RawRequest *http.Request + StartTime time.Time + RetryAttempt int + RawURL string // read only + Method string + Body []byte + GetBody GetContentFunc // URL is an auto-generated field, and is nil in request middleware (OnBeforeRequest), // consider using RawURL if you want, it's not nil in client middleware (WrapRoundTripFunc) URL *urlpkg.URL @@ -187,6 +188,12 @@ func (r *Request) SetFormData(data map[string]string) *Request { return r } +// SetOrderedFormData set the ordered form data from key-values pairs. +func (r *Request) SetOrderedFormData(kvs ...string) *Request { + r.OrderedFormData = append(r.OrderedFormData, kvs...) + return r +} + // SetFormDataAnyType set the form data from a map, which value could be any type, // will convert to string automatically. // It will not been used if request method does not allow payload. From 5f9bf49af01a3beaabf0c5e4a19b8f8aa80eb01a Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 10 Oct 2024 14:15:27 +0800 Subject: [PATCH 833/843] Add global wrapper for Request.SetOrderedFormData --- request_wrapper.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/request_wrapper.go b/request_wrapper.go index fec48b40..cdc3f685 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -26,6 +26,12 @@ func SetFormData(data map[string]string) *Request { return defaultClient.R().SetFormData(data) } +// SetOrderedFormData is a global wrapper methods which delegated +// to the default client, create a request and SetOrderedFormData for request. +func SetOrderedFormData(kvs ...string) *Request { + return defaultClient.R().SetOrderedFormData(kvs...) +} + // SetFormDataAnyType is a global wrapper methods which delegated // to the default client, create a request and SetFormDataAnyType for request. func SetFormDataAnyType(data map[string]interface{}) *Request { From 5e4950e6dc7f94d9fa670b7b3f1245fae63ddeff Mon Sep 17 00:00:00 2001 From: rosahaj <141790572+rosahaj@users.noreply.github.com> Date: Sun, 13 Oct 2024 18:49:16 +0200 Subject: [PATCH 834/843] Add client.SetMultipartBoundaryFunc and port Blink/WebKit/Firefox implementations --- client.go | 12 +++++++++ client_impersonate.go | 61 ++++++++++++++++++++++++++++++++++++++++--- client_test.go | 31 ++++++++++++++++++++++ client_wrapper.go | 23 ++++++++++++++-- middleware.go | 15 +++++++++-- response.go | 5 ++-- 6 files changed, 138 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 360a5df0..1b79591c 100644 --- a/client.go +++ b/client.go @@ -59,6 +59,7 @@ type Client struct { jsonUnmarshal func(data []byte, v interface{}) error xmlMarshal func(v interface{}) ([]byte, error) xmlUnmarshal func(data []byte, v interface{}) error + multipartBoundaryFunc func() string outputDirectory string scheme string log Logger @@ -239,6 +240,17 @@ func (c *Client) SetCommonFormData(data map[string]string) *Client { return c } +// SetMultipartBoundaryFunc overrides the default function used to generate +// boundary delimiters for "multipart/form-data" requests with a customized one, +// which returns a boundary delimiter (without the two leading hyphens). +// +// Boundary delimiter may only contain certain ASCII characters, and must be +// non-empty and at most 70 bytes long (see RFC 2046, Section 5.1.1). +func (c *Client) SetMultipartBoundaryFunc(fn func() string) *Client { + c.multipartBoundaryFunc = fn + return c +} + // SetBaseURL set the default base URL, will be used if request URL is // a relative URL. func (c *Client) SetBaseURL(u string) *Client { diff --git a/client_impersonate.go b/client_impersonate.go index 581ca08f..e2fd94c8 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -1,10 +1,56 @@ package req import ( + "crypto/rand" + "encoding/binary" + "math/big" + "strconv" + "strings" + "github.com/imroc/req/v3/http2" utls "github.com/refraction-networking/utls" ) +// Identical for both Blink-based browsers (Chrome, Chromium, etc.) and WebKit-based browsers (Safari, etc.) +// Blink implementation: https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/platform/network/form_data_encoder.cc;drc=1d694679493c7b2f7b9df00e967b4f8699321093;l=130 +// WebKit implementation: https://github.com/WebKit/WebKit/blob/47eea119fe9462721e5cc75527a4280c6d5f5214/Source/WebCore/platform/network/FormDataBuilder.cpp#L120 +func webkitMultipartBoundaryFunc() string { + const letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789AB" + + sb := strings.Builder{} + sb.WriteString("----WebKitFormBoundary") + + for i := 0; i < 16; i++ { + index, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)-1))) + if err != nil { + panic(err) + } + + sb.WriteByte(letters[index.Int64()]) + } + + return sb.String() +} + +// Firefox implementation: https://searchfox.org/mozilla-central/source/dom/html/HTMLFormSubmission.cpp#355 +func firefoxMultipartBoundaryFunc() string { + sb := strings.Builder{} + sb.WriteString("-------------------------") + + for i := 0; i < 3; i++ { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + panic(err) + } + u32 := binary.LittleEndian.Uint32(b[:]) + s := strconv.FormatUint(uint64(u32), 10) + + sb.WriteString(s) + } + + return sb.String() +} + var ( chromeHttp2Settings = []http2.Setting{ { @@ -71,6 +117,7 @@ var ( "sec-fetch-dest": "document", "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", } + chromeHeaderPriority = http2.PriorityParam{ StreamDep: 0, Exclusive: true, @@ -87,7 +134,8 @@ func (c *Client) ImpersonateChrome() *Client { SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). SetCommonHeaderOrder(chromeHeaderOrder...). SetCommonHeaders(chromeHeaders). - SetHTTP2HeaderPriority(chromeHeaderPriority) + SetHTTP2HeaderPriority(chromeHeaderPriority). + SetMultipartBoundaryFunc(webkitMultipartBoundaryFunc) return c } @@ -106,6 +154,7 @@ var ( Val: 16384, }, } + firefoxPriorityFrames = []http2.PriorityFrame{ { StreamID: 3, @@ -156,12 +205,14 @@ var ( }, }, } + firefoxPseudoHeaderOrder = []string{ ":method", ":path", ":authority", ":scheme", } + firefoxHeaderOrder = []string{ "user-agent", "accept", @@ -176,6 +227,7 @@ var ( "sec-fetch-user", "te", } + firefoxHeaders = map[string]string{ "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:105.0) Gecko/20100101 Firefox/105.0", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", @@ -187,6 +239,7 @@ var ( "sec-fetch-user": "?1", //"te": "trailers", } + firefoxHeaderPriority = http2.PriorityParam{ StreamDep: 13, Exclusive: false, @@ -204,7 +257,8 @@ func (c *Client) ImpersonateFirefox() *Client { SetCommonPseudoHeaderOder(firefoxPseudoHeaderOrder...). SetCommonHeaderOrder(firefoxHeaderOrder...). SetCommonHeaders(firefoxHeaders). - SetHTTP2HeaderPriority(firefoxHeaderPriority) + SetHTTP2HeaderPriority(firefoxHeaderPriority). + SetMultipartBoundaryFunc(firefoxMultipartBoundaryFunc) return c } @@ -264,6 +318,7 @@ func (c *Client) ImpersonateSafari() *Client { SetCommonPseudoHeaderOder(safariPseudoHeaderOrder...). SetCommonHeaderOrder(safariHeaderOrder...). SetCommonHeaders(safariHeaders). - SetHTTP2HeaderPriority(safariHeaderPriority) + SetHTTP2HeaderPriority(safariHeaderPriority). + SetMultipartBoundaryFunc(webkitMultipartBoundaryFunc) return c } diff --git a/client_test.go b/client_test.go index b87745cb..e9e9f75e 100644 --- a/client_test.go +++ b/client_test.go @@ -5,12 +5,14 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" "net/http" "net/http/cookiejar" "net/url" "os" + "regexp" "strings" "testing" "time" @@ -468,6 +470,35 @@ func TestSetCommonFormData(t *testing.T) { tests.AssertEqual(t, "test", form.Get("test")) } +func TestSetMultipartBoundaryFunc(t *testing.T) { + delimiter := "test-delimiter" + expectedContentType := fmt.Sprintf("multipart/form-data; boundary=%s", delimiter) + resp, err := tc(). + SetMultipartBoundaryFunc(func() string { + return delimiter + }).R(). + EnableForceMultipart(). + SetFormData( + map[string]string{ + "test": "test", + }). + Post("/content-type") + assertSuccess(t, resp, err) + tests.AssertEqual(t, expectedContentType, resp.String()) +} + +func TestFirefoxMultipartBoundaryFunc(t *testing.T) { + r := regexp.MustCompile(`^-------------------------\d{1,10}\d{1,10}\d{1,10}$`) + b := firefoxMultipartBoundaryFunc() + tests.AssertEqual(t, true, r.MatchString(b)) +} + +func TestWebkitMultipartBoundaryFunc(t *testing.T) { + r := regexp.MustCompile(`^----WebKitFormBoundary[0-9a-zA-Z]{16}$`) + b := webkitMultipartBoundaryFunc() + tests.AssertEqual(t, true, r.MatchString(b)) +} + func TestClientClone(t *testing.T) { c1 := tc().DevMode(). SetCommonHeader("test", "test"). diff --git a/client_wrapper.go b/client_wrapper.go index 2c1136af..1e773d97 100644 --- a/client_wrapper.go +++ b/client_wrapper.go @@ -3,13 +3,14 @@ package req import ( "context" "crypto/tls" - "github.com/imroc/req/v3/http2" - utls "github.com/refraction-networking/utls" "io" "net" "net/http" "net/url" "time" + + "github.com/imroc/req/v3/http2" + utls "github.com/refraction-networking/utls" ) // WrapRoundTrip is a global wrapper methods which delegated @@ -56,6 +57,12 @@ func SetCommonFormData(data map[string]string) *Client { return defaultClient.SetCommonFormData(data) } +// SetMultipartBoundaryFunc is a global wrapper methods which delegated +// to the default client's Client.SetMultipartBoundaryFunc. +func SetMultipartBoundaryFunc(fn func() string) *Client { + return defaultClient.SetMultipartBoundaryFunc(fn) +} + // SetBaseURL is a global wrapper methods which delegated // to the default client's Client.SetBaseURL. func SetBaseURL(u string) *Client { @@ -482,6 +489,18 @@ func ImpersonateChrome() *Client { return defaultClient.ImpersonateChrome() } +// ImpersonateChrome is a global wrapper methods which delegated +// to the default client's Client.ImpersonateChrome. +func ImpersonateFirefox() *Client { + return defaultClient.ImpersonateFirefox() +} + +// ImpersonateChrome is a global wrapper methods which delegated +// to the default client's Client.ImpersonateChrome. +func ImpersonateSafari() *Client { + return defaultClient.ImpersonateFirefox() +} + // SetCommonContentType is a global wrapper methods which delegated // to the default client's Client.SetCommonContentType. func SetCommonContentType(ct string) *Client { diff --git a/middleware.go b/middleware.go index 438be661..df5d9d9f 100644 --- a/middleware.go +++ b/middleware.go @@ -148,13 +148,21 @@ func writeMultiPart(r *Request, w *multipart.Writer) { } } -func handleMultiPart(r *Request) (err error) { +func handleMultiPart(c *Client, r *Request) (err error) { + var b string + if c.multipartBoundaryFunc != nil { + b = c.multipartBoundaryFunc() + } + if r.forceChunkedEncoding { pr, pw := io.Pipe() r.GetBody = func() (io.ReadCloser, error) { return pr, nil } w := multipart.NewWriter(pw) + if len(b) > 0 { + w.SetBoundary(b) + } r.SetContentType(w.FormDataContentType()) go func() { writeMultiPart(r, w) @@ -163,6 +171,9 @@ func handleMultiPart(r *Request) (err error) { } else { buf := new(bytes.Buffer) w := multipart.NewWriter(buf) + if len(b) > 0 { + w.SetBoundary(b) + } writeMultiPart(r, w) r.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(buf.Bytes())), nil @@ -242,7 +253,7 @@ func parseRequestBody(c *Client, r *Request) (err error) { } // handle multipart if r.isMultiPart { - return handleMultiPart(r) + return handleMultiPart(c, r) } // handle form data diff --git a/response.go b/response.go index 65b9f521..0e51900a 100644 --- a/response.go +++ b/response.go @@ -1,12 +1,13 @@ package req import ( - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/util" "io" "net/http" "strings" "time" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/util" ) // Response is the http response. From 87ad469c9402fef3632a864514f7e238fa28b24a Mon Sep 17 00:00:00 2001 From: rosahaj <141790572+rosahaj@users.noreply.github.com> Date: Sun, 1 Dec 2024 21:27:43 +0100 Subject: [PATCH 835/843] Add DefaultRedirectPolicy --- client.go | 16 +++------------- client_test.go | 4 ++++ redirect.go | 5 +++++ 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 1b79591c..b1eea7d9 100644 --- a/client.go +++ b/client.go @@ -321,20 +321,10 @@ func (c *Client) GetTLSClientConfig() *tls.Config { return c.TLSClientConfig } -func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") - } - if c.DebugLog { - c.log.Debugf(" %s %s", req.Method, req.URL.String()) - } - return nil -} - // SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect // responses (usually responses with 301 and 302 status code), see the predefined -// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, -// SameDomainRedirectPolicy and SameHostRedirectPolicy. +// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, DefaultRedirectPolicy, MaxRedirectPolicy, +// NoRedirectPolicy, SameDomainRedirectPolicy and SameHostRedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -1565,7 +1555,7 @@ func C() *Client { xmlUnmarshal: xml.Unmarshal, cookiejarFactory: memoryCookieJarFactory, } - httpClient.CheckRedirect = c.defaultCheckRedirect + c.SetRedirectPolicy(DefaultRedirectPolicy()) c.initCookieJar() c.initTransport() diff --git a/client_test.go b/client_test.go index e9e9f75e..7a6aeebe 100644 --- a/client_test.go +++ b/client_test.go @@ -369,6 +369,10 @@ func TestRedirect(t *testing.T) { tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) + _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(20)).SetRedirectPolicy(DefaultRedirectPolicy()).R().Get("/unlimited-redirect") + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "stopped after 10 redirects", true) + _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) diff --git a/redirect.go b/redirect.go index f1cc4331..fcc13e4b 100644 --- a/redirect.go +++ b/redirect.go @@ -21,6 +21,11 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { } } +// DefaultRedirectPolicy allows up to 10 redirects +func DefaultRedirectPolicy() RedirectPolicy { + return MaxRedirectPolicy(10) +} + // NoRedirectPolicy disable redirect behaviour func NoRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error { From 3d70b6b5a5aa0c4dbe958fdd411b1bd24991de26 Mon Sep 17 00:00:00 2001 From: simonren-tes Date: Wed, 4 Dec 2024 17:58:11 +0800 Subject: [PATCH 836/843] feat: upgrade quic-go to v0.48.2, fix CVE-2024-53259 --- examples/find-popular-repo/go.mod | 14 ++--- examples/find-popular-repo/go.sum | 12 ++++ examples/opentelemetry-jaeger-tracing/go.mod | 14 ++--- examples/opentelemetry-jaeger-tracing/go.sum | 12 ++++ examples/upload/uploadclient/go.mod | 14 ++--- examples/upload/uploadclient/go.sum | 12 ++++ examples/uploadcallback/uploadclient/go.mod | 14 ++--- examples/uploadcallback/uploadclient/go.sum | 12 ++++ go.mod | 24 ++++---- go.sum | 59 +++++++++----------- 10 files changed, 114 insertions(+), 73 deletions(-) diff --git a/examples/find-popular-repo/go.mod b/examples/find-popular-repo/go.mod index 430089f7..890f1462 100644 --- a/examples/find-popular-repo/go.mod +++ b/examples/find-popular-repo/go.mod @@ -1,6 +1,6 @@ module find-popular-repo -go 1.21 +go 1.22.0 toolchain go1.22.3 @@ -21,11 +21,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect + golang.org/x/tools v0.27.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/find-popular-repo/go.sum b/examples/find-popular-repo/go.sum index 382380ce..b70f2661 100644 --- a/examples/find-popular-repo/go.sum +++ b/examples/find-popular-repo/go.sum @@ -171,6 +171,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -181,6 +183,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVD golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -204,6 +208,8 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -243,6 +249,8 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -254,6 +262,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -271,6 +281,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/opentelemetry-jaeger-tracing/go.mod b/examples/opentelemetry-jaeger-tracing/go.mod index 09afcab2..6454f319 100644 --- a/examples/opentelemetry-jaeger-tracing/go.mod +++ b/examples/opentelemetry-jaeger-tracing/go.mod @@ -1,6 +1,6 @@ module opentelemetry-jaeger-tracing -go 1.21 +go 1.22.0 toolchain go1.22.3 @@ -30,11 +30,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect + golang.org/x/tools v0.27.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/opentelemetry-jaeger-tracing/go.sum b/examples/opentelemetry-jaeger-tracing/go.sum index 6ca9b455..768df1c9 100644 --- a/examples/opentelemetry-jaeger-tracing/go.sum +++ b/examples/opentelemetry-jaeger-tracing/go.sum @@ -182,6 +182,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -192,6 +194,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVD golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -218,6 +222,8 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -261,6 +267,8 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -273,6 +281,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -289,6 +299,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/upload/uploadclient/go.mod b/examples/upload/uploadclient/go.mod index 8913edf1..e338f5cb 100644 --- a/examples/upload/uploadclient/go.mod +++ b/examples/upload/uploadclient/go.mod @@ -1,6 +1,6 @@ module uploadclient -go 1.21 +go 1.22.0 toolchain go1.22.3 @@ -22,11 +22,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect + golang.org/x/tools v0.27.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/upload/uploadclient/go.sum b/examples/upload/uploadclient/go.sum index 3d8d9499..b445b360 100644 --- a/examples/upload/uploadclient/go.sum +++ b/examples/upload/uploadclient/go.sum @@ -170,6 +170,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -180,6 +182,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVD golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -208,6 +212,8 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -250,6 +256,8 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -262,6 +270,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -278,6 +288,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/examples/uploadcallback/uploadclient/go.mod b/examples/uploadcallback/uploadclient/go.mod index 8913edf1..e338f5cb 100644 --- a/examples/uploadcallback/uploadclient/go.mod +++ b/examples/uploadcallback/uploadclient/go.mod @@ -1,6 +1,6 @@ module uploadclient -go 1.21 +go 1.22.0 toolchain go1.22.3 @@ -22,11 +22,11 @@ require ( github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.5 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect + golang.org/x/tools v0.27.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/examples/uploadcallback/uploadclient/go.sum b/examples/uploadcallback/uploadclient/go.sum index 3d8d9499..b445b360 100644 --- a/examples/uploadcallback/uploadclient/go.sum +++ b/examples/uploadcallback/uploadclient/go.sum @@ -170,6 +170,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -180,6 +182,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVD golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -208,6 +212,8 @@ golang.org/x/net v0.0.0-20220809012201-f428fae20770/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -250,6 +256,8 @@ golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -262,6 +270,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -278,6 +288,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/go.mod b/go.mod index 87ce144f..dc69fa45 100644 --- a/go.mod +++ b/go.mod @@ -7,23 +7,23 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/klauspost/compress v1.17.9 github.com/quic-go/qpack v0.5.1 - github.com/quic-go/quic-go v0.47.0 + github.com/quic-go/quic-go v0.48.2 github.com/refraction-networking/utls v1.6.7 - golang.org/x/net v0.29.0 - golang.org/x/text v0.18.0 + golang.org/x/net v0.31.0 + golang.org/x/text v0.20.0 ) require ( github.com/cloudflare/circl v1.4.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect + github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/onsi/ginkgo/v2 v2.20.2 // indirect - go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.27.0 // indirect - golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect - golang.org/x/mod v0.21.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/tools v0.25.0 // indirect + github.com/onsi/ginkgo/v2 v2.22.0 // indirect + go.uber.org/mock v0.5.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/sync v0.9.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/tools v0.27.0 // indirect ) diff --git a/go.sum b/go.sum index 5976620d..3fa5fec8 100644 --- a/go.sum +++ b/go.sum @@ -6,15 +6,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9 h1:q5g0N9eal4bmJwXHC5z0QCKs8qhS35hFfq0BAYsIwZI= -github.com/google/pprof v0.0.0-20240903155634-a8630aee4ab9/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= -github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 h1:c5FlPPgxOn7kJz3VoPLkQYQXGBS3EklQ4Zfi57uOuqQ= -github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467 h1:keEZFtbLJugfE0qHn+Ge1JCE71spzkchQobDf3mzS/4= +github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -22,44 +19,40 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/onsi/ginkgo/v2 v2.20.2 h1:7NVCeyIWROIAheY21RLS+3j2bb52W0W82tkberYytp4= -github.com/onsi/ginkgo/v2 v2.20.2/go.mod h1:K9gyxPIlb+aIvnZ8bd9Ak+YP18w3APlR+5coaZoE2ag= -github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= -github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/onsi/ginkgo/v2 v2.22.0 h1:Yed107/8DjTr0lKCNt7Dn8yQ6ybuDRQoMGrNFKzMfHg= +github.com/onsi/ginkgo/v2 v2.22.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= +github.com/onsi/gomega v1.34.2 h1:pNCwDkzrsv7MS9kpaQvVb1aVLahQXyJ/Tv5oAZMI3i8= +github.com/onsi/gomega v1.34.2/go.mod h1:v1xfxRgk0KIsG+QOdm7p8UosrOzPYRo60fd3B/1Dukc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= -github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B2MR1K67ULZM= github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk= -golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= -golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= -golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= -golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= -golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= -golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From abcca35dc3af43f1958002336a06c8add261a684 Mon Sep 17 00:00:00 2001 From: rosahaj <141790572+rosahaj@users.noreply.github.com> Date: Thu, 5 Dec 2024 04:10:27 +0100 Subject: [PATCH 837/843] Update client.ImpersonateXXX methods --- client_impersonate.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/client_impersonate.go b/client_impersonate.go index e2fd94c8..dfa9235e 100644 --- a/client_impersonate.go +++ b/client_impersonate.go @@ -105,17 +105,17 @@ var ( chromeHeaders = map[string]string{ "pragma": "no-cache", "cache-control": "no-cache", - "sec-ch-ua": `"Not_A Brand";v="99", "Google Chrome";v="109", "Chromium";v="109"`, + "sec-ch-ua": `"Not_A Brand";v="8", "Chromium";v="120", "Google Chrome";v="120"`, "sec-ch-ua-mobile": "?0", "sec-ch-ua-platform": `"macOS"`, "upgrade-insecure-requests": "1", - "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "sec-fetch-site": "none", "sec-fetch-mode": "navigate", "sec-fetch-user": "?1", "sec-fetch-dest": "document", - "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,it;q=0.6", + "accept-language": "zh-CN,zh;q=0.9", } chromeHeaderPriority = http2.PriorityParam{ @@ -125,10 +125,10 @@ var ( } ) -// ImpersonateChrome impersonates Chrome browser (version 109). +// ImpersonateChrome impersonates Chrome browser (version 120). func (c *Client) ImpersonateChrome() *Client { c. - SetTLSFingerprint(utls.HelloChrome_Auto). // Chrome 106~109 shares the same tls fingerprint. + SetTLSFingerprint(utls.HelloChrome_120). SetHTTP2SettingsFrame(chromeHttp2Settings...). SetHTTP2ConnectionFlow(15663105). SetCommonPseudoHeaderOder(chromePseudoHeaderOrder...). @@ -229,7 +229,7 @@ var ( } firefoxHeaders = map[string]string{ - "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:105.0) Gecko/20100101 Firefox/105.0", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:120.0) Gecko/20100101 Firefox/120.0", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", "accept-language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2", "upgrade-insecure-requests": "1", @@ -247,10 +247,10 @@ var ( } ) -// ImpersonateFirefox impersonates Firefox browser (version 105). +// ImpersonateFirefox impersonates Firefox browser (version 120). func (c *Client) ImpersonateFirefox() *Client { c. - SetTLSFingerprint(utls.HelloFirefox_Auto). + SetTLSFingerprint(utls.HelloFirefox_120). SetHTTP2SettingsFrame(firefoxHttp2Settings...). SetHTTP2ConnectionFlow(12517377). SetHTTP2PriorityFrames(firefoxPriorityFrames...). @@ -309,10 +309,10 @@ var ( } ) -// ImpersonateSafari impersonates Safari browser (version 16). +// ImpersonateSafari impersonates Safari browser (version 16.6). func (c *Client) ImpersonateSafari() *Client { c. - SetTLSFingerprint(utls.HelloSafari_Auto). + SetTLSFingerprint(utls.HelloSafari_16_0). SetHTTP2SettingsFrame(safariHttp2Settings...). SetHTTP2ConnectionFlow(10485760). SetCommonPseudoHeaderOder(safariPseudoHeaderOrder...). From 8c6ee8ea41631495c515af0c354139f163637d64 Mon Sep 17 00:00:00 2001 From: god Date: Sat, 14 Dec 2024 14:48:02 +0800 Subject: [PATCH 838/843] =?UTF-8?q?feat(response):=20=E6=B7=BB=E5=8A=A0=20?= =?UTF-8?q?SetBody=20=E5=92=8C=20SetBodyStr=20=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在使用中间件对响应体进行加密/解密处理时,直接修改 Response.Body 无法生效,因为 body 已被读取到内存中。新增这两个方法可以直接 设置解密后的内容到 Response.body 字段,确保后续获取响应内容时 能正确获取到解密后的数据。 - 添加 SetBody 方法用于设置字节数组形式的响应体 - 添加 SetBodyStr 方法用于设置字符串形式的响应体 相关场景: - 响应体加密/解密中间件 - 响应内容统一处理 --- response.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/response.go b/response.go index 0e51900a..7f50f617 100644 --- a/response.go +++ b/response.go @@ -196,6 +196,16 @@ func (r *Response) Into(v interface{}) error { return r.Unmarshal(v) } +// Set response body with byte array content +func (r *Response) SetBody(body []byte) { + r.body = body +} + +// Set response body with string content +func (r *Response) SetBodyStr(body string) { + r.body = []byte(body) +} + // Bytes return the response body as []bytes that have already been read, could be // nil if not read, the following cases are already read: // 1. `Request.SetResult` or `Request.SetError` is called. From bc7cb8e32d2655f830783c475a8ae3355516aafa Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 Dec 2024 10:16:31 +0800 Subject: [PATCH 839/843] update go modules --- go.mod | 22 +++++++++++----------- go.sum | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index dc69fa45..6bef9175 100644 --- a/go.mod +++ b/go.mod @@ -3,27 +3,27 @@ module github.com/imroc/req/v3 go 1.22.0 require ( - github.com/andybalholm/brotli v1.1.0 + github.com/andybalholm/brotli v1.1.1 github.com/hashicorp/go-multierror v1.1.1 - github.com/klauspost/compress v1.17.9 + github.com/klauspost/compress v1.17.11 github.com/quic-go/qpack v0.5.1 github.com/quic-go/quic-go v0.48.2 github.com/refraction-networking/utls v1.6.7 - golang.org/x/net v0.31.0 - golang.org/x/text v0.20.0 + golang.org/x/net v0.32.0 + golang.org/x/text v0.21.0 ) require ( - github.com/cloudflare/circl v1.4.0 // indirect + github.com/cloudflare/circl v1.5.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467 // indirect + github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/onsi/ginkgo/v2 v2.22.0 // indirect go.uber.org/mock v0.5.0 // indirect - golang.org/x/crypto v0.29.0 // indirect - golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e // indirect golang.org/x/mod v0.22.0 // indirect - golang.org/x/sync v0.9.0 // indirect - golang.org/x/sys v0.27.0 // indirect - golang.org/x/tools v0.27.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/tools v0.28.0 // indirect ) diff --git a/go.sum b/go.sum index 3fa5fec8..e4968097 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,24 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY= github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU= +github.com/cloudflare/circl v1.5.0 h1:hxIWksrX6XN5a1L2TI/h53AGPhNHoUBo+TD1ms9+pys= +github.com/cloudflare/circl v1.5.0/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467 h1:keEZFtbLJugfE0qHn+Ge1JCE71spzkchQobDf3mzS/4= github.com/google/pprof v0.0.0-20241203143554-1e3fdc7de467/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -19,6 +26,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/onsi/ginkgo/v2 v2.22.0 h1:Yed107/8DjTr0lKCNt7Dn8yQ6ybuDRQoMGrNFKzMfHg= github.com/onsi/ginkgo/v2 v2.22.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= github.com/onsi/gomega v1.34.2 h1:pNCwDkzrsv7MS9kpaQvVb1aVLahQXyJ/Tv5oAZMI3i8= @@ -33,26 +42,41 @@ github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e h1:4qufH0hlUYs6AO6XmZC3GqfDPGSXHVXUFR6OND+iJX4= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o= golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= +golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= +golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 551efe5944695d710f2413a5699d9d1bd059f520 Mon Sep 17 00:00:00 2001 From: roc Date: Tue, 17 Dec 2024 10:23:04 +0800 Subject: [PATCH 840/843] rename SetBodyStr to SetBodyString(#408) --- response.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/response.go b/response.go index 7f50f617..a8e99025 100644 --- a/response.go +++ b/response.go @@ -202,7 +202,7 @@ func (r *Response) SetBody(body []byte) { } // Set response body with string content -func (r *Response) SetBodyStr(body string) { +func (r *Response) SetBodyString(body string) { r.body = []byte(body) } From 03fab1a258bf63f1a682b82d6bb1fde10ded4694 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 25 Dec 2024 14:52:20 +0800 Subject: [PATCH 841/843] Update golang.org/x/net to v0.33.0 (#409) --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 6bef9175..c54597be 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/quic-go/qpack v0.5.1 github.com/quic-go/quic-go v0.48.2 github.com/refraction-networking/utls v1.6.7 - golang.org/x/net v0.32.0 + golang.org/x/net v0.33.0 golang.org/x/text v0.21.0 ) diff --git a/go.sum b/go.sum index e4968097..98c5fd3b 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= From 80a65c471b0a4a5b8f978429af686ea43f56a1e9 Mon Sep 17 00:00:00 2001 From: bingoohuang Date: Tue, 4 Mar 2025 19:31:09 +0800 Subject: [PATCH 842/843] feat: Add LocalAddr to TraceInfo to capture local network address --- request.go | 1 + trace.go | 3 +++ 2 files changed, 4 insertions(+) diff --git a/request.go b/request.go index c8bea419..9036d72a 100644 --- a/request.go +++ b/request.go @@ -146,6 +146,7 @@ func (r *Request) TraceInfo() TraceInfo { // Capture remote address info when connection is non-nil if ct.gotConnInfo.Conn != nil { ti.RemoteAddr = ct.gotConnInfo.Conn.RemoteAddr() + ti.LocalAddr = ct.gotConnInfo.Conn.LocalAddr() } return ti diff --git a/trace.go b/trace.go index 1d9926ca..de5f1a6f 100644 --- a/trace.go +++ b/trace.go @@ -102,6 +102,9 @@ type TraceInfo struct { // RemoteAddr returns the remote network address. RemoteAddr net.Addr + + // LocalAddr returns the local network address. + LocalAddr net.Addr } type clientTrace struct { From b77d10510ad34654e12b9ef08a4cefef7e46dd94 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 5 Mar 2025 10:22:00 +0800 Subject: [PATCH 843/843] Add LocalAddr field to TraceInfo.String() --- trace.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/trace.go b/trace.go index de5f1a6f..c87127e9 100644 --- a/trace.go +++ b/trace.go @@ -17,12 +17,14 @@ TLSHandshakeTime : %v FirstResponseTime : %v ResponseTime : %v IsConnReused: : false -RemoteAddr : %v` +RemoteAddr : %v +LocalAddr : %v` traceReusedFmt = `TotalTime : %v FirstResponseTime : %v ResponseTime : %v IsConnReused: : true -RemoteAddr : %v` +RemoteAddr : %v +LocalAddr : %v` ) // Blame return the human-readable reason of why request is slowing. @@ -57,9 +59,9 @@ func (t TraceInfo) String() string { return "trace is not enabled" } if t.IsConnReused { - return fmt.Sprintf(traceReusedFmt, t.TotalTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) + return fmt.Sprintf(traceReusedFmt, t.TotalTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr, t.LocalAddr) } - return fmt.Sprintf(traceFmt, t.TotalTime, t.DNSLookupTime, t.TCPConnectTime, t.TLSHandshakeTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr) + return fmt.Sprintf(traceFmt, t.TotalTime, t.DNSLookupTime, t.TCPConnectTime, t.TLSHandshakeTime, t.FirstResponseTime, t.ResponseTime, t.RemoteAddr, t.LocalAddr) } // TraceInfo represents the trace information.